Skip to content

Commit

Permalink
added tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jbaron committed Jul 21, 2024
1 parent 55d5d90 commit cce155a
Show file tree
Hide file tree
Showing 8 changed files with 89 additions and 16 deletions.
2 changes: 1 addition & 1 deletion roboquant/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.6.3"
__version__ = "0.6.4"

from roboquant import brokers
from roboquant import feeds
Expand Down
4 changes: 2 additions & 2 deletions roboquant/asset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import ABC
from dataclasses import dataclass
from decimal import Decimal
from typing import ClassVar
from typing import ClassVar, Type

from roboquant.wallet import Amount

Expand Down Expand Up @@ -77,7 +77,7 @@ def contract_value(self, size: Decimal, price: float) -> float:
return float(size) * price * self.multiplier


def __default_deserializer(clazz):
def __default_deserializer(clazz: Type[Asset]):

__cache: dict[str, Asset] = {}

Expand Down
1 change: 1 addition & 0 deletions roboquant/feeds/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .historic import HistoricFeed
from .randomwalk import RandomWalk
from .sqllitefeed import SQLFeed
from .parquetfeed import ParquetFeed

try:
from .yahoo import YahooFeed
Expand Down
31 changes: 29 additions & 2 deletions roboquant/feeds/parquetfeed.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import os.path
from array import array
from typing import Any
from typing import Any, Iterable

import pyarrow as pa
import pyarrow.parquet as pq
Expand Down Expand Up @@ -36,10 +36,16 @@ def exists(self):
return os.path.exists(self.parquet_path)

def play(self, channel: EventChannel):
# pylint: disable=too-many-locals
dataset = pq.ParquetFile(self.parquet_path)
last_time: Any = None
items = []
for batch in dataset.iter_batches():

row_group_indexes = self.__get_row_group_indexes(channel.timeframe)

# for batch in dataset.iter_batches():
for idx in row_group_indexes:
batch = dataset.read_row_group(idx)
times = batch.column("time")
assets = batch.column("asset")
prices = batch.column("prices")
Expand Down Expand Up @@ -71,6 +77,18 @@ def play(self, channel: EventChannel):
event = Event(now, items)
channel.put(event)

def __get_row_group_indexes(self, timeframe: Timeframe | None) -> Iterable[int]:
md = pq.read_metadata(self.parquet_path)
time = timeframe.start if timeframe else None

if not time:
return range(0, md.num_row_groups)
for idx in range(md.num_row_groups):
stat = md.row_group(idx).column(0).statistics
if stat.max >= time:
return range(idx, md.num_row_groups)
return []

def timeframe(self) -> Timeframe:
d = pq.read_metadata(self.parquet_path).to_dict()
if d["row_groups"]:
Expand All @@ -80,6 +98,15 @@ def timeframe(self) -> Timeframe:
return tf
return EMPTY_TIMEFRAME

def assets(self) -> list[Asset]:
if not self.exists():
return []

result_table = pq.read_table(self.parquet_path, columns=["asset"], schema=ParquetFeed.__schema)
assets_list = result_table["asset"].to_pylist()
assets_set = set(assets_list)
return list({Asset.deserialize(s) for s in assets_set})

def meta(self):
return pq.read_metadata(self.parquet_path)

Expand Down
20 changes: 13 additions & 7 deletions roboquant/strategies/basestrategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,15 @@

class BaseStrategy(Strategy):
# pylint: disable=too-many-instance-attributes
"""Base version of strategy that contains several methods to make it easier to create orders.
"""
"""Base version of strategy that contains several methods to make it easier to manage orders."""

def __init__(self) -> None:
super().__init__()
self.order_value_perc = 0.1
self.buy_price = "DEFAULT"
self.sell_price = "DEFAULT"
self.fractional_order_digits = 0
self.cancel_existing_orders = True
self.cancel_orders_older_than = timedelta(days=30)

self._orders: list[Order]
self._order_value: float
Expand All @@ -39,6 +38,7 @@ def create_orders(self, event: Event, account: Account) -> list[Order]:
self._event = event
self._order_value = round(account.equity_value() * self.order_value_perc, 2)

self.cancel_old_orders()
self.process(event, account)
return self._orders

Expand All @@ -53,20 +53,26 @@ def add_buy_order(self, asset: Asset, limit: float | None = None):
if size := self._get_size(asset, limit):
order = Order(asset, size, limit)
return self.add_order(order)

logger.info("rejected buy order asset %s", asset)
return False

def add_exit_order(self, asset: Asset, limit: float | None = None):
if size := -self._account.get_position_size(asset):
if limit := limit or self._get_limit(asset, size > 0):
order = Order(asset, size, limit)
return self.add_order(order)

logger.info("rejected exit order asset %s", asset)
return False

def add_sell_order(self, asset: Asset, limit: float | None = None):
if limit := limit or self._get_limit(asset, False):
if size := self._get_size(asset, limit) * -1:
order = Order(asset, size, limit)
return self.add_order(order)

logger.info("rejected sell order asset %s", asset)
return False

def _get_limit(self, asset: Asset, is_buy: bool) -> float | None:
Expand All @@ -91,14 +97,14 @@ def add_order(self, order: Order) -> bool:
return False

self._remaining_buying_power -= bp
if self.cancel_existing_orders:
self.cancel_open_orders(order.asset)
self._orders.append(order)
return True

def cancel_old_orders(self, older_than: timedelta):
def cancel_old_orders(self):
for order in self._account.orders:
if order.created_at + older_than < self._event.time:
if not order.created_at:
continue
if order.created_at + self.cancel_orders_older_than < self._event.time:
self.cancel_order(order)

def cancel_open_orders(self, *assets):
Expand Down
2 changes: 2 additions & 0 deletions roboquant/strategies/emacrossover.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from datetime import timedelta
from roboquant.account import Account
from roboquant.asset import Asset
from roboquant.event import Event
Expand All @@ -14,6 +15,7 @@ def __init__(self, fast_period=13, slow_period=26, smoothing=2.0, price_type="DE
self.slow = 1.0 - (smoothing / (slow_period + 1))
self.price_type = price_type
self.min_steps = max(fast_period, slow_period)
self.cancel_orders_older_than = timedelta(days=5)

def process(self, event: Event, account: Account):
for asset, price in event.get_prices(self.price_type).items():
Expand Down
13 changes: 9 additions & 4 deletions tests/unit/test_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,22 @@
from decimal import Decimal

from roboquant import Order
from roboquant.asset import Stock


apple = Stock("AAPL")


class TestOrder(unittest.TestCase):

def test_order_create(self):
order = Order("AAPL", 100, 120.0)
order = Order(apple, 100, 120.0)
self.assertEqual(120.0, order.limit)
self.assertEqual(None, order.id)
self.assertEqual(apple, order.asset)

def test_order_info(self):
order = Order("AAPL", 100, 120.0, tif="ABC")
order = Order(apple, 100, 120.0, tif="ABC")
info = order.info
self.assertIn("tif", info)

Expand All @@ -22,7 +27,7 @@ def test_order_info(self):
self.assertIn("tif", info)

def test_order_update(self):
order = Order("AAPL", 100, 120.0)
order = Order(apple, 100, 120.0)
order.id = "update1"

update_order = order.modify(size=50)
Expand All @@ -31,7 +36,7 @@ def test_order_update(self):
self.assertEqual(order.id, update_order.id)

def test_order_cancel(self):
order = Order("AAPL", 100, 120.0)
order = Order(apple, 100, 120.0)
order.id = "cancel1"

cancel_order = order.cancel()
Expand Down
32 changes: 32 additions & 0 deletions tests/unit/test_parquetfeed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import tempfile
import unittest
from pathlib import Path

from roboquant.feeds import ParquetFeed
from tests.common import get_feed, run_price_item_feed


class TestParquetFeed(unittest.TestCase):

def test_sql_feed(self):
path = tempfile.gettempdir()
db_file = Path(path).joinpath("tmp.parquet")
db_file.unlink(missing_ok=True)
self.assertFalse(db_file.exists())

feed = ParquetFeed(db_file)
self.assertFalse(feed.exists())

origin_feed = get_feed()
feed.record(origin_feed)
self.assertTrue(db_file.exists())

self.assertEqual(set(origin_feed.assets()), set(feed.assets()))
self.assertEqual(origin_feed.timeframe(), feed.timeframe())

run_price_item_feed(feed, origin_feed.assets(), self)
db_file.unlink(missing_ok=True)


if __name__ == "__main__":
unittest.main()

0 comments on commit cce155a

Please sign in to comment.