This is an automated email from the ASF dual-hosted git repository.
kaxil pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 531bfab9dee Add `message_history` to `AgentOperator` for multi-turn
agent sessions (#68648)
531bfab9dee is described below
commit 531bfab9dee3ccdd0cea1e679f6de23b1b88b3b6
Author: Kaxil Naik <[email protected]>
AuthorDate: Fri Jun 19 00:57:58 2026 +0100
Add `message_history` to `AgentOperator` for multi-turn agent sessions
(#68648)
AgentOperator and @task.agent ran a fresh single-turn conversation every
time. Add an opt-in message_history parameter that seeds the run with prior
turns and pushes the post-run transcript to XCom (key 'message_history') so the
next run can resume. Default None keeps single-turn behavior unchanged. Storing
the transcript under a session key stays the DAG's responsibility.
---
providers/common/ai/docs/operators/agent.rst | 48 +++++++
.../common/ai/example_dags/example_agent.py | 56 +++++++-
.../airflow/providers/common/ai/operators/agent.py | 69 +++++++++-
.../tests/unit/common/ai/operators/test_agent.py | 149 ++++++++++++++++++++-
4 files changed, 318 insertions(+), 4 deletions(-)
diff --git a/providers/common/ai/docs/operators/agent.rst
b/providers/common/ai/docs/operators/agent.rst
index a79e5110048..b3805aa34b7 100644
--- a/providers/common/ai/docs/operators/agent.rst
+++ b/providers/common/ai/docs/operators/agent.rst
@@ -156,6 +156,49 @@ tasks can consume it.
:end-before: [END howto_agent_chain]
+Multi-turn Sessions
+-------------------
+
+By default each agent run is a cold, single-turn conversation. To carry a
+conversation across runs -- a chat or iterative agent where "and the third
one?"
+must resolve against an earlier answer -- pass ``message_history``.
+
+When ``message_history`` is set, the operator seeds the run with those prior
+turns and, after the run, pushes the full updated transcript
+(``result.all_messages()``) to XCom under the key ``message_history``. The next
+run reads it back to resume the conversation. ``None`` (the default) keeps the
+single-turn behavior unchanged.
+
+The operator does **not** decide *where* a session is stored -- that keying is
+deployment-specific. The pattern is three tasks: load the prior transcript for
+the session, run the agent, store the updated transcript. The example keys a
+JSON file in object storage by ``session_id`` (use ``s3://`` / ``gs://`` in a
+deployment); the first run starts from an empty ``"[]"``.
+
+.. exampleinclude::
/../../ai/src/airflow/providers/common/ai/example_dags/example_agent.py
+ :language: python
+ :start-after: [START howto_agent_session]
+ :end-before: [END howto_agent_session]
+
+``message_history`` accepts a list of pydantic-ai ``ModelMessage`` objects or
+their JSON form (``str`` / ``bytes``), so the value emitted to XCom feeds
+straight back in on the next run. When pulling it via a template, pass
+``default='[]'`` (as above) so the first run -- which has no XCom yet --
starts a
+fresh session instead of trying to parse the string ``"None"``.
+
+The transcript is **cumulative**: each turn appends to it, so it grows for the
+life of the session. For long sessions, configure an object-storage XCom
backend
+or trim older turns before the next run rather than feeding the whole history
+back unbounded.
+
+.. note::
+
+ ``message_history`` cannot be combined with ``enable_hitl_review`` -- the
+ operator raises at construction. The post-review (human-approved)
transcript
+ is not recoverable today, so emitting the pre-review transcript would
+ silently drop the reviewed turns.
+
+
Durable Execution
-----------------
@@ -406,6 +449,11 @@ Parameters
- ``code_mode``: When ``True``, wraps the agent's tools in a single
``run_code``
tool that the model drives by writing Python, executed in the Monty sandbox.
Requires the ``code-mode`` extra. Default ``False``. See :ref:`code-mode`.
+- ``message_history``: Prior conversation to seed a multi-turn session, as a
list
+ of pydantic-ai ``ModelMessage`` objects or their JSON form (``str`` /
``bytes``).
+ When set, the post-run transcript is pushed to XCom under the key
+ ``message_history`` for the next run to resume. Default ``None``
(single-turn).
+ See `Multi-turn Sessions`_.
Logging
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_agent.py
b/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_agent.py
index 6fced224c92..787b0d6dce2 100644
---
a/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_agent.py
+++
b/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_agent.py
@@ -24,7 +24,7 @@ from pydantic import BaseModel
from airflow.providers.common.ai.operators.agent import AgentOperator
from airflow.providers.common.ai.toolsets.hook import HookToolset
-from airflow.providers.common.compat.sdk import dag, task
+from airflow.providers.common.compat.sdk import ObjectStoragePath, dag, task
try:
from airflow.providers.common.ai.toolsets.sql import SQLToolset
@@ -247,3 +247,57 @@ def example_agent_operator_code_mode():
# [END howto_operator_agent_code_mode]
example_agent_operator_code_mode()
+
+
+# ---------------------------------------------------------------------------
+# 8. Multi-turn session — resume a conversation across DAG runs
+# ---------------------------------------------------------------------------
+
+
+# [START howto_agent_session]
+@dag(tags=["example"], params={"session_id": "demo-session"})
+def example_agent_session():
+ """Resume a conversation across runs via ``message_history``.
+
+ The agent step seeds itself with the prior transcript and re-emits the
+ updated transcript to XCom (key ``message_history``). Loading and storing
+ that transcript under a session key is the DAG's job -- here, a JSON file
in
+ object storage keyed by ``session_id``. Swap the path for ``s3://`` /
+ ``gs://`` in a deployment.
+ """
+ sessions_root = ObjectStoragePath("file:///tmp/airflow_agent_sessions")
+
+ @task
+ def load_history(session_id: str) -> str:
+ path = sessions_root / f"{session_id}.json"
+ # First turn: no file yet -> start a fresh session (empty transcript).
+ return path.read_text() if path.exists() else "[]"
+
+ @task.agent(
+ llm_conn_id="pydanticai_default",
+ system_prompt="You are a helpful assistant. Use the earlier turns for
context.",
+ # The XComArg both wires the dependency and resolves to the JSON
transcript.
+ message_history=load_history("{{ params.session_id }}"),
+ )
+ def ask(question: str) -> str:
+ return question
+
+ @task
+ def save_history(session_id: str, transcript: str) -> None:
+ # Local/fsspec object storage does not auto-create parent dirs on
write.
+ sessions_root.mkdir(parents=True, exist_ok=True)
+ (sessions_root / f"{session_id}.json").write_text(transcript)
+
+ answer = ask("And what did I ask you a moment ago?")
+ saved = save_history(
+ "{{ params.session_id }}",
+ # The agent step pushes the post-run transcript under this XCom key.
+ "{{ ti.xcom_pull(task_ids='ask', key='message_history') }}",
+ )
+ # save runs after the agent so the pulled transcript is the fresh one.
+ answer >> saved
+
+
+# [END howto_agent_session]
+
+example_agent_session()
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/operators/agent.py
b/providers/common/ai/src/airflow/providers/common/ai/operators/agent.py
index bda06ea1f56..56c9ec5bbb6 100644
--- a/providers/common/ai/src/airflow/providers/common/ai/operators/agent.py
+++ b/providers/common/ai/src/airflow/providers/common/ai/operators/agent.py
@@ -48,6 +48,7 @@ except ImportError: # pragma: no cover - cores before the
worker-side registrat
if TYPE_CHECKING:
from pydantic_ai import Agent
+ from pydantic_ai.messages import ModelMessage
from pydantic_ai.toolsets.abstract import AbstractToolset
from pydantic_ai.usage import UsageLimits
@@ -166,6 +167,22 @@ class AgentOperator(BaseOperator, HITLReviewMixin):
Cannot be combined with ``durable=True`` (durable replay assumes a
stable per-step call order that code mode does not guarantee).
Default ``False``.
+ :param message_history: Prior conversation to seed the run with, for
+ multi-turn sessions that span task runs. Accepts a ``list`` of
+ pydantic-ai ``ModelMessage`` objects, or their JSON form as ``str`` /
+ ``bytes`` -- e.g.
+ ``"{{ ti.xcom_pull(task_ids='ask', key='message_history',
default='[]') }}"``
+ (pass ``default='[]'`` so the first run, with no XCom yet, starts a
fresh
+ session instead of failing to parse the string ``"None"``). ``None``
+ (default) is a single-turn run -- no behavior change. When set (an
empty
+ ``[]`` / ``""`` starts a fresh session), the full transcript after the
run
+ -- ``result.all_messages()`` -- is pushed to XCom under the key
+ ``message_history`` so the next run can resume. Persisting that
transcript
+ under a session key (e.g. in object storage) is the DAG's
responsibility.
+ The transcript is cumulative and grows each turn; for long sessions
use an
+ object-storage XCom backend or trim old turns. Not supported together
with
+ ``enable_hitl_review`` (raises) -- the post-review transcript is not
yet
+ recoverable.
**HITL Review parameters** (requires the ``hitl_review`` plugin):
@@ -199,6 +216,7 @@ class AgentOperator(BaseOperator, HITLReviewMixin):
"model_id",
"system_prompt",
"agent_params",
+ "message_history",
)
operator_extra_links = (HITLReviewLink(),)
@@ -217,6 +235,7 @@ class AgentOperator(BaseOperator, HITLReviewMixin):
usage_limits: UsageLimits | None = None,
durable: bool = False,
code_mode: bool = False,
+ message_history: list[ModelMessage] | str | bytes | None = None,
# Agent feedback parameters
enable_hitl_review: bool = False,
max_hitl_iterations: int = 5,
@@ -240,6 +259,7 @@ class AgentOperator(BaseOperator, HITLReviewMixin):
self.enable_tool_logging = enable_tool_logging
self.agent_params = agent_params or {}
self.usage_limits = usage_limits
+ self.message_history = message_history
self.durable = durable
self.code_mode = code_mode
@@ -256,6 +276,13 @@ class AgentOperator(BaseOperator, HITLReviewMixin):
# replay. Reject the combination rather than silently
mis-replaying.
raise ValueError("durable=True and code_mode=True cannot be used
together.")
+ if message_history is not None and enable_hitl_review:
+ # The post-review transcript is not recoverable today
(run_hitl_review
+ # returns only the final string), so emitting the pre-review
transcript
+ # would silently drop the human-approved turns. Block until HITL
can
+ # surface the final message history.
+ raise ValueError("message_history and enable_hitl_review=True
cannot be used together.")
+
self.enable_hitl_review = enable_hitl_review
self.max_hitl_iterations = max_hitl_iterations
self.hitl_timeout = hitl_timeout
@@ -331,6 +358,11 @@ class AgentOperator(BaseOperator, HITLReviewMixin):
agent = self._build_agent()
+ run_kwargs: dict[str, Any] = {"usage_limits": self.usage_limits}
+ history = self._resolve_message_history()
+ if history is not None:
+ run_kwargs["message_history"] = history
+
storage = self._durable_storage
counter = self._durable_counter
if self.durable and storage is not None and counter is not None:
@@ -343,9 +375,9 @@ class AgentOperator(BaseOperator, HITLReviewMixin):
resolved_model = infer_model(agent.model)
caching_model = CachingModel(resolved_model, storage=storage,
counter=counter)
with agent.override(model=caching_model):
- result = agent.run_sync(self.prompt,
usage_limits=self.usage_limits)
+ result = agent.run_sync(self.prompt, **run_kwargs)
else:
- result = agent.run_sync(self.prompt,
usage_limits=self.usage_limits)
+ result = agent.run_sync(self.prompt, **run_kwargs)
log_run_summary(self.log, result)
@@ -368,6 +400,9 @@ class AgentOperator(BaseOperator, HITLReviewMixin):
if self._durable_storage is not None:
self._durable_storage.cleanup()
+ if self.message_history is not None:
+ self._emit_message_history(context, result)
+
output = result.output
if self.enable_hitl_review:
@@ -391,6 +426,36 @@ class AgentOperator(BaseOperator, HITLReviewMixin):
output = output.model_dump()
return output
+ def _resolve_message_history(self) -> list[ModelMessage] | None:
+ """
+ Deserialize :attr:`message_history` into a list of pydantic-ai
messages.
+
+ ``None`` means single-turn (no history passed to the run). A ``str`` /
+ ``bytes`` value is parsed as the JSON the operator emits to XCom; a
list
+ (of ``ModelMessage`` objects or their dict form) is validated as-is.
+ """
+ raw = self.message_history
+ if raw is None:
+ return None
+ if isinstance(raw, (str, bytes)) and not raw.strip():
+ # A template that renders to empty (no prior XCom) starts a fresh
session.
+ return []
+ # pydantic-ai is imported lazily here to match this module's pattern of
+ # keeping pydantic-ai out of DAG-parse-time imports.
+ from pydantic_ai.messages import ModelMessagesTypeAdapter
+
+ if isinstance(raw, (str, bytes)):
+ return ModelMessagesTypeAdapter.validate_json(raw)
+ return ModelMessagesTypeAdapter.validate_python(raw)
+
+ def _emit_message_history(self, context: Context, result: Any) -> None:
+ """Push the full post-run transcript to XCom for the next turn to
resume."""
+ # Lazy import: see _resolve_message_history.
+ from pydantic_ai.messages import ModelMessagesTypeAdapter
+
+ transcript =
ModelMessagesTypeAdapter.dump_json(result.all_messages()).decode()
+ context["task_instance"].xcom_push(key="message_history",
value=transcript)
+
def regenerate_with_feedback(self, *, feedback: str, message_history: Any)
-> tuple[str, Any]:
"""Re-run the agent with *feedback* appended to the conversation
history."""
agent = self._build_agent()
diff --git a/providers/common/ai/tests/unit/common/ai/operators/test_agent.py
b/providers/common/ai/tests/unit/common/ai/operators/test_agent.py
index a9f017b94ee..1288dbbe652 100644
--- a/providers/common/ai/tests/unit/common/ai/operators/test_agent.py
+++ b/providers/common/ai/tests/unit/common/ai/operators/test_agent.py
@@ -22,6 +22,13 @@ from unittest.mock import MagicMock, patch
import pytest
from pydantic import BaseModel
+from pydantic_ai.messages import (
+ ModelMessagesTypeAdapter,
+ ModelRequest,
+ ModelResponse,
+ TextPart,
+ UserPromptPart,
+)
from pydantic_ai.usage import UsageLimits
from airflow.providers.common.ai.operators.agent import AgentOperator,
HITLReviewLink, _build_code_mode
@@ -90,7 +97,14 @@ class TestAgentOperatorValidation:
class TestAgentOperatorTemplateFields:
def test_template_fields(self):
- expected = {"prompt", "llm_conn_id", "model_id", "system_prompt",
"agent_params"}
+ expected = {
+ "prompt",
+ "llm_conn_id",
+ "model_id",
+ "system_prompt",
+ "agent_params",
+ "message_history",
+ }
assert set(AgentOperator.template_fields) == expected
@@ -617,3 +631,136 @@ class TestAgentOperatorMultimodalPromptGuard:
op.execute(context=MagicMock())
mock_agent.run_sync.assert_not_called()
+
+
+def _sample_history():
+ """A minimal two-message pydantic-ai conversation for round-trip tests."""
+ return [
+ ModelRequest(parts=[UserPromptPart(content="first question")]),
+ ModelResponse(parts=[TextPart(content="first answer")]),
+ ]
+
+
+# The accepted input forms for ``message_history``, computed once at
collection time.
+_SAMPLE_HISTORY_JSON =
ModelMessagesTypeAdapter.dump_json(_sample_history()).decode()
+_SAMPLE_HISTORY_DICTS =
ModelMessagesTypeAdapter.dump_python(_sample_history(), mode="json")
+
+
+class TestAgentOperatorMessageHistory:
+ """Multi-turn session support: seed run_sync with prior history, emit the
transcript."""
+
+ @pytest.mark.parametrize(
+ ("raw", "expected_len"),
+ [
+ pytest.param([], 0, id="empty-list"),
+ pytest.param("", 0, id="empty-str"),
+ pytest.param(" ", 0, id="blank-str"),
+ pytest.param(_SAMPLE_HISTORY_JSON, 2, id="json-str"),
+ pytest.param(_SAMPLE_HISTORY_DICTS, 2, id="list-of-dicts"),
+ pytest.param(_sample_history(), 2, id="list-of-objects"),
+ ],
+ )
+ @patch("airflow.providers.common.ai.operators.agent.PydanticAIHook",
autospec=True)
+ def test_message_history_seeds_run_sync(self, mock_hook_cls, raw,
expected_len):
+ """Every accepted input form is deserialized and passed to run_sync;
blank/empty start fresh."""
+ mock_agent = _make_mock_agent("ok")
+ mock_hook_cls.get_hook.return_value.create_agent.return_value =
mock_agent
+
+ op = AgentOperator(task_id="t", prompt="run", llm_conn_id="c",
message_history=raw)
+ op.execute(context=MagicMock())
+
+ passed = mock_agent.run_sync.call_args.kwargs["message_history"]
+ assert len(passed) == expected_len
+ assert all(isinstance(m, (ModelRequest, ModelResponse)) for m in
passed)
+
+ @patch("airflow.providers.common.ai.operators.agent.PydanticAIHook",
autospec=True)
+ def test_none_is_single_turn_no_history_no_emit(self, mock_hook_cls):
+ """Default message_history=None passes no history and pushes no
transcript XCom."""
+ mock_agent = _make_mock_agent("ok")
+ mock_hook_cls.get_hook.return_value.create_agent.return_value =
mock_agent
+
+ op = AgentOperator(task_id="t", prompt="run", llm_conn_id="c")
+ context = MagicMock()
+ op.execute(context=context)
+
+ assert "message_history" not in mock_agent.run_sync.call_args.kwargs
+ context["task_instance"].xcom_push.assert_not_called()
+
+ @patch("airflow.providers.common.ai.operators.agent.PydanticAIHook",
autospec=True)
+ def test_transcript_emitted_to_xcom_when_history_set(self, mock_hook_cls):
+ """When message_history is set, the post-run transcript is pushed to
XCom and round-trips."""
+ mock_agent = _make_mock_agent("ok")
+ mock_agent.run_sync.return_value.all_messages.return_value =
_sample_history()
+ mock_hook_cls.get_hook.return_value.create_agent.return_value =
mock_agent
+
+ op = AgentOperator(task_id="t", prompt="run", llm_conn_id="c",
message_history=[])
+ context = MagicMock()
+ op.execute(context=context)
+
+ ti = context["task_instance"]
+ ti.xcom_push.assert_called_once()
+ push_kwargs = ti.xcom_push.call_args.kwargs
+ assert push_kwargs["key"] == "message_history"
+ restored = ModelMessagesTypeAdapter.validate_json(push_kwargs["value"])
+ assert len(restored) == 2
+
+ @patch("airflow.providers.common.ai.operators.agent.PydanticAIHook",
autospec=True)
+ def test_usage_limits_still_forwarded_with_history(self, mock_hook_cls):
+ """Adding message_history does not drop usage_limits from the run_sync
call."""
+ mock_agent = _make_mock_agent("ok")
+ mock_hook_cls.get_hook.return_value.create_agent.return_value =
mock_agent
+
+ limits = UsageLimits(request_limit=2)
+ op = AgentOperator(
+ task_id="t", prompt="run", llm_conn_id="c", usage_limits=limits,
message_history=[]
+ )
+ op.execute(context=MagicMock())
+
+ kwargs = mock_agent.run_sync.call_args.kwargs
+ assert kwargs["usage_limits"] is limits
+ assert kwargs["message_history"] == []
+
+ def test_message_history_with_hitl_review_raises(self):
+ """message_history cannot be combined with HITL review (post-review
transcript is lost)."""
+ with pytest.raises(ValueError, match="message_history and
enable_hitl_review"):
+ AgentOperator(
+ task_id="t",
+ prompt="run",
+ llm_conn_id="c",
+ message_history=[],
+ enable_hitl_review=True,
+ )
+
+ @patch("pydantic_ai.models.wrapper.infer_model", side_effect=lambda m: m)
+ @patch("pydantic_ai.models.infer_model", autospec=True)
+ @patch("airflow.providers.common.ai.durable.storage._get_base_path")
+ @patch("airflow.providers.common.ai.operators.agent.PydanticAIHook",
autospec=True)
+ def test_durable_path_also_seeds_message_history(
+ self, mock_hook_cls, mock_base_path, mock_infer, _, tmp_path
+ ):
+ """The durable branch forwards message_history into the cached run
too."""
+ from airflow.sdk import ObjectStoragePath
+
+ mock_base_path.return_value =
ObjectStoragePath(f"file://{tmp_path.as_posix()}")
+
+ mock_agent = MagicMock(spec=["run_sync", "model", "override"])
+ mock_agent.run_sync.return_value = _make_mock_run_result("ok")
+ mock_agent.model = "test-model"
+ mock_agent.override.return_value.__enter__ =
MagicMock(return_value=None)
+ mock_agent.override.return_value.__exit__ =
MagicMock(return_value=False)
+ mock_hook_cls.get_hook.return_value.create_agent.return_value =
mock_agent
+ mock_infer.return_value = MagicMock()
+
+ context = MagicMock()
+ context.__getitem__ = MagicMock(
+ return_value=MagicMock(dag_id="d", task_id="t", run_id="r",
map_index=-1)
+ )
+
+ history_json =
ModelMessagesTypeAdapter.dump_json(_sample_history()).decode()
+ op = AgentOperator(
+ task_id="test", prompt="test", llm_conn_id="my_llm", durable=True,
message_history=history_json
+ )
+ op.execute(context=context)
+
+ passed = mock_agent.run_sync.call_args.kwargs["message_history"]
+ assert len(passed) == 2