Skip to content

Commit

Permalink
0.9.30 优化 RedisWeightsClient
Browse files Browse the repository at this point in the history
  • Loading branch information
zengbin93 committed Sep 28, 2023
1 parent dc6045a commit 289ff6d
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 7 deletions.
15 changes: 8 additions & 7 deletions czsc/traders/rwc.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,12 @@ def __init__(self, strategy_name, redis_url, **kwargs):

self.heartbeat_thread = threading.Thread(target=self.__heartbeat, daemon=True)
self.heartbeat_thread.start()
self.key_prefix = kwargs.get("key_prefix", "Weights")

def set_metadata(self, base_freq, description, author, outsample_sdt, **kwargs):
"""设置策略元数据"""
outsample_sdt = pd.to_datetime(outsample_sdt).strftime('%Y%m%d')
meta = {'name': self.strategy_name, 'base_freq': base_freq,
meta = {'name': self.strategy_name, 'base_freq': base_freq, 'key_prefix': self.key_prefix,
'description': description, 'author': author, 'outsample_sdt': outsample_sdt,
'update_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
'kwargs': json.dumps(kwargs)}
Expand All @@ -79,7 +80,7 @@ def publish(self, symbol, dt, weight, price=0, ref=None, overwrite=False):
dt = pd.to_datetime(dt)

udt = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
key = f'Weights:{self.strategy_name}:{symbol}:{dt.strftime("%Y%m%d%H%M%S")}'
key = f'{self.key_prefix}:{self.strategy_name}:{symbol}:{dt.strftime("%Y%m%d%H%M%S")}'
ref = ref if ref else '{}'
ref_str = json.dumps(ref) if isinstance(ref, dict) else ref
return self.lua_publish(keys=[key], args=[1 if overwrite else 0, udt, weight, price, ref_str])
Expand All @@ -103,7 +104,7 @@ def publish_dataframe(self, df, overwrite=False, batch_size=10000):

keys, args = [], []
for row in df[['symbol', 'dt', 'weight', 'price', 'ref']].to_numpy():
key = f'Weights:{self.strategy_name}:{row[0]}:{row[1].strftime("%Y%m%d%H%M%S")}'
key = f'{self.key_prefix}:{self.strategy_name}:{row[0]}:{row[1].strftime("%Y%m%d%H%M%S")}'
keys.append(key)

args.append(row[2])
Expand Down Expand Up @@ -144,7 +145,7 @@ def get_keys(self, pattern):
def clear_all(self):
"""删除该策略所有记录"""
self.r.delete(f"{self.strategy_name}:meta")
keys = self.get_keys(f'Weights:{self.strategy_name}*')
keys = self.get_keys(f'{self.key_prefix}:{self.strategy_name}*')
if keys is not None and len(keys) > 0:
self.r.delete(*keys)

Expand Down Expand Up @@ -195,7 +196,7 @@ def register_lua_publish(client):

def get_symbols(self):
"""获取策略交易的品种列表"""
keys = self.get_keys(f'Weights:{self.strategy_name}*')
keys = self.get_keys(f'{self.key_prefix}:{self.strategy_name}*')
symbols = {x.split(":")[2] for x in keys}
return list(symbols)

Expand All @@ -204,7 +205,7 @@ def get_last_weights(self, symbols=None):
symbols = symbols if symbols else self.get_symbols()
with self.r.pipeline() as pipe:
for symbol in symbols:
pipe.hgetall(f"Weights:{self.strategy_name}:{symbol}:LAST")
pipe.hgetall(f'{self.key_prefix}:{self.strategy_name}:{symbol}:LAST')
rows = pipe.execute()

dfw = pd.DataFrame(rows)
Expand All @@ -216,7 +217,7 @@ def get_hist_weights(self, symbol, sdt, edt) -> pd.DataFrame:
"""获取单个品种的持仓权重历史数据"""
start_score = pd.to_datetime(sdt).strftime('%Y%m%d%H%M%S')
end_score = pd.to_datetime(edt).strftime('%Y%m%d%H%M%S')
model_key = f'Weights:{self.strategy_name}:{symbol}'
model_key = f'{self.key_prefix}:{self.strategy_name}:{symbol}'
key_list = self.r.zrangebyscore(model_key, start_score, end_score)

if len(key_list) == 0:
Expand Down
32 changes: 32 additions & 0 deletions examples/test_offline/test_rwc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
def test_weights_producer():
import pandas as pd
from czsc.traders import RedisWeightsClient

redis_url = 'redis://20.205.5.**:9103/1'
rwc = RedisWeightsClient('test', redis_url, key_prefix='WeightsA')

# 如果需要清理 redis 中的数据,执行
rwc.clear_all()

# 首次写入,建议设置一些策略元数据
rwc.set_metadata(description='测试策略:仅用于读写redis测试', base_freq='1分钟', author='ZB', outsample_sdt='20220101')
print(rwc.metadata)

# 写入策略持仓权重,样例数据下载:https://s0cqcxuy3p.feishu.cn/wiki/Pf1fw1woQi4iJikbKJmcYToznxb
weights = pd.read_feather(r"C:\Users\zengb\Downloads\weight_example.feather")

# 写入单条数据
rwc.publish(**weights.iloc[0].to_dict())

# 批量写入整个dataframe;样例超300万行,写入耗时约5分钟
rwc.publish_dataframe(weights, overwrite=False, batch_size=1000000)

# 获取redis中该策略有持仓权重的品种列表
symbols = rwc.get_symbols()
print(symbols)

# 获取指定品种在某个时间段的持仓权重数据
dfw = rwc.get_hist_weights('ZZSF9001', '20210101', '20230101')

# 获取所有品种最近一个时间的持仓权重
dfr = rwc.get_last_weights(symbols=symbols)

0 comments on commit 289ff6d

Please sign in to comment.