This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 626d3daa9b Add OpenLineage support for Trino. (#32910)
626d3daa9b is described below

commit 626d3daa9b5348fec6dfb4d29edcff97bba20298
Author: Jakub Dardzinski <[email protected]>
AuthorDate: Thu Aug 24 14:00:26 2023 +0200

    Add OpenLineage support for Trino. (#32910)
    
    Signed-off-by: Jakub Dardzinski <[email protected]>
---
 airflow/providers/openlineage/utils/sql.py         | 17 +++++-
 airflow/providers/trino/hooks/trino.py             | 29 ++++++++++
 dev/breeze/tests/test_provider_dependencies.py     |  4 +-
 generated/provider_dependencies.json               |  3 +-
 .../providers/trino/hooks/test_trino.py            | 11 ++++
 tests/providers/openlineage/utils/test_sql.py      | 24 ++++----
 tests/providers/trino/operators/test_trino.py      | 64 ++++++++++++++++++++++
 7 files changed, 134 insertions(+), 18 deletions(-)

diff --git a/airflow/providers/openlineage/utils/sql.py 
b/airflow/providers/openlineage/utils/sql.py
index 3c87b04bb6..b31d8da240 100644
--- a/airflow/providers/openlineage/utils/sql.py
+++ b/airflow/providers/openlineage/utils/sql.py
@@ -155,11 +155,22 @@ def create_information_schema_query(
     metadata = MetaData(sqlalchemy_engine)
     select_statements = []
     for db, schema_mapping in tables_hierarchy.items():
-        schema, table_name = information_schema_table_name.split(".")
+        # Information schema table name is expected to be "< 
information_schema schema >.<view/table name>"
+        # usually "information_schema.columns". In order to use table 
identifier correct for various table
+        # we need to pass first part of dot-separated identifier as `schema` 
argument to `sqlalchemy.Table`.
         if db:
-            schema = f"{db}.{schema}"
+            # Use database as first part of table identifier.
+            schema = db
+            table_name = information_schema_table_name
+        else:
+            # When no database passed, use schema as first part of table 
identifier.
+            schema, table_name = information_schema_table_name.split(".")
         information_schema_table = Table(
-            table_name, metadata, *[Column(column) for column in columns], 
schema=schema
+            table_name,
+            metadata,
+            *[Column(column) for column in columns],
+            schema=schema,
+            quote=False,
         )
         filter_clauses = create_filter_clauses(schema_mapping, 
information_schema_table, uppercase_names)
         
select_statements.append(information_schema_table.select().filter(*filter_clauses))
diff --git a/airflow/providers/trino/hooks/trino.py 
b/airflow/providers/trino/hooks/trino.py
index 14461b727d..e8a90951a7 100644
--- a/airflow/providers/trino/hooks/trino.py
+++ b/airflow/providers/trino/hooks/trino.py
@@ -233,3 +233,32 @@ class TrinoHook(DbApiHook):
         :return: The cell
         """
         return cell
+
+    def get_openlineage_database_info(self, connection):
+        """Returns Trino specific information for OpenLineage."""
+        from airflow.providers.openlineage.sqlparser import DatabaseInfo
+
+        return DatabaseInfo(
+            scheme="trino",
+            authority=DbApiHook.get_openlineage_authority_part(
+                connection, default_port=trino.constants.DEFAULT_PORT
+            ),
+            information_schema_columns=[
+                "table_schema",
+                "table_name",
+                "column_name",
+                "ordinal_position",
+                "data_type",
+                "table_catalog",
+            ],
+            database=connection.extra_dejson.get("catalog", "hive"),
+            is_information_schema_cross_db=True,
+        )
+
+    def get_openlineage_database_dialect(self, _):
+        """Returns Trino dialect."""
+        return "trino"
+
+    def get_openlineage_default_schema(self):
+        """Returns Trino default schema."""
+        return trino.constants.DEFAULT_SCHEMA
diff --git a/dev/breeze/tests/test_provider_dependencies.py 
b/dev/breeze/tests/test_provider_dependencies.py
index 5a1a1b0800..d532e8330c 100644
--- a/dev/breeze/tests/test_provider_dependencies.py
+++ b/dev/breeze/tests/test_provider_dependencies.py
@@ -25,7 +25,7 @@ def test_get_downstream_only():
     related_providers = get_related_providers(
         "trino", upstream_dependencies=False, downstream_dependencies=True
     )
-    assert {"google", "common.sql"} == related_providers
+    assert {"openlineage", "google", "common.sql"} == related_providers
 
 
 def test_get_upstream_only():
@@ -39,7 +39,7 @@ def test_both():
     related_providers = get_related_providers(
         "trino", upstream_dependencies=True, downstream_dependencies=True
     )
-    assert {"google", "mysql", "common.sql"} == related_providers
+    assert {"openlineage", "google", "mysql", "common.sql"} == 
related_providers
 
 
 def test_none():
diff --git a/generated/provider_dependencies.json 
b/generated/provider_dependencies.json
index f388781e18..0b1c5d5a1d 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -892,7 +892,8 @@
     ],
     "cross-providers-deps": [
       "common.sql",
-      "google"
+      "google",
+      "openlineage"
     ],
     "excluded-python-versions": []
   },
diff --git a/tests/integration/providers/trino/hooks/test_trino.py 
b/tests/integration/providers/trino/hooks/test_trino.py
index bb06d53887..1fc7cbc3d5 100644
--- a/tests/integration/providers/trino/hooks/test_trino.py
+++ b/tests/integration/providers/trino/hooks/test_trino.py
@@ -22,6 +22,7 @@ from unittest import mock
 import pytest
 
 from airflow.providers.trino.hooks.trino import TrinoHook
+from airflow.providers.trino.operators.trino import TrinoOperator
 
 
 @pytest.mark.integration("trino")
@@ -46,3 +47,13 @@ class TestTrinoHookIntegration:
             sql = "SELECT name FROM tpch.sf1.customer ORDER BY custkey ASC 
LIMIT 3"
             records = hook.get_records(sql)
             assert [["Customer#000000001"], ["Customer#000000002"], 
["Customer#000000003"]] == records
+
+    @mock.patch.dict("os.environ", 
AIRFLOW_CONN_TRINO_DEFAULT="trino://airflow@trino:8080/")
+    def test_openlineage_methods(self):
+        op = TrinoOperator(task_id="trino_test", sql="SELECT name FROM 
tpch.sf1.customer LIMIT 3")
+        op.execute({})
+        lineage = op.get_openlineage_facets_on_start()
+        assert lineage.inputs[0].namespace == "trino://trino:8080"
+        assert lineage.inputs[0].name == "tpch.sf1.customer"
+        assert "schema" in lineage.inputs[0].facets
+        assert lineage.job_facets["sql"].query == "SELECT name FROM 
tpch.sf1.customer LIMIT 3"
diff --git a/tests/providers/openlineage/utils/test_sql.py 
b/tests/providers/openlineage/utils/test_sql.py
index a82ab36bda..8567920578 100644
--- a/tests/providers/openlineage/utils/test_sql.py
+++ b/tests/providers/openlineage/utils/test_sql.py
@@ -327,17 +327,17 @@ def 
test_create_create_information_schema_query_cross_db():
             information_schema_table_name="information_schema.columns",
             tables_hierarchy={"db": {"schema1": ["table1"]}, "db2": 
{"schema1": ["table2"]}},
         )
-        == 'SELECT "db.information_schema".columns.table_schema, 
"db.information_schema".columns.table_name, '
-        '"db.information_schema".columns.column_name, 
"db.information_schema".columns.ordinal_position, '
-        '"db.information_schema".columns.data_type \n'
-        'FROM "db.information_schema".columns \n'
-        "WHERE \"db.information_schema\".columns.table_schema = 'schema1' "
-        "AND \"db.information_schema\".columns.table_name IN ('table1') "
+        == "SELECT db.information_schema.columns.table_schema, 
db.information_schema.columns.table_name, "
+        "db.information_schema.columns.column_name, 
db.information_schema.columns.ordinal_position, "
+        "db.information_schema.columns.data_type \n"
+        "FROM db.information_schema.columns \n"
+        "WHERE db.information_schema.columns.table_schema = 'schema1' "
+        "AND db.information_schema.columns.table_name IN ('table1') "
         "UNION ALL "
-        'SELECT "db2.information_schema".columns.table_schema, 
"db2.information_schema".columns.table_name, '
-        '"db2.information_schema".columns.column_name, 
"db2.information_schema".columns.ordinal_position, '
-        '"db2.information_schema".columns.data_type \n'
-        'FROM "db2.information_schema".columns \n'
-        "WHERE \"db2.information_schema\".columns.table_schema = 'schema1' "
-        "AND \"db2.information_schema\".columns.table_name IN ('table2')"
+        "SELECT db2.information_schema.columns.table_schema, 
db2.information_schema.columns.table_name, "
+        "db2.information_schema.columns.column_name, 
db2.information_schema.columns.ordinal_position, "
+        "db2.information_schema.columns.data_type \n"
+        "FROM db2.information_schema.columns \n"
+        "WHERE db2.information_schema.columns.table_schema = 'schema1' "
+        "AND db2.information_schema.columns.table_name IN ('table2')"
     )
diff --git a/tests/providers/trino/operators/test_trino.py 
b/tests/providers/trino/operators/test_trino.py
index 4be4a27c09..a0390b262b 100644
--- a/tests/providers/trino/operators/test_trino.py
+++ b/tests/providers/trino/operators/test_trino.py
@@ -20,8 +20,12 @@ from __future__ import annotations
 from unittest import mock
 
 import pytest
+from openlineage.client.facet import SchemaDatasetFacet, SchemaField, 
SqlJobFacet
+from openlineage.client.run import Dataset
 
 from airflow.exceptions import AirflowProviderDeprecationWarning
+from airflow.models.connection import Connection
+from airflow.providers.trino.hooks.trino import TrinoHook
 from airflow.providers.trino.operators.trino import TrinoOperator
 
 TRINO_CONN_ID = "test_trino"
@@ -49,3 +53,63 @@ class TestTrinoOperator:
             parameters=None,
             return_last=True,
         )
+
+
+def test_execute_openlineage_events():
+    DB_NAME = "tpch"
+    DB_SCHEMA_NAME = "sf1"
+
+    class TrinoHookForTests(TrinoHook):
+        get_conn = mock.MagicMock(name="conn")
+        get_connection = mock.MagicMock()
+
+        def get_first(self, *_):
+            return [f"{DB_NAME}.{DB_SCHEMA_NAME}"]
+
+    dbapi_hook = TrinoHookForTests()
+
+    class TrinoOperatorForTest(TrinoOperator):
+        def get_db_hook(self):
+            return dbapi_hook
+
+    sql = "SELECT name FROM tpch.sf1.customer LIMIT 3"
+    op = TrinoOperatorForTest(task_id="trino-operator", sql=sql)
+    rows = [
+        (DB_SCHEMA_NAME, "customer", "custkey", 1, "bigint", DB_NAME),
+        (DB_SCHEMA_NAME, "customer", "name", 2, "varchar(25)", DB_NAME),
+        (DB_SCHEMA_NAME, "customer", "address", 3, "varchar(40)", DB_NAME),
+        (DB_SCHEMA_NAME, "customer", "nationkey", 4, "bigint", DB_NAME),
+        (DB_SCHEMA_NAME, "customer", "phone", 5, "varchar(15)", DB_NAME),
+        (DB_SCHEMA_NAME, "customer", "acctbal", 6, "double", DB_NAME),
+    ]
+    dbapi_hook.get_connection.return_value = Connection(
+        conn_id="trino_default",
+        conn_type="trino",
+        host="trino",
+        port=8080,
+    )
+    dbapi_hook.get_conn.return_value.cursor.return_value.fetchall.side_effect 
= [rows, []]
+
+    lineage = op.get_openlineage_facets_on_start()
+    assert lineage.inputs == [
+        Dataset(
+            namespace="trino://trino:8080",
+            name=f"{DB_NAME}.{DB_SCHEMA_NAME}.customer",
+            facets={
+                "schema": SchemaDatasetFacet(
+                    fields=[
+                        SchemaField(name="custkey", type="bigint"),
+                        SchemaField(name="name", type="varchar(25)"),
+                        SchemaField(name="address", type="varchar(40)"),
+                        SchemaField(name="nationkey", type="bigint"),
+                        SchemaField(name="phone", type="varchar(15)"),
+                        SchemaField(name="acctbal", type="double"),
+                    ]
+                )
+            },
+        )
+    ]
+
+    assert len(lineage.outputs) == 0
+
+    assert lineage.job_facets == {"sql": SqlJobFacet(query=sql)}

Reply via email to