Skip to content

Commit

Permalink
0.9.36 优化策略持仓权重发布
Browse files Browse the repository at this point in the history
  • Loading branch information
zengbin93 committed Nov 12, 2023
1 parent 8802cff commit 5eb27e3
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 28 deletions.
96 changes: 68 additions & 28 deletions czsc/traders/rwc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
create_dt: 2023/9/24 15:19
describe: 策略持仓权重管理
"""
import os
import time
import json
import redis
Expand All @@ -14,18 +15,15 @@
from datetime import datetime


logger.disable(__name__)


class RedisWeightsClient:
"""策略持仓权重收发客户端"""

version = "V231111"
version = "V231112"

def __init__(self, strategy_name, redis_url, **kwargs):
def __init__(self, strategy_name, redis_url=None, send_heartbeat=True, **kwargs):
"""
:param strategy_name: str, 策略名
:param redis_url: str, redis连接字符串
:param redis_url: str, redis连接字符串, 默认为None, 即从环境变量 RWC_REDIS_URL 中读取
For example::
Expand All @@ -41,19 +39,27 @@ def __init__(self, strategy_name, redis_url, **kwargs):
<https://www.iana.org/assignments/uri-schemes/prov/rediss>
- ``unix://``: creates a Unix Domain Socket connection.
:param send_heartbeat: boolean, 是否发送心跳
如果为True,会在后台启动一个线程,每15秒向redis发送一次心跳,用于检测策略是否存活。
推荐在写入数据时设置为True,读取数据时设置为False,避免无用的心跳。
:param kwargs: dict, 其他参数
- key_prefix: str, redis中key的前缀,默认为 Weights
- heartbeat_prefix: str, 心跳key的前缀,默认为 heartbeat
"""
self.strategy_name = strategy_name
self.redis_url = redis_url
self.redis_url = redis_url if redis_url else os.getenv("RWC_REDIS_URL")
self.key_prefix = kwargs.get("key_prefix", "Weights")

self.heartbeat_client = redis.from_url(redis_url, decode_responses=True)
self.heartbeat_prefix = kwargs.get("heartbeat_prefix", "heartbeat")

thread_safe_pool = redis.BlockingConnectionPool.from_url(redis_url, decode_responses=True)
thread_safe_pool = redis.BlockingConnectionPool.from_url(self.redis_url, decode_responses=True)
self.r = redis.Redis(connection_pool=thread_safe_pool)
self.lua_publish = RedisWeightsClient.register_lua_publish(self.r)

if kwargs.get('send_heartbeat', True):
if send_heartbeat:
self.heartbeat_client = redis.from_url(self.redis_url, decode_responses=True)
self.heartbeat_prefix = kwargs.get("heartbeat_prefix", "heartbeat")
self.heartbeat_thread = threading.Thread(target=self.__heartbeat, daemon=True)
self.heartbeat_thread.start()

Expand Down Expand Up @@ -134,7 +140,17 @@ def publish_dataframe(self, df, overwrite=False, batch_size=10000):
"""
df = df.copy()
df['dt'] = pd.to_datetime(df['dt'])
df = df.sort_values('dt')
logger.info(f"输入数据中有 {len(df)} 条权重信号")

# 去除单个品种下相邻时间权重相同的数据
_res = []
for _, dfg in df.groupby('symbol'):
dfg = dfg.sort_values('dt', ascending=True).reset_index(drop=True)
dfg = dfg[dfg['weight'].diff().fillna(1) != 0].copy()
_res.append(dfg)
df = pd.concat(_res, ignore_index=True)
df = df.sort_values(['dt']).reset_index(drop=True)
logger.info(f"去除单个品种下相邻时间权重相同的数据后,剩余 {len(df)} 条权重信号")

if 'price' not in df.columns:
df['price'] = 0
Expand Down Expand Up @@ -251,18 +267,36 @@ def get_symbols(self):
symbols = {x.split(":")[2] for x in keys}
return list(symbols)

def get_last_weights(self, symbols=None, ignore_zero=True):
def get_last_weights(self, symbols=None, ignore_zero=True, lua=True):
"""获取最近的持仓权重
:param symbols: list, 品种列表
:param ignore_zero: boolean, 是否忽略权重为0的品种
:param lua: boolean, 是否使用 lua 脚本获取,默认为True
如果要全量获取,推荐使用 lua 脚本,速度更快;如果要获取指定 symbols,不推荐使用 lua 脚本。
:return: pd.DataFrame
"""
symbols = symbols if symbols else self.get_symbols()
with self.r.pipeline() as pipe:
for symbol in symbols:
pipe.hgetall(f'{self.key_prefix}:{self.strategy_name}:{symbol}:LAST')
rows = pipe.execute()
if lua:
lua_script = """
local keys = redis.call('KEYS', ARGV[1])
local results = {}
for i=1, #keys do
results[i] = redis.call('HGETALL', keys[i])
end
return results
"""
key_pattern = self.key_prefix + ':' + self.strategy_name + ':*:LAST'
results = self.r.eval(lua_script, 0, key_pattern)
rows = [dict(zip(r[::2], r[1::2])) for r in results] # type: ignore
if symbols:
rows = [r for r in rows if r['symbol'] in symbols]

else:
symbols = symbols if symbols else self.get_symbols()
with self.r.pipeline() as pipe:
for symbol in symbols:
pipe.hgetall(f'{self.key_prefix}:{self.strategy_name}:{symbol}:LAST')
rows = pipe.execute()

dfw = pd.DataFrame(rows)
dfw['weight'] = dfw['weight'].astype(float)
Expand Down Expand Up @@ -318,16 +352,22 @@ def get_all_weights(self, sdt=None, edt=None, ignore_zero=True) -> pd.DataFrame:
:param ignore_zero: boolean, 是否忽略权重为0的品种
:return: pd.DataFrame
"""
keys = self.get_keys(f"{self.key_prefix}:{self.strategy_name}:*:*")
if keys is None or len(keys) == 0: # type: ignore
return pd.DataFrame()
lua_script = """
local keys = redis.call('KEYS', ARGV[1])
local results = {}
for i=1, #keys do
local last_part = keys[i]:match('([^:]+)$')
if #last_part == 14 and tonumber(last_part) ~= nil then
results[#results + 1] = redis.call('HGETALL', keys[i])
end
end
return results
"""
key_pattern = self.key_prefix + ':' + self.strategy_name + ':*:*'
results = self.r.eval(lua_script, 0, key_pattern)
results = [dict(zip(r[::2], r[1::2])) for r in results] # type: ignore

keys = [x for x in keys if len(x.split(":")[-1]) == 14] # type: ignore
with self.r.pipeline() as pipe:
for key in keys:
pipe.hgetall(key)
rows = pipe.execute()
df = pd.DataFrame(rows)
df = pd.DataFrame(results)
df['dt'] = pd.to_datetime(df['dt'])
df['weight'] = df['weight'].astype(float)
df = df.sort_values(['dt', 'symbol']).reset_index(drop=True)
Expand Down
48 changes: 48 additions & 0 deletions examples/test_offline/test_rwc_v231112.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import dotenv
dotenv.load_dotenv(r"D:\ZB\git_repo\waditu\czsc\examples\test_offline\.env", override=True)

import os
import sys
sys.path.insert(0, ".")
sys.path.insert(0, "..")

import czsc
import pandas as pd

assert czsc.RedisWeightsClient.version == "V231112"

print(os.getenv("RWC_REDIS_URL"))


def test_writer():
rwc = czsc.RedisWeightsClient('STK004_100', key_prefix='WeightsA', send_heartbeat=True)
# 首次写入,建议设置一些策略元数据
rwc.set_metadata(description='测试策略:仅用于读写redis测试', base_freq='日线', author='测试', outsample_sdt='20220101')
print(rwc.metadata)

dfw = pd.read_csv(r"C:\Users\zengb\Desktop\weights_example.csv")
# 写入单条数据
rwc.publish(**dfw.iloc[0].to_dict())

# 批量写入整个dataframe;样例超100万行,写入耗时约3分钟
rwc.publish_dataframe(dfw, overwrite=False, batch_size=10000)

# 获取redis中该策略有持仓权重的品种列表
symbols = rwc.get_symbols()
print(symbols)


def test_reader():
# 读取redis中的数据:send_heartbeat 推荐设置为 False,否则导致心跳数据异常
rwc = czsc.RedisWeightsClient('STK002_100', key_prefix='WeightsA', send_heartbeat=False)
symbols = rwc.get_symbols()

# 读取单个品种的持仓历史
df = rwc.get_hist_weights('000001.SZ', '20170101', '20230101')

# 读取所有品种最近一个时间的持仓权重,忽略权重为0的品种
df1 = rwc.get_last_weights(ignore_zero=True)

# 读取策略全部持仓权重历史
dfa1 = rwc.get_all_weights()
dfa2 = rwc.get_all_weights(ignore_zero=False)

0 comments on commit 5eb27e3

Please sign in to comment.