This is an automated email from the ASF dual-hosted git repository.
kaxil pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 6769feb82e1 Return common.ai SQLToolset errors to the agent so it
self-corrects (#68117)
6769feb82e1 is described below
commit 6769feb82e1d8e8f71b29a81691cd00680ebe41a
Author: Kaxil Naik <[email protected]>
AuthorDate: Sun Jun 7 21:20:09 2026 +0100
Return common.ai SQLToolset errors to the agent so it self-corrects (#68117)
SQLToolset's query, get_schema, and list_tables tools raised on any
database error, failing the whole @task.agent task. The agent never saw
the error, so it could not correct its SQL and retry within the run.
call_tool now surfaces any tool failure to the agent as ModelRetry,
carrying the database's own message, without inspecting the error type or
text. pydantic-ai bounds the loop with the tool's max_retries, so a
fixable SQL error is corrected within the run while an unrecoverable one
(a bad connection, an auth failure) exhausts the budget and fails the
task for Airflow to retry.
The airflow_toolset_to_langchain_tools bridge feeds ModelRetry back to
the model as tool output; it now bounds that by the tool's max_retries
too, so a tool that keeps failing propagates instead of looping forever.
Removes the previous _is_retryable_query_error classifier and its
per-driver exception-tuple imports: matching exception class names and
error-message substrings is brittle and cannot cover every backend.
---
.../common/ai/toolsets/langchain_bridge.py | 38 ++++-
.../airflow/providers/common/ai/toolsets/sql.py | 88 +++-------
.../common/ai/toolsets/test_langchain_bridge.py | 34 ++++
.../ai/tests/unit/common/ai/toolsets/test_sql.py | 189 ++++++---------------
4 files changed, 143 insertions(+), 206 deletions(-)
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/toolsets/langchain_bridge.py
b/providers/common/ai/src/airflow/providers/common/ai/toolsets/langchain_bridge.py
index 3f5762679ae..3c2765c1a70 100644
---
a/providers/common/ai/src/airflow/providers/common/ai/toolsets/langchain_bridge.py
+++
b/providers/common/ai/src/airflow/providers/common/ai/toolsets/langchain_bridge.py
@@ -85,6 +85,11 @@ def airflow_toolset_to_langchain_tools(
works regardless of how the agent handles tool errors. Raising instead
would
abort the run under ``create_agent``'s default tool-error handling.
+ The retry message is bounded by the tool's ``max_retries``: a tool that
keeps
+ raising ``ModelRetry`` (for example an unrecoverable connection error)
stops
+ being fed back and propagates once the budget is exhausted, so the run
fails
+ instead of looping forever. The count resets after a successful call.
+
The toolset's ``get_tools`` is invoked eagerly here to enumerate the tools.
.. warning::
@@ -148,20 +153,39 @@ def _build_structured_tool(
# the args unchanged; a typed one coerces them (e.g. "5" -> 5).
return toolset_tool.args_validator.validate_python(kwargs)
+ # ModelRetry is a "feed this back to the model and retry" signal, so the
bridge
+ # returns its message as the tool output instead of raising (see
docstring).
+ # Bound it the way native pydantic-ai does, via the tool's max_retries: a
tool
+ # that keeps raising ModelRetry (e.g. an unrecoverable connection error)
must
+ # eventually propagate so the run fails rather than looping forever. The
count
+ # resets on the first successful call.
+ max_retries = toolset_tool.max_retries if toolset_tool.max_retries is not
None else 1
+ retries = {"count": 0}
+
+ def _handle_retry(error: ModelRetry) -> str:
+ retries["count"] += 1
+ if retries["count"] > max_retries:
+ # Reset before propagating so a reused tool starts the next run
with a
+ # fresh budget instead of staying permanently exhausted.
+ retries["count"] = 0
+ raise error
+ return str(error)
+
def _sync_call(**kwargs: Any) -> Any:
try:
- return _run_coro_sync(toolset.call_tool(name, _validate(kwargs),
ctx, toolset_tool))
+ result = _run_coro_sync(toolset.call_tool(name, _validate(kwargs),
ctx, toolset_tool))
except ModelRetry as e:
- # ModelRetry is a "feed this back to the model and retry" signal,
not a
- # failure. Return the message as the tool output so the model
self-corrects
- # (see docstring); raising would abort under create_agent's
default handling.
- return str(e)
+ return _handle_retry(e)
+ retries["count"] = 0
+ return result
async def _async_call(**kwargs: Any) -> Any:
try:
- return await toolset.call_tool(name, _validate(kwargs), ctx,
toolset_tool)
+ result = await toolset.call_tool(name, _validate(kwargs), ctx,
toolset_tool)
except ModelRetry as e:
- return str(e)
+ return _handle_retry(e)
+ retries["count"] = 0
+ return result
return structured_tool_cls.from_function(
func=_sync_call,
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/toolsets/sql.py
b/providers/common/ai/src/airflow/providers/common/ai/toolsets/sql.py
index ee3128705a1..45990901e15 100644
--- a/providers/common/ai/src/airflow/providers/common/ai/toolsets/sql.py
+++ b/providers/common/ai/src/airflow/providers/common/ai/toolsets/sql.py
@@ -19,7 +19,6 @@
from __future__ import annotations
import json
-import sqlite3
from contextlib import suppress
from typing import TYPE_CHECKING, Any
@@ -76,31 +75,6 @@ _CHECK_QUERY_SCHEMA: dict[str, Any] = {
"required": ["sql"],
}
-_POSTGRES_RETRYABLE_EXCEPTIONS: tuple[type[Exception], ...] = ()
-with suppress(ImportError):
- import psycopg2.errors as _psycopg2_errors
-
- _POSTGRES_RETRYABLE_EXCEPTIONS += (
- _psycopg2_errors.UndefinedColumn,
- _psycopg2_errors.UndefinedTable,
- )
-
-with suppress(ImportError):
- from psycopg import errors as _psycopg3_errors
-
- _POSTGRES_RETRYABLE_EXCEPTIONS += (
- _psycopg3_errors.UndefinedColumn,
- _psycopg3_errors.UndefinedTable,
- )
-
-_SQLALCHEMY_RETRYABLE_EXCEPTIONS: tuple[type[Exception], ...] = ()
-with suppress(ImportError):
- from sqlalchemy.exc import (
- ProgrammingError as _SQLAlchemyProgrammingError,
- )
-
- _SQLALCHEMY_RETRYABLE_EXCEPTIONS = (_SQLAlchemyProgrammingError,)
-
class SQLToolset(AbstractToolset[Any]):
"""
@@ -112,6 +86,13 @@ class SQLToolset(AbstractToolset[Any]):
Uses a :class:`~airflow.providers.common.sql.hooks.sql.DbApiHook` resolved
lazily from the given ``db_conn_id``.
+ When a tool fails, the database's error message is returned to the agent
as a
+ retry (:class:`pydantic_ai.ModelRetry`) so the model can correct its SQL
within
+ the run instead of failing the task. ``pydantic-ai`` bounds this by the
tool's
+ ``max_retries``, so an unrecoverable error -- a bad connection or an auth
+ failure -- exhausts the retries and fails the task for Airflow to retry.
The
+ toolset does not inspect the error type or message.
+
:param db_conn_id: Airflow connection ID for the database.
:param allowed_tables: Restrict which tables the agent can discover via
``list_tables`` and ``get_schema``. ``None`` (default) exposes all
tables
@@ -243,15 +224,27 @@ class SQLToolset(AbstractToolset[Any]):
ctx: RunContext[Any],
tool: ToolsetTool[Any],
) -> Any:
- if name == "list_tables":
- return self._list_tables()
- if name == "get_schema":
- return self._get_schema(tool_args["table_name"])
- if name == "query":
- return self._query(tool_args["sql"])
- if name == "check_query":
+ if name not in ("list_tables", "get_schema", "query", "check_query"):
+ raise ValueError(f"Unknown tool: {name!r}")
+ try:
+ if name == "list_tables":
+ return self._list_tables()
+ if name == "get_schema":
+ return self._get_schema(tool_args["table_name"])
+ if name == "query":
+ return self._query(tool_args["sql"])
return self._check_query(tool_args["sql"])
- raise ValueError(f"Unknown tool: {name!r}")
+ except Exception as e:
+ # Hand the database's own error back to the agent as a retry so it
can
+ # read the message and fix its SQL within the run. pydantic-ai
bounds
+ # this by the tool's max_retries, so an unrecoverable error (a bad
+ # connection, an auth failure) exhausts the budget and fails the
task
+ # for Airflow to retry, rather than being silently worked around.
+ raise ModelRetry(
+ f"The {name} tool failed: {e}\n"
+ "Use the list_tables and get_schema tools to inspect the
database, "
+ "then fix the query and try again."
+ ) from e
# ------------------------------------------------------------------
# Tool implementations
@@ -317,14 +310,7 @@ class SQLToolset(AbstractToolset[Any]):
allow_read_only_metadata=True,
)
- try:
- rows = hook.get_records(sql)
- except Exception as e:
- if self._is_retryable_query_error(hook, e):
- raise ModelRetry(
- f"error: {e!s}, Use get_schema and list_tables tools for
more details."
- ) from e
- raise
+ rows = hook.get_records(sql)
# Fetch column names from cursor description.
col_names: list[str] | None = None
if hook.last_description:
@@ -343,24 +329,6 @@ class SQLToolset(AbstractToolset[Any]):
output["max_rows"] = self._max_rows
return json.dumps(output, default=str)
- @staticmethod
- def _is_retryable_query_error(hook: DbApiHook, error: Exception) -> bool:
- check_error = getattr(error, "orig", error)
- conn_type = getattr(hook, "conn_type", None)
- if conn_type == "postgres":
- return bool(_POSTGRES_RETRYABLE_EXCEPTIONS) and isinstance(
- check_error, _POSTGRES_RETRYABLE_EXCEPTIONS
- )
- if conn_type == "sqlite":
- if isinstance(check_error, sqlite3.OperationalError):
- message = str(check_error).lower()
- return "no such column" in message or "no such table" in
message
- return False
- if _SQLALCHEMY_RETRYABLE_EXCEPTIONS and isinstance(error,
_SQLALCHEMY_RETRYABLE_EXCEPTIONS):
- return True
- # TODO: Add support for other databases.
- return False
-
def _check_query(self, sql: str) -> str:
# Resolve the dialect best-effort: if the connection can't be reached
we
# still syntax-check dialect-agnostically rather than reporting
invalid.
diff --git
a/providers/common/ai/tests/unit/common/ai/toolsets/test_langchain_bridge.py
b/providers/common/ai/tests/unit/common/ai/toolsets/test_langchain_bridge.py
index 6187f4e9d4c..89bd866772e 100644
--- a/providers/common/ai/tests/unit/common/ai/toolsets/test_langchain_bridge.py
+++ b/providers/common/ai/tests/unit/common/ai/toolsets/test_langchain_bridge.py
@@ -153,6 +153,40 @@ class TestAirflowToolsetToLangChainTools:
assert boom.invoke({}) == "fix your input and try again"
+ def test_repeated_model_retry_propagates_then_resets(self):
+ # A tool that keeps raising ModelRetry must not loop forever: once the
tool's
+ # max_retries (1 here) is exhausted, the error propagates so the run
fails
+ # instead of the bridge feeding the message back indefinitely. The
budget then
+ # resets so a reused tool is not poisoned for the next run.
+ boom = {t.name: t for t in
airflow_toolset_to_langchain_tools(FakeToolset())}["boom"]
+
+ assert boom.invoke({}) == "fix your input and try again" # fed back
+ with pytest.raises(ModelRetry, match="fix your input"): # budget
exhausted -> propagates
+ boom.invoke({})
+ assert boom.invoke({}) == "fix your input and try again" # reset ->
fed back again
+
+ def test_model_retry_count_resets_after_success(self):
+ # The retry budget resets on a successful call: fail (fed back),
succeed
+ # (reset), fail again (fed back rather than immediately propagating).
+ class FlakyToolset(FakeToolset):
+ def __init__(self) -> None:
+ super().__init__()
+ self._outcomes = iter([ModelRetry("retry 1"), "ok",
ModelRetry("retry 2")])
+
+ async def call_tool(self, name, tool_args, ctx, tool) -> Any:
+ if name == "boom":
+ outcome = next(self._outcomes)
+ if isinstance(outcome, ModelRetry):
+ raise outcome
+ return outcome
+ return await super().call_tool(name, tool_args, ctx, tool)
+
+ boom = {t.name: t for t in
airflow_toolset_to_langchain_tools(FlakyToolset())}["boom"]
+
+ assert boom.invoke({}) == "retry 1" # count 0 -> 1, returned
+ assert boom.invoke({}) == "ok" # success, count reset to 0
+ assert boom.invoke({}) == "retry 2" # count 0 -> 1 again, returned
(not propagated)
+
def test_model_retry_returned_as_tool_output_async(self):
boom = {t.name: t for t in
airflow_toolset_to_langchain_tools(FakeToolset())}["boom"]
diff --git a/providers/common/ai/tests/unit/common/ai/toolsets/test_sql.py
b/providers/common/ai/tests/unit/common/ai/toolsets/test_sql.py
index 5e425597a32..92e033a6b60 100644
--- a/providers/common/ai/tests/unit/common/ai/toolsets/test_sql.py
+++ b/providers/common/ai/tests/unit/common/ai/toolsets/test_sql.py
@@ -17,16 +17,13 @@
from __future__ import annotations
import asyncio
-import importlib.util
import json
-import sqlite3
from unittest.mock import MagicMock, PropertyMock, patch
import pytest
from pydantic_ai.exceptions import ModelRetry
from airflow.providers.common.ai.toolsets.sql import SQLToolset
-from airflow.providers.common.ai.utils.sql_validation import SQLSafetyError
from airflow.providers.common.sql.hooks.sql import DbApiHook
@@ -87,6 +84,17 @@ class TestSQLToolsetListTables:
tables = json.loads(result)
assert tables == ["orders"]
+ def test_introspection_error_raises_model_retry(self):
+ """A failure while listing tables is returned to the agent as a
retry."""
+ ts = SQLToolset("pg_default")
+ mock_hook = _make_mock_db_hook()
+ mock_hook.inspector.get_table_names.side_effect = Exception("could not
connect to server")
+ ts._hook = mock_hook
+
+ with pytest.raises(ModelRetry) as exc_info:
+ asyncio.run(ts.call_tool("list_tables", {}, ctx=MagicMock(),
tool=MagicMock()))
+ assert "could not connect to server" in exc_info.value.message
+
class TestSQLToolsetGetSchema:
def test_returns_column_info(self):
@@ -112,6 +120,18 @@ class TestSQLToolsetGetSchema:
assert "error" in data
assert "secrets" in data["error"]
+ def test_introspection_error_raises_model_retry(self):
+ """A failure while reading a table's schema is returned to the agent
as a retry."""
+ ts = SQLToolset("pg_default")
+ ts._hook = _make_mock_db_hook()
+ ts._hook.get_table_schema.side_effect = Exception('relation "users"
does not exist')
+
+ with pytest.raises(ModelRetry) as exc_info:
+ asyncio.run(
+ ts.call_tool("get_schema", {"table_name": "users"},
ctx=MagicMock(), tool=MagicMock())
+ )
+ assert "does not exist" in exc_info.value.message
+
class TestSQLToolsetQuery:
def test_returns_rows_as_json(self):
@@ -143,12 +163,14 @@ class TestSQLToolsetQuery:
assert data["truncated"] is True
assert data["count"] == 3
- def test_blocks_unsafe_sql_by_default(self):
+ def test_unsafe_sql_raises_model_retry(self):
+ """An unsafe statement is surfaced to the agent as a retry so it can
switch to a SELECT."""
ts = SQLToolset("pg_default")
ts._hook = _make_mock_db_hook()
- with pytest.raises(SQLSafetyError, match="not allowed"):
+ with pytest.raises(ModelRetry) as exc_info:
asyncio.run(ts.call_tool("query", {"sql": "DROP TABLE users"},
ctx=MagicMock(), tool=MagicMock()))
+ assert "not allowed" in exc_info.value.message
def test_allows_writes_when_enabled(self):
ts = SQLToolset("pg_default", allow_writes=True)
@@ -167,145 +189,30 @@ class TestSQLToolsetQuery:
data = json.loads(result)
assert "rows" in data
- def test_raises_model_retry_when_query_fails_with_retryable_error(self):
- """When the query fails with a retryable error, raise ModelRetry so
the model retries."""
- ts = SQLToolset("pg_default")
- ts._hook = _make_mock_db_hook()
- ts._hook.conn_type = "sqlite"
- ts._hook.get_records.side_effect = sqlite3.OperationalError("no such
column: nonexistent")
-
- with pytest.raises(ModelRetry) as exc_info:
- asyncio.run(
- ts.call_tool(
- "query",
- {"sql": "SELECT id, nonexistent FROM users"},
- ctx=MagicMock(),
- tool=MagicMock(),
- )
- )
- assert "nonexistent" in exc_info.value.message
- assert "get_schema" in exc_info.value.message
- assert "list_tables" in exc_info.value.message
-
- def test_model_retry_message_includes_schema_hint(self):
- """ModelRetry message tells the model to use get_schema and
list_tables for more details."""
- ts = SQLToolset("pg_default")
- ts._hook = _make_mock_db_hook()
- ts._hook.conn_type = "sqlite"
- ts._hook.get_records.side_effect = sqlite3.OperationalError("no such
table: missing_table")
-
- with pytest.raises(ModelRetry) as exc_info:
- asyncio.run(
- ts.call_tool("query", {"sql": "SELECT foo FROM x"},
ctx=MagicMock(), tool=MagicMock())
- )
- assert "get_schema" in exc_info.value.message
- assert "list_tables" in exc_info.value.message
-
- def test_non_retryable_error_is_propagated(self):
- ts = SQLToolset("pg_default")
- ts._hook = _make_mock_db_hook()
- ts._hook.conn_type = "sqlite"
- ts._hook.get_records.side_effect = sqlite3.OperationalError("database
is locked")
-
- with pytest.raises(sqlite3.OperationalError, match="database is
locked"):
- asyncio.run(ts.call_tool("query", {"sql": "SELECT 1"},
ctx=MagicMock(), tool=MagicMock()))
-
- def test_error_propagates_when_hook_conn_type_not_supported(self):
- ts = SQLToolset("pg_default")
- ts._hook = _make_mock_db_hook()
- ts._hook.conn_type = "mysql"
- ts._hook.get_records.side_effect = RuntimeError("unexpected db error")
-
- with pytest.raises(RuntimeError, match="unexpected db error"):
- asyncio.run(ts.call_tool("query", {"sql": "SELECT 1"},
ctx=MagicMock(), tool=MagicMock()))
-
- def test_error_propagates_when_hook_has_no_conn_type(self):
- ts = SQLToolset("pg_default")
- mock_hook = MagicMock(spec=["get_records", "last_description"])
- mock_hook.get_records.side_effect = RuntimeError("hook error")
- type(mock_hook).last_description = PropertyMock(return_value=[])
- ts._hook = mock_hook
-
- with pytest.raises(RuntimeError, match="hook error"):
- asyncio.run(ts.call_tool("query", {"sql": "SELECT 1"},
ctx=MagicMock(), tool=MagicMock()))
-
- @pytest.mark.skipif(
- importlib.util.find_spec("psycopg2") is None,
- reason="psycopg2 is not available for lowest dependency tests",
- )
- def
test_sqlalchemy_programming_error_with_psycopg2_undefined_column_orig_raises_model_retry_for_postgres(
- self,
- ):
- from psycopg2 import errors as psycopg2_errors
- from sqlalchemy.exc import ProgrammingError
-
- ts = SQLToolset("pg_default")
- ts._hook = _make_mock_db_hook()
- ts._hook.conn_type = "postgres"
- ts._hook.get_records.side_effect = ProgrammingError(
- statement="SELECT id, missing FROM users",
- params=None,
- orig=psycopg2_errors.UndefinedColumn('column "missing" does not
exist'),
- )
-
- with (
- patch(
-
"airflow.providers.common.ai.toolsets.sql._POSTGRES_RETRYABLE_EXCEPTIONS",
- (psycopg2_errors.UndefinedColumn,),
- ),
- patch(
-
"airflow.providers.common.ai.toolsets.sql._SQLALCHEMY_RETRYABLE_EXCEPTIONS",
- (ProgrammingError,),
- ),
- pytest.raises(ModelRetry),
- ):
- asyncio.run(
- ts.call_tool(
- "query",
- {"sql": "SELECT id, missing FROM users"},
- ctx=MagicMock(),
- tool=MagicMock(),
- )
- )
-
- @pytest.mark.skipif(
- importlib.util.find_spec("psycopg2") is None,
- reason="psycopg2 is not available for lowest dependency tests",
+ @pytest.mark.parametrize(
+ "error",
+ [
+ Exception("001003 (42000): SQL compilation error: unexpected
'rows'"),
+ RuntimeError("type mismatch"),
+ ConnectionError("could not connect to server"),
+ ],
)
- def
test_sqlalchemy_programming_error_with_psycopg2_insufficient_privilege_orig_is_not_retried_for_postgres(
- self,
- ):
- from psycopg2 import errors as psycopg2_errors
- from sqlalchemy.exc import ProgrammingError
-
+ def test_query_error_is_returned_to_agent_as_model_retry(self, error):
+ """Any error from the query, whatever its type, is handed back to the
agent as a retry with
+ the database's own message. The toolset never inspects the error type
or text; pydantic-ai's
+ max_retries bounds the loop, so an unrecoverable error still fails the
task."""
ts = SQLToolset("pg_default")
ts._hook = _make_mock_db_hook()
- ts._hook.conn_type = "postgres"
- ts._hook.get_records.side_effect = ProgrammingError(
- statement="SELECT id FROM users",
- params=None,
- orig=psycopg2_errors.InsufficientPrivilege("permission denied for
table users"),
- )
+ ts._hook.get_records.side_effect = error
- with (
- patch(
-
"airflow.providers.common.ai.toolsets.sql._POSTGRES_RETRYABLE_EXCEPTIONS",
- (psycopg2_errors.UndefinedColumn,
psycopg2_errors.UndefinedTable),
- ),
- patch(
-
"airflow.providers.common.ai.toolsets.sql._SQLALCHEMY_RETRYABLE_EXCEPTIONS",
- (ProgrammingError,),
- ),
- pytest.raises(ProgrammingError),
- ):
+ with pytest.raises(ModelRetry) as exc_info:
asyncio.run(
- ts.call_tool(
- "query",
- {"sql": "SELECT id FROM users"},
- ctx=MagicMock(),
- tool=MagicMock(),
- )
+ ts.call_tool("query", {"sql": "SELECT foo FROM bar"},
ctx=MagicMock(), tool=MagicMock())
)
+ message = exc_info.value.message
+ assert str(error) in message
+ assert "list_tables" in message
+ assert "get_schema" in message
class TestSQLToolsetCheckQuery:
@@ -536,8 +443,12 @@ class TestSQLToolsetMetadataStatements:
ts._hook = _make_mock_db_hook()
ts._hook.dialect_name = "postgresql"
- with pytest.raises(SQLSafetyError, match="not allowed"):
+ # The statement is rejected before execution and surfaced to the agent
as a
+ # retry; get_records is never reached, so the guardrail still holds.
+ with pytest.raises(ModelRetry) as exc_info:
asyncio.run(ts.call_tool("query", {"sql": sql}, ctx=MagicMock(),
tool=MagicMock()))
+ assert "not allowed" in exc_info.value.message
+ ts._hook.get_records.assert_not_called()
def test_check_query_accepts_describe(self):
ts = SQLToolset("pg_default")