This is an automated email from the ASF dual-hosted git repository.

mobuchowski pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 807bdca29c6 fix: Adjust OpenLineage DefaultExtractor for 
RuntimeTaskInstance in Airflow 3 (#47673)
807bdca29c6 is described below

commit 807bdca29c634a04be85637902db680f567f8e73
Author: Kacper Muda <[email protected]>
AuthorDate: Mon Mar 17 22:32:51 2025 +0100

    fix: Adjust OpenLineage DefaultExtractor for RuntimeTaskInstance in Airflow 
3 (#47673)
---
 .../providers/openlineage/extractors/base.py       |  74 ++++++------
 .../providers/openlineage/extractors/manager.py    |  33 ++++--
 .../providers/openlineage/plugins/adapter.py       |   2 +-
 .../providers/openlineage/plugins/listener.py      |  14 ++-
 .../tests/unit/openlineage/extractors/test_base.py | 128 ++++++++++++++-------
 .../unit/openlineage/extractors/test_manager.py    |   8 +-
 6 files changed, 161 insertions(+), 98 deletions(-)

diff --git 
a/providers/openlineage/src/airflow/providers/openlineage/extractors/base.py 
b/providers/openlineage/src/airflow/providers/openlineage/extractors/base.py
index 2b85825d8c6..b5f8a93f20d 100644
--- a/providers/openlineage/src/airflow/providers/openlineage/extractors/base.py
+++ b/providers/openlineage/src/airflow/providers/openlineage/extractors/base.py
@@ -29,14 +29,16 @@ with warnings.catch_warnings():
     from openlineage.client.facet import BaseFacet as BaseFacet_V1
 from openlineage.client.facet_v2 import JobFacet, RunFacet
 
-from airflow.providers.openlineage.utils.utils import AIRFLOW_V_2_10_PLUS
 from airflow.utils.log.logging_mixin import LoggingMixin
-from airflow.utils.state import TaskInstanceState
 
 # this is not to break static checks compatibility with v1 OpenLineage facet 
classes
 DatasetSubclass = TypeVar("DatasetSubclass", bound=OLDataset)
 BaseFacetSubclass = TypeVar("BaseFacetSubclass", bound=Union[BaseFacet_V1, 
RunFacet, JobFacet])
 
+OL_METHOD_NAME_START = "get_openlineage_facets_on_start"
+OL_METHOD_NAME_COMPLETE = "get_openlineage_facets_on_complete"
+OL_METHOD_NAME_FAIL = "get_openlineage_facets_on_failure"
+
 
 @define
 class OperatorLineage(Generic[DatasetSubclass, BaseFacetSubclass]):
@@ -81,6 +83,9 @@ class BaseExtractor(ABC, LoggingMixin):
     def extract_on_complete(self, task_instance) -> OperatorLineage | None:
         return self.extract()
 
+    def extract_on_failure(self, task_instance) -> OperatorLineage | None:
+        return self.extract()
+
 
 class DefaultExtractor(BaseExtractor):
     """Extractor that uses `get_openlineage_facets_on_start/complete/failure` 
methods."""
@@ -96,46 +101,41 @@ class DefaultExtractor(BaseExtractor):
         return []
 
     def _execute_extraction(self) -> OperatorLineage | None:
-        # OpenLineage methods are optional - if there's no method, return None
-        try:
-            self.log.debug(
-                "Trying to execute `get_openlineage_facets_on_start` for %s.", 
self.operator.task_type
-            )
-            return 
self._get_openlineage_facets(self.operator.get_openlineage_facets_on_start)  # 
type: ignore
-        except ImportError:
-            self.log.error(
-                "OpenLineage provider method failed to import OpenLineage 
integration. "
-                "This should not happen. Please report this bug to developers."
-            )
-            return None
-        except AttributeError:
+        method = getattr(self.operator, OL_METHOD_NAME_START, None)
+        if callable(method):
             self.log.debug(
-                "Operator %s does not have the get_openlineage_facets_on_start 
method.",
-                self.operator.task_type,
+                "Trying to execute '%s' method of '%s'.", 
OL_METHOD_NAME_START, self.operator.task_type
             )
-            return OperatorLineage()
+            return self._get_openlineage_facets(method)
+        self.log.debug(
+            "Operator '%s' does not have '%s' method.", 
self.operator.task_type, OL_METHOD_NAME_START
+        )
+        return OperatorLineage()
 
     def extract_on_complete(self, task_instance) -> OperatorLineage | None:
-        failed_states = [TaskInstanceState.FAILED, 
TaskInstanceState.UP_FOR_RETRY]
-        if not AIRFLOW_V_2_10_PLUS:  # todo: remove when min airflow version 
>= 2.10.0
-            # Before fix (#41053) implemented in Airflow 2.10 TaskInstance's 
state was still RUNNING when
-            # being passed to listener's on_failure method. Since 
`extract_on_complete()` is only called
-            # after task completion, RUNNING state means that we are dealing 
with FAILED task in < 2.10
-            failed_states = [TaskInstanceState.RUNNING]
-
-        if task_instance.state in failed_states:
-            on_failed = getattr(self.operator, 
"get_openlineage_facets_on_failure", None)
-            if on_failed and callable(on_failed):
-                self.log.debug(
-                    "Executing `get_openlineage_facets_on_failure` for %s.", 
self.operator.task_type
-                )
-                return self._get_openlineage_facets(on_failed, task_instance)
-        on_complete = getattr(self.operator, 
"get_openlineage_facets_on_complete", None)
-        if on_complete and callable(on_complete):
-            self.log.debug("Executing `get_openlineage_facets_on_complete` for 
%s.", self.operator.task_type)
-            return self._get_openlineage_facets(on_complete, task_instance)
+        method = getattr(self.operator, OL_METHOD_NAME_COMPLETE, None)
+        if callable(method):
+            self.log.debug(
+                "Trying to execute '%s' method of '%s'.", 
OL_METHOD_NAME_COMPLETE, self.operator.task_type
+            )
+            return self._get_openlineage_facets(method, task_instance)
+        self.log.debug(
+            "Operator '%s' does not have '%s' method.", 
self.operator.task_type, OL_METHOD_NAME_COMPLETE
+        )
         return self.extract()
 
+    def extract_on_failure(self, task_instance) -> OperatorLineage | None:
+        method = getattr(self.operator, OL_METHOD_NAME_FAIL, None)
+        if callable(method):
+            self.log.debug(
+                "Trying to execute '%s' method of '%s'.", OL_METHOD_NAME_FAIL, 
self.operator.task_type
+            )
+            return self._get_openlineage_facets(method, task_instance)
+        self.log.debug(
+            "Operator '%s' does not have '%s' method.", 
self.operator.task_type, OL_METHOD_NAME_FAIL
+        )
+        return self.extract_on_complete(task_instance)
+
     def _get_openlineage_facets(self, get_facets_method, *args) -> 
OperatorLineage | None:
         try:
             facets: OperatorLineage = get_facets_method(*args)
@@ -153,5 +153,5 @@ class DefaultExtractor(BaseExtractor):
                 "This should not happen."
             )
         except Exception:
-            self.log.warning("OpenLineage provider method failed to extract 
data from provider. ")
+            self.log.warning("OpenLineage provider method failed to extract 
data from provider.")
         return None
diff --git 
a/providers/openlineage/src/airflow/providers/openlineage/extractors/manager.py 
b/providers/openlineage/src/airflow/providers/openlineage/extractors/manager.py
index f07014885ea..964616382f0 100644
--- 
a/providers/openlineage/src/airflow/providers/openlineage/extractors/manager.py
+++ 
b/providers/openlineage/src/airflow/providers/openlineage/extractors/manager.py
@@ -24,7 +24,11 @@ from airflow.providers.common.compat.openlineage.utils.utils 
import (
 )
 from airflow.providers.openlineage import conf
 from airflow.providers.openlineage.extractors import BaseExtractor, 
OperatorLineage
-from airflow.providers.openlineage.extractors.base import DefaultExtractor
+from airflow.providers.openlineage.extractors.base import (
+    OL_METHOD_NAME_COMPLETE,
+    OL_METHOD_NAME_START,
+    DefaultExtractor,
+)
 from airflow.providers.openlineage.extractors.bash import BashExtractor
 from airflow.providers.openlineage.extractors.python import PythonExtractor
 from airflow.providers.openlineage.utils.utils import (
@@ -32,6 +36,7 @@ from airflow.providers.openlineage.utils.utils import (
     try_import_from_string,
 )
 from airflow.utils.log.logging_mixin import LoggingMixin
+from airflow.utils.state import TaskInstanceState
 
 if TYPE_CHECKING:
     from openlineage.client.event_v2 import Dataset
@@ -87,7 +92,9 @@ class ExtractorManager(LoggingMixin):
     def add_extractor(self, operator_class: str, extractor: 
type[BaseExtractor]):
         self.extractors[operator_class] = extractor
 
-    def extract_metadata(self, dagrun, task, complete: bool = False, 
task_instance=None) -> OperatorLineage:
+    def extract_metadata(
+        self, dagrun, task, task_instance_state: TaskInstanceState, 
task_instance=None
+    ) -> OperatorLineage:
         extractor = self._get_extractor(task)
         task_info = (
             f"task_type={task.task_type} "
@@ -104,10 +111,15 @@ class ExtractorManager(LoggingMixin):
                     extractor.__class__.__name__,
                     str(task_info),
                 )
-                if complete:
-                    task_metadata = 
extractor.extract_on_complete(task_instance)
-                else:
+                if task_instance_state == TaskInstanceState.RUNNING:
                     task_metadata = extractor.extract()
+                elif task_instance_state == TaskInstanceState.FAILED:
+                    if callable(getattr(extractor, "extract_on_failure", 
None)):
+                        task_metadata = 
extractor.extract_on_failure(task_instance)
+                    else:
+                        task_metadata = 
extractor.extract_on_complete(task_instance)
+                else:
+                    task_metadata = 
extractor.extract_on_complete(task_instance)
 
                 self.log.debug(
                     "Found task metadata for operation %s: %s",
@@ -155,13 +167,9 @@ class ExtractorManager(LoggingMixin):
             return self.extractors[task.task_type]
 
         def method_exists(method_name):
-            method = getattr(task, method_name, None)
-            if method:
-                return callable(method)
+            return callable(getattr(task, method_name, None))
 
-        if method_exists("get_openlineage_facets_on_start") or method_exists(
-            "get_openlineage_facets_on_complete"
-        ):
+        if method_exists(OL_METHOD_NAME_START) or 
method_exists(OL_METHOD_NAME_COMPLETE):
             return self.default_extractor
         return None
 
@@ -191,7 +199,8 @@ class ExtractorManager(LoggingMixin):
             if d:
                 task_metadata.outputs.append(d)
 
-    def get_hook_lineage(self) -> tuple[list[Dataset], list[Dataset]] | None:
+    @staticmethod
+    def get_hook_lineage() -> tuple[list[Dataset], list[Dataset]] | None:
         try:
             from airflow.providers.common.compat.lineage.hook import (
                 get_hook_lineage_collector,
diff --git 
a/providers/openlineage/src/airflow/providers/openlineage/plugins/adapter.py 
b/providers/openlineage/src/airflow/providers/openlineage/plugins/adapter.py
index 8350b2a0d51..a7c7d5e1f9f 100644
--- a/providers/openlineage/src/airflow/providers/openlineage/plugins/adapter.py
+++ b/providers/openlineage/src/airflow/providers/openlineage/plugins/adapter.py
@@ -85,7 +85,7 @@ class OpenLineageAdapter(LoggingMixin):
             if config:
                 self.log.debug(
                     "OpenLineage configuration found. Transport type: `%s`",
-                    config.get("type", "no type provided"),
+                    config.get("transport", {}).get("type", "no type 
provided"),
                 )
                 self._client = OpenLineageClient(config=config)  # type: 
ignore[call-arg]
             else:
diff --git 
a/providers/openlineage/src/airflow/providers/openlineage/plugins/listener.py 
b/providers/openlineage/src/airflow/providers/openlineage/plugins/listener.py
index 3af06538ce7..b46c65de64a 100644
--- 
a/providers/openlineage/src/airflow/providers/openlineage/plugins/listener.py
+++ 
b/providers/openlineage/src/airflow/providers/openlineage/plugins/listener.py
@@ -200,7 +200,9 @@ class OpenLineageListener:
             operator_name = task.task_type.lower()
 
             with Stats.timer(f"ol.extract.{event_type}.{operator_name}"):
-                task_metadata = 
self.extractor_manager.extract_metadata(dagrun, task)
+                task_metadata = self.extractor_manager.extract_metadata(
+                    dagrun=dagrun, task=task, 
task_instance_state=TaskInstanceState.RUNNING
+                )
 
             redacted_event = self.adapter.start_task(
                 run_id=task_uuid,
@@ -303,7 +305,10 @@ class OpenLineageListener:
 
             with Stats.timer(f"ol.extract.{event_type}.{operator_name}"):
                 task_metadata = self.extractor_manager.extract_metadata(
-                    dagrun, task, complete=True, task_instance=task_instance
+                    dagrun=dagrun,
+                    task=task,
+                    task_instance_state=TaskInstanceState.SUCCESS,
+                    task_instance=task_instance,
                 )
 
             redacted_event = self.adapter.complete_task(
@@ -424,7 +429,10 @@ class OpenLineageListener:
 
             with Stats.timer(f"ol.extract.{event_type}.{operator_name}"):
                 task_metadata = self.extractor_manager.extract_metadata(
-                    dagrun, task, complete=True, task_instance=task_instance
+                    dagrun=dagrun,
+                    task=task,
+                    task_instance_state=TaskInstanceState.FAILED,
+                    task_instance=task_instance,
                 )
 
             redacted_event = self.adapter.fail_task(
diff --git 
a/providers/openlineage/tests/unit/openlineage/extractors/test_base.py 
b/providers/openlineage/tests/unit/openlineage/extractors/test_base.py
index c85f75b3751..a120537a414 100644
--- a/providers/openlineage/tests/unit/openlineage/extractors/test_base.py
+++ b/providers/openlineage/tests/unit/openlineage/extractors/test_base.py
@@ -56,16 +56,43 @@ class CompleteRunFacet(JobFacet):
     finished: bool = field(default=False)
 
 
+@define
+class FailRunFacet(JobFacet):
+    failed: bool = field(default=False)
+
+
 FINISHED_FACETS: dict[str, JobFacet] = {"complete": CompleteRunFacet(True)}
+FAILED_FACETS: dict[str, JobFacet] = {"failure": FailRunFacet(True)}
 
 
 class ExampleExtractor(BaseExtractor):
     @classmethod
     def get_operator_classnames(cls):
-        return ["ExampleOperator"]
+        return ["OperatorWithoutFailure"]
+
+
+class OperatorWithoutFailure(BaseOperator):
+    def execute(self, context) -> Any:
+        pass
+
+    def get_openlineage_facets_on_start(self) -> OperatorLineage:
+        return OperatorLineage(
+            inputs=INPUTS,
+            outputs=OUTPUTS,
+            run_facets=RUN_FACETS,
+            job_facets=JOB_FACETS,
+        )
+
+    def get_openlineage_facets_on_complete(self, task_instance) -> 
OperatorLineage:
+        return OperatorLineage(
+            inputs=INPUTS,
+            outputs=OUTPUTS,
+            run_facets=RUN_FACETS,
+            job_facets=FINISHED_FACETS,
+        )
 
 
-class ExampleOperator(BaseOperator):
+class OperatorWithAllOlMethods(BaseOperator):
     def execute(self, context) -> Any:
         pass
 
@@ -85,6 +112,14 @@ class ExampleOperator(BaseOperator):
             job_facets=FINISHED_FACETS,
         )
 
+    def get_openlineage_facets_on_failure(self, task_instance) -> 
OperatorLineage:
+        return OperatorLineage(
+            inputs=INPUTS,
+            outputs=OUTPUTS,
+            run_facets=RUN_FACETS,
+            job_facets=FAILED_FACETS,
+        )
+
 
 class OperatorWithoutComplete(BaseOperator):
     def execute(self, context) -> Any:
@@ -162,14 +197,14 @@ class BrokenOperator(BaseOperator):
 
 
 def test_default_extraction():
-    extractor = ExtractorManager().get_extractor_class(ExampleOperator)
+    extractor = ExtractorManager().get_extractor_class(OperatorWithoutFailure)
     assert extractor is DefaultExtractor
 
-    metadata = extractor(ExampleOperator(task_id="test")).extract()
+    metadata = extractor(OperatorWithoutFailure(task_id="test")).extract()
 
     task_instance = mock.MagicMock()
 
-    metadata_on_complete = 
extractor(ExampleOperator(task_id="test")).extract_on_complete(
+    metadata_on_complete = 
extractor(OperatorWithoutFailure(task_id="test")).extract_on_complete(
         task_instance=task_instance
     )
 
@@ -235,50 +270,59 @@ def test_extraction_without_on_start():
 
 
 @pytest.mark.parametrize(
-    "task_state, is_airflow_2_10_or_higher, should_call_on_failure",
+    "operator_class, task_state, expected_job_facets",
     (
-        # Airflow >= 2.10
-        (TaskInstanceState.FAILED, True, True),
-        (TaskInstanceState.UP_FOR_RETRY, True, True),
-        (TaskInstanceState.RUNNING, True, False),
-        (TaskInstanceState.SUCCESS, True, False),
-        # Airflow < 2.10
-        (TaskInstanceState.RUNNING, False, True),
-        (TaskInstanceState.SUCCESS, False, False),
-        (TaskInstanceState.FAILED, False, False),  # should never happen, 
fixed in #41053
-        (TaskInstanceState.UP_FOR_RETRY, False, False),  # should never 
happen, fixed in #41053
+        (OperatorWithAllOlMethods, TaskInstanceState.FAILED, FAILED_FACETS),
+        (OperatorWithAllOlMethods, TaskInstanceState.RUNNING, JOB_FACETS),
+        (OperatorWithAllOlMethods, TaskInstanceState.SUCCESS, FINISHED_FACETS),
+        (OperatorWithAllOlMethods, TaskInstanceState.UP_FOR_RETRY, 
FINISHED_FACETS),  # Should never happen
+        (OperatorWithAllOlMethods, None, FINISHED_FACETS),  # Should never 
happen
+        (OperatorWithoutFailure, TaskInstanceState.FAILED, FINISHED_FACETS),
+        (OperatorWithoutFailure, TaskInstanceState.RUNNING, JOB_FACETS),
+        (OperatorWithoutFailure, TaskInstanceState.SUCCESS, FINISHED_FACETS),
+        (OperatorWithoutFailure, TaskInstanceState.UP_FOR_RETRY, 
FINISHED_FACETS),  # Should never happen
+        (OperatorWithoutFailure, None, FINISHED_FACETS),  # Should never happen
+        (OperatorWithoutStart, TaskInstanceState.FAILED, FINISHED_FACETS),
+        (OperatorWithoutStart, TaskInstanceState.RUNNING, {}),
+        (OperatorWithoutStart, TaskInstanceState.SUCCESS, FINISHED_FACETS),
+        (OperatorWithoutStart, TaskInstanceState.UP_FOR_RETRY, 
FINISHED_FACETS),  # Should never happen
+        (OperatorWithoutStart, None, FINISHED_FACETS),  # Should never happen
+        (OperatorWithoutComplete, TaskInstanceState.FAILED, JOB_FACETS),
+        (OperatorWithoutComplete, TaskInstanceState.RUNNING, JOB_FACETS),
+        (OperatorWithoutComplete, TaskInstanceState.SUCCESS, JOB_FACETS),
+        (OperatorWithoutComplete, TaskInstanceState.UP_FOR_RETRY, JOB_FACETS), 
 # Should never happen
+        (OperatorWithoutComplete, None, JOB_FACETS),  # Should never happen
     ),
 )
-def test_extract_on_failure(task_state, is_airflow_2_10_or_higher, 
should_call_on_failure):
-    task_instance = mock.Mock(state=task_state)
-    operator = mock.Mock()
-    operator.get_openlineage_facets_on_failure = mock.Mock(
-        return_value=OperatorLineage(run_facets={"failed": True})
+def test_extractor_manager_calls_appropriate_extractor_method(
+    operator_class, task_state, expected_job_facets
+):
+    extractor_manager = ExtractorManager()
+
+    ti = mock.MagicMock()
+
+    metadata = extractor_manager.extract_metadata(
+        dagrun=mock.MagicMock(run_id="dagrun_run_id"),
+        task=operator_class(task_id="task_id"),
+        task_instance_state=task_state,
+        task_instance=ti,
     )
-    operator.get_openlineage_facets_on_complete = mock.Mock(return_value=None)
-
-    extractor = DefaultExtractor(operator=operator)
 
-    with mock.patch(
-        "airflow.providers.openlineage.extractors.base.AIRFLOW_V_2_10_PLUS", 
is_airflow_2_10_or_higher
-    ):
-        result = extractor.extract_on_complete(task_instance)
-
-        if should_call_on_failure:
-            
operator.get_openlineage_facets_on_failure.assert_called_once_with(task_instance)
-            operator.get_openlineage_facets_on_complete.assert_not_called()
-            assert isinstance(result, OperatorLineage)
-            assert result.run_facets == {"failed": True}
-        else:
-            operator.get_openlineage_facets_on_failure.assert_not_called()
-            
operator.get_openlineage_facets_on_complete.assert_called_once_with(task_instance)
-            assert result is None
+    assert metadata.job_facets == expected_job_facets
+    if not expected_job_facets:  # Empty OperatorLineage() is expected
+        assert not metadata.inputs
+        assert not metadata.outputs
+        assert not metadata.run_facets
+    else:
+        assert metadata.inputs == INPUTS
+        assert metadata.outputs == OUTPUTS
+        assert metadata.run_facets == RUN_FACETS
 
 
 @mock.patch("airflow.providers.openlineage.conf.custom_extractors")
 def test_extractors_env_var(custom_extractors):
     custom_extractors.return_value = 
{"unit.openlineage.extractors.test_base.ExampleExtractor"}
-    extractor = 
ExtractorManager().get_extractor_class(ExampleOperator(task_id="example"))
+    extractor = 
ExtractorManager().get_extractor_class(OperatorWithoutFailure(task_id="example"))
     assert extractor is ExampleExtractor
 
 
@@ -292,7 +336,7 @@ def 
test_does_not_use_default_extractor_when_no_get_openlineage_facets():
     assert extractor_class is None
 
 
-def test_does_not_use_default_extractor_when_explicite_extractor():
+def test_does_not_use_default_extractor_when_explicit_extractor():
     extractor_class = ExtractorManager().get_extractor_class(
         PythonOperator(task_id="c", python_callable=lambda: 7)
     )
@@ -316,6 +360,4 @@ def 
test_default_extractor_uses_wrong_operatorlineage_class():
     operator = OperatorWrongOperatorLineageClass(task_id="task_id")
     # If extractor returns lineage class that can't be changed into 
OperatorLineage, just return
     # empty OperatorLineage
-    assert (
-        ExtractorManager().extract_metadata(mock.MagicMock(), operator, 
complete=False) == OperatorLineage()
-    )
+    assert ExtractorManager().extract_metadata(mock.MagicMock(), operator, 
None) == OperatorLineage()
diff --git 
a/providers/openlineage/tests/unit/openlineage/extractors/test_manager.py 
b/providers/openlineage/tests/unit/openlineage/extractors/test_manager.py
index 256bc133f1a..04739f22633 100644
--- a/providers/openlineage/tests/unit/openlineage/extractors/test_manager.py
+++ b/providers/openlineage/tests/unit/openlineage/extractors/test_manager.py
@@ -293,7 +293,9 @@ def 
test_extractor_manager_uses_hook_level_lineage(hook_lineage_collector):
     hook_lineage_collector.add_input_asset(None, uri="s3://bucket/input_key")
     hook_lineage_collector.add_output_asset(None, uri="s3://bucket/output_key")
     extractor_manager = ExtractorManager()
-    metadata = extractor_manager.extract_metadata(dagrun=dagrun, task=task, 
complete=True, task_instance=ti)
+    metadata = extractor_manager.extract_metadata(
+        dagrun=dagrun, task=task, task_instance_state=None, task_instance=ti
+    )
 
     assert metadata.inputs == [OpenLineageDataset(namespace="s3://bucket", 
name="input_key")]
     assert metadata.outputs == [OpenLineageDataset(namespace="s3://bucket", 
name="output_key")]
@@ -318,7 +320,9 @@ def 
test_extractor_manager_does_not_use_hook_level_lineage_when_operator(
     hook_lineage_collector.add_input_asset(None, uri="s3://bucket/input_key")
 
     extractor_manager = ExtractorManager()
-    metadata = extractor_manager.extract_metadata(dagrun=dagrun, task=task, 
complete=True, task_instance=ti)
+    metadata = extractor_manager.extract_metadata(
+        dagrun=dagrun, task=task, task_instance_state=None, task_instance=ti
+    )
 
     # s3://bucket/input_key not here - use data from operator
     assert metadata.inputs == [OpenLineageDataset(namespace="s3://bucket", 
name="proper_input_key")]

Reply via email to