Ori.livneh has uploaded a new change for review.
https://gerrit.wikimedia.org/r/234454
Change subject: Small improvements to ssh-agent-proxy
......................................................................
Small improvements to ssh-agent-proxy
Change-Id: I97dd5d803996e930bf8ef118b77ce8b4942f5b6f
---
M modules/keyholder/files/ssh-agent-proxy
1 file changed, 99 insertions(+), 97 deletions(-)
git pull ssh://gerrit.wikimedia.org:29418/operations/puppet
refs/changes/54/234454/1
diff --git a/modules/keyholder/files/ssh-agent-proxy
b/modules/keyholder/files/ssh-agent-proxy
index 375e11b..df7b447 100644
--- a/modules/keyholder/files/ssh-agent-proxy
+++ b/modules/keyholder/files/ssh-agent-proxy
@@ -40,6 +40,7 @@
"""
import argparse
+import glob
import grp
import hashlib
import os
@@ -62,141 +63,142 @@
)
-SSH_AGENT_FAILURE = 5
+# Defined in <socket.h>.
SO_PEERCRED = 17
-s_ns_header = struct.Struct('!L')
+# These constants are part of OpenSSH's ssh-agent protocol spec.
+# See <http://api.libssh.org/rfc/PROTOCOL.agent>.
+SSH2_AGENTC_REQUEST_IDENTITIES = 11
+SSH2_AGENTC_SIGN_REQUEST = 13
+SSH_AGENTC_REQUEST_RSA_IDENTITIES = 1
+SSH_AGENT_FAILURE = 5
+
s_message_header = struct.Struct('!LB')
s_ucred = struct.Struct('2Ii')
-syslog.openlog(logoption=syslog.LOG_PID, facility=syslog.LOG_AUTH)
+
+def unpack_variable_length_string(buffer, offset=0):
+ """Read a variable-length string from a buffer. The first 4 bytes are the
+ big-endian unsigned long representing the length of the string."""
+ size = struct.unpack_from('!L', buffer, offset)
+ string, = struct.unpack_from('xxxx%ds' % size, buffer, offset)
+ return string
+
+
+def get_key_permissions(path):
+ """Recursively walk `path`, loading YAML configuration files."""
+ key_permissions = {}
+ for fname in glob.glob(os.path.join(path, '*.y*ml')):
+ with open(fname) as yml:
+ for group, keys in yaml.safe_load(yml).items():
+ for key in keys:
+ key = key.replace(':', '')
+ key_permissions.setdefault(key, set()).add(group)
+ return key_permissions
class SshAgentProxyServer(socketserver.ThreadingUnixStreamServer):
+ """A threaded server that listens on a UNIX domain socket and handles
+ requests by filtering them and proxying them to a backend SSH agent."""
+
def __init__(self, server_address, agent_address, key_permissions):
+ super().__init__(server_address, SshAgentProxyHandler)
self.agent_address = agent_address
- super(SshAgentProxyServer, self).__init__(
- server_address, SshAgentProxyHandler)
self.key_permissions = key_permissions
class SshAgentProxyHandler(socketserver.BaseRequestHandler):
- # See <http://api.libssh.org/rfc/PROTOCOL.agent>
- permitted_requests = {
- 0x1: 'SSH_AGENTC_REQUEST_RSA_IDENTITIES',
- 0xb: 'SSH2_AGENTC_REQUEST_IDENTITIES',
- 0xd: 'SSH2_AGENTC_SIGN_REQUEST',
- }
+ """This class is responsible for handling an individual connection
+ to an SshAgentProxyServer."""
def get_peer_credentials(self, sock):
- credentials = sock.getsockopt(
- socket.SOL_SOCKET, SO_PEERCRED, s_ucred.size)
- pid, uid, gid = s_ucred.unpack(credentials)
- user_name = pwd.getpwuid(uid).pw_name
- groups = {g.gr_name for g in grp.getgrall() if user_name in g.gr_mem}
- groups.add(grp.getgrgid(gid).gr_name)
- return user_name, groups
+ """Return the user and group name of the peer of a UNIX socket."""
+ ucred = sock.getsockopt(socket.SOL_SOCKET, SO_PEERCRED, s_ucred.size)
+ _, uid, gid = s_ucred.unpack(ucred)
+ user = pwd.getpwuid(uid).pw_name
+ groups = {grp.getgrgid(gid).gr_name}
+ groups.update(g.gr_name for g in grp.getgrall() if user in g.gr_mem)
+ return user, groups
def setup(self):
- self.proxy = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
- self.proxy.setblocking(False)
- self.proxy.connect(self.server.agent_address)
- self.sockets = (self.request, self.proxy)
+ """Set up a connection to the backend SSH agent backend."""
+ self.backend = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
+ self.backend.setblocking(False)
+ self.backend.connect(self.server.agent_address)
def recv_message(self, sock):
+ """Read a message from a socket."""
header = sock.recv(s_message_header.size, socket.MSG_WAITALL)
- if len(header) < s_message_header.size:
+ try:
+ size, code = s_message_header.unpack(header)
+ except struct.error:
return None, b''
- size, code = s_message_header.unpack(header)
message = sock.recv(size - 1, socket.MSG_WAITALL)
return code, message
def send_message(self, sock, code, message=b''):
+ """Send a message on a socket."""
header = s_message_header.pack(len(message) + 1, code)
sock.sendall(header + message)
- def get_fingerprint_from_ns(self, msg):
- """Get key from netstring sign request
+ def handle_backend(self):
+ """Read data from the backend SSH agent and send to client."""
+ code, message = self.recv_message(self.backend)
+ self.send_message(self.request, code, message)
- retrieves key from SSH2_AGENTC_SIGN_REQUEST (first 4 bytes are
- the big-endian unsigned long representing length of the key,
- remaining bytes are signing message). Returns md5 base64 key
- representation and compares result against key_permissions hash.
- """
- key_start = s_ns_header.size
- key_length, = s_ns_header.unpack_from(msg)
- key_blob = msg[key_start:key_start + key_length]
- return hashlib.md5(key_blob).hexdigest()
+ def handle_client_request(self):
+ """Read data from client and send to backend SSH agent."""
+ code, message = self.recv_message(self.request)
- def check_key(self, groups, msg):
- """Checks key fingerprint and group membership against key_permissions
- """
- fingerprint = self.get_fingerprint_from_ns(msg)
- for group in groups:
- if fingerprint in self.server.key_permissions.get(group, ()):
- return True
- return False
+ if code == SSH2_AGENTC_REQUEST_IDENTITIES:
+ return self.send_message(self.backend, code, message)
- def authorized_request(self, code, groups, message):
- """Checks if the code of the request is a permitted request type. If
- the request type is permitted and the request type is a sign request
- check group/key permissions.
- """
- if code not in self.permitted_requests:
- return False
+ if code == SSH_AGENTC_REQUEST_RSA_IDENTITIES:
+ return self.send_message(self.backend, code, message)
- if self.permitted_requests[code] == 'SSH2_AGENTC_SIGN_REQUEST':
- return self.check_key(groups, message)
+ if code == SSH2_AGENTC_SIGN_REQUEST:
+ key = unpack_variable_length_string(message)
+ digest = hashlib.md5(key).hexdigest()
+ user, groups = self.get_peer_credentials(self.request)
+ if groups & self.server.key_permissions.get(digest, set()):
+ syslog.syslog(syslog.LOG_INFO, 'Allowing signing request '
+ 'from user %s using key %s.' % (user, digest))
+ return self.send_message(self.backend, code, message)
+ syslog.syslog(syslog.LOG_NOTICE, 'Denying signing request '
+ 'from user %s using key %s.' % (user, digest))
- return True
+ return self.send_message(self.request, SSH_AGENT_FAILURE)
def handle(self):
+ """Handle a new client connection by shuttling data between the client
+ and the backend."""
+ syslog.syslog('New connection from %s.' % self.client_address)
while 1:
- readable, *_ = select.select(self.sockets, (), (), 1)
- if self.proxy in readable:
- code, message = self.recv_message(self.proxy)
- self.send_message(self.request, code, message)
- if self.request in readable:
- code, message = self.recv_message(self.request)
- if code is None:
- return
- user, groups = self.get_peer_credentials(self.request)
- req = self.permitted_requests.get(code, 'UNKNOWN (%s)' % code)
- syslog.syslog('Received %s from %s:%s' % (req, user, groups))
-
- if self.authorized_request(code, groups, message):
- self.send_message(self.proxy, code, message)
- else:
- self.send_message(self.request, SSH_AGENT_FAILURE)
+ rlist, *_ = select.select((self.backend, self.request), (), (), 1)
+ if self.backend in rlist:
+ self.handle_backend()
+ if self.request in rlist:
+ self.handle_client_request()
-def raise_err(err):
- raise err
-
-
-def get_permissions(path):
- """Recursively walk `path`, loading YAML configuration files."""
- permissions = {}
- for _, _, files in os.walk(path, onerror=raise_err):
- for f in files:
- with open(os.path.join(path, f)) as yml:
- for group, keys in yaml.load(yml).items():
- sanitized_keys = (key.replace(':', '') for key in keys)
- permissions.setdefault(group, []).extend(sanitized_keys)
- return permissions
-
-
-ap = argparse.ArgumentParser(description='Filtering proxy for ssh-agent(1)')
-ap.add_argument('--bind', default='/run/keyholder/proxy.sock',
- help='Bind the proxy to the domain socket at this address')
-ap.add_argument('--connect', default='/run/keyholder/agent.sock',
- help='Proxy connects to the ssh-agent socket at this address')
-ap.add_argument('--auth-dir', default='/etc/keyholder-auth.d',
- help='directory with YAML files containing group names'
- 'mapped to arrays of SSH public key fingerprints')
+ap = argparse.ArgumentParser(description='filtering proxy for ssh-agent')
+ap.add_argument(
+ '--bind',
+ default='/run/keyholder/proxy.sock',
+ help='Bind the proxy to the domain socket at this address'
+)
+ap.add_argument(
+ '--connect',
+ default='/run/keyholder/agent.sock',
+ help='Proxy connects to the ssh-agent socket at this address'
+)
+ap.add_argument(
+ '--auth-dir',
+ default='/etc/keyholder-auth.d',
+ help='directory with YAML configuration files'
+)
args = ap.parse_args()
-syslog.syslog('Proxying %s -> %s' % (args.bind, args.connect))
-proxy = SshAgentProxyServer(args.bind, args.connect,
- get_permissions(args.auth_dir))
-proxy.serve_forever()
+key_perms = get_key_permissions(args.auth_dir)
+syslog.openlog(logoption=syslog.LOG_PID, facility=syslog.LOG_AUTH)
+SshAgentProxyServer(args.bind, args.connect, key_perms).serve_forever()
--
To view, visit https://gerrit.wikimedia.org/r/234454
To unsubscribe, visit https://gerrit.wikimedia.org/r/settings
Gerrit-MessageType: newchange
Gerrit-Change-Id: I97dd5d803996e930bf8ef118b77ce8b4942f5b6f
Gerrit-PatchSet: 1
Gerrit-Project: operations/puppet
Gerrit-Branch: production
Gerrit-Owner: Ori.livneh <[email protected]>
_______________________________________________
MediaWiki-commits mailing list
[email protected]
https://lists.wikimedia.org/mailman/listinfo/mediawiki-commits