ashb commented on code in PR #27540:
URL: https://github.com/apache/airflow/pull/27540#discussion_r1025069885
##########
airflow/decorators/base.py:
##########
@@ -207,17 +209,29 @@ def __init__(
super().__init__(task_id=task_id, **kwargs_to_upstream, **kwargs)
def execute(self, context: Context):
+ # todo make this more generic (move to prepare_lineage) so it deals
with non taskflow operators
+ # as well
+ for arg in chain(self.op_args, self.op_kwargs.values()):
+ if isinstance(arg, Dataset):
+ self.inlets.append(arg)
return_value = super().execute(context)
return self._handle_output(return_value=return_value, context=context,
xcom_push=self.xcom_push)
def _handle_output(self, return_value: Any, context: Context, xcom_push:
Callable):
"""
Handles logic for whether a decorator needs to push a single return
value or multiple return values.
+ It sets outlets if any datasets are found in the returned value(s)
Review Comment:
Doc style nit. The first paragraph appears in the summary so there should be
a blank line. Doesn't matter here as this is a private method so wouldn't
appear in docs anyway, but a good habit to get in to.
```suggestion
It sets outlets if any datasets are found in the returned value(s)
```
##########
airflow/models/xcom.py:
##########
@@ -620,32 +621,41 @@ def serialize_value(
if conf.getboolean("core", "enable_xcom_pickling"):
return pickle.dumps(value)
try:
- return json.dumps(value).encode("UTF-8")
- except (ValueError, TypeError):
+ return json.dumps(value, cls=XComEncoder).encode("UTF-8")
+ except (ValueError, TypeError) as ex:
log.error(
- "Could not serialize the XCom value into JSON."
+ f"{ex}."
Review Comment:
```suggestion
"%s."
```
##########
airflow/models/xcom.py:
##########
@@ -620,32 +621,41 @@ def serialize_value(
if conf.getboolean("core", "enable_xcom_pickling"):
return pickle.dumps(value)
try:
- return json.dumps(value).encode("UTF-8")
- except (ValueError, TypeError):
+ return json.dumps(value, cls=XComEncoder).encode("UTF-8")
+ except (ValueError, TypeError) as ex:
log.error(
- "Could not serialize the XCom value into JSON."
+ f"{ex}."
" If you are using pickle instead of JSON for XCom,"
" then you need to enable pickle support for XCom"
- " in your airflow config."
+ " in your airflow config or make sure to decorate your"
+ " object with attr."
Review Comment:
```suggestion
" object with attr.",
ex
```
##########
airflow/utils/json.py:
##########
@@ -123,3 +138,113 @@ def dumps(self, obj, **kwargs):
def loads(self, s: str | bytes, **kwargs):
return json.loads(s, **kwargs)
+
+
+# for now separate as AirflowJsonEncoder is non-standard
+class XComEncoder(json.JSONEncoder):
+ """This encoder serializes any object that has attr, dataclass or a custom
serializer."""
+
+ def default(self, o: object) -> dict:
+ from airflow.serialization.serialized_objects import BaseSerialization
+
+ dct = {
+ CLASSNAME: o.__module__ + "." + o.__class__.__qualname__,
+ VERSION: getattr(o.__class__, "version", DEFAULT_VERSION),
+ }
+
+ if hasattr(o, "serialize"):
+ dct[DATA] = getattr(o, "serialize")()
+ return dct
+ elif dataclasses.is_dataclass(o.__class__):
+ data = dataclasses.asdict(o)
+ dct[DATA] = BaseSerialization.serialize(data)
+ return dct
+ elif attr.has(o.__class__):
+ # Only include attributes which we can pass back to the classes
constructor
+ data = attr.asdict(o, recurse=True, filter=lambda a, v: a.init) #
type: ignore[arg-type]
+ dct[DATA] = BaseSerialization.serialize(data)
+ return dct
+ else:
+ return super().default(o)
+
+ def encode(self, o: Any) -> str:
+ if isinstance(o, dict) and CLASSNAME in o:
+ raise AttributeError(f"reserved key {CLASSNAME} found in dict to
serialize")
+
+ return super().encode(o)
+
+
+class XComDecoder(json.JSONDecoder):
+ """
+ This decoder deserializes dicts to objects if they contain
+ the `__classname__` key otherwise it will return the dict
+ as is.
+ """
+
+ def __init__(self, *args, **kwargs) -> None:
+ if not kwargs.get("object_hook"):
+ kwargs["object_hook"] = self.object_hook
+
+ super().__init__(*args, **kwargs)
+
+ @staticmethod
+ def object_hook(dct: dict) -> object:
+ dct = XComDecoder._convert(dct)
+
+ if CLASSNAME in dct and VERSION in dct:
+ from airflow.serialization.serialized_objects import
BaseSerialization
+
+ cls = import_string(dct[CLASSNAME])
+
+ version = getattr(cls, "version", 0)
+ if hasattr(cls, "deserialize"):
+ return getattr(cls, "deserialize")(dct[DATA], version)
+
+ if VERSION in dct and int(dct[VERSION]) > version:
Review Comment:
We've already checked version is in the dict on L>194, and shouldn't version
always be an int? (I guess being "safe" about the type isn't a bad thing though)
```suggestion
if and dct[VERSION] > version:
```
##########
airflow/utils/json.py:
##########
@@ -123,3 +138,113 @@ def dumps(self, obj, **kwargs):
def loads(self, s: str | bytes, **kwargs):
return json.loads(s, **kwargs)
+
+
+# for now separate as AirflowJsonEncoder is non-standard
+class XComEncoder(json.JSONEncoder):
+ """This encoder serializes any object that has attr, dataclass or a custom
serializer."""
+
+ def default(self, o: object) -> dict:
+ from airflow.serialization.serialized_objects import BaseSerialization
+
+ dct = {
+ CLASSNAME: o.__module__ + "." + o.__class__.__qualname__,
+ VERSION: getattr(o.__class__, "version", DEFAULT_VERSION),
+ }
+
+ if hasattr(o, "serialize"):
+ dct[DATA] = getattr(o, "serialize")()
+ return dct
+ elif dataclasses.is_dataclass(o.__class__):
+ data = dataclasses.asdict(o)
+ dct[DATA] = BaseSerialization.serialize(data)
+ return dct
+ elif attr.has(o.__class__):
+ # Only include attributes which we can pass back to the classes
constructor
+ data = attr.asdict(o, recurse=True, filter=lambda a, v: a.init) #
type: ignore[arg-type]
+ dct[DATA] = BaseSerialization.serialize(data)
+ return dct
+ else:
+ return super().default(o)
+
+ def encode(self, o: Any) -> str:
+ if isinstance(o, dict) and CLASSNAME in o:
+ raise AttributeError(f"reserved key {CLASSNAME} found in dict to
serialize")
+
+ return super().encode(o)
+
+
+class XComDecoder(json.JSONDecoder):
+ """
+ This decoder deserializes dicts to objects if they contain
+ the `__classname__` key otherwise it will return the dict
+ as is.
+ """
+
+ def __init__(self, *args, **kwargs) -> None:
+ if not kwargs.get("object_hook"):
+ kwargs["object_hook"] = self.object_hook
+
+ super().__init__(*args, **kwargs)
+
+ @staticmethod
+ def object_hook(dct: dict) -> object:
+ dct = XComDecoder._convert(dct)
+
+ if CLASSNAME in dct and VERSION in dct:
+ from airflow.serialization.serialized_objects import
BaseSerialization
+
+ cls = import_string(dct[CLASSNAME])
+
+ version = getattr(cls, "version", 0)
+ if hasattr(cls, "deserialize"):
+ return getattr(cls, "deserialize")(dct[DATA], version)
Review Comment:
This is the wrong version -- we want to pass the version of the "row", not
the current latest version
```suggestion
return getattr(cls, "deserialize")(dct[DATA], dct[VERSION])
```
(But we probably need to do some re-factoring/re-ordering rather than just
apply this change)
##########
airflow/utils/json.py:
##########
@@ -40,6 +45,16 @@
log = logging.getLogger(__name__)
+CLASSNAME = "__classname__"
+VERSION = "__version__"
+DATA = "__data__"
+
+OLD_TYPE = "__type"
+OLD_SOURCE = "__source"
+OLD_DATA = "__var"
+
+DEFAULT_VERSION = "0"
Review Comment:
```suggestion
DEFAULT_VERSION = 0
```
I think
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]