This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 0b1308ce45 Make `placeholder` of DbApiHook configurable in UI (#38528)
0b1308ce45 is described below
commit 0b1308ce455ed9a6c0263ae63aa0166fed278453
Author: David Blain <[email protected]>
AuthorDate: Sat Apr 6 09:10:53 2024 +0200
Make `placeholder` of DbApiHook configurable in UI (#38528)
* refactor: Moved placeholder property from OdbcHook class to parent
DbApiHook class so that the same logic can also be used with the JdbcHook
* refactor: Import BaseHook under type checking block
* refactor: Marked test_placeholder_config_from_extra as a db test
* refactor: Moved mock_conn from conftest to test_utils module under common
sql
* refactor: Removed unnecessary else statement in placeholder property
* refactor: Default placeholder can be a class/static variable as it's only
purpose is to define a default SQL placeholder, the actual placeholder will
always be retrieved through the property
* refactor: Updated sql test with changes from main
* refactor: Reformatted test
* Update airflow/providers/common/sql/hooks/sql.py
Co-authored-by: Elad Kalif <[email protected]>
* fix: Fixed name of constant SQL_PLACEHOLDERS being checked in placeholder
property
---------
Co-authored-by: David Blain <[email protected]>
Co-authored-by: Elad Kalif <[email protected]>
---
airflow/providers/common/sql/hooks/sql.py | 16 ++++-
airflow/providers/odbc/hooks/odbc.py | 16 -----
tests/providers/common/sql/hooks/test_sql.py | 75 ++++++++++++-----------
tests/providers/common/sql/test_utils.py | 55 +++++++++++++++++
tests/providers/odbc/hooks/test_odbc.py | 92 ++++++++++------------------
5 files changed, 138 insertions(+), 116 deletions(-)
diff --git a/airflow/providers/common/sql/hooks/sql.py
b/airflow/providers/common/sql/hooks/sql.py
index 7c2480bb6a..3f324e4f69 100644
--- a/airflow/providers/common/sql/hooks/sql.py
+++ b/airflow/providers/common/sql/hooks/sql.py
@@ -50,6 +50,7 @@ if TYPE_CHECKING:
from airflow.providers.openlineage.sqlparser import DatabaseInfo
T = TypeVar("T")
+SQL_PLACEHOLDERS = frozenset({"%s", "?"})
def return_single_query_results(sql: str | Iterable[str], return_last: bool,
split_statements: bool):
@@ -146,6 +147,8 @@ class DbApiHook(BaseHook):
connector: ConnectorProtocol | None = None
# Override with db-specific query to check connection
_test_connection_sql = "select 1"
+ # Default SQL placeholder
+ _placeholder: str = "%s"
def __init__(self, *args, schema: str | None = None, log_sql: bool = True,
**kwargs):
super().__init__()
@@ -164,7 +167,6 @@ class DbApiHook(BaseHook):
self.__schema = schema
self.log_sql = log_sql
self.descriptions: list[Sequence[Sequence] | None] = []
- self._placeholder: str = "%s"
self._insert_statement_format: str = kwargs.get(
"insert_statement_format", "INSERT INTO {} {} VALUES ({})"
)
@@ -173,7 +175,17 @@ class DbApiHook(BaseHook):
)
@property
- def placeholder(self) -> str:
+ def placeholder(self):
+ conn = self.get_connection(getattr(self, self.conn_name_attr))
+ placeholder = conn.extra_dejson.get("placeholder")
+ if placeholder in SQL_PLACEHOLDERS:
+ return placeholder
+ self.log.warning(
+ "Placeholder defined in Connection '%s' is not listed in
'DEFAULT_SQL_PLACEHOLDERS' "
+ "and got ignored. Falling back to the default placeholder '%s'.",
+ placeholder,
+ self._placeholder,
+ )
return self._placeholder
def get_conn(self):
diff --git a/airflow/providers/odbc/hooks/odbc.py
b/airflow/providers/odbc/hooks/odbc.py
index a14e64d6df..8cf95bf095 100644
--- a/airflow/providers/odbc/hooks/odbc.py
+++ b/airflow/providers/odbc/hooks/odbc.py
@@ -27,8 +27,6 @@ from pyodbc import Connection, Row, connect
from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.utils.helpers import merge_dicts
-DEFAULT_ODBC_PLACEHOLDERS = frozenset({"%s", "?"})
-
class OdbcHook(DbApiHook):
"""
@@ -202,20 +200,6 @@ class OdbcHook(DbApiHook):
conn = connect(self.odbc_connection_string, **self.connect_kwargs)
return conn
- @property
- def placeholder(self):
- placeholder = self.connection.extra_dejson.get("placeholder")
- if placeholder in DEFAULT_ODBC_PLACEHOLDERS:
- return placeholder
- else:
- self.log.warning(
- "Placeholder defined in Connection '%s' is not listed in
'DEFAULT_ODBC_PLACEHOLDERS' "
- "and got ignored. Falling back to the default placeholder
'%s'.",
- placeholder,
- self._placeholder,
- )
- return self._placeholder
-
def get_uri(self) -> str:
"""URI invoked in
:meth:`~airflow.providers.common.sql.hooks.sql.DbApiHook.get_sqlalchemy_engine`."""
quoted_conn_str = quote_plus(self.odbc_connection_string)
diff --git a/tests/providers/common/sql/hooks/test_sql.py
b/tests/providers/common/sql/hooks/test_sql.py
index e0a7b0fccb..4bd5bdcc54 100644
--- a/tests/providers/common/sql/hooks/test_sql.py
+++ b/tests/providers/common/sql/hooks/test_sql.py
@@ -19,7 +19,6 @@
from __future__ import annotations
import warnings
-from typing import Any
from unittest.mock import MagicMock
import pytest
@@ -28,6 +27,7 @@ from airflow.exceptions import
AirflowProviderDeprecationWarning
from airflow.models import Connection
from airflow.providers.common.sql.hooks.sql import DbApiHook, fetch_all_handler
from airflow.utils.session import provide_session
+from tests.providers.common.sql.test_utils import mock_hook
TASK_ID = "sql-operator"
HOST = "host"
@@ -214,39 +214,40 @@ def test_query(
dbapi_hook.get_conn.return_value.cursor.return_value.close.assert_called()
[email protected]_test
[email protected](
- "empty_statement",
- [
- pytest.param([], id="Empty list"),
- pytest.param("", id="Empty string"),
- pytest.param("\n", id="Only EOL"),
- ],
-)
-def test_no_query(empty_statement):
- dbapi_hook = DBApiHookForTests()
- dbapi_hook.get_conn.return_value.cursor.rowcount = 0
- with pytest.raises(ValueError) as err:
- dbapi_hook.run(sql=empty_statement)
- assert err.value.args[0] == "List of SQL statements is empty"
-
-
[email protected]_test
-def test_make_common_data_structure_hook_has_deprecated_method():
- """If hook implements ``_make_serializable`` warning should be raised on
call."""
-
- class DBApiHookForMakeSerializableTests(DBApiHookForTests):
- def _make_serializable(self, result: Any):
- return result
-
- hook = DBApiHookForMakeSerializableTests()
- with pytest.warns(AirflowProviderDeprecationWarning,
match="`_make_serializable` method is deprecated"):
- hook._make_common_data_structure(["foo", "bar", "baz"])
-
-
[email protected]_test
-def test_make_common_data_structure_no_deprecated_method():
- """If hook not implements ``_make_serializable`` there is no warning
should be raised on call."""
- with warnings.catch_warnings():
- warnings.simplefilter("error", AirflowProviderDeprecationWarning)
- DBApiHookForTests()._make_common_data_structure(["foo", "bar", "baz"])
+class TestDbApiHook:
+ @pytest.mark.db_test
+ @pytest.mark.parametrize(
+ "empty_statement",
+ [
+ pytest.param([], id="Empty list"),
+ pytest.param("", id="Empty string"),
+ pytest.param("\n", id="Only EOL"),
+ ],
+ )
+ def test_no_query(self, empty_statement):
+ dbapi_hook = mock_hook(DbApiHook)
+ with pytest.raises(ValueError) as err:
+ dbapi_hook.run(sql=empty_statement)
+ assert err.value.args[0] == "List of SQL statements is empty"
+
+ @pytest.mark.db_test
+ def test_make_common_data_structure_hook_has_deprecated_method(self):
+ """If hook implements ``_make_serializable`` warning should be raised
on call."""
+ hook = mock_hook(DbApiHook)
+ hook._make_serializable = lambda result: result
+ with pytest.warns(
+ AirflowProviderDeprecationWarning, match="`_make_serializable`
method is deprecated"
+ ):
+ hook._make_common_data_structure(["foo", "bar", "baz"])
+
+ @pytest.mark.db_test
+ def test_make_common_data_structure_no_deprecated_method(self):
+ """If hook not implements ``_make_serializable`` there is no warning
should be raised on call."""
+ with warnings.catch_warnings():
+ warnings.simplefilter("error", AirflowProviderDeprecationWarning)
+ mock_hook(DbApiHook)._make_common_data_structure(["foo", "bar",
"baz"])
+
+ @pytest.mark.db_test
+ def test_placeholder_config_from_extra(self):
+ dbapi_hook = mock_hook(DbApiHook, conn_params={"extra":
{"placeholder": "?"}})
+ assert dbapi_hook.placeholder == "?"
diff --git a/tests/providers/common/sql/test_utils.py
b/tests/providers/common/sql/test_utils.py
new file mode 100644
index 0000000000..c3bc4c3565
--- /dev/null
+++ b/tests/providers/common/sql/test_utils.py
@@ -0,0 +1,55 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+from unittest import mock
+
+from airflow.models import Connection
+
+if TYPE_CHECKING:
+ from airflow.hooks.base import BaseHook
+
+
+def mock_hook(hook_class: type[BaseHook], hook_params=None, conn_params=None):
+ hook_params = hook_params or {}
+ conn_params = conn_params or {}
+ connection = Connection(
+ **{
+ **dict(login="login", password="password", host="host",
schema="schema", port=1234),
+ **conn_params,
+ }
+ )
+
+ cursor = mock.MagicMock(
+ rowcount=0, spec=["description", "rowcount", "execute", "fetchall",
"fetchone", "close"]
+ )
+ conn = mock.MagicMock()
+ conn.cursor.return_value = cursor
+
+ class MockedHook(hook_class):
+ conn_name_attr = "test_conn_id"
+
+ @classmethod
+ def get_connection(cls, conn_id: str):
+ return connection
+
+ def get_conn(self):
+ return conn
+
+ return MockedHook(**hook_params)
diff --git a/tests/providers/odbc/hooks/test_odbc.py
b/tests/providers/odbc/hooks/test_odbc.py
index 64683d21e1..bddd2ffd99 100644
--- a/tests/providers/odbc/hooks/test_odbc.py
+++ b/tests/providers/odbc/hooks/test_odbc.py
@@ -27,8 +27,8 @@ from urllib.parse import quote_plus, urlsplit
import pyodbc
import pytest
-from airflow.models import Connection
from airflow.providers.odbc.hooks.odbc import OdbcHook
+from tests.providers.common.sql.test_utils import mock_hook
@pytest.fixture
@@ -77,38 +77,10 @@ def pyodbc_instancecheck():
class TestOdbcHook:
- def get_hook(self=None, hook_params=None, conn_params=None):
- hook_params = hook_params or {}
- conn_params = conn_params or {}
- connection = Connection(
- **{
- **dict(login="login", password="password", host="host",
schema="schema", port=1234),
- **conn_params,
- }
- )
-
- cursor = mock.MagicMock(
- rowcount=0, spec=["description", "rowcount", "execute",
"fetchall", "fetchone", "close"]
- )
- conn = mock.MagicMock()
- conn.cursor.return_value = cursor
-
- class UnitTestOdbcHook(OdbcHook):
- conn_name_attr = "test_conn_id"
-
- @classmethod
- def get_connection(cls, conn_id: str):
- return connection
-
- def get_conn(self):
- return conn
-
- return UnitTestOdbcHook(**hook_params)
-
def test_driver_in_extra_not_used(self):
conn_params = dict(extra=json.dumps(dict(Driver="Fake Driver",
Fake_Param="Fake Param")))
hook_params = {"driver": "ParamDriver"}
- hook = self.get_hook(conn_params=conn_params, hook_params=hook_params)
+ hook = mock_hook(OdbcHook, conn_params=conn_params,
hook_params=hook_params)
expected = (
"DRIVER={ParamDriver};"
"SERVER=host;"
@@ -123,7 +95,7 @@ class TestOdbcHook:
def test_driver_in_both(self):
conn_params = dict(extra=json.dumps(dict(Driver="Fake Driver",
Fake_Param="Fake Param")))
hook_params = dict(driver="ParamDriver")
- hook = self.get_hook(hook_params=hook_params, conn_params=conn_params)
+ hook = mock_hook(OdbcHook, conn_params=conn_params,
hook_params=hook_params)
expected = (
"DRIVER={ParamDriver};"
"SERVER=host;"
@@ -137,7 +109,7 @@ class TestOdbcHook:
def test_dsn_in_extra(self):
conn_params = dict(extra=json.dumps(dict(DSN="MyDSN", Fake_Param="Fake
Param")))
- hook = self.get_hook(conn_params=conn_params)
+ hook = mock_hook(OdbcHook, conn_params=conn_params)
expected = (
"DSN=MyDSN;SERVER=host;DATABASE=schema;UID=login;PWD=password;PORT=1234;Fake_Param=Fake
Param;"
)
@@ -146,7 +118,7 @@ class TestOdbcHook:
def test_dsn_in_both(self):
conn_params = dict(extra=json.dumps(dict(DSN="MyDSN", Fake_Param="Fake
Param")))
hook_params = dict(driver="ParamDriver", dsn="ParamDSN")
- hook = self.get_hook(hook_params=hook_params, conn_params=conn_params)
+ hook = mock_hook(OdbcHook, conn_params=conn_params,
hook_params=hook_params)
expected = (
"DRIVER={ParamDriver};"
"DSN=ParamDSN;"
@@ -162,7 +134,7 @@ class TestOdbcHook:
def test_get_uri(self):
conn_params = dict(extra=json.dumps(dict(DSN="MyDSN", Fake_Param="Fake
Param")))
hook_params = dict(dsn="ParamDSN")
- hook = self.get_hook(hook_params=hook_params, conn_params=conn_params)
+ hook = mock_hook(OdbcHook, conn_params=conn_params,
hook_params=hook_params)
uri_param = quote_plus(
"DSN=ParamDSN;SERVER=host;DATABASE=schema;UID=login;PWD=password;PORT=1234;Fake_Param=Fake
Param;"
)
@@ -170,7 +142,8 @@ class TestOdbcHook:
assert hook.get_uri() == expected
def test_connect_kwargs_from_hook(self):
- hook = self.get_hook(
+ hook = mock_hook(
+ OdbcHook,
hook_params=dict(
connect_kwargs={
"attrs_before": {
@@ -202,7 +175,7 @@ class TestOdbcHook:
)
)
- hook = self.get_hook(conn_params=dict(extra=extra))
+ hook = mock_hook(OdbcHook, conn_params=dict(extra=extra))
assert hook.connect_kwargs == {
"attrs_before": {1: 2, pyodbc.SQL_TXN_ISOLATION:
pyodbc.SQL_TXN_READ_UNCOMMITTED},
"readonly": True,
@@ -219,7 +192,7 @@ class TestOdbcHook:
connect_kwargs={"attrs_before": {3: 5, pyodbc.SQL_TXN_ISOLATION:
0}, "readonly": True}
)
- hook = self.get_hook(conn_params=dict(extra=conn_extra),
hook_params=hook_params)
+ hook = mock_hook(OdbcHook, conn_params=dict(extra=conn_extra),
hook_params=hook_params)
assert hook.connect_kwargs == {
"attrs_before": {1: 2, 3: 5, pyodbc.SQL_TXN_ISOLATION: 0},
"readonly": True,
@@ -230,74 +203,71 @@ class TestOdbcHook:
Bools will be parsed from uri as strings
"""
conn_extra = json.dumps(dict(connect_kwargs={"ansi": True}))
- hook = self.get_hook(conn_params=dict(extra=conn_extra))
+ hook = mock_hook(OdbcHook, conn_params=dict(extra=conn_extra))
assert hook.connect_kwargs == {
"ansi": True,
}
def test_driver(self):
- hook = self.get_hook(hook_params=dict(driver="Blah driver"))
+ hook = mock_hook(OdbcHook, hook_params=dict(driver="Blah driver"))
assert hook.driver == "Blah driver"
- hook = self.get_hook(hook_params=dict(driver="{Blah driver}"))
+ hook = mock_hook(OdbcHook, hook_params=dict(driver="{Blah driver}"))
assert hook.driver == "Blah driver"
def test_driver_extra_raises_warning_by_default(self, caplog):
with caplog.at_level(logging.WARNING,
logger="airflow.providers.odbc.hooks.test_odbc"):
- driver = self.get_hook(conn_params=dict(extra='{"driver": "Blah
driver"}')).driver
+ driver = mock_hook(OdbcHook, conn_params=dict(extra='{"driver":
"Blah driver"}')).driver
assert "You have supplied 'driver' via connection extra but it
will not be used" in caplog.text
assert driver is None
@mock.patch.dict("os.environ",
{"AIRFLOW__PROVIDERS_ODBC__ALLOW_DRIVER_IN_EXTRA": "TRUE"})
def test_driver_extra_works_when_allow_driver_extra(self):
- hook = self.get_hook(
- conn_params=dict(extra='{"driver": "Blah driver"}'),
hook_params=dict(allow_driver_extra=True)
+ hook = mock_hook(
+ OdbcHook,
+ conn_params=dict(extra='{"driver": "Blah driver"}'),
+ hook_params=dict(allow_driver_extra=True),
)
assert hook.driver == "Blah driver"
def test_default_driver_set(self):
with patch.object(OdbcHook, "default_driver", "Blah driver"):
- hook = self.get_hook()
+ hook = mock_hook(OdbcHook)
assert hook.driver == "Blah driver"
def test_driver_extra_works_when_default_driver_set(self):
with patch.object(OdbcHook, "default_driver", "Blah driver"):
- hook = self.get_hook()
+ hook = mock_hook(OdbcHook)
assert hook.driver == "Blah driver"
def test_driver_none_by_default(self):
- hook = self.get_hook()
+ hook = mock_hook(OdbcHook)
assert hook.driver is None
def
test_driver_extra_raises_warning_and_returns_default_driver_by_default(self,
caplog):
with patch.object(OdbcHook, "default_driver", "Blah driver"):
with caplog.at_level(logging.WARNING,
logger="airflow.providers.odbc.hooks.test_odbc"):
- driver = self.get_hook(conn_params=dict(extra='{"driver":
"Blah driver2"}')).driver
+ driver = mock_hook(OdbcHook,
conn_params=dict(extra='{"driver": "Blah driver2"}')).driver
assert "have supplied 'driver' via connection extra but it
will not be used" in caplog.text
assert driver == "Blah driver"
- def test_placeholder_config_from_extra(self):
- conn_params = dict(extra=json.dumps(dict(placeholder="?")))
- hook = self.get_hook(conn_params=conn_params)
- assert hook.placeholder == "?"
-
def test_database(self):
- hook = self.get_hook(hook_params=dict(database="abc"))
+ hook = mock_hook(OdbcHook, hook_params=dict(database="abc"))
assert hook.database == "abc"
- hook = self.get_hook()
+ hook = mock_hook(OdbcHook)
assert hook.database == "schema"
def test_sqlalchemy_scheme_default(self):
- hook = self.get_hook()
+ hook = mock_hook(OdbcHook)
uri = hook.get_uri()
assert urlsplit(uri).scheme == "mssql+pyodbc"
def test_sqlalchemy_scheme_param(self):
- hook = self.get_hook(hook_params=dict(sqlalchemy_scheme="my-scheme"))
+ hook = mock_hook(OdbcHook,
hook_params=dict(sqlalchemy_scheme="my-scheme"))
uri = hook.get_uri()
assert urlsplit(uri).scheme == "my-scheme"
def test_sqlalchemy_scheme_extra(self):
- hook =
self.get_hook(conn_params=dict(extra=json.dumps(dict(sqlalchemy_scheme="my-scheme"))))
+ hook = mock_hook(OdbcHook,
conn_params=dict(extra=json.dumps(dict(sqlalchemy_scheme="my-scheme"))))
uri = hook.get_uri()
assert urlsplit(uri).scheme == "my-scheme"
@@ -323,7 +293,7 @@ class TestOdbcHook:
def mock_handler(*_):
return pyodbc_result
- hook = self.get_hook()
+ hook = mock_hook(OdbcHook)
with monkeypatch.context() as patcher:
patcher.setattr("pyodbc.Row", pyodbc_instancecheck)
result = hook.run("SQL", handler=mock_handler)
@@ -340,7 +310,7 @@ class TestOdbcHook:
def mock_handler(*_):
return pyodbc_result
- hook = self.get_hook()
+ hook = mock_hook(OdbcHook)
with monkeypatch.context() as patcher:
patcher.setattr("pyodbc.Row", pyodbc_instancecheck)
result = hook.run("SQL", handler=mock_handler)
@@ -359,13 +329,13 @@ class TestOdbcHook:
def mock_handler(*_):
return pyodbc_result
- hook = self.get_hook()
+ hook = mock_hook(OdbcHook)
with monkeypatch.context() as patcher:
patcher.setattr("pyodbc.Row", pyodbc_instancecheck)
result = hook.run("SQL", handler=mock_handler)
assert hook_result == result
def test_query_no_handler_return_none(self):
- hook = self.get_hook()
+ hook = mock_hook(OdbcHook)
result = hook.run("SQL")
assert result is None