This is an automated email from the ASF dual-hosted git repository.

weilee pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new d509abfa21 Introduce StartTriggerArgs and prevent start trigger 
initialization in scheduler (#39585)
d509abfa21 is described below

commit d509abfa217565d2d5249c639ea7459c44292368
Author: Wei Lee <[email protected]>
AuthorDate: Tue Jun 11 20:10:01 2024 +0900

    Introduce StartTriggerArgs and prevent start trigger initialization in 
scheduler (#39585)
    
    * fix(baseoperator): change `start_trigger` into `start_trigger_args`
    * feat(baseoperator): add `start_from_trigger` as the flag to decide 
whether to start task execution from triggerer
    * fix(dagrun): set start_date before deferring task from scheduler
---
 airflow/decorators/base.py                         |  4 +-
 airflow/models/abstractoperator.py                 |  5 +-
 airflow/models/baseoperator.py                     | 10 +--
 airflow/models/dagrun.py                           | 16 +---
 airflow/models/mappedoperator.py                   | 17 ++--
 airflow/models/taskinstance.py                     | 39 ++++++---
 airflow/serialization/serialized_objects.py        | 31 +++----
 airflow/triggers/base.py                           | 20 +++++
 .../authoring-and-scheduling/deferring.rst         | 36 +++++---
 tests/models/test_dagrun.py                        | 17 ++--
 tests/serialization/test_dag_serialization.py      | 99 +++++++++++-----------
 tests/serialization/test_pydantic_models.py        |  4 +-
 12 files changed, 168 insertions(+), 130 deletions(-)

diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py
index 2ae85a9c43..74b44ffe23 100644
--- a/airflow/decorators/base.py
+++ b/airflow/decorators/base.py
@@ -509,8 +509,8 @@ class _TaskDecorator(ExpandableFactory, Generic[FParams, 
FReturn, OperatorSubcla
             # task's expand() contribute to the op_kwargs operator argument, 
not
             # the operator arguments themselves, and should expand against it.
             expand_input_attr="op_kwargs_expand_input",
-            start_trigger=self.operator_class.start_trigger,
-            next_method=self.operator_class.next_method,
+            start_trigger_args=self.operator_class.start_trigger_args,
+            start_from_trigger=self.operator_class.start_from_trigger,
         )
         return XComArg(operator=operator)
 
diff --git a/airflow/models/abstractoperator.py 
b/airflow/models/abstractoperator.py
index b7160430e0..1bb83a2dc0 100644
--- a/airflow/models/abstractoperator.py
+++ b/airflow/models/abstractoperator.py
@@ -122,8 +122,9 @@ class AbstractOperator(Templater, DAGNode):
             "node_id",  # Duplicates task_id
             "task_group",  # Doesn't have a useful repr, no point showing in UI
             "inherits_from_empty_operator",  # impl detail
-            "start_trigger",
-            "next_method",
+            # Decide whether to start task execution from triggerer
+            "start_trigger_args",
+            "start_from_trigger",
             # For compatibility with TG, for operators these are just the 
current task, no point showing
             "roots",
             "leaves",
diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index 98532d90b0..bbd629cfc1 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -116,7 +116,7 @@ if TYPE_CHECKING:
     from airflow.models.operator import Operator
     from airflow.models.xcom_arg import XComArg
     from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
-    from airflow.triggers.base import BaseTrigger
+    from airflow.triggers.base import BaseTrigger, StartTriggerArgs
     from airflow.utils.task_group import TaskGroup
     from airflow.utils.types import ArgNotSet
 
@@ -819,8 +819,8 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
     # Set to True for an operator instantiated by a mapped operator.
     __from_mapped = False
 
-    start_trigger: BaseTrigger | None = None
-    next_method: str | None = None
+    start_trigger_args: StartTriggerArgs | None = None
+    start_from_trigger: bool = False
 
     def __init__(
         self,
@@ -1679,9 +1679,9 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
                     "is_teardown",
                     "on_failure_fail_dagrun",
                     "map_index_template",
-                    "start_trigger",
-                    "next_method",
+                    "start_trigger_args",
                     "_needs_expansion",
+                    "start_from_trigger",
                 }
             )
             DagContext.pop_context_managed_dag()
diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py
index 008608b96d..4773a89d1d 100644
--- a/airflow/models/dagrun.py
+++ b/airflow/models/dagrun.py
@@ -51,7 +51,7 @@ from airflow import settings
 from airflow.api_internal.internal_api_call import internal_api_call
 from airflow.callbacks.callback_requests import DagCallbackRequest
 from airflow.configuration import conf as airflow_conf
-from airflow.exceptions import AirflowException, RemovedInAirflow3Warning, 
TaskDeferred, TaskNotFound
+from airflow.exceptions import AirflowException, RemovedInAirflow3Warning, 
TaskNotFound
 from airflow.listeners.listener import get_listener_manager
 from airflow.models import Log
 from airflow.models.abstractoperator import NotMapped
@@ -1538,19 +1538,11 @@ class DagRun(Base, LoggingMixin):
                 and not ti.task.outlets
             ):
                 dummy_ti_ids.append((ti.task_id, ti.map_index))
-            elif (
-                ti.task.start_trigger is not None
-                and ti.task.next_method is not None
-                and not ti.task.on_execute_callback
-                and not ti.task.on_success_callback
-                and not ti.task.outlets
-            ):
+            elif ti.task.start_from_trigger is True and 
ti.task.start_trigger_args is not None:
+                ti.start_date = timezone.utcnow()
                 if ti.state != TaskInstanceState.UP_FOR_RESCHEDULE:
                     ti.try_number += 1
-                ti.defer_task(
-                    exception=TaskDeferred(trigger=ti.task.start_trigger, 
method_name=ti.task.next_method),
-                    session=session,
-                )
+                ti.defer_task(exception=None, session=session)
             else:
                 schedulable_ti_ids.append((ti.task_id, ti.map_index))
 
diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py
index 27d0510c30..abbed3cfa9 100644
--- a/airflow/models/mappedoperator.py
+++ b/airflow/models/mappedoperator.py
@@ -81,7 +81,7 @@ if TYPE_CHECKING:
     from airflow.models.param import ParamsDict
     from airflow.models.xcom_arg import XComArg
     from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
-    from airflow.triggers.base import BaseTrigger
+    from airflow.triggers.base import StartTriggerArgs
     from airflow.utils.context import Context
     from airflow.utils.operator_resources import Resources
     from airflow.utils.task_group import TaskGroup
@@ -237,8 +237,8 @@ class OperatorPartial:
             # For classic operators, this points to expand_input because kwargs
             # to BaseOperator.expand() contribute to operator arguments.
             expand_input_attr="expand_input",
-            start_trigger=self.operator_class.start_trigger,
-            next_method=self.operator_class.next_method,
+            start_trigger_args=self.operator_class.start_trigger_args,
+            start_from_trigger=self.operator_class.start_from_trigger,
         )
         return op
 
@@ -281,8 +281,8 @@ class MappedOperator(AbstractOperator):
     _task_module: str
     _task_type: str
     _operator_name: str
-    start_trigger: BaseTrigger | None
-    next_method: str | None
+    start_trigger_args: StartTriggerArgs | None
+    start_from_trigger: bool
     _needs_expansion: bool = True
 
     dag: DAG | None
@@ -309,12 +309,7 @@ class MappedOperator(AbstractOperator):
     supports_lineage: bool = False
 
     HIDE_ATTRS_FROM_UI: ClassVar[frozenset[str]] = 
AbstractOperator.HIDE_ATTRS_FROM_UI | frozenset(
-        (
-            "parse_time_mapped_ti_count",
-            "operator_class",
-            "start_trigger",
-            "next_method",
-        )
+        ("parse_time_mapped_ti_count", "operator_class", "start_trigger_args", 
"start_from_trigger")
     )
 
     def __hash__(self):
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index c1ace17cd5..373ad108c2 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -1575,12 +1575,29 @@ def _coalesce_to_orm_ti(*, ti: TaskInstancePydantic | 
TaskInstance, session: Ses
 @internal_api_call
 @provide_session
 def _defer_task(
-    ti: TaskInstance | TaskInstancePydantic, exception: TaskDeferred, session: 
Session = NEW_SESSION
+    ti: TaskInstance | TaskInstancePydantic,
+    exception: TaskDeferred | None = None,
+    session: Session = NEW_SESSION,
 ) -> TaskInstancePydantic | TaskInstance:
     from airflow.models.trigger import Trigger
 
+    if exception is not None:
+        trigger_row = Trigger.from_object(exception.trigger)
+        trigger_kwargs = exception.kwargs
+        next_method = exception.method_name
+        timeout = exception.timeout
+    elif ti.task is not None and ti.task.start_trigger_args is not None:
+        trigger_row = Trigger(
+            classpath=ti.task.start_trigger_args.trigger_cls,
+            kwargs=ti.task.start_trigger_args.trigger_kwargs or {},
+        )
+        trigger_kwargs = ti.task.start_trigger_args.trigger_kwargs
+        next_method = ti.task.start_trigger_args.next_method
+        timeout = ti.task.start_trigger_args.timeout
+    else:
+        raise AirflowException("exception and ti.task.start_trigger_args 
cannot both be None")
+
     # First, make the trigger entry
-    trigger_row = Trigger.from_object(exception.trigger)
     session.add(trigger_row)
     session.flush()
 
@@ -1594,12 +1611,12 @@ def _defer_task(
     # depending on self.next_method semantics
     ti.state = TaskInstanceState.DEFERRED
     ti.trigger_id = trigger_row.id
-    ti.next_method = exception.method_name
-    ti.next_kwargs = exception.kwargs or {}
+    ti.next_method = next_method
+    ti.next_kwargs = trigger_kwargs or {}
 
     # Calculate timeout too if it was passed
-    if exception.timeout is not None:
-        ti.trigger_timeout = timezone.utcnow() + exception.timeout
+    if timeout is not None:
+        ti.trigger_timeout = timezone.utcnow() + timeout
     else:
         ti.trigger_timeout = None
 
@@ -1615,8 +1632,10 @@ def _defer_task(
             ti.trigger_timeout = ti.start_date + execution_timeout
     if ti.test_mode:
         _add_log(event=ti.state, task_instance=ti, session=session)
-    session.merge(ti)
-    session.commit()
+
+    if exception is not None:
+        session.merge(ti)
+        session.commit()
     return ti
 
 
@@ -3000,8 +3019,8 @@ class TaskInstance(Base, LoggingMixin):
         return _execute_task(self, context, task_orig)
 
     @provide_session
-    def defer_task(self, exception: TaskDeferred, session: Session) -> None:
-        """Mark the task as deferred and sets up the trigger that is needed to 
resume it.
+    def defer_task(self, exception: TaskDeferred | None, session: Session = 
NEW_SESSION) -> None:
+        """Mark the task as deferred and sets up the trigger that is needed to 
resume it when TaskDeferred is raised.
 
         :meta: private
         """
diff --git a/airflow/serialization/serialized_objects.py 
b/airflow/serialization/serialized_objects.py
index 6e7f50a87c..eb9c15f43c 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -66,11 +66,10 @@ from airflow.task.priority_strategy import (
     airflow_priority_weight_strategies,
     airflow_priority_weight_strategies_classes,
 )
-from airflow.triggers.base import BaseTrigger
+from airflow.triggers.base import BaseTrigger, StartTriggerArgs
 from airflow.utils.code_utils import get_python_source
 from airflow.utils.context import Context, OutletEventAccessor, 
OutletEventAccessors
 from airflow.utils.docs import get_docs_url
-from airflow.utils.helpers import exactly_one
 from airflow.utils.module_loading import import_string, qualname
 from airflow.utils.operator_resources import Resources
 from airflow.utils.task_group import MappedTaskGroup, TaskGroup
@@ -1018,11 +1017,10 @@ class SerializedBaseOperator(BaseOperator, 
BaseSerialization):
         # Used to determine if an Operator is inherited from EmptyOperator
         serialize_op["_is_empty"] = op.inherits_from_empty_operator
 
-        if exactly_one(op.start_trigger is not None, op.next_method is not 
None):
-            raise AirflowException("start_trigger and next_method should both 
be set.")
-
-        serialize_op["start_trigger"] = op.start_trigger.serialize() if 
op.start_trigger else None
-        serialize_op["next_method"] = op.next_method
+        serialize_op["start_trigger_args"] = (
+            op.start_trigger_args.serialize() if op.start_trigger_args else 
None
+        )
+        serialize_op["start_from_trigger"] = op.start_from_trigger
 
         if op.operator_extra_links:
             serialize_op["_operator_extra_links"] = 
cls._serialize_operator_extra_links(
@@ -1206,16 +1204,11 @@ class SerializedBaseOperator(BaseOperator, 
BaseSerialization):
         # Used to determine if an Operator is inherited from EmptyOperator
         setattr(op, "_is_empty", bool(encoded_op.get("_is_empty", False)))
 
-        # Deserialize start_trigger
-        serialized_start_trigger = encoded_op.get("start_trigger")
-        if serialized_start_trigger:
-            trigger_cls_name, trigger_kwargs = serialized_start_trigger
-            trigger_cls = import_string(trigger_cls_name)
-            start_trigger = trigger_cls(**trigger_kwargs)
-            setattr(op, "start_trigger", start_trigger)
-        else:
-            setattr(op, "start_trigger", None)
-        setattr(op, "next_method", encoded_op.get("next_method", None))
+        start_trigger_args = None
+        if encoded_op.get("start_trigger_args", None):
+            start_trigger_args = 
StartTriggerArgs(**encoded_op.get("start_trigger_args", None))
+        setattr(op, "start_trigger_args", start_trigger_args)
+        setattr(op, "start_from_trigger", 
bool(encoded_op.get("start_from_trigger", False)))
 
     @staticmethod
     def set_task_dag_references(task: Operator, dag: DAG) -> None:
@@ -1278,8 +1271,8 @@ class SerializedBaseOperator(BaseOperator, 
BaseSerialization):
                 end_date=None,
                 
disallow_kwargs_override=encoded_op["_disallow_kwargs_override"],
                 expand_input_attr=encoded_op["_expand_input_attr"],
-                start_trigger=None,
-                next_method=None,
+                start_trigger_args=encoded_op.get("start_trigger_args", None),
+                start_from_trigger=encoded_op.get("start_from_trigger", False),
             )
         else:
             op = SerializedBaseOperator(task_id=encoded_op["task_id"])
diff --git a/airflow/triggers/base.py b/airflow/triggers/base.py
index 0d239af0ca..ac5727a1ba 100644
--- a/airflow/triggers/base.py
+++ b/airflow/triggers/base.py
@@ -17,11 +17,31 @@
 from __future__ import annotations
 
 import abc
+from dataclasses import dataclass
+from datetime import timedelta
 from typing import Any, AsyncIterator
 
 from airflow.utils.log.logging_mixin import LoggingMixin
 
 
+@dataclass
+class StartTriggerArgs:
+    """Arguments required for start task execution from triggerer."""
+
+    trigger_cls: str
+    next_method: str
+    trigger_kwargs: dict[str, Any] | None = None
+    timeout: timedelta | None = None
+
+    def serialize(self):
+        return {
+            "trigger_cls": self.trigger_cls,
+            "trigger_kwargs": self.trigger_kwargs,
+            "next_method": self.next_method,
+            "timeout": self.timeout,
+        }
+
+
 class BaseTrigger(abc.ABC, LoggingMixin):
     """
     Base class for all triggers.
diff --git a/docs/apache-airflow/authoring-and-scheduling/deferring.rst 
b/docs/apache-airflow/authoring-and-scheduling/deferring.rst
index 084a08f0ac..a65e932a90 100644
--- a/docs/apache-airflow/authoring-and-scheduling/deferring.rst
+++ b/docs/apache-airflow/authoring-and-scheduling/deferring.rst
@@ -143,10 +143,14 @@ The ``self.defer`` call raises the ``TaskDeferred`` 
exception, so it can work an
 Triggering Deferral from Start
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
-If you want to defer your task directly to the triggerer without going into 
the worker, you can add the class level attributes ``start_trigger`` and 
``next_method`` to your deferrable operator.
+ .. versionadded:: 2.10.0
 
-* ``start_trigger``: An instance of a trigger you want to defer to. It will be 
serialized into the database.
+If you want to defer your task directly to the triggerer without going into 
the worker, you can set class level attribute ``start_with_trigger`` to 
``True`` and add a class level attribute ``start_trigger_args`` with an 
``StartTriggerArgs`` object with the following 4 attributes to your deferrable 
operator:
+
+* ``trigger_cls``: An importable path to your trigger class.
+* ``trigger_kwargs``: Additional keyword arguments to pass to the method when 
it is called.
 * ``next_method``: The method name on your operator that you want Airflow to 
call when it resumes.
+* ``timeout``: (Optional) A timedelta that specifies a timeout after which 
this deferral will fail, and fail the task instance. Defaults to ``None``, 
which means no timeout.
 
 
 This is particularly useful when deferring is the only thing the ``execute`` 
method does. Here's a basic refinement of the previous example.
@@ -156,23 +160,28 @@ This is particularly useful when deferring is the only 
thing the ``execute`` met
     from datetime import timedelta
     from typing import Any
 
+    from airflow.triggers.base import StartTriggerArgs
     from airflow.sensors.base import BaseSensorOperator
-    from airflow.triggers.temporal import TimeDeltaTrigger
     from airflow.utils.context import Context
 
 
     class WaitOneHourSensor(BaseSensorOperator):
-        start_trigger = TimeDeltaTrigger(timedelta(hours=1))
-        next_method = "execute_complete"
+        start_trigger_args = StartTriggerArgs(
+            trigger_cls="airflow.triggers.temporal.TimeDeltaTrigger",
+            trigger_kwargs={"moment": timedelta(hours=1)},
+            next_method="execute_complete",
+            timeout=None,
+        )
+        start_from_trigger = True
 
         def execute_complete(self, context: Context, event: dict[str, Any] | 
None = None) -> None:
             # We have no more work to do here. Mark as complete.
             return
 
-``start_trigger`` and ``next_method`` can also be set at the instance level 
for more flexible configuration.
+``start_from_trigger`` and ``trigger_kwargs`` can also be modified at the 
instance level for more flexible configuration.
 
 .. warning::
-    Dynamic task mapping is not supported when ``start_trigger`` and 
``next_method`` are assigned in instance level.
+    Dynamic task mapping is not supported when ``trigger_kwargs`` is modified 
at instance level.
 
 .. code-block:: python
 
@@ -184,11 +193,18 @@ This is particularly useful when deferring is the only 
thing the ``execute`` met
     from airflow.utils.context import Context
 
 
-    class WaitOneHourSensor(BaseSensorOperator):
+    class WaitTwoHourSensor(BaseSensorOperator):
+        start_trigger_args = StartTriggerArgs(
+            trigger_cls="airflow.triggers.temporal.TimeDeltaTrigger",
+            trigger_kwargs={},
+            next_method="execute_complete",
+            timeout=None,
+        )
+
         def __init__(self, *args: list[Any], **kwargs: dict[str, Any]) -> None:
             super().__init__(*args, **kwargs)
-            self.start_trigger = TimeDeltaTrigger(timedelta(hours=1))
-            self.next_method = "execute_complete"
+            self.start_trigger_args.trigger_kwargs = {"moment": 
timedelta(hours=1)}
+            self.start_from_trigger = True
 
         def execute_complete(self, context: Context, event: dict[str, Any] | 
None = None) -> None:
             # We have no more work to do here. Mark as complete.
diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py
index fe1c2d58a2..93e0611243 100644
--- a/tests/models/test_dagrun.py
+++ b/tests/models/test_dagrun.py
@@ -40,7 +40,7 @@ from airflow.operators.empty import EmptyOperator
 from airflow.operators.python import ShortCircuitOperator
 from airflow.serialization.serialized_objects import SerializedDAG
 from airflow.stats import Stats
-from airflow.triggers.testing import SuccessTrigger
+from airflow.triggers.base import StartTriggerArgs
 from airflow.utils import timezone
 from airflow.utils.state import DagRunState, State, TaskInstanceState
 from airflow.utils.trigger_rule import TriggerRule
@@ -1989,16 +1989,21 @@ def test_schedule_tis_map_index(dag_maker, session):
 
 def test_schedule_tis_start_trigger(dag_maker, session):
     """
-    Test that an operator with _start_trigger and _next_method set can be 
directly
-    deferred during scheduling.
+    Test that an operator with start_trigger_args set can be directly deferred 
during scheduling.
     """
-    trigger = SuccessTrigger()
 
     class TestOperator(BaseOperator):
+        start_trigger_args = StartTriggerArgs(
+            trigger_cls="airflow.triggers.testing.SuccessTrigger",
+            trigger_kwargs=None,
+            next_method="execute_complete",
+            timeout=None,
+        )
+        start_from_trigger = True
+
         def __init__(self, *args, **kwargs):
             super().__init__(*args, **kwargs)
-            self.start_trigger = trigger
-            self.next_method = "execute_complete"
+            self.start_trigger_args.trigger_kwargs = {}
 
         def execute_complete(self):
             pass
diff --git a/tests/serialization/test_dag_serialization.py 
b/tests/serialization/test_dag_serialization.py
index 093b7fba76..16f6d3cb68 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -72,7 +72,7 @@ from airflow.serialization.serialized_objects import (
 from airflow.task.priority_strategy import _DownstreamPriorityWeightStrategy
 from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
 from airflow.timetables.simple import NullTimetable, OnceTimetable
-from airflow.triggers.testing import SuccessTrigger
+from airflow.triggers.base import StartTriggerArgs
 from airflow.utils import timezone
 from airflow.utils.operator_resources import Resources
 from airflow.utils.task_group import TaskGroup
@@ -197,8 +197,8 @@ serialized_simple_dag_ground_truth = {
                     "_log_config_logger_name": "airflow.task.operators",
                     "_needs_expansion": False,
                     "weight_rule": "downstream",
-                    "next_method": None,
-                    "start_trigger": None,
+                    "start_trigger_args": None,
+                    "start_from_trigger": False,
                 },
             },
             {
@@ -227,8 +227,8 @@ serialized_simple_dag_ground_truth = {
                     "_log_config_logger_name": "airflow.task.operators",
                     "_needs_expansion": False,
                     "weight_rule": "downstream",
-                    "next_method": None,
-                    "start_trigger": None,
+                    "start_trigger_args": None,
+                    "start_from_trigger": False,
                 },
             },
         ],
@@ -2167,25 +2167,38 @@ class TestStringifiedDAGs:
             SerializedDAG.to_dict(dag)
 
     @pytest.mark.db_test
-    def test_start_trigger_and_next_method_in_serialized_dag(self):
+    def test_start_trigger_args_in_serialized_dag(self):
         """
-        Test that when we provide start_trigger and next_method, the DAG can 
be correctly serialized.
+        Test that when we provide start_trigger_args, the DAG can be correctly 
serialized.
         """
-        trigger = SuccessTrigger()
 
         class TestOperator(BaseOperator):
+            start_trigger_args = StartTriggerArgs(
+                trigger_cls="airflow.triggers.testing.SuccessTrigger",
+                trigger_kwargs=None,
+                next_method="execute_complete",
+                timeout=None,
+            )
+            start_from_trigger = False
+
             def __init__(self, *args, **kwargs):
                 super().__init__(*args, **kwargs)
-                self.start_trigger = trigger
-                self.next_method = "execute_complete"
+                self.start_trigger_args.trigger_kwargs = {}
+                self.start_from_trigger = True
 
             def execute_complete(self):
                 pass
 
         class Test2Operator(BaseOperator):
+            start_trigger_args = StartTriggerArgs(
+                trigger_cls="airflow.triggers.testing.SuccessTrigger",
+                trigger_kwargs={},
+                next_method="execute_complete",
+                timeout=None,
+            )
+            start_from_trigger = True
+
             def __init__(self, *args, **kwargs):
-                self.start_trigger = trigger
-                self.next_method = "execute_complete"
                 super().__init__(*args, **kwargs)
 
             def execute_complete(self):
@@ -2200,29 +2213,13 @@ class TestStringifiedDAGs:
         serialized_obj = SerializedDAG.to_dict(dag)
 
         for task in serialized_obj["dag"]["tasks"]:
-            assert task["__var"]["start_trigger"] == trigger.serialize()
-            assert task["__var"]["next_method"] == "execute_complete"
-
-    @pytest.mark.db_test
-    def test_start_trigger_in_serialized_dag_but_no_next_method(self):
-        """
-        Test that when we provide start_trigger without next_method, an 
AriflowException should be raised.
-        """
-
-        trigger = SuccessTrigger()
-
-        class TestOperator(BaseOperator):
-            def __init__(self, *args, **kwargs):
-                super().__init__(*args, **kwargs)
-                self.start_trigger = trigger
-
-        dag = DAG(dag_id="test_dag", start_date=datetime(2023, 11, 9))
-
-        with dag:
-            TestOperator(task_id="test_task")
-
-        with pytest.raises(AirflowException, match="start_trigger and 
next_method should both be set."):
-            SerializedDAG.to_dict(dag)
+            assert task["__var"]["start_trigger_args"] == {
+                "trigger_cls": "airflow.triggers.testing.SuccessTrigger",
+                "trigger_kwargs": {},
+                "next_method": "execute_complete",
+                "timeout": None,
+            }
+            assert task["__var"]["start_from_trigger"] is True
 
 
 def test_kubernetes_optional():
@@ -2274,8 +2271,8 @@ def test_operator_expand_serde():
         "_needs_expansion": True,
         "_task_module": "airflow.operators.bash",
         "_task_type": "BashOperator",
-        "start_trigger": None,
-        "next_method": None,
+        "start_trigger_args": None,
+        "start_from_trigger": False,
         "downstream_task_ids": [],
         "expand_input": {
             "type": "dict-of-lists",
@@ -2308,8 +2305,8 @@ def test_operator_expand_serde():
     assert op.operator_class == {
         "_task_type": "BashOperator",
         "_needs_expansion": True,
-        "start_trigger": None,
-        "next_method": None,
+        "start_trigger_args": None,
+        "start_from_trigger": False,
         "downstream_task_ids": [],
         "task_id": "a",
         "template_ext": [".sh", ".bash"],
@@ -2355,8 +2352,8 @@ def test_operator_expand_xcomarg_serde():
         "ui_fgcolor": "#000",
         "_disallow_kwargs_override": False,
         "_expand_input_attr": "expand_input",
-        "next_method": None,
-        "start_trigger": None,
+        "start_trigger_args": None,
+        "start_from_trigger": False,
     }
 
     op = BaseSerialization.deserialize(serialized)
@@ -2413,8 +2410,8 @@ def test_operator_expand_kwargs_literal_serde(strict):
         "ui_fgcolor": "#000",
         "_disallow_kwargs_override": strict,
         "_expand_input_attr": "expand_input",
-        "next_method": None,
-        "start_trigger": None,
+        "start_trigger_args": None,
+        "start_from_trigger": False,
     }
 
     op = BaseSerialization.deserialize(serialized)
@@ -2462,8 +2459,8 @@ def test_operator_expand_kwargs_xcomarg_serde(strict):
         "ui_fgcolor": "#000",
         "_disallow_kwargs_override": strict,
         "_expand_input_attr": "expand_input",
-        "next_method": None,
-        "start_trigger": None,
+        "start_trigger_args": None,
+        "start_from_trigger": False,
     }
 
     op = BaseSerialization.deserialize(serialized)
@@ -2581,8 +2578,8 @@ def test_taskflow_expand_serde():
         "template_fields_renderers": {"templates_dict": "json", "op_args": 
"py", "op_kwargs": "py"},
         "_disallow_kwargs_override": False,
         "_expand_input_attr": "op_kwargs_expand_input",
-        "next_method": None,
-        "start_trigger": None,
+        "start_trigger_args": None,
+        "start_from_trigger": False,
     }
 
     deserialized = BaseSerialization.deserialize(serialized)
@@ -2648,8 +2645,8 @@ def test_taskflow_expand_kwargs_serde(strict):
         "_task_module": "airflow.decorators.python",
         "_task_type": "_PythonDecoratedOperator",
         "_operator_name": "@task",
-        "next_method": None,
-        "start_trigger": None,
+        "start_trigger_args": None,
+        "start_from_trigger": False,
         "downstream_task_ids": [],
         "partial_kwargs": {
             "is_setup": False,
@@ -2801,8 +2798,8 @@ def test_mapped_task_with_operator_extra_links_property():
         "_is_empty": False,
         "_is_mapped": True,
         "_needs_expansion": True,
-        "next_method": None,
-        "start_trigger": None,
+        "start_trigger_args": None,
+        "start_from_trigger": False,
     }
     deserialized_dag = 
SerializedDAG.deserialize_dag(serialized_dag[Encoding.VAR])
     assert deserialized_dag.task_dict["task"].operator_extra_links == 
[AirflowLink2()]
diff --git a/tests/serialization/test_pydantic_models.py 
b/tests/serialization/test_pydantic_models.py
index 048faebf54..dae611e68b 100644
--- a/tests/serialization/test_pydantic_models.py
+++ b/tests/serialization/test_pydantic_models.py
@@ -78,8 +78,8 @@ def 
test_deserialize_ti_mapped_op_reserialized_with_refresh_from_task(session, d
         "_needs_expansion": True,
         "_task_type": "_PythonDecoratedOperator",
         "downstream_task_ids": [],
-        "next_method": None,
-        "start_trigger": None,
+        "start_from_trigger": False,
+        "start_trigger_args": None,
         "_operator_name": "@task",
         "ui_fgcolor": "#000",
         "ui_color": "#ffefeb",

Reply via email to