-
Notifications
You must be signed in to change notification settings - Fork 9.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DOC] Update mnist.py example #1270
Comments
Proposal: diff --git a/mnist/main.py b/mnist/main.py
index 184dc47..a3cffd1 100644
--- a/mnist/main.py
+++ b/mnist/main.py
@@ -3,13 +3,14 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
-from torchvision import datasets, transforms
+from torchvision import datasets
+from torchvision.transforms import v2 as transforms
from torch.optim.lr_scheduler import StepLR
class Net(nn.Module):
def __init__(self):
- super(Net, self).__init__()
+ super().__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout(0.25)
@@ -33,19 +34,42 @@ class Net(nn.Module):
return output
-def train(args, model, device, train_loader, optimizer, epoch):
+def train_amp(args, model, device, train_loader, opt, epoch, scaler):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
- data, target = data.to(device), target.to(device)
- optimizer.zero_grad()
+ data, target = data.to(device, memory_format=torch.channels_last), target.to(
+ device
+ )
+ opt.zero_grad()
+ with torch.autocast(device_type=device.type):
+ output = model(data)
+ loss = F.nll_loss(output, target)
+ scaler.scale(loss).backward()
+ scaler.step(opt)
+ scaler.update()
+ if batch_idx % args.log_interval == 0:
+ print(
+ f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100.0 * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}"
+ )
+ if args.dry_run:
+ break
+
+
+def train(args, model, device, train_loader, opt, epoch):
+ model.train()
+ for batch_idx, (data, target) in enumerate(train_loader):
+ data, target = data.to(device, memory_format=torch.channels_last), target.to(
+ device
+ )
+ opt.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
- optimizer.step()
+ opt.step()
if batch_idx % args.log_interval == 0:
- print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
- epoch, batch_idx * len(data), len(train_loader.dataset),
- 100. * batch_idx / len(train_loader), loss.item()))
+ print(
+ f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100.0 * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}"
+ )
if args.dry_run:
break
@@ -58,43 +82,125 @@ def test(model, device, test_loader):
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
- test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
- pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
+ test_loss += F.nll_loss(
+ output, target, reduction="sum"
+ ).item() # sum up batch loss
+ pred = output.argmax(
+ dim=1, keepdim=True
+ ) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
- print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
- test_loss, correct, len(test_loader.dataset),
- 100. * correct / len(test_loader.dataset)))
+ print(
+ f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({100.0 * correct / len(test_loader.dataset):.0f}%)\n"
+ )
-def main():
+def parse_args():
# Training settings
- parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
- parser.add_argument('--batch-size', type=int, default=64, metavar='N',
- help='input batch size for training (default: 64)')
- parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
- help='input batch size for testing (default: 1000)')
- parser.add_argument('--epochs', type=int, default=14, metavar='N',
- help='number of epochs to train (default: 14)')
- parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
- help='learning rate (default: 1.0)')
- parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
- help='Learning rate step gamma (default: 0.7)')
- parser.add_argument('--no-cuda', action='store_true', default=False,
- help='disables CUDA training')
- parser.add_argument('--no-mps', action='store_true', default=False,
- help='disables macOS GPU training')
- parser.add_argument('--dry-run', action='store_true', default=False,
- help='quickly check a single pass')
- parser.add_argument('--seed', type=int, default=1, metavar='S',
- help='random seed (default: 1)')
- parser.add_argument('--log-interval', type=int, default=10, metavar='N',
- help='how many batches to wait before logging training status')
- parser.add_argument('--save-model', action='store_true', default=False,
- help='For Saving the current Model')
+ parser = argparse.ArgumentParser(description="PyTorch MNIST Example")
+ parser.add_argument(
+ "--batch-size",
+ type=int,
+ default=64,
+ metavar="N",
+ help="input batch size for training (default: 64)",
+ )
+ parser.add_argument(
+ "--test-batch-size",
+ type=int,
+ default=1000,
+ metavar="N",
+ help="input batch size for testing (default: 1000)",
+ )
+ parser.add_argument(
+ "--epochs",
+ type=int,
+ default=14,
+ metavar="N",
+ help="number of epochs to train (default: 14)",
+ )
+ parser.add_argument(
+ "--lr",
+ type=float,
+ default=1.0,
+ metavar="LR",
+ help="learning rate (default: 1.0)",
+ )
+ parser.add_argument(
+ "--gamma",
+ type=float,
+ default=0.7,
+ metavar="M",
+ help="Learning rate step gamma (default: 0.7)",
+ )
+ parser.add_argument(
+ "--no-cuda", action="store_true", default=False, help="disables CUDA training"
+ )
+ parser.add_argument(
+ "--no-mps",
+ action="store_true",
+ default=False,
+ help="disables macOS GPU training",
+ )
+ parser.add_argument(
+ "--dry-run",
+ action="store_true",
+ default=False,
+ help="quickly check a single pass",
+ )
+ parser.add_argument(
+ "--seed", type=int, default=1, metavar="S", help="random seed (default: 1)"
+ )
+ parser.add_argument(
+ "--log-interval",
+ type=int,
+ default=10,
+ metavar="N",
+ help="how many batches to wait before logging training status",
+ )
+ parser.add_argument(
+ "--use-amp",
+ type=bool,
+ default=False,
+ help="use automatic mixed precision",
+ )
+ parser.add_argument(
+ "--compile-backend",
+ type=str,
+ default="inductor",
+ metavar="BACKEND",
+ help="backend to compile the model with",
+ )
+ parser.add_argument(
+ "--compile-mode",
+ type=str,
+ default="default",
+ metavar="MODE",
+ help="compilation mode",
+ )
+ parser.add_argument(
+ "--save-model",
+ action="store_true",
+ default=False,
+ help="For Saving the current Model",
+ )
+ parser.add_argument(
+ "--data-dir",
+ type=str,
+ default="../data",
+ metavar="DIR",
+ help="path to the data directory",
+ )
args = parser.parse_args()
+
+ return args
+
+
+def main():
+ args = parse_args()
+
use_cuda = not args.no_cuda and torch.cuda.is_available()
use_mps = not args.no_mps and torch.backends.mps.is_available()
@@ -107,32 +213,43 @@ def main():
else:
device = torch.device("cpu")
- train_kwargs = {'batch_size': args.batch_size}
- test_kwargs = {'batch_size': args.test_batch_size}
+ train_kwargs = {"batch_size": args.batch_size}
+ test_kwargs = {"batch_size": args.test_batch_size}
if use_cuda:
- cuda_kwargs = {'num_workers': 1,
- 'pin_memory': True,
- 'shuffle': True}
+ cuda_kwargs = {"num_workers": 1, "pin_memory": True, "shuffle": True}
train_kwargs.update(cuda_kwargs)
test_kwargs.update(cuda_kwargs)
- transform=transforms.Compose([
- transforms.ToTensor(),
- transforms.Normalize((0.1307,), (0.3081,))
- ])
- dataset1 = datasets.MNIST('../data', train=True, download=True,
- transform=transform)
- dataset2 = datasets.MNIST('../data', train=False,
- transform=transform)
- train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
+ transform = transforms.Compose(
+ [
+ transforms.ToImage(),
+ transforms.ToDtype(torch.float32, scale=True),
+ transforms.Normalize(mean=(0.1307,), std=(0.3081,)),
+ ]
+ )
+
+ data_dir = args.data_dir
+
+ dataset1 = datasets.MNIST(data_dir, train=True, download=True, transform=transform)
+ dataset2 = datasets.MNIST(data_dir, train=False, transform=transform)
+ train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)
- model = Net().to(device)
- optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
+ model = Net().to(device, memory_format=torch.channels_last)
+ model = torch.compile(model, backend=args.compile_backend, mode=args.compile_mode)
+ optimizer = optim.Adadelta(model.parameters(), lr=torch.tensor(args.lr))
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
+
+ scaler = None
+ if use_cuda and args.use_amp:
+ scaler = torch.GradScaler(device=device)
+
for epoch in range(1, args.epochs + 1):
- train(args, model, device, train_loader, optimizer, epoch)
+ if scaler is None:
+ train(args, model, device, train_loader, optimizer, epoch)
+ else:
+ train_amp(args, model, device, train_loader, optimizer, epoch, scaler)
test(model, device, test_loader)
scheduler.step()
@@ -140,5 +257,5 @@ def main():
torch.save(model.state_dict(), "mnist_cnn.pt")
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
|
CC @svekars |
Hey, can I work on this issue ? |
@orion160 any updates on this ? |
@doshi-kevin I am not a PyTorch maintainer, it is great your interest on contributing. Though I see 2 problems, first I you are applying the commit, which means that there's neither authoring nor coauthoring on the git history. And the second, is that I see that your PR is about 600+ commits long, so it is possible that you did a merge and you created an altered The usual approach would be to rebase your changes to head and then do the PR. |
Sure I will make the changes, else create a new fork and a new pull request. @orion160 |
It seems good |
Hope this gets pushed @orion160 @msaroufim |
Update example at https://github.com/pytorch/examples/blob/main/mnist/main.py to use torch.compile features
The text was updated successfully, but these errors were encountered: