This is an automated email from the ASF dual-hosted git repository.

mobuchowski pushed a commit to branch tasksdk-call-listeners
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 59f1deba9179ad1343be2eff56596e2e025ed82b
Author: Maciej Obuchowski <[email protected]>
AuthorDate: Wed Jan 8 13:31:46 2025 +0100

    make OpenLineage provider support Airflow 3's listener interface
    
    Signed-off-by: Maciej Obuchowski <[email protected]>
---
 .../administration-and-deployment/listeners.rst    |   20 +-
 .../providers/openlineage/plugins/listener.py      |  219 ++-
 .../openlineage/utils/selective_enable.py          |    8 +-
 .../airflow/providers/openlineage/utils/utils.py   |   42 +-
 .../tests/openlineage/extractors/test_manager.py   |  144 +-
 .../tests/openlineage/plugins/test_listener.py     | 1691 +++++++++++++-------
 6 files changed, 1460 insertions(+), 664 deletions(-)

diff --git a/docs/apache-airflow/administration-and-deployment/listeners.rst 
b/docs/apache-airflow/administration-and-deployment/listeners.rst
index 8ca3ed93fc0..97467361ac3 100644
--- a/docs/apache-airflow/administration-and-deployment/listeners.rst
+++ b/docs/apache-airflow/administration-and-deployment/listeners.rst
@@ -165,9 +165,7 @@ For example if you want to implement a listener that uses 
the ``error`` field in
             ...
 
             @hookimpl
-            def on_task_instance_failed(
-                self, previous_state, task_instance, error: None | str | 
BaseException, session
-            ):
+            def on_task_instance_failed(self, previous_state, task_instance, 
error: None | str | BaseException):
                 # Handle error case here
                 pass
 
@@ -177,15 +175,19 @@ For example if you want to implement a listener that uses 
the ``error`` field in
             ...
 
             @hookimpl
-            def on_task_instance_failed(self, previous_state, task_instance, 
session):
+            def on_task_instance_failed(self, previous_state, task_instance):
                 # Handle no error case here
                 pass
 
 List of changes in the listener interfaces since 2.8.0 when they were 
introduced:
 
 
-+-----------------+-----------------------------+---------------------------------------+
-| Airflow Version | Affected method             | Change                       
         |
-+=================+=============================+=======================================+
-| 2.10.0          | ``on_task_instance_failed`` | An error field added to the 
interface |
-+-----------------+-----------------------------+---------------------------------------+
++-----------------+--------------------------------------------+-------------------------------------------------------------------------+
+| Airflow Version | Affected method                            | Change        
                                                          |
++=================+============================================+=========================================================================+
+| 2.10.0          | ``on_task_instance_failed``                | An error 
field added to the interface                                   |
++-----------------+--------------------------------------------+-------------------------------------------------------------------------+
+| 3.0.0           | ``on_task_instance_running``,              | ``session`` 
argument removed from task instance listeners,              |
+|                 | ``on_task_instance_success``,              | 
``task_instance`` object is now an instance of ``RuntimeTaskInstance``  |
+|                 | ``on_task_instance_failed``                |               
                                                          |
++-----------------+--------------------------------------------+-------------------------------------------------------------------------+
diff --git a/providers/src/airflow/providers/openlineage/plugins/listener.py 
b/providers/src/airflow/providers/openlineage/plugins/listener.py
index c1da206c987..99ced767345 100644
--- a/providers/src/airflow/providers/openlineage/plugins/listener.py
+++ b/providers/src/airflow/providers/openlineage/plugins/listener.py
@@ -19,6 +19,7 @@ from __future__ import annotations
 import logging
 import os
 from concurrent.futures import ProcessPoolExecutor
+from datetime import datetime
 from typing import TYPE_CHECKING
 
 import psutil
@@ -33,6 +34,7 @@ from airflow.providers.openlineage.extractors import 
ExtractorManager
 from airflow.providers.openlineage.plugins.adapter import OpenLineageAdapter, 
RunState
 from airflow.providers.openlineage.utils.utils import (
     AIRFLOW_V_2_10_PLUS,
+    AIRFLOW_V_3_0_PLUS,
     get_airflow_dag_run_facet,
     get_airflow_debug_facet,
     get_airflow_job_facet,
@@ -42,7 +44,6 @@ from airflow.providers.openlineage.utils.utils import (
     get_user_provided_run_facets,
     is_operator_disabled,
     is_selective_lineage_enabled,
-    is_ti_rescheduled_already,
     print_warning,
 )
 from airflow.settings import configure_orm
@@ -52,9 +53,9 @@ from airflow.utils.state import TaskInstanceState
 from airflow.utils.timeout import timeout
 
 if TYPE_CHECKING:
-    from sqlalchemy.orm import Session
-
     from airflow.models import TaskInstance
+    from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
+    from airflow.settings import Session
 
 _openlineage_listener: OpenLineageListener | None = None
 
@@ -87,28 +88,49 @@ class OpenLineageListener:
         self.extractor_manager = ExtractorManager()
         self.adapter = OpenLineageAdapter()
 
-    @hookimpl
-    def on_task_instance_running(
-        self,
-        previous_state: TaskInstanceState,
-        task_instance: TaskInstance,
-        session: Session,  # This will always be QUEUED
-    ) -> None:
-        if not getattr(task_instance, "task", None) is not None:
-            self.log.warning(
-                "No task set for TI object task_id: %s - dag_id: %s - run_id 
%s",
-                task_instance.task_id,
-                task_instance.dag_id,
-                task_instance.run_id,
-            )
-            return
+    if AIRFLOW_V_3_0_PLUS:
+
+        @hookimpl
+        def on_task_instance_running(
+            self,
+            previous_state: TaskInstanceState,
+            task_instance: RuntimeTaskInstance,
+        ):
+            self.log.debug("OpenLineage listener got notification about task 
instance start")
+            context = task_instance.get_template_context()
+
+            task = context["task"]
+            if TYPE_CHECKING:
+                assert task
+            dagrun = context["dag_run"]
+            dag = context["dag"]
+            self._on_task_instance_running(task_instance, dag, dagrun, task)
+    else:
 
-        self.log.debug("OpenLineage listener got notification about task 
instance start")
-        dagrun = task_instance.dag_run
-        task = task_instance.task
-        if TYPE_CHECKING:
-            assert task
-        dag = task.dag
+        @hookimpl
+        def on_task_instance_running(
+            self,
+            previous_state: TaskInstanceState,
+            task_instance: TaskInstance,
+            session: Session,  # type: ignore[valid-type]
+        ) -> None:
+            if not getattr(task_instance, "task", None) is not None:
+                self.log.warning(
+                    "No task set for TI object task_id: %s - dag_id: %s - 
run_id %s",
+                    task_instance.task_id,
+                    task_instance.dag_id,
+                    task_instance.run_id,
+                )
+                return
+
+            self.log.debug("OpenLineage listener got notification about task 
instance start")
+            task = task_instance.task
+            if TYPE_CHECKING:
+                assert task
+
+            self._on_task_instance_running(task_instance, task.dag, 
task_instance.dag_run, task)
+
+    def _on_task_instance_running(self, task_instance: RuntimeTaskInstance | 
TaskInstance, dag, dagrun, task):
         if is_operator_disabled(task):
             self.log.debug(
                 "Skipping OpenLineage event emission for operator `%s` "
@@ -127,35 +149,38 @@ class OpenLineageListener:
             return
 
         # Needs to be calculated outside of inner method so that it gets 
cached for usage in fork processes
+        data_interval_start = dagrun.data_interval_start
+        if isinstance(data_interval_start, datetime):
+            data_interval_start = data_interval_start.isoformat()
+        data_interval_end = dagrun.data_interval_end
+        if isinstance(data_interval_end, datetime):
+            data_interval_end = data_interval_end.isoformat()
+
+        clear_number = 0
+        if hasattr(dagrun, "clear_number"):
+            clear_number = dagrun.clear_number
+
         debug_facet = get_airflow_debug_facet()
 
         @print_warning(self.log)
         def on_running():
-            # that's a workaround to detect task running from deferred state
-            # we return here because Airflow 2.3 needs task from deferred state
-            if task_instance.next_method is not None:
-                return
-
-            if is_ti_rescheduled_already(task_instance):
+            context = task_instance.get_template_context()
+            if hasattr(context, "task_reschedule_count") and 
context["task_reschedule_count"] > 0:
                 self.log.debug("Skipping this instance of rescheduled task - 
START event was emitted already")
                 return
 
             parent_run_id = self.adapter.build_dag_run_id(
                 dag_id=dag.dag_id,
                 logical_date=dagrun.logical_date,
-                clear_number=dagrun.clear_number,
+                clear_number=clear_number,
             )
-
-            if hasattr(task_instance, "logical_date"):
-                logical_date = task_instance.logical_date
-            else:
-                logical_date = task_instance.execution_date
+            start_date = task_instance.start_date if task_instance.start_date 
else timezone.utcnow()
 
             task_uuid = self.adapter.build_task_instance_run_id(
                 dag_id=dag.dag_id,
                 task_id=task.task_id,
                 try_number=task_instance.try_number,
-                logical_date=logical_date,
+                logical_date=dagrun.logical_date,
                 map_index=task_instance.map_index,
             )
             event_type = RunState.RUNNING.value.lower()
@@ -164,11 +189,6 @@ class OpenLineageListener:
             with Stats.timer(f"ol.extract.{event_type}.{operator_name}"):
                 task_metadata = 
self.extractor_manager.extract_metadata(dagrun, task)
 
-            start_date = task_instance.start_date if task_instance.start_date 
else timezone.utcnow()
-            data_interval_start = (
-                dagrun.data_interval_start.isoformat() if 
dagrun.data_interval_start else None
-            )
-            data_interval_end = dagrun.data_interval_end.isoformat() if 
dagrun.data_interval_end else None
             redacted_event = self.adapter.start_task(
                 run_id=task_uuid,
                 job_name=get_job_name(task),
@@ -195,17 +215,39 @@ class OpenLineageListener:
 
         self._execute(on_running, "on_running", use_fork=True)
 
-    @hookimpl
-    def on_task_instance_success(
-        self, previous_state: TaskInstanceState, task_instance: TaskInstance, 
session: Session
-    ) -> None:
-        self.log.debug("OpenLineage listener got notification about task 
instance success")
+    if AIRFLOW_V_3_0_PLUS:
+
+        @hookimpl
+        def on_task_instance_success(
+            self, previous_state: TaskInstanceState, task_instance: 
RuntimeTaskInstance
+        ) -> None:
+            self.log.debug("OpenLineage listener got notification about task 
instance success")
+
+            context = task_instance.get_template_context()
+            task = context["task"]
+            if TYPE_CHECKING:
+                assert task
+            dagrun = context["dag_run"]
+            dag = context["dag"]
+            self._on_task_instance_success(task_instance, dag, dagrun, task)
+
+    else:
+
+        @hookimpl
+        def on_task_instance_success(
+            self,
+            previous_state: TaskInstanceState,
+            task_instance: TaskInstance,
+            session: Session,  # type: ignore[valid-type]
+        ) -> None:
+            self.log.debug("OpenLineage listener got notification about task 
instance success")
+            task = task_instance.task
+            if TYPE_CHECKING:
+                assert task
+            self._on_task_instance_success(task_instance, task.dag, 
task_instance.dag_run, task)
 
-        dagrun = task_instance.dag_run
-        task = task_instance.task
-        if TYPE_CHECKING:
-            assert task
-        dag = task.dag
+    def _on_task_instance_success(self, task_instance: RuntimeTaskInstance, 
dag, dagrun, task):
+        end_date = timezone.utcnow()
 
         if is_operator_disabled(task):
             self.log.debug(
@@ -232,15 +274,11 @@ class OpenLineageListener:
                 clear_number=dagrun.clear_number,
             )
 
-            if hasattr(task_instance, "logical_date"):
-                logical_date = task_instance.logical_date
-            else:
-                logical_date = task_instance.execution_date
             task_uuid = self.adapter.build_task_instance_run_id(
                 dag_id=dag.dag_id,
                 task_id=task.task_id,
                 try_number=_get_try_number_success(task_instance),
-                logical_date=logical_date,
+                logical_date=dagrun.logical_date,
                 map_index=task_instance.map_index,
             )
             event_type = RunState.COMPLETE.value.lower()
@@ -251,8 +289,6 @@ class OpenLineageListener:
                     dagrun, task, complete=True, task_instance=task_instance
                 )
 
-            end_date = task_instance.end_date if task_instance.end_date else 
timezone.utcnow()
-
             redacted_event = self.adapter.complete_task(
                 run_id=task_uuid,
                 job_name=get_job_name(task),
@@ -273,7 +309,7 @@ class OpenLineageListener:
 
         self._execute(on_success, "on_success", use_fork=True)
 
-    if AIRFLOW_V_2_10_PLUS:
+    if AIRFLOW_V_3_0_PLUS:
 
         @hookimpl
         def on_task_instance_failed(
@@ -281,36 +317,54 @@ class OpenLineageListener:
             previous_state: TaskInstanceState,
             task_instance: TaskInstance,
             error: None | str | BaseException,
-            session: Session,
         ) -> None:
-            self._on_task_instance_failed(
-                previous_state=previous_state, task_instance=task_instance, 
error=error, session=session
-            )
+            self.log.debug("OpenLineage listener got notification about task 
instance failure")
+            context = task_instance.get_template_context()
+            task = context["task"]
+            if TYPE_CHECKING:
+                assert task
+            dagrun = context["dag_run"]
+            dag = context["dag"]
+            self._on_task_instance_failed(task_instance, dag, dagrun, task, 
error)
 
+    elif AIRFLOW_V_2_10_PLUS:
+
+        @hookimpl
+        def on_task_instance_failed(
+            self,
+            previous_state: TaskInstanceState,
+            task_instance: TaskInstance,
+            error: None | str | BaseException,
+            session: Session,  # type: ignore[valid-type]
+        ) -> None:
+            self.log.debug("OpenLineage listener got notification about task 
instance failure")
+            task = task_instance.task
+            if TYPE_CHECKING:
+                assert task
+            self._on_task_instance_failed(task_instance, task.dag, 
task_instance.dag_run, task, error)
     else:
 
         @hookimpl
         def on_task_instance_failed(
-            self, previous_state: TaskInstanceState, task_instance: 
TaskInstance, session: Session
+            self,
+            previous_state: TaskInstanceState,
+            task_instance: TaskInstance,
+            session: Session,  # type: ignore[valid-type]
         ) -> None:
-            self._on_task_instance_failed(
-                previous_state=previous_state, task_instance=task_instance, 
error=None, session=session
-            )
+            task = task_instance.task
+            if TYPE_CHECKING:
+                assert task
+            self._on_task_instance_failed(task_instance, task.dag, 
task_instance.dag_run, task)
 
     def _on_task_instance_failed(
         self,
-        previous_state: TaskInstanceState,
-        task_instance: TaskInstance,
-        session: Session,
+        task_instance: TaskInstance | RuntimeTaskInstance,
+        dag,
+        dagrun,
+        task,
         error: None | str | BaseException = None,
     ) -> None:
-        self.log.debug("OpenLineage listener got notification about task 
instance failure")
-
-        dagrun = task_instance.dag_run
-        task = task_instance.task
-        if TYPE_CHECKING:
-            assert task
-        dag = task.dag
+        end_date = timezone.utcnow()
 
         if is_operator_disabled(task):
             self.log.debug(
@@ -337,16 +391,11 @@ class OpenLineageListener:
                 clear_number=dagrun.clear_number,
             )
 
-            if hasattr(task_instance, "logical_date"):
-                logical_date = task_instance.logical_date
-            else:
-                logical_date = task_instance.execution_date
-
             task_uuid = self.adapter.build_task_instance_run_id(
                 dag_id=dag.dag_id,
                 task_id=task.task_id,
                 try_number=task_instance.try_number,
-                logical_date=logical_date,
+                logical_date=dagrun.logical_date,
                 map_index=task_instance.map_index,
             )
             event_type = RunState.FAIL.value.lower()
@@ -357,8 +406,6 @@ class OpenLineageListener:
                     dagrun, task, complete=True, task_instance=task_instance
                 )
 
-            end_date = task_instance.end_date if task_instance.end_date else 
timezone.utcnow()
-
             redacted_event = self.adapter.fail_task(
                 run_id=task_uuid,
                 job_name=get_job_name(task),
diff --git 
a/providers/src/airflow/providers/openlineage/utils/selective_enable.py 
b/providers/src/airflow/providers/openlineage/utils/selective_enable.py
index a3c16a1e18d..b0cd8a4455c 100644
--- a/providers/src/airflow/providers/openlineage/utils/selective_enable.py
+++ b/providers/src/airflow/providers/openlineage/utils/selective_enable.py
@@ -18,7 +18,7 @@
 from __future__ import annotations
 
 import logging
-from typing import TypeVar
+from typing import TYPE_CHECKING, TypeVar
 
 from airflow.models import DAG, Operator, Param
 from airflow.models.xcom_arg import XComArg
@@ -28,6 +28,10 @@ ENABLE_OL_PARAM = Param(True, const=True)
 DISABLE_OL_PARAM = Param(False, const=False)
 T = TypeVar("T", bound="DAG | Operator")
 
+if TYPE_CHECKING:
+    from airflow.sdk.definitions.baseoperator import BaseOperator as 
SdkBaseOperator
+
+
 log = logging.getLogger(__name__)
 
 
@@ -65,7 +69,7 @@ def disable_lineage(obj: T) -> T:
     return obj
 
 
-def is_task_lineage_enabled(task: Operator) -> bool:
+def is_task_lineage_enabled(task: Operator | SdkBaseOperator) -> bool:
     """Check if selective enable OpenLineage parameter is set to True on task 
level."""
     if task.params.get(ENABLE_OL_PARAM_NAME) is False:
         log.debug(
diff --git a/providers/src/airflow/providers/openlineage/utils/utils.py 
b/providers/src/airflow/providers/openlineage/utils/utils.py
index 4408a833fba..2cb998d47aa 100644
--- a/providers/src/airflow/providers/openlineage/utils/utils.py
+++ b/providers/src/airflow/providers/openlineage/utils/utils.py
@@ -27,12 +27,11 @@ from typing import TYPE_CHECKING, Any, Callable
 
 import attrs
 from openlineage.client.utils import RedactMixin
-from sqlalchemy import exists
 
 from airflow import __version__ as AIRFLOW_VERSION
 
 # TODO: move this maybe to Airflow's logic?
-from airflow.models import DAG, BaseOperator, DagRun, MappedOperator, 
TaskReschedule
+from airflow.models import DAG, BaseOperator, DagRun, MappedOperator
 from airflow.providers.openlineage import (
     __version__ as OPENLINEAGE_PROVIDER_VERSION,
     conf,
@@ -52,7 +51,6 @@ from airflow.providers.openlineage.utils.selective_enable 
import (
     is_task_lineage_enabled,
 )
 from airflow.providers.openlineage.version_compat import AIRFLOW_V_2_10_PLUS, 
AIRFLOW_V_3_0_PLUS
-from airflow.sensors.base import BaseSensorOperator
 from airflow.serialization.serialized_objects import SerializedBaseOperator
 from airflow.utils.context import AirflowContextDeprecationWarning
 from airflow.utils.log.secrets_masker import (
@@ -62,7 +60,11 @@ from airflow.utils.log.secrets_masker import (
     should_hide_value_for_key,
 )
 from airflow.utils.module_loading import import_string
-from airflow.utils.session import NEW_SESSION, provide_session
+
+try:
+    from airflow.sdk.definitions.baseoperator import BaseOperator as 
SdkBaseOperator
+except ImportError:
+    SdkBaseOperator = BaseOperator  # type: ignore[misc]
 
 if TYPE_CHECKING:
     from openlineage.client.event_v2 import Dataset as OpenLineageDataset
@@ -90,7 +92,7 @@ def try_import_from_string(string: str) -> Any:
         return import_string(string)
 
 
-def get_operator_class(task: BaseOperator) -> type:
+def get_operator_class(task: BaseOperator | SdkBaseOperator) -> type:
     if task.__class__.__name__ in ("DecoratedMappedOperator", 
"MappedOperator"):
         return task.operator_class
     return task.__class__
@@ -153,7 +155,7 @@ def get_user_provided_run_facets(ti: TaskInstance, 
ti_state: TaskInstanceState)
     return custom_facets
 
 
-def get_fully_qualified_class_name(operator: BaseOperator | MappedOperator) -> 
str:
+def get_fully_qualified_class_name(operator: BaseOperator | MappedOperator | 
SdkBaseOperator) -> str:
     if isinstance(operator, (MappedOperator, SerializedBaseOperator)):
         # as in 
airflow.api_connexion.schemas.common_schema.ClassReferenceSchema
         return operator._task_module + "." + operator._task_type  # type: 
ignore
@@ -161,44 +163,22 @@ def get_fully_qualified_class_name(operator: BaseOperator 
| MappedOperator) -> s
     return op_class.__module__ + "." + op_class.__name__
 
 
-def is_operator_disabled(operator: BaseOperator | MappedOperator) -> bool:
+def is_operator_disabled(operator: BaseOperator | MappedOperator | 
SdkBaseOperator) -> bool:
     return get_fully_qualified_class_name(operator) in 
conf.disabled_operators()
 
 
-def is_selective_lineage_enabled(obj: DAG | BaseOperator | MappedOperator) -> 
bool:
+def is_selective_lineage_enabled(obj: DAG | BaseOperator | MappedOperator | 
SdkBaseOperator) -> bool:
     """If selective enable is active check if DAG or Task is enabled to emit 
events."""
     if not conf.selective_enable():
         return True
     if isinstance(obj, DAG):
         return is_dag_lineage_enabled(obj)
-    elif isinstance(obj, (BaseOperator, MappedOperator)):
+    elif isinstance(obj, (BaseOperator, MappedOperator, SdkBaseOperator)):
         return is_task_lineage_enabled(obj)
     else:
         raise TypeError("is_selective_lineage_enabled can only be used on DAG 
or Operator objects")
 
 
-@provide_session
-def is_ti_rescheduled_already(ti: TaskInstance, session=NEW_SESSION):
-    if not isinstance(ti.task, BaseSensorOperator):
-        return False
-
-    if not ti.task.reschedule:
-        return False
-
-    return (
-        session.query(
-            exists().where(
-                TaskReschedule.dag_id == ti.dag_id,
-                TaskReschedule.task_id == ti.task_id,
-                TaskReschedule.run_id == ti.run_id,
-                TaskReschedule.map_index == ti.map_index,
-                TaskReschedule.try_number == ti.try_number,
-            )
-        ).scalar()
-        is True
-    )
-
-
 class InfoJsonEncodable(dict):
     """
     Airflow objects might not be json-encodable overall.
diff --git a/providers/tests/openlineage/extractors/test_manager.py 
b/providers/tests/openlineage/extractors/test_manager.py
index df64b7d1e75..c3a56f16872 100644
--- a/providers/tests/openlineage/extractors/test_manager.py
+++ b/providers/tests/openlineage/extractors/test_manager.py
@@ -19,6 +19,7 @@ from __future__ import annotations
 
 import tempfile
 from typing import TYPE_CHECKING, Any
+from unittest import mock
 from unittest.mock import MagicMock
 
 import pytest
@@ -28,6 +29,7 @@ from openlineage.client.facet_v2 import (
     ownership_dataset,
     schema_dataset,
 )
+from uuid6 import uuid7
 
 from airflow.io.path import ObjectStoragePath
 from airflow.lineage.entities import Column, File, Table, User
@@ -36,18 +38,22 @@ from airflow.models.taskinstance import TaskInstance
 from airflow.providers.openlineage.extractors import OperatorLineage
 from airflow.providers.openlineage.extractors.manager import ExtractorManager
 from airflow.providers.openlineage.utils.utils import Asset
+from airflow.utils import timezone
 from airflow.utils.state import State
 
 from tests_common.test_utils.compat import PythonOperator
-from tests_common.test_utils.version_compat import AIRFLOW_V_2_10_PLUS
+from tests_common.test_utils.version_compat import AIRFLOW_V_2_10_PLUS, 
AIRFLOW_V_3_0_PLUS
 
 if TYPE_CHECKING:
+    from datetime import datetime
     try:
         from airflow.sdk.definitions.context import Context
     except ImportError:
         # TODO: Remove once provider drops support for Airflow 2
         from airflow.utils.context import Context
 
+    from task_sdk.tests.conftest import MakeTIContextCallable
+
 if AIRFLOW_V_2_10_PLUS:
 
     @pytest.fixture
@@ -65,6 +71,19 @@ if AIRFLOW_V_2_10_PLUS:
         hook._hook_lineage_collector = None
 
 
+if AIRFLOW_V_3_0_PLUS:
+    from airflow.sdk.api.datamodels._generated import TaskInstance as 
SDKTaskInstance
+    from airflow.sdk.execution_time import task_runner
+    from airflow.sdk.execution_time.comms import StartupDetails
+    from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance, 
parse
+else:
+    SDKTaskInstance = ...  # type: ignore
+    task_runner = ...  # type: ignore
+    StartupDetails = ...  # type: ignore
+    RuntimeTaskInstance = ...  # type: ignore
+    parse = ...  # type: ignore
+
+
 @pytest.mark.parametrize(
     ("uri", "dataset"),
     (
@@ -301,7 +320,10 @@ def 
test_extractor_manager_does_not_use_hook_level_lineage_when_operator(
 
 
 @pytest.mark.db_test
[email protected](not AIRFLOW_V_2_10_PLUS, reason="Hook lineage works in 
Airflow >= 2.10.0")
[email protected](
+    not AIRFLOW_V_2_10_PLUS or AIRFLOW_V_3_0_PLUS,
+    reason="Test for hook level lineage in Airflow >= 2.10.0 < 3.0",
+)
 def test_extractor_manager_gets_data_from_pythonoperator(session, dag_maker, 
hook_lineage_collector):
     path = None
     with tempfile.NamedTemporaryFile() as f:
@@ -328,3 +350,121 @@ def 
test_extractor_manager_gets_data_from_pythonoperator(session, dag_maker, hoo
 
     assert len(datasets.outputs) == 1
     assert datasets.outputs[0].asset == Asset(uri=path)
+
+
[email protected]
+def mock_supervisor_comms():
+    with mock.patch(
+        "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True
+    ) as supervisor_comms:
+        yield supervisor_comms
+
+
[email protected]
+def mocked_parse(spy_agency):
+    """
+    Fixture to set up an inline DAG and use it in a stubbed `parse` function. 
Use this fixture if you
+    want to isolate and test `parse` or `run` logic without having to define a 
DAG file.
+
+    This fixture returns a helper function `set_dag` that:
+    1. Creates an in line DAG with the given `dag_id` and `task` (limited to 
one task)
+    2. Constructs a `RuntimeTaskInstance` based on the provided 
`StartupDetails` and task.
+    3. Stubs the `parse` function using `spy_agency`, to return the mocked 
`RuntimeTaskInstance`.
+
+    After adding the fixture in your test function signature, you can use it 
like this ::
+
+            mocked_parse(
+                StartupDetails(
+                    ti=TaskInstance(id=uuid7(), task_id="hello", 
dag_id="super_basic_run", run_id="c", try_number=1),
+                    file="",
+                    requests_fd=0,
+                ),
+                "example_dag_id",
+                CustomOperator(task_id="hello"),
+            )
+    """
+
+    def set_dag(what: StartupDetails, dag_id: str, task: BaseOperator) -> 
RuntimeTaskInstance:
+        from task_sdk.tests.execution_time.test_task_runner import 
get_inline_dag
+
+        dag = get_inline_dag(dag_id, task)
+        t = dag.task_dict[task.task_id]
+        ti = RuntimeTaskInstance.model_construct(
+            **what.ti.model_dump(exclude_unset=True), task=t, 
_ti_context_from_server=what.ti_context
+        )
+        spy_agency.spy_on(parse, call_fake=lambda _: ti)
+        return ti
+
+    return set_dag
+
+
[email protected]
+def make_ti_context() -> MakeTIContextCallable:
+    """Factory for creating TIRunContext objects."""
+    from airflow.sdk.api.datamodels._generated import DagRun, TIRunContext
+
+    def _make_context(
+        dag_id: str = "test_dag",
+        run_id: str = "test_run",
+        logical_date: str | datetime = "2024-12-01T01:00:00Z",
+        data_interval_start: str | datetime = "2024-12-01T00:00:00Z",
+        data_interval_end: str | datetime = "2024-12-01T01:00:00Z",
+        clear_number: int = 0,
+        start_date: str | datetime = "2024-12-01T01:00:00Z",
+        run_type: str = "manual",
+        task_reschedule_count: int = 0,
+    ) -> TIRunContext:
+        return TIRunContext(
+            dag_run=DagRun(
+                dag_id=dag_id,
+                run_id=run_id,
+                logical_date=logical_date,  # type: ignore
+                data_interval_start=data_interval_start,  # type: ignore
+                data_interval_end=data_interval_end,  # type: ignore
+                clear_number=clear_number,  # type: ignore
+                start_date=start_date,  # type: ignore
+                run_type=run_type,  # type: ignore
+            ),
+            task_reschedule_count=task_reschedule_count,
+        )
+
+    return _make_context
+
+
[email protected]_test
[email protected](not AIRFLOW_V_3_0_PLUS, reason="Task SDK related test")
+def test_extractor_manager_gets_data_from_pythonoperator_tasksdk(
+    session, hook_lineage_collector, mocked_parse, make_ti_context, 
mock_supervisor_comms
+):
+    path = None
+    with tempfile.NamedTemporaryFile() as f:
+        path = f.name
+
+        def use_read():
+            storage_path = ObjectStoragePath(path)
+            with storage_path.open("w") as out:
+                out.write("test")
+
+    task = PythonOperator(task_id="test_task_extractor_pythonoperator", 
python_callable=use_read)
+
+    what = StartupDetails(
+        ti=SDKTaskInstance(
+            id=uuid7(),
+            task_id="test_task_extractor_pythonoperator",
+            dag_id="test_hookcollector_dag",
+            run_id="c",
+            try_number=1,
+            start_date=timezone.utcnow(),
+        ),
+        file="",
+        requests_fd=0,
+        ti_context=make_ti_context(),
+    )
+    ti = mocked_parse(what, "test_hookcollector_dag", task)
+
+    task_runner.run(ti, MagicMock())
+
+    datasets = hook_lineage_collector.collected_assets
+
+    assert len(datasets.outputs) == 1
+    assert datasets.outputs[0].asset == Asset(uri=path)
diff --git a/providers/tests/openlineage/plugins/test_listener.py 
b/providers/tests/openlineage/plugins/test_listener.py
index 837873f439d..10896286a7d 100644
--- a/providers/tests/openlineage/plugins/test_listener.py
+++ b/providers/tests/openlineage/plugins/test_listener.py
@@ -18,9 +18,10 @@ from __future__ import annotations
 
 import datetime as dt
 import uuid
+from collections import defaultdict
 from concurrent.futures import Future
 from contextlib import suppress
-from typing import Callable
+from typing import TYPE_CHECKING, Callable
 from unittest import mock
 from unittest.mock import ANY, MagicMock, patch
 
@@ -29,14 +30,16 @@ import pytest
 from openlineage.client import OpenLineageClient
 from openlineage.client.transport import ConsoleTransport
 from openlineage.client.transport.console import ConsoleConfig
+from uuid6 import uuid7
 
 from airflow.models import DAG, DagRun, TaskInstance
 from airflow.models.baseoperator import BaseOperator
+from airflow.operators.empty import EmptyOperator
 from airflow.providers.openlineage.plugins.adapter import OpenLineageAdapter
 from airflow.providers.openlineage.plugins.facets import AirflowDebugRunFacet
 from airflow.providers.openlineage.plugins.listener import OpenLineageListener
 from airflow.providers.openlineage.utils.selective_enable import 
disable_lineage, enable_lineage
-from airflow.utils import types
+from airflow.utils import timezone, types
 from airflow.utils.state import DagRunState, State
 
 from tests_common.test_utils.compat import PythonOperator
@@ -55,6 +58,10 @@ TRY_NUMBER_SUCCESS = 0 if AIRFLOW_V_2_10_PLUS else 2
 TRY_NUMBER_AFTER_EXECUTION = 0 if AIRFLOW_V_2_10_PLUS else 2
 
 
+if TYPE_CHECKING:
+    from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
+
+
 class TemplateOperator(BaseOperator):
     template_fields = ["df"]
 
@@ -74,95 +81,104 @@ def regular_call(self, callable, callable_name, use_fork):
     callable()
 
 
-@patch("airflow.models.TaskInstance.xcom_push")
-@patch("airflow.models.BaseOperator.render_template")
-def test_listener_does_not_change_task_instance(render_mock, xcom_push_mock):
-    render_mock.return_value = render_df()
+class MockExecutor:
+    def __init__(self, *args, **kwargs):
+        self.submitted = False
+        self.succeeded = False
+        self.result = None
 
-    date = dt.datetime(2022, 1, 1)
-    dag = DAG(
-        "test",
-        schedule=None,
-        start_date=date,
-        user_defined_macros={"render_df": render_df},
-        params={"df": {"col": [1, 2]}},
-    )
-    t = TemplateOperator(task_id="template_op", dag=dag, do_xcom_push=True, 
df=dag.param("df"))
-    run_id = str(uuid.uuid1())
-    if AIRFLOW_V_3_0_PLUS:
-        dagrun_kwargs = {
-            "dag_version": None,
-            "logical_date": date,
-            "triggered_by": types.DagRunTriggeredByType.TEST,
-        }
-    else:
-        dagrun_kwargs = {"execution_date": date}
-    dag.create_dagrun(
-        run_id=run_id,
-        data_interval=(date, date),
-        run_type=types.DagRunType.MANUAL,
-        state=DagRunState.QUEUED,
-        **dagrun_kwargs,
-    )
-    ti = TaskInstance(t, run_id=run_id)
-    ti.check_and_change_state_before_execution()  # make listener hook on 
running event
-    ti._run_raw_task()
+    def submit(self, fn, /, *args, **kwargs):
+        self.submitted = True
+        try:
+            fn(*args, **kwargs)
+            self.succeeded = True
+        except Exception:
+            pass
+        return MagicMock()
 
-    # check if task returns the same DataFrame
-    pd.testing.assert_frame_equal(xcom_push_mock.call_args.kwargs["value"], 
render_df())
+    def shutdown(self, *args, **kwargs):
+        print("Shutting down")
 
-    # check if render_template method always get the same unrendered field
-    assert not isinstance(render_mock.call_args.args[0], pd.DataFrame)
 
[email protected](AIRFLOW_V_3_0_PLUS, reason="Airflow 2 tests")
+class TestOpenLineageListenerAirflow2:
+    @patch("airflow.models.TaskInstance.xcom_push")
+    @patch("airflow.models.BaseOperator.render_template")
+    def test_listener_does_not_change_task_instance(self, render_mock, 
xcom_push_mock):
+        render_mock.return_value = render_df()
 
-def _setup_mock_listener(mock_listener: mock.Mock, captured_try_numbers: 
dict[str, int]) -> None:
-    """Sets up the mock listener with side effects to capture try numbers for 
different task instance events.
+        dag = DAG(
+            "test",
+            schedule=None,
+            start_date=dt.datetime(2022, 1, 1),
+            user_defined_macros={"render_df": render_df},
+            params={"df": {"col": [1, 2]}},
+        )
+        t = TemplateOperator(task_id="template_op", dag=dag, 
do_xcom_push=True, df=dag.param("df"))
+        run_id = str(uuid.uuid1())
+        dag.create_dagrun(
+            state=State.NONE,
+            run_id=run_id,
+        )
+        ti = TaskInstance(t, run_id=run_id)
+        ti.check_and_change_state_before_execution()  # make listener hook on 
running event
+        ti._run_raw_task()
 
-    :param mock_listener: The mock object for the listener manager.
-    :param captured_try_numbers: A dictionary to store captured try numbers 
keyed by event names.
+        # check if task returns the same DataFrame
+        
pd.testing.assert_frame_equal(xcom_push_mock.call_args.kwargs["value"], 
render_df())
 
-    This function iterates through specified event names and sets a side 
effect on the corresponding
-    method of the listener manager's hook. The side effect is a nested 
function that captures the try number
-    of the task instance when the method is called.
+        # check if render_template method always get the same unrendered field
+        assert not isinstance(render_mock.call_args.args[0], pd.DataFrame)
 
-    :Example:
+    def _setup_mock_listener(self, mock_listener: mock.Mock, 
captured_try_numbers: dict[str, int]) -> None:
+        """Sets up the mock listener with side effects to capture try numbers 
for different task instance events.
 
-        captured_try_numbers = {}
-        mock_listener = Mock()
-        _setup_mock_listener(mock_listener, captured_try_numbers)
-        # After running a task, captured_try_numbers will have the try number 
captured at the moment of
-        execution for specified methods. F.e. {"running": 1, "success": 2} for 
on_task_instance_running and
-        on_task_instance_success methods.
-    """
+        :param mock_listener: The mock object for the listener manager.
+        :param captured_try_numbers: A dictionary to store captured try 
numbers keyed by event names.
+
+        This function iterates through specified event names and sets a side 
effect on the corresponding
+        method of the listener manager's hook. The side effect is a nested 
function that captures the try number
+        of the task instance when the method is called.
+
+        :Example:
 
-    def capture_try_number(method_name):
-        def inner(*args, **kwargs):
-            captured_try_numbers[method_name] = 
kwargs["task_instance"].try_number
+            captured_try_numbers = {}
+            mock_listener = Mock()
+            _setup_mock_listener(mock_listener, captured_try_numbers)
+            # After running a task, captured_try_numbers will have the try 
number captured at the moment of
+            execution for specified methods. F.e. {"running": 1, "success": 2} 
for on_task_instance_running and
+            on_task_instance_success methods.
+        """
 
-        return inner
+        def capture_try_number(method_name):
+            def inner(*args, **kwargs):
+                captured_try_numbers[method_name] = 
kwargs["task_instance"].try_number
 
-    for event in ["running", "success", "failed"]:
-        getattr(
-            mock_listener.return_value.hook, f"on_task_instance_{event}"
-        ).side_effect = capture_try_number(event)
+            return inner
 
+        for event in ["running", "success", "failed"]:
+            getattr(
+                mock_listener.return_value.hook, f"on_task_instance_{event}"
+            ).side_effect = capture_try_number(event)
 
-def _create_test_dag_and_task(python_callable: Callable, scenario_name: str) 
-> tuple[DagRun, TaskInstance]:
-    """Creates a test DAG and a task for a custom test scenario.
+    def _create_test_dag_and_task(
+        self, python_callable: Callable, scenario_name: str
+    ) -> tuple[DagRun, TaskInstance]:
+        """Creates a test DAG and a task for a custom test scenario.
 
-    :param python_callable: The Python callable to be executed by the 
PythonOperator.
-    :param scenario_name: The name of the test scenario, used to uniquely name 
the DAG and task.
+        :param python_callable: The Python callable to be executed by the 
PythonOperator.
+        :param scenario_name: The name of the test scenario, used to uniquely 
name the DAG and task.
 
-    :return: TaskInstance: The created TaskInstance object.
+        :return: TaskInstance: The created TaskInstance object.
 
     This function creates a DAG and a PythonOperator task with the provided
     python_callable. It generates a unique run ID and creates a DAG run. This
     setup is useful for testing different scenarios in Airflow tasks.
 
-    :Example:
+        :Example:
 
-        def sample_callable(**kwargs):
-            print("Hello World")
+            def sample_callable(**kwargs):
+                print("Hello World")
 
         task_instance = _create_test_dag_and_task(sample_callable, 
"sample_scenario")
         # Use task_instance to simulate running a task in a test.
@@ -175,14 +191,11 @@ def _create_test_dag_and_task(python_callable: Callable, 
scenario_name: str) ->
     )
     t = PythonOperator(task_id=f"test_task_{scenario_name}", dag=dag, 
python_callable=python_callable)
     run_id = str(uuid.uuid1())
-    if AIRFLOW_V_3_0_PLUS:
-        dagrun_kwargs: dict = {
-            "dag_version": None,
-            "logical_date": date,
-            "triggered_by": types.DagRunTriggeredByType.TEST,
-        }
-    else:
-        dagrun_kwargs = {"execution_date": date}
+    dagrun_kwargs: dict = {
+        "dag_version": None,
+        "logical_date": date,
+        "triggered_by": types.DagRunTriggeredByType.TEST,
+    }
     dagrun = dag.create_dagrun(
         run_id=run_id,
         data_interval=(date, date),
@@ -193,501 +206,1120 @@ def _create_test_dag_and_task(python_callable: 
Callable, scenario_name: str) ->
     task_instance = TaskInstance(t, run_id=run_id)
     return dagrun, task_instance
 
+    def _create_listener_and_task_instance(self) -> tuple[OpenLineageListener, 
TaskInstance]:
+        """Creates and configures an OpenLineageListener instance and a mock 
TaskInstance for testing.
 
-def _create_listener_and_task_instance() -> tuple[OpenLineageListener, 
TaskInstance]:
-    """Creates and configures an OpenLineageListener instance and a mock 
TaskInstance for testing.
+        :return: A tuple containing the configured OpenLineageListener and 
TaskInstance.
 
-    :return: A tuple containing the configured OpenLineageListener and 
TaskInstance.
+        This function instantiates an OpenLineageListener, sets up its 
required properties with mock objects, and
+        creates a mock TaskInstance with predefined attributes. This setup is 
commonly used for testing the
+        interaction between an OpenLineageListener and a TaskInstance in 
Airflow.
 
-    This function instantiates an OpenLineageListener, sets up its required 
properties with mock objects, and
-    creates a mock TaskInstance with predefined attributes. This setup is 
commonly used for testing the
-    interaction between an OpenLineageListener and a TaskInstance in Airflow.
+        :Example:
 
-    :Example:
+            listener, task_instance = _create_listener_and_task_instance()
+            # Now you can use listener and task_instance in your tests to 
simulate their interaction.
+        """
 
-        listener, task_instance = _create_listener_and_task_instance()
-        # Now you can use listener and task_instance in your tests to simulate 
their interaction.
-    """
+        def mock_dag_id(dag_id, logical_date, clear_number):
+            return f"{logical_date.isoformat()}.{dag_id}.{clear_number}"
+
+        def mock_task_id(dag_id, task_id, try_number, logical_date, map_index):
+            return 
f"{logical_date.isoformat()}.{dag_id}.{task_id}.{try_number}.{map_index}"
 
-    def mock_dag_id(dag_id, logical_date, clear_number):
-        return f"{logical_date.isoformat()}.{dag_id}.{clear_number}"
-
-    def mock_task_id(dag_id, task_id, try_number, logical_date, map_index):
-        return 
f"{logical_date.isoformat()}.{dag_id}.{task_id}.{try_number}.{map_index}"
-
-    listener = OpenLineageListener()
-    listener.extractor_manager = mock.Mock()
-
-    metadata = mock.Mock()
-    metadata.run_facets = {"run_facet": 1}
-    listener.extractor_manager.extract_metadata.return_value = metadata
-
-    adapter = mock.Mock()
-    adapter.build_dag_run_id.side_effect = mock_dag_id
-    adapter.build_task_instance_run_id.side_effect = mock_task_id
-    adapter.start_task = mock.Mock()
-    adapter.fail_task = mock.Mock()
-    adapter.complete_task = mock.Mock()
-    listener.adapter = adapter
-
-    task_instance = TaskInstance(task=mock.Mock())
-    task_instance.dag_run = DagRun()
-    task_instance.dag_run.run_id = "dag_run_run_id"
-    task_instance.dag_run.data_interval_start = None
-    task_instance.dag_run.data_interval_end = None
-    task_instance.dag_run.clear_number = 0
-    if AIRFLOW_V_3_0_PLUS:
-        task_instance.dag_run.logical_date = dt.datetime(2020, 1, 1, 1, 1, 1)
-    else:
+        listener = OpenLineageListener()
+        listener.extractor_manager = mock.Mock()
+
+        metadata = mock.Mock()
+        metadata.run_facets = {"run_facet": 1}
+        listener.extractor_manager.extract_metadata.return_value = metadata
+
+        adapter = mock.Mock()
+        adapter.build_dag_run_id.side_effect = mock_dag_id
+        adapter.build_task_instance_run_id.side_effect = mock_task_id
+        adapter.start_task = mock.Mock()
+        adapter.fail_task = mock.Mock()
+        adapter.complete_task = mock.Mock()
+        listener.adapter = adapter
+
+        task_instance = TaskInstance(task=mock.Mock())
+        task_instance.dag_run = DagRun()
+        task_instance.dag_run.run_id = "dag_run_run_id"
+        task_instance.dag_run.data_interval_start = None
+        task_instance.dag_run.data_interval_end = None
+        task_instance.dag_run.clear_number = 0
         task_instance.dag_run.execution_date = dt.datetime(2020, 1, 1, 1, 1, 1)
-    task_instance.task = mock.Mock()
-    task_instance.task.task_id = "task_id"
-    task_instance.task.dag = mock.Mock()
-    task_instance.task.dag.dag_id = "dag_id"
-    task_instance.task.dag.description = "Test DAG Description"
-    task_instance.task.dag.owner = "Test Owner"
-    task_instance.task.inlets = []
-    task_instance.task.outlets = []
-    task_instance.dag_id = "dag_id"
-    task_instance.run_id = "dag_run_run_id"
-    task_instance.try_number = 1
-    task_instance.state = State.RUNNING
-    task_instance.start_date = dt.datetime(2023, 1, 1, 13, 1, 1)
-    task_instance.end_date = dt.datetime(2023, 1, 3, 13, 1, 1)
-    task_instance.logical_date = dt.datetime(2020, 1, 1, 1, 1, 1)
-    task_instance.map_index = -1
-    task_instance.next_method = None  # Ensure this is None to reach start_task
-
-    return listener, task_instance
-
-
[email protected]("airflow.providers.openlineage.conf.debug_mode", return_value=True)
[email protected]("airflow.providers.openlineage.plugins.listener.is_operator_disabled")
[email protected]("airflow.providers.openlineage.plugins.listener.get_airflow_run_facet")
[email protected]("airflow.providers.openlineage.plugins.listener.get_airflow_mapped_task_facet")
[email protected]("airflow.providers.openlineage.plugins.listener.get_user_provided_run_facets")
[email protected]("airflow.providers.openlineage.plugins.listener.get_job_name")
[email protected]("airflow.providers.openlineage.plugins.listener.OpenLineageListener._execute",
 new=regular_call)
-def test_adapter_start_task_is_called_with_proper_arguments(
-    mock_get_job_name,
-    mock_get_airflow_mapped_task_facet,
-    mock_get_user_provided_run_facets,
-    mock_get_airflow_run_facet,
-    mock_disabled,
-    mock_debug_mode,
-):
-    """Tests that the 'start_task' method of the OpenLineageAdapter is invoked 
with the correct arguments.
-
-    The test checks that the job name, job description, event time, and other 
related data are
-    correctly passed to the adapter. It also verifies that custom facets and 
Airflow run facets are
-    correctly retrieved and included in the call. This ensures that all 
relevant data, including custom
-    and Airflow-specific metadata, is accurately conveyed to the adapter 
during the initialization of a task,
-    reflecting the comprehensive tracking of task execution contexts."""
-
-    listener, task_instance = _create_listener_and_task_instance()
-    mock_get_job_name.return_value = "job_name"
-    mock_get_airflow_mapped_task_facet.return_value = {"mapped_facet": 1}
-    mock_get_user_provided_run_facets.return_value = {"custom_user_facet": 2}
-    mock_get_airflow_run_facet.return_value = {"airflow_run_facet": 3}
-    mock_disabled.return_value = False
-
-    listener.on_task_instance_running(None, task_instance, None)
-    listener.adapter.start_task.assert_called_once_with(
-        run_id="2020-01-01T01:01:01.dag_id.task_id.1.-1",
-        job_name="job_name",
-        job_description="Test DAG Description",
-        event_time="2023-01-01T13:01:01",
-        parent_job_name="dag_id",
-        parent_run_id="2020-01-01T01:01:01.dag_id.0",
-        code_location=None,
-        nominal_start_time=None,
-        nominal_end_time=None,
-        owners=["Test Owner"],
-        task=listener.extractor_manager.extract_metadata(),
-        run_facets={
-            "mapped_facet": 1,
-            "custom_user_facet": 2,
-            "airflow_run_facet": 3,
-            "debug": AirflowDebugRunFacet(packages=ANY),
-        },
+        task_instance.task = mock.Mock()
+        task_instance.task.task_id = "task_id"
+        task_instance.task.dag = mock.Mock()
+        task_instance.task.dag.dag_id = "dag_id"
+        task_instance.task.dag.description = "Test DAG Description"
+        task_instance.task.dag.owner = "Test Owner"
+        task_instance.task.inlets = []
+        task_instance.task.outlets = []
+        task_instance.dag_id = "dag_id"
+        task_instance.run_id = "dag_run_run_id"
+        task_instance.try_number = 1
+        task_instance.state = State.RUNNING
+        task_instance.start_date = dt.datetime(2023, 1, 1, 13, 1, 1)
+        task_instance.end_date = dt.datetime(2023, 1, 3, 13, 1, 1)
+        task_instance.logical_date = dt.datetime(2020, 1, 1, 1, 1, 1)
+        task_instance.map_index = -1
+        task_instance.next_method = None  # Ensure this is None to reach 
start_task
+        task_instance.get_template_context = mock.MagicMock()  # type: 
ignore[method-assign]
+        task_instance.get_template_context.return_value = 
defaultdict(mock.MagicMock)
+        task_instance.get_template_context()["task_reschedule_count"] = 0
+
+        return listener, task_instance
+
+    @mock.patch("airflow.providers.openlineage.conf.debug_mode", 
return_value=True)
+    
@mock.patch("airflow.providers.openlineage.plugins.listener.is_operator_disabled")
+    
@mock.patch("airflow.providers.openlineage.plugins.listener.get_airflow_run_facet")
+    
@mock.patch("airflow.providers.openlineage.plugins.listener.get_airflow_mapped_task_facet")
+    
@mock.patch("airflow.providers.openlineage.plugins.listener.get_user_provided_run_facets")
+    @mock.patch("airflow.providers.openlineage.plugins.listener.get_job_name")
+    @mock.patch(
+        
"airflow.providers.openlineage.plugins.listener.OpenLineageListener._execute", 
new=regular_call
     )
+    def test_adapter_start_task_is_called_with_proper_arguments(
+        self,
+        mock_get_job_name,
+        mock_get_airflow_mapped_task_facet,
+        mock_get_user_provided_run_facets,
+        mock_get_airflow_run_facet,
+        mock_disabled,
+        mock_debug_mode,
+    ):
+        """Tests that the 'start_task' method of the OpenLineageAdapter is 
invoked with the correct arguments.
+
+        The test checks that the job name, job description, event time, and 
other related data are
+        correctly passed to the adapter. It also verifies that custom facets 
and Airflow run facets are
+        correctly retrieved and included in the call. This ensures that all 
relevant data, including custom
+        and Airflow-specific metadata, is accurately conveyed to the adapter 
during the initialization of a task,
+        reflecting the comprehensive tracking of task execution contexts."""
+
+        listener, task_instance = self._create_listener_and_task_instance()
+        mock_get_job_name.return_value = "job_name"
+        mock_get_airflow_mapped_task_facet.return_value = {"mapped_facet": 1}
+        mock_get_user_provided_run_facets.return_value = {"custom_user_facet": 
2}
+        mock_get_airflow_run_facet.return_value = {"airflow_run_facet": 3}
+        mock_disabled.return_value = False
+
+        listener.on_task_instance_running(None, task_instance, None)
+        listener.adapter.start_task.assert_called_once_with(
+            run_id="2020-01-01T01:01:01.dag_id.task_id.1.-1",
+            job_name="job_name",
+            job_description="Test DAG Description",
+            event_time="2023-01-01T13:01:01",
+            parent_job_name="dag_id",
+            parent_run_id="2020-01-01T01:01:01.dag_id.0",
+            code_location=None,
+            nominal_start_time=None,
+            nominal_end_time=None,
+            owners=["Test Owner"],
+            task=listener.extractor_manager.extract_metadata(),
+            run_facets={
+                "mapped_facet": 1,
+                "custom_user_facet": 2,
+                "airflow_run_facet": 3,
+                "debug": AirflowDebugRunFacet(packages=ANY),
+            },
+        )
 
+    @mock.patch("airflow.providers.openlineage.conf.debug_mode", 
return_value=True)
+    
@mock.patch("airflow.providers.openlineage.plugins.listener.is_operator_disabled")
+    
@mock.patch("airflow.providers.openlineage.plugins.listener.get_airflow_run_facet")
+    
@mock.patch("airflow.providers.openlineage.plugins.listener.get_user_provided_run_facets")
+    @mock.patch("airflow.providers.openlineage.plugins.listener.get_job_name")
+    @mock.patch(
+        
"airflow.providers.openlineage.plugins.listener.OpenLineageListener._execute", 
new=regular_call
+    )
+    @mock.patch("airflow.utils.timezone.utcnow", 
return_value=dt.datetime(2023, 1, 3, 13, 1, 1))
+    def test_adapter_fail_task_is_called_with_proper_arguments(
+        self,
+        mock_utcnow,
+        mock_get_job_name,
+        mock_get_user_provided_run_facets,
+        mock_get_airflow_run_facet,
+        mock_disabled,
+        mock_debug_mode,
+    ):
+        """Tests that the 'fail_task' method of the OpenLineageAdapter is 
invoked with the correct arguments.
+
+        This test ensures that the job name is accurately retrieved and 
included, along with the generated
+        run_id and task metadata. By mocking the job name retrieval and the 
run_id generation,
+        the test verifies the integrity and consistency of the data passed to 
the adapter during task
+        failure events, thus confirming that the adapter's failure handling is 
functioning as expected.
+        """
+
+        listener, task_instance = self._create_listener_and_task_instance()
+        task_instance.logical_date = dt.datetime(2020, 1, 1, 1, 1, 1)
+        mock_get_job_name.return_value = "job_name"
+        mock_get_user_provided_run_facets.return_value = {"custom_user_facet": 
2}
+        mock_get_airflow_run_facet.return_value = {"airflow": {"task": "..."}}
+        mock_disabled.return_value = False
+
+        err = ValueError("test")
+        on_task_failed_listener_kwargs = {"error": err} if AIRFLOW_V_2_10_PLUS 
else {}
+        expected_err_kwargs = {"error": err if AIRFLOW_V_2_10_PLUS else None}
+
+        listener.on_task_instance_failed(
+            previous_state=None, task_instance=task_instance, 
**on_task_failed_listener_kwargs, session=None
+        )
+        listener.adapter.fail_task.assert_called_once_with(
+            end_time="2023-01-03T13:01:01",
+            job_name="job_name",
+            parent_job_name="dag_id",
+            parent_run_id="2020-01-01T01:01:01.dag_id.0",
+            run_id="2020-01-01T01:01:01.dag_id.task_id.1.-1",
+            task=listener.extractor_manager.extract_metadata(),
+            run_facets={
+                "custom_user_facet": 2,
+                "airflow": {"task": "..."},
+                "debug": AirflowDebugRunFacet(packages=ANY),
+            },
+            **expected_err_kwargs,
+        )
 
[email protected]("airflow.providers.openlineage.conf.debug_mode", return_value=True)
[email protected]("airflow.providers.openlineage.plugins.listener.is_operator_disabled")
[email protected]("airflow.providers.openlineage.plugins.listener.get_airflow_run_facet")
[email protected]("airflow.providers.openlineage.plugins.listener.get_user_provided_run_facets")
[email protected]("airflow.providers.openlineage.plugins.listener.get_job_name")
[email protected]("airflow.providers.openlineage.plugins.listener.OpenLineageListener._execute",
 new=regular_call)
-def test_adapter_fail_task_is_called_with_proper_arguments(
-    mock_get_job_name,
-    mock_get_user_provided_run_facets,
-    mock_get_airflow_run_facet,
-    mock_disabled,
-    mock_debug_mode,
-):
-    """Tests that the 'fail_task' method of the OpenLineageAdapter is invoked 
with the correct arguments.
-
-    This test ensures that the job name is accurately retrieved and included, 
along with the generated
-    run_id and task metadata. By mocking the job name retrieval and the run_id 
generation,
-    the test verifies the integrity and consistency of the data passed to the 
adapter during task
-    failure events, thus confirming that the adapter's failure handling is 
functioning as expected.
-    """
-
-    listener, task_instance = _create_listener_and_task_instance()
-    task_instance.logical_date = dt.datetime(2020, 1, 1, 1, 1, 1)
-    mock_get_job_name.return_value = "job_name"
-    mock_get_user_provided_run_facets.return_value = {"custom_user_facet": 2}
-    mock_get_airflow_run_facet.return_value = {"airflow": {"task": "..."}}
-    mock_disabled.return_value = False
+    @mock.patch("airflow.providers.openlineage.conf.debug_mode", 
return_value=True)
+    
@mock.patch("airflow.providers.openlineage.plugins.listener.is_operator_disabled")
+    
@mock.patch("airflow.providers.openlineage.plugins.listener.get_airflow_run_facet")
+    
@mock.patch("airflow.providers.openlineage.plugins.listener.get_user_provided_run_facets")
+    @mock.patch("airflow.providers.openlineage.plugins.listener.get_job_name")
+    @mock.patch(
+        
"airflow.providers.openlineage.plugins.listener.OpenLineageListener._execute", 
new=regular_call
+    )
+    @mock.patch("airflow.utils.timezone.utcnow", 
return_value=dt.datetime(2023, 1, 3, 13, 1, 1))
+    def test_adapter_complete_task_is_called_with_proper_arguments(
+        self,
+        mock_utcnow,
+        mock_get_job_name,
+        mock_get_user_provided_run_facets,
+        mock_get_airflow_run_facet,
+        mock_disabled,
+        mock_debug_mode,
+    ):
+        """Tests that the 'complete_task' method of the OpenLineageAdapter is 
called with the correct arguments.
+
+        It checks that the job name is correctly retrieved and passed,
+        along with the run_id and task metadata. The test also simulates 
changes in the try_number
+        attribute of the task instance, as it would occur in Airflow, to 
ensure that the run_id is updated
+        accordingly. This helps confirm the consistency and correctness of the 
data passed to the adapter
+        during the task's lifecycle events.
+        """
+
+        listener, task_instance = self._create_listener_and_task_instance()
+        mock_get_job_name.return_value = "job_name"
+        mock_get_user_provided_run_facets.return_value = {"custom_user_facet": 
2}
+        mock_get_airflow_run_facet.return_value = {"airflow": {"task": "..."}}
+        mock_disabled.return_value = False
+
+        listener.on_task_instance_success(None, task_instance, None)
+        # This run_id will be different as we did NOT simulate increase of the 
try_number attribute,
+        # which happens in Airflow < 2.10.
+        calls = listener.adapter.complete_task.call_args_list
+        assert len(calls) == 1
+        assert calls[0][1] == dict(
+            end_time="2023-01-03T13:01:01",
+            job_name="job_name",
+            parent_job_name="dag_id",
+            parent_run_id="2020-01-01T01:01:01.dag_id.0",
+            
run_id=f"2020-01-01T01:01:01.dag_id.task_id.{EXPECTED_TRY_NUMBER_1}.-1",
+            task=listener.extractor_manager.extract_metadata(),
+            run_facets={
+                "custom_user_facet": 2,
+                "airflow": {"task": "..."},
+                "debug": AirflowDebugRunFacet(packages=ANY),
+            },
+        )
 
-    err = ValueError("test")
-    on_task_failed_listener_kwargs = {"error": err} if AIRFLOW_V_2_10_PLUS 
else {}
-    expected_err_kwargs = {"error": err if AIRFLOW_V_2_10_PLUS else None}
+    @mock.patch(
+        
"airflow.providers.openlineage.plugins.listener.OpenLineageListener._execute", 
new=regular_call
+    )
+    def 
test_on_task_instance_running_correctly_calls_openlineage_adapter_run_id_method(self):
+        """Tests the OpenLineageListener's response when a task instance is in 
the running state.
+
+        This test ensures that when an Airflow task instance transitions to 
the running state,
+        the OpenLineageAdapter's `build_task_instance_run_id` method is called 
exactly once with the correct
+        parameters derived from the task instance.
+        """
+        listener, task_instance = self._create_listener_and_task_instance()
+        listener.on_task_instance_running(None, task_instance, None)
+        listener.adapter.build_task_instance_run_id.assert_called_once_with(
+            dag_id="dag_id",
+            task_id="task_id",
+            logical_date=dt.datetime(2020, 1, 1, 1, 1, 1),
+            try_number=1,
+            map_index=-1,
+        )
 
-    listener.on_task_instance_failed(
-        previous_state=None, task_instance=task_instance, session=None, 
**on_task_failed_listener_kwargs
+    @mock.patch(
+        
"airflow.providers.openlineage.plugins.listener.OpenLineageListener._execute", 
new=regular_call
     )
-    listener.adapter.fail_task.assert_called_once_with(
-        end_time="2023-01-03T13:01:01",
-        job_name="job_name",
-        parent_job_name="dag_id",
-        parent_run_id="2020-01-01T01:01:01.dag_id.0",
-        run_id="2020-01-01T01:01:01.dag_id.task_id.1.-1",
-        task=listener.extractor_manager.extract_metadata(),
-        run_facets={
-            "custom_user_facet": 2,
-            "airflow": {"task": "..."},
-            "debug": AirflowDebugRunFacet(packages=ANY),
-        },
-        **expected_err_kwargs,
+    def 
test_on_task_instance_failed_correctly_calls_openlineage_adapter_run_id_method(self):
+        """Tests the OpenLineageListener's response when a task instance is in 
the failed state.
+
+        This test ensures that when an Airflow task instance transitions to 
the failed state,
+        the OpenLineageAdapter's `build_task_instance_run_id` method is called 
exactly once with the correct
+        parameters derived from the task instance.
+        """
+        listener, task_instance = self._create_listener_and_task_instance()
+        on_task_failed_kwargs = {"error": ValueError("test")} if 
AIRFLOW_V_2_10_PLUS else {}
+
+        listener.on_task_instance_failed(
+            previous_state=None, task_instance=task_instance, 
**on_task_failed_kwargs, session=None
+        )
+        listener.adapter.build_task_instance_run_id.assert_called_once_with(
+            dag_id="dag_id",
+            task_id="task_id",
+            logical_date=dt.datetime(2020, 1, 1, 1, 1, 1),
+            try_number=1,
+            map_index=-1,
+        )
+
+    @mock.patch(
+        
"airflow.providers.openlineage.plugins.listener.OpenLineageListener._execute", 
new=regular_call
     )
+    def 
test_on_task_instance_success_correctly_calls_openlineage_adapter_run_id_method(self):
+        """Tests the OpenLineageListener's response when a task instance is in 
the success state.
+
+        This test ensures that when an Airflow task instance transitions to 
the success state,
+        the OpenLineageAdapter's `build_task_instance_run_id` method is called 
exactly once with the correct
+        parameters derived from the task instance.
+        """
+        listener, task_instance = self._create_listener_and_task_instance()
+        listener.on_task_instance_success(None, task_instance, None)
+        listener.adapter.build_task_instance_run_id.assert_called_once_with(
+            dag_id="dag_id",
+            task_id="task_id",
+            logical_date=dt.datetime(2020, 1, 1, 1, 1, 1),
+            try_number=EXPECTED_TRY_NUMBER_1,
+            map_index=-1,
+        )
 
+    @mock.patch("airflow.models.taskinstance.get_listener_manager")
+    def 
test_listener_on_task_instance_failed_is_called_before_try_number_increment(self,
 mock_listener):
+        """Validates the listener's on-failure method is called before 
try_number increment happens.
 
[email protected]("airflow.providers.openlineage.conf.debug_mode", return_value=True)
[email protected]("airflow.providers.openlineage.plugins.listener.is_operator_disabled")
[email protected]("airflow.providers.openlineage.plugins.listener.get_airflow_run_facet")
[email protected]("airflow.providers.openlineage.plugins.listener.get_user_provided_run_facets")
[email protected]("airflow.providers.openlineage.plugins.listener.get_job_name")
[email protected]("airflow.providers.openlineage.plugins.listener.OpenLineageListener._execute",
 new=regular_call)
-def test_adapter_complete_task_is_called_with_proper_arguments(
-    mock_get_job_name,
-    mock_get_user_provided_run_facets,
-    mock_get_airflow_run_facet,
-    mock_disabled,
-    mock_debug_mode,
-):
-    """Tests that the 'complete_task' method of the OpenLineageAdapter is 
called with the correct arguments.
-
-    It checks that the job name is correctly retrieved and passed,
-    along with the run_id and task metadata. The test also simulates changes 
in the try_number
-    attribute of the task instance, as it would occur in Airflow, to ensure 
that the run_id is updated
-    accordingly. This helps confirm the consistency and correctness of the 
data passed to the adapter
-    during the task's lifecycle events.
-    """
+        This test ensures that when a task instance fails, Airflow's listener 
method for
+        task failure (`on_task_instance_failed`) is invoked before the 
increment of the
+        `try_number` attribute happens. A custom exception simulates task 
failure, and the test
+        captures the `try_number` at the moment of this method call.
+        """
+        captured_try_numbers = {}
+        self._setup_mock_listener(mock_listener, captured_try_numbers)
 
-    listener, task_instance = _create_listener_and_task_instance()
-    mock_get_job_name.return_value = "job_name"
-    mock_get_user_provided_run_facets.return_value = {"custom_user_facet": 2}
-    mock_get_airflow_run_facet.return_value = {"airflow": {"task": "..."}}
-    mock_disabled.return_value = False
-
-    listener.on_task_instance_success(None, task_instance, None)
-    # This run_id will be different as we did NOT simulate increase of the 
try_number attribute,
-    # which happens in Airflow < 2.10.
-    calls = listener.adapter.complete_task.call_args_list
-    assert len(calls) == 1
-    assert calls[0][1] == dict(
-        end_time="2023-01-03T13:01:01",
-        job_name="job_name",
-        parent_job_name="dag_id",
-        parent_run_id="2020-01-01T01:01:01.dag_id.0",
-        
run_id=f"2020-01-01T01:01:01.dag_id.task_id.{EXPECTED_TRY_NUMBER_1}.-1",
-        task=listener.extractor_manager.extract_metadata(),
-        run_facets={
-            "custom_user_facet": 2,
-            "airflow": {"task": "..."},
-            "debug": AirflowDebugRunFacet(packages=ANY),
-        },
-    )
+        # Just to make sure no error interferes with the test, and we do not 
suppress it by accident
+        class CustomError(Exception):
+            pass
 
+        def fail_callable(**kwargs):
+            raise CustomError("Simulated task failure")
 
[email protected]("airflow.providers.openlineage.plugins.listener.OpenLineageListener._execute",
 new=regular_call)
-def 
test_on_task_instance_running_correctly_calls_openlineage_adapter_run_id_method():
-    """Tests the OpenLineageListener's response when a task instance is in the 
running state.
+        _, task_instance = self._create_test_dag_and_task(fail_callable, 
"failure")
+        # try_number before execution
+        assert task_instance.try_number == TRY_NUMBER_BEFORE_EXECUTION
+        with suppress(CustomError):
+            task_instance.run()
 
-    This test ensures that when an Airflow task instance transitions to the 
running state,
-    the OpenLineageAdapter's `build_task_instance_run_id` method is called 
exactly once with the correct
-    parameters derived from the task instance.
-    """
-    listener, task_instance = _create_listener_and_task_instance()
-    listener.on_task_instance_running(None, task_instance, None)
-    listener.adapter.build_task_instance_run_id.assert_called_once_with(
-        dag_id="dag_id",
-        task_id="task_id",
-        logical_date=dt.datetime(2020, 1, 1, 1, 1, 1),
-        try_number=1,
-        map_index=-1,
-    )
+        # try_number at the moment of function being called
+        assert captured_try_numbers["running"] == TRY_NUMBER_RUNNING
+        assert captured_try_numbers["failed"] == TRY_NUMBER_FAILED
 
+        # try_number after task has been executed
+        assert task_instance.try_number == TRY_NUMBER_AFTER_EXECUTION
 
[email protected]("airflow.providers.openlineage.plugins.listener.OpenLineageListener._execute",
 new=regular_call)
-def 
test_on_task_instance_failed_correctly_calls_openlineage_adapter_run_id_method():
-    """Tests the OpenLineageListener's response when a task instance is in the 
failed state.
+    @mock.patch("airflow.models.taskinstance.get_listener_manager")
+    def 
test_listener_on_task_instance_success_is_called_after_try_number_increment(self,
 mock_listener):
+        """Validates the listener's on-success method is called before 
try_number increment happens.
 
-    This test ensures that when an Airflow task instance transitions to the 
failed state,
-    the OpenLineageAdapter's `build_task_instance_run_id` method is called 
exactly once with the correct
-    parameters derived from the task instance.
-    """
-    listener, task_instance = _create_listener_and_task_instance()
-    on_task_failed_kwargs = {"error": ValueError("test")} if 
AIRFLOW_V_2_10_PLUS else {}
+        This test ensures that when a task instance successfully completes, the
+        `on_task_instance_success` method of Airflow's listener is called with 
an
+        incremented `try_number` compared to the `try_number` before execution.
+        The test simulates a successful task execution and captures the 
`try_number` at the method call.
+        """
+        captured_try_numbers = {}
+        self._setup_mock_listener(mock_listener, captured_try_numbers)
+
+        def success_callable(**kwargs):
+            return None
+
+        _, task_instance = self._create_test_dag_and_task(success_callable, 
"success")
+        # try_number before execution
+        assert task_instance.try_number == TRY_NUMBER_BEFORE_EXECUTION
+        task_instance.run()
+
+        # try_number at the moment of function being called
+        assert captured_try_numbers["running"] == TRY_NUMBER_RUNNING
+        assert captured_try_numbers["success"] == TRY_NUMBER_SUCCESS
+
+        # try_number after task has been executed
+        assert task_instance.try_number == TRY_NUMBER_AFTER_EXECUTION
+
+    
@mock.patch("airflow.providers.openlineage.plugins.listener.is_operator_disabled")
+    
@mock.patch("airflow.providers.openlineage.plugins.listener.get_airflow_run_facet")
+    
@mock.patch("airflow.providers.openlineage.plugins.listener.get_user_provided_run_facets")
+    @mock.patch("airflow.providers.openlineage.plugins.listener.get_job_name")
+    def 
test_listener_on_task_instance_running_do_not_call_adapter_when_disabled_operator(
+        self, mock_get_job_name, mock_get_user_provided_run_facets, 
mock_get_airflow_run_facet, mock_disabled
+    ):
+        listener, task_instance = self._create_listener_and_task_instance()
+        mock_get_job_name.return_value = "job_name"
+        mock_get_user_provided_run_facets.return_value = {"custom_facet": 2}
+        mock_get_airflow_run_facet.return_value = {"airflow_run_facet": 3}
+        mock_disabled.return_value = True
+
+        listener.on_task_instance_running(None, task_instance, None)
+        mock_disabled.assert_called_once_with(task_instance.task)
+        listener.adapter.build_dag_run_id.assert_not_called()
+        listener.adapter.build_task_instance_run_id.assert_not_called()
+        listener.extractor_manager.extract_metadata.assert_not_called()
+        listener.adapter.start_task.assert_not_called()
+
+    
@mock.patch("airflow.providers.openlineage.plugins.listener.is_operator_disabled")
+    
@mock.patch("airflow.providers.openlineage.plugins.listener.get_user_provided_run_facets")
+    @mock.patch("airflow.providers.openlineage.plugins.listener.get_job_name")
+    def 
test_listener_on_task_instance_failed_do_not_call_adapter_when_disabled_operator(
+        self, mock_get_job_name, mock_get_user_provided_run_facets, 
mock_disabled
+    ):
+        listener, task_instance = self._create_listener_and_task_instance()
+        mock_get_user_provided_run_facets.return_value = {"custom_facet": 2}
+        mock_disabled.return_value = True
+
+        on_task_failed_kwargs = {"error": ValueError("test")} if 
AIRFLOW_V_2_10_PLUS else {}
 
-    listener.on_task_instance_failed(
-        previous_state=None, task_instance=task_instance, session=None, 
**on_task_failed_kwargs
+        listener.on_task_instance_failed(
+            previous_state=None, task_instance=task_instance, 
**on_task_failed_kwargs, session=None
+        )
+        mock_disabled.assert_called_once_with(task_instance.task)
+        listener.adapter.build_dag_run_id.assert_not_called()
+        listener.adapter.build_task_instance_run_id.assert_not_called()
+        listener.extractor_manager.extract_metadata.assert_not_called()
+        listener.adapter.fail_task.assert_not_called()
+
+    
@mock.patch("airflow.providers.openlineage.plugins.listener.is_operator_disabled")
+    
@mock.patch("airflow.providers.openlineage.plugins.listener.get_user_provided_run_facets")
+    @mock.patch("airflow.providers.openlineage.plugins.listener.get_job_name")
+    def 
test_listener_on_task_instance_success_do_not_call_adapter_when_disabled_operator(
+        self, mock_get_job_name, mock_get_user_provided_run_facets, 
mock_disabled
+    ):
+        listener, task_instance = self._create_listener_and_task_instance()
+        mock_get_user_provided_run_facets.return_value = {"custom_facet": 2}
+        mock_disabled.return_value = True
+
+        listener.on_task_instance_success(None, task_instance, None)
+        mock_disabled.assert_called_once_with(task_instance.task)
+        listener.adapter.build_dag_run_id.assert_not_called()
+        listener.adapter.build_task_instance_run_id.assert_not_called()
+        listener.extractor_manager.extract_metadata.assert_not_called()
+        listener.adapter.complete_task.assert_not_called()
+
+    @pytest.mark.parametrize(
+        "max_workers,expected",
+        [
+            (None, 1),
+            ("8", 8),
+        ],
     )
-    listener.adapter.build_task_instance_run_id.assert_called_once_with(
-        dag_id="dag_id",
-        task_id="task_id",
-        logical_date=dt.datetime(2020, 1, 1, 1, 1, 1),
-        try_number=1,
-        map_index=-1,
+    
@mock.patch("airflow.providers.openlineage.plugins.listener.ProcessPoolExecutor",
 autospec=True)
+    def test_listener_on_dag_run_state_changes_configure_process_pool_size(
+        self, mock_executor, max_workers, expected
+    ):
+        """mock ProcessPoolExecutor and check if 
conf.dag_state_change_process_pool_size is applied to max_workers"""
+        listener = OpenLineageListener()
+        # mock ProcessPoolExecutor class
+        with conf_vars({("openlineage", "dag_state_change_process_pool_size"): 
max_workers}):
+            listener.on_dag_run_running(mock.MagicMock(), None)
+        mock_executor.assert_called_once_with(max_workers=expected, 
initializer=mock.ANY)
+        mock_executor.return_value.submit.assert_called_once()
+
+    @pytest.mark.parametrize(
+        ("method", "dag_run_state"),
+        [
+            ("on_dag_run_running", DagRunState.RUNNING),
+            ("on_dag_run_success", DagRunState.SUCCESS),
+            ("on_dag_run_failed", DagRunState.FAILED),
+        ],
     )
+    
@patch("airflow.providers.openlineage.plugins.adapter.OpenLineageAdapter.emit")
+    def test_listener_on_dag_run_state_changes(self, mock_emit, method, 
dag_run_state, create_task_instance):
+        mock_executor = MockExecutor()
+        ti = create_task_instance(dag_id="dag", task_id="op")
+        # Change the state explicitly to set end_date following the logic in 
the method
+        ti.dag_run.set_state(dag_run_state)
+        with mock.patch(
+            
"airflow.providers.openlineage.plugins.listener.ProcessPoolExecutor", 
return_value=mock_executor
+        ):
+            listener = OpenLineageListener()
+            getattr(listener, method)(ti.dag_run, None)
+            assert mock_executor.submitted is True
+            assert mock_executor.succeeded is True
+            mock_emit.assert_called_once()
 
+    def test_listener_logs_failed_serialization(self):
+        listener = OpenLineageListener()
+        callback_future = Future()
 
[email protected]("airflow.providers.openlineage.plugins.listener.OpenLineageListener._execute",
 new=regular_call)
-def 
test_on_task_instance_success_correctly_calls_openlineage_adapter_run_id_method():
-    """Tests the OpenLineageListener's response when a task instance is in the 
success state.
+        def set_result(*args, **kwargs):
+            callback_future.set_result(True)
 
-    This test ensures that when an Airflow task instance transitions to the 
success state,
-    the OpenLineageAdapter's `build_task_instance_run_id` method is called 
exactly once with the correct
-    parameters derived from the task instance.
-    """
-    listener, task_instance = _create_listener_and_task_instance()
-    listener.on_task_instance_success(None, task_instance, None)
-    listener.adapter.build_task_instance_run_id.assert_called_once_with(
-        dag_id="dag_id",
-        task_id="task_id",
-        logical_date=dt.datetime(2020, 1, 1, 1, 1, 1),
-        try_number=EXPECTED_TRY_NUMBER_1,
-        map_index=-1,
-    )
+        listener.log = MagicMock()
+        listener.log.warning = MagicMock(side_effect=set_result)
+        listener.adapter = OpenLineageAdapter(
+            
client=OpenLineageClient(transport=ConsoleTransport(config=ConsoleConfig()))
+        )
+        event_time = dt.datetime.now()
+        fut = listener.submit_callable(
+            listener.adapter.dag_failed,
+            dag_id="",
+            run_id="",
+            end_date=event_time,
+            logical_date=callback_future,
+            clear_number=0,
+            dag_run_state=DagRunState.FAILED,
+            task_ids=["task_id"],
+            msg="",
+        )
+        assert fut.exception(10)
+        callback_future.result(10)
+        assert callback_future.done()
+        listener.log.debug.assert_not_called()
+        listener.log.warning.assert_called_once()
 
 
[email protected]("airflow.models.taskinstance.get_listener_manager")
-def 
test_listener_on_task_instance_failed_is_called_before_try_number_increment(mock_listener):
-    """Validates the listener's on-failure method is called before try_number 
increment happens.
[email protected]
+def mock_supervisor_comms():
+    with mock.patch(
+        "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True
+    ) as supervisor_comms:
+        yield supervisor_comms
+
+
[email protected](not AIRFLOW_V_3_0_PLUS, reason="Airflow 3 tests")
+class TestOpenLineageListenerAirflow3:
+    @pytest.mark.skip("Rendering fields is not migrated yet in Airflow 3")
+    @patch("airflow.models.BaseOperator.render_template")
+    def test_listener_does_not_change_task_instance(self, render_mock, 
mock_supervisor_comms, spy_agency):
+        from airflow.sdk.execution_time.task_runner import (
+            RuntimeTaskInstance,
+            TaskInstance as SdkTaskInstance,
+            run,
+        )
 
-    This test ensures that when a task instance fails, Airflow's listener 
method for
-    task failure (`on_task_instance_failed`) is invoked before the increment 
of the
-    `try_number` attribute happens. A custom exception simulates task failure, 
and the test
-    captures the `try_number` at the moment of this method call.
-    """
-    captured_try_numbers = {}
-    _setup_mock_listener(mock_listener, captured_try_numbers)
+        render_mock.return_value = render_df()
 
-    # Just to make sure no error interferes with the test, and we do not 
suppress it by accident
-    class CustomError(Exception):
-        pass
+        date = dt.datetime(2022, 1, 1)
+        dag = DAG(
+            "test",
+            schedule=None,
+            start_date=dt.datetime(2022, 1, 1),
+            user_defined_macros={"render_df": render_df},
+            params={"df": {"col": [1, 2]}},
+        )
+        task = TemplateOperator(task_id="template_op", dag=dag, 
do_xcom_push=True, df=dag.param("df"))
+        run_id = str(uuid.uuid1())
 
-    def fail_callable(**kwargs):
-        raise CustomError("Simulated task failure")
+        dagrun_kwargs = {
+            "dag_version": None,
+            "logical_date": date,
+            "triggered_by": types.DagRunTriggeredByType.TEST,
+        }
 
-    _, task_instance = _create_test_dag_and_task(fail_callable, "failure")
-    # try_number before execution
-    assert task_instance.try_number == TRY_NUMBER_BEFORE_EXECUTION
-    with suppress(CustomError):
-        task_instance.run()
+        dag.create_dagrun(
+            run_id=run_id,
+            data_interval=(date, date),
+            run_type=types.DagRunType.MANUAL,
+            state=DagRunState.QUEUED,
+            **dagrun_kwargs,
+        )
+        ti = SdkTaskInstance(
+            id=uuid7(),
+            task_id="template_op",
+            dag_id=dag.dag_id,
+            run_id=run_id,
+            try_number=1,
+            start_date=timezone.utcnow(),
+            map_index=-1,
+        )
 
-    # try_number at the moment of function being called
-    assert captured_try_numbers["running"] == TRY_NUMBER_RUNNING
-    assert captured_try_numbers["failed"] == TRY_NUMBER_FAILED
+        runtime_ti = 
RuntimeTaskInstance.model_construct(**ti.model_dump(exclude_unset=True), 
task=task)
 
-    # try_number after task has been executed
-    assert task_instance.try_number == TRY_NUMBER_AFTER_EXECUTION
+        spy_agency.spy_on(runtime_ti.xcom_push, call_original=False)
+        run(runtime_ti, None)
 
+        # check if task returns the same DataFrame
+        pd.testing.assert_frame_equal(runtime_ti.xcom_push.last_call.args[1], 
render_df())
 
[email protected]("airflow.models.taskinstance.get_listener_manager")
-def 
test_listener_on_task_instance_success_is_called_after_try_number_increment(mock_listener):
-    """Validates the listener's on-success method is called before try_number 
increment happens.
+        # check if render_template method always get the same unrendered field
+        assert not isinstance(runtime_ti.xcom_push.last_call.args[1], 
pd.DataFrame)
 
-    This test ensures that when a task instance successfully completes, the
-    `on_task_instance_success` method of Airflow's listener is called with an
-    incremented `try_number` compared to the `try_number` before execution.
-    The test simulates a successful task execution and captures the 
`try_number` at the method call.
-    """
-    captured_try_numbers = {}
-    _setup_mock_listener(mock_listener, captured_try_numbers)
-
-    def success_callable(**kwargs):
-        return None
-
-    _, task_instance = _create_test_dag_and_task(success_callable, "success")
-    # try_number before execution
-    assert task_instance.try_number == TRY_NUMBER_BEFORE_EXECUTION
-    task_instance.run()
-
-    # try_number at the moment of function being called
-    assert captured_try_numbers["running"] == TRY_NUMBER_RUNNING
-    assert captured_try_numbers["success"] == TRY_NUMBER_SUCCESS
-
-    # try_number after task has been executed
-    assert task_instance.try_number == TRY_NUMBER_AFTER_EXECUTION
-
-
[email protected]("airflow.providers.openlineage.plugins.listener.is_operator_disabled")
[email protected]("airflow.providers.openlineage.plugins.listener.get_airflow_run_facet")
[email protected]("airflow.providers.openlineage.plugins.listener.get_user_provided_run_facets")
[email protected]("airflow.providers.openlineage.plugins.listener.get_job_name")
-def 
test_listener_on_task_instance_running_do_not_call_adapter_when_disabled_operator(
-    mock_get_job_name, mock_get_user_provided_run_facets, 
mock_get_airflow_run_facet, mock_disabled
-):
-    listener, task_instance = _create_listener_and_task_instance()
-    mock_get_job_name.return_value = "job_name"
-    mock_get_user_provided_run_facets.return_value = {"custom_facet": 2}
-    mock_get_airflow_run_facet.return_value = {"airflow_run_facet": 3}
-    mock_disabled.return_value = True
-
-    listener.on_task_instance_running(None, task_instance, None)
-    mock_disabled.assert_called_once_with(task_instance.task)
-    listener.adapter.build_dag_run_id.assert_not_called()
-    listener.adapter.build_task_instance_run_id.assert_not_called()
-    listener.extractor_manager.extract_metadata.assert_not_called()
-    listener.adapter.start_task.assert_not_called()
-
-
[email protected]("airflow.providers.openlineage.plugins.listener.is_operator_disabled")
[email protected]("airflow.providers.openlineage.plugins.listener.get_user_provided_run_facets")
[email protected]("airflow.providers.openlineage.plugins.listener.get_job_name")
-def 
test_listener_on_task_instance_failed_do_not_call_adapter_when_disabled_operator(
-    mock_get_job_name, mock_get_user_provided_run_facets, mock_disabled
-):
-    listener, task_instance = _create_listener_and_task_instance()
-    mock_get_user_provided_run_facets.return_value = {"custom_facet": 2}
-    mock_disabled.return_value = True
-
-    on_task_failed_kwargs = {"error": ValueError("test")} if 
AIRFLOW_V_2_10_PLUS else {}
-
-    listener.on_task_instance_failed(
-        previous_state=None, task_instance=task_instance, session=None, 
**on_task_failed_kwargs
+    def _setup_mock_listener(self, mock_listener: mock.Mock, 
captured_try_numbers: dict[str, int]) -> None:
+        """Sets up the mock listener with side effects to capture try numbers 
for different task instance events.
+
+        :param mock_listener: The mock object for the listener manager.
+        :param captured_try_numbers: A dictionary to store captured try 
numbers keyed by event names.
+
+        This function iterates through specified event names and sets a side 
effect on the corresponding
+        method of the listener manager's hook. The side effect is a nested 
function that captures the try number
+        of the task instance when the method is called.
+
+        :Example:
+
+            captured_try_numbers = {}
+            mock_listener = Mock()
+            _setup_mock_listener(mock_listener, captured_try_numbers)
+            # After running a task, captured_try_numbers will have the try 
number captured at the moment of
+            execution for specified methods. F.e. {"running": 1, "success": 2} 
for on_task_instance_running and
+            on_task_instance_success methods.
+        """
+
+        def capture_try_number(method_name):
+            def inner(*args, **kwargs):
+                captured_try_numbers[method_name] = 
kwargs["task_instance"].try_number
+
+            return inner
+
+        for event in ["running", "success", "failed"]:
+            getattr(
+                mock_listener.return_value.hook, f"on_task_instance_{event}"
+            ).side_effect = capture_try_number(event)
+
+    def _create_test_dag_and_task(
+        self, python_callable: Callable, scenario_name: str
+    ) -> tuple[DagRun, TaskInstance]:
+        """Creates a test DAG and a task for a custom test scenario.
+
+        :param python_callable: The Python callable to be executed by the 
PythonOperator.
+        :param scenario_name: The name of the test scenario, used to uniquely 
name the DAG and task.
+
+        :return: TaskInstance: The created TaskInstance object.
+
+        This function creates a DAG and a PythonOperator task with the 
provided python_callable. It generates a unique
+        run ID and creates a DAG run. This setup is useful for testing 
different scenarios in Airflow tasks.
+
+        :Example:
+
+            def sample_callable(**kwargs):
+                print("Hello World")
+
+            task_instance = _create_test_dag_and_task(sample_callable, 
"sample_scenario")
+            # Use task_instance to simulate running a task in a test.
+        """
+        dag = DAG(
+            f"test_{scenario_name}",
+            schedule=None,
+            start_date=dt.datetime(2022, 1, 1),
+        )
+        t = PythonOperator(task_id=f"test_task_{scenario_name}", dag=dag, 
python_callable=python_callable)
+        run_id = str(uuid.uuid1())
+        triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST}
+        dagrun = dag.create_dagrun(
+            state=State.NONE,  # type: ignore
+            run_id=run_id,
+            **triggered_by_kwargs,  # type: ignore
+        )
+        task_instance = TaskInstance(t, run_id=run_id)
+        return dagrun, task_instance
+
+    def _create_listener_and_task_instance(self) -> tuple[OpenLineageListener, 
RuntimeTaskInstance]:
+        """Creates and configures an OpenLineageListener instance and a mock 
TaskInstance for testing.
+
+        :return: A tuple containing the configured OpenLineageListener and 
TaskInstance.
+
+        This function instantiates an OpenLineageListener, sets up its 
required properties with mock objects, and
+        creates a mock TaskInstance with predefined attributes. This setup is 
commonly used for testing the
+        interaction between an OpenLineageListener and a TaskInstance in 
Airflow.
+
+        :Example:
+
+            listener, task_instance = _create_listener_and_task_instance()
+            # Now you can use listener and task_instance in your tests to 
simulate their interaction.
+        """
+
+        from airflow.sdk.api.datamodels._generated import (
+            DagRun as SdkDagRun,
+            DagRunType,
+            TaskInstance as SdkTaskInstance,
+            TIRunContext,
+        )
+        from airflow.sdk.definitions.dag import DAG
+        from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
+
+        def mock_dag_id(dag_id, logical_date, clear_number):
+            return f"{logical_date.isoformat()}.{dag_id}.{clear_number}"
+
+        def mock_task_id(dag_id, task_id, try_number, logical_date, map_index):
+            return 
f"{logical_date.isoformat()}.{dag_id}.{task_id}.{try_number}.{map_index}"
+
+        listener = OpenLineageListener()
+        listener.extractor_manager = mock.Mock()
+
+        metadata = mock.Mock()
+        metadata.run_facets = {"run_facet": 1}
+        listener.extractor_manager.extract_metadata.return_value = metadata
+
+        adapter = mock.Mock()
+        adapter.build_dag_run_id.side_effect = mock_dag_id
+        adapter.build_task_instance_run_id.side_effect = mock_task_id
+        adapter.start_task = mock.Mock()
+        adapter.fail_task = mock.Mock()
+        adapter.complete_task = mock.Mock()
+        listener.adapter = adapter
+
+        dag = DAG(
+            dag_id="dag_id",
+            description="Test DAG Description",
+        )
+        task = EmptyOperator(task_id="task_id", dag=dag, owner="Test Owner")
+
+        task_instance = SdkTaskInstance(
+            id=uuid7(),
+            task_id="task_id",
+            dag_id="dag_id",
+            run_id="dag_run_run_id",
+            try_number=1,
+            start_date=dt.datetime(2023, 1, 1, 13, 1, 1),
+            map_index=-1,
+        )
+        runtime_ti = RuntimeTaskInstance.model_construct(
+            **task_instance.model_dump(exclude_unset=True),
+            task=task,
+            _ti_context_from_server=TIRunContext(
+                dag_run=SdkDagRun(
+                    dag_id="dag_id",
+                    run_id="dag_run_run_id",
+                    logical_date=dt.datetime(2020, 1, 1, 1, 1, 1),
+                    data_interval_start=None,
+                    data_interval_end=None,
+                    start_date=dt.datetime(2023, 1, 1, 13, 1, 1),
+                    end_date=dt.datetime(2023, 1, 3, 13, 1, 1),
+                    clear_number=0,
+                    run_type=DagRunType.MANUAL,
+                    conf=None,
+                ),
+                task_reschedule_count=0,
+            ),
+        )
+
+        return listener, runtime_ti
+
+    @mock.patch("airflow.providers.openlineage.conf.debug_mode", 
return_value=True)
+    
@mock.patch("airflow.providers.openlineage.plugins.listener.is_operator_disabled")
+    
@mock.patch("airflow.providers.openlineage.plugins.listener.get_airflow_run_facet")
+    
@mock.patch("airflow.providers.openlineage.plugins.listener.get_airflow_mapped_task_facet")
+    
@mock.patch("airflow.providers.openlineage.plugins.listener.get_user_provided_run_facets")
+    @mock.patch("airflow.providers.openlineage.plugins.listener.get_job_name")
+    @mock.patch(
+        
"airflow.providers.openlineage.plugins.listener.OpenLineageListener._execute", 
new=regular_call
     )
-    mock_disabled.assert_called_once_with(task_instance.task)
-    listener.adapter.build_dag_run_id.assert_not_called()
-    listener.adapter.build_task_instance_run_id.assert_not_called()
-    listener.extractor_manager.extract_metadata.assert_not_called()
-    listener.adapter.fail_task.assert_not_called()
-
-
[email protected]("airflow.providers.openlineage.plugins.listener.is_operator_disabled")
[email protected]("airflow.providers.openlineage.plugins.listener.get_user_provided_run_facets")
[email protected]("airflow.providers.openlineage.plugins.listener.get_job_name")
-def 
test_listener_on_task_instance_success_do_not_call_adapter_when_disabled_operator(
-    mock_get_job_name, mock_get_user_provided_run_facets, mock_disabled
-):
-    listener, task_instance = _create_listener_and_task_instance()
-    mock_get_user_provided_run_facets.return_value = {"custom_facet": 2}
-    mock_disabled.return_value = True
-
-    listener.on_task_instance_success(None, task_instance, None)
-    mock_disabled.assert_called_once_with(task_instance.task)
-    listener.adapter.build_dag_run_id.assert_not_called()
-    listener.adapter.build_task_instance_run_id.assert_not_called()
-    listener.extractor_manager.extract_metadata.assert_not_called()
-    listener.adapter.complete_task.assert_not_called()
-
-
[email protected](
-    "max_workers,expected",
-    [
-        (None, 1),
-        ("8", 8),
-    ],
-)
[email protected]("airflow.providers.openlineage.plugins.listener.ProcessPoolExecutor",
 autospec=True)
-def 
test_listener_on_dag_run_state_changes_configure_process_pool_size(mock_executor,
 max_workers, expected):
-    """mock ProcessPoolExecutor and check if 
conf.dag_state_change_process_pool_size is applied to max_workers"""
-    listener = OpenLineageListener()
-    # mock ProcessPoolExecutor class
-    with conf_vars({("openlineage", "dag_state_change_process_pool_size"): 
max_workers}):
-        listener.on_dag_run_running(mock.MagicMock(), None)
-    mock_executor.assert_called_once_with(max_workers=expected, 
initializer=mock.ANY)
-    mock_executor.return_value.submit.assert_called_once()
+    def test_adapter_start_task_is_called_with_proper_arguments(
+        self,
+        mock_get_job_name,
+        mock_get_airflow_mapped_task_facet,
+        mock_get_user_provided_run_facets,
+        mock_get_airflow_run_facet,
+        mock_disabled,
+        mock_debug_mode,
+    ):
+        """Tests that the 'start_task' method of the OpenLineageAdapter is 
invoked with the correct arguments.
+
+        The test checks that the job name, job description, event time, and 
other related data are
+        correctly passed to the adapter. It also verifies that custom facets 
and Airflow run facets are
+        correctly retrieved and included in the call. This ensures that all 
relevant data, including custom
+        and Airflow-specific metadata, is accurately conveyed to the adapter 
during the initialization of a task,
+        reflecting the comprehensive tracking of task execution contexts."""
+
+        listener, task_instance = self._create_listener_and_task_instance()
+        mock_get_job_name.return_value = "job_name"
+        mock_get_airflow_mapped_task_facet.return_value = {"mapped_facet": 1}
+        mock_get_user_provided_run_facets.return_value = {"custom_user_facet": 
2}
+        mock_get_airflow_run_facet.return_value = {"airflow_run_facet": 3}
+        mock_disabled.return_value = False
+
+        listener.on_task_instance_running(None, task_instance)
+        listener.adapter.start_task.assert_called_once_with(
+            run_id="2020-01-01T01:01:01.dag_id.task_id.1.-1",
+            job_name="job_name",
+            job_description="Test DAG Description",
+            event_time="2023-01-01T13:01:01",
+            parent_job_name="dag_id",
+            parent_run_id="2020-01-01T01:01:01.dag_id.0",
+            code_location=None,
+            nominal_start_time=None,
+            nominal_end_time=None,
+            owners=["Test Owner"],
+            task=listener.extractor_manager.extract_metadata(),
+            run_facets={
+                "mapped_facet": 1,
+                "custom_user_facet": 2,
+                "airflow_run_facet": 3,
+                "debug": AirflowDebugRunFacet(packages=ANY),
+            },
+        )
 
+    @mock.patch("airflow.providers.openlineage.conf.debug_mode", 
return_value=True)
+    
@mock.patch("airflow.providers.openlineage.plugins.listener.is_operator_disabled")
+    
@mock.patch("airflow.providers.openlineage.plugins.listener.get_airflow_run_facet")
+    
@mock.patch("airflow.providers.openlineage.plugins.listener.get_user_provided_run_facets")
+    @mock.patch("airflow.providers.openlineage.plugins.listener.get_job_name")
+    @mock.patch(
+        
"airflow.providers.openlineage.plugins.listener.OpenLineageListener._execute", 
new=regular_call
+    )
+    @mock.patch("airflow.utils.timezone.utcnow", 
return_value=dt.datetime(2023, 1, 3, 13, 1, 1))
+    def test_adapter_fail_task_is_called_with_proper_arguments(
+        self,
+        mock_utcnow,
+        mock_get_job_name,
+        mock_get_user_provided_run_facets,
+        mock_get_airflow_run_facet,
+        mock_disabled,
+        mock_debug_mode,
+    ):
+        """Tests that the 'fail_task' method of the OpenLineageAdapter is 
invoked with the correct arguments.
+
+        This test ensures that the job name is accurately retrieved and 
included, along with the generated
+        run_id and task metadata. By mocking the job name retrieval and the 
run_id generation,
+        the test verifies the integrity and consistency of the data passed to 
the adapter during task
+        failure events, thus confirming that the adapter's failure handling is 
functioning as expected.
+        """
+
+        listener, task_instance = self._create_listener_and_task_instance()
+        task_instance.get_template_context()["dag_run"].logical_date = 
dt.datetime(2020, 1, 1, 1, 1, 1)
+        mock_get_job_name.return_value = "job_name"
+        mock_get_user_provided_run_facets.return_value = {"custom_user_facet": 
2}
+        mock_get_airflow_run_facet.return_value = {"airflow": {"task": "..."}}
+        mock_disabled.return_value = False
+
+        err = ValueError("test")
+        on_task_failed_listener_kwargs = {"error": err} if AIRFLOW_V_2_10_PLUS 
else {}
+        expected_err_kwargs = {"error": err if AIRFLOW_V_2_10_PLUS else None}
+
+        listener.on_task_instance_failed(
+            previous_state=None, task_instance=task_instance, 
**on_task_failed_listener_kwargs
+        )
+        listener.adapter.fail_task.assert_called_once_with(
+            end_time="2023-01-03T13:01:01",
+            job_name="job_name",
+            parent_job_name="dag_id",
+            parent_run_id="2020-01-01T01:01:01.dag_id.0",
+            run_id="2020-01-01T01:01:01.dag_id.task_id.1.-1",
+            task=listener.extractor_manager.extract_metadata(),
+            run_facets={
+                "custom_user_facet": 2,
+                "airflow": {"task": "..."},
+                "debug": AirflowDebugRunFacet(packages=ANY),
+            },
+            **expected_err_kwargs,
+        )
 
-class MockExecutor:
-    def __init__(self, *args, **kwargs):
-        self.submitted = False
-        self.succeeded = False
-        self.result = None
+    @mock.patch("airflow.providers.openlineage.conf.debug_mode", 
return_value=True)
+    
@mock.patch("airflow.providers.openlineage.plugins.listener.is_operator_disabled")
+    
@mock.patch("airflow.providers.openlineage.plugins.listener.get_airflow_run_facet")
+    
@mock.patch("airflow.providers.openlineage.plugins.listener.get_user_provided_run_facets")
+    @mock.patch("airflow.providers.openlineage.plugins.listener.get_job_name")
+    @mock.patch(
+        
"airflow.providers.openlineage.plugins.listener.OpenLineageListener._execute", 
new=regular_call
+    )
+    @mock.patch("airflow.utils.timezone.utcnow", 
return_value=dt.datetime(2023, 1, 3, 13, 1, 1))
+    def test_adapter_complete_task_is_called_with_proper_arguments(
+        self,
+        mock_utcnow,
+        mock_get_job_name,
+        mock_get_user_provided_run_facets,
+        mock_get_airflow_run_facet,
+        mock_disabled,
+        mock_debug_mode,
+    ):
+        """Tests that the 'complete_task' method of the OpenLineageAdapter is 
called with the correct arguments.
+
+        It checks that the job name is correctly retrieved and passed,
+        along with the run_id and task metadata. The test also simulates 
changes in the try_number
+        attribute of the task instance, as it would occur in Airflow, to 
ensure that the run_id is updated
+        accordingly. This helps confirm the consistency and correctness of the 
data passed to the adapter
+        during the task's lifecycle events.
+        """
+
+        listener, task_instance = self._create_listener_and_task_instance()
+        mock_get_job_name.return_value = "job_name"
+        mock_get_user_provided_run_facets.return_value = {"custom_user_facet": 
2}
+        mock_get_airflow_run_facet.return_value = {"airflow": {"task": "..."}}
+        mock_disabled.return_value = False
+
+        listener.on_task_instance_success(None, task_instance)
+        # This run_id will be different as we did NOT simulate increase of the 
try_number attribute,
+        # which happens in Airflow < 2.10.
+        calls = listener.adapter.complete_task.call_args_list
+        assert len(calls) == 1
+        assert calls[0][1] == dict(
+            end_time="2023-01-03T13:01:01",
+            job_name="job_name",
+            parent_job_name="dag_id",
+            parent_run_id="2020-01-01T01:01:01.dag_id.0",
+            
run_id=f"2020-01-01T01:01:01.dag_id.task_id.{EXPECTED_TRY_NUMBER_1}.-1",
+            task=listener.extractor_manager.extract_metadata(),
+            run_facets={
+                "custom_user_facet": 2,
+                "airflow": {"task": "..."},
+                "debug": AirflowDebugRunFacet(packages=ANY),
+            },
+        )
 
-    def submit(self, fn, /, *args, **kwargs):
-        self.submitted = True
-        try:
-            fn(*args, **kwargs)
-            self.succeeded = True
-        except Exception:
+    @mock.patch(
+        
"airflow.providers.openlineage.plugins.listener.OpenLineageListener._execute", 
new=regular_call
+    )
+    def 
test_on_task_instance_running_correctly_calls_openlineage_adapter_run_id_method(self):
+        """Tests the OpenLineageListener's response when a task instance is in 
the running state.
+
+        This test ensures that when an Airflow task instance transitions to 
the running state,
+        the OpenLineageAdapter's `build_task_instance_run_id` method is called 
exactly once with the correct
+        parameters derived from the task instance.
+        """
+        listener, task_instance = self._create_listener_and_task_instance()
+        listener.on_task_instance_running(None, task_instance)
+        listener.adapter.build_task_instance_run_id.assert_called_once_with(
+            dag_id="dag_id",
+            task_id="task_id",
+            logical_date=dt.datetime(2020, 1, 1, 1, 1, 1),
+            try_number=1,
+            map_index=-1,
+        )
+
+    @mock.patch(
+        
"airflow.providers.openlineage.plugins.listener.OpenLineageListener._execute", 
new=regular_call
+    )
+    def 
test_on_task_instance_failed_correctly_calls_openlineage_adapter_run_id_method(self):
+        """Tests the OpenLineageListener's response when a task instance is in 
the failed state.
+
+        This test ensures that when an Airflow task instance transitions to 
the failed state,
+        the OpenLineageAdapter's `build_task_instance_run_id` method is called 
exactly once with the correct
+        parameters derived from the task instance.
+        """
+        listener, task_instance = self._create_listener_and_task_instance()
+        on_task_failed_kwargs = {"error": ValueError("test")} if 
AIRFLOW_V_2_10_PLUS else {}
+
+        listener.on_task_instance_failed(
+            previous_state=None, task_instance=task_instance, 
**on_task_failed_kwargs
+        )
+        listener.adapter.build_task_instance_run_id.assert_called_once_with(
+            dag_id="dag_id",
+            task_id="task_id",
+            logical_date=dt.datetime(2020, 1, 1, 1, 1, 1),
+            try_number=1,
+            map_index=-1,
+        )
+
+    @mock.patch(
+        
"airflow.providers.openlineage.plugins.listener.OpenLineageListener._execute", 
new=regular_call
+    )
+    def 
test_on_task_instance_success_correctly_calls_openlineage_adapter_run_id_method(self):
+        """Tests the OpenLineageListener's response when a task instance is in 
the success state.
+
+        This test ensures that when an Airflow task instance transitions to 
the success state,
+        the OpenLineageAdapter's `build_task_instance_run_id` method is called 
exactly once with the correct
+        parameters derived from the task instance.
+        """
+        listener, task_instance = self._create_listener_and_task_instance()
+        listener.on_task_instance_success(None, task_instance)
+        listener.adapter.build_task_instance_run_id.assert_called_once_with(
+            dag_id="dag_id",
+            task_id="task_id",
+            logical_date=dt.datetime(2020, 1, 1, 1, 1, 1),
+            try_number=EXPECTED_TRY_NUMBER_1,
+            map_index=-1,
+        )
+
+    @mock.patch("airflow.models.taskinstance.get_listener_manager")
+    def 
test_listener_on_task_instance_failed_is_called_before_try_number_increment(self,
 mock_listener):
+        """Validates the listener's on-failure method is called before 
try_number increment happens.
+
+        This test ensures that when a task instance fails, Airflow's listener 
method for
+        task failure (`on_task_instance_failed`) is invoked before the 
increment of the
+        `try_number` attribute happens. A custom exception simulates task 
failure, and the test
+        captures the `try_number` at the moment of this method call.
+        """
+        captured_try_numbers = {}
+        self._setup_mock_listener(mock_listener, captured_try_numbers)
+
+        # Just to make sure no error interferes with the test, and we do not 
suppress it by accident
+        class CustomError(Exception):
             pass
-        return MagicMock()
 
-    def shutdown(self, *args, **kwargs):
-        print("Shutting down")
+        def fail_callable(**kwargs):
+            raise CustomError("Simulated task failure")
 
+        _, task_instance = self._create_test_dag_and_task(fail_callable, 
"failure")
+        # try_number before execution
+        assert task_instance.try_number == TRY_NUMBER_BEFORE_EXECUTION
+        with suppress(CustomError):
+            task_instance.run()
 
[email protected](
-    ("method", "dag_run_state"),
-    [
-        ("on_dag_run_running", DagRunState.RUNNING),
-        ("on_dag_run_success", DagRunState.SUCCESS),
-        ("on_dag_run_failed", DagRunState.FAILED),
-    ],
-)
-@patch("airflow.providers.openlineage.plugins.adapter.OpenLineageAdapter.emit")
-def test_listener_on_dag_run_state_changes(mock_emit, method, dag_run_state, 
create_task_instance):
-    mock_executor = MockExecutor()
-    ti = create_task_instance(dag_id="dag", task_id="op")
-    # Change the state explicitly to set end_date following the logic in the 
method
-    ti.dag_run.set_state(dag_run_state)
-    with mock.patch(
-        "airflow.providers.openlineage.plugins.listener.ProcessPoolExecutor", 
return_value=mock_executor
+        # try_number at the moment of function being called
+        assert captured_try_numbers["running"] == TRY_NUMBER_RUNNING
+        assert captured_try_numbers["failed"] == TRY_NUMBER_FAILED
+
+        # try_number after task has been executed
+        assert task_instance.try_number == TRY_NUMBER_AFTER_EXECUTION
+
+    @mock.patch("airflow.models.taskinstance.get_listener_manager")
+    def 
test_listener_on_task_instance_success_is_called_after_try_number_increment(self,
 mock_listener):
+        """Validates the listener's on-success method is called before 
try_number increment happens.
+
+        This test ensures that when a task instance successfully completes, the
+        `on_task_instance_success` method of Airflow's listener is called with 
an
+        incremented `try_number` compared to the `try_number` before execution.
+        The test simulates a successful task execution and captures the 
`try_number` at the method call.
+        """
+        captured_try_numbers = {}
+        self._setup_mock_listener(mock_listener, captured_try_numbers)
+
+        def success_callable(**kwargs):
+            return None
+
+        _, task_instance = self._create_test_dag_and_task(success_callable, 
"success")
+        # try_number before execution
+        assert task_instance.try_number == TRY_NUMBER_BEFORE_EXECUTION
+        task_instance.run()
+
+        # try_number at the moment of function being called
+        assert captured_try_numbers["running"] == TRY_NUMBER_RUNNING
+        assert captured_try_numbers["success"] == TRY_NUMBER_SUCCESS
+
+        # try_number after task has been executed
+        assert task_instance.try_number == TRY_NUMBER_AFTER_EXECUTION
+
+    
@mock.patch("airflow.providers.openlineage.plugins.listener.is_operator_disabled")
+    
@mock.patch("airflow.providers.openlineage.plugins.listener.get_airflow_run_facet")
+    
@mock.patch("airflow.providers.openlineage.plugins.listener.get_user_provided_run_facets")
+    @mock.patch("airflow.providers.openlineage.plugins.listener.get_job_name")
+    def 
test_listener_on_task_instance_running_do_not_call_adapter_when_disabled_operator(
+        self, mock_get_job_name, mock_get_user_provided_run_facets, 
mock_get_airflow_run_facet, mock_disabled
     ):
-        listener = OpenLineageListener()
-        getattr(listener, method)(ti.dag_run, None)
-        assert mock_executor.submitted is True
-        assert mock_executor.succeeded is True
-        mock_emit.assert_called_once()
+        listener, task_instance = self._create_listener_and_task_instance()
+        mock_get_job_name.return_value = "job_name"
+        mock_get_user_provided_run_facets.return_value = {"custom_facet": 2}
+        mock_get_airflow_run_facet.return_value = {"airflow_run_facet": 3}
+        mock_disabled.return_value = True
+
+        listener.on_task_instance_running(None, task_instance)
+        mock_disabled.assert_called_once_with(task_instance.task)
+        listener.adapter.build_dag_run_id.assert_not_called()
+        listener.adapter.build_task_instance_run_id.assert_not_called()
+        listener.extractor_manager.extract_metadata.assert_not_called()
+        listener.adapter.start_task.assert_not_called()
+
+    
@mock.patch("airflow.providers.openlineage.plugins.listener.is_operator_disabled")
+    
@mock.patch("airflow.providers.openlineage.plugins.listener.get_user_provided_run_facets")
+    @mock.patch("airflow.providers.openlineage.plugins.listener.get_job_name")
+    def 
test_listener_on_task_instance_failed_do_not_call_adapter_when_disabled_operator(
+        self, mock_get_job_name, mock_get_user_provided_run_facets, 
mock_disabled
+    ):
+        listener, task_instance = self._create_listener_and_task_instance()
+        mock_get_user_provided_run_facets.return_value = {"custom_facet": 2}
+        mock_disabled.return_value = True
 
+        on_task_failed_kwargs = {"error": ValueError("test")} if 
AIRFLOW_V_2_10_PLUS else {}
 
-def test_listener_logs_failed_serialization():
-    listener = OpenLineageListener()
-    callback_future = Future()
+        listener.on_task_instance_failed(
+            previous_state=None, task_instance=task_instance, 
**on_task_failed_kwargs
+        )
+        mock_disabled.assert_called_once_with(task_instance.task)
+        listener.adapter.build_dag_run_id.assert_not_called()
+        listener.adapter.build_task_instance_run_id.assert_not_called()
+        listener.extractor_manager.extract_metadata.assert_not_called()
+        listener.adapter.fail_task.assert_not_called()
+
+    
@mock.patch("airflow.providers.openlineage.plugins.listener.is_operator_disabled")
+    
@mock.patch("airflow.providers.openlineage.plugins.listener.get_user_provided_run_facets")
+    @mock.patch("airflow.providers.openlineage.plugins.listener.get_job_name")
+    def 
test_listener_on_task_instance_success_do_not_call_adapter_when_disabled_operator(
+        self, mock_get_job_name, mock_get_user_provided_run_facets, 
mock_disabled
+    ):
+        listener, task_instance = self._create_listener_and_task_instance()
+        mock_get_user_provided_run_facets.return_value = {"custom_facet": 2}
+        mock_disabled.return_value = True
 
-    def set_result(*args, **kwargs):
-        callback_future.set_result(True)
+        listener.on_task_instance_success(None, task_instance)
+        mock_disabled.assert_called_once_with(task_instance.task)
+        listener.adapter.build_dag_run_id.assert_not_called()
+        listener.adapter.build_task_instance_run_id.assert_not_called()
+        listener.extractor_manager.extract_metadata.assert_not_called()
+        listener.adapter.complete_task.assert_not_called()
 
-    listener.log = MagicMock()
-    listener.log.warning = MagicMock(side_effect=set_result)
-    listener.adapter = OpenLineageAdapter(
-        
client=OpenLineageClient(transport=ConsoleTransport(config=ConsoleConfig()))
+    @pytest.mark.parametrize(
+        "max_workers,expected",
+        [
+            (None, 1),
+            ("8", 8),
+        ],
     )
-    event_time = dt.datetime.now()
-    fut = listener.submit_callable(
-        listener.adapter.dag_failed,
-        dag_id="",
-        run_id="",
-        end_date=event_time,
-        logical_date=callback_future,
-        clear_number=0,
-        dag_run_state=DagRunState.FAILED,
-        task_ids=["task_id"],
-        msg="",
+    
@mock.patch("airflow.providers.openlineage.plugins.listener.ProcessPoolExecutor",
 autospec=True)
+    def test_listener_on_dag_run_state_changes_configure_process_pool_size(
+        self, mock_executor, max_workers, expected
+    ):
+        """mock ProcessPoolExecutor and check if 
conf.dag_state_change_process_pool_size is applied to max_workers"""
+        listener = OpenLineageListener()
+        # mock ProcessPoolExecutor class
+        with conf_vars({("openlineage", "dag_state_change_process_pool_size"): 
max_workers}):
+            listener.on_dag_run_running(mock.MagicMock(), None)
+        mock_executor.assert_called_once_with(max_workers=expected, 
initializer=mock.ANY)
+        mock_executor.return_value.submit.assert_called_once()
+
+    @pytest.mark.parametrize(
+        ("method", "dag_run_state"),
+        [
+            ("on_dag_run_running", DagRunState.RUNNING),
+            ("on_dag_run_success", DagRunState.SUCCESS),
+            ("on_dag_run_failed", DagRunState.FAILED),
+        ],
     )
-    assert fut.exception(10)
-    callback_future.result(10)
-    assert callback_future.done()
-    listener.log.debug.assert_not_called()
-    listener.log.warning.assert_called_once()
+    
@patch("airflow.providers.openlineage.plugins.adapter.OpenLineageAdapter.emit")
+    def test_listener_on_dag_run_state_changes(self, mock_emit, method, 
dag_run_state, create_task_instance):
+        mock_executor = MockExecutor()
+        ti = create_task_instance(dag_id="dag", task_id="op")
+        # Change the state explicitly to set end_date following the logic in 
the method
+        ti.dag_run.set_state(dag_run_state)
+        with mock.patch(
+            
"airflow.providers.openlineage.plugins.listener.ProcessPoolExecutor", 
return_value=mock_executor
+        ):
+            listener = OpenLineageListener()
+            getattr(listener, method)(ti.dag_run, None)
+            assert mock_executor.submitted is True
+            assert mock_executor.succeeded is True
+            mock_emit.assert_called_once()
+
+    def test_listener_logs_failed_serialization(self):
+        listener = OpenLineageListener()
+        callback_future = Future()
+
+        def set_result(*args, **kwargs):
+            callback_future.set_result(True)
+
+        listener.log = MagicMock()
+        listener.log.warning = MagicMock(side_effect=set_result)
+        listener.adapter = OpenLineageAdapter(
+            
client=OpenLineageClient(transport=ConsoleTransport(config=ConsoleConfig()))
+        )
+        event_time = dt.datetime.now()
+        fut = listener.submit_callable(
+            listener.adapter.dag_failed,
+            dag_id="",
+            run_id="",
+            end_date=event_time,
+            logical_date=callback_future,
+            clear_number=0,
+            dag_run_state=DagRunState.FAILED,
+            task_ids=["task_id"],
+            msg="",
+        )
+        assert fut.exception(10)
+        callback_future.result(10)
+        assert callback_future.done()
+        listener.log.debug.assert_not_called()
+        listener.log.warning.assert_called_once()
 
 
-class TestOpenLineageSelectiveEnable:
[email protected](AIRFLOW_V_3_0_PLUS, reason="Airflow 2 tests")
+class TestOpenLineageSelectiveEnableAirflow2:
     def setup_method(self):
         date = dt.datetime(2022, 1, 1)
         self.dag = DAG(
@@ -706,14 +1338,11 @@ class TestOpenLineageSelectiveEnable:
             task_id="test_task_selective_enable_2", dag=self.dag, 
python_callable=simple_callable
         )
         run_id = str(uuid.uuid1())
-        if AIRFLOW_V_3_0_PLUS:
-            dagrun_kwargs = {
-                "dag_version": None,
-                "logical_date": date,
-                "triggered_by": types.DagRunTriggeredByType.TEST,
-            }
-        else:
-            dagrun_kwargs = {"execution_date": date}
+        dagrun_kwargs = {
+            "dag_version": None,
+            "logical_date": date,
+            "triggered_by": types.DagRunTriggeredByType.TEST,
+        }
         self.dagrun = self.dag.create_dagrun(
             run_id=run_id,
             data_interval=(date, date),
@@ -791,10 +1420,7 @@ class TestOpenLineageSelectiveEnable:
             listener.on_task_instance_running(None, self.task_instance_1, None)
             listener.on_task_instance_success(None, self.task_instance_1, None)
             listener.on_task_instance_failed(
-                previous_state=None,
-                task_instance=self.task_instance_1,
-                session=None,
-                **on_task_failed_kwargs,
+                previous_state=None, task_instance=self.task_instance_1, 
**on_task_failed_kwargs, session=None
             )
 
             assert expected_task_call_count == 
listener.extractor_manager.extract_metadata.call_count
@@ -803,10 +1429,7 @@ class TestOpenLineageSelectiveEnable:
             listener.on_task_instance_running(None, self.task_instance_2, None)
             listener.on_task_instance_success(None, self.task_instance_2, None)
             listener.on_task_instance_failed(
-                previous_state=None,
-                task_instance=self.task_instance_2,
-                session=None,
-                **on_task_failed_kwargs,
+                previous_state=None, task_instance=self.task_instance_2, 
**on_task_failed_kwargs, session=None
             )
 
             # with selective-enable disabled both task_1 and task_2 should 
trigger metadata extraction
@@ -850,10 +1473,10 @@ class TestOpenLineageSelectiveEnable:
             listener.on_dag_run_success(self.dagrun, msg="test success")
 
             # run TaskInstance-related hooks for lineage enabled task
-            listener.on_task_instance_running(None, self.task_instance_1, None)
-            listener.on_task_instance_success(None, self.task_instance_1, None)
+            listener.on_task_instance_running(None, self.task_instance_1, 
session=None)
+            listener.on_task_instance_success(None, self.task_instance_1, 
session=None)
             listener.on_task_instance_failed(
-                previous_state=None, task_instance=self.task_instance_1, 
session=None, **on_task_failed_kwargs
+                previous_state=None, task_instance=self.task_instance_1, 
**on_task_failed_kwargs, session=None
             )
 
         assert expected_call_count == listener._executor.submit.call_count

Reply via email to