This is an automated email from the ASF dual-hosted git repository.
mobuchowski pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 066708352e fix: OpenLineage datasets in
GCSTimeSpanFileTransformOperator (#39064)
066708352e is described below
commit 066708352e6a6a06f213b65324e982f582019b8e
Author: Kacper Muda <[email protected]>
AuthorDate: Wed Apr 17 21:14:26 2024 +0200
fix: OpenLineage datasets in GCSTimeSpanFileTransformOperator (#39064)
Signed-off-by: Kacper Muda <[email protected]>
---
airflow/providers/google/cloud/operators/gcs.py | 69 ++++++++------
tests/providers/google/cloud/operators/test_gcs.py | 105 +++++++++++++++------
2 files changed, 115 insertions(+), 59 deletions(-)
diff --git a/airflow/providers/google/cloud/operators/gcs.py
b/airflow/providers/google/cloud/operators/gcs.py
index c311c8b4ed..6c72378a43 100644
--- a/airflow/providers/google/cloud/operators/gcs.py
+++ b/airflow/providers/google/cloud/operators/gcs.py
@@ -774,8 +774,8 @@ class
GCSTimeSpanFileTransformOperator(GoogleCloudBaseOperator):
self.upload_continue_on_fail = upload_continue_on_fail
self.upload_num_attempts = upload_num_attempts
- self._source_object_names: list[str] = []
- self._destination_object_names: list[str] = []
+ self._source_prefix_interp: str | None = None
+ self._destination_prefix_interp: str | None = None
def execute(self, context: Context) -> list[str]:
# Define intervals and prefixes.
@@ -803,11 +803,11 @@ class
GCSTimeSpanFileTransformOperator(GoogleCloudBaseOperator):
timespan_start = timespan_start.in_timezone(timezone.utc)
timespan_end = timespan_end.in_timezone(timezone.utc)
- source_prefix_interp =
GCSTimeSpanFileTransformOperator.interpolate_prefix(
+ self._source_prefix_interp =
GCSTimeSpanFileTransformOperator.interpolate_prefix(
self.source_prefix,
timespan_start,
)
- destination_prefix_interp =
GCSTimeSpanFileTransformOperator.interpolate_prefix(
+ self._destination_prefix_interp =
GCSTimeSpanFileTransformOperator.interpolate_prefix(
self.destination_prefix,
timespan_start,
)
@@ -828,9 +828,9 @@ class
GCSTimeSpanFileTransformOperator(GoogleCloudBaseOperator):
)
# Fetch list of files.
- self._source_object_names = source_hook.list_by_timespan(
+ blobs_to_transform = source_hook.list_by_timespan(
bucket_name=self.source_bucket,
- prefix=source_prefix_interp,
+ prefix=self._source_prefix_interp,
timespan_start=timespan_start,
timespan_end=timespan_end,
)
@@ -840,7 +840,7 @@ class
GCSTimeSpanFileTransformOperator(GoogleCloudBaseOperator):
temp_output_dir_path = Path(temp_output_dir)
# TODO: download in parallel.
- for blob_to_transform in self._source_object_names:
+ for blob_to_transform in blobs_to_transform:
destination_file = temp_input_dir_path / blob_to_transform
destination_file.parent.mkdir(parents=True, exist_ok=True)
try:
@@ -877,6 +877,8 @@ class
GCSTimeSpanFileTransformOperator(GoogleCloudBaseOperator):
self.log.info("Transformation succeeded. Output temporarily
located at %s", temp_output_dir_path)
+ files_uploaded = []
+
# TODO: upload in parallel.
for upload_file in temp_output_dir_path.glob("**/*"):
if upload_file.is_dir():
@@ -884,8 +886,8 @@ class
GCSTimeSpanFileTransformOperator(GoogleCloudBaseOperator):
upload_file_name =
str(upload_file.relative_to(temp_output_dir_path))
- if self.destination_prefix is not None:
- upload_file_name =
f"{destination_prefix_interp}/{upload_file_name}"
+ if self._destination_prefix_interp is not None:
+ upload_file_name =
f"{self._destination_prefix_interp.rstrip('/')}/{upload_file_name}"
self.log.info("Uploading file %s to %s", upload_file,
upload_file_name)
@@ -897,35 +899,46 @@ class
GCSTimeSpanFileTransformOperator(GoogleCloudBaseOperator):
chunk_size=self.chunk_size,
num_max_attempts=self.upload_num_attempts,
)
-
self._destination_object_names.append(str(upload_file_name))
+ files_uploaded.append(str(upload_file_name))
except GoogleCloudError:
if not self.upload_continue_on_fail:
raise
- return self._destination_object_names
+ return files_uploaded
def get_openlineage_facets_on_complete(self, task_instance):
- """Implement on_complete as execute() resolves object names."""
+ """Implement on_complete as execute() resolves object prefixes."""
from openlineage.client.run import Dataset
from airflow.providers.openlineage.extractors import OperatorLineage
- input_datasets = [
- Dataset(
- namespace=f"gs://{self.source_bucket}",
- name=object_name,
- )
- for object_name in self._source_object_names
- ]
- output_datasets = [
- Dataset(
- namespace=f"gs://{self.destination_bucket}",
- name=object_name,
- )
- for object_name in self._destination_object_names
- ]
-
- return OperatorLineage(inputs=input_datasets, outputs=output_datasets)
+ def _parse_prefix(pref):
+ # Use parent if not a file (dot not in name) and not a dir (ends
with slash)
+ if "." not in pref.split("/")[-1] and not pref.endswith("/"):
+ pref = Path(pref).parent.as_posix()
+ return "/" if pref in (".", "/", "") else pref.rstrip("/")
+
+ input_prefix, output_prefix = "/", "/"
+ if self._source_prefix_interp is not None:
+ input_prefix = _parse_prefix(self._source_prefix_interp)
+
+ if self._destination_prefix_interp is not None:
+ output_prefix = _parse_prefix(self._destination_prefix_interp)
+
+ return OperatorLineage(
+ inputs=[
+ Dataset(
+ namespace=f"gs://{self.source_bucket}",
+ name=input_prefix,
+ )
+ ],
+ outputs=[
+ Dataset(
+ namespace=f"gs://{self.destination_bucket}",
+ name=output_prefix,
+ )
+ ],
+ )
class GCSDeleteBucketOperator(GoogleCloudBaseOperator):
diff --git a/tests/providers/google/cloud/operators/test_gcs.py
b/tests/providers/google/cloud/operators/test_gcs.py
index 2eb96682bd..6236aa5f23 100644
--- a/tests/providers/google/cloud/operators/test_gcs.py
+++ b/tests/providers/google/cloud/operators/test_gcs.py
@@ -21,6 +21,7 @@ from datetime import datetime, timedelta, timezone
from pathlib import Path
from unittest import mock
+import pytest
from openlineage.client.facet import (
LifecycleStateChange,
LifecycleStateChangeDatasetFacet,
@@ -483,15 +484,78 @@ class TestGCSTimeSpanFileTransformOperator:
]
)
+ @pytest.mark.parametrize(
+ ("source_prefix", "dest_prefix", "inputs", "outputs"),
+ (
+ (
+ None,
+ None,
+ [Dataset(f"gs://{TEST_BUCKET}", "/")],
+ [Dataset(f"gs://{TEST_BUCKET}_dest", "/")],
+ ),
+ (
+ None,
+ "dest_pre/",
+ [Dataset(f"gs://{TEST_BUCKET}", "/")],
+ [Dataset(f"gs://{TEST_BUCKET}_dest", "dest_pre")],
+ ),
+ (
+ "source_pre/",
+ None,
+ [Dataset(f"gs://{TEST_BUCKET}", "source_pre")],
+ [Dataset(f"gs://{TEST_BUCKET}_dest", "/")],
+ ),
+ (
+ "source_pre/",
+ "dest_pre/",
+ [Dataset(f"gs://{TEST_BUCKET}", "source_pre")],
+ [Dataset(f"gs://{TEST_BUCKET}_dest", "dest_pre")],
+ ),
+ (
+ "source_pre",
+ "dest_pre",
+ [Dataset(f"gs://{TEST_BUCKET}", "/")],
+ [Dataset(f"gs://{TEST_BUCKET}_dest", "/")],
+ ),
+ (
+ "dir1/source_pre",
+ "dir2/dest_pre",
+ [Dataset(f"gs://{TEST_BUCKET}", "dir1")],
+ [Dataset(f"gs://{TEST_BUCKET}_dest", "dir2")],
+ ),
+ (
+ "",
+ "/",
+ [Dataset(f"gs://{TEST_BUCKET}", "/")],
+ [Dataset(f"gs://{TEST_BUCKET}_dest", "/")],
+ ),
+ (
+ "source/a.txt",
+ "target/",
+ [Dataset(f"gs://{TEST_BUCKET}", "source/a.txt")],
+ [Dataset(f"gs://{TEST_BUCKET}_dest", "target")],
+ ),
+ ),
+ ids=(
+ "no prefixes",
+ "dest prefix only",
+ "source prefix only",
+ "both with ending slash",
+ "both without ending slash",
+ "both as directory with prefix",
+ "both empty or root",
+ "source prefix is file path",
+ ),
+ )
@mock.patch("airflow.providers.google.cloud.operators.gcs.TemporaryDirectory")
@mock.patch("airflow.providers.google.cloud.operators.gcs.subprocess")
@mock.patch("airflow.providers.google.cloud.operators.gcs.GCSHook")
- def test_get_openlineage_facets_on_complete(self, mock_hook,
mock_subprocess, mock_tempdir):
+ def test_get_openlineage_facets_on_complete(
+ self, mock_hook, mock_subprocess, mock_tempdir, source_prefix,
dest_prefix, inputs, outputs
+ ):
source_bucket = TEST_BUCKET
- source_prefix = "source_prefix"
destination_bucket = TEST_BUCKET + "_dest"
- destination_prefix = "destination_prefix"
destination = "destination"
file1 = "file1"
@@ -508,8 +572,8 @@ class TestGCSTimeSpanFileTransformOperator:
mock_tempdir.return_value.__enter__.side_effect = ["source",
destination]
mock_hook.return_value.list_by_timespan.return_value = [
- f"{source_prefix}/{file1}",
- f"{source_prefix}/{file2}",
+ f"{source_prefix or ''}{file1}",
+ f"{source_prefix or ''}{file2}",
]
mock_proc = mock.MagicMock()
@@ -529,7 +593,7 @@ class TestGCSTimeSpanFileTransformOperator:
source_prefix=source_prefix,
source_gcp_conn_id="",
destination_bucket=destination_bucket,
- destination_prefix=destination_prefix,
+ destination_prefix=dest_prefix,
destination_gcp_conn_id="",
transform_script="script.py",
)
@@ -541,32 +605,11 @@ class TestGCSTimeSpanFileTransformOperator:
]
op.execute(context=context)
- expected_inputs = [
- Dataset(
- namespace=f"gs://{source_bucket}",
- name=f"{source_prefix}/{file1}",
- ),
- Dataset(
- namespace=f"gs://{source_bucket}",
- name=f"{source_prefix}/{file2}",
- ),
- ]
- expected_outputs = [
- Dataset(
- namespace=f"gs://{destination_bucket}",
- name=f"{destination_prefix}/{file1}",
- ),
- Dataset(
- namespace=f"gs://{destination_bucket}",
- name=f"{destination_prefix}/{file2}",
- ),
- ]
-
lineage = op.get_openlineage_facets_on_complete(None)
- assert len(lineage.inputs) == 2
- assert len(lineage.outputs) == 2
- assert lineage.inputs == expected_inputs
- assert lineage.outputs == expected_outputs
+ assert len(lineage.inputs) == len(inputs)
+ assert len(lineage.outputs) == len(outputs)
+ assert sorted(lineage.inputs) == sorted(inputs)
+ assert sorted(lineage.outputs) == sorted(outputs)
class TestGCSDeleteBucketOperator: