考虑以下两种对二维数值数组中的所有值求和的方法.

import numpy as np
from numba import njit
a = np.random.rand(2, 5000)

@njit(fastmath=True, cache=True)
def sum_array_slow(arr):
    s = 0
    for i in range(arr.shape[0]):
        for j in range(arr.shape[1]):
            s += arr[i, j]
    return s
    
@njit(fastmath=True, cache=True)
def sum_array_fast(arr):
    s = 0
    for i in range(arr.shape[1]):
        s += arr[0, i]
    for i in range(arr.shape[1]):
        s += arr[1, i]
    return s

查看sum_arrayslow中的嵌套循环,它似乎应该以与sum_arrayfast相同的顺序执行完全相同的操作.但是:

In [46]: %timeit sum_array_slow(a)
7.7 µs ± 374 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

In [47]: %timeit sum_array_fast(a)
951 ns ± 2.63 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)

为什么SUM_ARRAY_FAST函数比SUM_ARRAY_SLOW快8倍,而它似乎会以相同的顺序执行相同的计算?

推荐答案

这是因为慢版本是not automatically vectorized(即.编译器无法生成快速SIMD代码),而快速版本是.这当然是因为Numba在第一个循环中没有优化索引换行,所以它是Numba的missed optimization.

通过分析汇编代码可以看出这一点.以下是慢速版本的热循环:

.LBB0_6:
    addq    %rbx, %rdx
    vaddsd  (%rax,%rdx,8), %xmm0, %xmm0
    leaq    1(%rsi), %rdx
    cmpq    $1, %rbp
    cmovleq %r13, %rdx
    addq    %rbx, %rdx
    vaddsd  (%rax,%rdx,8), %xmm0, %xmm0
    leaq    2(%rsi), %rdx
    cmpq    $2, %rbp
    cmovleq %r13, %rdx
    addq    %rbx, %rdx
    vaddsd  (%rax,%rdx,8), %xmm0, %xmm0
    leaq    3(%rsi), %rdx
    cmpq    $3, %rbp
    cmovleq %r13, %rdx
    addq    $4, %rsi
    leaq    -4(%rbp), %rdi
    addq    %rbx, %rdx
    vaddsd  (%rax,%rdx,8), %xmm0, %xmm0
    cmpq    $4, %rbp
    movl    $0, %edx
    cmovgq  %rsi, %rdx
    movq    %rdi, %rbp
    cmpq    %rsi, %r12
    jne .LBB0_6

我们可以看到,Numba生成了许多无用的索引判断,这使得循环的效率非常低.我不知道有什么干净的方法来解决这个问题.这是可悲的,因为这样的问题在实践中远非罕见.使用像C和C++这样的本机语言解决了这个问题(因为在数组中没有索引包装).一种不安全/难看的方式是在Numba中使用指针,但提取Numpy数据指针并将其提供给Numba似乎是一件相当痛苦的事情(如果可能的话).

这里是一个快速的例子:

.LBB0_8:
    vaddpd  (%r11,%rsi,8), %ymm0, %ymm0
    vaddpd  32(%r11,%rsi,8), %ymm1, %ymm1
    vaddpd  64(%r11,%rsi,8), %ymm2, %ymm2
    vaddpd  96(%r11,%rsi,8), %ymm3, %ymm3
    vaddpd  128(%r11,%rsi,8), %ymm0, %ymm0
    vaddpd  160(%r11,%rsi,8), %ymm1, %ymm1
    vaddpd  192(%r11,%rsi,8), %ymm2, %ymm2
    vaddpd  224(%r11,%rsi,8), %ymm3, %ymm3
    vaddpd  256(%r11,%rsi,8), %ymm0, %ymm0
    vaddpd  288(%r11,%rsi,8), %ymm1, %ymm1
    vaddpd  320(%r11,%rsi,8), %ymm2, %ymm2
    vaddpd  352(%r11,%rsi,8), %ymm3, %ymm3
    vaddpd  384(%r11,%rsi,8), %ymm0, %ymm0
    vaddpd  416(%r11,%rsi,8), %ymm1, %ymm1
    vaddpd  448(%r11,%rsi,8), %ymm2, %ymm2
    vaddpd  480(%r11,%rsi,8), %ymm3, %ymm3
    addq    $64, %rsi
    addq    $-4, %rdi
    jne .LBB0_8

在这种情况下,循环得到了很好的优化.事实上,它对于大型数组几乎是最佳的.对于像您的示例中这样的小数组,它在像我这样的处理器上并不是最优的.事实上,AFAIK,展开的指令没有使用足够的寄存器来隐藏FMA单元的延迟(这是因为LLVM在内部生成次优代码).可能需要较低级别的本机代码来修复这个问题(至少,在Numba中没有简单的方法来修复这个问题).


更新

由于@Max9111提供了this link,所以可以使用无符号整数来优化慢代码.这个技巧极大地缩短了执行时间.以下是修改后的代码:

@njit(fastmath=True, cache=True)
def sum_array_faster(arr):
    s = 0
    for i in range(np.uint64(arr.shape[0])):
        for j in range(np.uint64(arr.shape[1])):
            s += arr[i, j]
    return s

以下是英特尔至强W-2255处理器的性能:

slow:     9.66 µs
fastest:  1.13 µs
fast:     1.14 µs

Theoretical optimal:  0.30-0.35 µs

replacing opt=0 by opt=2的解决方法(再次感谢@Max911)在我的机器上没有产生很好的结果:

slow:     2.12 µs
fastest:  2.17 µs
fast:     2.09 µs

更不用说编译时间也稍微长了一些.

可以实现更快的实现,以便更好地隐藏FMA指令的等待时间:

@njit(fastmath=True, cache=True)
def sum_array_fastest(arr):
    s0, s1 = 0, 0
    for i in range(arr.shape[1]):
        s0 += arr[0, i]
        s1 += arr[1, i]
    return s0 + s1

这款花了1.08微秒.它更好.

生成的Numba代码仍然有两个限制因素:

  • 与(短)执行时间相比,Numba的开销相当大:250-300 ns
  • Numba不使用我机器上可用的AVX-512(ZMM寄存器比AVX寄存器大两倍).

注意,可以使用Numba函数的方法inspect_asm来提取汇编代码.

Python相关问答推荐

Pandas 密集排名具有相同值,按顺序排列

将列中的滚动值集转换为单元格中的单个值

重命名变量并使用载体中的字符串存储 Select 该变量

Polars -转换为PL后无法计算熵.列表

创建带有二维码的Flask应用程序,可重定向到特定端点

不允许AMBIMA API请求方法

Pandas 在时间序列中设定频率

拆分pandas列并创建包含这些拆分值计数的新列

Python中是否有方法从公共域检索搜索结果

如何从具有多个嵌入选项卡的网页中Web抓取td类元素

需要计算60,000个坐标之间的距离

Python中的嵌套Ruby哈希

使可滚动框架在tkinter环境中看起来自然

通过pandas向每个非空单元格添加子字符串

我如何使法国在 map 中完全透明的代码?

在pandas中使用group_by,但有条件

为什么Django管理页面和我的页面的其他CSS文件和图片都找不到?'

判断solve_ivp中的事件

在单次扫描中创建列表

如何获取Python synsets列表的第一个内容?