This is an automated email from the ASF dual-hosted git repository. jhtimmins pushed a commit to branch v2-1-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 943d3c514ead4f983fdbfef56120e16c5ce42e03 Author: luoyuliuyin <[email protected]> AuthorDate: Tue Jul 6 20:44:32 2021 +0800 BugFix: Correctly handle custom `deps` and `task_group` during DAG Serialization (#16734) We check if the dag changed or not via dag_hash, so we need to correctly handle deps and task_group during DAG serialization to ensure that the generation of dag_hash is stable. closes https://github.com/apache/airflow/issues/16690 (cherry picked from commit 0632ecf6f56214c78deea2a4b54ea0daebb4e95d) --- airflow/serialization/serialized_objects.py | 16 +++-- tests/serialization/test_dag_serialization.py | 99 +++++++++++++++++++++++++++ 2 files changed, 110 insertions(+), 5 deletions(-) diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index d2f456d..bdbaea8 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -424,7 +424,10 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization): ) deps.append(f'{module_name}.{klass.__name__}') - serialize_op['deps'] = deps + # deps needs to be sorted here, because op.deps is a set, which is unstable when traversing, + # and the same call may get different results. + # When calling json.dumps(self.data, sort_keys=True) to generate dag_hash, misjudgment will occur + serialize_op['deps'] = sorted(deps) # Store all template_fields as they are if there are JSON Serializable # If not, store them as strings @@ -796,6 +799,9 @@ class SerializedTaskGroup(TaskGroup, BaseSerialization): if not task_group: return None + # task_group.xxx_ids needs to be sorted here, because task_group.xxx_ids is a set, + # when converting set to list, the order is uncertain. + # When calling json.dumps(self.data, sort_keys=True) to generate dag_hash, misjudgment will occur serialize_group = { "_group_id": task_group._group_id, "prefix_group_id": task_group.prefix_group_id, @@ -808,10 +814,10 @@ class SerializedTaskGroup(TaskGroup, BaseSerialization): else (DAT.TASK_GROUP, SerializedTaskGroup.serialize_task_group(child)) for label, child in task_group.children.items() }, - "upstream_group_ids": cls._serialize(list(task_group.upstream_group_ids)), - "downstream_group_ids": cls._serialize(list(task_group.downstream_group_ids)), - "upstream_task_ids": cls._serialize(list(task_group.upstream_task_ids)), - "downstream_task_ids": cls._serialize(list(task_group.downstream_task_ids)), + "upstream_group_ids": cls._serialize(sorted(task_group.upstream_group_ids)), + "downstream_group_ids": cls._serialize(sorted(task_group.downstream_group_ids)), + "upstream_task_ids": cls._serialize(sorted(task_group.upstream_task_ids)), + "downstream_task_ids": cls._serialize(sorted(task_group.downstream_task_ids)), } return serialize_group diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index 29e6ac0..7b0b476 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -965,6 +965,105 @@ class TestStringifiedDAGs(unittest.TestCase): check_task_group(serialized_dag.task_group) + def test_deps_sorted(self): + """ + Tests serialize_operator, make sure the deps is in order + """ + from airflow.operators.dummy import DummyOperator + from airflow.sensors.external_task import ExternalTaskSensor + + execution_date = datetime(2020, 1, 1) + with DAG(dag_id="test_deps_sorted", start_date=execution_date) as dag: + task1 = ExternalTaskSensor( + task_id="task1", + external_dag_id="external_dag_id", + mode="reschedule", + ) + task2 = DummyOperator(task_id="task2") + task1 >> task2 + + serialize_op = SerializedBaseOperator.serialize_operator(dag.task_dict["task1"]) + deps = serialize_op["deps"] + assert deps == [ + 'airflow.ti_deps.deps.not_in_retry_period_dep.NotInRetryPeriodDep', + 'airflow.ti_deps.deps.not_previously_skipped_dep.NotPreviouslySkippedDep', + 'airflow.ti_deps.deps.prev_dagrun_dep.PrevDagrunDep', + 'airflow.ti_deps.deps.ready_to_reschedule.ReadyToRescheduleDep', + 'airflow.ti_deps.deps.trigger_rule_dep.TriggerRuleDep', + ] + + def test_task_group_sorted(self): + """ + Tests serialize_task_group, make sure the list is in order + """ + from airflow.operators.dummy import DummyOperator + from airflow.serialization.serialized_objects import SerializedTaskGroup + from airflow.utils.task_group import TaskGroup + + """ + start + ╱ ╲ + ╱ ╲ + task_group_up1 task_group_up2 + (task_up1) (task_up2) + ╲ ╱ + task_group_middle + (task_middle) + ╱ ╲ + task_group_down1 task_group_down2 + (task_down1) (task_down2) + ╲ ╱ + ╲ ╱ + end + """ + execution_date = datetime(2020, 1, 1) + with DAG(dag_id="test_task_group_sorted", start_date=execution_date) as dag: + start = DummyOperator(task_id="start") + + with TaskGroup("task_group_up1") as task_group_up1: + _ = DummyOperator(task_id="task_up1") + + with TaskGroup("task_group_up2") as task_group_up2: + _ = DummyOperator(task_id="task_up2") + + with TaskGroup("task_group_middle") as task_group_middle: + _ = DummyOperator(task_id="task_middle") + + with TaskGroup("task_group_down1") as task_group_down1: + _ = DummyOperator(task_id="task_down1") + + with TaskGroup("task_group_down2") as task_group_down2: + _ = DummyOperator(task_id="task_down2") + + end = DummyOperator(task_id='end') + + start >> task_group_up1 + start >> task_group_up2 + task_group_up1 >> task_group_middle + task_group_up2 >> task_group_middle + task_group_middle >> task_group_down1 + task_group_middle >> task_group_down2 + task_group_down1 >> end + task_group_down2 >> end + + task_group_middle_dict = SerializedTaskGroup.serialize_task_group( + dag.task_group.children["task_group_middle"] + ) + upstream_group_ids = task_group_middle_dict["upstream_group_ids"] + assert upstream_group_ids == ['task_group_up1', 'task_group_up2'] + + upstream_task_ids = task_group_middle_dict["upstream_task_ids"] + assert upstream_task_ids == ['task_group_up1.task_up1', 'task_group_up2.task_up2'] + + downstream_group_ids = task_group_middle_dict["downstream_group_ids"] + assert downstream_group_ids == ['task_group_down1', 'task_group_down2'] + + task_group_down1_dict = SerializedTaskGroup.serialize_task_group( + dag.task_group.children["task_group_down1"] + ) + downstream_task_ids = task_group_down1_dict["downstream_task_ids"] + assert downstream_task_ids == ['end'] + def test_edge_info_serialization(self): """ Tests edge_info serialization/deserialization.
