Skip to content

Commit

Permalink
RandomActionPolicy: fix unpacking of act, state
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Panchenko committed Aug 14, 2024
1 parent c699e1d commit 02ab06a
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions tianshou/policy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,8 +711,9 @@ def forward(
batch: ObsBatchProtocol,
state: dict | BatchProtocol | np.ndarray | None = None,
**kwargs: Any,
) -> ActBatchProtocol:
return cast(ActBatchProtocol, Batch(act=self.actor(batch.obs)))
) -> ActStateBatchProtocol:
act, next_state = self.actor(batch.obs, state)
return cast(ActStateBatchProtocol, Batch(act=act, state=next_state))

def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TrainingStats:
return TrainingStats()
Expand Down

0 comments on commit 02ab06a

Please sign in to comment.