This is an automated email from the ASF dual-hosted git repository. potiuk pushed a commit to branch v1-10-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit db3f27ef0b95b2d30970c32fa9ad257b5dff0990 Author: Ash Berlin-Taylor <[email protected]> AuthorDate: Sun May 10 11:41:47 2020 +0100 Correctly restore upstream_task_ids when deserializing Operators (#8775) This test exposed a bug in one of the example dags, that wasn't caught by #6549. That will be a fixed in a separate issue, but it caused the round-trip tests to fail here Fixes #8720 (cherry picked from commit 280f1f0c4cc49aba1b2f8b456326795733769d18) --- airflow/serialization/serialized_objects.py | 2 +- tests/serialization/test_dag_serialization.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 3e564ec..8d261aa 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -591,7 +591,7 @@ class SerializedDAG(DAG, BaseSerialization): for task_id in serializable_task.downstream_task_ids: # Bypass set_upstream etc here - it does more than we want # noinspection PyProtectedMember - dag.task_dict[task_id]._upstream_task_ids.add(task_id) # pylint: disable=protected-access + dag.task_dict[task_id]._upstream_task_ids.add(serializable_task.task_id) # noqa: E501 # pylint: disable=protected-access return dag diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index e28e2b2..6b714a8 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -358,6 +358,9 @@ class TestStringifiedDAGs(unittest.TestCase): assert serialized_task.task_type == task.task_type assert set(serialized_task.template_fields) == set(task.template_fields) + assert serialized_task.upstream_task_ids == task.upstream_task_ids + assert serialized_task.downstream_task_ids == task.downstream_task_ids + for field in fields_to_check: assert getattr(serialized_task, field) == getattr(task, field), \ '{}.{}.{} does not match'.format(task.dag.dag_id, task.task_id, field)
