This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new da456065df Use base aws classes in Amazon Athena
Operators/Sensors/Triggers (#35133)
da456065df is described below
commit da456065dff1c55a1cce61299cbfdb91d3583eed
Author: Andrey Anshin <[email protected]>
AuthorDate: Tue Oct 24 09:40:19 2023 +0400
Use base aws classes in Amazon Athena Operators/Sensors/Triggers (#35133)
* Use base aws classes in Amazon Athena Operators/Sensors/Triggers
* Fix positional arguments in AthenaTrigger
---
airflow/providers/amazon/aws/operators/athena.py | 47 ++++++++-----
airflow/providers/amazon/aws/sensors/athena.py | 31 +++++----
airflow/providers/amazon/aws/triggers/athena.py | 11 +++-
.../operators/athena.rst | 5 ++
.../providers/amazon/aws/operators/test_athena.py | 30 +++++++--
tests/providers/amazon/aws/sensors/test_athena.py | 76 +++++++++++++---------
6 files changed, 135 insertions(+), 65 deletions(-)
diff --git a/airflow/providers/amazon/aws/operators/athena.py
b/airflow/providers/amazon/aws/operators/athena.py
index 95340d6f8a..39b2cb9ba6 100644
--- a/airflow/providers/amazon/aws/operators/athena.py
+++ b/airflow/providers/amazon/aws/operators/athena.py
@@ -17,22 +17,22 @@
# under the License.
from __future__ import annotations
-from functools import cached_property
from typing import TYPE_CHECKING, Any, Sequence
from airflow.configuration import conf
from airflow.exceptions import AirflowException
-from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.athena import AthenaHook
+from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
from airflow.providers.amazon.aws.triggers.athena import AthenaTrigger
+from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
if TYPE_CHECKING:
from airflow.utils.context import Context
-class AthenaOperator(BaseOperator):
+class AthenaOperator(AwsBaseOperator[AthenaHook]):
"""
- An operator that submits a presto query to athena.
+ An operator that submits a Trino/Presto query to Amazon Athena.
.. note:: if the task is killed while it runs, it'll cancel the athena
query that was launched,
EXCEPT if running in deferrable mode.
@@ -41,11 +41,10 @@ class AthenaOperator(BaseOperator):
For more information on how to use this operator, take a look at the
guide:
:ref:`howto/operator:AthenaOperator`
- :param query: Presto to be run on athena. (templated)
+ :param query: Trino/Presto query to be run on Amazon Athena. (templated)
:param database: Database to select. (templated)
:param catalog: Catalog to select. (templated)
:param output_location: s3 path to write the query results into.
(templated)
- :param aws_conn_id: aws connection to use
:param client_request_token: Unique token created by user to avoid
multiple executions of same query
:param workgroup: Athena workgroup in which query will be run. (templated)
:param query_execution_context: Context in which query need to be run
@@ -55,10 +54,23 @@ class AthenaOperator(BaseOperator):
To limit task execution time, use execution_timeout.
:param log_query: Whether to log athena query and other execution params
when it's executed.
Defaults to *True*.
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
+ If this is ``None`` or empty then the default boto3 behaviour is used.
If
+ running Airflow in a distributed manner and aws_conn_id is None or
+ empty, then default boto3 configuration would be used (and must be
+ maintained on each worker node).
+ :param region_name: AWS region_name. If not specified then the default
boto3 behaviour is used.
+ :param verify: Whether or not to verify SSL certificates. See:
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+ :param botocore_config: Configuration dictionary (key-values) for botocore
client. See:
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""
+ aws_hook_class = AthenaHook
ui_color = "#44b5e2"
- template_fields: Sequence[str] = ("query", "database", "output_location",
"workgroup", "catalog")
+ template_fields: Sequence[str] = aws_template_fields(
+ "query", "database", "output_location", "workgroup", "catalog"
+ )
template_ext: Sequence[str] = (".sql",)
template_fields_renderers = {"query": "sql"}
@@ -68,7 +80,6 @@ class AthenaOperator(BaseOperator):
query: str,
database: str,
output_location: str,
- aws_conn_id: str = "aws_default",
client_request_token: str | None = None,
workgroup: str = "primary",
query_execution_context: dict[str, str] | None = None,
@@ -84,7 +95,6 @@ class AthenaOperator(BaseOperator):
self.query = query
self.database = database
self.output_location = output_location
- self.aws_conn_id = aws_conn_id
self.client_request_token = client_request_token
self.workgroup = workgroup
self.query_execution_context = query_execution_context or {}
@@ -96,13 +106,12 @@ class AthenaOperator(BaseOperator):
self.deferrable = deferrable
self.catalog: str = catalog
- @cached_property
- def hook(self) -> AthenaHook:
- """Create and return an AthenaHook."""
- return AthenaHook(self.aws_conn_id, log_query=self.log_query)
+ @property
+ def _hook_parameters(self) -> dict[str, Any]:
+ return {**super()._hook_parameters, "log_query": self.log_query}
def execute(self, context: Context) -> str | None:
- """Run Presto Query on Athena."""
+ """Run Trino/Presto Query on Amazon Athena."""
self.query_execution_context["Database"] = self.database
self.query_execution_context["Catalog"] = self.catalog
self.result_configuration["OutputLocation"] = self.output_location
@@ -117,7 +126,13 @@ class AthenaOperator(BaseOperator):
if self.deferrable:
self.defer(
trigger=AthenaTrigger(
- self.query_execution_id, self.sleep_time,
self.max_polling_attempts, self.aws_conn_id
+ query_execution_id=self.query_execution_id,
+ waiter_delay=self.sleep_time,
+ waiter_max_attempts=self.max_polling_attempts,
+ aws_conn_id=self.aws_conn_id,
+ region_name=self.region_name,
+ verify=self.verify,
+ botocore_config=self.botocore_config,
),
method_name="execute_complete",
)
@@ -148,7 +163,7 @@ class AthenaOperator(BaseOperator):
return event["value"]
def on_kill(self) -> None:
- """Cancel the submitted athena query."""
+ """Cancel the submitted Amazon Athena query."""
if self.query_execution_id:
self.log.info("Received a kill signal.")
response = self.hook.stop_query(self.query_execution_id)
diff --git a/airflow/providers/amazon/aws/sensors/athena.py
b/airflow/providers/amazon/aws/sensors/athena.py
index 4a1de65266..38f2bb54f8 100644
--- a/airflow/providers/amazon/aws/sensors/athena.py
+++ b/airflow/providers/amazon/aws/sensors/athena.py
@@ -17,18 +17,19 @@
# under the License.
from __future__ import annotations
-from functools import cached_property
from typing import TYPE_CHECKING, Any, Sequence
+from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
+from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
+
if TYPE_CHECKING:
from airflow.utils.context import Context
from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.providers.amazon.aws.hooks.athena import AthenaHook
-from airflow.sensors.base import BaseSensorOperator
-class AthenaSensor(BaseSensorOperator):
+class AthenaSensor(AwsBaseSensor[AthenaHook]):
"""
Poll the state of the Query until it reaches a terminal state; fails if
the query fails.
@@ -40,9 +41,18 @@ class AthenaSensor(BaseSensorOperator):
:param query_execution_id: query_execution_id to check the state of
:param max_retries: Number of times to poll for query state before
returning the current state, defaults to None
- :param aws_conn_id: aws connection to use, defaults to 'aws_default'
:param sleep_time: Time in seconds to wait between two consecutive call to
check query status on athena, defaults to 10
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
+ If this is ``None`` or empty then the default boto3 behaviour is used.
If
+ running Airflow in a distributed manner and aws_conn_id is None or
+ empty, then default boto3 configuration would be used (and must be
+ maintained on each worker node).
+ :param region_name: AWS region_name. If not specified then the default
boto3 behaviour is used.
+ :param verify: Whether or not to verify SSL certificates. See:
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+ :param botocore_config: Configuration dictionary (key-values) for botocore
client. See:
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""
INTERMEDIATE_STATES = (
@@ -55,8 +65,10 @@ class AthenaSensor(BaseSensorOperator):
)
SUCCESS_STATES = ("SUCCEEDED",)
- template_fields: Sequence[str] = ("query_execution_id",)
- template_ext: Sequence[str] = ()
+ aws_hook_class = AthenaHook
+ template_fields: Sequence[str] = aws_template_fields(
+ "query_execution_id",
+ )
ui_color = "#66c3ff"
def __init__(
@@ -64,12 +76,10 @@ class AthenaSensor(BaseSensorOperator):
*,
query_execution_id: str,
max_retries: int | None = None,
- aws_conn_id: str = "aws_default",
sleep_time: int = 10,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
- self.aws_conn_id = aws_conn_id
self.query_execution_id = query_execution_id
self.sleep_time = sleep_time
self.max_retries = max_retries
@@ -87,8 +97,3 @@ class AthenaSensor(BaseSensorOperator):
if state in self.INTERMEDIATE_STATES:
return False
return True
-
- @cached_property
- def hook(self) -> AthenaHook:
- """Create and return an AthenaHook."""
- return AthenaHook(self.aws_conn_id)
diff --git a/airflow/providers/amazon/aws/triggers/athena.py
b/airflow/providers/amazon/aws/triggers/athena.py
index 65c4f08000..a6ca58a2a2 100644
--- a/airflow/providers/amazon/aws/triggers/athena.py
+++ b/airflow/providers/amazon/aws/triggers/athena.py
@@ -43,7 +43,8 @@ class AthenaTrigger(AwsBaseWaiterTrigger):
query_execution_id: str,
waiter_delay: int,
waiter_max_attempts: int,
- aws_conn_id: str,
+ aws_conn_id: str | None,
+ **kwargs,
):
super().__init__(
serialized_fields={"query_execution_id": query_execution_id},
@@ -56,7 +57,13 @@ class AthenaTrigger(AwsBaseWaiterTrigger):
waiter_delay=waiter_delay,
waiter_max_attempts=waiter_max_attempts,
aws_conn_id=aws_conn_id,
+ **kwargs,
)
def hook(self) -> AwsGenericHook:
- return AthenaHook(self.aws_conn_id)
+ return AthenaHook(
+ aws_conn_id=self.aws_conn_id,
+ region_name=self.region_name,
+ verify=self.verify,
+ config=self.botocore_config,
+ )
diff --git a/docs/apache-airflow-providers-amazon/operators/athena.rst
b/docs/apache-airflow-providers-amazon/operators/athena.rst
index a472279882..d82d2edf2b 100644
--- a/docs/apache-airflow-providers-amazon/operators/athena.rst
+++ b/docs/apache-airflow-providers-amazon/operators/athena.rst
@@ -30,6 +30,11 @@ Prerequisite Tasks
.. include:: ../_partials/prerequisite_tasks.rst
+Generic Parameters
+------------------
+
+.. include:: ../_partials/generic_parameters.rst
+
Operators
---------
diff --git a/tests/providers/amazon/aws/operators/test_athena.py
b/tests/providers/amazon/aws/operators/test_athena.py
index 5208c7c683..698535e26e 100644
--- a/tests/providers/amazon/aws/operators/test_athena.py
+++ b/tests/providers/amazon/aws/operators/test_athena.py
@@ -53,9 +53,9 @@ class TestAthenaOperator:
"start_date": DEFAULT_DATE,
}
- self.dag = DAG(f"{TEST_DAG_ID}test_schedule_dag_once",
default_args=args, schedule="@once")
+ self.dag = DAG(TEST_DAG_ID, default_args=args, schedule="@once")
- self.athena = AthenaOperator(
+ self.default_op_kwargs = dict(
task_id="test_athena_operator",
query="SELECT * FROM TEST_TABLE",
database="TEST_DATABASE",
@@ -63,15 +63,37 @@ class TestAthenaOperator:
client_request_token="eac427d0-1c6d-4dfb-96aa-2835d3ac6595",
sleep_time=0,
max_polling_attempts=3,
- dag=self.dag,
)
+ self.athena = AthenaOperator(**self.default_op_kwargs,
aws_conn_id=None, dag=self.dag)
+
+ def test_base_aws_op_attributes(self):
+ op = AthenaOperator(**self.default_op_kwargs)
+ assert op.hook.aws_conn_id == "aws_default"
+ assert op.hook._region_name is None
+ assert op.hook._verify is None
+ assert op.hook._config is None
+ assert op.hook.log_query is True
+
+ op = AthenaOperator(
+ **self.default_op_kwargs,
+ aws_conn_id="aws-test-custom-conn",
+ region_name="eu-west-1",
+ verify=False,
+ botocore_config={"read_timeout": 42},
+ log_query=False,
+ )
+ assert op.hook.aws_conn_id == "aws-test-custom-conn"
+ assert op.hook._region_name == "eu-west-1"
+ assert op.hook._verify is False
+ assert op.hook._config is not None
+ assert op.hook._config.read_timeout == 42
+ assert op.hook.log_query is False
def test_init(self):
assert self.athena.task_id == MOCK_DATA["task_id"]
assert self.athena.query == MOCK_DATA["query"]
assert self.athena.database == MOCK_DATA["database"]
assert self.athena.catalog == MOCK_DATA["catalog"]
- assert self.athena.aws_conn_id == "aws_default"
assert self.athena.client_request_token ==
MOCK_DATA["client_request_token"]
assert self.athena.sleep_time == 0
diff --git a/tests/providers/amazon/aws/sensors/test_athena.py
b/tests/providers/amazon/aws/sensors/test_athena.py
index 18012d81ca..a973c76a38 100644
--- a/tests/providers/amazon/aws/sensors/test_athena.py
+++ b/tests/providers/amazon/aws/sensors/test_athena.py
@@ -26,48 +26,64 @@ from airflow.providers.amazon.aws.hooks.athena import
AthenaHook
from airflow.providers.amazon.aws.sensors.athena import AthenaSensor
[email protected]
+def mock_poll_query_status():
+ with mock.patch.object(AthenaHook, "poll_query_status") as m:
+ yield m
+
+
class TestAthenaSensor:
def setup_method(self):
- self.sensor = AthenaSensor(
+ self.default_op_kwargs = dict(
task_id="test_athena_sensor",
query_execution_id="abc",
sleep_time=5,
max_retries=1,
- aws_conn_id="aws_default",
)
+ self.sensor = AthenaSensor(**self.default_op_kwargs, aws_conn_id=None)
- @mock.patch.object(AthenaHook, "poll_query_status",
side_effect=("SUCCEEDED",))
- def test_poke_success(self, mock_poll_query_status):
- assert self.sensor.poke({}) is True
-
- @mock.patch.object(AthenaHook, "poll_query_status",
side_effect=("RUNNING",))
- def test_poke_running(self, mock_poll_query_status):
- assert self.sensor.poke({}) is False
+ def test_base_aws_op_attributes(self):
+ op = AthenaSensor(**self.default_op_kwargs)
+ assert op.hook.aws_conn_id == "aws_default"
+ assert op.hook._region_name is None
+ assert op.hook._verify is None
+ assert op.hook._config is None
+ assert op.hook.log_query is True
- @mock.patch.object(AthenaHook, "poll_query_status",
side_effect=("QUEUED",))
- def test_poke_queued(self, mock_poll_query_status):
- assert self.sensor.poke({}) is False
+ op = AthenaSensor(
+ **self.default_op_kwargs,
+ aws_conn_id="aws-test-custom-conn",
+ region_name="eu-west-1",
+ verify=False,
+ botocore_config={"read_timeout": 42},
+ )
+ assert op.hook.aws_conn_id == "aws-test-custom-conn"
+ assert op.hook._region_name == "eu-west-1"
+ assert op.hook._verify is False
+ assert op.hook._config is not None
+ assert op.hook._config.read_timeout == 42
- @mock.patch.object(AthenaHook, "poll_query_status",
side_effect=("FAILED",))
- def test_poke_failed(self, mock_poll_query_status):
- with pytest.raises(AirflowException) as ctx:
- self.sensor.poke({})
- assert "Athena sensor failed" in str(ctx.value)
+ @pytest.mark.parametrize("state", ["SUCCEEDED"])
+ def test_poke_success_states(self, state, mock_poll_query_status):
+ mock_poll_query_status.side_effect = [state]
+ assert self.sensor.poke({}) is True
- @mock.patch.object(AthenaHook, "poll_query_status",
side_effect=("CANCELLED",))
- def test_poke_cancelled(self, mock_poll_query_status):
- with pytest.raises(AirflowException) as ctx:
- self.sensor.poke({})
- assert "Athena sensor failed" in str(ctx.value)
+ @pytest.mark.parametrize("state", ["RUNNING", "QUEUED"])
+ def test_poke_intermediate_states(self, state, mock_poll_query_status):
+ mock_poll_query_status.side_effect = [state]
+ assert self.sensor.poke({}) is False
@pytest.mark.parametrize(
- "soft_fail, expected_exception", ((False, AirflowException), (True,
AirflowSkipException))
+ "soft_fail, expected_exception",
+ [
+ pytest.param(False, AirflowException, id="not-soft-fail"),
+ pytest.param(True, AirflowSkipException, id="soft-fail"),
+ ],
)
- def test_fail_poke(self, soft_fail, expected_exception):
- self.sensor.soft_fail = soft_fail
+ @pytest.mark.parametrize("state", ["FAILED", "CANCELLED"])
+ def test_poke_failure_states(self, state, soft_fail, expected_exception,
mock_poll_query_status):
+ mock_poll_query_status.side_effect = [state]
+ sensor = AthenaSensor(**self.default_op_kwargs, aws_conn_id=None,
soft_fail=soft_fail)
message = "Athena sensor failed"
- with pytest.raises(expected_exception, match=message), mock.patch(
-
"airflow.providers.amazon.aws.hooks.athena.AthenaHook.poll_query_status"
- ) as poll_query_status:
- poll_query_status.return_value = "FAILED"
- self.sensor.poke(context={})
+ with pytest.raises(expected_exception, match=message):
+ sensor.poke({})