diff --git a/mlserver/parallel/registry.py b/mlserver/parallel/registry.py index 15122a690..c38fd2ed7 100644 --- a/mlserver/parallel/registry.py +++ b/mlserver/parallel/registry.py @@ -38,6 +38,12 @@ def _get_env_tarball(model: MLModel) -> Optional[str]: return to_absolute_path(model_settings, env_tarball) +def _append_gid_environment_hash( + env_hash: str, inference_pool_gid: Optional[str] = None +) -> str: + return f"{env_hash}-{inference_pool_gid}" + + class InferencePoolRegistry: """ Keeps track of the different inference pools loaded in the server. @@ -80,14 +86,17 @@ async def _get_or_create(self, model: MLModel) -> InferencePool: and model.settings.parameters.environment_path ): pool = await self._get_or_create_with_existing_env( - model.settings.parameters.environment_path + model.settings.parameters.environment_path, + model.settings.parameters.inference_pool_gid, ) else: pool = await self._get_or_create_with_tarball(model) return pool async def _get_or_create_with_existing_env( - self, environment_path: str + self, + environment_path: str, + inference_pool_gid: Optional[str], ) -> InferencePool: """ Creates or returns the InferencePool for a model that uses an existing @@ -98,8 +107,13 @@ async def _get_or_create_with_existing_env( ) logger.info(f"Using environment {expanded_environment_path}") env_hash = await compute_hash_of_string(expanded_environment_path) + + if inference_pool_gid is not None: + env_hash = _append_gid_environment_hash(env_hash, inference_pool_gid) + if env_hash in self._pools: return self._pools[env_hash] + env = Environment( env_path=expanded_environment_path, env_hash=env_hash, @@ -114,22 +128,38 @@ async def _get_or_create_with_existing_env( async def _get_or_create_with_tarball(self, model: MLModel) -> InferencePool: """ Creates or returns the InferencePool for a model that uses a - tarball as python environment. + tarball as a Python environment. """ env_tarball = _get_env_tarball(model) + inference_pool_gid = ( + model.settings.parameters.inference_pool_gid + if model.settings.parameters + else None + ) + if not env_tarball: - return self._default_pool + return ( + self._pools.setdefault( + inference_pool_gid, + InferencePool(self._settings, on_worker_stop=self._on_worker_stop), + ) + if inference_pool_gid + else self._default_pool + ) env_hash = await compute_hash_of_file(env_tarball) + if inference_pool_gid is not None: + env_hash = _append_gid_environment_hash(env_hash, inference_pool_gid) + if env_hash in self._pools: return self._pools[env_hash] env = await self._extract_tarball(env_hash, env_tarball) - pool = InferencePool( + self._pools[env_hash] = InferencePool( self._settings, env=env, on_worker_stop=self._on_worker_stop ) - self._pools[env_hash] = pool - return pool + + return self._pools[env_hash] async def _extract_tarball(self, env_hash: str, env_tarball: str) -> Environment: env_path = self._get_env_path(env_hash) @@ -145,8 +175,17 @@ def _get_env_path(self, env_hash: str) -> str: async def _find(self, model: MLModel) -> InferencePool: env_hash = _get_environment_hash(model) + inference_pool_gid = ( + model.settings.parameters.inference_pool_gid + if model.settings.parameters + else None + ) + if not env_hash: - return self._default_pool + if not inference_pool_gid: + return self._default_pool + else: + return self._pools[inference_pool_gid] if env_hash not in self._pools: raise EnvironmentNotFound(model, env_hash) diff --git a/mlserver/settings.py b/mlserver/settings.py index f390faab9..6c873ed7b 100644 --- a/mlserver/settings.py +++ b/mlserver/settings.py @@ -1,5 +1,6 @@ import sys import os +import uuid import json import importlib import inspect @@ -14,11 +15,13 @@ no_type_check, TYPE_CHECKING, ) +from typing_extensions import Self from pydantic import ( ImportString, Field, AliasChoices, ) +from pydantic import model_validator from pydantic._internal._validators import import_string import pydantic_settings from pydantic_settings import SettingsConfigDict @@ -313,6 +316,12 @@ class ModelParameters(BaseSettings): """Path to the environment tarball which should be used to load this model.""" + inference_pool_gid: Optional[str] = None + """Inference pool group id to be used to serve this model.""" + + autogenerate_inference_pool_gid: bool = False + """Flag to autogenerate the inference pool group id for this model.""" + format: Optional[str] = None """Format of the model (only available on certain runtimes).""" @@ -323,6 +332,12 @@ class ModelParameters(BaseSettings): """Arbitrary settings, dependent on the inference runtime implementation.""" + @model_validator(mode="after") + def set_inference_pool_gid(self) -> Self: + if self.autogenerate_inference_pool_gid and self.inference_pool_gid is None: + self.inference_pool_gid = str(uuid.uuid4()) + return self + class ModelSettings(BaseSettings): model_config = SettingsConfigDict( diff --git a/tests/parallel/test_registry.py b/tests/parallel/test_registry.py index 70629eadc..a6292021e 100644 --- a/tests/parallel/test_registry.py +++ b/tests/parallel/test_registry.py @@ -1,10 +1,13 @@ import pytest import os import asyncio +from copy import deepcopy +from typing import Optional +from unittest.mock import patch from mlserver.env import Environment, compute_hash_of_file from mlserver.model import MLModel -from mlserver.settings import Settings, ModelSettings +from mlserver.settings import Settings, ModelSettings, ModelParameters from mlserver.types import InferenceRequest from mlserver.codecs import StringCodec from mlserver.parallel.errors import EnvironmentNotFound @@ -12,10 +15,11 @@ InferencePoolRegistry, _set_environment_hash, _get_environment_hash, + _append_gid_environment_hash, ENV_HASH_ATTR, ) -from ..fixtures import EnvModel +from ..fixtures import SumModel, EnvModel @pytest.fixture @@ -71,12 +75,18 @@ async def test_default_pool( assert worker_count == settings.parallel_workers +@pytest.mark.parametrize("inference_pool_gid", ["dummy_id", None]) async def test_load_model( inference_pool_registry: InferencePoolRegistry, - sum_model: MLModel, + sum_model_settings: ModelSettings, inference_request: InferenceRequest, + inference_pool_gid: str, ): - sum_model.settings.name = "foo" + sum_model_settings = deepcopy(sum_model_settings) + sum_model_settings.name = "foo" + sum_model_settings.parameters.inference_pool_gid = inference_pool_gid + sum_model = SumModel(sum_model_settings) + model = await inference_pool_registry.load_model(sum_model) inference_response = await model.predict(inference_request) @@ -87,20 +97,22 @@ async def test_load_model( await inference_pool_registry.unload_model(sum_model) +def check_sklearn_version(response): + # Note: These versions come from the `environment.yml` found in + # `./tests/testdata/environment.yaml` + assert len(response.outputs) == 1 + assert response.outputs[0].name == "sklearn_version" + [sklearn_version] = StringCodec.decode_output(response.outputs[0]) + assert sklearn_version == "1.3.1" + + async def test_load_model_with_env( inference_pool_registry: InferencePoolRegistry, env_model: MLModel, inference_request: InferenceRequest, ): response = await env_model.predict(inference_request) - - assert len(response.outputs) == 1 - - # Note: These versions come from the `environment.yml` found in - # `./tests/testdata/environment.yaml` - assert response.outputs[0].name == "sklearn_version" - [sklearn_version] = StringCodec.decode_output(response.outputs[0]) - assert sklearn_version == "1.3.1" + check_sklearn_version(response) async def test_load_model_with_existing_env( @@ -109,14 +121,7 @@ async def test_load_model_with_existing_env( inference_request: InferenceRequest, ): response = await existing_env_model.predict(inference_request) - - assert len(response.outputs) == 1 - - # Note: These versions come from the `environment.yml` found in - # `./tests/testdata/environment.yaml` - assert response.outputs[0].name == "sklearn_version" - [sklearn_version] = StringCodec.decode_output(response.outputs[0]) - assert sklearn_version == "1.3.1" + check_sklearn_version(response) async def test_load_creates_pool( @@ -224,3 +229,93 @@ async def test_worker_stop( for _ in range(settings.parallel_workers + 2): inference_response = await sum_model.predict(inference_request) assert len(inference_response.outputs) > 0 + + +@pytest.mark.parametrize( + "env_hash, inference_pool_gid, expected_env_hash", + [ + ("dummy_hash", "dummy_gid", "dummy_hash-dummy_gid"), + ], +) +async def test__get_environment_hash_gid( + env_hash: str, inference_pool_gid: Optional[str], expected_env_hash: str +): + _env_hash = _append_gid_environment_hash(env_hash, inference_pool_gid) + assert _env_hash == expected_env_hash + + +async def test_default_and_default_gid( + inference_pool_registry: InferencePoolRegistry, + simple_model_settings: ModelSettings, +): + simple_model_settings_gid = deepcopy(simple_model_settings) + simple_model_settings_gid.parameters.inference_pool_gid = "dummy_id" + + simple_model = SumModel(simple_model_settings) + simple_model_gid = SumModel(simple_model_settings_gid) + + model = await inference_pool_registry.load_model(simple_model) + model_gid = await inference_pool_registry.load_model(simple_model_gid) + + assert len(inference_pool_registry._pools) == 1 + await inference_pool_registry.unload_model(model) + await inference_pool_registry.unload_model(model_gid) + + +async def test_env_and_env_gid( + inference_request: InferenceRequest, + inference_pool_registry: InferencePoolRegistry, + env_model_settings: ModelSettings, + env_tarball: str, +): + env_model_settings = deepcopy(env_model_settings) + env_model_settings.parameters.environment_tarball = env_tarball + + env_model_settings_gid = deepcopy(env_model_settings) + env_model_settings_gid.parameters.inference_pool_gid = "dummy_id" + + env_model = EnvModel(env_model_settings) + env_model_gid = EnvModel(env_model_settings_gid) + + model = await inference_pool_registry.load_model(env_model) + model_gid = await inference_pool_registry.load_model(env_model_gid) + assert len(inference_pool_registry._pools) == 2 + + response = await model.predict(inference_request) + response_gid = await model_gid.predict(inference_request) + check_sklearn_version(response) + check_sklearn_version(response_gid) + + await inference_pool_registry.unload_model(model) + await inference_pool_registry.unload_model(model_gid) + + +@pytest.mark.parametrize( + "inference_pool_grid, autogenerate_inference_pool_grid", + [ + ("dummy_gid", False), + ("dummy_gid", True), + (None, True), + (None, False), + ], +) +def test_autogenerate_inference_pool_gid( + inference_pool_grid: str, autogenerate_inference_pool_grid: bool +): + patch_uuid = "patch-uuid" + with patch("uuid.uuid4", return_value=patch_uuid): + model_settings = ModelSettings( + name="dummy-model", + implementation=MLModel, + parameters=ModelParameters( + inference_pool_gid=inference_pool_grid, + autogenerate_inference_pool_gid=autogenerate_inference_pool_grid, + ), + ) + + expected_gid = ( + inference_pool_grid + if not autogenerate_inference_pool_grid + else (inference_pool_grid or patch_uuid) + ) + assert model_settings.parameters.inference_pool_gid == expected_gid