This is an automated email from the ASF dual-hosted git repository.

zwoop pushed a commit to branch 9.1.x
in repository https://gitbox.apache.org/repos/asf/trafficserver.git

commit de8c12567dc3efbc451a2f3eb58f25eba9ebf47c
Author: Fei Deng <[email protected]>
AuthorDate: Wed Feb 3 01:50:19 2021 -0600

    use std::unordered_map to store sessions (#7405)
    
    make the hash function better for session id
    
    use std::shared_mutex to speed things up
    
    disable session cache for tlsv1.3 sessions since it is not needed
    
    use std::map to order sessions by time
---
 iocore/net/SSLSessionCache.cc | 175 +++++++++++++++++++-----------------------
 iocore/net/SSLSessionCache.h  |  67 ++++++++++++----
 iocore/net/SSLUtils.cc        |  12 +++
 3 files changed, 141 insertions(+), 113 deletions(-)

diff --git a/iocore/net/SSLSessionCache.cc b/iocore/net/SSLSessionCache.cc
index 17fb174..60ec19e 100644
--- a/iocore/net/SSLSessionCache.cc
+++ b/iocore/net/SSLSessionCache.cc
@@ -129,51 +129,50 @@ SSLSessionBucket::insertSession(const SSLSessionID &id, 
SSL_SESSION *sess, SSL *
     Debug("ssl.session_cache", "Inserting session '%s' to bucket %p.", buf, 
this);
   }
 
-  MUTEX_TRY_LOCK(lock, mutex, this_ethread());
-  if (!lock.is_locked()) {
+  Ptr<IOBufferData> buf;
+  Ptr<IOBufferData> buf_exdata;
+  size_t len_exdata = sizeof(ssl_session_cache_exdata);
+  buf               = new_IOBufferData(buffer_size_to_index(len, 
MAX_BUFFER_SIZE_INDEX), MEMALIGNED);
+  ink_release_assert(static_cast<size_t>(buf->block_size()) >= len);
+  unsigned char *loc = reinterpret_cast<unsigned char *>(buf->data());
+  i2d_SSL_SESSION(sess, &loc);
+  buf_exdata = new_IOBufferData(buffer_size_to_index(len, 
MAX_BUFFER_SIZE_INDEX), MEMALIGNED);
+  ink_release_assert(static_cast<size_t>(buf_exdata->block_size()) >= 
len_exdata);
+  ssl_session_cache_exdata *exdata = reinterpret_cast<ssl_session_cache_exdata 
*>(buf_exdata->data());
+  // This could be moved to a function in charge of populating exdata
+  exdata->curve  = (ssl == nullptr) ? 0 : SSLGetCurveNID(ssl);
+  ink_hrtime now = Thread::get_hrtime_updated();
+
+  ats_scoped_obj<SSLSession> ssl_session(new SSLSession(id, buf, len, 
buf_exdata, now));
+
+  std::unique_lock lock(mutex, std::try_to_lock);
+  if (!lock.owns_lock()) {
     if (ssl_rsb) {
       SSL_INCREMENT_DYN_STAT(ssl_session_cache_lock_contention);
     }
     if (SSLConfigParams::session_cache_skip_on_lock_contention) {
       return;
     }
-    lock.acquire(this_ethread());
+    lock.lock();
   }
 
   PRINT_BUCKET("insertSession before")
-  if (queue.size >= 
static_cast<int>(SSLConfigParams::session_cache_max_bucket_size)) {
+  if (bucket_data.size() >= SSLConfigParams::session_cache_max_bucket_size) {
     if (ssl_rsb) {
       SSL_INCREMENT_DYN_STAT(ssl_session_cache_eviction);
     }
-    removeOldestSession();
+    removeOldestSession(lock);
   }
 
   // Don't insert if it is already there
-  SSLSession *node = queue.tail;
-  while (node) {
-    if (node->session_id == id) {
-      return;
-    }
-    node = node->link.prev;
+  if (bucket_data.find(id) != bucket_data.end()) {
+    return;
   }
 
-  Ptr<IOBufferData> buf;
-  Ptr<IOBufferData> buf_exdata;
-  size_t len_exdata = sizeof(ssl_session_cache_exdata);
-  buf               = new_IOBufferData(buffer_size_to_index(len, 
MAX_BUFFER_SIZE_INDEX), MEMALIGNED);
-  ink_release_assert(static_cast<size_t>(buf->block_size()) >= len);
-  unsigned char *loc = reinterpret_cast<unsigned char *>(buf->data());
-  i2d_SSL_SESSION(sess, &loc);
-  buf_exdata = new_IOBufferData(buffer_size_to_index(len, 
MAX_BUFFER_SIZE_INDEX), MEMALIGNED);
-  ink_release_assert(static_cast<size_t>(buf_exdata->block_size()) >= 
len_exdata);
-  ssl_session_cache_exdata *exdata = reinterpret_cast<ssl_session_cache_exdata 
*>(buf_exdata->data());
-  // This could be moved to a function in charge of populating exdata
-  exdata->curve = (ssl == nullptr) ? 0 : SSLGetCurveNID(ssl);
-
-  ats_scoped_obj<SSLSession> ssl_session(new SSLSession(id, buf, len, 
buf_exdata));
-
   /* do the actual insert */
-  queue.enqueue(ssl_session.release());
+  auto node           = ssl_session.release();
+  bucket_data[id]     = node;
+  bucket_data_ts[now] = node;
 
   PRINT_BUCKET("insertSession after")
 }
@@ -182,33 +181,26 @@ int
 SSLSessionBucket::getSessionBuffer(const SSLSessionID &id, char *buffer, int 
&len)
 {
   int true_len = 0;
-  MUTEX_TRY_LOCK(lock, mutex, this_ethread());
-  if (!lock.is_locked()) {
+  std::shared_lock lock(mutex, std::try_to_lock);
+  if (!lock.owns_lock()) {
     if (ssl_rsb) {
       SSL_INCREMENT_DYN_STAT(ssl_session_cache_lock_contention);
     }
     if (SSLConfigParams::session_cache_skip_on_lock_contention) {
       return true_len;
     }
-
-    lock.acquire(this_ethread());
+    lock.lock();
   }
 
-  // We work backwards because that's the most likely place we'll find our 
session...
-  SSLSession *node = queue.tail;
-  while (node) {
-    if (node->session_id == id) {
-      true_len = node->len_asn1_data;
-      if (buffer) {
-        const unsigned char *loc = reinterpret_cast<const unsigned char 
*>(node->asn1_data->data());
-        if (true_len < len) {
-          len = true_len;
-        }
-        memcpy(buffer, loc, len);
-        return true_len;
-      }
+  auto node = bucket_data.find(id);
+  if (buffer && node != bucket_data.end()) {
+    true_len                 = node->second->len_asn1_data;
+    const unsigned char *loc = reinterpret_cast<const unsigned char 
*>(node->second->asn1_data->data());
+    if (true_len < len) {
+      len = true_len;
     }
-    node = node->link.prev;
+    memcpy(buffer, loc, len);
+    return true_len;
   }
   return 0;
 }
@@ -224,38 +216,31 @@ SSLSessionBucket::getSession(const SSLSessionID &id, 
SSL_SESSION **sess, ssl_ses
 
   Debug("ssl.session_cache", "Looking for session with id '%s' in bucket %p", 
buf, this);
 
-  MUTEX_TRY_LOCK(lock, mutex, this_ethread());
-  if (!lock.is_locked()) {
+  std::shared_lock lock(mutex, std::try_to_lock);
+  if (!lock.owns_lock()) {
     if (ssl_rsb) {
       SSL_INCREMENT_DYN_STAT(ssl_session_cache_lock_contention);
     }
     if (SSLConfigParams::session_cache_skip_on_lock_contention) {
       return false;
     }
-
-    lock.acquire(this_ethread());
+    lock.lock();
   }
 
   PRINT_BUCKET("getSession")
 
-  // We work backwards because that's the most likely place we'll find our 
session...
-  SSLSession *node = queue.tail;
-  while (node) {
-    if (node->session_id == id) {
-      const unsigned char *loc = reinterpret_cast<const unsigned char 
*>(node->asn1_data->data());
-      *sess                    = d2i_SSL_SESSION(nullptr, &loc, 
node->len_asn1_data);
-      if (data != nullptr) {
-        ssl_session_cache_exdata *exdata = 
reinterpret_cast<ssl_session_cache_exdata *>(node->extra_data->data());
-        *data                            = exdata;
-      }
-
-      return true;
-    }
-    node = node->link.prev;
+  auto node = bucket_data.find(id);
+  if (node == bucket_data.end()) {
+    Debug("ssl.session_cache", "Session with id '%s' not found in bucket %p.", 
buf, this);
+    return false;
   }
-
-  Debug("ssl.session_cache", "Session with id '%s' not found in bucket %p.", 
buf, this);
-  return false;
+  const unsigned char *loc = reinterpret_cast<const unsigned char 
*>(node->second->asn1_data->data());
+  *sess                    = d2i_SSL_SESSION(nullptr, &loc, 
node->second->len_asn1_data);
+  if (data != nullptr) {
+    ssl_session_cache_exdata *exdata = 
reinterpret_cast<ssl_session_cache_exdata *>(node->second->extra_data->data());
+    *data                            = exdata;
+  }
+  return true;
 }
 
 void inline SSLSessionBucket::print(const char *ref_str) const
@@ -266,53 +251,51 @@ void inline SSLSessionBucket::print(const char *ref_str) 
const
   }
 
   fprintf(stderr, "-------------- BUCKET %p (%s) ----------------\n", this, 
ref_str);
-  fprintf(stderr, "Current Size: %d, Max Size: %zd\n", queue.size, 
SSLConfigParams::session_cache_max_bucket_size);
-  fprintf(stderr, "Queue: \n");
+  fprintf(stderr, "Current Size: %ld, Max Size: %zd\n", bucket_data.size(), 
SSLConfigParams::session_cache_max_bucket_size);
+  fprintf(stderr, "Bucket: \n");
 
-  SSLSession *node = queue.head;
-  while (node) {
-    char s_buf[2 * node->session_id.len + 1];
-    node->session_id.toString(s_buf, sizeof(s_buf));
+  for (auto &x : bucket_data) {
+    char s_buf[2 * x.second->session_id.len + 1];
+    x.second->session_id.toString(s_buf, sizeof(s_buf));
     fprintf(stderr, "  %s\n", s_buf);
-    node = node->link.next;
   }
 }
 
-void inline SSLSessionBucket::removeOldestSession()
+void inline SSLSessionBucket::removeOldestSession(const 
std::unique_lock<std::shared_mutex> &lock)
 {
-  // Caller must hold the bucket lock.
-  ink_assert(this_ethread() == mutex->thread_holding);
+  // Caller must hold the bucket shared_mutex with unique_lock.
+  ink_assert(lock.owns_lock());
 
   PRINT_BUCKET("removeOldestSession before")
-  while (queue.head && queue.size >= 
static_cast<int>(SSLConfigParams::session_cache_max_bucket_size)) {
-    SSLSession *old_head = queue.pop();
-    if (is_debug_tag_set("ssl.session_cache")) {
-      char buf[old_head->session_id.len * 2 + 1];
-      old_head->session_id.toString(buf, sizeof(buf));
-      Debug("ssl.session_cache", "Removing session '%s' from bucket %p because 
the bucket has size %d and max %zd", buf, this,
-            (queue.size + 1), SSLConfigParams::session_cache_max_bucket_size);
-    }
-    delete old_head;
-  }
+
+  auto node = bucket_data_ts.begin();
+  bucket_data.erase(node->second->session_id);
+  bucket_data_ts.erase(node);
+
   PRINT_BUCKET("removeOldestSession after")
 }
 
 void
 SSLSessionBucket::removeSession(const SSLSessionID &id)
 {
-  SCOPED_MUTEX_LOCK(lock, mutex, this_ethread()); // We can't bail on 
contention here because this session MUST be removed.
-  SSLSession *node = queue.head;
-  while (node) {
-    if (node->session_id == id) {
-      queue.remove(node);
-      delete node;
-      return;
-    }
-    node = node->link.next;
+  // We can't bail on contention here because this session MUST be removed.
+  std::unique_lock lock(mutex);
+
+  auto node = bucket_data.find(id);
+
+  PRINT_BUCKET("removeSession before")
+
+  if (node != bucket_data.end()) {
+    bucket_data_ts.erase(node->second->time_stamp);
+    bucket_data.erase(node);
   }
+
+  PRINT_BUCKET("removeSession after")
+
+  return;
 }
 
 /* Session Bucket */
-SSLSessionBucket::SSLSessionBucket() : mutex(new_ProxyMutex()) {}
+SSLSessionBucket::SSLSessionBucket() {}
 
 SSLSessionBucket::~SSLSessionBucket() {}
diff --git a/iocore/net/SSLSessionCache.h b/iocore/net/SSLSessionCache.h
index 44ba12d..05a5930 100644
--- a/iocore/net/SSLSessionCache.h
+++ b/iocore/net/SSLSessionCache.h
@@ -29,6 +29,8 @@
 #include "P_SSLUtils.h"
 #include "ts/apidefs.h"
 #include <openssl/ssl.h>
+#include <mutex>
+#include <shared_mutex>
 
 #define SSL_MAX_SESSION_SIZE 256
 
@@ -36,12 +38,21 @@ struct ssl_session_cache_exdata {
   ssl_curve_id curve = 0;
 };
 
+inline void
+hash_combine(uint64_t &seed, uint64_t hash)
+{
+  // using boost's version of hash combine, substituting magic number with a 
64bit version
+  // 
https://www.boost.org/doc/libs/1_43_0/doc/html/hash/reference.html#boost.hash_combine
+  seed ^= hash + 0x9E3779B97F4A7C15 + (seed << 6) + (seed >> 2);
+}
+
 struct SSLSessionID : public TSSslSessionID {
   SSLSessionID(const unsigned char *s, size_t l)
   {
     len = l;
     ink_release_assert(l <= sizeof(bytes));
     memcpy(bytes, s, l);
+    hash();
   }
 
   SSLSessionID(const SSLSessionID &other)
@@ -50,6 +61,7 @@ struct SSLSessionID : public TSSslSessionID {
       memcpy(bytes, other.bytes, other.len);
 
     len = other.len;
+    hash();
   }
 
   bool
@@ -101,15 +113,33 @@ struct SSLSessionID : public TSSslSessionID {
   uint64_t
   hash() const
   {
-    // because the session ids should be uniformly random let's just use the 
last 64 bits as the hash.
-    // The first bytes could be interpreted as a name, and so not random.
-    if (len >= sizeof(uint64_t)) {
-      return *reinterpret_cast<uint64_t *>(const_cast<char *>(bytes + len - 
sizeof(uint64_t)));
-    } else if (len) {
-      return static_cast<uint64_t>(bytes[0]);
-    } else {
-      return 0;
+    if (hash_value == 0) {
+      // because the session ids should be uniformly random, we can treat the 
bits as a hash value
+      // however we need to combine them if the length is longer than 64bits
+      if (len >= sizeof(uint64_t)) {
+        uint64_t seed = 0;
+        for (uint64_t i = 0; i < len; i += sizeof(uint64_t)) {
+          hash_combine(seed, static_cast<uint64_t>(bytes[i]));
+        }
+        hash_value = seed;
+      } else if (len) {
+        hash_value = static_cast<uint64_t>(bytes[0]);
+      } else {
+        hash_value = 0;
+      }
     }
+    return hash_value;
+  }
+
+private:
+  mutable uint64_t hash_value = 0;
+};
+
+struct SSLSessionIDHash {
+  uint64_t
+  operator()(const SSLSessionID &id) const
+  {
+    return id.hash();
   }
 };
 
@@ -120,9 +150,11 @@ public:
   Ptr<IOBufferData> asn1_data; /* this is the ASN1 representation of the 
SSL_CTX */
   size_t len_asn1_data;
   Ptr<IOBufferData> extra_data;
+  ink_hrtime time_stamp;
 
-  SSLSession(const SSLSessionID &id, const Ptr<IOBufferData> &ssl_asn1_data, 
size_t len_asn1, Ptr<IOBufferData> &exdata)
-    : session_id(id), asn1_data(ssl_asn1_data), len_asn1_data(len_asn1), 
extra_data(exdata)
+  SSLSession(const SSLSessionID &id, const Ptr<IOBufferData> &ssl_asn1_data, 
size_t len_asn1, Ptr<IOBufferData> &exdata,
+             ink_hrtime ts)
+    : session_id(id), asn1_data(ssl_asn1_data), len_asn1_data(len_asn1), 
extra_data(exdata), time_stamp(ts)
   {
   }
 
@@ -134,18 +166,19 @@ class SSLSessionBucket
 public:
   SSLSessionBucket();
   ~SSLSessionBucket();
-  void insertSession(const SSLSessionID &, SSL_SESSION *ctx, SSL *ssl);
-  bool getSession(const SSLSessionID &, SSL_SESSION **ctx, 
ssl_session_cache_exdata **data);
-  int getSessionBuffer(const SSLSessionID &, char *buffer, int &len);
-  void removeSession(const SSLSessionID &);
+  void insertSession(const SSLSessionID &sid, SSL_SESSION *sess, SSL *ssl);
+  bool getSession(const SSLSessionID &sid, SSL_SESSION **sess, 
ssl_session_cache_exdata **data);
+  int getSessionBuffer(const SSLSessionID &sid, char *buffer, int &len);
+  void removeSession(const SSLSessionID &sid);
 
 private:
   /* these method must be used while hold the lock */
   void print(const char *) const;
-  void removeOldestSession();
+  void removeOldestSession(const std::unique_lock<std::shared_mutex> &lock);
 
-  Ptr<ProxyMutex> mutex;
-  CountQueue<SSLSession> queue;
+  mutable std::shared_mutex mutex;
+  std::unordered_map<SSLSessionID, SSLSession *, SSLSessionIDHash> bucket_data;
+  std::map<ink_hrtime, SSLSession *> bucket_data_ts;
 };
 
 class SSLSessionCache
diff --git a/iocore/net/SSLUtils.cc b/iocore/net/SSLUtils.cc
index d34d6b8..1c2398a 100644
--- a/iocore/net/SSLUtils.cc
+++ b/iocore/net/SSLUtils.cc
@@ -191,6 +191,12 @@ ssl_get_cached_session(SSL *ssl, const unsigned char *id, 
int len, int *copy)
 static int
 ssl_new_cached_session(SSL *ssl, SSL_SESSION *sess)
 {
+#ifdef TLS1_3_VERSION
+  if (SSL_SESSION_get_protocol_version(sess) == TLS1_3_VERSION) {
+    return 0;
+  }
+#endif
+
   unsigned int len        = 0;
   const unsigned char *id = SSL_SESSION_get_id(sess, &len);
 
@@ -219,6 +225,12 @@ ssl_new_cached_session(SSL *ssl, SSL_SESSION *sess)
 static void
 ssl_rm_cached_session(SSL_CTX *ctx, SSL_SESSION *sess)
 {
+#ifdef TLS1_3_VERSION
+  if (SSL_SESSION_get_protocol_version(sess) == TLS1_3_VERSION) {
+    return;
+  }
+#endif
+
   unsigned int len        = 0;
   const unsigned char *id = SSL_SESSION_get_id(sess, &len);
   SSLSessionID sid(id, len);

Reply via email to