This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 256438c3d6 add deferrable mode for `AthenaOperator` (#32110)
256438c3d6 is described below
commit 256438c3d6a80c989c68d2e0f3c8549108770f0e
Author: Raphaƫl Vandon <[email protected]>
AuthorDate: Mon Jun 26 23:51:51 2023 -0700
add deferrable mode for `AthenaOperator` (#32110)
* add deferrable mode for `AthenaOperator`
---
airflow/providers/amazon/aws/operators/athena.py | 20 +++++-
airflow/providers/amazon/aws/triggers/athena.py | 76 ++++++++++++++++++++++
airflow/providers/amazon/provider.yaml | 3 +
.../providers/amazon/aws/operators/test_athena.py | 12 ++++
tests/providers/amazon/aws/triggers/test_athena.py | 53 +++++++++++++++
5 files changed, 163 insertions(+), 1 deletion(-)
diff --git a/airflow/providers/amazon/aws/operators/athena.py
b/airflow/providers/amazon/aws/operators/athena.py
index 612e563ce6..990f2ec414 100644
--- a/airflow/providers/amazon/aws/operators/athena.py
+++ b/airflow/providers/amazon/aws/operators/athena.py
@@ -20,8 +20,10 @@ from __future__ import annotations
from functools import cached_property
from typing import TYPE_CHECKING, Any, Sequence
+from airflow import AirflowException
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.athena import AthenaHook
+from airflow.providers.amazon.aws.triggers.athena import AthenaTrigger
if TYPE_CHECKING:
from airflow.utils.context import Context
@@ -69,6 +71,7 @@ class AthenaOperator(BaseOperator):
sleep_time: int = 30,
max_polling_attempts: int | None = None,
log_query: bool = True,
+ deferrable: bool = False,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
@@ -81,9 +84,10 @@ class AthenaOperator(BaseOperator):
self.query_execution_context = query_execution_context or {}
self.result_configuration = result_configuration or {}
self.sleep_time = sleep_time
- self.max_polling_attempts = max_polling_attempts
+ self.max_polling_attempts = max_polling_attempts or 999999
self.query_execution_id: str | None = None
self.log_query: bool = log_query
+ self.deferrable = deferrable
@cached_property
def hook(self) -> AthenaHook:
@@ -101,6 +105,15 @@ class AthenaOperator(BaseOperator):
self.client_request_token,
self.workgroup,
)
+
+ if self.deferrable:
+ self.defer(
+ trigger=AthenaTrigger(
+ self.query_execution_id, self.sleep_time,
self.max_polling_attempts, self.aws_conn_id
+ ),
+ method_name="execute_complete",
+ )
+ # implicit else:
query_status = self.hook.poll_query_status(
self.query_execution_id,
max_polling_attempts=self.max_polling_attempts,
@@ -121,6 +134,11 @@ class AthenaOperator(BaseOperator):
return self.query_execution_id
+ def execute_complete(self, context, event=None):
+ if event["status"] != "success":
+ raise AirflowException(f"Error while waiting for operation on
cluster to complete: {event}")
+ return event["value"]
+
def on_kill(self) -> None:
"""Cancel the submitted athena query."""
if self.query_execution_id:
diff --git a/airflow/providers/amazon/aws/triggers/athena.py
b/airflow/providers/amazon/aws/triggers/athena.py
new file mode 100644
index 0000000000..780d9e9b98
--- /dev/null
+++ b/airflow/providers/amazon/aws/triggers/athena.py
@@ -0,0 +1,76 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import Any
+
+from airflow.providers.amazon.aws.hooks.athena import AthenaHook
+from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait
+from airflow.triggers.base import BaseTrigger, TriggerEvent
+
+
+class AthenaTrigger(BaseTrigger):
+ """
+ Trigger for RedshiftCreateClusterOperator.
+
+ The trigger will asynchronously poll the boto3 API and wait for the
+ Redshift cluster to be in the `available` state.
+
+ :param query_execution_id: ID of the Athena query execution to watch
+ :param poll_interval: The amount of time in seconds to wait between
attempts.
+ :param max_attempt: The maximum number of attempts to be made.
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
+ """
+
+ def __init__(
+ self,
+ query_execution_id: str,
+ poll_interval: int,
+ max_attempt: int,
+ aws_conn_id: str,
+ ):
+ self.query_execution_id = query_execution_id
+ self.poll_interval = poll_interval
+ self.max_attempt = max_attempt
+ self.aws_conn_id = aws_conn_id
+
+ def serialize(self) -> tuple[str, dict[str, Any]]:
+ return (
+ self.__class__.__module__ + "." + self.__class__.__qualname__,
+ {
+ "query_execution_id": str(self.query_execution_id),
+ "poll_interval": str(self.poll_interval),
+ "max_attempt": str(self.max_attempt),
+ "aws_conn_id": str(self.aws_conn_id),
+ },
+ )
+
+ async def run(self):
+ hook = AthenaHook(self.aws_conn_id)
+ async with hook.async_conn as client:
+ waiter = hook.get_waiter("query_complete", deferrable=True,
client=client)
+ await async_wait(
+ waiter=waiter,
+ waiter_delay=self.poll_interval,
+ max_attempts=self.max_attempt,
+ args={"QueryExecutionId": self.query_execution_id},
+ failure_message=f"Error while waiting for query
{self.query_execution_id} to complete",
+ status_message=f"Query execution id:
{self.query_execution_id}, "
+ "Query is still in non-terminal state",
+ status_args=["QueryExecution.Status.State"],
+ )
+ yield TriggerEvent({"status": "success", "value":
self.query_execution_id})
diff --git a/airflow/providers/amazon/provider.yaml
b/airflow/providers/amazon/provider.yaml
index 5439f9c8cb..e4f16ce398 100644
--- a/airflow/providers/amazon/provider.yaml
+++ b/airflow/providers/amazon/provider.yaml
@@ -515,6 +515,9 @@ hooks:
- airflow.providers.amazon.aws.hooks.appflow
triggers:
+ - integration-name: Amazon Athena
+ python-modules:
+ - airflow.providers.amazon.aws.triggers.athena
- integration-name: AWS Batch
python-modules:
- airflow.providers.amazon.aws.triggers.batch
diff --git a/tests/providers/amazon/aws/operators/test_athena.py
b/tests/providers/amazon/aws/operators/test_athena.py
index cfc7869768..9e52852520 100644
--- a/tests/providers/amazon/aws/operators/test_athena.py
+++ b/tests/providers/amazon/aws/operators/test_athena.py
@@ -20,9 +20,11 @@ from unittest import mock
import pytest
+from airflow.exceptions import TaskDeferred
from airflow.models import DAG, DagRun, TaskInstance
from airflow.providers.amazon.aws.hooks.athena import AthenaHook
from airflow.providers.amazon.aws.operators.athena import AthenaOperator
+from airflow.providers.amazon.aws.triggers.athena import AthenaTrigger
from airflow.utils import timezone
from airflow.utils.timezone import datetime
@@ -158,3 +160,13 @@ class TestAthenaOperator:
ti.dag_run = dag_run
assert self.athena.execute(ti.get_template_context()) ==
ATHENA_QUERY_ID
+
+ @mock.patch.object(AthenaHook, "run_query", return_value=ATHENA_QUERY_ID)
+ def test_is_deferred(self, mock_run_query):
+ self.athena.deferrable = True
+
+ with pytest.raises(TaskDeferred) as deferred:
+ self.athena.execute(None)
+
+ assert isinstance(deferred.value.trigger, AthenaTrigger)
+ assert deferred.value.trigger.query_execution_id == ATHENA_QUERY_ID
diff --git a/tests/providers/amazon/aws/triggers/test_athena.py
b/tests/providers/amazon/aws/triggers/test_athena.py
new file mode 100644
index 0000000000..04e601f439
--- /dev/null
+++ b/tests/providers/amazon/aws/triggers/test_athena.py
@@ -0,0 +1,53 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest import mock
+from unittest.mock import AsyncMock
+
+import pytest
+from botocore.exceptions import WaiterError
+
+from airflow.providers.amazon.aws.hooks.athena import AthenaHook
+from airflow.providers.amazon.aws.triggers.athena import AthenaTrigger
+
+
+class TestAthenaTrigger:
+ @pytest.mark.asyncio
+ @mock.patch.object(AthenaHook, "get_waiter")
+ @mock.patch.object(AthenaHook, "async_conn") # LatestBoto step of CI
fails without this
+ async def test_run_with_error(self, conn_mock, waiter_mock):
+ waiter_mock.side_effect = WaiterError("name", "reason", {})
+
+ trigger = AthenaTrigger("query_id", 0, 5, None)
+
+ with pytest.raises(WaiterError):
+ generator = trigger.run()
+ await generator.asend(None)
+
+ @pytest.mark.asyncio
+ @mock.patch.object(AthenaHook, "get_waiter")
+ @mock.patch.object(AthenaHook, "async_conn") # LatestBoto step of CI
fails without this
+ async def test_run_success(self, conn_mock, waiter_mock):
+ waiter_mock().wait = AsyncMock()
+ trigger = AthenaTrigger("my_query_id", 0, 5, None)
+
+ generator = trigger.run()
+ event = await generator.asend(None)
+
+ assert event.payload["status"] == "success"
+ assert event.payload["value"] == "my_query_id"