最初的问题是使用np.linspace
with arrays作为开始和停止参数,尽管现在我对我提出的解决方法有问题.
采取以下措施:
from numba import njit
import numpy as np
@njit
def f1():
start = np.array([0.1, 1.0], np.float32)
stop = np.array([1.0, 10.0], np.float32)
return np.linspace(start, stop, 10)
f1()
这将引发一个错误,因为linspace
中的though documented as supporting "only the 3-argument form",它们实际上的意思是"具有标量值的三参数形式,用于开始和停止".
因此,我提出了以下解决方法:
import numpy as np
from numba import njit
@njit
def f2():
start = np.array([0.1, 1.0], np.float32)
stop = np.array([1.0, 10.0], np.float32)
pts_0 = np.linspace(start[0], stop[0], 10).astype(np.float32) # works
pts_1 = np.linspace(start[1], stop[1], 10).astype(np.float32) # works
return np.stack([pts_0, pts_1]).T # error
这会引发此错误:
---------------------------------------------------------------------------
TypingError Traceback (most recent call last)
c:\Users\X\Desktop\X\data_analysis.ipynb Cell 46' in <cell line: 18>()
15 pts_1 = np.linspace(start[1], stop[1], 10).astype(np.float32)
16 return np.stack([pts_0, pts_1]).T
---> 18 r = f2()
File c:\Users\X\miniconda3\envs\X\lib\site-packages\numba\core\dispatcher.py:468, in _DispatcherBase._compile_for_args(self, *args, **kws)
464 msg = (f"{str(e).rstrip()} \n\nThis error may have been caused "
465 f"by the following argument(s):\n{args_str}\n")
466 e.patch_message(msg)
--> 468 error_rewrite(e, 'typing')
469 except errors.UnsupportedError as e:
470 # Something unsupported is present in the user code, add help info
471 error_rewrite(e, 'unsupported_error')
File c:\Users\X\miniconda3\envs\X\lib\site-packages\numba\core\dispatcher.py:409, in _DispatcherBase._compile_for_args.<locals>.error_rewrite(e, issue_type)
407 raise e
408 else:
--> 409 raise e.with_traceback(None)
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<function stack at 0x00000186F280CAF0>) found for signature:
>>> stack(list(array(float32, 1d, C))<iv=None>)
同样,根据documentation,np.stack
是受支持的(这一侧也不受支持).
我错过了什么?