This is an automated email from the ASF dual-hosted git repository.

jedcunningham pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 62f948cd30 Improve XComObjectStorageBackend implementation (#38608)
62f948cd30 is described below

commit 62f948cd309f4adeb6b15a2b634a66bfc87159cc
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Wed Apr 3 06:00:35 2024 +0800

    Improve XComObjectStorageBackend implementation (#38608)
    
    Repeated configuration access is moved to cached functions so the string
    literals don't need to be written repeatedly. The path configuration is
    made mandatory since it more or less is; using this backend without a
    path configured is most likely an unintended user error.
    
    Various functions are rewritten to take advantage of early returns, and
    more localized try-except blocks to improve code quality.
---
 airflow/providers/common/io/xcom/backend.py    | 99 +++++++++++++-------------
 tests/providers/common/io/xcom/test_backend.py | 20 ++++--
 2 files changed, 64 insertions(+), 55 deletions(-)

diff --git a/airflow/providers/common/io/xcom/backend.py 
b/airflow/providers/common/io/xcom/backend.py
index b2416862ee..163e15e00d 100644
--- a/airflow/providers/common/io/xcom/backend.py
+++ b/airflow/providers/common/io/xcom/backend.py
@@ -16,6 +16,7 @@
 # under the License.
 from __future__ import annotations
 
+import contextlib
 import json
 import uuid
 from typing import TYPE_CHECKING, Any, TypeVar
@@ -23,6 +24,7 @@ from urllib.parse import urlsplit
 
 import fsspec.utils
 
+from airflow.compat.functools import cache
 from airflow.configuration import conf
 from airflow.io.path import ObjectStoragePath
 from airflow.models.xcom import BaseXCom
@@ -65,6 +67,21 @@ def _get_compression_suffix(compression: str) -> str:
     raise ValueError(f"Compression {compression} is not supported. Make sure 
it is installed.")
 
 
+@cache
+def _get_base_path() -> ObjectStoragePath:
+    return ObjectStoragePath(conf.get_mandatory_value(SECTION, 
"xcom_objectstorage_path"))
+
+
+@cache
+def _get_compression() -> str | None:
+    return conf.get(SECTION, "xcom_objectstorage_compression", fallback=None) 
or None
+
+
+@cache
+def _get_threshold() -> int:
+    return conf.getint(SECTION, "xcom_objectstorage_threshold", fallback=-1)
+
+
 class XComObjectStorageBackend(BaseXCom):
     """XCom backend that stores data in an object store or database depending 
on the size of the data.
 
@@ -75,30 +92,24 @@ class XComObjectStorageBackend(BaseXCom):
     """
 
     @staticmethod
-    def _get_key(data: str) -> str:
-        """Get the key from the url and normalizes it to be relative to the 
configured path.
+    def _get_full_path(data: str) -> ObjectStoragePath:
+        """Get the path from stored value.
 
         :raises ValueError: if the key is not relative to the configured path
         :raises TypeError: if the url is not a valid url or cannot be split
         """
-        path = conf.get(SECTION, "xcom_objectstorage_path", fallback="")
-        p = ObjectStoragePath(path)
+        p = _get_base_path()
 
         # normalize the path
-        path = str(p)
-
         try:
             url = urlsplit(data)
         except AttributeError:
-            raise TypeError(f"Not a valid url: {data}")
+            raise TypeError(f"Not a valid url: {data}") from None
 
         if url.scheme:
-            k = ObjectStoragePath(data)
-
-            if _is_relative_to(k, p) is False:
+            if not _is_relative_to(ObjectStoragePath(data), p):
                 raise ValueError(f"Invalid key: {data}")
-            else:
-                return data.replace(path, "", 1).lstrip("/")
+            return p / data.replace(str(p), "", 1).lstrip("/")
 
         raise ValueError(f"Not a valid url: {data}")
 
@@ -115,61 +126,47 @@ class XComObjectStorageBackend(BaseXCom):
         # we will always serialize ourselves and not by BaseXCom as the 
deserialize method
         # from BaseXCom accepts only XCom objects and not the value directly
         s_val = json.dumps(value, cls=XComEncoder).encode("utf-8")
-        path = conf.get(SECTION, "xcom_objectstorage_path", fallback="")
-        compression = conf.get(SECTION, "xcom_objectstorage_compression", 
fallback=None)
 
-        if compression:
-            suffix = "." + _get_compression_suffix(compression)
+        if compression := _get_compression():
+            suffix = f".{_get_compression_suffix(compression)}"
         else:
             suffix = ""
-            compression = None
 
-        threshold = conf.getint(SECTION, "xcom_objectstorage_threshold", 
fallback=-1)
-
-        if path and -1 < threshold < len(s_val):
-            # safeguard against collisions
-            while True:
-                p = ObjectStoragePath(path) / 
f"{dag_id}/{run_id}/{task_id}/{str(uuid.uuid4())}{suffix}"
-                if not p.exists():
-                    break
+        threshold = _get_threshold()
+        if threshold < 0 or len(s_val) < threshold:  # Either no threshold or 
value is small enough.
+            return s_val
 
-            if not p.parent.exists():
-                p.parent.mkdir(parents=True, exist_ok=True)
+        base_path = _get_base_path()
+        while True:  # Safeguard against collisions.
+            p = base_path.joinpath(dag_id, run_id, task_id, 
f"{uuid.uuid4()}{suffix}")
+            if not p.exists():
+                break
+        p.parent.mkdir(parents=True, exist_ok=True)
 
-            with p.open(mode="wb", compression=compression) as f:
-                f.write(s_val)
-
-            return BaseXCom.serialize_value(str(p))
-        else:
-            return s_val
+        with p.open(mode="wb", compression=compression) as f:
+            f.write(s_val)
+        return BaseXCom.serialize_value(str(p))
 
     @staticmethod
-    def deserialize_value(
-        result: XCom,
-    ) -> Any:
+    def deserialize_value(result: XCom) -> Any:
         """Deserializes the value from the database or object storage.
 
         Compression is inferred from the file extension.
         """
         data = BaseXCom.deserialize_value(result)
-        path = conf.get(SECTION, "xcom_objectstorage_path", fallback="")
-
         try:
-            p = ObjectStoragePath(path) / 
XComObjectStorageBackend._get_key(data)
-            return json.load(p.open(mode="rb", compression="infer"), 
cls=XComDecoder)
-        except TypeError:
+            path = XComObjectStorageBackend._get_full_path(data)
+        except (TypeError, ValueError):  # Likely value stored directly in the 
database.
             return data
-        except ValueError:
+        try:
+            with path.open(mode="rb", compression="infer") as f:
+                return json.load(f, cls=XComDecoder)
+        except (TypeError, ValueError):
             return data
 
     @staticmethod
     def purge(xcom: XCom, session: Session) -> None:
-        path = conf.get(SECTION, "xcom_objectstorage_path", fallback="")
-        if isinstance(xcom.value, str):
-            try:
-                p = ObjectStoragePath(path) / 
XComObjectStorageBackend._get_key(xcom.value)
-                p.unlink(missing_ok=True)
-            except TypeError:
-                pass
-            except ValueError:
-                pass
+        if not isinstance(xcom.value, str):
+            return
+        with contextlib.suppress(TypeError, ValueError):
+            
XComObjectStorageBackend._get_full_path(xcom.value).unlink(missing_ok=True)
diff --git a/tests/providers/common/io/xcom/test_backend.py 
b/tests/providers/common/io/xcom/test_backend.py
index 008394f365..2da2d6fecd 100644
--- a/tests/providers/common/io/xcom/test_backend.py
+++ b/tests/providers/common/io/xcom/test_backend.py
@@ -20,7 +20,6 @@ from __future__ import annotations
 import pytest
 
 import airflow.models.xcom
-from airflow.io.path import ObjectStoragePath
 from airflow.models.xcom import BaseXCom, resolve_xcom_backend
 from airflow.operators.empty import EmptyOperator
 from airflow.providers.common.io.xcom.backend import XComObjectStorageBackend
@@ -42,6 +41,19 @@ def reset_db():
     db.clear_db_xcom()
 
 
[email protected](autouse=True)
+def reset_cache():
+    from airflow.providers.common.io.xcom import backend
+
+    backend._get_base_path.cache_clear()
+    backend._get_compression.cache_clear()
+    backend._get_threshold.cache_clear()
+    yield
+    backend._get_base_path.cache_clear()
+    backend._get_compression.cache_clear()
+    backend._get_threshold.cache_clear()
+
+
 @pytest.fixture
 def task_instance(create_task_instance_of_operator):
     return create_task_instance_of_operator(
@@ -121,7 +133,7 @@ class TestXComObjectStorageBackend:
         )
 
         data = BaseXCom.deserialize_value(res)
-        p = ObjectStoragePath(self.path) / 
XComObjectStorageBackend._get_key(data)
+        p = XComObjectStorageBackend._get_full_path(data)
         assert p.exists() is True
 
         value = XCom.get_value(
@@ -166,7 +178,7 @@ class TestXComObjectStorageBackend:
         )
 
         data = BaseXCom.deserialize_value(res)
-        p = ObjectStoragePath(self.path) / 
XComObjectStorageBackend._get_key(data)
+        p = XComObjectStorageBackend._get_full_path(data)
         assert p.exists() is True
 
         XCom.clear(
@@ -205,7 +217,7 @@ class TestXComObjectStorageBackend:
         )
 
         data = BaseXCom.deserialize_value(res)
-        p = ObjectStoragePath(self.path) / 
XComObjectStorageBackend._get_key(data)
+        p = XComObjectStorageBackend._get_full_path(data)
         assert p.exists() is True
         assert p.suffix == ".gz"
 

Reply via email to