Skip to content

Commit

Permalink
Fix generation using Jetstream Pytorch (#94)
Browse files Browse the repository at this point in the history
* feat(debug): add env var to skip warmup

* fix(Jetstream Pt): correct generation

Text generation was not correct because the weights in the model were
not correctly loaded. This is not something that it was easy to spot
just looking at few tokens generated, and it was something that it was
actually fixed already in the Jetstream/Pytorch code, but the fix hadn't
been ported to optimum-tpu.

This fix implement the necessary weights changes, aligning to Jetstream
Pytorch, and tests expected output has been modified accordingly.

* ci: separate Jetstream Pytorch test to its own workflow

The main workflow was failing due to an OS error. I suspect that being
related to a problem of space. Separating the workflow will make it
easier to analyse this issue.

* fix(jetstream Pt): make Jetstream Pt install more reliable

I was previously referencing a given git revision and install from
github, but since the Jetstream Pytorch package install its dependencies
from its git submodels, these are installed in temporary directories,
that can disappear afterwards. This happened on CI, making the
installation fail.

To work around that, a dedicated install script has been added, and it
is now used to install that.

* fix(style): correct generator style

* refactor(Jetstream Pt): avoid duplicating Llama modeling

Since this is error-prone, a better solution is just to use this.
This hadn't been done before mainly because in the model config we do
not have some of the params anymore (ffn_dim_multiplier and
multiple_of). We do have intermediate_size though, and that is enough to
reconstruct parameters that end up producing the same calculation.

This refactor should allow for future code to follow Jetstream/Pytorch
changes in an easier way.
  • Loading branch information
tengomucho authored Sep 23, 2024
1 parent 4265e13 commit 094d8a8
Show file tree
Hide file tree
Showing 10 changed files with 95 additions and 295 deletions.
32 changes: 32 additions & 0 deletions .github/workflows/test-pytorch-xla-tpu-tgi-jetstream.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
name: Optimum TPU / Test TGI on TPU / Jetstream Pytorch

on:
push:
branches: [ main ]
paths:
- "text-generation-inference/**"
pull_request:
branches: [ main ]
paths:
- "text-generation-inference/**"

concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true

jobs:
do-the-job:
name: Run TGI tests - Jetstream Pytorch
runs-on: optimum-tpu
container:
image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.4.0_3.10_tpuvm
options: --shm-size "16gb" --ipc host --privileged
env:
PJRT_DEVICE: TPU
steps:
- name: Checkout
uses: actions/checkout@v4

- name: Build and test TGI server
run: |
HF_TOKEN=${{ secrets.HF_TOKEN_OPTIMUM_TPU_CI }} make tgi_test_jetstream
11 changes: 0 additions & 11 deletions .github/workflows/test-pytorch-xla-tpu-tgi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ jobs:
name: Run TGI tests
runs-on: optimum-tpu
container:
# Use a nightly image that works with TPU (release was not working)
image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.4.0_3.10_tpuvm
options: --shm-size "16gb" --ipc host --privileged
env:
Expand All @@ -31,13 +30,3 @@ jobs:
- name: Build and test TGI server
run: |
HF_TOKEN=${{ secrets.HF_TOKEN_OPTIMUM_TPU_CI }} make tgi_test
# Use a different step to test the Jetstream Pytorch version, to avoid conflicts with torch-xla[tpu]
- name: Install and test TGI server (Jetstream Pytorch)
run: |
pip install -U .[jetstream-pt] \
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html \
-f https://storage.googleapis.com/libtpu-releases/index.html
JETSTREAM_PT=1 HF_TOKEN=${{ secrets.HF_TOKEN_OPTIMUM_TPU_CI }} python -m \
pytest -sv text-generation-inference/tests -k jetstream
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -133,4 +133,6 @@ dmypy.json
*.pt

.vscode
.idea/
.idea/

jetstream-pt-deps
14 changes: 13 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ $(PACKAGE_DIST) $(PACKAGE_WHEEL): $(PACKAGE_FILES)
python -m build

clean:
rm -rf dist
rm -rf dist deps
make -C text-generation-inference/server/ clean

tpu-tgi:
Expand Down Expand Up @@ -87,6 +87,18 @@ tgi_server:
make -C text-generation-inference/server clean
VERSION=${VERSION} TGI_VERSION=${TGI_VERSION} make -C text-generation-inference/server gen-server

jetstream_requirements:
bash install-jetstream-pt.sh
python -m pip install .[jetstream-pt] \
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html \
-f https://storage.googleapis.com/libtpu-releases/index.html

tgi_test_jetstream: test_installs jetstream_requirements tgi_server
find text-generation-inference -name "text_generation_server-$(VERSION)-py3-none-any.whl" \
-exec python -m pip install --force-reinstall {} \;
JETSTREAM_PT=1 python -m pytest -sv text-generation-inference/tests -k jetstream

tgi_test: test_installs tgi_server
find text-generation-inference -name "text_generation_server-$(VERSION)-py3-none-any.whl" \
-exec python -m pip install --force-reinstall {} \;
Expand Down
13 changes: 13 additions & 0 deletions install-jetstream-pt.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#!/bin/bash
deps_dir=deps
rm -rf $deps_dir
mkdir -p $deps_dir
cd $deps_dir
pwd
git clone https://github.com/google/jetstream-pytorch.git
cd jetstream-pytorch
git checkout ec4ac8f6b180ade059a2284b8b7d843b3cab0921
git submodule update --init --recursive
# We cannot install in a temporary directory because the directory should not be deleted after the script finishes,
# because it will install its dependendencies from that directory.
pip install -e .
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,10 @@ build-backend = "setuptools.build_meta"
[project.optional-dependencies]
tests = ["pytest", "safetensors"]
quality = ["black", "ruff", "isort"]
# Jetstream/Pytorch support is experimental for now, requires installation from fixed commit.
# Jetstream/Pytorch support is experimental for now, it needs to be installed manually.
# Pallas is pulled because it will install a compatible version of jax[tpu].
jetstream-pt = [
"jetstream-pt @ git+https://github.com/google/jetstream-pytorch.git@ec4ac8f6b180ade059a2284b8b7d843b3cab0921",
"jetstream-pt",
"torch-xla[pallas] == 2.4.0"
]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,7 @@ def instantiate_model_from_repo_id(
env.device = "meta"
model = create_model(model_dir, env)
weights = fetch_models._load_weights(model_dir)
updated_keys = model.get_hf_names_to_real_name()
for name, updated in updated_keys.items():
if name in weights:
val = weights.pop(name)
weights[updated] = val
weights = model.convert_hf_weights(weights)

model.load_state_dict(weights, assign=True, strict=False)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import logging
import os
import time
from enum import Enum
from typing import List, Optional, Tuple
Expand All @@ -9,7 +10,7 @@
import numpy as np
import torch
import torch_xla2
from jetstream.engine.token_utils import pad_tokens, take_nearest_length, DEFAULT_PREFILL_BUCKETS
from jetstream.engine.token_utils import DEFAULT_PREFILL_BUCKETS, pad_tokens, take_nearest_length
from jetstream_pt.engine import PyTorchEngine
from loguru import logger
from transformers import AutoTokenizer, PreTrainedTokenizerBase
Expand Down Expand Up @@ -330,6 +331,9 @@ def warmup(self, batch: Batch) -> int:
# Counter-intuitively, now we ignore the input batch. Instead, we create dummy batches to cover all possible
# batch sizes and sequence lengths.
seq_len = self.model.config.sequence_length
if os.environ.get("SKIP_WARMUP", "0") == "1":
logger.debug("Skipping warmup")
return batch_size * seq_len
bucket_seq_len = take_nearest_length(DEFAULT_PREFILL_BUCKETS, seq_len)
decode_done = False
for l in reversed(DEFAULT_PREFILL_BUCKETS):
Expand Down
Loading

0 comments on commit 094d8a8

Please sign in to comment.