jedcunningham commented on a change in pull request #21326:
URL: https://github.com/apache/airflow/pull/21326#discussion_r799635055



##########
File path: airflow/providers/ssh/operators/ssh.py
##########
@@ -18,15 +18,16 @@
 
 import warnings
 from base64 import b64encode
-from select import select
-from typing import Optional, Sequence, Tuple, Union
-
-from paramiko.client import SSHClient
+from typing import TYPE_CHECKING, Optional, Sequence, Union
 
 from airflow.configuration import conf
 from airflow.exceptions import AirflowException
 from airflow.models import BaseOperator
-from airflow.providers.ssh.hooks.ssh import SSHHook
+
+if TYPE_CHECKING:
+    from paramiko.client import SSHClient
+
+    from airflow.providers.ssh.hooks.ssh import SSHHook

Review comment:
       For my own edification, why only import this when type checking?

##########
File path: tests/providers/ssh/operators/test_ssh.py
##########
@@ -42,117 +42,69 @@ class SSHClientSideEffect:
     def __init__(self, hook):
         self.hook = hook
 
-    def __call__(self):
-        self.return_value = self.hook.get_conn()
-        return self.return_value
-
 
 class TestSSHOperator:
     def setup_method(self):
-        from airflow.providers.ssh.hooks.ssh import SSHHook
 
-        hook = SSHHook(ssh_conn_id='ssh_default', banner_timeout=100)
+        hook = SSHHook(ssh_conn_id='ssh_default')
         hook.no_host_key_check = True
-        self.dag = DAG('ssh_test', default_args={'start_date': DEFAULT_DATE})
+
+        ssh_client = mock.create_autospec(SSHClient)
+        # `with ssh_client` should return itself.
+        ssh_client.__enter__.return_value = ssh_client
+        hook.get_conn = mock.MagicMock(return_value=ssh_client)
         self.hook = hook
 
-    def test_hook_created_correctly_with_timeout(self):
-        timeout = 20
-        ssh_id = "ssh_default"
-        with self.dag:
-            task = SSHOperator(
-                task_id="test",
-                command=COMMAND,
-                timeout=timeout,
-                ssh_conn_id="ssh_default",
-                banner_timeout=100,
-            )
-        task.execute(None)
-        assert timeout == task.ssh_hook.conn_timeout
-        assert ssh_id == task.ssh_hook.ssh_conn_id
+    # Make sure nothing in this test actually connects to SSH -- that's for 
hook tests.
+    @pytest.fixture(autouse=True)
+    def _patch_exec_ssh_client(self):
+        with mock.patch.object(self.hook, 'exec_ssh_client_command') as 
exec_ssh_client_command:
+            self.exec_ssh_client_command = exec_ssh_client_command
+            exec_ssh_client_command.return_value = (0, b'airflow', '')
+            yield exec_ssh_client_command
 
     def test_hook_created_correctly(self):
         conn_timeout = 20
         cmd_timeout = 45
-        ssh_id = 'ssh_default'
-        with self.dag:
-            task = SSHOperator(
-                task_id="test",
-                command=COMMAND,
-                conn_timeout=conn_timeout,
-                cmd_timeout=cmd_timeout,
-                ssh_conn_id="ssh_default",
-                banner_timeout=100,
-            )
-        task.execute(None)
-        assert conn_timeout == task.ssh_hook.conn_timeout
-        assert ssh_id == task.ssh_hook.ssh_conn_id
-
-    @conf_vars({('core', 'enable_xcom_pickling'): 'False'})
-    def test_json_command_execution(self, create_task_instance_of_operator):
-        ti = create_task_instance_of_operator(
-            SSHOperator,
-            dag_id="unit_tests_ssh_test_op_json_command_execution",
-            task_id="test",
-            ssh_hook=self.hook,
-            command=COMMAND,
-            do_xcom_push=True,
-            banner_timeout=100,
-        )
-        ti.run()
-        assert ti.duration is not None
-        assert ti.xcom_pull(task_ids='test', key='return_value') == 
b64encode(b'airflow').decode('utf-8')
-
-    @conf_vars({('core', 'enable_xcom_pickling'): 'True'})
-    def test_pickle_command_execution(self, create_task_instance_of_operator):
-        ti = create_task_instance_of_operator(
-            SSHOperator,
-            dag_id="unit_tests_ssh_test_op_pickle_command_execution",
+        task = SSHOperator(
             task_id="test",
-            ssh_hook=self.hook,
             command=COMMAND,
-            do_xcom_push=True,
-            banner_timeout=100,
+            conn_timeout=conn_timeout,
+            cmd_timeout=cmd_timeout,
+            ssh_conn_id="ssh_default",
         )
-        ti.run()
-        assert ti.duration is not None
-        assert ti.xcom_pull(task_ids='test', key='return_value') == b'airflow'
+        ssh_hook = task.get_hook()
+        assert conn_timeout == ssh_hook.conn_timeout
+        assert "ssh_default" == ssh_hook.ssh_conn_id
 
-    @conf_vars({('core', 'enable_xcom_pickling'): 'True'})
-    def test_command_execution_with_env(self, 
create_task_instance_of_operator):
-        ti = create_task_instance_of_operator(
-            SSHOperator,
-            dag_id="unit_tests_ssh_test_op_command_execution_with_env",
+    @conf_vars({('core', 'enable_xcom_pickling'): 'False'})

Review comment:
       This is false by default, no? Plus we also set it in the test body?
   ```suggestion
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to