feng-tao closed pull request #3793: [AIRFLOW-2948] Arg check & better doc - 
SSHOperator & SFTPOperator
URL: https://github.com/apache/incubator-airflow/pull/3793
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/airflow/contrib/operators/sftp_operator.py 
b/airflow/contrib/operators/sftp_operator.py
index 3c736c8b95..a3b5c1f244 100644
--- a/airflow/contrib/operators/sftp_operator.py
+++ b/airflow/contrib/operators/sftp_operator.py
@@ -33,11 +33,15 @@ class SFTPOperator(BaseOperator):
     This operator uses ssh_hook to open sftp trasport channel that serve as 
basis
     for file transfer.
 
-    :param ssh_hook: predefined ssh_hook to use for remote execution
+    :param ssh_hook: predefined ssh_hook to use for remote execution.
+        Either `ssh_hook` or `ssh_conn_id` needs to be provided.
     :type ssh_hook: :class:`SSHHook`
-    :param ssh_conn_id: connection id from airflow Connections
+    :param ssh_conn_id: connection id from airflow Connections.
+        `ssh_conn_id` will be ingored if `ssh_hook` is provided.
     :type ssh_conn_id: str
     :param remote_host: remote host to connect (templated)
+        Nullable. If provided, it will replace the `remote_host` which was
+        defined in `ssh_hook` or predefined in the connection of `ssh_conn_id`.
     :type remote_host: str
     :param local_filepath: local file path to get or put. (templated)
     :type local_filepath: str
@@ -77,13 +81,21 @@ def __init__(self,
     def execute(self, context):
         file_msg = None
         try:
-            if self.ssh_conn_id and not self.ssh_hook:
-                self.ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id)
+            if self.ssh_conn_id:
+                if self.ssh_hook and isinstance(self.ssh_hook, SSHHook):
+                    self.log.info("ssh_conn_id is ignored when ssh_hook is 
provided.")
+                else:
+                    self.log.info("ssh_hook is not provided or invalid. " +
+                                  "Trying ssh_conn_id to create SSHHook.")
+                    self.ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id)
 
             if not self.ssh_hook:
-                raise AirflowException("can not operate without ssh_hook or 
ssh_conn_id")
+                raise AirflowException("Cannot operate without ssh_hook or 
ssh_conn_id.")
 
             if self.remote_host is not None:
+                self.log.info("remote_host is provided explicitly. " +
+                              "It will replace the remote_host which was 
defined " +
+                              "in ssh_hook or predefined in connection of 
ssh_conn_id.")
                 self.ssh_hook.remote_host = self.remote_host
 
             with self.ssh_hook.get_conn() as ssh_client:
diff --git a/airflow/contrib/operators/ssh_operator.py 
b/airflow/contrib/operators/ssh_operator.py
index c0e8953d2c..2bf342935d 100644
--- a/airflow/contrib/operators/ssh_operator.py
+++ b/airflow/contrib/operators/ssh_operator.py
@@ -31,11 +31,15 @@ class SSHOperator(BaseOperator):
     """
     SSHOperator to execute commands on given remote host using the ssh_hook.
 
-    :param ssh_hook: predefined ssh_hook to use for remote execution
+    :param ssh_hook: predefined ssh_hook to use for remote execution.
+        Either `ssh_hook` or `ssh_conn_id` needs to be provided.
     :type ssh_hook: :class:`SSHHook`
-    :param ssh_conn_id: connection id from airflow Connections
+    :param ssh_conn_id: connection id from airflow Connections.
+        `ssh_conn_id` will be ingored if `ssh_hook` is provided.
     :type ssh_conn_id: str
     :param remote_host: remote host to connect (templated)
+        Nullable. If provided, it will replace the `remote_host` which was
+        defined in `ssh_hook` or predefined in the connection of `ssh_conn_id`.
     :type remote_host: str
     :param command: command to execute on remote host. (templated)
     :type command: str
@@ -68,14 +72,22 @@ def __init__(self,
 
     def execute(self, context):
         try:
-            if self.ssh_conn_id and not self.ssh_hook:
-                self.ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id,
-                                        timeout=self.timeout)
+            if self.ssh_conn_id:
+                if self.ssh_hook and isinstance(self.ssh_hook, SSHHook):
+                    self.log.info("ssh_conn_id is ignored when ssh_hook is 
provided.")
+                else:
+                    self.log.info("ssh_hook is not provided or invalid. " +
+                                  "Trying ssh_conn_id to create SSHHook.")
+                    self.ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id,
+                                            timeout=self.timeout)
 
             if not self.ssh_hook:
                 raise AirflowException("Cannot operate without ssh_hook or 
ssh_conn_id.")
 
             if self.remote_host is not None:
+                self.log.info("remote_host is provided explicitly. " +
+                              "It will replace the remote_host which was 
defined " +
+                              "in ssh_hook or predefined in connection of 
ssh_conn_id.")
                 self.ssh_hook.remote_host = self.remote_host
 
             if not self.command:
diff --git a/tests/contrib/operators/test_sftp_operator.py 
b/tests/contrib/operators/test_sftp_operator.py
index 01446a6fdd..5770c1b940 100644
--- a/tests/contrib/operators/test_sftp_operator.py
+++ b/tests/contrib/operators/test_sftp_operator.py
@@ -20,6 +20,7 @@
 import os
 import unittest
 from base64 import b64encode
+import six
 
 from airflow import configuration
 from airflow import models
@@ -219,6 +220,71 @@ def test_json_file_transfer_get(self):
         self.assertEqual(content_received.strip(),
             test_remote_file_content.encode('utf-8').decode('utf-8'))
 
+    def test_arg_checking(self):
+        from airflow.exceptions import AirflowException
+        conn_id = "conn_id_for_testing"
+        os.environ['AIRFLOW_CONN_' + conn_id.upper()] = 
"ssh://test_id@localhost"
+
+        # Exception should be raised if neither ssh_hook nor ssh_conn_id is 
provided
+        if six.PY2:
+            self.assertRaisesRegex = self.assertRaisesRegexp
+        with self.assertRaisesRegex(AirflowException,
+                                    "Cannot operate without ssh_hook or 
ssh_conn_id."):
+            task_0 = SFTPOperator(
+                task_id="test_sftp",
+                local_filepath=self.test_local_filepath,
+                remote_filepath=self.test_remote_filepath,
+                operation=SFTPOperation.PUT,
+                dag=self.dag
+            )
+            task_0.execute(None)
+
+        # if ssh_hook is invalid/not provided, use ssh_conn_id to create 
SSHHook
+        task_1 = SFTPOperator(
+            task_id="test_sftp",
+            ssh_hook="string_rather_than_SSHHook",  # invalid ssh_hook
+            ssh_conn_id=conn_id,
+            local_filepath=self.test_local_filepath,
+            remote_filepath=self.test_remote_filepath,
+            operation=SFTPOperation.PUT,
+            dag=self.dag
+        )
+        try:
+            task_1.execute(None)
+        except Exception:
+            pass
+        self.assertEqual(task_1.ssh_hook.ssh_conn_id, conn_id)
+
+        task_2 = SFTPOperator(
+            task_id="test_sftp",
+            ssh_conn_id=conn_id,  # no ssh_hook provided
+            local_filepath=self.test_local_filepath,
+            remote_filepath=self.test_remote_filepath,
+            operation=SFTPOperation.PUT,
+            dag=self.dag
+        )
+        try:
+            task_2.execute(None)
+        except Exception:
+            pass
+        self.assertEqual(task_2.ssh_hook.ssh_conn_id, conn_id)
+
+        # if both valid ssh_hook and ssh_conn_id are provided, ignore 
ssh_conn_id
+        task_3 = SFTPOperator(
+            task_id="test_sftp",
+            ssh_hook=self.hook,
+            ssh_conn_id=conn_id,
+            local_filepath=self.test_local_filepath,
+            remote_filepath=self.test_remote_filepath,
+            operation=SFTPOperation.PUT,
+            dag=self.dag
+        )
+        try:
+            task_3.execute(None)
+        except Exception:
+            pass
+        self.assertEqual(task_3.ssh_hook.ssh_conn_id, self.hook.ssh_conn_id)
+
     def delete_local_resource(self):
         if os.path.exists(self.test_local_filepath):
             os.remove(self.test_local_filepath)
@@ -226,11 +292,11 @@ def delete_local_resource(self):
     def delete_remote_resource(self):
         # check the remote file content
         remove_file_task = SSHOperator(
-                task_id="test_check_file",
-                ssh_hook=self.hook,
-                command="rm {0}".format(self.test_remote_filepath),
-                do_xcom_push=True,
-                dag=self.dag
+            task_id="test_check_file",
+            ssh_hook=self.hook,
+            command="rm {0}".format(self.test_remote_filepath),
+            do_xcom_push=True,
+            dag=self.dag
         )
         self.assertIsNotNone(remove_file_task)
         ti3 = TaskInstance(task=remove_file_task, 
execution_date=timezone.utcnow())
diff --git a/tests/contrib/operators/test_ssh_operator.py 
b/tests/contrib/operators/test_ssh_operator.py
index 7ddd24b2ac..1a2c788596 100644
--- a/tests/contrib/operators/test_ssh_operator.py
+++ b/tests/contrib/operators/test_ssh_operator.py
@@ -19,6 +19,7 @@
 
 import unittest
 from base64 import b64encode
+import six
 
 from airflow import configuration
 from airflow import models
@@ -148,6 +149,65 @@ def test_no_output_command(self):
         self.assertIsNotNone(ti.duration)
         self.assertEqual(ti.xcom_pull(task_ids='test', key='return_value'), 
b'')
 
+    def test_arg_checking(self):
+        import os
+        from airflow.exceptions import AirflowException
+        conn_id = "conn_id_for_testing"
+        TIMEOUT = 5
+        os.environ['AIRFLOW_CONN_' + conn_id.upper()] = 
"ssh://test_id@localhost"
+
+        # Exception should be raised if neither ssh_hook nor ssh_conn_id is 
provided
+        if six.PY2:
+            self.assertRaisesRegex = self.assertRaisesRegexp
+        with self.assertRaisesRegex(AirflowException,
+                                    "Cannot operate without ssh_hook or 
ssh_conn_id."):
+            task_0 = SSHOperator(task_id="test", command="echo -n airflow",
+                                 timeout=TIMEOUT, dag=self.dag)
+            task_0.execute(None)
+
+        # if ssh_hook is invalid/not provided, use ssh_conn_id to create 
SSHHook
+        task_1 = SSHOperator(
+            task_id="test_1",
+            ssh_hook="string_rather_than_SSHHook",  # invalid ssh_hook
+            ssh_conn_id=conn_id,
+            command="echo -n airflow",
+            timeout=TIMEOUT,
+            dag=self.dag
+        )
+        try:
+            task_1.execute(None)
+        except Exception:
+            pass
+        self.assertEqual(task_1.ssh_hook.ssh_conn_id, conn_id)
+
+        task_2 = SSHOperator(
+            task_id="test_2",
+            ssh_conn_id=conn_id,  # no ssh_hook provided
+            command="echo -n airflow",
+            timeout=TIMEOUT,
+            dag=self.dag
+        )
+        try:
+            task_2.execute(None)
+        except Exception:
+            pass
+        self.assertEqual(task_2.ssh_hook.ssh_conn_id, conn_id)
+
+        # if both valid ssh_hook and ssh_conn_id are provided, ignore 
ssh_conn_id
+        task_3 = SSHOperator(
+            task_id="test_3",
+            ssh_hook=self.hook,
+            ssh_conn_id=conn_id,
+            command="echo -n airflow",
+            timeout=TIMEOUT,
+            dag=self.dag
+        )
+        try:
+            task_3.execute(None)
+        except Exception:
+            pass
+        self.assertEqual(task_3.ssh_hook.ssh_conn_id, self.hook.ssh_conn_id)
+
 
 if __name__ == '__main__':
     unittest.main()


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to