-
Notifications
You must be signed in to change notification settings - Fork 336
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DRAFT, Example] Add MCTS example #2796
base: gh/kurtamohler/5/base
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/2796
Note: Links to docs will display an error until the docs builds have been completed. ❌ 3 New Failures, 2 Unrelated FailuresAs of commit d90202c with merge base d4f8846 ( NEW FAILURES - The following jobs have failed:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: 5dc5cbdb68a621e14617734c386aef6e91edbda3 Pull Request resolved: #2796
This seems to work, but at the moment, it is about 100x slower than the one I implemented outside of TorchRL here. I will see what I can do to speed it up. Once I improve performance, then I'll think about how to add a good API for it |
ghstack-source-id: 512f8540518396b5beb68bb74aafaf8638f44156 Pull Request resolved: #2796
|
Name | Max | Mean | Ops | Ops on Repo HEAD
|
Change |
---|---|---|---|---|---|
test_simple | 0.6175s | 0.5222s | 1.9151 Ops/s | 1.8972 Ops/s | |
test_transformed | 1.1331s | 1.0286s | 0.9722 Ops/s | 0.9593 Ops/s | |
test_serial | 1.6266s | 1.5324s | 0.6526 Ops/s | 0.6493 Ops/s | |
test_parallel | 1.3983s | 1.2987s | 0.7700 Ops/s | 0.7631 Ops/s | |
test_step_mdp_speed[True-True-True-True-True] | 0.1843ms | 30.9592μs | 32.3006 KOps/s | 32.1457 KOps/s | |
test_step_mdp_speed[True-True-True-True-False] | 40.3260μs | 17.8510μs | 56.0193 KOps/s | 54.0759 KOps/s | |
test_step_mdp_speed[True-True-True-False-True] | 75.6930μs | 17.8048μs | 56.1645 KOps/s | 56.9130 KOps/s | |
test_step_mdp_speed[True-True-True-False-False] | 64.5110μs | 9.9876μs | 100.1243 KOps/s | 95.8146 KOps/s | |
test_step_mdp_speed[True-True-False-True-True] | 0.1166ms | 32.7434μs | 30.5405 KOps/s | 30.0040 KOps/s | |
test_step_mdp_speed[True-True-False-True-False] | 52.0980μs | 19.5504μs | 51.1498 KOps/s | 48.2528 KOps/s | |
test_step_mdp_speed[True-True-False-False-True] | 75.0110μs | 19.1025μs | 52.3491 KOps/s | 52.6426 KOps/s | |
test_step_mdp_speed[True-True-False-False-False] | 36.4880μs | 11.9171μs | 83.9132 KOps/s | 81.9139 KOps/s | |
test_step_mdp_speed[True-False-True-True-True] | 93.9460μs | 34.0557μs | 29.3637 KOps/s | 28.9388 KOps/s | |
test_step_mdp_speed[True-False-True-True-False] | 65.6730μs | 21.4875μs | 46.5387 KOps/s | 44.9446 KOps/s | |
test_step_mdp_speed[True-False-True-False-True] | 0.6261ms | 19.0157μs | 52.5880 KOps/s | 52.0208 KOps/s | |
test_step_mdp_speed[True-False-True-False-False] | 68.7490μs | 11.8667μs | 84.2691 KOps/s | 81.7045 KOps/s | |
test_step_mdp_speed[True-False-False-True-True] | 79.9600μs | 35.7080μs | 28.0049 KOps/s | 27.7468 KOps/s | |
test_step_mdp_speed[True-False-False-True-False] | 77.2650μs | 23.2044μs | 43.0952 KOps/s | 41.6326 KOps/s | |
test_step_mdp_speed[True-False-False-False-True] | 73.8390μs | 20.7081μs | 48.2902 KOps/s | 47.6512 KOps/s | |
test_step_mdp_speed[True-False-False-False-False] | 48.7720μs | 13.5803μs | 73.6358 KOps/s | 71.4829 KOps/s | |
test_step_mdp_speed[False-True-True-True-True] | 0.1002ms | 34.2497μs | 29.1974 KOps/s | 28.2334 KOps/s | |
test_step_mdp_speed[False-True-True-True-False] | 46.8680μs | 21.4981μs | 46.5158 KOps/s | 44.6556 KOps/s | |
test_step_mdp_speed[False-True-True-False-True] | 72.9070μs | 21.9846μs | 45.4863 KOps/s | 45.6100 KOps/s | |
test_step_mdp_speed[False-True-True-False-False] | 42.3690μs | 13.2241μs | 75.6193 KOps/s | 73.2419 KOps/s | |
test_step_mdp_speed[False-True-False-True-True] | 96.0910μs | 36.0289μs | 27.7555 KOps/s | 27.2927 KOps/s | |
test_step_mdp_speed[False-True-False-True-False] | 80.5710μs | 23.3483μs | 42.8297 KOps/s | 41.4139 KOps/s | |
test_step_mdp_speed[False-True-False-False-True] | 2.4611ms | 23.6706μs | 42.2466 KOps/s | 41.8897 KOps/s | |
test_step_mdp_speed[False-True-False-False-False] | 66.6680μs | 14.9179μs | 67.0335 KOps/s | 64.0416 KOps/s | |
test_step_mdp_speed[False-False-True-True-True] | 0.2903ms | 37.8548μs | 26.4167 KOps/s | 26.1062 KOps/s | |
test_step_mdp_speed[False-False-True-True-False] | 78.6180μs | 24.8089μs | 40.3081 KOps/s | 38.1866 KOps/s | |
test_step_mdp_speed[False-False-True-False-True] | 75.4420μs | 23.5631μs | 42.4392 KOps/s | 41.8005 KOps/s | |
test_step_mdp_speed[False-False-True-False-False] | 43.7720μs | 15.0059μs | 66.6406 KOps/s | 64.1173 KOps/s | |
test_step_mdp_speed[False-False-False-True-True] | 0.6203ms | 39.4416μs | 25.3539 KOps/s | 24.8095 KOps/s | |
test_step_mdp_speed[False-False-False-True-False] | 54.7630μs | 26.6703μs | 37.4949 KOps/s | 34.4231 KOps/s | |
test_step_mdp_speed[False-False-False-False-True] | 62.4670μs | 25.0515μs | 39.9178 KOps/s | 39.9885 KOps/s | |
test_step_mdp_speed[False-False-False-False-False] | 71.4440μs | 16.6912μs | 59.9117 KOps/s | 58.6803 KOps/s | |
test_values[generalized_advantage_estimate-True-True] | 12.2149ms | 9.6668ms | 103.4467 Ops/s | 98.9054 Ops/s | |
test_values[vec_generalized_advantage_estimate-True-True] | 28.4316ms | 26.7785ms | 37.3434 Ops/s | 40.8891 Ops/s | |
test_values[td0_return_estimate-False-False] | 0.2452ms | 0.1863ms | 5.3680 KOps/s | 5.1882 KOps/s | |
test_values[td1_return_estimate-False-False] | 25.3956ms | 24.4179ms | 40.9536 Ops/s | 40.5239 Ops/s | |
test_values[vec_td1_return_estimate-False-False] | 28.8388ms | 26.8837ms | 37.1973 Ops/s | 40.7472 Ops/s | |
test_values[td_lambda_return_estimate-True-False] | 37.8873ms | 34.7523ms | 28.7751 Ops/s | 27.6997 Ops/s | |
test_values[vec_td_lambda_return_estimate-True-False] | 29.6251ms | 26.9493ms | 37.1068 Ops/s | 40.7798 Ops/s | |
test_gae_speed[generalized_advantage_estimate-False-1-512] | 8.7033ms | 8.5157ms | 117.4303 Ops/s | 114.7094 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] | 2.4398ms | 1.9736ms | 506.7004 Ops/s | 512.0098 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] | 0.5033ms | 0.3689ms | 2.7107 KOps/s | 2.6915 KOps/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] | 46.0872ms | 44.7122ms | 22.3653 Ops/s | 23.6552 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] | 5.2718ms | 3.4994ms | 285.7606 Ops/s | 274.2531 Ops/s | |
test_dqn_speed[False-None] | 5.9663ms | 1.4293ms | 699.6426 Ops/s | 693.3207 Ops/s | |
test_dqn_speed[False-backward] | 2.0084ms | 1.9102ms | 523.4971 Ops/s | 515.3687 Ops/s | |
test_dqn_speed[True-None] | 0.7649ms | 0.4916ms | 2.0340 KOps/s | 1.9821 KOps/s | |
test_dqn_speed[True-backward] | 0.9608ms | 0.9138ms | 1.0943 KOps/s | 1.0102 KOps/s | |
test_dqn_speed[reduce-overhead-None] | 0.7424ms | 0.4994ms | 2.0024 KOps/s | 1.9877 KOps/s | |
test_dqn_speed[reduce-overhead-backward] | 0.9778ms | 0.9272ms | 1.0785 KOps/s | 1.0844 KOps/s | |
test_ddpg_speed[False-None] | 3.4381ms | 2.9239ms | 342.0134 Ops/s | 337.4566 Ops/s | |
test_ddpg_speed[False-backward] | 4.1578ms | 4.0571ms | 246.4826 Ops/s | 239.0507 Ops/s | |
test_ddpg_speed[True-None] | 1.4357ms | 1.2553ms | 796.6201 Ops/s | 791.1489 Ops/s | |
test_ddpg_speed[True-backward] | 2.2468ms | 2.1384ms | 467.6487 Ops/s | 459.5028 Ops/s | |
test_ddpg_speed[reduce-overhead-None] | 1.4014ms | 1.2397ms | 806.6783 Ops/s | 795.5446 Ops/s | |
test_ddpg_speed[reduce-overhead-backward] | 2.1963ms | 2.1486ms | 465.4210 Ops/s | 461.3308 Ops/s | |
test_sac_speed[False-None] | 9.7606ms | 8.1672ms | 122.4413 Ops/s | 120.4289 Ops/s | |
test_sac_speed[False-backward] | 11.3081ms | 10.9282ms | 91.5066 Ops/s | 89.5492 Ops/s | |
test_sac_speed[True-None] | 2.4606ms | 2.1268ms | 470.1884 Ops/s | 451.2203 Ops/s | |
test_sac_speed[True-backward] | 4.9933ms | 3.9501ms | 253.1608 Ops/s | 251.1789 Ops/s | |
test_sac_speed[reduce-overhead-None] | 2.6814ms | 2.1122ms | 473.4483 Ops/s | 428.4997 Ops/s | |
test_sac_speed[reduce-overhead-backward] | 3.8646ms | 3.7997ms | 263.1759 Ops/s | 234.0449 Ops/s | |
test_redq_speed[False-None] | 19.1162ms | 13.7037ms | 72.9728 Ops/s | 72.7626 Ops/s | |
test_redq_speed[False-backward] | 30.3747ms | 23.2842ms | 42.9475 Ops/s | 42.8210 Ops/s | |
test_redq_speed[True-None] | 6.1085ms | 5.5606ms | 179.8376 Ops/s | 176.3052 Ops/s | |
test_redq_speed[True-backward] | 14.0947ms | 13.4996ms | 74.0765 Ops/s | 75.0943 Ops/s | |
test_redq_speed[reduce-overhead-None] | 7.2820ms | 5.8963ms | 169.5986 Ops/s | 171.6651 Ops/s | |
test_redq_speed[reduce-overhead-backward] | 15.1155ms | 13.7475ms | 72.7406 Ops/s | 73.9759 Ops/s | |
test_redq_deprec_speed[False-None] | 15.9083ms | 14.1763ms | 70.5401 Ops/s | 73.2660 Ops/s | |
test_redq_deprec_speed[False-backward] | 22.5850ms | 20.3131ms | 49.2293 Ops/s | 50.2689 Ops/s | |
test_redq_deprec_speed[True-None] | 5.5060ms | 4.4648ms | 223.9755 Ops/s | 243.8710 Ops/s | |
test_redq_deprec_speed[True-backward] | 10.0150ms | 9.3619ms | 106.8160 Ops/s | 107.6145 Ops/s | |
test_redq_deprec_speed[reduce-overhead-None] | 5.0637ms | 4.0620ms | 246.1847 Ops/s | 237.8016 Ops/s | |
test_redq_deprec_speed[reduce-overhead-backward] | 9.5742ms | 9.0558ms | 110.4263 Ops/s | 110.1511 Ops/s | |
test_td3_speed[False-None] | 10.1384ms | 8.3553ms | 119.6852 Ops/s | 119.9756 Ops/s | |
test_td3_speed[False-backward] | 11.9334ms | 11.0128ms | 90.8035 Ops/s | 91.3422 Ops/s | |
test_td3_speed[True-None] | 2.0747ms | 1.9124ms | 522.8933 Ops/s | 536.3619 Ops/s | |
test_td3_speed[True-backward] | 4.1253ms | 3.8748ms | 258.0766 Ops/s | 286.6193 Ops/s | |
test_td3_speed[reduce-overhead-None] | 2.0516ms | 1.8752ms | 533.2886 Ops/s | 526.9510 Ops/s | |
test_td3_speed[reduce-overhead-backward] | 3.8828ms | 3.6854ms | 271.3445 Ops/s | 283.0428 Ops/s | |
test_cql_speed[False-None] | 40.9889ms | 38.2333ms | 26.1552 Ops/s | 26.6875 Ops/s | |
test_cql_speed[False-backward] | 57.1261ms | 48.3787ms | 20.6702 Ops/s | 20.4153 Ops/s | |
test_cql_speed[True-None] | 17.9840ms | 16.7885ms | 59.5647 Ops/s | 59.3220 Ops/s | |
test_cql_speed[True-backward] | 24.7962ms | 23.5321ms | 42.4952 Ops/s | 41.1774 Ops/s | |
test_cql_speed[reduce-overhead-None] | 17.7399ms | 16.5337ms | 60.4825 Ops/s | 59.1388 Ops/s | |
test_cql_speed[reduce-overhead-backward] | 25.3222ms | 23.6639ms | 42.2584 Ops/s | 41.4083 Ops/s | |
test_a2c_speed[False-None] | 8.3358ms | 7.3074ms | 136.8478 Ops/s | 130.7232 Ops/s | |
test_a2c_speed[False-backward] | 15.5036ms | 14.6760ms | 68.1383 Ops/s | 63.1199 Ops/s | |
test_a2c_speed[True-None] | 4.9574ms | 3.7894ms | 263.8914 Ops/s | 259.7024 Ops/s | |
test_a2c_speed[True-backward] | 11.5226ms | 10.4817ms | 95.4048 Ops/s | 91.4282 Ops/s | |
test_a2c_speed[reduce-overhead-None] | 4.8052ms | 3.7722ms | 265.0976 Ops/s | 257.5680 Ops/s | |
test_a2c_speed[reduce-overhead-backward] | 12.3357ms | 10.5274ms | 94.9903 Ops/s | 90.6309 Ops/s | |
test_ppo_speed[False-None] | 8.6519ms | 7.6152ms | 131.3159 Ops/s | 120.8475 Ops/s | |
test_ppo_speed[False-backward] | 16.5301ms | 15.3460ms | 65.1636 Ops/s | 63.8162 Ops/s | |
test_ppo_speed[True-None] | 4.4193ms | 4.1497ms | 240.9818 Ops/s | 223.8530 Ops/s | |
test_ppo_speed[True-backward] | 10.8265ms | 10.2466ms | 97.5932 Ops/s | 91.7465 Ops/s | |
test_ppo_speed[reduce-overhead-None] | 5.1574ms | 4.1344ms | 241.8751 Ops/s | 222.9146 Ops/s | |
test_ppo_speed[reduce-overhead-backward] | 10.7295ms | 10.1501ms | 98.5213 Ops/s | 90.3343 Ops/s | |
test_reinforce_speed[False-None] | 7.4124ms | 6.6319ms | 150.7869 Ops/s | 143.0439 Ops/s | |
test_reinforce_speed[False-backward] | 10.1645ms | 9.9770ms | 100.2308 Ops/s | 94.5207 Ops/s | |
test_reinforce_speed[True-None] | 3.9429ms | 3.1376ms | 318.7148 Ops/s | 273.7543 Ops/s | |
test_reinforce_speed[True-backward] | 11.6127ms | 9.6688ms | 103.4251 Ops/s | 98.4432 Ops/s | |
test_reinforce_speed[reduce-overhead-None] | 3.7992ms | 3.1254ms | 319.9592 Ops/s | 277.2036 Ops/s | |
test_reinforce_speed[reduce-overhead-backward] | 10.0521ms | 9.1686ms | 109.0674 Ops/s | 104.0447 Ops/s | |
test_iql_speed[False-None] | 38.0960ms | 33.6432ms | 29.7237 Ops/s | 29.2264 Ops/s | |
test_iql_speed[False-backward] | 0.3730s | 52.4790ms | 19.0553 Ops/s | 21.0787 Ops/s | |
test_iql_speed[True-None] | 13.2750ms | 11.6641ms | 85.7329 Ops/s | 83.6217 Ops/s | |
test_iql_speed[True-backward] | 23.5424ms | 22.9789ms | 43.5182 Ops/s | 43.4024 Ops/s | |
test_iql_speed[reduce-overhead-None] | 16.8024ms | 12.0812ms | 82.7734 Ops/s | 85.0949 Ops/s | |
test_iql_speed[reduce-overhead-backward] | 24.6331ms | 23.1251ms | 43.2430 Ops/s | 43.6621 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 5.6602ms | 5.2261ms | 191.3473 Ops/s | 200.5830 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 2.5219ms | 0.5559ms | 1.7988 KOps/s | 1.8313 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 1.0544ms | 0.5362ms | 1.8649 KOps/s | 1.8904 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 7.5757ms | 4.9707ms | 201.1781 Ops/s | 204.6013 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 1.2941ms | 0.5414ms | 1.8471 KOps/s | 1.8596 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.7420ms | 0.5087ms | 1.9658 KOps/s | 1.9493 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-sampler6-10000] | 2.1870ms | 1.7107ms | 584.5624 Ops/s | 582.6143 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-sampler7-10000] | 2.3828ms | 1.6318ms | 612.8296 Ops/s | 609.5040 Ops/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 5.0834ms | 4.8867ms | 204.6360 Ops/s | 202.1230 Ops/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 1.3448ms | 0.6776ms | 1.4757 KOps/s | 1.4559 KOps/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 0.9163ms | 0.6514ms | 1.5351 KOps/s | 1.4881 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 5.0678ms | 4.9001ms | 204.0785 Ops/s | 201.4258 Ops/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 2.2522ms | 0.5530ms | 1.8084 KOps/s | 1.8055 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 0.7588ms | 0.5212ms | 1.9186 KOps/s | 1.8645 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 5.6946ms | 4.7895ms | 208.7915 Ops/s | 207.2582 Ops/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 2.4728ms | 0.5441ms | 1.8379 KOps/s | 1.8164 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.7444ms | 0.5090ms | 1.9648 KOps/s | 1.9594 KOps/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 5.1713ms | 4.8932ms | 204.3648 Ops/s | 199.8262 Ops/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 2.8079ms | 0.6826ms | 1.4651 KOps/s | 1.4416 KOps/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 0.8934ms | 0.6497ms | 1.5392 KOps/s | 1.5141 KOps/s | |
test_rb_populate[TensorDictReplayBuffer-ListStorage-RandomSampler-400] | 5.9832ms | 4.3389ms | 230.4725 Ops/s | 237.4878 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] | 7.0568ms | 2.4053ms | 415.7508 Ops/s | 416.5131 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] | 6.7151ms | 1.4342ms | 697.2502 Ops/s | 672.7578 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] | 5.7846ms | 4.3527ms | 229.7436 Ops/s | 240.5738 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] | 6.8833ms | 2.3859ms | 419.1350 Ops/s | 407.5892 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] | 0.4726s | 10.8645ms | 92.0427 Ops/s | 754.0772 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] | 6.0097ms | 4.5155ms | 221.4573 Ops/s | 29.8296 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] | 9.6705ms | 2.5589ms | 390.7878 Ops/s | 439.2559 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] | 5.0974ms | 1.5430ms | 648.0834 Ops/s | 712.3835 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-10000-10000-100-True] | 12.2878ms | 11.8856ms | 84.1355 Ops/s | 84.0272 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-10000-10000-100-False] | 17.3090ms | 14.4916ms | 69.0056 Ops/s | 68.9879 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-100000-10000-100-True] | 21.7794ms | 20.7989ms | 48.0795 Ops/s | 46.9459 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-100000-10000-100-False] | 15.8797ms | 14.7445ms | 67.8221 Ops/s | 67.1953 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-1000000-10000-100-True] | 22.4090ms | 20.8891ms | 47.8719 Ops/s | 48.3941 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-1000000-10000-100-False] | 17.7159ms | 16.1549ms | 61.9009 Ops/s | 62.0356 Ops/s |
|
Name | Max | Mean | Ops | Ops on Repo HEAD
|
Change |
---|---|---|---|---|---|
test_simple | 0.9024s | 0.8111s | 1.2330 Ops/s | 1.2443 Ops/s | |
test_transformed | 1.4885s | 1.3990s | 0.7148 Ops/s | 0.6881 Ops/s | |
test_serial | 2.2502s | 2.2470s | 0.4450 Ops/s | 0.4288 Ops/s | |
test_parallel | 1.9690s | 1.8759s | 0.5331 Ops/s | 0.5415 Ops/s | |
test_step_mdp_speed[True-True-True-True-True] | 0.1128ms | 37.9405μs | 26.3571 KOps/s | 24.6534 KOps/s | |
test_step_mdp_speed[True-True-True-True-False] | 56.5810μs | 22.5050μs | 44.4345 KOps/s | 42.3343 KOps/s | |
test_step_mdp_speed[True-True-True-False-True] | 49.9010μs | 21.3507μs | 46.8369 KOps/s | 44.0371 KOps/s | |
test_step_mdp_speed[True-True-True-False-False] | 0.1080ms | 12.4725μs | 80.1767 KOps/s | 76.9529 KOps/s | |
test_step_mdp_speed[True-True-False-True-True] | 77.9210μs | 40.6749μs | 24.5852 KOps/s | 23.0614 KOps/s | |
test_step_mdp_speed[True-True-False-True-False] | 96.7020μs | 24.7458μs | 40.4109 KOps/s | 38.5441 KOps/s | |
test_step_mdp_speed[True-True-False-False-True] | 58.1210μs | 24.0430μs | 41.5921 KOps/s | 40.3530 KOps/s | |
test_step_mdp_speed[True-True-False-False-False] | 42.7100μs | 15.0501μs | 66.4446 KOps/s | 63.7020 KOps/s | |
test_step_mdp_speed[True-False-True-True-True] | 71.0010μs | 44.0232μs | 22.7153 KOps/s | 21.6358 KOps/s | |
test_step_mdp_speed[True-False-True-True-False] | 55.7710μs | 27.0425μs | 36.9789 KOps/s | 35.2487 KOps/s | |
test_step_mdp_speed[True-False-True-False-True] | 51.4310μs | 24.0759μs | 41.5353 KOps/s | 39.5241 KOps/s | |
test_step_mdp_speed[True-False-True-False-False] | 85.2210μs | 14.8622μs | 67.2848 KOps/s | 63.9388 KOps/s | |
test_step_mdp_speed[True-False-False-True-True] | 81.8410μs | 45.7442μs | 21.8607 KOps/s | 21.2243 KOps/s | |
test_step_mdp_speed[True-False-False-True-False] | 56.8910μs | 29.5695μs | 33.8187 KOps/s | 32.5516 KOps/s | |
test_step_mdp_speed[True-False-False-False-True] | 56.3410μs | 25.8361μs | 38.7055 KOps/s | 37.0513 KOps/s | |
test_step_mdp_speed[True-False-False-False-False] | 52.0600μs | 16.7003μs | 59.8793 KOps/s | 56.0811 KOps/s | |
test_step_mdp_speed[False-True-True-True-True] | 72.5410μs | 43.7729μs | 22.8452 KOps/s | 22.3047 KOps/s | |
test_step_mdp_speed[False-True-True-True-False] | 57.0110μs | 26.6372μs | 37.5414 KOps/s | 34.8494 KOps/s | |
test_step_mdp_speed[False-True-True-False-True] | 2.6962ms | 27.3879μs | 36.5125 KOps/s | 34.5241 KOps/s | |
test_step_mdp_speed[False-True-True-False-False] | 43.8510μs | 16.6154μs | 60.1852 KOps/s | 57.8275 KOps/s | |
test_step_mdp_speed[False-True-False-True-True] | 0.1108ms | 45.7779μs | 21.8446 KOps/s | 21.1496 KOps/s | |
test_step_mdp_speed[False-True-False-True-False] | 60.9500μs | 29.4526μs | 33.9529 KOps/s | 32.4557 KOps/s | |
test_step_mdp_speed[False-True-False-False-True] | 61.0010μs | 29.8506μs | 33.5002 KOps/s | 31.8843 KOps/s | |
test_step_mdp_speed[False-True-False-False-False] | 46.5500μs | 18.8811μs | 52.9631 KOps/s | 52.8693 KOps/s | |
test_step_mdp_speed[False-False-True-True-True] | 74.0310μs | 47.6371μs | 20.9920 KOps/s | 19.8215 KOps/s | |
test_step_mdp_speed[False-False-True-True-False] | 65.9410μs | 31.4016μs | 31.8455 KOps/s | 29.5131 KOps/s | |
test_step_mdp_speed[False-False-True-False-True] | 91.1720μs | 29.8988μs | 33.4461 KOps/s | 32.5258 KOps/s | |
test_step_mdp_speed[False-False-True-False-False] | 47.6100μs | 18.6935μs | 53.4946 KOps/s | 51.0158 KOps/s | |
test_step_mdp_speed[False-False-False-True-True] | 90.0210μs | 49.9215μs | 20.0314 KOps/s | 19.3464 KOps/s | |
test_step_mdp_speed[False-False-False-True-False] | 0.2214ms | 33.7618μs | 29.6193 KOps/s | 27.9720 KOps/s | |
test_step_mdp_speed[False-False-False-False-True] | 71.4610μs | 31.4490μs | 31.7975 KOps/s | 30.6030 KOps/s | |
test_step_mdp_speed[False-False-False-False-False] | 75.5510μs | 20.9188μs | 47.8039 KOps/s | 45.1928 KOps/s | |
test_values[generalized_advantage_estimate-True-True] | 25.8726ms | 25.5169ms | 39.1897 Ops/s | 40.3473 Ops/s | |
test_values[vec_generalized_advantage_estimate-True-True] | 0.1152s | 3.2064ms | 311.8774 Ops/s | 322.9260 Ops/s | |
test_values[td0_return_estimate-False-False] | 0.1080ms | 79.7268μs | 12.5428 KOps/s | 12.3756 KOps/s | |
test_values[td1_return_estimate-False-False] | 61.9897ms | 58.6170ms | 17.0599 Ops/s | 17.8557 Ops/s | |
test_values[vec_td1_return_estimate-False-False] | 1.3342ms | 1.0931ms | 914.8599 Ops/s | 923.8950 Ops/s | |
test_values[td_lambda_return_estimate-True-False] | 91.4606ms | 90.8398ms | 11.0084 Ops/s | 11.2554 Ops/s | |
test_values[vec_td_lambda_return_estimate-True-False] | 1.3979ms | 1.0930ms | 914.9106 Ops/s | 925.0125 Ops/s | |
test_gae_speed[generalized_advantage_estimate-False-1-512] | 25.3682ms | 25.0268ms | 39.9571 Ops/s | 40.8619 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] | 1.0638ms | 0.7637ms | 1.3095 KOps/s | 1.3240 KOps/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] | 0.8192ms | 0.6760ms | 1.4792 KOps/s | 1.4904 KOps/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] | 1.6391ms | 1.4927ms | 669.9117 Ops/s | 674.6612 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] | 0.8543ms | 0.6898ms | 1.4497 KOps/s | 1.4528 KOps/s | |
test_dqn_speed[False-None] | 6.9847ms | 1.4936ms | 669.5298 Ops/s | 654.6117 Ops/s | |
test_dqn_speed[False-backward] | 2.2650ms | 2.1067ms | 474.6657 Ops/s | 469.1643 Ops/s | |
test_dqn_speed[True-None] | 0.1587s | 0.6561ms | 1.5242 KOps/s | 1.6830 KOps/s | |
test_dqn_speed[True-backward] | 1.2755ms | 1.2244ms | 816.6934 Ops/s | 808.5532 Ops/s | |
test_dqn_speed[reduce-overhead-None] | 0.7346ms | 0.5804ms | 1.7230 KOps/s | 1.6579 KOps/s | |
test_dqn_speed[reduce-overhead-backward] | 1.1234ms | 1.0719ms | 932.9170 Ops/s | 1.0290 KOps/s | |
test_ddpg_speed[False-None] | 3.1480ms | 2.8152ms | 355.2142 Ops/s | 348.4543 Ops/s | |
test_ddpg_speed[False-backward] | 4.2932ms | 4.1772ms | 239.3947 Ops/s | 241.5302 Ops/s | |
test_ddpg_speed[True-None] | 1.5087ms | 1.3481ms | 741.7636 Ops/s | 743.1255 Ops/s | |
test_ddpg_speed[True-backward] | 2.5445ms | 2.4071ms | 415.4351 Ops/s | 408.7144 Ops/s | |
test_ddpg_speed[reduce-overhead-None] | 1.5139ms | 1.3554ms | 737.7933 Ops/s | 735.3389 Ops/s | |
test_ddpg_speed[reduce-overhead-backward] | 1.9290ms | 1.8889ms | 529.3967 Ops/s | 526.4949 Ops/s | |
test_sac_speed[False-None] | 8.3165ms | 7.9401ms | 125.9427 Ops/s | 123.4959 Ops/s | |
test_sac_speed[False-backward] | 11.3655ms | 10.8639ms | 92.0480 Ops/s | 90.4774 Ops/s | |
test_sac_speed[True-None] | 2.0478ms | 1.8579ms | 538.2365 Ops/s | 533.2831 Ops/s | |
test_sac_speed[True-backward] | 3.7468ms | 3.5898ms | 278.5664 Ops/s | 265.3952 Ops/s | |
test_sac_speed[reduce-overhead-None] | 20.5716ms | 11.6312ms | 85.9757 Ops/s | 85.1413 Ops/s | |
test_sac_speed[reduce-overhead-backward] | 1.7871ms | 1.6411ms | 609.3609 Ops/s | 541.1511 Ops/s | |
test_redq_speed[False-None] | 7.7279ms | 7.3865ms | 135.3822 Ops/s | 131.9948 Ops/s | |
test_redq_speed[False-backward] | 11.8368ms | 11.2586ms | 88.8208 Ops/s | 84.1872 Ops/s | |
test_redq_speed[True-None] | 2.7421ms | 2.3603ms | 423.6824 Ops/s | 420.9243 Ops/s | |
test_redq_speed[True-backward] | 4.3757ms | 4.2297ms | 236.4255 Ops/s | 228.5406 Ops/s | |
test_redq_speed[reduce-overhead-None] | 2.5255ms | 2.3694ms | 422.0469 Ops/s | 415.7063 Ops/s | |
test_redq_speed[reduce-overhead-backward] | 4.3999ms | 4.2318ms | 236.3063 Ops/s | 235.3771 Ops/s | |
test_redq_deprec_speed[False-None] | 9.6127ms | 8.9474ms | 111.7646 Ops/s | 110.1424 Ops/s | |
test_redq_deprec_speed[False-backward] | 12.7964ms | 12.2596ms | 81.5687 Ops/s | 80.1379 Ops/s | |
test_redq_deprec_speed[True-None] | 2.8498ms | 2.6713ms | 374.3536 Ops/s | 375.1299 Ops/s | |
test_redq_deprec_speed[True-backward] | 4.6540ms | 4.5006ms | 222.1909 Ops/s | 220.6760 Ops/s | |
test_redq_deprec_speed[reduce-overhead-None] | 2.8527ms | 2.6450ms | 378.0692 Ops/s | 373.1767 Ops/s | |
test_redq_deprec_speed[reduce-overhead-backward] | 4.8713ms | 4.4958ms | 222.4282 Ops/s | 220.5824 Ops/s | |
test_td3_speed[False-None] | 7.9778ms | 7.8425ms | 127.5111 Ops/s | 126.1581 Ops/s | |
test_td3_speed[False-backward] | 11.0507ms | 10.3876ms | 96.2690 Ops/s | 94.8747 Ops/s | |
test_td3_speed[True-None] | 1.7693ms | 1.6879ms | 592.4470 Ops/s | 593.1977 Ops/s | |
test_td3_speed[True-backward] | 3.4222ms | 3.3740ms | 296.3884 Ops/s | 309.4337 Ops/s | |
test_td3_speed[reduce-overhead-None] | 49.3250ms | 25.3956ms | 39.3770 Ops/s | 41.0234 Ops/s | |
test_td3_speed[reduce-overhead-backward] | 1.6861ms | 1.5355ms | 651.2495 Ops/s | 715.7298 Ops/s | |
test_cql_speed[False-None] | 16.9714ms | 16.5848ms | 60.2963 Ops/s | 59.4008 Ops/s | |
test_cql_speed[False-backward] | 22.6609ms | 22.1009ms | 45.2469 Ops/s | 45.4337 Ops/s | |
test_cql_speed[True-None] | 3.7199ms | 3.4320ms | 291.3763 Ops/s | 284.1761 Ops/s | |
test_cql_speed[True-backward] | 6.0366ms | 5.6507ms | 176.9703 Ops/s | 179.3846 Ops/s | |
test_cql_speed[reduce-overhead-None] | 21.1306ms | 12.8564ms | 77.7821 Ops/s | 76.7557 Ops/s | |
test_cql_speed[reduce-overhead-backward] | 2.0089ms | 1.8657ms | 535.9866 Ops/s | 526.9308 Ops/s | |
test_a2c_speed[False-None] | 3.3172ms | 3.1486ms | 317.6052 Ops/s | 307.0070 Ops/s | |
test_a2c_speed[False-backward] | 6.5640ms | 6.0673ms | 164.8174 Ops/s | 161.8745 Ops/s | |
test_a2c_speed[True-None] | 1.5784ms | 1.3610ms | 734.7501 Ops/s | 730.9896 Ops/s | |
test_a2c_speed[True-backward] | 2.9904ms | 2.9003ms | 344.7964 Ops/s | 341.7160 Ops/s | |
test_a2c_speed[reduce-overhead-None] | 15.1148ms | 8.3339ms | 119.9916 Ops/s | 115.1253 Ops/s | |
test_a2c_speed[reduce-overhead-backward] | 1.6697ms | 1.4589ms | 685.4628 Ops/s | 671.3371 Ops/s | |
test_ppo_speed[False-None] | 3.8727ms | 3.6663ms | 272.7569 Ops/s | 266.6039 Ops/s | |
test_ppo_speed[False-backward] | 7.1045ms | 6.7950ms | 147.1664 Ops/s | 144.9106 Ops/s | |
test_ppo_speed[True-None] | 1.5968ms | 1.4194ms | 704.5428 Ops/s | 680.7189 Ops/s | |
test_ppo_speed[True-backward] | 3.0818ms | 3.0349ms | 329.4975 Ops/s | 322.6773 Ops/s | |
test_ppo_speed[reduce-overhead-None] | 1.1263ms | 0.9704ms | 1.0305 KOps/s | 1.0329 KOps/s | |
test_ppo_speed[reduce-overhead-backward] | 1.7174ms | 1.5628ms | 639.8803 Ops/s | 683.9738 Ops/s | |
test_reinforce_speed[False-None] | 2.4955ms | 2.2231ms | 449.8180 Ops/s | 435.7217 Ops/s | |
test_reinforce_speed[False-backward] | 3.7709ms | 3.3569ms | 297.8943 Ops/s | 300.9179 Ops/s | |
test_reinforce_speed[True-None] | 1.5260ms | 1.3055ms | 765.9706 Ops/s | 753.8459 Ops/s | |
test_reinforce_speed[True-backward] | 3.1496ms | 3.0530ms | 327.5500 Ops/s | 344.6038 Ops/s | |
test_reinforce_speed[reduce-overhead-None] | 17.5533ms | 9.7120ms | 102.9655 Ops/s | 102.5919 Ops/s | |
test_reinforce_speed[reduce-overhead-backward] | 1.8220ms | 1.6455ms | 607.7313 Ops/s | 589.9837 Ops/s | |
test_iql_speed[False-None] | 9.4865ms | 9.0723ms | 110.2252 Ops/s | 106.8639 Ops/s | |
test_iql_speed[False-backward] | 13.4199ms | 13.0443ms | 76.6617 Ops/s | 74.8610 Ops/s | |
test_iql_speed[True-None] | 2.4219ms | 2.2411ms | 446.2115 Ops/s | 434.3926 Ops/s | |
test_iql_speed[True-backward] | 5.2768ms | 4.8775ms | 205.0246 Ops/s | 199.7520 Ops/s | |
test_iql_speed[reduce-overhead-None] | 0.4922s | 12.5089ms | 79.9428 Ops/s | 92.9854 Ops/s | |
test_iql_speed[reduce-overhead-backward] | 2.2373ms | 2.1007ms | 476.0209 Ops/s | 462.6554 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 7.5246ms | 6.0896ms | 164.2138 Ops/s | 162.0042 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 0.6269ms | 0.3277ms | 3.0513 KOps/s | 2.9986 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 0.6024ms | 0.2947ms | 3.3928 KOps/s | 3.1750 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 6.1345ms | 5.8189ms | 171.8527 Ops/s | 169.2296 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 1.6940ms | 0.2954ms | 3.3857 KOps/s | 3.2184 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.5426ms | 0.2842ms | 3.5185 KOps/s | 3.2317 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-sampler6-10000] | 1.7246ms | 1.4246ms | 701.9687 Ops/s | 702.6461 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-sampler7-10000] | 1.6187ms | 1.2870ms | 776.9716 Ops/s | 756.7469 Ops/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 6.1443ms | 5.9899ms | 166.9465 Ops/s | 162.8856 Ops/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 1.8475ms | 0.4070ms | 2.4572 KOps/s | 2.0703 KOps/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 0.6858ms | 0.4423ms | 2.2609 KOps/s | 2.5725 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 5.9997ms | 5.8196ms | 171.8318 Ops/s | 167.2710 Ops/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 1.0243ms | 0.2866ms | 3.4893 KOps/s | 3.0223 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 0.5544ms | 0.3708ms | 2.6971 KOps/s | 3.5840 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 8.8756ms | 5.8530ms | 170.8513 Ops/s | 170.2279 Ops/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 0.8340ms | 0.2903ms | 3.4450 KOps/s | 2.8492 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.5855ms | 0.2859ms | 3.4976 KOps/s | 3.2314 KOps/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 6.2519ms | 6.0007ms | 166.6478 Ops/s | 163.7038 Ops/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 1.8331ms | 0.5015ms | 1.9942 KOps/s | 2.1672 KOps/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 0.6468ms | 0.4775ms | 2.0944 KOps/s | 2.0841 KOps/s | |
test_rb_populate[TensorDictReplayBuffer-ListStorage-RandomSampler-400] | 6.9246ms | 5.3408ms | 187.2375 Ops/s | 177.0212 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] | 10.2242ms | 2.0637ms | 484.5624 Ops/s | 420.0322 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] | 6.6887ms | 1.1891ms | 840.9556 Ops/s | 755.8146 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] | 0.4696s | 14.6946ms | 68.0523 Ops/s | 176.4254 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] | 5.6425ms | 1.9905ms | 502.3946 Ops/s | 426.3204 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] | 10.2162ms | 1.2768ms | 783.2043 Ops/s | 859.7147 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] | 8.9207ms | 5.6278ms | 177.6896 Ops/s | 30.8883 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] | 8.3153ms | 2.1779ms | 459.1492 Ops/s | 424.1024 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] | 8.7596ms | 1.4062ms | 711.1345 Ops/s | 784.1174 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-10000-10000-100-True] | 13.1347ms | 12.8182ms | 78.0138 Ops/s | 75.5464 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-10000-10000-100-False] | 18.3582ms | 16.6753ms | 59.9691 Ops/s | 57.8135 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-100000-10000-100-True] | 18.2302ms | 17.8490ms | 56.0256 Ops/s | 55.2187 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-100000-10000-100-False] | 18.4385ms | 16.8757ms | 59.2568 Ops/s | 57.8430 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-1000000-10000-100-True] | 18.1026ms | 17.3581ms | 57.6101 Ops/s | 55.2279 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-1000000-10000-100-False] | 19.7164ms | 17.9428ms | 55.7325 Ops/s | 54.5954 Ops/s |
ghstack-source-id: 899441844c058e291017de34c7be2df0f8219a31 Pull Request resolved: #2796
include_hash_inv=True, | ||
include_san=True, | ||
stateful=True, | ||
mask_actions=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note to self: from performance profiling, it looks like action masking takes up a significant chunk of the runtime, which makes sense--it's a 29275 element tensor created at each step. So I'll need to try turning this off, which will require adding support for ChessEnv.all_actions
in the absence of the mask
ghstack-source-id: 899441844c058e291017de34c7be2df0f8219a31 Pull Request resolved: pytorch#2796
ghstack-source-id: 899441844c058e291017de34c7be2df0f8219a31 Pull Request resolved: pytorch#2796
ghstack-source-id: 899441844c058e291017de34c7be2df0f8219a31 Pull Request resolved: pytorch#2796
ghstack-source-id: 08ebabd8c0b3ba0776a3b45370a056a3b90b20d2 Pull Request resolved: #2796
ghstack-source-id: 08ebabd8c0b3ba0776a3b45370a056a3b90b20d2 Pull Request resolved: pytorch#2796
ghstack-source-id: bd984300c77e8ed51adf687d1d826e9c149911f0 Pull Request resolved: #2796
@kurtamohler LMK if you need help with this! |
Stack from ghstack (oldest at bottom):
MCTSForest.extend
#2795MCTSForest/Tree.to_string
#2794