This is an automated email from the ASF dual-hosted git repository.
jedcunningham pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 62f948cd30 Improve XComObjectStorageBackend implementation (#38608)
62f948cd30 is described below
commit 62f948cd309f4adeb6b15a2b634a66bfc87159cc
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Wed Apr 3 06:00:35 2024 +0800
Improve XComObjectStorageBackend implementation (#38608)
Repeated configuration access is moved to cached functions so the string
literals don't need to be written repeatedly. The path configuration is
made mandatory since it more or less is; using this backend without a
path configured is most likely an unintended user error.
Various functions are rewritten to take advantage of early returns, and
more localized try-except blocks to improve code quality.
---
airflow/providers/common/io/xcom/backend.py | 99 +++++++++++++-------------
tests/providers/common/io/xcom/test_backend.py | 20 ++++--
2 files changed, 64 insertions(+), 55 deletions(-)
diff --git a/airflow/providers/common/io/xcom/backend.py
b/airflow/providers/common/io/xcom/backend.py
index b2416862ee..163e15e00d 100644
--- a/airflow/providers/common/io/xcom/backend.py
+++ b/airflow/providers/common/io/xcom/backend.py
@@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations
+import contextlib
import json
import uuid
from typing import TYPE_CHECKING, Any, TypeVar
@@ -23,6 +24,7 @@ from urllib.parse import urlsplit
import fsspec.utils
+from airflow.compat.functools import cache
from airflow.configuration import conf
from airflow.io.path import ObjectStoragePath
from airflow.models.xcom import BaseXCom
@@ -65,6 +67,21 @@ def _get_compression_suffix(compression: str) -> str:
raise ValueError(f"Compression {compression} is not supported. Make sure
it is installed.")
+@cache
+def _get_base_path() -> ObjectStoragePath:
+ return ObjectStoragePath(conf.get_mandatory_value(SECTION,
"xcom_objectstorage_path"))
+
+
+@cache
+def _get_compression() -> str | None:
+ return conf.get(SECTION, "xcom_objectstorage_compression", fallback=None)
or None
+
+
+@cache
+def _get_threshold() -> int:
+ return conf.getint(SECTION, "xcom_objectstorage_threshold", fallback=-1)
+
+
class XComObjectStorageBackend(BaseXCom):
"""XCom backend that stores data in an object store or database depending
on the size of the data.
@@ -75,30 +92,24 @@ class XComObjectStorageBackend(BaseXCom):
"""
@staticmethod
- def _get_key(data: str) -> str:
- """Get the key from the url and normalizes it to be relative to the
configured path.
+ def _get_full_path(data: str) -> ObjectStoragePath:
+ """Get the path from stored value.
:raises ValueError: if the key is not relative to the configured path
:raises TypeError: if the url is not a valid url or cannot be split
"""
- path = conf.get(SECTION, "xcom_objectstorage_path", fallback="")
- p = ObjectStoragePath(path)
+ p = _get_base_path()
# normalize the path
- path = str(p)
-
try:
url = urlsplit(data)
except AttributeError:
- raise TypeError(f"Not a valid url: {data}")
+ raise TypeError(f"Not a valid url: {data}") from None
if url.scheme:
- k = ObjectStoragePath(data)
-
- if _is_relative_to(k, p) is False:
+ if not _is_relative_to(ObjectStoragePath(data), p):
raise ValueError(f"Invalid key: {data}")
- else:
- return data.replace(path, "", 1).lstrip("/")
+ return p / data.replace(str(p), "", 1).lstrip("/")
raise ValueError(f"Not a valid url: {data}")
@@ -115,61 +126,47 @@ class XComObjectStorageBackend(BaseXCom):
# we will always serialize ourselves and not by BaseXCom as the
deserialize method
# from BaseXCom accepts only XCom objects and not the value directly
s_val = json.dumps(value, cls=XComEncoder).encode("utf-8")
- path = conf.get(SECTION, "xcom_objectstorage_path", fallback="")
- compression = conf.get(SECTION, "xcom_objectstorage_compression",
fallback=None)
- if compression:
- suffix = "." + _get_compression_suffix(compression)
+ if compression := _get_compression():
+ suffix = f".{_get_compression_suffix(compression)}"
else:
suffix = ""
- compression = None
- threshold = conf.getint(SECTION, "xcom_objectstorage_threshold",
fallback=-1)
-
- if path and -1 < threshold < len(s_val):
- # safeguard against collisions
- while True:
- p = ObjectStoragePath(path) /
f"{dag_id}/{run_id}/{task_id}/{str(uuid.uuid4())}{suffix}"
- if not p.exists():
- break
+ threshold = _get_threshold()
+ if threshold < 0 or len(s_val) < threshold: # Either no threshold or
value is small enough.
+ return s_val
- if not p.parent.exists():
- p.parent.mkdir(parents=True, exist_ok=True)
+ base_path = _get_base_path()
+ while True: # Safeguard against collisions.
+ p = base_path.joinpath(dag_id, run_id, task_id,
f"{uuid.uuid4()}{suffix}")
+ if not p.exists():
+ break
+ p.parent.mkdir(parents=True, exist_ok=True)
- with p.open(mode="wb", compression=compression) as f:
- f.write(s_val)
-
- return BaseXCom.serialize_value(str(p))
- else:
- return s_val
+ with p.open(mode="wb", compression=compression) as f:
+ f.write(s_val)
+ return BaseXCom.serialize_value(str(p))
@staticmethod
- def deserialize_value(
- result: XCom,
- ) -> Any:
+ def deserialize_value(result: XCom) -> Any:
"""Deserializes the value from the database or object storage.
Compression is inferred from the file extension.
"""
data = BaseXCom.deserialize_value(result)
- path = conf.get(SECTION, "xcom_objectstorage_path", fallback="")
-
try:
- p = ObjectStoragePath(path) /
XComObjectStorageBackend._get_key(data)
- return json.load(p.open(mode="rb", compression="infer"),
cls=XComDecoder)
- except TypeError:
+ path = XComObjectStorageBackend._get_full_path(data)
+ except (TypeError, ValueError): # Likely value stored directly in the
database.
return data
- except ValueError:
+ try:
+ with path.open(mode="rb", compression="infer") as f:
+ return json.load(f, cls=XComDecoder)
+ except (TypeError, ValueError):
return data
@staticmethod
def purge(xcom: XCom, session: Session) -> None:
- path = conf.get(SECTION, "xcom_objectstorage_path", fallback="")
- if isinstance(xcom.value, str):
- try:
- p = ObjectStoragePath(path) /
XComObjectStorageBackend._get_key(xcom.value)
- p.unlink(missing_ok=True)
- except TypeError:
- pass
- except ValueError:
- pass
+ if not isinstance(xcom.value, str):
+ return
+ with contextlib.suppress(TypeError, ValueError):
+
XComObjectStorageBackend._get_full_path(xcom.value).unlink(missing_ok=True)
diff --git a/tests/providers/common/io/xcom/test_backend.py
b/tests/providers/common/io/xcom/test_backend.py
index 008394f365..2da2d6fecd 100644
--- a/tests/providers/common/io/xcom/test_backend.py
+++ b/tests/providers/common/io/xcom/test_backend.py
@@ -20,7 +20,6 @@ from __future__ import annotations
import pytest
import airflow.models.xcom
-from airflow.io.path import ObjectStoragePath
from airflow.models.xcom import BaseXCom, resolve_xcom_backend
from airflow.operators.empty import EmptyOperator
from airflow.providers.common.io.xcom.backend import XComObjectStorageBackend
@@ -42,6 +41,19 @@ def reset_db():
db.clear_db_xcom()
[email protected](autouse=True)
+def reset_cache():
+ from airflow.providers.common.io.xcom import backend
+
+ backend._get_base_path.cache_clear()
+ backend._get_compression.cache_clear()
+ backend._get_threshold.cache_clear()
+ yield
+ backend._get_base_path.cache_clear()
+ backend._get_compression.cache_clear()
+ backend._get_threshold.cache_clear()
+
+
@pytest.fixture
def task_instance(create_task_instance_of_operator):
return create_task_instance_of_operator(
@@ -121,7 +133,7 @@ class TestXComObjectStorageBackend:
)
data = BaseXCom.deserialize_value(res)
- p = ObjectStoragePath(self.path) /
XComObjectStorageBackend._get_key(data)
+ p = XComObjectStorageBackend._get_full_path(data)
assert p.exists() is True
value = XCom.get_value(
@@ -166,7 +178,7 @@ class TestXComObjectStorageBackend:
)
data = BaseXCom.deserialize_value(res)
- p = ObjectStoragePath(self.path) /
XComObjectStorageBackend._get_key(data)
+ p = XComObjectStorageBackend._get_full_path(data)
assert p.exists() is True
XCom.clear(
@@ -205,7 +217,7 @@ class TestXComObjectStorageBackend:
)
data = BaseXCom.deserialize_value(res)
- p = ObjectStoragePath(self.path) /
XComObjectStorageBackend._get_key(data)
+ p = XComObjectStorageBackend._get_full_path(data)
assert p.exists() is True
assert p.suffix == ".gz"