This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 371833e076 Trino Hook: Add ability to read JWT from file (#31950)
371833e076 is described below
commit 371833e076d033be84f109cce980a6275032833c
Author: jhbigler <[email protected]>
AuthorDate: Sat Jun 24 09:32:48 2023 -0700
Trino Hook: Add ability to read JWT from file (#31950)
---------
Co-authored-by: Joshua H. Bigler <[email protected]>
Co-authored-by: Joshua Bigler <[email protected]>
Co-authored-by: Phani Kumar <[email protected]>
---
airflow/providers/trino/hooks/trino.py | 7 ++++-
.../apache-airflow-providers-trino/connections.rst | 3 +++
tests/providers/trino/hooks/test_trino.py | 30 ++++++++++++++++++++++
3 files changed, 39 insertions(+), 1 deletion(-)
diff --git a/airflow/providers/trino/hooks/trino.py
b/airflow/providers/trino/hooks/trino.py
index 070adabd01..09fbe6efa6 100644
--- a/airflow/providers/trino/hooks/trino.py
+++ b/airflow/providers/trino/hooks/trino.py
@@ -95,7 +95,12 @@ class TrinoHook(DbApiHook):
elif db.password:
auth = trino.auth.BasicAuthentication(db.login, db.password) #
type: ignore[attr-defined]
elif extra.get("auth") == "jwt":
- auth = trino.auth.JWTAuthentication(token=extra.get("jwt__token"))
+ if "jwt__file" in extra:
+ with open(extra.get("jwt__file")) as jwt_file:
+ token = jwt_file.read()
+ else:
+ token = extra.get("jwt__token")
+ auth = trino.auth.JWTAuthentication(token=token)
elif extra.get("auth") == "certs":
auth = trino.auth.CertificateAuthentication(
extra.get("certs__client_cert_path"),
diff --git a/docs/apache-airflow-providers-trino/connections.rst
b/docs/apache-airflow-providers-trino/connections.rst
index 3c50cd1b45..ee25fbbce7 100644
--- a/docs/apache-airflow-providers-trino/connections.rst
+++ b/docs/apache-airflow-providers-trino/connections.rst
@@ -49,7 +49,10 @@ Extra (optional, connection parameters)
The following extra parameters can be used to configure authentication:
* ``jwt__token`` - If jwt authentication should be used, the value of
token is given via this parameter.
+ * ``jwt__file`` - If jwt authentication should be used, the location on
disk for the file containing the jwt token.
* ``certs__client_cert_path``, ``certs__client_key_path``- If certificate
authentication should be used, the path to the client certificate and key is
given via these parameters.
* ``kerberos__service_name``, ``kerberos__config``,
``kerberos__mutual_authentication``, ``kerberos__force_preemptive``,
``kerberos__hostname_override``, ``kerberos__sanitize_mutual_error_response``,
``kerberos__principal``,``kerberos__delegate``, ``kerberos__ca_bundle`` - These
parameters can be set when enabling ``kerberos`` authentication.
* ``session_properties`` - JSON dictionary which allows to set
session_properties. Example:
``{'session_properties':{'scale_writers':true,'task_writer_count:1'}}``
* ``client_tags`` - List of comma separated tags. Example
``{'client_tags':['sales','cluster1']}```
+
+ Note: If ``jwt__file`` and ``jwt__token`` are both given, ``jwt__file``
will take precedent.
diff --git a/tests/providers/trino/hooks/test_trino.py
b/tests/providers/trino/hooks/test_trino.py
index 4a0f2e6d2f..5a9e51bf09 100644
--- a/tests/providers/trino/hooks/test_trino.py
+++ b/tests/providers/trino/hooks/test_trino.py
@@ -18,7 +18,9 @@
from __future__ import annotations
import json
+import os
import re
+from tempfile import TemporaryDirectory
from unittest import mock
from unittest.mock import patch
@@ -37,6 +39,19 @@ JWT_AUTHENTICATION =
"airflow.providers.trino.hooks.trino.trino.auth.JWTAuthenti
CERT_AUTHENTICATION =
"airflow.providers.trino.hooks.trino.trino.auth.CertificateAuthentication"
[email protected]()
+def jwt_token_file():
+ # Couldn't get this working with TemporaryFile, using TemporaryDirectory
instead
+ # Save a phony jwt to a temporary file for the trino hook to read from
+ with TemporaryDirectory() as tmp_dir:
+ tmp_jwt_file = os.path.join(tmp_dir, "jwt.json")
+
+ with open(tmp_jwt_file, "w") as tmp_file:
+ tmp_file.write('{"phony":"jwt"}')
+
+ yield tmp_jwt_file
+
+
class TestTrinoHookConn:
@patch(BASIC_AUTHENTICATION)
@patch(TRINO_DBAPI_CONNECT)
@@ -110,6 +125,21 @@ class TestTrinoHookConn:
TrinoHook().get_conn()
self.assert_connection_called_with(mock_connect, auth=mock_jwt_auth)
+ @patch(JWT_AUTHENTICATION)
+ @patch(TRINO_DBAPI_CONNECT)
+ @patch(HOOK_GET_CONNECTION)
+ def test_get_conn_jwt_file(self, mock_get_connection, mock_connect,
mock_jwt_auth, jwt_token_file):
+ extras = {
+ "auth": "jwt",
+ "jwt__file": jwt_token_file,
+ }
+ self.set_get_connection_return_value(
+ mock_get_connection,
+ extra=json.dumps(extras),
+ )
+ TrinoHook().get_conn()
+ self.assert_connection_called_with(mock_connect, auth=mock_jwt_auth)
+
@patch(CERT_AUTHENTICATION)
@patch(TRINO_DBAPI_CONNECT)
@patch(HOOK_GET_CONNECTION)