This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 58e26d9df4 Fix XCom deserialization when it contains nonprimitive
values (#30819)
58e26d9df4 is described below
commit 58e26d9df42f10e4e2b46cd26c6832547945789b
Author: Hussein Awala <[email protected]>
AuthorDate: Sun Apr 23 12:38:38 2023 +0200
Fix XCom deserialization when it contains nonprimitive values (#30819)
* Add testcase to show issue with deserialization
* fix XCom deserializion
---------
Co-authored-by: utkarsh sharma <[email protected]>
---
airflow/serialization/serde.py | 4 +++-
tests/utils/test_json.py | 23 +++++++++++++++++++++++
2 files changed, 26 insertions(+), 1 deletion(-)
diff --git a/airflow/serialization/serde.py b/airflow/serialization/serde.py
index 5b19593984..a62ebc2bef 100644
--- a/airflow/serialization/serde.py
+++ b/airflow/serialization/serde.py
@@ -202,7 +202,9 @@ def deserialize(o: T | None, full=True, type_hint: Any =
None) -> object:
return col
if not isinstance(o, dict):
- raise TypeError()
+ # if o is not a dict, then it's already deserialized
+ # in this case we should return it as is
+ return o
o = _convert(o)
diff --git a/tests/utils/test_json.py b/tests/utils/test_json.py
index 38eae2780a..49d089c1bd 100644
--- a/tests/utils/test_json.py
+++ b/tests/utils/test_json.py
@@ -23,6 +23,7 @@ from datetime import date, datetime
from typing import ClassVar
import numpy as np
+import pandas
import pendulum
import pytest
@@ -72,6 +73,28 @@ class TestXComEncoder:
obj = json.loads(s, cls=utils_json.XComDecoder)
assert dataset.uri == obj.uri
+ def test_encode_xcom_with_nested_dict_pandas(self):
+ def _compare(data, obj):
+ assert len(data) == len(obj)
+ for key in data:
+ if isinstance(data[key], dict):
+ return _compare(data[key], obj[key])
+ if isinstance(data[key], pandas.DataFrame):
+ assert data[key].equals(obj[key])
+ else:
+ assert data[key] == obj[key]
+
+ data = (
+ {"foo": 1, "bar": 2, "baz": pandas.DataFrame(data={"col1": [1, 2],
"col2": [3, 4]})},
+ {"d1": {"d2": pandas.DataFrame(data={"col1": [1, 2], "col2": [3,
4]})}},
+ {"d1": {"d2": {"d3": pandas.DataFrame(data={"col1": [1, 2],
"col2": [3, 4]})}}},
+ )
+ s = json.dumps(data, cls=utils_json.XComEncoder)
+ obj = json.loads(s, cls=utils_json.XComDecoder)
+ assert len(data) == len(obj)
+ for i in range(len(data)):
+ _compare(data[i], obj[i])
+
def test_orm_deserialize(self):
x = 14
u = U(x=x)