Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added SER Strategy #1672

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 207 additions & 0 deletions avalanche/training/supervised/ser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
from collections import defaultdict
from typing import (
Callable,
List,
Optional,
Union,
)

import torch
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss, Module
from torch.optim import Optimizer
from copy import deepcopy

from avalanche.training.templates.strategy_mixin_protocol import CriterionType
from avalanche.training.utils import cycle
from avalanche.core import SupervisedPlugin
from avalanche.training.plugins.evaluation import (
EvaluationPlugin,
default_evaluator,
)
from avalanche.training.templates import SupervisedTemplate
from avalanche.training.supervised.der import ClassBalancedBufferWithLogits
from avalanche.models import avalanche_forward


class SER(SupervisedTemplate):
"""
Implements the SER Strategy,
from the "Continual Learning with Strong Experience Replay"
paper, https://arxiv.org/pdf/2305.13622
"""

def __init__(
self,
*,
model: Module,
optimizer: Optimizer,
criterion: CriterionType = CrossEntropyLoss(),
mem_size: int = 200,
batch_size_mem: Optional[int] = None,
alpha: float = 0.1,
beta: float = 0.5,
train_mb_size: int = 1,
train_epochs: int = 1,
eval_mb_size: Optional[int] = 1,
device: Union[str, torch.device] = "cpu",
plugins: Optional[List[SupervisedPlugin]] = None,
evaluator: Union[
EvaluationPlugin, Callable[[], EvaluationPlugin]
] = default_evaluator,
eval_every=-1,
peval_mode="epoch",
**kwargs
):
"""
:param model: PyTorch model.
:param optimizer: PyTorch optimizer.
:param criterion: loss function.
:param mem_size: int : Fixed memory size
:param batch_size_mem: int : Size of the batch sampled from the buffer
:param alpha: float : Hyperparameter weighting the backward consistency loss
:param beta: float : Hyperparameter weighting the forward consistency loss
:param train_mb_size: mini-batch size for training.
:param train_passes: number of training passes.
:param eval_mb_size: mini-batch size for eval.
:param device: PyTorch device where the model will be allocated.
:param plugins: (optional) list of StrategyPlugins.
:param evaluator: (optional) instance of EvaluationPlugin for logging
and metric computations. None to remove logging.
:param eval_every: the frequency of the calls to `eval` inside the
training loop. -1 disables the evaluation. 0 means `eval` is called
only at the end of the learning experience. Values >0 mean that
`eval` is called every `eval_every` experiences and at the end of
the learning experience.
:param peval_mode: one of {'experience', 'iteration'}. Decides whether
the periodic evaluation during training should execute every
`eval_every` experience or iterations (Default='experience').
"""
super().__init__(
model=model,
optimizer=optimizer,
criterion=criterion,
train_mb_size=train_mb_size,
train_epochs=train_epochs,
eval_mb_size=eval_mb_size,
device=device,
plugins=plugins,
evaluator=evaluator,
eval_every=eval_every,
peval_mode=peval_mode,
**kwargs
)
if batch_size_mem is None:
self.batch_size_mem = train_mb_size
else:
self.batch_size_mem = batch_size_mem
self.mem_size = mem_size
self.storage_policy = ClassBalancedBufferWithLogits(
self.mem_size, adaptive_size=True
)
self.replay_loader = None
self.alpha = alpha
self.beta = beta
self.old_model = None

def _before_training_exp(self, **kwargs):
buffer = self.storage_policy.buffer
if len(buffer) >= self.batch_size_mem:
self.replay_loader = cycle(
torch.utils.data.DataLoader(
buffer,
batch_size=self.batch_size_mem,
shuffle=True,
drop_last=True,
num_workers=kwargs.get("num_workers", 0),
)
)
else:
self.replay_loader = None
# Freeze model
self.old_model = deepcopy(self.model)
self.old_model.eval()

super()._before_training_exp(**kwargs)

def _after_training_exp(self, **kwargs):
self.replay_loader = None # Allow SER to be checkpointed
self.storage_policy.update(self, **kwargs)
super()._after_training_exp(**kwargs)

def _before_forward(self, **kwargs):
super()._before_forward(**kwargs)
if self.replay_loader is None:
return None

batch_x, batch_y, batch_tid, batch_logits = next(self.replay_loader)
batch_x, batch_y, batch_tid, batch_logits = (
batch_x.to(self.device),
batch_y.to(self.device),
batch_tid.to(self.device),
batch_logits.to(self.device),
)
self.mbatch[0] = torch.cat((batch_x, self.mbatch[0]))
self.mbatch[1] = torch.cat((batch_y, self.mbatch[1]))
self.mbatch[2] = torch.cat((batch_tid, self.mbatch[2]))
self.batch_logits = batch_logits

def training_epoch(self, **kwargs):
"""Training epoch.

:param kwargs:
:return:
"""
for self.mbatch in self.dataloader:
if self._stop_training:
break

self._unpack_minibatch()
self._before_training_iteration(**kwargs)

self.optimizer.zero_grad()
self.loss = self._make_empty_loss()

# Forward
self._before_forward(**kwargs)
self.mb_output = self.forward()
with torch.no_grad():
old_mb_output = avalanche_forward(
self.old_model, self.mb_x, self.mb_task_id
)
self._after_forward(**kwargs)

if self.replay_loader is not None:

# Classification loss on current task and memory data
self.loss += F.cross_entropy(
self.mb_output[:],
self.mb_y[:],
)
# Backward consistency loss on memory data
self.loss += self.alpha * F.mse_loss(
self.mb_output[: self.batch_size_mem],
self.batch_logits,
)
# Forward consistency loss on current task data
self.loss += self.beta * F.mse_loss(
old_mb_output[self.batch_size_mem :],
self.mb_output[self.batch_size_mem :],
)

else:
self.loss += self.criterion()

self._before_backward(**kwargs)
self.backward()
self._after_backward(**kwargs)

# Optimization step
self._before_update(**kwargs)
self.optimizer_step()
self._after_update(**kwargs)

self._after_training_iteration(**kwargs)


__all__ = ["SER"]
18 changes: 18 additions & 0 deletions tests/training/test_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
ER_ACE,
ER_AML,
DER,
SER,
LearningToPrompt,
ExpertGateStrategy,
MER,
Expand Down Expand Up @@ -1113,6 +1114,23 @@ def test_der(self):
)
run_strategy(benchmark, strategy)

def test_ser(self):
# SIT scenario
model, optimizer, criterion, benchmark = self.init_scenario(multi_task=False)
with self.assertWarns(PositionalArgumentsDeprecatedWarning):
strategy = SER(
model,
optimizer,
criterion,
mem_size=1000,
batch_size_mem=10,
train_mb_size=10,
device=self.device,
eval_mb_size=50,
train_epochs=2,
)
run_strategy(benchmark, strategy)

def test_feature_distillation(self):
# SIT scenario
model, optimizer, criterion, benchmark = self.init_scenario(multi_task=False)
Expand Down