我想规范化一个向量.最简单的方法是

import numpy as np
v = np.random.rand(3)
v /= np.linalg.norm(v)

但我担心我的软件包的性能和平方和(不可避免的),取平方根,然后dividing all the vector不是一个好主意.

然后我得到了this question,这个解决方案使用sklearn.preprocessing.normalize来完成.不幸的是,它向我的包中添加了另一个需求/依赖项.

这是一个问题.

  1. 难道不应该有numpy函数来实现这一点吗?它使用fast inverse square root algorithm.还是在numpy的范围之外,不应该有这样的函数?
  2. 我应该在cython/numba中实现我自己的功能吗?
  3. 或者,如果我非常担心性能,我应该放弃python而在C/C++中开始编写代码吗?

推荐答案

难道不应该有一个numpy函数来实现这一点吗?它使用快速平方根逆算法.还是它超出了numpy的范围,不应该有这样的功能?

我不知道Numpy中有什么功能可以做到这一点.纯Numpy中需要多个函数调用.sklearn.preprocessing.normalize确实是一个很好的 Select (而AFAIK并不是唯一提供这一点的软件包).

这东西是Numpy is not designed to compute small arrays efficiently.对于小数组(比如只有3个值),Numpy调用的开销是巨大的.组合多个函数调用只会使情况变得更糟.开销主要是由于类型/形状/值判断、内部函数调用、CPython解释器以及新数组的分配.因此,即使Numpy能够提供您想要的功能,对于只有3个项目的数组来说,速度也会很慢.

我应该在cython/numba中实现我自己的功能吗?

This is a good idea,因为Numba可以用更小的开销完成这项工作.注意,尽管Nuba函数调用仍然有很小的开销,但从Nuba上下文调用它们非常便宜(本机调用).

例如,您可以使用:

# Note:
# - The signature cause an eager compilation
# - ::1 means the axis is contiguous (generate a faster code)
@nb.njit('(float64[::1],)')
def normalize(v):
    s = 0.0
    for i in range(v.size):
        s += v[i] * v[i]
    inv_norm = 1.0 / np.sqrt(s)
    for i in range(v.size):
        v[i] *= inv_norm

此函数在正常工作时不分配任何新array.此外,Numba只能在包装函数中进行最少的判断.循环速度非常快,但如果用实际大小(如3)替换v.size,则可以使循环速度更快,因为JIT可以展开循环并生成接近最优的代码.np.sqrt将是内联的,它应该生成一条快速平方根FP指令.如果使用标志fastmath=True,JIT甚至可以在x86-64平台上使用专用的更快指令来计算倒数平方根(请注意,如果使用NaN等特殊值或关心FP关联性,fastmath是不安全的).尽管如此,在主流计算机上,对于非常小的向量调用此函数的开销可能为v.size-300 ns:CPython包装函数的开销很大.删除它的唯一解决方案是在调用者函数中使用Numba/Cython.如果您需要在大多数项目中使用它们,那么直接编写C/C++代码当然更好.

或者,如果我非常担心性能,我应该放弃python,开始用C/C++编写代码吗?

It depends of your overall project但如果您想像这样操纵许多小向量,直接使用C/C++会更加高效.另一种 Select 是对当前速度较慢的内核使用Numba或Cython.

优化良好的Numba代码或Cython代码的性能可以非常接近本地编译的C/C++代码.例如,我成功地用Numba一次超越了高度优化的OpenBLAS代码(多亏了专门化).Numba中的主要开销来源之一是数组绑定判断(通常可以针对循环进行优化).C/C++的级别较低,因此您不需要支付任何隐藏成本,但代码可能更难维护.此外,您可以应用在Nuba/Cython中甚至不可能实现的较低级别优化(例如,直接使用SIMD内部函数或汇编指令,生成带有模板的专用代码).

Python相关问答推荐

双情节在单个图上切换-pPython

pandas MultiIndex是SQL复合索引的对应物吗?

使用pandas MultiIndex进行不连续 Select

计算每月过go x年的平均值

如何处理必须存在于环境中但无法安装的Python项目依赖项?

通过交换 node 对链接列表进行 Select 排序

在Pandas框架中截短至固定数量的列

时间序列分解

韦尔福德方差与Numpy方差不同

为什么tkinter框架没有被隐藏?

当使用keras.utils.Image_dataset_from_directory仅加载测试数据集时,结果不同

重新匹配{ }中包含的文本,其中文本可能包含{{var}

通过Selenium从页面获取所有H2元素

scikit-learn导入无法导入名称METRIC_MAPPING64'

如何在python xsModel库中定义一个可选[December]字段,以产生受约束的SON模式

如何在Python数据框架中加速序列的符号化

使用密钥字典重新配置嵌套字典密钥名

计算每个IP的平均值

matplotlib + python foor loop

人口全部乱序 - Python—Matplotlib—映射