This is an automated email from the ASF dual-hosted git repository.

gopidesu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 8d43942b24f Simplify default rows limit return result (#64183)
8d43942b24f is described below

commit 8d43942b24f57eb25c47a89b73c69561a6f90d38
Author: GPK <[email protected]>
AuthorDate: Tue Mar 24 20:31:06 2026 +0000

    Simplify default rows limit return result (#64183)
    
    * Simplify default rows limit return result
    
    * Resolve comments
---
 providers/common/sql/docs/operators.rst            |  2 +-
 .../providers/common/sql/datafusion/engine.py      | 13 ++++++-
 .../providers/common/sql/operators/analytics.py    | 42 +++-------------------
 .../unit/common/sql/datafusion/test_engine.py      | 39 +++++++++++++++++++-
 .../unit/common/sql/operators/test_analytics.py    | 11 +++---
 5 files changed, 62 insertions(+), 45 deletions(-)

diff --git a/providers/common/sql/docs/operators.rst 
b/providers/common/sql/docs/operators.rst
index 773feefb4b5..88e326a25fe 100644
--- a/providers/common/sql/docs/operators.rst
+++ b/providers/common/sql/docs/operators.rst
@@ -288,7 +288,7 @@ Parameters
 ----------
 * ``datasource_configs`` (list[DataSourceConfig], required): List of 
datasource configurations
 * ``queries`` (list[str], required): List of SQL queries to run on the data
-* ``max_rows_check`` (int, optional): Maximum number of rows to check for each 
query. Default is 100. If any query returns more than this number of rows, it 
will be skipped in the results returned by the operator. This is to prevent 
returning too many rows in the results which can cause xcom rendering issues in 
Airflow UI.
+* ``max_rows_check`` (int, optional): Maximum number of rows returned for each 
query. Default is 100. If a query returns more rows, the engine logs a warning 
and returns only the first ``max_rows_check`` rows. This prevents returning too 
many rows, which can cause xcom rendering issues in the Airflow UI.
 * ``engine`` (DataFusionEngine, optional): Query engine to use. Default is 
"datafusion". Currently, only "datafusion" is supported.
 * ``result_output_format`` (str, optional): Output format for the results. 
Default is ``tabulate``. Supported formats are ``tabulate``, ``json``.
 
diff --git 
a/providers/common/sql/src/airflow/providers/common/sql/datafusion/engine.py 
b/providers/common/sql/src/airflow/providers/common/sql/datafusion/engine.py
index 4786c1f779c..21e0b72390b 100644
--- a/providers/common/sql/src/airflow/providers/common/sql/datafusion/engine.py
+++ b/providers/common/sql/src/airflow/providers/common/sql/datafusion/engine.py
@@ -101,11 +101,22 @@ class DataFusionEngine(LoggingMixin):
             datasource_config.table_name,
         )
 
-    def execute_query(self, query: str) -> dict[str, list[Any]]:
+    def execute_query(self, query: str, max_rows: int | None = None) -> 
dict[str, list[Any]]:
         """Execute a query and return the result as a dictionary."""
         try:
             self.log.info("Executing query: %s", query)
             df = self.session_context.sql(query)
+
+            if max_rows is not None:
+                result = df.limit(max_rows + 1).to_pydict()
+                if result and len(next(iter(result.values()))) > max_rows:
+                    self.log.warning(
+                        "Query returned more than %s rows. Returning first %s 
rows.",
+                        max_rows,
+                        max_rows,
+                    )
+                    return {column: values[:max_rows] for column, values in 
result.items()}
+                return result
             return df.to_pydict()
         except Exception as e:
             raise QueryExecutionException(f"Error while executing query: {e}")
diff --git 
a/providers/common/sql/src/airflow/providers/common/sql/operators/analytics.py 
b/providers/common/sql/src/airflow/providers/common/sql/operators/analytics.py
index fd814139ec7..2f24fcab3a3 100644
--- 
a/providers/common/sql/src/airflow/providers/common/sql/operators/analytics.py
+++ 
b/providers/common/sql/src/airflow/providers/common/sql/operators/analytics.py
@@ -38,7 +38,8 @@ class AnalyticsOperator(BaseOperator):
 
     :param datasource_configs: List of datasource configurations to register.
     :param queries: List of SQL queries to execute.
-    :param max_rows_check: Maximum number of rows allowed in query results. 
Queries exceeding this will be skipped.
+    :param max_rows_check: Maximum number of rows returned per query. Queries 
exceeding this return
+        only the first N rows with a warning.
     :param engine: Optional DataFusion engine instance.
     :param result_output_format: List of output formats for results. 
Supported: 'tabulate', 'json'. Default is 'tabulate'.
     """
@@ -89,7 +90,7 @@ class AnalyticsOperator(BaseOperator):
 
         # TODO make it parallel as there is no dependency between queries
         for query in self.queries:
-            result_dict = self._df_engine.execute_query(query)
+            result_dict = self._df_engine.execute_query(query, 
max_rows=self.max_rows_check)
             results.append({"query": query, "data": result_dict})
 
         match self.result_output_format:
@@ -100,20 +101,6 @@ class AnalyticsOperator(BaseOperator):
             case _:
                 raise ValueError(f"Unsupported output format: 
{self.result_output_format}")
 
-    def _is_result_too_large(self, result_dict: dict[str, Any]) -> tuple[bool, 
int]:
-        """Check if a result exceeds the max_rows_check limit."""
-        if not result_dict:
-            return False, 0
-        num_rows = len(next(iter(result_dict.values())))
-        max_rows_exceeded = num_rows > self.max_rows_check
-        if max_rows_exceeded:
-            self.log.warning(
-                "Query returned %s rows, exceeding max_rows_check (%s). 
Skipping result output as large datasets are unsuitable for return.",
-                num_rows,
-                self.max_rows_check,
-            )
-        return max_rows_exceeded, num_rows
-
     def _build_tabulate_output(self, query_results: list[dict[str, Any]]) -> 
str:
         from tabulate import tabulate
 
@@ -121,15 +108,7 @@ class AnalyticsOperator(BaseOperator):
         for item in query_results:
             query = item["query"]
             result_dict = item["data"]
-            too_large, row_count = self._is_result_too_large(result_dict)
-
-            if too_large:
-                output_parts.append(
-                    f"\n### Results: {query}\n\n"
-                    f"**Skipped**: {row_count} rows exceed max_rows_check 
({self.max_rows_check})\n\n"
-                    f"{'-' * 40}\n"
-                )
-                continue
+            row_count = len(next(iter(result_dict.values()))) if result_dict 
else 0
 
             table_str = tabulate(
                 self._get_rows(result_dict, row_count),
@@ -151,18 +130,7 @@ class AnalyticsOperator(BaseOperator):
         for item in query_results:
             query = item["query"]
             result_dict = item["data"]
-            max_rows_exceeded, row_count = 
self._is_result_too_large(result_dict)
-
-            if max_rows_exceeded:
-                json_results.append(
-                    {
-                        "query": query,
-                        "status": "skipped_too_large",
-                        "row_count": row_count,
-                        "max_allowed": self.max_rows_check,
-                    }
-                )
-                continue
+            row_count = len(next(iter(result_dict.values()))) if result_dict 
else 0
 
             json_results.append(
                 {
diff --git 
a/providers/common/sql/tests/unit/common/sql/datafusion/test_engine.py 
b/providers/common/sql/tests/unit/common/sql/datafusion/test_engine.py
index 24606f7e7c8..d48c705046f 100644
--- a/providers/common/sql/tests/unit/common/sql/datafusion/test_engine.py
+++ b/providers/common/sql/tests/unit/common/sql/datafusion/test_engine.py
@@ -141,15 +141,52 @@ class TestDataFusionEngine:
     def test_execute_query_success(self):
         engine = DataFusionEngine()
         engine.df_ctx = MagicMock(spec=SessionContext)
-        mock_df = MagicMock()
+        mock_df = MagicMock(spec=["limit", "to_pydict"])
         mock_df.to_pydict.return_value = {"col1": [1, 2]}
         engine.df_ctx.sql.return_value = mock_df
 
         result = engine.execute_query("SELECT * FROM test_table")
 
         engine.df_ctx.sql.assert_called_once_with("SELECT * FROM test_table")
+        mock_df.limit.assert_not_called()
         assert result == {"col1": [1, 2]}
 
+    def test_execute_query_with_max_rows(self):
+        engine = DataFusionEngine()
+        engine.df_ctx = MagicMock(spec=SessionContext)
+        mock_df = MagicMock(spec=["limit", "to_pydict"])
+        limited_df = MagicMock(spec=["to_pydict"])
+        limited_df.to_pydict.return_value = {"col1": [1, 2, 3]}
+        mock_df.limit.return_value = limited_df
+        engine.df_ctx.sql.return_value = mock_df
+
+        result = engine.execute_query("SELECT * FROM test_table", max_rows=3)
+
+        engine.df_ctx.sql.assert_called_once_with("SELECT * FROM test_table")
+        mock_df.limit.assert_called_once_with(4)
+        assert result == {"col1": [1, 2, 3]}
+
+    def test_execute_query_with_max_rows_logs_warning_when_exceeded(self):
+        engine = DataFusionEngine()
+        engine.df_ctx = MagicMock(spec=SessionContext)
+        mock_df = MagicMock(spec=["limit", "to_pydict"])
+        limited_df = MagicMock(spec=["to_pydict"])
+        limited_df.to_pydict.return_value = {"col1": [1, 2, 3, 4]}
+        mock_df.limit.return_value = limited_df
+        engine.df_ctx.sql.return_value = mock_df
+
+        with patch.object(engine.log, "warning") as mock_warning:
+            result = engine.execute_query("SELECT * FROM test_table", 
max_rows=3)
+
+        engine.df_ctx.sql.assert_called_once_with("SELECT * FROM test_table")
+        mock_df.limit.assert_called_once_with(4)
+        mock_warning.assert_called_once_with(
+            "Query returned more than %s rows. Returning first %s rows.",
+            3,
+            3,
+        )
+        assert result == {"col1": [1, 2, 3]}
+
     def test_execute_query_failure(self):
         engine = DataFusionEngine()
         engine.df_ctx = MagicMock(spec=SessionContext)
diff --git 
a/providers/common/sql/tests/unit/common/sql/operators/test_analytics.py 
b/providers/common/sql/tests/unit/common/sql/operators/test_analytics.py
index c1ded538337..c116a2ad63c 100644
--- a/providers/common/sql/tests/unit/common/sql/operators/test_analytics.py
+++ b/providers/common/sql/tests/unit/common/sql/operators/test_analytics.py
@@ -53,18 +53,19 @@ class TestAnalyticsOperator:
         result = operator.execute(context={})
 
         mock_engine.register_datasource.assert_called_once()
-        mock_engine.execute_query.assert_called_once_with("SELECT * FROM 
users_data")
+        mock_engine.execute_query.assert_called_once_with("SELECT * FROM 
users_data", max_rows=100)
         assert "col1" in result
         assert "col2" in result
 
-    def test_execute_max_rows_exceeded(self, operator, mock_engine):
+    def test_execute_uses_overridden_max_rows_limit(self, operator, 
mock_engine):
         operator.max_rows_check = 3
-        mock_engine.execute_query.return_value = {"col1": [1, 2, 3, 4]}
+        mock_engine.execute_query.return_value = {"col1": [1, 2, 3]}
 
         result = operator.execute(context={})
 
-        assert "Skipped" in result
-        assert "4 rows exceed max_rows_check (3)" in result
+        mock_engine.execute_query.assert_called_once_with("SELECT * FROM 
users_data", max_rows=3)
+        assert "Skipped" not in result
+        assert "col1" in result
 
     def test_json_output_format(self, mock_engine):
         datasource_config = DataSourceConfig(

Reply via email to