This is an automated email from the ASF dual-hosted git repository. vatsrahul1001 pushed a commit to branch backport-322-63871 in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 99d5786af1e5196b1cc37f055c971f64acdc1dc8 Author: Jeongwoo Do <[email protected]> AuthorDate: Fri May 15 03:17:24 2026 +0900 fix serialize_template_field handling callable value in dict (#63871) * Fix non-deterministic serialization of non-jsonable objects in template fields * fix logic * fix logic * fix logic * fix logic (cherry picked from commit 90051561e721d24834f79d5e1335708c671a9be9) --- airflow-core/src/airflow/serialization/helpers.py | 123 ++-- .../tests/unit/dags/test_dag_decorator_version.py | 63 +++ .../tests/unit/models/test_renderedtifields.py | 4 +- .../unit/serialization/test_dag_serialization.py | 39 ++ .../tests/unit/serialization/test_helpers.py | 630 +++++++++++++++++++++ .../src/airflow/sdk/execution_time/task_runner.py | 134 ++--- .../task_sdk/execution_time/test_task_runner.py | 13 +- 7 files changed, 879 insertions(+), 127 deletions(-) diff --git a/airflow-core/src/airflow/serialization/helpers.py b/airflow-core/src/airflow/serialization/helpers.py index e2c8069a116..83b57d1c7cc 100644 --- a/airflow-core/src/airflow/serialization/helpers.py +++ b/airflow-core/src/airflow/serialization/helpers.py @@ -19,87 +19,96 @@ from __future__ import annotations import contextlib +import inspect from typing import TYPE_CHECKING, Any from airflow._shared.module_loading import qualname from airflow._shared.secrets_masker import redact from airflow._shared.template_rendering import truncate_rendered_value from airflow.configuration import conf -from airflow.settings import json if TYPE_CHECKING: from airflow.partition_mappers.base import PartitionMapper from airflow.timetables.base import Timetable as CoreTimetable -def serialize_template_field(template_field: Any, name: str) -> str | dict | list | int | float: +def serialize_template_field(template_field: Any, name: str) -> str | dict | list | int | float | bool | None: """ Return a serializable representation of the templated field. - If ``templated_field`` is provided via a callable then - return the following serialized value: ``<callable full_qualified_name>`` + The walk has two responsibilities: - If ``templated_field`` contains a class or instance that requires recursive - templating, store them as strings. Otherwise simply return the field as-is. + 1. **Make the template_field JSON-encodable** — every container is rebuilt + with primitive leaves (str/int/float/bool/None), tuples and sets are + flattened to lists, and unsupported objects fall through to ``str()`` + so ``json.dumps`` never raises on the result. + 2. **Keep the output deterministic across parses** — callables are replaced + with their qualified name (never the default ``<function ... at 0x...>`` + repr), dicts are key-sorted, and (frozen)sets are sorted by element so + the same input always produces the same string. """ - def is_jsonable(x): - try: - json.dumps(x) - except (TypeError, OverflowError): - return False - else: - return True - - def translate_tuples_to_lists(obj: Any): - """Recursively convert tuples to lists.""" - if isinstance(obj, tuple): - return [translate_tuples_to_lists(item) for item in obj] - if isinstance(obj, list): - return [translate_tuples_to_lists(item) for item in obj] - if isinstance(obj, dict): - return {key: translate_tuples_to_lists(value) for key, value in obj.items()} - return obj + def normalize_dict_key(key) -> str: + """Normalize a dict key to a serialized string type.""" + # Serialized template_field keys must all be strings, not a mix of types, so that + # downstream json.dumps(..., sort_keys=True) does not raise on mixed-type keys. + return str(serialize_object(key)) + + def serialize_object(obj): + """Recursively rewrite ``obj`` into a JSON-encodable, hash-stable structure.""" + if obj is None or isinstance(obj, (str, int, float, bool)): + return obj - def sort_dict_recursively(obj: Any) -> Any: - """Recursively sort dictionaries to ensure consistent ordering.""" if isinstance(obj, dict): - return {k: sort_dict_recursively(v) for k, v in sorted(obj.items())} - if isinstance(obj, list): - return [sort_dict_recursively(item) for item in obj] - if isinstance(obj, tuple): - return tuple(sort_dict_recursively(item) for item in obj) - return obj + # Serialize keys/values first so each key is a string and the output is hash-stable, + # then sort by the serialized key to prevent hash inconsistencies when dict ordering varies. + serialized_pairs = [(normalize_dict_key(k), serialize_object(v)) for k, v in obj.items()] + return dict(sorted(serialized_pairs, key=lambda kv: kv[0])) + + if isinstance(obj, (list, tuple)): + return [serialize_object(item) for item in obj] + + if isinstance(obj, (set, frozenset)): + # JSON has no set type → flatten to a list with deterministic ordering + # so hash randomization on element types cannot shift cross-process iteration order. + serialized_set = [serialize_object(e) for e in obj] + return sorted(serialized_set, key=lambda x: (type(x).__name__, str(x))) + + # Use inspect.getattr_static to bypass any custom __getattr__ / metaclass magic + if callable(inspect.getattr_static(obj, "serialize", None)): + return serialize_object(obj.serialize()) + + # Kubernetes client objects (V1Pod, V1Container, ...) expose their content via to_dict(). + # Scope the branch to the kubernetes namespace so unrelated user classes that happen to + # define a to_dict() method fall through to str() instead of being treated as K8s payloads. + if getattr(type(obj), "__module__", "").startswith( + ("kubernetes.", "kubernetes_asyncio.") + ) and callable(inspect.getattr_static(obj, "to_dict", None)): + return serialize_object(obj.to_dict()) + + if callable(obj): + # Use qualified name; default repr embeds memory addresses, which would change the DAG hash on every parse + return f"<callable {qualname(obj, True)}>" + + # A custom __str__ or __repr__ is treated as an intentional textual representation + # supplied by the author and used as-is. + if type(obj).__str__ is not object.__str__ or type(obj).__repr__ is not object.__repr__: + return str(obj) + + # Otherwise fall back to a qualname marker. The default object repr is + # `<ClassName object at 0x...>`, which embeds a memory address that flips per process + # and would break DAG hash stability — use the class qualname instead. + return f"<{qualname(type(obj), True)} object>" max_length = conf.getint("core", "max_templated_field_length") - if not is_jsonable(template_field): - try: - serialized = template_field.serialize() - except AttributeError: - if callable(template_field): - full_qualified_name = qualname(template_field, True) - serialized = f"<callable {full_qualified_name}>" - else: - serialized = str(template_field) - if len(serialized) > max_length: - rendered = redact(serialized, name) - return truncate_rendered_value(str(rendered), max_length) - return serialized - if not template_field and not isinstance(template_field, tuple): - # Avoid unnecessary serialization steps for empty fields unless they are tuples - # and need to be converted to lists - return template_field - template_field = translate_tuples_to_lists(template_field) - # Sort dictionaries recursively to ensure consistent string representation - # This prevents hash inconsistencies when dict ordering varies - if isinstance(template_field, dict): - template_field = sort_dict_recursively(template_field) - serialized = str(template_field) - if len(serialized) > max_length: - rendered = redact(serialized, name) + serialized = serialize_object(template_field) + + if len(str(serialized)) > max_length: + rendered = redact(str(serialized), name) return truncate_rendered_value(str(rendered), max_length) - return template_field + + return serialized class TimetableNotRegistered(ValueError): diff --git a/airflow-core/tests/unit/dags/test_dag_decorator_version.py b/airflow-core/tests/unit/dags/test_dag_decorator_version.py new file mode 100644 index 00000000000..35fd0c98bb9 --- /dev/null +++ b/airflow-core/tests/unit/dags/test_dag_decorator_version.py @@ -0,0 +1,63 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from datetime import datetime + +from airflow.sdk import dag, task, task_group + + +@dag( + dag_id="TEST_DTM", + dag_display_name="TEST DTM", + schedule=None, + default_args={"owner": "airflow", "email": ""}, + start_date=datetime(2024, 1, 25), +) +def dtm_test( + exponent: int = 2, +): + + @task + def get_data(): + return [20, 100, 200, 222, 242, 272] + + @task + def to_exp(number: int, exponent: int) -> float: + return number**exponent + + @task + def trunc(number: float, digits: int) -> float: + return round(number / 22, digits) + + @task + def save(number: list[float]): + for n in number: + print(f"Got number: {n}") + + @task_group # type: ignore[type-var] + def transform(number: int, exponent: int) -> float: + a = to_exp(number, exponent) + b = trunc(a, 2) + return b + + data = get_data() + result = transform.partial(exponent=exponent).expand(number=data) + save(result) # type: ignore[arg-type] + + +instance = dtm_test() diff --git a/airflow-core/tests/unit/models/test_renderedtifields.py b/airflow-core/tests/unit/models/test_renderedtifields.py index d42ed06b033..37e6088494d 100644 --- a/airflow-core/tests/unit/models/test_renderedtifields.py +++ b/airflow-core/tests/unit/models/test_renderedtifields.py @@ -116,11 +116,11 @@ class TestRenderedTaskInstanceFields: pytest.param([], [], id="list"), pytest.param({}, {}, id="empty_dict"), pytest.param((), [], id="empty_tuple"), - pytest.param(set(), "set()", id="empty_set"), + pytest.param(set(), [], id="empty_set"), pytest.param("test-string", "test-string", id="string"), pytest.param({"foo": "bar"}, {"foo": "bar"}, id="dict"), pytest.param(("foo", "bar"), ["foo", "bar"], id="tuple"), - pytest.param({"foo"}, "{'foo'}", id="set"), + pytest.param({"foo"}, ["foo"], id="set"), (date(2018, 12, 6), "2018-12-06"), pytest.param(datetime(2018, 12, 6, 10, 55), "2018-12-06 10:55:00+00:00", id="datetime"), pytest.param( diff --git a/airflow-core/tests/unit/serialization/test_dag_serialization.py b/airflow-core/tests/unit/serialization/test_dag_serialization.py index 375b13dea35..f17163d9222 100644 --- a/airflow-core/tests/unit/serialization/test_dag_serialization.py +++ b/airflow-core/tests/unit/serialization/test_dag_serialization.py @@ -82,6 +82,7 @@ from airflow.serialization.json_schema import load_dag_schema_dict from airflow.serialization.serialized_objects import ( BaseSerialization, DagSerialization, + LazyDeserializedDAG, OperatorSerialization, _XComRef, ) @@ -114,6 +115,7 @@ from tests_common.test_utils.timetables import ( cron_timetable, delta_timetable, ) +from unit.models import TEST_DAGS_FOLDER if TYPE_CHECKING: from airflow.sdk.definitions.context import Context @@ -702,6 +704,43 @@ class TestStringifiedDAGs: for dag_id in stringified_dags: self.validate_deserialized_dag(stringified_dags[dag_id], dags[dag_id]) + @pytest.mark.db_test + @conf_vars({("core", "load_examples"): "false"}) + def test_reserialize_should_make_equal_hash_with_dag_processor(self): + dagbag1 = DagBag(TEST_DAGS_FOLDER / "test_dag_decorator_version.py") + hash_result1 = LazyDeserializedDAG.from_dag(next(iter(dagbag1.dags.values()))).hash + + dagbag2 = DagBag(TEST_DAGS_FOLDER / "test_dag_decorator_version.py") + hash_result2 = LazyDeserializedDAG.from_dag(next(iter(dagbag2.dags.values()))).hash + + assert hash_result1 == hash_result2 + + @pytest.mark.db_test + @conf_vars({("core", "load_examples"): "false"}) + def test_hash_succeeds_for_dag_with_mixed_primitive_key_template_field(self): + """SerializedDagModel.hash() must not raise on a template field whose dict has mixed-type primitive keys. + + Building the Dag twice via ``create_dag()`` produces independent Dag and + operator instances, so the hashes must also be equal across calls — + otherwise the serialization path is leaking non-deterministic state + (memory addresses, dict ordering, etc.) into the hash. + """ + from airflow.providers.standard.operators.python import PythonOperator + + def create_dag(): + with DAG(dag_id="dag_mixed_keys", schedule=None, start_date=datetime(2024, 1, 1)) as dag: + PythonOperator( + task_id="op", + python_callable=empty_function, + op_kwargs={"data": {1: "a", "b": "c", None: "z", 2: "d"}, empty_function: "t"}, + ) + return dag + + first_hash = LazyDeserializedDAG.from_dag(create_dag()).hash + second_hash = LazyDeserializedDAG.from_dag(create_dag()).hash + + assert first_hash == second_hash + @skip_if_force_lowest_dependencies_marker @pytest.mark.db_test def test_roundtrip_provider_example_dags(self): diff --git a/airflow-core/tests/unit/serialization/test_helpers.py b/airflow-core/tests/unit/serialization/test_helpers.py index 0ded4f64a86..a4cfe623bd4 100644 --- a/airflow-core/tests/unit/serialization/test_helpers.py +++ b/airflow-core/tests/unit/serialization/test_helpers.py @@ -16,6 +16,14 @@ # under the License. from __future__ import annotations +import json + +import pytest + +from airflow.sdk.definitions._internal.types import SET_DURING_EXECUTION +from airflow.serialization.definitions.notset import NOTSET +from airflow.serialization.helpers import serialize_template_field + def test_serialize_template_field_with_very_small_max_length(monkeypatch): """Test that truncation message is prioritized even for very small max_length.""" @@ -29,3 +37,625 @@ def test_serialize_template_field_with_very_small_max_length(monkeypatch): # This ensures users always see why content is truncated assert result assert "Truncated. You can change this behaviour" in result + + +def test_serialize_template_field_truncation_kicks_in(monkeypatch): + """Long serialized output must be truncated with the standard message.""" + monkeypatch.setenv("AIRFLOW__CORE__MAX_TEMPLATED_FIELD_LENGTH", "20") + + long_value = {"k": "x" * 500} + result = serialize_template_field(long_value, "field") + + assert "Truncated. You can change this behaviour" in result + + +def test_serialize_template_field_with_notset(): + """NOTSET must serialize deterministically via serialize(), not str() fallback.""" + result = serialize_template_field(NOTSET, "logical_date") + assert result == "NOTSET" + + +def test_serialize_template_field_with_set_during_execution(): + """SetDuringExecution must use its own serialize() override.""" + result = serialize_template_field(SET_DURING_EXECUTION, "logical_date") + assert result == "DYNAMIC (set during execution)" + + +def test_argnotset_repr_and_str(): + """repr/str should return the stable serialized sentinel string.""" + assert repr(NOTSET) == "NOTSET" + assert str(NOTSET) == "NOTSET" + assert repr(SET_DURING_EXECUTION) == "DYNAMIC (set during execution)" + assert str(SET_DURING_EXECUTION) == "DYNAMIC (set during execution)" + + +def test_serialize_template_field_with_dict_value_callable(): + + def fn_returns_callable(): + def get_arg(): + pass + + return get_arg + + template_name = "op_kwargs" + + def make_value(): + return {"values": [3, 1, 2], "sort_key": lambda x: x} + + result1 = serialize_template_field(make_value(), template_name) + result2 = serialize_template_field(make_value(), template_name) + + assert result1 == result2 + + def make_value_nested(): + return { + "values": [3, 1, 2], + "sort_key_nested": {"b": lambda x: fn_returns_callable(), "a": "test"}, + } + + result1_nested = serialize_template_field(make_value_nested(), template_name) + result2_nested = serialize_template_field(make_value_nested(), template_name) + + assert result1_nested == result2_nested + + +def test_serialize_template_field_with_mixed_key_dict_and_callable(): + """Mixed-key dicts containing callables must serialize deterministically without TypeError.""" + template_name = "op_kwargs" + + def make_value(): + return {1: "a", "b": lambda x: x, 2: "c"} + + result1 = serialize_template_field(make_value(), template_name) + result2 = serialize_template_field(make_value(), template_name) + + assert result1 == result2 + assert any(isinstance(v, str) and "<callable " in v for v in result1.values()) + + +def test_serialize_template_field_with_mixed_key_jsonable_dict(): + """Jsonable mixed-key dicts must not raise when sorted for deterministic output.""" + template_name = "op_kwargs" + + def make_value(): + return {1: "a", "b": "c", 2: "d", 3: True} + + result1 = serialize_template_field(make_value(), template_name) + result2 = serialize_template_field(make_value(), template_name) + + assert result1 == result2 + + [email protected]( + "value", + [None, "hello", 0, 42, -1, 3.14, True, False], + ids=["none", "str", "zero", "int", "neg_int", "float", "true", "false"], +) +def test_serialize_template_field_primitives_pass_through(value): + """Primitives (None, str, int, float, bool) must be returned unchanged and keep their type.""" + result = serialize_template_field(value, "field") + assert result == value + assert type(result) is type(value) + + +def test_serialize_template_field_tuple_becomes_list(): + """Top-level and nested tuples must flatten to lists for JSON compatibility.""" + result = serialize_template_field((1, 2, (3, 4)), "field") + assert result == [1, 2, [3, 4]] + + +def test_serialize_template_field_tuple_key_normalized(): + """Tuple keys must be normalized to a string so the dict stays JSON-encodable.""" + result1 = serialize_template_field({(1, 2): "v", (3, 4): "w"}, "op_kwargs") + result2 = serialize_template_field({(3, 4): "w", (1, 2): "v"}, "op_kwargs") + + assert result1 == result2 + assert all(isinstance(k, str) for k in result1) + json.dumps(result1) # must not raise + + +def test_serialize_template_field_frozenset_key_normalized(): + """Frozenset keys must be normalized to a string.""" + result = serialize_template_field({frozenset([1, 2]): "v"}, "op_kwargs") + assert isinstance(next(iter(result)), str) + json.dumps(result) + + +def test_serialize_template_field_callable_key_uses_qualname(): + """Callable keys must serialize via qualname so memory addresses don't leak into the hash.""" + + def my_fn(): + pass + + result = serialize_template_field({my_fn: "v"}, "op_kwargs") + key = next(iter(result)) + assert key.startswith("<callable ") + assert "my_fn" in key + assert "at 0x" not in key + + +def test_serialize_template_field_mixed_exotic_keys_deterministic(): + """A dict with str, int, tuple, and callable keys must serialize the same way every call.""" + + def my_fn(): + pass + + def make_value(): + return {"a": 1, 2: "b", (3, 4): "c", my_fn: "d"} + + r1 = serialize_template_field(make_value(), "op_kwargs") + r2 = serialize_template_field(make_value(), "op_kwargs") + assert r1 == r2 + json.dumps(r1) + + +def test_serialize_template_field_object_with_serialize_method(): + """An object exposing serialize() must use it (recursively) instead of str().""" + + class Custom: + def serialize(self): + return {"kind": "custom", "values": (1, 2, 3)} + + result = serialize_template_field(Custom(), "field") + assert result == {"kind": "custom", "values": [1, 2, 3]} + + +def test_serialize_template_field_object_with_getattr_no_serialize(): + """Objects with custom __getattr__ but no real serialize attribute must fall through to str().""" + + class Tricky: + def __getattr__(self, item): + # Mimic SQLAlchemy / proxy objects that return *something* for any attribute access + return lambda *a, **kw: "should-not-be-called" + + def __str__(self): + return "tricky-object" + + result = serialize_template_field(Tricky(), "field") + assert result == "tricky-object" + + +def test_serialize_template_field_non_kubernetes_to_dict_falls_through_to_str(): + """User classes that happen to define to_dict() must not be treated as K8s payloads.""" + + class CustomWithToDict: + def to_dict(self): + return {"should": "not be used"} + + def __str__(self): + return "custom-via-str" + + result = serialize_template_field(CustomWithToDict(), "field") + assert result == "custom-via-str" + + +def test_serialize_template_field_kubernetes_object_uses_to_dict(): + """Objects whose class is defined under the kubernetes.* namespace are normalized via to_dict().""" + + class FakeK8sObject: + def to_dict(self): + return {"kind": "Pod", "metadata": {"name": "test"}} + + FakeK8sObject.__module__ = "kubernetes.client.models.v1_pod" + + result = serialize_template_field(FakeK8sObject(), "field") + assert result == {"kind": "Pod", "metadata": {"name": "test"}} + + +def test_serialize_template_field_bytes_become_str(): + """Bytes are not JSON-encodable; they must be coerced via str().""" + result = serialize_template_field(b"binary", "field") + assert isinstance(result, str) + + +def test_serialize_template_field_no_memory_address_in_output(): + """Output must never contain `<function ... at 0x...>` repr leaks (which would break DAG hashing).""" + + def my_fn(): + pass + + value = { + "a": my_fn, + "b": [my_fn, {"c": my_fn}], + my_fn: "as-key", + ("tup",): my_fn, + } + result = serialize_template_field(value, "op_kwargs") + assert "at 0x" not in str(result) + + +def test_serialize_template_field_plain_object_has_no_memory_address(): + """Objects relying on the default object.__str__ would leak `<ClassName object at 0x...>`.""" + + class Opaque: + pass + + result = serialize_template_field(Opaque(), "field") + assert isinstance(result, str) + assert "at 0x" not in result + assert "Opaque" in result + + +def test_serialize_template_field_plain_object_repr_preserved_when_custom(): + """A user-defined __repr__ is a meaningful representation and must be kept as-is.""" + + class WithRepr: + def __repr__(self): + return "stable-repr" + + result = serialize_template_field(WithRepr(), "field") + assert result == "stable-repr" + + +def test_serialize_template_field_set_of_plain_objects_is_deterministic(): + """Repeated serialization of a set of plain objects must produce identical output across calls.""" + + class Opaque: + pass + + first = serialize_template_field({Opaque(), Opaque()}, "field") + second = serialize_template_field({Opaque(), Opaque()}, "field") + assert first == second + assert "at 0x" not in str(first) + + +def test_serialize_template_field_output_is_jsonable(): + """Whatever shape we pass in, the result must be directly JSON-encodable.""" + + def my_fn(): + pass + + value = { + "callable_value": my_fn, + "nested": {"list": [1, (2, 3), my_fn], "deep": {("k",): my_fn}}, + frozenset([1, 2]): [my_fn], + my_fn: {"x": 1}, + } + result = serialize_template_field(value, "op_kwargs") + json.dumps(result) + + +def test_serialize_template_field_deeply_nested_determinism(): + """Determinism across new instances of the same nested structure (key ordering must not matter).""" + + def my_fn(): + pass + + def make_a(): + return { + "z": [3, 2, 1], + "a": {"nested": my_fn, "items": (1, 2)}, + 10: ("x", "y"), + } + + def make_b(): + # Same content, different insertion order + return { + 10: ("x", "y"), + "a": {"items": (1, 2), "nested": my_fn}, + "z": [3, 2, 1], + } + + assert serialize_template_field(make_a(), "f") == serialize_template_field(make_b(), "f") + + +def test_serialize_template_field_bool_not_collapsed_to_int(): + """bool must be preserved as bool (Python treats True == 1, but JSON distinguishes them).""" + result = serialize_template_field({"flag": True, "count": 1}, "op_kwargs") + assert result["flag"] is True + assert result["count"] == 1 + assert type(result["flag"]) is bool + + +def test_serialize_template_field_none_preserved(): + """None must round-trip as None, not the string 'None'.""" + result = serialize_template_field({"x": None, "y": [None, 1]}, "op_kwargs") + assert result == {"x": None, "y": [None, 1]} + + +def test_serialize_template_field_list_with_callables_and_objects(): + """Lists must recursively serialize callables and objects without leaking repr.""" + + def my_fn(): + pass + + class Custom: + def serialize(self): + return "custom-serialized" + + result = serialize_template_field([1, my_fn, Custom(), (2, my_fn)], "field") + assert result[0] == 1 + assert result[1].startswith("<callable ") + assert "my_fn" in result[1] + assert result[2] == "custom-serialized" + assert result[3][0] == 2 + assert result[3][1].startswith("<callable ") + + +def test_serialize_template_field_key_with_serialize_returning_nested_callable(): + """A key whose .serialize() returns a structure containing callables must not leak memory addresses.""" + + def my_fn(): + pass + + class Custom: + def serialize(self): + return {"k": my_fn} # nested callable inside serialize() output + + result = serialize_template_field({Custom(): "v"}, "op_kwargs") + assert "at 0x" not in str(result) + json.dumps(result) + + +def test_serialize_template_field_key_with_serialize_returning_primitive(): + """A key whose .serialize() returns a primitive must use that primitive directly (no str() wrap).""" + + class Custom: + def serialize(self): + return "stable-id-v1" + + result = serialize_template_field({Custom(): "v"}, "op_kwargs") + assert result == {"stable-id-v1": "v"} + + +def test_serialize_template_field_key_with_serialize_returning_list_with_callable(): + """Sibling case to the dict-with-callable test: list output with nested callables must also be cleaned before str().""" + + def my_fn(): + pass + + class Custom: + def serialize(self): + return [1, my_fn, (2, my_fn)] + + result1 = serialize_template_field({Custom(): "v"}, "op_kwargs") + result2 = serialize_template_field({Custom(): "v"}, "op_kwargs") + + key = next(iter(result1)) + assert "at 0x" not in key + assert "<callable " in key + assert result1 == result2 + json.dumps(result1) + + +def test_serialize_template_field_key_falls_back_to_str_when_no_serialize(): + """A non-primitive, non-callable key without .serialize() must use str() of the original object""" + + class NoSerialize: + def __str__(self): + return "no-serialize-stringified" + + result = serialize_template_field({NoSerialize(): "v"}, "op_kwargs") + assert result == {"no-serialize-stringified": "v"} + + +def test_serialize_template_field_set_value_with_callable_no_memory_address_leak(): + """A set containing a callable must replace the callable via qualname, not leak `at 0x...`.""" + + def my_fn(): + pass + + result = serialize_template_field({my_fn}, "op_kwargs") + + assert "at 0x" not in str(result) + assert "<callable " in str(result) + + +def test_serialize_template_field_frozenset_value_with_callable_no_memory_address_leak(): + """Same regression as set, but with frozenset as a value.""" + + def my_fn(): + pass + + result = serialize_template_field({"items": frozenset([my_fn])}, "op_kwargs") + + assert "at 0x" not in str(result) + assert "<callable " in str(result) + + +def test_serialize_template_field_frozenset_key_with_callable_member_no_memory_address_leak(): + """A frozenset key containing a callable must serialize without leaking memory addresses.""" + + def my_fn(): + pass + + # frozenset of hashables (functions are hashable) is a valid dict key + result = serialize_template_field({frozenset([my_fn]): "v"}, "op_kwargs") + + key = next(iter(result)) + assert "at 0x" not in key + assert "<callable " in key + + +def test_serialize_template_field_set_value_flattens_to_list(): + """Set must serialize to a JSON-compatible list, not a Python set repr string.""" + + result = serialize_template_field({"items": {1, 2, 3}}, "op_kwargs") + + assert isinstance(result["items"], list) + assert sorted(result["items"]) == [1, 2, 3] + json.dumps(result) + + +def test_serialize_template_field_set_of_strings_deterministic_ordering(): + """Set of strings must serialize with deterministic ordering — not affected by PYTHONHASHSEED. + + Sets are walked then sorted by (type_name, str(element)), so the output ordering + depends on the elements rather than on hash randomization across processes. + """ + # Same content, two independent set instances + a = serialize_template_field({"items": {"banana", "apple", "cherry"}}, "op_kwargs") + b = serialize_template_field({"items": {"cherry", "banana", "apple"}}, "op_kwargs") + + assert a == b + assert isinstance(a["items"], list) + assert a["items"] == sorted(a["items"]) + + +def test_serialize_template_field_nested_set_with_callable(): + """Set nested deep inside a dict/list must still recursively clean callables.""" + + def my_fn(): + pass + + value = {"outer": [{"inner": {my_fn, "literal"}}]} + result = serialize_template_field(value, "op_kwargs") + + assert "at 0x" not in str(result) + json.dumps(result) + + +def test_serialize_template_field_callable_keys_sort_by_qualname_not_address(): + """Two distinct named callables as dict keys must sort by qualname, not memory address. + + Without this guarantee, two semantically-identical inputs that happen to allocate + the callables in a different order produce different serialized output, and re-parsing + the same Dag in another process can produce a different hash. + """ + + def fn_a(): + pass + + def fn_b(): + pass + + # Two dicts with the same content but different insertion orders must produce + # the same output once sorting is keyed on qualname. + r1 = serialize_template_field({fn_a: 1, fn_b: 2}, "op_kwargs") + r2 = serialize_template_field({fn_b: 2, fn_a: 1}, "op_kwargs") + + assert r1 == r2 + # The serialized iteration order must follow qualname (fn_a before fn_b), + # not memory address. + keys = list(r1.keys()) + assert len(keys) == 2 + assert "fn_a" in keys[0] + assert "fn_b" in keys[1] + + +def test_serialize_template_field_lambda_keys_collapse_deterministically(): + """Multiple lambdas as keys collapse to one entry deterministically across parses. + + Each call to ``make_value()`` produces *new* lambda objects with new memory + addresses. The serialized result must not depend on those addresses. + """ + + def make_value(): + # Two lambdas; both qualnames are ``<lambda>``, so they collapse to the same + # serialized key. The assertion below targets stability across calls, + # not key preservation between the two lambdas. + return {(lambda x: x): "a", (lambda y: y): "b"} + + r1 = serialize_template_field(make_value(), "op_kwargs") + r2 = serialize_template_field(make_value(), "op_kwargs") + + assert r1 == r2 + assert "at 0x" not in str(r1) + + +def test_serialize_template_field_dict_with_serializable_keys_sort_by_serialized_form(): + """Custom objects whose .serialize() returns a stable string must be sorted by that string, not by repr.""" + + class StableId: + def __init__(self, name): + self.name = name + + def serialize(self): + return self.name + + # Insert in reverse alphabetical order — sorting by serialized form must reverse it. + r1 = serialize_template_field({StableId("zeta"): 1, StableId("alpha"): 2}, "op_kwargs") + r2 = serialize_template_field({StableId("alpha"): 2, StableId("zeta"): 1}, "op_kwargs") + + assert r1 == r2 + assert list(r1.keys()) == ["alpha", "zeta"] + + [email protected]( + ("value", "expected_keys"), + [ + ({1: "a", 2: "b"}, {"1", "2"}), + ({True: "a", False: "b"}, {"True", "False"}), + ({None: "a"}, {"None"}), + ({1.5: "a", 2.5: "b"}, {"1.5", "2.5"}), + ({1: "a", "b": "c"}, {"1", "b"}), + ], + ids=["int_keys", "bool_keys", "none_key", "float_keys", "mixed_int_str"], +) +def test_serialize_template_field_primitive_keys_coerced_to_string(value, expected_keys): + """All dict keys must be coerced to str so json.dumps(sort_keys=True) downstream cannot raise.""" + result = serialize_template_field(value, "op_kwargs") + assert set(result.keys()) == expected_keys + assert all(isinstance(k, str) for k in result) + + +def test_serialize_template_field_mixed_primitive_keys_jsonable_sort_keys(): + """Output of mixed-type primitive keys must survive ``json.dumps(..., sort_keys=True)``.""" + value = {1: "a", "b": "c", 2: "d", 3: True, None: "z", False: "y"} + result = serialize_template_field(value, "op_kwargs") + + json.dumps(result, sort_keys=True) + + +def test_serialize_template_field_mixed_primitive_keys_deterministic_across_calls(): + """Same input parsed twice must yield identical output once keys are stringified.""" + + def fn_a(): + pass + + def fn_b(): + pass + + def make_value(): + return {1: "a", "b": "c", 2: "d", None: "z", "test": fn_b, fn_a: 3.5} + + assert serialize_template_field(make_value(), "op_kwargs") == serialize_template_field( + make_value(), "op_kwargs" + ) + + +def test_serialize_template_field_nested_mixed_primitive_keys_jsonable(): + """Nested mixed-type primitive keys (dict inside dict) must also be coerced and jsonable.""" + value = {"outer": {1: "a", "b": "c", None: "z"}} + result = serialize_template_field(value, "op_kwargs") + + assert all(isinstance(k, str) for k in result["outer"]) + json.dumps(result, sort_keys=True) + + +def test_serialize_template_field_deeply_nested_dict_keys_recursively_normalized(): + """Every nested dict must apply key normalization and sorting recursively. + + Mixed-type primitive keys and callable keys appear at multiple depths; the + helper must stringify and sort them at each level so the full output is + deterministic across calls and safe for ``json.dumps(sort_keys=True)``. + """ + + def fn_inner(): + pass + + def make_value(): + return { + "level1": { + 1: "a", + fn_inner: { + None: "deep", + "nested_str": "value", + 2.5: {fn_inner: "deepest"}, + }, + "b": {3: "three", 4: "four"}, + }, + } + + r1 = serialize_template_field(make_value(), "op_kwargs") + r2 = serialize_template_field(make_value(), "op_kwargs") + + assert r1 == r2 + assert all(isinstance(k, str) for k in r1["level1"]) + callable_key = next(k for k in r1["level1"] if "fn_inner" in k) + inner = r1["level1"][callable_key] + assert all(isinstance(k, str) for k in inner) + float_key = next(k for k in inner if k == "2.5") + assert all(isinstance(k, str) for k in inner[float_key]) + assert "at 0x" not in str(r1) + json.dumps(r1, sort_keys=True) diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index c363c40aa60..60b6c5d8132 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -966,83 +966,91 @@ def startup(msg: StartupDetails) -> tuple[RuntimeTaskInstance, Context, Logger]: return ti, ti.get_template_context(), log -def _serialize_template_field(template_field: Any, name: str) -> str | dict | list | int | float: +def _serialize_template_field( + template_field: Any, name: str +) -> str | dict | list | int | float | bool | None: """ Return a serializable representation of the templated field. - If ``templated_field`` contains a class or instance that requires recursive - templating, store them as strings. Otherwise simply return the field as-is. + The walk has two responsibilities: - Used sdk secrets masker to redact secrets in the serialized output. + 1. **Make the template_field JSON-encodable** — every container is rebuilt + with primitive leaves (str/int/float/bool/None), tuples and sets are + flattened to lists, and unsupported objects fall through to ``str()`` + so ``json.dumps`` never raises on the result. + 2. **Keep the output deterministic across parses** — callables are replaced + with their qualified name (never the default ``<function ... at 0x...>`` + repr), dicts are key-sorted, and (frozen)sets are sorted by element so + the same input always produces the same string. + + Uses the SDK secrets masker to redact secrets in the serialized output. """ - import json + import inspect + from airflow.sdk._shared.module_loading import qualname from airflow.sdk._shared.secrets_masker import redact - def is_jsonable(x): - try: - json.dumps(x) - except (TypeError, OverflowError): - return False - else: - return True - - def translate_tuples_to_lists(obj: Any): - """Recursively convert tuples to lists.""" - if isinstance(obj, tuple): - return [translate_tuples_to_lists(item) for item in obj] - if isinstance(obj, list): - return [translate_tuples_to_lists(item) for item in obj] - if isinstance(obj, dict): - return {key: translate_tuples_to_lists(value) for key, value in obj.items()} - return obj + def normalize_dict_key(key) -> str: + """Normalize a dict key to a serialized string type.""" + # Serialized template_field keys must all be strings, not a mix of types, so that + # downstream json.dumps(..., sort_keys=True) does not raise on mixed-type keys. + return str(serialize_object(key)) + + def serialize_object(obj): + """Recursively rewrite ``obj`` into a JSON-encodable, hash-stable structure.""" + if obj is None or isinstance(obj, (str, int, float, bool)): + return obj - def sort_dict_recursively(obj: Any) -> Any: - """Recursively sort dictionaries to ensure consistent ordering.""" if isinstance(obj, dict): - return {k: sort_dict_recursively(v) for k, v in sorted(obj.items())} - if isinstance(obj, list): - return [sort_dict_recursively(item) for item in obj] - if isinstance(obj, tuple): - return tuple(sort_dict_recursively(item) for item in obj) - return obj - - def _fallback_serialization(obj): - """Serialize objects with to_dict() method (eg: k8s objects) for json.dumps() default parameter.""" - if hasattr(obj, "to_dict"): - return obj.to_dict() - raise TypeError(f"cannot serialize {obj}") + # Serialize keys/values first so each key is a string and the output is hash-stable, + # then sort by the serialized key to prevent hash inconsistencies when dict ordering varies. + serialized_pairs = [(normalize_dict_key(k), serialize_object(v)) for k, v in obj.items()] + return dict(sorted(serialized_pairs, key=lambda kv: kv[0])) + + if isinstance(obj, (list, tuple)): + return [serialize_object(item) for item in obj] + + if isinstance(obj, (set, frozenset)): + # JSON has no set type → flatten to a list with deterministic ordering + # so hash randomization on element types cannot shift cross-process iteration order. + serialized_set = [serialize_object(e) for e in obj] + return sorted(serialized_set, key=lambda x: (type(x).__name__, str(x))) + + # Use inspect.getattr_static to bypass any custom __getattr__ / metaclass magic + if callable(inspect.getattr_static(obj, "serialize", None)): + return serialize_object(obj.serialize()) + + # Kubernetes client objects (V1Pod, V1Container, ...) expose their content via to_dict(). + # Scope the branch to the kubernetes namespace so unrelated user classes that happen to + # define a to_dict() method fall through to str() instead of being treated as K8s payloads. + if getattr(type(obj), "__module__", "").startswith( + ("kubernetes.", "kubernetes_asyncio.") + ) and callable(inspect.getattr_static(obj, "to_dict", None)): + return serialize_object(obj.to_dict()) + + if callable(obj): + # Use qualified name; default repr embeds memory addresses, which would change the DAG hash on every parse + return f"<callable {qualname(obj, True)}>" + + # A custom __str__ or __repr__ is treated as an intentional textual representation + # supplied by the author and used as-is. + if type(obj).__str__ is not object.__str__ or type(obj).__repr__ is not object.__repr__: + return str(obj) + + # Otherwise fall back to a qualname marker. The default object repr is + # `<ClassName object at 0x...>`, which embeds a memory address that flips per process + # and would break DAG hash stability — use the class qualname instead. + return f"<{qualname(type(obj), True)} object>" max_length = conf.getint("core", "max_templated_field_length") - if not is_jsonable(template_field): - try: - serialized = template_field.serialize() - except AttributeError: - # check if these objects can be converted to JSON serializable types - try: - serialized = json.dumps(template_field, default=_fallback_serialization) - except (TypeError, ValueError): - # fall back to string representation if not - serialized = str(template_field) - if len(serialized) > max_length: - rendered = redact(serialized, name) - return truncate_rendered_value(str(rendered), max_length) - return serialized - if not template_field and not isinstance(template_field, tuple): - # Avoid unnecessary serialization steps for empty fields unless they are tuples - # and need to be converted to lists - return template_field - template_field = translate_tuples_to_lists(template_field) - # Sort dictionaries recursively to ensure consistent string representation - # This prevents hash inconsistencies when dict ordering varies - if isinstance(template_field, dict): - template_field = sort_dict_recursively(template_field) - serialized = str(template_field) - if len(serialized) > max_length: - rendered = redact(serialized, name) + serialized = serialize_object(template_field) + + if len(str(serialized)) > max_length: + rendered = redact(str(serialized), name) return truncate_rendered_value(str(rendered), max_length) - return template_field + + return serialized def _serialize_rendered_fields(task: AbstractOperator) -> dict[str, JsonValue]: diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 36084c21db6..4c60830aa4a 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -1099,7 +1099,7 @@ def test_basic_templated_dag(mocked_parse, make_ti_context, mock_supervisor_comm ), pytest.param( {"my_tup": (1, 2), "my_set": {1, 2, 3}}, - {"my_tup": [1, 2], "my_set": "{1, 2, 3}"}, + {"my_tup": [1, 2], "my_set": [1, 2, 3]}, id="tuples_and_sets", ), pytest.param( @@ -2997,10 +2997,13 @@ class TestRuntimeTaskInstance: rendered_fields = mock_supervisor_comms.send.mock_calls[0].kwargs["msg"].rendered_fields assert rendered_fields is not None - assert ( - rendered_fields["env_vars"] - == '[{"name": "var1", "value": "This is a test phrase.", "value_from": null}, {"name": "var2", "value": "***", "value_from": null}, {"name": "var3", "value": "***", "value_from": null}]' - ) + # K8s V1EnvVar objects expose .to_dict(); the recursive walk normalizes the list of objects + # into a list of plain dicts so the result is directly JSON-encodable and redact can mask secrets in nested values. + assert rendered_fields["env_vars"] == [ + {"name": "var1", "value": "This is a test phrase.", "value_from": None}, + {"name": "var2", "value": "***", "value_from": None}, + {"name": "var3", "value": "***", "value_from": None}, + ] def test_nested_template_field_renderer_respects_redaction( self, create_runtime_ti, mock_supervisor_comms
