我需要使用数据帧上的滑动窗口来计算一些度量.如果metric只需要一列,我会使用rolling.但有些人认为它无法与2+列一起工作.

def mean_squared_error(aa, bb):
    return np.sum((aa - bb) ** 2) / len(aa)

def rolling_metric(df_, col_a, col_b, window, metric_fn):
    result = []
    for i, id_ in enumerate(df_.index):
        if i < (df_.shape[0] - window + 1):
            slice_idx = df_.index[i: i+window-1]
            slice_a, slice_b = df_.loc[slice_idx, col_a], df_.loc[slice_idx, col_b]
            result.append(metric_fn(slice_a, slice_b))
        else:
            result.append(None)
    return pd.Series(data = result, index = df_.index)

df = pd.DataFrame(data=(np.random.rand(1000, 2)*10).round(2), columns = ['y_true', 'y_pred'] )

%time df2 = rolling_metric(df, 'y_true', 'y_pred', window=7, metric_fn=mean_squared_error)

仅1000行就需要将近1秒的时间.

请建议更快的矢量化方法来计算滑动窗口上的此类度量.

推荐答案

在这种情况下:

您可以预先计算平方误差,然后使用.Rolling.mean():

df['sq_error'] = (df['y_true'] - df['y_pred'])**2

%time df['sq_error'].rolling(6).mean().dropna()

请注意,在您的示例中,实际窗口大小是6(打印切片长度),这就是我在代码片段中将其设置为6的原因.

你甚至可以这样写:

%time df['y_true'].subtract(df['y_pred']).pow(2).rolling(6).mean().dropna()

一般来说:

如果不能将其缩减为一列,那么从pandas 1.3.0开始,可以使用method='table参数将函数应用于整个数据帧.然而,这有以下要求:

  • 这仅在使用numba发动机时实施.所以,你需要在apply中设置engine='numba'并安装它.
  • 您需要在apply中设置raw=True:这意味着在您的函数中,您将对numpy个数组而不是数据帧进行操作.这是前一点的结果.

因此,你的计算可以是这样的:

WIN_LEN = 6

def mean_sq_err_table(arr, min_window=WIN_LEN):
    if len(arr) < min_window:
        return np.nan
    else:
        return np.mean((arr[:, 0] - arr[:, 1])**2)
    
df.rolling(WIN_LEN, method='table').apply(mean_sq_err_table, engine='numba', raw=True).dropna()

因为它使用numba,这也是相对较快的.

Python相关问答推荐

根据二元组列表在pandas中创建新列

如何在Python脚本中附加一个Google tab(已经打开)

当独立的网络调用不应该互相阻塞时,'

如何在给定的条件下使numpy数组的计算速度最快?

在不同的帧B中判断帧A中的子字符串,每个帧的大小不同

在Python中从嵌套的for循环中获取插值

为什么Python内存中的列表大小与文档不匹配?

Odoo16:模板中使用的docs变量在哪里定义?

比Pandas 更好的 Select

如何反转一个框架中列的值?

极点替换值大于组内另一个极点数据帧的最大值

来自Airflow Connection的额外参数

将相应的值从第2列合并到第1列(Pandas )

具有不同坐标的tkinter canvs.cocords()和canvs.moveto()

某些值的数值幂和**之间的差异

基于2级列表的Pandas 切片3级多索引

为什么在安装了64位Python的64位Windows 10上以32位运行?

Python键盘模块不会立即检测到按键

为什么任何一个HTML页面在保存到文件后都会变大6个字节?

nameError_C未定义