一个 idea 是使用numba
通过参数engine='numba'
在Rolling.apply
中更快地输出计数:
(tmp.rolling(window=3, min_periods=1)
.apply(lambda x: x[~np.isnan(x)][-2:].mean(), engine='numba', raw=True))
Test performance:
tmp = pd.concat([tmp] * 100000, ignore_index=True)
In [88]: %timeit tmp.rolling(window=3, min_periods=1).apply(lambda x: x[~np.isnan(x)][-2:].mean(),engine='numba', raw=True)
901 ms ± 6.76 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [89]: %timeit tmp.rolling(window=3, min_periods=1).apply(lambda x: x[~np.isnan(x)][-2:].mean(), raw=True)
13 s ± 181 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Numpy approach:
您可以将DataFrame转换为3d数组,并附加前NaN
个值,然后将非NaN
个值转移并获取含义:
#https://stackoverflow.com/a/44559180/2901002
def justify_nd(a, invalid_val, axis, side):
"""
Justify ndarray for the valid elements (that are not invalid_val).
Parameters
----------
A : ndarray
Input array to be justified
invalid_val : scalar
invalid value
axis : int
Axis along which justification is to be made
side : str
Direction of justification. Must be 'front' or 'end'.
So, with 'front', valid elements are pushed to the front and
with 'end' valid elements are pushed to the end along specified axis.
"""
pushax = lambda a: np.moveaxis(a, axis, -1)
if invalid_val is np.nan:
mask = ~np.isnan(a)
else:
mask = a!=invalid_val
justified_mask = np.sort(mask,axis=axis)
if side=='front':
justified_mask = np.flip(justified_mask,axis=axis)
out = np.full(a.shape, invalid_val)
if (axis==-1) or (axis==a.ndim-1):
out[justified_mask] = a[mask]
else:
pushax(out)[pushax(justified_mask)] = pushax(a)[pushax(mask)]
return out
from numpy.lib.stride_tricks import sliding_window_view as swv
window_size = 3
N = 2
a = tmp.astype(float).to_numpy()
arr = np.vstack([np.full((window_size-1,a.shape[1]), np.nan),a])
out = np.nanmean(justify_nd(swv(arr, window_size, axis=0),
invalid_val=np.nan, axis=2, side='end')[:, :, -N:],
axis=2)
print (out)
[[nan nan nan]
[nan nan nan]
[1. 1. 1. ]
[1.5 1.5 1.5]
[2.5 2.5 2.5]
[2.5 2.5 2.5]
[3. 3. 3. ]
[nan nan nan]]
df = pd.DataFrame(out, index=tmp.index, columns=tmp.columns)
print (df)
Name A B C
Date
11.1 NaN NaN NaN
12.1 NaN NaN NaN
13.1 1.0 1.0 1.0
14.1 1.5 1.5 1.5
15.1 2.5 2.5 2.5
16.1 2.5 2.5 2.5
17.1 3.0 3.0 3.0
18.1 NaN NaN NaN
Performance:
tmp = pd.concat([tmp] * 100000, ignore_index=True)
In [99]: %%timeit
...: a = tmp.astype(float).to_numpy()
...: arr = np.vstack([np.full((window_size-1,a.shape[1]), np.nan),a])
...:
...: out = np.nanmean(justify_nd(swv(arr, window_size, axis=0),
...: invalid_val=np.nan,
axis=2, side='end')[:, :, -N:], axis=2)
...:
...: df = pd.DataFrame(out, index=tmp.index, columns=tmp.columns)
...:
338 ms ± 4.61 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)