This is an automated email from the ASF dual-hosted git repository.
weilee pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 85b2666eab Add start execution from triggerer support to dynamic task
mapping (#39912)
85b2666eab is described below
commit 85b2666eabc655a99a31609a7a27a3c577c1eefb
Author: Wei Lee <[email protected]>
AuthorDate: Mon Jul 22 16:29:37 2024 +0800
Add start execution from triggerer support to dynamic task mapping (#39912)
* feat(dagrun): add start_from_trigger support to mapped operator
* feat(mapped_operator): add partial support to start_trigger_args
* feat(mappedoperator): do not include xcom when expanding start trigger
args and flag
---
airflow/decorators/base.py | 6 +-
airflow/models/abstractoperator.py | 22 +++++++
airflow/models/baseoperator.py | 22 +++++++
airflow/models/dagrun.py | 20 ++++--
airflow/models/expandinput.py | 31 +++++----
airflow/models/mappedoperator.py | 75 ++++++++++++++++++++--
airflow/models/param.py | 2 +-
airflow/models/taskinstance.py | 19 ++++--
airflow/models/xcom_arg.py | 16 ++---
airflow/template/templater.py | 4 +-
airflow/utils/mixins.py | 2 +-
.../authoring-and-scheduling/deferring.rst | 46 ++++++++++++-
tests/models/test_dagrun.py | 32 +++++++++
13 files changed, 254 insertions(+), 43 deletions(-)
diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py
index 5a20fa55dd..d743acbe50 100644
--- a/airflow/decorators/base.py
+++ b/airflow/decorators/base.py
@@ -550,11 +550,13 @@ class DecoratedMappedOperator(MappedOperator):
super(DecoratedMappedOperator,
DecoratedMappedOperator).__attrs_post_init__(self)
XComArg.apply_upstream_relationship(self,
self.op_kwargs_expand_input.value)
- def _expand_mapped_kwargs(self, context: Context, session: Session) ->
tuple[Mapping[str, Any], set[int]]:
+ def _expand_mapped_kwargs(
+ self, context: Context, session: Session, *, include_xcom: bool
+ ) -> tuple[Mapping[str, Any], set[int]]:
# We only use op_kwargs_expand_input so this must always be empty.
if self.expand_input is not EXPAND_INPUT_EMPTY:
raise AssertionError(f"unexpected expand_input:
{self.expand_input}")
- op_kwargs, resolved_oids = super()._expand_mapped_kwargs(context,
session)
+ op_kwargs, resolved_oids = super()._expand_mapped_kwargs(context,
session, include_xcom=include_xcom)
return {"op_kwargs": op_kwargs}, resolved_oids
def _get_unmap_kwargs(self, mapped_kwargs: Mapping[str, Any], *, strict:
bool) -> dict[str, Any]:
diff --git a/airflow/models/abstractoperator.py
b/airflow/models/abstractoperator.py
index 89e8b6cc72..9cf1830bb4 100644
--- a/airflow/models/abstractoperator.py
+++ b/airflow/models/abstractoperator.py
@@ -55,6 +55,7 @@ if TYPE_CHECKING:
from airflow.models.operator import Operator
from airflow.models.taskinstance import TaskInstance
from airflow.task.priority_strategy import PriorityWeightStrategy
+ from airflow.triggers.base import StartTriggerArgs
from airflow.utils.task_group import TaskGroup
DEFAULT_OWNER: str = conf.get_mandatory_value("operators", "default_owner")
@@ -427,6 +428,27 @@ class AbstractOperator(Templater, DAGNode):
"""
raise NotImplementedError()
+ def expand_start_from_trigger(self, *, context: Context, session: Session)
-> bool:
+ """
+ Get the start_from_trigger value of the current abstract operator.
+
+ MappedOperator uses this to unmap start_from_trigger to decide whether
to start the task
+ execution directly from triggerer.
+
+ :meta private:
+ """
+ raise NotImplementedError()
+
+ def expand_start_trigger_args(self, *, context: Context, session: Session)
-> StartTriggerArgs | None:
+ """
+ Get the start_trigger_args value of the current abstract operator.
+
+ MappedOperator uses this to unmap start_trigger_args to decide how to
start a task from triggerer.
+
+ :meta private:
+ """
+ raise NotImplementedError()
+
@property
def priority_weight_total(self) -> int:
"""
diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index 30ab591867..8525b78f60 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -1795,6 +1795,28 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
"""
return self
+ def expand_start_from_trigger(self, *, context: Context, session: Session)
-> bool:
+ """
+ Get the start_from_trigger value of the current abstract operator.
+
+ Since a BaseOperator is not mapped to begin with, this simply returns
+ the original value of start_from_trigger.
+
+ :meta private:
+ """
+ return self.start_from_trigger
+
+ def expand_start_trigger_args(self, *, context: Context, session: Session)
-> StartTriggerArgs | None:
+ """
+ Get the start_trigger_args value of the current abstract operator.
+
+ Since a BaseOperator is not mapped to begin with, this simply returns
+ the original value of start_trigger_args.
+
+ :meta private:
+ """
+ return self.start_trigger_args
+
# TODO: Deprecate for Airflow 3.0
Chainable = Union[DependencyMixin, Sequence[DependencyMixin]]
diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py
index d4ef937e9d..d47be6e74b 100644
--- a/airflow/models/dagrun.py
+++ b/airflow/models/dagrun.py
@@ -1577,11 +1577,21 @@ class DagRun(Base, LoggingMixin):
and not ti.task.outlets
):
dummy_ti_ids.append((ti.task_id, ti.map_index))
- elif ti.task.start_from_trigger is True and
ti.task.start_trigger_args is not None:
- ti.start_date = timezone.utcnow()
- if ti.state != TaskInstanceState.UP_FOR_RESCHEDULE:
- ti.try_number += 1
- ti.defer_task(exception=None, session=session)
+ # check "start_trigger_args" to see whether the operator supports
start execution from triggerer
+ # if so, we'll then check "start_from_trigger" to see whether this
feature is turned on and defer
+ # this task.
+ # if not, we'll add this "ti" into "schedulable_ti_ids" and later
execute it to run in the worker
+ elif ti.task.start_trigger_args is not None:
+ context = ti.get_template_context()
+ start_from_trigger =
ti.task.expand_start_from_trigger(context=context, session=session)
+
+ if start_from_trigger:
+ ti.start_date = timezone.utcnow()
+ if ti.state != TaskInstanceState.UP_FOR_RESCHEDULE:
+ ti.try_number += 1
+ ti.defer_task(exception=None, session=session)
+ else:
+ schedulable_ti_ids.append((ti.task_id, ti.map_index))
else:
schedulable_ti_ids.append((ti.task_id, ti.map_index))
diff --git a/airflow/models/expandinput.py b/airflow/models/expandinput.py
index 4673eb4960..417c3bd0c1 100644
--- a/airflow/models/expandinput.py
+++ b/airflow/models/expandinput.py
@@ -69,8 +69,8 @@ class MappedArgument(ResolveMixin):
yield from self._input.iter_references()
@provide_session
- def resolve(self, context: Context, *, session: Session = NEW_SESSION) ->
Any:
- data, _ = self._input.resolve(context, session=session)
+ def resolve(self, context: Context, *, include_xcom: bool, session:
Session = NEW_SESSION) -> Any:
+ data, _ = self._input.resolve(context, session=session,
include_xcom=include_xcom)
return data[self._key]
@@ -165,9 +165,11 @@ class DictOfListsExpandInput(NamedTuple):
lengths = self._get_map_lengths(run_id, session=session)
return functools.reduce(operator.mul, (lengths[name] for name in
self.value), 1)
- def _expand_mapped_field(self, key: str, value: Any, context: Context, *,
session: Session) -> Any:
- if _needs_run_time_resolution(value):
- value = value.resolve(context, session=session)
+ def _expand_mapped_field(
+ self, key: str, value: Any, context: Context, *, session: Session,
include_xcom: bool
+ ) -> Any:
+ if include_xcom and _needs_run_time_resolution(value):
+ value = value.resolve(context, session=session,
include_xcom=include_xcom)
map_index = context["ti"].map_index
if map_index < 0:
raise RuntimeError("can't resolve task-mapping argument without
expanding")
@@ -203,8 +205,13 @@ class DictOfListsExpandInput(NamedTuple):
if isinstance(x, XComArg):
yield from x.iter_references()
- def resolve(self, context: Context, session: Session) ->
tuple[Mapping[str, Any], set[int]]:
- data = {k: self._expand_mapped_field(k, v, context, session=session)
for k, v in self.value.items()}
+ def resolve(
+ self, context: Context, session: Session, *, include_xcom: bool
+ ) -> tuple[Mapping[str, Any], set[int]]:
+ data = {
+ k: self._expand_mapped_field(k, v, context, session=session,
include_xcom=include_xcom)
+ for k, v in self.value.items()
+ }
literal_keys = {k for k, _ in self._iter_parse_time_resolved_kwargs()}
resolved_oids = {id(v) for k, v in data.items() if k not in
literal_keys}
return data, resolved_oids
@@ -248,7 +255,9 @@ class ListOfDictsExpandInput(NamedTuple):
if isinstance(x, XComArg):
yield from x.iter_references()
- def resolve(self, context: Context, session: Session) ->
tuple[Mapping[str, Any], set[int]]:
+ def resolve(
+ self, context: Context, session: Session, *, include_xcom: bool
+ ) -> tuple[Mapping[str, Any], set[int]]:
map_index = context["ti"].map_index
if map_index < 0:
raise RuntimeError("can't resolve task-mapping argument without
expanding")
@@ -257,9 +266,9 @@ class ListOfDictsExpandInput(NamedTuple):
if isinstance(self.value, collections.abc.Sized):
mapping = self.value[map_index]
if not isinstance(mapping, collections.abc.Mapping):
- mapping = mapping.resolve(context, session)
- else:
- mappings = self.value.resolve(context, session)
+ mapping = mapping.resolve(context, session,
include_xcom=include_xcom)
+ elif include_xcom:
+ mappings = self.value.resolve(context, session,
include_xcom=include_xcom)
if not isinstance(mappings, collections.abc.Sequence):
raise ValueError(f"expand_kwargs() expects a list[dict], not
{_describe_type(mappings)}")
mapping = mappings[map_index]
diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py
index 45a6ae1ac1..2377fdab00 100644
--- a/airflow/models/mappedoperator.py
+++ b/airflow/models/mappedoperator.py
@@ -51,6 +51,7 @@ from airflow.models.pool import Pool
from airflow.serialization.enums import DagAttributeTypes
from airflow.task.priority_strategy import PriorityWeightStrategy,
validate_and_load_priority_weight_strategy
from airflow.ti_deps.deps.mapped_task_expanded import MappedTaskIsExpanded
+from airflow.triggers.base import StartTriggerArgs
from airflow.typing_compat import Literal
from airflow.utils.context import context_update_for_unmapped
from airflow.utils.helpers import is_container, prevent_duplicates
@@ -81,7 +82,6 @@ if TYPE_CHECKING:
from airflow.models.param import ParamsDict
from airflow.models.xcom_arg import XComArg
from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
- from airflow.triggers.base import StartTriggerArgs
from airflow.utils.context import Context
from airflow.utils.operator_resources import Resources
from airflow.utils.task_group import TaskGroup
@@ -688,14 +688,16 @@ class MappedOperator(AbstractOperator):
"""Implement DAGNode."""
return DagAttributeTypes.OP, self.task_id
- def _expand_mapped_kwargs(self, context: Context, session: Session) ->
tuple[Mapping[str, Any], set[int]]:
+ def _expand_mapped_kwargs(
+ self, context: Context, session: Session, *, include_xcom: bool
+ ) -> tuple[Mapping[str, Any], set[int]]:
"""
Get the kwargs to create the unmapped operator.
This exists because taskflow operators expand against op_kwargs, not
the
entire operator kwargs dict.
"""
- return self._get_specified_expand_input().resolve(context, session)
+ return self._get_specified_expand_input().resolve(context, session,
include_xcom=include_xcom)
def _get_unmap_kwargs(self, mapped_kwargs: Mapping[str, Any], *, strict:
bool) -> dict[str, Any]:
"""
@@ -729,6 +731,69 @@ class MappedOperator(AbstractOperator):
"params": params,
}
+ def expand_start_from_trigger(self, *, context: Context, session: Session)
-> bool:
+ """
+ Get the start_from_trigger value of the current abstract operator.
+
+ MappedOperator uses this to unmap start_from_trigger to decide whether
to start the task
+ execution directly from triggerer.
+
+ :meta private:
+ """
+ # start_from_trigger only makes sense when start_trigger_args exists.
+ if not self.start_trigger_args:
+ return False
+
+ mapped_kwargs, _ = self._expand_mapped_kwargs(context, session,
include_xcom=False)
+ if self._disallow_kwargs_override:
+ prevent_duplicates(
+ self.partial_kwargs,
+ mapped_kwargs,
+ fail_reason="unmappable or already specified",
+ )
+
+ # Ordering is significant; mapped kwargs should override partial ones.
+ return mapped_kwargs.get(
+ "start_from_trigger",
self.partial_kwargs.get("start_from_trigger", self.start_from_trigger)
+ )
+
+ def expand_start_trigger_args(self, *, context: Context, session: Session)
-> StartTriggerArgs | None:
+ """
+ Get the kwargs to create the unmapped start_trigger_args.
+
+ This method is for allowing mapped operator to start execution from
triggerer.
+ """
+ if not self.start_trigger_args:
+ return None
+
+ mapped_kwargs, _ = self._expand_mapped_kwargs(context, session,
include_xcom=False)
+ if self._disallow_kwargs_override:
+ prevent_duplicates(
+ self.partial_kwargs,
+ mapped_kwargs,
+ fail_reason="unmappable or already specified",
+ )
+
+ # Ordering is significant; mapped kwargs should override partial ones.
+ trigger_kwargs = mapped_kwargs.get(
+ "trigger_kwargs",
+ self.partial_kwargs.get("trigger_kwargs",
self.start_trigger_args.trigger_kwargs),
+ )
+ next_kwargs = mapped_kwargs.get(
+ "next_kwargs",
+ self.partial_kwargs.get("next_kwargs",
self.start_trigger_args.next_kwargs),
+ )
+ timeout = mapped_kwargs.get(
+ "trigger_timeout", self.partial_kwargs.get("trigger_timeout",
self.start_trigger_args.timeout)
+ )
+ return StartTriggerArgs(
+ trigger_cls=self.start_trigger_args.trigger_cls,
+ trigger_kwargs=trigger_kwargs,
+ next_method=self.start_trigger_args.next_method,
+ next_kwargs=next_kwargs,
+ timeout=timeout,
+ )
+
def unmap(self, resolve: None | Mapping[str, Any] | tuple[Context,
Session]) -> BaseOperator:
"""
Get the "normal" Operator after applying the current mapping.
@@ -749,7 +814,7 @@ class MappedOperator(AbstractOperator):
if isinstance(resolve, collections.abc.Mapping):
kwargs = resolve
elif resolve is not None:
- kwargs, _ = self._expand_mapped_kwargs(*resolve)
+ kwargs, _ = self._expand_mapped_kwargs(*resolve,
include_xcom=True)
else:
raise RuntimeError("cannot unmap a non-serialized operator
without context")
kwargs = self._get_unmap_kwargs(kwargs,
strict=self._disallow_kwargs_override)
@@ -844,7 +909,7 @@ class MappedOperator(AbstractOperator):
# set_current_task_session context manager to store the session in the
current task.
session = get_current_task_instance_session()
- mapped_kwargs, seen_oids = self._expand_mapped_kwargs(context, session)
+ mapped_kwargs, seen_oids = self._expand_mapped_kwargs(context,
session, include_xcom=True)
unmapped_task = self.unmap(mapped_kwargs)
context_update_for_unmapped(context, unmapped_task)
diff --git a/airflow/models/param.py b/airflow/models/param.py
index 3f56e50cd1..a5b9504a06 100644
--- a/airflow/models/param.py
+++ b/airflow/models/param.py
@@ -329,7 +329,7 @@ class DagParam(ResolveMixin):
def iter_references(self) -> Iterable[tuple[Operator, str]]:
return ()
- def resolve(self, context: Context) -> Any:
+ def resolve(self, context: Context, *, include_xcom: bool) -> Any:
"""Pull DagParam value from DagRun context. This method is run during
``op.execute()``."""
with contextlib.suppress(KeyError):
return context["dag_run"].conf[self._name]
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 393da63c74..086e89b0eb 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -83,6 +83,7 @@ from airflow.exceptions import (
AirflowTaskTimeout,
DagRunNotFound,
RemovedInAirflow3Warning,
+ TaskDeferralError,
TaskDeferred,
UnmappableXComLengthPushed,
UnmappableXComTypePushed,
@@ -1617,15 +1618,23 @@ def _defer_task(
next_kwargs = exception.kwargs
timeout = exception.timeout
elif ti.task is not None and ti.task.start_trigger_args is not None:
+ context = ti.get_template_context()
+ start_trigger_args =
ti.task.expand_start_trigger_args(context=context, session=session)
+ if start_trigger_args is None:
+ raise TaskDeferralError(
+ "A none 'None' start_trigger_args has been change to 'None'
during expandion"
+ )
+
+ trigger_kwargs = start_trigger_args.trigger_kwargs or {}
+ next_kwargs = start_trigger_args.next_kwargs
+ next_method = start_trigger_args.next_method
+ timeout = start_trigger_args.timeout
trigger_row = Trigger(
classpath=ti.task.start_trigger_args.trigger_cls,
- kwargs=ti.task.start_trigger_args.trigger_kwargs or {},
+ kwargs=trigger_kwargs,
)
- next_kwargs = ti.task.start_trigger_args.next_kwargs
- next_method = ti.task.start_trigger_args.next_method
- timeout = ti.task.start_trigger_args.timeout
else:
- raise AirflowException("exception and ti.task.start_trigger_args
cannot both be None")
+ raise TaskDeferralError("exception and ti.task.start_trigger_args
cannot both be None")
# First, make the trigger entry
session.add(trigger_row)
diff --git a/airflow/models/xcom_arg.py b/airflow/models/xcom_arg.py
index bf7d8323e8..108634bb7e 100644
--- a/airflow/models/xcom_arg.py
+++ b/airflow/models/xcom_arg.py
@@ -208,7 +208,7 @@ class XComArg(ResolveMixin, DependencyMixin):
raise NotImplementedError()
@provide_session
- def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any:
+ def resolve(self, context: Context, session: Session = NEW_SESSION, *,
include_xcom: bool) -> Any:
"""
Pull XCom value.
@@ -437,7 +437,7 @@ class PlainXComArg(XComArg):
)
@provide_session
- def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any:
+ def resolve(self, context: Context, session: Session = NEW_SESSION, *,
include_xcom: bool) -> Any:
ti = context["ti"]
if TYPE_CHECKING:
assert isinstance(ti, TaskInstance)
@@ -551,8 +551,8 @@ class MapXComArg(XComArg):
return self.arg.get_task_map_length(run_id, session=session)
@provide_session
- def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any:
- value = self.arg.resolve(context, session=session)
+ def resolve(self, context: Context, session: Session = NEW_SESSION, *,
include_xcom: bool) -> Any:
+ value = self.arg.resolve(context, session=session,
include_xcom=include_xcom)
if not isinstance(value, (Sequence, dict)):
raise ValueError(f"XCom map expects sequence or dict, not
{type(value).__name__}")
return _MapResult(value, self.callables)
@@ -632,8 +632,8 @@ class ZipXComArg(XComArg):
return max(ready_lengths)
@provide_session
- def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any:
- values = [arg.resolve(context, session=session) for arg in self.args]
+ def resolve(self, context: Context, session: Session = NEW_SESSION, *,
include_xcom: bool) -> Any:
+ values = [arg.resolve(context, session=session,
include_xcom=include_xcom) for arg in self.args]
for value in values:
if not isinstance(value, (Sequence, dict)):
raise ValueError(f"XCom zip expects sequence or dict, not
{type(value).__name__}")
@@ -707,8 +707,8 @@ class ConcatXComArg(XComArg):
return sum(ready_lengths)
@provide_session
- def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any:
- values = [arg.resolve(context, session=session) for arg in self.args]
+ def resolve(self, context: Context, session: Session = NEW_SESSION, *,
include_xcom: bool) -> Any:
+ values = [arg.resolve(context, session=session,
include_xcom=include_xcom) for arg in self.args]
for value in values:
if not isinstance(value, (Sequence, dict)):
raise ValueError(f"XCom concat expects sequence or dict, not
{type(value).__name__}")
diff --git a/airflow/template/templater.py b/airflow/template/templater.py
index bfbb6f5722..3dc10d36ff 100644
--- a/airflow/template/templater.py
+++ b/airflow/template/templater.py
@@ -46,7 +46,7 @@ class LiteralValue(ResolveMixin):
def iter_references(self) -> Iterable[tuple[Operator, str]]:
return ()
- def resolve(self, context: Context) -> Any:
+ def resolve(self, context: Context, *, include_xcom: bool) -> Any:
return self.value
@@ -172,7 +172,7 @@ class Templater(LoggingMixin):
if isinstance(value, ObjectStoragePath):
return self._render_object_storage_path(value, context, jinja_env)
if isinstance(value, ResolveMixin):
- return value.resolve(context)
+ return value.resolve(context, include_xcom=True)
# Fast path for common built-in collections.
if value.__class__ is tuple:
diff --git a/airflow/utils/mixins.py b/airflow/utils/mixins.py
index eb0a1e81e1..324299b3a6 100644
--- a/airflow/utils/mixins.py
+++ b/airflow/utils/mixins.py
@@ -64,7 +64,7 @@ class ResolveMixin:
"""
raise NotImplementedError
- def resolve(self, context: Context) -> typing.Any:
+ def resolve(self, context: Context, *, include_xcom: bool) -> typing.Any:
"""
Resolve this value for runtime.
diff --git a/docs/apache-airflow/authoring-and-scheduling/deferring.rst
b/docs/apache-airflow/authoring-and-scheduling/deferring.rst
index a9b26703bc..d30a665c83 100644
--- a/docs/apache-airflow/authoring-and-scheduling/deferring.rst
+++ b/docs/apache-airflow/authoring-and-scheduling/deferring.rst
@@ -182,9 +182,6 @@ This is particularly useful when deferring is the only
thing the ``execute`` met
``start_from_trigger`` and ``trigger_kwargs`` can also be modified at the
instance level for more flexible configuration.
-.. warning::
- Dynamic task mapping is not supported when ``trigger_kwargs`` is modified
at instance level.
-
.. code-block:: python
from datetime import timedelta
@@ -213,6 +210,49 @@ This is particularly useful when deferring is the only
thing the ``execute`` met
# We have no more work to do here. Mark as complete.
return
+To enable Dynamic Task Mapping support, you can define ``start_from_trigger``
and ``trigger_kwargs`` in the parameter of "__init__". Note that you don't need
to define both of them to use this feature, but you do need to use the exact
same parameter name. For example, if you define an argument as ``t_kwargs`` and
assign this value to ``self.start_trigger_args.trigger_kwargs``, it will not
work. Note that this works different from mapping an operator without
``start_from_trigger`` support [...]
+
+.. code-block:: python
+
+ from datetime import timedelta
+ from typing import Any
+
+ from airflow.sensors.base import BaseSensorOperator
+ from airflow.triggers.temporal import TimeDeltaTrigger
+ from airflow.utils.context import Context
+
+
+ class WaitTwoHourSensor(BaseSensorOperator):
+ start_trigger_args = StartTriggerArgs(
+ trigger_cls="airflow.triggers.temporal.TimeDeltaTrigger",
+ trigger_kwargs={},
+ next_method="execute_complete",
+ timeout=None,
+ )
+
+ def __init__(
+ self,
+ *args: list[Any],
+ trigger_kwargs: dict[str, Any] | None,
+ start_from_trigger: bool,
+ **kwargs: dict[str, Any],
+ ) -> None:
+ super().__init__(*args, **kwargs)
+ self.start_trigger_args.trigger_kwargs = trigger_kwargs
+ self.start_from_trigger = start_from_trigger
+
+ def execute_complete(self, context: Context, event: dict[str, Any] |
None = None) -> None:
+ # We have no more work to do here. Mark as complete.
+ return
+
+These parameters can be mapped using the ``expand`` and ``partial`` methods.
Note that XCom values won't be resolved at this stage.
+
+.. code-block:: python
+
+ WaitTwoHourSensor.partial(task_id="transform",
start_from_trigger=True).expand(
+ trigger_kwargs=[{"moment": timedelta(hours=2)}, {"moment":
timedelta(hours=2)}]
+ )
+
Writing Triggers
~~~~~~~~~~~~~~~~
diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py
index 021d09e889..7a791aa365 100644
--- a/tests/models/test_dagrun.py
+++ b/tests/models/test_dagrun.py
@@ -2043,6 +2043,38 @@ def
test_schedule_tis_empty_operator_try_number(dag_maker, session: Session):
assert empty_ti.try_number == 1
+def test_schedule_tis_start_trigger_through_expand(dag_maker, session):
+ """
+ Test that an operator with start_trigger_args set can be directly deferred
during scheduling.
+ """
+
+ class TestOperator(BaseOperator):
+ start_trigger_args = StartTriggerArgs(
+ trigger_cls="airflow.triggers.testing.SuccessTrigger",
+ trigger_kwargs={},
+ next_method="execute_complete",
+ timeout=None,
+ )
+ start_from_trigger = False
+
+ def __init__(self, *args, start_from_trigger: bool = False, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.start_from_trigger = start_from_trigger
+
+ def execute_complete(self):
+ pass
+
+ with dag_maker(session=session):
+
TestOperator.partial(task_id="test_task").expand(start_from_trigger=[True,
False])
+
+ dr: DagRun = dag_maker.create_dagrun()
+
+ dr.schedule_tis(dr.task_instances, session=session)
+ tis = [(ti.state, ti.map_index) for ti in dr.task_instances]
+ assert tis[0] == (TaskInstanceState.DEFERRED, 0)
+ assert tis[1] == (None, 1)
+
+
def test_mapped_expand_kwargs(dag_maker):
with dag_maker():