This is an automated email from the ASF dual-hosted git repository.
kaxil pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 6a95da5e98d Register operator-declared XCom classes from a worker-side
DAG walk (#67875)
6a95da5e98d is described below
commit 6a95da5e98d755860aec60bf946017bbc5e7442e
Author: Kaxil Naik <[email protected]>
AuthorDate: Tue Jun 2 09:58:04 2026 +0100
Register operator-declared XCom classes from a worker-side DAG walk (#67875)
The common-ai typed-XCom feature registered an operator's ``output_type``
Pydantic class for XCom deserialization as a side effect of operator
``__init__``. That misses two real cases:
1. Mapped producers. ``@task.llm(output_type=X).expand(...)`` registers X
only
when the mapped task unmaps and runs (which goes through ``__init__``),
in
the producer's own process. A downstream consumer runs in a different
process
that loads the DAG with the producer still an unexpanded
``MappedOperator``
(no ``__init__``), so X is never registered there and deserialization
raises
``ImportError``. The shipped ``example_llm_analysis_pipeline`` is this
shape.
2. Workers that reconstruct operators without ``__init__`` -- e.g. a
parsed-DAG
cache that loads operators -- never run the registration, so
the same ``ImportError`` occurs.
Move registration to the worker: ``task_runner.parse()`` walks the loaded
DAG's
tasks and registers each operator's declared deserialization classes before
any
task runs. It reads the classes off real operators and off not-yet-expanded
mapped operators (``partial_kwargs``), and works regardless of how the DAG
was
loaded. Operators opt in by declaring
``deserialization_allowed_class_fields``
(``("output_type",)`` for the common-ai operators); the generic Pydantic
type-tree walk (``Union``/``Optional``/``list``) lives in ``serde`` as
``iter_pydantic_models``.
The operator ``__init__`` registration side effect is removed. Operators
probe
``serde.SUPPORTS_OPERATOR_DESERIALIZATION_WALKER`` and fall back to dumping
the
model to a ``dict`` on cores that lack the walk, so the value stays
deserializable without an allow-list edit. ``allow_class`` (and thus the
gate)
is unchanged: only an operator-declared class, sourced from the trusted
parsed
DAG, is registered, and registration stays exact-match.
---
providers/common/ai/docs/changelog.rst | 35 +++++-----
providers/common/ai/docs/operators/agent.rst | 17 ++---
providers/common/ai/docs/operators/llm.rst | 23 ++++---
.../common/ai/docs/operators/llm_file_analysis.rst | 13 ++--
.../airflow/providers/common/ai/operators/agent.py | 25 ++++----
.../airflow/providers/common/ai/operators/llm.py | 35 +++++-----
.../providers/common/ai/utils/output_type.py | 32 +--------
.../tests/unit/common/ai/decorators/test_agent.py | 14 ++--
.../tests/unit/common/ai/operators/test_agent.py | 39 +++--------
.../ai/tests/unit/common/ai/operators/test_llm.py | 47 +++++---------
.../common/ai/operators/test_llm_file_analysis.py | 18 +++---
.../tests/unit/common/ai/utils/test_output_type.py | 42 +-----------
task-sdk/src/airflow/sdk/bases/operator.py | 7 ++
.../src/airflow/sdk/execution_time/task_runner.py | 43 +++++++++++++
task-sdk/src/airflow/sdk/serde/__init__.py | 38 ++++++++++-
.../task_sdk/execution_time/test_task_runner.py | 75 ++++++++++++++++++++++
task-sdk/tests/task_sdk/serde/test_serde.py | 20 ++++++
17 files changed, 302 insertions(+), 221 deletions(-)
diff --git a/providers/common/ai/docs/changelog.rst
b/providers/common/ai/docs/changelog.rst
index 4badccb0161..f935bc64395 100644
--- a/providers/common/ai/docs/changelog.rst
+++ b/providers/common/ai/docs/changelog.rst
@@ -29,24 +29,23 @@ Breaking change: operators with ``output_type=<BaseModel
subclass>``
(``LLMOperator``, ``LLMAgentOperator``, ``LLMFileAnalysisOperator``, and
their ``@task.llm`` / ``@task.agent`` / ``@task.llm_file_analysis`` decorators)
now return the Pydantic model instance through XCom instead of dumping it to
-a ``dict`` when the running Airflow version provides
-``airflow.sdk.serde.allow_class``. Downstream tasks should type-hint the model
-class (``def downstream(result: MyModel)``) and use attribute access
-(``result.field``) instead of subscript access. The output class must be
-defined at **module scope** and bound to an attribute matching its
-``__name__``; operators raise ``ValueError`` at construction time when
-``output_type`` (or any ``BaseModel`` reachable from a ``Union``/``Optional``/
-``list`` of types) is nested, dynamically built, or non-importable by
``qualname``.
-
-Same-DAG downstream tasks deserialize the model without any configuration
-change because each worker re-runs the operator constructor when it parses the
-DAG. The UI XCom viewer renders the value via the ``stringify`` path and works
-without configuration (it shows ``module.MyModel@version=1(field=value,...)``
-rather than a pretty form, but no allow-list edit is required). Cross-DAG
-``xcom_pull`` consumers still need the class qualified name added to
-``[core] allowed_deserialization_classes`` -- the consumer DAG's worker only
-parses its own DAG file. On older Airflow releases that lack ``allow_class``
-the operators continue to dump to ``dict``.
+a ``dict``, on Airflow versions whose worker registers operator-declared output
+classes for deserialization. Downstream tasks should type-hint the model class
+(``def downstream(result: MyModel)``) and use attribute access
(``result.field``)
+instead of subscript access. The output class must be defined at **module
scope**
+and bound to an attribute matching its ``__name__``; classes that are nested,
+dynamically built, or otherwise non-importable by ``qualname`` cannot be
+re-imported and will fail to deserialize at the consumer.
+
+The worker walks the loaded DAG and registers each declared class before any
+task runs, so same-DAG downstream tasks (including mapped ``.expand(...)``
+producers) deserialize the model without any configuration change. The UI XCom
+viewer renders the value via the ``stringify`` path and works without
+configuration (it shows ``module.MyModel@version=1(field=value,...)`` rather
than
+a pretty form). Cross-DAG ``xcom_pull`` consumers still need the class
qualified
+name added to ``[core] allowed_deserialization_classes`` -- the consumer DAG's
+worker only loads its own DAG. On Airflow versions whose worker does not
register
+declared classes, the operators dump to ``dict`` instead.
0.3.0
.....
diff --git a/providers/common/ai/docs/operators/agent.rst
b/providers/common/ai/docs/operators/agent.rst
index 51e575d139e..8e7c8ad5f39 100644
--- a/providers/common/ai/docs/operators/agent.rst
+++ b/providers/common/ai/docs/operators/agent.rst
@@ -123,14 +123,15 @@ back. The model instance is pushed to XCom unchanged so
downstream tasks can
type-hint the class directly (``def downstream(result: MyModel)``) and use
attribute access (``result.field``).
-The operator auto-registers ``output_type`` (and any ``BaseModel`` reachable
-from ``Union``/``Optional``/``list`` shapes) for XCom deserialization in every
-process that parses the DAG. The Pydantic class must be defined at **module
-scope** and bound to an attribute matching its ``__name__``. Same-DAG
-downstream tasks need no configuration. The UI's XCom viewer renders the value
-via the ``stringify`` path (no configuration needed; see the ``LLMOperator``
-guide for the exact representation). Cross-DAG ``xcom_pull`` consumers still
-need the class ``qualname`` added to ``[core]
allowed_deserialization_classes``.
+The declared ``output_type`` (and any ``BaseModel`` reachable from
+``Union``/``Optional``/``list`` shapes) is registered for XCom deserialization
by
+the worker when it loads the DAG, before any task runs. The Pydantic class must
+be defined at **module scope** and bound to an attribute matching its
+``__name__``. Same-DAG downstream tasks need no configuration. The UI's XCom
+viewer renders the value via the ``stringify`` path (no configuration needed;
+see the ``LLMOperator`` guide for the exact representation). Cross-DAG
+``xcom_pull`` consumers still need the class ``qualname`` added to
+``[core] allowed_deserialization_classes``.
.. exampleinclude::
/../../ai/src/airflow/providers/common/ai/example_dags/example_agent.py
:language: python
diff --git a/providers/common/ai/docs/operators/llm.rst
b/providers/common/ai/docs/operators/llm.rst
index c542d0e7abc..1f5a375ca90 100644
--- a/providers/common/ai/docs/operators/llm.rst
+++ b/providers/common/ai/docs/operators/llm.rst
@@ -49,13 +49,15 @@ to return structured data, and the model instance is pushed
to XCom unchanged
so downstream tasks can type-hint the class directly
(``def downstream(result: MyModel)``) and use attribute access
(``result.field``).
-The operator auto-registers ``output_type`` (and any ``BaseModel`` reachable
from
-``Union``/``Optional``/``list`` shapes) for XCom deserialization in every
-process that parses the DAG. The Pydantic class must be defined at **module
-scope** and bound to an attribute matching its ``__name__`` -- classes nested
-inside a function or ``@dag``-decorated body, parameterized generics, and
-dynamically-built classes whose ``__name__`` does not match the attribute they
-are bound to are rejected at construction time with a ``ValueError``.
+The declared ``output_type`` (and any ``BaseModel`` reachable from
+``Union``/``Optional``/``list`` shapes) is registered for XCom deserialization
by
+the worker when it loads the DAG, before any task runs -- so no edit to
+``[core] allowed_deserialization_classes`` is needed. The Pydantic class must
be
+defined at **module scope** and bound to an attribute matching its
``__name__``;
+classes nested inside a function or ``@dag``-decorated body, parameterized
+generics, and dynamically-built classes whose ``__name__`` does not match the
+attribute they are bound to cannot be re-imported, so they are skipped with a
+warning at worker startup and the value fails to deserialize at the consumer.
.. exampleinclude::
/../../ai/src/airflow/providers/common/ai/example_dags/example_llm.py
:language: python
@@ -67,9 +69,10 @@ are bound to are rejected at construction time with a
``ValueError``.
:start-after: [START howto_operator_llm_structured]
:end-before: [END howto_operator_llm_structured]
-Auto-registration covers downstream tasks in the **same DAG** -- their workers
-parse the DAG file when starting up, which re-runs the operator constructor and
-re-populates the per-process allow-list.
+Registration covers downstream tasks in the **same DAG**: every worker walks
the
+loaded DAG's tasks at startup and registers each declared class, so it also
works
+for mapped producers (``.expand(...)``) and for workers that load DAGs from a
+cache that bypasses operator construction.
The Airflow UI's XCom viewer renders Pydantic instances via the
``stringify`` path, which produces a representation like
diff --git a/providers/common/ai/docs/operators/llm_file_analysis.rst
b/providers/common/ai/docs/operators/llm_file_analysis.rst
index 9e207a5c963..5e38851ef86 100644
--- a/providers/common/ai/docs/operators/llm_file_analysis.rst
+++ b/providers/common/ai/docs/operators/llm_file_analysis.rst
@@ -78,12 +78,13 @@ Structured Output
Set ``output_type`` to a Pydantic ``BaseModel`` when you want a typed response
back from the LLM instead of a plain string. The model instance is pushed to
XCom unchanged so downstream tasks can type-hint the class directly. The
-operator auto-registers ``output_type`` (and any ``BaseModel`` reachable from
-``Union``/``Optional``/``list`` shapes) for deserialization in every process
-that parses the DAG. Define the class at **module scope** and bind it to an
-attribute matching its ``__name__``: nested-in-function classes and
-dynamically-built classes are rejected at construction time. Same-DAG
-downstream tasks need no configuration; the UI XCom viewer renders the value
+declared ``output_type`` (and any ``BaseModel`` reachable from
+``Union``/``Optional``/``list`` shapes) is registered for deserialization by
the
+worker when it loads the DAG. Define the class at **module scope** and bind it
to
+an attribute matching its ``__name__``: nested-in-function and
dynamically-built
+classes cannot be re-imported, so they are skipped at worker startup and fail
to
+deserialize at the consumer. Same-DAG downstream tasks need no configuration;
the
+UI XCom viewer renders the value
via the ``stringify`` path (no configuration needed). Cross-DAG ``xcom_pull``
consumers still need the class ``qualname`` added to
``[core] allowed_deserialization_classes`` (see the ``LLMOperator`` guide for
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/operators/agent.py
b/providers/common/ai/src/airflow/providers/common/ai/operators/agent.py
index 3b5f516ac2e..b41a0c54d8b 100644
--- a/providers/common/ai/src/airflow/providers/common/ai/operators/agent.py
+++ b/providers/common/ai/src/airflow/providers/common/ai/operators/agent.py
@@ -22,17 +22,14 @@ import json
from collections.abc import Sequence
from datetime import timedelta
from functools import cached_property
-from typing import TYPE_CHECKING, Any
+from typing import TYPE_CHECKING, Any, ClassVar
from pydantic import BaseModel
from airflow.providers.common.ai.hooks.pydantic_ai import PydanticAIHook
from airflow.providers.common.ai.mixins.hitl_review import HITLReviewMixin
from airflow.providers.common.ai.utils.logging import log_run_summary,
wrap_toolsets_for_logging
-from airflow.providers.common.ai.utils.output_type import (
- iter_base_model_classes,
- rehydrate_pydantic_output,
-)
+from airflow.providers.common.ai.utils.output_type import
rehydrate_pydantic_output
from airflow.providers.common.compat.sdk import (
AirflowOptionalProviderFeatureException,
BaseOperator,
@@ -42,9 +39,12 @@ from airflow.providers.common.compat.sdk import (
from airflow.providers.common.compat.version_compat import AIRFLOW_V_3_1_PLUS
try:
- from airflow.sdk.serde import allow_class
-except ImportError: # pragma: no cover - Airflow versions before allow_class
shipped
- allow_class = None # type: ignore[assignment]
+ # See LLMOperator: new enough cores register declared ``output_type``
classes
+ # from a worker-side DAG walk, so the model instance flows through XCom;
older
+ # cores dump to a dict instead.
+ from airflow.sdk.serde import SUPPORTS_OPERATOR_DESERIALIZATION_WALKER as
_CORE_WALKER
+except ImportError: # pragma: no cover - cores before the worker-side
registration walk
+ _CORE_WALKER = False
if TYPE_CHECKING:
from pydantic_ai import Agent
@@ -150,6 +150,8 @@ class AgentOperator(BaseOperator, HITLReviewMixin):
when a downstream consumer needs the dict shape.
"""
+ deserialization_allowed_class_fields: ClassVar[tuple[str, ...]] =
("output_type",)
+
template_fields: Sequence[str] = (
"prompt",
"llm_conn_id",
@@ -189,10 +191,9 @@ class AgentOperator(BaseOperator, HITLReviewMixin):
self.system_prompt = system_prompt
self.output_type = output_type
self.serialize_output = serialize_output
- self._serialize_model_output = serialize_output or allow_class is None
- if not serialize_output and allow_class is not None:
- for model_cls in iter_base_model_classes(output_type):
- allow_class(model_cls)
+ # See LLMOperator: instance flows when the core registers
``output_type``
+ # via its worker-side DAG walk; otherwise (or on opt-in) dump to a
dict.
+ self._serialize_model_output = serialize_output or not _CORE_WALKER
self.toolsets = toolsets
self.enable_tool_logging = enable_tool_logging
self.agent_params = agent_params or {}
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/operators/llm.py
b/providers/common/ai/src/airflow/providers/common/ai/operators/llm.py
index 9d104db1443..c9a22632f2e 100644
--- a/providers/common/ai/src/airflow/providers/common/ai/operators/llm.py
+++ b/providers/common/ai/src/airflow/providers/common/ai/operators/llm.py
@@ -21,23 +21,24 @@ from __future__ import annotations
from collections.abc import Sequence
from datetime import timedelta
from functools import cached_property
-from typing import TYPE_CHECKING, Any
+from typing import TYPE_CHECKING, Any, ClassVar
from pydantic import BaseModel
from airflow.providers.common.ai.hooks.pydantic_ai import PydanticAIHook
from airflow.providers.common.ai.mixins.approval import LLMApprovalMixin
from airflow.providers.common.ai.utils.logging import log_run_summary
-from airflow.providers.common.ai.utils.output_type import (
- iter_base_model_classes,
- rehydrate_pydantic_output,
-)
+from airflow.providers.common.ai.utils.output_type import
rehydrate_pydantic_output
from airflow.providers.common.compat.sdk import BaseOperator
try:
- from airflow.sdk.serde import allow_class
-except ImportError: # pragma: no cover - Airflow versions before allow_class
shipped
- allow_class = None # type: ignore[assignment]
+ # New enough cores register an operator's declared ``output_type`` classes
for
+ # XCom deserialization from a worker-side walk over the loaded DAG. On
those
+ # cores the model instance flows through XCom unchanged. Older cores lack
that
+ # walk, so the operator dumps to a dict instead (still deserializable
anywhere).
+ from airflow.sdk.serde import SUPPORTS_OPERATOR_DESERIALIZATION_WALKER as
_CORE_WALKER
+except ImportError: # pragma: no cover - cores before the worker-side
registration walk
+ _CORE_WALKER = False
if TYPE_CHECKING:
from pydantic_ai import Agent
@@ -95,6 +96,8 @@ class LLMOperator(BaseOperator, LLMApprovalMixin):
external system that expects JSON-style payloads).
"""
+ deserialization_allowed_class_fields: ClassVar[tuple[str, ...]] =
("output_type",)
+
template_fields: Sequence[str] = (
"prompt",
"llm_conn_id",
@@ -126,12 +129,10 @@ class LLMOperator(BaseOperator, LLMApprovalMixin):
self.system_prompt = system_prompt
self.output_type = output_type
self.serialize_output = serialize_output
- # Skip registration when the user opted into the dict form -- the wire
- # carries a plain dict in that case and never hits the allow-list gate.
- self._serialize_model_output = serialize_output or allow_class is None
- if not serialize_output and allow_class is not None:
- for model_cls in iter_base_model_classes(output_type):
- allow_class(model_cls)
+ # Return the Pydantic instance when the core can register
``output_type``
+ # for deserialization (its worker-side DAG walk); otherwise, or when
the
+ # user opts in, dump to a dict so the value is deserializable anywhere.
+ self._serialize_model_output = serialize_output or not _CORE_WALKER
self.agent_params = agent_params or {}
self.usage_limits = usage_limits
self.require_approval = require_approval
@@ -173,9 +174,9 @@ class LLMOperator(BaseOperator, LLMApprovalMixin):
self.defer_for_approval(context, output) # type: ignore[misc]
if self._serialize_model_output and isinstance(output, BaseModel):
- # ``serialize_output=True`` was set explicitly, or this is an
- # older Airflow version without ``airflow.sdk.serde.allow_class``.
- # Either way, dump to dict so XCom carries a plain JSON payload.
+ # ``serialize_output=True``, or a core without the worker-side
+ # deserialization-class walk: dump to a dict so XCom carries a
plain
+ # JSON payload that deserializes without an allow-list entry.
output = output.model_dump()
return output
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/utils/output_type.py
b/providers/common/ai/src/airflow/providers/common/ai/utils/output_type.py
index 4d46b35609b..2ae9b6a5bc3 100644
--- a/providers/common/ai/src/airflow/providers/common/ai/utils/output_type.py
+++ b/providers/common/ai/src/airflow/providers/common/ai/utils/output_type.py
@@ -18,41 +18,11 @@
from __future__ import annotations
-from collections.abc import Iterator
-from typing import Any, get_args, get_origin
+from typing import Any
from pydantic import BaseModel, ValidationError
-def iter_base_model_classes(output_type: Any) -> Iterator[type[BaseModel]]:
- """
- Yield every Pydantic ``BaseModel`` subclass reachable from ``output_type``.
-
- pydantic-ai accepts ``output_type`` as a single class, a ``Union`` /
- ``Optional`` of classes, a list of classes (multi-output), or a
parameterized
- generic such as ``list[MyModel]``. The agent may return an instance of any
- ``BaseModel`` reachable from the type expression, so each must be
registered
- for XCom deserialization, not just the top-level ``output_type``.
- """
- seen: set[type] = set()
- stack: list[Any] = [output_type]
- while stack:
- t = stack.pop()
- # ``list[A]`` returns ``True`` for ``isinstance(t, type)`` on Python
3.10+
- # but has a non-None ``get_origin``; check origin first so we recurse
- # into its args instead of treating ``list[A]`` as a leaf type.
- origin = get_origin(t)
- if origin is not None:
- stack.extend(get_args(t))
- continue
- if isinstance(t, type):
- if t in seen:
- continue
- seen.add(t)
- if issubclass(t, BaseModel):
- yield t
-
-
def rehydrate_pydantic_output(
output_type: Any,
raw: str,
diff --git a/providers/common/ai/tests/unit/common/ai/decorators/test_agent.py
b/providers/common/ai/tests/unit/common/ai/decorators/test_agent.py
index eb1e27ba87b..25a176e3297 100644
--- a/providers/common/ai/tests/unit/common/ai/decorators/test_agent.py
+++ b/providers/common/ai/tests/unit/common/ai/decorators/test_agent.py
@@ -26,15 +26,13 @@ from airflow.providers.common.ai.decorators.agent import
_AgentDecoratedOperator
from airflow.providers.common.ai.toolsets.logging import LoggingToolset
try:
- from airflow.sdk.serde import allow_class
-
- _allow_class: object | None = allow_class
+ from airflow.sdk.serde import SUPPORTS_OPERATOR_DESERIALIZATION_WALKER as
_CORE_WALKER
except ImportError:
- _allow_class = None
+ _CORE_WALKER = False
-requires_allow_class = pytest.mark.skipif(
- _allow_class is None,
- reason="Requires airflow.sdk.serde.allow_class (Airflow with typed-XCom
support).",
+requires_typed_xcom = pytest.mark.skipif(
+ not _CORE_WALKER,
+ reason="Requires a core with the worker-side deserialization-class walk.",
)
@@ -175,7 +173,7 @@ class TestAgentDecoratedOperator:
assert isinstance(passed_toolsets[0], LoggingToolset)
assert passed_toolsets[0].wrapped is mock_toolset
- @requires_allow_class
+ @requires_typed_xcom
@patch("airflow.providers.common.ai.operators.agent.PydanticAIHook",
autospec=True)
def test_execute_structured_output(self, mock_hook_cls):
"""BaseModel output flows through XCom as the Pydantic instance."""
diff --git a/providers/common/ai/tests/unit/common/ai/operators/test_agent.py
b/providers/common/ai/tests/unit/common/ai/operators/test_agent.py
index b934fb77edc..4a0b08cf148 100644
--- a/providers/common/ai/tests/unit/common/ai/operators/test_agent.py
+++ b/providers/common/ai/tests/unit/common/ai/operators/test_agent.py
@@ -29,15 +29,13 @@ from airflow.providers.common.ai.toolsets.logging import
LoggingToolset
from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS
try:
- from airflow.sdk.serde import allow_class
-
- _allow_class: object | None = allow_class
+ from airflow.sdk.serde import SUPPORTS_OPERATOR_DESERIALIZATION_WALKER as
_CORE_WALKER
except ImportError:
- _allow_class = None
+ _CORE_WALKER = False
-requires_allow_class = pytest.mark.skipif(
- _allow_class is None,
- reason="Requires airflow.sdk.serde.allow_class (Airflow with typed-XCom
support).",
+requires_typed_xcom = pytest.mark.skipif(
+ not _CORE_WALKER,
+ reason="Requires a core with the worker-side deserialization-class walk.",
)
@@ -210,7 +208,7 @@ class TestAgentOperatorExecute:
assert create_call[1]["retries"] == 3
assert create_call[1]["model_settings"] == {"temperature": 0}
- @requires_allow_class
+ @requires_typed_xcom
@patch("airflow.providers.common.ai.operators.agent.PydanticAIHook",
autospec=True)
def test_execute_structured_output(self, mock_hook_cls):
"""Structured output keeps the Pydantic instance so downstream tasks
can type-hint it."""
@@ -230,26 +228,9 @@ class TestAgentOperatorExecute:
assert result.text == "Great"
assert result.score == 0.95
- @requires_allow_class
- def test_init_rejects_nested_output_type(self):
- """A BaseModel defined inside a function carries ``<locals>`` and
can't survive XCom."""
-
- def _build():
- class Nested(BaseModel):
- v: int
-
- return AgentOperator(task_id="t", prompt="p", llm_conn_id="c",
output_type=Nested)
-
- with pytest.raises(ValueError, match="defined inside a function"):
- _build()
-
- @requires_allow_class
- def test_init_registers_output_type_in_extra_allowed(self):
- from airflow.sdk.module_loading import qualname
- from airflow.sdk.serde import _extra_allowed
-
- AgentOperator(task_id="t", prompt="p", llm_conn_id="c",
output_type=Summary)
- assert qualname(Summary) in _extra_allowed
+ def test_declares_output_type_for_deserialization(self):
+ """Declares ``output_type`` so the worker-side DAG walk registers it
for deserialization."""
+ assert "output_type" in
AgentOperator.deserialization_allowed_class_fields
@patch("airflow.providers.common.ai.operators.agent.PydanticAIHook",
autospec=True)
def test_execute_with_model_id(self, mock_hook_cls):
@@ -294,7 +275,7 @@ class TestAgentOperatorExecute:
assert result == "Approved output"
mock_run_hitl.assert_called_once_with(op, context, "Initial output",
message_history=msg_history)
- @requires_allow_class
+ @requires_typed_xcom
@pytest.mark.skipif(
not AIRFLOW_V_3_1_PLUS, reason="Human in the loop is only compatible
with Airflow >= 3.1.0"
)
diff --git a/providers/common/ai/tests/unit/common/ai/operators/test_llm.py
b/providers/common/ai/tests/unit/common/ai/operators/test_llm.py
index cfa4cf3e191..2a707752fdf 100644
--- a/providers/common/ai/tests/unit/common/ai/operators/test_llm.py
+++ b/providers/common/ai/tests/unit/common/ai/operators/test_llm.py
@@ -32,15 +32,16 @@ from airflow.providers.common.ai.operators.llm import
LLMOperator
from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS
try:
- from airflow.sdk.serde import allow_class
-
- _allow_class: object | None = allow_class
+ from airflow.sdk.serde import SUPPORTS_OPERATOR_DESERIALIZATION_WALKER as
_CORE_WALKER
except ImportError:
- _allow_class = None
-
-requires_allow_class = pytest.mark.skipif(
- _allow_class is None,
- reason="Requires airflow.sdk.serde.allow_class (Airflow with typed-XCom
support).",
+ _CORE_WALKER = False
+
+# Returning the Pydantic instance through XCom (rather than a dict) only
happens
+# on cores that register declared ``output_type`` classes from the worker-side
+# DAG walk. On older cores the operator dumps to a dict, so these tests skip.
+requires_typed_xcom = pytest.mark.skipif(
+ not _CORE_WALKER,
+ reason="Requires a core with the worker-side deserialization-class walk.",
)
@@ -104,7 +105,7 @@ class TestLLMOperator:
mock_agent.run_sync.assert_called_once_with("Summarize",
usage_limits=limits)
- @requires_allow_class
+ @requires_typed_xcom
@patch("airflow.providers.common.ai.operators.llm.PydanticAIHook",
autospec=True)
def test_execute_structured_output_with_all_params(self, mock_hook_cls):
"""Structured output returns the Pydantic instance unchanged so
downstream tasks keep the type."""
@@ -133,27 +134,13 @@ class TestLLMOperator:
model_settings={"temperature": 0.9},
)
- @requires_allow_class
- def test_init_rejects_nested_output_type(self):
- """output_type defined inside a function carries ``<locals>`` and
can't survive XCom."""
-
- def _build_op():
- class Nested(BaseModel):
- v: int
-
- return LLMOperator(task_id="t", prompt="p", llm_conn_id="c",
output_type=Nested)
-
- with pytest.raises(ValueError, match="defined inside a function"):
- _build_op()
-
- @requires_allow_class
- def test_init_registers_output_type_in_extra_allowed(self):
- """A module-scope BaseModel output_type is auto-registered for XCom
deserialization."""
- from airflow.sdk.module_loading import qualname
- from airflow.sdk.serde import _extra_allowed
+ def test_declares_output_type_for_deserialization(self):
+ """Declares ``output_type`` so the worker-side DAG walk registers it
for deserialization.
- LLMOperator(task_id="t", prompt="p", llm_conn_id="c",
output_type=Entities)
- assert qualname(Entities) in _extra_allowed
+ Registration happens in the core walk over the loaded DAG (covered by
the
+ task-runner tests), not as an ``__init__`` side effect.
+ """
+ assert "output_type" in
LLMOperator.deserialization_allowed_class_fields
@patch("airflow.providers.common.ai.operators.llm.PydanticAIHook",
autospec=True)
def test_execute_serialize_output_returns_dict(self, mock_hook_cls):
@@ -354,7 +341,7 @@ class TestLLMOperatorApproval:
assert result == "edited"
- @requires_allow_class
+ @requires_typed_xcom
def test_execute_complete_rehydrates_pydantic_for_structured_output(self):
"""When output_type is a BaseModel, execute_complete returns the
model, not the JSON string."""
op = LLMOperator(task_id="t", prompt="p", llm_conn_id="c",
output_type=Summary)
diff --git
a/providers/common/ai/tests/unit/common/ai/operators/test_llm_file_analysis.py
b/providers/common/ai/tests/unit/common/ai/operators/test_llm_file_analysis.py
index 7b32e27451d..6c970e43263 100644
---
a/providers/common/ai/tests/unit/common/ai/operators/test_llm_file_analysis.py
+++
b/providers/common/ai/tests/unit/common/ai/operators/test_llm_file_analysis.py
@@ -29,15 +29,13 @@ from airflow.providers.common.ai.utils.file_analysis import
FileAnalysisRequest
from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS
try:
- from airflow.sdk.serde import allow_class
-
- _allow_class: object | None = allow_class
+ from airflow.sdk.serde import SUPPORTS_OPERATOR_DESERIALIZATION_WALKER as
_CORE_WALKER
except ImportError:
- _allow_class = None
+ _CORE_WALKER = False
-requires_allow_class = pytest.mark.skipif(
- _allow_class is None,
- reason="Requires airflow.sdk.serde.allow_class (Airflow with typed-XCom
support).",
+requires_typed_xcom = pytest.mark.skipif(
+ not _CORE_WALKER,
+ reason="Requires a core with the worker-side deserialization-class walk.",
)
@@ -119,7 +117,7 @@ class TestLLMFileAnalysisOperator:
)
mock_agent.run_sync.assert_called_once_with("prepared prompt",
usage_limits=None)
- @requires_allow_class
+ @requires_typed_xcom
@patch("airflow.providers.common.ai.operators.llm.PydanticAIHook",
autospec=True)
@patch(
"airflow.providers.common.ai.operators.llm_file_analysis.build_file_analysis_request",
autospec=True
@@ -272,7 +270,7 @@ class TestLLMFileAnalysisOperatorApproval:
assert exc_info.value.kwargs["generated_output"] ==
'{"findings":["error spike"]}'
mock_upsert.assert_called_once()
- @requires_allow_class
+ @requires_typed_xcom
def test_execute_complete_with_approval_restores_structured_output(self):
op = LLMFileAnalysisOperator(
task_id="approval_complete_test",
@@ -289,7 +287,7 @@ class TestLLMFileAnalysisOperatorApproval:
assert isinstance(result, Summary)
assert result.findings == ["error spike"]
- @requires_allow_class
+ @requires_typed_xcom
def
test_execute_complete_with_approval_restores_modified_structured_output(self):
op = LLMFileAnalysisOperator(
task_id="approval_complete_modified_test",
diff --git a/providers/common/ai/tests/unit/common/ai/utils/test_output_type.py
b/providers/common/ai/tests/unit/common/ai/utils/test_output_type.py
index 45971f46f60..4c3dcae5ee2 100644
--- a/providers/common/ai/tests/unit/common/ai/utils/test_output_type.py
+++ b/providers/common/ai/tests/unit/common/ai/utils/test_output_type.py
@@ -18,53 +18,13 @@ from __future__ import annotations
from pydantic import BaseModel
-from airflow.providers.common.ai.utils.output_type import (
- iter_base_model_classes,
- rehydrate_pydantic_output,
-)
+from airflow.providers.common.ai.utils.output_type import
rehydrate_pydantic_output
class A(BaseModel):
x: int
-class B(BaseModel):
- y: str
-
-
-class C(BaseModel):
- z: float
-
-
-class TestIterBaseModelClasses:
- def test_single_class(self):
- assert set(iter_base_model_classes(A)) == {A}
-
- def test_str_skipped(self):
- assert set(iter_base_model_classes(str)) == set()
-
- def test_optional(self):
- assert set(iter_base_model_classes(A | None)) == {A}
-
- def test_union(self):
- assert set(iter_base_model_classes(A | B)) == {A, B}
-
- def test_list_of_models(self):
- assert set(iter_base_model_classes(list[A])) == {A}
-
- def test_dict_with_model_values(self):
- assert set(iter_base_model_classes(dict[str, A])) == {A}
-
- def test_nested_union_list_optional(self):
- assert set(iter_base_model_classes(list[A | B | None])) == {A, B}
-
- def test_mixed_with_primitives(self):
- assert set(iter_base_model_classes(A | str | int | B)) == {A, B}
-
- def test_three_models(self):
- assert set(iter_base_model_classes(A | B | C)) == {A, B, C}
-
-
class TestRehydratePydanticOutput:
def test_returns_model_instance(self):
result = rehydrate_pydantic_output(A, '{"x": 7}',
serialize_output=False)
diff --git a/task-sdk/src/airflow/sdk/bases/operator.py
b/task-sdk/src/airflow/sdk/bases/operator.py
index 8d6de54eb6d..97da8869686 100644
--- a/task-sdk/src/airflow/sdk/bases/operator.py
+++ b/task-sdk/src/airflow/sdk/bases/operator.py
@@ -929,6 +929,13 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
template_fields_renderers: ClassVar[dict[str, str]] = {}
+ # Names of constructor fields whose values may contain Pydantic model
classes
+ # this operator can emit to XCom (e.g. ``("output_type",)``). The worker
+ # registers those classes in the deserialization allow-list before running
+ # tasks, so downstream consumers can deserialize the instances without an
+ # ``[core] allowed_deserialization_classes`` edit. Empty by default.
+ deserialization_allowed_class_fields: ClassVar[tuple[str, ...]] = ()
+
operator_extra_links: Collection[BaseOperatorLink] = ()
# Defines the color in the UI
diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
index 0ff9be4e8c9..26f05d3a9d3 100644
--- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -137,6 +137,7 @@ from airflow.sdk.execution_time.sentry import Sentry
from airflow.sdk.execution_time.xcom import XCom
from airflow.sdk.listener import get_listener_manager
from airflow.sdk.observability.metrics import stats_utils
+from airflow.sdk.serde import allow_class, iter_pydantic_models
from airflow.sdk.state import TaskScope
from airflow.sdk.timezone import coerce_datetime
@@ -837,6 +838,42 @@ def _maybe_reschedule_startup_failure(
)
+def _register_deserialization_allowed_classes(dag, log: Logger) -> None:
+ """
+ Register every operator-declared XCom model class in the deserialization
allow-list.
+
+ Runs once per task-run startup, walking the whole DAG, so a consumer task
can
+ deserialize a producer's structured output even though only the producer
+ constructs the value. Reads the declared classes off real operators
+ (``getattr``) and off not-yet-expanded mapped operators
(``partial_kwargs``),
+ and works regardless of how the DAG was loaded -- a fresh parse, or a cache
+ that reconstructs operators without running ``__init__``.
+
+ Failures to register a single class are logged and skipped rather than
failing
+ task startup; a genuinely undeserializable declaration surfaces later as
the
+ normal allow-list ImportError at consume time.
+ """
+ for op in dag.tasks:
+ if isinstance(op, MappedOperator):
+ fields = getattr(op.operator_class,
"deserialization_allowed_class_fields", ())
+ values = [op.partial_kwargs.get(field) for field in fields]
+ else:
+ fields = getattr(op, "deserialization_allowed_class_fields", ())
+ values = [getattr(op, field, None) for field in fields]
+ for value in values:
+ if value is None:
+ continue
+ for model_cls in iter_pydantic_models(value):
+ try:
+ allow_class(model_cls)
+ except ValueError as exc:
+ log.warning(
+ "Skipping XCom deserialization registration for a
model class",
+ task_id=op.task_id,
+ error=str(exc),
+ )
+
+
@detail_span("parse")
def parse(what: StartupDetails, log: Logger) -> RuntimeTaskInstance:
# TODO: Task-SDK:
@@ -895,6 +932,12 @@ def parse(what: StartupDetails, log: Logger) ->
RuntimeTaskInstance:
f"task is of the wrong type, got {type(task)}, wanted
{BaseOperator} or {MappedOperator}"
)
+ # Register operator-declared XCom model classes (e.g. ``output_type``) for
+ # the whole DAG so this task can deserialize structured output produced by
+ # any other task -- including mapped producers and DAGs loaded from a cache
+ # that bypasses operator ``__init__``.
+ _register_deserialization_allowed_classes(dag, log)
+
# Surface the post-RUNNING startup breakdown so support engineers and DAG
authors can
# attribute apparent slow startup to bundle prep (Airflow-side, e.g. git
fetch) vs.
# DAG file parse (user code). Emitted before return so it lands in the
task log.
diff --git a/task-sdk/src/airflow/sdk/serde/__init__.py
b/task-sdk/src/airflow/sdk/serde/__init__.py
index 0b9a383fe41..028d1d391cf 100644
--- a/task-sdk/src/airflow/sdk/serde/__init__.py
+++ b/task-sdk/src/airflow/sdk/serde/__init__.py
@@ -22,10 +22,11 @@ import functools
import logging
import re
import sys
+from collections.abc import Iterator
from fnmatch import fnmatch
from importlib import import_module
from re import Pattern
-from typing import TYPE_CHECKING, Any, TypeVar, cast, overload
+from typing import TYPE_CHECKING, Any, TypeVar, cast, get_args, get_origin,
overload
import attr
@@ -56,6 +57,12 @@ PYDANTIC_MODEL_QUALNAME = "pydantic.main.BaseModel"
DEFAULT_VERSION = 0
+# Signals that this Airflow registers operator-declared deserialization classes
+# from a worker-side walk over the loaded DAG (see the task runner), so
operators
+# do not need to register them as an ``__init__`` side effect. Providers probe
+# this to drop their back-compat ``__init__`` registration on new enough cores.
+SUPPORTS_OPERATOR_DESERIALIZATION_WALKER = True
+
T = TypeVar("T", bool, float, int, dict, list, str, tuple, set)
U = bool | float | int | dict | list | str | tuple | set
S = list | tuple | set
@@ -117,6 +124,35 @@ def allow_class(cls: type) -> None:
_extra_allowed.add(qn)
+def iter_pydantic_models(annotation: Any) -> Iterator[type]:
+ """
+ Yield every Pydantic model class reachable from a type annotation.
+
+ Handles a bare model class, ``Optional`` / ``Union`` of models, and
+ parameterized containers such as ``list[MyModel]`` -- the shapes accepted
as
+ an operator ``output_type``. The agent (or operator) may emit an instance
of
+ any model reachable from the annotation, so each must be registered for
XCom
+ deserialization, not just the top-level type.
+ """
+ seen: set[Any] = set()
+ stack: list[Any] = [annotation]
+ while stack:
+ tp = stack.pop()
+ # ``list[A]`` answers ``True`` to ``isinstance(tp, type)`` on 3.10+ yet
+ # carries a non-None ``get_origin``; recurse into its args first so the
+ # container itself is not mistaken for a leaf type.
+ origin = get_origin(tp)
+ if origin is not None:
+ stack.extend(get_args(tp))
+ continue
+ if isinstance(tp, type):
+ if tp in seen:
+ continue
+ seen.add(tp)
+ if is_pydantic_model(tp):
+ yield tp
+
+
def decode(d: dict[str, Any]) -> tuple[str, int, Any]:
classname = d[CLASSNAME]
version = d[VERSION]
diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
index 507e63cfff0..097c3c3d0c3 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
@@ -32,11 +32,13 @@ from unittest.mock import call, patch
import pandas as pd
import pytest
+import structlog
from opentelemetry import trace
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.sdk.trace.export.in_memory_span_exporter import
InMemorySpanExporter
from opentelemetry.trace.propagation.tracecontext import
TraceContextTextMapPropagator
+from pydantic import BaseModel
from task_sdk import FAKE_BUNDLE
from uuid6 import uuid7
@@ -162,6 +164,7 @@ from airflow.sdk.execution_time.task_runner import (
_execute_task,
_make_task_span,
_push_xcom_if_needed,
+ _register_deserialization_allowed_classes,
_serialize_outlet_events,
_xcom_push,
detail_span,
@@ -248,9 +251,14 @@ def test_parse_dag_bag(mock_dagbag, test_dags_dir: Path,
make_ti_context):
mock_dagbag.return_value = mock_bag_instance
mock_dag = mock.Mock(spec=DAG)
mock_task = mock.Mock(spec=BaseOperator)
+ # The worker walks dag.tasks to register declared deserialization classes;
+ # give the mock an iterable tasks list and an empty declaration so the walk
+ # is a no-op for this BundleDagBag-construction test.
+ mock_task.deserialization_allowed_class_fields = ()
mock_bag_instance.dags = {"super_basic": mock_dag}
mock_dag.task_dict = {"a": mock_task}
+ mock_dag.tasks = [mock_task]
what = StartupDetails(
ti=TaskInstanceDTO(
@@ -5662,3 +5670,70 @@ class TestTaskInstanceStateOperations:
for call in mock_supervisor_comms.send.call_args_list
]
assert ClearTaskState not in sent_types
+
+
+class _WalkerModelA(BaseModel):
+ a: int
+
+
+class _WalkerModelB(BaseModel):
+ b: int
+
+
+class _WalkerOperator(BaseOperator):
+ """Operator that declares a Pydantic ``output_type`` for
deserialization."""
+
+ deserialization_allowed_class_fields = ("output_type",)
+
+ def __init__(self, *, output_type=str, value=None, **kwargs):
+ super().__init__(**kwargs)
+ self.output_type = output_type
+ self.value = value
+
+ def execute(self, context):
+ return None
+
+
+class TestRegisterDeserializationAllowedClasses:
+ """The worker-side walk registers operator-declared XCom classes for the
whole DAG.
+
+ ``allow_class`` is patched to a spy so these assert the walk *extracts*
the right
+ classes from both real and mapped operators; ``allow_class``'s own import
+ validation is covered by the serde tests.
+ """
+
+ def test_registers_real_and_mapped_operators(self):
+
+ with DAG("walker_dag") as dag:
+ # Non-mapped producer: output_type is a plain attribute.
+ _WalkerOperator(task_id="real", output_type=_WalkerModelA)
+ # Mapped producer: output_type lives in partial_kwargs and
__init__ never
+ # runs at parse -- the case the old __init__ registration could
not reach.
+ _WalkerOperator.partial(task_id="mapped",
output_type=_WalkerModelB).expand(value=[1, 2])
+
+ registered: list[type] = []
+ with patch("airflow.sdk.execution_time.task_runner.allow_class",
side_effect=registered.append):
+ _register_deserialization_allowed_classes(dag,
structlog.get_logger())
+
+ assert _WalkerModelA in registered, "real operator output_type not
registered"
+ assert _WalkerModelB in registered, "mapped operator output_type not
registered"
+
+ def test_default_operator_registers_nothing(self):
+
+ with DAG("walker_dag_plain") as dag:
+ BaseOperator(task_id="plain")
+
+ registered: list[type] = []
+ with patch("airflow.sdk.execution_time.task_runner.allow_class",
side_effect=registered.append):
+ _register_deserialization_allowed_classes(dag,
structlog.get_logger())
+ assert registered == []
+
+ def test_bad_declaration_is_skipped_not_fatal(self):
+ """A class that fails ``allow_class`` validation is logged and
skipped, not raised."""
+
+ with DAG("walker_dag_bad") as dag:
+ _WalkerOperator(task_id="real", output_type=_WalkerModelA)
+
+ with patch("airflow.sdk.execution_time.task_runner.allow_class",
side_effect=ValueError("nope")):
+ # Must not raise -- the walk swallows per-class registration
errors.
+ _register_deserialization_allowed_classes(dag,
structlog.get_logger())
diff --git a/task-sdk/tests/task_sdk/serde/test_serde.py
b/task-sdk/tests/task_sdk/serde/test_serde.py
index 890ed436d39..bf41f4e3ed4 100644
--- a/task-sdk/tests/task_sdk/serde/test_serde.py
+++ b/task-sdk/tests/task_sdk/serde/test_serde.py
@@ -44,6 +44,7 @@ from airflow.sdk.serde import (
_match_regexp,
allow_class,
deserialize,
+ iter_pydantic_models,
serialize,
)
@@ -455,6 +456,25 @@ class TestSerDe:
with pytest.raises(ValueError, match="cannot be re-imported|does not
resolve"):
allow_class(Mismatched)
+ def test_iter_pydantic_models_shapes(self):
+ """iter_pydantic_models finds models in bare, optional/union, and
container annotations."""
+
+ class M1(BaseModel):
+ a: int
+
+ class M2(BaseModel):
+ b: int
+
+ assert set(iter_pydantic_models(M1)) == {M1}
+ assert set(iter_pydantic_models(M1 | None)) == {M1}
+ assert set(iter_pydantic_models(M1 | M2)) == {M1, M2}
+ assert set(iter_pydantic_models(list[M1])) == {M1}
+ assert set(iter_pydantic_models(dict[str, M2])) == {M2}
+ assert set(iter_pydantic_models(list[M1 | M2 | None])) == {M1, M2}
+ # Non-model annotations yield nothing.
+ assert set(iter_pydantic_models(str)) == set()
+ assert set(iter_pydantic_models(list[int])) == set()
+
def test_incompatible_version(self):
data = dict(
{