This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new c75a105935 add type annotations to Amazon provider "execute_coplete" 
methods (#36330)
c75a105935 is described below

commit c75a1059355df6e1edc42f77947150b4a4c5d51a
Author: Wei Lee <weilee...@gmail.com>
AuthorDate: Fri Feb 16 18:49:40 2024 +0800

    add type annotations to Amazon provider "execute_coplete" methods (#36330)
    
    * style(providers/amazon): improve execute_complete type annotation
    
    * refactor(providers/amazon): change check_execute_complete_event to 
validate_execute_complete_event
    
    * style(providers/amazon): fix ecs type annotation
---
 airflow/providers/amazon/aws/operators/athena.py   |  5 +-
 airflow/providers/amazon/aws/operators/batch.py    | 11 ++---
 airflow/providers/amazon/aws/operators/ecs.py      |  8 +++-
 airflow/providers/amazon/aws/operators/eks.py      | 43 +++++++++--------
 airflow/providers/amazon/aws/operators/emr.py      | 55 ++++++++++++----------
 airflow/providers/amazon/aws/operators/glue.py     |  7 ++-
 .../providers/amazon/aws/operators/glue_crawler.py |  7 ++-
 .../amazon/aws/operators/glue_databrew.py          |  7 ++-
 .../amazon/aws/operators/lambda_function.py        |  3 ++
 airflow/providers/amazon/aws/operators/rds.py      | 33 ++++++++-----
 .../amazon/aws/operators/redshift_cluster.py       | 30 +++++-------
 .../amazon/aws/operators/redshift_data.py          |  6 +--
 .../providers/amazon/aws/operators/sagemaker.py    | 44 +++++++++--------
 .../amazon/aws/operators/step_function.py          |  5 +-
 airflow/providers/amazon/aws/sensors/ec2.py        |  6 ++-
 airflow/providers/amazon/aws/sensors/emr.py        | 19 +++++---
 .../amazon/aws/sensors/glue_catalog_partition.py   |  5 +-
 .../amazon/aws/sensors/redshift_cluster.py         |  6 +--
 airflow/providers/amazon/aws/sensors/s3.py         |  3 ++
 airflow/providers/amazon/aws/sensors/sqs.py        |  5 +-
 airflow/providers/amazon/aws/utils/__init__.py     | 10 ++++
 21 files changed, 190 insertions(+), 128 deletions(-)

diff --git a/airflow/providers/amazon/aws/operators/athena.py 
b/airflow/providers/amazon/aws/operators/athena.py
index 90b2e7cdba..18fb165fc2 100644
--- a/airflow/providers/amazon/aws/operators/athena.py
+++ b/airflow/providers/amazon/aws/operators/athena.py
@@ -26,6 +26,7 @@ from airflow.providers.amazon.aws.hooks.athena import 
AthenaHook
 from airflow.providers.amazon.aws.links.athena import AthenaQueryResultsLink
 from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
 from airflow.providers.amazon.aws.triggers.athena import AthenaTrigger
+from airflow.providers.amazon.aws.utils import validate_execute_complete_event
 from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
 
 if TYPE_CHECKING:
@@ -179,7 +180,9 @@ class AthenaOperator(AwsBaseOperator[AthenaHook]):
 
         return self.query_execution_id
 
-    def execute_complete(self, context, event=None):
+    def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> str:
+        event = validate_execute_complete_event(event)
+
         if event["status"] != "success":
             raise AirflowException(f"Error while waiting for operation on 
cluster to complete: {event}")
         return event["value"]
diff --git a/airflow/providers/amazon/aws/operators/batch.py 
b/airflow/providers/amazon/aws/operators/batch.py
index 8a124b4027..78ad203718 100644
--- a/airflow/providers/amazon/aws/operators/batch.py
+++ b/airflow/providers/amazon/aws/operators/batch.py
@@ -44,7 +44,7 @@ from airflow.providers.amazon.aws.triggers.batch import (
     BatchCreateComputeEnvironmentTrigger,
     BatchJobTrigger,
 )
-from airflow.providers.amazon.aws.utils import trim_none_values
+from airflow.providers.amazon.aws.utils import trim_none_values, 
validate_execute_complete_event
 from airflow.providers.amazon.aws.utils.task_log_fetcher import 
AwsTaskLogFetcher
 
 if TYPE_CHECKING:
@@ -269,10 +269,7 @@ class BatchOperator(BaseOperator):
         return self.job_id
 
     def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> str:
-        if event is None:
-            err_msg = "Trigger error: event is None"
-            self.log.info(err_msg)
-            raise AirflowException(err_msg)
+        event = validate_execute_complete_event(event)
 
         if event["status"] != "success":
             raise AirflowException(f"Error while running job: {event}")
@@ -541,7 +538,9 @@ class BatchCreateComputeEnvironmentOperator(BaseOperator):
         self.log.info("AWS Batch compute environment created successfully")
         return arn
 
-    def execute_complete(self, context, event=None):
+    def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> str:
+        event = validate_execute_complete_event(event)
+
         if event["status"] != "success":
             raise AirflowException(f"Error while waiting for the compute 
environment to be ready: {event}")
         return event["value"]
diff --git a/airflow/providers/amazon/aws/operators/ecs.py 
b/airflow/providers/amazon/aws/operators/ecs.py
index f043a076e6..b5874d98d4 100644
--- a/airflow/providers/amazon/aws/operators/ecs.py
+++ b/airflow/providers/amazon/aws/operators/ecs.py
@@ -21,7 +21,7 @@ import re
 import warnings
 from datetime import timedelta
 from functools import cached_property
-from typing import TYPE_CHECKING, Sequence
+from typing import TYPE_CHECKING, Any, Sequence
 
 from airflow.configuration import conf
 from airflow.exceptions import AirflowException, 
AirflowProviderDeprecationWarning
@@ -35,6 +35,7 @@ from airflow.providers.amazon.aws.triggers.ecs import (
     ClusterInactiveTrigger,
     TaskDoneTrigger,
 )
+from airflow.providers.amazon.aws.utils import validate_execute_complete_event
 from airflow.providers.amazon.aws.utils.identifiers import generate_uuid
 from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
 from airflow.providers.amazon.aws.utils.task_log_fetcher import 
AwsTaskLogFetcher
@@ -580,7 +581,9 @@ class EcsRunTaskOperator(EcsBaseOperator):
         else:
             return None
 
-    def execute_complete(self, context, event=None):
+    def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> str | None:
+        event = validate_execute_complete_event(event)
+
         if event["status"] != "success":
             raise AirflowException(f"Error in task execution: {event}")
         self.arn = event["task_arn"]  # restore arn to its updated value, 
needed for next steps
@@ -596,6 +599,7 @@ class EcsRunTaskOperator(EcsBaseOperator):
             )
             if len(one_log["events"]) > 0:
                 return one_log["events"][0]["message"]
+        return None
 
     def _after_execution(self):
         self._check_success_task()
diff --git a/airflow/providers/amazon/aws/operators/eks.py 
b/airflow/providers/amazon/aws/operators/eks.py
index 70679a6100..7348d08b6f 100644
--- a/airflow/providers/amazon/aws/operators/eks.py
+++ b/airflow/providers/amazon/aws/operators/eks.py
@@ -39,6 +39,7 @@ from airflow.providers.amazon.aws.triggers.eks import (
     EksDeleteFargateProfileTrigger,
     EksDeleteNodegroupTrigger,
 )
+from airflow.providers.amazon.aws.utils import validate_execute_complete_event
 from airflow.providers.amazon.aws.utils.waiter_with_logging import wait
 from airflow.providers.cncf.kubernetes.utils.pod_manager import OnFinishAction
 
@@ -421,11 +422,10 @@ class EksCreateClusterOperator(BaseOperator):
         raise AirflowException("Error creating cluster")
 
     def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> None:
+        event = validate_execute_complete_event(event)
+
         resource = "fargate profile" if self.compute == "fargate" else 
self.compute
-        if event is None:
-            self.log.info("Trigger error: event is None")
-            raise AirflowException("Trigger error: event is None")
-        elif event["status"] != "success":
+        if event["status"] != "success":
             raise AirflowException(f"Error creating {resource}: {event}")
 
         self.log.info("%s created successfully", resource)
@@ -547,10 +547,11 @@ class EksCreateNodegroupOperator(BaseOperator):
                 timeout=timedelta(seconds=self.waiter_max_attempts * 
self.waiter_delay + 60),
             )
 
-    def execute_complete(self, context, event=None):
+    def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> None:
+        event = validate_execute_complete_event(event)
+
         if event["status"] != "success":
             raise AirflowException(f"Error creating nodegroup: {event}")
-        return
 
 
 class EksCreateFargateProfileOperator(BaseOperator):
@@ -656,12 +657,13 @@ class EksCreateFargateProfileOperator(BaseOperator):
                 timeout=timedelta(seconds=(self.waiter_max_attempts * 
self.waiter_delay + 60)),
             )
 
-    def execute_complete(self, context, event=None):
+    def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> None:
+        event = validate_execute_complete_event(event)
+
         if event["status"] != "success":
             raise AirflowException(f"Error creating Fargate profile: {event}")
-        else:
-            self.log.info("Fargate profile created successfully")
-        return
+
+        self.log.info("Fargate profile created successfully")
 
 
 class EksDeleteClusterOperator(BaseOperator):
@@ -788,10 +790,9 @@ class EksDeleteClusterOperator(BaseOperator):
         self.log.info(SUCCESS_MSG.format(compute=FARGATE_FULL_NAME))
 
     def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> None:
-        if event is None:
-            self.log.error("Trigger error. Event is None")
-            raise AirflowException("Trigger error. Event is None")
-        elif event["status"] == "success":
+        event = validate_execute_complete_event(event)
+
+        if event["status"] == "success":
             self.log.info("Cluster deleted successfully.")
 
 
@@ -879,10 +880,11 @@ class EksDeleteNodegroupOperator(BaseOperator):
                 clusterName=self.cluster_name, 
nodegroupName=self.nodegroup_name
             )
 
-    def execute_complete(self, context, event=None):
+    def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> None:
+        event = validate_execute_complete_event(event)
+
         if event["status"] != "success":
             raise AirflowException(f"Error deleting nodegroup: {event}")
-        return
 
 
 class EksDeleteFargateProfileOperator(BaseOperator):
@@ -972,12 +974,13 @@ class EksDeleteFargateProfileOperator(BaseOperator):
                 WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": 
self.waiter_max_attempts},
             )
 
-    def execute_complete(self, context, event=None):
+    def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> None:
+        event = validate_execute_complete_event(event)
+
         if event["status"] != "success":
             raise AirflowException(f"Error deleting Fargate profile: {event}")
-        else:
-            self.log.info("Fargate profile deleted successfully")
-        return
+
+        self.log.info("Fargate profile deleted successfully")
 
 
 class EksPodOperator(KubernetesPodOperator):
diff --git a/airflow/providers/amazon/aws/operators/emr.py 
b/airflow/providers/amazon/aws/operators/emr.py
index 68e1c90296..d6bdb2e318 100644
--- a/airflow/providers/amazon/aws/operators/emr.py
+++ b/airflow/providers/amazon/aws/operators/emr.py
@@ -50,6 +50,7 @@ from airflow.providers.amazon.aws.triggers.emr import (
     EmrServerlessStopApplicationTrigger,
     EmrTerminateJobFlowTrigger,
 )
+from airflow.providers.amazon.aws.utils import validate_execute_complete_event
 from airflow.providers.amazon.aws.utils.waiter import waiter
 from airflow.providers.amazon.aws.utils.waiter_with_logging import wait
 from airflow.utils.helpers import exactly_one, prune_dict
@@ -189,11 +190,13 @@ class EmrAddStepsOperator(BaseOperator):
 
         return step_ids
 
-    def execute_complete(self, context, event=None):
+    def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> str:
+        event = validate_execute_complete_event(event)
+
         if event["status"] != "success":
             raise AirflowException(f"Error while running steps: {event}")
-        else:
-            self.log.info("Steps completed successfully")
+
+        self.log.info("Steps completed successfully")
         return event["value"]
 
 
@@ -633,7 +636,9 @@ class EmrContainerOperator(BaseOperator):
                 f"query_execution_id is {self.job_id}. Error: {error_message}"
             )
 
-    def execute_complete(self, context, event=None):
+    def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> str:
+        event = validate_execute_complete_event(event)
+
         if event["status"] != "success":
             raise AirflowException(f"Error while running job: {event}")
 
@@ -820,11 +825,13 @@ class EmrCreateJobFlowOperator(BaseOperator):
             )
         return self._job_flow_id
 
-    def execute_complete(self, context, event=None):
+    def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> str:
+        event = validate_execute_complete_event(event)
+
         if event["status"] != "success":
             raise AirflowException(f"Error creating jobFlow: {event}")
-        else:
-            self.log.info("JobFlow created successfully")
+
+        self.log.info("JobFlow created successfully")
         return event["job_flow_id"]
 
     def on_kill(self) -> None:
@@ -983,12 +990,13 @@ class EmrTerminateJobFlowOperator(BaseOperator):
                 timeout=timedelta(seconds=self.waiter_max_attempts * 
self.waiter_delay + 60),
             )
 
-    def execute_complete(self, context, event=None):
+    def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> None:
+        event = validate_execute_complete_event(event)
+
         if event["status"] != "success":
             raise AirflowException(f"Error terminating JobFlow: {event}")
-        else:
-            self.log.info("Jobflow terminated successfully.")
-        return
+
+        self.log.info("Jobflow terminated successfully.")
 
 
 class EmrServerlessCreateApplicationOperator(BaseOperator):
@@ -1149,7 +1157,9 @@ class 
EmrServerlessCreateApplicationOperator(BaseOperator):
         )
 
     def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> None:
-        if event is None or event["status"] != "success":
+        event = validate_execute_complete_event(event)
+
+        if event["status"] != "success":
             raise AirflowException(f"Trigger error: Application failed to 
start, event is {event}")
 
         self.log.info("Application %s started", event["application_id"])
@@ -1387,10 +1397,9 @@ class EmrServerlessStartJobOperator(BaseOperator):
         return self.job_id
 
     def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> None:
-        if event is None:
-            self.log.error("Trigger error: event is None")
-            raise AirflowException("Trigger error: event is None")
-        elif event["status"] == "success":
+        event = validate_execute_complete_event(event)
+
+        if event["status"] == "success":
             self.log.info("Serverless job completed")
             return event["job_id"]
 
@@ -1686,10 +1695,9 @@ class EmrServerlessStopApplicationOperator(BaseOperator):
             )
 
     def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> None:
-        if event is None:
-            self.log.error("Trigger error: event is None")
-            raise AirflowException("Trigger error: event is None")
-        elif event["status"] == "success":
+        event = validate_execute_complete_event(event)
+
+        if event["status"] == "success":
             self.log.info("EMR serverless application %s stopped 
successfully", self.application_id)
 
 
@@ -1815,8 +1823,7 @@ class 
EmrServerlessDeleteApplicationOperator(EmrServerlessStopApplicationOperato
         self.log.info("EMR serverless application deleted")
 
     def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> None:
-        if event is None:
-            self.log.error("Trigger error: event is None")
-            raise AirflowException("Trigger error: event is None")
-        elif event["status"] == "success":
+        event = validate_execute_complete_event(event)
+
+        if event["status"] == "success":
             self.log.info("EMR serverless application %s deleted 
successfully", self.application_id)
diff --git a/airflow/providers/amazon/aws/operators/glue.py 
b/airflow/providers/amazon/aws/operators/glue.py
index 97bcaba66a..369a3a02a4 100644
--- a/airflow/providers/amazon/aws/operators/glue.py
+++ b/airflow/providers/amazon/aws/operators/glue.py
@@ -20,7 +20,7 @@ from __future__ import annotations
 import os
 import urllib.parse
 from functools import cached_property
-from typing import TYPE_CHECKING, Sequence
+from typing import TYPE_CHECKING, Any, Sequence
 
 from airflow.configuration import conf
 from airflow.exceptions import AirflowException
@@ -29,6 +29,7 @@ from airflow.providers.amazon.aws.hooks.glue import 
GlueJobHook
 from airflow.providers.amazon.aws.hooks.s3 import S3Hook
 from airflow.providers.amazon.aws.links.glue import GlueJobRunDetailsLink
 from airflow.providers.amazon.aws.triggers.glue import GlueJobCompleteTrigger
+from airflow.providers.amazon.aws.utils import validate_execute_complete_event
 
 if TYPE_CHECKING:
     from airflow.utils.context import Context
@@ -215,7 +216,9 @@ class GlueJobOperator(BaseOperator):
             self.log.info("AWS Glue Job: %s. Run Id: %s", self.job_name, 
self._job_run_id)
         return self._job_run_id
 
-    def execute_complete(self, context, event=None):
+    def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> str:
+        event = validate_execute_complete_event(event)
+
         if event["status"] != "success":
             raise AirflowException(f"Error in glue job: {event}")
         return event["value"]
diff --git a/airflow/providers/amazon/aws/operators/glue_crawler.py 
b/airflow/providers/amazon/aws/operators/glue_crawler.py
index 660d4948d6..a962e7783e 100644
--- a/airflow/providers/amazon/aws/operators/glue_crawler.py
+++ b/airflow/providers/amazon/aws/operators/glue_crawler.py
@@ -18,11 +18,12 @@
 from __future__ import annotations
 
 from functools import cached_property
-from typing import TYPE_CHECKING, Sequence
+from typing import TYPE_CHECKING, Any, Sequence
 
 from airflow.configuration import conf
 from airflow.exceptions import AirflowException
 from airflow.providers.amazon.aws.triggers.glue_crawler import 
GlueCrawlerCompleteTrigger
+from airflow.providers.amazon.aws.utils import validate_execute_complete_event
 
 if TYPE_CHECKING:
     from airflow.utils.context import Context
@@ -107,7 +108,9 @@ class GlueCrawlerOperator(BaseOperator):
 
         return crawler_name
 
-    def execute_complete(self, context, event=None):
+    def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> str:
+        event = validate_execute_complete_event(event)
+
         if event["status"] != "success":
             raise AirflowException(f"Error in glue crawl: {event}")
         return self.config["Name"]
diff --git a/airflow/providers/amazon/aws/operators/glue_databrew.py 
b/airflow/providers/amazon/aws/operators/glue_databrew.py
index 596a507397..3ef0373917 100644
--- a/airflow/providers/amazon/aws/operators/glue_databrew.py
+++ b/airflow/providers/amazon/aws/operators/glue_databrew.py
@@ -18,12 +18,13 @@
 from __future__ import annotations
 
 from functools import cached_property
-from typing import TYPE_CHECKING, Sequence
+from typing import TYPE_CHECKING, Any, Sequence
 
 from airflow.configuration import conf
 from airflow.models import BaseOperator
 from airflow.providers.amazon.aws.hooks.glue_databrew import GlueDataBrewHook
 from airflow.providers.amazon.aws.triggers.glue_databrew import 
GlueDataBrewJobCompleteTrigger
+from airflow.providers.amazon.aws.utils import validate_execute_complete_event
 
 if TYPE_CHECKING:
     from airflow.utils.context import Context
@@ -101,7 +102,9 @@ class GlueDataBrewStartJobOperator(BaseOperator):
 
         return {"run_id": run_id}
 
-    def execute_complete(self, context: Context, event=None) -> dict[str, str]:
+    def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> dict[str, str]:
+        event = validate_execute_complete_event(event)
+
         run_id = event.get("run_id", "")
         status = event.get("status", "")
 
diff --git a/airflow/providers/amazon/aws/operators/lambda_function.py 
b/airflow/providers/amazon/aws/operators/lambda_function.py
index 5dec3f116f..8072c5045c 100644
--- a/airflow/providers/amazon/aws/operators/lambda_function.py
+++ b/airflow/providers/amazon/aws/operators/lambda_function.py
@@ -26,6 +26,7 @@ from airflow.exceptions import AirflowException
 from airflow.providers.amazon.aws.hooks.lambda_function import LambdaHook
 from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
 from airflow.providers.amazon.aws.triggers.lambda_function import 
LambdaCreateFunctionCompleteTrigger
+from airflow.providers.amazon.aws.utils import validate_execute_complete_event
 from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
 
 if TYPE_CHECKING:
@@ -143,6 +144,8 @@ class 
LambdaCreateFunctionOperator(AwsBaseOperator[LambdaHook]):
         return response.get("FunctionArn")
 
     def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> str:
+        event = validate_execute_complete_event(event)
+
         if not event or event["status"] != "success":
             raise AirflowException(f"Trigger error: event is {event}")
 
diff --git a/airflow/providers/amazon/aws/operators/rds.py 
b/airflow/providers/amazon/aws/operators/rds.py
index 6d1f2fd7fb..af3c02ecec 100644
--- a/airflow/providers/amazon/aws/operators/rds.py
+++ b/airflow/providers/amazon/aws/operators/rds.py
@@ -32,6 +32,7 @@ from airflow.providers.amazon.aws.triggers.rds import (
     RdsDbDeletedTrigger,
     RdsDbStoppedTrigger,
 )
+from airflow.providers.amazon.aws.utils import validate_execute_complete_event
 from airflow.providers.amazon.aws.utils.rds import RdsDbType
 from airflow.providers.amazon.aws.utils.tags import format_tags
 from airflow.providers.amazon.aws.utils.waiter_with_logging import wait
@@ -637,11 +638,13 @@ class RdsCreateDbInstanceOperator(RdsBaseOperator):
             )
         return json.dumps(create_db_instance, default=str)
 
-    def execute_complete(self, context, event=None) -> str:
+    def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> str:
+        event = validate_execute_complete_event(event)
+
         if event["status"] != "success":
             raise AirflowException(f"DB instance creation failed: {event}")
-        else:
-            return json.dumps(event["response"], default=str)
+
+        return json.dumps(event["response"], default=str)
 
 
 class RdsDeleteDbInstanceOperator(RdsBaseOperator):
@@ -720,11 +723,13 @@ class RdsDeleteDbInstanceOperator(RdsBaseOperator):
             )
         return json.dumps(delete_db_instance, default=str)
 
-    def execute_complete(self, context, event=None) -> str:
+    def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> str:
+        event = validate_execute_complete_event(event)
+
         if event["status"] != "success":
             raise AirflowException(f"DB instance deletion failed: {event}")
-        else:
-            return json.dumps(event["response"], default=str)
+
+        return json.dumps(event["response"], default=str)
 
 
 class RdsStartDbOperator(RdsBaseOperator):
@@ -786,10 +791,12 @@ class RdsStartDbOperator(RdsBaseOperator):
         return json.dumps(start_db_response, default=str)
 
     def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> str:
-        if event is None or event["status"] != "success":
+        event = validate_execute_complete_event(event)
+
+        if event["status"] != "success":
             raise AirflowException(f"Failed to start DB: {event}")
-        else:
-            return json.dumps(event["response"], default=str)
+
+        return json.dumps(event["response"], default=str)
 
     def _start_db(self):
         self.log.info("Starting DB %s '%s'", self.db_type.value, 
self.db_identifier)
@@ -883,10 +890,12 @@ class RdsStopDbOperator(RdsBaseOperator):
         return json.dumps(stop_db_response, default=str)
 
     def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> str:
-        if event is None or event["status"] != "success":
+        event = validate_execute_complete_event(event)
+
+        if event["status"] != "success":
             raise AirflowException(f"Failed to start DB: {event}")
-        else:
-            return json.dumps(event["response"], default=str)
+
+        return json.dumps(event["response"], default=str)
 
     def _stop_db(self):
         self.log.info("Stopping DB %s '%s'", self.db_type.value, 
self.db_identifier)
diff --git a/airflow/providers/amazon/aws/operators/redshift_cluster.py 
b/airflow/providers/amazon/aws/operators/redshift_cluster.py
index 0fc70a607a..cced10fff3 100644
--- a/airflow/providers/amazon/aws/operators/redshift_cluster.py
+++ b/airflow/providers/amazon/aws/operators/redshift_cluster.py
@@ -31,6 +31,7 @@ from airflow.providers.amazon.aws.triggers.redshift_cluster 
import (
     RedshiftPauseClusterTrigger,
     RedshiftResumeClusterTrigger,
 )
+from airflow.providers.amazon.aws.utils import validate_execute_complete_event
 
 if TYPE_CHECKING:
     from airflow.utils.context import Context
@@ -314,10 +315,11 @@ class RedshiftCreateClusterOperator(BaseOperator):
         self.log.info("Created Redshift cluster %s", self.cluster_identifier)
         self.log.info(cluster)
 
-    def execute_complete(self, context, event=None):
+    def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> None:
+        event = validate_execute_complete_event(event)
+
         if event["status"] != "success":
             raise AirflowException(f"Error creating cluster: {event}")
-        return
 
 
 class RedshiftCreateClusterSnapshotOperator(BaseOperator):
@@ -409,12 +411,13 @@ class RedshiftCreateClusterSnapshotOperator(BaseOperator):
                 },
             )
 
-    def execute_complete(self, context, event=None):
+    def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> None:
+        event = validate_execute_complete_event(event)
+
         if event["status"] != "success":
             raise AirflowException(f"Error creating snapshot: {event}")
-        else:
-            self.log.info("Cluster snapshot created.")
-        return
+
+        self.log.info("Cluster snapshot created.")
 
 
 class RedshiftDeleteClusterSnapshotOperator(BaseOperator):
@@ -569,10 +572,7 @@ class RedshiftResumeClusterOperator(BaseOperator):
             )
 
     def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> None:
-        if event is None:
-            err_msg = "Trigger error: event is None"
-            self.log.info(err_msg)
-            raise AirflowException(err_msg)
+        event = validate_execute_complete_event(event)
 
         if event["status"] != "success":
             raise AirflowException(f"Error resuming cluster: {event}")
@@ -659,10 +659,7 @@ class RedshiftPauseClusterOperator(BaseOperator):
                 )
 
     def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> None:
-        if event is None:
-            err_msg = "Trigger error: event is None"
-            self.log.info(err_msg)
-            raise AirflowException(err_msg)
+        event = validate_execute_complete_event(event)
 
         if event["status"] != "success":
             raise AirflowException(f"Error pausing cluster: {event}")
@@ -767,10 +764,7 @@ class RedshiftDeleteClusterOperator(BaseOperator):
             )
 
     def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> None:
-        if event is None:
-            err_msg = "Trigger error: event is None"
-            self.log.info(err_msg)
-            raise AirflowException(err_msg)
+        event = validate_execute_complete_event(event)
 
         if event["status"] != "success":
             raise AirflowException(f"Error deleting cluster: {event}")
diff --git a/airflow/providers/amazon/aws/operators/redshift_data.py 
b/airflow/providers/amazon/aws/operators/redshift_data.py
index 71ee82069e..54e3c2c7ae 100644
--- a/airflow/providers/amazon/aws/operators/redshift_data.py
+++ b/airflow/providers/amazon/aws/operators/redshift_data.py
@@ -24,6 +24,7 @@ from airflow.exceptions import AirflowException
 from airflow.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook
 from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
 from airflow.providers.amazon.aws.triggers.redshift_data import 
RedshiftDataTrigger
+from airflow.providers.amazon.aws.utils import validate_execute_complete_event
 from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
 
 if TYPE_CHECKING:
@@ -170,10 +171,7 @@ class 
RedshiftDataOperator(AwsBaseOperator[RedshiftDataHook]):
     def execute_complete(
         self, context: Context, event: dict[str, Any] | None = None
     ) -> GetStatementResultResponseTypeDef | str:
-        if event is None:
-            err_msg = "Trigger error: event is None"
-            self.log.info(err_msg)
-            raise AirflowException(err_msg)
+        event = validate_execute_complete_event(event)
 
         if event["status"] == "error":
             msg = f"context: {context}, error message: {event['message']}"
diff --git a/airflow/providers/amazon/aws/operators/sagemaker.py 
b/airflow/providers/amazon/aws/operators/sagemaker.py
index e8cfa26b29..69eeb1e06d 100644
--- a/airflow/providers/amazon/aws/operators/sagemaker.py
+++ b/airflow/providers/amazon/aws/operators/sagemaker.py
@@ -39,7 +39,7 @@ from airflow.providers.amazon.aws.triggers.sagemaker import (
     SageMakerTrainingPrintLogTrigger,
     SageMakerTrigger,
 )
-from airflow.providers.amazon.aws.utils import trim_none_values
+from airflow.providers.amazon.aws.utils import trim_none_values, 
validate_execute_complete_event
 from airflow.providers.amazon.aws.utils.sagemaker import ApprovalStatus
 from airflow.providers.amazon.aws.utils.tags import format_tags
 from airflow.utils.helpers import prune_dict
@@ -315,11 +315,13 @@ class SageMakerProcessingOperator(SageMakerBaseOperator):
         self.serialized_job = 
serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"]))
         return {"Processing": self.serialized_job}
 
-    def execute_complete(self, context, event=None):
+    def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> dict[str, dict]:
+        event = validate_execute_complete_event(event)
+
         if event["status"] != "success":
             raise AirflowException(f"Error while running job: {event}")
-        else:
-            self.log.info(event["message"])
+
+        self.log.info(event["message"])
         self.serialized_job = 
serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"]))
         self.log.info("%s completed successfully.", self.task_id)
         return {"Processing": self.serialized_job}
@@ -566,7 +568,9 @@ class SageMakerEndpointOperator(SageMakerBaseOperator):
             "Endpoint": 
serialize(self.hook.describe_endpoint(endpoint_info["EndpointName"])),
         }
 
-    def execute_complete(self, context, event=None):
+    def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> dict[str, dict]:
+        event = validate_execute_complete_event(event)
+
         if event["status"] != "success":
             raise AirflowException(f"Error while running job: {event}")
         endpoint_info = self.config.get("Endpoint", self.config)
@@ -749,10 +753,7 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
         return self.serialize_result()
 
     def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> dict[str, dict]:
-        if event is None:
-            err_msg = "Trigger error: event is None"
-            self.log.error(err_msg)
-            raise AirflowException(err_msg)
+        event = validate_execute_complete_event(event)
 
         self.log.info(event["message"])
         return self.serialize_result()
@@ -924,7 +925,9 @@ class SageMakerTuningOperator(SageMakerBaseOperator):
 
         return {"Tuning": serialize(description)}
 
-    def execute_complete(self, context, event=None):
+    def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> dict[str, dict]:
+        event = validate_execute_complete_event(event)
+
         if event["status"] != "success":
             raise AirflowException(f"Error while running job: {event}")
         return {
@@ -1154,10 +1157,7 @@ class SageMakerTrainingOperator(SageMakerBaseOperator):
         return self.serialize_result()
 
     def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> dict[str, dict]:
-        if event is None:
-            err_msg = "Trigger error: event is None"
-            self.log.error(err_msg)
-            raise AirflowException(err_msg)
+        event = validate_execute_complete_event(event)
 
         if event["status"] != "success":
             raise AirflowException(f"Error while running job: {event}")
@@ -1296,7 +1296,9 @@ class 
SageMakerStartPipelineOperator(SageMakerBaseOperator):
         return arn
 
     def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> str:
-        if event is None or event["status"] != "success":
+        event = validate_execute_complete_event(event)
+
+        if event["status"] != "success":
             raise AirflowException(f"Failure during pipeline execution: 
{event}")
         return event["value"]
 
@@ -1389,12 +1391,14 @@ class 
SageMakerStopPipelineOperator(SageMakerBaseOperator):
         return status
 
     def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> str:
-        if event is None or event["status"] != "success":
+        event = validate_execute_complete_event(event)
+
+        if event["status"] != "success":
             raise AirflowException(f"Failure during pipeline execution: 
{event}")
-        else:
-            # theoretically we should do a `describe` call to know this,
-            # but if we reach this point, this is the only possible status
-            return "Stopped"
+
+        # theoretically we should do a `describe` call to know this,
+        # but if we reach this point, this is the only possible status
+        return "Stopped"
 
 
 class SageMakerRegisterModelVersionOperator(SageMakerBaseOperator):
diff --git a/airflow/providers/amazon/aws/operators/step_function.py 
b/airflow/providers/amazon/aws/operators/step_function.py
index 067d7e4529..bffb348dfc 100644
--- a/airflow/providers/amazon/aws/operators/step_function.py
+++ b/airflow/providers/amazon/aws/operators/step_function.py
@@ -29,6 +29,7 @@ from airflow.providers.amazon.aws.links.step_function import (
 )
 from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
 from airflow.providers.amazon.aws.triggers.step_function import 
StepFunctionsExecutionCompleteTrigger
+from airflow.providers.amazon.aws.utils import validate_execute_complete_event
 from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
 
 if TYPE_CHECKING:
@@ -129,7 +130,9 @@ class 
StepFunctionStartExecutionOperator(AwsBaseOperator[StepFunctionHook]):
         return execution_arn
 
     def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> None:
-        if event is None or event["status"] != "success":
+        event = validate_execute_complete_event(event)
+
+        if event["status"] != "success":
             raise AirflowException(f"Trigger error: event is {event}")
 
         self.log.info("State Machine execution completed successfully")
diff --git a/airflow/providers/amazon/aws/sensors/ec2.py 
b/airflow/providers/amazon/aws/sensors/ec2.py
index cdebd1b44a..58593cfb30 100644
--- a/airflow/providers/amazon/aws/sensors/ec2.py
+++ b/airflow/providers/amazon/aws/sensors/ec2.py
@@ -24,6 +24,7 @@ from airflow.configuration import conf
 from airflow.exceptions import AirflowException, AirflowSkipException
 from airflow.providers.amazon.aws.hooks.ec2 import EC2Hook
 from airflow.providers.amazon.aws.triggers.ec2 import EC2StateSensorTrigger
+from airflow.providers.amazon.aws.utils import validate_execute_complete_event
 from airflow.sensors.base import BaseSensorOperator
 
 if TYPE_CHECKING:
@@ -92,11 +93,12 @@ class EC2InstanceStateSensor(BaseSensorOperator):
         self.log.info("instance state: %s", instance_state)
         return instance_state == self.target_state
 
-    def execute_complete(self, context, event=None):
+    def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> None:
+        event = validate_execute_complete_event(event)
+
         if event["status"] != "success":
             # TODO: remove this if check when min_airflow_version is set to 
higher than 2.7.1
             message = f"Error: {event}"
             if self.soft_fail:
                 raise AirflowSkipException(message)
             raise AirflowException(message)
-        return
diff --git a/airflow/providers/amazon/aws/sensors/emr.py 
b/airflow/providers/amazon/aws/sensors/emr.py
index 76bce335d2..93ad355764 100644
--- a/airflow/providers/amazon/aws/sensors/emr.py
+++ b/airflow/providers/amazon/aws/sensors/emr.py
@@ -32,6 +32,7 @@ from airflow.providers.amazon.aws.triggers.emr import (
     EmrStepSensorTrigger,
     EmrTerminateJobFlowTrigger,
 )
+from airflow.providers.amazon.aws.utils import validate_execute_complete_event
 from airflow.sensors.base import BaseSensorOperator
 
 if TYPE_CHECKING:
@@ -335,15 +336,17 @@ class EmrContainerSensor(BaseSensorOperator):
                 method_name="execute_complete",
             )
 
-    def execute_complete(self, context, event=None):
+    def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> None:
+        event = validate_execute_complete_event(event)
+
         if event["status"] != "success":
             # TODO: remove this if check when min_airflow_version is set to 
higher than 2.7.1
             message = f"Error while running job: {event}"
             if self.soft_fail:
                 raise AirflowSkipException(message)
             raise AirflowException(message)
-        else:
-            self.log.info("Job completed.")
+
+        self.log.info("Job completed.")
 
 
 class EmrNotebookExecutionSensor(EmrBaseSensor):
@@ -526,7 +529,9 @@ class EmrJobFlowSensor(EmrBaseSensor):
                 method_name="execute_complete",
             )
 
-    def execute_complete(self, context: Context, event=None) -> None:
+    def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> None:
+        event = validate_execute_complete_event(event)
+
         if event["status"] != "success":
             # TODO: remove this if check when min_airflow_version is set to 
higher than 2.7.1
             message = f"Error while running job: {event}"
@@ -657,7 +662,9 @@ class EmrStepSensor(EmrBaseSensor):
                 method_name="execute_complete",
             )
 
-    def execute_complete(self, context, event=None):
+    def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> None:
+        event = validate_execute_complete_event(event)
+
         if event["status"] != "success":
             # TODO: remove this if check when min_airflow_version is set to 
higher than 2.7.1
             message = f"Error while running job: {event}"
@@ -665,4 +672,4 @@ class EmrStepSensor(EmrBaseSensor):
                 raise AirflowSkipException(message)
             raise AirflowException(message)
 
-        self.log.info("Job completed.")
+        self.log.info("Job %s completed.", self.job_flow_id)
diff --git a/airflow/providers/amazon/aws/sensors/glue_catalog_partition.py 
b/airflow/providers/amazon/aws/sensors/glue_catalog_partition.py
index 5245249c51..bc6a6c2560 100644
--- a/airflow/providers/amazon/aws/sensors/glue_catalog_partition.py
+++ b/airflow/providers/amazon/aws/sensors/glue_catalog_partition.py
@@ -27,6 +27,7 @@ from airflow.configuration import conf
 from airflow.exceptions import AirflowException, 
AirflowProviderDeprecationWarning, AirflowSkipException
 from airflow.providers.amazon.aws.hooks.glue_catalog import GlueCatalogHook
 from airflow.providers.amazon.aws.triggers.glue import 
GlueCatalogPartitionTrigger
+from airflow.providers.amazon.aws.utils import validate_execute_complete_event
 from airflow.sensors.base import BaseSensorOperator
 
 if TYPE_CHECKING:
@@ -111,7 +112,9 @@ class GlueCatalogPartitionSensor(BaseSensorOperator):
         return self.hook.check_for_partition(self.database_name, 
self.table_name, self.expression)
 
     def execute_complete(self, context: Context, event: dict | None = None) -> 
None:
-        if event is None or event["status"] != "success":
+        event = validate_execute_complete_event(event)
+
+        if event["status"] != "success":
             # TODO: remove this if block when min_airflow_version is set to 
higher than 2.7.1
             message = f"Trigger error: event is {event}"
             if self.soft_fail:
diff --git a/airflow/providers/amazon/aws/sensors/redshift_cluster.py 
b/airflow/providers/amazon/aws/sensors/redshift_cluster.py
index cd63bb5e1f..4d133b489e 100644
--- a/airflow/providers/amazon/aws/sensors/redshift_cluster.py
+++ b/airflow/providers/amazon/aws/sensors/redshift_cluster.py
@@ -26,6 +26,7 @@ from airflow.configuration import conf
 from airflow.exceptions import AirflowException, 
AirflowProviderDeprecationWarning, AirflowSkipException
 from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftHook
 from airflow.providers.amazon.aws.triggers.redshift_cluster import 
RedshiftClusterTrigger
+from airflow.providers.amazon.aws.utils import validate_execute_complete_event
 from airflow.sensors.base import BaseSensorOperator
 
 if TYPE_CHECKING:
@@ -88,10 +89,7 @@ class RedshiftClusterSensor(BaseSensorOperator):
             )
 
     def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> None:
-        if event is None:
-            err_msg = "Trigger error: event is None"
-            self.log.error(err_msg)
-            raise AirflowException(err_msg)
+        event = validate_execute_complete_event(event)
 
         status = event["status"]
         if status == "error":
diff --git a/airflow/providers/amazon/aws/sensors/s3.py 
b/airflow/providers/amazon/aws/sensors/s3.py
index 6d55a724af..2699120cbb 100644
--- a/airflow/providers/amazon/aws/sensors/s3.py
+++ b/airflow/providers/amazon/aws/sensors/s3.py
@@ -27,6 +27,7 @@ from typing import TYPE_CHECKING, Any, Callable, Sequence, 
cast
 from deprecated import deprecated
 
 from airflow.configuration import conf
+from airflow.providers.amazon.aws.utils import validate_execute_complete_event
 
 if TYPE_CHECKING:
     from airflow.utils.context import Context
@@ -371,6 +372,8 @@ class S3KeysUnchangedSensor(BaseSensorOperator):
 
         Relies on trigger to throw an exception, otherwise it assumes 
execution was successful.
         """
+        event = validate_execute_complete_event(event)
+
         if event and event["status"] == "error":
             # TODO: remove this if block when min_airflow_version is set to 
higher than 2.7.1
             if self.soft_fail:
diff --git a/airflow/providers/amazon/aws/sensors/sqs.py 
b/airflow/providers/amazon/aws/sensors/sqs.py
index 7d9065c080..cbe12b2bd4 100644
--- a/airflow/providers/amazon/aws/sensors/sqs.py
+++ b/airflow/providers/amazon/aws/sensors/sqs.py
@@ -28,6 +28,7 @@ from airflow.exceptions import AirflowException, 
AirflowProviderDeprecationWarni
 from airflow.providers.amazon.aws.hooks.sqs import SqsHook
 from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
 from airflow.providers.amazon.aws.triggers.sqs import SqsSensorTrigger
+from airflow.providers.amazon.aws.utils import validate_execute_complete_event
 from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
 from airflow.providers.amazon.aws.utils.sqs import MessageFilteringType, 
process_response
 
@@ -155,7 +156,9 @@ class SqsSensor(AwsBaseSensor[SqsHook]):
             super().execute(context=context)
 
     def execute_complete(self, context: Context, event: dict | None = None) -> 
None:
-        if event is None or event["status"] != "success":
+        event = validate_execute_complete_event(event)
+
+        if event["status"] != "success":
             # TODO: remove this if block when min_airflow_version is set to 
higher than 2.7.1
             message = f"Trigger error: event is {event}"
             if self.soft_fail:
diff --git a/airflow/providers/amazon/aws/utils/__init__.py 
b/airflow/providers/amazon/aws/utils/__init__.py
index 8495e1587a..2a96c3e447 100644
--- a/airflow/providers/amazon/aws/utils/__init__.py
+++ b/airflow/providers/amazon/aws/utils/__init__.py
@@ -20,7 +20,9 @@ import logging
 import re
 from datetime import datetime, timezone
 from enum import Enum
+from typing import Any
 
+from airflow.exceptions import AirflowException
 from airflow.utils.helpers import prune_dict
 from airflow.version import version
 
@@ -72,6 +74,14 @@ def get_airflow_version() -> tuple[int, ...]:
     return tuple(int(x) for x in match.groups())
 
 
+def validate_execute_complete_event(event: dict[str, Any] | None = None) -> 
dict[str, Any]:
+    if event is None:
+        err_msg = "Trigger error: event is None"
+        log.error(err_msg)
+        raise AirflowException(err_msg)
+    return event
+
+
 class _StringCompareEnum(Enum):
     """
     An Enum class which can be compared with regular `str` and subclasses.


Reply via email to