This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 0c02ead4d8 treeview - deterministic and new getter (#37162)
0c02ead4d8 is described below
commit 0c02ead4d8a527cbf0a916b6344f255c520e637f
Author: raphaelauv <[email protected]>
AuthorDate: Sun Feb 11 15:42:33 2024 +0100
treeview - deterministic and new getter (#37162)
* treeview - determinist and new getter
* review 1
---------
Co-authored-by: raphaelauv <[email protected]>
---
airflow/models/dag.py | 25 ++++++++++++++++++-------
tests/models/test_dag.py | 17 ++++++++++++++---
2 files changed, 32 insertions(+), 10 deletions(-)
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 988dfb25e4..d2366c0e9e 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -41,6 +41,7 @@ from typing import (
Callable,
Collection,
Container,
+ Generator,
Iterable,
Iterator,
List,
@@ -2627,15 +2628,25 @@ class DAG(LoggingMixin):
def tree_view(self) -> None:
"""Print an ASCII tree representation of the DAG."""
+ for tmp in self._generate_tree_view():
+ print(tmp)
- def get_downstream(task, level=0):
- print((" " * level * 4) + str(task))
+ def _generate_tree_view(self) -> Generator[str, None, None]:
+ def get_downstream(task, level=0) -> Generator[str, None, None]:
+ yield (" " * level * 4) + str(task)
level += 1
- for t in task.downstream_list:
- get_downstream(t, level)
-
- for t in self.roots:
- get_downstream(t)
+ for tmp_task in sorted(task.downstream_list, key=lambda x:
x.task_id):
+ yield from get_downstream(tmp_task, level)
+
+ for t in sorted(self.roots, key=lambda x: x.task_id):
+ yield from get_downstream(t)
+
+ def get_tree_view(self) -> str:
+ """Return an ASCII tree representation of the DAG."""
+ rst = ""
+ for tmp in self._generate_tree_view():
+ rst += tmp + "\n"
+ return rst
@property
def task(self) -> TaskDecoratorCollection:
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index 1f70ba051a..05681cfe88 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -1422,19 +1422,30 @@ class TestDag:
def test_tree_view(self):
"""Verify correctness of dag.tree_view()."""
with DAG("test_dag", start_date=DEFAULT_DATE) as dag:
- op1 = EmptyOperator(task_id="t1")
+ op1_a = EmptyOperator(task_id="t1_a")
+ op1_b = EmptyOperator(task_id="t1_b")
op2 = EmptyOperator(task_id="t2")
op3 = EmptyOperator(task_id="t3")
- op1 >> op2 >> op3
+ op1_b >> op2
+ op1_a >> op2 >> op3
with redirect_stdout(StringIO()) as stdout:
dag.tree_view()
stdout = stdout.getvalue()
stdout_lines = stdout.splitlines()
- assert "t1" in stdout_lines[0]
+ assert "t1_a" in stdout_lines[0]
assert "t2" in stdout_lines[1]
assert "t3" in stdout_lines[2]
+ assert "t1_b" in stdout_lines[3]
+ assert dag.get_tree_view() == (
+ "<Task(EmptyOperator): t1_a>\n"
+ " <Task(EmptyOperator): t2>\n"
+ " <Task(EmptyOperator): t3>\n"
+ "<Task(EmptyOperator): t1_b>\n"
+ " <Task(EmptyOperator): t2>\n"
+ " <Task(EmptyOperator): t3>\n"
+ )
def test_duplicate_task_ids_not_allowed_with_dag_context_manager(self):
"""Verify tasks with Duplicate task_id raises error"""