Skip to content

Commit

Permalink
Update batch_driver.py (#2345)
Browse files Browse the repository at this point in the history
* Update batch_driver.py

* black

* cli copy

* fix: comments

* fix: comments

* format

---------

Co-authored-by: santiagxf <[email protected]>
  • Loading branch information
santiagxf and santiagxf authored Jun 29, 2023
1 parent 7e360ef commit a8c4255
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import glob
import mlflow
import pandas as pd
import logging


def init():
Expand All @@ -19,18 +20,23 @@ def init():

# Load the model, it's input types and output names
model = mlflow.pyfunc.load(model_path)
if model.metadata.signature.inputs:
model_input_types = dict(
zip(
model.metadata.signature.inputs.input_names(),
model.metadata.signature.inputs.pandas_types(),
if model.metadata and model.metadata.signature:
if model.metadata.signature.inputs:
model_input_types = dict(
zip(
model.metadata.signature.inputs.input_names(),
model.metadata.signature.inputs.pandas_types(),
)
)
if model.metadata.signature.outputs:
if model.metadata.signature.outputs.has_input_names():
model_output_names = model.metadata.signature.outputs.input_names()
elif len(model.metadata.signature.outputs.input_names()) == 1:
model_output_names = ["prediction"]
else:
logging.warning(
"Model doesn't contain a signature. Input data types won't be enforced."
)
if model.metadata.signature.outputs:
if model.metadata.signature.outputs.has_input_names():
model_output_names = model.metadata.signature.outputs.input_names()
elif len(model.metadata.signature.outputs.input_names()) == 1:
model_output_names = ["prediction"]


def run(mini_batch):
Expand All @@ -41,10 +47,12 @@ def run(mini_batch):
lambda fp: pd.read_csv(fp).assign(filename=os.path.basename(fp)), mini_batch
)
)

if model_input_types:
data = data.astype(model_input_types)

pred = model.predict(data)
# Predict over the input data, minus the column filename which is not part of the model.
pred = model.predict(data.drop("filename", axis=1))

if pred is not pd.DataFrame:
if not model_output_names:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
# Copyright (c) Microsoft. All rights reserved.
# Licensed under the MIT license.

import os
import glob
import mlflow
import pandas as pd
import logging
from pathlib import Path


Expand All @@ -19,23 +23,32 @@ def init():

# Load the model, it's input types and output names
model = mlflow.pyfunc.load(model_path)
if model.metadata.signature.inputs:
model_input_types = dict(
zip(
model.metadata.signature.inputs.input_names(),
model.metadata.signature.inputs.pandas_types(),
if model.metadata and model.metadata.signature:
if model.metadata.signature.inputs:
model_input_types = dict(
zip(
model.metadata.signature.inputs.input_names(),
model.metadata.signature.inputs.pandas_types(),
)
)
if model.metadata.signature.outputs:
if model.metadata.signature.outputs.has_input_names():
model_output_names = model.metadata.signature.outputs.input_names()
elif len(model.metadata.signature.outputs.input_names()) == 1:
model_output_names = ["prediction"]
else:
logging.warning(
"Model doesn't contain a signature. Input data types won't be enforced."
)
if model.metadata.signature.outputs:
if model.metadata.signature.outputs.has_input_names():
model_output_names = model.metadata.signature.outputs.input_names()
elif len(model.metadata.signature.outputs.input_names()) == 1:
model_output_names = ["prediction"]


def run(mini_batch):
for file_path in mini_batch:
data = pd.read_csv(file_path).astype(model_input_types)
data = pd.read_csv(file_path)

if model_input_types:
data = data.astype(model_input_types)

pred = model.predict(data)

if pred is not pd.DataFrame:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import glob
import mlflow
import pandas as pd
import logging


def init():
Expand All @@ -19,18 +20,23 @@ def init():

# Load the model, it's input types and output names
model = mlflow.pyfunc.load(model_path)
if model.metadata.signature.inputs:
model_input_types = dict(
zip(
model.metadata.signature.inputs.input_names(),
model.metadata.signature.inputs.pandas_types(),
if model.metadata and model.metadata.signature:
if model.metadata.signature.inputs:
model_input_types = dict(
zip(
model.metadata.signature.inputs.input_names(),
model.metadata.signature.inputs.pandas_types(),
)
)
if model.metadata.signature.outputs:
if model.metadata.signature.outputs.has_input_names():
model_output_names = model.metadata.signature.outputs.input_names()
elif len(model.metadata.signature.outputs.input_names()) == 1:
model_output_names = ["prediction"]
else:
logging.warning(
"Model doesn't contain a signature. Input data types won't be enforced."
)
if model.metadata.signature.outputs:
if model.metadata.signature.outputs.has_input_names():
model_output_names = model.metadata.signature.outputs.input_names()
elif len(model.metadata.signature.outputs.input_names()) == 1:
model_output_names = ["prediction"]


def run(mini_batch):
Expand All @@ -41,10 +47,12 @@ def run(mini_batch):
lambda fp: pd.read_csv(fp).assign(filename=os.path.basename(fp)), mini_batch
)
)

if model_input_types:
data = data.astype(model_input_types)

pred = model.predict(data)
# Predict over the input data, minus the column filename which is not part of the model.
pred = model.predict(data.drop("filename", axis=1))

if pred is not pd.DataFrame:
if not model_output_names:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
# Copyright (c) Microsoft. All rights reserved.
# Licensed under the MIT license.

import os
import glob
import mlflow
import pandas as pd
import logging
from pathlib import Path


Expand All @@ -19,23 +23,32 @@ def init():

# Load the model, it's input types and output names
model = mlflow.pyfunc.load(model_path)
if model.metadata.signature.inputs:
model_input_types = dict(
zip(
model.metadata.signature.inputs.input_names(),
model.metadata.signature.inputs.pandas_types(),
if model.metadata and model.metadata.signature:
if model.metadata.signature.inputs:
model_input_types = dict(
zip(
model.metadata.signature.inputs.input_names(),
model.metadata.signature.inputs.pandas_types(),
)
)
if model.metadata.signature.outputs:
if model.metadata.signature.outputs.has_input_names():
model_output_names = model.metadata.signature.outputs.input_names()
elif len(model.metadata.signature.outputs.input_names()) == 1:
model_output_names = ["prediction"]
else:
logging.warning(
"Model doesn't contain a signature. Input data types won't be enforced."
)
if model.metadata.signature.outputs:
if model.metadata.signature.outputs.has_input_names():
model_output_names = model.metadata.signature.outputs.input_names()
elif len(model.metadata.signature.outputs.input_names()) == 1:
model_output_names = ["prediction"]


def run(mini_batch):
for file_path in mini_batch:
data = pd.read_csv(file_path).astype(model_input_types)
data = pd.read_csv(file_path)

if model_input_types:
data = data.astype(model_input_types)

pred = model.predict(data)

if pred is not pd.DataFrame:
Expand Down

0 comments on commit a8c4255

Please sign in to comment.