Skip to content

Commit

Permalink
Spyre tests (#61)
Browse files Browse the repository at this point in the history
Supports issue #878

This PR comprises an initial set of 5 end-to-end tests (pytest) to
verify the operation of vLLM on Spyre.

1) verification of correct output by comparing against output generated
using HF transformers
2) verification of correct handling of multiple "overlapping" warmup
shapes
3) verification of correct handling of prompts that are exceeding the
maximum prompt length defined by the warmup shapes
4) verification of tensor parallel operation
5) verification that seeded sampling should result in deterministic
output (covers recent issue)

The various tests allow to cover the various backends ('eager',
'inductor', 'sendnn_decoder'), multiple warmup shapes and combinations,
and models. These tests also generate additional output that can be used
to efficiently analyze and debug any problems.

An example of such output is shown below:

```
model:         /models/llama-194m
warmup shapes: [(64, 20, 8)]
backend:       sendnn_decoder

#prompts:      4
#HF results:   4
#vLLM results: 4

....

prompt:        'What is the weather today like?'
generated:
        HF:    ' What is the weather like tomorrow? What is the weather like today? What is the weather like tomorrow'
        vLLM:  ' What is the temperature? What is the humidity? What is the wind? What is the wind direction'  ERROR

   token id. token               logprob         token id. token               logprob
HF:     3639 ' What'             -1.592316  vLLM:     3639 ' What'             -1.583668
HF:      374 ' is'               -1.394676  vLLM:      374 ' is'               -1.388251
HF:      279 ' the'              -0.357713  vLLM:      279 ' the'              -0.350707
HF:     9282 ' weather'          -1.251681  vLLM:     9499 ' temperature'      -1.276881    ERROR
HF:     1093 ' like'             -0.403686  vLLM:       30 '?'                 -1.256650    ERROR
HF:    16986 ' tomorrow'         -1.232682  vLLM:     3639 ' What'             -0.992781    ERROR
HF:       30 '?'                 -0.236499  vLLM:      374 ' is'               -0.772252    ERROR
HF:     3639 ' What'             -0.647572  vLLM:      279 ' the'              -0.110132    ERROR
HF:      374 ' is'               -0.691557  vLLM:    38193 ' humidity'         -1.615233    ERROR
HF:      279 ' the'              -0.176384  vLLM:       30 '?'                 -0.366836    ERROR
HF:     9282 ' weather'          -0.283581  vLLM:     3639 ' What'             -0.616249    ERROR
HF:     1093 ' like'             -0.266174  vLLM:      374 ' is'               -0.546297    ERROR
HF:     3432 ' today'            -0.595149  vLLM:      279 ' the'              -0.066663    ERROR
HF:       30 '?'                 -0.449156  vLLM:    10160 ' wind'             -1.652243    ERROR
HF:     3639 ' What'             -1.047424  vLLM:       30 '?'                 -1.023496    ERROR
HF:      374 ' is'               -0.569301  vLLM:     3639 ' What'             -0.602964    ERROR
HF:      279 ' the'              -0.122663  vLLM:      374 ' is'               -0.443599    ERROR
HF:     9282 ' weather'          -0.208832  vLLM:      279 ' the'              -0.075392    ERROR
HF:     1093 ' like'             -0.267763  vLLM:    10160 ' wind'             -1.916859    ERROR
HF:    16986 ' tomorrow'         -0.968443  vLLM:     5216 ' direction'        -1.399925    ERROR

logprob absolute differences: average=0.413219  maximum=1.649096
logprob relative differences: average=1.180087  maximum=6.158796
```
As this example illustrates, the test compares the text/tokens and
corresponding logprobs that are generated by HF and vLLM.
In his example, the logprobs for 'weather' and 'temperature' are very
close for the 4th token that was generated (as can be seen above). Due
to different processing/precision etc. on the CPU and the Spyre card,
'weather' respectively 'temperature' had the highest probability for HF
respectively vLLM/Spyre. The output of the test allows to analyze these
cases quickly. Based on this output, the test can then be adapted to use
different prompts that only report 'real' errors.
  • Loading branch information
jvlunteren authored and GitHub Enterprise committed Nov 7, 2024
1 parent b17d9a6 commit 16a97a1
Show file tree
Hide file tree
Showing 6 changed files with 516 additions and 0 deletions.
153 changes: 153 additions & 0 deletions tests/spyre/spyre_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# vLLM / Spyre
from vllm import LLM, SamplingParams
import os

DISABLE_ASSERTS=False # used for debugging

def generate_spyre_vllm_output(
model: str,
prompts: list[str],
warmup_shapes : list[tuple([int, int, int])],
max_model_len : int,
block_size : int,
sampling_params : SamplingParams,
tensor_parallel_size : int,
backend: str) -> list[dict[str, any]]:

warmup_prompt_length = [t[0] for t in warmup_shapes]
warmup_new_tokens = [t[1] for t in warmup_shapes]
warmup_batch_size = [t[2] for t in warmup_shapes]

os.environ['VLLM_SPYRE_WARMUP_PROMPT_LENS'] = ','.join(str(val) for val in warmup_prompt_length)
os.environ['VLLM_SPYRE_WARMUP_NEW_TOKENS'] = ','.join(str(val) for val in warmup_new_tokens)
os.environ['VLLM_SPYRE_WARMUP_BATCH_SIZES'] = ','.join(str(val) for val in warmup_batch_size)
os.environ['VLLM_SPYRE_DYNAMO_BACKEND'] = backend

vllm_model = LLM(
model=model,
tokenizer=model,
max_model_len=max_model_len,
block_size=block_size,
tensor_parallel_size=tensor_parallel_size,
device="sendnn"
)

vllm_outputs = vllm_model.generate(prompts, sampling_params)

results = []
for req_output in vllm_outputs:
result = {}
result['text'] = req_output.outputs[0].text
result['token_ids'] = req_output.outputs[0].token_ids
result['tokens'] = tuple([req_output.outputs[0].logprobs[i][t].decoded_token for i,t in enumerate(result['token_ids'])])
result['logprobs'] = tuple([req_output.outputs[0].logprobs[i][t].logprob for i,t in enumerate(result['token_ids'])])
results.append(result)

return results



# Hugging Face
from transformers import AutoModelForCausalLM, AutoTokenizer

def generate_hf_output(
model: str,
prompts : list[str],
max_new_tokens : int) -> list[dict[str, any]]:

hf_model = AutoModelForCausalLM.from_pretrained(model)
hf_tokenizer = AutoTokenizer.from_pretrained(model)

results = []
for prompt_index, prompt in enumerate(prompts):
hf_input_tokens = hf_tokenizer(prompt, return_tensors="pt").input_ids
hf_output = hf_model.generate(
hf_input_tokens,
do_sample=False,
max_new_tokens=max_new_tokens,
return_dict_in_generate=True,
output_scores=True)

# decode output tokens after first removing input tokens (prompt)
hf_generated_text = hf_tokenizer.batch_decode(hf_output.sequences[:,len(hf_input_tokens[0]):])[0]
hf_transition_scores = hf_model.compute_transition_scores(
hf_output.sequences,
hf_output.scores,
normalize_logits=True
)

# return HF generated text, tokens, token ids and logprobs
result = {}
result['text'] = hf_generated_text
result['token_ids'] = []
result['tokens'] = []
result['logprobs'] = []
for tok_index, hf_logprob in enumerate(hf_transition_scores[0]):
hf_token_id = hf_output.sequences[0][tok_index + len(hf_input_tokens[0])]
result['token_ids'].append(hf_token_id.item())
result['tokens'].append(hf_tokenizer.decode(hf_token_id))
result['logprobs'].append(hf_logprob.item())
result['token_ids'] = tuple(result['token_ids'])
result['tokens'] = tuple(result['tokens'])
result['logprobs'] = tuple(result['logprobs'])
results.append(result)

return results



# compare results
import math
import numpy as np

def compare_results(
model: str,
prompts: list[str],
warmup_shapes : list[tuple([int, int, int])],
tensor_parallel_size : int,
backend: str,
vllm_results : list[dict[str, any]],
hf_results : list[dict[str, any]]):

print(f"\nmodel: {model:s}")
print(f"warmup shapes: {warmup_shapes}")
print(f"tp size: {tensor_parallel_size}")
print(f"backend: {backend:s}")
print(f"\n#prompts: {len(prompts):d}")
print(f"#HF results: {len(hf_results):d}{'' if len(hf_results) == len(prompts) else ' ERROR':s}")
print(f"#vLLM results: {len(vllm_results):d}{'' if len(vllm_results) == len(prompts) else ' ERROR':s}")
print()
assert DISABLE_ASSERTS or len(hf_results) == len(vllm_results)
assert DISABLE_ASSERTS or len(hf_results) == len(prompts)

for prompt_index, (prompt, hf_result, vllm_result) in enumerate(zip(prompts, hf_results, vllm_results)):
print(f"\nprompt {prompt_index:3d}: {repr(prompt):s}")
print("generated:")
print(f" HF: {repr(hf_result['text']):s}")
print(f" vLLM: {repr(vllm_result['text']):s}{'' if hf_result['text'] == vllm_result['text'] else ' ERROR':s}")
print()

assert DISABLE_ASSERTS or hf_result['text'] == vllm_result['text']

if len(hf_result['tokens']) > 0:
print(" token id. token logprob token id. token logprob")

logprob_abs_diff_list = []
logprob_rel_diff_list = []

for hf_token, hf_token_id, hf_logprob, vllm_token, vllm_token_id, vllm_logprob in zip(hf_result['tokens'], hf_result['token_ids'], hf_result['logprobs'], vllm_result['tokens'], vllm_result['token_ids'], vllm_result['logprobs']):
logprob_abs_diff = math.fabs(hf_logprob - vllm_logprob)
logprob_abs_diff_list.append(logprob_abs_diff)
logprob_rel_diff = math.fabs(logprob_abs_diff/hf_logprob)
logprob_rel_diff_list.append(logprob_rel_diff)
print(f"HF: {hf_token_id:8d} {repr(hf_token):14s} {hf_logprob:14f} "
f"vLLM: {vllm_token_id:8d} {repr(vllm_token):14s} {vllm_logprob:14f} "
f"{'' if hf_token_id == vllm_token_id and math.isclose(hf_logprob, vllm_logprob, rel_tol=0.1) else ' ERROR':s}")
#f"logprob absolute difference={logprob_abs_diff:f} relative difference={logprob_rel_diff:f}")
assert DISABLE_ASSERTS or hf_token_id == vllm_token_id
assert DISABLE_ASSERTS or math.isclose(hf_logprob, vllm_logprob, rel_tol=0.1) #abs_tol=0.01)
print()
print(f"logprob absolute differences: average={np.mean(logprob_abs_diff_list):f} maximum={np.max(logprob_abs_diff_list):f}")
print(f"logprob relative differences: average={np.mean(logprob_rel_diff_list):f} maximum={np.max(logprob_rel_diff_list):f}")

print()
69 changes: 69 additions & 0 deletions tests/spyre/test_spyre_basic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
"""Verification of vLLM output by comparing with HF
Run `pytest tests/test_spyre_basic.py`.
"""

from vllm import SamplingParams

from spyre_util import generate_spyre_vllm_output, generate_hf_output, compare_results

import pytest


@pytest.mark.parametrize("model", ["/models/llama-194m"])
@pytest.mark.parametrize("prompts",
[["Provide a list of instructions for preparing chicken soup for a family of four.",
"Hello",
"What is the weather today like?",
"Who are you?"]])
@pytest.mark.parametrize("warmup_shape", [(64,20,4),(64,20,8),(128,20,4),(128,20,8)]) # (prompt_length/new_tokens/batch_size)
@pytest.mark.parametrize("backend", ["eager"]) #, "inductor", "sendnn_decoder"])
def test_output(
model: str,
prompts : list[str],
warmup_shape : tuple([int, int, int]),
backend: str,
) -> None:
'''
The warmup is based on a single shape. After the warmup,
one request with the provided prompts is input to vLLM.
The same prompts are also input to HF in a one-by-one fashion
(in order to prevent batching and padding in HF to have impact
on the result). The generated output including text, token ids,
and logprobs, is verified to be identical for vLLM and HF.
'''

max_new_tokens = warmup_shape[1]

vllm_sampling_params = SamplingParams(
max_tokens=max_new_tokens,
temperature=0,
logprobs=0, # return logprobs of generated tokens only
ignore_eos=True
)

vllm_results = generate_spyre_vllm_output(
model=model,
prompts=prompts,
warmup_shapes=[warmup_shape],
max_model_len=2048,
block_size=2048,
sampling_params=vllm_sampling_params,
tensor_parallel_size=1,
backend=backend
)

hf_results = generate_hf_output(
model=model,
prompts=prompts,
max_new_tokens=max_new_tokens)

compare_results(
model=model,
prompts=prompts,
warmup_shapes=[warmup_shape],
tensor_parallel_size=1,
backend=backend,
vllm_results=vllm_results,
hf_results=hf_results)

81 changes: 81 additions & 0 deletions tests/spyre/test_spyre_max_prompt_length.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""Verification of handling prompt length exceeding warmup shapes
Run `pytest tests/test_spyre_max_prompt_length.py`.
"""

from vllm import SamplingParams
from transformers import AutoTokenizer

from spyre_util import generate_spyre_vllm_output, generate_hf_output, compare_results

import pytest


@pytest.mark.parametrize("model", ["/models/llama-194m"])
@pytest.mark.parametrize("prompts",
[7 * ["Hello","Below is an instruction that describes a task. Write a response that appropriately completes the request. Be polite in your response to the user. Provide a list of instructions for preparing chicken soup for a family of four. Indicate if the weather forecast looks good for today. Explain in a brief summary comprised of at most 50 words what you are."]])
@pytest.mark.parametrize("warmup_shapes", [[(64,20,4)],[(64,20,4),(128,20,4)]]) # (prompt_length/new_tokens/batch_size)
@pytest.mark.parametrize("backend", ["eager"]) #, "inductor", "sendnn_decoder"])
def test_output(
model: str,
prompts : list[str],
warmup_shapes : list[tuple([int, int, int])],
backend: str,
) -> None:
'''
The warmup is based on one or multiple shapes. After the warmup,
one request with multiple provided prompts is input to vLLM.
At least one provided prompt should have a length longer than the
maximum prompt length defined by the warmup shapes. It is useful
to define enough prompts to fill multiple batches entirely and
partially, in order to test the maximum prompt length check
also in relation with the position of a prompt within a batch (not
likely that this will be an issue, but just to be sure).
It is verified that only for the prompts that
do not exceed the maximum prompt length, "non-empty" output is
generated. The output is verified using HF.
'''

max_prompt_length = max([t[0] for t in warmup_shapes])
max_new_tokens = max([t[1] for t in warmup_shapes])

vllm_sampling_params = SamplingParams(
max_tokens=max_new_tokens,
temperature=0,
logprobs=0, # return logprobs of generated tokens only
ignore_eos=True
)

vllm_results = generate_spyre_vllm_output(
model=model,
prompts=prompts,
warmup_shapes=warmup_shapes,
max_model_len=2048,
block_size=2048,
sampling_params=vllm_sampling_params,
tensor_parallel_size=1,
backend=backend
)

hf_results = generate_hf_output(
model=model,
prompts=prompts,
max_new_tokens=max_new_tokens)

# for prompts longer than the max_prompt_length, the corresponding output in
# hf_results is reset to 'empty' in order to create the expected output for vLLM
hf_tokenizer = AutoTokenizer.from_pretrained(model)
for prompt_index, prompt in enumerate(prompts):
hf_input_tokens = hf_tokenizer(prompt, return_tensors="pt").input_ids
if len(hf_input_tokens[0]) > max_prompt_length:
hf_results[prompt_index] = {'text': '', 'token_ids': (), 'tokens': (), 'logprobs': ()}

compare_results(
model=model,
prompts=prompts,
warmup_shapes=warmup_shapes,
tensor_parallel_size=1,
backend=backend,
vllm_results=vllm_results,
hf_results=hf_results)

70 changes: 70 additions & 0 deletions tests/spyre/test_spyre_seed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""Verification of seeded random sampling to be deterministic
Run `pytest tests/test_spyre_seed.py`.
"""

from vllm import SamplingParams

from spyre_util import generate_spyre_vllm_output

import pytest
import math


@pytest.mark.parametrize("model", ["/models/llama-194m"])
@pytest.mark.parametrize("prompt", ["Provide a list of instructions for preparing chicken soup for a family of four."])
@pytest.mark.parametrize("temperature", [0.1, 1.0])
@pytest.mark.parametrize("seed", [42])
@pytest.mark.parametrize("warmup_shape", [(64,20,4),(64,20,8),(128,20,4),(128,20,8)]) # (prompt_length/new_tokens/batch_size)
@pytest.mark.parametrize("backend", ["eager"]) #, "inductor", "sendnn_decoder"])
def test_seed(
model: str,
prompt: str,
temperature : float,
seed : int,
warmup_shape : tuple([int, int, int]),
backend: str,
) -> None:
'''
The warmup is based on a single shape. After the warmup,
output is generated for one request with 16 identical prompts
using random sampling (non-zero temperature) in combination
with a seed. The generated output, including text, token ids,
logprobs is verified to be identical for all 16 sequences.
'''

max_new_tokens = warmup_shape[1]

prompts = [prompt] * 16

vllm_sampling_params = SamplingParams(
max_tokens=max_new_tokens,
temperature=temperature,
logprobs=0, # return logprobs of generated tokens only
ignore_eos=True,
seed=seed
)

vllm_results = generate_spyre_vllm_output(
model=model,
prompts=prompts,
warmup_shapes=[warmup_shape],
max_model_len=2048,
block_size=2048,
sampling_params=vllm_sampling_params,
tensor_parallel_size=1,
backend=backend
)


# compare all generated outputs against the first generated output
for vllm_result in vllm_results:
assert vllm_result['text'] == vllm_results[0]['text']

# compare logprobs for all tokens between the current and the first sequence
assert len(vllm_result['logprobs']) == len(vllm_results[0]['logprobs'])
for token_id, logprob, token_id_0, logprob_0 in zip(vllm_result['token_ids'], vllm_result['logprobs'], vllm_results[0]['token_ids'], vllm_results[0]['logprobs']):
assert token_id == token_id_0
assert math.isclose(logprob, logprob_0, rel_tol=0.1)


Loading

0 comments on commit 16a97a1

Please sign in to comment.