From ae82630a71fafa7792a8f0b936276cb614c2c78b Mon Sep 17 00:00:00 2001 From: Linsho Kaku Date: Wed, 24 May 2023 15:18:38 +0900 Subject: [PATCH 1/3] add test --- tests/pytorch_pfn_extras_tests/test_logic.py | 39 +++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/tests/pytorch_pfn_extras_tests/test_logic.py b/tests/pytorch_pfn_extras_tests/test_logic.py index eade4d6aa..cf0db4adb 100644 --- a/tests/pytorch_pfn_extras_tests/test_logic.py +++ b/tests/pytorch_pfn_extras_tests/test_logic.py @@ -1,8 +1,10 @@ +from typing import Any, Mapping import pytest import torch from torch import nn -from torch.nn import functional as F +from torch.optim import Optimizer +from torch.nn import functional as F, Module from unittest import mock import pytorch_pfn_extras as ppe @@ -60,3 +62,38 @@ def test_trainer(device): ) trainer.run(data) assert backward_fn.call_count == epochs * iters_per_epoch + +@pytest.mark.parametrize("trigger", [(1, "epoch"), (0.5, "epoch"), (10, "iteration"), (5, "iteration"), (1, "iteration")]) +def test_train_step_mode_with_evaluator(trigger): + iters_per_epoch = 10 + epochs = 20 + model = MyModel() + ppe.to(model, "cpu") + model_with_loss = MyModelWithLossFn(model) + optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + data = torch.utils.data.DataLoader( + [(torch.rand(20,), torch.rand(10,)) for i in range(iters_per_epoch)]) + backward_fn = mock.Mock(return_value=None) + + class LogicWithTrainStepCheck(ppe.handler.Logic): + def train_step(self, models: Mapping[str, Module], optimizers: Mapping[str, Optimizer], batch_idx: int, batch: Any) -> Any: + model = models[self.model_name] + assert model.training + return super().train_step(models, optimizers, batch_idx, batch) + + trainer = ppe.engine.create_trainer( + model_with_loss, + optimizer, + epochs, + logic=LogicWithTrainStepCheck(), + evaluator=( + ppe.engine.create_evaluator( + models=model_with_loss, + logic=LogicWithTrainStepCheck(), + ), + trigger, + ), + options={'backward_function': backward_fn} + ) + trainer.run(data, data) + assert backward_fn.call_count == epochs * iters_per_epoch From 707d4e21a7a8807502b480e81d14b7fac5b5d134 Mon Sep 17 00:00:00 2001 From: Linsho Kaku Date: Wed, 24 May 2023 15:27:35 +0900 Subject: [PATCH 2/3] fix test lint --- tests/pytorch_pfn_extras_tests/test_logic.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/pytorch_pfn_extras_tests/test_logic.py b/tests/pytorch_pfn_extras_tests/test_logic.py index cf0db4adb..9d85f8821 100644 --- a/tests/pytorch_pfn_extras_tests/test_logic.py +++ b/tests/pytorch_pfn_extras_tests/test_logic.py @@ -63,6 +63,7 @@ def test_trainer(device): trainer.run(data) assert backward_fn.call_count == epochs * iters_per_epoch + @pytest.mark.parametrize("trigger", [(1, "epoch"), (0.5, "epoch"), (10, "iteration"), (5, "iteration"), (1, "iteration")]) def test_train_step_mode_with_evaluator(trigger): iters_per_epoch = 10 @@ -82,15 +83,15 @@ def train_step(self, models: Mapping[str, Module], optimizers: Mapping[str, Opti return super().train_step(models, optimizers, batch_idx, batch) trainer = ppe.engine.create_trainer( - model_with_loss, - optimizer, - epochs, + model_with_loss, + optimizer, + epochs, logic=LogicWithTrainStepCheck(), evaluator=( ppe.engine.create_evaluator( models=model_with_loss, logic=LogicWithTrainStepCheck(), - ), + ), trigger, ), options={'backward_function': backward_fn} From ece1018b1c0b50cb3342c691f1a05e2cd07e5157 Mon Sep 17 00:00:00 2001 From: Linsho Kaku Date: Wed, 24 May 2023 15:28:46 +0900 Subject: [PATCH 3/3] add train_validation_end methods for recover training mode --- pytorch_pfn_extras/handler/_logic.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/pytorch_pfn_extras/handler/_logic.py b/pytorch_pfn_extras/handler/_logic.py index b13e1fd5f..b9dacc663 100644 --- a/pytorch_pfn_extras/handler/_logic.py +++ b/pytorch_pfn_extras/handler/_logic.py @@ -268,6 +268,10 @@ def train_epoch_begin( # Needed for `torch.utils.data.DistributedSampler` loader.sampler.set_epoch(epoch) # type: ignore[attr-defined] + def train_epoch_end(self, models: Mapping[str, Any], epoch: int) -> None: + model = models[self.model_name] + model.eval() + def train_step( self, models: Mapping[str, torch.nn.Module], @@ -339,6 +343,10 @@ def train_validation_begin( model = models[self.model_name] model.eval() + def train_validation_end(self, models: Mapping[str, Any]) -> None: + model = models[self.model_name] + model.train() + def eval_step( self, models: Mapping[str, torch.nn.Module], @@ -404,6 +412,10 @@ def train_epoch_begin( # Needed for `torch.utils.data.DistributedSampler` loader.sampler.set_epoch(epoch) # type: ignore[attr-defined] + def train_epoch_end(self, models: Mapping[str, Any], epoch: int) -> None: + model = models[self.model_name] + model.eval() + def train_step( self, models: Mapping[str, torch.nn.Module], @@ -447,6 +459,10 @@ def train_validation_begin( model = models[self.model_name] model.eval() + def train_validation_end(self, models: Mapping[str, Any]) -> None: + model = models[self.model_name] + model.train() + def eval_step( self, models: Mapping[str, torch.nn.Module],