Skip to content

Commit

Permalink
added tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jbaron committed Feb 26, 2024
1 parent ce94af7 commit d2dfd00
Show file tree
Hide file tree
Showing 13 changed files with 176 additions and 108 deletions.
22 changes: 18 additions & 4 deletions roboquant/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,27 @@ def contract_value(self, symbol: str, size: Decimal, price: float) -> float:
return float(size) * price

def mkt_value(self, prices: dict[str, float]) -> float:
"""Return the market value of all the open positions in the account using the provided prices."""
return sum([self.contract_value(symbol, pos.size, prices[symbol]) for symbol, pos in self.positions.items()], 0.0)
"""Return the market value of all the open positions in the account using the provided prices.
If there is no known price provided for a position, the average price paid will be used instead
"""
return sum(
[
self.contract_value(symbol, pos.size, prices[symbol] if symbol in prices else pos.avg_price)
for symbol, pos in self.positions.items()
],
0.0,
)

def unrealized_pnl(self, prices: dict[str, float]) -> float:
"""Return the unrealized profit and loss for the open position given the provided market prices"""
"""Return the unrealized profit and loss for the open position given the provided market prices
Positions that don't have a known price, will be ignored.
"""
return sum(
[self.contract_value(symbol, pos.size, prices[symbol] - pos.avg_price) for symbol, pos in self.positions.items()],
[
self.contract_value(symbol, pos.size, prices[symbol] - pos.avg_price)
for symbol, pos in self.positions.items()
if symbol in prices
],
0.0,
)

Expand Down
4 changes: 4 additions & 0 deletions roboquant/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,10 @@ def price_items(self) -> dict[str, PriceItem]:
"""
return {item.symbol: item for item in self.items if isinstance(item, PriceItem)}

def get_prices(self, price_type: str = "DEFAULT") -> dict[str, float]:
"""Return all the prices of a certain price_type"""
return {k: v.price(price_type) for k, v in self.price_items.items()}

def get_price(self, symbol: str, price_type: str = "DEFAULT") -> float | None:
"""Return the price for the symbol, or None if not found."""

Expand Down
2 changes: 1 addition & 1 deletion roboquant/journals/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from roboquant.journals.basicjournal import BasicJournal
from roboquant.journals.tensorboardjournal import TensorboardJournal
from roboquant.journals.runmetric import RunMetric
from roboquant.journals.equitymetric import EquityMetric
from roboquant.journals.pnlmetric import PNLMetric
from roboquant.journals.metric import Metric
from roboquant.journals.feedmetric import FeedMetric
from roboquant.journals.pricemetric import PriceItemMetric
Original file line number Diff line number Diff line change
@@ -1,31 +1,51 @@
from roboquant.journals.metric import Metric


class EquityMetric(Metric):
"""Calculates the following equity related metrics:
- `value` equity value itself
class PNLMetric(Metric):
"""Calculates the following PNL related metrics:
- `equity` value
- `mdd` max drawdown
- `pnl` since the previous step in the run
- `total_pnl` since the beginning of the run
- `new` pnl since the previous step in the run
- `unrealized` pnl in the open positions
- `realized` pnl
- `total` pnl
"""

def __init__(self):
self.max_drawdown = 0.0
self.max_gain = 0.0
self.first_equity = None
self.prev_equity = None
self.max_equity = -10e10
self.min_equity = 10e10
self._prices = {}

def calc(self, event, account, signals, orders) -> dict[str, float]:
equity = account.equity

total, realized, unrealized = self.__get_pnls(equity, event, account)

return {
"equity/value": equity,
"equity/max_drawdown": self.__get_max_drawdown(equity),
"equity/max_gain": self.__get_max_gain(equity),
"equity/pnl": self.__get_pnl(equity),
"pnl/equity": equity,
"pnl/max_drawdown": self.__get_max_drawdown(equity),
"pnl/max_gain": self.__get_max_gain(equity),
"pnl/new": self.__get_new_pnl(equity),
"pnl/total": total,
"pnl/realized": realized,
"pnl/unrealized": unrealized,
}

def __get_pnl(self, equity):
def __get_pnls(self, equity, event, account):
if self.first_equity is None:
self.first_equity = equity

self._prices.update(event.get_prices())
unrealized = account.unrealized_pnl(self._prices)
total = equity - self.first_equity
realized = total - unrealized
return total, realized, unrealized

def __get_new_pnl(self, equity):
if self.prev_equity is None:
self.prev_equity = equity

Expand Down
Empty file added tests/performance/__init__.py
Empty file.
30 changes: 0 additions & 30 deletions tests/performance/bigcsvfeed_test.py

This file was deleted.

18 changes: 0 additions & 18 deletions tests/performance/profiling_test.py

This file was deleted.

42 changes: 42 additions & 0 deletions tests/performance/test_bigfeed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import os
import time
import unittest
import roboquant as rq


class TestBigFeed(unittest.TestCase):

def test_bigfeed(self):
start = time.time()
path = os.path.expanduser("~/data/nyse_stocks/")
feed = rq.feeds.CSVFeed.stooq_us_daily(path)
loadtime = time.time() - start
strategy = rq.strategies.EMACrossover(13, 26)
journal = rq.journals.BasicJournal()
start = time.time()
account = rq.run(feed, strategy, journal=journal)
runtime = time.time() - start

self.assertTrue(journal.items > 1_000_000)
self.assertTrue(journal.signals > 100_000)
self.assertTrue(journal.orders > 10_000)
self.assertTrue(journal.events > 10_000)

print(account)
print(journal)

# Print statistics
print()
print(f"load time = {loadtime:.1f}s")
print("files =", len(feed.symbols))
print(f"throughput = {len(feed.symbols) / loadtime:.0f} files/s")
print(f"run time = {runtime:.1f}s")
candles = journal.items
print(f"candles = {(candles / 1_000_000):.1f}M")
throughput = candles / (runtime * 1_000_000)
print(f"throughput = {throughput:.1f}M candles/s")
print()


if __name__ == "__main__":
unittest.main()
25 changes: 25 additions & 0 deletions tests/performance/test_profiling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from cProfile import Profile
import os
from pstats import Stats, SortKey
import roboquant as rq
import unittest


class TestProfile(unittest.TestCase):

def test_profile(self):
path = os.path.expanduser("~/data/nasdaq_stocks/1")
feed = rq.feeds.CSVFeed.stooq_us_daily(path)
print("timeframe =", feed.timeframe(), " symbols =", len(feed.symbols))
strategy = rq.strategies.EMACrossover(13, 26)
journal = rq.journals.BasicJournal()

# Profile the run to detect bottlenecks
with Profile() as profile:
rq.run(feed, strategy, journal=journal)
print(f"\n{journal}")
Stats(profile).sort_stats(SortKey.TIME).print_stats()


if __name__ == "__main__":
unittest.main()
52 changes: 52 additions & 0 deletions tests/performance/test_tiingodelay.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import logging
import time

from roboquant import Timeframe
from roboquant.feeds import EventChannel, TiingoLiveFeed, feedutil
from statistics import mean, stdev
import unittest


class TestTiingoDelay(unittest.TestCase):

def test_tiingodelay(self):
"""
Measure the average delay receiving prices from IEX using Tiingo.
This includes the following paths:
- From IEX to Tiingo (New York)
- Tiingo holds it for 15ms (requirement from IEX)
- From Tiingo to the modem/access-point in your house
- From the access-point to your computer (f.e lan or Wi-Fi)
"""

logging.basicConfig(level=logging.INFO)

feed = TiingoLiveFeed(market="iex")

# subscribe to all IEX stocks for TOP of order book changes and Trades.
feed.subscribe(threshold_level=5)

timeframe = Timeframe.next(minutes=1)
channel = EventChannel(timeframe, maxsize=10_000)
feedutil.play_background(feed, channel)

delays = []
while event := channel.get():
if event.items:
delays.append(time.time() - event.time.timestamp())

if delays:
t = (
f"mean={mean(delays):.3f} stdev={stdev(delays):.3f}"
+ f"max={max(delays):.3f} min={min(delays):.3f} n={len(delays)}"
)
print(t)
else:
print("didn't receive any actions, is it perhaps outside trading hours?")

feed.close()


if __name__ == "__main__":
unittest.main()
41 changes: 0 additions & 41 deletions tests/performance/tiingodelay_test.py

This file was deleted.

4 changes: 2 additions & 2 deletions tests/samples/tensorboard_metrics.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import roboquant as rq
from roboquant.journals import TensorboardJournal, EquityMetric, RunMetric, FeedMetric, PriceItemMetric, AlphaBeta
from roboquant.journals import TensorboardJournal, PNLMetric, RunMetric, FeedMetric, PriceItemMetric, AlphaBeta
from tensorboard.summary import Writer

if __name__ == "__main__":
Expand All @@ -13,6 +13,6 @@
s = rq.strategies.EMACrossover(p1, p2)
log_dir = f"""runs/ema_{p1}_{p2}"""
writer = Writer(log_dir)
journal = TensorboardJournal(writer, EquityMetric(), RunMetric(), FeedMetric(), PriceItemMetric("JPM"), AlphaBeta(200))
journal = TensorboardJournal(writer, PNLMetric(), RunMetric(), FeedMetric(), PriceItemMetric("JPM"), AlphaBeta(200))
account = rq.run(feed, s, journal=journal)
writer.close()
4 changes: 2 additions & 2 deletions tests/unit/test_tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import tempfile
import unittest
import roboquant as rq
from roboquant.journals import RunMetric, EquityMetric, TensorboardJournal
from roboquant.journals import RunMetric, PNLMetric, TensorboardJournal
from tests.common import get_feed
from tensorboard.summary import Writer

Expand All @@ -16,7 +16,7 @@ def test_tensorboard_journal(self):

output = Path(tmpdir).joinpath("runs")
writer = Writer(str(output))
journal = TensorboardJournal(writer, RunMetric(), EquityMetric())
journal = TensorboardJournal(writer, RunMetric(), PNLMetric())
rq.run(feed, rq.strategies.EMACrossover(), journal=journal)
writer.close()

Expand Down

0 comments on commit d2dfd00

Please sign in to comment.