Skip to content

Commit

Permalink
Use pytest, enable returning multiple prompts
Browse files Browse the repository at this point in the history
  • Loading branch information
helena-intel committed Mar 28, 2024
1 parent ae86dc0 commit 1c63f75
Show file tree
Hide file tree
Showing 6 changed files with 147 additions and 143 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ jobs:
python-version: "3.10"
- name: Install package and dependencies
run: |
pip install pytest parameterized sentencepiece transformers protobuf .
pip install .
pip install -r tests/requirements.txt
pip freeze
- name: Run test
run: |
python -m pytest tests
10 changes: 8 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ To generate one or a few prompts, or to test the functionality, you can use the
pip install git+https://github.com/helena-intel/test-prompt-generator.git transformers
```

Some tokenizers may require additional dependencies. For example, ChatGLM3 requires `sentencepiece`.
Some tokenizers may require additional dependencies. For example, `sentencepiece` or `protobuf`.

## Usage

Expand All @@ -29,6 +29,7 @@ Prompts are generated by truncating a given source text at the provided number o
A prefix can optionally be prepended to the text, to create prompts like "Please summarize the following text: [text]". The
prompts are returned by the function/command line app, and can also optionally be saved to a .jsonl file.


### Python API

#### Basic usage
Expand Down Expand Up @@ -69,6 +70,11 @@ prompt = generate_prompt(
)
```

> NOTE: When specifing one token size, the prompt will be returned as string, making it easy to copy and use in a test scenario
where you need one prompt. When specifying multiple token sizes a dictionary with the prompts will be returned. The output file
is always in .jsonl format, regardless of the number of generated prompts.


### Command Line App

```shell
Expand All @@ -86,7 +92,7 @@ options:
preset tokenizer id, model_id from Hugging Face hub, or path to local directory with tokenizer files. Options for presets are: ['bert', 'bloom', 'gemma', 'chatglm3', 'falcon', 'gpt-neox',
'llama', 'magicoder', 'mistral', 'opt', 'phi-2', 'pythia', 'roberta', 'qwen', 'starcoder', 't5']
-n NUM_TOKENS, --num_tokens NUM_TOKENS
Number of tokens the generated prompt should have. To specify multiple token sizes, use e.g. `-n 16 32` and include `--output_file`
Number of tokens the generated prompt should have. To specify multiple token sizes, use e.g. `-n 16 32`
-p PREFIX, --prefix PREFIX
Optional: prefix that the prompt should start with. Example: 'Translate to Dutch:'
-o OUTPUT_FILE, --output_file OUTPUT_FILE
Expand Down
53 changes: 25 additions & 28 deletions test_prompt_generator/test_prompt_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,21 +52,22 @@ def generate_prompt(
verbose: bool = False,
source_text_file: str = None,
source_text: str = None,
return_type: str = None,
):
### Validate inputs
if isinstance(num_tokens, int):
num_tokens = [
num_tokens,
]
num_tokens = [num_tokens]
if source_text == "":
source_text = None
if source_text_file == "":
source_text_file = None
if source_text_file is not None and source_text is not None:
raise ValueError("Only one of `source_text` or `source_text_file` should be provided.")

if len(num_tokens) > 1 and output_file is None:
raise ValueError("When generating multiple prompts, output_file should be provided.")
if return_type is None:
return_type = "string" if len(num_tokens) == 1 else "dict"
if return_type not in ("string", "dict"):
raise ValueError("return_type should be 'string', 'dict' or None")

### Load tokenizer
if tokenizer_id in _preset_tokenizers:
Expand All @@ -88,8 +89,7 @@ def generate_prompt(
prefix_num_tokens = len(prefix_tokens)
if prefix_num_tokens > min(num_tokens):
logging.warning(
f"Requested number of tokens {min(num_tokens)} is smaller than "
f"number of prefix tokens: {prefix_num_tokens}"
f"Requested number of tokens {min(num_tokens)} is smaller than " f"number of prefix tokens: {prefix_num_tokens}"
)

source_text = prefix.strip() + " " + source_text
Expand All @@ -107,10 +107,9 @@ def generate_prompt(
tokens = inputs["input_ids"]
total_tokens = len(tokens)
if max(num_tokens) > total_tokens:
raise ValueError(
f"Cannot generate prompt with {max(num_tokens)} tokens; the source text contains {total_tokens} tokens."
)
raise ValueError(f"Cannot generate prompt with {max(num_tokens)} tokens; the source text contains {total_tokens} tokens.")

prompt_dicts = []
for i, tok in enumerate(num_tokens):
num_tokens_from_source = tok
if tokenizer("hello")["input_ids"][-1] in tokenizer.all_special_ids:
Expand Down Expand Up @@ -139,33 +138,30 @@ def generate_prompt(
print("prompt", repr(prompt))
print("prompt_tokens", prompt_tokens["input_ids"])
print("source_tokens", tokens[:num_tokens_from_source])
raise RuntimeError(
f"Expected {tok} tokens, got {len(prompt_tokens['input_ids'])}. Tokenizer: {tokenizer_id}"
)
raise RuntimeError(f"Expected {tok} tokens, got {len(prompt_tokens['input_ids'])}. Tokenizer: {tokenizer_id}")

### Write output file

prompt_dict = {
"prompt": prompt,
"model_id": tokenizer_id,
"token_size": tok,
}
prompt_dicts.append(prompt_dict)

jsonl_result = "\n".join(json.dumps(item) for item in prompt_dicts)

if output_file is not None:
if i == 0:
if (Path(output_file).exists()) and (not overwrite):
raise FileExistsError(f"{output_file} already exists. Set overwrite to allow overwriting.")
Path(output_file).parent.mkdir(parents=True, exist_ok=True)
if output_file is not None:
if (Path(output_file).exists()) and (not overwrite):
raise FileExistsError(f"{output_file} already exists. Set overwrite to allow overwriting.")
Path(output_file).parent.mkdir(parents=True, exist_ok=True)

with open(output_file, "w") as f:
json.dump(prompt_dict, f, ensure_ascii=False)
else:
with open(output_file, "a") as f:
f.write("\n")
json.dump(prompt_dict, f, ensure_ascii=False)
with open(output_file, "w") as f:
f.write(f"{jsonl_result}\n")

if len(num_tokens) == 1:
return prompt
if return_type == "string":
return prompt if len(num_tokens) == 1 else "\n".join(item["prompt"] for item in prompt_dicts)
else:
return prompt_dict if len(num_tokens) == 1 else prompt_dicts


def main():
Expand Down Expand Up @@ -218,7 +214,7 @@ def main():
if args.verbose:
logging.info(f"Command line arguments: {args}")

return generate_prompt(
result = generate_prompt(
tokenizer_id=args.tokenizer,
num_tokens=args.num_tokens,
prefix=args.prefix,
Expand All @@ -227,6 +223,7 @@ def main():
verbose=args.verbose,
source_text_file=args.file,
)
return result


if __name__ == "__main__":
Expand Down
3 changes: 3 additions & 0 deletions tests/pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[pytest]
markers =
quicktest: Runs a quick test without all parameterized options
5 changes: 5 additions & 0 deletions tests/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
transformers
sentencepiece
protobuf
pytest
parameterized
Loading

0 comments on commit 1c63f75

Please sign in to comment.