gopidesupavan commented on code in PR #67438:
URL: https://github.com/apache/airflow/pull/67438#discussion_r3309375031
##########
providers/common/ai/src/airflow/providers/common/ai/hooks/pydantic_ai.py:
##########
@@ -171,25 +179,146 @@ def _provider_factory(pname: str) -> Any:
self._model = infer_model(model_name)
return self._model
- @overload
- def create_agent(
- self, output_type: type[OutputT], *, instructions: str, **agent_kwargs
- ) -> Agent[None, OutputT]: ...
+ # ------------------------------------------------------------------
+ # BaseAIHook abstract interface
+ # ------------------------------------------------------------------
+
+ def _tool_spec_to_native(self, spec: ToolSpec) -> Any:
+ """Convert a
:class:`~airflow.providers.common.ai.hooks.base_ai.ToolSpec` to a pydantic-ai
``Tool``."""
+ from pydantic_ai.tools import Tool
- @overload
- def create_agent(self, *, instructions: str, **agent_kwargs) ->
Agent[None, str]: ...
+ return Tool(
+ spec.fn,
+ name=spec.name,
+ description=spec.description,
+ sequential=spec.sequential,
+ )
- def create_agent(
- self, output_type: type[Any] = str, *, instructions: str,
**agent_kwargs
- ) -> Agent[None, Any]:
+ def create_agent(self, request: AgentRunRequest) -> Agent[None, Any]:
"""
- Create a pydantic-ai Agent configured with this hook's model.
+ Build a pydantic-ai ``Agent`` from *request*.
- :param output_type: The expected output type from the agent (default:
``str``).
- :param instructions: System-level instructions for the agent.
- :param agent_kwargs: Additional keyword arguments passed to the Agent
constructor.
+ When :attr:`~AgentRunRequest.durable_context` is set, initialises
durable
+ storage and step counter and binds them to the returned agent for use
by
+ :meth:`run_agent`.
+
+ :param request: Agent configuration — output type, instructions,
toolsets, extra params.
"""
- return Agent(self.get_conn(), output_type=output_type,
instructions=instructions, **agent_kwargs)
+ storage = counter = None
+ if request.durable_context is not None:
+ storage, counter = self._init_durable(request.durable_context)
+
+ extra_kwargs = dict(request.agent_params or {})
+ if request.toolsets:
+ from pydantic_ai.toolsets.abstract import AbstractToolset
+
+ abstract_items = [ts for ts in request.toolsets if isinstance(ts,
AbstractToolset)]
+ pipeline_items = [ts for ts in request.toolsets if not
isinstance(ts, AbstractToolset)]
+
+ if pipeline_items:
+ resolved = self._resolve_tools(
+ pipeline_items,
+ request.enable_tool_logging,
+ storage,
+ counter,
+ )
+ if resolved:
+ extra_kwargs["tools"] = resolved
Review Comment:
upated
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]