This is an automated email from the ASF dual-hosted git repository. jedcunningham pushed a commit to branch v2-2-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 157a864d67627aacd960848791af913720ad45eb Author: Daniel Standish <[email protected]> AuthorDate: Fri Nov 5 10:19:50 2021 -0700 Fix serialization of Params with set data type (#19267) This is a solution for https://github.com/apache/airflow/issues/19096 Previously, the serialization of params did not run the param value through the `_serialize` function, resulting in non-json-serializable dictionaries. This manifested when a user, for example, tried to use params with a default value of type `set`. Here we change the logic to run the param value through the serialization process. And I add a test for the `set` case. closes https://github.com/apache/airflow/issues/19096 (cherry picked from commit 8512e0507263495ddd326e27699c45cafd31a5e1) --- airflow/models/param.py | 4 +- airflow/serialization/schema.json | 22 +++++++- airflow/serialization/serialized_objects.py | 50 ++++++++++++++----- tests/serialization/test_dag_serialization.py | 72 ++++++++++++++++++++++++--- 4 files changed, 125 insertions(+), 23 deletions(-) diff --git a/airflow/models/param.py b/airflow/models/param.py index 1ae01dc..53ac79a 100644 --- a/airflow/models/param.py +++ b/airflow/models/param.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - from typing import Any, Dict, Optional import jsonschema @@ -49,6 +48,7 @@ class Param: """ __NO_VALUE_SENTINEL = NoValueSentinel() + CLASS_IDENTIFIER = '__class' def __init__(self, default: Any = __NO_VALUE_SENTINEL, description: str = None, **kwargs): self.value = default @@ -90,7 +90,7 @@ class Param: def dump(self) -> dict: """Dump the Param as a dictionary""" - out_dict = {'__class': f'{self.__module__}.{self.__class__.__name__}'} + out_dict = {self.CLASS_IDENTIFIER: f'{self.__module__}.{self.__class__.__name__}'} out_dict.update(self.__dict__) return out_dict diff --git a/airflow/serialization/schema.json b/airflow/serialization/schema.json index b4a64b4..6d25c1e 100644 --- a/airflow/serialization/schema.json +++ b/airflow/serialization/schema.json @@ -78,7 +78,7 @@ "dag": { "type": "object", "properties": { - "params": { "$ref": "#/definitions/dict" }, + "params": { "$ref": "#/definitions/params_dict" }, "_dag_id": { "type": "string" }, "tasks": { "$ref": "#/definitions/tasks" }, "timezone": { "$ref": "#/definitions/timezone" }, @@ -135,6 +135,24 @@ "type": "array", "additionalProperties": { "$ref": "#/definitions/operator" } }, + "params_dict": { + "type": "object", + "additionalProperties": {"$ref": "#/definitions/param" } + }, + "param": { + "$comment": "A param for a dag / operator", + "type": "object", + "required": [ + "__class", + "default" + ], + "properties": { + "__class": { "type": "string" }, + "default": {}, + "description": {"anyOf": [{"type":"string"}, {"type":"null"}]}, + "schema": { "$ref": "#/definitions/dict" } + } + }, "operator": { "$comment": "A task/operator in a DAG", "type": "object", @@ -166,7 +184,7 @@ "retry_delay": { "$ref": "#/definitions/timedelta" }, "retry_exponential_backoff": { "type": "boolean" }, "max_retry_delay": { "$ref": "#/definitions/timedelta" }, - "params": { "$ref": "#/definitions/dict" }, + "params": { "$ref": "#/definitions/params_dict" }, "priority_weight": { "type": "number" }, "weight_rule": { "type": "string" }, "executor_config": { "$ref": "#/definitions/dict" }, diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 6ec9770..c451695 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -55,7 +55,6 @@ try: except ImportError: HAS_KUBERNETES = False - if TYPE_CHECKING: from airflow.ti_deps.deps.base_ti_dep import BaseTIDep @@ -325,7 +324,7 @@ class BaseSerialization: elif isinstance(var, TaskGroup): return SerializedTaskGroup.serialize_task_group(var) elif isinstance(var, Param): - return cls._encode(var.dump(), type_=DAT.PARAM) + return cls._encode(cls._serialize_param(var), type_=DAT.PARAM) else: log.debug('Cast type %s to str in serialization.', type(var)) return str(var) @@ -368,9 +367,7 @@ class BaseSerialization: elif type_ == DAT.TUPLE: return tuple(cls._deserialize(v) for v in var) elif type_ == DAT.PARAM: - param_class = import_string(var['_type']) - del var['_type'] - return param_class(**var) + return cls._deserialize_param(var) else: raise TypeError(f'Invalid type {type_!s} in deserialization.') @@ -410,29 +407,58 @@ class BaseSerialization: return False @classmethod + def _serialize_param(cls, param: Param): + return dict( + __class=f"{param.__module__}.{param.__class__.__name__}", + default=cls._serialize(param.value), + description=cls._serialize(param.description), + schema=cls._serialize(param.schema), + ) + + @classmethod + def _deserialize_param(cls, param_dict: Dict): + """ + In 2.2.0, Param attrs were assumed to be json-serializable and were not run through + this class's ``_serialize`` method. So before running through ``_deserialize``, + we first verify that it's necessary to do. + """ + class_name = param_dict['__class'] + class_ = import_string(class_name) # type: Type[Param] + attrs = ('default', 'description', 'schema') + kwargs = {} + for attr in attrs: + if attr not in param_dict: + continue + val = param_dict[attr] + is_serialized = isinstance(val, dict) and '__type' in val + if is_serialized: + deserialized_val = cls._deserialize(param_dict[attr]) + kwargs[attr] = deserialized_val + else: + kwargs[attr] = val + return class_(**kwargs) + + @classmethod def _serialize_params_dict(cls, params: ParamsDict): """Serialize Params dict for a DAG/Task""" serialized_params = {} for k, v in params.items(): # TODO: As of now, we would allow serialization of params which are of type Param only if f'{v.__module__}.{v.__class__.__name__}' == 'airflow.models.param.Param': - kwargs = v.dump() - kwargs['default'] = kwargs.pop('value') - serialized_params[k] = kwargs + serialized_params[k] = cls._serialize_param(v) else: raise ValueError('Params to a DAG or a Task can be only of type airflow.models.param.Param') return serialized_params @classmethod def _deserialize_params_dict(cls, encoded_params: Dict) -> ParamsDict: - """Deserialize a DAGs Params dict""" + """Deserialize a DAG's Params dict""" op_params = {} for k, v in encoded_params.items(): if isinstance(v, dict) and "__class" in v: - param_class = import_string(v['__class']) - op_params[k] = param_class(**v) + op_params[k] = cls._deserialize_param(v) else: - # Old style params, upgrade it + # Old style params, convert it op_params[k] = Param(v) return ParamsDict(op_params) diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index afba96a..6fec7f9 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -21,6 +21,7 @@ import copy import importlib import importlib.util +import json import multiprocessing import os from datetime import datetime, timedelta @@ -724,6 +725,7 @@ class TestStringifiedDAGs: [ (None, {}), ({"param_1": "value_1"}, {"param_1": "value_1"}), + ({"param_1": {1, 2, 3}}, {"param_1": {1, 2, 3}}), ], ) def test_dag_params_roundtrip(self, val, expected_val): @@ -733,7 +735,10 @@ class TestStringifiedDAGs: dag = DAG(dag_id='simple_dag', params=val) BaseOperator(task_id='simple_task', dag=dag, start_date=datetime(2019, 8, 1)) - serialized_dag = SerializedDAG.to_dict(dag) + serialized_dag_json = SerializedDAG.to_json(dag) + + serialized_dag = json.loads(serialized_dag_json) + assert "params" in serialized_dag["dag"] deserialized_dag = SerializedDAG.from_dict(serialized_dag) @@ -764,14 +769,37 @@ class TestStringifiedDAGs: params={'path': S3Param('s3://my_bucket/my_path')}, ) - with pytest.raises(SerializationError): - SerializedDAG.to_dict(dag) + @pytest.mark.parametrize( + 'param', + [ + Param('my value', description='hello', schema={'type': 'string'}), + Param('my value', description='hello'), + Param(None, description=None), + ], + ) + def test_full_param_roundtrip(self, param): + """ + Test to make sure that only native Param objects are being passed as dag or task params + """ + + dag = DAG(dag_id='simple_dag', params={'my_param': param}) + serialized_json = SerializedDAG.to_json(dag) + serialized = json.loads(serialized_json) + SerializedDAG.validate_schema(serialized) + dag = SerializedDAG.from_dict(serialized) + + assert dag.params["my_param"] == param.value + observed_param = dict.get(dag.params, 'my_param') + assert isinstance(observed_param, Param) + assert observed_param.description == param.description + assert observed_param.schema == param.schema @pytest.mark.parametrize( "val, expected_val", [ (None, {}), ({"param_1": "value_1"}, {"param_1": "value_1"}), + ({"param_1": {1, 2, 3}}, {"param_1": {1, 2, 3}}), ], ) def test_task_params_roundtrip(self, val, expected_val): @@ -1433,29 +1461,32 @@ class TestStringifiedDAGs: assert serialized_obj == expected_output def test_params_upgrade(self): + """when pre-2.2.0 param (i.e. primitive) is deserialized we convert to Param""" serialized = { "__version": 1, "dag": { "_dag_id": "simple_dag", - "fileloc": __file__, + "fileloc": '/path/to/file.py', "tasks": [], "timezone": "UTC", "params": {"none": None, "str": "str", "dict": {"a": "b"}}, }, } - SerializedDAG.validate_schema(serialized) dag = SerializedDAG.from_dict(serialized) assert dag.params["none"] is None assert isinstance(dict.__getitem__(dag.params, "none"), Param) assert dag.params["str"] == "str" - def test_params_serialize_default(self): + def test_params_serialize_default_2_2_0(self): + """In 2.0.0, param ``default`` was assumed to be json-serializable objects and were not run though + the standard serializer function. In 2.2.2 we serialize param ``default``. We keep this + test only to ensure that params stored in 2.2.0 can still be parsed correctly.""" serialized = { "__version": 1, "dag": { "_dag_id": "simple_dag", - "fileloc": __file__, + "fileloc": '/path/to/file.py', "tasks": [], "timezone": "UTC", "params": {"str": {"__class": "airflow.models.param.Param", "default": "str"}}, @@ -1467,6 +1498,33 @@ class TestStringifiedDAGs: assert isinstance(dict.__getitem__(dag.params, "str"), Param) assert dag.params["str"] == "str" + def test_params_serialize_default(self): + serialized = { + "__version": 1, + "dag": { + "_dag_id": "simple_dag", + "fileloc": '/path/to/file.py', + "tasks": [], + "timezone": "UTC", + "params": { + "my_param": { + "default": "a string value", + "description": "hello", + "schema": {"__var": {"type": "string"}, "__type": "dict"}, + "__class": "airflow.models.param.Param", + } + }, + }, + } + SerializedDAG.validate_schema(serialized) + dag = SerializedDAG.from_dict(serialized) + + assert dag.params["my_param"] == "a string value" + param = dict.get(dag.params, 'my_param') + assert isinstance(param, Param) + assert param.description == 'hello' + assert param.schema == {'type': 'string'} + def test_kubernetes_optional(): """Serialisation / deserialisation continues to work without kubernetes installed"""
