kaxil commented on code in PR #67438:
URL: https://github.com/apache/airflow/pull/67438#discussion_r3307560341
##########
providers/common/ai/src/airflow/providers/common/ai/hooks/pydantic_ai.py:
##########
@@ -171,25 +179,146 @@ def _provider_factory(pname: str) -> Any:
self._model = infer_model(model_name)
return self._model
- @overload
- def create_agent(
- self, output_type: type[OutputT], *, instructions: str, **agent_kwargs
- ) -> Agent[None, OutputT]: ...
+ # ------------------------------------------------------------------
+ # BaseAIHook abstract interface
+ # ------------------------------------------------------------------
+
+ def _tool_spec_to_native(self, spec: ToolSpec) -> Any:
+ """Convert a
:class:`~airflow.providers.common.ai.hooks.base_ai.ToolSpec` to a pydantic-ai
``Tool``."""
+ from pydantic_ai.tools import Tool
Review Comment:
`pydantic_ai` is already imported at module top (`from pydantic_ai import
Agent`), so there's no circular-import or optional-dep reason for this inline
`from pydantic_ai.tools import Tool`. Same for `from
pydantic_ai.toolsets.abstract import AbstractToolset` at L213, `from
pydantic_ai.messages import ToolCallPart` at L259, and the three
`airflow.providers.common.ai.durable.*` / `toolsets.logging` imports at
L231/L242/L272 (`base_ai.py` doesn't import any of those, so hoisting them to
module top is safe). Worth pulling all six to the top of the file.
##########
providers/common/ai/src/airflow/providers/common/ai/hooks/pydantic_ai.py:
##########
@@ -171,25 +179,146 @@ def _provider_factory(pname: str) -> Any:
self._model = infer_model(model_name)
return self._model
- @overload
- def create_agent(
- self, output_type: type[OutputT], *, instructions: str, **agent_kwargs
- ) -> Agent[None, OutputT]: ...
+ # ------------------------------------------------------------------
+ # BaseAIHook abstract interface
+ # ------------------------------------------------------------------
+
+ def _tool_spec_to_native(self, spec: ToolSpec) -> Any:
+ """Convert a
:class:`~airflow.providers.common.ai.hooks.base_ai.ToolSpec` to a pydantic-ai
``Tool``."""
+ from pydantic_ai.tools import Tool
- @overload
- def create_agent(self, *, instructions: str, **agent_kwargs) ->
Agent[None, str]: ...
+ return Tool(
+ spec.fn,
+ name=spec.name,
+ description=spec.description,
+ sequential=spec.sequential,
+ )
- def create_agent(
- self, output_type: type[Any] = str, *, instructions: str,
**agent_kwargs
- ) -> Agent[None, Any]:
+ def create_agent(self, request: AgentRunRequest) -> Agent[None, Any]:
Review Comment:
Breaking signature change: in 0.2.0 this was `create_agent(output_type=...,
instructions=..., **agent_kwargs)` (public, documented in the prior changelog
block); in 0.3.0 it becomes `create_agent(request: AgentRunRequest)`. Anyone
calling the hook directly between releases breaks at the call site. Worth a
short prose migration note just below the `Changelog` header in
`providers/common/ai/docs/changelog.rst` per the `NOTE TO CONTRIBUTORS` block
in that file.
##########
providers/common/ai/src/airflow/providers/common/ai/hooks/pydantic_ai.py:
##########
@@ -171,25 +179,146 @@ def _provider_factory(pname: str) -> Any:
self._model = infer_model(model_name)
return self._model
- @overload
- def create_agent(
- self, output_type: type[OutputT], *, instructions: str, **agent_kwargs
- ) -> Agent[None, OutputT]: ...
+ # ------------------------------------------------------------------
+ # BaseAIHook abstract interface
+ # ------------------------------------------------------------------
+
+ def _tool_spec_to_native(self, spec: ToolSpec) -> Any:
+ """Convert a
:class:`~airflow.providers.common.ai.hooks.base_ai.ToolSpec` to a pydantic-ai
``Tool``."""
+ from pydantic_ai.tools import Tool
- @overload
- def create_agent(self, *, instructions: str, **agent_kwargs) ->
Agent[None, str]: ...
+ return Tool(
+ spec.fn,
+ name=spec.name,
+ description=spec.description,
+ sequential=spec.sequential,
+ )
- def create_agent(
- self, output_type: type[Any] = str, *, instructions: str,
**agent_kwargs
- ) -> Agent[None, Any]:
+ def create_agent(self, request: AgentRunRequest) -> Agent[None, Any]:
"""
- Create a pydantic-ai Agent configured with this hook's model.
+ Build a pydantic-ai ``Agent`` from *request*.
- :param output_type: The expected output type from the agent (default:
``str``).
- :param instructions: System-level instructions for the agent.
- :param agent_kwargs: Additional keyword arguments passed to the Agent
constructor.
+ When :attr:`~AgentRunRequest.durable_context` is set, initialises
durable
+ storage and step counter and binds them to the returned agent for use
by
+ :meth:`run_agent`.
+
+ :param request: Agent configuration — output type, instructions,
toolsets, extra params.
"""
- return Agent(self.get_conn(), output_type=output_type,
instructions=instructions, **agent_kwargs)
+ storage = counter = None
+ if request.durable_context is not None:
+ storage, counter = self._init_durable(request.durable_context)
+
+ extra_kwargs = dict(request.agent_params or {})
+ if request.toolsets:
+ from pydantic_ai.toolsets.abstract import AbstractToolset
+
+ abstract_items = [ts for ts in request.toolsets if isinstance(ts,
AbstractToolset)]
+ pipeline_items = [ts for ts in request.toolsets if not
isinstance(ts, AbstractToolset)]
+
+ if pipeline_items:
+ resolved = self._resolve_tools(
+ pipeline_items,
+ request.enable_tool_logging,
+ storage,
+ counter,
+ )
+ if resolved:
Review Comment:
`if pipeline_items:` (L218) already guards this branch, and `_resolve_tools`
returns a list with at least one entry per input item, so `resolved` cannot be
empty here -- the `if resolved:` is dead defensive code. (The parallel `if
abstract_items:` check at L228 is what genuinely guards the `toolsets`
assignment below, so that one is fine.)
##########
providers/common/ai/src/airflow/providers/common/ai/hooks/base_ai.py:
##########
@@ -0,0 +1,399 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Shared contract for agent-framework hooks used by
:class:`~airflow.providers.common.ai.operators.agent.AgentOperator`."""
+
+from __future__ import annotations
+
+import functools
+import inspect
+import json
+import time
+from abc import ABCMeta, abstractmethod
+from collections.abc import Callable, Sequence
+from dataclasses import dataclass, field
+from typing import Any, ClassVar
+
+from airflow.providers.common.compat.sdk import BaseHook
+
+_EMPTY_OBJECT_SCHEMA: dict[str, Any] = {"type": "object", "properties": {}}
+
+# Durable storage/counter pairs keyed by ``id(agent)``.
+# pydantic-ai ``Agent`` is not hashable, so ``WeakKeyDictionary`` cannot be
used.
+# ``create_agent`` and ``run_agent`` run synchronously in the same task, so
``id()``
+# is stable until ``_pop_agent_durable`` removes the entry.
+_AGENT_DURABLE: dict[int, tuple[Any, Any]] = {}
Review Comment:
Storing `(storage, counter)` in a module-level dict keyed by `id(agent)`
works for the happy path (`create_agent` then `run_agent`), but every code path
that calls `create_agent` without ever reaching `run_agent` (exception in
between, test cleanup that drops the agent, user code that builds an agent and
never runs it) leaks the entry until process death. Simpler and leak-free:
stash on the agent itself, e.g. `agent._airflow_durable = (storage, counter)`
and pop via `delattr`. The comment says `WeakKeyDictionary` is out because
pydantic-ai `Agent` isn't hashable, but a direct attribute on the instance
sidesteps that and doesn't need the global table at all.
##########
providers/common/ai/src/airflow/providers/common/ai/hooks/pydantic_ai.py:
##########
@@ -171,25 +179,146 @@ def _provider_factory(pname: str) -> Any:
self._model = infer_model(model_name)
return self._model
- @overload
- def create_agent(
- self, output_type: type[OutputT], *, instructions: str, **agent_kwargs
- ) -> Agent[None, OutputT]: ...
+ # ------------------------------------------------------------------
+ # BaseAIHook abstract interface
+ # ------------------------------------------------------------------
+
+ def _tool_spec_to_native(self, spec: ToolSpec) -> Any:
+ """Convert a
:class:`~airflow.providers.common.ai.hooks.base_ai.ToolSpec` to a pydantic-ai
``Tool``."""
+ from pydantic_ai.tools import Tool
- @overload
- def create_agent(self, *, instructions: str, **agent_kwargs) ->
Agent[None, str]: ...
+ return Tool(
+ spec.fn,
+ name=spec.name,
+ description=spec.description,
+ sequential=spec.sequential,
+ )
- def create_agent(
- self, output_type: type[Any] = str, *, instructions: str,
**agent_kwargs
- ) -> Agent[None, Any]:
+ def create_agent(self, request: AgentRunRequest) -> Agent[None, Any]:
"""
- Create a pydantic-ai Agent configured with this hook's model.
+ Build a pydantic-ai ``Agent`` from *request*.
- :param output_type: The expected output type from the agent (default:
``str``).
- :param instructions: System-level instructions for the agent.
- :param agent_kwargs: Additional keyword arguments passed to the Agent
constructor.
+ When :attr:`~AgentRunRequest.durable_context` is set, initialises
durable
+ storage and step counter and binds them to the returned agent for use
by
+ :meth:`run_agent`.
+
+ :param request: Agent configuration — output type, instructions,
toolsets, extra params.
"""
- return Agent(self.get_conn(), output_type=output_type,
instructions=instructions, **agent_kwargs)
+ storage = counter = None
+ if request.durable_context is not None:
+ storage, counter = self._init_durable(request.durable_context)
+
+ extra_kwargs = dict(request.agent_params or {})
+ if request.toolsets:
+ from pydantic_ai.toolsets.abstract import AbstractToolset
+
+ abstract_items = [ts for ts in request.toolsets if isinstance(ts,
AbstractToolset)]
+ pipeline_items = [ts for ts in request.toolsets if not
isinstance(ts, AbstractToolset)]
+
+ if pipeline_items:
+ resolved = self._resolve_tools(
+ pipeline_items,
+ request.enable_tool_logging,
+ storage,
+ counter,
+ )
+ if resolved:
+ extra_kwargs["tools"] = resolved
Review Comment:
`extra_kwargs = dict(request.agent_params or {})` then
`extra_kwargs["tools"] = resolved` here and `extra_kwargs["toolsets"] =
processed` at L245 will silently overwrite anything the user passed via
`agent_params={"tools": [native_tool]}` or `agent_params={"toolsets": [...]}`.
Either merge (e.g. `extra_kwargs.setdefault("tools", []).extend(resolved)`) or
raise a clear error when both are supplied -- otherwise users debugging "why
isn't my native tool being called" have to dig in here to find out.
##########
providers/common/ai/src/airflow/providers/common/ai/hooks/pydantic_ai.py:
##########
@@ -171,25 +179,146 @@ def _provider_factory(pname: str) -> Any:
self._model = infer_model(model_name)
return self._model
- @overload
- def create_agent(
- self, output_type: type[OutputT], *, instructions: str, **agent_kwargs
- ) -> Agent[None, OutputT]: ...
+ # ------------------------------------------------------------------
+ # BaseAIHook abstract interface
+ # ------------------------------------------------------------------
+
+ def _tool_spec_to_native(self, spec: ToolSpec) -> Any:
+ """Convert a
:class:`~airflow.providers.common.ai.hooks.base_ai.ToolSpec` to a pydantic-ai
``Tool``."""
+ from pydantic_ai.tools import Tool
- @overload
- def create_agent(self, *, instructions: str, **agent_kwargs) ->
Agent[None, str]: ...
+ return Tool(
+ spec.fn,
+ name=spec.name,
+ description=spec.description,
+ sequential=spec.sequential,
+ )
- def create_agent(
- self, output_type: type[Any] = str, *, instructions: str,
**agent_kwargs
- ) -> Agent[None, Any]:
+ def create_agent(self, request: AgentRunRequest) -> Agent[None, Any]:
"""
- Create a pydantic-ai Agent configured with this hook's model.
+ Build a pydantic-ai ``Agent`` from *request*.
- :param output_type: The expected output type from the agent (default:
``str``).
- :param instructions: System-level instructions for the agent.
- :param agent_kwargs: Additional keyword arguments passed to the Agent
constructor.
+ When :attr:`~AgentRunRequest.durable_context` is set, initialises
durable
+ storage and step counter and binds them to the returned agent for use
by
+ :meth:`run_agent`.
+
+ :param request: Agent configuration — output type, instructions,
toolsets, extra params.
"""
- return Agent(self.get_conn(), output_type=output_type,
instructions=instructions, **agent_kwargs)
+ storage = counter = None
+ if request.durable_context is not None:
+ storage, counter = self._init_durable(request.durable_context)
+
+ extra_kwargs = dict(request.agent_params or {})
+ if request.toolsets:
+ from pydantic_ai.toolsets.abstract import AbstractToolset
+
+ abstract_items = [ts for ts in request.toolsets if isinstance(ts,
AbstractToolset)]
+ pipeline_items = [ts for ts in request.toolsets if not
isinstance(ts, AbstractToolset)]
+
+ if pipeline_items:
+ resolved = self._resolve_tools(
+ pipeline_items,
+ request.enable_tool_logging,
+ storage,
+ counter,
+ )
+ if resolved:
+ extra_kwargs["tools"] = resolved
+
+ if abstract_items:
+ processed: list[Any] = list(abstract_items)
+ if storage is not None and counter is not None:
+ from airflow.providers.common.ai.durable.caching_toolset
import CachingToolset
+
+ processed = [
+ CachingToolset(
+ wrapped=ts,
+ storage=storage,
+ counter=counter,
+ )
+ for ts in processed
+ ]
+ if request.enable_tool_logging:
+ from airflow.providers.common.ai.toolsets.logging import
LoggingToolset
+
+ processed = [LoggingToolset(wrapped=ts, logger=self.log)
for ts in processed]
+ extra_kwargs["toolsets"] = processed
+
+ agent = Agent(
+ self.get_model(),
+ output_type=request.output_type,
+ instructions=request.instructions,
+ **extra_kwargs,
+ )
+ if storage is not None and counter is not None:
+ self._bind_agent_durable(agent, storage, counter)
+ return agent
+
+ def run_agent(self, agent: Agent[None, Any], request: AgentRunRequest) ->
AgentRunResult:
+ """Run *agent* synchronously for *request* and return a normalized
:class:`~airflow.providers.common.ai.hooks.base_ai.AgentRunResult`."""
+ from pydantic_ai.messages import ToolCallPart
+
+ run_kwargs: dict[str, Any] = {}
+ if request.message_history is not None:
+ run_kwargs["message_history"] = request.message_history
+ if request.usage_limits is not None:
+ run_kwargs["usage_limits"] = request.usage_limits
+
+ durable = self._pop_agent_durable(agent)
+ storage, counter = durable if durable else (None, None)
+
+ try:
+ if storage is not None and counter is not None:
+ from airflow.providers.common.ai.durable.caching_model import
CachingModel
+
+ if agent.model is None:
+ raise ValueError("Agent model must be set when
durable=True")
+ model = agent.model
+ resolved_model = infer_model(model) if isinstance(model, str)
else model
+ caching_model = CachingModel(
+ resolved_model,
+ storage=storage,
+ counter=counter,
+ )
+ with agent.override(model=caching_model):
+ result = agent.run_sync(request.prompt, **run_kwargs)
+ else:
+ result = agent.run_sync(request.prompt, **run_kwargs)
+
+ usage = result.usage
+ tool_names: list[str] = []
+ for message in result.all_messages():
+ for part in getattr(message, "parts", []):
+ if isinstance(part, ToolCallPart):
+ tool_names.append(part.tool_name)
+
+ run_result = AgentRunResult(
+ output=result.output,
+ message_history=result.all_messages(),
+ model_name=getattr(result.response, "model_name", None),
+ usage=AgentUsage(
+ requests=usage.requests,
+ tool_calls=usage.tool_calls,
+ input_tokens=usage.input_tokens,
+ output_tokens=usage.output_tokens,
+ total_tokens=usage.total_tokens,
+ ),
+ tool_names=tool_names or None,
+ )
+
+ if counter is not None:
+ run_result.durable_stats = DurableStats(
+ replayed_model=counter.replayed_model,
+ replayed_tool=counter.replayed_tool,
+ cached_model=counter.cached_model,
+ cached_tool=counter.cached_tool,
+ )
+ except BaseException:
Review Comment:
`except BaseException: raise` is a no-op (it re-raises the same exception
unchanged), so the whole `try/except/else` collapses to "build `run_result`,
then on success call `storage.cleanup()`". That's expressible without the
`except` branch -- e.g. put the cleanup directly after the `run_result`
assignment and let exceptions propagate naturally. As written it reads like
there's an intentional exception transform happening, which there isn't.
##########
providers/common/ai/src/airflow/providers/common/ai/operators/llm.py:
##########
@@ -108,19 +107,9 @@ def __init__(
self.allow_modifications = allow_modifications
@cached_property
- def llm_hook(self) -> PydanticAIHook:
- """
- Return the correct PydanticAIHook subclass for the configured
connection.
-
- Delegates to :meth:`~PydanticAIHook.get_hook` which looks up
- the connection's ``conn_type`` and instantiates the matching subclass
- (e.g.
:class:`~airflow.providers.common.ai.hooks.pydantic_ai.PydanticAIAzureHook`
- for ``pydanticai-azure`` connections).
- """
- hook_params = {
- "model_id": self.model_id,
- }
- return PydanticAIHook.get_hook(self.llm_conn_id,
hook_params=hook_params)
+ def llm_hook(self) -> BaseAIHook:
+ """Return the agent hook for the configured connection."""
+ return BaseAIHook.get_agent_hook(self.llm_conn_id,
hook_params={"model_id": self.model_id})
def execute(self, context: Context) -> Any:
Review Comment:
`AgentOperator._validate_hook_capabilities()` rejects unsupported
`toolsets`/`usage_limits`/`durable` upfront, but `LLMOperator` (and
`LLMBranchOperator`, `LLMFileAnalysisOperator`, `LLMSchemaCompareOperator`,
`LLMSQLQueryOperator`) skip that check entirely. A future `BaseAIHook` subclass
with `supports_usage_limits = False` will silently drop `self.usage_limits`
here, because `run_agent` simply doesn't pass it through. Either factor
`_validate_hook_capabilities` into a mixin shared with `LLMOperator`, or have
`run_agent` raise when `request.usage_limits is not None` and
`supports_usage_limits` is False so the mismatch fails loudly instead of
silently.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]