This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 9ba45b6d5f FIX: Only pass connection to sqlalchemy engine in JdbcHook 
(#42705)
9ba45b6d5f is described below

commit 9ba45b6d5f474f5c39c563f98dd87afa4245a115
Author: David Blain <[email protected]>
AuthorDate: Mon Oct 7 17:30:26 2024 +0200

    FIX: Only pass connection to sqlalchemy engine in JdbcHook (#42705)
    
    * refactor: Only pass connection as creator to create a sqlalchemy engine 
in JdbcHook, don't generalize it.
    
    * refactor: Make sure engine_kwargs is initialised
    
    * docs: Fixed type in docstring
    
    * Revert "docs: Fixed type in docstring"
    
    This reverts commit 05714bc366a3f765c2064dec7ed11e606b4df112.
    
    * refactor: Added unit test for get_sqlalchemy_engine in JdbcHook
    
    * refactor: Reformatted get_hook method in TestJdbcHook
    
    * refactor: Refactored get_hook method in TestJdbcHook
    
    * refactor: Subclassed JdbcHook to allow overriding the get_connection 
method and return a mocked connection
    
    ---------
    
    Co-authored-by: David Blain <[email protected]>
---
 airflow/providers/common/sql/hooks/sql.py |  1 -
 airflow/providers/jdbc/hooks/jdbc.py      | 13 +++++++++++
 tests/providers/jdbc/hooks/test_jdbc.py   | 37 ++++++++++++++++++++++++++-----
 3 files changed, 45 insertions(+), 6 deletions(-)

diff --git a/airflow/providers/common/sql/hooks/sql.py 
b/airflow/providers/common/sql/hooks/sql.py
index dfa8c6fc72..7983808d0d 100644
--- a/airflow/providers/common/sql/hooks/sql.py
+++ b/airflow/providers/common/sql/hooks/sql.py
@@ -275,7 +275,6 @@ class DbApiHook(BaseHook):
         """
         if engine_kwargs is None:
             engine_kwargs = {}
-        engine_kwargs["creator"] = self.get_conn
 
         try:
             url = self.sqlalchemy_url
diff --git a/airflow/providers/jdbc/hooks/jdbc.py 
b/airflow/providers/jdbc/hooks/jdbc.py
index 27a438ae41..356bd5d450 100644
--- a/airflow/providers/jdbc/hooks/jdbc.py
+++ b/airflow/providers/jdbc/hooks/jdbc.py
@@ -163,6 +163,19 @@ class JdbcHook(DbApiHook):
             database=conn.schema,
         )
 
+    def get_sqlalchemy_engine(self, engine_kwargs=None):
+        """
+        Get an sqlalchemy_engine object.
+
+        :param engine_kwargs: Kwargs used in :func:`~sqlalchemy.create_engine`.
+        :return: the created engine.
+        """
+        if engine_kwargs is None:
+            engine_kwargs = {}
+        engine_kwargs["creator"] = self.get_conn
+
+        return super().get_sqlalchemy_engine(engine_kwargs)
+
     def get_conn(self) -> jaydebeapi.Connection:
         conn: Connection = self.get_connection(self.get_conn_id())
         host: str = conn.host
diff --git a/tests/providers/jdbc/hooks/test_jdbc.py 
b/tests/providers/jdbc/hooks/test_jdbc.py
index cb38ce40ae..f26a9d7ffb 100644
--- a/tests/providers/jdbc/hooks/test_jdbc.py
+++ b/tests/providers/jdbc/hooks/test_jdbc.py
@@ -19,6 +19,7 @@ from __future__ import annotations
 
 import json
 import logging
+import sqlite3
 from unittest import mock
 from unittest.mock import Mock, patch
 
@@ -36,19 +37,30 @@ pytestmark = pytest.mark.db_test
 jdbc_conn_mock = Mock(name="jdbc_conn")
 
 
-def get_hook(hook_params=None, conn_params=None):
+def get_hook(
+    hook_params=None,
+    conn_params=None,
+    login: str | None = "login",
+    password: str | None = "password",
+    host: str | None = "host",
+    schema: str | None = "schema",
+    port: int | None = 1234,
+):
     hook_params = hook_params or {}
     conn_params = conn_params or {}
     connection = Connection(
         **{
-            **dict(login="login", password="password", host="host", 
schema="schema", port=1234),
+            **dict(login=login, password=password, host=host, schema=schema, 
port=port),
             **conn_params,
         }
     )
 
-    hook = JdbcHook(**hook_params)
-    hook.get_connection = Mock()
-    hook.get_connection.return_value = connection
+    class MockedJdbcHook(JdbcHook):
+        @classmethod
+        def get_connection(cls, conn_id: str) -> Connection:
+            return connection
+
+    hook = MockedJdbcHook(**hook_params)
     return hook
 
 
@@ -201,3 +213,18 @@ class TestJdbcHook:
         hook = get_hook(conn_params=conn_params, hook_params=hook_params)
 
         assert str(hook.sqlalchemy_url) == 
"mssql://login:password@host:1234/schema"
+
+    def test_get_sqlalchemy_engine_verify_creator_is_being_used(self):
+        jdbc_hook = get_hook(
+            conn_params=dict(extra={"sqlalchemy_scheme": "sqlite"}),
+            login=None,
+            password=None,
+            host=None,
+            schema=":memory:",
+            port=None,
+        )
+
+        with sqlite3.connect(":memory:") as connection:
+            jdbc_hook.get_conn = lambda: connection
+            engine = jdbc_hook.get_sqlalchemy_engine()
+            assert engine.connect().connection.connection == connection

Reply via email to