This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 9a595e2391e providers/trino: set split_statements=True by default and
add unit tests (#58158)
9a595e2391e is described below
commit 9a595e2391e5cf6e27bebb22b2b91c234d8c9f5f
Author: Nikita Kalganov <[email protected]>
AuthorDate: Thu Nov 13 14:23:37 2025 +0300
providers/trino: set split_statements=True by default and add unit tests
(#58158)
* providers/trino: set split_statements=True by default in TrinoHook.run()
* providers/trino: add unit tests for TrinoHook.run()
* providers/trino: fixed patches in split_statement tests
* providers/trino fix documentation
* providers/trino: Apply ruff auto-format fixes in TrinoHook
* providers/trino: Apply ruff-format fixes in TrinoHook
* providers/trino: re-run after rebase on upstream/main
---------
Co-authored-by: Nikita <[email protected]>
Co-authored-by: Nikita.Kalganov <[email protected]>
---
.../src/airflow/providers/trino/hooks/trino.py | 48 +++++++++++++++++++++-
.../trino/tests/unit/trino/hooks/test_trino.py | 32 +++++++++++++++
2 files changed, 78 insertions(+), 2 deletions(-)
diff --git a/providers/trino/src/airflow/providers/trino/hooks/trino.py
b/providers/trino/src/airflow/providers/trino/hooks/trino.py
index b59ba32cb59..28d23423f6c 100644
--- a/providers/trino/src/airflow/providers/trino/hooks/trino.py
+++ b/providers/trino/src/airflow/providers/trino/hooks/trino.py
@@ -19,9 +19,9 @@ from __future__ import annotations
import json
import os
-from collections.abc import Iterable, Mapping
+from collections.abc import Callable, Iterable, Mapping
from pathlib import Path
-from typing import TYPE_CHECKING, Any, TypeVar
+from typing import TYPE_CHECKING, Any, TypeVar, overload
from urllib.parse import quote_plus, urlencode
import trino
@@ -277,6 +277,50 @@ class TrinoHook(DbApiHook):
df = pd.DataFrame(**kwargs)
return df
+ @overload
+ def run(
+ self,
+ sql: str | Iterable[str],
+ autocommit: bool = ...,
+ parameters: Iterable | Mapping[str, Any] | None = ...,
+ handler: None = ...,
+ split_statements: bool = ...,
+ return_last: bool = ...,
+ ) -> None: ...
+
+ @overload
+ def run(
+ self,
+ sql: str | Iterable[str],
+ autocommit: bool = ...,
+ parameters: Iterable | Mapping[str, Any] | None = ...,
+ handler: Callable[[Any], T] = ...,
+ split_statements: bool = ...,
+ return_last: bool = ...,
+ ) -> tuple | list[tuple] | list[list[tuple] | tuple] | None: ...
+
+ def run(
+ self,
+ sql: str | Iterable[str],
+ autocommit: bool = False,
+ parameters: Iterable | Mapping[str, Any] | None = None,
+ handler: Callable[[Any], T] | None = None,
+ split_statements: bool = True,
+ return_last: bool = True,
+ ) -> tuple | list[tuple] | list[list[tuple] | tuple] | None:
+ """
+ Override common run to set split_statements=True by default.
+
+ :param sql: SQL statement or list of statements to execute.
+ :param autocommit: Set autocommit mode before query execution.
+ :param parameters: Parameters to render the SQL query with.
+ :param handler: Optional callable to process each statement result.
+ :param split_statements: Split single SQL string into statements if
True.
+ :param return_last: Return only last statement result if True.
+ :return: Query result or list of results.
+ """
+ return super().run(sql, autocommit, parameters, handler,
split_statements, return_last)
+
def _get_polars_df(self, sql: str = "", parameters=None, **kwargs):
try:
import polars as pl
diff --git a/providers/trino/tests/unit/trino/hooks/test_trino.py
b/providers/trino/tests/unit/trino/hooks/test_trino.py
index 75bcd6c0e86..3623a4e9714 100644
--- a/providers/trino/tests/unit/trino/hooks/test_trino.py
+++ b/providers/trino/tests/unit/trino/hooks/test_trino.py
@@ -401,6 +401,38 @@ class TestTrinoHook:
self.db_hook.run(sql, autocommit, parameters, list)
mock_run.assert_called_once_with(sql, autocommit, parameters, handler)
+ @patch("airflow.providers.common.sql.hooks.sql.DbApiHook.run")
+ def test_run_defaults_no_handler(self, super_run):
+ super_run.return_value = None
+ sql = "SELECT 1"
+ result = self.db_hook.run(sql)
+ assert result is None
+ super_run.assert_called_once_with(sql, False, None, None, True, True)
+
+ @patch("airflow.providers.common.sql.hooks.sql.DbApiHook.run")
+ def test_run_with_handler_and_params(self, super_run):
+ super_run.return_value = [("ok",)]
+ sql = "SELECT 1"
+ autocommit = True
+ parameters = ("hello", "world")
+ handler = list
+ res = self.db_hook.run(
+ sql,
+ autocommit=autocommit,
+ parameters=parameters,
+ handler=handler,
+ split_statements=False,
+ return_last=False,
+ )
+ assert res == [("ok",)]
+ super_run.assert_called_once_with(sql, True, parameters, handler,
False, False)
+
+ @patch("airflow.providers.common.sql.hooks.sql.DbApiHook.run")
+ def test_run_multistatement_defaults_to_split(self, super_run):
+ sql = "SELECT 1; SELECT 2"
+ self.db_hook.run(sql)
+ super_run.assert_called_once_with(sql, False, None, None, True, True)
+
def test_connection_success(self):
status, msg = self.db_hook.test_connection()
assert status is True