This is an automated email from the ASF dual-hosted git repository.
vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 2c44fdd12c8 Removed unnecessary `aws_conn_id` param from operators
constructors (#51236)
2c44fdd12c8 is described below
commit 2c44fdd12c829b298bcdd9efe0673748ad43e131
Author: dominikhei <[email protected]>
AuthorDate: Mon Jun 9 16:32:05 2025 +0200
Removed unnecessary `aws_conn_id` param from operators constructors (#51236)
* Removed unnecessary aws_conn_id param from operators constructors
* Added regression tests to operators and renamed no_conn test to
default_conn
---
.../amazon/aws/operators/cloud_formation.py | 2 -
.../providers/amazon/aws/operators/comprehend.py | 2 -
.../airflow/providers/amazon/aws/operators/dms.py | 2 -
.../airflow/providers/amazon/aws/operators/glue.py | 6 --
.../amazon/aws/operators/test_cloud_formation.py | 17 +++++
.../unit/amazon/aws/operators/test_comprehend.py | 23 ++++++
.../tests/unit/amazon/aws/operators/test_dms.py | 21 ++++++
.../tests/unit/amazon/aws/operators/test_glue.py | 84 ++++++++++++++++++++++
.../tests/unit/amazon/aws/operators/test_rds.py | 4 +-
9 files changed, 147 insertions(+), 14 deletions(-)
diff --git
a/providers/amazon/src/airflow/providers/amazon/aws/operators/cloud_formation.py
b/providers/amazon/src/airflow/providers/amazon/aws/operators/cloud_formation.py
index b3168dccc46..e4f3d211bea 100644
---
a/providers/amazon/src/airflow/providers/amazon/aws/operators/cloud_formation.py
+++
b/providers/amazon/src/airflow/providers/amazon/aws/operators/cloud_formation.py
@@ -98,13 +98,11 @@ class
CloudFormationDeleteStackOperator(AwsBaseOperator[CloudFormationHook]):
*,
stack_name: str,
cloudformation_parameters: dict | None = None,
- aws_conn_id: str | None = "aws_default",
**kwargs,
):
super().__init__(**kwargs)
self.cloudformation_parameters = cloudformation_parameters or {}
self.stack_name = stack_name
- self.aws_conn_id = aws_conn_id
def execute(self, context: Context):
self.log.info("CloudFormation Parameters: %s",
self.cloudformation_parameters)
diff --git
a/providers/amazon/src/airflow/providers/amazon/aws/operators/comprehend.py
b/providers/amazon/src/airflow/providers/amazon/aws/operators/comprehend.py
index e8bc64973c7..c1b459bc34e 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/operators/comprehend.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/comprehend.py
@@ -289,7 +289,6 @@ class
ComprehendCreateDocumentClassifierOperator(AwsBaseOperator[ComprehendHook]
waiter_delay: int = 60,
waiter_max_attempts: int = 20,
deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
- aws_conn_id: str | None = "aws_default",
**kwargs,
):
super().__init__(**kwargs)
@@ -305,7 +304,6 @@ class
ComprehendCreateDocumentClassifierOperator(AwsBaseOperator[ComprehendHook]
self.waiter_delay = waiter_delay
self.waiter_max_attempts = waiter_max_attempts
self.deferrable = deferrable
- self.aws_conn_id = aws_conn_id
def execute(self, context: Context) -> str:
if self.output_data_config:
diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/dms.py
b/providers/amazon/src/airflow/providers/amazon/aws/operators/dms.py
index 75897f2b897..b9af469a506 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/operators/dms.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/dms.py
@@ -91,7 +91,6 @@ class DmsCreateTaskOperator(AwsBaseOperator[DmsHook]):
table_mappings: dict,
migration_type: str = "full-load",
create_task_kwargs: dict | None = None,
- aws_conn_id: str | None = "aws_default",
**kwargs,
):
super().__init__(**kwargs)
@@ -102,7 +101,6 @@ class DmsCreateTaskOperator(AwsBaseOperator[DmsHook]):
self.migration_type = migration_type
self.table_mappings = table_mappings
self.create_task_kwargs = create_task_kwargs or {}
- self.aws_conn_id = aws_conn_id
def execute(self, context: Context):
"""
diff --git
a/providers/amazon/src/airflow/providers/amazon/aws/operators/glue.py
b/providers/amazon/src/airflow/providers/amazon/aws/operators/glue.py
index 2b7496de78a..4b80d47b046 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/operators/glue.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/glue.py
@@ -313,7 +313,6 @@ class
GlueDataQualityOperator(AwsBaseOperator[GlueDataQualityHook]):
description: str = "AWS Glue Data Quality Rule Set With Airflow",
update_rule_set: bool = False,
data_quality_ruleset_kwargs: dict | None = None,
- aws_conn_id: str | None = "aws_default",
**kwargs,
):
super().__init__(**kwargs)
@@ -322,7 +321,6 @@ class
GlueDataQualityOperator(AwsBaseOperator[GlueDataQualityHook]):
self.description = description
self.update_rule_set = update_rule_set
self.data_quality_ruleset_kwargs = data_quality_ruleset_kwargs or {}
- self.aws_conn_id = aws_conn_id
def validate_inputs(self) -> None:
if not self.ruleset.startswith("Rules") or not
self.ruleset.endswith("]"):
@@ -421,7 +419,6 @@ class
GlueDataQualityRuleSetEvaluationRunOperator(AwsBaseOperator[GlueDataQualit
waiter_delay: int = 60,
waiter_max_attempts: int = 20,
deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
- aws_conn_id: str | None = "aws_default",
**kwargs,
):
super().__init__(**kwargs)
@@ -437,7 +434,6 @@ class
GlueDataQualityRuleSetEvaluationRunOperator(AwsBaseOperator[GlueDataQualit
self.waiter_delay = waiter_delay
self.waiter_max_attempts = waiter_max_attempts
self.deferrable = deferrable
- self.aws_conn_id = aws_conn_id
def validate_inputs(self) -> None:
glue_table = self.datasource.get("GlueTable", {})
@@ -584,7 +580,6 @@ class
GlueDataQualityRuleRecommendationRunOperator(AwsBaseOperator[GlueDataQuali
waiter_delay: int = 60,
waiter_max_attempts: int = 20,
deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
- aws_conn_id: str | None = "aws_default",
**kwargs,
):
super().__init__(**kwargs)
@@ -598,7 +593,6 @@ class
GlueDataQualityRuleRecommendationRunOperator(AwsBaseOperator[GlueDataQuali
self.waiter_delay = waiter_delay
self.waiter_max_attempts = waiter_max_attempts
self.deferrable = deferrable
- self.aws_conn_id = aws_conn_id
def execute(self, context: Context) -> str:
glue_table = self.datasource.get("GlueTable", {})
diff --git
a/providers/amazon/tests/unit/amazon/aws/operators/test_cloud_formation.py
b/providers/amazon/tests/unit/amazon/aws/operators/test_cloud_formation.py
index 1230e5d27fb..bbcd41f2176 100644
--- a/providers/amazon/tests/unit/amazon/aws/operators/test_cloud_formation.py
+++ b/providers/amazon/tests/unit/amazon/aws/operators/test_cloud_formation.py
@@ -103,6 +103,23 @@ class TestCloudFormationCreateStackOperator:
validate_template_fields(op)
+ def test_overwritten_conn_passed_to_hook(self):
+ OVERWRITTEN_CONN = "new-conn-id"
+ op = CloudFormationCreateStackOperator(
+ task_id="cf_create_stack_pass_conn",
+ stack_name="fake-stack",
+ cloudformation_parameters={},
+ aws_conn_id=OVERWRITTEN_CONN,
+ )
+ assert op.hook.aws_conn_id == OVERWRITTEN_CONN
+
+ def test_default_conn_passed_to_hook(self):
+ DEFAULT_CONN = "aws_default"
+ op = CloudFormationCreateStackOperator(
+ task_id="cf_create_stack_pass_default_conn",
stack_name="fake-stack", cloudformation_parameters={}
+ )
+ assert op.hook.aws_conn_id == DEFAULT_CONN
+
class TestCloudFormationDeleteStackOperator:
def test_init(self):
diff --git
a/providers/amazon/tests/unit/amazon/aws/operators/test_comprehend.py
b/providers/amazon/tests/unit/amazon/aws/operators/test_comprehend.py
index 5d500f7637e..3e74c7896fa 100644
--- a/providers/amazon/tests/unit/amazon/aws/operators/test_comprehend.py
+++ b/providers/amazon/tests/unit/amazon/aws/operators/test_comprehend.py
@@ -80,6 +80,29 @@ class TestComprehendBaseOperator:
assert comprehend_base_op.client == mocked_client
comprehend_base_operator_mock_hook.assert_called_once()
+ def test_overwritten_conn_passed_to_hook(self):
+ OVERWRITTEN_CONN = "new-conn-id"
+ op = ComprehendBaseOperator(
+ task_id="comprehend_base_operator",
+ input_data_config=INPUT_DATA_CONFIG,
+ output_data_config=OUTPUT_DATA_CONFIG,
+ language_code=LANGUAGE_CODE,
+ data_access_role_arn=ROLE_ARN,
+ aws_conn_id=OVERWRITTEN_CONN,
+ )
+ assert op.hook.aws_conn_id == OVERWRITTEN_CONN
+
+ def test_default_conn_passed_to_hook(self):
+ DEFAULT_CONN = "aws_default"
+ op = ComprehendBaseOperator(
+ task_id="comprehend_base_operator",
+ input_data_config=INPUT_DATA_CONFIG,
+ output_data_config=OUTPUT_DATA_CONFIG,
+ language_code=LANGUAGE_CODE,
+ data_access_role_arn=ROLE_ARN,
+ )
+ assert op.hook.aws_conn_id == DEFAULT_CONN
+
class TestComprehendStartPiiEntitiesDetectionJobOperator:
JOB_ID = "random-job-id-1234567"
diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_dms.py
b/providers/amazon/tests/unit/amazon/aws/operators/test_dms.py
index 3cb97cad9d2..5771483cd32 100644
--- a/providers/amazon/tests/unit/amazon/aws/operators/test_dms.py
+++ b/providers/amazon/tests/unit/amazon/aws/operators/test_dms.py
@@ -149,6 +149,27 @@ class TestDmsCreateTaskOperator:
validate_template_fields(op)
+ def test_overwritten_conn_passed_to_hook(self):
+ OVERWRITTEN_CONN = "new-conn-id"
+ op = DmsCreateTaskOperator(
+ task_id="dms_create_task_operator",
+ **self.TASK_DATA,
+ aws_conn_id=OVERWRITTEN_CONN,
+ verify=True,
+ botocore_config={"read_timeout": 42},
+ )
+ assert op.hook.aws_conn_id == OVERWRITTEN_CONN
+
+ def test_default_conn_passed_to_hook(self):
+ DEFAULT_CONN = "aws_default"
+ op = DmsCreateTaskOperator(
+ task_id="dms_create_task_operator",
+ **self.TASK_DATA,
+ verify=True,
+ botocore_config={"read_timeout": 42},
+ )
+ assert op.hook.aws_conn_id == DEFAULT_CONN
+
class TestDmsDeleteTaskOperator:
TASK_DATA = {
diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_glue.py
b/providers/amazon/tests/unit/amazon/aws/operators/test_glue.py
index 21ab76a7317..7a0210cf130 100644
--- a/providers/amazon/tests/unit/amazon/aws/operators/test_glue.py
+++ b/providers/amazon/tests/unit/amazon/aws/operators/test_glue.py
@@ -408,6 +408,25 @@ class TestGlueJobOperator:
)
validate_template_fields(operator)
+ def test_overwritten_conn_passed_to_hook(self):
+ OVERWRITTEN_CONN = "new-conn-id"
+ op = GlueJobOperator(
+ task_id=TASK_ID,
+ aws_conn_id=OVERWRITTEN_CONN,
+ iam_role_name="role_arn",
+ replace_script_file=True,
+ )
+ assert op.hook.aws_conn_id == OVERWRITTEN_CONN
+
+ def test_default_conn_passed_to_hook(self):
+ DEFAULT_CONN = "aws_default"
+ op = GlueJobOperator(
+ task_id=TASK_ID,
+ iam_role_name="role_arn",
+ replace_script_file=True,
+ )
+ assert op.hook.aws_conn_id == DEFAULT_CONN
+
class TestGlueDataQualityOperator:
RULE_SET_NAME = "TestRuleSet"
@@ -542,6 +561,23 @@ class TestGlueDataQualityOperator:
)
validate_template_fields(operator)
+ def test_overwritten_conn_passed_to_hook(self):
+ OVERWRITTEN_CONN = "new-conn-id"
+ op = GlueDataQualityOperator(
+ task_id="test_overwritten_conn_passed_to_hook",
+ name=self.RULE_SET_NAME,
+ ruleset=self.RULE_SET,
+ aws_conn_id=OVERWRITTEN_CONN,
+ )
+ assert op.hook.aws_conn_id == OVERWRITTEN_CONN
+
+ def test_default_conn_passed_to_hook(self):
+ DEFAULT_CONN = "aws_default"
+ op = GlueDataQualityOperator(
+ task_id="test_default_conn_passed_to_hook",
name=self.RULE_SET_NAME, ruleset=self.RULE_SET
+ )
+ assert op.hook.aws_conn_id == DEFAULT_CONN
+
class TestGlueDataQualityRuleSetEvaluationRunOperator:
RUN_ID = "1234567890"
@@ -648,6 +684,29 @@ class TestGlueDataQualityRuleSetEvaluationRunOperator:
def test_template_fields(self):
validate_template_fields(self.operator)
+ def test_overwritten_conn_passed_to_hook(self):
+ OVERWRITTEN_CONN = "new-conn-id"
+ op = GlueDataQualityRuleSetEvaluationRunOperator(
+ task_id="test_overwritten_conn_passed_to_hook",
+ datasource=self.DATA_SOURCE,
+ role=self.ROLE,
+ rule_set_names=self.RULE_SET_NAMES,
+ show_results=False,
+ aws_conn_id=OVERWRITTEN_CONN,
+ )
+ assert op.hook.aws_conn_id == OVERWRITTEN_CONN
+
+ def test_default_conn_passed_to_hook(self):
+ DEFAULT_CONN = "aws_default"
+ op = GlueDataQualityRuleSetEvaluationRunOperator(
+ task_id="test_default_conn_passed_to_hook",
+ datasource=self.DATA_SOURCE,
+ role=self.ROLE,
+ rule_set_names=self.RULE_SET_NAMES,
+ show_results=False,
+ )
+ assert op.hook.aws_conn_id == DEFAULT_CONN
+
class TestGlueDataQualityRuleRecommendationRunOperator:
RUN_ID = "1234567890"
@@ -756,3 +815,28 @@ class TestGlueDataQualityRuleRecommendationRunOperator:
def test_template_fields(self):
validate_template_fields(self.operator)
+
+ def test_overwritten_conn_passed_to_hook(self):
+ OVERWRITTEN_CONN = "new-conn-id"
+ op = GlueDataQualityRuleRecommendationRunOperator(
+ task_id="test_overwritten_conn_passed_to_hook",
+ datasource=self.DATA_SOURCE,
+ role=self.ROLE,
+ number_of_workers=10,
+ timeout=1000,
+ recommendation_run_kwargs={"CreatedRulesetName": "test-ruleset"},
+ aws_conn_id=OVERWRITTEN_CONN,
+ )
+ assert op.hook.aws_conn_id == OVERWRITTEN_CONN
+
+ def test_default_conn_passed_to_hook(self):
+ DEFAULT_CONN = "aws_default"
+ op = GlueDataQualityRuleRecommendationRunOperator(
+ task_id="test_default_conn_passed_to_hook",
+ datasource=self.DATA_SOURCE,
+ role=self.ROLE,
+ number_of_workers=10,
+ timeout=1000,
+ recommendation_run_kwargs={"CreatedRulesetName": "test-ruleset"},
+ )
+ assert op.hook.aws_conn_id == DEFAULT_CONN
diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_rds.py
b/providers/amazon/tests/unit/amazon/aws/operators/test_rds.py
index 412dbc1094a..565780db09a 100644
--- a/providers/amazon/tests/unit/amazon/aws/operators/test_rds.py
+++ b/providers/amazon/tests/unit/amazon/aws/operators/test_rds.py
@@ -181,9 +181,9 @@ class TestBaseRdsOperator:
)
assert op.hook.aws_conn_id == OVERWRITTEN_CONN
- def test_no_conn_passed_to_hook(self):
+ def test_default_conn_passed_to_hook(self):
DEFAULT_CONN = "aws_default"
- op = RdsBaseOperator(task_id="test_no_conn_passed_to_hook_task",
dag=self.dag)
+ op = RdsBaseOperator(task_id="test_default_conn_passed_to_hook_task",
dag=self.dag)
assert op.hook.aws_conn_id == DEFAULT_CONN