-
Notifications
You must be signed in to change notification settings - Fork 338
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Refactor] Refactor multi-agent MLP #1497
Conversation
|
Name | Max | Mean | Ops | Ops on Repo HEAD
|
Change |
---|---|---|---|---|---|
test_single | 65.3593ms | 64.3205ms | 15.5471 Ops/s | 15.8124 Ops/s | |
test_sync | 0.1046s | 37.1208ms | 26.9391 Ops/s | 26.6108 Ops/s | |
test_async | 0.1167s | 34.7543ms | 28.7734 Ops/s | 29.1226 Ops/s | |
test_simple | 0.5750s | 0.5154s | 1.9403 Ops/s | 1.9326 Ops/s | |
test_transformed | 0.7228s | 0.6888s | 1.4518 Ops/s | 1.4669 Ops/s | |
test_serial | 1.5065s | 1.4622s | 0.6839 Ops/s | 0.6852 Ops/s | |
test_parallel | 1.4309s | 1.3495s | 0.7410 Ops/s | 0.7357 Ops/s | |
test_step_mdp_speed[True-True-True-True-True] | 0.6725ms | 42.2489μs | 23.6693 KOps/s | 22.5777 KOps/s | |
test_step_mdp_speed[True-True-True-True-False] | 78.0000μs | 23.8784μs | 41.8788 KOps/s | 40.2908 KOps/s | |
test_step_mdp_speed[True-True-True-False-True] | 60.4010μs | 29.1393μs | 34.3179 KOps/s | 34.3148 KOps/s | |
test_step_mdp_speed[True-True-True-False-False] | 42.3000μs | 16.3828μs | 61.0394 KOps/s | 60.9973 KOps/s | |
test_step_mdp_speed[True-True-False-True-True] | 0.2423ms | 42.9177μs | 23.3004 KOps/s | 23.0402 KOps/s | |
test_step_mdp_speed[True-True-False-True-False] | 50.0000μs | 25.6376μs | 39.0052 KOps/s | 39.1316 KOps/s | |
test_step_mdp_speed[True-True-False-False-True] | 0.3519ms | 30.9628μs | 32.2968 KOps/s | 32.3322 KOps/s | |
test_step_mdp_speed[True-True-False-False-False] | 46.8000μs | 18.3477μs | 54.5027 KOps/s | 55.4423 KOps/s | |
test_step_mdp_speed[True-False-True-True-True] | 0.2619ms | 45.2726μs | 22.0884 KOps/s | 22.1637 KOps/s | |
test_step_mdp_speed[True-False-True-True-False] | 51.4000μs | 27.7341μs | 36.0566 KOps/s | 36.6708 KOps/s | |
test_step_mdp_speed[True-False-True-False-True] | 88.7010μs | 30.9955μs | 32.2628 KOps/s | 32.4149 KOps/s | |
test_step_mdp_speed[True-False-True-False-False] | 46.0010μs | 18.2629μs | 54.7557 KOps/s | 55.1271 KOps/s | |
test_step_mdp_speed[True-False-False-True-True] | 87.7010μs | 46.7085μs | 21.4094 KOps/s | 21.3519 KOps/s | |
test_step_mdp_speed[True-False-False-True-False] | 58.1000μs | 29.2624μs | 34.1735 KOps/s | 34.3492 KOps/s | |
test_step_mdp_speed[True-False-False-False-True] | 0.1906ms | 32.2765μs | 30.9823 KOps/s | 31.0174 KOps/s | |
test_step_mdp_speed[True-False-False-False-False] | 43.3000μs | 20.1602μs | 49.6027 KOps/s | 50.1240 KOps/s | |
test_step_mdp_speed[False-True-True-True-True] | 83.0010μs | 45.1604μs | 22.1433 KOps/s | 22.2073 KOps/s | |
test_step_mdp_speed[False-True-True-True-False] | 52.3010μs | 27.4882μs | 36.3792 KOps/s | 36.8313 KOps/s | |
test_step_mdp_speed[False-True-True-False-True] | 73.2000μs | 34.0273μs | 29.3882 KOps/s | 29.6401 KOps/s | |
test_step_mdp_speed[False-True-True-False-False] | 2.7071ms | 19.8367μs | 50.4116 KOps/s | 50.6143 KOps/s | |
test_step_mdp_speed[False-True-False-True-True] | 94.3000μs | 46.2317μs | 21.6302 KOps/s | 21.4647 KOps/s | |
test_step_mdp_speed[False-True-False-True-False] | 52.9000μs | 29.2519μs | 34.1858 KOps/s | 34.7078 KOps/s | |
test_step_mdp_speed[False-True-False-False-True] | 63.9000μs | 35.0153μs | 28.5589 KOps/s | 28.0907 KOps/s | |
test_step_mdp_speed[False-True-False-False-False] | 50.1000μs | 21.7415μs | 45.9950 KOps/s | 46.6572 KOps/s | |
test_step_mdp_speed[False-False-True-True-True] | 79.1000μs | 47.9984μs | 20.8340 KOps/s | 20.6679 KOps/s | |
test_step_mdp_speed[False-False-True-True-False] | 1.5292ms | 32.0765μs | 31.1755 KOps/s | 32.5987 KOps/s | |
test_step_mdp_speed[False-False-True-False-True] | 0.1379ms | 35.3154μs | 28.3162 KOps/s | 28.0801 KOps/s | |
test_step_mdp_speed[False-False-True-False-False] | 82.9010μs | 21.7638μs | 45.9478 KOps/s | 46.8712 KOps/s | |
test_step_mdp_speed[False-False-False-True-True] | 72.4010μs | 49.2847μs | 20.2903 KOps/s | 19.7393 KOps/s | |
test_step_mdp_speed[False-False-False-True-False] | 85.4000μs | 32.8602μs | 30.4319 KOps/s | 31.1289 KOps/s | |
test_step_mdp_speed[False-False-False-False-True] | 93.2010μs | 36.4213μs | 27.4565 KOps/s | 27.3974 KOps/s | |
test_step_mdp_speed[False-False-False-False-False] | 98.7010μs | 23.1384μs | 43.2183 KOps/s | 43.6468 KOps/s | |
test_values[generalized_advantage_estimate-True-True] | 14.2447ms | 13.7471ms | 72.7425 Ops/s | 73.8813 Ops/s | |
test_values[vec_generalized_advantage_estimate-True-True] | 57.1088ms | 51.1062ms | 19.5671 Ops/s | 19.5698 Ops/s | |
test_values[td0_return_estimate-False-False] | 0.2800ms | 0.1907ms | 5.2442 KOps/s | 5.0438 KOps/s | |
test_values[td1_return_estimate-False-False] | 14.5551ms | 13.7999ms | 72.4644 Ops/s | 74.9384 Ops/s | |
test_values[vec_td1_return_estimate-False-False] | 55.9537ms | 50.3559ms | 19.8586 Ops/s | 19.9825 Ops/s | |
test_values[td_lambda_return_estimate-True-False] | 39.3804ms | 33.0829ms | 30.2271 Ops/s | 31.1808 Ops/s | |
test_values[vec_td_lambda_return_estimate-True-False] | 56.6191ms | 50.2329ms | 19.9073 Ops/s | 19.8582 Ops/s | |
test_gae_speed[generalized_advantage_estimate-False-1-512] | 12.6659ms | 12.3638ms | 80.8815 Ops/s | 83.7884 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] | 8.5874ms | 2.4072ms | 415.4237 Ops/s | 393.4747 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] | 0.9617ms | 0.4155ms | 2.4070 KOps/s | 2.4625 KOps/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] | 49.4139ms | 48.8941ms | 20.4524 Ops/s | 18.4316 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] | 10.2041ms | 3.8198ms | 261.7936 Ops/s | 256.4847 Ops/s | |
test_dqn_speed | 9.5140ms | 1.7766ms | 562.8626 Ops/s | 565.4168 Ops/s | |
test_ddpg_speed | 9.4085ms | 2.4420ms | 409.5028 Ops/s | 403.5304 Ops/s | |
test_sac_speed | 14.1843ms | 7.7718ms | 128.6708 Ops/s | 127.8974 Ops/s | |
test_redq_speed | 20.9222ms | 14.7717ms | 67.6972 Ops/s | 66.7510 Ops/s | |
test_redq_deprec_speed | 18.8325ms | 12.2582ms | 81.5779 Ops/s | 79.5969 Ops/s | |
test_td3_speed | 11.8230ms | 9.6662ms | 103.4528 Ops/s | 103.5999 Ops/s | |
test_cql_speed | 31.7515ms | 26.4079ms | 37.8675 Ops/s | 28.8467 Ops/s | |
test_a2c_speed | 11.5547ms | 5.4105ms | 184.8247 Ops/s | 187.1120 Ops/s | |
test_ppo_speed | 13.2454ms | 5.8623ms | 170.5828 Ops/s | 166.2552 Ops/s | |
test_reinforce_speed | 13.8834ms | 4.2465ms | 235.4866 Ops/s | 242.3526 Ops/s | |
test_iql_speed | 29.7489ms | 22.5171ms | 44.4108 Ops/s | 45.0195 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 3.3395ms | 2.7043ms | 369.7863 Ops/s | 390.5615 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 0.1239s | 3.1539ms | 317.0630 Ops/s | 340.2017 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 4.9175ms | 2.8330ms | 352.9774 Ops/s | 368.8233 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 0.1999s | 3.2508ms | 307.6209 Ops/s | 393.7598 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 4.7319ms | 2.8249ms | 353.9889 Ops/s | 373.1334 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 4.3369ms | 2.8111ms | 355.7365 Ops/s | 366.7044 Ops/s | |
test_sample_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 3.5138ms | 2.6638ms | 375.4101 Ops/s | 388.6307 Ops/s | |
test_sample_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 4.5389ms | 2.7863ms | 358.9013 Ops/s | 372.9096 Ops/s | |
test_sample_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 4.9628ms | 2.7637ms | 361.8389 Ops/s | 368.4590 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 3.1989ms | 2.6168ms | 382.1413 Ops/s | 391.9071 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 4.3445ms | 2.7694ms | 361.0937 Ops/s | 372.8810 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 4.7240ms | 2.8050ms | 356.5033 Ops/s | 362.0788 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 3.3616ms | 2.7568ms | 362.7358 Ops/s | 381.5147 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 5.0045ms | 2.7788ms | 359.8704 Ops/s | 370.4147 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 4.5547ms | 2.7860ms | 358.9433 Ops/s | 371.6851 Ops/s | |
test_iterate_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 3.3693ms | 2.6412ms | 378.6217 Ops/s | 392.1825 Ops/s | |
test_iterate_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 4.5337ms | 2.7567ms | 362.7546 Ops/s | 374.3449 Ops/s | |
test_iterate_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 4.6578ms | 2.7795ms | 359.7765 Ops/s | 367.2856 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-400] | 0.2198s | 26.5162ms | 37.7128 Ops/s | 37.2321 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] | 0.1179s | 26.1802ms | 38.1969 Ops/s | 38.2202 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] | 0.1182s | 24.3722ms | 41.0303 Ops/s | 41.4593 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] | 0.1187s | 26.4770ms | 37.7686 Ops/s | 38.4080 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] | 0.1195s | 24.3780ms | 41.0205 Ops/s | 41.5751 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] | 0.1187s | 26.3724ms | 37.9184 Ops/s | 38.0204 Ops/s | |
test_populate_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] | 0.1215s | 22.6642ms | 44.1224 Ops/s | 41.3646 Ops/s | |
test_populate_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] | 0.1219s | 26.2949ms | 38.0302 Ops/s | 38.0901 Ops/s | |
test_populate_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] | 0.1198s | 24.3793ms | 41.0184 Ops/s | 38.0046 Ops/s |
torchrl/modules/models/multiagent.py
Outdated
if self.centralised: | ||
self.net_call = torch.vmap( | ||
self.net, | ||
(None, 0) | ||
) if not self.share_params else self.net | ||
else: | ||
self.net_call = torch.vmap(self.net, (-2, 0)) if not self.share_params else self.net |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this work and pass the test?
Cause it seems to me that we are not doing the resahping into (n_agents
* n_inputs_per_agent
) in the centralized case
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No it doesn't. In fact it's not ready for review :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But feel free to comment on the general idea from the PR description!
yeah make sense. one thing is that rnns will not be able to be coded as nn modules so we will need another logic for those might be a bit overkill to use inheritance just for mlp and cnn |
Why would RNNs need to be treated separately? |
Because rnns in torchrl are TensorDictModules no? Yeah i think a common multiagent nn module can makes sense, we need to be parametric to the multiagent dim but it should be possible |
That can be easily worked around, it's up to us to adapt our design. |
@matteobettini this is ready for review. |
torchrl/modules/models/multiagent.py
Outdated
if not self.share_params: | ||
self.net_call = torch.vmap(self.net, in_dims=(-2, 0), out_dims=(-2,)) | ||
else: | ||
self.net_call = torch.vmap(self.net, in_dims=(-2, None), out_dims=(-2,)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this needed? wouldn't the mlp do vectorization by default?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure all ops will support an arbitrary num of dims
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh ok let's revert to previous version then and we are good to go
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Co-authored-by: Matteo Bettini <[email protected]>
The mappo test fails |
I think this is a good use case for first class dimension, as the bug is a bit convoluted to solve. Let me work on this and put this on hold |
Hi, I am just wondering if this issue is still being investigated, as I would like to try RNNs in the multi-agent setting (specifically for MAPPO, as that paper claims better performance for RNNs vs MLPs in many settings). Apologies if this is not the right place to ask. |
This is definitely a good place to ask. I can refresh and merge this PR now, and make a CNN/RNN version too. Bear with me... |
Great! I eagerly await the refresh/merge of this PR! |
Yeah sorry about the delay I'll jump on it this afternoon |
# Conflicts: # test/test_modules.py # torchrl/modules/models/multiagent.py
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/1497
Note: Links to docs will display an error until the docs builds have been completed. ❌ 6 New Failures, 32 Unrelated FailuresAs of commit 0905f9f with merge base 67f659c ( NEW FAILURES - The following jobs have failed:
FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
cc @matteobettini
By the way, rather than doing a multi-agent MLP, a multi-agent ConvNet, a multi agent LSTM etc, should we not consider one way of batching ops in whichever nn.Module for multiagent and (maybe) just subclass that for every class?
Seems a bit repetitive no? Isn't there pare of the logic we can recycle?
TODO: