From bc6003877258b7d6b320cd9d19c73ed5ece888d3 Mon Sep 17 00:00:00 2001 From: Peter Dekkers Date: Fri, 23 Feb 2024 19:54:20 +0100 Subject: [PATCH] removed one more dependency --- pyproject.toml | 1 - roboquant/__init__.py | 2 +- roboquant/account.py | 45 +++-- roboquant/order.py | 25 ++- roboquant/roboquant.py | 2 +- roboquant/timeframe.py | 2 +- roboquant/trackers/__init__.py | 6 +- .../trackers/{capmtracker.py => alphabeta.py} | 14 +- roboquant/trackers/basictracker.py | 64 ++----- roboquant/trackers/equitytracker.py | 55 +++++- roboquant/trackers/markettracker.py | 52 ++++++ roboquant/trackers/standardtracker.py | 174 ------------------ roboquant/trackers/tensorboardtracker.py | 2 +- roboquant/trackers/tracker.py | 6 +- tests/data/output/account_repr.txt | 34 ++-- tests/samples/backtest.py | 9 + tests/samples/talib_strategy.py | 8 +- tests/samples/walkforward.py | 14 +- tests/unit/test_account.py | 17 +- ...apmtracker.py => test_alphabetatracker.py} | 4 +- tests/unit/test_equitytracker.py | 2 +- tests/unit/test_standardtracker.py | 19 -- 22 files changed, 210 insertions(+), 347 deletions(-) rename roboquant/trackers/{capmtracker.py => alphabeta.py} (82%) create mode 100644 roboquant/trackers/markettracker.py delete mode 100644 roboquant/trackers/standardtracker.py create mode 100644 tests/samples/backtest.py rename tests/unit/{test_capmtracker.py => test_alphabetatracker.py} (84%) delete mode 100644 tests/unit/test_standardtracker.py diff --git a/pyproject.toml b/pyproject.toml index 858553e..cf7d7e6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,6 @@ classifiers = [ keywords = ["trading", "investment", "finance", "crypto", "stocks", "exchange", "forex"] dependencies = [ "numpy>=1.25.2", - "prettytable~=3.9.0", "websocket-client~=1.7.0", "requests>=2.31.0", ] diff --git a/roboquant/__init__.py b/roboquant/__init__.py index 8def650..c3be348 100644 --- a/roboquant/__init__.py +++ b/roboquant/__init__.py @@ -11,7 +11,7 @@ from roboquant.brokers import Broker, SimBroker from roboquant.traders import Trader, FlexTrader -from roboquant.trackers import Tracker, StandardTracker, BasicTracker, CAPMTracker, EquityTracker, TensorboardTracker +from roboquant.trackers import Tracker, BasicTracker, AlphaBetaTracker, EquityTracker, TensorboardTracker, MarketTracker from roboquant.strategies import ( Strategy, EMACrossover, diff --git a/roboquant/account.py b/roboquant/account.py index 6c3516a..2ddc8f2 100644 --- a/roboquant/account.py +++ b/roboquant/account.py @@ -2,7 +2,6 @@ from datetime import datetime from decimal import Decimal from roboquant.order import Order -from prettytable import PrettyTable @dataclass(slots=True, frozen=True) @@ -28,12 +27,18 @@ class Account: Only the broker updates the state of the account and does this only during its `sync` method. """ + buying_power: float + positions: dict[str, Position] + orders: list[Order] + last_update: datetime + equity: float + def __init__(self): self.buying_power: float = 0.0 self.positions: dict[str, Position] = {} self.orders: list[Order] = [] self.last_update: datetime = datetime.fromisoformat("1900-01-01T00:00:00+00:00") - self.equity = 0.0 + self.equity: float = 0.0 def contract_value(self, symbol: str, size: Decimal, price: float) -> float: """Return the total value of the provided contract size denoted in the base currency of the account. @@ -50,6 +55,7 @@ def mkt_value(self, prices: dict[str, float]) -> float: return sum([self.contract_value(symbol, pos.size, prices[symbol]) for symbol, pos in self.positions.items()], 0.0) def unrealized_pnl(self, prices: dict[str, float]) -> float: + """Return the unrealized profit and loss for the open position given the provided market prices""" return sum( [self.contract_value(symbol, pos.size, prices[symbol] - pos.avg_price) for symbol, pos in self.positions.items()], 0.0, @@ -64,6 +70,7 @@ def has_open_order(self, symbol: str) -> bool: return False def get_position_size(self, symbol) -> Decimal: + """Return the position size for the symbol""" pos = self.positions.get(symbol) return pos.size if pos else Decimal(0) @@ -72,27 +79,19 @@ def open_orders(self): return [order for order in self.orders if not order.closed] def __repr__(self) -> str: - p = PrettyTable(["account", "value"], align="r", float_format="12.2") - p.add_row(["buying power", self.buying_power]) - p.add_row(["equity", self.equity]) - p.add_row(["positions", len(self.positions)]) - p.add_row(["orders", len(self.orders)]) - p.add_row(["last update", self.last_update.strftime("%Y-%m-%d %H:%M:%S")]) - result = p.get_string() + "\n\n" - - if self.positions: - p = PrettyTable(["symbol", "position size", "avg price"], align="r", float_format="12.2") - for symbol, pos in self.positions.items(): - p.add_row([symbol, pos.size, pos.avg_price]) - result += p.get_string() + "\n\n" - - if self.orders: - p = PrettyTable(["symbol", "order size", "order id", "limit", "status", "closed"], align="r", float_format="12.2") - for order in self.orders: - p.add_row([order.symbol, order.size, order.id, order.limit, order.status.name, order.closed]) - result += p.get_string() + "\n" - - return result + p = [f"{v.size}@{k}" for k, v in self.positions.items()] + p_str = ", ".join(p) + + o = [f"{o.size}@{o.symbol}" for o in self.open_orders()] + o_str = ", ".join(o) + + return f""" + buying power : {self.buying_power:_.2f} + equity : {self.equity:_.2f} + positions : {p_str} + open orders : {o_str} + last update : {self.last_update} + """ class OptionAccount(Account): diff --git a/roboquant/order.py b/roboquant/order.py index 4ea713f..7591d1f 100644 --- a/roboquant/order.py +++ b/roboquant/order.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass from decimal import Decimal from enum import Flag, auto from copy import copy @@ -31,13 +32,25 @@ def closed(self): """Return True is the status is closed, False otherwise""" return self in OrderStatus._CLOSE + def __repr__(self) -> str: + return self.name + +@dataclass class Order: """ A trading order. Default is a market order when only the size th specified. But optional a limit can be specified, - making it a limit order. The id is automatically assigned by the broker and should not be set manually. + making it a limit order. + + The id is automatically assigned by the broker and should not be set manually. """ + symbol: str + size: Decimal + limit: float | None + id: str | None + status: OrderStatus + def __init__(self, symbol: str, size: Decimal | str | int | float, limit: float | None = None): self.symbol = symbol self.size = Decimal(size) @@ -58,7 +71,7 @@ def closed(self) -> bool: return self.status.closed def cancel(self) -> "Order": - """Cancel this order. You can only cancel orders that are still open and have an id. + """Create a cancellation order. You can only cancel orders that are still open and have an id. The returned order looks like a regular order, but with its size set to zero. """ assert self.id is not None, "Can only cancel orders with an id" @@ -69,8 +82,10 @@ def cancel(self) -> "Order": return result def update(self, size: Decimal | str | int | float | None = None, limit: float | None = None) -> "Order": - """Update this order. You can only update orders that are still open and have an id. - You can update the size and/or limit of an order. The id of the order stays the same as the original order. + """Create an update-order. You can update the size and/or limit of an order. The returned order has the same id + as the original order. + + You can only update existing orders that are still open and have an id. """ assert self.id is not None, "Can only update orders with an id" @@ -100,8 +115,10 @@ def is_cancellation(self): @property def is_buy(self): + """Return True if this is a BUY order, False otherwise""" return self.size > 0 @property def is_sell(self): + """Return True if this is a SELL order, False otherwise""" return self.size < 0 diff --git a/roboquant/roboquant.py b/roboquant/roboquant.py index ef73e55..da467a3 100644 --- a/roboquant/roboquant.py +++ b/roboquant/roboquant.py @@ -66,6 +66,6 @@ def run( orders = self.trader.create_orders(signals, event, account) self.broker.place_orders(*orders) if tracker: - tracker.log(event, account, signals, orders) + tracker.trace(event, account, signals, orders) return self.broker.sync() diff --git a/roboquant/timeframe.py b/roboquant/timeframe.py index e2e38e7..b969303 100644 --- a/roboquant/timeframe.py +++ b/roboquant/timeframe.py @@ -103,7 +103,7 @@ def split(self, n: int | timedelta) -> list["Timeframe"]: period = n if isinstance(n, timedelta) else self.duration / n end = self.start result = [] - while end in self: + while end < self.end: start = end end = start + period result.append(Timeframe(start, end, False)) diff --git a/roboquant/trackers/__init__.py b/roboquant/trackers/__init__.py index 0210649..368ffdb 100644 --- a/roboquant/trackers/__init__.py +++ b/roboquant/trackers/__init__.py @@ -1,6 +1,6 @@ -from .tracker import Tracker -from roboquant.trackers.standardtracker import StandardTracker -from roboquant.trackers.capmtracker import CAPMTracker +from roboquant.trackers.tracker import Tracker +from roboquant.trackers.alphabeta import AlphaBetaTracker from roboquant.trackers.basictracker import BasicTracker from roboquant.trackers.equitytracker import EquityTracker from roboquant.trackers.tensorboardtracker import TensorboardTracker +from roboquant.trackers.markettracker import MarketTracker diff --git a/roboquant/trackers/capmtracker.py b/roboquant/trackers/alphabeta.py similarity index 82% rename from roboquant/trackers/capmtracker.py rename to roboquant/trackers/alphabeta.py index 72eae74..ae22cfa 100644 --- a/roboquant/trackers/capmtracker.py +++ b/roboquant/trackers/alphabeta.py @@ -1,13 +1,14 @@ from typing import Tuple import numpy as np -from prettytable import PrettyTable from roboquant.event import Event from roboquant.timeframe import Timeframe from .tracker import Tracker -class CAPMTracker(Tracker): +class AlphaBetaTracker(Tracker): + """Tracks the Alpha and Beta""" + def __init__(self, price_type="DEFAULT"): self.mkt_returns = [] self.acc_returns = [] @@ -27,7 +28,7 @@ def _get_market_returns(self, prices: dict[str, float]): result += prices[symbol] / self.last_prices[symbol] - 1.0 return result / cnt - def log(self, event: Event, account, signals, orders): + def trace(self, event: Event, account, signals, orders): prices = {item.symbol: item.price(self.price_type) for item in event.price_items.values()} equity = account.equity if self.init: @@ -41,13 +42,6 @@ def log(self, event: Event, account, signals, orders): self.last_equity = equity self.init = True - def __repr__(self) -> str: - alpha, beta = self.alpha_beta() - table = PrettyTable(["Metric", "Value"], float_format=".2", align="r") - table.add_row(["alpha %", alpha * 100]) - table.add_row(["beta", beta]) - return table.get_string() - def alpha_beta(self, risk_free_return=0.0) -> Tuple[float, float]: if not self.start_time or not self.end_time: return float("nan"), float("nan") diff --git a/roboquant/trackers/basictracker.py b/roboquant/trackers/basictracker.py index a855cc3..4744229 100644 --- a/roboquant/trackers/basictracker.py +++ b/roboquant/trackers/basictracker.py @@ -1,70 +1,44 @@ +from dataclasses import dataclass from datetime import datetime import logging -from .tracker import Tracker + +from roboquant.trackers.tracker import Tracker from roboquant.account import Account from roboquant.event import Event from roboquant.order import Order from roboquant.signal import Signal -from prettytable import PrettyTable logger = logging.getLogger(__name__) +@dataclass class BasicTracker(Tracker): """Tracks a number of basic metrics: - - start- and end-time - - total number of events, items, signals and orders - - equity + - last time + - total number of events, items, signals and orders until that time This tracker adds little overhead to a run, both CPU and memory wise. """ - - def __init__(self, price_type="DEFAULT"): - self.start_time = None - self.end_time = None + time: datetime | None + items: int + orders: int + signals: int + events: int + + def __init__(self, output=False): + self.time = None self.items = 0 self.orders = 0 self.signals = 0 self.events = 0 - self.equity = None - self.buying_power = 0.0 - - def log(self, event: Event, account: Account, signals: dict[str, Signal], orders: list[Order]): - - if self.start_time is None: - self.start_time = event.time + self.__output = output - self.end_time = event.time + def trace(self, event: Event, account: Account, signals: dict[str, Signal], orders: list[Order]): + self.time = event.time self.items += len(event.items) self.orders += len(orders) self.events += 1 self.signals += len(signals) - self.equity = account.equity - self.buying_power = account.buying_power - - if logger.isEnabledFor(logging.INFO): - logger.info( - "time=%s events=%s items=%s signals=%s orders=%s equity=%s, buying-power=%s", - self.end_time, - self.events, - self.items, - self.signals, - self.orders, - self.equity, - self.buying_power - ) - - def __repr__(self) -> str: - - def to_timefmt(time: datetime | None): - return "-" if time is None else time.strftime("%Y-%m-%d %H:%M:%S") - - p = PrettyTable(["metric", "value"], align="r", float_format=".2") - p.add_row(["start", to_timefmt(self.start_time)]) - p.add_row(["end", to_timefmt(self.end_time)]) - p.add_row(["events", self.events]) - p.add_row(["items", self.items]) - p.add_row(["signals", self.signals]) - p.add_row(["orders", self.orders]) - return p.get_string() + if self.__output: + print(self.__repr__() + "\n") diff --git a/roboquant/trackers/equitytracker.py b/roboquant/trackers/equitytracker.py index 1e3db77..dcdcafc 100644 --- a/roboquant/trackers/equitytracker.py +++ b/roboquant/trackers/equitytracker.py @@ -1,19 +1,56 @@ -from datetime import datetime -from .tracker import Tracker +from roboquant.timeframe import Timeframe +from roboquant.trackers.tracker import Tracker class EquityTracker(Tracker): """Tracks the time of an event and the equity at that moment. - If multiple events happen at the same time, only the first one will be registered. + If multiple events happen at the same time, only the equity for first one will be registered. """ def __init__(self): self.timeline = [] - self.equity = [] - self.last = datetime.fromisoformat("1900-01-01T00:00:00+00:00") + self.equities = [] + self.__last = None - def log(self, event, account, signals, orders): - if event.time > self.last: + def trace(self, event, account, signals, orders): + if self.__last is None or event.time > self.__last: self.timeline.append(event.time) - self.equity.append(account.equity) - self.last = event.time + self.equities.append(account.equity) + self.__last = event.time + + def timeframe(self): + return Timeframe(self.timeline[0], self.timeline[-1], True) + + def pnl(self, annualized=False): + """Return the profit & loss percentage, optionally annualized from the recorded durtion""" + pnl = self.equities[-1]/self.equities[0] - 1 + if annualized: + return self.timeframe().annualize(pnl) + else: + return pnl + + def max_drawdown(self): + max_equity = self.equities[0] + result = 0.0 + for equity in self.equities: + if equity > max_equity: + max_equity = equity + + dd = (equity - max_equity) / max_equity + if dd < result: + result = dd + + return result + + def max_gain(self): + min_equity = self.equities[0] + result = 0.0 + for equity in self.equities: + if equity < min_equity: + min_equity = equity + + gain = (equity - min_equity) / min_equity + if gain > result: + result = gain + + return result diff --git a/roboquant/trackers/markettracker.py b/roboquant/trackers/markettracker.py new file mode 100644 index 0000000..9a75cfd --- /dev/null +++ b/roboquant/trackers/markettracker.py @@ -0,0 +1,52 @@ +from roboquant.timeframe import Timeframe + + +class MarketTracker: + + class __Entry: + """Keeps track of the market returns of a single symbol""" + + __slots__ = "start_time", "end_time", "start_price", "end_price" + + def __init__(self, time, price): + self.start_time = time + self.start_price = price + self.end_time = time + self.end_price = price + + def weighted(self): + rate = self.end_price / self.start_price - 1.0 + return rate * self.duration + + @property + def duration(self): + return (self.end_time - self.start_time).total_seconds() + + def __init__(self): + self.market_returns = {} + self.price_type = "DEFAULT" + + def trace(self, event, account, signals, orders): + for symbol, item in event.price_items.items(): + price = item.price(self.price_type) + if mr := self.market_returns.get(symbol): + mr.end_time = event.time + mr.end_price = price + else: + self.market_returns[symbol] = self.__Entry(event.time, price) + + def timeframe(self): + start = min([v.start_time for v in self.market_returns.values()]) + end = max([v.end_time for v in self.market_returns.values()]) + return Timeframe(start, end, True) + + def get_market_return(self): + mr = [v for v in self.market_returns.values()] + total = sum(v.weighted() for v in mr) + sum_weights = sum(v.duration for v in mr) + avg_return = total / sum_weights if sum_weights != 0.0 else float("NaN") + tf = self.timeframe() + if tf: + return tf.annualize(avg_return) + else: + return 0.0 diff --git a/roboquant/trackers/standardtracker.py b/roboquant/trackers/standardtracker.py deleted file mode 100644 index 23aea81..0000000 --- a/roboquant/trackers/standardtracker.py +++ /dev/null @@ -1,174 +0,0 @@ -from datetime import datetime -from ..timeframe import Timeframe -from ..event import Event -from prettytable import PrettyTable -import math -from .tracker import Tracker - - -class _MarketReturn: - """Keeps track of the market returns of a single symbol""" - - __slots__ = "start_time", "end_time", "start_price", "end_price" - - def __init__(self, time, price): - self.start_time = time - self.start_price = price - self.end_time = time - self.end_price = price - - def weighted(self): - rate = self.end_price / self.start_price - 1.0 - return rate * self.duration - - @property - def duration(self): - return (self.end_time - self.start_time).total_seconds() - - -class _PropertyCalculator: - """Keeps track of the market returns of a single symbol""" - - __slots__ = "start_time", "end_time", "total" - - def __init__(self): - self.start_time = None - self.total = 0 - self.end_time = None - - def add(self, value: int, time: datetime): - if value != 0: - self.total += value - self.end_time = time - if self.start_time is None: - self.start_time = time - - -class _EquityCalculator: - """Tracks several equity metrics""" - - def __init__(self): - self.max_equity = -10e9 - self.min_equity = 10e9 - self.mdd = 0.0 - self.max_gain = 0.0 - self.start_equity = float("nan") - self.end_equity = float("nan") - - def add(self, equity): - if math.isnan(self.start_equity): - self.start_equity = equity - - self.end_equity = equity - - if equity > self.max_equity: - self.max_equity = equity - - if equity < self.min_equity: - self.min_equity = equity - - dd = (equity - self.max_equity) / self.max_equity - if dd < self.mdd: - self.mdd = dd - - gain = (equity - self.min_equity) / self.min_equity - if gain > self.max_gain: - self.max_gain = gain - - -class StandardTracker(Tracker): - """Tracks a number of key metrics: - - total, min and max of events, items, signals, orders and equity - - drawdown and gain - - annual performance - - market performance - """ - - def __init__(self, price_type="DEFAULT"): - self.properties = { - "event": _PropertyCalculator(), - "item": _PropertyCalculator(), - "signal": _PropertyCalculator(), - "order": _PropertyCalculator(), - } - self.market_returns: dict[str, _MarketReturn] = dict() - self.price_type = price_type - self.mddCalculator = _EquityCalculator() - self.max_positions = 0 - - def _update_market_returns(self, event: Event): - for symbol, item in event.price_items.items(): - price = item.price(self.price_type) - if mr := self.market_returns.get(symbol): - mr.end_time = event.time - mr.end_price = price - else: - self.market_returns[symbol] = _MarketReturn(event.time, price) - - def get_market_return(self): - mr = [v for v in self.market_returns.values()] - total = sum(v.weighted() for v in mr) - sum_weights = sum(v.duration for v in mr) - avg_return = total / sum_weights if sum_weights != 0.0 else float("NaN") - tf = self.timeframe() - if tf: - return tf.annualize(avg_return) - else: - return 0.0 - - def log(self, event, account, signals, orders): - t = event.time - prop = self.properties - - prop["event"].add(1, t) - prop["item"].add(len(event.items), t) - prop["signal"].add(len(signals), t) - prop["order"].add(len(orders), t) - - if (npositions := len(account.positions)) > self.max_positions: - self.max_positions = npositions - - self.mddCalculator.add(account.equity) - self._update_market_returns(event) - - def __repr__(self) -> str: - if self.properties["event"].total == 0: - return "no events observed" - - def to_timefmt(time: datetime | None): - return "-" if time is None else time.strftime("%Y-%m-%d %H:%M:%S") - - pnl = self.annualized_pnl() * 100 - mkt_pnl = self.get_market_return() * 100 - p = PrettyTable(["metric", "value"], align="r", float_format=".2") - for k, v in self.properties.items(): - p.add_row([f"total {k}s", v.total]) - p.add_row([f"first {k}", to_timefmt(v.start_time)]) - p.add_row([f"last {k}", to_timefmt(v.end_time)], divider=True) - - p.add_row(["max positions", self.max_positions], divider=True) - - p.add_row(["start equity", self.mddCalculator.start_equity]) - p.add_row(["end equity", self.mddCalculator.end_equity]) - p.add_row(["min equity", self.mddCalculator.min_equity]) - p.add_row(["max equity", self.mddCalculator.max_equity], divider=True) - - p.add_row(["max drawdown %", self.mddCalculator.mdd * 100]) - p.add_row(["max gain %", self.mddCalculator.max_gain * 100], divider=True) - - p.add_row(["annual pnl %", pnl]) - p.add_row(["annual mkt %", mkt_pnl]) - return p.get_string() - - def timeframe(self): - events = self.properties["event"] - if events.total > 0: - return Timeframe(events.start_time, events.end_time, inclusive=True) # type: ignore - - def annualized_pnl(self): - pnl = self.mddCalculator.end_equity / self.mddCalculator.start_equity - 1.0 - tf = self.timeframe() - if tf: - return tf.annualize(pnl) - else: - return 0.0 diff --git a/roboquant/trackers/tensorboardtracker.py b/roboquant/trackers/tensorboardtracker.py index f54095c..7e5120e 100644 --- a/roboquant/trackers/tensorboardtracker.py +++ b/roboquant/trackers/tensorboardtracker.py @@ -19,7 +19,7 @@ def __init__(self, summary_writer): self.writer = summary_writer self._step = 0 - def log(self, event: Event, account: Account, signals: dict[str, Signal], orders: list[Order]): + def trace(self, event: Event, account: Account, signals: dict[str, Signal], orders: list[Order]): self.items += len(event.items) self.signals += len(signals) self.orders += len(orders) diff --git a/roboquant/trackers/tracker.py b/roboquant/trackers/tracker.py index 923d10e..52df1d8 100644 --- a/roboquant/trackers/tracker.py +++ b/roboquant/trackers/tracker.py @@ -8,10 +8,10 @@ class Tracker(Protocol): """ - A tracker allows for tracking and/or logging of one or more metrics during a run. + A tracker allows for the tracking and/or logging of one or more metrics during a run. """ - def log(self, event: Event, account: Account, signals: dict[str, Signal], orders: list[Order]): + def trace(self, event: Event, account: Account, signals: dict[str, Signal], orders: list[Order]): """invoked at each step of a run that provides the tracker with the opportunity to - calculate metrics and log these.""" + track and log various metrics.""" ... diff --git a/tests/data/output/account_repr.txt b/tests/data/output/account_repr.txt index 3ce29a8..5da9423 100644 --- a/tests/data/output/account_repr.txt +++ b/tests/data/output/account_repr.txt @@ -1,24 +1,14 @@ -+--------------+---------------------+ -| account | value | -+--------------+---------------------+ -| buying power | 1000000.00 | -| equity | 1000000.00 | -| positions | 3 | -| orders | 2 | -| last update | 1900-01-01 00:00:00 | -+--------------+---------------------+ + buying power equity last update +-------------- ------------ ------------------------- + 1000000.00 1000000.00 1900-01-01 00:00:00+00:00 -+--------+---------------+--------------+ -| symbol | position size | avg price | -+--------+---------------+--------------+ -| STOCK0 | 10 | 100.00 | -| STOCK1 | 10 | 100.00 | -| STOCK2 | 10 | 100.00 | -+--------+---------------+--------------+ +symbol order size limit id status +-------- ------------ ------- ---- ------------------- +STOCK0 100 OrderStatus.INITIAL +STOCK1 100 OrderStatus.INITIAL -+--------+------------+----------+-------+---------+--------+ -| symbol | order size | order id | limit | status | closed | -+--------+------------+----------+-------+---------+--------+ -| STOCK0 | 100 | None | None | INITIAL | False | -| STOCK1 | 100 | None | None | INITIAL | False | -+--------+------------+----------+-------+---------+--------+ +symbol pos size avg price +-------- ---------- ----------- +STOCK0 10 100 +STOCK1 10 100 +STOCK2 10 100 \ No newline at end of file diff --git a/tests/samples/backtest.py b/tests/samples/backtest.py new file mode 100644 index 0000000..8b39b63 --- /dev/null +++ b/tests/samples/backtest.py @@ -0,0 +1,9 @@ +from roboquant import Roboquant, EMACrossover, YahooFeed + +if __name__ == "__main__": + """Minimal back test scenario""" + + feed = YahooFeed("JPM", "IBM", "F", start_date="2000-01-01") + rq = Roboquant(EMACrossover()) + account = rq.run(feed) + print(account) diff --git a/tests/samples/talib_strategy.py b/tests/samples/talib_strategy.py index a307f15..c2e02ea 100644 --- a/tests/samples/talib_strategy.py +++ b/tests/samples/talib_strategy.py @@ -1,20 +1,20 @@ import unittest import talib.stream as ta -from roboquant import CandleStrategy, OHLCVBuffer +from roboquant import CandleStrategy, OHLCVBuffer, BUY, SELL, Signal from tests.common import run_strategy class MyTaLibStrategy(CandleStrategy): """Example using talib to create a strategy""" - def _create_signal(self, _, ohlcv: OHLCVBuffer) -> float | None: + def _create_signal(self, _, ohlcv: OHLCVBuffer) -> Signal | None: close = ohlcv.close() ema12 = ta.EMA(close, 12) # type: ignore ema26 = ta.EMA(close, 26) # type: ignore if ema12 > ema26: - return 1.0 + return BUY if ema12 < ema26: - return -1.0 + return SELL class TestOHLCVStrategy(unittest.TestCase): diff --git a/tests/samples/walkforward.py b/tests/samples/walkforward.py index 3483f6a..7bd7250 100644 --- a/tests/samples/walkforward.py +++ b/tests/samples/walkforward.py @@ -1,16 +1,16 @@ -from roboquant import Roboquant, EMACrossover, StandardTracker, YahooFeed +from roboquant import Roboquant, EMACrossover, EquityTracker, YahooFeed + if __name__ == "__main__": feed = YahooFeed("JPM", "IBM", "F", start_date="2000-01-01") - # split the feed timeframe in 5 equal parts - timeframes = feed.timeframe().split(5) + # split the feed timeframe in 4 equal parts + timeframes = feed.timeframe().split(4) # run a back-test on each timeframe for timeframe in timeframes: rq = Roboquant(EMACrossover(13, 26)) - tracker = StandardTracker() + tracker = EquityTracker() rq.run(feed, tracker, timeframe) - pnl = tracker.annualized_pnl() * 100 - mkt = tracker.get_market_return() * 100 - print(f"{timeframe} portfolio-pnl = {pnl:5.2f}% mkt-pnl = {mkt:5.2f}%") + pnl = tracker.pnl(annualized=True) * 100 + print(f"{timeframe} portfolio-pnl = {pnl:5.2f}%") diff --git a/tests/unit/test_account.py b/tests/unit/test_account.py index 570c75d..845e4ad 100644 --- a/tests/unit/test_account.py +++ b/tests/unit/test_account.py @@ -1,8 +1,7 @@ import unittest from decimal import Decimal -from roboquant import Account, Position, OptionAccount, Order -from tests.common import get_output +from roboquant import Account, Position, OptionAccount class TestAccount(unittest.TestCase): @@ -34,20 +33,6 @@ def test_account_option(self): self.assertEqual(20000.0, acc.contract_value("AAPL 131101C00470000", Decimal(1), 200.0)) self.assertEqual(2000.0, acc.contract_value("AAPL7 131101C00470000", Decimal(1), 200.0)) - def test_account_repr(self): - acc = Account() - acc.buying_power = 1_000_000.0 - acc.equity = 1_000_000.0 - - for i in range(3): - acc.positions[f"STOCK{i}"] = Position(Decimal(10), 100.0) - - for i in range(2): - acc.orders.append(Order(f"STOCK{i}", 100)) - - self.maxDiff = None - self.assertEqual(acc.__repr__(), get_output("account_repr.txt")) - if __name__ == "__main__": unittest.main() diff --git a/tests/unit/test_capmtracker.py b/tests/unit/test_alphabetatracker.py similarity index 84% rename from tests/unit/test_capmtracker.py rename to tests/unit/test_alphabetatracker.py index 69c7630..056bd36 100644 --- a/tests/unit/test_capmtracker.py +++ b/tests/unit/test_alphabetatracker.py @@ -1,7 +1,7 @@ import unittest from roboquant import Roboquant from roboquant.strategies import EMACrossover -from roboquant.trackers import CAPMTracker +from roboquant.trackers import AlphaBetaTracker from tests.common import get_feed @@ -10,7 +10,7 @@ class TestCAPMTracker(unittest.TestCase): def test_capmtracker(self): rq = Roboquant(EMACrossover()) feed = get_feed() - tracker = CAPMTracker() + tracker = AlphaBetaTracker() rq.run(feed, tracker=tracker) alpha, beta = tracker.alpha_beta() self.assertGreater(alpha, -1) diff --git a/tests/unit/test_equitytracker.py b/tests/unit/test_equitytracker.py index 025cad7..2c2d278 100644 --- a/tests/unit/test_equitytracker.py +++ b/tests/unit/test_equitytracker.py @@ -11,7 +11,7 @@ def test_capmtracker(self): tracker = EquityTracker() rq.run(feed, tracker=tracker) - timeline, equity = tracker.timeline, tracker.equity + timeline, equity = tracker.timeline, tracker.equities self.assertEqual(len(timeline), len(equity)) self.assertEqual(feed.timeframe().start, timeline[0]) diff --git a/tests/unit/test_standardtracker.py b/tests/unit/test_standardtracker.py deleted file mode 100644 index d3f9757..0000000 --- a/tests/unit/test_standardtracker.py +++ /dev/null @@ -1,19 +0,0 @@ -import unittest -from roboquant import Roboquant, EMACrossover, StandardTracker -from tests.common import get_feed, get_output - - -class StandardTrackerTest(unittest.TestCase): - - def test_standardtracker(self): - rq = Roboquant(EMACrossover()) - feed = get_feed() - tracker = StandardTracker() - rq.run(feed, tracker=tracker) - - self.maxDiff = None - self.assertEqual(tracker.__repr__(), get_output("standardtracker_repr.txt")) - - -if __name__ == "__main__": - unittest.main()