diff --git a/tests/backtest/test_indicator.py b/tests/backtest/test_indicator.py index 213d739b9..e8e2f9a83 100644 --- a/tests/backtest/test_indicator.py +++ b/tests/backtest/test_indicator.py @@ -7,14 +7,14 @@ import pandas as pd import pandas_ta import pytest - +from pyasn1_modules.rfc3779 import id_pe_ipAddrBlocks from tradeexecutor.state.identifier import AssetIdentifier, TradingPairIdentifier from tradeexecutor.strategy.execution_context import ExecutionContext, unit_test_execution_context, unit_test_trading_execution_context from tradeexecutor.strategy.pandas_trader.indicator import ( IndicatorSet, DiskIndicatorStorage, IndicatorDefinition, IndicatorFunctionSignatureMismatch, calculate_and_load_indicators, IndicatorKey, IndicatorSource, IndicatorDependencyResolver, - IndicatorCalculationFailed, MemoryIndicatorStorage, calculate_and_load_indicators_inline, calculate_and_load_indicators_inline, + IndicatorCalculationFailed, MemoryIndicatorStorage, calculate_and_load_indicators_inline, ) from tradeexecutor.strategy.pandas_trader.strategy_input import StrategyInputIndicators from tradeexecutor.strategy.parameters import StrategyParameters @@ -793,9 +793,20 @@ def regime( fast_sma: pd.Series = dependency_resolver.get_indicator_data("fast_sma", pair=pair, parameters={"length": length}) return close > fast_sma + def multipair(universe: TradingStrategyUniverse, dependency_resolver: IndicatorDependencyResolver) -> pd.DataFrame: + # Test multipair data resolution + series = dependency_resolver.get_indicator_data_pairs_combined("regime") + assert isinstance(series.index, pd.MultiIndex) + assert isinstance(series, pd.Series) + # Change from pd.Series to pd.DataFrame with column "value" + df = series.to_frame(name='value') + assert df.columns == ["value"] + return df + def create_indicators(parameters: StrategyParameters, indicators: IndicatorSet, strategy_universe: TradingStrategyUniverse, execution_context: ExecutionContext): indicators.add("fast_sma", pandas_ta.sma, {"length": parameters.fast_sma}, order=1) indicators.add("regime", regime, {"length": parameters.fast_sma}, order=2) + indicators.add("multipair", multipair, {}, IndicatorSource.strategy_universe, order=3) class MyParameters: fast_sma = 20 @@ -815,7 +826,7 @@ class MyParameters: wbtc_usdc = strategy_universe.get_pair_by_human_description((ChainId.ethereum, exchange.exchange_slug, "WBTC", "USDC")) keys = list(indicator_result.keys()) - keys = sorted(keys, key=lambda k: (k.pair.internal_id, k.definition.name)) # Ensure we read set in deterministic order + keys = sorted(keys, key=lambda k: (k.pair.internal_id if k.pair else 9_999_999, k.definition.name)) # Ensure we read set in deterministic order # Check our pair x indicator matrix assert keys[0].pair== weth_usdc @@ -831,7 +842,7 @@ class MyParameters: for result in indicator_result.values(): assert not result.cached - assert isinstance(result.data, pd.Series) + assert isinstance(result.data, (pd.Series, pd.DataFrame)) assert len(result.data) > 0 # Run with higher workers count should still work since it should force single thread diff --git a/tradeexecutor/strategy/pandas_trader/indicator.py b/tradeexecutor/strategy/pandas_trader/indicator.py index d0ce06b95..93a3f8d7b 100644 --- a/tradeexecutor/strategy/pandas_trader/indicator.py +++ b/tradeexecutor/strategy/pandas_trader/indicator.py @@ -1343,6 +1343,71 @@ def match_indicator( return result + def get_indicator_data_pairs_combined( + self, + name: str, + ) -> pd.Series: + """Get a DataFrame that contains indicator data for all pairs combined. + + - Allows to access the indicator data for all pairs as a combined dataframe. + + Example: + + .. code-block:: python + + def regime( + close: pd.Series, + pair: TradingPairIdentifier, + length: int, + dependency_resolver: IndicatorDependencyResolver, + ) -> pd.Series: + fast_sma: pd.Series = dependency_resolver.get_indicator_data("fast_sma", pair=pair, parameters={"length": length}) + return close > fast_sma + + def multipair(universe: TradingStrategyUniverse, dependency_resolver: IndicatorDependencyResolver) -> pd.DataFrame: + # Test multipair data resolution + series = dependency_resolver.get_indicator_data_pairs_combined("regime") + assert isinstance(series.index, pd.MultiIndex) + assert isinstance(series, pd.Series) + # Change from pd.Series to pd.DataFrame with column "value" + df = series.to_frame(name='value') + assert df.columns == ["value"] + return df + + def create_indicators(parameters: StrategyParameters, indicators: IndicatorSet, strategy_universe: TradingStrategyUniverse, execution_context: ExecutionContext): + indicators.add("regime", regime, {"length": parameters.fast_sma}, order=2) + indicators.add("multipair", multipair, {}, IndicatorSource.strategy_universe, order=3) + + Output: + + .. code-block:: text + + pair_id timestamp + 1 2021-06-01 False + 2021-06-02 False + 2021-06-03 False + 2021-06-04 False + 2021-06-05 False + ... + 2 2021-12-27 True + 2021-12-28 True + 2021-12-29 False + 2021-12-30 False + 2021-12-31 False + + :param name: + An indicator that was previously calculated by its `order`. + + :return: + DataFrame with MultiIndex (pair_id, timestamp) + """ + + series_map = {pair.internal_id: self.get_indicator_data(name, pair=pair) for pair in self.strategy_universe.iterate_pairs()} + series_list = list(series_map.values()) + pair_ids = list(series_map.keys()) + combined = pd.concat(series_list, keys=pair_ids, names=['pair_id', 'timestamp']) + return combined + def get_indicator_data( self, name: str, diff --git a/tradeexecutor/utils/binance.py b/tradeexecutor/utils/binance.py index b0ee171e2..2b5ce5dad 100644 --- a/tradeexecutor/utils/binance.py +++ b/tradeexecutor/utils/binance.py @@ -107,7 +107,7 @@ def fetch_binance_dataset( df, {symbol: pair for symbol, pair in zip(symbols, pairs)} ) - candle_df["pair_id"].replace(spot_symbol_map, inplace=True) + candle_df["pair_id"] = candle_df["pair_id"].replace(spot_symbol_map) candle_universe, stop_loss_candle_universe = load_candle_universe_from_dataframe( df=candle_df,