This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 42f513e891e feat: Add OpenLineage support for transfer operators
between GCS and SFTP (#45485)
42f513e891e is described below
commit 42f513e891e1794b47a69daf5a4fdafd89f2fdfd
Author: Kacper Muda <[email protected]>
AuthorDate: Wed Jan 8 23:55:52 2025 +0100
feat: Add OpenLineage support for transfer operators between GCS and SFTP
(#45485)
Signed-off-by: Kacper Muda <[email protected]>
---
.../google/cloud/transfers/gcs_to_sftp.py | 40 +++++++++--
.../google/cloud/transfers/sftp_to_gcs.py | 39 +++++++++--
.../google/cloud/transfers/test_gcs_to_sftp.py | 80 ++++++++++++++++++++++
.../google/cloud/transfers/test_sftp_to_gcs.py | 47 +++++++++++++
4 files changed, 197 insertions(+), 9 deletions(-)
diff --git
a/providers/src/airflow/providers/google/cloud/transfers/gcs_to_sftp.py
b/providers/src/airflow/providers/google/cloud/transfers/gcs_to_sftp.py
index 4210cdc041f..f529cef3613 100644
--- a/providers/src/airflow/providers/google/cloud/transfers/gcs_to_sftp.py
+++ b/providers/src/airflow/providers/google/cloud/transfers/gcs_to_sftp.py
@@ -21,6 +21,7 @@ from __future__ import annotations
import os
from collections.abc import Sequence
+from functools import cached_property
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING
@@ -129,14 +130,16 @@ class GCSToSFTPOperator(BaseOperator):
self.impersonation_chain = impersonation_chain
self.sftp_dirs = None
+ @cached_property
+ def sftp_hook(self):
+ return SFTPHook(self.sftp_conn_id)
+
def execute(self, context: Context):
gcs_hook = GCSHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)
- sftp_hook = SFTPHook(self.sftp_conn_id)
-
if WILDCARD in self.source_object:
total_wildcards = self.source_object.count(WILDCARD)
if total_wildcards > 1:
@@ -155,12 +158,12 @@ class GCSToSFTPOperator(BaseOperator):
for source_object in objects:
destination_path =
self._resolve_destination_path(source_object, prefix=prefix_dirname)
- self._copy_single_object(gcs_hook, sftp_hook, source_object,
destination_path)
+ self._copy_single_object(gcs_hook, self.sftp_hook,
source_object, destination_path)
self.log.info("Done. Uploaded '%d' files to %s", len(objects),
self.destination_path)
else:
destination_path =
self._resolve_destination_path(self.source_object)
- self._copy_single_object(gcs_hook, sftp_hook, self.source_object,
destination_path)
+ self._copy_single_object(gcs_hook, self.sftp_hook,
self.source_object, destination_path)
self.log.info("Done. Uploaded '%s' file to %s",
self.source_object, destination_path)
def _resolve_destination_path(self, source_object: str, prefix: str | None
= None) -> str:
@@ -200,3 +203,32 @@ class GCSToSFTPOperator(BaseOperator):
if self.move_object:
self.log.info("Executing delete of gs://%s/%s",
self.source_bucket, source_object)
gcs_hook.delete(self.source_bucket, source_object)
+
+ def get_openlineage_facets_on_start(self):
+ from airflow.providers.common.compat.openlineage.facet import Dataset
+ from airflow.providers.google.cloud.openlineage.utils import
extract_ds_name_from_gcs_path
+ from airflow.providers.openlineage.extractors import OperatorLineage
+
+ source_name = extract_ds_name_from_gcs_path(f"{self.source_object}")
+ dest_name = f"{self.destination_path}"
+ if self.keep_directory_structure:
+ dest_name = os.path.join(dest_name, source_name if source_name !=
"/" else "")
+ elif WILDCARD not in self.source_object:
+ dest_name = os.path.join(dest_name,
os.path.basename(self.source_object))
+
+ dest_name = dest_name.rstrip("/") if dest_name != "/" else "/"
+
+ return OperatorLineage(
+ inputs=[
+ Dataset(
+ namespace=f"gs://{self.source_bucket}",
+ name=source_name,
+ )
+ ],
+ outputs=[
+ Dataset(
+
namespace=f"file://{self.sftp_hook.remote_host}:{self.sftp_hook.port}",
+ name=dest_name,
+ )
+ ],
+ )
diff --git
a/providers/src/airflow/providers/google/cloud/transfers/sftp_to_gcs.py
b/providers/src/airflow/providers/google/cloud/transfers/sftp_to_gcs.py
index 77bd28116cb..e08f2dd944e 100644
--- a/providers/src/airflow/providers/google/cloud/transfers/sftp_to_gcs.py
+++ b/providers/src/airflow/providers/google/cloud/transfers/sftp_to_gcs.py
@@ -21,6 +21,7 @@ from __future__ import annotations
import os
from collections.abc import Sequence
+from functools import cached_property
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING
@@ -109,6 +110,10 @@ class SFTPToGCSOperator(BaseOperator):
self.impersonation_chain = impersonation_chain
self.sftp_prefetch = sftp_prefetch
+ @cached_property
+ def sftp_hook(self):
+ return SFTPHook(self.sftp_conn_id)
+
def execute(self, context: Context):
self.destination_path =
self._set_destination_path(self.destination_path)
self.destination_bucket =
self._set_bucket_name(self.destination_bucket)
@@ -117,8 +122,6 @@ class SFTPToGCSOperator(BaseOperator):
impersonation_chain=self.impersonation_chain,
)
- sftp_hook = SFTPHook(self.sftp_conn_id)
-
if WILDCARD in self.source_path:
total_wildcards = self.source_path.count(WILDCARD)
if total_wildcards > 1:
@@ -130,7 +133,7 @@ class SFTPToGCSOperator(BaseOperator):
prefix, delimiter = self.source_path.split(WILDCARD, 1)
base_path = os.path.dirname(prefix)
- files, _, _ = sftp_hook.get_tree_map(base_path, prefix=prefix,
delimiter=delimiter)
+ files, _, _ = self.sftp_hook.get_tree_map(base_path,
prefix=prefix, delimiter=delimiter)
for file in files:
destination_path = file.replace(base_path,
self.destination_path, 1)
@@ -140,13 +143,13 @@ class SFTPToGCSOperator(BaseOperator):
# retain the "/" prefix, if it has.
if not self.destination_path:
destination_path = destination_path.lstrip("/")
- self._copy_single_object(gcs_hook, sftp_hook, file,
destination_path)
+ self._copy_single_object(gcs_hook, self.sftp_hook, file,
destination_path)
else:
destination_object = (
self.destination_path if self.destination_path else
self.source_path.rsplit("/", 1)[1]
)
- self._copy_single_object(gcs_hook, sftp_hook, self.source_path,
destination_object)
+ self._copy_single_object(gcs_hook, self.sftp_hook,
self.source_path, destination_object)
def _copy_single_object(
self,
@@ -188,3 +191,29 @@ class SFTPToGCSOperator(BaseOperator):
def _set_bucket_name(name: str) -> str:
bucket = name if not name.startswith("gs://") else name[5:]
return bucket.strip("/")
+
+ def get_openlineage_facets_on_start(self):
+ from airflow.providers.common.compat.openlineage.facet import Dataset
+ from airflow.providers.google.cloud.openlineage.utils import
extract_ds_name_from_gcs_path
+ from airflow.providers.openlineage.extractors import OperatorLineage
+
+ source_name =
extract_ds_name_from_gcs_path(self.source_path.split(WILDCARD, 1)[0])
+ if self.source_path.startswith("/") and source_name != "/":
+ source_name = "/" + source_name
+
+ if WILDCARD not in self.source_path and not self.destination_path:
+ dest_name = self.source_path.rsplit("/", 1)[1]
+ else:
+ dest_name =
extract_ds_name_from_gcs_path(f"{self.destination_path}")
+
+ return OperatorLineage(
+ inputs=[
+ Dataset(
+
namespace=f"file://{self.sftp_hook.remote_host}:{self.sftp_hook.port}",
+ name=source_name,
+ )
+ ],
+ outputs=[
+ Dataset(namespace="gs://" +
self._set_bucket_name(self.destination_bucket), name=dest_name)
+ ],
+ )
diff --git a/providers/tests/google/cloud/transfers/test_gcs_to_sftp.py
b/providers/tests/google/cloud/transfers/test_gcs_to_sftp.py
index f2ede74a8a7..e6aedd5e598 100644
--- a/providers/tests/google/cloud/transfers/test_gcs_to_sftp.py
+++ b/providers/tests/google/cloud/transfers/test_gcs_to_sftp.py
@@ -321,3 +321,83 @@ class TestGoogleCloudStorageToSFTPOperator:
)
with pytest.raises(AirflowException):
operator.execute(None)
+
+ @pytest.mark.parametrize(
+ "source_object, destination_path, keep_directory_structure,
expected_source, expected_destination",
+ [
+ (
+ "folder/test_object.txt",
+ "dest/dir",
+ True,
+ "folder/test_object.txt",
+ "dest/dir/folder/test_object.txt",
+ ),
+ (
+ "folder/test_object.txt",
+ "dest/dir/",
+ True,
+ "folder/test_object.txt",
+ "dest/dir/folder/test_object.txt",
+ ),
+ (
+ "folder/test_object.txt",
+ "dest/dir",
+ False,
+ "folder/test_object.txt",
+ "dest/dir/test_object.txt",
+ ),
+ ("folder/test_object.txt", "/", False, "folder/test_object.txt",
"/test_object.txt"),
+ ("folder/test_object.txt", "/", True, "folder/test_object.txt",
"/folder/test_object.txt"),
+ (
+ "folder/test_object.txt",
+ "dest/dir/dest_object.txt",
+ True,
+ "folder/test_object.txt",
+ "dest/dir/dest_object.txt/folder/test_object.txt", # Dest
path is always treated as "dir"
+ ),
+ (
+ "folder/test_object.txt",
+ "dest/dir/dest_object.txt",
+ False,
+ "folder/test_object.txt",
+ "dest/dir/dest_object.txt/test_object.txt", # Dest path is
always treated as "dir"
+ ),
+ ("folder/test_object*.txt", "dest/dir", True, "folder",
"dest/dir/folder"),
+ ("folder/test_object*", "dest/dir", False, "folder", "dest/dir"),
+ ("*", "/", True, "/", "/"),
+ ("*", "/dest/dir", True, "/", "/dest/dir"),
+ ("*", "/dest/dir", False, "/", "/dest/dir"),
+ ],
+ )
+
@mock.patch("airflow.providers.google.cloud.transfers.gcs_to_sftp.SFTPHook")
+ def test_get_openlineage_facets(
+ self,
+ sftp_hook_mock,
+ source_object,
+ destination_path,
+ keep_directory_structure,
+ expected_source,
+ expected_destination,
+ ):
+ sftp_hook_mock.return_value.remote_host = "11.222.33.44"
+ sftp_hook_mock.return_value.port = 22
+ operator = GCSToSFTPOperator(
+ task_id=TASK_ID,
+ source_bucket=TEST_BUCKET,
+ source_object=source_object,
+ destination_path=destination_path,
+ keep_directory_structure=keep_directory_structure,
+ move_object=False,
+ gcp_conn_id=GCP_CONN_ID,
+ sftp_conn_id=SFTP_CONN_ID,
+ )
+
+ result = operator.get_openlineage_facets_on_start()
+ assert not result.run_facets
+ assert not result.job_facets
+ assert len(result.inputs) == 1
+ assert len(result.outputs) == 1
+ assert result.inputs[0].namespace == f"gs://{TEST_BUCKET}"
+ assert result.inputs[0].name == expected_source
+ assert result.outputs[0].namespace == "file://11.222.33.44:22"
+ assert result.outputs[0].name == expected_destination
diff --git a/providers/tests/google/cloud/transfers/test_sftp_to_gcs.py
b/providers/tests/google/cloud/transfers/test_sftp_to_gcs.py
index 7755ef0f025..a3bf24f9ed1 100644
--- a/providers/tests/google/cloud/transfers/test_sftp_to_gcs.py
+++ b/providers/tests/google/cloud/transfers/test_sftp_to_gcs.py
@@ -301,3 +301,50 @@ class TestSFTPToGCSOperator:
),
]
)
+
+ @pytest.mark.parametrize(
+ "source_object, destination_path, expected_source,
expected_destination",
+ [
+ ("folder/test_object.txt", "dest/dir", "folder/test_object.txt",
"dest"),
+ ("folder/test_object.txt", "dest/dir/", "folder/test_object.txt",
"dest/dir"),
+ ("folder/test_object.txt", "/", "folder/test_object.txt", "/"),
+ (
+ "folder/test_object.txt",
+ "dest/dir/dest_object.txt",
+ "folder/test_object.txt",
+ "dest/dir/dest_object.txt",
+ ),
+ ("folder/test_object*.txt", "dest/dir", "folder", "dest"),
+ ("folder/test_object/*", "/", "folder/test_object", "/"),
+ ("folder/test_object*", "/", "folder", "/"),
+ ("folder/test_object/*", None, "folder/test_object", "/"),
+ ("*", "/", "/", "/"),
+ ("/*", "/", "/", "/"),
+ ("/*", "dest/dir", "/", "dest"),
+ ],
+ )
+
@mock.patch("airflow.providers.google.cloud.transfers.sftp_to_gcs.SFTPHook")
+ def test_get_openlineage_facets(
+ self, sftp_hook_mock, source_object, destination_path,
expected_source, expected_destination
+ ):
+ sftp_hook_mock.return_value.remote_host = "11.222.33.44"
+ sftp_hook_mock.return_value.port = 22
+ operator = SFTPToGCSOperator(
+ task_id=TASK_ID,
+ source_path=source_object,
+ destination_path=destination_path,
+ destination_bucket=TEST_BUCKET,
+ move_object=False,
+ gcp_conn_id=GCP_CONN_ID,
+ sftp_conn_id=SFTP_CONN_ID,
+ )
+
+ result = operator.get_openlineage_facets_on_start()
+ assert not result.run_facets
+ assert not result.job_facets
+ assert len(result.inputs) == 1
+ assert len(result.outputs) == 1
+ assert result.inputs[0].namespace == "file://11.222.33.44:22"
+ assert result.inputs[0].name == expected_source
+ assert result.outputs[0].namespace == f"gs://{TEST_BUCKET}"
+ assert result.outputs[0].name == expected_destination