This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 1dc582dba3 fix openlineage parsing dag tree with MappedOperator 
(#40621)
1dc582dba3 is described below

commit 1dc582dba32156bd48da41c0cc5d1b2ab699993b
Author: Kacper Muda <[email protected]>
AuthorDate: Fri Jul 5 15:49:00 2024 +0200

    fix openlineage parsing dag tree with MappedOperator (#40621)
    
    Signed-off-by: Kacper Muda <[email protected]>
---
 airflow/providers/openlineage/utils/utils.py    |  9 +++---
 tests/providers/openlineage/utils/test_utils.py | 37 +++++++++++++++++++++----
 2 files changed, 36 insertions(+), 10 deletions(-)

diff --git a/airflow/providers/openlineage/utils/utils.py 
b/airflow/providers/openlineage/utils/utils.py
index dd9be69d50..b75e5f101b 100644
--- a/airflow/providers/openlineage/utils/utils.py
+++ b/airflow/providers/openlineage/utils/utils.py
@@ -371,12 +371,13 @@ def _get_parsed_dag_tree(dag: DAG) -> dict:
         # Determine the level by counting the leading spaces, assuming 4 
spaces per level
         # as defined in airflow.models.dag.DAG._generate_tree_view()
         level = (len(line) - len(stripped_line)) // 4
-        # airflow.models.baseoperator.BaseOperator.__repr__ is used in DAG tree
-        # <Task({op_class}): {task_id}>
-        match = re.match(r"^<Task\((.+)\): (.*?)>$", stripped_line)
+        # airflow.models.baseoperator.BaseOperator.__repr__ or
+        # airflow.models.mappedoperator.MappedOperator.__repr__ is used in DAG 
tree
+        # <Task({op_class}): {task_id}> or <Mapped({op_class}): {task_id}>
+        match = re.match(r"^<(?:Task|Mapped)\(.+\): (.+)>$", stripped_line)
         if not match:
             return {}
-        current_task_id = match[2]
+        current_task_id = match[1]
 
         if level == 0:  # It's a root task
             task_dict[current_task_id] = {}
diff --git a/tests/providers/openlineage/utils/test_utils.py 
b/tests/providers/openlineage/utils/test_utils.py
index d58be508d4..381743141a 100644
--- a/tests/providers/openlineage/utils/test_utils.py
+++ b/tests/providers/openlineage/utils/test_utils.py
@@ -21,6 +21,8 @@ import datetime
 from unittest.mock import MagicMock
 
 from airflow import DAG
+from airflow.decorators import task
+from airflow.models.baseoperator import BaseOperator
 from airflow.models.mappedoperator import MappedOperator
 from airflow.operators.bash import BashOperator
 from airflow.operators.empty import EmptyOperator
@@ -311,20 +313,42 @@ def test_dag_tree_level_indent():
 
 
 def test_get_dag_tree():
+    class TestMappedOperator(BaseOperator):
+        def __init__(self, value, **kwargs):
+            super().__init__(**kwargs)
+            self.value = value
+
+        def execute(self, context):
+            return self.value + 1
+
+    @task
+    def generate_list() -> list:
+        return [1, 2, 3]
+
+    @task
+    def process_item(item: int) -> int:
+        return item * 2
+
+    @task
+    def sum_values(values: list[int]) -> int:
+        return sum(values)
+
     with DAG(dag_id="dag", start_date=datetime.datetime(2024, 6, 1)) as dag:
-        task = CustomOperatorForTest(task_id="task", bash_command="exit 0;")
+        task_ = BashOperator(task_id="task", bash_command="exit 0;")
         task_0 = BashOperator(task_id="task_0", bash_command="exit 0;")
         task_1 = BashOperator(task_id="task_1", bash_command="exit 1;")
         task_2 = PythonOperator(task_id="task_2", python_callable=lambda: 1)
         task_3 = BashOperator(task_id="task_3", bash_command="exit 0;")
-        task_4 = EmptyOperator(
-            task_id="task_4",
-        )
+        task_4 = EmptyOperator(task_id="task_4")
         task_5 = BashOperator(task_id="task_5", bash_command="exit 0;")
         task_6 = EmptyOperator(task_id="task_6.test5")
         task_7 = BashOperator(task_id="task_7", bash_command="exit 0;")
         task_8 = PythonOperator(task_id="task_8", python_callable=lambda: 1)  
# noqa: F841
-        task_9 = BashOperator(task_id="task_9", bash_command="exit 0;")
+        task_9 = TestMappedOperator.partial(task_id="task_9").expand(value=[1, 
2])
+
+        list_result = generate_list()
+        processed_results = process_item.expand(item=list_result)
+        result_sum = sum_values(processed_results)  # noqa: F841
 
         with TaskGroup("section_1", prefix_group_id=True) as tg:
             task_10 = PythonOperator(task_id="task_3", python_callable=lambda: 
1)
@@ -333,12 +357,13 @@ def test_get_dag_tree():
                 with TaskGroup("section_3", parent_group=tg2):
                     task_12 = PythonOperator(task_id="task_12", 
python_callable=lambda: 1)
 
-        task >> [task_2, task_7]
+        task_ >> [task_2, task_7]
         task_0 >> [task_2, task_1] >> task_3 >> [task_4, task_5] >> task_6
         task_1 >> task_9 >> task_3 >> task_4 >> task_5 >> task_6
         task_3 >> task_10 >> task_12
 
         expected = {
+            "generate_list": {"process_item": {"sum_values": {}}},
             "section_1.section_2.task_11": {},
             "task": {
                 "task_2": {

Reply via email to