This is an automated email from the ASF dual-hosted git repository.

mobuchowski pushed a commit to branch openlineage-interface
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit aa4ecade3131325d8a647f334d6038beeb2109a4
Author: Maciej Obuchowski <[email protected]>
AuthorDate: Mon Jul 31 14:38:10 2023 +0200

    add OpenLineage methods as mixin interface to BaseOperator
    
    Signed-off-by: Maciej Obuchowski <[email protected]>
---
 airflow/models/abstractoperator.py                 |  3 +-
 airflow/providers/openlineage/extractors/base.py   | 11 +++---
 .../providers/openlineage/extractors/manager.py    | 10 ++----
 airflow/providers/openlineage/plugins/listener.py  | 10 ++++--
 .../providers/openlineage/plugins/openlineage.py   |  7 ++++
 airflow/utils/openlineage_mixin.py                 | 39 ++++++++++++++++++++++
 .../extractors/test_default_extractor.py           | 18 +++-------
 .../openlineage/plugins/test_openlineage.py        | 12 ++++++-
 8 files changed, 77 insertions(+), 33 deletions(-)

diff --git a/airflow/models/abstractoperator.py 
b/airflow/models/abstractoperator.py
index 11e9184735..daf1d45a25 100644
--- a/airflow/models/abstractoperator.py
+++ b/airflow/models/abstractoperator.py
@@ -33,6 +33,7 @@ from airflow.template.templater import Templater
 from airflow.utils.context import Context
 from airflow.utils.db import exists_query
 from airflow.utils.log.secrets_masker import redact
+from airflow.utils.openlineage_mixin import OpenLineageMixin
 from airflow.utils.session import NEW_SESSION, provide_session
 from airflow.utils.setup_teardown import SetupTeardownContext
 from airflow.utils.sqlalchemy import skip_locked, with_row_locks
@@ -82,7 +83,7 @@ class NotMapped(Exception):
     """Raise if a task is neither mapped nor has any parent mapped groups."""
 
 
-class AbstractOperator(Templater, DAGNode):
+class AbstractOperator(Templater, DAGNode, OpenLineageMixin):
     """Common implementation for operators, including unmapped and mapped.
 
     This base class is more about sharing implementations, not defining a 
common
diff --git a/airflow/providers/openlineage/extractors/base.py 
b/airflow/providers/openlineage/extractors/base.py
index 95d8fa6f28..5a61e97048 100644
--- a/airflow/providers/openlineage/extractors/base.py
+++ b/airflow/providers/openlineage/extractors/base.py
@@ -91,13 +91,10 @@ class DefaultExtractor(BaseExtractor):
 
     def extract_on_complete(self, task_instance) -> OperatorLineage | None:
         if task_instance.state == TaskInstanceState.FAILED:
-            on_failed = getattr(self.operator, 
"get_openlineage_facets_on_failure", None)
-            if on_failed and callable(on_failed):
-                return self._get_openlineage_facets(on_failed, task_instance)
-        on_complete = getattr(self.operator, 
"get_openlineage_facets_on_complete", None)
-        if on_complete and callable(on_complete):
-            return self._get_openlineage_facets(on_complete, task_instance)
-        return self.extract()
+            return self._get_openlineage_facets(
+                self.operator.get_openlineage_facets_on_failure, task_instance
+            )
+        return 
self._get_openlineage_facets(self.operator.get_openlineage_facets_on_complete, 
task_instance)
 
     def _get_openlineage_facets(self, get_facets_method, *args) -> 
OperatorLineage | None:
         try:
diff --git a/airflow/providers/openlineage/extractors/manager.py 
b/airflow/providers/openlineage/extractors/manager.py
index 02a4124840..281bc09f72 100644
--- a/airflow/providers/openlineage/extractors/manager.py
+++ b/airflow/providers/openlineage/extractors/manager.py
@@ -32,6 +32,7 @@ from airflow.providers.openlineage.plugins.facets import (
 from airflow.providers.openlineage.utils.utils import 
get_filtered_unknown_operator_keys
 from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.utils.module_loading import import_string
+from airflow.utils.openlineage_mixin import OpenLineageMixin
 
 if TYPE_CHECKING:
     from airflow.models import Operator
@@ -135,14 +136,7 @@ class ExtractorManager(LoggingMixin):
         if task.task_type in self.extractors:
             return self.extractors[task.task_type]
 
-        def method_exists(method_name):
-            method = getattr(task, method_name, None)
-            if method:
-                return callable(method)
-
-        if method_exists("get_openlineage_facets_on_start") or method_exists(
-            "get_openlineage_facets_on_complete"
-        ):
+        if isinstance(task, OpenLineageMixin):
             return self.default_extractor
         return None
 
diff --git a/airflow/providers/openlineage/plugins/listener.py 
b/airflow/providers/openlineage/plugins/listener.py
index 99394863f5..516c8456bb 100644
--- a/airflow/providers/openlineage/plugins/listener.py
+++ b/airflow/providers/openlineage/plugins/listener.py
@@ -42,9 +42,15 @@ class OpenLineageListener:
 
     def __init__(self):
         self.log = logging.getLogger(__name__)
-        self.executor: Executor = None  # type: ignore
         self.extractor_manager = ExtractorManager()
         self.adapter = OpenLineageAdapter()
+        self._executor: Executor | None = None
+
+    @property
+    def executor(self) -> Executor:
+        if self._executor is None:
+            self._executor = ThreadPoolExecutor(max_workers=8, 
thread_name_prefix="openlineage_")
+        return self._executor
 
     @hookimpl
     def on_task_instance_running(
@@ -151,7 +157,7 @@ class OpenLineageListener:
     @hookimpl
     def on_starting(self, component):
         self.log.debug("on_starting: %s", component.__class__.__name__)
-        self.executor = ThreadPoolExecutor(max_workers=8, 
thread_name_prefix="openlineage_")
+        self.executor
 
     @hookimpl
     def before_stopping(self, component):
diff --git a/airflow/providers/openlineage/plugins/openlineage.py 
b/airflow/providers/openlineage/plugins/openlineage.py
index 2ec0801147..178bb59a8d 100644
--- a/airflow/providers/openlineage/plugins/openlineage.py
+++ b/airflow/providers/openlineage/plugins/openlineage.py
@@ -16,14 +16,19 @@
 # under the License.
 from __future__ import annotations
 
+import logging
 import os
 
 from airflow.configuration import conf
 from airflow.plugins_manager import AirflowPlugin
 from airflow.providers.openlineage.plugins.macros import lineage_parent_id, 
lineage_run_id
 
+log = logging.getLogger("airflow")
+
 
 def _is_disabled() -> bool:
+    log.error(conf.getboolean("openlineage", "disabled"))
+    log.error(os.getenv("OPENLINEAGE_DISABLED", "false").lower())
     return (
         conf.getboolean("openlineage", "disabled")
         or os.getenv("OPENLINEAGE_DISABLED", "false").lower() == "true"
@@ -41,6 +46,8 @@ class OpenLineageProviderPlugin(AirflowPlugin):
     name = "OpenLineageProviderPlugin"
     macros = [lineage_run_id, lineage_parent_id]
     if not _is_disabled():
+        log.error("EEE?")
+        log.error(_is_disabled())
         from airflow.providers.openlineage.plugins.listener import 
OpenLineageListener
 
         listeners = [OpenLineageListener()]
diff --git a/airflow/utils/openlineage_mixin.py 
b/airflow/utils/openlineage_mixin.py
new file mode 100644
index 0000000000..d3773e9e37
--- /dev/null
+++ b/airflow/utils/openlineage_mixin.py
@@ -0,0 +1,39 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import typing
+
+if typing.TYPE_CHECKING:
+    from airflow.models import TaskInstance
+    from airflow.providers.openlineage.extractors import OperatorLineage
+
+
+class OpenLineageMixin:
+    """
+    This interface marks implementation of OpenLineage methods,
+    allowing us to check for it's existence rather than existence of 
particular methods on BaseOperator.
+    """
+
+    def get_openlineage_facets_on_start(self) -> OperatorLineage:
+        raise NotImplementedError()
+
+    def get_openlineage_facets_on_complete(self, task_instance: TaskInstance) 
-> OperatorLineage:
+        return self.get_openlineage_facets_on_start()
+
+    def get_openlineage_facets_on_fail(self, task_instance: TaskInstance) -> 
OperatorLineage:
+        return self.get_openlineage_facets_on_complete(task_instance)
diff --git a/tests/providers/openlineage/extractors/test_default_extractor.py 
b/tests/providers/openlineage/extractors/test_default_extractor.py
index 41d477aa9c..e3c9221093 100644
--- a/tests/providers/openlineage/extractors/test_default_extractor.py
+++ b/tests/providers/openlineage/extractors/test_default_extractor.py
@@ -137,7 +137,7 @@ class OperatorWrongOperatorLineageClass(BaseOperator):
         )
 
 
-class BrokenOperator(BaseOperator):
+class BrokenOperator:
     get_openlineage_facets = []
 
     def execute(self, context) -> Any:
@@ -145,7 +145,7 @@ class BrokenOperator(BaseOperator):
 
 
 def test_default_extraction():
-    extractor = ExtractorManager().get_extractor_class(ExampleOperator)
+    extractor = 
ExtractorManager().get_extractor_class(ExampleOperator(task_id="test"))
     assert extractor is DefaultExtractor
 
     metadata = extractor(ExampleOperator(task_id="test")).extract()
@@ -172,7 +172,7 @@ def test_default_extraction():
 
 
 def test_extraction_without_on_complete():
-    extractor = ExtractorManager().get_extractor_class(OperatorWithoutComplete)
+    extractor = 
ExtractorManager().get_extractor_class(OperatorWithoutComplete(task_id="test"))
     assert extractor is DefaultExtractor
 
     metadata = extractor(OperatorWithoutComplete(task_id="test")).extract()
@@ -196,7 +196,7 @@ def test_extraction_without_on_complete():
 
 
 def test_extraction_without_on_start():
-    extractor = ExtractorManager().get_extractor_class(OperatorWithoutStart)
+    extractor = 
ExtractorManager().get_extractor_class(OperatorWithoutStart(task_id="test"))
     assert extractor is DefaultExtractor
 
     metadata = extractor(OperatorWithoutStart(task_id="test")).extract()
@@ -217,16 +217,6 @@ def test_extraction_without_on_start():
     )
 
 
-def test_does_not_use_default_extractor_when_not_a_method():
-    extractor_class = 
ExtractorManager().get_extractor_class(BrokenOperator(task_id="a"))
-    assert extractor_class is None
-
-
-def test_does_not_use_default_extractor_when_no_get_openlineage_facets():
-    extractor_class = 
ExtractorManager().get_extractor_class(BaseOperator(task_id="b"))
-    assert extractor_class is None
-
-
 def test_does_not_use_default_extractor_when_explicite_extractor():
     extractor_class = ExtractorManager().get_extractor_class(
         PythonOperator(task_id="c", python_callable=lambda: 7)
diff --git a/tests/providers/openlineage/plugins/test_openlineage.py 
b/tests/providers/openlineage/plugins/test_openlineage.py
index 4fcb0f287d..3975260c08 100644
--- a/tests/providers/openlineage/plugins/test_openlineage.py
+++ b/tests/providers/openlineage/plugins/test_openlineage.py
@@ -17,6 +17,7 @@
 from __future__ import annotations
 
 import contextlib
+import logging
 import os
 import sys
 from unittest.mock import patch
@@ -25,6 +26,8 @@ import pytest
 
 from tests.test_utils.config import conf_vars
 
+log = logging.getLogger("airflow")
+
 
 class TestOpenLineageProviderPlugin:
     def setup_method(self):
@@ -55,7 +58,14 @@ class TestOpenLineageProviderPlugin:
         with contextlib.ExitStack() as stack:
             for mock in mocks:
                 stack.enter_context(mock)
-            from airflow.providers.openlineage.plugins.openlineage import 
OpenLineageProviderPlugin
+            from airflow.providers.openlineage.plugins.openlineage import (
+                OpenLineageProviderPlugin,
+                _is_disabled,
+            )
 
             plugin = OpenLineageProviderPlugin()
+
+            log.error(_is_disabled())
+            log.error(plugin.listeners)
+            log.error(expected)
             assert len(plugin.listeners) == expected

Reply via email to