This is an automated email from the ASF dual-hosted git repository.

phanikumv pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new fa927276b5 Enhance start_trigger_args serialization (#40993)
fa927276b5 is described below

commit fa927276b52a5e0edca7b1b0bbc62083871a601c
Author: Wei Lee <[email protected]>
AuthorDate: Fri Jul 26 14:39:40 2024 +0800

    Enhance start_trigger_args serialization (#40993)
---
 airflow/serialization/serialized_objects.py   | 46 ++++++++++++++++++++++++---
 airflow/triggers/base.py                      |  9 ------
 tests/serialization/test_dag_serialization.py | 38 ++++++++++++++--------
 3 files changed, 67 insertions(+), 26 deletions(-)

diff --git a/airflow/serialization/serialized_objects.py 
b/airflow/serialization/serialized_objects.py
index 4d6a5b7e57..1513a50d38 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -27,7 +27,7 @@ import warnings
 import weakref
 from inspect import signature
 from textwrap import dedent
-from typing import TYPE_CHECKING, Any, Collection, Iterable, Mapping, 
NamedTuple, Union
+from typing import TYPE_CHECKING, Any, Collection, Iterable, Mapping, 
NamedTuple, Union, cast
 
 import attrs
 import lazy_object_proxy
@@ -350,6 +350,42 @@ def decode_priority_weight_strategy(var: str) -> 
PriorityWeightStrategy:
     return priority_weight_strategy_class()
 
 
+def encode_start_trigger_args(var: StartTriggerArgs) -> dict[str, Any]:
+    """
+    Encode a StartTriggerArgs.
+
+    :meta private:
+    """
+    serialize_kwargs = (
+        lambda key: BaseSerialization.serialize(getattr(var, key)) if 
getattr(var, key) is not None else None
+    )
+    return {
+        "__type": "START_TRIGGER_ARGS",
+        "trigger_cls": var.trigger_cls,
+        "trigger_kwargs": serialize_kwargs("trigger_kwargs"),
+        "next_method": var.next_method,
+        "next_kwargs": serialize_kwargs("next_kwargs"),
+        "timeout": var.timeout.total_seconds() if var.timeout else None,
+    }
+
+
+def decode_start_trigger_args(var: dict[str, Any]) -> StartTriggerArgs:
+    """
+    Decode a StartTriggerArgs.
+
+    :meta private:
+    """
+    deserialize_kwargs = lambda key: BaseSerialization.deserialize(var[key]) 
if var[key] is not None else None
+
+    return StartTriggerArgs(
+        trigger_cls=var["trigger_cls"],
+        trigger_kwargs=deserialize_kwargs("trigger_kwargs"),
+        next_method=var["next_method"],
+        next_kwargs=deserialize_kwargs("next_kwargs"),
+        timeout=datetime.timedelta(seconds=var["timeout"]) if var["timeout"] 
else None,
+    )
+
+
 class _XComRef(NamedTuple):
     """
     Store info needed to create XComArg.
@@ -1088,7 +1124,7 @@ class SerializedBaseOperator(BaseOperator, 
BaseSerialization):
         serialize_op["_is_empty"] = op.inherits_from_empty_operator
 
         serialize_op["start_trigger_args"] = (
-            op.start_trigger_args.serialize() if op.start_trigger_args else 
None
+            encode_start_trigger_args(op.start_trigger_args) if 
op.start_trigger_args else None
         )
         serialize_op["start_from_trigger"] = op.start_from_trigger
 
@@ -1276,8 +1312,10 @@ class SerializedBaseOperator(BaseOperator, 
BaseSerialization):
         setattr(op, "_is_empty", bool(encoded_op.get("_is_empty", False)))
 
         start_trigger_args = None
-        if encoded_op.get("start_trigger_args", None):
-            start_trigger_args = 
StartTriggerArgs(**encoded_op.get("start_trigger_args", None))
+        encoded_start_trigger_args = encoded_op.get("start_trigger_args", None)
+        if encoded_start_trigger_args:
+            encoded_start_trigger_args = cast(dict, encoded_start_trigger_args)
+            start_trigger_args = 
decode_start_trigger_args(encoded_start_trigger_args)
         setattr(op, "start_trigger_args", start_trigger_args)
         setattr(op, "start_from_trigger", 
bool(encoded_op.get("start_from_trigger", False)))
 
diff --git a/airflow/triggers/base.py b/airflow/triggers/base.py
index 190e2983ce..7b5338ad2f 100644
--- a/airflow/triggers/base.py
+++ b/airflow/triggers/base.py
@@ -47,15 +47,6 @@ class StartTriggerArgs:
     next_kwargs: dict[str, Any] | None = None
     timeout: timedelta | None = None
 
-    def serialize(self):
-        return {
-            "trigger_cls": self.trigger_cls,
-            "trigger_kwargs": self.trigger_kwargs,
-            "next_method": self.next_method,
-            "next_kwargs": self.next_kwargs,
-            "timeout": self.timeout,
-        }
-
 
 class BaseTrigger(abc.ABC, LoggingMixin):
     """
diff --git a/tests/serialization/test_dag_serialization.py 
b/tests/serialization/test_dag_serialization.py
index 8650378854..cefc5b7a2d 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -2284,8 +2284,8 @@ class TestStringifiedDAGs:
 
         class TestOperator(BaseOperator):
             start_trigger_args = StartTriggerArgs(
-                trigger_cls="airflow.triggers.testing.SuccessTrigger",
-                trigger_kwargs=None,
+                trigger_cls="airflow.triggers.temporal.TimeDeltaTrigger",
+                trigger_kwargs={"delta": timedelta(seconds=1)},
                 next_method="execute_complete",
                 next_kwargs=None,
                 timeout=None,
@@ -2294,7 +2294,7 @@ class TestStringifiedDAGs:
 
             def __init__(self, *args, **kwargs):
                 super().__init__(*args, **kwargs)
-                self.start_trigger_args.trigger_kwargs = {}
+                self.start_trigger_args.trigger_kwargs = {"delta": 
timedelta(seconds=2)}
                 self.start_from_trigger = True
 
             def execute_complete(self):
@@ -2323,16 +2323,28 @@ class TestStringifiedDAGs:
             Test2Operator(task_id="test_task_2")
 
         serialized_obj = SerializedDAG.to_dict(dag)
-
-        for task in serialized_obj["dag"]["tasks"]:
-            assert task["__var"]["start_trigger_args"] == {
-                "trigger_cls": "airflow.triggers.testing.SuccessTrigger",
-                "trigger_kwargs": {},
-                "next_method": "execute_complete",
-                "next_kwargs": None,
-                "timeout": None,
-            }
-            assert task["__var"]["start_from_trigger"] is True
+        tasks = serialized_obj["dag"]["tasks"]
+
+        assert tasks[0]["__var"]["start_trigger_args"] == {
+            "__type": "START_TRIGGER_ARGS",
+            "trigger_cls": "airflow.triggers.temporal.TimeDeltaTrigger",
+            # "trigger_kwargs": {"__type": "dict", "__var": {"delta": 
{"__type": "timedelta", "__var": 2.0}}},
+            "trigger_kwargs": {"__type": "dict", "__var": {"delta": {"__type": 
"timedelta", "__var": 2.0}}},
+            "next_method": "execute_complete",
+            "next_kwargs": None,
+            "timeout": None,
+        }
+        assert tasks[0]["__var"]["start_from_trigger"] is True
+
+        assert tasks[1]["__var"]["start_trigger_args"] == {
+            "__type": "START_TRIGGER_ARGS",
+            "trigger_cls": "airflow.triggers.testing.SuccessTrigger",
+            "trigger_kwargs": {"__type": "dict", "__var": {}},
+            "next_method": "execute_complete",
+            "next_kwargs": None,
+            "timeout": None,
+        }
+        assert tasks[1]["__var"]["start_from_trigger"] is True
 
 
 def test_kubernetes_optional():

Reply via email to