Skip to content

Commit

Permalink
Add 'uuid_column', 'tenant' params to WeaviateIngestOperator (apache#…
Browse files Browse the repository at this point in the history
…36387)

* Add 'uuid_column', 'tenant' params to WeaviateIngestOperator

* Update airflow/providers/weaviate/operators/weaviate.py

Co-authored-by: Pankaj Singh <[email protected]>

* Fix test

---------

Co-authored-by: Pankaj Singh <[email protected]>
  • Loading branch information
utkarsharma2 and pankajastro authored Dec 23, 2023
1 parent 75d74b1 commit ff3b8da
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 8 deletions.
17 changes: 13 additions & 4 deletions airflow/providers/weaviate/operators/weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ def __init__(
input_json: list[dict[str, Any]] | pd.DataFrame | None = None,
input_data: list[dict[str, Any]] | pd.DataFrame | None = None,
vector_col: str = "Vector",
uuid_column: str = "id",
tenant: str | None = None,
**kwargs: Any,
) -> None:
self.batch_params = kwargs.pop("batch_params", {})
Expand All @@ -70,6 +72,8 @@ def __init__(
self.conn_id = conn_id
self.vector_col = vector_col
self.input_json = input_json
self.uuid_column = uuid_column
self.tenant = tenant
if input_data is not None:
self.input_data = input_data
elif input_json is not None:
Expand All @@ -87,11 +91,16 @@ def hook(self) -> WeaviateHook:
"""Return an instance of the WeaviateHook."""
return WeaviateHook(conn_id=self.conn_id, **self.hook_params)

def execute(self, context: Context) -> None:
def execute(self, context: Context) -> list:
self.log.debug("Input data: %s", self.input_data)
insertion_errors: list = []
self.hook.batch_data(
self.class_name,
self.input_data,
**self.batch_params,
class_name=self.class_name,
data=self.input_data,
batch_config_params=self.batch_params,
vector_col=self.vector_col,
insertion_errors=insertion_errors,
uuid_col=self.uuid_column,
tenant=self.tenant,
)
return insertion_errors
14 changes: 10 additions & 4 deletions tests/providers/weaviate/operators/test_weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,13 @@ def operator(self):
task_id="weaviate_task",
conn_id="weaviate_conn",
class_name="my_class",
input_json={"data": "sample_data"},
input_json=[{"data": "sample_data"}],
)

def test_constructor(self, operator):
assert operator.conn_id == "weaviate_conn"
assert operator.class_name == "my_class"
assert operator.input_data == {"data": "sample_data"}
assert operator.input_data == [{"data": "sample_data"}]
assert operator.batch_params == {}
assert operator.hook_params == {}

Expand All @@ -47,9 +47,15 @@ def test_execute_with_input_json(self, mock_log, operator):
operator.execute(context=None)

operator.hook.batch_data.assert_called_once_with(
"my_class", {"data": "sample_data"}, vector_col="Vector", **{}
class_name="my_class",
data=[{"data": "sample_data"}],
batch_config_params={},
vector_col="Vector",
insertion_errors=[],
uuid_col="id",
tenant=None,
)
mock_log.debug.assert_called_once_with("Input data: %s", {"data": "sample_data"})
mock_log.debug.assert_called_once_with("Input data: %s", [{"data": "sample_data"}])

@pytest.mark.db_test
def test_templates(self, create_task_instance_of_operator):
Expand Down

0 comments on commit ff3b8da

Please sign in to comment.