This is an automated email from the ASF dual-hosted git repository.
taragolis pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new f24a03709e Add deferrable param in SFTPSensor (#37117)
f24a03709e is described below
commit f24a03709eecbda87ed794cee567806e51c3a21f
Author: Utkarsh Sharma <[email protected]>
AuthorDate: Sun Feb 4 05:27:54 2024 +0530
Add deferrable param in SFTPSensor (#37117)
* Add deferrable param in SFTPSensor
* Fix static checks
* Fixed duplicate task_id issue
* Add missing tests
* Fix failing test cases
* Fix provider.yaml static check
* Update airflow/providers/sftp/sensors/sftp.py
Co-authored-by: Pankaj Koti <[email protected]>
* Update airflow/providers/sftp/sensors/sftp.py
* Removed get_files_by_pattern
* Change the name for rst marker - howto_sensor_sftp_deferrable
* Handle missing file case for triggers
---------
Co-authored-by: Pankaj Koti <[email protected]>
---
airflow/providers/sftp/hooks/sftp.py | 170 ++++++++++-
airflow/providers/sftp/provider.yaml | 6 +
airflow/providers/sftp/sensors/sftp.py | 55 +++-
airflow/providers/sftp/triggers/__init__.py | 16 +
airflow/providers/sftp/triggers/sftp.py | 137 +++++++++
.../sensors/sftp_sensor.rst | 8 +
generated/provider_dependencies.json | 1 +
pyproject.toml | 1 +
tests/providers/sftp/hooks/test_sftp.py | 327 ++++++++++++++++++++-
tests/providers/sftp/triggers/__init__.py | 16 +
tests/providers/sftp/triggers/test_sftp.py | 203 +++++++++++++
tests/system/providers/sftp/example_sftp_sensor.py | 12 +-
12 files changed, 947 insertions(+), 5 deletions(-)
diff --git a/airflow/providers/sftp/hooks/sftp.py
b/airflow/providers/sftp/hooks/sftp.py
index 3b67d164a2..eb13ddf555 100644
--- a/airflow/providers/sftp/hooks/sftp.py
+++ b/airflow/providers/sftp/hooks/sftp.py
@@ -23,14 +23,20 @@ import os
import stat
import warnings
from fnmatch import fnmatch
-from typing import TYPE_CHECKING, Any, Callable
+from typing import TYPE_CHECKING, Any, Callable, Sequence
+
+import asyncssh
+from asgiref.sync import sync_to_async
from airflow.exceptions import AirflowException,
AirflowProviderDeprecationWarning
+from airflow.hooks.base import BaseHook
from airflow.providers.ssh.hooks.ssh import SSHHook
if TYPE_CHECKING:
import paramiko
+ from airflow.models.connection import Connection
+
class SFTPHook(SSHHook):
"""Interact with SFTP.
@@ -400,3 +406,165 @@ class SFTPHook(SSHHook):
matched_files.append(file)
return matched_files
+
+
+class SFTPHookAsync(BaseHook):
+ """
+ Interact with an SFTP server via asyncssh package.
+
+ :param sftp_conn_id: SFTP connection ID to be used for connecting to SFTP
server
+ :param host: hostname of the SFTP server
+ :param port: port of the SFTP server
+ :param username: username used when authenticating to the SFTP server
+ :param password: password used when authenticating to the SFTP server.
+ Can be left blank if using a key file
+ :param known_hosts: path to the known_hosts file on the local file system.
Defaults to ``~/.ssh/known_hosts``.
+ :param key_file: path to the client key file used for authentication to
SFTP server
+ :param passphrase: passphrase used with the key_file for authentication to
SFTP server
+ """
+
+ conn_name_attr = "ssh_conn_id"
+ default_conn_name = "sftp_default"
+ conn_type = "sftp"
+ hook_name = "SFTP"
+ default_known_hosts = "~/.ssh/known_hosts"
+
+ def __init__( # nosec: B107
+ self,
+ sftp_conn_id: str = default_conn_name,
+ host: str = "",
+ port: int = 22,
+ username: str = "",
+ password: str = "",
+ known_hosts: str = default_known_hosts,
+ key_file: str = "",
+ passphrase: str = "",
+ private_key: str = "",
+ ) -> None:
+ self.sftp_conn_id = sftp_conn_id
+ self.host = host
+ self.port = port
+ self.username = username
+ self.password = password
+ self.known_hosts: bytes | str = os.path.expanduser(known_hosts)
+ self.key_file = key_file
+ self.passphrase = passphrase
+ self.private_key = private_key
+
+ def _parse_extras(self, conn: Connection) -> None:
+ """Parse extra fields from the connection into instance fields."""
+ extra_options = conn.extra_dejson
+ if "key_file" in extra_options and self.key_file == "":
+ self.key_file = extra_options["key_file"]
+ if "known_hosts" in extra_options and self.known_hosts !=
self.default_known_hosts:
+ self.known_hosts = extra_options["known_hosts"]
+ if ("passphrase" or "private_key_passphrase") in extra_options:
+ self.passphrase = extra_options["passphrase"]
+ if "private_key" in extra_options:
+ self.private_key = extra_options["private_key"]
+
+ host_key = extra_options.get("host_key")
+ no_host_key_check = extra_options.get("no_host_key_check")
+
+ if no_host_key_check is not None:
+ no_host_key_check = str(no_host_key_check).lower() == "true"
+ if host_key is not None and no_host_key_check:
+ raise ValueError("Host key check was skipped, but `host_key`
value was given")
+ if no_host_key_check:
+ self.log.warning(
+ "No Host Key Verification. This won't protect against
Man-In-The-Middle attacks"
+ )
+ self.known_hosts = "none"
+
+ if host_key is not None:
+ self.known_hosts = f"{conn.host} {host_key}".encode()
+
+ async def _get_conn(self) -> asyncssh.SSHClientConnection:
+ """
+ Asynchronously connect to the SFTP server as an SSH client.
+
+ The following parameters are provided either in the extra json object
in
+ the SFTP connection definition
+
+ - key_file
+ - known_hosts
+ - passphrase
+ """
+ conn = await sync_to_async(self.get_connection)(self.sftp_conn_id)
+ if conn.extra is not None:
+ self._parse_extras(conn)
+
+ conn_config = {
+ "host": conn.host,
+ "port": conn.port,
+ "username": conn.login,
+ "password": conn.password,
+ }
+ if self.key_file:
+ conn_config.update(client_keys=self.key_file)
+ if self.known_hosts:
+ if self.known_hosts.lower() == "none":
+ conn_config.update(known_hosts=None)
+ else:
+ conn_config.update(known_hosts=self.known_hosts)
+ if self.private_key:
+ _private_key = asyncssh.import_private_key(self.private_key,
self.passphrase)
+ conn_config.update(client_keys=[_private_key])
+ if self.passphrase:
+ conn_config.update(passphrase=self.passphrase)
+ ssh_client_conn = await asyncssh.connect(**conn_config)
+ return ssh_client_conn
+
+ async def list_directory(self, path: str = "") -> list[str] | None:
+ """Returns a list of files on the SFTP server at the provided path."""
+ ssh_conn = await self._get_conn()
+ sftp_client = await ssh_conn.start_sftp_client()
+ try:
+ files = await sftp_client.listdir(path)
+ return sorted(files)
+ except asyncssh.SFTPNoSuchFile:
+ return None
+
+ async def read_directory(self, path: str = "") ->
Sequence[asyncssh.sftp.SFTPName] | None:
+ """Returns a list of files along with their attributes on the SFTP
server at the provided path."""
+ ssh_conn = await self._get_conn()
+ sftp_client = await ssh_conn.start_sftp_client()
+ try:
+ files = await sftp_client.readdir(path)
+ return files
+ except asyncssh.SFTPNoSuchFile:
+ return None
+
+ async def get_files_and_attrs_by_pattern(
+ self, path: str = "", fnmatch_pattern: str = ""
+ ) -> Sequence[asyncssh.sftp.SFTPName]:
+ """
+ Get the files along with their attributes matching the pattern (e.g.
``*.pdf``) at the provided path.
+
+ if one exists. Otherwise, raises an AirflowException to be handled
upstream for deferring
+ """
+ files_list = await self.read_directory(path)
+ if files_list is None:
+ raise FileNotFoundError(f"No files at path {path!r} found...")
+ matched_files = [file for file in files_list if
fnmatch(str(file.filename), fnmatch_pattern)]
+ return matched_files
+
+ async def get_mod_time(self, path: str) -> str:
+ """
+ Makes SFTP async connection.
+
+ Looks for last modified time in the specific file path and returns
last modification time for
+ the file path.
+
+ :param path: full path to the remote file
+ """
+ ssh_conn = await self._get_conn()
+ sftp_client = await ssh_conn.start_sftp_client()
+ try:
+ ftp_mdtm = await sftp_client.stat(path)
+ modified_time = ftp_mdtm.mtime
+ mod_time =
datetime.datetime.fromtimestamp(modified_time).strftime("%Y%m%d%H%M%S") #
type: ignore[arg-type]
+ self.log.info("Found File %s last modified: %s", str(path),
str(mod_time))
+ return mod_time
+ except asyncssh.SFTPNoSuchFile:
+ raise AirflowException("No files matching")
diff --git a/airflow/providers/sftp/provider.yaml
b/airflow/providers/sftp/provider.yaml
index a91bec8791..928cb3702c 100644
--- a/airflow/providers/sftp/provider.yaml
+++ b/airflow/providers/sftp/provider.yaml
@@ -61,6 +61,7 @@ dependencies:
- apache-airflow>=2.6.0
- apache-airflow-providers-ssh>=2.1.0
- paramiko>=2.8.0
+ - asyncssh>=2.12.0
integrations:
- integration-name: SSH File Transfer Protocol (SFTP)
@@ -92,3 +93,8 @@ connection-types:
task-decorators:
- class-name: airflow.providers.sftp.decorators.sensors.sftp.sftp_sensor_task
name: sftp_sensor
+
+triggers:
+ - integration-name: SSH File Transfer Protocol (SFTP)
+ python-modules:
+ - airflow.providers.sftp.triggers.sftp
diff --git a/airflow/providers/sftp/sensors/sftp.py
b/airflow/providers/sftp/sensors/sftp.py
index 4db4248895..4a02669682 100644
--- a/airflow/providers/sftp/sensors/sftp.py
+++ b/airflow/providers/sftp/sensors/sftp.py
@@ -19,13 +19,15 @@
from __future__ import annotations
import os
-from datetime import datetime
+from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Any, Callable, Sequence
from paramiko.sftp import SFTP_NO_SUCH_FILE
-from airflow.exceptions import AirflowSkipException
+from airflow.configuration import conf
+from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.providers.sftp.hooks.sftp import SFTPHook
+from airflow.providers.sftp.triggers.sftp import SFTPTrigger
from airflow.sensors.base import BaseSensorOperator, PokeReturnValue
from airflow.utils.timezone import convert_to_utc
@@ -41,6 +43,7 @@ class SFTPSensor(BaseSensorOperator):
:param file_pattern: The pattern that will be used to match the file
(fnmatch format)
:param sftp_conn_id: The connection to run the sensor against
:param newer_than: DateTime for which the file or file path should be
newer than, comparison is inclusive
+ :param deferrable: If waiting for completion, whether to defer the task
until done, default is ``False``.
"""
template_fields: Sequence[str] = (
@@ -58,6 +61,7 @@ class SFTPSensor(BaseSensorOperator):
python_callable: Callable | None = None,
op_args: list | None = None,
op_kwargs: dict[str, Any] | None = None,
+ deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -69,6 +73,7 @@ class SFTPSensor(BaseSensorOperator):
self.python_callable: Callable | None = python_callable
self.op_args = op_args or []
self.op_kwargs = op_kwargs or {}
+ self.deferrable = deferrable
def poke(self, context: Context) -> PokeReturnValue | bool:
self.hook = SFTPHook(self.sftp_conn_id)
@@ -119,3 +124,49 @@ class SFTPSensor(BaseSensorOperator):
xcom_value={"files_found": files_found,
"decorator_return_value": callable_return},
)
return True
+
+ def execute(self, context: Context) -> Any:
+ # Unlike other async sensors, we do not follow the pattern of calling
the synchronous self.poke()
+ # method before deferring here. This is due to the current limitations
we have in the synchronous
+ # SFTPHook methods. They are as follows:
+ #
+ # For file_pattern sensing, the hook implements list_directory()
method which returns a list of
+ # filenames only without the attributes like modified time which is
required for the file_pattern
+ # sensing when newer_than is supplied. This leads to intermittent
failures potentially due to
+ # throttling by the SFTP server as the hook makes multiple calls to
the server to get the
+ # attributes for each of the files in the directory.This limitation is
resolved here by instead
+ # calling the read_directory() method which returns a list of files
along with their attributes
+ # in a single call. We can add back the call to self.poke() before
deferring once the above
+ # limitations are resolved in the sync sensor.
+ if self.deferrable:
+ self.defer(
+ timeout=timedelta(seconds=self.timeout),
+ trigger=SFTPTrigger(
+ path=self.path,
+ file_pattern=self.file_pattern,
+ sftp_conn_id=self.sftp_conn_id,
+ poke_interval=self.poke_interval,
+ newer_than=self.newer_than,
+ ),
+ method_name="execute_complete",
+ )
+ else:
+ return super().execute(context=context)
+
+ def execute_complete(self, context: dict[str, Any], event: Any = None) ->
None:
+ """
+ Callback for when the trigger fires - returns immediately.
+
+ Relies on trigger to throw an exception, otherwise it assumes
execution was
+ successful.
+ """
+ if event is not None:
+ if "status" in event and event["status"] == "error":
+ raise AirflowException(event["message"])
+
+ if "status" in event and event["status"] == "success":
+ self.log.info("%s completed successfully.", self.task_id)
+ self.log.info(event["message"])
+ return None
+
+ raise AirflowException("No event received in trigger callback")
diff --git a/airflow/providers/sftp/triggers/__init__.py
b/airflow/providers/sftp/triggers/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/airflow/providers/sftp/triggers/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/airflow/providers/sftp/triggers/sftp.py
b/airflow/providers/sftp/triggers/sftp.py
new file mode 100644
index 0000000000..9220766262
--- /dev/null
+++ b/airflow/providers/sftp/triggers/sftp.py
@@ -0,0 +1,137 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import asyncio
+from datetime import datetime
+from typing import Any, AsyncIterator
+
+from dateutil.parser import parse as parse_date
+
+from airflow.exceptions import AirflowException
+from airflow.providers.sftp.hooks.sftp import SFTPHookAsync
+from airflow.triggers.base import BaseTrigger, TriggerEvent
+from airflow.utils.timezone import convert_to_utc
+
+
+class SFTPTrigger(BaseTrigger):
+ """
+ SFTPTrigger that fires in below listed scenarios.
+
+ 1. The path on the SFTP server does not exist
+ 2. The pattern do not match
+
+ :param path: The path on the SFTP server to search for a file matching the
file pattern.
+ Authentication method used in the SFTP connection must have
access to this path
+ :param file_pattern: Pattern to be used for matching against the list of
files at the path above.
+ Uses the fnmatch module from std library to perform the
matching.
+
+ :param sftp_conn_id: SFTP connection ID to be used for connecting to SFTP
server
+ :param poke_interval: How often, in seconds, to check for the existence of
the file on the SFTP server
+ """
+
+ def __init__(
+ self,
+ path: str,
+ file_pattern: str = "",
+ sftp_conn_id: str = "sftp_default",
+ newer_than: datetime | str | None = None,
+ poke_interval: float = 5,
+ ) -> None:
+ super().__init__()
+ self.path = path
+ self.file_pattern = file_pattern
+ self.sftp_conn_id = sftp_conn_id
+ self.newer_than = newer_than
+ self.poke_interval = poke_interval
+
+ def serialize(self) -> tuple[str, dict[str, Any]]:
+ """Serializes SFTPTrigger arguments and classpath."""
+ return (
+ "airflow.providers.sftp.triggers.sftp.SFTPTrigger",
+ {
+ "path": self.path,
+ "file_pattern": self.file_pattern,
+ "sftp_conn_id": self.sftp_conn_id,
+ "newer_than": self.newer_than,
+ "poke_interval": self.poke_interval,
+ },
+ )
+
+ async def run(self) -> AsyncIterator[TriggerEvent]:
+ """
+ Makes a series of asynchronous calls to sftp servers via async sftp
hook. It yields a Trigger.
+
+ - If file matching file pattern exists at the specified path return it,
+ - If file pattern was not provided, it looks directly into the
specific path which was provided.
+ - If newer then datetime was provided it looks for the file path last
modified time and
+ check whether the last modified time is greater, if true return file
if false it polls again.
+ """
+ hook = self._get_async_hook()
+ exc = None
+ if isinstance(self.newer_than, str):
+ self.newer_than = parse_date(self.newer_than)
+ _newer_than = convert_to_utc(self.newer_than) if self.newer_than else
None
+ while True:
+ try:
+ if self.file_pattern:
+ files_returned_by_hook = await
hook.get_files_and_attrs_by_pattern(
+ path=self.path, fnmatch_pattern=self.file_pattern
+ )
+ files_sensed = []
+ for file in files_returned_by_hook:
+ if _newer_than:
+ if file.attrs.mtime is None:
+ continue
+ mod_time =
datetime.fromtimestamp(float(file.attrs.mtime)).strftime(
+ "%Y%m%d%H%M%S"
+ )
+ mod_time_utc =
convert_to_utc(datetime.strptime(mod_time, "%Y%m%d%H%M%S"))
+ if _newer_than <= mod_time_utc:
+ files_sensed.append(file.filename)
+ else:
+ files_sensed.append(file.filename)
+ if files_sensed:
+ yield TriggerEvent(
+ {
+ "status": "success",
+ "message": f"Sensed {len(files_sensed)} files:
{files_sensed}",
+ }
+ )
+ else:
+ mod_time = await hook.get_mod_time(self.path)
+ if _newer_than:
+ mod_time_utc =
convert_to_utc(datetime.strptime(mod_time, "%Y%m%d%H%M%S"))
+ if _newer_than <= mod_time_utc:
+ yield TriggerEvent({"status": "success",
"message": f"Sensed file: {self.path}"})
+ else:
+ yield TriggerEvent({"status": "success", "message":
f"Sensed file: {self.path}"})
+ await asyncio.sleep(self.poke_interval)
+ except AirflowException:
+ await asyncio.sleep(self.poke_interval)
+ except FileNotFoundError:
+ await asyncio.sleep(self.poke_interval)
+ except Exception as e:
+ exc = e
+ # Break loop to avoid infinite retries on terminal failure
+ break
+
+ yield TriggerEvent({"status": "error", "message": exc})
+
+ def _get_async_hook(self) -> SFTPHookAsync:
+ return SFTPHookAsync(sftp_conn_id=self.sftp_conn_id)
diff --git a/docs/apache-airflow-providers-sftp/sensors/sftp_sensor.rst
b/docs/apache-airflow-providers-sftp/sensors/sftp_sensor.rst
index 70294974f5..b99bc36672 100644
--- a/docs/apache-airflow-providers-sftp/sensors/sftp_sensor.rst
+++ b/docs/apache-airflow-providers-sftp/sensors/sftp_sensor.rst
@@ -44,3 +44,11 @@ Whatever returned by the python callable is put into XCom.
:dedent: 4
:start-after: [START howto_operator_sftp_sensor_decorator]
:end-before: [END howto_operator_sftp_sensor_decorator]
+
+Checks for the existence of a file on an SFTP server in the deferrable mode:
+
+.. exampleinclude:: /../../tests/system/providers/sftp/example_sftp_sensor.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_sensor_sftp_deferrable]
+ :end-before: [END howto_sensor_sftp_deferrable]
diff --git a/generated/provider_dependencies.json
b/generated/provider_dependencies.json
index 63af5f0982..d69a2ab0c8 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -993,6 +993,7 @@
"deps": [
"apache-airflow-providers-ssh>=2.1.0",
"apache-airflow>=2.6.0",
+ "asyncssh>=2.12.0",
"paramiko>=2.8.0"
],
"devel-deps": [],
diff --git a/pyproject.toml b/pyproject.toml
index abc1ea0013..a8384e675f 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -891,6 +891,7 @@ sendgrid = [ # source:
airflow/providers/sendgrid/provider.yaml
]
sftp = [ # source: airflow/providers/sftp/provider.yaml
"apache-airflow[ssh]",
+ "asyncssh>=2.12.0",
"paramiko>=2.8.0",
]
singularity = [ # source: airflow/providers/singularity/provider.yaml
diff --git a/tests/providers/sftp/hooks/test_sftp.py
b/tests/providers/sftp/hooks/test_sftp.py
index 847e64a138..f914e1da56 100644
--- a/tests/providers/sftp/hooks/test_sftp.py
+++ b/tests/providers/sftp/hooks/test_sftp.py
@@ -17,18 +17,22 @@
# under the License.
from __future__ import annotations
+import datetime
import json
import os
import shutil
from io import StringIO
from unittest import mock
+from unittest.mock import AsyncMock, patch
import paramiko
import pytest
+from asyncssh import SFTPAttrs, SFTPNoSuchFile
+from asyncssh.sftp import SFTPName
from airflow.exceptions import AirflowException,
AirflowProviderDeprecationWarning
from airflow.models import Connection
-from airflow.providers.sftp.hooks.sftp import SFTPHook
+from airflow.providers.sftp.hooks.sftp import SFTPHook, SFTPHookAsync
from airflow.providers.ssh.hooks.ssh import SSHHook
from airflow.utils.session import provide_session
@@ -458,3 +462,324 @@ class TestSFTPHook:
def test_get_matched_files_with_different_pattern(self):
output = self.hook.get_files_by_pattern(self.temp_dir, "*_file_*.txt")
assert output == [ANOTHER_FILE_FOR_TESTS]
+
+
+class MockSFTPClient:
+ def __init__(self):
+ pass
+
+ async def listdir(self, path: str):
+ if path == "/path/does_not/exist/":
+ raise SFTPNoSuchFile("File does not exist")
+ else:
+ return ["..", ".", "file"]
+
+ async def readdir(self, path: str):
+ if path == "/path/does_not/exist/":
+ raise SFTPNoSuchFile("File does not exist")
+ else:
+ return [SFTPName(".."), SFTPName("."), SFTPName("file")]
+
+ async def stat(self, path: str):
+ if path == "/path/does_not/exist/":
+ raise SFTPNoSuchFile("No files matching")
+ else:
+ sftp_obj = SFTPAttrs()
+ sftp_obj.mtime = 1667302566
+ return sftp_obj
+
+
+class MockSSHClient:
+ def __init__(self):
+ pass
+
+ async def start_sftp_client(self):
+ return MockSFTPClient()
+
+
+class MockAirflowConnection:
+ def __init__(self, known_hosts="~/.ssh/known_hosts"):
+ self.host = "localhost"
+ self.port = 22
+ self.login = "username"
+ self.password = "password"
+ self.extra = """
+ {
+ "key_file": "~/keys/my_key",
+ "known_hosts": "unused",
+ "passphrase": "mypassphrase"
+ }
+ """
+ self.extra_dejson = {
+ "key_file": "~/keys/my_key",
+ "known_hosts": known_hosts,
+ "passphrase": "mypassphrase",
+ }
+
+
+class MockAirflowConnectionWithHostKey:
+ def __init__(self, host_key: str | None = None, no_host_key_check: bool =
False, port: int = 22):
+ self.host = "localhost"
+ self.port = port
+ self.login = "username"
+ self.password = "password"
+ self.extra = f'{{ "no_host_key_check": {no_host_key_check},
"host_key": {host_key} }}'
+ self.extra_dejson = { # type: ignore
+ "no_host_key_check": no_host_key_check,
+ "host_key": host_key,
+ "key_file": "~/keys/my_key",
+ "private_key": "~/keys/my_key",
+ }
+
+
+class MockAirflowConnectionWithPrivate:
+ def __init__(self):
+ self.host = "localhost"
+ self.port = 22
+ self.login = "username"
+ self.password = "password"
+ self.extra = """
+ {
+ "private_key": "~/keys/my_key",
+ "known_hosts": "unused",
+ "passphrase": "mypassphrase"
+ }
+ """
+ self.extra_dejson = {
+ "private_key": "~/keys/my_key",
+ "known_hosts": None,
+ "passphrase": "mypassphrase",
+ }
+
+
+class TestSFTPHookAsync:
+ @patch("asyncssh.connect", new_callable=AsyncMock)
+ @patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync.get_connection")
+ @pytest.mark.asyncio
+ async def
test_extra_dejson_fields_for_connection_building_known_hosts_none(
+ self, mock_get_connection, mock_connect, caplog
+ ):
+ """
+ Assert that connection details passed through the extra field in the
Airflow connection
+ are properly passed when creating SFTP connection
+ """
+
+ mock_get_connection.return_value =
MockAirflowConnection(known_hosts="None")
+
+ hook = SFTPHookAsync()
+ await hook._get_conn()
+
+ expected_connection_details = {
+ "host": "localhost",
+ "port": 22,
+ "username": "username",
+ "password": "password",
+ "client_keys": "~/keys/my_key",
+ "known_hosts": None,
+ "passphrase": "mypassphrase",
+ }
+
+ mock_connect.assert_called_with(**expected_connection_details)
+
+ @pytest.mark.parametrize(
+ "mock_port, mock_host_key",
+ [
+ (22, "ssh-ed25519
AAAAC3NzaC1lZDI1NTE5AAAAIFe8P8lk5HFfL/rMlcCMHQhw1cg+uZtlK5rXQk2C4pOY"),
+ (2222,
"AAAAC3NzaC1lZDI1NTE5AAAAIFe8P8lk5HFfL/rMlcCMHQhw1cg+uZtlK5rXQk2C4pOY"),
+ (
+ 2222,
+ "ecdsa-sha2-nistp256
AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBDDsXFe87LsBA1Hfi+mtw"
+
"/EoQkv8bXVtfOwdMP1ETpHVsYpm5QG/7tsLlKdE8h6EoV/OFw7XQtoibNZp/l5ABjE=",
+ ),
+ ],
+ )
+ @patch("asyncssh.connect", new_callable=AsyncMock)
+ @patch("asyncssh.import_private_key")
+ @patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync.get_connection")
+ @pytest.mark.asyncio
+ async def test_extra_dejson_fields_for_connection_with_host_key(
+ self,
+ mock_get_connection,
+ mock_import_private_key,
+ mock_connect,
+ mock_port,
+ mock_host_key,
+ ):
+ """
+ Assert that connection details passed through the extra field in the
Airflow connection
+ are properly passed to paramiko client for validating given host key.
+ """
+ mock_get_connection.return_value = MockAirflowConnectionWithHostKey(
+ host_key=mock_host_key, no_host_key_check=False, port=mock_port
+ )
+
+ hook = SFTPHookAsync()
+ await hook._get_conn()
+
+ assert hook.known_hosts == f"localhost {mock_host_key}".encode()
+
+ @patch("asyncssh.connect", new_callable=AsyncMock)
+ @patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync.get_connection")
+ @pytest.mark.asyncio
+ async def test_extra_dejson_fields_for_connection_raises_valuerror(
+ self, mock_get_connection, mock_connect
+ ):
+ """
+ Assert that when both host_key and no_host_key_check are set, a
valuerror is raised because no_host_key_check
+ should be unset when host_key is given and the host_key needs to be
validated.
+ """
+ host_key = "ssh-ed25519
AAAAC3NzaC1lZDI1NTE5AAAAIFe8P8lk5HFfL/rMlcCMHQhw1cg+uZtlK5rXQk2C4pOY"
+ mock_get_connection.return_value = MockAirflowConnectionWithHostKey(
+ host_key=host_key, no_host_key_check=True
+ )
+
+ hook = SFTPHookAsync()
+ with pytest.raises(ValueError) as exc:
+ await hook._get_conn()
+
+ assert str(exc.value) == "Host key check was skipped, but `host_key`
value was given"
+
+ @patch("paramiko.SSHClient.connect")
+ @patch("asyncssh.import_private_key")
+ @patch("asyncssh.connect", new_callable=AsyncMock)
+ @patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync.get_connection")
+ @pytest.mark.asyncio
+ async def test_no_host_key_check_set_logs_warning(
+ self, mock_get_connection, mock_connect, mock_import_pkey,
mock_ssh_connect, caplog
+ ):
+ """Assert that when no_host_key_check is set, a warning is logged for
MITM attacks possibility."""
+ mock_get_connection.return_value =
MockAirflowConnectionWithHostKey(no_host_key_check=True)
+
+ hook = SFTPHookAsync()
+ await hook._get_conn()
+ assert "No Host Key Verification. This won't protect against
Man-In-The-Middle attacks" in caplog.text
+
+ @patch("asyncssh.connect", new_callable=AsyncMock)
+ @patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync.get_connection")
+ @pytest.mark.asyncio
+ async def test_extra_dejson_fields_for_connection_building(self,
mock_get_connection, mock_connect):
+ """
+ Assert that connection details passed through the extra field in the
Airflow connection
+ are properly passed when creating SFTP connection
+ """
+
+ mock_get_connection.return_value = MockAirflowConnection()
+
+ hook = SFTPHookAsync()
+ await hook._get_conn()
+
+ expected_connection_details = {
+ "host": "localhost",
+ "port": 22,
+ "username": "username",
+ "password": "password",
+ "client_keys": "~/keys/my_key",
+ "known_hosts": "~/.ssh/known_hosts",
+ "passphrase": "mypassphrase",
+ }
+
+ mock_connect.assert_called_with(**expected_connection_details)
+
+ @pytest.mark.asyncio
+ @patch("asyncssh.connect", new_callable=AsyncMock)
+ @patch("asyncssh.import_private_key")
+ @patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync.get_connection")
+ async def test_connection_private(self, mock_get_connection,
mock_import_private_key, mock_connect):
+ """
+ Assert that connection details with private key passed through the
extra field in the Airflow connection
+ are properly passed when creating SFTP connection
+ """
+
+ mock_get_connection.return_value = MockAirflowConnectionWithPrivate()
+ mock_import_private_key.return_value = "test"
+
+ hook = SFTPHookAsync()
+ await hook._get_conn()
+
+ expected_connection_details = {
+ "host": "localhost",
+ "port": 22,
+ "username": "username",
+ "password": "password",
+ "client_keys": ["test"],
+ "passphrase": "mypassphrase",
+ }
+
+ mock_connect.assert_called_with(**expected_connection_details)
+
+ @patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync._get_conn")
+ @pytest.mark.asyncio
+ async def test_list_directory_path_does_not_exist(self,
mock_hook_get_conn):
+ """
+ Assert that AirflowException is raised when path does not exist on
SFTP server
+ """
+ mock_hook_get_conn.return_value = MockSSHClient()
+ hook = SFTPHookAsync()
+
+ expected_files = None
+ files = await hook.list_directory(path="/path/does_not/exist/")
+ assert files == expected_files
+
+ @patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync._get_conn")
+ @pytest.mark.asyncio
+ async def test_read_directory_path_does_not_exist(self,
mock_hook_get_conn):
+ """
+ Assert that AirflowException is raised when path does not exist on
SFTP server
+ """
+ mock_hook_get_conn.return_value = MockSSHClient()
+ hook = SFTPHookAsync()
+
+ expected_files = None
+ files = await hook.read_directory(path="/path/does_not/exist/")
+ assert files == expected_files
+
+ @patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync._get_conn")
+ @pytest.mark.asyncio
+ async def test_list_directory_path_has_files(self, mock_hook_get_conn):
+ """
+ Assert that file list is returned when path exists on SFTP server
+ """
+ mock_hook_get_conn.return_value = MockSSHClient()
+ hook = SFTPHookAsync()
+
+ expected_files = ["..", ".", "file"]
+ files = await hook.list_directory(path="/path/exists/")
+ assert sorted(files) == sorted(expected_files)
+
+ @patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync._get_conn")
+ @pytest.mark.asyncio
+ async def test_get_file_by_pattern_with_match(self, mock_hook_get_conn):
+ """
+ Assert that filename is returned when file pattern is matched on SFTP
server
+ """
+ mock_hook_get_conn.return_value = MockSSHClient()
+ hook = SFTPHookAsync()
+
+ files = await
hook.get_files_and_attrs_by_pattern(path="/path/exists/",
fnmatch_pattern="file")
+
+ assert len(files) == 1
+ assert files[0].filename == "file"
+
+ @pytest.mark.asyncio
+ @patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync._get_conn")
+ async def test_get_mod_time(self, mock_hook_get_conn):
+ """
+ Assert that file attribute and return the modified time of the file
+ """
+ mock_hook_get_conn.return_value.start_sftp_client.return_value =
MockSFTPClient()
+ hook = SFTPHookAsync()
+ mod_time = await hook.get_mod_time("/path/exists/file")
+ expected_value =
datetime.datetime.fromtimestamp(1667302566).strftime("%Y%m%d%H%M%S")
+ assert mod_time == expected_value
+
+ @pytest.mark.asyncio
+ @patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync._get_conn")
+ async def test_get_mod_time_exception(self, mock_hook_get_conn):
+ """
+ Assert that get_mod_time raise exception when file does not exist
+ """
+ mock_hook_get_conn.return_value.start_sftp_client.return_value =
MockSFTPClient()
+ hook = SFTPHookAsync()
+ with pytest.raises(AirflowException) as exc:
+ await hook.get_mod_time("/path/does_not/exist/")
+ assert str(exc.value) == "No files matching"
diff --git a/tests/providers/sftp/triggers/__init__.py
b/tests/providers/sftp/triggers/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/tests/providers/sftp/triggers/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/providers/sftp/triggers/test_sftp.py
b/tests/providers/sftp/triggers/test_sftp.py
new file mode 100644
index 0000000000..f1bfa52533
--- /dev/null
+++ b/tests/providers/sftp/triggers/test_sftp.py
@@ -0,0 +1,203 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+import datetime
+import time
+from unittest import mock
+
+import pytest
+from asyncssh.sftp import SFTPAttrs, SFTPName
+
+from airflow.exceptions import AirflowException
+from airflow.providers.sftp.triggers.sftp import SFTPTrigger
+from airflow.triggers.base import TriggerEvent
+
+
+class TestSFTPTrigger:
+ def test_sftp_trigger_serialization(self):
+ """
+ Asserts that the SFTPTrigger correctly serializes its arguments and
classpath.
+ """
+ trigger = SFTPTrigger(path="test/path/", sftp_conn_id="sftp_default",
file_pattern="my_test_file")
+ classpath, kwargs = trigger.serialize()
+ assert classpath == "airflow.providers.sftp.triggers.sftp.SFTPTrigger"
+ assert kwargs == {
+ "path": "test/path/",
+ "file_pattern": "my_test_file",
+ "sftp_conn_id": "sftp_default",
+ "newer_than": None,
+ "poke_interval": 5.0,
+ }
+
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize(
+ "newer_than",
+ ["19700101053001", None],
+ )
+
@mock.patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync.get_files_and_attrs_by_pattern")
+ async def test_sftp_trigger_run_trigger_success_state(self,
mock_get_files_by_pattern, newer_than):
+ """
+ Assert that a TriggerEvent with a success status is yielded if a file
+ matching the pattern is returned by the hook
+ """
+ mock_get_files_by_pattern.return_value = [
+ SFTPName("some_file", attrs=SFTPAttrs(mtime=1684244898)),
+ SFTPName("some_other_file"),
+ ]
+
+ trigger = SFTPTrigger(
+ path="test/path/", sftp_conn_id="sftp_default",
file_pattern="my_test_file", newer_than=newer_than
+ )
+
+ if newer_than:
+ expected_event = {"status": "success", "message": "Sensed 1 files:
['some_file']"}
+ else:
+ expected_event = {
+ "status": "success",
+ "message": "Sensed 2 files: ['some_file', 'some_other_file']",
+ }
+
+ generator = trigger.run()
+ actual_event = await generator.asend(None)
+
+ assert TriggerEvent(expected_event) == actual_event
+
+ @pytest.mark.asyncio
+ @mock.patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync.get_mod_time")
+ async def test_sftp_success_without_file_pattern(self, mock_mod_time):
+ """
+ Test SFTPTrigger run method by mocking the file path and without file
pattern,
+ assert that a TriggerEvent with a success status is yielded.
+ """
+
+ mock_mod_time.return_value = "19700101053001"
+
+ trigger = SFTPTrigger(path="test/path/test.txt",
sftp_conn_id="sftp_default", file_pattern="")
+
+ expected_event = {"status": "success", "message": "Sensed file:
test/path/test.txt"}
+
+ generator = trigger.run()
+ actual_event = await generator.asend(None)
+
+ assert TriggerEvent(expected_event) == actual_event
+
+ @pytest.mark.asyncio
+ @mock.patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync.get_mod_time")
+ async def test_sftp_success_with_newer_then(self, mock_mod_time):
+ """
+ Test SFTPTrigger run method by mocking the file path, without file
pattern, and with newer then datetime
+ assert that a TriggerEvent with a success status is yielded.
+ """
+ mock_mod_time.return_value =
datetime.datetime.now().strftime("%Y%m%d%H%M%S")
+ yesterday = datetime.datetime.now() - datetime.timedelta(days=1)
+ trigger = SFTPTrigger(
+ path="test/path/test.txt", sftp_conn_id="sftp_default",
file_pattern="", newer_than=yesterday
+ )
+
+ expected_event = {"status": "success", "message": "Sensed file:
test/path/test.txt"}
+
+ generator = trigger.run()
+ actual_event = await generator.asend(None)
+
+ assert TriggerEvent(expected_event) == actual_event
+
+ @pytest.mark.asyncio
+
@mock.patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync.get_files_and_attrs_by_pattern")
+ async def test_sftp_trigger_run_trigger_defer_state(
+ self,
+ mock_get_files_by_pattern,
+ ):
+ """
+ Assert that a the task does not complete,
+ indicating that the task needs to be deferred
+ """
+ mock_get_files_by_pattern.return_value = [SFTPName("my_test_file.txt",
attrs=SFTPAttrs(mtime=49129))]
+ yesterday = datetime.datetime.now() - datetime.timedelta(days=1)
+ trigger = SFTPTrigger(
+ path="test/path/", sftp_conn_id="sftp_default",
file_pattern="my_test_file", newer_than=yesterday
+ )
+
+ task = asyncio.create_task(trigger.run().__anext__())
+ await asyncio.sleep(0.5)
+
+ # TriggerEvent was not returned
+ assert task.done() is False
+ asyncio.get_event_loop().stop()
+
+ @pytest.mark.asyncio
+ @mock.patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync.get_mod_time")
+ async def test_sftp_with_newer_then_date_greater(self, mock_mod_time):
+ """
+ Test the Trigger run method by passing full file path, without file
pattern and along with newer then datetime.
+ mock the datetime as greater then the last modified date and make the
trigger task in running
+ state and assert to success
+ """
+ today_time = time.time()
+ mock_mod_time.return_value =
datetime.date.fromtimestamp(today_time).strftime("%Y%m%d%H%M%S")
+ newer_then_time = datetime.datetime.now() + datetime.timedelta(hours=1)
+ trigger = SFTPTrigger(
+ path="test/path/test.txt",
+ sftp_conn_id="sftp_default",
+ file_pattern="",
+ newer_than=newer_then_time,
+ )
+
+ task = asyncio.create_task(trigger.run().__anext__())
+ await asyncio.sleep(0.5)
+
+ # TriggerEvent was not returned
+ assert task.done() is False
+ asyncio.get_event_loop().stop()
+
+ @pytest.mark.asyncio
+
@mock.patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync.get_files_and_attrs_by_pattern")
+ async def test_sftp_trigger_run_trigger_failure_state(self,
mock_get_files_by_pattern):
+ """
+ Mock the hook to raise other than an AirflowException and assert that
a TriggerEvent with a failure status
+ """
+ mock_get_files_by_pattern.side_effect = Exception("An unexpected
exception")
+
+ trigger = SFTPTrigger(path="test/path/", sftp_conn_id="sftp_default",
file_pattern="my_test_file")
+
+ expected_event = {"status": "failure", "message": "An unexpected
exception"}
+
+ with pytest.raises(Exception):
+ generator = trigger.run()
+ actual_event = await generator.asend(None)
+
+ assert TriggerEvent(expected_event) == actual_event
+
+ @pytest.mark.asyncio
+
@mock.patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync.get_files_and_attrs_by_pattern")
+ async def test_sftp_trigger_run_airflow_exception(self,
mock_get_files_by_pattern):
+ """
+ Assert that a the task does not complete if the hook raises an
AirflowException,
+ indicating that the task needs to be deferred
+ """
+
+ mock_get_files_by_pattern.side_effect = AirflowException("No files at
path /test/path/ found...")
+
+ trigger = SFTPTrigger(path="/test/path/", sftp_conn_id="sftp_default",
file_pattern="my_test_file")
+
+ task = asyncio.create_task(trigger.run().__anext__())
+ await asyncio.sleep(0.5)
+
+ # TriggerEvent was not returned
+ assert task.done() is False
+ asyncio.get_event_loop().stop()
diff --git a/tests/system/providers/sftp/example_sftp_sensor.py
b/tests/system/providers/sftp/example_sftp_sensor.py
index 934556637b..df705b4b87 100644
--- a/tests/system/providers/sftp/example_sftp_sensor.py
+++ b/tests/system/providers/sftp/example_sftp_sensor.py
@@ -76,8 +76,18 @@ with DAG(
sftp_with_operator = SFTPSensor(task_id="sftp_operator",
path=FULL_FILE_PATH, poke_interval=10)
# [END howto_operator_sftp_sensor]
+ # [START howto_sensor_sftp_deferrable]
+ sftp_sensor_with_async = SFTPSensor(
+ task_id="sftp_operator_async", path=FULL_FILE_PATH, poke_interval=10,
deferrable=True
+ )
+ # [END howto_sensor_sftp_deferrable]
+
remove_file_task_start >> sleep_task >> create_decoy_file_task
- remove_file_task_start >> [sftp_with_operator, sftp_with_sensor] >>
remove_file_task_end
+ (
+ remove_file_task_start
+ >> [sftp_with_operator, sftp_sensor_with_async, sftp_with_sensor]
+ >> remove_file_task_end
+ )
from tests.system.utils.watcher import watcher