This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 8d28fe848da Add ``SSHRemoteJobOperator`` for resilient remote job
execution (#60297)
8d28fe848da is described below
commit 8d28fe848da385634eeffc9865e85d15bd1882e8
Author: Kaxil Naik <[email protected]>
AuthorDate: Sat Jan 10 01:56:33 2026 +0000
Add ``SSHRemoteJobOperator`` for resilient remote job execution (#60297)
A deferrable operator for executing commands on remote hosts via SSH.
Jobs run detached and survive network interruptions, with logs streamed
incrementally back to Airflow. Supports both POSIX and Windows hosts.
---
docs/spelling_wordlist.txt | 1 +
providers/ssh/docs/index.rst | 2 +
providers/ssh/docs/operators/ssh_remote_job.rst | 264 ++++++++++++
providers/ssh/provider.yaml | 8 +
providers/ssh/pyproject.toml | 1 +
.../src/airflow/providers/ssh/get_provider_info.py | 12 +-
.../ssh/src/airflow/providers/ssh/hooks/ssh.py | 142 +++++++
.../providers/ssh/operators/ssh_remote_job.py | 455 +++++++++++++++++++++
.../src/airflow/providers/ssh/triggers/__init__.py | 17 +
.../providers/ssh/triggers/ssh_remote_job.py | 271 ++++++++++++
.../src/airflow/providers/ssh/utils/__init__.py | 17 +
.../src/airflow/providers/ssh/utils/remote_job.py | 448 ++++++++++++++++++++
.../ssh/tests/unit/ssh/hooks/test_ssh_async.py | 172 ++++++++
.../unit/ssh/operators/test_ssh_remote_job.py | 331 +++++++++++++++
providers/ssh/tests/unit/ssh/triggers/__init__.py | 17 +
.../tests/unit/ssh/triggers/test_ssh_remote_job.py | 197 +++++++++
providers/ssh/tests/unit/ssh/utils/__init__.py | 17 +
.../ssh/tests/unit/ssh/utils/test_remote_job.py | 254 ++++++++++++
18 files changed, 2625 insertions(+), 1 deletion(-)
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index ec1de59d310..36e369f2648 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -1228,6 +1228,7 @@ Nodegroups
nodegroups
nodeName
nodeSelector
+nohup
Nones
nonnegative
nonterminal
diff --git a/providers/ssh/docs/index.rst b/providers/ssh/docs/index.rst
index 15cc24b29fe..de567fcc63f 100644
--- a/providers/ssh/docs/index.rst
+++ b/providers/ssh/docs/index.rst
@@ -35,6 +35,7 @@
:caption: Guides
Connection types <connections/ssh>
+ Operators <operators/ssh_remote_job>
.. toctree::
:hidden:
@@ -93,6 +94,7 @@ PIP package Version required
========================================== ==================
``apache-airflow`` ``>=2.11.0``
``apache-airflow-providers-common-compat`` ``>=1.10.1``
+``asyncssh`` ``>=2.12.0``
``paramiko`` ``>=2.9.0,<4.0.0``
``sshtunnel`` ``>=0.3.2``
========================================== ==================
diff --git a/providers/ssh/docs/operators/ssh_remote_job.rst
b/providers/ssh/docs/operators/ssh_remote_job.rst
new file mode 100644
index 00000000000..c78a0792740
--- /dev/null
+++ b/providers/ssh/docs/operators/ssh_remote_job.rst
@@ -0,0 +1,264 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ .. http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+
+
+
+.. _howto/operator:SSHRemoteJobOperator:
+
+SSHRemoteJobOperator
+====================
+
+Use the
:class:`~airflow.providers.ssh.operators.ssh_remote_job.SSHRemoteJobOperator`
to execute
+commands on a remote server via SSH as a detached job. This operator is
**deferrable**, meaning it
+offloads long-running job monitoring to the triggerer, freeing up worker slots
for other tasks.
+
+This operator is designed to be more resilient than
:class:`~airflow.providers.ssh.operators.ssh.SSHOperator`
+for long-running jobs, especially in environments where network interruptions
or worker restarts may occur.
+
+Key Features
+------------
+
+* **Deferrable**: Offloads job monitoring to the triggerer process
+* **Detached Execution**: Starts remote jobs that continue running even if SSH
connection drops
+* **Incremental Log Streaming**: Tails logs from remote host and displays them
in Airflow
+* **Cross-Platform**: Supports both POSIX (Linux/macOS) and Windows remote
hosts
+* **Resilient**: Jobs survive network interruptions and worker restarts
+* **File-based Completion**: Uses exit code file for reliable completion
detection
+
+When to Use This Operator
+--------------------------
+
+Use ``SSHRemoteJobOperator`` when:
+
+* Running long-running jobs (minutes to hours) on remote hosts
+* Network stability is a concern
+* You need to see incremental logs as the job progresses
+* The remote job should survive temporary disconnections
+* You want to free up worker slots during job execution
+
+Use the traditional :class:`~airflow.providers.ssh.operators.ssh.SSHOperator`
when:
+
+* Running short commands (seconds)
+* You need bidirectional communication during execution
+* The command requires an interactive TTY
+
+How It Works
+------------
+
+1. **Job Submission**: The operator connects via SSH and submits a wrapper
script that:
+
+ * Creates a unique job directory on the remote host
+ * Starts your command as a detached process (``nohup`` on POSIX,
``Start-Process`` on Windows)
+ * Redirects output to a log file
+ * Writes exit code to a file when complete
+
+2. **Deferral**: The operator immediately defers to
:class:`~airflow.providers.ssh.triggers.ssh_remote_job.SSHRemoteJobTrigger`
+
+3. **Monitoring**: The trigger periodically:
+
+ * Checks if the exit code file exists (job complete)
+ * Reads new log content incrementally
+ * Yields events with log chunks back to the operator
+
+4. **Completion**: When the job finishes:
+
+ * Final logs are displayed
+ * Exit code is checked (0 = success, non-zero = failure)
+ * Optional cleanup of remote job directory
+
+Using the Operator
+------------------
+
+Basic Example
+^^^^^^^^^^^^^
+
+.. code-block:: python
+
+ from airflow.providers.ssh.operators.ssh_remote_job import
SSHRemoteJobOperator
+
+ run_script = SSHRemoteJobOperator(
+ task_id="run_remote_script",
+ ssh_conn_id="my_ssh_connection",
+ command="/path/to/script.sh",
+ poll_interval=5, # Check status every 5 seconds
+ cleanup="on_success", # Clean up remote files on success
+ )
+
+With Environment Variables
+^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+.. code-block:: python
+
+ run_with_env = SSHRemoteJobOperator(
+ task_id="run_with_environment",
+ ssh_conn_id="my_ssh_connection",
+ command="python process_data.py",
+ environment={
+ "DATA_PATH": "/data/input.csv",
+ "OUTPUT_PATH": "/data/output.csv",
+ "LOG_LEVEL": "INFO",
+ },
+ )
+
+Windows Remote Host
+^^^^^^^^^^^^^^^^^^^
+
+.. code-block:: python
+
+ run_on_windows = SSHRemoteJobOperator(
+ task_id="run_on_windows",
+ ssh_conn_id="windows_ssh_connection",
+ command="C:\\Scripts\\process.ps1",
+ remote_os="windows", # Explicitly specify Windows
+ poll_interval=10,
+ )
+
+With Timeout and Skip on Exit Code
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+.. code-block:: python
+
+ run_with_options = SSHRemoteJobOperator(
+ task_id="run_with_options",
+ ssh_conn_id="my_ssh_connection",
+ command="./long_running_job.sh",
+ timeout=3600, # Fail if not complete in 1 hour
+ skip_on_exit_code=99, # Skip task if job exits with code 99
+ cleanup="always", # Always clean up, even on failure
+ )
+
+Parameters
+----------
+
+* ``ssh_conn_id`` (str, required): The Airflow connection ID for SSH connection
+* ``command`` (str, required): The command or script path to execute on the
remote host
+* ``remote_host`` (str, optional): Override the host from the connection
+* ``environment`` (dict, optional): Environment variables to set for the
command
+* ``remote_base_dir`` (str, optional): Base directory for job files. Defaults
to:
+
+ * POSIX: ``/tmp/airflow-ssh-jobs``
+ * Windows: ``$env:TEMP\\airflow-ssh-jobs``
+
+* ``poll_interval`` (int, optional): How often to check job status in seconds.
Default: 5
+* ``log_chunk_size`` (int, optional): Maximum bytes to read from log per poll.
Default: 65536
+* ``timeout`` (int, optional): Hard timeout for the entire task in seconds.
Default: None (no timeout)
+* ``cleanup`` (str, optional): When to clean up remote job directory:
+
+ * ``"never"`` (default): Never clean up
+ * ``"on_success"``: Clean up only if job succeeds
+ * ``"always"``: Always clean up regardless of status
+
+* ``remote_os`` (str, optional): Remote OS type (``"auto"``, ``"posix"``,
``"windows"``). Default: ``"auto"``
+* ``skip_on_exit_code`` (int or list, optional): Exit code(s) that should
cause task to skip instead of fail
+
+Remote OS Detection
+-------------------
+
+The operator can automatically detect the remote OS (``remote_os="auto"``),
but explicit specification
+is more reliable:
+
+* Use ``remote_os="posix"`` for Linux and macOS hosts
+* Use ``remote_os="windows"`` for Windows hosts with OpenSSH Server
+
+For Windows, ensure:
+
+* OpenSSH Server is installed and running
+* PowerShell is available (default on modern Windows)
+* SSH connection allows command execution
+
+Job Directory Structure
+-----------------------
+
+Each job creates a directory on the remote host with these files:
+
+.. code-block:: text
+
+ /tmp/airflow-ssh-jobs/af_mydag_mytask_run123_try1_abc12345/
+ ├── stdout.log # Combined stdout/stderr
+ ├── exit_code # Final exit code (0 or non-zero)
+ ├── pid # Process ID (for on_kill)
+ └── status # Optional status file (for user scripts)
+
+Your command can access these via environment variables:
+
+* ``LOG_FILE``: Path to the log file
+* ``STATUS_FILE``: Path to the status file
+
+Connection Requirements
+-----------------------
+
+The SSH connection must support:
+
+* Non-interactive authentication (password or key-based)
+* Command execution without PTY
+* File I/O on the remote host
+
+See :ref:`howto/connection:ssh` for connection configuration.
+
+Limitations and Considerations
+-------------------------------
+
+**Network Interruptions**: While the operator is resilient to disconnections
during monitoring,
+the initial job submission must succeed. If submission fails, the task will
fail immediately.
+
+**Remote Process Management**: Jobs are detached using ``nohup`` (POSIX) or
``Start-Process`` (Windows).
+If the remote host reboots during job execution, the job will be lost.
+
+**Log Size**: Large log outputs may impact performance. The ``log_chunk_size``
parameter controls
+how much data is read per poll. For very large logs (GBs), consider having
your script write
+logs to a separate file and only log summaries to stdout.
+
+**Exit Code Detection**: The operator uses file-based exit code detection for
reliability.
+If your script uses ``exec`` to replace the shell process, ensure the exit
code is still
+written to the file.
+
+**Concurrent Jobs**: Each task instance creates a unique job directory.
Multiple concurrent
+tasks can run on the same remote host without conflicts.
+
+**Cleanup**: Use ``cleanup="on_success"`` or ``cleanup="always"`` to avoid
accumulating
+job directories on the remote host. For debugging, use ``cleanup="never"`` and
manually
+inspect the job directory.
+
+Comparison with SSHOperator
+----------------------------
+
++---------------------------+------------------+---------------------+
+| Feature | SSHOperator | SSHRemoteJobOperator|
++===========================+==================+=====================+
+| Execution Model | Synchronous | Asynchronous |
++---------------------------+------------------+---------------------+
+| Worker Slot Usage | Entire duration | Only during submit |
++---------------------------+------------------+---------------------+
+| Network Resilience | Low | High |
++---------------------------+------------------+---------------------+
+| Long-running Jobs | Not recommended | Designed for |
++---------------------------+------------------+---------------------+
+| Incremental Logs | No | Yes |
++---------------------------+------------------+---------------------+
+| Windows Support | Limited | Full (via OpenSSH) |
++---------------------------+------------------+---------------------+
+| Setup Complexity | Simple | Moderate |
++---------------------------+------------------+---------------------+
+
+Related Documentation
+---------------------
+
+* :class:`~airflow.providers.ssh.operators.ssh.SSHOperator` - Traditional
synchronous SSH operator
+* :class:`~airflow.providers.ssh.triggers.ssh_remote_job.SSHRemoteJobTrigger`
- Trigger used by this operator
+* :class:`~airflow.providers.ssh.hooks.ssh.SSHHook` - SSH hook for synchronous
operations
+* :class:`~airflow.providers.ssh.hooks.ssh.SSHHookAsync` - Async SSH hook for
triggers
+* :ref:`howto/connection:ssh` - SSH connection configuration
diff --git a/providers/ssh/provider.yaml b/providers/ssh/provider.yaml
index d546e79a82a..0ed4adcadea 100644
--- a/providers/ssh/provider.yaml
+++ b/providers/ssh/provider.yaml
@@ -82,17 +82,25 @@ integrations:
external-doc-url: https://tools.ietf.org/html/rfc4251
logo: /docs/integration-logos/SSH.png
tags: [protocol]
+ how-to-guide:
+ - /docs/apache-airflow-providers-ssh/operators/ssh_remote_job.rst
operators:
- integration-name: Secure Shell (SSH)
python-modules:
- airflow.providers.ssh.operators.ssh
+ - airflow.providers.ssh.operators.ssh_remote_job
hooks:
- integration-name: Secure Shell (SSH)
python-modules:
- airflow.providers.ssh.hooks.ssh
+triggers:
+ - integration-name: Secure Shell (SSH)
+ python-modules:
+ - airflow.providers.ssh.triggers.ssh_remote_job
+
connection-types:
- hook-class-name: airflow.providers.ssh.hooks.ssh.SSHHook
connection-type: ssh
diff --git a/providers/ssh/pyproject.toml b/providers/ssh/pyproject.toml
index 35faea62138..8db0cfbabed 100644
--- a/providers/ssh/pyproject.toml
+++ b/providers/ssh/pyproject.toml
@@ -60,6 +60,7 @@ requires-python = ">=3.10"
dependencies = [
"apache-airflow>=2.11.0",
"apache-airflow-providers-common-compat>=1.10.1", # use next version
+ "asyncssh>=2.12.0",
# TODO: Bump to >= 4.0.0 once
https://github.com/apache/airflow/issues/54079 is handled
"paramiko>=2.9.0,<4.0.0",
"sshtunnel>=0.3.2",
diff --git a/providers/ssh/src/airflow/providers/ssh/get_provider_info.py
b/providers/ssh/src/airflow/providers/ssh/get_provider_info.py
index 17b2028f8e5..4ed67152f73 100644
--- a/providers/ssh/src/airflow/providers/ssh/get_provider_info.py
+++ b/providers/ssh/src/airflow/providers/ssh/get_provider_info.py
@@ -32,17 +32,27 @@ def get_provider_info():
"external-doc-url": "https://tools.ietf.org/html/rfc4251",
"logo": "/docs/integration-logos/SSH.png",
"tags": ["protocol"],
+ "how-to-guide":
["/docs/apache-airflow-providers-ssh/operators/ssh_remote_job.rst"],
}
],
"operators": [
{
"integration-name": "Secure Shell (SSH)",
- "python-modules": ["airflow.providers.ssh.operators.ssh"],
+ "python-modules": [
+ "airflow.providers.ssh.operators.ssh",
+ "airflow.providers.ssh.operators.ssh_remote_job",
+ ],
}
],
"hooks": [
{"integration-name": "Secure Shell (SSH)", "python-modules":
["airflow.providers.ssh.hooks.ssh"]}
],
+ "triggers": [
+ {
+ "integration-name": "Secure Shell (SSH)",
+ "python-modules":
["airflow.providers.ssh.triggers.ssh_remote_job"],
+ }
+ ],
"connection-types": [
{"hook-class-name": "airflow.providers.ssh.hooks.ssh.SSHHook",
"connection-type": "ssh"}
],
diff --git a/providers/ssh/src/airflow/providers/ssh/hooks/ssh.py
b/providers/ssh/src/airflow/providers/ssh/hooks/ssh.py
index c5f6deb3a8f..493f3f92369 100644
--- a/providers/ssh/src/airflow/providers/ssh/hooks/ssh.py
+++ b/providers/ssh/src/airflow/providers/ssh/hooks/ssh.py
@@ -530,3 +530,145 @@ class SSHHook(BaseHook):
return True, "Connection successfully tested"
except Exception as e:
return False, str(e)
+
+
+class SSHHookAsync(BaseHook):
+ """
+ Asynchronous SSH hook using asyncssh for use in triggers.
+
+ This hook provides async SSH connectivity for deferrable operators
+ and their triggers.
+
+ :param ssh_conn_id: SSH connection ID from Airflow Connections
+ :param host: hostname of the SSH server (overrides connection)
+ :param port: port of the SSH server (overrides connection)
+ :param username: username for authentication (overrides connection)
+ :param password: password for authentication (overrides connection)
+ :param known_hosts: path to known_hosts file. Defaults to
``~/.ssh/known_hosts``.
+ :param key_file: path to private key file for authentication
+ :param passphrase: passphrase for the private key
+ :param private_key: private key content as string
+ """
+
+ conn_name_attr = "ssh_conn_id"
+ default_conn_name = "ssh_default"
+ conn_type = "ssh"
+ hook_name = "SSH"
+ default_known_hosts = "~/.ssh/known_hosts"
+
+ def __init__(
+ self,
+ ssh_conn_id: str = default_conn_name,
+ host: str | None = None,
+ port: int | None = None,
+ username: str | None = None,
+ password: str | None = None,
+ known_hosts: str = default_known_hosts,
+ key_file: str = "",
+ passphrase: str = "",
+ private_key: str = "",
+ ) -> None:
+ super().__init__()
+ self.ssh_conn_id = ssh_conn_id
+ self.host = host
+ self.port = port
+ self.username = username
+ self.password = password
+ self.known_hosts: bytes | str = os.path.expanduser(known_hosts)
+ self.key_file = key_file
+ self.passphrase = passphrase
+ self.private_key = private_key
+
+ def _parse_extras(self, conn: Any) -> None:
+ """Parse extra fields from the connection into instance fields."""
+ extra_options = conn.extra_dejson
+ if "key_file" in extra_options and self.key_file == "":
+ self.key_file = extra_options["key_file"]
+ if "known_hosts" in extra_options:
+ expanded_default = os.path.expanduser(self.default_known_hosts)
+ if self.known_hosts == expanded_default:
+ self.known_hosts = extra_options["known_hosts"]
+ if "passphrase" in extra_options or "private_key_passphrase" in
extra_options:
+ self.passphrase = extra_options.get("passphrase") or
extra_options.get(
+ "private_key_passphrase", ""
+ )
+ if "private_key" in extra_options:
+ self.private_key = extra_options["private_key"]
+
+ host_key = extra_options.get("host_key")
+ nhkc_raw = extra_options.get("no_host_key_check")
+ no_host_key_check = str(nhkc_raw).lower() == "true" if nhkc_raw is not
None else False
+
+ if host_key is not None and no_host_key_check:
+ raise ValueError("Host key check was skipped, but `host_key` value
was given")
+
+ if no_host_key_check:
+ self.log.warning("No Host Key Verification. This won't protect
against Man-In-The-Middle attacks")
+ self.known_hosts = "none"
+ elif host_key is not None:
+ self.known_hosts = f"{conn.host} {host_key}".encode()
+
+ async def _get_conn(self):
+ """
+ Asynchronously connect to the SSH server.
+
+ Returns an asyncssh SSHClientConnection that can be used to run
commands.
+ """
+ import asyncssh
+ from asgiref.sync import sync_to_async
+
+ conn = await sync_to_async(self.get_connection)(self.ssh_conn_id)
+ if conn.extra is not None:
+ self._parse_extras(conn)
+
+ def _get_value(self_val, conn_val, default=None):
+ if self_val is not None:
+ return self_val
+ if conn_val is not None:
+ return conn_val
+ return default
+
+ conn_config: dict = {
+ "host": _get_value(self.host, conn.host),
+ "port": _get_value(self.port, conn.port, SSH_PORT),
+ "username": _get_value(self.username, conn.login),
+ "password": _get_value(self.password, conn.password),
+ }
+ if self.key_file:
+ conn_config["client_keys"] = self.key_file
+ if self.known_hosts:
+ if isinstance(self.known_hosts, str) and self.known_hosts.lower()
== "none":
+ conn_config["known_hosts"] = None
+ else:
+ conn_config["known_hosts"] = self.known_hosts
+ if self.private_key:
+ _private_key = asyncssh.import_private_key(self.private_key,
self.passphrase)
+ conn_config["client_keys"] = [_private_key]
+ if self.passphrase:
+ conn_config["passphrase"] = self.passphrase
+
+ ssh_client_conn = await asyncssh.connect(**conn_config)
+ return ssh_client_conn
+
+ async def run_command(self, command: str, timeout: float | None = None) ->
tuple[int, str, str]:
+ """
+ Execute a command on the remote host asynchronously.
+
+ :param command: The command to execute
+ :param timeout: Optional timeout in seconds
+ :return: Tuple of (exit_code, stdout, stderr)
+ """
+ async with await self._get_conn() as ssh_conn:
+ result = await ssh_conn.run(command, timeout=timeout, check=False)
+ return result.exit_status or 0, result.stdout or "", result.stderr
or ""
+
+ async def run_command_output(self, command: str, timeout: float | None =
None) -> str:
+ """
+ Execute a command and return stdout.
+
+ :param command: The command to execute
+ :param timeout: Optional timeout in seconds
+ :return: stdout as string
+ """
+ _, stdout, _ = await self.run_command(command, timeout=timeout)
+ return stdout
diff --git
a/providers/ssh/src/airflow/providers/ssh/operators/ssh_remote_job.py
b/providers/ssh/src/airflow/providers/ssh/operators/ssh_remote_job.py
new file mode 100644
index 00000000000..783edb39b8f
--- /dev/null
+++ b/providers/ssh/src/airflow/providers/ssh/operators/ssh_remote_job.py
@@ -0,0 +1,455 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""SSH Remote Job Operator for deferrable remote command execution."""
+
+from __future__ import annotations
+
+import warnings
+from collections.abc import Container, Sequence
+from datetime import timedelta
+from functools import cached_property
+from typing import TYPE_CHECKING, Any, Literal
+
+from airflow.providers.common.compat.sdk import AirflowException,
AirflowSkipException, BaseOperator
+from airflow.providers.ssh.hooks.ssh import SSHHook
+from airflow.providers.ssh.triggers.ssh_remote_job import SSHRemoteJobTrigger
+from airflow.providers.ssh.utils.remote_job import (
+ RemoteJobPaths,
+ build_posix_cleanup_command,
+ build_posix_kill_command,
+ build_posix_os_detection_command,
+ build_posix_wrapper_command,
+ build_windows_cleanup_command,
+ build_windows_kill_command,
+ build_windows_os_detection_command,
+ build_windows_wrapper_command,
+ generate_job_id,
+)
+
+if TYPE_CHECKING:
+ from airflow.providers.common.compat.sdk import Context
+
+
+class SSHRemoteJobOperator(BaseOperator):
+ r"""
+ Execute a command on a remote host via SSH with deferrable monitoring.
+
+ This operator submits a job to run detached on the remote host, then
+ uses a trigger to asynchronously monitor the job status and stream logs.
+ This approach is resilient to network interruptions as the remote job
+ continues running independently of the SSH connection.
+
+ The remote job is wrapped to:
+ - Run detached from the SSH session (via nohup on POSIX, Start-Process on
Windows)
+ - Redirect stdout/stderr to a log file
+ - Write the exit code to a file on completion
+
+ :param ssh_conn_id: SSH connection ID from Airflow Connections
+ :param command: Command to execute on the remote host (templated)
+ :param remote_host: Override the host from the connection (templated)
+ :param environment: Environment variables to set for the command
(templated)
+ :param remote_base_dir: Base directory for job artifacts (templated).
+ Defaults to /tmp/airflow-ssh-jobs on POSIX,
C:\\Windows\\Temp\\airflow-ssh-jobs on Windows
+ :param poll_interval: Seconds between status polls (default: 5)
+ :param log_chunk_size: Max bytes to read per poll (default: 65536)
+ :param timeout: Hard timeout in seconds for the entire operation
+ :param cleanup: When to clean up remote job directory:
+ 'never', 'on_success', or 'always' (default: 'never')
+ :param remote_os: Remote operating system: 'auto', 'posix', or 'windows'
(default: 'auto')
+ :param skip_on_exit_code: Exit codes that should skip the task instead of
failing
+ :param conn_timeout: SSH connection timeout in seconds
+ :param banner_timeout: Timeout waiting for SSH banner in seconds
+ """
+
+ template_fields: Sequence[str] = ("command", "environment", "remote_host",
"remote_base_dir")
+ template_ext: Sequence[str] = (
+ ".sh",
+ ".bash",
+ ".ps1",
+ )
+ template_fields_renderers = {
+ "command": "bash",
+ "environment": "python",
+ }
+ ui_color = "#e4f0e8"
+
+ def __init__(
+ self,
+ *,
+ ssh_conn_id: str,
+ command: str,
+ remote_host: str | None = None,
+ environment: dict[str, str] | None = None,
+ remote_base_dir: str | None = None,
+ poll_interval: int = 5,
+ log_chunk_size: int = 65536,
+ timeout: int | None = None,
+ cleanup: Literal["never", "on_success", "always"] = "never",
+ remote_os: Literal["auto", "posix", "windows"] = "auto",
+ skip_on_exit_code: int | Container[int] | None = None,
+ conn_timeout: int | None = None,
+ banner_timeout: float = 30.0,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.ssh_conn_id = ssh_conn_id
+ self.command = command
+ self.remote_host = remote_host
+ self.environment = environment
+
+ if remote_base_dir is not None:
+ self._validate_base_dir(remote_base_dir)
+ self.remote_base_dir = remote_base_dir
+
+ self.poll_interval = poll_interval
+ self.log_chunk_size = log_chunk_size
+ self.timeout = timeout
+ self.cleanup = cleanup
+ self.remote_os = remote_os
+ self.conn_timeout = conn_timeout
+ self.banner_timeout = banner_timeout
+ self.skip_on_exit_code = (
+ skip_on_exit_code
+ if isinstance(skip_on_exit_code, Container)
+ else [skip_on_exit_code]
+ if skip_on_exit_code is not None
+ else []
+ )
+
+ self._job_id: str | None = None
+ self._paths: RemoteJobPaths | None = None
+ self._detected_os: Literal["posix", "windows"] | None = None
+
+ @staticmethod
+ def _validate_base_dir(path: str) -> None:
+ """
+ Validate the remote base directory path for security.
+
+ :param path: Path to validate
+ :raises ValueError: If path contains dangerous patterns
+ """
+ if not path:
+ raise ValueError("remote_base_dir cannot be empty")
+
+ if ".." in path:
+ raise ValueError(f"remote_base_dir cannot contain '..' (path
traversal not allowed). Got: {path}")
+
+ if "\x00" in path:
+ raise ValueError("remote_base_dir cannot contain null bytes")
+
+ dangerous_patterns = ["/etc", "/bin", "/sbin", "/boot", "C:\\Windows",
"C:\\Program Files"]
+ for pattern in dangerous_patterns:
+ if pattern in path:
+ warnings.warn(
+ f"remote_base_dir '{path}' contains potentially sensitive
path '{pattern}'. "
+ "Ensure you have appropriate permissions.",
+ UserWarning,
+ stacklevel=3,
+ )
+
+ @cached_property
+ def ssh_hook(self) -> SSHHook:
+ """Create the SSH hook for command submission."""
+ return SSHHook(
+ ssh_conn_id=self.ssh_conn_id,
+ remote_host=self.remote_host or "",
+ conn_timeout=self.conn_timeout,
+ banner_timeout=self.banner_timeout,
+ )
+
+ def _detect_remote_os(self) -> Literal["posix", "windows"]:
+ """
+ Detect the remote operating system.
+
+ Uses a two-stage detection:
+ 1. Try POSIX detection via `uname` (works on Linux, macOS, BSD,
Solaris, AIX, etc.)
+ 2. Try Windows detection via PowerShell
+ 3. Raise error if both fail
+ """
+ if self.remote_os != "auto":
+ return self.remote_os
+
+ self.log.info("Auto-detecting remote operating system...")
+ with self.ssh_hook.get_conn() as ssh_client:
+ try:
+ exit_status, stdout, _ = self.ssh_hook.exec_ssh_client_command(
+ ssh_client,
+ build_posix_os_detection_command(),
+ get_pty=False,
+ environment=None,
+ timeout=10,
+ )
+ if exit_status == 0 and stdout:
+ output = stdout.decode("utf-8",
errors="replace").strip().lower()
+ posix_systems = [
+ "linux",
+ "darwin",
+ "freebsd",
+ "openbsd",
+ "netbsd",
+ "sunos",
+ "aix",
+ "hp-ux",
+ ]
+ if any(system in output for system in posix_systems):
+ self.log.info("Detected POSIX system: %s", output)
+ return "posix"
+ except Exception as e:
+ self.log.debug("POSIX detection failed: %s", e)
+
+ try:
+ exit_status, stdout, _ = self.ssh_hook.exec_ssh_client_command(
+ ssh_client,
+ build_windows_os_detection_command(),
+ get_pty=False,
+ environment=None,
+ timeout=10,
+ )
+ if exit_status == 0 and stdout:
+ output = stdout.decode("utf-8", errors="replace").strip()
+ if "WINDOWS" in output.upper():
+ self.log.info("Detected Windows system")
+ return "windows"
+ except Exception as e:
+ self.log.debug("Windows detection failed: %s", e)
+
+ raise AirflowException(
+ "Could not auto-detect remote OS. Please explicitly set
remote_os='posix' or 'windows'"
+ )
+
+ def execute(self, context: Context) -> None:
+ """
+ Submit the remote job and defer to the trigger for monitoring.
+
+ :param context: Airflow task context
+ """
+ if not self.command:
+ raise AirflowException("SSH operator error: command not
specified.")
+
+ self._detected_os = self._detect_remote_os()
+ self.log.info("Remote OS: %s", self._detected_os)
+
+ ti = context["ti"]
+ self._job_id = generate_job_id(
+ dag_id=ti.dag_id,
+ task_id=ti.task_id,
+ run_id=ti.run_id,
+ try_number=ti.try_number,
+ )
+ self.log.info("Generated job ID: %s", self._job_id)
+
+ self._paths = RemoteJobPaths(
+ job_id=self._job_id,
+ remote_os=self._detected_os,
+ base_dir=self.remote_base_dir,
+ )
+
+ if self._detected_os == "posix":
+ wrapper_cmd = build_posix_wrapper_command(
+ command=self.command,
+ paths=self._paths,
+ environment=self.environment,
+ )
+ else:
+ wrapper_cmd = build_windows_wrapper_command(
+ command=self.command,
+ paths=self._paths,
+ environment=self.environment,
+ )
+
+ self.log.info("Submitting remote job to %s", self.ssh_hook.remote_host)
+ with self.ssh_hook.get_conn() as ssh_client:
+ exit_status, stdout, stderr =
self.ssh_hook.exec_ssh_client_command(
+ ssh_client,
+ wrapper_cmd,
+ get_pty=False,
+ environment=None,
+ timeout=60,
+ )
+
+ if exit_status != 0:
+ stderr_str = stderr.decode("utf-8", errors="replace") if
stderr else ""
+ raise AirflowException(
+ f"Failed to submit remote job. Exit code: {exit_status}.
Error: {stderr_str}"
+ )
+
+ returned_job_id = stdout.decode("utf-8", errors="replace").strip()
if stdout else ""
+ if returned_job_id != self._job_id:
+ self.log.warning("Job ID mismatch. Expected: %s, Got: %s",
self._job_id, returned_job_id)
+
+ self.log.info("Remote job submitted successfully. Job ID: %s",
self._job_id)
+ self.log.info("Job directory: %s", self._paths.job_dir)
+
+ if self.do_xcom_push:
+ ti.xcom_push(
+ key="ssh_remote_job",
+ value={
+ "job_id": self._job_id,
+ "job_dir": self._paths.job_dir,
+ "log_file": self._paths.log_file,
+ "exit_code_file": self._paths.exit_code_file,
+ "pid_file": self._paths.pid_file,
+ "remote_os": self._detected_os,
+ },
+ )
+
+ self.defer(
+ trigger=SSHRemoteJobTrigger(
+ ssh_conn_id=self.ssh_conn_id,
+ remote_host=self.remote_host,
+ job_id=self._job_id,
+ job_dir=self._paths.job_dir,
+ log_file=self._paths.log_file,
+ exit_code_file=self._paths.exit_code_file,
+ remote_os=self._detected_os,
+ poll_interval=self.poll_interval,
+ log_chunk_size=self.log_chunk_size,
+ log_offset=0,
+ ),
+ method_name="execute_complete",
+ timeout=timedelta(seconds=self.timeout) if self.timeout else None,
+ )
+
+ def execute_complete(self, context: Context, event: dict[str, Any]) ->
None:
+ """
+ Handle trigger events and re-defer if job is still running.
+
+ :param context: Airflow task context
+ :param event: Event data from the trigger
+ """
+ if not event:
+ raise AirflowException("Received null event from trigger")
+
+ required_keys = ["job_id", "job_dir", "log_file", "exit_code_file",
"remote_os", "done"]
+ missing_keys = [key for key in required_keys if key not in event]
+ if missing_keys:
+ raise AirflowException(
+ f"Invalid trigger event: missing required keys {missing_keys}.
Event: {event}"
+ )
+
+ log_chunk = event.get("log_chunk", "")
+ if log_chunk:
+ for line in log_chunk.splitlines():
+ self.log.info("[remote] %s", line)
+
+ if not event.get("done", False):
+ self.log.debug("Job still running, continuing to monitor...")
+ self.defer(
+ trigger=SSHRemoteJobTrigger(
+ ssh_conn_id=self.ssh_conn_id,
+ remote_host=self.remote_host,
+ job_id=event["job_id"],
+ job_dir=event["job_dir"],
+ log_file=event["log_file"],
+ exit_code_file=event["exit_code_file"],
+ remote_os=event["remote_os"],
+ poll_interval=self.poll_interval,
+ log_chunk_size=self.log_chunk_size,
+ log_offset=event.get("log_offset", 0),
+ ),
+ method_name="execute_complete",
+ timeout=timedelta(seconds=self.timeout) if self.timeout else
None,
+ )
+ return
+
+ exit_code = event.get("exit_code")
+ job_dir = event.get("job_dir", "")
+ remote_os = event.get("remote_os", "posix")
+
+ self.log.info("Remote job completed with exit code: %s", exit_code)
+
+ should_cleanup = self.cleanup == "always" or (self.cleanup ==
"on_success" and exit_code == 0)
+ if should_cleanup and job_dir:
+ self._cleanup_remote_job(job_dir, remote_os)
+
+ if exit_code is None:
+ raise AirflowException(f"Remote job failed: {event.get('message',
'Unknown error')}")
+
+ if exit_code in self.skip_on_exit_code:
+ raise AirflowSkipException(f"Remote job returned skip exit code:
{exit_code}")
+
+ if exit_code != 0:
+ raise AirflowException(f"Remote job failed with exit code:
{exit_code}")
+
+ self.log.info("Remote job completed successfully")
+
+ def _cleanup_remote_job(self, job_dir: str, remote_os: str) -> None:
+ """Clean up the remote job directory."""
+ self.log.info("Cleaning up remote job directory: %s", job_dir)
+ try:
+ if remote_os == "posix":
+ cleanup_cmd = build_posix_cleanup_command(job_dir)
+ else:
+ cleanup_cmd = build_windows_cleanup_command(job_dir)
+
+ with self.ssh_hook.get_conn() as ssh_client:
+ self.ssh_hook.exec_ssh_client_command(
+ ssh_client,
+ cleanup_cmd,
+ get_pty=False,
+ environment=None,
+ timeout=30,
+ )
+ self.log.info("Remote cleanup completed")
+ except Exception as e:
+ self.log.warning("Failed to clean up remote job directory: %s", e)
+
+ def on_kill(self) -> None:
+ """
+ Attempt to kill the remote process when the task is killed.
+
+ Since the operator is recreated after deferral, instance variables may
not
+ be set. We retrieve job information from XCom if needed.
+ """
+ job_id = self._job_id
+ pid_file = self._paths.pid_file if self._paths else None
+ remote_os = self._detected_os
+
+ if not job_id or not pid_file or not remote_os:
+ try:
+ if hasattr(self, "task_instance") and self.task_instance:
+ job_info =
self.task_instance.xcom_pull(key="ssh_remote_job")
+ if job_info:
+ job_id = job_info.get("job_id")
+ pid_file = job_info.get("pid_file")
+ remote_os = job_info.get("remote_os")
+ except Exception as e:
+ self.log.debug("Could not retrieve job info from XCom: %s", e)
+
+ if not job_id or not pid_file or not remote_os:
+ self.log.info("No active job information available for kill")
+ return
+
+ self.log.info("Attempting to kill remote job: %s", job_id)
+ try:
+ if remote_os == "posix":
+ kill_cmd = build_posix_kill_command(pid_file)
+ else:
+ kill_cmd = build_windows_kill_command(pid_file)
+
+ with self.ssh_hook.get_conn() as ssh_client:
+ self.ssh_hook.exec_ssh_client_command(
+ ssh_client,
+ kill_cmd,
+ get_pty=False,
+ environment=None,
+ timeout=30,
+ )
+ self.log.info("Kill command sent to remote process")
+ except Exception as e:
+ self.log.warning("Failed to kill remote process: %s", e)
diff --git a/providers/ssh/src/airflow/providers/ssh/triggers/__init__.py
b/providers/ssh/src/airflow/providers/ssh/triggers/__init__.py
new file mode 100644
index 00000000000..217e5db9607
--- /dev/null
+++ b/providers/ssh/src/airflow/providers/ssh/triggers/__init__.py
@@ -0,0 +1,17 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/providers/ssh/src/airflow/providers/ssh/triggers/ssh_remote_job.py
b/providers/ssh/src/airflow/providers/ssh/triggers/ssh_remote_job.py
new file mode 100644
index 00000000000..0d4072c1ca4
--- /dev/null
+++ b/providers/ssh/src/airflow/providers/ssh/triggers/ssh_remote_job.py
@@ -0,0 +1,271 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""SSH Remote Job Trigger for deferrable execution."""
+
+from __future__ import annotations
+
+import asyncio
+from collections.abc import AsyncIterator
+from typing import Any, Literal
+
+import tenacity
+
+from airflow.providers.ssh.hooks.ssh import SSHHookAsync
+from airflow.providers.ssh.utils.remote_job import (
+ build_posix_completion_check_command,
+ build_posix_file_size_command,
+ build_posix_log_tail_command,
+ build_windows_completion_check_command,
+ build_windows_file_size_command,
+ build_windows_log_tail_command,
+)
+from airflow.triggers.base import BaseTrigger, TriggerEvent
+
+
+class SSHRemoteJobTrigger(BaseTrigger):
+ """
+ Trigger that monitors a remote SSH job and streams logs.
+
+ This trigger polls the remote host to check job completion status
+ and reads log output incrementally.
+
+ :param ssh_conn_id: SSH connection ID from Airflow Connections
+ :param remote_host: Optional override for the remote host
+ :param job_id: Unique identifier for the remote job
+ :param job_dir: Remote directory containing job artifacts
+ :param log_file: Path to the log file on the remote host
+ :param exit_code_file: Path to the exit code file on the remote host
+ :param remote_os: Operating system of the remote host ('posix' or
'windows')
+ :param poll_interval: Seconds between polling attempts
+ :param log_chunk_size: Maximum bytes to read per poll
+ :param log_offset: Current byte offset in the log file
+ """
+
+ def __init__(
+ self,
+ ssh_conn_id: str,
+ remote_host: str | None,
+ job_id: str,
+ job_dir: str,
+ log_file: str,
+ exit_code_file: str,
+ remote_os: Literal["posix", "windows"],
+ poll_interval: int = 5,
+ log_chunk_size: int = 65536,
+ log_offset: int = 0,
+ command_timeout: float = 30.0,
+ ) -> None:
+ super().__init__()
+ self.ssh_conn_id = ssh_conn_id
+ self.remote_host = remote_host
+ self.job_id = job_id
+ self.job_dir = job_dir
+ self.log_file = log_file
+ self.exit_code_file = exit_code_file
+ self.remote_os = remote_os
+ self.poll_interval = poll_interval
+ self.log_chunk_size = log_chunk_size
+ self.log_offset = log_offset
+ self.command_timeout = command_timeout
+
+ def serialize(self) -> tuple[str, dict[str, Any]]:
+ """Serialize the trigger for storage."""
+ return (
+
"airflow.providers.ssh.triggers.ssh_remote_job.SSHRemoteJobTrigger",
+ {
+ "ssh_conn_id": self.ssh_conn_id,
+ "remote_host": self.remote_host,
+ "job_id": self.job_id,
+ "job_dir": self.job_dir,
+ "log_file": self.log_file,
+ "exit_code_file": self.exit_code_file,
+ "remote_os": self.remote_os,
+ "poll_interval": self.poll_interval,
+ "log_chunk_size": self.log_chunk_size,
+ "log_offset": self.log_offset,
+ "command_timeout": self.command_timeout,
+ },
+ )
+
+ def _get_hook(self) -> SSHHookAsync:
+ """Create the async SSH hook."""
+ return SSHHookAsync(
+ ssh_conn_id=self.ssh_conn_id,
+ host=self.remote_host,
+ )
+
+ @tenacity.retry(
+ stop=tenacity.stop_after_attempt(3),
+ wait=tenacity.wait_exponential(multiplier=1, min=1, max=10),
+ retry=tenacity.retry_if_exception_type((OSError, TimeoutError,
ConnectionError)),
+ reraise=True,
+ )
+ async def _check_completion(self, hook: SSHHookAsync) -> int | None:
+ """
+ Check if the remote job has completed.
+
+ Retries transient network errors up to 3 times with exponential
backoff.
+
+ :return: Exit code if completed, None if still running
+ """
+ if self.remote_os == "posix":
+ cmd = build_posix_completion_check_command(self.exit_code_file)
+ else:
+ cmd = build_windows_completion_check_command(self.exit_code_file)
+
+ try:
+ _, stdout, _ = await hook.run_command(cmd,
timeout=self.command_timeout)
+ stdout = stdout.strip()
+ if stdout and stdout.isdigit():
+ return int(stdout)
+ except (OSError, TimeoutError, ConnectionError) as e:
+ self.log.warning("Transient error checking completion (will
retry): %s", e)
+ raise
+ except Exception as e:
+ self.log.warning("Error checking completion status: %s", e)
+ return None
+
+ @tenacity.retry(
+ stop=tenacity.stop_after_attempt(3),
+ wait=tenacity.wait_exponential(multiplier=1, min=1, max=10),
+ retry=tenacity.retry_if_exception_type((OSError, TimeoutError,
ConnectionError)),
+ reraise=True,
+ )
+ async def _get_log_size(self, hook: SSHHookAsync) -> int:
+ """
+ Get the current size of the log file in bytes.
+
+ Retries transient network errors up to 3 times with exponential
backoff.
+ """
+ if self.remote_os == "posix":
+ cmd = build_posix_file_size_command(self.log_file)
+ else:
+ cmd = build_windows_file_size_command(self.log_file)
+
+ try:
+ _, stdout, _ = await hook.run_command(cmd,
timeout=self.command_timeout)
+ stdout = stdout.strip()
+ if stdout and stdout.isdigit():
+ return int(stdout)
+ except (OSError, TimeoutError, ConnectionError) as e:
+ self.log.warning("Transient error getting log size (will retry):
%s", e)
+ raise
+ except Exception as e:
+ self.log.warning("Error getting log file size: %s", e)
+ return 0
+
+ @tenacity.retry(
+ stop=tenacity.stop_after_attempt(3),
+ wait=tenacity.wait_exponential(multiplier=1, min=1, max=10),
+ retry=tenacity.retry_if_exception_type((OSError, TimeoutError,
ConnectionError)),
+ reraise=True,
+ )
+ async def _read_log_chunk(self, hook: SSHHookAsync) -> tuple[str, int]:
+ """
+ Read a chunk of logs from the current offset.
+
+ Retries transient network errors up to 3 times with exponential
backoff.
+
+ :return: Tuple of (log_chunk, new_offset)
+ """
+ file_size = await self._get_log_size(hook)
+ if file_size <= self.log_offset:
+ return "", self.log_offset
+
+ bytes_available = file_size - self.log_offset
+ bytes_to_read = min(bytes_available, self.log_chunk_size)
+
+ if self.remote_os == "posix":
+ cmd = build_posix_log_tail_command(self.log_file, self.log_offset,
bytes_to_read)
+ else:
+ cmd = build_windows_log_tail_command(self.log_file,
self.log_offset, bytes_to_read)
+
+ try:
+ exit_code, stdout, _ = await hook.run_command(cmd,
timeout=self.command_timeout)
+
+ # Advance offset by bytes requested, not decoded string length
+ new_offset = self.log_offset + bytes_to_read if stdout else
self.log_offset
+
+ return stdout, new_offset
+ except (OSError, TimeoutError, ConnectionError) as e:
+ self.log.warning("Transient error reading logs (will retry): %s",
e)
+ raise
+ except Exception as e:
+ self.log.warning("Error reading log chunk: %s", e)
+ return "", self.log_offset
+
+ async def run(self) -> AsyncIterator[TriggerEvent]:
+ """
+ Poll the remote job status and yield events with log chunks.
+
+ This method runs in a loop, checking the job status and reading
+ logs at each poll interval. It yields a TriggerEvent each time
+ with the current status and any new log output.
+ """
+ hook = self._get_hook()
+
+ while True:
+ try:
+ exit_code = await self._check_completion(hook)
+ log_chunk, new_offset = await self._read_log_chunk(hook)
+
+ base_event = {
+ "job_id": self.job_id,
+ "job_dir": self.job_dir,
+ "log_file": self.log_file,
+ "exit_code_file": self.exit_code_file,
+ "remote_os": self.remote_os,
+ }
+
+ if exit_code is not None:
+ yield TriggerEvent(
+ {
+ **base_event,
+ "status": "success" if exit_code == 0 else
"failed",
+ "done": True,
+ "exit_code": exit_code,
+ "log_chunk": log_chunk,
+ "log_offset": new_offset,
+ "message": f"Job completed with exit code
{exit_code}",
+ }
+ )
+ return
+
+ self.log_offset = new_offset
+ if log_chunk:
+ self.log.info("%s", log_chunk.rstrip())
+ await asyncio.sleep(self.poll_interval)
+
+ except Exception as e:
+ self.log.exception("Error in SSH remote job trigger")
+ yield TriggerEvent(
+ {
+ "job_id": self.job_id,
+ "job_dir": self.job_dir,
+ "log_file": self.log_file,
+ "exit_code_file": self.exit_code_file,
+ "remote_os": self.remote_os,
+ "status": "error",
+ "done": True,
+ "exit_code": None,
+ "log_chunk": "",
+ "log_offset": self.log_offset,
+ "message": f"Trigger error: {e}",
+ }
+ )
+ return
diff --git a/providers/ssh/src/airflow/providers/ssh/utils/__init__.py
b/providers/ssh/src/airflow/providers/ssh/utils/__init__.py
new file mode 100644
index 00000000000..217e5db9607
--- /dev/null
+++ b/providers/ssh/src/airflow/providers/ssh/utils/__init__.py
@@ -0,0 +1,17 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/providers/ssh/src/airflow/providers/ssh/utils/remote_job.py
b/providers/ssh/src/airflow/providers/ssh/utils/remote_job.py
new file mode 100644
index 00000000000..d761896651d
--- /dev/null
+++ b/providers/ssh/src/airflow/providers/ssh/utils/remote_job.py
@@ -0,0 +1,448 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utilities for SSH remote job execution."""
+
+from __future__ import annotations
+
+import base64
+import re
+import secrets
+import string
+from dataclasses import dataclass
+from typing import Literal
+
+POSIX_DEFAULT_BASE_DIR = "/tmp/airflow-ssh-jobs"
+WINDOWS_DEFAULT_BASE_DIR = "$env:TEMP\\airflow-ssh-jobs"
+
+
+def _validate_job_dir(job_dir: str, remote_os: Literal["posix", "windows"]) ->
None:
+ """
+ Validate that job_dir is under the expected base directory.
+
+ :param job_dir: The job directory path to validate
+ :param remote_os: Operating system type
+ :raises ValueError: If job_dir doesn't start with the expected base path
+ """
+ if remote_os == "posix":
+ expected_prefix = POSIX_DEFAULT_BASE_DIR + "/"
+ else:
+ expected_prefix = WINDOWS_DEFAULT_BASE_DIR + "\\"
+
+ if not job_dir.startswith(expected_prefix):
+ raise ValueError(
+ f"Invalid job directory '{job_dir}'. Expected path under
'{expected_prefix[:-1]}' for safety."
+ )
+
+
+def _validate_env_var_name(name: str) -> None:
+ """
+ Validate environment variable name for security.
+
+ :param name: Environment variable name
+ :raises ValueError: If name contains dangerous characters
+ """
+ if not name:
+ raise ValueError("Environment variable name cannot be empty")
+
+ if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", name):
+ raise ValueError(
+ f"Invalid environment variable name '{name}'. "
+ "Only alphanumeric characters and underscores are allowed, "
+ "and the name must start with a letter or underscore."
+ )
+
+
+def generate_job_id(
+ dag_id: str,
+ task_id: str,
+ run_id: str,
+ try_number: int,
+ suffix_length: int = 8,
+) -> str:
+ """
+ Generate a unique job ID for remote execution.
+
+ Creates a deterministic identifier from the task context with a random
suffix
+ to ensure uniqueness across retries and potential race conditions.
+
+ :param dag_id: The DAG identifier
+ :param task_id: The task identifier
+ :param run_id: The run identifier
+ :param try_number: The attempt number
+ :param suffix_length: Length of random suffix (default 8)
+ :return: Sanitized job ID string
+ """
+
+ def sanitize(value: str) -> str:
+ return re.sub(r"[^a-zA-Z0-9]", "_", value)[:50]
+
+ sanitized_dag = sanitize(dag_id)
+ sanitized_task = sanitize(task_id)
+ sanitized_run = sanitize(run_id)
+
+ alphabet = string.ascii_lowercase + string.digits
+ suffix = "".join(secrets.choice(alphabet) for _ in range(suffix_length))
+
+ return
f"af_{sanitized_dag}_{sanitized_task}_{sanitized_run}_try{try_number}_{suffix}"
+
+
+@dataclass
+class RemoteJobPaths:
+ """Paths for remote job artifacts on the target system."""
+
+ job_id: str
+ remote_os: Literal["posix", "windows"]
+ base_dir: str | None = None
+
+ def __post_init__(self):
+ if self.base_dir is None:
+ if self.remote_os == "posix":
+ self.base_dir = POSIX_DEFAULT_BASE_DIR
+ else:
+ self.base_dir = WINDOWS_DEFAULT_BASE_DIR
+
+ @property
+ def sep(self) -> str:
+ """Path separator for the remote OS."""
+ return "\\" if self.remote_os == "windows" else "/"
+
+ @property
+ def job_dir(self) -> str:
+ """Directory containing all job artifacts."""
+ return f"{self.base_dir}{self.sep}{self.job_id}"
+
+ @property
+ def log_file(self) -> str:
+ """Path to stdout/stderr log file."""
+ return f"{self.job_dir}{self.sep}stdout.log"
+
+ @property
+ def exit_code_file(self) -> str:
+ """Path to exit code file (written on completion)."""
+ return f"{self.job_dir}{self.sep}exit_code"
+
+ @property
+ def exit_code_tmp_file(self) -> str:
+ """Temporary exit code file (for atomic write)."""
+ return f"{self.job_dir}{self.sep}exit_code.tmp"
+
+ @property
+ def pid_file(self) -> str:
+ """Path to PID file for the background process."""
+ return f"{self.job_dir}{self.sep}pid"
+
+ @property
+ def status_file(self) -> str:
+ """Path to optional status file for progress updates."""
+ return f"{self.job_dir}{self.sep}status"
+
+
+def build_posix_wrapper_command(
+ command: str,
+ paths: RemoteJobPaths,
+ environment: dict[str, str] | None = None,
+) -> str:
+ """
+ Build a POSIX shell wrapper that runs the command detached via nohup.
+
+ The wrapper:
+ - Creates the job directory
+ - Starts the command in the background with nohup
+ - Redirects stdout/stderr to the log file
+ - Writes the exit code atomically on completion
+ - Writes the PID for potential cancellation
+
+ :param command: The command to execute
+ :param paths: RemoteJobPaths instance with all paths
+ :param environment: Optional environment variables to set
+ :return: Shell command string to submit via SSH
+ """
+ env_exports = ""
+ if environment:
+ for key, value in environment.items():
+ _validate_env_var_name(key)
+ escaped_value = value.replace("'", "'\"'\"'")
+ env_exports += f"export {key}='{escaped_value}'\n"
+
+ escaped_command = command.replace("'", "'\"'\"'")
+
+ wrapper = f"""set -euo pipefail
+job_dir='{paths.job_dir}'
+log_file='{paths.log_file}'
+exit_code_file='{paths.exit_code_file}'
+exit_code_tmp='{paths.exit_code_tmp_file}'
+pid_file='{paths.pid_file}'
+status_file='{paths.status_file}'
+
+mkdir -p "$job_dir"
+: > "$log_file"
+
+nohup bash -c '
+set +e
+export LOG_FILE="'"$log_file"'"
+export STATUS_FILE="'"$status_file"'"
+{env_exports}{escaped_command} >>"'"$log_file"'" 2>&1
+ec=$?
+echo -n "$ec" > "'"$exit_code_tmp"'"
+mv "'"$exit_code_tmp"'" "'"$exit_code_file"'"
+exit 0
+' >/dev/null 2>&1 &
+
+echo -n $! > "$pid_file"
+echo "{paths.job_id}"
+"""
+ return wrapper
+
+
+def build_windows_wrapper_command(
+ command: str,
+ paths: RemoteJobPaths,
+ environment: dict[str, str] | None = None,
+) -> str:
+ """
+ Build a PowerShell wrapper that runs the command detached via
Start-Process.
+
+ The wrapper:
+ - Creates the job directory
+ - Starts the command in a new detached PowerShell process
+ - Redirects stdout/stderr to the log file
+ - Writes the exit code atomically on completion
+ - Writes the PID for potential cancellation
+
+ :param command: The command to execute (PowerShell script path or command)
+ :param paths: RemoteJobPaths instance with all paths
+ :param environment: Optional environment variables to set
+ :return: PowerShell command string to submit via SSH
+ """
+ env_setup = ""
+ if environment:
+ for key, value in environment.items():
+ _validate_env_var_name(key)
+ escaped_value = value.replace("'", "''")
+ env_setup += f"$env:{key} = '{escaped_value}'; "
+
+ def ps_escape(s: str) -> str:
+ return s.replace("'", "''")
+
+ job_dir = ps_escape(paths.job_dir)
+ log_file = ps_escape(paths.log_file)
+ exit_code_file = ps_escape(paths.exit_code_file)
+ exit_code_tmp = ps_escape(paths.exit_code_tmp_file)
+ pid_file = ps_escape(paths.pid_file)
+ status_file = ps_escape(paths.status_file)
+ escaped_command = ps_escape(command)
+ job_id = ps_escape(paths.job_id)
+
+ child_script = f"""$ErrorActionPreference = 'Continue'
+$env:LOG_FILE = '{log_file}'
+$env:STATUS_FILE = '{status_file}'
+{env_setup}
+{escaped_command}
+$ec = $LASTEXITCODE
+if ($null -eq $ec) {{ $ec = 0 }}
+Set-Content -NoNewline -Path '{exit_code_tmp}' -Value $ec
+Move-Item -Force -Path '{exit_code_tmp}' -Destination '{exit_code_file}'
+"""
+ child_script_bytes = child_script.encode("utf-16-le")
+ encoded_script = base64.b64encode(child_script_bytes).decode("ascii")
+
+ wrapper = f"""$jobDir = '{job_dir}'
+New-Item -ItemType Directory -Force -Path $jobDir | Out-Null
+$log = '{log_file}'
+'' | Set-Content -Path $log
+
+$p = Start-Process -FilePath 'powershell.exe' -ArgumentList @('-NoProfile',
'-NonInteractive', '-EncodedCommand', '{encoded_script}')
-RedirectStandardOutput $log -RedirectStandardError $log -PassThru -WindowStyle
Hidden
+Set-Content -NoNewline -Path '{pid_file}' -Value $p.Id
+Write-Output '{job_id}'
+"""
+ wrapper_bytes = wrapper.encode("utf-16-le")
+ encoded_wrapper = base64.b64encode(wrapper_bytes).decode("ascii")
+
+ return f"powershell.exe -NoProfile -NonInteractive -EncodedCommand
{encoded_wrapper}"
+
+
+def build_posix_log_tail_command(log_file: str, offset: int, max_bytes: int)
-> str:
+ """
+ Build a POSIX command to read log bytes from offset.
+
+ :param log_file: Path to the log file
+ :param offset: Byte offset to start reading from
+ :param max_bytes: Maximum bytes to read
+ :return: Shell command that outputs the log chunk
+ """
+ # tail -c +N is 1-indexed, so offset 0 means start at byte 1
+ tail_offset = offset + 1
+ return f"tail -c +{tail_offset} '{log_file}' 2>/dev/null | head -c
{max_bytes} || true"
+
+
+def build_windows_log_tail_command(log_file: str, offset: int, max_bytes: int)
-> str:
+ """
+ Build a PowerShell command to read log bytes from offset.
+
+ :param log_file: Path to the log file
+ :param offset: Byte offset to start reading from
+ :param max_bytes: Maximum bytes to read
+ :return: PowerShell command that outputs the log chunk
+ """
+ escaped_path = log_file.replace("'", "''")
+ script = f"""$path = '{escaped_path}'
+if (Test-Path $path) {{
+ try {{
+ $fs = [System.IO.File]::Open($path, 'Open', 'Read', 'ReadWrite')
+ $fs.Seek({offset}, [System.IO.SeekOrigin]::Begin) | Out-Null
+ $buf = New-Object byte[] {max_bytes}
+ $n = $fs.Read($buf, 0, $buf.Length)
+ $fs.Close()
+ if ($n -gt 0) {{
+ [System.Text.Encoding]::UTF8.GetString($buf, 0, $n)
+ }}
+ }} catch {{}}
+}}"""
+ script_bytes = script.encode("utf-16-le")
+ encoded_script = base64.b64encode(script_bytes).decode("ascii")
+ return f"powershell.exe -NoProfile -NonInteractive -EncodedCommand
{encoded_script}"
+
+
+def build_posix_file_size_command(file_path: str) -> str:
+ """
+ Build a POSIX command to get file size in bytes.
+
+ :param file_path: Path to the file
+ :return: Shell command that outputs the file size
+ """
+ return f"stat -c%s '{file_path}' 2>/dev/null || stat -f%z '{file_path}'
2>/dev/null || echo 0"
+
+
+def build_windows_file_size_command(file_path: str) -> str:
+ """
+ Build a PowerShell command to get file size in bytes.
+
+ :param file_path: Path to the file
+ :return: PowerShell command that outputs the file size
+ """
+ escaped_path = file_path.replace("'", "''")
+ script = f"""$path = '{escaped_path}'
+if (Test-Path $path) {{
+ (Get-Item $path).Length
+}} else {{
+ 0
+}}"""
+ script_bytes = script.encode("utf-16-le")
+ encoded_script = base64.b64encode(script_bytes).decode("ascii")
+ return f"powershell.exe -NoProfile -NonInteractive -EncodedCommand
{encoded_script}"
+
+
+def build_posix_completion_check_command(exit_code_file: str) -> str:
+ """
+ Build a POSIX command to check if job completed and get exit code.
+
+ :param exit_code_file: Path to the exit code file
+ :return: Shell command that outputs exit code if done, empty otherwise
+ """
+ return f"test -s '{exit_code_file}' && cat '{exit_code_file}' || true"
+
+
+def build_windows_completion_check_command(exit_code_file: str) -> str:
+ """
+ Build a PowerShell command to check if job completed and get exit code.
+
+ :param exit_code_file: Path to the exit code file
+ :return: PowerShell command that outputs exit code if done, empty otherwise
+ """
+ escaped_path = exit_code_file.replace("'", "''")
+ script = f"""$path = '{escaped_path}'
+if (Test-Path $path) {{
+ $txt = Get-Content -Raw -Path $path
+ if ($txt -match '^[0-9]+$') {{ $txt.Trim() }}
+}}"""
+ script_bytes = script.encode("utf-16-le")
+ encoded_script = base64.b64encode(script_bytes).decode("ascii")
+ return f"powershell.exe -NoProfile -NonInteractive -EncodedCommand
{encoded_script}"
+
+
+def build_posix_kill_command(pid_file: str) -> str:
+ """
+ Build a POSIX command to kill the remote process.
+
+ :param pid_file: Path to the PID file
+ :return: Shell command to kill the process
+ """
+ return f"test -f '{pid_file}' && kill $(cat '{pid_file}') 2>/dev/null ||
true"
+
+
+def build_windows_kill_command(pid_file: str) -> str:
+ """
+ Build a PowerShell command to kill the remote process.
+
+ :param pid_file: Path to the PID file
+ :return: PowerShell command to kill the process
+ """
+ escaped_path = pid_file.replace("'", "''")
+ script = f"""$path = '{escaped_path}'
+if (Test-Path $path) {{
+ $pid = Get-Content $path
+ Stop-Process -Id $pid -Force -ErrorAction SilentlyContinue
+}}"""
+ script_bytes = script.encode("utf-16-le")
+ encoded_script = base64.b64encode(script_bytes).decode("ascii")
+ return f"powershell.exe -NoProfile -NonInteractive -EncodedCommand
{encoded_script}"
+
+
+def build_posix_cleanup_command(job_dir: str) -> str:
+ """
+ Build a POSIX command to clean up the job directory.
+
+ :param job_dir: Path to the job directory
+ :return: Shell command to remove the directory
+ :raises ValueError: If job_dir is not under the expected base directory
+ """
+ _validate_job_dir(job_dir, "posix")
+ return f"rm -rf '{job_dir}'"
+
+
+def build_windows_cleanup_command(job_dir: str) -> str:
+ """
+ Build a PowerShell command to clean up the job directory.
+
+ :param job_dir: Path to the job directory
+ :return: PowerShell command to remove the directory
+ :raises ValueError: If job_dir is not under the expected base directory
+ """
+ _validate_job_dir(job_dir, "windows")
+ escaped_path = job_dir.replace("'", "''")
+ script = f"Remove-Item -Recurse -Force -Path '{escaped_path}' -ErrorAction
SilentlyContinue"
+ script_bytes = script.encode("utf-16-le")
+ encoded_script = base64.b64encode(script_bytes).decode("ascii")
+ return f"powershell.exe -NoProfile -NonInteractive -EncodedCommand
{encoded_script}"
+
+
+def build_posix_os_detection_command() -> str:
+ """
+ Build a command to detect if the remote system is POSIX-compliant.
+
+ Returns the OS name (Linux, Darwin, FreeBSD, etc.) or UNKNOWN.
+ """
+ return "uname -s 2>/dev/null || echo UNKNOWN"
+
+
+def build_windows_os_detection_command() -> str:
+ """Build a command to detect if the remote system is Windows."""
+ script = '$PSVersionTable.PSVersion.Major; if ($?) { "WINDOWS" }'
+ script_bytes = script.encode("utf-16-le")
+ encoded_script = base64.b64encode(script_bytes).decode("ascii")
+ return f"powershell.exe -NoProfile -NonInteractive -EncodedCommand
{encoded_script}"
diff --git a/providers/ssh/tests/unit/ssh/hooks/test_ssh_async.py
b/providers/ssh/tests/unit/ssh/hooks/test_ssh_async.py
new file mode 100644
index 00000000000..ee92c809be8
--- /dev/null
+++ b/providers/ssh/tests/unit/ssh/hooks/test_ssh_async.py
@@ -0,0 +1,172 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest import mock
+
+import pytest
+
+from airflow.providers.ssh.hooks.ssh import SSHHookAsync
+
+
+class TestSSHHookAsync:
+ def test_init_with_conn_id(self):
+ """Test initialization with connection ID."""
+ hook = SSHHookAsync(ssh_conn_id="test_ssh_conn")
+ assert hook.ssh_conn_id == "test_ssh_conn"
+ assert hook.host is None
+ assert hook.port is None
+
+ def test_init_with_overrides(self):
+ """Test initialization with parameter overrides."""
+ hook = SSHHookAsync(
+ ssh_conn_id="test_ssh_conn",
+ host="custom.host.com",
+ port=2222,
+ username="testuser",
+ password="testpass",
+ )
+ assert hook.host == "custom.host.com"
+ assert hook.port == 2222
+ assert hook.username == "testuser"
+ assert hook.password == "testpass"
+
+ def test_init_default_known_hosts(self):
+ """Test default known_hosts path."""
+ hook = SSHHookAsync(ssh_conn_id="test_conn")
+ assert "known_hosts" in str(hook.known_hosts)
+
+ def test_parse_extras_key_file(self):
+ """Test parsing key_file from connection extras."""
+ hook = SSHHookAsync(ssh_conn_id="test_conn")
+ mock_conn = mock.MagicMock()
+ mock_conn.extra_dejson = {"key_file": "/path/to/key"}
+ mock_conn.host = "test.host"
+
+ hook._parse_extras(mock_conn)
+ assert hook.key_file == "/path/to/key"
+
+ def test_parse_extras_no_host_key_check(self):
+ """Test parsing no_host_key_check from connection extras."""
+ hook = SSHHookAsync(ssh_conn_id="test_conn")
+ mock_conn = mock.MagicMock()
+ mock_conn.extra_dejson = {"no_host_key_check": "true"}
+ mock_conn.host = "test.host"
+
+ hook._parse_extras(mock_conn)
+ assert hook.known_hosts == "none"
+
+ def test_parse_extras_host_key(self):
+ """Test parsing host_key from connection extras."""
+ hook = SSHHookAsync(ssh_conn_id="test_conn")
+ mock_conn = mock.MagicMock()
+ mock_conn.extra_dejson = {"host_key": "ssh-rsa AAAAB3...",
"no_host_key_check": "false"}
+ mock_conn.host = "test.host"
+
+ hook._parse_extras(mock_conn)
+ assert hook.known_hosts == b"test.host ssh-rsa AAAAB3..."
+
+ def test_parse_extras_host_key_with_no_check_raises(self):
+ """Test that host_key with no_host_key_check raises error."""
+ hook = SSHHookAsync(ssh_conn_id="test_conn")
+ mock_conn = mock.MagicMock()
+ mock_conn.extra_dejson = {
+ "host_key": "ssh-rsa AAAAB3...",
+ "no_host_key_check": "true",
+ }
+ mock_conn.host = "test.host"
+
+ with pytest.raises(ValueError, match="Host key check was skipped"):
+ hook._parse_extras(mock_conn)
+
+ def test_parse_extras_private_key(self):
+ """Test parsing private_key from connection extras."""
+ hook = SSHHookAsync(ssh_conn_id="test_conn")
+ mock_conn = mock.MagicMock()
+ test_key = "test-private-key-content"
+ mock_conn.extra_dejson = {"private_key": test_key}
+ mock_conn.host = "test.host"
+
+ hook._parse_extras(mock_conn)
+ assert hook.private_key == test_key
+
+ @pytest.mark.asyncio
+ async def test_get_conn_builds_config(self):
+ """Test that _get_conn builds correct connection config."""
+ hook = SSHHookAsync(
+ ssh_conn_id="test_conn",
+ host="test.host.com",
+ port=22,
+ username="testuser",
+ )
+
+ mock_conn_obj = mock.MagicMock()
+ mock_conn_obj.extra_dejson = {"no_host_key_check": "true"}
+ mock_conn_obj.host = None
+ mock_conn_obj.port = None
+ mock_conn_obj.login = None
+ mock_conn_obj.password = None
+ mock_conn_obj.extra = "{}"
+
+ mock_ssh_client = mock.AsyncMock()
+
+ with mock.patch("asgiref.sync.sync_to_async") as mock_sync:
+ mock_sync.return_value = mock.AsyncMock(return_value=mock_conn_obj)
+
+ with mock.patch("asyncssh.connect", new_callable=mock.AsyncMock)
as mock_connect:
+ mock_connect.return_value = mock_ssh_client
+ result = await hook._get_conn()
+
+ mock_connect.assert_called_once()
+ call_kwargs = mock_connect.call_args[1]
+ assert call_kwargs["host"] == "test.host.com"
+ assert call_kwargs["port"] == 22
+ assert call_kwargs["username"] == "testuser"
+ assert result == mock_ssh_client
+
+ @pytest.mark.asyncio
+ async def test_run_command(self):
+ """Test running a command."""
+ hook = SSHHookAsync(ssh_conn_id="test_conn")
+
+ mock_result = mock.MagicMock()
+ mock_result.exit_status = 0
+ mock_result.stdout = "output"
+ mock_result.stderr = ""
+
+ mock_conn = mock.AsyncMock()
+ mock_conn.run = mock.AsyncMock(return_value=mock_result)
+ mock_conn.__aenter__ = mock.AsyncMock(return_value=mock_conn)
+ mock_conn.__aexit__ = mock.AsyncMock(return_value=None)
+
+ with mock.patch.object(hook, "_get_conn", return_value=mock_conn):
+ exit_code, stdout, stderr = await hook.run_command("echo test")
+
+ assert exit_code == 0
+ assert stdout == "output"
+ assert stderr == ""
+ mock_conn.run.assert_called_once_with("echo test", timeout=None,
check=False)
+
+ @pytest.mark.asyncio
+ async def test_run_command_output(self):
+ """Test running a command and getting output."""
+ hook = SSHHookAsync(ssh_conn_id="test_conn")
+
+ with mock.patch.object(hook, "run_command", return_value=(0, "test
output", "")):
+ output = await hook.run_command_output("echo test")
+ assert output == "test output"
diff --git a/providers/ssh/tests/unit/ssh/operators/test_ssh_remote_job.py
b/providers/ssh/tests/unit/ssh/operators/test_ssh_remote_job.py
new file mode 100644
index 00000000000..871f3981393
--- /dev/null
+++ b/providers/ssh/tests/unit/ssh/operators/test_ssh_remote_job.py
@@ -0,0 +1,331 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest import mock
+
+import pytest
+
+from airflow.exceptions import TaskDeferred
+from airflow.providers.common.compat.sdk import AirflowException,
AirflowSkipException
+from airflow.providers.ssh.hooks.ssh import SSHHook
+from airflow.providers.ssh.operators.ssh_remote_job import SSHRemoteJobOperator
+from airflow.providers.ssh.triggers.ssh_remote_job import SSHRemoteJobTrigger
+
+
+class TestSSHRemoteJobOperator:
+ @pytest.fixture(autouse=True)
+ def mock_ssh_hook(self):
+ """Mock the SSHHook to avoid connection lookup."""
+ with mock.patch.object(
+ SSHRemoteJobOperator, "ssh_hook", new_callable=mock.PropertyMock
+ ) as mock_hook_prop:
+ mock_hook = mock.create_autospec(SSHHook, instance=True)
+ mock_hook.remote_host = "test.host.com"
+ mock_ssh_client = mock.MagicMock()
+ mock_hook.get_conn.return_value.__enter__.return_value =
mock_ssh_client
+ mock_hook.get_conn.return_value.__exit__.return_value = None
+ mock_hook.exec_ssh_client_command.return_value = (0, b"", b"")
+ mock_hook_prop.return_value = mock_hook
+ self.mock_hook = mock_hook
+ self.mock_hook_prop = mock_hook_prop
+ yield
+
+ def test_init_default_values(self):
+ """Test operator initialization with default values."""
+ op = SSHRemoteJobOperator(
+ task_id="test_task",
+ ssh_conn_id="test_conn",
+ command="/path/to/script.sh",
+ )
+ assert op.ssh_conn_id == "test_conn"
+ assert op.command == "/path/to/script.sh"
+ assert op.poll_interval == 5
+ assert op.log_chunk_size == 65536
+ assert op.cleanup == "never"
+ assert op.remote_os == "auto"
+
+ def test_init_custom_values(self):
+ """Test operator initialization with custom values."""
+ op = SSHRemoteJobOperator(
+ task_id="test_task",
+ ssh_conn_id="test_conn",
+ command="/path/to/script.sh",
+ remote_host="custom.host.com",
+ poll_interval=10,
+ log_chunk_size=32768,
+ timeout=3600,
+ cleanup="on_success",
+ remote_os="posix",
+ skip_on_exit_code=[42, 43],
+ )
+ assert op.remote_host == "custom.host.com"
+ assert op.poll_interval == 10
+ assert op.log_chunk_size == 32768
+ assert op.timeout == 3600
+ assert op.cleanup == "on_success"
+ assert op.remote_os == "posix"
+ assert 42 in op.skip_on_exit_code
+ assert 43 in op.skip_on_exit_code
+
+ def test_template_fields(self):
+ """Test that template fields are defined correctly."""
+ op = SSHRemoteJobOperator(
+ task_id="test_task",
+ ssh_conn_id="test_conn",
+ command="/path/to/script.sh",
+ )
+ assert "command" in op.template_fields
+ assert "environment" in op.template_fields
+ assert "remote_host" in op.template_fields
+ assert "remote_base_dir" in op.template_fields
+
+ def test_execute_defers_to_trigger(self):
+ """Test that execute submits job and defers to trigger."""
+ self.mock_hook.exec_ssh_client_command.return_value = (
+ 0,
+ b"af_test_dag_test_task_run1_try1_abc123",
+ b"",
+ )
+
+ op = SSHRemoteJobOperator(
+ task_id="test_task",
+ ssh_conn_id="test_conn",
+ command="/path/to/script.sh",
+ remote_os="posix",
+ )
+
+ mock_ti = mock.MagicMock()
+ mock_ti.dag_id = "test_dag"
+ mock_ti.task_id = "test_task"
+ mock_ti.run_id = "run1"
+ mock_ti.try_number = 1
+ context = {"ti": mock_ti}
+
+ with pytest.raises(TaskDeferred) as exc_info:
+ op.execute(context)
+
+ assert isinstance(exc_info.value.trigger, SSHRemoteJobTrigger)
+ assert exc_info.value.method_name == "execute_complete"
+
+ def test_execute_raises_if_no_command(self):
+ """Test that execute raises if command is not specified."""
+ op = SSHRemoteJobOperator(
+ task_id="test_task",
+ ssh_conn_id="test_conn",
+ command="",
+ )
+ # Set command to empty after init
+ op.command = ""
+
+ with pytest.raises(AirflowException, match="command not specified"):
+ op.execute({})
+
+ @mock.patch.object(SSHRemoteJobOperator, "defer")
+ def test_execute_complete_re_defers_if_not_done(self, mock_defer):
+ """Test that execute_complete re-defers if job is not done."""
+ op = SSHRemoteJobOperator(
+ task_id="test_task",
+ ssh_conn_id="test_conn",
+ command="/path/to/script.sh",
+ )
+
+ event = {
+ "done": False,
+ "status": "running",
+ "job_id": "test_job_123",
+ "job_dir": "/tmp/airflow-ssh-jobs/test_job_123",
+ "log_file": "/tmp/airflow-ssh-jobs/test_job_123/stdout.log",
+ "exit_code_file": "/tmp/airflow-ssh-jobs/test_job_123/exit_code",
+ "remote_os": "posix",
+ "log_chunk": "Some output\n",
+ "log_offset": 100,
+ "exit_code": None,
+ }
+
+ op.execute_complete({}, event)
+
+ mock_defer.assert_called_once()
+ call_kwargs = mock_defer.call_args[1]
+ assert isinstance(call_kwargs["trigger"], SSHRemoteJobTrigger)
+ assert call_kwargs["trigger"].log_offset == 100
+
+ def test_execute_complete_success(self):
+ """Test execute_complete with successful completion."""
+ op = SSHRemoteJobOperator(
+ task_id="test_task",
+ ssh_conn_id="test_conn",
+ command="/path/to/script.sh",
+ )
+
+ event = {
+ "done": True,
+ "status": "success",
+ "exit_code": 0,
+ "job_id": "test_job_123",
+ "job_dir": "/tmp/airflow-ssh-jobs/test_job_123",
+ "log_file": "/tmp/airflow-ssh-jobs/test_job_123/stdout.log",
+ "exit_code_file": "/tmp/airflow-ssh-jobs/test_job_123/exit_code",
+ "log_chunk": "Final output\n",
+ "log_offset": 200,
+ "remote_os": "posix",
+ }
+
+ # Should complete without exception
+ op.execute_complete({}, event)
+
+ def test_execute_complete_failure(self):
+ """Test execute_complete with non-zero exit code."""
+ op = SSHRemoteJobOperator(
+ task_id="test_task",
+ ssh_conn_id="test_conn",
+ command="/path/to/script.sh",
+ )
+
+ event = {
+ "done": True,
+ "status": "failed",
+ "exit_code": 1,
+ "job_id": "test_job_123",
+ "job_dir": "/tmp/airflow-ssh-jobs/test_job_123",
+ "log_file": "/tmp/airflow-ssh-jobs/test_job_123/stdout.log",
+ "exit_code_file": "/tmp/airflow-ssh-jobs/test_job_123/exit_code",
+ "log_chunk": "Error output\n",
+ "log_offset": 200,
+ "remote_os": "posix",
+ }
+
+ with pytest.raises(AirflowException, match="exit code: 1"):
+ op.execute_complete({}, event)
+
+ def test_execute_complete_skip_on_exit_code(self):
+ """Test execute_complete skips on configured exit code."""
+ op = SSHRemoteJobOperator(
+ task_id="test_task",
+ ssh_conn_id="test_conn",
+ command="/path/to/script.sh",
+ skip_on_exit_code=42,
+ )
+
+ event = {
+ "done": True,
+ "status": "failed",
+ "exit_code": 42,
+ "job_id": "test_job_123",
+ "job_dir": "/tmp/airflow-ssh-jobs/test_job_123",
+ "log_file": "/tmp/airflow-ssh-jobs/test_job_123/stdout.log",
+ "exit_code_file": "/tmp/airflow-ssh-jobs/test_job_123/exit_code",
+ "log_chunk": "",
+ "log_offset": 0,
+ "remote_os": "posix",
+ }
+
+ with pytest.raises(AirflowSkipException):
+ op.execute_complete({}, event)
+
+ def test_execute_complete_with_cleanup(self):
+ """Test execute_complete performs cleanup when configured."""
+ op = SSHRemoteJobOperator(
+ task_id="test_task",
+ ssh_conn_id="test_conn",
+ command="/path/to/script.sh",
+ cleanup="on_success",
+ )
+
+ event = {
+ "done": True,
+ "status": "success",
+ "exit_code": 0,
+ "job_id": "test_job_123",
+ "job_dir": "/tmp/airflow-ssh-jobs/test_job_123",
+ "log_file": "/tmp/airflow-ssh-jobs/test_job_123/stdout.log",
+ "exit_code_file": "/tmp/airflow-ssh-jobs/test_job_123/exit_code",
+ "log_chunk": "",
+ "log_offset": 0,
+ "remote_os": "posix",
+ }
+
+ op.execute_complete({}, event)
+
+ # Verify cleanup command was executed
+ self.mock_hook.exec_ssh_client_command.assert_called_once()
+ call_args = self.mock_hook.exec_ssh_client_command.call_args
+ assert "rm -rf" in call_args[0][1]
+
+ def test_on_kill(self):
+ """Test on_kill attempts to kill remote process."""
+ op = SSHRemoteJobOperator(
+ task_id="test_task",
+ ssh_conn_id="test_conn",
+ command="/path/to/script.sh",
+ )
+
+ # Simulate that execute was called
+ op._job_id = "test_job_123"
+ op._detected_os = "posix"
+ from airflow.providers.ssh.utils.remote_job import RemoteJobPaths
+
+ op._paths = RemoteJobPaths(job_id="test_job_123", remote_os="posix")
+
+ op.on_kill()
+
+ # Verify kill command was executed
+ self.mock_hook.exec_ssh_client_command.assert_called_once()
+ call_args = self.mock_hook.exec_ssh_client_command.call_args
+ assert "kill" in call_args[0][1]
+
+ def test_on_kill_after_rehydration(self):
+ """Test on_kill retrieves job info from XCom after operator
rehydration."""
+ op = SSHRemoteJobOperator(
+ task_id="test_task",
+ ssh_conn_id="test_conn",
+ command="/path/to/script.sh",
+ )
+
+ # Instance variables are None (simulating rehydration)
+ # Don't set _job_id, _paths, _detected_os
+
+ # Mock task_instance with XCom data
+ mock_ti = mock.MagicMock()
+ mock_ti.xcom_pull.return_value = {
+ "job_id": "test_job_123",
+ "pid_file": "/tmp/airflow-ssh-jobs/test_job_123/pid",
+ "remote_os": "posix",
+ }
+ op.task_instance = mock_ti
+
+ op.on_kill()
+
+ # Verify XCom was called to get job info
+ mock_ti.xcom_pull.assert_called_once_with(key="ssh_remote_job")
+
+ # Verify kill command was executed
+ self.mock_hook.exec_ssh_client_command.assert_called_once()
+ call_args = self.mock_hook.exec_ssh_client_command.call_args
+ assert "kill" in call_args[0][1]
+
+ def test_on_kill_no_active_job(self):
+ """Test on_kill does nothing if no active job."""
+ op = SSHRemoteJobOperator(
+ task_id="test_task",
+ ssh_conn_id="test_conn",
+ command="/path/to/script.sh",
+ )
+
+ # Should not raise even without active job
+ op.on_kill()
diff --git a/providers/ssh/tests/unit/ssh/triggers/__init__.py
b/providers/ssh/tests/unit/ssh/triggers/__init__.py
new file mode 100644
index 00000000000..217e5db9607
--- /dev/null
+++ b/providers/ssh/tests/unit/ssh/triggers/__init__.py
@@ -0,0 +1,17 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/providers/ssh/tests/unit/ssh/triggers/test_ssh_remote_job.py
b/providers/ssh/tests/unit/ssh/triggers/test_ssh_remote_job.py
new file mode 100644
index 00000000000..67d9672a8cb
--- /dev/null
+++ b/providers/ssh/tests/unit/ssh/triggers/test_ssh_remote_job.py
@@ -0,0 +1,197 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest import mock
+
+import pytest
+
+from airflow.providers.ssh.triggers.ssh_remote_job import SSHRemoteJobTrigger
+
+
+class TestSSHRemoteJobTrigger:
+ def test_serialization(self):
+ """Test that the trigger can be serialized correctly."""
+ trigger = SSHRemoteJobTrigger(
+ ssh_conn_id="test_conn",
+ remote_host="test.example.com",
+ job_id="test_job_123",
+ job_dir="/tmp/airflow-ssh-jobs/test_job_123",
+ log_file="/tmp/airflow-ssh-jobs/test_job_123/stdout.log",
+ exit_code_file="/tmp/airflow-ssh-jobs/test_job_123/exit_code",
+ remote_os="posix",
+ poll_interval=10,
+ log_chunk_size=32768,
+ log_offset=1000,
+ )
+
+ classpath, kwargs = trigger.serialize()
+
+ assert classpath ==
"airflow.providers.ssh.triggers.ssh_remote_job.SSHRemoteJobTrigger"
+ assert kwargs["ssh_conn_id"] == "test_conn"
+ assert kwargs["remote_host"] == "test.example.com"
+ assert kwargs["job_id"] == "test_job_123"
+ assert kwargs["job_dir"] == "/tmp/airflow-ssh-jobs/test_job_123"
+ assert kwargs["log_file"] ==
"/tmp/airflow-ssh-jobs/test_job_123/stdout.log"
+ assert kwargs["exit_code_file"] ==
"/tmp/airflow-ssh-jobs/test_job_123/exit_code"
+ assert kwargs["remote_os"] == "posix"
+ assert kwargs["poll_interval"] == 10
+ assert kwargs["log_chunk_size"] == 32768
+ assert kwargs["log_offset"] == 1000
+
+ def test_default_values(self):
+ """Test default parameter values."""
+ trigger = SSHRemoteJobTrigger(
+ ssh_conn_id="test_conn",
+ remote_host=None,
+ job_id="test_job",
+ job_dir="/tmp/job",
+ log_file="/tmp/job/stdout.log",
+ exit_code_file="/tmp/job/exit_code",
+ remote_os="posix",
+ )
+
+ assert trigger.poll_interval == 5
+ assert trigger.log_chunk_size == 65536
+ assert trigger.log_offset == 0
+
+ @pytest.mark.asyncio
+ async def test_run_job_completed_success(self):
+ """Test trigger when job completes successfully."""
+ trigger = SSHRemoteJobTrigger(
+ ssh_conn_id="test_conn",
+ remote_host=None,
+ job_id="test_job",
+ job_dir="/tmp/job",
+ log_file="/tmp/job/stdout.log",
+ exit_code_file="/tmp/job/exit_code",
+ remote_os="posix",
+ )
+
+ with mock.patch.object(trigger, "_check_completion", return_value=0):
+ with mock.patch.object(trigger, "_read_log_chunk",
return_value=("Final output\n", 100)):
+ events = []
+ async for event in trigger.run():
+ events.append(event)
+
+ assert len(events) == 1
+ assert events[0].payload["status"] == "success"
+ assert events[0].payload["done"] is True
+ assert events[0].payload["exit_code"] == 0
+ assert events[0].payload["log_chunk"] == "Final output\n"
+
+ @pytest.mark.asyncio
+ async def test_run_job_completed_failure(self):
+ """Test trigger when job completes with failure."""
+ trigger = SSHRemoteJobTrigger(
+ ssh_conn_id="test_conn",
+ remote_host=None,
+ job_id="test_job",
+ job_dir="/tmp/job",
+ log_file="/tmp/job/stdout.log",
+ exit_code_file="/tmp/job/exit_code",
+ remote_os="posix",
+ )
+
+ with mock.patch.object(trigger, "_check_completion", return_value=1):
+ with mock.patch.object(trigger, "_read_log_chunk",
return_value=("Error output\n", 50)):
+ events = []
+ async for event in trigger.run():
+ events.append(event)
+
+ assert len(events) == 1
+ assert events[0].payload["status"] == "failed"
+ assert events[0].payload["done"] is True
+ assert events[0].payload["exit_code"] == 1
+
+ @pytest.mark.asyncio
+ async def test_run_job_polls_until_completion(self):
+ """Test trigger polls without yielding until job completes."""
+ trigger = SSHRemoteJobTrigger(
+ ssh_conn_id="test_conn",
+ remote_host=None,
+ job_id="test_job",
+ job_dir="/tmp/job",
+ log_file="/tmp/job/stdout.log",
+ exit_code_file="/tmp/job/exit_code",
+ remote_os="posix",
+ poll_interval=0.01,
+ )
+
+ poll_count = 0
+
+ async def mock_check_completion(_):
+ nonlocal poll_count
+ poll_count += 1
+ # Return None (still running) for first 2 polls, then exit code 0
+ if poll_count < 3:
+ return None
+ return 0
+
+ with mock.patch.object(trigger, "_check_completion",
side_effect=mock_check_completion):
+ with mock.patch.object(trigger, "_read_log_chunk",
return_value=("output\n", 50)):
+ events = []
+ async for event in trigger.run():
+ events.append(event)
+
+ # Only one event should be yielded (the completion event)
+ assert len(events) == 1
+ assert events[0].payload["status"] == "success"
+ assert events[0].payload["done"] is True
+ assert events[0].payload["exit_code"] == 0
+ # Should have polled 3 times
+ assert poll_count == 3
+
+ @pytest.mark.asyncio
+ async def test_run_handles_exception(self):
+ """Test trigger handles exceptions gracefully."""
+ trigger = SSHRemoteJobTrigger(
+ ssh_conn_id="test_conn",
+ remote_host=None,
+ job_id="test_job",
+ job_dir="/tmp/job",
+ log_file="/tmp/job/stdout.log",
+ exit_code_file="/tmp/job/exit_code",
+ remote_os="posix",
+ )
+
+ with mock.patch.object(trigger, "_check_completion",
side_effect=Exception("Connection failed")):
+ events = []
+ async for event in trigger.run():
+ events.append(event)
+
+ assert len(events) == 1
+ assert events[0].payload["status"] == "error"
+ assert events[0].payload["done"] is True
+ assert "Connection failed" in events[0].payload["message"]
+
+ def test_get_hook(self):
+ """Test hook creation."""
+ trigger = SSHRemoteJobTrigger(
+ ssh_conn_id="test_conn",
+ remote_host="custom.host.com",
+ job_id="test_job",
+ job_dir="/tmp/job",
+ log_file="/tmp/job/stdout.log",
+ exit_code_file="/tmp/job/exit_code",
+ remote_os="posix",
+ )
+
+ hook = trigger._get_hook()
+ assert hook.ssh_conn_id == "test_conn"
+ assert hook.host == "custom.host.com"
diff --git a/providers/ssh/tests/unit/ssh/utils/__init__.py
b/providers/ssh/tests/unit/ssh/utils/__init__.py
new file mode 100644
index 00000000000..217e5db9607
--- /dev/null
+++ b/providers/ssh/tests/unit/ssh/utils/__init__.py
@@ -0,0 +1,17 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/providers/ssh/tests/unit/ssh/utils/test_remote_job.py
b/providers/ssh/tests/unit/ssh/utils/test_remote_job.py
new file mode 100644
index 00000000000..1ababb62e08
--- /dev/null
+++ b/providers/ssh/tests/unit/ssh/utils/test_remote_job.py
@@ -0,0 +1,254 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import base64
+
+import pytest
+
+from airflow.providers.ssh.utils.remote_job import (
+ RemoteJobPaths,
+ build_posix_cleanup_command,
+ build_posix_completion_check_command,
+ build_posix_file_size_command,
+ build_posix_kill_command,
+ build_posix_log_tail_command,
+ build_posix_wrapper_command,
+ build_windows_cleanup_command,
+ build_windows_completion_check_command,
+ build_windows_file_size_command,
+ build_windows_kill_command,
+ build_windows_log_tail_command,
+ build_windows_wrapper_command,
+ generate_job_id,
+)
+
+
+class TestGenerateJobId:
+ def test_generates_unique_ids(self):
+ """Test that job IDs are unique."""
+ id1 = generate_job_id("dag1", "task1", "run1", 1)
+ id2 = generate_job_id("dag1", "task1", "run1", 1)
+ assert id1 != id2
+
+ def test_includes_context_info(self):
+ """Test that job ID includes context information."""
+ job_id = generate_job_id("my_dag", "my_task", "manual__2024", 2)
+ assert "af_" in job_id
+ assert "my_dag" in job_id
+ assert "my_task" in job_id
+ assert "try2" in job_id
+
+ def test_sanitizes_special_characters(self):
+ """Test that special characters are sanitized."""
+ job_id = generate_job_id("dag-with-dashes", "task.with.dots",
"run:with:colons", 1)
+ assert "-" not in job_id.split("_try")[0]
+ assert "." not in job_id.split("_try")[0]
+ assert ":" not in job_id.split("_try")[0]
+
+ def test_suffix_length(self):
+ """Test that suffix length is configurable."""
+ job_id = generate_job_id("dag", "task", "run", 1, suffix_length=12)
+ parts = job_id.split("_")
+ assert len(parts[-1]) == 12
+
+
+class TestRemoteJobPaths:
+ def test_posix_default_paths(self):
+ """Test POSIX default paths."""
+ paths = RemoteJobPaths(job_id="test_job", remote_os="posix")
+ assert paths.base_dir == "/tmp/airflow-ssh-jobs"
+ assert paths.job_dir == "/tmp/airflow-ssh-jobs/test_job"
+ assert paths.log_file == "/tmp/airflow-ssh-jobs/test_job/stdout.log"
+ assert paths.exit_code_file ==
"/tmp/airflow-ssh-jobs/test_job/exit_code"
+ assert paths.pid_file == "/tmp/airflow-ssh-jobs/test_job/pid"
+ assert paths.sep == "/"
+
+ def test_windows_default_paths(self):
+ """Test Windows default paths use $env:TEMP for portability."""
+ paths = RemoteJobPaths(job_id="test_job", remote_os="windows")
+ assert paths.base_dir == "$env:TEMP\\airflow-ssh-jobs"
+ assert paths.job_dir == "$env:TEMP\\airflow-ssh-jobs\\test_job"
+ assert paths.log_file ==
"$env:TEMP\\airflow-ssh-jobs\\test_job\\stdout.log"
+ assert paths.exit_code_file ==
"$env:TEMP\\airflow-ssh-jobs\\test_job\\exit_code"
+ assert paths.pid_file == "$env:TEMP\\airflow-ssh-jobs\\test_job\\pid"
+ assert paths.sep == "\\"
+
+ def test_custom_base_dir(self):
+ """Test custom base directory."""
+ paths = RemoteJobPaths(job_id="test_job", remote_os="posix",
base_dir="/custom/path")
+ assert paths.base_dir == "/custom/path"
+ assert paths.job_dir == "/custom/path/test_job"
+
+
+class TestBuildPosixWrapperCommand:
+ def test_basic_command(self):
+ """Test basic wrapper command generation."""
+ paths = RemoteJobPaths(job_id="test_job", remote_os="posix")
+ wrapper = build_posix_wrapper_command("/path/to/script.sh", paths)
+
+ assert "mkdir -p" in wrapper
+ assert "nohup bash -c" in wrapper
+ assert "/path/to/script.sh" in wrapper
+ assert "exit_code" in wrapper
+ assert "pid" in wrapper
+
+ def test_with_environment(self):
+ """Test wrapper with environment variables."""
+ paths = RemoteJobPaths(job_id="test_job", remote_os="posix")
+ wrapper = build_posix_wrapper_command(
+ "/path/to/script.sh",
+ paths,
+ environment={"MY_VAR": "my_value", "OTHER": "test"},
+ )
+
+ assert "export MY_VAR='my_value'" in wrapper
+ assert "export OTHER='test'" in wrapper
+
+ def test_escapes_quotes(self):
+ """Test that single quotes in command are escaped."""
+ paths = RemoteJobPaths(job_id="test_job", remote_os="posix")
+ wrapper = build_posix_wrapper_command("echo 'hello world'", paths)
+ assert wrapper is not None
+
+
+class TestBuildWindowsWrapperCommand:
+ def test_basic_command(self):
+ """Test basic Windows wrapper command generation."""
+ paths = RemoteJobPaths(job_id="test_job", remote_os="windows")
+ wrapper = build_windows_wrapper_command("C:\\scripts\\test.ps1", paths)
+
+ assert "powershell.exe" in wrapper
+ assert "-EncodedCommand" in wrapper
+ # Decode and verify script content
+ encoded_script = wrapper.split("-EncodedCommand ")[1]
+ decoded_script = base64.b64decode(encoded_script).decode("utf-16-le")
+ assert "New-Item -ItemType Directory" in decoded_script
+ assert "Start-Process" in decoded_script
+
+ def test_with_environment(self):
+ """Test Windows wrapper with environment variables."""
+ paths = RemoteJobPaths(job_id="test_job", remote_os="windows")
+ wrapper = build_windows_wrapper_command(
+ "C:\\scripts\\test.ps1",
+ paths,
+ environment={"MY_VAR": "my_value"},
+ )
+ assert wrapper is not None
+ assert "-EncodedCommand" in wrapper
+
+
+class TestLogTailCommands:
+ def test_posix_log_tail(self):
+ """Test POSIX log tail command uses efficient tail+head pipeline."""
+ cmd = build_posix_log_tail_command("/tmp/log.txt", 100, 1024)
+ assert "tail -c +101" in cmd # offset 100 -> byte 101 (1-indexed)
+ assert "head -c 1024" in cmd
+ assert "/tmp/log.txt" in cmd
+
+ def test_windows_log_tail(self):
+ """Test Windows log tail command."""
+ cmd = build_windows_log_tail_command("C:\\temp\\log.txt", 100, 1024)
+ assert "powershell.exe" in cmd
+ assert "-EncodedCommand" in cmd
+ # Decode and verify the script content
+ encoded_script = cmd.split("-EncodedCommand ")[1]
+ decoded_script = base64.b64decode(encoded_script).decode("utf-16-le")
+ assert "Seek(100" in decoded_script
+ assert "1024" in decoded_script
+
+
+class TestFileSizeCommands:
+ def test_posix_file_size(self):
+ """Test POSIX file size command."""
+ cmd = build_posix_file_size_command("/tmp/file.txt")
+ assert "stat" in cmd
+ assert "/tmp/file.txt" in cmd
+
+ def test_windows_file_size(self):
+ """Test Windows file size command."""
+ cmd = build_windows_file_size_command("C:\\temp\\file.txt")
+ assert "powershell.exe" in cmd
+ assert "-EncodedCommand" in cmd
+ encoded_script = cmd.split("-EncodedCommand ")[1]
+ decoded_script = base64.b64decode(encoded_script).decode("utf-16-le")
+ assert "Get-Item" in decoded_script
+ assert "Length" in decoded_script
+
+
+class TestCompletionCheckCommands:
+ def test_posix_completion_check(self):
+ """Test POSIX completion check command."""
+ cmd = build_posix_completion_check_command("/tmp/exit_code")
+ assert "test -s" in cmd
+ assert "cat" in cmd
+
+ def test_windows_completion_check(self):
+ """Test Windows completion check command."""
+ cmd = build_windows_completion_check_command("C:\\temp\\exit_code")
+ assert "powershell.exe" in cmd
+ assert "-EncodedCommand" in cmd
+ encoded_script = cmd.split("-EncodedCommand ")[1]
+ decoded_script = base64.b64decode(encoded_script).decode("utf-16-le")
+ assert "Test-Path" in decoded_script
+ assert "Get-Content" in decoded_script
+
+
+class TestKillCommands:
+ def test_posix_kill(self):
+ """Test POSIX kill command."""
+ cmd = build_posix_kill_command("/tmp/pid")
+ assert "kill" in cmd
+ assert "cat" in cmd
+
+ def test_windows_kill(self):
+ """Test Windows kill command."""
+ cmd = build_windows_kill_command("C:\\temp\\pid")
+ assert "powershell.exe" in cmd
+ assert "-EncodedCommand" in cmd
+ encoded_script = cmd.split("-EncodedCommand ")[1]
+ decoded_script = base64.b64decode(encoded_script).decode("utf-16-le")
+ assert "Stop-Process" in decoded_script
+
+
+class TestCleanupCommands:
+ def test_posix_cleanup(self):
+ """Test POSIX cleanup command."""
+ cmd = build_posix_cleanup_command("/tmp/airflow-ssh-jobs/job_123")
+ assert "rm -rf" in cmd
+ assert "/tmp/airflow-ssh-jobs/job_123" in cmd
+
+ def test_windows_cleanup(self):
+ """Test Windows cleanup command."""
+ cmd =
build_windows_cleanup_command("$env:TEMP\\airflow-ssh-jobs\\job_123")
+ assert "powershell.exe" in cmd
+ assert "-EncodedCommand" in cmd
+ encoded_script = cmd.split("-EncodedCommand ")[1]
+ decoded_script = base64.b64decode(encoded_script).decode("utf-16-le")
+ assert "Remove-Item" in decoded_script
+ assert "-Recurse" in decoded_script
+
+ def test_posix_cleanup_rejects_invalid_path(self):
+ """Test POSIX cleanup rejects paths outside expected base directory."""
+ with pytest.raises(ValueError, match="Invalid job directory"):
+ build_posix_cleanup_command("/tmp/other_dir")
+
+ def test_windows_cleanup_rejects_invalid_path(self):
+ """Test Windows cleanup rejects paths outside expected base
directory."""
+ with pytest.raises(ValueError, match="Invalid job directory"):
+ build_windows_cleanup_command("C:\\temp\\other_dir")