This is an automated email from the ASF dual-hosted git repository.
shahar pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 69af18592ff refactor: OdbcHook must use it's own connection when
creating a sqlalchemy engine (#43145)
69af18592ff is described below
commit 69af18592ffcb666b180fa56215200e545428bf3
Author: David Blain <[email protected]>
AuthorDate: Wed Oct 23 17:12:17 2024 +0200
refactor: OdbcHook must use it's own connection when creating a sqlalchemy
engine (#43145)
Co-authored-by: David Blain <[email protected]>
---
providers/src/airflow/providers/odbc/hooks/odbc.py | 13 +++++++++++++
providers/tests/odbc/hooks/test_odbc.py | 9 +++++++++
2 files changed, 22 insertions(+)
diff --git a/providers/src/airflow/providers/odbc/hooks/odbc.py
b/providers/src/airflow/providers/odbc/hooks/odbc.py
index 48dada49f88..aa3f9ce50fa 100644
--- a/providers/src/airflow/providers/odbc/hooks/odbc.py
+++ b/providers/src/airflow/providers/odbc/hooks/odbc.py
@@ -180,6 +180,19 @@ class OdbcHook(DbApiHook):
return merged_connect_kwargs
+ def get_sqlalchemy_engine(self, engine_kwargs=None):
+ """
+ Get an sqlalchemy_engine object.
+
+ :param engine_kwargs: Kwargs used in :func:`~sqlalchemy.create_engine`.
+ :return: the created engine.
+ """
+ if engine_kwargs is None:
+ engine_kwargs = {}
+ engine_kwargs["creator"] = self.get_conn
+
+ return super().get_sqlalchemy_engine(engine_kwargs)
+
def get_conn(self) -> Connection:
"""Return ``pyodbc`` connection object."""
conn = connect(self.odbc_connection_string, **self.connect_kwargs)
diff --git a/providers/tests/odbc/hooks/test_odbc.py
b/providers/tests/odbc/hooks/test_odbc.py
index 8f749aa4f76..5d2e195dcc6 100644
--- a/providers/tests/odbc/hooks/test_odbc.py
+++ b/providers/tests/odbc/hooks/test_odbc.py
@@ -19,6 +19,7 @@ from __future__ import annotations
import json
import logging
+import sqlite3
from dataclasses import dataclass
from unittest import mock
from unittest.mock import patch
@@ -340,3 +341,11 @@ class TestOdbcHook:
hook = mock_hook(OdbcHook)
result = hook.run("SQL")
assert result is None
+
+ def test_get_sqlalchemy_engine_verify_creator_is_being_used(self):
+ hook = mock_hook(OdbcHook, conn_params={"extra": {"sqlalchemy_scheme":
"sqlite"}})
+
+ with sqlite3.connect(":memory:") as connection:
+ hook.get_conn = lambda: connection
+ engine = hook.get_sqlalchemy_engine()
+ assert engine.connect().connection.connection == connection