ashb commented on code in PR #27540:
URL: https://github.com/apache/airflow/pull/27540#discussion_r1020084477
##########
airflow/utils/json.py:
##########
@@ -123,3 +136,89 @@ def dumps(self, obj, **kwargs):
def loads(self, s: str | bytes, **kwargs):
return json.loads(s, **kwargs)
+
+
+# for now separate as AirflowJsonEncoder is non-standard
+class XComEncoder(json.JSONEncoder):
+ """This encoder allows serializes any object that has attr."""
+
+ def default(self, o: object) -> dict:
+ from airflow.serialization.serialized_objects import BaseSerialization
+
+ if hasattr(o, "serialize"):
+ classname = o.__module__ + "." + o.__class__.__name__
+ version = getattr(o.__class__, "version", DEFAULT_VERSION)
+ return {
+ CLASSNAME: classname,
+ VERSION: version,
+ DATA: getattr(o.__class__, "serialize")(o)
+ }
+ elif attr.has(o.__class__):
+ classname = o.__module__ + "." + o.__class__.__name__
+
+ version = getattr(o, "version", DEFAULT_VERSION)
+ # Only include attributes which we can pass back to the classes
constructor
+ data = attr.asdict(o, recurse=True, filter=lambda a, v: a.init) #
type: ignore[arg-type]
+ return {
+ CLASSNAME: classname,
+ VERSION: version,
+ DATA: BaseSerialization.serialize(data),
+ }
+ else:
+ return super().default(o)
+
+ def encode(self, o: Any) -> str:
+ if isinstance(o, dict) and CLASSNAME in o:
+ raise AttributeError(f"reserved key {CLASSNAME} found in dict to
serialize")
+
+ return super().encode(o)
+
+
+class XComDecoder(json.JSONDecoder):
+ """
+ This decoder deserializes dicts to objects if they contain
+ the `__classname__` key otherwise it will return the dict
+ as is.
+ """
+
+ def __init__(self, *args, **kwargs) -> None:
+ super().__init__(object_hook=self.object_hook, *args, **kwargs)
+
+ def object_hook(self, dct) -> object:
+ dct = self._convert(dct)
+
+ if CLASSNAME in dct and VERSION in dct:
+ from airflow.serialization.serialized_objects import
BaseSerialization
+
+ cls = import_string(dct[CLASSNAME])
+
+ version = getattr(cls, "version", 0)
+ if VERSION in dct and int(dct[VERSION]) < version:
+ raise TypeError(
+ "serialized version of %s is newer than module version (%s
> %s)",
+ dct[CLASSNAME],
+ dct[VERSION],
+ version,
+ )
+
+ if hasattr(cls, "deserialize"):
+ return getattr(cls, "deserialize")(dct[DATA])
Review Comment:
We should probably pass version here too?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]