Skip to content

Commit

Permalink
feat: allow replacement of entire datafile when the schema lines up c…
Browse files Browse the repository at this point in the history
…orrectly
  • Loading branch information
chebbyChefNEQ committed Feb 3, 2025
1 parent c73d717 commit 0dc4169
Show file tree
Hide file tree
Showing 8 changed files with 714 additions and 24 deletions.
13 changes: 12 additions & 1 deletion protos/transaction.proto
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,16 @@ message Transaction {
}
}

message DataReplacementGroup {
uint64 fragment_id = 1;
DataFile new_file = 2;
}

// An operation that replaces the data in a region of the table with new data.
message DataReplacement {
repeated DataReplacementGroup replacements = 1;
}

// The operation of this transaction.
oneof operation {
Append append = 100;
Expand All @@ -186,11 +196,12 @@ message Transaction {
Update update = 108;
Project project = 109;
UpdateConfig update_config = 110;
DataReplacement data_replacement = 111;
}

// An operation to apply to the blob dataset
oneof blob_operation {
Append blob_append = 200;
Overwrite blob_overwrite = 202;
}
}
}
26 changes: 21 additions & 5 deletions python/python/lance/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
)
from .dependencies import numpy as np
from .dependencies import pandas as pd
from .fragment import FragmentMetadata, LanceFragment
from .fragment import DataFile, FragmentMetadata, LanceFragment
from .lance import (
CleanupStats,
Compaction,
Expand Down Expand Up @@ -1927,7 +1927,7 @@ def create_index(
valid_index_types = ["IVF_FLAT", "IVF_PQ", "IVF_HNSW_PQ", "IVF_HNSW_SQ"]
if index_type not in valid_index_types:
raise NotImplementedError(
f"Only {valid_index_types} index types supported. " f"Got {index_type}"
f"Only {valid_index_types} index types supported. Got {index_type}"
)
if index_type != "IVF_PQ" and one_pass_ivfpq:
raise ValueError(
Expand Down Expand Up @@ -2247,8 +2247,7 @@ def _commit(
commit_lock: Optional[CommitLock] = None,
) -> LanceDataset:
warnings.warn(
"LanceDataset._commit() is deprecated, use LanceDataset.commit()"
" instead",
"LanceDataset._commit() is deprecated, use LanceDataset.commit() instead",
DeprecationWarning,
)
return LanceDataset.commit(base_uri, operation, read_version, commit_lock)
Expand Down Expand Up @@ -2935,6 +2934,23 @@ class CreateIndex(BaseOperation):
dataset_version: int
fragment_ids: Set[int]

@dataclass
class DataReplacementGroup:
"""
Group of data replacements
"""

fragment_id: int
new_file: DataFile

@dataclass
class DataReplacement(BaseOperation):
"""
Operation that replaces existing datafiles in the dataset.
"""

replacements: List[LanceOperation.DataReplacementGroup]


class ScannerBuilder:
def __init__(self, ds: LanceDataset):
Expand Down Expand Up @@ -3203,7 +3219,7 @@ def nearest(

if q_dim != dim:
raise ValueError(
f"Query vector size {len(q)} does not match index column size" f" {dim}"
f"Query vector size {len(q)} does not match index column size {dim}"
)

if k is not None and int(k) <= 0:
Expand Down
2 changes: 1 addition & 1 deletion python/python/lance/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def take_rows(
if indices[i] > indices[i + 1]:
raise ValueError(
f"Indices must be sorted in ascending order for \
file API, got {indices[i]} > {indices[i+1]}"
file API, got {indices[i]} > {indices[i + 1]}"
)

return ReaderResults(
Expand Down
2 changes: 1 addition & 1 deletion python/python/lance/ray/sink.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def on_write_complete(

if len(write_results) == 0:
warnings.warn(
"write results is empty. please check ray version " "or internal error",
"write results is empty. please check ray version or internal error",
DeprecationWarning,
)
return
Expand Down
25 changes: 25 additions & 0 deletions python/python/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2913,3 +2913,28 @@ def test_dataset_schema(tmp_path: Path):
ds = lance.write_dataset(table, str(tmp_path)) # noqa: F841
ds._default_scan_options = {"with_row_id": True}
assert ds.schema == ds.to_table().schema


def test_data_replacement(tmp_path: Path):
table = pa.Table.from_pydict({"a": range(100), "b": range(100)})
base_dir = tmp_path / "test"

dataset = lance.write_dataset(table, base_dir)

table = pa.Table.from_pydict({"a": range(100, 200), "b": range(100, 200)})
fragment = lance.fragment.LanceFragment.create(base_dir, table)
data_file = fragment.files[0]
data_replacement = lance.LanceOperation.DataReplacement(
[lance.LanceOperation.DataReplacementGroup(0, data_file)]
)
dataset = lance.LanceDataset.commit(dataset, data_replacement, read_version=1)

tbl = dataset.to_table()

expected = pa.Table.from_pydict(
{
"a": list(range(100, 200)),
"b": list(range(100, 200)),
}
)
assert tbl == expected
46 changes: 44 additions & 2 deletions python/src/transaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@

use arrow::pyarrow::PyArrowType;
use arrow_schema::Schema as ArrowSchema;
use lance::dataset::transaction::{Operation, RewriteGroup, RewrittenIndex, Transaction};
use lance::dataset::transaction::{
DataReplacementGroup, Operation, RewriteGroup, RewrittenIndex, Transaction,
};
use lance::datatypes::Schema;
use lance_table::format::{Fragment, Index};
use lance_table::format::{DataFile, Fragment, Index};
use pyo3::exceptions::PyValueError;
use pyo3::types::PySet;
use pyo3::{intern, prelude::*};
Expand All @@ -15,6 +17,32 @@ use uuid::Uuid;
use crate::schema::LanceSchema;
use crate::utils::{class_name, export_vec, extract_vec, PyLance};

impl FromPyObject<'_> for PyLance<DataReplacementGroup> {
fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult<Self> {
let fragment_id = ob.getattr("fragment_id")?.extract::<u64>()?;
let new_file = &ob.getattr("new_file")?.extract::<PyLance<DataFile>>()?;

Ok(Self(DataReplacementGroup(fragment_id, new_file.0.clone())))
}
}

impl ToPyObject for PyLance<&DataReplacementGroup> {
fn to_object(&self, py: Python<'_>) -> PyObject {
let namespace = py
.import_bound(intern!(py, "lance"))
.and_then(|module| module.getattr(intern!(py, "LanceOperation")))
.expect("Failed to import LanceOperation namespace");

let fragment_id = self.0 .0;
let new_file = PyLance(&self.0 .1).to_object(py);

let cls = namespace
.getattr("DataReplacementGroup")
.expect("Failed to get DataReplacementGroup class");
cls.call1((fragment_id, new_file)).unwrap().to_object(py)
}
}

impl FromPyObject<'_> for PyLance<Operation> {
fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult<Self> {
match class_name(ob)? {
Expand Down Expand Up @@ -118,6 +146,13 @@ impl FromPyObject<'_> for PyLance<Operation> {
};
Ok(Self(op))
}
"DataReplacement" => {
let replacements = extract_vec(&ob.getattr("replacements")?)?;

let op = Operation::DataReplacement { replacements };

Ok(Self(op))
}
unsupported => Err(PyValueError::new_err(format!(
"Unsupported operation: {unsupported}",
))),
Expand Down Expand Up @@ -172,6 +207,13 @@ impl ToPyObject for PyLance<&Operation> {
.unwrap()
.to_object(py)
}
Operation::DataReplacement { replacements } => {
let replacements = export_vec(py, replacements.as_slice());
let cls = namespace
.getattr("DataReplacement")
.expect("Failed to get DataReplacement class");
cls.call1((replacements,)).unwrap().to_object(py)
}
_ => todo!(),
}
}
Expand Down
Loading

0 comments on commit 0dc4169

Please sign in to comment.