我创建了一个回归模型字典,从训练数据集d中以group的值为索引

import numpy as np
import pandas as pd
from sklearn.linear_model import LinearRegression
from sklearn.pipeline import Pipeline

d = pd.DataFrame({
    "group":["cat","fish","horse","cat","fish","horse","cat","horse"],
    "x":[1,4,7,2,5,8,3,9],
    "y":[10,20,14,12,12,3,12,2],
    "z":[3,5,3,5,9,1,2,3]
})

features, models =['x','z'],{}
for animal in ['horse','cat','fish']:
    models[animal] = Pipeline([("estimator",LinearRegression(fit_intercept=True))])
    x,y = d.loc[d.group==animal,features],d.loc[d.group==animal,"y"]
    models[animal].fit(x,y)

我还有一个测试数据集test_d,其中有一些组的行,但不是所有的组(即所有模型).

test_d = pd.DataFrame({
    "group":["dog","fish","horse","dog","fish","horse","dog","horse"],
    "x":[1,2,3,4,5,6,7,8],
    "z":[3,5,3,5,9,1,2,3]
})

我想在分组的test_d上使用apply,利用.name查找正确的模型(如果存在),并使用函数f()返回预测

def f(g):
    try:
        predictions = models[g.name].predict(g[features])
    except:
        predictions = [None]*len(g)
    return predictions

函数"工作"的意思是返回正确的值

grouping_column ="group"
test_d.groupby(grouping_column, group_keys=False).apply(f)

输出:

group
dog                           [None, None, None]
fish     [20.94117647058824, 12.000000000000004]
horse                          [38.0, 15.0, 8.0]
dtype: object

问题:

应该如何写f(),以便我可以直接将值分配给test_d?我想这样做:

test_d["predictions"] = test_d.groupby(grouping_column, group_keys=False).apply(f)

但这显然行不通.

   group  x  z predictions
0    dog  1  3         NaN
1   fish  2  5         NaN
2  horse  3  3         NaN
3    dog  4  5         NaN
4   fish  5  9         NaN
5  horse  6  1         NaN
6    dog  7  2         NaN
7  horse  8  3         NaN

预期yields

   group  x  z  predictions
0    dog  1  3          NaN
1   fish  2  5    20.941176
2  horse  3  3    38.000000
3    dog  4  5          NaN
4   fish  5  9    12.000000
5  horse  6  1    15.000000
6    dog  7  2          NaN
7  horse  8  3     8.000000

推荐答案

函数f应该返回一个带有原始索引的序列:

def f(g):
    try:
        predictions = models[g.name].predict(g[features])
    except:
        predictions = [None]*len(g)
    return pd.Series(predictions, index=g.index)

test_d.groupby('group', group_keys=False).apply(f)

输出:

0         None
3         None
6         None
1    20.941176
4         12.0
2         38.0
5         15.0
7          8.0
dtype: object

因此,如果指定,索引将对齐:

test_d['predictions'] = test_d.groupby('group', group_keys=False).apply(f)

输出:

   group  x  z predictions
0    dog  1  3        None
1   fish  2  5   20.941176
2  horse  3  3        38.0
3    dog  4  5        None
4   fish  5  9        12.0
5  horse  6  1        15.0
6    dog  7  2        None
7  horse  8  3         8.0

Python相关问答推荐

删除最后一个pip安装的包

不理解Value错误:在Python中使用迭代对象设置时必须具有相等的len键和值

Pandas - groupby字符串字段并按时间范围 Select

如何在polars(pythonapi)中解构嵌套 struct ?

如何获取numpy数组的特定索引值?

在Python argparse包中添加formatter_class MetavarTypeHelpFormatter时, - help不再工作""""

如何更新pandas DataFrame上列标题的de值?

numpy.unique如何消除重复列?

将标签移动到matplotlib饼图中楔形块的开始处

基于Scipy插值法的三次样条系数

Gekko中基于时间的间隔约束

如何从pandas DataFrame中获取. groupby()和. agg()之后的子列?

如何按row_id/row_number过滤数据帧

如何在GEKKO中使用复共轭物

从嵌套极轴列的列表中删除元素

为什么Visual Studio Code说我的代码在使用Pandas concat函数后无法访问?

在MongoDB文档中仅返回数组字段

将标签与山脊线图对齐

极地数据帧:ROLING_SUM向前看

以元组为索引的Numpy多维索引