This is an automated email from the ASF dual-hosted git repository.
husseinawala pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 46470aba68 Fix assignment of template field in `__init__` in
`CloudDataTransferServiceCreateJobOperator` (#36909)
46470aba68 is described below
commit 46470aba68e5ebeee24a03dc22d012a50ee287ad
Author: rom sharon <[email protected]>
AuthorDate: Sun Feb 4 15:57:13 2024 +0200
Fix assignment of template field in `__init__` in
`CloudDataTransferServiceCreateJobOperator` (#36909)
* fix initialization of templated field in constructor
* remove file from exclude
* add test for templated field
* change body to be realistic
---
.pre-commit-config.yaml | 1 -
.../cloud/operators/cloud_storage_transfer_service.py | 4 +++-
.../operators/test_cloud_storage_transfer_service.py | 16 ++++++++++++----
3 files changed, 15 insertions(+), 6 deletions(-)
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index cd7565f0b4..ded656d9dc 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -318,7 +318,6 @@ repos:
^airflow\/providers\/google\/cloud\/operators\/bigquery\.py$|
^airflow\/providers\/amazon\/aws\/transfers\/gcs_to_s3\.py$|
^airflow\/providers\/databricks\/operators\/databricks\.py$|
-
^airflow\/providers\/google\/cloud\/operators\/cloud_storage_transfer_service\.py$|
^airflow\/providers\/google\/cloud\/transfers\/bigquery_to_mysql\.py$|
^airflow\/providers\/amazon\/aws\/transfers\/redshift_to_s3\.py$|
^airflow\/providers\/google\/cloud\/operators\/compute\.py$|
diff --git
a/airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py
b/airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py
index ea6a7eacdb..9f25ac58f7 100644
--- a/airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py
+++ b/airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py
@@ -236,7 +236,9 @@ class
CloudDataTransferServiceCreateJobOperator(GoogleCloudBaseOperator):
**kwargs,
) -> None:
super().__init__(**kwargs)
- self.body = deepcopy(body)
+ self.body = body
+ if isinstance(self.body, dict):
+ self.body = deepcopy(body)
self.aws_conn_id = aws_conn_id
self.gcp_conn_id = gcp_conn_id
self.api_version = api_version
diff --git
a/tests/providers/google/cloud/operators/test_cloud_storage_transfer_service.py
b/tests/providers/google/cloud/operators/test_cloud_storage_transfer_service.py
index 40c2b87b0a..1ef5d6b729 100644
---
a/tests/providers/google/cloud/operators/test_cloud_storage_transfer_service.py
+++
b/tests/providers/google/cloud/operators/test_cloud_storage_transfer_service.py
@@ -114,6 +114,10 @@ VALID_TRANSFER_JOB_BASE: dict = {
SCHEDULE: SCHEDULE_DICT,
TRANSFER_SPEC: {GCS_DATA_SINK: {BUCKET_NAME: GCS_BUCKET_NAME, PATH:
DESTINATION_PATH}},
}
+VALID_TRANSFER_JOB_JINJA = deepcopy(VALID_TRANSFER_JOB_BASE)
+VALID_TRANSFER_JOB_JINJA[NAME] = "{{ dag.dag_id }}"
+VALID_TRANSFER_JOB_JINJA_RENDERED = deepcopy(VALID_TRANSFER_JOB_JINJA)
+VALID_TRANSFER_JOB_JINJA_RENDERED[NAME] =
"TestGcpStorageTransferJobCreateOperator"
VALID_TRANSFER_JOB_GCS = deepcopy(VALID_TRANSFER_JOB_BASE)
VALID_TRANSFER_JOB_GCS[TRANSFER_SPEC].update(deepcopy(SOURCE_GCS))
VALID_TRANSFER_JOB_AWS = deepcopy(VALID_TRANSFER_JOB_BASE)
@@ -324,21 +328,25 @@ class TestGcpStorageTransferJobCreateOperator:
# (could be anything else) just to test if the templating works for all
# fields
@pytest.mark.db_test
+ @pytest.mark.parametrize(
+ "body, excepted",
+ [(VALID_TRANSFER_JOB_JINJA, VALID_TRANSFER_JOB_JINJA_RENDERED)],
+ )
@mock.patch(
"airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook"
)
- def test_templates(self, _, create_task_instance_of_operator):
- dag_id = "TestGcpStorageTransferJobCreateOperator_test_templates"
+ def test_templates(self, _, create_task_instance_of_operator, body,
excepted):
+ dag_id = "TestGcpStorageTransferJobCreateOperator"
ti = create_task_instance_of_operator(
CloudDataTransferServiceCreateJobOperator,
dag_id=dag_id,
- body={"description": "{{ dag.dag_id }}"},
+ body=body,
gcp_conn_id="{{ dag.dag_id }}",
aws_conn_id="{{ dag.dag_id }}",
task_id="task-id",
)
ti.render_templates()
- assert dag_id == getattr(ti.task, "body")[DESCRIPTION]
+ assert excepted == getattr(ti.task, "body")
assert dag_id == getattr(ti.task, "gcp_conn_id")
assert dag_id == getattr(ti.task, "aws_conn_id")