This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 73d0ee88b50 Add `HookToolset` and `SQLToolset` for agentic LLM 
workflows (#62785)
73d0ee88b50 is described below

commit 73d0ee88b50ddc9c2e1e22250a6a11083b975780
Author: Kaxil Naik <[email protected]>
AuthorDate: Tue Mar 3 20:50:43 2026 +0000

    Add `HookToolset` and `SQLToolset` for agentic LLM workflows (#62785)
    
    HookToolset: Generic adapter that exposes any Airflow Hook's methods
    as pydantic-ai tools via introspection. Requires explicit
    allowed_methods list (no auto-discovery). Builds JSON Schema from
    method signatures and enriches tool descriptions from docstrings.
    
    SQLToolset: Curated 4-tool database toolset (list_tables, get_schema,
    query, check_query) wrapping DbApiHook. Read-only by default with SQL
    validation, allowed_tables metadata filtering, and max_rows truncation.
    
    Both implement pydantic-ai's AbstractToolset interface with
    sequential=True on all tool definitions to prevent concurrent sync I/O.
    
    * Fix mypy error: annotate result variable in SQLToolset._query
    
    The list comprehension in the else branch produces list[list[Any]]
    while the if branch produces list[dict[str, Any]]. Add an explicit
    type annotation to satisfy mypy.
    
    * Add toolset/agentic/ctx to spelling wordlist
    
    Sphinx autoapi generates RST from pydantic-ai's AbstractToolset base
    class docstrings. These words appear in the auto-generated docs and
    need to be in the global wordlist.
    
    Docs for HookToolset (generic hook→tools adapter) and SQLToolset
    (curated 4-tool DB toolset). Includes defense layers table,
    allowed_tables limitation, HookToolset guidelines, recommended
    configurations, and production checklist.
---
 docs/spelling_wordlist.txt                         |   7 +
 providers/common/ai/docs/index.rst                 |   1 +
 providers/common/ai/docs/toolsets.rst              | 273 ++++++++++++++++++++
 .../providers/common/ai/toolsets/__init__.py       |  35 +++
 .../airflow/providers/common/ai/toolsets/hook.py   | 267 ++++++++++++++++++++
 .../airflow/providers/common/ai/toolsets/sql.py    | 231 +++++++++++++++++
 .../ai/tests/unit/common/ai/toolsets/__init__.py   |  16 ++
 .../ai/tests/unit/common/ai/toolsets/test_hook.py  | 281 +++++++++++++++++++++
 .../ai/tests/unit/common/ai/toolsets/test_sql.py   | 233 +++++++++++++++++
 9 files changed, 1344 insertions(+)

diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index fbfbeda2b00..81fd9cc567e 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -1,6 +1,7 @@
 aarch
 abc
 AbstractFileSystem
+AbstractToolset
 accessor
 AccessSecretVersionResponse
 accountmaking
@@ -24,6 +25,7 @@ adobjects
 AdsInsights
 adsinsights
 afterall
+agentic
 AgentKey
 ai
 aio
@@ -375,6 +377,7 @@ Ctl
 ctl
 ctor
 Ctrl
+ctx
 cubeName
 customDataImportUids
 customizability
@@ -831,6 +834,7 @@ Gzip
 gzipped
 hadoop
 hadoopcmd
+hardcode
 hardcoded
 Harenslak
 Hashable
@@ -1918,6 +1922,9 @@ tokopedia
 tolerations
 toml
 toolchain
+toolset
+Toolsets
+toolsets
 Tooltip
 tooltip
 tooltips
diff --git a/providers/common/ai/docs/index.rst 
b/providers/common/ai/docs/index.rst
index 764ae24c4dd..87fbe994716 100644
--- a/providers/common/ai/docs/index.rst
+++ b/providers/common/ai/docs/index.rst
@@ -36,6 +36,7 @@
 
     Connection types <connections/pydantic_ai>
     Hooks <hooks/pydantic_ai>
+    Toolsets <toolsets>
     Operators <operators/index>
 
 .. toctree::
diff --git a/providers/common/ai/docs/toolsets.rst 
b/providers/common/ai/docs/toolsets.rst
new file mode 100644
index 00000000000..7334a5ae0a4
--- /dev/null
+++ b/providers/common/ai/docs/toolsets.rst
@@ -0,0 +1,273 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+    or more contributor license agreements.  See the NOTICE file
+    distributed with this work for additional information
+    regarding copyright ownership.  The ASF licenses this file
+    to you under the Apache License, Version 2.0 (the
+    "License"); you may not use this file except in compliance
+    with the License.  You may obtain a copy of the License at
+
+ ..   http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+    software distributed under the License is distributed on an
+    "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+    KIND, either express or implied.  See the License for the
+    specific language governing permissions and limitations
+    under the License.
+
+.. _howto/toolsets:
+
+Toolsets — Airflow Hooks as AI Agent Tools
+==========================================
+
+Airflow's 350+ provider hooks already have typed methods, rich docstrings,
+and managed credentials. Toolsets expose them as pydantic-ai tools so that
+LLM agents can call them during multi-turn reasoning.
+
+Two toolsets are included:
+
+- :class:`~airflow.providers.common.ai.toolsets.hook.HookToolset` — generic
+  adapter for any Airflow Hook.
+- :class:`~airflow.providers.common.ai.toolsets.sql.SQLToolset` — curated
+  4-tool database toolset.
+
+Both implement pydantic-ai's
+`AbstractToolset <https://ai.pydantic.dev/toolsets/>`__ interface and can be
+passed to any pydantic-ai ``Agent``, including via
+:class:`~airflow.providers.common.ai.operators.agent.AgentOperator`.
+
+
+``HookToolset``
+---------------
+
+Generic adapter that exposes selected methods of any Airflow Hook as
+pydantic-ai tools via introspection. Requires an explicit ``allowed_methods``
+list — there is no auto-discovery.
+
+.. code-block:: python
+
+    from airflow.providers.http.hooks.http import HttpHook
+    from airflow.providers.common.ai.toolsets.hook import HookToolset
+
+    http_hook = HttpHook(http_conn_id="my_api")
+
+    toolset = HookToolset(
+        http_hook,
+        allowed_methods=["run"],
+        tool_name_prefix="http_",
+    )
+
+For each listed method, the introspection engine:
+
+1. Builds a JSON Schema from the method signature (``inspect.signature`` +
+   ``get_type_hints``).
+2. Extracts the description from the first paragraph of the docstring.
+3. Enriches parameter descriptions from Sphinx ``:param:`` or Google
+   ``Args:`` blocks.
+
+Parameters
+^^^^^^^^^^
+
+- ``hook``: An instantiated Airflow Hook.
+- ``allowed_methods``: Method names to expose as tools. Required. Methods
+  are validated with ``hasattr`` + ``callable`` at instantiation time.
+- ``tool_name_prefix``: Optional prefix prepended to each tool name
+  (e.g. ``"s3_"`` produces ``"s3_list_keys"``).
+
+
+``SQLToolset``
+--------------
+
+Curated toolset wrapping
+:class:`~airflow.providers.common.sql.hooks.sql.DbApiHook` with four tools:
+
+.. list-table::
+   :header-rows: 1
+   :widths: 20 50
+
+   * - Tool
+     - Description
+   * - ``list_tables``
+     - Lists available table names (filtered by ``allowed_tables`` if set)
+   * - ``get_schema``
+     - Returns column names and types for a table
+   * - ``query``
+     - Executes a SQL query and returns rows as JSON
+   * - ``check_query``
+     - Validates SQL syntax without executing it
+
+.. code-block:: python
+
+    from airflow.providers.common.ai.toolsets.sql import SQLToolset
+
+    toolset = SQLToolset(
+        db_conn_id="postgres_default",
+        allowed_tables=["customers", "orders"],
+        max_rows=20,
+    )
+
+The ``DbApiHook`` is resolved lazily from ``db_conn_id`` on first tool call
+via ``BaseHook.get_connection(conn_id).get_hook()``.
+
+Parameters
+^^^^^^^^^^
+
+- ``db_conn_id``: Airflow connection ID for the database.
+- ``allowed_tables``: Restrict which tables the agent can discover via
+  ``list_tables`` and ``get_schema``. ``None`` (default) exposes all tables.
+  See :ref:`allowed-tables-limitation` for an important caveat.
+- ``schema``: Database schema/namespace for table listing and introspection.
+- ``allow_writes``: Allow data-modifying SQL (INSERT, UPDATE, DELETE, etc.).
+  Default ``False`` — only SELECT-family statements are permitted.
+- ``max_rows``: Maximum rows returned from the ``query`` tool. Default ``50``.
+
+
+Security
+--------
+
+LLM agents call tools based on natural-language reasoning. This makes them
+powerful but introduces risks that don't exist with deterministic operators.
+
+Defense Layers
+^^^^^^^^^^^^^^
+
+No single layer is sufficient — they work together.
+
+.. list-table::
+   :header-rows: 1
+   :widths: 20 40 40
+
+   * - Layer
+     - What it does
+     - What it does NOT do
+   * - **Airflow Connections**
+     - Credentials are stored in Airflow's secret backend, never in DAG code.
+       The LLM agent cannot see API keys or database passwords.
+     - Does not prevent the agent from using the connection to access data
+       the connection has access to.
+   * - **HookToolset: explicit allow-list**
+     - Only methods listed in ``allowed_methods`` are exposed as tools.
+       Auto-discovery is not supported. Methods are validated at DAG parse
+       time.
+     - Does not restrict what arguments the agent passes to allowed methods.
+   * - **SQLToolset: read-only by default**
+     - ``allow_writes=False`` (default) validates every SQL query through
+       ``validate_sql()`` and rejects INSERT, UPDATE, DELETE, DROP, etc.
+     - Does not prevent the agent from reading sensitive data that the
+       database user has SELECT access to.
+   * - **SQLToolset: allowed_tables**
+     - Restricts which tables appear in ``list_tables`` and ``get_schema``
+       responses, limiting the agent's knowledge of the schema.
+     - Does **not** validate table references in SQL queries. The agent can
+       still query unlisted tables if it guesses the name. See
+       :ref:`allowed-tables-limitation` below.
+   * - **SQLToolset: max_rows**
+     - Truncates query results to ``max_rows`` (default 50), preventing the
+       agent from pulling entire tables into context.
+     - Does not limit the number of queries the agent can make.
+   * - **pydantic-ai: tool call budget**
+     - pydantic-ai's ``max_result_retries`` and ``model_settings`` control
+       how many tool-call rounds the agent can make before stopping.
+     - Requires explicit configuration — the default allows many rounds.
+
+
+.. _allowed-tables-limitation:
+
+The ``allowed_tables`` Limitation
+"""""""""""""""""""""""""""""""""
+
+``allowed_tables`` is a **metadata filter**, not an access control mechanism.
+It hides table names from ``list_tables`` and blocks ``get_schema`` for
+unlisted tables, but does not parse SQL queries to validate table references.
+
+An LLM can craft ``SELECT * FROM secrets`` even when
+``allowed_tables=["orders"]``. Parsing SQL for table references (including
+CTEs, subqueries, aliases, and vendor-specific syntax) is complex and
+error-prone; we chose not to provide a false sense of security.
+
+For query-level restrictions, use database permissions:
+
+.. code-block:: sql
+
+    -- Create a read-only role with access to specific tables only
+    CREATE ROLE airflow_agent_reader;
+    GRANT SELECT ON orders, customers TO airflow_agent_reader;
+    -- Use this role's credentials in the Airflow connection
+
+The Airflow connection should use a database user with the minimum privileges
+required.
+
+
+HookToolset Guidelines
+""""""""""""""""""""""
+
+- List only the methods the agent needs. Never expose ``run()`` or
+  ``get_connection()`` — these give broad access.
+- Prefer read-only methods (``list_*``, ``get_*``, ``describe_*``).
+- The agent controls arguments. If a method accepts a ``path`` parameter,
+  the agent can pass any path the hook has access to.
+
+.. code-block:: python
+
+    # Good: expose only list and read
+    HookToolset(
+        s3_hook,
+        allowed_methods=["list_keys", "read_key"],
+        tool_name_prefix="s3_",
+    )
+
+    # Bad: exposes delete and write operations
+    HookToolset(
+        s3_hook,
+        allowed_methods=["list_keys", "read_key", "delete_object", 
"load_string"],
+    )
+
+
+Recommended Configuration
+"""""""""""""""""""""""""
+
+**Read-only analytics** (the most common pattern):
+
+.. code-block:: python
+
+    SQLToolset(
+        db_conn_id="analytics_readonly",  # Connection with SELECT-only grants
+        allowed_tables=["orders", "customers"],  # Hide other tables from agent
+        allow_writes=False,  # Default — validates SQL
+        max_rows=50,  # Default — truncate large results
+    )
+
+**Agents that need to modify data** (use with caution):
+
+.. code-block:: python
+
+    SQLToolset(
+        db_conn_id="app_db",
+        allowed_tables=["user_preferences"],
+        allow_writes=True,  # Disables SQL validation — agent can INSERT/UPDATE
+        max_rows=100,
+    )
+
+
+Production Checklist
+""""""""""""""""""""
+
+Before deploying an agent task to production:
+
+1. **Connection credentials**: Use Airflow's secret backend. Never hardcode
+   API keys in DAG files.
+2. **Database permissions**: Create a dedicated database user with minimum
+   required grants. Don't reuse the admin connection.
+3. **Tool allow-list**: Review ``allowed_methods`` / ``allowed_tables``. The
+   agent can call any exposed tool with any arguments.
+4. **Read-only default**: Keep ``allow_writes=False`` unless the task
+   specifically requires writes.
+5. **Row limits**: Set ``max_rows`` appropriate to the use case. Large
+   result sets consume LLM context and increase cost.
+6. **Model budget**: Configure pydantic-ai's ``model_settings`` (e.g.
+   ``max_tokens``) and ``retries`` to bound cost and prevent runaway loops.
+7. **System prompt**: Include safety instructions in ``system_prompt`` (e.g.
+   "Only query tables related to the question. Never modify data.").
+8. **Prompt injection**: Be cautious when the prompt includes untrusted data
+   (user input, external API responses, upstream XCom). Consider sanitizing
+   inputs before passing them to the agent.
diff --git 
a/providers/common/ai/src/airflow/providers/common/ai/toolsets/__init__.py 
b/providers/common/ai/src/airflow/providers/common/ai/toolsets/__init__.py
new file mode 100644
index 00000000000..aba5a45ee07
--- /dev/null
+++ b/providers/common/ai/src/airflow/providers/common/ai/toolsets/__init__.py
@@ -0,0 +1,35 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Toolsets for exposing Airflow hooks as pydantic-ai agent tools."""
+
+from __future__ import annotations
+
+from airflow.providers.common.ai.toolsets.hook import HookToolset
+
+__all__ = ["HookToolset", "SQLToolset"]
+
+
+def __getattr__(name: str):
+    if name == "SQLToolset":
+        try:
+            from airflow.providers.common.ai.toolsets.sql import SQLToolset
+        except ImportError as e:
+            from airflow.providers.common.compat.sdk import 
AirflowOptionalProviderFeatureException
+
+            raise AirflowOptionalProviderFeatureException(e)
+        return SQLToolset
+    raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
diff --git 
a/providers/common/ai/src/airflow/providers/common/ai/toolsets/hook.py 
b/providers/common/ai/src/airflow/providers/common/ai/toolsets/hook.py
new file mode 100644
index 00000000000..ae3987b6c0e
--- /dev/null
+++ b/providers/common/ai/src/airflow/providers/common/ai/toolsets/hook.py
@@ -0,0 +1,267 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Generic adapter that exposes Airflow Hook methods as pydantic-ai tools."""
+
+from __future__ import annotations
+
+import inspect
+import json
+import re
+import types
+from typing import TYPE_CHECKING, Any, Union, get_args, get_origin, 
get_type_hints
+
+from pydantic_ai.tools import ToolDefinition
+from pydantic_ai.toolsets.abstract import AbstractToolset, ToolsetTool
+from pydantic_core import SchemaValidator, core_schema
+
+if TYPE_CHECKING:
+    from collections.abc import Callable
+
+    from pydantic_ai._run_context import RunContext
+
+    from airflow.providers.common.compat.sdk import BaseHook
+
+# Single shared validator — accepts any JSON-decoded dict from the LLM.
+_PASSTHROUGH_VALIDATOR = SchemaValidator(core_schema.any_schema())
+
+# Maps Python types to JSON Schema fragments.
+_TYPE_MAP: dict[type, dict[str, Any]] = {
+    str: {"type": "string"},
+    int: {"type": "integer"},
+    float: {"type": "number"},
+    bool: {"type": "boolean"},
+    list: {"type": "array"},
+    dict: {"type": "object"},
+    bytes: {"type": "string"},
+}
+
+
+class HookToolset(AbstractToolset[Any]):
+    """
+    Expose selected methods of an Airflow Hook as pydantic-ai tools.
+
+    This adapter introspects the method signatures and docstrings of the given
+    hook to build :class:`~pydantic_ai.tools.ToolDefinition` objects that an 
LLM
+    agent can call.
+
+    :param hook: An instantiated Airflow Hook.
+    :param allowed_methods: Method names to expose as tools. Required —
+        auto-discovery is intentionally not supported for safety.
+    :param tool_name_prefix: Optional prefix prepended to each tool name
+        (e.g. ``"s3_"`` → ``"s3_list_keys"``).
+    """
+
+    def __init__(
+        self,
+        hook: BaseHook,
+        *,
+        allowed_methods: list[str],
+        tool_name_prefix: str = "",
+    ) -> None:
+        if not allowed_methods:
+            raise ValueError("allowed_methods must be a non-empty list.")
+
+        hook_cls_name = type(hook).__name__
+        for method_name in allowed_methods:
+            if not hasattr(hook, method_name):
+                raise ValueError(
+                    f"Hook {hook_cls_name!r} has no method {method_name!r}. 
Check your allowed_methods list."
+                )
+            if not callable(getattr(hook, method_name)):
+                raise ValueError(f"{hook_cls_name}.{method_name} is not 
callable.")
+
+        self._hook = hook
+        self._allowed_methods = allowed_methods
+        self._tool_name_prefix = tool_name_prefix
+        self._id = f"hook-{type(hook).__name__}"
+
+    @property
+    def id(self) -> str:
+        return self._id
+
+    async def get_tools(self, ctx: RunContext[Any]) -> dict[str, 
ToolsetTool[Any]]:
+        tools: dict[str, ToolsetTool[Any]] = {}
+        for method_name in self._allowed_methods:
+            method = getattr(self._hook, method_name)
+            tool_name = f"{self._tool_name_prefix}{method_name}" if 
self._tool_name_prefix else method_name
+
+            json_schema = _build_json_schema_from_signature(method)
+            description = _extract_description(method)
+            param_docs = _parse_param_docs(method.__doc__ or "")
+
+            # Enrich parameter descriptions from docstring.
+            for param_name, param_desc in param_docs.items():
+                if param_name in json_schema.get("properties", {}):
+                    json_schema["properties"][param_name]["description"] = 
param_desc
+
+            # sequential=True because hook methods perform synchronous I/O
+            # (network calls, DB queries) and should not run concurrently.
+            tool_def = ToolDefinition(
+                name=tool_name,
+                description=description,
+                parameters_json_schema=json_schema,
+                sequential=True,
+            )
+            tools[tool_name] = ToolsetTool(
+                toolset=self,
+                tool_def=tool_def,
+                max_retries=1,
+                args_validator=_PASSTHROUGH_VALIDATOR,
+            )
+        return tools
+
+    async def call_tool(
+        self,
+        name: str,
+        tool_args: dict[str, Any],
+        ctx: RunContext[Any],
+        tool: ToolsetTool[Any],
+    ) -> Any:
+        method_name = name.removeprefix(self._tool_name_prefix) if 
self._tool_name_prefix else name
+        method: Callable[..., Any] = getattr(self._hook, method_name)
+        result = method(**tool_args)
+        return _serialize_for_llm(result)
+
+
+# ---------------------------------------------------------------------------
+# Private introspection helpers
+# ---------------------------------------------------------------------------
+
+
+def _python_type_to_json_schema(annotation: Any) -> dict[str, Any]:
+    """Convert a Python type annotation to a JSON Schema fragment."""
+    if annotation is inspect.Parameter.empty or annotation is Any:
+        return {"type": "string"}
+
+    origin = get_origin(annotation)
+    args = get_args(annotation)
+
+    # Optional[X] is Union[X, None] — handle both types.UnionType (3.10+) and 
typing.Union
+    if origin is types.UnionType or origin is Union:
+        non_none = [a for a in args if a is not type(None)]
+        if len(non_none) == 1:
+            return _python_type_to_json_schema(non_none[0])
+        return {"type": "string"}
+
+    # list[X]
+    if origin is list:
+        items = _python_type_to_json_schema(args[0]) if args else {"type": 
"string"}
+        return {"type": "array", "items": items}
+
+    # dict[K, V]
+    if origin is dict:
+        return {"type": "object"}
+
+    # Always return a fresh copy — callers may mutate the dict (e.g. adding 
"description").
+    schema = _TYPE_MAP.get(annotation)
+    return dict(schema) if schema else {"type": "string"}
+
+
+def _build_json_schema_from_signature(method: Callable[..., Any]) -> dict[str, 
Any]:
+    """Build a JSON Schema ``object`` from a method's signature and type 
hints."""
+    sig = inspect.signature(method)
+
+    try:
+        hints = get_type_hints(method)
+    except Exception:
+        hints = {}
+
+    properties: dict[str, Any] = {}
+    required: list[str] = []
+
+    for name, param in sig.parameters.items():
+        if name in ("self", "cls"):
+            continue
+        # Skip **kwargs and *args
+        if param.kind in (param.VAR_POSITIONAL, param.VAR_KEYWORD):
+            continue
+
+        annotation = hints.get(name, param.annotation)
+        prop = _python_type_to_json_schema(annotation)
+        properties[name] = prop
+
+        if param.default is inspect.Parameter.empty:
+            required.append(name)
+
+    schema: dict[str, Any] = {"type": "object", "properties": properties}
+    if required:
+        schema["required"] = required
+    return schema
+
+
+def _extract_description(method: Callable[..., Any]) -> str:
+    """Return the first paragraph of a method's docstring."""
+    doc = inspect.getdoc(method)
+    if not doc:
+        return method.__name__.replace("_", " ").capitalize()
+
+    # First paragraph = everything up to the first blank line.
+    lines: list[str] = []
+    for line in doc.splitlines():
+        if not line.strip():
+            if lines:
+                break
+            continue
+        lines.append(line.strip())
+    return " ".join(lines) if lines else method.__name__.replace("_", " 
").capitalize()
+
+
+# Matches Sphinx-style `:param name:` and Google-style `name:` under an 
``Args:`` block.
+_SPHINX_PARAM_RE = re.compile(r":param\s+(\w+):\s*(.+?)(?=\n\s*:|$)", 
re.DOTALL)
+_GOOGLE_ARGS_RE = re.compile(r"^\s{2,}(\w+)\s*(?:\(.+?\))?:\s*(.+)", 
re.MULTILINE)
+
+
+def _parse_param_docs(docstring: str) -> dict[str, str]:
+    """Parse parameter descriptions from Sphinx or Google-style docstrings."""
+    params: dict[str, str] = {}
+
+    # Try Sphinx style first.
+    for match in _SPHINX_PARAM_RE.finditer(docstring):
+        name = match.group(1)
+        desc = " ".join(match.group(2).split())
+        params[name] = desc
+
+    if params:
+        return params
+
+    # Fall back to Google style (``Args:`` section).
+    in_args = False
+    for line in docstring.splitlines():
+        stripped = line.strip()
+        if stripped.lower().startswith("args:"):
+            in_args = True
+            continue
+        if in_args:
+            if stripped and not stripped[0].isspace() and ":" not in stripped:
+                break
+            m = _GOOGLE_ARGS_RE.match(line)
+            if m:
+                params[m.group(1)] = " ".join(m.group(2).split())
+
+    return params
+
+
+def _serialize_for_llm(value: Any) -> str:
+    """Convert a Python return value to a string suitable for an LLM."""
+    if value is None:
+        return "null"
+    if isinstance(value, str):
+        return value
+    try:
+        return json.dumps(value, default=str)
+    except (TypeError, ValueError):
+        return str(value)
diff --git 
a/providers/common/ai/src/airflow/providers/common/ai/toolsets/sql.py 
b/providers/common/ai/src/airflow/providers/common/ai/toolsets/sql.py
new file mode 100644
index 00000000000..f60f4b621c3
--- /dev/null
+++ b/providers/common/ai/src/airflow/providers/common/ai/toolsets/sql.py
@@ -0,0 +1,231 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Curated SQL toolset wrapping DbApiHook for agentic database workflows."""
+
+from __future__ import annotations
+
+import json
+from typing import TYPE_CHECKING, Any
+
+try:
+    from airflow.providers.common.ai.utils.sql_validation import validate_sql 
as _validate_sql
+    from airflow.providers.common.sql.hooks.sql import DbApiHook
+except ImportError as e:
+    from airflow.providers.common.compat.sdk import 
AirflowOptionalProviderFeatureException
+
+    raise AirflowOptionalProviderFeatureException(e)
+
+from pydantic_ai.tools import ToolDefinition
+from pydantic_ai.toolsets.abstract import AbstractToolset, ToolsetTool
+from pydantic_core import SchemaValidator, core_schema
+
+from airflow.providers.common.compat.sdk import BaseHook
+
+if TYPE_CHECKING:
+    from pydantic_ai._run_context import RunContext
+
+_PASSTHROUGH_VALIDATOR = SchemaValidator(core_schema.any_schema())
+
+# JSON Schemas for the four SQL tools.
+_LIST_TABLES_SCHEMA: dict[str, Any] = {
+    "type": "object",
+    "properties": {},
+}
+
+_GET_SCHEMA_SCHEMA: dict[str, Any] = {
+    "type": "object",
+    "properties": {
+        "table_name": {"type": "string", "description": "Name of the table to 
inspect."},
+    },
+    "required": ["table_name"],
+}
+
+_QUERY_SCHEMA: dict[str, Any] = {
+    "type": "object",
+    "properties": {
+        "sql": {"type": "string", "description": "SQL query to execute."},
+    },
+    "required": ["sql"],
+}
+
+_CHECK_QUERY_SCHEMA: dict[str, Any] = {
+    "type": "object",
+    "properties": {
+        "sql": {"type": "string", "description": "SQL query to validate."},
+    },
+    "required": ["sql"],
+}
+
+
+class SQLToolset(AbstractToolset[Any]):
+    """
+    Curated toolset that gives an LLM agent safe access to a SQL database.
+
+    Provides four tools — ``list_tables``, ``get_schema``, ``query``, and
+    ``check_query`` — inspired by LangChain's ``SQLDatabaseToolkit`` pattern.
+
+    Uses a :class:`~airflow.providers.common.sql.hooks.sql.DbApiHook` resolved
+    lazily from the given ``db_conn_id``.
+
+    :param db_conn_id: Airflow connection ID for the database.
+    :param allowed_tables: Restrict which tables the agent can discover via
+        ``list_tables`` and ``get_schema``. ``None`` (default) exposes all 
tables.
+
+        .. note::
+            ``allowed_tables`` controls metadata visibility only. It does 
**not**
+            parse or validate table references in SQL queries. An LLM can still
+            query tables outside this list if it guesses the name. For 
query-level
+            restrictions, use database-level permissions (e.g. a read-only role
+            with grants limited to specific tables).
+
+    :param schema: Database schema/namespace for table listing and 
introspection.
+    :param allow_writes: Allow data-modifying SQL (INSERT, UPDATE, DELETE, 
etc.).
+        Default ``False`` — only SELECT-family statements are permitted.
+    :param max_rows: Maximum number of rows returned from the ``query`` tool.
+        Default ``50``.
+    """
+
+    def __init__(
+        self,
+        db_conn_id: str,
+        *,
+        allowed_tables: list[str] | None = None,
+        schema: str | None = None,
+        allow_writes: bool = False,
+        max_rows: int = 50,
+    ) -> None:
+        self._db_conn_id = db_conn_id
+        self._allowed_tables: frozenset[str] | None = 
frozenset(allowed_tables) if allowed_tables else None
+        self._schema = schema
+        self._allow_writes = allow_writes
+        self._max_rows = max_rows
+        self._hook: DbApiHook | None = None
+
+    @property
+    def id(self) -> str:
+        return f"sql-{self._db_conn_id}"
+
+    # ------------------------------------------------------------------
+    # Lazy hook resolution
+    # ------------------------------------------------------------------
+
+    def _get_db_hook(self) -> DbApiHook:
+        if self._hook is None:
+            connection = BaseHook.get_connection(self._db_conn_id)
+            hook = connection.get_hook()
+            if not isinstance(hook, DbApiHook):
+                raise ValueError(
+                    f"Connection {self._db_conn_id!r} does not provide a 
DbApiHook. "
+                    f"Got {type(hook).__name__}."
+                )
+            self._hook = hook
+        return self._hook
+
+    # ------------------------------------------------------------------
+    # AbstractToolset interface
+    # ------------------------------------------------------------------
+
+    async def get_tools(self, ctx: RunContext[Any]) -> dict[str, 
ToolsetTool[Any]]:
+        tools: dict[str, ToolsetTool[Any]] = {}
+
+        for name, description, schema in (
+            ("list_tables", "List available table names in the database.", 
_LIST_TABLES_SCHEMA),
+            ("get_schema", "Get column names and types for a table.", 
_GET_SCHEMA_SCHEMA),
+            ("query", "Execute a SQL query and return rows as JSON.", 
_QUERY_SCHEMA),
+            ("check_query", "Validate SQL syntax without executing it.", 
_CHECK_QUERY_SCHEMA),
+        ):
+            # sequential=True because all tools use a shared DbApiHook with
+            # synchronous I/O — they must not run concurrently.
+            tool_def = ToolDefinition(
+                name=name,
+                description=description,
+                parameters_json_schema=schema,
+                sequential=True,
+            )
+            tools[name] = ToolsetTool(
+                toolset=self,
+                tool_def=tool_def,
+                max_retries=1,
+                args_validator=_PASSTHROUGH_VALIDATOR,
+            )
+        return tools
+
+    async def call_tool(
+        self,
+        name: str,
+        tool_args: dict[str, Any],
+        ctx: RunContext[Any],
+        tool: ToolsetTool[Any],
+    ) -> Any:
+        if name == "list_tables":
+            return self._list_tables()
+        if name == "get_schema":
+            return self._get_schema(tool_args["table_name"])
+        if name == "query":
+            return self._query(tool_args["sql"])
+        if name == "check_query":
+            return self._check_query(tool_args["sql"])
+        raise ValueError(f"Unknown tool: {name!r}")
+
+    # ------------------------------------------------------------------
+    # Tool implementations
+    # ------------------------------------------------------------------
+
+    def _list_tables(self) -> str:
+        hook = self._get_db_hook()
+        tables: list[str] = hook.inspector.get_table_names(schema=self._schema)
+        if self._allowed_tables is not None:
+            tables = [t for t in tables if t in self._allowed_tables]
+        return json.dumps(tables)
+
+    def _get_schema(self, table_name: str) -> str:
+        if self._allowed_tables is not None and table_name not in 
self._allowed_tables:
+            return json.dumps({"error": f"Table {table_name!r} is not in the 
allowed tables list."})
+        hook = self._get_db_hook()
+        columns = hook.get_table_schema(table_name, schema=self._schema)
+        return json.dumps(columns)
+
+    def _query(self, sql: str) -> str:
+        if not self._allow_writes:
+            _validate_sql(sql)
+
+        hook = self._get_db_hook()
+        rows = hook.get_records(sql)
+        # Fetch column names from cursor description.
+        col_names: list[str] | None = None
+        if hook.last_description:
+            col_names = [desc[0] for desc in hook.last_description]
+
+        result: list[dict[str, Any]] | list[list[Any]]
+        if rows and col_names:
+            result = [dict(zip(col_names, row)) for row in rows[: 
self._max_rows]]
+        else:
+            result = [list(row) for row in (rows or [])[: self._max_rows]]
+
+        truncated = len(rows or []) > self._max_rows
+        output: dict[str, Any] = {"rows": result, "count": len(rows or [])}
+        if truncated:
+            output["truncated"] = True
+            output["max_rows"] = self._max_rows
+        return json.dumps(output, default=str)
+
+    def _check_query(self, sql: str) -> str:
+        try:
+            _validate_sql(sql)
+            return json.dumps({"valid": True})
+        except Exception as e:
+            return json.dumps({"valid": False, "error": str(e)})
diff --git a/providers/common/ai/tests/unit/common/ai/toolsets/__init__.py 
b/providers/common/ai/tests/unit/common/ai/toolsets/__init__.py
new file mode 100644
index 00000000000..13a83393a91
--- /dev/null
+++ b/providers/common/ai/tests/unit/common/ai/toolsets/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/providers/common/ai/tests/unit/common/ai/toolsets/test_hook.py 
b/providers/common/ai/tests/unit/common/ai/toolsets/test_hook.py
new file mode 100644
index 00000000000..2a40bdf0c4b
--- /dev/null
+++ b/providers/common/ai/tests/unit/common/ai/toolsets/test_hook.py
@@ -0,0 +1,281 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+from unittest.mock import MagicMock
+
+import pytest
+
+from airflow.providers.common.ai.toolsets.hook import (
+    HookToolset,
+    _build_json_schema_from_signature,
+    _extract_description,
+    _parse_param_docs,
+    _serialize_for_llm,
+)
+
+
+class _FakeHook:
+    """Fake hook for testing HookToolset introspection."""
+
+    def list_keys(self, bucket: str, prefix: str = "") -> list[str]:
+        """List object keys in a bucket.
+
+        :param bucket: Name of the S3 bucket.
+        :param prefix: Key prefix to filter by.
+        """
+        return [f"{prefix}file1.txt", f"{prefix}file2.txt"]
+
+    def read_file(self, key: str) -> str:
+        """Read a file from storage."""
+        return f"contents of {key}"
+
+    def no_docstring(self, x: int) -> int:
+        return x * 2
+
+
+class TestHookToolsetInit:
+    def test_requires_non_empty_allowed_methods(self):
+        with pytest.raises(ValueError, match="non-empty"):
+            HookToolset(MagicMock(), allowed_methods=[])
+
+    def test_rejects_nonexistent_method(self):
+        hook = _FakeHook()
+        with pytest.raises(ValueError, match="has no method 'nonexistent'"):
+            HookToolset(hook, allowed_methods=["nonexistent"])
+
+    def test_rejects_non_callable_attribute(self):
+        hook = MagicMock()
+        hook.some_attr = "not callable"
+
+        # MagicMock attributes are callable by default, so use a real object
+        class HookWithAttr:
+            data = [1, 2, 3]
+
+        with pytest.raises(ValueError, match="not callable"):
+            HookToolset(HookWithAttr(), allowed_methods=["data"])
+
+    def test_id_includes_hook_class_name(self):
+        hook = _FakeHook()
+        ts = HookToolset(hook, allowed_methods=["list_keys"])
+        assert "FakeHook" in ts.id
+
+
+class TestHookToolsetGetTools:
+    def test_returns_tools_for_allowed_methods(self):
+        hook = _FakeHook()
+        ts = HookToolset(hook, allowed_methods=["list_keys", "read_file"])
+        tools = asyncio.run(ts.get_tools(ctx=MagicMock()))
+        assert set(tools.keys()) == {"list_keys", "read_file"}
+
+    def test_tool_definitions_have_correct_schemas(self):
+        hook = _FakeHook()
+        ts = HookToolset(hook, allowed_methods=["list_keys"])
+        tools = asyncio.run(ts.get_tools(ctx=MagicMock()))
+
+        tool_def = tools["list_keys"].tool_def
+        assert tool_def.name == "list_keys"
+        assert "bucket" in tool_def.parameters_json_schema["properties"]
+        assert "prefix" in tool_def.parameters_json_schema["properties"]
+        assert "bucket" in tool_def.parameters_json_schema["required"]
+        # prefix has a default, so it's not required
+        assert "prefix" not in tool_def.parameters_json_schema.get("required", 
[])
+
+    def test_tool_name_prefix(self):
+        hook = _FakeHook()
+        ts = HookToolset(hook, allowed_methods=["list_keys"], 
tool_name_prefix="s3_")
+        tools = asyncio.run(ts.get_tools(ctx=MagicMock()))
+        assert "s3_list_keys" in tools
+
+    def test_description_from_docstring(self):
+        hook = _FakeHook()
+        ts = HookToolset(hook, allowed_methods=["list_keys"])
+        tools = asyncio.run(ts.get_tools(ctx=MagicMock()))
+
+        assert tools["list_keys"].tool_def.description == "List object keys in 
a bucket."
+
+    def test_description_fallback_for_no_docstring(self):
+        hook = _FakeHook()
+        ts = HookToolset(hook, allowed_methods=["no_docstring"])
+        tools = asyncio.run(ts.get_tools(ctx=MagicMock()))
+
+        assert tools["no_docstring"].tool_def.description == "No docstring"
+
+    def test_tools_are_sequential(self):
+        hook = _FakeHook()
+        ts = HookToolset(hook, allowed_methods=["list_keys"])
+        tools = asyncio.run(ts.get_tools(ctx=MagicMock()))
+        assert tools["list_keys"].tool_def.sequential is True
+
+    def test_param_docs_enriched_in_schema(self):
+        hook = _FakeHook()
+        ts = HookToolset(hook, allowed_methods=["list_keys"])
+        tools = asyncio.run(ts.get_tools(ctx=MagicMock()))
+
+        props = 
tools["list_keys"].tool_def.parameters_json_schema["properties"]
+        assert "description" in props["bucket"]
+        assert "S3 bucket" in props["bucket"]["description"]
+
+
+class TestHookToolsetCallTool:
+    def test_dispatches_to_hook_method(self):
+        hook = _FakeHook()
+        ts = HookToolset(hook, allowed_methods=["list_keys"])
+        tools = asyncio.run(ts.get_tools(ctx=MagicMock()))
+
+        result = asyncio.run(
+            ts.call_tool(
+                "list_keys",
+                {"bucket": "my-bucket", "prefix": "data/"},
+                ctx=MagicMock(),
+                tool=tools["list_keys"],
+            )
+        )
+        assert "data/file1.txt" in result
+
+    def test_dispatches_with_prefix(self):
+        hook = _FakeHook()
+        ts = HookToolset(hook, allowed_methods=["read_file"], 
tool_name_prefix="storage_")
+        tools = asyncio.run(ts.get_tools(ctx=MagicMock()))
+
+        result = asyncio.run(
+            ts.call_tool(
+                "storage_read_file", {"key": "test.txt"}, ctx=MagicMock(), 
tool=tools["storage_read_file"]
+            )
+        )
+        assert result == "contents of test.txt"
+
+
+class TestBuildJsonSchemaFromSignature:
+    def test_basic_types(self):
+        def fn(name: str, count: int, rate: float, active: bool):
+            pass
+
+        schema = _build_json_schema_from_signature(fn)
+        assert schema["properties"]["name"] == {"type": "string"}
+        assert schema["properties"]["count"] == {"type": "integer"}
+        assert schema["properties"]["rate"] == {"type": "number"}
+        assert schema["properties"]["active"] == {"type": "boolean"}
+        assert set(schema["required"]) == {"name", "count", "rate", "active"}
+
+    def test_optional_params_not_required(self):
+        def fn(name: str, prefix: str = ""):
+            pass
+
+        schema = _build_json_schema_from_signature(fn)
+        assert schema["required"] == ["name"]
+
+    def test_list_type(self):
+        def fn(items: list[str]):
+            pass
+
+        schema = _build_json_schema_from_signature(fn)
+        assert schema["properties"]["items"] == {"type": "array", "items": 
{"type": "string"}}
+
+    def test_no_annotation_defaults_to_string(self):
+        def fn(x):
+            pass
+
+        schema = _build_json_schema_from_signature(fn)
+        assert schema["properties"]["x"] == {"type": "string"}
+
+    def test_skips_self_and_cls(self):
+        class Foo:
+            def method(self, x: int):
+                pass
+
+        schema = _build_json_schema_from_signature(Foo().method)
+        assert "self" not in schema["properties"]
+
+    def test_skips_var_args(self):
+        def fn(x: int, *args, **kwargs):
+            pass
+
+        schema = _build_json_schema_from_signature(fn)
+        assert set(schema["properties"].keys()) == {"x"}
+
+
+class TestExtractDescription:
+    def test_first_paragraph(self):
+        def fn():
+            """First paragraph.
+
+            Second paragraph with details.
+            """
+
+        assert _extract_description(fn) == "First paragraph."
+
+    def test_multiline_first_paragraph(self):
+        def fn():
+            """First line of
+            the first paragraph.
+
+            Second paragraph.
+            """
+
+        assert _extract_description(fn) == "First line of the first paragraph."
+
+    def test_no_docstring_uses_method_name(self):
+        def some_method():
+            pass
+
+        assert _extract_description(some_method) == "Some method"
+
+
+class TestParseParamDocs:
+    def test_sphinx_style(self):
+        docstring = """Do something.
+
+        :param name: The name of the thing.
+        :param count: How many items.
+        """
+        result = _parse_param_docs(docstring)
+        assert result["name"] == "The name of the thing."
+        assert result["count"] == "How many items."
+
+    def test_google_style(self):
+        docstring = """Do something.
+
+        Args:
+            name: The name of the thing.
+            count: How many items.
+        """
+        result = _parse_param_docs(docstring)
+        assert result["name"] == "The name of the thing."
+        assert result["count"] == "How many items."
+
+
+class TestSerializeForLlm:
+    def test_string_passthrough(self):
+        assert _serialize_for_llm("hello") == "hello"
+
+    def test_none_returns_null(self):
+        assert _serialize_for_llm(None) == "null"
+
+    def test_dict_to_json(self):
+        result = _serialize_for_llm({"key": "value"})
+        assert result == '{"key": "value"}'
+
+    def test_list_to_json(self):
+        result = _serialize_for_llm([1, 2, 3])
+        assert result == "[1, 2, 3]"
+
+    def test_non_serializable_falls_back_to_str(self):
+        obj = object()
+        result = _serialize_for_llm(obj)
+        assert "object" in result
diff --git a/providers/common/ai/tests/unit/common/ai/toolsets/test_sql.py 
b/providers/common/ai/tests/unit/common/ai/toolsets/test_sql.py
new file mode 100644
index 00000000000..0573acd2a77
--- /dev/null
+++ b/providers/common/ai/tests/unit/common/ai/toolsets/test_sql.py
@@ -0,0 +1,233 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+import json
+from unittest.mock import MagicMock, PropertyMock, patch
+
+import pytest
+
+from airflow.providers.common.ai.toolsets.sql import SQLToolset
+from airflow.providers.common.ai.utils.sql_validation import SQLSafetyError
+
+
+def _make_mock_db_hook(
+    table_names: list[str] | None = None,
+    table_schema: list[dict[str, str]] | None = None,
+    records: list[tuple] | None = None,
+    last_description: list[tuple] | None = None,
+):
+    """Create a mock DbApiHook with sensible defaults."""
+    from airflow.providers.common.sql.hooks.sql import DbApiHook
+
+    mock = MagicMock(spec=DbApiHook)
+    mock.inspector = MagicMock()
+    mock.inspector.get_table_names.return_value = table_names or ["users", 
"orders"]
+    mock.get_table_schema.return_value = table_schema or [
+        {"name": "id", "type": "INTEGER"},
+        {"name": "name", "type": "VARCHAR"},
+    ]
+    mock.get_records.return_value = records or [(1, "Alice"), (2, "Bob")]
+    type(mock).last_description = PropertyMock(return_value=last_description 
or [("id",), ("name",)])
+    return mock
+
+
+class TestSQLToolsetInit:
+    def test_id_includes_conn_id(self):
+        ts = SQLToolset("my_pg")
+        assert ts.id == "sql-my_pg"
+
+
+class TestSQLToolsetGetTools:
+    def test_returns_four_tools(self):
+        ts = SQLToolset("pg_default")
+        tools = asyncio.run(ts.get_tools(ctx=MagicMock()))
+        assert set(tools.keys()) == {"list_tables", "get_schema", "query", 
"check_query"}
+
+    def test_tool_definitions_have_descriptions(self):
+        ts = SQLToolset("pg_default")
+        tools = asyncio.run(ts.get_tools(ctx=MagicMock()))
+        for tool in tools.values():
+            assert tool.tool_def.description
+
+
+class TestSQLToolsetListTables:
+    def test_returns_all_tables(self):
+        ts = SQLToolset("pg_default")
+        mock_hook = _make_mock_db_hook(table_names=["users", "orders", 
"products"])
+        ts._hook = mock_hook
+
+        result = asyncio.run(ts.call_tool("list_tables", {}, ctx=MagicMock(), 
tool=MagicMock()))
+        tables = json.loads(result)
+        assert tables == ["users", "orders", "products"]
+
+    def test_filters_by_allowed_tables(self):
+        ts = SQLToolset("pg_default", allowed_tables=["orders"])
+        mock_hook = _make_mock_db_hook(table_names=["users", "orders", 
"products"])
+        ts._hook = mock_hook
+
+        result = asyncio.run(ts.call_tool("list_tables", {}, ctx=MagicMock(), 
tool=MagicMock()))
+        tables = json.loads(result)
+        assert tables == ["orders"]
+
+
+class TestSQLToolsetGetSchema:
+    def test_returns_column_info(self):
+        ts = SQLToolset("pg_default")
+        mock_hook = _make_mock_db_hook()
+        ts._hook = mock_hook
+
+        result = asyncio.run(
+            ts.call_tool("get_schema", {"table_name": "users"}, 
ctx=MagicMock(), tool=MagicMock())
+        )
+        columns = json.loads(result)
+        assert columns == [{"name": "id", "type": "INTEGER"}, {"name": "name", 
"type": "VARCHAR"}]
+        mock_hook.get_table_schema.assert_called_once_with("users", 
schema=None)
+
+    def test_blocks_table_not_in_allowed_list(self):
+        ts = SQLToolset("pg_default", allowed_tables=["orders"])
+        ts._hook = _make_mock_db_hook()
+
+        result = asyncio.run(
+            ts.call_tool("get_schema", {"table_name": "secrets"}, 
ctx=MagicMock(), tool=MagicMock())
+        )
+        data = json.loads(result)
+        assert "error" in data
+        assert "secrets" in data["error"]
+
+
+class TestSQLToolsetQuery:
+    def test_returns_rows_as_json(self):
+        ts = SQLToolset("pg_default")
+        ts._hook = _make_mock_db_hook(
+            records=[(1, "Alice"), (2, "Bob")],
+            last_description=[("id",), ("name",)],
+        )
+
+        result = asyncio.run(
+            ts.call_tool("query", {"sql": "SELECT id, name FROM users"}, 
ctx=MagicMock(), tool=MagicMock())
+        )
+        data = json.loads(result)
+        assert data["rows"] == [{"id": 1, "name": "Alice"}, {"id": 2, "name": 
"Bob"}]
+        assert data["count"] == 2
+
+    def test_truncates_at_max_rows(self):
+        ts = SQLToolset("pg_default", max_rows=1)
+        ts._hook = _make_mock_db_hook(
+            records=[(1, "Alice"), (2, "Bob"), (3, "Charlie")],
+            last_description=[("id",), ("name",)],
+        )
+
+        result = asyncio.run(
+            ts.call_tool("query", {"sql": "SELECT id, name FROM users"}, 
ctx=MagicMock(), tool=MagicMock())
+        )
+        data = json.loads(result)
+        assert len(data["rows"]) == 1
+        assert data["truncated"] is True
+        assert data["count"] == 3
+
+    def test_blocks_unsafe_sql_by_default(self):
+        ts = SQLToolset("pg_default")
+        ts._hook = _make_mock_db_hook()
+
+        with pytest.raises(SQLSafetyError, match="not allowed"):
+            asyncio.run(ts.call_tool("query", {"sql": "DROP TABLE users"}, 
ctx=MagicMock(), tool=MagicMock()))
+
+    def test_allows_writes_when_enabled(self):
+        ts = SQLToolset("pg_default", allow_writes=True)
+        ts._hook = _make_mock_db_hook(
+            records=[(1,)],
+            last_description=[("count",)],
+        )
+
+        # Should not raise even with INSERT
+        result = asyncio.run(
+            ts.call_tool(
+                "query", {"sql": "INSERT INTO users VALUES (3, 'Eve')"}, 
ctx=MagicMock(), tool=MagicMock()
+            )
+        )
+        # The mock doesn't actually execute, just returns mocked records
+        data = json.loads(result)
+        assert "rows" in data
+
+
+class TestSQLToolsetCheckQuery:
+    def test_valid_select(self):
+        ts = SQLToolset("pg_default")
+        ts._hook = _make_mock_db_hook()
+
+        result = asyncio.run(
+            ts.call_tool("check_query", {"sql": "SELECT 1"}, ctx=MagicMock(), 
tool=MagicMock())
+        )
+        data = json.loads(result)
+        assert data["valid"] is True
+
+    def test_invalid_sql(self):
+        ts = SQLToolset("pg_default")
+        ts._hook = _make_mock_db_hook()
+
+        result = asyncio.run(
+            ts.call_tool("check_query", {"sql": "DROP TABLE users"}, 
ctx=MagicMock(), tool=MagicMock())
+        )
+        data = json.loads(result)
+        assert data["valid"] is False
+        assert "error" in data
+
+
+class TestSQLToolsetHookResolution:
+    @patch("airflow.providers.common.ai.toolsets.sql.BaseHook", autospec=True)
+    def test_lazy_resolves_db_hook(self, mock_base_hook):
+        from airflow.providers.common.sql.hooks.sql import DbApiHook
+
+        mock_hook = MagicMock(spec=DbApiHook)
+        mock_conn = MagicMock(spec=["get_hook"])
+        mock_conn.get_hook.return_value = mock_hook
+        mock_base_hook.get_connection.return_value = mock_conn
+
+        ts = SQLToolset("pg_default")
+        hook = ts._get_db_hook()
+
+        assert hook is mock_hook
+        mock_base_hook.get_connection.assert_called_once_with("pg_default")
+
+    @patch("airflow.providers.common.ai.toolsets.sql.BaseHook", autospec=True)
+    def test_raises_for_non_dbapi_hook(self, mock_base_hook):
+        mock_conn = MagicMock(spec=["get_hook"])
+        mock_conn.get_hook.return_value = MagicMock()  # Not a DbApiHook
+        mock_base_hook.get_connection.return_value = mock_conn
+
+        ts = SQLToolset("bad_conn")
+
+        with pytest.raises(ValueError, match="does not provide a DbApiHook"):
+            ts._get_db_hook()
+
+    @patch("airflow.providers.common.ai.toolsets.sql.BaseHook", autospec=True)
+    def test_caches_hook_after_first_resolution(self, mock_base_hook):
+        from airflow.providers.common.sql.hooks.sql import DbApiHook
+
+        mock_hook = MagicMock(spec=DbApiHook)
+        mock_conn = MagicMock(spec=["get_hook"])
+        mock_conn.get_hook.return_value = mock_hook
+        mock_base_hook.get_connection.return_value = mock_conn
+
+        ts = SQLToolset("pg_default")
+        ts._get_db_hook()
+        ts._get_db_hook()
+
+        # Only called once because result is cached.
+        mock_base_hook.get_connection.assert_called_once()

Reply via email to