Skip to content

Commit

Permalink
improved imports
Browse files Browse the repository at this point in the history
  • Loading branch information
jbaron committed Feb 27, 2024
1 parent 76334f1 commit 4f8c8dc
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 31 deletions.
27 changes: 12 additions & 15 deletions roboquant/brokers/ibkrbroker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,21 @@
from datetime import datetime, timezone, timedelta
from decimal import Decimal

from ibapi import VERSION
from ibapi.account_summary_tags import AccountSummaryTags
from ibapi.client import EClient
from ibapi.contract import Contract
from ibapi.order import Order as IBKROrder
from ibapi.wrapper import EWrapper

from roboquant.account import Account, Position
from roboquant.event import Event
from roboquant.order import Order, OrderStatus
from .broker import Broker

logger = logging.getLogger(__name__)

try:
from ibapi.client import EClient
from ibapi.wrapper import EWrapper
from ibapi.contract import Contract
from ibapi.order import Order as IBKROrder
from ibapi.account_summary_tags import AccountSummaryTags
from ibapi import VERSION
assert VERSION["major"] == 10 and VERSION["minor"] == 19, "Wrong version of the IBAPI found"

assert VERSION["major"] == 10 and VERSION["minor"] == 19, "Wrong version of the IBAPI found"
except ImportError:
logger.fatal("Couldn't import IBAPI package, you need to install this")
pass
logger = logging.getLogger(__name__)


# noinspection PyPep8Naming
Expand Down Expand Up @@ -183,7 +179,7 @@ def place_orders(self, *orders: Order):
if order.id is None:
order.id = self.__api.get_next_order_id()
self.__api.orders[order.id] = order
ibkr_order = self._get_order(order)
ibkr_order = self.__get_order(order)
contract = self.contract_mapping.get(order.symbol) or self._get_default_contract(order.symbol)
self.__api.placeOrder(int(order.id), contract, ibkr_order)

Expand All @@ -201,7 +197,8 @@ def _get_default_contract(self, symbol: str) -> Contract:
c.exchange = "SMART" # use smart routing by default
return c

def _get_order(self, order: Order):
@staticmethod
def __get_order(order: Order):
o = IBKROrder()
o.action = "BUY" if order.is_buy else "SELL"
o.totalQuantity = abs(order.size)
Expand Down
6 changes: 5 additions & 1 deletion roboquant/feeds/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,8 @@
from .randomwalk import RandomWalk
from .sqllitefeed import SQLFeed
from .tiingofeed import TiingoLiveFeed, TiingoHistoricFeed
from .yahoofeed import YahooFeed

try:
from .yahoofeed import YahooFeed
except ImportError:
pass
14 changes: 7 additions & 7 deletions roboquant/feeds/sqllitefeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,32 +21,32 @@ class SQLFeed(Feed):
_sql_insert_row = "INSERT into prices VALUES(?,?,?,?,?,?,?,?)"
_sql_create_index = "CREATE INDEX IF NOT EXISTS date_idx ON prices(date)"

def __init__(self, file) -> None:
def __init__(self, db_file) -> None:
super().__init__()
self.file = file
self.db_file = db_file

def create_index(self):
con = sqlite3.connect(self.file)
con = sqlite3.connect(self.db_file)
con.execute(SQLFeed._sql_create_index)
con.commit()

def timeframe(self):
con = sqlite3.connect(self.file)
con = sqlite3.connect(self.db_file)
result = con.execute(SQLFeed._sql_select_timeframe).fetchall()
con.commit()
row = result[0]
tf = Timeframe.fromisoformat(row[0], row[1], True)
return tf

def symbols(self):
con = sqlite3.connect(self.file)
con = sqlite3.connect(self.db_file)
result = con.execute(SQLFeed._sql_select_symbols).fetchall()
con.commit()
symbols = [columns[0] for columns in result]
return symbols

def play(self, channel: EventChannel):
con = sqlite3.connect(self.file)
con = sqlite3.connect(self.db_file)
cur = con.cursor()
cnt = 0
t_old = ""
Expand Down Expand Up @@ -75,7 +75,7 @@ def play(self, channel: EventChannel):

def record(self, feed, timeframe=None, append=False):
"""Record another feed to this SQLite database"""
con = sqlite3.connect(self.file)
con = sqlite3.connect(self.db_file)
cur = con.cursor()

if not append:
Expand Down
8 changes: 2 additions & 6 deletions roboquant/feeds/yahoofeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from array import array
from datetime import datetime, timezone

import yfinance

from roboquant.event import Candle
from roboquant.feeds.historicfeed import HistoricFeed

Expand All @@ -18,12 +20,6 @@ def __init__(self, *symbols: str, start_date="2010-01-01", end_date: str | None

columns = ["Open", "High", "Low", "Close", "Volume", "Adj Close"]

try:
import yfinance
except ImportError:
logger.fatal("Couldn't import yfinance package")
return

for symbol in symbols:
logger.debug("requesting symbol=%s", symbol)
df = yfinance.Ticker(symbol).history(
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/test_yahoofeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

class TestYahooFeed(unittest.TestCase):

def test_yahoofeed(self):
def test_yahoo_feed(self):
feed = YahooFeed("MSFT", "JPM", start_date="2018-01-01", end_date="2020-01-01")
self.assertEqual(2, len(feed.symbols))
self.assertEqual({"MSFT", "JPM"}, set(feed.symbols))
Expand All @@ -16,7 +16,7 @@ def test_yahoofeed(self):

run_priceitem_feed(feed, ["MSFT", "JPM"], self)

def test_yahoofeed_wrong_symbol(self):
def test_yahoo_feed_wrong_symbol(self):
# expect some error logging due to parsing an invalid symbol
feed = YahooFeed("INVALID_TICKER_NAME", start_date="2010-01-01", end_date="2020-01-01")
self.assertEqual(0, len(feed.symbols))
Expand Down

0 comments on commit 4f8c8dc

Please sign in to comment.