This is an automated email from the ASF dual-hosted git repository.

kaxil pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new b193ced28a7 Enforce SQLToolset allowed_tables on queries, not just 
discovery (#68487)
b193ced28a7 is described below

commit b193ced28a79ba567d36d95ad26445fb4719794a
Author: Kaxil Naik <[email protected]>
AuthorDate: Tue Jun 16 13:17:10 2026 +0100

    Enforce SQLToolset allowed_tables on queries, not just discovery (#68487)
    
    allowed_tables previously restricted only metadata discovery (list_tables /
    get_schema); the query and check_query tools never checked it, so an agent
    could read any table by name. It is now enforced on the query tools as well:
    the SQL is parsed with sqlglot and rejected before execution if it reaches a
    table that is not on the list, resolved with its database/catalog and with 
CTE
    references excluded by lexical scope.
    
    Constructs a schema.table list cannot describe are rejected while a list is
    active: table-valued functions, TABLE('name') row sources, the TABLE <name>
    shorthand, SHOW, dynamic SQL, quoted identifiers, cross-database references,
    and inline comments (where MySQL executable /*! ... */ comments hide).
    
    This is application-level defense-in-depth, not a substitute for database
    permissions: data reached through a function whose argument is itself SQL 
or a
    path (pg_read_file, query_to_xml, scalar dblink) is out of its reach, so a
    least-privilege DB role remains the hard boundary.
    
    * Allow CTE sources in DML against allowed_tables
    
    A CTE used as a DML source (WITH src AS (...) INSERT INTO orders SELECT * 
FROM
    src) was falsely rejected: CTE scoping was disabled for the whole DML 
statement,
    so src was checked against allowed_tables as if it were a base table.
    
    Now only the DML target is exempt from CTE resolution (you cannot write to 
a CTE,
    so a same-named CTE never shadows the target); sources follow normal 
lexical CTE
    scoping. The target, and any off-list table a CTE body actually reads, are 
still
    enforced.
---
 providers/common/ai/docs/toolsets.rst              |  78 ++++--
 .../airflow/providers/common/ai/toolsets/sql.py    | 155 ++++++++---
 .../providers/common/ai/utils/sql_validation.py    | 259 +++++++++++++++++--
 .../ai/tests/unit/common/ai/toolsets/test_sql.py   | 282 +++++++++++++++++++++
 .../unit/common/ai/utils/test_sql_validation.py    | 192 ++++++++++++++
 5 files changed, 883 insertions(+), 83 deletions(-)

diff --git a/providers/common/ai/docs/toolsets.rst 
b/providers/common/ai/docs/toolsets.rst
index 33814d50644..d284227c0ac 100644
--- a/providers/common/ai/docs/toolsets.rst
+++ b/providers/common/ai/docs/toolsets.rst
@@ -155,10 +155,10 @@ the validator, so ``SHOW`` is recognized on databases 
that support it (Snowflake
 MySQL, etc.); on databases without ``SHOW`` it stays rejected. Data-modifying
 statements remain blocked -- including ones hidden behind 
``DESCRIBE``/``EXPLAIN``
 (e.g. ``EXPLAIN DELETE ...``, ``DESCRIBE DROP TABLE ...``), which the validator
-rejects by scanning the parsed statement for write operations. Like ``SELECT``,
-metadata statements are not scoped by ``allowed_tables`` (see
-:ref:`allowed-tables-limitation`) -- an agent can ``DESCRIBE`` a table outside 
the
-list, so rely on database permissions to restrict access.
+rejects by scanning the parsed statement for write operations. When
+``allowed_tables`` is set it scopes these statements too: a ``DESCRIBE`` names 
a
+table, so its target must be on the list, while ``SHOW`` enumerates objects 
beyond
+any single table and is rejected outright (see 
:ref:`allowed-tables-enforcement`).
 
 Multi-schema warehouses
 ^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -186,11 +186,13 @@ Parameters
 ^^^^^^^^^^
 
 - ``db_conn_id``: Airflow connection ID for the database.
-- ``allowed_tables``: Restrict which tables the agent can discover via
-  ``list_tables`` and ``get_schema``. ``None`` (default) exposes all tables in
-  ``schema``. Entries may be schema-qualified (``"SCHEMA.TABLE"``) to span
-  multiple schemas; see above. Matching is case-insensitive.
-  See :ref:`allowed-tables-limitation` for an important caveat.
+- ``allowed_tables``: Restrict the agent to a fixed set of tables. ``None``
+  (default) exposes all tables in ``schema``. Entries may be schema-qualified
+  (``"SCHEMA.TABLE"``) to span multiple schemas; see above. Matching is
+  case-insensitive. When set, the list is enforced on ``query`` and
+  ``check_query`` as well as discovery -- every table a query references must 
be
+  on it. See :ref:`allowed-tables-enforcement` for what this does and does not
+  guarantee.
 - ``schema``: Default schema/namespace for unqualified table listing and
   introspection. Schema-qualified ``allowed_tables`` entries override it per 
table.
 - ``allow_writes``: Allow data-modifying SQL (INSERT, UPDATE, DELETE, etc.).
@@ -583,11 +585,14 @@ No single layer is sufficient — they work together.
        INTO, and other non-SELECT statements.
      - Does not prevent the agent from reading any registered data source.
    * - **SQLToolset: allowed_tables**
-     - Restricts which tables appear in ``list_tables`` and ``get_schema``
-       responses, limiting the agent's knowledge of the schema.
-     - Does **not** validate table references in SQL queries. The agent can
-       still query unlisted tables if it guesses the name. See
-       :ref:`allowed-tables-limitation` below.
+     - Restricts the agent to listed tables across ``list_tables``,
+       ``get_schema``, ``query``, and ``check_query``. Queries are parsed and
+       every referenced table (including via subqueries, CTEs, JOINs, and
+       ``DESCRIBE``) is checked against the list before execution.
+     - Cannot police data reached through side-effecting scalar functions
+       (e.g. ``pg_read_file``), and is only as exact as the SQL parser. Pair it
+       with least-privilege database grants. See
+       :ref:`allowed-tables-enforcement` below.
    * - **SQLToolset: max_rows**
      - Truncates query results to ``max_rows`` (default 50), preventing the
        agent from pulling entire tables into context.
@@ -604,21 +609,38 @@ No single layer is sufficient — they work together.
      - Requires explicit configuration — the default allows many rounds.
 
 
-.. _allowed-tables-limitation:
+.. _allowed-tables-enforcement:
 
-The ``allowed_tables`` Limitation
-"""""""""""""""""""""""""""""""""
+How ``allowed_tables`` Is Enforced
+""""""""""""""""""""""""""""""""""
 
-``allowed_tables`` is a **metadata filter**, not an access control mechanism.
-It hides table names from ``list_tables`` and blocks ``get_schema`` for
-unlisted tables, but does not parse SQL queries to validate table references.
+When ``allowed_tables`` is set it governs every tool, not just discovery:
 
-An LLM can craft ``SELECT * FROM secrets`` even when
-``allowed_tables=["orders"]``. Parsing SQL for table references (including
-CTEs, subqueries, aliases, and vendor-specific syntax) is complex and
-error-prone; we chose not to provide a false sense of security.
+- ``list_tables`` and ``get_schema`` only reveal listed tables.
+- ``query`` and ``check_query`` parse the SQL with `sqlglot
+  <https://github.com/tobymao/sqlglot>`_ and reject it before execution if it
+  references any table that is not on the list. Tables reached indirectly are
+  caught too -- through subqueries, CTEs, JOINs, set operations (``UNION`` 
etc.),
+  ``DESCRIBE``, catalog views such as ``information_schema``, and DML. CTE
+  references are excluded by lexical scope, so a same-named CTE in another 
scope
+  cannot hide a real table, and the database/catalog is part of the match, so a
+  cross-database reference like ``otherdb.public.orders`` is refused.
+- Constructs the list cannot describe are rejected outright while it is active:
+  table-valued functions (``dblink``), ``TABLE('name')`` row sources, the
+  ``TABLE <name>`` shorthand, ``SHOW``, dynamic SQL (``EXEC``), and **inline
+  comments** -- the last because parser-vs-engine differences hide in comments
+  (MySQL executes ``/*! ... */`` while sqlglot and other engines ignore it).
 
-For query-level restrictions, use database permissions:
+So ``SELECT * FROM secrets`` with ``allowed_tables=["orders"]`` is refused, and
+the rejection is handed back to the agent so it can re-target an allowed table.
+
+This is a strong **application-level guardrail**, but it is not a substitute 
for
+database permissions. It cannot police data reached through a function whose
+argument is itself SQL or a path: ``pg_read_file('/etc/passwd')`` reads a file,
+and ``query_to_xml('SELECT * FROM other_table', ...)`` or a scalar ``dblink``
+reads a table through a string the parser cannot inspect. Any query the engine
+parses differently from sqlglot is also a residual gap. For a hard boundary, 
also
+run the connection as a least-privilege role:
 
 .. code-block:: sql
 
@@ -627,8 +649,10 @@ For query-level restrictions, use database permissions:
     GRANT SELECT ON orders, customers TO airflow_agent_reader;
     -- Use this role's credentials in the Airflow connection
 
-The Airflow connection should use a database user with the minimum privileges
-required.
+Defense in depth: the allow-list contains the agent's *intent* (and gives it a
+correctable error), while the database role is the boundary that holds even if
+the agent reaches data the parser cannot see. The connection should use a
+database user with the minimum privileges required.
 
 
 HookToolset Guidelines
diff --git 
a/providers/common/ai/src/airflow/providers/common/ai/toolsets/sql.py 
b/providers/common/ai/src/airflow/providers/common/ai/toolsets/sql.py
index 45990901e15..a4435971969 100644
--- a/providers/common/ai/src/airflow/providers/common/ai/toolsets/sql.py
+++ b/providers/common/ai/src/airflow/providers/common/ai/toolsets/sql.py
@@ -24,6 +24,9 @@ from typing import TYPE_CHECKING, Any
 
 try:
     from airflow.providers.common.ai.utils.sql_validation import (
+        SQLSafetyError,
+        collect_table_references,
+        parse_sql as _parse_sql,
         resolve_sqlglot_dialect,
         validate_sql as _validate_sql,
     )
@@ -94,21 +97,37 @@ class SQLToolset(AbstractToolset[Any]):
     toolset does not inspect the error type or message.
 
     :param db_conn_id: Airflow connection ID for the database.
-    :param allowed_tables: Restrict which tables the agent can discover via
-        ``list_tables`` and ``get_schema``. ``None`` (default) exposes all 
tables
-        in ``schema``. Entries may be schema-qualified (``"SCHEMA.TABLE"``) to 
span
-        multiple schemas in one database -- common on warehouses such as 
Snowflake.
-        ``list_tables`` then introspects each referenced schema and returns the
-        matching tables fully qualified, and ``get_schema`` routes to the 
table's
-        own schema. Unqualified entries use ``schema``. Matching is
-        case-insensitive, since databases reflect identifiers in their own 
case.
+    :param allowed_tables: Restrict the agent to a fixed set of tables. 
``None``
+        (default) exposes every table in ``schema``. Entries may be 
schema-qualified
+        (``"SCHEMA.TABLE"``) to span multiple schemas in one database -- 
common on
+        warehouses such as Snowflake. ``list_tables`` introspects each 
referenced
+        schema and returns the matching tables fully qualified, and 
``get_schema``
+        routes to the table's own schema. Unqualified entries use ``schema``.
+        Matching is case-insensitive, since databases reflect identifiers in 
their
+        own case.
+
+        When set, the list is enforced on the ``query`` and ``check_query`` 
tools as
+        well as on discovery: every table a query reaches -- through 
subqueries, CTEs,
+        JOINs, set operations, ``DESCRIBE``, catalog views such as
+        ``information_schema``, or DML -- must be on the list, resolved with 
its
+        database/catalog, or the query is rejected before it runs. CTE 
references are
+        excluded by lexical scope (a same-named CTE in another scope never 
hides a real
+        table). Constructs the list cannot describe are rejected outright 
while it is
+        active: table-valued functions (``dblink``), ``TABLE('name')`` row 
sources, the
+        ``TABLE <name>`` shorthand, ``SHOW``, dynamic SQL, and **inline 
comments**
+        (where parser-vs-engine differences such as MySQL ``/*! ... */`` 
executable
+        comments hide).
 
         .. note::
-            ``allowed_tables`` controls metadata visibility only. It does 
**not**
-            parse or validate table references in SQL queries. An LLM can still
-            query tables outside this list if it guesses the name. For 
query-level
-            restrictions, use database-level permissions (e.g. a read-only role
-            with grants limited to specific tables).
+            This is an application-level guardrail, enforced by parsing the 
SQL with
+            sqlglot. It is strong defense-in-depth but not a substitute for 
database
+            permissions: it cannot police data reached through a function whose
+            argument is itself SQL or a path -- ``pg_read_file('...')`` (a 
file) or
+            ``query_to_xml('SELECT ... FROM other_table', ...)`` and 
``dblink`` in
+            scalar position (a table, read through a string the parser cannot 
inspect)
+            -- and any query the engine parses differently from sqlglot is a 
residual
+            gap. For a hard guarantee, also point ``db_conn_id`` at a 
least-privilege
+            role whose ``SELECT`` grants are limited to the same tables.
 
     :param schema: Default schema/namespace for table listing and 
introspection,
         used for unqualified ``allowed_tables`` entries and unqualified
@@ -131,42 +150,63 @@ class SQLToolset(AbstractToolset[Any]):
     ) -> None:
         self._db_conn_id = db_conn_id
         self._allowed_tables: frozenset[str] | None = 
frozenset(allowed_tables) if allowed_tables else None
-        # Case-folded view for membership tests: databases reflect identifiers 
in
-        # their own case (Snowflake stores unquoted names uppercase but 
reflects
-        # them lowercased), so a byte-exact match against the user's entries 
would
-        # silently miss. allowed_tables is a visibility hint, not access 
control,
-        # so case-insensitive matching is safe.
-        self._allowed_tables_ci: frozenset[str] | None = (
-            frozenset(t.casefold() for t in self._allowed_tables)
-            if self._allowed_tables is not None
-            else None
-        )
         self._schema = schema
         self._allow_writes = allow_writes
         self._max_rows = max_rows
         self._hook: DbApiHook | None = None
 
-        # Derive which schemas to introspect from schema-qualified 
allowed_tables.
+        # Canonical ``(catalog, schema, table)`` view of allowed_tables for 
membership
+        # tests, plus the schemas to introspect. Built once: every reference 
-- a
+        # discovery hit, a get_schema arg, or a table parsed out of a query -- 
is
+        # normalised to the same shape and matched against this set.
+        #
+        # Identifiers are case-folded: databases reflect them in their own case
+        # (Snowflake stores unquoted names uppercase but reflects them 
lowercased), so
+        # a byte-exact match against the user's entries would silently miss. 
Unqualified
+        # entries resolve to the default ``schema`` (``None`` when unset) so 
that
+        # ``"orders"`` and ``"<schema>.orders"`` denote the same table. 
Allow-list
+        # entries carry no catalog, so any catalog-qualified reference
+        # (``otherdb.public.orders``) has a non-null catalog in its key and 
cannot match
+        # -- that closes cross-database access the single-connection 
allow-list can't
+        # describe.
+        self._allowed_canonical: frozenset[tuple[str | None, str | None, str]] 
| None = None
         # Qualified entries ("SCHEMA.TABLE") are listed under their own schema 
and
         # returned fully qualified; unqualified entries (and allow-all) use the
         # default ``schema``.
         self._qualified_schemas: frozenset[str] = frozenset()
         self._include_default_schema: bool = True
         if self._allowed_tables is not None:
+            canonical: set[tuple[str | None, str | None, str]] = set()
             qualified_schemas: set[str] = set()
             include_default = False
             for entry in self._allowed_tables:
-                entry_schema, sep, _ = entry.rpartition(".")
+                entry_schema, sep, table = entry.rpartition(".")
                 if sep:
                     qualified_schemas.add(entry_schema)
+                    canonical.add(self._canonical_ref("", entry_schema, table))
                 else:
                     include_default = True
+                    canonical.add(self._canonical_ref("", self._schema, entry))
+            self._allowed_canonical = frozenset(canonical)
             self._qualified_schemas = frozenset(qualified_schemas)
             self._include_default_schema = include_default
 
-    def _is_table_allowed(self, name: str) -> bool:
-        """Case-insensitive membership test against ``allowed_tables`` 
(allow-all when unset)."""
-        return self._allowed_tables_ci is None or name.casefold() in 
self._allowed_tables_ci
+    @staticmethod
+    def _canonical_ref(
+        catalog: str | None, schema: str | None, table: str
+    ) -> tuple[str | None, str | None, str]:
+        """Normalise a ``(catalog, schema, table)`` reference to its 
case-folded comparison key."""
+        return (
+            catalog.casefold() if catalog else None,
+            schema.casefold() if schema else None,
+            table.casefold(),
+        )
+
+    def _is_ref_allowed(self, catalog: str | None, schema: str | None, table: 
str) -> bool:
+        """Membership test for a resolved ``(catalog, schema, table)`` 
reference (allow-all when unset)."""
+        if self._allowed_canonical is None:
+            return True
+        return self._canonical_ref(catalog, schema, table) in 
self._allowed_canonical
 
     @property
     def id(self) -> str:
@@ -263,11 +303,11 @@ class SQLToolset(AbstractToolset[Any]):
         # Dedupe by (schema, table) so a table reachable both qualified and 
via the
         # default schema (e.g. "public.users" and "users" with 
schema="public") is
         # listed once. Case-folded because databases reflect identifiers in 
their case.
-        seen: set[tuple[str | None, str]] = set()
+        seen: set[tuple[str | None, str | None, str]] = set()
 
         def add(schema: str | None, name: str, display: str) -> None:
-            key = (schema.casefold() if schema else None, name.casefold())
-            if self._is_table_allowed(display) and key not in seen:
+            key = self._canonical_ref("", schema, name)
+            if self._is_ref_allowed("", schema, name) and key not in seen:
                 seen.add(key)
                 tables.append(display)
 
@@ -286,10 +326,10 @@ class SQLToolset(AbstractToolset[Any]):
         return json.dumps(tables)
 
     def _get_schema(self, table_name: str) -> str:
-        if not self._is_table_allowed(table_name):
+        schema, table = self._split_table_identifier(table_name)
+        if not self._is_ref_allowed("", schema, table):
             return json.dumps({"error": f"Table {table_name!r} is not in the 
allowed tables list."})
         hook = self._get_db_hook()
-        schema, table = self._split_table_identifier(table_name)
         columns = hook.get_table_schema(table, schema=schema)
         return json.dumps(columns)
 
@@ -300,15 +340,19 @@ class SQLToolset(AbstractToolset[Any]):
 
     def _query(self, sql: str) -> str:
         hook = self._get_db_hook()
+        dialect = self._dialect_for_validation()
+        statements: list[Any] | None = None
         if not self._allow_writes:
             # allow_read_only_metadata lets agents inspect schemas with 
DESCRIBE/SHOW
             # (a common first move) instead of hard-failing; the deep scan 
still
             # rejects any data-modifying statement, including EXPLAIN <write>.
-            _validate_sql(
-                sql,
-                dialect=self._dialect_for_validation(),
-                allow_read_only_metadata=True,
-            )
+            statements = _validate_sql(sql, dialect=dialect, 
allow_read_only_metadata=True)
+        elif self._allowed_canonical is not None:
+            # Writes are allowed but tables are restricted: parse anyway so the
+            # allow-list still governs which tables a write may touch.
+            statements = _parse_sql(sql, dialect=dialect)
+        if statements is not None:
+            self._enforce_allowed_tables(statements)
 
         rows = hook.get_records(sql)
         # Fetch column names from cursor description.
@@ -336,7 +380,40 @@ class SQLToolset(AbstractToolset[Any]):
         with suppress(Exception):
             dialect = self._dialect_for_validation()
         try:
-            _validate_sql(sql, dialect=dialect, allow_read_only_metadata=True)
+            statements = _validate_sql(sql, dialect=dialect, 
allow_read_only_metadata=True)
+            self._enforce_allowed_tables(statements)
             return json.dumps({"valid": True})
         except Exception as e:
             return json.dumps({"valid": False, "error": str(e)})
+
+    def _enforce_allowed_tables(self, statements: list[Any]) -> None:
+        """
+        Reject a parsed query that reaches any table outside 
``allowed_tables``.
+
+        No-op when ``allowed_tables`` is unset (allow-all). Otherwise every 
table the
+        query references (resolved scope-correctly, including catalog) must be 
on the
+        list, and any construct the list cannot describe -- a table-valued 
function,
+        ``SHOW``, dynamic SQL, an inline comment, or the ``TABLE <name>`` 
shorthand --
+        is refused. Raises :class:`SQLSafetyError` -- ``call_tool`` turns it 
into a
+        ``ModelRetry`` so the agent can re-target an allowed table, while
+        ``check_query`` reports it invalid.
+        """
+        if self._allowed_canonical is None:
+            return
+        scan = collect_table_references(statements)
+        if scan.unverifiable_sources:
+            raise SQLSafetyError(
+                f"Query uses a data source that cannot be checked against 
allowed_tables: "
+                f"{'; '.join(scan.unverifiable_sources)}. Query the allowed 
tables directly: "
+                f"use list_tables to see them."
+            )
+        disallowed = [
+            ".".join(part for part in (catalog, schema, table) if part)
+            for catalog, schema, table in scan.tables
+            if not self._is_ref_allowed(catalog, schema or self._schema, table)
+        ]
+        if disallowed:
+            raise SQLSafetyError(
+                f"Query references tables that are not in the allowed tables 
list: "
+                f"{', '.join(sorted(set(disallowed)))}. Use list_tables to see 
the allowed tables."
+            )
diff --git 
a/providers/common/ai/src/airflow/providers/common/ai/utils/sql_validation.py 
b/providers/common/ai/src/airflow/providers/common/ai/utils/sql_validation.py
index a00b4dc11e6..3b39210710e 100644
--- 
a/providers/common/ai/src/airflow/providers/common/ai/utils/sql_validation.py
+++ 
b/providers/common/ai/src/airflow/providers/common/ai/utils/sql_validation.py
@@ -24,6 +24,8 @@ This is safer than a denylist because new/unexpected 
statement types
 
 from __future__ import annotations
 
+from typing import NamedTuple
+
 import sqlglot
 from sqlglot import exp
 from sqlglot.dialects import Dialects
@@ -105,6 +107,245 @@ class SQLSafetyError(Exception):
     """Generated SQL failed safety validation."""
 
 
+def parse_sql(
+    sql: str,
+    *,
+    dialect: str | None = None,
+    allow_multiple_statements: bool = False,
+) -> list[exp.Expr]:
+    """
+    Parse SQL into statements, enforcing the empty- and multi-statement guards 
only.
+
+    Shared by :func:`validate_sql` (which then applies statement-type checks) 
and by
+    callers that need the parsed AST for their own analysis -- e.g. 
table-reference
+    extraction for ``allowed_tables`` enforcement -- without the read-only 
allow-list.
+
+    :param sql: SQL string to parse.
+    :param dialect: SQL dialect for parsing (``postgres``, ``mysql``, etc.).
+    :param allow_multiple_statements: Whether to allow multiple 
semicolon-separated
+        statements. Default ``False`` -- multi-statement input can hide a 
dangerous
+        operation after a benign one.
+    :return: List of parsed sqlglot Expression objects (never empty).
+    :raises SQLSafetyError: If the SQL is empty, cannot be parsed, or contains 
multiple
+        statements when not permitted.
+    """
+    if not sql or not sql.strip():
+        raise SQLSafetyError("Empty SQL input.")
+
+    try:
+        statements = sqlglot.parse(sql, dialect=dialect, 
error_level=ErrorLevel.RAISE)
+    except sqlglot.errors.ParseError as e:
+        raise SQLSafetyError(f"SQL parse error: {e}") from e
+
+    # sqlglot.parse can return [None] for empty input
+    parsed = [s for s in statements if s is not None]
+    if not parsed:
+        raise SQLSafetyError("Empty SQL input.")
+
+    if not allow_multiple_statements and len(parsed) > 1:
+        raise SQLSafetyError(
+            f"Multiple statements detected ({len(parsed)}). Only single 
statements are allowed by default."
+        )
+    return parsed
+
+
+class TableScan(NamedTuple):
+    """Result of :func:`collect_table_references`."""
+
+    #: ``(catalog, schema, table)`` for every real base table referenced 
anywhere in
+    #: the AST. ``catalog`` and ``schema`` are ``""`` when the reference omits 
them.
+    #: In-scope CTE references are excluded. Catalog is reported so the caller 
can
+    #: reject cross-database references (``otherdb.public.orders``) that a
+    #: ``schema.table`` allow-list cannot describe.
+    tables: list[tuple[str, str, str]]
+    #: Human-readable descriptions of constructs that cannot be checked 
against an
+    #: allow-list and so must be rejected while one is active: table-valued 
functions
+    #: (``dblink``), ``TABLE('name')`` row sources, ``SHOW``, dynamic SQL
+    #: (``EXEC``/``Command``), inline comments (a parser-vs-engine 
differential), and
+    #: the ``TABLE <name>`` shorthand. Empty when every construct is 
verifiable.
+    unverifiable_sources: list[str]
+
+
+_DML_TYPES: tuple[type[exp.Expr], ...] = (exp.Insert, exp.Update, exp.Delete, 
exp.Merge)
+
+
+def _same_identifier(a: exp.Identifier, b: exp.Identifier) -> bool:
+    """
+    Compare two identifiers under standard identifier-folding rules.
+
+    Unquoted names fold (case-insensitive); quoted names are case-preserving 
and
+    distinct from unquoted ones. Used to decide whether a table reference 
names a CTE:
+    being *stricter* here is safe -- a near-miss falls through to the 
allow-list check.
+    """
+    aq, bq = bool(a.args.get("quoted")), bool(b.args.get("quoted"))
+    if not aq and not bq:
+        return str(a.this).casefold() == str(b.this).casefold()
+    if aq and bq:
+        return str(a.this) == str(b.this)
+    return False
+
+
+def _enclosing_cte(table: exp.Expr, with_: exp.With) -> exp.CTE | None:
+    """Return the CTE of ``with_`` whose *definition* contains ``table`` (else 
``None``)."""
+    node = table.parent
+    while node is not None and node is not with_.parent:
+        if isinstance(node, exp.CTE) and node.parent is with_:
+            return node
+        node = node.parent
+    return None
+
+
+def _is_in_scope_cte(table: exp.Table) -> bool:
+    """
+    Report whether ``table`` is a bare reference resolved by a CTE visible at 
its scope.
+
+    Walks the ancestor chain (lexical scope) collecting CTE names from each 
enclosing
+    ``WITH``. A CTE defined in a *sibling* or *inner* subquery is not an 
ancestor, so a
+    real top-level table is never excluded by an unrelated same-named CTE
+    (``SELECT * FROM secret WHERE id IN (WITH secret AS (...) SELECT ...)``). A
+    non-recursive CTE is not visible inside its own definition, so
+    ``WITH secret AS (SELECT * FROM secret) ...`` still reports the real 
``secret``.
+    CTE order matters too: inside one CTE's body only *earlier* siblings are 
in scope
+    (forward references need ``RECURSIVE``), so ``WITH a AS (SELECT * FROM 
secret),
+    secret AS (...) SELECT * FROM a`` still reports the real ``secret`` read 
by ``a``.
+    """
+    ref = table.this
+    if not isinstance(ref, exp.Identifier):
+        return False
+    node: exp.Expr | None = table.parent
+    while node is not None:
+        # A WITH attaches to its owning query (Select/Union/DML) as a sibling 
of the
+        # body, so the query -- an ancestor of the table -- holds it. Find it 
by type
+        # rather than a fixed arg key (sqlglot has used both ``with`` and 
``with_``).
+        with_ = (
+            next((v for v in node.args.values() if isinstance(v, exp.With)), 
None)
+            if isinstance(node, exp.Expression)
+            else None
+        )
+        if isinstance(with_, exp.With):
+            recursive = bool(with_.args.get("recursive"))
+            ctes = list(with_.expressions)
+            enclosing = _enclosing_cte(table, with_)
+            # If the reference sits inside CTE E's own body, only CTEs defined 
*before*
+            # E are visible there (plus E itself when RECURSIVE); a CTE 
defined after E
+            # is not yet in scope. In the main query body every CTE is visible.
+            enclosing_idx = next((i for i, c in enumerate(ctes) if c is 
enclosing), None)
+            for idx, cte in enumerate(ctes):
+                if enclosing_idx is not None:
+                    if idx > enclosing_idx:
+                        continue
+                    if idx == enclosing_idx and not recursive:
+                        continue
+                alias = cte.args.get("alias")
+                cte_ident = alias.this if isinstance(alias, exp.TableAlias) 
else None
+                if isinstance(cte_ident, exp.Identifier) and 
_same_identifier(cte_ident, ref):
+                    return True
+        node = node.parent
+    return False
+
+
+def collect_table_references(statements: list[exp.Expr]) -> TableScan:
+    """
+    Walk parsed statements and report every real table they reach, 
scope-correctly.
+
+    This is the AST half of ``allowed_tables`` enforcement: it returns the 
concrete
+    base tables a query reaches (including those nested in subqueries, CTEs, 
JOINs, set
+    operations, ``DESCRIBE``, and DML) as ``(catalog, schema, table)`` so the 
caller can
+    check each against its allow-list, plus a list of constructs that cannot 
be checked
+    and must therefore be rejected while an allow-list is active.
+
+    Handled carefully (each was a confirmed bypass before it was closed):
+
+    - **CTE references are excluded by lexical scope, not by name.** A table 
is treated
+      as a CTE only when a ``WITH`` *enclosing that reference* defines the 
name (see
+      :func:`_is_in_scope_cte`); a same-named CTE in a sibling/inner query no 
longer
+      hides a real top-level table. A DML *target* is always a real table (you 
cannot
+      write to a CTE, so a same-named CTE does not shadow it), but DML 
*sources* follow
+      normal CTE scoping -- a CTE used as an INSERT/UPDATE source is not 
flagged.
+    - **Catalog-qualified references are reported with their catalog**, so the 
caller
+      rejects ``otherdb.public.orders`` instead of matching it to 
``public.orders``.
+    - **Unverifiable constructs are listed, not silently dropped:** nameless
+      table-valued functions (``dblink``), ``TABLE('name')`` row sources
+      (``exp.TableFromRows``), ``SHOW``, dynamic SQL (``EXEC``/``Command``), 
the
+      ``TABLE <name>`` shorthand (which sqlglot parses incorrectly, leaking the
+      ``TABLE`` keyword as a column), a **quoted identifier** (case-sensitive 
on the engine but
+      matched case-insensitively here, so ``"Orders"`` could otherwise reach a 
table
+      distinct from the allow-listed ``orders``), and **any inline comment** --
+      comments are where parser-vs-engine differentials hide (MySQL executable
+      ``/*! ... */``, ``--`` not followed by whitespace, ``#``).
+
+    :param statements: Parsed sqlglot statements (from :func:`parse_sql`).
+    :return: A :class:`TableScan` of real table references and unverifiable 
constructs.
+    """
+    tables: list[tuple[str, str, str]] = []
+    unverifiable: list[str] = []
+    for stmt in statements:
+        # SHOW enumerates objects / leaks a table's columns outside any single 
table.
+        if isinstance(stmt, exp.Show):
+            unverifiable.append("a SHOW statement")
+            continue
+        # Dynamic SQL and anything sqlglot can only represent as a raw Command 
reach
+        # data through text the parser cannot inspect.
+        if isinstance(stmt, (exp.Command, exp.Execute)):
+            unverifiable.append(f"a {type(stmt).__name__.lower()} statement")
+            continue
+
+        # A comment is a parser-vs-engine differential vector: sqlglot drops 
it, but the
+        # engine may execute it (MySQL `/*! ... */`) or tokenize it 
differently (`--`
+        # without a trailing space, `#`). sqlglot tokenizes string literals 
correctly,
+        # so a `--` inside a quoted string is not flagged here.
+        if any(node.comments for node in stmt.walk()):
+            unverifiable.append("an inline comment")
+            continue
+
+        # `TABLE('name')` / `TABLE($$name$$)` name a table through a string 
the parser
+        # cannot resolve; sqlglot models them as TableFromRows, not exp.Table.
+        if any(True for _ in stmt.find_all(exp.TableFromRows)):
+            unverifiable.append("a TABLE(...) row source")
+            continue
+
+        # `TABLE <name>` (Postgres/MySQL shorthand for SELECT * FROM <name>) 
is not
+        # modelled by sqlglot; it parses incorrectly, leaking the reserved 
word TABLE as an
+        # unquoted column identifier. No real query has an unquoted column 
named TABLE.
+        if any(
+            isinstance(col.this, exp.Identifier)
+            and not col.this.args.get("quoted")
+            and str(col.this.this).upper() == "TABLE"
+            for col in stmt.find_all(exp.Column)
+        ):
+            unverifiable.append("a TABLE <name> shorthand")
+            continue
+
+        # A DML statement's *target* (the table written to) is always a real 
table --
+        # you cannot INSERT/UPDATE/DELETE/MERGE into a CTE, so even a 
same-named CTE does
+        # not shadow it. Its *sources* (the SELECT/USING/subqueries) follow 
normal CTE
+        # scoping, so a CTE used as a source is not mistaken for a base table.
+        target = stmt.args.get("this") if isinstance(stmt, _DML_TYPES) else 
None
+        target_ids = {id(t) for t in target.find_all(exp.Table)} if target is 
not None else set()
+        for table in stmt.find_all(exp.Table):
+            name = table.name
+            if not name:
+                unverifiable.append(f"table-valued function ({table.sql()})")
+                continue
+            # A bare, non-target reference may be a CTE; a qualified one or a 
DML target
+            # never is.
+            if id(table) not in target_ids and not table.db and not 
table.catalog and _is_in_scope_cte(table):
+                continue
+            # A quoted identifier is case-sensitive on the engine, but the 
allow-list is
+            # matched case-insensitively (and a plain ``schema.table`` string 
cannot
+            # carry quoting), so a quoted reference cannot be matched soundly: 
on
+            # Postgres/Snowflake ``"Orders"`` is a *different* table from the 
allow-listed
+            # ``orders``. Reject rather than risk reaching a case-distinct 
table.
+            if any(
+                isinstance(part, exp.Identifier) and part.args.get("quoted")
+                for part in (table.this, table.args.get("db"), 
table.args.get("catalog"))
+            ):
+                unverifiable.append("a quoted identifier")
+                continue
+            tables.append((table.catalog, table.db, name))
+    return TableScan(tables=tables, unverifiable_sources=unverifiable)
+
+
 def validate_sql(
     sql: str,
     *,
@@ -138,9 +379,6 @@ def validate_sql(
     :raises SQLSafetyError: If the SQL is empty, contains disallowed statement 
types,
         or has multiple statements when not permitted.
     """
-    if not sql or not sql.strip():
-        raise SQLSafetyError("Empty SQL input.")
-
     # A caller-supplied ``allowed_types`` is an explicit opt-out of the curated
     # read-only defaults (and the data-modifying deep scan). Otherwise we use 
the
     # read-only defaults, optionally widened with metadata statements, and keep
@@ -154,20 +392,7 @@ def validate_sql(
         types = allowed_types
         run_data_modifying_scan = types == DEFAULT_ALLOWED_TYPES
 
-    try:
-        statements = sqlglot.parse(sql, dialect=dialect, 
error_level=ErrorLevel.RAISE)
-    except sqlglot.errors.ParseError as e:
-        raise SQLSafetyError(f"SQL parse error: {e}") from e
-
-    # sqlglot.parse can return [None] for empty input
-    parsed = [s for s in statements if s is not None]
-    if not parsed:
-        raise SQLSafetyError("Empty SQL input.")
-
-    if not allow_multiple_statements and len(parsed) > 1:
-        raise SQLSafetyError(
-            f"Multiple statements detected ({len(parsed)}). Only single 
statements are allowed by default."
-        )
+    parsed = parse_sql(sql, dialect=dialect, 
allow_multiple_statements=allow_multiple_statements)
 
     for stmt in parsed:
         if not isinstance(stmt, types):
diff --git a/providers/common/ai/tests/unit/common/ai/toolsets/test_sql.py 
b/providers/common/ai/tests/unit/common/ai/toolsets/test_sql.py
index 92e033a6b60..ed0619a1db0 100644
--- a/providers/common/ai/tests/unit/common/ai/toolsets/test_sql.py
+++ b/providers/common/ai/tests/unit/common/ai/toolsets/test_sql.py
@@ -467,3 +467,285 @@ class TestSQLToolsetMetadataStatements:
                 ts.call_tool("check_query", {"sql": "SELECT 1"}, 
ctx=MagicMock(), tool=MagicMock())
             )
         assert json.loads(result)["valid"] is True
+
+
+def _run_query(ts: SQLToolset, sql: str):
+    return asyncio.run(ts.call_tool("query", {"sql": sql}, ctx=MagicMock(), 
tool=MagicMock()))
+
+
+def _run_check(ts: SQLToolset, sql: str):
+    return json.loads(
+        asyncio.run(ts.call_tool("check_query", {"sql": sql}, ctx=MagicMock(), 
tool=MagicMock()))
+    )
+
+
+class TestSQLToolsetAllowedTablesQueryEnforcement:
+    """``allowed_tables`` is enforced on the query/check_query tools, not just 
on discovery."""
+
+    def test_query_allows_table_on_the_list(self):
+        ts = SQLToolset("pg_default", allowed_tables=["orders"])
+        ts._hook = _make_mock_db_hook(records=[(1,)], 
last_description=[("id",)])
+
+        result = _run_query(ts, "SELECT id FROM orders")
+
+        assert "rows" in json.loads(result)
+        ts._hook.get_records.assert_called_once_with("SELECT id FROM orders")
+
+    def test_query_blocks_table_off_the_list(self):
+        """The headline escape: querying a table that is not on the allow-list 
is refused."""
+        ts = SQLToolset("pg_default", allowed_tables=["orders"])
+        ts._hook = _make_mock_db_hook()
+
+        with pytest.raises(ModelRetry) as exc_info:
+            _run_query(ts, "SELECT * FROM secret_salaries")
+
+        assert "not in the allowed tables list" in exc_info.value.message
+        assert "secret_salaries" in exc_info.value.message
+        ts._hook.get_records.assert_not_called()
+
+    @pytest.mark.parametrize(
+        "sql",
+        [
+            "SELECT * FROM (SELECT * FROM secret_salaries) x",
+            "WITH s AS (SELECT * FROM secret_salaries) SELECT * FROM s",
+            "SELECT * FROM orders JOIN secret_salaries ON orders.id = 
secret_salaries.id",
+            "SELECT * FROM orders UNION SELECT * FROM secret_salaries",
+            "SELECT * FROM secret_salaries WHERE id IN (SELECT id FROM 
orders)",
+        ],
+        ids=["subquery", "cte_body", "join", "union", "where_subquery"],
+    )
+    def test_query_blocks_disallowed_table_reached_indirectly(self, sql):
+        ts = SQLToolset("pg_default", allowed_tables=["orders"])
+        ts._hook = _make_mock_db_hook()
+
+        with pytest.raises(ModelRetry) as exc_info:
+            _run_query(ts, sql)
+
+        assert "secret_salaries" in exc_info.value.message
+        ts._hook.get_records.assert_not_called()
+
+    def test_query_blocks_catalog_enumeration(self):
+        """information_schema/pg_catalog are ordinary tables, so the 
allow-list blocks them too."""
+        ts = SQLToolset("pg_default", allowed_tables=["orders"])
+        ts._hook = _make_mock_db_hook()
+
+        with pytest.raises(ModelRetry) as exc_info:
+            _run_query(ts, "SELECT table_name FROM information_schema.tables")
+
+        assert "information_schema.tables" in exc_info.value.message
+        ts._hook.get_records.assert_not_called()
+
+    def test_query_allows_cte_reference_not_mistaken_for_table(self):
+        """A CTE whose name is not on the list is fine as long as its body 
stays allowed."""
+        ts = SQLToolset("pg_default", allowed_tables=["orders"])
+        ts._hook = _make_mock_db_hook(records=[(1,)], 
last_description=[("id",)])
+
+        result = _run_query(ts, "WITH ranked AS (SELECT * FROM orders) SELECT 
* FROM ranked")
+
+        assert "rows" in json.loads(result)
+
+    def test_query_blocks_table_valued_function(self):
+        """dblink reaches data through a path the list can't describe, so it 
is refused."""
+        ts = SQLToolset("pg_default", allowed_tables=["orders"])
+        ts._hook = _make_mock_db_hook()
+        ts._hook.dialect_name = "postgresql"
+
+        with pytest.raises(ModelRetry) as exc_info:
+            _run_query(ts, "SELECT * FROM dblink('host=evil', 'SELECT 1') AS 
t(x int)")
+
+        assert "cannot be checked against allowed_tables" in 
exc_info.value.message
+        ts._hook.get_records.assert_not_called()
+
+    def test_query_blocks_show_when_allowlist_active(self):
+        ts = SQLToolset("sf_default", allowed_tables=["orders"])
+        ts._hook = _make_mock_db_hook()
+        ts._hook.dialect_name = "snowflake"
+
+        with pytest.raises(ModelRetry) as exc_info:
+            _run_query(ts, "SHOW TABLES")
+
+        assert "cannot be checked against allowed_tables" in 
exc_info.value.message
+        ts._hook.get_records.assert_not_called()
+
+    def test_query_blocks_describe_of_disallowed_table(self):
+        ts = SQLToolset("sf_default", allowed_tables=["orders"])
+        ts._hook = _make_mock_db_hook()
+        ts._hook.dialect_name = "snowflake"
+
+        with pytest.raises(ModelRetry) as exc_info:
+            _run_query(ts, "DESCRIBE TABLE secret_salaries")
+
+        assert "secret_salaries" in exc_info.value.message
+        ts._hook.get_records.assert_not_called()
+
+    def test_query_allows_describe_of_allowed_table(self):
+        ts = SQLToolset("sf_default", allowed_tables=["orders"])
+        ts._hook = _make_mock_db_hook(records=[("id", "INT")], 
last_description=[("name",), ("type",)])
+        ts._hook.dialect_name = "snowflake"
+
+        result = _run_query(ts, "DESCRIBE TABLE orders")
+
+        assert "rows" in json.loads(result)
+        ts._hook.get_records.assert_called_once_with("DESCRIBE TABLE orders")
+
+    def test_query_allows_schema_qualified_table_on_list(self):
+        ts = SQLToolset("sf", allowed_tables=["MODEL_CRM.SF_ASTRO_ORGS"])
+        ts._hook = _make_mock_db_hook(records=[(1,)], 
last_description=[("id",)])
+        ts._hook.dialect_name = "snowflake"
+
+        result = _run_query(ts, "SELECT * FROM MODEL_CRM.SF_ASTRO_ORGS")
+
+        assert "rows" in json.loads(result)
+
+    def test_query_unqualified_resolves_to_default_schema(self):
+        """``public.orders`` and ``orders`` denote the same table when 
schema='public'."""
+        ts = SQLToolset("pg", allowed_tables=["orders"], schema="public")
+        ts._hook = _make_mock_db_hook(records=[(1,)], 
last_description=[("id",)])
+
+        # Qualifying with the default schema must still match the bare 
allow-list entry.
+        result = _run_query(ts, "SELECT * FROM public.orders")
+        assert "rows" in json.loads(result)
+
+    def test_no_allowlist_leaves_queries_unrestricted(self):
+        """Without allowed_tables the query tool behaves exactly as before 
(allow-all)."""
+        ts = SQLToolset("pg_default")
+        ts._hook = _make_mock_db_hook(records=[(1,)], 
last_description=[("id",)])
+
+        result = _run_query(ts, "SELECT * FROM anything_at_all")
+
+        assert "rows" in json.loads(result)
+        ts._hook.get_records.assert_called_once_with("SELECT * FROM 
anything_at_all")
+
+    def test_check_query_reports_disallowed_table_as_invalid(self):
+        ts = SQLToolset("pg_default", allowed_tables=["orders"])
+        ts._hook = _make_mock_db_hook()
+
+        data = _run_check(ts, "SELECT * FROM secret_salaries")
+
+        assert data["valid"] is False
+        assert "secret_salaries" in data["error"]
+
+    def test_check_query_valid_for_allowed_table(self):
+        ts = SQLToolset("pg_default", allowed_tables=["orders"])
+        ts._hook = _make_mock_db_hook()
+
+        assert _run_check(ts, "SELECT * FROM orders")["valid"] is True
+
+    def test_writes_still_bounded_by_allowed_tables(self):
+        """allow_writes widens the statement types, but the allow-list still 
scopes the target."""
+        ts = SQLToolset("pg_default", allowed_tables=["orders"], 
allow_writes=True)
+        ts._hook = _make_mock_db_hook(records=[], last_description=None)
+
+        # An allowed target is written.
+        _run_query(ts, "INSERT INTO orders (id) VALUES (1)")
+        ts._hook.get_records.assert_called_once_with("INSERT INTO orders (id) 
VALUES (1)")
+
+        # A disallowed target is refused before execution.
+        ts._hook.get_records.reset_mock()
+        with pytest.raises(ModelRetry) as exc_info:
+            _run_query(ts, "INSERT INTO secret_salaries (id) VALUES (1)")
+        assert "secret_salaries" in exc_info.value.message
+        ts._hook.get_records.assert_not_called()
+
+    def test_writes_reject_dynamic_sql_the_parser_cannot_inspect(self):
+        """allow_writes skips the read-only validator, so the allow-list must 
still
+        refuse dynamic SQL (EXEC/EXECUTE) whose table access is opaque."""
+        ts = SQLToolset("mssql_default", allowed_tables=["orders"], 
allow_writes=True)
+        ts._hook = _make_mock_db_hook()
+        ts._hook.dialect_name = "mssql"
+
+        with pytest.raises(ModelRetry) as exc_info:
+            _run_query(ts, "EXEC sp_who")
+
+        assert "cannot be checked against allowed_tables" in 
exc_info.value.message
+        ts._hook.get_records.assert_not_called()
+
+
+class TestSQLToolsetAllowedTablesBypassRegressions:
+    """Regression tests for bypasses found by adversarial red-teaming of the 
allow-list."""
+
+    @pytest.mark.parametrize(
+        ("sql", "dialect", "allow_writes"),
+        [
+            # CTE scope: a same-named CTE in an inner/sibling scope must not 
hide the real table.
+            (
+                "SELECT * FROM secret_salaries WHERE id IN "
+                "(WITH secret_salaries AS (SELECT 1 id) SELECT id FROM 
secret_salaries)",
+                "postgresql",
+                False,
+            ),
+            # Non-recursive CTE is not in scope within its own body.
+            (
+                "WITH secret_salaries AS (SELECT * FROM secret_salaries) 
SELECT * FROM secret_salaries",
+                "postgresql",
+                False,
+            ),
+            # A CTE may only reference earlier siblings; a later-defined name 
is the real table.
+            (
+                "WITH a AS (SELECT * FROM secret_salaries), secret_salaries AS 
(SELECT 1 id) SELECT * FROM a",
+                "postgresql",
+                False,
+            ),
+            # Cross-database / catalog qualifier the schema.table allow-list 
cannot describe.
+            ("SELECT * FROM secretdb.public.orders", "snowflake", False),
+            ("SELECT * FROM secret_salaries..orders", "mssql", False),
+            # MySQL executable comments execute on the engine but sqlglot 
treats them as inert.
+            ("SELECT * FROM orders/*!UNION SELECT * FROM secret_salaries*/", 
"mysql", False),
+            ("SELECT id FROM orders /*!50000 UNION SELECT id FROM 
secret_salaries */", "mysql", False),
+            # TABLE <name> shorthand (mis-parsed) and TABLE('name') row source 
(string-named).
+            ("TABLE secret_salaries UNION SELECT * FROM orders", "postgresql", 
False),
+            ("SELECT * FROM TABLE('secret_salaries')", "snowflake", False),
+            # Write-mode CTE shadowing the DML target.
+            ("WITH secret_salaries AS (SELECT 1) DELETE FROM secret_salaries", 
"postgresql", True),
+            # Quoted identifier is case-distinct on the engine but case-folds 
into the list.
+            ('SELECT * FROM "Orders"', "postgresql", False),
+            # A DML source CTE whose body reads an off-list table is still 
caught.
+            (
+                "WITH src AS (SELECT * FROM secret_salaries) INSERT INTO 
orders SELECT * FROM src",
+                "postgresql",
+                True,
+            ),
+        ],
+        ids=[
+            "cte_inner_shadow",
+            "cte_self_body",
+            "cte_forward_ref",
+            "catalog_cross_db",
+            "mssql_empty_middle",
+            "mysql_exec_comment",
+            "mysql_versioned_comment",
+            "table_shorthand",
+            "table_row_source",
+            "write_cte_target",
+            "quoted_case_distinct",
+            "dml_cte_body_reads_offlist",
+        ],
+    )
+    def test_known_bypasses_are_rejected(self, sql, dialect, allow_writes):
+        ts = SQLToolset("c", allowed_tables=["orders"], 
allow_writes=allow_writes)
+        ts._hook = _make_mock_db_hook()
+        ts._hook.dialect_name = dialect
+
+        with pytest.raises(ModelRetry):
+            _run_query(ts, sql)
+        ts._hook.get_records.assert_not_called()
+
+    def test_legit_cte_over_allowed_table_still_runs(self):
+        """The scope-aware fix must not false-reject a genuine CTE over an 
allowed table."""
+        ts = SQLToolset("c", allowed_tables=["orders"])
+        ts._hook = _make_mock_db_hook(records=[(1,)], 
last_description=[("id",)])
+
+        result = _run_query(ts, "WITH ranked AS (SELECT * FROM orders) SELECT 
* FROM ranked")
+
+        assert "rows" in json.loads(result)
+        ts._hook.get_records.assert_called_once()
+
+    def test_dml_with_cte_source_over_allowed_table_runs(self):
+        """A CTE used as a DML source must not be mistaken for a disallowed 
base table."""
+        ts = SQLToolset("c", allowed_tables=["orders"], allow_writes=True)
+        ts._hook = _make_mock_db_hook(records=[], last_description=None)
+
+        sql = "WITH src AS (SELECT * FROM orders) INSERT INTO orders SELECT * 
FROM src"
+        _run_query(ts, sql)
+
+        ts._hook.get_records.assert_called_once_with(sql)
diff --git 
a/providers/common/ai/tests/unit/common/ai/utils/test_sql_validation.py 
b/providers/common/ai/tests/unit/common/ai/utils/test_sql_validation.py
index 9ca6604ba5e..257761d78e0 100644
--- a/providers/common/ai/tests/unit/common/ai/utils/test_sql_validation.py
+++ b/providers/common/ai/tests/unit/common/ai/utils/test_sql_validation.py
@@ -21,6 +21,8 @@ from sqlglot import exp
 
 from airflow.providers.common.ai.utils.sql_validation import (
     SQLSafetyError,
+    collect_table_references,
+    parse_sql,
     resolve_sqlglot_dialect,
     validate_sql,
 )
@@ -312,3 +314,193 @@ class TestResolveSqlglotDialect:
     )
     def test_resolution(self, dialect_name, expected):
         assert resolve_sqlglot_dialect(dialect_name) == expected
+
+
+class TestParseSQL:
+    """``parse_sql`` enforces only the empty- and multi-statement guards."""
+
+    def test_returns_statements(self):
+        parsed = parse_sql("SELECT 1")
+        assert len(parsed) == 1
+        assert isinstance(parsed[0], exp.Select)
+
+    def test_does_not_apply_type_checks(self):
+        """Unlike validate_sql, parse_sql accepts writes -- callers add their 
own policy."""
+        parsed = parse_sql("DELETE FROM users WHERE id = 1")
+        assert isinstance(parsed[0], exp.Delete)
+
+    @pytest.mark.parametrize("sql", ["", "   ", "\n\t"])
+    def test_rejects_empty(self, sql):
+        with pytest.raises(SQLSafetyError, match="Empty SQL"):
+            parse_sql(sql)
+
+    def test_rejects_multiple_statements_by_default(self):
+        with pytest.raises(SQLSafetyError, match="Multiple statements"):
+            parse_sql("SELECT 1; SELECT 2")
+
+    def test_allows_multiple_statements_when_opted_in(self):
+        assert len(parse_sql("SELECT 1; SELECT 2", 
allow_multiple_statements=True)) == 2
+
+    def test_rejects_unparsable(self):
+        with pytest.raises(SQLSafetyError, match="parse error"):
+            parse_sql("SELECT FROM WHERE )(")
+
+
+class TestCollectTableReferences:
+    """``collect_table_references`` reports the real tables a query reaches."""
+
+    @pytest.mark.parametrize(
+        ("sql", "dialect", "expected"),
+        [
+            ("SELECT * FROM secret", None, [("", "", "secret")]),
+            ("SELECT * FROM model_crm.orders", None, [("", "model_crm", 
"orders")]),
+            ("SELECT * FROM a JOIN b ON a.id = b.id", None, [("", "", "a"), 
("", "", "b")]),
+            ("SELECT * FROM (SELECT * FROM inner_t) x", None, [("", "", 
"inner_t")]),
+            ("SELECT * FROM a UNION SELECT * FROM b", None, [("", "", "a"), 
("", "", "b")]),
+            (
+                "SELECT table_name FROM information_schema.tables",
+                "postgres",
+                [("", "information_schema", "tables")],
+            ),
+            ("DESCRIBE secret", "mysql", [("", "", "secret")]),
+            # Cross-database reference carries its catalog so the caller can 
reject it.
+            ("SELECT * FROM otherdb.public.orders", "snowflake", [("otherdb", 
"public", "orders")]),
+        ],
+        ids=["table", "qualified", "join", "subquery", "union", "catalog", 
"describe", "cross_db"],
+    )
+    def test_collects_real_tables(self, sql, dialect, expected):
+        scan = collect_table_references(parse_sql(sql, dialect=dialect))
+        assert sorted(scan.tables) == sorted(expected)
+        assert scan.unverifiable_sources == []
+
+    def test_excludes_cte_reference_but_keeps_its_body(self):
+        scan = collect_table_references(parse_sql("WITH s AS (SELECT * FROM 
base) SELECT * FROM s"))
+        assert scan.tables == [("", "", "base")]  # 's' is the CTE, not a table
+
+    def test_cte_that_shadows_a_table_name_yields_no_table(self):
+        scan = collect_table_references(parse_sql("WITH secret AS (SELECT 1 AS 
x) SELECT * FROM secret"))
+        assert scan.tables == []
+
+    def test_schema_qualified_name_is_never_treated_as_a_cte(self):
+        scan = collect_table_references(parse_sql("WITH s AS (SELECT 1 AS x) 
SELECT * FROM myschema.s"))
+        assert scan.tables == [("", "myschema", "s")]
+
+    def test_inner_cte_does_not_shadow_outer_real_table(self):
+        """A same-named CTE in an inner subquery must not hide the real 
top-level table."""
+        sql = "SELECT * FROM secret WHERE id IN (WITH secret AS (SELECT 1 id) 
SELECT id FROM secret)"
+        scan = collect_table_references(parse_sql(sql))
+        assert ("", "", "secret") in scan.tables  # the top-level real table 
is reported
+
+    def test_cte_self_body_references_real_table(self):
+        """A non-recursive CTE is not in scope within its own body, so the 
real table shows."""
+        scan = collect_table_references(
+            parse_sql("WITH secret AS (SELECT * FROM secret) SELECT * FROM 
secret")
+        )
+        assert ("", "", "secret") in scan.tables
+
+    def test_cte_forward_reference_is_real_table(self):
+        """A CTE may only reference earlier siblings; a later-defined name is 
the real table."""
+        sql = "WITH a AS (SELECT * FROM secret), secret AS (SELECT 1 id) 
SELECT * FROM a"
+        scan = collect_table_references(parse_sql(sql))
+        assert ("", "", "secret") in scan.tables
+
+    def test_legit_cte_reference_is_excluded(self):
+        """A genuine CTE reference is not reported as a base table (no false 
reject)."""
+        scan = collect_table_references(
+            parse_sql("WITH ranked AS (SELECT * FROM orders) SELECT * FROM 
ranked")
+        )
+        assert scan.tables == [("", "", "orders")]
+
+    @pytest.mark.parametrize(
+        ("sql", "dialect"),
+        [
+            ("SELECT * FROM orders/*!UNION SELECT * FROM secret*/", "mysql"),
+            ("SELECT * FROM orders /*!50000 UNION SELECT * FROM secret */", 
"mysql"),
+            ("SELECT * FROM orders WHERE 0--+1 OR id IN (SELECT id FROM 
secret)", "mysql"),
+        ],
+        ids=["exec_comment", "versioned_comment", "dashdash"],
+    )
+    def test_flags_inline_comment_as_unverifiable(self, sql, dialect):
+        """Comments hide parser-vs-engine differentials (MySQL executable 
comments), so reject them."""
+        scan = collect_table_references(parse_sql(sql, dialect=dialect))
+        assert scan.unverifiable_sources
+
+    @pytest.mark.parametrize(
+        "sql",
+        ["SELECT * FROM TABLE('secret')", "SELECT * FROM TABLE($$secret$$)"],
+        ids=["string", "dollar"],
+    )
+    def test_flags_table_row_source_as_unverifiable(self, sql):
+        """Snowflake TABLE('name') names a table through a string the parser 
can't resolve."""
+        scan = collect_table_references(parse_sql(sql, dialect="snowflake"))
+        assert scan.unverifiable_sources
+
+    def test_flags_table_shorthand_as_unverifiable(self):
+        """The TABLE <name> shorthand is mis-parsed by sqlglot, so reject 
it."""
+        scan = collect_table_references(
+            parse_sql("TABLE secret UNION SELECT * FROM orders", 
dialect="postgres")
+        )
+        assert scan.unverifiable_sources
+
+    @pytest.mark.parametrize(
+        "sql",
+        ['SELECT * FROM public."Orders"', 'SELECT * FROM "Orders"', 'SELECT * 
FROM "PUBLIC".orders'],
+        ids=["quoted_table", "quoted_bare", "quoted_schema"],
+    )
+    def test_flags_quoted_identifier_as_unverifiable(self, sql):
+        """A quoted identifier is case-sensitive; case-insensitive matching 
can't verify it."""
+        scan = collect_table_references(parse_sql(sql, dialect="postgres"))
+        assert scan.unverifiable_sources
+
+    def test_dml_target_is_real_but_cte_source_is_excluded(self):
+        """A CTE used as a DML source is not a base table; the target and CTE 
body are."""
+        sql = "WITH src AS (SELECT * FROM orders) INSERT INTO orders SELECT * 
FROM src"
+        scan = collect_table_references(parse_sql(sql, dialect="postgres"))
+        # Only the real table `orders` is reported (target + CTE body); `src` 
is the CTE.
+        assert {t for _, _, t in scan.tables} == {"orders"}
+        assert scan.unverifiable_sources == []
+
+    def test_dml_target_shadowed_by_cte_is_still_reported(self):
+        """A DML target is a real table even when a same-named CTE exists 
(can't write a CTE)."""
+        sql = "WITH secret AS (SELECT 1) DELETE FROM secret"
+        scan = collect_table_references(parse_sql(sql, dialect="postgres"))
+        assert ("", "", "secret") in scan.tables
+
+    @pytest.mark.parametrize(
+        ("sql", "dialect"),
+        [
+            ("SELECT * FROM dblink('h', 'SELECT 1') AS t(x int)", "postgres"),
+            ("SELECT * FROM generate_series(1, 10)", "postgres"),
+        ],
+        ids=["dblink", "generate_series"],
+    )
+    def test_flags_table_valued_functions_as_unverifiable(self, sql, dialect):
+        scan = collect_table_references(parse_sql(sql, dialect=dialect))
+        assert scan.tables == []
+        assert scan.unverifiable_sources
+
+    @pytest.mark.parametrize(
+        ("sql", "dialect"),
+        [("SHOW TABLES", "snowflake"), ("SHOW COLUMNS FROM secret", "mysql")],
+        ids=["show_tables", "show_columns"],
+    )
+    def test_flags_show_as_unverifiable(self, sql, dialect):
+        scan = collect_table_references(parse_sql(sql, dialect=dialect))
+        assert scan.unverifiable_sources
+
+    def test_scalar_function_without_table_has_no_references(self):
+        """A scalar function call references no table -- the allow-list does 
not cover it."""
+        scan = collect_table_references(parse_sql("SELECT 
pg_read_file('/etc/passwd')", dialect="postgres"))
+        assert scan.tables == []
+        assert scan.unverifiable_sources == []
+
+    @pytest.mark.parametrize(
+        ("sql", "dialect"),
+        [("EXEC sp_who", "tsql"), ("EXECUTE my_proc", "tsql")],
+        ids=["exec", "execute"],
+    )
+    def test_flags_dynamic_sql_as_unverifiable(self, sql, dialect):
+        """EXEC/EXECUTE hide their table access in text the parser can't 
read."""
+        scan = collect_table_references(parse_sql(sql, dialect=dialect))
+        assert scan.tables == []
+        assert scan.unverifiable_sources

Reply via email to