Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[wip] - support for shorting #450

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
13 changes: 13 additions & 0 deletions tensortrade/core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,19 @@ def __init__(self, balance: 'Quantity', size: 'Quantity', *args) -> None:
*args
)

# class ShortAgainstTheBoxException(Exception):
# """
# Raised when shorting against the box.
# More details: https://www.investopedia.com/terms/s/sellagainstthebox.asp

# Parameters
# ----------
# """
# def __init__(self, balance: 'Quantity', size: 'Quantity', *args) -> None:
# super().__init__(
# "Must sell the remaining {} balance first before shorting {}.".format(balance, size),
# *args
# )

# =============================================================================
# Trading Pair Exceptions
Expand Down
6 changes: 6 additions & 0 deletions tensortrade/env/default/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,12 @@ def get_orders(self, action: int, portfolio: 'Portfolio') -> 'List[Order]':
wallet = portfolio.get_wallet(ep.exchange.id, instrument=instrument)

balance = wallet.balance.as_float()

if balance < 0 and side == TradeSide.SELL:
# ignore sells of short positions aa it doesn't make sense
# short positions are to be bought back, not sold again.
return []

size = (balance * proportion)
size = min(balance, size)
quantity = (size * instrument).quantize()
Expand Down
3 changes: 2 additions & 1 deletion tensortrade/env/generic/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,9 @@ def __init__(self,
gym.envs.register(
id='TensorTrade-v0',
max_episode_steps=max_episode_steps,
entry_point="gym.envs"
)
self.spec = gym.spec(env_id='TensorTrade-v0')
self.spec = gym.spec('TensorTrade-v0')

for c in self.components.values():
c.clock = self.clock
Expand Down
2 changes: 1 addition & 1 deletion tensortrade/oms/instruments/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .exchange_pair import ExchangePair
from .instrument import *
from .quantity import Quantity
from .quantity import Quantity, NegativeQuantity
from .trading_pair import TradingPair

33 changes: 21 additions & 12 deletions tensortrade/oms/instruments/quantity.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ class Quantity:
Raised if the `size` of the quantity being created is negative.
"""

def __init__(self, instrument: 'Instrument', size: Decimal, path_id: str = None):
if size < 0:
def __init__(self, instrument: 'Instrument', size: Decimal, path_id: str = None, allow_negative=False):
if size < 0 and not allow_negative:
if abs(size) > Decimal(10)**(-instrument.precision):
raise InvalidNegativeQuantity(float(size))
else:
Expand All @@ -60,6 +60,7 @@ def __init__(self, instrument: 'Instrument', size: Decimal, path_id: str = None)
self.instrument = instrument
self.size = size if isinstance(size, Decimal) else Decimal(size)
self.path_id = path_id
self.allow_negative = allow_negative

@property
def is_locked(self) -> bool:
Expand All @@ -79,7 +80,7 @@ def lock_for(self, path_id: str) -> "Quantity":
`Quantity`
A locked quantity for an order path.
"""
return Quantity(self.instrument, self.size, path_id)
return Quantity(self.instrument, self.size, path_id, self.allow_negative)

def convert(self, exchange_pair: "ExchangePair") -> "Quantity":
"""Converts the quantity into the value of another instrument based
Expand All @@ -102,7 +103,7 @@ def convert(self, exchange_pair: "ExchangePair") -> "Quantity":
else:
instrument = exchange_pair.pair.base
converted_size = self.size * exchange_pair.price
return Quantity(instrument, converted_size, self.path_id)
return Quantity(instrument, converted_size, self.path_id, self.allow_negative)

def free(self) -> "Quantity":
"""Gets the free version of this quantity.
Expand All @@ -125,7 +126,8 @@ def quantize(self) -> "Quantity":
"""
return Quantity(self.instrument,
self.size.quantize(Decimal(10)**-self.instrument.precision),
self.path_id)
self.path_id,
self.allow_negative)

def as_float(self) -> float:
"""Gets the size as a `float`.
Expand Down Expand Up @@ -157,16 +159,16 @@ def contain(self, exchange_pair: "ExchangePair"):

if exchange_pair.pair.base == self.instrument:
size = self.size
return Quantity(self.instrument, min(size, options.max_trade_size), self.path_id)
return Quantity(self.instrument, min(size, options.max_trade_size), self.path_id, self.allow_negative)

size = self.size * price
if size < options.max_trade_size:
return Quantity(self.instrument, self.size, self.path_id)
if abs(size) < options.max_trade_size:
return Quantity(self.instrument, self.size, self.path_id, self.allow_negative)

max_trade_size = Decimal(options.max_trade_size)
contained_size = max_trade_size / price
contained_size = contained_size.quantize(Decimal(10)**-self.instrument.precision, rounding=ROUND_DOWN)
return Quantity(self.instrument, contained_size, self.path_id)
return Quantity(self.instrument, contained_size, self.path_id, self.allow_negative)

@staticmethod
def validate(left: "Union[Quantity, Number]",
Expand Down Expand Up @@ -216,11 +218,11 @@ def validate(left: "Union[Quantity, Number]",
return left, right

elif isinstance(left, Number) and isinstance(right, Quantity):
left = Quantity(right.instrument, left, right.path_id)
left = Quantity(right.instrument, left, right.path_id, right.allow_negative)
return left, right

elif isinstance(left, Quantity) and isinstance(right, Number):
right = Quantity(left.instrument, right, left.path_id)
right = Quantity(left.instrument, right, left.path_id, left.allow_negative)
return left, right

elif isinstance(left, Quantity):
Expand Down Expand Up @@ -277,7 +279,7 @@ def _math_op(left: "Union[Quantity, Number]",
"""
left, right = Quantity.validate(left, right)
size = op(left.size, right.size)
return Quantity(left.instrument, size, left.path_id)
return Quantity(left.instrument, size, left.path_id, left.allow_negative)

def __add__(self, other: "Quantity") -> "Quantity":
return Quantity._math_op(self, other, operator.add)
Expand Down Expand Up @@ -319,3 +321,10 @@ def __repr__(self) -> str:



@total_ordering
class NegativeQuantity(Quantity):
def __init__(self, instrument: 'Instrument', size: Union[Decimal, Number], path_id: str = None):
super().__init__(instrument, size, path_id, True)

def to_positive_quantity(self):
return Quantity(self.instrument, self.size, self.path_id, False)
2 changes: 1 addition & 1 deletion tensortrade/oms/wallets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .wallet import Wallet
from .wallet import Wallet, MarginWallet
from .portfolio import Portfolio


Expand Down
96 changes: 90 additions & 6 deletions tensortrade/oms/wallets/wallet.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License

from typing import Dict, Tuple
from typing import Dict, Tuple, Union
from collections import namedtuple
from decimal import Decimal
from numbers import Number

import numpy as np

Expand All @@ -25,7 +26,8 @@
DoubleUnlockedQuantity,
QuantityNotLocked
)
from tensortrade.oms.instruments import Instrument, Quantity, ExchangePair
from tensortrade.oms.instruments import Instrument, Quantity, ExchangePair, instrument
from tensortrade.oms.instruments.quantity import NegativeQuantity
from tensortrade.oms.orders import Order
from tensortrade.oms.exchanges import Exchange
from tensortrade.oms.wallets.ledger import Ledger
Expand Down Expand Up @@ -54,10 +56,13 @@ def __init__(self, exchange: 'Exchange', balance: 'Quantity'):
self.balance = balance.quantize()
self._locked = {}

def _get_quantity(self, instrument: 'Instrument', size: Union[Decimal, Number], path_id: str = None, allow_negative=False) -> 'Quantity':
return Quantity(instrument=instrument, size=size, path_id=path_id, allow_negative=allow_negative)

@property
def locked_balance(self) -> 'Quantity':
"""The total balance of the wallet locked in orders. (`Quantity`, read-only)"""
locked_balance = Quantity(self.instrument, 0)
locked_balance = self._get_quantity(self.instrument, 0)

for quantity in self._locked.values():
locked_balance += quantity.size
Expand Down Expand Up @@ -107,7 +112,7 @@ def lock(self, quantity, order: 'Order', reason: str) -> 'Quantity':
raise DoubleLockedQuantity(quantity)

if quantity > self.balance:
if (quantity-self.balance)>Decimal(10)**(-self.instrument.precision+2):
if (quantity-self.balance)>Decimal(10)**(-(self.instrument.precision+2)):
raise InsufficientFunds(self.balance, quantity)
else:
quantity = self.balance
Expand Down Expand Up @@ -332,7 +337,7 @@ def transfer(source: 'Wallet',
instrument = exchange_pair.pair.base
converted_size = quantity.size * exchange_pair.price

converted = Quantity(instrument, converted_size, quantity.path_id).quantize()
converted = target._get_quantity(instrument, converted_size, quantity.path_id).quantize()

converted = target.deposit(converted, 'TRADED {} {} @ {}'.format(quantity,
exchange_pair,
Expand Down Expand Up @@ -366,11 +371,90 @@ def transfer(source: 'Wallet',

def reset(self) -> None:
"""Resets the wallet."""
self.balance = Quantity(self.instrument, self._initial_size).quantize()
self.balance = self._get_quantity(self.instrument, self._initial_size).quantize()
self._locked = {}

def __str__(self) -> str:
return '<Wallet: balance={}, locked={}>'.format(self.balance, self.locked_balance)

def __repr__(self) -> str:
return str(self)

class ShortAgainstTheBoxException(Exception):
"""
Raised when shorting against the box.
More details: https://www.investopedia.com/terms/s/sellagainstthebox.asp

Parameters
----------
"""
def __init__(self, balance: 'Quantity', size: 'Quantity', *args) -> None:
super().__init__(
"Must sell the remaining {} balance first before shorting {}.".format(balance, size),
*args
)

class MarginWallet(Wallet):
def __init__(self, exchange: 'Exchange', balance: 'Quantity'):
super().__init__(exchange, balance)

def _get_quantity(self, instrument: 'Instrument', size: Union[Decimal, Number], path_id: str = None) -> 'Quantity':
return super()._get_quantity(instrument, size, path_id, allow_negative=True)

def lock(self, quantity, order: 'Order', reason: str) -> 'Quantity':
try:
return super().lock(quantity, order, reason)
except InsufficientFunds:
if self.balance > 0:
raise ShortAgainstTheBoxException(self.balance, quantity)

quantity = NegativeQuantity(quantity.instrument, quantity.size, quantity.path_id)
self.balance = NegativeQuantity(self.balance.instrument, self.balance.size, self.balance.path_id)
self.balance -= quantity

quantity = quantity.lock_for(order.path_id)

if quantity.path_id not in self._locked:
self._locked[quantity.path_id] = quantity
else:
self._locked[quantity.path_id] += quantity

self._locked[quantity.path_id] = self._locked[quantity.path_id].quantize()
self.balance = self.balance.quantize()

self.ledger.commit(wallet=self,
quantity=quantity,
source="{}:{}/free".format(self.exchange.name, self.instrument),
target="{}:{}/locked".format(self.exchange.name, self.instrument),
memo="LOCK ({})".format(reason))

return quantity

def withdraw(self, quantity: 'Quantity', reason: str) -> 'Quantity':
try:
return super().withdraw(quantity, reason)
except InsufficientFunds:
# try to move past this. as long as the balance is 0 and nothing is locked in an order, allow shorting
locked_quantity = 0 if quantity.path_id == None else self._locked[quantity.path_id]
if self.balance > 0 or locked_quantity > 0:
raise ShortAgainstTheBoxException(self.balance, quantity)
quantity = NegativeQuantity(quantity.instrument, quantity.size, quantity.path_id)
self.balance = NegativeQuantity(self.balance.instrument, self.balance.size, self.balance.path_id)
self.balance -= quantity
self.balance = self.balance.quantize()

self.ledger.commit(wallet=self,
quantity=quantity,
source="{}:{}/locked".format(self.exchange.name, self.instrument),
target=self.exchange.name,
memo="WITHDRAWAL ({})".format(reason))
return quantity

def deposit(self, quantity: 'Quantity', reason: str) -> 'Quantity':
qty: Quantity = super().deposit(quantity, reason)

# if type(self.balance) is NegativeQuantity and self.total_balance >= 0:
# self.balance = NegativeQuantity(self.balance).to_positive_quantity()

return qty

Loading