Skip to content

Commit

Permalink
Fixed small bug
Browse files Browse the repository at this point in the history
  • Loading branch information
jbaron committed Jul 20, 2024
1 parent 318bcff commit 55d5d90
Show file tree
Hide file tree
Showing 7 changed files with 16 additions and 66 deletions.
2 changes: 1 addition & 1 deletion roboquant/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.6.2"
__version__ = "0.6.3"

from roboquant import brokers
from roboquant import feeds
Expand Down
4 changes: 2 additions & 2 deletions roboquant/feeds/historic.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ def _add_item(self, time: datetime, item: PriceItem):
items = self.__data[time]
items.append(item)

def assets(self):
def assets(self) -> list[Asset]:
"""Return the list of unique symbols available in this feed"""
self.__update()
return self.__assets
return list(self.__assets)

def timeline(self) -> list[datetime]:
"""Return the timeline of this feed as a list of datatime objects"""
Expand Down
49 changes: 0 additions & 49 deletions roboquant/feeds/parquetfeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import Any

import pyarrow as pa
import pyarrow.dataset as ds
import pyarrow.parquet as pq

from roboquant.event import Quote, Bar, Trade
Expand Down Expand Up @@ -72,45 +71,6 @@ def play(self, channel: EventChannel):
event = Event(now, items)
channel.put(event)

def _generator(self, feed: Feed, timeframe: Timeframe | None):
channel = feed.play_background(timeframe)
t_old = ""
items = []
cnt = 0
while event := channel.get():
t = event.time

if t != t_old and cnt > 10_000:
if items:
batch = pa.RecordBatch.from_pylist(items, schema=ParquetFeed.__schema)
yield batch
items = []
t_old = t
cnt = 0

for item in event.items:
match item:
case Quote():
cnt += 1
items.append({"time": t, "type": 1, "asset": item.asset.serialize(), "prices": item.data.tolist()})
case Bar():
cnt += 1
items.append({"time": t, "type": 2, "asset": item.asset.serialize(), "prices": item.ohlcv.tolist()})
case Trade():
cnt += 1
items.append(
{
"time": t,
"type": 3,
"asset": item.asset.serialize(),
"prices": [item.trade_price, item.trade_volume],
}
)

if items:
batch = pa.RecordBatch.from_pylist(items, schema=ParquetFeed.__schema)
yield batch

def timeframe(self) -> Timeframe:
d = pq.read_metadata(self.parquet_path).to_dict()
if d["row_groups"]:
Expand All @@ -123,15 +83,6 @@ def timeframe(self) -> Timeframe:
def meta(self):
return pq.read_metadata(self.parquet_path)

def record2(self, feed: Feed, timeframe: Timeframe | None = None):
ds.write_dataset(
self._generator(feed, timeframe),
self.parquet_path,
format="parquet",
schema=ParquetFeed.__schema,
existing_data_behavior="overwrite_or_ignore",
)

def __repr__(self) -> str:
return f"ParquetFeed(path={self.parquet_path})"

Expand Down
12 changes: 6 additions & 6 deletions roboquant/feeds/randomwalk.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import numpy as np

from roboquant.asset import Stock
from roboquant.asset import Asset, Stock
from roboquant.event import Bar, Trade, Quote
from .historic import HistoricFeed

Expand Down Expand Up @@ -76,14 +76,14 @@ def __get_assets(
rnd,
n_symbols,
symbol_len,
):
symbols = set()
) -> list[Asset]:
assets = set()
alphabet = np.array(list(string.ascii_uppercase))
while len(symbols) < n_symbols:
while len(assets) < n_symbols:
symbol = "".join(rnd.choice(alphabet, size=symbol_len))
asset = Stock(symbol, "USD")
symbols.add(asset)
return symbols
assets.add(asset)
return list(assets)

@staticmethod
def __price_path(rnd, n, scale, min_price, max_price):
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_ibkr.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def test_ibkr_order(self):
Amount.converter = One2OneConversion()
logging.basicConfig(level=logging.DEBUG)
logging.getLogger("ibapi").setLevel(logging.WARNING)
asset = Stock("JPM", "USD")
asset = Stock("JPM")
limit = 205

broker = IBKRBroker()
Expand Down
11 changes: 5 additions & 6 deletions tests/unit/test_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,19 @@ def test_lstm_model(self):
# logging.basicConfig()
# logging.getLogger("roboquant.strategies").setLevel(level=logging.INFO)
# Setup
symbol = "AAPL"
asset = Stock(symbol, "USD")
apple = Stock("AAPL")
prediction = 10
feed = get_feed()
model = _MyModel()

input_feature = CombinedFeature(
BarFeature(asset),
SMAFeature(PriceFeature(asset, price_type="HIGH"), 10)
BarFeature(apple),
SMAFeature(PriceFeature(apple, price_type="HIGH"), 10)
).returns().normalize()

label_feature = PriceFeature(asset, price_type="CLOSE").returns(prediction)
label_feature = PriceFeature(apple, price_type="CLOSE").returns(prediction)

strategy = RNNStrategy(input_feature, label_feature, model, symbol, sequences=20, buy_pct=0.01)
strategy = RNNStrategy(input_feature, label_feature, model, apple, sequences=20, buy_pct=0.01)

# Train the model with 10 years of data
tf = rq.Timeframe.fromisoformat("2010-01-01", "2020-01-01")
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_sqlfeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_sql_feed(self):
self.assertEqual(origin_feed.timeframe(), feed.timeframe())
feed.create_index()

self.assertEqual(origin_feed.assets(), feed.assets())
self.assertEqual(set(origin_feed.assets()), set(feed.assets()))

run_price_item_feed(feed, origin_feed.assets(), self)
db_file.unlink(missing_ok=True)
Expand Down

0 comments on commit 55d5d90

Please sign in to comment.