This is an automated email from the ASF dual-hosted git repository. jedcunningham pushed a commit to branch v2-2-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 4f269594366f501974c6c37bde5d19f615761859 Author: Jed Cunningham <[email protected]> AuthorDate: Tue Oct 19 14:30:15 2021 -0600 Allow Param to support a default value of ``None`` (#19034) (cherry picked from commit ee4fb073b6d3c64437fc27b673ed9b6568953f46) --- UPDATING.md | 14 ++++++++ airflow/models/dag.py | 4 +-- airflow/models/param.py | 39 +++++++++++++++++----- airflow/serialization/serialized_objects.py | 18 +++------- tests/api_connexion/endpoints/test_dag_endpoint.py | 6 ++-- .../api_connexion/endpoints/test_task_endpoint.py | 6 ++-- tests/api_connexion/schemas/test_dag_schema.py | 2 +- tests/api_connexion/schemas/test_task_schema.py | 2 +- tests/models/test_dag.py | 16 +++++++++ tests/models/test_param.py | 31 +++++++++++++++-- tests/serialization/test_dag_serialization.py | 17 ++++++++++ 11 files changed, 119 insertions(+), 36 deletions(-) diff --git a/UPDATING.md b/UPDATING.md index 8ffd87d..9c0c4fd 100644 --- a/UPDATING.md +++ b/UPDATING.md @@ -77,6 +77,20 @@ https://developers.google.com/style/inclusive-documentation --> +### ``Param``'s default value for ``default`` removed + +``Param``, introduced in Airflow 2.2.0, accidentally set the default value to ``None``. This default has been removed. If you want ``None`` as your default, explicitly set it as such. For example: + +```python +Param(None, type=["null", "string"]) +``` + +Now if you resolve a ``Param`` without a default and don't pass a value, you will get an ``TypeError``. For Example: + +```python +Param().resolve() # raises TypeError +``` + ## Airflow 2.2.0 ### `worker_log_server_port` configuration has been moved to the ``logging`` section. diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 129a30a..b372cfe 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -2631,8 +2631,8 @@ class DAG(LoggingMixin): return for k, v in self.params.items(): - # As type can be an array, we would check if `null` is a allowed type or not - if v.default is None and ("type" not in v.schema or "null" not in v.schema["type"]): + # As type can be an array, we would check if `null` is an allowed type or not + if not v.has_value and ("type" not in v.schema or "null" not in v.schema["type"]): raise AirflowException( "DAG Schedule must be None, if there are any required params without default values" ) diff --git a/airflow/models/param.py b/airflow/models/param.py index 2869481..1ae01dc 100644 --- a/airflow/models/param.py +++ b/airflow/models/param.py @@ -24,6 +24,16 @@ from jsonschema.exceptions import ValidationError from airflow.exceptions import AirflowException +class NoValueSentinel: + """Sentinel class used to distinguish between None and no passed value""" + + def __str__(self): + return "NoValueSentinel" + + def __repr__(self): + return "NoValueSentinel" + + class Param: """ Class to hold the default value of a Param and rule set to do the validations. Without the rule set @@ -38,22 +48,25 @@ class Param: :type schema: dict """ - def __init__(self, default: Any = None, description: str = None, **kwargs): - self.default = default + __NO_VALUE_SENTINEL = NoValueSentinel() + + def __init__(self, default: Any = __NO_VALUE_SENTINEL, description: str = None, **kwargs): + self.value = default self.description = description self.schema = kwargs.pop('schema') if 'schema' in kwargs else kwargs - # If default is not None, then validate it once, may raise ValueError - if default: + # If we have a value, validate it once. May raise ValueError. + if self.has_value: try: - jsonschema.validate(self.default, self.schema, format_checker=FormatChecker()) + jsonschema.validate(self.value, self.schema, format_checker=FormatChecker()) except ValidationError as err: raise ValueError(err) - def resolve(self, value: Optional[Any] = None, suppress_exception: bool = False) -> Any: + def resolve(self, value: Optional[Any] = __NO_VALUE_SENTINEL, suppress_exception: bool = False) -> Any: """ Runs the validations and returns the Param's final value. - May raise ValueError on failed validations. + May raise ValueError on failed validations, or TypeError + if no value is passed and no value already exists. :param value: The value to be updated for the Param :type value: Optional[Any] @@ -61,14 +74,18 @@ class Param: If true and validations fails, the return value would be None. :type suppress_exception: bool """ + final_val = value if value != self.__NO_VALUE_SENTINEL else self.value + if isinstance(final_val, NoValueSentinel): + if suppress_exception: + return None + raise TypeError("No value passed and Param has no default value") try: - final_val = value or self.default jsonschema.validate(final_val, self.schema, format_checker=FormatChecker()) - self.default = final_val except ValidationError as err: if suppress_exception: return None raise ValueError(err) from None + self.value = final_val return final_val def dump(self) -> dict: @@ -77,6 +94,10 @@ class Param: out_dict.update(self.__dict__) return out_dict + @property + def has_value(self) -> bool: + return not isinstance(self.value, NoValueSentinel) + class ParamsDict(dict): """ diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 3d30ca4..6ec9770 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -416,7 +416,9 @@ class BaseSerialization: for k, v in params.items(): # TODO: As of now, we would allow serialization of params which are of type Param only if f'{v.__module__}.{v.__class__.__name__}' == 'airflow.models.param.Param': - serialized_params[k] = v.dump() + kwargs = v.dump() + kwargs['default'] = kwargs.pop('value') + serialized_params[k] = kwargs else: raise ValueError('Params to a DAG or a Task can be only of type airflow.models.param.Param') return serialized_params @@ -543,7 +545,7 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization): serialize_op[template_field] = serialize_template_field(value) if op.params: - serialize_op['params'] = cls._serialize_operator_params(op.params) + serialize_op['params'] = cls._serialize_params_dict(op.params) return serialize_op @@ -747,18 +749,6 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization): return serialize_operator_extra_links - @classmethod - def _serialize_operator_params(cls, op_params: ParamsDict): - """Serialize Params dict of a operator""" - serialized_params = {} - for k, v in op_params.items(): - # TODO: As of now, we would allow serialization of params which are of type Param only - if f'{v.__module__}.{v.__class__.__name__}' == 'airflow.models.param.Param': - serialized_params[k] = v.dump() - else: - raise ValueError('Params to a Task can be only of type airflow.models.param.Param') - return serialized_params - class SerializedDAG(DAG, BaseSerialization): """ diff --git a/tests/api_connexion/endpoints/test_dag_endpoint.py b/tests/api_connexion/endpoints/test_dag_endpoint.py index e9602fa..73dac72 100644 --- a/tests/api_connexion/endpoints/test_dag_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_endpoint.py @@ -240,7 +240,7 @@ class TestGetDagDetails(TestDagEndpoint): "params": { "foo": { '__class': 'airflow.models.param.Param', - 'default': 1, + 'value': 1, 'description': None, 'schema': {}, } @@ -353,7 +353,7 @@ class TestGetDagDetails(TestDagEndpoint): "params": { "foo": { '__class': 'airflow.models.param.Param', - 'default': 1, + 'value': 1, 'description': None, 'schema': {}, } @@ -400,7 +400,7 @@ class TestGetDagDetails(TestDagEndpoint): "params": { "foo": { '__class': 'airflow.models.param.Param', - 'default': 1, + 'value': 1, 'description': None, 'schema': {}, } diff --git a/tests/api_connexion/endpoints/test_task_endpoint.py b/tests/api_connexion/endpoints/test_task_endpoint.py index 242df7f..2630d49 100644 --- a/tests/api_connexion/endpoints/test_task_endpoint.py +++ b/tests/api_connexion/endpoints/test_task_endpoint.py @@ -100,7 +100,7 @@ class TestGetTask(TestTaskEndpoint): 'params': { 'foo': { '__class': 'airflow.models.param.Param', - 'default': 'bar', + 'value': 'bar', 'description': None, 'schema': {}, } @@ -150,7 +150,7 @@ class TestGetTask(TestTaskEndpoint): 'params': { 'foo': { '__class': 'airflow.models.param.Param', - 'default': 'bar', + 'value': 'bar', 'description': None, 'schema': {}, } @@ -215,7 +215,7 @@ class TestGetTasks(TestTaskEndpoint): 'params': { 'foo': { '__class': 'airflow.models.param.Param', - 'default': 'bar', + 'value': 'bar', 'description': None, 'schema': {}, } diff --git a/tests/api_connexion/schemas/test_dag_schema.py b/tests/api_connexion/schemas/test_dag_schema.py index 3f6d3b0..f12c06d 100644 --- a/tests/api_connexion/schemas/test_dag_schema.py +++ b/tests/api_connexion/schemas/test_dag_schema.py @@ -133,7 +133,7 @@ class TestDAGDetailSchema: 'params': { 'foo': { '__class': 'airflow.models.param.Param', - 'default': 1, + 'value': 1, 'description': None, 'schema': {}, } diff --git a/tests/api_connexion/schemas/test_task_schema.py b/tests/api_connexion/schemas/test_task_schema.py index f4a4867..65006ed 100644 --- a/tests/api_connexion/schemas/test_task_schema.py +++ b/tests/api_connexion/schemas/test_task_schema.py @@ -81,7 +81,7 @@ class TestTaskCollectionSchema: 'params': { 'foo': { '__class': 'airflow.models.param.Param', - 'default': 'bar', + 'value': 'bar', 'description': None, 'schema': {}, } diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index e77f271..dfa1a0a 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -1799,14 +1799,30 @@ class TestDag(unittest.TestCase): def test_validate_params_on_trigger_dag(self): dag = models.DAG('dummy-dag', schedule_interval=None, params={'param1': Param(type="string")}) + with pytest.raises(TypeError, match="No value passed and Param has no default value"): + dag.create_dagrun( + run_id="test_dagrun_missing_param", + state=State.RUNNING, + execution_date=TEST_DATE, + ) + dag = models.DAG('dummy-dag', schedule_interval=None, params={'param1': Param(type="string")}) with pytest.raises(ValueError, match="Invalid input for param param1: None is not of type 'string'"): dag.create_dagrun( run_id="test_dagrun_missing_param", state=State.RUNNING, execution_date=TEST_DATE, + conf={"param1": None}, ) + dag = models.DAG('dummy-dag', schedule_interval=None, params={'param1': Param(type="string")}) + dag.create_dagrun( + run_id="test_dagrun_missing_param", + state=State.RUNNING, + execution_date=TEST_DATE, + conf={"param1": "hello"}, + ) + class TestDagModel: def test_dags_needing_dagruns_not_too_early(self): diff --git a/tests/models/test_param.py b/tests/models/test_param.py index ea25f21..a9f260b 100644 --- a/tests/models/test_param.py +++ b/tests/models/test_param.py @@ -31,15 +31,25 @@ class TestParam(unittest.TestCase): p = Param('test') assert p.resolve() == 'test' - p.default = 10 + p.value = 10 assert p.resolve() == 10 def test_null_param(self): p = Param() + with pytest.raises(TypeError, match='No value passed and Param has no default value'): + p.resolve() + assert p.resolve(None) is None + + p = Param(None) assert p.resolve() is None + assert p.resolve(None) is None p = Param(type="null") + p = Param(None, type='null') assert p.resolve() is None + assert p.resolve(None) is None + with pytest.raises(ValueError): + p.resolve('test') def test_string_param(self): p = Param('test', type='string') @@ -51,8 +61,10 @@ class TestParam(unittest.TestCase): p = Param('10.0.0.0', type='string', format='ipv4') assert p.resolve() == '10.0.0.0' + p = Param(type='string') with pytest.raises(ValueError): - p = Param(type='string') + p.resolve(None) + with pytest.raises(TypeError, match='No value passed and Param has no default value'): p.resolve() def test_int_param(self): @@ -96,7 +108,7 @@ class TestParam(unittest.TestCase): p = Param('abc', type='string', minLength=2, maxLength=4) assert p.resolve() == 'abc' - p.default = 'long_string' + p.value = 'long_string' assert p.resolve(suppress_exception=True) is None def test_explicit_schema(self): @@ -115,6 +127,19 @@ class TestParam(unittest.TestCase): with pytest.raises(ValueError): p = S3Param("file://not_valid/s3_path") + def test_value_saved(self): + p = Param("hello", type="string") + assert p.resolve("world") == "world" + assert p.resolve() == "world" + + def test_dump(self): + p = Param('hello', description='world', type='string', minLength=2) + dump = p.dump() + assert dump['__class'] == 'airflow.models.param.Param' + assert dump['value'] == 'hello' + assert dump['description'] == 'world' + assert dump['schema'] == {'type': 'string', 'minLength': 2} + class TestParamsDict(unittest.TestCase): def test_params_dict(self): diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index 22eea07..afba96a 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -1450,6 +1450,23 @@ class TestStringifiedDAGs: assert isinstance(dict.__getitem__(dag.params, "none"), Param) assert dag.params["str"] == "str" + def test_params_serialize_default(self): + serialized = { + "__version": 1, + "dag": { + "_dag_id": "simple_dag", + "fileloc": __file__, + "tasks": [], + "timezone": "UTC", + "params": {"str": {"__class": "airflow.models.param.Param", "default": "str"}}, + }, + } + SerializedDAG.validate_schema(serialized) + dag = SerializedDAG.from_dict(serialized) + + assert isinstance(dict.__getitem__(dag.params, "str"), Param) + assert dag.params["str"] == "str" + def test_kubernetes_optional(): """Serialisation / deserialisation continues to work without kubernetes installed"""
