Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] Recurrent policies #40

Open
jamesheald opened this issue Apr 2, 2024 · 16 comments
Open

[Feature Request] Recurrent policies #40

jamesheald opened this issue Apr 2, 2024 · 16 comments
Labels
enhancement New feature or request help wanted Extra attention is needed

Comments

@jamesheald
Copy link
Contributor

There are recurrent (LSTM) policy options for sb3 (e.g. RecurrentPPO). It would be great to have recurrent PPO implemented for sbx.

@jamesheald jamesheald added the enhancement New feature or request label Apr 2, 2024
@araffin
Copy link
Owner

araffin commented Apr 2, 2024

Hello,
are you willing to contribute the implementation?

@araffin araffin added the help wanted Extra attention is needed label Apr 2, 2024
@jamesheald
Copy link
Contributor Author

Yes, definitely. I'm new to the stable baselines environment and am currently implementing some more basic things for what I need, but once I'm at the stage of my project where it would be useful to have the recurrent version of PPO, I would be happy to implement this.

@corentinlger
Copy link

Hello, I'd like to contribute to the implementation if possible.

Would you like to share some code with the current implementation of PPO ? Or add a new reccurentppo dir in https://github.com/araffin/sbx/tree/master/sbx ?

@araffin
Copy link
Owner

araffin commented Aug 2, 2024

Probably a new folder would be cleaner.

@corentinlger
Copy link

Ok I'll try that !

@jamesheald
Copy link
Contributor Author

@corentinlger can I help?

@corentinlger
Copy link

Hello, thanks for asking because I indeed had a few questions about this implementation. I first started by looking at the LSTM section of this blog about PPO in order to better understand the specifities of the algorithm. In this implementation, the author adds an lstm component to the agent without changing the actor and the critic networks, and only saves the first lstm state at the beginning of a rollout. Then he uses it to reconstruct the probability distributions (used in rollouts) during the networks updates (see ppo_atari_lstm.py from CleanRL).

Then I looked at the implementation of LSTM_PPO in Sb3-Contrib which does something slightly different for the networks of the agent. Here it is the actor and the critic that both incorporate an LSTM component (see sb3_contrib/common/recurrent/policies.py).

But the major difference is that all the lstm states obtained during the rollouts are added to a buffer (see sb3_contrib/ppo_recurrent/ppo_recurrent.py). This buffer seems to implements a mechanism for padding sequences with a mask, which ensures that episodes of varying lengths are padded to the same one within the buffer, while still indicating where an episode ends. The mask is then used in the train function in order to take account of episode endings when doing the updates of the network. But otherwise these updates seem to be pretty similar to the ones in vanilla PPO from Sb3. For the implementation of the rollout buffer in jax, there was this issue that told it might be easier to handle rollout data sequentially (instead of using the mask and padding mechanisms) and just jit the whole function (which should work because the rollout data will always have the same shape).

Additionally, I am not sure to understand this part of the file :

single_hidden_state_shape = (lstm.num_layers, self.n_envs, lstm.hidden_size)
# hidden and cell states for actor and critic
self._last_lstm_states = RNNStates(
    (
        th.zeros(single_hidden_state_shape, device=self.device),
        th.zeros(single_hidden_state_shape, device=self.device),
    ),
    (
        th.zeros(single_hidden_state_shape, device=self.device),
        th.zeros(single_hidden_state_shape, device=self.device),
    ),
)

Why both hidden and cell state are of shape (2, single_hidden_state_shape) ? Compared to CleanRL implementation where the shape for each is only (single_hidden_state_shape,) ? Apart from that, am I missing something in this implementation of LSTM PPO (eg gradient being computed in a different manner here compared to vanilla PPO)?

Sorry I didn't think this comment was going to be that long ...
But from what I understand I should surely try to do something in the spirit of Sb3-Contrib (but maybe with a simpler implementation of the buffer ?). What do you think @araffin @jamesheald ?

@araffin
Copy link
Owner

araffin commented Aug 13, 2024

@corentinlger sorry I was until today at the RL conference, let me try to answer in the coming days when I'm back ;)

In short: recurrent PPO in SB3 contrib is overly complex (and I'm not happy about it, so I would be glad if we can find a cleaner solution).

@corentinlger
Copy link

Hello, no problem ! I'll also try to think about an simpler solution (at least for the first minimal implementation)

@araffin
Copy link
Owner

araffin commented Aug 28, 2024

Here it is the actor and the critic that both incorporate an LSTM component

actually, there are different modes in SB3 contrib (shared, actor only, enable critic lstm), the default one should be separate lstm for both actor and critic.

while still indicating where an episode ends.

correct

or the implementation of the rollout buffer in jax, there was vwxyzjn/cleanrl#276 that told it might be easier to handle rollout data

if we could get rid of the for loop, that would be nice, but we need to be extra careful (Stable-Baselines-Team/stable-baselines3-contrib#239)

Why both hidden and cell state are of shape (2, single_hidden_state_shape) ?

from the comment, for actor and critic. Or you had another question?

@corentinlger
Copy link

if we could get rid of the for loop, that would be nice, but we need to be extra careful (Stable-Baselines-Team/stable-baselines3-contrib#239)

Actually purejaxrl impelmented a ppo_rnn using this trick for the lstm states reset (do it in the RNN call instead of doing it outside) :

@nn.compact
    def __call__(self, carry, x):
        """Applies the module."""
        rnn_state = carry
        ins, resets = x
        rnn_state = jnp.where(
            resets[:, np.newaxis], # if reset flag
            self.initialize_carry(ins.shape[0], ins.shape[1]), # intialize carry = create new lstm state
            rnn_state, # else keep the current rnn state
        )
        new_rnn_state, y = nn.GRUCell()(rnn_state, ins)
        return new_rnn_state, y

It seems to simplify the code a bit and avoids using a for loop, but necessitates giving the observation and the dones flag to the network. What do you think of this solution ?

from the comment, for actor and critic. Or you had another question?

Indeed, this completely makes sense now that you say it

@araffin
Copy link
Owner

araffin commented Aug 31, 2024

giving the observation and the dones flag to the network. What do you think of this solution ?

we need to do that anyway, no?

@corentinlger
Copy link

sorry @araffin I didn't have a computer during two weeks.

we need to do that anyway, no?

Indeed

I implemented a first version of the policy.py file, and I'm trying to implement a working version of ppo_recurrent.py (update the buffer, and _setup_model and collect_rollouts functions.

@corentinlger
Copy link

Hello, I have an first version of the algorithm that runs but doesn't learn yet. Is it possible to open a draft PR so you can review it? I'd like to know if it's on the right track before I dive deeper into the details to get it fully functional.

@araffin
Copy link
Owner

araffin commented Oct 18, 2024

yes you can, but I won't have time before one or two weeks to review it.

@corentinlger
Copy link

Ok thanks ! I will try to keep improving it in the mean time

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

3 participants