这是因为慢版本是not automatically vectorized(即.编译器无法生成快速SIMD代码),而快速版本是.这当然是因为Numba在第一个循环中没有优化索引换行,所以它是Numba的missed optimization.
通过分析汇编代码可以看出这一点.以下是慢速版本的热循环:
.LBB0_6:
addq %rbx, %rdx
vaddsd (%rax,%rdx,8), %xmm0, %xmm0
leaq 1(%rsi), %rdx
cmpq $1, %rbp
cmovleq %r13, %rdx
addq %rbx, %rdx
vaddsd (%rax,%rdx,8), %xmm0, %xmm0
leaq 2(%rsi), %rdx
cmpq $2, %rbp
cmovleq %r13, %rdx
addq %rbx, %rdx
vaddsd (%rax,%rdx,8), %xmm0, %xmm0
leaq 3(%rsi), %rdx
cmpq $3, %rbp
cmovleq %r13, %rdx
addq $4, %rsi
leaq -4(%rbp), %rdi
addq %rbx, %rdx
vaddsd (%rax,%rdx,8), %xmm0, %xmm0
cmpq $4, %rbp
movl $0, %edx
cmovgq %rsi, %rdx
movq %rdi, %rbp
cmpq %rsi, %r12
jne .LBB0_6
我们可以看到,Numba生成了许多无用的索引判断,这使得循环的效率非常低.我不知道有什么干净的方法来解决这个问题.这是可悲的,因为这样的问题在实践中远非罕见.使用像C和C++这样的本机语言解决了这个问题(因为在数组中没有索引包装).一种不安全/难看的方式是在Numba中使用指针,但提取Numpy数据指针并将其提供给Numba似乎是一件相当痛苦的事情(如果可能的话).
这里是一个快速的例子:
.LBB0_8:
vaddpd (%r11,%rsi,8), %ymm0, %ymm0
vaddpd 32(%r11,%rsi,8), %ymm1, %ymm1
vaddpd 64(%r11,%rsi,8), %ymm2, %ymm2
vaddpd 96(%r11,%rsi,8), %ymm3, %ymm3
vaddpd 128(%r11,%rsi,8), %ymm0, %ymm0
vaddpd 160(%r11,%rsi,8), %ymm1, %ymm1
vaddpd 192(%r11,%rsi,8), %ymm2, %ymm2
vaddpd 224(%r11,%rsi,8), %ymm3, %ymm3
vaddpd 256(%r11,%rsi,8), %ymm0, %ymm0
vaddpd 288(%r11,%rsi,8), %ymm1, %ymm1
vaddpd 320(%r11,%rsi,8), %ymm2, %ymm2
vaddpd 352(%r11,%rsi,8), %ymm3, %ymm3
vaddpd 384(%r11,%rsi,8), %ymm0, %ymm0
vaddpd 416(%r11,%rsi,8), %ymm1, %ymm1
vaddpd 448(%r11,%rsi,8), %ymm2, %ymm2
vaddpd 480(%r11,%rsi,8), %ymm3, %ymm3
addq $64, %rsi
addq $-4, %rdi
jne .LBB0_8
在这种情况下,循环得到了很好的优化.事实上,它对于大型数组几乎是最佳的.对于像您的示例中这样的小数组,它在像我这样的处理器上并不是最优的.事实上,AFAIK,展开的指令没有使用足够的寄存器来隐藏FMA单元的延迟(这是因为LLVM在内部生成次优代码).可能需要较低级别的本机代码来修复这个问题(至少,在Numba中没有简单的方法来修复这个问题).
更新
由于@Max9111提供了this link,所以可以使用无符号整数来优化慢代码.这个技巧极大地缩短了执行时间.以下是修改后的代码:
@njit(fastmath=True, cache=True)
def sum_array_faster(arr):
s = 0
for i in range(np.uint64(arr.shape[0])):
for j in range(np.uint64(arr.shape[1])):
s += arr[i, j]
return s
以下是英特尔至强W-2255处理器的性能:
slow: 9.66 µs
fastest: 1.13 µs
fast: 1.14 µs
Theoretical optimal: 0.30-0.35 µs
replacing opt=0
by opt=2
的解决方法(再次感谢@Max911)在我的机器上没有产生很好的结果:
slow: 2.12 µs
fastest: 2.17 µs
fast: 2.09 µs
更不用说编译时间也稍微长了一些.
可以实现更快的实现,以便更好地隐藏FMA指令的等待时间:
@njit(fastmath=True, cache=True)
def sum_array_fastest(arr):
s0, s1 = 0, 0
for i in range(arr.shape[1]):
s0 += arr[0, i]
s1 += arr[1, i]
return s0 + s1
这款花了1.08微秒.它更好.
生成的Numba代码仍然有两个限制因素:
- 与(短)执行时间相比,Numba的开销相当大:250-300 ns
- Numba不使用我机器上可用的AVX-512(ZMM寄存器比AVX寄存器大两倍).
注意,可以使用Numba函数的方法inspect_asm
来提取汇编代码.