我有下面的代码,我想将Pandas UDF重写为纯窗口函数,用于速度优化
第cumulative_pass
列是我想要以编程方式创建的-
import pandas as pd
import pyspark.sql.functions as F
from pyspark.sql import SparkSession
from pyspark.sql import Window
import sys
spark_session = SparkSession.builder.getOrCreate()
df_data = {'username': ['bob','bob', 'bob', 'bob', 'bob', 'bob', 'bob', 'bob'],
'session': [1,2,3,4,5,6,7,8],
'year_start': [2020,2020,2020,2020,2020,2021,2022,2023],
'year_end': [2020,2020,2020,2020,2021,2021,2022,2023],
'pass': [1,0,0,0,0,1,1,0],
'cumulative_pass': [0,0,0,0,0,1,2,3],
}
df_pandas = pd.DataFrame.from_dict(df_data)
df = spark_session.createDataFrame(df_pandas)
df.show()
最后的show
个将是这个-
+--------+-------+----------+--------+----+---------------+
|username|session|year_start|year_end|pass|cumulative_pass|
+--------+-------+----------+--------+----+---------------+
| bob| 1| 2020| 2020| 1| 0|
| bob| 2| 2020| 2020| 0| 0|
| bob| 3| 2020| 2020| 0| 0|
| bob| 4| 2020| 2020| 0| 0|
| bob| 5| 2020| 2021| 0| 0|
| bob| 6| 2021| 2021| 1| 1|
| bob| 7| 2022| 2022| 1| 2|
| bob| 8| 2023| 2023| 0| 3|
+--------+-------+----------+--------+----+---------------+
下面的代码可以工作,但速度很慢(UDF速度很慢)
def conditional_sum(data: pd.DataFrame) -> int:
df = data.apply(pd.Series)
return df.loc[df['year_start'].max() > df['year_end']]['pass'].sum()
udf_conditional_sum = F.pandas_udf(conditional_sum, IntegerType())
w = Window.partitionBy("username").orderBy(F.asc("year_start")).rowsBetween(-sys.maxsize, 0)
df = df.withColumn("calculate_cumulative_pass", udf_conditional_sum(F.struct("year_start", "year_end", "pass")).over(w))
注意--我修改了w
个,删除了第二个排序