我开始研究pyo3,作为一种测试,我正试图使用pyo3包装一个RUST库;然而,当lambda函数作为参数传递时,我遇到了一些性能问题.

假设我有一个rust库,它的函数以回调作为参数.此函数多次计算回调函数,然后返回结果,例如:

pub fn test_function<F: Fn(f64) -> f64>(cb: F) -> f64 {
    //Not actually an implementation this trivial, it is just to execute the callback a number of times
    (0..225_000).map(|i| cb(i as f64)).sum::<f64>()
}

通过传递回调并计算平均时间,我try 通过执行此函数1000次来测量此函数所用的时间.

use std::time::SystemTime;
use pyo3test::test_function;

pub fn main() {
    let mut sum = 0.0f64;
    let reps = 1_000;
    let start = SystemTime::now();
    for _ in 0..reps {
        sum += test_function(|x| x);
    }
    let end = start.elapsed().unwrap();
    println!("Result: {sum}");
    println!("Duration: {:?}", end.checked_div(reps).unwrap());
}

在我的机器上,当在发行版中运行时,每次执行test_function大约需要250微秒.

然后,我try 使用pyo3以下面的方式包装这个函数

use pyo3::{PyAny, pyfunction, pymodule, PyResult, Python, wrap_pyfunction};
use pyo3::prelude::PyModule;
use crate::test_function;

#[pymodule]
fn pyo3test(_py: Python, m: &PyModule) -> PyResult<()> {
    m.add_function(wrap_pyfunction!(py_test_function, m)?)?;
    Ok(())
}

#[pyfunction]
pub fn py_test_function(function: &PyAny) -> f64 {
    assert!(function.is_callable());
    let cb = move |x| function.call1((x, )).unwrap().extract::<f64>().unwrap();
    test_function(cb)
}

我编译了所有内容(在发行版中),将模块导入到了Python中,然后通过1000次执行此函数并计算平均时间,再次测量了所需时间

import pyo3test
from time import time

reps = 1000
cb = lambda x: x
summation = 0
start = time()
for _ in range(reps):
    summation += pyo3test.py_test_function(cb)
end = time()

duration = end-start
avg = duration/reps
print(avg)

在这种情况下,平均执行时间大约为20毫秒,几乎比纯铁 rust 情况多2个数量级.由于GIL,我没有想到会有相同的执行时间,但我会猜到更接近毫秒的值.

这真的是意料之中的吗,还是我错过了什么?有可能在不改变纯铁 rust 实现的情况下改进这一点吗?

我试着查看文档,虽然it suggestsextract很慢,但我认为我在这里做不了什么,因为downcast不能在这里使用. 还有什么办法吗?


UPDATE

为了遵循S的建议,我又写了一个铁 rust 函数:

pub fn test_function_alt(values: &[f64]) -> f64 {
    values.iter().sum::<f64>()
}

我用PYO3包好了

#[pyfunction]
pub fn py_test_function_alt(values: Vec<f64>) -> f64 {
    test_function_alt(&values)
}

然后,我编写了以下的python函数

import numpy as np

def foobar(cb):
    vals = cb(np.arange(225000))
    return pyo3test.py_test_function_alt(vals)

此函数仍在大约20毫秒内执行.

推荐答案

这里的GIL不是您的问题,python解释器很慢,在没有任何ffi的情况下,用纯python语言调用回调需要20毫秒.

重新设计您的API接口,使其不进行225_000次调用,取而代之的是使用专门为将数据传递给本机API而设计的pythonarraynumpy arrays.

Python相关问答推荐

我在使用fill_between()将最大和最小带应用到我的图表中时遇到问题

比较两个数据帧并并排附加结果(获取性能警告)

根据不同列的值在收件箱中移动数据

PywinAuto在Windows 11上引发了Memory错误,但在Windows 10上未引发

管道冻结和管道卸载

当从Docker的--env-file参数读取Python中的环境变量时,每个\n都会添加一个\'.如何没有额外的?

如何在Raspberry Pi上检测USB并使用Python访问它?

导入...从...混乱

Python脚本使用蓝牙运行在Windows 11与raspberry pi4

如何合并两个列表,并获得每个索引值最高的列表名称?

如何使用SentenceTransformers创建矢量嵌入?

跳过嵌套JSON中的级别并转换为Pandas Rame

如何在海上配对图中使某些标记周围的黑色边框

使用python playwright从 Select 子菜单中 Select 值

如何从比较函数生成ngroup?

解决Geopandas和Altair中的正图和投影问题

如何训练每一个pandaprame行的线性回归并生成斜率

为什么后跟inplace方法的`.rename(Columns={';b';:';b';},Copy=False)`没有更新原始数据帧?

在MongoDB文档中仅返回数组字段

组颠倒大Pandas 数据帧