This is an automated email from the ASF dual-hosted git repository.
jedcunningham pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 6032159 Upgrade old DAG/task param format when deserializing from the
DB (#18986)
6032159 is described below
commit 6032159aadc65b39d385dfa4d24e7f77899f6d46
Author: Ash Berlin-Taylor <[email protected]>
AuthorDate: Fri Oct 15 17:27:45 2021 +0100
Upgrade old DAG/task param format when deserializing from the DB (#18986)
---
airflow/serialization/serialized_objects.py | 66 ++++++++++++---------------
tests/serialization/test_dag_serialization.py | 18 ++++++++
2 files changed, 47 insertions(+), 37 deletions(-)
diff --git a/airflow/serialization/serialized_objects.py
b/airflow/serialization/serialized_objects.py
index 5200319..3d30ca4 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -409,6 +409,32 @@ class BaseSerialization:
return True
return False
+ @classmethod
+ def _serialize_params_dict(cls, params: ParamsDict):
+ """Serialize Params dict for a DAG/Task"""
+ serialized_params = {}
+ for k, v in params.items():
+ # TODO: As of now, we would allow serialization of params which
are of type Param only
+ if f'{v.__module__}.{v.__class__.__name__}' ==
'airflow.models.param.Param':
+ serialized_params[k] = v.dump()
+ else:
+ raise ValueError('Params to a DAG or a Task can be only of
type airflow.models.param.Param')
+ return serialized_params
+
+ @classmethod
+ def _deserialize_params_dict(cls, encoded_params: Dict) -> ParamsDict:
+ """Deserialize a DAGs Params dict"""
+ op_params = {}
+ for k, v in encoded_params.items():
+ if isinstance(v, dict) and "__class" in v:
+ param_class = import_string(v['__class'])
+ op_params[k] = param_class(**v)
+ else:
+ # Old style params, upgrade it
+ op_params[k] = Param(v)
+
+ return ParamsDict(op_params)
+
class DependencyDetector:
"""Detects dependencies between DAGs."""
@@ -584,7 +610,7 @@ class SerializedBaseOperator(BaseOperator,
BaseSerialization):
elif k == "deps":
v = cls._deserialize_deps(v)
elif k == "params":
- v = cls._deserialize_operator_params(v)
+ v = cls._deserialize_params_dict(v)
elif k in cls._decorated_fields or k not in
op.get_serialized_fields():
v = cls._deserialize(v)
# else use v as it is
@@ -722,17 +748,6 @@ class SerializedBaseOperator(BaseOperator,
BaseSerialization):
return serialize_operator_extra_links
@classmethod
- def _deserialize_operator_params(cls, encoded_op_params: Dict) ->
Dict[str, Param]:
- """Deserialize Params dict of a operator"""
- op_params = {}
- for k, v in encoded_op_params.items():
- param_class = import_string(v['__class'])
- del v['__class']
- op_params[k] = param_class(**v)
-
- return ParamsDict(op_params)
-
- @classmethod
def _serialize_operator_params(cls, op_params: ParamsDict):
"""Serialize Params dict of a operator"""
serialized_params = {}
@@ -802,7 +817,7 @@ class SerializedDAG(DAG, BaseSerialization):
# Edge info in the JSON exactly matches our internal structure
serialize_dag["edge_info"] = dag.edge_info
- serialize_dag["params"] = cls._serialize_dag_params(dag.params)
+ serialize_dag["params"] = cls._serialize_params_dict(dag.params)
# has_on_*_callback are only stored if the value is True, as the
default is False
if dag.has_on_success_callback:
@@ -843,7 +858,7 @@ class SerializedDAG(DAG, BaseSerialization):
elif k in cls._decorated_fields:
v = cls._deserialize(v)
elif k == "params":
- v = cls._deserialize_dag_params(v)
+ v = cls._deserialize_params_dict(v)
# else use v as it is
setattr(dag, k, v)
@@ -915,29 +930,6 @@ class SerializedDAG(DAG, BaseSerialization):
raise ValueError(f"Unsure how to deserialize version {ver!r}")
return cls.deserialize_dag(serialized_obj['dag'])
- @classmethod
- def _serialize_dag_params(cls, dag_params: ParamsDict):
- """Serialize Params dict for a DAG"""
- serialized_params = {}
- for k, v in dag_params.items():
- # TODO: As of now, we would allow serialization of params which
are of type Param only
- if f'{v.__module__}.{v.__class__.__name__}' ==
'airflow.models.param.Param':
- serialized_params[k] = v.dump()
- else:
- raise ValueError('Params to a DAG can be only of type
airflow.models.param.Param')
- return serialized_params
-
- @classmethod
- def _deserialize_dag_params(cls, encoded_dag_params: Dict) -> ParamsDict:
- """Deserialize a DAGs Params dict"""
- op_params = {}
- for k, v in encoded_dag_params.items():
- param_class = import_string(v['__class'])
- del v['__class']
- op_params[k] = param_class(**v)
-
- return ParamsDict(op_params)
-
class SerializedTaskGroup(TaskGroup, BaseSerialization):
"""A JSON serializable representation of TaskGroup."""
diff --git a/tests/serialization/test_dag_serialization.py
b/tests/serialization/test_dag_serialization.py
index 5db4fc8..22eea07 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -1432,6 +1432,24 @@ class TestStringifiedDAGs:
serialized_obj = serialized_obj["__var"]
assert serialized_obj == expected_output
+ def test_params_upgrade(self):
+ serialized = {
+ "__version": 1,
+ "dag": {
+ "_dag_id": "simple_dag",
+ "fileloc": __file__,
+ "tasks": [],
+ "timezone": "UTC",
+ "params": {"none": None, "str": "str", "dict": {"a": "b"}},
+ },
+ }
+ SerializedDAG.validate_schema(serialized)
+ dag = SerializedDAG.from_dict(serialized)
+
+ assert dag.params["none"] is None
+ assert isinstance(dict.__getitem__(dag.params, "none"), Param)
+ assert dag.params["str"] == "str"
+
def test_kubernetes_optional():
"""Serialisation / deserialisation continues to work without kubernetes
installed"""