This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new b7d0bf9800 fix OpenLineage extraction for AthenaOperator (#40545)
b7d0bf9800 is described below
commit b7d0bf9800974e2029a777e20417e3498e665503
Author: Kacper Muda <[email protected]>
AuthorDate: Thu Jul 4 11:15:26 2024 +0200
fix OpenLineage extraction for AthenaOperator (#40545)
Signed-off-by: Kacper Muda <[email protected]>
---
airflow/providers/amazon/aws/operators/athena.py | 26 ++++++++++++++++------
.../providers/amazon/aws/operators/test_athena.py | 23 ++++++++++++++++++-
2 files changed, 41 insertions(+), 8 deletions(-)
diff --git a/airflow/providers/amazon/aws/operators/athena.py
b/airflow/providers/amazon/aws/operators/athena.py
index 5d30b93143..0178d60a12 100644
--- a/airflow/providers/amazon/aws/operators/athena.py
+++ b/airflow/providers/amazon/aws/operators/athena.py
@@ -175,9 +175,6 @@ class AthenaOperator(AwsBaseOperator[AthenaHook]):
f"query_execution_id is {self.query_execution_id}."
)
- # Save output location from API response for later use in OpenLineage.
- self.output_location =
self.hook.get_output_location(self.query_execution_id)
-
return self.query_execution_id
def execute_complete(self, context: Context, event: dict[str, Any] | None
= None) -> str:
@@ -185,6 +182,9 @@ class AthenaOperator(AwsBaseOperator[AthenaHook]):
if event["status"] != "success":
raise AirflowException(f"Error while waiting for operation on
cluster to complete: {event}")
+
+ # Save query_execution_id to be later used by listeners
+ self.query_execution_id = event["value"]
return event["value"]
def on_kill(self) -> None:
@@ -208,14 +208,21 @@ class AthenaOperator(AwsBaseOperator[AthenaHook]):
)
self.hook.poll_query_status(self.query_execution_id,
sleep_time=self.sleep_time)
- def get_openlineage_facets_on_start(self) -> OperatorLineage:
+ def get_openlineage_facets_on_complete(self, _) -> OperatorLineage:
"""
Retrieve OpenLineage data by parsing SQL queries and enriching them
with Athena API.
In addition to CTAS query, query and calculation results are stored in
S3 location.
- For that reason additional output is attached with this location.
+ For that reason additional output is attached with this location.
Instead of using the complete
+ path where the results are saved (user's prefix + some UUID), we are
creating a dataset with the
+ user-provided path only. This should make it easier to match this
dataset across different processes.
"""
- from openlineage.client.facet import ExtractionError,
ExtractionErrorRunFacet, SqlJobFacet
+ from openlineage.client.facet import (
+ ExternalQueryRunFacet,
+ ExtractionError,
+ ExtractionErrorRunFacet,
+ SqlJobFacet,
+ )
from openlineage.client.run import Dataset
from airflow.providers.openlineage.extractors.base import
OperatorLineage
@@ -265,6 +272,11 @@ class AthenaOperator(AwsBaseOperator[AthenaHook]):
)
)
+ if self.query_execution_id:
+ run_facets["externalQuery"] = ExternalQueryRunFacet(
+ externalQueryId=self.query_execution_id, source="awsathena"
+ )
+
if self.output_location:
parsed = urlparse(self.output_location)
outputs.append(Dataset(namespace=f"{parsed.scheme}://{parsed.netloc}",
name=parsed.path or "/"))
@@ -301,7 +313,7 @@ class AthenaOperator(AwsBaseOperator[AthenaHook]):
)
}
fields = [
- SchemaField(name=column["Name"], type=column["Type"],
description=column["Comment"])
+ SchemaField(name=column["Name"], type=column["Type"],
description=column.get("Comment"))
for column in table_metadata["TableMetadata"]["Columns"]
]
if fields:
diff --git a/tests/providers/amazon/aws/operators/test_athena.py
b/tests/providers/amazon/aws/operators/test_athena.py
index 66fb6b297f..5d5a6b88c3 100644
--- a/tests/providers/amazon/aws/operators/test_athena.py
+++ b/tests/providers/amazon/aws/operators/test_athena.py
@@ -21,6 +21,7 @@ from unittest import mock
import pytest
from openlineage.client.facet import (
+ ExternalQueryRunFacet,
SchemaDatasetFacet,
SchemaField,
SqlJobFacet,
@@ -264,6 +265,24 @@ class TestAthenaOperator:
query_execution_id=ATHENA_QUERY_ID,
)
+ def
test_execute_complete_reassigns_query_execution_id_after_deferring(self):
+ """Assert that we use query_execution_id from event after deferral."""
+
+ operator = AthenaOperator(
+ task_id="test_athena_operator",
+ query="SELECT * FROM TEST_TABLE",
+ database="TEST_DATABASE",
+ deferrable=True,
+ )
+ assert operator.query_execution_id is None
+
+ query_execution_id = "123456"
+ operator.execute_complete(
+ context=None,
+ event={"status": "success", "value": query_execution_id},
+ )
+ assert operator.query_execution_id == query_execution_id
+
@mock.patch.object(AthenaHook, "region_name",
new_callable=mock.PropertyMock)
@mock.patch.object(AthenaHook, "get_conn")
def test_operator_openlineage_data(self, mock_conn, mock_region_name):
@@ -285,6 +304,7 @@ class TestAthenaOperator:
max_polling_attempts=3,
dag=self.dag,
)
+ op.query_execution_id = "12345" # Mocking what will be available
after execution
expected_lineage = OperatorLineage(
inputs=[
@@ -365,5 +385,6 @@ class TestAthenaOperator:
query="INSERT INTO TEST_TABLE SELECT CUSTOMER_EMAIL FROM
DISCOUNTS",
)
},
+ run_facets={"externalQuery":
ExternalQueryRunFacet(externalQueryId="12345", source="awsathena")},
)
- assert op.get_openlineage_facets_on_start() == expected_lineage
+ assert op.get_openlineage_facets_on_complete(None) == expected_lineage