Skip to content

Commit

Permalink
Update Post-Evaluation Subtask2b Model.py
Browse files Browse the repository at this point in the history
Updated image path and model saving
  • Loading branch information
vemchance authored Mar 12, 2024
1 parent 211e946 commit 99d3bba
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions Subtask2b/Post-Evaluation Subtask2b Model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
print(num_classes, flush=True)
#### read a csv instead ####

path = "/home/h3/647518/SemEval/improved_model_2b/subtask2b_images"
path = "image_path"
images = [os.path.join(dirpath,f) for (dirpath, dirnames, filenames) in os.walk(path) for f in filenames]
images_df = pd.DataFrame(images, columns=['filepath'])
images_df['image'] = images_df['filepath'].str.split('/').str[-1]
Expand Down Expand Up @@ -239,10 +239,10 @@ def val_epoch(ensemble_model, val_dataloader):
val_accuracy, f1_micro_val, f1_macro_val, loss = val_epoch(ensemble_model, val_dataloader)
print(f"\n Epoch:{epochs + 1} / {num_epochs}, Val accuracy:{val_accuracy:.5f}, Val F1 Micro: {f1_micro_val:.5f}, Val F1 Macro:{f1_macro_val:.5f}, Validation Loss: {loss:.5f}", flush=True)

if f1_macro_val > best_loss:
if f1_macro_val < best_loss:
print('Saving Model :)')
torch.save(ensemble_model, f'2bensemble_model_improved_{epochs}e.pth')
torch.save(ensemble_model.state_dict(),f'2bensemble_model_improved_{epochs}e_weights.pth')
best_loss = f1_macro_val
else:
best_loss = best_loss
best_loss = best_loss

0 comments on commit 99d3bba

Please sign in to comment.