From 7b7b72157c165d21d1344e848e761c8d37685c86 Mon Sep 17 00:00:00 2001 From: Yasuhiro Matsumoto Date: Wed, 22 May 2024 00:46:06 +0900 Subject: [PATCH] type variable df --- .../machinelearning-python-bentoml/iris_classifier.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/code-samples/community/serving/machinelearning-python-bentoml/iris_classifier.py b/code-samples/community/serving/machinelearning-python-bentoml/iris_classifier.py index 3fa0c1f734..324208bf47 100644 --- a/code-samples/community/serving/machinelearning-python-bentoml/iris_classifier.py +++ b/code-samples/community/serving/machinelearning-python-bentoml/iris_classifier.py @@ -1,4 +1,8 @@ +import numpy as np import bentoml +from pydantic import Field +from bentoml.validators import Shape +from typing_extensions import Annotated import joblib @@ -10,5 +14,10 @@ def __init__(self): self.model = joblib.load(self.iris_model.path_of("model.pkl")) @bentoml.api - def predict(self, df): + def predict( + self, + df: Annotated[np.ndarray, Shape((-1, 4))] = Field( + default=[[5.2, 2.3, 5.0, 0.7]] + ), + ) -> np.ndarray: return self.artifacts.model.predict(df)