为了说明我面临的问题,下面是一些示例代码:

a = np.round(np.random.rand(10, 15))
counta = np.count_nonzero(a, axis=-1)
print(counta)

A = np.einsum('im,mj->ijm', a, a.T)
countA = np.count_nonzero(A, axis=-1)
print(countA)

它创建一个二维数组,并沿最后一个轴计算其非零元素.然后,它创建一个3D数组,其中的非零元素沿最后一个轴再次计数.

我的问题是,我的数组a太大了,我可以执行第一步,但不能执行第二步,因为A数组会占用太多内存.

Is there any way to still get 100? That is to count the zeros in A along a given axis without actually creating the array?

推荐答案

我认为您可以简单地使用矩阵乘法(点积)来获得结果,而不需要生成庞大的3D数组A:

a = np.round(np.random.rand(10, 15)).astype(int)
counta = np.count_nonzero(a, axis=-1)

A = np.einsum('im,mj->ijm', a, a.T)
countA = np.count_nonzero(A, axis=-1)

assert np.all(countA == (a @ a.T))

这也要快得多:

a = np.round(np.random.rand(1000, 1500)).astype(int)

%timeit np.count_nonzero(np.einsum('im,mj->ijm', a, a.T), axis=-1)
3.94 s ± 38.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit a @ a.T
558 ms ± 6.74 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

还要注意,第一步是多余的,第二步是多余的:

assert np.all(counta == np.diag(countA))

Python相关问答推荐

在Python中,如何初始化集合列表脚本的输出

如何循环循环的每个元素并过滤掉Python rame中的条件

如何在telegram 机器人中发送音频?

Tkinter -控制调色板的位置

"Discord机器人中缺少所需的位置参数ctx

在编写要Excel的数据透视框架时修复标题行

如何使用Python中的clinicalTrials.gov API获取完整结果?

通过仅导入pandas来在for循环中进行多情节

如何计算列表列行之间的公共元素

rame中不兼容的d类型

如何标记Spacy中不包含特定符号的单词?

从numpy数组和参数创建收件箱

如何让Flask 中的请求标签发挥作用

如何在给定的条件下使numpy数组的计算速度最快?

将9个3x3矩阵按特定顺序排列成9x9矩阵

递归访问嵌套字典中的元素值

在单个对象中解析多个Python数据帧

让函数调用方程

在Python中调用变量(特别是Tkinter)

人口全部乱序 - Python—Matplotlib—映射