diff --git a/czsc/utils/trade.py b/czsc/utils/trade.py index f76f965a3..a91b65d32 100644 --- a/czsc/utils/trade.py +++ b/czsc/utils/trade.py @@ -32,12 +32,14 @@ def cal_trade_price(bars: Union[List[RawBar], pd.DataFrame], decimals=3, **kwarg df[f"sum_vcp_{t}"] = df['vol_close_prod'].rolling(t).sum() df[f"VWAP{t}"] = (df[f"sum_vcp_{t}"] / df[f"sum_vol_{t}"]).shift(-t) price_cols.extend([f"TWAP{t}", f"VWAP{t}"]) + df.drop(columns=[f"sum_vol_{t}", f"sum_vcp_{t}"], inplace=True) + df.drop(columns=['vol_close_prod'], inplace=True) # 用当前K线的收盘价填充交易价中的 nan 值 for price_col in price_cols: df.loc[df[price_col].isnull(), price_col] = df[df[price_col].isnull()]['close'] - df = df[['symbol', 'dt', 'open', 'close', 'high', 'low', 'vol', 'amount'] + price_cols].round(decimals) + df[price_cols] = df[price_cols].round(decimals) return df