diff --git a/mnist_rnn/main.py b/mnist_rnn/main.py index 57e86cd57d..2fa64c00d6 100644 --- a/mnist_rnn/main.py +++ b/mnist_rnn/main.py @@ -91,8 +91,10 @@ def main(): help='learning rate (default: 0.1)') parser.add_argument('--gamma', type=float, default=0.7, metavar='M', help='learning rate step gamma (default: 0.7)') - parser.add_argument('--no-cuda', action='store_true', default=False, - help='disables CUDA training') + parser.add_argument('--cuda', action='store_true', default=False, + help='enables CUDA training') + parser.add_argument('--mps', action="store_true", default=False, + help="enables MPS training") parser.add_argument('--dry-run', action='store_true', default=False, help='quickly check a single pass') parser.add_argument('--seed', type=int, default=1, metavar='S', @@ -102,13 +104,19 @@ def main(): parser.add_argument('--save-model', action='store_true', default=False, help='for Saving the current Model') args = parser.parse_args() - use_cuda = not args.no_cuda and torch.cuda.is_available() - torch.manual_seed(args.seed) + if args.cuda and not args.mps: + device = "cuda" + elif args.mps and not args.cuda: + device = "mps" + else: + device = "cpu" + + device = torch.device(device) - device = torch.device("cuda" if use_cuda else "cpu") + torch.manual_seed(args.seed) - kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} + kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} train_loader = torch.utils.data.DataLoader( datasets.MNIST('../data', train=True, download=True, transform=transforms.Compose([