This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 1de5a965921 feat: add OpenLineage support for transfer operators 
between gcs and local (#44417)
1de5a965921 is described below

commit 1de5a965921e75162fa23f2fcd8514beea428429
Author: Kacper Muda <[email protected]>
AuthorDate: Wed Nov 27 14:31:06 2024 +0100

    feat: add OpenLineage support for transfer operators between gcs and local 
(#44417)
    
    Signed-off-by: Kacper Muda <[email protected]>
---
 .../src/airflow/providers/common/io/assets/file.py |  4 +-
 .../google/cloud/transfers/gcs_to_local.py         |  9 +++
 .../google/cloud/transfers/local_to_gcs.py         | 47 ++++++++++--
 providers/tests/common/io/assets/test_file.py      |  4 +-
 .../google/cloud/transfers/test_gcs_to_local.py    | 17 +++++
 .../google/cloud/transfers/test_local_to_gcs.py    | 85 +++++++++++++++++++---
 6 files changed, 148 insertions(+), 18 deletions(-)

diff --git a/providers/src/airflow/providers/common/io/assets/file.py 
b/providers/src/airflow/providers/common/io/assets/file.py
index 6277e48c0a8..28d990d5630 100644
--- a/providers/src/airflow/providers/common/io/assets/file.py
+++ b/providers/src/airflow/providers/common/io/assets/file.py
@@ -56,4 +56,6 @@ def convert_asset_to_openlineage(asset: Asset, 
lineage_context) -> OpenLineageDa
     from airflow.providers.common.compat.openlineage.facet import Dataset as 
OpenLineageDataset
 
     parsed = urllib.parse.urlsplit(asset.uri)
-    return OpenLineageDataset(namespace=f"file://{parsed.netloc}", 
name=parsed.path)
+    return OpenLineageDataset(
+        namespace=f"file://{parsed.netloc}" if parsed.netloc else "file", 
name=parsed.path
+    )
diff --git 
a/providers/src/airflow/providers/google/cloud/transfers/gcs_to_local.py 
b/providers/src/airflow/providers/google/cloud/transfers/gcs_to_local.py
index bdcee1006ff..70cdf0cdb9b 100644
--- a/providers/src/airflow/providers/google/cloud/transfers/gcs_to_local.py
+++ b/providers/src/airflow/providers/google/cloud/transfers/gcs_to_local.py
@@ -113,3 +113,12 @@ class GCSToLocalFilesystemOperator(BaseOperator):
                 raise AirflowException("The size of the downloaded file is too 
large to push to XCom!")
         else:
             hook.download(bucket_name=self.bucket, 
object_name=self.object_name, filename=self.filename)
+
+    def get_openlineage_facets_on_start(self):
+        from airflow.providers.common.compat.openlineage.facet import Dataset
+        from airflow.providers.openlineage.extractors import OperatorLineage
+
+        return OperatorLineage(
+            inputs=[Dataset(namespace=f"gs://{self.bucket}", 
name=self.object_name)],
+            outputs=[Dataset(namespace="file", name=self.filename)] if 
self.filename else [],
+        )
diff --git 
a/providers/src/airflow/providers/google/cloud/transfers/local_to_gcs.py 
b/providers/src/airflow/providers/google/cloud/transfers/local_to_gcs.py
index eeed05c1d00..b1a143b242f 100644
--- a/providers/src/airflow/providers/google/cloud/transfers/local_to_gcs.py
+++ b/providers/src/airflow/providers/google/cloud/transfers/local_to_gcs.py
@@ -69,12 +69,12 @@ class LocalFilesystemToGCSOperator(BaseOperator):
     def __init__(
         self,
         *,
-        src,
-        dst,
-        bucket,
-        gcp_conn_id="google_cloud_default",
-        mime_type="application/octet-stream",
-        gzip=False,
+        src: str | list[str],
+        dst: str,
+        bucket: str,
+        gcp_conn_id: str = "google_cloud_default",
+        mime_type: str = "application/octet-stream",
+        gzip: bool = False,
         chunk_size: int | None = None,
         impersonation_chain: str | Sequence[str] | None = None,
         **kwargs,
@@ -120,3 +120,38 @@ class LocalFilesystemToGCSOperator(BaseOperator):
                 gzip=self.gzip,
                 chunk_size=self.chunk_size,
             )
+
+    def get_openlineage_facets_on_start(self):
+        from airflow.providers.common.compat.openlineage.facet import (
+            Dataset,
+            Identifier,
+            SymlinksDatasetFacet,
+        )
+        from airflow.providers.google.cloud.openlineage.utils import WILDCARD, 
extract_ds_name_from_gcs_path
+        from airflow.providers.openlineage.extractors import OperatorLineage
+
+        source_facets = {}
+        if isinstance(self.src, str):  # Single path provided, possibly 
relative or with wildcard
+            original_src = f"{self.src}"
+            absolute_src = os.path.abspath(self.src)
+            resolved_src = extract_ds_name_from_gcs_path(absolute_src)
+            if original_src.startswith("/") and not 
resolved_src.startswith("/"):
+                resolved_src = "/" + resolved_src
+            source_objects = [resolved_src]
+
+            if WILDCARD in original_src or absolute_src != resolved_src:
+                # We attach a symlink with unmodified path.
+                source_facets = {
+                    "symlink": SymlinksDatasetFacet(
+                        identifiers=[Identifier(namespace="file", 
name=original_src, type="file")]
+                    ),
+                }
+        else:
+            source_objects = self.src
+
+        dest_object = self.dst if os.path.basename(self.dst) else 
extract_ds_name_from_gcs_path(self.dst)
+
+        return OperatorLineage(
+            inputs=[Dataset(namespace="file", name=src, facets=source_facets) 
for src in source_objects],
+            outputs=[Dataset(namespace=f"gs://{self.bucket}", 
name=dest_object)],
+        )
diff --git a/providers/tests/common/io/assets/test_file.py 
b/providers/tests/common/io/assets/test_file.py
index 21357f933fd..d2dc48d845e 100644
--- a/providers/tests/common/io/assets/test_file.py
+++ b/providers/tests/common/io/assets/test_file.py
@@ -54,12 +54,12 @@ def test_file_asset():
 @pytest.mark.parametrize(
     ("uri", "ol_dataset"),
     (
-        ("file:///valid/path", OpenLineageDataset(namespace="file://", 
name="/valid/path")),
+        ("file:///valid/path", OpenLineageDataset(namespace="file", 
name="/valid/path")),
         (
             "file://127.0.0.1:8080/dir/file.csv",
             OpenLineageDataset(namespace="file://127.0.0.1:8080", 
name="/dir/file.csv"),
         ),
-        ("file:///C://dir/file", OpenLineageDataset(namespace="file://", 
name="/C://dir/file")),
+        ("file:///C://dir/file", OpenLineageDataset(namespace="file", 
name="/C://dir/file")),
     ),
 )
 def test_convert_asset_to_openlineage(uri, ol_dataset):
diff --git a/providers/tests/google/cloud/transfers/test_gcs_to_local.py 
b/providers/tests/google/cloud/transfers/test_gcs_to_local.py
index d88261699ad..c4ff984bd9a 100644
--- a/providers/tests/google/cloud/transfers/test_gcs_to_local.py
+++ b/providers/tests/google/cloud/transfers/test_gcs_to_local.py
@@ -113,3 +113,20 @@ class TestGoogleCloudStorageDownloadOperator:
             bucket_name=TEST_BUCKET, object_name=TEST_OBJECT
         )
         context["ti"].xcom_push.assert_called_once_with(key=XCOM_KEY, 
value=FILE_CONTENT_STR)
+
+    def test_get_openlineage_facets_on_start_(self):
+        operator = GCSToLocalFilesystemOperator(
+            task_id=TASK_ID,
+            bucket=TEST_BUCKET,
+            object_name=TEST_OBJECT,
+            filename=LOCAL_FILE_PATH,
+        )
+        result = operator.get_openlineage_facets_on_start()
+        assert not result.job_facets
+        assert not result.run_facets
+        assert len(result.outputs) == 1
+        assert len(result.inputs) == 1
+        assert result.outputs[0].namespace == "file"
+        assert result.outputs[0].name == LOCAL_FILE_PATH
+        assert result.inputs[0].namespace == f"gs://{TEST_BUCKET}"
+        assert result.inputs[0].name == TEST_OBJECT
diff --git a/providers/tests/google/cloud/transfers/test_local_to_gcs.py 
b/providers/tests/google/cloud/transfers/test_local_to_gcs.py
index bfa331372f6..0ebf2f59503 100644
--- a/providers/tests/google/cloud/transfers/test_local_to_gcs.py
+++ b/providers/tests/google/cloud/transfers/test_local_to_gcs.py
@@ -25,6 +25,10 @@ from unittest import mock
 import pytest
 
 from airflow.models.dag import DAG
+from airflow.providers.common.compat.openlineage.facet import (
+    Identifier,
+    SymlinksDatasetFacet,
+)
 from airflow.providers.google.cloud.transfers.local_to_gcs import 
LocalFilesystemToGCSOperator
 
 pytestmark = pytest.mark.db_test
@@ -72,7 +76,7 @@ class TestFileToGcsOperator:
     def test_execute(self, mock_hook):
         mock_instance = mock_hook.return_value
         operator = LocalFilesystemToGCSOperator(
-            task_id="gcs_to_file_sensor",
+            task_id="file_to_gcs_operator",
             dag=self.dag,
             src=self.testfile1,
             dst="test/test1.csv",
@@ -91,7 +95,7 @@ class TestFileToGcsOperator:
     @pytest.mark.db_test
     def test_execute_with_empty_src(self):
         operator = LocalFilesystemToGCSOperator(
-            task_id="local_to_sensor",
+            task_id="file_to_gcs_operator",
             dag=self.dag,
             src="no_file.txt",
             dst="test/no_file.txt",
@@ -104,7 +108,7 @@ class TestFileToGcsOperator:
     def test_execute_multiple(self, mock_hook):
         mock_instance = mock_hook.return_value
         operator = LocalFilesystemToGCSOperator(
-            task_id="gcs_to_file_sensor", dag=self.dag, src=self.testfiles, 
dst="test/", **self._config
+            task_id="file_to_gcs_operator", dag=self.dag, src=self.testfiles, 
dst="test/", **self._config
         )
         operator.execute(None)
         files_objects = zip(
@@ -127,7 +131,7 @@ class TestFileToGcsOperator:
     def test_execute_wildcard(self, mock_hook):
         mock_instance = mock_hook.return_value
         operator = LocalFilesystemToGCSOperator(
-            task_id="gcs_to_file_sensor", dag=self.dag, src="/tmp/fake*.csv", 
dst="test/", **self._config
+            task_id="file_to_gcs_operator", dag=self.dag, 
src="/tmp/fake*.csv", dst="test/", **self._config
         )
         operator.execute(None)
         object_names = ["test/" + os.path.basename(fp) for fp in 
glob("/tmp/fake*.csv")]
@@ -145,17 +149,80 @@ class TestFileToGcsOperator:
         ]
         mock_instance.upload.assert_has_calls(calls)
 
+    @pytest.mark.parametrize(
+        ("src", "dst"),
+        [
+            ("/tmp/fake*.csv", "test/test1.csv"),
+            ("/tmp/fake*.csv", "test"),
+            ("/tmp/fake*.csv", "test/dir"),
+        ],
+    )
     
@mock.patch("airflow.providers.google.cloud.transfers.local_to_gcs.GCSHook", 
autospec=True)
-    def test_execute_negative(self, mock_hook):
+    def test_execute_negative(self, mock_hook, src, dst):
         mock_instance = mock_hook.return_value
         operator = LocalFilesystemToGCSOperator(
-            task_id="gcs_to_file_sensor",
+            task_id="file_to_gcs_operator",
             dag=self.dag,
-            src="/tmp/fake*.csv",
-            dst="test/test1.csv",
+            src=src,
+            dst=dst,
             **self._config,
         )
-        print(glob("/tmp/fake*.csv"))
         with pytest.raises(ValueError):
             operator.execute(None)
         mock_instance.assert_not_called()
+
+    @pytest.mark.parametrize(
+        ("src", "dst", "expected_input", "expected_output", "symlink"),
+        [
+            ("/tmp/fake*.csv", "test/", "/tmp", "test", True),
+            ("/tmp/../tmp/fake*.csv", "test/", "/tmp", "test", True),
+            ("/tmp/fake1.csv", "test/test1.csv", "/tmp/fake1.csv", 
"test/test1.csv", False),
+            ("/tmp/fake1.csv", "test/pre", "/tmp/fake1.csv", "test/pre", 
False),
+        ],
+    )
+    def test_get_openlineage_facets_on_start_with_string_src(
+        self, src, dst, expected_input, expected_output, symlink
+    ):
+        operator = LocalFilesystemToGCSOperator(
+            task_id="gcs_to_file_sensor",
+            dag=self.dag,
+            src=src,
+            dst=dst,
+            **self._config,
+        )
+        result = operator.get_openlineage_facets_on_start()
+        assert not result.job_facets
+        assert not result.run_facets
+        assert len(result.outputs) == 1
+        assert len(result.inputs) == 1
+        assert result.outputs[0].name == expected_output
+        assert result.inputs[0].name == expected_input
+        if symlink:
+            assert result.inputs[0].facets["symlink"] == SymlinksDatasetFacet(
+                identifiers=[Identifier(namespace="file", name=src, 
type="file")]
+            )
+
+    @pytest.mark.parametrize(
+        ("src", "dst", "expected_inputs", "expected_output"),
+        [
+            (["/tmp/fake1.csv", "/tmp/fake2.csv"], "test/", ["/tmp/fake1.csv", 
"/tmp/fake2.csv"], "test"),
+            (["/tmp/fake1.csv", "/tmp/fake2.csv"], "", ["/tmp/fake1.csv", 
"/tmp/fake2.csv"], "/"),
+        ],
+    )
+    def test_get_openlineage_facets_on_start_with_list_src(self, src, dst, 
expected_inputs, expected_output):
+        operator = LocalFilesystemToGCSOperator(
+            task_id="gcs_to_file_sensor",
+            dag=self.dag,
+            src=src,
+            dst=dst,
+            **self._config,
+        )
+        result = operator.get_openlineage_facets_on_start()
+        assert not result.job_facets
+        assert not result.run_facets
+        assert len(result.outputs) == 1
+        assert len(result.inputs) == len(expected_inputs)
+        assert result.outputs[0].name == expected_output
+        assert result.outputs[0].namespace == "gs://dummy"
+        assert all(inp.name in expected_inputs for inp in result.inputs)
+        assert all(inp.namespace == "file" for inp in result.inputs)

Reply via email to