From 2e6dfa53fae284fc8aee5e69a4c70b4de12a34c6 Mon Sep 17 00:00:00 2001 From: zengbin93 Date: Sun, 26 Nov 2023 16:04:22 +0800 Subject: [PATCH] 0.9.37 fix check bi --- czsc/analyze.py | 78 +++++++++++++++---------- examples/test_offline/test_update_bi.py | 25 ++++++++ 2 files changed, 73 insertions(+), 30 deletions(-) create mode 100644 examples/test_offline/test_update_bi.py diff --git a/czsc/analyze.py b/czsc/analyze.py index 6455bac3f..b2d5c9718 100644 --- a/czsc/analyze.py +++ b/czsc/analyze.py @@ -75,7 +75,23 @@ def remove_include(k1: NewBar, k2: NewBar, k3: RawBar): def check_fx(k1: NewBar, k2: NewBar, k3: NewBar): - """查找分型""" + """查找分型 + + 函数计算逻辑: + + 1. 如果第二个`NewBar`对象的最高价和最低价都高于第一个和第三个`NewBar`对象的对应价格,那么它被认为是顶分型(G)。 + 在这种情况下,函数会创建一个新的`FX`对象,其标记为`Mark.G`,并将其赋值给`fx`。 + + 2. 如果第二个`NewBar`对象的最高价和最低价都低于第一个和第三个`NewBar`对象的对应价格,那么它被认为是底分型(D)。 + 在这种情况下,函数会创建一个新的`FX`对象,其标记为`Mark.D`,并将其赋值给`fx`。 + + 3. 函数最后返回`fx`,如果没有找到分型,`fx`将为`None`。 + + :param k1: 第一个`NewBar`对象 + :param k2: 第二个`NewBar`对象 + :param k3: 第三个`NewBar`对象 + :return: `FX`对象或`None` + """ fx = None if k1.high < k2.high > k3.high and k1.low < k2.low > k3.low: fx = FX(symbol=k1.symbol, dt=k2.dt, mark=Mark.G, high=k2.high, @@ -89,10 +105,24 @@ def check_fx(k1: NewBar, k2: NewBar, k3: NewBar): def check_fxs(bars: List[NewBar]) -> List[FX]: - """输入一串无包含关系K线,查找其中所有分型""" + """输入一串无包含关系K线,查找其中所有分型 + + 函数的主要步骤: + + 1. 创建一个空列表`fxs`用于存储找到的分型。 + 2. 遍历`bars`列表中的每个元素(除了第一个和最后一个),并对每三个连续的`NewBar`对象调用`check_fx`函数。 + 3. 如果`check_fx`函数返回一个`FX`对象,检查它的标记是否与`fxs`列表中最后一个`FX`对象的标记相同。如果相同,记录一个错误日志。 + 如果不同,将这个`FX`对象添加到`fxs`列表中。 + 4. 最后返回`fxs`列表,它包含了`bars`列表中所有找到的分型。 + + 这个函数的主要目的是找出`bars`列表中所有的顶分型和底分型,并确保它们是交替出现的。如果发现连续的两个分型标记相同,它会记录一个错误日志。 + + :param bars: 无包含关系K线列表 + :return: 分型列表 + """ fxs = [] - for i in range(1, len(bars)-1): - fx = check_fx(bars[i-1], bars[i], bars[i+1]) + for i in range(1, len(bars) - 1): + fx = check_fx(bars[i - 1], bars[i], bars[i + 1]) if isinstance(fx, FX): # 默认情况下,fxs本身是顶底交替的,但是对于一些特殊情况下不是这样; 临时强制要求fxs序列顶底交替 if len(fxs) >= 2 and fx.mark == fxs[-1].mark: @@ -115,32 +145,20 @@ def check_bi(bars: List[NewBar], benchmark=None): return None, bars fx_a = fxs[0] - try: - if fxs[0].mark == Mark.D: - direction = Direction.Up - fxs_b = [x for x in fxs if x.mark == Mark.G and x.dt > fx_a.dt and x.fx > fx_a.fx] - if not fxs_b: - return None, bars - - fx_b = fxs_b[0] - for fx in fxs_b[1:]: - if fx.high >= fx_b.high: - fx_b = fx - - elif fxs[0].mark == Mark.G: - direction = Direction.Down - fxs_b = [x for x in fxs if x.mark == Mark.D and x.dt > fx_a.dt and x.fx < fx_a.fx] - if not fxs_b: - return None, bars - - fx_b = fxs_b[0] - for fx in fxs_b[1:]: - if fx.low <= fx_b.low: - fx_b = fx - else: - raise ValueError - except Exception as e: - logger.exception(f"笔识别错误: {e}") + if fx_a.mark == Mark.D: + direction = Direction.Up + fxs_b = (x for x in fxs if x.mark == Mark.G and x.dt > fx_a.dt and x.fx > fx_a.fx) + fx_b = max(fxs_b, key=lambda fx: fx.high, default=None) + + elif fx_a.mark == Mark.G: + direction = Direction.Down + fxs_b = (x for x in fxs if x.mark == Mark.D and x.dt > fx_a.dt and x.fx < fx_a.fx) + fx_b = min(fxs_b, key=lambda fx: fx.low, default=None) + + else: + raise ValueError + + if fx_b is None: return None, bars bars_a = [x for x in bars if fx_a.elements[0].dt <= x.dt <= fx_b.elements[2].dt] diff --git a/examples/test_offline/test_update_bi.py b/examples/test_offline/test_update_bi.py new file mode 100644 index 000000000..38b156648 --- /dev/null +++ b/examples/test_offline/test_update_bi.py @@ -0,0 +1,25 @@ +import sys +sys.path.insert(0, '.') +sys.path.insert(0, '..') +sys.path.insert(0, '../..') +import pandas as pd +import os +from czsc import RawBar, Freq, CZSC, welcome + +welcome() + + +def read_daily(): + file_kline = os.path.join(r"D:\ZB\git_repo\waditu\czsc\test", "data/000001.SH_D.csv") + kline = pd.read_csv(file_kline, encoding="utf-8") + kline['amount'] = kline['close'] * kline['vol'] + + kline.loc[:, "dt"] = pd.to_datetime(kline.dt) + bars = [RawBar(symbol=row['symbol'], id=i, freq=Freq.D, open=row['open'], dt=row['dt'], + close=row['close'], high=row['high'], low=row['low'], vol=row['vol'], amount=row['amount']) + for i, row in kline.iterrows()] + return bars + +bars = read_daily() +c = CZSC(bars) +# %timeit c = CZSC(bars) \ No newline at end of file