Skip to content

Commit

Permalink
0.9.60 优化 ta 模块
Browse files Browse the repository at this point in the history
  • Loading branch information
zengbin93 committed Oct 5, 2024
1 parent c5dbc9c commit 567b930
Show file tree
Hide file tree
Showing 4 changed files with 169 additions and 100 deletions.
1 change: 1 addition & 0 deletions czsc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from czsc.strategies import CzscStrategyBase, CzscJsonStrategy
from czsc.sensors import holds_concepts_effect, CTAResearch, EventMatchSensor
from czsc.sensors.feature import FixedNumberSelector
from czsc.utils import ta
from czsc.traders import (
CzscTrader,
CzscSignals,
Expand Down
102 changes: 93 additions & 9 deletions czsc/utils/ta.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,12 @@ def EMA(close: np.array, timeperiod=5):
return np.array(res, dtype=np.double).round(4)


def MACD(close: np.array, fastperiod=12, slowperiod=26, signalperiod=9):
def MACD(real: np.array, fastperiod=12, slowperiod=26, signalperiod=9):
"""MACD 异同移动平均线
https://baike.baidu.com/item/MACD%E6%8C%87%E6%A0%87/6271283
:param close: np.array
收盘价序列
:param real: np.array
价格序列
:param fastperiod: int
快周期,默认值 12
:param slowperiod: int
Expand All @@ -65,8 +65,8 @@ def MACD(close: np.array, fastperiod=12, slowperiod=26, signalperiod=9):
:return: (np.array, np.array, np.array)
diff, dea, macd
"""
ema12 = EMA(close, timeperiod=fastperiod)
ema26 = EMA(close, timeperiod=slowperiod)
ema12 = EMA(real, timeperiod=fastperiod)
ema26 = EMA(real, timeperiod=slowperiod)
diff = ema12 - ema26
dea = EMA(diff, timeperiod=signalperiod)
macd = (diff - dea) * 2
Expand Down Expand Up @@ -146,7 +146,7 @@ def RSQ(close: [np.array, list]) -> float:
return round(rsq, 4)


def plus_di(high, low, close, timeperiod=14):
def PLUS_DI(high, low, close, timeperiod=14):
"""
Calculate Plus Directional Indicator (PLUS_DI) manually.
Expand Down Expand Up @@ -179,7 +179,7 @@ def plus_di(high, low, close, timeperiod=14):
return plus_di_


def minus_di(high, low, close, timeperiod=14):
def MINUS_DI(high, low, close, timeperiod=14):
"""
Calculate Minus Directional Indicator (MINUS_DI) manually.
Expand Down Expand Up @@ -213,7 +213,7 @@ def minus_di(high, low, close, timeperiod=14):
return minus_di_


def atr(high, low, close, timeperiod=14):
def ATR(high, low, close, timeperiod=14):
"""
Calculate Average True Range (ATR).
Expand All @@ -234,7 +234,6 @@ def atr(high, low, close, timeperiod=14):

# Calculate ATR
atr_ = tr.rolling(window=timeperiod).mean()

return atr_


Expand Down Expand Up @@ -338,3 +337,88 @@ def LINEARREG_ANGLE(real, timeperiod=14):
angles[today] = np.arctan(m) * (180.0 / np.pi)

return angles


def DOUBLE_SMA_LS(series: pd.Series, n=5, m=20, **kwargs):
"""双均线多空
:param series: str, 数据源字段
:param n: int, 短周期
:param m: int, 长周期
"""
assert n < m, "短周期必须小于长周期"
return np.sign(series.rolling(window=n).mean() - series.rolling(window=m).mean()).fillna(0)


def BOLL_LS(series: pd.Series, n=5, s=0.1, **kwargs):
"""布林线多空
series 大于 n 周期均线 + s * n周期标准差,做多;小于 n 周期均线 - s * n周期标准差,做空
:param series: str, 数据源字段
:param n: int, 短周期
:param s: int, 波动率的倍数,默认为 0.1
"""
sm = series.rolling(window=n).mean()
sd = series.rolling(window=n).std()
return np.where(series > sm + s * sd, 1, np.where(series < sm - s * sd, -1, 0))


def SMA_MIN_MAX_SCALE(series: pd.Series, timeperiod=5, window=5, **kwargs):
"""均线的最大最小值归一化
:param series: str, 数据源字段
:param timeperiod: int, 均线周期
:param window: int, 窗口大小
"""
sm = series.rolling(window=timeperiod).mean()
sm_min = sm.rolling(window=window).min()
sm_max = sm.rolling(window=window).max()
res = (sm - sm_min) / (sm_max - sm_min)
res = res.fillna(0) * 2 - 1
return res


def RS_VOLATILITY(df: pd.DataFrame, timeperiod=30, **kwargs):
"""RS 波动率,值越大,波动越大
:param df: str, 标准K线数据
:param timeperiod: int, 周期
"""
log_h_c = np.log(df["high"] / df["close"])
log_h_o = np.log(df["high"] / df["open"])
log_l_c = np.log(df["low"] / df["close"])
log_l_o = np.log(df["low"] / df["open"])

x = log_h_c * log_h_o + log_l_c * log_l_o
res = np.sqrt(x.rolling(window=timeperiod).mean())
return res


def PK_VOLATILITY(df: pd.DataFrame, timeperiod=30, **kwargs):
"""PK 波动率,值越大,波动越大
:param df: str, 标准K线数据
:param timeperiod: int, 周期
"""
log_h_l = np.log(df["high"] / df["low"]).pow(2)
log_hl_mean = log_h_l.rolling(window=timeperiod).sum() / (4 * timeperiod * np.log(2))
res = np.sqrt(log_hl_mean)
return res


try:
import talib as ta

SMA = ta.SMA
EMA = ta.EMA
MACD = ta.MACD
LINEARREG_ANGLE = ta.LINEARREG_ANGLE
except ImportError:
SMA = SMA
EMA = EMA
MACD = MACD
print(
f"ta-lib 没有正确安装,将使用自定义分析函数。建议安装 ta-lib,可以大幅提升计算速度。"
f"请参考安装教程 https://blog.csdn.net/qaz2134560/article/details/98484091"
)
86 changes: 0 additions & 86 deletions czsc/utils/ta1.py

This file was deleted.

80 changes: 75 additions & 5 deletions examples/develop/对比numpy与talib的速度.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,89 @@
import sys
import time

sys.path.insert(0, r"A:\ZB\git_repo\waditu\czsc")
import talib
from czsc.utils import ta
from czsc.connectors import cooperation as coo


df = coo.get_raw_bars(symbol="SFIC9001", freq="30分钟", fq="后复权", sdt="20100101", edt="20210301", raw_bars=False)
df = coo.get_raw_bars(symbol="SFIC9001", freq="日线", fq="后复权", sdt="20100101", edt="20210301", raw_bars=False)


def test_with_numpy():
def test_vs_volatility():
df1 = df.copy()
df1["x"] = ta.LINEARREG_ANGLE(df["close"].values, 10)
df1["STD"] = df["close"].pct_change().rolling(30).std()
df1["RS"] = ta.RS_VOLATILITY(df1, timeperiod=30)
df1["PK"] = ta.PK_VOLATILITY(df1, timeperiod=30)
df1[["STD", "RS", "PK"]].corr()


def test_with_talib():
def test_compare_LINEARREG_ANGLE():
df1 = df.copy()
df1["x"] = talib.LINEARREG_ANGLE(df["close"].values, 10)
s1 = time.time()
df1["x1"] = ta.LINEARREG_ANGLE(df["close"].values, 10)
e1 = time.time() - s1

s2 = time.time()
df1["x2"] = talib.LINEARREG_ANGLE(df["close"].values, 10)
e2 = time.time() - s2

print(f"计算时间差异,ta: {e1}, talib: {e2};相差:{e1 - e2}")
df1["diff"] = df1["x1"] - df1["x2"]
assert df1["diff"].abs().max() < 1e-6
print(df1["diff"].abs().max())


def test_compare_CCI():
# 数据对不上,需要调整
df1 = df.copy()
s1 = time.time()
df1["x1"] = ta.CCI(df1["high"].values, df1["low"].values, df1["close"].values, timeperiod=14)
e1 = time.time() - s1

s2 = time.time()
df1["x2"] = talib.CCI(df1["high"].values, df1["low"].values, df1["close"].values, timeperiod=14)
e2 = time.time() - s2

print(f"计算时间差异,ta: {e1}, talib: {e2};相差:{e1 - e2}")
df1["diff"] = df1["x1"] - df1["x2"]
print(df1["diff"].abs().describe())

assert df1["diff"].abs().max() < 1e-6


def test_compare_MFI():
# 数据对不上,需要调整
df1 = df.copy()
s1 = time.time()
df1["x1"] = ta.MFI(df1["high"], df1["low"], df1["close"], df1["vol"], timeperiod=14)
e1 = time.time() - s1

s2 = time.time()
df1["x2"] = talib.MFI(df1["high"].values, df1["low"].values, df1["close"].values, df1["vol"].values, timeperiod=14)
e2 = time.time() - s2

print(f"计算时间差异,ta: {e1}, talib: {e2};相差:{e1 - e2}")
df1["diff"] = df1["x1"] - df1["x2"]
print(df1["diff"].abs().describe())

assert df1["diff"].abs().max() < 1e-6


def test_compare_PLUS_DI():
# 数据对不上,需要调整
df1 = coo.get_raw_bars(symbol="SFIC9001", freq="日线", fq="后复权", sdt="20100101", edt="20210301", raw_bars=False)

s1 = time.time()
df1["x1"] = ta.PLUS_DI(df1["high"], df1["low"], df1["close"], timeperiod=14)
e1 = time.time() - s1

s2 = time.time()
df1["x2"] = talib.PLUS_DI(df1["high"].values, df1["low"].values, df1["close"].values, timeperiod=14)
e2 = time.time() - s2

print(f"计算时间差异,ta: {e1}, talib: {e2};相差:{e1 - e2}")
df1["diff"] = df1["x1"] - df1["x2"]
print(df1["diff"].abs().describe())

assert df1["diff"].abs().max() < 1e-6

0 comments on commit 567b930

Please sign in to comment.