Skip to content

Commit

Permalink
Whitelist model classes when loading pickle file
Browse files Browse the repository at this point in the history
  • Loading branch information
smythp authored Feb 26, 2025
2 parents a5f70df + a43fe63 commit 92befc9
Showing 1 changed file with 25 additions and 6 deletions.
31 changes: 25 additions & 6 deletions image_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,10 @@ def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
epoch_loss = running_loss / dataset_sizes[phase]
epoch_acc = running_corrects.double() / dataset_sizes[phase]

phase_emoji = '🐳' if phase == 'train' else '🐧'
print(f"{phase_emoji} {phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}")
phase_emoji = "🐳" if phase == "train" else "🐧"
print(
f"{phase_emoji} {phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}"
)

# deep copy the model
if phase == "val" and epoch_acc > best_acc:
Expand Down Expand Up @@ -158,7 +160,7 @@ def infer_class(model, img_path):

model.eval()

img = Image.open(img_path).convert('RGB')
img = Image.open(img_path).convert("RGB")
img = data_transforms["val"](img)
img = img.unsqueeze(0)
img = img.to(device)
Expand All @@ -173,12 +175,29 @@ def infer_class(model, img_path):
if __name__ == "__main__":
model_file = script_directory / Path("octopus_whale_penguin_model.pt")
if model_file.exists():
model = torch.load(model_file)
with torch.serialization.safe_globals(
[
models.resnet.ResNet,
nn.modules.conv.Conv2d,
nn.modules.linear.Linear,
nn.modules.pooling.AdaptiveAvgPool2d,
models.resnet.BasicBlock,
nn.modules.container.Sequential,
nn.modules.pooling.MaxPool2d,
nn.modules.activation.ReLU,
nn.modules.batchnorm.BatchNorm2d,
]
):
model = torch.load(model_file)
else:
model = run_training()
print('Attempting to save the model as octopus_whale_penguin_model.pt in the same directory as this script...')
print(
"Attempting to save the model as octopus_whale_penguin_model.pt in the same directory as this script..."
)
torch.save(model, script_directory / Path("octopus_whale_penguin_model.pt"))
print('\nModel saved. To run inference using the stored model, rerun this script with a path to an image as an argument.')
print(
"\nModel saved. To run inference using the stored model, rerun this script with a path to an image as an argument."
)

if len(argv) >= 2:
input_file = Path(argv[1])
Expand Down

0 comments on commit 92befc9

Please sign in to comment.