kaxil commented on code in PR #67438:
URL: https://github.com/apache/airflow/pull/67438#discussion_r3298933682


##########
providers/common/ai/src/airflow/providers/common/ai/toolsets/sql.py:
##########
@@ -236,11 +208,10 @@ def _query(self, sql: str) -> str:
             rows = hook.get_records(sql)
         except Exception as e:
             if self._is_retryable_query_error(hook, e):
-                raise ModelRetry(
+                raise ValueError(

Review Comment:
   Swapping `ModelRetry` for `ValueError` here looks like a behavior 
regression. `ModelRetry` told the pydantic-ai agent "your tool call was 
recoverable, try again with the new context I gave you" so the LLM could 
self-correct on `UndefinedColumn`/`UndefinedTable` by calling `get_schema` 
first. A plain `ValueError` propagates straight out of `agent.run_sync()` and 
fails the whole task, which is the opposite of what `_is_retryable_query_error` 
is gating for.
   
   If the new `BaseToolset` contract needs a framework-neutral "ask the model 
to retry" signal, that's worth a follow-up; otherwise the pydantic-ai-specific 
raise should stay (the toolset still runs through `Tool(spec.fn, ...)` in the 
same agent, so `ModelRetry` is still in scope).



##########
providers/common/ai/src/airflow/providers/common/ai/toolsets/sql.py:
##########
@@ -164,50 +150,36 @@ def _get_db_hook(self) -> DbApiHook:
         return self._hook
 
     # ------------------------------------------------------------------
-    # AbstractToolset interface
+    # BaseToolset interface
     # ------------------------------------------------------------------
 
-    async def get_tools(self, ctx: RunContext[Any]) -> dict[str, 
ToolsetTool[Any]]:
-        tools: dict[str, ToolsetTool[Any]] = {}
-
-        for name, description, schema in (
-            ("list_tables", "List available table names in the database.", 
_LIST_TABLES_SCHEMA),
-            ("get_schema", "Get column names and types for a table.", 
_GET_SCHEMA_SCHEMA),
-            ("query", "Execute a SQL query and return rows as JSON.", 
_QUERY_SCHEMA),
-            ("check_query", "Validate SQL syntax without executing it.", 
_CHECK_QUERY_SCHEMA),
-        ):
-            # sequential=True because all tools use a shared DbApiHook with
-            # synchronous I/O — they must not run concurrently.
-            tool_def = ToolDefinition(
-                name=name,
-                description=description,
-                parameters_json_schema=schema,
-                sequential=True,
-            )
-            tools[name] = ToolsetTool(
-                toolset=self,
-                tool_def=tool_def,
-                max_retries=1,
-                args_validator=_PASSTHROUGH_VALIDATOR,
-            )
-        return tools
-
-    async def call_tool(
-        self,
-        name: str,
-        tool_args: dict[str, Any],
-        ctx: RunContext[Any],
-        tool: ToolsetTool[Any],
-    ) -> Any:
-        if name == "list_tables":
-            return self._list_tables()
-        if name == "get_schema":
-            return self._get_schema(tool_args["table_name"])
-        if name == "query":
-            return self._query(tool_args["sql"])
-        if name == "check_query":
-            return self._check_query(tool_args["sql"])
-        raise ValueError(f"Unknown tool: {name!r}")
+    def as_tools(self) -> list[ToolSpec]:

Review Comment:
   The previous `get_tools` set `sequential=True` on every tool because they 
all share one `DbApiHook` with sync I/O and "must not run concurrently" (per 
the comment that was removed). The new `ToolSpec` shape has no way to express 
that, and `_tool_spec_to_native` just produces `Tool(spec.fn, name=..., 
description=...)` with no sequential flag.
   
   If pydantic-ai parallelises tool calls in a single turn (it can for 
independent calls), `list_tables` / `get_schema` / `query` will share the same 
hook's cursor concurrently. Worth either pinning the tools sequential at 
conversion time for hooks whose toolsets need it, or exposing a `sequential: 
bool` field on `ToolSpec` so toolsets can opt in framework-neutrally.



##########
providers/common/ai/src/airflow/providers/common/ai/hooks/base_ai.py:
##########
@@ -0,0 +1,353 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Shared contract for agent-framework hooks used by 
:class:`~airflow.providers.common.ai.operators.agent.AgentOperator`."""
+
+from __future__ import annotations
+
+import functools
+import inspect
+import json
+import time
+from abc import ABCMeta, abstractmethod
+from collections.abc import Callable
+from dataclasses import dataclass, field
+from typing import Any, ClassVar
+
+from airflow.providers.common.compat.sdk import BaseHook
+
+
+@dataclass
+class AgentUsage:
+    """Token and request usage from an agent run, when the backend exposes 
it."""
+
+    requests: int = 0
+    tool_calls: int = 0
+    input_tokens: int = 0
+    output_tokens: int = 0
+    total_tokens: int = 0
+
+
+@dataclass
+class DurableStats:
+    """Step-level cache statistics from a durable agent run."""
+
+    replayed_model: int = 0
+    replayed_tool: int = 0
+    cached_model: int = 0
+    cached_tool: int = 0
+
+
+@dataclass
+class AgentRunResult:
+    """
+    Backend-neutral result from :meth:`BaseAIHook.run_agent`.
+
+    :param output: Final agent output (``str``, Pydantic model instance, etc.).
+    :param message_history: Opaque conversation state for HITL regeneration; 
only pass back to the
+        same hook implementation that produced it.
+    :param model_name: Resolved model identifier, when available.
+    :param usage: Usage counters when the backend exposes them.
+    :param tool_names: Ordered tool names invoked during the run, when known.
+    :param durable_stats: Durable step-cache statistics, populated when 
durable execution is enabled.
+    """
+
+    output: Any
+    message_history: Any = None
+    model_name: str | None = None
+    usage: AgentUsage | None = None
+    tool_names: list[str] | None = None
+    durable_stats: DurableStats | None = None
+
+
+@dataclass
+class ToolSpec:
+    """
+    Framework-neutral tool descriptor.
+
+    Toolsets produce :class:`ToolSpec` objects; each hook converts them to its
+    native tool representation via :meth:`BaseAIHook._tool_spec_to_native`.
+
+    :param name: Tool name exposed to the LLM.
+    :param description: Human-readable description used by the LLM to decide 
when to call this tool.
+    :param parameters: JSON Schema ``object`` describing the tool's parameters.
+    :param fn: Callable that implements the tool. Must accept keyword 
arguments matching *parameters*.
+    """
+
+    name: str
+    description: str
+    parameters: dict[str, Any]
+    fn: Callable[..., Any]
+
+
+@dataclass
+class DurableContext:
+    """Framework-neutral identity of the running task, used to locate the 
durable cache file."""
+
+    dag_id: str
+    task_id: str
+    run_id: str
+    map_index: int = -1
+
+
+@dataclass
+class AgentRunRequest:
+    """
+    Parameter object passed to :meth:`BaseAIHook.create_agent` and 
:meth:`BaseAIHook.run_agent`.
+
+    Encapsulates everything the hook needs to build and run an agent in a 
single
+    framework-neutral structure, so that 
:class:`~airflow.providers.common.ai.operators.agent.AgentOperator`
+    has zero framework-specific imports.
+
+    :param prompt: User prompt for this invocation.
+    :param output_type: Expected structured output type (default: ``str``).
+    :param instructions: System-level instructions for the agent.
+    :param toolsets: List of :class:`BaseToolset` instances the agent may call.
+    :param usage_limits: Backend-specific usage limits; ignored if the hook 
does not support them.
+    :param message_history: Prior conversation state from a previous 
:class:`AgentRunResult`.
+    :param enable_tool_logging: When ``True`` (default), wraps each tool 
callable with a logging shim.
+    :param durable_context: When set, enables step-level durable caching for 
the run.
+    :param agent_params: Extra keyword arguments forwarded to the underlying 
agent constructor.
+        Use this escape hatch for framework-specific options.
+    """
+
+    prompt: str
+    output_type: type[Any] = str
+    instructions: str = ""
+    toolsets: list[Any] | None = None
+    usage_limits: Any = None
+    message_history: Any = None
+    enable_tool_logging: bool = True
+    durable_context: DurableContext | None = None
+    agent_params: dict[str, Any] = field(default_factory=dict)
+
+
+class BaseToolset(metaclass=ABCMeta):
+    """
+    Abstract base for framework-agnostic toolsets.
+
+    Subclasses implement :meth:`as_tools` to return a list of :class:`ToolSpec`
+    objects.  Each hook converts those specs to its native tool representation
+    via :meth:`BaseAIHook._tool_spec_to_native`.
+    """
+
+    @abstractmethod
+    def as_tools(self) -> list[ToolSpec]:
+        """Return the list of tools this toolset exposes."""
+
+
+class BaseAIHook(BaseHook, metaclass=ABCMeta):
+    """
+    Abstract hook for multi-turn LLM agents.
+
+    :class:`~airflow.providers.common.ai.operators.agent.AgentOperator` 
resolves the concrete hook
+    from the Airflow connection ``conn_type`` (for example ``pydanticai`` or 
``pydanticai-bedrock``).
+
+    Subclasses implement :meth:`get_model`, :meth:`create_agent`, 
:meth:`run_agent`, and
+    :meth:`_tool_spec_to_native`.
+
+    Shared helpers :meth:`_init_durable`, :meth:`_resolve_tools`, 
:meth:`_logged_callable`, and
+    :meth:`_cached_callable` are provided for all hooks.
+    """
+
+    conn_name_attr = "llm_conn_id"
+
+    supports_toolsets: ClassVar[bool] = False
+    supports_durable: ClassVar[bool] = False
+    supports_usage_limits: ClassVar[bool] = False
+
+    @classmethod
+    def get_agent_hook(cls, conn_id: str, *, hook_params: dict[str, Any] | 
None = None) -> BaseAIHook:
+        """
+        Return an agent hook for *conn_id*, verifying it implements this 
contract.
+
+        Uses the connection's ``conn_type`` to select the hook class 
registered in
+        ``provider.yaml``.
+        """
+        hook = cls.get_hook(conn_id, hook_params=hook_params)

Review Comment:
   `get_agent_hook` is the backend-neutral entry point, but operators call it 
with `hook_params={"model_id": self.model_id}`, and `model_id` is a 
pydantic-ai-only concept. A future `StrandsAIHook` or `ADKHook` that doesn't 
accept `model_id` in `__init__` will blow up here with `TypeError: __init__() 
got an unexpected keyword argument 'model_id'`.
   
   Either the operators should stop hardcoding `model_id` and the hook should 
read it from the connection extra, or `BaseAIHook.__init__` should accept and 
normalise common kwargs (e.g. `model_id`) and let subclasses opt out. As 
written, the abstraction leaks the pydantic-ai param straight through the 
backend-neutral resolver.



##########
providers/common/ai/src/airflow/providers/common/ai/hooks/pydantic_ai.py:
##########
@@ -171,25 +185,134 @@ def _provider_factory(pname: str) -> Any:
         self._model = infer_model(model_name)
         return self._model
 
-    @overload
-    def create_agent(
-        self, output_type: type[OutputT], *, instructions: str, **agent_kwargs
-    ) -> Agent[None, OutputT]: ...
+    # ------------------------------------------------------------------
+    # BaseAIHook abstract interface
+    # ------------------------------------------------------------------
 
-    @overload
-    def create_agent(self, *, instructions: str, **agent_kwargs) -> 
Agent[None, str]: ...
+    def _tool_spec_to_native(self, spec: ToolSpec) -> Any:
+        """Convert a 
:class:`~airflow.providers.common.ai.hooks.base_ai.ToolSpec` to a pydantic-ai 
``Tool``."""
+        from pydantic_ai.tools import Tool
 
-    def create_agent(
-        self, output_type: type[Any] = str, *, instructions: str, 
**agent_kwargs
-    ) -> Agent[None, Any]:
+        return Tool(spec.fn, name=spec.name, description=spec.description)
+
+    def create_agent(self, request: AgentRunRequest) -> Agent[None, Any]:
         """
-        Create a pydantic-ai Agent configured with this hook's model.
+        Build a pydantic-ai ``Agent`` from *request*.
+
+        When :attr:`~AgentRunRequest.durable_context` is set, initialises 
durable
+        storage and step counter and stores them on the instance for use by
+        :meth:`run_agent`.
 
-        :param output_type: The expected output type from the agent (default: 
``str``).
-        :param instructions: System-level instructions for the agent.
-        :param agent_kwargs: Additional keyword arguments passed to the Agent 
constructor.
+        :param request: Agent configuration — output type, instructions, 
toolsets, extra params.
         """
-        return Agent(self.get_conn(), output_type=output_type, 
instructions=instructions, **agent_kwargs)
+        if request.durable_context is not None:

Review Comment:
   Stashing `_durable_storage` / `_durable_counter` on the hook instance in 
`create_agent` couples two calls that the contract says are independent. If a 
caller (or `LLMRetryPolicy` / `regenerate_with_feedback`) does 
`create_agent(req_a)` then `create_agent(req_b)` then runs them, the second 
`create_agent` overwrites state the first agent depends on, and the `else` 
branch a few lines down actively clears it to `None`.
   
   The shared helpers in `BaseAIHook` already take `storage`/`counter` as 
parameters, so the cleaner fix is to attach the pair to the returned `agent` (a 
small wrapper, or a `WeakKeyDictionary` keyed by agent id) and have `run_agent` 
pull them off the agent rather than `self`. Same root cause as Copilot's 
cleanup-on-exception comment at line 273, but worth fixing the contract rather 
than just the `finally`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to