This is an automated email from the ASF dual-hosted git repository.

mobuchowski pushed a commit to branch mobuchowski/cache-connection
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit e0dc34210b89b186be5d00772cb4c154b19a8cb8
Author: Maciej Obuchowski <[email protected]>
AuthorDate: Wed Apr 8 19:16:58 2026 +0200

    fix review: cache by hook identity instead of conn_id
    
    Use id(hook) as the @cache key instead of conn_id to ensure distinct
    hook instances sharing the same conn_id but with different params
    get separate cached database info results.
    
    Signed-off-by: Maciej Obuchowski <[email protected]>
    Signed-off-by: Maciej Obuchowski <[email protected]>
---
 .../openlineage/utils/sql_hook_lineage.py          | 34 +++++++++++++---------
 .../openlineage/utils/test_sql_hook_lineage.py     | 26 +++++++++++++++++
 2 files changed, 47 insertions(+), 13 deletions(-)

diff --git 
a/providers/openlineage/src/airflow/providers/openlineage/utils/sql_hook_lineage.py
 
b/providers/openlineage/src/airflow/providers/openlineage/utils/sql_hook_lineage.py
index 1e76158b420..68ffe65dab5 100644
--- 
a/providers/openlineage/src/airflow/providers/openlineage/utils/sql_hook_lineage.py
+++ 
b/providers/openlineage/src/airflow/providers/openlineage/utils/sql_hook_lineage.py
@@ -71,25 +71,32 @@ def emit_lineage_from_sql_extras(task_instance, sql_extras: 
list, is_successful:
     events: list[RunEvent] = []
     query_count = 0
 
-    # Build conn_id -> hook mapping before iterating. Hook instances are not 
hashable so
-    # conn_id (a plain string) is used as the @cache key throughout.
-    _hook_by_conn_id = {_get_hook_conn_id(e.context): e.context for e in 
sql_extras}
+    # Build hook identity -> (hook, conn_id) mapping before iterating.
+    # Using id(hook) as cache key instead of conn_id ensures distinct hook 
instances
+    # with the same conn_id but different params are cached separately.
+    _hook_info: dict[int, tuple[object, str | None]] = {}
+    for e in sql_extras:
+        hid = id(e.context)
+        if hid not in _hook_info:
+            _hook_info[hid] = (e.context, _get_hook_conn_id(e.context))
 
     @cache
-    def _get_connection(conn_id: str):
-        return _hook_by_conn_id[conn_id].get_connection(conn_id)
+    def _get_connection(hook_id: int):
+        hook, conn_id = _hook_info[hook_id]
+        return hook.get_connection(conn_id)
 
     @cache
-    def _get_database_info(conn_id: str):
+    def _get_database_info(hook_id: int):
+        hook, conn_id = _hook_info[hook_id]
         try:
-            return 
_hook_by_conn_id[conn_id].get_openlineage_database_info(_get_connection(conn_id))
+            return hook.get_openlineage_database_info(_get_connection(hook_id))
         except Exception as e:
             log.debug("Failed to get OpenLineage database info for %s: %s", 
conn_id, e)
             return None
 
     @cache
-    def _get_namespace(conn_id: str) -> str | None:
-        db_info = _get_database_info(conn_id)
+    def _get_namespace(hook_id: int) -> str | None:
+        db_info = _get_database_info(hook_id)
         return SQLParser.create_namespace(db_info) if db_info is not None else 
None
 
     for extra_info in sql_extras:
@@ -104,11 +111,12 @@ def emit_lineage_from_sql_extras(task_instance, 
sql_extras: list, is_successful:
         query_count += 1
 
         hook = extra_info.context
-        conn_id = _get_hook_conn_id(hook)
+        hook_id = id(hook)
+        conn_id = _hook_info[hook_id][1]
 
         # Parse SQL to obtain lineage (inputs, outputs, facets)
         query_lineage: OperatorLineage | None = None
-        database_info = _get_database_info(conn_id) if conn_id else None
+        database_info = _get_database_info(hook_id) if conn_id else None
         if sql_text and conn_id and database_info is not None:
             try:
                 query_lineage = get_openlineage_facets_with_sql(
@@ -117,7 +125,7 @@ def emit_lineage_from_sql_extras(task_instance, sql_extras: 
list, is_successful:
                     conn_id=conn_id,
                     
database=value.get(SqlJobHookLineageExtra.VALUE__DEFAULT_DB.value),
                     use_connection=False,  # Temporary solution before we 
figure out timeouts for queries
-                    connection=_get_connection(conn_id),
+                    connection=_get_connection(hook_id),
                     database_info=database_info,
                 )
             except Exception as e:
@@ -131,7 +139,7 @@ def emit_lineage_from_sql_extras(task_instance, sql_extras: 
list, is_successful:
             query_lineage = OperatorLineage(job_facets=job_facets)
 
         # Enrich run facets with external query info when available.
-        namespace = _get_namespace(conn_id) if conn_id else None
+        namespace = _get_namespace(hook_id) if conn_id else None
         if job_id and namespace:
             query_lineage.run_facets.setdefault(
                 "externalQuery",
diff --git 
a/providers/openlineage/tests/unit/openlineage/utils/test_sql_hook_lineage.py 
b/providers/openlineage/tests/unit/openlineage/utils/test_sql_hook_lineage.py
index 5eafefc3526..1866013bf56 100644
--- 
a/providers/openlineage/tests/unit/openlineage/utils/test_sql_hook_lineage.py
+++ 
b/providers/openlineage/tests/unit/openlineage/utils/test_sql_hook_lineage.py
@@ -541,3 +541,29 @@ class TestEmitLineageFromSqlExtras:
         assert result is None
         call_kwargs = self.mock_event_pair.call_args.kwargs
         assert call_kwargs["run_facets"]["externalQuery"] is original_ext_query
+
+    def test_different_hooks_same_conn_id_get_separate_db_info(self):
+        """Two hooks sharing a conn_id but returning different database info 
are cached separately."""
+        mock_ti = mock.MagicMock(dag_id="dag_id", task_id="task_id")
+
+        hook_a = mock.MagicMock()
+        hook_b = mock.MagicMock()
+
+        db_info_a = mock.MagicMock()
+        db_info_b = mock.MagicMock()
+        hook_a.get_openlineage_database_info.return_value = db_info_a
+        hook_b.get_openlineage_database_info.return_value = db_info_b
+
+        self.mock_conn_id.return_value = "same_conn"
+        self.mock_ns.side_effect = lambda db_info: f"ns_{id(db_info)}"
+        self.mock_facets_fn.return_value = OperatorLineage()
+
+        extras = [
+            _make_extra(sql="SELECT 1", hook=hook_a),
+            _make_extra(sql="SELECT 2", hook=hook_b),
+        ]
+        emit_lineage_from_sql_extras(task_instance=mock_ti, sql_extras=extras)
+
+        # Both hooks should have had get_openlineage_database_info called
+        hook_a.get_openlineage_database_info.assert_called_once()
+        hook_b.get_openlineage_database_info.assert_called_once()

Reply via email to