Ori.livneh has uploaded a new change for review.

  https://gerrit.wikimedia.org/r/234454

Change subject: Small improvements to ssh-agent-proxy
......................................................................

Small improvements to ssh-agent-proxy

Change-Id: I97dd5d803996e930bf8ef118b77ce8b4942f5b6f
---
M modules/keyholder/files/ssh-agent-proxy
1 file changed, 99 insertions(+), 97 deletions(-)


  git pull ssh://gerrit.wikimedia.org:29418/operations/puppet 
refs/changes/54/234454/1

diff --git a/modules/keyholder/files/ssh-agent-proxy 
b/modules/keyholder/files/ssh-agent-proxy
index 375e11b..df7b447 100644
--- a/modules/keyholder/files/ssh-agent-proxy
+++ b/modules/keyholder/files/ssh-agent-proxy
@@ -40,6 +40,7 @@
 
 """
 import argparse
+import glob
 import grp
 import hashlib
 import os
@@ -62,141 +63,142 @@
     )
 
 
-SSH_AGENT_FAILURE = 5
+# Defined in <socket.h>.
 SO_PEERCRED = 17
 
-s_ns_header = struct.Struct('!L')
+# These constants are part of OpenSSH's ssh-agent protocol spec.
+# See <http://api.libssh.org/rfc/PROTOCOL.agent>.
+SSH2_AGENTC_REQUEST_IDENTITIES = 11
+SSH2_AGENTC_SIGN_REQUEST = 13
+SSH_AGENTC_REQUEST_RSA_IDENTITIES = 1
+SSH_AGENT_FAILURE = 5
+
 s_message_header = struct.Struct('!LB')
 s_ucred = struct.Struct('2Ii')
 
-syslog.openlog(logoption=syslog.LOG_PID, facility=syslog.LOG_AUTH)
+
+def unpack_variable_length_string(buffer, offset=0):
+    """Read a variable-length string from a buffer. The first 4 bytes are the
+    big-endian unsigned long representing the length of the string."""
+    size = struct.unpack_from('!L', buffer, offset)
+    string, = struct.unpack_from('xxxx%ds' % size, buffer, offset)
+    return string
+
+
+def get_key_permissions(path):
+    """Recursively walk `path`, loading YAML configuration files."""
+    key_permissions = {}
+    for fname in glob.glob(os.path.join(path, '*.y*ml')):
+        with open(fname) as yml:
+            for group, keys in yaml.safe_load(yml).items():
+                for key in keys:
+                    key = key.replace(':', '')
+                    key_permissions.setdefault(key, set()).add(group)
+    return key_permissions
 
 
 class SshAgentProxyServer(socketserver.ThreadingUnixStreamServer):
+    """A threaded server that listens on a UNIX domain socket and handles
+    requests by filtering them and proxying them to a backend SSH agent."""
+
     def __init__(self, server_address, agent_address, key_permissions):
+        super().__init__(server_address, SshAgentProxyHandler)
         self.agent_address = agent_address
-        super(SshAgentProxyServer, self).__init__(
-                server_address, SshAgentProxyHandler)
         self.key_permissions = key_permissions
 
 
 class SshAgentProxyHandler(socketserver.BaseRequestHandler):
-    # See <http://api.libssh.org/rfc/PROTOCOL.agent>
-    permitted_requests = {
-        0x1: 'SSH_AGENTC_REQUEST_RSA_IDENTITIES',
-        0xb: 'SSH2_AGENTC_REQUEST_IDENTITIES',
-        0xd: 'SSH2_AGENTC_SIGN_REQUEST',
-    }
+    """This class is responsible for handling an individual connection
+    to an SshAgentProxyServer."""
 
     def get_peer_credentials(self, sock):
-        credentials = sock.getsockopt(
-            socket.SOL_SOCKET, SO_PEERCRED, s_ucred.size)
-        pid, uid, gid = s_ucred.unpack(credentials)
-        user_name = pwd.getpwuid(uid).pw_name
-        groups = {g.gr_name for g in grp.getgrall() if user_name in g.gr_mem}
-        groups.add(grp.getgrgid(gid).gr_name)
-        return user_name, groups
+        """Return the user and group name of the peer of a UNIX socket."""
+        ucred = sock.getsockopt(socket.SOL_SOCKET, SO_PEERCRED, s_ucred.size)
+        _, uid, gid = s_ucred.unpack(ucred)
+        user = pwd.getpwuid(uid).pw_name
+        groups = {grp.getgrgid(gid).gr_name}
+        groups.update(g.gr_name for g in grp.getgrall() if user in g.gr_mem)
+        return user, groups
 
     def setup(self):
-        self.proxy = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
-        self.proxy.setblocking(False)
-        self.proxy.connect(self.server.agent_address)
-        self.sockets = (self.request, self.proxy)
+        """Set up a connection to the backend SSH agent backend."""
+        self.backend = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
+        self.backend.setblocking(False)
+        self.backend.connect(self.server.agent_address)
 
     def recv_message(self, sock):
+        """Read a message from a socket."""
         header = sock.recv(s_message_header.size, socket.MSG_WAITALL)
-        if len(header) < s_message_header.size:
+        try:
+            size, code = s_message_header.unpack(header)
+        except struct.error:
             return None, b''
-        size, code = s_message_header.unpack(header)
         message = sock.recv(size - 1, socket.MSG_WAITALL)
         return code, message
 
     def send_message(self, sock, code, message=b''):
+        """Send a message on a socket."""
         header = s_message_header.pack(len(message) + 1, code)
         sock.sendall(header + message)
 
-    def get_fingerprint_from_ns(self, msg):
-        """Get key from netstring sign request
+    def handle_backend(self):
+        """Read data from the backend SSH agent and send to client."""
+        code, message = self.recv_message(self.backend)
+        self.send_message(self.request, code, message)
 
-        retrieves key from SSH2_AGENTC_SIGN_REQUEST (first 4 bytes are
-        the big-endian unsigned long representing length of the key,
-        remaining bytes are signing message). Returns md5 base64 key
-        representation and compares result against key_permissions hash.
-        """
-        key_start = s_ns_header.size
-        key_length, = s_ns_header.unpack_from(msg)
-        key_blob = msg[key_start:key_start + key_length]
-        return hashlib.md5(key_blob).hexdigest()
+    def handle_client_request(self):
+        """Read data from client and send to backend SSH agent."""
+        code, message = self.recv_message(self.request)
 
-    def check_key(self, groups, msg):
-        """Checks key fingerprint and group membership against key_permissions
-        """
-        fingerprint = self.get_fingerprint_from_ns(msg)
-        for group in groups:
-            if fingerprint in self.server.key_permissions.get(group, ()):
-                return True
-        return False
+        if code == SSH2_AGENTC_REQUEST_IDENTITIES:
+            return self.send_message(self.backend, code, message)
 
-    def authorized_request(self, code, groups, message):
-        """Checks if the code of the request is a permitted request type. If
-        the request type is permitted and the request type is a sign request
-        check group/key permissions.
-        """
-        if code not in self.permitted_requests:
-            return False
+        if code == SSH_AGENTC_REQUEST_RSA_IDENTITIES:
+            return self.send_message(self.backend, code, message)
 
-        if self.permitted_requests[code] == 'SSH2_AGENTC_SIGN_REQUEST':
-            return self.check_key(groups, message)
+        if code == SSH2_AGENTC_SIGN_REQUEST:
+            key = unpack_variable_length_string(message)
+            digest = hashlib.md5(key).hexdigest()
+            user, groups = self.get_peer_credentials(self.request)
+            if groups & self.server.key_permissions.get(digest, set()):
+                syslog.syslog(syslog.LOG_INFO, 'Allowing signing request '
+                              'from user %s using key %s.' % (user, digest))
+                return self.send_message(self.backend, code, message)
+            syslog.syslog(syslog.LOG_NOTICE, 'Denying signing request '
+                          'from user %s using key %s.' % (user, digest))
 
-        return True
+        return self.send_message(self.request, SSH_AGENT_FAILURE)
 
     def handle(self):
+        """Handle a new client connection by shuttling data between the client
+        and the backend."""
+        syslog.syslog('New connection from %s.' % self.client_address)
         while 1:
-            readable, *_ = select.select(self.sockets, (), (), 1)
-            if self.proxy in readable:
-                code, message = self.recv_message(self.proxy)
-                self.send_message(self.request, code, message)
-            if self.request in readable:
-                code, message = self.recv_message(self.request)
-                if code is None:
-                    return
-                user, groups = self.get_peer_credentials(self.request)
-                req = self.permitted_requests.get(code, 'UNKNOWN (%s)' % code)
-                syslog.syslog('Received %s from %s:%s' % (req, user, groups))
-
-                if self.authorized_request(code, groups, message):
-                    self.send_message(self.proxy, code, message)
-                else:
-                    self.send_message(self.request, SSH_AGENT_FAILURE)
+            rlist, *_ = select.select((self.backend, self.request), (), (), 1)
+            if self.backend in rlist:
+                self.handle_backend()
+            if self.request in rlist:
+                self.handle_client_request()
 
 
-def raise_err(err):
-    raise err
-
-
-def get_permissions(path):
-    """Recursively walk `path`, loading YAML configuration files."""
-    permissions = {}
-    for _, _, files in os.walk(path, onerror=raise_err):
-        for f in files:
-            with open(os.path.join(path, f)) as yml:
-                for group, keys in yaml.load(yml).items():
-                    sanitized_keys = (key.replace(':', '') for key in keys)
-                    permissions.setdefault(group, []).extend(sanitized_keys)
-    return permissions
-
-
-ap = argparse.ArgumentParser(description='Filtering proxy for ssh-agent(1)')
-ap.add_argument('--bind', default='/run/keyholder/proxy.sock',
-                help='Bind the proxy to the domain socket at this address')
-ap.add_argument('--connect', default='/run/keyholder/agent.sock',
-                help='Proxy connects to the ssh-agent socket at this address')
-ap.add_argument('--auth-dir', default='/etc/keyholder-auth.d',
-                help='directory with YAML files containing group names'
-                     'mapped to arrays of SSH public key fingerprints')
+ap = argparse.ArgumentParser(description='filtering proxy for ssh-agent')
+ap.add_argument(
+    '--bind',
+    default='/run/keyholder/proxy.sock',
+    help='Bind the proxy to the domain socket at this address'
+)
+ap.add_argument(
+    '--connect',
+    default='/run/keyholder/agent.sock',
+    help='Proxy connects to the ssh-agent socket at this address'
+)
+ap.add_argument(
+    '--auth-dir',
+    default='/etc/keyholder-auth.d',
+    help='directory with YAML configuration files'
+)
 args = ap.parse_args()
 
-syslog.syslog('Proxying %s -> %s' % (args.bind, args.connect))
-proxy = SshAgentProxyServer(args.bind, args.connect,
-                            get_permissions(args.auth_dir))
-proxy.serve_forever()
+key_perms = get_key_permissions(args.auth_dir)
+syslog.openlog(logoption=syslog.LOG_PID, facility=syslog.LOG_AUTH)
+SshAgentProxyServer(args.bind, args.connect, key_perms).serve_forever()

-- 
To view, visit https://gerrit.wikimedia.org/r/234454
To unsubscribe, visit https://gerrit.wikimedia.org/r/settings

Gerrit-MessageType: newchange
Gerrit-Change-Id: I97dd5d803996e930bf8ef118b77ce8b4942f5b6f
Gerrit-PatchSet: 1
Gerrit-Project: operations/puppet
Gerrit-Branch: production
Gerrit-Owner: Ori.livneh <[email protected]>

_______________________________________________
MediaWiki-commits mailing list
[email protected]
https://lists.wikimedia.org/mailman/listinfo/mediawiki-commits

Reply via email to