This is an automated email from the ASF dual-hosted git repository.

bolke pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 573d650708 AIP-58: Add object storage backend for xcom (#37058)
573d650708 is described below

commit 573d650708334c5e4ea4f1d72d01b976edefc6cf
Author: Bolke de Bruin <bo...@xs4all.nl>
AuthorDate: Sat Feb 3 09:30:48 2024 +0100

    AIP-58: Add object storage backend for xcom (#37058)
    
    * Add object storage backend for xcom
    
    This adds the possibility to store xcoms on a  configurable object storage
    supported backend.
    
    ---------
    
    Co-authored-by: Ephraim Anierobi <splendidzig...@gmail.com>
---
 airflow/models/xcom.py                             |  18 +-
 airflow/provider.yaml.schema.json                  |   7 +
 airflow/providers/common/io/provider.yaml          |  34 +++
 .../common/io/{provider.yaml => xcom/__init__.py}  |  39 +--
 airflow/providers/common/io/xcom/backend.py        | 168 +++++++++++++
 .../configurations-ref.rst                         |  18 ++
 docs/apache-airflow-providers-common-io/index.rst  |   2 +
 .../xcom_backend.rst                               |  42 ++++
 docs/apache-airflow/core-concepts/xcoms.rst        |  29 +++
 docs/spelling_wordlist.txt                         |   1 +
 tests/models/test_xcom.py                          |   4 +-
 .../providers/common/io/xcom/__init__.py           |  29 ---
 tests/providers/common/io/xcom/test_backend.py     | 263 +++++++++++++++++++++
 13 files changed, 596 insertions(+), 58 deletions(-)

diff --git a/airflow/models/xcom.py b/airflow/models/xcom.py
index ee74dd89f7..f36f3ba3f7 100644
--- a/airflow/models/xcom.py
+++ b/airflow/models/xcom.py
@@ -29,6 +29,7 @@ from functools import cached_property, wraps
 from typing import TYPE_CHECKING, Any, Generator, Iterable, cast, overload
 
 import attr
+from deprecated import deprecated
 from sqlalchemy import (
     Column,
     ForeignKeyConstraint,
@@ -368,6 +369,7 @@ class BaseXCom(TaskInstanceDependencies, LoggingMixin):
     @staticmethod
     @provide_session
     @internal_api_call
+    @deprecated
     def get_one(
         execution_date: datetime.datetime | None = None,
         key: str | None = None,
@@ -418,7 +420,7 @@ class BaseXCom(TaskInstanceDependencies, LoggingMixin):
 
         result = query.with_entities(BaseXCom.value).first()
         if result:
-            return BaseXCom.deserialize_value(result)
+            return XCom.deserialize_value(result)
         return None
 
     @overload
@@ -556,9 +558,15 @@ class BaseXCom(TaskInstanceDependencies, LoggingMixin):
         for xcom in xcoms:
             if not isinstance(xcom, XCom):
                 raise TypeError(f"Expected XCom; received 
{xcom.__class__.__name__}")
+            XCom.purge(xcom, session)
             session.delete(xcom)
         session.commit()
 
+    @staticmethod
+    def purge(xcom: XCom, session: Session) -> None:
+        """Purge an XCom entry from underlying storage implementations."""
+        pass
+
     @overload
     @staticmethod
     @internal_api_call
@@ -641,7 +649,13 @@ class BaseXCom(TaskInstanceDependencies, LoggingMixin):
         query = session.query(BaseXCom).filter_by(dag_id=dag_id, 
task_id=task_id, run_id=run_id)
         if map_index is not None:
             query = query.filter_by(map_index=map_index)
-        query.delete()
+
+        for xcom in query:
+            # print(f"Clearing XCOM {xcom} with value {xcom.value}")
+            XCom.purge(xcom, session)
+            session.delete(xcom)
+
+        session.commit()
 
     @staticmethod
     def serialize_value(
diff --git a/airflow/provider.yaml.schema.json 
b/airflow/provider.yaml.schema.json
index 430cf2e632..ff14ea59a7 100644
--- a/airflow/provider.yaml.schema.json
+++ b/airflow/provider.yaml.schema.json
@@ -189,6 +189,13 @@
                 "type": "string"
             }
         },
+        "xcom": {
+            "type": "array",
+            "description": "XCom module names",
+            "items": {
+                "type": "string"
+            }
+        },
         "transfers": {
             "type": "array",
             "items": {
diff --git a/airflow/providers/common/io/provider.yaml 
b/airflow/providers/common/io/provider.yaml
index 7c32a42d0c..3c919320fc 100644
--- a/airflow/providers/common/io/provider.yaml
+++ b/airflow/providers/common/io/provider.yaml
@@ -43,3 +43,37 @@ operators:
   - integration-name: Common IO
     python-modules:
       - airflow.providers.common.io.operators.file_transfer
+
+xcom:
+  - airflow.providers.common.io.xcom.backend
+
+config:
+  common.io:
+    description: Common IO configuration section
+    options:
+      xcom_objectstorage_path:
+        description: |
+          Path to a location on object storage where XComs can be stored in 
url format.
+        version_added: 1.3.0
+        type: string
+        example: "s3://conn_id@bucket/path"
+        default: ""
+      xcom_objectstorage_threshold:
+        description: |
+          Threshold in bytes for storing XComs in object storage. -1 means 
always store in the
+          database. 0 means always store in object storage. Any positive 
number means
+          it will be stored in object storage if the size of the value is 
greater than the threshold.
+        version_added: 1.3.0
+        type: integer
+        example: "1000000"
+        default: "-1"
+      xcom_objectstorage_compression:
+        description: |
+          Compression algorithm to use when storing XComs in object storage. 
Supported algorithms
+          are a.o.: snappy, zip, gzip, bz2, and lzma. If not specified, no 
compression will be used.
+          Note that the compression algorithm must be available in the Python 
installation (e.g.
+          python-snappy for snappy). Zip, gz, bz2 are available by default.
+        version_added: 1.3.0
+        type: string
+        example: "gz"
+        default: ""
diff --git a/airflow/providers/common/io/provider.yaml 
b/airflow/providers/common/io/xcom/__init__.py
similarity index 56%
copy from airflow/providers/common/io/provider.yaml
copy to airflow/providers/common/io/xcom/__init__.py
index 7c32a42d0c..7977c868bb 100644
--- a/airflow/providers/common/io/provider.yaml
+++ b/airflow/providers/common/io/xcom/__init__.py
@@ -14,32 +14,19 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+from __future__ import annotations
 
----
-package-name: apache-airflow-providers-common-io
-name: Common IO
-description: |
-  ``Common IO Provider``
+import packaging.version
 
-state: ready
-source-date-epoch: 1704610529
-versions:
-  - 1.2.0
-  - 1.1.0
-  - 1.0.1
-  - 1.0.0
+try:
+    from airflow import __version__ as airflow_version
+except ImportError:
+    from airflow.version import version as airflow_version
 
-dependencies:
-  - apache-airflow>=2.8.0
-
-integrations:
-  - integration-name: Common IO
-    external-doc-url: 
https://filesystem-spec.readthedocs.io/en/latest/index.html
-    how-to-guide:
-      - /docs/apache-airflow-providers-common-io/operators.rst
-    tags: [software]
-
-operators:
-  - integration-name: Common IO
-    python-modules:
-      - airflow.providers.common.io.operators.file_transfer
+if 
packaging.version.parse(packaging.version.parse(airflow_version).base_version) 
< packaging.version.parse(
+    "2.9.0"
+):
+    raise RuntimeError(
+        "The package xcom backend feature of 
`apache-airflow-providers-common-io` needs "
+        "Apache Airflow 2.9.0+"
+    )
diff --git a/airflow/providers/common/io/xcom/backend.py 
b/airflow/providers/common/io/xcom/backend.py
new file mode 100644
index 0000000000..0bb5dd286f
--- /dev/null
+++ b/airflow/providers/common/io/xcom/backend.py
@@ -0,0 +1,168 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import json
+import uuid
+from typing import TYPE_CHECKING, Any, TypeVar
+from urllib.parse import urlsplit
+
+import fsspec.utils
+
+from airflow.configuration import conf
+from airflow.io.path import ObjectStoragePath
+from airflow.models.xcom import BaseXCom
+from airflow.utils.json import XComDecoder, XComEncoder
+
+if TYPE_CHECKING:
+    from sqlalchemy.orm import Session
+
+    from airflow.models import XCom
+
+T = TypeVar("T")
+
+SECTION = "common.io"
+
+
+def _is_relative_to(o: ObjectStoragePath, other: ObjectStoragePath) -> bool:
+    """This is a port of the pathlib.Path.is_relative_to method. It is not 
available in python 3.8."""
+    if hasattr(o, "is_relative_to"):
+        return o.is_relative_to(other)
+
+    try:
+        o.relative_to(other)
+        return True
+    except ValueError:
+        return False
+
+
+def _get_compression_suffix(compression: str) -> str:
+    """This returns the compression suffix for the given compression.
+
+    :raises ValueError: if the compression is not supported
+    """
+    for suffix, c in fsspec.utils.compressions.items():
+        if c == compression:
+            return suffix
+
+    raise ValueError(f"Compression {compression} is not supported. Make sure 
it is installed.")
+
+
+class XComObjectStoreBackend(BaseXCom):
+    """XCom backend that stores data in an object store or database depending 
on the size of the data.
+
+    If the value is larger than the configured threshold, it will be stored in 
an object store.
+    Otherwise, it will be stored in the database. If it is stored in an object 
store, the path
+    to the object in the store will be returned and saved in the database (by 
BaseXCom). Otherwise, the value
+    itself will be returned and thus saved in the database.
+    """
+
+    @staticmethod
+    def _get_key(data: str) -> str:
+        """This gets the key from the url and normalizes it to be relative to 
the configured path.
+
+        :raises ValueError: if the key is not relative to the configured path
+        :raises TypeError: if the url is not a valid url or cannot be split
+        """
+        path = conf.get(SECTION, "xcom_objectstore_path", fallback="")
+        p = ObjectStoragePath(path)
+
+        try:
+            url = urlsplit(data)
+        except AttributeError:
+            raise TypeError(f"Not a valid url: {data}")
+
+        if url.scheme:
+            k = ObjectStoragePath(data)
+
+            if _is_relative_to(k, p) is False:
+                raise ValueError(f"Invalid key: {data}")
+            else:
+                return data.replace(path, "", 1).lstrip("/")
+
+        raise ValueError(f"Not a valid url: {data}")
+
+    @staticmethod
+    def serialize_value(
+        value: T,
+        *,
+        key: str | None = None,
+        task_id: str | None = None,
+        dag_id: str | None = None,
+        run_id: str | None = None,
+        map_index: int | None = None,
+    ) -> bytes | str:
+        # we will always serialize ourselves and not by BaseXCom as the 
deserialize method
+        # from BaseXCom accepts only XCom objects and not the value directly
+        s_val = json.dumps(value, cls=XComEncoder).encode("utf-8")
+        path = conf.get(SECTION, "xcom_objectstore_path", fallback="")
+        compression = conf.get(SECTION, "xcom_objectstore_compression", 
fallback=None)
+
+        if compression:
+            suffix = "." + _get_compression_suffix(compression)
+        else:
+            suffix = ""
+
+        threshold = conf.getint(SECTION, "xcom_objectstore_threshold", 
fallback=-1)
+
+        if path and -1 < threshold < len(s_val):
+            # safeguard against collisions
+            while True:
+                p = ObjectStoragePath(path) / 
f"{dag_id}/{run_id}/{task_id}/{str(uuid.uuid4())}{suffix}"
+                if not p.exists():
+                    break
+
+            if not p.parent.exists():
+                p.parent.mkdir(parents=True, exist_ok=True)
+
+            with p.open("wb", compression=compression) as f:
+                f.write(s_val)
+
+            return BaseXCom.serialize_value(str(p))
+        else:
+            return s_val
+
+    @staticmethod
+    def deserialize_value(
+        result: XCom,
+    ) -> Any:
+        """Deserializes the value from the database or object storage.
+
+        Compression is inferred from the file extension.
+        """
+        data = BaseXCom.deserialize_value(result)
+        path = conf.get(SECTION, "xcom_objectstore_path", fallback="")
+
+        try:
+            p = ObjectStoragePath(path) / XComObjectStoreBackend._get_key(data)
+            return json.load(p.open("rb", compression="infer"), 
cls=XComDecoder)
+        except TypeError:
+            return data
+        except ValueError:
+            return data
+
+    @staticmethod
+    def purge(xcom: XCom, session: Session) -> None:
+        path = conf.get(SECTION, "xcom_objectstore_path", fallback="")
+        if isinstance(xcom.value, str):
+            try:
+                p = ObjectStoragePath(path) / 
XComObjectStoreBackend._get_key(xcom.value)
+                p.unlink(missing_ok=True)
+            except TypeError:
+                pass
+            except ValueError:
+                pass
diff --git a/docs/apache-airflow-providers-common-io/configurations-ref.rst 
b/docs/apache-airflow-providers-common-io/configurations-ref.rst
new file mode 100644
index 0000000000..5885c9d91b
--- /dev/null
+++ b/docs/apache-airflow-providers-common-io/configurations-ref.rst
@@ -0,0 +1,18 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+    or more contributor license agreements.  See the NOTICE file
+    distributed with this work for additional information
+    regarding copyright ownership.  The ASF licenses this file
+    to you under the Apache License, Version 2.0 (the
+    "License"); you may not use this file except in compliance
+    with the License.  You may obtain a copy of the License at
+
+ ..   http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+    software distributed under the License is distributed on an
+    "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+    KIND, either express or implied.  See the License for the
+    specific language governing permissions and limitations
+    under the License.
+
+.. include:: ../exts/includes/providers-configurations-ref.rst
diff --git a/docs/apache-airflow-providers-common-io/index.rst 
b/docs/apache-airflow-providers-common-io/index.rst
index 6cd056dfdc..488582fcd0 100644
--- a/docs/apache-airflow-providers-common-io/index.rst
+++ b/docs/apache-airflow-providers-common-io/index.rst
@@ -28,6 +28,7 @@
     Home <self>
     Changelog <changelog>
     Security <security>
+    Configuration <configurations-ref>
 
 .. toctree::
     :hidden:
@@ -36,6 +37,7 @@
 
     Transferring a file <transfer>
     Operators <operators>
+    Object Storage XCom Backend <xcom_backend>
 
 .. toctree::
     :hidden:
diff --git a/docs/apache-airflow-providers-common-io/xcom_backend.rst 
b/docs/apache-airflow-providers-common-io/xcom_backend.rst
new file mode 100644
index 0000000000..9216fff711
--- /dev/null
+++ b/docs/apache-airflow-providers-common-io/xcom_backend.rst
@@ -0,0 +1,42 @@
+.. Licensed to the Apache Software Foundation (ASF) under one
+   or more contributor license agreements.  See the NOTICE file
+   distributed with this work for additional information
+   regarding copyright ownership.  The ASF licenses this file
+   to you under the Apache License, Version 2.0 (the
+   "License"); you may not use this file except in compliance
+   with the License.  You may obtain a copy of the License at
+
+..   http://www.apache.org/licenses/LICENSE-2.0
+
+.. Unless required by applicable law or agreed to in writing,
+   software distributed under the License is distributed on an
+   "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+   KIND, either express or implied.  See the License for the
+   specific language governing permissions and limitations
+   under the License.
+
+Object Storage XCom Backend
+===========================
+
+The default XCom backend is the :class:`~airflow.models.xcom.BaseXCom` class, 
which stores XComs in the Airflow database. This is fine for small values, but 
can be problematic for large values, or for large numbers of XComs.
+
+To enable storing XComs in an object store, you can set the ``xcom_backend`` 
configuration option to 
``airflow.providers.common.io.xcom.backend.XComObjectStoreBackend``. You will 
also need to set ``xcom_objectstorage_path`` to the desired location. The 
connection
+id is obtained from the user part of the url the you will provide, e.g. 
``xcom_objectstorage_path = s3://conn_id@mybucket/key``. Furthermore, 
``xcom_objectstorage_threshold`` is required
+to be something larger than -1. Any object smaller than the threshold in bytes 
will be stored in the database and anything larger will be be
+put in object storage. This will allow a hybrid setup. If an xcom is stored on 
object storage a reference will be
+saved in the database. Finally, you can set ``xcom_objectstorage_compression`` 
to fsspec supported compression methods like ``zip`` or ``snappy`` to
+compress the data before storing it in object storage.
+
+So for example the following configuration will store anything above 1MB in S3 
and will compress it using gzip::
+
+      [core]
+      xcom_backend = 
airflow.providers.common.io.xcom.backend.XComObjectStoreBackend
+
+      [common.io]
+      xcom_objectstorage_path = s3://conn_id@mybucket/key
+      xcom_objectstorage_threshold = 1048576
+      xcom_objectstorage_compression = gzip
+
+.. note::
+
+  Compression requires the support for it is installed in your python 
environment. For example, to use ``snappy`` compression, you need to install 
``python-snappy``. Zip, gzip and bz2 work out of the box.
diff --git a/docs/apache-airflow/core-concepts/xcoms.rst 
b/docs/apache-airflow/core-concepts/xcoms.rst
index b4685b4cb0..b2dc2b665c 100644
--- a/docs/apache-airflow/core-concepts/xcoms.rst
+++ b/docs/apache-airflow/core-concepts/xcoms.rst
@@ -56,6 +56,35 @@ XComs are a relative of :doc:`variables`, with the main 
difference being that XC
 
   If the first task run is not succeeded then on every retry task XComs will 
be cleared to make the task run idempotent.
 
+
+Object Storage XCom Backend
+---------------------------
+
+The default XCom backend is the :class:`~airflow.models.xcom.BaseXCom` class, 
which stores XComs in the Airflow database. This is fine for small values, but 
can be problematic for large values, or for large numbers of XComs.
+
+To enable storing XComs in an object store, you can set the ``xcom_backend`` 
configuration option to 
``airflow.providers.common.io.xcom.backend.XComObjectStoreBackend``. You will 
also need to set ``xcom_objectstorage_path`` to the desired location. The 
connection
+id is obtained from the user part of the url the you will provide, e.g. 
``xcom_objectstorage_path = s3://conn_id@mybucket/key``. Furthermore, 
``xcom_objectstorage_threshold`` is required
+to be something larger than -1. Any object smaller than the threshold in bytes 
will be stored in the database and anything larger will be be
+put in object storage. This will allow a hybrid setup. If an xcom is stored on 
object storage a reference will be
+saved in the database. Finally, you can set ``xcom_objectstorage_compression`` 
to fsspec supported compression methods like ``zip`` or ``snappy`` to
+compress the data before storing it in object storage.
+
+So for example the following configuration will store anything above 1MB in S3 
and will compress it using gzip::
+
+      [core]
+      xcom_backend = 
airflow.providers.common.io.xcom.backend.XComObjectStoreBackend
+
+      [common.io]
+      xcom_objectstorage_path = s3://conn_id@mybucket/key
+      xcom_objectstorage_threshold = 1048576
+      xcom_objectstorage_compression = gzip
+
+
+.. note::
+
+  Compression requires the support for it is installed in your python 
environment. For example, to use ``snappy`` compression, you need to install 
``python-snappy``. Zip, gzip and bz2 work out of the box.
+
+
 Custom XCom Backends
 --------------------
 
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index 1dc8d5693b..08d2817eae 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -1080,6 +1080,7 @@ Oauth
 oauth
 Oauthlib
 objectORfile
+objectstorage
 observability
 od
 odbc
diff --git a/tests/models/test_xcom.py b/tests/models/test_xcom.py
index 8ab7d4eb56..14a9da7f6a 100644
--- a/tests/models/test_xcom.py
+++ b/tests/models/test_xcom.py
@@ -569,7 +569,8 @@ class TestXComClear:
         push_simple_json_xcom(ti=task_instance, key="xcom_1", value={"key": 
"value"})
 
     @pytest.mark.usefixtures("setup_for_xcom_clear")
-    def test_xcom_clear(self, session, task_instance):
+    @mock.patch("airflow.models.xcom.XCom.purge")
+    def test_xcom_clear(self, mock_purge, session, task_instance):
         assert session.query(XCom).count() == 1
         XCom.clear(
             dag_id=task_instance.dag_id,
@@ -578,6 +579,7 @@ class TestXComClear:
             session=session,
         )
         assert session.query(XCom).count() == 0
+        assert mock_purge.call_count == 1
 
     @pytest.mark.usefixtures("setup_for_xcom_clear")
     def test_xcom_clear_with_execution_date(self, session, task_instance):
diff --git a/airflow/providers/common/io/provider.yaml 
b/tests/providers/common/io/xcom/__init__.py
similarity index 56%
copy from airflow/providers/common/io/provider.yaml
copy to tests/providers/common/io/xcom/__init__.py
index 7c32a42d0c..13a83393a9 100644
--- a/airflow/providers/common/io/provider.yaml
+++ b/tests/providers/common/io/xcom/__init__.py
@@ -14,32 +14,3 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-
----
-package-name: apache-airflow-providers-common-io
-name: Common IO
-description: |
-  ``Common IO Provider``
-
-state: ready
-source-date-epoch: 1704610529
-versions:
-  - 1.2.0
-  - 1.1.0
-  - 1.0.1
-  - 1.0.0
-
-dependencies:
-  - apache-airflow>=2.8.0
-
-integrations:
-  - integration-name: Common IO
-    external-doc-url: 
https://filesystem-spec.readthedocs.io/en/latest/index.html
-    how-to-guide:
-      - /docs/apache-airflow-providers-common-io/operators.rst
-    tags: [software]
-
-operators:
-  - integration-name: Common IO
-    python-modules:
-      - airflow.providers.common.io.operators.file_transfer
diff --git a/tests/providers/common/io/xcom/test_backend.py 
b/tests/providers/common/io/xcom/test_backend.py
new file mode 100644
index 0000000000..fce5ed985e
--- /dev/null
+++ b/tests/providers/common/io/xcom/test_backend.py
@@ -0,0 +1,263 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from configparser import DuplicateSectionError
+from typing import TYPE_CHECKING
+
+import pytest
+
+import airflow.models.xcom
+from airflow import settings
+from airflow.configuration import conf
+from airflow.io.path import ObjectStoragePath
+from airflow.models.dagrun import DagRun
+from airflow.models.taskinstance import TaskInstance
+from airflow.models.xcom import BaseXCom, resolve_xcom_backend
+from airflow.operators.empty import EmptyOperator
+from airflow.providers.common.io.xcom.backend import XComObjectStoreBackend
+from airflow.utils import timezone
+from airflow.utils.session import create_session
+from airflow.utils.types import DagRunType
+from airflow.utils.xcom import XCOM_RETURN_KEY
+from tests.test_utils.config import conf_vars
+
+if TYPE_CHECKING:
+    from sqlalchemy.orm import Session
+
+
+@pytest.fixture(autouse=True)
+def reset_db():
+    """Reset XCom entries."""
+    with create_session() as session:
+        session.query(DagRun).delete()
+        session.query(airflow.models.xcom.XCom).delete()
+
+
+@pytest.fixture()
+def task_instance_factory(request, session: Session):
+    def func(*, dag_id, task_id, execution_date):
+        run_id = DagRun.generate_run_id(DagRunType.SCHEDULED, execution_date)
+        run = DagRun(
+            dag_id=dag_id,
+            run_type=DagRunType.SCHEDULED,
+            run_id=run_id,
+            execution_date=execution_date,
+        )
+        session.add(run)
+        ti = TaskInstance(EmptyOperator(task_id=task_id), run_id=run_id)
+        ti.dag_id = dag_id
+        session.add(ti)
+        session.commit()
+
+        def cleanup_database():
+            # This should also clear task instances by cascading.
+            session.query(DagRun).filter_by(id=run.id).delete()
+            session.commit()
+
+        request.addfinalizer(cleanup_database)
+        return ti
+
+    return func
+
+
+@pytest.fixture()
+def task_instance(task_instance_factory):
+    return task_instance_factory(
+        dag_id="dag",
+        task_id="task_1",
+        execution_date=timezone.datetime(2021, 12, 3, 4, 56),
+    )
+
+
+class TestXcomObjectStoreBackend:
+    path = "file:/tmp/xcom"
+
+    def setup_method(self):
+        try:
+            conf.add_section("common.io")
+        except DuplicateSectionError:
+            pass
+        conf.set("core", "xcom_backend", 
"airflow.providers.common.io.xcom.backend.XComObjectStoreBackend")
+        conf.set("common.io", "xcom_objectstore_path", self.path)
+        conf.set("common.io", "xcom_objectstore_threshold", "50")
+        settings.configure_vars()
+
+    def teardown_method(self):
+        conf.remove_option("core", "xcom_backend")
+        conf.remove_option("common.io", "xcom_objectstore_path")
+        conf.remove_option("common.io", "xcom_objectstore_threshold")
+        settings.configure_vars()
+        p = ObjectStoragePath(self.path)
+        if p.exists():
+            p.rmdir(recursive=True)
+
+    @pytest.mark.db_test
+    def test_value_db(self, task_instance, session):
+        XCom = resolve_xcom_backend()
+        airflow.models.xcom.XCom = XCom
+
+        XCom.set(
+            key=XCOM_RETURN_KEY,
+            value={"key": "value"},
+            dag_id=task_instance.dag_id,
+            task_id=task_instance.task_id,
+            run_id=task_instance.run_id,
+            session=session,
+        )
+
+        value = XCom.get_value(
+            key=XCOM_RETURN_KEY,
+            ti_key=task_instance.key,
+            session=session,
+        )
+        assert value == {"key": "value"}
+
+        qry = XCom.get_many(
+            key=XCOM_RETURN_KEY,
+            dag_ids=task_instance.dag_id,
+            task_ids=task_instance.task_id,
+            run_id=task_instance.run_id,
+            session=session,
+        )
+        assert qry.first().value == {"key": "value"}
+
+    @pytest.mark.db_test
+    def test_value_storage(self, task_instance, session):
+        XCom = resolve_xcom_backend()
+        airflow.models.xcom.XCom = XCom
+
+        XCom.set(
+            key=XCOM_RETURN_KEY,
+            value={"key": "bigvaluebigvaluebigvalue" * 100},
+            dag_id=task_instance.dag_id,
+            task_id=task_instance.task_id,
+            run_id=task_instance.run_id,
+            session=session,
+        )
+
+        res = (
+            XCom.get_many(
+                key=XCOM_RETURN_KEY,
+                dag_ids=task_instance.dag_id,
+                task_ids=task_instance.task_id,
+                run_id=task_instance.run_id,
+                session=session,
+            )
+            .with_entities(BaseXCom.value)
+            .first()
+        )
+
+        data = BaseXCom.deserialize_value(res)
+        p = ObjectStoragePath(self.path) / 
XComObjectStoreBackend._get_key(data)
+        assert p.exists() is True
+
+        value = XCom.get_value(
+            key=XCOM_RETURN_KEY,
+            ti_key=task_instance.key,
+            session=session,
+        )
+        assert value == {"key": "bigvaluebigvaluebigvalue" * 100}
+
+        qry = XCom.get_many(
+            key=XCOM_RETURN_KEY,
+            dag_ids=task_instance.dag_id,
+            task_ids=task_instance.task_id,
+            run_id=task_instance.run_id,
+            session=session,
+        )
+        assert self.path in qry.first().value
+
+    @pytest.mark.db_test
+    def test_clear(self, task_instance, session):
+        XCom = resolve_xcom_backend()
+        airflow.models.xcom.XCom = XCom
+
+        XCom.set(
+            key=XCOM_RETURN_KEY,
+            value={"key": "superlargevalue" * 100},
+            dag_id=task_instance.dag_id,
+            task_id=task_instance.task_id,
+            run_id=task_instance.run_id,
+            session=session,
+        )
+
+        res = (
+            XCom.get_many(
+                key=XCOM_RETURN_KEY,
+                dag_ids=task_instance.dag_id,
+                task_ids=task_instance.task_id,
+                run_id=task_instance.run_id,
+                session=session,
+            )
+            .with_entities(BaseXCom.value)
+            .first()
+        )
+
+        data = BaseXCom.deserialize_value(res)
+        p = ObjectStoragePath(self.path) / 
XComObjectStoreBackend._get_key(data)
+        assert p.exists() is True
+
+        XCom.clear(
+            dag_id=task_instance.dag_id,
+            task_id=task_instance.task_id,
+            run_id=task_instance.run_id,
+            session=session,
+        )
+
+        assert p.exists() is False
+
+    @pytest.mark.db_test
+    @conf_vars({("common.io", "xcom_objectstore_compression"): "gzip"})
+    def test_compression(self, task_instance, session):
+        XCom = resolve_xcom_backend()
+        airflow.models.xcom.XCom = XCom
+
+        XCom.set(
+            key=XCOM_RETURN_KEY,
+            value={"key": "superlargevalue" * 100},
+            dag_id=task_instance.dag_id,
+            task_id=task_instance.task_id,
+            run_id=task_instance.run_id,
+            session=session,
+        )
+
+        res = (
+            XCom.get_many(
+                key=XCOM_RETURN_KEY,
+                dag_ids=task_instance.dag_id,
+                task_ids=task_instance.task_id,
+                run_id=task_instance.run_id,
+                session=session,
+            )
+            .with_entities(BaseXCom.value)
+            .first()
+        )
+
+        data = BaseXCom.deserialize_value(res)
+        p = ObjectStoragePath(self.path) / 
XComObjectStoreBackend._get_key(data)
+        assert p.exists() is True
+        assert p.suffix == ".gz"
+
+        value = XCom.get_value(
+            key=XCOM_RETURN_KEY,
+            ti_key=task_instance.key,
+            session=session,
+        )
+
+        assert value == {"key": "superlargevalue" * 100}

Reply via email to