Skip to content

Commit

Permalink
0.9.38 新增 resample_to_daily 方法
Browse files Browse the repository at this point in the history
  • Loading branch information
zengbin93 committed Nov 30, 2023
1 parent a3c59ba commit c74c013
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 1 deletion.
1 change: 1 addition & 0 deletions czsc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
update_tbars,
update_nbars,
risk_free_returns,
resample_to_daily,

CrossSectionalPerformance,
cross_sectional_ranker,
Expand Down
2 changes: 1 addition & 1 deletion czsc/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from .sig import check_pressure_support, check_gap_info, is_bis_down, is_bis_up, get_sub_elements, is_symmetry_zs
from .sig import same_dir_counts, fast_slow_cross, count_last_same, create_single_signal
from .plotly_plot import KlineChart
from .trade import cal_trade_price, update_nbars, update_bbars, update_tbars, risk_free_returns
from .trade import cal_trade_price, update_nbars, update_bbars, update_tbars, risk_free_returns, resample_to_daily
from .cross import CrossSectionalPerformance, cross_sectional_ranker
from .stats import daily_performance, net_value_stats, subtract_fee
from .signal_analyzer import SignalAnalyzer, SignalPerformance
Expand Down
50 changes: 50 additions & 0 deletions czsc/utils/trade.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,53 @@ def update_tbars(da: pd.DataFrame, event_col: str) -> None:
n_seq = [int(x.strip('nb')) for x in da.columns if x[0] == 'n' and x[-1] == 'b']
for n in n_seq:
da[f't{n}b'] = da[f'n{n}b'] * da[event_col]


def resample_to_daily(df: pd.DataFrame, sdt=None, edt=None, only_trade_date=True):
"""将非日线数据转换为日线数据,以便进行日线级别的分析
函数执行逻辑:
1. 首先,函数接收一个数据框`df`,以及可选的开始日期`sdt`,结束日期`edt`,和一个布尔值`only_trade_date`。
2. 函数将`df`中的`dt`列转换为日期时间格式。如果没有提供`sdt`或`edt`,则使用`df`中的最小和最大日期作为开始和结束日期。
3. 创建一个日期序列。如果`only_trade_date`为真,则只包含交易日期;否则,包含`sdt`和`edt`之间的所有日期。
4. 使用`merge_asof`函数,找到每个日期在原始`df`中对应的最近一个日期。
5. 创建一个映射,将每个日期映射到原始`df`中的对应行。
6. 对于日期序列中的每个日期,复制映射中对应的数据行,并将日期设置为当前日期。
7. 最后,将所有复制的数据行合并成一个新的数据框,并返回。
:param df: 日线以上周期的数据,必须包含 dt 列
:param sdt: 开始日期
:param edt: 结束日期
:param only_trade_date: 是否只保留交易日数据
:return: pd.DataFrame
"""
from czsc.utils.calendar import get_trading_dates

df['dt'] = pd.to_datetime(df['dt'])
sdt = df['dt'].min() if not sdt else pd.to_datetime(sdt)
edt = df['dt'].max() if not edt else pd.to_datetime(edt)

# 创建日期序列
if only_trade_date:
trade_dates = get_trading_dates(sdt=sdt, edt=edt)
else:
trade_dates = pd.date_range(sdt, edt, freq='D').tolist()
trade_dates = pd.DataFrame({'date': trade_dates})
trade_dates = trade_dates.sort_values('date', ascending=True).reset_index(drop=True)

# 通过 merge_asof 找到每个日期对应原始 df 中最近一个日期
vdt = pd.DataFrame({'dt': df['dt'].unique()})
trade_dates = pd.merge_asof(trade_dates, vdt, left_on='date', right_on='dt')
trade_dates = trade_dates.dropna(subset=['dt']).reset_index(drop=True)

dt_map = {dt: dfg for dt, dfg in df.groupby('dt')}
results = []
for row in trade_dates.to_dict('records'):
# 注意:这里必须进行 copy,否则默认浅拷贝导致数据异常
df_ = dt_map[row['dt']].copy()
df_['dt'] = row['date']
results.append(df_)

dfr = pd.concat(results, ignore_index=True)
return dfr
19 changes: 19 additions & 0 deletions test/test_trade_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
describe: 测试交易价格计算
"""
import czsc
import pandas as pd
import numpy as np
from test.test_analyze import read_1min

Expand All @@ -18,3 +19,21 @@ def test_trade_price():
close = df['close'].iloc[1:21]
vol = df['vol'].iloc[1:21]
assert df['VWAP20'].iloc[0] == round(np.average(close, weights=vol), 3)


def test_make_it_daily():
dts = pd.date_range(start='2022-01-01', end='2022-02-28', freq='W')
df = pd.DataFrame({'dt': dts, 'value': np.random.random(len(dts))})

# Call the function with the test DataFrame
result = czsc.resample_to_daily(df)

# Check the result
assert isinstance(result, pd.DataFrame), "Result should be a DataFrame"
assert 'dt' in result.columns, "Result should have a 'dt' column"
assert result['dt'].dtype == 'datetime64[ns]', "'dt' column should be datetime64[ns] type"
assert not result['dt'].isnull().any(), "'dt' column should not have any null values"

# Check if the result DataFrame has daily data
result = czsc.resample_to_daily(df, only_trade_date=False)
assert (result['dt'].diff().dt.days <= 1).iloc[1:].all(), "Result should have daily data"

0 comments on commit c74c013

Please sign in to comment.