This is an automated email from the ASF dual-hosted git repository. ephraimanierobi pushed a commit to branch v2-8-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 1995d45f0f9557a5962c3d86e793241af4b79c2c Author: raphaelauv <[email protected]> AuthorDate: Sun Feb 11 15:42:33 2024 +0100 treeview - deterministic and new getter (#37162) * treeview - determinist and new getter * review 1 --------- Co-authored-by: raphaelauv <[email protected]> (cherry picked from commit 0c02ead4d8a527cbf0a916b6344f255c520e637f) --- airflow/models/dag.py | 25 ++++++++++++++++++------- tests/models/test_dag.py | 17 ++++++++++++++--- 2 files changed, 32 insertions(+), 10 deletions(-) diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 8d0aef0d5d..f90980ecbe 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -41,6 +41,7 @@ from typing import ( Callable, Collection, Container, + Generator, Iterable, Iterator, List, @@ -2621,15 +2622,25 @@ class DAG(LoggingMixin): def tree_view(self) -> None: """Print an ASCII tree representation of the DAG.""" + for tmp in self._generate_tree_view(): + print(tmp) - def get_downstream(task, level=0): - print((" " * level * 4) + str(task)) + def _generate_tree_view(self) -> Generator[str, None, None]: + def get_downstream(task, level=0) -> Generator[str, None, None]: + yield (" " * level * 4) + str(task) level += 1 - for t in task.downstream_list: - get_downstream(t, level) - - for t in self.roots: - get_downstream(t) + for tmp_task in sorted(task.downstream_list, key=lambda x: x.task_id): + yield from get_downstream(tmp_task, level) + + for t in sorted(self.roots, key=lambda x: x.task_id): + yield from get_downstream(t) + + def get_tree_view(self) -> str: + """Return an ASCII tree representation of the DAG.""" + rst = "" + for tmp in self._generate_tree_view(): + rst += tmp + "\n" + return rst @property def task(self) -> TaskDecoratorCollection: diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index f367b00abe..152bed8e94 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -1369,19 +1369,30 @@ class TestDag: def test_tree_view(self): """Verify correctness of dag.tree_view().""" with DAG("test_dag", start_date=DEFAULT_DATE) as dag: - op1 = EmptyOperator(task_id="t1") + op1_a = EmptyOperator(task_id="t1_a") + op1_b = EmptyOperator(task_id="t1_b") op2 = EmptyOperator(task_id="t2") op3 = EmptyOperator(task_id="t3") - op1 >> op2 >> op3 + op1_b >> op2 + op1_a >> op2 >> op3 with redirect_stdout(StringIO()) as stdout: dag.tree_view() stdout = stdout.getvalue() stdout_lines = stdout.splitlines() - assert "t1" in stdout_lines[0] + assert "t1_a" in stdout_lines[0] assert "t2" in stdout_lines[1] assert "t3" in stdout_lines[2] + assert "t1_b" in stdout_lines[3] + assert dag.get_tree_view() == ( + "<Task(EmptyOperator): t1_a>\n" + " <Task(EmptyOperator): t2>\n" + " <Task(EmptyOperator): t3>\n" + "<Task(EmptyOperator): t1_b>\n" + " <Task(EmptyOperator): t2>\n" + " <Task(EmptyOperator): t3>\n" + ) def test_duplicate_task_ids_not_allowed_with_dag_context_manager(self): """Verify tasks with Duplicate task_id raises error"""
