我使用以下代码将JAX 2D数组的特定行设置为使用JAX数组的特定值:
zeros_array = jnp.zeros((3, 8))
value = jnp.array([1,2,3,4])
value_2 = jnp.array([1])
value_3 = jnp.array([1,2])
values = jnp.array([value,value_2,value_3])
zeros_array = zeros_array.at[0].set(values)
但是,我收到以下错误:
ValueError: All input arrays must have the same shape.
将JNP修改为NP(NumPy)时,错误消失.有什么方法可以解决这个错误吗?我知道一种解决方法是使用at[0,1].set(),at[0,2:n].set()设置2D数组中的每个单独array.