Here is numba version that computes OHLC your way which is significantly faster:
from numba import njit
@njit
def compute_ohlc(floor_15_min, O, H, L, C, O_out, H_out, L_out, C_out):
first, curr_max, curr_min, last = O[0], H[0], L[0], C[0]
last_v = floor_15_min[0]
for i, v in enumerate(floor_15_min):
if v != last_v:
first, curr_max, curr_min, last = O[i], H[i], L[i], C[i]
last_v = v
else:
curr_max = max(curr_max, H[i])
curr_min = min(curr_min, L[i])
last = C[i]
O_out[i] = first
H_out[i] = curr_max
L_out[i] = curr_min
C_out[i] = last
def compute_numba(df):
df["15_min_floor_2"] = df.index.floor("15 min")
df[["Open_15_2", "High_15_2", "Low_15_2", "Close_15_2"]] = np.nan
compute_ohlc(
df["15_min_floor_2"].values,
df["Open"].values,
df["High"].values,
df["Low"].values,
df["Close"].values,
df["Open_15_2"].values,
df["High_15_2"].values,
df["Low_15_2"].values,
df["Close_15_2"].values,
)
compute_numba(df)
432001行随机df
的基准测试:
from timeit import timeit
import pandas as pd
from numba import njit
# generate some random data:
np.random.seed(42)
idx = pd.date_range("1-1-2023", "1-6-2023", freq="1000ms")
df = pd.DataFrame(
{
"Open": 50 + np.random.random(len(idx)) * 100,
"High": 50 + np.random.random(len(idx)) * 100,
"Low": 50 + np.random.random(len(idx)) * 100,
"Close": 50 + np.random.random(len(idx)) * 100,
},
index=idx,
)
def get_result_df(df):
def nearest_quarter_hour(timestamp):
return timestamp.floor("15min")
# Find the nearest 15-minute floor for each timestamp
df["15_min_floor"] = df.index.map(nearest_quarter_hour)
# Group by the nearest 15-minute floor and calculate rolling OHLC
rolling_df = (
df.groupby("15_min_floor")
.rolling(window="15min")
.agg(
{
"Open": lambda x: x.iloc[0], # First value in the window
"High": "max",
"Low": "min",
"Close": lambda x: x.iloc[-1], # Last value in the window
}
)
.reset_index(level=0, drop=True)
)
# add _15 to each column rolling df
rolling_df.columns = [f"{col}_15" for col in rolling_df.columns]
# Merge with original DataFrame
result_df = pd.concat([df, rolling_df], axis=1)
return result_df
@njit
def compute_ohlc(floor_15_min, O, H, L, C, O_out, H_out, L_out, C_out):
first, curr_max, curr_min, last = O[0], H[0], L[0], C[0]
last_v = floor_15_min[0]
for i, v in enumerate(floor_15_min):
if v != last_v:
first, curr_max, curr_min, last = O[i], H[i], L[i], C[i]
last_v = v
else:
curr_max = max(curr_max, H[i])
curr_min = min(curr_min, L[i])
last = C[i]
O_out[i] = first
H_out[i] = curr_max
L_out[i] = curr_min
C_out[i] = last
def compute_numba(df):
df["15_min_floor_2"] = df.index.floor("15 min")
df[["Open_15_2", "High_15_2", "Low_15_2", "Close_15_2"]] = np.nan
compute_ohlc(
df["15_min_floor_2"].values,
df["Open"].values,
df["High"].values,
df["Low"].values,
df["Close"].values,
df["Open_15_2"].values,
df["High_15_2"].values,
df["Low_15_2"].values,
df["Close_15_2"].values,
)
t1 = timeit("get_result_df(df)", number=1, globals=globals())
t2 = timeit("compute_numba(df)", number=1, globals=globals())
print(f"Time normal = {t1}")
print(f"Time numba = {t2}")
在我的电脑上打印AMD 5700x(432001行):
Time normal = 29.57983471499756
Time numba = 0.2751060768496245
对于数据帧pd.date_range("1-1-2004", "1-1-2024", freq="1000ms")
(约6.31亿行),结果是:
Time numba = 11.551695882808417