Skip to content

Commit

Permalink
Change state dict loading default to strict (mosaicml#2216)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored May 11, 2023
1 parent 07747f9 commit 4577174
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 10 deletions.
18 changes: 14 additions & 4 deletions composer/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -977,11 +977,21 @@ def load_model_state(
# with the `module.` prefix
torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(state_dict['model'], 'module.')

if self.fsdp_enabled and self.fsdp_state_dict_type is not None:
with fsdp_state_dict_type_context(self.model, state_dict_type=self.fsdp_state_dict_type):
try:
if self.fsdp_enabled and self.fsdp_state_dict_type is not None:
with fsdp_state_dict_type_context(self.model, state_dict_type=self.fsdp_state_dict_type):
missing_keys, unexpected_keys = self.model.load_state_dict(state_dict['model'], strict=strict)
else:
missing_keys, unexpected_keys = self.model.load_state_dict(state_dict['model'], strict=strict)
else:
missing_keys, unexpected_keys = self.model.load_state_dict(state_dict['model'], strict=strict)
except RuntimeError as e:
if 'Missing key(s) in state_dict' in str(e) or 'Unexpected key(s) in state_dict' in str(e):
raise RuntimeError(
textwrap.dedent('Failed to load checkpoint due to missing or unexpected keys in state_dict. '
'This is likely due to a change in the model architecture. If this is intentional, '
'you can set load_strict_model_weights=False in the Trainer.')) from e
else:
raise e

if len(missing_keys) > 0:
log.warning(f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}")
if len(unexpected_keys) > 0:
Expand Down
4 changes: 2 additions & 2 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,7 @@ class Trainer:
load_weights_only (bool, optional): Whether or not to only restore the weights from the checkpoint without
restoring the associated state. Ignored if ``load_path`` is ``None``. (default: ``False``)
load_strict_model_weights (bool, optional): Ensure that the set of weights in the checkpoint and model must exactly match.
Ignored if ``load_path`` is ``None``. (default: ``False``)
Ignored if ``load_path`` is ``None``. (default: ``True``)
load_progress_bar (bool, optional): Display the progress bar for downloading the checkpoint.
Ignored if ``load_path`` is either ``None`` or a local file path. (default: ``True``)
load_ignore_keys (List[str] | (Dict) -> None, optional): A list of paths for the ``state_dict`` of the checkpoint,
Expand Down Expand Up @@ -829,7 +829,7 @@ def __init__(
load_path: Optional[str] = None,
load_object_store: Optional[Union[ObjectStore, LoggerDestination]] = None,
load_weights_only: bool = False,
load_strict_model_weights: bool = False,
load_strict_model_weights: bool = True,
load_progress_bar: bool = True,
load_ignore_keys: Optional[Union[List[str], Callable[[Dict], None]]] = None,
load_exclude_algorithms: Optional[List[str]] = None,
Expand Down
3 changes: 2 additions & 1 deletion examples/pretrain_finetune_huggingface.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@
" optimizers=optimizer,\n",
" schedulers=[lr_scheduler],\n",
" device='gpu' if torch.cuda.is_available() else 'cpu',\n",
" #train_subset_num_batches=100, # uncomment this line to only run part of training, which will be faster\n",
" # train_subset_num_batches=100, # uncomment this line to only run part of training, which will be faster\n",
" precision='fp32',\n",
" seed=17,\n",
")\n",
Expand Down Expand Up @@ -407,6 +407,7 @@
" save_folder='checkpoints/finetuning/',\n",
" load_path=f'checkpoints/pretraining/latest-rank0.pt',\n",
" load_weights_only=True, # We're starting a new training run, so we just the model weights\n",
" load_strict_model_weights=False, # We're going from the original model, which is for MaskedLM, to a new model, for SequenceClassification\n",
" optimizers=optimizer,\n",
" schedulers=[lr_scheduler],\n",
" device='gpu' if torch.cuda.is_available() else 'cpu',\n",
Expand Down
1 change: 1 addition & 0 deletions tests/test_full_nlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def finetuning_test_helper(tokenizer, model, algorithms, checkpoint_path, pretra
save_folder='finetuning_checkpoints',
load_path=checkpoint_path,
load_weights_only=True,
load_strict_model_weights=False,
loggers=[rud],
max_duration='1ep',
seed=17,
Expand Down
33 changes: 31 additions & 2 deletions tests/trainer/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,8 +395,9 @@ def _metrics_equal(self, train_metrics_1, train_metrics_2, eval_metrics_1, eval_
except AssertionError:
return False

def get_trainer(self, max_duration='2ep', **kwargs):
model = SimpleConvModel()
def get_trainer(self, model=None, max_duration='2ep', **kwargs):
if model is None:
model = SimpleConvModel()
optimizer = torch.optim.Adam(model.parameters())

train_dataset = RandomImageDataset()
Expand Down Expand Up @@ -475,6 +476,34 @@ def test_other_backends_error(self, load_path: str, monkeypatch: MonkeyPatch):
with pytest.raises(NotImplementedError):
self.get_trainer(load_path=load_path)

@pytest.mark.parametrize('missing_key', [True, False])
@pytest.mark.parametrize('unexpected_key', [True, False])
def test_strict_errors(self, missing_key: bool, unexpected_key: bool):
model1 = SimpleConvModel()
if unexpected_key:
model1.unexpected_dummy = torch.nn.Parameter(torch.zeros(1))

trainer_1 = self.get_trainer(model=model1, save_folder='first')
trainer_1.fit()
trainer_1.close()

model2 = SimpleConvModel()
if missing_key:
model2.missing_dummy = torch.nn.Parameter(torch.zeros(1))

last_checkpoint = os.path.join('first', 'ep2.pt')
if missing_key or unexpected_key:
error_context = pytest.raises(RuntimeError, match='Failed to load checkpoint due to')
else:
error_context = contextlib.nullcontext()

with error_context:
_ = self.get_trainer(
model=model2,
load_path=last_checkpoint,
load_weights_only=True,
)

@device('cpu', 'gpu')
@pytest.mark.parametrize('load_weights_only', [True, False])
def test_load_weights(self, device, load_weights_only):
Expand Down
2 changes: 1 addition & 1 deletion tests/utils/test_autolog_hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def test_extract_hparams_trainer():
'load_path': None,
'load_object_store': None,
'load_weights_only': False,
'load_strict_model_weights': False,
'load_strict_model_weights': True,
'load_progress_bar': True,
'load_ignore_keys': None,
'load_exclude_algorithms': None,
Expand Down

0 comments on commit 4577174

Please sign in to comment.