Skip to content

Commit

Permalink
Add reset
Browse files Browse the repository at this point in the history
  • Loading branch information
nkaz001 committed Apr 7, 2023
1 parent d17d249 commit b92bbcf
Show file tree
Hide file tree
Showing 10 changed files with 310 additions and 64 deletions.
181 changes: 117 additions & 64 deletions hftbacktest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
'Stat',
'validate_data', 'correct_local_timestamp', 'correct_exch_timestamp', 'correct',)

__version__ = '1.4.2'
__version__ = '1.5.0'


def HftBacktest(
Expand All @@ -50,79 +50,31 @@ def HftBacktest(
):
cache = Cache()

if isinstance(data, pd.DataFrame):
if isinstance(data, list):
local_reader = DataReader(cache)
local_reader.add_data(data.to_numpy())

exch_reader = DataReader(cache)
exch_reader.add_data(data.to_numpy())
elif isinstance(data, np.ndarray):
local_reader = DataReader(cache)
local_reader.add_data(data)

exch_reader = DataReader(cache)
exch_reader.add_data(data)
for item in data:
if isinstance(item, str):
local_reader.add_file(item)
exch_reader.add_file(item)
elif isinstance(item, pd.DataFrame) or isinstance(item, np.ndarray):
local_reader.add_data(item)
exch_reader.add_data(item)
else:
raise ValueError('Unsupported data type')
elif isinstance(data, str):
local_reader = DataReader(cache)
local_reader.add_file(data)

exch_reader = DataReader(cache)
exch_reader.add_file(data)
elif isinstance(data, list):
local_reader = DataReader(cache)
exch_reader = DataReader(cache)
for filepath in data:
if isinstance(filepath, str):
local_reader.add_file(filepath)
exch_reader.add_file(filepath)
elif isinstance(filepath, pd.DataFrame) or isinstance(filepath, np.ndarray):
local_reader.add_data(filepath)
exch_reader.add_data(filepath)
else:
raise ValueError('Unsupported data type')
else:
raise ValueError('Unsupported data type')
data = __load_data(data)
local_reader = DataReader(cache)
local_reader.add_data(data)

if isinstance(snapshot, pd.DataFrame):
assert (snapshot.columns[:6] == [
'event',
'exch_timestamp',
'local_timestamp',
'side',
'price',
'qty'
]).all()
snapshot = snapshot.to_numpy()
elif isinstance(snapshot, np.ndarray):
assert snapshot.shape[1] >= 6
elif isinstance(snapshot, str):
if snapshot.endswith('.npy'):
snapshot = np.load(snapshot)
elif snapshot.endswith('.npz'):
tmp = np.load(snapshot)
if 'data' in tmp:
snapshot = tmp['data']
assert snapshot.shape[1] >= 6
else:
k = list(tmp.keys())[0]
print("Snapshot is loaded from %s instead of 'data'" % k)
snapshot = tmp[k]
assert snapshot.shape[1] >= 6
else:
df = pd.read_pickle(snapshot, compression='gzip')
assert (df.columns[:6] == [
'event',
'exch_timestamp',
'local_timestamp',
'side',
'price',
'qty'
]).all()
snapshot = df.to_numpy()
elif snapshot is None:
pass
else:
raise ValueError('Unsupported snapshot type')
exch_reader = DataReader(cache)
exch_reader.add_data(data)

if queue_model is None:
queue_model = RiskAverseQueueModel()
Expand All @@ -131,6 +83,7 @@ def HftBacktest(
exch_market_depth = MarketDepth(tick_size, lot_size)

if snapshot is not None:
snapshot = __load_data(snapshot)
local_market_depth.apply_snapshot(snapshot)
exch_market_depth.apply_snapshot(snapshot)

Expand Down Expand Up @@ -178,3 +131,103 @@ def HftBacktest(
)

return SingleAssetHftBacktest(local, exch)


def reset(
hbt,
data,
tick_size=None,
lot_size=None,
maker_fee=None,
taker_fee=None,
snapshot=None,
start_position=0,
start_balance=0,
start_fee=0,
trade_list_size=None,
):
cache = Cache()

if isinstance(data, list):
local_reader = DataReader(cache)
exch_reader = DataReader(cache)
for item in data:
if isinstance(item, str):
local_reader.add_file(item)
exch_reader.add_file(item)
elif isinstance(item, pd.DataFrame) or isinstance(item, np.ndarray):
local_reader.add_data(item)
exch_reader.add_data(item)
else:
raise ValueError('Unsupported data type')
elif isinstance(data, str):
local_reader = DataReader(cache)
local_reader.add_file(data)

exch_reader = DataReader(cache)
exch_reader.add_file(data)
else:
data = __load_data(data)
local_reader = DataReader(cache)
local_reader.add_data(data)

exch_reader = DataReader(cache)
exch_reader.add_data(data)

snapshot = __load_data(snapshot) if snapshot is not None else None

hbt.reset(
local_reader,
exch_reader,
start_position,
start_balance,
start_fee,
maker_fee,
taker_fee,
tick_size,
lot_size,
snapshot,
trade_list_size,
)


def __load_data(data):
if isinstance(data, pd.DataFrame):
assert (data.columns[:6] == [
'event',
'exch_timestamp',
'local_timestamp',
'side',
'price',
'qty'
]).all()
data = data.to_numpy()
elif isinstance(data, np.ndarray):
assert data.shape[1] >= 6
elif isinstance(data, str):
if data.endswith('.npy'):
data = np.load(data)
elif data.endswith('.npz'):
tmp = np.load(data)
if 'data' in tmp:
data = tmp['data']
assert data.shape[1] >= 6
else:
k = list(tmp.keys())[0]
print("Data is loaded from %s instead of 'data'" % k)
data = tmp[k]
assert data.shape[1] >= 6
else:
df = pd.read_pickle(data, compression='gzip')
assert (df.columns[:6] == [
'event',
'exch_timestamp',
'local_timestamp',
'side',
'price',
'qty'
]).all()
data = df.to_numpy()
else:
raise ValueError('Unsupported data type')
return data
40 changes: 40 additions & 0 deletions hftbacktest/backtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,46 @@ def goto(self, timestamp, wait_order_response=WAIT_ORDER_RESPONSE_NONE):
return False
return True

def reset(
self,
local_reader,
exch_reader,
start_position,
start_balance,
start_fee,
maker_fee,
taker_fee,
tick_size,
lot_size,
snapshot,
trade_list_size,
):
self.local.reader = local_reader
self.exch.reader = exch_reader

self.local.reset(
start_position,
start_balance,
start_fee,
maker_fee,
taker_fee,
tick_size,
lot_size,
snapshot,
trade_list_size,
)
self.exch.reset(
start_position,
start_balance,
start_fee,
maker_fee,
taker_fee,
tick_size,
lot_size,
snapshot
)
self.current_timestamp = self.local.next_data[0, COL_LOCAL_TIMESTAMP]
self.run = True

def SingleAssetHftBacktest(local, exch):
jitted = jitclass(spec=[
Expand Down
16 changes: 16 additions & 0 deletions hftbacktest/models/latencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ def entry(self, timestamp, order, proc):
def response(self, timestamp, order, proc):
return self.response_latency

def reset(self):
pass


@jitclass
class FeedLatency:
Expand Down Expand Up @@ -71,6 +74,9 @@ def entry(self, timestamp, order, proc):
def response(self, timestamp, order, proc):
return self.response_latency + self.resp_latency_mul * self.__latency(proc)

def reset(self):
pass


@jitclass
class ForwardFeedLatency:
Expand Down Expand Up @@ -105,6 +111,9 @@ def entry(self, timestamp, order, proc):
def response(self, timestamp, order, proc):
return self.response_latency + self.resp_latency_mul * self.__latency(proc)

def reset(self):
pass


@jitclass
class BackwardFeedLatency:
Expand Down Expand Up @@ -139,6 +148,9 @@ def entry(self, timestamp, order, proc):
def response(self, timestamp, order, proc):
return self.response_latency + self.resp_latency_mul * self.__latency(proc)

def reset(self):
pass


@jitclass
class IntpOrderLatency:
Expand Down Expand Up @@ -199,3 +211,7 @@ def response(self, timestamp, order, proc):
lat2 = next_resp_local_timestamp - next_exch_timestamp
return self.__intp(timestamp, exch_timestamp, lat1, next_exch_timestamp, lat2)
raise ValueError

def reset(self):
self.entry_rn = 0
self.resp_rn = 0
6 changes: 6 additions & 0 deletions hftbacktest/models/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ def depth(self, order, prev_qty, new_qty, proc):
def is_filled(self, order, proc):
return round(order.q[0] / proc.lot_size) < 0

def reset(self):
pass


class ProbQueueModel:
def __init__(self):
Expand Down Expand Up @@ -69,6 +72,9 @@ def is_filled(self, order, proc):
def prob(self, front, back):
return np.divide(self.f(back), self.f(back) + self.f(front))

def reset(self):
pass


@jitclass
class LogProbQueueModel(ProbQueueModel):
Expand Down
5 changes: 5 additions & 0 deletions hftbacktest/order.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,11 @@ def get(self, order_id):
return recv_timestamp
raise KeyError

def reset(self):
self.order_list.clear()
self.orders.clear()
self.frontmost_timestamp = 0

def __getitem__(self, key):
return self.order_list[key]

Expand Down
29 changes: 29 additions & 0 deletions hftbacktest/proc/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,35 @@ def __init__(
self.last_trades = np.full((trade_list_size, self.data.shape[1]), np.nan, np.float64)
self.user_data = np.full((20, self.data.shape[1]), np.nan, np.float64)

def reset(
self,
start_position,
start_balance,
start_fee,
maker_fee,
taker_fee,
tick_size,
lot_size,
snapshot,
trade_list_size,
):
self._proc_reset(
start_position,
start_balance,
start_fee,
maker_fee,
taker_fee,
tick_size,
lot_size,
snapshot
)
self.trade_len = 0
if trade_list_size is not None:
self.last_trades = np.full((trade_list_size, self.data.shape[1]), np.nan, np.float64)
else:
self.last_trades[:, :] = np.nan
self.user_data[:, :] = np.nan

def _next_data_timestamp(self):
return self._next_data_timestamp_column(COL_LOCAL_TIMESTAMP)

Expand Down
Loading

0 comments on commit b92bbcf

Please sign in to comment.