Skip to content

Commit

Permalink
Add dataset.transform() where we pass the entire input as iterable (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
dsmilkov authored Nov 29, 2023
1 parent b85a1ed commit 86251f6
Show file tree
Hide file tree
Showing 4 changed files with 223 additions and 37 deletions.
36 changes: 35 additions & 1 deletion lilac/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pathlib
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from typing import Any, Iterable, Iterator, Literal, Optional, Sequence, Union
from typing import Any, Callable, Iterable, Iterator, Literal, Optional, Sequence, Union

import pandas as pd
from pydantic import (
Expand Down Expand Up @@ -624,6 +624,40 @@ def map(
"""
pass

@abc.abstractmethod
def transform(
self,
transform_fn: Callable[[Iterable[Item]], Iterable[Item]],
output_column: str,
input_path: Optional[Path] = None,
nest_under: Optional[Path] = None,
overwrite: bool = False,
combine_columns: bool = False,
resolve_span: bool = False,
) -> None:
"""Transforms the entire dataset (or a column) and writes the result to a new column.
Args:
transform_fn: A callable that takes a full row item dictionary, and returns an Item for the
result. The result Item can be a primitive, like a string.
output_column: The name of the output column to write to. When `nest_under` is False
(the default), this will be the name of the top-level column. When `nest_under` is True,
the output_column will be the name of the column under the path given by `nest_under`.
input_path: The path to the input column to map over. If not specified, the map function will
be called with the full row item dictionary. If specified, the map function will be called
with the value at the given path, flattened. The output column will be written in the same
shape as the input column, paralleling its nestedness.
nest_under: The path to nest the output under. This is useful when emitting annotations, like
spans, so they will get hierarchically shown in the UI.
overwrite: Set to true to overwrite this column if it already exists. If this bit is False,
an error will be thrown if the column already exists.
combine_columns: When true, the row passed to the map function will be a deeply nested object
reflecting the hierarchy of the data. When false, all columns will be flattened as top-level
fields.
resolve_span: Whether to resolve the spans into text before calling the map function.
"""
pass

@abc.abstractmethod
def to_json(
self,
Expand Down
151 changes: 145 additions & 6 deletions lilac/data/dataset_duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -2640,6 +2640,138 @@ def map(

return DuckDBMapOutput(pyarrow_reader=reader, output_column=output_column)

@override
def transform(
self,
transform_fn: Callable[[Iterable[Item]], Iterable[Item]],
output_column: str,
input_path: Optional[Path] = None,
nest_under: Optional[Path] = None,
overwrite: bool = False,
combine_columns: bool = False,
resolve_span: bool = False,
) -> None:
manifest = self.manifest()
input_path = normalize_path(input_path) if input_path else None
if input_path:
input_field = manifest.data_schema.get_field(input_path)
if not input_field.dtype:
raise ValueError(
f'Input path {input_path} is not a leaf path. This is currently unsupported. If you '
'require this, please file an issue and we will prioritize.'
)

# Validate output_column and nest_under.
if nest_under is not None:
nest_under = normalize_path(nest_under)
if output_column is None:
raise ValueError('When using `nest_under`, you must specify an output column name.')

# Make sure nest_under does not contain any repeated values.
for path_part in nest_under:
if path_part == PATH_WILDCARD:
raise ValueError('Nesting map outputs under a repeated field is not yet supported.')

if not manifest.data_schema.has_field(nest_under):
raise ValueError(f'The `nest_under` column {nest_under} does not exist.')

if nest_under is not None:
output_path = (*nest_under, output_column)
else:
output_path = (output_column,)

parquet_filepath: Optional[str] = None
if manifest.data_schema.has_field(output_path):
if overwrite:
field = manifest.data_schema.get_field(output_path)
if field.map is None:
raise ValueError(f'{output_path} is not a map column so it cannot be overwritten.')
# Delete the parquet file and map manifest.
assert output_column is not None
parquet_filepath = os.path.join(
self.dataset_path, get_parquet_filename(output_column, shard_index=0, num_shards=1)
)
if os.path.exists(parquet_filepath):
delete_file(parquet_filepath)

map_manifest_filepath = os.path.join(
self.dataset_path, f'{output_column}.{MAP_MANIFEST_SUFFIX}'
)
if os.path.exists(map_manifest_filepath):
delete_file(map_manifest_filepath)

else:
raise ValueError(
f'Cannot map to path "{output_column}" which already exists in the dataset. '
'Use overwrite=True to overwrite the column.'
)

task_ids = []
jsonl_cache_filepaths: list[str] = []
output_col_desc_suffix = f' to "{output_column}"' if output_column else ''
task_id = get_task_manager().task_id(
name=f'[{self.namespace}/{self.dataset_name}] transform '
f'"{transform_fn.__name__}"{output_col_desc_suffix}',
)
jsonl_cache_filepath = _jsonl_cache_filepath(
namespace=self.namespace,
dataset_name=self.dataset_name,
key=output_path,
project_dir=self.project_dir,
shard_id=0,
shard_count=1,
)
get_task_manager().execute(
task_id,
self._map_worker,
transform_fn,
output_path,
jsonl_cache_filepath,
0,
1,
input_path,
overwrite,
combine_columns,
resolve_span,
(task_id, 0),
True, # entire_input
)
task_ids.append(task_id)
jsonl_cache_filepaths.append(jsonl_cache_filepath)

# Wait for the tasks to finish before reading the outputs.
get_task_manager().wait(task_ids)

_, map_schema, parquet_filepath = self._reshard_cache(
output_path=output_path, jsonl_cache_filepaths=jsonl_cache_filepaths
)

assert parquet_filepath is not None

map_field_root = map_schema.get_field(output_path)

map_field_root.map = MapInfo(
fn_name=transform_fn.__name__,
input_path=input_path,
fn_source=inspect.getsource(transform_fn),
date_created=datetime.now(),
)

parquet_dir = os.path.dirname(parquet_filepath)
map_manifest_filepath = os.path.join(parquet_dir, f'{output_column}.{MAP_MANIFEST_SUFFIX}')
parquet_filename = os.path.basename(parquet_filepath)
map_manifest = MapManifest(
files=[parquet_filename],
data_schema=map_schema,
parquet_id=get_map_parquet_id(output_path),
py_version=metadata.version('lilac'),
)
with open_file(map_manifest_filepath, 'w') as f:
f.write(map_manifest.model_dump_json(exclude_none=True, indent=2))

parquet_filepath = os.path.join(self.dataset_path, parquet_filepath)
log(f'Wrote transform output to {parquet_filepath}')

def _map_worker(
self,
map_fn: MapFn,
Expand All @@ -2652,6 +2784,7 @@ def _map_worker(
combine_columns: bool = False,
resolve_span: bool = False,
task_step_id: Optional[TaskStepId] = None,
entire_input: bool = False,
) -> None:
map_sig = inspect.signature(map_fn)
if len(map_sig.parameters) > 2 or len(map_sig.parameters) == 0:
Expand All @@ -2662,12 +2795,18 @@ def _map_worker(

has_job_id_arg = len(map_sig.parameters) == 2

def _map_iterable(items: Iterable[RichData]) -> Iterable[Optional[Item]]:
for item in items:
args: Union[tuple[Item], tuple[Item, int]] = (
(item,) if not has_job_id_arg else (item, job_id)
)
yield map_fn(*args)
if not entire_input:

def _map_iterable(items: Iterable[RichData]) -> Iterable[Optional[Item]]:
for item in items:
args: Union[tuple[Item], tuple[Item, int]] = (
(item,) if not has_job_id_arg else (item, job_id)
)
yield map_fn(*args)
else:

def _map_iterable(items: Iterable[RichData]) -> Iterable[Optional[Item]]:
return map_fn(items)

self._compute_disk_cached(
_map_iterable,
Expand Down
36 changes: 36 additions & 0 deletions lilac/data/dataset_map_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -949,3 +949,39 @@ def skip_first_and_last_letter(item: str) -> Item:
{'text': 'abcd', 'skip': span(1, 3)},
{'text': 'efghi', 'skip': span(1, 4)},
]


def test_transform(make_test_data: TestDataMaker) -> None:
def text_len(items: Iterable[Item]) -> Iterable[Item]:
for item in items:
yield len(item['text'])

dataset = make_test_data([{'text': 'abcd'}, {'text': 'efghi'}])
dataset.transform(text_len, output_column='text_len')

rows = dataset.select_rows()
assert list(rows) == [{'text': 'abcd', 'text_len': 4}, {'text': 'efghi', 'text_len': 5}]


def test_transform_with_input_path(make_test_data: TestDataMaker) -> None:
def text_len(texts: Iterable[Item]) -> Iterable[Item]:
for text in texts:
yield len(text)

dataset = make_test_data([{'text': 'abcd'}, {'text': 'efghi'}])
dataset.transform(text_len, input_path='text', output_column='text_len')

rows = dataset.select_rows()
assert list(rows) == [{'text': 'abcd', 'text_len': 4}, {'text': 'efghi', 'text_len': 5}]


def test_transform_size_mismatch(make_test_data: TestDataMaker) -> None:
def text_len(texts: Iterable[Item]) -> Iterable[Item]:
for i, text in enumerate(texts):
# Skip the first item.
if i > 0:
yield len(text)

dataset = make_test_data([{'text': 'abcd'}, {'text': 'efghi'}])
with pytest.raises(Exception):
dataset.transform(text_len, input_path='text', output_column='text_len')
37 changes: 7 additions & 30 deletions lilac/data/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,36 +310,13 @@ def sparse_to_dense_compute(
"""Densifies the input before calling the provided `func` and sparsifies the output."""
total_size: int = 0

def densify(x: Iterator[Optional[Tin]]) -> Iterator[tuple[int, Tin]]:
nonlocal total_size
for i, value in enumerate(x):
total_size += 1
def densify(x: Iterator[Optional[Tin]]) -> Iterator[Tin]:
for value in x:
if value is not None:
yield i, value

dense_input_with_locations = densify(sparse_input)
dense_input_with_locations_0, dense_input_with_locations_1 = itertools.tee(
dense_input_with_locations, 2
)
dense_input = (value for (_, value) in dense_input_with_locations_0)
yield value

sparse_input, sparse_input_2 = itertools.tee(sparse_input, 2)
dense_input = densify(sparse_input_2)
dense_output = iter(func(dense_input))
index = 0

location_index = 0

while True:
try:
out = next(dense_output)
out_index, _ = next(dense_input_with_locations_1)
while index < out_index:
yield None
index += 1
yield out
location_index += 1
index += 1
except StopIteration:
while index < total_size:
yield None
index += 1
return
for input in sparse_input:
yield None if input is None else next(dense_output)

0 comments on commit 86251f6

Please sign in to comment.