This is an automated email from the ASF dual-hosted git repository.
rahulvats pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new efdca941103 Use BaseXCom serialize_value when objectstorage_threshold
is less than given input (#49173)
efdca941103 is described below
commit efdca941103542835a944223cc52216c99b5db16
Author: GPK <[email protected]>
AuthorDate: Mon Apr 14 08:33:46 2025 +0100
Use BaseXCom serialize_value when objectstorage_threshold is less than
given input (#49173)
Use BaseXCom serialize_value when objectstorage_threshold is less than
given input (#49173)
---
.../airflow/providers/common/io/xcom/backend.py | 18 ++++++++++----
.../io/tests/unit/common/io/xcom/test_backend.py | 28 ++++++++++++++++++++++
2 files changed, 41 insertions(+), 5 deletions(-)
diff --git
a/providers/common/io/src/airflow/providers/common/io/xcom/backend.py
b/providers/common/io/src/airflow/providers/common/io/xcom/backend.py
index 5fc86b8f182..c5144222e5c 100644
--- a/providers/common/io/src/airflow/providers/common/io/xcom/backend.py
+++ b/providers/common/io/src/airflow/providers/common/io/xcom/backend.py
@@ -118,8 +118,7 @@ class XComObjectStorageBackend(BaseXCom):
run_id: str | None = None,
map_index: int | None = None,
) -> bytes | str:
- # we will always serialize ourselves and not by BaseXCom as the
deserialize method
- # from BaseXCom accepts only XCom objects and not the value directly
+ # We will use this serialized value to write to the object store.
s_val = json.dumps(value, cls=XComEncoder)
s_val_encoded = s_val.encode("utf-8")
@@ -131,7 +130,7 @@ class XComObjectStorageBackend(BaseXCom):
threshold = _get_threshold()
if threshold < 0 or len(s_val_encoded) < threshold: # Either no
threshold or value is small enough.
if AIRFLOW_V_3_0_PLUS:
- return s_val
+ return BaseXCom.serialize_value(value)
else:
# TODO: Remove this branch once we drop support for Airflow 2
# This is for Airflow 2.10 where the value is expected to be
bytes
@@ -160,9 +159,18 @@ class XComObjectStorageBackend(BaseXCom):
Compression is inferred from the file extension.
"""
- data = BaseXCom.deserialize_value(result)
+ base_xcom_deser_result = BaseXCom.deserialize_value(result)
+ data = base_xcom_deser_result
+
+ if not AIRFLOW_V_3_0_PLUS:
+ try:
+ # When XComObjectStorageBackend is used, xcom value will be
serialized using json.dumps
+ # likely, we need to deserialize it using json.loads
+ data = json.loads(base_xcom_deser_result, cls=XComDecoder)
+ except (TypeError, ValueError):
+ pass
try:
- path = XComObjectStorageBackend._get_full_path(data)
+ path =
XComObjectStorageBackend._get_full_path(base_xcom_deser_result)
except (TypeError, ValueError): # Likely value stored directly in the
database.
return data
try:
diff --git a/providers/common/io/tests/unit/common/io/xcom/test_backend.py
b/providers/common/io/tests/unit/common/io/xcom/test_backend.py
index 802106024b8..99fb46a66c7 100644
--- a/providers/common/io/tests/unit/common/io/xcom/test_backend.py
+++ b/providers/common/io/tests/unit/common/io/xcom/test_backend.py
@@ -17,6 +17,8 @@
# under the License.
from __future__ import annotations
+from unittest.mock import MagicMock
+
import pytest
import airflow.models.xcom
@@ -102,6 +104,8 @@ class TestXComObjectStorageBackend:
)
if AIRFLOW_V_3_0_PLUS:
+ # When using XComObjectStorageBackend, the value is stored in the
db is serialized with json dumps
+ # so we need to mimic that same behavior below.
mock_supervisor_comms.get_message.return_value = XComResult(
key="return_value", value={"key": "value"}
)
@@ -362,3 +366,27 @@ class TestXComObjectStorageBackend:
)
assert value == {"key": "superlargevalue" * 100}
+
+ @pytest.mark.parametrize(
+ "value, expected_value",
+ [
+ pytest.param(1, 1, id="int"),
+ pytest.param(1.0, 1.0, id="float"),
+ pytest.param("string", "string", id="str"),
+ pytest.param(True, True, id="bool"),
+ pytest.param({"key": "value"}, {"key": "value"}, id="dict"),
+ pytest.param({"key": {"key": "value"}}, {"key": {"key": "value"}},
id="nested_dict"),
+ pytest.param([1, 2, 3], [1, 2, 3], id="list"),
+ pytest.param((1, 2, 3), (1, 2, 3), id="tuple"),
+ pytest.param(None, None, id="none"),
+ ],
+ )
+ def test_serialization_deserialization_basic(self, value, expected_value):
+ XCom = resolve_xcom_backend()
+ airflow.models.xcom.XCom = XCom
+
+ serialized_data = XCom.serialize_value(value)
+ mock_xcom_ser = MagicMock(value=serialized_data)
+ deserialized_data = XCom.deserialize_value(mock_xcom_ser)
+
+ assert deserialized_data == expected_value