This is an automated email from the ASF dual-hosted git repository.

mobuchowski pushed a commit to branch openlineage-interface
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 220970eefbdb10a2b1bd34c8f6e11173084971df
Author: Maciej Obuchowski <[email protected]>
AuthorDate: Mon Jul 31 14:38:10 2023 +0200

    add OpenLineage methods as mixin interface to BaseOperator
    
    Signed-off-by: Maciej Obuchowski <[email protected]>
---
 airflow/providers/common/sql/operators/sql.py      |  3 +-
 airflow/providers/ftp/operators/ftp.py             |  3 +-
 .../providers/google/cloud/transfers/gcs_to_gcs.py |  3 +-
 airflow/providers/openlineage/extractors/base.py   | 29 ++++++++----
 airflow/providers/openlineage/extractors/bash.py   |  3 +-
 .../providers/openlineage/extractors/manager.py    | 20 +++++---
 airflow/providers/openlineage/extractors/python.py |  3 +-
 airflow/providers/openlineage/plugins/listener.py  | 10 +++-
 airflow/providers/sftp/operators/sftp.py           |  3 +-
 airflow/utils/openlineage_mixin.py                 | 39 +++++++++++++++
 .../extractors/test_default_extractor.py           | 55 ++++++++++++++--------
 11 files changed, 127 insertions(+), 44 deletions(-)

diff --git a/airflow/providers/common/sql/operators/sql.py 
b/airflow/providers/common/sql/operators/sql.py
index bf2b38f055..a3dc30fe55 100644
--- a/airflow/providers/common/sql/operators/sql.py
+++ b/airflow/providers/common/sql/operators/sql.py
@@ -27,6 +27,7 @@ from airflow.hooks.base import BaseHook
 from airflow.models import BaseOperator, SkipMixin
 from airflow.providers.common.sql.hooks.sql import DbApiHook, 
fetch_all_handler, return_single_query_results
 from airflow.utils.helpers import merge_dicts
+from airflow.utils.openlineage_mixin import OpenLineageMixin
 
 if TYPE_CHECKING:
     from airflow.providers.openlineage.extractors import OperatorLineage
@@ -188,7 +189,7 @@ class BaseSQLOperator(BaseOperator):
         raise AirflowFailException(exception_string)
 
 
-class SQLExecuteQueryOperator(BaseSQLOperator):
+class SQLExecuteQueryOperator(BaseSQLOperator, OpenLineageMixin):
     """
     Executes SQL code in a specific database.
 
diff --git a/airflow/providers/ftp/operators/ftp.py 
b/airflow/providers/ftp/operators/ftp.py
index 45bccbea4c..216bdbfef2 100644
--- a/airflow/providers/ftp/operators/ftp.py
+++ b/airflow/providers/ftp/operators/ftp.py
@@ -27,6 +27,7 @@ from typing import Any, Sequence
 
 from airflow.models import BaseOperator
 from airflow.providers.ftp.hooks.ftp import FTPHook, FTPSHook
+from airflow.utils.openlineage_mixin import OpenLineageMixin
 
 
 class FTPOperation:
@@ -36,7 +37,7 @@ class FTPOperation:
     GET = "get"
 
 
-class FTPFileTransmitOperator(BaseOperator):
+class FTPFileTransmitOperator(BaseOperator, OpenLineageMixin):
     """
     FTPFileTransmitOperator for transferring files from remote host to local 
or vice a versa.
 
diff --git a/airflow/providers/google/cloud/transfers/gcs_to_gcs.py 
b/airflow/providers/google/cloud/transfers/gcs_to_gcs.py
index d4c3bd679f..d0e8870a59 100644
--- a/airflow/providers/google/cloud/transfers/gcs_to_gcs.py
+++ b/airflow/providers/google/cloud/transfers/gcs_to_gcs.py
@@ -24,6 +24,7 @@ from typing import TYPE_CHECKING, Sequence
 from airflow.exceptions import AirflowException, 
AirflowProviderDeprecationWarning
 from airflow.models import BaseOperator
 from airflow.providers.google.cloud.hooks.gcs import GCSHook
+from airflow.utils.openlineage_mixin import OpenLineageMixin
 
 WILDCARD = "*"
 
@@ -31,7 +32,7 @@ if TYPE_CHECKING:
     from airflow.utils.context import Context
 
 
-class GCSToGCSOperator(BaseOperator):
+class GCSToGCSOperator(BaseOperator, OpenLineageMixin):
     """
     Copies objects from a bucket to another, with renaming if requested.
 
diff --git a/airflow/providers/openlineage/extractors/base.py 
b/airflow/providers/openlineage/extractors/base.py
index 95d8fa6f28..43a0b584f6 100644
--- a/airflow/providers/openlineage/extractors/base.py
+++ b/airflow/providers/openlineage/extractors/base.py
@@ -18,6 +18,7 @@
 from __future__ import annotations
 
 from abc import ABC, abstractmethod
+from contextlib import suppress
 
 from attrs import Factory, define
 from openlineage.client.facet import BaseFacet
@@ -84,20 +85,28 @@ class DefaultExtractor(BaseExtractor):
 
     def extract(self) -> OperatorLineage | None:
         # OpenLineage methods are optional - if there's no method, return None
-        try:
+        with suppress(AttributeError):
             return 
self._get_openlineage_facets(self.operator.get_openlineage_facets_on_start)  # 
type: ignore
-        except AttributeError:
-            return None
+        return None
 
     def extract_on_complete(self, task_instance) -> OperatorLineage | None:
+        """
+        For complete method, we want to handle on_failure and on_complete 
methods as priority.
+        If they are not implemented - which happens in older, 
pre-OpenLineageMixin
+        classes, we're falling back to on_start method.
+        """
         if task_instance.state == TaskInstanceState.FAILED:
-            on_failed = getattr(self.operator, 
"get_openlineage_facets_on_failure", None)
-            if on_failed and callable(on_failed):
-                return self._get_openlineage_facets(on_failed, task_instance)
-        on_complete = getattr(self.operator, 
"get_openlineage_facets_on_complete", None)
-        if on_complete and callable(on_complete):
-            return self._get_openlineage_facets(on_complete, task_instance)
-        return self.extract()
+            with suppress(AttributeError):
+                return self._get_openlineage_facets(
+                    self.operator.get_openlineage_facets_on_failure, 
task_instance
+                )
+        with suppress(AttributeError):
+            return self._get_openlineage_facets(
+                self.operator.get_openlineage_facets_on_complete, task_instance
+            )
+        with suppress(AttributeError):
+            return 
self._get_openlineage_facets(self.operator.get_openlineage_facets_on_start, 
task_instance)
+        return None
 
     def _get_openlineage_facets(self, get_facets_method, *args) -> 
OperatorLineage | None:
         try:
diff --git a/airflow/providers/openlineage/extractors/bash.py 
b/airflow/providers/openlineage/extractors/bash.py
index cfb022032e..c4db7790d7 100644
--- a/airflow/providers/openlineage/extractors/bash.py
+++ b/airflow/providers/openlineage/extractors/bash.py
@@ -25,13 +25,14 @@ from airflow.providers.openlineage.plugins.facets import (
     UnknownOperatorInstance,
 )
 from airflow.providers.openlineage.utils.utils import 
get_filtered_unknown_operator_keys, is_source_enabled
+from airflow.utils.openlineage_mixin import OpenLineageMixin
 
 """
 :meta private:
 """
 
 
-class BashExtractor(BaseExtractor):
+class BashExtractor(BaseExtractor, OpenLineageMixin):
     """
     Extract executed bash command and put it into SourceCodeJobFacet.
 
diff --git a/airflow/providers/openlineage/extractors/manager.py 
b/airflow/providers/openlineage/extractors/manager.py
index 02a4124840..8f293ae545 100644
--- a/airflow/providers/openlineage/extractors/manager.py
+++ b/airflow/providers/openlineage/extractors/manager.py
@@ -32,6 +32,7 @@ from airflow.providers.openlineage.plugins.facets import (
 from airflow.providers.openlineage.utils.utils import 
get_filtered_unknown_operator_keys
 from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.utils.module_loading import import_string
+from airflow.utils.openlineage_mixin import OpenLineageMixin
 
 if TYPE_CHECKING:
     from airflow.models import Operator
@@ -135,13 +136,9 @@ class ExtractorManager(LoggingMixin):
         if task.task_type in self.extractors:
             return self.extractors[task.task_type]
 
-        def method_exists(method_name):
-            method = getattr(task, method_name, None)
-            if method:
-                return callable(method)
-
-        if method_exists("get_openlineage_facets_on_start") or method_exists(
-            "get_openlineage_facets_on_complete"
+        # We need to handle older OpenLineage implementations that do not 
implement OpenLineageMixin
+        if isinstance(task, OpenLineageMixin) or 
self._is_backwards_compatible_openlineage_implementation(
+            task
         ):
             return self.default_extractor
         return None
@@ -199,3 +196,12 @@ class ExtractorManager(LoggingMixin):
         except AttributeError:
             self.log.error("Extractor returns non-valid metadata: %s", 
task_metadata)
             return None
+
+    @staticmethod
+    def _is_backwards_compatible_openlineage_implementation(task) -> bool:
+        # Existence of those methods is a valid OL implementation.
+        return (
+            hasattr(task, "get_openlineage_facets_on_start")
+            or hasattr(task, "get_openlineage_facets_on_complete")
+            or hasattr(task, "get_openlineage_facets_on_failure")
+        )
diff --git a/airflow/providers/openlineage/extractors/python.py 
b/airflow/providers/openlineage/extractors/python.py
index 017e6488cc..f0edad03d2 100644
--- a/airflow/providers/openlineage/extractors/python.py
+++ b/airflow/providers/openlineage/extractors/python.py
@@ -28,13 +28,14 @@ from airflow.providers.openlineage.plugins.facets import (
     UnknownOperatorInstance,
 )
 from airflow.providers.openlineage.utils.utils import 
get_filtered_unknown_operator_keys, is_source_enabled
+from airflow.utils.openlineage_mixin import OpenLineageMixin
 
 """
 :meta private:
 """
 
 
-class PythonExtractor(BaseExtractor):
+class PythonExtractor(BaseExtractor, OpenLineageMixin):
     """
     Extract executed source code and put it into SourceCodeJobFacet.
 
diff --git a/airflow/providers/openlineage/plugins/listener.py 
b/airflow/providers/openlineage/plugins/listener.py
index 99394863f5..516c8456bb 100644
--- a/airflow/providers/openlineage/plugins/listener.py
+++ b/airflow/providers/openlineage/plugins/listener.py
@@ -42,9 +42,15 @@ class OpenLineageListener:
 
     def __init__(self):
         self.log = logging.getLogger(__name__)
-        self.executor: Executor = None  # type: ignore
         self.extractor_manager = ExtractorManager()
         self.adapter = OpenLineageAdapter()
+        self._executor: Executor | None = None
+
+    @property
+    def executor(self) -> Executor:
+        if self._executor is None:
+            self._executor = ThreadPoolExecutor(max_workers=8, 
thread_name_prefix="openlineage_")
+        return self._executor
 
     @hookimpl
     def on_task_instance_running(
@@ -151,7 +157,7 @@ class OpenLineageListener:
     @hookimpl
     def on_starting(self, component):
         self.log.debug("on_starting: %s", component.__class__.__name__)
-        self.executor = ThreadPoolExecutor(max_workers=8, 
thread_name_prefix="openlineage_")
+        self.executor
 
     @hookimpl
     def before_stopping(self, component):
diff --git a/airflow/providers/sftp/operators/sftp.py 
b/airflow/providers/sftp/operators/sftp.py
index 8da4b3f332..bba92e8e9f 100644
--- a/airflow/providers/sftp/operators/sftp.py
+++ b/airflow/providers/sftp/operators/sftp.py
@@ -30,6 +30,7 @@ from airflow.exceptions import AirflowException, 
AirflowProviderDeprecationWarni
 from airflow.models import BaseOperator
 from airflow.providers.sftp.hooks.sftp import SFTPHook
 from airflow.providers.ssh.hooks.ssh import SSHHook
+from airflow.utils.openlineage_mixin import OpenLineageMixin
 
 
 class SFTPOperation:
@@ -39,7 +40,7 @@ class SFTPOperation:
     GET = "get"
 
 
-class SFTPOperator(BaseOperator):
+class SFTPOperator(BaseOperator, OpenLineageMixin):
     """
     SFTPOperator for transferring files from remote host to local or vice a 
versa.
 
diff --git a/airflow/utils/openlineage_mixin.py 
b/airflow/utils/openlineage_mixin.py
new file mode 100644
index 0000000000..8060b0f7c6
--- /dev/null
+++ b/airflow/utils/openlineage_mixin.py
@@ -0,0 +1,39 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import typing
+
+if typing.TYPE_CHECKING:
+    from airflow.models import TaskInstance
+    from airflow.providers.openlineage.extractors import OperatorLineage
+
+
+class OpenLineageMixin:
+    """
+    This interface marks implementation of OpenLineage methods,
+    allowing us to check for its existence rather than existence of particular 
methods on BaseOperator.
+    """
+
+    def get_openlineage_facets_on_start(self) -> OperatorLineage | None:
+        raise NotImplementedError()
+
+    def get_openlineage_facets_on_complete(self, task_instance: TaskInstance) 
-> OperatorLineage | None:
+        return self.get_openlineage_facets_on_start()
+
+    def get_openlineage_facets_on_fail(self, task_instance: TaskInstance) -> 
OperatorLineage | None:
+        return self.get_openlineage_facets_on_complete(task_instance)
diff --git a/tests/providers/openlineage/extractors/test_default_extractor.py 
b/tests/providers/openlineage/extractors/test_default_extractor.py
index 41d477aa9c..68be86e29e 100644
--- a/tests/providers/openlineage/extractors/test_default_extractor.py
+++ b/tests/providers/openlineage/extractors/test_default_extractor.py
@@ -31,6 +31,7 @@ from airflow.providers.openlineage.extractors.base import (
 )
 from airflow.providers.openlineage.extractors.manager import ExtractorManager
 from airflow.providers.openlineage.extractors.python import PythonExtractor
+from airflow.utils.openlineage_mixin import OpenLineageMixin
 
 INPUTS = [Dataset(namespace="database://host:port", name="inputtable")]
 OUTPUTS = [Dataset(namespace="database://host:port", name="inputtable")]
@@ -48,7 +49,7 @@ class CompleteRunFacet(BaseFacet):
 FINISHED_FACETS = {"complete": CompleteRunFacet(True)}
 
 
-class ExampleOperator(BaseOperator):
+class ExampleOperator(BaseOperator, OpenLineageMixin):
     def execute(self, context) -> Any:
         pass
 
@@ -69,7 +70,7 @@ class ExampleOperator(BaseOperator):
         )
 
 
-class OperatorWithoutComplete(BaseOperator):
+class OperatorWithoutComplete(BaseOperator, OpenLineageMixin):
     def execute(self, context) -> Any:
         pass
 
@@ -82,7 +83,7 @@ class OperatorWithoutComplete(BaseOperator):
         )
 
 
-class OperatorWithoutStart(BaseOperator):
+class OperatorWithoutStart(BaseOperator, OpenLineageMixin):
     def execute(self, context) -> Any:
         pass
 
@@ -95,7 +96,7 @@ class OperatorWithoutStart(BaseOperator):
         )
 
 
-class OperatorDifferentOperatorLineageClass(BaseOperator):
+class OperatorDifferentOperatorLineageClass(BaseOperator, OpenLineageMixin):
     def execute(self, context) -> Any:
         pass
 
@@ -119,7 +120,7 @@ class OperatorDifferentOperatorLineageClass(BaseOperator):
         )
 
 
-class OperatorWrongOperatorLineageClass(BaseOperator):
+class OperatorWrongOperatorLineageClass(BaseOperator, OpenLineageMixin):
     def execute(self, context) -> Any:
         pass
 
@@ -137,15 +138,28 @@ class OperatorWrongOperatorLineageClass(BaseOperator):
         )
 
 
-class BrokenOperator(BaseOperator):
+class BrokenOperator:
     get_openlineage_facets = []
 
     def execute(self, context) -> Any:
         pass
 
 
+class OperatorWithoutMixinButProperClass(BaseOperator):
+    def execute(self, context) -> Any:
+        pass
+
+    def get_openlineage_facets_on_start(self) -> OperatorLineage:
+        return OperatorLineage(
+            inputs=INPUTS,
+            outputs=OUTPUTS,
+            run_facets=RUN_FACETS,
+            job_facets=JOB_FACETS,
+        )
+
+
 def test_default_extraction():
-    extractor = ExtractorManager().get_extractor_class(ExampleOperator)
+    extractor = 
ExtractorManager().get_extractor_class(ExampleOperator(task_id="test"))
     assert extractor is DefaultExtractor
 
     metadata = extractor(ExampleOperator(task_id="test")).extract()
@@ -172,7 +186,7 @@ def test_default_extraction():
 
 
 def test_extraction_without_on_complete():
-    extractor = ExtractorManager().get_extractor_class(OperatorWithoutComplete)
+    extractor = 
ExtractorManager().get_extractor_class(OperatorWithoutComplete(task_id="test"))
     assert extractor is DefaultExtractor
 
     metadata = extractor(OperatorWithoutComplete(task_id="test")).extract()
@@ -196,7 +210,7 @@ def test_extraction_without_on_complete():
 
 
 def test_extraction_without_on_start():
-    extractor = ExtractorManager().get_extractor_class(OperatorWithoutStart)
+    extractor = 
ExtractorManager().get_extractor_class(OperatorWithoutStart(task_id="test"))
     assert extractor is DefaultExtractor
 
     metadata = extractor(OperatorWithoutStart(task_id="test")).extract()
@@ -217,16 +231,6 @@ def test_extraction_without_on_start():
     )
 
 
-def test_does_not_use_default_extractor_when_not_a_method():
-    extractor_class = 
ExtractorManager().get_extractor_class(BrokenOperator(task_id="a"))
-    assert extractor_class is None
-
-
-def test_does_not_use_default_extractor_when_no_get_openlineage_facets():
-    extractor_class = 
ExtractorManager().get_extractor_class(BaseOperator(task_id="b"))
-    assert extractor_class is None
-
-
 def test_does_not_use_default_extractor_when_explicite_extractor():
     extractor_class = ExtractorManager().get_extractor_class(
         PythonOperator(task_id="c", python_callable=lambda: 7)
@@ -254,3 +258,16 @@ def 
test_default_extractor_uses_wrong_operatorlineage_class():
     assert (
         ExtractorManager().extract_metadata(mock.MagicMock(), operator, 
complete=False) == OperatorLineage()
     )
+
+
+def test_default_extractor_works_without_mixin():
+    operator = OperatorWithoutMixinButProperClass(task_id="task_id")
+    extractor_class = ExtractorManager().get_extractor_class(operator)
+    assert extractor_class is DefaultExtractor
+    extractor = extractor_class(operator)
+    assert extractor.extract() == OperatorLineage(
+        inputs=INPUTS,
+        outputs=OUTPUTS,
+        run_facets=RUN_FACETS,
+        job_facets=JOB_FACETS,
+    )

Reply via email to