This is an automated email from the ASF dual-hosted git repository.
onikolas pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 72d09a677f Use a waiter in `AthenaHook` (#31942)
72d09a677f is described below
commit 72d09a677fea22b51dbf20f3b12bae6b3c1e4792
Author: Raphaƫl Vandon <[email protected]>
AuthorDate: Fri Jun 23 14:20:37 2023 -0700
Use a waiter in `AthenaHook` (#31942)
* Use custom waiters for Emr Serverless operators
Update unit tests
---------
Co-authored-by: Syed Hussain <[email protected]>
Co-authored-by: Vincent <[email protected]>
---
airflow/providers/amazon/aws/hooks/athena.py | 80 ++++++++++------------
airflow/providers/amazon/aws/operators/athena.py | 5 +-
airflow/providers/amazon/aws/sensors/athena.py | 4 +-
airflow/providers/amazon/aws/waiters/athena.json | 30 ++++++++
tests/providers/amazon/aws/hooks/test_athena.py | 12 ++--
.../providers/amazon/aws/operators/test_athena.py | 63 ++---------------
6 files changed, 81 insertions(+), 113 deletions(-)
diff --git a/airflow/providers/amazon/aws/hooks/athena.py
b/airflow/providers/amazon/aws/hooks/athena.py
index f68eee9355..b0d1878507 100644
--- a/airflow/providers/amazon/aws/hooks/athena.py
+++ b/airflow/providers/amazon/aws/hooks/athena.py
@@ -24,12 +24,14 @@ This module contains AWS Athena hook.
"""
from __future__ import annotations
-from time import sleep
+import warnings
from typing import Any
from botocore.paginate import PageIterator
+from airflow.exceptions import AirflowException,
AirflowProviderDeprecationWarning
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
+from airflow.providers.amazon.aws.utils.waiter_with_logging import wait
class AthenaHook(AwsBaseHook):
@@ -38,8 +40,7 @@ class AthenaHook(AwsBaseHook):
Provide thick wrapper around
:external+boto3:py:class:`boto3.client("athena") <Athena.Client>`.
- :param sleep_time: Time (in seconds) to wait between two consecutive calls
- to check query status on Athena.
+ :param sleep_time: obsolete, please use the parameter of
`poll_query_status` method instead
:param log_query: Whether to log athena query and other execution params
when it's executed. Defaults to *True*.
@@ -65,9 +66,20 @@ class AthenaHook(AwsBaseHook):
"CANCELLED",
)
- def __init__(self, *args: Any, sleep_time: int = 30, log_query: bool =
True, **kwargs: Any) -> None:
+ def __init__(
+ self, *args: Any, sleep_time: int | None = None, log_query: bool =
True, **kwargs: Any
+ ) -> None:
super().__init__(client_type="athena", *args, **kwargs) # type: ignore
- self.sleep_time = sleep_time
+ if sleep_time is not None:
+ self.sleep_time = sleep_time
+ warnings.warn(
+ "The `sleep_time` parameter of the Athena hook is deprecated, "
+ "please pass this parameter to the poll_query_status method
instead.",
+ AirflowProviderDeprecationWarning,
+ stacklevel=2,
+ )
+ else:
+ self.sleep_time = 30 # previous default value
self.log_query = log_query
def run_query(
@@ -229,51 +241,31 @@ class AthenaHook(AwsBaseHook):
return paginator.paginate(**result_params)
def poll_query_status(
- self,
- query_execution_id: str,
- max_polling_attempts: int | None = None,
+ self, query_execution_id: str, max_polling_attempts: int | None =
None, sleep_time: int | None = None
) -> str | None:
"""Poll the state of a submitted query until it reaches final state.
:param query_execution_id: ID of submitted athena query
- :param max_polling_attempts: Number of times to poll for query state
- before function exits
+ :param max_polling_attempts: Number of times to poll for query state
before function exits
+ :param sleep_time: Time (in seconds) to wait between two consecutive
query status checks.
:return: One of the final states
"""
- try_number = 1
- final_query_state = None # Query state when query reaches final state
or max_polling_attempts reached
- while True:
- query_state = self.check_query_status(query_execution_id)
- if query_state is None:
- self.log.info(
- "Query execution id: %s, trial %s: Invalid query state.
Retrying again",
- query_execution_id,
- try_number,
- )
- elif query_state in self.TERMINAL_STATES:
- self.log.info(
- "Query execution id: %s, trial %s: Query execution
completed. Final state is %s",
- query_execution_id,
- try_number,
- query_state,
- )
- final_query_state = query_state
- break
- else:
- self.log.info(
- "Query execution id: %s, trial %s: Query is still in
non-terminal state - %s",
- query_execution_id,
- try_number,
- query_state,
- )
- if (
- max_polling_attempts and try_number >= max_polling_attempts
- ): # Break loop if max_polling_attempts reached
- final_query_state = query_state
- break
- try_number += 1
- sleep(self.sleep_time)
- return final_query_state
+ try:
+ wait(
+ waiter=self.get_waiter("query_complete"),
+ waiter_delay=sleep_time or self.sleep_time,
+ max_attempts=max_polling_attempts or 120,
+ args={"QueryExecutionId": query_execution_id},
+ failure_message=f"Error while waiting for query
{query_execution_id} to complete",
+ status_message=f"Query execution id: {query_execution_id}, "
+ f"Query is still in non-terminal state",
+ status_args=["QueryExecution.Status.State"],
+ )
+ except AirflowException as error:
+ # this function does not raise errors to keep previous behavior.
+ self.log.warning(error)
+ finally:
+ return self.check_query_status(query_execution_id)
def get_output_location(self, query_execution_id: str) -> str:
"""Get the output location of the query results in S3 URI format.
diff --git a/airflow/providers/amazon/aws/operators/athena.py
b/airflow/providers/amazon/aws/operators/athena.py
index 1bd1a97be2..612e563ce6 100644
--- a/airflow/providers/amazon/aws/operators/athena.py
+++ b/airflow/providers/amazon/aws/operators/athena.py
@@ -88,7 +88,7 @@ class AthenaOperator(BaseOperator):
@cached_property
def hook(self) -> AthenaHook:
"""Create and return an AthenaHook."""
- return AthenaHook(self.aws_conn_id, sleep_time=self.sleep_time,
log_query=self.log_query)
+ return AthenaHook(self.aws_conn_id, log_query=self.log_query)
def execute(self, context: Context) -> str | None:
"""Run Presto Query on Athena."""
@@ -104,6 +104,7 @@ class AthenaOperator(BaseOperator):
query_status = self.hook.poll_query_status(
self.query_execution_id,
max_polling_attempts=self.max_polling_attempts,
+ sleep_time=self.sleep_time,
)
if query_status in AthenaHook.FAILURE_STATES:
@@ -139,4 +140,4 @@ class AthenaOperator(BaseOperator):
self.log.info(
"Polling Athena for query with id %s to reach final
state", self.query_execution_id
)
- self.hook.poll_query_status(self.query_execution_id)
+ self.hook.poll_query_status(self.query_execution_id,
sleep_time=self.sleep_time)
diff --git a/airflow/providers/amazon/aws/sensors/athena.py
b/airflow/providers/amazon/aws/sensors/athena.py
index f67fb3ff9a..599341092e 100644
--- a/airflow/providers/amazon/aws/sensors/athena.py
+++ b/airflow/providers/amazon/aws/sensors/athena.py
@@ -76,7 +76,7 @@ class AthenaSensor(BaseSensorOperator):
self.max_retries = max_retries
def poke(self, context: Context) -> bool:
- state = self.hook.poll_query_status(self.query_execution_id,
self.max_retries)
+ state = self.hook.poll_query_status(self.query_execution_id,
self.max_retries, self.sleep_time)
if state in self.FAILURE_STATES:
raise AirflowException("Athena sensor failed")
@@ -88,4 +88,4 @@ class AthenaSensor(BaseSensorOperator):
@cached_property
def hook(self) -> AthenaHook:
"""Create and return an AthenaHook."""
- return AthenaHook(self.aws_conn_id, sleep_time=self.sleep_time)
+ return AthenaHook(self.aws_conn_id)
diff --git a/airflow/providers/amazon/aws/waiters/athena.json
b/airflow/providers/amazon/aws/waiters/athena.json
new file mode 100644
index 0000000000..db68ce32f4
--- /dev/null
+++ b/airflow/providers/amazon/aws/waiters/athena.json
@@ -0,0 +1,30 @@
+{
+ "version": 2,
+ "waiters": {
+ "query_complete": {
+ "operation": "GetQueryExecution",
+ "delay": 30,
+ "maxAttempts": 120,
+ "acceptors": [
+ {
+ "expected": "SUCCEEDED",
+ "matcher": "path",
+ "state": "success",
+ "argument": "QueryExecution.Status.State"
+ },
+ {
+ "expected": "FAILED",
+ "matcher": "path",
+ "state": "failure",
+ "argument": "QueryExecution.Status.State"
+ },
+ {
+ "expected": "CANCELLED",
+ "matcher": "path",
+ "state": "failure",
+ "argument": "QueryExecution.Status.State"
+ }
+ ]
+ }
+ }
+}
diff --git a/tests/providers/amazon/aws/hooks/test_athena.py
b/tests/providers/amazon/aws/hooks/test_athena.py
index a65470acea..05ed6e9e30 100644
--- a/tests/providers/amazon/aws/hooks/test_athena.py
+++ b/tests/providers/amazon/aws/hooks/test_athena.py
@@ -49,11 +49,10 @@ MOCK_QUERY_EXECUTION_OUTPUT = {
class TestAthenaHook:
def setup_method(self):
- self.athena = AthenaHook(sleep_time=0)
+ self.athena = AthenaHook()
def test_init(self):
assert self.athena.aws_conn_id == "aws_default"
- assert self.athena.sleep_time == 0
@mock.patch.object(AthenaHook, "get_conn")
def test_hook_run_query_without_token(self, mock_conn):
@@ -104,7 +103,7 @@ class TestAthenaHook:
@mock.patch.object(AthenaHook, "log")
@mock.patch.object(AthenaHook, "get_conn")
def test_hook_run_query_no_log_query(self, mock_conn, log):
- athena_hook_no_log_query = AthenaHook(sleep_time=0, log_query=False)
+ athena_hook_no_log_query = AthenaHook(log_query=False)
athena_hook_no_log_query.run_query(
query=MOCK_DATA["query"],
query_context=mock_query_context,
@@ -176,7 +175,9 @@ class TestAthenaHook:
@mock.patch.object(AthenaHook, "get_conn")
def test_hook_poll_query_when_final(self, mock_conn):
mock_conn.return_value.get_query_execution.return_value =
MOCK_SUCCEEDED_QUERY_EXECUTION
- result =
self.athena.poll_query_status(query_execution_id=MOCK_DATA["query_execution_id"])
+ result = self.athena.poll_query_status(
+ query_execution_id=MOCK_DATA["query_execution_id"], sleep_time=0
+ )
mock_conn.return_value.get_query_execution.assert_called_once()
assert result == "SUCCEEDED"
@@ -184,8 +185,7 @@ class TestAthenaHook:
def test_hook_poll_query_with_timeout(self, mock_conn):
mock_conn.return_value.get_query_execution.return_value =
MOCK_RUNNING_QUERY_EXECUTION
result = self.athena.poll_query_status(
- query_execution_id=MOCK_DATA["query_execution_id"],
- max_polling_attempts=1,
+ query_execution_id=MOCK_DATA["query_execution_id"],
max_polling_attempts=1, sleep_time=0
)
mock_conn.return_value.get_query_execution.assert_called_once()
assert result == "RUNNING"
diff --git a/tests/providers/amazon/aws/operators/test_athena.py
b/tests/providers/amazon/aws/operators/test_athena.py
index e7b945d2a4..cfc7869768 100644
--- a/tests/providers/amazon/aws/operators/test_athena.py
+++ b/tests/providers/amazon/aws/operators/test_athena.py
@@ -71,8 +71,6 @@ class TestAthenaOperator:
assert self.athena.client_request_token ==
MOCK_DATA["client_request_token"]
assert self.athena.sleep_time == 0
- assert self.athena.hook.sleep_time == 0
-
@mock.patch.object(AthenaHook, "check_query_status",
side_effect=("SUCCEEDED",))
@mock.patch.object(AthenaHook, "run_query", return_value=ATHENA_QUERY_ID)
@mock.patch.object(AthenaHook, "get_conn")
@@ -90,11 +88,7 @@ class TestAthenaOperator:
@mock.patch.object(
AthenaHook,
"check_query_status",
- side_effect=(
- "RUNNING",
- "RUNNING",
- "SUCCEEDED",
- ),
+ side_effect="SUCCEEDED",
)
@mock.patch.object(AthenaHook, "run_query", return_value=ATHENA_QUERY_ID)
@mock.patch.object(AthenaHook, "get_conn")
@@ -107,39 +101,9 @@ class TestAthenaOperator:
MOCK_DATA["client_request_token"],
MOCK_DATA["workgroup"],
)
- assert mock_check_query_status.call_count == 3
-
- @mock.patch.object(
- AthenaHook,
- "check_query_status",
- side_effect=(
- None,
- None,
- ),
- )
- @mock.patch.object(AthenaHook, "run_query", return_value=ATHENA_QUERY_ID)
- @mock.patch.object(AthenaHook, "get_conn")
- def test_hook_run_failed_query_with_none(self, mock_conn, mock_run_query,
mock_check_query_status):
- with pytest.raises(Exception):
- self.athena.execute({})
- mock_run_query.assert_called_once_with(
- MOCK_DATA["query"],
- query_context,
- result_configuration,
- MOCK_DATA["client_request_token"],
- MOCK_DATA["workgroup"],
- )
- assert mock_check_query_status.call_count == 3
@mock.patch.object(AthenaHook, "get_state_change_reason")
- @mock.patch.object(
- AthenaHook,
- "check_query_status",
- side_effect=(
- "RUNNING",
- "FAILED",
- ),
- )
+ @mock.patch.object(AthenaHook, "check_query_status", return_value="FAILED")
@mock.patch.object(AthenaHook, "run_query", return_value=ATHENA_QUERY_ID)
@mock.patch.object(AthenaHook, "get_conn")
def test_hook_run_failure_query(
@@ -154,18 +118,9 @@ class TestAthenaOperator:
MOCK_DATA["client_request_token"],
MOCK_DATA["workgroup"],
)
- assert mock_check_query_status.call_count == 2
assert mock_get_state_change_reason.call_count == 1
- @mock.patch.object(
- AthenaHook,
- "check_query_status",
- side_effect=(
- "RUNNING",
- "RUNNING",
- "CANCELLED",
- ),
- )
+ @mock.patch.object(AthenaHook, "check_query_status",
return_value="CANCELLED")
@mock.patch.object(AthenaHook, "run_query", return_value=ATHENA_QUERY_ID)
@mock.patch.object(AthenaHook, "get_conn")
def test_hook_run_cancelled_query(self, mock_conn, mock_run_query,
mock_check_query_status):
@@ -178,17 +133,8 @@ class TestAthenaOperator:
MOCK_DATA["client_request_token"],
MOCK_DATA["workgroup"],
)
- assert mock_check_query_status.call_count == 3
- @mock.patch.object(
- AthenaHook,
- "check_query_status",
- side_effect=(
- "RUNNING",
- "RUNNING",
- "RUNNING",
- ),
- )
+ @mock.patch.object(AthenaHook, "check_query_status",
return_value="RUNNING")
@mock.patch.object(AthenaHook, "run_query", return_value=ATHENA_QUERY_ID)
@mock.patch.object(AthenaHook, "get_conn")
def test_hook_run_failed_query_with_max_tries(self, mock_conn,
mock_run_query, mock_check_query_status):
@@ -201,7 +147,6 @@ class TestAthenaOperator:
MOCK_DATA["client_request_token"],
MOCK_DATA["workgroup"],
)
- assert mock_check_query_status.call_count == 3
@mock.patch.object(AthenaHook, "check_query_status",
side_effect=("SUCCEEDED",))
@mock.patch.object(AthenaHook, "run_query", return_value=ATHENA_QUERY_ID)