From a43fe63dbe2ab9c091f9a1b9417c8b6fd6c8bf4b Mon Sep 17 00:00:00 2001 From: Patrick Smyth Date: Wed, 26 Feb 2025 11:08:23 -0500 Subject: [PATCH] Whitelisted model classes when loading pickle file Ran linter Signed-off-by: Patrick Smyth --- image_classification.py | 31 +++++++++++++++++++++++++------ 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/image_classification.py b/image_classification.py index 08ac9ef..98a481c 100644 --- a/image_classification.py +++ b/image_classification.py @@ -107,8 +107,10 @@ def train_model(model, criterion, optimizer, scheduler, num_epochs=25): epoch_loss = running_loss / dataset_sizes[phase] epoch_acc = running_corrects.double() / dataset_sizes[phase] - phase_emoji = '🐳' if phase == 'train' else '🐧' - print(f"{phase_emoji} {phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}") + phase_emoji = "🐳" if phase == "train" else "🐧" + print( + f"{phase_emoji} {phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}" + ) # deep copy the model if phase == "val" and epoch_acc > best_acc: @@ -158,7 +160,7 @@ def infer_class(model, img_path): model.eval() - img = Image.open(img_path).convert('RGB') + img = Image.open(img_path).convert("RGB") img = data_transforms["val"](img) img = img.unsqueeze(0) img = img.to(device) @@ -173,12 +175,29 @@ def infer_class(model, img_path): if __name__ == "__main__": model_file = script_directory / Path("octopus_whale_penguin_model.pt") if model_file.exists(): - model = torch.load(model_file) + with torch.serialization.safe_globals( + [ + models.resnet.ResNet, + nn.modules.conv.Conv2d, + nn.modules.linear.Linear, + nn.modules.pooling.AdaptiveAvgPool2d, + models.resnet.BasicBlock, + nn.modules.container.Sequential, + nn.modules.pooling.MaxPool2d, + nn.modules.activation.ReLU, + nn.modules.batchnorm.BatchNorm2d, + ] + ): + model = torch.load(model_file) else: model = run_training() - print('Attempting to save the model as octopus_whale_penguin_model.pt in the same directory as this script...') + print( + "Attempting to save the model as octopus_whale_penguin_model.pt in the same directory as this script..." + ) torch.save(model, script_directory / Path("octopus_whale_penguin_model.pt")) - print('\nModel saved. To run inference using the stored model, rerun this script with a path to an image as an argument.') + print( + "\nModel saved. To run inference using the stored model, rerun this script with a path to an image as an argument." + ) if len(argv) >= 2: input_file = Path(argv[1])