This is an automated email from the ASF dual-hosted git repository.

shahar pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 5e943e3c8ff Add extra_credentials and roles to TrinoHook (#51298)
5e943e3c8ff is described below

commit 5e943e3c8ff89fd3c4b4f5402a4fa61d79ad1360
Author: Zach <[email protected]>
AuthorDate: Mon Jun 2 11:42:52 2025 -0400

    Add extra_credentials and roles to TrinoHook (#51298)
    
    * Add extra_credentials and roles to TrinoHook
    
    * Address Trino provider doc URL error
---
 providers/trino/docs/connections.rst                 |  2 ++
 .../trino/src/airflow/providers/trino/hooks/trino.py |  2 ++
 providers/trino/tests/unit/trino/hooks/test_trino.py | 20 ++++++++++++++++++++
 3 files changed, 24 insertions(+)

diff --git a/providers/trino/docs/connections.rst 
b/providers/trino/docs/connections.rst
index f6bd0b46c1e..32e0abe360d 100644
--- a/providers/trino/docs/connections.rst
+++ b/providers/trino/docs/connections.rst
@@ -55,5 +55,7 @@ Extra (optional, connection parameters)
     * ``session_properties`` - JSON dictionary which allows to set 
session_properties. Example: 
``{'session_properties':{'scale_writers':true,'task_writer_count:1'}}``
     * ``client_tags`` - List of comma separated tags. Example 
``{'client_tags':['sales','cluster1']}```
     * ``timezone`` - The time zone for the session can be explicitly set using 
the IANA time zone name. Example: ``{'timezone':'Asia/Jerusalem'}``.
+    * ``extra_credential`` - List of key-value string pairs which are passed 
to the Trino connector. For more information, refer to the Trino client 
protocol doc page here: 
https://trino.io/docs/current/develop/client-protocol.html
+    * ``roles`` - Mapping of catalog names to their corresponding Trino 
authorization role. For more information, refer to the Trino Python client docs 
here: https://github.com/trinodb/trino-python-client?tab=readme-ov-file#roles
 
     Note: If ``jwt__file`` and ``jwt__token`` are both given, ``jwt__file`` 
will take precedent.
diff --git a/providers/trino/src/airflow/providers/trino/hooks/trino.py 
b/providers/trino/src/airflow/providers/trino/hooks/trino.py
index 1cca01a2770..35bb83cce2a 100644
--- a/providers/trino/src/airflow/providers/trino/hooks/trino.py
+++ b/providers/trino/src/airflow/providers/trino/hooks/trino.py
@@ -211,6 +211,8 @@ class TrinoHook(DbApiHook):
             session_properties=extra.get("session_properties") or None,
             client_tags=extra.get("client_tags") or None,
             timezone=extra.get("timezone") or None,
+            extra_credential=extra.get("extra_credential") or None,
+            roles=extra.get("roles") or None,
         )
 
         return trino_conn
diff --git a/providers/trino/tests/unit/trino/hooks/test_trino.py 
b/providers/trino/tests/unit/trino/hooks/test_trino.py
index 7e485a28138..02966d2e919 100644
--- a/providers/trino/tests/unit/trino/hooks/test_trino.py
+++ b/providers/trino/tests/unit/trino/hooks/test_trino.py
@@ -258,6 +258,22 @@ class TestTrinoHookConn:
         TrinoHook().get_conn()
         self.assert_connection_called_with(mock_connect, 
timezone="Asia/Jerusalem")
 
+    @patch(HOOK_GET_CONNECTION)
+    @patch(TRINO_DBAPI_CONNECT)
+    def test_get_conn_extra_credential(self, mock_connect, 
mock_get_connection):
+        extras = {"extra_credential": [["a.username", "bar"], ["a.password", 
"foo"]]}
+        self.set_get_connection_return_value(mock_get_connection, 
extra=json.dumps(extras))
+        TrinoHook().get_conn()
+        self.assert_connection_called_with(mock_connect, 
extra_credential=extras["extra_credential"])
+
+    @patch(HOOK_GET_CONNECTION)
+    @patch(TRINO_DBAPI_CONNECT)
+    def test_get_conn_roles(self, mock_connect, mock_get_connection):
+        extras = {"roles": {"catalog1": "trinoRoleA", "catalog2": 
"trinoRoleB"}}
+        self.set_get_connection_return_value(mock_get_connection, 
extra=json.dumps(extras))
+        TrinoHook().get_conn()
+        self.assert_connection_called_with(mock_connect, roles=extras["roles"])
+
     @staticmethod
     def set_get_connection_return_value(mock_get_connection, extra=None, 
password=None):
         mocked_connection = Connection(
@@ -274,6 +290,8 @@ class TestTrinoHookConn:
         session_properties=None,
         client_tags=None,
         timezone=None,
+        extra_credential=None,
+        roles=None,
     ):
         mock_connect.assert_called_once_with(
             catalog="hive",
@@ -290,6 +308,8 @@ class TestTrinoHookConn:
             session_properties=session_properties,
             client_tags=client_tags,
             timezone=timezone,
+            extra_credential=extra_credential,
+            roles=roles,
         )
 
 

Reply via email to