示例代码.第cumulative_pass列是我想要以编程方式创建的-

import pandas as pd
import pyspark.sql.functions as F
from pyspark.sql import SparkSession
from pyspark.sql import Window
import sys 

spark_session = SparkSession.builder.getOrCreate()
 

df_data = {'username': ['bob','bob', 'bob', 'bob', 'bob', 'bob', 'bob', 'bob'],
           'session': [1,2,3,4,5,6,7,8],
           'year_start': [2020,2020,2020,2020,2020,2021,2022,2023],
           'year_end': [2020,2020,2020,2020,2021,2021,2022,2023],
           'pass': [1,0,0,0,0,1,1,0],
           'cumulative_pass': [0,0,0,0,0,1,2,3],
          }
df_pandas = pd.DataFrame.from_dict(df_data)


df = spark_session.createDataFrame(df_pandas)
df.show()

最后的show个将是这个-

+--------+-------+----------+--------+----+---------------+
|username|session|year_start|year_end|pass|cumulative_pass|
+--------+-------+----------+--------+----+---------------+
|     bob|      1|      2020|    2020|   1|              0|
|     bob|      2|      2020|    2020|   0|              0|
|     bob|      3|      2020|    2020|   0|              0|
|     bob|      4|      2020|    2020|   0|              0|
|     bob|      5|      2020|    2021|   0|              0|
|     bob|      6|      2021|    2021|   1|              1|
|     bob|      7|      2022|    2022|   1|              2|
|     bob|      8|      2023|    2023|   0|              3|
+--------+-------+----------+--------+----+---------------+

当当前行的year_start是先前行的year_end时,cumulative_pass列将对所有先前行的pass列求和

我的try (由于语法原因不起作用)-

def conditional_sum(data: pd.DataFrame) -> int:
   # df = data.apply(pd.Series)  # transform dict into separate columns

    return df.loc[df['year_start'].max() > df['year_end']]['pass'].sum()

udf_conditional_sum = F.pandas_udf(conditional_sum, IntegerType())

w = Window.partitionBy("username").orderBy(F.asc("year_start"), F.desc("year_end")).rowsBetween(-sys.maxsize, -1)
df = df.withColumn("calculate_cumulative_pass", udf_conditional_sum(F.struct("year_start", "year_end", "pass")).over(w))

基于https://stackoverflow.com/a/73278159/5004050的代码

推荐答案

对数据帧执行自合并,使左侧数据帧中的year_start大于右侧数据帧中的year_end,然后将生成的数据帧按左侧数据帧中的列进行分组,并将AGG passSUM进行分组,以获得所需的累积和.

df.createOrReplaceTempView('df')
df1 = spark.sql(
"""
SELECT
    A.username, A.session, 
    A.year_start, A.year_end, A.pass, 
    COALESCE(SUM(B.pass), 0) AS cummulative_sum
FROM
    df AS A
LEFT JOIN
    df AS B
ON
    A.year_start > B.year_end
GROUP BY
    A.username, A.session, 
    A.year_start, A.year_end, A.pass
""")

df1.show()
+--------+-------+----------+--------+----+---------------+
|username|session|year_start|year_end|pass|cummulative_sum|
+--------+-------+----------+--------+----+---------------+
|     bob|      1|      2020|    2020|   1|              0|
|     bob|      2|      2020|    2020|   0|              0|
|     bob|      3|      2020|    2020|   0|              0|
|     bob|      4|      2020|    2020|   0|              0|
|     bob|      5|      2020|    2021|   0|              0|
|     bob|      6|      2021|    2021|   1|              1|
|     bob|      7|      2022|    2022|   1|              2|
|     bob|      8|      2023|    2023|   0|              3|
+--------+-------+----------+--------+----+---------------+

Python相关问答推荐

配置Sweetviz以分析对象类型列,而无需转换

滚动和,句号来自Pandas列

将jit与numpy linSpace函数一起使用时出错

删除最后一个pip安装的包

使用索引列表列表对列进行切片并获取行方向的向量长度

删除所有列值,但判断是否存在任何二元组

如何让Flask 中的请求标签发挥作用

海上重叠直方图

Pandas—在数据透视表中占总数的百分比

在ubuntu上安装dlib时出错

如果满足某些条件,则用另一个数据帧列中的值填充空数据帧或数组

如何在Python中找到线性依赖mod 2

SQLAlchemy bindparam在mssql上失败(但在mysql上工作)

Plotly Dash Creating Interactive Graph下拉列表

从源代码显示不同的输出(机器学习)(Python)

如何将一组组合框重置回无 Select tkinter?

如何设置nan值为numpy数组多条件

如何防止html代码出现在quarto gfm报告中的pandas表之上

为什么在生成时间序列时,元组索引会超出范围?

Django查询集-排除True值