This is an automated email from the ASF dual-hosted git repository. mobuchowski pushed a commit to branch openlineage-interface in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 220970eefbdb10a2b1bd34c8f6e11173084971df Author: Maciej Obuchowski <[email protected]> AuthorDate: Mon Jul 31 14:38:10 2023 +0200 add OpenLineage methods as mixin interface to BaseOperator Signed-off-by: Maciej Obuchowski <[email protected]> --- airflow/providers/common/sql/operators/sql.py | 3 +- airflow/providers/ftp/operators/ftp.py | 3 +- .../providers/google/cloud/transfers/gcs_to_gcs.py | 3 +- airflow/providers/openlineage/extractors/base.py | 29 ++++++++---- airflow/providers/openlineage/extractors/bash.py | 3 +- .../providers/openlineage/extractors/manager.py | 20 +++++--- airflow/providers/openlineage/extractors/python.py | 3 +- airflow/providers/openlineage/plugins/listener.py | 10 +++- airflow/providers/sftp/operators/sftp.py | 3 +- airflow/utils/openlineage_mixin.py | 39 +++++++++++++++ .../extractors/test_default_extractor.py | 55 ++++++++++++++-------- 11 files changed, 127 insertions(+), 44 deletions(-) diff --git a/airflow/providers/common/sql/operators/sql.py b/airflow/providers/common/sql/operators/sql.py index bf2b38f055..a3dc30fe55 100644 --- a/airflow/providers/common/sql/operators/sql.py +++ b/airflow/providers/common/sql/operators/sql.py @@ -27,6 +27,7 @@ from airflow.hooks.base import BaseHook from airflow.models import BaseOperator, SkipMixin from airflow.providers.common.sql.hooks.sql import DbApiHook, fetch_all_handler, return_single_query_results from airflow.utils.helpers import merge_dicts +from airflow.utils.openlineage_mixin import OpenLineageMixin if TYPE_CHECKING: from airflow.providers.openlineage.extractors import OperatorLineage @@ -188,7 +189,7 @@ class BaseSQLOperator(BaseOperator): raise AirflowFailException(exception_string) -class SQLExecuteQueryOperator(BaseSQLOperator): +class SQLExecuteQueryOperator(BaseSQLOperator, OpenLineageMixin): """ Executes SQL code in a specific database. diff --git a/airflow/providers/ftp/operators/ftp.py b/airflow/providers/ftp/operators/ftp.py index 45bccbea4c..216bdbfef2 100644 --- a/airflow/providers/ftp/operators/ftp.py +++ b/airflow/providers/ftp/operators/ftp.py @@ -27,6 +27,7 @@ from typing import Any, Sequence from airflow.models import BaseOperator from airflow.providers.ftp.hooks.ftp import FTPHook, FTPSHook +from airflow.utils.openlineage_mixin import OpenLineageMixin class FTPOperation: @@ -36,7 +37,7 @@ class FTPOperation: GET = "get" -class FTPFileTransmitOperator(BaseOperator): +class FTPFileTransmitOperator(BaseOperator, OpenLineageMixin): """ FTPFileTransmitOperator for transferring files from remote host to local or vice a versa. diff --git a/airflow/providers/google/cloud/transfers/gcs_to_gcs.py b/airflow/providers/google/cloud/transfers/gcs_to_gcs.py index d4c3bd679f..d0e8870a59 100644 --- a/airflow/providers/google/cloud/transfers/gcs_to_gcs.py +++ b/airflow/providers/google/cloud/transfers/gcs_to_gcs.py @@ -24,6 +24,7 @@ from typing import TYPE_CHECKING, Sequence from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.gcs import GCSHook +from airflow.utils.openlineage_mixin import OpenLineageMixin WILDCARD = "*" @@ -31,7 +32,7 @@ if TYPE_CHECKING: from airflow.utils.context import Context -class GCSToGCSOperator(BaseOperator): +class GCSToGCSOperator(BaseOperator, OpenLineageMixin): """ Copies objects from a bucket to another, with renaming if requested. diff --git a/airflow/providers/openlineage/extractors/base.py b/airflow/providers/openlineage/extractors/base.py index 95d8fa6f28..43a0b584f6 100644 --- a/airflow/providers/openlineage/extractors/base.py +++ b/airflow/providers/openlineage/extractors/base.py @@ -18,6 +18,7 @@ from __future__ import annotations from abc import ABC, abstractmethod +from contextlib import suppress from attrs import Factory, define from openlineage.client.facet import BaseFacet @@ -84,20 +85,28 @@ class DefaultExtractor(BaseExtractor): def extract(self) -> OperatorLineage | None: # OpenLineage methods are optional - if there's no method, return None - try: + with suppress(AttributeError): return self._get_openlineage_facets(self.operator.get_openlineage_facets_on_start) # type: ignore - except AttributeError: - return None + return None def extract_on_complete(self, task_instance) -> OperatorLineage | None: + """ + For complete method, we want to handle on_failure and on_complete methods as priority. + If they are not implemented - which happens in older, pre-OpenLineageMixin + classes, we're falling back to on_start method. + """ if task_instance.state == TaskInstanceState.FAILED: - on_failed = getattr(self.operator, "get_openlineage_facets_on_failure", None) - if on_failed and callable(on_failed): - return self._get_openlineage_facets(on_failed, task_instance) - on_complete = getattr(self.operator, "get_openlineage_facets_on_complete", None) - if on_complete and callable(on_complete): - return self._get_openlineage_facets(on_complete, task_instance) - return self.extract() + with suppress(AttributeError): + return self._get_openlineage_facets( + self.operator.get_openlineage_facets_on_failure, task_instance + ) + with suppress(AttributeError): + return self._get_openlineage_facets( + self.operator.get_openlineage_facets_on_complete, task_instance + ) + with suppress(AttributeError): + return self._get_openlineage_facets(self.operator.get_openlineage_facets_on_start, task_instance) + return None def _get_openlineage_facets(self, get_facets_method, *args) -> OperatorLineage | None: try: diff --git a/airflow/providers/openlineage/extractors/bash.py b/airflow/providers/openlineage/extractors/bash.py index cfb022032e..c4db7790d7 100644 --- a/airflow/providers/openlineage/extractors/bash.py +++ b/airflow/providers/openlineage/extractors/bash.py @@ -25,13 +25,14 @@ from airflow.providers.openlineage.plugins.facets import ( UnknownOperatorInstance, ) from airflow.providers.openlineage.utils.utils import get_filtered_unknown_operator_keys, is_source_enabled +from airflow.utils.openlineage_mixin import OpenLineageMixin """ :meta private: """ -class BashExtractor(BaseExtractor): +class BashExtractor(BaseExtractor, OpenLineageMixin): """ Extract executed bash command and put it into SourceCodeJobFacet. diff --git a/airflow/providers/openlineage/extractors/manager.py b/airflow/providers/openlineage/extractors/manager.py index 02a4124840..8f293ae545 100644 --- a/airflow/providers/openlineage/extractors/manager.py +++ b/airflow/providers/openlineage/extractors/manager.py @@ -32,6 +32,7 @@ from airflow.providers.openlineage.plugins.facets import ( from airflow.providers.openlineage.utils.utils import get_filtered_unknown_operator_keys from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.module_loading import import_string +from airflow.utils.openlineage_mixin import OpenLineageMixin if TYPE_CHECKING: from airflow.models import Operator @@ -135,13 +136,9 @@ class ExtractorManager(LoggingMixin): if task.task_type in self.extractors: return self.extractors[task.task_type] - def method_exists(method_name): - method = getattr(task, method_name, None) - if method: - return callable(method) - - if method_exists("get_openlineage_facets_on_start") or method_exists( - "get_openlineage_facets_on_complete" + # We need to handle older OpenLineage implementations that do not implement OpenLineageMixin + if isinstance(task, OpenLineageMixin) or self._is_backwards_compatible_openlineage_implementation( + task ): return self.default_extractor return None @@ -199,3 +196,12 @@ class ExtractorManager(LoggingMixin): except AttributeError: self.log.error("Extractor returns non-valid metadata: %s", task_metadata) return None + + @staticmethod + def _is_backwards_compatible_openlineage_implementation(task) -> bool: + # Existence of those methods is a valid OL implementation. + return ( + hasattr(task, "get_openlineage_facets_on_start") + or hasattr(task, "get_openlineage_facets_on_complete") + or hasattr(task, "get_openlineage_facets_on_failure") + ) diff --git a/airflow/providers/openlineage/extractors/python.py b/airflow/providers/openlineage/extractors/python.py index 017e6488cc..f0edad03d2 100644 --- a/airflow/providers/openlineage/extractors/python.py +++ b/airflow/providers/openlineage/extractors/python.py @@ -28,13 +28,14 @@ from airflow.providers.openlineage.plugins.facets import ( UnknownOperatorInstance, ) from airflow.providers.openlineage.utils.utils import get_filtered_unknown_operator_keys, is_source_enabled +from airflow.utils.openlineage_mixin import OpenLineageMixin """ :meta private: """ -class PythonExtractor(BaseExtractor): +class PythonExtractor(BaseExtractor, OpenLineageMixin): """ Extract executed source code and put it into SourceCodeJobFacet. diff --git a/airflow/providers/openlineage/plugins/listener.py b/airflow/providers/openlineage/plugins/listener.py index 99394863f5..516c8456bb 100644 --- a/airflow/providers/openlineage/plugins/listener.py +++ b/airflow/providers/openlineage/plugins/listener.py @@ -42,9 +42,15 @@ class OpenLineageListener: def __init__(self): self.log = logging.getLogger(__name__) - self.executor: Executor = None # type: ignore self.extractor_manager = ExtractorManager() self.adapter = OpenLineageAdapter() + self._executor: Executor | None = None + + @property + def executor(self) -> Executor: + if self._executor is None: + self._executor = ThreadPoolExecutor(max_workers=8, thread_name_prefix="openlineage_") + return self._executor @hookimpl def on_task_instance_running( @@ -151,7 +157,7 @@ class OpenLineageListener: @hookimpl def on_starting(self, component): self.log.debug("on_starting: %s", component.__class__.__name__) - self.executor = ThreadPoolExecutor(max_workers=8, thread_name_prefix="openlineage_") + self.executor @hookimpl def before_stopping(self, component): diff --git a/airflow/providers/sftp/operators/sftp.py b/airflow/providers/sftp/operators/sftp.py index 8da4b3f332..bba92e8e9f 100644 --- a/airflow/providers/sftp/operators/sftp.py +++ b/airflow/providers/sftp/operators/sftp.py @@ -30,6 +30,7 @@ from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarni from airflow.models import BaseOperator from airflow.providers.sftp.hooks.sftp import SFTPHook from airflow.providers.ssh.hooks.ssh import SSHHook +from airflow.utils.openlineage_mixin import OpenLineageMixin class SFTPOperation: @@ -39,7 +40,7 @@ class SFTPOperation: GET = "get" -class SFTPOperator(BaseOperator): +class SFTPOperator(BaseOperator, OpenLineageMixin): """ SFTPOperator for transferring files from remote host to local or vice a versa. diff --git a/airflow/utils/openlineage_mixin.py b/airflow/utils/openlineage_mixin.py new file mode 100644 index 0000000000..8060b0f7c6 --- /dev/null +++ b/airflow/utils/openlineage_mixin.py @@ -0,0 +1,39 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import typing + +if typing.TYPE_CHECKING: + from airflow.models import TaskInstance + from airflow.providers.openlineage.extractors import OperatorLineage + + +class OpenLineageMixin: + """ + This interface marks implementation of OpenLineage methods, + allowing us to check for its existence rather than existence of particular methods on BaseOperator. + """ + + def get_openlineage_facets_on_start(self) -> OperatorLineage | None: + raise NotImplementedError() + + def get_openlineage_facets_on_complete(self, task_instance: TaskInstance) -> OperatorLineage | None: + return self.get_openlineage_facets_on_start() + + def get_openlineage_facets_on_fail(self, task_instance: TaskInstance) -> OperatorLineage | None: + return self.get_openlineage_facets_on_complete(task_instance) diff --git a/tests/providers/openlineage/extractors/test_default_extractor.py b/tests/providers/openlineage/extractors/test_default_extractor.py index 41d477aa9c..68be86e29e 100644 --- a/tests/providers/openlineage/extractors/test_default_extractor.py +++ b/tests/providers/openlineage/extractors/test_default_extractor.py @@ -31,6 +31,7 @@ from airflow.providers.openlineage.extractors.base import ( ) from airflow.providers.openlineage.extractors.manager import ExtractorManager from airflow.providers.openlineage.extractors.python import PythonExtractor +from airflow.utils.openlineage_mixin import OpenLineageMixin INPUTS = [Dataset(namespace="database://host:port", name="inputtable")] OUTPUTS = [Dataset(namespace="database://host:port", name="inputtable")] @@ -48,7 +49,7 @@ class CompleteRunFacet(BaseFacet): FINISHED_FACETS = {"complete": CompleteRunFacet(True)} -class ExampleOperator(BaseOperator): +class ExampleOperator(BaseOperator, OpenLineageMixin): def execute(self, context) -> Any: pass @@ -69,7 +70,7 @@ class ExampleOperator(BaseOperator): ) -class OperatorWithoutComplete(BaseOperator): +class OperatorWithoutComplete(BaseOperator, OpenLineageMixin): def execute(self, context) -> Any: pass @@ -82,7 +83,7 @@ class OperatorWithoutComplete(BaseOperator): ) -class OperatorWithoutStart(BaseOperator): +class OperatorWithoutStart(BaseOperator, OpenLineageMixin): def execute(self, context) -> Any: pass @@ -95,7 +96,7 @@ class OperatorWithoutStart(BaseOperator): ) -class OperatorDifferentOperatorLineageClass(BaseOperator): +class OperatorDifferentOperatorLineageClass(BaseOperator, OpenLineageMixin): def execute(self, context) -> Any: pass @@ -119,7 +120,7 @@ class OperatorDifferentOperatorLineageClass(BaseOperator): ) -class OperatorWrongOperatorLineageClass(BaseOperator): +class OperatorWrongOperatorLineageClass(BaseOperator, OpenLineageMixin): def execute(self, context) -> Any: pass @@ -137,15 +138,28 @@ class OperatorWrongOperatorLineageClass(BaseOperator): ) -class BrokenOperator(BaseOperator): +class BrokenOperator: get_openlineage_facets = [] def execute(self, context) -> Any: pass +class OperatorWithoutMixinButProperClass(BaseOperator): + def execute(self, context) -> Any: + pass + + def get_openlineage_facets_on_start(self) -> OperatorLineage: + return OperatorLineage( + inputs=INPUTS, + outputs=OUTPUTS, + run_facets=RUN_FACETS, + job_facets=JOB_FACETS, + ) + + def test_default_extraction(): - extractor = ExtractorManager().get_extractor_class(ExampleOperator) + extractor = ExtractorManager().get_extractor_class(ExampleOperator(task_id="test")) assert extractor is DefaultExtractor metadata = extractor(ExampleOperator(task_id="test")).extract() @@ -172,7 +186,7 @@ def test_default_extraction(): def test_extraction_without_on_complete(): - extractor = ExtractorManager().get_extractor_class(OperatorWithoutComplete) + extractor = ExtractorManager().get_extractor_class(OperatorWithoutComplete(task_id="test")) assert extractor is DefaultExtractor metadata = extractor(OperatorWithoutComplete(task_id="test")).extract() @@ -196,7 +210,7 @@ def test_extraction_without_on_complete(): def test_extraction_without_on_start(): - extractor = ExtractorManager().get_extractor_class(OperatorWithoutStart) + extractor = ExtractorManager().get_extractor_class(OperatorWithoutStart(task_id="test")) assert extractor is DefaultExtractor metadata = extractor(OperatorWithoutStart(task_id="test")).extract() @@ -217,16 +231,6 @@ def test_extraction_without_on_start(): ) -def test_does_not_use_default_extractor_when_not_a_method(): - extractor_class = ExtractorManager().get_extractor_class(BrokenOperator(task_id="a")) - assert extractor_class is None - - -def test_does_not_use_default_extractor_when_no_get_openlineage_facets(): - extractor_class = ExtractorManager().get_extractor_class(BaseOperator(task_id="b")) - assert extractor_class is None - - def test_does_not_use_default_extractor_when_explicite_extractor(): extractor_class = ExtractorManager().get_extractor_class( PythonOperator(task_id="c", python_callable=lambda: 7) @@ -254,3 +258,16 @@ def test_default_extractor_uses_wrong_operatorlineage_class(): assert ( ExtractorManager().extract_metadata(mock.MagicMock(), operator, complete=False) == OperatorLineage() ) + + +def test_default_extractor_works_without_mixin(): + operator = OperatorWithoutMixinButProperClass(task_id="task_id") + extractor_class = ExtractorManager().get_extractor_class(operator) + assert extractor_class is DefaultExtractor + extractor = extractor_class(operator) + assert extractor.extract() == OperatorLineage( + inputs=INPUTS, + outputs=OUTPUTS, + run_facets=RUN_FACETS, + job_facets=JOB_FACETS, + )
