Skip to content

Commit

Permalink
Second commit after file structure fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Pankaj-sk committed Dec 5, 2024
1 parent ed83974 commit ba7a3d4
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install torch torchvision pytest
pip install --index-url https://download.pytorch.org/whl/cpu torch torchvision --no-cache-dir
pip install pytest
- name: Train model
run: |
Expand All @@ -30,7 +31,7 @@ jobs:
python -m pytest src/tests/
- name: Upload model artifact
uses: actions/upload-artifact@v2
uses: actions/upload-artifact@v4
with:
name: trained-model
path: mnist_project/src/model_mnist_*.pth
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# ERA-V3 MNIST Project

This project implements a CNN model for MNIST digit classification with CI/CD pipeline.

## Project Structure
- `mnist_project/`: Contains the main project code
- `src/`: Source code directory
- `model.py`: CNN model architecture
- `train.py`: Training script
- `utils.py`: Utility functions
- `tests/`: Test files

## Local Setup
1. Clone the repository:
9 changes: 8 additions & 1 deletion mnist_project/src/tests/test_training.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
import torch
import glob
from src.model import MNISTModel
from src.utils import evaluate_model

def test_model_accuracy():
model = MNISTModel()
model.load_state_dict(torch.load('model_mnist_latest.pth'))
# Find the most recent model file
model_files = glob.glob('model_mnist_*.pth')
if not model_files:
raise FileNotFoundError("No model file found")
latest_model = max(model_files) # Gets the most recent file

model.load_state_dict(torch.load(latest_model))
accuracy = evaluate_model(model)
assert accuracy >= 0.8, f"Model accuracy {accuracy} is below 0.8"

0 comments on commit ba7a3d4

Please sign in to comment.