This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 42f513e891e feat: Add OpenLineage support for transfer operators 
between GCS and SFTP (#45485)
42f513e891e is described below

commit 42f513e891e1794b47a69daf5a4fdafd89f2fdfd
Author: Kacper Muda <[email protected]>
AuthorDate: Wed Jan 8 23:55:52 2025 +0100

    feat: Add OpenLineage support for transfer operators between GCS and SFTP 
(#45485)
    
    Signed-off-by: Kacper Muda <[email protected]>
---
 .../google/cloud/transfers/gcs_to_sftp.py          | 40 +++++++++--
 .../google/cloud/transfers/sftp_to_gcs.py          | 39 +++++++++--
 .../google/cloud/transfers/test_gcs_to_sftp.py     | 80 ++++++++++++++++++++++
 .../google/cloud/transfers/test_sftp_to_gcs.py     | 47 +++++++++++++
 4 files changed, 197 insertions(+), 9 deletions(-)

diff --git 
a/providers/src/airflow/providers/google/cloud/transfers/gcs_to_sftp.py 
b/providers/src/airflow/providers/google/cloud/transfers/gcs_to_sftp.py
index 4210cdc041f..f529cef3613 100644
--- a/providers/src/airflow/providers/google/cloud/transfers/gcs_to_sftp.py
+++ b/providers/src/airflow/providers/google/cloud/transfers/gcs_to_sftp.py
@@ -21,6 +21,7 @@ from __future__ import annotations
 
 import os
 from collections.abc import Sequence
+from functools import cached_property
 from tempfile import NamedTemporaryFile
 from typing import TYPE_CHECKING
 
@@ -129,14 +130,16 @@ class GCSToSFTPOperator(BaseOperator):
         self.impersonation_chain = impersonation_chain
         self.sftp_dirs = None
 
+    @cached_property
+    def sftp_hook(self):
+        return SFTPHook(self.sftp_conn_id)
+
     def execute(self, context: Context):
         gcs_hook = GCSHook(
             gcp_conn_id=self.gcp_conn_id,
             impersonation_chain=self.impersonation_chain,
         )
 
-        sftp_hook = SFTPHook(self.sftp_conn_id)
-
         if WILDCARD in self.source_object:
             total_wildcards = self.source_object.count(WILDCARD)
             if total_wildcards > 1:
@@ -155,12 +158,12 @@ class GCSToSFTPOperator(BaseOperator):
 
             for source_object in objects:
                 destination_path = 
self._resolve_destination_path(source_object, prefix=prefix_dirname)
-                self._copy_single_object(gcs_hook, sftp_hook, source_object, 
destination_path)
+                self._copy_single_object(gcs_hook, self.sftp_hook, 
source_object, destination_path)
 
             self.log.info("Done. Uploaded '%d' files to %s", len(objects), 
self.destination_path)
         else:
             destination_path = 
self._resolve_destination_path(self.source_object)
-            self._copy_single_object(gcs_hook, sftp_hook, self.source_object, 
destination_path)
+            self._copy_single_object(gcs_hook, self.sftp_hook, 
self.source_object, destination_path)
             self.log.info("Done. Uploaded '%s' file to %s", 
self.source_object, destination_path)
 
     def _resolve_destination_path(self, source_object: str, prefix: str | None 
= None) -> str:
@@ -200,3 +203,32 @@ class GCSToSFTPOperator(BaseOperator):
         if self.move_object:
             self.log.info("Executing delete of gs://%s/%s", 
self.source_bucket, source_object)
             gcs_hook.delete(self.source_bucket, source_object)
+
+    def get_openlineage_facets_on_start(self):
+        from airflow.providers.common.compat.openlineage.facet import Dataset
+        from airflow.providers.google.cloud.openlineage.utils import 
extract_ds_name_from_gcs_path
+        from airflow.providers.openlineage.extractors import OperatorLineage
+
+        source_name = extract_ds_name_from_gcs_path(f"{self.source_object}")
+        dest_name = f"{self.destination_path}"
+        if self.keep_directory_structure:
+            dest_name = os.path.join(dest_name, source_name if source_name != 
"/" else "")
+        elif WILDCARD not in self.source_object:
+            dest_name = os.path.join(dest_name, 
os.path.basename(self.source_object))
+
+        dest_name = dest_name.rstrip("/") if dest_name != "/" else "/"
+
+        return OperatorLineage(
+            inputs=[
+                Dataset(
+                    namespace=f"gs://{self.source_bucket}",
+                    name=source_name,
+                )
+            ],
+            outputs=[
+                Dataset(
+                    
namespace=f"file://{self.sftp_hook.remote_host}:{self.sftp_hook.port}",
+                    name=dest_name,
+                )
+            ],
+        )
diff --git 
a/providers/src/airflow/providers/google/cloud/transfers/sftp_to_gcs.py 
b/providers/src/airflow/providers/google/cloud/transfers/sftp_to_gcs.py
index 77bd28116cb..e08f2dd944e 100644
--- a/providers/src/airflow/providers/google/cloud/transfers/sftp_to_gcs.py
+++ b/providers/src/airflow/providers/google/cloud/transfers/sftp_to_gcs.py
@@ -21,6 +21,7 @@ from __future__ import annotations
 
 import os
 from collections.abc import Sequence
+from functools import cached_property
 from tempfile import NamedTemporaryFile
 from typing import TYPE_CHECKING
 
@@ -109,6 +110,10 @@ class SFTPToGCSOperator(BaseOperator):
         self.impersonation_chain = impersonation_chain
         self.sftp_prefetch = sftp_prefetch
 
+    @cached_property
+    def sftp_hook(self):
+        return SFTPHook(self.sftp_conn_id)
+
     def execute(self, context: Context):
         self.destination_path = 
self._set_destination_path(self.destination_path)
         self.destination_bucket = 
self._set_bucket_name(self.destination_bucket)
@@ -117,8 +122,6 @@ class SFTPToGCSOperator(BaseOperator):
             impersonation_chain=self.impersonation_chain,
         )
 
-        sftp_hook = SFTPHook(self.sftp_conn_id)
-
         if WILDCARD in self.source_path:
             total_wildcards = self.source_path.count(WILDCARD)
             if total_wildcards > 1:
@@ -130,7 +133,7 @@ class SFTPToGCSOperator(BaseOperator):
             prefix, delimiter = self.source_path.split(WILDCARD, 1)
             base_path = os.path.dirname(prefix)
 
-            files, _, _ = sftp_hook.get_tree_map(base_path, prefix=prefix, 
delimiter=delimiter)
+            files, _, _ = self.sftp_hook.get_tree_map(base_path, 
prefix=prefix, delimiter=delimiter)
 
             for file in files:
                 destination_path = file.replace(base_path, 
self.destination_path, 1)
@@ -140,13 +143,13 @@ class SFTPToGCSOperator(BaseOperator):
                 # retain the "/" prefix, if it has.
                 if not self.destination_path:
                     destination_path = destination_path.lstrip("/")
-                self._copy_single_object(gcs_hook, sftp_hook, file, 
destination_path)
+                self._copy_single_object(gcs_hook, self.sftp_hook, file, 
destination_path)
 
         else:
             destination_object = (
                 self.destination_path if self.destination_path else 
self.source_path.rsplit("/", 1)[1]
             )
-            self._copy_single_object(gcs_hook, sftp_hook, self.source_path, 
destination_object)
+            self._copy_single_object(gcs_hook, self.sftp_hook, 
self.source_path, destination_object)
 
     def _copy_single_object(
         self,
@@ -188,3 +191,29 @@ class SFTPToGCSOperator(BaseOperator):
     def _set_bucket_name(name: str) -> str:
         bucket = name if not name.startswith("gs://") else name[5:]
         return bucket.strip("/")
+
+    def get_openlineage_facets_on_start(self):
+        from airflow.providers.common.compat.openlineage.facet import Dataset
+        from airflow.providers.google.cloud.openlineage.utils import 
extract_ds_name_from_gcs_path
+        from airflow.providers.openlineage.extractors import OperatorLineage
+
+        source_name = 
extract_ds_name_from_gcs_path(self.source_path.split(WILDCARD, 1)[0])
+        if self.source_path.startswith("/") and source_name != "/":
+            source_name = "/" + source_name
+
+        if WILDCARD not in self.source_path and not self.destination_path:
+            dest_name = self.source_path.rsplit("/", 1)[1]
+        else:
+            dest_name = 
extract_ds_name_from_gcs_path(f"{self.destination_path}")
+
+        return OperatorLineage(
+            inputs=[
+                Dataset(
+                    
namespace=f"file://{self.sftp_hook.remote_host}:{self.sftp_hook.port}",
+                    name=source_name,
+                )
+            ],
+            outputs=[
+                Dataset(namespace="gs://" + 
self._set_bucket_name(self.destination_bucket), name=dest_name)
+            ],
+        )
diff --git a/providers/tests/google/cloud/transfers/test_gcs_to_sftp.py 
b/providers/tests/google/cloud/transfers/test_gcs_to_sftp.py
index f2ede74a8a7..e6aedd5e598 100644
--- a/providers/tests/google/cloud/transfers/test_gcs_to_sftp.py
+++ b/providers/tests/google/cloud/transfers/test_gcs_to_sftp.py
@@ -321,3 +321,83 @@ class TestGoogleCloudStorageToSFTPOperator:
         )
         with pytest.raises(AirflowException):
             operator.execute(None)
+
+    @pytest.mark.parametrize(
+        "source_object, destination_path, keep_directory_structure, 
expected_source, expected_destination",
+        [
+            (
+                "folder/test_object.txt",
+                "dest/dir",
+                True,
+                "folder/test_object.txt",
+                "dest/dir/folder/test_object.txt",
+            ),
+            (
+                "folder/test_object.txt",
+                "dest/dir/",
+                True,
+                "folder/test_object.txt",
+                "dest/dir/folder/test_object.txt",
+            ),
+            (
+                "folder/test_object.txt",
+                "dest/dir",
+                False,
+                "folder/test_object.txt",
+                "dest/dir/test_object.txt",
+            ),
+            ("folder/test_object.txt", "/", False, "folder/test_object.txt", 
"/test_object.txt"),
+            ("folder/test_object.txt", "/", True, "folder/test_object.txt", 
"/folder/test_object.txt"),
+            (
+                "folder/test_object.txt",
+                "dest/dir/dest_object.txt",
+                True,
+                "folder/test_object.txt",
+                "dest/dir/dest_object.txt/folder/test_object.txt",  # Dest 
path is always treated as "dir"
+            ),
+            (
+                "folder/test_object.txt",
+                "dest/dir/dest_object.txt",
+                False,
+                "folder/test_object.txt",
+                "dest/dir/dest_object.txt/test_object.txt",  # Dest path is 
always treated as "dir"
+            ),
+            ("folder/test_object*.txt", "dest/dir", True, "folder", 
"dest/dir/folder"),
+            ("folder/test_object*", "dest/dir", False, "folder", "dest/dir"),
+            ("*", "/", True, "/", "/"),
+            ("*", "/dest/dir", True, "/", "/dest/dir"),
+            ("*", "/dest/dir", False, "/", "/dest/dir"),
+        ],
+    )
+    
@mock.patch("airflow.providers.google.cloud.transfers.gcs_to_sftp.SFTPHook")
+    def test_get_openlineage_facets(
+        self,
+        sftp_hook_mock,
+        source_object,
+        destination_path,
+        keep_directory_structure,
+        expected_source,
+        expected_destination,
+    ):
+        sftp_hook_mock.return_value.remote_host = "11.222.33.44"
+        sftp_hook_mock.return_value.port = 22
+        operator = GCSToSFTPOperator(
+            task_id=TASK_ID,
+            source_bucket=TEST_BUCKET,
+            source_object=source_object,
+            destination_path=destination_path,
+            keep_directory_structure=keep_directory_structure,
+            move_object=False,
+            gcp_conn_id=GCP_CONN_ID,
+            sftp_conn_id=SFTP_CONN_ID,
+        )
+
+        result = operator.get_openlineage_facets_on_start()
+        assert not result.run_facets
+        assert not result.job_facets
+        assert len(result.inputs) == 1
+        assert len(result.outputs) == 1
+        assert result.inputs[0].namespace == f"gs://{TEST_BUCKET}"
+        assert result.inputs[0].name == expected_source
+        assert result.outputs[0].namespace == "file://11.222.33.44:22"
+        assert result.outputs[0].name == expected_destination
diff --git a/providers/tests/google/cloud/transfers/test_sftp_to_gcs.py 
b/providers/tests/google/cloud/transfers/test_sftp_to_gcs.py
index 7755ef0f025..a3bf24f9ed1 100644
--- a/providers/tests/google/cloud/transfers/test_sftp_to_gcs.py
+++ b/providers/tests/google/cloud/transfers/test_sftp_to_gcs.py
@@ -301,3 +301,50 @@ class TestSFTPToGCSOperator:
                 ),
             ]
         )
+
+    @pytest.mark.parametrize(
+        "source_object, destination_path, expected_source, 
expected_destination",
+        [
+            ("folder/test_object.txt", "dest/dir", "folder/test_object.txt", 
"dest"),
+            ("folder/test_object.txt", "dest/dir/", "folder/test_object.txt", 
"dest/dir"),
+            ("folder/test_object.txt", "/", "folder/test_object.txt", "/"),
+            (
+                "folder/test_object.txt",
+                "dest/dir/dest_object.txt",
+                "folder/test_object.txt",
+                "dest/dir/dest_object.txt",
+            ),
+            ("folder/test_object*.txt", "dest/dir", "folder", "dest"),
+            ("folder/test_object/*", "/", "folder/test_object", "/"),
+            ("folder/test_object*", "/", "folder", "/"),
+            ("folder/test_object/*", None, "folder/test_object", "/"),
+            ("*", "/", "/", "/"),
+            ("/*", "/", "/", "/"),
+            ("/*", "dest/dir", "/", "dest"),
+        ],
+    )
+    
@mock.patch("airflow.providers.google.cloud.transfers.sftp_to_gcs.SFTPHook")
+    def test_get_openlineage_facets(
+        self, sftp_hook_mock, source_object, destination_path, 
expected_source, expected_destination
+    ):
+        sftp_hook_mock.return_value.remote_host = "11.222.33.44"
+        sftp_hook_mock.return_value.port = 22
+        operator = SFTPToGCSOperator(
+            task_id=TASK_ID,
+            source_path=source_object,
+            destination_path=destination_path,
+            destination_bucket=TEST_BUCKET,
+            move_object=False,
+            gcp_conn_id=GCP_CONN_ID,
+            sftp_conn_id=SFTP_CONN_ID,
+        )
+
+        result = operator.get_openlineage_facets_on_start()
+        assert not result.run_facets
+        assert not result.job_facets
+        assert len(result.inputs) == 1
+        assert len(result.outputs) == 1
+        assert result.inputs[0].namespace == "file://11.222.33.44:22"
+        assert result.inputs[0].name == expected_source
+        assert result.outputs[0].namespace == f"gs://{TEST_BUCKET}"
+        assert result.outputs[0].name == expected_destination

Reply via email to