This is an automated email from the ASF dual-hosted git repository.

mobuchowski pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new aaab2be5ca9 Fix OpenLineage SQL utils emitting duplicate datasets for 
tables in multiple schemas (#64622)
aaab2be5ca9 is described below

commit aaab2be5ca9790e866023310551e48e8c6d109ed
Author: Rahul Madan <[email protected]>
AuthorDate: Wed Apr 8 15:55:06 2026 +0530

    Fix OpenLineage SQL utils emitting duplicate datasets for tables in 
multiple schemas (#64622)
    
    * added fix for schema issue in sql - issue #35552
    
    Signed-off-by: Rahul Madan <[email protected]>
    
    * updated func name
    
    Signed-off-by: Rahul Madan <[email protected]>
    
    ---------
    
    Signed-off-by: Rahul Madan <[email protected]>
---
 .../src/airflow/providers/openlineage/utils/sql.py |  36 +++++++-
 .../tests/unit/openlineage/utils/test_sql.py       | 100 +++++++++++++++++++--
 2 files changed, 129 insertions(+), 7 deletions(-)

diff --git 
a/providers/openlineage/src/airflow/providers/openlineage/utils/sql.py 
b/providers/openlineage/src/airflow/providers/openlineage/utils/sql.py
index aa17ba3dc32..ee8aea70ce7 100644
--- a/providers/openlineage/src/airflow/providers/openlineage/utils/sql.py
+++ b/providers/openlineage/src/airflow/providers/openlineage/utils/sql.py
@@ -79,6 +79,32 @@ class TableSchema:
         )
 
 
+def _prefer_default_schema_for_duplicate_tables(
+    table_schemas: list[TableSchema],
+    default_schema: str,
+) -> list[TableSchema]:
+    """
+    When the same table appears in multiple schemas, keep only the default 
schema match.
+
+    This handles the case where a SQL query references a table by bare name 
(without
+    schema qualifier) and the information_schema query returns results from 
multiple
+    schemas. In that case, only the entry matching the connection's default 
schema
+    should be kept.
+    """
+    table_groups: dict[tuple, list[TableSchema]] = defaultdict(list)
+    for ts in table_schemas:
+        table_groups[(ts.database, ts.table)].append(ts)
+
+    result = []
+    for group in table_groups.values():
+        if len(group) == 1:
+            result.append(group[0])
+        else:
+            matching = [ts for ts in group if ts.schema == default_schema]
+            result.extend(matching if matching else group)
+    return result
+
+
 def get_table_schemas(
     hook: BaseHook,
     namespace: str,
@@ -101,12 +127,18 @@ def get_table_schemas(
     with closing(hook.get_conn()) as conn, closing(conn.cursor()) as cursor:
         if in_query:
             cursor.execute(in_query)
-            in_datasets = [x.to_dataset(namespace, database, schema) for x in 
parse_query_result(cursor)]
+            in_table_schemas = parse_query_result(cursor)
+            if schema:
+                in_table_schemas = 
_prefer_default_schema_for_duplicate_tables(in_table_schemas, schema)
+            in_datasets = [x.to_dataset(namespace, database, schema) for x in 
in_table_schemas]
         else:
             in_datasets = []
         if out_query:
             cursor.execute(out_query)
-            out_datasets = [x.to_dataset(namespace, database, schema) for x in 
parse_query_result(cursor)]
+            out_table_schemas = parse_query_result(cursor)
+            if schema:
+                out_table_schemas = 
_prefer_default_schema_for_duplicate_tables(out_table_schemas, schema)
+            out_datasets = [x.to_dataset(namespace, database, schema) for x in 
out_table_schemas]
         else:
             out_datasets = []
     log.debug("Got table schema query result from database.")
diff --git a/providers/openlineage/tests/unit/openlineage/utils/test_sql.py 
b/providers/openlineage/tests/unit/openlineage/utils/test_sql.py
index aa699d38dae..27cea0d04bd 100644
--- a/providers/openlineage/tests/unit/openlineage/utils/test_sql.py
+++ b/providers/openlineage/tests/unit/openlineage/utils/test_sql.py
@@ -181,11 +181,6 @@ def test_get_table_schemas_with_mixed_schemas():
             Dataset(
                 namespace="bigquery", name="FOOD_DELIVERY.PUBLIC.DISCOUNTS", 
facets={"schema": SCHEMA_FACET}
             ),
-            Dataset(
-                namespace="bigquery",
-                name="FOOD_DELIVERY.ANOTHER_DB_SCHEMA.DISCOUNTS",
-                facets={"schema": SCHEMA_FACET},
-            ),
         ],
         [],
     )
@@ -249,6 +244,101 @@ def test_get_table_schemas_with_other_database():
     )
 
 
+def test_get_table_schemas_filters_by_default_schema():
+    """When the same table exists in multiple schemas, only the default schema 
should be returned."""
+    hook = MagicMock()
+    ANOTHER_DB_SCHEMA_NAME = "ANOTHER_DB_SCHEMA"
+
+    rows = [
+        (DB_SCHEMA_NAME, DB_TABLE_NAME.name, "ID", 1, "int4"),
+        (DB_SCHEMA_NAME, DB_TABLE_NAME.name, "AMOUNT_OFF", 2, "int4"),
+        (DB_SCHEMA_NAME, DB_TABLE_NAME.name, "CUSTOMER_EMAIL", 3, "varchar"),
+        (DB_SCHEMA_NAME, DB_TABLE_NAME.name, "STARTS_ON", 4, "timestamp"),
+        (DB_SCHEMA_NAME, DB_TABLE_NAME.name, "ENDS_ON", 5, "timestamp"),
+        (ANOTHER_DB_SCHEMA_NAME, DB_TABLE_NAME.name, "ID", 1, "int4"),
+        (ANOTHER_DB_SCHEMA_NAME, DB_TABLE_NAME.name, "AMOUNT_OFF", 2, "int4"),
+        (ANOTHER_DB_SCHEMA_NAME, DB_TABLE_NAME.name, "CUSTOMER_EMAIL", 3, 
"varchar"),
+        (ANOTHER_DB_SCHEMA_NAME, DB_TABLE_NAME.name, "STARTS_ON", 4, 
"timestamp"),
+        (ANOTHER_DB_SCHEMA_NAME, DB_TABLE_NAME.name, "ENDS_ON", 5, 
"timestamp"),
+    ]
+
+    hook.get_conn.return_value.cursor.return_value.fetchall.side_effect = 
[rows, []]
+
+    table_schemas = get_table_schemas(
+        hook=hook,
+        namespace="bigquery",
+        database=DB_NAME,
+        schema=DB_SCHEMA_NAME,
+        in_query="fake_sql",
+        out_query="another_fake_sql",
+    )
+
+    # Only the default schema (PUBLIC) should be returned, not 
ANOTHER_DB_SCHEMA
+    assert table_schemas == (
+        [
+            Dataset(
+                namespace="bigquery", name="FOOD_DELIVERY.PUBLIC.DISCOUNTS", 
facets={"schema": SCHEMA_FACET}
+            ),
+        ],
+        [],
+    )
+
+
+def test_get_table_schemas_no_default_schema_keeps_all():
+    """When no default schema is provided, all schemas should be returned."""
+    hook = MagicMock()
+    ANOTHER_DB_SCHEMA_NAME = "ANOTHER_DB_SCHEMA"
+
+    rows = [
+        (DB_SCHEMA_NAME, DB_TABLE_NAME.name, "ID", 1, "int4"),
+        (DB_SCHEMA_NAME, DB_TABLE_NAME.name, "AMOUNT_OFF", 2, "int4"),
+        (DB_SCHEMA_NAME, DB_TABLE_NAME.name, "CUSTOMER_EMAIL", 3, "varchar"),
+        (DB_SCHEMA_NAME, DB_TABLE_NAME.name, "STARTS_ON", 4, "timestamp"),
+        (DB_SCHEMA_NAME, DB_TABLE_NAME.name, "ENDS_ON", 5, "timestamp"),
+        (ANOTHER_DB_SCHEMA_NAME, DB_TABLE_NAME.name, "ID", 1, "int4"),
+        (ANOTHER_DB_SCHEMA_NAME, DB_TABLE_NAME.name, "AMOUNT_OFF", 2, "int4"),
+        (ANOTHER_DB_SCHEMA_NAME, DB_TABLE_NAME.name, "CUSTOMER_EMAIL", 3, 
"varchar"),
+        (ANOTHER_DB_SCHEMA_NAME, DB_TABLE_NAME.name, "STARTS_ON", 4, 
"timestamp"),
+        (ANOTHER_DB_SCHEMA_NAME, DB_TABLE_NAME.name, "ENDS_ON", 5, 
"timestamp"),
+    ]
+
+    hook.get_conn.return_value.cursor.return_value.fetchall.side_effect = 
[rows, []]
+
+    table_schemas = get_table_schemas(
+        hook=hook,
+        namespace="bigquery",
+        database=DB_NAME,
+        schema=None,
+        in_query="fake_sql",
+        out_query="another_fake_sql",
+    )
+
+    another_schema_facet = schema_dataset.SchemaDatasetFacet(
+        fields=[
+            schema_dataset.SchemaDatasetFacetFields(name="ID", type="int4"),
+            schema_dataset.SchemaDatasetFacetFields(name="AMOUNT_OFF", 
type="int4"),
+            schema_dataset.SchemaDatasetFacetFields(name="CUSTOMER_EMAIL", 
type="varchar"),
+            schema_dataset.SchemaDatasetFacetFields(name="STARTS_ON", 
type="timestamp"),
+            schema_dataset.SchemaDatasetFacetFields(name="ENDS_ON", 
type="timestamp"),
+        ]
+    )
+
+    # No default schema provided, so all schemas should be returned
+    assert table_schemas == (
+        [
+            Dataset(
+                namespace="bigquery", name="FOOD_DELIVERY.PUBLIC.DISCOUNTS", 
facets={"schema": SCHEMA_FACET}
+            ),
+            Dataset(
+                namespace="bigquery",
+                name="FOOD_DELIVERY.ANOTHER_DB_SCHEMA.DISCOUNTS",
+                facets={"schema": another_schema_facet},
+            ),
+        ],
+        [],
+    )
+
+
 @pytest.mark.parametrize(
     ("schema_mapping", "expected"),
     [

Reply via email to