This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 901f00d1d0 feat(docker): Replace `use_dill` with `serializer` (#41356)
901f00d1d0 is described below

commit 901f00d1d0b8ea9cf5ba531f94fe7b9bd8637a58
Author: phi-friday <[email protected]>
AuthorDate: Wed Aug 14 19:42:47 2024 +0900

    feat(docker): Replace `use_dill` with `serializer` (#41356)
---
 airflow/decorators/__init__.pyi                  |  14 ++-
 airflow/providers/docker/decorators/docker.py    |  88 +++++++++++++++--
 tests/providers/docker/decorators/test_docker.py | 117 ++++++++++++++++++++++-
 3 files changed, 204 insertions(+), 15 deletions(-)

diff --git a/airflow/decorators/__init__.pyi b/airflow/decorators/__init__.pyi
index 089e453d02..e6d2479183 100644
--- a/airflow/decorators/__init__.pyi
+++ b/airflow/decorators/__init__.pyi
@@ -379,8 +379,9 @@ class TaskDecoratorCollection:
         self,
         *,
         multiple_outputs: bool | None = None,
-        use_dill: bool = False,  # Added by _DockerDecoratedOperator.
         python_command: str = "python3",
+        serializer: Literal["pickle", "cloudpickle", "dill"] | None = None,
+        use_dill: bool = False,  # Added by _DockerDecoratedOperator.
         # 'command', 'retrieve_output', and 'retrieve_output_path' are filled 
by
         # _DockerDecoratedOperator.
         image: str,
@@ -432,8 +433,17 @@ class TaskDecoratorCollection:
 
         :param multiple_outputs: If set, function return value will be 
unrolled to multiple XCom values.
             Dict will unroll to XCom values with keys as XCom keys. Defaults 
to False.
-        :param use_dill: Whether to use dill or pickle for serialization
         :param python_command: Python command for executing functions, 
Default: python3
+        :param serializer: Which serializer use to serialize the args and 
result. It can be one of the following:
+
+            - ``"pickle"``: (default) Use pickle for serialization. Included 
in the Python Standard Library.
+            - ``"cloudpickle"``: Use cloudpickle for serialize more complex 
types,
+              this requires to include cloudpickle in your requirements.
+            - ``"dill"``: Use dill for serialize more complex types,
+              this requires to include dill in your requirements.
+        :param use_dill: Deprecated, use ``serializer`` instead. Whether to 
use dill to serialize
+            the args and result (pickle is default). This allows more complex 
types
+            but requires you to include dill in your requirements.
         :param image: Docker image from which to create the container.
             If image tag is omitted, "latest" will be used.
         :param api_version: Remote API version. Set to ``auto`` to 
automatically
diff --git a/airflow/providers/docker/decorators/docker.py 
b/airflow/providers/docker/decorators/docker.py
index d851c98aca..9812e5fc57 100644
--- a/airflow/providers/docker/decorators/docker.py
+++ b/airflow/providers/docker/decorators/docker.py
@@ -18,13 +18,12 @@ from __future__ import annotations
 
 import base64
 import os
-import pickle
+import warnings
 from tempfile import TemporaryDirectory
-from typing import TYPE_CHECKING, Callable, Sequence
-
-import dill
+from typing import TYPE_CHECKING, Any, Callable, Literal, Sequence
 
 from airflow.decorators.base import DecoratedOperator, task_decorator_factory
+from airflow.exceptions import AirflowException, 
AirflowProviderDeprecationWarning
 from airflow.providers.docker.operators.docker import DockerOperator
 from airflow.utils.python_virtualenv import write_python_script
 
@@ -32,6 +31,47 @@ if TYPE_CHECKING:
     from airflow.decorators.base import TaskDecorator
     from airflow.utils.context import Context
 
+    Serializer = Literal["pickle", "dill", "cloudpickle"]
+
+try:
+    from airflow.operators.python import _SERIALIZERS
+except ImportError:
+    import logging
+
+    import lazy_object_proxy
+
+    log = logging.getLogger(__name__)
+
+    def _load_pickle():
+        import pickle
+
+        return pickle
+
+    def _load_dill():
+        try:
+            import dill
+        except ModuleNotFoundError:
+            log.error("Unable to import `dill` module. Please please make sure 
that it installed.")
+            raise
+        return dill
+
+    def _load_cloudpickle():
+        try:
+            import cloudpickle
+        except ModuleNotFoundError:
+            log.error(
+                "Unable to import `cloudpickle` module. "
+                "Please install it with: pip install 
'apache-airflow[cloudpickle]'"
+            )
+            raise
+        return cloudpickle
+
+    _SERIALIZERS: dict[Serializer, Any] = {  # type: ignore[no-redef]
+        "pickle": lazy_object_proxy.Proxy(_load_pickle),
+        "dill": lazy_object_proxy.Proxy(_load_dill),
+        "cloudpickle": lazy_object_proxy.Proxy(_load_cloudpickle),
+    }
+
 
 def _generate_decode_command(env_var, file, python_command):
     # We don't need `f.close()` as the interpreter is about to exit anyway
@@ -53,7 +93,6 @@ class _DockerDecoratedOperator(DecoratedOperator, 
DockerOperator):
 
     :param python_callable: A reference to an object that is callable
     :param python: Python binary name to use
-    :param use_dill: Whether dill should be used to serialize the callable
     :param expect_airflow: whether to expect airflow to be installed in the 
docker environment. if this
           one is specified, the script to run callable will attempt to load 
Airflow macros.
     :param op_kwargs: a dictionary of keyword arguments that will get unpacked
@@ -63,6 +102,16 @@ class _DockerDecoratedOperator(DecoratedOperator, 
DockerOperator):
     :param multiple_outputs: if set, function return value will be
         unrolled to multiple XCom values. Dict will unroll to xcom values with 
keys as keys.
         Defaults to False.
+    :param serializer: Which serializer use to serialize the args and result. 
It can be one of the following:
+
+        - ``"pickle"``: (default) Use pickle for serialization. Included in 
the Python Standard Library.
+        - ``"cloudpickle"``: Use cloudpickle for serialize more complex types,
+          this requires to include cloudpickle in your requirements.
+        - ``"dill"``: Use dill for serialize more complex types,
+          this requires to include dill in your requirements.
+    :param use_dill: Deprecated, use ``serializer`` instead. Whether to use 
dill to serialize
+        the args and result (pickle is default). This allows more complex types
+        but requires you to include dill in your requirements.
     """
 
     custom_operator_name = "@task.docker"
@@ -74,12 +123,35 @@ class _DockerDecoratedOperator(DecoratedOperator, 
DockerOperator):
         use_dill=False,
         python_command="python3",
         expect_airflow: bool = True,
+        serializer: Serializer | None = None,
         **kwargs,
     ) -> None:
+        if use_dill:
+            warnings.warn(
+                "`use_dill` is deprecated and will be removed in a future 
version. "
+                "Please provide serializer='dill' instead.",
+                AirflowProviderDeprecationWarning,
+                stacklevel=3,
+            )
+            if serializer:
+                raise AirflowException(
+                    "Both 'use_dill' and 'serializer' parameters are set. 
Please set only one of them"
+                )
+            serializer = "dill"
+        serializer = serializer or "pickle"
+        if serializer not in _SERIALIZERS:
+            msg = (
+                f"Unsupported serializer {serializer!r}. "
+                f"Expected one of {', '.join(map(repr, _SERIALIZERS))}"
+            )
+            raise AirflowException(msg)
+
         command = "placeholder command"
         self.python_command = python_command
         self.expect_airflow = expect_airflow
-        self.use_dill = use_dill
+        self.use_dill = serializer == "dill"
+        self.serializer: Serializer = serializer
+
         super().__init__(
             command=command, retrieve_output=True, 
retrieve_output_path="/tmp/script.out", **kwargs
         )
@@ -128,9 +200,7 @@ class _DockerDecoratedOperator(DecoratedOperator, 
DockerOperator):
 
     @property
     def pickling_library(self):
-        if self.use_dill:
-            return dill
-        return pickle
+        return _SERIALIZERS[self.serializer]
 
 
 def docker_task(
diff --git a/tests/providers/docker/decorators/test_docker.py 
b/tests/providers/docker/decorators/test_docker.py
index 93db9f211b..42b1a514a3 100644
--- a/tests/providers/docker/decorators/test_docker.py
+++ b/tests/providers/docker/decorators/test_docker.py
@@ -17,12 +17,13 @@
 from __future__ import annotations
 
 import logging
+from importlib.util import find_spec
 from io import StringIO as StringBuffer
 
 import pytest
 
 from airflow.decorators import setup, task, teardown
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, 
AirflowProviderDeprecationWarning
 from airflow.models import TaskInstance
 from airflow.models.dag import DAG
 from airflow.utils import timezone
@@ -32,6 +33,10 @@ pytestmark = pytest.mark.db_test
 
 
 DEFAULT_DATE = timezone.datetime(2021, 9, 1)
+DILL_INSTALLED = find_spec("dill") is not None
+DILL_MARKER = pytest.mark.skipif(not DILL_INSTALLED, reason="`dill` is not 
installed")
+CLOUDPICKLE_INSTALLED = find_spec("cloudpickle") is not None
+CLOUDPICKLE_MARKER = pytest.mark.skipif(not CLOUDPICKLE_INSTALLED, 
reason="`cloudpickle` is not installed")
 
 
 class TestDockerDecorator:
@@ -207,13 +212,21 @@ class TestDockerDecorator:
         assert teardown_task.is_teardown
         assert teardown_task.on_failure_fail_dagrun is on_failure_fail_dagrun
 
-    @pytest.mark.parametrize("use_dill", [True, False])
-    def test_deepcopy_with_python_operator(self, dag_maker, use_dill):
+    @pytest.mark.parametrize(
+        "serializer",
+        [
+            pytest.param("pickle", id="pickle"),
+            pytest.param("dill", marks=DILL_MARKER, id="dill"),
+            pytest.param("cloudpickle", marks=CLOUDPICKLE_MARKER, 
id="cloudpickle"),
+            pytest.param(None, id="default"),
+        ],
+    )
+    def test_deepcopy_with_python_operator(self, dag_maker, serializer):
         import copy
 
         from airflow.providers.docker.decorators.docker import 
_DockerDecoratedOperator
 
-        @task.docker(image="python:3.9-slim", auto_remove="force", 
use_dill=use_dill)
+        @task.docker(image="python:3.9-slim", auto_remove="force", 
serializer=serializer)
         def f():
             import logging
 
@@ -247,6 +260,7 @@ class TestDockerDecorator:
         assert isinstance(clone_of_docker_operator, _DockerDecoratedOperator)
         assert some_task.command == clone_of_docker_operator.command
         assert some_task.expect_airflow == 
clone_of_docker_operator.expect_airflow
+        assert some_task.serializer == clone_of_docker_operator.serializer
         assert some_task.use_dill == clone_of_docker_operator.use_dill
         assert some_task.pickling_library is 
clone_of_docker_operator.pickling_library
 
@@ -317,3 +331,98 @@ class TestDockerDecorator:
         assert 'with open(sys.argv[4], "w") as file:' not in log_content
         last_line_of_docker_operator_log = log_content.splitlines()[-1]
         assert "ValueError: This task is expected to fail" in 
last_line_of_docker_operator_log
+
+    @pytest.mark.parametrize(
+        "serializer",
+        [
+            pytest.param("pickle", id="pickle"),
+            pytest.param("dill", marks=DILL_MARKER, id="dill"),
+            pytest.param("cloudpickle", marks=CLOUDPICKLE_MARKER, 
id="cloudpickle"),
+        ],
+    )
+    def test_ambiguous_serializer(self, dag_maker, serializer):
+        @task.docker(image="python:3.9-slim", auto_remove="force", 
use_dill=True, serializer=serializer)
+        def f():
+            pass
+
+        with dag_maker():
+            with pytest.warns(
+                AirflowProviderDeprecationWarning, match="`use_dill` is 
deprecated and will be removed"
+            ):
+                with pytest.raises(
+                    AirflowException, match="Both 'use_dill' and 'serializer' 
parameters are set"
+                ):
+                    f()
+
+    def test_invalid_serializer(self, dag_maker):
+        @task.docker(image="python:3.9-slim", auto_remove="force", 
serializer="airflow")
+        def f():
+            """Ensure dill is correctly installed."""
+            import dill  # noqa: F401
+
+        with dag_maker():
+            with pytest.raises(AirflowException, match="Unsupported serializer 
'airflow'"):
+                f()
+
+    @pytest.mark.parametrize(
+        "serializer",
+        [
+            pytest.param(
+                "dill",
+                marks=pytest.mark.skipif(
+                    DILL_INSTALLED, reason="For this test case `dill` 
shouldn't be installed"
+                ),
+                id="dill",
+            ),
+            pytest.param(
+                "cloudpickle",
+                marks=pytest.mark.skipif(
+                    CLOUDPICKLE_INSTALLED, reason="For this test case 
`cloudpickle` shouldn't be installed"
+                ),
+                id="cloudpickle",
+            ),
+        ],
+    )
+    def test_advanced_serializer_not_installed(self, dag_maker, serializer, 
caplog):
+        """Test case for check raising an error if dill/cloudpickle is not 
installed."""
+
+        @task.docker(image="python:3.9-slim", auto_remove="force", 
serializer=serializer)
+        def f(): ...
+
+        with dag_maker():
+            with pytest.raises(ModuleNotFoundError):
+                f()
+        assert f"Unable to import `{serializer}` module." in caplog.text
+
+    @CLOUDPICKLE_MARKER
+    def test_add_cloudpickle(self, dag_maker):
+        @task.docker(image="python:3.9-slim", auto_remove="force", 
serializer="cloudpickle")
+        def f():
+            """Ensure cloudpickle is correctly installed."""
+            import cloudpickle  # noqa: F401
+
+        with dag_maker():
+            f()
+
+    @DILL_MARKER
+    def test_add_dill(self, dag_maker):
+        @task.docker(image="python:3.9-slim", auto_remove="force", 
serializer="dill")
+        def f():
+            """Ensure dill is correctly installed."""
+            import dill  # noqa: F401
+
+        with dag_maker():
+            f()
+
+    @DILL_MARKER
+    def test_add_dill_use_dill(self, dag_maker):
+        @task.docker(image="python:3.9-slim", auto_remove="force", 
use_dill=True)
+        def f():
+            """Ensure dill is correctly installed."""
+            import dill  # noqa: F401
+
+        with dag_maker():
+            with pytest.warns(
+                AirflowProviderDeprecationWarning, match="`use_dill` is 
deprecated and will be removed"
+            ):
+                f()

Reply via email to