Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Optimize ray mode performance #442

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 42 additions & 29 deletions data_juicer/core/ray_data.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from functools import partial

import pyarrow as pa
from loguru import logger
Expand All @@ -14,28 +15,26 @@
from ray.data import Dataset


def is_valid_path(item, dataset_dir):
full_path = os.path.abspath(os.path.join(dataset_dir, item))
return os.path.exists(full_path)
def get_abs_path(path, dataset_dir):
full_path = os.path.abspath(os.path.join(dataset_dir, path))
if os.path.exists(full_path):
return full_path
else:
return path


def convert_to_absolute_paths(dict_with_paths, dataset_dir, path_keys):
def convert_to_absolute_paths(samples, dataset_dir, path_keys):
samples = samples.to_pydict()
for key in path_keys:
if key not in dict_with_paths:
continue
if isinstance(dict_with_paths[key], list):
dict_with_paths[key] = [
os.path.abspath(os.path.join(dataset_dir, item))
if isinstance(item, str) and is_valid_path(dataset_dir, item)
else item for item in dict_with_paths[key]
]
elif isinstance(dict_with_paths[key], str):
dict_with_paths[key] = os.path.abspath(
os.path.join(dataset_dir,
dict_with_paths[key])) if is_valid_path(
dict_with_paths[key],
dataset_dir) else dict_with_paths[key]
return dict_with_paths
for idx in range(len(samples[key])):
paths = samples[key][idx]
if isinstance(paths, str):
samples[key][idx] = get_abs_path(paths, dataset_dir)
elif isinstance(paths, list):
samples[key][idx] = [
get_abs_path(item, dataset_dir) for item in paths
]
return pa.Table.from_pydict(samples)


# TODO: check path for nestdataset
Expand All @@ -44,22 +43,26 @@ def set_dataset_to_absolute_path(dataset, dataset_path, cfg):
Set all the path in input data to absolute path.
Checks dataset_dir and project_dir for valid paths.
"""
if not (cfg.video_key in dataset.columns() or cfg.image_key
in dataset.columns() or cfg.audio_key in dataset.columns()):
return dataset
dataset_dir = os.path.dirname(dataset_path)
dataset = dataset.map(lambda item: convert_to_absolute_paths(
item, dataset_dir, [cfg.video_key, cfg.image_key, cfg.audio_key]))
logger.info(f"transfer {dataset.count()} sample's paths")
path_keys = []
columns = dataset.columns()
for key in [cfg.video_key, cfg.image_key, cfg.audio_key]:
if key in columns:
path_keys.append(key)
if len(path_keys) > 0:
dataset_dir = os.path.dirname(dataset_path)
dataset = dataset.map_batches(partial(convert_to_absolute_paths,
dataset_dir=dataset_dir,
path_keys=path_keys),
batch_format='pyarrow',
zero_copy_batch=True)
return dataset


def preprocess_dataset(dataset: Dataset, dataset_path, cfg) -> Dataset:
columns = dataset.columns()
if dataset_path:
dataset = set_dataset_to_absolute_path(dataset, dataset_path, cfg)
columns = dataset.columns()
if Fields.stats not in columns:
logger.info(f'columns {columns}')

def process_batch_arrow(table: pa.Table) -> pa.Table:
new_column_data = [{} for _ in range(len(table))]
Expand All @@ -78,6 +81,11 @@ def get_num_gpus(op, op_proc):
return 1.0 / proc_per_gpu


def filter_batch(batch, filter_func):
mask = pa.array(filter_func(batch.to_pydict()))
return batch.filter(mask)


class RayDataset(DJDataset):

def __init__(self,
Expand Down Expand Up @@ -123,7 +131,12 @@ def _run_single_op(self, op):
if op.stats_export_path is not None:
self.data.write_json(op.stats_export_path,
force_ascii=False)
self.data = self.data.filter(op.process)
self.data = self.data.map_batches(partial(
filter_batch, filter_func=op.process),
batch_format='pyarrow',
batch_size=batch_size,
num_gpus=num_gpus,
zero_copy_batch=True)
else:
logger.error(
'Ray executor only support Filter and Mapper OPs for now')
Expand Down
104 changes: 57 additions & 47 deletions data_juicer/ops/filter/flagged_words_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ class FlaggedWordFilter(Filter):
"""Filter to keep samples with flagged-word ratio less than a specific max
value."""

_batched_op = True

def __init__(self,
lang: str = 'en',
tokenization: bool = False,
Expand Down Expand Up @@ -76,53 +78,61 @@ def __init__(self,
self.model_key = prepare_model(model_type='sentencepiece',
lang=lang)

def compute_stats(self, sample, context=False):
def compute_stats(self, samples, context=False):
# check if it's computed already
if StatsKeys.flagged_words_ratio in sample[Fields.stats]:
return sample

# try to get words from context
samples_list = samples[self.text_key]
samples_stats = samples[Fields.stats]
words_key = f'{InterVars.words}-{self.model_key}'
if context and words_key in sample[Fields.context]:
words = sample[Fields.context][words_key]
else:
tokenizer = get_model(self.model_key)
words = get_words_from_document(
sample[self.text_key],
token_func=tokenizer.encode_as_pieces if tokenizer else None)
if context:
sample[Fields.context][words_key] = words

# try to get refined words from context
refined_words_key = f'{InterVars.refined_words}-True-SPECIAL_CHARS-' \
f'{self.use_words_aug}-' \
f'{self.words_aug_group_sizes}-' \
f'{self.words_aug_join_char}'
if context and refined_words_key in sample[Fields.context]:
words = sample[Fields.context][refined_words_key]
tokenizer = get_model(self.model_key)
for idx, stat in enumerate(samples_stats):
if StatsKeys.flagged_words_ratio in stat:
continue
if context and words_key in samples[Fields.context][idx]:
words = samples[Fields.context][idx][words_key]
else:
words = get_words_from_document(
samples_list[idx],
token_func=tokenizer.encode_as_pieces
if tokenizer else None)
if context:
samples[Fields.context][idx][words_key] = words
# try to get refined words from context
refined_words_key = f'{InterVars.refined_words}' \
'-True-SPECIAL_CHARS-' \
f'{self.use_words_aug}-' \
f'{self.words_aug_group_sizes}-' \
f'{self.words_aug_join_char}'
if context and refined_words_key in samples[Fields.context][idx]:
words = samples[Fields.context][idx][refined_words_key]
else:
words = words_refinement(
words,
lower_case=True,
strip_chars=SPECIAL_CHARACTERS,
use_words_aug=self.use_words_aug,
words_aug_group_sizes=self.words_aug_group_sizes,
words_aug_join_char=self.words_aug_join_char)
if context:
samples[Fields.context][idx][refined_words_key] = words

flagged_words_ratio = (len([
word for word in words if word in self.FLAGGED_WORDS[self.lang]
]) / len(words)) if len(words) != 0 else 0.0

if flagged_words_ratio > 1.0:
flagged_words_ratio = 1.0

samples_stats[idx][
StatsKeys.flagged_words_ratio] = flagged_words_ratio

return samples

def process(self, samples):
if isinstance(samples[Fields.stats], list):
return list(
map(
lambda stat: stat[StatsKeys.flagged_words_ratio] <= self.
max_ratio, samples[Fields.stats]))
else:
words = words_refinement(
words,
lower_case=True,
strip_chars=SPECIAL_CHARACTERS,
use_words_aug=self.use_words_aug,
words_aug_group_sizes=self.words_aug_group_sizes,
words_aug_join_char=self.words_aug_join_char)
if context:
sample[Fields.context][refined_words_key] = words

flagged_words_ratio = (len(
[word
for word in words if word in self.FLAGGED_WORDS[self.lang]]) /
len(words)) if len(words) != 0 else 0.0

if flagged_words_ratio > 1.0:
flagged_words_ratio = 1.0

sample[Fields.stats][
StatsKeys.flagged_words_ratio] = flagged_words_ratio
return sample

def process(self, sample):
return sample[Fields.stats][
StatsKeys.flagged_words_ratio] <= self.max_ratio
return samples[Fields.stats][
StatsKeys.flagged_words_ratio] <= self.max_ratio
93 changes: 54 additions & 39 deletions data_juicer/ops/filter/image_aspect_ratio_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ class ImageAspectRatioFilter(Filter):
AspectRatio = W / H.
"""

_batched_op = True

def __init__(self,
min_ratio: float = 0.333,
max_ratio: float = 3.0,
Expand All @@ -40,43 +42,56 @@ def __init__(self,
f'Can only be one of ["any", "all"].')
self.any = (any_or_all == 'any')

def compute_stats(self, sample, context=False):
# check if it's computed already
if StatsKeys.aspect_ratios in sample[Fields.stats]:
return sample

# there is no image in this sample
if self.image_key not in sample or not sample[self.image_key]:
sample[Fields.stats][StatsKeys.aspect_ratios] = np.array(
[], dtype=np.float64)
return sample

# load images
loaded_image_keys = sample[self.image_key]
sample, images = load_data_with_context(sample, context,
loaded_image_keys, load_image)

# compute aspect ratios for each image with W/H
aspect_ratios = {
key: (images[key].width / images[key].height)
for key in images
}
sample[Fields.stats][StatsKeys.aspect_ratios] = [
aspect_ratios[key] for key in loaded_image_keys
]
return sample

def process(self, sample):
aspect_ratios = sample[Fields.stats][StatsKeys.aspect_ratios]
keep_bools = np.array([
self.min_ratio <= aspect_ratio <= self.max_ratio
for aspect_ratio in aspect_ratios
])
if len(keep_bools) <= 0:
return True

# different strategies
if self.any:
return keep_bools.any()
def compute_stats(self, samples, context=False):
image_list = samples[self.image_key]
samples_stats = samples[Fields.stats]

for i, stat in enumerate(samples_stats):
# check if it's computed already
if StatsKeys.aspect_ratios in stat:
continue

# there is no image in this sample
loaded_image_keys = image_list[i]
if not loaded_image_keys:
stat[StatsKeys.aspect_ratios] = np.array([], dtype=np.float64)
continue

# load images
samples, images = load_data_with_context(samples, context,
loaded_image_keys,
load_image)

# compute aspect ratios for each image with W/H
aspect_ratios = {
key: (images[key].width / images[key].height)
for key in images
}
stat[StatsKeys.aspect_ratios] = [
aspect_ratios[key] for key in loaded_image_keys
]

return samples

def process(self, samples):

def process_single(values):
keep_bools = np.array([
self.min_ratio <= value <= self.max_ratio for value in values
])
if len(keep_bools) <= 0:
return True

# different strategies
if self.any:
return keep_bools.any()
else:
return keep_bools.all()

if isinstance(samples[Fields.stats], list):
return map(
lambda stat: process_single(stat[StatsKeys.aspect_ratios]),
samples[Fields.stats])
else:
return keep_bools.all()
return process_single(
samples[Fields.stats][StatsKeys.aspect_ratios])
2 changes: 1 addition & 1 deletion data_juicer/ops/filter/perplexity_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def compute_stats(self, samples, context=False):
samples_list = samples[self.text_key]
samples_stats = samples[Fields.stats]
words_key = f'{InterVars.words}-{self.sp_model_key}'
tokenizer = get_model(self.sp_model_key)

for idx, stat in enumerate(samples_stats):
# check if it's computed already
Expand All @@ -59,7 +60,6 @@ def compute_stats(self, samples, context=False):
if context and words_key in samples[Fields.context][idx]:
words = samples[Fields.context][idx][words_key]
else:
tokenizer = get_model(self.sp_model_key)
words = get_words_from_document(
samples_list[idx],
token_func=tokenizer.encode_as_pieces
Expand Down
Loading