Skip to content

Commit

Permalink
Pydantic v2 migration (#423)
Browse files Browse the repository at this point in the history
Co-authored-by: Logan Adams <[email protected]>
Co-authored-by: Abhishek Kulkarni <[email protected]>
Co-authored-by: Logan Adams <[email protected]>
  • Loading branch information
4 people authored Aug 22, 2024
1 parent 9568678 commit 6c71581
Show file tree
Hide file tree
Showing 17 changed files with 207 additions and 245 deletions.
4 changes: 2 additions & 2 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
asyncio
autodoc_pydantic<2.0.0
deepspeed>=0.13.0
autodoc_pydantic>=2.0.0
deepspeed>=0.15.0
grpcio
grpcio-tools
sphinx==7.1.2
Expand Down
10 changes: 5 additions & 5 deletions mii/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def _parse_kwargs_to_model_config(
# Fill model_config dict with relevant kwargs, store remaining kwargs in a new dict
remaining_kwargs = {}
for key, val in kwargs.items():
if key in ModelConfig.__dict__["__fields__"]:
if key in ModelConfig.model_fields.keys():
if key in model_config:
assert (
model_config.get(key) == val
Expand Down Expand Up @@ -77,7 +77,7 @@ def _parse_kwargs_to_mii_config(

# Fill mii_config dict with relevant kwargs, raise error on unknown kwargs
for key, val in remaining_kwargs.items():
if key in MIIConfig.__dict__["__fields__"]:
if key in MIIConfig.model_fields.keys():
if key in mii_config:
assert (
mii_config.get(key) == val
Expand Down Expand Up @@ -183,9 +183,9 @@ def serve(
mii.aml_related.utils.generate_aml_scripts(
acr_name=acr_name,
deployment_name=mii_config.deployment_name,
model_name=mii_config.model_config.model,
task_name=mii_config.model_config.task,
replica_num=mii_config.model_config.replica_num,
model_name=mii_config.model_conf.model,
task_name=mii_config.model_conf.task,
replica_num=mii_config.model_conf.replica_num,
instance_type=mii_config.instance_type,
version=mii_config.version,
)
Expand Down
2 changes: 1 addition & 1 deletion mii/backend/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class MIIClient:
"""
def __init__(self, mii_config: MIIConfig, host: str = "localhost") -> None:
self.mii_config = mii_config
self.task = mii_config.model_config.task
self.task = mii_config.model_conf.task
self.port = mii_config.port_number
self.asyncio_loop = asyncio.get_event_loop()
channel = create_channel(host, self.port)
Expand Down
19 changes: 9 additions & 10 deletions mii/backend/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

def config_to_b64_str(config: DeepSpeedConfigModel) -> str:
# convert json str -> bytes
json_bytes = config.json().encode()
json_bytes = config.model_dump_json().encode()
# base64 encoded bytes
b64_config_bytes = base64.urlsafe_b64encode(json_bytes)
# bytes -> str
Expand All @@ -31,7 +31,7 @@ class MIIServer:
"""Initialize the model, setup the server for the model"""
def __init__(self, mii_config: MIIConfig) -> None:

self.task = mii_config.model_config.task
self.task = mii_config.model_conf.task
self.port_number = mii_config.port_number

if not os.path.isfile(mii_config.hostfile):
Expand All @@ -47,8 +47,7 @@ def __init__(self, mii_config: MIIConfig) -> None:
# balancer process, each DeepSpeed model replica, and optionally the
# REST API process)
processes = self._initialize_service(mii_config)
self._wait_until_server_is_live(processes,
mii_config.model_config.replica_configs)
self._wait_until_server_is_live(processes, mii_config.model_conf.replica_configs)

def _wait_until_server_is_live(self,
processes: List[subprocess.Popen],
Expand Down Expand Up @@ -143,15 +142,15 @@ def _initialize_service(self, mii_config: MIIConfig) -> List[subprocess.Popen]:
]

host_gpus = defaultdict(list)
for repl_config in mii_config.model_config.replica_configs:
for repl_config in mii_config.model_conf.replica_configs:
host_gpus[repl_config.hostname].extend(repl_config.gpu_indices)

use_multiple_hosts = len(
set(repl_config.hostname
for repl_config in mii_config.model_config.replica_configs)) > 1
for repl_config in mii_config.model_conf.replica_configs)) > 1

# Start replica instances
for repl_config in mii_config.model_config.replica_configs:
for repl_config in mii_config.model_conf.replica_configs:
hostfile = tempfile.NamedTemporaryFile(delete=False)
hostfile.write(
f"{repl_config.hostname} slots={max(host_gpus[repl_config.hostname])+1}\n"
Expand All @@ -161,7 +160,7 @@ def _initialize_service(self, mii_config: MIIConfig) -> List[subprocess.Popen]:
use_multiple_hosts)
processes.append(
self._launch_server_process(
mii_config.model_config,
mii_config.model_conf,
"MII server",
ds_launch_str=ds_launch_str,
server_args=server_args + [
Expand All @@ -175,15 +174,15 @@ def _initialize_service(self, mii_config: MIIConfig) -> List[subprocess.Popen]:
# expected to assign one GPU to one process.
processes.append(
self._launch_server_process(
mii_config.model_config,
mii_config.model_conf,
"load balancer",
server_args=server_args + ["--load-balancer"],
))

if mii_config.enable_restful_api:
processes.append(
self._launch_server_process(
mii_config.model_config,
mii_config.model_conf,
"restful api gateway",
server_args=server_args + ["--restful-gateway"],
))
Expand Down
125 changes: 58 additions & 67 deletions mii/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,18 @@

from deepspeed.launcher.runner import DLTS_HOSTFILE, fetch_hostfile
from deepspeed.inference import RaggedInferenceEngineConfig
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
from pydantic import Field, model_validator, field_validator

from mii.constants import DeploymentType, TaskType, ModelProvider
from mii.errors import DeploymentNotFoundError
from mii.modeling.tokenizers import MIITokenizerWrapper
from mii.pydantic_v1 import BaseModel, Field, root_validator, validator, Extra
from mii.utils import generate_deployment_name, get_default_task, import_score_file
from mii.utils import generate_deployment_name, import_score_file

DEVICE_MAP_DEFAULT = "auto"


class MIIConfigModel(BaseModel):
class Config:
validate_all = True
validate_assignment = True
use_enum_values = True
allow_population_by_field_name = True
extra = "forbid"
arbitrary_types_allowed = True


class GenerateParamsConfig(MIIConfigModel):
class GenerateParamsConfig(DeepSpeedConfigModel):
"""
Options for changing text-generation behavior.
"""
Expand All @@ -39,7 +30,7 @@ class GenerateParamsConfig(MIIConfigModel):
max_length: int = 1024
""" Maximum length of ``input_tokens`` + ``generated_tokens``. """

max_new_tokens: int = None
max_new_tokens: Optional[int] = None
""" Maximum number of new tokens generated. ``max_length`` takes precedent. """

min_new_tokens: int = 0
Expand Down Expand Up @@ -68,24 +59,25 @@ class GenerateParamsConfig(MIIConfigModel):

stop: List[str] = []
""" List of strings to stop generation at."""
@validator("stop", pre=True)
@field_validator("stop", mode="before")
@classmethod
def make_stop_string_list(cls, field_value: Union[str, List[str]]) -> List[str]:
if isinstance(field_value, str):
return [field_value]
return field_value

@validator("stop")
@field_validator("stop")
@classmethod
def sort_stop_strings(cls, field_value: List[str]) -> List[str]:
return sorted(field_value)

@root_validator
def check_prompt_length(cls, values: Dict[str, Any]) -> Dict[str, Any]:
prompt_length = values.get("prompt_length")
max_length = values.get("max_length")
assert max_length > prompt_length, f"max_length ({max_length}) must be greater than prompt_length ({prompt_length})"
return values
@model_validator(mode="after")
def check_prompt_length(self) -> "GenerateParamsConfig":
assert self.max_length > self.prompt_length, f"max_length ({self.max_length}) must be greater than prompt_length ({self.prompt_length})"
return self

@root_validator
@model_validator(mode="before")
@classmethod
def set_max_new_tokens(cls, values: Dict[str, Any]) -> Dict[str, Any]:
max_length = values.get("max_length")
max_new_tokens = values.get("max_new_tokens")
Expand All @@ -94,19 +86,16 @@ def set_max_new_tokens(cls, values: Dict[str, Any]) -> Dict[str, Any]:
values["max_new_tokens"] = max_length - prompt_length
return values

class Config:
extra = Extra.forbid


class ReplicaConfig(MIIConfigModel):
class ReplicaConfig(DeepSpeedConfigModel):
hostname: str = ""
tensor_parallel_ports: List[int] = []
torch_dist_port: int = None
torch_dist_port: Optional[int] = None
gpu_indices: List[int] = []
zmq_port: int = None
zmq_port: Optional[int] = None


class ModelConfig(MIIConfigModel):
class ModelConfig(DeepSpeedConfigModel):
model_name_or_path: str
"""
Model name or path of the model to HuggingFace model to be deployed.
Expand Down Expand Up @@ -192,8 +181,9 @@ class ModelConfig(MIIConfigModel):
def provider(self) -> ModelProvider:
return ModelProvider.HUGGING_FACE

@validator("device_map", pre=True)
def make_device_map_dict(cls, v):
@field_validator("device_map", mode="before")
@classmethod
def make_device_map_dict(cls, v: Any) -> Dict:
if isinstance(v, int):
return {"localhost": [[v]]}
if isinstance(v, list) and isinstance(v[0], int):
Expand All @@ -202,36 +192,36 @@ def make_device_map_dict(cls, v):
return {"localhost": v}
return v

@root_validator
@model_validator(mode="before")
@classmethod
def auto_fill_values(cls, values: Dict[str, Any]) -> Dict[str, Any]:
assert values.get("model_name_or_path"), "model_name_or_path must be provided"
if not values.get("tokenizer"):
values["tokenizer"] = values.get("model_name_or_path")
if not values.get("task"):
values["task"] = get_default_task(values.get("model_name_or_path"))
#if not values.get("task"):
# values["task"] = get_default_task(values.get("model_name_or_path"))
values["task"] = TaskType.TEXT_GENERATION
return values

@root_validator
def propagate_tp_size(cls, values: Dict[str, Any]) -> Dict[str, Any]:
tensor_parallel = values.get("tensor_parallel")
values.get("inference_engine_config").tensor_parallel.tp_size = tensor_parallel
return values

@root_validator
def propagate_quantization_mode(cls, values: Dict[str, Any]) -> Dict[str, Any]:
quantization_mode = values.get("quantization_mode")
values.get(
"inference_engine_config").quantization.quantization_mode = quantization_mode
return values
@model_validator(mode="after")
def propagate_tp_size(self) -> "ModelConfig":
self.inference_engine_config.tensor_parallel.tp_size = self.tensor_parallel
return self

@root_validator
def check_replica_config(cls, values: Dict[str, Any]) -> Dict[str, Any]:
num_replica_config = len(values.get("replica_configs"))
@model_validator(mode="after")
def check_replica_config(self) -> "ModelConfig":
num_replica_config = len(self.replica_configs)
if num_replica_config > 0:
assert num_replica_config == values.get("replica_num"), "Number of replica configs must match replica_num"
return values
assert num_replica_config == self.replica_num, "Number of replica configs must match replica_num"
return self

@model_validator(mode="after")
def propagate_quantization_mode(self) -> "ModelConfig":
self.inference_engine_config.quantization.quantization_mode = self.quantization_mode
return self


class MIIConfig(MIIConfigModel):
class MIIConfig(DeepSpeedConfigModel):
deployment_name: str = ""
"""
Name of the deployment. Used as an identifier for obtaining a inference
Expand All @@ -245,7 +235,7 @@ class MIIConfig(MIIConfigModel):
* `AML` will generate the assets necessary to deploy on AML resources.
"""

model_config: ModelConfig
model_conf: ModelConfig = Field(alias="model_config")
"""
Configuration for the deployed model(s).
"""
Expand Down Expand Up @@ -290,17 +280,18 @@ class MIIConfig(MIIConfigModel):
"""
AML instance type to use when create AML deployment assets.
"""
@root_validator(skip_on_failure=True)
def AML_name_valid(cls, values: Dict[str, Any]) -> Dict[str, Any]:
if values.get("deployment_type") == DeploymentType.AML:
@model_validator(mode="after")
def AML_name_valid(self) -> "MIIConfig":
if self.deployment_type == DeploymentType.AML:
allowed_chars = set(string.ascii_lowercase + string.ascii_uppercase +
string.digits + "-")
assert (
set(values.get("deployment_name")) <= allowed_chars
set(self.deployment_name) <= allowed_chars
), "AML deployment names can only contain a-z, A-Z, 0-9, and '-'."
return values
return self

@root_validator(skip_on_failure=True)
@model_validator(mode="before")
@classmethod
def check_deployment_name(cls, values: Dict[str, Any]) -> Dict[str, Any]:
deployment_name = values.get("deployment_name")
if not deployment_name:
Expand All @@ -311,14 +302,14 @@ def check_deployment_name(cls, values: Dict[str, Any]) -> Dict[str, Any]:
return values

def generate_replica_configs(self) -> None:
if self.model_config.replica_configs:
if self.model_conf.replica_configs:
return
torch_dist_port = self.model_config.torch_dist_port
tensor_parallel = self.model_config.tensor_parallel
torch_dist_port = self.model_conf.torch_dist_port
tensor_parallel = self.model_conf.tensor_parallel
replica_pool = _allocate_devices(self.hostfile,
tensor_parallel,
self.model_config.replica_num,
self.model_config.device_map)
self.model_conf.replica_num,
self.model_conf.device_map)
replica_configs = []
for i, (hostname, gpu_indices) in enumerate(replica_pool):
# Reserver port for a LB proxy when replication is enabled
Expand All @@ -332,10 +323,10 @@ def generate_replica_configs(self) -> None:
tensor_parallel_ports=tensor_parallel_ports,
torch_dist_port=replica_torch_dist_port,
gpu_indices=gpu_indices,
zmq_port=self.model_config.zmq_port_number + i,
zmq_port=self.model_conf.zmq_port_number + i,
))

self.model_config.replica_configs = replica_configs
self.model_conf.replica_configs = replica_configs


def _allocate_devices(hostfile_path: str,
Expand Down
2 changes: 1 addition & 1 deletion mii/legacy/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def mii_query_handle(deployment_name):
return MIINonPersistentClient(task, deployment_name)

mii_config = _get_mii_config(deployment_name)
return MIIClient(mii_config.model_config.task,
return MIIClient(mii_config.model_conf.task,
"localhost", # TODO: This can probably be removed
mii_config.port_number)

Expand Down
Loading

0 comments on commit 6c71581

Please sign in to comment.