Skip to content

Commit

Permalink
Make ML Server serve multiple models (#460)
Browse files Browse the repository at this point in the history
* Make ML Server serve multiple models

Will take the gordo-project and gordo-name 
as a parameter of the url to load the model by request.
ie. /gordo/v0/gordo-project/gordo-name/predict would
load the model 'gordo-name' and process the request.
  • Loading branch information
milesgranger authored Sep 27, 2019
1 parent 28a7dc3 commit c1c50fe
Show file tree
Hide file tree
Showing 11 changed files with 194 additions and 105 deletions.
45 changes: 4 additions & 41 deletions gordo_components/server/server.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
# -*- coding: utf-8 -*-
import os
import logging
import timeit
import typing
from functools import wraps

from flask import Flask, g
from sklearn.base import BaseEstimator

from gordo_components import serializer
from gordo_components.data_provider.base import GordoBaseDataProvider
from gordo_components.server import views

Expand All @@ -18,7 +14,7 @@
class Config:
"""Server config"""

MODEL_LOCATION_ENV_VAR = "MODEL_LOCATION"
MODEL_COLLECTION_DIR_ENV_VAR = "MODEL_COLLECTION_DIR"


def adapt_proxy_deployment(wsgi_app: typing.Callable) -> typing.Callable:
Expand Down Expand Up @@ -97,38 +93,6 @@ def wrapper(environ, start_response):
return wrapper


def load_model_and_metadata(
model_dir_env_var: str
) -> typing.Tuple[BaseEstimator, dict]:
"""
Loads a model and metadata from the path found in ``model_dir_env_var``
environment variable
Parameters
----------
model_dir_env_var: str
The name of the environment variable which stores the location of the model
Returns
-------
BaseEstimator, dict
Tuple where the 0th element is the model, and the 1st element is the metadata
associated with the model
"""
logger.debug("Determining model location...")
model_location = os.getenv(model_dir_env_var)
if model_location is None:
raise ValueError(f'Environment variable "{model_dir_env_var}" not set!')
if not os.path.isdir(model_location):
raise NotADirectoryError(
f'The supplied directory: "{model_location}" does not exist!'
)

model = serializer.load(model_location)
metadata = serializer.load_metadata(model_location)
return model, metadata


def build_app(data_provider: typing.Optional[GordoBaseDataProvider] = None):
"""
Build app and any associated routes
Expand Down Expand Up @@ -157,10 +121,9 @@ def _log_time_taken(response):
response.headers["Server-Timing"] = f"request_walltime_s;dur={runtime_s}"
return response

with app.app_context():
app.model, app.metadata = load_model_and_metadata(
app.config["MODEL_LOCATION_ENV_VAR"]
)
@app.route("/healthcheck")
def base_healthcheck():
return "", 200

return app

Expand Down
81 changes: 79 additions & 2 deletions gordo_components/server/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,18 @@

import logging
import functools
import os
import timeit
import dateutil
from datetime import datetime
from typing import Union, List

import pandas as pd

from flask import request, g, jsonify, make_response, Response, current_app
from functools import lru_cache, wraps
from sklearn.base import BaseEstimator

from gordo_components import serializer
from gordo_components.dataset.datasets import TimeSeriesDataset


Expand Down Expand Up @@ -280,7 +283,7 @@ def wrapper_method(self, *args, **kwargs):
data_provider=g.data_provider,
from_ts=start - self.frequency.delta,
to_ts=end,
resolution=current_app.metadata["dataset"]["resolution"],
resolution=g.metadata["dataset"]["resolution"],
tag_list=self.tags,
target_tag_list=self.target_tags or None,
)
Expand Down Expand Up @@ -316,3 +319,77 @@ def wrapper_method(self, *args, **kwargs):
return method(self, *args, **kwargs)

return wrapper_method


@lru_cache(maxsize=int(os.getenv("N_CACHED_MODELS", 2)))
def load_model(directory: str, name: str) -> BaseEstimator:
"""
Load a given model from the directory by name.
Parameters
----------
directory: str
Directory to look for the model
name: str
Name of the model to load, this would be the sub directory within the
directory parameter.
Returns
-------
BaseEstimator
"""
model = serializer.load(os.path.join(directory, name))
return model


@lru_cache(maxsize=20)
def load_metadata(directory: str, name: str) -> dict:
"""
Load metadata from a directory for a given model by name.
Parameters
----------
directory: str
Directory to look for the model's metadata
name: str
Name of the model to load metadata for, this would be the sub directory
within the directory parameter.
Returns
-------
dict
"""
metadata = serializer.load_metadata(os.path.join(directory, name))
return metadata


def metadata_required(f):
"""
Decorate a view which has ``gordo_name`` as a url parameter and will
set ``g.metadata`` to that model's metadata
"""

@wraps(f)
def wrapper(*args: tuple, gordo_project: str, gordo_name: str, **kwargs: dict):
collection_dir = os.environ[current_app.config["MODEL_COLLECTION_DIR_ENV_VAR"]]
g.metadata = load_metadata(directory=collection_dir, name=gordo_name)
return f(*args, **kwargs)

return wrapper


def model_required(f):
"""
Decorate a view which has ``gordo_name`` as a url parameter and will
set ``g.model`` to be the loaded model and ``g.metadata``
to that model's metadata
"""

@wraps(f)
def wrapper(*args: tuple, gordo_project: str, gordo_name: str, **kwargs: dict):
collection_dir = os.environ[current_app.config["MODEL_COLLECTION_DIR_ENV_VAR"]]
g.model = load_model(directory=collection_dir, name=gordo_name)
g.metadata = load_metadata(directory=collection_dir, name=gordo_name)
return f(*args, **kwargs)

return wrapper
14 changes: 9 additions & 5 deletions gordo_components/server/views/anomaly.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import timeit
import typing

from flask import Blueprint, make_response, jsonify, current_app, g
from flask import Blueprint, make_response, jsonify, g
from flask_restplus import fields

from gordo_components import __version__
Expand All @@ -15,7 +15,7 @@

logger = logging.getLogger(__name__)

anomaly_blueprint = Blueprint("ioc_anomaly_blueprint", __name__, url_prefix="/anomaly")
anomaly_blueprint = Blueprint("ioc_anomaly_blueprint", __name__)

api = Api(
app=anomaly_blueprint,
Expand Down Expand Up @@ -99,6 +99,7 @@ class AnomalyView(BaseModelView):
"X": "Nested list of samples to predict, or single list considered as one sample"
}
)
@utils.model_required
@utils.extract_X_y
def post(self):
start_time = timeit.default_timer()
Expand All @@ -111,6 +112,7 @@ def post(self):
"end": "An ISO formatted datetime with timezone info string indicating prediction range end",
}
)
@utils.model_required
@utils.extract_X_y
def get(self):
start_time = timeit.default_timer()
Expand Down Expand Up @@ -147,10 +149,10 @@ def _create_anomaly_response(self, start_time: float = None):

# Now create an anomaly dataframe from the base response dataframe
try:
anomaly_df = current_app.model.anomaly(g.X, g.y, frequency=self.frequency)
anomaly_df = g.model.anomaly(g.X, g.y, frequency=self.frequency)
except AttributeError:
msg = {
"message": f"Model is not an AnomalyDetector, it is of type: {type(current_app.model)}"
"message": f"Model is not an AnomalyDetector, it is of type: {type(g.model)}"
}
return make_response(jsonify(msg), 422) # 422 Unprocessable Entity

Expand All @@ -160,4 +162,6 @@ def _create_anomaly_response(self, start_time: float = None):
return make_response(jsonify(context), context.pop("status-code", 200))


api.add_resource(AnomalyView, "/prediction")
api.add_resource(
AnomalyView, "/gordo/v0/<gordo_project>/<gordo_name>/anomaly/prediction"
)
36 changes: 20 additions & 16 deletions gordo_components/server/views/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,20 +106,16 @@ def frequency(self):
"""
The frequency the model was trained with in the dataset
"""
return pd.tseries.frequencies.to_offset(
current_app.metadata["dataset"]["resolution"]
)
return pd.tseries.frequencies.to_offset(g.metadata["dataset"]["resolution"])

@property
def tags(self) -> typing.List[SensorTag]:
return normalize_sensor_tags(current_app.metadata["dataset"]["tag_list"])
return normalize_sensor_tags(g.metadata["dataset"]["tag_list"])

@property
def target_tags(self) -> typing.List[SensorTag]:
if "target_tag_list" in current_app.metadata["dataset"]:
return normalize_sensor_tags(
current_app.metadata["dataset"]["target_tag_list"]
)
if "target_tag_list" in g.metadata["dataset"]:
return normalize_sensor_tags(g.metadata["dataset"]["target_tag_list"])
else:
return []

Expand All @@ -130,6 +126,7 @@ def target_tags(self) -> typing.List[SensorTag]:
"end": "An ISO formatted datetime with timezone info string indicating prediction range end",
}
)
@server_utils.model_required
@server_utils.extract_X_y
def get(self):
"""
Expand All @@ -140,6 +137,7 @@ def get(self):
@api.response(200, "Success", API_MODEL_OUTPUT_POST)
@api.expect(API_MODEL_INPUT_POST, validate=False)
@api.doc(params={"X": "Nested or single list of sample(s) to predict"})
@server_utils.model_required
@server_utils.extract_X_y
def post(self):
"""
Expand Down Expand Up @@ -174,7 +172,7 @@ def _process_request(self):
process_request_start_time_s = timeit.default_timer()

try:
output = model_io.get_model_output(model=current_app.model, X=X)
output = model_io.get_model_output(model=g.model, X=X)
except ValueError as err:
tb = traceback.format_exc()
logger.error(
Expand Down Expand Up @@ -215,15 +213,16 @@ class MetaDataView(Resource):
Serve model / server metadata
"""

@server_utils.metadata_required
def get(self):
"""
Get metadata about this endpoint, also serves as /healthcheck endpoint
"""
model_location_env_var = current_app.config["MODEL_LOCATION_ENV_VAR"]
model_collection_env_var = current_app.config["MODEL_COLLECTION_DIR_ENV_VAR"]
return {
"gordo-server-version": __version__,
"metadata": current_app.metadata,
"env": {model_location_env_var: os.environ.get(model_location_env_var)},
"metadata": g.metadata,
"env": {model_collection_env_var: os.environ.get(model_collection_env_var)},
}


Expand All @@ -237,6 +236,7 @@ class DownloadModel(Resource):
@api.doc(
description="Download model, loadable via gordo_components.serializer.loads"
)
@server_utils.model_required
def get(self):
"""
Responds with a serialized copy of the current model being served.
Expand All @@ -246,11 +246,15 @@ def get(self):
bytes
Results from ``gordo_components.serializer.dumps()``
"""
serialized_model = serializer.dumps(current_app.model)
serialized_model = serializer.dumps(g.model)
buff = io.BytesIO(serialized_model)
return send_file(buff, attachment_filename="model.tar.gz")


api.add_resource(BaseModelView, "/prediction")
api.add_resource(MetaDataView, "/metadata", "/healthcheck")
api.add_resource(DownloadModel, "/download-model")
api.add_resource(BaseModelView, "/gordo/v0/<gordo_project>/<gordo_name>/prediction")
api.add_resource(
MetaDataView,
"/gordo/v0/<gordo_project>/<gordo_name>/metadata",
"/gordo/v0/<gordo_project>/<gordo_name>/healthcheck",
)
api.add_resource(DownloadModel, "/gordo/v0/<gordo_project>/<gordo_name>/download-model")
5 changes: 3 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
test = pytest --addopts "-n auto -m 'not dockertest'"

# Test _everything_
testall = pytest
testall = pytest --addopts "--ignore benchmarks"

# Only run tests which use docker
testdocker = pytest --addopts "-m 'dockertest'"
Expand Down Expand Up @@ -36,6 +36,7 @@ testallelse = pytest --addopts
--ignore tests/gordo_components/server
--ignore tests/gordo_components/util
--ignore tests/gordo_components/watchman
--ignore tests/test_formatting.py"
--ignore tests/test_formatting.py
--ignore benchmarks"

testbenchmarks = pytest --addopts "--benchmark-only benchmarks/"
Loading

0 comments on commit c1c50fe

Please sign in to comment.