我开始研究pyo3,作为一种测试,我正试图使用pyo3包装一个RUST库;然而,当lambda函数作为参数传递时,我遇到了一些性能问题.
假设我有一个rust库,它的函数以回调作为参数.此函数多次计算回调函数,然后返回结果,例如:
pub fn test_function<F: Fn(f64) -> f64>(cb: F) -> f64 {
//Not actually an implementation this trivial, it is just to execute the callback a number of times
(0..225_000).map(|i| cb(i as f64)).sum::<f64>()
}
通过传递回调并计算平均时间,我try 通过执行此函数1000次来测量此函数所用的时间.
use std::time::SystemTime;
use pyo3test::test_function;
pub fn main() {
let mut sum = 0.0f64;
let reps = 1_000;
let start = SystemTime::now();
for _ in 0..reps {
sum += test_function(|x| x);
}
let end = start.elapsed().unwrap();
println!("Result: {sum}");
println!("Duration: {:?}", end.checked_div(reps).unwrap());
}
在我的机器上,当在发行版中运行时,每次执行test_function
大约需要250微秒.
然后,我try 使用pyo3以下面的方式包装这个函数
use pyo3::{PyAny, pyfunction, pymodule, PyResult, Python, wrap_pyfunction};
use pyo3::prelude::PyModule;
use crate::test_function;
#[pymodule]
fn pyo3test(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(py_test_function, m)?)?;
Ok(())
}
#[pyfunction]
pub fn py_test_function(function: &PyAny) -> f64 {
assert!(function.is_callable());
let cb = move |x| function.call1((x, )).unwrap().extract::<f64>().unwrap();
test_function(cb)
}
我编译了所有内容(在发行版中),将模块导入到了Python中,然后通过1000次执行此函数并计算平均时间,再次测量了所需时间
import pyo3test
from time import time
reps = 1000
cb = lambda x: x
summation = 0
start = time()
for _ in range(reps):
summation += pyo3test.py_test_function(cb)
end = time()
duration = end-start
avg = duration/reps
print(avg)
在这种情况下,平均执行时间大约为20毫秒,几乎比纯铁 rust 情况多2个数量级.由于GIL,我没有想到会有相同的执行时间,但我会猜到更接近毫秒的值.
这真的是意料之中的吗,还是我错过了什么?有可能在不改变纯铁 rust 实现的情况下改进这一点吗?
我试着查看文档,虽然it suggests和extract
很慢,但我认为我在这里做不了什么,因为downcast
不能在这里使用.
还有什么办法吗?
UPDATE
为了遵循S的建议,我又写了一个铁 rust 函数:
pub fn test_function_alt(values: &[f64]) -> f64 {
values.iter().sum::<f64>()
}
我用PYO3包好了
#[pyfunction]
pub fn py_test_function_alt(values: Vec<f64>) -> f64 {
test_function_alt(&values)
}
然后,我编写了以下的python函数
import numpy as np
def foobar(cb):
vals = cb(np.arange(225000))
return pyo3test.py_test_function_alt(vals)
此函数仍在大约20毫秒内执行.