kaxil commented on code in PR #67450:
URL: https://github.com/apache/airflow/pull/67450#discussion_r3321449583


##########
providers/common/ai/src/airflow/providers/common/ai/hooks/strands_ai.py:
##########
@@ -0,0 +1,218 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Hooks for LLM agents via the Strands Agents SDK."""
+
+from __future__ import annotations
+
+import functools
+from abc import abstractmethod
+from typing import Any
+
+from strands import Agent, AgentSkills, Skill, tool as strands_tool
+from strands.models.gemini import GeminiModel
+
+from airflow.providers.common.ai.hooks.base_ai import (
+    AgentRunRequest,
+    AgentRunResult,
+    BaseAIHook,
+    SkillSpec,
+    ToolSpec,
+    tool_identifier,
+)
+
+
+class StrandsHook(BaseAIHook):
+    """
+    Base hook for LLM agents via `Strands Agents 
<https://strandsagents.com/>`__.
+
+    Subclasses implement :meth:`get_model` to return a configured Strands 
model instance
+    (for example :class:`strands.models.gemini.GeminiModel`). The
+    :meth:`create_agent`, :meth:`run_agent`, and :meth:`_tool_spec_to_native`
+    implementations are shared across all Strands model backends.
+    """
+
+    conn_name_attr = "llm_conn_id"
+    default_conn_name = "strands_default"
+
+    supports_toolsets = True
+    supports_durable = False
+    supports_usage_limits = False
+    supports_skills = True
+
+    def __init__(
+        self,
+        llm_conn_id: str | None = None,
+        model_id: str | None = None,
+        **kwargs: Any,
+    ) -> None:
+        super().__init__(**kwargs)
+        self.llm_conn_id = llm_conn_id if llm_conn_id is not None else 
self.default_conn_name
+        self.model_id = model_id
+        self._resolved_model_id: str | None = None
+
+    @abstractmethod
+    def get_model(self) -> Any:
+        """Return a configured Strands model instance."""
+
+    def _tool_spec_to_native(self, spec: ToolSpec) -> Any:
+        """Convert a 
:class:`~airflow.providers.common.ai.hooks.base_ai.ToolSpec` to a Strands 
tool."""
+        fn = spec.fn
+
+        # Strands infers tool name from __name__ and description from __doc__.
+        # functools.wraps preserves __wrapped__ so inspect.signature() follows 
it
+        # for parameter schema inference, then we override name/doc from spec.
+        @functools.wraps(fn)
+        def tool_fn(*args: Any, **kwargs: Any) -> Any:
+            return fn(*args, **kwargs)
+
+        tool_fn.__name__ = tool_identifier(spec.name)
+        tool_fn.__doc__ = spec.description
+        return strands_tool(tool_fn)
+
+    def _skill_spec_to_native(self, skill: str | SkillSpec) -> Any:
+        """Convert a skill source to a Strands-native skill object or path."""
+        if isinstance(skill, SkillSpec) and not skill.path:
+            return Skill(
+                name=skill.name,
+                description=skill.description,
+                instructions=skill.instructions,
+            )
+        return super()._skill_spec_to_native(skill)
+
+    def _build_skills_plugin(self, request: AgentRunRequest) -> Any | None:
+        """Build a Strands ``AgentSkills`` plugin when skill sources are 
configured."""
+        sources = self._resolve_skill_sources(request)
+        if not sources:
+            return None
+
+        skills_arg: Any = sources[0] if len(sources) == 1 else sources
+        return AgentSkills(skills=skills_arg, **dict(request.skills_params or 
{}))
+
+    def create_agent(self, request: AgentRunRequest) -> Any:
+        """Build a Strands ``Agent`` from *request*."""
+        self.validate_run_request(request)
+
+        native_tools: list[Any] = []
+        if request.toolsets:
+            native_tools = self._resolve_tools(
+                request.toolsets,
+                request.enable_tool_logging,
+                None,  # durable execution is not supported for Strands
+                None,
+            )
+
+        agent_kwargs: dict[str, Any] = dict(request.agent_params or {})
+        if request.instructions:
+            agent_kwargs["system_prompt"] = request.instructions
+
+        plugins: list[Any] = list(agent_kwargs.pop("plugins", []) or [])
+        skills_plugin = self._build_skills_plugin(request)
+        if skills_plugin is not None:
+            plugins.append(skills_plugin)
+        if plugins:
+            agent_kwargs["plugins"] = plugins
+
+        return Agent(model=self.get_model(), tools=native_tools or [], 
**agent_kwargs)
+
+    def run_agent(self, agent: Any, request: AgentRunRequest) -> 
AgentRunResult:
+        """Run the Strands *agent* for *request* and return a normalized 
:class:`AgentRunResult`."""
+        response = agent(request.prompt)
+        return AgentRunResult(
+            output=str(response),
+            model_name=self._resolved_model_id or self.model_id,
+        )

Review Comment:
   This constructs `AgentRunResult` with only `output` and `model_name`, so 
`usage` and `tool_names` stay at their `None` defaults. `log_run_summary` then 
skips both the token counts (`logging.py:35`) and the tool-call sequence 
(`logging.py:49`) for Strands even when tools ran, whereas the pydantic-ai path 
populates both. Strands' result exposes usage metrics and a tool trace; mapping 
them into `AgentUsage` / `tool_names` here would match the pydantic-ai path. 
Use `getattr(..., None)` guards if the metrics shape varies across SDK 
versions. (Low priority: no crash, just missing observability.)



##########
providers/common/ai/src/airflow/providers/common/ai/hooks/strands_ai.py:
##########
@@ -0,0 +1,218 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Hooks for LLM agents via the Strands Agents SDK."""
+
+from __future__ import annotations
+
+import functools
+from abc import abstractmethod
+from typing import Any
+
+from strands import Agent, AgentSkills, Skill, tool as strands_tool
+from strands.models.gemini import GeminiModel
+
+from airflow.providers.common.ai.hooks.base_ai import (
+    AgentRunRequest,
+    AgentRunResult,
+    BaseAIHook,
+    SkillSpec,
+    ToolSpec,
+    tool_identifier,
+)
+
+
+class StrandsHook(BaseAIHook):
+    """
+    Base hook for LLM agents via `Strands Agents 
<https://strandsagents.com/>`__.
+
+    Subclasses implement :meth:`get_model` to return a configured Strands 
model instance
+    (for example :class:`strands.models.gemini.GeminiModel`). The
+    :meth:`create_agent`, :meth:`run_agent`, and :meth:`_tool_spec_to_native`
+    implementations are shared across all Strands model backends.
+    """
+
+    conn_name_attr = "llm_conn_id"
+    default_conn_name = "strands_default"
+
+    supports_toolsets = True
+    supports_durable = False
+    supports_usage_limits = False
+    supports_skills = True
+
+    def __init__(
+        self,
+        llm_conn_id: str | None = None,
+        model_id: str | None = None,
+        **kwargs: Any,
+    ) -> None:
+        super().__init__(**kwargs)
+        self.llm_conn_id = llm_conn_id if llm_conn_id is not None else 
self.default_conn_name
+        self.model_id = model_id
+        self._resolved_model_id: str | None = None
+
+    @abstractmethod
+    def get_model(self) -> Any:
+        """Return a configured Strands model instance."""
+
+    def _tool_spec_to_native(self, spec: ToolSpec) -> Any:
+        """Convert a 
:class:`~airflow.providers.common.ai.hooks.base_ai.ToolSpec` to a Strands 
tool."""
+        fn = spec.fn
+
+        # Strands infers tool name from __name__ and description from __doc__.
+        # functools.wraps preserves __wrapped__ so inspect.signature() follows 
it
+        # for parameter schema inference, then we override name/doc from spec.
+        @functools.wraps(fn)
+        def tool_fn(*args: Any, **kwargs: Any) -> Any:
+            return fn(*args, **kwargs)
+
+        tool_fn.__name__ = tool_identifier(spec.name)
+        tool_fn.__doc__ = spec.description
+        return strands_tool(tool_fn)
+
+    def _skill_spec_to_native(self, skill: str | SkillSpec) -> Any:
+        """Convert a skill source to a Strands-native skill object or path."""
+        if isinstance(skill, SkillSpec) and not skill.path:
+            return Skill(
+                name=skill.name,
+                description=skill.description,
+                instructions=skill.instructions,
+            )
+        return super()._skill_spec_to_native(skill)
+
+    def _build_skills_plugin(self, request: AgentRunRequest) -> Any | None:
+        """Build a Strands ``AgentSkills`` plugin when skill sources are 
configured."""
+        sources = self._resolve_skill_sources(request)
+        if not sources:
+            return None
+
+        skills_arg: Any = sources[0] if len(sources) == 1 else sources
+        return AgentSkills(skills=skills_arg, **dict(request.skills_params or 
{}))
+
+    def create_agent(self, request: AgentRunRequest) -> Any:
+        """Build a Strands ``Agent`` from *request*."""
+        self.validate_run_request(request)
+
+        native_tools: list[Any] = []
+        if request.toolsets:
+            native_tools = self._resolve_tools(
+                request.toolsets,
+                request.enable_tool_logging,
+                None,  # durable execution is not supported for Strands
+                None,
+            )
+
+        agent_kwargs: dict[str, Any] = dict(request.agent_params or {})
+        if request.instructions:
+            agent_kwargs["system_prompt"] = request.instructions
+
+        plugins: list[Any] = list(agent_kwargs.pop("plugins", []) or [])
+        skills_plugin = self._build_skills_plugin(request)
+        if skills_plugin is not None:
+            plugins.append(skills_plugin)
+        if plugins:
+            agent_kwargs["plugins"] = plugins
+
+        return Agent(model=self.get_model(), tools=native_tools or [], 
**agent_kwargs)
+
+    def run_agent(self, agent: Any, request: AgentRunRequest) -> 
AgentRunResult:
+        """Run the Strands *agent* for *request* and return a normalized 
:class:`AgentRunResult`."""
+        response = agent(request.prompt)
+        return AgentRunResult(
+            output=str(response),

Review Comment:
   `AgentOperator` documents `output_type` for structured output and threads it 
into `AgentRunRequest.output_type` (`agent.py:226`), and 
`PydanticAIHook.create_agent` honors it. `StrandsHook.run_agent` never reads 
`request.output_type` and hard-coerces `output=str(response)` here, so 
`AgentOperator(output_type=MyModel, llm_conn_id="strands_...")` silently gets a 
stringified blob with no error. The operator's `if isinstance(output, 
BaseModel)` branch (`agent.py:283`) can never fire for Strands. Either honor 
`output_type` (Strands supports structured output) and return the typed result, 
or add a `supports_output_type` ClassVar and reject it in 
`validate_run_request` (mirroring `supports_usage_limits`) so misuse fails fast 
instead of returning wrong data. Worth a test pinning whichever contract you 
pick.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to