This is an automated email from the ASF dual-hosted git repository.

utkarsharma pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 8850715e22 Remove 'insertion_errors' as required argument (#36435)
8850715e22 is described below

commit 8850715e22dc8fd69dfc234efed805cc75708938
Author: Utkarsh Sharma <[email protected]>
AuthorDate: Tue Dec 26 15:30:52 2023 +0530

    Remove 'insertion_errors' as required argument (#36435)
---
 airflow/providers/weaviate/hooks/weaviate.py        | 4 +---
 airflow/providers/weaviate/operators/weaviate.py    | 1 -
 tests/providers/weaviate/hooks/test_weaviate.py     | 4 ++--
 tests/providers/weaviate/operators/test_weaviate.py | 1 -
 4 files changed, 3 insertions(+), 7 deletions(-)

diff --git a/airflow/providers/weaviate/hooks/weaviate.py 
b/airflow/providers/weaviate/hooks/weaviate.py
index d0b8db37cb..10f784e9fe 100644
--- a/airflow/providers/weaviate/hooks/weaviate.py
+++ b/airflow/providers/weaviate/hooks/weaviate.py
@@ -385,7 +385,6 @@ class WeaviateHook(BaseHook):
         self,
         class_name: str,
         data: list[dict[str, Any]] | pd.DataFrame,
-        insertion_errors: list,
         batch_config_params: dict[str, Any] | None = None,
         vector_col: str = "Vector",
         uuid_col: str = "id",
@@ -397,7 +396,6 @@ class WeaviateHook(BaseHook):
 
         :param class_name: The name of the class that objects belongs to.
         :param data: list or dataframe of objects we want to add.
-        :param insertion_errors: list to hold errors while inserting.
         :param batch_config_params: dict of batch configuration option.
             .. seealso:: `batch_config_params options 
<https://weaviate-python-client.readthedocs.io/en/v3.25.3/weaviate.batch.html#weaviate.batch.Batch.configure>`__
         :param vector_col: name of the column containing the vector.
@@ -408,6 +406,7 @@ class WeaviateHook(BaseHook):
         data = self._convert_dataframe_to_list(data)
         total_results = 0
         error_results = 0
+        insertion_errors: list = []
 
         def _process_batch_errors(
             results: list,
@@ -1070,7 +1069,6 @@ class WeaviateHook(BaseHook):
             insertion_errors = self.batch_data(
                 class_name=class_name,
                 data=data,
-                insertion_errors=insertion_errors,
                 batch_config_params=batch_config_params,
                 vector_col=vector_column,
                 uuid_col=uuid_column,
diff --git a/airflow/providers/weaviate/operators/weaviate.py 
b/airflow/providers/weaviate/operators/weaviate.py
index aa942c61ad..d23d6f3cfa 100644
--- a/airflow/providers/weaviate/operators/weaviate.py
+++ b/airflow/providers/weaviate/operators/weaviate.py
@@ -99,7 +99,6 @@ class WeaviateIngestOperator(BaseOperator):
             data=self.input_data,
             batch_config_params=self.batch_params,
             vector_col=self.vector_col,
-            insertion_errors=insertion_errors,
             uuid_col=self.uuid_column,
             tenant=self.tenant,
         )
diff --git a/tests/providers/weaviate/hooks/test_weaviate.py 
b/tests/providers/weaviate/hooks/test_weaviate.py
index 5b8d5ce65d..9d931b9397 100644
--- a/tests/providers/weaviate/hooks/test_weaviate.py
+++ b/tests/providers/weaviate/hooks/test_weaviate.py
@@ -428,7 +428,7 @@ def test_batch_data(data, expected_length, weaviate_hook):
     test_class_name = "TestClass"
 
     # Test the batch_data method
-    weaviate_hook.batch_data(test_class_name, data, insertion_errors=[])
+    weaviate_hook.batch_data(test_class_name, data)
 
     # Assert that the batch_data method was called with the correct arguments
     mock_client.batch.configure.assert_called_once()
@@ -446,7 +446,7 @@ def test_batch_data_retry(get_conn, weaviate_hook):
     error.response = response
     side_effect = [None, error, None, error, None]
     
get_conn.return_value.batch.__enter__.return_value.add_data_object.side_effect 
= side_effect
-    weaviate_hook.batch_data("TestClass", data, insertion_errors=[])
+    weaviate_hook.batch_data("TestClass", data)
     assert 
get_conn.return_value.batch.__enter__.return_value.add_data_object.call_count 
== len(side_effect)
 
 
diff --git a/tests/providers/weaviate/operators/test_weaviate.py 
b/tests/providers/weaviate/operators/test_weaviate.py
index cd65366453..b675cf5e64 100644
--- a/tests/providers/weaviate/operators/test_weaviate.py
+++ b/tests/providers/weaviate/operators/test_weaviate.py
@@ -54,7 +54,6 @@ class TestWeaviateIngestOperator:
             data=[{"data": "sample_data"}],
             batch_config_params={},
             vector_col="Vector",
-            insertion_errors=[],
             uuid_col="id",
             tenant=None,
         )

Reply via email to