我试图在numpy数组中找到重复的行.以下代码复制了我的数组的 struct ,该数组每行有n行、m列和nz个非零条目:
import numpy as np
import random
import datetime
def create_mat(n, m, nz):
sample_mat = np.zeros((n, m), dtype='uint8')
random.seed(42)
for row in range(0, n):
counter = 0
while counter < nz:
random_col = random.randrange(0, m-1, 1)
if sample_mat[row, random_col] == 0:
sample_mat[row, random_col] = 1
counter += 1
test = np.all(np.sum(sample_mat, axis=1) == nz)
print(f'All rows have {nz} elements: {test}')
return sample_mat
我试图优化的代码如下:
if __name__ == '__main__':
threshold = 2
mat = create_mat(1800000, 108, 8)
print(f'Time: {datetime.datetime.now()}')
duplicate_rows, _, duplicate_counts = np.unique(mat, axis=0, return_counts=True, return_index=True)
duplicate_indices = [int(x) for x in np.argwhere(duplicate_counts >= threshold)]
print(f'Time: {datetime.datetime.now()}')
print(f'Duplicate rows: {len(duplicate_rows)} Sample inds: {duplicate_indices[0:5]} Sample counts: {duplicate_counts[0:5]}')
print(f'Sample rows:')
print(duplicate_rows[0:5])
我的输出如下:
All rows have 8 elements: True
Time: 2022-06-29 12:08:07.320834
Time: 2022-06-29 12:08:23.281633
Duplicate rows: 1799994 Sample inds: [508991, 553136, 930379, 1128637, 1290356] Sample counts: [1 1 1 1 1]
Sample rows:
[[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 1 0 0 1 1 0 0 0 1 0 0 0 0 0 0 1 0 1 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 1 1 0 0 0 0 0 1 1 1 1 0 0 0 0 1 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 1 0 0 1 0 0 0 0 1 0 1 0 0 1 0 0 0 1 0 1 0 1 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 1 0 1 1 0 0 0 0 1 1 0 0 0 0 0 0 1 1 0 1 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 1 1 0 0 0 0 1 0 0 0 1 0 1 1 0 0 0 0 0 1 0 0 0 0 1 0]]
我考虑过使用NUBA,但挑战是它不使用轴参数.类似地,转换为列表和利用集合也是一种 Select ,但随后通过循环执行重复计数似乎"不符合逻辑".
考虑到我需要多次运行此代码(因为我正在修改numpy数组,然后需要重新搜索重复项),时间至关重要.我也try 对这一步使用多处理,但np.unique似乎被阻塞了(即,即使我try 运行多个版本的unique,我最终也会限制一个线程以6%的CPU容量运行,而其他线程则处于空闲状态).