我认为这是一项简单的任务,但我在网上找不到解决方案.我有一个外部C++库,我在Python代码中使用它,向我返回ctypes.POINTER(ctypes.c_float).我想将这些指针的数组传递给jax.vmap函数.问题是jax不接受ctypes.POINTER(ctypes.c_float)类型.那么,我可以以某种方式将这个指针投射到普通的int吗?从技术上讲,这显然是可能的.但如何在Python中做到这一点呢?

这是一个例子:

lib = ctypes.cdll.LoadLibrary(lib_path)
lib.foo.argtypes = None
lib.foo.restype = ctypes.POINTER(ctypes.c_float)

bar = jax.vmap(lambda : dummy lib.foo())(jax.numpy.empty(16))

x = jax.numpy.empty(16, 256, 256, 1)
y = jax.vmap(lib.bar, in_axes = (0, 1))(x, bar)

所以,我想调用lib.foo 16次,这样我就有一个包含所有指针的数组bar.然后我想调用另一个库函数lib.bar,它期望bar和另一个(批量)参数x.

问题在于,收件箱声称ctypes.POINTER(ctypes.c_float)不是有效的收件箱类型.这就是为什么我认为解决方案是将指针投射到int并将这些int存储在bar中.

推荐答案

列表:

Here's a piece of code exemplifying how to handle pointers and their addresses. The trick is to use ctypes.addressof (documented in the 2nd URL).

code00.py:

#!/usr/bin/env python

import ctypes as cts
import sys


CType = cts.c_float
CTypePtr = cts.POINTER(CType)


def ctype_pointer(seq):  # Helper
    CTypeArr = (CType * len(seq))
    ctype_arr = CTypeArr(*seq)
    return cts.cast(ctype_arr, CTypePtr)


def pointer_elements(addr, count):  # Helper
    return tuple(CType.from_address(addr + i * cts.sizeof(CType)).value for i in range(count))


def main(*argv):
    seq = (2.718182, -3.141593, 1.618034, -0.618034, 0)
    ptr = ctype_pointer(seq)
    print(f"Pointer: {ptr}")
    print(f"\nPointer elements: {tuple(ptr[i] for i in range(len(seq)))}")  # Check if pointer has correct data
    ptr_addr = cts.addressof(ptr.contents)  # @TODO - cfati: Straightforward
    print(f"\nAddress: {ptr_addr} (0x{ptr_addr:016X})\nElements from address: {pointer_elements(ptr_addr, len(seq))}")
    ptr_addr0 = cts.cast(ptr, cts.c_void_p).value  # @TODO - cfati: Alternative
    print(f"\nAddresses match: {ptr_addr == ptr_addr0}")


if __name__ == "__main__":
    print(
        "Python {:s} {:03d}bit on {:s}\n".format(
            " ".join(elem.strip() for elem in sys.version.split("\n")),
            64 if sys.maxsize > 0x100000000 else 32,
            sys.platform,
        )
    )
    rc = main(*sys.argv[1:])
    print("\nDone.\n")
    sys.exit(rc)

Notes:

  • 尽管它增加了一点复杂性,但我引入了CType"层",以表明它应该适用于任何类型,而不仅仅是float(只要序列中的值是该类型)

  • 唯一真正相关的线是标有100的线

Output:

(py_pc064_03.08_test0_lancer) [cfati@cfati-5510-0:/mnt/e/Work/Dev/StackExchange/StackOverflow/q078366208]> python ./code00.py 
Python 3.8.19 (default, Apr  6 2024, 17:58:10) [GCC 11.4.0] 064bit on linux

Pointer: <__main__.LP_c_float object at 0x7203e97e7d40>

Pointer elements: (2.71818208694458, -3.1415929794311523, 1.6180340051651, -0.6180340051651001, 0.0)

Address: 125361127594576 (0x00007203E97A9A50)
Elements from address: (2.71818208694458, -3.1415929794311523, 1.6180340051651, -0.6180340051651001, 0.0)

Addresses match: True

Done.

Python相关问答推荐

覆盖Django rest响应,仅返回PK

如何使用矩阵在sklearn中同时对每个列执行matthews_corrcoef?

用gekko解决的ADE方程系统突然不再工作,错误消息异常:@错误:模型文件未找到.& &

Python中MongoDB的BSON时间戳

如何在图片中找到这个化学测试条?OpenCV精明边缘检测不会绘制边界框

如何将双框框列中的成对变成两个新列

抓取rotowire MLB球员新闻并使用Python形成表格

2D空间中的反旋算法

如何让这个星型模式在Python中只使用一个for循环?

使用密钥字典重新配置嵌套字典密钥名

计算每个IP的平均值

如何杀死一个进程,我的Python可执行文件以sudo启动?

合并与拼接并举

Gunicorn无法启动Flask应用,因为无法将应用解析为属性名或函数调用.'"'' "

polars:有效的方法来应用函数过滤列的字符串

使用SQLAlchemy从多线程Python应用程序在postgr中插入多行的最佳方法是什么?'

如何根据一定条件生成段id

比较两个有条件的数据帧并删除所有不合格的数据帧

Pandas:计数器的滚动和,复位

具有不匹配列的2D到3D广播