diff --git a/benchmarks/nn/functional_benchmarks_test.py b/benchmarks/nn/functional_benchmarks_test.py index bfc26b640..b4c7a2147 100644 --- a/benchmarks/nn/functional_benchmarks_test.py +++ b/benchmarks/nn/functional_benchmarks_test.py @@ -1,6 +1,5 @@ # we use deepcopy as our implementation modifies the modules in-place import argparse -from copy import deepcopy import pytest import torch @@ -27,11 +26,11 @@ def net(): def _functorch_make_functional(net): - functorch_make_functional(deepcopy(net)) + functorch_make_functional(net) def _make_functional(net): - make_functional(deepcopy(net)) + make_functional(net) def make_tdmodule(): @@ -129,14 +128,23 @@ def test_tdseq_dispatch(benchmark): # Creation -def test_instantiation_functorch(benchmark, net): +def test_instantiation_functorch( + benchmark, +): benchmark.pedantic( - _functorch_make_functional, args=(net,), iterations=10, rounds=100 + _functorch_make_functional, + setup=lambda: ((make_net(),), {}), + iterations=1, + rounds=10000, ) -def test_instantiation_td(benchmark, net): - benchmark.pedantic(_make_functional, args=(net,), iterations=10, rounds=100) +def test_instantiation_td( + benchmark, +): + benchmark.pedantic( + _make_functional, setup=lambda: ((make_net(),), {}), iterations=1, rounds=10000 + ) # Execution