This is an automated email from the ASF dual-hosted git repository.
ash pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new a5d5bd0232 Fix deserializing Params where the default is an array
(#27944)
a5d5bd0232 is described below
commit a5d5bd0232b98c6b39e587dd144086f4b7d8664d
Author: Ash Berlin-Taylor <[email protected]>
AuthorDate: Sat Nov 26 21:41:01 2022 +0000
Fix deserializing Params where the default is an array (#27944)
In a previous change we deserialized Param values inside a list, but the
tests didn't previously cover an array of plain values (`[True]` for
instance)
This caused the webserver to 500 (bad, but only affected a single DAG)
but it _also_ caused the scheduler to crash when it tried to process
this DAG (bad-bordering on terrible! Nothing should ever bring down the
whole scheduler)
---
airflow/serialization/serialized_objects.py | 13 +++++++++----
tests/serialization/test_dag_serialization.py | 1 +
2 files changed, 10 insertions(+), 4 deletions(-)
diff --git a/airflow/serialization/serialized_objects.py
b/airflow/serialization/serialized_objects.py
index d41d70a4e7..6cb33cd203 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -560,14 +560,19 @@ class BaseSerialization:
class_: type[Param] = import_string(class_name)
attrs = ("default", "description", "schema")
kwargs = {}
+
+ def is_serialized(val):
+ if isinstance(val, dict):
+ return Encoding.TYPE in val
+ if isinstance(val, list):
+ return all(isinstance(item, dict) and Encoding.TYPE in item
for item in val)
+ return False
+
for attr in attrs:
if attr not in param_dict:
continue
val = param_dict[attr]
- is_serialized = (isinstance(val, dict) and Encoding.TYPE in val)
or (
- isinstance(val, list) and all(Encoding.TYPE in param for param
in val)
- )
- if is_serialized:
+ if is_serialized(val):
deserialized_val = cls.deserialize(param_dict[attr])
kwargs[attr] = deserialized_val
else:
diff --git a/tests/serialization/test_dag_serialization.py
b/tests/serialization/test_dag_serialization.py
index 994ba12b32..44411f5c07 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -836,6 +836,7 @@ class TestStringifiedDAGs:
Param("my value", description="hello", schema={"type": "string"}),
Param("my value", description="hello"),
Param(None, description=None),
+ Param([True], type="array", items={"type": "boolean"}),
],
)
def test_full_param_roundtrip(self, param):