URL: https://github.com/freeipa/freeipa/pull/476
Author: HonzaCholasta
 Title: #476: vault: cache the transport certificate on client
Action: opened

PR body:
"""
https://fedorahosted.org/freeipa/ticket/6652
"""

To pull the PR as Git branch:
git remote add ghfreeipa https://github.com/freeipa/freeipa
git fetch ghfreeipa pull/476/head:pr476
git checkout pr476
From 0c51e0a6ab0cb4e08a8720f6eb45ba8b244682af Mon Sep 17 00:00:00 2001
From: Jan Cholasta <jchol...@redhat.com>
Date: Fri, 17 Feb 2017 11:25:17 +0100
Subject: [PATCH] vault: cache the transport certificate on client

https://fedorahosted.org/freeipa/ticket/6652
---
 ipaclient/plugins/vault.py           | 173 ++++++++++++++++++++++++++++-------
 ipaclient/remote_plugins/__init__.py |   3 +-
 ipaclient/remote_plugins/schema.py   |  12 +--
 ipalib/constants.py                  |  14 +++
 4 files changed, 158 insertions(+), 44 deletions(-)

diff --git a/ipaclient/plugins/vault.py b/ipaclient/plugins/vault.py
index 9efb1f1..52ffbd8 100644
--- a/ipaclient/plugins/vault.py
+++ b/ipaclient/plugins/vault.py
@@ -20,15 +20,17 @@
 from __future__ import print_function
 
 import base64
+import errno
 import getpass
 import io
 import json
 import os
 import sys
+import tempfile
 
 from cryptography.fernet import Fernet, InvalidToken
 from cryptography.hazmat.backends import default_backend
-from cryptography.hazmat.primitives import hashes
+from cryptography.hazmat.primitives import hashes, serialization
 from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
 from cryptography.hazmat.primitives.asymmetric import padding
 from cryptography.hazmat.primitives.serialization import load_pem_public_key,\
@@ -37,12 +39,21 @@
 import nss.nss as nss
 
 from ipaclient.frontend import MethodOverride
+from ipalib import x509
+from ipalib.constants import USER_CACHE_PATH
 from ipalib.frontend import Local, Method, Object
 from ipalib.util import classproperty
 from ipalib import api, errors
 from ipalib import Bytes, Flag, Str
 from ipalib.plugable import Registry
 from ipalib import _
+from ipapython.dnsutil import DNSName
+from ipapython.ipa_log_manager import log_mgr
+
+logger = log_mgr.get_logger(__name__)
+
+TRANSPORT_CERT_CACHE_PATH = (
+    os.path.join(USER_CACHE_PATH, 'ipa', 'kra-transport-certs'))
 
 
 def validated_read(argname, filename, mode='r', encoding=None):
@@ -169,6 +180,58 @@ def decrypt(data, symmetric_key=None, private_key=None):
                 message=_('Invalid credentials'))
 
 
+def get_cached_transport_cert(domain):
+    basename = DNSName(domain).ToASCII() + '.pem'
+    filename = os.path.join(TRANSPORT_CERT_CACHE_PATH, basename)
+
+    try:
+        try:
+            return x509.load_certificate_from_file(filename)
+        except EnvironmentError as e:
+            if e.errno != errno.ENOENT:
+                raise
+    except Exception as e:
+        logger.warning("Failed to load %s: %s", filename, e)
+
+    raise KeyError(domain)
+
+
+def set_cached_transport_cert(domain, cert):
+    basename = DNSName(domain).ToASCII() + '.pem'
+    filename = os.path.join(TRANSPORT_CERT_CACHE_PATH, basename)
+
+    try:
+        data = cert.public_bytes(serialization.Encoding.PEM)
+        try:
+            os.makedirs(TRANSPORT_CERT_CACHE_PATH)
+        except EnvironmentError as e:
+            if e.errno != errno.EEXIST:
+                raise
+        with tempfile.NamedTemporaryFile(dir=TRANSPORT_CERT_CACHE_PATH,
+                                         delete=False) as f:
+            f.write(data)
+            os.rename(f.name, filename)
+    except Exception as e:
+        logger.warning("Failed to save %s: %s", filename, e)
+
+
+def del_cached_transport_cert(domain):
+    basename = DNSName(domain).ToASCII() + '.pem'
+    filename = os.path.join(TRANSPORT_CERT_CACHE_PATH, basename)
+
+    try:
+        try:
+            os.remove(filename)
+            return
+        except EnvironmentError as e:
+            if e.errno != errno.ENOENT:
+                raise
+    except Exception as e:
+        logger.warning("Failed to remove %s: %s", filename, e)
+
+    raise KeyError(domain)
+
+
 @register(no_fail=True)
 class _fake_vault(Object):
     name = 'vault'
@@ -765,27 +828,12 @@ def forward(self, *args, **options):
         # initialize NSS database
         nss.nss_init(api.env.nss_dir)
 
-        # retrieve transport certificate
-        config = self.api.Command.vaultconfig_show()['result']
-        transport_cert_der = config['transport_cert']
-        nss_transport_cert = nss.Certificate(transport_cert_der)
-
         # generate session key
         mechanism = nss.CKM_DES3_CBC_PAD
         slot = nss.get_best_slot(mechanism)
         key_length = slot.get_best_key_length(mechanism)
         session_key = slot.key_gen(mechanism, None, key_length)
 
-        # wrap session key with transport certificate
-        # pylint: disable=no-member
-        public_key = nss_transport_cert.subject_public_key_info.public_key
-        # pylint: enable=no-member
-        wrapped_session_key = nss.pub_wrap_sym_key(mechanism,
-                                                   public_key,
-                                                   session_key)
-
-        options['session_key'] = wrapped_session_key.data
-
         nonce_length = nss.get_iv_length(mechanism)
         nonce = nss.generate_random(nonce_length)
         options['nonce'] = nonce
@@ -813,7 +861,46 @@ def forward(self, *args, **options):
 
         options['vault_data'] = wrapped_vault_data
 
-        return self.api.Command.vault_archive_internal(*args, **options)
+        for retry in (True, False):
+            # retrieve transport certificate
+            try:
+                transport_cert = get_cached_transport_cert(api.env.domain)
+            except KeyError:
+                config = self.api.Command.vaultconfig_show()['result']
+                transport_cert_der = config['transport_cert']
+                transport_cert = (
+                    x509.load_certificate(transport_cert_der, x509.DER))
+                set_cached_transport_cert(api.env.domain, transport_cert)
+            else:
+                transport_cert_der = (
+                    transport_cert.public_bytes(serialization.Encoding.DER))
+            nss_transport_cert = nss.Certificate(transport_cert_der)
+
+            # wrap session key with transport certificate
+            # pylint: disable=no-member
+            public_key = nss_transport_cert.subject_public_key_info.public_key
+            # pylint: enable=no-member
+            wrapped_session_key = nss.pub_wrap_sym_key(mechanism,
+                                                       public_key,
+                                                       session_key)
+
+            options['session_key'] = wrapped_session_key.data
+
+            try:
+                result = self.api.Command.vault_archive_internal(*args,
+                                                                 **options)
+            except errors.NotFound:
+                raise
+            except (errors.InternalError,
+                    errors.ExecutionError,
+                    errors.GenericError):
+                del_cached_transport_cert(api.env.domain)
+                if not retry:
+                    raise
+            else:
+                break
+
+        return result
 
 
 @register(no_fail=True)
@@ -928,29 +1015,51 @@ def forward(self, *args, **options):
         # initialize NSS database
         nss.nss_init(api.env.nss_dir)
 
-        # retrieve transport certificate
-        config = self.api.Command.vaultconfig_show()['result']
-        transport_cert_der = config['transport_cert']
-        nss_transport_cert = nss.Certificate(transport_cert_der)
-
         # generate session key
         mechanism = nss.CKM_DES3_CBC_PAD
         slot = nss.get_best_slot(mechanism)
         key_length = slot.get_best_key_length(mechanism)
         session_key = slot.key_gen(mechanism, None, key_length)
 
-        # wrap session key with transport certificate
-        # pylint: disable=no-member
-        public_key = nss_transport_cert.subject_public_key_info.public_key
-        # pylint: enable=no-member
-        wrapped_session_key = nss.pub_wrap_sym_key(mechanism,
-                                                   public_key,
-                                                   session_key)
+        for retry in (True, False):
+            # retrieve transport certificate
+            try:
+                transport_cert = get_cached_transport_cert(api.env.domain)
+            except KeyError:
+                config = self.api.Command.vaultconfig_show()['result']
+                transport_cert_der = config['transport_cert']
+                transport_cert = (
+                    x509.load_certificate(transport_cert_der, x509.DER))
+                set_cached_transport_cert(api.env.domain, transport_cert)
+            else:
+                transport_cert_der = (
+                    transport_cert.public_bytes(serialization.Encoding.DER))
+            nss_transport_cert = nss.Certificate(transport_cert_der)
+
+            # wrap session key with transport certificate
+            # pylint: disable=no-member
+            public_key = nss_transport_cert.subject_public_key_info.public_key
+            # pylint: enable=no-member
+            wrapped_session_key = nss.pub_wrap_sym_key(mechanism,
+                                                       public_key,
+                                                       session_key)
 
-        # send retrieval request to server
-        options['session_key'] = wrapped_session_key.data
+            # send retrieval request to server
+            options['session_key'] = wrapped_session_key.data
 
-        response = self.api.Command.vault_retrieve_internal(*args, **options)
+            try:
+                response = self.api.Command.vault_retrieve_internal(*args,
+                                                                    **options)
+            except errors.NotFound:
+                raise
+            except (errors.InternalError,
+                    errors.ExecutionError,
+                    errors.GenericError):
+                del_cached_transport_cert(api.env.domain)
+                if not retry:
+                    raise
+            else:
+                break
 
         result = response['result']
         nonce = result['nonce']
diff --git a/ipaclient/remote_plugins/__init__.py b/ipaclient/remote_plugins/__init__.py
index da7004d..f95b9b7 100644
--- a/ipaclient/remote_plugins/__init__.py
+++ b/ipaclient/remote_plugins/__init__.py
@@ -12,6 +12,7 @@
 from . import compat
 from . import schema
 from ipaclient.plugins.rpcclient import rpcclient
+from ipalib.constants import USER_CACHE_PATH
 from ipapython.dnsutil import DNSName
 from ipapython.ipa_log_manager import log_mgr
 
@@ -19,7 +20,7 @@
 
 
 class ServerInfo(collections.MutableMapping):
-    _DIR = os.path.join(schema.USER_CACHE_PATH, 'ipa', 'servers')
+    _DIR = os.path.join(USER_CACHE_PATH, 'ipa', 'servers')
 
     def __init__(self, api):
         hostname = DNSName(api.env.server).ToASCII()
diff --git a/ipaclient/remote_plugins/schema.py b/ipaclient/remote_plugins/schema.py
index 15c03f4..34c155f 100644
--- a/ipaclient/remote_plugins/schema.py
+++ b/ipaclient/remote_plugins/schema.py
@@ -17,6 +17,7 @@
 
 from ipaclient.frontend import ClientCommand, ClientMethod
 from ipalib import errors, parameters, plugable
+from ipalib.constants import USER_CACHE_PATH
 from ipalib.errors import SchemaUpToDate
 from ipalib.frontend import Object
 from ipalib.output import Output
@@ -31,17 +32,6 @@
 if six.PY3:
     unicode = str
 
-USER_CACHE_PATH = (
-    os.environ.get('XDG_CACHE_HOME') or
-    os.path.join(
-        os.environ.get(
-            'HOME',
-            os.path.expanduser('~')
-        ),
-        '.cache'
-    )
-)
-
 _TYPES = {
     'DN': DN,
     'DNSName': DNSName,
diff --git a/ipalib/constants.py b/ipalib/constants.py
index e64324f..4205f5b 100644
--- a/ipalib/constants.py
+++ b/ipalib/constants.py
@@ -21,6 +21,8 @@
 """
 All constants centralised in one file.
 """
+
+import os
 import socket
 from ipapython.dn import DN
 from ipapython.version import VERSION, API_VERSION
@@ -293,3 +295,15 @@
     "tls1.2"
 ]
 TLS_VERSION_MINIMAL = "tls1.0"
+
+# Use cache path
+USER_CACHE_PATH = (
+    os.environ.get('XDG_CACHE_HOME') or
+    os.path.join(
+        os.environ.get(
+            'HOME',
+            os.path.expanduser('~')
+        ),
+        '.cache'
+    )
+)
-- 
Manage your subscription for the Freeipa-devel mailing list:
https://www.redhat.com/mailman/listinfo/freeipa-devel
Contribute to FreeIPA: http://www.freeipa.org/page/Contribute/Code

Reply via email to