diff --git a/composer/core/evaluator.py b/composer/core/evaluator.py index e1cf73fa97..6f54e01195 100644 --- a/composer/core/evaluator.py +++ b/composer/core/evaluator.py @@ -86,11 +86,10 @@ def __init__( self.label = label self.dataloader = ensure_data_spec(dataloader) - self.metric_names = [] if metric_names is not None: if not isinstance(metric_names, list): raise ValueError(f'``metric_names`` should be a list of strings, not a {type(metric_names)}') - self.metric_names = metric_names + self.metric_names = metric_names self.subset_num_batches = subset_num_batches self._eval_interval = None diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 2cf79bf06c..be18de5251 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -106,14 +106,13 @@ def _get_default_scheduler_frequency(schedulers: Optional[Union[Scheduler, Seque def _filter_metrics(metrics: Dict[str, Metric], metric_names: Optional[List[str]]) -> Dict[str, Metric]: """Filter the metrics based on the given metric_names as regex strings (e.g. 'Accuracy', 'f1' for 'BinaryF1Score', 'Top-.' for 'Top-1 Accuracy' and 'Top-2 Accuracy', etc). If no metric_names are provided, all metrics will be returned.""" metrics = deepcopy(metrics) - if not metric_names: + if metric_names is None: return metrics - else: - filtered_metrics = {} - for name, metric in metrics.items(): - if any(re.match(f'.*{metric_name}.*', name, re.IGNORECASE) for metric_name in metric_names): - filtered_metrics[name] = metric - return filtered_metrics + filtered_metrics = {} + for name, metric in metrics.items(): + if any(re.match(f'.*{metric_name}.*', name, re.IGNORECASE) for metric_name in metric_names): + filtered_metrics[name] = metric + return filtered_metrics def _validate_precision(precision: Precision, device: Device): diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index d51a9123ac..6408c008b6 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -167,6 +167,46 @@ def test_compile_unsupported_torch_version_exception(self, caplog, model: Compos auto_log_hparams=True, compile_config={}) + def test_eval_metrics(self): + model = SimpleModel() + train_dataloader = DataLoader(RandomClassificationDataset(size=1), batch_size=1) + all_metrics = model.get_metrics(is_train=False) + + # Test default eval metrics + trainer = Trainer( + model=model, + train_dataloader=train_dataloader, + eval_dataloader=Evaluator(label='eval', + dataloader=DataLoader(RandomClassificationDataset(size=1), batch_size=1)), + ) + + assert trainer.state.eval_metrics['eval'] == all_metrics + + # Test empty eval metrics + trainer = Trainer( + model=model, + train_dataloader=train_dataloader, + eval_dataloader=Evaluator(label='eval', + dataloader=DataLoader(RandomClassificationDataset(size=1), batch_size=1), + metric_names=[]), + ) + + assert trainer.state.eval_metrics['eval'] == {} + + # Test selected eval metrics + single_metric = next(iter(all_metrics)) + trainer = Trainer( + model=model, + train_dataloader=train_dataloader, + eval_dataloader=Evaluator(label='eval', + dataloader=DataLoader(RandomClassificationDataset(size=1), batch_size=1), + metric_names=[single_metric]), + ) + + eval_metric_names = trainer.state.eval_metrics['eval'].keys() + assert len(eval_metric_names) == 1 + assert next(iter(eval_metric_names)) == single_metric + def _assert_optimizer_is_on_device(optimizer: torch.optim.Optimizer): for state in optimizer.state.values():