This is an automated email from the ASF dual-hosted git repository.
mobuchowski pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 6b880844ad openlineage, sftp: add OpenLineage support for sftp
provider (#31360)
6b880844ad is described below
commit 6b880844ade6036954d5343f6a74a241b3865153
Author: Maciej Obuchowski <[email protected]>
AuthorDate: Tue Jul 25 11:17:04 2023 +0200
openlineage, sftp: add OpenLineage support for sftp provider (#31360)
Signed-off-by: Maciej Obuchowski <[email protected]>
---
airflow/providers/sftp/operators/sftp.py | 82 +++++++++++++++++++++
generated/provider_dependencies.json | 1 +
tests/providers/sftp/operators/test_sftp.py | 110 +++++++++++++++++++++++++++-
3 files changed, 192 insertions(+), 1 deletion(-)
diff --git a/airflow/providers/sftp/operators/sftp.py
b/airflow/providers/sftp/operators/sftp.py
index 0920387faa..8da4b3f332 100644
--- a/airflow/providers/sftp/operators/sftp.py
+++ b/airflow/providers/sftp/operators/sftp.py
@@ -19,10 +19,13 @@
from __future__ import annotations
import os
+import socket
import warnings
from pathlib import Path
from typing import Any, Sequence
+import paramiko
+
from airflow.exceptions import AirflowException,
AirflowProviderDeprecationWarning
from airflow.models import BaseOperator
from airflow.providers.sftp.hooks.sftp import SFTPHook
@@ -188,3 +191,82 @@ class SFTPOperator(BaseOperator):
raise AirflowException(f"Error while transferring {file_msg},
error: {str(e)}")
return self.local_filepath
+
+ def get_openlineage_facets_on_start(self):
+ """
+ This returns OpenLineage datasets in format:
+ input: file://<local_host>/path
+ output: file://<remote_host>:<remote_port>/path.
+ """
+ from openlineage.client.run import Dataset
+
+ from airflow.providers.openlineage.extractors import OperatorLineage
+
+ scheme = "file"
+ local_host = socket.gethostname()
+ try:
+ local_host = socket.gethostbyname(local_host)
+ except Exception as e:
+ self.log.warning(
+ f"Failed to resolve local hostname. Using the hostname got by
socket.gethostbyname() without resolution. {e}", # noqa: E501
+ exc_info=True,
+ )
+
+ hook = self.sftp_hook or self.ssh_hook or
SFTPHook(ssh_conn_id=self.ssh_conn_id)
+
+ if self.remote_host is not None:
+ remote_host = self.remote_host
+ else:
+ remote_host = hook.get_connection(hook.ssh_conn_id).host
+
+ try:
+ remote_host = socket.gethostbyname(remote_host)
+ except OSError as e:
+ self.log.warning(
+ f"Failed to resolve remote hostname. Using the provided
hostname without resolution. {e}", # noqa: E501
+ exc_info=True,
+ )
+
+ if hasattr(hook, "port"):
+ remote_port = hook.port
+ elif hasattr(hook, "ssh_hook"):
+ remote_port = hook.ssh_hook.port
+
+ # Since v4.1.0, SFTPOperator accepts both a string (single file) and a
list of
+ # strings (multiple files) as local_filepath and remote_filepath, and
internally
+ # keeps them as list in both cases. But before 4.1.0, only single
string is
+ # allowed. So we consider both cases here for backward compatibility.
+ if isinstance(self.local_filepath, str):
+ local_filepath = [self.local_filepath]
+ else:
+ local_filepath = self.local_filepath
+ if isinstance(self.remote_filepath, str):
+ remote_filepath = [self.remote_filepath]
+ else:
+ remote_filepath = self.remote_filepath
+
+ local_datasets = [
+ Dataset(namespace=self._get_namespace(scheme, local_host, None,
path), name=path)
+ for path in local_filepath
+ ]
+ remote_datasets = [
+ Dataset(namespace=self._get_namespace(scheme, remote_host,
remote_port, path), name=path)
+ for path in remote_filepath
+ ]
+
+ if self.operation.lower() == SFTPOperation.GET:
+ inputs = remote_datasets
+ outputs = local_datasets
+ else:
+ inputs = local_datasets
+ outputs = remote_datasets
+
+ return OperatorLineage(
+ inputs=inputs,
+ outputs=outputs,
+ )
+
+ def _get_namespace(self, scheme, host, port, path) -> str:
+ port = port or paramiko.config.SSH_PORT
+ authority = f"{host}:{port}"
+ return f"{scheme}://{authority}"
diff --git a/generated/provider_dependencies.json
b/generated/provider_dependencies.json
index 091a45cf36..f4eeb8a273 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -785,6 +785,7 @@
"apache-airflow>=2.4.0"
],
"cross-providers-deps": [
+ "openlineage",
"ssh"
],
"excluded-python-versions": []
diff --git a/tests/providers/sftp/operators/test_sftp.py
b/tests/providers/sftp/operators/test_sftp.py
index 92137e2a5c..8adb93f7db 100644
--- a/tests/providers/sftp/operators/test_sftp.py
+++ b/tests/providers/sftp/operators/test_sftp.py
@@ -18,13 +18,16 @@
from __future__ import annotations
import os
+import socket
from base64 import b64encode
from unittest import mock
+import paramiko
import pytest
+from openlineage.client.run import Dataset
from airflow.exceptions import AirflowException
-from airflow.models import DAG
+from airflow.models import DAG, Connection
from airflow.providers.sftp.hooks.sftp import SFTPHook
from airflow.providers.sftp.operators.sftp import SFTPOperation, SFTPOperator
from airflow.providers.ssh.hooks.ssh import SSHHook
@@ -36,6 +39,18 @@ from tests.test_utils.config import conf_vars
DEFAULT_DATE = datetime(2017, 1, 1)
TEST_CONN_ID = "conn_id_for_testing"
+LOCAL_FILEPATH = "/path/local"
+REMOTE_FILEPATH = "/path/remote"
+LOCAL_DATASET = [
+
Dataset(namespace=f"file://{socket.gethostbyname(socket.gethostname())}:22",
name=LOCAL_FILEPATH)
+]
+REMOTE_DATASET = [Dataset(namespace="file://remotehost:22",
name=REMOTE_FILEPATH)]
+
+TEST_GET_PUT_PARAMS = [
+ (SFTPOperation.GET, (REMOTE_DATASET, LOCAL_DATASET)),
+ (SFTPOperation.PUT, (LOCAL_DATASET, REMOTE_DATASET)),
+]
+
class TestSFTPOperator:
def setup_method(self):
@@ -478,3 +493,96 @@ class TestSFTPOperator:
return_value = sftp_op.execute(None)
assert isinstance(return_value, str)
assert return_value == local_filepath
+
+ @pytest.mark.parametrize(
+ "operation, expected",
+ TEST_GET_PUT_PARAMS,
+ )
+ @mock.patch("airflow.providers.ssh.hooks.ssh.SSHHook.get_conn",
spec=paramiko.SSHClient)
+ @mock.patch("airflow.providers.ssh.hooks.ssh.SSHHook.get_connection",
spec=Connection)
+ def test_extract_ssh_conn_id(self, get_connection, get_conn, operation,
expected):
+ get_connection.return_value = Connection(
+ conn_id="sftp_conn_id",
+ conn_type="sftp",
+ host="remotehost",
+ port=22,
+ )
+
+ dag_id = "sftp_dag"
+ task_id = "sftp_task"
+
+ task = SFTPOperator(
+ task_id=task_id,
+ ssh_conn_id="sftp_conn_id",
+ dag=DAG(dag_id),
+ start_date=timezone.utcnow(),
+ local_filepath="/path/local",
+ remote_filepath="/path/remote",
+ operation=operation,
+ )
+ lineage = task.get_openlineage_facets_on_start()
+
+ assert lineage.inputs == expected[0]
+ assert lineage.outputs == expected[1]
+
+ @pytest.mark.parametrize(
+ "operation, expected",
+ TEST_GET_PUT_PARAMS,
+ )
+ @mock.patch("airflow.providers.ssh.hooks.ssh.SSHHook.get_conn",
spec=paramiko.SSHClient)
+ @mock.patch("airflow.providers.ssh.hooks.ssh.SSHHook.get_connection",
spec=Connection)
+ def test_extract_sftp_hook(self, get_connection, get_conn, operation,
expected):
+ get_connection.return_value = Connection(
+ conn_id="sftp_conn_id",
+ conn_type="sftp",
+ host="remotehost",
+ port=22,
+ )
+
+ dag_id = "sftp_dag"
+ task_id = "sftp_task"
+
+ task = SFTPOperator(
+ task_id=task_id,
+ sftp_hook=SFTPHook(ssh_conn_id="sftp_conn_id"),
+ dag=DAG(dag_id),
+ start_date=timezone.utcnow(),
+ local_filepath="/path/local",
+ remote_filepath="/path/remote",
+ operation=operation,
+ )
+ lineage = task.get_openlineage_facets_on_start()
+
+ assert lineage.inputs == expected[0]
+ assert lineage.outputs == expected[1]
+
+ @pytest.mark.parametrize(
+ "operation, expected",
+ TEST_GET_PUT_PARAMS,
+ )
+ @mock.patch("airflow.providers.ssh.hooks.ssh.SSHHook.get_conn",
spec=paramiko.SSHClient)
+ @mock.patch("airflow.providers.ssh.hooks.ssh.SSHHook.get_connection",
spec=Connection)
+ def test_extract_ssh_hook(self, get_connection, get_conn, operation,
expected):
+ get_connection.return_value = Connection(
+ conn_id="sftp_conn_id",
+ conn_type="sftp",
+ host="remotehost",
+ port=22,
+ )
+
+ dag_id = "sftp_dag"
+ task_id = "sftp_task"
+
+ task = SFTPOperator(
+ task_id=task_id,
+ ssh_hook=SSHHook(ssh_conn_id="sftp_conn_id"),
+ dag=DAG(dag_id),
+ start_date=timezone.utcnow(),
+ local_filepath="/path/local",
+ remote_filepath="/path/remote",
+ operation=operation,
+ )
+ lineage = task.get_openlineage_facets_on_start()
+
+ assert lineage.inputs == expected[0]
+ assert lineage.outputs == expected[1]