Skip to content

Commit

Permalink
fix conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
linshokaku committed May 24, 2023
2 parents 2543d89 + 612c607 commit 2686502
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 0 deletions.
16 changes: 16 additions & 0 deletions pytorch_pfn_extras/handler/_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,10 @@ def train_epoch_begin(
# Needed for `torch.utils.data.DistributedSampler`
loader.sampler.set_epoch(epoch) # type: ignore[attr-defined]

def train_epoch_end(self, models: Mapping[str, Any], epoch: int) -> None:
model = models[self.model_name]
model.eval()

def train_step(
self,
models: Mapping[str, torch.nn.Module],
Expand Down Expand Up @@ -349,6 +353,10 @@ def train_validation_begin(
model = models[self.model_name]
model.eval()

def train_validation_end(self, models: Mapping[str, Any]) -> None:
model = models[self.model_name]
model.train()

def eval_step(
self,
models: Mapping[str, torch.nn.Module],
Expand Down Expand Up @@ -415,6 +423,10 @@ def train_epoch_begin(
# Needed for `torch.utils.data.DistributedSampler`
loader.sampler.set_epoch(epoch) # type: ignore[attr-defined]

def train_epoch_end(self, models: Mapping[str, Any], epoch: int) -> None:
model = models[self.model_name]
model.eval()

def train_step(
self,
models: Mapping[str, torch.nn.Module],
Expand Down Expand Up @@ -458,6 +470,10 @@ def train_validation_begin(
model = models[self.model_name]
model.eval()

def train_validation_end(self, models: Mapping[str, Any]) -> None:
model = models[self.model_name]
model.train()

def eval_step(
self,
models: Mapping[str, torch.nn.Module],
Expand Down
65 changes: 65 additions & 0 deletions tests/pytorch_pfn_extras_tests/test_logic.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from typing import Any, Mapping
from unittest import mock

import pytest
import pytorch_pfn_extras as ppe
import torch
from torch import nn
from torch.nn import Module
from torch.nn import functional as F
from torch.optim import Optimizer


class MyModel(torch.nn.Module):
Expand Down Expand Up @@ -72,3 +75,65 @@ def test_trainer(device):
)
trainer.run(data)
assert backward_fn.call_count == epochs * iters_per_epoch


@pytest.mark.parametrize(
"trigger",
[
(1, "epoch"),
(0.5, "epoch"),
(10, "iteration"),
(5, "iteration"),
(1, "iteration"),
],
)
def test_train_step_mode_with_evaluator(trigger):
iters_per_epoch = 10
epochs = 20
model = MyModel()
ppe.to(model, "cpu")
model_with_loss = MyModelWithLossFn(model)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
data = torch.utils.data.DataLoader(
[
(
torch.rand(
20,
),
torch.rand(
10,
),
)
for i in range(iters_per_epoch)
]
)
backward_fn = mock.Mock(return_value=None)

class LogicWithTrainStepCheck(ppe.handler.Logic):
def train_step(
self,
models: Mapping[str, Module],
optimizers: Mapping[str, Optimizer],
batch_idx: int,
batch: Any,
) -> Any:
model = models[self.model_name]
assert model.training
return super().train_step(models, optimizers, batch_idx, batch)

trainer = ppe.engine.create_trainer(
model_with_loss,
optimizer,
epochs,
logic=LogicWithTrainStepCheck(),
evaluator=(
ppe.engine.create_evaluator(
models=model_with_loss,
logic=LogicWithTrainStepCheck(),
),
trigger,
),
options={"backward_function": backward_fn},
)
trainer.run(data, data)
assert backward_fn.call_count == epochs * iters_per_epoch

0 comments on commit 2686502

Please sign in to comment.