Repository: incubator-airflow Updated Branches: refs/heads/master 751e936ac -> fe0edeaab
[AIRFLOW-756][AIRFLOW-751] Replace ssh hook, operator & sftp operator with paramiko based Closes #1999 from jhsenjaliya/AIRFLOW-756 Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/fe0edeaa Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/fe0edeaa Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/fe0edeaa Branch: refs/heads/master Commit: fe0edeaab5a23e6b0bcd67c22ed32e8303273840 Parents: 751e936 Author: Jay <[email protected]> Authored: Thu Jul 20 22:07:30 2017 +0200 Committer: Bolke de Bruin <[email protected]> Committed: Thu Jul 20 22:07:45 2017 +0200 ---------------------------------------------------------------------- UPDATING.md | 7 + airflow/contrib/hooks/ssh_hook.py | 268 +++++++++++-------- airflow/contrib/operators/__init__.py | 2 +- airflow/contrib/operators/sftp_operator.py | 99 +++++++ .../contrib/operators/ssh_execute_operator.py | 159 ----------- airflow/contrib/operators/ssh_operator.py | 106 ++++++++ docs/code.rst | 2 +- scripts/ci/requirements.txt | 1 + setup.py | 7 +- tests/contrib/hooks/test_ssh_hook.py | 70 +++++ tests/contrib/operators/test_sftp_operator.py | 158 +++++++++++ .../operators/test_ssh_execute_operator.py | 95 ------- tests/contrib/operators/test_ssh_operator.py | 89 ++++++ tests/core.py | 52 ---- 14 files changed, 694 insertions(+), 421 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/fe0edeaa/UPDATING.md ---------------------------------------------------------------------- diff --git a/UPDATING.md b/UPDATING.md index c38c96d..aefed3e 100644 --- a/UPDATING.md +++ b/UPDATING.md @@ -117,6 +117,13 @@ If you experience problems connecting with your operator make sure you set the c Also the old P12 key file type is not supported anymore and only the new JSON key files are supported as a service account. +### SSH Hook updates, along with new SSH Operator & SFTP Operator + SSH Hook now uses Paramiko library to create ssh client connection, instead of sub-process based ssh command execution previously (<1.9.0), so this is backward incompatible. + - update SSHHook constructor + - use SSHOperator class in place of SSHExecuteOperator which is removed now. Refer test_ssh_operator.py for usage info. + - SFTPOperator is added to perform secure file transfer from serverA to serverB. Refer test_sftp_operator.py.py for usage info. + - No updates are required if you are using ftpHook, it will continue work as is. + ### Deprecated Features These features are marked for deprecation. They may still work (and raise a `DeprecationWarning`), but are no longer supported and will be removed entirely in Airflow 2.0 http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/fe0edeaa/airflow/contrib/hooks/ssh_hook.py ---------------------------------------------------------------------- diff --git a/airflow/contrib/hooks/ssh_hook.py b/airflow/contrib/hooks/ssh_hook.py index e63a65d..f1e25a6 100755 --- a/airflow/contrib/hooks/ssh_hook.py +++ b/airflow/contrib/hooks/ssh_hook.py @@ -14,125 +14,144 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# -# This is a port of Luigi's ssh implementation. All credits go there. -import subprocess -from contextlib import contextmanager - -from airflow.hooks.base_hook import BaseHook -from airflow.exceptions import AirflowException +import getpass import logging +import os +import paramiko -class SSHHook(BaseHook): - """ - Light-weight remote execution library and utilities. - - Using this hook (which is just a convenience wrapper for subprocess), - is created to let you stream data from a remotely stored file. +from contextlib import contextmanager +from airflow.exceptions import AirflowException +from airflow.hooks.base_hook import BaseHook - As a bonus, :class:`SSHHook` also provides a really cool feature that let's you - set up ssh tunnels super easily using a python context manager (there is an example - in the integration part of unittests). - :param key_file: Typically the SSHHook uses the keys that are used by the user - airflow is running under. This sets the behavior to use another file instead. +class SSHHook(BaseHook): + """ + Hook for ssh remote execution using Paramiko. + ref: https://github.com/paramiko/paramiko + This hook also lets you create ssh tunnel and serve as basis for SFTP file transfer + + :param ssh_conn_id: connection id from airflow Connections from where all the required + parameters can be fetched like username, password or key_file. + Thought the priority is given to the param passed during init + :type ssh_conn_id: str + :param remote_host: remote host to connect + :type remote_host: str + :param username: username to connect to the remote_host + :type username: str + :param password: password of the username to connect to the remote_host + :type password: str + :param key_file: key file to use to connect to the remote_host. :type key_file: str - :param connect_timeout: sets the connection timeout for this connection. - :type connect_timeout: int - :param no_host_key_check: whether to check to host key. If True host keys will not - be checked, but are also not stored in the current users's known_hosts file. - :type no_host_key_check: bool - :param tty: allocate a tty. - :type tty: bool - :param sshpass: Use to non-interactively perform password authentication by using - sshpass. - :type sshpass: bool + :param timeout: timeout for the attempt to connect to the remote_host. + :type timeout: int """ - def __init__(self, conn_id='ssh_default'): - conn = self.get_connection(conn_id) - self.key_file = conn.extra_dejson.get('key_file', None) - self.connect_timeout = conn.extra_dejson.get('connect_timeout', None) - self.tcp_keepalive = conn.extra_dejson.get('tcp_keepalive', False) - self.server_alive_interval = conn.extra_dejson.get('server_alive_interval', 60) - self.no_host_key_check = conn.extra_dejson.get('no_host_key_check', False) - self.tty = conn.extra_dejson.get('tty', False) - self.sshpass = conn.extra_dejson.get('sshpass', False) - self.conn = conn - - def get_conn(self): - pass - - def _host_ref(self): - if self.conn.login: - return "{0}@{1}".format(self.conn.login, self.conn.host) - else: - return self.conn.host - - def _prepare_command(self, cmd): - connection_cmd = ["ssh", self._host_ref(), "-o", "ControlMaster=no"] - if self.sshpass: - connection_cmd = ["sshpass", "-e"] + connection_cmd - else: - connection_cmd += ["-o", "BatchMode=yes"] # no password prompts - - if self.conn.port: - connection_cmd += ["-p", str(self.conn.port)] - - if self.connect_timeout: - connection_cmd += ["-o", "ConnectionTimeout={}".format(self.connect_timeout)] - - if self.tcp_keepalive: - connection_cmd += ["-o", "TCPKeepAlive=yes"] - connection_cmd += ["-o", "ServerAliveInterval={}".format(self.server_alive_interval)] - if self.no_host_key_check: - connection_cmd += ["-o", "UserKnownHostsFile=/dev/null", - "-o", "StrictHostKeyChecking=no"] + def __init__(self, + ssh_conn_id=None, + remote_host=None, + username=None, + password=None, + key_file=None, + timeout=10 + ): + super(SSHHook, self).__init__(ssh_conn_id) + self.ssh_conn_id = ssh_conn_id + self.remote_host = remote_host + self.username = username + self.password = password + self.key_file = key_file + self.timeout = timeout + # Default values, overridable from Connection + self.compress = True + self.no_host_key_check = True + self.client = None - if self.key_file: - connection_cmd += ["-i", self.key_file] - - if self.tty: - connection_cmd += ["-t"] - - connection_cmd += cmd - logging.debug("SSH cmd: {} ".format(connection_cmd)) - - return connection_cmd - - def Popen(self, cmd, **kwargs): - """ - Remote Popen - - :param cmd: command to remotely execute - :param kwargs: extra arguments to Popen (see subprocess.Popen) - :return: handle to subprocess - """ - prefixed_cmd = self._prepare_command(cmd) - return subprocess.Popen(prefixed_cmd, **kwargs) - - def check_output(self, cmd): - """ - Executes a remote command and returns the stdout a remote process. - Simplified version of Popen when you only want the output as a string and detect any errors. - - :param cmd: command to remotely execute - :return: stdout - """ - p = self.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) - output, stderr = p.communicate() - - if p.returncode != 0: - # I like this better: RemoteCalledProcessError(p.returncode, cmd, self.host, output=output) - raise AirflowException("Cannot execute {} on {}. Error code is: {}. Output: {}, Stderr: {}".format( - cmd, self.conn.host, p.returncode, output, stderr)) - - return output + def get_conn(self): + if not self.client: + logging.debug('creating ssh client for conn_id: {0}'.format(self.ssh_conn_id)) + if self.ssh_conn_id is not None: + conn = self.get_connection(self.ssh_conn_id) + if self.username is None: + self.username = conn.login + if self.password is None: + self.password = conn.password + if self.remote_host is None: + self.remote_host = conn.host + if conn.extra is not None: + extra_options = conn.extra_dejson + self.key_file = extra_options.get("key_file") + + if "timeout" in extra_options: + self.timeout = int(extra_options["timeout"], 10) + + if "compress" in extra_options \ + and extra_options["compress"].lower() == 'false': + self.compress = False + if "no_host_key_check" in extra_options \ + and extra_options["no_host_key_check"].lower() == 'false': + self.no_host_key_check = False + + if not self.remote_host: + raise AirflowException("Missing required param: remote_host") + + # Auto detecting username values from system + if not self.username: + logging.debug("username to ssh to host: {0} is not specified, using " + "system's default provided by getpass.getuser()" + .format(self.remote_host, self.ssh_conn_id)) + self.username = getpass.getuser() + + host_proxy = None + user_ssh_config_filename = os.path.expanduser('~/.ssh/config') + if os.path.isfile(user_ssh_config_filename): + ssh_conf = paramiko.SSHConfig() + ssh_conf.parse(open(user_ssh_config_filename)) + host_info = ssh_conf.lookup(self.remote_host) + if host_info and host_info.get('proxycommand'): + host_proxy = paramiko.ProxyCommand(host_info.get('proxycommand')) + + if not (self.password or self.key_file): + if host_info and host_info.get('identityfile'): + self.key_file = host_info.get('identityfile')[0] + + try: + client = paramiko.SSHClient() + client.load_system_host_keys() + if self.no_host_key_check: + # Default is RejectPolicy + client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + + if self.password and self.password.strip(): + client.connect(hostname=self.remote_host, + username=self.username, + password=self.password, + timeout=self.timeout, + compress=self.compress, + sock=host_proxy) + else: + client.connect(hostname=self.remote_host, + username=self.username, + key_filename=self.key_file, + timeout=self.timeout, + compress=self.compress, + sock=host_proxy) + + self.client = client + except paramiko.AuthenticationException as auth_error: + logging.error("Auth failed while connecting to host: {0}, error: {1}" + .format(self.remote_host, auth_error)) + except paramiko.SSHException as ssh_error: + logging.error("Failed connecting to host: {0}, error: {1}" + .format(self.remote_host, ssh_error)) + except Exception as error: + logging.error("Error connecting to host: {0}, error: {1}" + .format(self.remote_host, error)) + return self.client @contextmanager - def tunnel(self, local_port, remote_port=None, remote_host="localhost"): + def create_tunnel(self, local_port, remote_port=None, remote_host="localhost"): """ Creates a tunnel between two hosts. Like ssh -L <LOCAL_PORT>:host:<REMOTE_PORT>. Remember to close() the returned "tunnel" object in order to clean up @@ -146,13 +165,40 @@ class SSHHook(BaseHook): :type remote_host: str :return: """ + + import subprocess + # this will ensure the connection to the ssh.remote_host from where the tunnel + # is getting created + self.get_conn() + tunnel_host = "{0}:{1}:{2}".format(local_port, remote_host, remote_port) - proc = self.Popen(["-L", tunnel_host, "echo -n ready && cat"], - stdin=subprocess.PIPE, stdout=subprocess.PIPE) + ssh_cmd = ["ssh", "{0}@{1}".format(self.username, self.remote_host), + "-o", "ControlMaster=no", + "-o", "UserKnownHostsFile=/dev/null", + "-o", "StrictHostKeyChecking=no"] + + ssh_tunnel_cmd = ["-L", tunnel_host, + "echo -n ready && cat" + ] + + ssh_cmd += ssh_tunnel_cmd + logging.debug("creating tunnel with cmd: {0}".format(ssh_cmd)) + + proc = subprocess.Popen(ssh_cmd, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE) ready = proc.stdout.read(5) - assert ready == b"ready", "Did not get 'ready' from remote" + assert ready == b"ready", \ + "Did not get 'ready' from remote, got '{0}' instead".format(ready) yield proc.communicate() - assert proc.returncode == 0, "Tunnel process did unclean exit (returncode {}".format(proc.returncode) + assert proc.returncode == 0, \ + "Tunnel process did unclean exit (returncode {}".format(proc.returncode) + + def __enter__(self): + return self + def __exit__(self, exc_type, exc_val, exc_tb): + if self.client is not None: + self.client.close() http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/fe0edeaa/airflow/contrib/operators/__init__.py ---------------------------------------------------------------------- diff --git a/airflow/contrib/operators/__init__.py b/airflow/contrib/operators/__init__.py index 4ea6c17..b731373 100644 --- a/airflow/contrib/operators/__init__.py +++ b/airflow/contrib/operators/__init__.py @@ -32,7 +32,7 @@ import sys # # ------------------------------------------------------------------------ _operators = { - 'ssh_execute_operator': ['SSHExecuteOperator'], + 'ssh_operator': ['SSHOperator'], 'vertica_operator': ['VerticaOperator'], 'vertica_to_hive': ['VerticaToHiveTransfer'], 'qubole_operator': ['QuboleOperator'], http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/fe0edeaa/airflow/contrib/operators/sftp_operator.py ---------------------------------------------------------------------- diff --git a/airflow/contrib/operators/sftp_operator.py b/airflow/contrib/operators/sftp_operator.py new file mode 100644 index 0000000..b9f07d5 --- /dev/null +++ b/airflow/contrib/operators/sftp_operator.py @@ -0,0 +1,99 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from airflow.contrib.hooks.ssh_hook import SSHHook +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.utils.decorators import apply_defaults + + +class SFTPOperation(object): + PUT = 'put' + GET = 'get' + + +class SFTPOperator(BaseOperator): + """ + SFTPOperator for transferring files from remote host to local or vice a versa. + This operator uses ssh_hook to open sftp trasport channel that serve as basis + for file transfer. + + :param ssh_hook: predefined ssh_hook to use for remote execution + :type ssh_hook: :class:`SSHHook` + :param ssh_conn_id: connection id from airflow Connections + :type ssh_conn_id: str + :param remote_host: remote host to connect + :type remote_host: str + :param local_filepath: local file path to get or put + :type local_filepath: str + :param remote_filepath: remote file path to get or put + :type remote_filepath: str + :param operation: specify operation 'get' or 'put', defaults to get + :type get: bool + """ + template_fields = ('local_filepath', 'remote_filepath') + + @apply_defaults + def __init__(self, + ssh_hook=None, + ssh_conn_id=None, + remote_host=None, + local_filepath=None, + remote_filepath=None, + operation=SFTPOperation.PUT, + *args, + **kwargs): + super(SFTPOperator, self).__init__(*args, **kwargs) + self.ssh_hook = ssh_hook + self.ssh_conn_id = ssh_conn_id + self.remote_host = remote_host + self.local_filepath = local_filepath + self.remote_filepath = remote_filepath + self.operation = operation + if not (self.operation.lower() == SFTPOperation.GET or self.operation.lower() == SFTPOperation.PUT): + raise TypeError("unsupported operation value {0}, expected {1} or {2}" + .format(self.operation, SFTPOperation.GET, SFTPOperation.PUT)) + + def execute(self, context): + file_msg = None + try: + if self.ssh_conn_id and not self.ssh_hook: + self.ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id) + + if not self.ssh_hook: + raise AirflowException("can not operate without ssh_hook or ssh_conn_id") + + if self.remote_host is not None: + self.ssh_hook.remote_host = self.remote_host + + ssh_client = self.ssh_hook.get_conn() + sftp_client = ssh_client.open_sftp() + if self.operation.lower() == SFTPOperation.GET: + file_msg = "from {0} to {1}".format(self.remote_filepath, + self.local_filepath) + logging.debug("Starting to transfer {0}".format(file_msg)) + sftp_client.get(self.remote_filepath, self.local_filepath) + else: + file_msg = "from {0} to {1}".format(self.local_filepath, + self.remote_filepath) + logging.debug("Starting to transfer file {0}".format(file_msg)) + sftp_client.put(self.local_filepath, self.remote_filepath) + + except Exception as e: + raise AirflowException("Error while transferring {0}, error: {1}" + .format(file_msg, str(e))) + + return None http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/fe0edeaa/airflow/contrib/operators/ssh_execute_operator.py ---------------------------------------------------------------------- diff --git a/airflow/contrib/operators/ssh_execute_operator.py b/airflow/contrib/operators/ssh_execute_operator.py deleted file mode 100644 index 3bd8f09..0000000 --- a/airflow/contrib/operators/ssh_execute_operator.py +++ /dev/null @@ -1,159 +0,0 @@ -# -*- coding: utf-8 -*- -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from builtins import bytes -import logging -import subprocess -from subprocess import STDOUT - -from airflow.models import BaseOperator -from airflow.utils.decorators import apply_defaults -from airflow.exceptions import AirflowException - - -class SSHTempFileContent(object): - """This class prvides a functionality that creates tempfile - with given content at remote host. - Use like:: - - with SSHTempFileContent(ssh_hook, content) as tempfile: - ... - - In this case, a temporary file ``tempfile`` - with content ``content`` is created where ``ssh_hook`` designate. - - Note that this isn't safe because other processes - at remote host can read and write that tempfile. - - :param ssh_hook: A SSHHook that indicates a remote host - where you want to create tempfile - :param content: Initial content of creating temporary file - :type content: string - :param prefix: The prefix string you want to use for the temporary file - :type prefix: string - """ - - def __init__(self, ssh_hook, content, prefix="tmp"): - self._ssh_hook = ssh_hook - self._content = content - self._prefix = prefix - - def __enter__(self): - ssh_hook = self._ssh_hook - string = self._content - prefix = self._prefix - - pmktemp = ssh_hook.Popen(["-q", - "mktemp", "-t", prefix + "_XXXXXX"], - stdout=subprocess.PIPE, - stderr=STDOUT) - tempfile = pmktemp.communicate()[0].rstrip() - pmktemp.wait() - if pmktemp.returncode: - raise AirflowException("Failed to create remote temp file") - - ptee = ssh_hook.Popen(["-q", "tee", tempfile], - stdin=subprocess.PIPE, - # discard stdout - stderr=STDOUT) - ptee.stdin.write(bytes(string, 'utf_8')) - ptee.stdin.close() - ptee.wait() - if ptee.returncode: - raise AirflowException("Failed to write to remote temp file") - - self._tempfile = tempfile - return tempfile - - def __exit__(self, type, value, traceback): - sp = self._ssh_hook.Popen(["-q", "rm", "-f", "--", self._tempfile]) - sp.communicate() - sp.wait() - if sp.returncode: - raise AirflowException("Failed to remove to remote temp file") - return False - - -class SSHExecuteOperator(BaseOperator): - """ - Execute a Bash script, command or set of commands at remote host. - - :param ssh_hook: A SSHHook that indicates the remote host - you want to run the script - :type ssh_hook: string - :param bash_command: The command, set of commands or reference to a - bash script (must be '.sh') to be executed. - :type bash_command: string - :param env: If env is not None, it must be a mapping that defines the - environment variables for the new process; these are used instead - of inheriting the current process environment, which is the default - behavior. - :type env: dict - """ - - template_fields = ("bash_command", "env",) - template_ext = (".sh", ".bash",) - - @apply_defaults - def __init__(self, - ssh_hook, - bash_command, - xcom_push=False, - env=None, - *args, **kwargs): - super(SSHExecuteOperator, self).__init__(*args, **kwargs) - self.bash_command = bash_command - self.env = env - self.hook = ssh_hook - self.xcom_push = xcom_push - - def execute(self, context): - bash_command = self.bash_command - hook = self.hook - host = hook._host_ref() - - with SSHTempFileContent(self.hook, - self.bash_command, - self.task_id) as remote_file_path: - logging.info("Temporary script " - "location : {0}:{1}".format(host, remote_file_path)) - logging.info("Running command: " + bash_command) - if self.env is not None: - logging.info("env: " + str(self.env)) - - sp = hook.Popen( - ['-q', 'bash', remote_file_path], - stdout=subprocess.PIPE, stderr=STDOUT, - env=self.env) - - self.sp = sp - - logging.info("Output:") - line = '' - for line in iter(sp.stdout.readline, b''): - line = line.decode('utf_8').strip() - logging.info(line) - sp.wait() - logging.info("Command exited with " - "return code {0}".format(sp.returncode)) - if sp.returncode: - raise AirflowException("Bash command failed") - if self.xcom_push: - return line - - def on_kill(self): - # TODO: Cleanup remote tempfile - # TODO: kill `mktemp` or `tee` too when they are alive. - logging.info('Sending SIGTERM signal to bash subprocess') - self.sp.terminate() http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/fe0edeaa/airflow/contrib/operators/ssh_operator.py ---------------------------------------------------------------------- diff --git a/airflow/contrib/operators/ssh_operator.py b/airflow/contrib/operators/ssh_operator.py new file mode 100644 index 0000000..ff874da --- /dev/null +++ b/airflow/contrib/operators/ssh_operator.py @@ -0,0 +1,106 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from airflow.contrib.hooks.ssh_hook import SSHHook +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.utils.decorators import apply_defaults + + +class SSHOperator(BaseOperator): + + """ + SSHOperator to execute commands on given remote host using the ssh_hook. + + :param ssh_hook: predefined ssh_hook to use for remote execution + :type ssh_hook: :class:`SSHHook` + :param ssh_conn_id: connection id from airflow Connections + :type ssh_conn_id: str + :param remote_host: remote host to connect + :type remote_host: str + :param command: command to execute on remote host + :type command: str + :param timeout: timeout for executing the command. + :type timeout: int + :param do_xcom_push: return the stdout which also get set in xcom by airflow platform + :type do_xcom_push: bool + """ + + template_fields = ('command',) + + @apply_defaults + def __init__(self, + ssh_hook=None, + ssh_conn_id=None, + remote_host=None, + command=None, + timeout=10, + do_xcom_push=False, + *args, + **kwargs): + super(SSHOperator, self).__init__(*args, **kwargs) + self.ssh_hook = ssh_hook + self.ssh_conn_id = ssh_conn_id + self.remote_host = remote_host + self.command = command + self.timeout = timeout + self.do_xcom_push = do_xcom_push + + def execute(self, context): + try: + if self.ssh_conn_id and not self.ssh_hook: + self.ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id) + + if not self.ssh_hook: + raise AirflowException("can not operate without ssh_hook or ssh_conn_id") + + if self.remote_host is not None: + self.ssh_hook.remote_host = self.remote_host + + ssh_client = self.ssh_hook.get_conn() + + if not self.command: + raise AirflowException("no command specified so nothing to execute here.") + + # Auto apply tty when its required in case of sudo + get_pty = False + if self.command.startswith('sudo'): + get_pty = True + + # set timeout taken as params + stdin, stdout, stderr = ssh_client.exec_command(command=self.command, + get_pty=get_pty, + timeout=self.timeout + ) + exit_status = stdout.channel.recv_exit_status() + if exit_status is 0: + # only returning on output if do_xcom_push is set + # otherwise its not suppose to be disclosed + if self.do_xcom_push: + return stdout.read() + else: + error_msg = stderr.read() + raise AirflowException("error running cmd: {0}, error: {1}" + .format(self.command, error_msg)) + + except Exception as e: + raise AirflowException("SSH operator error: {0}".format(str(e))) + + return True + + def tunnel(self): + ssh_client = self.ssh_hook.get_conn() + ssh_client.get_transport() http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/fe0edeaa/docs/code.rst ---------------------------------------------------------------------- diff --git a/docs/code.rst b/docs/code.rst index b17c3fe..a1980f2 100644 --- a/docs/code.rst +++ b/docs/code.rst @@ -91,7 +91,7 @@ Community-contributed Operators .. automodule:: airflow.contrib.operators :show-inheritance: :members: - SSHExecuteOperator, + SSHOperator, VerticaOperator, VerticaToHiveTransfer http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/fe0edeaa/scripts/ci/requirements.txt ---------------------------------------------------------------------- diff --git a/scripts/ci/requirements.txt b/scripts/ci/requirements.txt index 0e5dbaf..670335c 100644 --- a/scripts/ci/requirements.txt +++ b/scripts/ci/requirements.txt @@ -63,6 +63,7 @@ oauth2client>=2.0.2,<2.1.0 pandas pandas-gbq parameterized +paramiko>=2.1.1 psutil>=4.2.0, <5.0.0 psycopg2 pydruid http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/fe0edeaa/setup.py ---------------------------------------------------------------------- diff --git a/setup.py b/setup.py index e4689d2..dedcf76 100644 --- a/setup.py +++ b/setup.py @@ -160,6 +160,7 @@ mysql = ['mysqlclient>=1.3.6'] rabbitmq = ['librabbitmq>=1.6.1'] oracle = ['cx_Oracle>=5.1.2'] postgres = ['psycopg2>=2.7.1'] +ssh = ['paramiko>=2.1.1'] salesforce = ['simple-salesforce>=0.72'] s3 = [ 'boto>=2.36.0', @@ -196,11 +197,12 @@ devel = [ 'nose-ignore-docstring==0.2', 'nose-timer', 'parameterized', - 'rednose' + 'rednose', + 'paramiko' ] devel_minreq = devel + mysql + doc + password + s3 + cgroups devel_hadoop = devel_minreq + hive + hdfs + webhdfs + kerberos -devel_all = devel + all_dbs + doc + samba + s3 + slack + crypto + oracle + docker +devel_all = devel + all_dbs + doc + samba + s3 + slack + crypto + oracle + docker + ssh def do_setup(): @@ -284,6 +286,7 @@ def do_setup(): 'salesforce': salesforce, 'samba': samba, 'slack': slack, + 'ssh': ssh, 'statsd': statsd, 'vertica': vertica, 'webhdfs': webhdfs, http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/fe0edeaa/tests/contrib/hooks/test_ssh_hook.py ---------------------------------------------------------------------- diff --git a/tests/contrib/hooks/test_ssh_hook.py b/tests/contrib/hooks/test_ssh_hook.py new file mode 100644 index 0000000..a556332 --- /dev/null +++ b/tests/contrib/hooks/test_ssh_hook.py @@ -0,0 +1,70 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from airflow import configuration + + +HELLO_SERVER_CMD = """ +import socket, sys +listener = socket.socket() +listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) +listener.bind(('localhost', 2134)) +listener.listen(1) +sys.stdout.write('ready') +sys.stdout.flush() +conn = listener.accept()[0] +conn.sendall(b'hello') +""" + + +class SSHHookTest(unittest.TestCase): + def setUp(self): + configuration.load_test_config() + from airflow.contrib.hooks.ssh_hook import SSHHook + self.hook = SSHHook(ssh_conn_id='ssh_default') + self.hook.no_host_key_check = True + + def test_ssh_connection(self): + ssh_hook = self.hook.get_conn() + self.assertIsNotNone(ssh_hook) + + def test_tunnel(self): + print("Setting up remote listener") + import subprocess + import socket + + self.server_handle = subprocess.Popen(["python", "-c", HELLO_SERVER_CMD], + stdout=subprocess.PIPE) + print("Setting up tunnel") + with self.hook.create_tunnel(2135, 2134): + print("Tunnel up") + server_output = self.server_handle.stdout.read(5) + self.assertEqual(server_output, b"ready") + print("Connecting to server via tunnel") + s = socket.socket() + s.connect(("localhost", 2135)) + print("Receiving...", ) + response = s.recv(5) + self.assertEqual(response, b"hello") + print("Closing connection") + s.close() + print("Waiting for listener...") + output, _ = self.server_handle.communicate() + self.assertEqual(self.server_handle.returncode, 0) + print("Closing tunnel") + + +if __name__ == '__main__': + unittest.main() http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/fe0edeaa/tests/contrib/operators/test_sftp_operator.py ---------------------------------------------------------------------- diff --git a/tests/contrib/operators/test_sftp_operator.py b/tests/contrib/operators/test_sftp_operator.py new file mode 100644 index 0000000..3d31414 --- /dev/null +++ b/tests/contrib/operators/test_sftp_operator.py @@ -0,0 +1,158 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest +from datetime import datetime + +from airflow import configuration +from airflow import models +from airflow.contrib.operators.sftp_operator import SFTPOperator, SFTPOperation +from airflow.contrib.operators.ssh_operator import SSHOperator +from airflow.models import DAG, TaskInstance +from airflow.settings import Session + +TEST_DAG_ID = 'unit_tests' +DEFAULT_DATE = datetime(2017, 1, 1) + + +def reset(dag_id=TEST_DAG_ID): + session = Session() + tis = session.query(models.TaskInstance).filter_by(dag_id=dag_id) + tis.delete() + session.commit() + session.close() + +reset() + + +class SFTPOperatorTest(unittest.TestCase): + def setUp(self): + configuration.load_test_config() + from airflow.contrib.hooks.ssh_hook import SSHHook + hook = SSHHook(ssh_conn_id='ssh_default') + hook.no_host_key_check = True + args = { + 'owner': 'airflow', + 'start_date': DEFAULT_DATE, + 'provide_context': True + } + dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once', default_args=args) + dag.schedule_interval = '@once' + self.hook = hook + self.dag = dag + self.test_dir = "/tmp" + self.test_local_filename = 'test_local_file' + self.test_remote_filename = 'test_remote_file' + self.test_local_filepath = '{0}/{1}'.format(self.test_dir, + self.test_local_filename) + self.test_remote_filepath = '{0}/{1}'.format(self.test_dir, + self.test_remote_filename) + + def test_file_transfer_put(self): + test_local_file_content = \ + b"This is local file content \n which is multiline " \ + b"continuing....with other character\nanother line here \n this is last line" + # create a test file locally + with open(self.test_local_filepath, 'wb') as f: + f.write(test_local_file_content) + + # put test file to remote + put_test_task = SFTPOperator( + task_id="test_sftp", + ssh_hook=self.hook, + local_filepath=self.test_local_filepath, + remote_filepath=self.test_remote_filepath, + operation=SFTPOperation.PUT, + dag=self.dag + ) + self.assertIsNotNone(put_test_task) + ti2 = TaskInstance(task=put_test_task, execution_date=datetime.now()) + ti2.run() + + # check the remote file content + check_file_task = SSHOperator( + task_id="test_check_file", + ssh_hook=self.hook, + command="cat {0}".format(self.test_remote_filepath), + do_xcom_push=True, + dag=self.dag + ) + self.assertIsNotNone(check_file_task) + ti3 = TaskInstance(task=check_file_task, execution_date=datetime.now()) + ti3.run() + self.assertEqual( + ti3.xcom_pull(task_ids='test_check_file', key='return_value').strip(), + test_local_file_content) + + def test_file_transfer_get(self): + test_remote_file_content = \ + "This is remote file content \n which is also multiline " \ + "another line here \n this is last line. EOF" + + # create a test file remotely + create_file_task = SSHOperator( + task_id="test_create_file", + ssh_hook=self.hook, + command="echo '{0}' > {1}".format(test_remote_file_content, + self.test_remote_filepath), + do_xcom_push=True, + dag=self.dag + ) + self.assertIsNotNone(create_file_task) + ti1 = TaskInstance(task=create_file_task, execution_date=datetime.now()) + ti1.run() + + # get remote file to local + get_test_task = SFTPOperator( + task_id="test_sftp", + ssh_hook=self.hook, + local_filepath=self.test_local_filepath, + remote_filepath=self.test_remote_filepath, + operation=SFTPOperation.GET, + dag=self.dag + ) + self.assertIsNotNone(get_test_task) + ti2 = TaskInstance(task=get_test_task, execution_date=datetime.now()) + ti2.run() + + # test the received content + content_received = None + with open(self.test_local_filepath, 'r') as f: + content_received = f.read() + self.assertEqual(content_received.strip(), test_remote_file_content) + + def delete_local_resource(self): + if os.path.exists(self.test_local_filepath): + os.remove(self.test_local_filepath) + + def delete_remote_resource(self): + # check the remote file content + remove_file_task = SSHOperator( + task_id="test_check_file", + ssh_hook=self.hook, + command="rm {0}".format(self.test_remote_filepath), + do_xcom_push=True, + dag=self.dag + ) + self.assertIsNotNone(remove_file_task) + ti3 = TaskInstance(task=remove_file_task, execution_date=datetime.now()) + ti3.run() + + def tearDown(self): + self.delete_local_resource() and self.delete_remote_resource() + + +if __name__ == '__main__': + unittest.main() http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/fe0edeaa/tests/contrib/operators/test_ssh_execute_operator.py ---------------------------------------------------------------------- diff --git a/tests/contrib/operators/test_ssh_execute_operator.py b/tests/contrib/operators/test_ssh_execute_operator.py deleted file mode 100644 index 0c2b9f2..0000000 --- a/tests/contrib/operators/test_ssh_execute_operator.py +++ /dev/null @@ -1,95 +0,0 @@ -# -*- coding: utf-8 -*- -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest -import os -import sys -from datetime import datetime -from io import StringIO - -import mock - -from airflow import configuration -from airflow.settings import Session -from airflow import models, DAG -from airflow.contrib.operators.ssh_execute_operator import SSHExecuteOperator - - -TEST_DAG_ID = 'unit_tests' -DEFAULT_DATE = datetime(2015, 1, 1) -configuration.load_test_config() - - -def reset(dag_id=TEST_DAG_ID): - session = Session() - tis = session.query(models.TaskInstance).filter_by(dag_id=dag_id) - tis.delete() - session.commit() - session.close() - -reset() - - -class SSHExecuteOperatorTest(unittest.TestCase): - - def setUp(self): - - if sys.version_info[0] == 3: - raise unittest.SkipTest('SSHExecuteOperatorTest won\'t work with ' - 'python3. No need to test anything here') - - configuration.load_test_config() - from airflow.contrib.hooks.ssh_hook import SSHHook - hook = mock.MagicMock(spec=SSHHook) - hook.no_host_key_check = True - hook.Popen.return_value.stdout = StringIO(u'stdout') - hook.Popen.return_value.returncode = False - args = { - 'owner': 'airflow', - 'start_date': DEFAULT_DATE, - 'provide_context': True - } - dag = DAG(TEST_DAG_ID+'test_schedule_dag_once', default_args=args) - dag.schedule_interval = '@once' - self.hook = hook - self.dag = dag - - @mock.patch('airflow.contrib.operators.ssh_execute_operator.SSHTempFileContent') - def test_simple(self, temp_file): - temp_file.return_value.__enter__ = lambda x: 'filepath' - task = SSHExecuteOperator( - task_id="test", - bash_command="echo airflow", - ssh_hook=self.hook, - dag=self.dag, - ) - task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) - - @mock.patch('airflow.contrib.operators.ssh_execute_operator.SSHTempFileContent') - def test_with_env(self, temp_file): - temp_file.return_value.__enter__ = lambda x: 'filepath' - test_env = os.environ.copy() - test_env['AIRFLOW_test'] = "test" - task = SSHExecuteOperator( - task_id="test", - bash_command="echo $AIRFLOW_HOME", - ssh_hook=self.hook, - env=test_env['AIRFLOW_test'], - dag=self.dag, - ) - task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) - - -if __name__ == '__main__': - unittest.main() http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/fe0edeaa/tests/contrib/operators/test_ssh_operator.py ---------------------------------------------------------------------- diff --git a/tests/contrib/operators/test_ssh_operator.py b/tests/contrib/operators/test_ssh_operator.py new file mode 100644 index 0000000..21433d3 --- /dev/null +++ b/tests/contrib/operators/test_ssh_operator.py @@ -0,0 +1,89 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from datetime import datetime + +from airflow import configuration +from airflow import models +from airflow.contrib.operators.ssh_operator import SSHOperator +from airflow.models import DAG, TaskInstance +from airflow.settings import Session + +TEST_DAG_ID = 'unit_tests' +DEFAULT_DATE = datetime(2017, 1, 1) + + +def reset(dag_id=TEST_DAG_ID): + session = Session() + tis = session.query(models.TaskInstance).filter_by(dag_id=dag_id) + tis.delete() + session.commit() + session.close() + +reset() + + +class SSHOperatorTest(unittest.TestCase): + def setUp(self): + configuration.load_test_config() + from airflow.contrib.hooks.ssh_hook import SSHHook + hook = SSHHook(ssh_conn_id='ssh_default') + hook.no_host_key_check = True + args = { + 'owner': 'airflow', + 'start_date': DEFAULT_DATE, + 'provide_context': True + } + dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once', default_args=args) + dag.schedule_interval = '@once' + self.hook = hook + self.dag = dag + + def test_command_execution(self): + task = SSHOperator( + task_id="test", + ssh_hook=self.hook, + command="echo -n airflow", + do_xcom_push=True, + dag=self.dag, + ) + + self.assertIsNotNone(task) + + ti = TaskInstance( + task=task, execution_date=datetime.now()) + ti.run() + self.assertIsNotNone(ti.duration) + self.assertEqual(ti.xcom_pull(task_ids='test', key='return_value'), b'airflow') + + def test_command_execution_with_env(self): + task = SSHOperator( + task_id="test", + ssh_hook=self.hook, + command="echo -n airflow", + do_xcom_push=True, + dag=self.dag, + ) + + self.assertIsNotNone(task) + + ti = TaskInstance( + task=task, execution_date=datetime.now()) + ti.run() + self.assertIsNotNone(ti.duration) + self.assertEqual(ti.xcom_pull(task_ids='test', key='return_value'), b'airflow') + +if __name__ == '__main__': + unittest.main() http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/fe0edeaa/tests/core.py ---------------------------------------------------------------------- diff --git a/tests/core.py b/tests/core.py index 259b61d..923e0c3 100644 --- a/tests/core.py +++ b/tests/core.py @@ -2381,58 +2381,6 @@ class S3HookTest(unittest.TestCase): "Incorrect parsing of the s3 url") -HELLO_SERVER_CMD = """ -import socket, sys -listener = socket.socket() -listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) -listener.bind(('localhost', 2134)) -listener.listen(1) -sys.stdout.write('ready') -sys.stdout.flush() -conn = listener.accept()[0] -conn.sendall(b'hello') -""" - - -class SSHHookTest(unittest.TestCase): - def setUp(self): - configuration.load_test_config() - from airflow.contrib.hooks.ssh_hook import SSHHook - self.hook = SSHHook() - self.hook.no_host_key_check = True - - def test_remote_cmd(self): - output = self.hook.check_output(["echo", "-n", "airflow"]) - self.assertEqual(output, b"airflow") - - def test_tunnel(self): - print("Setting up remote listener") - import subprocess - import socket - - self.handle = self.hook.Popen([ - "python", "-c", '"{0}"'.format(HELLO_SERVER_CMD) - ], stdout=subprocess.PIPE) - - print("Setting up tunnel") - with self.hook.tunnel(2135, 2134): - print("Tunnel up") - server_output = self.handle.stdout.read(5) - self.assertEqual(server_output, b"ready") - print("Connecting to server via tunnel") - s = socket.socket() - s.connect(("localhost", 2135)) - print("Receiving...", ) - response = s.recv(5) - self.assertEqual(response, b"hello") - print("Closing connection") - s.close() - print("Waiting for listener...") - output, _ = self.handle.communicate() - self.assertEqual(self.handle.returncode, 0) - print("Closing tunnel") - - send_email_test = mock.Mock()
