This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new d73bef2a43 Add Amazon Athena query results extra link (#36447)
d73bef2a43 is described below

commit d73bef2a435ad5bf9e482986614c1e349beb5137
Author: Andrey Anshin <[email protected]>
AuthorDate: Wed Dec 27 21:44:21 2023 +0400

    Add Amazon Athena query results extra link (#36447)
---
 airflow/providers/amazon/aws/links/athena.py       | 30 +++++++++++++++++++
 airflow/providers/amazon/aws/operators/athena.py   |  9 ++++++
 airflow/providers/amazon/provider.yaml             |  1 +
 tests/providers/amazon/aws/links/test_athena.py    | 35 ++++++++++++++++++++++
 .../providers/amazon/aws/operators/test_athena.py  | 25 +++++++++++++++-
 5 files changed, 99 insertions(+), 1 deletion(-)

diff --git a/airflow/providers/amazon/aws/links/athena.py 
b/airflow/providers/amazon/aws/links/athena.py
new file mode 100644
index 0000000000..99950d2580
--- /dev/null
+++ b/airflow/providers/amazon/aws/links/athena.py
@@ -0,0 +1,30 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from airflow.providers.amazon.aws.links.base_aws import BASE_AWS_CONSOLE_LINK, 
BaseAwsLink
+
+
+class AthenaQueryResultsLink(BaseAwsLink):
+    """Helper class for constructing Amazon Athena query results."""
+
+    name = "Query Results"
+    key = "_athena_query_results"
+    format_str = (
+        BASE_AWS_CONSOLE_LINK + "/athena/home?region={region_name}#"
+        "/query-editor/history/{query_execution_id}"
+    )
diff --git a/airflow/providers/amazon/aws/operators/athena.py 
b/airflow/providers/amazon/aws/operators/athena.py
index 3bc907227d..90b2e7cdba 100644
--- a/airflow/providers/amazon/aws/operators/athena.py
+++ b/airflow/providers/amazon/aws/operators/athena.py
@@ -23,6 +23,7 @@ from urllib.parse import urlparse
 from airflow.configuration import conf
 from airflow.exceptions import AirflowException
 from airflow.providers.amazon.aws.hooks.athena import AthenaHook
+from airflow.providers.amazon.aws.links.athena import AthenaQueryResultsLink
 from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
 from airflow.providers.amazon.aws.triggers.athena import AthenaTrigger
 from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
@@ -82,6 +83,7 @@ class AthenaOperator(AwsBaseOperator[AthenaHook]):
     )
     template_ext: Sequence[str] = (".sql",)
     template_fields_renderers = {"query": "sql"}
+    operator_extra_links = (AthenaQueryResultsLink(),)
 
     def __init__(
         self,
@@ -132,6 +134,13 @@ class AthenaOperator(AwsBaseOperator[AthenaHook]):
             self.client_request_token,
             self.workgroup,
         )
+        AthenaQueryResultsLink.persist(
+            context=context,
+            operator=self,
+            region_name=self.hook.conn_region_name,
+            aws_partition=self.hook.conn_partition,
+            query_execution_id=self.query_execution_id,
+        )
 
         if self.deferrable:
             self.defer(
diff --git a/airflow/providers/amazon/provider.yaml 
b/airflow/providers/amazon/provider.yaml
index d96fd43a91..e6b220c080 100644
--- a/airflow/providers/amazon/provider.yaml
+++ b/airflow/providers/amazon/provider.yaml
@@ -711,6 +711,7 @@ transfers:
     python-module: airflow.providers.amazon.aws.transfers.azure_blob_to_s3
 
 extra-links:
+  - airflow.providers.amazon.aws.links.athena.AthenaQueryResultsLink
   - airflow.providers.amazon.aws.links.batch.BatchJobDefinitionLink
   - airflow.providers.amazon.aws.links.batch.BatchJobDetailsLink
   - airflow.providers.amazon.aws.links.batch.BatchJobQueueLink
diff --git a/tests/providers/amazon/aws/links/test_athena.py 
b/tests/providers/amazon/aws/links/test_athena.py
new file mode 100644
index 0000000000..1729fdf4e5
--- /dev/null
+++ b/tests/providers/amazon/aws/links/test_athena.py
@@ -0,0 +1,35 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from airflow.providers.amazon.aws.links.athena import AthenaQueryResultsLink
+from tests.providers.amazon.aws.links.test_base_aws import BaseAwsLinksTestCase
+
+
+class TestAthenaQueryResultsLink(BaseAwsLinksTestCase):
+    link_class = AthenaQueryResultsLink
+
+    def test_extra_link(self):
+        self.assert_extra_link_url(
+            expected_url=(
+                "https://console.aws.amazon.com/athena/home";
+                
"?region=eu-west-1#/query-editor/history/00000000-0000-0000-0000-000000000000"
+            ),
+            region_name="eu-west-1",
+            aws_partition="aws",
+            query_execution_id="00000000-0000-0000-0000-000000000000",
+        )
diff --git a/tests/providers/amazon/aws/operators/test_athena.py 
b/tests/providers/amazon/aws/operators/test_athena.py
index 5820827b05..4c7e564ccc 100644
--- a/tests/providers/amazon/aws/operators/test_athena.py
+++ b/tests/providers/amazon/aws/operators/test_athena.py
@@ -57,7 +57,8 @@ result_configuration = {"OutputLocation": 
MOCK_DATA["outputLocation"]}
 
 
 class TestAthenaOperator:
-    def setup_method(self):
+    @pytest.fixture(autouse=True)
+    def setup_test_cases(self):
         args = {
             "owner": "airflow",
             "start_date": DEFAULT_DATE,
@@ -77,6 +78,10 @@ class TestAthenaOperator:
             **self.default_op_kwargs, output_location="s3://test_s3_bucket/", 
aws_conn_id=None, dag=self.dag
         )
 
+        with 
mock.patch("airflow.providers.amazon.aws.links.athena.AthenaQueryResultsLink.persist")
 as m:
+            self.mocked_athena_result_link = m
+            yield
+
     def test_base_aws_op_attributes(self):
         op = AthenaOperator(**self.default_op_kwargs)
         assert op.hook.aws_conn_id == "aws_default"
@@ -138,6 +143,15 @@ class TestAthenaOperator:
         )
         assert mock_check_query_status.call_count == 1
 
+        # Validate call persist Athena Query result link
+        self.mocked_athena_result_link.assert_called_once_with(
+            aws_partition=mock.ANY,
+            context=mock.ANY,
+            operator=mock.ANY,
+            region_name=mock.ANY,
+            query_execution_id=ATHENA_QUERY_ID,
+        )
+
     @mock.patch.object(
         AthenaHook,
         "check_query_status",
@@ -241,6 +255,15 @@ class TestAthenaOperator:
 
         assert isinstance(deferred.value.trigger, AthenaTrigger)
 
+        # Validate call persist Athena Query result link
+        self.mocked_athena_result_link.assert_called_once_with(
+            aws_partition=mock.ANY,
+            context=mock.ANY,
+            operator=mock.ANY,
+            region_name=mock.ANY,
+            query_execution_id=ATHENA_QUERY_ID,
+        )
+
     @mock.patch.object(AthenaHook, "region_name", 
new_callable=mock.PropertyMock)
     @mock.patch.object(AthenaHook, "get_conn")
     def test_operator_openlineage_data(self, mock_conn, mock_region_name):

Reply via email to