This is an automated email from the ASF dual-hosted git repository.
mobuchowski pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 6594a2f0f45 Add OpenLineage support to AthenaSQLHook (#66844)
6594a2f0f45 is described below
commit 6594a2f0f45aea1bb233c3810a7ebb9861057553
Author: Rahul Madan <[email protected]>
AuthorDate: Mon May 18 14:12:10 2026 +0530
Add OpenLineage support to AthenaSQLHook (#66844)
* Add OpenLineage support to AthenaSQLHook
Signed-off-by: Rahul Madan <[email protected]>
* Added tests for athena sql hook
Signed-off-by: Rahul Madan <[email protected]>
* Address review: hook-constructor region wins + support aws_domain extra
Signed-off-by: Rahul Madan <[email protected]>
---------
Signed-off-by: Rahul Madan <[email protected]>
---
.../providers/amazon/aws/hooks/athena_sql.py | 31 ++++++++
.../tests/unit/amazon/aws/hooks/test_athena_sql.py | 93 ++++++++++++++++++++++
2 files changed, 124 insertions(+)
diff --git
a/providers/amazon/src/airflow/providers/amazon/aws/hooks/athena_sql.py
b/providers/amazon/src/airflow/providers/amazon/aws/hooks/athena_sql.py
index 94348612700..a9791aec246 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/athena_sql.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/athena_sql.py
@@ -177,6 +177,37 @@ class AthenaSQLHook(AwsBaseHook, DbApiHook):
aws_domain=self.conn.extra_dejson.get("aws_domain",
"amazonaws.com"),
)
+ def get_openlineage_database_info(self, connection):
+ """Return Amazon Athena specific information for OpenLineage."""
+ from airflow.providers.openlineage.sqlparser import DatabaseInfo
+
+ region_name = self.region_name or
connection.extra_dejson.get("region_name")
+ aws_domain = connection.extra_dejson.get("aws_domain", "amazonaws.com")
+ authority = f"athena.{region_name}.{aws_domain}" if region_name else
f"athena.{aws_domain}"
+
+ return DatabaseInfo(
+ scheme="awsathena",
+ authority=authority,
+ information_schema_columns=[
+ "table_schema",
+ "table_name",
+ "column_name",
+ "ordinal_position",
+ "data_type",
+ "table_catalog",
+ ],
+ database=connection.extra_dejson.get("catalog", "AwsDataCatalog"),
+ is_information_schema_cross_db=True,
+ )
+
+ def get_openlineage_database_dialect(self, _) -> str:
+ """Return Athena dialect. Athena uses Trino SQL engine."""
+ return "trino"
+
+ def get_openlineage_default_schema(self) -> str | None:
+ """Return Athena default schema."""
+ return self.conn.schema or "default"
+
def get_uri(self) -> str:
"""Overridden to use the Athena dialect as driver name."""
from airflow.providers.common.compat.sdk import
AirflowOptionalProviderFeatureException
diff --git a/providers/amazon/tests/unit/amazon/aws/hooks/test_athena_sql.py
b/providers/amazon/tests/unit/amazon/aws/hooks/test_athena_sql.py
index fc6fe82737c..7b5cec5f7bc 100644
--- a/providers/amazon/tests/unit/amazon/aws/hooks/test_athena_sql.py
+++ b/providers/amazon/tests/unit/amazon/aws/hooks/test_athena_sql.py
@@ -181,3 +181,96 @@ class TestAthenaSQLHookConn:
assert hook._verify is False
assert hook._region_name == "us-west-2"
assert hook._config is not None
+
+
+class TestAthenaSQLHookOpenLineage:
+ """Static tests for the OpenLineage methods on AthenaSQLHook."""
+
+ EXPECTED_INFORMATION_SCHEMA_COLUMNS = [
+ "table_schema",
+ "table_name",
+ "column_name",
+ "ordinal_position",
+ "data_type",
+ "table_catalog",
+ ]
+
+ @staticmethod
+ def _make_hook(connection: Connection, hook_region: str | None = None) ->
AthenaSQLHook:
+ hook = AthenaSQLHook(region_name=hook_region) if hook_region else
AthenaSQLHook()
+ hook.get_connection = mock.Mock(return_value=connection) # type:
ignore[method-assign]
+ return hook
+
+ @pytest.mark.parametrize(
+ ("extras", "hook_region", "expected_authority"),
+ [
+ # region from connection extras when hook-constructor region not
set
+ ({"region_name": "us-east-1"}, None,
"athena.us-east-1.amazonaws.com"),
+ # hook-constructor region (explicit user override) wins over
extras region
+ ({"region_name": "eu-west-1"}, "us-east-2",
"athena.us-east-2.amazonaws.com"),
+ # hook-constructor region used when extras have none
+ ({}, "ap-south-1", "athena.ap-south-1.amazonaws.com"),
+ # graceful fallback when neither is set
+ ({}, None, "athena.amazonaws.com"),
+ # aws_domain extra changes the domain (AWS GovCloud / China / ISO
partitions)
+ (
+ {"region_name": "cn-north-1", "aws_domain":
"amazonaws.com.cn"},
+ None,
+ "athena.cn-north-1.amazonaws.com.cn",
+ ),
+ # aws_domain still applied when region falls back
+ ({"aws_domain": "amazonaws.com.cn"}, None,
"athena.amazonaws.com.cn"),
+ ],
+ )
+ def test_get_openlineage_database_info_region_extraction(self, extras,
hook_region, expected_authority):
+ conn = Connection(conn_type="athena", schema="default", extra=extras)
+ hook = self._make_hook(conn, hook_region)
+ info = hook.get_openlineage_database_info(conn)
+ assert info.authority == expected_authority
+
+ def test_get_openlineage_database_info_returns_expected_fields(self):
+ """Snapshot of the DatabaseInfo shape so accidental changes are
caught."""
+ conn = Connection(
+ conn_type="athena",
+ schema="default",
+ extra={"region_name": "us-east-1"},
+ )
+ hook = self._make_hook(conn)
+ info = hook.get_openlineage_database_info(conn)
+ assert info.scheme == "awsathena"
+ assert info.authority == "athena.us-east-1.amazonaws.com"
+ assert info.database == "AwsDataCatalog"
+ assert info.is_information_schema_cross_db is True
+ assert info.information_schema_columns ==
self.EXPECTED_INFORMATION_SCHEMA_COLUMNS
+
+ def test_get_openlineage_database_info_custom_catalog(self):
+ conn = Connection(
+ conn_type="athena",
+ schema="default",
+ extra={"region_name": "us-east-1", "catalog": "MyCatalog"},
+ )
+ hook = self._make_hook(conn)
+ info = hook.get_openlineage_database_info(conn)
+ assert info.database == "MyCatalog"
+
+ def test_get_openlineage_database_dialect_returns_trino(self):
+ conn = Connection(conn_type="athena", extra={"region_name":
"us-east-1"})
+ hook = self._make_hook(conn)
+ assert hook.get_openlineage_database_dialect(conn) == "trino"
+
+ @pytest.mark.parametrize(
+ ("connection_schema", "expected_schema"),
+ [
+ ("mydb", "mydb"),
+ (None, "default"),
+ ("", "default"),
+ ],
+ )
+ def test_get_openlineage_default_schema(self, connection_schema,
expected_schema):
+ conn = Connection(
+ conn_type="athena",
+ schema=connection_schema,
+ extra={"region_name": "us-east-1"},
+ )
+ hook = self._make_hook(conn)
+ assert hook.get_openlineage_default_schema() == expected_schema