我有一个[n_rows, n_cols, n_channels]
大小的稀疏array.在我的代码中,我有一个循环,其中数组不断更新和裁剪:
def update(arr, row_idx, col_idx, ch_idx):
arr[row_idx, col_idx, ch_idx] += 1
arr[arr > 10] = 10
arr = np.array(n_rows, n_cols, n_channels)
while True:
update(arr, 0, 1, 2)
为了优化我的代码,我可以使用带有索引列表的缓存,并每N次迭代更新一次数组:
def update(arr, rows_list, cols_list, ch_list):
arr[rows_list, cols_list, ch_list] += 1
arr[arr > 10] = 10
arr = np.array(n_rows, n_cols, n_channels)
cache_length = 3
rows_list, cols_list, ch_list = [], [], []
while True:
rows_list.append(something1)
cols_list.append(something2)
ch_list.append(something3)
if len(row_list) == cache_length:
update(arr, rows_list, cols_list, ch_list)
rows_list, cols_list, ch_list = [], [], []
这可以节省时间,但可能会发生缓存多次包含相同的数组索引,例如:
# arr[0, 0, 6] should be updated twice
update(arr, [0, 0, 2], [3, 3, 5], [6, 6, 6])
我如何更改我的代码才能使此优化起作用?