generated from XpressAI/xai-component-library-template
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathsimple_ml_predict.py
65 lines (49 loc) · 2.13 KB
/
simple_ml_predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# This is an example component.
# To run this, don't forget to install the required libraries:
# pip install torch torchvision
from xai_components.base import InArg, OutArg, InCompArg, BaseComponent, Component, xai_component
import io
import os
import torch
import urllib.request
import torchvision.transforms as T
from PIL import Image
from torchvision.models import mobilenet_v2
@xai_component
class MobileNetV2ProcessImageData(Component):
"""Processes image data using the MobileNetV2 model and returns the predicted class label.
##### inPorts:
- image_data: Input image data in bytes.
##### outPorts:
- prediction: Predicted class label for the input image.
"""
image_data: InArg[bytes]
prediction: OutArg[str]
def execute(self, ctx) -> None:
# Load the image from the image_data bytes
image = Image.open(io.BytesIO(self.image_data.value))
# Load the MobileNetV2 model
model = mobilenet_v2(pretrained=True)
model.eval()
# Preprocess the image and prepare it for the MobileNetV2 model
preprocess = T.Compose([
T.Resize(256),
T.CenterCrop(224),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
input_image = preprocess(image).unsqueeze(0)
# Get the prediction using the MobileNetV2 model
with torch.no_grad():
output = model(input_image)
_, prediction_idx = torch.max(output, 1)
prediction_idx = prediction_idx.item()
# Download the imagenet_classes.txt file if it doesn't exist
file_path = "imagenet_classes.txt"
if not os.path.exists(file_path):
url = "https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt"
urllib.request.urlretrieve(url, file_path)
# Convert the prediction index to a human-readable class label
with open(file_path, "r") as f:
labels = [line.strip() for line in f.readlines()]
self.prediction.value = labels[prediction_idx]