This is an automated email from the ASF dual-hosted git repository.
kaxil pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new b193ced28a7 Enforce SQLToolset allowed_tables on queries, not just
discovery (#68487)
b193ced28a7 is described below
commit b193ced28a79ba567d36d95ad26445fb4719794a
Author: Kaxil Naik <[email protected]>
AuthorDate: Tue Jun 16 13:17:10 2026 +0100
Enforce SQLToolset allowed_tables on queries, not just discovery (#68487)
allowed_tables previously restricted only metadata discovery (list_tables /
get_schema); the query and check_query tools never checked it, so an agent
could read any table by name. It is now enforced on the query tools as well:
the SQL is parsed with sqlglot and rejected before execution if it reaches a
table that is not on the list, resolved with its database/catalog and with
CTE
references excluded by lexical scope.
Constructs a schema.table list cannot describe are rejected while a list is
active: table-valued functions, TABLE('name') row sources, the TABLE <name>
shorthand, SHOW, dynamic SQL, quoted identifiers, cross-database references,
and inline comments (where MySQL executable /*! ... */ comments hide).
This is application-level defense-in-depth, not a substitute for database
permissions: data reached through a function whose argument is itself SQL
or a
path (pg_read_file, query_to_xml, scalar dblink) is out of its reach, so a
least-privilege DB role remains the hard boundary.
* Allow CTE sources in DML against allowed_tables
A CTE used as a DML source (WITH src AS (...) INSERT INTO orders SELECT *
FROM
src) was falsely rejected: CTE scoping was disabled for the whole DML
statement,
so src was checked against allowed_tables as if it were a base table.
Now only the DML target is exempt from CTE resolution (you cannot write to
a CTE,
so a same-named CTE never shadows the target); sources follow normal
lexical CTE
scoping. The target, and any off-list table a CTE body actually reads, are
still
enforced.
---
providers/common/ai/docs/toolsets.rst | 78 ++++--
.../airflow/providers/common/ai/toolsets/sql.py | 155 ++++++++---
.../providers/common/ai/utils/sql_validation.py | 259 +++++++++++++++++--
.../ai/tests/unit/common/ai/toolsets/test_sql.py | 282 +++++++++++++++++++++
.../unit/common/ai/utils/test_sql_validation.py | 192 ++++++++++++++
5 files changed, 883 insertions(+), 83 deletions(-)
diff --git a/providers/common/ai/docs/toolsets.rst
b/providers/common/ai/docs/toolsets.rst
index 33814d50644..d284227c0ac 100644
--- a/providers/common/ai/docs/toolsets.rst
+++ b/providers/common/ai/docs/toolsets.rst
@@ -155,10 +155,10 @@ the validator, so ``SHOW`` is recognized on databases
that support it (Snowflake
MySQL, etc.); on databases without ``SHOW`` it stays rejected. Data-modifying
statements remain blocked -- including ones hidden behind
``DESCRIBE``/``EXPLAIN``
(e.g. ``EXPLAIN DELETE ...``, ``DESCRIBE DROP TABLE ...``), which the validator
-rejects by scanning the parsed statement for write operations. Like ``SELECT``,
-metadata statements are not scoped by ``allowed_tables`` (see
-:ref:`allowed-tables-limitation`) -- an agent can ``DESCRIBE`` a table outside
the
-list, so rely on database permissions to restrict access.
+rejects by scanning the parsed statement for write operations. When
+``allowed_tables`` is set it scopes these statements too: a ``DESCRIBE`` names
a
+table, so its target must be on the list, while ``SHOW`` enumerates objects
beyond
+any single table and is rejected outright (see
:ref:`allowed-tables-enforcement`).
Multi-schema warehouses
^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -186,11 +186,13 @@ Parameters
^^^^^^^^^^
- ``db_conn_id``: Airflow connection ID for the database.
-- ``allowed_tables``: Restrict which tables the agent can discover via
- ``list_tables`` and ``get_schema``. ``None`` (default) exposes all tables in
- ``schema``. Entries may be schema-qualified (``"SCHEMA.TABLE"``) to span
- multiple schemas; see above. Matching is case-insensitive.
- See :ref:`allowed-tables-limitation` for an important caveat.
+- ``allowed_tables``: Restrict the agent to a fixed set of tables. ``None``
+ (default) exposes all tables in ``schema``. Entries may be schema-qualified
+ (``"SCHEMA.TABLE"``) to span multiple schemas; see above. Matching is
+ case-insensitive. When set, the list is enforced on ``query`` and
+ ``check_query`` as well as discovery -- every table a query references must
be
+ on it. See :ref:`allowed-tables-enforcement` for what this does and does not
+ guarantee.
- ``schema``: Default schema/namespace for unqualified table listing and
introspection. Schema-qualified ``allowed_tables`` entries override it per
table.
- ``allow_writes``: Allow data-modifying SQL (INSERT, UPDATE, DELETE, etc.).
@@ -583,11 +585,14 @@ No single layer is sufficient — they work together.
INTO, and other non-SELECT statements.
- Does not prevent the agent from reading any registered data source.
* - **SQLToolset: allowed_tables**
- - Restricts which tables appear in ``list_tables`` and ``get_schema``
- responses, limiting the agent's knowledge of the schema.
- - Does **not** validate table references in SQL queries. The agent can
- still query unlisted tables if it guesses the name. See
- :ref:`allowed-tables-limitation` below.
+ - Restricts the agent to listed tables across ``list_tables``,
+ ``get_schema``, ``query``, and ``check_query``. Queries are parsed and
+ every referenced table (including via subqueries, CTEs, JOINs, and
+ ``DESCRIBE``) is checked against the list before execution.
+ - Cannot police data reached through side-effecting scalar functions
+ (e.g. ``pg_read_file``), and is only as exact as the SQL parser. Pair it
+ with least-privilege database grants. See
+ :ref:`allowed-tables-enforcement` below.
* - **SQLToolset: max_rows**
- Truncates query results to ``max_rows`` (default 50), preventing the
agent from pulling entire tables into context.
@@ -604,21 +609,38 @@ No single layer is sufficient — they work together.
- Requires explicit configuration — the default allows many rounds.
-.. _allowed-tables-limitation:
+.. _allowed-tables-enforcement:
-The ``allowed_tables`` Limitation
-"""""""""""""""""""""""""""""""""
+How ``allowed_tables`` Is Enforced
+""""""""""""""""""""""""""""""""""
-``allowed_tables`` is a **metadata filter**, not an access control mechanism.
-It hides table names from ``list_tables`` and blocks ``get_schema`` for
-unlisted tables, but does not parse SQL queries to validate table references.
+When ``allowed_tables`` is set it governs every tool, not just discovery:
-An LLM can craft ``SELECT * FROM secrets`` even when
-``allowed_tables=["orders"]``. Parsing SQL for table references (including
-CTEs, subqueries, aliases, and vendor-specific syntax) is complex and
-error-prone; we chose not to provide a false sense of security.
+- ``list_tables`` and ``get_schema`` only reveal listed tables.
+- ``query`` and ``check_query`` parse the SQL with `sqlglot
+ <https://github.com/tobymao/sqlglot>`_ and reject it before execution if it
+ references any table that is not on the list. Tables reached indirectly are
+ caught too -- through subqueries, CTEs, JOINs, set operations (``UNION``
etc.),
+ ``DESCRIBE``, catalog views such as ``information_schema``, and DML. CTE
+ references are excluded by lexical scope, so a same-named CTE in another
scope
+ cannot hide a real table, and the database/catalog is part of the match, so a
+ cross-database reference like ``otherdb.public.orders`` is refused.
+- Constructs the list cannot describe are rejected outright while it is active:
+ table-valued functions (``dblink``), ``TABLE('name')`` row sources, the
+ ``TABLE <name>`` shorthand, ``SHOW``, dynamic SQL (``EXEC``), and **inline
+ comments** -- the last because parser-vs-engine differences hide in comments
+ (MySQL executes ``/*! ... */`` while sqlglot and other engines ignore it).
-For query-level restrictions, use database permissions:
+So ``SELECT * FROM secrets`` with ``allowed_tables=["orders"]`` is refused, and
+the rejection is handed back to the agent so it can re-target an allowed table.
+
+This is a strong **application-level guardrail**, but it is not a substitute
for
+database permissions. It cannot police data reached through a function whose
+argument is itself SQL or a path: ``pg_read_file('/etc/passwd')`` reads a file,
+and ``query_to_xml('SELECT * FROM other_table', ...)`` or a scalar ``dblink``
+reads a table through a string the parser cannot inspect. Any query the engine
+parses differently from sqlglot is also a residual gap. For a hard boundary,
also
+run the connection as a least-privilege role:
.. code-block:: sql
@@ -627,8 +649,10 @@ For query-level restrictions, use database permissions:
GRANT SELECT ON orders, customers TO airflow_agent_reader;
-- Use this role's credentials in the Airflow connection
-The Airflow connection should use a database user with the minimum privileges
-required.
+Defense in depth: the allow-list contains the agent's *intent* (and gives it a
+correctable error), while the database role is the boundary that holds even if
+the agent reaches data the parser cannot see. The connection should use a
+database user with the minimum privileges required.
HookToolset Guidelines
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/toolsets/sql.py
b/providers/common/ai/src/airflow/providers/common/ai/toolsets/sql.py
index 45990901e15..a4435971969 100644
--- a/providers/common/ai/src/airflow/providers/common/ai/toolsets/sql.py
+++ b/providers/common/ai/src/airflow/providers/common/ai/toolsets/sql.py
@@ -24,6 +24,9 @@ from typing import TYPE_CHECKING, Any
try:
from airflow.providers.common.ai.utils.sql_validation import (
+ SQLSafetyError,
+ collect_table_references,
+ parse_sql as _parse_sql,
resolve_sqlglot_dialect,
validate_sql as _validate_sql,
)
@@ -94,21 +97,37 @@ class SQLToolset(AbstractToolset[Any]):
toolset does not inspect the error type or message.
:param db_conn_id: Airflow connection ID for the database.
- :param allowed_tables: Restrict which tables the agent can discover via
- ``list_tables`` and ``get_schema``. ``None`` (default) exposes all
tables
- in ``schema``. Entries may be schema-qualified (``"SCHEMA.TABLE"``) to
span
- multiple schemas in one database -- common on warehouses such as
Snowflake.
- ``list_tables`` then introspects each referenced schema and returns the
- matching tables fully qualified, and ``get_schema`` routes to the
table's
- own schema. Unqualified entries use ``schema``. Matching is
- case-insensitive, since databases reflect identifiers in their own
case.
+ :param allowed_tables: Restrict the agent to a fixed set of tables.
``None``
+ (default) exposes every table in ``schema``. Entries may be
schema-qualified
+ (``"SCHEMA.TABLE"``) to span multiple schemas in one database --
common on
+ warehouses such as Snowflake. ``list_tables`` introspects each
referenced
+ schema and returns the matching tables fully qualified, and
``get_schema``
+ routes to the table's own schema. Unqualified entries use ``schema``.
+ Matching is case-insensitive, since databases reflect identifiers in
their
+ own case.
+
+ When set, the list is enforced on the ``query`` and ``check_query``
tools as
+ well as on discovery: every table a query reaches -- through
subqueries, CTEs,
+ JOINs, set operations, ``DESCRIBE``, catalog views such as
+ ``information_schema``, or DML -- must be on the list, resolved with
its
+ database/catalog, or the query is rejected before it runs. CTE
references are
+ excluded by lexical scope (a same-named CTE in another scope never
hides a real
+ table). Constructs the list cannot describe are rejected outright
while it is
+ active: table-valued functions (``dblink``), ``TABLE('name')`` row
sources, the
+ ``TABLE <name>`` shorthand, ``SHOW``, dynamic SQL, and **inline
comments**
+ (where parser-vs-engine differences such as MySQL ``/*! ... */``
executable
+ comments hide).
.. note::
- ``allowed_tables`` controls metadata visibility only. It does
**not**
- parse or validate table references in SQL queries. An LLM can still
- query tables outside this list if it guesses the name. For
query-level
- restrictions, use database-level permissions (e.g. a read-only role
- with grants limited to specific tables).
+ This is an application-level guardrail, enforced by parsing the
SQL with
+ sqlglot. It is strong defense-in-depth but not a substitute for
database
+ permissions: it cannot police data reached through a function whose
+ argument is itself SQL or a path -- ``pg_read_file('...')`` (a
file) or
+ ``query_to_xml('SELECT ... FROM other_table', ...)`` and
``dblink`` in
+ scalar position (a table, read through a string the parser cannot
inspect)
+ -- and any query the engine parses differently from sqlglot is a
residual
+ gap. For a hard guarantee, also point ``db_conn_id`` at a
least-privilege
+ role whose ``SELECT`` grants are limited to the same tables.
:param schema: Default schema/namespace for table listing and
introspection,
used for unqualified ``allowed_tables`` entries and unqualified
@@ -131,42 +150,63 @@ class SQLToolset(AbstractToolset[Any]):
) -> None:
self._db_conn_id = db_conn_id
self._allowed_tables: frozenset[str] | None =
frozenset(allowed_tables) if allowed_tables else None
- # Case-folded view for membership tests: databases reflect identifiers
in
- # their own case (Snowflake stores unquoted names uppercase but
reflects
- # them lowercased), so a byte-exact match against the user's entries
would
- # silently miss. allowed_tables is a visibility hint, not access
control,
- # so case-insensitive matching is safe.
- self._allowed_tables_ci: frozenset[str] | None = (
- frozenset(t.casefold() for t in self._allowed_tables)
- if self._allowed_tables is not None
- else None
- )
self._schema = schema
self._allow_writes = allow_writes
self._max_rows = max_rows
self._hook: DbApiHook | None = None
- # Derive which schemas to introspect from schema-qualified
allowed_tables.
+ # Canonical ``(catalog, schema, table)`` view of allowed_tables for
membership
+ # tests, plus the schemas to introspect. Built once: every reference
-- a
+ # discovery hit, a get_schema arg, or a table parsed out of a query --
is
+ # normalised to the same shape and matched against this set.
+ #
+ # Identifiers are case-folded: databases reflect them in their own case
+ # (Snowflake stores unquoted names uppercase but reflects them
lowercased), so
+ # a byte-exact match against the user's entries would silently miss.
Unqualified
+ # entries resolve to the default ``schema`` (``None`` when unset) so
that
+ # ``"orders"`` and ``"<schema>.orders"`` denote the same table.
Allow-list
+ # entries carry no catalog, so any catalog-qualified reference
+ # (``otherdb.public.orders``) has a non-null catalog in its key and
cannot match
+ # -- that closes cross-database access the single-connection
allow-list can't
+ # describe.
+ self._allowed_canonical: frozenset[tuple[str | None, str | None, str]]
| None = None
# Qualified entries ("SCHEMA.TABLE") are listed under their own schema
and
# returned fully qualified; unqualified entries (and allow-all) use the
# default ``schema``.
self._qualified_schemas: frozenset[str] = frozenset()
self._include_default_schema: bool = True
if self._allowed_tables is not None:
+ canonical: set[tuple[str | None, str | None, str]] = set()
qualified_schemas: set[str] = set()
include_default = False
for entry in self._allowed_tables:
- entry_schema, sep, _ = entry.rpartition(".")
+ entry_schema, sep, table = entry.rpartition(".")
if sep:
qualified_schemas.add(entry_schema)
+ canonical.add(self._canonical_ref("", entry_schema, table))
else:
include_default = True
+ canonical.add(self._canonical_ref("", self._schema, entry))
+ self._allowed_canonical = frozenset(canonical)
self._qualified_schemas = frozenset(qualified_schemas)
self._include_default_schema = include_default
- def _is_table_allowed(self, name: str) -> bool:
- """Case-insensitive membership test against ``allowed_tables``
(allow-all when unset)."""
- return self._allowed_tables_ci is None or name.casefold() in
self._allowed_tables_ci
+ @staticmethod
+ def _canonical_ref(
+ catalog: str | None, schema: str | None, table: str
+ ) -> tuple[str | None, str | None, str]:
+ """Normalise a ``(catalog, schema, table)`` reference to its
case-folded comparison key."""
+ return (
+ catalog.casefold() if catalog else None,
+ schema.casefold() if schema else None,
+ table.casefold(),
+ )
+
+ def _is_ref_allowed(self, catalog: str | None, schema: str | None, table:
str) -> bool:
+ """Membership test for a resolved ``(catalog, schema, table)``
reference (allow-all when unset)."""
+ if self._allowed_canonical is None:
+ return True
+ return self._canonical_ref(catalog, schema, table) in
self._allowed_canonical
@property
def id(self) -> str:
@@ -263,11 +303,11 @@ class SQLToolset(AbstractToolset[Any]):
# Dedupe by (schema, table) so a table reachable both qualified and
via the
# default schema (e.g. "public.users" and "users" with
schema="public") is
# listed once. Case-folded because databases reflect identifiers in
their case.
- seen: set[tuple[str | None, str]] = set()
+ seen: set[tuple[str | None, str | None, str]] = set()
def add(schema: str | None, name: str, display: str) -> None:
- key = (schema.casefold() if schema else None, name.casefold())
- if self._is_table_allowed(display) and key not in seen:
+ key = self._canonical_ref("", schema, name)
+ if self._is_ref_allowed("", schema, name) and key not in seen:
seen.add(key)
tables.append(display)
@@ -286,10 +326,10 @@ class SQLToolset(AbstractToolset[Any]):
return json.dumps(tables)
def _get_schema(self, table_name: str) -> str:
- if not self._is_table_allowed(table_name):
+ schema, table = self._split_table_identifier(table_name)
+ if not self._is_ref_allowed("", schema, table):
return json.dumps({"error": f"Table {table_name!r} is not in the
allowed tables list."})
hook = self._get_db_hook()
- schema, table = self._split_table_identifier(table_name)
columns = hook.get_table_schema(table, schema=schema)
return json.dumps(columns)
@@ -300,15 +340,19 @@ class SQLToolset(AbstractToolset[Any]):
def _query(self, sql: str) -> str:
hook = self._get_db_hook()
+ dialect = self._dialect_for_validation()
+ statements: list[Any] | None = None
if not self._allow_writes:
# allow_read_only_metadata lets agents inspect schemas with
DESCRIBE/SHOW
# (a common first move) instead of hard-failing; the deep scan
still
# rejects any data-modifying statement, including EXPLAIN <write>.
- _validate_sql(
- sql,
- dialect=self._dialect_for_validation(),
- allow_read_only_metadata=True,
- )
+ statements = _validate_sql(sql, dialect=dialect,
allow_read_only_metadata=True)
+ elif self._allowed_canonical is not None:
+ # Writes are allowed but tables are restricted: parse anyway so the
+ # allow-list still governs which tables a write may touch.
+ statements = _parse_sql(sql, dialect=dialect)
+ if statements is not None:
+ self._enforce_allowed_tables(statements)
rows = hook.get_records(sql)
# Fetch column names from cursor description.
@@ -336,7 +380,40 @@ class SQLToolset(AbstractToolset[Any]):
with suppress(Exception):
dialect = self._dialect_for_validation()
try:
- _validate_sql(sql, dialect=dialect, allow_read_only_metadata=True)
+ statements = _validate_sql(sql, dialect=dialect,
allow_read_only_metadata=True)
+ self._enforce_allowed_tables(statements)
return json.dumps({"valid": True})
except Exception as e:
return json.dumps({"valid": False, "error": str(e)})
+
+ def _enforce_allowed_tables(self, statements: list[Any]) -> None:
+ """
+ Reject a parsed query that reaches any table outside
``allowed_tables``.
+
+ No-op when ``allowed_tables`` is unset (allow-all). Otherwise every
table the
+ query references (resolved scope-correctly, including catalog) must be
on the
+ list, and any construct the list cannot describe -- a table-valued
function,
+ ``SHOW``, dynamic SQL, an inline comment, or the ``TABLE <name>``
shorthand --
+ is refused. Raises :class:`SQLSafetyError` -- ``call_tool`` turns it
into a
+ ``ModelRetry`` so the agent can re-target an allowed table, while
+ ``check_query`` reports it invalid.
+ """
+ if self._allowed_canonical is None:
+ return
+ scan = collect_table_references(statements)
+ if scan.unverifiable_sources:
+ raise SQLSafetyError(
+ f"Query uses a data source that cannot be checked against
allowed_tables: "
+ f"{'; '.join(scan.unverifiable_sources)}. Query the allowed
tables directly: "
+ f"use list_tables to see them."
+ )
+ disallowed = [
+ ".".join(part for part in (catalog, schema, table) if part)
+ for catalog, schema, table in scan.tables
+ if not self._is_ref_allowed(catalog, schema or self._schema, table)
+ ]
+ if disallowed:
+ raise SQLSafetyError(
+ f"Query references tables that are not in the allowed tables
list: "
+ f"{', '.join(sorted(set(disallowed)))}. Use list_tables to see
the allowed tables."
+ )
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/utils/sql_validation.py
b/providers/common/ai/src/airflow/providers/common/ai/utils/sql_validation.py
index a00b4dc11e6..3b39210710e 100644
---
a/providers/common/ai/src/airflow/providers/common/ai/utils/sql_validation.py
+++
b/providers/common/ai/src/airflow/providers/common/ai/utils/sql_validation.py
@@ -24,6 +24,8 @@ This is safer than a denylist because new/unexpected
statement types
from __future__ import annotations
+from typing import NamedTuple
+
import sqlglot
from sqlglot import exp
from sqlglot.dialects import Dialects
@@ -105,6 +107,245 @@ class SQLSafetyError(Exception):
"""Generated SQL failed safety validation."""
+def parse_sql(
+ sql: str,
+ *,
+ dialect: str | None = None,
+ allow_multiple_statements: bool = False,
+) -> list[exp.Expr]:
+ """
+ Parse SQL into statements, enforcing the empty- and multi-statement guards
only.
+
+ Shared by :func:`validate_sql` (which then applies statement-type checks)
and by
+ callers that need the parsed AST for their own analysis -- e.g.
table-reference
+ extraction for ``allowed_tables`` enforcement -- without the read-only
allow-list.
+
+ :param sql: SQL string to parse.
+ :param dialect: SQL dialect for parsing (``postgres``, ``mysql``, etc.).
+ :param allow_multiple_statements: Whether to allow multiple
semicolon-separated
+ statements. Default ``False`` -- multi-statement input can hide a
dangerous
+ operation after a benign one.
+ :return: List of parsed sqlglot Expression objects (never empty).
+ :raises SQLSafetyError: If the SQL is empty, cannot be parsed, or contains
multiple
+ statements when not permitted.
+ """
+ if not sql or not sql.strip():
+ raise SQLSafetyError("Empty SQL input.")
+
+ try:
+ statements = sqlglot.parse(sql, dialect=dialect,
error_level=ErrorLevel.RAISE)
+ except sqlglot.errors.ParseError as e:
+ raise SQLSafetyError(f"SQL parse error: {e}") from e
+
+ # sqlglot.parse can return [None] for empty input
+ parsed = [s for s in statements if s is not None]
+ if not parsed:
+ raise SQLSafetyError("Empty SQL input.")
+
+ if not allow_multiple_statements and len(parsed) > 1:
+ raise SQLSafetyError(
+ f"Multiple statements detected ({len(parsed)}). Only single
statements are allowed by default."
+ )
+ return parsed
+
+
+class TableScan(NamedTuple):
+ """Result of :func:`collect_table_references`."""
+
+ #: ``(catalog, schema, table)`` for every real base table referenced
anywhere in
+ #: the AST. ``catalog`` and ``schema`` are ``""`` when the reference omits
them.
+ #: In-scope CTE references are excluded. Catalog is reported so the caller
can
+ #: reject cross-database references (``otherdb.public.orders``) that a
+ #: ``schema.table`` allow-list cannot describe.
+ tables: list[tuple[str, str, str]]
+ #: Human-readable descriptions of constructs that cannot be checked
against an
+ #: allow-list and so must be rejected while one is active: table-valued
functions
+ #: (``dblink``), ``TABLE('name')`` row sources, ``SHOW``, dynamic SQL
+ #: (``EXEC``/``Command``), inline comments (a parser-vs-engine
differential), and
+ #: the ``TABLE <name>`` shorthand. Empty when every construct is
verifiable.
+ unverifiable_sources: list[str]
+
+
+_DML_TYPES: tuple[type[exp.Expr], ...] = (exp.Insert, exp.Update, exp.Delete,
exp.Merge)
+
+
+def _same_identifier(a: exp.Identifier, b: exp.Identifier) -> bool:
+ """
+ Compare two identifiers under standard identifier-folding rules.
+
+ Unquoted names fold (case-insensitive); quoted names are case-preserving
and
+ distinct from unquoted ones. Used to decide whether a table reference
names a CTE:
+ being *stricter* here is safe -- a near-miss falls through to the
allow-list check.
+ """
+ aq, bq = bool(a.args.get("quoted")), bool(b.args.get("quoted"))
+ if not aq and not bq:
+ return str(a.this).casefold() == str(b.this).casefold()
+ if aq and bq:
+ return str(a.this) == str(b.this)
+ return False
+
+
+def _enclosing_cte(table: exp.Expr, with_: exp.With) -> exp.CTE | None:
+ """Return the CTE of ``with_`` whose *definition* contains ``table`` (else
``None``)."""
+ node = table.parent
+ while node is not None and node is not with_.parent:
+ if isinstance(node, exp.CTE) and node.parent is with_:
+ return node
+ node = node.parent
+ return None
+
+
+def _is_in_scope_cte(table: exp.Table) -> bool:
+ """
+ Report whether ``table`` is a bare reference resolved by a CTE visible at
its scope.
+
+ Walks the ancestor chain (lexical scope) collecting CTE names from each
enclosing
+ ``WITH``. A CTE defined in a *sibling* or *inner* subquery is not an
ancestor, so a
+ real top-level table is never excluded by an unrelated same-named CTE
+ (``SELECT * FROM secret WHERE id IN (WITH secret AS (...) SELECT ...)``). A
+ non-recursive CTE is not visible inside its own definition, so
+ ``WITH secret AS (SELECT * FROM secret) ...`` still reports the real
``secret``.
+ CTE order matters too: inside one CTE's body only *earlier* siblings are
in scope
+ (forward references need ``RECURSIVE``), so ``WITH a AS (SELECT * FROM
secret),
+ secret AS (...) SELECT * FROM a`` still reports the real ``secret`` read
by ``a``.
+ """
+ ref = table.this
+ if not isinstance(ref, exp.Identifier):
+ return False
+ node: exp.Expr | None = table.parent
+ while node is not None:
+ # A WITH attaches to its owning query (Select/Union/DML) as a sibling
of the
+ # body, so the query -- an ancestor of the table -- holds it. Find it
by type
+ # rather than a fixed arg key (sqlglot has used both ``with`` and
``with_``).
+ with_ = (
+ next((v for v in node.args.values() if isinstance(v, exp.With)),
None)
+ if isinstance(node, exp.Expression)
+ else None
+ )
+ if isinstance(with_, exp.With):
+ recursive = bool(with_.args.get("recursive"))
+ ctes = list(with_.expressions)
+ enclosing = _enclosing_cte(table, with_)
+ # If the reference sits inside CTE E's own body, only CTEs defined
*before*
+ # E are visible there (plus E itself when RECURSIVE); a CTE
defined after E
+ # is not yet in scope. In the main query body every CTE is visible.
+ enclosing_idx = next((i for i, c in enumerate(ctes) if c is
enclosing), None)
+ for idx, cte in enumerate(ctes):
+ if enclosing_idx is not None:
+ if idx > enclosing_idx:
+ continue
+ if idx == enclosing_idx and not recursive:
+ continue
+ alias = cte.args.get("alias")
+ cte_ident = alias.this if isinstance(alias, exp.TableAlias)
else None
+ if isinstance(cte_ident, exp.Identifier) and
_same_identifier(cte_ident, ref):
+ return True
+ node = node.parent
+ return False
+
+
+def collect_table_references(statements: list[exp.Expr]) -> TableScan:
+ """
+ Walk parsed statements and report every real table they reach,
scope-correctly.
+
+ This is the AST half of ``allowed_tables`` enforcement: it returns the
concrete
+ base tables a query reaches (including those nested in subqueries, CTEs,
JOINs, set
+ operations, ``DESCRIBE``, and DML) as ``(catalog, schema, table)`` so the
caller can
+ check each against its allow-list, plus a list of constructs that cannot
be checked
+ and must therefore be rejected while an allow-list is active.
+
+ Handled carefully (each was a confirmed bypass before it was closed):
+
+ - **CTE references are excluded by lexical scope, not by name.** A table
is treated
+ as a CTE only when a ``WITH`` *enclosing that reference* defines the
name (see
+ :func:`_is_in_scope_cte`); a same-named CTE in a sibling/inner query no
longer
+ hides a real top-level table. A DML *target* is always a real table (you
cannot
+ write to a CTE, so a same-named CTE does not shadow it), but DML
*sources* follow
+ normal CTE scoping -- a CTE used as an INSERT/UPDATE source is not
flagged.
+ - **Catalog-qualified references are reported with their catalog**, so the
caller
+ rejects ``otherdb.public.orders`` instead of matching it to
``public.orders``.
+ - **Unverifiable constructs are listed, not silently dropped:** nameless
+ table-valued functions (``dblink``), ``TABLE('name')`` row sources
+ (``exp.TableFromRows``), ``SHOW``, dynamic SQL (``EXEC``/``Command``),
the
+ ``TABLE <name>`` shorthand (which sqlglot parses incorrectly, leaking the
+ ``TABLE`` keyword as a column), a **quoted identifier** (case-sensitive
on the engine but
+ matched case-insensitively here, so ``"Orders"`` could otherwise reach a
table
+ distinct from the allow-listed ``orders``), and **any inline comment** --
+ comments are where parser-vs-engine differentials hide (MySQL executable
+ ``/*! ... */``, ``--`` not followed by whitespace, ``#``).
+
+ :param statements: Parsed sqlglot statements (from :func:`parse_sql`).
+ :return: A :class:`TableScan` of real table references and unverifiable
constructs.
+ """
+ tables: list[tuple[str, str, str]] = []
+ unverifiable: list[str] = []
+ for stmt in statements:
+ # SHOW enumerates objects / leaks a table's columns outside any single
table.
+ if isinstance(stmt, exp.Show):
+ unverifiable.append("a SHOW statement")
+ continue
+ # Dynamic SQL and anything sqlglot can only represent as a raw Command
reach
+ # data through text the parser cannot inspect.
+ if isinstance(stmt, (exp.Command, exp.Execute)):
+ unverifiable.append(f"a {type(stmt).__name__.lower()} statement")
+ continue
+
+ # A comment is a parser-vs-engine differential vector: sqlglot drops
it, but the
+ # engine may execute it (MySQL `/*! ... */`) or tokenize it
differently (`--`
+ # without a trailing space, `#`). sqlglot tokenizes string literals
correctly,
+ # so a `--` inside a quoted string is not flagged here.
+ if any(node.comments for node in stmt.walk()):
+ unverifiable.append("an inline comment")
+ continue
+
+ # `TABLE('name')` / `TABLE($$name$$)` name a table through a string
the parser
+ # cannot resolve; sqlglot models them as TableFromRows, not exp.Table.
+ if any(True for _ in stmt.find_all(exp.TableFromRows)):
+ unverifiable.append("a TABLE(...) row source")
+ continue
+
+ # `TABLE <name>` (Postgres/MySQL shorthand for SELECT * FROM <name>)
is not
+ # modelled by sqlglot; it parses incorrectly, leaking the reserved
word TABLE as an
+ # unquoted column identifier. No real query has an unquoted column
named TABLE.
+ if any(
+ isinstance(col.this, exp.Identifier)
+ and not col.this.args.get("quoted")
+ and str(col.this.this).upper() == "TABLE"
+ for col in stmt.find_all(exp.Column)
+ ):
+ unverifiable.append("a TABLE <name> shorthand")
+ continue
+
+ # A DML statement's *target* (the table written to) is always a real
table --
+ # you cannot INSERT/UPDATE/DELETE/MERGE into a CTE, so even a
same-named CTE does
+ # not shadow it. Its *sources* (the SELECT/USING/subqueries) follow
normal CTE
+ # scoping, so a CTE used as a source is not mistaken for a base table.
+ target = stmt.args.get("this") if isinstance(stmt, _DML_TYPES) else
None
+ target_ids = {id(t) for t in target.find_all(exp.Table)} if target is
not None else set()
+ for table in stmt.find_all(exp.Table):
+ name = table.name
+ if not name:
+ unverifiable.append(f"table-valued function ({table.sql()})")
+ continue
+ # A bare, non-target reference may be a CTE; a qualified one or a
DML target
+ # never is.
+ if id(table) not in target_ids and not table.db and not
table.catalog and _is_in_scope_cte(table):
+ continue
+ # A quoted identifier is case-sensitive on the engine, but the
allow-list is
+ # matched case-insensitively (and a plain ``schema.table`` string
cannot
+ # carry quoting), so a quoted reference cannot be matched soundly:
on
+ # Postgres/Snowflake ``"Orders"`` is a *different* table from the
allow-listed
+ # ``orders``. Reject rather than risk reaching a case-distinct
table.
+ if any(
+ isinstance(part, exp.Identifier) and part.args.get("quoted")
+ for part in (table.this, table.args.get("db"),
table.args.get("catalog"))
+ ):
+ unverifiable.append("a quoted identifier")
+ continue
+ tables.append((table.catalog, table.db, name))
+ return TableScan(tables=tables, unverifiable_sources=unverifiable)
+
+
def validate_sql(
sql: str,
*,
@@ -138,9 +379,6 @@ def validate_sql(
:raises SQLSafetyError: If the SQL is empty, contains disallowed statement
types,
or has multiple statements when not permitted.
"""
- if not sql or not sql.strip():
- raise SQLSafetyError("Empty SQL input.")
-
# A caller-supplied ``allowed_types`` is an explicit opt-out of the curated
# read-only defaults (and the data-modifying deep scan). Otherwise we use
the
# read-only defaults, optionally widened with metadata statements, and keep
@@ -154,20 +392,7 @@ def validate_sql(
types = allowed_types
run_data_modifying_scan = types == DEFAULT_ALLOWED_TYPES
- try:
- statements = sqlglot.parse(sql, dialect=dialect,
error_level=ErrorLevel.RAISE)
- except sqlglot.errors.ParseError as e:
- raise SQLSafetyError(f"SQL parse error: {e}") from e
-
- # sqlglot.parse can return [None] for empty input
- parsed = [s for s in statements if s is not None]
- if not parsed:
- raise SQLSafetyError("Empty SQL input.")
-
- if not allow_multiple_statements and len(parsed) > 1:
- raise SQLSafetyError(
- f"Multiple statements detected ({len(parsed)}). Only single
statements are allowed by default."
- )
+ parsed = parse_sql(sql, dialect=dialect,
allow_multiple_statements=allow_multiple_statements)
for stmt in parsed:
if not isinstance(stmt, types):
diff --git a/providers/common/ai/tests/unit/common/ai/toolsets/test_sql.py
b/providers/common/ai/tests/unit/common/ai/toolsets/test_sql.py
index 92e033a6b60..ed0619a1db0 100644
--- a/providers/common/ai/tests/unit/common/ai/toolsets/test_sql.py
+++ b/providers/common/ai/tests/unit/common/ai/toolsets/test_sql.py
@@ -467,3 +467,285 @@ class TestSQLToolsetMetadataStatements:
ts.call_tool("check_query", {"sql": "SELECT 1"},
ctx=MagicMock(), tool=MagicMock())
)
assert json.loads(result)["valid"] is True
+
+
+def _run_query(ts: SQLToolset, sql: str):
+ return asyncio.run(ts.call_tool("query", {"sql": sql}, ctx=MagicMock(),
tool=MagicMock()))
+
+
+def _run_check(ts: SQLToolset, sql: str):
+ return json.loads(
+ asyncio.run(ts.call_tool("check_query", {"sql": sql}, ctx=MagicMock(),
tool=MagicMock()))
+ )
+
+
+class TestSQLToolsetAllowedTablesQueryEnforcement:
+ """``allowed_tables`` is enforced on the query/check_query tools, not just
on discovery."""
+
+ def test_query_allows_table_on_the_list(self):
+ ts = SQLToolset("pg_default", allowed_tables=["orders"])
+ ts._hook = _make_mock_db_hook(records=[(1,)],
last_description=[("id",)])
+
+ result = _run_query(ts, "SELECT id FROM orders")
+
+ assert "rows" in json.loads(result)
+ ts._hook.get_records.assert_called_once_with("SELECT id FROM orders")
+
+ def test_query_blocks_table_off_the_list(self):
+ """The headline escape: querying a table that is not on the allow-list
is refused."""
+ ts = SQLToolset("pg_default", allowed_tables=["orders"])
+ ts._hook = _make_mock_db_hook()
+
+ with pytest.raises(ModelRetry) as exc_info:
+ _run_query(ts, "SELECT * FROM secret_salaries")
+
+ assert "not in the allowed tables list" in exc_info.value.message
+ assert "secret_salaries" in exc_info.value.message
+ ts._hook.get_records.assert_not_called()
+
+ @pytest.mark.parametrize(
+ "sql",
+ [
+ "SELECT * FROM (SELECT * FROM secret_salaries) x",
+ "WITH s AS (SELECT * FROM secret_salaries) SELECT * FROM s",
+ "SELECT * FROM orders JOIN secret_salaries ON orders.id =
secret_salaries.id",
+ "SELECT * FROM orders UNION SELECT * FROM secret_salaries",
+ "SELECT * FROM secret_salaries WHERE id IN (SELECT id FROM
orders)",
+ ],
+ ids=["subquery", "cte_body", "join", "union", "where_subquery"],
+ )
+ def test_query_blocks_disallowed_table_reached_indirectly(self, sql):
+ ts = SQLToolset("pg_default", allowed_tables=["orders"])
+ ts._hook = _make_mock_db_hook()
+
+ with pytest.raises(ModelRetry) as exc_info:
+ _run_query(ts, sql)
+
+ assert "secret_salaries" in exc_info.value.message
+ ts._hook.get_records.assert_not_called()
+
+ def test_query_blocks_catalog_enumeration(self):
+ """information_schema/pg_catalog are ordinary tables, so the
allow-list blocks them too."""
+ ts = SQLToolset("pg_default", allowed_tables=["orders"])
+ ts._hook = _make_mock_db_hook()
+
+ with pytest.raises(ModelRetry) as exc_info:
+ _run_query(ts, "SELECT table_name FROM information_schema.tables")
+
+ assert "information_schema.tables" in exc_info.value.message
+ ts._hook.get_records.assert_not_called()
+
+ def test_query_allows_cte_reference_not_mistaken_for_table(self):
+ """A CTE whose name is not on the list is fine as long as its body
stays allowed."""
+ ts = SQLToolset("pg_default", allowed_tables=["orders"])
+ ts._hook = _make_mock_db_hook(records=[(1,)],
last_description=[("id",)])
+
+ result = _run_query(ts, "WITH ranked AS (SELECT * FROM orders) SELECT
* FROM ranked")
+
+ assert "rows" in json.loads(result)
+
+ def test_query_blocks_table_valued_function(self):
+ """dblink reaches data through a path the list can't describe, so it
is refused."""
+ ts = SQLToolset("pg_default", allowed_tables=["orders"])
+ ts._hook = _make_mock_db_hook()
+ ts._hook.dialect_name = "postgresql"
+
+ with pytest.raises(ModelRetry) as exc_info:
+ _run_query(ts, "SELECT * FROM dblink('host=evil', 'SELECT 1') AS
t(x int)")
+
+ assert "cannot be checked against allowed_tables" in
exc_info.value.message
+ ts._hook.get_records.assert_not_called()
+
+ def test_query_blocks_show_when_allowlist_active(self):
+ ts = SQLToolset("sf_default", allowed_tables=["orders"])
+ ts._hook = _make_mock_db_hook()
+ ts._hook.dialect_name = "snowflake"
+
+ with pytest.raises(ModelRetry) as exc_info:
+ _run_query(ts, "SHOW TABLES")
+
+ assert "cannot be checked against allowed_tables" in
exc_info.value.message
+ ts._hook.get_records.assert_not_called()
+
+ def test_query_blocks_describe_of_disallowed_table(self):
+ ts = SQLToolset("sf_default", allowed_tables=["orders"])
+ ts._hook = _make_mock_db_hook()
+ ts._hook.dialect_name = "snowflake"
+
+ with pytest.raises(ModelRetry) as exc_info:
+ _run_query(ts, "DESCRIBE TABLE secret_salaries")
+
+ assert "secret_salaries" in exc_info.value.message
+ ts._hook.get_records.assert_not_called()
+
+ def test_query_allows_describe_of_allowed_table(self):
+ ts = SQLToolset("sf_default", allowed_tables=["orders"])
+ ts._hook = _make_mock_db_hook(records=[("id", "INT")],
last_description=[("name",), ("type",)])
+ ts._hook.dialect_name = "snowflake"
+
+ result = _run_query(ts, "DESCRIBE TABLE orders")
+
+ assert "rows" in json.loads(result)
+ ts._hook.get_records.assert_called_once_with("DESCRIBE TABLE orders")
+
+ def test_query_allows_schema_qualified_table_on_list(self):
+ ts = SQLToolset("sf", allowed_tables=["MODEL_CRM.SF_ASTRO_ORGS"])
+ ts._hook = _make_mock_db_hook(records=[(1,)],
last_description=[("id",)])
+ ts._hook.dialect_name = "snowflake"
+
+ result = _run_query(ts, "SELECT * FROM MODEL_CRM.SF_ASTRO_ORGS")
+
+ assert "rows" in json.loads(result)
+
+ def test_query_unqualified_resolves_to_default_schema(self):
+ """``public.orders`` and ``orders`` denote the same table when
schema='public'."""
+ ts = SQLToolset("pg", allowed_tables=["orders"], schema="public")
+ ts._hook = _make_mock_db_hook(records=[(1,)],
last_description=[("id",)])
+
+ # Qualifying with the default schema must still match the bare
allow-list entry.
+ result = _run_query(ts, "SELECT * FROM public.orders")
+ assert "rows" in json.loads(result)
+
+ def test_no_allowlist_leaves_queries_unrestricted(self):
+ """Without allowed_tables the query tool behaves exactly as before
(allow-all)."""
+ ts = SQLToolset("pg_default")
+ ts._hook = _make_mock_db_hook(records=[(1,)],
last_description=[("id",)])
+
+ result = _run_query(ts, "SELECT * FROM anything_at_all")
+
+ assert "rows" in json.loads(result)
+ ts._hook.get_records.assert_called_once_with("SELECT * FROM
anything_at_all")
+
+ def test_check_query_reports_disallowed_table_as_invalid(self):
+ ts = SQLToolset("pg_default", allowed_tables=["orders"])
+ ts._hook = _make_mock_db_hook()
+
+ data = _run_check(ts, "SELECT * FROM secret_salaries")
+
+ assert data["valid"] is False
+ assert "secret_salaries" in data["error"]
+
+ def test_check_query_valid_for_allowed_table(self):
+ ts = SQLToolset("pg_default", allowed_tables=["orders"])
+ ts._hook = _make_mock_db_hook()
+
+ assert _run_check(ts, "SELECT * FROM orders")["valid"] is True
+
+ def test_writes_still_bounded_by_allowed_tables(self):
+ """allow_writes widens the statement types, but the allow-list still
scopes the target."""
+ ts = SQLToolset("pg_default", allowed_tables=["orders"],
allow_writes=True)
+ ts._hook = _make_mock_db_hook(records=[], last_description=None)
+
+ # An allowed target is written.
+ _run_query(ts, "INSERT INTO orders (id) VALUES (1)")
+ ts._hook.get_records.assert_called_once_with("INSERT INTO orders (id)
VALUES (1)")
+
+ # A disallowed target is refused before execution.
+ ts._hook.get_records.reset_mock()
+ with pytest.raises(ModelRetry) as exc_info:
+ _run_query(ts, "INSERT INTO secret_salaries (id) VALUES (1)")
+ assert "secret_salaries" in exc_info.value.message
+ ts._hook.get_records.assert_not_called()
+
+ def test_writes_reject_dynamic_sql_the_parser_cannot_inspect(self):
+ """allow_writes skips the read-only validator, so the allow-list must
still
+ refuse dynamic SQL (EXEC/EXECUTE) whose table access is opaque."""
+ ts = SQLToolset("mssql_default", allowed_tables=["orders"],
allow_writes=True)
+ ts._hook = _make_mock_db_hook()
+ ts._hook.dialect_name = "mssql"
+
+ with pytest.raises(ModelRetry) as exc_info:
+ _run_query(ts, "EXEC sp_who")
+
+ assert "cannot be checked against allowed_tables" in
exc_info.value.message
+ ts._hook.get_records.assert_not_called()
+
+
+class TestSQLToolsetAllowedTablesBypassRegressions:
+ """Regression tests for bypasses found by adversarial red-teaming of the
allow-list."""
+
+ @pytest.mark.parametrize(
+ ("sql", "dialect", "allow_writes"),
+ [
+ # CTE scope: a same-named CTE in an inner/sibling scope must not
hide the real table.
+ (
+ "SELECT * FROM secret_salaries WHERE id IN "
+ "(WITH secret_salaries AS (SELECT 1 id) SELECT id FROM
secret_salaries)",
+ "postgresql",
+ False,
+ ),
+ # Non-recursive CTE is not in scope within its own body.
+ (
+ "WITH secret_salaries AS (SELECT * FROM secret_salaries)
SELECT * FROM secret_salaries",
+ "postgresql",
+ False,
+ ),
+ # A CTE may only reference earlier siblings; a later-defined name
is the real table.
+ (
+ "WITH a AS (SELECT * FROM secret_salaries), secret_salaries AS
(SELECT 1 id) SELECT * FROM a",
+ "postgresql",
+ False,
+ ),
+ # Cross-database / catalog qualifier the schema.table allow-list
cannot describe.
+ ("SELECT * FROM secretdb.public.orders", "snowflake", False),
+ ("SELECT * FROM secret_salaries..orders", "mssql", False),
+ # MySQL executable comments execute on the engine but sqlglot
treats them as inert.
+ ("SELECT * FROM orders/*!UNION SELECT * FROM secret_salaries*/",
"mysql", False),
+ ("SELECT id FROM orders /*!50000 UNION SELECT id FROM
secret_salaries */", "mysql", False),
+ # TABLE <name> shorthand (mis-parsed) and TABLE('name') row source
(string-named).
+ ("TABLE secret_salaries UNION SELECT * FROM orders", "postgresql",
False),
+ ("SELECT * FROM TABLE('secret_salaries')", "snowflake", False),
+ # Write-mode CTE shadowing the DML target.
+ ("WITH secret_salaries AS (SELECT 1) DELETE FROM secret_salaries",
"postgresql", True),
+ # Quoted identifier is case-distinct on the engine but case-folds
into the list.
+ ('SELECT * FROM "Orders"', "postgresql", False),
+ # A DML source CTE whose body reads an off-list table is still
caught.
+ (
+ "WITH src AS (SELECT * FROM secret_salaries) INSERT INTO
orders SELECT * FROM src",
+ "postgresql",
+ True,
+ ),
+ ],
+ ids=[
+ "cte_inner_shadow",
+ "cte_self_body",
+ "cte_forward_ref",
+ "catalog_cross_db",
+ "mssql_empty_middle",
+ "mysql_exec_comment",
+ "mysql_versioned_comment",
+ "table_shorthand",
+ "table_row_source",
+ "write_cte_target",
+ "quoted_case_distinct",
+ "dml_cte_body_reads_offlist",
+ ],
+ )
+ def test_known_bypasses_are_rejected(self, sql, dialect, allow_writes):
+ ts = SQLToolset("c", allowed_tables=["orders"],
allow_writes=allow_writes)
+ ts._hook = _make_mock_db_hook()
+ ts._hook.dialect_name = dialect
+
+ with pytest.raises(ModelRetry):
+ _run_query(ts, sql)
+ ts._hook.get_records.assert_not_called()
+
+ def test_legit_cte_over_allowed_table_still_runs(self):
+ """The scope-aware fix must not false-reject a genuine CTE over an
allowed table."""
+ ts = SQLToolset("c", allowed_tables=["orders"])
+ ts._hook = _make_mock_db_hook(records=[(1,)],
last_description=[("id",)])
+
+ result = _run_query(ts, "WITH ranked AS (SELECT * FROM orders) SELECT
* FROM ranked")
+
+ assert "rows" in json.loads(result)
+ ts._hook.get_records.assert_called_once()
+
+ def test_dml_with_cte_source_over_allowed_table_runs(self):
+ """A CTE used as a DML source must not be mistaken for a disallowed
base table."""
+ ts = SQLToolset("c", allowed_tables=["orders"], allow_writes=True)
+ ts._hook = _make_mock_db_hook(records=[], last_description=None)
+
+ sql = "WITH src AS (SELECT * FROM orders) INSERT INTO orders SELECT *
FROM src"
+ _run_query(ts, sql)
+
+ ts._hook.get_records.assert_called_once_with(sql)
diff --git
a/providers/common/ai/tests/unit/common/ai/utils/test_sql_validation.py
b/providers/common/ai/tests/unit/common/ai/utils/test_sql_validation.py
index 9ca6604ba5e..257761d78e0 100644
--- a/providers/common/ai/tests/unit/common/ai/utils/test_sql_validation.py
+++ b/providers/common/ai/tests/unit/common/ai/utils/test_sql_validation.py
@@ -21,6 +21,8 @@ from sqlglot import exp
from airflow.providers.common.ai.utils.sql_validation import (
SQLSafetyError,
+ collect_table_references,
+ parse_sql,
resolve_sqlglot_dialect,
validate_sql,
)
@@ -312,3 +314,193 @@ class TestResolveSqlglotDialect:
)
def test_resolution(self, dialect_name, expected):
assert resolve_sqlglot_dialect(dialect_name) == expected
+
+
+class TestParseSQL:
+ """``parse_sql`` enforces only the empty- and multi-statement guards."""
+
+ def test_returns_statements(self):
+ parsed = parse_sql("SELECT 1")
+ assert len(parsed) == 1
+ assert isinstance(parsed[0], exp.Select)
+
+ def test_does_not_apply_type_checks(self):
+ """Unlike validate_sql, parse_sql accepts writes -- callers add their
own policy."""
+ parsed = parse_sql("DELETE FROM users WHERE id = 1")
+ assert isinstance(parsed[0], exp.Delete)
+
+ @pytest.mark.parametrize("sql", ["", " ", "\n\t"])
+ def test_rejects_empty(self, sql):
+ with pytest.raises(SQLSafetyError, match="Empty SQL"):
+ parse_sql(sql)
+
+ def test_rejects_multiple_statements_by_default(self):
+ with pytest.raises(SQLSafetyError, match="Multiple statements"):
+ parse_sql("SELECT 1; SELECT 2")
+
+ def test_allows_multiple_statements_when_opted_in(self):
+ assert len(parse_sql("SELECT 1; SELECT 2",
allow_multiple_statements=True)) == 2
+
+ def test_rejects_unparsable(self):
+ with pytest.raises(SQLSafetyError, match="parse error"):
+ parse_sql("SELECT FROM WHERE )(")
+
+
+class TestCollectTableReferences:
+ """``collect_table_references`` reports the real tables a query reaches."""
+
+ @pytest.mark.parametrize(
+ ("sql", "dialect", "expected"),
+ [
+ ("SELECT * FROM secret", None, [("", "", "secret")]),
+ ("SELECT * FROM model_crm.orders", None, [("", "model_crm",
"orders")]),
+ ("SELECT * FROM a JOIN b ON a.id = b.id", None, [("", "", "a"),
("", "", "b")]),
+ ("SELECT * FROM (SELECT * FROM inner_t) x", None, [("", "",
"inner_t")]),
+ ("SELECT * FROM a UNION SELECT * FROM b", None, [("", "", "a"),
("", "", "b")]),
+ (
+ "SELECT table_name FROM information_schema.tables",
+ "postgres",
+ [("", "information_schema", "tables")],
+ ),
+ ("DESCRIBE secret", "mysql", [("", "", "secret")]),
+ # Cross-database reference carries its catalog so the caller can
reject it.
+ ("SELECT * FROM otherdb.public.orders", "snowflake", [("otherdb",
"public", "orders")]),
+ ],
+ ids=["table", "qualified", "join", "subquery", "union", "catalog",
"describe", "cross_db"],
+ )
+ def test_collects_real_tables(self, sql, dialect, expected):
+ scan = collect_table_references(parse_sql(sql, dialect=dialect))
+ assert sorted(scan.tables) == sorted(expected)
+ assert scan.unverifiable_sources == []
+
+ def test_excludes_cte_reference_but_keeps_its_body(self):
+ scan = collect_table_references(parse_sql("WITH s AS (SELECT * FROM
base) SELECT * FROM s"))
+ assert scan.tables == [("", "", "base")] # 's' is the CTE, not a table
+
+ def test_cte_that_shadows_a_table_name_yields_no_table(self):
+ scan = collect_table_references(parse_sql("WITH secret AS (SELECT 1 AS
x) SELECT * FROM secret"))
+ assert scan.tables == []
+
+ def test_schema_qualified_name_is_never_treated_as_a_cte(self):
+ scan = collect_table_references(parse_sql("WITH s AS (SELECT 1 AS x)
SELECT * FROM myschema.s"))
+ assert scan.tables == [("", "myschema", "s")]
+
+ def test_inner_cte_does_not_shadow_outer_real_table(self):
+ """A same-named CTE in an inner subquery must not hide the real
top-level table."""
+ sql = "SELECT * FROM secret WHERE id IN (WITH secret AS (SELECT 1 id)
SELECT id FROM secret)"
+ scan = collect_table_references(parse_sql(sql))
+ assert ("", "", "secret") in scan.tables # the top-level real table
is reported
+
+ def test_cte_self_body_references_real_table(self):
+ """A non-recursive CTE is not in scope within its own body, so the
real table shows."""
+ scan = collect_table_references(
+ parse_sql("WITH secret AS (SELECT * FROM secret) SELECT * FROM
secret")
+ )
+ assert ("", "", "secret") in scan.tables
+
+ def test_cte_forward_reference_is_real_table(self):
+ """A CTE may only reference earlier siblings; a later-defined name is
the real table."""
+ sql = "WITH a AS (SELECT * FROM secret), secret AS (SELECT 1 id)
SELECT * FROM a"
+ scan = collect_table_references(parse_sql(sql))
+ assert ("", "", "secret") in scan.tables
+
+ def test_legit_cte_reference_is_excluded(self):
+ """A genuine CTE reference is not reported as a base table (no false
reject)."""
+ scan = collect_table_references(
+ parse_sql("WITH ranked AS (SELECT * FROM orders) SELECT * FROM
ranked")
+ )
+ assert scan.tables == [("", "", "orders")]
+
+ @pytest.mark.parametrize(
+ ("sql", "dialect"),
+ [
+ ("SELECT * FROM orders/*!UNION SELECT * FROM secret*/", "mysql"),
+ ("SELECT * FROM orders /*!50000 UNION SELECT * FROM secret */",
"mysql"),
+ ("SELECT * FROM orders WHERE 0--+1 OR id IN (SELECT id FROM
secret)", "mysql"),
+ ],
+ ids=["exec_comment", "versioned_comment", "dashdash"],
+ )
+ def test_flags_inline_comment_as_unverifiable(self, sql, dialect):
+ """Comments hide parser-vs-engine differentials (MySQL executable
comments), so reject them."""
+ scan = collect_table_references(parse_sql(sql, dialect=dialect))
+ assert scan.unverifiable_sources
+
+ @pytest.mark.parametrize(
+ "sql",
+ ["SELECT * FROM TABLE('secret')", "SELECT * FROM TABLE($$secret$$)"],
+ ids=["string", "dollar"],
+ )
+ def test_flags_table_row_source_as_unverifiable(self, sql):
+ """Snowflake TABLE('name') names a table through a string the parser
can't resolve."""
+ scan = collect_table_references(parse_sql(sql, dialect="snowflake"))
+ assert scan.unverifiable_sources
+
+ def test_flags_table_shorthand_as_unverifiable(self):
+ """The TABLE <name> shorthand is mis-parsed by sqlglot, so reject
it."""
+ scan = collect_table_references(
+ parse_sql("TABLE secret UNION SELECT * FROM orders",
dialect="postgres")
+ )
+ assert scan.unverifiable_sources
+
+ @pytest.mark.parametrize(
+ "sql",
+ ['SELECT * FROM public."Orders"', 'SELECT * FROM "Orders"', 'SELECT *
FROM "PUBLIC".orders'],
+ ids=["quoted_table", "quoted_bare", "quoted_schema"],
+ )
+ def test_flags_quoted_identifier_as_unverifiable(self, sql):
+ """A quoted identifier is case-sensitive; case-insensitive matching
can't verify it."""
+ scan = collect_table_references(parse_sql(sql, dialect="postgres"))
+ assert scan.unverifiable_sources
+
+ def test_dml_target_is_real_but_cte_source_is_excluded(self):
+ """A CTE used as a DML source is not a base table; the target and CTE
body are."""
+ sql = "WITH src AS (SELECT * FROM orders) INSERT INTO orders SELECT *
FROM src"
+ scan = collect_table_references(parse_sql(sql, dialect="postgres"))
+ # Only the real table `orders` is reported (target + CTE body); `src`
is the CTE.
+ assert {t for _, _, t in scan.tables} == {"orders"}
+ assert scan.unverifiable_sources == []
+
+ def test_dml_target_shadowed_by_cte_is_still_reported(self):
+ """A DML target is a real table even when a same-named CTE exists
(can't write a CTE)."""
+ sql = "WITH secret AS (SELECT 1) DELETE FROM secret"
+ scan = collect_table_references(parse_sql(sql, dialect="postgres"))
+ assert ("", "", "secret") in scan.tables
+
+ @pytest.mark.parametrize(
+ ("sql", "dialect"),
+ [
+ ("SELECT * FROM dblink('h', 'SELECT 1') AS t(x int)", "postgres"),
+ ("SELECT * FROM generate_series(1, 10)", "postgres"),
+ ],
+ ids=["dblink", "generate_series"],
+ )
+ def test_flags_table_valued_functions_as_unverifiable(self, sql, dialect):
+ scan = collect_table_references(parse_sql(sql, dialect=dialect))
+ assert scan.tables == []
+ assert scan.unverifiable_sources
+
+ @pytest.mark.parametrize(
+ ("sql", "dialect"),
+ [("SHOW TABLES", "snowflake"), ("SHOW COLUMNS FROM secret", "mysql")],
+ ids=["show_tables", "show_columns"],
+ )
+ def test_flags_show_as_unverifiable(self, sql, dialect):
+ scan = collect_table_references(parse_sql(sql, dialect=dialect))
+ assert scan.unverifiable_sources
+
+ def test_scalar_function_without_table_has_no_references(self):
+ """A scalar function call references no table -- the allow-list does
not cover it."""
+ scan = collect_table_references(parse_sql("SELECT
pg_read_file('/etc/passwd')", dialect="postgres"))
+ assert scan.tables == []
+ assert scan.unverifiable_sources == []
+
+ @pytest.mark.parametrize(
+ ("sql", "dialect"),
+ [("EXEC sp_who", "tsql"), ("EXECUTE my_proc", "tsql")],
+ ids=["exec", "execute"],
+ )
+ def test_flags_dynamic_sql_as_unverifiable(self, sql, dialect):
+ """EXEC/EXECUTE hide their table access in text the parser can't
read."""
+ scan = collect_table_references(parse_sql(sql, dialect=dialect))
+ assert scan.tables == []
+ assert scan.unverifiable_sources