Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Dec 16, 2024
1 parent ecaa6c8 commit ea32802
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 81 deletions.
150 changes: 99 additions & 51 deletions optimum_benchmark/backends/py_txi/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,6 @@ def load(self) -> None:
else:
self.logger.info("\t+ Downloading pretrained model")
self.download_pretrained_model()

if self.config.task in TEXT_GENERATION_TASKS:
self.logger.info("\t+ Preparing generation config")
self.prepare_generation_config()

self.logger.info("\t+ Loading pretrained model")
self.load_model_from_pretrained()

Expand All @@ -50,86 +45,139 @@ def download_pretrained_model(self) -> None:
with init_empty_weights(include_buffers=True):
self.automodel_loader.from_pretrained(self.config.model, **self.config.model_kwargs, cache_dir=self.volume)

def prepare_generation_config(self) -> None:
self.generation_config.eos_token_id = None
self.generation_config.pad_token_id = None
model_cache_folder = f"models/{self.config.model}".replace("/", "--")
model_cache_path = f"{self.volume}/{model_cache_folder}"
snapshot_file = f"{model_cache_path}/refs/{self.config.model_kwargs.get('revision', 'main')}"
snapshot_ref = open(snapshot_file, "r").read().strip()
model_snapshot_path = f"{model_cache_path}/snapshots/{snapshot_ref}"
self.logger.info("\t+ Saving new pretrained generation config")
self.generation_config.save_pretrained(save_directory=model_snapshot_path)
if self.config.task in TEXT_GENERATION_TASKS:
self.logger.info("\t+ Preparing generation config")
self.generation_config.eos_token_id = None
self.generation_config.pad_token_id = None
model_cache_folder = f"models/{self.config.model}".replace("/", "--")
model_cache_path = f"{self.volume}/{model_cache_folder}"
snapshot_file = f"{model_cache_path}/refs/{self.config.model_kwargs.get('revision', 'main')}"
snapshot_ref = open(snapshot_file, "r").read().strip()
model_snapshot_path = f"{model_cache_path}/snapshots/{snapshot_ref}"
self.logger.info("\t+ Saving pretrained generation config")
self.generation_config.save_pretrained(save_directory=model_snapshot_path)

def create_no_weights_model(self) -> None:
self.no_weights_model = os.path.join(self.tmpdir.name, "no_weights_model")
filename = os.path.join(self.no_weights_model, "model.safetensors")
os.makedirs(self.no_weights_model, exist_ok=True)
state_dict = torch.nn.Linear(1, 1).state_dict()
safetensor = os.path.join(self.no_weights_model, "model.safetensors")
save_file(tensors=state_dict, filename=safetensor, metadata={"format": "pt"})

save_file(tensors=torch.nn.Linear(1, 1).state_dict(), filename=filename, metadata={"format": "pt"})
with fast_weights_init():
# unlike Transformers, TXI won't accept any missing tensors so we need to materialize the model
self.pretrained_model = self.automodel_loader.from_pretrained(
self.no_weights_model, **self.config.model_kwargs, device_map="auto", _fast_init=False
)
save_file(tensors=self.pretrained_model.state_dict(), filename=safetensor, metadata={"format": "pt"})
save_file(tensors=self.pretrained_model.state_dict(), filename=filename, metadata={"format": "pt"})
del self.pretrained_model
torch.cuda.empty_cache()

if self.pretrained_config is not None:
self.pretrained_config.save_pretrained(save_directory=self.no_weights_model)
if self.pretrained_processor is not None:
self.pretrained_processor.save_pretrained(save_directory=self.no_weights_model)
if self.generation_config is not None:

if self.config.task in TEXT_GENERATION_TASKS:
self.generation_config.eos_token_id = None
self.generation_config.pad_token_id = None
self.generation_config.save_pretrained(save_directory=self.no_weights_model)

def load_model_with_no_weights(self) -> None:
self.config.volumes = (self.config.volumes, {self.tmpdir.name: {"bind": self.tmpdir.name, "mode": "rw"}})
self.config.volumes = {self.tmpdir.name: {"bind": self.tmpdir.name, "mode": "rw"}}
original_model, self.config.model = self.config.model, self.no_weights_model
self.load_model_from_pretrained()
self.config.model, self.config.volumes = original_model
self.config.model = original_model

def load_model_from_pretrained(self) -> None:
if self.config.task in TEXT_GENERATION_TASKS:
self.pretrained_model = TGI(
config=TGIConfig(
model_id=self.config.model,
gpus=self.config.gpus,
devices=self.config.devices,
volumes=self.config.volumes,
environment=self.config.environment,
ports=self.config.ports,
dtype=self.config.dtype,
sharded=self.config.sharded,
quantize=self.config.quantize,
num_shard=self.config.num_shard,
speculate=self.config.speculate,
cuda_graphs=self.config.cuda_graphs,
disable_custom_kernels=self.config.disable_custom_kernels,
trust_remote_code=self.config.trust_remote_code,
max_concurrent_requests=self.config.max_concurrent_requests,
),
config=TGIConfig(self.config.model, **self.txi_kwargs, **self.tgi_kwargs),
)

elif self.config.task in TEXT_EMBEDDING_TASKS:
self.pretrained_model = TEI(
config=TEIConfig(
model_id=self.config.model,
gpus=self.config.gpus,
devices=self.config.devices,
volumes=self.config.volumes,
environment=self.config.environment,
ports=self.config.ports,
dtype=self.config.dtype,
pooling=self.config.pooling,
max_concurrent_requests=self.config.max_concurrent_requests,
),
config=TEIConfig(self.config.model, **self.txi_kwargs, **self.tei_kwargs),
)
else:
raise NotImplementedError(f"TXI does not support task {self.config.task}")

@property
def txi_kwargs(self):
kwargs = {}

if self.config.gpus is not None:
kwargs["gpus"] = self.config.gpus

if self.config.image is not None:
kwargs["image"] = self.config.image

if self.config.ports is not None:
kwargs["ports"] = self.config.ports

if self.config.volumes is not None:
kwargs["volumes"] = self.config.volumes

if self.config.devices is not None:
kwargs["devices"] = self.config.devices

if self.config.shm_size is not None:
kwargs["shm_size"] = self.config.shm_size

if self.config.environment is not None:
kwargs["environment"] = self.config.environment

if self.config.connection_timeout is not None:
kwargs["connection_timeout"] = self.config.connection_timeout

if self.config.first_request_timeout is not None:
kwargs["first_request_timeout"] = self.config.first_request_timeout

if self.config.max_concurrent_requests is not None:
kwargs["max_concurrent_requests"] = self.config.max_concurrent_requests

return kwargs

@property
def tei_kwargs(self):
kwargs = {}

if self.config.dtype is not None:
kwargs["dtype"] = self.config.dtype

if self.config.pooling is not None:
kwargs["pooling"] = self.config.pooling

return kwargs

@property
def tgi_kwargs(self):
kwargs = {}

if self.config.dtype is not None:
kwargs["dtype"] = self.config.dtype

if self.config.sharded is not None:
kwargs["sharded"] = self.config.sharded

if self.config.quantize is not None:
kwargs["quantize"] = self.config.quantize

if self.config.num_shard is not None:
kwargs["num_shard"] = self.config.num_shard

if self.config.speculate is not None:
kwargs["speculate"] = self.config.speculate

if self.config.cuda_graphs is not None:
kwargs["cuda_graphs"] = self.config.cuda_graphs

if self.config.trust_remote_code is not None:
kwargs["trust_remote_code"] = self.config.trust_remote_code

if self.config.disable_custom_kernels is not None:
kwargs["disable_custom_kernels"] = self.config.disable_custom_kernels

return kwargs

def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
if self.config.task in TEXT_GENERATION_TASKS:
inputs = {"prompt": self.pretrained_processor.batch_decode(inputs["input_ids"].tolist())}
Expand Down
41 changes: 11 additions & 30 deletions optimum_benchmark/backends/py_txi/config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import os
from dataclasses import dataclass, field
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union

from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE

from ...import_utils import py_txi_version
from ...system_utils import is_nvidia_system, is_rocm_system
from ...task_utils import TEXT_EMBEDDING_TASKS, TEXT_GENERATION_TASKS
Expand All @@ -16,7 +14,7 @@ class PyTXIConfig(BackendConfig):
version: Optional[str] = py_txi_version()
_target_: str = "optimum_benchmark.backends.py_txi.backend.PyTXIBackend"

# optimum benchmark specific
# optimum-benchmark specific
no_weights: bool = False

# Image to use for the container
Expand All @@ -28,27 +26,18 @@ class PyTXIConfig(BackendConfig):
# NVIDIA-docker GPU device options e.g. "all" (all) or "0,1,2,3" (ids) or 4 (count)
gpus: Optional[Union[str, int]] = None
# Things to forward to the container
ports: Dict[str, Any] = field(
default_factory=lambda: {"80/tcp": ("127.0.0.1", 0)},
metadata={"help": "Dictionary of ports to expose from the container."},
)
volumes: Dict[str, Any] = field(
default_factory=lambda: {HUGGINGFACE_HUB_CACHE: {"bind": "/data", "mode": "rw"}},
metadata={"help": "Dictionary of volumes to mount inside the container."},
)
environment: List[str] = field(
default_factory=lambda: ["HF_TOKEN"],
metadata={"help": "List of environment variables to forward to the container from the host."},
)

# first connection/request
connection_timeout: int = 60
first_request_timeout: int = 60
ports: Optional[Dict[str, Any]] = None
environment: Optional[List[str]] = None
volumes: Optional[Dict[str, Any]] = None
# First connection/request
connection_timeout: Optional[int] = None
first_request_timeout: Optional[int] = None
max_concurrent_requests: Optional[int] = None

# Common options
dtype: Optional[str] = None

# TEI specific
pooling: Optional[str] = None
# TGI specific
sharded: Optional[str] = None
quantize: Optional[str] = None
Expand All @@ -58,9 +47,6 @@ class PyTXIConfig(BackendConfig):
trust_remote_code: Optional[bool] = None
disable_custom_kernels: Optional[bool] = None

# TEI specific
pooling: Optional[str] = None

def __post_init__(self):
super().__post_init__()

Expand All @@ -76,9 +62,4 @@ def __post_init__(self):
renderDs = [file for file in os.listdir("/dev/dri") if file.startswith("renderD")]
self.devices = ["/dev/kfd"] + [f"/dev/dri/{renderDs[i]}" for i in ids]

# Common options
if self.max_concurrent_requests is None:
if self.task in TEXT_GENERATION_TASKS:
self.max_concurrent_requests = 128
elif self.task in TEXT_EMBEDDING_TASKS:
self.max_concurrent_requests = 512
self.trust_remote_code = self.model_kwargs.get("trust_remote_code", None)

0 comments on commit ea32802

Please sign in to comment.