This is an automated email from the ASF dual-hosted git repository.

husseinawala pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 46470aba68 Fix assignment of template field in `__init__` in 
`CloudDataTransferServiceCreateJobOperator` (#36909)
46470aba68 is described below

commit 46470aba68e5ebeee24a03dc22d012a50ee287ad
Author: rom sharon <[email protected]>
AuthorDate: Sun Feb 4 15:57:13 2024 +0200

    Fix assignment of template field in `__init__` in 
`CloudDataTransferServiceCreateJobOperator` (#36909)
    
    * fix initialization of templated field in constructor
    
    * remove file from exclude
    
    * add test for templated field
    
    * change body to be realistic
---
 .pre-commit-config.yaml                                  |  1 -
 .../cloud/operators/cloud_storage_transfer_service.py    |  4 +++-
 .../operators/test_cloud_storage_transfer_service.py     | 16 ++++++++++++----
 3 files changed, 15 insertions(+), 6 deletions(-)

diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index cd7565f0b4..ded656d9dc 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -318,7 +318,6 @@ repos:
               ^airflow\/providers\/google\/cloud\/operators\/bigquery\.py$|
               ^airflow\/providers\/amazon\/aws\/transfers\/gcs_to_s3\.py$|
               ^airflow\/providers\/databricks\/operators\/databricks\.py$|
-              
^airflow\/providers\/google\/cloud\/operators\/cloud_storage_transfer_service\.py$|
               
^airflow\/providers\/google\/cloud\/transfers\/bigquery_to_mysql\.py$|
               ^airflow\/providers\/amazon\/aws\/transfers\/redshift_to_s3\.py$|
               ^airflow\/providers\/google\/cloud\/operators\/compute\.py$|
diff --git 
a/airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py 
b/airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py
index ea6a7eacdb..9f25ac58f7 100644
--- a/airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py
+++ b/airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py
@@ -236,7 +236,9 @@ class 
CloudDataTransferServiceCreateJobOperator(GoogleCloudBaseOperator):
         **kwargs,
     ) -> None:
         super().__init__(**kwargs)
-        self.body = deepcopy(body)
+        self.body = body
+        if isinstance(self.body, dict):
+            self.body = deepcopy(body)
         self.aws_conn_id = aws_conn_id
         self.gcp_conn_id = gcp_conn_id
         self.api_version = api_version
diff --git 
a/tests/providers/google/cloud/operators/test_cloud_storage_transfer_service.py 
b/tests/providers/google/cloud/operators/test_cloud_storage_transfer_service.py
index 40c2b87b0a..1ef5d6b729 100644
--- 
a/tests/providers/google/cloud/operators/test_cloud_storage_transfer_service.py
+++ 
b/tests/providers/google/cloud/operators/test_cloud_storage_transfer_service.py
@@ -114,6 +114,10 @@ VALID_TRANSFER_JOB_BASE: dict = {
     SCHEDULE: SCHEDULE_DICT,
     TRANSFER_SPEC: {GCS_DATA_SINK: {BUCKET_NAME: GCS_BUCKET_NAME, PATH: 
DESTINATION_PATH}},
 }
+VALID_TRANSFER_JOB_JINJA = deepcopy(VALID_TRANSFER_JOB_BASE)
+VALID_TRANSFER_JOB_JINJA[NAME] = "{{ dag.dag_id }}"
+VALID_TRANSFER_JOB_JINJA_RENDERED = deepcopy(VALID_TRANSFER_JOB_JINJA)
+VALID_TRANSFER_JOB_JINJA_RENDERED[NAME] = 
"TestGcpStorageTransferJobCreateOperator"
 VALID_TRANSFER_JOB_GCS = deepcopy(VALID_TRANSFER_JOB_BASE)
 VALID_TRANSFER_JOB_GCS[TRANSFER_SPEC].update(deepcopy(SOURCE_GCS))
 VALID_TRANSFER_JOB_AWS = deepcopy(VALID_TRANSFER_JOB_BASE)
@@ -324,21 +328,25 @@ class TestGcpStorageTransferJobCreateOperator:
     # (could be anything else) just to test if the templating works for all
     # fields
     @pytest.mark.db_test
+    @pytest.mark.parametrize(
+        "body, excepted",
+        [(VALID_TRANSFER_JOB_JINJA, VALID_TRANSFER_JOB_JINJA_RENDERED)],
+    )
     @mock.patch(
         
"airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook"
     )
-    def test_templates(self, _, create_task_instance_of_operator):
-        dag_id = "TestGcpStorageTransferJobCreateOperator_test_templates"
+    def test_templates(self, _, create_task_instance_of_operator, body, 
excepted):
+        dag_id = "TestGcpStorageTransferJobCreateOperator"
         ti = create_task_instance_of_operator(
             CloudDataTransferServiceCreateJobOperator,
             dag_id=dag_id,
-            body={"description": "{{ dag.dag_id }}"},
+            body=body,
             gcp_conn_id="{{ dag.dag_id }}",
             aws_conn_id="{{ dag.dag_id }}",
             task_id="task-id",
         )
         ti.render_templates()
-        assert dag_id == getattr(ti.task, "body")[DESCRIPTION]
+        assert excepted == getattr(ti.task, "body")
         assert dag_id == getattr(ti.task, "gcp_conn_id")
         assert dag_id == getattr(ti.task, "aws_conn_id")
 

Reply via email to