Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Refactor multi-agent MLP #1497

Closed
wants to merge 9 commits into from
Closed

[Refactor] Refactor multi-agent MLP #1497

wants to merge 9 commits into from

Conversation

vmoens
Copy link
Contributor

@vmoens vmoens commented Sep 6, 2023

cc @matteobettini

By the way, rather than doing a multi-agent MLP, a multi-agent ConvNet, a multi agent LSTM etc, should we not consider one way of batching ops in whichever nn.Module for multiagent and (maybe) just subclass that for every class?

Seems a bit repetitive no? Isn't there pare of the logic we can recycle?

TODO:

  • Use vmap and module ensemble for multiple nets

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 6, 2023
@github-actions
Copy link

github-actions bot commented Sep 6, 2023

$\color{#D29922}\textsf{\Large⚠\kern{0.2cm}\normalsize Warning}$ Result of CPU Benchmark Tests

Total Benchmarks: 89. Improved: $\large\color{#35bf28}5$. Worsened: $\large\color{#d91a1a}4$.

Expand to view detailed results
Name Max Mean Ops Ops on Repo HEAD Change
test_single 65.3593ms 64.3205ms 15.5471 Ops/s 15.8124 Ops/s $\color{#d91a1a}-1.68\%$
test_sync 0.1046s 37.1208ms 26.9391 Ops/s 26.6108 Ops/s $\color{#35bf28}+1.23\%$
test_async 0.1167s 34.7543ms 28.7734 Ops/s 29.1226 Ops/s $\color{#d91a1a}-1.20\%$
test_simple 0.5750s 0.5154s 1.9403 Ops/s 1.9326 Ops/s $\color{#35bf28}+0.40\%$
test_transformed 0.7228s 0.6888s 1.4518 Ops/s 1.4669 Ops/s $\color{#d91a1a}-1.04\%$
test_serial 1.5065s 1.4622s 0.6839 Ops/s 0.6852 Ops/s $\color{#d91a1a}-0.19\%$
test_parallel 1.4309s 1.3495s 0.7410 Ops/s 0.7357 Ops/s $\color{#35bf28}+0.73\%$
test_step_mdp_speed[True-True-True-True-True] 0.6725ms 42.2489μs 23.6693 KOps/s 22.5777 KOps/s $\color{#35bf28}+4.83\%$
test_step_mdp_speed[True-True-True-True-False] 78.0000μs 23.8784μs 41.8788 KOps/s 40.2908 KOps/s $\color{#35bf28}+3.94\%$
test_step_mdp_speed[True-True-True-False-True] 60.4010μs 29.1393μs 34.3179 KOps/s 34.3148 KOps/s $+0.01\%$
test_step_mdp_speed[True-True-True-False-False] 42.3000μs 16.3828μs 61.0394 KOps/s 60.9973 KOps/s $\color{#35bf28}+0.07\%$
test_step_mdp_speed[True-True-False-True-True] 0.2423ms 42.9177μs 23.3004 KOps/s 23.0402 KOps/s $\color{#35bf28}+1.13\%$
test_step_mdp_speed[True-True-False-True-False] 50.0000μs 25.6376μs 39.0052 KOps/s 39.1316 KOps/s $\color{#d91a1a}-0.32\%$
test_step_mdp_speed[True-True-False-False-True] 0.3519ms 30.9628μs 32.2968 KOps/s 32.3322 KOps/s $\color{#d91a1a}-0.11\%$
test_step_mdp_speed[True-True-False-False-False] 46.8000μs 18.3477μs 54.5027 KOps/s 55.4423 KOps/s $\color{#d91a1a}-1.69\%$
test_step_mdp_speed[True-False-True-True-True] 0.2619ms 45.2726μs 22.0884 KOps/s 22.1637 KOps/s $\color{#d91a1a}-0.34\%$
test_step_mdp_speed[True-False-True-True-False] 51.4000μs 27.7341μs 36.0566 KOps/s 36.6708 KOps/s $\color{#d91a1a}-1.67\%$
test_step_mdp_speed[True-False-True-False-True] 88.7010μs 30.9955μs 32.2628 KOps/s 32.4149 KOps/s $\color{#d91a1a}-0.47\%$
test_step_mdp_speed[True-False-True-False-False] 46.0010μs 18.2629μs 54.7557 KOps/s 55.1271 KOps/s $\color{#d91a1a}-0.67\%$
test_step_mdp_speed[True-False-False-True-True] 87.7010μs 46.7085μs 21.4094 KOps/s 21.3519 KOps/s $\color{#35bf28}+0.27\%$
test_step_mdp_speed[True-False-False-True-False] 58.1000μs 29.2624μs 34.1735 KOps/s 34.3492 KOps/s $\color{#d91a1a}-0.51\%$
test_step_mdp_speed[True-False-False-False-True] 0.1906ms 32.2765μs 30.9823 KOps/s 31.0174 KOps/s $\color{#d91a1a}-0.11\%$
test_step_mdp_speed[True-False-False-False-False] 43.3000μs 20.1602μs 49.6027 KOps/s 50.1240 KOps/s $\color{#d91a1a}-1.04\%$
test_step_mdp_speed[False-True-True-True-True] 83.0010μs 45.1604μs 22.1433 KOps/s 22.2073 KOps/s $\color{#d91a1a}-0.29\%$
test_step_mdp_speed[False-True-True-True-False] 52.3010μs 27.4882μs 36.3792 KOps/s 36.8313 KOps/s $\color{#d91a1a}-1.23\%$
test_step_mdp_speed[False-True-True-False-True] 73.2000μs 34.0273μs 29.3882 KOps/s 29.6401 KOps/s $\color{#d91a1a}-0.85\%$
test_step_mdp_speed[False-True-True-False-False] 2.7071ms 19.8367μs 50.4116 KOps/s 50.6143 KOps/s $\color{#d91a1a}-0.40\%$
test_step_mdp_speed[False-True-False-True-True] 94.3000μs 46.2317μs 21.6302 KOps/s 21.4647 KOps/s $\color{#35bf28}+0.77\%$
test_step_mdp_speed[False-True-False-True-False] 52.9000μs 29.2519μs 34.1858 KOps/s 34.7078 KOps/s $\color{#d91a1a}-1.50\%$
test_step_mdp_speed[False-True-False-False-True] 63.9000μs 35.0153μs 28.5589 KOps/s 28.0907 KOps/s $\color{#35bf28}+1.67\%$
test_step_mdp_speed[False-True-False-False-False] 50.1000μs 21.7415μs 45.9950 KOps/s 46.6572 KOps/s $\color{#d91a1a}-1.42\%$
test_step_mdp_speed[False-False-True-True-True] 79.1000μs 47.9984μs 20.8340 KOps/s 20.6679 KOps/s $\color{#35bf28}+0.80\%$
test_step_mdp_speed[False-False-True-True-False] 1.5292ms 32.0765μs 31.1755 KOps/s 32.5987 KOps/s $\color{#d91a1a}-4.37\%$
test_step_mdp_speed[False-False-True-False-True] 0.1379ms 35.3154μs 28.3162 KOps/s 28.0801 KOps/s $\color{#35bf28}+0.84\%$
test_step_mdp_speed[False-False-True-False-False] 82.9010μs 21.7638μs 45.9478 KOps/s 46.8712 KOps/s $\color{#d91a1a}-1.97\%$
test_step_mdp_speed[False-False-False-True-True] 72.4010μs 49.2847μs 20.2903 KOps/s 19.7393 KOps/s $\color{#35bf28}+2.79\%$
test_step_mdp_speed[False-False-False-True-False] 85.4000μs 32.8602μs 30.4319 KOps/s 31.1289 KOps/s $\color{#d91a1a}-2.24\%$
test_step_mdp_speed[False-False-False-False-True] 93.2010μs 36.4213μs 27.4565 KOps/s 27.3974 KOps/s $\color{#35bf28}+0.22\%$
test_step_mdp_speed[False-False-False-False-False] 98.7010μs 23.1384μs 43.2183 KOps/s 43.6468 KOps/s $\color{#d91a1a}-0.98\%$
test_values[generalized_advantage_estimate-True-True] 14.2447ms 13.7471ms 72.7425 Ops/s 73.8813 Ops/s $\color{#d91a1a}-1.54\%$
test_values[vec_generalized_advantage_estimate-True-True] 57.1088ms 51.1062ms 19.5671 Ops/s 19.5698 Ops/s $\color{#d91a1a}-0.01\%$
test_values[td0_return_estimate-False-False] 0.2800ms 0.1907ms 5.2442 KOps/s 5.0438 KOps/s $\color{#35bf28}+3.97\%$
test_values[td1_return_estimate-False-False] 14.5551ms 13.7999ms 72.4644 Ops/s 74.9384 Ops/s $\color{#d91a1a}-3.30\%$
test_values[vec_td1_return_estimate-False-False] 55.9537ms 50.3559ms 19.8586 Ops/s 19.9825 Ops/s $\color{#d91a1a}-0.62\%$
test_values[td_lambda_return_estimate-True-False] 39.3804ms 33.0829ms 30.2271 Ops/s 31.1808 Ops/s $\color{#d91a1a}-3.06\%$
test_values[vec_td_lambda_return_estimate-True-False] 56.6191ms 50.2329ms 19.9073 Ops/s 19.8582 Ops/s $\color{#35bf28}+0.25\%$
test_gae_speed[generalized_advantage_estimate-False-1-512] 12.6659ms 12.3638ms 80.8815 Ops/s 83.7884 Ops/s $\color{#d91a1a}-3.47\%$
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] 8.5874ms 2.4072ms 415.4237 Ops/s 393.4747 Ops/s $\textbf{\color{#35bf28}+5.58\%}$
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] 0.9617ms 0.4155ms 2.4070 KOps/s 2.4625 KOps/s $\color{#d91a1a}-2.26\%$
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] 49.4139ms 48.8941ms 20.4524 Ops/s 18.4316 Ops/s $\textbf{\color{#35bf28}+10.96\%}$
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] 10.2041ms 3.8198ms 261.7936 Ops/s 256.4847 Ops/s $\color{#35bf28}+2.07\%$
test_dqn_speed 9.5140ms 1.7766ms 562.8626 Ops/s 565.4168 Ops/s $\color{#d91a1a}-0.45\%$
test_ddpg_speed 9.4085ms 2.4420ms 409.5028 Ops/s 403.5304 Ops/s $\color{#35bf28}+1.48\%$
test_sac_speed 14.1843ms 7.7718ms 128.6708 Ops/s 127.8974 Ops/s $\color{#35bf28}+0.60\%$
test_redq_speed 20.9222ms 14.7717ms 67.6972 Ops/s 66.7510 Ops/s $\color{#35bf28}+1.42\%$
test_redq_deprec_speed 18.8325ms 12.2582ms 81.5779 Ops/s 79.5969 Ops/s $\color{#35bf28}+2.49\%$
test_td3_speed 11.8230ms 9.6662ms 103.4528 Ops/s 103.5999 Ops/s $\color{#d91a1a}-0.14\%$
test_cql_speed 31.7515ms 26.4079ms 37.8675 Ops/s 28.8467 Ops/s $\textbf{\color{#35bf28}+31.27\%}$
test_a2c_speed 11.5547ms 5.4105ms 184.8247 Ops/s 187.1120 Ops/s $\color{#d91a1a}-1.22\%$
test_ppo_speed 13.2454ms 5.8623ms 170.5828 Ops/s 166.2552 Ops/s $\color{#35bf28}+2.60\%$
test_reinforce_speed 13.8834ms 4.2465ms 235.4866 Ops/s 242.3526 Ops/s $\color{#d91a1a}-2.83\%$
test_iql_speed 29.7489ms 22.5171ms 44.4108 Ops/s 45.0195 Ops/s $\color{#d91a1a}-1.35\%$
test_sample_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] 3.3395ms 2.7043ms 369.7863 Ops/s 390.5615 Ops/s $\textbf{\color{#d91a1a}-5.32\%}$
test_sample_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] 0.1239s 3.1539ms 317.0630 Ops/s 340.2017 Ops/s $\textbf{\color{#d91a1a}-6.80\%}$
test_sample_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] 4.9175ms 2.8330ms 352.9774 Ops/s 368.8233 Ops/s $\color{#d91a1a}-4.30\%$
test_sample_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] 0.1999s 3.2508ms 307.6209 Ops/s 393.7598 Ops/s $\textbf{\color{#d91a1a}-21.88\%}$
test_sample_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] 4.7319ms 2.8249ms 353.9889 Ops/s 373.1334 Ops/s $\textbf{\color{#d91a1a}-5.13\%}$
test_sample_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] 4.3369ms 2.8111ms 355.7365 Ops/s 366.7044 Ops/s $\color{#d91a1a}-2.99\%$
test_sample_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] 3.5138ms 2.6638ms 375.4101 Ops/s 388.6307 Ops/s $\color{#d91a1a}-3.40\%$
test_sample_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] 4.5389ms 2.7863ms 358.9013 Ops/s 372.9096 Ops/s $\color{#d91a1a}-3.76\%$
test_sample_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] 4.9628ms 2.7637ms 361.8389 Ops/s 368.4590 Ops/s $\color{#d91a1a}-1.80\%$
test_iterate_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] 3.1989ms 2.6168ms 382.1413 Ops/s 391.9071 Ops/s $\color{#d91a1a}-2.49\%$
test_iterate_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] 4.3445ms 2.7694ms 361.0937 Ops/s 372.8810 Ops/s $\color{#d91a1a}-3.16\%$
test_iterate_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] 4.7240ms 2.8050ms 356.5033 Ops/s 362.0788 Ops/s $\color{#d91a1a}-1.54\%$
test_iterate_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] 3.3616ms 2.7568ms 362.7358 Ops/s 381.5147 Ops/s $\color{#d91a1a}-4.92\%$
test_iterate_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] 5.0045ms 2.7788ms 359.8704 Ops/s 370.4147 Ops/s $\color{#d91a1a}-2.85\%$
test_iterate_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] 4.5547ms 2.7860ms 358.9433 Ops/s 371.6851 Ops/s $\color{#d91a1a}-3.43\%$
test_iterate_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] 3.3693ms 2.6412ms 378.6217 Ops/s 392.1825 Ops/s $\color{#d91a1a}-3.46\%$
test_iterate_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] 4.5337ms 2.7567ms 362.7546 Ops/s 374.3449 Ops/s $\color{#d91a1a}-3.10\%$
test_iterate_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] 4.6578ms 2.7795ms 359.7765 Ops/s 367.2856 Ops/s $\color{#d91a1a}-2.04\%$
test_populate_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-400] 0.2198s 26.5162ms 37.7128 Ops/s 37.2321 Ops/s $\color{#35bf28}+1.29\%$
test_populate_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] 0.1179s 26.1802ms 38.1969 Ops/s 38.2202 Ops/s $\color{#d91a1a}-0.06\%$
test_populate_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] 0.1182s 24.3722ms 41.0303 Ops/s 41.4593 Ops/s $\color{#d91a1a}-1.03\%$
test_populate_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] 0.1187s 26.4770ms 37.7686 Ops/s 38.4080 Ops/s $\color{#d91a1a}-1.66\%$
test_populate_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] 0.1195s 24.3780ms 41.0205 Ops/s 41.5751 Ops/s $\color{#d91a1a}-1.33\%$
test_populate_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] 0.1187s 26.3724ms 37.9184 Ops/s 38.0204 Ops/s $\color{#d91a1a}-0.27\%$
test_populate_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] 0.1215s 22.6642ms 44.1224 Ops/s 41.3646 Ops/s $\textbf{\color{#35bf28}+6.67\%}$
test_populate_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] 0.1219s 26.2949ms 38.0302 Ops/s 38.0901 Ops/s $\color{#d91a1a}-0.16\%$
test_populate_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] 0.1198s 24.3793ms 41.0184 Ops/s 38.0046 Ops/s $\textbf{\color{#35bf28}+7.93\%}$

Comment on lines 190 to 196
if self.centralised:
self.net_call = torch.vmap(
self.net,
(None, 0)
) if not self.share_params else self.net
else:
self.net_call = torch.vmap(self.net, (-2, 0)) if not self.share_params else self.net
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this work and pass the test?

Cause it seems to me that we are not doing the resahping into (n_agents * n_inputs_per_agent) in the centralized case

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No it doesn't. In fact it's not ready for review :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But feel free to comment on the general idea from the PR description!

@matteobettini
Copy link
Contributor

By the way, rather than doing a multi-agent MLP, a multi-agent ConvNet, a multi agent LSTM etc, should we not consider one way of batching ops in whichever nn.Module for multiagent and (maybe) just subclass that for every class?

yeah make sense. one thing is that rnns will not be able to be coded as nn modules so we will need another logic for those

might be a bit overkill to use inheritance just for mlp and cnn

@vmoens
Copy link
Contributor Author

vmoens commented Sep 6, 2023

Why would RNNs need to be treated separately?
There are also transformers, GPS, graph nets, or any custom nn.Module no?

@matteobettini
Copy link
Contributor

Because rnns in torchrl are TensorDictModules no?
In fact we do not have them in the models.

Yeah i think a common multiagent nn module can makes sense, we need to be parametric to the multiagent dim but it should be possible

@vmoens
Copy link
Contributor Author

vmoens commented Sep 6, 2023

Because rnns in torchrl are TensorDictModules no?
In fact we do not have them in the models.

That can be easily worked around, it's up to us to adapt our design.

@vmoens vmoens marked this pull request as ready for review September 6, 2023 14:14
@vmoens
Copy link
Contributor Author

vmoens commented Sep 6, 2023

@matteobettini this is ready for review.
FYI it depends on pytorch/tensordict#522 (you can review that one first)

if not self.share_params:
self.net_call = torch.vmap(self.net, in_dims=(-2, 0), out_dims=(-2,))
else:
self.net_call = torch.vmap(self.net, in_dims=(-2, None), out_dims=(-2,))
Copy link
Contributor

@matteobettini matteobettini Sep 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this needed? wouldn't the mlp do vectorization by default?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure all ops will support an arbitrary num of dims

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh ok let's revert to previous version then and we are good to go

Copy link
Contributor

@matteobettini matteobettini left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@vmoens vmoens added the Refactoring Refactoring of an existing feature label Sep 7, 2023
@vmoens
Copy link
Contributor Author

vmoens commented Sep 12, 2023

The mappo test fails
I'm investigating that

@vmoens
Copy link
Contributor Author

vmoens commented Sep 12, 2023

I think this is a good use case for first class dimension, as the bug is a bit convoluted to solve. Let me work on this and put this on hold

@vmoens vmoens marked this pull request as draft September 12, 2023 16:57
@kfu02
Copy link

kfu02 commented Feb 9, 2024

Hi, I am just wondering if this issue is still being investigated, as I would like to try RNNs in the multi-agent setting (specifically for MAPPO, as that paper claims better performance for RNNs vs MLPs in many settings). Apologies if this is not the right place to ask.

@vmoens
Copy link
Contributor Author

vmoens commented Feb 9, 2024

This is definitely a good place to ask. I can refresh and merge this PR now, and make a CNN/RNN version too. Bear with me...

@kfu02
Copy link

kfu02 commented Feb 15, 2024

Great! I eagerly await the refresh/merge of this PR!

@vmoens
Copy link
Contributor Author

vmoens commented Feb 16, 2024

Yeah sorry about the delay I'll jump on it this afternoon

# Conflicts:
#	test/test_modules.py
#	torchrl/modules/models/multiagent.py
Copy link

pytorch-bot bot commented Feb 16, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/1497

Note: Links to docs will display an error until the docs builds have been completed.

❌ 6 New Failures, 32 Unrelated Failures

As of commit 0905f9f with merge base 67f659c (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@vmoens
Copy link
Contributor Author

vmoens commented Feb 16, 2024

@kfu02

Check out #1921
It should work faster and allow us to code RNNs in a successive PR.
The issue right now is just that we can't account for non initialized lazy params as they are gathered in a single TensorDict during construction.

The solution will be a bit convoluted but doable :)

@vmoens vmoens closed this Feb 17, 2024
@vmoens vmoens deleted the edit_ma_mlp branch April 3, 2024 06:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Refactoring Refactoring of an existing feature
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants