diff --git a/czsc/traders/rwc.py b/czsc/traders/rwc.py index d1236e150..938574bb2 100644 --- a/czsc/traders/rwc.py +++ b/czsc/traders/rwc.py @@ -5,6 +5,7 @@ create_dt: 2023/9/24 15:19 describe: 策略持仓权重管理 """ +import os import time import json import redis @@ -14,18 +15,15 @@ from datetime import datetime -logger.disable(__name__) - - class RedisWeightsClient: """策略持仓权重收发客户端""" - version = "V231111" + version = "V231112" - def __init__(self, strategy_name, redis_url, **kwargs): + def __init__(self, strategy_name, redis_url=None, send_heartbeat=True, **kwargs): """ :param strategy_name: str, 策略名 - :param redis_url: str, redis连接字符串 + :param redis_url: str, redis连接字符串, 默认为None, 即从环境变量 RWC_REDIS_URL 中读取 For example:: @@ -41,19 +39,27 @@ def __init__(self, strategy_name, redis_url, **kwargs): - ``unix://``: creates a Unix Domain Socket connection. + :param send_heartbeat: boolean, 是否发送心跳 + + 如果为True,会在后台启动一个线程,每15秒向redis发送一次心跳,用于检测策略是否存活。 + 推荐在写入数据时设置为True,读取数据时设置为False,避免无用的心跳。 + + :param kwargs: dict, 其他参数 + + - key_prefix: str, redis中key的前缀,默认为 Weights + - heartbeat_prefix: str, 心跳key的前缀,默认为 heartbeat """ self.strategy_name = strategy_name - self.redis_url = redis_url + self.redis_url = redis_url if redis_url else os.getenv("RWC_REDIS_URL") self.key_prefix = kwargs.get("key_prefix", "Weights") - self.heartbeat_client = redis.from_url(redis_url, decode_responses=True) - self.heartbeat_prefix = kwargs.get("heartbeat_prefix", "heartbeat") - - thread_safe_pool = redis.BlockingConnectionPool.from_url(redis_url, decode_responses=True) + thread_safe_pool = redis.BlockingConnectionPool.from_url(self.redis_url, decode_responses=True) self.r = redis.Redis(connection_pool=thread_safe_pool) self.lua_publish = RedisWeightsClient.register_lua_publish(self.r) - if kwargs.get('send_heartbeat', True): + if send_heartbeat: + self.heartbeat_client = redis.from_url(self.redis_url, decode_responses=True) + self.heartbeat_prefix = kwargs.get("heartbeat_prefix", "heartbeat") self.heartbeat_thread = threading.Thread(target=self.__heartbeat, daemon=True) self.heartbeat_thread.start() @@ -134,7 +140,17 @@ def publish_dataframe(self, df, overwrite=False, batch_size=10000): """ df = df.copy() df['dt'] = pd.to_datetime(df['dt']) - df = df.sort_values('dt') + logger.info(f"输入数据中有 {len(df)} 条权重信号") + + # 去除单个品种下相邻时间权重相同的数据 + _res = [] + for _, dfg in df.groupby('symbol'): + dfg = dfg.sort_values('dt', ascending=True).reset_index(drop=True) + dfg = dfg[dfg['weight'].diff().fillna(1) != 0].copy() + _res.append(dfg) + df = pd.concat(_res, ignore_index=True) + df = df.sort_values(['dt']).reset_index(drop=True) + logger.info(f"去除单个品种下相邻时间权重相同的数据后,剩余 {len(df)} 条权重信号") if 'price' not in df.columns: df['price'] = 0 @@ -251,18 +267,36 @@ def get_symbols(self): symbols = {x.split(":")[2] for x in keys} return list(symbols) - def get_last_weights(self, symbols=None, ignore_zero=True): + def get_last_weights(self, symbols=None, ignore_zero=True, lua=True): """获取最近的持仓权重 :param symbols: list, 品种列表 :param ignore_zero: boolean, 是否忽略权重为0的品种 + :param lua: boolean, 是否使用 lua 脚本获取,默认为True + 如果要全量获取,推荐使用 lua 脚本,速度更快;如果要获取指定 symbols,不推荐使用 lua 脚本。 :return: pd.DataFrame """ - symbols = symbols if symbols else self.get_symbols() - with self.r.pipeline() as pipe: - for symbol in symbols: - pipe.hgetall(f'{self.key_prefix}:{self.strategy_name}:{symbol}:LAST') - rows = pipe.execute() + if lua: + lua_script = """ + local keys = redis.call('KEYS', ARGV[1]) + local results = {} + for i=1, #keys do + results[i] = redis.call('HGETALL', keys[i]) + end + return results + """ + key_pattern = self.key_prefix + ':' + self.strategy_name + ':*:LAST' + results = self.r.eval(lua_script, 0, key_pattern) + rows = [dict(zip(r[::2], r[1::2])) for r in results] # type: ignore + if symbols: + rows = [r for r in rows if r['symbol'] in symbols] + + else: + symbols = symbols if symbols else self.get_symbols() + with self.r.pipeline() as pipe: + for symbol in symbols: + pipe.hgetall(f'{self.key_prefix}:{self.strategy_name}:{symbol}:LAST') + rows = pipe.execute() dfw = pd.DataFrame(rows) dfw['weight'] = dfw['weight'].astype(float) @@ -318,16 +352,22 @@ def get_all_weights(self, sdt=None, edt=None, ignore_zero=True) -> pd.DataFrame: :param ignore_zero: boolean, 是否忽略权重为0的品种 :return: pd.DataFrame """ - keys = self.get_keys(f"{self.key_prefix}:{self.strategy_name}:*:*") - if keys is None or len(keys) == 0: # type: ignore - return pd.DataFrame() + lua_script = """ + local keys = redis.call('KEYS', ARGV[1]) + local results = {} + for i=1, #keys do + local last_part = keys[i]:match('([^:]+)$') + if #last_part == 14 and tonumber(last_part) ~= nil then + results[#results + 1] = redis.call('HGETALL', keys[i]) + end + end + return results + """ + key_pattern = self.key_prefix + ':' + self.strategy_name + ':*:*' + results = self.r.eval(lua_script, 0, key_pattern) + results = [dict(zip(r[::2], r[1::2])) for r in results] # type: ignore - keys = [x for x in keys if len(x.split(":")[-1]) == 14] # type: ignore - with self.r.pipeline() as pipe: - for key in keys: - pipe.hgetall(key) - rows = pipe.execute() - df = pd.DataFrame(rows) + df = pd.DataFrame(results) df['dt'] = pd.to_datetime(df['dt']) df['weight'] = df['weight'].astype(float) df = df.sort_values(['dt', 'symbol']).reset_index(drop=True) diff --git a/examples/test_offline/test_rwc_v231112.py b/examples/test_offline/test_rwc_v231112.py new file mode 100644 index 000000000..bcb275916 --- /dev/null +++ b/examples/test_offline/test_rwc_v231112.py @@ -0,0 +1,48 @@ +import dotenv +dotenv.load_dotenv(r"D:\ZB\git_repo\waditu\czsc\examples\test_offline\.env", override=True) + +import os +import sys +sys.path.insert(0, ".") +sys.path.insert(0, "..") + +import czsc +import pandas as pd + +assert czsc.RedisWeightsClient.version == "V231112" + +print(os.getenv("RWC_REDIS_URL")) + + +def test_writer(): + rwc = czsc.RedisWeightsClient('STK004_100', key_prefix='WeightsA', send_heartbeat=True) + # 首次写入,建议设置一些策略元数据 + rwc.set_metadata(description='测试策略:仅用于读写redis测试', base_freq='日线', author='测试', outsample_sdt='20220101') + print(rwc.metadata) + + dfw = pd.read_csv(r"C:\Users\zengb\Desktop\weights_example.csv") + # 写入单条数据 + rwc.publish(**dfw.iloc[0].to_dict()) + + # 批量写入整个dataframe;样例超100万行,写入耗时约3分钟 + rwc.publish_dataframe(dfw, overwrite=False, batch_size=10000) + + # 获取redis中该策略有持仓权重的品种列表 + symbols = rwc.get_symbols() + print(symbols) + + +def test_reader(): + # 读取redis中的数据:send_heartbeat 推荐设置为 False,否则导致心跳数据异常 + rwc = czsc.RedisWeightsClient('STK002_100', key_prefix='WeightsA', send_heartbeat=False) + symbols = rwc.get_symbols() + + # 读取单个品种的持仓历史 + df = rwc.get_hist_weights('000001.SZ', '20170101', '20230101') + + # 读取所有品种最近一个时间的持仓权重,忽略权重为0的品种 + df1 = rwc.get_last_weights(ignore_zero=True) + + # 读取策略全部持仓权重历史 + dfa1 = rwc.get_all_weights() + dfa2 = rwc.get_all_weights(ignore_zero=False)