This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 1dc582dba3 fix openlineage parsing dag tree with MappedOperator
(#40621)
1dc582dba3 is described below
commit 1dc582dba32156bd48da41c0cc5d1b2ab699993b
Author: Kacper Muda <[email protected]>
AuthorDate: Fri Jul 5 15:49:00 2024 +0200
fix openlineage parsing dag tree with MappedOperator (#40621)
Signed-off-by: Kacper Muda <[email protected]>
---
airflow/providers/openlineage/utils/utils.py | 9 +++---
tests/providers/openlineage/utils/test_utils.py | 37 +++++++++++++++++++++----
2 files changed, 36 insertions(+), 10 deletions(-)
diff --git a/airflow/providers/openlineage/utils/utils.py
b/airflow/providers/openlineage/utils/utils.py
index dd9be69d50..b75e5f101b 100644
--- a/airflow/providers/openlineage/utils/utils.py
+++ b/airflow/providers/openlineage/utils/utils.py
@@ -371,12 +371,13 @@ def _get_parsed_dag_tree(dag: DAG) -> dict:
# Determine the level by counting the leading spaces, assuming 4
spaces per level
# as defined in airflow.models.dag.DAG._generate_tree_view()
level = (len(line) - len(stripped_line)) // 4
- # airflow.models.baseoperator.BaseOperator.__repr__ is used in DAG tree
- # <Task({op_class}): {task_id}>
- match = re.match(r"^<Task\((.+)\): (.*?)>$", stripped_line)
+ # airflow.models.baseoperator.BaseOperator.__repr__ or
+ # airflow.models.mappedoperator.MappedOperator.__repr__ is used in DAG
tree
+ # <Task({op_class}): {task_id}> or <Mapped({op_class}): {task_id}>
+ match = re.match(r"^<(?:Task|Mapped)\(.+\): (.+)>$", stripped_line)
if not match:
return {}
- current_task_id = match[2]
+ current_task_id = match[1]
if level == 0: # It's a root task
task_dict[current_task_id] = {}
diff --git a/tests/providers/openlineage/utils/test_utils.py
b/tests/providers/openlineage/utils/test_utils.py
index d58be508d4..381743141a 100644
--- a/tests/providers/openlineage/utils/test_utils.py
+++ b/tests/providers/openlineage/utils/test_utils.py
@@ -21,6 +21,8 @@ import datetime
from unittest.mock import MagicMock
from airflow import DAG
+from airflow.decorators import task
+from airflow.models.baseoperator import BaseOperator
from airflow.models.mappedoperator import MappedOperator
from airflow.operators.bash import BashOperator
from airflow.operators.empty import EmptyOperator
@@ -311,20 +313,42 @@ def test_dag_tree_level_indent():
def test_get_dag_tree():
+ class TestMappedOperator(BaseOperator):
+ def __init__(self, value, **kwargs):
+ super().__init__(**kwargs)
+ self.value = value
+
+ def execute(self, context):
+ return self.value + 1
+
+ @task
+ def generate_list() -> list:
+ return [1, 2, 3]
+
+ @task
+ def process_item(item: int) -> int:
+ return item * 2
+
+ @task
+ def sum_values(values: list[int]) -> int:
+ return sum(values)
+
with DAG(dag_id="dag", start_date=datetime.datetime(2024, 6, 1)) as dag:
- task = CustomOperatorForTest(task_id="task", bash_command="exit 0;")
+ task_ = BashOperator(task_id="task", bash_command="exit 0;")
task_0 = BashOperator(task_id="task_0", bash_command="exit 0;")
task_1 = BashOperator(task_id="task_1", bash_command="exit 1;")
task_2 = PythonOperator(task_id="task_2", python_callable=lambda: 1)
task_3 = BashOperator(task_id="task_3", bash_command="exit 0;")
- task_4 = EmptyOperator(
- task_id="task_4",
- )
+ task_4 = EmptyOperator(task_id="task_4")
task_5 = BashOperator(task_id="task_5", bash_command="exit 0;")
task_6 = EmptyOperator(task_id="task_6.test5")
task_7 = BashOperator(task_id="task_7", bash_command="exit 0;")
task_8 = PythonOperator(task_id="task_8", python_callable=lambda: 1)
# noqa: F841
- task_9 = BashOperator(task_id="task_9", bash_command="exit 0;")
+ task_9 = TestMappedOperator.partial(task_id="task_9").expand(value=[1,
2])
+
+ list_result = generate_list()
+ processed_results = process_item.expand(item=list_result)
+ result_sum = sum_values(processed_results) # noqa: F841
with TaskGroup("section_1", prefix_group_id=True) as tg:
task_10 = PythonOperator(task_id="task_3", python_callable=lambda:
1)
@@ -333,12 +357,13 @@ def test_get_dag_tree():
with TaskGroup("section_3", parent_group=tg2):
task_12 = PythonOperator(task_id="task_12",
python_callable=lambda: 1)
- task >> [task_2, task_7]
+ task_ >> [task_2, task_7]
task_0 >> [task_2, task_1] >> task_3 >> [task_4, task_5] >> task_6
task_1 >> task_9 >> task_3 >> task_4 >> task_5 >> task_6
task_3 >> task_10 >> task_12
expected = {
+ "generate_list": {"process_item": {"sum_values": {}}},
"section_1.section_2.task_11": {},
"task": {
"task_2": {