Skip to content

Commit

Permalink
Add exclude_signals to select_rows, and include_signals to export met…
Browse files Browse the repository at this point in the history
…hods.
  • Loading branch information
nsthorat committed Jan 26, 2024
1 parent 7574775 commit f6ee26b
Show file tree
Hide file tree
Showing 5 changed files with 445 additions and 13 deletions.
12 changes: 12 additions & 0 deletions lilac/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,7 @@ def select_rows(
resolve_span: bool = False,
combine_columns: bool = False,
include_deleted: bool = False,
exclude_signals: bool = False,
user: Optional[UserInfo] = None,
) -> SelectRowsResult:
"""Select a set of rows that match the provided filters, analogous to SQL SELECT.
Expand All @@ -694,6 +695,7 @@ def select_rows(
combine_columns: Whether to combine columns into a single object. The object will be pruned
to only include sub-fields that correspond to the requested columns.
include_deleted: Whether to include deleted rows in the query.
exclude_signals: Whether to exclude fields produced by signals.
user: The authenticated user, if auth is enabled and the user is logged in. This is used to
apply ACL to the query, especially for concepts.
Expand Down Expand Up @@ -987,6 +989,7 @@ def to_huggingface(
include_labels: Optional[Sequence[str]] = None,
exclude_labels: Optional[Sequence[str]] = None,
include_deleted: bool = False,
include_signals: bool = False,
) -> HuggingFaceDataset:
"""Export the dataset to a huggingface dataset.
Expand All @@ -996,6 +999,7 @@ def to_huggingface(
include_labels: The labels to include in the export.
exclude_labels: The labels to exclude in the export.
include_deleted: Whether to include deleted rows in the export.
include_signals: Whether to include fields produced by signals.
"""
pass

Expand All @@ -1009,6 +1013,7 @@ def to_json(
include_labels: Optional[Sequence[str]] = None,
exclude_labels: Optional[Sequence[str]] = None,
include_deleted: bool = False,
include_signals: bool = False,
) -> None:
"""Export the dataset to a JSON file.
Expand All @@ -1020,6 +1025,7 @@ def to_json(
include_labels: The labels to include in the export.
exclude_labels: The labels to exclude in the export.
include_deleted: Whether to include deleted rows in the export.
include_signals: Whether to include fields produced by signals.
"""
pass

Expand All @@ -1031,6 +1037,7 @@ def to_pandas(
include_labels: Optional[Sequence[str]] = None,
exclude_labels: Optional[Sequence[str]] = None,
include_deleted: bool = False,
include_signals: bool = False,
) -> pd.DataFrame:
"""Export the dataset to a pandas DataFrame.
Expand All @@ -1040,6 +1047,7 @@ def to_pandas(
include_labels: The labels to include in the export.
exclude_labels: The labels to exclude in the export.
include_deleted: Whether to include deleted rows in the export.
include_signals: Whether to include fields produced by signals.
"""
pass

Expand All @@ -1052,6 +1060,7 @@ def to_parquet(
include_labels: Optional[Sequence[str]] = None,
exclude_labels: Optional[Sequence[str]] = None,
include_deleted: bool = False,
include_signals: bool = False,
) -> None:
"""Export the dataset to a parquet file.
Expand All @@ -1062,6 +1071,7 @@ def to_parquet(
include_labels: The labels to include in the export.
exclude_labels: The labels to exclude in the export.
include_deleted: Whether to include deleted rows in the export.
include_signals: Whether to include fields produced by signals.
"""
pass

Expand All @@ -1074,6 +1084,7 @@ def to_csv(
include_labels: Optional[Sequence[str]] = None,
exclude_labels: Optional[Sequence[str]] = None,
include_deleted: bool = False,
include_signals: bool = False,
) -> None:
"""Export the dataset to a csv file.
Expand All @@ -1084,6 +1095,7 @@ def to_csv(
include_labels: The labels to include in the export.
exclude_labels: The labels to exclude in the export.
include_deleted: Whether to include deleted rows in the export.
include_signals: Whether to include fields produced by signals.
"""
pass

Expand Down
102 changes: 92 additions & 10 deletions lilac/data/dataset_duckdb.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""The DuckDB implementation of the dataset database."""
import copy
import csv
import functools
import gc
Expand Down Expand Up @@ -2051,6 +2052,7 @@ def select_rows(
resolve_span: bool = False,
combine_columns: bool = False,
include_deleted: bool = False,
exclude_signals: bool = False,
user: Optional[UserInfo] = None,
) -> SelectRowsResult:
manifest = self.manifest()
Expand All @@ -2072,6 +2074,14 @@ def select_rows(
combine_columns=True,
).data_schema

# Remove fields that are produced by signals.
if exclude_signals:
signal_paths: list[PathTuple] = []
for signal_manifest in self._signal_manifests:
signal_paths.extend(list(signal_manifest.data_schema.leafs.keys()))

cols = [col for col in cols if col.path not in signal_paths]

self._validate_columns(cols, manifest.data_schema, schema)

temp_rowid_selected = False
Expand Down Expand Up @@ -2188,7 +2198,7 @@ def select_rows(
if final_col_name not in columns_to_merge:
columns_to_merge[final_col_name] = {}

duckdb_paths = self._column_to_duckdb_paths(column, schema, combine_columns)
duckdb_paths = self._column_to_duckdb_paths(column, schema, combine_columns, exclude_signals)
span_from = self._resolve_span(path, manifest) if resolve_span or column.signal_udf else None

for parquet_id, duckdb_path in duckdb_paths:
Expand Down Expand Up @@ -2415,6 +2425,7 @@ def select_rows_schema(
sort_order: Optional[SortOrder] = None,
searches: Optional[Sequence[Search]] = None,
combine_columns: bool = False,
exclude_signals: bool = False,
) -> SelectRowsSchemaResult:
"""Returns the schema of the result of `select_rows` above with the same arguments."""
if not combine_columns:
Expand All @@ -2428,11 +2439,19 @@ def select_rows_schema(
search_udfs = self._search_udfs(searches, manifest)
cols.extend([search_udf.udf for search_udf in search_udfs])

# Remove fields that are produced by signals.
if exclude_signals:
signal_paths: list[PathTuple] = []
for signal_manifest in self._signal_manifests:
signal_paths.extend(list(signal_manifest.data_schema.leafs.keys()))

cols = [col for col in cols if col.path not in signal_paths]

udfs: list[SelectRowsSchemaUDF] = []
col_schemas: list[Schema] = []
for col in cols:
dest_path = _col_destination_path(col)
if col.signal_udf:
if col.signal_udf and not exclude_signals:
udfs.append(SelectRowsSchemaUDF(path=dest_path, alias=col.alias))
field = col.signal_udf.fields()
assert field, f'Signal {col.signal_udf.name} needs `Signal.fields` defined when run as UDF.'
Expand All @@ -2442,6 +2461,13 @@ def select_rows_schema(
else:
# This column might refer to an output of a udf. We postpone validation to later.
continue

# Delete any signals from the schema if we are excluding signals.
if exclude_signals:
field = copy.deepcopy(field)
field = _remove_signals_from_field(field)
assert field is not None

col_schemas.append(_make_schema_from_path(dest_path, field))

sort_results = self._merge_sorts(search_udfs, sort_by, sort_order)
Expand Down Expand Up @@ -2630,7 +2656,7 @@ def _leaf_path_to_duckdb_path(self, leaf_path: PathTuple, schema: Schema) -> Pat
return duckdb_path

def _column_to_duckdb_paths(
self, column: Column, schema: Schema, combine_columns: bool
self, column: Column, schema: Schema, combine_columns: bool, exclude_signals: bool = False
) -> list[tuple[str, PathTuple]]:
path = column.path
if path[0] in self._label_schemas:
Expand All @@ -2640,7 +2666,7 @@ def _column_to_duckdb_paths(
is_petal = schema.get_field(path).dtype is not None

# NOTE: The order of this array matters as we check the source and map manifests for fields
# before reading signal manifests, via source_or_map_has_path.
# before reading signal manifests, via source_or_map_has_path.s
parquet_manifests: list[Union[SourceManifest, SignalManifest, MapManifest]] = [
self._source_manifest,
*self._map_manifests,
Expand All @@ -2654,6 +2680,8 @@ def _column_to_duckdb_paths(
for m in parquet_manifests:
if not m.files:
continue
if exclude_signals and isinstance(m, SignalManifest):
continue
# Skip this parquet file if it doesn't contain the path.
# if not schema_contains_path(m.data_schema, path):
# continue
Expand Down Expand Up @@ -3206,13 +3234,18 @@ def to_huggingface(
include_labels: Optional[Sequence[str]] = None,
exclude_labels: Optional[Sequence[str]] = None,
include_deleted: bool = False,
include_signals: bool = False,
) -> HuggingFaceDataset:
filters, _ = self._normalize_filters(
filter_likes=filters, col_aliases={}, udf_aliases={}, manifest=self.manifest()
)
filters.extend(self._compile_include_exclude_filters(include_labels, exclude_labels))
rows = self.select_rows(
columns, filters=filters, combine_columns=True, include_deleted=include_deleted
columns,
filters=filters,
combine_columns=True,
include_deleted=include_deleted,
exclude_signals=not include_signals,
)

def _gen() -> Iterator[Item]:
Expand All @@ -3231,13 +3264,18 @@ def to_json(
include_labels: Optional[Sequence[str]] = None,
exclude_labels: Optional[Sequence[str]] = None,
include_deleted: bool = False,
include_signals: bool = False,
) -> None:
filters, _ = self._normalize_filters(
filter_likes=filters, col_aliases={}, udf_aliases={}, manifest=self.manifest()
)
filters.extend(self._compile_include_exclude_filters(include_labels, exclude_labels))
rows = self.select_rows(
columns, filters=filters, combine_columns=True, include_deleted=include_deleted
columns,
filters=filters,
combine_columns=True,
include_deleted=include_deleted,
exclude_signals=not include_signals,
)
filepath = os.path.expanduser(filepath)
with open_file(filepath, 'wb') as file:
Expand All @@ -3257,13 +3295,18 @@ def to_pandas(
include_labels: Optional[Sequence[str]] = None,
exclude_labels: Optional[Sequence[str]] = None,
include_deleted: bool = False,
include_signals: bool = False,
) -> pd.DataFrame:
filters, _ = self._normalize_filters(
filter_likes=filters, col_aliases={}, udf_aliases={}, manifest=self.manifest()
)
filters.extend(self._compile_include_exclude_filters(include_labels, exclude_labels))
rows = self.select_rows(
columns, filters=filters, combine_columns=True, include_deleted=include_deleted
columns,
filters=filters,
combine_columns=True,
include_deleted=include_deleted,
exclude_signals=not include_signals,
)
return pd.DataFrame.from_records(list(rows))

Expand All @@ -3276,6 +3319,7 @@ def to_csv(
include_labels: Optional[Sequence[str]] = None,
exclude_labels: Optional[Sequence[str]] = None,
include_deleted: bool = False,
include_signals: bool = False,
) -> None:
manifest = self.manifest()
filters, _ = self._normalize_filters(
Expand All @@ -3284,7 +3328,11 @@ def to_csv(
filters.extend(self._compile_include_exclude_filters(include_labels, exclude_labels))
select_schema = self.select_rows_schema(columns, combine_columns=True)
rows = self.select_rows(
columns, filters=filters, combine_columns=True, include_deleted=include_deleted
columns,
filters=filters,
combine_columns=True,
include_deleted=include_deleted,
exclude_signals=not include_signals,
)
fieldnames = list(select_schema.data_schema.fields.keys())
filepath = os.path.expanduser(filepath)
Expand All @@ -3303,14 +3351,21 @@ def to_parquet(
include_labels: Optional[Sequence[str]] = None,
exclude_labels: Optional[Sequence[str]] = None,
include_deleted: bool = False,
include_signals: bool = False,
) -> None:
filters, _ = self._normalize_filters(
filter_likes=filters, col_aliases={}, udf_aliases={}, manifest=self.manifest()
)
filters.extend(self._compile_include_exclude_filters(include_labels, exclude_labels))
select_schema = self.select_rows_schema(columns, combine_columns=True)
select_schema = self.select_rows_schema(
columns, combine_columns=True, exclude_signals=not include_signals
)
rows = self.select_rows(
columns, filters=filters, combine_columns=True, include_deleted=include_deleted
columns,
filters=filters,
combine_columns=True,
include_deleted=include_deleted,
exclude_signals=not include_signals,
)
filepath = os.path.expanduser(filepath)
with open_file(filepath, 'wb') as f:
Expand Down Expand Up @@ -3704,6 +3759,33 @@ def _schema_has_spans(field: Field) -> bool:
return False


def _remove_signals_from_field(field: Field) -> Optional[Field]:
"""Remove signals from a field."""
if field.signal is not None:
return None

if field.fields:
fields: dict[str, Field] = {}
for key, sub_field in field.fields.items():
if not sub_field.signal:
sub_field = _remove_signals_from_field(sub_field)
if sub_field:
fields[key] = sub_field
return Field(fields=fields, dtype=field.dtype)

if field.repeated_field:
if not field.signal:
sub_field = _remove_signals_from_field(field.repeated_field)
if sub_field:
return Field(repeated_field=sub_field, dtype=field.dtype)
else:
return None
else:
return None

return field


def _normalize_bins(bins: Optional[Union[Sequence[Bin], Sequence[float]]]) -> Optional[list[Bin]]:
if bins is None:
return None
Expand Down
Loading

0 comments on commit f6ee26b

Please sign in to comment.