This is an automated email from the ASF dual-hosted git repository.
bolke pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 6e7340f9c7 Improve serialization of tuples and sets (#29019)
6e7340f9c7 is described below
commit 6e7340f9c71534df5be70bd7812d4de3015e78b7
Author: Bolke de Bruin <[email protected]>
AuthorDate: Mon Mar 20 19:19:55 2023 +0100
Improve serialization of tuples and sets (#29019)
Sets cannot be encoded into JSON directly and tuples loose information when
serialized into JSON directly.
We now serialize both into a dict and encode them properly.
---------
Co-authored-by: Tzu-ping Chung <[email protected]>
---
airflow/datasets/__init__.py | 2 +-
airflow/serialization/serde.py | 87 ++++++++++++++++++++--------
airflow/serialization/serializers/builtin.py | 59 +++++++++++++++++++
airflow/utils/json.py | 4 ++
tests/decorators/test_python.py | 4 +-
tests/serialization/test_serde.py | 69 ++++++++++++++++++++--
tests/utils/test_json.py | 17 ++++++
7 files changed, 210 insertions(+), 32 deletions(-)
diff --git a/airflow/datasets/__init__.py b/airflow/datasets/__init__.py
index aa50fd16ec..0dc635a00b 100644
--- a/airflow/datasets/__init__.py
+++ b/airflow/datasets/__init__.py
@@ -29,7 +29,7 @@ class Dataset:
uri: str = attr.field(validator=[attr.validators.min_len(1),
attr.validators.max_len(3000)])
extra: dict[str, Any] | None = None
- version: ClassVar[int] = 1
+ __version__: ClassVar[int] = 1
@uri.validator
def _check_uri(self, attr, uri: str):
diff --git a/airflow/serialization/serde.py b/airflow/serialization/serde.py
index dd4913e567..55d63ca170 100644
--- a/airflow/serialization/serde.py
+++ b/airflow/serialization/serde.py
@@ -56,6 +56,7 @@ S = Union[list, tuple, set]
_serializers: dict[str, ModuleType] = {}
_deserializers: dict[str, ModuleType] = {}
+_stringifiers: dict[str, ModuleType] = {}
_extra_allowed: set[str] = set()
_primitives = (int, bool, float, str)
@@ -67,8 +68,16 @@ def encode(cls: str, version: int, data: T) -> dict[str, str
| int | T]:
return {CLASSNAME: cls, VERSION: version, DATA: data}
-def decode(d: dict[str, str | int | T]) -> tuple:
- return d[CLASSNAME], d[VERSION], d.get(DATA, None)
+def decode(d: dict[str, Any]) -> tuple[str, int, Any]:
+ classname = d[CLASSNAME]
+ version = d[VERSION]
+
+ if not isinstance(classname, str) or not isinstance(version, int):
+ raise ValueError(f"cannot decode {d!r}")
+
+ data = d.get(DATA)
+
+ return classname, version, data
def serialize(o: object, depth: int = 0) -> U | None:
@@ -85,6 +94,11 @@ def serialize(o: object, depth: int = 0) -> U | None:
2. A registered serializer in the namespace of
``airflow.serialization.serializers``
3. Annotations from attr or dataclass.
+ Limitations: attr and dataclass objects can lose type information for
nested objects
+ as they do not store this when calling ``asdict``. This means that at
deserialization values
+ will be deserialized as a dict as opposed to reinstating the object.
Provide
+ your own serializer to work around this.
+
:param o: The object to serialize.
:param depth: Private tracker for nested serialization.
:raise TypeError: A serializer cannot be found.
@@ -105,14 +119,8 @@ def serialize(o: object, depth: int = 0) -> U | None:
return o
- # tuples and plain dicts are iterated over recursively
- if isinstance(o, _builtin_collections):
- s = [serialize(d, depth + 1) for d in o]
- if isinstance(o, tuple):
- return tuple(s)
- if isinstance(o, set):
- return set(s)
- return s
+ if isinstance(o, list):
+ return [serialize(d, depth + 1) for d in o]
if isinstance(o, dict):
if CLASSNAME in o or SCHEMA_ID in o:
@@ -148,6 +156,7 @@ def serialize(o: object, depth: int = 0) -> U | None:
# dataclasses
if dataclasses.is_dataclass(cls):
+ # fixme: unfortunately using asdict with nested dataclasses it looses
information
data = dataclasses.asdict(o)
dct[DATA] = serialize(data, depth + 1)
return dct
@@ -181,8 +190,16 @@ def deserialize(o: T | None, full=True, type_hint: Any =
None) -> object:
if isinstance(o, _primitives):
return o
+ # tuples, sets are included here for backwards compatibility
if isinstance(o, _builtin_collections):
- return [deserialize(d) for d in o]
+ col = [deserialize(d) for d in o]
+ if isinstance(o, tuple):
+ return tuple(col)
+
+ if isinstance(o, set):
+ return set(col)
+
+ return col
if not isinstance(o, dict):
raise TypeError()
@@ -196,8 +213,8 @@ def deserialize(o: T | None, full=True, type_hint: Any =
None) -> object:
# custom deserialization starts here
cls: Any
version = 0
- value: Any
- classname: str
+ value: Any = None
+ classname = ""
if type_hint:
cls = type_hint
@@ -207,19 +224,22 @@ def deserialize(o: T | None, full=True, type_hint: Any =
None) -> object:
if CLASSNAME in o and VERSION in o:
classname, version, value = decode(o)
- if not _match(classname) and classname not in _extra_allowed:
- raise ImportError(
- f"{classname} was not found in allow list for deserialization
imports. "
- f"To allow it, add it to allowed_deserialization_classes in
the configuration"
- )
- if full:
- cls = import_string(classname)
+ if not classname:
+ raise TypeError("classname cannot be empty")
# only return string representation
if not full:
return _stringify(classname, version, value)
+ if not _match(classname) and classname not in _extra_allowed:
+ raise ImportError(
+ f"{classname} was not found in allow list for deserialization
imports. "
+ f"To allow it, add it to allowed_deserialization_classes in the
configuration"
+ )
+
+ cls = import_string(classname)
+
# registered deserializer
if classname in _deserializers:
return _deserializers[classname].deserialize(classname, version,
deserialize(value))
@@ -258,35 +278,45 @@ def _match(classname: str) -> bool:
def _stringify(classname: str, version: int, value: T | None) -> str:
+ """Convert a previously serialized object in a somewhat human-readable
format.
+
+ This function is not designed to be exact, and will not extensively
traverse
+ the whole tree of an object.
+ """
+ if classname in _stringifiers:
+ return _stringifiers[classname].stringify(classname, version, value)
+
s = f"{classname}@version={version}("
if isinstance(value, _primitives):
s += f"{value})"
elif isinstance(value, _builtin_collections):
- s += ",".join(str(serialize(value)))
+ # deserialized values can be != str
+ s += ",".join(str(deserialize(value, full=False)))
elif isinstance(value, dict):
for k, v in value.items():
- s += f"{k}={serialize(v)},"
+ s += f"{k}={deserialize(v, full=False)},"
s = s[:-1] + ")"
return s
def _register():
- """Register builtin serializers and deserializers for types that don't
have any themselves"""
+ """Register builtin serializers and deserializers for types that don't
have any themselves."""
_serializers.clear()
_deserializers.clear()
+ _stringifiers.clear()
with Stats.timer("serde.load_serializers") as timer:
for _, name, _ in iter_namespace(airflow.serialization.serializers):
name = import_module(name)
- for s in getattr(name, "serializers", list()):
+ for s in getattr(name, "serializers", ()):
if not isinstance(s, str):
s = qualname(s)
if s in _serializers and _serializers[s] != name:
raise AttributeError(f"duplicate {s} for serialization in
{name} and {_serializers[s]}")
log.debug("registering %s for serialization", s)
_serializers[s] = name
- for d in getattr(name, "deserializers", list()):
+ for d in getattr(name, "deserializers", ()):
if not isinstance(d, str):
d = qualname(d)
if d in _deserializers and _deserializers[d] != name:
@@ -294,6 +324,13 @@ def _register():
log.debug("registering %s for deserialization", d)
_deserializers[d] = name
_extra_allowed.add(d)
+ for c in getattr(name, "stringifiers", ()):
+ if not isinstance(c, str):
+ c = qualname(c)
+ if c in _deserializers and _deserializers[c] != name:
+ raise AttributeError(f"duplicate {c} for stringifiers in
{name} and {_stringifiers[c]}")
+ log.debug("registering %s for stringifying", c)
+ _stringifiers[c] = name
log.info("loading serializers took %.3f seconds", timer.duration)
diff --git a/airflow/serialization/serializers/builtin.py
b/airflow/serialization/serializers/builtin.py
new file mode 100644
index 0000000000..52ecc282ab
--- /dev/null
+++ b/airflow/serialization/serializers/builtin.py
@@ -0,0 +1,59 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, cast
+
+from airflow.utils.module_loading import qualname
+
+if TYPE_CHECKING:
+ from airflow.serialization.serde import U
+
+__version__ = 1
+
+serializers = ["builtins.frozenset", "builtins.set", "builtins.tuple"]
+deserializers = serializers
+stringifiers = serializers
+
+
+def serialize(o: object) -> tuple[U, str, int, bool]:
+ return list(cast(list, o)), qualname(o), __version__, True
+
+
+def deserialize(classname: str, version: int, data: list) -> tuple | set |
frozenset:
+ if version > __version__:
+ raise TypeError("serialized version is newer than class version")
+
+ if classname == qualname(tuple):
+ return tuple(data)
+
+ if classname == qualname(set):
+ return set(data)
+
+ if classname == qualname(frozenset):
+ return frozenset(data)
+
+ raise TypeError(f"do not know how to deserialize {classname}")
+
+
+def stringify(classname: str, version: int, data: list) -> str:
+ if classname not in stringifiers:
+ raise TypeError(f"do not know how to stringify {classname}")
+
+ s = ",".join(str(d) for d in data)
+ return f"({s})"
diff --git a/airflow/utils/json.py b/airflow/utils/json.py
index 0a497aab89..f1e65307e5 100644
--- a/airflow/utils/json.py
+++ b/airflow/utils/json.py
@@ -88,6 +88,10 @@ class XComEncoder(json.JSONEncoder):
if isinstance(o, dict) and (CLASSNAME in o or SCHEMA_ID in o):
raise AttributeError(f"reserved key {CLASSNAME} found in dict to
serialize")
+ # tuples are not preserved by std python serializer
+ if isinstance(o, tuple):
+ o = self.default(o)
+
return super().encode(o)
diff --git a/tests/decorators/test_python.py b/tests/decorators/test_python.py
index 20dc883106..354ac5fc2e 100644
--- a/tests/decorators/test_python.py
+++ b/tests/decorators/test_python.py
@@ -179,7 +179,7 @@ class TestAirflowTaskDecorator(BasePythonTest):
ti = dr.get_task_instances()[0]
assert res.operator.multiple_outputs is False
- assert ti.xcom_pull() == [8, 4]
+ assert ti.xcom_pull() == (8, 4)
assert ti.xcom_pull(key="return_value_0") is None
assert ti.xcom_pull(key="return_value_1") is None
@@ -197,7 +197,7 @@ class TestAirflowTaskDecorator(BasePythonTest):
ti = dr.get_task_instances()[0]
assert not ident.operator.multiple_outputs
- assert ti.xcom_pull() == [35, 36]
+ assert ti.xcom_pull() == (35, 36)
assert ti.xcom_pull(key="return_value_0") is None
assert ti.xcom_pull(key="return_value_1") is None
diff --git a/tests/serialization/test_serde.py
b/tests/serialization/test_serde.py
index 03aea4f7b8..51bfb97f82 100644
--- a/tests/serialization/test_serde.py
+++ b/tests/serialization/test_serde.py
@@ -80,6 +80,15 @@ class W:
x: int
+@dataclass
+class V:
+ __version__: ClassVar[int] = 1
+ w: W
+ s: list
+ t: tuple
+ c: int
+
+
@pytest.mark.usefixtures("recalculate_patterns")
class TestSerDe:
def test_ser_primitives(self):
@@ -104,17 +113,34 @@ class TestSerDe:
e = serialize(i)
assert i == e
- def test_ser_iterables(self):
+ def test_ser_collections(self):
i = [1, 2]
- e = serialize(i)
+ e = deserialize(serialize(i))
assert i == e
i = ("a", "b", "a", "c")
- e = serialize(i)
+ e = deserialize(serialize(i))
assert i == e
i = {2, 3}
- e = serialize(i)
+ e = deserialize(serialize(i))
+ assert i == e
+
+ i = frozenset({6, 7})
+ e = deserialize(serialize(i))
+ assert i == e
+
+ def test_der_collections_compat(self):
+ i = [1, 2]
+ e = deserialize(i)
+ assert i == e
+
+ i = ("a", "b", "a", "c")
+ e = deserialize(i)
+ assert i == e
+
+ i = {2, 3}
+ e = deserialize(i)
assert i == e
def test_ser_plain_dict(self):
@@ -248,3 +274,38 @@ class TestSerDe:
import_string(s)
except ImportError:
raise AttributeError(f"{s} cannot be imported (located in
{name})")
+
+ def test_stringify(self):
+ i = V(W(10), ["l1", "l2"], (1, 2), 10)
+ e = serialize(i)
+ s = deserialize(e, full=False)
+
+ assert f"{qualname(V)}@version={V.__version__}" in s
+ # asdict from dataclasses removes class information
+ assert "w={'x': 10}" in s
+ assert "s=['l1', 'l2']" in s
+ assert "t=(1,2)" in s
+ assert "c=10" in s
+ e["__data__"]["t"] = (1, 2)
+
+ s = deserialize(e, full=False)
+
+ @pytest.mark.parametrize(
+ "obj, expected",
+ [
+ (
+ Z(10),
+ {
+ "__classname__": "tests.serialization.test_serde.Z",
+ "__version__": 1,
+ "__data__": {"x": 10},
+ },
+ ),
+ (
+ W(2),
+ {"__classname__": "tests.serialization.test_serde.W",
"__version__": 2, "__data__": {"x": 2}},
+ ),
+ ],
+ )
+ def test_serialized_data(self, obj, expected):
+ assert expected == serialize(obj)
diff --git a/tests/utils/test_json.py b/tests/utils/test_json.py
index bcf11dff10..38eae2780a 100644
--- a/tests/utils/test_json.py
+++ b/tests/utils/test_json.py
@@ -78,3 +78,20 @@ class TestXComEncoder:
s = json.dumps(u, cls=utils_json.XComEncoder)
o = json.loads(s, cls=utils_json.XComDecoder,
object_hook=utils_json.XComDecoder.orm_object_hook)
assert o ==
f"{U.__module__}.{U.__qualname__}@version={U.__version__}(x={x})"
+
+ def test_collections(self):
+ i = [1, 2]
+ e = json.loads(json.dumps(i, cls=utils_json.XComEncoder),
cls=utils_json.XComDecoder)
+ assert i == e
+
+ i = ("a", "b", "a", "c")
+ e = json.loads(json.dumps(i, cls=utils_json.XComEncoder),
cls=utils_json.XComDecoder)
+ assert i == e
+
+ i = {2, 3}
+ e = json.loads(json.dumps(i, cls=utils_json.XComEncoder),
cls=utils_json.XComDecoder)
+ assert i == e
+
+ i = frozenset({6, 7})
+ e = json.loads(json.dumps(i, cls=utils_json.XComEncoder),
cls=utils_json.XComDecoder)
+ assert i == e