-
Notifications
You must be signed in to change notification settings - Fork 488
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
API guide code snippets don't work #8497
Comments
according to @richardsliu, this also not work:
|
I used |
Thanks for flagging this issue. This is a poor, legacy naming decision that certainly needs to be corrected. Are you able to run this test https://github.com/pytorch/xla/blob/master/test/test_train_mp_mnist.py? cc @tengyifei @mikegre-google to help address this documentation issue, ideally for 2.6. |
OK, thanks. I think I got this to run, with the following code: import torch_xla.core.xla_model as xm
from torch_xla import runtime as xr
import torch
import torch_xla.utils.utils as xu
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
class MNIST(nn.Module):
def __init__(self):
super(MNIST, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.bn1 = nn.BatchNorm2d(10)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.bn2 = nn.BatchNorm2d(20)
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = self.bn1(x)
x = F.relu(F.max_pool2d(self.conv2(x), 2))
x = self.bn2(x)
x = torch.flatten(x, 1)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
batch_size=128
momentum=0.5
lr=0.01
device = xm.xla_device()
model = MNIST().train().to(device)
loss_fn = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
train_loader = xu.SampleGenerator(
data=(torch.zeros(batch_size, 1, 28, 28),
torch.zeros(batch_size, dtype=torch.int64)),
sample_count=60000 // batch_size // xr.world_size())
for data, target in train_loader:
optimizer.zero_grad()
data = data.to(device)
target = target.to(device)
output = model(data)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
xm.mark_step() |
And the multi-device version: import torch_xla.core.xla_model as xm
from torch_xla import runtime as xr
import torch
import torch_xla.utils.utils as xu
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch_xla
import torch_xla.distributed.parallel_loader as pl
class MNIST(nn.Module):
def __init__(self):
super(MNIST, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.bn1 = nn.BatchNorm2d(10)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.bn2 = nn.BatchNorm2d(20)
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = self.bn1(x)
x = F.relu(F.max_pool2d(self.conv2(x), 2))
x = self.bn2(x)
x = torch.flatten(x, 1)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
batch_size=128
momentum=0.5
lr=0.01
def _mp_fn(index):
device = xm.xla_device()
train_loader = xu.SampleGenerator(
data=(torch.zeros(batch_size, 1, 28, 28), torch.zeros(batch_size, dtype=torch.int64)),
sample_count=60000 // batch_size // xr.world_size())
mp_device_loader = pl.MpDeviceLoader(train_loader, device)
print(device)
model = MNIST().train().to(device)
loss_fn = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
for data, target in mp_device_loader:
optimizer.zero_grad()
output = model(data)
loss = loss_fn(output, target)
loss.backward()
xm.optimizer_step(optimizer)
if __name__ == '__main__':
torch_xla.launch(_mp_fn, args=()) |
📚 Documentation
Trying to follow the example here: https://github.com/pytorch/xla/blob/master/API_GUIDE.md#running-on-a-single-xla-device
The Python code snippet doesn't work, as
MNIST()
,nn
, andoptim
are all undefined.The text was updated successfully, but these errors were encountered: