This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 371833e076 Trino Hook: Add ability to read JWT from file (#31950)
371833e076 is described below

commit 371833e076d033be84f109cce980a6275032833c
Author: jhbigler <[email protected]>
AuthorDate: Sat Jun 24 09:32:48 2023 -0700

    Trino Hook: Add ability to read JWT from file (#31950)
    
    
    
    ---------
    
    Co-authored-by: Joshua H. Bigler <[email protected]>
    Co-authored-by: Joshua Bigler <[email protected]>
    Co-authored-by: Phani Kumar <[email protected]>
---
 airflow/providers/trino/hooks/trino.py             |  7 ++++-
 .../apache-airflow-providers-trino/connections.rst |  3 +++
 tests/providers/trino/hooks/test_trino.py          | 30 ++++++++++++++++++++++
 3 files changed, 39 insertions(+), 1 deletion(-)

diff --git a/airflow/providers/trino/hooks/trino.py 
b/airflow/providers/trino/hooks/trino.py
index 070adabd01..09fbe6efa6 100644
--- a/airflow/providers/trino/hooks/trino.py
+++ b/airflow/providers/trino/hooks/trino.py
@@ -95,7 +95,12 @@ class TrinoHook(DbApiHook):
         elif db.password:
             auth = trino.auth.BasicAuthentication(db.login, db.password)  # 
type: ignore[attr-defined]
         elif extra.get("auth") == "jwt":
-            auth = trino.auth.JWTAuthentication(token=extra.get("jwt__token"))
+            if "jwt__file" in extra:
+                with open(extra.get("jwt__file")) as jwt_file:
+                    token = jwt_file.read()
+            else:
+                token = extra.get("jwt__token")
+            auth = trino.auth.JWTAuthentication(token=token)
         elif extra.get("auth") == "certs":
             auth = trino.auth.CertificateAuthentication(
                 extra.get("certs__client_cert_path"),
diff --git a/docs/apache-airflow-providers-trino/connections.rst 
b/docs/apache-airflow-providers-trino/connections.rst
index 3c50cd1b45..ee25fbbce7 100644
--- a/docs/apache-airflow-providers-trino/connections.rst
+++ b/docs/apache-airflow-providers-trino/connections.rst
@@ -49,7 +49,10 @@ Extra (optional, connection parameters)
     The following extra parameters can be used to configure authentication:
 
     * ``jwt__token`` - If jwt authentication should be used, the value of 
token is given via this parameter.
+    * ``jwt__file``  - If jwt authentication should be used, the location on 
disk for the file containing the jwt token.
     * ``certs__client_cert_path``, ``certs__client_key_path``- If certificate 
authentication should be used, the path to the client certificate and key is 
given via these parameters.
     * ``kerberos__service_name``, ``kerberos__config``, 
``kerberos__mutual_authentication``, ``kerberos__force_preemptive``, 
``kerberos__hostname_override``, ``kerberos__sanitize_mutual_error_response``, 
``kerberos__principal``,``kerberos__delegate``, ``kerberos__ca_bundle`` - These 
parameters can be set when enabling ``kerberos`` authentication.
     * ``session_properties`` - JSON dictionary which allows to set 
session_properties. Example: 
``{'session_properties':{'scale_writers':true,'task_writer_count:1'}}``
     * ``client_tags`` - List of comma separated tags. Example 
``{'client_tags':['sales','cluster1']}```
+
+    Note: If ``jwt__file`` and ``jwt__token`` are both given, ``jwt__file`` 
will take precedent.
diff --git a/tests/providers/trino/hooks/test_trino.py 
b/tests/providers/trino/hooks/test_trino.py
index 4a0f2e6d2f..5a9e51bf09 100644
--- a/tests/providers/trino/hooks/test_trino.py
+++ b/tests/providers/trino/hooks/test_trino.py
@@ -18,7 +18,9 @@
 from __future__ import annotations
 
 import json
+import os
 import re
+from tempfile import TemporaryDirectory
 from unittest import mock
 from unittest.mock import patch
 
@@ -37,6 +39,19 @@ JWT_AUTHENTICATION = 
"airflow.providers.trino.hooks.trino.trino.auth.JWTAuthenti
 CERT_AUTHENTICATION = 
"airflow.providers.trino.hooks.trino.trino.auth.CertificateAuthentication"
 
 
[email protected]()
+def jwt_token_file():
+    # Couldn't get this working with TemporaryFile, using TemporaryDirectory 
instead
+    # Save a phony jwt to a temporary file for the trino hook to read from
+    with TemporaryDirectory() as tmp_dir:
+        tmp_jwt_file = os.path.join(tmp_dir, "jwt.json")
+
+        with open(tmp_jwt_file, "w") as tmp_file:
+            tmp_file.write('{"phony":"jwt"}')
+
+        yield tmp_jwt_file
+
+
 class TestTrinoHookConn:
     @patch(BASIC_AUTHENTICATION)
     @patch(TRINO_DBAPI_CONNECT)
@@ -110,6 +125,21 @@ class TestTrinoHookConn:
         TrinoHook().get_conn()
         self.assert_connection_called_with(mock_connect, auth=mock_jwt_auth)
 
+    @patch(JWT_AUTHENTICATION)
+    @patch(TRINO_DBAPI_CONNECT)
+    @patch(HOOK_GET_CONNECTION)
+    def test_get_conn_jwt_file(self, mock_get_connection, mock_connect, 
mock_jwt_auth, jwt_token_file):
+        extras = {
+            "auth": "jwt",
+            "jwt__file": jwt_token_file,
+        }
+        self.set_get_connection_return_value(
+            mock_get_connection,
+            extra=json.dumps(extras),
+        )
+        TrinoHook().get_conn()
+        self.assert_connection_called_with(mock_connect, auth=mock_jwt_auth)
+
     @patch(CERT_AUTHENTICATION)
     @patch(TRINO_DBAPI_CONNECT)
     @patch(HOOK_GET_CONNECTION)

Reply via email to