This is an automated email from the ASF dual-hosted git repository.

onikolas pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 44b97e1687 Add OpenLineage support for Redshift SQL. (#35794)
44b97e1687 is described below

commit 44b97e168733b08b308f16b2738b6c15e8a35862
Author: Jakub Dardzinski <[email protected]>
AuthorDate: Thu Jan 4 18:49:35 2024 +0100

    Add OpenLineage support for Redshift SQL. (#35794)
    
    Add flat information schema query support in SQLParser.
    
    Signed-off-by: Jakub Dardzinski <[email protected]>
    Co-authored-by: Niko Oliveira <[email protected]>
---
 airflow/providers/amazon/aws/hooks/redshift_sql.py |  61 ++++++
 airflow/providers/openlineage/sqlparser.py         |  13 +-
 airflow/providers/openlineage/utils/sql.py         |  97 ++++++---
 .../amazon/aws/hooks/test_redshift_sql.py          |  44 ++++
 .../amazon/aws/operators/test_redshift_sql.py      | 241 +++++++++++++++++++++
 tests/providers/openlineage/utils/test_sql.py      |  62 ++++--
 6 files changed, 472 insertions(+), 46 deletions(-)

diff --git a/airflow/providers/amazon/aws/hooks/redshift_sql.py 
b/airflow/providers/amazon/aws/hooks/redshift_sql.py
index 66659cb0a1..580efc1443 100644
--- a/airflow/providers/amazon/aws/hooks/redshift_sql.py
+++ b/airflow/providers/amazon/aws/hooks/redshift_sql.py
@@ -30,6 +30,7 @@ from airflow.providers.common.sql.hooks.sql import DbApiHook
 
 if TYPE_CHECKING:
     from airflow.models.connection import Connection
+    from airflow.providers.openlineage.sqlparser import DatabaseInfo
 
 
 class RedshiftSQLHook(DbApiHook):
@@ -197,3 +198,63 @@ class RedshiftSQLHook(DbApiHook):
         conn_kwargs_dejson = self.conn.extra_dejson
         conn_kwargs: dict = {**conn_params, **conn_kwargs_dejson}
         return redshift_connector.connect(**conn_kwargs)
+
+    def get_openlineage_database_info(self, connection: Connection) -> 
DatabaseInfo:
+        """Returns Redshift specific information for OpenLineage."""
+        from airflow.providers.openlineage.sqlparser import DatabaseInfo
+
+        authority = self._get_openlineage_redshift_authority_part(connection)
+
+        return DatabaseInfo(
+            scheme="redshift",
+            authority=authority,
+            database=connection.schema,
+            information_schema_table_name="SVV_REDSHIFT_COLUMNS",
+            information_schema_columns=[
+                "schema_name",
+                "table_name",
+                "column_name",
+                "ordinal_position",
+                "data_type",
+                "database_name",
+            ],
+            is_information_schema_cross_db=True,
+            use_flat_cross_db_query=True,
+        )
+
+    def _get_openlineage_redshift_authority_part(self, connection: Connection) 
-> str:
+        from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
+
+        port = connection.port or 5439
+
+        cluster_identifier = None
+
+        if connection.extra_dejson.get("iam", False):
+            cluster_identifier = 
connection.extra_dejson.get("cluster_identifier")
+            region_name = AwsBaseHook(aws_conn_id=self.aws_conn_id).region_name
+            identifier = f"{cluster_identifier}.{region_name}"
+        if not cluster_identifier:
+            identifier = self._get_identifier_from_hostname(connection.host)
+        return f"{identifier}:{port}"
+
+    def _get_identifier_from_hostname(self, hostname: str) -> str:
+        parts = hostname.split(".")
+        if "amazonaws.com" in hostname and len(parts) == 6:
+            return f"{parts[0]}.{parts[2]}"
+        else:
+            self.log.debug(
+                """Could not parse identifier from hostname '%s'.
+            You are probably using IP to connect to Redshift cluster.
+            Expected format: 
'cluster_identifier.id.region_name.redshift.amazonaws.com'
+            Falling back to whole hostname.""",
+                hostname,
+            )
+            return hostname
+
+    def get_openlineage_database_dialect(self, connection: Connection) -> str:
+        """Returns redshift dialect."""
+        return "redshift"
+
+    def get_openlineage_default_schema(self) -> str | None:
+        """Returns current schema. This is usually changed with 
``SEARCH_PATH`` parameter."""
+        return self.get_first("SELECT CURRENT_SCHEMA();")[0]
diff --git a/airflow/providers/openlineage/sqlparser.py 
b/airflow/providers/openlineage/sqlparser.py
index 41c378fc27..d54c19dbc8 100644
--- a/airflow/providers/openlineage/sqlparser.py
+++ b/airflow/providers/openlineage/sqlparser.py
@@ -67,6 +67,7 @@ class GetTableSchemasParams(TypedDict):
     is_cross_db: bool
     information_schema_columns: list[str]
     information_schema_table: str
+    use_flat_cross_db_query: bool
     is_uppercase_names: bool
     database: str | None
 
@@ -83,6 +84,8 @@ class DatabaseInfo:
     :param database: Takes precedence over parsed database name.
     :param information_schema_columns: List of columns names from information 
schema table.
     :param information_schema_table_name: Information schema table name.
+    :param use_flat_cross_db_query: Specifies if single information schema 
table should be used
+        for cross-database queries (e.g. for Redshift).
     :param is_information_schema_cross_db: Specifies if information schema 
contains
         cross-database data.
     :param is_uppercase_names: Specifies if database accepts only uppercase 
names (e.g. Snowflake).
@@ -95,6 +98,7 @@ class DatabaseInfo:
     database: str | None = None
     information_schema_columns: list[str] = DEFAULT_INFORMATION_SCHEMA_COLUMNS
     information_schema_table_name: str = DEFAULT_INFORMATION_SCHEMA_TABLE_NAME
+    use_flat_cross_db_query: bool = False
     is_information_schema_cross_db: bool = False
     is_uppercase_names: bool = False
     normalize_name_method: Callable[[str], str] = default_normalize_name_method
@@ -133,6 +137,7 @@ class SQLParser:
             "information_schema_table": 
database_info.information_schema_table_name,
             "is_uppercase_names": database_info.is_uppercase_names,
             "database": database or database_info.database,
+            "use_flat_cross_db_query": database_info.use_flat_cross_db_query,
         }
         return get_table_schemas(
             hook,
@@ -297,9 +302,10 @@ class SQLParser:
         tables: list[DbTableMeta],
         normalize_name: Callable[[str], str],
         is_cross_db: bool,
-        information_schema_columns,
-        information_schema_table,
-        is_uppercase_names,
+        information_schema_columns: list[str],
+        information_schema_table: str,
+        is_uppercase_names: bool,
+        use_flat_cross_db_query: bool,
         database: str | None = None,
         sqlalchemy_engine: Engine | None = None,
     ) -> str:
@@ -314,6 +320,7 @@ class SQLParser:
             columns=information_schema_columns,
             information_schema_table_name=information_schema_table,
             tables_hierarchy=tables_hierarchy,
+            use_flat_cross_db_query=use_flat_cross_db_query,
             uppercase_names=is_uppercase_names,
             sqlalchemy_engine=sqlalchemy_engine,
         )
diff --git a/airflow/providers/openlineage/utils/sql.py 
b/airflow/providers/openlineage/utils/sql.py
index 7bd6043040..da08fa68d4 100644
--- a/airflow/providers/openlineage/utils/sql.py
+++ b/airflow/providers/openlineage/utils/sql.py
@@ -24,7 +24,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional
 from attrs import define
 from openlineage.client.facet import SchemaDatasetFacet, SchemaField
 from openlineage.client.run import Dataset
-from sqlalchemy import Column, MetaData, Table, and_, union_all
+from sqlalchemy import Column, MetaData, Table, and_, or_, union_all
 
 if TYPE_CHECKING:
     from sqlalchemy.engine import Engine
@@ -42,7 +42,7 @@ class ColumnIndex(IntEnum):
     ORDINAL_POSITION = 3
     # Use 'udt_name' which is the underlying type of column
     UDT_NAME = 4
-    # Database is optional as 5th column
+    # Database is optional as 6th column
     DATABASE = 5
 
 
@@ -145,55 +145,96 @@ def create_information_schema_query(
     information_schema_table_name: str,
     tables_hierarchy: TablesHierarchy,
     uppercase_names: bool = False,
+    use_flat_cross_db_query: bool = False,
     sqlalchemy_engine: Engine | None = None,
 ) -> str:
     """Creates query for getting table schemas from information schema."""
     metadata = MetaData(sqlalchemy_engine)
     select_statements = []
-    for db, schema_mapping in tables_hierarchy.items():
-        # Information schema table name is expected to be "< 
information_schema schema >.<view/table name>"
-        # usually "information_schema.columns". In order to use table 
identifier correct for various table
-        # we need to pass first part of dot-separated identifier as `schema` 
argument to `sqlalchemy.Table`.
-        if db:
-            # Use database as first part of table identifier.
-            schema = db
-            table_name = information_schema_table_name
-        else:
-            # When no database passed, use schema as first part of table 
identifier.
-            schema, table_name = information_schema_table_name.split(".")
+    # Don't iterate over tables hierarchy, just pass it to query single 
information schema table
+    if use_flat_cross_db_query:
         information_schema_table = Table(
-            table_name,
+            information_schema_table_name,
             metadata,
             *[Column(column) for column in columns],
-            schema=schema,
             quote=False,
         )
-        filter_clauses = create_filter_clauses(schema_mapping, 
information_schema_table, uppercase_names)
-        
select_statements.append(information_schema_table.select().filter(*filter_clauses))
+        filter_clauses = create_filter_clauses(
+            tables_hierarchy,
+            information_schema_table,
+            uppercase_names=uppercase_names,
+        )
+        
select_statements.append(information_schema_table.select().filter(filter_clauses))
+    else:
+        for db, schema_mapping in tables_hierarchy.items():
+            # Information schema table name is expected to be "< 
information_schema schema >.<view/table name>"
+            # usually "information_schema.columns". In order to use table 
identifier correct for various table
+            # we need to pass first part of dot-separated identifier as 
`schema` argument to `sqlalchemy.Table`.
+            if db:
+                # Use database as first part of table identifier.
+                schema = db
+                table_name = information_schema_table_name
+            else:
+                # When no database passed, use schema as first part of table 
identifier.
+                schema, table_name = information_schema_table_name.split(".")
+            information_schema_table = Table(
+                table_name,
+                metadata,
+                *[Column(column) for column in columns],
+                schema=schema,
+                quote=False,
+            )
+            filter_clauses = create_filter_clauses(
+                {None: schema_mapping},
+                information_schema_table,
+                uppercase_names=uppercase_names,
+            )
+            
select_statements.append(information_schema_table.select().filter(filter_clauses))
     return str(
         union_all(*select_statements).compile(sqlalchemy_engine, 
compile_kwargs={"literal_binds": True})
     )
 
 
 def create_filter_clauses(
-    schema_mapping: dict, information_schema_table: Table, uppercase_names: 
bool = False
+    mapping: dict,
+    information_schema_table: Table,
+    uppercase_names: bool = False,
 ) -> ClauseElement:
     """
     Creates comprehensive filter clauses for all tables in one database.
 
-    :param schema_mapping: a dictionary of schema names and list of tables in 
each
+    :param mapping: a nested dictionary of database, schema names and list of 
tables in each
     :param information_schema_table: `sqlalchemy.Table` instance used to 
construct clauses
         For most SQL dbs it contains `table_name` and `table_schema` columns,
         therefore it is expected the table has them defined.
     :param uppercase_names: if True use schema and table names uppercase
     """
+    table_schema_column_name = 
information_schema_table.columns[ColumnIndex.SCHEMA].name
+    table_name_column_name = 
information_schema_table.columns[ColumnIndex.TABLE_NAME].name
+    try:
+        table_database_column_name = 
information_schema_table.columns[ColumnIndex.DATABASE].name
+    except IndexError:
+        table_database_column_name = ""
+
     filter_clauses = []
-    for schema, tables in schema_mapping.items():
-        filter_clause = information_schema_table.c.table_name.in_(
-            name.upper() if uppercase_names else name for name in tables
-        )
-        if schema:
-            schema = schema.upper() if uppercase_names else schema
-            filter_clause = and_(information_schema_table.c.table_schema == 
schema, filter_clause)
-        filter_clauses.append(filter_clause)
-    return filter_clauses
+    for db, schema_mapping in mapping.items():
+        schema_level_clauses = []
+        for schema, tables in schema_mapping.items():
+            filter_clause = 
information_schema_table.c[table_name_column_name].in_(
+                name.upper() if uppercase_names else name for name in tables
+            )
+            if schema:
+                schema = schema.upper() if uppercase_names else schema
+                filter_clause = and_(
+                    information_schema_table.c[table_schema_column_name] == 
schema, filter_clause
+                )
+            schema_level_clauses.append(filter_clause)
+        if db and table_database_column_name:
+            db = db.upper() if uppercase_names else db
+            filter_clause = and_(
+                information_schema_table.c[table_database_column_name] == db, 
or_(*schema_level_clauses)
+            )
+            filter_clauses.append(filter_clause)
+        else:
+            filter_clauses.extend(schema_level_clauses)
+    return or_(*filter_clauses)
diff --git a/tests/providers/amazon/aws/hooks/test_redshift_sql.py 
b/tests/providers/amazon/aws/hooks/test_redshift_sql.py
index 4871522489..aced8cae13 100644
--- a/tests/providers/amazon/aws/hooks/test_redshift_sql.py
+++ b/tests/providers/amazon/aws/hooks/test_redshift_sql.py
@@ -31,6 +31,7 @@ LOGIN_PASSWORD = "password"
 LOGIN_HOST = "host"
 LOGIN_PORT = 5439
 LOGIN_SCHEMA = "dev"
+MOCK_REGION_NAME = "eu-north-1"
 
 
 class TestRedshiftSQLHookConn:
@@ -240,3 +241,46 @@ class TestRedshiftSQLHookConn:
                 ClusterIdentifier=expected_cluster_identifier,
                 AutoCreate=False,
             )
+
+    @mock.patch.dict("os.environ", 
AIRFLOW_CONN_AWS_DEFAULT=f"aws://?region_name={MOCK_REGION_NAME}")
+    @pytest.mark.parametrize(
+        "connection_host, connection_extra, expected_identity",
+        [
+            # test without a connection host but with a cluster_identifier in 
connection extra
+            (
+                None,
+                {"iam": True, "cluster_identifier": 
"cluster_identifier_from_extra"},
+                f"cluster_identifier_from_extra.{MOCK_REGION_NAME}",
+            ),
+            # test with a connection host and without a cluster_identifier in 
connection extra
+            (
+                
"cluster_identifier_from_host.id.my_region.redshift.amazonaws.com",
+                {"iam": True},
+                "cluster_identifier_from_host.my_region",
+            ),
+            # test with both connection host and cluster_identifier in 
connection extra
+            (
+                "cluster_identifier_from_host.x.y",
+                {"iam": True, "cluster_identifier": 
"cluster_identifier_from_extra"},
+                f"cluster_identifier_from_extra.{MOCK_REGION_NAME}",
+            ),
+            # test when hostname doesn't match pattern
+            (
+                "1.2.3.4",
+                {},
+                "1.2.3.4",
+            ),
+        ],
+    )
+    def test_get_openlineage_redshift_authority_part(
+        self,
+        connection_host,
+        connection_extra,
+        expected_identity,
+    ):
+        self.connection.host = connection_host
+        self.connection.extra = json.dumps(connection_extra)
+
+        assert f"{expected_identity}:{LOGIN_PORT}" == 
self.db_hook._get_openlineage_redshift_authority_part(
+            self.connection
+        )
diff --git a/tests/providers/amazon/aws/operators/test_redshift_sql.py 
b/tests/providers/amazon/aws/operators/test_redshift_sql.py
new file mode 100644
index 0000000000..d1c6e26151
--- /dev/null
+++ b/tests/providers/amazon/aws/operators/test_redshift_sql.py
@@ -0,0 +1,241 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest.mock import MagicMock, call, patch
+
+import pytest
+from openlineage.client.facet import (
+    ColumnLineageDatasetFacet,
+    ColumnLineageDatasetFacetFieldsAdditional,
+    ColumnLineageDatasetFacetFieldsAdditionalInputFields,
+    SchemaDatasetFacet,
+    SchemaField,
+    SqlJobFacet,
+)
+from openlineage.client.run import Dataset
+
+from airflow.models.connection import Connection
+from airflow.providers.amazon.aws.hooks.redshift_sql import RedshiftSQLHook
+from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator
+
+MOCK_REGION_NAME = "eu-north-1"
+
+
+class TestRedshiftSQLOpenLineage:
+    @patch.dict("os.environ", 
AIRFLOW_CONN_AWS_DEFAULT=f"aws://?region_name={MOCK_REGION_NAME}")
+    @pytest.mark.parametrize(
+        "connection_host, connection_extra, expected_identity",
+        [
+            # test without a connection host but with a cluster_identifier in 
connection extra
+            (
+                None,
+                {"iam": True, "cluster_identifier": 
"cluster_identifier_from_extra"},
+                f"cluster_identifier_from_extra.{MOCK_REGION_NAME}",
+            ),
+            # test with a connection host and without a cluster_identifier in 
connection extra
+            (
+                
"cluster_identifier_from_host.id.my_region.redshift.amazonaws.com",
+                {"iam": True},
+                "cluster_identifier_from_host.my_region",
+            ),
+            # test with both connection host and cluster_identifier in 
connection extra
+            (
+                "cluster_identifier_from_host.x.y",
+                {"iam": True, "cluster_identifier": 
"cluster_identifier_from_extra"},
+                f"cluster_identifier_from_extra.{MOCK_REGION_NAME}",
+            ),
+            # test when hostname doesn't match pattern
+            (
+                "1.2.3.4",
+                {},
+                "1.2.3.4",
+            ),
+        ],
+    )
+    @patch("airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook.conn")
+    def test_execute_openlineage_events(
+        self, mock_aws_hook_conn, connection_host, connection_extra, 
expected_identity
+    ):
+        DB_NAME = "database"
+        DB_SCHEMA_NAME = "public"
+
+        ANOTHER_DB_NAME = "another_db"
+        ANOTHER_DB_SCHEMA = "another_schema"
+
+        # Mock AWS Connection
+        mock_aws_hook_conn.get_cluster_credentials.return_value = {
+            "DbPassword": "aws_token",
+            "DbUser": "IAM:user",
+        }
+
+        class RedshiftSQLHookForTests(RedshiftSQLHook):
+            get_conn = MagicMock(name="conn")
+            get_connection = MagicMock()
+
+            def get_first(self, *_):
+                return [f"{DB_NAME}.{DB_SCHEMA_NAME}"]
+
+        dbapi_hook = RedshiftSQLHookForTests()
+
+        class RedshiftOperatorForTest(SQLExecuteQueryOperator):
+            def get_db_hook(self):
+                return dbapi_hook
+
+        sql = (
+            "INSERT INTO Test_table\n"
+            "SELECT t1.*, t2.additional_constant FROM 
ANOTHER_db.another_schema.popular_orders_day_of_week t1\n"
+            "JOIN little_table t2 ON t1.order_day_of_week = 
t2.order_day_of_week;\n"
+            "FORGOT TO COMMENT"
+        )
+        op = RedshiftOperatorForTest(task_id="redshift-operator", sql=sql)
+        rows = [
+            [
+                (
+                    ANOTHER_DB_SCHEMA,
+                    "popular_orders_day_of_week",
+                    "order_day_of_week",
+                    1,
+                    "varchar",
+                    ANOTHER_DB_NAME,
+                ),
+                (
+                    ANOTHER_DB_SCHEMA,
+                    "popular_orders_day_of_week",
+                    "order_placed_on",
+                    2,
+                    "timestamp",
+                    ANOTHER_DB_NAME,
+                ),
+                (
+                    ANOTHER_DB_SCHEMA,
+                    "popular_orders_day_of_week",
+                    "orders_placed",
+                    3,
+                    "int4",
+                    ANOTHER_DB_NAME,
+                ),
+                (DB_SCHEMA_NAME, "little_table", "order_day_of_week", 1, 
"varchar", DB_NAME),
+                (DB_SCHEMA_NAME, "little_table", "additional_constant", 2, 
"varchar", DB_NAME),
+            ],
+            [
+                (DB_SCHEMA_NAME, "test_table", "order_day_of_week", 1, 
"varchar", DB_NAME),
+                (DB_SCHEMA_NAME, "test_table", "order_placed_on", 2, 
"timestamp", DB_NAME),
+                (DB_SCHEMA_NAME, "test_table", "orders_placed", 3, "int4", 
DB_NAME),
+                (DB_SCHEMA_NAME, "test_table", "additional_constant", 4, 
"varchar", DB_NAME),
+            ],
+        ]
+        dbapi_hook.get_connection.return_value = Connection(
+            conn_id="redshift_default",
+            conn_type="redshift",
+            host=connection_host,
+            extra=connection_extra,
+        )
+        
dbapi_hook.get_conn.return_value.cursor.return_value.fetchall.side_effect = rows
+
+        lineage = op.get_openlineage_facets_on_start()
+        assert 
dbapi_hook.get_conn.return_value.cursor.return_value.execute.mock_calls == [
+            call(
+                "SELECT SVV_REDSHIFT_COLUMNS.schema_name, "
+                "SVV_REDSHIFT_COLUMNS.table_name, "
+                "SVV_REDSHIFT_COLUMNS.column_name, "
+                "SVV_REDSHIFT_COLUMNS.ordinal_position, "
+                "SVV_REDSHIFT_COLUMNS.data_type, "
+                "SVV_REDSHIFT_COLUMNS.database_name \n"
+                "FROM SVV_REDSHIFT_COLUMNS \n"
+                "WHERE SVV_REDSHIFT_COLUMNS.table_name IN ('little_table') "
+                "OR SVV_REDSHIFT_COLUMNS.database_name = 'another_db' "
+                "AND SVV_REDSHIFT_COLUMNS.schema_name = 'another_schema' AND "
+                "SVV_REDSHIFT_COLUMNS.table_name IN 
('popular_orders_day_of_week')"
+            ),
+            call(
+                "SELECT SVV_REDSHIFT_COLUMNS.schema_name, "
+                "SVV_REDSHIFT_COLUMNS.table_name, "
+                "SVV_REDSHIFT_COLUMNS.column_name, "
+                "SVV_REDSHIFT_COLUMNS.ordinal_position, "
+                "SVV_REDSHIFT_COLUMNS.data_type, "
+                "SVV_REDSHIFT_COLUMNS.database_name \n"
+                "FROM SVV_REDSHIFT_COLUMNS \n"
+                "WHERE SVV_REDSHIFT_COLUMNS.table_name IN ('Test_table')"
+            ),
+        ]
+
+        expected_namespace = f"redshift://{expected_identity}:5439"
+
+        assert lineage.inputs == [
+            Dataset(
+                namespace=expected_namespace,
+                
name=f"{ANOTHER_DB_NAME}.{ANOTHER_DB_SCHEMA}.popular_orders_day_of_week",
+                facets={
+                    "schema": SchemaDatasetFacet(
+                        fields=[
+                            SchemaField(name="order_day_of_week", 
type="varchar"),
+                            SchemaField(name="order_placed_on", 
type="timestamp"),
+                            SchemaField(name="orders_placed", type="int4"),
+                        ]
+                    )
+                },
+            ),
+            Dataset(
+                namespace=expected_namespace,
+                name=f"{DB_NAME}.{DB_SCHEMA_NAME}.little_table",
+                facets={
+                    "schema": SchemaDatasetFacet(
+                        fields=[
+                            SchemaField(name="order_day_of_week", 
type="varchar"),
+                            SchemaField(name="additional_constant", 
type="varchar"),
+                        ]
+                    )
+                },
+            ),
+        ]
+        assert lineage.outputs == [
+            Dataset(
+                namespace=expected_namespace,
+                name=f"{DB_NAME}.{DB_SCHEMA_NAME}.test_table",
+                facets={
+                    "schema": SchemaDatasetFacet(
+                        fields=[
+                            SchemaField(name="order_day_of_week", 
type="varchar"),
+                            SchemaField(name="order_placed_on", 
type="timestamp"),
+                            SchemaField(name="orders_placed", type="int4"),
+                            SchemaField(name="additional_constant", 
type="varchar"),
+                        ]
+                    ),
+                    "columnLineage": ColumnLineageDatasetFacet(
+                        fields={
+                            "additional_constant": 
ColumnLineageDatasetFacetFieldsAdditional(
+                                inputFields=[
+                                    
ColumnLineageDatasetFacetFieldsAdditionalInputFields(
+                                        namespace=expected_namespace,
+                                        name="database.public.little_table",
+                                        field="additional_constant",
+                                    )
+                                ],
+                                transformationDescription="",
+                                transformationType="",
+                            )
+                        }
+                    ),
+                },
+            )
+        ]
+
+        assert lineage.job_facets == {"sql": SqlJobFacet(query=sql)}
+
+        assert lineage.run_facets["extractionError"].failedTasks == 1
diff --git a/tests/providers/openlineage/utils/test_sql.py 
b/tests/providers/openlineage/utils/test_sql.py
index 8567920578..180defbeec 100644
--- a/tests/providers/openlineage/utils/test_sql.py
+++ b/tests/providers/openlineage/utils/test_sql.py
@@ -262,31 +262,63 @@ def test_get_table_schemas_with_other_database():
 @pytest.mark.parametrize(
     "schema_mapping, expected",
     [
-        pytest.param({None: ["C1", "C2"]}, 
["information_schema.columns.table_name IN ('C1', 'C2')"]),
+        pytest.param({None: {None: ["C1", "C2"]}}, 
"information_schema.columns.table_name IN ('C1', 'C2')"),
         pytest.param(
-            {"Schema1": ["Table1"], "Schema2": ["Table2"]},
-            [
-                "information_schema.columns.table_schema = 'Schema1' AND "
-                "information_schema.columns.table_name IN ('Table1')",
-                "information_schema.columns.table_schema = 'Schema2' AND "
-                "information_schema.columns.table_name IN ('Table2')",
-            ],
+            {None: {"Schema1": ["Table1"], "Schema2": ["Table2"]}},
+            "information_schema.columns.table_schema = 'Schema1' AND "
+            "information_schema.columns.table_name IN ('Table1') OR "
+            "information_schema.columns.table_schema = 'Schema2' AND "
+            "information_schema.columns.table_name IN ('Table2')",
         ),
         pytest.param(
-            {"Schema1": ["Table1", "Table2"]},
-            [
-                "information_schema.columns.table_schema = 'Schema1' AND "
-                "information_schema.columns.table_name IN ('Table1', 
'Table2')",
-            ],
+            {None: {"Schema1": ["Table1", "Table2"]}},
+            "information_schema.columns.table_schema = 'Schema1' AND "
+            "information_schema.columns.table_name IN ('Table1', 'Table2')",
+        ),
+        pytest.param(
+            {"Database1": {"Schema1": ["Table1", "Table2"]}},
+            "information_schema.columns.table_database = 'Database1' "
+            "AND information_schema.columns.table_schema = 'Schema1' "
+            "AND information_schema.columns.table_name IN ('Table1', 
'Table2')",
+        ),
+        pytest.param(
+            {"Database1": {"Schema1": ["Table1", "Table2"], "Schema2": 
["Table3", "Table4"]}},
+            "information_schema.columns.table_database = 'Database1' "
+            "AND (information_schema.columns.table_schema = 'Schema1' "
+            "AND information_schema.columns.table_name IN ('Table1', 'Table2') 
"
+            "OR information_schema.columns.table_schema = 'Schema2' "
+            "AND information_schema.columns.table_name IN ('Table3', 
'Table4'))",
+        ),
+        pytest.param(
+            {"Database1": {"Schema1": ["Table1", "Table2"]}, "Database2": 
{"Schema2": ["Table3", "Table4"]}},
+            "information_schema.columns.table_database = 'Database1' "
+            "AND information_schema.columns.table_schema = 'Schema1' "
+            "AND information_schema.columns.table_name IN ('Table1', 'Table2') 
OR "
+            "information_schema.columns.table_database = 'Database2' "
+            "AND information_schema.columns.table_schema = 'Schema2' "
+            "AND information_schema.columns.table_name IN ('Table3', 
'Table4')",
         ),
     ],
 )
 def test_create_filter_clauses(schema_mapping, expected):
     information_table = Table(
-        "columns", MetaData(), *[Column("table_name"), 
Column("table_schema")], schema="information_schema"
+        "columns",
+        MetaData(),
+        *[
+            Column(name)
+            for name in [
+                "table_schema",
+                "table_name",
+                "column_name",
+                "ordinal_position",
+                "udt_name",
+                "table_database",
+            ]
+        ],
+        schema="information_schema",
     )
     clauses = create_filter_clauses(schema_mapping, information_table)
-    assert [str(clause.compile(compile_kwargs={"literal_binds": True})) for 
clause in clauses] == expected
+    assert str(clauses.compile(compile_kwargs={"literal_binds": True})) == 
expected
 
 
 def test_create_create_information_schema_query():

Reply via email to