This is an automated email from the ASF dual-hosted git repository.

mobuchowski pushed a commit to branch openlineage_dont_run_tis_executor
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit d55bc6d3de7773e3e5613cbca5ee0850792bad76
Author: Maciej Obuchowski <[email protected]>
AuthorDate: Sun Aug 13 18:10:09 2023 +0200

    openlineage: don't run task instance listener in executor
    
    Signed-off-by: Maciej Obuchowski <[email protected]>
---
 airflow/providers/openlineage/plugins/listener.py | 20 ++++---
 tests/dags/test_dag_xcom_openlineage.py           | 41 ++++++++++++++
 tests/listeners/test_listeners.py                 | 67 ++++++++++++++++++++++-
 tests/listeners/xcom_listener.py                  | 46 ++++++++++++++++
 4 files changed, 164 insertions(+), 10 deletions(-)

diff --git a/airflow/providers/openlineage/plugins/listener.py 
b/airflow/providers/openlineage/plugins/listener.py
index d85a559f56..4a6b75f677 100644
--- a/airflow/providers/openlineage/plugins/listener.py
+++ b/airflow/providers/openlineage/plugins/listener.py
@@ -17,7 +17,7 @@
 from __future__ import annotations
 
 import logging
-from concurrent.futures import Executor, ThreadPoolExecutor
+from concurrent.futures import ThreadPoolExecutor
 from datetime import datetime
 from typing import TYPE_CHECKING
 
@@ -42,8 +42,8 @@ class OpenLineageListener:
     """OpenLineage listener sends events on task instance and dag run starts, 
completes and failures."""
 
     def __init__(self):
+        self._executor = None
         self.log = logging.getLogger(__name__)
-        self.executor: Executor = None  # type: ignore
         self.extractor_manager = ExtractorManager()
         self.adapter = OpenLineageAdapter()
 
@@ -102,7 +102,7 @@ class OpenLineageListener:
                 },
             )
 
-        self.executor.submit(on_running)
+        on_running()
 
     @hookimpl
     def on_task_instance_success(self, previous_state, task_instance: 
TaskInstance, session):
@@ -130,7 +130,7 @@ class OpenLineageListener:
                 task=task_metadata,
             )
 
-        self.executor.submit(on_success)
+        on_success()
 
     @hookimpl
     def on_task_instance_failed(self, previous_state, task_instance: 
TaskInstance, session):
@@ -158,12 +158,17 @@ class OpenLineageListener:
                 task=task_metadata,
             )
 
-        self.executor.submit(on_failure)
+        on_failure()
+
+    @property
+    def executor(self):
+        if not self._executor:
+            self._executor = ThreadPoolExecutor(max_workers=8, 
thread_name_prefix="openlineage_")
+        return self._executor
 
     @hookimpl
     def on_starting(self, component):
         self.log.debug("on_starting: %s", component.__class__.__name__)
-        self.executor = ThreadPoolExecutor(max_workers=8, 
thread_name_prefix="openlineage_")
 
     @hookimpl
     def before_stopping(self, component):
@@ -174,9 +179,6 @@ class OpenLineageListener:
 
     @hookimpl
     def on_dag_run_running(self, dag_run: DagRun, msg: str):
-        if not self.executor:
-            self.log.error("Executor have not started before 
`on_dag_run_running`")
-            return
         data_interval_start = dag_run.data_interval_start.isoformat() if 
dag_run.data_interval_start else None
         data_interval_end = dag_run.data_interval_end.isoformat() if 
dag_run.data_interval_end else None
         self.executor.submit(
diff --git a/tests/dags/test_dag_xcom_openlineage.py 
b/tests/dags/test_dag_xcom_openlineage.py
new file mode 100644
index 0000000000..6236c8b4ec
--- /dev/null
+++ b/tests/dags/test_dag_xcom_openlineage.py
@@ -0,0 +1,41 @@
+##
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import datetime
+
+from airflow.models import DAG
+from airflow.operators.python import PythonOperator
+
+dag = DAG(
+    dag_id="test_dag_xcom_openlineage",
+    default_args={"owner": "airflow", "retries": 3, "start_date": 
datetime.datetime(2022, 1, 1)},
+    schedule="0 0 * * *",
+    dagrun_timeout=datetime.timedelta(minutes=60),
+)
+
+
+def push_and_pull(ti, **kwargs):
+    ti.xcom_push(key="pushed_key", value="asdf")
+    ti.xcom_pull(key="pushed_key")
+
+
+task = PythonOperator(task_id="push_and_pull", python_callable=push_and_pull, 
dag=dag)
+
+if __name__ == "__main__":
+    dag.cli()
diff --git a/tests/listeners/test_listeners.py 
b/tests/listeners/test_listeners.py
index 6369bd60da..d4dba22a78 100644
--- a/tests/listeners/test_listeners.py
+++ b/tests/listeners/test_listeners.py
@@ -16,22 +16,32 @@
 # under the License.
 from __future__ import annotations
 
+import logging
+import os
+import time
+
 import pytest as pytest
 
 from airflow import AirflowException
 from airflow.jobs.job import Job, run_job
+from airflow.jobs.local_task_job_runner import LocalTaskJobRunner
 from airflow.listeners.listener import get_listener_manager
+from airflow.models import DagBag, TaskInstance
 from airflow.operators.bash import BashOperator
+from airflow.task.task_runner.standard_task_runner import StandardTaskRunner
 from airflow.utils import timezone
 from airflow.utils.session import provide_session
-from airflow.utils.state import DagRunState, TaskInstanceState
+from airflow.utils.state import DagRunState, State, TaskInstanceState
+from airflow.utils.timeout import timeout
 from tests.listeners import (
     class_listener,
     full_listener,
     lifecycle_listener,
     partial_listener,
     throwing_listener,
+    xcom_listener,
 )
+from tests.models import DEFAULT_DATE
 from tests.utils.test_helpers import MockJobRunner
 
 LISTENERS = [
@@ -46,6 +56,8 @@ DAG_ID = "test_listener_dag"
 TASK_ID = "test_listener_task"
 EXECUTION_DATE = timezone.utcnow()
 
+TEST_DAG_FOLDER = os.environ["AIRFLOW__CORE__DAGS_FOLDER"]
+
 
 @pytest.fixture(autouse=True)
 def clean_listener_manager():
@@ -163,3 +175,56 @@ def test_class_based_listener(create_task_instance, 
session=None):
 
     assert len(listener.state) == 2
     assert listener.state == [TaskInstanceState.RUNNING, 
TaskInstanceState.SUCCESS]
+
+
+def test_ol_does_not_block_xcoms():
+    """
+    Test that ensures that where a task is marked success in the UI
+    on_success_callback gets executed
+    """
+
+    path_listener_writer = "/tmp/test_ol_does_not_block_xcoms"
+    try:
+        os.unlink(path_listener_writer)
+    except OSError:
+        pass
+
+    listener = xcom_listener.XComListener(path_listener_writer, 
"push_and_pull")
+    get_listener_manager().add_listener(listener)
+    log = logging.getLogger("airflow")
+
+    dagbag = DagBag(
+        dag_folder=TEST_DAG_FOLDER,
+        include_examples=False,
+    )
+    dag = dagbag.dags.get("test_dag_xcom_openlineage")
+    task = dag.get_task("push_and_pull")
+    dag.create_dagrun(
+        run_id="test",
+        data_interval=(DEFAULT_DATE, DEFAULT_DATE),
+        state=State.RUNNING,
+        start_date=DEFAULT_DATE,
+    )
+
+    ti = TaskInstance(task=task, run_id="test")
+    job = Job(dag_id=ti.dag_id)
+    job_runner = LocalTaskJobRunner(job=job, task_instance=ti, 
ignore_ti_state=True)
+    task_runner = StandardTaskRunner(job_runner)
+    task_runner.start()
+
+    # Wait until process makes itself the leader of its own process group
+    with timeout(seconds=1):
+        while True:
+            runner_pgid = os.getpgid(task_runner.process.pid)
+            if runner_pgid == task_runner.process.pid:
+                break
+            time.sleep(0.01)
+
+    # Wait till process finishes
+    assert task_runner.return_code(timeout=10) is not None
+    log.error(task_runner.return_code())
+
+    with open(path_listener_writer) as f:
+        assert f.readline() == "on_task_instance_running\n"
+        assert f.readline() == "on_task_instance_success\n"
+        assert f.readline() == "listener\n"
diff --git a/tests/listeners/xcom_listener.py b/tests/listeners/xcom_listener.py
new file mode 100644
index 0000000000..a7ffc19178
--- /dev/null
+++ b/tests/listeners/xcom_listener.py
@@ -0,0 +1,46 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from airflow.listeners import hookimpl
+
+
+class XComListener:
+    def __init__(self, path: str, task_id: str):
+        self.path = path
+        self.task_id = task_id
+
+    def write(self, line: str):
+        with open(self.path, "a") as f:
+            f.write(line + "\n")
+
+    @hookimpl
+    def on_task_instance_running(self, previous_state, task_instance, session):
+        task_instance.xcom_push(key="listener", value="listener")
+        task_instance.xcom_pull(task_ids=task_instance.task_id, key="listener")
+        self.write("on_task_instance_running")
+
+    @hookimpl
+    def on_task_instance_success(self, previous_state, task_instance, session):
+        read = task_instance.xcom_pull(task_ids=self.task_id, key="listener")
+        self.write("on_task_instance_success")
+        self.write(read)
+
+
+def clear():
+    pass

Reply via email to