kaxil commented on code in PR #67438: URL: https://github.com/apache/airflow/pull/67438#discussion_r3352633688
########## providers/common/ai/src/airflow/providers/common/ai/hooks/base.py: ########## @@ -0,0 +1,415 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Shared contract for agent-framework hooks used by :class:`~airflow.providers.common.ai.operators.agent.AgentOperator`.""" + +from __future__ import annotations + +import functools +import json +import time +from abc import ABCMeta, abstractmethod +from collections.abc import Callable, Sequence +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, ClassVar, Generic, TypeVar + +from airflow.providers.common.ai.utils.callables import is_async_callable +from airflow.providers.common.ai.utils.function_schema import callable_to_tool_spec +from airflow.providers.common.compat.sdk import BaseHook + +AgentT = TypeVar("AgentT") + + +class Capability(str, Enum): + """ + Capability tokens declared by concrete hook classes. + + A hook advertises its support by including the relevant tokens in its + :attr:`BaseAIHook.capabilities` frozenset. + :meth:`BaseAIHook.validate_run_request` rejects requests that use a + feature whose token is absent. + """ + + TOOLSETS = "toolsets" + USAGE_LIMITS = "usage_limits" + DURABLE = "durable" + + +@dataclass +class AgentUsage: + """Token and request usage from an agent run, when the backend exposes it.""" + + requests: int | None = None + tool_calls: int | None = None + input_tokens: int | None = None + output_tokens: int | None = None + total_tokens: int | None = None + + +@dataclass +class DurableStats: + """Step-level cache statistics from a durable agent run.""" + + replayed_model: int = 0 + replayed_tool: int = 0 + cached_model: int = 0 + cached_tool: int = 0 + + +@dataclass +class AgentRunResult: + """ + Backend-neutral result from :meth:`BaseAIHook.run_agent`. + + :param output: Final agent output (``str``, Pydantic model instance, etc.). + :param message_history: Opaque conversation state for HITL regeneration; only pass back to the + same hook implementation that produced it. + :param model_name: Resolved model identifier, when available. + :param usage: Usage counters when the backend exposes them. + :param tool_names: Ordered tool names invoked during the run, when known. + :param durable_stats: Durable step-cache statistics, populated when durable execution is enabled. + """ + + output: Any + message_history: Any = None + model_name: str | None = None + usage: AgentUsage | None = None + tool_names: list[str] | None = None + durable_stats: DurableStats | None = None + + +@dataclass +class ToolSpec: + """ + Framework-neutral tool descriptor. + + Toolsets produce :class:`ToolSpec` objects; each hook converts them to its + native tool representation via :meth:`BaseAIHook._tool_spec_to_native`. + + :param name: Tool name exposed to the LLM. + :param description: Human-readable description used by the LLM to decide when to call this tool. + :param parameters: JSON Schema ``object`` describing the tool's parameters. + :param fn: Callable that implements the tool. Must accept keyword arguments matching *parameters*. + :param sequential: When ``True``, the backend must not invoke this tool concurrently with others + in the same turn (for example when tools share a non-thread-safe connection). + """ + + name: str + description: str + parameters: dict[str, Any] + fn: Callable[..., Any] + sequential: bool = False + + +@dataclass +class DurableContext: + """Framework-neutral identity of the running task, used to locate the durable cache file.""" + + dag_id: str + task_id: str + run_id: str + map_index: int = -1 + + +@dataclass +class AgentRunRequest: + """ + Parameter object passed to :meth:`BaseAIHook.create_agent` and :meth:`BaseAIHook.run_agent`. + + Encapsulates everything the hook needs to build and run an agent in a single + framework-neutral structure, so that :class:`~airflow.providers.common.ai.operators.agent.AgentOperator` + has zero framework-specific imports. This contract is currently validated by + the pydantic-ai hook family and may evolve as more framework backends are added. + + :param prompt: User prompt for this invocation (plain ``str`` or a multimodal + ``Sequence`` accepted by the backend agent's run API). + :param output_type: Expected structured output type or backend-specific JSON schema + mapping (default: ``str``). + :param instructions: System-level instructions for the agent. + :param toolsets: List of tools/toolsets the agent may call (BaseToolset instances, plain callables, or backend-native tool objects). + :param usage_limits: Backend-specific usage limits; ignored if the hook does not support them. + :param message_history: Prior conversation state from a previous :class:`AgentRunResult`. + :param enable_tool_logging: When ``True`` (default), wraps Airflow-resolved tool callables with + a logging shim. Backend-native tool objects may be passed through unchanged by the concrete + hook and might not receive this wrapper. + :param durable_context: When set, enables step-level durable caching for the run. + :param agent_params: Extra keyword arguments forwarded to the underlying agent constructor. + Use this escape hatch for framework-specific options. + """ + + prompt: str | Sequence[Any] + output_type: type[Any] | dict[str, Any] | None = str Review Comment: Widening this to `type | dict | None` matches what I asked for earlier, thanks. One catch though: the only backend passes it straight to `Agent(output_type=...)` (pydantic_ai.py:302-303), and on pydantic-ai-slim 1.96 a raw JSON-schema dict raises `SchemaError: Unknown schema type: "object"`. No caller hits it today (both operators type `output_type: type`), so it's latent, but the annotation now invites a dict that the backend rejects. Worth either dropping the `dict` arm until a backend can consume it, or noting in the docstring that it's reserved for backends that accept a raw schema. ########## providers/common/ai/src/airflow/providers/common/ai/toolsets/sql.py: ########## @@ -99,7 +91,7 @@ _SQLALCHEMY_RETRYABLE_EXCEPTIONS = (_SQLAlchemyProgrammingError,) -class SQLToolset(AbstractToolset[Any]): +class SQLToolset(BaseToolset): Review Comment: `SQLToolset` shipped in 0.3.0 as a `pydantic_ai` `AbstractToolset` with public `get_tools`/`call_tool`; here it becomes a `BaseToolset` with only `as_tools()`. The in-tree operator and langchain-bridge paths are migrated, but anyone on 0.3.0 passing `SQLToolset(...)` directly to `pydantic_ai.Agent(toolsets=[...])` breaks since it no longer satisfies the toolset protocol. Not mandatory, but if it's cheap to keep back-compat (e.g. have `SQLToolset` still expose the `AbstractToolset` surface, or ship a thin shim) that would spare direct 0.3.0 users the break. If we decide it's not worth preserving, a changelog line below the `Changelog` header noting the protocol change would help. ########## providers/common/ai/src/airflow/providers/common/ai/toolsets/langchain_bridge.py: ########## @@ -170,3 +177,34 @@ async def _async_call(**kwargs: Any) -> Any: description=tool_def.description or name, args_schema=tool_def.parameters_json_schema, ) + + +def _build_structured_tool_from_spec( + spec: ToolSpec, + structured_tool_cls: type[StructuredTool], +) -> StructuredTool: + """Build a single LangChain ``StructuredTool`` from an Airflow ``ToolSpec``.""" + + def _sync_call(**kwargs: Any) -> Any: + try: + if asyncio.iscoroutinefunction(spec.fn): Review Comment: This uses raw `asyncio.iscoroutinefunction(spec.fn)`, but the provider already has `is_async_callable` in `utils/callables.py` which also handles async-`__call__` callable objects and unwraps `functools.partial`. Reusing it here keeps async detection consistent with `_logged_callable` / `_cached_callable`. Same at line 198. ########## providers/common/ai/src/airflow/providers/common/ai/hooks/pydantic_ai.py: ########## @@ -172,41 +196,195 @@ def _provider_factory(pname: str) -> Any: self._model = infer_model(model_name) return self._model - @overload - def create_agent( - self, output_type: type[OutputT], *, instructions: str, **agent_kwargs - ) -> Agent[None, OutputT]: ... + # ------------------------------------------------------------------ + # BaseAIHook abstract interface + # ------------------------------------------------------------------ - @overload - def create_agent(self, *, instructions: str, **agent_kwargs) -> Agent[None, str]: ... + def _tool_spec_to_native(self, spec: ToolSpec) -> Any: + """Convert a :class:`~airflow.providers.common.ai.hooks.base.ToolSpec` to a pydantic-ai ``Tool``.""" + return Tool.from_schema( + spec.fn, + name=spec.name, + description=spec.description, + sequential=spec.sequential, + json_schema=spec.parameters, + ) - def create_agent( - self, output_type: type[Any] = str, *, instructions: str, **agent_kwargs - ) -> Agent[None, Any]: + def _build_agent(self, request: AgentRunRequest) -> PydanticAgentHandle: """ - Create a pydantic-ai Agent configured with this hook's model. + Build a pydantic-ai ``Agent`` handle from *request*. - When ``[common.ai] otel_export_enabled`` is set and the worker has an - OpenTelemetry exporter configured, the agent is instrumented to emit - GenAI spans through Airflow's tracing pipeline. See + When :attr:`~AgentRunRequest.durable_context` is set, initialises durable + storage and step counter and returns them alongside the native agent for use + by :meth:`run_agent`. When ``[common.ai] otel_export_enabled`` is set and the + worker has an OpenTelemetry exporter configured, the agent is instrumented to + emit GenAI spans through Airflow's tracing pipeline. See :mod:`airflow.providers.common.ai.observability`. - :param output_type: The expected output type from the agent (default: ``str``). - :param instructions: System-level instructions for the agent. - :param agent_kwargs: Additional keyword arguments passed to the Agent constructor. + :param request: Agent configuration — output type, instructions, toolsets, extra params. + + Native pydantic-ai ``Tool`` instances supplied in ``request.toolsets`` are passed through + unchanged. Airflow tool logging and durable tool-result caching are applied to + framework-neutral callables / ``BaseToolset`` specs and pydantic-ai ``AbstractToolset`` + instances, but not to native ``Tool`` instances. """ - agent = Agent(self.get_conn(), output_type=output_type, instructions=instructions, **agent_kwargs) - if "instrument" not in agent_kwargs: - # Set the public ``agent.instrument`` surface rather than the + durable_state = None + if request.durable_context is not None: + durable_state = self._init_durable(request.durable_context) + + extra_kwargs = dict(request.agent_params or {}) + if request.toolsets: + if "tools" in extra_kwargs: + raise ValueError( + "agent_params must not include 'tools' when toolsets= is set on AgentRunRequest." + ) + if "toolsets" in extra_kwargs: + raise ValueError( + "agent_params must not include 'toolsets' when toolsets= is set on AgentRunRequest." + ) + + abstract_items = [ts for ts in request.toolsets if isinstance(ts, AbstractToolset)] + pipeline_items = [ts for ts in request.toolsets if not isinstance(ts, AbstractToolset)] Review Comment: Routing every non-`AbstractToolset` item through the callable-to-function-tool pipeline reinterprets pydantic-ai's `ToolsetFunc` (a callable returning an `AbstractToolset`, which `Agent(toolsets=...)` accepts). The native-`Tool` case is preserved at L252-257, but a bare `ToolsetFunc` isn't a `Tool` instance, so it falls to the `else` and gets rebuilt as a model-callable tool from its `__call__` signature. Previously the whole list was forwarded to the agent, so a `ToolsetFunc` worked. No in-tree consumer passes one and treating bare callables as tools looks deliberate, so just checking: was dropping the `ToolsetFunc` pass-through intentional? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
