Skip to content

Commit

Permalink
0.9.27 trader 对象新增 weight_backtest
Browse files Browse the repository at this point in the history
  • Loading branch information
zengbin93 committed Aug 15, 2023
1 parent a21d233 commit b772fa6
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 3 deletions.
39 changes: 38 additions & 1 deletion czsc/traders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from datetime import datetime, timedelta
from deprecated import deprecated
from collections import OrderedDict
from typing import Callable, List, AnyStr, Union
from typing import Callable, List, AnyStr, Union, Optional
from pyecharts.charts import Tab
from pyecharts.components import Table
from pyecharts.options import ComponentTitleOpts
Expand Down Expand Up @@ -473,3 +473,40 @@ def take_snapshot(self, file_html=None, width: str = "1400px", height: str = "58
else:
return tab

def get_ensemble_weight(self, method: Optional[Union[AnyStr, Callable]] = None):
"""获取 CzscTrader 中所有 positions 按照 method 方法集成之后的权重
:param method: str or callable
集成方法,可选值包括:'mean', 'max', 'min', 'vote'
也可以传入自定义的函数,函数的输入为 dict,key 为 position.name,value 为 position.pos, 样例输入:
{'多头策略A': 1, '多头策略B': 1, '空头策略A': -1}
:param kwargs:
:return: pd.DataFrame
columns = ['dt', 'symbol', 'weight', 'price']
"""
from czsc.traders.weight_backtest import get_ensemble_weight
method = self.__ensemble_method if not method else method
return get_ensemble_weight(self, method)

def weight_backtest(self, **kwargs):
"""执行仓位集成权重的回测
:param kwargs:
- method: str or callable,集成方法,参考 get_ensemble_weight 方法
- digits: int,权重小数点后保留的位数,例如 2 表示保留两位小数
- fee_rate: float,手续费率,例如 0.0002 表示万二
- res_path: str,回测结果保存路径
:return: 回测结果
"""
from czsc.traders.weight_backtest import WeightBacktest

method = kwargs.get("method", self.__ensemble_method)
digits = kwargs.get("digits", 2)
fee_rate = kwargs.get("fee_rate", 0.0002)
res_path = kwargs.get("res_path", "./weight_backtest")
dfw = self.get_ensemble_weight(method)
wb = WeightBacktest(dfw, digits=digits, fee_rate=fee_rate, res_path=res_path)
_res = wb.backtest()
return _res
4 changes: 2 additions & 2 deletions czsc/traders/weight_backtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
import numpy as np
import pandas as pd
import plotly.express as px
from czsc.traders.base import CzscTrader
from loguru import logger
from pathlib import Path
from typing import Union, AnyStr, Callable
from czsc.utils.stats import daily_performance, evaluate_pairs
from czsc.traders.base import CzscTrader
from czsc.utils.io import save_json
from czsc.utils.stats import daily_performance, evaluate_pairs


def get_ensemble_weight(trader: CzscTrader, method: Union[AnyStr, Callable] = 'mean'):
Expand Down
7 changes: 7 additions & 0 deletions test/test_trader_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@
email: [email protected]
create_dt: 2021/11/7 21:07
"""
import os
import shutil
import pandas as pd
from copy import deepcopy
from typing import List
from czsc.utils.cache import home_path
from czsc.traders.base import CzscSignals, BarGenerator, CzscTrader
from czsc.traders.sig_parse import get_signals_config, get_signals_freqs
from czsc.objects import Signal, Factor, Event, Operate, Position
Expand Down Expand Up @@ -252,6 +255,10 @@ def _weighted_ensemble(poss):
assert ct.get_ensemble_pos('vote') == 0
assert ct.get_ensemble_pos('max') == 0
assert ct.get_ensemble_pos('mean') == 0
dfw = ct.get_ensemble_weight(method='mean')
assert len(dfw) == len(bars_right)

res = ct.weight_backtest(method='mean', res_path=os.path.join(home_path, "test_trader"))

# 通过 on_bar 执行
ct1 = CzscTrader(deepcopy(bg), signals_config=signals_config,
Expand Down

0 comments on commit b772fa6

Please sign in to comment.