This is an automated email from the ASF dual-hosted git repository.
gopidesu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 356961f8be9 Add ModelRetry handling to DataFusionToolset (#63501)
356961f8be9 is described below
commit 356961f8be9bbcb302ce49e866b754ae2b55bb7c
Author: GPK <[email protected]>
AuthorDate: Sat Mar 14 03:38:48 2026 +0000
Add ModelRetry handling to DataFusionToolset (#63501)
* Add ModelRetry handling to DataFusionToolset
* Resolve comments
---
.../providers/common/ai/toolsets/datafusion.py | 19 +++++++
.../unit/common/ai/toolsets/test_datafusion.py | 63 ++++++++++++++++++++--
2 files changed, 79 insertions(+), 3 deletions(-)
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/toolsets/datafusion.py
b/providers/common/ai/src/airflow/providers/common/ai/toolsets/datafusion.py
index 7c3de86241e..a83bd8bbc8c 100644
--- a/providers/common/ai/src/airflow/providers/common/ai/toolsets/datafusion.py
+++ b/providers/common/ai/src/airflow/providers/common/ai/toolsets/datafusion.py
@@ -20,6 +20,7 @@ from __future__ import annotations
import json
import logging
+import re
from typing import TYPE_CHECKING, Any
try:
@@ -31,6 +32,7 @@ except ImportError as e:
raise AirflowOptionalProviderFeatureException(e)
+from pydantic_ai.exceptions import ModelRetry
from pydantic_ai.tools import ToolDefinition
from pydantic_ai.toolsets.abstract import AbstractToolset, ToolsetTool
from pydantic_core import SchemaValidator, core_schema
@@ -66,6 +68,14 @@ _QUERY_SCHEMA: dict[str, Any] = {
"required": ["sql"],
}
+# DataFusion python bindings don't expose any native exception types, it uses
rust exceptions.
+# So we have to rely on error message parsing with regex.
+_RETRYABLE_IDENTIFIER = r"""(?:['"][^'"]+['"]|\w+)"""
+_RETRYABLE_QUERY_ERROR_PATTERNS = (
+ re.compile(rf"""column\s+{_RETRYABLE_IDENTIFIER}\s+not\s+found""",
re.IGNORECASE),
+ re.compile(rf"""table\s+{_RETRYABLE_IDENTIFIER}\s+not\s+found""",
re.IGNORECASE),
+)
+
class DataFusionToolset(AbstractToolset[Any]):
"""
@@ -204,4 +214,13 @@ class DataFusionToolset(AbstractToolset[Any]):
log.warning("query failed SQL safety validation: %s", ex)
raise
except QueryExecutionException as ex:
+ if self._is_retryable_query_error(ex):
+ raise ModelRetry(
+ f"error: {ex!s}, Use get_schema and list_tables tools for
more details."
+ ) from ex
return json.dumps({"error": str(ex), "query": sql})
+
+ @staticmethod
+ def _is_retryable_query_error(error: QueryExecutionException) -> bool:
+ message = str(error)
+ return any(pattern.search(message) for pattern in
_RETRYABLE_QUERY_ERROR_PATTERNS)
diff --git
a/providers/common/ai/tests/unit/common/ai/toolsets/test_datafusion.py
b/providers/common/ai/tests/unit/common/ai/toolsets/test_datafusion.py
index 77bc0cc80ea..89959649cdd 100644
--- a/providers/common/ai/tests/unit/common/ai/toolsets/test_datafusion.py
+++ b/providers/common/ai/tests/unit/common/ai/toolsets/test_datafusion.py
@@ -18,13 +18,18 @@ from __future__ import annotations
import asyncio
import json
+from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
from pydantic_ai._run_context import RunContext
+from pydantic_ai.exceptions import ModelRetry
from pydantic_ai.toolsets.abstract import ToolsetTool
-from airflow.providers.common.ai.toolsets.datafusion import DataFusionToolset
+from airflow.providers.common.ai.toolsets.datafusion import (
+ _RETRYABLE_QUERY_ERROR_PATTERNS,
+ DataFusionToolset,
+)
from airflow.providers.common.ai.utils.sql_validation import SQLSafetyError
from airflow.providers.common.sql.config import DataSourceConfig
@@ -267,13 +272,27 @@ class TestDataFusionToolsetGetSchemaErrors:
class TestDataFusionToolsetQueryErrors:
+ @pytest.mark.parametrize(
+ ("message", "expected"),
+ [
+ ("DataFusion error: column 'name' not found", True),
+ ("DataFusion error: column name not found", True),
+ ("DataFusion error: table tripss not found", True),
+ ("DataFusion error: table 'tripss' not found", True),
+ ("DataFusion error: access denied", False),
+ ],
+ )
+ def test_retryable_query_error_patterns_match_expected_messages(self,
message: str, expected: bool):
+ matches = any(pattern.search(message) for pattern in
_RETRYABLE_QUERY_ERROR_PATTERNS)
+ assert matches is expected
+
def test_query_execution_exception_returns_error_json(self):
from airflow.providers.common.sql.datafusion.exceptions import
QueryExecutionException
cfg = _make_mock_datasource_config()
ts = DataFusionToolset([cfg])
engine = _make_mock_engine()
- engine.execute_query.side_effect = QueryExecutionException("execution
failed: column x not found")
+ engine.execute_query.side_effect = QueryExecutionException("execution
failed")
ts._engine = engine
result = asyncio.run(
@@ -285,9 +304,47 @@ class TestDataFusionToolsetQueryErrors:
)
)
data = json.loads(result)
- assert "column x not found" in data["error"]
+ assert "execution failed" in data["error"]
assert data["query"] == "SELECT x FROM t"
+ @pytest.mark.parametrize(
+ ("sql", "expected_substring"),
+ [
+ ("SELECT name FROM trips", "column 'name' not found"),
+ ("SELECT * FROM tripss", "table 'tripss' not found"),
+ ],
+ )
+ def test_retryable_query_errors_raise_model_retry(
+ self, tmp_path: Path, sql: str, expected_substring: str
+ ):
+ csv_path = tmp_path / "trips.csv"
+ csv_path.write_text("pickup_date,id\n2024-01-01,1\n")
+
+ ts = DataFusionToolset(
+ [
+ DataSourceConfig(
+ conn_id="",
+ table_name="trips",
+ uri=f"file://{csv_path}",
+ format="csv",
+ )
+ ]
+ )
+
+ with pytest.raises(ModelRetry) as exc_info:
+ asyncio.run(
+ ts.call_tool(
+ "query",
+ {"sql": sql},
+ ctx=MagicMock(spec=RunContext),
+ tool=MagicMock(spec=ToolsetTool),
+ )
+ )
+
+ assert expected_substring in exc_info.value.message
+ assert "get_schema" in exc_info.value.message
+ assert "list_tables" in exc_info.value.message
+
def test_unexpected_exception_propagates(self):
cfg = _make_mock_datasource_config()
ts = DataFusionToolset([cfg])