Skip to content

Commit

Permalink
Merge branch 'release/0.3.1'
Browse files Browse the repository at this point in the history
  • Loading branch information
pziecina-nv committed Sep 26, 2023
2 parents 9d05bdd + f807f7b commit 7e93b40
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 34 deletions.
3 changes: 1 addition & 2 deletions examples/nemo_megatron_gpt_multinode/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ def main():
action="store_true",
help="Enable verbose logging",
)

args = parser.parse_args()

log_level = logging.DEBUG if args.verbose else logging.INFO
Expand Down Expand Up @@ -122,7 +121,7 @@ def _param(dtype, value):
result_dict = client.infer_batch(
tasks=tasks,
prompts=prompts,
min_length=_param(np.int32, 0),
min_length=_param(np.int32, 20),
max_length=_param(np.int32, args.output_len),
use_greedy=_param(np.bool_, True),
temperature=_param(np.float32, 1.0),
Expand Down
6 changes: 2 additions & 4 deletions examples/nemo_megatron_gpt_multinode/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def _format_prompts(

@batch
@group_by_values("tasks", *_INPUT_PARAMETERS_NAMES, pad_fn=ConstantPadder(0))
@first_value(*_INPUT_PARAMETERS_NAMES)
@first_value(*_INPUT_PARAMETERS_NAMES, strict=False)
def infer(self, **inputs: np.ndarray) -> typing.Dict[str, np.ndarray]:
# Tell other ranks we're doing generate
generate_num = 0
Expand All @@ -82,19 +82,17 @@ def _str_ndarray2list(str_ndarray: np.ndarray) -> typing.List[str]:

tasks = _str_ndarray2list(inputs.pop("tasks"))
prompts = _str_ndarray2list(inputs.pop("prompts"))

length_params = LengthParam(**{k: v for k, v in inputs.items() if k in typing.get_type_hints(LengthParam)})
sampling_params = SamplingParam(
**{k: v for k, v in inputs.items() if k in typing.get_type_hints(SamplingParam)}
)

if tasks[0] == "text_generation":
generate_fn = self._text_generate_fn
else:
generate_fn = self._task_generate_fn
if generate_fn is None:
raise PyTritonInvalidOperationError(
f"Model {self.model_name} does not support task {inputs['task']}. "
f"Model {self.model_name} does not support task {tasks[0]}. "
"Only text_generation task is supported."
)
prompts = self._format_prompts(tasks, prompts)
Expand Down
22 changes: 19 additions & 3 deletions examples/nemo_megatron_gpt_multinode/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import socket
import typing
import warnings
from typing import Dict, Tuple, Type, Union

import filelock
import huggingface_hub # pytype: disable=import-error
Expand Down Expand Up @@ -76,13 +77,26 @@ def _map_type(type_):
else:
raise PyTritonBadParameterError(f"Unknown type {type_}")

def _get_tensor_params(type_):
def _get_tensor_params(type_: Type) -> Dict[str, Union[Tuple[int, ...], type]]:
"""
Returns a shape and a type of Triton tensor. The shape and the type are inferred from a
Python typing.
Args:
type_: a Python typing which should be a single type or a nested ``List``. If `type_` is a usual
type, then shape is ``(1,)``. If ``type_`` is a nested ``List``, then ``-1`` is added for each
``List``. E.g., ``List[int]`` -> ``(1, -1)``, ``List[List[int]]`` -> ``(1, -1, -1)``. Additional
Please note that all shapes have additional ``(1,)`` leading dimension.
Returns:
a dictionary with 2 elements: ``"shape"`` and ``"type"``. ``"type"`` is a numpy type which corresponds
to ``type_``.
"""
count = 0
while typing.get_origin(type_) is list:
type_ = typing.get_args(type_)[0]
count += 1
count -= 1 # we don't want to count the last dimension
shape = (-1,) * count if count > 0 else (1,)
shape = (1,) + (-1,) * count
return {"shape": shape, "dtype": _map_type(type_)}

overwrite_kwargs = overwrite_kwargs or {}
Expand Down Expand Up @@ -179,6 +193,8 @@ def load_model(
LOGGER.debug(f"Loading {model_path} on {worker_name}")

save_restore_connector = NLPSaveRestoreConnector()
if model_path.is_dir():
save_restore_connector.model_extracted_dir = model_path.as_posix()
pretrained_cfg = save_restore_connector.restore_from(
None, model_path.as_posix(), return_config=True, trainer=trainer
)
Expand Down
59 changes: 48 additions & 11 deletions examples/nemo_megatron_gpt_multinode/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
"""Text generation server with NeMo Megatron GPT model."""
import argparse
import logging
from pathlib import Path

import torch # pytype: disable=import-error
import yaml
from nemo.collections.nlp.modules.common.text_generation_utils import generate # pytype: disable=import-error
from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy # pytype: disable=import-error
from pytorch_lightning.trainer.trainer import Trainer # pytype: disable=import-error
Expand All @@ -39,6 +41,10 @@
DEFAULT_LOG_FORMAT = "%(asctime)s - %(levelname)8s - %(process)8d - %(threadName)s - %(name)s: %(message)s"


def _resolved_path(path_str):
return Path(path_str).resolve()


def main():
"""Main function."""
parser = argparse.ArgumentParser(description=__doc__)
Expand All @@ -56,7 +62,8 @@ def main():
type=int,
help="Number of nodes to load model on",
)
parser.add_argument(
model_location = parser.add_mutually_exclusive_group()
model_location.add_argument(
"--model-repo-id",
default="nvidia/nemo-megatron-gpt-1.3B",
help="Model repository id on HuggingFace Hub",
Expand All @@ -65,6 +72,12 @@ def main():
"--model-filename",
help="Path to the model nemo file in HF hub. If not provided first on the list .nemo file will be used.",
)
model_location.add_argument(
"--model-path",
help="Path to the model nemo file in local file system. This argument has a higher priority "
"than `--model-repo-id`.",
type=_resolved_path,
)
parser.add_argument("--prompt-model-path", help="Path to the model prompt nemo file")
parser.add_argument(
"--timeout",
Expand All @@ -79,7 +92,22 @@ def main():
action="store_true",
help="Enable verbose logging",
)

parser.add_argument(
"--triton-config",
type=_resolved_path,
help="A path to YAML config for Triton. You may find allowed fields in `pytriton.triton.TritonConfig`",
)
parser.add_argument(
"--model-name",
default="GPT",
help="A name of a Megatron model inside Triton.",
)
parser.add_argument(
"--workspace",
type=_resolved_path,
help="Path to a directory where workspace has to be created (optional)."
"If not provided workspace with random name will be created in ~/.cache/pytriton directory.",
)
args = parser.parse_args()

log_level = logging.DEBUG if args.verbose else logging.INFO
Expand All @@ -89,25 +117,34 @@ def main():
logger.info("Initialize trainer:")
logger.info(f" devices: {args.gpus}")
logger.info(f" nodes: {args.nodes}")

trainer = Trainer(
strategy=NLPDDPStrategy(),
devices=args.gpus,
num_nodes=args.nodes,
accelerator="gpu",
logger=False,
num_nodes=args.nodes,
precision=16,
logger=False,
enable_checkpointing=False,
replace_sampler_ddp=False,
)

model_path = download_hf_model(args.model_repo_id, args.model_filename)
model = load_model(model_path, trainer, prompt_learning_model_path=args.prompt_model_path)
if args.model_path is not None:
model = load_model(args.model_path, trainer, prompt_learning_model_path=args.prompt_model_path)
else:
model_path = download_hf_model(args.model_repo_id, args.model_filename)
model = load_model(model_path, trainer, prompt_learning_model_path=args.prompt_model_path)

app_state = setup_distributed_environment(trainer)
if app_state.global_rank == 0:

infer_callable = NemoGptCallable(model_name="GPT", model=model)

triton_config = TritonConfig(http_address=ENDPOINT_BIND_ADDRESS, http_port=HTTP_PORT)
with Triton(config=triton_config) as triton:
infer_callable = NemoGptCallable(model_name=args.model_name, model=model)
if args.triton_config is None:
triton_config = TritonConfig(http_address=ENDPOINT_BIND_ADDRESS, http_port=HTTP_PORT)
else:
with open(args.triton_config) as f:
data = yaml.safe_load(f)
triton_config = TritonConfig.from_dict(data)
with Triton(config=triton_config, workspace=args.workspace) as triton:
triton.bind(
model_name=infer_callable.model_name,
infer_func=infer_callable.infer,
Expand Down
40 changes: 26 additions & 14 deletions pytriton/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
import threading
import threading as th
import typing
from typing import Callable, Dict, List, Optional, Sequence, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Union

import typing_inspect

Expand Down Expand Up @@ -226,23 +226,15 @@ def to_dict(self):
return dataclasses.asdict(self)

@classmethod
def from_env(cls) -> "TritonConfig":
"""Creates TritonConfig from environment variables.
Environment variables should start with `PYTRITON_TRITON_CONFIG_` prefix. For example:
PYTRITON_TRITON_CONFIG_GRPC_PORT=45436
PYTRITON_TRITON_CONFIG_LOG_VERBOSE=4
Typical use:
def from_dict(cls, config: Dict[str, Any]) -> "TritonConfig":
"""Creates a ``TritonConfig`` instance from an input dictionary. Values are converted into correct types.
triton_config = TritonConfig.from_env()
Args:
config: a dictionary with all required fields
Returns:
TritonConfig class instantiated from environment variables.
a ``TritonConfig`` instance
"""
prefix = "PYTRITON_TRITON_CONFIG_"
config = {name[len(prefix) :].lower(): value for name, value in os.environ.items() if name.startswith(prefix)}
fields: Dict[str, dataclasses.Field] = {field.name: field for field in dataclasses.fields(cls)}
unknown_config_parameters = {name: value for name, value in config.items() if name not in fields}
for name, value in unknown_config_parameters.items():
Expand All @@ -263,6 +255,26 @@ def _cast_value(_field, _value):
}
return cls(**config_with_casted_values)

@classmethod
def from_env(cls) -> "TritonConfig":
"""Creates TritonConfig from environment variables.
Environment variables should start with `PYTRITON_TRITON_CONFIG_` prefix. For example:
PYTRITON_TRITON_CONFIG_GRPC_PORT=45436
PYTRITON_TRITON_CONFIG_LOG_VERBOSE=4
Typical use:
triton_config = TritonConfig.from_env()
Returns:
TritonConfig class instantiated from environment variables.
"""
prefix = "PYTRITON_TRITON_CONFIG_"
config = {name[len(prefix) :].lower(): value for name, value in os.environ.items() if name.startswith(prefix)}
return cls.from_dict(config)


class _LogLevelChecker:
"""Check if log level is too verbose."""
Expand Down

0 comments on commit 7e93b40

Please sign in to comment.