Skip to content

Commit

Permalink
more flex strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
jbaron committed Aug 10, 2024
1 parent 20383d2 commit 92cad21
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 8 deletions.
2 changes: 1 addition & 1 deletion roboquant/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.8.1"
__version__ = "0.8.2"

import logging

Expand Down
13 changes: 6 additions & 7 deletions roboquant/ml/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,11 @@ def create_signals(self, event: Event) -> list[Signal]:
h.append(row)
if len(h) == h.maxlen:
x = np.asarray(h, dtype=self._dtype)
if signal := self.predict(x, event.time):
return [signal]
return self.predict(x, event.time)
return []

@abstractmethod
def predict(self, x: NDArray, time: datetime) -> Signal | None: ...
def predict(self, x: NDArray, time: datetime) -> list[Signal]: ...


class SequenceDataset(Dataset):
Expand Down Expand Up @@ -116,7 +115,7 @@ def __init__(
self.sell_pct = sell_pct
self.asset = asset

def predict(self, x, time) -> Signal | None:
def predict(self, x, time) -> list[Signal]:
x = torch.asarray(x)
x = torch.unsqueeze(x, dim=0) # add the batch dimension

Expand All @@ -131,10 +130,10 @@ def predict(self, x, time) -> Signal | None:

logger.info("prediction p=%s time=%s", p, time)
if p >= self.buy_pct:
return Signal.buy(self.asset)
return [Signal.buy(self.asset)]
if p <= self.sell_pct:
return Signal.sell(self.asset)
return None
return [Signal.sell(self.asset)]
return []

def _get_dataloaders(self, x, y, prediction: int, validation_split: float, batch_size: int):
# what is the border between train- and validation-data
Expand Down

0 comments on commit 92cad21

Please sign in to comment.