This is an automated email from the ASF dual-hosted git repository.
gopidesu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new d349918fa4b SQLToolset: Retry model on query errors (#63358)
d349918fa4b is described below
commit d349918fa4b24d178b77761224260638ab59fe2e
Author: GPK <[email protected]>
AuthorDate: Thu Mar 12 07:54:40 2026 +0000
SQLToolset: Retry model on query errors (#63358)
* Add ModelRetry mechanism for sqltoolset to retry using RETRYABLE_ERRORS
* Move SQL retry classification into SQLToolset and narrow retryable errors
* Resolve comments
* fixup tests
---
.../airflow/providers/common/ai/toolsets/sql.py | 55 +++++++-
.../ai/tests/unit/common/ai/toolsets/test_sql.py | 143 +++++++++++++++++++++
2 files changed, 197 insertions(+), 1 deletion(-)
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/toolsets/sql.py
b/providers/common/ai/src/airflow/providers/common/ai/toolsets/sql.py
index f60f4b621c3..0902cff99f2 100644
--- a/providers/common/ai/src/airflow/providers/common/ai/toolsets/sql.py
+++ b/providers/common/ai/src/airflow/providers/common/ai/toolsets/sql.py
@@ -19,6 +19,8 @@
from __future__ import annotations
import json
+import sqlite3
+from contextlib import suppress
from typing import TYPE_CHECKING, Any
try:
@@ -29,6 +31,7 @@ except ImportError as e:
raise AirflowOptionalProviderFeatureException(e)
+from pydantic_ai.exceptions import ModelRetry
from pydantic_ai.tools import ToolDefinition
from pydantic_ai.toolsets.abstract import AbstractToolset, ToolsetTool
from pydantic_core import SchemaValidator, core_schema
@@ -70,6 +73,31 @@ _CHECK_QUERY_SCHEMA: dict[str, Any] = {
"required": ["sql"],
}
+_POSTGRES_RETRYABLE_EXCEPTIONS: tuple[type[Exception], ...] = ()
+with suppress(ImportError):
+ import psycopg2.errors as _psycopg2_errors
+
+ _POSTGRES_RETRYABLE_EXCEPTIONS += (
+ _psycopg2_errors.UndefinedColumn,
+ _psycopg2_errors.UndefinedTable,
+ )
+
+with suppress(ImportError):
+ from psycopg import errors as _psycopg3_errors
+
+ _POSTGRES_RETRYABLE_EXCEPTIONS += (
+ _psycopg3_errors.UndefinedColumn,
+ _psycopg3_errors.UndefinedTable,
+ )
+
+_SQLALCHEMY_RETRYABLE_EXCEPTIONS: tuple[type[Exception], ...] = ()
+with suppress(ImportError):
+ from sqlalchemy.exc import (
+ ProgrammingError as _SQLAlchemyProgrammingError,
+ )
+
+ _SQLALCHEMY_RETRYABLE_EXCEPTIONS = (_SQLAlchemyProgrammingError,)
+
class SQLToolset(AbstractToolset[Any]):
"""
@@ -204,7 +232,14 @@ class SQLToolset(AbstractToolset[Any]):
_validate_sql(sql)
hook = self._get_db_hook()
- rows = hook.get_records(sql)
+ try:
+ rows = hook.get_records(sql)
+ except Exception as e:
+ if self._is_retryable_query_error(hook, e):
+ raise ModelRetry(
+ f"error: {e!s}, Use get_schema and list_tables tools for
more details."
+ ) from e
+ raise
# Fetch column names from cursor description.
col_names: list[str] | None = None
if hook.last_description:
@@ -223,6 +258,24 @@ class SQLToolset(AbstractToolset[Any]):
output["max_rows"] = self._max_rows
return json.dumps(output, default=str)
+ @staticmethod
+ def _is_retryable_query_error(hook: DbApiHook, error: Exception) -> bool:
+ check_error = getattr(error, "orig", error)
+ conn_type = getattr(hook, "conn_type", None)
+ if conn_type == "postgres":
+ return bool(_POSTGRES_RETRYABLE_EXCEPTIONS) and isinstance(
+ check_error, _POSTGRES_RETRYABLE_EXCEPTIONS
+ )
+ if conn_type == "sqlite":
+ if isinstance(check_error, sqlite3.OperationalError):
+ message = str(check_error).lower()
+ return "no such column" in message or "no such table" in
message
+ return False
+ if _SQLALCHEMY_RETRYABLE_EXCEPTIONS and isinstance(error,
_SQLALCHEMY_RETRYABLE_EXCEPTIONS):
+ return True
+ # TODO: Add support for other databases.
+ return False
+
def _check_query(self, sql: str) -> str:
try:
_validate_sql(sql)
diff --git a/providers/common/ai/tests/unit/common/ai/toolsets/test_sql.py
b/providers/common/ai/tests/unit/common/ai/toolsets/test_sql.py
index 0573acd2a77..471b956385d 100644
--- a/providers/common/ai/tests/unit/common/ai/toolsets/test_sql.py
+++ b/providers/common/ai/tests/unit/common/ai/toolsets/test_sql.py
@@ -17,10 +17,13 @@
from __future__ import annotations
import asyncio
+import importlib.util
import json
+import sqlite3
from unittest.mock import MagicMock, PropertyMock, patch
import pytest
+from pydantic_ai.exceptions import ModelRetry
from airflow.providers.common.ai.toolsets.sql import SQLToolset
from airflow.providers.common.ai.utils.sql_validation import SQLSafetyError
@@ -165,6 +168,146 @@ class TestSQLToolsetQuery:
data = json.loads(result)
assert "rows" in data
+ def test_raises_model_retry_when_query_fails_with_retryable_error(self):
+ """When the query fails with a retryable error, raise ModelRetry so
the model retries."""
+ ts = SQLToolset("pg_default")
+ ts._hook = _make_mock_db_hook()
+ ts._hook.conn_type = "sqlite"
+ ts._hook.get_records.side_effect = sqlite3.OperationalError("no such
column: nonexistent")
+
+ with pytest.raises(ModelRetry) as exc_info:
+ asyncio.run(
+ ts.call_tool(
+ "query",
+ {"sql": "SELECT id, nonexistent FROM users"},
+ ctx=MagicMock(),
+ tool=MagicMock(),
+ )
+ )
+ assert "nonexistent" in exc_info.value.message
+ assert "get_schema" in exc_info.value.message
+ assert "list_tables" in exc_info.value.message
+
+ def test_model_retry_message_includes_schema_hint(self):
+ """ModelRetry message tells the model to use get_schema and
list_tables for more details."""
+ ts = SQLToolset("pg_default")
+ ts._hook = _make_mock_db_hook()
+ ts._hook.conn_type = "sqlite"
+ ts._hook.get_records.side_effect = sqlite3.OperationalError("no such
table: missing_table")
+
+ with pytest.raises(ModelRetry) as exc_info:
+ asyncio.run(
+ ts.call_tool("query", {"sql": "SELECT foo FROM x"},
ctx=MagicMock(), tool=MagicMock())
+ )
+ assert "get_schema" in exc_info.value.message
+ assert "list_tables" in exc_info.value.message
+
+ def test_non_retryable_error_is_propagated(self):
+ ts = SQLToolset("pg_default")
+ ts._hook = _make_mock_db_hook()
+ ts._hook.conn_type = "sqlite"
+ ts._hook.get_records.side_effect = sqlite3.OperationalError("database
is locked")
+
+ with pytest.raises(sqlite3.OperationalError, match="database is
locked"):
+ asyncio.run(ts.call_tool("query", {"sql": "SELECT 1"},
ctx=MagicMock(), tool=MagicMock()))
+
+ def test_error_propagates_when_hook_conn_type_not_supported(self):
+ ts = SQLToolset("pg_default")
+ ts._hook = _make_mock_db_hook()
+ ts._hook.conn_type = "mysql"
+ ts._hook.get_records.side_effect = RuntimeError("unexpected db error")
+
+ with pytest.raises(RuntimeError, match="unexpected db error"):
+ asyncio.run(ts.call_tool("query", {"sql": "SELECT 1"},
ctx=MagicMock(), tool=MagicMock()))
+
+ def test_error_propagates_when_hook_has_no_conn_type(self):
+ ts = SQLToolset("pg_default")
+ mock_hook = MagicMock(spec=["get_records", "last_description"])
+ mock_hook.get_records.side_effect = RuntimeError("hook error")
+ type(mock_hook).last_description = PropertyMock(return_value=[])
+ ts._hook = mock_hook
+
+ with pytest.raises(RuntimeError, match="hook error"):
+ asyncio.run(ts.call_tool("query", {"sql": "SELECT 1"},
ctx=MagicMock(), tool=MagicMock()))
+
+ @pytest.mark.skipif(
+ importlib.util.find_spec("psycopg2") is None,
+ reason="psycopg2 is not available for lowest dependency tests",
+ )
+ def
test_sqlalchemy_programming_error_with_psycopg2_undefined_column_orig_raises_model_retry_for_postgres(
+ self,
+ ):
+ from psycopg2 import errors as psycopg2_errors
+ from sqlalchemy.exc import ProgrammingError
+
+ ts = SQLToolset("pg_default")
+ ts._hook = _make_mock_db_hook()
+ ts._hook.conn_type = "postgres"
+ ts._hook.get_records.side_effect = ProgrammingError(
+ statement="SELECT id, missing FROM users",
+ params=None,
+ orig=psycopg2_errors.UndefinedColumn('column "missing" does not
exist'),
+ )
+
+ with (
+ patch(
+
"airflow.providers.common.ai.toolsets.sql._POSTGRES_RETRYABLE_EXCEPTIONS",
+ (psycopg2_errors.UndefinedColumn,),
+ ),
+ patch(
+
"airflow.providers.common.ai.toolsets.sql._SQLALCHEMY_RETRYABLE_EXCEPTIONS",
+ (ProgrammingError,),
+ ),
+ pytest.raises(ModelRetry),
+ ):
+ asyncio.run(
+ ts.call_tool(
+ "query",
+ {"sql": "SELECT id, missing FROM users"},
+ ctx=MagicMock(),
+ tool=MagicMock(),
+ )
+ )
+
+ @pytest.mark.skipif(
+ importlib.util.find_spec("psycopg2") is None,
+ reason="psycopg2 is not available for lowest dependency tests",
+ )
+ def
test_sqlalchemy_programming_error_with_psycopg2_insufficient_privilege_orig_is_not_retried_for_postgres(
+ self,
+ ):
+ from psycopg2 import errors as psycopg2_errors
+ from sqlalchemy.exc import ProgrammingError
+
+ ts = SQLToolset("pg_default")
+ ts._hook = _make_mock_db_hook()
+ ts._hook.conn_type = "postgres"
+ ts._hook.get_records.side_effect = ProgrammingError(
+ statement="SELECT id FROM users",
+ params=None,
+ orig=psycopg2_errors.InsufficientPrivilege("permission denied for
table users"),
+ )
+
+ with (
+ patch(
+
"airflow.providers.common.ai.toolsets.sql._POSTGRES_RETRYABLE_EXCEPTIONS",
+ (psycopg2_errors.UndefinedColumn,
psycopg2_errors.UndefinedTable),
+ ),
+ patch(
+
"airflow.providers.common.ai.toolsets.sql._SQLALCHEMY_RETRYABLE_EXCEPTIONS",
+ (ProgrammingError,),
+ ),
+ pytest.raises(ProgrammingError),
+ ):
+ asyncio.run(
+ ts.call_tool(
+ "query",
+ {"sql": "SELECT id FROM users"},
+ ctx=MagicMock(),
+ tool=MagicMock(),
+ )
+ )
+
class TestSQLToolsetCheckQuery:
def test_valid_select(self):