This is an automated email from the ASF dual-hosted git repository.
phanikumv pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new fa927276b5 Enhance start_trigger_args serialization (#40993)
fa927276b5 is described below
commit fa927276b52a5e0edca7b1b0bbc62083871a601c
Author: Wei Lee <[email protected]>
AuthorDate: Fri Jul 26 14:39:40 2024 +0800
Enhance start_trigger_args serialization (#40993)
---
airflow/serialization/serialized_objects.py | 46 ++++++++++++++++++++++++---
airflow/triggers/base.py | 9 ------
tests/serialization/test_dag_serialization.py | 38 ++++++++++++++--------
3 files changed, 67 insertions(+), 26 deletions(-)
diff --git a/airflow/serialization/serialized_objects.py
b/airflow/serialization/serialized_objects.py
index 4d6a5b7e57..1513a50d38 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -27,7 +27,7 @@ import warnings
import weakref
from inspect import signature
from textwrap import dedent
-from typing import TYPE_CHECKING, Any, Collection, Iterable, Mapping,
NamedTuple, Union
+from typing import TYPE_CHECKING, Any, Collection, Iterable, Mapping,
NamedTuple, Union, cast
import attrs
import lazy_object_proxy
@@ -350,6 +350,42 @@ def decode_priority_weight_strategy(var: str) ->
PriorityWeightStrategy:
return priority_weight_strategy_class()
+def encode_start_trigger_args(var: StartTriggerArgs) -> dict[str, Any]:
+ """
+ Encode a StartTriggerArgs.
+
+ :meta private:
+ """
+ serialize_kwargs = (
+ lambda key: BaseSerialization.serialize(getattr(var, key)) if
getattr(var, key) is not None else None
+ )
+ return {
+ "__type": "START_TRIGGER_ARGS",
+ "trigger_cls": var.trigger_cls,
+ "trigger_kwargs": serialize_kwargs("trigger_kwargs"),
+ "next_method": var.next_method,
+ "next_kwargs": serialize_kwargs("next_kwargs"),
+ "timeout": var.timeout.total_seconds() if var.timeout else None,
+ }
+
+
+def decode_start_trigger_args(var: dict[str, Any]) -> StartTriggerArgs:
+ """
+ Decode a StartTriggerArgs.
+
+ :meta private:
+ """
+ deserialize_kwargs = lambda key: BaseSerialization.deserialize(var[key])
if var[key] is not None else None
+
+ return StartTriggerArgs(
+ trigger_cls=var["trigger_cls"],
+ trigger_kwargs=deserialize_kwargs("trigger_kwargs"),
+ next_method=var["next_method"],
+ next_kwargs=deserialize_kwargs("next_kwargs"),
+ timeout=datetime.timedelta(seconds=var["timeout"]) if var["timeout"]
else None,
+ )
+
+
class _XComRef(NamedTuple):
"""
Store info needed to create XComArg.
@@ -1088,7 +1124,7 @@ class SerializedBaseOperator(BaseOperator,
BaseSerialization):
serialize_op["_is_empty"] = op.inherits_from_empty_operator
serialize_op["start_trigger_args"] = (
- op.start_trigger_args.serialize() if op.start_trigger_args else
None
+ encode_start_trigger_args(op.start_trigger_args) if
op.start_trigger_args else None
)
serialize_op["start_from_trigger"] = op.start_from_trigger
@@ -1276,8 +1312,10 @@ class SerializedBaseOperator(BaseOperator,
BaseSerialization):
setattr(op, "_is_empty", bool(encoded_op.get("_is_empty", False)))
start_trigger_args = None
- if encoded_op.get("start_trigger_args", None):
- start_trigger_args =
StartTriggerArgs(**encoded_op.get("start_trigger_args", None))
+ encoded_start_trigger_args = encoded_op.get("start_trigger_args", None)
+ if encoded_start_trigger_args:
+ encoded_start_trigger_args = cast(dict, encoded_start_trigger_args)
+ start_trigger_args =
decode_start_trigger_args(encoded_start_trigger_args)
setattr(op, "start_trigger_args", start_trigger_args)
setattr(op, "start_from_trigger",
bool(encoded_op.get("start_from_trigger", False)))
diff --git a/airflow/triggers/base.py b/airflow/triggers/base.py
index 190e2983ce..7b5338ad2f 100644
--- a/airflow/triggers/base.py
+++ b/airflow/triggers/base.py
@@ -47,15 +47,6 @@ class StartTriggerArgs:
next_kwargs: dict[str, Any] | None = None
timeout: timedelta | None = None
- def serialize(self):
- return {
- "trigger_cls": self.trigger_cls,
- "trigger_kwargs": self.trigger_kwargs,
- "next_method": self.next_method,
- "next_kwargs": self.next_kwargs,
- "timeout": self.timeout,
- }
-
class BaseTrigger(abc.ABC, LoggingMixin):
"""
diff --git a/tests/serialization/test_dag_serialization.py
b/tests/serialization/test_dag_serialization.py
index 8650378854..cefc5b7a2d 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -2284,8 +2284,8 @@ class TestStringifiedDAGs:
class TestOperator(BaseOperator):
start_trigger_args = StartTriggerArgs(
- trigger_cls="airflow.triggers.testing.SuccessTrigger",
- trigger_kwargs=None,
+ trigger_cls="airflow.triggers.temporal.TimeDeltaTrigger",
+ trigger_kwargs={"delta": timedelta(seconds=1)},
next_method="execute_complete",
next_kwargs=None,
timeout=None,
@@ -2294,7 +2294,7 @@ class TestStringifiedDAGs:
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
- self.start_trigger_args.trigger_kwargs = {}
+ self.start_trigger_args.trigger_kwargs = {"delta":
timedelta(seconds=2)}
self.start_from_trigger = True
def execute_complete(self):
@@ -2323,16 +2323,28 @@ class TestStringifiedDAGs:
Test2Operator(task_id="test_task_2")
serialized_obj = SerializedDAG.to_dict(dag)
-
- for task in serialized_obj["dag"]["tasks"]:
- assert task["__var"]["start_trigger_args"] == {
- "trigger_cls": "airflow.triggers.testing.SuccessTrigger",
- "trigger_kwargs": {},
- "next_method": "execute_complete",
- "next_kwargs": None,
- "timeout": None,
- }
- assert task["__var"]["start_from_trigger"] is True
+ tasks = serialized_obj["dag"]["tasks"]
+
+ assert tasks[0]["__var"]["start_trigger_args"] == {
+ "__type": "START_TRIGGER_ARGS",
+ "trigger_cls": "airflow.triggers.temporal.TimeDeltaTrigger",
+ # "trigger_kwargs": {"__type": "dict", "__var": {"delta":
{"__type": "timedelta", "__var": 2.0}}},
+ "trigger_kwargs": {"__type": "dict", "__var": {"delta": {"__type":
"timedelta", "__var": 2.0}}},
+ "next_method": "execute_complete",
+ "next_kwargs": None,
+ "timeout": None,
+ }
+ assert tasks[0]["__var"]["start_from_trigger"] is True
+
+ assert tasks[1]["__var"]["start_trigger_args"] == {
+ "__type": "START_TRIGGER_ARGS",
+ "trigger_cls": "airflow.triggers.testing.SuccessTrigger",
+ "trigger_kwargs": {"__type": "dict", "__var": {}},
+ "next_method": "execute_complete",
+ "next_kwargs": None,
+ "timeout": None,
+ }
+ assert tasks[1]["__var"]["start_from_trigger"] is True
def test_kubernetes_optional():