我想应用一个numba UDF,它为df中的每个组生成相同长度的向量:

import numba

df = pl.DataFrame(
    {
        "group": ["A", "A", "A", "B", "B"],
        "index": [1, 3, 5, 1, 4],
    }
)

@numba.jit(nopython=True)
def UDF(array: np.ndarray, threshold: int) -> np.ndarray:
    result = np.zeros(array.shape[0])
    accumulator = 0
    
    for i, value in enumerate(array):
        accumulator += value
        if accumulator >= threshold:
            result[i] = 1
            accumulator = 0
            
    return result

df.with_columns(
    pl.col("index")
    .map_batches(
        lambda x: UDF(x.to_numpy(), 5)
        )
    .over("group")
    .cast(pl.UInt8)
    .alias("udf")
    )

灵感来自this post,其中引入了multi-processing应用程序.然而,在上面的例子中,我使用了一个over窗口函数来应用自定义框架.是否有一个有效的方法来执行上述执行?

预期输出:

shape: (6, 3)
┌───────┬───────┬─────┐
│ group ┆ index ┆ udf │
│ ---   ┆ ---   ┆ --- │
│ str   ┆ i64   ┆ u8  │
╞═══════╪═══════╪═════╡
│ A     ┆ 1     ┆ 0   │
│ A     ┆ 3     ┆ 0   │
│ A     ┆ 5     ┆ 1   │
│ B     ┆ 1     ┆ 0   │
│ B     ┆ 4     ┆ 1   │
└───────┴───────┴─────┘

推荐答案

Here is example how you can do this with + using numba's parallelization features:

from numba import njit, prange


@njit(parallel=True)
def UDF_nb_parallel(array, n, threshold):
    result = np.zeros_like(array, dtype="uint8")

    for i in prange(array.size // n):
        accumulator = 0
        for j in range(i * n, (i + 1) * n):
            value = array[j]
            accumulator += value
            if accumulator >= threshold:
                result[j] = 1
                accumulator = 0

    return result

df = df.with_columns(
    pl.Series(name="new_udf", values=UDF_nb_parallel(df["index"].to_numpy(), 3, 5))
)
print(df)

打印:

shape: (9, 3)
┌───────┬───────┬─────────┐
│ group ┆ index ┆ new_udf │
│ ---   ┆ ---   ┆ ---     │
│ str   ┆ i64   ┆ u8      │
╞═══════╪═══════╪═════════╡
│ A     ┆ 1     ┆ 0       │
│ A     ┆ 3     ┆ 0       │
│ A     ┆ 5     ┆ 1       │
│ B     ┆ 1     ┆ 0       │
│ B     ┆ 4     ┆ 1       │
│ B     ┆ 8     ┆ 1       │
│ C     ┆ 1     ┆ 0       │
│ C     ┆ 1     ┆ 0       │
│ C     ┆ 4     ┆ 1       │
└───────┴───────┴─────────┘

基准:

from timeit import timeit

import numpy as np
import polars as pl
from numba import njit, prange


def get_df(N, n):
    assert N % n == 0

    df = pl.DataFrame(
        {
            "group": [f"group_{i}" for i in range(N // n) for _ in range(n)],
            "index": np.random.randint(1, 5, size=N, dtype="uint64"),
        }
    )
    return df


@njit
def UDF(array: np.ndarray, threshold: int) -> np.ndarray:
    result = np.zeros(array.shape[0])
    accumulator = 0

    for i, value in enumerate(array):
        accumulator += value
        if accumulator >= threshold:
            result[i] = 1
            accumulator = 0

    return result


@njit(parallel=True)
def UDF_nb_parallel(array, n, threshold):
    result = np.zeros_like(array, dtype="uint8")

    for i in prange(array.size // n):
        accumulator = 0
        for j in range(i * n, (i + 1) * n):
            value = array[j]
            accumulator += value
            if accumulator >= threshold:
                result[j] = 1
                accumulator = 0

    return result


def get_udf_polars(df):
    return df.with_columns(
        pl.col("index")
        .map_batches(lambda x: UDF(x.to_numpy(), 5))
        .over("group")
        .cast(pl.UInt8)
        .alias("udf")
    )


df = get_df(3 * 33_333, 3)  # 100_000 values, length of groups 3

df = get_udf_polars(df)

df = df.with_columns(
    pl.Series(name="new_udf", values=UDF_nb_parallel(df["index"].to_numpy(), 3, 5))
)

assert np.allclose(df["udf"].to_numpy(), df["new_udf"].to_numpy())


t1 = timeit("get_udf_polars(df)", number=1, globals=globals())
t2 = timeit(
    'df.with_columns(pl.Series(name="new_udf", values=UDF_nb_parallel(df["index"].to_numpy(), 3, 5)))',
    number=1,
    globals=globals(),
)

print(t1)
print(t2)

我的机器上的打印(AMD 5700x):

2.7000599699968006
0.00025866299984045327

0.06319052699836902_000_000行/组3需要0.06319052699836902行(parallel=False需要0.2159650030080229行)


编辑:处理可变长度组:

@njit(parallel=True)
def UDF_nb_parallel_2(array, indices, amount, threshold):
    result = np.zeros_like(array, dtype="uint8")

    for i in prange(indices.size):
        accumulator = 0
        for j in range(indices[i], indices[i] + amount[i]):
            value = array[j]
            accumulator += value
            if accumulator >= threshold:
                result[j] = 1
                accumulator = 0

    return result

def get_udf_polars_nb(df):
    n = df["group"].to_numpy()
    indices = np.unique(n, return_index=True)[1]
    amount = np.diff(np.r_[indices, [n.size]])
    return df.with_columns(
        pl.Series(
            name="new_udf",
            values=UDF_nb_parallel_2(df["index"].to_numpy(), indices, amount, 5),
        )
    )

df = get_udf_polars_nb(df)

基准:

import random
from timeit import timeit

import numpy as np
import polars as pl
from numba import njit, prange


def get_df(N):
    groups = []
    cnt, group_no, running = 0, 1, True
    while running:
        for _ in range(random.randint(3, 10)):
            groups.append(group_no)
            cnt += 1
            if cnt >= N:
                running = False
                break
        group_no += 1

    df = pl.DataFrame(
        {
            "group": groups,
            "index": np.random.randint(1, 5, size=N, dtype="uint64"),
        }
    )
    return df


@njit
def UDF(array: np.ndarray, threshold: int) -> np.ndarray:
    result = np.zeros(array.shape[0])
    accumulator = 0

    for i, value in enumerate(array):
        accumulator += value
        if accumulator >= threshold:
            result[i] = 1
            accumulator = 0

    return result


@njit(parallel=True)
def UDF_nb_parallel_2(array, indices, amount, threshold):
    result = np.zeros_like(array, dtype="uint8")

    for i in prange(indices.size):
        accumulator = 0
        for j in range(indices[i], indices[i] + amount[i]):
            value = array[j]
            accumulator += value
            if accumulator >= threshold:
                result[j] = 1
                accumulator = 0

    return result


def get_udf_polars(df):
    return df.with_columns(
        pl.col("index")
        .map_batches(lambda x: UDF(x.to_numpy(), 5))
        .over("group")
        .cast(pl.UInt8)
        .alias("udf")
    )


def get_udf_polars_nb(df):
    n = df["group"].to_numpy()
    indices = np.unique(n, return_index=True)[1]
    amount = np.diff(np.r_[indices, [n.size]])
    return df.with_columns(
        pl.Series(
            name="new_udf",
            values=UDF_nb_parallel_2(df["index"].to_numpy(), indices, amount, 5),
        )
    )


df = get_df(100_000)  # 100_000 values, length of groups length 3-9

df = get_udf_polars(df)
df = get_udf_polars_nb(df)

assert np.allclose(df["udf"].to_numpy(), df["new_udf"].to_numpy())


t1 = timeit("get_udf_polars(df)", number=1, globals=globals())
t2 = timeit("get_udf_polars_nb(df)", number=1, globals=globals())

print(t1)
print(t2)

打印:

1.2675148629932664
0.0024339070077985525

Python相关问答推荐

运行回文查找器代码时发生错误:[类型错误:builtin_index_or_system对象不可订阅]

为什么带有dropna=False的groupby会阻止后续的MultiIndex.dropna()工作?

2D空间中的反旋算法

如何列举Pandigital Prime Set

如何使用根据其他值相似的列从列表中获取的中间值填充空NaN数据

计算天数

Polars asof在下一个可用日期加入

如何防止Pandas将索引标为周期?

基于多个数组的多个条件将值添加到numpy数组

使用字典或列表的值组合

语法错误:文档. evaluate:表达式不是合法表达式

使用嵌套对象字段的Qdrant过滤

从嵌套极轴列的列表中删除元素

如何在Gekko中处理跨矢量优化

python3中np. divide(x,y)和x/y有什么区别?'

Pandas 删除只有一种类型的值的行,重复或不重复

函数()参数';代码';必须是代码而不是字符串

Python:在cmd中添加参数时的语法

用LAKEF划分实木地板AWS Wrangler

如何定义一个将类型与接收该类型的参数的可调用进行映射的字典?