diff --git a/text-generation-inference/docker/Dockerfile b/text-generation-inference/docker/Dockerfile index 5632449c..6c348a0e 100644 --- a/text-generation-inference/docker/Dockerfile +++ b/text-generation-inference/docker/Dockerfile @@ -109,7 +109,10 @@ COPY . /opt/optimum-tpu # Install requirements for optimum-tpu, then for TGI then optimum-tpu RUN python3 -m pip install hf_transfer safetensors==${SAFETENSORS_VERSION} && \ - python3 -m pip install -e /opt/optimum-tpu -f https://storage.googleapis.com/libtpu-releases/index.html + python3 -m pip install -e /opt/optimum-tpu[jetstream-pt] \ + -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ + -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html \ + -f https://storage.googleapis.com/libtpu-releases/index.html # Install router COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router diff --git a/text-generation-inference/server/text_generation_server/jetstream_pt_support/engine_loader.py b/text-generation-inference/server/text_generation_server/jetstream_pt_support/engine_loader.py index 91cf9784..93dccdd7 100644 --- a/text-generation-inference/server/text_generation_server/jetstream_pt_support/engine_loader.py +++ b/text-generation-inference/server/text_generation_server/jetstream_pt_support/engine_loader.py @@ -144,6 +144,10 @@ def create_engine( env = JetEngineEnvironment(env_data) model = instantiate_model_from_repo_id(model_path, env) + # Update config with engine data + model.config.batch_size = batch_size + model.config.sequence_length = sequence_length + weight_shardings = model.get_sharding_annotations() sharded_weights = shard_weights(env, model.state_dict(), weight_shardings) diff --git a/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py b/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py index 359d90cd..8251c3df 100644 --- a/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py +++ b/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py @@ -9,7 +9,7 @@ import numpy as np import torch import torch_xla2 -from jetstream.engine.token_utils import pad_tokens, take_nearest_length +from jetstream.engine.token_utils import pad_tokens, take_nearest_length, DEFAULT_PREFILL_BUCKETS from jetstream_pt.engine import PyTorchEngine from loguru import logger from transformers import AutoTokenizer, PreTrainedTokenizerBase @@ -36,20 +36,6 @@ optimum_logger = logging.getLogger("optimum.tpu") optimum_logger.setLevel("CRITICAL") -# These will do some bucketing on prefill lengths to avoid too many different sizes -PREFILL_LENGTHS = [ - 32, - 64, - 128, - 256, - 512, - 1024, - 2048, - 4096, - 8192, - 16384, - 32768, -] class Slot: """Represents a slot in a static batch""" @@ -78,6 +64,7 @@ def clear(self): self._generated_text = "" self._next_text = "" self._truncate = 0 + self._seed = 0 @property def id(self) -> int: @@ -134,7 +121,7 @@ def assign(self, batch_id: int, request: Request, generation_config: GenerationC self._generation_config.do_sample = request.parameters.do_sample self._generation_config.repetition_penalty = request.parameters.repetition_penalty self._truncate = request.truncate - self.seed = request.parameters.seed + self._seed = request.parameters.seed # TODO: watermark self._generation_config.max_new_tokens = request.stopping_parameters.max_new_tokens self._max_new_tokens = self._generation_config.max_new_tokens @@ -237,6 +224,20 @@ def next_token(self) -> int: def empty(self) -> bool: return len(self._tokens) == 0 + @property + def seed(self) -> int: + return self._seed + + +class PrefillSlot: + def __init__(self): + self._curslot = None + + def set(self, slot: Slot): + self._curslot = slot + + def select(self, logits: jnp.ndarray) -> int: + return self._curslot.select(logits) class TpuGeneratorJetStream(Generator): """A Generator for models running on TPU, single threaded.""" @@ -267,6 +268,7 @@ def __init__( self.batch_id = 0 # Note: this index will _never_ be decremented, and that's fine. self.slot_index = 0 + self.prefill_slot = PrefillSlot() @property def info(self) -> InfoResponse: @@ -328,31 +330,30 @@ def warmup(self, batch: Batch) -> int: # Counter-intuitively, now we ignore the input batch. Instead, we create dummy batches to cover all possible # batch sizes and sequence lengths. seq_len = self.model.config.sequence_length - bucket_seq_len = take_nearest_length(PREFILL_LENGTHS, seq_len) - dummy_request = self._create_dummy_request(seq_len) + bucket_seq_len = take_nearest_length(DEFAULT_PREFILL_BUCKETS, seq_len) decode_done = False - for l in reversed(PREFILL_LENGTHS): + for l in reversed(DEFAULT_PREFILL_BUCKETS): # Skip all the unsupported lengths if l > bucket_seq_len: continue - # Set all truncate values for all requests - dummy_request.truncate = l - dummy_request.stopping_parameters.max_new_tokens = 10 + # create a dummy request with the current sequence length + dummy_request = self._create_dummy_request(l) + # We define few max_new_tokens to request at least one (by prefill) and another by decode. + MAX_NEW_TOKENS = 10 + dummy_request.stopping_parameters.max_new_tokens = MAX_NEW_TOKENS warmup_batch = Batch(id=0, requests=[dummy_request], size=1, max_tokens=batch.max_tokens) logger.debug(f"Warmup for requests, len {l} seq_len {seq_len}") _generations, next_batch = self.prefill(warmup_batch) - if not decode_done and next_batch is not None: + if next_batch is not None: self.decode([next_batch]) decode_done = True self.clear() if not decode_done: logger.debug("No decode done during warmup") - self.prefill(batch) - self.clear() elapsed = time.time() - start logger.debug(f"Warmup done, took {elapsed:.2f}s") seq_len = self.engine.env.seq_len @@ -390,11 +391,13 @@ def _token_encode(self, text: str, max_length: int) -> Tuple[jnp.ndarray, int]: max_length=max_length, add_special_tokens=False, ) + # max_prefill_length must be a power of 2 + max_prefill_length = take_nearest_length(DEFAULT_PREFILL_BUCKETS, self.model.config.sequence_length) tokens, true_length = pad_tokens(input_ids[0], self.tokenizer.bos_token_id, self.tokenizer.pad_token_id, is_bos=True, - max_prefill_length=self.model.config.sequence_length, + max_prefill_length=max_prefill_length, jax_padding=True, ) return tokens, true_length @@ -436,6 +439,7 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: for request in batch.requests: # Dynamically create a new slot for each request slot = Slot(self._get_slot_id(), self.tokenizer) + self.prefill_slot.set(slot) self.slot_index += 1 slot.assign(self.batch_id, request, self.model.generation_config) logger.debug(f"Request {slot.request_id} assigned to slot {slot.id}") @@ -452,7 +456,7 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: ) slot.reset(truncated_input_ids, selector) # To allow jit'ing the select function, we need to wrap it in a partial - slot_select = jax.tree_util.Partial(slot.select) + slot_select = jax.tree_util.Partial(self.prefill_slot.select) # Ask for prefill and insert prefill_results, _result_tokens = self.engine.prefill( params=self.params, @@ -469,6 +473,7 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: self.slots.append(slot) len_active_slots += 1 + batch = None if len_active_slots > 0: # Whatever initial batch these requests came from, we always return all pending requests in a single batch request_ids = [slot.request_id for slot in self.slots if slot.state == Slot.State.READY] @@ -499,6 +504,13 @@ def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBa Return: A list of `Generation` for each request and a `CachedBatch` containing all pending requests. """ + + # In python we should use type duck, but if elements passed on the list are not of the right type, this will + # prevent raising an error and wasting time. Return an empty generation instead. + if any(not isinstance(item, CachedBatch) for item in batches): + logger.error("Unexpected type in decode, expected CachedBatch") + return [], None + # batches contains a list composed of ongoing requests: # - the batch id returned by the last decode, # - the batch id(s) returned by the last prefill(s) @@ -532,7 +544,7 @@ def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBa # Get the next token. # Note that for now we ignore is_valid and length as we don't use them, we will re-parse these in post # generation. - next_token, _is_valid, _length = result_tokens.data[slot.id] + next_token = self.decode_state.tokens[slot.id].item() if slot.state != Slot.State.READY: logger.error(f"Unexpected Slot {slot.id} is not ready for decoding, skipping.") diff --git a/text-generation-inference/server/text_generation_server/jetstream_pt_support/token_selector.py b/text-generation-inference/server/text_generation_server/jetstream_pt_support/token_selector.py index 9f9e21f8..fe31ad9c 100644 --- a/text-generation-inference/server/text_generation_server/jetstream_pt_support/token_selector.py +++ b/text-generation-inference/server/text_generation_server/jetstream_pt_support/token_selector.py @@ -55,6 +55,8 @@ def __init__( self.eos_token_ids = eos_token_ids self.pad_token_id = pad_token_id self.logits_warper = logits_warper + # Seed needs to fit a 64-bit integer, so we modulo it in case is bigger (that can happen!) + seed = seed % jnp.iinfo(jnp.int64).max self.key = jax.random.PRNGKey(seed) @classmethod diff --git a/text-generation-inference/tests/test_decode.py b/text-generation-inference/tests/test_decode.py index 1f431a3e..c7ea0f19 100644 --- a/text-generation-inference/tests/test_decode.py +++ b/text-generation-inference/tests/test_decode.py @@ -100,12 +100,12 @@ def _test_decode_single(params): DecodeTestParams( model_id="meta-llama/Llama-2-7b-hf", sequence_length=256, - expected_text="\n\nThe clocks were striking thirteen\nThe clocks were striking thirteen\n", + expected_text="\nThe clocks were striking thirteen\nThe clocks were striking thirteen\nThe", ), DecodeTestParams( model_id="meta-llama/Meta-Llama-3-8B", sequence_length=256, - expected_text=" Winston Winston Smith, his chin on his hands, and the clock in the Ministry of Truth, M", + expected_text=" Winston Smith, his chin on his hands, and the clock in the Ministry of Truth, Minit", ), ], ids=["Llama-2-7b-hf", "Meta-Llama-3-8B"], @@ -123,7 +123,7 @@ def test_decode_single_jetstream_pytorch_slow(params, do_sample): DecodeTestParams( model_id="Maykeye/TinyLLama-v0", sequence_length=256, - expected_text=" She She had a big and it had a big, blue, and a big, red and a", + expected_text=" She had a big and it had a big, blue, and a big, red and a big", ), ], ids=["TinyLLama-v0"], diff --git a/text-generation-inference/tests/test_warmup.py b/text-generation-inference/tests/test_warmup.py new file mode 100644 index 00000000..6e33b2ee --- /dev/null +++ b/text-generation-inference/tests/test_warmup.py @@ -0,0 +1,30 @@ + + +import pytest +from helpers import create_request, prepare_model +from text_generation_server.auto_generator import AutoGenerator +from text_generation_server.pb.generate_pb2 import Batch + +from optimum.tpu.jetstream_pt_support import jetstream_pt_available + + +def test_warmup_jetstream_pytorch(): + if not jetstream_pt_available(): + pytest.skip("Jetstream PyTorch is not available") + model_id = "Maykeye/TinyLLama-v0" + + # The maximum sequence length of the model is set to 1000, but warmup will round that up to the next power of two + # in prefill (1024). + sequence_length = 1000 + + model_path = prepare_model(model_id, sequence_length) + input_text = "It was a bright cold day in April, and the clocks were striking thirteen." + max_new_tokens = 20 + + generator = AutoGenerator.from_pretrained( + model_path, revision="", max_batch_size=1, max_sequence_length=sequence_length + ) + request = create_request(id=0, inputs=input_text, max_new_tokens=max_new_tokens, do_sample=False) + batch = Batch(id=0, requests=[request], size=1, max_tokens=sequence_length) + generator.warmup(batch) +