This is an automated email from the ASF dual-hosted git repository.
shahar pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 5e943e3c8ff Add extra_credentials and roles to TrinoHook (#51298)
5e943e3c8ff is described below
commit 5e943e3c8ff89fd3c4b4f5402a4fa61d79ad1360
Author: Zach <[email protected]>
AuthorDate: Mon Jun 2 11:42:52 2025 -0400
Add extra_credentials and roles to TrinoHook (#51298)
* Add extra_credentials and roles to TrinoHook
* Address Trino provider doc URL error
---
providers/trino/docs/connections.rst | 2 ++
.../trino/src/airflow/providers/trino/hooks/trino.py | 2 ++
providers/trino/tests/unit/trino/hooks/test_trino.py | 20 ++++++++++++++++++++
3 files changed, 24 insertions(+)
diff --git a/providers/trino/docs/connections.rst
b/providers/trino/docs/connections.rst
index f6bd0b46c1e..32e0abe360d 100644
--- a/providers/trino/docs/connections.rst
+++ b/providers/trino/docs/connections.rst
@@ -55,5 +55,7 @@ Extra (optional, connection parameters)
* ``session_properties`` - JSON dictionary which allows to set
session_properties. Example:
``{'session_properties':{'scale_writers':true,'task_writer_count:1'}}``
* ``client_tags`` - List of comma separated tags. Example
``{'client_tags':['sales','cluster1']}```
* ``timezone`` - The time zone for the session can be explicitly set using
the IANA time zone name. Example: ``{'timezone':'Asia/Jerusalem'}``.
+ * ``extra_credential`` - List of key-value string pairs which are passed
to the Trino connector. For more information, refer to the Trino client
protocol doc page here:
https://trino.io/docs/current/develop/client-protocol.html
+ * ``roles`` - Mapping of catalog names to their corresponding Trino
authorization role. For more information, refer to the Trino Python client docs
here: https://github.com/trinodb/trino-python-client?tab=readme-ov-file#roles
Note: If ``jwt__file`` and ``jwt__token`` are both given, ``jwt__file``
will take precedent.
diff --git a/providers/trino/src/airflow/providers/trino/hooks/trino.py
b/providers/trino/src/airflow/providers/trino/hooks/trino.py
index 1cca01a2770..35bb83cce2a 100644
--- a/providers/trino/src/airflow/providers/trino/hooks/trino.py
+++ b/providers/trino/src/airflow/providers/trino/hooks/trino.py
@@ -211,6 +211,8 @@ class TrinoHook(DbApiHook):
session_properties=extra.get("session_properties") or None,
client_tags=extra.get("client_tags") or None,
timezone=extra.get("timezone") or None,
+ extra_credential=extra.get("extra_credential") or None,
+ roles=extra.get("roles") or None,
)
return trino_conn
diff --git a/providers/trino/tests/unit/trino/hooks/test_trino.py
b/providers/trino/tests/unit/trino/hooks/test_trino.py
index 7e485a28138..02966d2e919 100644
--- a/providers/trino/tests/unit/trino/hooks/test_trino.py
+++ b/providers/trino/tests/unit/trino/hooks/test_trino.py
@@ -258,6 +258,22 @@ class TestTrinoHookConn:
TrinoHook().get_conn()
self.assert_connection_called_with(mock_connect,
timezone="Asia/Jerusalem")
+ @patch(HOOK_GET_CONNECTION)
+ @patch(TRINO_DBAPI_CONNECT)
+ def test_get_conn_extra_credential(self, mock_connect,
mock_get_connection):
+ extras = {"extra_credential": [["a.username", "bar"], ["a.password",
"foo"]]}
+ self.set_get_connection_return_value(mock_get_connection,
extra=json.dumps(extras))
+ TrinoHook().get_conn()
+ self.assert_connection_called_with(mock_connect,
extra_credential=extras["extra_credential"])
+
+ @patch(HOOK_GET_CONNECTION)
+ @patch(TRINO_DBAPI_CONNECT)
+ def test_get_conn_roles(self, mock_connect, mock_get_connection):
+ extras = {"roles": {"catalog1": "trinoRoleA", "catalog2":
"trinoRoleB"}}
+ self.set_get_connection_return_value(mock_get_connection,
extra=json.dumps(extras))
+ TrinoHook().get_conn()
+ self.assert_connection_called_with(mock_connect, roles=extras["roles"])
+
@staticmethod
def set_get_connection_return_value(mock_get_connection, extra=None,
password=None):
mocked_connection = Connection(
@@ -274,6 +290,8 @@ class TestTrinoHookConn:
session_properties=None,
client_tags=None,
timezone=None,
+ extra_credential=None,
+ roles=None,
):
mock_connect.assert_called_once_with(
catalog="hive",
@@ -290,6 +308,8 @@ class TestTrinoHookConn:
session_properties=session_properties,
client_tags=client_tags,
timezone=timezone,
+ extra_credential=extra_credential,
+ roles=roles,
)