Skip to content

Commit

Permalink
Fix mnist example (microsoft#4926)
Browse files Browse the repository at this point in the history
  • Loading branch information
Thiago Crepaldi authored Aug 26, 2020
1 parent 438babd commit cac2575
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions orttraining/pytorch_frontend_examples/mnist_training.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
## This code is from https://github.com/pytorch/examples/blob/master/mnist/main.py
## with modification to do training using onnxruntime as backend on cuda device.
## A private PyTorch build from https://aiinfra.visualstudio.com/Lotus/_git/pytorch (ORTTraining branch) is needed to run the demo.
## To run the demo with ORT backend:
## python mnist_training.py --use-ort

## When "--use-ort" is not given, it will run training with PyTorch as backend.
## Model testing is not complete.

from __future__ import print_function
Expand Down Expand Up @@ -88,20 +85,13 @@ def main():
help='number of epochs to train (default: 10)')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
help='learning rate (default: 0.01)')
parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
help='SGD momentum (default: 0.5)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
help='how many batches to wait before logging training status')

parser.add_argument('--save-model', action='store_true', default=False,
help='For Saving the current Model')

parser.add_argument('--use-ort', action='store_true', default=False,
help='to use onnxruntime as training backend')

args = parser.parse_args()
use_cuda = not args.no_cuda and torch.cuda.is_available()
Expand Down Expand Up @@ -141,8 +131,18 @@ def main():

model_desc = mnist_model_description()
# use log_interval as gradient accumulate steps
trainer = ORTTrainer(model, my_loss, model_desc, "LambOptimizer", None, IODescription('Learning_Rate', [1,], torch.float32), device, 1, None,
args.world_rank, args.world_size, use_mixed_precision=False, allreduce_post_accumulation = True)
trainer = ORTTrainer(model,
my_loss,
model_desc,
"SGDOptimizer",
None,
IODescription('Learning_Rate', [1,], torch.float32),
device,
1,
args.world_rank,
args.world_size,
use_mixed_precision=False,
allreduce_post_accumulation=True)
print('\nBuild ort model done.')

for epoch in range(1, args.epochs + 1):
Expand Down

0 comments on commit cac2575

Please sign in to comment.