This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new f3aacebe50 Convert sftp hook to use paramiko instead of pysftp (#24512)
f3aacebe50 is described below
commit f3aacebe502c4ea5dc2b7d29373539296fa037eb
Author: Paul Williams <[email protected]>
AuthorDate: Sun Jun 19 18:40:25 2022 -0400
Convert sftp hook to use paramiko instead of pysftp (#24512)
---
airflow/providers/sftp/hooks/sftp.py | 254 ++++++++++++---------
airflow/providers/sftp/operators/sftp.py | 133 ++++++-----
airflow/providers/ssh/hooks/ssh.py | 13 +-
docker_tests/test_prod_image.py | 2 +-
.../connections/sftp.rst | 27 ++-
.../connections/ssh.rst | 6 +-
docs/spelling_wordlist.txt | 2 +
setup.py | 2 -
tests/providers/sftp/hooks/test_sftp.py | 61 +++--
tests/providers/sftp/operators/test_sftp.py | 55 ++++-
tests/providers/ssh/hooks/test_ssh.py | 25 ++
11 files changed, 366 insertions(+), 214 deletions(-)
diff --git a/airflow/providers/sftp/hooks/sftp.py
b/airflow/providers/sftp/hooks/sftp.py
index d436d091b5..b6037c9a55 100644
--- a/airflow/providers/sftp/hooks/sftp.py
+++ b/airflow/providers/sftp/hooks/sftp.py
@@ -17,15 +17,15 @@
# under the License.
"""This module contains SFTP hook."""
import datetime
+import os
import stat
import warnings
from fnmatch import fnmatch
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
-import pysftp
-import tenacity
-from paramiko import SSHException
+import paramiko
+from airflow.exceptions import AirflowException
from airflow.providers.ssh.hooks.ssh import SSHHook
@@ -49,11 +49,9 @@ class SFTPHook(SSHHook):
Errors that may occur throughout but should be handled downstream.
For consistency reasons with SSHHook, the preferred parameter is
"ssh_conn_id".
- Please note that it is still possible to use the parameter "ftp_conn_id"
- to initialize the hook, but it will be removed in future Airflow versions.
:param ssh_conn_id: The :ref:`sftp connection id<howto/connection:sftp>`
- :param ftp_conn_id (Outdated): The :ref:`sftp connection
id<howto/connection:sftp>`
+ :param ssh_hook: Optional SSH hook (included to support passing of an SSH
hook to the SFTP operator)
"""
conn_name_attr = 'ssh_conn_id'
@@ -73,9 +71,29 @@ class SFTPHook(SSHHook):
def __init__(
self,
ssh_conn_id: Optional[str] = 'sftp_default',
+ ssh_hook: Optional[SSHHook] = None,
*args,
**kwargs,
) -> None:
+ self.conn: Optional[paramiko.SFTPClient] = None
+
+ # TODO: remove support for ssh_hook when it is removed from
SFTPOperator
+ self.ssh_hook = ssh_hook
+
+ if self.ssh_hook is not None:
+ warnings.warn(
+ 'Parameter `ssh_hook` is deprecated and will be removed in a
future version.',
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ if not isinstance(self.ssh_hook, SSHHook):
+ raise AirflowException(
+ f'ssh_hook must be an instance of SSHHook, but got
{type(self.ssh_hook)}'
+ )
+ self.log.info('ssh_hook is provided. It will be used to generate
SFTP connection.')
+ self.ssh_conn_id = self.ssh_hook.ssh_conn_id
+ return
+
ftp_conn_id = kwargs.pop('ftp_conn_id', None)
if ftp_conn_id:
warnings.warn(
@@ -84,100 +102,33 @@ class SFTPHook(SSHHook):
stacklevel=2,
)
ssh_conn_id = ftp_conn_id
+
kwargs['ssh_conn_id'] = ssh_conn_id
+ self.ssh_conn_id = ssh_conn_id
+
super().__init__(*args, **kwargs)
- self.conn = None
- self.private_key_pass = None
- self.ciphers = None
-
- # Fail for unverified hosts, unless this is explicitly allowed
- self.no_host_key_check = False
-
- if self.ssh_conn_id is not None:
- conn = self.get_connection(self.ssh_conn_id)
- if conn.extra is not None:
- extra_options = conn.extra_dejson
-
- # For backward compatibility
- # TODO: remove in the next major provider release.
-
- if 'private_key_pass' in extra_options:
- warnings.warn(
- 'Extra option `private_key_pass` is deprecated.'
- 'Please use `private_key_passphrase` instead.'
- '`private_key_passphrase` will precede if both options
are specified.'
- 'The old option `private_key_pass` will be removed in
a future release.',
- DeprecationWarning,
- stacklevel=2,
- )
- self.private_key_pass = extra_options.get(
- 'private_key_passphrase',
extra_options.get('private_key_pass')
- )
+ def get_conn(self) -> paramiko.SFTPClient: # type: ignore[override]
+ """
+ Opens an SFTP connection to the remote host
- if 'ignore_hostkey_verification' in extra_options:
- warnings.warn(
- 'Extra option `ignore_hostkey_verification` is
deprecated.'
- 'Please use `no_host_key_check` instead.'
- 'This option will be removed in a future release.',
- DeprecationWarning,
- stacklevel=2,
- )
- self.no_host_key_check = (
-
str(extra_options['ignore_hostkey_verification']).lower() == 'true'
- )
-
- if 'no_host_key_check' in extra_options:
- self.no_host_key_check =
str(extra_options['no_host_key_check']).lower() == 'true'
-
- if 'ciphers' in extra_options:
- self.ciphers = extra_options['ciphers']
-
- @tenacity.retry(
- stop=tenacity.stop_after_delay(10),
- wait=tenacity.wait_exponential(multiplier=1, max=10),
- retry=tenacity.retry_if_exception_type(SSHException),
- reraise=True,
- )
- def get_conn(self) -> pysftp.Connection:
- """Returns an SFTP connection object"""
+ :rtype: paramiko.SFTPClient
+ """
if self.conn is None:
- cnopts = pysftp.CnOpts()
- if self.no_host_key_check:
- cnopts.hostkeys = None
+ # TODO: remove support for ssh_hook when it is removed from
SFTPOperator
+ if self.ssh_hook is not None:
+ self.conn = self.ssh_hook.get_conn().open_sftp()
else:
- if self.host_key is not None:
- cnopts.hostkeys.add(self.remote_host,
self.host_key.get_name(), self.host_key)
- else:
- pass # will fallback to system host keys if none
explicitly specified in conn extra
-
- cnopts.compression = self.compress
- cnopts.ciphers = self.ciphers
- conn_params = {
- 'host': self.remote_host,
- 'port': self.port,
- 'username': self.username,
- 'cnopts': cnopts,
- }
- if self.password and self.password.strip():
- conn_params['password'] = self.password
- if self.pkey:
- conn_params['private_key'] = self.pkey
- elif self.key_file:
- conn_params['private_key'] = self.key_file
- if self.private_key_pass:
- conn_params['private_key_pass'] = self.private_key_pass
-
- self.conn = pysftp.Connection(**conn_params)
+ self.conn = super().get_conn().open_sftp()
return self.conn
def close_conn(self) -> None:
- """Closes the connection"""
+ """Closes the SFTP connection"""
if self.conn is not None:
self.conn.close()
self.conn = None
- def describe_directory(self, path: str) -> Dict[str, Dict[str, str]]:
+ def describe_directory(self, path: str) -> Dict[str, Dict[str, Union[str,
int, None]]]:
"""
Returns a dictionary of {filename: {attributes}} for all files
on the remote system (where the MLSD command is supported).
@@ -185,13 +136,13 @@ class SFTPHook(SSHHook):
:param path: full path to the remote directory
"""
conn = self.get_conn()
- flist = conn.listdir_attr(path)
+ flist = sorted(conn.listdir_attr(path), key=lambda x: x.filename)
files = {}
for f in flist:
- modify =
datetime.datetime.fromtimestamp(f.st_mtime).strftime('%Y%m%d%H%M%S')
+ modify =
datetime.datetime.fromtimestamp(f.st_mtime).strftime('%Y%m%d%H%M%S') # type:
ignore
files[f.filename] = {
'size': f.st_size,
- 'type': 'dir' if stat.S_ISDIR(f.st_mode) else 'file',
+ 'type': 'dir' if stat.S_ISDIR(f.st_mode) else 'file', # type:
ignore
'modify': modify,
}
return files
@@ -203,9 +154,45 @@ class SFTPHook(SSHHook):
:param path: full path to the remote directory to list
"""
conn = self.get_conn()
- files = conn.listdir(path)
+ files = sorted(conn.listdir(path))
return files
+ def mkdir(self, path: str, mode: int = 777) -> None:
+ """
+ Creates a directory on the remote system.
+
+ :param path: full path to the remote directory to create
+ :param mode: permissions to set the directory with
+ """
+ conn = self.get_conn()
+ conn.mkdir(path, mode=int(str(mode), 8))
+
+ def isdir(self, path: str) -> bool:
+ """
+ Checks if the path provided is a directory or not.
+
+ :param path: full path to the remote directory to check
+ """
+ conn = self.get_conn()
+ try:
+ result = stat.S_ISDIR(conn.stat(path).st_mode) # type: ignore
+ except OSError:
+ result = False
+ return result
+
+ def isfile(self, path: str) -> bool:
+ """
+ Checks if the path provided is a file or not.
+
+ :param path: full path to the remote file to check
+ """
+ conn = self.get_conn()
+ try:
+ result = stat.S_ISREG(conn.stat(path).st_mode) # type: ignore
+ except OSError:
+ result = False
+ return result
+
def create_directory(self, path: str, mode: int = 777) -> None:
"""
Creates a directory on the remote system.
@@ -214,7 +201,18 @@ class SFTPHook(SSHHook):
:param mode: int representation of octal mode for directory
"""
conn = self.get_conn()
- conn.makedirs(path, mode)
+ if self.isdir(path):
+ self.log.info(f"{path} already exists")
+ return
+ elif self.isfile(path):
+ raise AirflowException(f"{path} already exists and is a file")
+ else:
+ dirname, basename = os.path.split(path)
+ if dirname and not self.isdir(dirname):
+ self.create_directory(dirname, mode)
+ if basename:
+ self.log.info(f"Creating {path}")
+ conn.mkdir(path, mode=mode)
def delete_directory(self, path: str) -> None:
"""
@@ -237,7 +235,7 @@ class SFTPHook(SSHHook):
conn = self.get_conn()
conn.get(remote_full_path, local_full_path)
- def store_file(self, remote_full_path: str, local_full_path: str) -> None:
+ def store_file(self, remote_full_path: str, local_full_path: str, confirm:
bool = True) -> None:
"""
Transfers a local file to the remote location.
If local_full_path_or_buffer is a string path, the file will be read
@@ -247,7 +245,7 @@ class SFTPHook(SSHHook):
:param local_full_path: full path to the local file
"""
conn = self.get_conn()
- conn.put(local_full_path, remote_full_path)
+ conn.put(local_full_path, remote_full_path, confirm=confirm)
def delete_file(self, path: str) -> None:
"""
@@ -266,7 +264,7 @@ class SFTPHook(SSHHook):
"""
conn = self.get_conn()
ftp_mdtm = conn.stat(path).st_mtime
- return
datetime.datetime.fromtimestamp(ftp_mdtm).strftime('%Y%m%d%H%M%S')
+ return
datetime.datetime.fromtimestamp(ftp_mdtm).strftime('%Y%m%d%H%M%S') # type:
ignore
def path_exists(self, path: str) -> bool:
"""
@@ -275,7 +273,11 @@ class SFTPHook(SSHHook):
:param path: full path to the remote file or directory
"""
conn = self.get_conn()
- return conn.exists(path)
+ try:
+ conn.stat(path)
+ except OSError:
+ return False
+ return True
@staticmethod
def _is_path_match(path: str, prefix: Optional[str] = None, delimiter:
Optional[str] = None) -> bool:
@@ -293,6 +295,51 @@ class SFTPHook(SSHHook):
return False
return True
+ def walktree(
+ self,
+ path: str,
+ fcallback: Callable[[str], Optional[Any]],
+ dcallback: Callable[[str], Optional[Any]],
+ ucallback: Callable[[str], Optional[Any]],
+ recurse: bool = True,
+ ) -> None:
+ """
+ Recursively descend, depth first, the directory tree rooted at
+ path, calling discreet callback functions for each regular file,
+ directory and unknown file type.
+
+ :param str path:
+ root of remote directory to descend, use '.' to start at
+ :attr:`.pwd`
+ :param callable fcallback:
+ callback function to invoke for a regular file.
+ (form: ``func(str)``)
+ :param callable dcallback:
+ callback function to invoke for a directory. (form: ``func(str)``)
+ :param callable ucallback:
+ callback function to invoke for an unknown file type.
+ (form: ``func(str)``)
+ :param bool recurse: *Default: True* - should it recurse
+
+ :returns: None
+ """
+ conn = self.get_conn()
+ for entry in self.list_directory(path):
+ pathname = os.path.join(path, entry)
+ mode = conn.stat(pathname).st_mode
+ if stat.S_ISDIR(mode): # type: ignore
+ # It's a directory, call the dcallback function
+ dcallback(pathname)
+ if recurse:
+ # now, recurse into it
+ self.walktree(pathname, fcallback, dcallback, ucallback)
+ elif stat.S_ISREG(mode): # type: ignore
+ # It's a file, call the fcallback function
+ fcallback(pathname)
+ else:
+ # Unknown file type
+ ucallback(pathname)
+
def get_tree_map(
self, path: str, prefix: Optional[str] = None, delimiter:
Optional[str] = None
) -> Tuple[List[str], List[str], List[str]]:
@@ -306,14 +353,15 @@ class SFTPHook(SSHHook):
:return: tuple with list of files, dirs and unknown items
:rtype: Tuple[List[str], List[str], List[str]]
"""
- conn = self.get_conn()
- files, dirs, unknowns = [], [], [] # type: List[str], List[str],
List[str]
+ files: List[str] = []
+ dirs: List[str] = []
+ unknowns: List[str] = []
- def append_matching_path_callback(list_):
+ def append_matching_path_callback(list_: List[str]) -> Callable:
return lambda item: list_.append(item) if
self._is_path_match(item, prefix, delimiter) else None
- conn.walktree(
- remotepath=path,
+ self.walktree(
+ path=path,
fcallback=append_matching_path_callback(files),
dcallback=append_matching_path_callback(dirs),
ucallback=append_matching_path_callback(unknowns),
@@ -326,7 +374,7 @@ class SFTPHook(SSHHook):
"""Test the SFTP connection by calling path with directory"""
try:
conn = self.get_conn()
- conn.pwd
+ conn.normalize('.')
return True, "Connection successfully tested"
except Exception as e:
return False, str(e)
diff --git a/airflow/providers/sftp/operators/sftp.py
b/airflow/providers/sftp/operators/sftp.py
index c78c7f4c04..4806e982e7 100644
--- a/airflow/providers/sftp/operators/sftp.py
+++ b/airflow/providers/sftp/operators/sftp.py
@@ -17,16 +17,18 @@
# under the License.
"""This module contains SFTP operator."""
import os
+import warnings
from pathlib import Path
-from typing import Any, Sequence
+from typing import Any, Optional, Sequence
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
+from airflow.providers.sftp.hooks.sftp import SFTPHook
from airflow.providers.ssh.hooks.ssh import SSHHook
class SFTPOperation:
- """Operation that can be used with SFTP/"""
+ """Operation that can be used with SFTP"""
PUT = 'put'
GET = 'get'
@@ -35,17 +37,19 @@ class SFTPOperation:
class SFTPOperator(BaseOperator):
"""
SFTPOperator for transferring files from remote host to local or vice a
versa.
- This operator uses ssh_hook to open sftp transport channel that serve as
basis
+ This operator uses sftp_hook to open sftp transport channel that serve as
basis
for file transfer.
- :param ssh_hook: predefined ssh_hook to use for remote execution.
- Either `ssh_hook` or `ssh_conn_id` needs to be provided.
:param ssh_conn_id: :ref:`ssh connection id<howto/connection:ssh>`
from airflow Connections. `ssh_conn_id` will be ignored if `ssh_hook`
- is provided.
+ or `sftp_hook` is provided.
+ :param sftp_hook: predefined SFTPHook to use
+ Either `sftp_hook` or `ssh_conn_id` needs to be provided.
+ :param ssh_hook: Deprecated - predefined SSHHook to use for remote
execution
+ Use `sftp_hook` instead.
:param remote_host: remote host to connect (templated)
Nullable. If provided, it will replace the `remote_host` which was
- defined in `ssh_hook` or predefined in the connection of `ssh_conn_id`.
+ defined in `sftp_hook`/`ssh_hook` or predefined in the connection of
`ssh_conn_id`.
:param local_filepath: local file path to get or put. (templated)
:param remote_filepath: remote file path to get or put. (templated)
:param operation: specify operation 'get' or 'put', defaults to put
@@ -75,18 +79,20 @@ class SFTPOperator(BaseOperator):
def __init__(
self,
*,
- ssh_hook=None,
- ssh_conn_id=None,
- remote_host=None,
- local_filepath=None,
- remote_filepath=None,
- operation=SFTPOperation.PUT,
- confirm=True,
- create_intermediate_dirs=False,
+ ssh_hook: Optional[SSHHook] = None,
+ sftp_hook: Optional[SFTPHook] = None,
+ ssh_conn_id: Optional[str] = None,
+ remote_host: Optional[str] = None,
+ local_filepath: str,
+ remote_filepath: str,
+ operation: str = SFTPOperation.PUT,
+ confirm: bool = True,
+ create_intermediate_dirs: bool = False,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.ssh_hook = ssh_hook
+ self.sftp_hook = sftp_hook
self.ssh_conn_id = ssh_conn_id
self.remote_host = remote_host
self.local_filepath = local_filepath
@@ -94,79 +100,72 @@ class SFTPOperator(BaseOperator):
self.operation = operation
self.confirm = confirm
self.create_intermediate_dirs = create_intermediate_dirs
+
if not (self.operation.lower() == SFTPOperation.GET or
self.operation.lower() == SFTPOperation.PUT):
raise TypeError(
f"Unsupported operation value {self.operation}, "
f"expected {SFTPOperation.GET} or {SFTPOperation.PUT}."
)
- def execute(self, context: Any) -> str:
+ # TODO: remove support for ssh_hook in next major provider version in
hook and operator
+ if self.ssh_hook is not None and self.sftp_hook is not None:
+ raise AirflowException(
+ 'Both `ssh_hook` and `sftp_hook` are defined. Please use only
one of them.'
+ )
+
+ if self.ssh_hook is not None:
+ if not isinstance(self.ssh_hook, SSHHook):
+ self.log.info('ssh_hook is invalid. Trying ssh_conn_id to
create SFTPHook.')
+ self.sftp_hook = SFTPHook(ssh_conn_id=self.ssh_conn_id)
+ if self.sftp_hook is None:
+ warnings.warn(
+ 'Parameter `ssh_hook` is deprecated'
+ 'Please use `sftp_hook` instead.'
+ 'The old parameter `ssh_hook` will be removed in a future
version.',
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ self.sftp_hook = SFTPHook(ssh_hook=self.ssh_hook)
+
+ def execute(self, context: Any) -> Optional[str]:
file_msg = None
try:
if self.ssh_conn_id:
- if self.ssh_hook and isinstance(self.ssh_hook, SSHHook):
- self.log.info("ssh_conn_id is ignored when ssh_hook is
provided.")
+ if self.sftp_hook and isinstance(self.sftp_hook, SFTPHook):
+ self.log.info("ssh_conn_id is ignored when
sftp_hook/ssh_hook is provided.")
else:
self.log.info(
- "ssh_hook is not provided or invalid. Trying
ssh_conn_id to create SSHHook."
+ 'sftp_hook/ssh_hook not provided or invalid. Trying
ssh_conn_id to create SFTPHook.'
)
- self.ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id)
+ self.sftp_hook = SFTPHook(ssh_conn_id=self.ssh_conn_id)
- if not self.ssh_hook:
- raise AirflowException("Cannot operate without ssh_hook or
ssh_conn_id.")
+ if not self.sftp_hook:
+ raise AirflowException("Cannot operate without sftp_hook or
ssh_conn_id.")
if self.remote_host is not None:
self.log.info(
"remote_host is provided explicitly. "
"It will replace the remote_host which was defined "
- "in ssh_hook or predefined in connection of ssh_conn_id."
+ "in sftp_hook or predefined in connection of ssh_conn_id."
)
- self.ssh_hook.remote_host = self.remote_host
-
- with self.ssh_hook.get_conn() as ssh_client:
- sftp_client = ssh_client.open_sftp()
- if self.operation.lower() == SFTPOperation.GET:
- local_folder = os.path.dirname(self.local_filepath)
- if self.create_intermediate_dirs:
- Path(local_folder).mkdir(parents=True, exist_ok=True)
- file_msg = f"from {self.remote_filepath} to
{self.local_filepath}"
- self.log.info("Starting to transfer %s", file_msg)
- sftp_client.get(self.remote_filepath, self.local_filepath)
- else:
- remote_folder = os.path.dirname(self.remote_filepath)
- if self.create_intermediate_dirs:
- _make_intermediate_dirs(
- sftp_client=sftp_client,
- remote_directory=remote_folder,
- )
- file_msg = f"from {self.local_filepath} to
{self.remote_filepath}"
- self.log.info("Starting to transfer file %s", file_msg)
- sftp_client.put(self.local_filepath, self.remote_filepath,
confirm=self.confirm)
+ self.sftp_hook.remote_host = self.remote_host
+
+ if self.operation.lower() == SFTPOperation.GET:
+ local_folder = os.path.dirname(self.local_filepath)
+ if self.create_intermediate_dirs:
+ Path(local_folder).mkdir(parents=True, exist_ok=True)
+ file_msg = f"from {self.remote_filepath} to
{self.local_filepath}"
+ self.log.info("Starting to transfer %s", file_msg)
+ self.sftp_hook.retrieve_file(self.remote_filepath,
self.local_filepath)
+ else:
+ remote_folder = os.path.dirname(self.remote_filepath)
+ if self.create_intermediate_dirs:
+ self.sftp_hook.create_directory(remote_folder)
+ file_msg = f"from {self.local_filepath} to
{self.remote_filepath}"
+ self.log.info("Starting to transfer file %s", file_msg)
+ self.sftp_hook.store_file(self.remote_filepath,
self.local_filepath, confirm=self.confirm)
except Exception as e:
raise AirflowException(f"Error while transferring {file_msg},
error: {str(e)}")
return self.local_filepath
-
-
-def _make_intermediate_dirs(sftp_client, remote_directory) -> None:
- """
- Create all the intermediate directories in a remote host
-
- :param sftp_client: A Paramiko SFTP client.
- :param remote_directory: Absolute Path of the directory containing the file
- :return:
- """
- if remote_directory == '/':
- sftp_client.chdir('/')
- return
- if remote_directory == '':
- return
- try:
- sftp_client.chdir(remote_directory)
- except OSError:
- dirname, basename = os.path.split(remote_directory.rstrip('/'))
- _make_intermediate_dirs(sftp_client, dirname)
- sftp_client.mkdir(basename)
- sftp_client.chdir(basename)
- return
diff --git a/airflow/providers/ssh/hooks/ssh.py
b/airflow/providers/ssh/hooks/ssh.py
index 53c74f44c5..51240820ec 100644
--- a/airflow/providers/ssh/hooks/ssh.py
+++ b/airflow/providers/ssh/hooks/ssh.py
@@ -22,7 +22,7 @@ import warnings
from base64 import decodebytes
from io import StringIO
from select import select
-from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union
+from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union
import paramiko
from paramiko.config import SSH_PORT
@@ -71,6 +71,7 @@ class SSHHook(BaseHook):
:param disabled_algorithms: dictionary mapping algorithm type to an
iterable of algorithm identifiers, which will be disabled for the
lifetime of the transport
+ :param ciphers: list of ciphers to use in order of preference
"""
# List of classes to try loading private keys as, ordered (roughly) by
most common to least common
@@ -116,6 +117,7 @@ class SSHHook(BaseHook):
keepalive_interval: int = 30,
banner_timeout: float = 30.0,
disabled_algorithms: Optional[dict] = None,
+ ciphers: Optional[List[str]] = None,
) -> None:
super().__init__()
self.ssh_conn_id = ssh_conn_id
@@ -130,6 +132,7 @@ class SSHHook(BaseHook):
self.keepalive_interval = keepalive_interval
self.banner_timeout = banner_timeout
self.disabled_algorithms = disabled_algorithms
+ self.ciphers = ciphers
self.host_proxy_cmd = None
# Default values, overridable from Connection
@@ -205,6 +208,9 @@ class SSHHook(BaseHook):
if "disabled_algorithms" in extra_options:
self.disabled_algorithms =
extra_options.get("disabled_algorithms")
+ if "ciphers" in extra_options:
+ self.ciphers = extra_options.get("ciphers")
+
if host_key is not None:
if host_key.startswith("ssh-"):
key_type, host_key = host_key.split(None)[:2]
@@ -342,6 +348,11 @@ class SSHHook(BaseHook):
# type "Optional[Transport]" and item "None" has no attribute
"set_keepalive".
client.get_transport().set_keepalive(self.keepalive_interval) #
type: ignore[union-attr]
+ if self.ciphers:
+ # MyPy check ignored because "paramiko" isn't well-typed. The
`client.get_transport()` returns
+ # type "Optional[Transport]" and item "None" has no method
`get_security_options`".
+ client.get_transport().get_security_options().ciphers =
self.ciphers # type: ignore[union-attr]
+
self.client = client
return client
diff --git a/docker_tests/test_prod_image.py b/docker_tests/test_prod_image.py
index aae374c4d0..e2ae74593a 100644
--- a/docker_tests/test_prod_image.py
+++ b/docker_tests/test_prod_image.py
@@ -162,7 +162,7 @@ class TestPythonPackages:
"pyodbc": ["pyodbc"],
"redis": ["redis"],
"sendgrid": ["sendgrid"],
- "sftp/ssh": ["paramiko", "pysftp", "sshtunnel"],
+ "sftp/ssh": ["paramiko", "sshtunnel"],
"slack": ["slack_sdk"],
"statsd": ["statsd"],
"virtualenv": ["virtualenv"],
diff --git a/docs/apache-airflow-providers-sftp/connections/sftp.rst
b/docs/apache-airflow-providers-sftp/connections/sftp.rst
index 58133bb25f..95c88ebb64 100644
--- a/docs/apache-airflow-providers-sftp/connections/sftp.rst
+++ b/docs/apache-airflow-providers-sftp/connections/sftp.rst
@@ -29,10 +29,8 @@ Authenticating to SFTP
There are two ways to connect to SFTP using Airflow.
-1. Use `host key
- <https://pysftp.readthedocs.io/en/release_0.2.9/pysftp.html#pysftp.CnOpts>`_
- i.e. host key entered in extras value ``host_key``.
-2. Use ``private_key`` or ``key_file``, along with the optional
``private_key_pass``
+1. Use ``login`` and ``password``.
+2. Use ``private_key`` or ``key_file``, along with the optional
``private_key_passphrase``
Only one authorization method can be used at a time. If you need to manage
multiple credentials or keys then you should
configure multiple connections.
@@ -61,17 +59,18 @@ Extra (optional)
Specify the extra parameters (as json dictionary) that can be used in sftp
connection.
The following parameters are all optional:
- * ``private_key_pass``: Specify the password to use, if private_key is
encrypted.
- * ``no_host_key_check``: Set to false to restrict connecting to hosts with
either no entries in ~/.ssh/known_hosts
- (Hosts file) or not present in the host_key extra. This provides maximum
protection against trojan horse attacks,
- but can be troublesome when the /etc/ssh/ssh_known_hosts file is poorly
maintained or connections to new hosts are
- frequently made. This option forces the user to manually add all new
hosts. Default is true, ssh will automatically
- add new host keys to the user known hosts files.
- * ``host_key``: The base64 encoded ssh-rsa public key of the host, as you
would find in the known_hosts file.
- Specifying this, along with no_host_key_check=False allows you to only
make the connection if the public key of
- the endpoint matches this value.
- * ``private_key`` Specify the content of the private key, the path to the
private key file(str) or paramiko.AgentKey
* ``key_file`` - Full Path of the private SSH Key file that will be used
to connect to the remote_host.
+ * ``private_key`` - Content of the private key used to connect to the
remote_host.
+ * ``private_key_passphrase`` - Content of the private key passphrase used
to decrypt the private key.
+ * ``conn_timeout`` - An optional timeout (in seconds) for the TCP connect.
Default is ``10``.
+ * ``timeout`` - Deprecated - use conn_timeout instead.
+ * ``compress`` - ``true`` to ask the remote client/server to compress
traffic; ``false`` to refuse compression. Default is ``true``.
+ * ``no_host_key_check`` - Set to ``false`` to restrict connecting to hosts
with no entries in ``~/.ssh/known_hosts`` (Hosts file). This provides maximum
protection against trojan horse attacks, but can be troublesome when the
``/etc/ssh/ssh_known_hosts`` file is poorly maintained or connections to new
hosts are frequently made. This option forces the user to manually add all new
hosts. Default is ``true``, ssh will automatically add new host keys to the
user known hosts files.
+ * ``allow_host_key_change`` - Set to ``true`` if you want to allow
connecting to hosts that has host key changed or when you get 'REMOTE HOST
IDENTIFICATION HAS CHANGED' error. This won't protect against
Man-In-The-Middle attacks. Other possible solution is to remove the host entry
from ``~/.ssh/known_hosts`` file. Default is ``false``.
+ * ``look_for_keys`` - Set to ``false`` if you want to disable searching
for discoverable private key files in ``~/.ssh/``
+ * ``host_key`` - The base64 encoded ssh-rsa public key of the host or
"ssh-<key type> <key data>" (as you would find in the ``known_hosts`` file).
Specifying this allows making the connection if and only if the public key of
the endpoint matches this value.
+ * ``disabled_algorithms`` - A dictionary mapping algorithm type to an
iterable of algorithm identifiers, which will be disabled for the lifetime of
the transport.
+ * ``ciphers`` - A list of ciphers to use in order of preference.
Example “extras” field using ``host_key``:
diff --git a/docs/apache-airflow-providers-ssh/connections/ssh.rst
b/docs/apache-airflow-providers-ssh/connections/ssh.rst
index b91c0854a1..5d46010c77 100644
--- a/docs/apache-airflow-providers-ssh/connections/ssh.rst
+++ b/docs/apache-airflow-providers-ssh/connections/ssh.rst
@@ -55,6 +55,7 @@ Extra (optional)
* ``look_for_keys`` - Set to ``false`` if you want to disable searching
for discoverable private key files in ``~/.ssh/``
* ``host_key`` - The base64 encoded ssh-rsa public key of the host or
"ssh-<key type> <key data>" (as you would find in the ``known_hosts`` file).
Specifying this allows making the connection if and only if the public key of
the endpoint matches this value.
* ``disabled_algorithms`` - A dictionary mapping algorithm type to an
iterable of algorithm identifiers, which will be disabled for the lifetime of
the transport.
+ * ``ciphers`` - A list of ciphers to use in order of preference.
Example "extras" field:
@@ -66,8 +67,9 @@ Extra (optional)
"compress": "false",
"look_for_keys": "false",
"allow_host_key_change": "false",
- "host_key": "AAAHD...YDWwq=="
- "disabled_algorithms": {"pubkeys": ["rsa-sha2-256", "rsa-sha2-512"]}
+ "host_key": "AAAHD...YDWwq==",
+ "disabled_algorithms": {"pubkeys": ["rsa-sha2-256", "rsa-sha2-512"]},
+ "ciphers": ["aes128-ctr", "aes192-ctr", "aes256-ctr"]
}
When specifying the connection as URI (in :envvar:`AIRFLOW_CONN_{CONN_ID}`
variable) you should specify it
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index 1e3142b654..b43440f9ee 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -346,6 +346,7 @@ ResourceRequirements
Roadmap
Robinhood
RoleBinding
+SFTPClient
SIGTERM
SSHClient
SSHTunnelForwarder
@@ -1250,6 +1251,7 @@ readme
readthedocs
realtime
rebase
+recurse
recurses
redbubble
redis
diff --git a/setup.py b/setup.py
index 9d3243f2b9..ebcf944a69 100644
--- a/setup.py
+++ b/setup.py
@@ -537,7 +537,6 @@ spark = [
]
ssh = [
'paramiko>=2.6.0',
- 'pysftp>=0.2.9',
'sshtunnel>=0.3.2',
]
statsd = [
@@ -640,7 +639,6 @@ devel_only = [
'pre-commit',
'pypsrp',
'pygithub',
- 'pysftp',
# Pytest 7 has been released in February 2022 and we should attempt to
upgrade and remove the limit
# It contains a number of potential breaking changes but none of them
looks breaking our use
# https://docs.pytest.org/en/latest/changelog.html#pytest-7-0-0-2022-02-03
diff --git a/tests/providers/sftp/hooks/test_sftp.py
b/tests/providers/sftp/hooks/test_sftp.py
index 95bb971bdf..7ebc57e7a8 100644
--- a/tests/providers/sftp/hooks/test_sftp.py
+++ b/tests/providers/sftp/hooks/test_sftp.py
@@ -23,11 +23,13 @@ from io import StringIO
from unittest import mock
import paramiko
-import pysftp
+import pytest
from parameterized import parameterized
+from airflow.exceptions import AirflowException
from airflow.models import Connection
from airflow.providers.sftp.hooks.sftp import SFTPHook
+from airflow.providers.ssh.hooks.ssh import SSHHook
from airflow.utils.session import provide_session
@@ -45,6 +47,7 @@ SUB_DIR = "sub_dir"
TMP_FILE_FOR_TESTS = 'test_file.txt'
ANOTHER_FILE_FOR_TESTS = 'test_file_1.txt'
LOG_FILE_FOR_TESTS = 'test_log.log'
+FIFO_FOR_TESTS = 'test_fifo'
SFTP_CONNECTION_USER = "root"
@@ -59,6 +62,7 @@ class TestSFTPHook(unittest.TestCase):
connection = session.query(Connection).filter(Connection.conn_id ==
"sftp_default").first()
old_login = connection.login
connection.login = login
+ connection.extra = '' # clear out extra so it doesn't look for a key
file
session.commit()
return old_login
@@ -76,10 +80,11 @@ class TestSFTPHook(unittest.TestCase):
file.write('Test file')
with open(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR,
TMP_FILE_FOR_TESTS), 'a') as file:
file.write('Test file')
+ os.mkfifo(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, FIFO_FOR_TESTS))
def test_get_conn(self):
output = self.hook.get_conn()
- assert isinstance(output, pysftp.Connection)
+ assert isinstance(output, paramiko.SFTPClient)
def test_close_conn(self):
self.hook.conn = self.hook.get_conn()
@@ -93,13 +98,24 @@ class TestSFTPHook(unittest.TestCase):
def test_list_directory(self):
output = self.hook.list_directory(path=os.path.join(TMP_PATH,
TMP_DIR_FOR_TESTS))
- assert output == [SUB_DIR]
+ assert output == [SUB_DIR, FIFO_FOR_TESTS]
+
+ def test_mkdir(self):
+ new_dir_name = 'mk_dir'
+ self.hook.mkdir(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS,
new_dir_name))
+ output = self.hook.describe_directory(os.path.join(TMP_PATH,
TMP_DIR_FOR_TESTS))
+ assert new_dir_name in output
def test_create_and_delete_directory(self):
new_dir_name = 'new_dir'
self.hook.create_directory(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS,
new_dir_name))
output = self.hook.describe_directory(os.path.join(TMP_PATH,
TMP_DIR_FOR_TESTS))
assert new_dir_name in output
+ # test directory already exists for code coverage, should not raise an
exception
+ self.hook.create_directory(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS,
new_dir_name))
+ # test path already exists and is a file, should raise an exception
+ with pytest.raises(AirflowException, match="already exists and is a
file"):
+ self.hook.create_directory(os.path.join(TMP_PATH,
TMP_DIR_FOR_TESTS, SUB_DIR, TMP_FILE_FOR_TESTS))
self.hook.delete_directory(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS,
new_dir_name))
output = self.hook.describe_directory(os.path.join(TMP_PATH,
TMP_DIR_FOR_TESTS))
assert new_dir_name not in output
@@ -125,7 +141,7 @@ class TestSFTPHook(unittest.TestCase):
local_full_path=os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS),
)
output = self.hook.list_directory(path=os.path.join(TMP_PATH,
TMP_DIR_FOR_TESTS))
- assert output == [SUB_DIR, TMP_FILE_FOR_TESTS]
+ assert output == [SUB_DIR, FIFO_FOR_TESTS, TMP_FILE_FOR_TESTS]
retrieved_file_name = 'retrieved.txt'
self.hook.retrieve_file(
remote_full_path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS,
TMP_FILE_FOR_TESTS),
@@ -135,7 +151,7 @@ class TestSFTPHook(unittest.TestCase):
os.remove(os.path.join(TMP_PATH, retrieved_file_name))
self.hook.delete_file(path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS,
TMP_FILE_FOR_TESTS))
output = self.hook.list_directory(path=os.path.join(TMP_PATH,
TMP_DIR_FOR_TESTS))
- assert output == [SUB_DIR]
+ assert output == [SUB_DIR, FIFO_FOR_TESTS]
def test_get_mod_time(self):
self.hook.store_file(
@@ -150,15 +166,15 @@ class TestSFTPHook(unittest.TestCase):
connection = Connection(login='login', host='host')
get_connection.return_value = connection
hook = SFTPHook()
- assert hook.no_host_key_check is False
+ assert hook.no_host_key_check is True
@mock.patch('airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection')
def test_no_host_key_check_enabled(self, get_connection):
- connection = Connection(login='login', host='host',
extra='{"no_host_key_check": true}')
+ connection = Connection(login='login', host='host',
extra='{"no_host_key_check": false}')
get_connection.return_value = connection
hook = SFTPHook()
- assert hook.no_host_key_check is True
+ assert hook.no_host_key_check is False
@mock.patch('airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection')
def test_no_host_key_check_disabled(self, get_connection):
@@ -192,14 +208,6 @@ class TestSFTPHook(unittest.TestCase):
hook = SFTPHook()
assert hook.no_host_key_check is True
- @mock.patch('airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection')
- def test_no_host_key_check_no_ignore(self, get_connection):
- connection = Connection(login='login', host='host',
extra='{"ignore_hostkey_verification": false}')
-
- get_connection.return_value = connection
- hook = SFTPHook()
- assert hook.no_host_key_check is False
-
@mock.patch('airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection')
def test_host_key_default(self, get_connection):
connection = Connection(login='login', host='host')
@@ -309,7 +317,7 @@ class TestSFTPHook(unittest.TestCase):
assert files == [os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR,
TMP_FILE_FOR_TESTS)]
assert dirs == [os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR)]
- assert unknowns == []
+ assert unknowns == [os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS,
FIFO_FOR_TESTS)]
@mock.patch('airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection')
def test_connection_failure(self, mock_get_connection):
@@ -319,7 +327,9 @@ class TestSFTPHook(unittest.TestCase):
)
mock_get_connection.return_value = connection
with mock.patch.object(SFTPHook, 'get_conn') as get_conn:
- type(get_conn.return_value).pwd =
mock.PropertyMock(side_effect=Exception('Connection Error'))
+ type(get_conn.return_value).normalize = mock.PropertyMock(
+ side_effect=Exception('Connection Error')
+ )
hook = SFTPHook()
status, msg = hook.test_connection()
@@ -360,6 +370,21 @@ class TestSFTPHook(unittest.TestCase):
# Default is 'sftp_default
assert SFTPHook().ssh_conn_id == 'sftp_default'
+ @mock.patch('airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection')
+ def test_invalid_ssh_hook(self, mock_get_connection):
+ with pytest.raises(AirflowException, match="ssh_hook must be an
instance of SSHHook"):
+ connection = Connection(conn_id='sftp_default', login='root',
host='localhost')
+ mock_get_connection.return_value = connection
+ SFTPHook(ssh_hook='invalid_hook') # type: ignore
+
+ @mock.patch('airflow.providers.ssh.hooks.ssh.SSHHook.get_connection')
+ def test_valid_ssh_hook(self, mock_get_connection):
+ connection = Connection(conn_id='sftp_test', login='root',
host='localhost')
+ mock_get_connection.return_value = connection
+ hook = SFTPHook(ssh_hook=SSHHook(ssh_conn_id='sftp_test'))
+ assert hook.ssh_conn_id == 'sftp_test'
+ assert isinstance(hook.get_conn(), paramiko.SFTPClient)
+
def test_get_suffix_pattern_match(self):
output = self.hook.get_file_by_pattern(TMP_PATH, "*.txt")
self.assertTrue(output, TMP_FILE_FOR_TESTS)
diff --git a/tests/providers/sftp/operators/test_sftp.py
b/tests/providers/sftp/operators/test_sftp.py
index 6aa9bb0169..544b7cc5ea 100644
--- a/tests/providers/sftp/operators/test_sftp.py
+++ b/tests/providers/sftp/operators/test_sftp.py
@@ -23,7 +23,9 @@ import pytest
from airflow.exceptions import AirflowException
from airflow.models import DAG
+from airflow.providers.sftp.hooks.sftp import SFTPHook
from airflow.providers.sftp.operators.sftp import SFTPOperation, SFTPOperator
+from airflow.providers.ssh.hooks.ssh import SSHHook
from airflow.providers.ssh.operators.ssh import SSHOperator
from airflow.utils import timezone
from airflow.utils.timezone import datetime
@@ -35,7 +37,6 @@ TEST_CONN_ID = "conn_id_for_testing"
class TestSFTPOperator:
def setup_method(self):
- from airflow.providers.ssh.hooks.ssh import SSHHook
hook = SSHHook(ssh_conn_id='ssh_default')
hook.no_host_key_check = True
@@ -321,7 +322,7 @@ class TestSFTPOperator:
def test_arg_checking(self):
dag = DAG(dag_id="unit_tests_sftp_op_arg_checking",
default_args={"start_date": DEFAULT_DATE})
# Exception should be raised if neither ssh_hook nor ssh_conn_id is
provided
- with pytest.raises(AirflowException, match="Cannot operate without
ssh_hook or ssh_conn_id."):
+ with pytest.raises(AirflowException, match="Cannot operate without
sftp_hook or ssh_conn_id."):
task_0 = SFTPOperator(
task_id="test_sftp_0",
local_filepath=self.test_local_filepath,
@@ -334,7 +335,7 @@ class TestSFTPOperator:
# if ssh_hook is invalid/not provided, use ssh_conn_id to create
SSHHook
task_1 = SFTPOperator(
task_id="test_sftp_1",
- ssh_hook="string_rather_than_SSHHook", # invalid ssh_hook
+ ssh_hook="string_rather_than_SSHHook", # type: ignore
ssh_conn_id=TEST_CONN_ID,
local_filepath=self.test_local_filepath,
remote_filepath=self.test_remote_filepath,
@@ -345,7 +346,7 @@ class TestSFTPOperator:
task_1.execute(None)
except Exception:
pass
- assert task_1.ssh_hook.ssh_conn_id == TEST_CONN_ID
+ assert task_1.sftp_hook.ssh_conn_id == TEST_CONN_ID
task_2 = SFTPOperator(
task_id="test_sftp_2",
@@ -359,7 +360,7 @@ class TestSFTPOperator:
task_2.execute(None)
except Exception:
pass
- assert task_2.ssh_hook.ssh_conn_id == TEST_CONN_ID
+ assert task_2.sftp_hook.ssh_conn_id == TEST_CONN_ID
# if both valid ssh_hook and ssh_conn_id are provided, ignore
ssh_conn_id
task_3 = SFTPOperator(
@@ -375,4 +376,46 @@ class TestSFTPOperator:
task_3.execute(None)
except Exception:
pass
- assert task_3.ssh_hook.ssh_conn_id == self.hook.ssh_conn_id
+ assert task_3.sftp_hook.ssh_conn_id == self.hook.ssh_conn_id
+
+ # Exception should be raised if operation is invalid
+ with pytest.raises(TypeError, match="Unsupported operation value
invalid_operation, "):
+ task_4 = SFTPOperator(
+ task_id="test_sftp_4",
+ local_filepath=self.test_local_filepath,
+ remote_filepath=self.test_remote_filepath,
+ operation='invalid_operation',
+ dag=dag,
+ )
+ task_4.execute(None)
+
+ # Exception should be raised if both ssh_hook and sftp_hook are
provided
+ with pytest.raises(
+ AirflowException,
+ match="Both `ssh_hook` and `sftp_hook` are defined. Please use
only one of them.",
+ ):
+ task_5 = SFTPOperator(
+ task_id="test_sftp_5",
+ ssh_hook=self.hook,
+ sftp_hook=SFTPHook(),
+ local_filepath=self.test_local_filepath,
+ remote_filepath=self.test_remote_filepath,
+ operation=SFTPOperation.PUT,
+ dag=dag,
+ )
+ task_5.execute(None)
+
+ task_6 = SFTPOperator(
+ task_id="test_sftp_6",
+ ssh_conn_id=TEST_CONN_ID,
+ remote_host='remotehost',
+ local_filepath=self.test_local_filepath,
+ remote_filepath=self.test_remote_filepath,
+ operation=SFTPOperation.PUT,
+ dag=dag,
+ )
+ try:
+ task_6.execute(None)
+ except Exception:
+ pass
+ assert task_6.sftp_hook.remote_host == 'remotehost'
diff --git a/tests/providers/ssh/hooks/test_ssh.py
b/tests/providers/ssh/hooks/test_ssh.py
index 195823857e..e11482336b 100644
--- a/tests/providers/ssh/hooks/test_ssh.py
+++ b/tests/providers/ssh/hooks/test_ssh.py
@@ -78,6 +78,8 @@ TEST_ENCRYPTED_PRIVATE_KEY =
generate_key_string(pkey=TEST_PKEY, passphrase=PASS
TEST_DISABLED_ALGORITHMS = {"pubkeys": ["rsa-sha2-256", "rsa-sha2-512"]}
+TEST_CIPHERS = ["aes128-ctr", "aes192-ctr", "aes256-ctr"]
+
class TestSSHHook(unittest.TestCase):
CONN_SSH_WITH_NO_EXTRA = 'ssh_with_no_extra'
@@ -99,6 +101,7 @@ class TestSSHHook(unittest.TestCase):
'ssh_with_host_key_and_allow_host_key_changes_true'
)
CONN_SSH_WITH_EXTRA_DISABLED_ALGORITHMS =
'ssh_with_extra_disabled_algorithms'
+ CONN_SSH_WITH_EXTRA_CIPHERS = 'ssh_with_extra_ciphers'
@classmethod
def tearDownClass(cls) -> None:
@@ -119,6 +122,7 @@ class TestSSHHook(unittest.TestCase):
cls.CONN_SSH_WITH_NO_HOST_KEY_AND_NO_HOST_KEY_CHECK_FALSE,
cls.CONN_SSH_WITH_NO_HOST_KEY_AND_NO_HOST_KEY_CHECK_TRUE,
cls.CONN_SSH_WITH_EXTRA_DISABLED_ALGORITHMS,
+ cls.CONN_SSH_WITH_EXTRA_CIPHERS,
]
connections =
session.query(Connection).filter(Connection.conn_id.in_(conns_to_reset))
connections.delete(synchronize_session=False)
@@ -275,6 +279,14 @@ class TestSSHHook(unittest.TestCase):
extra=json.dumps({"disabled_algorithms":
TEST_DISABLED_ALGORITHMS}),
)
)
+ db.merge_conn(
+ Connection(
+ conn_id=cls.CONN_SSH_WITH_EXTRA_CIPHERS,
+ host='localhost',
+ conn_type='ssh',
+ extra=json.dumps({"ciphers": TEST_CIPHERS}),
+ )
+ )
@mock.patch('airflow.providers.ssh.hooks.ssh.paramiko.SSHClient')
def test_ssh_connection_with_password(self, ssh_mock):
@@ -781,6 +793,19 @@ class TestSSHHook(unittest.TestCase):
disabled_algorithms=TEST_DISABLED_ALGORITHMS,
)
+ @mock.patch('airflow.providers.ssh.hooks.ssh.paramiko.SSHClient')
+ def test_ssh_with_extra_ciphers(self, ssh_mock):
+ hook = SSHHook(
+ ssh_conn_id=self.CONN_SSH_WITH_EXTRA_CIPHERS,
+ remote_host='remote_host',
+ port='port',
+ username='username',
+ )
+
+ with hook.get_conn():
+ transport = ssh_mock.return_value.get_transport.return_value
+ assert transport.get_security_options.return_value.ciphers ==
TEST_CIPHERS
+
def test_openssh_private_key(self):
# Paramiko behaves differently with OpenSSH generated keys to paramiko
# generated keys, so we need a test one.