This is an automated email from the ASF dual-hosted git repository.
taragolis pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 2a469b3713 Remove backcompat inheritance for DbApiHook (#35754)
2a469b3713 is described below
commit 2a469b3713d95ab15df8e9090abdb9d15e50cbb9
Author: Andrey Anshin <[email protected]>
AuthorDate: Tue Nov 21 11:12:56 2023 +0400
Remove backcompat inheritance for DbApiHook (#35754)
* Remove backcompat inheritance for DbApiHook
* jwt_file > jwt__file
* simplify trino test
---
airflow/providers/apache/impala/hooks/impala.py | 3 +-
airflow/providers/common/sql/hooks/sql.py | 18 +---------
airflow/providers/common/sql/hooks/sql.pyi | 4 +--
.../providers/elasticsearch/hooks/elasticsearch.py | 4 +--
airflow/providers/trino/hooks/trino.py | 19 ++++++++---
tests/providers/odbc/hooks/test_odbc.py | 3 +-
tests/providers/trino/hooks/test_trino.py | 38 +++++++++++++++-------
7 files changed, 49 insertions(+), 40 deletions(-)
diff --git a/airflow/providers/apache/impala/hooks/impala.py
b/airflow/providers/apache/impala/hooks/impala.py
index ab19865a9e..b8c79b4e25 100644
--- a/airflow/providers/apache/impala/hooks/impala.py
+++ b/airflow/providers/apache/impala/hooks/impala.py
@@ -35,7 +35,8 @@ class ImpalaHook(DbApiHook):
hook_name = "Impala"
def get_conn(self) -> Connection:
- connection = self.get_connection(self.impala_conn_id) # pylint:
disable=no-member
+ conn_id: str = getattr(self, self.conn_name_attr)
+ connection = self.get_connection(conn_id)
return connect(
host=connection.host,
port=connection.port,
diff --git a/airflow/providers/common/sql/hooks/sql.py
b/airflow/providers/common/sql/hooks/sql.py
index ab4eda5d8e..bb85dedc1c 100644
--- a/airflow/providers/common/sql/hooks/sql.py
+++ b/airflow/providers/common/sql/hooks/sql.py
@@ -34,12 +34,10 @@ from typing import (
from urllib.parse import urlparse
import sqlparse
-from packaging.version import Version
from sqlalchemy import create_engine
from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
-from airflow.version import version
if TYPE_CHECKING:
from pandas import DataFrame
@@ -120,21 +118,7 @@ class ConnectorProtocol(Protocol):
"""
-# In case we are running it on Airflow 2.4+, we should use BaseHook, but on
Airflow 2.3 and below
-# We want the DbApiHook to derive from the original DbApiHook from airflow,
because otherwise
-# SqlSensor and BaseSqlOperator from "airflow.operators" and "airflow.sensors"
will refuse to
-# accept the new Hooks as not derived from the original DbApiHook
-if Version(version) < Version("2.4"):
- try:
- from airflow.hooks.dbapi import DbApiHook as BaseForDbApiHook
- except ImportError:
- # just in case we have a problem with circular import
- BaseForDbApiHook: type[BaseHook] = BaseHook # type: ignore[no-redef]
-else:
- BaseForDbApiHook: type[BaseHook] = BaseHook # type: ignore[no-redef]
-
-
-class DbApiHook(BaseForDbApiHook):
+class DbApiHook(BaseHook):
"""
Abstract base class for sql hooks.
diff --git a/airflow/providers/common/sql/hooks/sql.pyi
b/airflow/providers/common/sql/hooks/sql.pyi
index dedac037df..41bd6ebf47 100644
--- a/airflow/providers/common/sql/hooks/sql.pyi
+++ b/airflow/providers/common/sql/hooks/sql.pyi
@@ -32,8 +32,8 @@ Definition of the public interface for
airflow.providers.common.sql.hooks.sql
isort:skip_file
"""
from _typeshed import Incomplete
-from airflow.hooks.dbapi import DbApiHook as BaseForDbApiHook
-from typing import Any, Callable, Iterable, Mapping, Sequence
+from airflow.hooks.base import BaseHook as BaseForDbApiHook
+from typing import Any, Callable, Iterable, Mapping, Sequence, Union
from typing_extensions import Protocol
def return_single_query_results(
diff --git a/airflow/providers/elasticsearch/hooks/elasticsearch.py
b/airflow/providers/elasticsearch/hooks/elasticsearch.py
index 6c93586892..2d9fca4a97 100644
--- a/airflow/providers/elasticsearch/hooks/elasticsearch.py
+++ b/airflow/providers/elasticsearch/hooks/elasticsearch.py
@@ -108,9 +108,7 @@ class ElasticsearchSQLHook(DbApiHook):
if conn.extra_dejson.get("timeout", False):
conn_args["timeout"] = conn.extra_dejson["timeout"]
- conn = connect(**conn_args)
-
- return conn
+ return connect(**conn_args)
def get_uri(self) -> str:
conn_id = getattr(self, self.conn_name_attr)
diff --git a/airflow/providers/trino/hooks/trino.py
b/airflow/providers/trino/hooks/trino.py
index 03195fe452..798109dc3f 100644
--- a/airflow/providers/trino/hooks/trino.py
+++ b/airflow/providers/trino/hooks/trino.py
@@ -19,6 +19,7 @@ from __future__ import annotations
import json
import os
+from pathlib import Path
from typing import TYPE_CHECKING, Any, Iterable, Mapping, TypeVar
import trino
@@ -28,6 +29,7 @@ from trino.transaction import IsolationLevel
from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.providers.common.sql.hooks.sql import DbApiHook
+from airflow.utils.helpers import exactly_one
from airflow.utils.operator_helpers import AIRFLOW_VAR_NAME_FORMAT_MAPPING,
DEFAULT_FORMAT_PREFIX
if TYPE_CHECKING:
@@ -99,11 +101,20 @@ class TrinoHook(DbApiHook):
elif db.password:
auth = trino.auth.BasicAuthentication(db.login, db.password) #
type: ignore[attr-defined]
elif extra.get("auth") == "jwt":
- if "jwt__file" in extra:
- with open(extra.get("jwt__file")) as jwt_file:
- token = jwt_file.read()
+ if not exactly_one(jwt_file := "jwt__file" in extra, jwt_token :=
"jwt__token" in extra):
+ msg = (
+ "When auth set to 'jwt' then expected exactly one
parameter 'jwt__file' or 'jwt__token'"
+ " in connection extra, but "
+ )
+ if jwt_file and jwt_token:
+ msg += "provided both."
+ else:
+ msg += "none of them provided."
+ raise ValueError(msg)
+ elif jwt_file:
+ token = Path(extra["jwt__file"]).read_text()
else:
- token = extra.get("jwt__token")
+ token = extra["jwt__token"]
auth = trino.auth.JWTAuthentication(token=token)
elif extra.get("auth") == "certs":
auth = trino.auth.CertificateAuthentication(
diff --git a/tests/providers/odbc/hooks/test_odbc.py
b/tests/providers/odbc/hooks/test_odbc.py
index ad763b934b..03e09a8adf 100644
--- a/tests/providers/odbc/hooks/test_odbc.py
+++ b/tests/providers/odbc/hooks/test_odbc.py
@@ -79,7 +79,8 @@ class TestOdbcHook:
class UnitTestOdbcHook(OdbcHook):
conn_name_attr = "test_conn_id"
- def get_connection(self, conn_id: str):
+ @classmethod
+ def get_connection(cls, conn_id: str):
return connection
def get_conn(self):
diff --git a/tests/providers/trino/hooks/test_trino.py
b/tests/providers/trino/hooks/test_trino.py
index 8aeb4cbe08..5d61c056bc 100644
--- a/tests/providers/trino/hooks/test_trino.py
+++ b/tests/providers/trino/hooks/test_trino.py
@@ -18,9 +18,7 @@
from __future__ import annotations
import json
-import os
import re
-from tempfile import TemporaryDirectory
from unittest import mock
from unittest.mock import patch
@@ -40,16 +38,10 @@ CERT_AUTHENTICATION =
"airflow.providers.trino.hooks.trino.trino.auth.Certificat
@pytest.fixture()
-def jwt_token_file():
- # Couldn't get this working with TemporaryFile, using TemporaryDirectory
instead
- # Save a phony jwt to a temporary file for the trino hook to read from
- with TemporaryDirectory() as tmp_dir:
- tmp_jwt_file = os.path.join(tmp_dir, "jwt.json")
-
- with open(tmp_jwt_file, "w") as tmp_file:
- tmp_file.write('{"phony":"jwt"}')
-
- yield tmp_jwt_file
+def jwt_token_file(tmp_path):
+ jwt_file = tmp_path / "jwt.json"
+ jwt_file.write_text('{"phony":"jwt"}')
+ yield jwt_file.__fspath__()
class TestTrinoHookConn:
@@ -140,6 +132,28 @@ class TestTrinoHookConn:
TrinoHook().get_conn()
self.assert_connection_called_with(mock_connect, auth=mock_jwt_auth)
+ @pytest.mark.parametrize(
+ "jwt_file, jwt_token, error_suffix",
+ [
+ pytest.param(True, True, "provided both",
id="provided-both-params"),
+ pytest.param(False, False, "none of them provided",
id="no-jwt-provided"),
+ ],
+ )
+ @patch(HOOK_GET_CONNECTION)
+ def test_exactly_one_jwt_token(
+ self, mock_get_connection, jwt_file, jwt_token, error_suffix,
jwt_token_file
+ ):
+ error_match = f"When auth set to 'jwt'.*{error_suffix}"
+ extras = {"auth": "jwt"}
+ if jwt_file:
+ extras["jwt__file"] = jwt_token_file
+ if jwt_token:
+ extras["jwt__token"] = "TEST_JWT_TOKEN"
+
+ self.set_get_connection_return_value(mock_get_connection,
extra=json.dumps(extras))
+ with pytest.raises(ValueError, match=error_match):
+ TrinoHook().get_conn()
+
@patch(CERT_AUTHENTICATION)
@patch(TRINO_DBAPI_CONNECT)
@patch(HOOK_GET_CONNECTION)