Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Included separate inference pool #2040

55 changes: 47 additions & 8 deletions mlserver/parallel/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,14 @@ def _get_env_tarball(model: MLModel) -> Optional[str]:
return to_absolute_path(model_settings, env_tarball)


def _get_environment_hash_gid(
env_hash: str, inference_pool_gid: Optional[str] = None
) -> str:
if inference_pool_gid:
return f"{env_hash}-{inference_pool_gid}"
return env_hash


class InferencePoolRegistry:
"""
Keeps track of the different inference pools loaded in the server.
Expand Down Expand Up @@ -80,14 +88,17 @@ async def _get_or_create(self, model: MLModel) -> InferencePool:
and model.settings.parameters.environment_path
):
pool = await self._get_or_create_with_existing_env(
model.settings.parameters.environment_path
model.settings.parameters.environment_path,
model.settings.parameters.inference_pool_gid,
)
else:
pool = await self._get_or_create_with_tarball(model)
return pool

async def _get_or_create_with_existing_env(
self, environment_path: str
self,
environment_path: str,
inference_pool_gid: Optional[str],
) -> InferencePool:
"""
Creates or returns the InferencePool for a model that uses an existing
Expand All @@ -98,8 +109,11 @@ async def _get_or_create_with_existing_env(
)
logger.info(f"Using environment {expanded_environment_path}")
env_hash = await compute_hash_of_string(expanded_environment_path)
env_hash = _get_environment_hash_gid(env_hash, inference_pool_gid)

if env_hash in self._pools:
return self._pools[env_hash]

env = Environment(
env_path=expanded_environment_path,
env_hash=env_hash,
Expand All @@ -114,22 +128,37 @@ async def _get_or_create_with_existing_env(
async def _get_or_create_with_tarball(self, model: MLModel) -> InferencePool:
"""
Creates or returns the InferencePool for a model that uses a
tarball as python environment.
tarball as a Python environment.
"""
env_tarball = _get_env_tarball(model)
inference_pool_gid = (
model.settings.parameters.inference_pool_gid
if model.settings.parameters
else None
)

if not env_tarball:
return self._default_pool
return (
self._pools.setdefault(
inference_pool_gid,
InferencePool(self._settings, on_worker_stop=self._on_worker_stop),
)
if inference_pool_gid
else self._default_pool
)

env_hash = await compute_hash_of_file(env_tarball)
env_hash = _get_environment_hash_gid(env_hash, inference_pool_gid)

if env_hash in self._pools:
return self._pools[env_hash]

env = await self._extract_tarball(env_hash, env_tarball)
pool = InferencePool(
self._pools[env_hash] = InferencePool(
self._settings, env=env, on_worker_stop=self._on_worker_stop
)
self._pools[env_hash] = pool
return pool

return self._pools[env_hash]

async def _extract_tarball(self, env_hash: str, env_tarball: str) -> Environment:
env_path = self._get_env_path(env_hash)
Expand All @@ -145,8 +174,18 @@ def _get_env_path(self, env_hash: str) -> str:

async def _find(self, model: MLModel) -> InferencePool:
env_hash = _get_environment_hash(model)
inference_pool_gid = (
model.settings.parameters.inference_pool_gid
if model.settings.parameters
else None
)

if not env_hash:
return self._default_pool
return (
self._default_pool
if not inference_pool_gid
else self._pools[inference_pool_gid]
)

if env_hash not in self._pools:
raise EnvironmentNotFound(model, env_hash)
Expand Down
15 changes: 15 additions & 0 deletions mlserver/settings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import sys
import os
import uuid
import json
import importlib
import inspect
Expand All @@ -14,11 +15,13 @@
no_type_check,
TYPE_CHECKING,
)
from typing_extensions import Self
from pydantic import (
ImportString,
Field,
AliasChoices,
)
from pydantic import model_validator
from pydantic._internal._validators import import_string
import pydantic_settings
from pydantic_settings import SettingsConfigDict
Expand Down Expand Up @@ -313,6 +316,12 @@ class ModelParameters(BaseSettings):
"""Path to the environment tarball which should be used to load this
model."""

inference_pool_gid: Optional[str] = None
"""Inference pool group id to be used to serve this model."""

autogenerate_inference_pool_gid: bool = False
"""Flag to autogenerate the inference pool group id for this model."""

format: Optional[str] = None
"""Format of the model (only available on certain runtimes)."""

Expand All @@ -323,6 +332,12 @@ class ModelParameters(BaseSettings):
"""Arbitrary settings, dependent on the inference runtime
implementation."""

@model_validator(mode="after")
def set_inference_pool_gid(self) -> Self:
if self.autogenerate_inference_pool_gid and self.inference_pool_gid is None:
self.inference_pool_gid = str(uuid.uuid4())
return self


class ModelSettings(BaseSettings):
model_config = SettingsConfigDict(
Expand Down
136 changes: 116 additions & 20 deletions tests/parallel/test_registry.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,25 @@
import pytest
import os
import asyncio
from copy import deepcopy
from typing import Optional
from unittest.mock import patch

from mlserver.env import Environment, compute_hash_of_file
from mlserver.model import MLModel
from mlserver.settings import Settings, ModelSettings
from mlserver.settings import Settings, ModelSettings, ModelParameters
from mlserver.types import InferenceRequest
from mlserver.codecs import StringCodec
from mlserver.parallel.errors import EnvironmentNotFound
from mlserver.parallel.registry import (
InferencePoolRegistry,
_set_environment_hash,
_get_environment_hash,
_get_environment_hash_gid,
ENV_HASH_ATTR,
)

from ..fixtures import EnvModel
from ..fixtures import SumModel, EnvModel


@pytest.fixture
Expand Down Expand Up @@ -71,12 +75,18 @@ async def test_default_pool(
assert worker_count == settings.parallel_workers


@pytest.mark.parametrize("inference_pool_gid", ["dummy_id", None])
async def test_load_model(
inference_pool_registry: InferencePoolRegistry,
sum_model: MLModel,
sum_model_settings: ModelSettings,
inference_request: InferenceRequest,
inference_pool_gid: str,
):
sum_model.settings.name = "foo"
sum_model_settings = deepcopy(sum_model_settings)
sum_model_settings.name = "foo"
sum_model_settings.parameters.inference_pool_gid = inference_pool_gid
sum_model = SumModel(sum_model_settings)

model = await inference_pool_registry.load_model(sum_model)
inference_response = await model.predict(inference_request)

Expand All @@ -87,20 +97,22 @@ async def test_load_model(
await inference_pool_registry.unload_model(sum_model)


def check_sklearn_version(response):
# Note: These versions come from the `environment.yml` found in
# `./tests/testdata/environment.yaml`
assert len(response.outputs) == 1
assert response.outputs[0].name == "sklearn_version"
[sklearn_version] = StringCodec.decode_output(response.outputs[0])
assert sklearn_version == "1.3.1"


async def test_load_model_with_env(
inference_pool_registry: InferencePoolRegistry,
env_model: MLModel,
inference_request: InferenceRequest,
):
response = await env_model.predict(inference_request)

assert len(response.outputs) == 1

# Note: These versions come from the `environment.yml` found in
# `./tests/testdata/environment.yaml`
assert response.outputs[0].name == "sklearn_version"
[sklearn_version] = StringCodec.decode_output(response.outputs[0])
assert sklearn_version == "1.3.1"
check_sklearn_version(response)


async def test_load_model_with_existing_env(
Expand All @@ -109,14 +121,7 @@ async def test_load_model_with_existing_env(
inference_request: InferenceRequest,
):
response = await existing_env_model.predict(inference_request)

assert len(response.outputs) == 1

# Note: These versions come from the `environment.yml` found in
# `./tests/testdata/environment.yaml`
assert response.outputs[0].name == "sklearn_version"
[sklearn_version] = StringCodec.decode_output(response.outputs[0])
assert sklearn_version == "1.3.1"
check_sklearn_version(response)


async def test_load_creates_pool(
Expand Down Expand Up @@ -224,3 +229,94 @@ async def test_worker_stop(
for _ in range(settings.parallel_workers + 2):
inference_response = await sum_model.predict(inference_request)
assert len(inference_response.outputs) > 0


@pytest.mark.parametrize(
"env_hash, inference_pool_gid, expected_env_hash",
[
("dummy_hash", None, "dummy_hash"),
("dummy_hash", "dummy_gid", "dummy_hash-dummy_gid"),
],
)
async def test__get_environment_hash_gid(
env_hash: str, inference_pool_gid: Optional[str], expected_env_hash: str
):
_env_hash = _get_environment_hash_gid(env_hash, inference_pool_gid)
assert _env_hash == expected_env_hash


async def test_default_and_default_gid(
inference_pool_registry: InferencePoolRegistry,
simple_model_settings: ModelSettings,
):
simple_model_settings_gid = deepcopy(simple_model_settings)
simple_model_settings_gid.parameters.inference_pool_gid = "dummy_id"

simple_model = SumModel(simple_model_settings)
simple_model_gid = SumModel(simple_model_settings_gid)

model = await inference_pool_registry.load_model(simple_model)
model_gid = await inference_pool_registry.load_model(simple_model_gid)

assert len(inference_pool_registry._pools) == 1
await inference_pool_registry.unload_model(model)
await inference_pool_registry.unload_model(model_gid)


async def test_env_and_env_gid(
inference_request: InferenceRequest,
inference_pool_registry: InferencePoolRegistry,
env_model_settings: ModelSettings,
env_tarball: str,
):
env_model_settings = deepcopy(env_model_settings)
env_model_settings.parameters.environment_tarball = env_tarball

env_model_settings_gid = deepcopy(env_model_settings)
env_model_settings_gid.parameters.inference_pool_gid = "dummy_id"

env_model = EnvModel(env_model_settings)
env_model_gid = EnvModel(env_model_settings_gid)

model = await inference_pool_registry.load_model(env_model)
model_gid = await inference_pool_registry.load_model(env_model_gid)
assert len(inference_pool_registry._pools) == 2

response = await model.predict(inference_request)
response_gid = await model_gid.predict(inference_request)
check_sklearn_version(response)
check_sklearn_version(response_gid)

await inference_pool_registry.unload_model(model)
await inference_pool_registry.unload_model(model_gid)


@pytest.mark.parametrize(
"inference_pool_grid, autogenerate_inference_pool_grid",
[
("dummy_gid", False),
("dummy_gid", True),
(None, True),
(None, False),
],
)
def test_autogenerate_inference_pool_gid(
inference_pool_grid: str, autogenerate_inference_pool_grid: bool
):
patch_uuid = "patch-uuid"
with patch("uuid.uuid4", return_value=patch_uuid):
model_settings = ModelSettings(
name="dummy-model",
implementation=MLModel,
parameters=ModelParameters(
inference_pool_gid=inference_pool_grid,
autogenerate_inference_pool_gid=autogenerate_inference_pool_grid,
),
)

expected_gid = (
inference_pool_grid
if not autogenerate_inference_pool_grid
else (inference_pool_grid or patch_uuid)
)
assert model_settings.parameters.inference_pool_gid == expected_gid
Loading