我使用以下代码将JAX 2D数组的特定行设置为使用JAX数组的特定值:

zeros_array = jnp.zeros((3, 8))
value = jnp.array([1,2,3,4])
value_2 = jnp.array([1])
value_3 = jnp.array([1,2])
values = jnp.array([value,value_2,value_3])
zeros_array = zeros_array.at[0].set(values)

但是,我收到以下错误:

ValueError: All input arrays must have the same shape.

将JNP修改为NP(NumPy)时,错误消失.有什么方法可以解决这个错误吗?我知道一种解决方法是使用at[0,1].set(),at[0,2:n].set()设置2D数组中的每个单独array.

推荐答案

您所考虑的是一个"粗糙数组",不,在JAX中目前没有任何方法可以做到这一点.在较旧版本的NumPy中,这可以通过返回dtype为object的数组来实现,但在较新版本的NumPy中,这会导致错误,因为使用object数组通常不方便且效率低下(例如,如果更新存储在对象数组中,则无法有效地执行与最后一行中的索引更新操作相同的操作).

根据您的用例,可以在JAX和NumPy中使用几种解决方法,包括将数组的行存储为列表,或使用填充的2D数组表示.

我还要指出的是,JAX团队正在探索对散乱数组的本机支持(例如,参见https://github.com/google/jax/pull/16541),但它仍远未达到普遍使用的程度.

Python相关问答推荐

将两个收件箱相连导致索引的列标题消失

在pandas DataFrame上运行apply()时如何访问DateTime索引?

如何确保Flask应用程序管理面板中的项目具有单击删除功能?

在Python中根据id填写年份系列

替换字符串中的点/逗号,以便可以将其转换为浮动

在Python中管理多个OpenGVBO和VAO实例

在Windows上启动新Python项目的正确步骤顺序

是什么导致对Python脚本的jQuery Ajax调用引发500错误?

模型序列化器中未调用现场验证器

更改Seaborn条形图中的x轴日期时间限制

如何才能知道Python中2列表中的巧合.顺序很重要,但当1个失败时,其余的不应该失败或是0巧合

比较两个数据帧并并排附加结果(获取性能警告)

_repr_html_实现自定义__getattr_时未显示

根据二元组列表在pandas中创建新列

如何从.cgi网站刮一张表到rame?

django禁止直接分配到多对多集合的前端.使用user.set()

关于Python异步编程的问题和使用await/await def关键字

当点击tkinter菜单而不是菜单选项时,如何执行命令?

如何在BeautifulSoup/CSS Select 器中处理regex?

基于多个数组的多个条件将值添加到numpy数组