This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 9a595e2391e providers/trino: set split_statements=True by default and 
add unit tests (#58158)
9a595e2391e is described below

commit 9a595e2391e5cf6e27bebb22b2b91c234d8c9f5f
Author: Nikita Kalganov <[email protected]>
AuthorDate: Thu Nov 13 14:23:37 2025 +0300

    providers/trino: set split_statements=True by default and add unit tests 
(#58158)
    
    * providers/trino: set split_statements=True by default in TrinoHook.run()
    
    * providers/trino: add unit tests for TrinoHook.run()
    
    * providers/trino: fixed patches in split_statement tests
    
    * providers/trino fix documentation
    
    * providers/trino: Apply ruff auto-format fixes in TrinoHook
    
    * providers/trino: Apply ruff-format fixes in TrinoHook
    
    * providers/trino: re-run after rebase on upstream/main
    
    ---------
    
    Co-authored-by: Nikita <[email protected]>
    Co-authored-by: Nikita.Kalganov <[email protected]>
---
 .../src/airflow/providers/trino/hooks/trino.py     | 48 +++++++++++++++++++++-
 .../trino/tests/unit/trino/hooks/test_trino.py     | 32 +++++++++++++++
 2 files changed, 78 insertions(+), 2 deletions(-)

diff --git a/providers/trino/src/airflow/providers/trino/hooks/trino.py 
b/providers/trino/src/airflow/providers/trino/hooks/trino.py
index b59ba32cb59..28d23423f6c 100644
--- a/providers/trino/src/airflow/providers/trino/hooks/trino.py
+++ b/providers/trino/src/airflow/providers/trino/hooks/trino.py
@@ -19,9 +19,9 @@ from __future__ import annotations
 
 import json
 import os
-from collections.abc import Iterable, Mapping
+from collections.abc import Callable, Iterable, Mapping
 from pathlib import Path
-from typing import TYPE_CHECKING, Any, TypeVar
+from typing import TYPE_CHECKING, Any, TypeVar, overload
 from urllib.parse import quote_plus, urlencode
 
 import trino
@@ -277,6 +277,50 @@ class TrinoHook(DbApiHook):
             df = pd.DataFrame(**kwargs)
         return df
 
+    @overload
+    def run(
+        self,
+        sql: str | Iterable[str],
+        autocommit: bool = ...,
+        parameters: Iterable | Mapping[str, Any] | None = ...,
+        handler: None = ...,
+        split_statements: bool = ...,
+        return_last: bool = ...,
+    ) -> None: ...
+
+    @overload
+    def run(
+        self,
+        sql: str | Iterable[str],
+        autocommit: bool = ...,
+        parameters: Iterable | Mapping[str, Any] | None = ...,
+        handler: Callable[[Any], T] = ...,
+        split_statements: bool = ...,
+        return_last: bool = ...,
+    ) -> tuple | list[tuple] | list[list[tuple] | tuple] | None: ...
+
+    def run(
+        self,
+        sql: str | Iterable[str],
+        autocommit: bool = False,
+        parameters: Iterable | Mapping[str, Any] | None = None,
+        handler: Callable[[Any], T] | None = None,
+        split_statements: bool = True,
+        return_last: bool = True,
+    ) -> tuple | list[tuple] | list[list[tuple] | tuple] | None:
+        """
+        Override common run to set split_statements=True by default.
+
+        :param sql: SQL statement or list of statements to execute.
+        :param autocommit: Set autocommit mode before query execution.
+        :param parameters: Parameters to render the SQL query with.
+        :param handler: Optional callable to process each statement result.
+        :param split_statements: Split single SQL string into statements if 
True.
+        :param return_last: Return only last statement result if True.
+        :return: Query result or list of results.
+        """
+        return super().run(sql, autocommit, parameters, handler, 
split_statements, return_last)
+
     def _get_polars_df(self, sql: str = "", parameters=None, **kwargs):
         try:
             import polars as pl
diff --git a/providers/trino/tests/unit/trino/hooks/test_trino.py 
b/providers/trino/tests/unit/trino/hooks/test_trino.py
index 75bcd6c0e86..3623a4e9714 100644
--- a/providers/trino/tests/unit/trino/hooks/test_trino.py
+++ b/providers/trino/tests/unit/trino/hooks/test_trino.py
@@ -401,6 +401,38 @@ class TestTrinoHook:
         self.db_hook.run(sql, autocommit, parameters, list)
         mock_run.assert_called_once_with(sql, autocommit, parameters, handler)
 
+    @patch("airflow.providers.common.sql.hooks.sql.DbApiHook.run")
+    def test_run_defaults_no_handler(self, super_run):
+        super_run.return_value = None
+        sql = "SELECT 1"
+        result = self.db_hook.run(sql)
+        assert result is None
+        super_run.assert_called_once_with(sql, False, None, None, True, True)
+
+    @patch("airflow.providers.common.sql.hooks.sql.DbApiHook.run")
+    def test_run_with_handler_and_params(self, super_run):
+        super_run.return_value = [("ok",)]
+        sql = "SELECT 1"
+        autocommit = True
+        parameters = ("hello", "world")
+        handler = list
+        res = self.db_hook.run(
+            sql,
+            autocommit=autocommit,
+            parameters=parameters,
+            handler=handler,
+            split_statements=False,
+            return_last=False,
+        )
+        assert res == [("ok",)]
+        super_run.assert_called_once_with(sql, True, parameters, handler, 
False, False)
+
+    @patch("airflow.providers.common.sql.hooks.sql.DbApiHook.run")
+    def test_run_multistatement_defaults_to_split(self, super_run):
+        sql = "SELECT 1; SELECT 2"
+        self.db_hook.run(sql)
+        super_run.assert_called_once_with(sql, False, None, None, True, True)
+
     def test_connection_success(self):
         status, msg = self.db_hook.test_connection()
         assert status is True

Reply via email to