Skip to content

Commit

Permalink
Add MPS device (#1197)
Browse files Browse the repository at this point in the history
* Add MPS device

* Fix device setting
  • Loading branch information
chmjkb authored Jan 12, 2024
1 parent de85c09 commit 3e56db2
Showing 1 changed file with 14 additions and 6 deletions.
20 changes: 14 additions & 6 deletions mnist_rnn/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,10 @@ def main():
help='learning rate (default: 0.1)')
parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
help='learning rate step gamma (default: 0.7)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--cuda', action='store_true', default=False,
help='enables CUDA training')
parser.add_argument('--mps', action="store_true", default=False,
help="enables MPS training")
parser.add_argument('--dry-run', action='store_true', default=False,
help='quickly check a single pass')
parser.add_argument('--seed', type=int, default=1, metavar='S',
Expand All @@ -102,13 +104,19 @@ def main():
parser.add_argument('--save-model', action='store_true', default=False,
help='for Saving the current Model')
args = parser.parse_args()
use_cuda = not args.no_cuda and torch.cuda.is_available()

torch.manual_seed(args.seed)
if args.cuda and not args.mps:
device = "cuda"
elif args.mps and not args.cuda:
device = "mps"
else:
device = "cpu"

device = torch.device(device)

device = torch.device("cuda" if use_cuda else "cpu")
torch.manual_seed(args.seed)

kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=True, download=True,
transform=transforms.Compose([
Expand Down

0 comments on commit 3e56db2

Please sign in to comment.