Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix some issues with exporting. #1160

Merged
merged 4 commits into from
Feb 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lilac/data/dataset_duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -3381,7 +3381,7 @@ def to_json(
file.write(orjson.dumps(row))
file.write('\n'.encode('utf-8'))
else:
file.write(orjson.dumps(rows))
file.write(orjson.dumps(list(rows)))
log(f'Dataset exported to {filepath}')

@override
Expand Down
81 changes: 68 additions & 13 deletions lilac/data/dataset_export_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,15 @@ def test_export_to_json(make_test_data: TestDataMaker, tmp_path: pathlib.Path) -
filepath = tmp_path / 'dataset.json'
dataset.to_json(filepath)

with open(filepath) as f:
with open(filepath, 'r') as f:
parsed_items = [json.loads(line) for line in f.readlines()]

assert parsed_items == [{'text': 'hello'}, {'text': 'everybody'}]

# Include signals.
dataset.to_json(filepath, include_signals=True)

with open(filepath) as f:
with open(filepath, 'r') as f:
parsed_items = [json.loads(line) for line in f.readlines()]

assert parsed_items == [
Expand All @@ -126,7 +126,7 @@ def test_export_to_json(make_test_data: TestDataMaker, tmp_path: pathlib.Path) -
include_signals=True,
)

with open(filepath) as f:
with open(filepath, 'r') as f:
parsed_items = [json.loads(line) for line in f.readlines()]

assert parsed_items == [
Expand All @@ -138,7 +138,62 @@ def test_export_to_json(make_test_data: TestDataMaker, tmp_path: pathlib.Path) -
filepath, filters=[('text.test_signal.flen', 'less_equal', '5')], include_signals=True
)

with open(filepath) as f:
with open(filepath, 'r') as f:
parsed_items = [json.loads(line) for line in f.readlines()]

assert parsed_items == [{'text': {VALUE_KEY: 'hello', 'test_signal': {'flen': 5.0, 'len': 5}}}]


def test_export_to_jsonl(make_test_data: TestDataMaker, tmp_path: pathlib.Path) -> None:
dataset = make_test_data([{'text': 'hello'}, {'text': 'everybody'}])
dataset.compute_signal(TestSignal(), 'text')

# Download all columns.
filepath = tmp_path / 'dataset.json'
dataset.to_json(filepath, jsonl=True)

with open(filepath, 'r') as f:
parsed_items = [json.loads(line) for line in f.readlines()]

assert parsed_items == [{'text': 'hello'}, {'text': 'everybody'}]

# Include signals.
dataset.to_json(filepath, jsonl=True, include_signals=True)

with open(filepath, 'r') as f:
parsed_items = [json.loads(line) for line in f.readlines()]

assert parsed_items == [
{'text': {VALUE_KEY: 'hello', 'test_signal': {'flen': 5.0, 'len': 5}}},
{'text': {VALUE_KEY: 'everybody', 'test_signal': {'flen': 9.0, 'len': 9}}},
]

# Download a subset of columns with filter.
filepath = tmp_path / 'dataset2.json'
dataset.to_json(
filepath,
jsonl=True,
columns=['text', 'text.test_signal'],
filters=[('text.test_signal.len', 'greater', '6')],
include_signals=True,
)

with open(filepath, 'r') as f:
parsed_items = [json.loads(line) for line in f.readlines()]

assert parsed_items == [
{'text': {VALUE_KEY: 'everybody', 'test_signal': {'flen': 9.0, 'len': 9}}}
]

filepath = tmp_path / 'dataset3.json'
dataset.to_json(
filepath,
jsonl=True,
filters=[('text.test_signal.flen', 'less_equal', '5')],
include_signals=True,
)

with open(filepath, 'r') as f:
parsed_items = [json.loads(line) for line in f.readlines()]

assert parsed_items == [{'text': {VALUE_KEY: 'hello', 'test_signal': {'flen': 5.0, 'len': 5}}}]
Expand All @@ -152,7 +207,7 @@ def test_export_to_csv(make_test_data: TestDataMaker, tmp_path: pathlib.Path) ->
filepath = tmp_path / 'dataset.csv'
dataset.to_csv(filepath)

with open(filepath) as f:
with open(filepath, 'r') as f:
rows = list(csv.reader(f))

assert rows == [
Expand All @@ -172,7 +227,7 @@ def test_export_to_csv_include_signals(
filepath = tmp_path / 'dataset.csv'
dataset.to_csv(filepath, include_signals=True)

with open(filepath) as f:
with open(filepath, 'r') as f:
rows = list(csv.reader(f))

assert rows == [
Expand All @@ -196,7 +251,7 @@ def test_export_to_csv_subset_source_columns(
filepath = tmp_path / 'dataset.csv'
dataset.to_csv(filepath, columns=['age', 'metric'])

with open(filepath) as f:
with open(filepath, 'r') as f:
rows = list(csv.reader(f))

assert rows == [
Expand Down Expand Up @@ -232,7 +287,7 @@ def test_export_to_csv_subset_of_nested_data(
filepath = tmp_path / 'dataset.csv'
dataset.to_csv(filepath, columns=['doc.content', 'doc.paragraphs.*.text'])

with open(filepath) as f:
with open(filepath, 'r') as f:
rows = list(csv.reader(f))

assert rows == [
Expand Down Expand Up @@ -323,7 +378,7 @@ def test_label_and_export_by_excluding(
filepath = tmp_path / 'dataset.json'
dataset.to_json(filepath)

with open(filepath) as f:
with open(filepath, 'r') as f:
parsed_items = [json.loads(line) for line in f.readlines()]

assert parsed_items == [{f'{DELETED_LABEL_NAME}': None, 'text': 'a'}]
Expand All @@ -332,7 +387,7 @@ def test_label_and_export_by_excluding(
filepath = tmp_path / 'dataset.json'
dataset.to_json(filepath, include_deleted=True)

with open(filepath) as f:
with open(filepath, 'r') as f:
parsed_items = [json.loads(line) for line in f.readlines()]

assert parsed_items == [
Expand All @@ -357,7 +412,7 @@ def test_include_multiple_labels(make_test_data: TestDataMaker, tmp_path: pathli
filepath = tmp_path / 'dataset.json'
dataset.to_json(filepath, columns=['text'], include_labels=['good', 'very_good'])

with open(filepath) as f:
with open(filepath, 'r') as f:
parsed_items = [json.loads(line) for line in f.readlines()]

parsed_items = sorted(parsed_items, key=lambda x: x['text'])
Expand All @@ -373,7 +428,7 @@ def test_exclude_multiple_labels(make_test_data: TestDataMaker, tmp_path: pathli
filepath = tmp_path / 'dataset.json'
dataset.to_json(filepath, columns=['text'], exclude_labels=['bad', 'very_bad'])

with open(filepath) as f:
with open(filepath, 'r') as f:
parsed_items = [json.loads(line) for line in f.readlines()]

parsed_items = sorted(parsed_items, key=lambda x: x['text'])
Expand All @@ -389,7 +444,7 @@ def test_exclude_trumps_include(make_test_data: TestDataMaker, tmp_path: pathlib
filepath = tmp_path / 'dataset.json'
dataset.to_json(filepath, columns=['text'], include_labels=['good'], exclude_labels=['bad'])

with open(filepath) as f:
with open(filepath, 'r') as f:
parsed_items = [json.loads(line) for line in f.readlines()]

assert parsed_items == [{'text': 'b'}]
30 changes: 22 additions & 8 deletions lilac/router_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ class SelectRowsOptions(BaseModel):
offset: Optional[int] = None
combine_columns: Optional[bool] = None
include_deleted: bool = False
exclude_signals: bool = False


class SelectRowsSchemaOptions(BaseModel):
Expand Down Expand Up @@ -206,6 +207,7 @@ def select_rows(
offset=options.offset,
combine_columns=options.combine_columns or False,
include_deleted=options.include_deleted,
exclude_signals=options.exclude_signals,
user=user,
)

Expand Down Expand Up @@ -303,6 +305,7 @@ class ExportOptions(BaseModel):
columns: Sequence[Path] = []
include_labels: Sequence[str] = []
exclude_labels: Sequence[str] = []
include_signals: bool = False
# Note: "__deleted__" is "just" another label, and the UI
# will default to adding the "__deleted__" label to the exclude_labels list. If the user wants
# to include deleted items, they can remove the "__deleted__" label from the exclude_labels list.
Expand All @@ -328,20 +331,31 @@ def export_dataset(namespace: str, dataset_name: str, options: ExportOptions) ->

if options.format == 'csv':
dataset.to_csv(
options.filepath, options.columns, [], options.include_labels, options.exclude_labels
filepath=options.filepath,
columns=options.columns,
filters=[],
include_labels=options.include_labels,
exclude_labels=options.exclude_labels,
include_signals=options.include_signals,
)
elif options.format == 'json':
dataset.to_json(
options.filepath,
options.jsonl or False,
options.columns,
[],
options.include_labels,
options.exclude_labels,
filepath=options.filepath,
jsonl=options.jsonl or False,
columns=options.columns,
filters=[],
include_labels=options.include_labels,
exclude_labels=options.exclude_labels,
include_signals=options.include_signals,
)
elif options.format == 'parquet':
dataset.to_parquet(
options.filepath, options.columns, [], options.include_labels, options.exclude_labels
filepath=options.filepath,
columns=options.columns,
filters=[],
include_labels=options.include_labels,
exclude_labels=options.exclude_labels,
include_signals=options.include_signals,
)
else:
raise ValueError(f'Unknown format: {options.format}')
Expand Down
Loading
Loading