This is an automated email from the ASF dual-hosted git repository.
mobuchowski pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 807bdca29c6 fix: Adjust OpenLineage DefaultExtractor for
RuntimeTaskInstance in Airflow 3 (#47673)
807bdca29c6 is described below
commit 807bdca29c634a04be85637902db680f567f8e73
Author: Kacper Muda <[email protected]>
AuthorDate: Mon Mar 17 22:32:51 2025 +0100
fix: Adjust OpenLineage DefaultExtractor for RuntimeTaskInstance in Airflow
3 (#47673)
---
.../providers/openlineage/extractors/base.py | 74 ++++++------
.../providers/openlineage/extractors/manager.py | 33 ++++--
.../providers/openlineage/plugins/adapter.py | 2 +-
.../providers/openlineage/plugins/listener.py | 14 ++-
.../tests/unit/openlineage/extractors/test_base.py | 128 ++++++++++++++-------
.../unit/openlineage/extractors/test_manager.py | 8 +-
6 files changed, 161 insertions(+), 98 deletions(-)
diff --git
a/providers/openlineage/src/airflow/providers/openlineage/extractors/base.py
b/providers/openlineage/src/airflow/providers/openlineage/extractors/base.py
index 2b85825d8c6..b5f8a93f20d 100644
--- a/providers/openlineage/src/airflow/providers/openlineage/extractors/base.py
+++ b/providers/openlineage/src/airflow/providers/openlineage/extractors/base.py
@@ -29,14 +29,16 @@ with warnings.catch_warnings():
from openlineage.client.facet import BaseFacet as BaseFacet_V1
from openlineage.client.facet_v2 import JobFacet, RunFacet
-from airflow.providers.openlineage.utils.utils import AIRFLOW_V_2_10_PLUS
from airflow.utils.log.logging_mixin import LoggingMixin
-from airflow.utils.state import TaskInstanceState
# this is not to break static checks compatibility with v1 OpenLineage facet
classes
DatasetSubclass = TypeVar("DatasetSubclass", bound=OLDataset)
BaseFacetSubclass = TypeVar("BaseFacetSubclass", bound=Union[BaseFacet_V1,
RunFacet, JobFacet])
+OL_METHOD_NAME_START = "get_openlineage_facets_on_start"
+OL_METHOD_NAME_COMPLETE = "get_openlineage_facets_on_complete"
+OL_METHOD_NAME_FAIL = "get_openlineage_facets_on_failure"
+
@define
class OperatorLineage(Generic[DatasetSubclass, BaseFacetSubclass]):
@@ -81,6 +83,9 @@ class BaseExtractor(ABC, LoggingMixin):
def extract_on_complete(self, task_instance) -> OperatorLineage | None:
return self.extract()
+ def extract_on_failure(self, task_instance) -> OperatorLineage | None:
+ return self.extract()
+
class DefaultExtractor(BaseExtractor):
"""Extractor that uses `get_openlineage_facets_on_start/complete/failure`
methods."""
@@ -96,46 +101,41 @@ class DefaultExtractor(BaseExtractor):
return []
def _execute_extraction(self) -> OperatorLineage | None:
- # OpenLineage methods are optional - if there's no method, return None
- try:
- self.log.debug(
- "Trying to execute `get_openlineage_facets_on_start` for %s.",
self.operator.task_type
- )
- return
self._get_openlineage_facets(self.operator.get_openlineage_facets_on_start) #
type: ignore
- except ImportError:
- self.log.error(
- "OpenLineage provider method failed to import OpenLineage
integration. "
- "This should not happen. Please report this bug to developers."
- )
- return None
- except AttributeError:
+ method = getattr(self.operator, OL_METHOD_NAME_START, None)
+ if callable(method):
self.log.debug(
- "Operator %s does not have the get_openlineage_facets_on_start
method.",
- self.operator.task_type,
+ "Trying to execute '%s' method of '%s'.",
OL_METHOD_NAME_START, self.operator.task_type
)
- return OperatorLineage()
+ return self._get_openlineage_facets(method)
+ self.log.debug(
+ "Operator '%s' does not have '%s' method.",
self.operator.task_type, OL_METHOD_NAME_START
+ )
+ return OperatorLineage()
def extract_on_complete(self, task_instance) -> OperatorLineage | None:
- failed_states = [TaskInstanceState.FAILED,
TaskInstanceState.UP_FOR_RETRY]
- if not AIRFLOW_V_2_10_PLUS: # todo: remove when min airflow version
>= 2.10.0
- # Before fix (#41053) implemented in Airflow 2.10 TaskInstance's
state was still RUNNING when
- # being passed to listener's on_failure method. Since
`extract_on_complete()` is only called
- # after task completion, RUNNING state means that we are dealing
with FAILED task in < 2.10
- failed_states = [TaskInstanceState.RUNNING]
-
- if task_instance.state in failed_states:
- on_failed = getattr(self.operator,
"get_openlineage_facets_on_failure", None)
- if on_failed and callable(on_failed):
- self.log.debug(
- "Executing `get_openlineage_facets_on_failure` for %s.",
self.operator.task_type
- )
- return self._get_openlineage_facets(on_failed, task_instance)
- on_complete = getattr(self.operator,
"get_openlineage_facets_on_complete", None)
- if on_complete and callable(on_complete):
- self.log.debug("Executing `get_openlineage_facets_on_complete` for
%s.", self.operator.task_type)
- return self._get_openlineage_facets(on_complete, task_instance)
+ method = getattr(self.operator, OL_METHOD_NAME_COMPLETE, None)
+ if callable(method):
+ self.log.debug(
+ "Trying to execute '%s' method of '%s'.",
OL_METHOD_NAME_COMPLETE, self.operator.task_type
+ )
+ return self._get_openlineage_facets(method, task_instance)
+ self.log.debug(
+ "Operator '%s' does not have '%s' method.",
self.operator.task_type, OL_METHOD_NAME_COMPLETE
+ )
return self.extract()
+ def extract_on_failure(self, task_instance) -> OperatorLineage | None:
+ method = getattr(self.operator, OL_METHOD_NAME_FAIL, None)
+ if callable(method):
+ self.log.debug(
+ "Trying to execute '%s' method of '%s'.", OL_METHOD_NAME_FAIL,
self.operator.task_type
+ )
+ return self._get_openlineage_facets(method, task_instance)
+ self.log.debug(
+ "Operator '%s' does not have '%s' method.",
self.operator.task_type, OL_METHOD_NAME_FAIL
+ )
+ return self.extract_on_complete(task_instance)
+
def _get_openlineage_facets(self, get_facets_method, *args) ->
OperatorLineage | None:
try:
facets: OperatorLineage = get_facets_method(*args)
@@ -153,5 +153,5 @@ class DefaultExtractor(BaseExtractor):
"This should not happen."
)
except Exception:
- self.log.warning("OpenLineage provider method failed to extract
data from provider. ")
+ self.log.warning("OpenLineage provider method failed to extract
data from provider.")
return None
diff --git
a/providers/openlineage/src/airflow/providers/openlineage/extractors/manager.py
b/providers/openlineage/src/airflow/providers/openlineage/extractors/manager.py
index f07014885ea..964616382f0 100644
---
a/providers/openlineage/src/airflow/providers/openlineage/extractors/manager.py
+++
b/providers/openlineage/src/airflow/providers/openlineage/extractors/manager.py
@@ -24,7 +24,11 @@ from airflow.providers.common.compat.openlineage.utils.utils
import (
)
from airflow.providers.openlineage import conf
from airflow.providers.openlineage.extractors import BaseExtractor,
OperatorLineage
-from airflow.providers.openlineage.extractors.base import DefaultExtractor
+from airflow.providers.openlineage.extractors.base import (
+ OL_METHOD_NAME_COMPLETE,
+ OL_METHOD_NAME_START,
+ DefaultExtractor,
+)
from airflow.providers.openlineage.extractors.bash import BashExtractor
from airflow.providers.openlineage.extractors.python import PythonExtractor
from airflow.providers.openlineage.utils.utils import (
@@ -32,6 +36,7 @@ from airflow.providers.openlineage.utils.utils import (
try_import_from_string,
)
from airflow.utils.log.logging_mixin import LoggingMixin
+from airflow.utils.state import TaskInstanceState
if TYPE_CHECKING:
from openlineage.client.event_v2 import Dataset
@@ -87,7 +92,9 @@ class ExtractorManager(LoggingMixin):
def add_extractor(self, operator_class: str, extractor:
type[BaseExtractor]):
self.extractors[operator_class] = extractor
- def extract_metadata(self, dagrun, task, complete: bool = False,
task_instance=None) -> OperatorLineage:
+ def extract_metadata(
+ self, dagrun, task, task_instance_state: TaskInstanceState,
task_instance=None
+ ) -> OperatorLineage:
extractor = self._get_extractor(task)
task_info = (
f"task_type={task.task_type} "
@@ -104,10 +111,15 @@ class ExtractorManager(LoggingMixin):
extractor.__class__.__name__,
str(task_info),
)
- if complete:
- task_metadata =
extractor.extract_on_complete(task_instance)
- else:
+ if task_instance_state == TaskInstanceState.RUNNING:
task_metadata = extractor.extract()
+ elif task_instance_state == TaskInstanceState.FAILED:
+ if callable(getattr(extractor, "extract_on_failure",
None)):
+ task_metadata =
extractor.extract_on_failure(task_instance)
+ else:
+ task_metadata =
extractor.extract_on_complete(task_instance)
+ else:
+ task_metadata =
extractor.extract_on_complete(task_instance)
self.log.debug(
"Found task metadata for operation %s: %s",
@@ -155,13 +167,9 @@ class ExtractorManager(LoggingMixin):
return self.extractors[task.task_type]
def method_exists(method_name):
- method = getattr(task, method_name, None)
- if method:
- return callable(method)
+ return callable(getattr(task, method_name, None))
- if method_exists("get_openlineage_facets_on_start") or method_exists(
- "get_openlineage_facets_on_complete"
- ):
+ if method_exists(OL_METHOD_NAME_START) or
method_exists(OL_METHOD_NAME_COMPLETE):
return self.default_extractor
return None
@@ -191,7 +199,8 @@ class ExtractorManager(LoggingMixin):
if d:
task_metadata.outputs.append(d)
- def get_hook_lineage(self) -> tuple[list[Dataset], list[Dataset]] | None:
+ @staticmethod
+ def get_hook_lineage() -> tuple[list[Dataset], list[Dataset]] | None:
try:
from airflow.providers.common.compat.lineage.hook import (
get_hook_lineage_collector,
diff --git
a/providers/openlineage/src/airflow/providers/openlineage/plugins/adapter.py
b/providers/openlineage/src/airflow/providers/openlineage/plugins/adapter.py
index 8350b2a0d51..a7c7d5e1f9f 100644
--- a/providers/openlineage/src/airflow/providers/openlineage/plugins/adapter.py
+++ b/providers/openlineage/src/airflow/providers/openlineage/plugins/adapter.py
@@ -85,7 +85,7 @@ class OpenLineageAdapter(LoggingMixin):
if config:
self.log.debug(
"OpenLineage configuration found. Transport type: `%s`",
- config.get("type", "no type provided"),
+ config.get("transport", {}).get("type", "no type
provided"),
)
self._client = OpenLineageClient(config=config) # type:
ignore[call-arg]
else:
diff --git
a/providers/openlineage/src/airflow/providers/openlineage/plugins/listener.py
b/providers/openlineage/src/airflow/providers/openlineage/plugins/listener.py
index 3af06538ce7..b46c65de64a 100644
---
a/providers/openlineage/src/airflow/providers/openlineage/plugins/listener.py
+++
b/providers/openlineage/src/airflow/providers/openlineage/plugins/listener.py
@@ -200,7 +200,9 @@ class OpenLineageListener:
operator_name = task.task_type.lower()
with Stats.timer(f"ol.extract.{event_type}.{operator_name}"):
- task_metadata =
self.extractor_manager.extract_metadata(dagrun, task)
+ task_metadata = self.extractor_manager.extract_metadata(
+ dagrun=dagrun, task=task,
task_instance_state=TaskInstanceState.RUNNING
+ )
redacted_event = self.adapter.start_task(
run_id=task_uuid,
@@ -303,7 +305,10 @@ class OpenLineageListener:
with Stats.timer(f"ol.extract.{event_type}.{operator_name}"):
task_metadata = self.extractor_manager.extract_metadata(
- dagrun, task, complete=True, task_instance=task_instance
+ dagrun=dagrun,
+ task=task,
+ task_instance_state=TaskInstanceState.SUCCESS,
+ task_instance=task_instance,
)
redacted_event = self.adapter.complete_task(
@@ -424,7 +429,10 @@ class OpenLineageListener:
with Stats.timer(f"ol.extract.{event_type}.{operator_name}"):
task_metadata = self.extractor_manager.extract_metadata(
- dagrun, task, complete=True, task_instance=task_instance
+ dagrun=dagrun,
+ task=task,
+ task_instance_state=TaskInstanceState.FAILED,
+ task_instance=task_instance,
)
redacted_event = self.adapter.fail_task(
diff --git
a/providers/openlineage/tests/unit/openlineage/extractors/test_base.py
b/providers/openlineage/tests/unit/openlineage/extractors/test_base.py
index c85f75b3751..a120537a414 100644
--- a/providers/openlineage/tests/unit/openlineage/extractors/test_base.py
+++ b/providers/openlineage/tests/unit/openlineage/extractors/test_base.py
@@ -56,16 +56,43 @@ class CompleteRunFacet(JobFacet):
finished: bool = field(default=False)
+@define
+class FailRunFacet(JobFacet):
+ failed: bool = field(default=False)
+
+
FINISHED_FACETS: dict[str, JobFacet] = {"complete": CompleteRunFacet(True)}
+FAILED_FACETS: dict[str, JobFacet] = {"failure": FailRunFacet(True)}
class ExampleExtractor(BaseExtractor):
@classmethod
def get_operator_classnames(cls):
- return ["ExampleOperator"]
+ return ["OperatorWithoutFailure"]
+
+
+class OperatorWithoutFailure(BaseOperator):
+ def execute(self, context) -> Any:
+ pass
+
+ def get_openlineage_facets_on_start(self) -> OperatorLineage:
+ return OperatorLineage(
+ inputs=INPUTS,
+ outputs=OUTPUTS,
+ run_facets=RUN_FACETS,
+ job_facets=JOB_FACETS,
+ )
+
+ def get_openlineage_facets_on_complete(self, task_instance) ->
OperatorLineage:
+ return OperatorLineage(
+ inputs=INPUTS,
+ outputs=OUTPUTS,
+ run_facets=RUN_FACETS,
+ job_facets=FINISHED_FACETS,
+ )
-class ExampleOperator(BaseOperator):
+class OperatorWithAllOlMethods(BaseOperator):
def execute(self, context) -> Any:
pass
@@ -85,6 +112,14 @@ class ExampleOperator(BaseOperator):
job_facets=FINISHED_FACETS,
)
+ def get_openlineage_facets_on_failure(self, task_instance) ->
OperatorLineage:
+ return OperatorLineage(
+ inputs=INPUTS,
+ outputs=OUTPUTS,
+ run_facets=RUN_FACETS,
+ job_facets=FAILED_FACETS,
+ )
+
class OperatorWithoutComplete(BaseOperator):
def execute(self, context) -> Any:
@@ -162,14 +197,14 @@ class BrokenOperator(BaseOperator):
def test_default_extraction():
- extractor = ExtractorManager().get_extractor_class(ExampleOperator)
+ extractor = ExtractorManager().get_extractor_class(OperatorWithoutFailure)
assert extractor is DefaultExtractor
- metadata = extractor(ExampleOperator(task_id="test")).extract()
+ metadata = extractor(OperatorWithoutFailure(task_id="test")).extract()
task_instance = mock.MagicMock()
- metadata_on_complete =
extractor(ExampleOperator(task_id="test")).extract_on_complete(
+ metadata_on_complete =
extractor(OperatorWithoutFailure(task_id="test")).extract_on_complete(
task_instance=task_instance
)
@@ -235,50 +270,59 @@ def test_extraction_without_on_start():
@pytest.mark.parametrize(
- "task_state, is_airflow_2_10_or_higher, should_call_on_failure",
+ "operator_class, task_state, expected_job_facets",
(
- # Airflow >= 2.10
- (TaskInstanceState.FAILED, True, True),
- (TaskInstanceState.UP_FOR_RETRY, True, True),
- (TaskInstanceState.RUNNING, True, False),
- (TaskInstanceState.SUCCESS, True, False),
- # Airflow < 2.10
- (TaskInstanceState.RUNNING, False, True),
- (TaskInstanceState.SUCCESS, False, False),
- (TaskInstanceState.FAILED, False, False), # should never happen,
fixed in #41053
- (TaskInstanceState.UP_FOR_RETRY, False, False), # should never
happen, fixed in #41053
+ (OperatorWithAllOlMethods, TaskInstanceState.FAILED, FAILED_FACETS),
+ (OperatorWithAllOlMethods, TaskInstanceState.RUNNING, JOB_FACETS),
+ (OperatorWithAllOlMethods, TaskInstanceState.SUCCESS, FINISHED_FACETS),
+ (OperatorWithAllOlMethods, TaskInstanceState.UP_FOR_RETRY,
FINISHED_FACETS), # Should never happen
+ (OperatorWithAllOlMethods, None, FINISHED_FACETS), # Should never
happen
+ (OperatorWithoutFailure, TaskInstanceState.FAILED, FINISHED_FACETS),
+ (OperatorWithoutFailure, TaskInstanceState.RUNNING, JOB_FACETS),
+ (OperatorWithoutFailure, TaskInstanceState.SUCCESS, FINISHED_FACETS),
+ (OperatorWithoutFailure, TaskInstanceState.UP_FOR_RETRY,
FINISHED_FACETS), # Should never happen
+ (OperatorWithoutFailure, None, FINISHED_FACETS), # Should never happen
+ (OperatorWithoutStart, TaskInstanceState.FAILED, FINISHED_FACETS),
+ (OperatorWithoutStart, TaskInstanceState.RUNNING, {}),
+ (OperatorWithoutStart, TaskInstanceState.SUCCESS, FINISHED_FACETS),
+ (OperatorWithoutStart, TaskInstanceState.UP_FOR_RETRY,
FINISHED_FACETS), # Should never happen
+ (OperatorWithoutStart, None, FINISHED_FACETS), # Should never happen
+ (OperatorWithoutComplete, TaskInstanceState.FAILED, JOB_FACETS),
+ (OperatorWithoutComplete, TaskInstanceState.RUNNING, JOB_FACETS),
+ (OperatorWithoutComplete, TaskInstanceState.SUCCESS, JOB_FACETS),
+ (OperatorWithoutComplete, TaskInstanceState.UP_FOR_RETRY, JOB_FACETS),
# Should never happen
+ (OperatorWithoutComplete, None, JOB_FACETS), # Should never happen
),
)
-def test_extract_on_failure(task_state, is_airflow_2_10_or_higher,
should_call_on_failure):
- task_instance = mock.Mock(state=task_state)
- operator = mock.Mock()
- operator.get_openlineage_facets_on_failure = mock.Mock(
- return_value=OperatorLineage(run_facets={"failed": True})
+def test_extractor_manager_calls_appropriate_extractor_method(
+ operator_class, task_state, expected_job_facets
+):
+ extractor_manager = ExtractorManager()
+
+ ti = mock.MagicMock()
+
+ metadata = extractor_manager.extract_metadata(
+ dagrun=mock.MagicMock(run_id="dagrun_run_id"),
+ task=operator_class(task_id="task_id"),
+ task_instance_state=task_state,
+ task_instance=ti,
)
- operator.get_openlineage_facets_on_complete = mock.Mock(return_value=None)
-
- extractor = DefaultExtractor(operator=operator)
- with mock.patch(
- "airflow.providers.openlineage.extractors.base.AIRFLOW_V_2_10_PLUS",
is_airflow_2_10_or_higher
- ):
- result = extractor.extract_on_complete(task_instance)
-
- if should_call_on_failure:
-
operator.get_openlineage_facets_on_failure.assert_called_once_with(task_instance)
- operator.get_openlineage_facets_on_complete.assert_not_called()
- assert isinstance(result, OperatorLineage)
- assert result.run_facets == {"failed": True}
- else:
- operator.get_openlineage_facets_on_failure.assert_not_called()
-
operator.get_openlineage_facets_on_complete.assert_called_once_with(task_instance)
- assert result is None
+ assert metadata.job_facets == expected_job_facets
+ if not expected_job_facets: # Empty OperatorLineage() is expected
+ assert not metadata.inputs
+ assert not metadata.outputs
+ assert not metadata.run_facets
+ else:
+ assert metadata.inputs == INPUTS
+ assert metadata.outputs == OUTPUTS
+ assert metadata.run_facets == RUN_FACETS
@mock.patch("airflow.providers.openlineage.conf.custom_extractors")
def test_extractors_env_var(custom_extractors):
custom_extractors.return_value =
{"unit.openlineage.extractors.test_base.ExampleExtractor"}
- extractor =
ExtractorManager().get_extractor_class(ExampleOperator(task_id="example"))
+ extractor =
ExtractorManager().get_extractor_class(OperatorWithoutFailure(task_id="example"))
assert extractor is ExampleExtractor
@@ -292,7 +336,7 @@ def
test_does_not_use_default_extractor_when_no_get_openlineage_facets():
assert extractor_class is None
-def test_does_not_use_default_extractor_when_explicite_extractor():
+def test_does_not_use_default_extractor_when_explicit_extractor():
extractor_class = ExtractorManager().get_extractor_class(
PythonOperator(task_id="c", python_callable=lambda: 7)
)
@@ -316,6 +360,4 @@ def
test_default_extractor_uses_wrong_operatorlineage_class():
operator = OperatorWrongOperatorLineageClass(task_id="task_id")
# If extractor returns lineage class that can't be changed into
OperatorLineage, just return
# empty OperatorLineage
- assert (
- ExtractorManager().extract_metadata(mock.MagicMock(), operator,
complete=False) == OperatorLineage()
- )
+ assert ExtractorManager().extract_metadata(mock.MagicMock(), operator,
None) == OperatorLineage()
diff --git
a/providers/openlineage/tests/unit/openlineage/extractors/test_manager.py
b/providers/openlineage/tests/unit/openlineage/extractors/test_manager.py
index 256bc133f1a..04739f22633 100644
--- a/providers/openlineage/tests/unit/openlineage/extractors/test_manager.py
+++ b/providers/openlineage/tests/unit/openlineage/extractors/test_manager.py
@@ -293,7 +293,9 @@ def
test_extractor_manager_uses_hook_level_lineage(hook_lineage_collector):
hook_lineage_collector.add_input_asset(None, uri="s3://bucket/input_key")
hook_lineage_collector.add_output_asset(None, uri="s3://bucket/output_key")
extractor_manager = ExtractorManager()
- metadata = extractor_manager.extract_metadata(dagrun=dagrun, task=task,
complete=True, task_instance=ti)
+ metadata = extractor_manager.extract_metadata(
+ dagrun=dagrun, task=task, task_instance_state=None, task_instance=ti
+ )
assert metadata.inputs == [OpenLineageDataset(namespace="s3://bucket",
name="input_key")]
assert metadata.outputs == [OpenLineageDataset(namespace="s3://bucket",
name="output_key")]
@@ -318,7 +320,9 @@ def
test_extractor_manager_does_not_use_hook_level_lineage_when_operator(
hook_lineage_collector.add_input_asset(None, uri="s3://bucket/input_key")
extractor_manager = ExtractorManager()
- metadata = extractor_manager.extract_metadata(dagrun=dagrun, task=task,
complete=True, task_instance=ti)
+ metadata = extractor_manager.extract_metadata(
+ dagrun=dagrun, task=task, task_instance_state=None, task_instance=ti
+ )
# s3://bucket/input_key not here - use data from operator
assert metadata.inputs == [OpenLineageDataset(namespace="s3://bucket",
name="proper_input_key")]