diff --git a/composer/loggers/mosaicml_logger.py b/composer/loggers/mosaicml_logger.py index b3a226f2c2..d8be0c8b1f 100644 --- a/composer/loggers/mosaicml_logger.py +++ b/composer/loggers/mosaicml_logger.py @@ -18,8 +18,8 @@ import mcli import torch +from composer.core.time import TimeUnit from composer.loggers import Logger -from composer.loggers.logger import Logger from composer.loggers.logger_destination import LoggerDestination from composer.loggers.wandb_logger import WandBLogger from composer.utils import dist @@ -70,9 +70,9 @@ def __init__( if self._enabled: self.allowed_fails_left = 3 self.time_last_logged = 0 + self.train_dataloader_len = None self.time_failed_count_adjusted = 0 self.buffered_metadata: Dict[str, Any] = {} - self.run_name = os.environ.get(RUN_NAME_ENV_VAR) if self.run_name is not None: log.info(f'Logging to mosaic run {self.run_name}') @@ -88,6 +88,8 @@ def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None) -> No self._log_metadata(metrics) def after_load(self, state: State, logger: Logger) -> None: + # Log model data downloaded and initialized for run events + self._log_metadata({'model_initialized_time': time.time()}) # Log WandB run URL if it exists. Must run on after_load as WandB is setup on event init for callback in state.callbacks: if isinstance(callback, WandBLogger): @@ -95,13 +97,62 @@ def after_load(self, state: State, logger: Logger) -> None: if run_url is not None: self._log_metadata({'wandb/run_url': run_url}) + def _get_training_progress_metrics(self, state: State) -> Dict[str, Any]: + """Calculates training progress metrics. + + If user submits max duration: + - in tokens -> format: [token=x/xx] + - in batches -> format: [batch=x/xx] + - in epoch -> format: [epoch=x/xx] [batch=x/xx] (where batch refers to batches completed in current epoch) + If batches per epoch cannot be calculated, return [epoch=x/xx] + + If no training duration given -> format: '' + """ + if not self._enabled: + return {} + + assert state.max_duration is not None + if state.max_duration.unit == TimeUnit.TOKEN: + return { + 'training_progress': f'[token={state.timestamp.token.value}/{state.max_duration.value}]', + } + if state.max_duration.unit == TimeUnit.BATCH: + return { + 'training_progress': f'[batch={state.timestamp.batch.value}/{state.max_duration.value}]', + } + training_progress_metrics = {} + if state.max_duration.unit == TimeUnit.EPOCH: + cur_batch = state.timestamp.batch_in_epoch.value + cur_epoch = state.timestamp.epoch.value + if state.timestamp.epoch.value >= 1: + batches_per_epoch = (state.timestamp.batch - + state.timestamp.batch_in_epoch).value // state.timestamp.epoch.value + curr_progress = f'[batch={cur_batch}/{batches_per_epoch}]' + elif self.train_dataloader_len is not None: + curr_progress = f'[batch={cur_batch}/{self.train_dataloader_len}]' + else: + curr_progress = f'[batch={cur_batch}]' + if cur_epoch < state.max_duration.value: + cur_epoch += 1 + training_progress_metrics = { + 'training_sub_progress': curr_progress, + 'training_progress': f'[epoch={cur_epoch}/{state.max_duration.value}]', + } + return training_progress_metrics + + def batch_start(self, state: State, logger: Logger) -> None: + if state.dataloader_len is not None and self._enabled: + self.train_dataloader_len = state.dataloader_len.value + def batch_end(self, state: State, logger: Logger) -> None: + self._log_metadata(self._get_training_progress_metrics(state)) self._flush_metadata() def epoch_end(self, state: State, logger: Logger) -> None: self._flush_metadata() def fit_end(self, state: State, logger: Logger) -> None: + self._log_metadata(self._get_training_progress_metrics(state)) self._flush_metadata(force_flush=True) def eval_end(self, state: State, logger: Logger) -> None: diff --git a/tests/loggers/test_mosaicml_logger.py b/tests/loggers/test_mosaicml_logger.py index 53211f0da5..136af6f107 100644 --- a/tests/loggers/test_mosaicml_logger.py +++ b/tests/loggers/test_mosaicml_logger.py @@ -3,13 +3,14 @@ import json from typing import Type +from unittest.mock import MagicMock import mcli import pytest import torch from torch.utils.data import DataLoader -from composer.core import Callback +from composer.core import Callback, Time, TimeUnit from composer.loggers import WandBLogger from composer.loggers.mosaicml_logger import (MOSAICML_ACCESS_TOKEN_ENV_VAR, MOSAICML_PLATFORM_ENV_VAR, MosaicMLLogger, format_data_to_json_serializable) @@ -194,3 +195,89 @@ def test_auto_add_logger(monkeypatch, platform_env_var, access_token_env_var, lo # Otherwise, no logger else: assert logger_count == 0 + + +def test_run_events_logged(monkeypatch): + '''' + Current run events include: + 1. model initialization time + 2. training progress (i.e. [batch=x/xx] at batch end) + ''' + mock_mapi = MockMAPI() + monkeypatch.setattr(mcli, 'update_run_metadata', mock_mapi.update_run_metadata) + run_name = 'test-run-name' + monkeypatch.setenv('RUN_NAME', run_name) + trainer = Trainer(model=SimpleModel(), + train_dataloader=DataLoader(RandomClassificationDataset()), + train_subset_num_batches=1, + max_duration='4ba', + loggers=[MosaicMLLogger()]) + trainer.fit() + metadata = mock_mapi.run_metadata[run_name] + assert isinstance(metadata['mosaicml/model_initialized_time'], float) + assert 'mosaicml/training_progress' in metadata + assert metadata['mosaicml/training_progress'] == '[batch=4/4]' + assert 'mosaicml/training_sub_progress' not in metadata + + +def test_token_training_progress_metrics(): + logger = MosaicMLLogger() + logger._enabled = True + state = MagicMock() + state.max_duration.unit = TimeUnit.TOKEN + state.max_duration.value = 64 + state.timestamp.token.value = 50 + training_progress = logger._get_training_progress_metrics(state) + assert 'training_progress' in training_progress + assert training_progress['training_progress'] == '[token=50/64]' + assert 'training_sub_progress' not in training_progress + + +def test_epoch_training_progress_metrics(): + logger = MosaicMLLogger() + logger._enabled = True + state = MagicMock() + state.max_duration.unit = TimeUnit.EPOCH + state.max_duration = Time(3, TimeUnit.EPOCH) + state.timestamp.epoch = Time(2, TimeUnit.EPOCH) + state.timestamp.batch = Time(11, TimeUnit.BATCH) + state.timestamp.batch_in_epoch = Time(1, TimeUnit.BATCH) + training_progress = logger._get_training_progress_metrics(state) + assert 'training_progress' in training_progress + assert training_progress['training_progress'] == '[epoch=3/3]' + assert 'training_sub_progress' in training_progress + assert training_progress['training_sub_progress'] == '[batch=1/5]' + + +def test_epoch_zero_progress_metrics(): + logger = MosaicMLLogger() + logger._enabled = True + state = MagicMock() + logger.train_dataloader_len = 5 + state.max_duration.unit = TimeUnit.EPOCH + state.max_duration = Time(3, TimeUnit.EPOCH) + state.timestamp.epoch = Time(0, TimeUnit.EPOCH) + state.timestamp.batch = Time(0, TimeUnit.BATCH) + state.timestamp.batch_in_epoch = Time(0, TimeUnit.BATCH) + training_progress = logger._get_training_progress_metrics(state) + assert 'training_progress' in training_progress + assert training_progress['training_progress'] == '[epoch=1/3]' + assert 'training_sub_progress' in training_progress + assert training_progress['training_sub_progress'] == '[batch=0/5]' + + +def test_epoch_zero_no_dataloader_progress_metrics(): + logger = MosaicMLLogger() + logger._enabled = True + state = MagicMock() + state.dataloader_len = None + state.max_duration.unit = TimeUnit.EPOCH + state.max_duration = Time(3, TimeUnit.EPOCH) + state.timestamp.epoch = Time(0, TimeUnit.EPOCH) + state.timestamp.batch = Time(1, TimeUnit.BATCH) + state.timestamp.batch_in_epoch = Time(1, TimeUnit.BATCH) + training_progress = logger._get_training_progress_metrics(state) + assert 'training_progress' in training_progress + assert training_progress['training_progress'] == '[epoch=1/3]' + assert 'training_sub_progress' in training_progress + assert training_progress['training_sub_progress'] == '[batch=1]'