我有一个[n_rows, n_cols, n_channels]大小的稀疏array.在我的代码中,我有一个循环,其中数组不断更新和裁剪:

def update(arr, row_idx, col_idx, ch_idx):
    arr[row_idx, col_idx, ch_idx] += 1
    arr[arr > 10] = 10

arr = np.array(n_rows, n_cols, n_channels)
while True:
    update(arr, 0, 1, 2)

为了优化我的代码,我可以使用带有索引列表的缓存,并每N次迭代更新一次数组:

def update(arr, rows_list, cols_list, ch_list):
        arr[rows_list, cols_list, ch_list] += 1
        arr[arr > 10] = 10

arr = np.array(n_rows, n_cols, n_channels)
cache_length = 3
rows_list, cols_list, ch_list = [], [], []
while True:
    rows_list.append(something1)
    cols_list.append(something2)
    ch_list.append(something3)
    if len(row_list) == cache_length:
        update(arr, rows_list, cols_list, ch_list)
        rows_list, cols_list, ch_list = [], [], []

这可以节省时间,但可能会发生缓存多次包含相同的数组索引,例如:

# arr[0, 0, 6] should be updated twice
update(arr, [0, 0, 2], [3, 3, 5], [6, 6, 6])

我如何更改我的代码才能使此优化起作用?

推荐答案

您可以使用numpy.unique进行聚合:

def update(arr, row_idx, col_idx, ch_idx):
    idx, cnt = np.unique([row_idx, col_idx, ch_idx],
                         return_counts=True, axis=1)
    arr[tuple(idx)] += cnt
    arr[arr > 10] = 10

此外,您还可以通过仅剪裁更新值(而不是整个数组)来进一步优化:

def update(arr, row_idx, col_idx, ch_idx):
    idx, cnt = np.unique([row_idx, col_idx, ch_idx],
                         return_counts=True, axis=1)
    idx = tuple(idx)
    arr[idx] = np.clip(arr[idx]+cnt, -np.inf, 10)

示例:

arr = np.zeros((2, 3, 4), dtype='int')
update(arr, [0, 0, 1], [1, 1, 2], [3, 3, 3])

# arr
array([[[0, 0, 0, 0],
        [0, 0, 0, 2],
        [0, 0, 0, 0]],

       [[0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 1]]])

Python-3.x相关问答推荐

丢弃重复的索引,并在多索引数据帧中保留一个

如何使用Python将嵌套的XML转换为CSV

将列表项的极列水平分解为新列

如何将从维基百科表中抓取的数据转换为字典列表?

类变量的Python子类被视为类方法

如何将 WebDriver 传输到导入的测试?

Pandas groupby 然后 for each 组添加新行

使用 python 查找标记的元素

Pandas:从 Pandas 数据框中的 1 和 0 模式中获取值和 ID 的计数

如何计算Pandas 列中每列唯一项目的出现次数?

如何在两个矩阵的比较中允许任何列的符号差异,Python3?

用于 BIG 数组计算的多处理池映射比预期的要慢

例外:使用 Pyinstaller 时找不到 PyQt5 插件目录,尽管 PyQt5 甚至没有被使用

如何并行化文件下载?

numpy.ndarray 与 pandas.DataFrame

Python 3 变量名中接受哪些 Unicode 符号?

将行附加到 DataFrame 的最快和最有效的方法是什么?

变量类型注解NameError不一致

在 linux mint 上安装 python3-venv 模块

在 Python 中生成马尔可夫转移矩阵