From 9231451476831eb04f955ed5639c7a00b664586b Mon Sep 17 00:00:00 2001 From: Ferdinand Hahmann Date: Tue, 14 Apr 2020 16:41:10 +0200 Subject: [PATCH 1/3] Adding pytest to check that the logits does not depend on the test batch size. --- maml/tests/test_batchsize.py | 49 ++++++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) create mode 100644 maml/tests/test_batchsize.py diff --git a/maml/tests/test_batchsize.py b/maml/tests/test_batchsize.py new file mode 100644 index 0000000..60d8be2 --- /dev/null +++ b/maml/tests/test_batchsize.py @@ -0,0 +1,49 @@ +import pytest +from maml.datasets import get_benchmark_by_name +from maml.metalearners import ModelAgnosticMetaLearning +from maml.utils import tensors_to_device +from torchmeta.utils.data import BatchMetaDataLoader +import torch + +def test_batchsize(): + num_steps = 1 + num_workers = 0 + dataset = 'miniimagenet' + folder = 'data/miniimagenet' + num_ways = 4 + num_shots = 4 + num_shots_test = 4 + hidden_size = 64 + batch_size = 5 + first_order = False + step_size = 0.1 + benchmark = get_benchmark_by_name(dataset, + folder, + num_ways, + num_shots, + num_shots_test, + hidden_size=hidden_size) + meta_test_dataloader = BatchMetaDataLoader(benchmark.meta_test_dataset, + batch_size=batch_size, + shuffle=True, + num_workers=num_workers, + pin_memory=True) + metalearner = ModelAgnosticMetaLearning(benchmark.model, + first_order=first_order, + num_adaptation_steps=num_steps, + step_size=step_size, + loss_function=benchmark.loss_function, + device='cpu') + for batch in meta_test_dataloader: + batch = tensors_to_device(batch, device='cpu') + for task_id, (train_inputs, train_targets, test_inputs, test_targets) in enumerate(zip(*batch['train'], *batch['test'])): + params, _ = metalearner.adapt(train_inputs, train_targets, + is_classification_task=True, + num_adaptation_steps=metalearner.num_adaptation_steps, + step_size=metalearner.step_size, first_order=metalearner.first_order) + test_logits_1 = metalearner.model(test_inputs, params=params) + for idx in range(test_inputs.shape[0]): + test_logits_2 = metalearner.model(test_inputs[idx:idx + 1, ...], params=params) + assert torch.allclose(test_logits_1[idx:idx + 1, ...], test_logits_2, atol=1e-04) + break + return \ No newline at end of file From b6229be9a401da99b62b167159d02b2f610517a2 Mon Sep 17 00:00:00 2001 From: Ferdinand Hahmann Date: Tue, 14 Apr 2020 16:42:07 +0200 Subject: [PATCH 2/3] BatchNorm layer track the running stats --- maml/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/maml/model.py b/maml/model.py index 5e46f8b..fa2f14c 100644 --- a/maml/model.py +++ b/maml/model.py @@ -9,7 +9,7 @@ def conv_block(in_channels, out_channels, **kwargs): return MetaSequential(OrderedDict([ ('conv', MetaConv2d(in_channels, out_channels, **kwargs)), ('norm', nn.BatchNorm2d(out_channels, momentum=1., - track_running_stats=False)), + track_running_stats=True)), ('relu', nn.ReLU()), ('pool', nn.MaxPool2d(2)) ])) From a8a31d61837fbcbfedb15c0567fc7ad7d0dcb293 Mon Sep 17 00:00:00 2001 From: Ferdinand Hahmann Date: Tue, 14 Apr 2020 16:47:29 +0200 Subject: [PATCH 3/3] ModelAgnosticMetaLearning has a new attribute "training", which save the state of the class equally to "model.training". "model.training" needs to be independent from it to track running state of the BatchNorm layers. --- maml/metalearners/maml.py | 8 +++++--- maml/tests/test_batchsize.py | 2 ++ train.py | 2 ++ 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/maml/metalearners/maml.py b/maml/metalearners/maml.py index f7b98c6..d9fcc09 100644 --- a/maml/metalearners/maml.py +++ b/maml/metalearners/maml.py @@ -76,6 +76,7 @@ def __init__(self, model, optimizer=None, step_size=0.1, first_order=False, self.scheduler = scheduler self.loss_function = loss_function self.device = device + self.training = True if per_param_step_size: self.step_size = OrderedDict((name, torch.tensor(step_size, @@ -118,6 +119,7 @@ def get_outer_loss(self, batch): mean_outer_loss = torch.tensor(0., device=self.device) for task_id, (train_inputs, train_targets, test_inputs, test_targets) \ in enumerate(zip(*batch['train'], *batch['test'])): + self.model.train(True) params, adaptation_results = self.adapt(train_inputs, train_targets, is_classification_task=is_classification_task, num_adaptation_steps=self.num_adaptation_steps, @@ -127,7 +129,8 @@ def get_outer_loss(self, batch): if is_classification_task: results['accuracies_before'][task_id] = adaptation_results['accuracy_before'] - with torch.set_grad_enabled(self.model.training): + with torch.set_grad_enabled(self.training): + self.model.train(False) test_logits = self.model(test_inputs, params=params) outer_loss = self.loss_function(test_logits, test_targets) results['outer_losses'][task_id] = outer_loss.item() @@ -162,7 +165,7 @@ def adapt(self, inputs, targets, is_classification_task=None, self.model.zero_grad() params = update_parameters(self.model, inner_loss, step_size=step_size, params=params, - first_order=(not self.model.training) or first_order) + first_order=(not self.training) or first_order) return params, results @@ -184,7 +187,6 @@ def train_iter(self, dataloader, max_batches=500): '(eg. `{0}(model, optimizer=torch.optim.SGD(model.' 'parameters(), lr=0.01), ...).'.format(__class__.__name__)) num_batches = 0 - self.model.train() while num_batches < max_batches: for batch in dataloader: if num_batches >= max_batches: diff --git a/maml/tests/test_batchsize.py b/maml/tests/test_batchsize.py index 60d8be2..b733bd0 100644 --- a/maml/tests/test_batchsize.py +++ b/maml/tests/test_batchsize.py @@ -37,10 +37,12 @@ def test_batchsize(): for batch in meta_test_dataloader: batch = tensors_to_device(batch, device='cpu') for task_id, (train_inputs, train_targets, test_inputs, test_targets) in enumerate(zip(*batch['train'], *batch['test'])): + metalearner.model.train(True) params, _ = metalearner.adapt(train_inputs, train_targets, is_classification_task=True, num_adaptation_steps=metalearner.num_adaptation_steps, step_size=metalearner.step_size, first_order=metalearner.first_order) + metalearner.model.train(False) test_logits_1 = metalearner.model(test_inputs, params=params) for idx in range(test_inputs.shape[0]): test_logits_2 = metalearner.model(test_inputs[idx:idx + 1, ...], params=params) diff --git a/train.py b/train.py index e86065d..348de81 100644 --- a/train.py +++ b/train.py @@ -65,11 +65,13 @@ def main(args): # Training loop epoch_desc = 'Epoch {{0: <{0}d}}'.format(1 + int(math.log10(args.num_epochs))) for epoch in range(args.num_epochs): + metalearner.training = True metalearner.train(meta_train_dataloader, max_batches=args.num_batches, verbose=args.verbose, desc='Training', leave=False) + metalearner.training = False results = metalearner.evaluate(meta_val_dataloader, max_batches=args.num_batches, verbose=args.verbose,