This is an automated email from the ASF dual-hosted git repository.

mobuchowski pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 6b880844ad openlineage, sftp: add OpenLineage support for sftp 
provider (#31360)
6b880844ad is described below

commit 6b880844ade6036954d5343f6a74a241b3865153
Author: Maciej Obuchowski <[email protected]>
AuthorDate: Tue Jul 25 11:17:04 2023 +0200

    openlineage, sftp: add OpenLineage support for sftp provider (#31360)
    
    Signed-off-by: Maciej Obuchowski <[email protected]>
---
 airflow/providers/sftp/operators/sftp.py    |  82 +++++++++++++++++++++
 generated/provider_dependencies.json        |   1 +
 tests/providers/sftp/operators/test_sftp.py | 110 +++++++++++++++++++++++++++-
 3 files changed, 192 insertions(+), 1 deletion(-)

diff --git a/airflow/providers/sftp/operators/sftp.py 
b/airflow/providers/sftp/operators/sftp.py
index 0920387faa..8da4b3f332 100644
--- a/airflow/providers/sftp/operators/sftp.py
+++ b/airflow/providers/sftp/operators/sftp.py
@@ -19,10 +19,13 @@
 from __future__ import annotations
 
 import os
+import socket
 import warnings
 from pathlib import Path
 from typing import Any, Sequence
 
+import paramiko
+
 from airflow.exceptions import AirflowException, 
AirflowProviderDeprecationWarning
 from airflow.models import BaseOperator
 from airflow.providers.sftp.hooks.sftp import SFTPHook
@@ -188,3 +191,82 @@ class SFTPOperator(BaseOperator):
             raise AirflowException(f"Error while transferring {file_msg}, 
error: {str(e)}")
 
         return self.local_filepath
+
+    def get_openlineage_facets_on_start(self):
+        """
+        This returns OpenLineage datasets in format:
+            input: file://<local_host>/path
+            output: file://<remote_host>:<remote_port>/path.
+        """
+        from openlineage.client.run import Dataset
+
+        from airflow.providers.openlineage.extractors import OperatorLineage
+
+        scheme = "file"
+        local_host = socket.gethostname()
+        try:
+            local_host = socket.gethostbyname(local_host)
+        except Exception as e:
+            self.log.warning(
+                f"Failed to resolve local hostname. Using the hostname got by 
socket.gethostbyname() without resolution. {e}",  # noqa: E501
+                exc_info=True,
+            )
+
+        hook = self.sftp_hook or self.ssh_hook or 
SFTPHook(ssh_conn_id=self.ssh_conn_id)
+
+        if self.remote_host is not None:
+            remote_host = self.remote_host
+        else:
+            remote_host = hook.get_connection(hook.ssh_conn_id).host
+
+        try:
+            remote_host = socket.gethostbyname(remote_host)
+        except OSError as e:
+            self.log.warning(
+                f"Failed to resolve remote hostname. Using the provided 
hostname without resolution. {e}",  # noqa: E501
+                exc_info=True,
+            )
+
+        if hasattr(hook, "port"):
+            remote_port = hook.port
+        elif hasattr(hook, "ssh_hook"):
+            remote_port = hook.ssh_hook.port
+
+        # Since v4.1.0, SFTPOperator accepts both a string (single file) and a 
list of
+        # strings (multiple files) as local_filepath and remote_filepath, and 
internally
+        # keeps them as list in both cases. But before 4.1.0, only single 
string is
+        # allowed. So we consider both cases here for backward compatibility.
+        if isinstance(self.local_filepath, str):
+            local_filepath = [self.local_filepath]
+        else:
+            local_filepath = self.local_filepath
+        if isinstance(self.remote_filepath, str):
+            remote_filepath = [self.remote_filepath]
+        else:
+            remote_filepath = self.remote_filepath
+
+        local_datasets = [
+            Dataset(namespace=self._get_namespace(scheme, local_host, None, 
path), name=path)
+            for path in local_filepath
+        ]
+        remote_datasets = [
+            Dataset(namespace=self._get_namespace(scheme, remote_host, 
remote_port, path), name=path)
+            for path in remote_filepath
+        ]
+
+        if self.operation.lower() == SFTPOperation.GET:
+            inputs = remote_datasets
+            outputs = local_datasets
+        else:
+            inputs = local_datasets
+            outputs = remote_datasets
+
+        return OperatorLineage(
+            inputs=inputs,
+            outputs=outputs,
+        )
+
+    def _get_namespace(self, scheme, host, port, path) -> str:
+        port = port or paramiko.config.SSH_PORT
+        authority = f"{host}:{port}"
+        return f"{scheme}://{authority}"
diff --git a/generated/provider_dependencies.json 
b/generated/provider_dependencies.json
index 091a45cf36..f4eeb8a273 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -785,6 +785,7 @@
       "apache-airflow>=2.4.0"
     ],
     "cross-providers-deps": [
+      "openlineage",
       "ssh"
     ],
     "excluded-python-versions": []
diff --git a/tests/providers/sftp/operators/test_sftp.py 
b/tests/providers/sftp/operators/test_sftp.py
index 92137e2a5c..8adb93f7db 100644
--- a/tests/providers/sftp/operators/test_sftp.py
+++ b/tests/providers/sftp/operators/test_sftp.py
@@ -18,13 +18,16 @@
 from __future__ import annotations
 
 import os
+import socket
 from base64 import b64encode
 from unittest import mock
 
+import paramiko
 import pytest
+from openlineage.client.run import Dataset
 
 from airflow.exceptions import AirflowException
-from airflow.models import DAG
+from airflow.models import DAG, Connection
 from airflow.providers.sftp.hooks.sftp import SFTPHook
 from airflow.providers.sftp.operators.sftp import SFTPOperation, SFTPOperator
 from airflow.providers.ssh.hooks.ssh import SSHHook
@@ -36,6 +39,18 @@ from tests.test_utils.config import conf_vars
 DEFAULT_DATE = datetime(2017, 1, 1)
 TEST_CONN_ID = "conn_id_for_testing"
 
+LOCAL_FILEPATH = "/path/local"
+REMOTE_FILEPATH = "/path/remote"
+LOCAL_DATASET = [
+    
Dataset(namespace=f"file://{socket.gethostbyname(socket.gethostname())}:22", 
name=LOCAL_FILEPATH)
+]
+REMOTE_DATASET = [Dataset(namespace="file://remotehost:22", 
name=REMOTE_FILEPATH)]
+
+TEST_GET_PUT_PARAMS = [
+    (SFTPOperation.GET, (REMOTE_DATASET, LOCAL_DATASET)),
+    (SFTPOperation.PUT, (LOCAL_DATASET, REMOTE_DATASET)),
+]
+
 
 class TestSFTPOperator:
     def setup_method(self):
@@ -478,3 +493,96 @@ class TestSFTPOperator:
         return_value = sftp_op.execute(None)
         assert isinstance(return_value, str)
         assert return_value == local_filepath
+
+    @pytest.mark.parametrize(
+        "operation, expected",
+        TEST_GET_PUT_PARAMS,
+    )
+    @mock.patch("airflow.providers.ssh.hooks.ssh.SSHHook.get_conn", 
spec=paramiko.SSHClient)
+    @mock.patch("airflow.providers.ssh.hooks.ssh.SSHHook.get_connection", 
spec=Connection)
+    def test_extract_ssh_conn_id(self, get_connection, get_conn, operation, 
expected):
+        get_connection.return_value = Connection(
+            conn_id="sftp_conn_id",
+            conn_type="sftp",
+            host="remotehost",
+            port=22,
+        )
+
+        dag_id = "sftp_dag"
+        task_id = "sftp_task"
+
+        task = SFTPOperator(
+            task_id=task_id,
+            ssh_conn_id="sftp_conn_id",
+            dag=DAG(dag_id),
+            start_date=timezone.utcnow(),
+            local_filepath="/path/local",
+            remote_filepath="/path/remote",
+            operation=operation,
+        )
+        lineage = task.get_openlineage_facets_on_start()
+
+        assert lineage.inputs == expected[0]
+        assert lineage.outputs == expected[1]
+
+    @pytest.mark.parametrize(
+        "operation, expected",
+        TEST_GET_PUT_PARAMS,
+    )
+    @mock.patch("airflow.providers.ssh.hooks.ssh.SSHHook.get_conn", 
spec=paramiko.SSHClient)
+    @mock.patch("airflow.providers.ssh.hooks.ssh.SSHHook.get_connection", 
spec=Connection)
+    def test_extract_sftp_hook(self, get_connection, get_conn, operation, 
expected):
+        get_connection.return_value = Connection(
+            conn_id="sftp_conn_id",
+            conn_type="sftp",
+            host="remotehost",
+            port=22,
+        )
+
+        dag_id = "sftp_dag"
+        task_id = "sftp_task"
+
+        task = SFTPOperator(
+            task_id=task_id,
+            sftp_hook=SFTPHook(ssh_conn_id="sftp_conn_id"),
+            dag=DAG(dag_id),
+            start_date=timezone.utcnow(),
+            local_filepath="/path/local",
+            remote_filepath="/path/remote",
+            operation=operation,
+        )
+        lineage = task.get_openlineage_facets_on_start()
+
+        assert lineage.inputs == expected[0]
+        assert lineage.outputs == expected[1]
+
+    @pytest.mark.parametrize(
+        "operation, expected",
+        TEST_GET_PUT_PARAMS,
+    )
+    @mock.patch("airflow.providers.ssh.hooks.ssh.SSHHook.get_conn", 
spec=paramiko.SSHClient)
+    @mock.patch("airflow.providers.ssh.hooks.ssh.SSHHook.get_connection", 
spec=Connection)
+    def test_extract_ssh_hook(self, get_connection, get_conn, operation, 
expected):
+        get_connection.return_value = Connection(
+            conn_id="sftp_conn_id",
+            conn_type="sftp",
+            host="remotehost",
+            port=22,
+        )
+
+        dag_id = "sftp_dag"
+        task_id = "sftp_task"
+
+        task = SFTPOperator(
+            task_id=task_id,
+            ssh_hook=SSHHook(ssh_conn_id="sftp_conn_id"),
+            dag=DAG(dag_id),
+            start_date=timezone.utcnow(),
+            local_filepath="/path/local",
+            remote_filepath="/path/remote",
+            operation=operation,
+        )
+        lineage = task.get_openlineage_facets_on_start()
+
+        assert lineage.inputs == expected[0]
+        assert lineage.outputs == expected[1]

Reply via email to