From b772fa614fccabdd075962815bb60f502b6dfe34 Mon Sep 17 00:00:00 2001 From: zengbin93 Date: Tue, 15 Aug 2023 15:47:56 +0800 Subject: [PATCH] =?UTF-8?q?0.9.27=20trader=20=E5=AF=B9=E8=B1=A1=E6=96=B0?= =?UTF-8?q?=E5=A2=9E=20weight=5Fbacktest?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- czsc/traders/base.py | 39 ++++++++++++++++++++++++++++++++- czsc/traders/weight_backtest.py | 4 ++-- test/test_trader_base.py | 7 ++++++ 3 files changed, 47 insertions(+), 3 deletions(-) diff --git a/czsc/traders/base.py b/czsc/traders/base.py index 12ff74106..f4264c1d5 100644 --- a/czsc/traders/base.py +++ b/czsc/traders/base.py @@ -14,7 +14,7 @@ from datetime import datetime, timedelta from deprecated import deprecated from collections import OrderedDict -from typing import Callable, List, AnyStr, Union +from typing import Callable, List, AnyStr, Union, Optional from pyecharts.charts import Tab from pyecharts.components import Table from pyecharts.options import ComponentTitleOpts @@ -473,3 +473,40 @@ def take_snapshot(self, file_html=None, width: str = "1400px", height: str = "58 else: return tab + def get_ensemble_weight(self, method: Optional[Union[AnyStr, Callable]] = None): + """获取 CzscTrader 中所有 positions 按照 method 方法集成之后的权重 + + :param method: str or callable + 集成方法,可选值包括:'mean', 'max', 'min', 'vote' + 也可以传入自定义的函数,函数的输入为 dict,key 为 position.name,value 为 position.pos, 样例输入: + {'多头策略A': 1, '多头策略B': 1, '空头策略A': -1} + :param kwargs: + :return: pd.DataFrame + columns = ['dt', 'symbol', 'weight', 'price'] + """ + from czsc.traders.weight_backtest import get_ensemble_weight + method = self.__ensemble_method if not method else method + return get_ensemble_weight(self, method) + + def weight_backtest(self, **kwargs): + """执行仓位集成权重的回测 + + :param kwargs: + + - method: str or callable,集成方法,参考 get_ensemble_weight 方法 + - digits: int,权重小数点后保留的位数,例如 2 表示保留两位小数 + - fee_rate: float,手续费率,例如 0.0002 表示万二 + - res_path: str,回测结果保存路径 + + :return: 回测结果 + """ + from czsc.traders.weight_backtest import WeightBacktest + + method = kwargs.get("method", self.__ensemble_method) + digits = kwargs.get("digits", 2) + fee_rate = kwargs.get("fee_rate", 0.0002) + res_path = kwargs.get("res_path", "./weight_backtest") + dfw = self.get_ensemble_weight(method) + wb = WeightBacktest(dfw, digits=digits, fee_rate=fee_rate, res_path=res_path) + _res = wb.backtest() + return _res diff --git a/czsc/traders/weight_backtest.py b/czsc/traders/weight_backtest.py index df5ead171..a99ee04b0 100644 --- a/czsc/traders/weight_backtest.py +++ b/czsc/traders/weight_backtest.py @@ -8,12 +8,12 @@ import numpy as np import pandas as pd import plotly.express as px -from czsc.traders.base import CzscTrader from loguru import logger from pathlib import Path from typing import Union, AnyStr, Callable -from czsc.utils.stats import daily_performance, evaluate_pairs +from czsc.traders.base import CzscTrader from czsc.utils.io import save_json +from czsc.utils.stats import daily_performance, evaluate_pairs def get_ensemble_weight(trader: CzscTrader, method: Union[AnyStr, Callable] = 'mean'): diff --git a/test/test_trader_base.py b/test/test_trader_base.py index 8a16653db..b9a7c843e 100644 --- a/test/test_trader_base.py +++ b/test/test_trader_base.py @@ -4,9 +4,12 @@ email: zeng_bin8888@163.com create_dt: 2021/11/7 21:07 """ +import os +import shutil import pandas as pd from copy import deepcopy from typing import List +from czsc.utils.cache import home_path from czsc.traders.base import CzscSignals, BarGenerator, CzscTrader from czsc.traders.sig_parse import get_signals_config, get_signals_freqs from czsc.objects import Signal, Factor, Event, Operate, Position @@ -252,6 +255,10 @@ def _weighted_ensemble(poss): assert ct.get_ensemble_pos('vote') == 0 assert ct.get_ensemble_pos('max') == 0 assert ct.get_ensemble_pos('mean') == 0 + dfw = ct.get_ensemble_weight(method='mean') + assert len(dfw) == len(bars_right) + + res = ct.weight_backtest(method='mean', res_path=os.path.join(home_path, "test_trader")) # 通过 on_bar 执行 ct1 = CzscTrader(deepcopy(bg), signals_config=signals_config,