Copilot commented on code in PR #67438:
URL: https://github.com/apache/airflow/pull/67438#discussion_r3295276666
##########
providers/common/ai/src/airflow/providers/common/ai/hooks/pydantic_ai.py:
##########
@@ -171,25 +185,134 @@ def _provider_factory(pname: str) -> Any:
self._model = infer_model(model_name)
return self._model
- @overload
- def create_agent(
- self, output_type: type[OutputT], *, instructions: str, **agent_kwargs
- ) -> Agent[None, OutputT]: ...
+ # ------------------------------------------------------------------
+ # BaseAIHook abstract interface
+ # ------------------------------------------------------------------
- @overload
- def create_agent(self, *, instructions: str, **agent_kwargs) ->
Agent[None, str]: ...
+ def _tool_spec_to_native(self, spec: ToolSpec) -> Any:
+ """Convert a
:class:`~airflow.providers.common.ai.hooks.base_ai.ToolSpec` to a pydantic-ai
``Tool``."""
+ from pydantic_ai.tools import Tool
- def create_agent(
- self, output_type: type[Any] = str, *, instructions: str,
**agent_kwargs
- ) -> Agent[None, Any]:
+ return Tool(spec.fn, name=spec.name, description=spec.description)
+
+ def create_agent(self, request: AgentRunRequest) -> Agent[None, Any]:
"""
- Create a pydantic-ai Agent configured with this hook's model.
+ Build a pydantic-ai ``Agent`` from *request*.
+
+ When :attr:`~AgentRunRequest.durable_context` is set, initialises
durable
+ storage and step counter and stores them on the instance for use by
+ :meth:`run_agent`.
- :param output_type: The expected output type from the agent (default:
``str``).
- :param instructions: System-level instructions for the agent.
- :param agent_kwargs: Additional keyword arguments passed to the Agent
constructor.
+ :param request: Agent configuration — output type, instructions,
toolsets, extra params.
"""
- return Agent(self.get_conn(), output_type=output_type,
instructions=instructions, **agent_kwargs)
+ if request.durable_context is not None:
+ storage, counter = self._init_durable(request.durable_context)
+ self._durable_storage = storage
+ self._durable_counter = counter
+ else:
+ self._durable_storage = None
+ self._durable_counter = None
+
+ extra_kwargs = dict(request.agent_params or {})
+ if request.toolsets:
+ from pydantic_ai.toolsets.abstract import AbstractToolset
+
+ abstract_items = [ts for ts in request.toolsets if isinstance(ts,
AbstractToolset)]
+ pipeline_items = [ts for ts in request.toolsets if not
isinstance(ts, AbstractToolset)]
+
+ if pipeline_items:
+ resolved = self._resolve_tools(
+ pipeline_items,
+ request.enable_tool_logging,
+ self._durable_storage,
+ self._durable_counter,
+ )
+ if resolved:
+ extra_kwargs["tools"] = resolved
+
+ if abstract_items:
+ processed: list[Any] = list(abstract_items)
+ if self._durable_storage is not None and self._durable_counter
is not None:
+ from airflow.providers.common.ai.durable.caching_toolset
import CachingToolset
+
+ processed = [
+ CachingToolset(
+ wrapped=ts,
+ storage=self._durable_storage,
+ counter=self._durable_counter,
+ )
+ for ts in processed
+ ]
+ if request.enable_tool_logging:
+ from airflow.providers.common.ai.toolsets.logging import
LoggingToolset
+
+ processed = [LoggingToolset(wrapped=ts, logger=self.log)
for ts in processed]
+ extra_kwargs["toolsets"] = processed
+
+ return Agent(
+ self.get_model(),
+ output_type=request.output_type,
+ instructions=request.instructions,
+ **extra_kwargs,
+ )
+
+ def run_agent(self, agent: Agent[None, Any], request: AgentRunRequest) ->
AgentRunResult:
+ """Run *agent* synchronously for *request* and return a normalized
:class:`~airflow.providers.common.ai.hooks.base_ai.AgentRunResult`."""
+ from pydantic_ai.messages import ToolCallPart
+
+ run_kwargs: dict[str, Any] = {}
+ if request.message_history is not None:
+ run_kwargs["message_history"] = request.message_history
+ if request.usage_limits is not None:
+ run_kwargs["usage_limits"] = request.usage_limits
+
+ if self._durable_storage is not None and self._durable_counter is not
None:
+ from airflow.providers.common.ai.durable.caching_model import
CachingModel
+
+ resolved_model = infer_model(agent.model)
+ caching_model = CachingModel(
Review Comment:
Durable cleanup/state reset only happens on the success path. If
`agent.run_sync(...)` raises (network error, model error, etc.),
`_durable_storage` won’t be cleaned up and
`_durable_storage`/`_durable_counter` will remain set on the hook instance,
which can leak temp artifacts and contaminate subsequent runs. Wrap the durable
execution path in a `try/finally` that always calls `cleanup()` and clears
`_durable_storage`/`_durable_counter` (and set `durable_stats` before clearing
when a result exists).
##########
providers/common/ai/docs/toolsets.rst:
##########
@@ -67,11 +79,13 @@ This works because toolsets resolve Airflow connections
lazily via
``BaseHook.get_connection()``, which is available in any task execution
context.
-This approach gives you full control over the agent lifecycle -- you can call
-``agent.run_sync()`` multiple times, swap models at runtime, or combine
-results from several agents in a single task. The tradeoff is that you lose
+This approach gives you direct control over the agent lifecycle — you can
+build and run multiple agents in a single task, or combine results from
+several runs. The tradeoff is that you lose
the durable execution (step-level caching with retry replay), HITL review
-integration, and automatic tool call logging that ``AgentOperator`` provides.
+integration, and the automatic tool call logging and routing that
+``AgentOperator`` provides via
+:class:`~airflow.providers.common.ai.toolsets.logging.LoggingToolset`.
Review Comment:
This paragraph currently attributes durable execution and general tool
routing/logging to `LoggingToolset`, but `AgentOperator`’s durable execution is
hook-driven and only pydantic-ai `AbstractToolset` items are wrapped with
`LoggingToolset` (BaseToolset/callables use callable-level wrappers). Reword to
avoid implying `LoggingToolset` is the mechanism for durable execution or for
all tool routing/logging.
##########
providers/common/ai/src/airflow/providers/common/ai/hooks/base_ai.py:
##########
@@ -0,0 +1,353 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Shared contract for agent-framework hooks used by
:class:`~airflow.providers.common.ai.operators.agent.AgentOperator`."""
+
+from __future__ import annotations
+
+import functools
+import inspect
+import json
+import time
+from abc import ABCMeta, abstractmethod
+from collections.abc import Callable
+from dataclasses import dataclass, field
+from typing import Any, ClassVar
+
+from airflow.providers.common.compat.sdk import BaseHook
+
+
+@dataclass
+class AgentUsage:
+ """Token and request usage from an agent run, when the backend exposes
it."""
+
+ requests: int = 0
+ tool_calls: int = 0
+ input_tokens: int = 0
+ output_tokens: int = 0
+ total_tokens: int = 0
+
+
+@dataclass
+class DurableStats:
+ """Step-level cache statistics from a durable agent run."""
+
+ replayed_model: int = 0
+ replayed_tool: int = 0
+ cached_model: int = 0
+ cached_tool: int = 0
+
+
+@dataclass
+class AgentRunResult:
+ """
+ Backend-neutral result from :meth:`BaseAIHook.run_agent`.
+
+ :param output: Final agent output (``str``, Pydantic model instance, etc.).
+ :param message_history: Opaque conversation state for HITL regeneration;
only pass back to the
+ same hook implementation that produced it.
+ :param model_name: Resolved model identifier, when available.
+ :param usage: Usage counters when the backend exposes them.
+ :param tool_names: Ordered tool names invoked during the run, when known.
+ :param durable_stats: Durable step-cache statistics, populated when
durable execution is enabled.
+ """
+
+ output: Any
+ message_history: Any = None
+ model_name: str | None = None
+ usage: AgentUsage | None = None
+ tool_names: list[str] | None = None
+ durable_stats: DurableStats | None = None
+
+
+@dataclass
+class ToolSpec:
+ """
+ Framework-neutral tool descriptor.
+
+ Toolsets produce :class:`ToolSpec` objects; each hook converts them to its
+ native tool representation via :meth:`BaseAIHook._tool_spec_to_native`.
+
+ :param name: Tool name exposed to the LLM.
+ :param description: Human-readable description used by the LLM to decide
when to call this tool.
+ :param parameters: JSON Schema ``object`` describing the tool's parameters.
+ :param fn: Callable that implements the tool. Must accept keyword
arguments matching *parameters*.
+ """
+
+ name: str
+ description: str
+ parameters: dict[str, Any]
+ fn: Callable[..., Any]
+
+
+@dataclass
+class DurableContext:
+ """Framework-neutral identity of the running task, used to locate the
durable cache file."""
+
+ dag_id: str
+ task_id: str
+ run_id: str
+ map_index: int = -1
+
+
+@dataclass
+class AgentRunRequest:
+ """
+ Parameter object passed to :meth:`BaseAIHook.create_agent` and
:meth:`BaseAIHook.run_agent`.
+
+ Encapsulates everything the hook needs to build and run an agent in a
single
+ framework-neutral structure, so that
:class:`~airflow.providers.common.ai.operators.agent.AgentOperator`
+ has zero framework-specific imports.
+
+ :param prompt: User prompt for this invocation.
+ :param output_type: Expected structured output type (default: ``str``).
+ :param instructions: System-level instructions for the agent.
+ :param toolsets: List of :class:`BaseToolset` instances the agent may call.
+ :param usage_limits: Backend-specific usage limits; ignored if the hook
does not support them.
+ :param message_history: Prior conversation state from a previous
:class:`AgentRunResult`.
+ :param enable_tool_logging: When ``True`` (default), wraps each tool
callable with a logging shim.
+ :param durable_context: When set, enables step-level durable caching for
the run.
+ :param agent_params: Extra keyword arguments forwarded to the underlying
agent constructor.
+ Use this escape hatch for framework-specific options.
+ """
+
+ prompt: str
+ output_type: type[Any] = str
+ instructions: str = ""
+ toolsets: list[Any] | None = None
+ usage_limits: Any = None
+ message_history: Any = None
+ enable_tool_logging: bool = True
+ durable_context: DurableContext | None = None
+ agent_params: dict[str, Any] = field(default_factory=dict)
+
+
+class BaseToolset(metaclass=ABCMeta):
+ """
+ Abstract base for framework-agnostic toolsets.
+
+ Subclasses implement :meth:`as_tools` to return a list of :class:`ToolSpec`
+ objects. Each hook converts those specs to its native tool representation
+ via :meth:`BaseAIHook._tool_spec_to_native`.
+ """
+
+ @abstractmethod
+ def as_tools(self) -> list[ToolSpec]:
+ """Return the list of tools this toolset exposes."""
+
+
+class BaseAIHook(BaseHook, metaclass=ABCMeta):
+ """
+ Abstract hook for multi-turn LLM agents.
+
+ :class:`~airflow.providers.common.ai.operators.agent.AgentOperator`
resolves the concrete hook
+ from the Airflow connection ``conn_type`` (for example ``pydanticai`` or
``pydanticai-bedrock``).
+
+ Subclasses implement :meth:`get_model`, :meth:`create_agent`,
:meth:`run_agent`, and
+ :meth:`_tool_spec_to_native`.
+
+ Shared helpers :meth:`_init_durable`, :meth:`_resolve_tools`,
:meth:`_logged_callable`, and
+ :meth:`_cached_callable` are provided for all hooks.
+ """
+
+ conn_name_attr = "llm_conn_id"
+
+ supports_toolsets: ClassVar[bool] = False
+ supports_durable: ClassVar[bool] = False
+ supports_usage_limits: ClassVar[bool] = False
+
+ @classmethod
+ def get_agent_hook(cls, conn_id: str, *, hook_params: dict[str, Any] |
None = None) -> BaseAIHook:
+ """
+ Return an agent hook for *conn_id*, verifying it implements this
contract.
+
+ Uses the connection's ``conn_type`` to select the hook class
registered in
+ ``provider.yaml``.
+ """
+ hook = cls.get_hook(conn_id, hook_params=hook_params)
+ if not isinstance(hook, BaseAIHook):
+ raise TypeError(
+ f"Connection {conn_id!r} resolved to {type(hook).__name__},
which is not a BaseAIHook. "
+ "Use a connection type registered for agent frameworks (e.g.
pydanticai, pydanticai-bedrock)."
+ )
+ return hook
+
+ @abstractmethod
+ def get_model(self) -> Any:
+ """Return the backend model/client used to construct agents."""
+
+ def get_conn(self) -> Any:
+ """Return the backend model/client. Delegates to :meth:`get_model`."""
+ return self.get_model()
+
+ @abstractmethod
+ def create_agent(self, request: AgentRunRequest) -> Any:
+ """
+ Build (but do not run) the agent described by *request*.
+
+ Responsible for resolving :attr:`AgentRunRequest.toolsets` via
+ :meth:`_resolve_tools` and constructing the framework-native agent
object
+ with the model, tools, instructions, and output type from *request*.
+
+ When :attr:`AgentRunRequest.durable_context` is set, implementations
+ should call :meth:`_init_durable` and store the returned
storage/counter
+ on the instance so that :meth:`run_agent` can use them.
+
+ :param request: All parameters needed to configure the agent.
+ :returns: Framework-native agent object, ready to be passed to
:meth:`run_agent`.
+ """
+
+ @abstractmethod
+ def run_agent(self, agent: Any, request: AgentRunRequest) ->
AgentRunResult:
+ """
+ Execute *agent* for *request* and return a normalized
:class:`AgentRunResult`.
+
+ Implementations that store durable state on the instance (set during
+ :meth:`create_agent`) should apply it here and clean up after the run.
+
+ :param agent: Framework-native agent produced by :meth:`create_agent`.
+ :param request: The same request used to create the agent (prompt,
usage
+ limits, message history, etc.).
+ """
+
+ @abstractmethod
+ def _tool_spec_to_native(self, spec: ToolSpec) -> Any:
+ """
+ Convert a :class:`ToolSpec` to the agent framework's native tool
representation.
+
+ Called once per tool inside :meth:`_resolve_tools`. The returned object
+ is collected into a list and passed to the underlying agent
constructor.
+
+ :param spec: Universal tool descriptor, with the callable already
wrapped
+ by any enabled logging / caching shims.
+ """
+
+ def _init_durable(self, ctx: DurableContext) -> tuple[Any, Any]:
+ """
+ Create and return a ``DurableStorage`` / ``DurableStepCounter`` pair
for *ctx*.
+
+ Hooks call this inside :meth:`create_agent` when
+ :attr:`AgentRunRequest.durable_context` is set.
+ """
+ from airflow.providers.common.ai.durable.step_counter import
DurableStepCounter
+ from airflow.providers.common.ai.durable.storage import DurableStorage
+
+ storage = DurableStorage(
+ dag_id=ctx.dag_id,
+ task_id=ctx.task_id,
+ run_id=ctx.run_id,
+ map_index=ctx.map_index,
+ )
+ counter = DurableStepCounter()
+ return storage, counter
+
+ def _resolve_tools(
+ self,
+ toolsets: list[Any],
+ enable_logging: bool,
+ storage: Any,
+ counter: Any,
+ ) -> list[Any]:
+ """
+ Convert a mixed list of toolsets / callables / native tools into
framework-native tools.
+
+ Three cases per item:
+
+ * :class:`BaseToolset` — calls ``as_tools()`` and processes each
:class:`ToolSpec`.
+ * Plain Python function (``def`` / ``lambda``) — auto-wraps into a
:class:`ToolSpec`
+ using ``__name__`` and ``__doc__``, then processes it the same way.
+ * Anything else — passed through unchanged (assumed to be a native
tool object already
+ constructed for the target framework).
+
+ The processing pipeline for ``BaseToolset`` and callable items:
+ *fn* → optional cache wrap → optional log wrap →
:meth:`_tool_spec_to_native`.
+
+ :param toolsets: Mix of :class:`BaseToolset` instances, plain
callables, and native tool objects.
+ :param enable_logging: When ``True``, wrap each callable with
:meth:`_logged_callable`.
+ :param storage: ``DurableStorage`` instance, or ``None`` to skip
caching.
+ :param counter: ``DurableStepCounter`` instance, or ``None`` to skip
caching.
+ """
+ native: list[Any] = []
+ for ts in toolsets:
+ if isinstance(ts, BaseToolset):
+ specs = ts.as_tools()
+ elif inspect.isfunction(ts):
+ specs = [ToolSpec(name=ts.__name__, description=ts.__doc__ or
"", parameters={}, fn=ts)]
Review Comment:
When a plain Python function is passed in `toolsets`, `_resolve_tools` wraps
it into a `ToolSpec` with `parameters={}`. `ToolSpec.parameters` is documented
as a JSON Schema object, so `{}` is not a valid schema (and will be ambiguous
for backends that require an explicit object schema). Use an explicit
empty-object schema (e.g. `{"type": "object", "properties": {}}`) to keep the
contract consistent.
##########
providers/common/ai/tests/unit/common/ai/hooks/test_base_ai.py:
##########
@@ -0,0 +1,343 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from airflow.providers.common.ai.hooks.base_ai import (
+ AgentRunRequest,
+ AgentRunResult,
+ AgentUsage,
+ BaseAIHook,
+ BaseToolset,
+ DurableContext,
+ DurableStats,
+ ToolSpec,
+)
+from airflow.providers.common.compat.sdk import BaseHook
+
+
+class TestBaseAIHookGetAgentHook:
+ @patch("airflow.providers.common.ai.hooks.base_ai.BaseHook.get_hook",
autospec=True)
+ def test_returns_hook_when_instance_is_base_ai_hook(self, mock_get_hook):
+ mock_hook = MagicMock(spec=BaseAIHook)
+ mock_get_hook.return_value = mock_hook
+
+ result = BaseAIHook.get_agent_hook("my_conn")
+
+ assert result is mock_hook
+ mock_get_hook.assert_called_once_with("my_conn", hook_params=None)
+
+ @patch("airflow.providers.common.ai.hooks.base_ai.BaseHook.get_hook",
autospec=True)
+ def test_raises_when_hook_is_not_base_ai_hook(self, mock_get_hook):
+ mock_get_hook.return_value = MagicMock(spec=BaseHook)
+
+ with pytest.raises(TypeError, match="not a BaseAIHook"):
+ BaseAIHook.get_agent_hook("my_conn")
+
+
+class TestAgentRunResult:
+ def test_dataclass_fields(self):
+ usage = AgentUsage(requests=1, tool_calls=2, total_tokens=10)
+ result = AgentRunResult(
+ output="answer",
+ message_history=["msg"],
+ model_name="test-model",
+ usage=usage,
+ tool_names=["query"],
+ )
+ assert result.output == "answer"
+ assert result.message_history == ["msg"]
+ assert result.model_name == "test-model"
+ assert result.usage == usage
+ assert result.tool_names == ["query"]
+ assert result.durable_stats is None
+
+ def test_durable_stats_field(self):
+ stats = DurableStats(replayed_model=2, cached_model=3)
+ result = AgentRunResult(output="x", durable_stats=stats)
+ assert result.durable_stats is stats
+
+
+class TestAgentRunRequest:
+ def test_defaults(self):
+ req = AgentRunRequest(prompt="hello")
+ assert req.prompt == "hello"
+ assert req.output_type is str
+ assert req.instructions == ""
+ assert req.toolsets is None
+ assert req.usage_limits is None
+ assert req.message_history is None
+ assert req.enable_tool_logging is True
+ assert req.durable_context is None
+ assert req.agent_params == {}
+
+ def test_with_all_fields(self):
+ ctx = DurableContext(dag_id="d", task_id="t", run_id="r", map_index=2)
+ req = AgentRunRequest(
+ prompt="test",
+ output_type=int,
+ instructions="sys",
+ toolsets=["ts"],
+ usage_limits="limits",
+ message_history=["h"],
+ enable_tool_logging=False,
+ durable_context=ctx,
+ agent_params={"retries": 3},
+ )
+ assert req.output_type is int
+ assert req.instructions == "sys"
+ assert req.durable_context is ctx
+ assert req.agent_params == {"retries": 3}
+
+
+class TestBaseAIHookResolveTools:
+ def test_resolve_tools_calls_spec_to_native(self):
+ """_resolve_tools converts each ToolSpec via _tool_spec_to_native."""
+
+ class ConcreteHook(BaseAIHook):
+ conn_type = "test"
+ hook_name = "Test"
+
+ def get_model(self):
+ return None
+
+ def create_agent(self, request):
+ return None
+
+ def run_agent(self, agent, request):
+ return AgentRunResult(output="")
+
+ def _tool_spec_to_native(self, spec):
+ return {"name": spec.name, "fn": spec.fn}
+
+ hook = ConcreteHook.__new__(ConcreteHook)
+
+ def my_tool(x: int) -> str:
+ return str(x)
+
+ class MyToolset(BaseToolset):
+ def as_tools(self):
+ return [ToolSpec(name="my_tool", description="desc",
parameters={}, fn=my_tool)]
+
+ result = hook._resolve_tools([MyToolset()], enable_logging=False,
storage=None, counter=None)
+
+ assert len(result) == 1
+ assert result[0]["name"] == "my_tool"
+
+ def test_resolve_tools_wraps_with_logging(self):
+ """When enable_logging=True, callable is wrapped."""
+ mock_log = MagicMock()
+
+ class ConcreteHook(BaseAIHook):
+ conn_type = "test"
+ hook_name = "Test"
+
+ @property
+ def log(self):
+ return mock_log
+
+ def get_model(self):
+ return None
+
+ def create_agent(self, request):
+ return None
+
+ def run_agent(self, agent, request):
+ return AgentRunResult(output="")
+
+ def _tool_spec_to_native(self, spec):
+ return spec.fn
+
+ hook = ConcreteHook.__new__(ConcreteHook)
+
+ calls = []
+
+ def original():
+ calls.append("original")
+ return "result"
+
+ class SimpleToolset(BaseToolset):
+ def as_tools(self):
+ return [ToolSpec(name="original", description="",
parameters={}, fn=original)]
+
+ [wrapped_fn] = hook._resolve_tools([SimpleToolset()],
enable_logging=True, storage=None, counter=None)
+ wrapped_fn()
+
+ assert calls == ["original"]
+ mock_log.info.assert_called()
+
+ def test_resolve_tools_wraps_plain_callable(self):
+ """A bare Python function is auto-wrapped using __name__ and
__doc__."""
+
+ class ConcreteHook(BaseAIHook):
+ conn_type = "test"
+ hook_name = "Test"
+
+ def get_model(self):
+ return None
+
+ def create_agent(self, request):
+ return None
+
+ def run_agent(self, agent, request):
+ return AgentRunResult(output="")
+
+ def _tool_spec_to_native(self, spec):
+ return {"name": spec.name, "description": spec.description}
+
+ hook = ConcreteHook.__new__(ConcreteHook)
+
+ def roll_dice() -> str:
+ """Roll a six-sided die and return the result."""
+ return "4"
+
+ result = hook._resolve_tools([roll_dice], enable_logging=False,
storage=None, counter=None)
+
+ assert len(result) == 1
+ assert result[0]["name"] == "roll_dice"
+ assert result[0]["description"] == "Roll a six-sided die and return
the result."
+
+ def test_resolve_tools_passes_non_function_non_toolset_through(self):
+ """Items that are not BaseToolset and not plain functions are passed
through unchanged."""
+
+ class ConcreteHook(BaseAIHook):
+ conn_type = "test"
+ hook_name = "Test"
+
+ def get_model(self):
+ return None
+
+ def create_agent(self, request):
+ return None
+
+ def run_agent(self, agent, request):
+ return AgentRunResult(output="")
+
+ def _tool_spec_to_native(self, spec):
+ return spec.fn
+
+ hook = ConcreteHook.__new__(ConcreteHook)
+
+ native_tool_obj = object() # not a function, not a BaseToolset
+ result = hook._resolve_tools([native_tool_obj], enable_logging=True,
storage=None, counter=None)
+
+ assert result == [native_tool_obj]
+
+ def test_resolve_tools_mixes_base_toolset_and_native(self):
+ """BaseToolset items are converted; non-function native items are
passed through in order."""
+
+ class ConcreteHook(BaseAIHook):
+ conn_type = "test"
+ hook_name = "Test"
+
+ def get_model(self):
+ return None
+
+ def create_agent(self, request):
+ return None
+
+ def run_agent(self, agent, request):
+ return AgentRunResult(output="")
+
+ def _tool_spec_to_native(self, spec):
+ return f"converted:{spec.name}"
+
+ hook = ConcreteHook.__new__(ConcreteHook)
+
+ native_tool = object() # not a function, passes through unchanged
+
+ class MyToolset(BaseToolset):
+ def as_tools(self):
+ return [ToolSpec(name="greet", description="", parameters={},
fn=lambda: "hi")]
+
+ result = hook._resolve_tools(
+ [MyToolset(), native_tool], enable_logging=False, storage=None,
counter=None
+ )
+
+ assert result == ["converted:greet", native_tool]
+
+
+class TestBaseAIHookLoggedCallable:
+ def test_logged_callable_logs_and_returns(self):
+ logger = MagicMock()
+ calls = []
+
+ def fn(x):
+ calls.append(x)
+ return x * 2
+
+ wrapped = BaseAIHook._logged_callable(fn, logger)
+ result = wrapped(x=5)
+
+ assert result == 10
+ assert calls == [5]
+ logger.info.assert_called()
+
+ def test_logged_callable_logs_exception(self):
+ logger = MagicMock()
+
+ def failing():
+ raise RuntimeError("boom")
+
+ wrapped = BaseAIHook._logged_callable(failing, logger)
+ with pytest.raises(RuntimeError, match="boom"):
+ wrapped()
+
+ logger.exception.assert_called_once()
+
+
+class TestBaseAIHookCachedCallable:
+ def test_cached_callable_saves_and_returns(self):
+ storage = MagicMock()
+ counter = MagicMock()
+ counter.next_step.return_value = 1
+ storage.load_tool_result.return_value = (False, None)
+
+ calls = []
+
+ def fn():
+ calls.append(1)
+ return "computed"
+
+ wrapped = BaseAIHook._cached_callable(fn, storage, counter)
+ result = wrapped()
+
+ assert result == "computed"
+ assert calls == [1]
+ storage.save_tool_result.assert_called_once_with("tool_step_1",
"computed")
+
+ def test_cached_callable_replays_on_hit(self):
+ storage = MagicMock()
+ counter = MagicMock()
+ counter.next_step.return_value = 1
+ storage.load_tool_result.return_value = (True, "cached_value")
+
+ calls = []
+
+ def fn():
+ calls.append(1)
+ return "computed"
+
+ wrapped = BaseAIHook._cached_callable(fn, storage, counter)
+ result = wrapped()
+
+ assert result == "cached_value"
+ assert calls == []
+ counter.replayed_tool += 1
Review Comment:
`test_cached_callable_replays_on_hit` doesn’t currently assert that
`_cached_callable` increments `counter.replayed_tool`; `counter.replayed_tool
+= 1` mutates the mock in the test rather than verifying behavior. Initialize
`counter.replayed_tool` to an int and assert its post-call value (and
optionally assert no save occurs on a cache hit).
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]