diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 56a0a5e..71c99d5 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -1,4 +1,4 @@ -name: ML Model CI/CD +name: MNIST Model Tests on: [push] @@ -20,18 +20,20 @@ jobs: pip install --index-url https://download.pytorch.org/whl/cpu torch torchvision --no-cache-dir pip install pytest - - name: Train model + - name: Check Model Parameters run: | - cd mnist_project/src - python train.py + cd mnist_project + python -m pytest src/tests/test_model.py::test_model_parameters -v - - name: Run tests + - name: Train and Test Model run: | - cd mnist_project - python -m pytest src/tests/ + cd mnist_project/src + python train.py + cd .. + python -m pytest src/tests/test_training.py::test_model_accuracy -v - name: Upload model artifact uses: actions/upload-artifact@v4 with: name: trained-model - path: mnist_project/src/model_mnist_*.pth \ No newline at end of file + path: mnist_project/src/model_mnist_latest.pth \ No newline at end of file diff --git a/README.md b/README.md index 85fc5ad..e70743d 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,26 @@ -# ERA-V3 MNIST Project +# MNIST Classification Project -This project implements a CNN model for MNIST digit classification with CI/CD pipeline. +A PyTorch implementation of MNIST digit classification that achieves >95% accuracy in one epoch with less than 25,000 parameters. -## Project Structure -- `mnist_project/`: Contains the main project code - - `src/`: Source code directory - - `model.py`: CNN model architecture - - `train.py`: Training script - - `utils.py`: Utility functions - - `tests/`: Test files +## Model Architecture +- Input Layer: 28x28x1 +- Conv1: 8 filters with BatchNorm and ReLU +- Conv2: 16 filters with BatchNorm and ReLU +- Conv3: 20 filters with BatchNorm and ReLU +- MaxPooling layers +- Fully Connected Layer: 10 outputs +- Total Parameters: <25,000 -## Local Setup -1. Clone the repository: \ No newline at end of file +## Key Features +- Achieves >95% accuracy in 1 epoch +- Lightweight architecture (<25K parameters) +- Uses BatchNormalization for faster convergence +- Implements dropout for regularization + +## GitHub Actions Tests +The CI/CD pipeline automatically verifies: +1. Model has less than 25,000 parameters +2. Achieves accuracy greater than 95% in one epoch + +## Setup and Training +1. Install dependencies: \ No newline at end of file diff --git a/mnist_project/src/model.py b/mnist_project/src/model.py index a0ec4f0..5230fe8 100644 --- a/mnist_project/src/model.py +++ b/mnist_project/src/model.py @@ -1,20 +1,25 @@ import torch import torch.nn as nn +import torch.nn.functional as F class MNISTModel(nn.Module): def __init__(self): super(MNISTModel, self).__init__() - self.conv1 = nn.Conv2d(1, 8, kernel_size=3, padding=1) - self.conv2 = nn.Conv2d(8, 16, kernel_size=3, padding=1) - self.fc1 = nn.Linear(16 * 7 * 7, 64) - self.fc2 = nn.Linear(64, 10) + self.conv1 = nn.Conv2d(1, 8, kernel_size=3, padding=1) # 28x28x8 + self.bn1 = nn.BatchNorm2d(8) + self.conv2 = nn.Conv2d(8, 16, kernel_size=3, padding=1) # 28x28x16 + self.bn2 = nn.BatchNorm2d(16) + self.conv3 = nn.Conv2d(16, 20, kernel_size=3, padding=1) # 14x14x20 + self.bn3 = nn.BatchNorm2d(20) self.pool = nn.MaxPool2d(2, 2) - self.relu = nn.ReLU() + self.fc1 = nn.Linear(20 * 7 * 7, 10) + self.dropout = nn.Dropout(0.1) def forward(self, x): - x = self.pool(self.relu(self.conv1(x))) # 14x14 - x = self.pool(self.relu(self.conv2(x))) # 7x7 - x = x.view(-1, 16 * 7 * 7) - x = self.relu(self.fc1(x)) - x = self.fc2(x) - return x \ No newline at end of file + x = self.pool(F.relu(self.bn1(self.conv1(x)))) # 14x14x8 + x = F.relu(self.bn2(self.conv2(x))) # 14x14x16 + x = self.pool(F.relu(self.bn3(self.conv3(x)))) # 7x7x20 + x = x.view(-1, 20 * 7 * 7) + x = self.dropout(x) + x = self.fc1(x) + return F.log_softmax(x, dim=1) \ No newline at end of file diff --git a/mnist_project/src/tests/test_model.py b/mnist_project/src/tests/test_model.py index cbe09dd..ee1c12d 100644 --- a/mnist_project/src/tests/test_model.py +++ b/mnist_project/src/tests/test_model.py @@ -12,4 +12,5 @@ def test_model_architecture(): def test_model_parameters(): model = MNISTModel() total_params = sum(p.numel() for p in model.parameters()) - assert total_params < 100000, f"Model has {total_params} parameters, should be < 100000" \ No newline at end of file + print(f"Total parameters: {total_params}") + assert total_params < 25000, f"Model has {total_params} parameters, should be < 25000" \ No newline at end of file diff --git a/mnist_project/src/tests/test_training.py b/mnist_project/src/tests/test_training.py index 508e391..90999d9 100644 --- a/mnist_project/src/tests/test_training.py +++ b/mnist_project/src/tests/test_training.py @@ -4,11 +4,10 @@ from src.utils import evaluate_model def test_model_accuracy(): - # Get the src directory path current_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) model_path = os.path.join(current_dir, 'model_mnist_latest.pth') - print(f"Looking for model at: {model_path}") # Debug print + print(f"Looking for model at: {model_path}") if not os.path.exists(model_path): raise FileNotFoundError(f"No model file found at {model_path}") @@ -16,4 +15,4 @@ def test_model_accuracy(): model = MNISTModel() model.load_state_dict(torch.load(model_path, weights_only=True)) accuracy = evaluate_model(model) - assert accuracy >= 0.8, f"Model accuracy {accuracy} is below 0.8" \ No newline at end of file + assert accuracy >= 0.95, f"Model accuracy {accuracy:.2f} is below 0.95" \ No newline at end of file