This is an automated email from the ASF dual-hosted git repository.

utkarsharma pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new ff3b8daac0 Add 'uuid_column', 'tenant' params to 
WeaviateIngestOperator (#36387)
ff3b8daac0 is described below

commit ff3b8daac0cbf3c885ea1479b1fb9cfcb2261f21
Author: Utkarsh Sharma <[email protected]>
AuthorDate: Sat Dec 23 14:03:23 2023 +0530

    Add 'uuid_column', 'tenant' params to WeaviateIngestOperator (#36387)
    
    * Add 'uuid_column', 'tenant' params to WeaviateIngestOperator
    
    * Update airflow/providers/weaviate/operators/weaviate.py
    
    Co-authored-by: Pankaj Singh <[email protected]>
    
    * Fix test
    
    ---------
    
    Co-authored-by: Pankaj Singh <[email protected]>
---
 airflow/providers/weaviate/operators/weaviate.py    | 17 +++++++++++++----
 tests/providers/weaviate/operators/test_weaviate.py | 14 ++++++++++----
 2 files changed, 23 insertions(+), 8 deletions(-)

diff --git a/airflow/providers/weaviate/operators/weaviate.py 
b/airflow/providers/weaviate/operators/weaviate.py
index d12a2c2e6c..c36e6aa0d2 100644
--- a/airflow/providers/weaviate/operators/weaviate.py
+++ b/airflow/providers/weaviate/operators/weaviate.py
@@ -60,6 +60,8 @@ class WeaviateIngestOperator(BaseOperator):
         input_json: list[dict[str, Any]] | pd.DataFrame | None = None,
         input_data: list[dict[str, Any]] | pd.DataFrame | None = None,
         vector_col: str = "Vector",
+        uuid_column: str = "id",
+        tenant: str | None = None,
         **kwargs: Any,
     ) -> None:
         self.batch_params = kwargs.pop("batch_params", {})
@@ -70,6 +72,8 @@ class WeaviateIngestOperator(BaseOperator):
         self.conn_id = conn_id
         self.vector_col = vector_col
         self.input_json = input_json
+        self.uuid_column = uuid_column
+        self.tenant = tenant
         if input_data is not None:
             self.input_data = input_data
         elif input_json is not None:
@@ -87,11 +91,16 @@ class WeaviateIngestOperator(BaseOperator):
         """Return an instance of the WeaviateHook."""
         return WeaviateHook(conn_id=self.conn_id, **self.hook_params)
 
-    def execute(self, context: Context) -> None:
+    def execute(self, context: Context) -> list:
         self.log.debug("Input data: %s", self.input_data)
+        insertion_errors: list = []
         self.hook.batch_data(
-            self.class_name,
-            self.input_data,
-            **self.batch_params,
+            class_name=self.class_name,
+            data=self.input_data,
+            batch_config_params=self.batch_params,
             vector_col=self.vector_col,
+            insertion_errors=insertion_errors,
+            uuid_col=self.uuid_column,
+            tenant=self.tenant,
         )
+        return insertion_errors
diff --git a/tests/providers/weaviate/operators/test_weaviate.py 
b/tests/providers/weaviate/operators/test_weaviate.py
index 7a2c362494..775a568be1 100644
--- a/tests/providers/weaviate/operators/test_weaviate.py
+++ b/tests/providers/weaviate/operators/test_weaviate.py
@@ -30,13 +30,13 @@ class TestWeaviateIngestOperator:
             task_id="weaviate_task",
             conn_id="weaviate_conn",
             class_name="my_class",
-            input_json={"data": "sample_data"},
+            input_json=[{"data": "sample_data"}],
         )
 
     def test_constructor(self, operator):
         assert operator.conn_id == "weaviate_conn"
         assert operator.class_name == "my_class"
-        assert operator.input_data == {"data": "sample_data"}
+        assert operator.input_data == [{"data": "sample_data"}]
         assert operator.batch_params == {}
         assert operator.hook_params == {}
 
@@ -47,9 +47,15 @@ class TestWeaviateIngestOperator:
         operator.execute(context=None)
 
         operator.hook.batch_data.assert_called_once_with(
-            "my_class", {"data": "sample_data"}, vector_col="Vector", **{}
+            class_name="my_class",
+            data=[{"data": "sample_data"}],
+            batch_config_params={},
+            vector_col="Vector",
+            insertion_errors=[],
+            uuid_col="id",
+            tenant=None,
         )
-        mock_log.debug.assert_called_once_with("Input data: %s", {"data": 
"sample_data"})
+        mock_log.debug.assert_called_once_with("Input data: %s", [{"data": 
"sample_data"}])
 
     @pytest.mark.db_test
     def test_templates(self, create_task_instance_of_operator):

Reply via email to