This is an automated email from the ASF dual-hosted git repository.

jedcunningham pushed a commit to branch v2-2-test
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/v2-2-test by this push:
     new fe30eb7  Upgrade old DAG/task param format when deserializing from the 
DB (#18986)
fe30eb7 is described below

commit fe30eb71a3cfdcee727b7598bfa77f9f0befeb6e
Author: Ash Berlin-Taylor <[email protected]>
AuthorDate: Fri Oct 15 17:27:45 2021 +0100

    Upgrade old DAG/task param format when deserializing from the DB (#18986)
    
    (cherry picked from commit 6032159aadc65b39d385dfa4d24e7f77899f6d46)
---
 airflow/serialization/serialized_objects.py   | 66 ++++++++++++---------------
 tests/serialization/test_dag_serialization.py | 18 ++++++++
 2 files changed, 47 insertions(+), 37 deletions(-)

diff --git a/airflow/serialization/serialized_objects.py 
b/airflow/serialization/serialized_objects.py
index 5200319..3d30ca4 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -409,6 +409,32 @@ class BaseSerialization:
             return True
         return False
 
+    @classmethod
+    def _serialize_params_dict(cls, params: ParamsDict):
+        """Serialize Params dict for a DAG/Task"""
+        serialized_params = {}
+        for k, v in params.items():
+            # TODO: As of now, we would allow serialization of params which 
are of type Param only
+            if f'{v.__module__}.{v.__class__.__name__}' == 
'airflow.models.param.Param':
+                serialized_params[k] = v.dump()
+            else:
+                raise ValueError('Params to a DAG or a Task can be only of 
type airflow.models.param.Param')
+        return serialized_params
+
+    @classmethod
+    def _deserialize_params_dict(cls, encoded_params: Dict) -> ParamsDict:
+        """Deserialize a DAGs Params dict"""
+        op_params = {}
+        for k, v in encoded_params.items():
+            if isinstance(v, dict) and "__class" in v:
+                param_class = import_string(v['__class'])
+                op_params[k] = param_class(**v)
+            else:
+                # Old style params, upgrade it
+                op_params[k] = Param(v)
+
+        return ParamsDict(op_params)
+
 
 class DependencyDetector:
     """Detects dependencies between DAGs."""
@@ -584,7 +610,7 @@ class SerializedBaseOperator(BaseOperator, 
BaseSerialization):
             elif k == "deps":
                 v = cls._deserialize_deps(v)
             elif k == "params":
-                v = cls._deserialize_operator_params(v)
+                v = cls._deserialize_params_dict(v)
             elif k in cls._decorated_fields or k not in 
op.get_serialized_fields():
                 v = cls._deserialize(v)
             # else use v as it is
@@ -722,17 +748,6 @@ class SerializedBaseOperator(BaseOperator, 
BaseSerialization):
         return serialize_operator_extra_links
 
     @classmethod
-    def _deserialize_operator_params(cls, encoded_op_params: Dict) -> 
Dict[str, Param]:
-        """Deserialize Params dict of a operator"""
-        op_params = {}
-        for k, v in encoded_op_params.items():
-            param_class = import_string(v['__class'])
-            del v['__class']
-            op_params[k] = param_class(**v)
-
-        return ParamsDict(op_params)
-
-    @classmethod
     def _serialize_operator_params(cls, op_params: ParamsDict):
         """Serialize Params dict of a operator"""
         serialized_params = {}
@@ -802,7 +817,7 @@ class SerializedDAG(DAG, BaseSerialization):
 
             # Edge info in the JSON exactly matches our internal structure
             serialize_dag["edge_info"] = dag.edge_info
-            serialize_dag["params"] = cls._serialize_dag_params(dag.params)
+            serialize_dag["params"] = cls._serialize_params_dict(dag.params)
 
             # has_on_*_callback are only stored if the value is True, as the 
default is False
             if dag.has_on_success_callback:
@@ -843,7 +858,7 @@ class SerializedDAG(DAG, BaseSerialization):
             elif k in cls._decorated_fields:
                 v = cls._deserialize(v)
             elif k == "params":
-                v = cls._deserialize_dag_params(v)
+                v = cls._deserialize_params_dict(v)
             # else use v as it is
 
             setattr(dag, k, v)
@@ -915,29 +930,6 @@ class SerializedDAG(DAG, BaseSerialization):
             raise ValueError(f"Unsure how to deserialize version {ver!r}")
         return cls.deserialize_dag(serialized_obj['dag'])
 
-    @classmethod
-    def _serialize_dag_params(cls, dag_params: ParamsDict):
-        """Serialize Params dict for a DAG"""
-        serialized_params = {}
-        for k, v in dag_params.items():
-            # TODO: As of now, we would allow serialization of params which 
are of type Param only
-            if f'{v.__module__}.{v.__class__.__name__}' == 
'airflow.models.param.Param':
-                serialized_params[k] = v.dump()
-            else:
-                raise ValueError('Params to a DAG can be only of type 
airflow.models.param.Param')
-        return serialized_params
-
-    @classmethod
-    def _deserialize_dag_params(cls, encoded_dag_params: Dict) -> ParamsDict:
-        """Deserialize a DAGs Params dict"""
-        op_params = {}
-        for k, v in encoded_dag_params.items():
-            param_class = import_string(v['__class'])
-            del v['__class']
-            op_params[k] = param_class(**v)
-
-        return ParamsDict(op_params)
-
 
 class SerializedTaskGroup(TaskGroup, BaseSerialization):
     """A JSON serializable representation of TaskGroup."""
diff --git a/tests/serialization/test_dag_serialization.py 
b/tests/serialization/test_dag_serialization.py
index 5db4fc8..22eea07 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -1432,6 +1432,24 @@ class TestStringifiedDAGs:
             serialized_obj = serialized_obj["__var"]
         assert serialized_obj == expected_output
 
+    def test_params_upgrade(self):
+        serialized = {
+            "__version": 1,
+            "dag": {
+                "_dag_id": "simple_dag",
+                "fileloc": __file__,
+                "tasks": [],
+                "timezone": "UTC",
+                "params": {"none": None, "str": "str", "dict": {"a": "b"}},
+            },
+        }
+        SerializedDAG.validate_schema(serialized)
+        dag = SerializedDAG.from_dict(serialized)
+
+        assert dag.params["none"] is None
+        assert isinstance(dict.__getitem__(dag.params, "none"), Param)
+        assert dag.params["str"] == "str"
+
 
 def test_kubernetes_optional():
     """Serialisation / deserialisation continues to work without kubernetes 
installed"""

Reply via email to