我有下面的代码,我想将Pandas UDF重写为纯窗口函数,用于速度优化

cumulative_pass列是我想要以编程方式创建的-

import pandas as pd
import pyspark.sql.functions as F
from pyspark.sql import SparkSession
from pyspark.sql import Window
import sys 

spark_session = SparkSession.builder.getOrCreate()
 

df_data = {'username': ['bob','bob', 'bob', 'bob', 'bob', 'bob', 'bob', 'bob'],
           'session': [1,2,3,4,5,6,7,8],
           'year_start': [2020,2020,2020,2020,2020,2021,2022,2023],
           'year_end': [2020,2020,2020,2020,2021,2021,2022,2023],
           'pass': [1,0,0,0,0,1,1,0],
           'cumulative_pass': [0,0,0,0,0,1,2,3],
          }
df_pandas = pd.DataFrame.from_dict(df_data)


df = spark_session.createDataFrame(df_pandas)
df.show()

最后的show个将是这个-

+--------+-------+----------+--------+----+---------------+
|username|session|year_start|year_end|pass|cumulative_pass|
+--------+-------+----------+--------+----+---------------+
|     bob|      1|      2020|    2020|   1|              0|
|     bob|      2|      2020|    2020|   0|              0|
|     bob|      3|      2020|    2020|   0|              0|
|     bob|      4|      2020|    2020|   0|              0|
|     bob|      5|      2020|    2021|   0|              0|
|     bob|      6|      2021|    2021|   1|              1|
|     bob|      7|      2022|    2022|   1|              2|
|     bob|      8|      2023|    2023|   0|              3|
+--------+-------+----------+--------+----+---------------+

下面的代码可以工作,但速度很慢(UDF速度很慢)

def conditional_sum(data: pd.DataFrame) -> int:
   df = data.apply(pd.Series)

    return df.loc[df['year_start'].max() > df['year_end']]['pass'].sum()

udf_conditional_sum = F.pandas_udf(conditional_sum, IntegerType())

w = Window.partitionBy("username").orderBy(F.asc("year_start")).rowsBetween(-sys.maxsize, 0)
df = df.withColumn("calculate_cumulative_pass", udf_conditional_sum(F.struct("year_start", "year_end", "pass")).over(w))

注意--我修改了w个,删除了第二个排序

推荐答案

Code

W = Window.partitionBy('username').orderBy('year_start')
df = (
    df
    .withColumn('cumulative_pass',  F.collect_list(F.struct('year_end', 'pass')).over(W))
    .withColumn('cumulative_pass',  F.expr("AGGREGATE(cumulative_pass, 0, (acc, x) -> CAST(acc + IF(x['year_end'] < year_start, x['pass'], 0) AS INT))"))
)

How this works

创建一个窗口规范,并收集前面所有行的year_endpass对值.当配对中的year_end小于当前行的year_start时,将配对和sum配对中的pass个值相加.

Result

+--------+-------+----------+--------+----+---------------+
|username|session|year_start|year_end|pass|cumulative_pass|
+--------+-------+----------+--------+----+---------------+
|bob     |1      |2020      |2020    |1   |0              |
|bob     |2      |2020      |2020    |0   |0              |
|bob     |3      |2020      |2020    |0   |0              |
|bob     |4      |2020      |2020    |0   |0              |
|bob     |5      |2020      |2021    |0   |0              |
|bob     |6      |2021      |2021    |1   |1              |
|bob     |7      |2022      |2022    |1   |2              |
|bob     |8      |2023      |2023    |0   |3              |
+--------+-------+----------+--------+----+---------------+

Python相关问答推荐

使用Curses for Python保存和恢复终端窗口内容

如何才能将每个组比上一组增加N %?

如何知道标志是否由用户传递或具有默认值?

了解shuffle在NP.random.Generator.choice()中的作用

ambda将时间戳与组内另一列的所有时间戳进行比较

使用polars .滤镜进行切片速度比pandas .loc慢

我从带有langchain的mongoDB中的vector serch获得一个空数组

点到面的Y距离

韦尔福德方差与Numpy方差不同

图像 pyramid .难以创建所需的合成图像

切片包括面具的第一个实例在内的眼镜的最佳方法是什么?

如何使用根据其他值相似的列从列表中获取的中间值填充空NaN数据

如何在WSL2中更新Python到最新版本(3.12.2)?

将输入聚合到统一词典中

连接一个rabrame和另一个1d rabrame不是问题,但当使用[...]'运算符会产生不同的结果

什么是最好的方法来切割一个相框到一个面具的第一个实例?

可以bcrypts AES—256 GCM加密损坏ZIP文件吗?

人口全部乱序 - Python—Matplotlib—映射

如何在Gekko中使用分层条件约束

在Python中控制列表中的数据步长