This is an automated email from the ASF dual-hosted git repository.

mobuchowski pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 066708352e fix: OpenLineage datasets in 
GCSTimeSpanFileTransformOperator (#39064)
066708352e is described below

commit 066708352e6a6a06f213b65324e982f582019b8e
Author: Kacper Muda <[email protected]>
AuthorDate: Wed Apr 17 21:14:26 2024 +0200

    fix: OpenLineage datasets in GCSTimeSpanFileTransformOperator (#39064)
    
    Signed-off-by: Kacper Muda <[email protected]>
---
 airflow/providers/google/cloud/operators/gcs.py    |  69 ++++++++------
 tests/providers/google/cloud/operators/test_gcs.py | 105 +++++++++++++++------
 2 files changed, 115 insertions(+), 59 deletions(-)

diff --git a/airflow/providers/google/cloud/operators/gcs.py 
b/airflow/providers/google/cloud/operators/gcs.py
index c311c8b4ed..6c72378a43 100644
--- a/airflow/providers/google/cloud/operators/gcs.py
+++ b/airflow/providers/google/cloud/operators/gcs.py
@@ -774,8 +774,8 @@ class 
GCSTimeSpanFileTransformOperator(GoogleCloudBaseOperator):
         self.upload_continue_on_fail = upload_continue_on_fail
         self.upload_num_attempts = upload_num_attempts
 
-        self._source_object_names: list[str] = []
-        self._destination_object_names: list[str] = []
+        self._source_prefix_interp: str | None = None
+        self._destination_prefix_interp: str | None = None
 
     def execute(self, context: Context) -> list[str]:
         # Define intervals and prefixes.
@@ -803,11 +803,11 @@ class 
GCSTimeSpanFileTransformOperator(GoogleCloudBaseOperator):
         timespan_start = timespan_start.in_timezone(timezone.utc)
         timespan_end = timespan_end.in_timezone(timezone.utc)
 
-        source_prefix_interp = 
GCSTimeSpanFileTransformOperator.interpolate_prefix(
+        self._source_prefix_interp = 
GCSTimeSpanFileTransformOperator.interpolate_prefix(
             self.source_prefix,
             timespan_start,
         )
-        destination_prefix_interp = 
GCSTimeSpanFileTransformOperator.interpolate_prefix(
+        self._destination_prefix_interp = 
GCSTimeSpanFileTransformOperator.interpolate_prefix(
             self.destination_prefix,
             timespan_start,
         )
@@ -828,9 +828,9 @@ class 
GCSTimeSpanFileTransformOperator(GoogleCloudBaseOperator):
         )
 
         # Fetch list of files.
-        self._source_object_names = source_hook.list_by_timespan(
+        blobs_to_transform = source_hook.list_by_timespan(
             bucket_name=self.source_bucket,
-            prefix=source_prefix_interp,
+            prefix=self._source_prefix_interp,
             timespan_start=timespan_start,
             timespan_end=timespan_end,
         )
@@ -840,7 +840,7 @@ class 
GCSTimeSpanFileTransformOperator(GoogleCloudBaseOperator):
             temp_output_dir_path = Path(temp_output_dir)
 
             # TODO: download in parallel.
-            for blob_to_transform in self._source_object_names:
+            for blob_to_transform in blobs_to_transform:
                 destination_file = temp_input_dir_path / blob_to_transform
                 destination_file.parent.mkdir(parents=True, exist_ok=True)
                 try:
@@ -877,6 +877,8 @@ class 
GCSTimeSpanFileTransformOperator(GoogleCloudBaseOperator):
 
             self.log.info("Transformation succeeded. Output temporarily 
located at %s", temp_output_dir_path)
 
+            files_uploaded = []
+
             # TODO: upload in parallel.
             for upload_file in temp_output_dir_path.glob("**/*"):
                 if upload_file.is_dir():
@@ -884,8 +886,8 @@ class 
GCSTimeSpanFileTransformOperator(GoogleCloudBaseOperator):
 
                 upload_file_name = 
str(upload_file.relative_to(temp_output_dir_path))
 
-                if self.destination_prefix is not None:
-                    upload_file_name = 
f"{destination_prefix_interp}/{upload_file_name}"
+                if self._destination_prefix_interp is not None:
+                    upload_file_name = 
f"{self._destination_prefix_interp.rstrip('/')}/{upload_file_name}"
 
                 self.log.info("Uploading file %s to %s", upload_file, 
upload_file_name)
 
@@ -897,35 +899,46 @@ class 
GCSTimeSpanFileTransformOperator(GoogleCloudBaseOperator):
                         chunk_size=self.chunk_size,
                         num_max_attempts=self.upload_num_attempts,
                     )
-                    
self._destination_object_names.append(str(upload_file_name))
+                    files_uploaded.append(str(upload_file_name))
                 except GoogleCloudError:
                     if not self.upload_continue_on_fail:
                         raise
 
-            return self._destination_object_names
+            return files_uploaded
 
     def get_openlineage_facets_on_complete(self, task_instance):
-        """Implement on_complete as execute() resolves object names."""
+        """Implement on_complete as execute() resolves object prefixes."""
         from openlineage.client.run import Dataset
 
         from airflow.providers.openlineage.extractors import OperatorLineage
 
-        input_datasets = [
-            Dataset(
-                namespace=f"gs://{self.source_bucket}",
-                name=object_name,
-            )
-            for object_name in self._source_object_names
-        ]
-        output_datasets = [
-            Dataset(
-                namespace=f"gs://{self.destination_bucket}",
-                name=object_name,
-            )
-            for object_name in self._destination_object_names
-        ]
-
-        return OperatorLineage(inputs=input_datasets, outputs=output_datasets)
+        def _parse_prefix(pref):
+            # Use parent if not a file (dot not in name) and not a dir (ends 
with slash)
+            if "." not in pref.split("/")[-1] and not pref.endswith("/"):
+                pref = Path(pref).parent.as_posix()
+            return "/" if pref in (".", "/", "") else pref.rstrip("/")
+
+        input_prefix, output_prefix = "/", "/"
+        if self._source_prefix_interp is not None:
+            input_prefix = _parse_prefix(self._source_prefix_interp)
+
+        if self._destination_prefix_interp is not None:
+            output_prefix = _parse_prefix(self._destination_prefix_interp)
+
+        return OperatorLineage(
+            inputs=[
+                Dataset(
+                    namespace=f"gs://{self.source_bucket}",
+                    name=input_prefix,
+                )
+            ],
+            outputs=[
+                Dataset(
+                    namespace=f"gs://{self.destination_bucket}",
+                    name=output_prefix,
+                )
+            ],
+        )
 
 
 class GCSDeleteBucketOperator(GoogleCloudBaseOperator):
diff --git a/tests/providers/google/cloud/operators/test_gcs.py 
b/tests/providers/google/cloud/operators/test_gcs.py
index 2eb96682bd..6236aa5f23 100644
--- a/tests/providers/google/cloud/operators/test_gcs.py
+++ b/tests/providers/google/cloud/operators/test_gcs.py
@@ -21,6 +21,7 @@ from datetime import datetime, timedelta, timezone
 from pathlib import Path
 from unittest import mock
 
+import pytest
 from openlineage.client.facet import (
     LifecycleStateChange,
     LifecycleStateChangeDatasetFacet,
@@ -483,15 +484,78 @@ class TestGCSTimeSpanFileTransformOperator:
             ]
         )
 
+    @pytest.mark.parametrize(
+        ("source_prefix", "dest_prefix", "inputs", "outputs"),
+        (
+            (
+                None,
+                None,
+                [Dataset(f"gs://{TEST_BUCKET}", "/")],
+                [Dataset(f"gs://{TEST_BUCKET}_dest", "/")],
+            ),
+            (
+                None,
+                "dest_pre/",
+                [Dataset(f"gs://{TEST_BUCKET}", "/")],
+                [Dataset(f"gs://{TEST_BUCKET}_dest", "dest_pre")],
+            ),
+            (
+                "source_pre/",
+                None,
+                [Dataset(f"gs://{TEST_BUCKET}", "source_pre")],
+                [Dataset(f"gs://{TEST_BUCKET}_dest", "/")],
+            ),
+            (
+                "source_pre/",
+                "dest_pre/",
+                [Dataset(f"gs://{TEST_BUCKET}", "source_pre")],
+                [Dataset(f"gs://{TEST_BUCKET}_dest", "dest_pre")],
+            ),
+            (
+                "source_pre",
+                "dest_pre",
+                [Dataset(f"gs://{TEST_BUCKET}", "/")],
+                [Dataset(f"gs://{TEST_BUCKET}_dest", "/")],
+            ),
+            (
+                "dir1/source_pre",
+                "dir2/dest_pre",
+                [Dataset(f"gs://{TEST_BUCKET}", "dir1")],
+                [Dataset(f"gs://{TEST_BUCKET}_dest", "dir2")],
+            ),
+            (
+                "",
+                "/",
+                [Dataset(f"gs://{TEST_BUCKET}", "/")],
+                [Dataset(f"gs://{TEST_BUCKET}_dest", "/")],
+            ),
+            (
+                "source/a.txt",
+                "target/",
+                [Dataset(f"gs://{TEST_BUCKET}", "source/a.txt")],
+                [Dataset(f"gs://{TEST_BUCKET}_dest", "target")],
+            ),
+        ),
+        ids=(
+            "no prefixes",
+            "dest prefix only",
+            "source prefix only",
+            "both with ending slash",
+            "both without ending slash",
+            "both as directory with prefix",
+            "both empty or root",
+            "source prefix is file path",
+        ),
+    )
     
@mock.patch("airflow.providers.google.cloud.operators.gcs.TemporaryDirectory")
     @mock.patch("airflow.providers.google.cloud.operators.gcs.subprocess")
     @mock.patch("airflow.providers.google.cloud.operators.gcs.GCSHook")
-    def test_get_openlineage_facets_on_complete(self, mock_hook, 
mock_subprocess, mock_tempdir):
+    def test_get_openlineage_facets_on_complete(
+        self, mock_hook, mock_subprocess, mock_tempdir, source_prefix, 
dest_prefix, inputs, outputs
+    ):
         source_bucket = TEST_BUCKET
-        source_prefix = "source_prefix"
 
         destination_bucket = TEST_BUCKET + "_dest"
-        destination_prefix = "destination_prefix"
         destination = "destination"
 
         file1 = "file1"
@@ -508,8 +572,8 @@ class TestGCSTimeSpanFileTransformOperator:
 
         mock_tempdir.return_value.__enter__.side_effect = ["source", 
destination]
         mock_hook.return_value.list_by_timespan.return_value = [
-            f"{source_prefix}/{file1}",
-            f"{source_prefix}/{file2}",
+            f"{source_prefix or ''}{file1}",
+            f"{source_prefix or ''}{file2}",
         ]
 
         mock_proc = mock.MagicMock()
@@ -529,7 +593,7 @@ class TestGCSTimeSpanFileTransformOperator:
             source_prefix=source_prefix,
             source_gcp_conn_id="",
             destination_bucket=destination_bucket,
-            destination_prefix=destination_prefix,
+            destination_prefix=dest_prefix,
             destination_gcp_conn_id="",
             transform_script="script.py",
         )
@@ -541,32 +605,11 @@ class TestGCSTimeSpanFileTransformOperator:
             ]
             op.execute(context=context)
 
-        expected_inputs = [
-            Dataset(
-                namespace=f"gs://{source_bucket}",
-                name=f"{source_prefix}/{file1}",
-            ),
-            Dataset(
-                namespace=f"gs://{source_bucket}",
-                name=f"{source_prefix}/{file2}",
-            ),
-        ]
-        expected_outputs = [
-            Dataset(
-                namespace=f"gs://{destination_bucket}",
-                name=f"{destination_prefix}/{file1}",
-            ),
-            Dataset(
-                namespace=f"gs://{destination_bucket}",
-                name=f"{destination_prefix}/{file2}",
-            ),
-        ]
-
         lineage = op.get_openlineage_facets_on_complete(None)
-        assert len(lineage.inputs) == 2
-        assert len(lineage.outputs) == 2
-        assert lineage.inputs == expected_inputs
-        assert lineage.outputs == expected_outputs
+        assert len(lineage.inputs) == len(inputs)
+        assert len(lineage.outputs) == len(outputs)
+        assert sorted(lineage.inputs) == sorted(inputs)
+        assert sorted(lineage.outputs) == sorted(outputs)
 
 
 class TestGCSDeleteBucketOperator:

Reply via email to