TL;DR:你可以用Numba来优化np.dot
,只操作only on binary values.更具体地说,您可以使用64-bit views一次对8个字节执行SIMD-like operations.
将列表转换为数组
首先,使用这种方法可以有效地将列表转换为相对紧凑的数组:
vector = np.fromiter(vector, np.uint8)
conflicts = np.array([np.fromiter(conflicts[i], np.uint8) for i in range(len(conflicts))])
这比使用自动Numpy转换或np.array
更快(在Numpy代码内部执行的判断更少,Numpy、Numpy知道要构建什么类型的数组,并且生成的数组在内存中更小,因此填充速度更快).此步骤可用于加速基于np.dot
的解决方案.
如果输入已经是一个Numpy数组,那么判断它们的类型是np.uint8
或np.int8
.否则,请以conflits = conflits.astype(np.uint8)
为例将其转换为该类型.
第一次try
然后,一种解决方案是使用np.packbits
将输入二进制值尽可能多地打包到内存中的位数组中,然后执行逻辑and.但事实证明np.packbits
是相当慢的.因此,这种解决方案最终不是一个好主意.事实上,任何创建形状类似于conflicts
的临时数组的解决方案都会很慢,因为在内存中写入这样的数组通常比np.dot
慢(np.dot
从内存中读取conflicts
一次).
使用麻木
由于np.dot
经过了很好的优化,所以击败它的唯一解决方案是使用优化的本机代码.得益于即时编译器,Numba可以在运行时从基于Numpy的Python代码生成本机可执行代码.其思想是在每个块vector
到conflicts
行之间执行逻辑and.判断每个块的冲突,以便尽早停止计算.通过比较两个数组的uint64视图(以SIMD友好的方式),可以按8个八位字节的组高效地比较块.
import numba as nb
@nb.njit('bool_(uint8[::1], uint8[:,::1])')
def check_valid(vector, conflicts):
n, m = conflicts.shape
assert vector.size == m
for i in range(n):
block_size = 128 # In the range: 8,16,...,248
conflicts_row = conflicts[i,:]
gsum = 0 # Global sum of conflicts
m_limit = m // block_size * block_size
for j in range(0, m_limit, block_size):
vector_block = vector[j:j+block_size].view(np.uint64)
conflicts_block = conflicts_row[j:j+block_size].view(np.uint64)
# Matching
lsum = np.uint64(0) # 8 local sums of conflicts
for k in range(block_size//8):
lsum += vector_block[k] & conflicts_block[k]
# Trick to perform the reduction of all the bytes in lsum
lsum += lsum >> 32
lsum += lsum >> 16
lsum += lsum >> 8
gsum += lsum & 0xFF
# Check if there is a conflict
if gsum >= 2:
return False
# Remaining part
for j in range(m_limit, m):
gsum += vector[j] & conflicts_row[j]
if gsum >= 2:
return False
return True
后果
在我的机器上,对于形状(16, 65536)
的大型conflicts
数组(没有冲突),这大约是9 times faster比np.dot
.两种情况下都不包括转换列表的时间.当存在冲突时,提供的解决方案会更快,因为它可以提前停止计算.
理论上,计算速度应该更快,但Numba JIT无法使用SIMD指令将循环矢量化.尽管如此,似乎同样的问题也出现在np.dot
人身上.如果数组更大,则可以并行化块的计算(如果函数返回False,则以较慢的计算为代价).