From 14dd6edaf2be0f2d5ef5b1fdf4feef7b5cb3f337 Mon Sep 17 00:00:00 2001 From: Anton Peganov Date: Tue, 5 Sep 2023 12:37:03 -0700 Subject: [PATCH] Update megatron example so that it would support latest changes in NeMo --- .../nemo_megatron_gpt_multinode/client.py | 3 +- examples/nemo_megatron_gpt_multinode/gpt.py | 6 +- .../nemo_megatron_gpt_multinode/helpers.py | 22 ++++++- .../nemo_megatron_gpt_multinode/server.py | 59 +++++++++++++++---- pytriton/triton.py | 40 ++++++++----- 5 files changed, 96 insertions(+), 34 deletions(-) diff --git a/examples/nemo_megatron_gpt_multinode/client.py b/examples/nemo_megatron_gpt_multinode/client.py index 88bcd50..8178ad9 100755 --- a/examples/nemo_megatron_gpt_multinode/client.py +++ b/examples/nemo_megatron_gpt_multinode/client.py @@ -86,7 +86,6 @@ def main(): action="store_true", help="Enable verbose logging", ) - args = parser.parse_args() log_level = logging.DEBUG if args.verbose else logging.INFO @@ -122,7 +121,7 @@ def _param(dtype, value): result_dict = client.infer_batch( tasks=tasks, prompts=prompts, - min_length=_param(np.int32, 0), + min_length=_param(np.int32, 20), max_length=_param(np.int32, args.output_len), use_greedy=_param(np.bool_, True), temperature=_param(np.float32, 1.0), diff --git a/examples/nemo_megatron_gpt_multinode/gpt.py b/examples/nemo_megatron_gpt_multinode/gpt.py index d1d2904..aa4cc97 100644 --- a/examples/nemo_megatron_gpt_multinode/gpt.py +++ b/examples/nemo_megatron_gpt_multinode/gpt.py @@ -67,7 +67,7 @@ def _format_prompts( @batch @group_by_values("tasks", *_INPUT_PARAMETERS_NAMES, pad_fn=ConstantPadder(0)) - @first_value(*_INPUT_PARAMETERS_NAMES) + @first_value(*_INPUT_PARAMETERS_NAMES, strict=False) def infer(self, **inputs: np.ndarray) -> typing.Dict[str, np.ndarray]: # Tell other ranks we're doing generate generate_num = 0 @@ -82,19 +82,17 @@ def _str_ndarray2list(str_ndarray: np.ndarray) -> typing.List[str]: tasks = _str_ndarray2list(inputs.pop("tasks")) prompts = _str_ndarray2list(inputs.pop("prompts")) - length_params = LengthParam(**{k: v for k, v in inputs.items() if k in typing.get_type_hints(LengthParam)}) sampling_params = SamplingParam( **{k: v for k, v in inputs.items() if k in typing.get_type_hints(SamplingParam)} ) - if tasks[0] == "text_generation": generate_fn = self._text_generate_fn else: generate_fn = self._task_generate_fn if generate_fn is None: raise PyTritonInvalidOperationError( - f"Model {self.model_name} does not support task {inputs['task']}. " + f"Model {self.model_name} does not support task {tasks[0]}. " "Only text_generation task is supported." ) prompts = self._format_prompts(tasks, prompts) diff --git a/examples/nemo_megatron_gpt_multinode/helpers.py b/examples/nemo_megatron_gpt_multinode/helpers.py index 8a6a716..3f95e6d 100644 --- a/examples/nemo_megatron_gpt_multinode/helpers.py +++ b/examples/nemo_megatron_gpt_multinode/helpers.py @@ -17,6 +17,7 @@ import socket import typing import warnings +from typing import Dict, Tuple, Type, Union import filelock import huggingface_hub # pytype: disable=import-error @@ -76,13 +77,26 @@ def _map_type(type_): else: raise PyTritonBadParameterError(f"Unknown type {type_}") - def _get_tensor_params(type_): + def _get_tensor_params(type_: Type) -> Dict[str, Union[Tuple[int, ...], type]]: + """ + Returns a shape and a type of Triton tensor. The shape and the type are inferred from a + Python typing. + + Args: + type_: a Python typing which should be a single type or a nested ``List``. If `type_` is a usual + type, then shape is ``(1,)``. If ``type_`` is a nested ``List``, then ``-1`` is added for each + ``List``. E.g., ``List[int]`` -> ``(1, -1)``, ``List[List[int]]`` -> ``(1, -1, -1)``. Additional + Please note that all shapes have additional ``(1,)`` leading dimension. + + Returns: + a dictionary with 2 elements: ``"shape"`` and ``"type"``. ``"type"`` is a numpy type which corresponds + to ``type_``. + """ count = 0 while typing.get_origin(type_) is list: type_ = typing.get_args(type_)[0] count += 1 - count -= 1 # we don't want to count the last dimension - shape = (-1,) * count if count > 0 else (1,) + shape = (1,) + (-1,) * count return {"shape": shape, "dtype": _map_type(type_)} overwrite_kwargs = overwrite_kwargs or {} @@ -179,6 +193,8 @@ def load_model( LOGGER.debug(f"Loading {model_path} on {worker_name}") save_restore_connector = NLPSaveRestoreConnector() + if model_path.is_dir(): + save_restore_connector.model_extracted_dir = model_path.as_posix() pretrained_cfg = save_restore_connector.restore_from( None, model_path.as_posix(), return_config=True, trainer=trainer ) diff --git a/examples/nemo_megatron_gpt_multinode/server.py b/examples/nemo_megatron_gpt_multinode/server.py index 94101d1..52231dd 100755 --- a/examples/nemo_megatron_gpt_multinode/server.py +++ b/examples/nemo_megatron_gpt_multinode/server.py @@ -15,8 +15,10 @@ """Text generation server with NeMo Megatron GPT model.""" import argparse import logging +from pathlib import Path import torch # pytype: disable=import-error +import yaml from nemo.collections.nlp.modules.common.text_generation_utils import generate # pytype: disable=import-error from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy # pytype: disable=import-error from pytorch_lightning.trainer.trainer import Trainer # pytype: disable=import-error @@ -39,6 +41,10 @@ DEFAULT_LOG_FORMAT = "%(asctime)s - %(levelname)8s - %(process)8d - %(threadName)s - %(name)s: %(message)s" +def _resolved_path(path_str): + return Path(path_str).resolve() + + def main(): """Main function.""" parser = argparse.ArgumentParser(description=__doc__) @@ -56,7 +62,8 @@ def main(): type=int, help="Number of nodes to load model on", ) - parser.add_argument( + model_location = parser.add_mutually_exclusive_group() + model_location.add_argument( "--model-repo-id", default="nvidia/nemo-megatron-gpt-1.3B", help="Model repository id on HuggingFace Hub", @@ -65,6 +72,12 @@ def main(): "--model-filename", help="Path to the model nemo file in HF hub. If not provided first on the list .nemo file will be used.", ) + model_location.add_argument( + "--model-path", + help="Path to the model nemo file in local file system. This argument has a higher priority " + "than `--model-repo-id`.", + type=_resolved_path, + ) parser.add_argument("--prompt-model-path", help="Path to the model prompt nemo file") parser.add_argument( "--timeout", @@ -79,7 +92,22 @@ def main(): action="store_true", help="Enable verbose logging", ) - + parser.add_argument( + "--triton-config", + type=_resolved_path, + help="A path to YAML config for Triton. You may find allowed fields in `pytriton.triton.TritonConfig`", + ) + parser.add_argument( + "--model-name", + default="GPT", + help="A name of a Megatron model inside Triton.", + ) + parser.add_argument( + "--workspace", + type=_resolved_path, + help="Path to a directory where workspace has to be created (optional)." + "If not provided workspace with random name will be created in ~/.cache/pytriton directory.", + ) args = parser.parse_args() log_level = logging.DEBUG if args.verbose else logging.INFO @@ -89,25 +117,34 @@ def main(): logger.info("Initialize trainer:") logger.info(f" devices: {args.gpus}") logger.info(f" nodes: {args.nodes}") + trainer = Trainer( strategy=NLPDDPStrategy(), devices=args.gpus, - num_nodes=args.nodes, accelerator="gpu", - logger=False, + num_nodes=args.nodes, precision=16, + logger=False, + enable_checkpointing=False, + replace_sampler_ddp=False, ) - - model_path = download_hf_model(args.model_repo_id, args.model_filename) - model = load_model(model_path, trainer, prompt_learning_model_path=args.prompt_model_path) + if args.model_path is not None: + model = load_model(args.model_path, trainer, prompt_learning_model_path=args.prompt_model_path) + else: + model_path = download_hf_model(args.model_repo_id, args.model_filename) + model = load_model(model_path, trainer, prompt_learning_model_path=args.prompt_model_path) app_state = setup_distributed_environment(trainer) if app_state.global_rank == 0: - infer_callable = NemoGptCallable(model_name="GPT", model=model) - - triton_config = TritonConfig(http_address=ENDPOINT_BIND_ADDRESS, http_port=HTTP_PORT) - with Triton(config=triton_config) as triton: + infer_callable = NemoGptCallable(model_name=args.model_name, model=model) + if args.triton_config is None: + triton_config = TritonConfig(http_address=ENDPOINT_BIND_ADDRESS, http_port=HTTP_PORT) + else: + with open(args.triton_config) as f: + data = yaml.safe_load(f) + triton_config = TritonConfig.from_dict(data) + with Triton(config=triton_config, workspace=args.workspace) as triton: triton.bind( model_name=infer_callable.model_name, infer_func=infer_callable.infer, diff --git a/pytriton/triton.py b/pytriton/triton.py index 87b490b..1834ac6 100644 --- a/pytriton/triton.py +++ b/pytriton/triton.py @@ -43,7 +43,7 @@ import threading import threading as th import typing -from typing import Callable, Dict, List, Optional, Sequence, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Union import typing_inspect @@ -226,23 +226,15 @@ def to_dict(self): return dataclasses.asdict(self) @classmethod - def from_env(cls) -> "TritonConfig": - """Creates TritonConfig from environment variables. - - Environment variables should start with `PYTRITON_TRITON_CONFIG_` prefix. For example: - - PYTRITON_TRITON_CONFIG_GRPC_PORT=45436 - PYTRITON_TRITON_CONFIG_LOG_VERBOSE=4 - - Typical use: + def from_dict(cls, config: Dict[str, Any]) -> "TritonConfig": + """Creates a ``TritonConfig`` instance from an input dictionary. Values are converted into correct types. - triton_config = TritonConfig.from_env() + Args: + config: a dictionary with all required fields Returns: - TritonConfig class instantiated from environment variables. + a ``TritonConfig`` instance """ - prefix = "PYTRITON_TRITON_CONFIG_" - config = {name[len(prefix) :].lower(): value for name, value in os.environ.items() if name.startswith(prefix)} fields: Dict[str, dataclasses.Field] = {field.name: field for field in dataclasses.fields(cls)} unknown_config_parameters = {name: value for name, value in config.items() if name not in fields} for name, value in unknown_config_parameters.items(): @@ -263,6 +255,26 @@ def _cast_value(_field, _value): } return cls(**config_with_casted_values) + @classmethod + def from_env(cls) -> "TritonConfig": + """Creates TritonConfig from environment variables. + + Environment variables should start with `PYTRITON_TRITON_CONFIG_` prefix. For example: + + PYTRITON_TRITON_CONFIG_GRPC_PORT=45436 + PYTRITON_TRITON_CONFIG_LOG_VERBOSE=4 + + Typical use: + + triton_config = TritonConfig.from_env() + + Returns: + TritonConfig class instantiated from environment variables. + """ + prefix = "PYTRITON_TRITON_CONFIG_" + config = {name[len(prefix) :].lower(): value for name, value in os.environ.items() if name.startswith(prefix)} + return cls.from_dict(config) + class _LogLevelChecker: """Check if log level is too verbose."""