This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 7084429f42 Fixing template_fields for WeaviateIngestOperator (#36359)
7084429f42 is described below

commit 7084429f42d0a006e777612c07b3471100f953c9
Author: vatsrahul1001 <[email protected]>
AuthorDate: Fri Dec 22 03:26:02 2023 +0530

    Fixing template_fields for WeaviateIngestOperator (#36359)
    
    * fixing teamplate fields for WeaviateIngestOperator
    
    * removing fxitures which are not required
    
    * marking test as db as its accessing db
---
 airflow/providers/weaviate/operators/weaviate.py    |  4 ++--
 tests/providers/weaviate/operators/test_weaviate.py | 17 +++++++++++++++++
 2 files changed, 19 insertions(+), 2 deletions(-)

diff --git a/airflow/providers/weaviate/operators/weaviate.py 
b/airflow/providers/weaviate/operators/weaviate.py
index 4e07a59edb..d12a2c2e6c 100644
--- a/airflow/providers/weaviate/operators/weaviate.py
+++ b/airflow/providers/weaviate/operators/weaviate.py
@@ -51,7 +51,7 @@ class WeaviateIngestOperator(BaseOperator):
     :param vector_col: key/column name in which the vectors are stored.
     """
 
-    template_fields: Sequence[str] = ("input_json",)
+    template_fields: Sequence[str] = ("input_json", "input_data")
 
     def __init__(
         self,
@@ -69,7 +69,7 @@ class WeaviateIngestOperator(BaseOperator):
         self.class_name = class_name
         self.conn_id = conn_id
         self.vector_col = vector_col
-
+        self.input_json = input_json
         if input_data is not None:
             self.input_data = input_data
         elif input_json is not None:
diff --git a/tests/providers/weaviate/operators/test_weaviate.py 
b/tests/providers/weaviate/operators/test_weaviate.py
index 7490b64dc6..7a2c362494 100644
--- a/tests/providers/weaviate/operators/test_weaviate.py
+++ b/tests/providers/weaviate/operators/test_weaviate.py
@@ -50,3 +50,20 @@ class TestWeaviateIngestOperator:
             "my_class", {"data": "sample_data"}, vector_col="Vector", **{}
         )
         mock_log.debug.assert_called_once_with("Input data: %s", {"data": 
"sample_data"})
+
+    @pytest.mark.db_test
+    def test_templates(self, create_task_instance_of_operator):
+        dag_id = "TestWeaviateIngestOperator"
+        ti = create_task_instance_of_operator(
+            WeaviateIngestOperator,
+            dag_id=dag_id,
+            task_id="task-id",
+            conn_id="weaviate_conn",
+            class_name="my_class",
+            input_json="{{ dag.dag_id }}",
+            input_data="{{ dag.dag_id }}",
+        )
+        ti.render_templates()
+
+        assert dag_id == ti.task.input_json
+        assert dag_id == ti.task.input_data

Reply via email to