When ssl_crtd helper needs to add a fresh certificate to the database but finds an expired certificate already stored, ssl_crtd deletes the expired certificate file from disk before adding the fresh one. However, the addition still fails because the expired certificate was not removed from database indexes.

This fix:
  - Adds code to update database indexes upon deletion of a row.
  - Polishes certificates deletion code to avoid duplication.

This is a Measurement Factory project.
Bug fix: "(ssl_crtd): Cannot add certificate to db" when updating expired cert

When ssl_crtd helper needs to add a fresh certificate to the database but
finds an expired certificate already stored, ssl_crtd deletes the expired
certificate file from disk before adding the fresh one. However, the addition
still fails because the expired certificate was not removed from database
indexes.

This fix:
  - Adds code to update database indexes upon deletion of a row.
  - Polishes certificates deletion code to avoid duplication.

TODO: Report failure details to Squid and make certificate-specific failures
not fatal for the ssl_crtd helper.

This is a Measurement Factory project.
=== modified file 'src/ssl/certificate_db.cc'
--- src/ssl/certificate_db.cc	2011-09-06 07:07:52 +0000
+++ src/ssl/certificate_db.cc	2011-09-15 13:02:05 +0000
@@ -410,118 +410,122 @@
     if (corrupt)
         throw std::runtime_error("The SSL certificate database " + db_path + " is corrupted. Please rebuild");
 
     db.reset(temp_db.release());
 }
 
 void Ssl::CertificateDb::save()
 {
     if (!db)
         throw std::runtime_error("The certificates database is not loaded");;
 
     // To save the db to file,  create a new BIO with BIO file methods.
     Ssl::BIO_Pointer out(BIO_new(BIO_s_file()));
     if (!out || !BIO_write_filename(out.get(), const_cast<char *>(db_full.c_str())))
         throw std::runtime_error("Failed to initialize " + db_full + " file for writing");;
 
     if (TXT_DB_write(out.get(), db.get()) < 0)
         throw std::runtime_error("Failed to write " + db_full + " file");
 }
 
+// Normally defined in defines.h file
+#define countof(arr) (sizeof(arr)/sizeof(*arr))
+void Ssl::CertificateDb::deleteRow(const char **row, int rowIndex)
+{
+    const std::string filename(cert_full + "/" + row[cnlSerial] + ".pem");
+    const FileLocker cert_locker(filename);
+#if OPENSSL_VERSION_NUMBER >= 0x1000004fL
+    sk_OPENSSL_PSTRING_delete(db.get()->data, rowIndex);
+#else
+    sk_delete(db.get()->data, rowIndex);
+#endif
+    
+    const Columns db_indexes[]={cnlSerial, cnlName};
+    for (unsigned int i = 0; i < countof(db_indexes); i++) {
+#if OPENSSL_VERSION_NUMBER >= 0x1000004fL
+        if (LHASH_OF(OPENSSL_STRING) *fieldIndex =  db.get()->index[db_indexes[i]])
+            lh_OPENSSL_STRING_delete(fieldIndex, (char **)row);
+#else
+        if (LHASH *fieldIndex = db.get()->index[db_indexes[i]])
+            lh_delete(fieldIndex, row);
+#endif
+    }
+    
+    subSize(filename);
+    int ret = remove(filename.c_str());
+    if (ret < 0)
+        throw std::runtime_error("Failed to remove certficate file " + filename + " from db");
+}
+
 bool Ssl::CertificateDb::deleteInvalidCertificate()
 {
     if (!db)
         return false;
 
     bool removed_one = false;
 #if OPENSSL_VERSION_NUMBER >= 0x1000004fL
     for (int i = 0; i < sk_OPENSSL_PSTRING_num(db.get()->data); i++) {
         const char ** current_row = ((const char **)sk_OPENSSL_PSTRING_value(db.get()->data, i));
 #else
     for (int i = 0; i < sk_num(db.get()->data); i++) {
         const char ** current_row = ((const char **)sk_value(db.get()->data, i));
 #endif
 
         if (!sslDateIsInTheFuture(current_row[cnlExp_date])) {
-            std::string filename(cert_full + "/" + current_row[cnlSerial] + ".pem");
-            FileLocker cert_locker(filename);
-#if OPENSSL_VERSION_NUMBER >= 0x1000004fL
-            sk_OPENSSL_PSTRING_delete(db.get()->data, i);
-#else
-            sk_delete(db.get()->data, i);
-#endif
-            subSize(filename);
-            remove(filename.c_str());
+            deleteRow(current_row, i);
             removed_one = true;
             break;
         }
     }
 
     if (!removed_one)
         return false;
     return true;
 }
 
 bool Ssl::CertificateDb::deleteOldestCertificate()
 {
     if (!db)
         return false;
 
 #if OPENSSL_VERSION_NUMBER >= 0x1000004fL
     if (sk_OPENSSL_PSTRING_num(db.get()->data) == 0)
 #else
     if (sk_num(db.get()->data) == 0)
 #endif
         return false;
 
 #if OPENSSL_VERSION_NUMBER >= 0x1000004fL
     const char **row = (const char **)sk_OPENSSL_PSTRING_value(db.get()->data, 0);
 #else
     const char **row = (const char **)sk_value(db.get()->data, 0);
 #endif
-    std::string filename(cert_full + "/" + row[cnlSerial] + ".pem");
-    FileLocker cert_locker(filename);
 
-#if OPENSSL_VERSION_NUMBER >= 0x1000004fL
-    sk_OPENSSL_PSTRING_delete(db.get()->data, 0);
-#else
-    sk_delete(db.get()->data, 0);
-#endif
-
-    subSize(filename);
-    remove(filename.c_str());
+    deleteRow(row, 0);
 
     return true;
 }
 
 bool Ssl::CertificateDb::deleteByHostname(std::string const & host)
 {
     if (!db)
         return false;
 
 #if OPENSSL_VERSION_NUMBER >= 0x1000004fL
     for (int i = 0; i < sk_OPENSSL_PSTRING_num(db.get()->data); i++) {
         const char ** current_row = ((const char **)sk_OPENSSL_PSTRING_value(db.get()->data, i));
 #else
     for (int i = 0; i < sk_num(db.get()->data); i++) {
         const char ** current_row = ((const char **)sk_value(db.get()->data, i));
 #endif
         if (host == current_row[cnlName]) {
-            std::string filename(cert_full + "/" + current_row[cnlSerial] + ".pem");
-            FileLocker cert_locker(filename);
-#if OPENSSL_VERSION_NUMBER >= 0x1000004fL
-            sk_OPENSSL_PSTRING_delete(db.get()->data, i);
-#else
-            sk_delete(db.get()->data, i);
-#endif
-            subSize(filename);
-            remove(filename.c_str());
+            deleteRow(current_row, i);
             return true;
         }
     }
     return false;
 }
 
 bool Ssl::CertificateDb::IsEnabledDiskStore() const
 {
     return enabled_disk_store;
 }

=== modified file 'src/ssl/certificate_db.h'
--- src/ssl/certificate_db.h	2011-07-27 00:12:59 +0000
+++ src/ssl/certificate_db.h	2011-09-15 12:59:35 +0000
@@ -81,40 +81,41 @@
     /// Create and initialize a database  under the  db_path
     static void create(std::string const & db_path, int serial);
     /// Check the database stored under the db_path.
     static void check(std::string const & db_path, size_t max_db_size);
     std::string getSNString() const; ///< Get serial number as string.
     bool IsEnabledDiskStore() const; ///< Check enabled of dist store.
 private:
     void load(); ///< Load db from disk.
     void save(); ///< Save db to disk.
     size_t size() const; ///< Get db size on disk in bytes.
     /// Increase db size by the given file size and update size_file
     void addSize(std::string const & filename);
     /// Decrease db size by the given file size and update size_file
     void subSize(std::string const & filename);
     size_t readSize() const; ///< Read size from file size_file
     void writeSize(size_t db_size); ///< Write size to file size_file.
     size_t getFileSize(std::string const & filename); ///< get file size on disk.
     /// Only find certificate in current db and return it.
     bool pure_find(std::string const & host_name, Ssl::X509_Pointer & cert, Ssl::EVP_PKEY_Pointer & pkey);
 
+    void deleteRow(const char **row, int rowIndex); ///< Delete a row from TXT_DB
     bool deleteInvalidCertificate(); ///< Delete invalid certificate.
     bool deleteOldestCertificate(); ///< Delete oldest certificate.
     bool deleteByHostname(std::string const & host); ///< Delete using host name.
 
     /// Callback hash function for serials. Used to create TXT_DB index of serials.
     static unsigned long index_serial_hash(const char **a);
     /// Callback compare function for serials. Used to create TXT_DB index of serials.
     static int index_serial_cmp(const char **a, const char **b);
     /// Callback hash function for names. Used to create TXT_DB index of names..
     static unsigned long index_name_hash(const char **a);
     /// Callback compare function for  names. Used to create TXT_DB index of names..
     static int index_name_cmp(const char **a, const char **b);
 
     /// Definitions required by openSSL, to use the index_* functions defined above
     ///with TXT_DB_create_index.
 #if OPENSSL_VERSION_NUMBER > 0x10000000L
     static unsigned long index_serial_LHASH_HASH(const void *a) {
         return index_serial_hash((const char **)a);
     }
     static int index_serial_LHASH_COMP(const void *arg1, const void *arg2) {

Reply via email to