Skip to content

Commit

Permalink
Ensure model is saved and test looks in correct path
Browse files Browse the repository at this point in the history
  • Loading branch information
Pankaj-sk committed Dec 6, 2024
1 parent d6f44e4 commit 82fd26d
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions mnist_project/src/tests/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
def test_model_accuracy():
model = MNISTModel()
# Find the most recent model file
model_files = glob.glob('model_mnist_*.pth')
model_files = glob.glob('model_mnist_latest.pth')
if not model_files:
raise FileNotFoundError("No model file found")
latest_model = max(model_files) # Gets the most recent file

model.load_state_dict(torch.load(latest_model))
accuracy = evaluate_model(model)
assert accuracy >= 0.8, f"Model accuracy {accuracy} is below 0.8"
# Load the model
model.load_state_dict(torch.load(model_files[0]))

# Test accuracy
# Add your accuracy testing code here

0 comments on commit 82fd26d

Please sign in to comment.