URL: https://github.com/freeipa/freeipa/pull/476
Author: HonzaCholasta
 Title: #476: vault: cache the transport certificate on client
Action: synchronized

To pull the PR as Git branch:
git remote add ghfreeipa https://github.com/freeipa/freeipa
git fetch ghfreeipa pull/476/head:pr476
git checkout pr476
From fabb614a57d2d2eec40e846d73a0b1ba18aaf836 Mon Sep 17 00:00:00 2001
From: Jan Cholasta <jchol...@redhat.com>
Date: Fri, 17 Feb 2017 11:25:17 +0100
Subject: [PATCH] vault: cache the transport certificate on client

https://fedorahosted.org/freeipa/ticket/6652
---
 ipaclient/plugins/vault.py           | 149 +++++++++++++++++++++++++++--------
 ipaclient/remote_plugins/__init__.py |   3 +-
 ipaclient/remote_plugins/schema.py   |  12 +--
 ipalib/constants.py                  |  14 ++++
 4 files changed, 131 insertions(+), 47 deletions(-)

diff --git a/ipaclient/plugins/vault.py b/ipaclient/plugins/vault.py
index 9efb1f1..809f7b0 100644
--- a/ipaclient/plugins/vault.py
+++ b/ipaclient/plugins/vault.py
@@ -20,29 +20,41 @@
 from __future__ import print_function
 
 import base64
+import errno
 import getpass
 import io
 import json
 import os
 import sys
+import tempfile
 
 from cryptography.fernet import Fernet, InvalidToken
 from cryptography.hazmat.backends import default_backend
-from cryptography.hazmat.primitives import hashes
+from cryptography.hazmat.primitives import hashes, serialization
 from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
 from cryptography.hazmat.primitives.asymmetric import padding
 from cryptography.hazmat.primitives.serialization import load_pem_public_key,\
     load_pem_private_key
 
 import nss.nss as nss
+import six
 
 from ipaclient.frontend import MethodOverride
+from ipalib import x509
+from ipalib.constants import USER_CACHE_PATH
 from ipalib.frontend import Local, Method, Object
 from ipalib.util import classproperty
 from ipalib import api, errors
 from ipalib import Bytes, Flag, Str
 from ipalib.plugable import Registry
 from ipalib import _
+from ipapython.dnsutil import DNSName
+from ipapython.ipa_log_manager import log_mgr
+
+logger = log_mgr.get_logger(__name__)
+
+TRANSPORT_CERT_CACHE_PATH = (
+    os.path.join(USER_CACHE_PATH, 'ipa', 'kra-transport-certs'))
 
 
 def validated_read(argname, filename, mode='r', encoding=None):
@@ -568,6 +580,103 @@ def forward(self, *args, **options):
         return response
 
 
+class _TransportCertInvalid(Exception):
+    def __init__(self, exc_info):
+        self.exc_info = exc_info
+
+
+class ModVaultData(Local):
+    def _do_internal(self, mechanism, session_key, transport_cert_der,
+                     *args, **options):
+        nss_transport_cert = nss.Certificate(transport_cert_der)
+
+        # wrap session key with transport certificate
+        # pylint: disable=no-member
+        public_key = nss_transport_cert.subject_public_key_info.public_key
+        # pylint: enable=no-member
+        wrapped_session_key = nss.pub_wrap_sym_key(mechanism,
+                                                   public_key,
+                                                   session_key)
+
+        options['session_key'] = wrapped_session_key.data
+
+        name = self.name + '_internal'
+        try:
+            return self.api.Command[name](*args, **options)
+        except errors.NotFound:
+            raise
+        except (errors.InternalError,
+                errors.ExecutionError,
+                errors.GenericError):
+            raise _TransportCertInvalid(sys.exc_info())
+
+    def internal(self, mechanism, session_key, *args, **options):
+        """
+        Calls the internal counterpart of the command.
+        """
+        dirname = TRANSPORT_CERT_CACHE_PATH
+        basename = DNSName(self.api.env.domain).ToASCII() + '.pem'
+        filename = os.path.join(dirname, basename)
+
+        # try call with the cached transport cert, if there is one
+        transport_cert = None
+        try:
+            try:
+                transport_cert = x509.load_certificate_from_file(filename)
+            except EnvironmentError as e:
+                if e.errno != errno.ENOENT:
+                    raise
+        except Exception:
+            logger.warning("Failed to load %s: %s", filename,
+                           exc_info=True)
+
+        if transport_cert is not None:
+            transport_cert_der = (
+                transport_cert.public_bytes(serialization.Encoding.DER))
+
+            try:
+                return self._do_internal(mechanism,
+                                         session_key,
+                                         transport_cert_der,
+                                         *args, **options)
+            except _TransportCertInvalid:
+                try:
+                    os.remove(filename)
+                except EnvironmentError:
+                    logger.warning("Failed to remove %s", filename,
+                                   exc_info=True)
+
+        # retrieve transport certificate
+        config = self.api.Command.vaultconfig_show()['result']
+        transport_cert_der = config['transport_cert']
+
+        # call with the retrieved transport cert, cache it if successful
+        transport_cert_valid = True
+        try:
+            return self._do_internal(mechanism,
+                                     session_key,
+                                     transport_cert_der,
+                                     *args, **options)
+        except _TransportCertInvalid as e:
+            transport_cert_valid = False
+            six.reraise(*e.exc_info)
+        finally:
+            if transport_cert_valid:
+                try:
+                    try:
+                        os.makedirs(dirname)
+                    except EnvironmentError as e:
+                        if e.errno != errno.EEXIST:
+                            raise
+                    fd, tmpfilename = tempfile.mkstemp(dir=dirname)
+                    os.close(fd)
+                    x509.write_certificate(transport_cert_der, tmpfilename)
+                    os.rename(tmpfilename, filename)
+                except Exception:
+                    logger.warning("Failed to save %s", filename,
+                                   exc_info=True)
+
+
 @register(no_fail=True)
 class _fake_vault_archive_internal(Method):
     name = 'vault_archive_internal'
@@ -575,7 +684,7 @@ class _fake_vault_archive_internal(Method):
 
 
 @register()
-class vault_archive(Local):
+class vault_archive(ModVaultData):
     __doc__ = _('Archive data into a vault.')
 
     takes_options = (
@@ -765,27 +874,12 @@ def forward(self, *args, **options):
         # initialize NSS database
         nss.nss_init(api.env.nss_dir)
 
-        # retrieve transport certificate
-        config = self.api.Command.vaultconfig_show()['result']
-        transport_cert_der = config['transport_cert']
-        nss_transport_cert = nss.Certificate(transport_cert_der)
-
         # generate session key
         mechanism = nss.CKM_DES3_CBC_PAD
         slot = nss.get_best_slot(mechanism)
         key_length = slot.get_best_key_length(mechanism)
         session_key = slot.key_gen(mechanism, None, key_length)
 
-        # wrap session key with transport certificate
-        # pylint: disable=no-member
-        public_key = nss_transport_cert.subject_public_key_info.public_key
-        # pylint: enable=no-member
-        wrapped_session_key = nss.pub_wrap_sym_key(mechanism,
-                                                   public_key,
-                                                   session_key)
-
-        options['session_key'] = wrapped_session_key.data
-
         nonce_length = nss.get_iv_length(mechanism)
         nonce = nss.generate_random(nonce_length)
         options['nonce'] = nonce
@@ -813,7 +907,7 @@ def forward(self, *args, **options):
 
         options['vault_data'] = wrapped_vault_data
 
-        return self.api.Command.vault_archive_internal(*args, **options)
+        return self.internal(mechanism, session_key, *args, **options)
 
 
 @register(no_fail=True)
@@ -823,7 +917,7 @@ class _fake_vault_retrieve_internal(Method):
 
 
 @register()
-class vault_retrieve(Local):
+class vault_retrieve(ModVaultData):
     __doc__ = _('Retrieve a data from a vault.')
 
     takes_options = (
@@ -928,29 +1022,14 @@ def forward(self, *args, **options):
         # initialize NSS database
         nss.nss_init(api.env.nss_dir)
 
-        # retrieve transport certificate
-        config = self.api.Command.vaultconfig_show()['result']
-        transport_cert_der = config['transport_cert']
-        nss_transport_cert = nss.Certificate(transport_cert_der)
-
         # generate session key
         mechanism = nss.CKM_DES3_CBC_PAD
         slot = nss.get_best_slot(mechanism)
         key_length = slot.get_best_key_length(mechanism)
         session_key = slot.key_gen(mechanism, None, key_length)
 
-        # wrap session key with transport certificate
-        # pylint: disable=no-member
-        public_key = nss_transport_cert.subject_public_key_info.public_key
-        # pylint: enable=no-member
-        wrapped_session_key = nss.pub_wrap_sym_key(mechanism,
-                                                   public_key,
-                                                   session_key)
-
         # send retrieval request to server
-        options['session_key'] = wrapped_session_key.data
-
-        response = self.api.Command.vault_retrieve_internal(*args, **options)
+        response = self.internal(mechanism, session_key, *args, **options)
 
         result = response['result']
         nonce = result['nonce']
diff --git a/ipaclient/remote_plugins/__init__.py b/ipaclient/remote_plugins/__init__.py
index da7004d..f95b9b7 100644
--- a/ipaclient/remote_plugins/__init__.py
+++ b/ipaclient/remote_plugins/__init__.py
@@ -12,6 +12,7 @@
 from . import compat
 from . import schema
 from ipaclient.plugins.rpcclient import rpcclient
+from ipalib.constants import USER_CACHE_PATH
 from ipapython.dnsutil import DNSName
 from ipapython.ipa_log_manager import log_mgr
 
@@ -19,7 +20,7 @@
 
 
 class ServerInfo(collections.MutableMapping):
-    _DIR = os.path.join(schema.USER_CACHE_PATH, 'ipa', 'servers')
+    _DIR = os.path.join(USER_CACHE_PATH, 'ipa', 'servers')
 
     def __init__(self, api):
         hostname = DNSName(api.env.server).ToASCII()
diff --git a/ipaclient/remote_plugins/schema.py b/ipaclient/remote_plugins/schema.py
index 15c03f4..34c155f 100644
--- a/ipaclient/remote_plugins/schema.py
+++ b/ipaclient/remote_plugins/schema.py
@@ -17,6 +17,7 @@
 
 from ipaclient.frontend import ClientCommand, ClientMethod
 from ipalib import errors, parameters, plugable
+from ipalib.constants import USER_CACHE_PATH
 from ipalib.errors import SchemaUpToDate
 from ipalib.frontend import Object
 from ipalib.output import Output
@@ -31,17 +32,6 @@
 if six.PY3:
     unicode = str
 
-USER_CACHE_PATH = (
-    os.environ.get('XDG_CACHE_HOME') or
-    os.path.join(
-        os.environ.get(
-            'HOME',
-            os.path.expanduser('~')
-        ),
-        '.cache'
-    )
-)
-
 _TYPES = {
     'DN': DN,
     'DNSName': DNSName,
diff --git a/ipalib/constants.py b/ipalib/constants.py
index e64324f..4205f5b 100644
--- a/ipalib/constants.py
+++ b/ipalib/constants.py
@@ -21,6 +21,8 @@
 """
 All constants centralised in one file.
 """
+
+import os
 import socket
 from ipapython.dn import DN
 from ipapython.version import VERSION, API_VERSION
@@ -293,3 +295,15 @@
     "tls1.2"
 ]
 TLS_VERSION_MINIMAL = "tls1.0"
+
+# Use cache path
+USER_CACHE_PATH = (
+    os.environ.get('XDG_CACHE_HOME') or
+    os.path.join(
+        os.environ.get(
+            'HOME',
+            os.path.expanduser('~')
+        ),
+        '.cache'
+    )
+)
-- 
Manage your subscription for the Freeipa-devel mailing list:
https://www.redhat.com/mailman/listinfo/freeipa-devel
Contribute to FreeIPA: http://www.freeipa.org/page/Contribute/Code

Reply via email to