mik-laj commented on a change in pull request #9879:
URL: https://github.com/apache/airflow/pull/9879#discussion_r460596997



##########
File path: airflow/providers/google/cloud/hooks/compute_ssh.py
##########
@@ -0,0 +1,378 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import subprocess
+import time
+from io import StringIO
+from typing import Any, List, Optional, Union
+
+import paramiko
+from google.cloud.oslogin_v1 import OsLoginServiceClient
+from googleapiclient.discovery import build
+from paramiko import AuthenticationException, BadHostKeyException, SSHException
+
+from airflow import AirflowException
+from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
+
+
+class ComputeEngineSshHook(GoogleBaseHook):
+    """
+    Hook to connect to a remote instance in compute engine
+
+    :param api_version: The Compute Engine discovery API version
+    :type api_version: str
+    :param gcp_conn_id: The connection id to use when fetching connection 
information
+    :type gcp_conn_id: str
+    :param delegate_to: The account to impersonate, if any.
+        For this to work, the service account making the request must have
+        domain-wide delegation enabled.
+    :type delegate_to: str
+    """
+    _conn: Optional[Any] = None
+
+    def __init__(
+        self,
+        api_version: str = 'v1',
+        gcp_conn_id: str = 'google_cloud_default',
+        delegate_to: Optional[str] = None
+    ) -> None:
+        super().__init__(gcp_conn_id, delegate_to)
+        self.api_version = api_version
+        self._oslogin_conn: Optional[OsLoginServiceClient] = None
+
+    def get_conn(self) -> Any:
+        """
+        Retrieves connection to Google Compute Engine.
+
+        :return: Google Compute Engine services object
+        :rtype: dict
+        """
+        if not self._conn:
+            http_authorized = self._authorize()
+            self._conn = build('compute', self.api_version,
+                               http=http_authorized, cache_discovery=False)
+        return self._conn
+
+    def get_oslogin_conn(self) -> OsLoginServiceClient:
+        """
+        Connects to Oslogin API
+        """
+        if not self._oslogin_conn:
+            self._oslogin_conn = OsLoginServiceClient(
+                credentials=self._get_credentials(), 
client_info=self.client_info
+            )
+
+        return self._oslogin_conn
+
+    def get_sshclient(self,
+                      pkey: str,
+                      hostname: str,
+                      username: str
+                      ) -> paramiko.SSHClient:
+        """
+        Get SSHClient
+
+        :param pkey: The private key
+        :type pkey: str
+        :param hostname: The hostname of the target instance
+        :type hostname: str
+        :param username: The username from the login profile of the user 
account
+        :type username: str
+        """
+
+        client = paramiko.SSHClient()
+        # Default is RejectPolicy
+        # No knownhost checking since we are not storing pkey
+        client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
+
+        client.connect(hostname=hostname,
+                       username=username,
+                       pkey=pkey,
+                       look_for_keys=False)
+        return client
+
+    def get_service_account(self) -> str:
+        """
+        Get the service account id
+        """
+        account: str = self._get_credentials().service_account_email
+        os_login_client = self.get_oslogin_conn()
+        return os_login_client.user_path(user=account)
+
+    def get_username(self, oslogin=True) -> str:
+        """
+        Get the account username
+
+        :param oslogin: Whether to use OsLogin
+        :type oslogin: bool
+        """
+        account = self._get_credentials().service_account_email
+        if oslogin:
+            profile = 
self.get_oslogin_conn().get_login_profile(name=self.get_service_account())
+            account = profile.posix_accounts[0]
+            return account.username
+        return account.split("@")[0]
+
+    @GoogleBaseHook.fallback_to_default_project_id
+    def get_connection_type(self, project_id: Optional[str], instance: str, 
zone: str,
+                            hostname: Optional[str] = None,
+                            use_iap_tunnel: bool = False,
+                            use_internal_ip: bool = False
+                            ) -> str:
+        """
+        Selects how to connect to the target instance
+
+        :param project_id: The project ID of the instance
+        :type project_id: Optional[str]
+        :param instance: The name of the target instance to connect to
+        :type instance: str
+        :param zone: The zone of the target instance
+        :param zone: str
+        :param hostname: The hostname of the target instance. If provided, 
this would
+            be used, regardless of the setting for use_iap_tunnel and 
use_internal_ip
+        :type hostname: str
+        :param use_iap_tunnel: Whether to connect through IAP tunnel
+        :type use_iap_tunnel: bool
+        :param use_internal_ip: Whether to connect using internal IP
+        :type use_internal_ip: bool
+        """
+        if not hostname:
+            instance_info = self.instance_info(project_id=project_id,
+                                               zone=zone,
+                                               instance=instance)
+            if use_iap_tunnel or use_internal_ip:
+                return instance_info["networkInterfaces"][0] \
+                    .get("networkIP")
+            else:
+                access_config = instance_info \
+                    .get("networkInterfaces")[0] \
+                    .get("accessConfigs")
+                if access_config:
+                    return access_config[0].get("natIP")
+                else:
+                    raise AirflowException("The target instance does not have 
external IP,"
+                                           "consider specifying 
use_internal_ip or use_iap_tunnel")
+        return hostname
+
+    def execute(self, cmd: Union[str, List[str]]) -> None:
+        """
+        Executes command in local machine
+
+        :param cmd: The command to execute
+        :type cmd: Union[str, List[str]]
+        """
+        self.log.info('Executing command: %s', str(cmd))
+        process = subprocess.Popen(cmd)
+        output = process.communicate()[0]
+        returncode = process.returncode
+        if returncode:
+            # Error
+            self.log.error('Command returned error status %s', returncode)
+        if output:
+            self.log.info(output)
+
+    @GoogleBaseHook.fallback_to_default_project_id
+    def instance_info(self, project_id, zone, instance):
+        """
+        Used to get instance information
+
+        :param project_id: The project ID where the instance is located
+        :type project_id: str
+        :param zone: The zone of the instance
+        :type zone: str
+        :param instance: The name of the instance
+        :type instance: str
+        """
+        instance_info = self.get_conn().instances().get(  # pylint: 
disable=no-member
+            project=project_id,
+            instance=instance,
+            zone=zone).execute()
+        return instance_info
+
+    def _update_metadata(self, keys, metadata):
+        """
+        Updates ssh keys on metadata dictionary
+
+        :param keys: The public key to set on the metadata
+        :type keys: str
+        :param metadata: A dictionary containing instance metadata
+        :type metadata: dict
+        """
+        items = metadata.get("items")
+        for item in items:
+            if item.get("key") == "ssh-keys":
+                keys += item["value"]
+                item['value'] = keys
+                break
+        else:
+            new_dict = dict(key='ssh-keys', value=keys)
+            metadata['items'] = [new_dict]
+        return metadata
+
+    @GoogleBaseHook.fallback_to_default_project_id
+    def set_sshkey_metadata(self, project_id, zone, instance, username, 
pubkey):
+        """
+        Sets ssh key on instance metadata
+
+        :param project_id: The project ID of the instance
+        :type project_id: Optional[str]
+        :param zone: The zone of the target instance
+        :param zone: str
+        :param instance: The name of the target instance to connect to
+        :type instance: str
+        :param username: The logged in account username
+        :type username: str
+        :param pubkey: The public key to set on the instance
+        :type pubkey: str
+        """
+        instance_info = self.instance_info(project_id=project_id,
+                                           zone=zone,
+                                           instance=instance)
+        keys = username + ":" + pubkey + "\n"
+        metadata = instance_info['metadata']
+        metadata = self._update_metadata(keys, metadata)
+        self.get_conn().instances().setMetadata(  # pylint: disable=no-member
+            project=project_id,
+            zone=zone,
+            instance=instance,
+            body=metadata
+        ).execute()
+        time.sleep(10)  # TODO improve on this. The status should be Done 
before proceeding
+
+    def set_oslogin_key(self, pubkey, expire_time):
+        """
+        Sets public key on instance os login
+
+        :param pubkey: The public key to set on the OsLogin
+        :type pubkey: str
+        :param expire_time: The amount of time before the key expires in 
seconds
+        :type expire_time: int
+        """
+        expiration = int((time.time() + expire_time) * 1000000)
+        body = {
+            "key": pubkey,
+            "expiration_time_usec": expiration
+        }
+        os_login_client = self.get_oslogin_conn()
+        os_login_client.import_ssh_public_key(
+            parent=self.get_service_account(),
+            ssh_public_key=body
+        )
+
+    @GoogleBaseHook.fallback_to_default_project_id
+    def _create_ssh_key(self,
+                        instance: str,
+                        zone: str,
+                        project_id: Optional[str] = None,
+                        expire_time: int = 300,
+                        use_oslogin: bool = True) -> Any:
+        """
+        Generate a temporary SSH key and apply it to the specified account
+
+        :param project_id: The project ID of the instance
+        :type project_id: Optional[str]
+        :param zone: The zone of the target instance
+        :param zone: str
+        :param instance: The name of the target instance to connect to
+        :type instance: str
+        :param expire_time: The maximum amount of time before the private key 
is expired
+        :type expire_time: int
+        :param use_oslogin: Whether to use OsLogin to manage key
+        :type use_oslogin: int
+        """
+        username = self.get_username(oslogin=use_oslogin)
+        try:
+            self.log.info("Generating ssh keys...")
+            pkey_file = StringIO()
+            pkey_obj = paramiko.RSAKey.generate(2048)
+            pkey_obj.write_private_key(pkey_file)
+            pubkey = f"{pkey_obj.get_name()} {pkey_obj.get_base64()} 
{username}"
+        except (IOError, SSHException) as err:
+            raise AirflowException(f"Error encountered creating ssh keys, 
{err}")
+
+        if use_oslogin:
+            # Expiration time is in microseconds.
+            self.set_oslogin_key(pubkey, expire_time)
+        else:
+            # using instance metadata to manage keys
+            self.set_sshkey_metadata(project_id=project_id,
+                                     zone=zone,
+                                     instance=instance,
+                                     username=username,
+                                     pubkey=pubkey)
+        return pkey_obj
+
+    @GoogleBaseHook.fallback_to_default_project_id
+    def run(self,
+            cmd: str, instance: str, zone: str,
+            project_id: Optional[str] = None,
+            expire_time: int = 300,
+            hostname: Optional[str] = None,
+            use_internal_ip: bool = False,
+            use_iap_tunnel: bool = False,
+            use_oslogin: bool = True
+            ) -> Any:
+        """
+         Execute a command in a remote instance
+
+        :param cmd: The command to execute on the instance
+        :type cmd: str
+        :param instance: The name of the instance the command will be executed 
on
+        :type instance: str
+        :param zone: The zone where the instance is located
+        :type zone: str
+        :param project_id: The project ID of the instance
+        :type project_id: Optional[str]
+        :param expire_time: The maximum amount of time before the private key 
is expired
+        :type expire_time: int
+        :param hostname: The hostname of the instance.
+        :type hostname: str
+        :param use_iap_tunnel: Whether to connect through IAP tunnel
+        :type use_iap_tunnel: bool
+        :param use_internal_ip: Whether to connect using internal IP
+        :type use_internal_ip: bool
+        :param use_oslogin: Whether to manage keys using OsLogin API. If false,
+            keys are managed using instance metadata
+        """
+        hostname = self.get_connection_type(project_id=project_id,
+                                            instance=instance,
+                                            zone=zone,
+                                            hostname=hostname,
+                                            use_iap_tunnel=use_iap_tunnel,
+                                            use_internal_ip=use_internal_ip
+                                            )
+
+        username = self.get_username(oslogin=use_oslogin)
+
+        self.log.info("Creating remote connection to host")
+        pkey = self._create_ssh_key(
+            instance=instance,
+            project_id=project_id,
+            zone=zone,
+            expire_time=expire_time,
+            use_oslogin=use_oslogin
+        )
+        sshclient = self.get_sshclient(pkey=pkey,

Review comment:
       There is one problem if internal IP addresses are used.  Two servers can 
use the same IP address if they are on two different networks. We should make 
sure that we have connected to the expected server.
   ```
   metadata_id_url = 
'http://metadata.google.internal/computeMetadata/v1/instance/id'
   remote_command = f'[ `curl "{metadata_id_url}" -H "Metadata-Flavor: Google" 
-q` = {instance_id} ] || exit 42'
   ```




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to