This is an automated email from the ASF dual-hosted git repository.
utkarsharma pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new ff3b8daac0 Add 'uuid_column', 'tenant' params to
WeaviateIngestOperator (#36387)
ff3b8daac0 is described below
commit ff3b8daac0cbf3c885ea1479b1fb9cfcb2261f21
Author: Utkarsh Sharma <[email protected]>
AuthorDate: Sat Dec 23 14:03:23 2023 +0530
Add 'uuid_column', 'tenant' params to WeaviateIngestOperator (#36387)
* Add 'uuid_column', 'tenant' params to WeaviateIngestOperator
* Update airflow/providers/weaviate/operators/weaviate.py
Co-authored-by: Pankaj Singh <[email protected]>
* Fix test
---------
Co-authored-by: Pankaj Singh <[email protected]>
---
airflow/providers/weaviate/operators/weaviate.py | 17 +++++++++++++----
tests/providers/weaviate/operators/test_weaviate.py | 14 ++++++++++----
2 files changed, 23 insertions(+), 8 deletions(-)
diff --git a/airflow/providers/weaviate/operators/weaviate.py
b/airflow/providers/weaviate/operators/weaviate.py
index d12a2c2e6c..c36e6aa0d2 100644
--- a/airflow/providers/weaviate/operators/weaviate.py
+++ b/airflow/providers/weaviate/operators/weaviate.py
@@ -60,6 +60,8 @@ class WeaviateIngestOperator(BaseOperator):
input_json: list[dict[str, Any]] | pd.DataFrame | None = None,
input_data: list[dict[str, Any]] | pd.DataFrame | None = None,
vector_col: str = "Vector",
+ uuid_column: str = "id",
+ tenant: str | None = None,
**kwargs: Any,
) -> None:
self.batch_params = kwargs.pop("batch_params", {})
@@ -70,6 +72,8 @@ class WeaviateIngestOperator(BaseOperator):
self.conn_id = conn_id
self.vector_col = vector_col
self.input_json = input_json
+ self.uuid_column = uuid_column
+ self.tenant = tenant
if input_data is not None:
self.input_data = input_data
elif input_json is not None:
@@ -87,11 +91,16 @@ class WeaviateIngestOperator(BaseOperator):
"""Return an instance of the WeaviateHook."""
return WeaviateHook(conn_id=self.conn_id, **self.hook_params)
- def execute(self, context: Context) -> None:
+ def execute(self, context: Context) -> list:
self.log.debug("Input data: %s", self.input_data)
+ insertion_errors: list = []
self.hook.batch_data(
- self.class_name,
- self.input_data,
- **self.batch_params,
+ class_name=self.class_name,
+ data=self.input_data,
+ batch_config_params=self.batch_params,
vector_col=self.vector_col,
+ insertion_errors=insertion_errors,
+ uuid_col=self.uuid_column,
+ tenant=self.tenant,
)
+ return insertion_errors
diff --git a/tests/providers/weaviate/operators/test_weaviate.py
b/tests/providers/weaviate/operators/test_weaviate.py
index 7a2c362494..775a568be1 100644
--- a/tests/providers/weaviate/operators/test_weaviate.py
+++ b/tests/providers/weaviate/operators/test_weaviate.py
@@ -30,13 +30,13 @@ class TestWeaviateIngestOperator:
task_id="weaviate_task",
conn_id="weaviate_conn",
class_name="my_class",
- input_json={"data": "sample_data"},
+ input_json=[{"data": "sample_data"}],
)
def test_constructor(self, operator):
assert operator.conn_id == "weaviate_conn"
assert operator.class_name == "my_class"
- assert operator.input_data == {"data": "sample_data"}
+ assert operator.input_data == [{"data": "sample_data"}]
assert operator.batch_params == {}
assert operator.hook_params == {}
@@ -47,9 +47,15 @@ class TestWeaviateIngestOperator:
operator.execute(context=None)
operator.hook.batch_data.assert_called_once_with(
- "my_class", {"data": "sample_data"}, vector_col="Vector", **{}
+ class_name="my_class",
+ data=[{"data": "sample_data"}],
+ batch_config_params={},
+ vector_col="Vector",
+ insertion_errors=[],
+ uuid_col="id",
+ tenant=None,
)
- mock_log.debug.assert_called_once_with("Input data: %s", {"data":
"sample_data"})
+ mock_log.debug.assert_called_once_with("Input data: %s", [{"data":
"sample_data"}])
@pytest.mark.db_test
def test_templates(self, create_task_instance_of_operator):