This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 17b97e755a Allow session properties for trino connection (#27095)
17b97e755a is described below

commit 17b97e755a1e4b10b3bb47a3f334ed2677ac5ba5
Author: Aakash Nand <[email protected]>
AuthorDate: Sat Oct 22 03:39:58 2022 +0900

    Allow session properties for trino connection (#27095)
---
 airflow/providers/trino/hooks/trino.py              |  2 ++
 docs/apache-airflow-providers-trino/connections.rst |  1 +
 tests/providers/trino/hooks/test_trino.py           | 21 ++++++++++++++++++++-
 3 files changed, 23 insertions(+), 1 deletion(-)

diff --git a/airflow/providers/trino/hooks/trino.py 
b/airflow/providers/trino/hooks/trino.py
index da79664aa4..3023c4a3bd 100644
--- a/airflow/providers/trino/hooks/trino.py
+++ b/airflow/providers/trino/hooks/trino.py
@@ -101,6 +101,7 @@ class TrinoHook(DbApiHook):
         extra = db.extra_dejson
         auth = None
         user = db.login
+        session_properties = extra.get('session_properties')
         if db.password and extra.get('auth') in ('kerberos', 'certs'):
             raise AirflowException(f"The {extra.get('auth')!r} authorization 
type doesn't support password.")
         elif db.password:
@@ -145,6 +146,7 @@ class TrinoHook(DbApiHook):
             # type: ignore[func-returns-value]
             isolation_level=self.get_isolation_level(),
             verify=_boolify(extra.get('verify', True)),
+            session_properties=session_properties if session_properties else 
None,
         )
 
         return trino_conn
diff --git a/docs/apache-airflow-providers-trino/connections.rst 
b/docs/apache-airflow-providers-trino/connections.rst
index 9e7bb7301e..3dc279332d 100644
--- a/docs/apache-airflow-providers-trino/connections.rst
+++ b/docs/apache-airflow-providers-trino/connections.rst
@@ -51,3 +51,4 @@ Extra (optional, connection parameters)
     * ``jwt__token`` - If jwt authentication should be used, the value of 
token is given via this parameter.
     * ``certs__client_cert_path``, ``certs__client_key_path``- If certificate 
authentication should be used, the path to the client certificate and key is 
given via these parameters.
     * ``kerberos__service_name``, ``kerberos__config``, 
``kerberos__mutual_authentication``, ``kerberos__force_preemptive``, 
``kerberos__hostname_override``, ``kerberos__sanitize_mutual_error_response``, 
``kerberos__principal``,``kerberos__delegate``, ``kerberos__ca_bundle`` - These 
parameters can be set when enabling ``kerberos`` authentication.
+    * ``session_properties`` - JSON dictionary which allows to set 
session_properties. Example: 
``{'session_properties':{'scale_writers':true,'task_writer_count:1'}}``
diff --git a/tests/providers/trino/hooks/test_trino.py 
b/tests/providers/trino/hooks/test_trino.py
index 11d95c2bcc..78d0370e2e 100644
--- a/tests/providers/trino/hooks/test_trino.py
+++ b/tests/providers/trino/hooks/test_trino.py
@@ -153,6 +153,22 @@ class TestTrinoHookConn:
         TrinoHook().get_conn()
         self.assert_connection_called_with(mock_connect, auth=mock_auth)
 
+    @patch(HOOK_GET_CONNECTION)
+    @patch(TRINO_DBAPI_CONNECT)
+    def test_get_conn_session_properties(self, mock_connect, 
mock_get_connection):
+        extras = {
+            'session_properties': {
+                'scale_writers': 'true',
+                'task_writer_count': '1',
+                'writer_min_size': '100MB',
+            },
+        }
+
+        self.set_get_connection_return_value(mock_get_connection, extra=extras)
+        TrinoHook().get_conn()
+
+        self.assert_connection_called_with(mock_connect, 
session_properties=extras['session_properties'])
+
     @parameterized.expand(
         [
             ('False', False),
@@ -178,7 +194,9 @@ class TestTrinoHookConn:
         mock_get_connection.return_value = mocked_connection
 
     @staticmethod
-    def assert_connection_called_with(mock_connect, http_headers=mock.ANY, 
auth=None, verify=True):
+    def assert_connection_called_with(
+        mock_connect, http_headers=mock.ANY, auth=None, verify=True, 
session_properties=None
+    ):
         mock_connect.assert_called_once_with(
             catalog='hive',
             host='host',
@@ -191,6 +209,7 @@ class TestTrinoHookConn:
             isolation_level=IsolationLevel.AUTOCOMMIT,
             auth=None if not auth else auth.return_value,
             verify=verify,
+            session_properties=session_properties,
         )
 
 

Reply via email to