This is an automated email from the ASF dual-hosted git repository.

mobuchowski pushed a commit to branch aip-62/openlineage-impl
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 8d434eccc7301580d866170403ca803fd033ffc0
Author: Maciej Obuchowski <[email protected]>
AuthorDate: Wed Aug 14 17:19:36 2024 +0200

    feat: openlineage listener captures hook-level lineage
    
    Signed-off-by: Maciej Obuchowski <[email protected]>
---
 airflow/io/path.py                                 |   1 -
 .../providers/openlineage/extractors/manager.py    |  30 +++++
 .../providers/openlineage/plugins/openlineage.py   |   5 +
 airflow/providers/openlineage/provider.yaml        |   1 +
 generated/provider_dependencies.json               |   1 +
 .../openlineage/extractors/test_manager.py         | 147 ++++++++++++++++++---
 6 files changed, 163 insertions(+), 22 deletions(-)

diff --git a/airflow/io/path.py b/airflow/io/path.py
index 7c8d1f9f19..8afcab19ea 100644
--- a/airflow/io/path.py
+++ b/airflow/io/path.py
@@ -54,7 +54,6 @@ class TrackingFileWrapper(LoggingMixin):
         if callable(attr):
             # If the attribute is a method, wrap it in another method to 
intercept the call
             def wrapper(*args, **kwargs):
-                self.log.error("Calling method: %s", name)
                 if name == "read":
                     
get_hook_lineage_collector().add_input_dataset(context=self._path, 
uri=str(self._path))
                 elif name == "write":
diff --git a/airflow/providers/openlineage/extractors/manager.py 
b/airflow/providers/openlineage/extractors/manager.py
index 5b9ad6ac1b..0656248867 100644
--- a/airflow/providers/openlineage/extractors/manager.py
+++ b/airflow/providers/openlineage/extractors/manager.py
@@ -24,7 +24,9 @@ from airflow.providers.openlineage.extractors.base import 
DefaultExtractor
 from airflow.providers.openlineage.extractors.bash import BashExtractor
 from airflow.providers.openlineage.extractors.python import PythonExtractor
 from airflow.providers.openlineage.utils.utils import (
+    IS_AIRFLOW_2_10_OR_HIGHER,
     get_unknown_source_attribute_run_facet,
+    translate_airflow_dataset,
     try_import_from_string,
 )
 from airflow.utils.log.logging_mixin import LoggingMixin
@@ -113,6 +115,10 @@ class ExtractorManager(LoggingMixin):
                 self.log.warning(
                     "Failed to extract metadata using found extractor %s - %s 
%s", extractor, e, task_info
                 )
+        elif IS_AIRFLOW_2_10_OR_HIGHER and (hook_lineage := 
self.get_hook_lineage()) is not None:
+            inputs, outputs = hook_lineage
+            task_metadata = OperatorLineage(inputs=inputs, outputs=outputs)
+            return task_metadata
         else:
             self.log.debug("Unable to find an extractor %s", task_info)
 
@@ -168,6 +174,30 @@ class ExtractorManager(LoggingMixin):
             if d:
                 task_metadata.outputs.append(d)
 
+    def get_hook_lineage(self) -> tuple[list[Dataset], list[Dataset]] | None:
+        try:
+            from airflow.lineage.hook import get_hook_lineage_collector
+        except ImportError:
+            return None
+
+        if not get_hook_lineage_collector().has_collected:
+            return None
+
+        return (
+            [
+                dataset
+                for dataset_info in 
get_hook_lineage_collector().collected_datasets.inputs
+                if (dataset := translate_airflow_dataset(dataset_info.dataset, 
dataset_info.context))
+                is not None
+            ],
+            [
+                dataset
+                for dataset_info in 
get_hook_lineage_collector().collected_datasets.outputs
+                if (dataset := translate_airflow_dataset(dataset_info.dataset, 
dataset_info.context))
+                is not None
+            ],
+        )
+
     @staticmethod
     def convert_to_ol_dataset_from_object_storage_uri(uri: str) -> Dataset | 
None:
         from urllib.parse import urlparse
diff --git a/airflow/providers/openlineage/plugins/openlineage.py 
b/airflow/providers/openlineage/plugins/openlineage.py
index 5b4f2c3dbf..ce70f3a52a 100644
--- a/airflow/providers/openlineage/plugins/openlineage.py
+++ b/airflow/providers/openlineage/plugins/openlineage.py
@@ -25,6 +25,7 @@ from airflow.providers.openlineage.plugins.macros import (
     lineage_parent_id,
     lineage_run_id,
 )
+from airflow.providers.openlineage.utils.utils import IS_AIRFLOW_2_10_OR_HIGHER
 
 
 class OpenLineageProviderPlugin(AirflowPlugin):
@@ -39,6 +40,10 @@ class OpenLineageProviderPlugin(AirflowPlugin):
     if not conf.is_disabled():
         macros = [lineage_job_namespace, lineage_job_name, lineage_run_id, 
lineage_parent_id]
         listeners = [get_openlineage_listener()]
+        if IS_AIRFLOW_2_10_OR_HIGHER:
+            from airflow.lineage.hook import HookLineageReader
+
+            hook_lineage_readers = [HookLineageReader]
     else:
         macros = []
         listeners = []
diff --git a/airflow/providers/openlineage/provider.yaml 
b/airflow/providers/openlineage/provider.yaml
index f38b49035d..c09c5dedda 100644
--- a/airflow/providers/openlineage/provider.yaml
+++ b/airflow/providers/openlineage/provider.yaml
@@ -47,6 +47,7 @@ versions:
 dependencies:
   - apache-airflow>=2.8.0
   - apache-airflow-providers-common-sql>=1.6.0
+  - apache-airflow-providers-common-compat>=1.1.0
   - attrs>=22.2
   - openlineage-integration-common>=1.16.0
   - openlineage-python>=1.16.0
diff --git a/generated/provider_dependencies.json 
b/generated/provider_dependencies.json
index 2282045e12..609f1afd1e 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -935,6 +935,7 @@
   },
   "openlineage": {
     "deps": [
+      "apache-airflow-providers-common-compat>=1.1.0",
       "apache-airflow-providers-common-sql>=1.6.0",
       "apache-airflow>=2.8.0",
       "attrs>=22.2",
diff --git a/tests/providers/openlineage/extractors/test_manager.py 
b/tests/providers/openlineage/extractors/test_manager.py
index ccfd04d5e2..35edc599bc 100644
--- a/tests/providers/openlineage/extractors/test_manager.py
+++ b/tests/providers/openlineage/extractors/test_manager.py
@@ -17,26 +17,47 @@
 # under the License.
 from __future__ import annotations
 
+import tempfile
+from typing import TYPE_CHECKING, Any
+from unittest.mock import MagicMock
+
 import pytest
-from openlineage.client.event_v2 import Dataset
+from openlineage.client.event_v2 import Dataset as OpenLineageDataset
 from openlineage.client.facet_v2 import documentation_dataset, 
ownership_dataset, schema_dataset
 
+from airflow.datasets import Dataset
+from airflow.io.path import ObjectStoragePath
 from airflow.lineage.entities import Column, File, Table, User
+from airflow.models.baseoperator import BaseOperator
+from airflow.models.taskinstance import TaskInstance
+from airflow.operators.python import PythonOperator
+from airflow.providers.openlineage.extractors import OperatorLineage
 from airflow.providers.openlineage.extractors.manager import ExtractorManager
+from airflow.utils.state import State
+from tests.test_utils.compat import AIRFLOW_V_2_10_PLUS
+
+if TYPE_CHECKING:
+    from airflow.utils.context import Context
 
 
 @pytest.mark.parametrize(
     ("uri", "dataset"),
     (
-        ("s3://bucket1/dir1/file1", Dataset(namespace="s3://bucket1", 
name="dir1/file1")),
-        ("gs://bucket2/dir2/file2", Dataset(namespace="gs://bucket2", 
name="dir2/file2")),
-        ("gcs://bucket3/dir3/file3", Dataset(namespace="gs://bucket3", 
name="dir3/file3")),
-        ("hdfs://namenodehost:8020/file1", 
Dataset(namespace="hdfs://namenodehost:8020", name="file1")),
-        ("hdfs://namenodehost/file2", Dataset(namespace="hdfs://namenodehost", 
name="file2")),
-        ("file://localhost/etc/fstab", Dataset(namespace="file://localhost", 
name="etc/fstab")),
-        ("file:///etc/fstab", Dataset(namespace="file://", name="etc/fstab")),
-        ("https://test.com";, Dataset(namespace="https", name="test.com")),
-        ("https://test.com?param1=test1&param2=test2";, 
Dataset(namespace="https", name="test.com")),
+        ("s3://bucket1/dir1/file1", 
OpenLineageDataset(namespace="s3://bucket1", name="dir1/file1")),
+        ("gs://bucket2/dir2/file2", 
OpenLineageDataset(namespace="gs://bucket2", name="dir2/file2")),
+        ("gcs://bucket3/dir3/file3", 
OpenLineageDataset(namespace="gs://bucket3", name="dir3/file3")),
+        (
+            "hdfs://namenodehost:8020/file1",
+            OpenLineageDataset(namespace="hdfs://namenodehost:8020", 
name="file1"),
+        ),
+        ("hdfs://namenodehost/file2", 
OpenLineageDataset(namespace="hdfs://namenodehost", name="file2")),
+        ("file://localhost/etc/fstab", 
OpenLineageDataset(namespace="file://localhost", name="etc/fstab")),
+        ("file:///etc/fstab", OpenLineageDataset(namespace="file://", 
name="etc/fstab")),
+        ("https://test.com";, OpenLineageDataset(namespace="https", 
name="test.com")),
+        (
+            "https://test.com?param1=test1&param2=test2";,
+            OpenLineageDataset(namespace="https", name="test.com"),
+        ),
         ("file:test.csv", None),
         ("not_an_url", None),
     ),
@@ -50,21 +71,36 @@ def test_convert_to_ol_dataset_from_object_storage_uri(uri, 
dataset):
     ("obj", "dataset"),
     (
         (
-            Dataset(namespace="n1", name="f1"),
-            Dataset(namespace="n1", name="f1"),
+            OpenLineageDataset(namespace="n1", name="f1"),
+            OpenLineageDataset(namespace="n1", name="f1"),
+        ),
+        (
+            File(url="s3://bucket1/dir1/file1"),
+            OpenLineageDataset(namespace="s3://bucket1", name="dir1/file1"),
+        ),
+        (
+            File(url="gs://bucket2/dir2/file2"),
+            OpenLineageDataset(namespace="gs://bucket2", name="dir2/file2"),
+        ),
+        (
+            File(url="gcs://bucket3/dir3/file3"),
+            OpenLineageDataset(namespace="gs://bucket3", name="dir3/file3"),
         ),
-        (File(url="s3://bucket1/dir1/file1"), 
Dataset(namespace="s3://bucket1", name="dir1/file1")),
-        (File(url="gs://bucket2/dir2/file2"), 
Dataset(namespace="gs://bucket2", name="dir2/file2")),
-        (File(url="gcs://bucket3/dir3/file3"), 
Dataset(namespace="gs://bucket3", name="dir3/file3")),
         (
             File(url="hdfs://namenodehost:8020/file1"),
-            Dataset(namespace="hdfs://namenodehost:8020", name="file1"),
+            OpenLineageDataset(namespace="hdfs://namenodehost:8020", 
name="file1"),
+        ),
+        (
+            File(url="hdfs://namenodehost/file2"),
+            OpenLineageDataset(namespace="hdfs://namenodehost", name="file2"),
+        ),
+        (
+            File(url="file://localhost/etc/fstab"),
+            OpenLineageDataset(namespace="file://localhost", name="etc/fstab"),
         ),
-        (File(url="hdfs://namenodehost/file2"), 
Dataset(namespace="hdfs://namenodehost", name="file2")),
-        (File(url="file://localhost/etc/fstab"), 
Dataset(namespace="file://localhost", name="etc/fstab")),
-        (File(url="file:///etc/fstab"), Dataset(namespace="file://", 
name="etc/fstab")),
-        (File(url="https://test.com";), Dataset(namespace="https", 
name="test.com")),
-        (Table(cluster="c1", database="d1", name="t1"), 
Dataset(namespace="c1", name="d1.t1")),
+        (File(url="file:///etc/fstab"), 
OpenLineageDataset(namespace="file://", name="etc/fstab")),
+        (File(url="https://test.com";), OpenLineageDataset(namespace="https", 
name="test.com")),
+        (Table(cluster="c1", database="d1", name="t1"), 
OpenLineageDataset(namespace="c1", name="d1.t1")),
         ("gs://bucket2/dir2/file2", None),
         ("not_an_url", None),
     ),
@@ -167,3 +203,72 @@ def test_convert_to_ol_dataset_table():
     assert result.namespace == "c1"
     assert result.name == "d1.t1"
     assert result.facets == expected_facets
+
+
[email protected](not AIRFLOW_V_2_10_PLUS, reason="Hook lineage works in 
Airflow >= 2.10.0")
+def test_extractor_manager_uses_hook_level_lineage(hook_lineage_collector):
+    dagrun = MagicMock()
+    task = MagicMock()
+    del task.get_openlineage_facets_on_start
+    del task.get_openlineage_facets_on_complete
+    ti = MagicMock()
+
+    hook_lineage_collector.add_input_dataset(None, uri="s3://bucket/input_key")
+    hook_lineage_collector.add_output_dataset(None, 
uri="s3://bucket/output_key")
+    extractor_manager = ExtractorManager()
+    metadata = extractor_manager.extract_metadata(dagrun=dagrun, task=task, 
complete=True, task_instance=ti)
+
+    assert metadata.inputs == [OpenLineageDataset(namespace="s3://bucket", 
name="input_key")]
+    assert metadata.outputs == [OpenLineageDataset(namespace="s3://bucket", 
name="output_key")]
+
+
[email protected](not AIRFLOW_V_2_10_PLUS, reason="Hook lineage works in 
Airflow >= 2.10.0")
+def 
test_extractor_manager_does_not_use_hook_level_lineage_when_operator(hook_lineage_collector):
+    class FakeSupportedOperator(BaseOperator):
+        def execute(self, context: Context) -> Any:
+            pass
+
+        def get_openlineage_facets_on_start(self):
+            return OperatorLineage(
+                inputs=[OpenLineageDataset(namespace="s3://bucket", 
name="proper_input_key")]
+            )
+
+    dagrun = MagicMock()
+    task = FakeSupportedOperator(task_id="test_task_extractor")
+    ti = MagicMock()
+    hook_lineage_collector.add_input_dataset(None, uri="s3://bucket/input_key")
+
+    extractor_manager = ExtractorManager()
+    metadata = extractor_manager.extract_metadata(dagrun=dagrun, task=task, 
complete=True, task_instance=ti)
+
+    # s3://bucket/input_key not here - use data from operator
+    assert metadata.inputs == [OpenLineageDataset(namespace="s3://bucket", 
name="proper_input_key")]
+    assert metadata.outputs == []
+
+
[email protected](not AIRFLOW_V_2_10_PLUS, reason="Hook lineage works in 
Airflow >= 2.10.0")
+def test_extractor_manager_gets_data_from_pythonoperator(session, dag_maker, 
hook_lineage_collector):
+    path = None
+    with tempfile.NamedTemporaryFile() as f:
+        path = f.name
+        with dag_maker():
+
+            def use_read():
+                storage_path = ObjectStoragePath(path)
+                with storage_path.open("w") as out:
+                    out.write("test")
+
+            task = 
PythonOperator(task_id="test_task_extractor_pythonoperator", 
python_callable=use_read)
+
+    dr = dag_maker.create_dagrun()
+    ti = TaskInstance(task=task, run_id=dr.run_id)
+    ti.state = State.QUEUED
+    session.merge(ti)
+    session.commit()
+
+    ti.run()
+
+    datasets = hook_lineage_collector.collected_datasets
+
+    assert len(datasets.outputs) == 1
+    assert datasets.outputs[0].dataset == Dataset(uri=path)

Reply via email to