Repository: incubator-airflow
Updated Branches:
  refs/heads/master 216beacd5 -> 53b89b983


[AIRFLOW-1978] Add support for additional WinRM parameters

Implemented all of the WinRM options from pywinrm.
Implemented support for streaming stdout/stderr.

Closes #3512 from jshvrsn/winrm


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/53b89b98
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/53b89b98
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/53b89b98

Branch: refs/heads/master
Commit: 53b89b98371c7bb993b242c341d3941e9ce09f9a
Parents: 216beac
Author: Joshua Iverson <[email protected]>
Authored: Mon Jun 25 19:42:38 2018 +0100
Committer: Kaxil Naik <[email protected]>
Committed: Mon Jun 25 19:42:38 2018 +0100

----------------------------------------------------------------------
 airflow/contrib/hooks/winrm_hook.py         | 284 ++++++++++++++++-------
 airflow/contrib/operators/winrm_operator.py | 107 ++++++---
 2 files changed, 272 insertions(+), 119 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/53b89b98/airflow/contrib/hooks/winrm_hook.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/hooks/winrm_hook.py 
b/airflow/contrib/hooks/winrm_hook.py
index 0be904b..1dd02f9 100644
--- a/airflow/contrib/hooks/winrm_hook.py
+++ b/airflow/contrib/hooks/winrm_hook.py
@@ -17,121 +17,235 @@
 # specific language governing permissions and limitations
 # under the License.
 #
+
 import getpass
+
 from winrm.protocol import Protocol
+
 from airflow.exceptions import AirflowException
 from airflow.hooks.base_hook import BaseHook
 from airflow.utils.log.logging_mixin import LoggingMixin
 
 
 class WinRMHook(BaseHook, LoggingMixin):
-
     """
     Hook for winrm remote execution using pywinrm.
 
+    :seealso: https://github.com/diyan/pywinrm/blob/master/winrm/protocol.py
+
     :param ssh_conn_id: connection id from airflow Connections from where all
-        the required parameters can be fetched like username, password or 
key_file.
+        the required parameters can be fetched like username and password.
         Thought the priority is given to the param passed during init
-    :type ssh_conn_id: str
-    :param remote_host: remote host to connect
-    :type remote_host: str
+    :type ssh_conn_id: string
+    :param endpoint: When set to `None`, endpoint will be constructed like 
this:
+        'http://{remote_host}:{remote_port}/wsman'
+    :type endpoint: string
+    :param remote_host: Remote host to connect to.
+        Ignored if `endpoint` is not `None`.
+    :type remote_host: string
+    :param remote_port: Remote port to connect to.
+        Ignored if `endpoint` is not `None`.
+    :type remote_port: int
+    :param transport: transport type, one of 'plaintext' (default), 
'kerberos', 'ssl',
+        'ntlm', 'credssp'
+    :type transport: string
     :param username: username to connect to the remote_host
-    :type username: str
+    :type username: string
     :param password: password of the username to connect to the remote_host
-    :type password: str
-    :param key_file: key file to use to connect to the remote_host.
-    :type key_file: str
-    :param timeout: timeout for the attempt to connect to the remote_host.
-    :type timeout: int
-    :param keepalive_interval: send a keepalive packet to remote host
-        every keepalive_interval seconds
-    :type keepalive_interval: int
+    :type password: string
+    :param service: the service name, default is HTTP
+    :type service: string
+    :param keytab: the path to a keytab file if you are using one
+    :type keytab: string
+    :param ca_trust_path: Certification Authority trust path
+    :type ca_trust_path: string
+    :param cert_pem: client authentication certificate file path in PEM format
+    :type cert_pem: string
+    :param cert_key_pem: client authentication certificate key file path in 
PEM format
+    :type cert_key_pem: string
+    :param server_cert_validation: whether server certificate should be 
validated on
+        Python versions that suppport it; one of 'validate' (default), 'ignore'
+    :type server_cert_validation: string
+    :param kerberos_delegation: if True, TGT is sent to target server to
+        allow multiple hops
+    :type kerberos_delegation: bool
+    :param read_timeout_sec: maximum seconds to wait before an HTTP 
connect/read times out
+        (default 30). This value should be slightly higher than 
operation_timeout_sec,
+        as the server can block *at least* that long.
+    :type read_timeout_sec: int
+    :param operation_timeout_sec: maximum allowed time in seconds for any 
single wsman
+        HTTP operation (default 20). Note that operation timeouts while 
receiving output
+        (the only wsman operation that should take any significant time, and 
where these
+        timeouts are expected) will be silently retried indefinitely.
+    :type operation_timeout_sec: int
+    :param kerberos_hostname_override: the hostname to use for the kerberos 
exchange
+        (defaults to the hostname in the endpoint URL)
+    :type kerberos_hostname_override: string
+    :param message_encryption_enabled: Will encrypt the WinRM messages if set 
to True and
+        the transport auth supports message encryption (Default True).
+    :type message_encryption_enabled: bool
+    :param credssp_disable_tlsv1_2: Whether to disable TLSv1.2 support and 
work with older
+        protocols like TLSv1.0, default is False
+    :type credssp_disable_tlsv1_2: bool
+    :param send_cbt: Will send the channel bindings over a HTTPS channel 
(Default: True)
+    :type send_cbt: bool
     """
-
     def __init__(self,
                  ssh_conn_id=None,
+                 endpoint=None,
                  remote_host=None,
+                 remote_port=5985,
+                 transport='plaintext',
                  username=None,
                  password=None,
-                 key_file=None,
-                 timeout=10,
-                 keepalive_interval=30
+                 service='HTTP',
+                 keytab=None,
+                 ca_trust_path=None,
+                 cert_pem=None,
+                 cert_key_pem=None,
+                 server_cert_validation='validate',
+                 kerberos_delegation=False,
+                 read_timeout_sec=30,
+                 operation_timeout_sec=20,
+                 kerberos_hostname_override=None,
+                 message_encryption='auto',
+                 credssp_disable_tlsv1_2=False,
+                 send_cbt=True,
                  ):
         super(WinRMHook, self).__init__(ssh_conn_id)
-        # TODO make new win rm connection class
         self.ssh_conn_id = ssh_conn_id
+        self.endpoint = endpoint
         self.remote_host = remote_host
+        self.remote_port = remote_port
+        self.transport = transport
         self.username = username
         self.password = password
-        self.key_file = key_file
-        self.timeout = timeout
-        self.keepalive_interval = keepalive_interval
-        # Default values, overridable from Connection
-        self.compress = True
-        self.no_host_key_check = True
+        self.service = service
+        self.keytab = keytab
+        self.ca_trust_path = ca_trust_path
+        self.cert_pem = cert_pem
+        self.cert_key_pem = cert_key_pem
+        self.server_cert_validation = server_cert_validation
+        self.kerberos_delegation = kerberos_delegation
+        self.read_timeout_sec = read_timeout_sec
+        self.operation_timeout_sec = operation_timeout_sec
+        self.kerberos_hostname_override = kerberos_hostname_override
+        self.message_encryption = message_encryption
+        self.credssp_disable_tlsv1_2 = credssp_disable_tlsv1_2
+        self.send_cbt = send_cbt
+
         self.client = None
         self.winrm_protocol = None
 
     def get_conn(self):
-        if not self.client:
-            self.log.debug('Creating WinRM client for conn_id: %s', 
self.ssh_conn_id)
-            if self.ssh_conn_id is not None:
-                conn = self.get_connection(self.ssh_conn_id)
-                if self.username is None:
-                    self.username = conn.login
-                if self.password is None:
-                    self.password = conn.password
-                if self.remote_host is None:
-                    self.remote_host = conn.host
-                if conn.extra is not None:
-                    extra_options = conn.extra_dejson
-                    self.key_file = extra_options.get("key_file")
-
-                    if "timeout" in extra_options:
-                        self.timeout = int(extra_options["timeout"], 10)
-
-                    if "compress" in extra_options \
-                            and extra_options["compress"].lower() == 'false':
-                        self.compress = False
-                    if "no_host_key_check" in extra_options \
-                            and extra_options["no_host_key_check"].lower() == 
'false':
-                        self.no_host_key_check = False
-
-            if not self.remote_host:
-                raise AirflowException("Missing required param: remote_host")
-
-            # Auto detecting username values from system
-            if not self.username:
-                self.log.debug(
-                    "username to ssh to host: %s is not specified for 
connection id"
-                    " %s. Using system's default provided by 
getpass.getuser()",
-                    self.remote_host, self.ssh_conn_id
-                )
-                self.username = getpass.getuser()
-
-            try:
-
-                if self.password and self.password.strip():
-                    self.winrm_protocol = Protocol(
-                        # TODO pass in port from ssh conn
-                        endpoint='http://' + self.remote_host + ':5985/wsman',
-                        # TODO get cert transport working
-                        # transport='certificate',
-                        transport='plaintext',
-                        # cert_pem=r'publickey.pem',
-                        # cert_key_pem=r'dev.pem',
-                        read_timeout_sec=70,
-                        operation_timeout_sec=60,
-                        username=self.username,
-                        password=self.password,
-                        server_cert_validation='ignore')
-
-                self.log.info("Opening WinRM shell")
-                self.client = self.winrm_protocol.open_shell()
-
-            except Exception as error:
-                self.log.error(
-                    "Error connecting to host: %s, error: %s",
-                    self.remote_host, error
+        if self.client:
+            return self.client
+
+        self.log.debug('Creating WinRM client for conn_id: %s', 
self.ssh_conn_id)
+        if self.ssh_conn_id is not None:
+            conn = self.get_connection(self.ssh_conn_id)
+
+            if self.username is None:
+                self.username = conn.login
+            if self.password is None:
+                self.password = conn.password
+            if self.remote_host is None:
+                self.remote_host = conn.host
+
+            if conn.extra is not None:
+                extra_options = conn.extra_dejson
+
+                if "endpoint" in extra_options:
+                    self.endpoint = str(extra_options["endpoint"])
+                if "remote_port" in extra_options:
+                    self.remote_port = int(extra_options["remote_port"])
+                if "transport" in extra_options:
+                    self.transport = str(extra_options["transport"])
+                if "service" in extra_options:
+                    self.service = str(extra_options["service"])
+                if "keytab" in extra_options:
+                    self.keytab = str(extra_options["keytab"])
+                if "ca_trust_path" in extra_options:
+                    self.ca_trust_path = str(extra_options["ca_trust_path"])
+                if "cert_pem" in extra_options:
+                    self.cert_pem = str(extra_options["cert_pem"])
+                if "cert_key_pem" in extra_options:
+                    self.cert_key_pem = str(extra_options["cert_key_pem"])
+                if "server_cert_validation" in extra_options:
+                    self.server_cert_validation = \
+                        str(extra_options["server_cert_validation"])
+                if "kerberos_delegation" in extra_options:
+                    self.kerberos_delegation = \
+                        str(extra_options["kerberos_delegation"]).lower() == 
'true'
+                if "read_timeout_sec" in extra_options:
+                    self.read_timeout_sec = 
int(extra_options["read_timeout_sec"])
+                if "operation_timeout_sec" in extra_options:
+                    self.operation_timeout_sec = \
+                        int(extra_options["operation_timeout_sec"])
+                if "kerberos_hostname_override" in extra_options:
+                    self.kerberos_hostname_override = \
+                        str(extra_options["kerberos_hostname_override"])
+                if "message_encryption" in extra_options:
+                    self.message_encryption = 
str(extra_options["message_encryption"])
+                if "credssp_disable_tlsv1_2" in extra_options:
+                    self.credssp_disable_tlsv1_2 = \
+                        str(extra_options["credssp_disable_tlsv1_2"]).lower() 
== 'true'
+                if "send_cbt" in extra_options:
+                    self.send_cbt = str(extra_options["send_cbt"]).lower() == 
'true'
+
+        if not self.remote_host:
+            raise AirflowException("Missing required param: remote_host")
+
+        # Auto detecting username values from system
+        if not self.username:
+            self.log.debug(
+                "username to WinRM to host: %s is not specified for connection 
id"
+                " %s. Using system's default provided by getpass.getuser()",
+                self.remote_host, self.ssh_conn_id
+            )
+            self.username = getpass.getuser()
+
+        # If endpoint is not set, then build a standard wsman endpoint from 
host and port.
+        if not self.endpoint:
+            self.endpoint = 'http://{0}:{1}/wsman'.format(
+                self.remote_host,
+                self.remote_port
+            )
+
+        try:
+            if self.password and self.password.strip():
+                self.winrm_protocol = Protocol(
+                    endpoint=self.endpoint,
+                    transport=self.transport,
+                    username=self.username,
+                    password=self.password,
+                    service=self.service,
+                    keytab=self.keytab,
+                    ca_trust_path=self.ca_trust_path,
+                    cert_pem=self.cert_pem,
+                    cert_key_pem=self.cert_key_pem,
+                    server_cert_validation=self.server_cert_validation,
+                    kerberos_delegation=self.kerberos_delegation,
+                    read_timeout_sec=self.read_timeout_sec,
+                    operation_timeout_sec=self.operation_timeout_sec,
+                    kerberos_hostname_override=self.kerberos_hostname_override,
+                    message_encryption=self.message_encryption,
+                    credssp_disable_tlsv1_2=self.credssp_disable_tlsv1_2,
+                    send_cbt=self.send_cbt
                 )
+
+            self.log.info(
+                "Establishing WinRM connection to host: %s",
+                self.remote_host
+            )
+            self.client = self.winrm_protocol.open_shell()
+
+        except Exception as error:
+            error_msg = "Error connecting to host: {0}, error: {1}".format(
+                self.remote_host,
+                error
+            )
+            self.log.error(error_msg)
+            raise AirflowException(error_msg)
+
         return self.client

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/53b89b98/airflow/contrib/operators/winrm_operator.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/operators/winrm_operator.py 
b/airflow/contrib/operators/winrm_operator.py
index fcd2328..c81acac 100644
--- a/airflow/contrib/operators/winrm_operator.py
+++ b/airflow/contrib/operators/winrm_operator.py
@@ -17,14 +17,24 @@
 # specific language governing permissions and limitations
 # under the License.
 
+from base64 import b64encode
+import logging
+
+from winrm.exceptions import WinRMOperationTimeoutError
+
+from airflow import configuration
 from airflow.contrib.hooks.winrm_hook import WinRMHook
 from airflow.exceptions import AirflowException
 from airflow.models import BaseOperator
 from airflow.utils.decorators import apply_defaults
 
+# Hide the following error message in urllib3 when making WinRM connections:
+# requests.packages.urllib3.exceptions.HeaderParsingError: 
[StartBoundaryNotFoundDefect(),
+#   MultipartInvariantViolationDefect()], unparsed data: ''
+logging.getLogger('requests.packages.urllib3.connectionpool').setLevel(logging.CRITICAL)
 
-class WinRMOperator(BaseOperator):
 
+class WinRMOperator(BaseOperator):
     """
     WinRMOperator to execute commands on given remote host using the 
winrm_hook.
 
@@ -41,7 +51,6 @@ class WinRMOperator(BaseOperator):
     :param do_xcom_push: return the stdout which also get set in xcom by 
airflow platform
     :type do_xcom_push: bool
     """
-
     template_fields = ('command',)
 
     @apply_defaults
@@ -63,48 +72,78 @@ class WinRMOperator(BaseOperator):
         self.do_xcom_push = do_xcom_push
 
     def execute(self, context):
-        try:
-            if self.ssh_conn_id and not self.winrm_hook:
-                self.log.info("hook not found, creating")
-                self.winrm_hook = WinRMHook(ssh_conn_id=self.ssh_conn_id)
+        if self.ssh_conn_id and not self.winrm_hook:
+            self.log.info("Hook not found, creating...")
+            self.winrm_hook = WinRMHook(ssh_conn_id=self.ssh_conn_id)
 
-            if not self.winrm_hook:
-                raise AirflowException("can not operate without ssh_hook or 
ssh_conn_id")
+        if not self.winrm_hook:
+            raise AirflowException("Cannot operate without winrm_hook or 
ssh_conn_id.")
 
-            if self.remote_host is not None:
-                self.winrm_hook.remote_host = self.remote_host
+        if self.remote_host is not None:
+            self.winrm_hook.remote_host = self.remote_host
 
-            winrm_client = self.winrm_hook.get_conn()
-            self.log.info("Established WinRM connection")
+        if not self.command:
+            raise AirflowException("No command specified so nothing to execute 
here.")
 
-            if not self.command:
-                raise AirflowException("no command specified so nothing to 
execute here.")
+        winrm_client = self.winrm_hook.get_conn()
 
-            self.log.info(
-                "Starting command: '{command}' on remote host: {remotehost}".
-                format(command=self.command, 
remotehost=self.winrm_hook.remote_host)
+        try:
+            self.log.info("Running command: 
'{command}'...".format(command=self.command))
+            command_id = self.winrm_hook.winrm_protocol.run_command(
+                winrm_client,
+                self.command
             )
-            command_id = self.winrm_hook.winrm_protocol. \
-                run_command(winrm_client, self.command)
-            std_out, std_err, status_code = self.winrm_hook.winrm_protocol. \
-                get_command_output(winrm_client, command_id)
-
-            self.log.info("std out: " + std_out.decode())
-            self.log.info("std err: " + std_err.decode())
-            self.log.info("exit code: " + str(status_code))
-            self.log.info("Cleaning up WinRM command")
+
+            # See: 
https://github.com/diyan/pywinrm/blob/master/winrm/protocol.py
+            stdout_buffer = []
+            stderr_buffer = []
+            command_done = False
+            while not command_done:
+                try:
+                    stdout, stderr, return_code, command_done = \
+                        self.winrm_hook.winrm_protocol._raw_get_command_output(
+                            winrm_client,
+                            command_id
+                        )
+
+                    # Only buffer stdout if we need to so that we minimize 
memory usage.
+                    if self.do_xcom_push:
+                        stdout_buffer.append(stdout)
+                    stderr_buffer.append(stderr)
+
+                    for line in stdout.decode('utf-8').splitlines():
+                        self.log.info(line)
+                    for line in stderr.decode('utf-8').splitlines():
+                        self.log.warning(line)
+                except WinRMOperationTimeoutError as e:
+                    # this is an expected error when waiting for a
+                    # long-running process, just silently retry
+                    pass
+
             self.winrm_hook.winrm_protocol.cleanup_command(winrm_client, 
command_id)
-            self.log.info("Cleaning up WinRM protocol shell")
             self.winrm_hook.winrm_protocol.close_shell(winrm_client)
-            if status_code is 0:
-                return std_out.decode()
-
-            else:
-                error_msg = std_err.decode()
-                raise AirflowException("error running cmd: {0}, error: {1}"
-                                       .format(self.command, error_msg))
 
         except Exception as e:
             raise AirflowException("WinRM operator error: {0}".format(str(e)))
 
+        if return_code is 0:
+            # returning output if do_xcom_push is set
+            if self.do_xcom_push:
+                enable_pickling = configuration.conf.getboolean(
+                    'core', 'enable_xcom_pickling'
+                )
+                if enable_pickling:
+                    return stdout_buffer
+                else:
+                    return b64encode(b''.join(stdout_buffer)).decode('utf-8')
+        else:
+            error_msg = "Error running cmd: {0}, return code: {1}, error: 
{2}".format(
+                self.command,
+                return_code,
+                b''.join(stderr_buffer).decode('utf-8')
+            )
+            raise AirflowException(error_msg)
+
+        self.log.info("Finished!")
+
         return True

Reply via email to