This is an automated email from the ASF dual-hosted git repository. potiuk pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new c75a105935 add type annotations to Amazon provider "execute_coplete" methods (#36330) c75a105935 is described below commit c75a1059355df6e1edc42f77947150b4a4c5d51a Author: Wei Lee <weilee...@gmail.com> AuthorDate: Fri Feb 16 18:49:40 2024 +0800 add type annotations to Amazon provider "execute_coplete" methods (#36330) * style(providers/amazon): improve execute_complete type annotation * refactor(providers/amazon): change check_execute_complete_event to validate_execute_complete_event * style(providers/amazon): fix ecs type annotation --- airflow/providers/amazon/aws/operators/athena.py | 5 +- airflow/providers/amazon/aws/operators/batch.py | 11 ++--- airflow/providers/amazon/aws/operators/ecs.py | 8 +++- airflow/providers/amazon/aws/operators/eks.py | 43 +++++++++-------- airflow/providers/amazon/aws/operators/emr.py | 55 ++++++++++++---------- airflow/providers/amazon/aws/operators/glue.py | 7 ++- .../providers/amazon/aws/operators/glue_crawler.py | 7 ++- .../amazon/aws/operators/glue_databrew.py | 7 ++- .../amazon/aws/operators/lambda_function.py | 3 ++ airflow/providers/amazon/aws/operators/rds.py | 33 ++++++++----- .../amazon/aws/operators/redshift_cluster.py | 30 +++++------- .../amazon/aws/operators/redshift_data.py | 6 +-- .../providers/amazon/aws/operators/sagemaker.py | 44 +++++++++-------- .../amazon/aws/operators/step_function.py | 5 +- airflow/providers/amazon/aws/sensors/ec2.py | 6 ++- airflow/providers/amazon/aws/sensors/emr.py | 19 +++++--- .../amazon/aws/sensors/glue_catalog_partition.py | 5 +- .../amazon/aws/sensors/redshift_cluster.py | 6 +-- airflow/providers/amazon/aws/sensors/s3.py | 3 ++ airflow/providers/amazon/aws/sensors/sqs.py | 5 +- airflow/providers/amazon/aws/utils/__init__.py | 10 ++++ 21 files changed, 190 insertions(+), 128 deletions(-) diff --git a/airflow/providers/amazon/aws/operators/athena.py b/airflow/providers/amazon/aws/operators/athena.py index 90b2e7cdba..18fb165fc2 100644 --- a/airflow/providers/amazon/aws/operators/athena.py +++ b/airflow/providers/amazon/aws/operators/athena.py @@ -26,6 +26,7 @@ from airflow.providers.amazon.aws.hooks.athena import AthenaHook from airflow.providers.amazon.aws.links.athena import AthenaQueryResultsLink from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator from airflow.providers.amazon.aws.triggers.athena import AthenaTrigger +from airflow.providers.amazon.aws.utils import validate_execute_complete_event from airflow.providers.amazon.aws.utils.mixins import aws_template_fields if TYPE_CHECKING: @@ -179,7 +180,9 @@ class AthenaOperator(AwsBaseOperator[AthenaHook]): return self.query_execution_id - def execute_complete(self, context, event=None): + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str: + event = validate_execute_complete_event(event) + if event["status"] != "success": raise AirflowException(f"Error while waiting for operation on cluster to complete: {event}") return event["value"] diff --git a/airflow/providers/amazon/aws/operators/batch.py b/airflow/providers/amazon/aws/operators/batch.py index 8a124b4027..78ad203718 100644 --- a/airflow/providers/amazon/aws/operators/batch.py +++ b/airflow/providers/amazon/aws/operators/batch.py @@ -44,7 +44,7 @@ from airflow.providers.amazon.aws.triggers.batch import ( BatchCreateComputeEnvironmentTrigger, BatchJobTrigger, ) -from airflow.providers.amazon.aws.utils import trim_none_values +from airflow.providers.amazon.aws.utils import trim_none_values, validate_execute_complete_event from airflow.providers.amazon.aws.utils.task_log_fetcher import AwsTaskLogFetcher if TYPE_CHECKING: @@ -269,10 +269,7 @@ class BatchOperator(BaseOperator): return self.job_id def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str: - if event is None: - err_msg = "Trigger error: event is None" - self.log.info(err_msg) - raise AirflowException(err_msg) + event = validate_execute_complete_event(event) if event["status"] != "success": raise AirflowException(f"Error while running job: {event}") @@ -541,7 +538,9 @@ class BatchCreateComputeEnvironmentOperator(BaseOperator): self.log.info("AWS Batch compute environment created successfully") return arn - def execute_complete(self, context, event=None): + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str: + event = validate_execute_complete_event(event) + if event["status"] != "success": raise AirflowException(f"Error while waiting for the compute environment to be ready: {event}") return event["value"] diff --git a/airflow/providers/amazon/aws/operators/ecs.py b/airflow/providers/amazon/aws/operators/ecs.py index f043a076e6..b5874d98d4 100644 --- a/airflow/providers/amazon/aws/operators/ecs.py +++ b/airflow/providers/amazon/aws/operators/ecs.py @@ -21,7 +21,7 @@ import re import warnings from datetime import timedelta from functools import cached_property -from typing import TYPE_CHECKING, Sequence +from typing import TYPE_CHECKING, Any, Sequence from airflow.configuration import conf from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning @@ -35,6 +35,7 @@ from airflow.providers.amazon.aws.triggers.ecs import ( ClusterInactiveTrigger, TaskDoneTrigger, ) +from airflow.providers.amazon.aws.utils import validate_execute_complete_event from airflow.providers.amazon.aws.utils.identifiers import generate_uuid from airflow.providers.amazon.aws.utils.mixins import aws_template_fields from airflow.providers.amazon.aws.utils.task_log_fetcher import AwsTaskLogFetcher @@ -580,7 +581,9 @@ class EcsRunTaskOperator(EcsBaseOperator): else: return None - def execute_complete(self, context, event=None): + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str | None: + event = validate_execute_complete_event(event) + if event["status"] != "success": raise AirflowException(f"Error in task execution: {event}") self.arn = event["task_arn"] # restore arn to its updated value, needed for next steps @@ -596,6 +599,7 @@ class EcsRunTaskOperator(EcsBaseOperator): ) if len(one_log["events"]) > 0: return one_log["events"][0]["message"] + return None def _after_execution(self): self._check_success_task() diff --git a/airflow/providers/amazon/aws/operators/eks.py b/airflow/providers/amazon/aws/operators/eks.py index 70679a6100..7348d08b6f 100644 --- a/airflow/providers/amazon/aws/operators/eks.py +++ b/airflow/providers/amazon/aws/operators/eks.py @@ -39,6 +39,7 @@ from airflow.providers.amazon.aws.triggers.eks import ( EksDeleteFargateProfileTrigger, EksDeleteNodegroupTrigger, ) +from airflow.providers.amazon.aws.utils import validate_execute_complete_event from airflow.providers.amazon.aws.utils.waiter_with_logging import wait from airflow.providers.cncf.kubernetes.utils.pod_manager import OnFinishAction @@ -421,11 +422,10 @@ class EksCreateClusterOperator(BaseOperator): raise AirflowException("Error creating cluster") def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None: + event = validate_execute_complete_event(event) + resource = "fargate profile" if self.compute == "fargate" else self.compute - if event is None: - self.log.info("Trigger error: event is None") - raise AirflowException("Trigger error: event is None") - elif event["status"] != "success": + if event["status"] != "success": raise AirflowException(f"Error creating {resource}: {event}") self.log.info("%s created successfully", resource) @@ -547,10 +547,11 @@ class EksCreateNodegroupOperator(BaseOperator): timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay + 60), ) - def execute_complete(self, context, event=None): + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None: + event = validate_execute_complete_event(event) + if event["status"] != "success": raise AirflowException(f"Error creating nodegroup: {event}") - return class EksCreateFargateProfileOperator(BaseOperator): @@ -656,12 +657,13 @@ class EksCreateFargateProfileOperator(BaseOperator): timeout=timedelta(seconds=(self.waiter_max_attempts * self.waiter_delay + 60)), ) - def execute_complete(self, context, event=None): + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None: + event = validate_execute_complete_event(event) + if event["status"] != "success": raise AirflowException(f"Error creating Fargate profile: {event}") - else: - self.log.info("Fargate profile created successfully") - return + + self.log.info("Fargate profile created successfully") class EksDeleteClusterOperator(BaseOperator): @@ -788,10 +790,9 @@ class EksDeleteClusterOperator(BaseOperator): self.log.info(SUCCESS_MSG.format(compute=FARGATE_FULL_NAME)) def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None: - if event is None: - self.log.error("Trigger error. Event is None") - raise AirflowException("Trigger error. Event is None") - elif event["status"] == "success": + event = validate_execute_complete_event(event) + + if event["status"] == "success": self.log.info("Cluster deleted successfully.") @@ -879,10 +880,11 @@ class EksDeleteNodegroupOperator(BaseOperator): clusterName=self.cluster_name, nodegroupName=self.nodegroup_name ) - def execute_complete(self, context, event=None): + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None: + event = validate_execute_complete_event(event) + if event["status"] != "success": raise AirflowException(f"Error deleting nodegroup: {event}") - return class EksDeleteFargateProfileOperator(BaseOperator): @@ -972,12 +974,13 @@ class EksDeleteFargateProfileOperator(BaseOperator): WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": self.waiter_max_attempts}, ) - def execute_complete(self, context, event=None): + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None: + event = validate_execute_complete_event(event) + if event["status"] != "success": raise AirflowException(f"Error deleting Fargate profile: {event}") - else: - self.log.info("Fargate profile deleted successfully") - return + + self.log.info("Fargate profile deleted successfully") class EksPodOperator(KubernetesPodOperator): diff --git a/airflow/providers/amazon/aws/operators/emr.py b/airflow/providers/amazon/aws/operators/emr.py index 68e1c90296..d6bdb2e318 100644 --- a/airflow/providers/amazon/aws/operators/emr.py +++ b/airflow/providers/amazon/aws/operators/emr.py @@ -50,6 +50,7 @@ from airflow.providers.amazon.aws.triggers.emr import ( EmrServerlessStopApplicationTrigger, EmrTerminateJobFlowTrigger, ) +from airflow.providers.amazon.aws.utils import validate_execute_complete_event from airflow.providers.amazon.aws.utils.waiter import waiter from airflow.providers.amazon.aws.utils.waiter_with_logging import wait from airflow.utils.helpers import exactly_one, prune_dict @@ -189,11 +190,13 @@ class EmrAddStepsOperator(BaseOperator): return step_ids - def execute_complete(self, context, event=None): + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str: + event = validate_execute_complete_event(event) + if event["status"] != "success": raise AirflowException(f"Error while running steps: {event}") - else: - self.log.info("Steps completed successfully") + + self.log.info("Steps completed successfully") return event["value"] @@ -633,7 +636,9 @@ class EmrContainerOperator(BaseOperator): f"query_execution_id is {self.job_id}. Error: {error_message}" ) - def execute_complete(self, context, event=None): + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str: + event = validate_execute_complete_event(event) + if event["status"] != "success": raise AirflowException(f"Error while running job: {event}") @@ -820,11 +825,13 @@ class EmrCreateJobFlowOperator(BaseOperator): ) return self._job_flow_id - def execute_complete(self, context, event=None): + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str: + event = validate_execute_complete_event(event) + if event["status"] != "success": raise AirflowException(f"Error creating jobFlow: {event}") - else: - self.log.info("JobFlow created successfully") + + self.log.info("JobFlow created successfully") return event["job_flow_id"] def on_kill(self) -> None: @@ -983,12 +990,13 @@ class EmrTerminateJobFlowOperator(BaseOperator): timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay + 60), ) - def execute_complete(self, context, event=None): + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None: + event = validate_execute_complete_event(event) + if event["status"] != "success": raise AirflowException(f"Error terminating JobFlow: {event}") - else: - self.log.info("Jobflow terminated successfully.") - return + + self.log.info("Jobflow terminated successfully.") class EmrServerlessCreateApplicationOperator(BaseOperator): @@ -1149,7 +1157,9 @@ class EmrServerlessCreateApplicationOperator(BaseOperator): ) def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None: - if event is None or event["status"] != "success": + event = validate_execute_complete_event(event) + + if event["status"] != "success": raise AirflowException(f"Trigger error: Application failed to start, event is {event}") self.log.info("Application %s started", event["application_id"]) @@ -1387,10 +1397,9 @@ class EmrServerlessStartJobOperator(BaseOperator): return self.job_id def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None: - if event is None: - self.log.error("Trigger error: event is None") - raise AirflowException("Trigger error: event is None") - elif event["status"] == "success": + event = validate_execute_complete_event(event) + + if event["status"] == "success": self.log.info("Serverless job completed") return event["job_id"] @@ -1686,10 +1695,9 @@ class EmrServerlessStopApplicationOperator(BaseOperator): ) def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None: - if event is None: - self.log.error("Trigger error: event is None") - raise AirflowException("Trigger error: event is None") - elif event["status"] == "success": + event = validate_execute_complete_event(event) + + if event["status"] == "success": self.log.info("EMR serverless application %s stopped successfully", self.application_id) @@ -1815,8 +1823,7 @@ class EmrServerlessDeleteApplicationOperator(EmrServerlessStopApplicationOperato self.log.info("EMR serverless application deleted") def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None: - if event is None: - self.log.error("Trigger error: event is None") - raise AirflowException("Trigger error: event is None") - elif event["status"] == "success": + event = validate_execute_complete_event(event) + + if event["status"] == "success": self.log.info("EMR serverless application %s deleted successfully", self.application_id) diff --git a/airflow/providers/amazon/aws/operators/glue.py b/airflow/providers/amazon/aws/operators/glue.py index 97bcaba66a..369a3a02a4 100644 --- a/airflow/providers/amazon/aws/operators/glue.py +++ b/airflow/providers/amazon/aws/operators/glue.py @@ -20,7 +20,7 @@ from __future__ import annotations import os import urllib.parse from functools import cached_property -from typing import TYPE_CHECKING, Sequence +from typing import TYPE_CHECKING, Any, Sequence from airflow.configuration import conf from airflow.exceptions import AirflowException @@ -29,6 +29,7 @@ from airflow.providers.amazon.aws.hooks.glue import GlueJobHook from airflow.providers.amazon.aws.hooks.s3 import S3Hook from airflow.providers.amazon.aws.links.glue import GlueJobRunDetailsLink from airflow.providers.amazon.aws.triggers.glue import GlueJobCompleteTrigger +from airflow.providers.amazon.aws.utils import validate_execute_complete_event if TYPE_CHECKING: from airflow.utils.context import Context @@ -215,7 +216,9 @@ class GlueJobOperator(BaseOperator): self.log.info("AWS Glue Job: %s. Run Id: %s", self.job_name, self._job_run_id) return self._job_run_id - def execute_complete(self, context, event=None): + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str: + event = validate_execute_complete_event(event) + if event["status"] != "success": raise AirflowException(f"Error in glue job: {event}") return event["value"] diff --git a/airflow/providers/amazon/aws/operators/glue_crawler.py b/airflow/providers/amazon/aws/operators/glue_crawler.py index 660d4948d6..a962e7783e 100644 --- a/airflow/providers/amazon/aws/operators/glue_crawler.py +++ b/airflow/providers/amazon/aws/operators/glue_crawler.py @@ -18,11 +18,12 @@ from __future__ import annotations from functools import cached_property -from typing import TYPE_CHECKING, Sequence +from typing import TYPE_CHECKING, Any, Sequence from airflow.configuration import conf from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.triggers.glue_crawler import GlueCrawlerCompleteTrigger +from airflow.providers.amazon.aws.utils import validate_execute_complete_event if TYPE_CHECKING: from airflow.utils.context import Context @@ -107,7 +108,9 @@ class GlueCrawlerOperator(BaseOperator): return crawler_name - def execute_complete(self, context, event=None): + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str: + event = validate_execute_complete_event(event) + if event["status"] != "success": raise AirflowException(f"Error in glue crawl: {event}") return self.config["Name"] diff --git a/airflow/providers/amazon/aws/operators/glue_databrew.py b/airflow/providers/amazon/aws/operators/glue_databrew.py index 596a507397..3ef0373917 100644 --- a/airflow/providers/amazon/aws/operators/glue_databrew.py +++ b/airflow/providers/amazon/aws/operators/glue_databrew.py @@ -18,12 +18,13 @@ from __future__ import annotations from functools import cached_property -from typing import TYPE_CHECKING, Sequence +from typing import TYPE_CHECKING, Any, Sequence from airflow.configuration import conf from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.glue_databrew import GlueDataBrewHook from airflow.providers.amazon.aws.triggers.glue_databrew import GlueDataBrewJobCompleteTrigger +from airflow.providers.amazon.aws.utils import validate_execute_complete_event if TYPE_CHECKING: from airflow.utils.context import Context @@ -101,7 +102,9 @@ class GlueDataBrewStartJobOperator(BaseOperator): return {"run_id": run_id} - def execute_complete(self, context: Context, event=None) -> dict[str, str]: + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, str]: + event = validate_execute_complete_event(event) + run_id = event.get("run_id", "") status = event.get("status", "") diff --git a/airflow/providers/amazon/aws/operators/lambda_function.py b/airflow/providers/amazon/aws/operators/lambda_function.py index 5dec3f116f..8072c5045c 100644 --- a/airflow/providers/amazon/aws/operators/lambda_function.py +++ b/airflow/providers/amazon/aws/operators/lambda_function.py @@ -26,6 +26,7 @@ from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.lambda_function import LambdaHook from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator from airflow.providers.amazon.aws.triggers.lambda_function import LambdaCreateFunctionCompleteTrigger +from airflow.providers.amazon.aws.utils import validate_execute_complete_event from airflow.providers.amazon.aws.utils.mixins import aws_template_fields if TYPE_CHECKING: @@ -143,6 +144,8 @@ class LambdaCreateFunctionOperator(AwsBaseOperator[LambdaHook]): return response.get("FunctionArn") def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str: + event = validate_execute_complete_event(event) + if not event or event["status"] != "success": raise AirflowException(f"Trigger error: event is {event}") diff --git a/airflow/providers/amazon/aws/operators/rds.py b/airflow/providers/amazon/aws/operators/rds.py index 6d1f2fd7fb..af3c02ecec 100644 --- a/airflow/providers/amazon/aws/operators/rds.py +++ b/airflow/providers/amazon/aws/operators/rds.py @@ -32,6 +32,7 @@ from airflow.providers.amazon.aws.triggers.rds import ( RdsDbDeletedTrigger, RdsDbStoppedTrigger, ) +from airflow.providers.amazon.aws.utils import validate_execute_complete_event from airflow.providers.amazon.aws.utils.rds import RdsDbType from airflow.providers.amazon.aws.utils.tags import format_tags from airflow.providers.amazon.aws.utils.waiter_with_logging import wait @@ -637,11 +638,13 @@ class RdsCreateDbInstanceOperator(RdsBaseOperator): ) return json.dumps(create_db_instance, default=str) - def execute_complete(self, context, event=None) -> str: + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str: + event = validate_execute_complete_event(event) + if event["status"] != "success": raise AirflowException(f"DB instance creation failed: {event}") - else: - return json.dumps(event["response"], default=str) + + return json.dumps(event["response"], default=str) class RdsDeleteDbInstanceOperator(RdsBaseOperator): @@ -720,11 +723,13 @@ class RdsDeleteDbInstanceOperator(RdsBaseOperator): ) return json.dumps(delete_db_instance, default=str) - def execute_complete(self, context, event=None) -> str: + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str: + event = validate_execute_complete_event(event) + if event["status"] != "success": raise AirflowException(f"DB instance deletion failed: {event}") - else: - return json.dumps(event["response"], default=str) + + return json.dumps(event["response"], default=str) class RdsStartDbOperator(RdsBaseOperator): @@ -786,10 +791,12 @@ class RdsStartDbOperator(RdsBaseOperator): return json.dumps(start_db_response, default=str) def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str: - if event is None or event["status"] != "success": + event = validate_execute_complete_event(event) + + if event["status"] != "success": raise AirflowException(f"Failed to start DB: {event}") - else: - return json.dumps(event["response"], default=str) + + return json.dumps(event["response"], default=str) def _start_db(self): self.log.info("Starting DB %s '%s'", self.db_type.value, self.db_identifier) @@ -883,10 +890,12 @@ class RdsStopDbOperator(RdsBaseOperator): return json.dumps(stop_db_response, default=str) def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str: - if event is None or event["status"] != "success": + event = validate_execute_complete_event(event) + + if event["status"] != "success": raise AirflowException(f"Failed to start DB: {event}") - else: - return json.dumps(event["response"], default=str) + + return json.dumps(event["response"], default=str) def _stop_db(self): self.log.info("Stopping DB %s '%s'", self.db_type.value, self.db_identifier) diff --git a/airflow/providers/amazon/aws/operators/redshift_cluster.py b/airflow/providers/amazon/aws/operators/redshift_cluster.py index 0fc70a607a..cced10fff3 100644 --- a/airflow/providers/amazon/aws/operators/redshift_cluster.py +++ b/airflow/providers/amazon/aws/operators/redshift_cluster.py @@ -31,6 +31,7 @@ from airflow.providers.amazon.aws.triggers.redshift_cluster import ( RedshiftPauseClusterTrigger, RedshiftResumeClusterTrigger, ) +from airflow.providers.amazon.aws.utils import validate_execute_complete_event if TYPE_CHECKING: from airflow.utils.context import Context @@ -314,10 +315,11 @@ class RedshiftCreateClusterOperator(BaseOperator): self.log.info("Created Redshift cluster %s", self.cluster_identifier) self.log.info(cluster) - def execute_complete(self, context, event=None): + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None: + event = validate_execute_complete_event(event) + if event["status"] != "success": raise AirflowException(f"Error creating cluster: {event}") - return class RedshiftCreateClusterSnapshotOperator(BaseOperator): @@ -409,12 +411,13 @@ class RedshiftCreateClusterSnapshotOperator(BaseOperator): }, ) - def execute_complete(self, context, event=None): + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None: + event = validate_execute_complete_event(event) + if event["status"] != "success": raise AirflowException(f"Error creating snapshot: {event}") - else: - self.log.info("Cluster snapshot created.") - return + + self.log.info("Cluster snapshot created.") class RedshiftDeleteClusterSnapshotOperator(BaseOperator): @@ -569,10 +572,7 @@ class RedshiftResumeClusterOperator(BaseOperator): ) def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None: - if event is None: - err_msg = "Trigger error: event is None" - self.log.info(err_msg) - raise AirflowException(err_msg) + event = validate_execute_complete_event(event) if event["status"] != "success": raise AirflowException(f"Error resuming cluster: {event}") @@ -659,10 +659,7 @@ class RedshiftPauseClusterOperator(BaseOperator): ) def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None: - if event is None: - err_msg = "Trigger error: event is None" - self.log.info(err_msg) - raise AirflowException(err_msg) + event = validate_execute_complete_event(event) if event["status"] != "success": raise AirflowException(f"Error pausing cluster: {event}") @@ -767,10 +764,7 @@ class RedshiftDeleteClusterOperator(BaseOperator): ) def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None: - if event is None: - err_msg = "Trigger error: event is None" - self.log.info(err_msg) - raise AirflowException(err_msg) + event = validate_execute_complete_event(event) if event["status"] != "success": raise AirflowException(f"Error deleting cluster: {event}") diff --git a/airflow/providers/amazon/aws/operators/redshift_data.py b/airflow/providers/amazon/aws/operators/redshift_data.py index 71ee82069e..54e3c2c7ae 100644 --- a/airflow/providers/amazon/aws/operators/redshift_data.py +++ b/airflow/providers/amazon/aws/operators/redshift_data.py @@ -24,6 +24,7 @@ from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator from airflow.providers.amazon.aws.triggers.redshift_data import RedshiftDataTrigger +from airflow.providers.amazon.aws.utils import validate_execute_complete_event from airflow.providers.amazon.aws.utils.mixins import aws_template_fields if TYPE_CHECKING: @@ -170,10 +171,7 @@ class RedshiftDataOperator(AwsBaseOperator[RedshiftDataHook]): def execute_complete( self, context: Context, event: dict[str, Any] | None = None ) -> GetStatementResultResponseTypeDef | str: - if event is None: - err_msg = "Trigger error: event is None" - self.log.info(err_msg) - raise AirflowException(err_msg) + event = validate_execute_complete_event(event) if event["status"] == "error": msg = f"context: {context}, error message: {event['message']}" diff --git a/airflow/providers/amazon/aws/operators/sagemaker.py b/airflow/providers/amazon/aws/operators/sagemaker.py index e8cfa26b29..69eeb1e06d 100644 --- a/airflow/providers/amazon/aws/operators/sagemaker.py +++ b/airflow/providers/amazon/aws/operators/sagemaker.py @@ -39,7 +39,7 @@ from airflow.providers.amazon.aws.triggers.sagemaker import ( SageMakerTrainingPrintLogTrigger, SageMakerTrigger, ) -from airflow.providers.amazon.aws.utils import trim_none_values +from airflow.providers.amazon.aws.utils import trim_none_values, validate_execute_complete_event from airflow.providers.amazon.aws.utils.sagemaker import ApprovalStatus from airflow.providers.amazon.aws.utils.tags import format_tags from airflow.utils.helpers import prune_dict @@ -315,11 +315,13 @@ class SageMakerProcessingOperator(SageMakerBaseOperator): self.serialized_job = serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"])) return {"Processing": self.serialized_job} - def execute_complete(self, context, event=None): + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, dict]: + event = validate_execute_complete_event(event) + if event["status"] != "success": raise AirflowException(f"Error while running job: {event}") - else: - self.log.info(event["message"]) + + self.log.info(event["message"]) self.serialized_job = serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"])) self.log.info("%s completed successfully.", self.task_id) return {"Processing": self.serialized_job} @@ -566,7 +568,9 @@ class SageMakerEndpointOperator(SageMakerBaseOperator): "Endpoint": serialize(self.hook.describe_endpoint(endpoint_info["EndpointName"])), } - def execute_complete(self, context, event=None): + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, dict]: + event = validate_execute_complete_event(event) + if event["status"] != "success": raise AirflowException(f"Error while running job: {event}") endpoint_info = self.config.get("Endpoint", self.config) @@ -749,10 +753,7 @@ class SageMakerTransformOperator(SageMakerBaseOperator): return self.serialize_result() def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, dict]: - if event is None: - err_msg = "Trigger error: event is None" - self.log.error(err_msg) - raise AirflowException(err_msg) + event = validate_execute_complete_event(event) self.log.info(event["message"]) return self.serialize_result() @@ -924,7 +925,9 @@ class SageMakerTuningOperator(SageMakerBaseOperator): return {"Tuning": serialize(description)} - def execute_complete(self, context, event=None): + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, dict]: + event = validate_execute_complete_event(event) + if event["status"] != "success": raise AirflowException(f"Error while running job: {event}") return { @@ -1154,10 +1157,7 @@ class SageMakerTrainingOperator(SageMakerBaseOperator): return self.serialize_result() def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, dict]: - if event is None: - err_msg = "Trigger error: event is None" - self.log.error(err_msg) - raise AirflowException(err_msg) + event = validate_execute_complete_event(event) if event["status"] != "success": raise AirflowException(f"Error while running job: {event}") @@ -1296,7 +1296,9 @@ class SageMakerStartPipelineOperator(SageMakerBaseOperator): return arn def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str: - if event is None or event["status"] != "success": + event = validate_execute_complete_event(event) + + if event["status"] != "success": raise AirflowException(f"Failure during pipeline execution: {event}") return event["value"] @@ -1389,12 +1391,14 @@ class SageMakerStopPipelineOperator(SageMakerBaseOperator): return status def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str: - if event is None or event["status"] != "success": + event = validate_execute_complete_event(event) + + if event["status"] != "success": raise AirflowException(f"Failure during pipeline execution: {event}") - else: - # theoretically we should do a `describe` call to know this, - # but if we reach this point, this is the only possible status - return "Stopped" + + # theoretically we should do a `describe` call to know this, + # but if we reach this point, this is the only possible status + return "Stopped" class SageMakerRegisterModelVersionOperator(SageMakerBaseOperator): diff --git a/airflow/providers/amazon/aws/operators/step_function.py b/airflow/providers/amazon/aws/operators/step_function.py index 067d7e4529..bffb348dfc 100644 --- a/airflow/providers/amazon/aws/operators/step_function.py +++ b/airflow/providers/amazon/aws/operators/step_function.py @@ -29,6 +29,7 @@ from airflow.providers.amazon.aws.links.step_function import ( ) from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator from airflow.providers.amazon.aws.triggers.step_function import StepFunctionsExecutionCompleteTrigger +from airflow.providers.amazon.aws.utils import validate_execute_complete_event from airflow.providers.amazon.aws.utils.mixins import aws_template_fields if TYPE_CHECKING: @@ -129,7 +130,9 @@ class StepFunctionStartExecutionOperator(AwsBaseOperator[StepFunctionHook]): return execution_arn def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None: - if event is None or event["status"] != "success": + event = validate_execute_complete_event(event) + + if event["status"] != "success": raise AirflowException(f"Trigger error: event is {event}") self.log.info("State Machine execution completed successfully") diff --git a/airflow/providers/amazon/aws/sensors/ec2.py b/airflow/providers/amazon/aws/sensors/ec2.py index cdebd1b44a..58593cfb30 100644 --- a/airflow/providers/amazon/aws/sensors/ec2.py +++ b/airflow/providers/amazon/aws/sensors/ec2.py @@ -24,6 +24,7 @@ from airflow.configuration import conf from airflow.exceptions import AirflowException, AirflowSkipException from airflow.providers.amazon.aws.hooks.ec2 import EC2Hook from airflow.providers.amazon.aws.triggers.ec2 import EC2StateSensorTrigger +from airflow.providers.amazon.aws.utils import validate_execute_complete_event from airflow.sensors.base import BaseSensorOperator if TYPE_CHECKING: @@ -92,11 +93,12 @@ class EC2InstanceStateSensor(BaseSensorOperator): self.log.info("instance state: %s", instance_state) return instance_state == self.target_state - def execute_complete(self, context, event=None): + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None: + event = validate_execute_complete_event(event) + if event["status"] != "success": # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 message = f"Error: {event}" if self.soft_fail: raise AirflowSkipException(message) raise AirflowException(message) - return diff --git a/airflow/providers/amazon/aws/sensors/emr.py b/airflow/providers/amazon/aws/sensors/emr.py index 76bce335d2..93ad355764 100644 --- a/airflow/providers/amazon/aws/sensors/emr.py +++ b/airflow/providers/amazon/aws/sensors/emr.py @@ -32,6 +32,7 @@ from airflow.providers.amazon.aws.triggers.emr import ( EmrStepSensorTrigger, EmrTerminateJobFlowTrigger, ) +from airflow.providers.amazon.aws.utils import validate_execute_complete_event from airflow.sensors.base import BaseSensorOperator if TYPE_CHECKING: @@ -335,15 +336,17 @@ class EmrContainerSensor(BaseSensorOperator): method_name="execute_complete", ) - def execute_complete(self, context, event=None): + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None: + event = validate_execute_complete_event(event) + if event["status"] != "success": # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 message = f"Error while running job: {event}" if self.soft_fail: raise AirflowSkipException(message) raise AirflowException(message) - else: - self.log.info("Job completed.") + + self.log.info("Job completed.") class EmrNotebookExecutionSensor(EmrBaseSensor): @@ -526,7 +529,9 @@ class EmrJobFlowSensor(EmrBaseSensor): method_name="execute_complete", ) - def execute_complete(self, context: Context, event=None) -> None: + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None: + event = validate_execute_complete_event(event) + if event["status"] != "success": # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 message = f"Error while running job: {event}" @@ -657,7 +662,9 @@ class EmrStepSensor(EmrBaseSensor): method_name="execute_complete", ) - def execute_complete(self, context, event=None): + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None: + event = validate_execute_complete_event(event) + if event["status"] != "success": # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 message = f"Error while running job: {event}" @@ -665,4 +672,4 @@ class EmrStepSensor(EmrBaseSensor): raise AirflowSkipException(message) raise AirflowException(message) - self.log.info("Job completed.") + self.log.info("Job %s completed.", self.job_flow_id) diff --git a/airflow/providers/amazon/aws/sensors/glue_catalog_partition.py b/airflow/providers/amazon/aws/sensors/glue_catalog_partition.py index 5245249c51..bc6a6c2560 100644 --- a/airflow/providers/amazon/aws/sensors/glue_catalog_partition.py +++ b/airflow/providers/amazon/aws/sensors/glue_catalog_partition.py @@ -27,6 +27,7 @@ from airflow.configuration import conf from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException from airflow.providers.amazon.aws.hooks.glue_catalog import GlueCatalogHook from airflow.providers.amazon.aws.triggers.glue import GlueCatalogPartitionTrigger +from airflow.providers.amazon.aws.utils import validate_execute_complete_event from airflow.sensors.base import BaseSensorOperator if TYPE_CHECKING: @@ -111,7 +112,9 @@ class GlueCatalogPartitionSensor(BaseSensorOperator): return self.hook.check_for_partition(self.database_name, self.table_name, self.expression) def execute_complete(self, context: Context, event: dict | None = None) -> None: - if event is None or event["status"] != "success": + event = validate_execute_complete_event(event) + + if event["status"] != "success": # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 message = f"Trigger error: event is {event}" if self.soft_fail: diff --git a/airflow/providers/amazon/aws/sensors/redshift_cluster.py b/airflow/providers/amazon/aws/sensors/redshift_cluster.py index cd63bb5e1f..4d133b489e 100644 --- a/airflow/providers/amazon/aws/sensors/redshift_cluster.py +++ b/airflow/providers/amazon/aws/sensors/redshift_cluster.py @@ -26,6 +26,7 @@ from airflow.configuration import conf from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftHook from airflow.providers.amazon.aws.triggers.redshift_cluster import RedshiftClusterTrigger +from airflow.providers.amazon.aws.utils import validate_execute_complete_event from airflow.sensors.base import BaseSensorOperator if TYPE_CHECKING: @@ -88,10 +89,7 @@ class RedshiftClusterSensor(BaseSensorOperator): ) def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None: - if event is None: - err_msg = "Trigger error: event is None" - self.log.error(err_msg) - raise AirflowException(err_msg) + event = validate_execute_complete_event(event) status = event["status"] if status == "error": diff --git a/airflow/providers/amazon/aws/sensors/s3.py b/airflow/providers/amazon/aws/sensors/s3.py index 6d55a724af..2699120cbb 100644 --- a/airflow/providers/amazon/aws/sensors/s3.py +++ b/airflow/providers/amazon/aws/sensors/s3.py @@ -27,6 +27,7 @@ from typing import TYPE_CHECKING, Any, Callable, Sequence, cast from deprecated import deprecated from airflow.configuration import conf +from airflow.providers.amazon.aws.utils import validate_execute_complete_event if TYPE_CHECKING: from airflow.utils.context import Context @@ -371,6 +372,8 @@ class S3KeysUnchangedSensor(BaseSensorOperator): Relies on trigger to throw an exception, otherwise it assumes execution was successful. """ + event = validate_execute_complete_event(event) + if event and event["status"] == "error": # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 if self.soft_fail: diff --git a/airflow/providers/amazon/aws/sensors/sqs.py b/airflow/providers/amazon/aws/sensors/sqs.py index 7d9065c080..cbe12b2bd4 100644 --- a/airflow/providers/amazon/aws/sensors/sqs.py +++ b/airflow/providers/amazon/aws/sensors/sqs.py @@ -28,6 +28,7 @@ from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarni from airflow.providers.amazon.aws.hooks.sqs import SqsHook from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor from airflow.providers.amazon.aws.triggers.sqs import SqsSensorTrigger +from airflow.providers.amazon.aws.utils import validate_execute_complete_event from airflow.providers.amazon.aws.utils.mixins import aws_template_fields from airflow.providers.amazon.aws.utils.sqs import MessageFilteringType, process_response @@ -155,7 +156,9 @@ class SqsSensor(AwsBaseSensor[SqsHook]): super().execute(context=context) def execute_complete(self, context: Context, event: dict | None = None) -> None: - if event is None or event["status"] != "success": + event = validate_execute_complete_event(event) + + if event["status"] != "success": # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 message = f"Trigger error: event is {event}" if self.soft_fail: diff --git a/airflow/providers/amazon/aws/utils/__init__.py b/airflow/providers/amazon/aws/utils/__init__.py index 8495e1587a..2a96c3e447 100644 --- a/airflow/providers/amazon/aws/utils/__init__.py +++ b/airflow/providers/amazon/aws/utils/__init__.py @@ -20,7 +20,9 @@ import logging import re from datetime import datetime, timezone from enum import Enum +from typing import Any +from airflow.exceptions import AirflowException from airflow.utils.helpers import prune_dict from airflow.version import version @@ -72,6 +74,14 @@ def get_airflow_version() -> tuple[int, ...]: return tuple(int(x) for x in match.groups()) +def validate_execute_complete_event(event: dict[str, Any] | None = None) -> dict[str, Any]: + if event is None: + err_msg = "Trigger error: event is None" + log.error(err_msg) + raise AirflowException(err_msg) + return event + + class _StringCompareEnum(Enum): """ An Enum class which can be compared with regular `str` and subclasses.