This is an automated email from the ASF dual-hosted git repository.

gopidesu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 356961f8be9 Add ModelRetry handling to DataFusionToolset (#63501)
356961f8be9 is described below

commit 356961f8be9bbcb302ce49e866b754ae2b55bb7c
Author: GPK <[email protected]>
AuthorDate: Sat Mar 14 03:38:48 2026 +0000

    Add ModelRetry handling to DataFusionToolset (#63501)
    
    * Add ModelRetry handling to DataFusionToolset
    
    * Resolve comments
---
 .../providers/common/ai/toolsets/datafusion.py     | 19 +++++++
 .../unit/common/ai/toolsets/test_datafusion.py     | 63 ++++++++++++++++++++--
 2 files changed, 79 insertions(+), 3 deletions(-)

diff --git 
a/providers/common/ai/src/airflow/providers/common/ai/toolsets/datafusion.py 
b/providers/common/ai/src/airflow/providers/common/ai/toolsets/datafusion.py
index 7c3de86241e..a83bd8bbc8c 100644
--- a/providers/common/ai/src/airflow/providers/common/ai/toolsets/datafusion.py
+++ b/providers/common/ai/src/airflow/providers/common/ai/toolsets/datafusion.py
@@ -20,6 +20,7 @@ from __future__ import annotations
 
 import json
 import logging
+import re
 from typing import TYPE_CHECKING, Any
 
 try:
@@ -31,6 +32,7 @@ except ImportError as e:
 
     raise AirflowOptionalProviderFeatureException(e)
 
+from pydantic_ai.exceptions import ModelRetry
 from pydantic_ai.tools import ToolDefinition
 from pydantic_ai.toolsets.abstract import AbstractToolset, ToolsetTool
 from pydantic_core import SchemaValidator, core_schema
@@ -66,6 +68,14 @@ _QUERY_SCHEMA: dict[str, Any] = {
     "required": ["sql"],
 }
 
+# DataFusion python bindings don't expose any native exception types, it uses 
rust exceptions.
+# So we have to rely on error message parsing with regex.
+_RETRYABLE_IDENTIFIER = r"""(?:['"][^'"]+['"]|\w+)"""
+_RETRYABLE_QUERY_ERROR_PATTERNS = (
+    re.compile(rf"""column\s+{_RETRYABLE_IDENTIFIER}\s+not\s+found""", 
re.IGNORECASE),
+    re.compile(rf"""table\s+{_RETRYABLE_IDENTIFIER}\s+not\s+found""", 
re.IGNORECASE),
+)
+
 
 class DataFusionToolset(AbstractToolset[Any]):
     """
@@ -204,4 +214,13 @@ class DataFusionToolset(AbstractToolset[Any]):
             log.warning("query failed SQL safety validation: %s", ex)
             raise
         except QueryExecutionException as ex:
+            if self._is_retryable_query_error(ex):
+                raise ModelRetry(
+                    f"error: {ex!s}, Use get_schema and list_tables tools for 
more details."
+                ) from ex
             return json.dumps({"error": str(ex), "query": sql})
+
+    @staticmethod
+    def _is_retryable_query_error(error: QueryExecutionException) -> bool:
+        message = str(error)
+        return any(pattern.search(message) for pattern in 
_RETRYABLE_QUERY_ERROR_PATTERNS)
diff --git 
a/providers/common/ai/tests/unit/common/ai/toolsets/test_datafusion.py 
b/providers/common/ai/tests/unit/common/ai/toolsets/test_datafusion.py
index 77bc0cc80ea..89959649cdd 100644
--- a/providers/common/ai/tests/unit/common/ai/toolsets/test_datafusion.py
+++ b/providers/common/ai/tests/unit/common/ai/toolsets/test_datafusion.py
@@ -18,13 +18,18 @@ from __future__ import annotations
 
 import asyncio
 import json
+from pathlib import Path
 from unittest.mock import MagicMock, patch
 
 import pytest
 from pydantic_ai._run_context import RunContext
+from pydantic_ai.exceptions import ModelRetry
 from pydantic_ai.toolsets.abstract import ToolsetTool
 
-from airflow.providers.common.ai.toolsets.datafusion import DataFusionToolset
+from airflow.providers.common.ai.toolsets.datafusion import (
+    _RETRYABLE_QUERY_ERROR_PATTERNS,
+    DataFusionToolset,
+)
 from airflow.providers.common.ai.utils.sql_validation import SQLSafetyError
 from airflow.providers.common.sql.config import DataSourceConfig
 
@@ -267,13 +272,27 @@ class TestDataFusionToolsetGetSchemaErrors:
 
 
 class TestDataFusionToolsetQueryErrors:
+    @pytest.mark.parametrize(
+        ("message", "expected"),
+        [
+            ("DataFusion error: column 'name' not found", True),
+            ("DataFusion error: column name not found", True),
+            ("DataFusion error: table tripss not found", True),
+            ("DataFusion error: table 'tripss' not found", True),
+            ("DataFusion error: access denied", False),
+        ],
+    )
+    def test_retryable_query_error_patterns_match_expected_messages(self, 
message: str, expected: bool):
+        matches = any(pattern.search(message) for pattern in 
_RETRYABLE_QUERY_ERROR_PATTERNS)
+        assert matches is expected
+
     def test_query_execution_exception_returns_error_json(self):
         from airflow.providers.common.sql.datafusion.exceptions import 
QueryExecutionException
 
         cfg = _make_mock_datasource_config()
         ts = DataFusionToolset([cfg])
         engine = _make_mock_engine()
-        engine.execute_query.side_effect = QueryExecutionException("execution 
failed: column x not found")
+        engine.execute_query.side_effect = QueryExecutionException("execution 
failed")
         ts._engine = engine
 
         result = asyncio.run(
@@ -285,9 +304,47 @@ class TestDataFusionToolsetQueryErrors:
             )
         )
         data = json.loads(result)
-        assert "column x not found" in data["error"]
+        assert "execution failed" in data["error"]
         assert data["query"] == "SELECT x FROM t"
 
+    @pytest.mark.parametrize(
+        ("sql", "expected_substring"),
+        [
+            ("SELECT name FROM trips", "column 'name' not found"),
+            ("SELECT * FROM tripss", "table 'tripss' not found"),
+        ],
+    )
+    def test_retryable_query_errors_raise_model_retry(
+        self, tmp_path: Path, sql: str, expected_substring: str
+    ):
+        csv_path = tmp_path / "trips.csv"
+        csv_path.write_text("pickup_date,id\n2024-01-01,1\n")
+
+        ts = DataFusionToolset(
+            [
+                DataSourceConfig(
+                    conn_id="",
+                    table_name="trips",
+                    uri=f"file://{csv_path}",
+                    format="csv",
+                )
+            ]
+        )
+
+        with pytest.raises(ModelRetry) as exc_info:
+            asyncio.run(
+                ts.call_tool(
+                    "query",
+                    {"sql": sql},
+                    ctx=MagicMock(spec=RunContext),
+                    tool=MagicMock(spec=ToolsetTool),
+                )
+            )
+
+        assert expected_substring in exc_info.value.message
+        assert "get_schema" in exc_info.value.message
+        assert "list_tables" in exc_info.value.message
+
     def test_unexpected_exception_propagates(self):
         cfg = _make_mock_datasource_config()
         ts = DataFusionToolset([cfg])

Reply via email to