From 8daa12184c496a6a89da15c7a21af55ed69440cd Mon Sep 17 00:00:00 2001 From: zengbin93 Date: Wed, 1 Jan 2025 10:08:32 +0800 Subject: [PATCH] =?UTF-8?q?V0.9.62=20=E6=9B=B4=E6=96=B0=E4=B8=80=E6=89=B9?= =?UTF-8?q?=E4=BB=A3=E7=A0=81=20(#220)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 0.9.62 start coding * 0.9.62 test rs_czsc * 0.9.62 增加增量缓存数据函数 * 0.9.62 增加增量缓存数据函数 * 0.9.62 优化 streamlit 组件 * 0.9.62 新增 show_weight_distribution 组件 * 0.9.62 新增 CZSC_DATA_API 环境变量 * 0.9.62 rs_czsc * 0.9.62 rs_czsc * 0.9.62 不再支持 python 3.7 * 0.9.62 不再支持 python 3.7 * 0.9.62 remove pandas_ta * 0.9.62 remove pandas_ta * 0.9.62 fix bug * 0.9.62 新增K线质量检查工具 * 0.9.62 update limit leverage * 0.9.62 新增 clickhouse weights client * 0.9.62 fix check_kline_quality * 0.9.62 fix Optional * 0.9.62 fix Optional * 0.9.62 fix Optional --- .github/workflows/pythonpackage.yml | 6 +- czsc/__init__.py | 14 +- czsc/connectors/cooperation.py | 185 ++++++- czsc/eda.py | 28 +- czsc/strategies.py | 2 +- czsc/traders/cwc.py | 485 ++++++++++++++++++ czsc/utils/kline_quality.py | 325 ++++++++++-- czsc/utils/st_components.py | 108 +++- czsc/utils/ta.py | 3 +- ...50\351\207\217\346\243\200\346\237\245.py" | 482 +++++++++++++++++ examples/develop/weight_backtest.py | 23 +- requirements.txt | 5 +- setup.py | 1 - test/test_kline_quality.py | 39 +- 14 files changed, 1574 insertions(+), 132 deletions(-) create mode 100644 czsc/traders/cwc.py create mode 100644 "examples/develop/K\347\272\277\346\225\260\346\215\256\350\264\250\351\207\217\346\243\200\346\237\245.py" diff --git a/.github/workflows/pythonpackage.yml b/.github/workflows/pythonpackage.yml index 30a280f95..8c618e84b 100644 --- a/.github/workflows/pythonpackage.yml +++ b/.github/workflows/pythonpackage.yml @@ -5,7 +5,7 @@ name: Python package on: push: - branches: [ master, 'V0.9.61' ] + branches: [ master, 'V0.9.62' ] pull_request: branches: [ master ] @@ -15,7 +15,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.7, 3.8, 3.9, '3.10', '3.11'] + python-version: [3.8, 3.9, '3.10', '3.11', '3.12'] steps: - uses: actions/checkout@v2 @@ -30,7 +30,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install lxml==4.9.2 + pip install lxml pip install -r requirements.txt - name: Lint with flake8 run: | diff --git a/czsc/__init__.py b/czsc/__init__.py index 6035a6d82..d7e445869 100644 --- a/czsc/__init__.py +++ b/czsc/__init__.py @@ -164,6 +164,7 @@ show_classify, show_df_describe, show_date_effect, + show_weight_distribution, ) from czsc.utils.bi_info import ( @@ -196,13 +197,8 @@ ) -from czsc.utils.kline_quality import ( - check_high_low, - check_price_gap, - check_abnormal_volume, - check_zero_volume, -) - +from czsc.utils.kline_quality import check_kline_quality +from czsc.traders import cwc from czsc.utils.portfolio import ( max_sharp, @@ -225,10 +221,10 @@ ) -__version__ = "0.9.61" +__version__ = "0.9.62" __author__ = "zengbin93" __email__ = "zeng_bin8888@163.com" -__date__ = "20241101" +__date__ = "20241208" def welcome(): diff --git a/czsc/connectors/cooperation.py b/czsc/connectors/cooperation.py index 75329d9dc..0f9e5a82d 100644 --- a/czsc/connectors/cooperation.py +++ b/czsc/connectors/cooperation.py @@ -4,15 +4,15 @@ email: zeng_bin8888@163.com create_dt: 2023/11/15 20:45 describe: CZSC开源协作团队内部使用数据接口 - -接口说明:https://s0cqcxuy3p.feishu.cn/wiki/StQbwOrWdiJPpikET9EcrRVEnrd """ import os +import time import czsc import requests import loguru import pandas as pd from tqdm import tqdm +from pathlib import Path from datetime import datetime from czsc import RawBar, Freq @@ -20,7 +20,8 @@ # czsc.set_url_token(token='your token', url='http://zbczsc.com:9106') cache_path = os.getenv("CZSC_CACHE_PATH", os.path.expanduser("~/.quant_data_cache")) -dc = czsc.DataClient(token=os.getenv("CZSC_TOKEN"), url="http://zbczsc.com:9106", cache_path=cache_path) +url = os.getenv("CZSC_DATA_API", "http://zbczsc.com:9106") +dc = czsc.DataClient(token=os.getenv("CZSC_TOKEN"), url=url, cache_path=cache_path) def get_groups(): @@ -315,3 +316,181 @@ def get_stk_strategy(name="STK_001", **kwargs): dfw = pd.merge(dfw, dfb, on=["dt", "symbol"], how="left") dfh = dfw[["dt", "symbol", "weight", "n1b"]].copy() return dfh + + +# ====================================================================================================================== +# 增量更新本地缓存数据 +# ====================================================================================================================== +def get_all_strategies(ttl=3600 * 24 * 7, logger=loguru.logger, path=cache_path): + """获取所有策略的元数据 + + :param ttl: int, optional, 缓存时间,单位秒,默认为 7 天 + :param logger: loguru.logger, optional, 日志记录器 + :param path: str, optional, 缓存路径 + :return: pd.DataFrame, 包含字段 name, description, author, base_freq, outsample_sdt;示例如下: + + =========== ===================== ========= ========= ============ + name description author base_freq outsample_sdt + =========== ===================== ========= ========= ============ + STK_001 A股选股策略 ZB 1分钟 20220101 + STK_002 A股选股策略 ZB 1分钟 20220101 + STK_003 A股选股策略 ZB 1分钟 20220101 + =========== ===================== ========= ========= ============ + """ + path = Path(path) / "strategy" + path.mkdir(exist_ok=True, parents=True) + file_metas = path / "metas.feather" + + if file_metas.exists() and (time.time() - file_metas.stat().st_mtime) < ttl: + logger.info("【缓存命中】获取所有策略的元数据") + dfm = pd.read_feather(file_metas) + + else: + logger.info("【全量刷新】获取所有策略的元数据并刷新缓存") + dfm = dc.get_all_strategies(v=2, ttl=0) + dfm.to_feather(file_metas) + + return dfm + + +def __update_strategy_dailys(file_cache, strategy, logger=loguru.logger): + """更新策略的日收益数据""" + # 刷新缓存数据 + if file_cache.exists(): + df = pd.read_feather(file_cache) + + cache_sdt = (df["dt"].max() - pd.Timedelta(days=3)).strftime("%Y%m%d") + cache_edt = (pd.Timestamp.now() + pd.Timedelta(days=1)).strftime("%Y%m%d") + logger.info(f"【增量刷新缓存】获取策略 {strategy} 的日收益数据:{cache_sdt} - {cache_edt}") + + dfc = dc.sub_strategy_dailys(strategy=strategy, v=2, sdt=cache_sdt, edt=cache_edt, ttl=0) + dfc["dt"] = pd.to_datetime(dfc["dt"]) + df = pd.concat([df, dfc]).drop_duplicates(["dt", "symbol", "strategy"], keep="last") + + else: + cache_edt = (pd.Timestamp.now() + pd.Timedelta(days=1)).strftime("%Y%m%d") + logger.info(f"【全量刷新缓存】获取策略 {strategy} 的日收益数据:20170101 - {cache_edt}") + df = dc.sub_strategy_dailys(strategy=strategy, v=2, sdt="20170101", edt=cache_edt, ttl=0) + + df = df.reset_index(drop=True) + df["dt"] = pd.to_datetime(df["dt"]) + df.to_feather(file_cache) + return df + + +def get_strategy_dailys( + strategy="FCS001", symbol=None, sdt="20240101", edt=None, logger=loguru.logger, path=cache_path +): + """获取策略的历史日收益数据 + + :param strategy: 策略名称 + :param symbol: 品种名称 + :param sdt: 开始时间 + :param edt: 结束时间 + :param logger: loguru.logger, optional, 日志记录器 + :param path: str, optional, 缓存路径 + :return: pd.DataFrame, 包含字段 dt, symbol, strategy, returns;示例如下: + + =================== ========== ======== ========= + dt strategy symbol returns + =================== ========== ======== ========= + 2017-01-10 00:00:00 STK_001 A股选股 0.001 + 2017-01-11 00:00:00 STK_001 A股选股 0.012 + 2017-01-12 00:00:00 STK_001 A股选股 0.011 + =================== ========== ======== ========= + """ + path = Path(path) / "strategy" / "dailys" + path.mkdir(exist_ok=True, parents=True) + file_cache = path / f"{strategy}.feather" + + if edt is None: + edt = pd.Timestamp.now().strftime("%Y%m%d %H:%M:%S") + + # 判断缓存数据是否能满足需求 + if file_cache.exists(): + df = pd.read_feather(file_cache) + + if df["dt"].max() >= pd.Timestamp(edt): + logger.info(f"【缓存命中】获取策略 {strategy} 的日收益数据:{sdt} - {edt}") + + dfd = df[(df["dt"] >= pd.Timestamp(sdt)) & (df["dt"] <= pd.Timestamp(edt))].copy() + if symbol: + dfd = dfd[dfd["symbol"] == symbol].copy() + return dfd + + # 刷新缓存数据 + logger.info(f"【缓存刷新】获取策略 {strategy} 的日收益数据:{sdt} - {edt}") + df = __update_strategy_dailys(file_cache, strategy, logger=logger) + dfd = df[(df["dt"] >= pd.Timestamp(sdt)) & (df["dt"] <= pd.Timestamp(edt))].copy() + if symbol: + dfd = dfd[dfd["symbol"] == symbol].copy() + return dfd + + +def __update_strategy_weights(file_cache, strategy, logger=loguru.logger): + """更新策略的持仓权重数据""" + # 刷新缓存数据 + if file_cache.exists(): + df = pd.read_feather(file_cache) + + cache_sdt = (df["dt"].max() - pd.Timedelta(days=3)).strftime("%Y%m%d") + cache_edt = (pd.Timestamp.now() + pd.Timedelta(days=1)).strftime("%Y%m%d") + logger.info(f"【增量刷新缓存】获取策略 {strategy} 的持仓权重数据:{cache_sdt} - {cache_edt}") + + dfc = dc.post_request(api_name=strategy, v=2, sdt=cache_sdt, edt=cache_edt, hist=1, ttl=0) + dfc["dt"] = pd.to_datetime(dfc["dt"]) + dfc["strategy"] = strategy + + df = pd.concat([df, dfc]).drop_duplicates(["dt", "symbol", "weight"], keep="last") + + else: + cache_edt = (pd.Timestamp.now() + pd.Timedelta(days=1)).strftime("%Y%m%d") + logger.info(f"【全量刷新缓存】获取策略 {strategy} 的持仓权重数据:20170101 - {cache_edt}") + df = dc.post_request(api_name=strategy, v=2, sdt="20170101", edt=cache_edt, hist=1, ttl=0) + df["dt"] = pd.to_datetime(df["dt"]) + df["strategy"] = strategy + + df = df.reset_index(drop=True) + df.to_feather(file_cache) + return df + + +def get_strategy_weights(strategy="FCS001", sdt="20240101", edt=None, logger=loguru.logger, path=cache_path): + """获取策略的历史持仓权重数据 + + :param strategy: 策略名称 + :param sdt: 开始时间 + :param edt: 结束时间 + :param logger: loguru.logger, optional, 日志记录器 + :param path: str, optional, 缓存路径 + :return: pd.DataFrame, 包含字段 dt, symbol, weight, update_time, strategy;示例如下: + + =================== ========= ======== =================== ========== + dt symbol weight update_time strategy + =================== ========= ======== =================== ========== + 2017-01-09 00:00:00 000001.SZ 0 2024-07-27 16:13:29 STK_001 + 2017-01-10 00:00:00 000001.SZ 0 2024-07-27 16:13:29 STK_001 + 2017-01-11 00:00:00 000001.SZ 0 2024-07-27 16:13:29 STK_001 + =================== ========= ======== =================== ========== + """ + path = Path(path) / "strategy" / "weights" + path.mkdir(exist_ok=True, parents=True) + file_cache = path / f"{strategy}.feather" + + if edt is None: + edt = pd.Timestamp.now().strftime("%Y%m%d %H:%M:%S") + + # 判断缓存数据是否能满足需求 + if file_cache.exists(): + df = pd.read_feather(file_cache) + + if df["dt"].max() >= pd.Timestamp(edt): + logger.info(f"【缓存命中】获取策略 {strategy} 的历史持仓权重数据:{sdt} - {edt}") + dfd = df[(df["dt"] >= pd.Timestamp(sdt)) & (df["dt"] <= pd.Timestamp(edt))].copy() + return dfd + + # 刷新缓存数据 + logger.info(f"【缓存刷新】获取策略 {strategy} 的历史持仓权重数据:{sdt} - {edt}") + df = __update_strategy_weights(file_cache, strategy, logger=logger) + dfd = df[(df["dt"] >= pd.Timestamp(sdt)) & (df["dt"] <= pd.Timestamp(edt))].copy() + return dfd diff --git a/czsc/eda.py b/czsc/eda.py index 86060c06c..910a157a1 100644 --- a/czsc/eda.py +++ b/czsc/eda.py @@ -554,22 +554,38 @@ def limit_leverage(df: pd.DataFrame, leverage: float = 1.0, **kwargs): - window: int, 滚动窗口,默认为 300 - min_periods: int, 最小样本数,小于该值的窗口不计算均值,默认为 50 - weight: str, 权重列名,默认为 'weight' + - method: str, 计算均值的方法,'abs_mean' 或 'abs_max',默认为 'abs_mean' + abs_mean: 计算绝对均值作为调整杠杆的标准 + abs_max: 计算绝对最大值作为调整杠杆的标准 :return: DataFrame """ window = kwargs.get("window", 300) min_periods = kwargs.get("min_periods", 50) weight = kwargs.get("weight", "weight") + method = kwargs.get("method", "abs_mean") assert weight in df.columns, f"数据中不包含权重列 {weight}" - assert df['symbol'].nunique() == 1, "数据中包含多个品种,必须单品种" - assert df['dt'].is_monotonic_increasing, "数据未按日期排序,必须升序排列" - assert df['dt'].is_unique, "数据中存在重复dt,必须唯一" if kwargs.get("copy", False): df = df.copy() - abs_mean = df[weight].abs().rolling(window=window, min_periods=min_periods).mean().fillna(leverage) - adjust_ratio = leverage / abs_mean - df[weight] = (df[weight] * adjust_ratio).clip(-leverage, leverage) + df = df.sort_values(["dt", "symbol"], ascending=True).reset_index(drop=True) + + for symbol in df['symbol'].unique(): + dfx = df[df['symbol'] == symbol].copy() + # assert dfx['dt'].is_monotonic_increasing, f"{symbol} 数据未按日期排序,必须升序排列" + assert dfx['dt'].is_unique, f"{symbol} 数据中存在重复dt,必须唯一" + + if method == "abs_mean": + bench = dfx[weight].abs().rolling(window=window, min_periods=min_periods).mean().fillna(leverage) + elif method == "abs_max": + bench = dfx[weight].abs().rolling(window=window, min_periods=min_periods).max().fillna(leverage) + else: + raise ValueError(f"不支持的 method: {method}") + + adjust_ratio = leverage / bench + df.loc[df['symbol'] == symbol, weight] = (dfx[weight] * adjust_ratio).clip(-leverage, leverage) + return df + diff --git a/czsc/strategies.py b/czsc/strategies.py index c81e0fdcc..5b437956e 100644 --- a/czsc/strategies.py +++ b/czsc/strategies.py @@ -101,7 +101,7 @@ def init_bar_generator(self, bars: List[RawBar], **kwargs): :return: """ base_freq = str(bars[0].freq.value) - bg: BarGenerator = kwargs.get("bg", None) + bg: BarGenerator = kwargs.pop("bg", None) freqs = self.sorted_freqs[1:] if base_freq in self.sorted_freqs else self.sorted_freqs if bg is None: diff --git a/czsc/traders/cwc.py b/czsc/traders/cwc.py new file mode 100644 index 000000000..df979e329 --- /dev/null +++ b/czsc/traders/cwc.py @@ -0,0 +1,485 @@ +# -*- coding: utf-8 -*- +""" +author: zengbin93 +email: zeng_bin8888@163.com +create_dt: 2024/12/30 15:19 +describe: 基于 clickhouse 的策略持仓权重管理,cwc 为 clickhouse weights client 的缩写 + +推荐在环境变量中设置 clickhouse 的连接信息,如下: + +- CLICKHOUSE_HOST: 服务器地址,如 127.0.0.1 +- CLICKHOUSE_PORT: 服务器端口,如 9000 +- CLICKHOUSE_USER: 用户名, 如 default +- CLICKHOUSE_PASS: 密码, 如果没有密码,可以设置为空字符串 + +""" +# pip install clickhouse_connect -i https://pypi.tuna.tsinghua.edu.cn/simple +import os +import loguru +import pandas as pd +import clickhouse_connect as ch +from clickhouse_connect.driver import Client +from typing import Optional + + +def __db_from_env(): + host = os.getenv("CLICKHOUSE_HOST") + port = int(os.getenv("CLICKHOUSE_PORT")) + user = os.getenv("CLICKHOUSE_USER") + password = os.getenv("CLICKHOUSE_PASS") + + if not (host and port and user and password): + raise ValueError( + """ + 请设置环境变量:CLICKHOUSE_HOST, CLICKHOUSE_PORT, CLICKHOUSE_USER, CLICKHOUSE_PASS + + - CLICKHOUSE_HOST: 服务器地址,如 127.0.0.1 + - CLICKHOUSE_PORT: 服务器端口,如 9000 + - CLICKHOUSE_USER: 用户名, 如 default + - CLICKHOUSE_PASS: 密码, 如果没有密码,可以设置为空字符串 + """ + ) + + db = ch.get_client(host=host, port=port, user=user, password=password) + return db + + +def init_tables(db: Optional[Client] = None, **kwargs): + """ + 创建数据库表 + + :param db: clickhouse_connect.driver.Client, 数据库连接 + :param kwargs: dict, 数据表名和建表语句 + :return: None + """ + db = db or __db_from_env() + database = "czsc_strategy" + + # 创建数据库 + db.command(f"CREATE DATABASE IF NOT EXISTS {database}") + + metas_table = f""" + CREATE TABLE IF NOT EXISTS {database}.metas ( + strategy String NOT NULL, -- 策略名(唯一且不能为空) + base_freq String, -- 周期 + description String, -- 描述 + author String, -- 作者 + outsample_sdt DateTime, -- 样本外起始时间 + create_time DateTime, -- 策略入库时间 + update_time DateTime, -- 策略更新时间 + heartbeat_time DateTime, -- 最后一次心跳时间 + weight_type String, -- 策略上传的权重类型,ts 或 cs + memo String -- 策略备忘信息 + ) + ENGINE = ReplacingMergeTree() + ORDER BY strategy; + """ + + weights_table = f""" + CREATE TABLE IF NOT EXISTS {database}.weights ( + dt DateTime, -- 持仓权重时间 + symbol String, -- 符号(例如,股票代码或其他标识符) + weight Float64, -- 策略持仓权重值 + strategy String, -- 策略名称 + update_time DateTime -- 持仓权重更新时间 + ) + ENGINE = ReplacingMergeTree() + ORDER BY (strategy, dt, symbol); + """ + + latest_weights_view = f""" + CREATE VIEW IF NOT EXISTS {database}.latest_weights AS + SELECT + strategy, + symbol, + argMax(dt, dt) as latest_dt, + argMax(weight, dt) as latest_weight, + argMax(update_time, dt) as latest_update_time + FROM {database}.weights + GROUP BY strategy, symbol; + """ + + returns_table = f""" + CREATE TABLE IF NOT EXISTS {database}.returns ( + dt DateTime, -- 时间 + symbol String, -- 符号(例如,股票代码或其他标识符) + returns Float64, -- 策略收益,从上一个 dt 到当前 dt 的收益 + strategy String, -- 策略名称 + update_time DateTime -- 更新时间 + ) + ENGINE = ReplacingMergeTree() + ORDER BY (strategy, dt, symbol); + """ + + db.command(metas_table) + db.command(weights_table) + db.command(latest_weights_view) + db.command(returns_table) + + print("数据表创建成功!") + + +def get_meta(strategy, db: Optional[Client] = None, logger=loguru.logger) -> dict: + """获取策略元数据 + + :param db: clickhouse_connect.driver.Client, 数据库连接 + :param strategy: str, 策略名称 + :param logger: loguru.logger, 日志记录器 + :return: pd.DataFrame + """ + db = db or __db_from_env() + + query = f""" + SELECT * FROM czsc_strategy.metas final WHERE strategy = '{strategy}' + """ + df = db.query_df(query) + if df.empty: + logger.warning(f"策略 {strategy} 不存在元数据") + return {} + else: + assert len(df) == 1, f"策略 {strategy} 存在多条元数据,请检查" + return df.iloc[0].to_dict() + + +def get_all_metas(db: Optional[Client] = None) -> pd.DataFrame: + """获取所有策略元数据 + + :param db: clickhouse_connect.driver.Client, 数据库连接 + :return: pd.DataFrame + """ + db = db or __db_from_env() + df = db.query_df("SELECT * FROM czsc_strategy.metas final") + return df + + +def set_meta( + strategy, + base_freq, + description, + author, + outsample_sdt, + weight_type="ts", + memo="", + logger=loguru.logger, + overwrite=False, + db: Optional[Client] = None, +): + """设置策略元数据 + + :param db: clickhouse_connect.driver.Client, 数据库连接 + :param strategy: str, 策略名 + :param base_freq: str, 周期 + :param description: str, 描述 + :param author: str, 作者 + :param outsample_sdt: str, 样本外起始时间 + :param weight_type: str, 权重类型,ts 或 cs + :param memo: str, 备注 + :param logger: loguru.logger, 日志记录器 + :param overwrite: bool, 是否覆盖已有元数据 + :return: None + """ + db = db or __db_from_env() + + outsample_sdt = pd.to_datetime(outsample_sdt).tz_localize(None) + current_time = pd.to_datetime("now").tz_localize(None) + meta = get_meta(db=db, strategy=strategy) + + if not overwrite and meta: + logger.warning(f"策略 {strategy} 已存在元数据,如需更新请设置 overwrite=True") + return + + # create_time 在任何情况下都不会被覆盖,只有元数据不存在时才会设置 + create_time = current_time if not meta else pd.to_datetime(meta["create_time"]) + + # 构建DataFrame用于插入 + df = pd.DataFrame( + [ + { + "strategy": strategy, + "base_freq": base_freq, + "description": description, + "author": author, + "outsample_sdt": outsample_sdt, + "create_time": create_time, + "update_time": current_time, + "heartbeat_time": current_time, + "weight_type": weight_type, + "memo": memo, + } + ] + ) + res = db.insert_df("czsc_strategy.metas", df) + logger.info(f"{strategy} set_metadata: {res.summary}") + + +def __send_heartbeat(db: ch.driver.Client, strategy, logger=loguru.logger): + """发送心跳 + + :param db: clickhouse_connect.driver.Client, 数据库连接 + :param strategy: str, 策略名称 + :param logger: loguru.logger, 日志记录器 + :return: None + """ + try: + meta = get_meta(db=db, strategy=strategy) + if not meta: + logger.warning(f"策略 {strategy} 不存在元数据,无法发送心跳") + return + + current_time = pd.to_datetime("now").strftime("%Y-%m-%d %H:%M:%S") + db.command( + f"ALTER TABLE czsc_strategy.metas UPDATE heartbeat_time = '{current_time}' WHERE strategy = '{strategy}'" + ) + logger.info(f"策略 {strategy} 发送心跳成功") + + except Exception as e: + logger.error(f"发送心跳失败: {e}") + raise + + +def get_strategy_weights(strategy, db: Optional[Client] = None, sdt=None, edt=None, symbols=None): + """获取策略持仓权重 + + :param db: clickhouse_connect.driver.Client, 数据库连接 + :param strategy: str, 策略名称 + :param sdt: str, 开始时间 + :param edt: str, 结束时间 + :param symbols: list, 符号列表 + :return: pd.DataFrame + """ + db = db or __db_from_env() + + query = f""" + SELECT * FROM czsc_strategy.weights final WHERE strategy = '{strategy}' + """ + if sdt: + query += f" AND dt >= '{sdt}'" + if edt: + query += f" AND dt <= '{edt}'" + if symbols: + if isinstance(symbols, str): + symbols = [symbols] + symbol_str = ", ".join([f"'{s}'" for s in symbols]) + query += f""" AND symbol IN ({symbol_str})""" + + df = db.query_df(query) + df = df.sort_values(["dt", "symbol"]).reset_index(drop=True) + df["dt"] = df["dt"].dt.tz_localize(None) + df["update_time"] = df["update_time"].dt.tz_localize(None) + return df + + +def get_latest_weights(db: Optional[Client] = None, strategy=None): + """获取策略最新持仓权重时间 + + :param db: clickhouse_connect.driver.Client, 数据库连接 + :param strategy: str, 策略名称, 默认 None + :return: pd.DataFrame + """ + db = db or __db_from_env() + + query = "SELECT * FROM czsc_strategy.latest_weights final" + if strategy: + query += f" WHERE strategy = '{strategy}'" + + df = db.query_df(query) + df = df.rename(columns={"latest_dt": "dt", "latest_weight": "weight", "latest_update_time": "update_time"}) + if not df.empty: + df["dt"] = df["dt"].dt.tz_localize(None) + df["update_time"] = df["update_time"].dt.tz_localize(None) + df = df.sort_values(["strategy", "dt", "symbol"]).reset_index(drop=True) + return df + + +def publish_weights( + strategy: str, df: pd.DataFrame, batch_size=100000, logger=loguru.logger, db: Optional[Client] = None +): + """发布策略持仓权重 + + :param df: pd.DataFrame, 待发布的持仓权重数据 + :param db: clickhouse_connect.driver.Client, 数据库连接 + :param strategy: str, 策略名称 + :param batch_size: int, 批量发布的大小, 默认 100000 + :param logger: loguru.logger, 日志记录器 + :return: None + """ + db = db or __db_from_env() + + __send_heartbeat(db, strategy) + df = df[["dt", "symbol", "weight"]].copy() + df["strategy"] = strategy + df["dt"] = pd.to_datetime(df["dt"]) + + dfl = get_latest_weights(db, strategy) + + if not dfl.empty: + dfl["dt"] = pd.to_datetime(dfl["dt"]) + symbol_dt = dfl.set_index("symbol")["dt"].to_dict() + logger.info(f"策略 {strategy} 最新时间:{dfl['dt'].max()}") + + rows = [] + for symbol, dfg in df.groupby("symbol"): + if symbol in symbol_dt: + dfg = dfg[dfg["dt"] > symbol_dt[symbol]] + rows.append(dfg) + if rows: + df = pd.concat(rows, ignore_index=True) + + logger.info(f"策略 {strategy} 共 {len(df)} 条新信号") + + df = df.sort_values(["dt", "symbol"]).reset_index(drop=True) + df["update_time"] = pd.to_datetime("now") + df = df[["strategy", "symbol", "dt", "weight", "update_time"]].copy() + df = df.drop_duplicates(["symbol", "dt", "strategy"], keep="last").reset_index(drop=True) + df["weight"] = df["weight"].astype(float) + + logger.info(f"准备发布 {len(df)} 条策略信号") + + # 批量写入 + for i in range(0, len(df), batch_size): + batch_df = df.iloc[i : i + batch_size] + res = db.insert_df("czsc_strategy.weights", batch_df) + __send_heartbeat(db, strategy) + + if res: + logger.info(f"完成批次 {i//batch_size + 1}, 发布 {len(batch_df)} 条信号") + else: + logger.error(f"批次 {i//batch_size + 1} 发布失败: {res}") + return + + logger.info(f"完成所有信号发布, 共 {len(df)} 条") + __send_heartbeat(db, strategy) + + +def publish_returns( + strategy: str, + df: pd.DataFrame, + batch_size=100000, + logger=loguru.logger, + db: Optional[Client] = None, +): + """发布策略日收益 + + :param df: pd.DataFrame, 待发布的日收益数据 + :param db: clickhouse_connect.driver.Client, 数据库连接 + :param strategy: str, 策略名称 + :param batch_size: int, 批量发布的大小, 默认 100000 + :param logger: loguru.logger, 日志记录器 + :return: None + """ + db = db or __db_from_env() + + df = df[["dt", "symbol", "returns"]].copy() + df["strategy"] = strategy + df["dt"] = pd.to_datetime(df["dt"]) + + # 查询 czsc_strategy.returns 表中,每个品种最新的时间 + dfl = db.query_df( + f"SELECT symbol, max(dt) as dt FROM czsc_strategy.returns final WHERE strategy = '{strategy}' GROUP BY symbol" + ) + + if not dfl.empty: + dfl["dt"] = dfl["dt"].dt.tz_localize(None) + symbol_dt = dfl.set_index("symbol")["dt"].to_dict() + logger.info(f"策略 {strategy} 最新时间:{dfl['dt'].max()}") + + rows = [] + for symbol, dfg in df.groupby("symbol"): + if symbol in symbol_dt: + # 允许覆盖同一天的数据 + dfg = dfg[dfg["dt"] >= symbol_dt[symbol]] + rows.append(dfg) + if rows: + df = pd.concat(rows, ignore_index=True) + + logger.info(f"策略 {strategy} 共 {len(df)} 条新日收益") + + df = df.sort_values(["dt", "symbol"]).reset_index(drop=True) + df["update_time"] = pd.to_datetime("now") + df = df[["strategy", "symbol", "dt", "returns", "update_time"]].copy() + df = df.drop_duplicates(["symbol", "dt", "strategy"], keep="last").reset_index(drop=True) + df["returns"] = df["returns"].astype(float) + + logger.info(f"准备发布 {len(df)} 条策略日收益") + + # 批量写入 + for i in range(0, len(df), batch_size): + batch_df = df.iloc[i : i + batch_size] + res = db.insert_df("czsc_strategy.returns", batch_df) + + if res: + logger.info(f"完成批次 {i//batch_size + 1}, 发布 {len(batch_df)} 条日收益") + else: + logger.error(f"批次 {i//batch_size + 1} 发布失败") + return + + logger.info(f"完成所有日收益发布, 共 {len(df)} 条") + + +def get_strategy_returns(strategy, db: Optional[Client] = None, sdt=None, edt=None, symbols=None): + """获取策略日收益 + + :param db: clickhouse_connect.driver.Client, 数据库连接 + :param strategy: str, 策略名称 + :param sdt: str, 开始时间 + :param edt: str, 结束时间 + :param symbols: list, 符号列表 + :return: pd.DataFrame + """ + db = db or __db_from_env() + + query = f""" + SELECT * FROM czsc_strategy.returns final WHERE strategy = '{strategy}' + """ + if sdt: + query += f" AND dt >= '{sdt}'" + if edt: + query += f" AND dt <= '{edt}'" + if symbols: + if isinstance(symbols, str): + symbols = [symbols] + symbol_str = ", ".join([f"'{s}'" for s in symbols]) + query += f""" AND symbol IN ({symbol_str})""" + + df = db.query_df(query) + df = df.sort_values(["dt", "symbol"]).reset_index(drop=True) + df["dt"] = df["dt"].dt.tz_localize(None) + df["update_time"] = df["update_time"].dt.tz_localize(None) + return df + + +def clear_strategy(strategy, db: Optional[Client] = None, logger=loguru.logger, human_confirm=True): + """清空策略 + + :param db: clickhouse_connect.driver.Client, 数据库连接 + :param strategy: str, 策略名称 + :param logger: loguru.logger, 日志记录器 + :param human_confirm: bool, 是否需要人工确认,默认 True + :return: None + """ + db = db or __db_from_env() + + if human_confirm: + confirm = input(f"确认清空策略 {strategy} 的所有数据?(y/n): ") + if confirm.lower() != "y": + logger.warning(f"取消清空策略 {strategy} 的所有数据") + return + + query = f""" + DELETE FROM czsc_strategy.metas WHERE strategy = '{strategy}' + """ + _ = db.command(query) + logger.info(f"清空策略 {strategy} 元数据成功") + + query = f""" + DELETE FROM czsc_strategy.weights WHERE strategy = '{strategy}' + """ + _ = db.command(query) + logger.info(f"清空策略 {strategy} 持仓权重成功") + + query = f""" + DELETE FROM czsc_strategy.returns WHERE strategy = '{strategy}' + """ + _ = db.command(query) + logger.info(f"清空策略 {strategy} 日收益成功") + logger.warning(f"策略 {strategy} 清空完成") diff --git a/czsc/utils/kline_quality.py b/czsc/utils/kline_quality.py index 5de1a3977..ce06a75ba 100644 --- a/czsc/utils/kline_quality.py +++ b/czsc/utils/kline_quality.py @@ -3,65 +3,306 @@ email: zeng_bin8888@163.com create_dt: 2024/4/27 15:01 describe: K线质量评估工具函数 - -https://hailuoai.com/?chat=241699282914746375 """ import pandas as pd +import numpy as np + + +# 1. 缺失值检查 +def check_missing_values(df): + """ + 检查各列是否存在缺失值,并返回有缺失值的行。 + + :param df: 单个 symbol 的 DataFrame + :return: {'description': ..., 'rows': ...} 或 {'description': '无缺失值', 'rows': None} + """ + missing = df[df.isnull().any(axis=1)] + if not missing.empty: + return {"description": f"存在 {len(missing)} 条记录包含缺失值", "rows": missing} + else: + return {"description": "无缺失值", "rows": None} + + +# 2. 数据类型检查 +def check_data_types(df): + """ + 检查各列的数据类型是否符合预期,并返回类型不匹配的行。 + + :param df: 单个 symbol 的 DataFrame + :return: {'description': ..., 'rows': ...} 或 {'description': '数据类型均符合预期', 'rows': None} + """ + expected_types = { + "dt": "datetime64[ns]", + "symbol": "object", + "open": "float", + "close": "float", + "high": "float", + "low": "float", + "vol": "int64", + "amount": "float", + } + type_mismatches = {} + mismatch_rows = pd.DataFrame() + + for column, expected in expected_types.items(): + if column not in df.columns: + type_mismatches[column] = f"缺少列 {column}" + continue + actual_type = df[column].dtype + if expected.startswith("datetime"): + if not pd.api.types.is_datetime64_any_dtype(df[column]): + type_mismatches[column] = f"期望类型 {expected},但实际类型 {actual_type}" + mismatch_rows = pd.concat( + [mismatch_rows, df[df[column].apply(lambda x: not pd.api.types.is_datetime64_any_dtype([x]))]], + ignore_index=True, + ) + elif expected.startswith("float"): + if not pd.api.types.is_float_dtype(df[column]): + type_mismatches[column] = f"期望类型 {expected},但实际类型 {actual_type}" + mismatch_rows = pd.concat( + [mismatch_rows, df[df[column].apply(lambda x: not isinstance(x, float))]], ignore_index=True + ) + elif expected.startswith("int"): + if not pd.api.types.is_integer_dtype(df[column]): + type_mismatches[column] = f"期望类型 {expected},但实际类型 {actual_type}" + mismatch_rows = pd.concat( + [mismatch_rows, df[df[column].apply(lambda x: not isinstance(x, (int, np.integer)))]], + ignore_index=True, + ) + else: + if df[column].dtype != expected: + type_mismatches[column] = f"期望类型 {expected},但实际类型 {actual_type}" + mismatch_rows = pd.concat( + [mismatch_rows, df[df[column].apply(lambda x: not isinstance(x, str))]], ignore_index=True + ) + + if type_mismatches: + # 去重,避免同一行多次添加 + mismatch_rows = mismatch_rows.drop_duplicates() + return {"description": type_mismatches, "rows": mismatch_rows} + else: + return {"description": "数据类型均符合预期", "rows": None} + + +# 3. 日期时间顺序检查 +def check_datetime_order(df): + """ + 检查日期时间是否按升序排列,以及是否存在重复的日期时间,并返回相关的有问题的行。 + + :param df: 单个 symbol 的 DataFrame + :return: {'description': ..., 'rows': ...} 字典 + """ + results = {} + problem_rows = pd.DataFrame() + + # 检查是否按升序排列 + dt_sorted = df["dt"].is_monotonic_increasing + if not dt_sorted: + results["dt_order"] = "日期时间未按升序排列" + # 标记不按顺序的行 + sorted_df = df.sort_values("dt").reset_index(drop=True) + mismatched = df[df["dt"] != sorted_df["dt"]] + problem_rows = pd.concat([problem_rows, mismatched], ignore_index=True) + else: + results["dt_order"] = "日期时间按升序排列" + + # 检查重复的日期时间 + duplicate_dt = df.duplicated(subset=["dt"]).sum() + if duplicate_dt > 0: + results["duplicate_dt"] = f"存在 {duplicate_dt} 个重复的日期时间" + duplicates = df[df.duplicated(subset=["dt"], keep=False)] + problem_rows = pd.concat([problem_rows, duplicates], ignore_index=True) + else: + results["duplicate_dt"] = "无重复的日期时间" + + if not problem_rows.empty: + # 去重 + problem_rows = problem_rows.drop_duplicates() + return {"description": results, "rows": problem_rows} + else: + return {"description": results, "rows": None} + + +# 4. 价格合理性检查 +def check_price_reasonableness(df): + """ + 检查价格数据的合理性,并返回有问题的行。 + + :param df: 单个 symbol 的 DataFrame + :return: {'description': ..., 'rows': ...} 或 {'description': '所有价格数据合理', 'rows': None} + """ + issues = {} + problem_rows = pd.DataFrame() + + # high >= open, close + invalid_high = df[df["high"] < df[["open", "close"]].max(axis=1)] + if not invalid_high.empty: + issues["high_less_than_open_close"] = f"存在 {len(invalid_high)} 条记录,'high' 小于 'open' 或 'close'" + problem_rows = pd.concat([problem_rows, invalid_high], ignore_index=True) + # low <= open, close + invalid_low = df[df["low"] > df[["open", "close"]].min(axis=1)] + if not invalid_low.empty: + issues["low_greater_than_open_close"] = f"存在 {len(invalid_low)} 条记录,'low' 大于 'open' 或 'close'" + problem_rows = pd.concat([problem_rows, invalid_low], ignore_index=True) -def check_high_low(df): + # 价格不为负,且不为零 + negative_prices = df[(df[["open", "close", "high", "low"]] <= 0).any(axis=1)] + if not negative_prices.empty: + issues["negative_prices"] = f"存在 {len(negative_prices)} 条记录,价格为负数或零" + problem_rows = pd.concat([problem_rows, negative_prices], ignore_index=True) + + if issues: + # 去重 + problem_rows = problem_rows.drop_duplicates() + return {"description": issues, "rows": problem_rows} + else: + return {"description": "所有价格数据合理", "rows": None} + + +# 5. 成交量和金额检查 +def check_volume_amount(df): """ - 检查是否存在 high < low 的情况。 + 检查成交量和金额的数据合理性,并返回有问题的行。 + + :param df: 单个 symbol 的 DataFrame + :return: {'description': ..., 'rows': ...} 或 {'description': '成交量和金额数据合理', 'rows': None} """ - df["high_low_error"] = df["high"] < df["low"] - error_rate = df["high_low_error"].mean() - error_klines = df[df["high_low_error"]].copy() - return error_rate, error_klines + issues = {} + problem_rows = pd.DataFrame() + + # vol 和 amount 非负 + negative_vol = df[df["vol"] < 0] + if not negative_vol.empty: + issues["negative_vol"] = f"存在 {len(negative_vol)} 条记录,'vol' 为负数" + problem_rows = pd.concat([problem_rows, negative_vol], ignore_index=True) + + negative_amount = df[df["amount"] < 0] + if not negative_amount.empty: + issues["negative_amount"] = f"存在 {len(negative_amount)} 条记录,'amount' 为负数" + problem_rows = pd.concat([problem_rows, negative_amount], ignore_index=True) + + # vol 为零时 amount 也应为零 + zero_vol_nonzero_amount = df[(df["vol"] == 0) & (df["amount"] != 0)] + if not zero_vol_nonzero_amount.empty: + issues["zero_vol_nonzero_amount"] = f"存在 {len(zero_vol_nonzero_amount)} 条记录,'vol' 为零但 'amount' 不为零" + problem_rows = pd.concat([problem_rows, zero_vol_nonzero_amount], ignore_index=True) + if issues: + # 去重 + problem_rows = problem_rows.drop_duplicates() + return {"description": issues, "rows": problem_rows} + else: + return {"description": "成交量和金额数据合理", "rows": None} -def check_price_gap(df, **kwargs): + +# 6. 符号一致性检查 +def check_symbol_consistency(df): """ - 检查是否存在超过阈值的大幅度缺口。 + 检查符号数据的一致性和有效性,并返回有问题的行。 + + :param df: 单个 symbol 的 DataFrame + :return: {'description': ..., 'rows': ...} 或 {'description': '符号数据一致且有效', 'rows': None} """ - df = df.copy().sort_values(["dt", "symbol"]).reset_index(drop=True) - errors = [] - for symbol in df["symbol"].unique(): - symbol_df = df[df["symbol"] == symbol] - symbol_df["last_close"] = symbol_df["close"].shift(1) - symbol_df["price_gap"] = (symbol_df["open"] - symbol_df["last_close"]).abs() - gap_th = symbol_df["price_gap"].mean() + 3 * symbol_df["price_gap"].std() - error_ = symbol_df[symbol_df["price_gap"] > gap_th].copy() - if len(error_) > 0: - errors.append(error_) + # 检查符号是否为非空字符串 + invalid_symbols = df[df["symbol"].isnull() | (df["symbol"].astype(str).str.strip() == "")] + if not invalid_symbols.empty: + return {"description": f"存在 {len(invalid_symbols)} 条记录,符号为空或无效", "rows": invalid_symbols} + else: + return {"description": "符号数据一致且有效", "rows": None} - error_klines = pd.concat(errors) - error_rate = len(error_klines) / len(df) - return error_rate, error_klines +# 7. 重复记录检查 +def check_duplicate_records(df): + """ + 检查是否存在完全重复的记录,并返回重复的行。 -def check_abnormal_volume(df, **kwargs): + :param df: 单个 symbol 的 DataFrame + :return: {'description': ..., 'rows': ...} 或 {'description': '无重复记录', 'rows': None} """ - 检查是否存在异常成交量。 + duplicate_records = df[df.duplicated()] + if not duplicate_records.empty: + return {"description": f"存在 {len(duplicate_records)} 条完全重复的记录", "rows": duplicate_records} + else: + return {"description": "无重复记录", "rows": None} + + +# 8. 异常值检查 +def check_extreme_values(df, threshold=0.2): + """ + 检查价格日涨跌幅是否超过指定阈值,作为异常值,并返回有问题的行。 + + :param df: 单个 symbol 的 DataFrame + :param threshold: 涨跌幅阈值,默认为 50% + :return: {'description': ..., 'rows': ...} 或 {'description': '无异常的价格涨跌幅', 'rows': None} """ - df = df.copy().sort_values(["dt", "symbol"]).reset_index(drop=True) - errors = [] - for symbol in df["symbol"].unique(): - symbol_df = df[df["symbol"] == symbol] - volume_threshold = symbol_df["vol"].mean() + 3 * symbol_df["vol"].std() - error_ = symbol_df[symbol_df["vol"] > volume_threshold].copy() - if len(error_) > 0: - errors.append(error_) - error_klines = pd.concat(errors) - error_rate = len(error_klines) / len(df) - return error_rate, error_klines + if "close" not in df.columns: + return {"description": "缺少 'close' 列,无法进行异常值检查", "rows": None} + df = df.copy() + df["pct_change"] = df["close"].pct_change().abs() + extreme_changes = df[df["pct_change"] > threshold] + if not extreme_changes.empty: + return { + "description": f"存在 {len(extreme_changes)} 条记录,价格涨跌幅超过 {threshold*100}%", + "rows": extreme_changes, + } + else: + return {"description": "无异常的价格涨跌幅", "rows": None} -def check_zero_volume(df): + +# 主检查函数 +def check_kline_quality(df): """ - 计算零成交量的K线占比。 + 检查包含多个 symbol 的 K 线数据的质量问题,并返回有问题的行。 + + :param df: 包含 K 线数据的 DataFrame,必须包含以下列: + ['dt', 'symbol', 'open', 'close', 'high', 'low', 'vol', 'amount'] + :return: 嵌套字典,按 symbol 组织的检查结果,包含问题描述和有问题的行 """ - df = df.copy().sort_values(["dt", "symbol"]).reset_index(drop=True) - error_rate = df["vol"].eq(0).sum() / len(df) - error_klines = df[df["vol"].eq(0)].copy() - return error_rate, error_klines + required_columns = ["dt", "symbol", "open", "close", "high", "low", "vol", "amount"] + missing_columns = set(required_columns) - set(df.columns) + if missing_columns: + raise ValueError(f"输入数据缺少必要的列: {missing_columns}") + + # 确保 'dt' 列为 datetime 类型 + if not pd.api.types.is_datetime64_any_dtype(df["dt"]): + df["dt"] = pd.to_datetime(df["dt"], errors="coerce") + + # 按 symbol 分组 + grouped = df.groupby("symbol") + + quality_issues = {} + + for symbol, group in grouped: + symbol_issues = {} + + # 按日期排序 + group_sorted = group.sort_values("dt").reset_index(drop=True) + + # 逐个检查 + symbol_issues["missing_values"] = check_missing_values(group_sorted) + symbol_issues["type_mismatches"] = check_data_types(group_sorted) + symbol_issues["datetime_order"] = check_datetime_order(group_sorted) + symbol_issues["price_reasonableness"] = check_price_reasonableness(group_sorted) + symbol_issues["volume_amount"] = check_volume_amount(group_sorted) + symbol_issues["symbol_consistency"] = check_symbol_consistency(group_sorted) + symbol_issues["duplicate_records"] = check_duplicate_records(group_sorted) + symbol_issues["extreme_values"] = check_extreme_values(group_sorted) + + quality_issues[symbol] = symbol_issues + + # 输出检查结果 + for symbol, symbol_issues in quality_issues.items(): + for check, result in symbol_issues.items(): + + if result["rows"] is not None: + print(f"\n检查点: {symbol} - {check}") + print(f"结果描述: {result['description']}") + + print("有问题的数据行:") + print(result["rows"]) + print("\n\n") + + return quality_issues diff --git a/czsc/utils/st_components.py b/czsc/utils/st_components.py index 52a73cd46..22e2707aa 100644 --- a/czsc/utils/st_components.py +++ b/czsc/utils/st_components.py @@ -405,6 +405,26 @@ def show_factor_layering(df, factor, target="n1b", **kwargs): ) +def show_weight_distribution(dfw, abs_weight=True, **kwargs): + """展示权重分布 + + :param dfw: pd.DataFrame, 包含 symbol, dt, price, weight 列 + :param abs_weight: bool, 是否取权重的绝对值 + :param kwargs: + + - percentiles: list, 分位数 + """ + dfw = dfw.copy() + if abs_weight: + dfw["weight"] = dfw["weight"].abs() + + default_percentiles = [0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95] + percentiles = kwargs.get("percentiles", default_percentiles) + + dfs = dfw.groupby("symbol").apply(lambda x: x["weight"].describe(percentiles=percentiles)).reset_index() + show_df_describe(dfs) + + def show_symbol_factor_layering(df, x_col, y_col="n1b", **kwargs): """使用 streamlit 绘制单个标的上的因子分层收益率图 @@ -488,11 +508,18 @@ def show_weight_backtest(dfw, **kwargs): - n_jobs: int, 并行计算的进程数,默认为 1 """ + try: + from rs_czsc import WeightBacktest + except ImportError: + from czsc import WeightBacktest + from czsc.eda import cal_yearly_days fee = kwargs.get("fee", 2) digits = kwargs.get("digits", 2) + n_jobs = kwargs.pop("n_jobs", 1) yearly_days = kwargs.pop("yearly_days", None) + weight_type = kwargs.pop("weight_type", "ts") if not yearly_days: yearly_days = cal_yearly_days(dts=dfw["dt"].unique()) @@ -502,16 +529,16 @@ def show_weight_backtest(dfw, **kwargs): st.dataframe(dfw[dfw.isnull().sum(axis=1) > 0], use_container_width=True) st.stop() - wb = czsc.WeightBacktest( - dfw, - fee_rate=fee / 10000, - digits=digits, - n_jobs=kwargs.get("n_jobs", 1), - yearly_days=yearly_days, + wb = WeightBacktest( + dfw=dfw, fee_rate=fee / 10000, digits=digits, n_jobs=n_jobs, yearly_days=yearly_days, weight_type=weight_type ) - stat = wb.results["绩效评价"] + stat = wb.stats st.divider() + st.markdown( + f"**回测参数:** 单边手续费 {fee} BP,权重小数位数 {digits} ," + f"年交易天数 {yearly_days},品种数量:{dfw['symbol'].nunique()}" + ) c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11 = st.columns([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]) c1.metric("盈亏平衡点", f"{stat['盈亏平衡点']:.2%}") c2.metric("单笔收益(BP)", f"{stat['单笔收益']}") @@ -524,27 +551,30 @@ def show_weight_backtest(dfw, **kwargs): c9.metric("年化波动率", f"{stat['年化波动率']:.2%}") c10.metric("多头占比", f"{stat['多头占比']:.2%}") c11.metric("空头占比", f"{stat['空头占比']:.2%}") - st.caption(f"回测参数:单边手续费 {fee} BP,权重小数位数 {digits} ,年交易天数 {yearly_days}") - st.divider() - dret = wb.results["品种等权日收益"].copy() + with st.popover(label="交易方向统计", help="统计多头、空头交易次数、胜率、盈亏比等信息"): + dfx = pd.DataFrame([wb.long_stats, wb.short_stats]) + dfx.index = ["多头", "空头"] + dfx.index.name = "交易方向" + st.dataframe(dfx.T, use_container_width=True) + + dret = wb.daily_return.copy() dret["dt"] = pd.to_datetime(dret["date"]) dret = dret.set_index("dt").drop(columns=["date"]) - # dret.index = pd.to_datetime(dret.index) show_daily_return(dret, legend_only_cols=dfw["symbol"].unique().tolist(), yearly_days=yearly_days, **kwargs) if kwargs.get("show_drawdowns", False): show_drawdowns(dret, ret_col="total", sub_title="") - if kwargs.get("show_backtest_detail", False): - c1, c2 = st.columns([1, 1]) - with c1.expander("品种等权日收益", expanded=False): - df_ = wb.results["品种等权日收益"].copy() - st.dataframe(df_.style.background_gradient(cmap="RdYlGn_r").format("{:.2%}"), use_container_width=True) - - with c2.expander("查看开平交易对", expanded=False): - dfp = pd.concat([v["pairs"] for k, v in wb.results.items() if k in wb.symbols], ignore_index=True) - st.dataframe(dfp, use_container_width=True) + # if kwargs.get("show_backtest_detail", False): + # c1, c2 = st.columns([1, 1]) + # with c1.expander("品种等权日收益", expanded=False): + # df_ = wb.daily_return.copy() + # st.dataframe(df_.style.background_gradient(cmap="RdYlGn_r").format("{:.2%}"), use_container_width=True) + # + # # with c2.expander("查看开平交易对", expanded=False): + # # dfp = pd.concat([v["pairs"] for k, v in wb.results.items() if k in wb.symbols], ignore_index=True) + # # st.dataframe(dfp, use_container_width=True) if kwargs.get("show_splited_daily", False): with st.expander("品种等权日收益分段表现", expanded=False): @@ -558,6 +588,10 @@ def show_weight_backtest(dfw, **kwargs): with st.expander("月度累计收益", expanded=False): show_monthly_return(dret, ret_col="total", sub_title="") + if kwargs.get("show_weight_distribution", True): + with st.expander("策略分品种的 weight 分布", expanded=False): + show_weight_distribution(dfw, abs_weight=True) + return wb @@ -1007,7 +1041,14 @@ def show_drawdowns(df: pd.DataFrame, ret_col, **kwargs): dft = czsc.top_drawdowns(df[ret_col].copy(), top=10) dft = dft.style.background_gradient(cmap="RdYlGn_r", subset=["净值回撤"]) dft = dft.background_gradient(cmap="RdYlGn", subset=["回撤天数", "恢复天数", "新高间隔"]) - dft = dft.format({"净值回撤": "{:.2%}", "回撤天数": "{:.0f}", "恢复天数": "{:.0f}", "新高间隔": "{:.0f}"}) + dft = dft.format( + { + "净值回撤": "{:.2%}", + "回撤天数": "{:.0f}", + "恢复天数": "{:.0f}", + "新高间隔": "{:.0f}", + } + ) st.dataframe(dft, use_container_width=True) # 画图: 净值回撤 @@ -1019,20 +1060,33 @@ def show_drawdowns(df: pd.DataFrame, ret_col, **kwargs): line=dict(color="salmon"), fill="tozeroy", mode="lines", - name="回测曲线", + name="回撤曲线", + opacity=0.5, ) fig = go.Figure(drawdown) + # 增加累计收益曲线,右轴 + fig.add_trace( + go.Scatter( + x=df.index, y=df["cum_ret"], mode="lines", name="累计收益", yaxis="y2", opacity=0.8, line=dict(color="red") + ) + ) + fig.update_layout(yaxis2=dict(title="累计收益", overlaying="y", side="right")) + # 增加 10% 分位数线,30% 分位数线,50% 分位数线,同时增加文本标记 for q in [0.1, 0.3, 0.5]: y1 = df["drawdown"].quantile(q) - fig.add_hline(y=y1, line_dash="dot", line_color="green", line_width=2) - fig.add_annotation(x=df.index[-1], y=y1, text=f"{q:.1%} (DD: {y1:.2%})", showarrow=False, yshift=10) + fig.add_hline(y=y1, line_dash="dot", line_color="green", line_width=1) + fig.add_annotation( + x=df.index.unique()[5], + y=y1, + text=f"{q:.1%} (DD: {y1:.2%})", + showarrow=False, + yshift=10, + ) fig.update_layout(margin=dict(l=0, r=0, t=0, b=0)) - fig.update_layout(title="", xaxis_title="", yaxis_title="净值回撤", legend_title="回撤曲线") - # 限制 绘制高度 - fig.update_layout(height=300) + fig.update_layout(title="", xaxis_title="", yaxis_title="净值回撤", legend_title="回撤分析", height=300) st.plotly_chart(fig, use_container_width=True) diff --git a/czsc/utils/ta.py b/czsc/utils/ta.py index e95637fe8..69fab3170 100644 --- a/czsc/utils/ta.py +++ b/czsc/utils/ta.py @@ -11,7 +11,6 @@ """ import numpy as np import pandas as pd -import pandas_ta def SMA(close: np.array, timeperiod=5): @@ -526,6 +525,8 @@ def CHOP(high, low, close, **kwargs): :return: pd.Series, New feature generated. """ + import pandas_ta + return pandas_ta.chop(high=high, low=low, close=close, **kwargs) diff --git "a/examples/develop/K\347\272\277\346\225\260\346\215\256\350\264\250\351\207\217\346\243\200\346\237\245.py" "b/examples/develop/K\347\272\277\346\225\260\346\215\256\350\264\250\351\207\217\346\243\200\346\237\245.py" new file mode 100644 index 000000000..405a53872 --- /dev/null +++ "b/examples/develop/K\347\272\277\346\225\260\346\215\256\350\264\250\351\207\217\346\243\200\346\237\245.py" @@ -0,0 +1,482 @@ +import pandas as pd +import numpy as np + + +# 1. 缺失值检查 +def check_missing_values(df): + """ + 检查各列是否存在缺失值,并返回有缺失值的行。 + + :param df: 单个 symbol 的 DataFrame + :return: {'description': ..., 'rows': ...} 或 {'description': '无缺失值', 'rows': None} + """ + missing = df[df.isnull().any(axis=1)] + if not missing.empty: + return {"description": f"存在 {len(missing)} 条记录包含缺失值", "rows": missing} + else: + return {"description": "无缺失值", "rows": None} + + +# 2. 数据类型检查 +def check_data_types(df): + """ + 检查各列的数据类型是否符合预期,并返回类型不匹配的行。 + + :param df: 单个 symbol 的 DataFrame + :return: {'description': ..., 'rows': ...} 或 {'description': '数据类型均符合预期', 'rows': None} + """ + expected_types = { + "dt": "datetime64[ns]", + "symbol": "object", + "open": "float", + "close": "float", + "high": "float", + "low": "float", + "vol": "int64", + "amount": "float", + } + type_mismatches = {} + mismatch_rows = pd.DataFrame() + + for column, expected in expected_types.items(): + if column not in df.columns: + type_mismatches[column] = f"缺少列 {column}" + continue + actual_type = df[column].dtype + if expected.startswith("datetime"): + if not pd.api.types.is_datetime64_any_dtype(df[column]): + type_mismatches[column] = f"期望类型 {expected},但实际类型 {actual_type}" + mismatch_rows = pd.concat( + [mismatch_rows, df[df[column].apply(lambda x: not pd.api.types.is_datetime64_any_dtype([x]))]], + ignore_index=True, + ) + elif expected.startswith("float"): + if not pd.api.types.is_float_dtype(df[column]): + type_mismatches[column] = f"期望类型 {expected},但实际类型 {actual_type}" + mismatch_rows = pd.concat( + [mismatch_rows, df[df[column].apply(lambda x: not isinstance(x, float))]], ignore_index=True + ) + elif expected.startswith("int"): + if not pd.api.types.is_integer_dtype(df[column]): + type_mismatches[column] = f"期望类型 {expected},但实际类型 {actual_type}" + mismatch_rows = pd.concat( + [mismatch_rows, df[df[column].apply(lambda x: not isinstance(x, (int, np.integer)))]], + ignore_index=True, + ) + else: + if df[column].dtype != expected: + type_mismatches[column] = f"期望类型 {expected},但实际类型 {actual_type}" + mismatch_rows = pd.concat( + [mismatch_rows, df[df[column].apply(lambda x: not isinstance(x, str))]], ignore_index=True + ) + + if type_mismatches: + # 去重,避免同一行多次添加 + mismatch_rows = mismatch_rows.drop_duplicates() + return {"description": type_mismatches, "rows": mismatch_rows} + else: + return {"description": "数据类型均符合预期", "rows": None} + + +# 3. 日期时间顺序检查 +def check_datetime_order(df): + """ + 检查日期时间是否按升序排列,以及是否存在重复的日期时间,并返回相关的有问题的行。 + + :param df: 单个 symbol 的 DataFrame + :return: {'description': ..., 'rows': ...} 字典 + """ + results = {} + problem_rows = pd.DataFrame() + + # 检查是否按升序排列 + dt_sorted = df["dt"].is_monotonic_increasing + if not dt_sorted: + results["dt_order"] = "日期时间未按升序排列" + # 标记不按顺序的行 + sorted_df = df.sort_values("dt").reset_index(drop=True) + mismatched = df[df["dt"] != sorted_df["dt"]] + problem_rows = pd.concat([problem_rows, mismatched], ignore_index=True) + else: + results["dt_order"] = "日期时间按升序排列" + + # 检查重复的日期时间 + duplicate_dt = df.duplicated(subset=["dt"]).sum() + if duplicate_dt > 0: + results["duplicate_dt"] = f"存在 {duplicate_dt} 个重复的日期时间" + duplicates = df[df.duplicated(subset=["dt"], keep=False)] + problem_rows = pd.concat([problem_rows, duplicates], ignore_index=True) + else: + results["duplicate_dt"] = "无重复的日期时间" + + if not problem_rows.empty: + # 去重 + problem_rows = problem_rows.drop_duplicates() + return {"description": results, "rows": problem_rows} + else: + return {"description": results, "rows": None} + + +# 4. 价格合理性检查 +def check_price_reasonableness(df): + """ + 检查价格数据的合理性,并返回有问题的行。 + + :param df: 单个 symbol 的 DataFrame + :return: {'description': ..., 'rows': ...} 或 {'description': '所有价格数据合理', 'rows': None} + """ + issues = {} + problem_rows = pd.DataFrame() + + # high >= open, close + invalid_high = df[df["high"] < df[["open", "close"]].max(axis=1)] + if not invalid_high.empty: + issues["high_less_than_open_close"] = f"存在 {len(invalid_high)} 条记录,'high' 小于 'open' 或 'close'" + problem_rows = pd.concat([problem_rows, invalid_high], ignore_index=True) + + # low <= open, close + invalid_low = df[df["low"] > df[["open", "close"]].min(axis=1)] + if not invalid_low.empty: + issues["low_greater_than_open_close"] = f"存在 {len(invalid_low)} 条记录,'low' 大于 'open' 或 'close'" + problem_rows = pd.concat([problem_rows, invalid_low], ignore_index=True) + + # 价格不为负,且不为零 + negative_prices = df[(df[["open", "close", "high", "low"]] <= 0).any(axis=1)] + if not negative_prices.empty: + issues["negative_prices"] = f"存在 {len(negative_prices)} 条记录,价格为负数或零" + problem_rows = pd.concat([problem_rows, negative_prices], ignore_index=True) + + if issues: + # 去重 + problem_rows = problem_rows.drop_duplicates() + return {"description": issues, "rows": problem_rows} + else: + return {"description": "所有价格数据合理", "rows": None} + + +# 5. 成交量和金额检查 +def check_volume_amount(df): + """ + 检查成交量和金额的数据合理性,并返回有问题的行。 + + :param df: 单个 symbol 的 DataFrame + :return: {'description': ..., 'rows': ...} 或 {'description': '成交量和金额数据合理', 'rows': None} + """ + issues = {} + problem_rows = pd.DataFrame() + + # vol 和 amount 非负 + negative_vol = df[df["vol"] < 0] + if not negative_vol.empty: + issues["negative_vol"] = f"存在 {len(negative_vol)} 条记录,'vol' 为负数" + problem_rows = pd.concat([problem_rows, negative_vol], ignore_index=True) + + negative_amount = df[df["amount"] < 0] + if not negative_amount.empty: + issues["negative_amount"] = f"存在 {len(negative_amount)} 条记录,'amount' 为负数" + problem_rows = pd.concat([problem_rows, negative_amount], ignore_index=True) + + # vol 为零时 amount 也应为零 + zero_vol_nonzero_amount = df[(df["vol"] == 0) & (df["amount"] != 0)] + if not zero_vol_nonzero_amount.empty: + issues["zero_vol_nonzero_amount"] = f"存在 {len(zero_vol_nonzero_amount)} 条记录,'vol' 为零但 'amount' 不为零" + problem_rows = pd.concat([problem_rows, zero_vol_nonzero_amount], ignore_index=True) + + if issues: + # 去重 + problem_rows = problem_rows.drop_duplicates() + return {"description": issues, "rows": problem_rows} + else: + return {"description": "成交量和金额数据合理", "rows": None} + + +# 6. 符号一致性检查 +def check_symbol_consistency(df): + """ + 检查符号数据的一致性和有效性,并返回有问题的行。 + + :param df: 单个 symbol 的 DataFrame + :return: {'description': ..., 'rows': ...} 或 {'description': '符号数据一致且有效', 'rows': None} + """ + # 检查符号是否为非空字符串 + invalid_symbols = df[df["symbol"].isnull() | (df["symbol"].astype(str).str.strip() == "")] + if not invalid_symbols.empty: + return {"description": f"存在 {len(invalid_symbols)} 条记录,符号为空或无效", "rows": invalid_symbols} + else: + return {"description": "符号数据一致且有效", "rows": None} + + +# 7. 重复记录检查 +def check_duplicate_records(df): + """ + 检查是否存在完全重复的记录,并返回重复的行。 + + :param df: 单个 symbol 的 DataFrame + :return: {'description': ..., 'rows': ...} 或 {'description': '无重复记录', 'rows': None} + """ + duplicate_records = df[df.duplicated()] + if not duplicate_records.empty: + return {"description": f"存在 {len(duplicate_records)} 条完全重复的记录", "rows": duplicate_records} + else: + return {"description": "无重复记录", "rows": None} + + +# 8. 异常值检查 +def check_extreme_values(df, threshold=0.5): + """ + 检查价格日涨跌幅是否超过指定阈值,作为异常值,并返回有问题的行。 + + :param df: 单个 symbol 的 DataFrame + :param threshold: 涨跌幅阈值,默认为 50% + :return: {'description': ..., 'rows': ...} 或 {'description': '无异常的价格涨跌幅', 'rows': None} + """ + if "close" not in df.columns: + return {"description": "缺少 'close' 列,无法进行异常值检查", "rows": None} + + df = df.copy() + df["pct_change"] = df["close"].pct_change().abs() + extreme_changes = df[df["pct_change"] > threshold] + if not extreme_changes.empty: + return { + "description": f"存在 {len(extreme_changes)} 条记录,价格涨跌幅超过 {threshold*100}%", + "rows": extreme_changes, + } + else: + return {"description": "无异常的价格涨跌幅", "rows": None} + + +# 主检查函数 +def check_kline_data_quality_multiple_symbols(df): + """ + 检查包含多个 symbol 的 K 线数据的质量问题,并返回有问题的行。 + + :param df: 包含 K 线数据的 DataFrame,必须包含以下列: + ['dt', 'symbol', 'open', 'close', 'high', 'low', 'vol', 'amount'] + :return: 嵌套字典,按 symbol 组织的检查结果,包含问题描述和有问题的行 + """ + required_columns = ["dt", "symbol", "open", "close", "high", "low", "vol", "amount"] + missing_columns = set(required_columns) - set(df.columns) + if missing_columns: + raise ValueError(f"输入数据缺少必要的列: {missing_columns}") + + # 确保 'dt' 列为 datetime 类型 + if not pd.api.types.is_datetime64_any_dtype(df["dt"]): + df["dt"] = pd.to_datetime(df["dt"], errors="coerce") + + # 按 symbol 分组 + grouped = df.groupby("symbol") + + quality_issues = {} + + for symbol, group in grouped: + symbol_issues = {} + + # 按日期排序 + group_sorted = group.sort_values("dt").reset_index(drop=True) + + # 逐个检查 + symbol_issues["missing_values"] = check_missing_values(group_sorted) + symbol_issues["type_mismatches"] = check_data_types(group_sorted) + symbol_issues["datetime_order"] = check_datetime_order(group_sorted) + symbol_issues["price_reasonableness"] = check_price_reasonableness(group_sorted) + symbol_issues["volume_amount"] = check_volume_amount(group_sorted) + symbol_issues["symbol_consistency"] = check_symbol_consistency(group_sorted) + symbol_issues["duplicate_records"] = check_duplicate_records(group_sorted) + symbol_issues["extreme_values"] = check_extreme_values(group_sorted) + + quality_issues[symbol] = symbol_issues + + return quality_issues + + +# **示例用法** + + +def test(): + import pandas as pd + + # 示例数据 + data = { + "dt": pd.date_range(start="2023-01-01", periods=10, freq="D").tolist() * 2, + "symbol": ["AAPL"] * 10 + ["GOOG"] * 10, + "open": [ + 150.0, + 152.0, + 151.0, + 153.0, + 154.0, + 155.0, + 156.0, + 157.0, + 158.0, + 159.0, + 2800.0, + 2820.0, + 2810.0, + 2830.0, + 2840.0, + 2850.0, + 2860.0, + 2870.0, + 2880.0, + 2890.0, + ], + "close": [ + 152.0, + 151.0, + 153.0, + 154.0, + 155.0, + 156.0, + 157.0, + 158.0, + 159.0, + 160.0, + 2820.0, + 2810.0, + 2830.0, + 2840.0, + 2850.0, + 2860.0, + 2870.0, + 2880.0, + 2890.0, + 2900.0, + ], + "high": [ + 153.0, + 152.5, + 154.0, + 155.0, + 156.0, + 157.0, + 158.0, + 159.0, + 160.0, + 161.0, + 2825.0, + 2815.0, + 2835.0, + 2845.0, + 2855.0, + 2865.0, + 2875.0, + 2885.0, + 2895.0, + 2905.0, + ], + "low": [ + 149.0, + 150.5, + 150.0, + 152.0, + 153.0, + 154.0, + 155.0, + 156.0, + 157.0, + 158.0, + 2795.0, + 2805.0, + 2815.0, + 2825.0, + 2835.0, + 2845.0, + 2855.0, + 2865.0, + 2875.0, + 2885.0, + ], + "vol": [ + 1000, + 1100, + 1050, + 1150, + 1200, + 1250, + 1300, + 1350, + 1400, + 1450, + 2000, + 2100, + 2050, + 2150, + 2200, + 2250, + 2300, + 2350, + 2400, + 2450, + ], + "amount": [ + 150000.0, + 165500.0, + 160650.0, + 175500.0, + 186000.0, + 193750.0, + 202000.0, + 212250.0, + 224000.0, + 232250.0, + 4200000.0, + 4400000.0, + 4300000.0, + 4500000.0, + 4620000.0, + 4725000.0, + 4830000.0, + 4927500.0, + 5040000.0, + 5152500.0, + ], + } + + df = pd.DataFrame(data) + + # 引入缺失值和异常值进行测试 + # 对 AAPL + df.loc[2, "close"] = None # 缺失值 + df.loc[4, "high"] = 140.0 # high < open 或 close + df.loc[1, "vol"] = -500 # 负成交量 + + # 对 GOOG + df.loc[12, "low"] = 3000.0 # low > open 或 close + df.loc[15, "amount"] = -1000.0 # 负金额 + df.loc[18, "close"] = 5000.0 # 极端涨幅 + df.loc[19, "close"] = 3000.0 # 极端跌幅 + + # 执行数据质量检查 + issues = check_kline_data_quality_multiple_symbols(df) + + # 输出检查结果 + for symbol, symbol_issues in issues.items(): + print(f"\n=== 检查结果 for Symbol: {symbol} ===") + for check, result in symbol_issues.items(): + print(f"\n检查点: {check}") + print(f"结果描述: {result['description']}") + if result["rows"] is not None: + print("有问题的数据行:") + print(result["rows"]) + else: + print("无有问题的数据行。") + + +def test_new(): + df = pd.read_feather(r"C:\Users\zengb\Downloads\可转债.feather") + df["vol"] = df["vol"].astype(int) + # 执行数据质量检查 + issues = check_kline_data_quality_multiple_symbols(df) + + # 输出检查结果 + for symbol, symbol_issues in issues.items(): + for check, result in symbol_issues.items(): + + if result["rows"] is not None: + print(f"\n检查点: {symbol} - {check}") + print(f"结果描述: {result['description']}") + + print("有问题的数据行:") + print(result["rows"]) + print("\n\n") diff --git a/examples/develop/weight_backtest.py b/examples/develop/weight_backtest.py index 426aae280..37a375a83 100644 --- a/examples/develop/weight_backtest.py +++ b/examples/develop/weight_backtest.py @@ -5,6 +5,9 @@ import czsc import rs_czsc import pandas as pd +import streamlit as st + +st.set_page_config(layout="wide") def test_daily_performance(): @@ -1901,9 +1904,21 @@ def test_daily_performance(): def test_weight_backtest(): """从持仓权重样例数据中回测""" dfw = pd.read_feather(r"C:\Users\zengb\Downloads\weight_example.feather") + # dfw = pd.read_feather(r"C:\Users\zengb\Downloads\btc_weight_example.feather") + + pw = czsc.WeightBacktest(dfw.copy(), digits=2, fee_rate=0.0002, n_jobs=1, weight_type="ts") + print("\n", sorted(pw.stats.items())) + print("Python 版本方法:", dir(pw)) + + rw = rs_czsc.WeightBacktest(dfw.copy(), digits=2, fee_rate=0.0002, n_jobs=4, weight_type="ts") + print("\n", sorted(rw.stats.items())) + print("RUST 版本方法:", dir(rw)) - pw = czsc.WeightBacktest(dfw.copy(), digits=2, fee_rate=0.0002, n_jobs=1) - print(sorted(pw.stats.items())) - rw = rs_czsc.WeightBacktest(dfw.copy(), digits=2, fee_rate=0.0002, n_jobs=1) - print(sorted(rw.stats.items())) +# # dfw = pd.read_feather(r"C:\Users\zengb\Downloads\weight_example.feather") +# dfw = pd.read_feather(r"A:\量化研究\BTC策略1H持仓权重和日收益241201\BTC_1H_P01-weights.feather") +# dfw = dfw[["dt", "symbol", "weight", "price"]].copy().reset_index(drop=True) +# dfw.to_feather(r"C:\Users\zengb\Downloads\btc_weight_example.feather") +# st.dataframe(dfw.tail()) +# st.write(dfw.dtypes) +# czsc.show_weight_backtest(dfw) diff --git a/requirements.txt b/requirements.txt index da4e2e270..50ac3ed25 100644 --- a/requirements.txt +++ b/requirements.txt @@ -30,5 +30,6 @@ pytz flask scipy requests_toolbelt -pandas-ta -networkx \ No newline at end of file +networkx +rs_czsc>=0.1.2 +clickhouse_connect diff --git a/setup.py b/setup.py index 69473de35..69a99147a 100644 --- a/setup.py +++ b/setup.py @@ -28,7 +28,6 @@ package_data={"": ["utils/china_calendar.feather", "utils/minutes_split.feather"]}, classifiers=[ "License :: OSI Approved :: Apache Software License", - "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", diff --git a/test/test_kline_quality.py b/test/test_kline_quality.py index c02b79eae..16f65e59f 100644 --- a/test/test_kline_quality.py +++ b/test/test_kline_quality.py @@ -1,39 +1,12 @@ import pandas as pd -from czsc.utils.kline_quality import ( - check_high_low, - check_price_gap, - check_abnormal_volume, - check_zero_volume, -) +from czsc.utils.kline_quality import check_kline_quality from test.test_analyze import read_daily -def test_check_high_low(): - df = read_daily() - df = pd.DataFrame([x.__dict__ for x in df]) - error_rate, error_klines = check_high_low(df) - assert error_rate == 0 - - -def test_check_price_gap(): - df = read_daily() - df = pd.DataFrame([x.__dict__ for x in df]) - error_rate, error_klines = check_price_gap(df) - assert round(error_rate, 4) == 0.0183 - print(error_klines) - - -def test_check_abnormal_volume(): - df = read_daily() - df = pd.DataFrame([x.__dict__ for x in df]) - error_rate, error_klines = check_abnormal_volume(df) - assert round(error_rate, 4) == 0.0306 - print(error_klines) - - def test_check_zero_volume(): df = read_daily() - df = pd.DataFrame([x.__dict__ for x in df]) - error_rate, error_klines = check_zero_volume(df) - assert error_rate == 0 - print(error_klines) + df = pd.DataFrame([bar.__dict__ for bar in df]) + df["vol"] = df["vol"].astype(int) + # 执行数据质量检查 + df = df[["symbol", "dt", "open", "close", "high", "low", "vol", "amount"]] + issues = check_kline_quality(df)