Skip to content

Commit

Permalink
MNIST model with <25K params and >95% accuracy - assignment 5
Browse files Browse the repository at this point in the history
  • Loading branch information
Pankaj-sk committed Dec 14, 2024
1 parent e7562ac commit cd40fa1
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 34 deletions.
18 changes: 10 additions & 8 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: ML Model CI/CD
name: MNIST Model Tests

on: [push]

Expand All @@ -20,18 +20,20 @@ jobs:
pip install --index-url https://download.pytorch.org/whl/cpu torch torchvision --no-cache-dir
pip install pytest
- name: Train model
- name: Check Model Parameters
run: |
cd mnist_project/src
python train.py
cd mnist_project
python -m pytest src/tests/test_model.py::test_model_parameters -v
- name: Run tests
- name: Train and Test Model
run: |
cd mnist_project
python -m pytest src/tests/
cd mnist_project/src
python train.py
cd ..
python -m pytest src/tests/test_training.py::test_model_accuracy -v
- name: Upload model artifact
uses: actions/upload-artifact@v4
with:
name: trained-model
path: mnist_project/src/model_mnist_*.pth
path: mnist_project/src/model_mnist_latest.pth
34 changes: 23 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,26 @@
# ERA-V3 MNIST Project
# MNIST Classification Project

This project implements a CNN model for MNIST digit classification with CI/CD pipeline.
A PyTorch implementation of MNIST digit classification that achieves >95% accuracy in one epoch with less than 25,000 parameters.

## Project Structure
- `mnist_project/`: Contains the main project code
- `src/`: Source code directory
- `model.py`: CNN model architecture
- `train.py`: Training script
- `utils.py`: Utility functions
- `tests/`: Test files
## Model Architecture
- Input Layer: 28x28x1
- Conv1: 8 filters with BatchNorm and ReLU
- Conv2: 16 filters with BatchNorm and ReLU
- Conv3: 20 filters with BatchNorm and ReLU
- MaxPooling layers
- Fully Connected Layer: 10 outputs
- Total Parameters: <25,000

## Local Setup
1. Clone the repository:
## Key Features
- Achieves >95% accuracy in 1 epoch
- Lightweight architecture (<25K parameters)
- Uses BatchNormalization for faster convergence
- Implements dropout for regularization

## GitHub Actions Tests
The CI/CD pipeline automatically verifies:
1. Model has less than 25,000 parameters
2. Achieves accuracy greater than 95% in one epoch

## Setup and Training
1. Install dependencies:
27 changes: 16 additions & 11 deletions mnist_project/src/model.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,25 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

class MNISTModel(nn.Module):
def __init__(self):
super(MNISTModel, self).__init__()
self.conv1 = nn.Conv2d(1, 8, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(8, 16, kernel_size=3, padding=1)
self.fc1 = nn.Linear(16 * 7 * 7, 64)
self.fc2 = nn.Linear(64, 10)
self.conv1 = nn.Conv2d(1, 8, kernel_size=3, padding=1) # 28x28x8
self.bn1 = nn.BatchNorm2d(8)
self.conv2 = nn.Conv2d(8, 16, kernel_size=3, padding=1) # 28x28x16
self.bn2 = nn.BatchNorm2d(16)
self.conv3 = nn.Conv2d(16, 20, kernel_size=3, padding=1) # 14x14x20
self.bn3 = nn.BatchNorm2d(20)
self.pool = nn.MaxPool2d(2, 2)
self.relu = nn.ReLU()
self.fc1 = nn.Linear(20 * 7 * 7, 10)
self.dropout = nn.Dropout(0.1)

def forward(self, x):
x = self.pool(self.relu(self.conv1(x))) # 14x14
x = self.pool(self.relu(self.conv2(x))) # 7x7
x = x.view(-1, 16 * 7 * 7)
x = self.relu(self.fc1(x))
x = self.fc2(x)
return x
x = self.pool(F.relu(self.bn1(self.conv1(x)))) # 14x14x8
x = F.relu(self.bn2(self.conv2(x))) # 14x14x16
x = self.pool(F.relu(self.bn3(self.conv3(x)))) # 7x7x20
x = x.view(-1, 20 * 7 * 7)
x = self.dropout(x)
x = self.fc1(x)
return F.log_softmax(x, dim=1)
3 changes: 2 additions & 1 deletion mnist_project/src/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@ def test_model_architecture():
def test_model_parameters():
model = MNISTModel()
total_params = sum(p.numel() for p in model.parameters())
assert total_params < 100000, f"Model has {total_params} parameters, should be < 100000"
print(f"Total parameters: {total_params}")
assert total_params < 25000, f"Model has {total_params} parameters, should be < 25000"
5 changes: 2 additions & 3 deletions mnist_project/src/tests/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,15 @@
from src.utils import evaluate_model

def test_model_accuracy():
# Get the src directory path
current_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
model_path = os.path.join(current_dir, 'model_mnist_latest.pth')

print(f"Looking for model at: {model_path}") # Debug print
print(f"Looking for model at: {model_path}")

if not os.path.exists(model_path):
raise FileNotFoundError(f"No model file found at {model_path}")

model = MNISTModel()
model.load_state_dict(torch.load(model_path, weights_only=True))
accuracy = evaluate_model(model)
assert accuracy >= 0.8, f"Model accuracy {accuracy} is below 0.8"
assert accuracy >= 0.95, f"Model accuracy {accuracy:.2f} is below 0.95"

0 comments on commit cd40fa1

Please sign in to comment.