This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 120dbed3462 Add `host_key_policy` option to `ComputeEngineSSHHook` 
(#66746)
120dbed3462 is described below

commit 120dbed3462cedcb980aac022c587ba434249eb1
Author: Jarek Potiuk <[email protected]>
AuthorDate: Thu May 14 00:03:21 2026 +0200

    Add `host_key_policy` option to `ComputeEngineSSHHook` (#66746)
    
    * Add host_key_policy option to ComputeEngineSSHHook
    
    Exposes paramiko's `MissingHostKeyPolicy` choice as a constructor
    argument so callers can opt into strict host-key verification on the
    SSH transport. The argument accepts the string aliases `"auto_add"`,
    `"reject"` and `"warning"` (which map to the matching `paramiko`
    policy classes) and also passes through any custom
    `paramiko.MissingHostKeyPolicy` instance — so a caller that wants to
    pin the remote host's key from GCE guest attributes / instance
    metadata can plug in a policy that loads it on the fly.
    
    The default is `"auto_add"`, preserving the historical behaviour of
    this hook; no migration is required for existing callers. The
    previous inline comment claiming the missing host-key check was
    unrelated to the local private key is removed — it conflated two
    different concerns and is replaced with a pointer to the new
    constructor argument.
    
    Generated-by: Claude Opus 4.7 (1M context) following the guidelines at
    
https://github.com/apache/airflow/blob/main/contributing-docs/05_pull_requests.rst#gen-ai-assisted-contributions
    
    * Address review: docstring, ValueError chaining, top-level paramiko import
    
    @shahar1 review on PR #66746:
    
    * Add :param host_key_policy: to the ComputeEngineSSHHook docstring so
      users see the option without reading the source.
    * Fix the _resolve_host_key_policy() docstring to say ValueError (matches
      what is actually raised) instead of AirflowException.
    * Re-raise the unknown-policy ValueError with from None so the KeyError
      implementation detail doesn't leak into the chained traceback.
    * Move 'import paramiko' to module top in test_compute_ssh.py; the
      function-local imports were unnecessary now that the test class is
      permanent.
---
 .../providers/google/cloud/hooks/compute_ssh.py    | 56 ++++++++++++++++++++--
 .../unit/google/cloud/hooks/test_compute_ssh.py    | 37 ++++++++++++++
 2 files changed, 90 insertions(+), 3 deletions(-)

diff --git 
a/providers/google/src/airflow/providers/google/cloud/hooks/compute_ssh.py 
b/providers/google/src/airflow/providers/google/cloud/hooks/compute_ssh.py
index c39da86beed..33509f3abdf 100644
--- a/providers/google/src/airflow/providers/google/cloud/hooks/compute_ssh.py
+++ b/providers/google/src/airflow/providers/google/cloud/hooks/compute_ssh.py
@@ -112,6 +112,13 @@ class ComputeEngineSSHHook(SSHHook):
     :param impersonation_chain: Optional. The service account email to 
impersonate using short-term
         credentials. The provided service account must grant the originating 
account
         the Service Account Token Creator IAM role and have the sufficient 
rights to perform the request
+    :param host_key_policy: Policy used by the underlying 
``paramiko.SSHClient`` for unknown host keys.
+        Accepts the string aliases ``"auto_add"`` (default, historical 
behaviour — adds unknown keys
+        without prompting), ``"reject"`` (refuse to connect to hosts whose key 
is not in ``known_hosts``)
+        and ``"warning"`` (log a warning but still connect). Alternatively, 
pass an instance of
+        ``paramiko.MissingHostKeyPolicy`` (or a subclass) to plug in a custom 
policy — for example one
+        that loads pinned host keys from GCE guest attributes. Any other value 
raises ``ValueError``
+        when the connection is opened.
     """
 
     conn_name_attr = "gcp_conn_id"
@@ -141,6 +148,7 @@ class ComputeEngineSSHHook(SSHHook):
         cmd_timeout: int | ArgNotSet = NOTSET,
         max_retries: int = 10,
         impersonation_chain: str | None = None,
+        host_key_policy: str | paramiko.MissingHostKeyPolicy = "auto_add",
         **kwargs,
     ) -> None:
         # Ignore original constructor
@@ -158,8 +166,47 @@ class ComputeEngineSSHHook(SSHHook):
         self.cmd_timeout = cmd_timeout
         self.max_retries = max_retries
         self.impersonation_chain = impersonation_chain
+        self.host_key_policy = host_key_policy
         self._conn: Any | None = None
 
+    def _resolve_host_key_policy(self) -> paramiko.MissingHostKeyPolicy:
+        """
+        Resolve ``self.host_key_policy`` to a concrete paramiko policy 
instance.
+
+        Accepts:
+
+        - the string aliases ``"auto_add"``, ``"reject"`` or ``"warning"`` —
+          mapped to the matching ``paramiko`` policy class;
+        - an instance of ``paramiko.MissingHostKeyPolicy`` — used as-is, so
+          callers can plug in a custom policy (e.g. one that loads pinned
+          host keys from GCE guest attributes).
+
+        Any other value raises :class:`ValueError`.
+
+        The default value (``"auto_add"``) preserves the historical behaviour
+        of this hook. Callers that want the remote SSH server authenticated
+        before the session opens should pass ``"reject"`` together with a
+        populated ``known_hosts`` file, or supply a custom policy that
+        looks the remote host's key up from an out-of-band source.
+        """
+        if not isinstance(self.host_key_policy, str):
+            # Trust the caller: an explicit paramiko.MissingHostKeyPolicy
+            # instance, or a subclass instance with custom behaviour.
+            return self.host_key_policy
+        builtins = {
+            "auto_add": paramiko.AutoAddPolicy,
+            "reject": paramiko.RejectPolicy,
+            "warning": paramiko.WarningPolicy,
+        }
+        try:
+            return builtins[self.host_key_policy]()
+        except KeyError:
+            raise ValueError(
+                f"Unknown host_key_policy {self.host_key_policy!r}. "
+                "Expected one of 'auto_add', 'reject', 'warning', "
+                "or an instance of paramiko.MissingHostKeyPolicy."
+            ) from None
+
     @cached_property
     def _oslogin_hook(self) -> OSLoginHook:
         return OSLoginHook(gcp_conn_id=self.gcp_conn_id)
@@ -310,9 +357,12 @@ class ComputeEngineSSHHook(SSHHook):
         for time_to_wait in range(max_time_to_wait + 1):
             try:
                 client = _GCloudAuthorizedSSHClient(self._compute_hook)
-                # Default is RejectPolicy
-                # No known host checking since we are not storing privatekey
-                client.set_missing_host_key_policy(paramiko.AutoAddPolicy())  
# nosec B507
+                # Apply the policy configured via the `host_key_policy` 
constructor
+                # argument; default is `"auto_add"` (paramiko.AutoAddPolicy), 
which
+                # preserves the historical behaviour of this hook. Callers 
that need
+                # the remote host authenticated should pass `"reject"` with a
+                # populated known_hosts file or a custom MissingHostKeyPolicy.
+                
client.set_missing_host_key_policy(self._resolve_host_key_policy())  # nosec 
B507
                 client.connect(
                     hostname=hostname,
                     username=user,
diff --git a/providers/google/tests/unit/google/cloud/hooks/test_compute_ssh.py 
b/providers/google/tests/unit/google/cloud/hooks/test_compute_ssh.py
index 513f024c1a2..7aa65f3e24c 100644
--- a/providers/google/tests/unit/google/cloud/hooks/test_compute_ssh.py
+++ b/providers/google/tests/unit/google/cloud/hooks/test_compute_ssh.py
@@ -21,6 +21,7 @@ import logging
 from unittest import mock
 
 import httplib2
+import paramiko
 import pytest
 from googleapiclient.errors import HttpError
 from paramiko.ssh_exception import SSHException
@@ -560,3 +561,39 @@ class TestComputeEngineHookWithPassedProjectId:
             
mock_set_instance_metadata.call_args.kwargs["metadata"]["items"].sort(key=lambda
 x: x["key"])
             expected_metadata["items"].sort(key=lambda x: x["key"])
             assert mock_set_instance_metadata.call_args.kwargs["metadata"] == 
expected_metadata
+
+
+class TestHostKeyPolicyResolution:
+    """Tests for the ``host_key_policy`` constructor argument."""
+
+    def test_default_is_auto_add(self):
+        hook = ComputeEngineSSHHook()
+
+        assert hook.host_key_policy == "auto_add"
+        assert isinstance(hook._resolve_host_key_policy(), 
paramiko.AutoAddPolicy)
+
+    def test_string_aliases(self):
+        assert isinstance(
+            
ComputeEngineSSHHook(host_key_policy="auto_add")._resolve_host_key_policy(),
+            paramiko.AutoAddPolicy,
+        )
+        assert isinstance(
+            
ComputeEngineSSHHook(host_key_policy="reject")._resolve_host_key_policy(),
+            paramiko.RejectPolicy,
+        )
+        assert isinstance(
+            
ComputeEngineSSHHook(host_key_policy="warning")._resolve_host_key_policy(),
+            paramiko.WarningPolicy,
+        )
+
+    def test_custom_policy_instance_is_returned_unchanged(self):
+        custom_policy = paramiko.RejectPolicy()
+        hook = ComputeEngineSSHHook(host_key_policy=custom_policy)
+
+        assert hook._resolve_host_key_policy() is custom_policy
+
+    def test_unknown_string_raises_value_error(self):
+        hook = ComputeEngineSSHHook(host_key_policy="strict")
+
+        with pytest.raises(ValueError, match=r"Unknown host_key_policy 
'strict'"):
+            hook._resolve_host_key_policy()

Reply via email to