From c1cb3eac1c374ed612f5db39241e0dabc784efcc Mon Sep 17 00:00:00 2001 From: Break Yang Date: Mon, 26 Jul 2021 16:46:08 -0700 Subject: [PATCH] [REFACTOR] train.py to consolidate common logic for both single GPU and multi GPU training (#913) (#944) * [REFACTOR] train.py to consolidate common logic for both single GPU and multi GPU training * Address Wei's comments * Address Haonan's comments * Specify authoritative url and port as well * Remove unused Optional typing --- alf/bin/train.py | 133 +++++++++++++++++++++++++++++++++++++---------- 1 file changed, 105 insertions(+), 28 deletions(-) diff --git a/alf/bin/train.py b/alf/bin/train.py index 0f2b85dac..9aeaf515e 100644 --- a/alf/bin/train.py +++ b/alf/bin/train.py @@ -51,6 +51,8 @@ import os import pathlib import torch +import torch.distributed as dist +import torch.multiprocessing as mp from alf.utils import common import alf.utils.external_configurables @@ -67,54 +69,129 @@ def _define_flags(): flags.DEFINE_multi_string('conf_param', None, 'Config binding parameters.') flags.DEFINE_bool('store_snapshot', True, 'Whether store an ALF snapshot before training') + flags.DEFINE_enum( + 'distributed', 'none', ['none', 'multi-gpu'], + 'Set whether and how to run trainning in distributed mode.') + flags.mark_flag_as_required('root_dir') FLAGS = flags.FLAGS -@alf.configurable -def train_eval(root_dir): - """Train and evaluate algorithm +def _setup_logging(rank: int, log_dir: str): + """Setup logging for each process + + Args: + rank (int): The ID of the process among all of the DDP processes + log_dir (str): path to the direcotry where log files are written to + """ + FLAGS.alsologtostderr = True + logging.set_verbosity(logging.INFO) + logging.get_absl_handler().use_absl_log_file(log_dir=log_dir) + + +def _setup_device(rank: int = 0): + """Setup the GPU device for each process + + All tensors of the calling process will use the GPU with the + specified rank by default. Args: - root_dir (str): directory for saving summary and checkpoints + rank (int): The ID of the process among all of the DDP processes + """ - trainer_conf = policy_trainer.TrainerConfig(root_dir=root_dir) - if trainer_conf.ml_type == 'rl': - trainer = policy_trainer.RLTrainer(trainer_conf) - elif trainer_conf.ml_type == 'sl': - trainer = policy_trainer.SLTrainer(trainer_conf) - else: - raise ValueError("Unsupported ml_type: %s" % trainer_conf.ml_type) + if torch.cuda.is_available(): + alf.set_default_device('cuda') + torch.cuda.set_device(rank) - trainer.train() + +def training_worker(rank: int, world_size: int, conf_file: str, root_dir: str): + """An executable instance that trains and evaluate the algorithm + + Args: + rank (int): The ID of the process among all of the DDP processes. + world_size (int): The number of processes in total. If set to 1, it is interpreted as "non distributed mode". + conf_file (str): Path to the training configuration. + root_dir (str): Path to the directory for writing logs/summaries/checkpoints. + """ + try: + _setup_logging(log_dir=root_dir, rank=rank) + _setup_device(rank) + if world_size > 1: + # Specialization for distributed mode + dist.init_process_group('nccl', rank=rank, world_size=world_size) + # TODO(breakds): Remove this when DDP is finally working + # TODO(breakds): Also update the file level documentation when DDP is working + raise RuntimeError( + "Mutli-GPU DDP training is under development and temporarily unavailble" + ) + + # Parse the configuration file, which will also implicitly bring up the environments. + common.parse_conf_file(conf_file) + trainer_conf = policy_trainer.TrainerConfig(root_dir=root_dir) + + if trainer_conf.ml_type == 'rl': + trainer = policy_trainer.RLTrainer(trainer_conf) + elif trainer_conf.ml_type == 'sl': + # NOTE: SLTrainer does not support distributed training yet + if world_size > 1: + raise RuntimeError( + "Multi-GPU DDP training does not support supervised learning" + ) + trainer = policy_trainer.SLTrainer(trainer_conf) + else: + raise ValueError("Unsupported ml_type: %s" % trainer_conf.ml_type) + + trainer.train() + except Exception as e: + # If the training worker is running as a process in multiprocessing + # environment, this will make sure that the exception raised in this + # particular process is captured and shown. + logging.exception(e) + finally: + # Note that each training worker will have its own child processes + # running the environments. In the case when training worker process + # finishes ealier (e.g. when it raises an exception), it will hang + # instead of quitting unless all child processes are killed. + alf.close_env() def main(_): - FLAGS.alsologtostderr = True root_dir = common.abs_path(FLAGS.root_dir) os.makedirs(root_dir, exist_ok=True) - logging.get_absl_handler().use_absl_log_file(log_dir=root_dir) if FLAGS.store_snapshot: - # ..//alf/bin/train.py - file_path = os.path.abspath(__file__) - alf_root = str(pathlib.Path(file_path).parent.parent.parent.absolute()) - # generate a snapshot of ALF repo as ``/alf`` - common.generate_alf_root_snapshot(alf_root, root_dir) + common.generate_alf_root_snapshot(common.alf_root(), root_dir) conf_file = common.get_conf_file() - try: - common.parse_conf_file(conf_file) - train_eval(root_dir) - finally: - alf.close_env() + + # FLAGS.distributed is guaranteed to be one of the possible values. + if FLAGS.distributed == 'none': + training_worker( + rank=0, world_size=1, conf_file=conf_file, root_dir=root_dir) + elif FLAGS.distributed == 'multi-gpu': + world_size = torch.cuda.device_count() + + if world_size == 1: + logging.warn( + 'Fallback to single GPU mode as there is only one GPU') + training_worker( + rank=0, world_size=1, conf_file=conf_file, root_dir=root_dir) + return + + # The other process will communicate with the authoritative + # process via network protocol on localhost:12355. + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '12355' + + processes = mp.spawn( + training_worker, + args=(world_size, conf_file, root_dir), + join=True, + nprocs=world_size, + start_method='spawn') if __name__ == '__main__': _define_flags() - logging.set_verbosity(logging.INFO) - flags.mark_flag_as_required('root_dir') - if torch.cuda.is_available(): - alf.set_default_device("cuda") app.run(main)