Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

API guide code snippets don't work #8497

Open
richardsliu opened this issue Dec 17, 2024 · 5 comments
Open

API guide code snippets don't work #8497

richardsliu opened this issue Dec 17, 2024 · 5 comments

Comments

@richardsliu
Copy link

📚 Documentation

Trying to follow the example here: https://github.com/pytorch/xla/blob/master/API_GUIDE.md#running-on-a-single-xla-device

The Python code snippet doesn't work, as MNIST(), nn, and optim are all undefined.

@ManfeiBai
Copy link
Collaborator

according to @richardsliu, this also not work:

Traceback (most recent call last):
  File "/home/ricliu/mnist.py", line 8, in <module>
    model = MNIST(root='./data', download=True).train().to(device)
TypeError: 'bool' object is not callable

@richardsliu
Copy link
Author

I used torchvision.datasets.MNIST but I don't think that's the correct one? It seems to expect MNIST to be a model instead of a dataset?

@miladm
Copy link
Collaborator

miladm commented Dec 18, 2024

Thanks for flagging this issue. This is a poor, legacy naming decision that certainly needs to be corrected.

Are you able to run this test https://github.com/pytorch/xla/blob/master/test/test_train_mp_mnist.py?
If you read this code, you see how the naming legacy decision has come to life.

cc @tengyifei @mikegre-google to help address this documentation issue, ideally for 2.6.

@richardsliu
Copy link
Author

richardsliu commented Dec 18, 2024

OK, thanks. I think I got this to run, with the following code:

import torch_xla.core.xla_model as xm
from torch_xla import runtime as xr
import torch
import torch_xla.utils.utils as xu
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F



class MNIST(nn.Module):

  def __init__(self):
    super(MNIST, self).__init__()
    self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
    self.bn1 = nn.BatchNorm2d(10)
    self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
    self.bn2 = nn.BatchNorm2d(20)
    self.fc1 = nn.Linear(320, 50)
    self.fc2 = nn.Linear(50, 10)

  def forward(self, x):
    x = F.relu(F.max_pool2d(self.conv1(x), 2))
    x = self.bn1(x)
    x = F.relu(F.max_pool2d(self.conv2(x), 2))
    x = self.bn2(x)
    x = torch.flatten(x, 1)
    x = F.relu(self.fc1(x))
    x = self.fc2(x)
    return F.log_softmax(x, dim=1)

batch_size=128
momentum=0.5
lr=0.01
device = xm.xla_device()
model = MNIST().train().to(device)
loss_fn = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)

train_loader = xu.SampleGenerator(
        data=(torch.zeros(batch_size, 1, 28, 28),
            torch.zeros(batch_size, dtype=torch.int64)), 
        sample_count=60000 // batch_size // xr.world_size())


for data, target in train_loader:
  optimizer.zero_grad()
  data = data.to(device)
  target = target.to(device)
  output = model(data)
  loss = loss_fn(output, target)
  loss.backward()

  optimizer.step()
  xm.mark_step()

@richardsliu
Copy link
Author

And the multi-device version:

import torch_xla.core.xla_model as xm
from torch_xla import runtime as xr
import torch
import torch_xla.utils.utils as xu
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch_xla
import torch_xla.distributed.parallel_loader as pl



class MNIST(nn.Module):

  def __init__(self):
    super(MNIST, self).__init__()
    self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
    self.bn1 = nn.BatchNorm2d(10)
    self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
    self.bn2 = nn.BatchNorm2d(20)
    self.fc1 = nn.Linear(320, 50)
    self.fc2 = nn.Linear(50, 10)

  def forward(self, x):
    x = F.relu(F.max_pool2d(self.conv1(x), 2))
    x = self.bn1(x)
    x = F.relu(F.max_pool2d(self.conv2(x), 2))
    x = self.bn2(x)
    x = torch.flatten(x, 1)
    x = F.relu(self.fc1(x))
    x = self.fc2(x)
    return F.log_softmax(x, dim=1)

batch_size=128
momentum=0.5
lr=0.01


def _mp_fn(index):
  device = xm.xla_device()
  train_loader = xu.SampleGenerator(
          data=(torch.zeros(batch_size, 1, 28, 28), torch.zeros(batch_size, dtype=torch.int64)),
          sample_count=60000 // batch_size // xr.world_size())

  mp_device_loader = pl.MpDeviceLoader(train_loader, device)
  print(device)

  model = MNIST().train().to(device)
  loss_fn = nn.NLLLoss()
  optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)


  for data, target in mp_device_loader:
    optimizer.zero_grad()
    output = model(data)
    loss = loss_fn(output, target)
    loss.backward()
    xm.optimizer_step(optimizer)


if __name__ == '__main__':
  torch_xla.launch(_mp_fn, args=())

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants