This is an automated email from the ASF dual-hosted git repository.
amoghdesai pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new c03fc794300 Respect default_args in DAG when its set to a "falsy"
value (#57853)
c03fc794300 is described below
commit c03fc7943009b02072ad275dd31e54f927bff889
Author: Amogh Desai <[email protected]>
AuthorDate: Mon Nov 17 19:12:26 2025 +0530
Respect default_args in DAG when its set to a "falsy" value (#57853)
---
.../airflow/serialization/serialized_objects.py | 49 ++++-
.../unit/serialization/test_dag_serialization.py | 238 ++++++++++++---------
2 files changed, 181 insertions(+), 106 deletions(-)
diff --git a/airflow-core/src/airflow/serialization/serialized_objects.py
b/airflow-core/src/airflow/serialization/serialized_objects.py
index 9a2a8efe96e..d2f13c1dc92 100644
--- a/airflow-core/src/airflow/serialization/serialized_objects.py
+++ b/airflow-core/src/airflow/serialization/serialized_objects.py
@@ -1241,6 +1241,8 @@ class SerializedBaseOperator(DAGNode, BaseSerialization):
_json_schema: ClassVar[Validator] =
lazy_object_proxy.Proxy(load_dag_schema)
+ _const_fields: ClassVar[set[str] | None] = None
+
_can_skip_downstream: bool
_is_empty: bool
_needs_expansion: bool
@@ -1711,6 +1713,22 @@ class SerializedBaseOperator(DAGNode, BaseSerialization):
# Bypass set_upstream etc here - it does more than we want
dag.task_dict[task_id].upstream_task_ids.add(task.task_id)
+ @classmethod
+ def get_operator_const_fields(cls) -> set[str]:
+ """Get the set of operator fields that are marked as const in the JSON
schema."""
+ if (schema_loader := cls._json_schema) is None:
+ return set()
+
+ schema_data = schema_loader.schema
+ operator_def = schema_data.get("definitions", {}).get("operator", {})
+ properties = operator_def.get("properties", {})
+
+ return {
+ field_name
+ for field_name, field_def in properties.items()
+ if isinstance(field_def, dict) and field_def.get("const")
+ }
+
@classmethod
@lru_cache(maxsize=1) # Only one type: "operator"
def get_operator_optional_fields_from_schema(cls) -> set[str]:
@@ -1866,10 +1884,39 @@ class SerializedBaseOperator(DAGNode,
BaseSerialization):
# Check if value matches client_defaults (hierarchical defaults
optimization)
if cls._matches_client_defaults(var, attrname):
return True
- schema_defaults = cls.get_schema_defaults("operator")
+ # for const fields, we should always be excluded when False,
regardless of client_defaults
+ # Use class-level cache for optimisation
+ if cls._const_fields is None:
+ cls._const_fields = cls.get_operator_const_fields()
+ if attrname in cls._const_fields and var is False:
+ return True
+
+ schema_defaults = cls.get_schema_defaults("operator")
if attrname in schema_defaults:
if schema_defaults[attrname] == var:
+ # If it also matches client_defaults, exclude (optimization)
+ client_defaults = cls.generate_client_defaults()
+ if attrname in client_defaults:
+ if client_defaults[attrname] == var:
+ return True
+ # If client_defaults differs, preserve explicit override
from user
+ # Example: default_args={"retries": 0}, schema default=0,
client_defaults={"retries": 3}
+ if client_defaults[attrname] != var:
+ if op.has_dag():
+ dag = op.dag
+ if dag and attrname in dag.default_args and
dag.default_args[attrname] == var:
+ return False
+ if (
+ hasattr(op, "_BaseOperator__init_kwargs")
+ and attrname in op._BaseOperator__init_kwargs
+ and op._BaseOperator__init_kwargs[attrname] == var
+ ):
+ return False
+
+ # If client_defaults doesn't have this field (matches schema
default),
+ # exclude for optimization even if in default_args
+ # Example: default_args={"depends_on_past": False}, schema
default=False
return True
optional_fields = cls.get_operator_optional_fields_from_schema()
if var is None:
diff --git a/airflow-core/tests/unit/serialization/test_dag_serialization.py
b/airflow-core/tests/unit/serialization/test_dag_serialization.py
index c3e4e26ce27..3a8219d02e4 100644
--- a/airflow-core/tests/unit/serialization/test_dag_serialization.py
+++ b/airflow-core/tests/unit/serialization/test_dag_serialization.py
@@ -104,40 +104,41 @@ if TYPE_CHECKING:
from airflow.sdk.definitions.context import Context
[email protected]
-def operator_defaults(overrides):
[email protected]
+def operator_defaults(monkeypatch):
"""
- Temporarily patches OPERATOR_DEFAULTS, restoring original values after
context exit.
+ Fixture that provides a context manager to temporarily patch
OPERATOR_DEFAULTS.
- Example:
- with operator_defaults({"retries": 2, "retry_delay": 200.0}):
- # Test code with modified operator defaults
+ Usage:
+ def test_something(operator_defaults):
+ with operator_defaults({"retries": 2, "retry_delay": 200.0}):
+ # Test code with modified operator defaults
"""
+ import airflow.sdk.definitions._internal.abstractoperator as
abstract_op_module
from airflow.sdk.bases.operator import OPERATOR_DEFAULTS
+ from airflow.serialization.serialized_objects import SerializedBaseOperator
- original_values = {}
- try:
- # Store original values and apply overrides
+ @contextlib.contextmanager
+ def _operator_defaults(overrides):
+ # Patch OPERATOR_DEFAULTS
for key, value in overrides.items():
- original_values[key] = OPERATOR_DEFAULTS.get(key)
- OPERATOR_DEFAULTS[key] = value
+ monkeypatch.setitem(OPERATOR_DEFAULTS, key, value)
+
+ # Patch module-level constants
+ const_name = f"DEFAULT_{key.upper()}"
+ if hasattr(abstract_op_module, const_name):
+ monkeypatch.setattr(abstract_op_module, const_name, value)
# Clear the cache to ensure fresh generation
SerializedBaseOperator.generate_client_defaults.cache_clear()
- yield
- finally:
- # Cleanup: restore original values
- for key, original_value in original_values.items():
- if original_value is None and key in OPERATOR_DEFAULTS:
- # Key didn't exist originally, remove it
- del OPERATOR_DEFAULTS[key]
- else:
- # Restore original value
- OPERATOR_DEFAULTS[key] = original_value
+ try:
+ yield
+ finally:
+ # Clear cache again to restore normal behavior
+ SerializedBaseOperator.generate_client_defaults.cache_clear()
- # Clear cache again to restore normal behavior
- SerializedBaseOperator.generate_client_defaults.cache_clear()
+ return _operator_defaults
AIRFLOW_REPO_ROOT_PATH = Path(airflow.__file__).parents[3]
@@ -4107,77 +4108,104 @@ class TestDeserializationDefaultsResolution:
result =
SerializedBaseOperator._apply_defaults_to_encoded_op(encoded_op, None)
assert result == encoded_op
- @operator_defaults({"retries": 2})
- def test_multiple_tasks_share_client_defaults(self):
+ def test_multiple_tasks_share_client_defaults(self, operator_defaults):
"""Test that multiple tasks can share the same client_defaults when
there are actually non-default values."""
- with DAG(dag_id="test_dag") as dag:
- BashOperator(task_id="task1", bash_command="echo 1")
- BashOperator(task_id="task2", bash_command="echo 2")
+ with operator_defaults({"retries": 2}):
+ with DAG(dag_id="test_dag") as dag:
+ BashOperator(task_id="task1", bash_command="echo 1")
+ BashOperator(task_id="task2", bash_command="echo 2")
- serialized = SerializedDAG.to_dict(dag)
+ serialized = SerializedDAG.to_dict(dag)
- # Should have one client_defaults section for all tasks
- assert "client_defaults" in serialized
- assert "tasks" in serialized["client_defaults"]
+ # Should have one client_defaults section for all tasks
+ assert "client_defaults" in serialized
+ assert "tasks" in serialized["client_defaults"]
- # All tasks should benefit from the same client_defaults
- client_defaults = serialized["client_defaults"]["tasks"]
+ # All tasks should benefit from the same client_defaults
+ client_defaults = serialized["client_defaults"]["tasks"]
- # Deserialize and check both tasks get the defaults
- deserialized_dag = SerializedDAG.from_dict(serialized)
- deserialized_task1 = deserialized_dag.get_task("task1")
- deserialized_task2 = deserialized_dag.get_task("task2")
+ # Deserialize and check both tasks get the defaults
+ deserialized_dag = SerializedDAG.from_dict(serialized)
+ deserialized_task1 = deserialized_dag.get_task("task1")
+ deserialized_task2 = deserialized_dag.get_task("task2")
- # Both tasks should have retries=2 from client_defaults
- assert deserialized_task1.retries == 2
- assert deserialized_task2.retries == 2
+ # Both tasks should have retries=2 from client_defaults
+ assert deserialized_task1.retries == 2
+ assert deserialized_task2.retries == 2
- # Both tasks should have the same default values from client_defaults
- for field in client_defaults:
- if hasattr(deserialized_task1, field) and
hasattr(deserialized_task2, field):
- value1 = getattr(deserialized_task1, field)
- value2 = getattr(deserialized_task2, field)
- assert value1 == value2, f"Tasks have different values for
{field}: {value1} vs {value2}"
+ # Both tasks should have the same default values from
client_defaults
+ for field in client_defaults:
+ if hasattr(deserialized_task1, field) and
hasattr(deserialized_task2, field):
+ value1 = getattr(deserialized_task1, field)
+ value2 = getattr(deserialized_task2, field)
+ assert value1 == value2, f"Tasks have different values for
{field}: {value1} vs {value2}"
+
+ def test_default_args_when_equal_to_schema_defaults(self,
operator_defaults):
+ """Test that explicitly set values matching schema defaults are
preserved when client_defaults differ."""
+ with operator_defaults({"retries": 3}):
+ with DAG(dag_id="test_explicit_schema_default",
default_args={"retries": 0}) as dag:
+ BashOperator(task_id="task1", bash_command="echo 1")
+ BashOperator(task_id="task2", bash_command="echo 1", retries=2)
+
+ serialized = SerializedDAG.to_dict(dag)
+
+ # verify client_defaults has retries=3
+ assert "client_defaults" in serialized
+ assert "tasks" in serialized["client_defaults"]
+ client_defaults = serialized["client_defaults"]["tasks"]
+ assert client_defaults["retries"] == 3
+
+ task1_data = serialized["dag"]["tasks"][0]["__var"]
+ assert task1_data.get("retries", -1) == 0
+
+ task2_data = serialized["dag"]["tasks"][1]["__var"]
+ assert task2_data.get("retries", -1) == 2
+
+ deserialized_task1 =
SerializedDAG.from_dict(serialized).get_task("task1")
+ assert deserialized_task1.retries == 0
+
+ deserialized_task2 =
SerializedDAG.from_dict(serialized).get_task("task2")
+ assert deserialized_task2.retries == 2
class TestMappedOperatorSerializationAndClientDefaults:
"""Test MappedOperator serialization with client defaults and callback
properties."""
- @operator_defaults({"retry_delay": 200.0})
- def test_mapped_operator_client_defaults_application(self):
+ def test_mapped_operator_client_defaults_application(self,
operator_defaults):
"""Test that client_defaults are correctly applied to MappedOperator
during deserialization."""
- with DAG(dag_id="test_mapped_dag") as dag:
- # Create a mapped operator
- BashOperator.partial(
- task_id="mapped_task",
- retries=5, # Override default
- ).expand(bash_command=["echo 1", "echo 2", "echo 3"])
-
- # Serialize the DAG
- serialized_dag = SerializedDAG.to_dict(dag)
+ with operator_defaults({"retry_delay": 200.0}):
+ with DAG(dag_id="test_mapped_dag") as dag:
+ # Create a mapped operator
+ BashOperator.partial(
+ task_id="mapped_task",
+ retries=5, # Override default
+ ).expand(bash_command=["echo 1", "echo 2", "echo 3"])
+
+ # Serialize the DAG
+ serialized_dag = SerializedDAG.to_dict(dag)
- # Should have client_defaults section
- assert "client_defaults" in serialized_dag
- assert "tasks" in serialized_dag["client_defaults"]
+ # Should have client_defaults section
+ assert "client_defaults" in serialized_dag
+ assert "tasks" in serialized_dag["client_defaults"]
- # Deserialize and check that client_defaults are applied
- deserialized_dag = SerializedDAG.from_dict(serialized_dag)
- deserialized_task = deserialized_dag.get_task("mapped_task")
+ # Deserialize and check that client_defaults are applied
+ deserialized_dag = SerializedDAG.from_dict(serialized_dag)
+ deserialized_task = deserialized_dag.get_task("mapped_task")
- # Verify it's still a MappedOperator
- from airflow.models.mappedoperator import MappedOperator as
SchedulerMappedOperator
+ # Verify it's still a MappedOperator
+ from airflow.models.mappedoperator import MappedOperator as
SchedulerMappedOperator
- assert isinstance(deserialized_task, SchedulerMappedOperator)
+ assert isinstance(deserialized_task, SchedulerMappedOperator)
- # Check that client_defaults values are applied (e.g., retry_delay
from client_defaults)
- client_defaults = serialized_dag["client_defaults"]["tasks"]
- if "retry_delay" in client_defaults:
- # If retry_delay wasn't explicitly set, it should come from
client_defaults
- # Since we can't easily convert timedelta back, check the
serialized format
- assert hasattr(deserialized_task, "retry_delay")
+ # Check that client_defaults values are applied (e.g., retry_delay
from client_defaults)
+ client_defaults = serialized_dag["client_defaults"]["tasks"]
+ if "retry_delay" in client_defaults:
+ # If retry_delay wasn't explicitly set, it should come from
client_defaults
+ # Since we can't easily convert timedelta back, check the
serialized format
+ assert hasattr(deserialized_task, "retry_delay")
- # Explicit values should override client_defaults
- assert deserialized_task.retries == 5 # Explicitly set value
+ # Explicit values should override client_defaults
+ assert deserialized_task.retries == 5 # Explicitly set value
@pytest.mark.parametrize(
("task_config", "dag_id", "task_id", "non_default_fields"),
@@ -4208,45 +4236,45 @@ class TestMappedOperatorSerializationAndClientDefaults:
),
],
)
- @operator_defaults({"retry_delay": 200.0})
def test_mapped_operator_client_defaults_optimization(
- self, task_config, dag_id, task_id, non_default_fields
+ self, task_config, dag_id, task_id, non_default_fields,
operator_defaults
):
"""Test that MappedOperator serialization optimizes using client
defaults."""
- with DAG(dag_id=dag_id) as dag:
- # Create mapped operator with specified configuration
- BashOperator.partial(
- task_id=task_id,
- **task_config,
- ).expand(bash_command=["echo 1", "echo 2", "echo 3"])
+ with operator_defaults({"retry_delay": 200.0}):
+ with DAG(dag_id=dag_id) as dag:
+ # Create mapped operator with specified configuration
+ BashOperator.partial(
+ task_id=task_id,
+ **task_config,
+ ).expand(bash_command=["echo 1", "echo 2", "echo 3"])
- serialized_dag = SerializedDAG.to_dict(dag)
- mapped_task_serialized = serialized_dag["dag"]["tasks"][0]["__var"]
+ serialized_dag = SerializedDAG.to_dict(dag)
+ mapped_task_serialized = serialized_dag["dag"]["tasks"][0]["__var"]
- assert mapped_task_serialized is not None
- assert mapped_task_serialized.get("_is_mapped") is True
+ assert mapped_task_serialized is not None
+ assert mapped_task_serialized.get("_is_mapped") is True
- # Check optimization behavior
- client_defaults = serialized_dag["client_defaults"]["tasks"]
- partial_kwargs = mapped_task_serialized["partial_kwargs"]
+ # Check optimization behavior
+ client_defaults = serialized_dag["client_defaults"]["tasks"]
+ partial_kwargs = mapped_task_serialized["partial_kwargs"]
- # Check that all fields are optimized correctly
- for field, default_value in client_defaults.items():
- if field in non_default_fields:
- # Non-default fields should be present in partial_kwargs
- assert field in partial_kwargs, (
- f"Field '{field}' should be in partial_kwargs as it's
non-default"
- )
- # And have different values than defaults
- assert partial_kwargs[field] != default_value, (
- f"Field '{field}' should have non-default value"
- )
- else:
- # Default fields should either not be present or have
different values if present
- if field in partial_kwargs:
+ # Check that all fields are optimized correctly
+ for field, default_value in client_defaults.items():
+ if field in non_default_fields:
+ # Non-default fields should be present in partial_kwargs
+ assert field in partial_kwargs, (
+ f"Field '{field}' should be in partial_kwargs as it's
non-default"
+ )
+ # And have different values than defaults
assert partial_kwargs[field] != default_value, (
- f"Field '{field}' with default value should be
optimized out"
+ f"Field '{field}' should have non-default value"
)
+ else:
+ # Default fields should either not be present or have
different values if present
+ if field in partial_kwargs:
+ assert partial_kwargs[field] != default_value, (
+ f"Field '{field}' with default value should be
optimized out"
+ )
def test_mapped_operator_expand_input_preservation(self):
"""Test that expand_input is correctly preserved during
serialization."""