Skip to content

Commit

Permalink
0.9.32 优化 RedisWeightsClient
Browse files Browse the repository at this point in the history
  • Loading branch information
zengbin93 committed Oct 14, 2023
1 parent b39f536 commit 24c1d85
Showing 1 changed file with 29 additions and 21 deletions.
50 changes: 29 additions & 21 deletions czsc/traders/rwc.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
class RedisWeightsClient:
"""策略持仓权重收发客户端"""

version = "V231012"
version = "V231014"

def __init__(self, strategy_name, redis_url, **kwargs):
"""
Expand Down Expand Up @@ -81,22 +81,22 @@ def metadata(self):
key = f'{self.key_prefix}:META:{self.strategy_name}'
return self.r.hgetall(key)

@property
def last_time(self):
"""获取策略最近一次发布信号的时间"""
keys = self.get_keys(f'{self.key_prefix}:{self.strategy_name}:*:LAST')
dt = None
for key in keys:
dt_ = pd.to_datetime(str(self.r.hget(key, 'dt')))
dt = dt_ if not dt else max(dt, dt_)
return dt

def get_last_time(self, symbol):
"""获取指定品种上策略最近一次发布信号的时间"""
key = f'{self.key_prefix}:{self.strategy_name}:{symbol}:LAST'
if not self.r.exists(key):
return None
return pd.to_datetime(str(self.r.hget(key, 'dt')))
def get_last_times(self, symbols=None):
"""获取所有品种上策略最近一次发布信号的时间
:param symbols: list, 品种列表, 默认为None, 即获取所有品种
:return: dict, {symbol: datetime},如{'SFIF9001': datetime(2021, 9, 24, 15, 19, 0)}
"""
if isinstance(symbols, str):
row = self.r.hgetall(f'{self.key_prefix}:{self.strategy_name}:{symbols}:LAST')
return pd.to_datetime(row['dt']) if row else None # type: ignore

symbols = symbols if symbols else self.get_symbols()
with self.r.pipeline() as pipe:
for symbol in symbols:
pipe.hgetall(f'{self.key_prefix}:{self.strategy_name}:{symbol}:LAST')
rows = pipe.execute()
return {x['symbol']: pd.to_datetime(x['dt']) for x in rows}

def publish(self, symbol, dt, weight, price=0, ref=None, overwrite=False):
"""发布单个策略持仓权重
Expand All @@ -113,7 +113,7 @@ def publish(self, symbol, dt, weight, price=0, ref=None, overwrite=False):
dt = pd.to_datetime(dt)

if not overwrite:
last_dt = self.get_last_time(symbol)
last_dt = self.get_last_times(symbol)
if last_dt is not None and dt <= last_dt:
logger.warning(f"不允许重复写入,已过滤 {symbol} {dt} 的重复信号")
return 0
Expand Down Expand Up @@ -143,9 +143,10 @@ def publish_dataframe(self, df, overwrite=False, batch_size=10000):

if not overwrite:
raw_count = len(df)
_time = self.get_last_times()
_data = []
for symbol, dfg in df.groupby('symbol'):
last_dt = self.get_last_time(symbol)
last_dt = _time.get(symbol)
if last_dt is not None:
dfg = dfg[dfg['dt'] > last_dt]
_data.append(dfg)
Expand Down Expand Up @@ -250,8 +251,13 @@ def get_symbols(self):
symbols = {x.split(":")[2] for x in keys}
return list(symbols)

def get_last_weights(self, symbols=None):
"""获取最近的持仓权重"""
def get_last_weights(self, symbols=None, ignore_zero=True):
"""获取最近的持仓权重
:param symbols: list, 品种列表
:param ignore_zero: boolean, 是否忽略权重为0的品种
:return: pd.DataFrame
"""
symbols = symbols if symbols else self.get_symbols()
with self.r.pipeline() as pipe:
for symbol in symbols:
Expand All @@ -261,6 +267,8 @@ def get_last_weights(self, symbols=None):
dfw = pd.DataFrame(rows)
dfw['weight'] = dfw['weight'].astype(float)
dfw['dt'] = pd.to_datetime(dfw['dt'])
if ignore_zero:
dfw = dfw[dfw['weight'] != 0].copy().reset_index(drop=True)
return dfw

def get_hist_weights(self, symbol, sdt, edt) -> pd.DataFrame:
Expand Down

0 comments on commit 24c1d85

Please sign in to comment.