-
Notifications
You must be signed in to change notification settings - Fork 24
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix generation using Jetstream Pytorch (#94)
* feat(debug): add env var to skip warmup * fix(Jetstream Pt): correct generation Text generation was not correct because the weights in the model were not correctly loaded. This is not something that it was easy to spot just looking at few tokens generated, and it was something that it was actually fixed already in the Jetstream/Pytorch code, but the fix hadn't been ported to optimum-tpu. This fix implement the necessary weights changes, aligning to Jetstream Pytorch, and tests expected output has been modified accordingly. * ci: separate Jetstream Pytorch test to its own workflow The main workflow was failing due to an OS error. I suspect that being related to a problem of space. Separating the workflow will make it easier to analyse this issue. * fix(jetstream Pt): make Jetstream Pt install more reliable I was previously referencing a given git revision and install from github, but since the Jetstream Pytorch package install its dependencies from its git submodels, these are installed in temporary directories, that can disappear afterwards. This happened on CI, making the installation fail. To work around that, a dedicated install script has been added, and it is now used to install that. * fix(style): correct generator style * refactor(Jetstream Pt): avoid duplicating Llama modeling Since this is error-prone, a better solution is just to use this. This hadn't been done before mainly because in the model config we do not have some of the params anymore (ffn_dim_multiplier and multiple_of). We do have intermediate_size though, and that is enough to reconstruct parameters that end up producing the same calculation. This refactor should allow for future code to follow Jetstream/Pytorch changes in an easier way.
- Loading branch information
1 parent
4265e13
commit 094d8a8
Showing
10 changed files
with
95 additions
and
295 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
name: Optimum TPU / Test TGI on TPU / Jetstream Pytorch | ||
|
||
on: | ||
push: | ||
branches: [ main ] | ||
paths: | ||
- "text-generation-inference/**" | ||
pull_request: | ||
branches: [ main ] | ||
paths: | ||
- "text-generation-inference/**" | ||
|
||
concurrency: | ||
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} | ||
cancel-in-progress: true | ||
|
||
jobs: | ||
do-the-job: | ||
name: Run TGI tests - Jetstream Pytorch | ||
runs-on: optimum-tpu | ||
container: | ||
image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.4.0_3.10_tpuvm | ||
options: --shm-size "16gb" --ipc host --privileged | ||
env: | ||
PJRT_DEVICE: TPU | ||
steps: | ||
- name: Checkout | ||
uses: actions/checkout@v4 | ||
|
||
- name: Build and test TGI server | ||
run: | | ||
HF_TOKEN=${{ secrets.HF_TOKEN_OPTIMUM_TPU_CI }} make tgi_test_jetstream |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -133,4 +133,6 @@ dmypy.json | |
*.pt | ||
|
||
.vscode | ||
.idea/ | ||
.idea/ | ||
|
||
jetstream-pt-deps |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
#!/bin/bash | ||
deps_dir=deps | ||
rm -rf $deps_dir | ||
mkdir -p $deps_dir | ||
cd $deps_dir | ||
pwd | ||
git clone https://github.com/google/jetstream-pytorch.git | ||
cd jetstream-pytorch | ||
git checkout ec4ac8f6b180ade059a2284b8b7d843b3cab0921 | ||
git submodule update --init --recursive | ||
# We cannot install in a temporary directory because the directory should not be deleted after the script finishes, | ||
# because it will install its dependendencies from that directory. | ||
pip install -e . |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.