From 06420dba01c9bcb6e6600c1833a5a973e6cd9118 Mon Sep 17 00:00:00 2001 From: zengbin93 Date: Sat, 4 Nov 2023 17:33:48 +0800 Subject: [PATCH] =?UTF-8?q?V0.9.34=20=E6=9B=B4=E6=96=B0=E4=B8=80=E6=89=B9?= =?UTF-8?q?=E4=BB=A3=E7=A0=81=20(#175)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * V0.9.34 first commit * 0.9.33 fix bug * 0.9.34 新增 data client * 0.9.34 update data_client * 0.9.34 启用 AliyunOSS * 0.9.34 fix weight backtest * 0.9.43 新增 feture_cross_layering * 0.9.34 新增 show_weight_backtest --- .github/workflows/pythonpackage.yml | 2 +- czsc/__init__.py | 11 ++- czsc/traders/weight_backtest.py | 6 +- czsc/utils/__init__.py | 2 + czsc/utils/data_client.py | 98 +++++++++++++++++++++++ czsc/utils/features.py | 40 +++++++++ czsc/utils/st_components.py | 47 +++++++++++ czsc/utils/trade.py | 4 +- docs/requirements.txt | 3 +- examples/test_offline/test_data_client.py | 17 ++++ requirements.txt | 3 +- 11 files changed, 225 insertions(+), 8 deletions(-) create mode 100644 czsc/utils/data_client.py create mode 100644 examples/test_offline/test_data_client.py diff --git a/.github/workflows/pythonpackage.yml b/.github/workflows/pythonpackage.yml index eec2601e7..3d8ac16e8 100644 --- a/.github/workflows/pythonpackage.yml +++ b/.github/workflows/pythonpackage.yml @@ -5,7 +5,7 @@ name: Python package on: push: - branches: [ master, V0.9.33 ] + branches: [ master, V0.9.34 ] pull_request: branches: [ master ] diff --git a/czsc/__init__.py b/czsc/__init__.py index 355aa2e30..2d9054dbc 100644 --- a/czsc/__init__.py +++ b/czsc/__init__.py @@ -73,6 +73,11 @@ empty_cache_path, print_df_sample, index_composition, + + AliyunOSS, + DataClient, + set_url_token, + get_url_token, ) # 交易日历工具 @@ -91,6 +96,7 @@ show_factor_returns, show_factor_layering, show_symbol_factor_layering, + show_weight_backtest, ) from czsc.utils.bi_info import ( @@ -101,12 +107,13 @@ from czsc.utils.features import ( normalize_feature, normalize_ts_feature, + feture_cross_layering, ) -__version__ = "0.9.33" +__version__ = "0.9.34" __author__ = "zengbin93" __email__ = "zeng_bin8888@163.com" -__date__ = "20231018" +__date__ = "20231022" diff --git a/czsc/traders/weight_backtest.py b/czsc/traders/weight_backtest.py index 7aaeefa09..7b6e0087a 100644 --- a/czsc/traders/weight_backtest.py +++ b/czsc/traders/weight_backtest.py @@ -137,7 +137,7 @@ class WeightBacktest: 飞书文档:https://s0cqcxuy3p.feishu.cn/wiki/Pf1fw1woQi4iJikbKJmcYToznxb """ - version = "V231005" + version = "V231104" def __init__(self, dfw, digits=2, **kwargs) -> None: """持仓权重回测 @@ -169,9 +169,11 @@ def __init__(self, dfw, digits=2, **kwargs) -> None: """ self.kwargs = kwargs self.dfw = dfw.copy() + if self.dfw.isnull().sum().sum() > 0: + raise ValueError("dfw 中存在空值, 请先处理") self.digits = digits self.fee_rate = kwargs.get('fee_rate', 0.0002) - self.dfw['weight'] = self.dfw['weight'].round(digits) + self.dfw['weight'] = self.dfw['weight'].astype('float').round(digits) self.symbols = list(self.dfw['symbol'].unique().tolist()) self.results = self.backtest() diff --git a/czsc/utils/__init__.py b/czsc/utils/__init__.py index f568b47e3..32a682607 100644 --- a/czsc/utils/__init__.py +++ b/czsc/utils/__init__.py @@ -22,6 +22,8 @@ from .signal_analyzer import SignalAnalyzer, SignalPerformance from .cache import home_path, get_dir_size, empty_cache_path from .index_composition import index_composition +from .data_client import DataClient, set_url_token, get_url_token +from .oss import AliyunOSS sorted_freqs = ['Tick', '1分钟', '2分钟', '3分钟', '4分钟', '5分钟', '6分钟', '10分钟', '12分钟', diff --git a/czsc/utils/data_client.py b/czsc/utils/data_client.py new file mode 100644 index 000000000..e9e5d9025 --- /dev/null +++ b/czsc/utils/data_client.py @@ -0,0 +1,98 @@ +import os +import hashlib +import requests +import pandas as pd +from time import time +from pathlib import Path +from loguru import logger +from functools import partial + + +def set_url_token(token, url): + """设置指定 URL 数据接口的凭证码,通常一台机器只需要设置一次即可 + + :param token: 凭证码 + :param url: 数据接口地址 + """ + hash_key = hashlib.md5(str(url).encode('utf-8')).hexdigest() + file_token = Path("~").expanduser() / f"{hash_key}.txt" + with open(file_token, 'w', encoding='utf-8') as f: + f.write(token) + logger.info(f"{url} 数据访问凭证码已保存到 {file_token}") + + +def get_url_token(url): + """获取指定 URL 数据接口的凭证码""" + hash_key = hashlib.md5(str(url).encode('utf-8')).hexdigest() + file_token = Path("~").expanduser() / f"{hash_key}.txt" + if file_token.exists(): + return open(file_token, 'r', encoding='utf-8').read() + logger.warning(f"请设置 {url} 的访问凭证码,如果没有请联系管理员申请") + return None + + +class DataClient: + def __init__(self, token=None, url='http://api.tushare.pro', timeout=30, **kwargs): + """数据接口客户端,支持缓存,默认缓存路径为 ~/.quant_data_cache;兼容Tushare数据接口 + + :param token: str API接口TOKEN,用于用户认证 + :param url: str API接口地址 + :param timeout: int, 请求超时时间 + :param kwargs: dict, 其他参数 + + - clear_cache: bool, 是否清空缓存 + - cache_path: str, 缓存路径 + + """ + self.__token = token or get_url_token(url) + self.__http_url = url + self.__timeout = timeout + assert self.__token, "请设置czsc_token凭证码,如果没有请联系管理员申请" + self.cache_path = Path(kwargs.get("cache_path", os.path.expanduser("~/.quant_data_cache"))) + self.cache_path.mkdir(exist_ok=True, parents=True) + logger.info(f"数据缓存路径:{self.cache_path}") + if kwargs.get("clear_cache", False): + self.clear_cache() + + def clear_cache(self): + """清空缓存""" + for file in self.cache_path.glob("*.pkl"): + file.unlink() + logger.info(f"{self.cache_path} 路径下的数据缓存已清空") + + def post_request(self, api_name, fields='', **kwargs): + """执行API数据查询 + + :param api_name: str, 查询接口名称 + :param fields: str, 查询字段 + :param kwargs: dict, 查询参数 + :return: pd.DataFrame + """ + stime = time() + if api_name in ['__getstate__', '__setstate__']: + return pd.DataFrame() + + req_params = {'api_name': api_name, 'token': self.__token, 'params': kwargs, 'fields': fields} + hash_key = hashlib.md5(str(req_params).encode('utf-8')).hexdigest() + file_cache = self.cache_path / f"{hash_key}.pkl" + if file_cache.exists(): + df = pd.read_pickle(file_cache) + logger.info(f"缓存命中 | API:{api_name};参数:{kwargs};数据量:{df.shape}") + return df + + res = requests.post(self.__http_url, json=req_params, timeout=self.__timeout) + if res: + result = res.json() + if result['code'] != 0: + raise Exception(f"API: {api_name} - {kwargs} 数据获取失败: {result}") + + df = pd.DataFrame(result['data']['items'], columns=result['data']['fields']) + df.to_pickle(file_cache) + else: + df = pd.DataFrame() + + logger.info(f"本次获取数据总耗时:{time() - stime:.2f}秒;API:{api_name};参数:{kwargs};数据量:{df.shape}") + return df + + def __getattr__(self, name): + return partial(self.post_request, name) diff --git a/czsc/utils/features.py b/czsc/utils/features.py index 3f640a883..1f1df8abf 100644 --- a/czsc/utils/features.py +++ b/czsc/utils/features.py @@ -83,3 +83,43 @@ def normalize_ts_feature(df, x_col, n=10, **kwargs): df[f'{x_col}分层'] = df[f'{x_col}_qcut'].apply(lambda x: f'第{str(int(x+1)).zfill(2)}层') return df + + +def feture_cross_layering(df, x_col, **kwargs): + """因子在时间截面上分层 + + :param df: 因子数据,数据样例: + + =================== ======== =========== ========== ========== + dt symbol factor01 factor02 factor03 + =================== ======== =========== ========== ========== + 2022-12-19 00:00:00 ZZUR9001 -0.0221211 0.034236 0.0793672 + 2022-12-20 00:00:00 ZZUR9001 -0.0278691 0.0275818 0.0735083 + 2022-12-21 00:00:00 ZZUR9001 -0.00617075 0.0512298 0.0990967 + 2022-12-22 00:00:00 ZZUR9001 -0.0222238 0.0320096 0.0792036 + 2022-12-23 00:00:00 ZZUR9001 -0.0375133 0.0129455 0.059491 + =================== ======== =========== ========== ========== + + :param x_col: 因子列名 + :param kwargs: + + - n: 分层数量,默认为10 + + :return: df, 添加了 x_col分层 列 + """ + n = kwargs.get("n", 10) + assert 'dt' in df.columns, "因子数据必须包含 dt 列" + assert 'symbol' in df.columns, "因子数据必须包含 symbol 列" + assert x_col in df.columns, "因子数据必须包含 {} 列".format(x_col) + assert df['symbol'].nunique() > n, "标的数量必须大于分层数量" + + if df[x_col].nunique() > n: + def _layering(x): + return pd.qcut(x, q=n, labels=False, duplicates='drop') + df[f'{x_col}分层'] = df.groupby('dt')[x_col].transform(_layering) + else: + sorted_x = sorted(df[x_col].unique()) + df[f'{x_col}分层'] = df[x_col].apply(lambda x: sorted_x.index(x)) + df[f"{x_col}分层"] = df[f"{x_col}分层"].fillna(-1) + df[f'{x_col}分层'] = df[f'{x_col}分层'].apply(lambda x: f'第{str(int(x+1)).zfill(2)}层') + return df diff --git a/czsc/utils/st_components.py b/czsc/utils/st_components.py index 0faa9f1e7..a4af06bd6 100644 --- a/czsc/utils/st_components.py +++ b/czsc/utils/st_components.py @@ -223,3 +223,50 @@ def show_symbol_factor_layering(df, x_col, y_col='n1b', **kwargs): dfr['空头'] = -dfr[short].sum(axis=1) dfr['多空'] = dfr['多头'] + dfr['空头'] show_daily_return(dfr[['多头', '空头', '多空']]) + + +@st.cache_data(ttl=3600 * 24) +def show_weight_backtest(dfw, **kwargs): + """展示权重回测结果 + + :param dfw: 回测数据,任何字段都不允许有空值;数据样例: + + =================== ======== ======== ======= + dt symbol weight price + =================== ======== ======== ======= + 2019-01-02 09:01:00 DLi9001 0.5 961.695 + 2019-01-02 09:02:00 DLi9001 0.25 960.72 + 2019-01-02 09:03:00 DLi9001 0.25 962.669 + 2019-01-02 09:04:00 DLi9001 0.25 960.72 + 2019-01-02 09:05:00 DLi9001 0.25 961.695 + =================== ======== ======== ======= + + :param kwargs: + + - fee: 单边手续费,单位为BP,默认为2BP + """ + fee = kwargs.get("fee", 2) + if (dfw.isnull().sum().sum() > 0) or (dfw.isna().sum().sum() > 0): + st.warning("数据中存在空值,请检查数据后再试") + st.stop() + + from czsc.traders.weight_backtest import WeightBacktest + + wb = WeightBacktest(dfw, fee=fee / 10000) + stat = wb.results['绩效评价'] + + st.divider() + c1, c2, c3, c4, c5, c6, c7, c8 = st.columns([1, 1, 1, 1, 1, 1, 1, 1]) + c1.metric("盈亏平衡点", f"{stat['盈亏平衡点']:.2%}") + c2.metric("单笔收益", f"{stat['单笔收益']} BP") + c3.metric("交易胜率", f"{stat['交易胜率']:.2%}") + c4.metric("持仓K线数", f"{stat['持仓K线数']}") + c5.metric("最大回撤", f"{stat['最大回撤']:.2%}") + c6.metric("年化收益率", f"{stat['年化']:.2%}") + c7.metric("夏普比率", f"{stat['夏普']:.2f}") + c8.metric("卡玛比率", f"{stat['卡玛']:.2f}") + st.divider() + + dret = wb.results['品种等权日收益'] + dret.index = pd.to_datetime(dret.index) + show_daily_return(dret, legend_only_cols=dfw['symbol'].unique().tolist()) diff --git a/czsc/utils/trade.py b/czsc/utils/trade.py index f76f965a3..a91b65d32 100644 --- a/czsc/utils/trade.py +++ b/czsc/utils/trade.py @@ -32,12 +32,14 @@ def cal_trade_price(bars: Union[List[RawBar], pd.DataFrame], decimals=3, **kwarg df[f"sum_vcp_{t}"] = df['vol_close_prod'].rolling(t).sum() df[f"VWAP{t}"] = (df[f"sum_vcp_{t}"] / df[f"sum_vol_{t}"]).shift(-t) price_cols.extend([f"TWAP{t}", f"VWAP{t}"]) + df.drop(columns=[f"sum_vol_{t}", f"sum_vcp_{t}"], inplace=True) + df.drop(columns=['vol_close_prod'], inplace=True) # 用当前K线的收盘价填充交易价中的 nan 值 for price_col in price_cols: df.loc[df[price_col].isnull(), price_col] = df[df[price_col].isnull()]['close'] - df = df[['symbol', 'dt', 'open', 'close', 'high', 'low', 'vol', 'amount'] + price_cols].round(decimals) + df[price_cols] = df[price_cols].round(decimals) return df diff --git a/docs/requirements.txt b/docs/requirements.txt index cf9fb45ae..eea0f073e 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -27,4 +27,5 @@ plotly>=5.11.0 parse>=1.19.0 lightgbm>=4.0.0 streamlit -redis \ No newline at end of file +redis +oss2 \ No newline at end of file diff --git a/examples/test_offline/test_data_client.py b/examples/test_offline/test_data_client.py new file mode 100644 index 000000000..2b671cb18 --- /dev/null +++ b/examples/test_offline/test_data_client.py @@ -0,0 +1,17 @@ +import sys +sys.path.insert(0, r"D:\ZB\git_repo\waditu\czsc") +import czsc + + +def test_tushare_pro(): + # czsc.set_url_token("******", url="http://api.tushare.pro") + dc = czsc.DataClient(url="http://api.tushare.pro", cache_path="tushare_data") + df = dc.stock_basic(exchange='', list_status='L', fields='ts_code,symbol,name,area,industry,list_date') + try: + df = dc.stock_basic_1(exchange='', list_status='L', fields='ts_code,symbol,name,area,industry,list_date') + except Exception as e: + print(e) + + +if __name__ == '__main__': + test_tushare_pro() diff --git a/requirements.txt b/requirements.txt index a0dda89ca..f999f0c79 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,4 +21,5 @@ plotly>=5.11.0 parse>=1.19.0 lightgbm>=4.0.0 streamlit -redis \ No newline at end of file +redis +oss2 \ No newline at end of file