-
Notifications
You must be signed in to change notification settings - Fork 338
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] add multiagentRNN #1948
[WIP] add multiagentRNN #1948
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/1948
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Hi @vmoens , this draft PR is to add multi-agent RNN as you proposed in #1921 (comment) I have one blocking issue: the LSTMNet in torchRL returns a tuple of (output, hidden_states...), but the I am not sure what the best solution is here. Override the forward call? Add an MLP to the end of MultiAgentRNN (as that's the most typical use case)? To be honest, I'm also not fully understanding how the LSTMNet gets the previous hidden state as input either. I would appreciate any clarity you could provide! |
I did a bunch of fixes for this PR We have an arg that indicates that the agent dim is for the input but we should add one for the output too. Like you were mentioning: we must also allow the MARL modules to accept more than one tensor as input (eg in this case the input will be either a Tensor or a |
I think I got it working but I cannot guarantee that it's what it's supposed to do.
So pretty much all the same except that there's no Agent dim in the hidden state when it's shared and centralized since we flatten / expand the input / output and pass it through the net as in non-MARL cases @matteobettini looking for feedback here! :) I updated the gist above |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Look good, it would be useful to have more clarity (through code comments) in the core calls in forward
. The classes have become hard to read for me
torchrl/modules/models/multiagent.py
Outdated
num_inputs: int = 1, | ||
num_outputs: int = 1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With these two things you mean cardinality of the inputs and outputs?
This is confusing with n_agent_inputs
and outputs. I suggest to call it smth that contains "cardinality" like input_cardinality
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do the names input_tuple_len
and output_tuple_len
make sense to you? The purpose of these arguments to allow a multi-agent net to take in and output a tuple. In the LSTM case it's a 2-len tuple (input, (h_x, c_x))
and for GRU it's a 2-len tuple (input, h_x)
.
You have a typo where sometimes it Maybe would be useful to declare the dimensions at the beginning of that tree For the only case that is different (centralized=True, share_params=True), our convention was to expand the output to resemble the existance of the multiagent dimension. Why did you decide not to do this for all outputs? It seems strange to me that we do it for y but not for the others. I would do all or nothing. |
Another issue relates to how to store the hidden states in the grouping API. Imagine I have 2 agent groups with cardinality Tensordict(
group_a: TensorDict(
batch_size=(B,T,N_A)
),
group_b: TensorDict(
batch_size=(B,T,N_B)
)
batch_size(B,T)
) Now, if both groups use LSTMs, x and y can go in their tds, but the hidden states are more problematic, as their shape makes it so that we cannot use that structuring |
I did not "decide", it's just not possible. I fixed the diagram. |
I agree but that's the LSTM format so I would not change that. The goal of these classes is to be used without tensordict, so tensordict formatting should not impact the data format. This class is supposed to be the MA version of nn.LSTM. If we want to use it with a tensordict we can build a class similar to the LSTM net we have for single agents. |
Got it. So it is not possible to just say that when you have centralised=True, share=True, the hidden state should be just indexed alond the first elem of the agent dim instead of flattened? Also, how difficult do you think would be to add GRU to this PR as well? It has been found to work better than LSTM in rl and since we already added the infrastructure maybe it requires minimal additions |
@kfu02 went for LSTM but GRU would be easier. We implemented something like that for indices in replay buffers and I wish we didn't! It's extremely hard to maintain |
I can switch the current code to use a GRU. I was not aware that LSTMs were outperformed by GRUs, thanks @matteobettini ! |
Now that we have LSTM working I think we can do both! |
At least that is what we observed |
@vmoens do you think we could reopen this on a new branch? |
Sure we should get this feature, I just didn't have the time to work on it yet. |
I am keen on the feature, but I don't have bandwidth as this moment due to deadlines and it seems @kfu02 won't work on this anymore. I think it just misses testing and docs and it should be good to go. My personal concern is also a loss in readability in the mutli-agent modules after this change. We can reopent it in another PR and then first one of us that picks this up can work to ship it. |
We can work on readability. This is an unmerged PR, the proposed changes have not been properly tested or documented yet so I would not jump straight ahead to discarding the changes because of readability issues before we get this piece to a mature stage. Having one parent class that account for all nets makes it easier to check that we have all the features covered across networks. I noticed multiple inconsistencies across MLP and CNN so if we scale things up to other networks we need to build an API that is consistent and testable for all. |
I'm interested in reviving this feature for my application (TL;DR: MARL with small language models like GRUs) but I'm quite new to TorchRL and convention you follow etc. so please forgive any naivety. In general, I've been finding it quite difficult to understand everything happening with the code here... Handling of _in_dim seems very complicated seeing as it completely depends on the network. So I was wondering if it'd be simpler to have a Factory class that accepts 1) a pre-built single network and 2) the 'in_dim' specification, and then just run _pre_forward_checks to make sure the 'in_dim' matches the inputs and agent_dim. Example usage from a user perspective, MultiAgentFactory(
single_net=single_net, # Provide pre-built network, that may already have been pre-trained
n_agents=2,
is_centralised=centralised, # Same as centralised, but clarifying that the network should already expect in_features * n_agents
share_params=share_params,
agent_dim=-2,
map_dim=-2, # Same as in_dim, but I found this naming more intuitive from a user perspective also for example (-2,) or (-2, -2) etc.
outut_tuple_length=1,
) This would:
I'm currently drafting something like this for myself, but I want to know if it'd work to towards this feature in TorchRL or if I'm barking up the wrong tree. Alternative solution: The other alternative I would propose is to make _in_dim an abstract method to ensure it's defined for each network. However, this is more complicated for the user / developer. |
Description
Per #2003 adds multi-agent GRU and LSTMs to torchRL's multiagent modules.
Modifies the MultiAgentNetBase class to take in multiple input tensors and output tensors, which allows these recurrent multi-agent nets to input/output hidden states (e.g.
(input, h_x)
).Test gist: https://gist.github.com/kfu02/87ae6c6d99e681d474f4977a9653b329
Motivation and Context
Why is this change required? What problem does it solve?
If it fixes an open issue, please link to the issue here.
You can use the syntax
close #15213
if this solves the issue #15213close #2003
Types of changes
What types of changes does your code introduce? Remove all that do not apply:
Checklist
Go over all the following points, and put an
x
in all the boxes that apply.If you are unsure about any of these, don't hesitate to ask. We are here to help!