Skip to content

Commit

Permalink
Add other_data (#1062)
Browse files Browse the repository at this point in the history
- Add `State.other_data` where strategies can store custom variables in `decide_trades`
  • Loading branch information
miohtama authored Oct 17, 2024
1 parent b903c13 commit cf512cb
Show file tree
Hide file tree
Showing 3 changed files with 277 additions and 0 deletions.
171 changes: 171 additions & 0 deletions tests/backtest/test_other_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
"""Test other_data state data structures."""
import datetime
import random

import pytest

from tradingstrategy.candle import GroupedCandleUniverse
from tradingstrategy.chain import ChainId
from tradingstrategy.liquidity import GroupedLiquidityUniverse
from tradingstrategy.timebucket import TimeBucket
from tradingstrategy.universe import Universe

from tradeexecutor.backtest.backtest_pricing import BacktestPricing
from tradeexecutor.backtest.backtest_routing import BacktestRoutingModel
from tradeexecutor.cli.log import setup_pytest_logging
from tradeexecutor.backtest.backtest_runner import run_backtest_inline
from tradeexecutor.strategy.trading_strategy_universe import TradingStrategyUniverse, create_pair_universe_from_code
from tradeexecutor.strategy.execution_context import ExecutionContext, ExecutionMode
from tradeexecutor.strategy.cycle import CycleDuration
from tradeexecutor.strategy.reserve_currency import ReserveCurrency
from tradeexecutor.state.trade import TradeExecution
from tradeexecutor.state.identifier import AssetIdentifier, TradingPairIdentifier
from tradeexecutor.strategy.pandas_trader.indicator import IndicatorSet
from tradeexecutor.strategy.pandas_trader.strategy_input import StrategyInput
from tradeexecutor.strategy.strategy_module import StrategyParameters
from tradeexecutor.strategy.tvl_size_risk import USDTVLSizeRiskModel
from tradeexecutor.testing.synthetic_ethereum_data import generate_random_ethereum_address
from tradeexecutor.testing.synthetic_exchange_data import generate_exchange, generate_simple_routing_model
from tradeexecutor.testing.synthetic_price_data import generate_ohlcv_candles, generate_tvl_candles



@pytest.fixture(scope="module")
def logger(request):
"""Setup test logger."""
return setup_pytest_logging(request, mute_requests=False)


@pytest.fixture(scope="module")
def strategy_universe() -> TradingStrategyUniverse:
"""Create ETH-USDC universe with only increasing data.
- 1 months of data
- Close price increase 1% every hour
- Liquidity is a fixed 150,000 USD for the duration of the test
"""

start_at = datetime.datetime(2021, 6, 1)
end_at = datetime.datetime(2021, 7, 1)

# Set up fake assets
mock_chain_id = ChainId.ethereum
mock_exchange = generate_exchange(
exchange_id=random.randint(1, 1000),
chain_id=mock_chain_id,
address=generate_random_ethereum_address(),
exchange_slug="my-dex",
)
usdc = AssetIdentifier(ChainId.ethereum.value, generate_random_ethereum_address(), "USDC", 6, 1)
weth = AssetIdentifier(ChainId.ethereum.value, generate_random_ethereum_address(), "WETH", 18, 2)
weth_usdc = TradingPairIdentifier(
weth,
usdc,
generate_random_ethereum_address(),
mock_exchange.address,
internal_id=random.randint(1, 1000),
internal_exchange_id=mock_exchange.exchange_id,
fee=0.0030,
)

time_bucket = TimeBucket.d1

pair_universe = create_pair_universe_from_code(mock_chain_id, [weth_usdc])

# Create 1h underlying trade signal
daily_candles = generate_ohlcv_candles(
time_bucket,
start_at,
end_at,
pair_id=weth_usdc.internal_id,
daily_drift=(1.01, 1.01),
high_drift=1.05,
low_drift=0.90,
random_seed=1,
)
candle_universe = GroupedCandleUniverse.create_from_single_pair_dataframe(daily_candles)
universe = Universe(
time_bucket=TimeBucket.d1,
chains={mock_chain_id},
exchanges={mock_exchange},
pairs=pair_universe,
candles=candle_universe,
)
universe.pairs.exchange_universe = universe.exchange_universe

return TradingStrategyUniverse(
data_universe=universe,
reserve_assets=[usdc],
backtest_stop_loss_time_bucket=time_bucket,
)


@pytest.fixture()
def routing_model(synthetic_universe) -> BacktestRoutingModel:
return generate_simple_routing_model(synthetic_universe)


@pytest.fixture()
def pricing_model(synthetic_universe, routing_model) -> BacktestPricing:
pricing_model = BacktestPricing(
synthetic_universe.data_universe.candles,
routing_model,
allow_missing_fees=True,
)
return pricing_model


def create_indicators(timestamp: datetime.datetime, parameters: StrategyParameters, strategy_universe: TradingStrategyUniverse, execution_context: ExecutionContext):
# No indicators needed
return IndicatorSet()


def decide_trades(input: StrategyInput) -> list[TradeExecution]:
"""Example of storing and loading custom variables."""

cycle = input.cycle
state = input.state

# Saving values by cycle
state.other_data.save(cycle, "my_value", 1)
state.other_data.save(cycle, "my_value_2", [1, 2])
state.other_data.save(cycle, "my_value_3", {1: 2})

if cycle >= 2:
# Loading latest values
assert state.other_data.load_latest("my_value") == 1
assert state.other_data.load_latest("my_value_2") == [1, 2]
assert state.other_data.load_latest("my_value_3") == {1: 2}

return []


def test_other_data(strategy_universe, tmp_path):
"""Test state.other_data."""

# Start with $1M cash, far exceeding the market size
class Parameters:
backtest_start = strategy_universe.data_universe.candles.get_timestamp_range()[0].to_pydatetime()
backtest_end = strategy_universe.data_universe.candles.get_timestamp_range()[1].to_pydatetime()
initial_cash = 1_000_000
cycle_duration = CycleDuration.cycle_1d

# Run the test
result = run_backtest_inline(
client=None,
decide_trades=decide_trades,
create_indicators=create_indicators,
universe=strategy_universe,
reserve_currency=ReserveCurrency.usdc,
engine_version="0.5",
parameters=StrategyParameters.from_class(Parameters),
mode=ExecutionMode.unit_testing,
)

# Variables are readable after the backtest
state = result.state
assert len(state.other_data.data.keys()) == 29 # We stored data for 29 decide_trades cycles
assert state.other_data.data[1]["my_value"] == 1 # We can read historic values

102 changes: 102 additions & 0 deletions tradeexecutor/state/other_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
"""Storing of custom variables in the backtesting state."""
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Any, TypeAlias

from dataclasses_json import dataclass_json

JsonSerialisableObject: TypeAlias = Any


def _dict_dict():
return defaultdict(dict)


@dataclass_json
@dataclass(slots=True)
class OtherData:
"""Store custom variables in the backtesting state.
- For each cycle, you can record custom variables here
- All historical cycle values are stored
- All values must be JSON serialisable.
- You can then read the variables back
- This can be used in live trade execution as well **with care**.
Because of the underlying infrastructure may crash (blockchain halt, server crash)
cycles might be skipped.
Example of storing and loading custom variables:
.. code-block:: python
def decide_trades(input: StrategyInput) -> list[TradeExecution]:
cycle = input.cycle
state = input.state
# Saving values by cycle
state.other_data.save(cycle, "my_value", 1)
state.other_data.save(cycle, "my_value_2", [1, 2])
state.other_data.save(cycle, "my_value_3", {1: 2})
if cycle >= 2:
# Loading latest values
assert state.other_data.load_latest("my_value") == 1
assert state.other_data.load_latest("my_value_2") == [1, 2]
assert state.other_data.load_latest("my_value_3") == {1: 2}
return []
You can also read these variables after the backtest is complete:
.. code-block::
result = run_backtest_inline(
client=None,
decide_trades=decide_trades,
create_indicators=create_indicators,
universe=strategy_universe,
reserve_currency=ReserveCurrency.usdc,
engine_version="0.5",
parameters=StrategyParameters.from_class(Parameters),
mode=ExecutionMode.unit_testing,
)
# Variables are readable after the backtest
state = result.state
assert len(state.other_data.data.keys()) == 29 # We stored data for 29 decide_trades cycles
assert state.other_data.data[1]["my_value"] == 1 # We can read historic values
"""

#: Cycle number -> dict mapping
data: dict[int, JsonSerialisableObject] = field(default_factory=_dict_dict)

def get_latest_stored_cycle(self) -> int:
"""Get the cycle for which we have recorded any data.
:return:
0 if no data
"""
if len(self.data) == 0:
return 0
return max(self.data.keys())

def save(self, cycle: int, name: str, value: JsonSerialisableObject):
"""Save the value on this cycle."""
assert type(cycle) == int, f"Got {cycle}"
assert type(name) == str, f"Got {name}"
self.data[cycle][name] = value

def load_latest(self, name: str) -> JsonSerialisableObject | None:
"""Load the latest named value from the store.
- Take the value whatever is the last cycle
:return:
If the last cycle did not store this var, then return `None`.
"""
latest_cycle = self.get_latest_stored_cycle()
return self.data.get(latest_cycle).get(name, None)
4 changes: 4 additions & 0 deletions tradeexecutor/state/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from dataclasses_json import dataclass_json
from dataclasses_json.core import _ExtendedEncoder

from .other_data import OtherData
from .sync import Sync
from .identifier import AssetIdentifier, TradingPairIdentifier, TradingPairKind
from .portfolio import Portfolio
Expand Down Expand Up @@ -183,6 +184,9 @@ class State:
#: not live trading.
backtest_data: BacktestData | None = None

#: Misc. backtesting variables settable by users
other_data: Optional[OtherData] = field(default_factory=OtherData)

def __repr__(self):
return f"<State for {self.name}>"

Expand Down

0 comments on commit cf512cb

Please sign in to comment.