This is an automated email from the ASF dual-hosted git repository.

kaxil pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new d88504564c5 Support multi-schema introspection in common.ai SQLToolset 
(#68103)
d88504564c5 is described below

commit d88504564c506e7edfbd64cb0589f6ac9d6263c5
Author: Kaxil Naik <[email protected]>
AuthorDate: Sat Jun 6 01:17:53 2026 +0100

    Support multi-schema introspection in common.ai SQLToolset (#68103)
    
    SQLToolset's metadata tools (list_tables, get_schema) operated against a 
single
    schema, so an agent over a multi-schema warehouse (common on Snowflake) 
could
    not discover tables across schemas. With no schema set and schema-qualified
    tables, list_tables introspected a literal "None" schema
    (SHOW TABLES IN SCHEMA "DB"."None") and failed outright.
    
    allowed_tables entries may now be schema-qualified ("SCHEMA.TABLE"). 
list_tables
    introspects each referenced schema and returns the matching tables fully
    qualified, and get_schema routes each qualified name to its own schema.
    Unqualified entries and the allow-all case keep the previous single-schema
    behaviour using the default schema. Table-name matching is case-insensitive,
    because databases reflect identifiers in their own case (Snowflake reflects
    unquoted names lowercased) and a byte-exact match would silently return 
nothing.
    Results are de-duplicated by (schema, table) so a table reachable both 
qualified
    and via the default schema is listed once.
---
 providers/common/ai/docs/toolsets.rst              |  29 ++++-
 .../airflow/providers/common/ai/toolsets/sql.py    |  84 +++++++++++--
 .../ai/tests/unit/common/ai/toolsets/test_sql.py   | 132 ++++++++++++++++++++-
 3 files changed, 230 insertions(+), 15 deletions(-)

diff --git a/providers/common/ai/docs/toolsets.rst 
b/providers/common/ai/docs/toolsets.rst
index a9f454fbc89..617c63520b1 100644
--- a/providers/common/ai/docs/toolsets.rst
+++ b/providers/common/ai/docs/toolsets.rst
@@ -146,14 +146,39 @@ Curated toolset wrapping
 The ``DbApiHook`` is resolved lazily from ``db_conn_id`` on first tool call
 via ``BaseHook.get_connection(conn_id).get_hook()``.
 
+Multi-schema warehouses
+^^^^^^^^^^^^^^^^^^^^^^^^^
+
+When an agent's tables live in several schemas of one database -- common on
+Snowflake -- list them with schema-qualified ``allowed_tables`` entries:
+
+.. code-block:: python
+
+    SQLToolset(
+        db_conn_id="snowflake_hq",
+        allowed_tables=["MODEL_ASTRO.DEPLOYMENT_IMAGE_DETAILS", 
"MODEL_CRM.SF_ASTRO_ORGS"],
+    )
+
+``list_tables`` then introspects each referenced schema and returns the 
matching
+tables fully qualified (e.g. ``MODEL_ASTRO.DEPLOYMENT_IMAGE_DETAILS``), and
+``get_schema`` routes each qualified name to its own schema. Without this, a
+single ``schema`` only covers one namespace, and leaving ``schema`` unset made
+introspection query a literal ``"None"`` schema and fail. Unqualified entries
+fall back to ``schema``, and table-name matching is case-insensitive (databases
+reflect identifiers in their own case). For tables in a different *database*, 
use
+a separate toolset whose connection points at that database.
+
 Parameters
 ^^^^^^^^^^
 
 - ``db_conn_id``: Airflow connection ID for the database.
 - ``allowed_tables``: Restrict which tables the agent can discover via
-  ``list_tables`` and ``get_schema``. ``None`` (default) exposes all tables.
+  ``list_tables`` and ``get_schema``. ``None`` (default) exposes all tables in
+  ``schema``. Entries may be schema-qualified (``"SCHEMA.TABLE"``) to span
+  multiple schemas; see above. Matching is case-insensitive.
   See :ref:`allowed-tables-limitation` for an important caveat.
-- ``schema``: Database schema/namespace for table listing and introspection.
+- ``schema``: Default schema/namespace for unqualified table listing and
+  introspection. Schema-qualified ``allowed_tables`` entries override it per 
table.
 - ``allow_writes``: Allow data-modifying SQL (INSERT, UPDATE, DELETE, etc.).
   Default ``False`` — only SELECT-family statements are permitted.
 - ``max_rows``: Maximum rows returned from the ``query`` tool. Default ``50``.
diff --git 
a/providers/common/ai/src/airflow/providers/common/ai/toolsets/sql.py 
b/providers/common/ai/src/airflow/providers/common/ai/toolsets/sql.py
index 0902cff99f2..fca07177597 100644
--- a/providers/common/ai/src/airflow/providers/common/ai/toolsets/sql.py
+++ b/providers/common/ai/src/airflow/providers/common/ai/toolsets/sql.py
@@ -111,7 +111,13 @@ class SQLToolset(AbstractToolset[Any]):
 
     :param db_conn_id: Airflow connection ID for the database.
     :param allowed_tables: Restrict which tables the agent can discover via
-        ``list_tables`` and ``get_schema``. ``None`` (default) exposes all 
tables.
+        ``list_tables`` and ``get_schema``. ``None`` (default) exposes all 
tables
+        in ``schema``. Entries may be schema-qualified (``"SCHEMA.TABLE"``) to 
span
+        multiple schemas in one database -- common on warehouses such as 
Snowflake.
+        ``list_tables`` then introspects each referenced schema and returns the
+        matching tables fully qualified, and ``get_schema`` routes to the 
table's
+        own schema. Unqualified entries use ``schema``. Matching is
+        case-insensitive, since databases reflect identifiers in their own 
case.
 
         .. note::
             ``allowed_tables`` controls metadata visibility only. It does 
**not**
@@ -120,7 +126,10 @@ class SQLToolset(AbstractToolset[Any]):
             restrictions, use database-level permissions (e.g. a read-only role
             with grants limited to specific tables).
 
-    :param schema: Database schema/namespace for table listing and 
introspection.
+    :param schema: Default schema/namespace for table listing and 
introspection,
+        used for unqualified ``allowed_tables`` entries and unqualified
+        ``get_schema`` calls. Schema-qualified ``allowed_tables`` entries 
override
+        it per table.
     :param allow_writes: Allow data-modifying SQL (INSERT, UPDATE, DELETE, 
etc.).
         Default ``False`` — only SELECT-family statements are permitted.
     :param max_rows: Maximum number of rows returned from the ``query`` tool.
@@ -138,11 +147,43 @@ class SQLToolset(AbstractToolset[Any]):
     ) -> None:
         self._db_conn_id = db_conn_id
         self._allowed_tables: frozenset[str] | None = 
frozenset(allowed_tables) if allowed_tables else None
+        # Case-folded view for membership tests: databases reflect identifiers 
in
+        # their own case (Snowflake stores unquoted names uppercase but 
reflects
+        # them lowercased), so a byte-exact match against the user's entries 
would
+        # silently miss. allowed_tables is a visibility hint, not access 
control,
+        # so case-insensitive matching is safe.
+        self._allowed_tables_ci: frozenset[str] | None = (
+            frozenset(t.casefold() for t in self._allowed_tables)
+            if self._allowed_tables is not None
+            else None
+        )
         self._schema = schema
         self._allow_writes = allow_writes
         self._max_rows = max_rows
         self._hook: DbApiHook | None = None
 
+        # Derive which schemas to introspect from schema-qualified 
allowed_tables.
+        # Qualified entries ("SCHEMA.TABLE") are listed under their own schema 
and
+        # returned fully qualified; unqualified entries (and allow-all) use the
+        # default ``schema``.
+        self._qualified_schemas: frozenset[str] = frozenset()
+        self._include_default_schema: bool = True
+        if self._allowed_tables is not None:
+            qualified_schemas: set[str] = set()
+            include_default = False
+            for entry in self._allowed_tables:
+                entry_schema, sep, _ = entry.rpartition(".")
+                if sep:
+                    qualified_schemas.add(entry_schema)
+                else:
+                    include_default = True
+            self._qualified_schemas = frozenset(qualified_schemas)
+            self._include_default_schema = include_default
+
+    def _is_table_allowed(self, name: str) -> bool:
+        """Case-insensitive membership test against ``allowed_tables`` 
(allow-all when unset)."""
+        return self._allowed_tables_ci is None or name.casefold() in 
self._allowed_tables_ci
+
     @property
     def id(self) -> str:
         return f"sql-{self._db_conn_id}"
@@ -213,18 +254,47 @@ class SQLToolset(AbstractToolset[Any]):
     # Tool implementations
     # ------------------------------------------------------------------
 
+    def _split_table_identifier(self, table_name: str) -> tuple[str | None, 
str]:
+        """Split ``"SCHEMA.TABLE"`` into ``(schema, table)``; unqualified uses 
the default schema."""
+        schema, sep, table = table_name.rpartition(".")
+        if not sep:
+            return self._schema, table_name
+        return schema, table
+
     def _list_tables(self) -> str:
         hook = self._get_db_hook()
-        tables: list[str] = hook.inspector.get_table_names(schema=self._schema)
-        if self._allowed_tables is not None:
-            tables = [t for t in tables if t in self._allowed_tables]
+        tables: list[str] = []
+        # Dedupe by (schema, table) so a table reachable both qualified and 
via the
+        # default schema (e.g. "public.users" and "users" with 
schema="public") is
+        # listed once. Case-folded because databases reflect identifiers in 
their case.
+        seen: set[tuple[str | None, str]] = set()
+
+        def add(schema: str | None, name: str, display: str) -> None:
+            key = (schema.casefold() if schema else None, name.casefold())
+            if self._is_table_allowed(display) and key not in seen:
+                seen.add(key)
+                tables.append(display)
+
+        # Schemas referenced by qualified allowed_tables entries: introspect 
each
+        # and return matching tables fully qualified so they round-trip to 
get_schema.
+        for schema in sorted(self._qualified_schemas):
+            for name in hook.inspector.get_table_names(schema=schema):
+                add(schema, name, f"{schema}.{name}")
+
+        # Default schema: used for allow-all and unqualified allowed_tables 
entries.
+        # Names stay bare to preserve the single-schema behaviour.
+        if self._include_default_schema:
+            for name in hook.inspector.get_table_names(schema=self._schema):
+                add(self._schema, name, name)
+
         return json.dumps(tables)
 
     def _get_schema(self, table_name: str) -> str:
-        if self._allowed_tables is not None and table_name not in 
self._allowed_tables:
+        if not self._is_table_allowed(table_name):
             return json.dumps({"error": f"Table {table_name!r} is not in the 
allowed tables list."})
         hook = self._get_db_hook()
-        columns = hook.get_table_schema(table_name, schema=self._schema)
+        schema, table = self._split_table_identifier(table_name)
+        columns = hook.get_table_schema(table, schema=schema)
         return json.dumps(columns)
 
     def _query(self, sql: str) -> str:
diff --git a/providers/common/ai/tests/unit/common/ai/toolsets/test_sql.py 
b/providers/common/ai/tests/unit/common/ai/toolsets/test_sql.py
index 471b956385d..c1aae15aad5 100644
--- a/providers/common/ai/tests/unit/common/ai/toolsets/test_sql.py
+++ b/providers/common/ai/tests/unit/common/ai/toolsets/test_sql.py
@@ -27,6 +27,7 @@ from pydantic_ai.exceptions import ModelRetry
 
 from airflow.providers.common.ai.toolsets.sql import SQLToolset
 from airflow.providers.common.ai.utils.sql_validation import SQLSafetyError
+from airflow.providers.common.sql.hooks.sql import DbApiHook
 
 
 def _make_mock_db_hook(
@@ -36,8 +37,6 @@ def _make_mock_db_hook(
     last_description: list[tuple] | None = None,
 ):
     """Create a mock DbApiHook with sensible defaults."""
-    from airflow.providers.common.sql.hooks.sql import DbApiHook
-
     mock = MagicMock(spec=DbApiHook)
     mock.inspector = MagicMock()
     mock.inspector.get_table_names.return_value = table_names or ["users", 
"orders"]
@@ -335,8 +334,6 @@ class TestSQLToolsetCheckQuery:
 class TestSQLToolsetHookResolution:
     @patch("airflow.providers.common.ai.toolsets.sql.BaseHook", autospec=True)
     def test_lazy_resolves_db_hook(self, mock_base_hook):
-        from airflow.providers.common.sql.hooks.sql import DbApiHook
-
         mock_hook = MagicMock(spec=DbApiHook)
         mock_conn = MagicMock(spec=["get_hook"])
         mock_conn.get_hook.return_value = mock_hook
@@ -361,8 +358,6 @@ class TestSQLToolsetHookResolution:
 
     @patch("airflow.providers.common.ai.toolsets.sql.BaseHook", autospec=True)
     def test_caches_hook_after_first_resolution(self, mock_base_hook):
-        from airflow.providers.common.sql.hooks.sql import DbApiHook
-
         mock_hook = MagicMock(spec=DbApiHook)
         mock_conn = MagicMock(spec=["get_hook"])
         mock_conn.get_hook.return_value = mock_hook
@@ -374,3 +369,128 @@ class TestSQLToolsetHookResolution:
 
         # Only called once because result is cached.
         mock_base_hook.get_connection.assert_called_once()
+
+
+class TestSQLToolsetMultiSchema:
+    """Schema-qualified allowed_tables span multiple schemas in one 
database."""
+
+    @staticmethod
+    def _schema_aware_hook(tables_by_schema: dict[str | None, list[str]]):
+        hook = MagicMock(spec=DbApiHook)
+        hook.inspector = MagicMock()
+        hook.inspector.get_table_names.side_effect = lambda schema=None: 
tables_by_schema.get(schema, [])
+        hook.get_table_schema.return_value = [{"name": "id", "type": 
"INTEGER"}]
+        return hook
+
+    def test_list_tables_spans_multiple_schemas(self):
+        ts = SQLToolset(
+            "sf",
+            allowed_tables=["MODEL_ASTRO.DEPLOYMENT_IMAGE_DETAILS", 
"MODEL_CRM.SF_ASTRO_ORGS"],
+        )
+        ts._hook = self._schema_aware_hook(
+            {
+                "MODEL_ASTRO": ["DEPLOYMENT_IMAGE_DETAILS", "OTHER_TABLE"],
+                "MODEL_CRM": ["SF_ASTRO_ORGS"],
+            }
+        )
+
+        result = json.loads(asyncio.run(ts.call_tool("list_tables", {}, 
ctx=MagicMock(), tool=MagicMock())))
+        assert result == ["MODEL_ASTRO.DEPLOYMENT_IMAGE_DETAILS", 
"MODEL_CRM.SF_ASTRO_ORGS"]
+
+    def 
test_list_tables_never_introspects_none_schema_when_all_qualified(self):
+        """Regression for the 'SHOW TABLES IN SCHEMA "DB"."None"' failure."""
+        ts = SQLToolset("sf", allowed_tables=["MODEL_ASTRO.X", "MODEL_CRM.Y"])
+        ts._hook = self._schema_aware_hook({"MODEL_ASTRO": ["X"], "MODEL_CRM": 
["Y"]})
+
+        asyncio.run(ts.call_tool("list_tables", {}, ctx=MagicMock(), 
tool=MagicMock()))
+
+        called_schemas = {c.kwargs.get("schema") for c in 
ts._hook.inspector.get_table_names.call_args_list}
+        assert called_schemas == {"MODEL_ASTRO", "MODEL_CRM"}
+        assert None not in called_schemas
+
+    def test_list_tables_mixed_qualified_and_default(self):
+        ts = SQLToolset("pg", allowed_tables=["users", "MODEL_ASTRO.X"], 
schema="public")
+        ts._hook = self._schema_aware_hook({"public": ["users", "orders"], 
"MODEL_ASTRO": ["X", "Z"]})
+
+        result = json.loads(asyncio.run(ts.call_tool("list_tables", {}, 
ctx=MagicMock(), tool=MagicMock())))
+        # Qualified schemas listed first (sorted), then the default schema.
+        assert result == ["MODEL_ASTRO.X", "users"]
+
+    def test_get_schema_routes_to_qualified_schema(self):
+        ts = SQLToolset("sf", 
allowed_tables=["MODEL_ASTRO.DEPLOYMENT_IMAGE_DETAILS"])
+        ts._hook = self._schema_aware_hook({"MODEL_ASTRO": 
["DEPLOYMENT_IMAGE_DETAILS"]})
+
+        result = json.loads(
+            asyncio.run(
+                ts.call_tool(
+                    "get_schema",
+                    {"table_name": "MODEL_ASTRO.DEPLOYMENT_IMAGE_DETAILS"},
+                    ctx=MagicMock(),
+                    tool=MagicMock(),
+                )
+            )
+        )
+        assert result == [{"name": "id", "type": "INTEGER"}]
+        
ts._hook.get_table_schema.assert_called_once_with("DEPLOYMENT_IMAGE_DETAILS", 
schema="MODEL_ASTRO")
+
+    def test_get_schema_blocks_table_outside_allowed_schema(self):
+        ts = SQLToolset("sf", allowed_tables=["MODEL_ASTRO.X"])
+        ts._hook = self._schema_aware_hook({"MODEL_ASTRO": ["X"]})
+
+        result = json.loads(
+            asyncio.run(
+                ts.call_tool(
+                    "get_schema", {"table_name": "SECRETS.PASSWORDS"}, 
ctx=MagicMock(), tool=MagicMock()
+                )
+            )
+        )
+        assert "error" in result
+        ts._hook.get_table_schema.assert_not_called()
+
+    def test_get_schema_unqualified_uses_default_schema(self):
+        ts = SQLToolset("pg", schema="public")
+        ts._hook = self._schema_aware_hook({"public": ["users"]})
+
+        asyncio.run(ts.call_tool("get_schema", {"table_name": "users"}, 
ctx=MagicMock(), tool=MagicMock()))
+        ts._hook.get_table_schema.assert_called_once_with("users", 
schema="public")
+
+    def test_list_tables_matches_case_insensitively(self):
+        """Snowflake reflects unquoted names lowercased; uppercase 
allowed_tables still match."""
+        ts = SQLToolset(
+            "sf",
+            allowed_tables=["MODEL_ASTRO.DEPLOYMENT_IMAGE_DETAILS", 
"MODEL_CRM.SF_ASTRO_ORGS"],
+        )
+        ts._hook = self._schema_aware_hook(
+            {
+                "MODEL_ASTRO": ["deployment_image_details", "other"],
+                "MODEL_CRM": ["sf_astro_orgs"],
+            }
+        )
+
+        result = json.loads(asyncio.run(ts.call_tool("list_tables", {}, 
ctx=MagicMock(), tool=MagicMock())))
+        assert result == ["MODEL_ASTRO.deployment_image_details", 
"MODEL_CRM.sf_astro_orgs"]
+
+    def test_get_schema_matches_case_insensitively(self):
+        ts = SQLToolset("sf", 
allowed_tables=["MODEL_ASTRO.DEPLOYMENT_IMAGE_DETAILS"])
+        ts._hook = self._schema_aware_hook({"MODEL_ASTRO": 
["deployment_image_details"]})
+
+        result = json.loads(
+            asyncio.run(
+                ts.call_tool(
+                    "get_schema",
+                    {"table_name": "MODEL_ASTRO.deployment_image_details"},
+                    ctx=MagicMock(),
+                    tool=MagicMock(),
+                )
+            )
+        )
+        assert "error" not in result
+        
ts._hook.get_table_schema.assert_called_once_with("deployment_image_details", 
schema="MODEL_ASTRO")
+
+    def test_list_tables_deduplicates_same_table(self):
+        """A table listed both qualified and unqualified appears once."""
+        ts = SQLToolset("pg", allowed_tables=["public.users", "users"], 
schema="public")
+        ts._hook = self._schema_aware_hook({"public": ["users"]})
+
+        result = json.loads(asyncio.run(ts.call_tool("list_tables", {}, 
ctx=MagicMock(), tool=MagicMock())))
+        assert result == ["public.users"]

Reply via email to