This is an automated email from the ASF dual-hosted git repository.
gopidesupavan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new adf903571bb Add spec_file support to PydanticAIHook.create_agent
(#67788)
adf903571bb is described below
commit adf903571bb65dfdb0394099038035f6621f2232
Author: GPK <[email protected]>
AuthorDate: Sun Jun 7 22:52:47 2026 +0100
Add spec_file support to PydanticAIHook.create_agent (#67788)
* Add spec_file support to PydanticAIHook.create_agent
* resolve comments
* resolve comments remove version bump
* Fix mypy
---
providers/common/ai/docs/hooks/pydantic_ai.rst | 30 +++++
.../common/ai/example_dags/example_agent_spec.yaml | 27 +++++
.../ai/example_dags/example_pydantic_ai_hook.py | 43 +++++++
.../providers/common/ai/hooks/pydantic_ai.py | 76 +++++++++++-
.../tests/unit/common/ai/hooks/test_pydantic_ai.py | 128 +++++++++++++++++++++
5 files changed, 300 insertions(+), 4 deletions(-)
diff --git a/providers/common/ai/docs/hooks/pydantic_ai.rst
b/providers/common/ai/docs/hooks/pydantic_ai.rst
index 65bb53b4a38..ef4070d72c5 100644
--- a/providers/common/ai/docs/hooks/pydantic_ai.rst
+++ b/providers/common/ai/docs/hooks/pydantic_ai.rst
@@ -66,3 +66,33 @@ Define a Pydantic model for the expected output shape, then
pass it as ``output_
:language: python
:start-after: [START howto_hook_pydantic_ai_structured_output]
:end-before: [END howto_hook_pydantic_ai_structured_output]
+
+Loading Agent Config from a Spec File
+--------------------------------------
+
+Instead of hard-coding model name, instructions, and settings in Python, you
can
+store them in a YAML or JSON `AgentSpec
+<https://ai.pydantic.dev/agents/#agent-spec>`__ file and pass its path via
+``spec_file``. This keeps prompt engineering separate from Dag logic and lets
+you version-control agent configs independently.
+
+.. code-block:: yaml
+ :caption: agent_spec.yaml
+
+ model: openai:gpt-4o-mini
+ instructions: >
+ You are a concise summarizer. Given any text, respond with a single
+ paragraph that captures the key points.
+ model_settings:
+ temperature: 0.3
+ retries: 2
+
+.. exampleinclude::
/../../ai/src/airflow/providers/common/ai/example_dags/example_pydantic_ai_hook.py
+ :language: python
+ :start-after: [START howto_hook_pydantic_ai_spec_file]
+ :end-before: [END howto_hook_pydantic_ai_spec_file]
+
+The model declared in the spec file is used unless ``model_id`` or the
+connection's ``model`` extra is set, in which case the hook model takes
+precedence. Passing ``instructions`` to ``create_agent`` when a ``spec_file``
is
+also given appends additional instructions to the file value.
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_agent_spec.yaml
b/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_agent_spec.yaml
new file mode 100644
index 00000000000..03b424be282
--- /dev/null
+++
b/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_agent_spec.yaml
@@ -0,0 +1,27 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# pydantic-ai AgentSpec file — referenced by example_pydantic_ai_hook.py
+---
+model: openai:gpt-4o-mini
+instructions: >
+ You are a concise summarizer. Given any text, respond with a single
+ paragraph that captures the key points. Do not add commentary.
+model_settings:
+ temperature: 0.3
+retries: 2
+end_strategy: early
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_pydantic_ai_hook.py
b/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_pydantic_ai_hook.py
index d1790dcaba6..e3ca40bf1b4 100644
---
a/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_pydantic_ai_hook.py
+++
b/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_pydantic_ai_hook.py
@@ -18,6 +18,8 @@
from __future__ import annotations
+from pathlib import Path
+
from pydantic import BaseModel
from airflow.providers.common.ai.hooks.pydantic_ai import PydanticAIHook
@@ -100,3 +102,44 @@ def example_task_with_toolsets():
# [END howto_task_with_toolsets]
example_task_with_toolsets()
+
+
+# [START howto_hook_pydantic_ai_spec_file]
+@dag(schedule=None, tags=["example"])
+def example_pydantic_ai_spec_file():
+ """Load agent settings from a YAML spec file instead of inline code.
+
+ The spec file (``example_agent_spec.yaml``) declares model, instructions,
+ model_settings, retries, etc. If ``model_id`` or the connection's ``model``
+ extra is set, that hook model takes precedence over the file's model.
+ """
+
+ @task
+ def summarize_from_spec(text: str) -> str:
+ spec_path = Path(__file__).parent / "example_agent_spec.yaml"
+ hook = PydanticAIHook(llm_conn_id="pydanticai_default")
+ # Model, instructions, temperature, and retries all come from the YAML
file.
+ agent = hook.create_agent(spec_file=spec_path)
+ result = agent.run_sync(text)
+ return result.output
+
+ @task
+ def summarize_with_additional_instructions(text: str) -> str:
+ """Add call-time instructions alongside the spec file instructions."""
+ spec_path = Path(__file__).parent / "example_agent_spec.yaml"
+ hook = PydanticAIHook(llm_conn_id="pydanticai_default")
+ agent = hook.create_agent(
+ spec_file=spec_path,
+ instructions="Summarize in exactly one sentence.",
+ )
+ result = agent.run_sync(text)
+ return result.output
+
+ body = "Apache Airflow is an open-source platform for authoring,
scheduling..."
+ summarize_from_spec(body)
+ summarize_with_additional_instructions(body)
+
+
+# [END howto_hook_pydantic_ai_spec_file]
+
+example_pydantic_ai_spec_file()
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/hooks/pydantic_ai.py
b/providers/common/ai/src/airflow/providers/common/ai/hooks/pydantic_ai.py
index 44e2436576f..eec7c5e7944 100644
--- a/providers/common/ai/src/airflow/providers/common/ai/hooks/pydantic_ai.py
+++ b/providers/common/ai/src/airflow/providers/common/ai/hooks/pydantic_ai.py
@@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations
+from pathlib import Path
from typing import TYPE_CHECKING, Any, TypeVar, overload
from pydantic_ai import Agent
@@ -30,6 +31,8 @@ OutputT = TypeVar("OutputT")
if TYPE_CHECKING:
from pydantic_ai.models import KnownModelName, Model
+ from airflow.providers.common.compat.sdk import Connection
+
class PydanticAIHook(BaseHook):
"""
@@ -71,6 +74,8 @@ class PydanticAIHook(BaseHook):
self.llm_conn_id = llm_conn_id if llm_conn_id is not None else
self.default_conn_name
self.model_id = model_id
self._model: Model | None = None
+ self._conn: Connection | None = None
+ self._conn_extra_dejson: dict[str, Any] | None = None
@staticmethod
def get_ui_field_behaviour() -> dict[str, Any]:
@@ -134,9 +139,11 @@ class PydanticAIHook(BaseHook):
if self._model is not None:
return self._model
- conn = self.get_connection(self.llm_conn_id)
+ conn = self.get_connection(self.llm_conn_id) if self._conn is None
else self._conn
+ extra: dict[str, Any] = (
+ conn.extra_dejson if self._conn_extra_dejson is None else
self._conn_extra_dejson
+ )
- extra: dict[str, Any] = conn.extra_dejson
model_name: str | KnownModelName = self.model_id or extra.get("model",
"")
if not model_name:
raise ValueError(
@@ -172,6 +179,20 @@ class PydanticAIHook(BaseHook):
self._model = infer_model(model_name)
return self._model
+ def _get_conn_if_model_configured(self) -> Model | None:
+ """Return the hook model only when the hook or connection explicitly
configures one."""
+ if self.model_id:
+ return self.get_conn()
+
+ conn = self.get_connection(self.llm_conn_id)
+ self._conn = conn
+ self._conn_extra_dejson = conn.extra_dejson
+
+ if self._conn_extra_dejson.get("model"):
+ return self.get_conn()
+
+ return None
+
@overload
def create_agent(
self, output_type: type[OutputT], *, instructions: str, **agent_kwargs
@@ -180,8 +201,32 @@ class PydanticAIHook(BaseHook):
@overload
def create_agent(self, *, instructions: str, **agent_kwargs) ->
Agent[None, str]: ...
+ @overload
def create_agent(
- self, output_type: type[Any] = str, *, instructions: str,
**agent_kwargs
+ self,
+ output_type: type[OutputT],
+ *,
+ spec_file: str | Path,
+ instructions: str | None = ...,
+ **agent_kwargs,
+ ) -> Agent[None, OutputT]: ...
+
+ @overload
+ def create_agent(
+ self,
+ *,
+ spec_file: str | Path,
+ instructions: str | None = ...,
+ **agent_kwargs,
+ ) -> Agent[None, str]: ...
+
+ def create_agent(
+ self,
+ output_type: type[Any] = str,
+ *,
+ instructions: str | None = None,
+ spec_file: str | Path | None = None,
+ **agent_kwargs,
) -> Agent[None, Any]:
"""
Create a pydantic-ai Agent configured with this hook's model.
@@ -193,9 +238,32 @@ class PydanticAIHook(BaseHook):
:param output_type: The expected output type from the agent (default:
``str``).
:param instructions: System-level instructions for the agent.
+ Required when *spec_file* is not given. When *spec_file* is given,
+ this value is merged with the instructions in the file; omit it to
+ use only the file value.
+ :param spec_file: Path to a YAML or JSON ``AgentSpec`` file. When
supplied,
+ delegates to ``Agent.from_file``. If ``model_id`` or the
connection's
+ ``model`` extra is set, that model is passed to pydantic-ai;
otherwise
+ the spec file's ``model`` is used.
:param agent_kwargs: Additional keyword arguments passed to the Agent
constructor.
"""
- agent = Agent(self.get_conn(), output_type=output_type,
instructions=instructions, **agent_kwargs)
+ if spec_file is not None:
+ from_file_kwargs = dict(agent_kwargs)
+ model = self._get_conn_if_model_configured()
+ if model is not None:
+ from_file_kwargs["model"] = model
+ if instructions is not None:
+ from_file_kwargs["instructions"] = instructions
+
+ agent = Agent.from_file(
+ spec_file,
+ output_type=output_type,
+ **from_file_kwargs,
+ )
+ else:
+ if instructions is None:
+ raise ValueError("instructions is required when spec_file is
not provided.")
+ agent = Agent(self.get_conn(), output_type=output_type,
instructions=instructions, **agent_kwargs)
if "instrument" not in agent_kwargs:
# Set the public ``agent.instrument`` surface rather than the
# ``Agent(instrument=...)`` constructor kwarg, which is deprecated
in
diff --git a/providers/common/ai/tests/unit/common/ai/hooks/test_pydantic_ai.py
b/providers/common/ai/tests/unit/common/ai/hooks/test_pydantic_ai.py
index fb63b1f3e48..c203a3782a1 100644
--- a/providers/common/ai/tests/unit/common/ai/hooks/test_pydantic_ai.py
+++ b/providers/common/ai/tests/unit/common/ai/hooks/test_pydantic_ai.py
@@ -18,6 +18,7 @@ from __future__ import annotations
import json
import sys
+from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
@@ -243,6 +244,133 @@ class TestPydanticAIHookCreateAgent:
retries=3,
)
+ def test_create_agent_without_instructions_or_spec_file_raises(self):
+ hook = PydanticAIHook(llm_conn_id="test_conn",
model_id="openai:gpt-5.3")
+ with pytest.raises(ValueError, match="instructions is required"):
+ hook.create_agent()
+
+ @patch("airflow.providers.common.ai.hooks.pydantic_ai.infer_model",
autospec=True)
+ @patch("airflow.providers.common.ai.hooks.pydantic_ai.Agent")
+ def test_create_agent_with_spec_file_calls_from_file(self, mock_agent_cls,
mock_infer_model):
+ """spec_file routes to Agent.from_file with the hook model when
configured."""
+ mock_model = MagicMock(spec=Model)
+ mock_infer_model.return_value = mock_model
+
+ hook = PydanticAIHook(llm_conn_id="test_conn",
model_id="openai:gpt-5.3")
+ conn = Connection(conn_id="test_conn", conn_type="pydanticai")
+ with patch.object(hook, "get_connection", return_value=conn):
+ hook.create_agent(spec_file="/path/to/agent.yaml")
+
+ mock_agent_cls.from_file.assert_called_once_with(
+ "/path/to/agent.yaml",
+ model=mock_model,
+ output_type=str,
+ )
+ mock_agent_cls.assert_not_called()
+
+ @patch("airflow.providers.common.ai.hooks.pydantic_ai.infer_model",
autospec=True)
+ @patch("airflow.providers.common.ai.hooks.pydantic_ai.Agent")
+ def
test_create_agent_with_spec_file_uses_file_model_when_hook_model_not_configured(
+ self, mock_agent_cls, mock_infer_model
+ ):
+ """spec_file model is used when neither model_id nor connection model
is configured."""
+ hook = PydanticAIHook(llm_conn_id="test_conn")
+ conn = Connection(conn_id="test_conn", conn_type="pydanticai")
+ with patch.object(hook, "get_connection", return_value=conn):
+ hook.create_agent(spec_file="/path/to/agent.yaml")
+
+ mock_infer_model.assert_not_called()
+ mock_agent_cls.from_file.assert_called_once_with(
+ "/path/to/agent.yaml",
+ output_type=str,
+ )
+ mock_agent_cls.assert_not_called()
+
+ @patch("airflow.providers.common.ai.hooks.pydantic_ai.infer_model",
autospec=True)
+ @patch("airflow.providers.common.ai.hooks.pydantic_ai.Agent")
+ def test_create_agent_with_spec_file_path_object(self, mock_agent_cls,
mock_infer_model):
+ """spec_file accepts a pathlib.Path object."""
+ mock_model = MagicMock(spec=Model)
+ mock_infer_model.return_value = mock_model
+
+ hook = PydanticAIHook(llm_conn_id="test_conn",
model_id="openai:gpt-5.3")
+ conn = Connection(conn_id="test_conn", conn_type="pydanticai")
+ spec_path = Path("/path/to/agent.yaml")
+ with patch.object(hook, "get_connection", return_value=conn):
+ hook.create_agent(spec_file=spec_path)
+
+ mock_agent_cls.from_file.assert_called_once_with(
+ spec_path,
+ model=mock_model,
+ output_type=str,
+ )
+
+ @patch("airflow.providers.common.ai.hooks.pydantic_ai.infer_model",
autospec=True)
+ @patch("airflow.providers.common.ai.hooks.pydantic_ai.Agent")
+ def test_create_agent_with_spec_file_merges_additional_instructions(
+ self, mock_agent_cls, mock_infer_model
+ ):
+ """Explicit instructions are forwarded so pydantic-ai merges them with
the spec."""
+ mock_model = MagicMock(spec=Model)
+ mock_infer_model.return_value = mock_model
+
+ hook = PydanticAIHook(llm_conn_id="test_conn",
model_id="openai:gpt-5.3")
+ conn = Connection(conn_id="test_conn", conn_type="pydanticai")
+ with patch.object(hook, "get_connection", return_value=conn):
+ hook.create_agent(
+ spec_file="/path/to/agent.yaml",
+ instructions="Override instructions.",
+ )
+
+ mock_agent_cls.from_file.assert_called_once_with(
+ "/path/to/agent.yaml",
+ model=mock_model,
+ output_type=str,
+ instructions="Override instructions.",
+ )
+
+ @patch("airflow.providers.common.ai.hooks.pydantic_ai.infer_model",
autospec=True)
+ @patch("airflow.providers.common.ai.hooks.pydantic_ai.Agent")
+ def test_create_agent_with_spec_file_custom_output_type(self,
mock_agent_cls, mock_infer_model):
+ """output_type is forwarded to Agent.from_file."""
+ mock_model = MagicMock(spec=Model)
+ mock_infer_model.return_value = mock_model
+
+ hook = PydanticAIHook(llm_conn_id="test_conn",
model_id="openai:gpt-5.3")
+ conn = Connection(conn_id="test_conn", conn_type="pydanticai")
+ with patch.object(hook, "get_connection", return_value=conn):
+ hook.create_agent(output_type=dict,
spec_file="/path/to/agent.yaml")
+
+ mock_agent_cls.from_file.assert_called_once_with(
+ "/path/to/agent.yaml",
+ model=mock_model,
+ output_type=dict,
+ )
+
+ @patch("airflow.providers.common.ai.hooks.pydantic_ai.infer_model",
autospec=True)
+ @patch("airflow.providers.common.ai.hooks.pydantic_ai.Agent")
+ def test_create_agent_with_spec_file_forwards_agent_kwargs(self,
mock_agent_cls, mock_infer_model):
+ """Extra agent_kwargs are forwarded to Agent.from_file."""
+ mock_model = MagicMock(spec=Model)
+ mock_infer_model.return_value = mock_model
+
+ hook = PydanticAIHook(llm_conn_id="test_conn",
model_id="openai:gpt-5.3")
+ conn = Connection(conn_id="test_conn", conn_type="pydanticai")
+ with patch.object(hook, "get_connection", return_value=conn):
+ hook.create_agent(
+ spec_file="/path/to/agent.yaml",
+ retries=5,
+ end_strategy="early",
+ )
+
+ mock_agent_cls.from_file.assert_called_once_with(
+ "/path/to/agent.yaml",
+ model=mock_model,
+ output_type=str,
+ retries=5,
+ end_strategy="early",
+ )
+
class TestPydanticAIHookCreateAgentInstrumentation:
"""create_agent() wires OpenTelemetry instrumentation from
observability."""