Fix CertficateDB locking scheme
Currently we are locking every file going to be accessed by
CertificateDB code even if it is not realy needed, because of a more
general lock.
This patch:
- Replace the old FileLocker class with the pair Lock/Locker classes
- Remove most of the locks in CertificateDB with only two locks one
for main database locking and one lock for the file contain the
current serial number.
This is a Measurement Factory project
Fix CertficateDB locking scheme
Currently we are locking every file going to be accessed by CertificateDB code
even if it is not realy needed, because of a more general lock.
This patch:
- Replace the old FileLocker class with the pair Lock/Locker classes
- Remove most of the locks in CertificateDB with only two locks one
for main database locking and one lock for the file contain the
current serial number.
This is a Measurement Factory project
=== modified file 'src/ssl/certificate_db.cc'
--- src/ssl/certificate_db.cc 2011-09-15 16:34:52 +0000
+++ src/ssl/certificate_db.cc 2011-09-22 09:25:48 +0000
@@ -1,69 +1,125 @@
/*
* $Id$
*/
#include "config.h"
#include "ssl/certificate_db.h"
+#if HAVE_ERRNO_H
+#include <errno.h>
+#endif
#if HAVE_FSTREAM
#include <fstream>
#endif
#if HAVE_STDEXCEPT
#include <stdexcept>
#endif
#if HAVE_SYS_STAT_H
#include <sys/stat.h>
#endif
#if HAVE_SYS_FILE_H
#include <sys/file.h>
#endif
#if HAVE_FCNTL_H
#include <fcntl.h>
#endif
-Ssl::FileLocker::FileLocker(std::string const & filename)
- : fd(-1)
+#define HERE "(ssl_crtd) " << __FILE__ << ':' << __LINE__ << ": "
+
+Ssl::Lock::Lock(std::string const &aFilename) :
+ filename(aFilename),
+#if _SQUID_MSWIN_
+ hFile(INVALID_HANDLE_VALUE)
+#else
+ fd(-1)
+#endif
+{
+}
+
+bool Ssl::Lock::locked() const
{
#if _SQUID_MSWIN_
+ return hFile != INVALID_HANDLE_VALUE;
+#else
+ return fd != -1;
+#endif
+}
+
+void Ssl::Lock::lock()
+{
+
+#if _SQUID_MSWIN_
hFile = CreateFile(TEXT(filename.c_str()), GENERIC_READ, 0, NULL, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, NULL);
- if (hFile != INVALID_HANDLE_VALUE)
- LockFile(hFile, 0, 0, 1, 0);
+ if (hFile == INVALID_HANDLE_VALUE)
#else
fd = open(filename.c_str(), 0);
- if (fd != -1)
- flock(fd, LOCK_EX);
+ if (fd == -1)
#endif
+ throw std::runtime_error("Failed to open file " + filename);
+
+
+#if _SQUID_MSWIN_
+ if (!LockFile(hFile, 0, 0, 1, 0))
+#else
+ if (flock(fd, LOCK_EX) != 0)
+#endif
+ throw std::runtime_error("Failed to get a lock of " + filename);
}
-Ssl::FileLocker::~FileLocker()
-{
+void Ssl::Lock::unlock()
+{
#if _SQUID_MSWIN_
if (hFile != INVALID_HANDLE_VALUE) {
UnlockFile(hFile, 0, 0, 1, 0);
CloseHandle(hFile);
+ hFile = INVALID_HANDLE_VALUE;
}
#else
if (fd != -1) {
flock(fd, LOCK_UN);
close(fd);
+ fd = -1;
}
#endif
+ else
+ throw std::runtime_error("Lock is already unlocked for " + filename);
+}
+
+Ssl::Lock::~Lock()
+{
+ if (locked())
+ unlock();
+}
+
+Ssl::Locker::Locker(Lock &aLock, const char *aFileName, int aLineNo):
+ weLocked(false), lock(aLock), fileName(aFileName), lineNo(aLineNo)
+{
+ if (!lock.locked()) {
+ lock.lock();
+ weLocked = true;
+ }
+}
+
+Ssl::Locker::~Locker()
+{
+ if (weLocked)
+ lock.unlock();
}
Ssl::CertificateDb::Row::Row()
: width(cnlNumber)
{
row = new char *[width + 1];
for (size_t i = 0; i < width + 1; i++)
row[i] = NULL;
}
Ssl::CertificateDb::Row::~Row()
{
if (row) {
for (size_t i = 0; i < width + 1; i++) {
delete[](row[i]);
}
delete[](row);
}
}
@@ -113,60 +169,60 @@
int Ssl::CertificateDb::index_name_cmp(const char **a, const char **b)
{
return(strcmp(a[Ssl::CertificateDb::cnlName], b[CertificateDb::cnlName]));
}
const std::string Ssl::CertificateDb::serial_file("serial");
const std::string Ssl::CertificateDb::db_file("index.txt");
const std::string Ssl::CertificateDb::cert_dir("certs");
const std::string Ssl::CertificateDb::size_file("size");
const size_t Ssl::CertificateDb::min_db_size(4096);
Ssl::CertificateDb::CertificateDb(std::string const & aDb_path, size_t aMax_db_size, size_t aFs_block_size)
: db_path(aDb_path),
serial_full(aDb_path + "/" + serial_file),
db_full(aDb_path + "/" + db_file),
cert_full(aDb_path + "/" + cert_dir),
size_full(aDb_path + "/" + size_file),
db(NULL),
max_db_size(aMax_db_size),
fs_block_size(aFs_block_size),
+ dbLock(db_full),
+ dbSerialLock(serial_full),
enabled_disk_store(true)
{
if (db_path.empty() && !max_db_size)
enabled_disk_store = false;
else if ((db_path.empty() && max_db_size) || (!db_path.empty() && !max_db_size))
throw std::runtime_error("ssl_crtd is missing the required parameter. There should be -s and -M parameters together.");
- else
- load();
}
bool Ssl::CertificateDb::find(std::string const & host_name, Ssl::X509_Pointer & cert, Ssl::EVP_PKEY_Pointer & pkey)
{
- FileLocker db_locker(db_full);
+ const Locker locker(dbLock, Here);
load();
return pure_find(host_name, cert, pkey);
}
bool Ssl::CertificateDb::addCertAndPrivateKey(Ssl::X509_Pointer & cert, Ssl::EVP_PKEY_Pointer & pkey)
{
- FileLocker db_locker(db_full);
+ const Locker locker(dbLock, Here);
load();
if (!db || !cert || !pkey || min_db_size > max_db_size)
return false;
Row row;
ASN1_INTEGER * ai = X509_get_serialNumber(cert.get());
std::string serial_string;
Ssl::BIGNUM_Pointer serial(ASN1_INTEGER_to_BN(ai, NULL));
{
TidyPointer<char, tidyFree> hex_bn(BN_bn2hex(serial.get()));
serial_string = std::string(hex_bn.get());
}
row.setValue(cnlSerial, serial_string.c_str());
char ** rrow = TXT_DB_get_by_index(db.get(), cnlSerial, row.getRow());
if (rrow != NULL)
return false;
{
TidyPointer<char, tidyFree> subject(X509_NAME_oneline(X509_get_subject_name(cert.get()), NULL, 0));
if (pure_find(subject.get(), cert, pkey))
return true;
@@ -178,52 +234,51 @@
}
while (max_db_size < size()) {
deleteOldestCertificate();
}
row.setValue(cnlType, "V");
ASN1_UTCTIME * tm = X509_get_notAfter(cert.get());
row.setValue(cnlExp_date, std::string(reinterpret_cast<char *>(tm->data), tm->length).c_str());
row.setValue(cnlFile, "unknown");
{
TidyPointer<char, tidyFree> subject(X509_NAME_oneline(X509_get_subject_name(cert.get()), NULL, 0));
row.setValue(cnlName, subject.get());
}
if (!TXT_DB_insert(db.get(), row.getRow()))
return false;
row.reset();
std::string filename(cert_full + "/" + serial_string + ".pem");
- FileLocker cert_locker(filename);
if (!writeCertAndPrivateKeyToFile(cert, pkey, filename.c_str()))
return false;
addSize(filename);
save();
return true;
}
BIGNUM * Ssl::CertificateDb::getCurrentSerialNumber()
{
- FileLocker serial_locker(serial_full);
+ const Locker locker(dbSerialLock, Here);
// load serial number from file.
Ssl::BIO_Pointer file(BIO_new(BIO_s_file()));
if (!file)
return NULL;
if (BIO_rw_filename(file.get(), const_cast<char *>(serial_full.c_str())) <= 0)
return NULL;
Ssl::ASN1_INT_Pointer serial_ai(ASN1_INTEGER_new());
if (!serial_ai)
return NULL;
char buffer[1024];
if (!a2i_ASN1_INTEGER(file.get(), serial_ai.get(), buffer, sizeof(buffer)))
return NULL;
Ssl::BIGNUM_Pointer serial(ASN1_INTEGER_to_BN(serial_ai.get(), NULL));
if (!serial)
return NULL;
@@ -280,94 +335,91 @@
throw std::runtime_error("SSL error");
if (BIO_write_filename(file.get(), const_cast<char *>(serial_full.c_str())) <= 0)
throw std::runtime_error("Cannot open " + cert_full + " to open");
i2a_ASN1_INTEGER(file.get(), i.get());
std::ofstream size(size_full.c_str());
if (size)
size << 0;
else
throw std::runtime_error("Cannot open " + size_full + " to open");
std::ofstream db(db_full.c_str());
if (!db)
throw std::runtime_error("Cannot open " + db_full + " to open");
}
void Ssl::CertificateDb::check(std::string const & db_path, size_t max_db_size)
{
CertificateDb db(db_path, max_db_size, 0);
+ db.load();
}
std::string Ssl::CertificateDb::getSNString() const
{
- FileLocker serial_locker(serial_full);
+ const Locker locker(dbSerialLock, Here);
std::ifstream file(serial_full.c_str());
if (!file)
return "";
std::string serial;
file >> serial;
return serial;
}
bool Ssl::CertificateDb::pure_find(std::string const & host_name, Ssl::X509_Pointer & cert, Ssl::EVP_PKEY_Pointer & pkey)
{
if (!db)
return false;
Row row;
row.setValue(cnlName, host_name.c_str());
char **rrow = TXT_DB_get_by_index(db.get(), cnlName, row.getRow());
if (rrow == NULL)
return false;
if (!sslDateIsInTheFuture(rrow[cnlExp_date])) {
deleteByHostname(rrow[cnlName]);
return false;
}
// read cert and pkey from file.
std::string filename(cert_full + "/" + rrow[cnlSerial] + ".pem");
- FileLocker cert_locker(filename);
readCertAndPrivateKeyFromFiles(cert, pkey, filename.c_str(), NULL);
if (!cert || !pkey)
return false;
return true;
}
size_t Ssl::CertificateDb::size() const
{
- FileLocker size_locker(size_full);
return readSize();
}
void Ssl::CertificateDb::addSize(std::string const & filename)
{
- FileLocker size_locker(size_full);
writeSize(readSize() + getFileSize(filename));
}
void Ssl::CertificateDb::subSize(std::string const & filename)
{
- FileLocker size_locker(size_full);
writeSize(readSize() - getFileSize(filename));
}
size_t Ssl::CertificateDb::readSize() const
{
size_t db_size;
std::ifstream size_file(size_full.c_str());
if (!size_file && enabled_disk_store)
throw std::runtime_error("cannot read \"" + size_full + "\" file");
size_file >> db_size;
return db_size;
}
void Ssl::CertificateDb::writeSize(size_t db_size)
{
std::ofstream size_file(size_full.c_str());
if (!size_file && enabled_disk_store)
throw std::runtime_error("cannot write \"" + size_full + "\" file");
size_file << db_size;
}
@@ -415,41 +467,40 @@
void Ssl::CertificateDb::save()
{
if (!db)
throw std::runtime_error("The certificates database is not loaded");;
// To save the db to file, create a new BIO with BIO file methods.
Ssl::BIO_Pointer out(BIO_new(BIO_s_file()));
if (!out || !BIO_write_filename(out.get(), const_cast<char *>(db_full.c_str())))
throw std::runtime_error("Failed to initialize " + db_full + " file for writing");;
if (TXT_DB_write(out.get(), db.get()) < 0)
throw std::runtime_error("Failed to write " + db_full + " file");
}
// Normally defined in defines.h file
#define countof(arr) (sizeof(arr)/sizeof(*arr))
void Ssl::CertificateDb::deleteRow(const char **row, int rowIndex)
{
const std::string filename(cert_full + "/" + row[cnlSerial] + ".pem");
- const FileLocker cert_locker(filename);
#if OPENSSL_VERSION_NUMBER >= 0x1000004fL
sk_OPENSSL_PSTRING_delete(db.get()->data, rowIndex);
#else
sk_delete(db.get()->data, rowIndex);
#endif
const Columns db_indexes[]={cnlSerial, cnlName};
for (unsigned int i = 0; i < countof(db_indexes); i++) {
#if OPENSSL_VERSION_NUMBER >= 0x1000004fL
if (LHASH_OF(OPENSSL_STRING) *fieldIndex = db.get()->index[db_indexes[i]])
lh_OPENSSL_STRING_delete(fieldIndex, (char **)row);
#else
if (LHASH *fieldIndex = db.get()->index[db_indexes[i]])
lh_delete(fieldIndex, row);
#endif
}
subSize(filename);
int ret = remove(filename.c_str());
if (ret < 0)
=== modified file 'src/ssl/certificate_db.h'
--- src/ssl/certificate_db.h 2011-09-15 16:34:52 +0000
+++ src/ssl/certificate_db.h 2011-09-22 09:27:14 +0000
@@ -1,54 +1,74 @@
/*
* $Id$
*/
#ifndef SQUID_SSL_CERTIFICATE_DB_H
#define SQUID_SSL_CERTIFICATE_DB_H
#include "ssl/gadgets.h"
#include "ssl/support.h"
#if HAVE_STRING
#include <string>
#endif
#if HAVE_OPENSSL_OPENSSLV_H
#include <openssl/opensslv.h>
#endif
namespace Ssl
{
-/// Cross platform file locker.
-class FileLocker
-{
+/// maintains an exclusive blocking file-based lock
+class Lock {
public:
- /// Lock file
- FileLocker(std::string const & aFilename);
- /// Unlock file
- ~FileLocker();
+ explicit Lock(std::string const &filename); ///< creates an unlocked lock
+ ~Lock(); ///< releases the lock if it is locked
+ void lock(); ///< locks the lock, may block
+ void unlock(); ///< unlocks locked lock or throws
+ bool locked() const; ///< whether our lock is locked
+ const char *name() const { return filename.c_str(); }
private:
+ std::string filename;
#if _SQUID_MSWIN_
HANDLE hFile; ///< Windows file handle.
#else
int fd; ///< Linux file descriptor.
#endif
};
+/// an exception-safe way to obtain and release a lock
+class Locker
+{
+public:
+ /// locks the lock if the lock was unlocked
+ Locker(Lock &lock, const char *aFileName, int lineNo);
+ /// unlocks the lock if it was locked by us
+ ~Locker();
+private:
+ bool weLocked; ///< whether we locked the lock
+ Lock &lock; ///< the lock we are operating on
+ const std::string fileName; ///< where the lock was needed
+ const int lineNo; ///< where the lock was needed
+};
+
+/// convenience macro to pass source code location to Locker and others
+#define Here __FILE__, __LINE__
+
/**
* Database class for storing SSL certificates and their private keys.
* A database consist by:
* - A disk file to store current serial number
* - A disk file to store the current database size
* - A disk file which is a normal TXT_DB openSSL database
* - A directory under which the certificates and their private keys stored.
* The database before used must initialized with CertificateDb::create static method.
*/
class CertificateDb
{
public:
/// Names of db columns.
enum Columns {
cnlType = 0,
cnlExp_date,
cnlRev_date,
cnlSerial,
cnlFile,
cnlName,
@@ -133,26 +153,28 @@
static IMPLEMENT_LHASH_HASH_FN(index_name_hash,const char **)
static IMPLEMENT_LHASH_COMP_FN(index_name_cmp,const char **)
#endif
static const std::string serial_file; ///< Base name of the file to store serial number.
static const std::string db_file; ///< Base name of the database index file.
static const std::string cert_dir; ///< Base name of the directory to store the certs.
static const std::string size_file; ///< Base name of the file to store db size.
/// Min size of disk db. If real size < min_db_size the db will be disabled.
static const size_t min_db_size;
const std::string db_path; ///< The database directory.
const std::string serial_full; ///< Full path of the file to store serial number.
const std::string db_full; ///< Full path of the database index file.
const std::string cert_full; ///< Full path of the directory to store the certs.
const std::string size_full; ///< Full path of the file to store the db size.
TXT_DB_Pointer db; ///< Database with certificates info.
const size_t max_db_size; ///< Max size of db.
const size_t fs_block_size; ///< File system block size.
+ mutable Lock dbLock; ///< protects the database file
+ mutable Lock dbSerialLock; ///< protects the serial number file
bool enabled_disk_store; ///< The storage on the disk is enabled.
};
} // namespace Ssl
#endif // SQUID_SSL_CERTIFICATE_DB_H