This is an automated email from the ASF dual-hosted git repository.

shahar pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 69af18592ff refactor: OdbcHook must use it's own connection when 
creating a sqlalchemy engine (#43145)
69af18592ff is described below

commit 69af18592ffcb666b180fa56215200e545428bf3
Author: David Blain <[email protected]>
AuthorDate: Wed Oct 23 17:12:17 2024 +0200

    refactor: OdbcHook must use it's own connection when creating a sqlalchemy 
engine (#43145)
    
    Co-authored-by: David Blain <[email protected]>
---
 providers/src/airflow/providers/odbc/hooks/odbc.py | 13 +++++++++++++
 providers/tests/odbc/hooks/test_odbc.py            |  9 +++++++++
 2 files changed, 22 insertions(+)

diff --git a/providers/src/airflow/providers/odbc/hooks/odbc.py 
b/providers/src/airflow/providers/odbc/hooks/odbc.py
index 48dada49f88..aa3f9ce50fa 100644
--- a/providers/src/airflow/providers/odbc/hooks/odbc.py
+++ b/providers/src/airflow/providers/odbc/hooks/odbc.py
@@ -180,6 +180,19 @@ class OdbcHook(DbApiHook):
 
         return merged_connect_kwargs
 
+    def get_sqlalchemy_engine(self, engine_kwargs=None):
+        """
+        Get an sqlalchemy_engine object.
+
+        :param engine_kwargs: Kwargs used in :func:`~sqlalchemy.create_engine`.
+        :return: the created engine.
+        """
+        if engine_kwargs is None:
+            engine_kwargs = {}
+        engine_kwargs["creator"] = self.get_conn
+
+        return super().get_sqlalchemy_engine(engine_kwargs)
+
     def get_conn(self) -> Connection:
         """Return ``pyodbc`` connection object."""
         conn = connect(self.odbc_connection_string, **self.connect_kwargs)
diff --git a/providers/tests/odbc/hooks/test_odbc.py 
b/providers/tests/odbc/hooks/test_odbc.py
index 8f749aa4f76..5d2e195dcc6 100644
--- a/providers/tests/odbc/hooks/test_odbc.py
+++ b/providers/tests/odbc/hooks/test_odbc.py
@@ -19,6 +19,7 @@ from __future__ import annotations
 
 import json
 import logging
+import sqlite3
 from dataclasses import dataclass
 from unittest import mock
 from unittest.mock import patch
@@ -340,3 +341,11 @@ class TestOdbcHook:
         hook = mock_hook(OdbcHook)
         result = hook.run("SQL")
         assert result is None
+
+    def test_get_sqlalchemy_engine_verify_creator_is_being_used(self):
+        hook = mock_hook(OdbcHook, conn_params={"extra": {"sqlalchemy_scheme": 
"sqlite"}})
+
+        with sqlite3.connect(":memory:") as connection:
+            hook.get_conn = lambda: connection
+            engine = hook.get_sqlalchemy_engine()
+            assert engine.connect().connection.connection == connection

Reply via email to