This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 9ba45b6d5f FIX: Only pass connection to sqlalchemy engine in JdbcHook
(#42705)
9ba45b6d5f is described below
commit 9ba45b6d5f474f5c39c563f98dd87afa4245a115
Author: David Blain <[email protected]>
AuthorDate: Mon Oct 7 17:30:26 2024 +0200
FIX: Only pass connection to sqlalchemy engine in JdbcHook (#42705)
* refactor: Only pass connection as creator to create a sqlalchemy engine
in JdbcHook, don't generalize it.
* refactor: Make sure engine_kwargs is initialised
* docs: Fixed type in docstring
* Revert "docs: Fixed type in docstring"
This reverts commit 05714bc366a3f765c2064dec7ed11e606b4df112.
* refactor: Added unit test for get_sqlalchemy_engine in JdbcHook
* refactor: Reformatted get_hook method in TestJdbcHook
* refactor: Refactored get_hook method in TestJdbcHook
* refactor: Subclassed JdbcHook to allow overriding the get_connection
method and return a mocked connection
---------
Co-authored-by: David Blain <[email protected]>
---
airflow/providers/common/sql/hooks/sql.py | 1 -
airflow/providers/jdbc/hooks/jdbc.py | 13 +++++++++++
tests/providers/jdbc/hooks/test_jdbc.py | 37 ++++++++++++++++++++++++++-----
3 files changed, 45 insertions(+), 6 deletions(-)
diff --git a/airflow/providers/common/sql/hooks/sql.py
b/airflow/providers/common/sql/hooks/sql.py
index dfa8c6fc72..7983808d0d 100644
--- a/airflow/providers/common/sql/hooks/sql.py
+++ b/airflow/providers/common/sql/hooks/sql.py
@@ -275,7 +275,6 @@ class DbApiHook(BaseHook):
"""
if engine_kwargs is None:
engine_kwargs = {}
- engine_kwargs["creator"] = self.get_conn
try:
url = self.sqlalchemy_url
diff --git a/airflow/providers/jdbc/hooks/jdbc.py
b/airflow/providers/jdbc/hooks/jdbc.py
index 27a438ae41..356bd5d450 100644
--- a/airflow/providers/jdbc/hooks/jdbc.py
+++ b/airflow/providers/jdbc/hooks/jdbc.py
@@ -163,6 +163,19 @@ class JdbcHook(DbApiHook):
database=conn.schema,
)
+ def get_sqlalchemy_engine(self, engine_kwargs=None):
+ """
+ Get an sqlalchemy_engine object.
+
+ :param engine_kwargs: Kwargs used in :func:`~sqlalchemy.create_engine`.
+ :return: the created engine.
+ """
+ if engine_kwargs is None:
+ engine_kwargs = {}
+ engine_kwargs["creator"] = self.get_conn
+
+ return super().get_sqlalchemy_engine(engine_kwargs)
+
def get_conn(self) -> jaydebeapi.Connection:
conn: Connection = self.get_connection(self.get_conn_id())
host: str = conn.host
diff --git a/tests/providers/jdbc/hooks/test_jdbc.py
b/tests/providers/jdbc/hooks/test_jdbc.py
index cb38ce40ae..f26a9d7ffb 100644
--- a/tests/providers/jdbc/hooks/test_jdbc.py
+++ b/tests/providers/jdbc/hooks/test_jdbc.py
@@ -19,6 +19,7 @@ from __future__ import annotations
import json
import logging
+import sqlite3
from unittest import mock
from unittest.mock import Mock, patch
@@ -36,19 +37,30 @@ pytestmark = pytest.mark.db_test
jdbc_conn_mock = Mock(name="jdbc_conn")
-def get_hook(hook_params=None, conn_params=None):
+def get_hook(
+ hook_params=None,
+ conn_params=None,
+ login: str | None = "login",
+ password: str | None = "password",
+ host: str | None = "host",
+ schema: str | None = "schema",
+ port: int | None = 1234,
+):
hook_params = hook_params or {}
conn_params = conn_params or {}
connection = Connection(
**{
- **dict(login="login", password="password", host="host",
schema="schema", port=1234),
+ **dict(login=login, password=password, host=host, schema=schema,
port=port),
**conn_params,
}
)
- hook = JdbcHook(**hook_params)
- hook.get_connection = Mock()
- hook.get_connection.return_value = connection
+ class MockedJdbcHook(JdbcHook):
+ @classmethod
+ def get_connection(cls, conn_id: str) -> Connection:
+ return connection
+
+ hook = MockedJdbcHook(**hook_params)
return hook
@@ -201,3 +213,18 @@ class TestJdbcHook:
hook = get_hook(conn_params=conn_params, hook_params=hook_params)
assert str(hook.sqlalchemy_url) ==
"mssql://login:password@host:1234/schema"
+
+ def test_get_sqlalchemy_engine_verify_creator_is_being_used(self):
+ jdbc_hook = get_hook(
+ conn_params=dict(extra={"sqlalchemy_scheme": "sqlite"}),
+ login=None,
+ password=None,
+ host=None,
+ schema=":memory:",
+ port=None,
+ )
+
+ with sqlite3.connect(":memory:") as connection:
+ jdbc_hook.get_conn = lambda: connection
+ engine = jdbc_hook.get_sqlalchemy_engine()
+ assert engine.connect().connection.connection == connection