diff --git a/.github/workflows/pythonpackage.yml b/.github/workflows/pythonpackage.yml index 92d00c571..2872d9dd9 100644 --- a/.github/workflows/pythonpackage.yml +++ b/.github/workflows/pythonpackage.yml @@ -5,7 +5,7 @@ name: Python package on: push: - branches: [ master, V0.9.41 ] + branches: [ master, V0.9.42 ] pull_request: branches: [ master ] diff --git a/czsc/__init__.py b/czsc/__init__.py index 204cd49aa..98f5dae6c 100644 --- a/czsc/__init__.py +++ b/czsc/__init__.py @@ -98,6 +98,7 @@ # streamlit 量化分析组件 from czsc.utils.st_components import ( show_daily_return, + show_yearly_stats, show_splited_daily, show_monthly_return, show_correlation, @@ -126,10 +127,14 @@ find_most_similarity, ) -__version__ = "0.9.41" +from czsc.features.utils import ( + is_event_feature, +) + +__version__ = "0.9.42" __author__ = "zengbin93" __email__ = "zeng_bin8888@163.com" -__date__ = "20240114" +__date__ = "20240121" def welcome(): diff --git a/czsc/connectors/cooperation.py b/czsc/connectors/cooperation.py index 84e924715..0c1120cb8 100644 --- a/czsc/connectors/cooperation.py +++ b/czsc/connectors/cooperation.py @@ -11,6 +11,7 @@ import czsc import pandas as pd from tqdm import tqdm +from loguru import logger from datetime import datetime from czsc import RawBar, Freq @@ -55,14 +56,27 @@ def get_symbols(name, **kwargs): :return: """ if name == "股票": - data = dc.stock_basic(nobj=1, status=1) - return data['code'].tolist() + df = dc.stock_basic(nobj=1, status=1) + symbols = [f"{row['code']}#STOCK" for _, row in df.iterrows()] + return symbols if name == "ETF": - raise NotImplementedError + df = dc.etf_basic(v=2, fields='code,name') + dfk = dc.pro_bar(trade_date="2023-11-17", asset="e", v=2) + df = df[df['code'].isin(dfk['code'])].reset_index(drop=True) + symbols = [f"{row['code']}#ETF" for _, row in df.iterrows()] + return symbols if name == "A股指数": - raise NotImplementedError + # 指数 https://s0cqcxuy3p.feishu.cn/wiki/KuSAweAAhicvsGk9VPTc1ZWKnAd + df = dc.index_basic(v=2, market='SSE,SZSE') + symbols = [f"{row['code']}#INDEX" for _, row in df.iterrows()] + return symbols + + if name == "南华指数": + df = dc.index_basic(v=2, market='NH') + symbols = [row['code'] for _, row in df.iterrows()] + return symbols if name == "期货主力": kline = dc.future_klines(trade_date="20231101") @@ -71,6 +85,28 @@ def get_symbols(name, **kwargs): raise ValueError(f"{name} 分组无法识别,获取标的列表失败!") +def get_min_future_klines(code, sdt, edt, freq='1m'): + """分段获取期货1分钟K线后合并""" + dates = pd.date_range(start=sdt, end=edt, freq='1M') + dates = [d.strftime('%Y%m%d') for d in dates] + [sdt, edt] + dates = sorted(list(set(dates))) + + rows = [] + for sdt_, edt_ in tqdm(zip(dates[:-1], dates[1:]), total=len(dates) - 1): + df = dc.future_klines(code=code, sdt=sdt_, edt=edt_, freq=freq) + if df.empty: + continue + logger.info(f"{code}获取K线范围:{df['dt'].min()} - {df['dt'].max()}") + rows.append(df) + + df = pd.concat(rows, ignore_index=True) + df.rename(columns={'code': 'symbol'}, inplace=True) + df['dt'] = pd.to_datetime(df['dt']) + + df = df.drop_duplicates(subset=['dt', 'symbol'], keep='last') + return df + + def get_raw_bars(symbol, freq, sdt, edt, fq='前复权', **kwargs): """获取 CZSC 库定义的标准 RawBar 对象列表 @@ -85,29 +121,44 @@ def get_raw_bars(symbol, freq, sdt, edt, fq='前复权', **kwargs): """ freq = czsc.Freq(freq) - if symbol.endswith(".SH") or symbol.endswith(".SZ"): + if "SH" in symbol or "SZ" in symbol: fq_map = {"前复权": "qfq", "后复权": "hfq", "不复权": None} adj = fq_map.get(fq, None) + + code, asset = symbol.split("#") + if freq.value.endswith('分钟'): - df = dc.pro_bar(code=symbol, sdt=sdt, edt=edt, freq='min', adj=adj) + df = dc.pro_bar(code=code, sdt=sdt, edt=edt, freq='min', adj=adj, asset=asset[0].lower(), v=2) df = df[~df['dt'].str.endswith("09:30:00")].reset_index(drop=True) else: - df = dc.pro_bar(code=symbol, sdt=sdt, edt=edt, freq='day', adj=adj) + df = dc.pro_bar(code=code, sdt=sdt, edt=edt, freq='day', adj=adj, asset=asset[0].lower(), v=2) + df.rename(columns={'code': 'symbol'}, inplace=True) df['dt'] = pd.to_datetime(df['dt']) return czsc.resample_bars(df, target_freq=freq) if symbol.endswith("9001"): + # https://s0cqcxuy3p.feishu.cn/wiki/WLGQwJLWQiWPCZkPV7Xc3L1engg + if fq == "前复权": + logger.warning("期货主力合约暂时不支持前复权,已自动切换为后复权") + + freq_rd = '1m' if freq.value.endswith('分钟') else '1d' if freq.value.endswith('分钟'): - df = dc.future_klines(code=symbol, sdt=sdt, edt=edt, freq='1m') + df = get_min_future_klines(code=symbol, sdt=sdt, edt=edt, freq='1m') else: - df = dc.future_klines(code=symbol, sdt=sdt, edt=edt, freq='1d') - df.rename(columns={'code': 'symbol'}, inplace=True) + df = dc.future_klines(code=symbol, sdt=sdt, edt=edt, freq=freq_rd) + df.rename(columns={'code': 'symbol'}, inplace=True) + df['amount'] = df['vol'] * df['close'] df = df[['symbol', 'dt', 'open', 'close', 'high', 'low', 'vol', 'amount']].copy().reset_index(drop=True) df['dt'] = pd.to_datetime(df['dt']) return czsc.resample_bars(df, target_freq=freq) + if symbol.endswith(".NH"): + if freq != Freq.D: + raise ValueError("南华指数只支持日线数据") + df = dc.nh_daily(code=symbol, sdt=sdt, edt=edt) + raise ValueError(f"symbol {symbol} 无法识别,获取数据失败!") diff --git a/czsc/features/__init__.py b/czsc/features/__init__.py new file mode 100644 index 000000000..08bc3518e --- /dev/null +++ b/czsc/features/__init__.py @@ -0,0 +1,27 @@ +# -*- coding: utf-8 -*- +""" +author: zengbin93 +email: zeng_bin8888@163.com +create_dt: 2024/02/14 17:48 +describe: 时序特征因子库 + +因子函数编写规范:https://s0cqcxuy3p.feishu.cn/wiki/A9yawT6o1il9SrkUoBNchtXjnBK +""" + +from .ret import ( + RET001, + RET002, + RET003, + RET004, + RET005, + RET006, + RET007, + RET008, +) + +from .vpf import ( + VPF001, + VPF002, + VPF003, + VPF004, +) \ No newline at end of file diff --git a/czsc/features/ret.py b/czsc/features/ret.py new file mode 100644 index 000000000..64b3ee618 --- /dev/null +++ b/czsc/features/ret.py @@ -0,0 +1,214 @@ +""" +用于计算未来收益相关的因子,含有未来信息,不可用于实际交易 +通常用作模型训练、因子评价的标准 +""" +import numpy as np +import pandas as pd + + +def RET001(df, **kwargs): + """用 close 价格计算未来 N 根K线的收益率 + + 参数空间: + + :param df: 标准K线数据,DataFrame结构 + :param kwargs: 其他参数 + + - tag: str, 因子字段标记 + + :return: None + """ + tag = kwargs.get('tag', 'A') + n = kwargs.get('n', 5) + + col = f'F#RET001#{tag}' + df[col] = df['close'].shift(-n) / df['close'] - 1 + df[col] = df[col].fillna(0) + + +def RET002(df, **kwargs): + """用 open 价格计算未来 N 根K线的收益率 + + 参数空间: + + :param df: 标准K线数据,DataFrame结构 + :param kwargs: 其他参数 + + - tag: str, 因子字段标记 + + :return: None + """ + tag = kwargs.get('tag', 'A') + n = kwargs.get('n', 5) + + col = f'F#RET002#{tag}' + df[col] = df['open'].shift(-n - 1) / df['open'].shift(-1) - 1 + df[col] = df[col].fillna(0) + + +def RET003(df, **kwargs): + """未来 N 根K线的收益波动率 + + 参数空间: + + :param df: 标准K线数据,DataFrame结构 + :param kwargs: 其他参数 + + - tag: str, 因子字段标记 + - n: int, 计算未来 N 根K线的收益波动率 + + :return: None + """ + tag = kwargs.get('tag', 'A') + n = kwargs.get('n', 5) + + col = f'F#RET003#{tag}' + df['tmp'] = df['close'].pct_change() + df[col] = df['tmp'].rolling(n).std().shift(-n) + df[col] = df[col].fillna(0) + df.drop(columns=['tmp'], inplace=True) + + +def RET004(df, **kwargs): + """未来 N 根K线的最大收益盈亏比 + + 注意: + 1. 约束盈亏比的范围是 [0, 10] + 2. 当未来 N 根K线内收益最小值为0时,会导致计算结果为无穷大,此时将结果设置为10 + + :param df: 标准K线数据,DataFrame结构 + :param kwargs: 其他参数 + + - tag: str, 因子字段标记 + - n: int, 计算未来 N 根K线的收益盈亏比 + + :return: None + """ + tag = kwargs.get('tag', 'A') + n = kwargs.get('n', 5) + + col = f'F#RET004#{tag}' + df['max_ret'] = df['close'].rolling(n).apply(lambda x: x.max() / x[0] - 1, raw=True) + df['min_ret'] = df['close'].rolling(n).apply(lambda x: x.min() / x[0] - 1, raw=True) + df[col] = (df['max_ret'] / df['min_ret'].abs()).shift(-n) + df[col] = df[col].fillna(0) + df[col] = df[col].clip(0, 10) + df.drop(columns=['max_ret', 'min_ret'], inplace=True) + + +def RET005(df, **kwargs): + """未来 N 根K线的逐K胜率 + + :param df: 标准K线数据,DataFrame结构 + :param kwargs: 其他参数 + + - tag: str, 因子字段标记 + - n: int, 滚动窗口大小 + + :return: None + """ + tag = kwargs.get('tag', 'A') + n = kwargs.get('n', 5) + + col = f'F#RET005#{tag}' + df['ret'] = df['close'].pct_change() + df[col] = df['ret'].rolling(n).apply(lambda x: np.sum(x > 0) / n).shift(-n) + df[col] = df[col].fillna(0) + df.drop(columns=['ret'], inplace=True) + + +def RET006(df, **kwargs): + """未来 N 根K线的逐K盈亏比 + + 注意: + 1. 约束盈亏比的范围是 [0, 10] + + :param df: 标准K线数据,DataFrame结构 + :param kwargs: 其他参数 + + - tag: str, 因子字段标记 + - n: int, 滚动窗口大小 + + :return: None + """ + tag = kwargs.get('tag', 'A') + n = kwargs.get('n', 5) + + col = f'F#RET006#{tag}' + df['ret'] = df['close'].pct_change() + df['mean_win'] = df['ret'].rolling(n).apply(lambda x: np.sum(x[x > 0]) / np.sum(x > 0)) + df['mean_loss'] = df['ret'].rolling(n).apply(lambda x: np.sum(x[x < 0]) / np.sum(x < 0)) + df[col] = (df['mean_win'] / df['mean_loss'].abs()).shift(-n) + df[col] = df[col].fillna(0) + df[col] = df[col].clip(0, 10) + df.drop(columns=['ret', 'mean_win', 'mean_loss'], inplace=True) + + +def RET007(df, **kwargs): + """未来 N 根K线的最大跌幅 + + :param df: 标准K线数据,DataFrame结构 + :param kwargs: 其他参数 + + - tag: str, 因子字段标记 + - n: int, 滚动窗口大小 + + :return: None + """ + tag = kwargs.get('tag', 'A') + n = kwargs.get('n', 5) + + col = f'F#RET007#{tag}' + df[col] = df['close'].rolling(n).apply(lambda x: np.min(x) / x[0] - 1, raw=True).shift(-n) + df[col] = df[col].fillna(0) + + +def RET008(df, **kwargs): + """未来 N 根K线的最大涨幅 + + :param df: 标准K线数据,DataFrame结构 + :param kwargs: 其他参数 + + - tag: str, 因子字段标记 + - n: int, 滚动窗口大小 + + :return: None + """ + tag = kwargs.get('tag', 'A') + n = kwargs.get('n', 5) + + col = f'F#RET008#{tag}' + df[col] = df['close'].rolling(n).apply(lambda x: np.max(x) / x[0] - 1, raw=True).shift(-n) + df[col] = df[col].fillna(0) + + +def test_ret_functions(): + from czsc.connectors import cooperation as coo + + df = coo.dc.pro_bar(code="000001.SZ", freq="day", sdt="2020-01-01", edt="2021-01-31") + df['dt'] = pd.to_datetime(df['dt']) + df.rename(columns={'code': 'symbol'}, inplace=True) + + RET001(df, tag='A') + assert 'F#RET001#A' in df.columns + + RET002(df, tag='A') + assert 'F#RET002#A' in df.columns + + RET003(df, tag='A') + assert 'F#RET003#A' in df.columns + + RET004(df, tag='A') + assert 'F#RET004#A' in df.columns + + RET005(df, tag='A') + assert 'F#RET005#A' in df.columns + + RET006(df, tag='A') + assert 'F#RET006#A' in df.columns + + RET007(df, tag='A') + assert 'F#RET007#A' in df.columns + + RET008(df, tag='A') + assert 'F#RET008#A' in df.columns diff --git a/czsc/features/utils.py b/czsc/features/utils.py new file mode 100644 index 000000000..f1a3f60bb --- /dev/null +++ b/czsc/features/utils.py @@ -0,0 +1,13 @@ +# 工具函数 + + +def is_event_feature(df, col, **kwargs): + """事件类因子的判断函数 + + 事件因子的特征:多头事件发生时,因子值为1;空头事件发生时,因子值为-1;其他情况,因子值为0。 + + :param df: DataFrame + :param col: str, 因子字段名称 + """ + unique_values = df[col].unique() + return all([x in [0, 1, -1] for x in unique_values]) diff --git a/czsc/features/vpf.py b/czsc/features/vpf.py new file mode 100644 index 000000000..4d8c31564 --- /dev/null +++ b/czsc/features/vpf.py @@ -0,0 +1,105 @@ +# 标准量价因子 +import inspect +import numpy as np +import pandas as pd + + +def VPF001(df, **kwargs): + """比较开盘价、收盘价与当日最高价和最低价的中点的关系,来判断市场的强弱 + + :param df: 标准K线数据,DataFrame结构 + :param kwargs: 其他参数 + - tag: str, defaults to 'N2' 因子字段标记 + - num: int, defaults to 2 参数值 + """ + num = kwargs.get('num', 2) + tag = kwargs.get('tag', f'N{num}') + + factor_name = inspect.stack()[0].function + factor_col = f'F#{factor_name}#{tag}' + + con = df['open'] >= 1 / num * (df['high'] + df['low']) + con &= df['close'] >= 1 / num * (df['high'] + df['low']) + + red = df['open'] < 1 / num * (df['high'] + df['low']) + red &= df['close'] < 1 / num * (df['high'] + df['low']) + + df[factor_col] = 0 + df[factor_col] = np.where(con, -1, df[factor_col]) + df[factor_col] = np.where(red, 1, df[factor_col]) + + +def VPF002(df, **kwargs): + """比较过去收益率的正负,以及当日最高价、最低价与开盘价或收盘价的关系 + + :param df: 标准K线数据,DataFrame结构 + :param kwargs: 其他参数 + + - tag: str, defaults to 'N4' 因子字段标记 + - num: int, defaults to 4 参数值 + + :return: None + """ + num = kwargs.get('num', 4) + tag = kwargs.get('tag', f'N{num}') + + factor_name = inspect.stack()[0].function + factor_col = f'F#{factor_name}#{tag}' + + df['return'] = df['close'] / df['close'].shift(1) - 1 + red1 = df['return'].rolling(window=num, min_periods=1).sum() >= 0 + red2 = (df['high'] - df['close']) / (df['close'] - df['low']) >= 1 + + df[factor_col] = np.where(red1 | red2, 1, -1) + df.drop(columns=['return'], axis=1, inplace=True) + + +def VPF003(df, **kwargs): + """比较过去N天最高价、最低价、开盘价和收盘价的比例,判断市场强弱 + + :param df: 标准K线数据,DataFrame结构 + :param kwargs: 其他参数 + + - tag: str + - num: int, defaults to 60 参数值 + """ + num = kwargs.get('num', 2) + tag = kwargs.get('tag', f'N{num}') + + factor_name = inspect.stack()[0].function + factor_col = f'F#{factor_name}#{tag}' + + df['hol'] = (df['high'] - df['open']) / (df['high'] - df['low']) + df['clh'] = (df['close'] - df['low']) / (df['high'] - df['low']) + + con = df['hol'].rolling(window=num, min_periods=1).mean() >= 0.5 + con1 = (df['high'] + df['low'] - df['open'] - df['close']) >= 0 + df[factor_col] = np.where(con | con1, 1, -1) + red = df['clh'].rolling(window=num, min_periods=1).mean() >= 0.5 + df.loc[red, factor_col] = -1 + + df.drop(['hol', 'clh'], axis=1, inplace=True) + + +def VPF004(df, **kwargs): + """EMA指标 + + :param df: 标准K线数据,DataFrame结构 + :param kwargs: 其他参数 + + - tag: str, 因子字段标记 + - n: int, EMA的周期参数 + + :return: None + """ + n = kwargs.get('n', 7) + tag = kwargs.get('tag', f'N{n}') + + factor_name = inspect.stack()[0].function + factor_col = f'F#{factor_name}#{tag}' + + ema1 = df['close'].ewm(span=n, adjust=False).mean() + ema2 = ema1.ewm(span=n, adjust=False).mean() + ema3 = ema2.ewm(span=n, adjust=False).mean() + df[factor_col] = 3 * (ema1 - ema2) + ema3 + df[factor_col] = df[factor_col].fillna(0) diff --git a/czsc/signals/__init__.py b/czsc/signals/__init__.py index 5f6084000..79d2290ce 100644 --- a/czsc/signals/__init__.py +++ b/czsc/signals/__init__.py @@ -3,7 +3,7 @@ author: zengbin93 email: zeng_bin8888@163.com create_dt: 2021/11/21 17:48 -describe: 信号系统,注意:这里仅仅只是提供一些写信号的例子,用来做策略是不太行的 +describe: 信号函数 """ # ====================================================================================================================== # 以下是 0.9.1 开始的新标准下实现的信号函数,规范定义: @@ -108,6 +108,7 @@ bar_window_std_V230731, bar_window_ps_V230731, bar_window_ps_V230801, + bar_trend_V240209, ) from czsc.signals.jcc import ( @@ -163,6 +164,7 @@ tas_ma_round_V221206, tas_double_ma_V221203, tas_double_ma_V230511, + tas_double_ma_V240208, tas_ma_system_V230513, tas_boll_power_V221112, diff --git a/czsc/signals/bar.py b/czsc/signals/bar.py index d291903ee..0e9c259da 100644 --- a/czsc/signals/bar.py +++ b/czsc/signals/bar.py @@ -16,7 +16,7 @@ from czsc.traders.base import CzscSignals from czsc.objects import RawBar from czsc.utils.sig import check_pressure_support -from czsc.signals.tas import update_ma_cache +from czsc.signals.tas import update_ma_cache, update_macd_cache from czsc.utils.bar_generator import freq_end_time from czsc.utils import single_linear, freq_end_time, get_sub_elements, create_single_signal @@ -1648,3 +1648,67 @@ def bar_window_ps_V230801(c: CZSC, **kwargs) -> OrderedDict: pcts = [int(max((x.close - L_line) / (H_line - L_line), 0) * 10) for x in c.bars_raw[-w:]] v1, v2, v3 = f"最大N{max(pcts)}", f"最小N{min(pcts)}", f"当前N{pcts[-1]}" return create_single_signal(k1=k1, k2=k2, k3=k3, v1=v1, v2=v2, v3=v3) + + +def bar_trend_V240209(c: CZSC, **kwargs) -> OrderedDict: + """趋势跟踪信号 + + 参数模板:"{freq}_D{di}N{N}趋势跟踪_BS辅助V240209" + + **信号逻辑:** + + 以多头为例: + 1. 低点出现在高点之后,且低点右侧的高点到当前K线之间的K线数量在5-30之间; + 2. 低点右侧的K线的DIF值小于前N根K线的DIF值的标准差的一半; + 3. 低点右侧的K线的最低价大于低点的最低价; + 4. 低点右侧的K线的MACD值小于前N根K线的MACD值的标准差的一半。 + + **信号列表:** + + - Signal('60分钟_D1N60趋势跟踪_BS辅助V240209_多头_任意_任意_0') + - Signal('60分钟_D1N60趋势跟踪_BS辅助V240209_空头_任意_任意_0') + + :param c: CZSC对象 + :param kwargs: 参数设置 + + - di: int, default 1, 倒数第几根K线 + - N: int, default 20, 窗口大小 + + :return: 信号识别结果 + """ + di = int(kwargs.get('di', 1)) + N = int(kwargs.get('N', 60)) + + freq = c.freq.value + k1, k2, k3 = f"{freq}_D{di}N{N}趋势跟踪_BS辅助V240209".split('_') + v1 = '其他' + cache_key = update_macd_cache(c) + bars = get_sub_elements(c.bars_raw, di=di, n=N) + max_bar = max(bars, key=lambda x: x.high) + min_bar = min(bars, key=lambda x: x.low) + dif_std = np.std([x.cache[cache_key]['dif'] for x in bars]) + macd_std = np.std([x.cache[cache_key]['macd'] for x in bars]) + + if min_bar.dt < max_bar.dt: + right_bars = [x for x in c.bars_raw if x.dt >= max_bar.dt] + right_min_bar = min(right_bars, key=lambda x: x.low) + c1 = 30 > right_min_bar.id - max_bar.id > 5 + c2 = abs(right_bars[-1].cache[cache_key]['dif']) < dif_std # type: ignore + c3 = right_min_bar.low > min_bar.low + c4 = abs(right_bars[-1].cache[cache_key]['macd']) < macd_std # type: ignore + + if c1 and c2 and c3 and c4: + return create_single_signal(k1=k1, k2=k2, k3=k3, v1="多头") + + if min_bar.dt > max_bar.dt: + right_bars = [x for x in c.bars_raw if x.dt >= min_bar.dt] + right_max_bar = max(right_bars, key=lambda x: x.high) + c1 = 30 > right_max_bar.id - min_bar.id > 5 + c2 = abs(right_bars[-1].cache[cache_key]['dif']) < dif_std # type: ignore + c3 = right_max_bar.high < max_bar.high + c4 = abs(right_bars[-1].cache[cache_key]['macd']) < macd_std # type: ignore + + if c1 and c2 and c3 and c4: + return create_single_signal(k1=k1, k2=k2, k3=k3, v1="空头") + + return create_single_signal(k1=k1, k2=k2, k3=k3, v1=v1) diff --git a/czsc/signals/tas.py b/czsc/signals/tas.py index 310ad107a..108fcf99a 100644 --- a/czsc/signals/tas.py +++ b/czsc/signals/tas.py @@ -3416,3 +3416,56 @@ def tas_slope_V231019(c: CZSC, **kwargs) -> OrderedDict: elif q < 1 - th / 100: v1 = '看空' return create_single_signal(k1=k1, k2=k2, k3=k3, v1=v1) + + +def tas_double_ma_V240208(c: CZSC, **kwargs) -> OrderedDict: + """双均线多空信号,辅助V240208 + + 参数模板:"{freq}_D{di}N{N}M{M}双均线_BS辅助V240208" + + **信号逻辑:** + + 1. 找出最近3个均线交叉点,时间上由远到近,分别为 X1,X2,X3 + 2. 以多头为例:X3 和 X1 为金叉,且 X2 的价格最高 + + **信号列表:** + + - Signal('60分钟_D1N5M21双均线_BS辅助V240208_多头_任意_任意_0') + - Signal('60分钟_D1N5M21双均线_BS辅助V240208_空头_任意_任意_0') + + :param c: CZSC对象 + :param kwargs: 参数设置 + + - di: int, default 1, 倒数第几根K线 + - N: int, default 20, 快线周期 + - M: int, default 60, 慢线周期 + + :return: 信号识别结果 + """ + di = int(kwargs.get('di', 1)) + N = int(kwargs.get('N', 20)) + M = int(kwargs.get('M', 60)) + assert N < M, "N 必须小于 M" + + freq = c.freq.value + k1, k2, k3 = f"{freq}_D{di}N{N}M{M}双均线_BS辅助V240208".split('_') + v1 = '其他' + fast_ma_key = update_ma_cache(c, ma_type='SMA', timeperiod=N) + slow_ma_key = update_ma_cache(c, ma_type='SMA', timeperiod=M) + + bars = get_sub_elements(c.bars_raw, di=di, n=M * 30) + fast_ma = [x.cache[fast_ma_key] for x in bars] + slow_ma = [x.cache[slow_ma_key] for x in bars] + cross_info = fast_slow_cross(fast_ma, slow_ma) + + if len(cross_info) < 3: + return create_single_signal(k1=k1, k2=k2, k3=k3, v1=v1) + + x1, x2, x3 = cross_info[-3:] + if x3['类型'] == "金叉" and x2['快线'] > max(x1['快线'], x3['快线']): + return create_single_signal(k1=k1, k2=k2, k3=k3, v1='多头') + + if x3['类型'] == "死叉" and x2['快线'] < min(x1['快线'], x3['快线']): + return create_single_signal(k1=k1, k2=k2, k3=k3, v1='空头') + + return create_single_signal(k1=k1, k2=k2, k3=k3, v1=v1) diff --git a/czsc/utils/cache.py b/czsc/utils/cache.py index a2b6fe4a4..24a3ca095 100644 --- a/czsc/utils/cache.py +++ b/czsc/utils/cache.py @@ -8,9 +8,10 @@ import os import time import dill +import json import shutil import hashlib -import json +import inspect import pandas as pd from pathlib import Path from loguru import logger @@ -94,6 +95,10 @@ def get(self, k: str, suffix: str = "pkl") -> Any: res = pd.read_csv(file, encoding='utf-8') elif suffix == "xlsx": res = pd.read_excel(file) + elif suffix == "feather": + res = pd.read_feather(file) + elif suffix == "parquet": + res = pd.read_parquet(file) else: raise ValueError(f"suffix {suffix} not supported") return res @@ -132,6 +137,16 @@ def set(self, k: str, v: Any, suffix: str = "pkl"): raise ValueError("suffix xlsx only support pd.DataFrame") v.to_excel(file, index=False) + elif suffix == "feather": + if not isinstance(v, pd.DataFrame): + raise ValueError("suffix feather only support pd.DataFrame") + v.to_feather(file) + + elif suffix == "parquet": + if not isinstance(v, pd.DataFrame): + raise ValueError("suffix parquet only support pd.DataFrame") + v.to_parquet(file) + else: raise ValueError(f"suffix {suffix} not supported") @@ -150,15 +165,14 @@ def disk_cache(path: str, suffix: str = "pkl", ttl: int = -1): :param suffix: 缓存文件后缀,支持 pkl, json, txt, csv, xlsx :param ttl: 缓存文件有效期,单位:秒 """ - assert suffix in ["pkl", "json", "txt", "csv", "xlsx"], "suffix not supported" - def decorator(func): nonlocal path _c = DiskCache(path=Path(path) / func.__name__) def cached_func(*args, **kwargs): hash_str = f"{func.__name__}{args}{kwargs}" - k = hashlib.md5(hash_str.encode('utf-8')).hexdigest().upper()[:8] + code_str = inspect.getsource(func) + k = hashlib.md5((code_str + hash_str).encode('utf-8')).hexdigest().upper()[:8] k = f"{k}_{func.__name__}" if _c.is_found(k, suffix=suffix, ttl=ttl): diff --git a/czsc/utils/data_client.py b/czsc/utils/data_client.py index 360bec487..778e42351 100644 --- a/czsc/utils/data_client.py +++ b/czsc/utils/data_client.py @@ -53,6 +53,8 @@ def __init__(self, token=None, url='http://api.tushare.pro', timeout=300, **kwar - cache_path: str, 缓存路径 """ + from czsc.utils.cache import get_dir_size + self.__token = token or get_url_token(url) self.__http_url = url self.__timeout = timeout @@ -60,7 +62,8 @@ def __init__(self, token=None, url='http://api.tushare.pro', timeout=300, **kwar assert self.__token, "请设置czsc_token凭证码,如果没有请联系管理员申请" self.cache_path = Path(kwargs.get("cache_path", os.path.expanduser("~/.quant_data_cache"))) self.cache_path.mkdir(exist_ok=True, parents=True) - logger.info(f"数据URL: {url} 数据缓存路径:{self.cache_path}") + + logger.info(f"数据URL: {url} 数据缓存路径:{self.cache_path} 占用磁盘空间:{get_dir_size(self.cache_path) / 1024 / 1024:.2f} MB") if kwargs.get("clear_cache", False): self.clear_cache() diff --git a/czsc/utils/plotly_plot.py b/czsc/utils/plotly_plot.py index c46bf93cc..39d33ee1d 100644 --- a/czsc/utils/plotly_plot.py +++ b/czsc/utils/plotly_plot.py @@ -182,7 +182,12 @@ def add_macd(self, kline: pd.DataFrame, row=3, **kwargs): slowperiod = kwargs.get('slowperiod', 26) signalperiod = kwargs.get('signalperiod', 9) line_width = kwargs.get('line_width', 0.6) - diff, dea, macd = MACD(df["close"], fastperiod=fastperiod, slowperiod=slowperiod, signalperiod=signalperiod) + + if 'DIFF' in df.columns and 'DEA' in df.columns and 'MACD' in df.columns: + diff, dea, macd = df['DIFF'], df['DEA'], df['MACD'] + else: + diff, dea, macd = MACD(df["close"], fastperiod=fastperiod, slowperiod=slowperiod, signalperiod=signalperiod) + macd_colors = np.where(macd > 0, self.color_red, self.color_green) self.add_scatter_indicator(df['dt'], diff, name="DIFF", row=row, line_color='white', show_legend=False, line_width=line_width) diff --git a/czsc/utils/sig.py b/czsc/utils/sig.py index 516410606..e90e8a2b8 100644 --- a/czsc/utils/sig.py +++ b/czsc/utils/sig.py @@ -160,7 +160,7 @@ def check_gap_info(bars: List[RawBar]): return gap_info -def fast_slow_cross(fast: [List, np.array], slow: [List, np.array]): +def fast_slow_cross(fast, slow): """计算 fast 和 slow 的交叉信息 :param fast: 快线 diff --git a/czsc/utils/st_components.py b/czsc/utils/st_components.py index e260b0099..ac8712deb 100644 --- a/czsc/utils/st_components.py +++ b/czsc/utils/st_components.py @@ -44,6 +44,7 @@ def _stats(df_, type_='持有日'): # stats = stats.style.background_gradient(cmap='RdYlGn_r', axis=None, subset=fmt_cols).format('{:.4f}') stats = stats.style.background_gradient(cmap='RdYlGn_r', axis=None, subset=['年化']) + stats = stats.background_gradient(cmap='RdYlGn_r', axis=None, subset=['绝对收益']) stats = stats.background_gradient(cmap='RdYlGn_r', axis=None, subset=['夏普']) stats = stats.background_gradient(cmap='RdYlGn', axis=None, subset=['最大回撤']) stats = stats.background_gradient(cmap='RdYlGn_r', axis=None, subset=['卡玛']) @@ -63,6 +64,7 @@ def _stats(df_, type_='持有日'): '年化': '{:.2%}', '夏普': '{:.2f}', '非零覆盖': '{:.2%}', + '绝对收益': '{:.2%}', '日胜率': '{:.2%}', '新高间隔': '{:.2f}', '新高占比': '{:.2%}', @@ -301,7 +303,7 @@ def show_symbol_factor_layering(df, x_col, y_col='n1b', **kwargs): tabs = st.tabs(["分层收益率", "多空组合"]) with tabs[0]: - show_daily_return(mrr) + show_daily_return(mrr, stat_hold_days=False) with tabs[1]: col1, col2 = st.columns(2) @@ -360,7 +362,7 @@ def show_weight_backtest(dfw, **kwargs): dret = wb.results['品种等权日收益'] dret.index = pd.to_datetime(dret.index) - show_daily_return(dret, legend_only_cols=dfw['symbol'].unique().tolist()) + show_daily_return(dret, legend_only_cols=dfw['symbol'].unique().tolist(), **kwargs) if kwargs.get("show_backtest_detail", False): c1, c2 = st.columns([1, 1]) @@ -419,7 +421,7 @@ def show_splited_daily(df, ret_col, **kwargs): row['开始日期'] = sdt.strftime('%Y-%m-%d') row['结束日期'] = last_dt.strftime('%Y-%m-%d') row['收益名称'] = name - row['绝对收益'] = df1[ret_col].sum() + # row['绝对收益'] = df1[ret_col].sum() rows.append(row) dfv = pd.DataFrame(rows).set_index('收益名称') cols = ['开始日期', '结束日期', '绝对收益', '年化', '夏普', '最大回撤', '卡玛', '年化波动率', '非零覆盖', '日胜率', '盈亏平衡点'] @@ -450,6 +452,66 @@ def show_splited_daily(df, ret_col, **kwargs): st.dataframe(dfv, use_container_width=True) +def show_yearly_stats(df, ret_col, **kwargs): + """按年计算日收益表现 + + :param df: pd.DataFrame,数据源 + :param ret_col: str,收益列名 + :param kwargs: + + - sub_title: str, 子标题 + """ + if not df.index.dtype == 'datetime64[ns]': + df['dt'] = pd.to_datetime(df['dt']) + df.set_index('dt', inplace=True) + + assert df.index.dtype == 'datetime64[ns]', "index必须是datetime64[ns]类型, 请先使用 pd.to_datetime 进行转换" + df = df.copy().fillna(0) + df.sort_index(inplace=True, ascending=True) + + df['年份'] = df.index.year + + _stats = [] + for year, df_ in df.groupby('年份'): + _yst = czsc.daily_performance(df_[ret_col].to_list()) + _yst['年份'] = year + _stats.append(_yst) + + stats = pd.DataFrame(_stats).set_index('年份') + + stats = stats.style.background_gradient(cmap='RdYlGn_r', axis=None, subset=['年化']) + stats = stats.background_gradient(cmap='RdYlGn_r', axis=None, subset=['夏普']) + stats = stats.background_gradient(cmap='RdYlGn_r', axis=None, subset=['绝对收益']) + stats = stats.background_gradient(cmap='RdYlGn', axis=None, subset=['最大回撤']) + stats = stats.background_gradient(cmap='RdYlGn_r', axis=None, subset=['卡玛']) + stats = stats.background_gradient(cmap='RdYlGn', axis=None, subset=['年化波动率']) + stats = stats.background_gradient(cmap='RdYlGn', axis=None, subset=['盈亏平衡点']) + stats = stats.background_gradient(cmap='RdYlGn_r', axis=None, subset=['日胜率']) + stats = stats.background_gradient(cmap='RdYlGn_r', axis=None, subset=['非零覆盖']) + stats = stats.background_gradient(cmap='RdYlGn', axis=None, subset=['新高间隔']) + stats = stats.background_gradient(cmap='RdYlGn_r', axis=None, subset=['新高占比']) + + stats = stats.format( + { + '盈亏平衡点': '{:.2f}', + '年化波动率': '{:.2%}', + '最大回撤': '{:.2%}', + '卡玛': '{:.2f}', + '年化': '{:.2%}', + '夏普': '{:.2f}', + '非零覆盖': '{:.2%}', + '绝对收益': '{:.2%}', + '日胜率': '{:.2%}', + '新高间隔': '{:.2f}', + '新高占比': '{:.2%}', + } + ) + + if kwargs.get('sub_title'): + st.subheader(kwargs.get('sub_title'), divider="rainbow") + st.dataframe(stats, use_container_width=True) + + def show_ts_rolling_corr(df, col1, col2, **kwargs): """时序上按 rolling 的方式计算相关系数 diff --git a/czsc/utils/stats.py b/czsc/utils/stats.py index 5a234838f..86434e48b 100644 --- a/czsc/utils/stats.py +++ b/czsc/utils/stats.py @@ -7,6 +7,7 @@ """ import numpy as np import pandas as pd +from collections import Counter def cal_break_even_point(seq) -> float: @@ -87,7 +88,7 @@ def daily_performance(daily_returns): daily_returns = np.array(daily_returns, dtype=np.float64) if len(daily_returns) == 0 or np.std(daily_returns) == 0 or all(x == 0 for x in daily_returns): - return {"年化": 0, "夏普": 0, "最大回撤": 0, "卡玛": 0, "日胜率": 0, + return {"绝对收益": 0, "年化": 0, "夏普": 0, "最大回撤": 0, "卡玛": 0, "日胜率": 0, "年化波动率": 0, "非零覆盖": 0, "盈亏平衡点": 0, "新高间隔": 0, "新高占比": 0} annual_returns = np.sum(daily_returns) / len(daily_returns) * 252 @@ -100,14 +101,11 @@ def daily_performance(daily_returns): annual_volatility = np.std(daily_returns) * np.sqrt(252) none_zero_cover = len(daily_returns[daily_returns != 0]) / len(daily_returns) - # 计算最大新高时间 - high_index = [i for i, x in enumerate(dd) if x == 0] - max_interval = 0 - for i in range(len(high_index) - 1): - max_interval = max(max_interval, high_index[i + 1] - high_index[i]) + # 计算最大新高间隔 + max_interval = Counter(np.maximum.accumulate(cum_returns).tolist()).most_common(1)[0][1] # 计算新高时间占比 - high_pct = len(high_index) / len(dd) + high_pct = len([i for i, x in enumerate(dd) if x == 0]) / len(dd) def __min_max(x, min_val, max_val, digits=4): if x < min_val: @@ -119,6 +117,7 @@ def __min_max(x, min_val, max_val, digits=4): return round(x1, digits) sta = { + "绝对收益": round(np.sum(daily_returns), 4), "年化": round(annual_returns, 4), "夏普": __min_max(sharpe_ratio, -5, 5, 2), "最大回撤": round(max_drawdown, 4), diff --git a/examples/signals_dev/dropit/bar_trend_V240209.py b/examples/signals_dev/dropit/bar_trend_V240209.py new file mode 100644 index 000000000..0400aa411 --- /dev/null +++ b/examples/signals_dev/dropit/bar_trend_V240209.py @@ -0,0 +1,85 @@ +import numpy as np +from collections import OrderedDict +from czsc.analyze import CZSC +from loguru import logger +from czsc.signals.tas import update_ma_cache, update_macd_cache +from czsc.utils import create_single_signal, get_sub_elements, fast_slow_cross + + +def bar_trend_V240209(c: CZSC, **kwargs) -> OrderedDict: + """趋势跟踪信号 + + 参数模板:"{freq}_D{di}N{N}趋势跟踪_BS辅助V240209" + + **信号逻辑:** + + 以多头为例: + 1. 低点出现在高点之后,且低点右侧的高点到当前K线之间的K线数量在5-30之间; + 2. 低点右侧的K线的DIF值小于前N根K线的DIF值的标准差的一半; + 3. 低点右侧的K线的最低价大于低点的最低价; + 4. 低点右侧的K线的MACD值小于前N根K线的MACD值的标准差的一半。 + + **信号列表:** + + - Signal('60分钟_D1N60趋势跟踪_BS辅助V240209_多头_任意_任意_0') + - Signal('60分钟_D1N60趋势跟踪_BS辅助V240209_空头_任意_任意_0') + + :param c: CZSC对象 + :param kwargs: 参数设置 + + - di: int, default 1, 倒数第几根K线 + - N: int, default 20, 窗口大小 + + :return: 信号识别结果 + """ + di = int(kwargs.get('di', 1)) + N = int(kwargs.get('N', 60)) + + freq = c.freq.value + k1, k2, k3 = f"{freq}_D{di}N{N}趋势跟踪_BS辅助V240209".split('_') + v1 = '其他' + cache_key = update_macd_cache(c) + bars = get_sub_elements(c.bars_raw, di=di, n=N) + max_bar = max(bars, key=lambda x: x.high) + min_bar = min(bars, key=lambda x: x.low) + dif_std = np.std([x.cache[cache_key]['dif'] for x in bars]) + macd_std = np.std([x.cache[cache_key]['macd'] for x in bars]) + + if min_bar.dt < max_bar.dt: + right_bars = [x for x in c.bars_raw if x.dt >= max_bar.dt] + right_min_bar = min(right_bars, key=lambda x: x.low) + c1 = 30 > right_min_bar.id - max_bar.id > 5 + c2 = abs(right_bars[-1].cache[cache_key]['dif']) < dif_std # type: ignore + c3 = right_min_bar.low > min_bar.low + c4 = abs(right_bars[-1].cache[cache_key]['macd']) < macd_std # type: ignore + + if c1 and c2 and c3 and c4: + return create_single_signal(k1=k1, k2=k2, k3=k3, v1="多头") + + if min_bar.dt > max_bar.dt: + right_bars = [x for x in c.bars_raw if x.dt >= min_bar.dt] + right_max_bar = max(right_bars, key=lambda x: x.high) + c1 = 30 > right_max_bar.id - min_bar.id > 5 + c2 = abs(right_bars[-1].cache[cache_key]['dif']) < dif_std # type: ignore + c3 = right_max_bar.high < max_bar.high + c4 = abs(right_bars[-1].cache[cache_key]['macd']) < macd_std # type: ignore + + if c1 and c2 and c3 and c4: + return create_single_signal(k1=k1, k2=k2, k3=k3, v1="空头") + + return create_single_signal(k1=k1, k2=k2, k3=k3, v1=v1) + + +def check(): + from czsc.connectors import research + from czsc.traders.base import check_signals_acc + + symbols = research.get_symbols('A股主要指数') + bars = research.get_raw_bars(symbols[0], '15分钟', '20181101', '20210101', fq='前复权') + + signals_config = [{'name': bar_trend_V240209, 'freq': "60分钟", 'N': 60}] + check_signals_acc(bars, signals_config=signals_config, height='780px', delta_days=1) # type: ignore + + +if __name__ == '__main__': + check() diff --git a/examples/signals_dev/dropit/tas_double_ma_V240208.py b/examples/signals_dev/dropit/tas_double_ma_V240208.py new file mode 100644 index 000000000..99e00c988 --- /dev/null +++ b/examples/signals_dev/dropit/tas_double_ma_V240208.py @@ -0,0 +1,72 @@ +from collections import OrderedDict +from czsc.analyze import CZSC +from czsc.signals.tas import update_ma_cache +from czsc.utils import create_single_signal, get_sub_elements, fast_slow_cross + + +def tas_double_ma_V240208(c: CZSC, **kwargs) -> OrderedDict: + """双均线多空信号,辅助V240208 + + 参数模板:"{freq}_D{di}N{N}M{M}双均线_BS辅助V240208" + + **信号逻辑:** + + 1. 找出最近3个均线交叉点,时间上由远到近,分别为 X1,X2,X3 + 2. 以多头为例:X3 和 X1 为金叉,且 X2 的价格最高 + + **信号列表:** + + - Signal('60分钟_D1N5M21双均线_BS辅助V240208_多头_任意_任意_0') + - Signal('60分钟_D1N5M21双均线_BS辅助V240208_空头_任意_任意_0') + + :param c: CZSC对象 + :param kwargs: 参数设置 + + - di: int, default 1, 倒数第几根K线 + - N: int, default 20, 快线周期 + - M: int, default 60, 慢线周期 + + :return: 信号识别结果 + """ + di = int(kwargs.get('di', 1)) + N = int(kwargs.get('N', 20)) + M = int(kwargs.get('M', 60)) + assert N < M, "N < M" + + freq = c.freq.value + k1, k2, k3 = f"{freq}_D{di}N{N}M{M}双均线_BS辅助V240208".split('_') + v1 = '其他' + fast_ma_key = update_ma_cache(c, ma_type='SMA', timeperiod=N) + slow_ma_key = update_ma_cache(c, ma_type='SMA', timeperiod=M) + + bars = get_sub_elements(c.bars_raw, di=di, n=M * 30) + fast_ma = [x.cache[fast_ma_key] for x in bars] + slow_ma = [x.cache[slow_ma_key] for x in bars] + cross_info = fast_slow_cross(fast_ma, slow_ma) + + if len(cross_info) < 3: + return create_single_signal(k1=k1, k2=k2, k3=k3, v1=v1) + + x1, x2, x3 = cross_info[-3:] + if x3['类型'] == "金叉" and x2['快线'] > max(x1['快线'], x3['快线']): + return create_single_signal(k1=k1, k2=k2, k3=k3, v1='多头') + + if x3['类型'] == "死叉" and x2['快线'] < min(x1['快线'], x3['快线']): + return create_single_signal(k1=k1, k2=k2, k3=k3, v1='空头') + + return create_single_signal(k1=k1, k2=k2, k3=k3, v1=v1) + + +def check(): + from czsc.connectors import research + from czsc.traders.base import check_signals_acc + + symbols = research.get_symbols('A股主要指数') + bars = research.get_raw_bars(symbols[0], '15分钟', '20181101', '20210101', fq='前复权') + + signals_config = [{'name': tas_double_ma_V240208, 'freq': "60分钟", 'N': 5, 'M': 21}] + check_signals_acc(bars, signals_config=signals_config, height='780px', delta_days=5) # type: ignore + + +if __name__ == '__main__': + check() diff --git a/examples/signals_dev/tas_macd_bc_V230803.py b/examples/signals_dev/dropit/tas_macd_bc_V230803.py similarity index 100% rename from examples/signals_dev/tas_macd_bc_V230803.py rename to examples/signals_dev/dropit/tas_macd_bc_V230803.py diff --git a/examples/signals_dev/tas_macd_bc_V230804.py b/examples/signals_dev/dropit/tas_macd_bc_V230804.py similarity index 100% rename from examples/signals_dev/tas_macd_bc_V230804.py rename to examples/signals_dev/dropit/tas_macd_bc_V230804.py diff --git a/examples/signals_dev/tas_macd_bc_ubi_V230804.py b/examples/signals_dev/dropit/tas_macd_bc_ubi_V230804.py similarity index 100% rename from examples/signals_dev/tas_macd_bc_ubi_V230804.py rename to examples/signals_dev/dropit/tas_macd_bc_ubi_V230804.py diff --git a/examples/test_offline/test_cooperation.py b/examples/test_offline/test_cooperation.py new file mode 100644 index 000000000..a1b1b447e --- /dev/null +++ b/examples/test_offline/test_cooperation.py @@ -0,0 +1,40 @@ +import sys +sys.path.insert(0, r'D:\ZB\git_repo\waditu\czsc') +import czsc +from czsc.connectors.cooperation import * + +czsc.welcome() + + +def test_cooperation(): + # 获取股票列表 + symbols = get_symbols(name="股票") + print(f"股票数量:{len(symbols)}") + + # 获取日线数据 + kline = get_raw_bars(symbol="000001.SZ#STOCK", freq="日线", sdt="20220101", edt="20230101", fq="前复权") + print(kline[-5:]) + + # 获取60分钟数据 + kline = get_raw_bars(symbol="000001.SZ#STOCK", freq="60分钟", sdt="20220101", edt="20230101", fq="前复权") + print(kline[-10:]) + + # 获取ETF列表 + symbols = get_symbols(name="ETF") + print(f"ETF数量:{len(symbols)}") + + # 获取指数列表 + symbols = get_symbols(name="A股指数") + print(f"指数数量:{len(symbols)}") + + # 获取南华指数列表 + symbols = get_symbols(name="南华指数") + print(f"南华指数数量:{len(symbols)}") + + # 获取期货主力列表 + symbols = get_symbols(name="期货主力") + print(f"期货主力数量:{len(symbols)}") + + # 获取日线数据 + kline = get_raw_bars(symbol="SFIC9001", freq="日线", sdt="20210101", edt="20230101", fq="前复权") + kline = get_raw_bars(symbol="SFIC9001", freq="30分钟", sdt="20220101", edt="20230101", fq="前复权") diff --git a/test/test_features.py b/test/test_features.py new file mode 100644 index 000000000..e5393e073 --- /dev/null +++ b/test/test_features.py @@ -0,0 +1,14 @@ +import pytest +import pandas as pd + + +def test_is_event_feature(): + from czsc.features.utils import is_event_feature + + # 测试事件类因子 + df1 = pd.DataFrame({'factor': [0, 1, -1, 0, 1, -1]}) + assert is_event_feature(df1, 'factor') is True + + # 测试非事件类因子 + df2 = pd.DataFrame({'factor': [0, 1, 2, 3, 4, 5]}) + assert is_event_feature(df2, 'factor') is False diff --git a/test/test_utils.py b/test/test_utils.py index 880048378..f35064f7d 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -96,28 +96,28 @@ def test_daily_performance(): # Test case 1: empty daily returns result = daily_performance([]) - assert result == {'年化': 0, '夏普': 0, '最大回撤': 0, '卡玛': 0, '日胜率': 0, '年化波动率': 0, + assert result == {'绝对收益': 0, '年化': 0, '夏普': 0, '最大回撤': 0, '卡玛': 0, '日胜率': 0, '年化波动率': 0, '非零覆盖': 0, '盈亏平衡点': 0, '新高间隔': 0, '新高占比': 0} # Test case 2: daily returns with zero standard deviation result = daily_performance([1, 1, 1, 1, 1]) - assert result == {"年化": 0, "夏普": 0, "最大回撤": 0, "卡玛": 0, "日胜率": 0, + assert result == {'绝对收益': 0, "年化": 0, "夏普": 0, "最大回撤": 0, "卡玛": 0, "日胜率": 0, "年化波动率": 0, "非零覆盖": 0, "盈亏平衡点": 0, '新高间隔': 0, '新高占比': 0} # Test case 3: daily returns with all zeros result = daily_performance([0, 0, 0, 0, 0]) - assert result == {"年化": 0, "夏普": 0, "最大回撤": 0, "卡玛": 0, "日胜率": 0, + assert result == {'绝对收益': 0, "年化": 0, "夏普": 0, "最大回撤": 0, "卡玛": 0, "日胜率": 0, "年化波动率": 0, "非零覆盖": 0, "盈亏平衡点": 0, '新高间隔': 0, '新高占比': 0} # Test case 4: normal daily returns daily_returns = np.array([0.01, 0.02, -0.01, 0.03, 0.02, -0.02, 0.01, -0.01, 0.02, 0.01]) result = daily_performance(daily_returns) - assert result == {'年化': 2.016, '夏普': 5, '最大回撤': 0.02, '卡玛': 10, '日胜率': 0.7, '年化波动率': 0.2439, - '非零覆盖': 1.0, '盈亏平衡点': 0.7, '新高间隔': 4, '新高占比': 0.6} + assert result == {'绝对收益': 0.08, '年化': 2.016, '夏普': 5, '最大回撤': 0.02, '卡玛': 10, '日胜率': 0.7, '年化波动率': 0.2439, + '非零覆盖': 1.0, '盈亏平衡点': 0.7, '新高间隔': 5, '新高占比': 0.6} result = daily_performance([0.01, 0.02, -0.01, 0.03, 0.02, -0.02, 0.01, -0.01, 0.02, 0.01]) - assert result == {'年化': 2.016, '夏普': 5, '最大回撤': 0.02, '卡玛': 10, '日胜率': 0.7, '年化波动率': 0.2439, - '非零覆盖': 1.0, '盈亏平衡点': 0.7, '新高间隔': 4, '新高占比': 0.6} + assert result == {'绝对收益': 0.08, '年化': 2.016, '夏普': 5, '最大回撤': 0.02, '卡玛': 10, '日胜率': 0.7, '年化波动率': 0.2439, + '非零覆盖': 1.0, '盈亏平衡点': 0.7, '新高间隔': 5, '新高占比': 0.6} def test_find_most_similarity(): diff --git a/test/test_utils_cache.py b/test/test_utils_cache.py index 15be44bf9..4a01e86b3 100644 --- a/test/test_utils_cache.py +++ b/test/test_utils_cache.py @@ -12,12 +12,34 @@ def run_func_x(x): return x * 2 +@disk_cache(path=temp_path, suffix="txt", ttl=100) +def run_func_text(x): + return f"hello {x}" + + +@disk_cache(path=temp_path, suffix="json", ttl=100) +def run_func_json(x): + return {"a": 1, "b": 2, "x": x} + + @disk_cache(path=temp_path, suffix="xlsx", ttl=100) def run_func_y(x): df = pd.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6], 'x': [x, x, x]}) return df +@disk_cache(path=temp_path, suffix="feather", ttl=100) +def run_feather(x): + df = pd.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6], 'x': [x, x, x]}) + return df + + +@disk_cache(path=temp_path, suffix="parquet", ttl=100) +def run_parquet(x): + df = pd.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6], 'x': [x, x, x]}) + return df + + def test_disk_cache(): # Call the function result = run_func_x(5) @@ -31,6 +53,24 @@ def test_disk_cache(): # Check if the output is still correct assert result == 10 + # Call the function with a different argument + result = run_func_text(6) + result = run_func_text(6) + assert result == "hello 6" + + # Call the function with a different argument + result = run_func_json(7) + result = run_func_json(7) + assert result == {"a": 1, "b": 2, "x": 7} + + result = run_feather(8) + result = run_feather(8) + assert isinstance(result, pd.DataFrame) + + result = run_parquet(9) + result = run_parquet(9) + assert isinstance(result, pd.DataFrame) + # Check if the cache file exists files = os.listdir(os.path.join(temp_path, "run_func_x")) assert len(files) == 1