Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] add multiagentRNN #1948

Closed
wants to merge 8 commits into from
Closed

[WIP] add multiagentRNN #1948

wants to merge 8 commits into from

Conversation

kfu02
Copy link

@kfu02 kfu02 commented Feb 22, 2024

Description

Per #2003 adds multi-agent GRU and LSTMs to torchRL's multiagent modules.

Modifies the MultiAgentNetBase class to take in multiple input tensors and output tensors, which allows these recurrent multi-agent nets to input/output hidden states (e.g. (input, h_x)).

Test gist: https://gist.github.com/kfu02/87ae6c6d99e681d474f4977a9653b329

Motivation and Context

Why is this change required? What problem does it solve?
If it fixes an open issue, please link to the issue here.
You can use the syntax close #15213 if this solves the issue #15213

close #2003

  • I have raised an issue to propose this change (required for new features and bug fixes)

Types of changes

What types of changes does your code introduce? Remove all that do not apply:

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds core functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)
  • Example (update in the folder of examples)

Checklist

Go over all the following points, and put an x in all the boxes that apply.
If you are unsure about any of these, don't hesitate to ask. We are here to help!

  • I have read the CONTRIBUTION guide (required)
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.

Copy link

pytorch-bot bot commented Feb 22, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/1948

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 22, 2024
@kfu02 kfu02 changed the title import multiagentrnn, start coding WIP add multiagentRNN Feb 22, 2024
@kfu02
Copy link
Author

kfu02 commented Feb 22, 2024

Hi @vmoens , this draft PR is to add multi-agent RNN as you proposed in #1921 (comment)

I have one blocking issue: the LSTMNet in torchRL returns a tuple of (output, hidden_states...), but the forward() call in MultiAgentNetBase expects the network from build_single_net() to output a single tensor (as the MLP and ConvNet both do).

I am not sure what the best solution is here. Override the forward call? Add an MLP to the end of MultiAgentRNN (as that's the most typical use case)? To be honest, I'm also not fully understanding how the LSTMNet gets the previous hidden state as input either.

I would appreciate any clarity you could provide!

@kfu02 kfu02 changed the title WIP add multiagentRNN [WIP] add multiagentRNN Feb 22, 2024
@vmoens
Copy link
Contributor

vmoens commented Feb 26, 2024

I did a bunch of fixes for this PR
you can test it with this
https://gist.github.com/vmoens/c24c36b1efcbb159638dc0bf4cb12f15

We have an arg that indicates that the agent dim is for the input but we should add one for the output too.
The only reason we consider it's -2 is because CNN and MLP both end up with a linear layer which has just one dim of features.

Like you were mentioning: we must also allow the MARL modules to accept more than one tensor as input (eg in this case the input will be either a Tensor or a Tensor, Tuple[Tensor, Tensor]). I will make a follow up edit with that

@vmoens
Copy link
Contributor

vmoens commented Feb 26, 2024

I think I got it working but I cannot guarantee that it's what it's supposed to do.
Here's what signatures you should expect:

centralized = True
  |- share_params = True
  |---- input
  |           |- x [Batch, Time, Agents, Features]
  |           |- hidden[0]  [Layers, Batch, Hidden]  <====== No agents
  |           |- hidden[1] [Layers, Batch, Hidden]
  |---- output
  |           |- y [Batch, Time, Agents, Hidden]
  |           |- hidden[0]  [Layers, Batch, Hidden]
  |           |- hidden[1]  [Layers, Batch, Hidden]
  |- share_params = False
  |---- input
  |           |- x  [Batch, Time, Agents, Features]
  |           |- hidden[0]  [Layers, Batch, Agents, Hidden]
  |           |- hidden[1]  [Layers, Batch, Agents, Hidden]
  |---- output
  |           |- y  [Batch, Time, Agents, Hidden]
  |           |- hidden[0]  [Layers, Batch, Agents, Hidden]
  |           |- hidden[1]  [Layers, Batch, Agents, Hidden]
centralized = False
  |- share_params = True
  |---- input
  |           |- x  [Batch, Time, Agents, Features]
  |           |- hidden[0]  [Layers, Batch, Agents, Hidden]
  |           |- hidden[1]  [Layers, Batch, Agents, Hidden]
  |---- output
  |           |- y  [Batch, Time, Agents, Hidden]
  |           |- hidden[0]  [Layers, Batch, Agents, Hidden]
  |           |- hidden[1]  [Layers, Batch, Agents, Hidden]
  |- share_params = False
  |---- input
  |           |- x  [Batch, Time, Agents, Features]
  |           |- hidden[0]  [Layers, Batch, Agents, Hidden]
  |           |- hidden[1]  [Layers, Batch, Agents, Hidden]
  |---- output
  |           |- y  [Batch, Time, Agents, Hidden]
  |           |- hidden[0]  [Layers, Batch, Agents, Hidden]
  |           |- hidden[1]  [Layers, Batch, Agents, Hidden]
  

So pretty much all the same except that there's no Agent dim in the hidden state when it's shared and centralized since we flatten / expand the input / output and pass it through the net as in non-MARL cases

@matteobettini looking for feedback here! :)

I updated the gist above

@vmoens vmoens added the enhancement New feature or request label Feb 26, 2024
Copy link
Contributor

@matteobettini matteobettini left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Look good, it would be useful to have more clarity (through code comments) in the core calls in forward. The classes have become hard to read for me

Comment on lines 35 to 36
num_inputs: int = 1,
num_outputs: int = 1,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With these two things you mean cardinality of the inputs and outputs?

This is confusing with n_agent_inputs and outputs. I suggest to call it smth that contains "cardinality" like input_cardinality

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do the names input_tuple_len and output_tuple_len make sense to you? The purpose of these arguments to allow a multi-agent net to take in and output a tuple. In the LSTM case it's a 2-len tuple (input, (h_x, c_x)) and for GRU it's a 2-len tuple (input, h_x).

@matteobettini
Copy link
Contributor

matteobettini commented Feb 27, 2024

So pretty much all the same except that there's no Agent dim in the hidden state when it's shared and centralized since we flatten / expand the input / output and pass it through the net as in non-MARL cases

You have a typo where sometimes it hidden is the last dim of y and sometimes it is Features. also, is hidden == Hidden?

Maybe would be useful to declare the dimensions at the beginning of that tree

For the only case that is different (centralized=True, share_params=True), our convention was to expand the output to resemble the existance of the multiagent dimension. Why did you decide not to do this for all outputs? It seems strange to me that we do it for y but not for the others. I would do all or nothing.

@matteobettini
Copy link
Contributor

Another issue relates to how to store the hidden states in the grouping API.

Imagine I have 2 agent groups with cardinality N_A and N_B, then I might have a td that looks like

Tensordict(
group_a: TensorDict(
    batch_size=(B,T,N_A)
    ),
group_b: TensorDict(
    batch_size=(B,T,N_B)
    )
batch_size(B,T)
)

Now, if both groups use LSTMs, x and y can go in their tds, but the hidden states are more problematic, as their shape makes it so that we cannot use that structuring

@vmoens
Copy link
Contributor

vmoens commented Feb 28, 2024

Why did you decide not to do this for all outputs?

I did not "decide", it's just not possible.
If you expand it, you will feed back hidden states that are expanded, flatten them, and get even bigger hidden states (if that doesn't break - actually it will break because they're too big).
You can only expand the output you're not feeding back recursively.

I fixed the diagram.

@vmoens
Copy link
Contributor

vmoens commented Feb 28, 2024

Another issue relates to how to store the hidden states in the grouping API.

I agree but that's the LSTM format so I would not change that. The goal of these classes is to be used without tensordict, so tensordict formatting should not impact the data format. This class is supposed to be the MA version of nn.LSTM.

If we want to use it with a tensordict we can build a class similar to the LSTM net we have for single agents.

@matteobettini
Copy link
Contributor

I did not "decide", it's just not possible.
If you expand it, you will feed back hidden states that are expanded, flatten them, and get even bigger hidden states (if that doesn't break - actually it will break because they're too big).
You can only expand the output you're not feeding back recursively.

Got it. So it is not possible to just say that when you have centralised=True, share=True, the hidden state should be just indexed alond the first elem of the agent dim instead of flattened?

Also, how difficult do you think would be to add GRU to this PR as well? It has been found to work better than LSTM in rl and since we already added the infrastructure maybe it requires minimal additions

@vmoens
Copy link
Contributor

vmoens commented Feb 29, 2024

@kfu02 went for LSTM but GRU would be easier.
I'm not super in favour of expanding a tensor and then taking the first index. It's confusing (let people think the content is different) and dangerous both from a usage and memory perspective

We implemented something like that for indices in replay buffers and I wish we didn't! It's extremely hard to maintain

@kfu02
Copy link
Author

kfu02 commented Feb 29, 2024

@kfu02 went for LSTM but GRU would be easier. I'm not super in favour of expanding a tensor and then taking the first index. It's confusing (let people think the content is different) and dangerous both from a usage and memory perspective

We implemented something like that for indices in replay buffers and I wish we didn't! It's extremely hard to maintain

I can switch the current code to use a GRU. I was not aware that LSTMs were outperformed by GRUs, thanks @matteobettini !

@vmoens
Copy link
Contributor

vmoens commented Feb 29, 2024

Now that we have LSTM working I think we can do both!

@matteobettini
Copy link
Contributor

I was not aware that LSTMs were outperformed by GRUs

At least that is what we observed
https://arxiv.org/abs/2303.01859

@matteobettini
Copy link
Contributor

@vmoens @kfu02 what is the satus of this? ready for review?

@kfu02
Copy link
Author

kfu02 commented Mar 28, 2024

@vmoens @kfu02 what is the satus of this? ready for review?

Apologies, I still have to add docstrings and unit tests, and it has fallen behind in my priorities. I will complete this over the weekend most likely.

@kfu02 kfu02 closed this by deleting the head repository May 16, 2024
@matteobettini
Copy link
Contributor

@vmoens do you think we could reopen this on a new branch?

@vmoens
Copy link
Contributor

vmoens commented May 20, 2024

Sure we should get this feature, I just didn't have the time to work on it yet.
Should I try to solve it on my own? I had the impression you guys were on it

@matteobettini
Copy link
Contributor

I am keen on the feature, but I don't have bandwidth as this moment due to deadlines and it seems @kfu02 won't work on this anymore.

I think it just misses testing and docs and it should be good to go.

My personal concern is also a loss in readability in the mutli-agent modules after this change.

We can reopent it in another PR and then first one of us that picks this up can work to ship it.

@vmoens
Copy link
Contributor

vmoens commented May 20, 2024

We can work on readability. This is an unmerged PR, the proposed changes have not been properly tested or documented yet so I would not jump straight ahead to discarding the changes because of readability issues before we get this piece to a mature stage.

Having one parent class that account for all nets makes it easier to check that we have all the features covered across networks. I noticed multiple inconsistencies across MLP and CNN so if we scale things up to other networks we need to build an API that is consistent and testable for all.
Also there are some non-trivial tricks I want to apply in the future to make things more readable and faster to execute but that will be considerably harder if I have to patch 4 or more classes independently.

@MorganCThomas
Copy link

MorganCThomas commented Oct 31, 2024

I'm interested in reviving this feature for my application (TL;DR: MARL with small language models like GRUs) but I'm quite new to TorchRL and convention you follow etc. so please forgive any naivety. In general, I've been finding it quite difficult to understand everything happening with the code here...

Handling of _in_dim seems very complicated seeing as it completely depends on the network. So I was wondering if it'd be simpler to have a Factory class that accepts 1) a pre-built single network and 2) the 'in_dim' specification, and then just run _pre_forward_checks to make sure the 'in_dim' matches the inputs and agent_dim.

Example usage from a user perspective,

MultiAgentFactory(
        single_net=single_net, # Provide pre-built network, that may already have been pre-trained
        n_agents=2,
        is_centralised=centralised, # Same as centralised, but clarifying that the network should already expect in_features * n_agents
        share_params=share_params,
        agent_dim=-2,
        map_dim=-2, # Same as in_dim, but I found this naming more intuitive from a user perspective also for example (-2,) or (-2, -2) etc.
        outut_tuple_length=1,
    )

This would:

  • Give more responsibility to the user / developer to provide the map_dim, but simplify the API and make it more general.
  • Accept any network without the user having to write a child class, which I found more user-friendly.
  • Make it simpler to start from pre-trained networks (as is the case for me).
  • Moreover, if somebody wanted initialize multiple agents differently, it could accept a list of networks trivially.

I'm currently drafting something like this for myself, but I want to know if it'd work to towards this feature in TorchRL or if I'm barking up the wrong tree.

Alternative solution: The other alternative I would propose is to make _in_dim an abstract method to ensure it's defined for each network. However, this is more complicated for the user / developer.

Cc @matteobettini

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants