Skip to content

Commit

Permalink
Test mnist training (#61)
Browse files Browse the repository at this point in the history
Added MNIST training test case.
This requires torch-vision, so added it to the requirement. 

Furthermore, the change in cpu pass to `torch-to-iree`, which needs the
change in IREE pass to be landed and reflected in the IREE release, will
fail all the tests with `empty_strided` involved. As soon as IREE's new
release is available, the xfails can be removed.
  • Loading branch information
brucekimrokcmu authored Sep 20, 2023
1 parent ce2f5b0 commit d51df8d
Show file tree
Hide file tree
Showing 4 changed files with 165 additions and 4 deletions.
5 changes: 3 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
-f https://openxla.github.io/iree/pip-release-links.html

-r pytorch-cpu-requirements.txt
-r torchvision-requirements.txt

iree-compiler==20230914.645
iree-runtime==20230914.645
iree-compiler==20230920.651
iree-runtime==20230920.651
2 changes: 0 additions & 2 deletions tests/dynamo/importer_basic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,8 +255,6 @@ def forward(self, x):

@unittest.expectedFailure
def testImportAtenFull(self):
"""Expected to fail until torch-mlir op: torch.aten.empty_strided is implemented"""

def foo(x):
return torch.full(x.size(), fill_value=float("-inf"))

Expand Down
159 changes: 159 additions & 0 deletions tests/dynamo/mninst_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
# Copyright 2023 Nod Labs, Inc
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import logging

import math
import unittest
from dataclasses import dataclass
from typing import Any, Optional, Tuple

import torch
import torch.nn.functional as F
from torch import nn
import torch.optim as optim

import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader


# MNIST Data Loader
class MNISTDataLoader:
def __init__(self, batch_size, shuffle=True):
self.batch_size = batch_size
self.shuffle = shuffle

# Data Transformations
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])

# Download MNIST dataset
self.mnist_trainset = datasets.MNIST(root='../data', train=True, download=True, transform=transform)
self.mnist_testset = datasets.MNIST(root='../data', train=False, download=True, transform=transform)

def get_train_loader(self):
return DataLoader(
dataset=self.mnist_trainset,
batch_size=self.batch_size,
shuffle=self.shuffle
)

def get_test_loader(self):
return DataLoader(
dataset=self.mnist_testset,
batch_size=self.batch_size,
shuffle=False
)


# Simple CNN Model
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
self.relu = nn.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=2)
self.fc1 = nn.Linear(32 * 12 * 12, 10)

def forward(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.maxpool(x)
x = x.view(x.size(0), -1)
x = self.fc1(x)
return x

# Training
def train(model, images, labels, optimizer, criterion):
model.train()

total_loss = 0.0
num_correct = 0.0

optimizer.zero_grad()
# images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)

num_correct += int((torch.argmax(outputs, dim=1) == labels).sum())
total_loss += float(loss.item())

loss.backward()
optimizer.step()
total_loss += loss.item()

# TODO Implement inference func
"""
def test(model, images, labels, criterion):
model.eval()
num_correct = 0.0
total_loss = 0.0
with torch.no_grad():
# images, labels = images.to(device), labels.to(device)
with torch.inference_mode():
outputs = model(images)
loss = criterion(outputs, labels)
num_correct += int((torch.argmax(outputs, dim=1) == labels).sum())
total_loss += float(loss.item())
# acc = 100 * num_correct / (config['batch_size'] * len(test_loader))
# total_loss = float(total_loss / len(test_loader))
# return acc, total_loss
"""

def main():
# Example Hyperparameters
config = {
'batch_size': 64,
'learning_rate': 0.001,
# 'threshold' : 0.001,
# 'factor' : 0.1,
'num_epochs': 10,
}

# Data Loader
custom_data_loader = MNISTDataLoader(config['batch_size'])
train_loader = custom_data_loader.get_train_loader()
# test_loader = MNISTDataLoader.get_test_loader()

# Model, optimizer, loss
model = CNN()
optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'])
criterion = nn.CrossEntropyLoss()

# Training
train_opt = torch.compile(train, backend="turbine_cpu")
for i, (images, labels) in enumerate(train_loader):
train_opt(model, images, labels, optimizer, criterion)


# TODO: Inference
"""
test_opt = torch.compile(test, backend="turbine_cpu", mode="reduce-overhead")
for i, (images, labels) in enumerate(test_loader):
test(model, images, labels, criterion)
"""



class ModelTests(unittest.TestCase):
@unittest.expectedFailure
def testMNIST(self):
# TODO: Fix the below error
"""
failed to legalize operation 'arith.sitofp' that was explicitly marked illegal
"""
main()


if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
unittest.main()
3 changes: 3 additions & 0 deletions torchvision-requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
--pre
torchvision==0.16.0.dev20230901

0 comments on commit d51df8d

Please sign in to comment.