This is an automated email from the ASF dual-hosted git repository.

jedcunningham pushed a commit to branch v2-2-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 4f269594366f501974c6c37bde5d19f615761859
Author: Jed Cunningham <[email protected]>
AuthorDate: Tue Oct 19 14:30:15 2021 -0600

    Allow Param to support a default value of ``None`` (#19034)
    
    (cherry picked from commit ee4fb073b6d3c64437fc27b673ed9b6568953f46)
---
 UPDATING.md                                        | 14 ++++++++
 airflow/models/dag.py                              |  4 +--
 airflow/models/param.py                            | 39 +++++++++++++++++-----
 airflow/serialization/serialized_objects.py        | 18 +++-------
 tests/api_connexion/endpoints/test_dag_endpoint.py |  6 ++--
 .../api_connexion/endpoints/test_task_endpoint.py  |  6 ++--
 tests/api_connexion/schemas/test_dag_schema.py     |  2 +-
 tests/api_connexion/schemas/test_task_schema.py    |  2 +-
 tests/models/test_dag.py                           | 16 +++++++++
 tests/models/test_param.py                         | 31 +++++++++++++++--
 tests/serialization/test_dag_serialization.py      | 17 ++++++++++
 11 files changed, 119 insertions(+), 36 deletions(-)

diff --git a/UPDATING.md b/UPDATING.md
index 8ffd87d..9c0c4fd 100644
--- a/UPDATING.md
+++ b/UPDATING.md
@@ -77,6 +77,20 @@ https://developers.google.com/style/inclusive-documentation
 
 -->
 
+### ``Param``'s default value for ``default`` removed
+
+``Param``, introduced in Airflow 2.2.0, accidentally set the default value to 
``None``. This default has been removed. If you want ``None`` as your default, 
explicitly set it as such. For example:
+
+```python
+Param(None, type=["null", "string"])
+```
+
+Now if you resolve a ``Param`` without a default and don't pass a value, you 
will get an ``TypeError``. For Example:
+
+```python
+Param().resolve()  # raises TypeError
+```
+
 ## Airflow 2.2.0
 
 ### `worker_log_server_port` configuration has been moved to the ``logging`` 
section.
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 129a30a..b372cfe 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -2631,8 +2631,8 @@ class DAG(LoggingMixin):
             return
 
         for k, v in self.params.items():
-            # As type can be an array, we would check if `null` is a allowed 
type or not
-            if v.default is None and ("type" not in v.schema or "null" not in 
v.schema["type"]):
+            # As type can be an array, we would check if `null` is an allowed 
type or not
+            if not v.has_value and ("type" not in v.schema or "null" not in 
v.schema["type"]):
                 raise AirflowException(
                     "DAG Schedule must be None, if there are any required 
params without default values"
                 )
diff --git a/airflow/models/param.py b/airflow/models/param.py
index 2869481..1ae01dc 100644
--- a/airflow/models/param.py
+++ b/airflow/models/param.py
@@ -24,6 +24,16 @@ from jsonschema.exceptions import ValidationError
 from airflow.exceptions import AirflowException
 
 
+class NoValueSentinel:
+    """Sentinel class used to distinguish between None and no passed value"""
+
+    def __str__(self):
+        return "NoValueSentinel"
+
+    def __repr__(self):
+        return "NoValueSentinel"
+
+
 class Param:
     """
     Class to hold the default value of a Param and rule set to do the 
validations. Without the rule set
@@ -38,22 +48,25 @@ class Param:
     :type schema: dict
     """
 
-    def __init__(self, default: Any = None, description: str = None, **kwargs):
-        self.default = default
+    __NO_VALUE_SENTINEL = NoValueSentinel()
+
+    def __init__(self, default: Any = __NO_VALUE_SENTINEL, description: str = 
None, **kwargs):
+        self.value = default
         self.description = description
         self.schema = kwargs.pop('schema') if 'schema' in kwargs else kwargs
 
-        # If default is not None, then validate it once, may raise ValueError
-        if default:
+        # If we have a value, validate it once. May raise ValueError.
+        if self.has_value:
             try:
-                jsonschema.validate(self.default, self.schema, 
format_checker=FormatChecker())
+                jsonschema.validate(self.value, self.schema, 
format_checker=FormatChecker())
             except ValidationError as err:
                 raise ValueError(err)
 
-    def resolve(self, value: Optional[Any] = None, suppress_exception: bool = 
False) -> Any:
+    def resolve(self, value: Optional[Any] = __NO_VALUE_SENTINEL, 
suppress_exception: bool = False) -> Any:
         """
         Runs the validations and returns the Param's final value.
-        May raise ValueError on failed validations.
+        May raise ValueError on failed validations, or TypeError
+        if no value is passed and no value already exists.
 
         :param value: The value to be updated for the Param
         :type value: Optional[Any]
@@ -61,14 +74,18 @@ class Param:
             If true and validations fails, the return value would be None.
         :type suppress_exception: bool
         """
+        final_val = value if value != self.__NO_VALUE_SENTINEL else self.value
+        if isinstance(final_val, NoValueSentinel):
+            if suppress_exception:
+                return None
+            raise TypeError("No value passed and Param has no default value")
         try:
-            final_val = value or self.default
             jsonschema.validate(final_val, self.schema, 
format_checker=FormatChecker())
-            self.default = final_val
         except ValidationError as err:
             if suppress_exception:
                 return None
             raise ValueError(err) from None
+        self.value = final_val
         return final_val
 
     def dump(self) -> dict:
@@ -77,6 +94,10 @@ class Param:
         out_dict.update(self.__dict__)
         return out_dict
 
+    @property
+    def has_value(self) -> bool:
+        return not isinstance(self.value, NoValueSentinel)
+
 
 class ParamsDict(dict):
     """
diff --git a/airflow/serialization/serialized_objects.py 
b/airflow/serialization/serialized_objects.py
index 3d30ca4..6ec9770 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -416,7 +416,9 @@ class BaseSerialization:
         for k, v in params.items():
             # TODO: As of now, we would allow serialization of params which 
are of type Param only
             if f'{v.__module__}.{v.__class__.__name__}' == 
'airflow.models.param.Param':
-                serialized_params[k] = v.dump()
+                kwargs = v.dump()
+                kwargs['default'] = kwargs.pop('value')
+                serialized_params[k] = kwargs
             else:
                 raise ValueError('Params to a DAG or a Task can be only of 
type airflow.models.param.Param')
         return serialized_params
@@ -543,7 +545,7 @@ class SerializedBaseOperator(BaseOperator, 
BaseSerialization):
                     serialize_op[template_field] = 
serialize_template_field(value)
 
         if op.params:
-            serialize_op['params'] = cls._serialize_operator_params(op.params)
+            serialize_op['params'] = cls._serialize_params_dict(op.params)
 
         return serialize_op
 
@@ -747,18 +749,6 @@ class SerializedBaseOperator(BaseOperator, 
BaseSerialization):
 
         return serialize_operator_extra_links
 
-    @classmethod
-    def _serialize_operator_params(cls, op_params: ParamsDict):
-        """Serialize Params dict of a operator"""
-        serialized_params = {}
-        for k, v in op_params.items():
-            # TODO: As of now, we would allow serialization of params which 
are of type Param only
-            if f'{v.__module__}.{v.__class__.__name__}' == 
'airflow.models.param.Param':
-                serialized_params[k] = v.dump()
-            else:
-                raise ValueError('Params to a Task can be only of type 
airflow.models.param.Param')
-        return serialized_params
-
 
 class SerializedDAG(DAG, BaseSerialization):
     """
diff --git a/tests/api_connexion/endpoints/test_dag_endpoint.py 
b/tests/api_connexion/endpoints/test_dag_endpoint.py
index e9602fa..73dac72 100644
--- a/tests/api_connexion/endpoints/test_dag_endpoint.py
+++ b/tests/api_connexion/endpoints/test_dag_endpoint.py
@@ -240,7 +240,7 @@ class TestGetDagDetails(TestDagEndpoint):
             "params": {
                 "foo": {
                     '__class': 'airflow.models.param.Param',
-                    'default': 1,
+                    'value': 1,
                     'description': None,
                     'schema': {},
                 }
@@ -353,7 +353,7 @@ class TestGetDagDetails(TestDagEndpoint):
             "params": {
                 "foo": {
                     '__class': 'airflow.models.param.Param',
-                    'default': 1,
+                    'value': 1,
                     'description': None,
                     'schema': {},
                 }
@@ -400,7 +400,7 @@ class TestGetDagDetails(TestDagEndpoint):
             "params": {
                 "foo": {
                     '__class': 'airflow.models.param.Param',
-                    'default': 1,
+                    'value': 1,
                     'description': None,
                     'schema': {},
                 }
diff --git a/tests/api_connexion/endpoints/test_task_endpoint.py 
b/tests/api_connexion/endpoints/test_task_endpoint.py
index 242df7f..2630d49 100644
--- a/tests/api_connexion/endpoints/test_task_endpoint.py
+++ b/tests/api_connexion/endpoints/test_task_endpoint.py
@@ -100,7 +100,7 @@ class TestGetTask(TestTaskEndpoint):
             'params': {
                 'foo': {
                     '__class': 'airflow.models.param.Param',
-                    'default': 'bar',
+                    'value': 'bar',
                     'description': None,
                     'schema': {},
                 }
@@ -150,7 +150,7 @@ class TestGetTask(TestTaskEndpoint):
             'params': {
                 'foo': {
                     '__class': 'airflow.models.param.Param',
-                    'default': 'bar',
+                    'value': 'bar',
                     'description': None,
                     'schema': {},
                 }
@@ -215,7 +215,7 @@ class TestGetTasks(TestTaskEndpoint):
                     'params': {
                         'foo': {
                             '__class': 'airflow.models.param.Param',
-                            'default': 'bar',
+                            'value': 'bar',
                             'description': None,
                             'schema': {},
                         }
diff --git a/tests/api_connexion/schemas/test_dag_schema.py 
b/tests/api_connexion/schemas/test_dag_schema.py
index 3f6d3b0..f12c06d 100644
--- a/tests/api_connexion/schemas/test_dag_schema.py
+++ b/tests/api_connexion/schemas/test_dag_schema.py
@@ -133,7 +133,7 @@ class TestDAGDetailSchema:
             'params': {
                 'foo': {
                     '__class': 'airflow.models.param.Param',
-                    'default': 1,
+                    'value': 1,
                     'description': None,
                     'schema': {},
                 }
diff --git a/tests/api_connexion/schemas/test_task_schema.py 
b/tests/api_connexion/schemas/test_task_schema.py
index f4a4867..65006ed 100644
--- a/tests/api_connexion/schemas/test_task_schema.py
+++ b/tests/api_connexion/schemas/test_task_schema.py
@@ -81,7 +81,7 @@ class TestTaskCollectionSchema:
                     'params': {
                         'foo': {
                             '__class': 'airflow.models.param.Param',
-                            'default': 'bar',
+                            'value': 'bar',
                             'description': None,
                             'schema': {},
                         }
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index e77f271..dfa1a0a 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -1799,14 +1799,30 @@ class TestDag(unittest.TestCase):
 
     def test_validate_params_on_trigger_dag(self):
         dag = models.DAG('dummy-dag', schedule_interval=None, 
params={'param1': Param(type="string")})
+        with pytest.raises(TypeError, match="No value passed and Param has no 
default value"):
+            dag.create_dagrun(
+                run_id="test_dagrun_missing_param",
+                state=State.RUNNING,
+                execution_date=TEST_DATE,
+            )
 
+        dag = models.DAG('dummy-dag', schedule_interval=None, 
params={'param1': Param(type="string")})
         with pytest.raises(ValueError, match="Invalid input for param param1: 
None is not of type 'string'"):
             dag.create_dagrun(
                 run_id="test_dagrun_missing_param",
                 state=State.RUNNING,
                 execution_date=TEST_DATE,
+                conf={"param1": None},
             )
 
+        dag = models.DAG('dummy-dag', schedule_interval=None, 
params={'param1': Param(type="string")})
+        dag.create_dagrun(
+            run_id="test_dagrun_missing_param",
+            state=State.RUNNING,
+            execution_date=TEST_DATE,
+            conf={"param1": "hello"},
+        )
+
 
 class TestDagModel:
     def test_dags_needing_dagruns_not_too_early(self):
diff --git a/tests/models/test_param.py b/tests/models/test_param.py
index ea25f21..a9f260b 100644
--- a/tests/models/test_param.py
+++ b/tests/models/test_param.py
@@ -31,15 +31,25 @@ class TestParam(unittest.TestCase):
         p = Param('test')
         assert p.resolve() == 'test'
 
-        p.default = 10
+        p.value = 10
         assert p.resolve() == 10
 
     def test_null_param(self):
         p = Param()
+        with pytest.raises(TypeError, match='No value passed and Param has no 
default value'):
+            p.resolve()
+        assert p.resolve(None) is None
+
+        p = Param(None)
         assert p.resolve() is None
+        assert p.resolve(None) is None
 
         p = Param(type="null")
+        p = Param(None, type='null')
         assert p.resolve() is None
+        assert p.resolve(None) is None
+        with pytest.raises(ValueError):
+            p.resolve('test')
 
     def test_string_param(self):
         p = Param('test', type='string')
@@ -51,8 +61,10 @@ class TestParam(unittest.TestCase):
         p = Param('10.0.0.0', type='string', format='ipv4')
         assert p.resolve() == '10.0.0.0'
 
+        p = Param(type='string')
         with pytest.raises(ValueError):
-            p = Param(type='string')
+            p.resolve(None)
+        with pytest.raises(TypeError, match='No value passed and Param has no 
default value'):
             p.resolve()
 
     def test_int_param(self):
@@ -96,7 +108,7 @@ class TestParam(unittest.TestCase):
         p = Param('abc', type='string', minLength=2, maxLength=4)
         assert p.resolve() == 'abc'
 
-        p.default = 'long_string'
+        p.value = 'long_string'
         assert p.resolve(suppress_exception=True) is None
 
     def test_explicit_schema(self):
@@ -115,6 +127,19 @@ class TestParam(unittest.TestCase):
         with pytest.raises(ValueError):
             p = S3Param("file://not_valid/s3_path")
 
+    def test_value_saved(self):
+        p = Param("hello", type="string")
+        assert p.resolve("world") == "world"
+        assert p.resolve() == "world"
+
+    def test_dump(self):
+        p = Param('hello', description='world', type='string', minLength=2)
+        dump = p.dump()
+        assert dump['__class'] == 'airflow.models.param.Param'
+        assert dump['value'] == 'hello'
+        assert dump['description'] == 'world'
+        assert dump['schema'] == {'type': 'string', 'minLength': 2}
+
 
 class TestParamsDict(unittest.TestCase):
     def test_params_dict(self):
diff --git a/tests/serialization/test_dag_serialization.py 
b/tests/serialization/test_dag_serialization.py
index 22eea07..afba96a 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -1450,6 +1450,23 @@ class TestStringifiedDAGs:
         assert isinstance(dict.__getitem__(dag.params, "none"), Param)
         assert dag.params["str"] == "str"
 
+    def test_params_serialize_default(self):
+        serialized = {
+            "__version": 1,
+            "dag": {
+                "_dag_id": "simple_dag",
+                "fileloc": __file__,
+                "tasks": [],
+                "timezone": "UTC",
+                "params": {"str": {"__class": "airflow.models.param.Param", 
"default": "str"}},
+            },
+        }
+        SerializedDAG.validate_schema(serialized)
+        dag = SerializedDAG.from_dict(serialized)
+
+        assert isinstance(dict.__getitem__(dag.params, "str"), Param)
+        assert dag.params["str"] == "str"
+
 
 def test_kubernetes_optional():
     """Serialisation / deserialisation continues to work without kubernetes 
installed"""

Reply via email to