我有一个PySpark DataFrame,看起来像这样:
df = spark.createDataFrame(
data=[
(1, "GERMANY", "20230606", True),
(2, "GERMANY", "20230620", False),
(3, "GERMANY", "20230627", True),
(4, "GERMANY", "20230705", True),
(5, "GERMANY", "20230714", False),
(6, "GERMANY", "20230715", True),
],
schema=["ID", "COUNTRY", "DATE", "FLAG"]
)
df.show()
+---+-------+--------+-----+
| ID|COUNTRY| DATE| FLAG|
+---+-------+--------+-----+
| 1|GERMANY|20230606| true|
| 2|GERMANY|20230620|false|
| 3|GERMANY|20230627| true|
| 4|GERMANY|20230705| true|
| 5|GERMANY|20230714|false|
| 6|GERMANY|20230715| true|
+---+-------+--------+-----+
DataFrame有更多的国家.我想创建一个新的列COUNT_WITH_RESET
,遵循以下逻辑:
- 如果是
FLAG=False
,那么是COUNT_WITH_RESET=0
. - 如果为
FLAG=True
,则COUNT_WITH_RESET
应计算从上次日期开始的行数,其中FLAG=False
表示该特定国家/地区.
这应该是上面示例的输出.
+---+-------+--------+-----+----------------+
| ID|COUNTRY| DATE| FLAG|COUNT_WITH_RESET|
+---+-------+--------+-----+----------------+
| 1|GERMANY|20230606| true| 1|
| 2|GERMANY|20230620|false| 0|
| 3|GERMANY|20230627| true| 1|
| 4|GERMANY|20230705| true| 2|
| 5|GERMANY|20230714|false| 0|
| 6|GERMANY|20230715| true| 1|
+---+-------+--------+-----+----------------+
我试着把row_number()
个放在windows 上,但我没办法重置计数.我也试过.rowsBetween(Window.unboundedPreceding, Window.currentRow)
.以下是我的做法:
from pyspark.sql.window import Window
import pyspark.sql.functions as F
window_reset = Window.partitionBy("COUNTRY").orderBy("DATE")
df_with_reset = (
df
.withColumn("COUNT_WITH_RESET", F.when(~F.col("FLAG"), 0)
.otherwise(F.row_number().over(window_reset)))
)
df_with_reset.show()
+---+-------+--------+-----+----------------+
| ID|COUNTRY| DATE| FLAG|COUNT_WITH_RESET|
+---+-------+--------+-----+----------------+
| 1|GERMANY|20230606| true| 1|
| 2|GERMANY|20230620|false| 0|
| 3|GERMANY|20230627| true| 3|
| 4|GERMANY|20230705| true| 4|
| 5|GERMANY|20230714|false| 0|
| 6|GERMANY|20230715| true| 6|
+---+-------+--------+-----+----------------+
这显然是错误的,因为我的窗口仅按国家/地区进行分区,但我的思路正确吗?在PySpark中有没有特定的内置函数来实现这一点?我需要UDF吗?任何帮助都将不胜感激.