From 92cad218939ffe1a6805b2e9a1e5946dbc7da160 Mon Sep 17 00:00:00 2001 From: Peter Dekkers Date: Sat, 10 Aug 2024 11:48:23 +0200 Subject: [PATCH] more flex strategy --- roboquant/__init__.py | 2 +- roboquant/ml/strategies.py | 13 ++++++------- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/roboquant/__init__.py b/roboquant/__init__.py index eaded7b..4ef78cc 100644 --- a/roboquant/__init__.py +++ b/roboquant/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.8.1" +__version__ = "0.8.2" import logging diff --git a/roboquant/ml/strategies.py b/roboquant/ml/strategies.py index 2bf895d..53f400a 100644 --- a/roboquant/ml/strategies.py +++ b/roboquant/ml/strategies.py @@ -36,12 +36,11 @@ def create_signals(self, event: Event) -> list[Signal]: h.append(row) if len(h) == h.maxlen: x = np.asarray(h, dtype=self._dtype) - if signal := self.predict(x, event.time): - return [signal] + return self.predict(x, event.time) return [] @abstractmethod - def predict(self, x: NDArray, time: datetime) -> Signal | None: ... + def predict(self, x: NDArray, time: datetime) -> list[Signal]: ... class SequenceDataset(Dataset): @@ -116,7 +115,7 @@ def __init__( self.sell_pct = sell_pct self.asset = asset - def predict(self, x, time) -> Signal | None: + def predict(self, x, time) -> list[Signal]: x = torch.asarray(x) x = torch.unsqueeze(x, dim=0) # add the batch dimension @@ -131,10 +130,10 @@ def predict(self, x, time) -> Signal | None: logger.info("prediction p=%s time=%s", p, time) if p >= self.buy_pct: - return Signal.buy(self.asset) + return [Signal.buy(self.asset)] if p <= self.sell_pct: - return Signal.sell(self.asset) - return None + return [Signal.sell(self.asset)] + return [] def _get_dataloaders(self, x, y, prediction: int, validation_split: float, batch_size: int): # what is the border between train- and validation-data