This is an automated email from the ASF dual-hosted git repository.

husseinawala pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 489ca14946 Fix handling of default value and serialization of Param 
class (#33141)
489ca14946 is described below

commit 489ca1494691f6a0b717f9312e0ab8a2a7f76d96
Author: Jens Scheffler <[email protected]>
AuthorDate: Wed Aug 16 08:04:52 2023 +0200

    Fix handling of default value and serialization of Param class (#33141)
---
 airflow/models/param.py                       |  9 ++++++--
 airflow/serialization/enums.py                |  1 +
 airflow/serialization/serialized_objects.py   |  5 +++++
 tests/models/test_param.py                    | 31 +++++++++++++++++++++++++++
 tests/serialization/test_dag_serialization.py |  7 +++---
 5 files changed, 48 insertions(+), 5 deletions(-)

diff --git a/airflow/models/param.py b/airflow/models/param.py
index 82a49f715d..bea4333cf5 100644
--- a/airflow/models/param.py
+++ b/airflow/models/param.py
@@ -138,13 +138,18 @@ class Param:
 
     def dump(self) -> dict:
         """Dump the Param as a dictionary."""
-        out_dict = {self.CLASS_IDENTIFIER: 
f"{self.__module__}.{self.__class__.__name__}"}
+        out_dict: dict[str, str | None] = {
+            self.CLASS_IDENTIFIER: 
f"{self.__module__}.{self.__class__.__name__}"
+        }
         out_dict.update(self.__dict__)
+        # Ensure that not set is translated to None
+        if self.value is NOTSET:
+            out_dict["value"] = None
         return out_dict
 
     @property
     def has_value(self) -> bool:
-        return self.value is not NOTSET
+        return self.value is not NOTSET and self.value is not None
 
     def serialize(self) -> dict:
         return {"value": self.value, "description": self.description, 
"schema": self.schema}
diff --git a/airflow/serialization/enums.py b/airflow/serialization/enums.py
index c83d9f53ef..0b1c0ca009 100644
--- a/airflow/serialization/enums.py
+++ b/airflow/serialization/enums.py
@@ -55,3 +55,4 @@ class DagAttributeTypes(str, Enum):
     TASK_INSTANCE = "task_instance"
     DAG_RUN = "dag_run"
     DATA_SET = "data_set"
+    ARG_NOT_SET = "arg_not_set"
diff --git a/airflow/serialization/serialized_objects.py 
b/airflow/serialization/serialized_objects.py
index 17a084f20e..67d08b7a94 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -65,6 +65,7 @@ from airflow.utils.docs import get_docs_url
 from airflow.utils.module_loading import import_string, qualname
 from airflow.utils.operator_resources import Resources
 from airflow.utils.task_group import MappedTaskGroup, TaskGroup
+from airflow.utils.types import NOTSET, ArgNotSet
 
 if TYPE_CHECKING:
     from pydantic import BaseModel
@@ -499,6 +500,8 @@ class BaseSerialization:
                 return cls._encode(_pydantic_model_dump(DatasetPydantic, var), 
type_=DAT.DATA_SET)
             else:
                 return cls.default_serialization(strict, var)
+        elif isinstance(var, ArgNotSet):
+            return cls._encode(None, type_=DAT.ARG_NOT_SET)
         else:
             return cls.default_serialization(strict, var)
 
@@ -572,6 +575,8 @@ class BaseSerialization:
                 return DagRunPydantic.parse_obj(var)
             elif type_ == DAT.DATA_SET:
                 return DatasetPydantic.parse_obj(var)
+        elif type_ == DAT.ARG_NOT_SET:
+            return NOTSET
         else:
             raise TypeError(f"Invalid type {type_!s} in deserialization.")
 
diff --git a/tests/models/test_param.py b/tests/models/test_param.py
index 4053cf6571..b73cfea15f 100644
--- a/tests/models/test_param.py
+++ b/tests/models/test_param.py
@@ -23,6 +23,7 @@ import pytest
 from airflow.decorators import task
 from airflow.exceptions import ParamValidationError, RemovedInAirflow3Warning
 from airflow.models.param import Param, ParamsDict
+from airflow.serialization.serialized_objects import BaseSerialization
 from airflow.utils import timezone
 from airflow.utils.types import DagRunType
 from tests.test_utils.db import clear_db_dags, clear_db_runs, clear_db_xcom
@@ -41,14 +42,20 @@ class TestParam:
         with pytest.raises(ParamValidationError, match="No value passed and 
Param has no default value"):
             p.resolve()
         assert p.resolve(None) is None
+        assert p.dump()["value"] is None
+        assert not p.has_value
 
         p = Param(None)
         assert p.resolve() is None
         assert p.resolve(None) is None
+        assert p.dump()["value"] is None
+        assert not p.has_value
 
         p = Param(None, type="null")
         assert p.resolve() is None
         assert p.resolve(None) is None
+        assert p.dump()["value"] is None
+        assert not p.has_value
         with pytest.raises(ParamValidationError):
             p.resolve("test")
 
@@ -222,6 +229,30 @@ class TestParam:
         assert dump["description"] == "world"
         assert dump["schema"] == {"type": "string", "minLength": 2}
 
+    @pytest.mark.parametrize(
+        "param",
+        [
+            Param("my value", description="hello", schema={"type": "string"}),
+            Param("my value", description="hello"),
+            Param(None, description=None),
+            Param([True], type="array", items={"type": "boolean"}),
+            Param(),
+        ],
+    )
+    def test_param_serialization(self, param: Param):
+        """
+        Test to make sure that native Param objects can be correctly serialized
+        """
+
+        serializer = BaseSerialization()
+        serialized_param = serializer.serialize(param)
+        restored_param: Param = serializer.deserialize(serialized_param)
+
+        assert restored_param.value == param.value
+        assert isinstance(restored_param, Param)
+        assert restored_param.description == param.description
+        assert restored_param.schema == param.schema
+
 
 class TestParamsDict:
     def test_params_dict(self):
diff --git a/tests/serialization/test_dag_serialization.py 
b/tests/serialization/test_dag_serialization.py
index 26474a8929..5338579e01 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -899,20 +899,21 @@ class TestStringifiedDAGs:
             Param("my value", description="hello"),
             Param(None, description=None),
             Param([True], type="array", items={"type": "boolean"}),
+            Param(),
         ],
     )
-    def test_full_param_roundtrip(self, param):
+    def test_full_param_roundtrip(self, param: Param):
         """
         Test to make sure that only native Param objects are being passed as 
dag or task params
         """
 
-        dag = DAG(dag_id="simple_dag", params={"my_param": param})
+        dag = DAG(dag_id="simple_dag", schedule=None, params={"my_param": 
param})
         serialized_json = SerializedDAG.to_json(dag)
         serialized = json.loads(serialized_json)
         SerializedDAG.validate_schema(serialized)
         dag = SerializedDAG.from_dict(serialized)
 
-        assert dag.params["my_param"] == param.value
+        assert dag.params.get_param("my_param").value == param.value
         observed_param = dag.params.get_param("my_param")
         assert isinstance(observed_param, Param)
         assert observed_param.description == param.description

Reply via email to