This is an automated email from the ASF dual-hosted git repository.

bolke pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 6e7340f9c7 Improve serialization of tuples and sets (#29019)
6e7340f9c7 is described below

commit 6e7340f9c71534df5be70bd7812d4de3015e78b7
Author: Bolke de Bruin <[email protected]>
AuthorDate: Mon Mar 20 19:19:55 2023 +0100

    Improve serialization of tuples and sets (#29019)
    
    Sets cannot be encoded into JSON directly and tuples loose information when 
serialized into JSON directly.
    We now serialize both into a dict and encode them properly.
    
    ---------
    
    Co-authored-by: Tzu-ping Chung <[email protected]>
---
 airflow/datasets/__init__.py                 |  2 +-
 airflow/serialization/serde.py               | 87 ++++++++++++++++++++--------
 airflow/serialization/serializers/builtin.py | 59 +++++++++++++++++++
 airflow/utils/json.py                        |  4 ++
 tests/decorators/test_python.py              |  4 +-
 tests/serialization/test_serde.py            | 69 ++++++++++++++++++++--
 tests/utils/test_json.py                     | 17 ++++++
 7 files changed, 210 insertions(+), 32 deletions(-)

diff --git a/airflow/datasets/__init__.py b/airflow/datasets/__init__.py
index aa50fd16ec..0dc635a00b 100644
--- a/airflow/datasets/__init__.py
+++ b/airflow/datasets/__init__.py
@@ -29,7 +29,7 @@ class Dataset:
     uri: str = attr.field(validator=[attr.validators.min_len(1), 
attr.validators.max_len(3000)])
     extra: dict[str, Any] | None = None
 
-    version: ClassVar[int] = 1
+    __version__: ClassVar[int] = 1
 
     @uri.validator
     def _check_uri(self, attr, uri: str):
diff --git a/airflow/serialization/serde.py b/airflow/serialization/serde.py
index dd4913e567..55d63ca170 100644
--- a/airflow/serialization/serde.py
+++ b/airflow/serialization/serde.py
@@ -56,6 +56,7 @@ S = Union[list, tuple, set]
 
 _serializers: dict[str, ModuleType] = {}
 _deserializers: dict[str, ModuleType] = {}
+_stringifiers: dict[str, ModuleType] = {}
 _extra_allowed: set[str] = set()
 
 _primitives = (int, bool, float, str)
@@ -67,8 +68,16 @@ def encode(cls: str, version: int, data: T) -> dict[str, str 
| int | T]:
     return {CLASSNAME: cls, VERSION: version, DATA: data}
 
 
-def decode(d: dict[str, str | int | T]) -> tuple:
-    return d[CLASSNAME], d[VERSION], d.get(DATA, None)
+def decode(d: dict[str, Any]) -> tuple[str, int, Any]:
+    classname = d[CLASSNAME]
+    version = d[VERSION]
+
+    if not isinstance(classname, str) or not isinstance(version, int):
+        raise ValueError(f"cannot decode {d!r}")
+
+    data = d.get(DATA)
+
+    return classname, version, data
 
 
 def serialize(o: object, depth: int = 0) -> U | None:
@@ -85,6 +94,11 @@ def serialize(o: object, depth: int = 0) -> U | None:
     2. A registered serializer in the namespace of 
``airflow.serialization.serializers``
     3. Annotations from attr or dataclass.
 
+    Limitations: attr and dataclass objects can lose type information for 
nested objects
+    as they do not store this when calling ``asdict``. This means that at 
deserialization values
+    will be deserialized as a dict as opposed to reinstating the object. 
Provide
+    your own serializer to work around this.
+
     :param o: The object to serialize.
     :param depth: Private tracker for nested serialization.
     :raise TypeError: A serializer cannot be found.
@@ -105,14 +119,8 @@ def serialize(o: object, depth: int = 0) -> U | None:
 
         return o
 
-    # tuples and plain dicts are iterated over recursively
-    if isinstance(o, _builtin_collections):
-        s = [serialize(d, depth + 1) for d in o]
-        if isinstance(o, tuple):
-            return tuple(s)
-        if isinstance(o, set):
-            return set(s)
-        return s
+    if isinstance(o, list):
+        return [serialize(d, depth + 1) for d in o]
 
     if isinstance(o, dict):
         if CLASSNAME in o or SCHEMA_ID in o:
@@ -148,6 +156,7 @@ def serialize(o: object, depth: int = 0) -> U | None:
 
     # dataclasses
     if dataclasses.is_dataclass(cls):
+        # fixme: unfortunately using asdict with nested dataclasses it looses 
information
         data = dataclasses.asdict(o)
         dct[DATA] = serialize(data, depth + 1)
         return dct
@@ -181,8 +190,16 @@ def deserialize(o: T | None, full=True, type_hint: Any = 
None) -> object:
     if isinstance(o, _primitives):
         return o
 
+    # tuples, sets are included here for backwards compatibility
     if isinstance(o, _builtin_collections):
-        return [deserialize(d) for d in o]
+        col = [deserialize(d) for d in o]
+        if isinstance(o, tuple):
+            return tuple(col)
+
+        if isinstance(o, set):
+            return set(col)
+
+        return col
 
     if not isinstance(o, dict):
         raise TypeError()
@@ -196,8 +213,8 @@ def deserialize(o: T | None, full=True, type_hint: Any = 
None) -> object:
     # custom deserialization starts here
     cls: Any
     version = 0
-    value: Any
-    classname: str
+    value: Any = None
+    classname = ""
 
     if type_hint:
         cls = type_hint
@@ -207,19 +224,22 @@ def deserialize(o: T | None, full=True, type_hint: Any = 
None) -> object:
 
     if CLASSNAME in o and VERSION in o:
         classname, version, value = decode(o)
-        if not _match(classname) and classname not in _extra_allowed:
-            raise ImportError(
-                f"{classname} was not found in allow list for deserialization 
imports. "
-                f"To allow it, add it to allowed_deserialization_classes in 
the configuration"
-            )
 
-        if full:
-            cls = import_string(classname)
+    if not classname:
+        raise TypeError("classname cannot be empty")
 
     # only return string representation
     if not full:
         return _stringify(classname, version, value)
 
+    if not _match(classname) and classname not in _extra_allowed:
+        raise ImportError(
+            f"{classname} was not found in allow list for deserialization 
imports. "
+            f"To allow it, add it to allowed_deserialization_classes in the 
configuration"
+        )
+
+    cls = import_string(classname)
+
     # registered deserializer
     if classname in _deserializers:
         return _deserializers[classname].deserialize(classname, version, 
deserialize(value))
@@ -258,35 +278,45 @@ def _match(classname: str) -> bool:
 
 
 def _stringify(classname: str, version: int, value: T | None) -> str:
+    """Convert a previously serialized object in a somewhat human-readable 
format.
+
+    This function is not designed to be exact, and will not extensively 
traverse
+    the whole tree of an object.
+    """
+    if classname in _stringifiers:
+        return _stringifiers[classname].stringify(classname, version, value)
+
     s = f"{classname}@version={version}("
     if isinstance(value, _primitives):
         s += f"{value})"
     elif isinstance(value, _builtin_collections):
-        s += ",".join(str(serialize(value)))
+        # deserialized values can be != str
+        s += ",".join(str(deserialize(value, full=False)))
     elif isinstance(value, dict):
         for k, v in value.items():
-            s += f"{k}={serialize(v)},"
+            s += f"{k}={deserialize(v, full=False)},"
         s = s[:-1] + ")"
 
     return s
 
 
 def _register():
-    """Register builtin serializers and deserializers for types that don't 
have any themselves"""
+    """Register builtin serializers and deserializers for types that don't 
have any themselves."""
     _serializers.clear()
     _deserializers.clear()
+    _stringifiers.clear()
 
     with Stats.timer("serde.load_serializers") as timer:
         for _, name, _ in iter_namespace(airflow.serialization.serializers):
             name = import_module(name)
-            for s in getattr(name, "serializers", list()):
+            for s in getattr(name, "serializers", ()):
                 if not isinstance(s, str):
                     s = qualname(s)
                 if s in _serializers and _serializers[s] != name:
                     raise AttributeError(f"duplicate {s} for serialization in 
{name} and {_serializers[s]}")
                 log.debug("registering %s for serialization", s)
                 _serializers[s] = name
-            for d in getattr(name, "deserializers", list()):
+            for d in getattr(name, "deserializers", ()):
                 if not isinstance(d, str):
                     d = qualname(d)
                 if d in _deserializers and _deserializers[d] != name:
@@ -294,6 +324,13 @@ def _register():
                 log.debug("registering %s for deserialization", d)
                 _deserializers[d] = name
                 _extra_allowed.add(d)
+            for c in getattr(name, "stringifiers", ()):
+                if not isinstance(c, str):
+                    c = qualname(c)
+                if c in _deserializers and _deserializers[c] != name:
+                    raise AttributeError(f"duplicate {c} for stringifiers in 
{name} and {_stringifiers[c]}")
+                log.debug("registering %s for stringifying", c)
+                _stringifiers[c] = name
 
     log.info("loading serializers took %.3f seconds", timer.duration)
 
diff --git a/airflow/serialization/serializers/builtin.py 
b/airflow/serialization/serializers/builtin.py
new file mode 100644
index 0000000000..52ecc282ab
--- /dev/null
+++ b/airflow/serialization/serializers/builtin.py
@@ -0,0 +1,59 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, cast
+
+from airflow.utils.module_loading import qualname
+
+if TYPE_CHECKING:
+    from airflow.serialization.serde import U
+
+__version__ = 1
+
+serializers = ["builtins.frozenset", "builtins.set", "builtins.tuple"]
+deserializers = serializers
+stringifiers = serializers
+
+
+def serialize(o: object) -> tuple[U, str, int, bool]:
+    return list(cast(list, o)), qualname(o), __version__, True
+
+
+def deserialize(classname: str, version: int, data: list) -> tuple | set | 
frozenset:
+    if version > __version__:
+        raise TypeError("serialized version is newer than class version")
+
+    if classname == qualname(tuple):
+        return tuple(data)
+
+    if classname == qualname(set):
+        return set(data)
+
+    if classname == qualname(frozenset):
+        return frozenset(data)
+
+    raise TypeError(f"do not know how to deserialize {classname}")
+
+
+def stringify(classname: str, version: int, data: list) -> str:
+    if classname not in stringifiers:
+        raise TypeError(f"do not know how to stringify {classname}")
+
+    s = ",".join(str(d) for d in data)
+    return f"({s})"
diff --git a/airflow/utils/json.py b/airflow/utils/json.py
index 0a497aab89..f1e65307e5 100644
--- a/airflow/utils/json.py
+++ b/airflow/utils/json.py
@@ -88,6 +88,10 @@ class XComEncoder(json.JSONEncoder):
         if isinstance(o, dict) and (CLASSNAME in o or SCHEMA_ID in o):
             raise AttributeError(f"reserved key {CLASSNAME} found in dict to 
serialize")
 
+        # tuples are not preserved by std python serializer
+        if isinstance(o, tuple):
+            o = self.default(o)
+
         return super().encode(o)
 
 
diff --git a/tests/decorators/test_python.py b/tests/decorators/test_python.py
index 20dc883106..354ac5fc2e 100644
--- a/tests/decorators/test_python.py
+++ b/tests/decorators/test_python.py
@@ -179,7 +179,7 @@ class TestAirflowTaskDecorator(BasePythonTest):
         ti = dr.get_task_instances()[0]
 
         assert res.operator.multiple_outputs is False
-        assert ti.xcom_pull() == [8, 4]
+        assert ti.xcom_pull() == (8, 4)
         assert ti.xcom_pull(key="return_value_0") is None
         assert ti.xcom_pull(key="return_value_1") is None
 
@@ -197,7 +197,7 @@ class TestAirflowTaskDecorator(BasePythonTest):
         ti = dr.get_task_instances()[0]
 
         assert not ident.operator.multiple_outputs
-        assert ti.xcom_pull() == [35, 36]
+        assert ti.xcom_pull() == (35, 36)
         assert ti.xcom_pull(key="return_value_0") is None
         assert ti.xcom_pull(key="return_value_1") is None
 
diff --git a/tests/serialization/test_serde.py 
b/tests/serialization/test_serde.py
index 03aea4f7b8..51bfb97f82 100644
--- a/tests/serialization/test_serde.py
+++ b/tests/serialization/test_serde.py
@@ -80,6 +80,15 @@ class W:
     x: int
 
 
+@dataclass
+class V:
+    __version__: ClassVar[int] = 1
+    w: W
+    s: list
+    t: tuple
+    c: int
+
+
 @pytest.mark.usefixtures("recalculate_patterns")
 class TestSerDe:
     def test_ser_primitives(self):
@@ -104,17 +113,34 @@ class TestSerDe:
         e = serialize(i)
         assert i == e
 
-    def test_ser_iterables(self):
+    def test_ser_collections(self):
         i = [1, 2]
-        e = serialize(i)
+        e = deserialize(serialize(i))
         assert i == e
 
         i = ("a", "b", "a", "c")
-        e = serialize(i)
+        e = deserialize(serialize(i))
         assert i == e
 
         i = {2, 3}
-        e = serialize(i)
+        e = deserialize(serialize(i))
+        assert i == e
+
+        i = frozenset({6, 7})
+        e = deserialize(serialize(i))
+        assert i == e
+
+    def test_der_collections_compat(self):
+        i = [1, 2]
+        e = deserialize(i)
+        assert i == e
+
+        i = ("a", "b", "a", "c")
+        e = deserialize(i)
+        assert i == e
+
+        i = {2, 3}
+        e = deserialize(i)
         assert i == e
 
     def test_ser_plain_dict(self):
@@ -248,3 +274,38 @@ class TestSerDe:
                     import_string(s)
                 except ImportError:
                     raise AttributeError(f"{s} cannot be imported (located in 
{name})")
+
+    def test_stringify(self):
+        i = V(W(10), ["l1", "l2"], (1, 2), 10)
+        e = serialize(i)
+        s = deserialize(e, full=False)
+
+        assert f"{qualname(V)}@version={V.__version__}" in s
+        # asdict from dataclasses removes class information
+        assert "w={'x': 10}" in s
+        assert "s=['l1', 'l2']" in s
+        assert "t=(1,2)" in s
+        assert "c=10" in s
+        e["__data__"]["t"] = (1, 2)
+
+        s = deserialize(e, full=False)
+
+    @pytest.mark.parametrize(
+        "obj, expected",
+        [
+            (
+                Z(10),
+                {
+                    "__classname__": "tests.serialization.test_serde.Z",
+                    "__version__": 1,
+                    "__data__": {"x": 10},
+                },
+            ),
+            (
+                W(2),
+                {"__classname__": "tests.serialization.test_serde.W", 
"__version__": 2, "__data__": {"x": 2}},
+            ),
+        ],
+    )
+    def test_serialized_data(self, obj, expected):
+        assert expected == serialize(obj)
diff --git a/tests/utils/test_json.py b/tests/utils/test_json.py
index bcf11dff10..38eae2780a 100644
--- a/tests/utils/test_json.py
+++ b/tests/utils/test_json.py
@@ -78,3 +78,20 @@ class TestXComEncoder:
         s = json.dumps(u, cls=utils_json.XComEncoder)
         o = json.loads(s, cls=utils_json.XComDecoder, 
object_hook=utils_json.XComDecoder.orm_object_hook)
         assert o == 
f"{U.__module__}.{U.__qualname__}@version={U.__version__}(x={x})"
+
+    def test_collections(self):
+        i = [1, 2]
+        e = json.loads(json.dumps(i, cls=utils_json.XComEncoder), 
cls=utils_json.XComDecoder)
+        assert i == e
+
+        i = ("a", "b", "a", "c")
+        e = json.loads(json.dumps(i, cls=utils_json.XComEncoder), 
cls=utils_json.XComDecoder)
+        assert i == e
+
+        i = {2, 3}
+        e = json.loads(json.dumps(i, cls=utils_json.XComEncoder), 
cls=utils_json.XComDecoder)
+        assert i == e
+
+        i = frozenset({6, 7})
+        e = json.loads(json.dumps(i, cls=utils_json.XComEncoder), 
cls=utils_json.XComDecoder)
+        assert i == e

Reply via email to