我试图在两个值列上透视一个表,在透视的数据帧上应用一些用户定义的函数,然后取消透视(堆栈在Pandas 中).在Pandas 身上,这应该是这样的:

import pandas as pd
import polars as pl
from polars import col, lit, selectors as cs

df = pl.DataFrame(
    {
        "foo": [1, 1, 2, 2, 3, 3],
        "bar": ["y", "x", "y", "x", "y", "x"],
        "baz1": [1, 2, None, 4, 5, None],
        "baz2": [1, None, 3, 4, None, 6]
    }
)
df
'''
┌─────┬─────┬──────┬──────┐
│ foo ┆ bar ┆ baz1 ┆ baz2 │
│ --- ┆ --- ┆ ---  ┆ ---  │
│ i64 ┆ str ┆ i64  ┆ i64  │
╞═════╪═════╪══════╪══════╡
│ 1   ┆ y   ┆ 1    ┆ 1    │
│ 1   ┆ x   ┆ 2    ┆ null │
│ 2   ┆ y   ┆ null ┆ 3    │
│ 2   ┆ x   ┆ 4    ┆ 4    │
│ 3   ┆ y   ┆ 5    ┆ null │
│ 3   ┆ x   ┆ null ┆ 6    │
└─────┴─────┴──────┴──────┘
'''

pd_df = df.to_pandas()

index_col = ['foo']
columns_col = ['bar']
values_col = ['baz1', 'baz2']

def pd_udf(df): # for example purposes, let's assume the function can be more complex
    return (
        df.ffill() * 3
    )

pd_res = (
    pd_df.groupby(index_col + columns_col).first() # or .set_index(index_col + column) for same result as no duplicates
    .unstack()
    .pipe(pd_udf)
    .stack()
    .reset_index()
    .sort_values(index_col + columns_col)
    .pipe(pl.from_pandas)
)
pd_res
'''
┌─────┬─────┬──────┬──────┐
│ foo ┆ bar ┆ baz1 ┆ baz2 │
│ --- ┆ --- ┆ ---  ┆ ---  │
│ i64 ┆ str ┆ f64  ┆ f64  │
╞═════╪═════╪══════╪══════╡
│ 1   ┆ x   ┆ 6.0  ┆ null │
│ 1   ┆ y   ┆ 3.0  ┆ 3.0  │
│ 2   ┆ x   ┆ 12.0 ┆ 12.0 │
│ 2   ┆ y   ┆ 3.0  ┆ 9.0  │
│ 3   ┆ x   ┆ 12.0 ┆ 18.0 │
│ 3   ┆ y   ┆ 15.0 ┆ 9.0  │
└─────┴─────┴──────┴──────┘
'''

我找到了两种方法来达到同样的结果,如果有更好的方法来达到我想要的结果,请让我知道.

  1. With pivot, then melt, some code and another pivot

这不是最好的解决方案,由于旋转后列的f'{value_name}_{column_name}_{column_value}格式,堆栈位是通过相当多的操作实现的.

def pl_udf(df):
    return (
        df.with_columns(
            pl.exclude(index_col).forward_fill() * lit(3)
        )
    )

lazy_df_1 = (
    df
    .pivot(values = values_col,
            index = index_col,
            columns = columns_col
            )
    .lazy()
    .pipe(pl_udf)
    
    # pd stack bit
    .melt(id_vars = index_col)
    .select(
        col(index_col+['value']),
        col('variable').str.split('_').list.get(0).alias('temp'),
        col('variable').str.split('_').list.get(2).alias(columns_col[0]),
    )
    .collect()
    .pivot(values = 'value',
          index = index_col+columns_col,
          columns = 'temp')
    .lazy()
    .sort(index_col+columns_col)
)
lazy_df_1.collect()
  1. With group_by and explode with a user-defined function applied to Series

应用于系列的UDF需要将系列转换为数据帧,以便能够利用数据帧方法.再说一次,我不确定这是最好的解决方案.

def pl_udf_series(s):
    '''
    to apply to series directly
    '''
    return (
        s.to_frame() # to use with dataframe functions (my udf will use dataframe functions)
        .select(col(s.name).forward_fill() * lit(3))
        .to_series().to_list()
    )

lazy_df_2 = (
    df.lazy()
    .group_by(columns_col)
    .agg(
        col(index_col),
        col(values_col).map_elements(lambda x: pl_udf_series(x))  
    )
    .explode(columns=index_col + values_col)
    .sort(index_col+columns_col)
    .select(col(index_col+columns_col+values_col)) # reordering
)
lazy_df_2.collect()

这两种实现都提供了预期的结果:

pd_res.equals(lazy_df_1.collect())
# True
pd_res.equals(lazy_df_2.collect())
# True

在性能方面:

  1. Pandas 解决方案~1000微秒
  2. 枢轴、融化、枢轴~380微秒
  3. GROUP_BY,爆炸~450微秒(我见过一些2比1快的情况)

推荐答案

您真的需要旋转/取消旋转您的数据帧吗?您可以使用over()方法在组内应用函数:

def func1(col): # for example purposes, let's assume the function can be more complex
    return col.forward_fill().over("bar") * lit(3)

res = df.with_columns(func1(pl.all().exclude(['foo', 'bar'])))
print(res.sort(['foo','bar']))

┌─────┬─────┬──────┬──────┐
│ foo ┆ bar ┆ baz1 ┆ baz2 │
│ --- ┆ --- ┆ ---  ┆ ---  │
│ i64 ┆ str ┆ i64  ┆ i64  │
╞═════╪═════╪══════╪══════╡
│ 1   ┆ x   ┆ 6    ┆ null │
│ 1   ┆ y   ┆ 3    ┆ 3    │
│ 2   ┆ x   ┆ 12   ┆ 12   │
│ 2   ┆ y   ┆ 3    ┆ 9    │
│ 3   ┆ x   ┆ 12   ┆ 18   │
│ 3   ┆ y   ┆ 15   ┆ 9    │
└─────┴─────┴──────┴──────┘

Python-3.x相关问答推荐

在numpy. linalg的qr之后使用scipy. integrate中的solve_ivp时出现了一个奇怪的错误

Pandas 插入的速度太慢了.对于跟踪代码,什么是更快的替代方案?

Strawberry FastAPI:如何调用正确的函数?

Django 模型类方法使用错误的 `self`

如何查找以开头并替换的字符串

如何在 20 秒后重复使用 Pillow 在现有图像上创建新图像?

如何计算Pandas 列中每列唯一项目的出现次数?

Jupyter Notebook 拒绝打印一些字符串

通过 requests 库调用 API 获取访问令牌

是否有与 Laravel 4 等效的 python?

Asyncio RuntimeError:事件循环已关闭

在 Python 3 中获取所有超类

AttributeError:系列对象没有属性iterrows

Python的max函数有多高效

如何区分文件之类的对象和文件路径之类的对象

如何避免使用我的 python 包构建 C 库?

有没有一种标准方法来确保 python 脚本将由 python2 而不是 python3 解释?

有效地判断一个元素是否在列表中至少出现 n 次

注册 Celery 基于类的任务

十六进制字符串到 Python 3.2 中的带符号整数?