Skip to content

Commit

Permalink
Merge pull request #2 from future-xy/add-datasets
Browse files Browse the repository at this point in the history
feat: generate prompts from datasets
  • Loading branch information
future-xy authored Feb 21, 2025
2 parents f60cbbd + f769f03 commit b1a1526
Show file tree
Hide file tree
Showing 16 changed files with 499 additions and 22 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ tracestorm --model "Qwen/Qwen2.5-1.5B-Instruct" --pattern azure_code
#### Example Command for Loading Prompts from Datasets

```bash
tracestorm --model "Qwen/Qwen2.5-1.5B-Instruct" --duration 30 --datasets-config-file ./examples/datasets_config_hf.json
tracestorm --model "Qwen/Qwen2.5-1.5B-Instruct" --duration 30 --datasets-config ./examples/datasets_config_hf.json
```


Expand All @@ -60,7 +60,7 @@ tracestorm --model "Qwen/Qwen2.5-1.5B-Instruct" --duration 30 --datasets-config-
- Refer to `./examples/datasets_config_local.json` for an example configuration.
- If you want to test loading from local files, please run `./examples/test_data_loader.py` first to download and save two datasets.

2. Remote datasets from Hugging Face
2. Remote datasets from Hugging Face
- Refer to `./examples/datasets_config_hf.json` for an example configuration.

**Sorting Strategy**: Defines how prompts from multiple datasets are ordered
Expand All @@ -85,6 +85,6 @@ Please check `./examples/datasets_config_default.json` for required fields in `d
- `--base-url`: Optional. OpenAI Base URL (default is `http://localhost:8000/v1`).
- `--api-key`: Optional. OpenAI API Key (default is `none`).
- `--seed`: Optional. Random seed for trace pattern reproducibility (default is `none`).
- `--datasets-config-file`: Optional. Configuration file for loading prompt messages from provided datasets. Uses `DEFAULT_MESSAGES` is not specified.
- `--datasets-config`: Optional. Configuration file for loading prompt messages from provided datasets. Uses `DEFAULT_MESSAGES` is not specified.

Make sure to adjust the parameters according to your testing needs!
14 changes: 14 additions & 0 deletions examples/datasets_config_default.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
{
"sort_strategy": "random",
"dataset_1": {
"file_name": "",
"prompt_field": "",
"select_ratio": 1,
"split": "train"
},
"dataset_2": {
"file_name": "",
"prompt_field": "",
"select_ratio": 1
}
}
15 changes: 15 additions & 0 deletions examples/datasets_config_hf.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
{
"sort_strategy": "original",
"dataset_1": {
"file_name": "hf://datasets/fka/awesome-chatgpt-prompts/prompts.csv",
"prompt_field": "prompt",
"select_ratio": 2,
"split": "train"
},
"dataset_2": {
"file_name": "MAsad789565/Coding_GPT4_Data",
"prompt_field": "user",
"select_ratio": 8,
"split": "train"
}
}
13 changes: 13 additions & 0 deletions examples/datasets_config_local.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"sort_strategy": "random",
"dataset_1": {
"file_name": "Conversational_dataset.jsonl",
"prompt_field": "messages",
"select_ratio": 6
},
"dataset_2": {
"file_name": "~/.cache/tracestorm/GPT4_coding_sample.csv",
"prompt_field": "user",
"select_ratio": 4
}
}
28 changes: 28 additions & 0 deletions examples/save_test_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import os

import pandas as pd

from tracestorm.constants import DEFAULT_DATASET_FOLDER


def prepare_test_datasets():
df1 = pd.read_json(
"hf://datasets/MAsad789565/Coding_GPT4_Data/Data/GPT_4_Coding.json"
)
df2 = pd.read_json(
"hf://datasets/olathepavilion/Conversational-datasets-json/Validation.jsonl",
lines=True,
)

# save the pre-processed dataset to the default folder for test
os.makedirs(DEFAULT_DATASET_FOLDER, exist_ok=True)
path1 = os.path.join(DEFAULT_DATASET_FOLDER, "GPT4_coding_sample.csv")
path2 = os.path.join(DEFAULT_DATASET_FOLDER, "Conversational_dataset.jsonl")

# test with different file formats
df1.to_csv(path1, index=False)
df2.to_json(path2, orient="records", lines=True)


if __name__ == "__main__":
prepare_test_datasets()
7 changes: 6 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@ authors = [
]

dependencies = [
"datasets>=3.3.2",
"openai>=1.58.0",
"numpy>=1.26.4",
"pandas>=2.2.3",
"requests>=2.31.0",
"seaborn>=0.13.2",
"matplotlib>=3.9",
"click>=8.1.8"
Expand Down Expand Up @@ -43,4 +45,7 @@ ignore = ["B007"] # Loop control variable not used within loop body

[tool.isort]
use_parentheses = true
skip_gitignore = true
skip_gitignore = true

[tool.setuptools]
packages = { find = { exclude = ["examples"] } }
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
click>=8.1.8
datasets>=3.3.2
matplotlib>=3.9
numpy>=1.26.4
openai>=1.58.0
pandas>=2.2.3
requests>=2.31.0
seaborn>=0.13.2
2 changes: 1 addition & 1 deletion tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def test_cli_invalid_pattern(self):
)

self.assertNotEqual(result.exit_code, 0)
self.assertIn("Invalid pattern", result.output)
self.assertIn("Invalid value for '--pattern'", result.output)


if __name__ == "__main__":
Expand Down
63 changes: 63 additions & 0 deletions tests/test_data_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import os
import unittest

import pandas as pd

from tracestorm.constants import DEFAULT_DATASET_FOLDER
from tracestorm.data_loader import Dataset, load_datasets


class TestDataLoader(unittest.TestCase):
def test_remote_files(self):
"""
Test loading datasets from hugging face.
There are 2 datasets, testing for:
1. loading with datasets.load_dataset
2. loading csv format with pandas
"""
datasets, sort = load_datasets("examples/datasets_config_hf.json")
assert isinstance(datasets, list)
assert isinstance(datasets[0], Dataset) and isinstance(
datasets[1], Dataset
)
assert sort == "original"
assert len(datasets) == 2
assert datasets[0].select_ratio == 2 and datasets[1].select_ratio == 8
assert datasets[0].length > 0 and datasets[1].length > 0

def test_local_files(self):
"""Test loading from local files"""

os.makedirs(DEFAULT_DATASET_FOLDER, exist_ok=True)
# testing datasets
df1 = pd.read_json(
"hf://datasets/MAsad789565/Coding_GPT4_Data/Data/GPT_4_Coding.json"
)
df2 = pd.read_json(
"hf://datasets/olathepavilion/Conversational-datasets-json/Validation.jsonl",
lines=True,
)

# test with different file formats
path1 = os.path.join(DEFAULT_DATASET_FOLDER, "GPT4_coding_sample.csv")
path2 = os.path.join(
DEFAULT_DATASET_FOLDER, "Conversational_dataset.jsonl"
)

# save the pre-processed dataset to the default folder for test
df1.to_csv(path1, index=False)
df2.to_json(path2, orient="records", lines=True)

datasets, sort = load_datasets("examples/datasets_config_local.json")
assert isinstance(datasets, list)
assert isinstance(datasets[0], Dataset) and isinstance(
datasets[1], Dataset
)
assert sort == "random"
assert len(datasets) == 2
assert datasets[0].select_ratio == 6 and datasets[1].select_ratio == 4
assert datasets[0].length > 0 and datasets[1].length > 0


if __name__ == "__main__":
unittest.main()
45 changes: 39 additions & 6 deletions tracestorm/cli.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import os
from typing import Tuple
from typing import Optional, Tuple

import click

from tracestorm.core import run_load_test
from tracestorm.data_loader import load_datasets
from tracestorm.logger import init_logger
from tracestorm.trace_generator import (
AzureTraceGenerator,
Expand All @@ -14,13 +15,13 @@
logger = init_logger(__name__)

# Valid patterns
SYNTHETIC_PATTERNS = {"uniform"}
SYNTHETIC_PATTERNS = {"uniform", "poisson", "random"}
AZURE_PATTERNS = {"azure_code", "azure_conv"}
VALID_PATTERNS = SYNTHETIC_PATTERNS | AZURE_PATTERNS


def create_trace_generator(
pattern: str, rps: int, duration: int
pattern: str, rps: int, duration: int, seed: Optional[int] = None
) -> Tuple[TraceGenerator, str]:
"""
Create appropriate trace generator based on pattern and validate parameters.
Expand All @@ -29,6 +30,7 @@ def create_trace_generator(
pattern: Pattern for trace generation
rps: Requests per second (only for synthetic patterns)
duration: Duration in seconds (only for synthetic patterns)
seed: Random seed for reproducibility of trace patterns
Returns:
Tuple of (TraceGenerator instance, Warning message or empty string)
Expand All @@ -50,7 +52,9 @@ def create_trace_generator(
raise ValueError(
"Duration must be non-negative for synthetic patterns"
)
return SyntheticTraceGenerator(rps, pattern, duration), warning_msg
return SyntheticTraceGenerator(
rps, pattern, duration, seed
), warning_msg

# Azure patterns
if rps != 1:
Expand All @@ -75,6 +79,7 @@ def create_trace_generator(
@click.option(
"--pattern",
default="uniform",
type=click.Choice(sorted(VALID_PATTERNS), case_sensitive=False),
help=f"Pattern for generating trace. Valid patterns: {sorted(VALID_PATTERNS)}",
)
@click.option(
Expand All @@ -83,6 +88,12 @@ def create_trace_generator(
default=10,
help="Duration in seconds (only used with synthetic patterns)",
)
@click.option(
"--seed",
type=int,
default=None,
help="Random seed for reproducibility of trace patterns",
)
@click.option(
"--subprocesses", type=int, default=1, help="Number of subprocesses"
)
Expand All @@ -98,21 +109,43 @@ def create_trace_generator(
default=lambda: os.environ.get("OPENAI_API_KEY", "none"),
help="OpenAI API Key",
)
def main(model, rps, pattern, duration, subprocesses, base_url, api_key):
@click.option(
"--datasets-config", default=None, help="Config file for datasets"
)
def main(
model,
rps,
pattern,
duration,
seed,
subprocesses,
base_url,
api_key,
datasets_config,
):
"""Run trace-based load testing for OpenAI API endpoints."""
try:
trace_generator, warning_msg = create_trace_generator(
pattern, rps, duration
pattern, rps, duration, seed
)
if warning_msg:
logger.warning(warning_msg)

if datasets_config is None:
datasets = []
sort_strategy = None
else:
datasets, sort_strategy = load_datasets(datasets_config)

_, result_analyzer = run_load_test(
trace_generator=trace_generator,
model=model,
subprocesses=subprocesses,
base_url=base_url,
api_key=api_key,
datasets=datasets,
sort_strategy=sort_strategy,
seed=seed,
)

print(result_analyzer)
Expand Down
4 changes: 4 additions & 0 deletions tracestorm/constants.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os

AZURE_REPO_URL = "Azure/AzurePublicDataset"

AZURE_DATASET_PATHS = {
Expand All @@ -11,3 +13,5 @@
DEFAULT_SUBPROCESSES = 1

DEFAULT_MESSAGES = "Tell me a story"

DEFAULT_DATASET_FOLDER = os.path.expanduser("~/.cache/tracestorm")
16 changes: 14 additions & 2 deletions tracestorm/core.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import multiprocessing
from typing import List, Tuple
from typing import List, Optional, Tuple

from tracestorm.logger import init_logger
from tracestorm.request_generator import generate_request
Expand All @@ -17,6 +17,9 @@ def run_load_test(
subprocesses: int,
base_url: str,
api_key: str,
datasets: List,
sort_strategy: Optional[str] = None,
seed: Optional[int] = None,
) -> Tuple[List[Tuple], ResultAnalyzer]:
"""
Run load test with given configuration.
Expand All @@ -27,6 +30,9 @@ def run_load_test(
subprocesses: Number of subprocesses to use
base_url: Base URL for API calls
api_key: API key for authentication
datasets: List of datasets to generate prompts
sort_strategy: Sorting strategy for prompts in datasets.
seed: Random seed for sorting.
Returns:
Tuple of (List of results, ResultAnalyzer instance)
Expand All @@ -38,7 +44,13 @@ def run_load_test(
logger.warning("No requests to process. Trace is empty.")
return [], ResultAnalyzer()

requests = generate_request(model, total_requests)
requests = generate_request(
model_name=model,
nums=total_requests,
datasets=datasets,
sort_strategy=sort_strategy,
seed=seed,
)
ipc_queue = multiprocessing.Queue()
processes = []

Expand Down
Loading

0 comments on commit b1a1526

Please sign in to comment.