This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 0b1308ce45 Make `placeholder` of DbApiHook configurable in UI (#38528)
0b1308ce45 is described below

commit 0b1308ce455ed9a6c0263ae63aa0166fed278453
Author: David Blain <[email protected]>
AuthorDate: Sat Apr 6 09:10:53 2024 +0200

    Make `placeholder` of DbApiHook configurable in UI (#38528)
    
    * refactor: Moved placeholder property from OdbcHook class to parent 
DbApiHook class so that the same logic can also be used with the JdbcHook
    
    * refactor: Import BaseHook under type checking block
    
    * refactor: Marked test_placeholder_config_from_extra as a db test
    
    * refactor: Moved mock_conn from conftest to test_utils module under common 
sql
    
    * refactor: Removed unnecessary else statement in placeholder property
    
    * refactor: Default placeholder can be a class/static variable as it's only 
purpose is to define a default SQL placeholder, the actual placeholder will 
always be retrieved through the property
    
    * refactor: Updated sql test with changes from main
    
    * refactor: Reformatted test
    
    * Update airflow/providers/common/sql/hooks/sql.py
    
    Co-authored-by: Elad Kalif <[email protected]>
    
    * fix: Fixed name of constant SQL_PLACEHOLDERS being checked in placeholder 
property
    
    ---------
    
    Co-authored-by: David Blain <[email protected]>
    Co-authored-by: Elad Kalif <[email protected]>
---
 airflow/providers/common/sql/hooks/sql.py    | 16 ++++-
 airflow/providers/odbc/hooks/odbc.py         | 16 -----
 tests/providers/common/sql/hooks/test_sql.py | 75 ++++++++++++-----------
 tests/providers/common/sql/test_utils.py     | 55 +++++++++++++++++
 tests/providers/odbc/hooks/test_odbc.py      | 92 ++++++++++------------------
 5 files changed, 138 insertions(+), 116 deletions(-)

diff --git a/airflow/providers/common/sql/hooks/sql.py 
b/airflow/providers/common/sql/hooks/sql.py
index 7c2480bb6a..3f324e4f69 100644
--- a/airflow/providers/common/sql/hooks/sql.py
+++ b/airflow/providers/common/sql/hooks/sql.py
@@ -50,6 +50,7 @@ if TYPE_CHECKING:
     from airflow.providers.openlineage.sqlparser import DatabaseInfo
 
 T = TypeVar("T")
+SQL_PLACEHOLDERS = frozenset({"%s", "?"})
 
 
 def return_single_query_results(sql: str | Iterable[str], return_last: bool, 
split_statements: bool):
@@ -146,6 +147,8 @@ class DbApiHook(BaseHook):
     connector: ConnectorProtocol | None = None
     # Override with db-specific query to check connection
     _test_connection_sql = "select 1"
+    # Default SQL placeholder
+    _placeholder: str = "%s"
 
     def __init__(self, *args, schema: str | None = None, log_sql: bool = True, 
**kwargs):
         super().__init__()
@@ -164,7 +167,6 @@ class DbApiHook(BaseHook):
         self.__schema = schema
         self.log_sql = log_sql
         self.descriptions: list[Sequence[Sequence] | None] = []
-        self._placeholder: str = "%s"
         self._insert_statement_format: str = kwargs.get(
             "insert_statement_format", "INSERT INTO {} {} VALUES ({})"
         )
@@ -173,7 +175,17 @@ class DbApiHook(BaseHook):
         )
 
     @property
-    def placeholder(self) -> str:
+    def placeholder(self):
+        conn = self.get_connection(getattr(self, self.conn_name_attr))
+        placeholder = conn.extra_dejson.get("placeholder")
+        if placeholder in SQL_PLACEHOLDERS:
+            return placeholder
+        self.log.warning(
+            "Placeholder defined in Connection '%s' is not listed in 
'DEFAULT_SQL_PLACEHOLDERS' "
+            "and got ignored. Falling back to the default placeholder '%s'.",
+            placeholder,
+            self._placeholder,
+        )
         return self._placeholder
 
     def get_conn(self):
diff --git a/airflow/providers/odbc/hooks/odbc.py 
b/airflow/providers/odbc/hooks/odbc.py
index a14e64d6df..8cf95bf095 100644
--- a/airflow/providers/odbc/hooks/odbc.py
+++ b/airflow/providers/odbc/hooks/odbc.py
@@ -27,8 +27,6 @@ from pyodbc import Connection, Row, connect
 from airflow.providers.common.sql.hooks.sql import DbApiHook
 from airflow.utils.helpers import merge_dicts
 
-DEFAULT_ODBC_PLACEHOLDERS = frozenset({"%s", "?"})
-
 
 class OdbcHook(DbApiHook):
     """
@@ -202,20 +200,6 @@ class OdbcHook(DbApiHook):
         conn = connect(self.odbc_connection_string, **self.connect_kwargs)
         return conn
 
-    @property
-    def placeholder(self):
-        placeholder = self.connection.extra_dejson.get("placeholder")
-        if placeholder in DEFAULT_ODBC_PLACEHOLDERS:
-            return placeholder
-        else:
-            self.log.warning(
-                "Placeholder defined in Connection '%s' is not listed in 
'DEFAULT_ODBC_PLACEHOLDERS' "
-                "and got ignored. Falling back to the default placeholder 
'%s'.",
-                placeholder,
-                self._placeholder,
-            )
-            return self._placeholder
-
     def get_uri(self) -> str:
         """URI invoked in 
:meth:`~airflow.providers.common.sql.hooks.sql.DbApiHook.get_sqlalchemy_engine`."""
         quoted_conn_str = quote_plus(self.odbc_connection_string)
diff --git a/tests/providers/common/sql/hooks/test_sql.py 
b/tests/providers/common/sql/hooks/test_sql.py
index e0a7b0fccb..4bd5bdcc54 100644
--- a/tests/providers/common/sql/hooks/test_sql.py
+++ b/tests/providers/common/sql/hooks/test_sql.py
@@ -19,7 +19,6 @@
 from __future__ import annotations
 
 import warnings
-from typing import Any
 from unittest.mock import MagicMock
 
 import pytest
@@ -28,6 +27,7 @@ from airflow.exceptions import 
AirflowProviderDeprecationWarning
 from airflow.models import Connection
 from airflow.providers.common.sql.hooks.sql import DbApiHook, fetch_all_handler
 from airflow.utils.session import provide_session
+from tests.providers.common.sql.test_utils import mock_hook
 
 TASK_ID = "sql-operator"
 HOST = "host"
@@ -214,39 +214,40 @@ def test_query(
     dbapi_hook.get_conn.return_value.cursor.return_value.close.assert_called()
 
 
[email protected]_test
[email protected](
-    "empty_statement",
-    [
-        pytest.param([], id="Empty list"),
-        pytest.param("", id="Empty string"),
-        pytest.param("\n", id="Only EOL"),
-    ],
-)
-def test_no_query(empty_statement):
-    dbapi_hook = DBApiHookForTests()
-    dbapi_hook.get_conn.return_value.cursor.rowcount = 0
-    with pytest.raises(ValueError) as err:
-        dbapi_hook.run(sql=empty_statement)
-    assert err.value.args[0] == "List of SQL statements is empty"
-
-
[email protected]_test
-def test_make_common_data_structure_hook_has_deprecated_method():
-    """If hook implements ``_make_serializable`` warning should be raised on 
call."""
-
-    class DBApiHookForMakeSerializableTests(DBApiHookForTests):
-        def _make_serializable(self, result: Any):
-            return result
-
-    hook = DBApiHookForMakeSerializableTests()
-    with pytest.warns(AirflowProviderDeprecationWarning, 
match="`_make_serializable` method is deprecated"):
-        hook._make_common_data_structure(["foo", "bar", "baz"])
-
-
[email protected]_test
-def test_make_common_data_structure_no_deprecated_method():
-    """If hook not implements ``_make_serializable`` there is no warning 
should be raised on call."""
-    with warnings.catch_warnings():
-        warnings.simplefilter("error", AirflowProviderDeprecationWarning)
-        DBApiHookForTests()._make_common_data_structure(["foo", "bar", "baz"])
+class TestDbApiHook:
+    @pytest.mark.db_test
+    @pytest.mark.parametrize(
+        "empty_statement",
+        [
+            pytest.param([], id="Empty list"),
+            pytest.param("", id="Empty string"),
+            pytest.param("\n", id="Only EOL"),
+        ],
+    )
+    def test_no_query(self, empty_statement):
+        dbapi_hook = mock_hook(DbApiHook)
+        with pytest.raises(ValueError) as err:
+            dbapi_hook.run(sql=empty_statement)
+        assert err.value.args[0] == "List of SQL statements is empty"
+
+    @pytest.mark.db_test
+    def test_make_common_data_structure_hook_has_deprecated_method(self):
+        """If hook implements ``_make_serializable`` warning should be raised 
on call."""
+        hook = mock_hook(DbApiHook)
+        hook._make_serializable = lambda result: result
+        with pytest.warns(
+            AirflowProviderDeprecationWarning, match="`_make_serializable` 
method is deprecated"
+        ):
+            hook._make_common_data_structure(["foo", "bar", "baz"])
+
+    @pytest.mark.db_test
+    def test_make_common_data_structure_no_deprecated_method(self):
+        """If hook not implements ``_make_serializable`` there is no warning 
should be raised on call."""
+        with warnings.catch_warnings():
+            warnings.simplefilter("error", AirflowProviderDeprecationWarning)
+            mock_hook(DbApiHook)._make_common_data_structure(["foo", "bar", 
"baz"])
+
+    @pytest.mark.db_test
+    def test_placeholder_config_from_extra(self):
+        dbapi_hook = mock_hook(DbApiHook, conn_params={"extra": 
{"placeholder": "?"}})
+        assert dbapi_hook.placeholder == "?"
diff --git a/tests/providers/common/sql/test_utils.py 
b/tests/providers/common/sql/test_utils.py
new file mode 100644
index 0000000000..c3bc4c3565
--- /dev/null
+++ b/tests/providers/common/sql/test_utils.py
@@ -0,0 +1,55 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+from unittest import mock
+
+from airflow.models import Connection
+
+if TYPE_CHECKING:
+    from airflow.hooks.base import BaseHook
+
+
+def mock_hook(hook_class: type[BaseHook], hook_params=None, conn_params=None):
+    hook_params = hook_params or {}
+    conn_params = conn_params or {}
+    connection = Connection(
+        **{
+            **dict(login="login", password="password", host="host", 
schema="schema", port=1234),
+            **conn_params,
+        }
+    )
+
+    cursor = mock.MagicMock(
+        rowcount=0, spec=["description", "rowcount", "execute", "fetchall", 
"fetchone", "close"]
+    )
+    conn = mock.MagicMock()
+    conn.cursor.return_value = cursor
+
+    class MockedHook(hook_class):
+        conn_name_attr = "test_conn_id"
+
+        @classmethod
+        def get_connection(cls, conn_id: str):
+            return connection
+
+        def get_conn(self):
+            return conn
+
+    return MockedHook(**hook_params)
diff --git a/tests/providers/odbc/hooks/test_odbc.py 
b/tests/providers/odbc/hooks/test_odbc.py
index 64683d21e1..bddd2ffd99 100644
--- a/tests/providers/odbc/hooks/test_odbc.py
+++ b/tests/providers/odbc/hooks/test_odbc.py
@@ -27,8 +27,8 @@ from urllib.parse import quote_plus, urlsplit
 import pyodbc
 import pytest
 
-from airflow.models import Connection
 from airflow.providers.odbc.hooks.odbc import OdbcHook
+from tests.providers.common.sql.test_utils import mock_hook
 
 
 @pytest.fixture
@@ -77,38 +77,10 @@ def pyodbc_instancecheck():
 
 
 class TestOdbcHook:
-    def get_hook(self=None, hook_params=None, conn_params=None):
-        hook_params = hook_params or {}
-        conn_params = conn_params or {}
-        connection = Connection(
-            **{
-                **dict(login="login", password="password", host="host", 
schema="schema", port=1234),
-                **conn_params,
-            }
-        )
-
-        cursor = mock.MagicMock(
-            rowcount=0, spec=["description", "rowcount", "execute", 
"fetchall", "fetchone", "close"]
-        )
-        conn = mock.MagicMock()
-        conn.cursor.return_value = cursor
-
-        class UnitTestOdbcHook(OdbcHook):
-            conn_name_attr = "test_conn_id"
-
-            @classmethod
-            def get_connection(cls, conn_id: str):
-                return connection
-
-            def get_conn(self):
-                return conn
-
-        return UnitTestOdbcHook(**hook_params)
-
     def test_driver_in_extra_not_used(self):
         conn_params = dict(extra=json.dumps(dict(Driver="Fake Driver", 
Fake_Param="Fake Param")))
         hook_params = {"driver": "ParamDriver"}
-        hook = self.get_hook(conn_params=conn_params, hook_params=hook_params)
+        hook = mock_hook(OdbcHook, conn_params=conn_params, 
hook_params=hook_params)
         expected = (
             "DRIVER={ParamDriver};"
             "SERVER=host;"
@@ -123,7 +95,7 @@ class TestOdbcHook:
     def test_driver_in_both(self):
         conn_params = dict(extra=json.dumps(dict(Driver="Fake Driver", 
Fake_Param="Fake Param")))
         hook_params = dict(driver="ParamDriver")
-        hook = self.get_hook(hook_params=hook_params, conn_params=conn_params)
+        hook = mock_hook(OdbcHook, conn_params=conn_params, 
hook_params=hook_params)
         expected = (
             "DRIVER={ParamDriver};"
             "SERVER=host;"
@@ -137,7 +109,7 @@ class TestOdbcHook:
 
     def test_dsn_in_extra(self):
         conn_params = dict(extra=json.dumps(dict(DSN="MyDSN", Fake_Param="Fake 
Param")))
-        hook = self.get_hook(conn_params=conn_params)
+        hook = mock_hook(OdbcHook, conn_params=conn_params)
         expected = (
             
"DSN=MyDSN;SERVER=host;DATABASE=schema;UID=login;PWD=password;PORT=1234;Fake_Param=Fake
 Param;"
         )
@@ -146,7 +118,7 @@ class TestOdbcHook:
     def test_dsn_in_both(self):
         conn_params = dict(extra=json.dumps(dict(DSN="MyDSN", Fake_Param="Fake 
Param")))
         hook_params = dict(driver="ParamDriver", dsn="ParamDSN")
-        hook = self.get_hook(hook_params=hook_params, conn_params=conn_params)
+        hook = mock_hook(OdbcHook, conn_params=conn_params, 
hook_params=hook_params)
         expected = (
             "DRIVER={ParamDriver};"
             "DSN=ParamDSN;"
@@ -162,7 +134,7 @@ class TestOdbcHook:
     def test_get_uri(self):
         conn_params = dict(extra=json.dumps(dict(DSN="MyDSN", Fake_Param="Fake 
Param")))
         hook_params = dict(dsn="ParamDSN")
-        hook = self.get_hook(hook_params=hook_params, conn_params=conn_params)
+        hook = mock_hook(OdbcHook, conn_params=conn_params, 
hook_params=hook_params)
         uri_param = quote_plus(
             
"DSN=ParamDSN;SERVER=host;DATABASE=schema;UID=login;PWD=password;PORT=1234;Fake_Param=Fake
 Param;"
         )
@@ -170,7 +142,8 @@ class TestOdbcHook:
         assert hook.get_uri() == expected
 
     def test_connect_kwargs_from_hook(self):
-        hook = self.get_hook(
+        hook = mock_hook(
+            OdbcHook,
             hook_params=dict(
                 connect_kwargs={
                     "attrs_before": {
@@ -202,7 +175,7 @@ class TestOdbcHook:
             )
         )
 
-        hook = self.get_hook(conn_params=dict(extra=extra))
+        hook = mock_hook(OdbcHook, conn_params=dict(extra=extra))
         assert hook.connect_kwargs == {
             "attrs_before": {1: 2, pyodbc.SQL_TXN_ISOLATION: 
pyodbc.SQL_TXN_READ_UNCOMMITTED},
             "readonly": True,
@@ -219,7 +192,7 @@ class TestOdbcHook:
             connect_kwargs={"attrs_before": {3: 5, pyodbc.SQL_TXN_ISOLATION: 
0}, "readonly": True}
         )
 
-        hook = self.get_hook(conn_params=dict(extra=conn_extra), 
hook_params=hook_params)
+        hook = mock_hook(OdbcHook, conn_params=dict(extra=conn_extra), 
hook_params=hook_params)
         assert hook.connect_kwargs == {
             "attrs_before": {1: 2, 3: 5, pyodbc.SQL_TXN_ISOLATION: 0},
             "readonly": True,
@@ -230,74 +203,71 @@ class TestOdbcHook:
         Bools will be parsed from uri as strings
         """
         conn_extra = json.dumps(dict(connect_kwargs={"ansi": True}))
-        hook = self.get_hook(conn_params=dict(extra=conn_extra))
+        hook = mock_hook(OdbcHook, conn_params=dict(extra=conn_extra))
         assert hook.connect_kwargs == {
             "ansi": True,
         }
 
     def test_driver(self):
-        hook = self.get_hook(hook_params=dict(driver="Blah driver"))
+        hook = mock_hook(OdbcHook, hook_params=dict(driver="Blah driver"))
         assert hook.driver == "Blah driver"
-        hook = self.get_hook(hook_params=dict(driver="{Blah driver}"))
+        hook = mock_hook(OdbcHook, hook_params=dict(driver="{Blah driver}"))
         assert hook.driver == "Blah driver"
 
     def test_driver_extra_raises_warning_by_default(self, caplog):
         with caplog.at_level(logging.WARNING, 
logger="airflow.providers.odbc.hooks.test_odbc"):
-            driver = self.get_hook(conn_params=dict(extra='{"driver": "Blah 
driver"}')).driver
+            driver = mock_hook(OdbcHook, conn_params=dict(extra='{"driver": 
"Blah driver"}')).driver
             assert "You have supplied 'driver' via connection extra but it 
will not be used" in caplog.text
             assert driver is None
 
     @mock.patch.dict("os.environ", 
{"AIRFLOW__PROVIDERS_ODBC__ALLOW_DRIVER_IN_EXTRA": "TRUE"})
     def test_driver_extra_works_when_allow_driver_extra(self):
-        hook = self.get_hook(
-            conn_params=dict(extra='{"driver": "Blah driver"}'), 
hook_params=dict(allow_driver_extra=True)
+        hook = mock_hook(
+            OdbcHook,
+            conn_params=dict(extra='{"driver": "Blah driver"}'),
+            hook_params=dict(allow_driver_extra=True),
         )
         assert hook.driver == "Blah driver"
 
     def test_default_driver_set(self):
         with patch.object(OdbcHook, "default_driver", "Blah driver"):
-            hook = self.get_hook()
+            hook = mock_hook(OdbcHook)
             assert hook.driver == "Blah driver"
 
     def test_driver_extra_works_when_default_driver_set(self):
         with patch.object(OdbcHook, "default_driver", "Blah driver"):
-            hook = self.get_hook()
+            hook = mock_hook(OdbcHook)
             assert hook.driver == "Blah driver"
 
     def test_driver_none_by_default(self):
-        hook = self.get_hook()
+        hook = mock_hook(OdbcHook)
         assert hook.driver is None
 
     def 
test_driver_extra_raises_warning_and_returns_default_driver_by_default(self, 
caplog):
         with patch.object(OdbcHook, "default_driver", "Blah driver"):
             with caplog.at_level(logging.WARNING, 
logger="airflow.providers.odbc.hooks.test_odbc"):
-                driver = self.get_hook(conn_params=dict(extra='{"driver": 
"Blah driver2"}')).driver
+                driver = mock_hook(OdbcHook, 
conn_params=dict(extra='{"driver": "Blah driver2"}')).driver
                 assert "have supplied 'driver' via connection extra but it 
will not be used" in caplog.text
                 assert driver == "Blah driver"
 
-    def test_placeholder_config_from_extra(self):
-        conn_params = dict(extra=json.dumps(dict(placeholder="?")))
-        hook = self.get_hook(conn_params=conn_params)
-        assert hook.placeholder == "?"
-
     def test_database(self):
-        hook = self.get_hook(hook_params=dict(database="abc"))
+        hook = mock_hook(OdbcHook, hook_params=dict(database="abc"))
         assert hook.database == "abc"
-        hook = self.get_hook()
+        hook = mock_hook(OdbcHook)
         assert hook.database == "schema"
 
     def test_sqlalchemy_scheme_default(self):
-        hook = self.get_hook()
+        hook = mock_hook(OdbcHook)
         uri = hook.get_uri()
         assert urlsplit(uri).scheme == "mssql+pyodbc"
 
     def test_sqlalchemy_scheme_param(self):
-        hook = self.get_hook(hook_params=dict(sqlalchemy_scheme="my-scheme"))
+        hook = mock_hook(OdbcHook, 
hook_params=dict(sqlalchemy_scheme="my-scheme"))
         uri = hook.get_uri()
         assert urlsplit(uri).scheme == "my-scheme"
 
     def test_sqlalchemy_scheme_extra(self):
-        hook = 
self.get_hook(conn_params=dict(extra=json.dumps(dict(sqlalchemy_scheme="my-scheme"))))
+        hook = mock_hook(OdbcHook, 
conn_params=dict(extra=json.dumps(dict(sqlalchemy_scheme="my-scheme"))))
         uri = hook.get_uri()
         assert urlsplit(uri).scheme == "my-scheme"
 
@@ -323,7 +293,7 @@ class TestOdbcHook:
         def mock_handler(*_):
             return pyodbc_result
 
-        hook = self.get_hook()
+        hook = mock_hook(OdbcHook)
         with monkeypatch.context() as patcher:
             patcher.setattr("pyodbc.Row", pyodbc_instancecheck)
             result = hook.run("SQL", handler=mock_handler)
@@ -340,7 +310,7 @@ class TestOdbcHook:
         def mock_handler(*_):
             return pyodbc_result
 
-        hook = self.get_hook()
+        hook = mock_hook(OdbcHook)
         with monkeypatch.context() as patcher:
             patcher.setattr("pyodbc.Row", pyodbc_instancecheck)
             result = hook.run("SQL", handler=mock_handler)
@@ -359,13 +329,13 @@ class TestOdbcHook:
         def mock_handler(*_):
             return pyodbc_result
 
-        hook = self.get_hook()
+        hook = mock_hook(OdbcHook)
         with monkeypatch.context() as patcher:
             patcher.setattr("pyodbc.Row", pyodbc_instancecheck)
             result = hook.run("SQL", handler=mock_handler)
         assert hook_result == result
 
     def test_query_no_handler_return_none(self):
-        hook = self.get_hook()
+        hook = mock_hook(OdbcHook)
         result = hook.run("SQL")
         assert result is None

Reply via email to