This is an automated email from the ASF dual-hosted git repository.
husseinawala pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 489ca14946 Fix handling of default value and serialization of Param
class (#33141)
489ca14946 is described below
commit 489ca1494691f6a0b717f9312e0ab8a2a7f76d96
Author: Jens Scheffler <[email protected]>
AuthorDate: Wed Aug 16 08:04:52 2023 +0200
Fix handling of default value and serialization of Param class (#33141)
---
airflow/models/param.py | 9 ++++++--
airflow/serialization/enums.py | 1 +
airflow/serialization/serialized_objects.py | 5 +++++
tests/models/test_param.py | 31 +++++++++++++++++++++++++++
tests/serialization/test_dag_serialization.py | 7 +++---
5 files changed, 48 insertions(+), 5 deletions(-)
diff --git a/airflow/models/param.py b/airflow/models/param.py
index 82a49f715d..bea4333cf5 100644
--- a/airflow/models/param.py
+++ b/airflow/models/param.py
@@ -138,13 +138,18 @@ class Param:
def dump(self) -> dict:
"""Dump the Param as a dictionary."""
- out_dict = {self.CLASS_IDENTIFIER:
f"{self.__module__}.{self.__class__.__name__}"}
+ out_dict: dict[str, str | None] = {
+ self.CLASS_IDENTIFIER:
f"{self.__module__}.{self.__class__.__name__}"
+ }
out_dict.update(self.__dict__)
+ # Ensure that not set is translated to None
+ if self.value is NOTSET:
+ out_dict["value"] = None
return out_dict
@property
def has_value(self) -> bool:
- return self.value is not NOTSET
+ return self.value is not NOTSET and self.value is not None
def serialize(self) -> dict:
return {"value": self.value, "description": self.description,
"schema": self.schema}
diff --git a/airflow/serialization/enums.py b/airflow/serialization/enums.py
index c83d9f53ef..0b1c0ca009 100644
--- a/airflow/serialization/enums.py
+++ b/airflow/serialization/enums.py
@@ -55,3 +55,4 @@ class DagAttributeTypes(str, Enum):
TASK_INSTANCE = "task_instance"
DAG_RUN = "dag_run"
DATA_SET = "data_set"
+ ARG_NOT_SET = "arg_not_set"
diff --git a/airflow/serialization/serialized_objects.py
b/airflow/serialization/serialized_objects.py
index 17a084f20e..67d08b7a94 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -65,6 +65,7 @@ from airflow.utils.docs import get_docs_url
from airflow.utils.module_loading import import_string, qualname
from airflow.utils.operator_resources import Resources
from airflow.utils.task_group import MappedTaskGroup, TaskGroup
+from airflow.utils.types import NOTSET, ArgNotSet
if TYPE_CHECKING:
from pydantic import BaseModel
@@ -499,6 +500,8 @@ class BaseSerialization:
return cls._encode(_pydantic_model_dump(DatasetPydantic, var),
type_=DAT.DATA_SET)
else:
return cls.default_serialization(strict, var)
+ elif isinstance(var, ArgNotSet):
+ return cls._encode(None, type_=DAT.ARG_NOT_SET)
else:
return cls.default_serialization(strict, var)
@@ -572,6 +575,8 @@ class BaseSerialization:
return DagRunPydantic.parse_obj(var)
elif type_ == DAT.DATA_SET:
return DatasetPydantic.parse_obj(var)
+ elif type_ == DAT.ARG_NOT_SET:
+ return NOTSET
else:
raise TypeError(f"Invalid type {type_!s} in deserialization.")
diff --git a/tests/models/test_param.py b/tests/models/test_param.py
index 4053cf6571..b73cfea15f 100644
--- a/tests/models/test_param.py
+++ b/tests/models/test_param.py
@@ -23,6 +23,7 @@ import pytest
from airflow.decorators import task
from airflow.exceptions import ParamValidationError, RemovedInAirflow3Warning
from airflow.models.param import Param, ParamsDict
+from airflow.serialization.serialized_objects import BaseSerialization
from airflow.utils import timezone
from airflow.utils.types import DagRunType
from tests.test_utils.db import clear_db_dags, clear_db_runs, clear_db_xcom
@@ -41,14 +42,20 @@ class TestParam:
with pytest.raises(ParamValidationError, match="No value passed and
Param has no default value"):
p.resolve()
assert p.resolve(None) is None
+ assert p.dump()["value"] is None
+ assert not p.has_value
p = Param(None)
assert p.resolve() is None
assert p.resolve(None) is None
+ assert p.dump()["value"] is None
+ assert not p.has_value
p = Param(None, type="null")
assert p.resolve() is None
assert p.resolve(None) is None
+ assert p.dump()["value"] is None
+ assert not p.has_value
with pytest.raises(ParamValidationError):
p.resolve("test")
@@ -222,6 +229,30 @@ class TestParam:
assert dump["description"] == "world"
assert dump["schema"] == {"type": "string", "minLength": 2}
+ @pytest.mark.parametrize(
+ "param",
+ [
+ Param("my value", description="hello", schema={"type": "string"}),
+ Param("my value", description="hello"),
+ Param(None, description=None),
+ Param([True], type="array", items={"type": "boolean"}),
+ Param(),
+ ],
+ )
+ def test_param_serialization(self, param: Param):
+ """
+ Test to make sure that native Param objects can be correctly serialized
+ """
+
+ serializer = BaseSerialization()
+ serialized_param = serializer.serialize(param)
+ restored_param: Param = serializer.deserialize(serialized_param)
+
+ assert restored_param.value == param.value
+ assert isinstance(restored_param, Param)
+ assert restored_param.description == param.description
+ assert restored_param.schema == param.schema
+
class TestParamsDict:
def test_params_dict(self):
diff --git a/tests/serialization/test_dag_serialization.py
b/tests/serialization/test_dag_serialization.py
index 26474a8929..5338579e01 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -899,20 +899,21 @@ class TestStringifiedDAGs:
Param("my value", description="hello"),
Param(None, description=None),
Param([True], type="array", items={"type": "boolean"}),
+ Param(),
],
)
- def test_full_param_roundtrip(self, param):
+ def test_full_param_roundtrip(self, param: Param):
"""
Test to make sure that only native Param objects are being passed as
dag or task params
"""
- dag = DAG(dag_id="simple_dag", params={"my_param": param})
+ dag = DAG(dag_id="simple_dag", schedule=None, params={"my_param":
param})
serialized_json = SerializedDAG.to_json(dag)
serialized = json.loads(serialized_json)
SerializedDAG.validate_schema(serialized)
dag = SerializedDAG.from_dict(serialized)
- assert dag.params["my_param"] == param.value
+ assert dag.params.get_param("my_param").value == param.value
observed_param = dag.params.get_param("my_param")
assert isinstance(observed_param, Param)
assert observed_param.description == param.description