URL: https://github.com/freeipa/freeipa/pull/476
Author: HonzaCholasta
 Title: #476: vault: cache the transport certificate on client
Action: synchronized

To pull the PR as Git branch:
git remote add ghfreeipa https://github.com/freeipa/freeipa
git fetch ghfreeipa pull/476/head:pr476
git checkout pr476
From b35f363fd98dd6959d1dd2f9240dcdf308606ff9 Mon Sep 17 00:00:00 2001
From: Jan Cholasta <jchol...@redhat.com>
Date: Fri, 17 Feb 2017 11:25:17 +0100
Subject: [PATCH] vault: cache the transport certificate on client

Cache the KRA transport certificate on disk (in ~/.cache/ipa) as well as
in memory.

https://fedorahosted.org/freeipa/ticket/6652
---
 ipaclient/plugins/vault.py           | 221 ++++++++++++++++++++++++++---------
 ipaclient/remote_plugins/__init__.py |   3 +-
 ipaclient/remote_plugins/schema.py   |  12 +-
 ipalib/constants.py                  |  14 +++
 4 files changed, 186 insertions(+), 64 deletions(-)

diff --git a/ipaclient/plugins/vault.py b/ipaclient/plugins/vault.py
index 70756df..f24ec1e 100644
--- a/ipaclient/plugins/vault.py
+++ b/ipaclient/plugins/vault.py
@@ -20,30 +20,40 @@
 from __future__ import print_function
 
 import base64
+import collections
+import errno
 import getpass
 import io
 import json
 import os
 import sys
+import tempfile
 
 from cryptography.fernet import Fernet, InvalidToken
 from cryptography.hazmat.backends import default_backend
-from cryptography.hazmat.primitives import hashes
+from cryptography.hazmat.primitives import hashes, serialization
 from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
 from cryptography.hazmat.primitives.asymmetric import padding
 from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
 from cryptography.hazmat.primitives.padding import PKCS7
 from cryptography.hazmat.primitives.serialization import (
     load_pem_public_key, load_pem_private_key)
-from cryptography.x509 import load_der_x509_certificate
+
+import six
 
 from ipaclient.frontend import MethodOverride
+from ipalib import x509
+from ipalib.constants import USER_CACHE_PATH
 from ipalib.frontend import Local, Method, Object
 from ipalib.util import classproperty
 from ipalib import api, errors
 from ipalib import Bytes, Flag, Str
 from ipalib.plugable import Registry
 from ipalib import _
+from ipapython.dnsutil import DNSName
+from ipapython.ipa_log_manager import log_mgr
+
+logger = log_mgr.get_logger(__name__)
 
 
 def validated_read(argname, filename, mode='r', encoding=None):
@@ -550,6 +560,79 @@ def forward(self, *args, **options):
         return response
 
 
+class _TransportCertCache(collections.MutableMapping):
+    def __init__(self):
+        self._dirname = os.path.join(
+                USER_CACHE_PATH, 'ipa', 'kra-transport-certs')
+        self._transport_certs = {}
+
+    def _get_filename(self, domain):
+        basename = DNSName(domain).ToASCII() + '.pem'
+        return os.path.join(self._dirname, basename)
+
+    def __getitem__(self, domain):
+        try:
+            transport_cert = self._transport_certs[domain]
+        except KeyError:
+            transport_cert = None
+
+            filename = self._get_filename(domain)
+            try:
+                try:
+                    transport_cert = x509.load_certificate_from_file(filename)
+                except EnvironmentError as e:
+                    if e.errno != errno.ENOENT:
+                        raise
+            except Exception:
+                logger.warning("Failed to load %s: %s", filename,
+                               exc_info=True)
+
+            if transport_cert is None:
+                raise KeyError(domain)
+
+            self._transport_certs[domain] = transport_cert
+
+        return transport_cert
+
+    def __setitem__(self, domain, transport_cert):
+        filename = self._get_filename(domain)
+        transport_cert_der = (
+            transport_cert.public_bytes(serialization.Encoding.DER))
+        try:
+            try:
+                os.makedirs(self._dirname)
+            except EnvironmentError as e:
+                if e.errno != errno.EEXIST:
+                    raise
+            fd, tmpfilename = tempfile.mkstemp(dir=self._dirname)
+            os.close(fd)
+            x509.write_certificate(transport_cert_der, tmpfilename)
+            os.rename(tmpfilename, filename)
+        except Exception:
+            logger.warning("Failed to save %s", filename, exc_info=True)
+
+        self._transport_certs[domain] = transport_cert
+
+    def __delitem__(self, domain):
+        filename = self._get_filename(domain)
+        try:
+            os.unlink(filename)
+        except EnvironmentError as e:
+            if e.errno != errno.ENOENT:
+                logger.warning("Failed to remove %s", filename, exc_info=True)
+
+        del self._transport_certs[domain]
+
+    def __len__(self):
+        return len(self._transport_certs)
+
+    def __iter__(self):
+        return iter(self._transport_certs)
+
+
+_transport_cert_cache = _TransportCertCache()
+
+
 @register(override=True, no_fail=True)
 class vaultconfig_show(MethodOverride):
     def forward(self, *args, **options):
@@ -562,6 +645,11 @@ def forward(self, *args, **options):
 
         response = super(vaultconfig_show, self).forward(*args, **options)
 
+        # cache transport certificate
+        transport_cert = x509.load_certificate(
+                response['result']['transport_cert'], x509.DER)
+        _transport_cert_cache[self.api.env.domain] = transport_cert
+
         if file:
             with open(file, 'w') as f:
                 f.write(response['result']['transport_cert'])
@@ -569,6 +657,72 @@ def forward(self, *args, **options):
         return response
 
 
+class _TransportCertInvalid(Exception):
+    def __init__(self, exc_info):
+        self.exc_info = exc_info
+
+
+class ModVaultData(Local):
+    def _generate_session_key(self):
+        key_length = max(algorithms.TripleDES.key_sizes)
+        algo = algorithms.TripleDES(os.urandom(key_length // 8))
+        return algo
+
+    def _do_internal(self, algo, transport_cert, *args, **options):
+        public_key = transport_cert.public_key()
+
+        # wrap session key with transport certificate
+        wrapped_session_key = public_key.encrypt(
+            algo.key,
+            padding.PKCS1v15()
+        )
+        options['session_key'] = wrapped_session_key
+
+        name = self.name + '_internal'
+        try:
+            return self.api.Command[name](*args, **options)
+        except errors.NotFound:
+            raise
+        except (errors.InternalError,
+                errors.ExecutionError,
+                errors.GenericError):
+            raise _TransportCertInvalid(sys.exc_info())
+
+    def internal(self, algo, *args, **options):
+        """
+        Calls the internal counterpart of the command.
+        """
+        domain = self.api.env.domain
+
+        # try call with cached transport cert, uncache it if unsuccessful
+        transport_cert = _transport_cert_cache.get(domain)
+        if transport_cert is not None:
+            try:
+                return self._do_internal(algo, transport_cert,
+                                         *args, **options)
+            except _TransportCertInvalid as e:
+                e.exc_info = None
+                _transport_cert_cache.pop(domain, None)
+
+        # retrieve transport certificate
+        self.api.Command.vaultconfig_show()
+        transport_cert = _transport_cert_cache[domain]
+
+        # call with the retrieved transport cert, cache it if successful
+        transport_cert_valid = True
+        try:
+            return self._do_internal(algo, transport_cert, *args, **options)
+        except _TransportCertInvalid as e:
+            transport_cert_valid = False
+            try:
+                six.reraise(*e.exc_info)
+            finally:
+                e.exc_info = None
+        finally:
+            if transport_cert_valid:
+                _transport_cert_cache[domain] = transport_cert
+
+
 @register(no_fail=True)
 class _fake_vault_archive_internal(Method):
     name = 'vault_archive_internal'
@@ -576,7 +730,7 @@ class _fake_vault_archive_internal(Method):
 
 
 @register()
-class vault_archive(Local):
+class vault_archive(ModVaultData):
     __doc__ = _('Archive data into a vault.')
 
     takes_options = (
@@ -640,28 +794,15 @@ def get_output_params(self):
     def _iter_output(self):
         return self.api.Command.vault_archive_internal.output()
 
-    def _wrap_data(self, transport_cert_der, json_vault_data):
+    def _wrap_data(self, algo, json_vault_data):
         """Encrypt data with wrapped session key and transport cert
 
-        :param bytes transport_cert_der: transport cert in DER encoding
+        :param bytes algo: wrapping algorithm instance
         :param bytes json_vault_data: dumped vault data
         :return:
         """
-        transport_cert = load_der_x509_certificate(
-            transport_cert_der, default_backend())
-        public_key = transport_cert.public_key()
-
-        # generate session key
-        key_length = max(algorithms.TripleDES.key_sizes)
-        algo = algorithms.TripleDES(os.urandom(key_length // 8))
         nonce = os.urandom(algo.block_size // 8)
 
-        # wrap session key with transport certificate
-        wrapped_session_key = public_key.encrypt(
-            algo.key,
-            padding.PKCS1v15()
-        )
-
         # wrap vault_data with session key
         padder = PKCS7(algo.block_size).padder()
         padded_data = padder.update(json_vault_data)
@@ -671,7 +812,7 @@ def _wrap_data(self, transport_cert_der, json_vault_data):
         encryptor = cipher.encryptor()
         wrapped_vault_data = encryptor.update(padded_data) + encryptor.finalize()
 
-        return wrapped_session_key, nonce, wrapped_vault_data
+        return nonce, wrapped_vault_data
 
     def forward(self, *args, **options):
         data = options.get('data')
@@ -806,20 +947,15 @@ def forward(self, *args, **options):
 
         json_vault_data = json.dumps(vault_data).encode('utf-8')
 
-        # retrieve transport certificate
-        config = self.api.Command.vaultconfig_show()['result']
-        transport_cert_der = config['transport_cert']
-        # created wrapped session key and wrap vault data
-        wrapped_session_key, nonce, wrapped_vault_data = self._wrap_data(
-            transport_cert_der, json_vault_data
-
-        )
+        # generate session key
+        algo = self._generate_session_key()
+        # wrap vault data
+        nonce, wrapped_vault_data = self._wrap_data(algo, json_vault_data)
         options.update(
-            session_key=wrapped_session_key,
             nonce=nonce,
             vault_data=wrapped_vault_data
         )
-        return self.api.Command.vault_archive_internal(*args, **options)
+        return self.internal(algo, *args, **options)
 
 
 @register(no_fail=True)
@@ -829,7 +965,7 @@ class _fake_vault_retrieve_internal(Method):
 
 
 @register()
-class vault_retrieve(Local):
+class vault_retrieve(ModVaultData):
     __doc__ = _('Retrieve a data from a vault.')
 
     takes_options = (
@@ -899,20 +1035,6 @@ def get_output_params(self):
     def _iter_output(self):
         return self.api.Command.vault_retrieve_internal.output()
 
-    def _wrap_session_key(self, transport_cert_der):
-        transport_cert = load_der_x509_certificate(
-            transport_cert_der, default_backend())
-        public_key = transport_cert.public_key()
-        # generate session key
-        key_length = max(algorithms.TripleDES.key_sizes)
-        algo = algorithms.TripleDES(os.urandom(key_length // 8))
-        # wrap session key with transport certificate
-        wrapped_session_key = public_key.encrypt(
-            algo.key,
-            padding.PKCS1v15()
-        )
-        return algo, wrapped_session_key
-
     def _unwrap_response(self, algo, nonce, vault_data):
         cipher = Cipher(algo, modes.CBC(nonce), backend=default_backend())
         # decrypt
@@ -957,15 +1079,10 @@ def forward(self, *args, **options):
         vault = self.api.Command.vault_show(*args, **options)['result']
         vault_type = vault['ipavaulttype'][0]
 
-        # retrieve transport certificate
-        config = self.api.Command.vaultconfig_show()['result']
-        # create algo and wrap session key with transport cert
-        algo, wrapped_session_key = self._wrap_session_key(
-            config['transport_cert']
-        )
+        # generate session key
+        algo = self._generate_session_key()
         # send retrieval request to server
-        options['session_key'] = wrapped_session_key
-        response = self.api.Command.vault_retrieve_internal(*args, **options)
+        response = self.internal(algo, *args, **options)
         # unwrap data with session key
         vault_data = self._unwrap_response(
             algo,
diff --git a/ipaclient/remote_plugins/__init__.py b/ipaclient/remote_plugins/__init__.py
index da7004d..f95b9b7 100644
--- a/ipaclient/remote_plugins/__init__.py
+++ b/ipaclient/remote_plugins/__init__.py
@@ -12,6 +12,7 @@
 from . import compat
 from . import schema
 from ipaclient.plugins.rpcclient import rpcclient
+from ipalib.constants import USER_CACHE_PATH
 from ipapython.dnsutil import DNSName
 from ipapython.ipa_log_manager import log_mgr
 
@@ -19,7 +20,7 @@
 
 
 class ServerInfo(collections.MutableMapping):
-    _DIR = os.path.join(schema.USER_CACHE_PATH, 'ipa', 'servers')
+    _DIR = os.path.join(USER_CACHE_PATH, 'ipa', 'servers')
 
     def __init__(self, api):
         hostname = DNSName(api.env.server).ToASCII()
diff --git a/ipaclient/remote_plugins/schema.py b/ipaclient/remote_plugins/schema.py
index 0cdce9d..3ecd608 100644
--- a/ipaclient/remote_plugins/schema.py
+++ b/ipaclient/remote_plugins/schema.py
@@ -15,6 +15,7 @@
 
 from ipaclient.frontend import ClientCommand, ClientMethod
 from ipalib import errors, parameters, plugable
+from ipalib.constants import USER_CACHE_PATH
 from ipalib.errors import SchemaUpToDate
 from ipalib.frontend import Object
 from ipalib.output import Output
@@ -29,17 +30,6 @@
 if six.PY3:
     unicode = str
 
-USER_CACHE_PATH = (
-    os.environ.get('XDG_CACHE_HOME') or
-    os.path.join(
-        os.environ.get(
-            'HOME',
-            os.path.expanduser('~')
-        ),
-        '.cache'
-    )
-)
-
 _TYPES = {
     'DN': DN,
     'DNSName': DNSName,
diff --git a/ipalib/constants.py b/ipalib/constants.py
index 8789a95..61065e4 100644
--- a/ipalib/constants.py
+++ b/ipalib/constants.py
@@ -21,6 +21,8 @@
 """
 All constants centralised in one file.
 """
+
+import os
 import socket
 from ipapython.dn import DN
 from ipapython.version import VERSION, API_VERSION
@@ -296,3 +298,15 @@
     "tls1.2"
 ]
 TLS_VERSION_MINIMAL = "tls1.0"
+
+# Use cache path
+USER_CACHE_PATH = (
+    os.environ.get('XDG_CACHE_HOME') or
+    os.path.join(
+        os.environ.get(
+            'HOME',
+            os.path.expanduser('~')
+        ),
+        '.cache'
+    )
+)
-- 
Manage your subscription for the Freeipa-devel mailing list:
https://www.redhat.com/mailman/listinfo/freeipa-devel
Contribute to FreeIPA: http://www.freeipa.org/page/Contribute/Code

Reply via email to