This is an automated email from the ASF dual-hosted git repository. mobuchowski pushed a commit to branch openlineage-interface in repository https://gitbox.apache.org/repos/asf/airflow.git
commit aa4ecade3131325d8a647f334d6038beeb2109a4 Author: Maciej Obuchowski <[email protected]> AuthorDate: Mon Jul 31 14:38:10 2023 +0200 add OpenLineage methods as mixin interface to BaseOperator Signed-off-by: Maciej Obuchowski <[email protected]> --- airflow/models/abstractoperator.py | 3 +- airflow/providers/openlineage/extractors/base.py | 11 +++--- .../providers/openlineage/extractors/manager.py | 10 ++---- airflow/providers/openlineage/plugins/listener.py | 10 ++++-- .../providers/openlineage/plugins/openlineage.py | 7 ++++ airflow/utils/openlineage_mixin.py | 39 ++++++++++++++++++++++ .../extractors/test_default_extractor.py | 18 +++------- .../openlineage/plugins/test_openlineage.py | 12 ++++++- 8 files changed, 77 insertions(+), 33 deletions(-) diff --git a/airflow/models/abstractoperator.py b/airflow/models/abstractoperator.py index 11e9184735..daf1d45a25 100644 --- a/airflow/models/abstractoperator.py +++ b/airflow/models/abstractoperator.py @@ -33,6 +33,7 @@ from airflow.template.templater import Templater from airflow.utils.context import Context from airflow.utils.db import exists_query from airflow.utils.log.secrets_masker import redact +from airflow.utils.openlineage_mixin import OpenLineageMixin from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.setup_teardown import SetupTeardownContext from airflow.utils.sqlalchemy import skip_locked, with_row_locks @@ -82,7 +83,7 @@ class NotMapped(Exception): """Raise if a task is neither mapped nor has any parent mapped groups.""" -class AbstractOperator(Templater, DAGNode): +class AbstractOperator(Templater, DAGNode, OpenLineageMixin): """Common implementation for operators, including unmapped and mapped. This base class is more about sharing implementations, not defining a common diff --git a/airflow/providers/openlineage/extractors/base.py b/airflow/providers/openlineage/extractors/base.py index 95d8fa6f28..5a61e97048 100644 --- a/airflow/providers/openlineage/extractors/base.py +++ b/airflow/providers/openlineage/extractors/base.py @@ -91,13 +91,10 @@ class DefaultExtractor(BaseExtractor): def extract_on_complete(self, task_instance) -> OperatorLineage | None: if task_instance.state == TaskInstanceState.FAILED: - on_failed = getattr(self.operator, "get_openlineage_facets_on_failure", None) - if on_failed and callable(on_failed): - return self._get_openlineage_facets(on_failed, task_instance) - on_complete = getattr(self.operator, "get_openlineage_facets_on_complete", None) - if on_complete and callable(on_complete): - return self._get_openlineage_facets(on_complete, task_instance) - return self.extract() + return self._get_openlineage_facets( + self.operator.get_openlineage_facets_on_failure, task_instance + ) + return self._get_openlineage_facets(self.operator.get_openlineage_facets_on_complete, task_instance) def _get_openlineage_facets(self, get_facets_method, *args) -> OperatorLineage | None: try: diff --git a/airflow/providers/openlineage/extractors/manager.py b/airflow/providers/openlineage/extractors/manager.py index 02a4124840..281bc09f72 100644 --- a/airflow/providers/openlineage/extractors/manager.py +++ b/airflow/providers/openlineage/extractors/manager.py @@ -32,6 +32,7 @@ from airflow.providers.openlineage.plugins.facets import ( from airflow.providers.openlineage.utils.utils import get_filtered_unknown_operator_keys from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.module_loading import import_string +from airflow.utils.openlineage_mixin import OpenLineageMixin if TYPE_CHECKING: from airflow.models import Operator @@ -135,14 +136,7 @@ class ExtractorManager(LoggingMixin): if task.task_type in self.extractors: return self.extractors[task.task_type] - def method_exists(method_name): - method = getattr(task, method_name, None) - if method: - return callable(method) - - if method_exists("get_openlineage_facets_on_start") or method_exists( - "get_openlineage_facets_on_complete" - ): + if isinstance(task, OpenLineageMixin): return self.default_extractor return None diff --git a/airflow/providers/openlineage/plugins/listener.py b/airflow/providers/openlineage/plugins/listener.py index 99394863f5..516c8456bb 100644 --- a/airflow/providers/openlineage/plugins/listener.py +++ b/airflow/providers/openlineage/plugins/listener.py @@ -42,9 +42,15 @@ class OpenLineageListener: def __init__(self): self.log = logging.getLogger(__name__) - self.executor: Executor = None # type: ignore self.extractor_manager = ExtractorManager() self.adapter = OpenLineageAdapter() + self._executor: Executor | None = None + + @property + def executor(self) -> Executor: + if self._executor is None: + self._executor = ThreadPoolExecutor(max_workers=8, thread_name_prefix="openlineage_") + return self._executor @hookimpl def on_task_instance_running( @@ -151,7 +157,7 @@ class OpenLineageListener: @hookimpl def on_starting(self, component): self.log.debug("on_starting: %s", component.__class__.__name__) - self.executor = ThreadPoolExecutor(max_workers=8, thread_name_prefix="openlineage_") + self.executor @hookimpl def before_stopping(self, component): diff --git a/airflow/providers/openlineage/plugins/openlineage.py b/airflow/providers/openlineage/plugins/openlineage.py index 2ec0801147..178bb59a8d 100644 --- a/airflow/providers/openlineage/plugins/openlineage.py +++ b/airflow/providers/openlineage/plugins/openlineage.py @@ -16,14 +16,19 @@ # under the License. from __future__ import annotations +import logging import os from airflow.configuration import conf from airflow.plugins_manager import AirflowPlugin from airflow.providers.openlineage.plugins.macros import lineage_parent_id, lineage_run_id +log = logging.getLogger("airflow") + def _is_disabled() -> bool: + log.error(conf.getboolean("openlineage", "disabled")) + log.error(os.getenv("OPENLINEAGE_DISABLED", "false").lower()) return ( conf.getboolean("openlineage", "disabled") or os.getenv("OPENLINEAGE_DISABLED", "false").lower() == "true" @@ -41,6 +46,8 @@ class OpenLineageProviderPlugin(AirflowPlugin): name = "OpenLineageProviderPlugin" macros = [lineage_run_id, lineage_parent_id] if not _is_disabled(): + log.error("EEE?") + log.error(_is_disabled()) from airflow.providers.openlineage.plugins.listener import OpenLineageListener listeners = [OpenLineageListener()] diff --git a/airflow/utils/openlineage_mixin.py b/airflow/utils/openlineage_mixin.py new file mode 100644 index 0000000000..d3773e9e37 --- /dev/null +++ b/airflow/utils/openlineage_mixin.py @@ -0,0 +1,39 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import typing + +if typing.TYPE_CHECKING: + from airflow.models import TaskInstance + from airflow.providers.openlineage.extractors import OperatorLineage + + +class OpenLineageMixin: + """ + This interface marks implementation of OpenLineage methods, + allowing us to check for it's existence rather than existence of particular methods on BaseOperator. + """ + + def get_openlineage_facets_on_start(self) -> OperatorLineage: + raise NotImplementedError() + + def get_openlineage_facets_on_complete(self, task_instance: TaskInstance) -> OperatorLineage: + return self.get_openlineage_facets_on_start() + + def get_openlineage_facets_on_fail(self, task_instance: TaskInstance) -> OperatorLineage: + return self.get_openlineage_facets_on_complete(task_instance) diff --git a/tests/providers/openlineage/extractors/test_default_extractor.py b/tests/providers/openlineage/extractors/test_default_extractor.py index 41d477aa9c..e3c9221093 100644 --- a/tests/providers/openlineage/extractors/test_default_extractor.py +++ b/tests/providers/openlineage/extractors/test_default_extractor.py @@ -137,7 +137,7 @@ class OperatorWrongOperatorLineageClass(BaseOperator): ) -class BrokenOperator(BaseOperator): +class BrokenOperator: get_openlineage_facets = [] def execute(self, context) -> Any: @@ -145,7 +145,7 @@ class BrokenOperator(BaseOperator): def test_default_extraction(): - extractor = ExtractorManager().get_extractor_class(ExampleOperator) + extractor = ExtractorManager().get_extractor_class(ExampleOperator(task_id="test")) assert extractor is DefaultExtractor metadata = extractor(ExampleOperator(task_id="test")).extract() @@ -172,7 +172,7 @@ def test_default_extraction(): def test_extraction_without_on_complete(): - extractor = ExtractorManager().get_extractor_class(OperatorWithoutComplete) + extractor = ExtractorManager().get_extractor_class(OperatorWithoutComplete(task_id="test")) assert extractor is DefaultExtractor metadata = extractor(OperatorWithoutComplete(task_id="test")).extract() @@ -196,7 +196,7 @@ def test_extraction_without_on_complete(): def test_extraction_without_on_start(): - extractor = ExtractorManager().get_extractor_class(OperatorWithoutStart) + extractor = ExtractorManager().get_extractor_class(OperatorWithoutStart(task_id="test")) assert extractor is DefaultExtractor metadata = extractor(OperatorWithoutStart(task_id="test")).extract() @@ -217,16 +217,6 @@ def test_extraction_without_on_start(): ) -def test_does_not_use_default_extractor_when_not_a_method(): - extractor_class = ExtractorManager().get_extractor_class(BrokenOperator(task_id="a")) - assert extractor_class is None - - -def test_does_not_use_default_extractor_when_no_get_openlineage_facets(): - extractor_class = ExtractorManager().get_extractor_class(BaseOperator(task_id="b")) - assert extractor_class is None - - def test_does_not_use_default_extractor_when_explicite_extractor(): extractor_class = ExtractorManager().get_extractor_class( PythonOperator(task_id="c", python_callable=lambda: 7) diff --git a/tests/providers/openlineage/plugins/test_openlineage.py b/tests/providers/openlineage/plugins/test_openlineage.py index 4fcb0f287d..3975260c08 100644 --- a/tests/providers/openlineage/plugins/test_openlineage.py +++ b/tests/providers/openlineage/plugins/test_openlineage.py @@ -17,6 +17,7 @@ from __future__ import annotations import contextlib +import logging import os import sys from unittest.mock import patch @@ -25,6 +26,8 @@ import pytest from tests.test_utils.config import conf_vars +log = logging.getLogger("airflow") + class TestOpenLineageProviderPlugin: def setup_method(self): @@ -55,7 +58,14 @@ class TestOpenLineageProviderPlugin: with contextlib.ExitStack() as stack: for mock in mocks: stack.enter_context(mock) - from airflow.providers.openlineage.plugins.openlineage import OpenLineageProviderPlugin + from airflow.providers.openlineage.plugins.openlineage import ( + OpenLineageProviderPlugin, + _is_disabled, + ) plugin = OpenLineageProviderPlugin() + + log.error(_is_disabled()) + log.error(plugin.listeners) + log.error(expected) assert len(plugin.listeners) == expected
