bolkedebruin commented on code in PR #27540:
URL: https://github.com/apache/airflow/pull/27540#discussion_r1025093529


##########
airflow/utils/json.py:
##########
@@ -123,3 +138,113 @@ def dumps(self, obj, **kwargs):
 
     def loads(self, s: str | bytes, **kwargs):
         return json.loads(s, **kwargs)
+
+
+# for now separate as AirflowJsonEncoder is non-standard
+class XComEncoder(json.JSONEncoder):
+    """This encoder serializes any object that has attr, dataclass or a custom 
serializer."""
+
+    def default(self, o: object) -> dict:
+        from airflow.serialization.serialized_objects import BaseSerialization
+
+        dct = {
+            CLASSNAME: o.__module__ + "." + o.__class__.__qualname__,
+            VERSION: getattr(o.__class__, "version", DEFAULT_VERSION),
+        }
+
+        if hasattr(o, "serialize"):
+            dct[DATA] = getattr(o, "serialize")()
+            return dct
+        elif dataclasses.is_dataclass(o.__class__):
+            data = dataclasses.asdict(o)
+            dct[DATA] = BaseSerialization.serialize(data)
+            return dct
+        elif attr.has(o.__class__):
+            # Only include attributes which we can pass back to the classes 
constructor
+            data = attr.asdict(o, recurse=True, filter=lambda a, v: a.init)  # 
type: ignore[arg-type]
+            dct[DATA] = BaseSerialization.serialize(data)
+            return dct
+        else:
+            return super().default(o)
+
+    def encode(self, o: Any) -> str:
+        if isinstance(o, dict) and CLASSNAME in o:
+            raise AttributeError(f"reserved key {CLASSNAME} found in dict to 
serialize")
+
+        return super().encode(o)
+
+
+class XComDecoder(json.JSONDecoder):
+    """
+    This decoder deserializes dicts to objects if they contain
+    the `__classname__` key otherwise it will return the dict
+    as is.
+    """
+
+    def __init__(self, *args, **kwargs) -> None:
+        if not kwargs.get("object_hook"):
+            kwargs["object_hook"] = self.object_hook
+
+        super().__init__(*args, **kwargs)
+
+    @staticmethod
+    def object_hook(dct: dict) -> object:
+        dct = XComDecoder._convert(dct)
+
+        if CLASSNAME in dct and VERSION in dct:
+            from airflow.serialization.serialized_objects import 
BaseSerialization
+
+            cls = import_string(dct[CLASSNAME])
+
+            version = getattr(cls, "version", 0)
+            if hasattr(cls, "deserialize"):
+                return getattr(cls, "deserialize")(dct[DATA], version)

Review Comment:
   Good catch.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to