Skip to content

Commit

Permalink
[REFACTOR] train.py to consolidate common logic for both single GPU a…
Browse files Browse the repository at this point in the history
…nd multi GPU training (HorizonRobotics#913) (HorizonRobotics#944)

* [REFACTOR] train.py to consolidate common logic for both single GPU and multi GPU training

* Address Wei's comments

* Address Haonan's comments

* Specify authoritative url and port as well

* Remove unused Optional typing
  • Loading branch information
breakds authored and pd-perry committed Dec 11, 2021
1 parent 2119dbc commit c1cb3ea
Showing 1 changed file with 105 additions and 28 deletions.
133 changes: 105 additions & 28 deletions alf/bin/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@
import os
import pathlib
import torch
import torch.distributed as dist
import torch.multiprocessing as mp

from alf.utils import common
import alf.utils.external_configurables
Expand All @@ -67,54 +69,129 @@ def _define_flags():
flags.DEFINE_multi_string('conf_param', None, 'Config binding parameters.')
flags.DEFINE_bool('store_snapshot', True,
'Whether store an ALF snapshot before training')
flags.DEFINE_enum(
'distributed', 'none', ['none', 'multi-gpu'],
'Set whether and how to run trainning in distributed mode.')
flags.mark_flag_as_required('root_dir')


FLAGS = flags.FLAGS


@alf.configurable
def train_eval(root_dir):
"""Train and evaluate algorithm
def _setup_logging(rank: int, log_dir: str):
"""Setup logging for each process
Args:
rank (int): The ID of the process among all of the DDP processes
log_dir (str): path to the direcotry where log files are written to
"""
FLAGS.alsologtostderr = True
logging.set_verbosity(logging.INFO)
logging.get_absl_handler().use_absl_log_file(log_dir=log_dir)


def _setup_device(rank: int = 0):
"""Setup the GPU device for each process
All tensors of the calling process will use the GPU with the
specified rank by default.
Args:
root_dir (str): directory for saving summary and checkpoints
rank (int): The ID of the process among all of the DDP processes
"""
trainer_conf = policy_trainer.TrainerConfig(root_dir=root_dir)
if trainer_conf.ml_type == 'rl':
trainer = policy_trainer.RLTrainer(trainer_conf)
elif trainer_conf.ml_type == 'sl':
trainer = policy_trainer.SLTrainer(trainer_conf)
else:
raise ValueError("Unsupported ml_type: %s" % trainer_conf.ml_type)
if torch.cuda.is_available():
alf.set_default_device('cuda')
torch.cuda.set_device(rank)

trainer.train()

def training_worker(rank: int, world_size: int, conf_file: str, root_dir: str):
"""An executable instance that trains and evaluate the algorithm
Args:
rank (int): The ID of the process among all of the DDP processes.
world_size (int): The number of processes in total. If set to 1, it is interpreted as "non distributed mode".
conf_file (str): Path to the training configuration.
root_dir (str): Path to the directory for writing logs/summaries/checkpoints.
"""
try:
_setup_logging(log_dir=root_dir, rank=rank)
_setup_device(rank)
if world_size > 1:
# Specialization for distributed mode
dist.init_process_group('nccl', rank=rank, world_size=world_size)
# TODO(breakds): Remove this when DDP is finally working
# TODO(breakds): Also update the file level documentation when DDP is working
raise RuntimeError(
"Mutli-GPU DDP training is under development and temporarily unavailble"
)

# Parse the configuration file, which will also implicitly bring up the environments.
common.parse_conf_file(conf_file)
trainer_conf = policy_trainer.TrainerConfig(root_dir=root_dir)

if trainer_conf.ml_type == 'rl':
trainer = policy_trainer.RLTrainer(trainer_conf)
elif trainer_conf.ml_type == 'sl':
# NOTE: SLTrainer does not support distributed training yet
if world_size > 1:
raise RuntimeError(
"Multi-GPU DDP training does not support supervised learning"
)
trainer = policy_trainer.SLTrainer(trainer_conf)
else:
raise ValueError("Unsupported ml_type: %s" % trainer_conf.ml_type)

trainer.train()
except Exception as e:
# If the training worker is running as a process in multiprocessing
# environment, this will make sure that the exception raised in this
# particular process is captured and shown.
logging.exception(e)
finally:
# Note that each training worker will have its own child processes
# running the environments. In the case when training worker process
# finishes ealier (e.g. when it raises an exception), it will hang
# instead of quitting unless all child processes are killed.
alf.close_env()


def main(_):
FLAGS.alsologtostderr = True
root_dir = common.abs_path(FLAGS.root_dir)
os.makedirs(root_dir, exist_ok=True)
logging.get_absl_handler().use_absl_log_file(log_dir=root_dir)

if FLAGS.store_snapshot:
# ../<ALF_REPO>/alf/bin/train.py
file_path = os.path.abspath(__file__)
alf_root = str(pathlib.Path(file_path).parent.parent.parent.absolute())
# generate a snapshot of ALF repo as ``<root_dir>/alf``
common.generate_alf_root_snapshot(alf_root, root_dir)
common.generate_alf_root_snapshot(common.alf_root(), root_dir)

conf_file = common.get_conf_file()
try:
common.parse_conf_file(conf_file)
train_eval(root_dir)
finally:
alf.close_env()

# FLAGS.distributed is guaranteed to be one of the possible values.
if FLAGS.distributed == 'none':
training_worker(
rank=0, world_size=1, conf_file=conf_file, root_dir=root_dir)
elif FLAGS.distributed == 'multi-gpu':
world_size = torch.cuda.device_count()

if world_size == 1:
logging.warn(
'Fallback to single GPU mode as there is only one GPU')
training_worker(
rank=0, world_size=1, conf_file=conf_file, root_dir=root_dir)
return

# The other process will communicate with the authoritative
# process via network protocol on localhost:12355.
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'

processes = mp.spawn(
training_worker,
args=(world_size, conf_file, root_dir),
join=True,
nprocs=world_size,
start_method='spawn')


if __name__ == '__main__':
_define_flags()
logging.set_verbosity(logging.INFO)
flags.mark_flag_as_required('root_dir')
if torch.cuda.is_available():
alf.set_default_device("cuda")
app.run(main)

0 comments on commit c1cb3ea

Please sign in to comment.