This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 2197e4b Correctly handle get_pty attribute if command passed as
XComArg or template (#19323)
2197e4b is described below
commit 2197e4b59a7cf859eff5969b5f27b5e4f1084d3b
Author: Josh Fell <[email protected]>
AuthorDate: Fri Oct 29 17:16:31 2021 -0400
Correctly handle get_pty attribute if command passed as XComArg or template
(#19323)
---
airflow/providers/ssh/operators/ssh.py | 6 +++++-
tests/providers/ssh/operators/test_ssh.py | 2 +-
2 files changed, 6 insertions(+), 2 deletions(-)
diff --git a/airflow/providers/ssh/operators/ssh.py
b/airflow/providers/ssh/operators/ssh.py
index 300e155..c095322 100644
--- a/airflow/providers/ssh/operators/ssh.py
+++ b/airflow/providers/ssh/operators/ssh.py
@@ -98,7 +98,7 @@ class SSHOperator(BaseOperator):
if self.cmd_timeout is None:
self.cmd_timeout = self.timeout if self.timeout else CMD_TIMEOUT
self.environment = environment
- self.get_pty = (self.command.startswith('sudo') or get_pty) if
self.command else get_pty
+ self.get_pty = get_pty
if self.timeout:
warnings.warn(
@@ -209,6 +209,10 @@ class SSHOperator(BaseOperator):
result = None
if self.command is None:
raise AirflowException("SSH operator error: SSH command not
specified. Aborting.")
+
+ # Forcing get_pty to True if the command begins with "sudo".
+ self.get_pty = self.command.startswith('sudo') or self.get_pty
+
try:
with self.get_ssh_client() as ssh_client:
result = self.run_ssh_client_command(ssh_client, self.command)
diff --git a/tests/providers/ssh/operators/test_ssh.py
b/tests/providers/ssh/operators/test_ssh.py
index 551b64b..b715dcf 100644
--- a/tests/providers/ssh/operators/test_ssh.py
+++ b/tests/providers/ssh/operators/test_ssh.py
@@ -227,7 +227,7 @@ class TestSSHOperator:
assert str(ctx.value) == "SSH operator error: SSH command not
specified. Aborting."
else:
task.execute(None)
- assert task.get_pty == get_pty_out
+ assert task.get_pty == get_pty_out
def test_ssh_client_managed_correctly(self):
# Ensure ssh_client gets created once