This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 626d3daa9b Add OpenLineage support for Trino. (#32910)
626d3daa9b is described below
commit 626d3daa9b5348fec6dfb4d29edcff97bba20298
Author: Jakub Dardzinski <[email protected]>
AuthorDate: Thu Aug 24 14:00:26 2023 +0200
Add OpenLineage support for Trino. (#32910)
Signed-off-by: Jakub Dardzinski <[email protected]>
---
airflow/providers/openlineage/utils/sql.py | 17 +++++-
airflow/providers/trino/hooks/trino.py | 29 ++++++++++
dev/breeze/tests/test_provider_dependencies.py | 4 +-
generated/provider_dependencies.json | 3 +-
.../providers/trino/hooks/test_trino.py | 11 ++++
tests/providers/openlineage/utils/test_sql.py | 24 ++++----
tests/providers/trino/operators/test_trino.py | 64 ++++++++++++++++++++++
7 files changed, 134 insertions(+), 18 deletions(-)
diff --git a/airflow/providers/openlineage/utils/sql.py
b/airflow/providers/openlineage/utils/sql.py
index 3c87b04bb6..b31d8da240 100644
--- a/airflow/providers/openlineage/utils/sql.py
+++ b/airflow/providers/openlineage/utils/sql.py
@@ -155,11 +155,22 @@ def create_information_schema_query(
metadata = MetaData(sqlalchemy_engine)
select_statements = []
for db, schema_mapping in tables_hierarchy.items():
- schema, table_name = information_schema_table_name.split(".")
+ # Information schema table name is expected to be "<
information_schema schema >.<view/table name>"
+ # usually "information_schema.columns". In order to use table
identifier correct for various table
+ # we need to pass first part of dot-separated identifier as `schema`
argument to `sqlalchemy.Table`.
if db:
- schema = f"{db}.{schema}"
+ # Use database as first part of table identifier.
+ schema = db
+ table_name = information_schema_table_name
+ else:
+ # When no database passed, use schema as first part of table
identifier.
+ schema, table_name = information_schema_table_name.split(".")
information_schema_table = Table(
- table_name, metadata, *[Column(column) for column in columns],
schema=schema
+ table_name,
+ metadata,
+ *[Column(column) for column in columns],
+ schema=schema,
+ quote=False,
)
filter_clauses = create_filter_clauses(schema_mapping,
information_schema_table, uppercase_names)
select_statements.append(information_schema_table.select().filter(*filter_clauses))
diff --git a/airflow/providers/trino/hooks/trino.py
b/airflow/providers/trino/hooks/trino.py
index 14461b727d..e8a90951a7 100644
--- a/airflow/providers/trino/hooks/trino.py
+++ b/airflow/providers/trino/hooks/trino.py
@@ -233,3 +233,32 @@ class TrinoHook(DbApiHook):
:return: The cell
"""
return cell
+
+ def get_openlineage_database_info(self, connection):
+ """Returns Trino specific information for OpenLineage."""
+ from airflow.providers.openlineage.sqlparser import DatabaseInfo
+
+ return DatabaseInfo(
+ scheme="trino",
+ authority=DbApiHook.get_openlineage_authority_part(
+ connection, default_port=trino.constants.DEFAULT_PORT
+ ),
+ information_schema_columns=[
+ "table_schema",
+ "table_name",
+ "column_name",
+ "ordinal_position",
+ "data_type",
+ "table_catalog",
+ ],
+ database=connection.extra_dejson.get("catalog", "hive"),
+ is_information_schema_cross_db=True,
+ )
+
+ def get_openlineage_database_dialect(self, _):
+ """Returns Trino dialect."""
+ return "trino"
+
+ def get_openlineage_default_schema(self):
+ """Returns Trino default schema."""
+ return trino.constants.DEFAULT_SCHEMA
diff --git a/dev/breeze/tests/test_provider_dependencies.py
b/dev/breeze/tests/test_provider_dependencies.py
index 5a1a1b0800..d532e8330c 100644
--- a/dev/breeze/tests/test_provider_dependencies.py
+++ b/dev/breeze/tests/test_provider_dependencies.py
@@ -25,7 +25,7 @@ def test_get_downstream_only():
related_providers = get_related_providers(
"trino", upstream_dependencies=False, downstream_dependencies=True
)
- assert {"google", "common.sql"} == related_providers
+ assert {"openlineage", "google", "common.sql"} == related_providers
def test_get_upstream_only():
@@ -39,7 +39,7 @@ def test_both():
related_providers = get_related_providers(
"trino", upstream_dependencies=True, downstream_dependencies=True
)
- assert {"google", "mysql", "common.sql"} == related_providers
+ assert {"openlineage", "google", "mysql", "common.sql"} ==
related_providers
def test_none():
diff --git a/generated/provider_dependencies.json
b/generated/provider_dependencies.json
index f388781e18..0b1c5d5a1d 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -892,7 +892,8 @@
],
"cross-providers-deps": [
"common.sql",
- "google"
+ "google",
+ "openlineage"
],
"excluded-python-versions": []
},
diff --git a/tests/integration/providers/trino/hooks/test_trino.py
b/tests/integration/providers/trino/hooks/test_trino.py
index bb06d53887..1fc7cbc3d5 100644
--- a/tests/integration/providers/trino/hooks/test_trino.py
+++ b/tests/integration/providers/trino/hooks/test_trino.py
@@ -22,6 +22,7 @@ from unittest import mock
import pytest
from airflow.providers.trino.hooks.trino import TrinoHook
+from airflow.providers.trino.operators.trino import TrinoOperator
@pytest.mark.integration("trino")
@@ -46,3 +47,13 @@ class TestTrinoHookIntegration:
sql = "SELECT name FROM tpch.sf1.customer ORDER BY custkey ASC
LIMIT 3"
records = hook.get_records(sql)
assert [["Customer#000000001"], ["Customer#000000002"],
["Customer#000000003"]] == records
+
+ @mock.patch.dict("os.environ",
AIRFLOW_CONN_TRINO_DEFAULT="trino://airflow@trino:8080/")
+ def test_openlineage_methods(self):
+ op = TrinoOperator(task_id="trino_test", sql="SELECT name FROM
tpch.sf1.customer LIMIT 3")
+ op.execute({})
+ lineage = op.get_openlineage_facets_on_start()
+ assert lineage.inputs[0].namespace == "trino://trino:8080"
+ assert lineage.inputs[0].name == "tpch.sf1.customer"
+ assert "schema" in lineage.inputs[0].facets
+ assert lineage.job_facets["sql"].query == "SELECT name FROM
tpch.sf1.customer LIMIT 3"
diff --git a/tests/providers/openlineage/utils/test_sql.py
b/tests/providers/openlineage/utils/test_sql.py
index a82ab36bda..8567920578 100644
--- a/tests/providers/openlineage/utils/test_sql.py
+++ b/tests/providers/openlineage/utils/test_sql.py
@@ -327,17 +327,17 @@ def
test_create_create_information_schema_query_cross_db():
information_schema_table_name="information_schema.columns",
tables_hierarchy={"db": {"schema1": ["table1"]}, "db2":
{"schema1": ["table2"]}},
)
- == 'SELECT "db.information_schema".columns.table_schema,
"db.information_schema".columns.table_name, '
- '"db.information_schema".columns.column_name,
"db.information_schema".columns.ordinal_position, '
- '"db.information_schema".columns.data_type \n'
- 'FROM "db.information_schema".columns \n'
- "WHERE \"db.information_schema\".columns.table_schema = 'schema1' "
- "AND \"db.information_schema\".columns.table_name IN ('table1') "
+ == "SELECT db.information_schema.columns.table_schema,
db.information_schema.columns.table_name, "
+ "db.information_schema.columns.column_name,
db.information_schema.columns.ordinal_position, "
+ "db.information_schema.columns.data_type \n"
+ "FROM db.information_schema.columns \n"
+ "WHERE db.information_schema.columns.table_schema = 'schema1' "
+ "AND db.information_schema.columns.table_name IN ('table1') "
"UNION ALL "
- 'SELECT "db2.information_schema".columns.table_schema,
"db2.information_schema".columns.table_name, '
- '"db2.information_schema".columns.column_name,
"db2.information_schema".columns.ordinal_position, '
- '"db2.information_schema".columns.data_type \n'
- 'FROM "db2.information_schema".columns \n'
- "WHERE \"db2.information_schema\".columns.table_schema = 'schema1' "
- "AND \"db2.information_schema\".columns.table_name IN ('table2')"
+ "SELECT db2.information_schema.columns.table_schema,
db2.information_schema.columns.table_name, "
+ "db2.information_schema.columns.column_name,
db2.information_schema.columns.ordinal_position, "
+ "db2.information_schema.columns.data_type \n"
+ "FROM db2.information_schema.columns \n"
+ "WHERE db2.information_schema.columns.table_schema = 'schema1' "
+ "AND db2.information_schema.columns.table_name IN ('table2')"
)
diff --git a/tests/providers/trino/operators/test_trino.py
b/tests/providers/trino/operators/test_trino.py
index 4be4a27c09..a0390b262b 100644
--- a/tests/providers/trino/operators/test_trino.py
+++ b/tests/providers/trino/operators/test_trino.py
@@ -20,8 +20,12 @@ from __future__ import annotations
from unittest import mock
import pytest
+from openlineage.client.facet import SchemaDatasetFacet, SchemaField,
SqlJobFacet
+from openlineage.client.run import Dataset
from airflow.exceptions import AirflowProviderDeprecationWarning
+from airflow.models.connection import Connection
+from airflow.providers.trino.hooks.trino import TrinoHook
from airflow.providers.trino.operators.trino import TrinoOperator
TRINO_CONN_ID = "test_trino"
@@ -49,3 +53,63 @@ class TestTrinoOperator:
parameters=None,
return_last=True,
)
+
+
+def test_execute_openlineage_events():
+ DB_NAME = "tpch"
+ DB_SCHEMA_NAME = "sf1"
+
+ class TrinoHookForTests(TrinoHook):
+ get_conn = mock.MagicMock(name="conn")
+ get_connection = mock.MagicMock()
+
+ def get_first(self, *_):
+ return [f"{DB_NAME}.{DB_SCHEMA_NAME}"]
+
+ dbapi_hook = TrinoHookForTests()
+
+ class TrinoOperatorForTest(TrinoOperator):
+ def get_db_hook(self):
+ return dbapi_hook
+
+ sql = "SELECT name FROM tpch.sf1.customer LIMIT 3"
+ op = TrinoOperatorForTest(task_id="trino-operator", sql=sql)
+ rows = [
+ (DB_SCHEMA_NAME, "customer", "custkey", 1, "bigint", DB_NAME),
+ (DB_SCHEMA_NAME, "customer", "name", 2, "varchar(25)", DB_NAME),
+ (DB_SCHEMA_NAME, "customer", "address", 3, "varchar(40)", DB_NAME),
+ (DB_SCHEMA_NAME, "customer", "nationkey", 4, "bigint", DB_NAME),
+ (DB_SCHEMA_NAME, "customer", "phone", 5, "varchar(15)", DB_NAME),
+ (DB_SCHEMA_NAME, "customer", "acctbal", 6, "double", DB_NAME),
+ ]
+ dbapi_hook.get_connection.return_value = Connection(
+ conn_id="trino_default",
+ conn_type="trino",
+ host="trino",
+ port=8080,
+ )
+ dbapi_hook.get_conn.return_value.cursor.return_value.fetchall.side_effect
= [rows, []]
+
+ lineage = op.get_openlineage_facets_on_start()
+ assert lineage.inputs == [
+ Dataset(
+ namespace="trino://trino:8080",
+ name=f"{DB_NAME}.{DB_SCHEMA_NAME}.customer",
+ facets={
+ "schema": SchemaDatasetFacet(
+ fields=[
+ SchemaField(name="custkey", type="bigint"),
+ SchemaField(name="name", type="varchar(25)"),
+ SchemaField(name="address", type="varchar(40)"),
+ SchemaField(name="nationkey", type="bigint"),
+ SchemaField(name="phone", type="varchar(15)"),
+ SchemaField(name="acctbal", type="double"),
+ ]
+ )
+ },
+ )
+ ]
+
+ assert len(lineage.outputs) == 0
+
+ assert lineage.job_facets == {"sql": SqlJobFacet(query=sql)}