This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 256438c3d6 add deferrable mode for `AthenaOperator` (#32110)
256438c3d6 is described below

commit 256438c3d6a80c989c68d2e0f3c8549108770f0e
Author: RaphaĆ«l Vandon <[email protected]>
AuthorDate: Mon Jun 26 23:51:51 2023 -0700

    add deferrable mode for `AthenaOperator` (#32110)
    
    * add deferrable mode for `AthenaOperator`
---
 airflow/providers/amazon/aws/operators/athena.py   | 20 +++++-
 airflow/providers/amazon/aws/triggers/athena.py    | 76 ++++++++++++++++++++++
 airflow/providers/amazon/provider.yaml             |  3 +
 .../providers/amazon/aws/operators/test_athena.py  | 12 ++++
 tests/providers/amazon/aws/triggers/test_athena.py | 53 +++++++++++++++
 5 files changed, 163 insertions(+), 1 deletion(-)

diff --git a/airflow/providers/amazon/aws/operators/athena.py 
b/airflow/providers/amazon/aws/operators/athena.py
index 612e563ce6..990f2ec414 100644
--- a/airflow/providers/amazon/aws/operators/athena.py
+++ b/airflow/providers/amazon/aws/operators/athena.py
@@ -20,8 +20,10 @@ from __future__ import annotations
 from functools import cached_property
 from typing import TYPE_CHECKING, Any, Sequence
 
+from airflow import AirflowException
 from airflow.models import BaseOperator
 from airflow.providers.amazon.aws.hooks.athena import AthenaHook
+from airflow.providers.amazon.aws.triggers.athena import AthenaTrigger
 
 if TYPE_CHECKING:
     from airflow.utils.context import Context
@@ -69,6 +71,7 @@ class AthenaOperator(BaseOperator):
         sleep_time: int = 30,
         max_polling_attempts: int | None = None,
         log_query: bool = True,
+        deferrable: bool = False,
         **kwargs: Any,
     ) -> None:
         super().__init__(**kwargs)
@@ -81,9 +84,10 @@ class AthenaOperator(BaseOperator):
         self.query_execution_context = query_execution_context or {}
         self.result_configuration = result_configuration or {}
         self.sleep_time = sleep_time
-        self.max_polling_attempts = max_polling_attempts
+        self.max_polling_attempts = max_polling_attempts or 999999
         self.query_execution_id: str | None = None
         self.log_query: bool = log_query
+        self.deferrable = deferrable
 
     @cached_property
     def hook(self) -> AthenaHook:
@@ -101,6 +105,15 @@ class AthenaOperator(BaseOperator):
             self.client_request_token,
             self.workgroup,
         )
+
+        if self.deferrable:
+            self.defer(
+                trigger=AthenaTrigger(
+                    self.query_execution_id, self.sleep_time, 
self.max_polling_attempts, self.aws_conn_id
+                ),
+                method_name="execute_complete",
+            )
+        # implicit else:
         query_status = self.hook.poll_query_status(
             self.query_execution_id,
             max_polling_attempts=self.max_polling_attempts,
@@ -121,6 +134,11 @@ class AthenaOperator(BaseOperator):
 
         return self.query_execution_id
 
+    def execute_complete(self, context, event=None):
+        if event["status"] != "success":
+            raise AirflowException(f"Error while waiting for operation on 
cluster to complete: {event}")
+        return event["value"]
+
     def on_kill(self) -> None:
         """Cancel the submitted athena query."""
         if self.query_execution_id:
diff --git a/airflow/providers/amazon/aws/triggers/athena.py 
b/airflow/providers/amazon/aws/triggers/athena.py
new file mode 100644
index 0000000000..780d9e9b98
--- /dev/null
+++ b/airflow/providers/amazon/aws/triggers/athena.py
@@ -0,0 +1,76 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import Any
+
+from airflow.providers.amazon.aws.hooks.athena import AthenaHook
+from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait
+from airflow.triggers.base import BaseTrigger, TriggerEvent
+
+
+class AthenaTrigger(BaseTrigger):
+    """
+    Trigger for RedshiftCreateClusterOperator.
+
+    The trigger will asynchronously poll the boto3 API and wait for the
+    Redshift cluster to be in the `available` state.
+
+    :param query_execution_id:  ID of the Athena query execution to watch
+    :param poll_interval: The amount of time in seconds to wait between 
attempts.
+    :param max_attempt: The maximum number of attempts to be made.
+    :param aws_conn_id: The Airflow connection used for AWS credentials.
+    """
+
+    def __init__(
+        self,
+        query_execution_id: str,
+        poll_interval: int,
+        max_attempt: int,
+        aws_conn_id: str,
+    ):
+        self.query_execution_id = query_execution_id
+        self.poll_interval = poll_interval
+        self.max_attempt = max_attempt
+        self.aws_conn_id = aws_conn_id
+
+    def serialize(self) -> tuple[str, dict[str, Any]]:
+        return (
+            self.__class__.__module__ + "." + self.__class__.__qualname__,
+            {
+                "query_execution_id": str(self.query_execution_id),
+                "poll_interval": str(self.poll_interval),
+                "max_attempt": str(self.max_attempt),
+                "aws_conn_id": str(self.aws_conn_id),
+            },
+        )
+
+    async def run(self):
+        hook = AthenaHook(self.aws_conn_id)
+        async with hook.async_conn as client:
+            waiter = hook.get_waiter("query_complete", deferrable=True, 
client=client)
+            await async_wait(
+                waiter=waiter,
+                waiter_delay=self.poll_interval,
+                max_attempts=self.max_attempt,
+                args={"QueryExecutionId": self.query_execution_id},
+                failure_message=f"Error while waiting for query 
{self.query_execution_id} to complete",
+                status_message=f"Query execution id: 
{self.query_execution_id}, "
+                "Query is still in non-terminal state",
+                status_args=["QueryExecution.Status.State"],
+            )
+        yield TriggerEvent({"status": "success", "value": 
self.query_execution_id})
diff --git a/airflow/providers/amazon/provider.yaml 
b/airflow/providers/amazon/provider.yaml
index 5439f9c8cb..e4f16ce398 100644
--- a/airflow/providers/amazon/provider.yaml
+++ b/airflow/providers/amazon/provider.yaml
@@ -515,6 +515,9 @@ hooks:
       - airflow.providers.amazon.aws.hooks.appflow
 
 triggers:
+  - integration-name: Amazon Athena
+    python-modules:
+      - airflow.providers.amazon.aws.triggers.athena
   - integration-name: AWS Batch
     python-modules:
       - airflow.providers.amazon.aws.triggers.batch
diff --git a/tests/providers/amazon/aws/operators/test_athena.py 
b/tests/providers/amazon/aws/operators/test_athena.py
index cfc7869768..9e52852520 100644
--- a/tests/providers/amazon/aws/operators/test_athena.py
+++ b/tests/providers/amazon/aws/operators/test_athena.py
@@ -20,9 +20,11 @@ from unittest import mock
 
 import pytest
 
+from airflow.exceptions import TaskDeferred
 from airflow.models import DAG, DagRun, TaskInstance
 from airflow.providers.amazon.aws.hooks.athena import AthenaHook
 from airflow.providers.amazon.aws.operators.athena import AthenaOperator
+from airflow.providers.amazon.aws.triggers.athena import AthenaTrigger
 from airflow.utils import timezone
 from airflow.utils.timezone import datetime
 
@@ -158,3 +160,13 @@ class TestAthenaOperator:
         ti.dag_run = dag_run
 
         assert self.athena.execute(ti.get_template_context()) == 
ATHENA_QUERY_ID
+
+    @mock.patch.object(AthenaHook, "run_query", return_value=ATHENA_QUERY_ID)
+    def test_is_deferred(self, mock_run_query):
+        self.athena.deferrable = True
+
+        with pytest.raises(TaskDeferred) as deferred:
+            self.athena.execute(None)
+
+        assert isinstance(deferred.value.trigger, AthenaTrigger)
+        assert deferred.value.trigger.query_execution_id == ATHENA_QUERY_ID
diff --git a/tests/providers/amazon/aws/triggers/test_athena.py 
b/tests/providers/amazon/aws/triggers/test_athena.py
new file mode 100644
index 0000000000..04e601f439
--- /dev/null
+++ b/tests/providers/amazon/aws/triggers/test_athena.py
@@ -0,0 +1,53 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest import mock
+from unittest.mock import AsyncMock
+
+import pytest
+from botocore.exceptions import WaiterError
+
+from airflow.providers.amazon.aws.hooks.athena import AthenaHook
+from airflow.providers.amazon.aws.triggers.athena import AthenaTrigger
+
+
+class TestAthenaTrigger:
+    @pytest.mark.asyncio
+    @mock.patch.object(AthenaHook, "get_waiter")
+    @mock.patch.object(AthenaHook, "async_conn")  # LatestBoto step of CI 
fails without this
+    async def test_run_with_error(self, conn_mock, waiter_mock):
+        waiter_mock.side_effect = WaiterError("name", "reason", {})
+
+        trigger = AthenaTrigger("query_id", 0, 5, None)
+
+        with pytest.raises(WaiterError):
+            generator = trigger.run()
+            await generator.asend(None)
+
+    @pytest.mark.asyncio
+    @mock.patch.object(AthenaHook, "get_waiter")
+    @mock.patch.object(AthenaHook, "async_conn")  # LatestBoto step of CI 
fails without this
+    async def test_run_success(self, conn_mock, waiter_mock):
+        waiter_mock().wait = AsyncMock()
+        trigger = AthenaTrigger("my_query_id", 0, 5, None)
+
+        generator = trigger.run()
+        event = await generator.asend(None)
+
+        assert event.payload["status"] == "success"
+        assert event.payload["value"] == "my_query_id"

Reply via email to