diff --git a/roboquant/brokers/ibkrbroker.py b/roboquant/brokers/ibkrbroker.py index aa0de47..05c222b 100644 --- a/roboquant/brokers/ibkrbroker.py +++ b/roboquant/brokers/ibkrbroker.py @@ -4,25 +4,21 @@ from datetime import datetime, timezone, timedelta from decimal import Decimal +from ibapi import VERSION +from ibapi.account_summary_tags import AccountSummaryTags +from ibapi.client import EClient +from ibapi.contract import Contract +from ibapi.order import Order as IBKROrder +from ibapi.wrapper import EWrapper + from roboquant.account import Account, Position from roboquant.event import Event from roboquant.order import Order, OrderStatus from .broker import Broker -logger = logging.getLogger(__name__) - -try: - from ibapi.client import EClient - from ibapi.wrapper import EWrapper - from ibapi.contract import Contract - from ibapi.order import Order as IBKROrder - from ibapi.account_summary_tags import AccountSummaryTags - from ibapi import VERSION +assert VERSION["major"] == 10 and VERSION["minor"] == 19, "Wrong version of the IBAPI found" - assert VERSION["major"] == 10 and VERSION["minor"] == 19, "Wrong version of the IBAPI found" -except ImportError: - logger.fatal("Couldn't import IBAPI package, you need to install this") - pass +logger = logging.getLogger(__name__) # noinspection PyPep8Naming @@ -183,7 +179,7 @@ def place_orders(self, *orders: Order): if order.id is None: order.id = self.__api.get_next_order_id() self.__api.orders[order.id] = order - ibkr_order = self._get_order(order) + ibkr_order = self.__get_order(order) contract = self.contract_mapping.get(order.symbol) or self._get_default_contract(order.symbol) self.__api.placeOrder(int(order.id), contract, ibkr_order) @@ -201,7 +197,8 @@ def _get_default_contract(self, symbol: str) -> Contract: c.exchange = "SMART" # use smart routing by default return c - def _get_order(self, order: Order): + @staticmethod + def __get_order(order: Order): o = IBKROrder() o.action = "BUY" if order.is_buy else "SELL" o.totalQuantity = abs(order.size) diff --git a/roboquant/feeds/__init__.py b/roboquant/feeds/__init__.py index 38e996b..ac33579 100644 --- a/roboquant/feeds/__init__.py +++ b/roboquant/feeds/__init__.py @@ -7,4 +7,8 @@ from .randomwalk import RandomWalk from .sqllitefeed import SQLFeed from .tiingofeed import TiingoLiveFeed, TiingoHistoricFeed -from .yahoofeed import YahooFeed + +try: + from .yahoofeed import YahooFeed +except ImportError: + pass diff --git a/roboquant/feeds/sqllitefeed.py b/roboquant/feeds/sqllitefeed.py index 7be5cb6..23ffc27 100644 --- a/roboquant/feeds/sqllitefeed.py +++ b/roboquant/feeds/sqllitefeed.py @@ -21,17 +21,17 @@ class SQLFeed(Feed): _sql_insert_row = "INSERT into prices VALUES(?,?,?,?,?,?,?,?)" _sql_create_index = "CREATE INDEX IF NOT EXISTS date_idx ON prices(date)" - def __init__(self, file) -> None: + def __init__(self, db_file) -> None: super().__init__() - self.file = file + self.db_file = db_file def create_index(self): - con = sqlite3.connect(self.file) + con = sqlite3.connect(self.db_file) con.execute(SQLFeed._sql_create_index) con.commit() def timeframe(self): - con = sqlite3.connect(self.file) + con = sqlite3.connect(self.db_file) result = con.execute(SQLFeed._sql_select_timeframe).fetchall() con.commit() row = result[0] @@ -39,14 +39,14 @@ def timeframe(self): return tf def symbols(self): - con = sqlite3.connect(self.file) + con = sqlite3.connect(self.db_file) result = con.execute(SQLFeed._sql_select_symbols).fetchall() con.commit() symbols = [columns[0] for columns in result] return symbols def play(self, channel: EventChannel): - con = sqlite3.connect(self.file) + con = sqlite3.connect(self.db_file) cur = con.cursor() cnt = 0 t_old = "" @@ -75,7 +75,7 @@ def play(self, channel: EventChannel): def record(self, feed, timeframe=None, append=False): """Record another feed to this SQLite database""" - con = sqlite3.connect(self.file) + con = sqlite3.connect(self.db_file) cur = con.cursor() if not append: diff --git a/roboquant/feeds/yahoofeed.py b/roboquant/feeds/yahoofeed.py index f6d6706..36344b9 100644 --- a/roboquant/feeds/yahoofeed.py +++ b/roboquant/feeds/yahoofeed.py @@ -2,6 +2,8 @@ from array import array from datetime import datetime, timezone +import yfinance + from roboquant.event import Candle from roboquant.feeds.historicfeed import HistoricFeed @@ -18,12 +20,6 @@ def __init__(self, *symbols: str, start_date="2010-01-01", end_date: str | None columns = ["Open", "High", "Low", "Close", "Volume", "Adj Close"] - try: - import yfinance - except ImportError: - logger.fatal("Couldn't import yfinance package") - return - for symbol in symbols: logger.debug("requesting symbol=%s", symbol) df = yfinance.Ticker(symbol).history( diff --git a/tests/integration/test_yahoofeed.py b/tests/integration/test_yahoofeed.py index 11efc4b..5c03920 100644 --- a/tests/integration/test_yahoofeed.py +++ b/tests/integration/test_yahoofeed.py @@ -6,7 +6,7 @@ class TestYahooFeed(unittest.TestCase): - def test_yahoofeed(self): + def test_yahoo_feed(self): feed = YahooFeed("MSFT", "JPM", start_date="2018-01-01", end_date="2020-01-01") self.assertEqual(2, len(feed.symbols)) self.assertEqual({"MSFT", "JPM"}, set(feed.symbols)) @@ -16,7 +16,7 @@ def test_yahoofeed(self): run_priceitem_feed(feed, ["MSFT", "JPM"], self) - def test_yahoofeed_wrong_symbol(self): + def test_yahoo_feed_wrong_symbol(self): # expect some error logging due to parsing an invalid symbol feed = YahooFeed("INVALID_TICKER_NAME", start_date="2010-01-01", end_date="2020-01-01") self.assertEqual(0, len(feed.symbols))