Skip to content

Commit

Permalink
more typing and __slots__
Browse files Browse the repository at this point in the history
  • Loading branch information
jbaron committed Aug 20, 2024
1 parent a329788 commit 1fc19fc
Show file tree
Hide file tree
Showing 14 changed files with 44 additions and 38 deletions.
1 change: 1 addition & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[flake8]
max-line-length = 127
# ignore = W291,E226,W503
per-file-ignores =
# imported but unused
__init__.py: F401
Expand Down
4 changes: 3 additions & 1 deletion roboquant/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ class Account:
Only the broker updates the account and does this only during its `sync` method.
"""

__slots__ = "buying_power", "positions", "orders", "last_update", "cash"

def __init__(self, base_currency: Currency = USD):
self.buying_power: Amount = Amount(base_currency, 0.0)
self.positions: dict[Asset, Position] = {}
Expand All @@ -51,7 +53,7 @@ def __init__(self, base_currency: Currency = USD):
self.cash: Wallet = Wallet()

@property
def base_currency(self):
def base_currency(self) -> Currency:
"""Return the base currency of this account"""
return self.buying_power.currency

Expand Down
2 changes: 1 addition & 1 deletion roboquant/asset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
@dataclass(frozen=True, slots=True)
class Asset(ABC):
"""Abstract baseclass for all types of assets, ranging from stocks to cryptocurrencies.
Every asset has always at least a `symbol` and `currency` defined.
Every asset has always at least a `symbol` and `currency` defined. Assets are immutable.
"""

symbol: str
Expand Down
2 changes: 1 addition & 1 deletion roboquant/brokers/simbroker.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def place_orders(self, orders: list[Order]):
Orders that are placed that have already an order-id are either update- or cancellation-orders.
There is no trading simulation yet performed or account updated. This is done during the `sync` method.
Orders placed at time `t`, will be processed during time `t+1`. This protects against future bias.
Orders placed at time `t`, will be processed during time `t+1`. This protects against future bias.
"""
for order in orders:
if order.id is None:
Expand Down
2 changes: 1 addition & 1 deletion roboquant/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(self, path=None):
self.config = ConfigParser()
self.config.read_string(config_string)

def get(self, key):
def get(self, key: str) -> str:
for key2, value in os.environ.items():
final_key = key2.lower().replace("_", ".")
if final_key == key:
Expand Down
2 changes: 1 addition & 1 deletion roboquant/ml/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def step(self, action):
logger.debug("time=%s action=%s", self.event.time, action)

signals = [Signal(asset, float(rating)) for asset, rating in zip(self.assets, action)]
orders = self.trader.create_orders(signals, self.event, self.account)
orders = self.trader.create_orders(signals, self.event, self.account)
self.broker.place_orders(orders)

if self.journal:
Expand Down
3 changes: 2 additions & 1 deletion roboquant/ml/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ def create_signals(self, event: Event) -> list[Signal]:
return []

@abstractmethod
def predict(self, x: NDArray, time: datetime) -> list[Signal]: ...
def predict(self, x: NDArray, time: datetime) -> list[Signal]:
...


class SequenceDataset(Dataset):
Expand Down
12 changes: 6 additions & 6 deletions roboquant/order.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,10 @@ def __deepcopy__(self, memo):
result.fill = self.fill
return result

def value(self):
def value(self) -> float:
return self.asset.contract_value(self.size, self.limit)

def amount(self):
def amount(self) -> Amount:
return Amount(self.asset.currency, self.value())

@property
Expand All @@ -81,22 +81,22 @@ def is_cancellation(self):
return self.size.is_zero()

@property
def is_buy(self):
def is_buy(self) -> bool:
"""Return True if this is a BUY order, False otherwise"""
return self.size > 0

@property
def is_sell(self):
def is_sell(self) -> bool:
"""Return True if this is a SELL order, False otherwise"""
return self.size < 0

@property
def completed(self):
def completed(self) -> bool:
"""Return True if the order is completed (completely filled)"""
return not self.remaining

@property
def remaining(self):
def remaining(self) -> Decimal:
"""Return the remaining order size to be filled.
In case of a sell order, the remaining will be a negative number.
Expand Down
4 changes: 2 additions & 2 deletions roboquant/strategies/emacrossover.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,11 @@ def __init__(self, momentum1: float, momentum2: float, price: float):
self.price2 = price
self.step = 0

def is_above(self):
def is_above(self) -> bool:
"""Return True is the first momentum is above the second momentum, False otherwise"""
return self.price1 > self.price2

def add_price(self, price: float):
def add_price(self, price: float) -> int:
m1, m2 = self.momentum1, self.momentum2
self.price1 = m1 * self.price1 + (1.0 - m1) * price
self.price2 = m2 * self.price2 + (1.0 - m2) * price
Expand Down
13 changes: 11 additions & 2 deletions roboquant/strategies/multistrategy.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from itertools import groupby
from statistics import mean
from typing import Literal

from roboquant.event import Event
Expand All @@ -11,14 +13,14 @@ class MultiStrategy(Strategy):
- first: in case of multiple signals for the same symbol, the first one wins
- last: in case of multiple signals for the same symbol, the last one wins.
- avg: return the avgerage of the signals. All signals will be ENTRY and EXIT.
- mean: return the mean of the signal ratings. All signals will be ENTRY and EXIT.
- none: return all signals. This is also the default.
"""

def __init__(
self,
*strategies: Strategy,
order_filter: Literal["last", "first", "none"] = "none"
order_filter: Literal["last", "first", "none", "mean"] = "none"
):
super().__init__()
self.strategies = list(strategies)
Expand All @@ -38,5 +40,12 @@ def create_signals(self, event: Event) -> list[Signal]:
case "first":
s = {s.asset: s for s in reversed(signals)}
return list(s.values())
case "mean":
result = []
for key, group in groupby(signals, lambda signal: signal.asset):
rating = mean(signal.rating for signal in group)
if rating:
result.append(Signal(key, rating))
return result

raise ValueError("unsupported signal filter")
4 changes: 2 additions & 2 deletions roboquant/strategies/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@


class Strategy(ABC):
"""A strategy creates signals based on incoming events and the items within these events.
"""A strategy creates signals based on incoming events and the items contained within these events.
Often the items represent market data associated with an asset, but other types of items
are also possible.
"""

@abstractmethod
def create_signals(self, event: Event) -> list[Signal]:
"""Create zero or more signals for provided event."""
"""Create zero or more signals given the provided event."""
...
4 changes: 2 additions & 2 deletions roboquant/strategies/tastrategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ def __init__(self, size: int) -> None:
self._data: dict[Asset, OHLCVBuffer] = {}
self.size = size

def create_signals(self, event):
result = []
def create_signals(self, event) -> list[Signal]:
result: list[Signal] = []
for item in event.items:
if isinstance(item, Bar):
asset = item.asset
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/test_ibkr.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,14 @@ def test_ibkr_order(self):
self.assertEqual(asset, account.orders[0].asset)

# Update an order
update_order = order.modify(size=5, limit=limit-1)
update_order = order.modify(size=5, limit=limit - 1)
broker.place_orders([update_order])
time.sleep(5)
account = broker.sync()
print(account)
self.assertEqual(len(account.orders), 1)
self.assertEqual(account.orders[0].size, Decimal(5))
self.assertEqual(account.orders[0].limit, limit-1)
self.assertEqual(account.orders[0].limit, limit - 1)

# Cancel an order
cancel_order = update_order.cancel()
Expand Down
25 changes: 9 additions & 16 deletions tests/performance/test_delay.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from statistics import mean, stdev

from roboquant import Timeframe
from roboquant.feeds import Feed
from roboquant.alpaca import AlpacaLiveFeed


Expand All @@ -24,31 +23,25 @@ class TestDelay(unittest.TestCase):

__symbols = ["TSLA", "MSFT", "NVDA", "AMD", "AAPL", "AMZN", "META", "GOOG", "XOM", "JPM", "NLFX", "BA", "INTC", "V"]

def __run_feed(self, feed: Feed):
def test_alpaca_delay(self):
feed = AlpacaLiveFeed(market="iex")
feed.subscribe_quotes(*TestDelay.__symbols)
timeframe = Timeframe.next(minutes=1)
channel = feed.play_background(timeframe, 1000)
name = type(feed).__name__

delays = []
n = 0
while event := channel.get(70):
while event := channel.get(10):
if event.items:
n += len(event.items)
delays.append(time.time() - event.time.timestamp())

if delays:
t = (
f"feed={name} mean={mean(delays):.3f} stdev={stdev(delays):.3f} "
+ f"max={max(delays):.3f} min={min(delays):.3f} events={len(delays)} items={n}"
)
print(t)
else:
print(f"Didn't receive any items from {name}, is it perhaps outside trading hours?")
self.assertTrue(delays, "Didn't receive any quotes, is it perhaps outside trading hours?")

def test_alpaca_delay(self):
feed = AlpacaLiveFeed(market="iex")
feed.subscribe_quotes(*TestDelay.__symbols)
self.__run_feed(feed)
print(
f"delays mean={mean(delays):.3f} stdev={stdev(delays):.3f}",
f"max={max(delays):.3f} min={min(delays):.3f} events={len(delays)} items={n}"
)


if __name__ == "__main__":
Expand Down

0 comments on commit 1fc19fc

Please sign in to comment.