This is an automated email from the ASF dual-hosted git repository.

vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 2c44fdd12c8 Removed unnecessary `aws_conn_id` param from operators 
constructors (#51236)
2c44fdd12c8 is described below

commit 2c44fdd12c829b298bcdd9efe0673748ad43e131
Author: dominikhei <[email protected]>
AuthorDate: Mon Jun 9 16:32:05 2025 +0200

    Removed unnecessary `aws_conn_id` param from operators constructors (#51236)
    
    * Removed unnecessary aws_conn_id param from operators constructors
    
    * Added regression tests to operators and renamed no_conn test to 
default_conn
---
 .../amazon/aws/operators/cloud_formation.py        |  2 -
 .../providers/amazon/aws/operators/comprehend.py   |  2 -
 .../airflow/providers/amazon/aws/operators/dms.py  |  2 -
 .../airflow/providers/amazon/aws/operators/glue.py |  6 --
 .../amazon/aws/operators/test_cloud_formation.py   | 17 +++++
 .../unit/amazon/aws/operators/test_comprehend.py   | 23 ++++++
 .../tests/unit/amazon/aws/operators/test_dms.py    | 21 ++++++
 .../tests/unit/amazon/aws/operators/test_glue.py   | 84 ++++++++++++++++++++++
 .../tests/unit/amazon/aws/operators/test_rds.py    |  4 +-
 9 files changed, 147 insertions(+), 14 deletions(-)

diff --git 
a/providers/amazon/src/airflow/providers/amazon/aws/operators/cloud_formation.py
 
b/providers/amazon/src/airflow/providers/amazon/aws/operators/cloud_formation.py
index b3168dccc46..e4f3d211bea 100644
--- 
a/providers/amazon/src/airflow/providers/amazon/aws/operators/cloud_formation.py
+++ 
b/providers/amazon/src/airflow/providers/amazon/aws/operators/cloud_formation.py
@@ -98,13 +98,11 @@ class 
CloudFormationDeleteStackOperator(AwsBaseOperator[CloudFormationHook]):
         *,
         stack_name: str,
         cloudformation_parameters: dict | None = None,
-        aws_conn_id: str | None = "aws_default",
         **kwargs,
     ):
         super().__init__(**kwargs)
         self.cloudformation_parameters = cloudformation_parameters or {}
         self.stack_name = stack_name
-        self.aws_conn_id = aws_conn_id
 
     def execute(self, context: Context):
         self.log.info("CloudFormation Parameters: %s", 
self.cloudformation_parameters)
diff --git 
a/providers/amazon/src/airflow/providers/amazon/aws/operators/comprehend.py 
b/providers/amazon/src/airflow/providers/amazon/aws/operators/comprehend.py
index e8bc64973c7..c1b459bc34e 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/operators/comprehend.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/comprehend.py
@@ -289,7 +289,6 @@ class 
ComprehendCreateDocumentClassifierOperator(AwsBaseOperator[ComprehendHook]
         waiter_delay: int = 60,
         waiter_max_attempts: int = 20,
         deferrable: bool = conf.getboolean("operators", "default_deferrable", 
fallback=False),
-        aws_conn_id: str | None = "aws_default",
         **kwargs,
     ):
         super().__init__(**kwargs)
@@ -305,7 +304,6 @@ class 
ComprehendCreateDocumentClassifierOperator(AwsBaseOperator[ComprehendHook]
         self.waiter_delay = waiter_delay
         self.waiter_max_attempts = waiter_max_attempts
         self.deferrable = deferrable
-        self.aws_conn_id = aws_conn_id
 
     def execute(self, context: Context) -> str:
         if self.output_data_config:
diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/dms.py 
b/providers/amazon/src/airflow/providers/amazon/aws/operators/dms.py
index 75897f2b897..b9af469a506 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/operators/dms.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/dms.py
@@ -91,7 +91,6 @@ class DmsCreateTaskOperator(AwsBaseOperator[DmsHook]):
         table_mappings: dict,
         migration_type: str = "full-load",
         create_task_kwargs: dict | None = None,
-        aws_conn_id: str | None = "aws_default",
         **kwargs,
     ):
         super().__init__(**kwargs)
@@ -102,7 +101,6 @@ class DmsCreateTaskOperator(AwsBaseOperator[DmsHook]):
         self.migration_type = migration_type
         self.table_mappings = table_mappings
         self.create_task_kwargs = create_task_kwargs or {}
-        self.aws_conn_id = aws_conn_id
 
     def execute(self, context: Context):
         """
diff --git 
a/providers/amazon/src/airflow/providers/amazon/aws/operators/glue.py 
b/providers/amazon/src/airflow/providers/amazon/aws/operators/glue.py
index 2b7496de78a..4b80d47b046 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/operators/glue.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/glue.py
@@ -313,7 +313,6 @@ class 
GlueDataQualityOperator(AwsBaseOperator[GlueDataQualityHook]):
         description: str = "AWS Glue Data Quality Rule Set With Airflow",
         update_rule_set: bool = False,
         data_quality_ruleset_kwargs: dict | None = None,
-        aws_conn_id: str | None = "aws_default",
         **kwargs,
     ):
         super().__init__(**kwargs)
@@ -322,7 +321,6 @@ class 
GlueDataQualityOperator(AwsBaseOperator[GlueDataQualityHook]):
         self.description = description
         self.update_rule_set = update_rule_set
         self.data_quality_ruleset_kwargs = data_quality_ruleset_kwargs or {}
-        self.aws_conn_id = aws_conn_id
 
     def validate_inputs(self) -> None:
         if not self.ruleset.startswith("Rules") or not 
self.ruleset.endswith("]"):
@@ -421,7 +419,6 @@ class 
GlueDataQualityRuleSetEvaluationRunOperator(AwsBaseOperator[GlueDataQualit
         waiter_delay: int = 60,
         waiter_max_attempts: int = 20,
         deferrable: bool = conf.getboolean("operators", "default_deferrable", 
fallback=False),
-        aws_conn_id: str | None = "aws_default",
         **kwargs,
     ):
         super().__init__(**kwargs)
@@ -437,7 +434,6 @@ class 
GlueDataQualityRuleSetEvaluationRunOperator(AwsBaseOperator[GlueDataQualit
         self.waiter_delay = waiter_delay
         self.waiter_max_attempts = waiter_max_attempts
         self.deferrable = deferrable
-        self.aws_conn_id = aws_conn_id
 
     def validate_inputs(self) -> None:
         glue_table = self.datasource.get("GlueTable", {})
@@ -584,7 +580,6 @@ class 
GlueDataQualityRuleRecommendationRunOperator(AwsBaseOperator[GlueDataQuali
         waiter_delay: int = 60,
         waiter_max_attempts: int = 20,
         deferrable: bool = conf.getboolean("operators", "default_deferrable", 
fallback=False),
-        aws_conn_id: str | None = "aws_default",
         **kwargs,
     ):
         super().__init__(**kwargs)
@@ -598,7 +593,6 @@ class 
GlueDataQualityRuleRecommendationRunOperator(AwsBaseOperator[GlueDataQuali
         self.waiter_delay = waiter_delay
         self.waiter_max_attempts = waiter_max_attempts
         self.deferrable = deferrable
-        self.aws_conn_id = aws_conn_id
 
     def execute(self, context: Context) -> str:
         glue_table = self.datasource.get("GlueTable", {})
diff --git 
a/providers/amazon/tests/unit/amazon/aws/operators/test_cloud_formation.py 
b/providers/amazon/tests/unit/amazon/aws/operators/test_cloud_formation.py
index 1230e5d27fb..bbcd41f2176 100644
--- a/providers/amazon/tests/unit/amazon/aws/operators/test_cloud_formation.py
+++ b/providers/amazon/tests/unit/amazon/aws/operators/test_cloud_formation.py
@@ -103,6 +103,23 @@ class TestCloudFormationCreateStackOperator:
 
         validate_template_fields(op)
 
+    def test_overwritten_conn_passed_to_hook(self):
+        OVERWRITTEN_CONN = "new-conn-id"
+        op = CloudFormationCreateStackOperator(
+            task_id="cf_create_stack_pass_conn",
+            stack_name="fake-stack",
+            cloudformation_parameters={},
+            aws_conn_id=OVERWRITTEN_CONN,
+        )
+        assert op.hook.aws_conn_id == OVERWRITTEN_CONN
+
+    def test_default_conn_passed_to_hook(self):
+        DEFAULT_CONN = "aws_default"
+        op = CloudFormationCreateStackOperator(
+            task_id="cf_create_stack_pass_default_conn", 
stack_name="fake-stack", cloudformation_parameters={}
+        )
+        assert op.hook.aws_conn_id == DEFAULT_CONN
+
 
 class TestCloudFormationDeleteStackOperator:
     def test_init(self):
diff --git 
a/providers/amazon/tests/unit/amazon/aws/operators/test_comprehend.py 
b/providers/amazon/tests/unit/amazon/aws/operators/test_comprehend.py
index 5d500f7637e..3e74c7896fa 100644
--- a/providers/amazon/tests/unit/amazon/aws/operators/test_comprehend.py
+++ b/providers/amazon/tests/unit/amazon/aws/operators/test_comprehend.py
@@ -80,6 +80,29 @@ class TestComprehendBaseOperator:
         assert comprehend_base_op.client == mocked_client
         comprehend_base_operator_mock_hook.assert_called_once()
 
+    def test_overwritten_conn_passed_to_hook(self):
+        OVERWRITTEN_CONN = "new-conn-id"
+        op = ComprehendBaseOperator(
+            task_id="comprehend_base_operator",
+            input_data_config=INPUT_DATA_CONFIG,
+            output_data_config=OUTPUT_DATA_CONFIG,
+            language_code=LANGUAGE_CODE,
+            data_access_role_arn=ROLE_ARN,
+            aws_conn_id=OVERWRITTEN_CONN,
+        )
+        assert op.hook.aws_conn_id == OVERWRITTEN_CONN
+
+    def test_default_conn_passed_to_hook(self):
+        DEFAULT_CONN = "aws_default"
+        op = ComprehendBaseOperator(
+            task_id="comprehend_base_operator",
+            input_data_config=INPUT_DATA_CONFIG,
+            output_data_config=OUTPUT_DATA_CONFIG,
+            language_code=LANGUAGE_CODE,
+            data_access_role_arn=ROLE_ARN,
+        )
+        assert op.hook.aws_conn_id == DEFAULT_CONN
+
 
 class TestComprehendStartPiiEntitiesDetectionJobOperator:
     JOB_ID = "random-job-id-1234567"
diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_dms.py 
b/providers/amazon/tests/unit/amazon/aws/operators/test_dms.py
index 3cb97cad9d2..5771483cd32 100644
--- a/providers/amazon/tests/unit/amazon/aws/operators/test_dms.py
+++ b/providers/amazon/tests/unit/amazon/aws/operators/test_dms.py
@@ -149,6 +149,27 @@ class TestDmsCreateTaskOperator:
 
         validate_template_fields(op)
 
+    def test_overwritten_conn_passed_to_hook(self):
+        OVERWRITTEN_CONN = "new-conn-id"
+        op = DmsCreateTaskOperator(
+            task_id="dms_create_task_operator",
+            **self.TASK_DATA,
+            aws_conn_id=OVERWRITTEN_CONN,
+            verify=True,
+            botocore_config={"read_timeout": 42},
+        )
+        assert op.hook.aws_conn_id == OVERWRITTEN_CONN
+
+    def test_default_conn_passed_to_hook(self):
+        DEFAULT_CONN = "aws_default"
+        op = DmsCreateTaskOperator(
+            task_id="dms_create_task_operator",
+            **self.TASK_DATA,
+            verify=True,
+            botocore_config={"read_timeout": 42},
+        )
+        assert op.hook.aws_conn_id == DEFAULT_CONN
+
 
 class TestDmsDeleteTaskOperator:
     TASK_DATA = {
diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_glue.py 
b/providers/amazon/tests/unit/amazon/aws/operators/test_glue.py
index 21ab76a7317..7a0210cf130 100644
--- a/providers/amazon/tests/unit/amazon/aws/operators/test_glue.py
+++ b/providers/amazon/tests/unit/amazon/aws/operators/test_glue.py
@@ -408,6 +408,25 @@ class TestGlueJobOperator:
         )
         validate_template_fields(operator)
 
+    def test_overwritten_conn_passed_to_hook(self):
+        OVERWRITTEN_CONN = "new-conn-id"
+        op = GlueJobOperator(
+            task_id=TASK_ID,
+            aws_conn_id=OVERWRITTEN_CONN,
+            iam_role_name="role_arn",
+            replace_script_file=True,
+        )
+        assert op.hook.aws_conn_id == OVERWRITTEN_CONN
+
+    def test_default_conn_passed_to_hook(self):
+        DEFAULT_CONN = "aws_default"
+        op = GlueJobOperator(
+            task_id=TASK_ID,
+            iam_role_name="role_arn",
+            replace_script_file=True,
+        )
+        assert op.hook.aws_conn_id == DEFAULT_CONN
+
 
 class TestGlueDataQualityOperator:
     RULE_SET_NAME = "TestRuleSet"
@@ -542,6 +561,23 @@ class TestGlueDataQualityOperator:
         )
         validate_template_fields(operator)
 
+    def test_overwritten_conn_passed_to_hook(self):
+        OVERWRITTEN_CONN = "new-conn-id"
+        op = GlueDataQualityOperator(
+            task_id="test_overwritten_conn_passed_to_hook",
+            name=self.RULE_SET_NAME,
+            ruleset=self.RULE_SET,
+            aws_conn_id=OVERWRITTEN_CONN,
+        )
+        assert op.hook.aws_conn_id == OVERWRITTEN_CONN
+
+    def test_default_conn_passed_to_hook(self):
+        DEFAULT_CONN = "aws_default"
+        op = GlueDataQualityOperator(
+            task_id="test_default_conn_passed_to_hook", 
name=self.RULE_SET_NAME, ruleset=self.RULE_SET
+        )
+        assert op.hook.aws_conn_id == DEFAULT_CONN
+
 
 class TestGlueDataQualityRuleSetEvaluationRunOperator:
     RUN_ID = "1234567890"
@@ -648,6 +684,29 @@ class TestGlueDataQualityRuleSetEvaluationRunOperator:
     def test_template_fields(self):
         validate_template_fields(self.operator)
 
+    def test_overwritten_conn_passed_to_hook(self):
+        OVERWRITTEN_CONN = "new-conn-id"
+        op = GlueDataQualityRuleSetEvaluationRunOperator(
+            task_id="test_overwritten_conn_passed_to_hook",
+            datasource=self.DATA_SOURCE,
+            role=self.ROLE,
+            rule_set_names=self.RULE_SET_NAMES,
+            show_results=False,
+            aws_conn_id=OVERWRITTEN_CONN,
+        )
+        assert op.hook.aws_conn_id == OVERWRITTEN_CONN
+
+    def test_default_conn_passed_to_hook(self):
+        DEFAULT_CONN = "aws_default"
+        op = GlueDataQualityRuleSetEvaluationRunOperator(
+            task_id="test_default_conn_passed_to_hook",
+            datasource=self.DATA_SOURCE,
+            role=self.ROLE,
+            rule_set_names=self.RULE_SET_NAMES,
+            show_results=False,
+        )
+        assert op.hook.aws_conn_id == DEFAULT_CONN
+
 
 class TestGlueDataQualityRuleRecommendationRunOperator:
     RUN_ID = "1234567890"
@@ -756,3 +815,28 @@ class TestGlueDataQualityRuleRecommendationRunOperator:
 
     def test_template_fields(self):
         validate_template_fields(self.operator)
+
+    def test_overwritten_conn_passed_to_hook(self):
+        OVERWRITTEN_CONN = "new-conn-id"
+        op = GlueDataQualityRuleRecommendationRunOperator(
+            task_id="test_overwritten_conn_passed_to_hook",
+            datasource=self.DATA_SOURCE,
+            role=self.ROLE,
+            number_of_workers=10,
+            timeout=1000,
+            recommendation_run_kwargs={"CreatedRulesetName": "test-ruleset"},
+            aws_conn_id=OVERWRITTEN_CONN,
+        )
+        assert op.hook.aws_conn_id == OVERWRITTEN_CONN
+
+    def test_default_conn_passed_to_hook(self):
+        DEFAULT_CONN = "aws_default"
+        op = GlueDataQualityRuleRecommendationRunOperator(
+            task_id="test_default_conn_passed_to_hook",
+            datasource=self.DATA_SOURCE,
+            role=self.ROLE,
+            number_of_workers=10,
+            timeout=1000,
+            recommendation_run_kwargs={"CreatedRulesetName": "test-ruleset"},
+        )
+        assert op.hook.aws_conn_id == DEFAULT_CONN
diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_rds.py 
b/providers/amazon/tests/unit/amazon/aws/operators/test_rds.py
index 412dbc1094a..565780db09a 100644
--- a/providers/amazon/tests/unit/amazon/aws/operators/test_rds.py
+++ b/providers/amazon/tests/unit/amazon/aws/operators/test_rds.py
@@ -181,9 +181,9 @@ class TestBaseRdsOperator:
         )
         assert op.hook.aws_conn_id == OVERWRITTEN_CONN
 
-    def test_no_conn_passed_to_hook(self):
+    def test_default_conn_passed_to_hook(self):
         DEFAULT_CONN = "aws_default"
-        op = RdsBaseOperator(task_id="test_no_conn_passed_to_hook_task", 
dag=self.dag)
+        op = RdsBaseOperator(task_id="test_default_conn_passed_to_hook_task", 
dag=self.dag)
         assert op.hook.aws_conn_id == DEFAULT_CONN
 
 

Reply via email to