This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new da456065df Use base aws classes in Amazon Athena 
Operators/Sensors/Triggers (#35133)
da456065df is described below

commit da456065dff1c55a1cce61299cbfdb91d3583eed
Author: Andrey Anshin <[email protected]>
AuthorDate: Tue Oct 24 09:40:19 2023 +0400

    Use base aws classes in Amazon Athena Operators/Sensors/Triggers (#35133)
    
    * Use base aws classes in Amazon Athena Operators/Sensors/Triggers
    
    * Fix positional arguments in AthenaTrigger
---
 airflow/providers/amazon/aws/operators/athena.py   | 47 ++++++++-----
 airflow/providers/amazon/aws/sensors/athena.py     | 31 +++++----
 airflow/providers/amazon/aws/triggers/athena.py    | 11 +++-
 .../operators/athena.rst                           |  5 ++
 .../providers/amazon/aws/operators/test_athena.py  | 30 +++++++--
 tests/providers/amazon/aws/sensors/test_athena.py  | 76 +++++++++++++---------
 6 files changed, 135 insertions(+), 65 deletions(-)

diff --git a/airflow/providers/amazon/aws/operators/athena.py 
b/airflow/providers/amazon/aws/operators/athena.py
index 95340d6f8a..39b2cb9ba6 100644
--- a/airflow/providers/amazon/aws/operators/athena.py
+++ b/airflow/providers/amazon/aws/operators/athena.py
@@ -17,22 +17,22 @@
 # under the License.
 from __future__ import annotations
 
-from functools import cached_property
 from typing import TYPE_CHECKING, Any, Sequence
 
 from airflow.configuration import conf
 from airflow.exceptions import AirflowException
-from airflow.models import BaseOperator
 from airflow.providers.amazon.aws.hooks.athena import AthenaHook
+from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
 from airflow.providers.amazon.aws.triggers.athena import AthenaTrigger
+from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
 
 if TYPE_CHECKING:
     from airflow.utils.context import Context
 
 
-class AthenaOperator(BaseOperator):
+class AthenaOperator(AwsBaseOperator[AthenaHook]):
     """
-    An operator that submits a presto query to athena.
+    An operator that submits a Trino/Presto query to Amazon Athena.
 
     .. note:: if the task is killed while it runs, it'll cancel the athena 
query that was launched,
         EXCEPT if running in deferrable mode.
@@ -41,11 +41,10 @@ class AthenaOperator(BaseOperator):
         For more information on how to use this operator, take a look at the 
guide:
         :ref:`howto/operator:AthenaOperator`
 
-    :param query: Presto to be run on athena. (templated)
+    :param query: Trino/Presto query to be run on Amazon Athena. (templated)
     :param database: Database to select. (templated)
     :param catalog: Catalog to select. (templated)
     :param output_location: s3 path to write the query results into. 
(templated)
-    :param aws_conn_id: aws connection to use
     :param client_request_token: Unique token created by user to avoid 
multiple executions of same query
     :param workgroup: Athena workgroup in which query will be run. (templated)
     :param query_execution_context: Context in which query need to be run
@@ -55,10 +54,23 @@ class AthenaOperator(BaseOperator):
         To limit task execution time, use execution_timeout.
     :param log_query: Whether to log athena query and other execution params 
when it's executed.
         Defaults to *True*.
+    :param aws_conn_id: The Airflow connection used for AWS credentials.
+        If this is ``None`` or empty then the default boto3 behaviour is used. 
If
+        running Airflow in a distributed manner and aws_conn_id is None or
+        empty, then default boto3 configuration would be used (and must be
+        maintained on each worker node).
+    :param region_name: AWS region_name. If not specified then the default 
boto3 behaviour is used.
+    :param verify: Whether or not to verify SSL certificates. See:
+        
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+    :param botocore_config: Configuration dictionary (key-values) for botocore 
client. See:
+        
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
     """
 
+    aws_hook_class = AthenaHook
     ui_color = "#44b5e2"
-    template_fields: Sequence[str] = ("query", "database", "output_location", 
"workgroup", "catalog")
+    template_fields: Sequence[str] = aws_template_fields(
+        "query", "database", "output_location", "workgroup", "catalog"
+    )
     template_ext: Sequence[str] = (".sql",)
     template_fields_renderers = {"query": "sql"}
 
@@ -68,7 +80,6 @@ class AthenaOperator(BaseOperator):
         query: str,
         database: str,
         output_location: str,
-        aws_conn_id: str = "aws_default",
         client_request_token: str | None = None,
         workgroup: str = "primary",
         query_execution_context: dict[str, str] | None = None,
@@ -84,7 +95,6 @@ class AthenaOperator(BaseOperator):
         self.query = query
         self.database = database
         self.output_location = output_location
-        self.aws_conn_id = aws_conn_id
         self.client_request_token = client_request_token
         self.workgroup = workgroup
         self.query_execution_context = query_execution_context or {}
@@ -96,13 +106,12 @@ class AthenaOperator(BaseOperator):
         self.deferrable = deferrable
         self.catalog: str = catalog
 
-    @cached_property
-    def hook(self) -> AthenaHook:
-        """Create and return an AthenaHook."""
-        return AthenaHook(self.aws_conn_id, log_query=self.log_query)
+    @property
+    def _hook_parameters(self) -> dict[str, Any]:
+        return {**super()._hook_parameters, "log_query": self.log_query}
 
     def execute(self, context: Context) -> str | None:
-        """Run Presto Query on Athena."""
+        """Run Trino/Presto Query on Amazon Athena."""
         self.query_execution_context["Database"] = self.database
         self.query_execution_context["Catalog"] = self.catalog
         self.result_configuration["OutputLocation"] = self.output_location
@@ -117,7 +126,13 @@ class AthenaOperator(BaseOperator):
         if self.deferrable:
             self.defer(
                 trigger=AthenaTrigger(
-                    self.query_execution_id, self.sleep_time, 
self.max_polling_attempts, self.aws_conn_id
+                    query_execution_id=self.query_execution_id,
+                    waiter_delay=self.sleep_time,
+                    waiter_max_attempts=self.max_polling_attempts,
+                    aws_conn_id=self.aws_conn_id,
+                    region_name=self.region_name,
+                    verify=self.verify,
+                    botocore_config=self.botocore_config,
                 ),
                 method_name="execute_complete",
             )
@@ -148,7 +163,7 @@ class AthenaOperator(BaseOperator):
         return event["value"]
 
     def on_kill(self) -> None:
-        """Cancel the submitted athena query."""
+        """Cancel the submitted Amazon Athena query."""
         if self.query_execution_id:
             self.log.info("Received a kill signal.")
             response = self.hook.stop_query(self.query_execution_id)
diff --git a/airflow/providers/amazon/aws/sensors/athena.py 
b/airflow/providers/amazon/aws/sensors/athena.py
index 4a1de65266..38f2bb54f8 100644
--- a/airflow/providers/amazon/aws/sensors/athena.py
+++ b/airflow/providers/amazon/aws/sensors/athena.py
@@ -17,18 +17,19 @@
 # under the License.
 from __future__ import annotations
 
-from functools import cached_property
 from typing import TYPE_CHECKING, Any, Sequence
 
+from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
+from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
+
 if TYPE_CHECKING:
     from airflow.utils.context import Context
 
 from airflow.exceptions import AirflowException, AirflowSkipException
 from airflow.providers.amazon.aws.hooks.athena import AthenaHook
-from airflow.sensors.base import BaseSensorOperator
 
 
-class AthenaSensor(BaseSensorOperator):
+class AthenaSensor(AwsBaseSensor[AthenaHook]):
     """
     Poll the state of the Query until it reaches a terminal state; fails if 
the query fails.
 
@@ -40,9 +41,18 @@ class AthenaSensor(BaseSensorOperator):
     :param query_execution_id: query_execution_id to check the state of
     :param max_retries: Number of times to poll for query state before
         returning the current state, defaults to None
-    :param aws_conn_id: aws connection to use, defaults to 'aws_default'
     :param sleep_time: Time in seconds to wait between two consecutive call to
         check query status on athena, defaults to 10
+    :param aws_conn_id: The Airflow connection used for AWS credentials.
+        If this is ``None`` or empty then the default boto3 behaviour is used. 
If
+        running Airflow in a distributed manner and aws_conn_id is None or
+        empty, then default boto3 configuration would be used (and must be
+        maintained on each worker node).
+    :param region_name: AWS region_name. If not specified then the default 
boto3 behaviour is used.
+    :param verify: Whether or not to verify SSL certificates. See:
+        
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+    :param botocore_config: Configuration dictionary (key-values) for botocore 
client. See:
+        
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
     """
 
     INTERMEDIATE_STATES = (
@@ -55,8 +65,10 @@ class AthenaSensor(BaseSensorOperator):
     )
     SUCCESS_STATES = ("SUCCEEDED",)
 
-    template_fields: Sequence[str] = ("query_execution_id",)
-    template_ext: Sequence[str] = ()
+    aws_hook_class = AthenaHook
+    template_fields: Sequence[str] = aws_template_fields(
+        "query_execution_id",
+    )
     ui_color = "#66c3ff"
 
     def __init__(
@@ -64,12 +76,10 @@ class AthenaSensor(BaseSensorOperator):
         *,
         query_execution_id: str,
         max_retries: int | None = None,
-        aws_conn_id: str = "aws_default",
         sleep_time: int = 10,
         **kwargs: Any,
     ) -> None:
         super().__init__(**kwargs)
-        self.aws_conn_id = aws_conn_id
         self.query_execution_id = query_execution_id
         self.sleep_time = sleep_time
         self.max_retries = max_retries
@@ -87,8 +97,3 @@ class AthenaSensor(BaseSensorOperator):
         if state in self.INTERMEDIATE_STATES:
             return False
         return True
-
-    @cached_property
-    def hook(self) -> AthenaHook:
-        """Create and return an AthenaHook."""
-        return AthenaHook(self.aws_conn_id)
diff --git a/airflow/providers/amazon/aws/triggers/athena.py 
b/airflow/providers/amazon/aws/triggers/athena.py
index 65c4f08000..a6ca58a2a2 100644
--- a/airflow/providers/amazon/aws/triggers/athena.py
+++ b/airflow/providers/amazon/aws/triggers/athena.py
@@ -43,7 +43,8 @@ class AthenaTrigger(AwsBaseWaiterTrigger):
         query_execution_id: str,
         waiter_delay: int,
         waiter_max_attempts: int,
-        aws_conn_id: str,
+        aws_conn_id: str | None,
+        **kwargs,
     ):
         super().__init__(
             serialized_fields={"query_execution_id": query_execution_id},
@@ -56,7 +57,13 @@ class AthenaTrigger(AwsBaseWaiterTrigger):
             waiter_delay=waiter_delay,
             waiter_max_attempts=waiter_max_attempts,
             aws_conn_id=aws_conn_id,
+            **kwargs,
         )
 
     def hook(self) -> AwsGenericHook:
-        return AthenaHook(self.aws_conn_id)
+        return AthenaHook(
+            aws_conn_id=self.aws_conn_id,
+            region_name=self.region_name,
+            verify=self.verify,
+            config=self.botocore_config,
+        )
diff --git a/docs/apache-airflow-providers-amazon/operators/athena.rst 
b/docs/apache-airflow-providers-amazon/operators/athena.rst
index a472279882..d82d2edf2b 100644
--- a/docs/apache-airflow-providers-amazon/operators/athena.rst
+++ b/docs/apache-airflow-providers-amazon/operators/athena.rst
@@ -30,6 +30,11 @@ Prerequisite Tasks
 
 .. include:: ../_partials/prerequisite_tasks.rst
 
+Generic Parameters
+------------------
+
+.. include:: ../_partials/generic_parameters.rst
+
 Operators
 ---------
 
diff --git a/tests/providers/amazon/aws/operators/test_athena.py 
b/tests/providers/amazon/aws/operators/test_athena.py
index 5208c7c683..698535e26e 100644
--- a/tests/providers/amazon/aws/operators/test_athena.py
+++ b/tests/providers/amazon/aws/operators/test_athena.py
@@ -53,9 +53,9 @@ class TestAthenaOperator:
             "start_date": DEFAULT_DATE,
         }
 
-        self.dag = DAG(f"{TEST_DAG_ID}test_schedule_dag_once", 
default_args=args, schedule="@once")
+        self.dag = DAG(TEST_DAG_ID, default_args=args, schedule="@once")
 
-        self.athena = AthenaOperator(
+        self.default_op_kwargs = dict(
             task_id="test_athena_operator",
             query="SELECT * FROM TEST_TABLE",
             database="TEST_DATABASE",
@@ -63,15 +63,37 @@ class TestAthenaOperator:
             client_request_token="eac427d0-1c6d-4dfb-96aa-2835d3ac6595",
             sleep_time=0,
             max_polling_attempts=3,
-            dag=self.dag,
         )
+        self.athena = AthenaOperator(**self.default_op_kwargs, 
aws_conn_id=None, dag=self.dag)
+
+    def test_base_aws_op_attributes(self):
+        op = AthenaOperator(**self.default_op_kwargs)
+        assert op.hook.aws_conn_id == "aws_default"
+        assert op.hook._region_name is None
+        assert op.hook._verify is None
+        assert op.hook._config is None
+        assert op.hook.log_query is True
+
+        op = AthenaOperator(
+            **self.default_op_kwargs,
+            aws_conn_id="aws-test-custom-conn",
+            region_name="eu-west-1",
+            verify=False,
+            botocore_config={"read_timeout": 42},
+            log_query=False,
+        )
+        assert op.hook.aws_conn_id == "aws-test-custom-conn"
+        assert op.hook._region_name == "eu-west-1"
+        assert op.hook._verify is False
+        assert op.hook._config is not None
+        assert op.hook._config.read_timeout == 42
+        assert op.hook.log_query is False
 
     def test_init(self):
         assert self.athena.task_id == MOCK_DATA["task_id"]
         assert self.athena.query == MOCK_DATA["query"]
         assert self.athena.database == MOCK_DATA["database"]
         assert self.athena.catalog == MOCK_DATA["catalog"]
-        assert self.athena.aws_conn_id == "aws_default"
         assert self.athena.client_request_token == 
MOCK_DATA["client_request_token"]
         assert self.athena.sleep_time == 0
 
diff --git a/tests/providers/amazon/aws/sensors/test_athena.py 
b/tests/providers/amazon/aws/sensors/test_athena.py
index 18012d81ca..a973c76a38 100644
--- a/tests/providers/amazon/aws/sensors/test_athena.py
+++ b/tests/providers/amazon/aws/sensors/test_athena.py
@@ -26,48 +26,64 @@ from airflow.providers.amazon.aws.hooks.athena import 
AthenaHook
 from airflow.providers.amazon.aws.sensors.athena import AthenaSensor
 
 
[email protected]
+def mock_poll_query_status():
+    with mock.patch.object(AthenaHook, "poll_query_status") as m:
+        yield m
+
+
 class TestAthenaSensor:
     def setup_method(self):
-        self.sensor = AthenaSensor(
+        self.default_op_kwargs = dict(
             task_id="test_athena_sensor",
             query_execution_id="abc",
             sleep_time=5,
             max_retries=1,
-            aws_conn_id="aws_default",
         )
+        self.sensor = AthenaSensor(**self.default_op_kwargs, aws_conn_id=None)
 
-    @mock.patch.object(AthenaHook, "poll_query_status", 
side_effect=("SUCCEEDED",))
-    def test_poke_success(self, mock_poll_query_status):
-        assert self.sensor.poke({}) is True
-
-    @mock.patch.object(AthenaHook, "poll_query_status", 
side_effect=("RUNNING",))
-    def test_poke_running(self, mock_poll_query_status):
-        assert self.sensor.poke({}) is False
+    def test_base_aws_op_attributes(self):
+        op = AthenaSensor(**self.default_op_kwargs)
+        assert op.hook.aws_conn_id == "aws_default"
+        assert op.hook._region_name is None
+        assert op.hook._verify is None
+        assert op.hook._config is None
+        assert op.hook.log_query is True
 
-    @mock.patch.object(AthenaHook, "poll_query_status", 
side_effect=("QUEUED",))
-    def test_poke_queued(self, mock_poll_query_status):
-        assert self.sensor.poke({}) is False
+        op = AthenaSensor(
+            **self.default_op_kwargs,
+            aws_conn_id="aws-test-custom-conn",
+            region_name="eu-west-1",
+            verify=False,
+            botocore_config={"read_timeout": 42},
+        )
+        assert op.hook.aws_conn_id == "aws-test-custom-conn"
+        assert op.hook._region_name == "eu-west-1"
+        assert op.hook._verify is False
+        assert op.hook._config is not None
+        assert op.hook._config.read_timeout == 42
 
-    @mock.patch.object(AthenaHook, "poll_query_status", 
side_effect=("FAILED",))
-    def test_poke_failed(self, mock_poll_query_status):
-        with pytest.raises(AirflowException) as ctx:
-            self.sensor.poke({})
-        assert "Athena sensor failed" in str(ctx.value)
+    @pytest.mark.parametrize("state", ["SUCCEEDED"])
+    def test_poke_success_states(self, state, mock_poll_query_status):
+        mock_poll_query_status.side_effect = [state]
+        assert self.sensor.poke({}) is True
 
-    @mock.patch.object(AthenaHook, "poll_query_status", 
side_effect=("CANCELLED",))
-    def test_poke_cancelled(self, mock_poll_query_status):
-        with pytest.raises(AirflowException) as ctx:
-            self.sensor.poke({})
-        assert "Athena sensor failed" in str(ctx.value)
+    @pytest.mark.parametrize("state", ["RUNNING", "QUEUED"])
+    def test_poke_intermediate_states(self, state, mock_poll_query_status):
+        mock_poll_query_status.side_effect = [state]
+        assert self.sensor.poke({}) is False
 
     @pytest.mark.parametrize(
-        "soft_fail, expected_exception", ((False, AirflowException), (True, 
AirflowSkipException))
+        "soft_fail, expected_exception",
+        [
+            pytest.param(False, AirflowException, id="not-soft-fail"),
+            pytest.param(True, AirflowSkipException, id="soft-fail"),
+        ],
     )
-    def test_fail_poke(self, soft_fail, expected_exception):
-        self.sensor.soft_fail = soft_fail
+    @pytest.mark.parametrize("state", ["FAILED", "CANCELLED"])
+    def test_poke_failure_states(self, state, soft_fail, expected_exception, 
mock_poll_query_status):
+        mock_poll_query_status.side_effect = [state]
+        sensor = AthenaSensor(**self.default_op_kwargs, aws_conn_id=None, 
soft_fail=soft_fail)
         message = "Athena sensor failed"
-        with pytest.raises(expected_exception, match=message), mock.patch(
-            
"airflow.providers.amazon.aws.hooks.athena.AthenaHook.poll_query_status"
-        ) as poll_query_status:
-            poll_query_status.return_value = "FAILED"
-            self.sensor.poke(context={})
+        with pytest.raises(expected_exception, match=message):
+            sensor.poke({})

Reply via email to