This is an automated email from the ASF dual-hosted git repository.

amoghdesai pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new c03fc794300 Respect default_args in DAG when its set to a "falsy" 
value (#57853)
c03fc794300 is described below

commit c03fc7943009b02072ad275dd31e54f927bff889
Author: Amogh Desai <[email protected]>
AuthorDate: Mon Nov 17 19:12:26 2025 +0530

    Respect default_args in DAG when its set to a "falsy" value (#57853)
---
 .../airflow/serialization/serialized_objects.py    |  49 ++++-
 .../unit/serialization/test_dag_serialization.py   | 238 ++++++++++++---------
 2 files changed, 181 insertions(+), 106 deletions(-)

diff --git a/airflow-core/src/airflow/serialization/serialized_objects.py 
b/airflow-core/src/airflow/serialization/serialized_objects.py
index 9a2a8efe96e..d2f13c1dc92 100644
--- a/airflow-core/src/airflow/serialization/serialized_objects.py
+++ b/airflow-core/src/airflow/serialization/serialized_objects.py
@@ -1241,6 +1241,8 @@ class SerializedBaseOperator(DAGNode, BaseSerialization):
 
     _json_schema: ClassVar[Validator] = 
lazy_object_proxy.Proxy(load_dag_schema)
 
+    _const_fields: ClassVar[set[str] | None] = None
+
     _can_skip_downstream: bool
     _is_empty: bool
     _needs_expansion: bool
@@ -1711,6 +1713,22 @@ class SerializedBaseOperator(DAGNode, BaseSerialization):
             # Bypass set_upstream etc here - it does more than we want
             dag.task_dict[task_id].upstream_task_ids.add(task.task_id)
 
+    @classmethod
+    def get_operator_const_fields(cls) -> set[str]:
+        """Get the set of operator fields that are marked as const in the JSON 
schema."""
+        if (schema_loader := cls._json_schema) is None:
+            return set()
+
+        schema_data = schema_loader.schema
+        operator_def = schema_data.get("definitions", {}).get("operator", {})
+        properties = operator_def.get("properties", {})
+
+        return {
+            field_name
+            for field_name, field_def in properties.items()
+            if isinstance(field_def, dict) and field_def.get("const")
+        }
+
     @classmethod
     @lru_cache(maxsize=1)  # Only one type: "operator"
     def get_operator_optional_fields_from_schema(cls) -> set[str]:
@@ -1866,10 +1884,39 @@ class SerializedBaseOperator(DAGNode, 
BaseSerialization):
         # Check if value matches client_defaults (hierarchical defaults 
optimization)
         if cls._matches_client_defaults(var, attrname):
             return True
-        schema_defaults = cls.get_schema_defaults("operator")
 
+        # for const fields, we should always be excluded when False, 
regardless of client_defaults
+        # Use class-level cache for optimisation
+        if cls._const_fields is None:
+            cls._const_fields = cls.get_operator_const_fields()
+        if attrname in cls._const_fields and var is False:
+            return True
+
+        schema_defaults = cls.get_schema_defaults("operator")
         if attrname in schema_defaults:
             if schema_defaults[attrname] == var:
+                # If it also matches client_defaults, exclude (optimization)
+                client_defaults = cls.generate_client_defaults()
+                if attrname in client_defaults:
+                    if client_defaults[attrname] == var:
+                        return True
+                    # If client_defaults differs, preserve explicit override 
from user
+                    # Example: default_args={"retries": 0}, schema default=0, 
client_defaults={"retries": 3}
+                    if client_defaults[attrname] != var:
+                        if op.has_dag():
+                            dag = op.dag
+                            if dag and attrname in dag.default_args and 
dag.default_args[attrname] == var:
+                                return False
+                        if (
+                            hasattr(op, "_BaseOperator__init_kwargs")
+                            and attrname in op._BaseOperator__init_kwargs
+                            and op._BaseOperator__init_kwargs[attrname] == var
+                        ):
+                            return False
+
+                # If client_defaults doesn't have this field (matches schema 
default),
+                # exclude for optimization even if in default_args
+                # Example: default_args={"depends_on_past": False}, schema 
default=False
                 return True
         optional_fields = cls.get_operator_optional_fields_from_schema()
         if var is None:
diff --git a/airflow-core/tests/unit/serialization/test_dag_serialization.py 
b/airflow-core/tests/unit/serialization/test_dag_serialization.py
index c3e4e26ce27..3a8219d02e4 100644
--- a/airflow-core/tests/unit/serialization/test_dag_serialization.py
+++ b/airflow-core/tests/unit/serialization/test_dag_serialization.py
@@ -104,40 +104,41 @@ if TYPE_CHECKING:
     from airflow.sdk.definitions.context import Context
 
 
[email protected]
-def operator_defaults(overrides):
[email protected]
+def operator_defaults(monkeypatch):
     """
-    Temporarily patches OPERATOR_DEFAULTS, restoring original values after 
context exit.
+    Fixture that provides a context manager to temporarily patch 
OPERATOR_DEFAULTS.
 
-    Example:
-        with operator_defaults({"retries": 2, "retry_delay": 200.0}):
-            # Test code with modified operator defaults
+    Usage:
+        def test_something(operator_defaults):
+            with operator_defaults({"retries": 2, "retry_delay": 200.0}):
+                # Test code with modified operator defaults
     """
+    import airflow.sdk.definitions._internal.abstractoperator as 
abstract_op_module
     from airflow.sdk.bases.operator import OPERATOR_DEFAULTS
+    from airflow.serialization.serialized_objects import SerializedBaseOperator
 
-    original_values = {}
-    try:
-        # Store original values and apply overrides
+    @contextlib.contextmanager
+    def _operator_defaults(overrides):
+        # Patch OPERATOR_DEFAULTS
         for key, value in overrides.items():
-            original_values[key] = OPERATOR_DEFAULTS.get(key)
-            OPERATOR_DEFAULTS[key] = value
+            monkeypatch.setitem(OPERATOR_DEFAULTS, key, value)
+
+            # Patch module-level constants
+            const_name = f"DEFAULT_{key.upper()}"
+            if hasattr(abstract_op_module, const_name):
+                monkeypatch.setattr(abstract_op_module, const_name, value)
 
         # Clear the cache to ensure fresh generation
         SerializedBaseOperator.generate_client_defaults.cache_clear()
 
-        yield
-    finally:
-        # Cleanup: restore original values
-        for key, original_value in original_values.items():
-            if original_value is None and key in OPERATOR_DEFAULTS:
-                # Key didn't exist originally, remove it
-                del OPERATOR_DEFAULTS[key]
-            else:
-                # Restore original value
-                OPERATOR_DEFAULTS[key] = original_value
+        try:
+            yield
+        finally:
+            # Clear cache again to restore normal behavior
+            SerializedBaseOperator.generate_client_defaults.cache_clear()
 
-        # Clear cache again to restore normal behavior
-        SerializedBaseOperator.generate_client_defaults.cache_clear()
+    return _operator_defaults
 
 
 AIRFLOW_REPO_ROOT_PATH = Path(airflow.__file__).parents[3]
@@ -4107,77 +4108,104 @@ class TestDeserializationDefaultsResolution:
         result = 
SerializedBaseOperator._apply_defaults_to_encoded_op(encoded_op, None)
         assert result == encoded_op
 
-    @operator_defaults({"retries": 2})
-    def test_multiple_tasks_share_client_defaults(self):
+    def test_multiple_tasks_share_client_defaults(self, operator_defaults):
         """Test that multiple tasks can share the same client_defaults when 
there are actually non-default values."""
-        with DAG(dag_id="test_dag") as dag:
-            BashOperator(task_id="task1", bash_command="echo 1")
-            BashOperator(task_id="task2", bash_command="echo 2")
+        with operator_defaults({"retries": 2}):
+            with DAG(dag_id="test_dag") as dag:
+                BashOperator(task_id="task1", bash_command="echo 1")
+                BashOperator(task_id="task2", bash_command="echo 2")
 
-        serialized = SerializedDAG.to_dict(dag)
+            serialized = SerializedDAG.to_dict(dag)
 
-        # Should have one client_defaults section for all tasks
-        assert "client_defaults" in serialized
-        assert "tasks" in serialized["client_defaults"]
+            # Should have one client_defaults section for all tasks
+            assert "client_defaults" in serialized
+            assert "tasks" in serialized["client_defaults"]
 
-        # All tasks should benefit from the same client_defaults
-        client_defaults = serialized["client_defaults"]["tasks"]
+            # All tasks should benefit from the same client_defaults
+            client_defaults = serialized["client_defaults"]["tasks"]
 
-        # Deserialize and check both tasks get the defaults
-        deserialized_dag = SerializedDAG.from_dict(serialized)
-        deserialized_task1 = deserialized_dag.get_task("task1")
-        deserialized_task2 = deserialized_dag.get_task("task2")
+            # Deserialize and check both tasks get the defaults
+            deserialized_dag = SerializedDAG.from_dict(serialized)
+            deserialized_task1 = deserialized_dag.get_task("task1")
+            deserialized_task2 = deserialized_dag.get_task("task2")
 
-        # Both tasks should have retries=2 from client_defaults
-        assert deserialized_task1.retries == 2
-        assert deserialized_task2.retries == 2
+            # Both tasks should have retries=2 from client_defaults
+            assert deserialized_task1.retries == 2
+            assert deserialized_task2.retries == 2
 
-        # Both tasks should have the same default values from client_defaults
-        for field in client_defaults:
-            if hasattr(deserialized_task1, field) and 
hasattr(deserialized_task2, field):
-                value1 = getattr(deserialized_task1, field)
-                value2 = getattr(deserialized_task2, field)
-                assert value1 == value2, f"Tasks have different values for 
{field}: {value1} vs {value2}"
+            # Both tasks should have the same default values from 
client_defaults
+            for field in client_defaults:
+                if hasattr(deserialized_task1, field) and 
hasattr(deserialized_task2, field):
+                    value1 = getattr(deserialized_task1, field)
+                    value2 = getattr(deserialized_task2, field)
+                    assert value1 == value2, f"Tasks have different values for 
{field}: {value1} vs {value2}"
+
+    def test_default_args_when_equal_to_schema_defaults(self, 
operator_defaults):
+        """Test that explicitly set values matching schema defaults are 
preserved when client_defaults differ."""
+        with operator_defaults({"retries": 3}):
+            with DAG(dag_id="test_explicit_schema_default", 
default_args={"retries": 0}) as dag:
+                BashOperator(task_id="task1", bash_command="echo 1")
+                BashOperator(task_id="task2", bash_command="echo 1", retries=2)
+
+            serialized = SerializedDAG.to_dict(dag)
+
+            # verify client_defaults has retries=3
+            assert "client_defaults" in serialized
+            assert "tasks" in serialized["client_defaults"]
+            client_defaults = serialized["client_defaults"]["tasks"]
+            assert client_defaults["retries"] == 3
+
+            task1_data = serialized["dag"]["tasks"][0]["__var"]
+            assert task1_data.get("retries", -1) == 0
+
+            task2_data = serialized["dag"]["tasks"][1]["__var"]
+            assert task2_data.get("retries", -1) == 2
+
+            deserialized_task1 = 
SerializedDAG.from_dict(serialized).get_task("task1")
+            assert deserialized_task1.retries == 0
+
+            deserialized_task2 = 
SerializedDAG.from_dict(serialized).get_task("task2")
+            assert deserialized_task2.retries == 2
 
 
 class TestMappedOperatorSerializationAndClientDefaults:
     """Test MappedOperator serialization with client defaults and callback 
properties."""
 
-    @operator_defaults({"retry_delay": 200.0})
-    def test_mapped_operator_client_defaults_application(self):
+    def test_mapped_operator_client_defaults_application(self, 
operator_defaults):
         """Test that client_defaults are correctly applied to MappedOperator 
during deserialization."""
-        with DAG(dag_id="test_mapped_dag") as dag:
-            # Create a mapped operator
-            BashOperator.partial(
-                task_id="mapped_task",
-                retries=5,  # Override default
-            ).expand(bash_command=["echo 1", "echo 2", "echo 3"])
-
-        # Serialize the DAG
-        serialized_dag = SerializedDAG.to_dict(dag)
+        with operator_defaults({"retry_delay": 200.0}):
+            with DAG(dag_id="test_mapped_dag") as dag:
+                # Create a mapped operator
+                BashOperator.partial(
+                    task_id="mapped_task",
+                    retries=5,  # Override default
+                ).expand(bash_command=["echo 1", "echo 2", "echo 3"])
+
+            # Serialize the DAG
+            serialized_dag = SerializedDAG.to_dict(dag)
 
-        # Should have client_defaults section
-        assert "client_defaults" in serialized_dag
-        assert "tasks" in serialized_dag["client_defaults"]
+            # Should have client_defaults section
+            assert "client_defaults" in serialized_dag
+            assert "tasks" in serialized_dag["client_defaults"]
 
-        # Deserialize and check that client_defaults are applied
-        deserialized_dag = SerializedDAG.from_dict(serialized_dag)
-        deserialized_task = deserialized_dag.get_task("mapped_task")
+            # Deserialize and check that client_defaults are applied
+            deserialized_dag = SerializedDAG.from_dict(serialized_dag)
+            deserialized_task = deserialized_dag.get_task("mapped_task")
 
-        # Verify it's still a MappedOperator
-        from airflow.models.mappedoperator import MappedOperator as 
SchedulerMappedOperator
+            # Verify it's still a MappedOperator
+            from airflow.models.mappedoperator import MappedOperator as 
SchedulerMappedOperator
 
-        assert isinstance(deserialized_task, SchedulerMappedOperator)
+            assert isinstance(deserialized_task, SchedulerMappedOperator)
 
-        # Check that client_defaults values are applied (e.g., retry_delay 
from client_defaults)
-        client_defaults = serialized_dag["client_defaults"]["tasks"]
-        if "retry_delay" in client_defaults:
-            # If retry_delay wasn't explicitly set, it should come from 
client_defaults
-            # Since we can't easily convert timedelta back, check the 
serialized format
-            assert hasattr(deserialized_task, "retry_delay")
+            # Check that client_defaults values are applied (e.g., retry_delay 
from client_defaults)
+            client_defaults = serialized_dag["client_defaults"]["tasks"]
+            if "retry_delay" in client_defaults:
+                # If retry_delay wasn't explicitly set, it should come from 
client_defaults
+                # Since we can't easily convert timedelta back, check the 
serialized format
+                assert hasattr(deserialized_task, "retry_delay")
 
-        # Explicit values should override client_defaults
-        assert deserialized_task.retries == 5  # Explicitly set value
+            # Explicit values should override client_defaults
+            assert deserialized_task.retries == 5  # Explicitly set value
 
     @pytest.mark.parametrize(
         ("task_config", "dag_id", "task_id", "non_default_fields"),
@@ -4208,45 +4236,45 @@ class TestMappedOperatorSerializationAndClientDefaults:
             ),
         ],
     )
-    @operator_defaults({"retry_delay": 200.0})
     def test_mapped_operator_client_defaults_optimization(
-        self, task_config, dag_id, task_id, non_default_fields
+        self, task_config, dag_id, task_id, non_default_fields, 
operator_defaults
     ):
         """Test that MappedOperator serialization optimizes using client 
defaults."""
-        with DAG(dag_id=dag_id) as dag:
-            # Create mapped operator with specified configuration
-            BashOperator.partial(
-                task_id=task_id,
-                **task_config,
-            ).expand(bash_command=["echo 1", "echo 2", "echo 3"])
+        with operator_defaults({"retry_delay": 200.0}):
+            with DAG(dag_id=dag_id) as dag:
+                # Create mapped operator with specified configuration
+                BashOperator.partial(
+                    task_id=task_id,
+                    **task_config,
+                ).expand(bash_command=["echo 1", "echo 2", "echo 3"])
 
-        serialized_dag = SerializedDAG.to_dict(dag)
-        mapped_task_serialized = serialized_dag["dag"]["tasks"][0]["__var"]
+            serialized_dag = SerializedDAG.to_dict(dag)
+            mapped_task_serialized = serialized_dag["dag"]["tasks"][0]["__var"]
 
-        assert mapped_task_serialized is not None
-        assert mapped_task_serialized.get("_is_mapped") is True
+            assert mapped_task_serialized is not None
+            assert mapped_task_serialized.get("_is_mapped") is True
 
-        # Check optimization behavior
-        client_defaults = serialized_dag["client_defaults"]["tasks"]
-        partial_kwargs = mapped_task_serialized["partial_kwargs"]
+            # Check optimization behavior
+            client_defaults = serialized_dag["client_defaults"]["tasks"]
+            partial_kwargs = mapped_task_serialized["partial_kwargs"]
 
-        # Check that all fields are optimized correctly
-        for field, default_value in client_defaults.items():
-            if field in non_default_fields:
-                # Non-default fields should be present in partial_kwargs
-                assert field in partial_kwargs, (
-                    f"Field '{field}' should be in partial_kwargs as it's 
non-default"
-                )
-                # And have different values than defaults
-                assert partial_kwargs[field] != default_value, (
-                    f"Field '{field}' should have non-default value"
-                )
-            else:
-                # Default fields should either not be present or have 
different values if present
-                if field in partial_kwargs:
+            # Check that all fields are optimized correctly
+            for field, default_value in client_defaults.items():
+                if field in non_default_fields:
+                    # Non-default fields should be present in partial_kwargs
+                    assert field in partial_kwargs, (
+                        f"Field '{field}' should be in partial_kwargs as it's 
non-default"
+                    )
+                    # And have different values than defaults
                     assert partial_kwargs[field] != default_value, (
-                        f"Field '{field}' with default value should be 
optimized out"
+                        f"Field '{field}' should have non-default value"
                     )
+                else:
+                    # Default fields should either not be present or have 
different values if present
+                    if field in partial_kwargs:
+                        assert partial_kwargs[field] != default_value, (
+                            f"Field '{field}' with default value should be 
optimized out"
+                        )
 
     def test_mapped_operator_expand_input_preservation(self):
         """Test that expand_input is correctly preserved during 
serialization."""

Reply via email to