diff --git a/roboquant/__init__.py b/roboquant/__init__.py index 9c5a2d4..ca48da1 100644 --- a/roboquant/__init__.py +++ b/roboquant/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.6.3" +__version__ = "0.6.4" from roboquant import brokers from roboquant import feeds diff --git a/roboquant/asset.py b/roboquant/asset.py index 0a90c4b..1f9e6bd 100644 --- a/roboquant/asset.py +++ b/roboquant/asset.py @@ -1,7 +1,7 @@ from abc import ABC from dataclasses import dataclass from decimal import Decimal -from typing import ClassVar +from typing import ClassVar, Type from roboquant.wallet import Amount @@ -77,7 +77,7 @@ def contract_value(self, size: Decimal, price: float) -> float: return float(size) * price * self.multiplier -def __default_deserializer(clazz): +def __default_deserializer(clazz: Type[Asset]): __cache: dict[str, Asset] = {} diff --git a/roboquant/feeds/__init__.py b/roboquant/feeds/__init__.py index 4684fc9..98ccc63 100644 --- a/roboquant/feeds/__init__.py +++ b/roboquant/feeds/__init__.py @@ -6,6 +6,7 @@ from .historic import HistoricFeed from .randomwalk import RandomWalk from .sqllitefeed import SQLFeed +from .parquetfeed import ParquetFeed try: from .yahoo import YahooFeed diff --git a/roboquant/feeds/parquetfeed.py b/roboquant/feeds/parquetfeed.py index c8a61f1..7f1b990 100644 --- a/roboquant/feeds/parquetfeed.py +++ b/roboquant/feeds/parquetfeed.py @@ -1,7 +1,7 @@ import logging import os.path from array import array -from typing import Any +from typing import Any, Iterable import pyarrow as pa import pyarrow.parquet as pq @@ -36,10 +36,16 @@ def exists(self): return os.path.exists(self.parquet_path) def play(self, channel: EventChannel): + # pylint: disable=too-many-locals dataset = pq.ParquetFile(self.parquet_path) last_time: Any = None items = [] - for batch in dataset.iter_batches(): + + row_group_indexes = self.__get_row_group_indexes(channel.timeframe) + + # for batch in dataset.iter_batches(): + for idx in row_group_indexes: + batch = dataset.read_row_group(idx) times = batch.column("time") assets = batch.column("asset") prices = batch.column("prices") @@ -71,6 +77,18 @@ def play(self, channel: EventChannel): event = Event(now, items) channel.put(event) + def __get_row_group_indexes(self, timeframe: Timeframe | None) -> Iterable[int]: + md = pq.read_metadata(self.parquet_path) + time = timeframe.start if timeframe else None + + if not time: + return range(0, md.num_row_groups) + for idx in range(md.num_row_groups): + stat = md.row_group(idx).column(0).statistics + if stat.max >= time: + return range(idx, md.num_row_groups) + return [] + def timeframe(self) -> Timeframe: d = pq.read_metadata(self.parquet_path).to_dict() if d["row_groups"]: @@ -80,6 +98,15 @@ def timeframe(self) -> Timeframe: return tf return EMPTY_TIMEFRAME + def assets(self) -> list[Asset]: + if not self.exists(): + return [] + + result_table = pq.read_table(self.parquet_path, columns=["asset"], schema=ParquetFeed.__schema) + assets_list = result_table["asset"].to_pylist() + assets_set = set(assets_list) + return list({Asset.deserialize(s) for s in assets_set}) + def meta(self): return pq.read_metadata(self.parquet_path) diff --git a/roboquant/strategies/basestrategy.py b/roboquant/strategies/basestrategy.py index b1ab021..0f0ae35 100644 --- a/roboquant/strategies/basestrategy.py +++ b/roboquant/strategies/basestrategy.py @@ -15,8 +15,7 @@ class BaseStrategy(Strategy): # pylint: disable=too-many-instance-attributes - """Base version of strategy that contains several methods to make it easier to create orders. - """ + """Base version of strategy that contains several methods to make it easier to manage orders.""" def __init__(self) -> None: super().__init__() @@ -24,7 +23,7 @@ def __init__(self) -> None: self.buy_price = "DEFAULT" self.sell_price = "DEFAULT" self.fractional_order_digits = 0 - self.cancel_existing_orders = True + self.cancel_orders_older_than = timedelta(days=30) self._orders: list[Order] self._order_value: float @@ -39,6 +38,7 @@ def create_orders(self, event: Event, account: Account) -> list[Order]: self._event = event self._order_value = round(account.equity_value() * self.order_value_perc, 2) + self.cancel_old_orders() self.process(event, account) return self._orders @@ -53,6 +53,8 @@ def add_buy_order(self, asset: Asset, limit: float | None = None): if size := self._get_size(asset, limit): order = Order(asset, size, limit) return self.add_order(order) + + logger.info("rejected buy order asset %s", asset) return False def add_exit_order(self, asset: Asset, limit: float | None = None): @@ -60,6 +62,8 @@ def add_exit_order(self, asset: Asset, limit: float | None = None): if limit := limit or self._get_limit(asset, size > 0): order = Order(asset, size, limit) return self.add_order(order) + + logger.info("rejected exit order asset %s", asset) return False def add_sell_order(self, asset: Asset, limit: float | None = None): @@ -67,6 +71,8 @@ def add_sell_order(self, asset: Asset, limit: float | None = None): if size := self._get_size(asset, limit) * -1: order = Order(asset, size, limit) return self.add_order(order) + + logger.info("rejected sell order asset %s", asset) return False def _get_limit(self, asset: Asset, is_buy: bool) -> float | None: @@ -91,14 +97,14 @@ def add_order(self, order: Order) -> bool: return False self._remaining_buying_power -= bp - if self.cancel_existing_orders: - self.cancel_open_orders(order.asset) self._orders.append(order) return True - def cancel_old_orders(self, older_than: timedelta): + def cancel_old_orders(self): for order in self._account.orders: - if order.created_at + older_than < self._event.time: + if not order.created_at: + continue + if order.created_at + self.cancel_orders_older_than < self._event.time: self.cancel_order(order) def cancel_open_orders(self, *assets): diff --git a/roboquant/strategies/emacrossover.py b/roboquant/strategies/emacrossover.py index 36e6c3c..5066207 100644 --- a/roboquant/strategies/emacrossover.py +++ b/roboquant/strategies/emacrossover.py @@ -1,3 +1,4 @@ +from datetime import timedelta from roboquant.account import Account from roboquant.asset import Asset from roboquant.event import Event @@ -14,6 +15,7 @@ def __init__(self, fast_period=13, slow_period=26, smoothing=2.0, price_type="DE self.slow = 1.0 - (smoothing / (slow_period + 1)) self.price_type = price_type self.min_steps = max(fast_period, slow_period) + self.cancel_orders_older_than = timedelta(days=5) def process(self, event: Event, account: Account): for asset, price in event.get_prices(self.price_type).items(): diff --git a/tests/unit/test_order.py b/tests/unit/test_order.py index f6884ac..f0b54d8 100644 --- a/tests/unit/test_order.py +++ b/tests/unit/test_order.py @@ -2,17 +2,22 @@ from decimal import Decimal from roboquant import Order +from roboquant.asset import Stock + + +apple = Stock("AAPL") class TestOrder(unittest.TestCase): def test_order_create(self): - order = Order("AAPL", 100, 120.0) + order = Order(apple, 100, 120.0) self.assertEqual(120.0, order.limit) self.assertEqual(None, order.id) + self.assertEqual(apple, order.asset) def test_order_info(self): - order = Order("AAPL", 100, 120.0, tif="ABC") + order = Order(apple, 100, 120.0, tif="ABC") info = order.info self.assertIn("tif", info) @@ -22,7 +27,7 @@ def test_order_info(self): self.assertIn("tif", info) def test_order_update(self): - order = Order("AAPL", 100, 120.0) + order = Order(apple, 100, 120.0) order.id = "update1" update_order = order.modify(size=50) @@ -31,7 +36,7 @@ def test_order_update(self): self.assertEqual(order.id, update_order.id) def test_order_cancel(self): - order = Order("AAPL", 100, 120.0) + order = Order(apple, 100, 120.0) order.id = "cancel1" cancel_order = order.cancel() diff --git a/tests/unit/test_parquetfeed.py b/tests/unit/test_parquetfeed.py new file mode 100644 index 0000000..5000c59 --- /dev/null +++ b/tests/unit/test_parquetfeed.py @@ -0,0 +1,32 @@ +import tempfile +import unittest +from pathlib import Path + +from roboquant.feeds import ParquetFeed +from tests.common import get_feed, run_price_item_feed + + +class TestParquetFeed(unittest.TestCase): + + def test_sql_feed(self): + path = tempfile.gettempdir() + db_file = Path(path).joinpath("tmp.parquet") + db_file.unlink(missing_ok=True) + self.assertFalse(db_file.exists()) + + feed = ParquetFeed(db_file) + self.assertFalse(feed.exists()) + + origin_feed = get_feed() + feed.record(origin_feed) + self.assertTrue(db_file.exists()) + + self.assertEqual(set(origin_feed.assets()), set(feed.assets())) + self.assertEqual(origin_feed.timeframe(), feed.timeframe()) + + run_price_item_feed(feed, origin_feed.assets(), self) + db_file.unlink(missing_ok=True) + + +if __name__ == "__main__": + unittest.main()