This is an automated email from the ASF dual-hosted git repository.

mobuchowski pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 5626590406 openlineage, aws: Add OpenLineage support for 
AthenaOperator. (#35090)
5626590406 is described below

commit 56265904062960b681dc1d5237518b3e76b87296
Author: Jakub Dardzinski <[email protected]>
AuthorDate: Tue Nov 14 19:43:55 2023 +0100

    openlineage, aws: Add OpenLineage support for AthenaOperator. (#35090)
    
    * Add OpenLineage support for AthenaOperator.
    
    Based on: https://github.com/OpenLineage/OpenLineage/pull/1328.
    
    Signed-off-by: Jakub Dardzinski <[email protected]>
    
    * Cache `get_query_execution` in AthenaHook.
    Adjust code to catalog and output_location changes.
    
    Signed-off-by: Jakub Dardzinski <[email protected]>
    
    * Change caching implementation.
    Rename `get_query_execution` to `get_query_info`.
    
    Signed-off-by: Jakub Dardzinski <[email protected]>
    
    ---------
    
    Signed-off-by: Jakub Dardzinski <[email protected]>
---
 airflow/providers/amazon/aws/hooks/athena.py       |  27 ++++-
 airflow/providers/amazon/aws/operators/athena.py   | 115 ++++++++++++++++++++
 tests/providers/amazon/aws/hooks/test_athena.py    |  19 ++++
 .../amazon/aws/operators/athena_metadata.json      |  72 +++++++++++++
 .../providers/amazon/aws/operators/test_athena.py  | 120 ++++++++++++++++++++-
 5 files changed, 347 insertions(+), 6 deletions(-)

diff --git a/airflow/providers/amazon/aws/hooks/athena.py 
b/airflow/providers/amazon/aws/hooks/athena.py
index 11148bd6b1..c8af86ee3a 100644
--- a/airflow/providers/amazon/aws/hooks/athena.py
+++ b/airflow/providers/amazon/aws/hooks/athena.py
@@ -82,6 +82,7 @@ class AthenaHook(AwsBaseHook):
         else:
             self.sleep_time = 30  # previous default value
         self.log_query = log_query
+        self.__query_results: dict[str, Any] = {}
 
     def run_query(
         self,
@@ -120,7 +121,23 @@ class AthenaHook(AwsBaseHook):
         self.log.info("Query execution id: %s", query_execution_id)
         return query_execution_id
 
-    def check_query_status(self, query_execution_id: str) -> str | None:
+    def get_query_info(self, query_execution_id: str, use_cache: bool = False) 
-> dict:
+        """Get information about a single execution of a query.
+
+        .. seealso::
+            - :external+boto3:py:meth:`Athena.Client.get_query_execution`
+
+        :param query_execution_id: Id of submitted athena query
+        :param use_cache: If True, use execution information cache
+        """
+        if use_cache and query_execution_id in self.__query_results:
+            return self.__query_results[query_execution_id]
+        response = 
self.get_conn().get_query_execution(QueryExecutionId=query_execution_id)
+        if use_cache:
+            self.__query_results[query_execution_id] = response
+        return response
+
+    def check_query_status(self, query_execution_id: str, use_cache: bool = 
False) -> str | None:
         """Fetch the state of a submitted query.
 
         .. seealso::
@@ -130,7 +147,7 @@ class AthenaHook(AwsBaseHook):
         :return: One of valid query states, or *None* if the response is
             malformed.
         """
-        response = 
self.get_conn().get_query_execution(QueryExecutionId=query_execution_id)
+        response = self.get_query_info(query_execution_id=query_execution_id, 
use_cache=use_cache)
         state = None
         try:
             state = response["QueryExecution"]["Status"]["State"]
@@ -143,7 +160,7 @@ class AthenaHook(AwsBaseHook):
             # The error is being absorbed to implement retries.
             return state
 
-    def get_state_change_reason(self, query_execution_id: str) -> str | None:
+    def get_state_change_reason(self, query_execution_id: str, use_cache: bool 
= False) -> str | None:
         """
         Fetch the reason for a state change (e.g. error message). Returns None 
or reason string.
 
@@ -152,7 +169,7 @@ class AthenaHook(AwsBaseHook):
 
         :param query_execution_id: Id of submitted athena query
         """
-        response = 
self.get_conn().get_query_execution(QueryExecutionId=query_execution_id)
+        response = self.get_query_info(query_execution_id=query_execution_id, 
use_cache=use_cache)
         reason = None
         try:
             reason = response["QueryExecution"]["Status"]["StateChangeReason"]
@@ -277,7 +294,7 @@ class AthenaHook(AwsBaseHook):
         """
         output_location = None
         if query_execution_id:
-            response = 
self.get_conn().get_query_execution(QueryExecutionId=query_execution_id)
+            response = 
self.get_query_info(query_execution_id=query_execution_id, use_cache=True)
 
             if response:
                 try:
diff --git a/airflow/providers/amazon/aws/operators/athena.py 
b/airflow/providers/amazon/aws/operators/athena.py
index 1f9e527717..3bc907227d 100644
--- a/airflow/providers/amazon/aws/operators/athena.py
+++ b/airflow/providers/amazon/aws/operators/athena.py
@@ -18,6 +18,7 @@
 from __future__ import annotations
 
 from typing import TYPE_CHECKING, Any, Sequence
+from urllib.parse import urlparse
 
 from airflow.configuration import conf
 from airflow.exceptions import AirflowException
@@ -27,6 +28,10 @@ from airflow.providers.amazon.aws.triggers.athena import 
AthenaTrigger
 from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
 
 if TYPE_CHECKING:
+    from openlineage.client.facet import BaseFacet
+    from openlineage.client.run import Dataset
+
+    from airflow.providers.openlineage.extractors.base import OperatorLineage
     from airflow.utils.context import Context
 
 
@@ -160,6 +165,9 @@ class AthenaOperator(AwsBaseOperator[AthenaHook]):
                 f"query_execution_id is {self.query_execution_id}."
             )
 
+        # Save output location from API response for later use in OpenLineage.
+        self.output_location = 
self.hook.get_output_location(self.query_execution_id)
+
         return self.query_execution_id
 
     def execute_complete(self, context, event=None):
@@ -187,3 +195,110 @@ class AthenaOperator(AwsBaseOperator[AthenaHook]):
                         "Polling Athena for query with id %s to reach final 
state", self.query_execution_id
                     )
                     self.hook.poll_query_status(self.query_execution_id, 
sleep_time=self.sleep_time)
+
+    def get_openlineage_facets_on_start(self) -> OperatorLineage:
+        """Retrieve OpenLineage data by parsing SQL queries and enriching them 
with Athena API.
+
+        In addition to CTAS query, query and calculation results are stored in 
S3 location.
+        For that reason additional output is attached with this location.
+        """
+        from openlineage.client.facet import ExtractionError, 
ExtractionErrorRunFacet, SqlJobFacet
+        from openlineage.client.run import Dataset
+
+        from airflow.providers.openlineage.extractors.base import 
OperatorLineage
+        from airflow.providers.openlineage.sqlparser import SQLParser
+
+        sql_parser = SQLParser(dialect="generic")
+
+        job_facets: dict[str, BaseFacet] = {"sql": 
SqlJobFacet(query=sql_parser.normalize_sql(self.query))}
+        parse_result = sql_parser.parse(sql=self.query)
+
+        if not parse_result:
+            return OperatorLineage(job_facets=job_facets)
+
+        run_facets: dict[str, BaseFacet] = {}
+        if parse_result.errors:
+            run_facets["extractionError"] = ExtractionErrorRunFacet(
+                totalTasks=len(self.query) if isinstance(self.query, list) 
else 1,
+                failedTasks=len(parse_result.errors),
+                errors=[
+                    ExtractionError(
+                        errorMessage=error.message,
+                        stackTrace=None,
+                        task=error.origin_statement,
+                        taskNumber=error.index,
+                    )
+                    for error in parse_result.errors
+                ],
+            )
+
+        inputs: list[Dataset] = list(
+            filter(
+                None,
+                [
+                    self.get_openlineage_dataset(table.schema or 
self.database, table.name)
+                    for table in parse_result.in_tables
+                ],
+            )
+        )
+
+        outputs: list[Dataset] = list(
+            filter(
+                None,
+                [
+                    self.get_openlineage_dataset(table.schema or 
self.database, table.name)
+                    for table in parse_result.out_tables
+                ],
+            )
+        )
+
+        if self.output_location:
+            parsed = urlparse(self.output_location)
+            
outputs.append(Dataset(namespace=f"{parsed.scheme}://{parsed.netloc}", 
name=parsed.path))
+
+        return OperatorLineage(job_facets=job_facets, run_facets=run_facets, 
inputs=inputs, outputs=outputs)
+
+    def get_openlineage_dataset(self, database, table) -> Dataset | None:
+        from openlineage.client.facet import (
+            SchemaDatasetFacet,
+            SchemaField,
+            SymlinksDatasetFacet,
+            SymlinksDatasetFacetIdentifiers,
+        )
+        from openlineage.client.run import Dataset
+
+        client = self.hook.get_conn()
+        try:
+            table_metadata = client.get_table_metadata(
+                CatalogName=self.catalog, DatabaseName=database, 
TableName=table
+            )
+
+            # Dataset has also its' physical location which we can add in 
symlink facet.
+            s3_location = 
table_metadata["TableMetadata"]["Parameters"]["location"]
+            parsed_path = urlparse(s3_location)
+            facets: dict[str, BaseFacet] = {
+                "symlinks": SymlinksDatasetFacet(
+                    identifiers=[
+                        SymlinksDatasetFacetIdentifiers(
+                            
namespace=f"{parsed_path.scheme}://{parsed_path.netloc}",
+                            name=str(parsed_path.path),
+                            type="TABLE",
+                        )
+                    ]
+                )
+            }
+            fields = [
+                SchemaField(name=column["Name"], type=column["Type"], 
description=column["Comment"])
+                for column in table_metadata["TableMetadata"]["Columns"]
+            ]
+            if fields:
+                facets["schema"] = SchemaDatasetFacet(fields=fields)
+            return Dataset(
+                
namespace=f"awsathena://athena.{self.hook.region_name}.amazonaws.com",
+                name=".".join(filter(None, (self.catalog, database, table))),
+                facets=facets,
+            )
+
+        except Exception as e:
+            self.log.error("Cannot retrieve table metadata from Athena.Client. 
%s", e)
+            return None
diff --git a/tests/providers/amazon/aws/hooks/test_athena.py 
b/tests/providers/amazon/aws/hooks/test_athena.py
index 05ed6e9e30..8f224f0b2d 100644
--- a/tests/providers/amazon/aws/hooks/test_athena.py
+++ b/tests/providers/amazon/aws/hooks/test_athena.py
@@ -43,6 +43,7 @@ MOCK_QUERY_EXECUTION_OUTPUT = {
     "QueryExecution": {
         "QueryExecutionId": MOCK_DATA["query_execution_id"],
         "ResultConfiguration": {"OutputLocation": "s3://test_bucket/test.csv"},
+        "Status": {"StateChangeReason": "Terminated by user."},
     }
 }
 
@@ -195,3 +196,21 @@ class TestAthenaHook:
         mock_conn.return_value.get_query_execution.return_value = 
MOCK_QUERY_EXECUTION_OUTPUT
         result = 
self.athena.get_output_location(query_execution_id=MOCK_DATA["query_execution_id"])
         assert result == "s3://test_bucket/test.csv"
+
+    @mock.patch.object(AthenaHook, "get_conn")
+    def test_hook_get_query_info_caching(self, mock_conn):
+        mock_conn.return_value.get_query_execution.return_value = 
MOCK_QUERY_EXECUTION_OUTPUT
+        
self.athena.get_state_change_reason(query_execution_id=MOCK_DATA["query_execution_id"])
+        assert not self.athena._AthenaHook__query_results
+        # get_output_location uses cache
+        
self.athena.get_output_location(query_execution_id=MOCK_DATA["query_execution_id"])
+        assert MOCK_DATA["query_execution_id"] in 
self.athena._AthenaHook__query_results
+        mock_conn.return_value.get_query_execution.assert_called_with(
+            QueryExecutionId=MOCK_DATA["query_execution_id"]
+        )
+        self.athena.get_state_change_reason(
+            query_execution_id=MOCK_DATA["query_execution_id"], use_cache=False
+        )
+        mock_conn.return_value.get_query_execution.assert_called_with(
+            QueryExecutionId=MOCK_DATA["query_execution_id"]
+        )
diff --git a/tests/providers/amazon/aws/operators/athena_metadata.json 
b/tests/providers/amazon/aws/operators/athena_metadata.json
new file mode 100644
index 0000000000..f13b124174
--- /dev/null
+++ b/tests/providers/amazon/aws/operators/athena_metadata.json
@@ -0,0 +1,72 @@
+{
+    "DISCOUNTS": {
+      "TableMetadata": {
+        "Name": "DISCOUNTS",
+        "CreateTime": 1593559968.0,
+        "LastAccessTime": 0.0,
+        "TableType": "EXTERNAL_TABLE",
+        "Columns": [
+          {
+            "Name": "ID",
+            "Type": "int",
+            "Comment": "from deserializer"
+          },
+          {
+            "Name": "AMOUNT_OFF",
+            "Type": "int",
+            "Comment": "from deserializer"
+          },
+          {
+            "Name": "CUSTOMER_EMAIL",
+            "Type": "varchar",
+            "Comment": "from deserializer"
+          },
+          {
+            "Name": "STARTS_ON",
+            "Type": "timestamp",
+            "Comment": "from deserializer"
+          },
+          {
+            "Name": "ENDS_ON",
+            "Type": "timestamp",
+            "Comment": "from deserializer"
+          }
+        ],
+        "PartitionKeys": [],
+        "Parameters": {
+          "EXTERNAL": "TRUE",
+          "inputformat": "com.esri.json.hadoop.EnclosedJsonInputFormat",
+          "location": "s3://bucket/discount/data/path/",
+          "outputformat": 
"org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat",
+          "serde.param.serialization.format": "1",
+          "serde.serialization.lib": "com.esri.hadoop.hive.serde.JsonSerde",
+          "transient_lastDdlTime": "1593559968"
+        }
+      }
+    },
+    "TEST_TABLE": {
+      "TableMetadata": {
+        "Name": "TEST_TABLE",
+        "CreateTime": 1593559968.0,
+        "LastAccessTime": 0.0,
+        "TableType": "EXTERNAL_TABLE",
+        "Columns": [
+          {
+            "Name": "column",
+            "Type": "string",
+            "Comment": "from deserializer"
+          }
+        ],
+        "PartitionKeys": [],
+        "Parameters": {
+          "EXTERNAL": "TRUE",
+          "inputformat": "com.esri.json.hadoop.EnclosedJsonInputFormat",
+          "location": "s3://bucket/data/test_table/data/path",
+          "outputformat": 
"org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat",
+          "serde.param.serialization.format": "1",
+          "serde.serialization.lib": "com.esri.hadoop.hive.serde.JsonSerde",
+          "transient_lastDdlTime": "1593559968"
+        }
+      }
+    }
+  }
diff --git a/tests/providers/amazon/aws/operators/test_athena.py 
b/tests/providers/amazon/aws/operators/test_athena.py
index 497e2c2127..5820827b05 100644
--- a/tests/providers/amazon/aws/operators/test_athena.py
+++ b/tests/providers/amazon/aws/operators/test_athena.py
@@ -16,15 +16,25 @@
 # under the License.
 from __future__ import annotations
 
+import json
 from unittest import mock
 
 import pytest
+from openlineage.client.facet import (
+    SchemaDatasetFacet,
+    SchemaField,
+    SqlJobFacet,
+    SymlinksDatasetFacet,
+    SymlinksDatasetFacetIdentifiers,
+)
+from openlineage.client.run import Dataset
 
 from airflow.exceptions import TaskDeferred
 from airflow.models import DAG, DagRun, TaskInstance
 from airflow.providers.amazon.aws.hooks.athena import AthenaHook
 from airflow.providers.amazon.aws.operators.athena import AthenaOperator
 from airflow.providers.amazon.aws.triggers.athena import AthenaTrigger
+from airflow.providers.openlineage.extractors import OperatorLineage
 from airflow.utils import timezone
 from airflow.utils.timezone import datetime
 
@@ -150,7 +160,11 @@ class TestAthenaOperator:
     @mock.patch.object(AthenaHook, "run_query", return_value=ATHENA_QUERY_ID)
     @mock.patch.object(AthenaHook, "get_conn")
     def test_hook_run_failure_query(
-        self, mock_conn, mock_run_query, mock_check_query_status, 
mock_get_state_change_reason
+        self,
+        mock_conn,
+        mock_run_query,
+        mock_check_query_status,
+        mock_get_state_change_reason,
     ):
         with pytest.raises(Exception):
             self.athena.execute({})
@@ -226,3 +240,107 @@ class TestAthenaOperator:
             self.athena.execute(None)
 
         assert isinstance(deferred.value.trigger, AthenaTrigger)
+
+    @mock.patch.object(AthenaHook, "region_name", 
new_callable=mock.PropertyMock)
+    @mock.patch.object(AthenaHook, "get_conn")
+    def test_operator_openlineage_data(self, mock_conn, mock_region_name):
+        mock_region_name.return_value = "eu-west-1"
+
+        def mock_get_table_metadata(CatalogName, DatabaseName, TableName):
+            with 
open("tests/providers/amazon/aws/operators/athena_metadata.json") as f:
+                return json.load(f)[TableName]
+
+        mock_conn.return_value.get_table_metadata = mock_get_table_metadata
+
+        op = AthenaOperator(
+            task_id="test_athena_openlineage",
+            query="INSERT INTO TEST_TABLE SELECT CUSTOMER_EMAIL FROM 
DISCOUNTS",
+            database="TEST_DATABASE",
+            output_location="s3://test_s3_bucket/",
+            client_request_token="eac427d0-1c6d-4dfb-96aa-2835d3ac6595",
+            sleep_time=0,
+            max_polling_attempts=3,
+            dag=self.dag,
+        )
+
+        expected_lineage = OperatorLineage(
+            inputs=[
+                Dataset(
+                    namespace="awsathena://athena.eu-west-1.amazonaws.com",
+                    name="AwsDataCatalog.TEST_DATABASE.DISCOUNTS",
+                    facets={
+                        "symlinks": SymlinksDatasetFacet(
+                            identifiers=[
+                                SymlinksDatasetFacetIdentifiers(
+                                    namespace="s3://bucket",
+                                    name="/discount/data/path/",
+                                    type="TABLE",
+                                )
+                            ],
+                        ),
+                        "schema": SchemaDatasetFacet(
+                            fields=[
+                                SchemaField(
+                                    name="ID",
+                                    type="int",
+                                    description="from deserializer",
+                                ),
+                                SchemaField(
+                                    name="AMOUNT_OFF",
+                                    type="int",
+                                    description="from deserializer",
+                                ),
+                                SchemaField(
+                                    name="CUSTOMER_EMAIL",
+                                    type="varchar",
+                                    description="from deserializer",
+                                ),
+                                SchemaField(
+                                    name="STARTS_ON",
+                                    type="timestamp",
+                                    description="from deserializer",
+                                ),
+                                SchemaField(
+                                    name="ENDS_ON",
+                                    type="timestamp",
+                                    description="from deserializer",
+                                ),
+                            ],
+                        ),
+                    },
+                )
+            ],
+            outputs=[
+                Dataset(
+                    namespace="awsathena://athena.eu-west-1.amazonaws.com",
+                    name="AwsDataCatalog.TEST_DATABASE.TEST_TABLE",
+                    facets={
+                        "symlinks": SymlinksDatasetFacet(
+                            identifiers=[
+                                SymlinksDatasetFacetIdentifiers(
+                                    namespace="s3://bucket",
+                                    name="/data/test_table/data/path",
+                                    type="TABLE",
+                                )
+                            ],
+                        ),
+                        "schema": SchemaDatasetFacet(
+                            fields=[
+                                SchemaField(
+                                    name="column",
+                                    type="string",
+                                    description="from deserializer",
+                                )
+                            ],
+                        ),
+                    },
+                ),
+                Dataset(namespace="s3://test_s3_bucket", name="/"),
+            ],
+            job_facets={
+                "sql": SqlJobFacet(
+                    query="INSERT INTO TEST_TABLE SELECT CUSTOMER_EMAIL FROM 
DISCOUNTS",
+                )
+            },
+        )
+        assert op.get_openlineage_facets_on_start() == expected_lineage

Reply via email to