kaxil commented on code in PR #67450: URL: https://github.com/apache/airflow/pull/67450#discussion_r3321449583
########## providers/common/ai/src/airflow/providers/common/ai/hooks/strands_ai.py: ########## @@ -0,0 +1,218 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Hooks for LLM agents via the Strands Agents SDK.""" + +from __future__ import annotations + +import functools +from abc import abstractmethod +from typing import Any + +from strands import Agent, AgentSkills, Skill, tool as strands_tool +from strands.models.gemini import GeminiModel + +from airflow.providers.common.ai.hooks.base_ai import ( + AgentRunRequest, + AgentRunResult, + BaseAIHook, + SkillSpec, + ToolSpec, + tool_identifier, +) + + +class StrandsHook(BaseAIHook): + """ + Base hook for LLM agents via `Strands Agents <https://strandsagents.com/>`__. + + Subclasses implement :meth:`get_model` to return a configured Strands model instance + (for example :class:`strands.models.gemini.GeminiModel`). The + :meth:`create_agent`, :meth:`run_agent`, and :meth:`_tool_spec_to_native` + implementations are shared across all Strands model backends. + """ + + conn_name_attr = "llm_conn_id" + default_conn_name = "strands_default" + + supports_toolsets = True + supports_durable = False + supports_usage_limits = False + supports_skills = True + + def __init__( + self, + llm_conn_id: str | None = None, + model_id: str | None = None, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.llm_conn_id = llm_conn_id if llm_conn_id is not None else self.default_conn_name + self.model_id = model_id + self._resolved_model_id: str | None = None + + @abstractmethod + def get_model(self) -> Any: + """Return a configured Strands model instance.""" + + def _tool_spec_to_native(self, spec: ToolSpec) -> Any: + """Convert a :class:`~airflow.providers.common.ai.hooks.base_ai.ToolSpec` to a Strands tool.""" + fn = spec.fn + + # Strands infers tool name from __name__ and description from __doc__. + # functools.wraps preserves __wrapped__ so inspect.signature() follows it + # for parameter schema inference, then we override name/doc from spec. + @functools.wraps(fn) + def tool_fn(*args: Any, **kwargs: Any) -> Any: + return fn(*args, **kwargs) + + tool_fn.__name__ = tool_identifier(spec.name) + tool_fn.__doc__ = spec.description + return strands_tool(tool_fn) + + def _skill_spec_to_native(self, skill: str | SkillSpec) -> Any: + """Convert a skill source to a Strands-native skill object or path.""" + if isinstance(skill, SkillSpec) and not skill.path: + return Skill( + name=skill.name, + description=skill.description, + instructions=skill.instructions, + ) + return super()._skill_spec_to_native(skill) + + def _build_skills_plugin(self, request: AgentRunRequest) -> Any | None: + """Build a Strands ``AgentSkills`` plugin when skill sources are configured.""" + sources = self._resolve_skill_sources(request) + if not sources: + return None + + skills_arg: Any = sources[0] if len(sources) == 1 else sources + return AgentSkills(skills=skills_arg, **dict(request.skills_params or {})) + + def create_agent(self, request: AgentRunRequest) -> Any: + """Build a Strands ``Agent`` from *request*.""" + self.validate_run_request(request) + + native_tools: list[Any] = [] + if request.toolsets: + native_tools = self._resolve_tools( + request.toolsets, + request.enable_tool_logging, + None, # durable execution is not supported for Strands + None, + ) + + agent_kwargs: dict[str, Any] = dict(request.agent_params or {}) + if request.instructions: + agent_kwargs["system_prompt"] = request.instructions + + plugins: list[Any] = list(agent_kwargs.pop("plugins", []) or []) + skills_plugin = self._build_skills_plugin(request) + if skills_plugin is not None: + plugins.append(skills_plugin) + if plugins: + agent_kwargs["plugins"] = plugins + + return Agent(model=self.get_model(), tools=native_tools or [], **agent_kwargs) + + def run_agent(self, agent: Any, request: AgentRunRequest) -> AgentRunResult: + """Run the Strands *agent* for *request* and return a normalized :class:`AgentRunResult`.""" + response = agent(request.prompt) + return AgentRunResult( + output=str(response), + model_name=self._resolved_model_id or self.model_id, + ) Review Comment: This constructs `AgentRunResult` with only `output` and `model_name`, so `usage` and `tool_names` stay at their `None` defaults. `log_run_summary` then skips both the token counts (`logging.py:35`) and the tool-call sequence (`logging.py:49`) for Strands even when tools ran, whereas the pydantic-ai path populates both. Strands' result exposes usage metrics and a tool trace; mapping them into `AgentUsage` / `tool_names` here would match the pydantic-ai path. Use `getattr(..., None)` guards if the metrics shape varies across SDK versions. (Low priority: no crash, just missing observability.) ########## providers/common/ai/src/airflow/providers/common/ai/hooks/strands_ai.py: ########## @@ -0,0 +1,218 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Hooks for LLM agents via the Strands Agents SDK.""" + +from __future__ import annotations + +import functools +from abc import abstractmethod +from typing import Any + +from strands import Agent, AgentSkills, Skill, tool as strands_tool +from strands.models.gemini import GeminiModel + +from airflow.providers.common.ai.hooks.base_ai import ( + AgentRunRequest, + AgentRunResult, + BaseAIHook, + SkillSpec, + ToolSpec, + tool_identifier, +) + + +class StrandsHook(BaseAIHook): + """ + Base hook for LLM agents via `Strands Agents <https://strandsagents.com/>`__. + + Subclasses implement :meth:`get_model` to return a configured Strands model instance + (for example :class:`strands.models.gemini.GeminiModel`). The + :meth:`create_agent`, :meth:`run_agent`, and :meth:`_tool_spec_to_native` + implementations are shared across all Strands model backends. + """ + + conn_name_attr = "llm_conn_id" + default_conn_name = "strands_default" + + supports_toolsets = True + supports_durable = False + supports_usage_limits = False + supports_skills = True + + def __init__( + self, + llm_conn_id: str | None = None, + model_id: str | None = None, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.llm_conn_id = llm_conn_id if llm_conn_id is not None else self.default_conn_name + self.model_id = model_id + self._resolved_model_id: str | None = None + + @abstractmethod + def get_model(self) -> Any: + """Return a configured Strands model instance.""" + + def _tool_spec_to_native(self, spec: ToolSpec) -> Any: + """Convert a :class:`~airflow.providers.common.ai.hooks.base_ai.ToolSpec` to a Strands tool.""" + fn = spec.fn + + # Strands infers tool name from __name__ and description from __doc__. + # functools.wraps preserves __wrapped__ so inspect.signature() follows it + # for parameter schema inference, then we override name/doc from spec. + @functools.wraps(fn) + def tool_fn(*args: Any, **kwargs: Any) -> Any: + return fn(*args, **kwargs) + + tool_fn.__name__ = tool_identifier(spec.name) + tool_fn.__doc__ = spec.description + return strands_tool(tool_fn) + + def _skill_spec_to_native(self, skill: str | SkillSpec) -> Any: + """Convert a skill source to a Strands-native skill object or path.""" + if isinstance(skill, SkillSpec) and not skill.path: + return Skill( + name=skill.name, + description=skill.description, + instructions=skill.instructions, + ) + return super()._skill_spec_to_native(skill) + + def _build_skills_plugin(self, request: AgentRunRequest) -> Any | None: + """Build a Strands ``AgentSkills`` plugin when skill sources are configured.""" + sources = self._resolve_skill_sources(request) + if not sources: + return None + + skills_arg: Any = sources[0] if len(sources) == 1 else sources + return AgentSkills(skills=skills_arg, **dict(request.skills_params or {})) + + def create_agent(self, request: AgentRunRequest) -> Any: + """Build a Strands ``Agent`` from *request*.""" + self.validate_run_request(request) + + native_tools: list[Any] = [] + if request.toolsets: + native_tools = self._resolve_tools( + request.toolsets, + request.enable_tool_logging, + None, # durable execution is not supported for Strands + None, + ) + + agent_kwargs: dict[str, Any] = dict(request.agent_params or {}) + if request.instructions: + agent_kwargs["system_prompt"] = request.instructions + + plugins: list[Any] = list(agent_kwargs.pop("plugins", []) or []) + skills_plugin = self._build_skills_plugin(request) + if skills_plugin is not None: + plugins.append(skills_plugin) + if plugins: + agent_kwargs["plugins"] = plugins + + return Agent(model=self.get_model(), tools=native_tools or [], **agent_kwargs) + + def run_agent(self, agent: Any, request: AgentRunRequest) -> AgentRunResult: + """Run the Strands *agent* for *request* and return a normalized :class:`AgentRunResult`.""" + response = agent(request.prompt) + return AgentRunResult( + output=str(response), Review Comment: `AgentOperator` documents `output_type` for structured output and threads it into `AgentRunRequest.output_type` (`agent.py:226`), and `PydanticAIHook.create_agent` honors it. `StrandsHook.run_agent` never reads `request.output_type` and hard-coerces `output=str(response)` here, so `AgentOperator(output_type=MyModel, llm_conn_id="strands_...")` silently gets a stringified blob with no error. The operator's `if isinstance(output, BaseModel)` branch (`agent.py:283`) can never fire for Strands. Either honor `output_type` (Strands supports structured output) and return the typed result, or add a `supports_output_type` ClassVar and reject it in `validate_run_request` (mirroring `supports_usage_limits`) so misuse fails fast instead of returning wrong data. Worth a test pinning whichever contract you pick. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
