我有这个简单的python函数:

import numpy as np

def fast_transform(img, offset, factor):
    rep = (img.shape[0]//2, img.shape[1]//2)
    out = (img.astype(np.float32) - np.tile(offset, rep)) * np.tile(factor, rep)
    return out

该函数获得一个图像(作为nxm数字ndarray)和两个2x2数组(偏移量和因子).然后,它根据图像中每个像素在每个维度上的奇偶性来计算基本的线性变换: out[i,j] = (out[i,j] - offset[i%2,j%2]) * factor[i%2,j%2]

如您所见,我使用了np.til来try 提高函数的速度,但速度不足以满足我的需要(我认为创建虚拟的np.til数组使其不是最优的).我试着使用Numba,但它还不支持np.til.

你能帮我尽可能优化这个功能吗?我相信有一些简单的方法可以做到这一点,我错过了.

推荐答案

如果您愿意使用另一个库,您可以使用JAX使NumPy函数的速度提高约7倍(尽管如果您的数组具有不同的形状,这可能并不理想,因为JAX会针对不同的形状重新编译该函数):

from jax import config
config.update("jax_enable_x64", True)
import jax.numpy as jnp
import jax

@jax.jit
def fast_transform_jax(img, offset, factor):
    rep = (img.shape[0]//2, img.shape[1]//2)
    out = (img.astype(np.float32) - jnp.tile(offset, rep)) * jnp.tile(factor, rep)
    return out

对@Andrej的答案中的Numba函数稍作修改,使它们以运算函数通过allclose:

@nb.njit
def fast_transform_numba(img, offset, factor):
    img = img.astype(np.float32)
    out = np.empty(img.shape, dtype=np.float64)
    for i in range(img.shape[0]):
        for j in range(img.shape[1]):
            out[i, j] = (img[i, j] - offset[i % 2, j % 2]) * factor[i % 2, j % 2]
    return out

@nb.njit(parallel=True)
def fast_transform_numba_parallel(img, offset, factor):
    img = img.astype(np.float32)
    out = np.empty(img.shape, dtype=np.float64)
    for i in nb.prange(img.shape[0]):
        for j in nb.prange(img.shape[1]):
            out[i, j] = (img[i, j] - offset[i % 2, j % 2]) * factor[i % 2, j % 2]
    return out

计时:

rng = np.random.default_rng()

N, M = 1000, 1000
img = rng.random((N, M)) * 50
offset = rng.random((2, 2)) * 40
factor = rng.random((2, 2)) * 30

assert np.allclose(fast_transform(img, offset, factor), fast_transform_numba(img, offset, factor))
assert np.allclose(fast_transform(img, offset, factor), fast_transform_numba_parallel(img, offset, factor))
assert np.allclose(fast_transform(img, offset, factor), fast_transform_jax(img, offset, factor))

%timeit fast_transform(img, offset, factor)
%timeit fast_transform_numba(img, offset, factor)
%timeit fast_transform_numba_parallel(img, offset, factor)
%timeit fast_transform_jax(img, offset, factor).block_until_ready()

输出:

3.59 ms ± 332 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
1.39 ms ± 37.1 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
871 µs ± 47.5 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
521 µs ± 36.8 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

Python相关问答推荐

Python中MongoDB的BSON时间戳

Chatgpt API不断返回错误:404未能从API获取响应

如何自动抓取以下CSV

Python daskValue错误:无法识别的区块管理器dask -必须是以下之一:[]

pandas滚动和窗口中有效观察的最大数量

如何使用数组的最小条目拆分数组

在Python中动态计算范围

Stacked bar chart from billrame

多处理队列在与Forking http.server一起使用时随机跳过项目

索引到 torch 张量,沿轴具有可变长度索引

计算天数

解决调用嵌入式函数的XSLT中表达式的语法移位/归约冲突

幂集,其中每个元素可以是正或负""""

(Python/Pandas)基于列中非缺失值的子集DataFrame

为什么调用函数的值和次数不同,递归在代码中是如何工作的?

Gekko中基于时间的间隔约束

Python 3试图访问在线程调用中实例化的类的对象

解决Geopandas和Altair中的正图和投影问题

正在try 让Python读取特定的CSV文件

PYTHON中的selenium不会打开 chromium URL