From b96c07d4425996c1d38e71c0bcf7dc9ab548a36e Mon Sep 17 00:00:00 2001 From: ZiyueHuang Date: Tue, 1 Sep 2020 10:02:39 +0000 Subject: [PATCH] distributed training --- scripts/pretraining/run_electra.py | 12 +++++++++--- src/gluonnlp/utils/misc.py | 14 ++++++++++++++ 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/scripts/pretraining/run_electra.py b/scripts/pretraining/run_electra.py index 1678eeae8d..9522a6fb8c 100644 --- a/scripts/pretraining/run_electra.py +++ b/scripts/pretraining/run_electra.py @@ -22,6 +22,10 @@ import horovod.mxnet as hvd except ImportError: pass +try: + import byteps.mxnet as bps +except ImportError: + pass mx.npx.set_np() @@ -118,7 +122,7 @@ def parse_args(): help='The scale size of the generator layer') # Communication parser.add_argument('--comm_backend', type=str, default='device', - choices=['horovod', 'dist_sync_device', 'device'], + choices=['byteps', 'horovod', 'dist_sync_device', 'device'], help='Communication backend.') parser.add_argument('--gpus', type=str, default='0', help='list of gpus to run, e.g. 0 or 0,2,5. -1 means using cpu.') @@ -316,6 +320,8 @@ def train(args): }) if args.comm_backend == 'horovod': trainer = hvd.DistributedTrainer(param_dict, args.optimizer, optimizer_params) + elif args.comm_backend == 'byteps': + trainer = bps.DistributedTrainer(param_dict, args.optimizer, optimizer_params) else: trainer = mx.gluon.Trainer(param_dict, args.optimizer, optimizer_params, update_on_kvstore=False) @@ -414,8 +420,8 @@ def train(args): total_norm, ratio, is_finite = clip_grad_global_norm( params, args.max_grad_norm * num_workers) - if args.comm_backend == 'horovod': - # Note that horovod.trainer._scale is default to num_workers, + if args.comm_backend == 'horovod' or args.comm_backend == 'byteps': + # Note that hvd.trainer._scale and bps.trainer._scale are default to num_workers, # thus trainer.update(1) will scale the gradients by 1./num_workers trainer.update(1, ignore_stale_grad=True) else: diff --git a/src/gluonnlp/utils/misc.py b/src/gluonnlp/utils/misc.py index 38d1fa6258..7a35f221b3 100644 --- a/src/gluonnlp/utils/misc.py +++ b/src/gluonnlp/utils/misc.py @@ -637,6 +637,20 @@ def init_comm(backend, gpus): is_master_node = rank == local_rank ctx_l = [mx.gpu(local_rank)] logging.info('GPU communication supported by horovod') + elif backend == 'byteps': + try: + import byteps.mxnet as bps + except ImportError: + logging.info('BytePS must be installed.') + sys.exit(1) + bps.init() + store = None + num_workers = bps.size() + rank = bps.rank() + local_rank = bps.local_rank() + is_master_node = rank == local_rank + ctx_l = [mx.gpu(local_rank)] + logging.info('GPU communication supported by BytePS') else: store = mx.kv.create(backend) num_workers = store.num_workers