Fix CertficateDB locking scheme

Currently we are locking every file going to be accessed by CertificateDB code even if it is not realy needed, because of a more general lock.

This patch:
   - Replace the old FileLocker class with the pair Lock/Locker classes
   - Remove most of the locks in CertificateDB with only two locks one
     for main database locking and one lock for the file contain the
     current serial number.

This is a Measurement Factory project
Fix CertficateDB locking scheme

Currently we are locking every file going to be accessed by CertificateDB code
even if it is not realy needed, because of a more general lock.

This patch:
   - Replace the old FileLocker class with the pair Lock/Locker classes
   - Remove most of the locks in CertificateDB with only two locks one
     for main database locking and one lock for the file contain the
     current serial number.

This is a Measurement Factory project
=== modified file 'src/ssl/certificate_db.cc'
--- src/ssl/certificate_db.cc	2011-09-15 16:34:52 +0000
+++ src/ssl/certificate_db.cc	2011-09-22 09:25:48 +0000
@@ -1,69 +1,125 @@
 /*
  * $Id$
  */
 
 #include "config.h"
 #include "ssl/certificate_db.h"
+#if HAVE_ERRNO_H
+#include <errno.h>
+#endif
 #if HAVE_FSTREAM
 #include <fstream>
 #endif
 #if HAVE_STDEXCEPT
 #include <stdexcept>
 #endif
 #if HAVE_SYS_STAT_H
 #include <sys/stat.h>
 #endif
 #if HAVE_SYS_FILE_H
 #include <sys/file.h>
 #endif
 #if HAVE_FCNTL_H
 #include <fcntl.h>
 #endif
 
-Ssl::FileLocker::FileLocker(std::string const & filename)
-        :    fd(-1)
+#define HERE "(ssl_crtd) " << __FILE__ << ':' << __LINE__ << ": "
+
+Ssl::Lock::Lock(std::string const &aFilename) :
+    filename(aFilename),
+#if _SQUID_MSWIN_
+    hFile(INVALID_HANDLE_VALUE)
+#else
+    fd(-1)
+#endif
+{
+}
+
+bool Ssl::Lock::locked() const
 {
 #if _SQUID_MSWIN_
+    return hFile != INVALID_HANDLE_VALUE;
+#else
+    return fd != -1;
+#endif
+}
+
+void Ssl::Lock::lock()
+{
+
+#if _SQUID_MSWIN_
     hFile = CreateFile(TEXT(filename.c_str()), GENERIC_READ, 0, NULL, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, NULL);
-    if (hFile != INVALID_HANDLE_VALUE)
-        LockFile(hFile, 0, 0, 1, 0);
+    if (hFile == INVALID_HANDLE_VALUE)
 #else
     fd = open(filename.c_str(), 0);
-    if (fd != -1)
-        flock(fd, LOCK_EX);
+    if (fd == -1)
 #endif
+        throw std::runtime_error("Failed to open file " + filename);
+
+
+#if _SQUID_MSWIN_
+    if (!LockFile(hFile, 0, 0, 1, 0))
+#else
+    if (flock(fd, LOCK_EX) != 0)
+#endif
+        throw std::runtime_error("Failed to get a lock of " + filename);
 }
 
-Ssl::FileLocker::~FileLocker()
-{
+void Ssl::Lock::unlock()
+{ 
 #if _SQUID_MSWIN_
     if (hFile != INVALID_HANDLE_VALUE) {
         UnlockFile(hFile, 0, 0, 1, 0);
         CloseHandle(hFile);
+        hFile = INVALID_HANDLE_VALUE;
     }
 #else
     if (fd != -1) {
         flock(fd, LOCK_UN);
         close(fd);
+        fd = -1;
     }
 #endif
+    else
+        throw std::runtime_error("Lock is already unlocked for " + filename);
+}
+
+Ssl::Lock::~Lock()
+{
+    if (locked())
+        unlock();
+}
+
+Ssl::Locker::Locker(Lock &aLock, const char *aFileName, int aLineNo): 
+    weLocked(false), lock(aLock), fileName(aFileName), lineNo(aLineNo)
+{
+    if (!lock.locked()) {
+        lock.lock();
+        weLocked = true;
+    }
+}
+
+Ssl::Locker::~Locker()
+{
+    if (weLocked)
+        lock.unlock();
 }
 
 Ssl::CertificateDb::Row::Row()
         :   width(cnlNumber)
 {
     row = new char *[width + 1];
     for (size_t i = 0; i < width + 1; i++)
         row[i] = NULL;
 }
 
 Ssl::CertificateDb::Row::~Row()
 {
     if (row) {
         for (size_t i = 0; i < width + 1; i++) {
             delete[](row[i]);
         }
         delete[](row);
     }
 }
 
@@ -113,60 +169,60 @@
 int Ssl::CertificateDb::index_name_cmp(const char **a, const char **b)
 {
     return(strcmp(a[Ssl::CertificateDb::cnlName], b[CertificateDb::cnlName]));
 }
 
 const std::string Ssl::CertificateDb::serial_file("serial");
 const std::string Ssl::CertificateDb::db_file("index.txt");
 const std::string Ssl::CertificateDb::cert_dir("certs");
 const std::string Ssl::CertificateDb::size_file("size");
 const size_t Ssl::CertificateDb::min_db_size(4096);
 
 Ssl::CertificateDb::CertificateDb(std::string const & aDb_path, size_t aMax_db_size, size_t aFs_block_size)
         :  db_path(aDb_path),
         serial_full(aDb_path + "/" + serial_file),
         db_full(aDb_path + "/" + db_file),
         cert_full(aDb_path + "/" + cert_dir),
         size_full(aDb_path + "/" + size_file),
         db(NULL),
         max_db_size(aMax_db_size),
         fs_block_size(aFs_block_size),
+        dbLock(db_full),
+        dbSerialLock(serial_full),
         enabled_disk_store(true)
 {
     if (db_path.empty() && !max_db_size)
         enabled_disk_store = false;
     else if ((db_path.empty() && max_db_size) || (!db_path.empty() && !max_db_size))
         throw std::runtime_error("ssl_crtd is missing the required parameter. There should be -s and -M parameters together.");
-    else
-        load();
 }
 
 bool Ssl::CertificateDb::find(std::string const & host_name, Ssl::X509_Pointer & cert, Ssl::EVP_PKEY_Pointer & pkey)
 {
-    FileLocker db_locker(db_full);
+    const Locker locker(dbLock, Here);
     load();
     return pure_find(host_name, cert, pkey);
 }
 
 bool Ssl::CertificateDb::addCertAndPrivateKey(Ssl::X509_Pointer & cert, Ssl::EVP_PKEY_Pointer & pkey)
 {
-    FileLocker db_locker(db_full);
+    const Locker locker(dbLock, Here);
     load();
     if (!db || !cert || !pkey || min_db_size > max_db_size)
         return false;
     Row row;
     ASN1_INTEGER * ai = X509_get_serialNumber(cert.get());
     std::string serial_string;
     Ssl::BIGNUM_Pointer serial(ASN1_INTEGER_to_BN(ai, NULL));
     {
         TidyPointer<char, tidyFree> hex_bn(BN_bn2hex(serial.get()));
         serial_string = std::string(hex_bn.get());
     }
     row.setValue(cnlSerial, serial_string.c_str());
     char ** rrow = TXT_DB_get_by_index(db.get(), cnlSerial, row.getRow());
     if (rrow != NULL)
         return false;
 
     {
         TidyPointer<char, tidyFree> subject(X509_NAME_oneline(X509_get_subject_name(cert.get()), NULL, 0));
         if (pure_find(subject.get(), cert, pkey))
             return true;
@@ -178,52 +234,51 @@
     }
 
     while (max_db_size < size()) {
         deleteOldestCertificate();
     }
 
     row.setValue(cnlType, "V");
     ASN1_UTCTIME * tm = X509_get_notAfter(cert.get());
     row.setValue(cnlExp_date, std::string(reinterpret_cast<char *>(tm->data), tm->length).c_str());
     row.setValue(cnlFile, "unknown");
     {
         TidyPointer<char, tidyFree> subject(X509_NAME_oneline(X509_get_subject_name(cert.get()), NULL, 0));
         row.setValue(cnlName, subject.get());
     }
 
     if (!TXT_DB_insert(db.get(), row.getRow()))
         return false;
 
     row.reset();
     std::string filename(cert_full + "/" + serial_string + ".pem");
-    FileLocker cert_locker(filename);
     if (!writeCertAndPrivateKeyToFile(cert, pkey, filename.c_str()))
         return false;
     addSize(filename);
 
     save();
     return true;
 }
 
 BIGNUM * Ssl::CertificateDb::getCurrentSerialNumber()
 {
-    FileLocker serial_locker(serial_full);
+    const Locker locker(dbSerialLock, Here);
     // load serial number from file.
     Ssl::BIO_Pointer file(BIO_new(BIO_s_file()));
     if (!file)
         return NULL;
 
     if (BIO_rw_filename(file.get(), const_cast<char *>(serial_full.c_str())) <= 0)
         return NULL;
 
     Ssl::ASN1_INT_Pointer serial_ai(ASN1_INTEGER_new());
     if (!serial_ai)
         return NULL;
 
     char buffer[1024];
     if (!a2i_ASN1_INTEGER(file.get(), serial_ai.get(), buffer, sizeof(buffer)))
         return NULL;
 
     Ssl::BIGNUM_Pointer serial(ASN1_INTEGER_to_BN(serial_ai.get(), NULL));
 
     if (!serial)
         return NULL;
@@ -280,94 +335,91 @@
         throw std::runtime_error("SSL error");
 
     if (BIO_write_filename(file.get(), const_cast<char *>(serial_full.c_str())) <= 0)
         throw std::runtime_error("Cannot open " + cert_full + " to open");
 
     i2a_ASN1_INTEGER(file.get(), i.get());
 
     std::ofstream size(size_full.c_str());
     if (size)
         size << 0;
     else
         throw std::runtime_error("Cannot open " + size_full + " to open");
     std::ofstream db(db_full.c_str());
     if (!db)
         throw std::runtime_error("Cannot open " + db_full + " to open");
 }
 
 void Ssl::CertificateDb::check(std::string const & db_path, size_t max_db_size)
 {
     CertificateDb db(db_path, max_db_size, 0);
+    db.load();
 }
 
 std::string Ssl::CertificateDb::getSNString() const
 {
-    FileLocker serial_locker(serial_full);
+    const Locker locker(dbSerialLock, Here);
     std::ifstream file(serial_full.c_str());
     if (!file)
         return "";
     std::string serial;
     file >> serial;
     return serial;
 }
 
 bool Ssl::CertificateDb::pure_find(std::string const & host_name, Ssl::X509_Pointer & cert, Ssl::EVP_PKEY_Pointer & pkey)
 {
     if (!db)
         return false;
 
     Row row;
     row.setValue(cnlName, host_name.c_str());
 
     char **rrow = TXT_DB_get_by_index(db.get(), cnlName, row.getRow());
     if (rrow == NULL)
         return false;
 
     if (!sslDateIsInTheFuture(rrow[cnlExp_date])) {
         deleteByHostname(rrow[cnlName]);
         return false;
     }
 
     // read cert and pkey from file.
     std::string filename(cert_full + "/" + rrow[cnlSerial] + ".pem");
-    FileLocker cert_locker(filename);
     readCertAndPrivateKeyFromFiles(cert, pkey, filename.c_str(), NULL);
     if (!cert || !pkey)
         return false;
     return true;
 }
 
 size_t Ssl::CertificateDb::size() const
 {
-    FileLocker size_locker(size_full);
     return readSize();
 }
 
 void Ssl::CertificateDb::addSize(std::string const & filename)
 {
-    FileLocker size_locker(size_full);
     writeSize(readSize() + getFileSize(filename));
 }
 
 void Ssl::CertificateDb::subSize(std::string const & filename)
 {
-    FileLocker size_locker(size_full);
     writeSize(readSize() - getFileSize(filename));
 }
 
 size_t Ssl::CertificateDb::readSize() const
 {
     size_t db_size;
     std::ifstream size_file(size_full.c_str());
     if (!size_file && enabled_disk_store)
         throw std::runtime_error("cannot read \"" + size_full + "\" file");
     size_file >> db_size;
     return db_size;
 }
 
 void Ssl::CertificateDb::writeSize(size_t db_size)
 {
     std::ofstream size_file(size_full.c_str());
     if (!size_file && enabled_disk_store)
         throw std::runtime_error("cannot write \"" + size_full + "\" file");
     size_file << db_size;
 }
@@ -415,41 +467,40 @@
 
 void Ssl::CertificateDb::save()
 {
     if (!db)
         throw std::runtime_error("The certificates database is not loaded");;
 
     // To save the db to file,  create a new BIO with BIO file methods.
     Ssl::BIO_Pointer out(BIO_new(BIO_s_file()));
     if (!out || !BIO_write_filename(out.get(), const_cast<char *>(db_full.c_str())))
         throw std::runtime_error("Failed to initialize " + db_full + " file for writing");;
 
     if (TXT_DB_write(out.get(), db.get()) < 0)
         throw std::runtime_error("Failed to write " + db_full + " file");
 }
 
 // Normally defined in defines.h file
 #define countof(arr) (sizeof(arr)/sizeof(*arr))
 void Ssl::CertificateDb::deleteRow(const char **row, int rowIndex)
 {
     const std::string filename(cert_full + "/" + row[cnlSerial] + ".pem");
-    const FileLocker cert_locker(filename);
 #if OPENSSL_VERSION_NUMBER >= 0x1000004fL
     sk_OPENSSL_PSTRING_delete(db.get()->data, rowIndex);
 #else
     sk_delete(db.get()->data, rowIndex);
 #endif
     
     const Columns db_indexes[]={cnlSerial, cnlName};
     for (unsigned int i = 0; i < countof(db_indexes); i++) {
 #if OPENSSL_VERSION_NUMBER >= 0x1000004fL
         if (LHASH_OF(OPENSSL_STRING) *fieldIndex =  db.get()->index[db_indexes[i]])
             lh_OPENSSL_STRING_delete(fieldIndex, (char **)row);
 #else
         if (LHASH *fieldIndex = db.get()->index[db_indexes[i]])
             lh_delete(fieldIndex, row);
 #endif
     }
     
     subSize(filename);
     int ret = remove(filename.c_str());
     if (ret < 0)

=== modified file 'src/ssl/certificate_db.h'
--- src/ssl/certificate_db.h	2011-09-15 16:34:52 +0000
+++ src/ssl/certificate_db.h	2011-09-22 09:27:14 +0000
@@ -1,54 +1,74 @@
 /*
  * $Id$
  */
 
 #ifndef SQUID_SSL_CERTIFICATE_DB_H
 #define SQUID_SSL_CERTIFICATE_DB_H
 
 #include "ssl/gadgets.h"
 #include "ssl/support.h"
 #if HAVE_STRING
 #include <string>
 #endif
 #if HAVE_OPENSSL_OPENSSLV_H
 #include <openssl/opensslv.h>
 #endif
 
 namespace Ssl
 {
-/// Cross platform file locker.
-class FileLocker
-{
+/// maintains an exclusive blocking file-based lock
+class Lock {
 public:
-    /// Lock file
-    FileLocker(std::string const & aFilename);
-    /// Unlock file
-    ~FileLocker();
+    explicit Lock(std::string const &filename); ///<  creates an unlocked lock
+    ~Lock(); ///<  releases the lock if it is locked
+    void lock(); ///<  locks the lock, may block
+    void unlock(); ///<  unlocks locked lock or throws
+    bool locked() const; ///<  whether our lock is locked
+    const char *name() const { return filename.c_str(); }
 private:
+    std::string filename;
 #if _SQUID_MSWIN_
     HANDLE hFile; ///< Windows file handle.
 #else
     int fd; ///< Linux file descriptor.
 #endif
 };
 
+/// an exception-safe way to obtain and release a lock
+class Locker
+{
+public:
+    /// locks the lock if the lock was unlocked
+    Locker(Lock &lock, const char  *aFileName, int lineNo);
+    /// unlocks the lock if it was locked by us
+    ~Locker();
+private:
+    bool weLocked; ///<  whether we locked the lock
+    Lock &lock; ///<  the lock we are operating on
+    const std::string fileName; ///<  where the lock was needed
+    const int lineNo; ///<  where the lock was needed    
+};
+
+/// convenience macro to pass source code location to Locker and others
+#define Here __FILE__, __LINE__
+
 /**
  * Database class for storing SSL certificates and their private keys.
  * A database consist by:
  *     - A disk file to store current serial number
  *     - A disk file to store the current database size
  *     - A disk file which is a normal TXT_DB openSSL database
  *     - A directory under which the certificates and their private keys stored.
  *  The database before used must initialized with CertificateDb::create static method.
  */
 class CertificateDb
 {
 public:
     /// Names of db columns.
     enum Columns {
         cnlType = 0,
         cnlExp_date,
         cnlRev_date,
         cnlSerial,
         cnlFile,
         cnlName,
@@ -133,26 +153,28 @@
     static IMPLEMENT_LHASH_HASH_FN(index_name_hash,const char **)
     static IMPLEMENT_LHASH_COMP_FN(index_name_cmp,const char **)
 #endif
 
     static const std::string serial_file; ///< Base name of the file to store serial number.
     static const std::string db_file; ///< Base name of the database index file.
     static const std::string cert_dir; ///< Base name of the directory to store the certs.
     static const std::string size_file; ///< Base name of the file to store db size.
     /// Min size of disk db. If real size < min_db_size the  db will be disabled.
     static const size_t min_db_size;
 
     const std::string db_path; ///< The database directory.
     const std::string serial_full; ///< Full path of the file to store serial number.
     const std::string db_full; ///< Full path of the database index file.
     const std::string cert_full; ///< Full path of the directory to store the certs.
     const std::string size_full; ///< Full path of the file to store the db size.
 
     TXT_DB_Pointer db; ///< Database with certificates info.
     const size_t max_db_size; ///< Max size of db.
     const size_t fs_block_size; ///< File system block size.
+    mutable Lock dbLock;  ///< protects the database file
+    mutable Lock dbSerialLock; ///< protects the serial number file
 
     bool enabled_disk_store; ///< The storage on the disk is enabled.
 };
 
 } // namespace Ssl
 #endif // SQUID_SSL_CERTIFICATE_DB_H

Reply via email to