This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 17b97e755a Allow session properties for trino connection (#27095)
17b97e755a is described below
commit 17b97e755a1e4b10b3bb47a3f334ed2677ac5ba5
Author: Aakash Nand <[email protected]>
AuthorDate: Sat Oct 22 03:39:58 2022 +0900
Allow session properties for trino connection (#27095)
---
airflow/providers/trino/hooks/trino.py | 2 ++
docs/apache-airflow-providers-trino/connections.rst | 1 +
tests/providers/trino/hooks/test_trino.py | 21 ++++++++++++++++++++-
3 files changed, 23 insertions(+), 1 deletion(-)
diff --git a/airflow/providers/trino/hooks/trino.py
b/airflow/providers/trino/hooks/trino.py
index da79664aa4..3023c4a3bd 100644
--- a/airflow/providers/trino/hooks/trino.py
+++ b/airflow/providers/trino/hooks/trino.py
@@ -101,6 +101,7 @@ class TrinoHook(DbApiHook):
extra = db.extra_dejson
auth = None
user = db.login
+ session_properties = extra.get('session_properties')
if db.password and extra.get('auth') in ('kerberos', 'certs'):
raise AirflowException(f"The {extra.get('auth')!r} authorization
type doesn't support password.")
elif db.password:
@@ -145,6 +146,7 @@ class TrinoHook(DbApiHook):
# type: ignore[func-returns-value]
isolation_level=self.get_isolation_level(),
verify=_boolify(extra.get('verify', True)),
+ session_properties=session_properties if session_properties else
None,
)
return trino_conn
diff --git a/docs/apache-airflow-providers-trino/connections.rst
b/docs/apache-airflow-providers-trino/connections.rst
index 9e7bb7301e..3dc279332d 100644
--- a/docs/apache-airflow-providers-trino/connections.rst
+++ b/docs/apache-airflow-providers-trino/connections.rst
@@ -51,3 +51,4 @@ Extra (optional, connection parameters)
* ``jwt__token`` - If jwt authentication should be used, the value of
token is given via this parameter.
* ``certs__client_cert_path``, ``certs__client_key_path``- If certificate
authentication should be used, the path to the client certificate and key is
given via these parameters.
* ``kerberos__service_name``, ``kerberos__config``,
``kerberos__mutual_authentication``, ``kerberos__force_preemptive``,
``kerberos__hostname_override``, ``kerberos__sanitize_mutual_error_response``,
``kerberos__principal``,``kerberos__delegate``, ``kerberos__ca_bundle`` - These
parameters can be set when enabling ``kerberos`` authentication.
+ * ``session_properties`` - JSON dictionary which allows to set
session_properties. Example:
``{'session_properties':{'scale_writers':true,'task_writer_count:1'}}``
diff --git a/tests/providers/trino/hooks/test_trino.py
b/tests/providers/trino/hooks/test_trino.py
index 11d95c2bcc..78d0370e2e 100644
--- a/tests/providers/trino/hooks/test_trino.py
+++ b/tests/providers/trino/hooks/test_trino.py
@@ -153,6 +153,22 @@ class TestTrinoHookConn:
TrinoHook().get_conn()
self.assert_connection_called_with(mock_connect, auth=mock_auth)
+ @patch(HOOK_GET_CONNECTION)
+ @patch(TRINO_DBAPI_CONNECT)
+ def test_get_conn_session_properties(self, mock_connect,
mock_get_connection):
+ extras = {
+ 'session_properties': {
+ 'scale_writers': 'true',
+ 'task_writer_count': '1',
+ 'writer_min_size': '100MB',
+ },
+ }
+
+ self.set_get_connection_return_value(mock_get_connection, extra=extras)
+ TrinoHook().get_conn()
+
+ self.assert_connection_called_with(mock_connect,
session_properties=extras['session_properties'])
+
@parameterized.expand(
[
('False', False),
@@ -178,7 +194,9 @@ class TestTrinoHookConn:
mock_get_connection.return_value = mocked_connection
@staticmethod
- def assert_connection_called_with(mock_connect, http_headers=mock.ANY,
auth=None, verify=True):
+ def assert_connection_called_with(
+ mock_connect, http_headers=mock.ANY, auth=None, verify=True,
session_properties=None
+ ):
mock_connect.assert_called_once_with(
catalog='hive',
host='host',
@@ -191,6 +209,7 @@ class TestTrinoHookConn:
isolation_level=IsolationLevel.AUTOCOMMIT,
auth=None if not auth else auth.return_value,
verify=verify,
+ session_properties=session_properties,
)