diff --git a/pytorch_pfn_extras/handler/_logic.py b/pytorch_pfn_extras/handler/_logic.py index e5b3b0110..af9f78a53 100644 --- a/pytorch_pfn_extras/handler/_logic.py +++ b/pytorch_pfn_extras/handler/_logic.py @@ -277,6 +277,10 @@ def train_epoch_begin( # Needed for `torch.utils.data.DistributedSampler` loader.sampler.set_epoch(epoch) # type: ignore[attr-defined] + def train_epoch_end(self, models: Mapping[str, Any], epoch: int) -> None: + model = models[self.model_name] + model.eval() + def train_step( self, models: Mapping[str, torch.nn.Module], @@ -349,6 +353,10 @@ def train_validation_begin( model = models[self.model_name] model.eval() + def train_validation_end(self, models: Mapping[str, Any]) -> None: + model = models[self.model_name] + model.train() + def eval_step( self, models: Mapping[str, torch.nn.Module], @@ -415,6 +423,10 @@ def train_epoch_begin( # Needed for `torch.utils.data.DistributedSampler` loader.sampler.set_epoch(epoch) # type: ignore[attr-defined] + def train_epoch_end(self, models: Mapping[str, Any], epoch: int) -> None: + model = models[self.model_name] + model.eval() + def train_step( self, models: Mapping[str, torch.nn.Module], @@ -458,6 +470,10 @@ def train_validation_begin( model = models[self.model_name] model.eval() + def train_validation_end(self, models: Mapping[str, Any]) -> None: + model = models[self.model_name] + model.train() + def eval_step( self, models: Mapping[str, torch.nn.Module], diff --git a/tests/pytorch_pfn_extras_tests/test_logic.py b/tests/pytorch_pfn_extras_tests/test_logic.py index bbc65e5a9..6bf97fd33 100644 --- a/tests/pytorch_pfn_extras_tests/test_logic.py +++ b/tests/pytorch_pfn_extras_tests/test_logic.py @@ -1,10 +1,13 @@ +from typing import Any, Mapping from unittest import mock import pytest import pytorch_pfn_extras as ppe import torch from torch import nn +from torch.nn import Module from torch.nn import functional as F +from torch.optim import Optimizer class MyModel(torch.nn.Module): @@ -72,3 +75,65 @@ def test_trainer(device): ) trainer.run(data) assert backward_fn.call_count == epochs * iters_per_epoch + + +@pytest.mark.parametrize( + "trigger", + [ + (1, "epoch"), + (0.5, "epoch"), + (10, "iteration"), + (5, "iteration"), + (1, "iteration"), + ], +) +def test_train_step_mode_with_evaluator(trigger): + iters_per_epoch = 10 + epochs = 20 + model = MyModel() + ppe.to(model, "cpu") + model_with_loss = MyModelWithLossFn(model) + optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + data = torch.utils.data.DataLoader( + [ + ( + torch.rand( + 20, + ), + torch.rand( + 10, + ), + ) + for i in range(iters_per_epoch) + ] + ) + backward_fn = mock.Mock(return_value=None) + + class LogicWithTrainStepCheck(ppe.handler.Logic): + def train_step( + self, + models: Mapping[str, Module], + optimizers: Mapping[str, Optimizer], + batch_idx: int, + batch: Any, + ) -> Any: + model = models[self.model_name] + assert model.training + return super().train_step(models, optimizers, batch_idx, batch) + + trainer = ppe.engine.create_trainer( + model_with_loss, + optimizer, + epochs, + logic=LogicWithTrainStepCheck(), + evaluator=( + ppe.engine.create_evaluator( + models=model_with_loss, + logic=LogicWithTrainStepCheck(), + ), + trigger, + ), + options={"backward_function": backward_fn}, + ) + trainer.run(data, data) + assert backward_fn.call_count == epochs * iters_per_epoch