diff --git a/roboquant/__init__.py b/roboquant/__init__.py index 95e5e9d..9c5a2d4 100644 --- a/roboquant/__init__.py +++ b/roboquant/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.6.2" +__version__ = "0.6.3" from roboquant import brokers from roboquant import feeds diff --git a/roboquant/feeds/historic.py b/roboquant/feeds/historic.py index 43aef17..bc2eb57 100644 --- a/roboquant/feeds/historic.py +++ b/roboquant/feeds/historic.py @@ -36,10 +36,10 @@ def _add_item(self, time: datetime, item: PriceItem): items = self.__data[time] items.append(item) - def assets(self): + def assets(self) -> list[Asset]: """Return the list of unique symbols available in this feed""" self.__update() - return self.__assets + return list(self.__assets) def timeline(self) -> list[datetime]: """Return the timeline of this feed as a list of datatime objects""" diff --git a/roboquant/feeds/parquetfeed.py b/roboquant/feeds/parquetfeed.py index a9ff631..c8a61f1 100644 --- a/roboquant/feeds/parquetfeed.py +++ b/roboquant/feeds/parquetfeed.py @@ -4,7 +4,6 @@ from typing import Any import pyarrow as pa -import pyarrow.dataset as ds import pyarrow.parquet as pq from roboquant.event import Quote, Bar, Trade @@ -72,45 +71,6 @@ def play(self, channel: EventChannel): event = Event(now, items) channel.put(event) - def _generator(self, feed: Feed, timeframe: Timeframe | None): - channel = feed.play_background(timeframe) - t_old = "" - items = [] - cnt = 0 - while event := channel.get(): - t = event.time - - if t != t_old and cnt > 10_000: - if items: - batch = pa.RecordBatch.from_pylist(items, schema=ParquetFeed.__schema) - yield batch - items = [] - t_old = t - cnt = 0 - - for item in event.items: - match item: - case Quote(): - cnt += 1 - items.append({"time": t, "type": 1, "asset": item.asset.serialize(), "prices": item.data.tolist()}) - case Bar(): - cnt += 1 - items.append({"time": t, "type": 2, "asset": item.asset.serialize(), "prices": item.ohlcv.tolist()}) - case Trade(): - cnt += 1 - items.append( - { - "time": t, - "type": 3, - "asset": item.asset.serialize(), - "prices": [item.trade_price, item.trade_volume], - } - ) - - if items: - batch = pa.RecordBatch.from_pylist(items, schema=ParquetFeed.__schema) - yield batch - def timeframe(self) -> Timeframe: d = pq.read_metadata(self.parquet_path).to_dict() if d["row_groups"]: @@ -123,15 +83,6 @@ def timeframe(self) -> Timeframe: def meta(self): return pq.read_metadata(self.parquet_path) - def record2(self, feed: Feed, timeframe: Timeframe | None = None): - ds.write_dataset( - self._generator(feed, timeframe), - self.parquet_path, - format="parquet", - schema=ParquetFeed.__schema, - existing_data_behavior="overwrite_or_ignore", - ) - def __repr__(self) -> str: return f"ParquetFeed(path={self.parquet_path})" diff --git a/roboquant/feeds/randomwalk.py b/roboquant/feeds/randomwalk.py index 748d1b5..4942853 100644 --- a/roboquant/feeds/randomwalk.py +++ b/roboquant/feeds/randomwalk.py @@ -6,7 +6,7 @@ import numpy as np -from roboquant.asset import Stock +from roboquant.asset import Asset, Stock from roboquant.event import Bar, Trade, Quote from .historic import HistoricFeed @@ -76,14 +76,14 @@ def __get_assets( rnd, n_symbols, symbol_len, - ): - symbols = set() + ) -> list[Asset]: + assets = set() alphabet = np.array(list(string.ascii_uppercase)) - while len(symbols) < n_symbols: + while len(assets) < n_symbols: symbol = "".join(rnd.choice(alphabet, size=symbol_len)) asset = Stock(symbol, "USD") - symbols.add(asset) - return symbols + assets.add(asset) + return list(assets) @staticmethod def __price_path(rnd, n, scale, min_price, max_price): diff --git a/tests/integration/test_ibkr.py b/tests/integration/test_ibkr.py index cb75e97..b872fc0 100644 --- a/tests/integration/test_ibkr.py +++ b/tests/integration/test_ibkr.py @@ -15,7 +15,7 @@ def test_ibkr_order(self): Amount.converter = One2OneConversion() logging.basicConfig(level=logging.DEBUG) logging.getLogger("ibapi").setLevel(logging.WARNING) - asset = Stock("JPM", "USD") + asset = Stock("JPM") limit = 205 broker = IBKRBroker() diff --git a/tests/unit/test_rnn.py b/tests/unit/test_rnn.py index 7355670..9e69eb0 100644 --- a/tests/unit/test_rnn.py +++ b/tests/unit/test_rnn.py @@ -30,20 +30,19 @@ def test_lstm_model(self): # logging.basicConfig() # logging.getLogger("roboquant.strategies").setLevel(level=logging.INFO) # Setup - symbol = "AAPL" - asset = Stock(symbol, "USD") + apple = Stock("AAPL") prediction = 10 feed = get_feed() model = _MyModel() input_feature = CombinedFeature( - BarFeature(asset), - SMAFeature(PriceFeature(asset, price_type="HIGH"), 10) + BarFeature(apple), + SMAFeature(PriceFeature(apple, price_type="HIGH"), 10) ).returns().normalize() - label_feature = PriceFeature(asset, price_type="CLOSE").returns(prediction) + label_feature = PriceFeature(apple, price_type="CLOSE").returns(prediction) - strategy = RNNStrategy(input_feature, label_feature, model, symbol, sequences=20, buy_pct=0.01) + strategy = RNNStrategy(input_feature, label_feature, model, apple, sequences=20, buy_pct=0.01) # Train the model with 10 years of data tf = rq.Timeframe.fromisoformat("2010-01-01", "2020-01-01") diff --git a/tests/unit/test_sqlfeed.py b/tests/unit/test_sqlfeed.py index 57711dc..798fce4 100644 --- a/tests/unit/test_sqlfeed.py +++ b/tests/unit/test_sqlfeed.py @@ -23,7 +23,7 @@ def test_sql_feed(self): self.assertEqual(origin_feed.timeframe(), feed.timeframe()) feed.create_index() - self.assertEqual(origin_feed.assets(), feed.assets()) + self.assertEqual(set(origin_feed.assets()), set(feed.assets())) run_price_item_feed(feed, origin_feed.assets(), self) db_file.unlink(missing_ok=True)