-
Notifications
You must be signed in to change notification settings - Fork 524
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Script to automatically split off eval set #1525
Changes from 5 commits
fe27b8d
18859b1
f29ef67
3d9d51f
4e7b357
83ab9c3
0114f33
9a1b78b
7a23f60
2e3d14f
d7c7822
14cff66
a2c0507
85403c0
f377090
d85c83b
275a2a4
151a2e2
722526d
c786def
e6b8d14
dc58bb7
3b1fc4a
ee45600
107d246
4202a06
0ad6ab4
bdc58b3
30cdd67
ec4cafd
b517297
8cf3d87
a462f03
24fec79
214305f
4bbb4a5
788c1f5
56e4573
983a32d
d3d587d
b921b30
7925001
7b47ae6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
# Copyright 2024 MosaicML LLM Foundry authors | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import logging | ||
import os | ||
import re | ||
import json | ||
import contextlib | ||
import datasets as hf_datasets | ||
import numpy as np | ||
from typing import Optional | ||
|
||
import composer.utils as utils | ||
from llmfoundry.data.finetuning.tasks import maybe_safe_download_hf_data | ||
|
||
|
||
DELTA_JSONL_REGEX = re.compile(r"^tmp-t$") | ||
REMOTE_OBJECT_STORE_FILE_REGEX = re.compile( | ||
r"^((s3|oci|gs):\/\/|dbfs:\/Volumes\/)[/a-zA-Z0-9 ()_\-.]+$" | ||
) | ||
HF_REGEX = re.compile(r"^[/a-zA-Z0-9 ()_\-.]+$") | ||
|
||
TEMP_DIR = "tmp-split" | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
|
||
def maybe_download_data_as_json(data_path_folder: str, data_path_split: str) -> str: | ||
""" | ||
Prepares dataset as a local JSONL file. Downloads from remote object store or HF if necessary. | ||
|
||
This function is intended to be invoked by DBX Finetuning. | ||
Thus, it assumes the provided data is in one of three formats: | ||
1. A Delta table converted to JSONL at 'tmp-t/{data_path_split}-00000-of-00001.jsonl` | ||
using the 'llmfoundry.scripts.convert_delta_to_json.py' script. | ||
2. A JSONL stored as a remote object store file (e.g. S3, OCI, GCS) | ||
3. A Hugging Face dataset | ||
|
||
Args: | ||
data_path_folder (str): Path to the training dataset folder | ||
data_path_split (str): Data split | ||
|
||
Returns: | ||
str: Path to the training dataset | ||
""" | ||
os.makedirs(TEMP_DIR, exist_ok=True) | ||
|
||
if DELTA_JSONL_REGEX.match(data_path_folder): | ||
log.info(f"Dataset is converted from Delta table. Using local file {data_path_folder}") | ||
data_path = os.path.join(data_path_folder, f"{data_path_split}-00000-of-00001.jsonl") | ||
|
||
elif REMOTE_OBJECT_STORE_FILE_REGEX.match(data_path_folder): | ||
log.info( | ||
f"Downloading dataset from remote object store: {data_path_folder}{data_path_split}.jsonl" | ||
) | ||
remote_path = f"{data_path_folder}/{data_path_split}.jsonl" | ||
data_path = os.path.join(TEMP_DIR, f"{data_path_split}.jsonl") | ||
utils.get_file(remote_path, data_path, overwrite=True) | ||
|
||
elif HF_REGEX.match(data_path_folder): | ||
log.info( | ||
f"Downloading dataset from Hugging Face: {data_path_folder} with split {data_path_split}" | ||
) | ||
# TODO: maybe add support for HF kwargs | ||
local_hf_path = maybe_safe_download_hf_data(data_path_folder) | ||
# convert dataset split to JSONL | ||
dataset = hf_datasets.load_dataset( | ||
local_hf_path, | ||
split=data_path_split, | ||
) | ||
data_path = os.path.join(TEMP_DIR, f"{data_path_split}.jsonl") | ||
with open(data_path, "w") as f: | ||
for example in dataset: | ||
f.write(json.dumps(example) + "\n") | ||
|
||
else: | ||
raise ValueError( | ||
f"Encountered unknown data path format when splitting dataset: {data_path_folder} with split {data_path_split}" | ||
) | ||
|
||
if not os.path.exists(data_path): | ||
raise FileNotFoundError( | ||
f"Expected dataset file at {data_path} for splitting, but it does not exist." | ||
) | ||
|
||
return data_path | ||
|
||
|
||
@contextlib.contextmanager | ||
def temp_seed(seed: int): | ||
log.info(f"Setting random seed to {seed}") | ||
state = np.random.get_state() | ||
np.random.seed(seed) | ||
try: | ||
yield | ||
finally: | ||
np.random.set_state(state) | ||
|
||
|
||
def split_examples( | ||
data_path: str, | ||
output_path: str, | ||
eval_split_ratio: float, | ||
max_eval_samples: Optional[int] = None, | ||
seed: Optional[int] = None, | ||
) -> None: | ||
""" | ||
Splits the dataset into training and evaluation sets. | ||
|
||
Args: | ||
data_path (str): Path to the training dataset (local jsonl file) | ||
output_path (str): Directory to save the split dataset | ||
eval_split_ratio (float): Ratio of the dataset to use for evaluation. The remainder will be used for training | ||
max_eval_samples (int): Maximum number of samples to include in the eval set. If None, all eval_split_ratio * train_dataset_size samples will be used | ||
seed (int): Random seed for splitting the dataset | ||
""" | ||
os.makedirs(output_path, exist_ok=True) | ||
|
||
# first pass: count total number of lines and determine sample size | ||
total_lines = 0 | ||
with open(data_path, "r") as infile: | ||
for _ in infile: | ||
total_lines += 1 | ||
sample_size = int(eval_split_ratio * total_lines) | ||
if max_eval_samples is not None: | ||
sample_size = min(sample_size, max_eval_samples) | ||
|
||
with temp_seed(seed) if seed is not None else contextlib.nullcontext(): | ||
random_numbers = np.random.rand(total_lines) | ||
sample_indices = set(np.argsort(random_numbers)[:sample_size]) | ||
|
||
# second pass: sample indices | ||
with open(data_path, "r") as infile, open( | ||
os.path.join(output_path, "train.jsonl"), "w" | ||
) as train_outfile, open(os.path.join(output_path, "eval.jsonl"), "w") as eval_outfile: | ||
for idx, line in enumerate(infile): | ||
if idx in sample_indices: | ||
eval_outfile.write(line) | ||
else: | ||
train_outfile.write(line) | ||
|
||
log.info( | ||
f"Split {data_path} into train set of size {total_lines - sample_size} and eval set of size {sample_size}." | ||
) | ||
|
||
|
||
def split_eval_set_from_args( | ||
data_path_folder: str, | ||
data_path_split: str, | ||
output_path: str, | ||
eval_split_ratio: float, | ||
max_eval_samples: Optional[int] = None, | ||
seed: Optional[int] = None, | ||
) -> None: | ||
""" | ||
A wrapper for split_eval_set that parses arguments | ||
|
||
Args: | ||
data_path_folder (str): Path to the training dataset folder | ||
data_path_split (str): Data split | ||
output_path (str): Directory to save the split dataset | ||
eval_split_ratio (float): Ratio of the dataset to use for evaluation. The remainder will be used for training | ||
max_eval_samples (int): Maximum number of samples to include in the eval set. If None, all eval_split_ratio * train_dataset_size samples will be used | ||
seed (int): Random seed for splitting the dataset | ||
""" | ||
data_path = maybe_download_data_as_json(data_path_folder, data_path_split) | ||
split_examples(data_path, output_path, eval_split_ratio, max_eval_samples, seed) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -702,7 +702,72 @@ def state_dict(self, num_samples: int, | |
num_samples=num_samples, | ||
from_beginning=from_beginning, | ||
) | ||
|
||
def maybe_safe_download_hf_data( | ||
dataset_name: str, | ||
hf_kwargs: Optional[dict[str, Any]] = None | ||
) -> str: | ||
""" | ||
Download a HuggingFace dataset locally if it does not already exist. | ||
|
||
Args: | ||
dataset_name (str): The name of the HuggingFace dataset to use. Can be a remote http(s) | ||
directory or object store bucket containing the file {split}.jsonl. | ||
hf_kwargs (dict, optional): Additional kwargs to pass to `datasets.load_dataset`. | ||
|
||
Returns: | ||
str: The local path to the dataset. | ||
""" | ||
if hf_kwargs is None: | ||
hf_kwargs = {} | ||
|
||
if not os.path.isdir(dataset_name): | ||
local_dataset_dir = os.path.join( | ||
DOWNLOADED_FT_DATASETS_DIRPATH, | ||
dataset_name, | ||
) | ||
|
||
if _is_empty_or_nonexistent(dirpath=local_dataset_dir): | ||
# Safely load the dataset from HF Hub with restricted file types. | ||
hf_hub.snapshot_download( | ||
dataset_name, | ||
repo_type='dataset', | ||
allow_patterns=[ | ||
'*' + ext for ext in SUPPORTED_EXTENSIONS | ||
], | ||
token=hf_kwargs.get('token', None), | ||
revision=hf_kwargs.get('revision', None), | ||
local_dir_use_symlinks=False, | ||
local_dir=local_dataset_dir, | ||
) | ||
if _is_empty_or_nonexistent(dirpath=dataset_name): | ||
log.error("Failed to safely load the dataset from HF Hub.") | ||
raise InvalidFileExtensionError( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Moved HF safe download logic into a separate function so I could reuse it. Code is basically unchanged. Unrelated to meat of the PR, but per Daniel request, tried refactoring further to get rid of try-except block. I don't think it can be done without significant added complexity. You want to barrier/sync before throwing this InvalidFileExtensionError; however, this must be nested within this download logic so that only rank0 encounters it. So you can't both separate out this logic and have graceful exit without try-except. :'( |
||
dataset_name, | ||
SUPPORTED_EXTENSIONS, | ||
) | ||
# Set dataset_name to the downloaded location. | ||
dataset_name = local_dataset_dir | ||
|
||
# Ensure dataset_name is a local directory path (using abspath to avoid confusion). | ||
dataset_name = os.path.abspath(dataset_name) | ||
|
||
# Check that the directory contains only allowed file types. | ||
dataset_files = [ | ||
f for _, _, files in os.walk(dataset_name) for f in files | ||
] | ||
if not all( | ||
Path(f).suffix in SUPPORTED_EXTENSIONS + | ||
HUGGINGFACE_FOLDER_EXTENSIONS or f == '.gitignore' | ||
for f in dataset_files | ||
): | ||
log.error(f"Invalid file extension found in dataset during safe load.") | ||
raise InvalidFileExtensionError( | ||
dataset_name, | ||
SUPPORTED_EXTENSIONS, | ||
) | ||
|
||
return dataset_name | ||
|
||
class DatasetConstructor: | ||
|
||
|
@@ -901,50 +966,10 @@ def build_from_hf( | |
filtered_dataset = None | ||
try: | ||
if safe_load: | ||
if not os.path.isdir(dataset_name): | ||
# dataset_name is not a local dir path, download if needed. | ||
local_dataset_dir = os.path.join( | ||
DOWNLOADED_FT_DATASETS_DIRPATH, | ||
dataset_name, | ||
) | ||
|
||
if _is_empty_or_nonexistent(dirpath=local_dataset_dir): | ||
# Safely load a dataset from HF Hub with restricted file types. | ||
hf_hub.snapshot_download( | ||
dataset_name, | ||
repo_type='dataset', | ||
allow_patterns=[ | ||
'*' + ext for ext in SUPPORTED_EXTENSIONS | ||
], | ||
token=hf_kwargs.get('token', None), | ||
revision=hf_kwargs.get('revision', None), | ||
local_dir_use_symlinks=False, | ||
local_dir=local_dataset_dir, | ||
) | ||
if _is_empty_or_nonexistent(dirpath=local_dataset_dir): | ||
raise InvalidFileExtensionError( | ||
dataset_name, | ||
SUPPORTED_EXTENSIONS, | ||
) | ||
# Set dataset_name to the downloaded location. | ||
dataset_name = local_dataset_dir | ||
|
||
# dataset_name is a local dir path. Use the abspath to prevent confusion. | ||
dataset_name = os.path.abspath(dataset_name) | ||
|
||
# Ensure that the local dir contains only allowed file types. | ||
dataset_files = [ | ||
f for _, _, files in os.walk(dataset_name) for f in files | ||
] | ||
if not all( | ||
Path(f).suffix in SUPPORTED_EXTENSIONS + | ||
HUGGINGFACE_FOLDER_EXTENSIONS or f == '.gitignore' | ||
for f in dataset_files | ||
): | ||
raise InvalidFileExtensionError( | ||
dataset_name, | ||
SUPPORTED_EXTENSIONS, | ||
) | ||
dataset_name = maybe_download_hf_data( | ||
dataset_name, | ||
hf_kwargs, | ||
) | ||
|
||
dataset = hf_datasets.load_dataset( | ||
dataset_name, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# Copyright 2024 MosaicML LLM Foundry authors | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from argparse import ArgumentParser | ||
|
||
from llmfoundry.command_utils import split_eval_set_from_args | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = ArgumentParser( | ||
description="Split training dataset into train and eval sets", | ||
) | ||
parser.add_argument( | ||
"--data_path_folder", required=True, type=str, help="Path to the training dataset folder" | ||
) | ||
parser.add_argument( | ||
"--data_path_split", required=True, type=str, help="Path to the training dataset split" | ||
) | ||
parser.add_argument( | ||
"--output_path", | ||
required=True, | ||
type=str, | ||
help="Path to save the split dataset", | ||
) | ||
parser.add_argument( | ||
"--eval_split_ratio", | ||
required=False, | ||
type=float, | ||
default=0.1, | ||
help="Ratio of the dataset to use for evaluation. The remainder will be used for training", | ||
) | ||
parser.add_argument( | ||
"--max_eval_samples", | ||
required=False, | ||
type=int, | ||
default=None, | ||
help="Maximum number of samples to include in the eval set", | ||
) | ||
parser.add_argument( | ||
"--seed", | ||
required=False, | ||
type=int, | ||
default=42, | ||
help="Random seed for splitting the dataset", | ||
) | ||
args = parser.parse_args() | ||
split_eval_set_from_args( | ||
data_path_folder=args.data_path_folder, | ||
data_path_split=args.data_path_split, | ||
output_path=args.output_path, | ||
eval_split_ratio=args.eval_split_ratio, | ||
max_eval_samples=args.max_eval_samples, | ||
seed=args.seed, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import is formatted this way due to mocking difficulties - https://bhfsteve.blogspot.com/2012/06/patching-tip-using-mocks-in-python-unit.html