如果您愿意使用另一个库,您可以使用JAX使NumPy函数的速度提高约7倍(尽管如果您的数组具有不同的形状,这可能并不理想,因为JAX会针对不同的形状重新编译该函数):
from jax import config
config.update("jax_enable_x64", True)
import jax.numpy as jnp
import jax
@jax.jit
def fast_transform_jax(img, offset, factor):
rep = (img.shape[0]//2, img.shape[1]//2)
out = (img.astype(np.float32) - jnp.tile(offset, rep)) * jnp.tile(factor, rep)
return out
对@Andrej的答案中的Numba函数稍作修改,使它们以运算函数通过allclose
:
@nb.njit
def fast_transform_numba(img, offset, factor):
img = img.astype(np.float32)
out = np.empty(img.shape, dtype=np.float64)
for i in range(img.shape[0]):
for j in range(img.shape[1]):
out[i, j] = (img[i, j] - offset[i % 2, j % 2]) * factor[i % 2, j % 2]
return out
@nb.njit(parallel=True)
def fast_transform_numba_parallel(img, offset, factor):
img = img.astype(np.float32)
out = np.empty(img.shape, dtype=np.float64)
for i in nb.prange(img.shape[0]):
for j in nb.prange(img.shape[1]):
out[i, j] = (img[i, j] - offset[i % 2, j % 2]) * factor[i % 2, j % 2]
return out
计时:
rng = np.random.default_rng()
N, M = 1000, 1000
img = rng.random((N, M)) * 50
offset = rng.random((2, 2)) * 40
factor = rng.random((2, 2)) * 30
assert np.allclose(fast_transform(img, offset, factor), fast_transform_numba(img, offset, factor))
assert np.allclose(fast_transform(img, offset, factor), fast_transform_numba_parallel(img, offset, factor))
assert np.allclose(fast_transform(img, offset, factor), fast_transform_jax(img, offset, factor))
%timeit fast_transform(img, offset, factor)
%timeit fast_transform_numba(img, offset, factor)
%timeit fast_transform_numba_parallel(img, offset, factor)
%timeit fast_transform_jax(img, offset, factor).block_until_ready()
输出:
3.59 ms ± 332 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
1.39 ms ± 37.1 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
871 µs ± 47.5 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
521 µs ± 36.8 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)