From cf512cbdac2606401a46f64f3652311ff70c77a7 Mon Sep 17 00:00:00 2001 From: Mikko Ohtamaa Date: Thu, 17 Oct 2024 10:32:06 +0200 Subject: [PATCH] Add other_data (#1062) - Add `State.other_data` where strategies can store custom variables in `decide_trades` --- tests/backtest/test_other_data.py | 171 ++++++++++++++++++++++++++++++ tradeexecutor/state/other_data.py | 102 ++++++++++++++++++ tradeexecutor/state/state.py | 4 + 3 files changed, 277 insertions(+) create mode 100644 tests/backtest/test_other_data.py create mode 100644 tradeexecutor/state/other_data.py diff --git a/tests/backtest/test_other_data.py b/tests/backtest/test_other_data.py new file mode 100644 index 000000000..fcaa42f34 --- /dev/null +++ b/tests/backtest/test_other_data.py @@ -0,0 +1,171 @@ +"""Test other_data state data structures.""" +import datetime +import random + +import pytest + +from tradingstrategy.candle import GroupedCandleUniverse +from tradingstrategy.chain import ChainId +from tradingstrategy.liquidity import GroupedLiquidityUniverse +from tradingstrategy.timebucket import TimeBucket +from tradingstrategy.universe import Universe + +from tradeexecutor.backtest.backtest_pricing import BacktestPricing +from tradeexecutor.backtest.backtest_routing import BacktestRoutingModel +from tradeexecutor.cli.log import setup_pytest_logging +from tradeexecutor.backtest.backtest_runner import run_backtest_inline +from tradeexecutor.strategy.trading_strategy_universe import TradingStrategyUniverse, create_pair_universe_from_code +from tradeexecutor.strategy.execution_context import ExecutionContext, ExecutionMode +from tradeexecutor.strategy.cycle import CycleDuration +from tradeexecutor.strategy.reserve_currency import ReserveCurrency +from tradeexecutor.state.trade import TradeExecution +from tradeexecutor.state.identifier import AssetIdentifier, TradingPairIdentifier +from tradeexecutor.strategy.pandas_trader.indicator import IndicatorSet +from tradeexecutor.strategy.pandas_trader.strategy_input import StrategyInput +from tradeexecutor.strategy.strategy_module import StrategyParameters +from tradeexecutor.strategy.tvl_size_risk import USDTVLSizeRiskModel +from tradeexecutor.testing.synthetic_ethereum_data import generate_random_ethereum_address +from tradeexecutor.testing.synthetic_exchange_data import generate_exchange, generate_simple_routing_model +from tradeexecutor.testing.synthetic_price_data import generate_ohlcv_candles, generate_tvl_candles + + + +@pytest.fixture(scope="module") +def logger(request): + """Setup test logger.""" + return setup_pytest_logging(request, mute_requests=False) + + +@pytest.fixture(scope="module") +def strategy_universe() -> TradingStrategyUniverse: + """Create ETH-USDC universe with only increasing data. + + - 1 months of data + + - Close price increase 1% every hour + + - Liquidity is a fixed 150,000 USD for the duration of the test + """ + + start_at = datetime.datetime(2021, 6, 1) + end_at = datetime.datetime(2021, 7, 1) + + # Set up fake assets + mock_chain_id = ChainId.ethereum + mock_exchange = generate_exchange( + exchange_id=random.randint(1, 1000), + chain_id=mock_chain_id, + address=generate_random_ethereum_address(), + exchange_slug="my-dex", + ) + usdc = AssetIdentifier(ChainId.ethereum.value, generate_random_ethereum_address(), "USDC", 6, 1) + weth = AssetIdentifier(ChainId.ethereum.value, generate_random_ethereum_address(), "WETH", 18, 2) + weth_usdc = TradingPairIdentifier( + weth, + usdc, + generate_random_ethereum_address(), + mock_exchange.address, + internal_id=random.randint(1, 1000), + internal_exchange_id=mock_exchange.exchange_id, + fee=0.0030, + ) + + time_bucket = TimeBucket.d1 + + pair_universe = create_pair_universe_from_code(mock_chain_id, [weth_usdc]) + + # Create 1h underlying trade signal + daily_candles = generate_ohlcv_candles( + time_bucket, + start_at, + end_at, + pair_id=weth_usdc.internal_id, + daily_drift=(1.01, 1.01), + high_drift=1.05, + low_drift=0.90, + random_seed=1, + ) + candle_universe = GroupedCandleUniverse.create_from_single_pair_dataframe(daily_candles) + universe = Universe( + time_bucket=TimeBucket.d1, + chains={mock_chain_id}, + exchanges={mock_exchange}, + pairs=pair_universe, + candles=candle_universe, + ) + universe.pairs.exchange_universe = universe.exchange_universe + + return TradingStrategyUniverse( + data_universe=universe, + reserve_assets=[usdc], + backtest_stop_loss_time_bucket=time_bucket, + ) + + +@pytest.fixture() +def routing_model(synthetic_universe) -> BacktestRoutingModel: + return generate_simple_routing_model(synthetic_universe) + + +@pytest.fixture() +def pricing_model(synthetic_universe, routing_model) -> BacktestPricing: + pricing_model = BacktestPricing( + synthetic_universe.data_universe.candles, + routing_model, + allow_missing_fees=True, + ) + return pricing_model + + +def create_indicators(timestamp: datetime.datetime, parameters: StrategyParameters, strategy_universe: TradingStrategyUniverse, execution_context: ExecutionContext): + # No indicators needed + return IndicatorSet() + + +def decide_trades(input: StrategyInput) -> list[TradeExecution]: + """Example of storing and loading custom variables.""" + + cycle = input.cycle + state = input.state + + # Saving values by cycle + state.other_data.save(cycle, "my_value", 1) + state.other_data.save(cycle, "my_value_2", [1, 2]) + state.other_data.save(cycle, "my_value_3", {1: 2}) + + if cycle >= 2: + # Loading latest values + assert state.other_data.load_latest("my_value") == 1 + assert state.other_data.load_latest("my_value_2") == [1, 2] + assert state.other_data.load_latest("my_value_3") == {1: 2} + + return [] + + +def test_other_data(strategy_universe, tmp_path): + """Test state.other_data.""" + + # Start with $1M cash, far exceeding the market size + class Parameters: + backtest_start = strategy_universe.data_universe.candles.get_timestamp_range()[0].to_pydatetime() + backtest_end = strategy_universe.data_universe.candles.get_timestamp_range()[1].to_pydatetime() + initial_cash = 1_000_000 + cycle_duration = CycleDuration.cycle_1d + + # Run the test + result = run_backtest_inline( + client=None, + decide_trades=decide_trades, + create_indicators=create_indicators, + universe=strategy_universe, + reserve_currency=ReserveCurrency.usdc, + engine_version="0.5", + parameters=StrategyParameters.from_class(Parameters), + mode=ExecutionMode.unit_testing, + ) + + # Variables are readable after the backtest + state = result.state + assert len(state.other_data.data.keys()) == 29 # We stored data for 29 decide_trades cycles + assert state.other_data.data[1]["my_value"] == 1 # We can read historic values + diff --git a/tradeexecutor/state/other_data.py b/tradeexecutor/state/other_data.py new file mode 100644 index 000000000..d33035a17 --- /dev/null +++ b/tradeexecutor/state/other_data.py @@ -0,0 +1,102 @@ +"""Storing of custom variables in the backtesting state.""" +from collections import defaultdict +from dataclasses import dataclass, field +from typing import Any, TypeAlias + +from dataclasses_json import dataclass_json + +JsonSerialisableObject: TypeAlias = Any + + +def _dict_dict(): + return defaultdict(dict) + + +@dataclass_json +@dataclass(slots=True) +class OtherData: + """Store custom variables in the backtesting state. + + - For each cycle, you can record custom variables here + + - All historical cycle values are stored + + - All values must be JSON serialisable. + + - You can then read the variables back + + - This can be used in live trade execution as well **with care**. + Because of the underlying infrastructure may crash (blockchain halt, server crash) + cycles might be skipped. + + Example of storing and loading custom variables: + + .. code-block:: python + + def decide_trades(input: StrategyInput) -> list[TradeExecution]: + cycle = input.cycle + state = input.state + + # Saving values by cycle + state.other_data.save(cycle, "my_value", 1) + state.other_data.save(cycle, "my_value_2", [1, 2]) + state.other_data.save(cycle, "my_value_3", {1: 2}) + + if cycle >= 2: + # Loading latest values + assert state.other_data.load_latest("my_value") == 1 + assert state.other_data.load_latest("my_value_2") == [1, 2] + assert state.other_data.load_latest("my_value_3") == {1: 2} + + return [] + + You can also read these variables after the backtest is complete: + + .. code-block:: + + result = run_backtest_inline( + client=None, + decide_trades=decide_trades, + create_indicators=create_indicators, + universe=strategy_universe, + reserve_currency=ReserveCurrency.usdc, + engine_version="0.5", + parameters=StrategyParameters.from_class(Parameters), + mode=ExecutionMode.unit_testing, + ) + + # Variables are readable after the backtest + state = result.state + assert len(state.other_data.data.keys()) == 29 # We stored data for 29 decide_trades cycles + assert state.other_data.data[1]["my_value"] == 1 # We can read historic values + """ + + #: Cycle number -> dict mapping + data: dict[int, JsonSerialisableObject] = field(default_factory=_dict_dict) + + def get_latest_stored_cycle(self) -> int: + """Get the cycle for which we have recorded any data. + + :return: + 0 if no data + """ + if len(self.data) == 0: + return 0 + return max(self.data.keys()) + + def save(self, cycle: int, name: str, value: JsonSerialisableObject): + """Save the value on this cycle.""" + assert type(cycle) == int, f"Got {cycle}" + assert type(name) == str, f"Got {name}" + self.data[cycle][name] = value + + def load_latest(self, name: str) -> JsonSerialisableObject | None: + """Load the latest named value from the store. + + - Take the value whatever is the last cycle + + :return: + If the last cycle did not store this var, then return `None`. + """ + latest_cycle = self.get_latest_stored_cycle() + return self.data.get(latest_cycle).get(name, None) \ No newline at end of file diff --git a/tradeexecutor/state/state.py b/tradeexecutor/state/state.py index 4a31b2e16..43cb3a534 100644 --- a/tradeexecutor/state/state.py +++ b/tradeexecutor/state/state.py @@ -16,6 +16,7 @@ from dataclasses_json import dataclass_json from dataclasses_json.core import _ExtendedEncoder +from .other_data import OtherData from .sync import Sync from .identifier import AssetIdentifier, TradingPairIdentifier, TradingPairKind from .portfolio import Portfolio @@ -183,6 +184,9 @@ class State: #: not live trading. backtest_data: BacktestData | None = None + #: Misc. backtesting variables settable by users + other_data: Optional[OtherData] = field(default_factory=OtherData) + def __repr__(self): return f""