This is an automated email from the ASF dual-hosted git repository.

kaxil pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 6a95da5e98d Register operator-declared XCom classes from a worker-side 
DAG walk (#67875)
6a95da5e98d is described below

commit 6a95da5e98d755860aec60bf946017bbc5e7442e
Author: Kaxil Naik <[email protected]>
AuthorDate: Tue Jun 2 09:58:04 2026 +0100

    Register operator-declared XCom classes from a worker-side DAG walk (#67875)
    
    The common-ai typed-XCom feature registered an operator's ``output_type``
    Pydantic class for XCom deserialization as a side effect of operator
    ``__init__``. That misses two real cases:
    
    1. Mapped producers. ``@task.llm(output_type=X).expand(...)`` registers X 
only
       when the mapped task unmaps and runs (which goes through ``__init__``), 
in
       the producer's own process. A downstream consumer runs in a different 
process
       that loads the DAG with the producer still an unexpanded 
``MappedOperator``
       (no ``__init__``), so X is never registered there and deserialization 
raises
       ``ImportError``. The shipped ``example_llm_analysis_pipeline`` is this 
shape.
    
    2. Workers that reconstruct operators without ``__init__`` -- e.g. a 
parsed-DAG
       cache that loads operators -- never run the registration, so
       the same ``ImportError`` occurs.
    
    Move registration to the worker: ``task_runner.parse()`` walks the loaded 
DAG's
    tasks and registers each operator's declared deserialization classes before 
any
    task runs. It reads the classes off real operators and off not-yet-expanded
    mapped operators (``partial_kwargs``), and works regardless of how the DAG 
was
    loaded. Operators opt in by declaring 
``deserialization_allowed_class_fields``
    (``("output_type",)`` for the common-ai operators); the generic Pydantic
    type-tree walk (``Union``/``Optional``/``list``) lives in ``serde`` as
    ``iter_pydantic_models``.
    
    The operator ``__init__`` registration side effect is removed. Operators 
probe
    ``serde.SUPPORTS_OPERATOR_DESERIALIZATION_WALKER`` and fall back to dumping 
the
    model to a ``dict`` on cores that lack the walk, so the value stays
    deserializable without an allow-list edit. ``allow_class`` (and thus the 
gate)
    is unchanged: only an operator-declared class, sourced from the trusted 
parsed
    DAG, is registered, and registration stays exact-match.
---
 providers/common/ai/docs/changelog.rst             | 35 +++++-----
 providers/common/ai/docs/operators/agent.rst       | 17 ++---
 providers/common/ai/docs/operators/llm.rst         | 23 ++++---
 .../common/ai/docs/operators/llm_file_analysis.rst | 13 ++--
 .../airflow/providers/common/ai/operators/agent.py | 25 ++++----
 .../airflow/providers/common/ai/operators/llm.py   | 35 +++++-----
 .../providers/common/ai/utils/output_type.py       | 32 +--------
 .../tests/unit/common/ai/decorators/test_agent.py  | 14 ++--
 .../tests/unit/common/ai/operators/test_agent.py   | 39 +++--------
 .../ai/tests/unit/common/ai/operators/test_llm.py  | 47 +++++---------
 .../common/ai/operators/test_llm_file_analysis.py  | 18 +++---
 .../tests/unit/common/ai/utils/test_output_type.py | 42 +-----------
 task-sdk/src/airflow/sdk/bases/operator.py         |  7 ++
 .../src/airflow/sdk/execution_time/task_runner.py  | 43 +++++++++++++
 task-sdk/src/airflow/sdk/serde/__init__.py         | 38 ++++++++++-
 .../task_sdk/execution_time/test_task_runner.py    | 75 ++++++++++++++++++++++
 task-sdk/tests/task_sdk/serde/test_serde.py        | 20 ++++++
 17 files changed, 302 insertions(+), 221 deletions(-)

diff --git a/providers/common/ai/docs/changelog.rst 
b/providers/common/ai/docs/changelog.rst
index 4badccb0161..f935bc64395 100644
--- a/providers/common/ai/docs/changelog.rst
+++ b/providers/common/ai/docs/changelog.rst
@@ -29,24 +29,23 @@ Breaking change: operators with ``output_type=<BaseModel 
subclass>``
 (``LLMOperator``, ``LLMAgentOperator``, ``LLMFileAnalysisOperator``, and
 their ``@task.llm`` / ``@task.agent`` / ``@task.llm_file_analysis`` decorators)
 now return the Pydantic model instance through XCom instead of dumping it to
-a ``dict`` when the running Airflow version provides
-``airflow.sdk.serde.allow_class``. Downstream tasks should type-hint the model
-class (``def downstream(result: MyModel)``) and use attribute access
-(``result.field``) instead of subscript access. The output class must be
-defined at **module scope** and bound to an attribute matching its
-``__name__``; operators raise ``ValueError`` at construction time when
-``output_type`` (or any ``BaseModel`` reachable from a ``Union``/``Optional``/
-``list`` of types) is nested, dynamically built, or non-importable by 
``qualname``.
-
-Same-DAG downstream tasks deserialize the model without any configuration
-change because each worker re-runs the operator constructor when it parses the
-DAG. The UI XCom viewer renders the value via the ``stringify`` path and works
-without configuration (it shows ``module.MyModel@version=1(field=value,...)``
-rather than a pretty form, but no allow-list edit is required). Cross-DAG
-``xcom_pull`` consumers still need the class qualified name added to
-``[core] allowed_deserialization_classes`` -- the consumer DAG's worker only
-parses its own DAG file. On older Airflow releases that lack ``allow_class``
-the operators continue to dump to ``dict``.
+a ``dict``, on Airflow versions whose worker registers operator-declared output
+classes for deserialization. Downstream tasks should type-hint the model class
+(``def downstream(result: MyModel)``) and use attribute access 
(``result.field``)
+instead of subscript access. The output class must be defined at **module 
scope**
+and bound to an attribute matching its ``__name__``; classes that are nested,
+dynamically built, or otherwise non-importable by ``qualname`` cannot be
+re-imported and will fail to deserialize at the consumer.
+
+The worker walks the loaded DAG and registers each declared class before any
+task runs, so same-DAG downstream tasks (including mapped ``.expand(...)``
+producers) deserialize the model without any configuration change. The UI XCom
+viewer renders the value via the ``stringify`` path and works without
+configuration (it shows ``module.MyModel@version=1(field=value,...)`` rather 
than
+a pretty form). Cross-DAG ``xcom_pull`` consumers still need the class 
qualified
+name added to ``[core] allowed_deserialization_classes`` -- the consumer DAG's
+worker only loads its own DAG. On Airflow versions whose worker does not 
register
+declared classes, the operators dump to ``dict`` instead.
 
 0.3.0
 .....
diff --git a/providers/common/ai/docs/operators/agent.rst 
b/providers/common/ai/docs/operators/agent.rst
index 51e575d139e..8e7c8ad5f39 100644
--- a/providers/common/ai/docs/operators/agent.rst
+++ b/providers/common/ai/docs/operators/agent.rst
@@ -123,14 +123,15 @@ back. The model instance is pushed to XCom unchanged so 
downstream tasks can
 type-hint the class directly (``def downstream(result: MyModel)``) and use
 attribute access (``result.field``).
 
-The operator auto-registers ``output_type`` (and any ``BaseModel`` reachable
-from ``Union``/``Optional``/``list`` shapes) for XCom deserialization in every
-process that parses the DAG. The Pydantic class must be defined at **module
-scope** and bound to an attribute matching its ``__name__``. Same-DAG
-downstream tasks need no configuration. The UI's XCom viewer renders the value
-via the ``stringify`` path (no configuration needed; see the ``LLMOperator``
-guide for the exact representation). Cross-DAG ``xcom_pull`` consumers still
-need the class ``qualname`` added to ``[core] 
allowed_deserialization_classes``.
+The declared ``output_type`` (and any ``BaseModel`` reachable from
+``Union``/``Optional``/``list`` shapes) is registered for XCom deserialization 
by
+the worker when it loads the DAG, before any task runs. The Pydantic class must
+be defined at **module scope** and bound to an attribute matching its
+``__name__``. Same-DAG downstream tasks need no configuration. The UI's XCom
+viewer renders the value via the ``stringify`` path (no configuration needed;
+see the ``LLMOperator`` guide for the exact representation). Cross-DAG
+``xcom_pull`` consumers still need the class ``qualname`` added to
+``[core] allowed_deserialization_classes``.
 
 .. exampleinclude:: 
/../../ai/src/airflow/providers/common/ai/example_dags/example_agent.py
     :language: python
diff --git a/providers/common/ai/docs/operators/llm.rst 
b/providers/common/ai/docs/operators/llm.rst
index c542d0e7abc..1f5a375ca90 100644
--- a/providers/common/ai/docs/operators/llm.rst
+++ b/providers/common/ai/docs/operators/llm.rst
@@ -49,13 +49,15 @@ to return structured data, and the model instance is pushed 
to XCom unchanged
 so downstream tasks can type-hint the class directly
 (``def downstream(result: MyModel)``) and use attribute access 
(``result.field``).
 
-The operator auto-registers ``output_type`` (and any ``BaseModel`` reachable 
from
-``Union``/``Optional``/``list`` shapes) for XCom deserialization in every
-process that parses the DAG. The Pydantic class must be defined at **module
-scope** and bound to an attribute matching its ``__name__`` -- classes nested
-inside a function or ``@dag``-decorated body, parameterized generics, and
-dynamically-built classes whose ``__name__`` does not match the attribute they
-are bound to are rejected at construction time with a ``ValueError``.
+The declared ``output_type`` (and any ``BaseModel`` reachable from
+``Union``/``Optional``/``list`` shapes) is registered for XCom deserialization 
by
+the worker when it loads the DAG, before any task runs -- so no edit to
+``[core] allowed_deserialization_classes`` is needed. The Pydantic class must 
be
+defined at **module scope** and bound to an attribute matching its 
``__name__``;
+classes nested inside a function or ``@dag``-decorated body, parameterized
+generics, and dynamically-built classes whose ``__name__`` does not match the
+attribute they are bound to cannot be re-imported, so they are skipped with a
+warning at worker startup and the value fails to deserialize at the consumer.
 
 .. exampleinclude:: 
/../../ai/src/airflow/providers/common/ai/example_dags/example_llm.py
     :language: python
@@ -67,9 +69,10 @@ are bound to are rejected at construction time with a 
``ValueError``.
     :start-after: [START howto_operator_llm_structured]
     :end-before: [END howto_operator_llm_structured]
 
-Auto-registration covers downstream tasks in the **same DAG** -- their workers
-parse the DAG file when starting up, which re-runs the operator constructor and
-re-populates the per-process allow-list.
+Registration covers downstream tasks in the **same DAG**: every worker walks 
the
+loaded DAG's tasks at startup and registers each declared class, so it also 
works
+for mapped producers (``.expand(...)``) and for workers that load DAGs from a
+cache that bypasses operator construction.
 
 The Airflow UI's XCom viewer renders Pydantic instances via the
 ``stringify`` path, which produces a representation like
diff --git a/providers/common/ai/docs/operators/llm_file_analysis.rst 
b/providers/common/ai/docs/operators/llm_file_analysis.rst
index 9e207a5c963..5e38851ef86 100644
--- a/providers/common/ai/docs/operators/llm_file_analysis.rst
+++ b/providers/common/ai/docs/operators/llm_file_analysis.rst
@@ -78,12 +78,13 @@ Structured Output
 Set ``output_type`` to a Pydantic ``BaseModel`` when you want a typed response
 back from the LLM instead of a plain string. The model instance is pushed to
 XCom unchanged so downstream tasks can type-hint the class directly. The
-operator auto-registers ``output_type`` (and any ``BaseModel`` reachable from
-``Union``/``Optional``/``list`` shapes) for deserialization in every process
-that parses the DAG. Define the class at **module scope** and bind it to an
-attribute matching its ``__name__``: nested-in-function classes and
-dynamically-built classes are rejected at construction time. Same-DAG
-downstream tasks need no configuration; the UI XCom viewer renders the value
+declared ``output_type`` (and any ``BaseModel`` reachable from
+``Union``/``Optional``/``list`` shapes) is registered for deserialization by 
the
+worker when it loads the DAG. Define the class at **module scope** and bind it 
to
+an attribute matching its ``__name__``: nested-in-function and 
dynamically-built
+classes cannot be re-imported, so they are skipped at worker startup and fail 
to
+deserialize at the consumer. Same-DAG downstream tasks need no configuration; 
the
+UI XCom viewer renders the value
 via the ``stringify`` path (no configuration needed). Cross-DAG ``xcom_pull``
 consumers still need the class ``qualname`` added to
 ``[core] allowed_deserialization_classes`` (see the ``LLMOperator`` guide for
diff --git 
a/providers/common/ai/src/airflow/providers/common/ai/operators/agent.py 
b/providers/common/ai/src/airflow/providers/common/ai/operators/agent.py
index 3b5f516ac2e..b41a0c54d8b 100644
--- a/providers/common/ai/src/airflow/providers/common/ai/operators/agent.py
+++ b/providers/common/ai/src/airflow/providers/common/ai/operators/agent.py
@@ -22,17 +22,14 @@ import json
 from collections.abc import Sequence
 from datetime import timedelta
 from functools import cached_property
-from typing import TYPE_CHECKING, Any
+from typing import TYPE_CHECKING, Any, ClassVar
 
 from pydantic import BaseModel
 
 from airflow.providers.common.ai.hooks.pydantic_ai import PydanticAIHook
 from airflow.providers.common.ai.mixins.hitl_review import HITLReviewMixin
 from airflow.providers.common.ai.utils.logging import log_run_summary, 
wrap_toolsets_for_logging
-from airflow.providers.common.ai.utils.output_type import (
-    iter_base_model_classes,
-    rehydrate_pydantic_output,
-)
+from airflow.providers.common.ai.utils.output_type import 
rehydrate_pydantic_output
 from airflow.providers.common.compat.sdk import (
     AirflowOptionalProviderFeatureException,
     BaseOperator,
@@ -42,9 +39,12 @@ from airflow.providers.common.compat.sdk import (
 from airflow.providers.common.compat.version_compat import AIRFLOW_V_3_1_PLUS
 
 try:
-    from airflow.sdk.serde import allow_class
-except ImportError:  # pragma: no cover - Airflow versions before allow_class 
shipped
-    allow_class = None  # type: ignore[assignment]
+    # See LLMOperator: new enough cores register declared ``output_type`` 
classes
+    # from a worker-side DAG walk, so the model instance flows through XCom; 
older
+    # cores dump to a dict instead.
+    from airflow.sdk.serde import SUPPORTS_OPERATOR_DESERIALIZATION_WALKER as 
_CORE_WALKER
+except ImportError:  # pragma: no cover - cores before the worker-side 
registration walk
+    _CORE_WALKER = False
 
 if TYPE_CHECKING:
     from pydantic_ai import Agent
@@ -150,6 +150,8 @@ class AgentOperator(BaseOperator, HITLReviewMixin):
         when a downstream consumer needs the dict shape.
     """
 
+    deserialization_allowed_class_fields: ClassVar[tuple[str, ...]] = 
("output_type",)
+
     template_fields: Sequence[str] = (
         "prompt",
         "llm_conn_id",
@@ -189,10 +191,9 @@ class AgentOperator(BaseOperator, HITLReviewMixin):
         self.system_prompt = system_prompt
         self.output_type = output_type
         self.serialize_output = serialize_output
-        self._serialize_model_output = serialize_output or allow_class is None
-        if not serialize_output and allow_class is not None:
-            for model_cls in iter_base_model_classes(output_type):
-                allow_class(model_cls)
+        # See LLMOperator: instance flows when the core registers 
``output_type``
+        # via its worker-side DAG walk; otherwise (or on opt-in) dump to a 
dict.
+        self._serialize_model_output = serialize_output or not _CORE_WALKER
         self.toolsets = toolsets
         self.enable_tool_logging = enable_tool_logging
         self.agent_params = agent_params or {}
diff --git 
a/providers/common/ai/src/airflow/providers/common/ai/operators/llm.py 
b/providers/common/ai/src/airflow/providers/common/ai/operators/llm.py
index 9d104db1443..c9a22632f2e 100644
--- a/providers/common/ai/src/airflow/providers/common/ai/operators/llm.py
+++ b/providers/common/ai/src/airflow/providers/common/ai/operators/llm.py
@@ -21,23 +21,24 @@ from __future__ import annotations
 from collections.abc import Sequence
 from datetime import timedelta
 from functools import cached_property
-from typing import TYPE_CHECKING, Any
+from typing import TYPE_CHECKING, Any, ClassVar
 
 from pydantic import BaseModel
 
 from airflow.providers.common.ai.hooks.pydantic_ai import PydanticAIHook
 from airflow.providers.common.ai.mixins.approval import LLMApprovalMixin
 from airflow.providers.common.ai.utils.logging import log_run_summary
-from airflow.providers.common.ai.utils.output_type import (
-    iter_base_model_classes,
-    rehydrate_pydantic_output,
-)
+from airflow.providers.common.ai.utils.output_type import 
rehydrate_pydantic_output
 from airflow.providers.common.compat.sdk import BaseOperator
 
 try:
-    from airflow.sdk.serde import allow_class
-except ImportError:  # pragma: no cover - Airflow versions before allow_class 
shipped
-    allow_class = None  # type: ignore[assignment]
+    # New enough cores register an operator's declared ``output_type`` classes 
for
+    # XCom deserialization from a worker-side walk over the loaded DAG. On 
those
+    # cores the model instance flows through XCom unchanged. Older cores lack 
that
+    # walk, so the operator dumps to a dict instead (still deserializable 
anywhere).
+    from airflow.sdk.serde import SUPPORTS_OPERATOR_DESERIALIZATION_WALKER as 
_CORE_WALKER
+except ImportError:  # pragma: no cover - cores before the worker-side 
registration walk
+    _CORE_WALKER = False
 
 if TYPE_CHECKING:
     from pydantic_ai import Agent
@@ -95,6 +96,8 @@ class LLMOperator(BaseOperator, LLMApprovalMixin):
         external system that expects JSON-style payloads).
     """
 
+    deserialization_allowed_class_fields: ClassVar[tuple[str, ...]] = 
("output_type",)
+
     template_fields: Sequence[str] = (
         "prompt",
         "llm_conn_id",
@@ -126,12 +129,10 @@ class LLMOperator(BaseOperator, LLMApprovalMixin):
         self.system_prompt = system_prompt
         self.output_type = output_type
         self.serialize_output = serialize_output
-        # Skip registration when the user opted into the dict form -- the wire
-        # carries a plain dict in that case and never hits the allow-list gate.
-        self._serialize_model_output = serialize_output or allow_class is None
-        if not serialize_output and allow_class is not None:
-            for model_cls in iter_base_model_classes(output_type):
-                allow_class(model_cls)
+        # Return the Pydantic instance when the core can register 
``output_type``
+        # for deserialization (its worker-side DAG walk); otherwise, or when 
the
+        # user opts in, dump to a dict so the value is deserializable anywhere.
+        self._serialize_model_output = serialize_output or not _CORE_WALKER
         self.agent_params = agent_params or {}
         self.usage_limits = usage_limits
         self.require_approval = require_approval
@@ -173,9 +174,9 @@ class LLMOperator(BaseOperator, LLMApprovalMixin):
             self.defer_for_approval(context, output)  # type: ignore[misc]
 
         if self._serialize_model_output and isinstance(output, BaseModel):
-            # ``serialize_output=True`` was set explicitly, or this is an
-            # older Airflow version without ``airflow.sdk.serde.allow_class``.
-            # Either way, dump to dict so XCom carries a plain JSON payload.
+            # ``serialize_output=True``, or a core without the worker-side
+            # deserialization-class walk: dump to a dict so XCom carries a 
plain
+            # JSON payload that deserializes without an allow-list entry.
             output = output.model_dump()
 
         return output
diff --git 
a/providers/common/ai/src/airflow/providers/common/ai/utils/output_type.py 
b/providers/common/ai/src/airflow/providers/common/ai/utils/output_type.py
index 4d46b35609b..2ae9b6a5bc3 100644
--- a/providers/common/ai/src/airflow/providers/common/ai/utils/output_type.py
+++ b/providers/common/ai/src/airflow/providers/common/ai/utils/output_type.py
@@ -18,41 +18,11 @@
 
 from __future__ import annotations
 
-from collections.abc import Iterator
-from typing import Any, get_args, get_origin
+from typing import Any
 
 from pydantic import BaseModel, ValidationError
 
 
-def iter_base_model_classes(output_type: Any) -> Iterator[type[BaseModel]]:
-    """
-    Yield every Pydantic ``BaseModel`` subclass reachable from ``output_type``.
-
-    pydantic-ai accepts ``output_type`` as a single class, a ``Union`` /
-    ``Optional`` of classes, a list of classes (multi-output), or a 
parameterized
-    generic such as ``list[MyModel]``. The agent may return an instance of any
-    ``BaseModel`` reachable from the type expression, so each must be 
registered
-    for XCom deserialization, not just the top-level ``output_type``.
-    """
-    seen: set[type] = set()
-    stack: list[Any] = [output_type]
-    while stack:
-        t = stack.pop()
-        # ``list[A]`` returns ``True`` for ``isinstance(t, type)`` on Python 
3.10+
-        # but has a non-None ``get_origin``; check origin first so we recurse
-        # into its args instead of treating ``list[A]`` as a leaf type.
-        origin = get_origin(t)
-        if origin is not None:
-            stack.extend(get_args(t))
-            continue
-        if isinstance(t, type):
-            if t in seen:
-                continue
-            seen.add(t)
-            if issubclass(t, BaseModel):
-                yield t
-
-
 def rehydrate_pydantic_output(
     output_type: Any,
     raw: str,
diff --git a/providers/common/ai/tests/unit/common/ai/decorators/test_agent.py 
b/providers/common/ai/tests/unit/common/ai/decorators/test_agent.py
index eb1e27ba87b..25a176e3297 100644
--- a/providers/common/ai/tests/unit/common/ai/decorators/test_agent.py
+++ b/providers/common/ai/tests/unit/common/ai/decorators/test_agent.py
@@ -26,15 +26,13 @@ from airflow.providers.common.ai.decorators.agent import 
_AgentDecoratedOperator
 from airflow.providers.common.ai.toolsets.logging import LoggingToolset
 
 try:
-    from airflow.sdk.serde import allow_class
-
-    _allow_class: object | None = allow_class
+    from airflow.sdk.serde import SUPPORTS_OPERATOR_DESERIALIZATION_WALKER as 
_CORE_WALKER
 except ImportError:
-    _allow_class = None
+    _CORE_WALKER = False
 
-requires_allow_class = pytest.mark.skipif(
-    _allow_class is None,
-    reason="Requires airflow.sdk.serde.allow_class (Airflow with typed-XCom 
support).",
+requires_typed_xcom = pytest.mark.skipif(
+    not _CORE_WALKER,
+    reason="Requires a core with the worker-side deserialization-class walk.",
 )
 
 
@@ -175,7 +173,7 @@ class TestAgentDecoratedOperator:
         assert isinstance(passed_toolsets[0], LoggingToolset)
         assert passed_toolsets[0].wrapped is mock_toolset
 
-    @requires_allow_class
+    @requires_typed_xcom
     @patch("airflow.providers.common.ai.operators.agent.PydanticAIHook", 
autospec=True)
     def test_execute_structured_output(self, mock_hook_cls):
         """BaseModel output flows through XCom as the Pydantic instance."""
diff --git a/providers/common/ai/tests/unit/common/ai/operators/test_agent.py 
b/providers/common/ai/tests/unit/common/ai/operators/test_agent.py
index b934fb77edc..4a0b08cf148 100644
--- a/providers/common/ai/tests/unit/common/ai/operators/test_agent.py
+++ b/providers/common/ai/tests/unit/common/ai/operators/test_agent.py
@@ -29,15 +29,13 @@ from airflow.providers.common.ai.toolsets.logging import 
LoggingToolset
 from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS
 
 try:
-    from airflow.sdk.serde import allow_class
-
-    _allow_class: object | None = allow_class
+    from airflow.sdk.serde import SUPPORTS_OPERATOR_DESERIALIZATION_WALKER as 
_CORE_WALKER
 except ImportError:
-    _allow_class = None
+    _CORE_WALKER = False
 
-requires_allow_class = pytest.mark.skipif(
-    _allow_class is None,
-    reason="Requires airflow.sdk.serde.allow_class (Airflow with typed-XCom 
support).",
+requires_typed_xcom = pytest.mark.skipif(
+    not _CORE_WALKER,
+    reason="Requires a core with the worker-side deserialization-class walk.",
 )
 
 
@@ -210,7 +208,7 @@ class TestAgentOperatorExecute:
         assert create_call[1]["retries"] == 3
         assert create_call[1]["model_settings"] == {"temperature": 0}
 
-    @requires_allow_class
+    @requires_typed_xcom
     @patch("airflow.providers.common.ai.operators.agent.PydanticAIHook", 
autospec=True)
     def test_execute_structured_output(self, mock_hook_cls):
         """Structured output keeps the Pydantic instance so downstream tasks 
can type-hint it."""
@@ -230,26 +228,9 @@ class TestAgentOperatorExecute:
         assert result.text == "Great"
         assert result.score == 0.95
 
-    @requires_allow_class
-    def test_init_rejects_nested_output_type(self):
-        """A BaseModel defined inside a function carries ``<locals>`` and 
can't survive XCom."""
-
-        def _build():
-            class Nested(BaseModel):
-                v: int
-
-            return AgentOperator(task_id="t", prompt="p", llm_conn_id="c", 
output_type=Nested)
-
-        with pytest.raises(ValueError, match="defined inside a function"):
-            _build()
-
-    @requires_allow_class
-    def test_init_registers_output_type_in_extra_allowed(self):
-        from airflow.sdk.module_loading import qualname
-        from airflow.sdk.serde import _extra_allowed
-
-        AgentOperator(task_id="t", prompt="p", llm_conn_id="c", 
output_type=Summary)
-        assert qualname(Summary) in _extra_allowed
+    def test_declares_output_type_for_deserialization(self):
+        """Declares ``output_type`` so the worker-side DAG walk registers it 
for deserialization."""
+        assert "output_type" in 
AgentOperator.deserialization_allowed_class_fields
 
     @patch("airflow.providers.common.ai.operators.agent.PydanticAIHook", 
autospec=True)
     def test_execute_with_model_id(self, mock_hook_cls):
@@ -294,7 +275,7 @@ class TestAgentOperatorExecute:
         assert result == "Approved output"
         mock_run_hitl.assert_called_once_with(op, context, "Initial output", 
message_history=msg_history)
 
-    @requires_allow_class
+    @requires_typed_xcom
     @pytest.mark.skipif(
         not AIRFLOW_V_3_1_PLUS, reason="Human in the loop is only compatible 
with Airflow >= 3.1.0"
     )
diff --git a/providers/common/ai/tests/unit/common/ai/operators/test_llm.py 
b/providers/common/ai/tests/unit/common/ai/operators/test_llm.py
index cfa4cf3e191..2a707752fdf 100644
--- a/providers/common/ai/tests/unit/common/ai/operators/test_llm.py
+++ b/providers/common/ai/tests/unit/common/ai/operators/test_llm.py
@@ -32,15 +32,16 @@ from airflow.providers.common.ai.operators.llm import 
LLMOperator
 from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS
 
 try:
-    from airflow.sdk.serde import allow_class
-
-    _allow_class: object | None = allow_class
+    from airflow.sdk.serde import SUPPORTS_OPERATOR_DESERIALIZATION_WALKER as 
_CORE_WALKER
 except ImportError:
-    _allow_class = None
-
-requires_allow_class = pytest.mark.skipif(
-    _allow_class is None,
-    reason="Requires airflow.sdk.serde.allow_class (Airflow with typed-XCom 
support).",
+    _CORE_WALKER = False
+
+# Returning the Pydantic instance through XCom (rather than a dict) only 
happens
+# on cores that register declared ``output_type`` classes from the worker-side
+# DAG walk. On older cores the operator dumps to a dict, so these tests skip.
+requires_typed_xcom = pytest.mark.skipif(
+    not _CORE_WALKER,
+    reason="Requires a core with the worker-side deserialization-class walk.",
 )
 
 
@@ -104,7 +105,7 @@ class TestLLMOperator:
 
         mock_agent.run_sync.assert_called_once_with("Summarize", 
usage_limits=limits)
 
-    @requires_allow_class
+    @requires_typed_xcom
     @patch("airflow.providers.common.ai.operators.llm.PydanticAIHook", 
autospec=True)
     def test_execute_structured_output_with_all_params(self, mock_hook_cls):
         """Structured output returns the Pydantic instance unchanged so 
downstream tasks keep the type."""
@@ -133,27 +134,13 @@ class TestLLMOperator:
             model_settings={"temperature": 0.9},
         )
 
-    @requires_allow_class
-    def test_init_rejects_nested_output_type(self):
-        """output_type defined inside a function carries ``<locals>`` and 
can't survive XCom."""
-
-        def _build_op():
-            class Nested(BaseModel):
-                v: int
-
-            return LLMOperator(task_id="t", prompt="p", llm_conn_id="c", 
output_type=Nested)
-
-        with pytest.raises(ValueError, match="defined inside a function"):
-            _build_op()
-
-    @requires_allow_class
-    def test_init_registers_output_type_in_extra_allowed(self):
-        """A module-scope BaseModel output_type is auto-registered for XCom 
deserialization."""
-        from airflow.sdk.module_loading import qualname
-        from airflow.sdk.serde import _extra_allowed
+    def test_declares_output_type_for_deserialization(self):
+        """Declares ``output_type`` so the worker-side DAG walk registers it 
for deserialization.
 
-        LLMOperator(task_id="t", prompt="p", llm_conn_id="c", 
output_type=Entities)
-        assert qualname(Entities) in _extra_allowed
+        Registration happens in the core walk over the loaded DAG (covered by 
the
+        task-runner tests), not as an ``__init__`` side effect.
+        """
+        assert "output_type" in 
LLMOperator.deserialization_allowed_class_fields
 
     @patch("airflow.providers.common.ai.operators.llm.PydanticAIHook", 
autospec=True)
     def test_execute_serialize_output_returns_dict(self, mock_hook_cls):
@@ -354,7 +341,7 @@ class TestLLMOperatorApproval:
 
         assert result == "edited"
 
-    @requires_allow_class
+    @requires_typed_xcom
     def test_execute_complete_rehydrates_pydantic_for_structured_output(self):
         """When output_type is a BaseModel, execute_complete returns the 
model, not the JSON string."""
         op = LLMOperator(task_id="t", prompt="p", llm_conn_id="c", 
output_type=Summary)
diff --git 
a/providers/common/ai/tests/unit/common/ai/operators/test_llm_file_analysis.py 
b/providers/common/ai/tests/unit/common/ai/operators/test_llm_file_analysis.py
index 7b32e27451d..6c970e43263 100644
--- 
a/providers/common/ai/tests/unit/common/ai/operators/test_llm_file_analysis.py
+++ 
b/providers/common/ai/tests/unit/common/ai/operators/test_llm_file_analysis.py
@@ -29,15 +29,13 @@ from airflow.providers.common.ai.utils.file_analysis import 
FileAnalysisRequest
 from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS
 
 try:
-    from airflow.sdk.serde import allow_class
-
-    _allow_class: object | None = allow_class
+    from airflow.sdk.serde import SUPPORTS_OPERATOR_DESERIALIZATION_WALKER as 
_CORE_WALKER
 except ImportError:
-    _allow_class = None
+    _CORE_WALKER = False
 
-requires_allow_class = pytest.mark.skipif(
-    _allow_class is None,
-    reason="Requires airflow.sdk.serde.allow_class (Airflow with typed-XCom 
support).",
+requires_typed_xcom = pytest.mark.skipif(
+    not _CORE_WALKER,
+    reason="Requires a core with the worker-side deserialization-class walk.",
 )
 
 
@@ -119,7 +117,7 @@ class TestLLMFileAnalysisOperator:
         )
         mock_agent.run_sync.assert_called_once_with("prepared prompt", 
usage_limits=None)
 
-    @requires_allow_class
+    @requires_typed_xcom
     @patch("airflow.providers.common.ai.operators.llm.PydanticAIHook", 
autospec=True)
     @patch(
         
"airflow.providers.common.ai.operators.llm_file_analysis.build_file_analysis_request",
 autospec=True
@@ -272,7 +270,7 @@ class TestLLMFileAnalysisOperatorApproval:
         assert exc_info.value.kwargs["generated_output"] == 
'{"findings":["error spike"]}'
         mock_upsert.assert_called_once()
 
-    @requires_allow_class
+    @requires_typed_xcom
     def test_execute_complete_with_approval_restores_structured_output(self):
         op = LLMFileAnalysisOperator(
             task_id="approval_complete_test",
@@ -289,7 +287,7 @@ class TestLLMFileAnalysisOperatorApproval:
         assert isinstance(result, Summary)
         assert result.findings == ["error spike"]
 
-    @requires_allow_class
+    @requires_typed_xcom
     def 
test_execute_complete_with_approval_restores_modified_structured_output(self):
         op = LLMFileAnalysisOperator(
             task_id="approval_complete_modified_test",
diff --git a/providers/common/ai/tests/unit/common/ai/utils/test_output_type.py 
b/providers/common/ai/tests/unit/common/ai/utils/test_output_type.py
index 45971f46f60..4c3dcae5ee2 100644
--- a/providers/common/ai/tests/unit/common/ai/utils/test_output_type.py
+++ b/providers/common/ai/tests/unit/common/ai/utils/test_output_type.py
@@ -18,53 +18,13 @@ from __future__ import annotations
 
 from pydantic import BaseModel
 
-from airflow.providers.common.ai.utils.output_type import (
-    iter_base_model_classes,
-    rehydrate_pydantic_output,
-)
+from airflow.providers.common.ai.utils.output_type import 
rehydrate_pydantic_output
 
 
 class A(BaseModel):
     x: int
 
 
-class B(BaseModel):
-    y: str
-
-
-class C(BaseModel):
-    z: float
-
-
-class TestIterBaseModelClasses:
-    def test_single_class(self):
-        assert set(iter_base_model_classes(A)) == {A}
-
-    def test_str_skipped(self):
-        assert set(iter_base_model_classes(str)) == set()
-
-    def test_optional(self):
-        assert set(iter_base_model_classes(A | None)) == {A}
-
-    def test_union(self):
-        assert set(iter_base_model_classes(A | B)) == {A, B}
-
-    def test_list_of_models(self):
-        assert set(iter_base_model_classes(list[A])) == {A}
-
-    def test_dict_with_model_values(self):
-        assert set(iter_base_model_classes(dict[str, A])) == {A}
-
-    def test_nested_union_list_optional(self):
-        assert set(iter_base_model_classes(list[A | B | None])) == {A, B}
-
-    def test_mixed_with_primitives(self):
-        assert set(iter_base_model_classes(A | str | int | B)) == {A, B}
-
-    def test_three_models(self):
-        assert set(iter_base_model_classes(A | B | C)) == {A, B, C}
-
-
 class TestRehydratePydanticOutput:
     def test_returns_model_instance(self):
         result = rehydrate_pydantic_output(A, '{"x": 7}', 
serialize_output=False)
diff --git a/task-sdk/src/airflow/sdk/bases/operator.py 
b/task-sdk/src/airflow/sdk/bases/operator.py
index 8d6de54eb6d..97da8869686 100644
--- a/task-sdk/src/airflow/sdk/bases/operator.py
+++ b/task-sdk/src/airflow/sdk/bases/operator.py
@@ -929,6 +929,13 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
 
     template_fields_renderers: ClassVar[dict[str, str]] = {}
 
+    # Names of constructor fields whose values may contain Pydantic model 
classes
+    # this operator can emit to XCom (e.g. ``("output_type",)``). The worker
+    # registers those classes in the deserialization allow-list before running
+    # tasks, so downstream consumers can deserialize the instances without an
+    # ``[core] allowed_deserialization_classes`` edit. Empty by default.
+    deserialization_allowed_class_fields: ClassVar[tuple[str, ...]] = ()
+
     operator_extra_links: Collection[BaseOperatorLink] = ()
 
     # Defines the color in the UI
diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py 
b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
index 0ff9be4e8c9..26f05d3a9d3 100644
--- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -137,6 +137,7 @@ from airflow.sdk.execution_time.sentry import Sentry
 from airflow.sdk.execution_time.xcom import XCom
 from airflow.sdk.listener import get_listener_manager
 from airflow.sdk.observability.metrics import stats_utils
+from airflow.sdk.serde import allow_class, iter_pydantic_models
 from airflow.sdk.state import TaskScope
 from airflow.sdk.timezone import coerce_datetime
 
@@ -837,6 +838,42 @@ def _maybe_reschedule_startup_failure(
     )
 
 
+def _register_deserialization_allowed_classes(dag, log: Logger) -> None:
+    """
+    Register every operator-declared XCom model class in the deserialization 
allow-list.
+
+    Runs once per task-run startup, walking the whole DAG, so a consumer task 
can
+    deserialize a producer's structured output even though only the producer
+    constructs the value. Reads the declared classes off real operators
+    (``getattr``) and off not-yet-expanded mapped operators 
(``partial_kwargs``),
+    and works regardless of how the DAG was loaded -- a fresh parse, or a cache
+    that reconstructs operators without running ``__init__``.
+
+    Failures to register a single class are logged and skipped rather than 
failing
+    task startup; a genuinely undeserializable declaration surfaces later as 
the
+    normal allow-list ImportError at consume time.
+    """
+    for op in dag.tasks:
+        if isinstance(op, MappedOperator):
+            fields = getattr(op.operator_class, 
"deserialization_allowed_class_fields", ())
+            values = [op.partial_kwargs.get(field) for field in fields]
+        else:
+            fields = getattr(op, "deserialization_allowed_class_fields", ())
+            values = [getattr(op, field, None) for field in fields]
+        for value in values:
+            if value is None:
+                continue
+            for model_cls in iter_pydantic_models(value):
+                try:
+                    allow_class(model_cls)
+                except ValueError as exc:
+                    log.warning(
+                        "Skipping XCom deserialization registration for a 
model class",
+                        task_id=op.task_id,
+                        error=str(exc),
+                    )
+
+
 @detail_span("parse")
 def parse(what: StartupDetails, log: Logger) -> RuntimeTaskInstance:
     # TODO: Task-SDK:
@@ -895,6 +932,12 @@ def parse(what: StartupDetails, log: Logger) -> 
RuntimeTaskInstance:
             f"task is of the wrong type, got {type(task)}, wanted 
{BaseOperator} or {MappedOperator}"
         )
 
+    # Register operator-declared XCom model classes (e.g. ``output_type``) for
+    # the whole DAG so this task can deserialize structured output produced by
+    # any other task -- including mapped producers and DAGs loaded from a cache
+    # that bypasses operator ``__init__``.
+    _register_deserialization_allowed_classes(dag, log)
+
     # Surface the post-RUNNING startup breakdown so support engineers and DAG 
authors can
     # attribute apparent slow startup to bundle prep (Airflow-side, e.g. git 
fetch) vs.
     # DAG file parse (user code). Emitted before return so it lands in the 
task log.
diff --git a/task-sdk/src/airflow/sdk/serde/__init__.py 
b/task-sdk/src/airflow/sdk/serde/__init__.py
index 0b9a383fe41..028d1d391cf 100644
--- a/task-sdk/src/airflow/sdk/serde/__init__.py
+++ b/task-sdk/src/airflow/sdk/serde/__init__.py
@@ -22,10 +22,11 @@ import functools
 import logging
 import re
 import sys
+from collections.abc import Iterator
 from fnmatch import fnmatch
 from importlib import import_module
 from re import Pattern
-from typing import TYPE_CHECKING, Any, TypeVar, cast, overload
+from typing import TYPE_CHECKING, Any, TypeVar, cast, get_args, get_origin, 
overload
 
 import attr
 
@@ -56,6 +57,12 @@ PYDANTIC_MODEL_QUALNAME = "pydantic.main.BaseModel"
 
 DEFAULT_VERSION = 0
 
+# Signals that this Airflow registers operator-declared deserialization classes
+# from a worker-side walk over the loaded DAG (see the task runner), so 
operators
+# do not need to register them as an ``__init__`` side effect. Providers probe
+# this to drop their back-compat ``__init__`` registration on new enough cores.
+SUPPORTS_OPERATOR_DESERIALIZATION_WALKER = True
+
 T = TypeVar("T", bool, float, int, dict, list, str, tuple, set)
 U = bool | float | int | dict | list | str | tuple | set
 S = list | tuple | set
@@ -117,6 +124,35 @@ def allow_class(cls: type) -> None:
     _extra_allowed.add(qn)
 
 
+def iter_pydantic_models(annotation: Any) -> Iterator[type]:
+    """
+    Yield every Pydantic model class reachable from a type annotation.
+
+    Handles a bare model class, ``Optional`` / ``Union`` of models, and
+    parameterized containers such as ``list[MyModel]`` -- the shapes accepted 
as
+    an operator ``output_type``. The agent (or operator) may emit an instance 
of
+    any model reachable from the annotation, so each must be registered for 
XCom
+    deserialization, not just the top-level type.
+    """
+    seen: set[Any] = set()
+    stack: list[Any] = [annotation]
+    while stack:
+        tp = stack.pop()
+        # ``list[A]`` answers ``True`` to ``isinstance(tp, type)`` on 3.10+ yet
+        # carries a non-None ``get_origin``; recurse into its args first so the
+        # container itself is not mistaken for a leaf type.
+        origin = get_origin(tp)
+        if origin is not None:
+            stack.extend(get_args(tp))
+            continue
+        if isinstance(tp, type):
+            if tp in seen:
+                continue
+            seen.add(tp)
+            if is_pydantic_model(tp):
+                yield tp
+
+
 def decode(d: dict[str, Any]) -> tuple[str, int, Any]:
     classname = d[CLASSNAME]
     version = d[VERSION]
diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py 
b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
index 507e63cfff0..097c3c3d0c3 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
@@ -32,11 +32,13 @@ from unittest.mock import call, patch
 
 import pandas as pd
 import pytest
+import structlog
 from opentelemetry import trace
 from opentelemetry.sdk.trace import TracerProvider
 from opentelemetry.sdk.trace.export import SimpleSpanProcessor
 from opentelemetry.sdk.trace.export.in_memory_span_exporter import 
InMemorySpanExporter
 from opentelemetry.trace.propagation.tracecontext import 
TraceContextTextMapPropagator
+from pydantic import BaseModel
 from task_sdk import FAKE_BUNDLE
 from uuid6 import uuid7
 
@@ -162,6 +164,7 @@ from airflow.sdk.execution_time.task_runner import (
     _execute_task,
     _make_task_span,
     _push_xcom_if_needed,
+    _register_deserialization_allowed_classes,
     _serialize_outlet_events,
     _xcom_push,
     detail_span,
@@ -248,9 +251,14 @@ def test_parse_dag_bag(mock_dagbag, test_dags_dir: Path, 
make_ti_context):
     mock_dagbag.return_value = mock_bag_instance
     mock_dag = mock.Mock(spec=DAG)
     mock_task = mock.Mock(spec=BaseOperator)
+    # The worker walks dag.tasks to register declared deserialization classes;
+    # give the mock an iterable tasks list and an empty declaration so the walk
+    # is a no-op for this BundleDagBag-construction test.
+    mock_task.deserialization_allowed_class_fields = ()
 
     mock_bag_instance.dags = {"super_basic": mock_dag}
     mock_dag.task_dict = {"a": mock_task}
+    mock_dag.tasks = [mock_task]
 
     what = StartupDetails(
         ti=TaskInstanceDTO(
@@ -5662,3 +5670,70 @@ class TestTaskInstanceStateOperations:
             for call in mock_supervisor_comms.send.call_args_list
         ]
         assert ClearTaskState not in sent_types
+
+
+class _WalkerModelA(BaseModel):
+    a: int
+
+
+class _WalkerModelB(BaseModel):
+    b: int
+
+
+class _WalkerOperator(BaseOperator):
+    """Operator that declares a Pydantic ``output_type`` for 
deserialization."""
+
+    deserialization_allowed_class_fields = ("output_type",)
+
+    def __init__(self, *, output_type=str, value=None, **kwargs):
+        super().__init__(**kwargs)
+        self.output_type = output_type
+        self.value = value
+
+    def execute(self, context):
+        return None
+
+
+class TestRegisterDeserializationAllowedClasses:
+    """The worker-side walk registers operator-declared XCom classes for the 
whole DAG.
+
+    ``allow_class`` is patched to a spy so these assert the walk *extracts* 
the right
+    classes from both real and mapped operators; ``allow_class``'s own import
+    validation is covered by the serde tests.
+    """
+
+    def test_registers_real_and_mapped_operators(self):
+
+        with DAG("walker_dag") as dag:
+            # Non-mapped producer: output_type is a plain attribute.
+            _WalkerOperator(task_id="real", output_type=_WalkerModelA)
+            # Mapped producer: output_type lives in partial_kwargs and 
__init__ never
+            # runs at parse -- the case the old __init__ registration could 
not reach.
+            _WalkerOperator.partial(task_id="mapped", 
output_type=_WalkerModelB).expand(value=[1, 2])
+
+        registered: list[type] = []
+        with patch("airflow.sdk.execution_time.task_runner.allow_class", 
side_effect=registered.append):
+            _register_deserialization_allowed_classes(dag, 
structlog.get_logger())
+
+        assert _WalkerModelA in registered, "real operator output_type not 
registered"
+        assert _WalkerModelB in registered, "mapped operator output_type not 
registered"
+
+    def test_default_operator_registers_nothing(self):
+
+        with DAG("walker_dag_plain") as dag:
+            BaseOperator(task_id="plain")
+
+        registered: list[type] = []
+        with patch("airflow.sdk.execution_time.task_runner.allow_class", 
side_effect=registered.append):
+            _register_deserialization_allowed_classes(dag, 
structlog.get_logger())
+        assert registered == []
+
+    def test_bad_declaration_is_skipped_not_fatal(self):
+        """A class that fails ``allow_class`` validation is logged and 
skipped, not raised."""
+
+        with DAG("walker_dag_bad") as dag:
+            _WalkerOperator(task_id="real", output_type=_WalkerModelA)
+
+        with patch("airflow.sdk.execution_time.task_runner.allow_class", 
side_effect=ValueError("nope")):
+            # Must not raise -- the walk swallows per-class registration 
errors.
+            _register_deserialization_allowed_classes(dag, 
structlog.get_logger())
diff --git a/task-sdk/tests/task_sdk/serde/test_serde.py 
b/task-sdk/tests/task_sdk/serde/test_serde.py
index 890ed436d39..bf41f4e3ed4 100644
--- a/task-sdk/tests/task_sdk/serde/test_serde.py
+++ b/task-sdk/tests/task_sdk/serde/test_serde.py
@@ -44,6 +44,7 @@ from airflow.sdk.serde import (
     _match_regexp,
     allow_class,
     deserialize,
+    iter_pydantic_models,
     serialize,
 )
 
@@ -455,6 +456,25 @@ class TestSerDe:
         with pytest.raises(ValueError, match="cannot be re-imported|does not 
resolve"):
             allow_class(Mismatched)
 
+    def test_iter_pydantic_models_shapes(self):
+        """iter_pydantic_models finds models in bare, optional/union, and 
container annotations."""
+
+        class M1(BaseModel):
+            a: int
+
+        class M2(BaseModel):
+            b: int
+
+        assert set(iter_pydantic_models(M1)) == {M1}
+        assert set(iter_pydantic_models(M1 | None)) == {M1}
+        assert set(iter_pydantic_models(M1 | M2)) == {M1, M2}
+        assert set(iter_pydantic_models(list[M1])) == {M1}
+        assert set(iter_pydantic_models(dict[str, M2])) == {M2}
+        assert set(iter_pydantic_models(list[M1 | M2 | None])) == {M1, M2}
+        # Non-model annotations yield nothing.
+        assert set(iter_pydantic_models(str)) == set()
+        assert set(iter_pydantic_models(list[int])) == set()
+
     def test_incompatible_version(self):
         data = dict(
             {


Reply via email to