diff --git a/.github/workflows/test-pytorch-xla-tpu-tgi-jetstream.yml b/.github/workflows/test-pytorch-xla-tpu-tgi-jetstream.yml new file mode 100644 index 00000000..8640ea8c --- /dev/null +++ b/.github/workflows/test-pytorch-xla-tpu-tgi-jetstream.yml @@ -0,0 +1,32 @@ +name: Optimum TPU / Test TGI on TPU / Jetstream Pytorch + +on: + push: + branches: [ main ] + paths: + - "text-generation-inference/**" + pull_request: + branches: [ main ] + paths: + - "text-generation-inference/**" + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +jobs: + do-the-job: + name: Run TGI tests - Jetstream Pytorch + runs-on: optimum-tpu + container: + image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.4.0_3.10_tpuvm + options: --shm-size "16gb" --ipc host --privileged + env: + PJRT_DEVICE: TPU + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Build and test TGI server + run: | + HF_TOKEN=${{ secrets.HF_TOKEN_OPTIMUM_TPU_CI }} make tgi_test_jetstream diff --git a/.github/workflows/test-pytorch-xla-tpu-tgi.yml b/.github/workflows/test-pytorch-xla-tpu-tgi.yml index 4c681941..86873328 100644 --- a/.github/workflows/test-pytorch-xla-tpu-tgi.yml +++ b/.github/workflows/test-pytorch-xla-tpu-tgi.yml @@ -19,7 +19,6 @@ jobs: name: Run TGI tests runs-on: optimum-tpu container: - # Use a nightly image that works with TPU (release was not working) image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.4.0_3.10_tpuvm options: --shm-size "16gb" --ipc host --privileged env: @@ -31,13 +30,3 @@ jobs: - name: Build and test TGI server run: | HF_TOKEN=${{ secrets.HF_TOKEN_OPTIMUM_TPU_CI }} make tgi_test - - # Use a different step to test the Jetstream Pytorch version, to avoid conflicts with torch-xla[tpu] - - name: Install and test TGI server (Jetstream Pytorch) - run: | - pip install -U .[jetstream-pt] \ - -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ - -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html \ - -f https://storage.googleapis.com/libtpu-releases/index.html - JETSTREAM_PT=1 HF_TOKEN=${{ secrets.HF_TOKEN_OPTIMUM_TPU_CI }} python -m \ - pytest -sv text-generation-inference/tests -k jetstream diff --git a/.gitignore b/.gitignore index 7beb0e7f..55b9b3cc 100644 --- a/.gitignore +++ b/.gitignore @@ -133,4 +133,6 @@ dmypy.json *.pt .vscode -.idea/ \ No newline at end of file +.idea/ + +jetstream-pt-deps \ No newline at end of file diff --git a/Makefile b/Makefile index b7723d34..d6de8e1b 100644 --- a/Makefile +++ b/Makefile @@ -39,7 +39,7 @@ $(PACKAGE_DIST) $(PACKAGE_WHEEL): $(PACKAGE_FILES) python -m build clean: - rm -rf dist + rm -rf dist deps make -C text-generation-inference/server/ clean tpu-tgi: @@ -87,6 +87,18 @@ tgi_server: make -C text-generation-inference/server clean VERSION=${VERSION} TGI_VERSION=${TGI_VERSION} make -C text-generation-inference/server gen-server +jetstream_requirements: + bash install-jetstream-pt.sh + python -m pip install .[jetstream-pt] \ + -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ + -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html \ + -f https://storage.googleapis.com/libtpu-releases/index.html + +tgi_test_jetstream: test_installs jetstream_requirements tgi_server + find text-generation-inference -name "text_generation_server-$(VERSION)-py3-none-any.whl" \ + -exec python -m pip install --force-reinstall {} \; + JETSTREAM_PT=1 python -m pytest -sv text-generation-inference/tests -k jetstream + tgi_test: test_installs tgi_server find text-generation-inference -name "text_generation_server-$(VERSION)-py3-none-any.whl" \ -exec python -m pip install --force-reinstall {} \; diff --git a/install-jetstream-pt.sh b/install-jetstream-pt.sh new file mode 100644 index 00000000..aa5bd621 --- /dev/null +++ b/install-jetstream-pt.sh @@ -0,0 +1,13 @@ +#!/bin/bash +deps_dir=deps +rm -rf $deps_dir +mkdir -p $deps_dir +cd $deps_dir +pwd +git clone https://github.com/google/jetstream-pytorch.git +cd jetstream-pytorch +git checkout ec4ac8f6b180ade059a2284b8b7d843b3cab0921 +git submodule update --init --recursive +# We cannot install in a temporary directory because the directory should not be deleted after the script finishes, +# because it will install its dependendencies from that directory. +pip install -e . diff --git a/pyproject.toml b/pyproject.toml index 01ad1a5d..c57ba05d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,10 +58,10 @@ build-backend = "setuptools.build_meta" [project.optional-dependencies] tests = ["pytest", "safetensors"] quality = ["black", "ruff", "isort"] -# Jetstream/Pytorch support is experimental for now, requires installation from fixed commit. +# Jetstream/Pytorch support is experimental for now, it needs to be installed manually. # Pallas is pulled because it will install a compatible version of jax[tpu]. jetstream-pt = [ - "jetstream-pt @ git+https://github.com/google/jetstream-pytorch.git@ec4ac8f6b180ade059a2284b8b7d843b3cab0921", + "jetstream-pt", "torch-xla[pallas] == 2.4.0" ] diff --git a/text-generation-inference/server/text_generation_server/jetstream_pt_support/engine_loader.py b/text-generation-inference/server/text_generation_server/jetstream_pt_support/engine_loader.py index 93dccdd7..ab5ef782 100644 --- a/text-generation-inference/server/text_generation_server/jetstream_pt_support/engine_loader.py +++ b/text-generation-inference/server/text_generation_server/jetstream_pt_support/engine_loader.py @@ -104,11 +104,7 @@ def instantiate_model_from_repo_id( env.device = "meta" model = create_model(model_dir, env) weights = fetch_models._load_weights(model_dir) - updated_keys = model.get_hf_names_to_real_name() - for name, updated in updated_keys.items(): - if name in weights: - val = weights.pop(name) - weights[updated] = val + weights = model.convert_hf_weights(weights) model.load_state_dict(weights, assign=True, strict=False) diff --git a/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py b/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py index 8251c3df..9bd6f627 100644 --- a/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py +++ b/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py @@ -1,5 +1,6 @@ import copy import logging +import os import time from enum import Enum from typing import List, Optional, Tuple @@ -9,7 +10,7 @@ import numpy as np import torch import torch_xla2 -from jetstream.engine.token_utils import pad_tokens, take_nearest_length, DEFAULT_PREFILL_BUCKETS +from jetstream.engine.token_utils import DEFAULT_PREFILL_BUCKETS, pad_tokens, take_nearest_length from jetstream_pt.engine import PyTorchEngine from loguru import logger from transformers import AutoTokenizer, PreTrainedTokenizerBase @@ -330,6 +331,9 @@ def warmup(self, batch: Batch) -> int: # Counter-intuitively, now we ignore the input batch. Instead, we create dummy batches to cover all possible # batch sizes and sequence lengths. seq_len = self.model.config.sequence_length + if os.environ.get("SKIP_WARMUP", "0") == "1": + logger.debug("Skipping warmup") + return batch_size * seq_len bucket_seq_len = take_nearest_length(DEFAULT_PREFILL_BUCKETS, seq_len) decode_done = False for l in reversed(DEFAULT_PREFILL_BUCKETS): diff --git a/text-generation-inference/server/text_generation_server/jetstream_pt_support/llama_model_exportable_hf.py b/text-generation-inference/server/text_generation_server/jetstream_pt_support/llama_model_exportable_hf.py index 1bab00d3..14cd0316 100644 --- a/text-generation-inference/server/text_generation_server/jetstream_pt_support/llama_model_exportable_hf.py +++ b/text-generation-inference/server/text_generation_server/jetstream_pt_support/llama_model_exportable_hf.py @@ -1,171 +1,9 @@ -from typing import Any, List, Optional -import jax -import torch -import torch.nn.functional as F -from jetstream_pt.layers import ( - Attention, - RMSNorm, - get_quantized_embedding_layer, - get_quantized_linear_layer, -) -from jetstream_pt.model_base import ModuleBase +from jetstream_pt.third_party.llama.model_exportable import Transformer, model_args from transformers import GenerationConfig, GenerationMixin, LlamaConfig -class FeedForward(ModuleBase): - """Feed-forward module, AKA LlamaMLP on HuggingFace. - - Note the main difference is that it uses intermediate_size instead of multiple_of and ffn_dim_multiplier. - The parameter dim here corresponds to hidden_size in HuggingFace's Llama model, and hidden_dim is not really used, - because intermediate_size is used instead. - """ - - def __init__( - self, - dim: int, - intermediate_size: int, - device="meta", - env=None, - ): - super().__init__() - self.env = env - - LinearLayer = get_quantized_linear_layer(env.quant_config) - linear_kwargs = {} - if LinearLayer != torch.nn.Linear: - linear_kwargs["quant_config"] = env.quant_config - - self.w1 = LinearLayer( - dim, - intermediate_size, - bias=False, - device=device, - **linear_kwargs, - ) - self.w2 = LinearLayer( - intermediate_size, - dim, - bias=False, - device=device, - **linear_kwargs, - ) - self.w3 = LinearLayer( - dim, - intermediate_size, - bias=False, - device=device, - **linear_kwargs, - ) - self.hf_name("w1", "gate_proj") - self.hf_name("w2", "down_proj") - self.hf_name("w3", "up_proj") - - self.annotate_sharding("w1.weight", 0) - self.annotate_sharding("w2.weight", 1) - self.annotate_sharding("w3.weight", 0) - - def forward(self, x): - result = self.w2(F.silu(self.w1(x)) * self.w3(x)) - return result - - -class TransformerBlockHf(ModuleBase): - """This is essentially the same as the JetstreamPytoch Transformer, but it avoids using multiple_of and - ffn_dim_multiplier that are not available in HuggingFace's Llama model, and it uses intermediate_size instead. - """ - - def __init__( - self, - layer_id: int, - config: LlamaConfig, - device, - env, - ): - super().__init__() - self.env = env - self.n_heads = config.num_attention_heads - self.dim = config.hidden_size - self.head_dim = config.hidden_size // config.num_attention_heads - - self.attention = Attention( - config.num_attention_heads, - config.num_key_value_heads or config.num_attention_heads, - config.hidden_size // config.num_attention_heads, - config.hidden_size, - env=env, - device=device, - layer_id=layer_id, - ) - self.feed_forward = FeedForward( - dim=config.hidden_size, - intermediate_size=config.intermediate_size, - device=device, - env=env, - ) - self.layer_id = layer_id - self.attention_norm = RMSNorm( - config.hidden_size, eps=config.rms_norm_eps, device=device - ) - self.ffn_norm = RMSNorm( - config.hidden_size, eps=config.rms_norm_eps, device=device - ) - - self.hf_name("attention", "self_attn") - self.attention.hf_name("wq", "q_proj") - self.attention.hf_name("wk", "k_proj") - self.attention.hf_name("wv", "v_proj") - self.attention.hf_name("wo", "o_proj") - - self.attention.annotate_sharding("wq.weight", 0) - self.attention.annotate_sharding("wk.weight", 0) - self.attention.annotate_sharding("wv.weight", 0) - self.attention.annotate_sharding("wo.weight", 1) - - self.hf_name("feed_forward", "mlp") - self.hf_name("attention_norm", "input_layernorm") - self.hf_name("ffn_norm", "post_attention_layernorm") - - def forward( - self, - x: torch.Tensor, - freqs_cis: torch.Tensor, - mask: Optional[torch.Tensor], - cache, - start=None, - end=None, - ragged_batch_index=None, - ragged_block_index=None, - ): - with jax.named_scope("Attention"): - attn = self.attention.forward( - self.attention_norm(x), - freqs_cis, - mask, - cache, - start, - end, - ragged_batch_index, - ragged_block_index, - ) - with jax.named_scope("ffn_norm"): - h = x + attn - ffns = self.ffn_norm(h) - - with jax.named_scope("ffn"): - out = h + self.feed_forward.forward(ffns) - return out - - -def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor: - freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) - t = torch.arange(end, device=freqs.device) # type: ignore - freqs = torch.outer(t, freqs).float() # type: ignore - freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 - return freqs_cis - - -class TransformerHf(ModuleBase, GenerationMixin): +class TransformerHf(Transformer, GenerationMixin): """Transformer module that uses HF LlamaConfig instead of Jetstream Pytorch ModelArgs + device. Note that this class also derives from GenerationMixin, so that we can use its methods. @@ -177,121 +15,33 @@ def __init__( device, env, ): - super().__init__() - self.env = env self.config = config self.generation_config = GenerationConfig.from_model_config(config) - self.vocab_size = config.vocab_size - self.n_layers = config.num_hidden_layers - Embedding = get_quantized_embedding_layer(env.quant_config) - self.tok_embeddings = Embedding( - config.vocab_size, - config.hidden_size, - device=device, - ) + # NOTE: these parameters are deduced from the config's intermediate_size and hidden_size, so to be compatible + # with the original Jestream/Pytorch model. + ffn_dim_multiplier = config.intermediate_size / int(8 * config.hidden_size / 3) + multiple_of = 1 - self.layers = torch.nn.ModuleList() - for layer_id in range(config.num_hidden_layers): - self.layers.append(TransformerBlockHf(layer_id, config, device, env)) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device) - - LinearLayer = get_quantized_linear_layer(env.quant_config) - linear_kwargs = {} - if LinearLayer != torch.nn.Linear: - linear_kwargs["quant_config"] = env.quant_config - - self.output = LinearLayer( - config.hidden_size, - config.vocab_size, - bias=False, - device=device, - **linear_kwargs, - ) - # TODO what to do with this - freqs_cis = precompute_freqs_cis( - config.hidden_size // config.num_attention_heads, - env.cache_len * 2, - theta=config.rope_theta, + args = model_args.ModelArgs( + dim=config.hidden_size, + n_layers=config.num_hidden_layers, + n_heads=config.num_attention_heads, + n_kv_heads=config.num_key_value_heads, + vocab_size=config.vocab_size, + multiple_of=multiple_of, + ffn_dim_multiplier=ffn_dim_multiplier, + norm_eps=config.rms_norm_eps, + max_seq_len=env.cache_len, + bf16_enable=env.bf16_enable, + rope_theta=config.rope_theta, ) + args.device = device + super().__init__(args, env) - self.register_buffer("freqs_cis", freqs_cis) - - self.hf_name("output", "lm_head") - self.hf_name("norm", "model.norm") - self.hf_name("layers", "model.layers") - self.hf_name("tok_embeddings", "model.embed_tokens") - - self.annotate_sharding("tok_embeddings.weight", 1) - self.annotate_sharding("output.weight", 0) - - @torch.no_grad() - def forward( - self, - tokens: torch.Tensor, - input_pos: torch.Tensor, - caches: List[Any], - mask, - start=None, - ragged_batch_index=None, - ragged_block_index=None, - ): - """ - tokens: the input token for decoding - input_pos: the decoding position relative to the start, which is the length of the decoding results - caches: kv caches - mask: causal mask to filter the attention results - start: the starting position for each slot - ragged_batch_index: precomputed batch index for ragged attention - ragged_block_index: precomputed block index for ragged attention - """ - - with jax.named_scope("transformer_tok"): - seqlen = tokens.shape[-1] - h = self.tok_embeddings(tokens) - - with jax.named_scope("transformer_freq"): - bsz, seqlen = tokens.shape - freqs_cis = self.freqs_cis[input_pos] - freqs_cis = freqs_cis.reshape(bsz, seqlen, -1) - - end = None if start is None else (start + input_pos) % self.env.cache_len - # For stacked case, cannot get cache inside the loop which will cause cache copy - for layer_id, layer in enumerate(self.layers): - if caches[0].stacked: - cache = caches[0] - else: - cache = caches[layer_id] - # else: # For stacked case, there is only 1 yer of kv cache - - with jax.named_scope("TransformerBlock_Layer_" + str(layer_id)): - h = layer( - h, - freqs_cis, - mask, - cache, - start, - end, - ragged_batch_index, - ragged_block_index, - ) - - with jax.named_scope("transformer_norm"): - h = self.norm(h) - output = self.output(h).float() - return output @classmethod def from_config(cls, config, env): device = "meta" model = cls(config, device, env) return model - - def drop_weight(self, key): - return key.startswith("model") - - def shard_weights(self, _weights_dict): - """Shards the weights - - Assumes the weights_dict is a list of XLATensor2 - """ diff --git a/text-generation-inference/tests/test_decode.py b/text-generation-inference/tests/test_decode.py index c7ea0f19..1743efa6 100644 --- a/text-generation-inference/tests/test_decode.py +++ b/text-generation-inference/tests/test_decode.py @@ -16,6 +16,7 @@ class DecodeTestParams: sequence_length: int expected_text: str do_sample: bool = False + max_new_tokens: int = 20 @pytest.mark.parametrize("params", @@ -64,7 +65,7 @@ def test_decode_single_slow(params): def _test_decode_single(params): model_path = prepare_model(params.model_id, params.sequence_length) input_text = "It was a bright cold day in April, and the clocks were striking thirteen." - max_new_tokens = 20 + max_new_tokens = params.max_new_tokens generator = AutoGenerator.from_pretrained( model_path, revision="", max_batch_size=1, max_sequence_length=params.sequence_length @@ -100,12 +101,12 @@ def _test_decode_single(params): DecodeTestParams( model_id="meta-llama/Llama-2-7b-hf", sequence_length=256, - expected_text="\nThe clocks were striking thirteen\nThe clocks were striking thirteen\nThe", + expected_text="\nWinston Smith, his chin nuzzled into his breast in an effort to escape", ), DecodeTestParams( model_id="meta-llama/Meta-Llama-3-8B", sequence_length=256, - expected_text=" Winston Smith, his chin on his hands, and the clock in the Ministry of Truth, Minit", + expected_text=" Winston Smith, his chin nuzzled into his breast in an effort to escape the vile wind,", ), ], ids=["Llama-2-7b-hf", "Meta-Llama-3-8B"], @@ -123,7 +124,8 @@ def test_decode_single_jetstream_pytorch_slow(params, do_sample): DecodeTestParams( model_id="Maykeye/TinyLLama-v0", sequence_length=256, - expected_text=" She had a big and it had a big, blue, and a big, red and a big", + expected_text=" The sun was shining and the sky was shining.\nSuddenly, a big wind came and blew the wind away.", + max_new_tokens=25, ), ], ids=["TinyLLama-v0"],