Skip to content

Commit

Permalink
0.9.37 weight backtest 增加多进程支持
Browse files Browse the repository at this point in the history
  • Loading branch information
zengbin93 committed Nov 26, 2023
1 parent 9920ad4 commit 5e4abf0
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 10 deletions.
30 changes: 23 additions & 7 deletions czsc/traders/weight_backtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@
import numpy as np
import pandas as pd
import plotly.express as px
from tqdm import tqdm
from loguru import logger
from pathlib import Path
from typing import Union, AnyStr, Callable
from multiprocessing import cpu_count
from concurrent.futures import ProcessPoolExecutor
from czsc.traders.base import CzscTrader
from czsc.utils.io import save_json
from czsc.utils.stats import daily_performance, evaluate_pairs
Expand Down Expand Up @@ -158,7 +161,7 @@ class WeightBacktest:
飞书文档:https://s0cqcxuy3p.feishu.cn/wiki/Pf1fw1woQi4iJikbKJmcYToznxb
"""
version = "V231104"
version = "V231126"

def __init__(self, dfw, digits=2, **kwargs) -> None:
"""持仓权重回测
Expand Down Expand Up @@ -206,7 +209,7 @@ def __init__(self, dfw, digits=2, **kwargs) -> None:
self.fee_rate = kwargs.get('fee_rate', 0.0002)
self.dfw['weight'] = self.dfw['weight'].astype('float').round(digits)
self.symbols = list(self.dfw['symbol'].unique().tolist())
self.results = self.backtest()
self.results = self.backtest(n_jobs=kwargs.get('n_jobs', int(cpu_count() / 2)))

def get_symbol_daily(self, symbol):
"""获取某个合约的每日收益率
Expand Down Expand Up @@ -353,7 +356,13 @@ def __add_operate(dt, bar_id, volume, price, operate):
df_pairs = pd.DataFrame(pairs)
return df_pairs

def backtest(self):
def process_symbol(self, symbol):
"""处理某个合约的回测数据"""
daily = self.get_symbol_daily(symbol)
pairs = self.get_symbol_pairs(symbol)
return symbol, {"daily": daily, "pairs": pairs}

def backtest(self, n_jobs=1):
"""回测所有合约的收益率
函数计算逻辑:
Expand All @@ -369,12 +378,19 @@ def backtest(self):
4. 返回结果:将合约的等权日收益数据和绩效评价结果存储在res字典中,并将该字典作为函数的返回结果。
"""
n_jobs = min(n_jobs, cpu_count())
logger.info(f"n_jobs={n_jobs},将使用 {n_jobs} 个进程进行回测")

symbols = self.symbols
res = {}
for symbol in symbols:
daily = self.get_symbol_daily(symbol)
pairs = self.get_symbol_pairs(symbol)
res[symbol] = {"daily": daily, "pairs": pairs}
if n_jobs <= 1:
for symbol in tqdm(sorted(symbols), desc="WBT进度"):
res[symbol] = self.process_symbol(symbol)[1]
else:
with ProcessPoolExecutor(n_jobs) as pool:
for symbol, res_symbol in tqdm(pool.map(self.process_symbol, sorted(symbols)),
desc="WBT进度", total=len(symbols)):
res[symbol] = res_symbol

dret = pd.concat([v['daily'] for k, v in res.items() if k in symbols], ignore_index=True)
dret = pd.pivot_table(dret, index='date', columns='symbol', values='return').fillna(0)
Expand Down
11 changes: 8 additions & 3 deletions examples/test_offline/test_weight_backtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
import czsc
import pandas as pd

assert czsc.WeightBacktest.version == "V231005"
assert czsc.WeightBacktest.version == "V231126"


def run_by_weights():
"""从持仓权重样例数据中回测"""
dfw = pd.read_feather(r"C:\Users\zengb\Desktop\231005\weight_example.feather")
wb = czsc.WeightBacktest(dfw, digits=1, fee_rate=0.0002)
dfw = pd.read_feather(r"C:\Users\zengb\Downloads\weight_example.feather")
wb = czsc.WeightBacktest(dfw, digits=1, fee_rate=0.0002, n_jobs=1)
# wb = czsc.WeightBacktest(dfw, digits=1, fee_rate=0.0002)

# ------------------------------------------------------------------------------------
# 查看绩效评价
Expand Down Expand Up @@ -38,3 +39,7 @@ def run_by_weights():
print(symbol_res)

wb.report(res_path=r"C:\Users\zengb\Desktop\231005\weight_example")


if __name__ == '__main__':
run_by_weights()

0 comments on commit 5e4abf0

Please sign in to comment.