This is an automated email from the ASF dual-hosted git repository.
mobuchowski pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 5626590406 openlineage, aws: Add OpenLineage support for
AthenaOperator. (#35090)
5626590406 is described below
commit 56265904062960b681dc1d5237518b3e76b87296
Author: Jakub Dardzinski <[email protected]>
AuthorDate: Tue Nov 14 19:43:55 2023 +0100
openlineage, aws: Add OpenLineage support for AthenaOperator. (#35090)
* Add OpenLineage support for AthenaOperator.
Based on: https://github.com/OpenLineage/OpenLineage/pull/1328.
Signed-off-by: Jakub Dardzinski <[email protected]>
* Cache `get_query_execution` in AthenaHook.
Adjust code to catalog and output_location changes.
Signed-off-by: Jakub Dardzinski <[email protected]>
* Change caching implementation.
Rename `get_query_execution` to `get_query_info`.
Signed-off-by: Jakub Dardzinski <[email protected]>
---------
Signed-off-by: Jakub Dardzinski <[email protected]>
---
airflow/providers/amazon/aws/hooks/athena.py | 27 ++++-
airflow/providers/amazon/aws/operators/athena.py | 115 ++++++++++++++++++++
tests/providers/amazon/aws/hooks/test_athena.py | 19 ++++
.../amazon/aws/operators/athena_metadata.json | 72 +++++++++++++
.../providers/amazon/aws/operators/test_athena.py | 120 ++++++++++++++++++++-
5 files changed, 347 insertions(+), 6 deletions(-)
diff --git a/airflow/providers/amazon/aws/hooks/athena.py
b/airflow/providers/amazon/aws/hooks/athena.py
index 11148bd6b1..c8af86ee3a 100644
--- a/airflow/providers/amazon/aws/hooks/athena.py
+++ b/airflow/providers/amazon/aws/hooks/athena.py
@@ -82,6 +82,7 @@ class AthenaHook(AwsBaseHook):
else:
self.sleep_time = 30 # previous default value
self.log_query = log_query
+ self.__query_results: dict[str, Any] = {}
def run_query(
self,
@@ -120,7 +121,23 @@ class AthenaHook(AwsBaseHook):
self.log.info("Query execution id: %s", query_execution_id)
return query_execution_id
- def check_query_status(self, query_execution_id: str) -> str | None:
+ def get_query_info(self, query_execution_id: str, use_cache: bool = False)
-> dict:
+ """Get information about a single execution of a query.
+
+ .. seealso::
+ - :external+boto3:py:meth:`Athena.Client.get_query_execution`
+
+ :param query_execution_id: Id of submitted athena query
+ :param use_cache: If True, use execution information cache
+ """
+ if use_cache and query_execution_id in self.__query_results:
+ return self.__query_results[query_execution_id]
+ response =
self.get_conn().get_query_execution(QueryExecutionId=query_execution_id)
+ if use_cache:
+ self.__query_results[query_execution_id] = response
+ return response
+
+ def check_query_status(self, query_execution_id: str, use_cache: bool =
False) -> str | None:
"""Fetch the state of a submitted query.
.. seealso::
@@ -130,7 +147,7 @@ class AthenaHook(AwsBaseHook):
:return: One of valid query states, or *None* if the response is
malformed.
"""
- response =
self.get_conn().get_query_execution(QueryExecutionId=query_execution_id)
+ response = self.get_query_info(query_execution_id=query_execution_id,
use_cache=use_cache)
state = None
try:
state = response["QueryExecution"]["Status"]["State"]
@@ -143,7 +160,7 @@ class AthenaHook(AwsBaseHook):
# The error is being absorbed to implement retries.
return state
- def get_state_change_reason(self, query_execution_id: str) -> str | None:
+ def get_state_change_reason(self, query_execution_id: str, use_cache: bool
= False) -> str | None:
"""
Fetch the reason for a state change (e.g. error message). Returns None
or reason string.
@@ -152,7 +169,7 @@ class AthenaHook(AwsBaseHook):
:param query_execution_id: Id of submitted athena query
"""
- response =
self.get_conn().get_query_execution(QueryExecutionId=query_execution_id)
+ response = self.get_query_info(query_execution_id=query_execution_id,
use_cache=use_cache)
reason = None
try:
reason = response["QueryExecution"]["Status"]["StateChangeReason"]
@@ -277,7 +294,7 @@ class AthenaHook(AwsBaseHook):
"""
output_location = None
if query_execution_id:
- response =
self.get_conn().get_query_execution(QueryExecutionId=query_execution_id)
+ response =
self.get_query_info(query_execution_id=query_execution_id, use_cache=True)
if response:
try:
diff --git a/airflow/providers/amazon/aws/operators/athena.py
b/airflow/providers/amazon/aws/operators/athena.py
index 1f9e527717..3bc907227d 100644
--- a/airflow/providers/amazon/aws/operators/athena.py
+++ b/airflow/providers/amazon/aws/operators/athena.py
@@ -18,6 +18,7 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Sequence
+from urllib.parse import urlparse
from airflow.configuration import conf
from airflow.exceptions import AirflowException
@@ -27,6 +28,10 @@ from airflow.providers.amazon.aws.triggers.athena import
AthenaTrigger
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
if TYPE_CHECKING:
+ from openlineage.client.facet import BaseFacet
+ from openlineage.client.run import Dataset
+
+ from airflow.providers.openlineage.extractors.base import OperatorLineage
from airflow.utils.context import Context
@@ -160,6 +165,9 @@ class AthenaOperator(AwsBaseOperator[AthenaHook]):
f"query_execution_id is {self.query_execution_id}."
)
+ # Save output location from API response for later use in OpenLineage.
+ self.output_location =
self.hook.get_output_location(self.query_execution_id)
+
return self.query_execution_id
def execute_complete(self, context, event=None):
@@ -187,3 +195,110 @@ class AthenaOperator(AwsBaseOperator[AthenaHook]):
"Polling Athena for query with id %s to reach final
state", self.query_execution_id
)
self.hook.poll_query_status(self.query_execution_id,
sleep_time=self.sleep_time)
+
+ def get_openlineage_facets_on_start(self) -> OperatorLineage:
+ """Retrieve OpenLineage data by parsing SQL queries and enriching them
with Athena API.
+
+ In addition to CTAS query, query and calculation results are stored in
S3 location.
+ For that reason additional output is attached with this location.
+ """
+ from openlineage.client.facet import ExtractionError,
ExtractionErrorRunFacet, SqlJobFacet
+ from openlineage.client.run import Dataset
+
+ from airflow.providers.openlineage.extractors.base import
OperatorLineage
+ from airflow.providers.openlineage.sqlparser import SQLParser
+
+ sql_parser = SQLParser(dialect="generic")
+
+ job_facets: dict[str, BaseFacet] = {"sql":
SqlJobFacet(query=sql_parser.normalize_sql(self.query))}
+ parse_result = sql_parser.parse(sql=self.query)
+
+ if not parse_result:
+ return OperatorLineage(job_facets=job_facets)
+
+ run_facets: dict[str, BaseFacet] = {}
+ if parse_result.errors:
+ run_facets["extractionError"] = ExtractionErrorRunFacet(
+ totalTasks=len(self.query) if isinstance(self.query, list)
else 1,
+ failedTasks=len(parse_result.errors),
+ errors=[
+ ExtractionError(
+ errorMessage=error.message,
+ stackTrace=None,
+ task=error.origin_statement,
+ taskNumber=error.index,
+ )
+ for error in parse_result.errors
+ ],
+ )
+
+ inputs: list[Dataset] = list(
+ filter(
+ None,
+ [
+ self.get_openlineage_dataset(table.schema or
self.database, table.name)
+ for table in parse_result.in_tables
+ ],
+ )
+ )
+
+ outputs: list[Dataset] = list(
+ filter(
+ None,
+ [
+ self.get_openlineage_dataset(table.schema or
self.database, table.name)
+ for table in parse_result.out_tables
+ ],
+ )
+ )
+
+ if self.output_location:
+ parsed = urlparse(self.output_location)
+
outputs.append(Dataset(namespace=f"{parsed.scheme}://{parsed.netloc}",
name=parsed.path))
+
+ return OperatorLineage(job_facets=job_facets, run_facets=run_facets,
inputs=inputs, outputs=outputs)
+
+ def get_openlineage_dataset(self, database, table) -> Dataset | None:
+ from openlineage.client.facet import (
+ SchemaDatasetFacet,
+ SchemaField,
+ SymlinksDatasetFacet,
+ SymlinksDatasetFacetIdentifiers,
+ )
+ from openlineage.client.run import Dataset
+
+ client = self.hook.get_conn()
+ try:
+ table_metadata = client.get_table_metadata(
+ CatalogName=self.catalog, DatabaseName=database,
TableName=table
+ )
+
+ # Dataset has also its' physical location which we can add in
symlink facet.
+ s3_location =
table_metadata["TableMetadata"]["Parameters"]["location"]
+ parsed_path = urlparse(s3_location)
+ facets: dict[str, BaseFacet] = {
+ "symlinks": SymlinksDatasetFacet(
+ identifiers=[
+ SymlinksDatasetFacetIdentifiers(
+
namespace=f"{parsed_path.scheme}://{parsed_path.netloc}",
+ name=str(parsed_path.path),
+ type="TABLE",
+ )
+ ]
+ )
+ }
+ fields = [
+ SchemaField(name=column["Name"], type=column["Type"],
description=column["Comment"])
+ for column in table_metadata["TableMetadata"]["Columns"]
+ ]
+ if fields:
+ facets["schema"] = SchemaDatasetFacet(fields=fields)
+ return Dataset(
+
namespace=f"awsathena://athena.{self.hook.region_name}.amazonaws.com",
+ name=".".join(filter(None, (self.catalog, database, table))),
+ facets=facets,
+ )
+
+ except Exception as e:
+ self.log.error("Cannot retrieve table metadata from Athena.Client.
%s", e)
+ return None
diff --git a/tests/providers/amazon/aws/hooks/test_athena.py
b/tests/providers/amazon/aws/hooks/test_athena.py
index 05ed6e9e30..8f224f0b2d 100644
--- a/tests/providers/amazon/aws/hooks/test_athena.py
+++ b/tests/providers/amazon/aws/hooks/test_athena.py
@@ -43,6 +43,7 @@ MOCK_QUERY_EXECUTION_OUTPUT = {
"QueryExecution": {
"QueryExecutionId": MOCK_DATA["query_execution_id"],
"ResultConfiguration": {"OutputLocation": "s3://test_bucket/test.csv"},
+ "Status": {"StateChangeReason": "Terminated by user."},
}
}
@@ -195,3 +196,21 @@ class TestAthenaHook:
mock_conn.return_value.get_query_execution.return_value =
MOCK_QUERY_EXECUTION_OUTPUT
result =
self.athena.get_output_location(query_execution_id=MOCK_DATA["query_execution_id"])
assert result == "s3://test_bucket/test.csv"
+
+ @mock.patch.object(AthenaHook, "get_conn")
+ def test_hook_get_query_info_caching(self, mock_conn):
+ mock_conn.return_value.get_query_execution.return_value =
MOCK_QUERY_EXECUTION_OUTPUT
+
self.athena.get_state_change_reason(query_execution_id=MOCK_DATA["query_execution_id"])
+ assert not self.athena._AthenaHook__query_results
+ # get_output_location uses cache
+
self.athena.get_output_location(query_execution_id=MOCK_DATA["query_execution_id"])
+ assert MOCK_DATA["query_execution_id"] in
self.athena._AthenaHook__query_results
+ mock_conn.return_value.get_query_execution.assert_called_with(
+ QueryExecutionId=MOCK_DATA["query_execution_id"]
+ )
+ self.athena.get_state_change_reason(
+ query_execution_id=MOCK_DATA["query_execution_id"], use_cache=False
+ )
+ mock_conn.return_value.get_query_execution.assert_called_with(
+ QueryExecutionId=MOCK_DATA["query_execution_id"]
+ )
diff --git a/tests/providers/amazon/aws/operators/athena_metadata.json
b/tests/providers/amazon/aws/operators/athena_metadata.json
new file mode 100644
index 0000000000..f13b124174
--- /dev/null
+++ b/tests/providers/amazon/aws/operators/athena_metadata.json
@@ -0,0 +1,72 @@
+{
+ "DISCOUNTS": {
+ "TableMetadata": {
+ "Name": "DISCOUNTS",
+ "CreateTime": 1593559968.0,
+ "LastAccessTime": 0.0,
+ "TableType": "EXTERNAL_TABLE",
+ "Columns": [
+ {
+ "Name": "ID",
+ "Type": "int",
+ "Comment": "from deserializer"
+ },
+ {
+ "Name": "AMOUNT_OFF",
+ "Type": "int",
+ "Comment": "from deserializer"
+ },
+ {
+ "Name": "CUSTOMER_EMAIL",
+ "Type": "varchar",
+ "Comment": "from deserializer"
+ },
+ {
+ "Name": "STARTS_ON",
+ "Type": "timestamp",
+ "Comment": "from deserializer"
+ },
+ {
+ "Name": "ENDS_ON",
+ "Type": "timestamp",
+ "Comment": "from deserializer"
+ }
+ ],
+ "PartitionKeys": [],
+ "Parameters": {
+ "EXTERNAL": "TRUE",
+ "inputformat": "com.esri.json.hadoop.EnclosedJsonInputFormat",
+ "location": "s3://bucket/discount/data/path/",
+ "outputformat":
"org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat",
+ "serde.param.serialization.format": "1",
+ "serde.serialization.lib": "com.esri.hadoop.hive.serde.JsonSerde",
+ "transient_lastDdlTime": "1593559968"
+ }
+ }
+ },
+ "TEST_TABLE": {
+ "TableMetadata": {
+ "Name": "TEST_TABLE",
+ "CreateTime": 1593559968.0,
+ "LastAccessTime": 0.0,
+ "TableType": "EXTERNAL_TABLE",
+ "Columns": [
+ {
+ "Name": "column",
+ "Type": "string",
+ "Comment": "from deserializer"
+ }
+ ],
+ "PartitionKeys": [],
+ "Parameters": {
+ "EXTERNAL": "TRUE",
+ "inputformat": "com.esri.json.hadoop.EnclosedJsonInputFormat",
+ "location": "s3://bucket/data/test_table/data/path",
+ "outputformat":
"org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat",
+ "serde.param.serialization.format": "1",
+ "serde.serialization.lib": "com.esri.hadoop.hive.serde.JsonSerde",
+ "transient_lastDdlTime": "1593559968"
+ }
+ }
+ }
+ }
diff --git a/tests/providers/amazon/aws/operators/test_athena.py
b/tests/providers/amazon/aws/operators/test_athena.py
index 497e2c2127..5820827b05 100644
--- a/tests/providers/amazon/aws/operators/test_athena.py
+++ b/tests/providers/amazon/aws/operators/test_athena.py
@@ -16,15 +16,25 @@
# under the License.
from __future__ import annotations
+import json
from unittest import mock
import pytest
+from openlineage.client.facet import (
+ SchemaDatasetFacet,
+ SchemaField,
+ SqlJobFacet,
+ SymlinksDatasetFacet,
+ SymlinksDatasetFacetIdentifiers,
+)
+from openlineage.client.run import Dataset
from airflow.exceptions import TaskDeferred
from airflow.models import DAG, DagRun, TaskInstance
from airflow.providers.amazon.aws.hooks.athena import AthenaHook
from airflow.providers.amazon.aws.operators.athena import AthenaOperator
from airflow.providers.amazon.aws.triggers.athena import AthenaTrigger
+from airflow.providers.openlineage.extractors import OperatorLineage
from airflow.utils import timezone
from airflow.utils.timezone import datetime
@@ -150,7 +160,11 @@ class TestAthenaOperator:
@mock.patch.object(AthenaHook, "run_query", return_value=ATHENA_QUERY_ID)
@mock.patch.object(AthenaHook, "get_conn")
def test_hook_run_failure_query(
- self, mock_conn, mock_run_query, mock_check_query_status,
mock_get_state_change_reason
+ self,
+ mock_conn,
+ mock_run_query,
+ mock_check_query_status,
+ mock_get_state_change_reason,
):
with pytest.raises(Exception):
self.athena.execute({})
@@ -226,3 +240,107 @@ class TestAthenaOperator:
self.athena.execute(None)
assert isinstance(deferred.value.trigger, AthenaTrigger)
+
+ @mock.patch.object(AthenaHook, "region_name",
new_callable=mock.PropertyMock)
+ @mock.patch.object(AthenaHook, "get_conn")
+ def test_operator_openlineage_data(self, mock_conn, mock_region_name):
+ mock_region_name.return_value = "eu-west-1"
+
+ def mock_get_table_metadata(CatalogName, DatabaseName, TableName):
+ with
open("tests/providers/amazon/aws/operators/athena_metadata.json") as f:
+ return json.load(f)[TableName]
+
+ mock_conn.return_value.get_table_metadata = mock_get_table_metadata
+
+ op = AthenaOperator(
+ task_id="test_athena_openlineage",
+ query="INSERT INTO TEST_TABLE SELECT CUSTOMER_EMAIL FROM
DISCOUNTS",
+ database="TEST_DATABASE",
+ output_location="s3://test_s3_bucket/",
+ client_request_token="eac427d0-1c6d-4dfb-96aa-2835d3ac6595",
+ sleep_time=0,
+ max_polling_attempts=3,
+ dag=self.dag,
+ )
+
+ expected_lineage = OperatorLineage(
+ inputs=[
+ Dataset(
+ namespace="awsathena://athena.eu-west-1.amazonaws.com",
+ name="AwsDataCatalog.TEST_DATABASE.DISCOUNTS",
+ facets={
+ "symlinks": SymlinksDatasetFacet(
+ identifiers=[
+ SymlinksDatasetFacetIdentifiers(
+ namespace="s3://bucket",
+ name="/discount/data/path/",
+ type="TABLE",
+ )
+ ],
+ ),
+ "schema": SchemaDatasetFacet(
+ fields=[
+ SchemaField(
+ name="ID",
+ type="int",
+ description="from deserializer",
+ ),
+ SchemaField(
+ name="AMOUNT_OFF",
+ type="int",
+ description="from deserializer",
+ ),
+ SchemaField(
+ name="CUSTOMER_EMAIL",
+ type="varchar",
+ description="from deserializer",
+ ),
+ SchemaField(
+ name="STARTS_ON",
+ type="timestamp",
+ description="from deserializer",
+ ),
+ SchemaField(
+ name="ENDS_ON",
+ type="timestamp",
+ description="from deserializer",
+ ),
+ ],
+ ),
+ },
+ )
+ ],
+ outputs=[
+ Dataset(
+ namespace="awsathena://athena.eu-west-1.amazonaws.com",
+ name="AwsDataCatalog.TEST_DATABASE.TEST_TABLE",
+ facets={
+ "symlinks": SymlinksDatasetFacet(
+ identifiers=[
+ SymlinksDatasetFacetIdentifiers(
+ namespace="s3://bucket",
+ name="/data/test_table/data/path",
+ type="TABLE",
+ )
+ ],
+ ),
+ "schema": SchemaDatasetFacet(
+ fields=[
+ SchemaField(
+ name="column",
+ type="string",
+ description="from deserializer",
+ )
+ ],
+ ),
+ },
+ ),
+ Dataset(namespace="s3://test_s3_bucket", name="/"),
+ ],
+ job_facets={
+ "sql": SqlJobFacet(
+ query="INSERT INTO TEST_TABLE SELECT CUSTOMER_EMAIL FROM
DISCOUNTS",
+ )
+ },
+ )
+ assert op.get_openlineage_facets_on_start() == expected_lineage