This is an automated email from the ASF dual-hosted git repository.

weilee pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 85b2666eab Add start execution from triggerer support to dynamic task 
mapping (#39912)
85b2666eab is described below

commit 85b2666eabc655a99a31609a7a27a3c577c1eefb
Author: Wei Lee <[email protected]>
AuthorDate: Mon Jul 22 16:29:37 2024 +0800

    Add start execution from triggerer support to dynamic task mapping (#39912)
    
    * feat(dagrun): add start_from_trigger support to mapped operator
    * feat(mapped_operator): add partial support to start_trigger_args
    * feat(mappedoperator): do not include xcom when expanding start trigger 
args and flag
---
 airflow/decorators/base.py                         |  6 +-
 airflow/models/abstractoperator.py                 | 22 +++++++
 airflow/models/baseoperator.py                     | 22 +++++++
 airflow/models/dagrun.py                           | 20 ++++--
 airflow/models/expandinput.py                      | 31 +++++----
 airflow/models/mappedoperator.py                   | 75 ++++++++++++++++++++--
 airflow/models/param.py                            |  2 +-
 airflow/models/taskinstance.py                     | 19 ++++--
 airflow/models/xcom_arg.py                         | 16 ++---
 airflow/template/templater.py                      |  4 +-
 airflow/utils/mixins.py                            |  2 +-
 .../authoring-and-scheduling/deferring.rst         | 46 ++++++++++++-
 tests/models/test_dagrun.py                        | 32 +++++++++
 13 files changed, 254 insertions(+), 43 deletions(-)

diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py
index 5a20fa55dd..d743acbe50 100644
--- a/airflow/decorators/base.py
+++ b/airflow/decorators/base.py
@@ -550,11 +550,13 @@ class DecoratedMappedOperator(MappedOperator):
         super(DecoratedMappedOperator, 
DecoratedMappedOperator).__attrs_post_init__(self)
         XComArg.apply_upstream_relationship(self, 
self.op_kwargs_expand_input.value)
 
-    def _expand_mapped_kwargs(self, context: Context, session: Session) -> 
tuple[Mapping[str, Any], set[int]]:
+    def _expand_mapped_kwargs(
+        self, context: Context, session: Session, *, include_xcom: bool
+    ) -> tuple[Mapping[str, Any], set[int]]:
         # We only use op_kwargs_expand_input so this must always be empty.
         if self.expand_input is not EXPAND_INPUT_EMPTY:
             raise AssertionError(f"unexpected expand_input: 
{self.expand_input}")
-        op_kwargs, resolved_oids = super()._expand_mapped_kwargs(context, 
session)
+        op_kwargs, resolved_oids = super()._expand_mapped_kwargs(context, 
session, include_xcom=include_xcom)
         return {"op_kwargs": op_kwargs}, resolved_oids
 
     def _get_unmap_kwargs(self, mapped_kwargs: Mapping[str, Any], *, strict: 
bool) -> dict[str, Any]:
diff --git a/airflow/models/abstractoperator.py 
b/airflow/models/abstractoperator.py
index 89e8b6cc72..9cf1830bb4 100644
--- a/airflow/models/abstractoperator.py
+++ b/airflow/models/abstractoperator.py
@@ -55,6 +55,7 @@ if TYPE_CHECKING:
     from airflow.models.operator import Operator
     from airflow.models.taskinstance import TaskInstance
     from airflow.task.priority_strategy import PriorityWeightStrategy
+    from airflow.triggers.base import StartTriggerArgs
     from airflow.utils.task_group import TaskGroup
 
 DEFAULT_OWNER: str = conf.get_mandatory_value("operators", "default_owner")
@@ -427,6 +428,27 @@ class AbstractOperator(Templater, DAGNode):
         """
         raise NotImplementedError()
 
+    def expand_start_from_trigger(self, *, context: Context, session: Session) 
-> bool:
+        """
+        Get the start_from_trigger value of the current abstract operator.
+
+        MappedOperator uses this to unmap start_from_trigger to decide whether 
to start the task
+        execution directly from triggerer.
+
+        :meta private:
+        """
+        raise NotImplementedError()
+
+    def expand_start_trigger_args(self, *, context: Context, session: Session) 
-> StartTriggerArgs | None:
+        """
+        Get the start_trigger_args value of the current abstract operator.
+
+        MappedOperator uses this to unmap start_trigger_args to decide how to 
start a task from triggerer.
+
+        :meta private:
+        """
+        raise NotImplementedError()
+
     @property
     def priority_weight_total(self) -> int:
         """
diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index 30ab591867..8525b78f60 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -1795,6 +1795,28 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
         """
         return self
 
+    def expand_start_from_trigger(self, *, context: Context, session: Session) 
-> bool:
+        """
+        Get the start_from_trigger value of the current abstract operator.
+
+        Since a BaseOperator is not mapped to begin with, this simply returns
+        the original value of start_from_trigger.
+
+        :meta private:
+        """
+        return self.start_from_trigger
+
+    def expand_start_trigger_args(self, *, context: Context, session: Session) 
-> StartTriggerArgs | None:
+        """
+        Get the start_trigger_args value of the current abstract operator.
+
+        Since a BaseOperator is not mapped to begin with, this simply returns
+        the original value of start_trigger_args.
+
+        :meta private:
+        """
+        return self.start_trigger_args
+
 
 # TODO: Deprecate for Airflow 3.0
 Chainable = Union[DependencyMixin, Sequence[DependencyMixin]]
diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py
index d4ef937e9d..d47be6e74b 100644
--- a/airflow/models/dagrun.py
+++ b/airflow/models/dagrun.py
@@ -1577,11 +1577,21 @@ class DagRun(Base, LoggingMixin):
                 and not ti.task.outlets
             ):
                 dummy_ti_ids.append((ti.task_id, ti.map_index))
-            elif ti.task.start_from_trigger is True and 
ti.task.start_trigger_args is not None:
-                ti.start_date = timezone.utcnow()
-                if ti.state != TaskInstanceState.UP_FOR_RESCHEDULE:
-                    ti.try_number += 1
-                ti.defer_task(exception=None, session=session)
+            # check "start_trigger_args" to see whether the operator supports 
start execution from triggerer
+            # if so, we'll then check "start_from_trigger" to see whether this 
feature is turned on and defer
+            # this task.
+            # if not, we'll add this "ti" into "schedulable_ti_ids" and later 
execute it to run in the worker
+            elif ti.task.start_trigger_args is not None:
+                context = ti.get_template_context()
+                start_from_trigger = 
ti.task.expand_start_from_trigger(context=context, session=session)
+
+                if start_from_trigger:
+                    ti.start_date = timezone.utcnow()
+                    if ti.state != TaskInstanceState.UP_FOR_RESCHEDULE:
+                        ti.try_number += 1
+                    ti.defer_task(exception=None, session=session)
+                else:
+                    schedulable_ti_ids.append((ti.task_id, ti.map_index))
             else:
                 schedulable_ti_ids.append((ti.task_id, ti.map_index))
 
diff --git a/airflow/models/expandinput.py b/airflow/models/expandinput.py
index 4673eb4960..417c3bd0c1 100644
--- a/airflow/models/expandinput.py
+++ b/airflow/models/expandinput.py
@@ -69,8 +69,8 @@ class MappedArgument(ResolveMixin):
         yield from self._input.iter_references()
 
     @provide_session
-    def resolve(self, context: Context, *, session: Session = NEW_SESSION) -> 
Any:
-        data, _ = self._input.resolve(context, session=session)
+    def resolve(self, context: Context, *, include_xcom: bool, session: 
Session = NEW_SESSION) -> Any:
+        data, _ = self._input.resolve(context, session=session, 
include_xcom=include_xcom)
         return data[self._key]
 
 
@@ -165,9 +165,11 @@ class DictOfListsExpandInput(NamedTuple):
         lengths = self._get_map_lengths(run_id, session=session)
         return functools.reduce(operator.mul, (lengths[name] for name in 
self.value), 1)
 
-    def _expand_mapped_field(self, key: str, value: Any, context: Context, *, 
session: Session) -> Any:
-        if _needs_run_time_resolution(value):
-            value = value.resolve(context, session=session)
+    def _expand_mapped_field(
+        self, key: str, value: Any, context: Context, *, session: Session, 
include_xcom: bool
+    ) -> Any:
+        if include_xcom and _needs_run_time_resolution(value):
+            value = value.resolve(context, session=session, 
include_xcom=include_xcom)
         map_index = context["ti"].map_index
         if map_index < 0:
             raise RuntimeError("can't resolve task-mapping argument without 
expanding")
@@ -203,8 +205,13 @@ class DictOfListsExpandInput(NamedTuple):
             if isinstance(x, XComArg):
                 yield from x.iter_references()
 
-    def resolve(self, context: Context, session: Session) -> 
tuple[Mapping[str, Any], set[int]]:
-        data = {k: self._expand_mapped_field(k, v, context, session=session) 
for k, v in self.value.items()}
+    def resolve(
+        self, context: Context, session: Session, *, include_xcom: bool
+    ) -> tuple[Mapping[str, Any], set[int]]:
+        data = {
+            k: self._expand_mapped_field(k, v, context, session=session, 
include_xcom=include_xcom)
+            for k, v in self.value.items()
+        }
         literal_keys = {k for k, _ in self._iter_parse_time_resolved_kwargs()}
         resolved_oids = {id(v) for k, v in data.items() if k not in 
literal_keys}
         return data, resolved_oids
@@ -248,7 +255,9 @@ class ListOfDictsExpandInput(NamedTuple):
                 if isinstance(x, XComArg):
                     yield from x.iter_references()
 
-    def resolve(self, context: Context, session: Session) -> 
tuple[Mapping[str, Any], set[int]]:
+    def resolve(
+        self, context: Context, session: Session, *, include_xcom: bool
+    ) -> tuple[Mapping[str, Any], set[int]]:
         map_index = context["ti"].map_index
         if map_index < 0:
             raise RuntimeError("can't resolve task-mapping argument without 
expanding")
@@ -257,9 +266,9 @@ class ListOfDictsExpandInput(NamedTuple):
         if isinstance(self.value, collections.abc.Sized):
             mapping = self.value[map_index]
             if not isinstance(mapping, collections.abc.Mapping):
-                mapping = mapping.resolve(context, session)
-        else:
-            mappings = self.value.resolve(context, session)
+                mapping = mapping.resolve(context, session, 
include_xcom=include_xcom)
+        elif include_xcom:
+            mappings = self.value.resolve(context, session, 
include_xcom=include_xcom)
             if not isinstance(mappings, collections.abc.Sequence):
                 raise ValueError(f"expand_kwargs() expects a list[dict], not 
{_describe_type(mappings)}")
             mapping = mappings[map_index]
diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py
index 45a6ae1ac1..2377fdab00 100644
--- a/airflow/models/mappedoperator.py
+++ b/airflow/models/mappedoperator.py
@@ -51,6 +51,7 @@ from airflow.models.pool import Pool
 from airflow.serialization.enums import DagAttributeTypes
 from airflow.task.priority_strategy import PriorityWeightStrategy, 
validate_and_load_priority_weight_strategy
 from airflow.ti_deps.deps.mapped_task_expanded import MappedTaskIsExpanded
+from airflow.triggers.base import StartTriggerArgs
 from airflow.typing_compat import Literal
 from airflow.utils.context import context_update_for_unmapped
 from airflow.utils.helpers import is_container, prevent_duplicates
@@ -81,7 +82,6 @@ if TYPE_CHECKING:
     from airflow.models.param import ParamsDict
     from airflow.models.xcom_arg import XComArg
     from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
-    from airflow.triggers.base import StartTriggerArgs
     from airflow.utils.context import Context
     from airflow.utils.operator_resources import Resources
     from airflow.utils.task_group import TaskGroup
@@ -688,14 +688,16 @@ class MappedOperator(AbstractOperator):
         """Implement DAGNode."""
         return DagAttributeTypes.OP, self.task_id
 
-    def _expand_mapped_kwargs(self, context: Context, session: Session) -> 
tuple[Mapping[str, Any], set[int]]:
+    def _expand_mapped_kwargs(
+        self, context: Context, session: Session, *, include_xcom: bool
+    ) -> tuple[Mapping[str, Any], set[int]]:
         """
         Get the kwargs to create the unmapped operator.
 
         This exists because taskflow operators expand against op_kwargs, not 
the
         entire operator kwargs dict.
         """
-        return self._get_specified_expand_input().resolve(context, session)
+        return self._get_specified_expand_input().resolve(context, session, 
include_xcom=include_xcom)
 
     def _get_unmap_kwargs(self, mapped_kwargs: Mapping[str, Any], *, strict: 
bool) -> dict[str, Any]:
         """
@@ -729,6 +731,69 @@ class MappedOperator(AbstractOperator):
             "params": params,
         }
 
+    def expand_start_from_trigger(self, *, context: Context, session: Session) 
-> bool:
+        """
+        Get the start_from_trigger value of the current abstract operator.
+
+        MappedOperator uses this to unmap start_from_trigger to decide whether 
to start the task
+        execution directly from triggerer.
+
+        :meta private:
+        """
+        # start_from_trigger only makes sense when start_trigger_args exists.
+        if not self.start_trigger_args:
+            return False
+
+        mapped_kwargs, _ = self._expand_mapped_kwargs(context, session, 
include_xcom=False)
+        if self._disallow_kwargs_override:
+            prevent_duplicates(
+                self.partial_kwargs,
+                mapped_kwargs,
+                fail_reason="unmappable or already specified",
+            )
+
+        # Ordering is significant; mapped kwargs should override partial ones.
+        return mapped_kwargs.get(
+            "start_from_trigger", 
self.partial_kwargs.get("start_from_trigger", self.start_from_trigger)
+        )
+
+    def expand_start_trigger_args(self, *, context: Context, session: Session) 
-> StartTriggerArgs | None:
+        """
+        Get the kwargs to create the unmapped start_trigger_args.
+
+        This method is for allowing mapped operator to start execution from 
triggerer.
+        """
+        if not self.start_trigger_args:
+            return None
+
+        mapped_kwargs, _ = self._expand_mapped_kwargs(context, session, 
include_xcom=False)
+        if self._disallow_kwargs_override:
+            prevent_duplicates(
+                self.partial_kwargs,
+                mapped_kwargs,
+                fail_reason="unmappable or already specified",
+            )
+
+        # Ordering is significant; mapped kwargs should override partial ones.
+        trigger_kwargs = mapped_kwargs.get(
+            "trigger_kwargs",
+            self.partial_kwargs.get("trigger_kwargs", 
self.start_trigger_args.trigger_kwargs),
+        )
+        next_kwargs = mapped_kwargs.get(
+            "next_kwargs",
+            self.partial_kwargs.get("next_kwargs", 
self.start_trigger_args.next_kwargs),
+        )
+        timeout = mapped_kwargs.get(
+            "trigger_timeout", self.partial_kwargs.get("trigger_timeout", 
self.start_trigger_args.timeout)
+        )
+        return StartTriggerArgs(
+            trigger_cls=self.start_trigger_args.trigger_cls,
+            trigger_kwargs=trigger_kwargs,
+            next_method=self.start_trigger_args.next_method,
+            next_kwargs=next_kwargs,
+            timeout=timeout,
+        )
+
     def unmap(self, resolve: None | Mapping[str, Any] | tuple[Context, 
Session]) -> BaseOperator:
         """
         Get the "normal" Operator after applying the current mapping.
@@ -749,7 +814,7 @@ class MappedOperator(AbstractOperator):
             if isinstance(resolve, collections.abc.Mapping):
                 kwargs = resolve
             elif resolve is not None:
-                kwargs, _ = self._expand_mapped_kwargs(*resolve)
+                kwargs, _ = self._expand_mapped_kwargs(*resolve, 
include_xcom=True)
             else:
                 raise RuntimeError("cannot unmap a non-serialized operator 
without context")
             kwargs = self._get_unmap_kwargs(kwargs, 
strict=self._disallow_kwargs_override)
@@ -844,7 +909,7 @@ class MappedOperator(AbstractOperator):
         # set_current_task_session context manager to store the session in the 
current task.
         session = get_current_task_instance_session()
 
-        mapped_kwargs, seen_oids = self._expand_mapped_kwargs(context, session)
+        mapped_kwargs, seen_oids = self._expand_mapped_kwargs(context, 
session, include_xcom=True)
         unmapped_task = self.unmap(mapped_kwargs)
         context_update_for_unmapped(context, unmapped_task)
 
diff --git a/airflow/models/param.py b/airflow/models/param.py
index 3f56e50cd1..a5b9504a06 100644
--- a/airflow/models/param.py
+++ b/airflow/models/param.py
@@ -329,7 +329,7 @@ class DagParam(ResolveMixin):
     def iter_references(self) -> Iterable[tuple[Operator, str]]:
         return ()
 
-    def resolve(self, context: Context) -> Any:
+    def resolve(self, context: Context, *, include_xcom: bool) -> Any:
         """Pull DagParam value from DagRun context. This method is run during 
``op.execute()``."""
         with contextlib.suppress(KeyError):
             return context["dag_run"].conf[self._name]
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 393da63c74..086e89b0eb 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -83,6 +83,7 @@ from airflow.exceptions import (
     AirflowTaskTimeout,
     DagRunNotFound,
     RemovedInAirflow3Warning,
+    TaskDeferralError,
     TaskDeferred,
     UnmappableXComLengthPushed,
     UnmappableXComTypePushed,
@@ -1617,15 +1618,23 @@ def _defer_task(
         next_kwargs = exception.kwargs
         timeout = exception.timeout
     elif ti.task is not None and ti.task.start_trigger_args is not None:
+        context = ti.get_template_context()
+        start_trigger_args = 
ti.task.expand_start_trigger_args(context=context, session=session)
+        if start_trigger_args is None:
+            raise TaskDeferralError(
+                "A none 'None' start_trigger_args has been change to 'None' 
during expandion"
+            )
+
+        trigger_kwargs = start_trigger_args.trigger_kwargs or {}
+        next_kwargs = start_trigger_args.next_kwargs
+        next_method = start_trigger_args.next_method
+        timeout = start_trigger_args.timeout
         trigger_row = Trigger(
             classpath=ti.task.start_trigger_args.trigger_cls,
-            kwargs=ti.task.start_trigger_args.trigger_kwargs or {},
+            kwargs=trigger_kwargs,
         )
-        next_kwargs = ti.task.start_trigger_args.next_kwargs
-        next_method = ti.task.start_trigger_args.next_method
-        timeout = ti.task.start_trigger_args.timeout
     else:
-        raise AirflowException("exception and ti.task.start_trigger_args 
cannot both be None")
+        raise TaskDeferralError("exception and ti.task.start_trigger_args 
cannot both be None")
 
     # First, make the trigger entry
     session.add(trigger_row)
diff --git a/airflow/models/xcom_arg.py b/airflow/models/xcom_arg.py
index bf7d8323e8..108634bb7e 100644
--- a/airflow/models/xcom_arg.py
+++ b/airflow/models/xcom_arg.py
@@ -208,7 +208,7 @@ class XComArg(ResolveMixin, DependencyMixin):
         raise NotImplementedError()
 
     @provide_session
-    def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any:
+    def resolve(self, context: Context, session: Session = NEW_SESSION, *, 
include_xcom: bool) -> Any:
         """
         Pull XCom value.
 
@@ -437,7 +437,7 @@ class PlainXComArg(XComArg):
         )
 
     @provide_session
-    def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any:
+    def resolve(self, context: Context, session: Session = NEW_SESSION, *, 
include_xcom: bool) -> Any:
         ti = context["ti"]
         if TYPE_CHECKING:
             assert isinstance(ti, TaskInstance)
@@ -551,8 +551,8 @@ class MapXComArg(XComArg):
         return self.arg.get_task_map_length(run_id, session=session)
 
     @provide_session
-    def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any:
-        value = self.arg.resolve(context, session=session)
+    def resolve(self, context: Context, session: Session = NEW_SESSION, *, 
include_xcom: bool) -> Any:
+        value = self.arg.resolve(context, session=session, 
include_xcom=include_xcom)
         if not isinstance(value, (Sequence, dict)):
             raise ValueError(f"XCom map expects sequence or dict, not 
{type(value).__name__}")
         return _MapResult(value, self.callables)
@@ -632,8 +632,8 @@ class ZipXComArg(XComArg):
         return max(ready_lengths)
 
     @provide_session
-    def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any:
-        values = [arg.resolve(context, session=session) for arg in self.args]
+    def resolve(self, context: Context, session: Session = NEW_SESSION, *, 
include_xcom: bool) -> Any:
+        values = [arg.resolve(context, session=session, 
include_xcom=include_xcom) for arg in self.args]
         for value in values:
             if not isinstance(value, (Sequence, dict)):
                 raise ValueError(f"XCom zip expects sequence or dict, not 
{type(value).__name__}")
@@ -707,8 +707,8 @@ class ConcatXComArg(XComArg):
         return sum(ready_lengths)
 
     @provide_session
-    def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any:
-        values = [arg.resolve(context, session=session) for arg in self.args]
+    def resolve(self, context: Context, session: Session = NEW_SESSION, *, 
include_xcom: bool) -> Any:
+        values = [arg.resolve(context, session=session, 
include_xcom=include_xcom) for arg in self.args]
         for value in values:
             if not isinstance(value, (Sequence, dict)):
                 raise ValueError(f"XCom concat expects sequence or dict, not 
{type(value).__name__}")
diff --git a/airflow/template/templater.py b/airflow/template/templater.py
index bfbb6f5722..3dc10d36ff 100644
--- a/airflow/template/templater.py
+++ b/airflow/template/templater.py
@@ -46,7 +46,7 @@ class LiteralValue(ResolveMixin):
     def iter_references(self) -> Iterable[tuple[Operator, str]]:
         return ()
 
-    def resolve(self, context: Context) -> Any:
+    def resolve(self, context: Context, *, include_xcom: bool) -> Any:
         return self.value
 
 
@@ -172,7 +172,7 @@ class Templater(LoggingMixin):
         if isinstance(value, ObjectStoragePath):
             return self._render_object_storage_path(value, context, jinja_env)
         if isinstance(value, ResolveMixin):
-            return value.resolve(context)
+            return value.resolve(context, include_xcom=True)
 
         # Fast path for common built-in collections.
         if value.__class__ is tuple:
diff --git a/airflow/utils/mixins.py b/airflow/utils/mixins.py
index eb0a1e81e1..324299b3a6 100644
--- a/airflow/utils/mixins.py
+++ b/airflow/utils/mixins.py
@@ -64,7 +64,7 @@ class ResolveMixin:
         """
         raise NotImplementedError
 
-    def resolve(self, context: Context) -> typing.Any:
+    def resolve(self, context: Context, *, include_xcom: bool) -> typing.Any:
         """
         Resolve this value for runtime.
 
diff --git a/docs/apache-airflow/authoring-and-scheduling/deferring.rst 
b/docs/apache-airflow/authoring-and-scheduling/deferring.rst
index a9b26703bc..d30a665c83 100644
--- a/docs/apache-airflow/authoring-and-scheduling/deferring.rst
+++ b/docs/apache-airflow/authoring-and-scheduling/deferring.rst
@@ -182,9 +182,6 @@ This is particularly useful when deferring is the only 
thing the ``execute`` met
 
 ``start_from_trigger`` and ``trigger_kwargs`` can also be modified at the 
instance level for more flexible configuration.
 
-.. warning::
-    Dynamic task mapping is not supported when ``trigger_kwargs`` is modified 
at instance level.
-
 .. code-block:: python
 
     from datetime import timedelta
@@ -213,6 +210,49 @@ This is particularly useful when deferring is the only 
thing the ``execute`` met
             # We have no more work to do here. Mark as complete.
             return
 
+To enable Dynamic Task Mapping support, you can define ``start_from_trigger`` 
and ``trigger_kwargs`` in the parameter of "__init__". Note that you don't need 
to define both of them to use this feature, but you do need to use the exact 
same parameter name. For example, if you define an argument as ``t_kwargs`` and 
assign this value to ``self.start_trigger_args.trigger_kwargs``, it will not 
work. Note that this works different from mapping an operator without 
``start_from_trigger`` support [...]
+
+.. code-block:: python
+
+    from datetime import timedelta
+    from typing import Any
+
+    from airflow.sensors.base import BaseSensorOperator
+    from airflow.triggers.temporal import TimeDeltaTrigger
+    from airflow.utils.context import Context
+
+
+    class WaitTwoHourSensor(BaseSensorOperator):
+        start_trigger_args = StartTriggerArgs(
+            trigger_cls="airflow.triggers.temporal.TimeDeltaTrigger",
+            trigger_kwargs={},
+            next_method="execute_complete",
+            timeout=None,
+        )
+
+        def __init__(
+            self,
+            *args: list[Any],
+            trigger_kwargs: dict[str, Any] | None,
+            start_from_trigger: bool,
+            **kwargs: dict[str, Any],
+        ) -> None:
+            super().__init__(*args, **kwargs)
+            self.start_trigger_args.trigger_kwargs = trigger_kwargs
+            self.start_from_trigger = start_from_trigger
+
+        def execute_complete(self, context: Context, event: dict[str, Any] | 
None = None) -> None:
+            # We have no more work to do here. Mark as complete.
+            return
+
+These parameters can be mapped using the ``expand`` and ``partial`` methods. 
Note that XCom values won't be resolved at this stage.
+
+.. code-block:: python
+
+    WaitTwoHourSensor.partial(task_id="transform", 
start_from_trigger=True).expand(
+        trigger_kwargs=[{"moment": timedelta(hours=2)}, {"moment": 
timedelta(hours=2)}]
+    )
+
 Writing Triggers
 ~~~~~~~~~~~~~~~~
 
diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py
index 021d09e889..7a791aa365 100644
--- a/tests/models/test_dagrun.py
+++ b/tests/models/test_dagrun.py
@@ -2043,6 +2043,38 @@ def 
test_schedule_tis_empty_operator_try_number(dag_maker, session: Session):
     assert empty_ti.try_number == 1
 
 
+def test_schedule_tis_start_trigger_through_expand(dag_maker, session):
+    """
+    Test that an operator with start_trigger_args set can be directly deferred 
during scheduling.
+    """
+
+    class TestOperator(BaseOperator):
+        start_trigger_args = StartTriggerArgs(
+            trigger_cls="airflow.triggers.testing.SuccessTrigger",
+            trigger_kwargs={},
+            next_method="execute_complete",
+            timeout=None,
+        )
+        start_from_trigger = False
+
+        def __init__(self, *args, start_from_trigger: bool = False, **kwargs):
+            super().__init__(*args, **kwargs)
+            self.start_from_trigger = start_from_trigger
+
+        def execute_complete(self):
+            pass
+
+    with dag_maker(session=session):
+        
TestOperator.partial(task_id="test_task").expand(start_from_trigger=[True, 
False])
+
+    dr: DagRun = dag_maker.create_dagrun()
+
+    dr.schedule_tis(dr.task_instances, session=session)
+    tis = [(ti.state, ti.map_index) for ti in dr.task_instances]
+    assert tis[0] == (TaskInstanceState.DEFERRED, 0)
+    assert tis[1] == (None, 1)
+
+
 def test_mapped_expand_kwargs(dag_maker):
     with dag_maker():
 

Reply via email to