This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 3b72458240d Remove callbacks from DAG `default_args` when
serializating it (#57397)
3b72458240d is described below
commit 3b72458240d06e1d40b895b8ac705a5d607ca8ab
Author: Kaxil Naik <[email protected]>
AuthorDate: Tue Oct 28 11:23:19 2025 +0000
Remove callbacks from DAG `default_args` when serializating it (#57397)
---
.../airflow/serialization/serialized_objects.py | 17 ++++++
.../unit/serialization/test_dag_serialization.py | 62 ++++++++++++++++++++++
2 files changed, 79 insertions(+)
diff --git a/airflow-core/src/airflow/serialization/serialized_objects.py
b/airflow-core/src/airflow/serialization/serialized_objects.py
index 680816a6545..d43a2375ef9 100644
--- a/airflow-core/src/airflow/serialization/serialized_objects.py
+++ b/airflow-core/src/airflow/serialization/serialized_objects.py
@@ -2507,6 +2507,23 @@ class SerializedDAG(BaseSerialization):
serialized_dag["has_on_success_callback"] = True
if dag.has_on_failure_callback:
serialized_dag["has_on_failure_callback"] = True
+
+ # TODO: Move this logic to a better place -- ideally before
serializing contents of default_args.
+ # There is some duplication with this and
SerializedBaseOperator.partial_kwargs serialization.
+ # Ideally default_args goes through same logic as fields of
SerializedBaseOperator.
+ if serialized_dag.get("default_args", {}):
+ default_args_dict =
serialized_dag["default_args"][Encoding.VAR]
+ callbacks_to_remove = []
+ for k, v in list(default_args_dict.items()):
+ if k in [
+ f"on_{x}_callback" for x in ("execute", "failure",
"success", "retry", "skipped")
+ ]:
+ if bool(v):
+ default_args_dict[f"has_{k}"] = True
+ callbacks_to_remove.append(k)
+ for k in callbacks_to_remove:
+ del default_args_dict[k]
+
return serialized_dag
except SerializationError:
raise
diff --git a/airflow-core/tests/unit/serialization/test_dag_serialization.py
b/airflow-core/tests/unit/serialization/test_dag_serialization.py
index 03e205f2c98..0d070bfbdef 100644
--- a/airflow-core/tests/unit/serialization/test_dag_serialization.py
+++ b/airflow-core/tests/unit/serialization/test_dag_serialization.py
@@ -4291,3 +4291,65 @@ class TestMappedOperatorSerializationAndClientDefaults:
assert "owner" in deserialized_task.partial_kwargs
assert deserialized_task.partial_kwargs["retry_delay"] ==
timedelta(seconds=600)
assert deserialized_task.partial_kwargs["owner"] == "custom_owner"
+
+
[email protected](
+ ["callbacks", "expected_has_flags", "absent_keys"],
+ [
+ pytest.param(
+ {
+ "on_failure_callback": lambda ctx: None,
+ "on_success_callback": lambda ctx: None,
+ "on_retry_callback": lambda ctx: None,
+ },
+ ["has_on_failure_callback", "has_on_success_callback",
"has_on_retry_callback"],
+ ["on_failure_callback", "on_success_callback",
"on_retry_callback"],
+ id="multiple_callbacks",
+ ),
+ pytest.param(
+ {"on_failure_callback": lambda ctx: None},
+ ["has_on_failure_callback"],
+ ["on_failure_callback", "has_on_success_callback",
"on_success_callback"],
+ id="single_callback",
+ ),
+ pytest.param(
+ {"on_failure_callback": lambda ctx: None, "on_execute_callback":
None},
+ ["has_on_failure_callback"],
+ ["on_failure_callback", "has_on_execute_callback",
"on_execute_callback"],
+ id="callback_with_none",
+ ),
+ pytest.param(
+ {},
+ [],
+ [
+ "has_on_execute_callback",
+ "has_on_failure_callback",
+ "has_on_success_callback",
+ "has_on_retry_callback",
+ "has_on_skipped_callback",
+ ],
+ id="no_callbacks",
+ ),
+ ],
+)
+def test_dag_default_args_callbacks_serialization(callbacks,
expected_has_flags, absent_keys):
+ """Test callbacks in DAG default_args are serialized as boolean flags."""
+ default_args = {"owner": "test_owner", "retries": 2, **callbacks}
+
+ with DAG(dag_id="test_default_args_callbacks", default_args=default_args)
as dag:
+ BashOperator(task_id="task1", bash_command="echo 1", dag=dag)
+
+ serialized_dag_dict = SerializedDAG.serialize_dag(dag)
+ default_args_dict = serialized_dag_dict["default_args"][Encoding.VAR]
+
+ for flag in expected_has_flags:
+ assert default_args_dict.get(flag) is True
+
+ for key in absent_keys:
+ assert key not in default_args_dict
+
+ assert default_args_dict["owner"] == "test_owner"
+ assert default_args_dict["retries"] == 2
+
+ deserialized_dag = SerializedDAG.deserialize_dag(serialized_dag_dict)
+ assert deserialized_dag.dag_id == "test_default_args_callbacks"