我有一个表示多个数据库中数据的平面数据帧,其中每个数据库都有多个表,每个表都有多个列,每个列都有多个值:
df = pl.DataFrame(
{
'db_id': ["db_1", "db_1", "db_1", "db_2", "db_2", "db_2"],
'table_id': ['tab_1', 'tab_1', 'tab_2', 'tab_1', 'tab_2', 'tab_2'],
'column_id': ['col_1', 'col_2', 'col_1', 'col_2', 'col_1', 'col_3'],
'data': [[1, 2, 3], [10, 20, 30], [4, 5], [40, 50], [6], [60]]
}
)
db_id | table_id | column_id | data |
---|---|---|---|
"db_1" | "tab_1" | "col_1" | 1, 2, 3 |
"db_1" | "tab_1" | "col_2" | 10, 20, 30 |
"db_1" | "tab_2" | "col_1" | 4, 5 |
"db_2" | "tab_1" | "col_2" | 40, 50 |
"db_2" | "tab_2" | "col_1" | 6 |
"db_2" | "tab_2" | "col_3" | 60 |
如您所见,不同的数据库共享一些表,而表共享一些列.
我想从上面的数据帧中每table_id
提取一个数据帧,其中提取的数据帧被转置和分解,即提取的数据帧应该具有与特定table_id
(加db_id
)相对应的column_id
的集合作为其列,值是data
中的对应值.也就是说,对于上面的例子,结果应该是一个带有关键字"TAB_1"和"TAB_2"的字典,值是以下数据帧:
表1:
db_id | col_1 | col_2 |
---|---|---|
"db_1" | 1 | 10 |
"db_1" | 2 | 20 |
"db_1" | 3 | 30 |
"db_2" | null | 40 |
"db_2" | null | 50 |
表2:
db_id | col_1 | col_3 |
---|---|---|
"db_1" | 4 | null |
"db_1" | 5 | null |
"db_2" | 6 | 60 |
我有一个正常工作的函数可以做到这一点(见下文),但它有点慢.所以,我想知道有没有更快的方法来实现这一点?
这是我目前的解决方案:
def dataframe_per_table(
df: pl.DataFrame,
col_name__table_id: str = "table_id",
col_name__col_id: str = "column_id",
col_name__values: str = "data",
col_name__other_ids: Sequence[str] = ("db_id", )
) -> Dict[str, pl.DataFrame]:
col_name__other_ids = list(col_name__other_ids)
table_dfs = {}
for (table_name, *_), table in df.groupby(
[col_name__table_id] + col_name__other_ids
):
new_table = table.select(
pl.col(col_name__other_ids + [col_name__col_id, col_name__values])
).pivot(
index=col_name__other_ids,
columns=col_name__col_id,
values=col_name__values,
aggregate_function=None,
).explode(
columns=table[col_name__col_id].unique().to_list()
)
table_dfs[table_name] = pl.concat(
[table_dfs.setdefault(table_name, pl.DataFrame()), new_table],
how="diagonal"
)
return table_dfs
Update: Benchmarking/Summary of Answers
在一个大约有250万行的数据帧上,我最初的解决方案需要大约70 minutes行才能完成.
Disclaimer: since the execution times were too long, I only timed each solution once (i.e. 1 run, 1 loop), so the margin of error is large.个
然而,在发布了这个问题后,我意识到我可以通过在单独的循环中执行concat
来加快速度,这样每个最终数据帧都是由一次concat
操作创建的,而不是许多次:
def dataframe_per_table_v2(
df: pl.DataFrame,
col_name__table_id: str = "table_id",
col_name__col_id: str = "column_id",
col_name__values: str = "data",
col_name__other_ids: Sequence[str] = ("db_id", )
) -> Dict[str, pl.DataFrame]:
col_name__other_ids = list(col_name__other_ids)
table_dfs = {}
for (table_name, *_), table in df.groupby(
[col_name__table_id] + col_name__other_ids
):
new_table = table.select(
pl.col(col_name__other_ids + [col_name__col_id, col_name__values])
).pivot(
index=col_name__other_ids,
columns=col_name__col_id,
values=col_name__values,
aggregate_function=None,
).explode(
columns=table[col_name__col_id].unique().to_list()
)
# Up until here nothing is changed.
# Now, instead of directly concatenating, we just
# append the new dataframe to a list
table_dfs.setdefault(table_name, list()).append(new_table)
# Now, in a separate loop, each final dataframe is created
# by concatenating all collected dataframes once.
for table_name, table_sub_dfs in table_dfs.items():
table_dfs[table_name] = pl.concat(
table_sub_dfs,
how="diagonal"
)
return table_dfs
这将时间从70分钟减少到大约10 min分钟;好多了,但仍然太长了.
相比之下,answer by @jqurious只花了大约5 min只.它在末尾需要一个额外的步骤来删除不需要的列,并从列表中获得一个词典,但它仍然快得多.
然而,到目前为止,获胜者是前answer by @Dean MacGregor名,他们只获得了50 seconds名,并直接产生了预期的输出.
以下是他们的解决方案重写为一个函数:
def dataframe_per_table_v3(
df: pl.DataFrame,
col_name__table_id: str = "table_id",
col_name__col_id: str = "column_id",
col_name__values: str = "data",
col_name__other_ids: Sequence[str] = ("db_id", )
) -> Dict[str, pl.DataFrame]:
table_dfs = {
table_id: df.filter(
pl.col(col_name__table_id) == table_id
).with_columns(
idx_data=pl.arange(0, pl.col(col_name__values).arr.lengths())
).explode(
[col_name__values, 'idx_data']
).pivot(
values=col_name__values,
index=[*col_name__other_ids, 'idx_data'],
columns=col_name__col_id,
aggregate_function='first'
).drop(
'idx_data'
) for table_id in df.get_column(col_name__table_id).unique()
}
return table_dfs