This is an automated email from the ASF dual-hosted git repository.
weilee pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new d509abfa21 Introduce StartTriggerArgs and prevent start trigger
initialization in scheduler (#39585)
d509abfa21 is described below
commit d509abfa217565d2d5249c639ea7459c44292368
Author: Wei Lee <[email protected]>
AuthorDate: Tue Jun 11 20:10:01 2024 +0900
Introduce StartTriggerArgs and prevent start trigger initialization in
scheduler (#39585)
* fix(baseoperator): change `start_trigger` into `start_trigger_args`
* feat(baseoperator): add `start_from_trigger` as the flag to decide
whether to start task execution from triggerer
* fix(dagrun): set start_date before deferring task from scheduler
---
airflow/decorators/base.py | 4 +-
airflow/models/abstractoperator.py | 5 +-
airflow/models/baseoperator.py | 10 +--
airflow/models/dagrun.py | 16 +---
airflow/models/mappedoperator.py | 17 ++--
airflow/models/taskinstance.py | 39 ++++++---
airflow/serialization/serialized_objects.py | 31 +++----
airflow/triggers/base.py | 20 +++++
.../authoring-and-scheduling/deferring.rst | 36 +++++---
tests/models/test_dagrun.py | 17 ++--
tests/serialization/test_dag_serialization.py | 99 +++++++++++-----------
tests/serialization/test_pydantic_models.py | 4 +-
12 files changed, 168 insertions(+), 130 deletions(-)
diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py
index 2ae85a9c43..74b44ffe23 100644
--- a/airflow/decorators/base.py
+++ b/airflow/decorators/base.py
@@ -509,8 +509,8 @@ class _TaskDecorator(ExpandableFactory, Generic[FParams,
FReturn, OperatorSubcla
# task's expand() contribute to the op_kwargs operator argument,
not
# the operator arguments themselves, and should expand against it.
expand_input_attr="op_kwargs_expand_input",
- start_trigger=self.operator_class.start_trigger,
- next_method=self.operator_class.next_method,
+ start_trigger_args=self.operator_class.start_trigger_args,
+ start_from_trigger=self.operator_class.start_from_trigger,
)
return XComArg(operator=operator)
diff --git a/airflow/models/abstractoperator.py
b/airflow/models/abstractoperator.py
index b7160430e0..1bb83a2dc0 100644
--- a/airflow/models/abstractoperator.py
+++ b/airflow/models/abstractoperator.py
@@ -122,8 +122,9 @@ class AbstractOperator(Templater, DAGNode):
"node_id", # Duplicates task_id
"task_group", # Doesn't have a useful repr, no point showing in UI
"inherits_from_empty_operator", # impl detail
- "start_trigger",
- "next_method",
+ # Decide whether to start task execution from triggerer
+ "start_trigger_args",
+ "start_from_trigger",
# For compatibility with TG, for operators these are just the
current task, no point showing
"roots",
"leaves",
diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index 98532d90b0..bbd629cfc1 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -116,7 +116,7 @@ if TYPE_CHECKING:
from airflow.models.operator import Operator
from airflow.models.xcom_arg import XComArg
from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
- from airflow.triggers.base import BaseTrigger
+ from airflow.triggers.base import BaseTrigger, StartTriggerArgs
from airflow.utils.task_group import TaskGroup
from airflow.utils.types import ArgNotSet
@@ -819,8 +819,8 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
# Set to True for an operator instantiated by a mapped operator.
__from_mapped = False
- start_trigger: BaseTrigger | None = None
- next_method: str | None = None
+ start_trigger_args: StartTriggerArgs | None = None
+ start_from_trigger: bool = False
def __init__(
self,
@@ -1679,9 +1679,9 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
"is_teardown",
"on_failure_fail_dagrun",
"map_index_template",
- "start_trigger",
- "next_method",
+ "start_trigger_args",
"_needs_expansion",
+ "start_from_trigger",
}
)
DagContext.pop_context_managed_dag()
diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py
index 008608b96d..4773a89d1d 100644
--- a/airflow/models/dagrun.py
+++ b/airflow/models/dagrun.py
@@ -51,7 +51,7 @@ from airflow import settings
from airflow.api_internal.internal_api_call import internal_api_call
from airflow.callbacks.callback_requests import DagCallbackRequest
from airflow.configuration import conf as airflow_conf
-from airflow.exceptions import AirflowException, RemovedInAirflow3Warning,
TaskDeferred, TaskNotFound
+from airflow.exceptions import AirflowException, RemovedInAirflow3Warning,
TaskNotFound
from airflow.listeners.listener import get_listener_manager
from airflow.models import Log
from airflow.models.abstractoperator import NotMapped
@@ -1538,19 +1538,11 @@ class DagRun(Base, LoggingMixin):
and not ti.task.outlets
):
dummy_ti_ids.append((ti.task_id, ti.map_index))
- elif (
- ti.task.start_trigger is not None
- and ti.task.next_method is not None
- and not ti.task.on_execute_callback
- and not ti.task.on_success_callback
- and not ti.task.outlets
- ):
+ elif ti.task.start_from_trigger is True and
ti.task.start_trigger_args is not None:
+ ti.start_date = timezone.utcnow()
if ti.state != TaskInstanceState.UP_FOR_RESCHEDULE:
ti.try_number += 1
- ti.defer_task(
- exception=TaskDeferred(trigger=ti.task.start_trigger,
method_name=ti.task.next_method),
- session=session,
- )
+ ti.defer_task(exception=None, session=session)
else:
schedulable_ti_ids.append((ti.task_id, ti.map_index))
diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py
index 27d0510c30..abbed3cfa9 100644
--- a/airflow/models/mappedoperator.py
+++ b/airflow/models/mappedoperator.py
@@ -81,7 +81,7 @@ if TYPE_CHECKING:
from airflow.models.param import ParamsDict
from airflow.models.xcom_arg import XComArg
from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
- from airflow.triggers.base import BaseTrigger
+ from airflow.triggers.base import StartTriggerArgs
from airflow.utils.context import Context
from airflow.utils.operator_resources import Resources
from airflow.utils.task_group import TaskGroup
@@ -237,8 +237,8 @@ class OperatorPartial:
# For classic operators, this points to expand_input because kwargs
# to BaseOperator.expand() contribute to operator arguments.
expand_input_attr="expand_input",
- start_trigger=self.operator_class.start_trigger,
- next_method=self.operator_class.next_method,
+ start_trigger_args=self.operator_class.start_trigger_args,
+ start_from_trigger=self.operator_class.start_from_trigger,
)
return op
@@ -281,8 +281,8 @@ class MappedOperator(AbstractOperator):
_task_module: str
_task_type: str
_operator_name: str
- start_trigger: BaseTrigger | None
- next_method: str | None
+ start_trigger_args: StartTriggerArgs | None
+ start_from_trigger: bool
_needs_expansion: bool = True
dag: DAG | None
@@ -309,12 +309,7 @@ class MappedOperator(AbstractOperator):
supports_lineage: bool = False
HIDE_ATTRS_FROM_UI: ClassVar[frozenset[str]] =
AbstractOperator.HIDE_ATTRS_FROM_UI | frozenset(
- (
- "parse_time_mapped_ti_count",
- "operator_class",
- "start_trigger",
- "next_method",
- )
+ ("parse_time_mapped_ti_count", "operator_class", "start_trigger_args",
"start_from_trigger")
)
def __hash__(self):
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index c1ace17cd5..373ad108c2 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -1575,12 +1575,29 @@ def _coalesce_to_orm_ti(*, ti: TaskInstancePydantic |
TaskInstance, session: Ses
@internal_api_call
@provide_session
def _defer_task(
- ti: TaskInstance | TaskInstancePydantic, exception: TaskDeferred, session:
Session = NEW_SESSION
+ ti: TaskInstance | TaskInstancePydantic,
+ exception: TaskDeferred | None = None,
+ session: Session = NEW_SESSION,
) -> TaskInstancePydantic | TaskInstance:
from airflow.models.trigger import Trigger
+ if exception is not None:
+ trigger_row = Trigger.from_object(exception.trigger)
+ trigger_kwargs = exception.kwargs
+ next_method = exception.method_name
+ timeout = exception.timeout
+ elif ti.task is not None and ti.task.start_trigger_args is not None:
+ trigger_row = Trigger(
+ classpath=ti.task.start_trigger_args.trigger_cls,
+ kwargs=ti.task.start_trigger_args.trigger_kwargs or {},
+ )
+ trigger_kwargs = ti.task.start_trigger_args.trigger_kwargs
+ next_method = ti.task.start_trigger_args.next_method
+ timeout = ti.task.start_trigger_args.timeout
+ else:
+ raise AirflowException("exception and ti.task.start_trigger_args
cannot both be None")
+
# First, make the trigger entry
- trigger_row = Trigger.from_object(exception.trigger)
session.add(trigger_row)
session.flush()
@@ -1594,12 +1611,12 @@ def _defer_task(
# depending on self.next_method semantics
ti.state = TaskInstanceState.DEFERRED
ti.trigger_id = trigger_row.id
- ti.next_method = exception.method_name
- ti.next_kwargs = exception.kwargs or {}
+ ti.next_method = next_method
+ ti.next_kwargs = trigger_kwargs or {}
# Calculate timeout too if it was passed
- if exception.timeout is not None:
- ti.trigger_timeout = timezone.utcnow() + exception.timeout
+ if timeout is not None:
+ ti.trigger_timeout = timezone.utcnow() + timeout
else:
ti.trigger_timeout = None
@@ -1615,8 +1632,10 @@ def _defer_task(
ti.trigger_timeout = ti.start_date + execution_timeout
if ti.test_mode:
_add_log(event=ti.state, task_instance=ti, session=session)
- session.merge(ti)
- session.commit()
+
+ if exception is not None:
+ session.merge(ti)
+ session.commit()
return ti
@@ -3000,8 +3019,8 @@ class TaskInstance(Base, LoggingMixin):
return _execute_task(self, context, task_orig)
@provide_session
- def defer_task(self, exception: TaskDeferred, session: Session) -> None:
- """Mark the task as deferred and sets up the trigger that is needed to
resume it.
+ def defer_task(self, exception: TaskDeferred | None, session: Session =
NEW_SESSION) -> None:
+ """Mark the task as deferred and sets up the trigger that is needed to
resume it when TaskDeferred is raised.
:meta: private
"""
diff --git a/airflow/serialization/serialized_objects.py
b/airflow/serialization/serialized_objects.py
index 6e7f50a87c..eb9c15f43c 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -66,11 +66,10 @@ from airflow.task.priority_strategy import (
airflow_priority_weight_strategies,
airflow_priority_weight_strategies_classes,
)
-from airflow.triggers.base import BaseTrigger
+from airflow.triggers.base import BaseTrigger, StartTriggerArgs
from airflow.utils.code_utils import get_python_source
from airflow.utils.context import Context, OutletEventAccessor,
OutletEventAccessors
from airflow.utils.docs import get_docs_url
-from airflow.utils.helpers import exactly_one
from airflow.utils.module_loading import import_string, qualname
from airflow.utils.operator_resources import Resources
from airflow.utils.task_group import MappedTaskGroup, TaskGroup
@@ -1018,11 +1017,10 @@ class SerializedBaseOperator(BaseOperator,
BaseSerialization):
# Used to determine if an Operator is inherited from EmptyOperator
serialize_op["_is_empty"] = op.inherits_from_empty_operator
- if exactly_one(op.start_trigger is not None, op.next_method is not
None):
- raise AirflowException("start_trigger and next_method should both
be set.")
-
- serialize_op["start_trigger"] = op.start_trigger.serialize() if
op.start_trigger else None
- serialize_op["next_method"] = op.next_method
+ serialize_op["start_trigger_args"] = (
+ op.start_trigger_args.serialize() if op.start_trigger_args else
None
+ )
+ serialize_op["start_from_trigger"] = op.start_from_trigger
if op.operator_extra_links:
serialize_op["_operator_extra_links"] =
cls._serialize_operator_extra_links(
@@ -1206,16 +1204,11 @@ class SerializedBaseOperator(BaseOperator,
BaseSerialization):
# Used to determine if an Operator is inherited from EmptyOperator
setattr(op, "_is_empty", bool(encoded_op.get("_is_empty", False)))
- # Deserialize start_trigger
- serialized_start_trigger = encoded_op.get("start_trigger")
- if serialized_start_trigger:
- trigger_cls_name, trigger_kwargs = serialized_start_trigger
- trigger_cls = import_string(trigger_cls_name)
- start_trigger = trigger_cls(**trigger_kwargs)
- setattr(op, "start_trigger", start_trigger)
- else:
- setattr(op, "start_trigger", None)
- setattr(op, "next_method", encoded_op.get("next_method", None))
+ start_trigger_args = None
+ if encoded_op.get("start_trigger_args", None):
+ start_trigger_args =
StartTriggerArgs(**encoded_op.get("start_trigger_args", None))
+ setattr(op, "start_trigger_args", start_trigger_args)
+ setattr(op, "start_from_trigger",
bool(encoded_op.get("start_from_trigger", False)))
@staticmethod
def set_task_dag_references(task: Operator, dag: DAG) -> None:
@@ -1278,8 +1271,8 @@ class SerializedBaseOperator(BaseOperator,
BaseSerialization):
end_date=None,
disallow_kwargs_override=encoded_op["_disallow_kwargs_override"],
expand_input_attr=encoded_op["_expand_input_attr"],
- start_trigger=None,
- next_method=None,
+ start_trigger_args=encoded_op.get("start_trigger_args", None),
+ start_from_trigger=encoded_op.get("start_from_trigger", False),
)
else:
op = SerializedBaseOperator(task_id=encoded_op["task_id"])
diff --git a/airflow/triggers/base.py b/airflow/triggers/base.py
index 0d239af0ca..ac5727a1ba 100644
--- a/airflow/triggers/base.py
+++ b/airflow/triggers/base.py
@@ -17,11 +17,31 @@
from __future__ import annotations
import abc
+from dataclasses import dataclass
+from datetime import timedelta
from typing import Any, AsyncIterator
from airflow.utils.log.logging_mixin import LoggingMixin
+@dataclass
+class StartTriggerArgs:
+ """Arguments required for start task execution from triggerer."""
+
+ trigger_cls: str
+ next_method: str
+ trigger_kwargs: dict[str, Any] | None = None
+ timeout: timedelta | None = None
+
+ def serialize(self):
+ return {
+ "trigger_cls": self.trigger_cls,
+ "trigger_kwargs": self.trigger_kwargs,
+ "next_method": self.next_method,
+ "timeout": self.timeout,
+ }
+
+
class BaseTrigger(abc.ABC, LoggingMixin):
"""
Base class for all triggers.
diff --git a/docs/apache-airflow/authoring-and-scheduling/deferring.rst
b/docs/apache-airflow/authoring-and-scheduling/deferring.rst
index 084a08f0ac..a65e932a90 100644
--- a/docs/apache-airflow/authoring-and-scheduling/deferring.rst
+++ b/docs/apache-airflow/authoring-and-scheduling/deferring.rst
@@ -143,10 +143,14 @@ The ``self.defer`` call raises the ``TaskDeferred``
exception, so it can work an
Triggering Deferral from Start
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-If you want to defer your task directly to the triggerer without going into
the worker, you can add the class level attributes ``start_trigger`` and
``next_method`` to your deferrable operator.
+ .. versionadded:: 2.10.0
-* ``start_trigger``: An instance of a trigger you want to defer to. It will be
serialized into the database.
+If you want to defer your task directly to the triggerer without going into
the worker, you can set class level attribute ``start_with_trigger`` to
``True`` and add a class level attribute ``start_trigger_args`` with an
``StartTriggerArgs`` object with the following 4 attributes to your deferrable
operator:
+
+* ``trigger_cls``: An importable path to your trigger class.
+* ``trigger_kwargs``: Additional keyword arguments to pass to the method when
it is called.
* ``next_method``: The method name on your operator that you want Airflow to
call when it resumes.
+* ``timeout``: (Optional) A timedelta that specifies a timeout after which
this deferral will fail, and fail the task instance. Defaults to ``None``,
which means no timeout.
This is particularly useful when deferring is the only thing the ``execute``
method does. Here's a basic refinement of the previous example.
@@ -156,23 +160,28 @@ This is particularly useful when deferring is the only
thing the ``execute`` met
from datetime import timedelta
from typing import Any
+ from airflow.triggers.base import StartTriggerArgs
from airflow.sensors.base import BaseSensorOperator
- from airflow.triggers.temporal import TimeDeltaTrigger
from airflow.utils.context import Context
class WaitOneHourSensor(BaseSensorOperator):
- start_trigger = TimeDeltaTrigger(timedelta(hours=1))
- next_method = "execute_complete"
+ start_trigger_args = StartTriggerArgs(
+ trigger_cls="airflow.triggers.temporal.TimeDeltaTrigger",
+ trigger_kwargs={"moment": timedelta(hours=1)},
+ next_method="execute_complete",
+ timeout=None,
+ )
+ start_from_trigger = True
def execute_complete(self, context: Context, event: dict[str, Any] |
None = None) -> None:
# We have no more work to do here. Mark as complete.
return
-``start_trigger`` and ``next_method`` can also be set at the instance level
for more flexible configuration.
+``start_from_trigger`` and ``trigger_kwargs`` can also be modified at the
instance level for more flexible configuration.
.. warning::
- Dynamic task mapping is not supported when ``start_trigger`` and
``next_method`` are assigned in instance level.
+ Dynamic task mapping is not supported when ``trigger_kwargs`` is modified
at instance level.
.. code-block:: python
@@ -184,11 +193,18 @@ This is particularly useful when deferring is the only
thing the ``execute`` met
from airflow.utils.context import Context
- class WaitOneHourSensor(BaseSensorOperator):
+ class WaitTwoHourSensor(BaseSensorOperator):
+ start_trigger_args = StartTriggerArgs(
+ trigger_cls="airflow.triggers.temporal.TimeDeltaTrigger",
+ trigger_kwargs={},
+ next_method="execute_complete",
+ timeout=None,
+ )
+
def __init__(self, *args: list[Any], **kwargs: dict[str, Any]) -> None:
super().__init__(*args, **kwargs)
- self.start_trigger = TimeDeltaTrigger(timedelta(hours=1))
- self.next_method = "execute_complete"
+ self.start_trigger_args.trigger_kwargs = {"moment":
timedelta(hours=1)}
+ self.start_from_trigger = True
def execute_complete(self, context: Context, event: dict[str, Any] |
None = None) -> None:
# We have no more work to do here. Mark as complete.
diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py
index fe1c2d58a2..93e0611243 100644
--- a/tests/models/test_dagrun.py
+++ b/tests/models/test_dagrun.py
@@ -40,7 +40,7 @@ from airflow.operators.empty import EmptyOperator
from airflow.operators.python import ShortCircuitOperator
from airflow.serialization.serialized_objects import SerializedDAG
from airflow.stats import Stats
-from airflow.triggers.testing import SuccessTrigger
+from airflow.triggers.base import StartTriggerArgs
from airflow.utils import timezone
from airflow.utils.state import DagRunState, State, TaskInstanceState
from airflow.utils.trigger_rule import TriggerRule
@@ -1989,16 +1989,21 @@ def test_schedule_tis_map_index(dag_maker, session):
def test_schedule_tis_start_trigger(dag_maker, session):
"""
- Test that an operator with _start_trigger and _next_method set can be
directly
- deferred during scheduling.
+ Test that an operator with start_trigger_args set can be directly deferred
during scheduling.
"""
- trigger = SuccessTrigger()
class TestOperator(BaseOperator):
+ start_trigger_args = StartTriggerArgs(
+ trigger_cls="airflow.triggers.testing.SuccessTrigger",
+ trigger_kwargs=None,
+ next_method="execute_complete",
+ timeout=None,
+ )
+ start_from_trigger = True
+
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
- self.start_trigger = trigger
- self.next_method = "execute_complete"
+ self.start_trigger_args.trigger_kwargs = {}
def execute_complete(self):
pass
diff --git a/tests/serialization/test_dag_serialization.py
b/tests/serialization/test_dag_serialization.py
index 093b7fba76..16f6d3cb68 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -72,7 +72,7 @@ from airflow.serialization.serialized_objects import (
from airflow.task.priority_strategy import _DownstreamPriorityWeightStrategy
from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
from airflow.timetables.simple import NullTimetable, OnceTimetable
-from airflow.triggers.testing import SuccessTrigger
+from airflow.triggers.base import StartTriggerArgs
from airflow.utils import timezone
from airflow.utils.operator_resources import Resources
from airflow.utils.task_group import TaskGroup
@@ -197,8 +197,8 @@ serialized_simple_dag_ground_truth = {
"_log_config_logger_name": "airflow.task.operators",
"_needs_expansion": False,
"weight_rule": "downstream",
- "next_method": None,
- "start_trigger": None,
+ "start_trigger_args": None,
+ "start_from_trigger": False,
},
},
{
@@ -227,8 +227,8 @@ serialized_simple_dag_ground_truth = {
"_log_config_logger_name": "airflow.task.operators",
"_needs_expansion": False,
"weight_rule": "downstream",
- "next_method": None,
- "start_trigger": None,
+ "start_trigger_args": None,
+ "start_from_trigger": False,
},
},
],
@@ -2167,25 +2167,38 @@ class TestStringifiedDAGs:
SerializedDAG.to_dict(dag)
@pytest.mark.db_test
- def test_start_trigger_and_next_method_in_serialized_dag(self):
+ def test_start_trigger_args_in_serialized_dag(self):
"""
- Test that when we provide start_trigger and next_method, the DAG can
be correctly serialized.
+ Test that when we provide start_trigger_args, the DAG can be correctly
serialized.
"""
- trigger = SuccessTrigger()
class TestOperator(BaseOperator):
+ start_trigger_args = StartTriggerArgs(
+ trigger_cls="airflow.triggers.testing.SuccessTrigger",
+ trigger_kwargs=None,
+ next_method="execute_complete",
+ timeout=None,
+ )
+ start_from_trigger = False
+
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
- self.start_trigger = trigger
- self.next_method = "execute_complete"
+ self.start_trigger_args.trigger_kwargs = {}
+ self.start_from_trigger = True
def execute_complete(self):
pass
class Test2Operator(BaseOperator):
+ start_trigger_args = StartTriggerArgs(
+ trigger_cls="airflow.triggers.testing.SuccessTrigger",
+ trigger_kwargs={},
+ next_method="execute_complete",
+ timeout=None,
+ )
+ start_from_trigger = True
+
def __init__(self, *args, **kwargs):
- self.start_trigger = trigger
- self.next_method = "execute_complete"
super().__init__(*args, **kwargs)
def execute_complete(self):
@@ -2200,29 +2213,13 @@ class TestStringifiedDAGs:
serialized_obj = SerializedDAG.to_dict(dag)
for task in serialized_obj["dag"]["tasks"]:
- assert task["__var"]["start_trigger"] == trigger.serialize()
- assert task["__var"]["next_method"] == "execute_complete"
-
- @pytest.mark.db_test
- def test_start_trigger_in_serialized_dag_but_no_next_method(self):
- """
- Test that when we provide start_trigger without next_method, an
AriflowException should be raised.
- """
-
- trigger = SuccessTrigger()
-
- class TestOperator(BaseOperator):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.start_trigger = trigger
-
- dag = DAG(dag_id="test_dag", start_date=datetime(2023, 11, 9))
-
- with dag:
- TestOperator(task_id="test_task")
-
- with pytest.raises(AirflowException, match="start_trigger and
next_method should both be set."):
- SerializedDAG.to_dict(dag)
+ assert task["__var"]["start_trigger_args"] == {
+ "trigger_cls": "airflow.triggers.testing.SuccessTrigger",
+ "trigger_kwargs": {},
+ "next_method": "execute_complete",
+ "timeout": None,
+ }
+ assert task["__var"]["start_from_trigger"] is True
def test_kubernetes_optional():
@@ -2274,8 +2271,8 @@ def test_operator_expand_serde():
"_needs_expansion": True,
"_task_module": "airflow.operators.bash",
"_task_type": "BashOperator",
- "start_trigger": None,
- "next_method": None,
+ "start_trigger_args": None,
+ "start_from_trigger": False,
"downstream_task_ids": [],
"expand_input": {
"type": "dict-of-lists",
@@ -2308,8 +2305,8 @@ def test_operator_expand_serde():
assert op.operator_class == {
"_task_type": "BashOperator",
"_needs_expansion": True,
- "start_trigger": None,
- "next_method": None,
+ "start_trigger_args": None,
+ "start_from_trigger": False,
"downstream_task_ids": [],
"task_id": "a",
"template_ext": [".sh", ".bash"],
@@ -2355,8 +2352,8 @@ def test_operator_expand_xcomarg_serde():
"ui_fgcolor": "#000",
"_disallow_kwargs_override": False,
"_expand_input_attr": "expand_input",
- "next_method": None,
- "start_trigger": None,
+ "start_trigger_args": None,
+ "start_from_trigger": False,
}
op = BaseSerialization.deserialize(serialized)
@@ -2413,8 +2410,8 @@ def test_operator_expand_kwargs_literal_serde(strict):
"ui_fgcolor": "#000",
"_disallow_kwargs_override": strict,
"_expand_input_attr": "expand_input",
- "next_method": None,
- "start_trigger": None,
+ "start_trigger_args": None,
+ "start_from_trigger": False,
}
op = BaseSerialization.deserialize(serialized)
@@ -2462,8 +2459,8 @@ def test_operator_expand_kwargs_xcomarg_serde(strict):
"ui_fgcolor": "#000",
"_disallow_kwargs_override": strict,
"_expand_input_attr": "expand_input",
- "next_method": None,
- "start_trigger": None,
+ "start_trigger_args": None,
+ "start_from_trigger": False,
}
op = BaseSerialization.deserialize(serialized)
@@ -2581,8 +2578,8 @@ def test_taskflow_expand_serde():
"template_fields_renderers": {"templates_dict": "json", "op_args":
"py", "op_kwargs": "py"},
"_disallow_kwargs_override": False,
"_expand_input_attr": "op_kwargs_expand_input",
- "next_method": None,
- "start_trigger": None,
+ "start_trigger_args": None,
+ "start_from_trigger": False,
}
deserialized = BaseSerialization.deserialize(serialized)
@@ -2648,8 +2645,8 @@ def test_taskflow_expand_kwargs_serde(strict):
"_task_module": "airflow.decorators.python",
"_task_type": "_PythonDecoratedOperator",
"_operator_name": "@task",
- "next_method": None,
- "start_trigger": None,
+ "start_trigger_args": None,
+ "start_from_trigger": False,
"downstream_task_ids": [],
"partial_kwargs": {
"is_setup": False,
@@ -2801,8 +2798,8 @@ def test_mapped_task_with_operator_extra_links_property():
"_is_empty": False,
"_is_mapped": True,
"_needs_expansion": True,
- "next_method": None,
- "start_trigger": None,
+ "start_trigger_args": None,
+ "start_from_trigger": False,
}
deserialized_dag =
SerializedDAG.deserialize_dag(serialized_dag[Encoding.VAR])
assert deserialized_dag.task_dict["task"].operator_extra_links ==
[AirflowLink2()]
diff --git a/tests/serialization/test_pydantic_models.py
b/tests/serialization/test_pydantic_models.py
index 048faebf54..dae611e68b 100644
--- a/tests/serialization/test_pydantic_models.py
+++ b/tests/serialization/test_pydantic_models.py
@@ -78,8 +78,8 @@ def
test_deserialize_ti_mapped_op_reserialized_with_refresh_from_task(session, d
"_needs_expansion": True,
"_task_type": "_PythonDecoratedOperator",
"downstream_task_ids": [],
- "next_method": None,
- "start_trigger": None,
+ "start_from_trigger": False,
+ "start_trigger_args": None,
"_operator_name": "@task",
"ui_fgcolor": "#000",
"ui_color": "#ffefeb",