diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index 196cd72e4..1933c7d54 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -137,8 +137,11 @@ def process_fn( self._buffer, self._indices = buffer, indices batch = self._compute_returns(batch, buffer, indices) batch.act = to_torch_as(batch.act, batch.v_s) + logp_old = [] with torch.no_grad(): - batch.logp_old = self(batch).dist.log_prob(batch.act) + for minibatch in batch.split(self.max_batchsize, shuffle=False, merge_last=True): + logp_old.append(self(minibatch).dist.log_prob(minibatch.act)) + batch.logp_old = torch.cat(logp_old, dim=0).flatten() batch: LogpOldProtocol return batch