最初的问题是使用np.linspace with arrays作为开始和停止参数,尽管现在我对我提出的解决方法有问题.

采取以下措施:

from numba import njit
import numpy as np

@njit
def f1():
  start = np.array([0.1, 1.0], np.float32)
  stop = np.array([1.0, 10.0], np.float32)
  return np.linspace(start, stop, 10)

f1()

这将引发一个错误,因为linspace中的though documented as supporting "only the 3-argument form",它们实际上的意思是"具有标量值的三参数形式,用于开始和停止".

因此,我提出了以下解决方法:

import numpy as np
from numba import njit

@njit
def f2():
  start = np.array([0.1, 1.0], np.float32)
  stop = np.array([1.0, 10.0], np.float32)
  pts_0 = np.linspace(start[0], stop[0], 10).astype(np.float32) # works
  pts_1 = np.linspace(start[1], stop[1], 10).astype(np.float32) # works
  return np.stack([pts_0, pts_1]).T                             # error

这会引发此错误:

---------------------------------------------------------------------------
TypingError                               Traceback (most recent call last)
c:\Users\X\Desktop\X\data_analysis.ipynb Cell 46' in <cell line: 18>()
     15   pts_1 = np.linspace(start[1], stop[1], 10).astype(np.float32)
     16   return np.stack([pts_0, pts_1]).T
---> 18 r = f2()

File c:\Users\X\miniconda3\envs\X\lib\site-packages\numba\core\dispatcher.py:468, in _DispatcherBase._compile_for_args(self, *args, **kws)
    464         msg = (f"{str(e).rstrip()} \n\nThis error may have been caused "
    465                f"by the following argument(s):\n{args_str}\n")
    466         e.patch_message(msg)
--> 468     error_rewrite(e, 'typing')
    469 except errors.UnsupportedError as e:
    470     # Something unsupported is present in the user code, add help info
    471     error_rewrite(e, 'unsupported_error')

File c:\Users\X\miniconda3\envs\X\lib\site-packages\numba\core\dispatcher.py:409, in _DispatcherBase._compile_for_args.<locals>.error_rewrite(e, issue_type)
    407     raise e
    408 else:
--> 409     raise e.with_traceback(None)

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<function stack at 0x00000186F280CAF0>) found for signature:
 
 >>> stack(list(array(float32, 1d, C))<iv=None>)

同样,根据documentationnp.stack是受支持的(这一侧也不受支持).

我错过了什么?

推荐答案

np.stack是受支持的,但到目前为止它是expect a tuple而不是列表.以下是固定代码:

@njit
def f2():
  start = np.array([0.1, 1.0], np.float32)
  stop = np.array([1.0, 10.0], np.float32)
  pts_0 = np.linspace(start[0], stop[0], 10).astype(np.float32) # works
  pts_1 = np.linspace(start[1], stop[1], 10).astype(np.float32) # works
  return np.stack((pts_0, pts_1)).T                             # works

顺便说一下,请注意np.stack((pts_0, pts_1)).T不是很有效,因为它创建temporary arrays和非连续视图.由于使用Numba的目的是加快代码的速度,因此可以考虑使用应该更快的基本循环.同样的道理也适用于astype(np.float32):循环可以就地转换值.内存和分配都很昂贵,这通常是Numpy速度较慢的原因(也就是缺少特定用途的函数).这样的事情在future 会变得更慢(更多信息,请考虑阅读更多关于"memory wall"的内容),因此需要避免它们.

这是一个具有基本循环的速度更快的版本:

@njit
def f2():
    start1, start2 = np.float32(0.1), np.float32(1.0)
    stop1, stop2 = np.float32(1.0), np.float32(10.0)
    steps = 10
    delta = np.float32(1 / (steps - 1))
    res = np.empty((steps, 2), dtype=np.float32)
    for i in range(steps):
        res[i, 0] = start1 + (stop1 - start1) * (delta * i)
        res[i, 1] = start2 + (stop2 - start2) * (delta * i)
    return res

请注意,由于32位FP舍入,结果可能略有不同.

Python相关问答推荐

对Numpy函数进行载体化

ModuleNotFound错误:没有名为Crypto Windows 11、Python 3.11.6的模块

如何删除索引过go 的lexsort深度可能会影响性能?' &>

如何在虚拟Python环境中运行Python程序?

如何将一个动态分配的C数组转换为Numpy数组,并在C扩展模块中返回给Python

如何获得每个组的时间戳差异?

在vscode上使用Python虚拟环境时((env))

将JSON对象转换为Dataframe

如何在Python中找到线性依赖mod 2

字符串合并语法在哪里记录

判断solve_ivp中的事件

如何在PySide/Qt QColumbnView中删除列

在极中解析带有数字和SI前缀的字符串

在Python中从嵌套的for循环中获取插值

python的文件. truncate()意外地没有截断'

使用np.fft.fft2和cv2.dft重现相位谱.为什么结果并不相似呢?

如何从一个维基页面中抓取和存储多个表格?

有没有一种方法可以根据不同索引集的数组从2D数组的对称子矩阵高效地构造3D数组?

GEKKO中若干参数的线性插值动态优化

搜索结果未显示.我的URL选项卡显示:http://127.0.0.1:8000/search?";,而不是这个:";http://127.0.0.1:8000/search?q=name";