diff --git a/token_benchmark_ray.py b/token_benchmark_ray.py index 63216b1..7f9f4bd 100644 --- a/token_benchmark_ray.py +++ b/token_benchmark_ray.py @@ -24,7 +24,7 @@ ) from tqdm import tqdm -from transformers import LlamaTokenizerFast +from transformers import AutoTokenizer def get_token_throughput_latencies( model: str, @@ -34,6 +34,7 @@ def get_token_throughput_latencies( stddev_output_tokens: int, additional_sampling_params: Optional[Dict[str, Any]] = None, num_concurrent_requests: int = 1, + tokenizer: str = "hf-internal-testing/llama-tokenizer", max_num_completed_requests: int = 500, test_timeout_s=90, llm_api="openai", @@ -60,9 +61,7 @@ def get_token_throughput_latencies( """ random.seed(11111) - tokenizer = LlamaTokenizerFast.from_pretrained( - "hf-internal-testing/llama-tokenizer" - ) + tokenizer = AutoTokenizer.from_pretrained(tokenizer) get_token_length = lambda text: len(tokenizer.encode(text)) if not additional_sampling_params: @@ -292,6 +291,7 @@ def run_token_benchmark( additional_sampling_params: str, results_dir: str, user_metadata: Dict[str, Any], + tokenizer: str ): """ Args: @@ -327,6 +327,7 @@ def run_token_benchmark( stddev_output_tokens=stddev_output_tokens, num_concurrent_requests=num_concurrent_requests, additional_sampling_params=json.loads(additional_sampling_params), + tokenizer=tokenizer ) if results_dir: @@ -462,6 +463,14 @@ def run_token_benchmark( "name=foo,bar=1. These will be added to the metadata field of the results. " ), ) +args.add_argument( + "--tokenizer", + type=str, + default="hf-internal-testing/llama-tokenizer", + help=( + "Tokenizer to use for counting tokens" + ), +) if __name__ == "__main__": env_vars = dict(os.environ) @@ -488,4 +497,5 @@ def run_token_benchmark( additional_sampling_params=args.additional_sampling_params, results_dir=args.results_dir, user_metadata=user_metadata, + tokenizer=args.tokenizer )