-
Notifications
You must be signed in to change notification settings - Fork 37
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature Request] Recurrent policies #40
Comments
Hello, |
Yes, definitely. I'm new to the stable baselines environment and am currently implementing some more basic things for what I need, but once I'm at the stage of my project where it would be useful to have the recurrent version of PPO, I would be happy to implement this. |
Hello, I'd like to contribute to the implementation if possible. Would you like to share some code with the current implementation of PPO ? Or add a new reccurentppo dir in https://github.com/araffin/sbx/tree/master/sbx ? |
Probably a new folder would be cleaner. |
Ok I'll try that ! |
@corentinlger can I help? |
Hello, thanks for asking because I indeed had a few questions about this implementation. I first started by looking at the LSTM section of this blog about PPO in order to better understand the specifities of the algorithm. In this implementation, the author adds an lstm component to the agent without changing the actor and the critic networks, and only saves the first lstm state at the beginning of a rollout. Then he uses it to reconstruct the probability distributions (used in rollouts) during the networks updates (see ppo_atari_lstm.py from CleanRL). Then I looked at the implementation of LSTM_PPO in Sb3-Contrib which does something slightly different for the networks of the agent. Here it is the actor and the critic that both incorporate an LSTM component (see sb3_contrib/common/recurrent/policies.py). But the major difference is that all the lstm states obtained during the rollouts are added to a buffer (see sb3_contrib/ppo_recurrent/ppo_recurrent.py). This buffer seems to implements a mechanism for padding sequences with a mask, which ensures that episodes of varying lengths are padded to the same one within the buffer, while still indicating where an episode ends. The mask is then used in the Additionally, I am not sure to understand this part of the file : single_hidden_state_shape = (lstm.num_layers, self.n_envs, lstm.hidden_size)
# hidden and cell states for actor and critic
self._last_lstm_states = RNNStates(
(
th.zeros(single_hidden_state_shape, device=self.device),
th.zeros(single_hidden_state_shape, device=self.device),
),
(
th.zeros(single_hidden_state_shape, device=self.device),
th.zeros(single_hidden_state_shape, device=self.device),
),
) Why both hidden and cell state are of shape (2, single_hidden_state_shape) ? Compared to CleanRL implementation where the shape for each is only (single_hidden_state_shape,) ? Apart from that, am I missing something in this implementation of LSTM PPO (eg gradient being computed in a different manner here compared to vanilla PPO)? Sorry I didn't think this comment was going to be that long ... |
@corentinlger sorry I was until today at the RL conference, let me try to answer in the coming days when I'm back ;) In short: recurrent PPO in SB3 contrib is overly complex (and I'm not happy about it, so I would be glad if we can find a cleaner solution). |
Hello, no problem ! I'll also try to think about an simpler solution (at least for the first minimal implementation) |
actually, there are different modes in SB3 contrib (shared, actor only, enable critic lstm), the default one should be separate lstm for both actor and critic.
correct
if we could get rid of the for loop, that would be nice, but we need to be extra careful (Stable-Baselines-Team/stable-baselines3-contrib#239)
from the comment, for actor and critic. Or you had another question? |
Actually purejaxrl impelmented a ppo_rnn using this trick for the lstm states reset (do it in the RNN call instead of doing it outside) : @nn.compact
def __call__(self, carry, x):
"""Applies the module."""
rnn_state = carry
ins, resets = x
rnn_state = jnp.where(
resets[:, np.newaxis], # if reset flag
self.initialize_carry(ins.shape[0], ins.shape[1]), # intialize carry = create new lstm state
rnn_state, # else keep the current rnn state
)
new_rnn_state, y = nn.GRUCell()(rnn_state, ins)
return new_rnn_state, y It seems to simplify the code a bit and avoids using a for loop, but necessitates giving the observation and the dones flag to the network. What do you think of this solution ?
Indeed, this completely makes sense now that you say it |
we need to do that anyway, no? |
sorry @araffin I didn't have a computer during two weeks.
Indeed I implemented a first version of the |
Hello, I have an first version of the algorithm that runs but doesn't learn yet. Is it possible to open a draft PR so you can review it? I'd like to know if it's on the right track before I dive deeper into the details to get it fully functional. |
yes you can, but I won't have time before one or two weeks to review it. |
Ok thanks ! I will try to keep improving it in the mean time |
There are recurrent (LSTM) policy options for sb3 (e.g. RecurrentPPO). It would be great to have recurrent PPO implemented for sbx.
The text was updated successfully, but these errors were encountered: