This is an automated email from the ASF dual-hosted git repository. mobuchowski pushed a commit to branch tasksdk-call-listeners in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 59f1deba9179ad1343be2eff56596e2e025ed82b Author: Maciej Obuchowski <[email protected]> AuthorDate: Wed Jan 8 13:31:46 2025 +0100 make OpenLineage provider support Airflow 3's listener interface Signed-off-by: Maciej Obuchowski <[email protected]> --- .../administration-and-deployment/listeners.rst | 20 +- .../providers/openlineage/plugins/listener.py | 219 ++- .../openlineage/utils/selective_enable.py | 8 +- .../airflow/providers/openlineage/utils/utils.py | 42 +- .../tests/openlineage/extractors/test_manager.py | 144 +- .../tests/openlineage/plugins/test_listener.py | 1691 +++++++++++++------- 6 files changed, 1460 insertions(+), 664 deletions(-) diff --git a/docs/apache-airflow/administration-and-deployment/listeners.rst b/docs/apache-airflow/administration-and-deployment/listeners.rst index 8ca3ed93fc0..97467361ac3 100644 --- a/docs/apache-airflow/administration-and-deployment/listeners.rst +++ b/docs/apache-airflow/administration-and-deployment/listeners.rst @@ -165,9 +165,7 @@ For example if you want to implement a listener that uses the ``error`` field in ... @hookimpl - def on_task_instance_failed( - self, previous_state, task_instance, error: None | str | BaseException, session - ): + def on_task_instance_failed(self, previous_state, task_instance, error: None | str | BaseException): # Handle error case here pass @@ -177,15 +175,19 @@ For example if you want to implement a listener that uses the ``error`` field in ... @hookimpl - def on_task_instance_failed(self, previous_state, task_instance, session): + def on_task_instance_failed(self, previous_state, task_instance): # Handle no error case here pass List of changes in the listener interfaces since 2.8.0 when they were introduced: -+-----------------+-----------------------------+---------------------------------------+ -| Airflow Version | Affected method | Change | -+=================+=============================+=======================================+ -| 2.10.0 | ``on_task_instance_failed`` | An error field added to the interface | -+-----------------+-----------------------------+---------------------------------------+ ++-----------------+--------------------------------------------+-------------------------------------------------------------------------+ +| Airflow Version | Affected method | Change | ++=================+============================================+=========================================================================+ +| 2.10.0 | ``on_task_instance_failed`` | An error field added to the interface | ++-----------------+--------------------------------------------+-------------------------------------------------------------------------+ +| 3.0.0 | ``on_task_instance_running``, | ``session`` argument removed from task instance listeners, | +| | ``on_task_instance_success``, | ``task_instance`` object is now an instance of ``RuntimeTaskInstance`` | +| | ``on_task_instance_failed`` | | ++-----------------+--------------------------------------------+-------------------------------------------------------------------------+ diff --git a/providers/src/airflow/providers/openlineage/plugins/listener.py b/providers/src/airflow/providers/openlineage/plugins/listener.py index c1da206c987..99ced767345 100644 --- a/providers/src/airflow/providers/openlineage/plugins/listener.py +++ b/providers/src/airflow/providers/openlineage/plugins/listener.py @@ -19,6 +19,7 @@ from __future__ import annotations import logging import os from concurrent.futures import ProcessPoolExecutor +from datetime import datetime from typing import TYPE_CHECKING import psutil @@ -33,6 +34,7 @@ from airflow.providers.openlineage.extractors import ExtractorManager from airflow.providers.openlineage.plugins.adapter import OpenLineageAdapter, RunState from airflow.providers.openlineage.utils.utils import ( AIRFLOW_V_2_10_PLUS, + AIRFLOW_V_3_0_PLUS, get_airflow_dag_run_facet, get_airflow_debug_facet, get_airflow_job_facet, @@ -42,7 +44,6 @@ from airflow.providers.openlineage.utils.utils import ( get_user_provided_run_facets, is_operator_disabled, is_selective_lineage_enabled, - is_ti_rescheduled_already, print_warning, ) from airflow.settings import configure_orm @@ -52,9 +53,9 @@ from airflow.utils.state import TaskInstanceState from airflow.utils.timeout import timeout if TYPE_CHECKING: - from sqlalchemy.orm import Session - from airflow.models import TaskInstance + from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance + from airflow.settings import Session _openlineage_listener: OpenLineageListener | None = None @@ -87,28 +88,49 @@ class OpenLineageListener: self.extractor_manager = ExtractorManager() self.adapter = OpenLineageAdapter() - @hookimpl - def on_task_instance_running( - self, - previous_state: TaskInstanceState, - task_instance: TaskInstance, - session: Session, # This will always be QUEUED - ) -> None: - if not getattr(task_instance, "task", None) is not None: - self.log.warning( - "No task set for TI object task_id: %s - dag_id: %s - run_id %s", - task_instance.task_id, - task_instance.dag_id, - task_instance.run_id, - ) - return + if AIRFLOW_V_3_0_PLUS: + + @hookimpl + def on_task_instance_running( + self, + previous_state: TaskInstanceState, + task_instance: RuntimeTaskInstance, + ): + self.log.debug("OpenLineage listener got notification about task instance start") + context = task_instance.get_template_context() + + task = context["task"] + if TYPE_CHECKING: + assert task + dagrun = context["dag_run"] + dag = context["dag"] + self._on_task_instance_running(task_instance, dag, dagrun, task) + else: - self.log.debug("OpenLineage listener got notification about task instance start") - dagrun = task_instance.dag_run - task = task_instance.task - if TYPE_CHECKING: - assert task - dag = task.dag + @hookimpl + def on_task_instance_running( + self, + previous_state: TaskInstanceState, + task_instance: TaskInstance, + session: Session, # type: ignore[valid-type] + ) -> None: + if not getattr(task_instance, "task", None) is not None: + self.log.warning( + "No task set for TI object task_id: %s - dag_id: %s - run_id %s", + task_instance.task_id, + task_instance.dag_id, + task_instance.run_id, + ) + return + + self.log.debug("OpenLineage listener got notification about task instance start") + task = task_instance.task + if TYPE_CHECKING: + assert task + + self._on_task_instance_running(task_instance, task.dag, task_instance.dag_run, task) + + def _on_task_instance_running(self, task_instance: RuntimeTaskInstance | TaskInstance, dag, dagrun, task): if is_operator_disabled(task): self.log.debug( "Skipping OpenLineage event emission for operator `%s` " @@ -127,35 +149,38 @@ class OpenLineageListener: return # Needs to be calculated outside of inner method so that it gets cached for usage in fork processes + data_interval_start = dagrun.data_interval_start + if isinstance(data_interval_start, datetime): + data_interval_start = data_interval_start.isoformat() + data_interval_end = dagrun.data_interval_end + if isinstance(data_interval_end, datetime): + data_interval_end = data_interval_end.isoformat() + + clear_number = 0 + if hasattr(dagrun, "clear_number"): + clear_number = dagrun.clear_number + debug_facet = get_airflow_debug_facet() @print_warning(self.log) def on_running(): - # that's a workaround to detect task running from deferred state - # we return here because Airflow 2.3 needs task from deferred state - if task_instance.next_method is not None: - return - - if is_ti_rescheduled_already(task_instance): + context = task_instance.get_template_context() + if hasattr(context, "task_reschedule_count") and context["task_reschedule_count"] > 0: self.log.debug("Skipping this instance of rescheduled task - START event was emitted already") return parent_run_id = self.adapter.build_dag_run_id( dag_id=dag.dag_id, logical_date=dagrun.logical_date, - clear_number=dagrun.clear_number, + clear_number=clear_number, ) - - if hasattr(task_instance, "logical_date"): - logical_date = task_instance.logical_date - else: - logical_date = task_instance.execution_date + start_date = task_instance.start_date if task_instance.start_date else timezone.utcnow() task_uuid = self.adapter.build_task_instance_run_id( dag_id=dag.dag_id, task_id=task.task_id, try_number=task_instance.try_number, - logical_date=logical_date, + logical_date=dagrun.logical_date, map_index=task_instance.map_index, ) event_type = RunState.RUNNING.value.lower() @@ -164,11 +189,6 @@ class OpenLineageListener: with Stats.timer(f"ol.extract.{event_type}.{operator_name}"): task_metadata = self.extractor_manager.extract_metadata(dagrun, task) - start_date = task_instance.start_date if task_instance.start_date else timezone.utcnow() - data_interval_start = ( - dagrun.data_interval_start.isoformat() if dagrun.data_interval_start else None - ) - data_interval_end = dagrun.data_interval_end.isoformat() if dagrun.data_interval_end else None redacted_event = self.adapter.start_task( run_id=task_uuid, job_name=get_job_name(task), @@ -195,17 +215,39 @@ class OpenLineageListener: self._execute(on_running, "on_running", use_fork=True) - @hookimpl - def on_task_instance_success( - self, previous_state: TaskInstanceState, task_instance: TaskInstance, session: Session - ) -> None: - self.log.debug("OpenLineage listener got notification about task instance success") + if AIRFLOW_V_3_0_PLUS: + + @hookimpl + def on_task_instance_success( + self, previous_state: TaskInstanceState, task_instance: RuntimeTaskInstance + ) -> None: + self.log.debug("OpenLineage listener got notification about task instance success") + + context = task_instance.get_template_context() + task = context["task"] + if TYPE_CHECKING: + assert task + dagrun = context["dag_run"] + dag = context["dag"] + self._on_task_instance_success(task_instance, dag, dagrun, task) + + else: + + @hookimpl + def on_task_instance_success( + self, + previous_state: TaskInstanceState, + task_instance: TaskInstance, + session: Session, # type: ignore[valid-type] + ) -> None: + self.log.debug("OpenLineage listener got notification about task instance success") + task = task_instance.task + if TYPE_CHECKING: + assert task + self._on_task_instance_success(task_instance, task.dag, task_instance.dag_run, task) - dagrun = task_instance.dag_run - task = task_instance.task - if TYPE_CHECKING: - assert task - dag = task.dag + def _on_task_instance_success(self, task_instance: RuntimeTaskInstance, dag, dagrun, task): + end_date = timezone.utcnow() if is_operator_disabled(task): self.log.debug( @@ -232,15 +274,11 @@ class OpenLineageListener: clear_number=dagrun.clear_number, ) - if hasattr(task_instance, "logical_date"): - logical_date = task_instance.logical_date - else: - logical_date = task_instance.execution_date task_uuid = self.adapter.build_task_instance_run_id( dag_id=dag.dag_id, task_id=task.task_id, try_number=_get_try_number_success(task_instance), - logical_date=logical_date, + logical_date=dagrun.logical_date, map_index=task_instance.map_index, ) event_type = RunState.COMPLETE.value.lower() @@ -251,8 +289,6 @@ class OpenLineageListener: dagrun, task, complete=True, task_instance=task_instance ) - end_date = task_instance.end_date if task_instance.end_date else timezone.utcnow() - redacted_event = self.adapter.complete_task( run_id=task_uuid, job_name=get_job_name(task), @@ -273,7 +309,7 @@ class OpenLineageListener: self._execute(on_success, "on_success", use_fork=True) - if AIRFLOW_V_2_10_PLUS: + if AIRFLOW_V_3_0_PLUS: @hookimpl def on_task_instance_failed( @@ -281,36 +317,54 @@ class OpenLineageListener: previous_state: TaskInstanceState, task_instance: TaskInstance, error: None | str | BaseException, - session: Session, ) -> None: - self._on_task_instance_failed( - previous_state=previous_state, task_instance=task_instance, error=error, session=session - ) + self.log.debug("OpenLineage listener got notification about task instance failure") + context = task_instance.get_template_context() + task = context["task"] + if TYPE_CHECKING: + assert task + dagrun = context["dag_run"] + dag = context["dag"] + self._on_task_instance_failed(task_instance, dag, dagrun, task, error) + elif AIRFLOW_V_2_10_PLUS: + + @hookimpl + def on_task_instance_failed( + self, + previous_state: TaskInstanceState, + task_instance: TaskInstance, + error: None | str | BaseException, + session: Session, # type: ignore[valid-type] + ) -> None: + self.log.debug("OpenLineage listener got notification about task instance failure") + task = task_instance.task + if TYPE_CHECKING: + assert task + self._on_task_instance_failed(task_instance, task.dag, task_instance.dag_run, task, error) else: @hookimpl def on_task_instance_failed( - self, previous_state: TaskInstanceState, task_instance: TaskInstance, session: Session + self, + previous_state: TaskInstanceState, + task_instance: TaskInstance, + session: Session, # type: ignore[valid-type] ) -> None: - self._on_task_instance_failed( - previous_state=previous_state, task_instance=task_instance, error=None, session=session - ) + task = task_instance.task + if TYPE_CHECKING: + assert task + self._on_task_instance_failed(task_instance, task.dag, task_instance.dag_run, task) def _on_task_instance_failed( self, - previous_state: TaskInstanceState, - task_instance: TaskInstance, - session: Session, + task_instance: TaskInstance | RuntimeTaskInstance, + dag, + dagrun, + task, error: None | str | BaseException = None, ) -> None: - self.log.debug("OpenLineage listener got notification about task instance failure") - - dagrun = task_instance.dag_run - task = task_instance.task - if TYPE_CHECKING: - assert task - dag = task.dag + end_date = timezone.utcnow() if is_operator_disabled(task): self.log.debug( @@ -337,16 +391,11 @@ class OpenLineageListener: clear_number=dagrun.clear_number, ) - if hasattr(task_instance, "logical_date"): - logical_date = task_instance.logical_date - else: - logical_date = task_instance.execution_date - task_uuid = self.adapter.build_task_instance_run_id( dag_id=dag.dag_id, task_id=task.task_id, try_number=task_instance.try_number, - logical_date=logical_date, + logical_date=dagrun.logical_date, map_index=task_instance.map_index, ) event_type = RunState.FAIL.value.lower() @@ -357,8 +406,6 @@ class OpenLineageListener: dagrun, task, complete=True, task_instance=task_instance ) - end_date = task_instance.end_date if task_instance.end_date else timezone.utcnow() - redacted_event = self.adapter.fail_task( run_id=task_uuid, job_name=get_job_name(task), diff --git a/providers/src/airflow/providers/openlineage/utils/selective_enable.py b/providers/src/airflow/providers/openlineage/utils/selective_enable.py index a3c16a1e18d..b0cd8a4455c 100644 --- a/providers/src/airflow/providers/openlineage/utils/selective_enable.py +++ b/providers/src/airflow/providers/openlineage/utils/selective_enable.py @@ -18,7 +18,7 @@ from __future__ import annotations import logging -from typing import TypeVar +from typing import TYPE_CHECKING, TypeVar from airflow.models import DAG, Operator, Param from airflow.models.xcom_arg import XComArg @@ -28,6 +28,10 @@ ENABLE_OL_PARAM = Param(True, const=True) DISABLE_OL_PARAM = Param(False, const=False) T = TypeVar("T", bound="DAG | Operator") +if TYPE_CHECKING: + from airflow.sdk.definitions.baseoperator import BaseOperator as SdkBaseOperator + + log = logging.getLogger(__name__) @@ -65,7 +69,7 @@ def disable_lineage(obj: T) -> T: return obj -def is_task_lineage_enabled(task: Operator) -> bool: +def is_task_lineage_enabled(task: Operator | SdkBaseOperator) -> bool: """Check if selective enable OpenLineage parameter is set to True on task level.""" if task.params.get(ENABLE_OL_PARAM_NAME) is False: log.debug( diff --git a/providers/src/airflow/providers/openlineage/utils/utils.py b/providers/src/airflow/providers/openlineage/utils/utils.py index 4408a833fba..2cb998d47aa 100644 --- a/providers/src/airflow/providers/openlineage/utils/utils.py +++ b/providers/src/airflow/providers/openlineage/utils/utils.py @@ -27,12 +27,11 @@ from typing import TYPE_CHECKING, Any, Callable import attrs from openlineage.client.utils import RedactMixin -from sqlalchemy import exists from airflow import __version__ as AIRFLOW_VERSION # TODO: move this maybe to Airflow's logic? -from airflow.models import DAG, BaseOperator, DagRun, MappedOperator, TaskReschedule +from airflow.models import DAG, BaseOperator, DagRun, MappedOperator from airflow.providers.openlineage import ( __version__ as OPENLINEAGE_PROVIDER_VERSION, conf, @@ -52,7 +51,6 @@ from airflow.providers.openlineage.utils.selective_enable import ( is_task_lineage_enabled, ) from airflow.providers.openlineage.version_compat import AIRFLOW_V_2_10_PLUS, AIRFLOW_V_3_0_PLUS -from airflow.sensors.base import BaseSensorOperator from airflow.serialization.serialized_objects import SerializedBaseOperator from airflow.utils.context import AirflowContextDeprecationWarning from airflow.utils.log.secrets_masker import ( @@ -62,7 +60,11 @@ from airflow.utils.log.secrets_masker import ( should_hide_value_for_key, ) from airflow.utils.module_loading import import_string -from airflow.utils.session import NEW_SESSION, provide_session + +try: + from airflow.sdk.definitions.baseoperator import BaseOperator as SdkBaseOperator +except ImportError: + SdkBaseOperator = BaseOperator # type: ignore[misc] if TYPE_CHECKING: from openlineage.client.event_v2 import Dataset as OpenLineageDataset @@ -90,7 +92,7 @@ def try_import_from_string(string: str) -> Any: return import_string(string) -def get_operator_class(task: BaseOperator) -> type: +def get_operator_class(task: BaseOperator | SdkBaseOperator) -> type: if task.__class__.__name__ in ("DecoratedMappedOperator", "MappedOperator"): return task.operator_class return task.__class__ @@ -153,7 +155,7 @@ def get_user_provided_run_facets(ti: TaskInstance, ti_state: TaskInstanceState) return custom_facets -def get_fully_qualified_class_name(operator: BaseOperator | MappedOperator) -> str: +def get_fully_qualified_class_name(operator: BaseOperator | MappedOperator | SdkBaseOperator) -> str: if isinstance(operator, (MappedOperator, SerializedBaseOperator)): # as in airflow.api_connexion.schemas.common_schema.ClassReferenceSchema return operator._task_module + "." + operator._task_type # type: ignore @@ -161,44 +163,22 @@ def get_fully_qualified_class_name(operator: BaseOperator | MappedOperator) -> s return op_class.__module__ + "." + op_class.__name__ -def is_operator_disabled(operator: BaseOperator | MappedOperator) -> bool: +def is_operator_disabled(operator: BaseOperator | MappedOperator | SdkBaseOperator) -> bool: return get_fully_qualified_class_name(operator) in conf.disabled_operators() -def is_selective_lineage_enabled(obj: DAG | BaseOperator | MappedOperator) -> bool: +def is_selective_lineage_enabled(obj: DAG | BaseOperator | MappedOperator | SdkBaseOperator) -> bool: """If selective enable is active check if DAG or Task is enabled to emit events.""" if not conf.selective_enable(): return True if isinstance(obj, DAG): return is_dag_lineage_enabled(obj) - elif isinstance(obj, (BaseOperator, MappedOperator)): + elif isinstance(obj, (BaseOperator, MappedOperator, SdkBaseOperator)): return is_task_lineage_enabled(obj) else: raise TypeError("is_selective_lineage_enabled can only be used on DAG or Operator objects") -@provide_session -def is_ti_rescheduled_already(ti: TaskInstance, session=NEW_SESSION): - if not isinstance(ti.task, BaseSensorOperator): - return False - - if not ti.task.reschedule: - return False - - return ( - session.query( - exists().where( - TaskReschedule.dag_id == ti.dag_id, - TaskReschedule.task_id == ti.task_id, - TaskReschedule.run_id == ti.run_id, - TaskReschedule.map_index == ti.map_index, - TaskReschedule.try_number == ti.try_number, - ) - ).scalar() - is True - ) - - class InfoJsonEncodable(dict): """ Airflow objects might not be json-encodable overall. diff --git a/providers/tests/openlineage/extractors/test_manager.py b/providers/tests/openlineage/extractors/test_manager.py index df64b7d1e75..c3a56f16872 100644 --- a/providers/tests/openlineage/extractors/test_manager.py +++ b/providers/tests/openlineage/extractors/test_manager.py @@ -19,6 +19,7 @@ from __future__ import annotations import tempfile from typing import TYPE_CHECKING, Any +from unittest import mock from unittest.mock import MagicMock import pytest @@ -28,6 +29,7 @@ from openlineage.client.facet_v2 import ( ownership_dataset, schema_dataset, ) +from uuid6 import uuid7 from airflow.io.path import ObjectStoragePath from airflow.lineage.entities import Column, File, Table, User @@ -36,18 +38,22 @@ from airflow.models.taskinstance import TaskInstance from airflow.providers.openlineage.extractors import OperatorLineage from airflow.providers.openlineage.extractors.manager import ExtractorManager from airflow.providers.openlineage.utils.utils import Asset +from airflow.utils import timezone from airflow.utils.state import State from tests_common.test_utils.compat import PythonOperator -from tests_common.test_utils.version_compat import AIRFLOW_V_2_10_PLUS +from tests_common.test_utils.version_compat import AIRFLOW_V_2_10_PLUS, AIRFLOW_V_3_0_PLUS if TYPE_CHECKING: + from datetime import datetime try: from airflow.sdk.definitions.context import Context except ImportError: # TODO: Remove once provider drops support for Airflow 2 from airflow.utils.context import Context + from task_sdk.tests.conftest import MakeTIContextCallable + if AIRFLOW_V_2_10_PLUS: @pytest.fixture @@ -65,6 +71,19 @@ if AIRFLOW_V_2_10_PLUS: hook._hook_lineage_collector = None +if AIRFLOW_V_3_0_PLUS: + from airflow.sdk.api.datamodels._generated import TaskInstance as SDKTaskInstance + from airflow.sdk.execution_time import task_runner + from airflow.sdk.execution_time.comms import StartupDetails + from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance, parse +else: + SDKTaskInstance = ... # type: ignore + task_runner = ... # type: ignore + StartupDetails = ... # type: ignore + RuntimeTaskInstance = ... # type: ignore + parse = ... # type: ignore + + @pytest.mark.parametrize( ("uri", "dataset"), ( @@ -301,7 +320,10 @@ def test_extractor_manager_does_not_use_hook_level_lineage_when_operator( @pytest.mark.db_test [email protected](not AIRFLOW_V_2_10_PLUS, reason="Hook lineage works in Airflow >= 2.10.0") [email protected]( + not AIRFLOW_V_2_10_PLUS or AIRFLOW_V_3_0_PLUS, + reason="Test for hook level lineage in Airflow >= 2.10.0 < 3.0", +) def test_extractor_manager_gets_data_from_pythonoperator(session, dag_maker, hook_lineage_collector): path = None with tempfile.NamedTemporaryFile() as f: @@ -328,3 +350,121 @@ def test_extractor_manager_gets_data_from_pythonoperator(session, dag_maker, hoo assert len(datasets.outputs) == 1 assert datasets.outputs[0].asset == Asset(uri=path) + + [email protected] +def mock_supervisor_comms(): + with mock.patch( + "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True + ) as supervisor_comms: + yield supervisor_comms + + [email protected] +def mocked_parse(spy_agency): + """ + Fixture to set up an inline DAG and use it in a stubbed `parse` function. Use this fixture if you + want to isolate and test `parse` or `run` logic without having to define a DAG file. + + This fixture returns a helper function `set_dag` that: + 1. Creates an in line DAG with the given `dag_id` and `task` (limited to one task) + 2. Constructs a `RuntimeTaskInstance` based on the provided `StartupDetails` and task. + 3. Stubs the `parse` function using `spy_agency`, to return the mocked `RuntimeTaskInstance`. + + After adding the fixture in your test function signature, you can use it like this :: + + mocked_parse( + StartupDetails( + ti=TaskInstance(id=uuid7(), task_id="hello", dag_id="super_basic_run", run_id="c", try_number=1), + file="", + requests_fd=0, + ), + "example_dag_id", + CustomOperator(task_id="hello"), + ) + """ + + def set_dag(what: StartupDetails, dag_id: str, task: BaseOperator) -> RuntimeTaskInstance: + from task_sdk.tests.execution_time.test_task_runner import get_inline_dag + + dag = get_inline_dag(dag_id, task) + t = dag.task_dict[task.task_id] + ti = RuntimeTaskInstance.model_construct( + **what.ti.model_dump(exclude_unset=True), task=t, _ti_context_from_server=what.ti_context + ) + spy_agency.spy_on(parse, call_fake=lambda _: ti) + return ti + + return set_dag + + [email protected] +def make_ti_context() -> MakeTIContextCallable: + """Factory for creating TIRunContext objects.""" + from airflow.sdk.api.datamodels._generated import DagRun, TIRunContext + + def _make_context( + dag_id: str = "test_dag", + run_id: str = "test_run", + logical_date: str | datetime = "2024-12-01T01:00:00Z", + data_interval_start: str | datetime = "2024-12-01T00:00:00Z", + data_interval_end: str | datetime = "2024-12-01T01:00:00Z", + clear_number: int = 0, + start_date: str | datetime = "2024-12-01T01:00:00Z", + run_type: str = "manual", + task_reschedule_count: int = 0, + ) -> TIRunContext: + return TIRunContext( + dag_run=DagRun( + dag_id=dag_id, + run_id=run_id, + logical_date=logical_date, # type: ignore + data_interval_start=data_interval_start, # type: ignore + data_interval_end=data_interval_end, # type: ignore + clear_number=clear_number, # type: ignore + start_date=start_date, # type: ignore + run_type=run_type, # type: ignore + ), + task_reschedule_count=task_reschedule_count, + ) + + return _make_context + + [email protected]_test [email protected](not AIRFLOW_V_3_0_PLUS, reason="Task SDK related test") +def test_extractor_manager_gets_data_from_pythonoperator_tasksdk( + session, hook_lineage_collector, mocked_parse, make_ti_context, mock_supervisor_comms +): + path = None + with tempfile.NamedTemporaryFile() as f: + path = f.name + + def use_read(): + storage_path = ObjectStoragePath(path) + with storage_path.open("w") as out: + out.write("test") + + task = PythonOperator(task_id="test_task_extractor_pythonoperator", python_callable=use_read) + + what = StartupDetails( + ti=SDKTaskInstance( + id=uuid7(), + task_id="test_task_extractor_pythonoperator", + dag_id="test_hookcollector_dag", + run_id="c", + try_number=1, + start_date=timezone.utcnow(), + ), + file="", + requests_fd=0, + ti_context=make_ti_context(), + ) + ti = mocked_parse(what, "test_hookcollector_dag", task) + + task_runner.run(ti, MagicMock()) + + datasets = hook_lineage_collector.collected_assets + + assert len(datasets.outputs) == 1 + assert datasets.outputs[0].asset == Asset(uri=path) diff --git a/providers/tests/openlineage/plugins/test_listener.py b/providers/tests/openlineage/plugins/test_listener.py index 837873f439d..10896286a7d 100644 --- a/providers/tests/openlineage/plugins/test_listener.py +++ b/providers/tests/openlineage/plugins/test_listener.py @@ -18,9 +18,10 @@ from __future__ import annotations import datetime as dt import uuid +from collections import defaultdict from concurrent.futures import Future from contextlib import suppress -from typing import Callable +from typing import TYPE_CHECKING, Callable from unittest import mock from unittest.mock import ANY, MagicMock, patch @@ -29,14 +30,16 @@ import pytest from openlineage.client import OpenLineageClient from openlineage.client.transport import ConsoleTransport from openlineage.client.transport.console import ConsoleConfig +from uuid6 import uuid7 from airflow.models import DAG, DagRun, TaskInstance from airflow.models.baseoperator import BaseOperator +from airflow.operators.empty import EmptyOperator from airflow.providers.openlineage.plugins.adapter import OpenLineageAdapter from airflow.providers.openlineage.plugins.facets import AirflowDebugRunFacet from airflow.providers.openlineage.plugins.listener import OpenLineageListener from airflow.providers.openlineage.utils.selective_enable import disable_lineage, enable_lineage -from airflow.utils import types +from airflow.utils import timezone, types from airflow.utils.state import DagRunState, State from tests_common.test_utils.compat import PythonOperator @@ -55,6 +58,10 @@ TRY_NUMBER_SUCCESS = 0 if AIRFLOW_V_2_10_PLUS else 2 TRY_NUMBER_AFTER_EXECUTION = 0 if AIRFLOW_V_2_10_PLUS else 2 +if TYPE_CHECKING: + from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance + + class TemplateOperator(BaseOperator): template_fields = ["df"] @@ -74,95 +81,104 @@ def regular_call(self, callable, callable_name, use_fork): callable() -@patch("airflow.models.TaskInstance.xcom_push") -@patch("airflow.models.BaseOperator.render_template") -def test_listener_does_not_change_task_instance(render_mock, xcom_push_mock): - render_mock.return_value = render_df() +class MockExecutor: + def __init__(self, *args, **kwargs): + self.submitted = False + self.succeeded = False + self.result = None - date = dt.datetime(2022, 1, 1) - dag = DAG( - "test", - schedule=None, - start_date=date, - user_defined_macros={"render_df": render_df}, - params={"df": {"col": [1, 2]}}, - ) - t = TemplateOperator(task_id="template_op", dag=dag, do_xcom_push=True, df=dag.param("df")) - run_id = str(uuid.uuid1()) - if AIRFLOW_V_3_0_PLUS: - dagrun_kwargs = { - "dag_version": None, - "logical_date": date, - "triggered_by": types.DagRunTriggeredByType.TEST, - } - else: - dagrun_kwargs = {"execution_date": date} - dag.create_dagrun( - run_id=run_id, - data_interval=(date, date), - run_type=types.DagRunType.MANUAL, - state=DagRunState.QUEUED, - **dagrun_kwargs, - ) - ti = TaskInstance(t, run_id=run_id) - ti.check_and_change_state_before_execution() # make listener hook on running event - ti._run_raw_task() + def submit(self, fn, /, *args, **kwargs): + self.submitted = True + try: + fn(*args, **kwargs) + self.succeeded = True + except Exception: + pass + return MagicMock() - # check if task returns the same DataFrame - pd.testing.assert_frame_equal(xcom_push_mock.call_args.kwargs["value"], render_df()) + def shutdown(self, *args, **kwargs): + print("Shutting down") - # check if render_template method always get the same unrendered field - assert not isinstance(render_mock.call_args.args[0], pd.DataFrame) [email protected](AIRFLOW_V_3_0_PLUS, reason="Airflow 2 tests") +class TestOpenLineageListenerAirflow2: + @patch("airflow.models.TaskInstance.xcom_push") + @patch("airflow.models.BaseOperator.render_template") + def test_listener_does_not_change_task_instance(self, render_mock, xcom_push_mock): + render_mock.return_value = render_df() -def _setup_mock_listener(mock_listener: mock.Mock, captured_try_numbers: dict[str, int]) -> None: - """Sets up the mock listener with side effects to capture try numbers for different task instance events. + dag = DAG( + "test", + schedule=None, + start_date=dt.datetime(2022, 1, 1), + user_defined_macros={"render_df": render_df}, + params={"df": {"col": [1, 2]}}, + ) + t = TemplateOperator(task_id="template_op", dag=dag, do_xcom_push=True, df=dag.param("df")) + run_id = str(uuid.uuid1()) + dag.create_dagrun( + state=State.NONE, + run_id=run_id, + ) + ti = TaskInstance(t, run_id=run_id) + ti.check_and_change_state_before_execution() # make listener hook on running event + ti._run_raw_task() - :param mock_listener: The mock object for the listener manager. - :param captured_try_numbers: A dictionary to store captured try numbers keyed by event names. + # check if task returns the same DataFrame + pd.testing.assert_frame_equal(xcom_push_mock.call_args.kwargs["value"], render_df()) - This function iterates through specified event names and sets a side effect on the corresponding - method of the listener manager's hook. The side effect is a nested function that captures the try number - of the task instance when the method is called. + # check if render_template method always get the same unrendered field + assert not isinstance(render_mock.call_args.args[0], pd.DataFrame) - :Example: + def _setup_mock_listener(self, mock_listener: mock.Mock, captured_try_numbers: dict[str, int]) -> None: + """Sets up the mock listener with side effects to capture try numbers for different task instance events. - captured_try_numbers = {} - mock_listener = Mock() - _setup_mock_listener(mock_listener, captured_try_numbers) - # After running a task, captured_try_numbers will have the try number captured at the moment of - execution for specified methods. F.e. {"running": 1, "success": 2} for on_task_instance_running and - on_task_instance_success methods. - """ + :param mock_listener: The mock object for the listener manager. + :param captured_try_numbers: A dictionary to store captured try numbers keyed by event names. + + This function iterates through specified event names and sets a side effect on the corresponding + method of the listener manager's hook. The side effect is a nested function that captures the try number + of the task instance when the method is called. + + :Example: - def capture_try_number(method_name): - def inner(*args, **kwargs): - captured_try_numbers[method_name] = kwargs["task_instance"].try_number + captured_try_numbers = {} + mock_listener = Mock() + _setup_mock_listener(mock_listener, captured_try_numbers) + # After running a task, captured_try_numbers will have the try number captured at the moment of + execution for specified methods. F.e. {"running": 1, "success": 2} for on_task_instance_running and + on_task_instance_success methods. + """ - return inner + def capture_try_number(method_name): + def inner(*args, **kwargs): + captured_try_numbers[method_name] = kwargs["task_instance"].try_number - for event in ["running", "success", "failed"]: - getattr( - mock_listener.return_value.hook, f"on_task_instance_{event}" - ).side_effect = capture_try_number(event) + return inner + for event in ["running", "success", "failed"]: + getattr( + mock_listener.return_value.hook, f"on_task_instance_{event}" + ).side_effect = capture_try_number(event) -def _create_test_dag_and_task(python_callable: Callable, scenario_name: str) -> tuple[DagRun, TaskInstance]: - """Creates a test DAG and a task for a custom test scenario. + def _create_test_dag_and_task( + self, python_callable: Callable, scenario_name: str + ) -> tuple[DagRun, TaskInstance]: + """Creates a test DAG and a task for a custom test scenario. - :param python_callable: The Python callable to be executed by the PythonOperator. - :param scenario_name: The name of the test scenario, used to uniquely name the DAG and task. + :param python_callable: The Python callable to be executed by the PythonOperator. + :param scenario_name: The name of the test scenario, used to uniquely name the DAG and task. - :return: TaskInstance: The created TaskInstance object. + :return: TaskInstance: The created TaskInstance object. This function creates a DAG and a PythonOperator task with the provided python_callable. It generates a unique run ID and creates a DAG run. This setup is useful for testing different scenarios in Airflow tasks. - :Example: + :Example: - def sample_callable(**kwargs): - print("Hello World") + def sample_callable(**kwargs): + print("Hello World") task_instance = _create_test_dag_and_task(sample_callable, "sample_scenario") # Use task_instance to simulate running a task in a test. @@ -175,14 +191,11 @@ def _create_test_dag_and_task(python_callable: Callable, scenario_name: str) -> ) t = PythonOperator(task_id=f"test_task_{scenario_name}", dag=dag, python_callable=python_callable) run_id = str(uuid.uuid1()) - if AIRFLOW_V_3_0_PLUS: - dagrun_kwargs: dict = { - "dag_version": None, - "logical_date": date, - "triggered_by": types.DagRunTriggeredByType.TEST, - } - else: - dagrun_kwargs = {"execution_date": date} + dagrun_kwargs: dict = { + "dag_version": None, + "logical_date": date, + "triggered_by": types.DagRunTriggeredByType.TEST, + } dagrun = dag.create_dagrun( run_id=run_id, data_interval=(date, date), @@ -193,501 +206,1120 @@ def _create_test_dag_and_task(python_callable: Callable, scenario_name: str) -> task_instance = TaskInstance(t, run_id=run_id) return dagrun, task_instance + def _create_listener_and_task_instance(self) -> tuple[OpenLineageListener, TaskInstance]: + """Creates and configures an OpenLineageListener instance and a mock TaskInstance for testing. -def _create_listener_and_task_instance() -> tuple[OpenLineageListener, TaskInstance]: - """Creates and configures an OpenLineageListener instance and a mock TaskInstance for testing. + :return: A tuple containing the configured OpenLineageListener and TaskInstance. - :return: A tuple containing the configured OpenLineageListener and TaskInstance. + This function instantiates an OpenLineageListener, sets up its required properties with mock objects, and + creates a mock TaskInstance with predefined attributes. This setup is commonly used for testing the + interaction between an OpenLineageListener and a TaskInstance in Airflow. - This function instantiates an OpenLineageListener, sets up its required properties with mock objects, and - creates a mock TaskInstance with predefined attributes. This setup is commonly used for testing the - interaction between an OpenLineageListener and a TaskInstance in Airflow. + :Example: - :Example: + listener, task_instance = _create_listener_and_task_instance() + # Now you can use listener and task_instance in your tests to simulate their interaction. + """ - listener, task_instance = _create_listener_and_task_instance() - # Now you can use listener and task_instance in your tests to simulate their interaction. - """ + def mock_dag_id(dag_id, logical_date, clear_number): + return f"{logical_date.isoformat()}.{dag_id}.{clear_number}" + + def mock_task_id(dag_id, task_id, try_number, logical_date, map_index): + return f"{logical_date.isoformat()}.{dag_id}.{task_id}.{try_number}.{map_index}" - def mock_dag_id(dag_id, logical_date, clear_number): - return f"{logical_date.isoformat()}.{dag_id}.{clear_number}" - - def mock_task_id(dag_id, task_id, try_number, logical_date, map_index): - return f"{logical_date.isoformat()}.{dag_id}.{task_id}.{try_number}.{map_index}" - - listener = OpenLineageListener() - listener.extractor_manager = mock.Mock() - - metadata = mock.Mock() - metadata.run_facets = {"run_facet": 1} - listener.extractor_manager.extract_metadata.return_value = metadata - - adapter = mock.Mock() - adapter.build_dag_run_id.side_effect = mock_dag_id - adapter.build_task_instance_run_id.side_effect = mock_task_id - adapter.start_task = mock.Mock() - adapter.fail_task = mock.Mock() - adapter.complete_task = mock.Mock() - listener.adapter = adapter - - task_instance = TaskInstance(task=mock.Mock()) - task_instance.dag_run = DagRun() - task_instance.dag_run.run_id = "dag_run_run_id" - task_instance.dag_run.data_interval_start = None - task_instance.dag_run.data_interval_end = None - task_instance.dag_run.clear_number = 0 - if AIRFLOW_V_3_0_PLUS: - task_instance.dag_run.logical_date = dt.datetime(2020, 1, 1, 1, 1, 1) - else: + listener = OpenLineageListener() + listener.extractor_manager = mock.Mock() + + metadata = mock.Mock() + metadata.run_facets = {"run_facet": 1} + listener.extractor_manager.extract_metadata.return_value = metadata + + adapter = mock.Mock() + adapter.build_dag_run_id.side_effect = mock_dag_id + adapter.build_task_instance_run_id.side_effect = mock_task_id + adapter.start_task = mock.Mock() + adapter.fail_task = mock.Mock() + adapter.complete_task = mock.Mock() + listener.adapter = adapter + + task_instance = TaskInstance(task=mock.Mock()) + task_instance.dag_run = DagRun() + task_instance.dag_run.run_id = "dag_run_run_id" + task_instance.dag_run.data_interval_start = None + task_instance.dag_run.data_interval_end = None + task_instance.dag_run.clear_number = 0 task_instance.dag_run.execution_date = dt.datetime(2020, 1, 1, 1, 1, 1) - task_instance.task = mock.Mock() - task_instance.task.task_id = "task_id" - task_instance.task.dag = mock.Mock() - task_instance.task.dag.dag_id = "dag_id" - task_instance.task.dag.description = "Test DAG Description" - task_instance.task.dag.owner = "Test Owner" - task_instance.task.inlets = [] - task_instance.task.outlets = [] - task_instance.dag_id = "dag_id" - task_instance.run_id = "dag_run_run_id" - task_instance.try_number = 1 - task_instance.state = State.RUNNING - task_instance.start_date = dt.datetime(2023, 1, 1, 13, 1, 1) - task_instance.end_date = dt.datetime(2023, 1, 3, 13, 1, 1) - task_instance.logical_date = dt.datetime(2020, 1, 1, 1, 1, 1) - task_instance.map_index = -1 - task_instance.next_method = None # Ensure this is None to reach start_task - - return listener, task_instance - - [email protected]("airflow.providers.openlineage.conf.debug_mode", return_value=True) [email protected]("airflow.providers.openlineage.plugins.listener.is_operator_disabled") [email protected]("airflow.providers.openlineage.plugins.listener.get_airflow_run_facet") [email protected]("airflow.providers.openlineage.plugins.listener.get_airflow_mapped_task_facet") [email protected]("airflow.providers.openlineage.plugins.listener.get_user_provided_run_facets") [email protected]("airflow.providers.openlineage.plugins.listener.get_job_name") [email protected]("airflow.providers.openlineage.plugins.listener.OpenLineageListener._execute", new=regular_call) -def test_adapter_start_task_is_called_with_proper_arguments( - mock_get_job_name, - mock_get_airflow_mapped_task_facet, - mock_get_user_provided_run_facets, - mock_get_airflow_run_facet, - mock_disabled, - mock_debug_mode, -): - """Tests that the 'start_task' method of the OpenLineageAdapter is invoked with the correct arguments. - - The test checks that the job name, job description, event time, and other related data are - correctly passed to the adapter. It also verifies that custom facets and Airflow run facets are - correctly retrieved and included in the call. This ensures that all relevant data, including custom - and Airflow-specific metadata, is accurately conveyed to the adapter during the initialization of a task, - reflecting the comprehensive tracking of task execution contexts.""" - - listener, task_instance = _create_listener_and_task_instance() - mock_get_job_name.return_value = "job_name" - mock_get_airflow_mapped_task_facet.return_value = {"mapped_facet": 1} - mock_get_user_provided_run_facets.return_value = {"custom_user_facet": 2} - mock_get_airflow_run_facet.return_value = {"airflow_run_facet": 3} - mock_disabled.return_value = False - - listener.on_task_instance_running(None, task_instance, None) - listener.adapter.start_task.assert_called_once_with( - run_id="2020-01-01T01:01:01.dag_id.task_id.1.-1", - job_name="job_name", - job_description="Test DAG Description", - event_time="2023-01-01T13:01:01", - parent_job_name="dag_id", - parent_run_id="2020-01-01T01:01:01.dag_id.0", - code_location=None, - nominal_start_time=None, - nominal_end_time=None, - owners=["Test Owner"], - task=listener.extractor_manager.extract_metadata(), - run_facets={ - "mapped_facet": 1, - "custom_user_facet": 2, - "airflow_run_facet": 3, - "debug": AirflowDebugRunFacet(packages=ANY), - }, + task_instance.task = mock.Mock() + task_instance.task.task_id = "task_id" + task_instance.task.dag = mock.Mock() + task_instance.task.dag.dag_id = "dag_id" + task_instance.task.dag.description = "Test DAG Description" + task_instance.task.dag.owner = "Test Owner" + task_instance.task.inlets = [] + task_instance.task.outlets = [] + task_instance.dag_id = "dag_id" + task_instance.run_id = "dag_run_run_id" + task_instance.try_number = 1 + task_instance.state = State.RUNNING + task_instance.start_date = dt.datetime(2023, 1, 1, 13, 1, 1) + task_instance.end_date = dt.datetime(2023, 1, 3, 13, 1, 1) + task_instance.logical_date = dt.datetime(2020, 1, 1, 1, 1, 1) + task_instance.map_index = -1 + task_instance.next_method = None # Ensure this is None to reach start_task + task_instance.get_template_context = mock.MagicMock() # type: ignore[method-assign] + task_instance.get_template_context.return_value = defaultdict(mock.MagicMock) + task_instance.get_template_context()["task_reschedule_count"] = 0 + + return listener, task_instance + + @mock.patch("airflow.providers.openlineage.conf.debug_mode", return_value=True) + @mock.patch("airflow.providers.openlineage.plugins.listener.is_operator_disabled") + @mock.patch("airflow.providers.openlineage.plugins.listener.get_airflow_run_facet") + @mock.patch("airflow.providers.openlineage.plugins.listener.get_airflow_mapped_task_facet") + @mock.patch("airflow.providers.openlineage.plugins.listener.get_user_provided_run_facets") + @mock.patch("airflow.providers.openlineage.plugins.listener.get_job_name") + @mock.patch( + "airflow.providers.openlineage.plugins.listener.OpenLineageListener._execute", new=regular_call ) + def test_adapter_start_task_is_called_with_proper_arguments( + self, + mock_get_job_name, + mock_get_airflow_mapped_task_facet, + mock_get_user_provided_run_facets, + mock_get_airflow_run_facet, + mock_disabled, + mock_debug_mode, + ): + """Tests that the 'start_task' method of the OpenLineageAdapter is invoked with the correct arguments. + + The test checks that the job name, job description, event time, and other related data are + correctly passed to the adapter. It also verifies that custom facets and Airflow run facets are + correctly retrieved and included in the call. This ensures that all relevant data, including custom + and Airflow-specific metadata, is accurately conveyed to the adapter during the initialization of a task, + reflecting the comprehensive tracking of task execution contexts.""" + + listener, task_instance = self._create_listener_and_task_instance() + mock_get_job_name.return_value = "job_name" + mock_get_airflow_mapped_task_facet.return_value = {"mapped_facet": 1} + mock_get_user_provided_run_facets.return_value = {"custom_user_facet": 2} + mock_get_airflow_run_facet.return_value = {"airflow_run_facet": 3} + mock_disabled.return_value = False + + listener.on_task_instance_running(None, task_instance, None) + listener.adapter.start_task.assert_called_once_with( + run_id="2020-01-01T01:01:01.dag_id.task_id.1.-1", + job_name="job_name", + job_description="Test DAG Description", + event_time="2023-01-01T13:01:01", + parent_job_name="dag_id", + parent_run_id="2020-01-01T01:01:01.dag_id.0", + code_location=None, + nominal_start_time=None, + nominal_end_time=None, + owners=["Test Owner"], + task=listener.extractor_manager.extract_metadata(), + run_facets={ + "mapped_facet": 1, + "custom_user_facet": 2, + "airflow_run_facet": 3, + "debug": AirflowDebugRunFacet(packages=ANY), + }, + ) + @mock.patch("airflow.providers.openlineage.conf.debug_mode", return_value=True) + @mock.patch("airflow.providers.openlineage.plugins.listener.is_operator_disabled") + @mock.patch("airflow.providers.openlineage.plugins.listener.get_airflow_run_facet") + @mock.patch("airflow.providers.openlineage.plugins.listener.get_user_provided_run_facets") + @mock.patch("airflow.providers.openlineage.plugins.listener.get_job_name") + @mock.patch( + "airflow.providers.openlineage.plugins.listener.OpenLineageListener._execute", new=regular_call + ) + @mock.patch("airflow.utils.timezone.utcnow", return_value=dt.datetime(2023, 1, 3, 13, 1, 1)) + def test_adapter_fail_task_is_called_with_proper_arguments( + self, + mock_utcnow, + mock_get_job_name, + mock_get_user_provided_run_facets, + mock_get_airflow_run_facet, + mock_disabled, + mock_debug_mode, + ): + """Tests that the 'fail_task' method of the OpenLineageAdapter is invoked with the correct arguments. + + This test ensures that the job name is accurately retrieved and included, along with the generated + run_id and task metadata. By mocking the job name retrieval and the run_id generation, + the test verifies the integrity and consistency of the data passed to the adapter during task + failure events, thus confirming that the adapter's failure handling is functioning as expected. + """ + + listener, task_instance = self._create_listener_and_task_instance() + task_instance.logical_date = dt.datetime(2020, 1, 1, 1, 1, 1) + mock_get_job_name.return_value = "job_name" + mock_get_user_provided_run_facets.return_value = {"custom_user_facet": 2} + mock_get_airflow_run_facet.return_value = {"airflow": {"task": "..."}} + mock_disabled.return_value = False + + err = ValueError("test") + on_task_failed_listener_kwargs = {"error": err} if AIRFLOW_V_2_10_PLUS else {} + expected_err_kwargs = {"error": err if AIRFLOW_V_2_10_PLUS else None} + + listener.on_task_instance_failed( + previous_state=None, task_instance=task_instance, **on_task_failed_listener_kwargs, session=None + ) + listener.adapter.fail_task.assert_called_once_with( + end_time="2023-01-03T13:01:01", + job_name="job_name", + parent_job_name="dag_id", + parent_run_id="2020-01-01T01:01:01.dag_id.0", + run_id="2020-01-01T01:01:01.dag_id.task_id.1.-1", + task=listener.extractor_manager.extract_metadata(), + run_facets={ + "custom_user_facet": 2, + "airflow": {"task": "..."}, + "debug": AirflowDebugRunFacet(packages=ANY), + }, + **expected_err_kwargs, + ) [email protected]("airflow.providers.openlineage.conf.debug_mode", return_value=True) [email protected]("airflow.providers.openlineage.plugins.listener.is_operator_disabled") [email protected]("airflow.providers.openlineage.plugins.listener.get_airflow_run_facet") [email protected]("airflow.providers.openlineage.plugins.listener.get_user_provided_run_facets") [email protected]("airflow.providers.openlineage.plugins.listener.get_job_name") [email protected]("airflow.providers.openlineage.plugins.listener.OpenLineageListener._execute", new=regular_call) -def test_adapter_fail_task_is_called_with_proper_arguments( - mock_get_job_name, - mock_get_user_provided_run_facets, - mock_get_airflow_run_facet, - mock_disabled, - mock_debug_mode, -): - """Tests that the 'fail_task' method of the OpenLineageAdapter is invoked with the correct arguments. - - This test ensures that the job name is accurately retrieved and included, along with the generated - run_id and task metadata. By mocking the job name retrieval and the run_id generation, - the test verifies the integrity and consistency of the data passed to the adapter during task - failure events, thus confirming that the adapter's failure handling is functioning as expected. - """ - - listener, task_instance = _create_listener_and_task_instance() - task_instance.logical_date = dt.datetime(2020, 1, 1, 1, 1, 1) - mock_get_job_name.return_value = "job_name" - mock_get_user_provided_run_facets.return_value = {"custom_user_facet": 2} - mock_get_airflow_run_facet.return_value = {"airflow": {"task": "..."}} - mock_disabled.return_value = False + @mock.patch("airflow.providers.openlineage.conf.debug_mode", return_value=True) + @mock.patch("airflow.providers.openlineage.plugins.listener.is_operator_disabled") + @mock.patch("airflow.providers.openlineage.plugins.listener.get_airflow_run_facet") + @mock.patch("airflow.providers.openlineage.plugins.listener.get_user_provided_run_facets") + @mock.patch("airflow.providers.openlineage.plugins.listener.get_job_name") + @mock.patch( + "airflow.providers.openlineage.plugins.listener.OpenLineageListener._execute", new=regular_call + ) + @mock.patch("airflow.utils.timezone.utcnow", return_value=dt.datetime(2023, 1, 3, 13, 1, 1)) + def test_adapter_complete_task_is_called_with_proper_arguments( + self, + mock_utcnow, + mock_get_job_name, + mock_get_user_provided_run_facets, + mock_get_airflow_run_facet, + mock_disabled, + mock_debug_mode, + ): + """Tests that the 'complete_task' method of the OpenLineageAdapter is called with the correct arguments. + + It checks that the job name is correctly retrieved and passed, + along with the run_id and task metadata. The test also simulates changes in the try_number + attribute of the task instance, as it would occur in Airflow, to ensure that the run_id is updated + accordingly. This helps confirm the consistency and correctness of the data passed to the adapter + during the task's lifecycle events. + """ + + listener, task_instance = self._create_listener_and_task_instance() + mock_get_job_name.return_value = "job_name" + mock_get_user_provided_run_facets.return_value = {"custom_user_facet": 2} + mock_get_airflow_run_facet.return_value = {"airflow": {"task": "..."}} + mock_disabled.return_value = False + + listener.on_task_instance_success(None, task_instance, None) + # This run_id will be different as we did NOT simulate increase of the try_number attribute, + # which happens in Airflow < 2.10. + calls = listener.adapter.complete_task.call_args_list + assert len(calls) == 1 + assert calls[0][1] == dict( + end_time="2023-01-03T13:01:01", + job_name="job_name", + parent_job_name="dag_id", + parent_run_id="2020-01-01T01:01:01.dag_id.0", + run_id=f"2020-01-01T01:01:01.dag_id.task_id.{EXPECTED_TRY_NUMBER_1}.-1", + task=listener.extractor_manager.extract_metadata(), + run_facets={ + "custom_user_facet": 2, + "airflow": {"task": "..."}, + "debug": AirflowDebugRunFacet(packages=ANY), + }, + ) - err = ValueError("test") - on_task_failed_listener_kwargs = {"error": err} if AIRFLOW_V_2_10_PLUS else {} - expected_err_kwargs = {"error": err if AIRFLOW_V_2_10_PLUS else None} + @mock.patch( + "airflow.providers.openlineage.plugins.listener.OpenLineageListener._execute", new=regular_call + ) + def test_on_task_instance_running_correctly_calls_openlineage_adapter_run_id_method(self): + """Tests the OpenLineageListener's response when a task instance is in the running state. + + This test ensures that when an Airflow task instance transitions to the running state, + the OpenLineageAdapter's `build_task_instance_run_id` method is called exactly once with the correct + parameters derived from the task instance. + """ + listener, task_instance = self._create_listener_and_task_instance() + listener.on_task_instance_running(None, task_instance, None) + listener.adapter.build_task_instance_run_id.assert_called_once_with( + dag_id="dag_id", + task_id="task_id", + logical_date=dt.datetime(2020, 1, 1, 1, 1, 1), + try_number=1, + map_index=-1, + ) - listener.on_task_instance_failed( - previous_state=None, task_instance=task_instance, session=None, **on_task_failed_listener_kwargs + @mock.patch( + "airflow.providers.openlineage.plugins.listener.OpenLineageListener._execute", new=regular_call ) - listener.adapter.fail_task.assert_called_once_with( - end_time="2023-01-03T13:01:01", - job_name="job_name", - parent_job_name="dag_id", - parent_run_id="2020-01-01T01:01:01.dag_id.0", - run_id="2020-01-01T01:01:01.dag_id.task_id.1.-1", - task=listener.extractor_manager.extract_metadata(), - run_facets={ - "custom_user_facet": 2, - "airflow": {"task": "..."}, - "debug": AirflowDebugRunFacet(packages=ANY), - }, - **expected_err_kwargs, + def test_on_task_instance_failed_correctly_calls_openlineage_adapter_run_id_method(self): + """Tests the OpenLineageListener's response when a task instance is in the failed state. + + This test ensures that when an Airflow task instance transitions to the failed state, + the OpenLineageAdapter's `build_task_instance_run_id` method is called exactly once with the correct + parameters derived from the task instance. + """ + listener, task_instance = self._create_listener_and_task_instance() + on_task_failed_kwargs = {"error": ValueError("test")} if AIRFLOW_V_2_10_PLUS else {} + + listener.on_task_instance_failed( + previous_state=None, task_instance=task_instance, **on_task_failed_kwargs, session=None + ) + listener.adapter.build_task_instance_run_id.assert_called_once_with( + dag_id="dag_id", + task_id="task_id", + logical_date=dt.datetime(2020, 1, 1, 1, 1, 1), + try_number=1, + map_index=-1, + ) + + @mock.patch( + "airflow.providers.openlineage.plugins.listener.OpenLineageListener._execute", new=regular_call ) + def test_on_task_instance_success_correctly_calls_openlineage_adapter_run_id_method(self): + """Tests the OpenLineageListener's response when a task instance is in the success state. + + This test ensures that when an Airflow task instance transitions to the success state, + the OpenLineageAdapter's `build_task_instance_run_id` method is called exactly once with the correct + parameters derived from the task instance. + """ + listener, task_instance = self._create_listener_and_task_instance() + listener.on_task_instance_success(None, task_instance, None) + listener.adapter.build_task_instance_run_id.assert_called_once_with( + dag_id="dag_id", + task_id="task_id", + logical_date=dt.datetime(2020, 1, 1, 1, 1, 1), + try_number=EXPECTED_TRY_NUMBER_1, + map_index=-1, + ) + @mock.patch("airflow.models.taskinstance.get_listener_manager") + def test_listener_on_task_instance_failed_is_called_before_try_number_increment(self, mock_listener): + """Validates the listener's on-failure method is called before try_number increment happens. [email protected]("airflow.providers.openlineage.conf.debug_mode", return_value=True) [email protected]("airflow.providers.openlineage.plugins.listener.is_operator_disabled") [email protected]("airflow.providers.openlineage.plugins.listener.get_airflow_run_facet") [email protected]("airflow.providers.openlineage.plugins.listener.get_user_provided_run_facets") [email protected]("airflow.providers.openlineage.plugins.listener.get_job_name") [email protected]("airflow.providers.openlineage.plugins.listener.OpenLineageListener._execute", new=regular_call) -def test_adapter_complete_task_is_called_with_proper_arguments( - mock_get_job_name, - mock_get_user_provided_run_facets, - mock_get_airflow_run_facet, - mock_disabled, - mock_debug_mode, -): - """Tests that the 'complete_task' method of the OpenLineageAdapter is called with the correct arguments. - - It checks that the job name is correctly retrieved and passed, - along with the run_id and task metadata. The test also simulates changes in the try_number - attribute of the task instance, as it would occur in Airflow, to ensure that the run_id is updated - accordingly. This helps confirm the consistency and correctness of the data passed to the adapter - during the task's lifecycle events. - """ + This test ensures that when a task instance fails, Airflow's listener method for + task failure (`on_task_instance_failed`) is invoked before the increment of the + `try_number` attribute happens. A custom exception simulates task failure, and the test + captures the `try_number` at the moment of this method call. + """ + captured_try_numbers = {} + self._setup_mock_listener(mock_listener, captured_try_numbers) - listener, task_instance = _create_listener_and_task_instance() - mock_get_job_name.return_value = "job_name" - mock_get_user_provided_run_facets.return_value = {"custom_user_facet": 2} - mock_get_airflow_run_facet.return_value = {"airflow": {"task": "..."}} - mock_disabled.return_value = False - - listener.on_task_instance_success(None, task_instance, None) - # This run_id will be different as we did NOT simulate increase of the try_number attribute, - # which happens in Airflow < 2.10. - calls = listener.adapter.complete_task.call_args_list - assert len(calls) == 1 - assert calls[0][1] == dict( - end_time="2023-01-03T13:01:01", - job_name="job_name", - parent_job_name="dag_id", - parent_run_id="2020-01-01T01:01:01.dag_id.0", - run_id=f"2020-01-01T01:01:01.dag_id.task_id.{EXPECTED_TRY_NUMBER_1}.-1", - task=listener.extractor_manager.extract_metadata(), - run_facets={ - "custom_user_facet": 2, - "airflow": {"task": "..."}, - "debug": AirflowDebugRunFacet(packages=ANY), - }, - ) + # Just to make sure no error interferes with the test, and we do not suppress it by accident + class CustomError(Exception): + pass + def fail_callable(**kwargs): + raise CustomError("Simulated task failure") [email protected]("airflow.providers.openlineage.plugins.listener.OpenLineageListener._execute", new=regular_call) -def test_on_task_instance_running_correctly_calls_openlineage_adapter_run_id_method(): - """Tests the OpenLineageListener's response when a task instance is in the running state. + _, task_instance = self._create_test_dag_and_task(fail_callable, "failure") + # try_number before execution + assert task_instance.try_number == TRY_NUMBER_BEFORE_EXECUTION + with suppress(CustomError): + task_instance.run() - This test ensures that when an Airflow task instance transitions to the running state, - the OpenLineageAdapter's `build_task_instance_run_id` method is called exactly once with the correct - parameters derived from the task instance. - """ - listener, task_instance = _create_listener_and_task_instance() - listener.on_task_instance_running(None, task_instance, None) - listener.adapter.build_task_instance_run_id.assert_called_once_with( - dag_id="dag_id", - task_id="task_id", - logical_date=dt.datetime(2020, 1, 1, 1, 1, 1), - try_number=1, - map_index=-1, - ) + # try_number at the moment of function being called + assert captured_try_numbers["running"] == TRY_NUMBER_RUNNING + assert captured_try_numbers["failed"] == TRY_NUMBER_FAILED + # try_number after task has been executed + assert task_instance.try_number == TRY_NUMBER_AFTER_EXECUTION [email protected]("airflow.providers.openlineage.plugins.listener.OpenLineageListener._execute", new=regular_call) -def test_on_task_instance_failed_correctly_calls_openlineage_adapter_run_id_method(): - """Tests the OpenLineageListener's response when a task instance is in the failed state. + @mock.patch("airflow.models.taskinstance.get_listener_manager") + def test_listener_on_task_instance_success_is_called_after_try_number_increment(self, mock_listener): + """Validates the listener's on-success method is called before try_number increment happens. - This test ensures that when an Airflow task instance transitions to the failed state, - the OpenLineageAdapter's `build_task_instance_run_id` method is called exactly once with the correct - parameters derived from the task instance. - """ - listener, task_instance = _create_listener_and_task_instance() - on_task_failed_kwargs = {"error": ValueError("test")} if AIRFLOW_V_2_10_PLUS else {} + This test ensures that when a task instance successfully completes, the + `on_task_instance_success` method of Airflow's listener is called with an + incremented `try_number` compared to the `try_number` before execution. + The test simulates a successful task execution and captures the `try_number` at the method call. + """ + captured_try_numbers = {} + self._setup_mock_listener(mock_listener, captured_try_numbers) + + def success_callable(**kwargs): + return None + + _, task_instance = self._create_test_dag_and_task(success_callable, "success") + # try_number before execution + assert task_instance.try_number == TRY_NUMBER_BEFORE_EXECUTION + task_instance.run() + + # try_number at the moment of function being called + assert captured_try_numbers["running"] == TRY_NUMBER_RUNNING + assert captured_try_numbers["success"] == TRY_NUMBER_SUCCESS + + # try_number after task has been executed + assert task_instance.try_number == TRY_NUMBER_AFTER_EXECUTION + + @mock.patch("airflow.providers.openlineage.plugins.listener.is_operator_disabled") + @mock.patch("airflow.providers.openlineage.plugins.listener.get_airflow_run_facet") + @mock.patch("airflow.providers.openlineage.plugins.listener.get_user_provided_run_facets") + @mock.patch("airflow.providers.openlineage.plugins.listener.get_job_name") + def test_listener_on_task_instance_running_do_not_call_adapter_when_disabled_operator( + self, mock_get_job_name, mock_get_user_provided_run_facets, mock_get_airflow_run_facet, mock_disabled + ): + listener, task_instance = self._create_listener_and_task_instance() + mock_get_job_name.return_value = "job_name" + mock_get_user_provided_run_facets.return_value = {"custom_facet": 2} + mock_get_airflow_run_facet.return_value = {"airflow_run_facet": 3} + mock_disabled.return_value = True + + listener.on_task_instance_running(None, task_instance, None) + mock_disabled.assert_called_once_with(task_instance.task) + listener.adapter.build_dag_run_id.assert_not_called() + listener.adapter.build_task_instance_run_id.assert_not_called() + listener.extractor_manager.extract_metadata.assert_not_called() + listener.adapter.start_task.assert_not_called() + + @mock.patch("airflow.providers.openlineage.plugins.listener.is_operator_disabled") + @mock.patch("airflow.providers.openlineage.plugins.listener.get_user_provided_run_facets") + @mock.patch("airflow.providers.openlineage.plugins.listener.get_job_name") + def test_listener_on_task_instance_failed_do_not_call_adapter_when_disabled_operator( + self, mock_get_job_name, mock_get_user_provided_run_facets, mock_disabled + ): + listener, task_instance = self._create_listener_and_task_instance() + mock_get_user_provided_run_facets.return_value = {"custom_facet": 2} + mock_disabled.return_value = True + + on_task_failed_kwargs = {"error": ValueError("test")} if AIRFLOW_V_2_10_PLUS else {} - listener.on_task_instance_failed( - previous_state=None, task_instance=task_instance, session=None, **on_task_failed_kwargs + listener.on_task_instance_failed( + previous_state=None, task_instance=task_instance, **on_task_failed_kwargs, session=None + ) + mock_disabled.assert_called_once_with(task_instance.task) + listener.adapter.build_dag_run_id.assert_not_called() + listener.adapter.build_task_instance_run_id.assert_not_called() + listener.extractor_manager.extract_metadata.assert_not_called() + listener.adapter.fail_task.assert_not_called() + + @mock.patch("airflow.providers.openlineage.plugins.listener.is_operator_disabled") + @mock.patch("airflow.providers.openlineage.plugins.listener.get_user_provided_run_facets") + @mock.patch("airflow.providers.openlineage.plugins.listener.get_job_name") + def test_listener_on_task_instance_success_do_not_call_adapter_when_disabled_operator( + self, mock_get_job_name, mock_get_user_provided_run_facets, mock_disabled + ): + listener, task_instance = self._create_listener_and_task_instance() + mock_get_user_provided_run_facets.return_value = {"custom_facet": 2} + mock_disabled.return_value = True + + listener.on_task_instance_success(None, task_instance, None) + mock_disabled.assert_called_once_with(task_instance.task) + listener.adapter.build_dag_run_id.assert_not_called() + listener.adapter.build_task_instance_run_id.assert_not_called() + listener.extractor_manager.extract_metadata.assert_not_called() + listener.adapter.complete_task.assert_not_called() + + @pytest.mark.parametrize( + "max_workers,expected", + [ + (None, 1), + ("8", 8), + ], ) - listener.adapter.build_task_instance_run_id.assert_called_once_with( - dag_id="dag_id", - task_id="task_id", - logical_date=dt.datetime(2020, 1, 1, 1, 1, 1), - try_number=1, - map_index=-1, + @mock.patch("airflow.providers.openlineage.plugins.listener.ProcessPoolExecutor", autospec=True) + def test_listener_on_dag_run_state_changes_configure_process_pool_size( + self, mock_executor, max_workers, expected + ): + """mock ProcessPoolExecutor and check if conf.dag_state_change_process_pool_size is applied to max_workers""" + listener = OpenLineageListener() + # mock ProcessPoolExecutor class + with conf_vars({("openlineage", "dag_state_change_process_pool_size"): max_workers}): + listener.on_dag_run_running(mock.MagicMock(), None) + mock_executor.assert_called_once_with(max_workers=expected, initializer=mock.ANY) + mock_executor.return_value.submit.assert_called_once() + + @pytest.mark.parametrize( + ("method", "dag_run_state"), + [ + ("on_dag_run_running", DagRunState.RUNNING), + ("on_dag_run_success", DagRunState.SUCCESS), + ("on_dag_run_failed", DagRunState.FAILED), + ], ) + @patch("airflow.providers.openlineage.plugins.adapter.OpenLineageAdapter.emit") + def test_listener_on_dag_run_state_changes(self, mock_emit, method, dag_run_state, create_task_instance): + mock_executor = MockExecutor() + ti = create_task_instance(dag_id="dag", task_id="op") + # Change the state explicitly to set end_date following the logic in the method + ti.dag_run.set_state(dag_run_state) + with mock.patch( + "airflow.providers.openlineage.plugins.listener.ProcessPoolExecutor", return_value=mock_executor + ): + listener = OpenLineageListener() + getattr(listener, method)(ti.dag_run, None) + assert mock_executor.submitted is True + assert mock_executor.succeeded is True + mock_emit.assert_called_once() + def test_listener_logs_failed_serialization(self): + listener = OpenLineageListener() + callback_future = Future() [email protected]("airflow.providers.openlineage.plugins.listener.OpenLineageListener._execute", new=regular_call) -def test_on_task_instance_success_correctly_calls_openlineage_adapter_run_id_method(): - """Tests the OpenLineageListener's response when a task instance is in the success state. + def set_result(*args, **kwargs): + callback_future.set_result(True) - This test ensures that when an Airflow task instance transitions to the success state, - the OpenLineageAdapter's `build_task_instance_run_id` method is called exactly once with the correct - parameters derived from the task instance. - """ - listener, task_instance = _create_listener_and_task_instance() - listener.on_task_instance_success(None, task_instance, None) - listener.adapter.build_task_instance_run_id.assert_called_once_with( - dag_id="dag_id", - task_id="task_id", - logical_date=dt.datetime(2020, 1, 1, 1, 1, 1), - try_number=EXPECTED_TRY_NUMBER_1, - map_index=-1, - ) + listener.log = MagicMock() + listener.log.warning = MagicMock(side_effect=set_result) + listener.adapter = OpenLineageAdapter( + client=OpenLineageClient(transport=ConsoleTransport(config=ConsoleConfig())) + ) + event_time = dt.datetime.now() + fut = listener.submit_callable( + listener.adapter.dag_failed, + dag_id="", + run_id="", + end_date=event_time, + logical_date=callback_future, + clear_number=0, + dag_run_state=DagRunState.FAILED, + task_ids=["task_id"], + msg="", + ) + assert fut.exception(10) + callback_future.result(10) + assert callback_future.done() + listener.log.debug.assert_not_called() + listener.log.warning.assert_called_once() [email protected]("airflow.models.taskinstance.get_listener_manager") -def test_listener_on_task_instance_failed_is_called_before_try_number_increment(mock_listener): - """Validates the listener's on-failure method is called before try_number increment happens. [email protected] +def mock_supervisor_comms(): + with mock.patch( + "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True + ) as supervisor_comms: + yield supervisor_comms + + [email protected](not AIRFLOW_V_3_0_PLUS, reason="Airflow 3 tests") +class TestOpenLineageListenerAirflow3: + @pytest.mark.skip("Rendering fields is not migrated yet in Airflow 3") + @patch("airflow.models.BaseOperator.render_template") + def test_listener_does_not_change_task_instance(self, render_mock, mock_supervisor_comms, spy_agency): + from airflow.sdk.execution_time.task_runner import ( + RuntimeTaskInstance, + TaskInstance as SdkTaskInstance, + run, + ) - This test ensures that when a task instance fails, Airflow's listener method for - task failure (`on_task_instance_failed`) is invoked before the increment of the - `try_number` attribute happens. A custom exception simulates task failure, and the test - captures the `try_number` at the moment of this method call. - """ - captured_try_numbers = {} - _setup_mock_listener(mock_listener, captured_try_numbers) + render_mock.return_value = render_df() - # Just to make sure no error interferes with the test, and we do not suppress it by accident - class CustomError(Exception): - pass + date = dt.datetime(2022, 1, 1) + dag = DAG( + "test", + schedule=None, + start_date=dt.datetime(2022, 1, 1), + user_defined_macros={"render_df": render_df}, + params={"df": {"col": [1, 2]}}, + ) + task = TemplateOperator(task_id="template_op", dag=dag, do_xcom_push=True, df=dag.param("df")) + run_id = str(uuid.uuid1()) - def fail_callable(**kwargs): - raise CustomError("Simulated task failure") + dagrun_kwargs = { + "dag_version": None, + "logical_date": date, + "triggered_by": types.DagRunTriggeredByType.TEST, + } - _, task_instance = _create_test_dag_and_task(fail_callable, "failure") - # try_number before execution - assert task_instance.try_number == TRY_NUMBER_BEFORE_EXECUTION - with suppress(CustomError): - task_instance.run() + dag.create_dagrun( + run_id=run_id, + data_interval=(date, date), + run_type=types.DagRunType.MANUAL, + state=DagRunState.QUEUED, + **dagrun_kwargs, + ) + ti = SdkTaskInstance( + id=uuid7(), + task_id="template_op", + dag_id=dag.dag_id, + run_id=run_id, + try_number=1, + start_date=timezone.utcnow(), + map_index=-1, + ) - # try_number at the moment of function being called - assert captured_try_numbers["running"] == TRY_NUMBER_RUNNING - assert captured_try_numbers["failed"] == TRY_NUMBER_FAILED + runtime_ti = RuntimeTaskInstance.model_construct(**ti.model_dump(exclude_unset=True), task=task) - # try_number after task has been executed - assert task_instance.try_number == TRY_NUMBER_AFTER_EXECUTION + spy_agency.spy_on(runtime_ti.xcom_push, call_original=False) + run(runtime_ti, None) + # check if task returns the same DataFrame + pd.testing.assert_frame_equal(runtime_ti.xcom_push.last_call.args[1], render_df()) [email protected]("airflow.models.taskinstance.get_listener_manager") -def test_listener_on_task_instance_success_is_called_after_try_number_increment(mock_listener): - """Validates the listener's on-success method is called before try_number increment happens. + # check if render_template method always get the same unrendered field + assert not isinstance(runtime_ti.xcom_push.last_call.args[1], pd.DataFrame) - This test ensures that when a task instance successfully completes, the - `on_task_instance_success` method of Airflow's listener is called with an - incremented `try_number` compared to the `try_number` before execution. - The test simulates a successful task execution and captures the `try_number` at the method call. - """ - captured_try_numbers = {} - _setup_mock_listener(mock_listener, captured_try_numbers) - - def success_callable(**kwargs): - return None - - _, task_instance = _create_test_dag_and_task(success_callable, "success") - # try_number before execution - assert task_instance.try_number == TRY_NUMBER_BEFORE_EXECUTION - task_instance.run() - - # try_number at the moment of function being called - assert captured_try_numbers["running"] == TRY_NUMBER_RUNNING - assert captured_try_numbers["success"] == TRY_NUMBER_SUCCESS - - # try_number after task has been executed - assert task_instance.try_number == TRY_NUMBER_AFTER_EXECUTION - - [email protected]("airflow.providers.openlineage.plugins.listener.is_operator_disabled") [email protected]("airflow.providers.openlineage.plugins.listener.get_airflow_run_facet") [email protected]("airflow.providers.openlineage.plugins.listener.get_user_provided_run_facets") [email protected]("airflow.providers.openlineage.plugins.listener.get_job_name") -def test_listener_on_task_instance_running_do_not_call_adapter_when_disabled_operator( - mock_get_job_name, mock_get_user_provided_run_facets, mock_get_airflow_run_facet, mock_disabled -): - listener, task_instance = _create_listener_and_task_instance() - mock_get_job_name.return_value = "job_name" - mock_get_user_provided_run_facets.return_value = {"custom_facet": 2} - mock_get_airflow_run_facet.return_value = {"airflow_run_facet": 3} - mock_disabled.return_value = True - - listener.on_task_instance_running(None, task_instance, None) - mock_disabled.assert_called_once_with(task_instance.task) - listener.adapter.build_dag_run_id.assert_not_called() - listener.adapter.build_task_instance_run_id.assert_not_called() - listener.extractor_manager.extract_metadata.assert_not_called() - listener.adapter.start_task.assert_not_called() - - [email protected]("airflow.providers.openlineage.plugins.listener.is_operator_disabled") [email protected]("airflow.providers.openlineage.plugins.listener.get_user_provided_run_facets") [email protected]("airflow.providers.openlineage.plugins.listener.get_job_name") -def test_listener_on_task_instance_failed_do_not_call_adapter_when_disabled_operator( - mock_get_job_name, mock_get_user_provided_run_facets, mock_disabled -): - listener, task_instance = _create_listener_and_task_instance() - mock_get_user_provided_run_facets.return_value = {"custom_facet": 2} - mock_disabled.return_value = True - - on_task_failed_kwargs = {"error": ValueError("test")} if AIRFLOW_V_2_10_PLUS else {} - - listener.on_task_instance_failed( - previous_state=None, task_instance=task_instance, session=None, **on_task_failed_kwargs + def _setup_mock_listener(self, mock_listener: mock.Mock, captured_try_numbers: dict[str, int]) -> None: + """Sets up the mock listener with side effects to capture try numbers for different task instance events. + + :param mock_listener: The mock object for the listener manager. + :param captured_try_numbers: A dictionary to store captured try numbers keyed by event names. + + This function iterates through specified event names and sets a side effect on the corresponding + method of the listener manager's hook. The side effect is a nested function that captures the try number + of the task instance when the method is called. + + :Example: + + captured_try_numbers = {} + mock_listener = Mock() + _setup_mock_listener(mock_listener, captured_try_numbers) + # After running a task, captured_try_numbers will have the try number captured at the moment of + execution for specified methods. F.e. {"running": 1, "success": 2} for on_task_instance_running and + on_task_instance_success methods. + """ + + def capture_try_number(method_name): + def inner(*args, **kwargs): + captured_try_numbers[method_name] = kwargs["task_instance"].try_number + + return inner + + for event in ["running", "success", "failed"]: + getattr( + mock_listener.return_value.hook, f"on_task_instance_{event}" + ).side_effect = capture_try_number(event) + + def _create_test_dag_and_task( + self, python_callable: Callable, scenario_name: str + ) -> tuple[DagRun, TaskInstance]: + """Creates a test DAG and a task for a custom test scenario. + + :param python_callable: The Python callable to be executed by the PythonOperator. + :param scenario_name: The name of the test scenario, used to uniquely name the DAG and task. + + :return: TaskInstance: The created TaskInstance object. + + This function creates a DAG and a PythonOperator task with the provided python_callable. It generates a unique + run ID and creates a DAG run. This setup is useful for testing different scenarios in Airflow tasks. + + :Example: + + def sample_callable(**kwargs): + print("Hello World") + + task_instance = _create_test_dag_and_task(sample_callable, "sample_scenario") + # Use task_instance to simulate running a task in a test. + """ + dag = DAG( + f"test_{scenario_name}", + schedule=None, + start_date=dt.datetime(2022, 1, 1), + ) + t = PythonOperator(task_id=f"test_task_{scenario_name}", dag=dag, python_callable=python_callable) + run_id = str(uuid.uuid1()) + triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} + dagrun = dag.create_dagrun( + state=State.NONE, # type: ignore + run_id=run_id, + **triggered_by_kwargs, # type: ignore + ) + task_instance = TaskInstance(t, run_id=run_id) + return dagrun, task_instance + + def _create_listener_and_task_instance(self) -> tuple[OpenLineageListener, RuntimeTaskInstance]: + """Creates and configures an OpenLineageListener instance and a mock TaskInstance for testing. + + :return: A tuple containing the configured OpenLineageListener and TaskInstance. + + This function instantiates an OpenLineageListener, sets up its required properties with mock objects, and + creates a mock TaskInstance with predefined attributes. This setup is commonly used for testing the + interaction between an OpenLineageListener and a TaskInstance in Airflow. + + :Example: + + listener, task_instance = _create_listener_and_task_instance() + # Now you can use listener and task_instance in your tests to simulate their interaction. + """ + + from airflow.sdk.api.datamodels._generated import ( + DagRun as SdkDagRun, + DagRunType, + TaskInstance as SdkTaskInstance, + TIRunContext, + ) + from airflow.sdk.definitions.dag import DAG + from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance + + def mock_dag_id(dag_id, logical_date, clear_number): + return f"{logical_date.isoformat()}.{dag_id}.{clear_number}" + + def mock_task_id(dag_id, task_id, try_number, logical_date, map_index): + return f"{logical_date.isoformat()}.{dag_id}.{task_id}.{try_number}.{map_index}" + + listener = OpenLineageListener() + listener.extractor_manager = mock.Mock() + + metadata = mock.Mock() + metadata.run_facets = {"run_facet": 1} + listener.extractor_manager.extract_metadata.return_value = metadata + + adapter = mock.Mock() + adapter.build_dag_run_id.side_effect = mock_dag_id + adapter.build_task_instance_run_id.side_effect = mock_task_id + adapter.start_task = mock.Mock() + adapter.fail_task = mock.Mock() + adapter.complete_task = mock.Mock() + listener.adapter = adapter + + dag = DAG( + dag_id="dag_id", + description="Test DAG Description", + ) + task = EmptyOperator(task_id="task_id", dag=dag, owner="Test Owner") + + task_instance = SdkTaskInstance( + id=uuid7(), + task_id="task_id", + dag_id="dag_id", + run_id="dag_run_run_id", + try_number=1, + start_date=dt.datetime(2023, 1, 1, 13, 1, 1), + map_index=-1, + ) + runtime_ti = RuntimeTaskInstance.model_construct( + **task_instance.model_dump(exclude_unset=True), + task=task, + _ti_context_from_server=TIRunContext( + dag_run=SdkDagRun( + dag_id="dag_id", + run_id="dag_run_run_id", + logical_date=dt.datetime(2020, 1, 1, 1, 1, 1), + data_interval_start=None, + data_interval_end=None, + start_date=dt.datetime(2023, 1, 1, 13, 1, 1), + end_date=dt.datetime(2023, 1, 3, 13, 1, 1), + clear_number=0, + run_type=DagRunType.MANUAL, + conf=None, + ), + task_reschedule_count=0, + ), + ) + + return listener, runtime_ti + + @mock.patch("airflow.providers.openlineage.conf.debug_mode", return_value=True) + @mock.patch("airflow.providers.openlineage.plugins.listener.is_operator_disabled") + @mock.patch("airflow.providers.openlineage.plugins.listener.get_airflow_run_facet") + @mock.patch("airflow.providers.openlineage.plugins.listener.get_airflow_mapped_task_facet") + @mock.patch("airflow.providers.openlineage.plugins.listener.get_user_provided_run_facets") + @mock.patch("airflow.providers.openlineage.plugins.listener.get_job_name") + @mock.patch( + "airflow.providers.openlineage.plugins.listener.OpenLineageListener._execute", new=regular_call ) - mock_disabled.assert_called_once_with(task_instance.task) - listener.adapter.build_dag_run_id.assert_not_called() - listener.adapter.build_task_instance_run_id.assert_not_called() - listener.extractor_manager.extract_metadata.assert_not_called() - listener.adapter.fail_task.assert_not_called() - - [email protected]("airflow.providers.openlineage.plugins.listener.is_operator_disabled") [email protected]("airflow.providers.openlineage.plugins.listener.get_user_provided_run_facets") [email protected]("airflow.providers.openlineage.plugins.listener.get_job_name") -def test_listener_on_task_instance_success_do_not_call_adapter_when_disabled_operator( - mock_get_job_name, mock_get_user_provided_run_facets, mock_disabled -): - listener, task_instance = _create_listener_and_task_instance() - mock_get_user_provided_run_facets.return_value = {"custom_facet": 2} - mock_disabled.return_value = True - - listener.on_task_instance_success(None, task_instance, None) - mock_disabled.assert_called_once_with(task_instance.task) - listener.adapter.build_dag_run_id.assert_not_called() - listener.adapter.build_task_instance_run_id.assert_not_called() - listener.extractor_manager.extract_metadata.assert_not_called() - listener.adapter.complete_task.assert_not_called() - - [email protected]( - "max_workers,expected", - [ - (None, 1), - ("8", 8), - ], -) [email protected]("airflow.providers.openlineage.plugins.listener.ProcessPoolExecutor", autospec=True) -def test_listener_on_dag_run_state_changes_configure_process_pool_size(mock_executor, max_workers, expected): - """mock ProcessPoolExecutor and check if conf.dag_state_change_process_pool_size is applied to max_workers""" - listener = OpenLineageListener() - # mock ProcessPoolExecutor class - with conf_vars({("openlineage", "dag_state_change_process_pool_size"): max_workers}): - listener.on_dag_run_running(mock.MagicMock(), None) - mock_executor.assert_called_once_with(max_workers=expected, initializer=mock.ANY) - mock_executor.return_value.submit.assert_called_once() + def test_adapter_start_task_is_called_with_proper_arguments( + self, + mock_get_job_name, + mock_get_airflow_mapped_task_facet, + mock_get_user_provided_run_facets, + mock_get_airflow_run_facet, + mock_disabled, + mock_debug_mode, + ): + """Tests that the 'start_task' method of the OpenLineageAdapter is invoked with the correct arguments. + + The test checks that the job name, job description, event time, and other related data are + correctly passed to the adapter. It also verifies that custom facets and Airflow run facets are + correctly retrieved and included in the call. This ensures that all relevant data, including custom + and Airflow-specific metadata, is accurately conveyed to the adapter during the initialization of a task, + reflecting the comprehensive tracking of task execution contexts.""" + + listener, task_instance = self._create_listener_and_task_instance() + mock_get_job_name.return_value = "job_name" + mock_get_airflow_mapped_task_facet.return_value = {"mapped_facet": 1} + mock_get_user_provided_run_facets.return_value = {"custom_user_facet": 2} + mock_get_airflow_run_facet.return_value = {"airflow_run_facet": 3} + mock_disabled.return_value = False + + listener.on_task_instance_running(None, task_instance) + listener.adapter.start_task.assert_called_once_with( + run_id="2020-01-01T01:01:01.dag_id.task_id.1.-1", + job_name="job_name", + job_description="Test DAG Description", + event_time="2023-01-01T13:01:01", + parent_job_name="dag_id", + parent_run_id="2020-01-01T01:01:01.dag_id.0", + code_location=None, + nominal_start_time=None, + nominal_end_time=None, + owners=["Test Owner"], + task=listener.extractor_manager.extract_metadata(), + run_facets={ + "mapped_facet": 1, + "custom_user_facet": 2, + "airflow_run_facet": 3, + "debug": AirflowDebugRunFacet(packages=ANY), + }, + ) + @mock.patch("airflow.providers.openlineage.conf.debug_mode", return_value=True) + @mock.patch("airflow.providers.openlineage.plugins.listener.is_operator_disabled") + @mock.patch("airflow.providers.openlineage.plugins.listener.get_airflow_run_facet") + @mock.patch("airflow.providers.openlineage.plugins.listener.get_user_provided_run_facets") + @mock.patch("airflow.providers.openlineage.plugins.listener.get_job_name") + @mock.patch( + "airflow.providers.openlineage.plugins.listener.OpenLineageListener._execute", new=regular_call + ) + @mock.patch("airflow.utils.timezone.utcnow", return_value=dt.datetime(2023, 1, 3, 13, 1, 1)) + def test_adapter_fail_task_is_called_with_proper_arguments( + self, + mock_utcnow, + mock_get_job_name, + mock_get_user_provided_run_facets, + mock_get_airflow_run_facet, + mock_disabled, + mock_debug_mode, + ): + """Tests that the 'fail_task' method of the OpenLineageAdapter is invoked with the correct arguments. + + This test ensures that the job name is accurately retrieved and included, along with the generated + run_id and task metadata. By mocking the job name retrieval and the run_id generation, + the test verifies the integrity and consistency of the data passed to the adapter during task + failure events, thus confirming that the adapter's failure handling is functioning as expected. + """ + + listener, task_instance = self._create_listener_and_task_instance() + task_instance.get_template_context()["dag_run"].logical_date = dt.datetime(2020, 1, 1, 1, 1, 1) + mock_get_job_name.return_value = "job_name" + mock_get_user_provided_run_facets.return_value = {"custom_user_facet": 2} + mock_get_airflow_run_facet.return_value = {"airflow": {"task": "..."}} + mock_disabled.return_value = False + + err = ValueError("test") + on_task_failed_listener_kwargs = {"error": err} if AIRFLOW_V_2_10_PLUS else {} + expected_err_kwargs = {"error": err if AIRFLOW_V_2_10_PLUS else None} + + listener.on_task_instance_failed( + previous_state=None, task_instance=task_instance, **on_task_failed_listener_kwargs + ) + listener.adapter.fail_task.assert_called_once_with( + end_time="2023-01-03T13:01:01", + job_name="job_name", + parent_job_name="dag_id", + parent_run_id="2020-01-01T01:01:01.dag_id.0", + run_id="2020-01-01T01:01:01.dag_id.task_id.1.-1", + task=listener.extractor_manager.extract_metadata(), + run_facets={ + "custom_user_facet": 2, + "airflow": {"task": "..."}, + "debug": AirflowDebugRunFacet(packages=ANY), + }, + **expected_err_kwargs, + ) -class MockExecutor: - def __init__(self, *args, **kwargs): - self.submitted = False - self.succeeded = False - self.result = None + @mock.patch("airflow.providers.openlineage.conf.debug_mode", return_value=True) + @mock.patch("airflow.providers.openlineage.plugins.listener.is_operator_disabled") + @mock.patch("airflow.providers.openlineage.plugins.listener.get_airflow_run_facet") + @mock.patch("airflow.providers.openlineage.plugins.listener.get_user_provided_run_facets") + @mock.patch("airflow.providers.openlineage.plugins.listener.get_job_name") + @mock.patch( + "airflow.providers.openlineage.plugins.listener.OpenLineageListener._execute", new=regular_call + ) + @mock.patch("airflow.utils.timezone.utcnow", return_value=dt.datetime(2023, 1, 3, 13, 1, 1)) + def test_adapter_complete_task_is_called_with_proper_arguments( + self, + mock_utcnow, + mock_get_job_name, + mock_get_user_provided_run_facets, + mock_get_airflow_run_facet, + mock_disabled, + mock_debug_mode, + ): + """Tests that the 'complete_task' method of the OpenLineageAdapter is called with the correct arguments. + + It checks that the job name is correctly retrieved and passed, + along with the run_id and task metadata. The test also simulates changes in the try_number + attribute of the task instance, as it would occur in Airflow, to ensure that the run_id is updated + accordingly. This helps confirm the consistency and correctness of the data passed to the adapter + during the task's lifecycle events. + """ + + listener, task_instance = self._create_listener_and_task_instance() + mock_get_job_name.return_value = "job_name" + mock_get_user_provided_run_facets.return_value = {"custom_user_facet": 2} + mock_get_airflow_run_facet.return_value = {"airflow": {"task": "..."}} + mock_disabled.return_value = False + + listener.on_task_instance_success(None, task_instance) + # This run_id will be different as we did NOT simulate increase of the try_number attribute, + # which happens in Airflow < 2.10. + calls = listener.adapter.complete_task.call_args_list + assert len(calls) == 1 + assert calls[0][1] == dict( + end_time="2023-01-03T13:01:01", + job_name="job_name", + parent_job_name="dag_id", + parent_run_id="2020-01-01T01:01:01.dag_id.0", + run_id=f"2020-01-01T01:01:01.dag_id.task_id.{EXPECTED_TRY_NUMBER_1}.-1", + task=listener.extractor_manager.extract_metadata(), + run_facets={ + "custom_user_facet": 2, + "airflow": {"task": "..."}, + "debug": AirflowDebugRunFacet(packages=ANY), + }, + ) - def submit(self, fn, /, *args, **kwargs): - self.submitted = True - try: - fn(*args, **kwargs) - self.succeeded = True - except Exception: + @mock.patch( + "airflow.providers.openlineage.plugins.listener.OpenLineageListener._execute", new=regular_call + ) + def test_on_task_instance_running_correctly_calls_openlineage_adapter_run_id_method(self): + """Tests the OpenLineageListener's response when a task instance is in the running state. + + This test ensures that when an Airflow task instance transitions to the running state, + the OpenLineageAdapter's `build_task_instance_run_id` method is called exactly once with the correct + parameters derived from the task instance. + """ + listener, task_instance = self._create_listener_and_task_instance() + listener.on_task_instance_running(None, task_instance) + listener.adapter.build_task_instance_run_id.assert_called_once_with( + dag_id="dag_id", + task_id="task_id", + logical_date=dt.datetime(2020, 1, 1, 1, 1, 1), + try_number=1, + map_index=-1, + ) + + @mock.patch( + "airflow.providers.openlineage.plugins.listener.OpenLineageListener._execute", new=regular_call + ) + def test_on_task_instance_failed_correctly_calls_openlineage_adapter_run_id_method(self): + """Tests the OpenLineageListener's response when a task instance is in the failed state. + + This test ensures that when an Airflow task instance transitions to the failed state, + the OpenLineageAdapter's `build_task_instance_run_id` method is called exactly once with the correct + parameters derived from the task instance. + """ + listener, task_instance = self._create_listener_and_task_instance() + on_task_failed_kwargs = {"error": ValueError("test")} if AIRFLOW_V_2_10_PLUS else {} + + listener.on_task_instance_failed( + previous_state=None, task_instance=task_instance, **on_task_failed_kwargs + ) + listener.adapter.build_task_instance_run_id.assert_called_once_with( + dag_id="dag_id", + task_id="task_id", + logical_date=dt.datetime(2020, 1, 1, 1, 1, 1), + try_number=1, + map_index=-1, + ) + + @mock.patch( + "airflow.providers.openlineage.plugins.listener.OpenLineageListener._execute", new=regular_call + ) + def test_on_task_instance_success_correctly_calls_openlineage_adapter_run_id_method(self): + """Tests the OpenLineageListener's response when a task instance is in the success state. + + This test ensures that when an Airflow task instance transitions to the success state, + the OpenLineageAdapter's `build_task_instance_run_id` method is called exactly once with the correct + parameters derived from the task instance. + """ + listener, task_instance = self._create_listener_and_task_instance() + listener.on_task_instance_success(None, task_instance) + listener.adapter.build_task_instance_run_id.assert_called_once_with( + dag_id="dag_id", + task_id="task_id", + logical_date=dt.datetime(2020, 1, 1, 1, 1, 1), + try_number=EXPECTED_TRY_NUMBER_1, + map_index=-1, + ) + + @mock.patch("airflow.models.taskinstance.get_listener_manager") + def test_listener_on_task_instance_failed_is_called_before_try_number_increment(self, mock_listener): + """Validates the listener's on-failure method is called before try_number increment happens. + + This test ensures that when a task instance fails, Airflow's listener method for + task failure (`on_task_instance_failed`) is invoked before the increment of the + `try_number` attribute happens. A custom exception simulates task failure, and the test + captures the `try_number` at the moment of this method call. + """ + captured_try_numbers = {} + self._setup_mock_listener(mock_listener, captured_try_numbers) + + # Just to make sure no error interferes with the test, and we do not suppress it by accident + class CustomError(Exception): pass - return MagicMock() - def shutdown(self, *args, **kwargs): - print("Shutting down") + def fail_callable(**kwargs): + raise CustomError("Simulated task failure") + _, task_instance = self._create_test_dag_and_task(fail_callable, "failure") + # try_number before execution + assert task_instance.try_number == TRY_NUMBER_BEFORE_EXECUTION + with suppress(CustomError): + task_instance.run() [email protected]( - ("method", "dag_run_state"), - [ - ("on_dag_run_running", DagRunState.RUNNING), - ("on_dag_run_success", DagRunState.SUCCESS), - ("on_dag_run_failed", DagRunState.FAILED), - ], -) -@patch("airflow.providers.openlineage.plugins.adapter.OpenLineageAdapter.emit") -def test_listener_on_dag_run_state_changes(mock_emit, method, dag_run_state, create_task_instance): - mock_executor = MockExecutor() - ti = create_task_instance(dag_id="dag", task_id="op") - # Change the state explicitly to set end_date following the logic in the method - ti.dag_run.set_state(dag_run_state) - with mock.patch( - "airflow.providers.openlineage.plugins.listener.ProcessPoolExecutor", return_value=mock_executor + # try_number at the moment of function being called + assert captured_try_numbers["running"] == TRY_NUMBER_RUNNING + assert captured_try_numbers["failed"] == TRY_NUMBER_FAILED + + # try_number after task has been executed + assert task_instance.try_number == TRY_NUMBER_AFTER_EXECUTION + + @mock.patch("airflow.models.taskinstance.get_listener_manager") + def test_listener_on_task_instance_success_is_called_after_try_number_increment(self, mock_listener): + """Validates the listener's on-success method is called before try_number increment happens. + + This test ensures that when a task instance successfully completes, the + `on_task_instance_success` method of Airflow's listener is called with an + incremented `try_number` compared to the `try_number` before execution. + The test simulates a successful task execution and captures the `try_number` at the method call. + """ + captured_try_numbers = {} + self._setup_mock_listener(mock_listener, captured_try_numbers) + + def success_callable(**kwargs): + return None + + _, task_instance = self._create_test_dag_and_task(success_callable, "success") + # try_number before execution + assert task_instance.try_number == TRY_NUMBER_BEFORE_EXECUTION + task_instance.run() + + # try_number at the moment of function being called + assert captured_try_numbers["running"] == TRY_NUMBER_RUNNING + assert captured_try_numbers["success"] == TRY_NUMBER_SUCCESS + + # try_number after task has been executed + assert task_instance.try_number == TRY_NUMBER_AFTER_EXECUTION + + @mock.patch("airflow.providers.openlineage.plugins.listener.is_operator_disabled") + @mock.patch("airflow.providers.openlineage.plugins.listener.get_airflow_run_facet") + @mock.patch("airflow.providers.openlineage.plugins.listener.get_user_provided_run_facets") + @mock.patch("airflow.providers.openlineage.plugins.listener.get_job_name") + def test_listener_on_task_instance_running_do_not_call_adapter_when_disabled_operator( + self, mock_get_job_name, mock_get_user_provided_run_facets, mock_get_airflow_run_facet, mock_disabled ): - listener = OpenLineageListener() - getattr(listener, method)(ti.dag_run, None) - assert mock_executor.submitted is True - assert mock_executor.succeeded is True - mock_emit.assert_called_once() + listener, task_instance = self._create_listener_and_task_instance() + mock_get_job_name.return_value = "job_name" + mock_get_user_provided_run_facets.return_value = {"custom_facet": 2} + mock_get_airflow_run_facet.return_value = {"airflow_run_facet": 3} + mock_disabled.return_value = True + + listener.on_task_instance_running(None, task_instance) + mock_disabled.assert_called_once_with(task_instance.task) + listener.adapter.build_dag_run_id.assert_not_called() + listener.adapter.build_task_instance_run_id.assert_not_called() + listener.extractor_manager.extract_metadata.assert_not_called() + listener.adapter.start_task.assert_not_called() + + @mock.patch("airflow.providers.openlineage.plugins.listener.is_operator_disabled") + @mock.patch("airflow.providers.openlineage.plugins.listener.get_user_provided_run_facets") + @mock.patch("airflow.providers.openlineage.plugins.listener.get_job_name") + def test_listener_on_task_instance_failed_do_not_call_adapter_when_disabled_operator( + self, mock_get_job_name, mock_get_user_provided_run_facets, mock_disabled + ): + listener, task_instance = self._create_listener_and_task_instance() + mock_get_user_provided_run_facets.return_value = {"custom_facet": 2} + mock_disabled.return_value = True + on_task_failed_kwargs = {"error": ValueError("test")} if AIRFLOW_V_2_10_PLUS else {} -def test_listener_logs_failed_serialization(): - listener = OpenLineageListener() - callback_future = Future() + listener.on_task_instance_failed( + previous_state=None, task_instance=task_instance, **on_task_failed_kwargs + ) + mock_disabled.assert_called_once_with(task_instance.task) + listener.adapter.build_dag_run_id.assert_not_called() + listener.adapter.build_task_instance_run_id.assert_not_called() + listener.extractor_manager.extract_metadata.assert_not_called() + listener.adapter.fail_task.assert_not_called() + + @mock.patch("airflow.providers.openlineage.plugins.listener.is_operator_disabled") + @mock.patch("airflow.providers.openlineage.plugins.listener.get_user_provided_run_facets") + @mock.patch("airflow.providers.openlineage.plugins.listener.get_job_name") + def test_listener_on_task_instance_success_do_not_call_adapter_when_disabled_operator( + self, mock_get_job_name, mock_get_user_provided_run_facets, mock_disabled + ): + listener, task_instance = self._create_listener_and_task_instance() + mock_get_user_provided_run_facets.return_value = {"custom_facet": 2} + mock_disabled.return_value = True - def set_result(*args, **kwargs): - callback_future.set_result(True) + listener.on_task_instance_success(None, task_instance) + mock_disabled.assert_called_once_with(task_instance.task) + listener.adapter.build_dag_run_id.assert_not_called() + listener.adapter.build_task_instance_run_id.assert_not_called() + listener.extractor_manager.extract_metadata.assert_not_called() + listener.adapter.complete_task.assert_not_called() - listener.log = MagicMock() - listener.log.warning = MagicMock(side_effect=set_result) - listener.adapter = OpenLineageAdapter( - client=OpenLineageClient(transport=ConsoleTransport(config=ConsoleConfig())) + @pytest.mark.parametrize( + "max_workers,expected", + [ + (None, 1), + ("8", 8), + ], ) - event_time = dt.datetime.now() - fut = listener.submit_callable( - listener.adapter.dag_failed, - dag_id="", - run_id="", - end_date=event_time, - logical_date=callback_future, - clear_number=0, - dag_run_state=DagRunState.FAILED, - task_ids=["task_id"], - msg="", + @mock.patch("airflow.providers.openlineage.plugins.listener.ProcessPoolExecutor", autospec=True) + def test_listener_on_dag_run_state_changes_configure_process_pool_size( + self, mock_executor, max_workers, expected + ): + """mock ProcessPoolExecutor and check if conf.dag_state_change_process_pool_size is applied to max_workers""" + listener = OpenLineageListener() + # mock ProcessPoolExecutor class + with conf_vars({("openlineage", "dag_state_change_process_pool_size"): max_workers}): + listener.on_dag_run_running(mock.MagicMock(), None) + mock_executor.assert_called_once_with(max_workers=expected, initializer=mock.ANY) + mock_executor.return_value.submit.assert_called_once() + + @pytest.mark.parametrize( + ("method", "dag_run_state"), + [ + ("on_dag_run_running", DagRunState.RUNNING), + ("on_dag_run_success", DagRunState.SUCCESS), + ("on_dag_run_failed", DagRunState.FAILED), + ], ) - assert fut.exception(10) - callback_future.result(10) - assert callback_future.done() - listener.log.debug.assert_not_called() - listener.log.warning.assert_called_once() + @patch("airflow.providers.openlineage.plugins.adapter.OpenLineageAdapter.emit") + def test_listener_on_dag_run_state_changes(self, mock_emit, method, dag_run_state, create_task_instance): + mock_executor = MockExecutor() + ti = create_task_instance(dag_id="dag", task_id="op") + # Change the state explicitly to set end_date following the logic in the method + ti.dag_run.set_state(dag_run_state) + with mock.patch( + "airflow.providers.openlineage.plugins.listener.ProcessPoolExecutor", return_value=mock_executor + ): + listener = OpenLineageListener() + getattr(listener, method)(ti.dag_run, None) + assert mock_executor.submitted is True + assert mock_executor.succeeded is True + mock_emit.assert_called_once() + + def test_listener_logs_failed_serialization(self): + listener = OpenLineageListener() + callback_future = Future() + + def set_result(*args, **kwargs): + callback_future.set_result(True) + + listener.log = MagicMock() + listener.log.warning = MagicMock(side_effect=set_result) + listener.adapter = OpenLineageAdapter( + client=OpenLineageClient(transport=ConsoleTransport(config=ConsoleConfig())) + ) + event_time = dt.datetime.now() + fut = listener.submit_callable( + listener.adapter.dag_failed, + dag_id="", + run_id="", + end_date=event_time, + logical_date=callback_future, + clear_number=0, + dag_run_state=DagRunState.FAILED, + task_ids=["task_id"], + msg="", + ) + assert fut.exception(10) + callback_future.result(10) + assert callback_future.done() + listener.log.debug.assert_not_called() + listener.log.warning.assert_called_once() -class TestOpenLineageSelectiveEnable: [email protected](AIRFLOW_V_3_0_PLUS, reason="Airflow 2 tests") +class TestOpenLineageSelectiveEnableAirflow2: def setup_method(self): date = dt.datetime(2022, 1, 1) self.dag = DAG( @@ -706,14 +1338,11 @@ class TestOpenLineageSelectiveEnable: task_id="test_task_selective_enable_2", dag=self.dag, python_callable=simple_callable ) run_id = str(uuid.uuid1()) - if AIRFLOW_V_3_0_PLUS: - dagrun_kwargs = { - "dag_version": None, - "logical_date": date, - "triggered_by": types.DagRunTriggeredByType.TEST, - } - else: - dagrun_kwargs = {"execution_date": date} + dagrun_kwargs = { + "dag_version": None, + "logical_date": date, + "triggered_by": types.DagRunTriggeredByType.TEST, + } self.dagrun = self.dag.create_dagrun( run_id=run_id, data_interval=(date, date), @@ -791,10 +1420,7 @@ class TestOpenLineageSelectiveEnable: listener.on_task_instance_running(None, self.task_instance_1, None) listener.on_task_instance_success(None, self.task_instance_1, None) listener.on_task_instance_failed( - previous_state=None, - task_instance=self.task_instance_1, - session=None, - **on_task_failed_kwargs, + previous_state=None, task_instance=self.task_instance_1, **on_task_failed_kwargs, session=None ) assert expected_task_call_count == listener.extractor_manager.extract_metadata.call_count @@ -803,10 +1429,7 @@ class TestOpenLineageSelectiveEnable: listener.on_task_instance_running(None, self.task_instance_2, None) listener.on_task_instance_success(None, self.task_instance_2, None) listener.on_task_instance_failed( - previous_state=None, - task_instance=self.task_instance_2, - session=None, - **on_task_failed_kwargs, + previous_state=None, task_instance=self.task_instance_2, **on_task_failed_kwargs, session=None ) # with selective-enable disabled both task_1 and task_2 should trigger metadata extraction @@ -850,10 +1473,10 @@ class TestOpenLineageSelectiveEnable: listener.on_dag_run_success(self.dagrun, msg="test success") # run TaskInstance-related hooks for lineage enabled task - listener.on_task_instance_running(None, self.task_instance_1, None) - listener.on_task_instance_success(None, self.task_instance_1, None) + listener.on_task_instance_running(None, self.task_instance_1, session=None) + listener.on_task_instance_success(None, self.task_instance_1, session=None) listener.on_task_instance_failed( - previous_state=None, task_instance=self.task_instance_1, session=None, **on_task_failed_kwargs + previous_state=None, task_instance=self.task_instance_1, **on_task_failed_kwargs, session=None ) assert expected_call_count == listener._executor.submit.call_count
