This is an automated email from the ASF dual-hosted git repository.

maskit pushed a commit to branch quic-latest
in repository https://gitbox.apache.org/repos/asf/trafficserver.git

commit 6751e97d47d4bef227c35ba206ae4466292b1cdc
Author: Masakazu Kitajo <[email protected]>
AuthorDate: Thu Jun 21 11:07:08 2018 +0900

    Protect sending packet numbers
---
 iocore/net/P_QUICPacketHandler.h          |  10 +--
 iocore/net/QUICNetVConnection.cc          |   6 +-
 iocore/net/QUICPacketHandler.cc           |  20 +++---
 iocore/net/quic/QUICPacket.cc             | 100 ++++++++++++++++++++++++++++++
 iocore/net/quic/QUICPacket.h              |   5 ++
 iocore/net/quic/QUICPacketReceiveQueue.cc |  51 +--------------
 iocore/net/quic/QUICPacketReceiveQueue.h  |   2 -
 7 files changed, 128 insertions(+), 66 deletions(-)

diff --git a/iocore/net/P_QUICPacketHandler.h b/iocore/net/P_QUICPacketHandler.h
index 3e982bc..77fe364 100644
--- a/iocore/net/P_QUICPacketHandler.h
+++ b/iocore/net/P_QUICPacketHandler.h
@@ -32,6 +32,7 @@
 class QUICClosedConCollector;
 class QUICNetVConnection;
 class QUICPacket;
+class QUICPacketNumberProtector;
 
 class QUICPacketHandler
 {
@@ -39,11 +40,12 @@ public:
   QUICPacketHandler();
   ~QUICPacketHandler();
 
-  virtual void send_packet(const QUICPacket &packet, QUICNetVConnection *vc) = 
0;
+  virtual void send_packet(const QUICPacket &packet, QUICNetVConnection *vc, 
QUICPacketNumberProtector &pn_protector) = 0;
   virtual void close_conenction(QUICNetVConnection *conn);
 
 protected:
-  static void _send_packet(Continuation *c, const QUICPacket &packet, 
UDPConnection *udp_con, IpEndpoint &addr, uint32_t pmtu);
+  static void _send_packet(Continuation *c, const QUICPacket &packet, 
UDPConnection *udp_con, IpEndpoint &addr, uint32_t pmtu,
+                           QUICPacketNumberProtector *pn_protector);
 
   Event *_collector_event                       = nullptr;
   QUICClosedConCollector *_closed_con_collector = nullptr;
@@ -68,7 +70,7 @@ public:
   void init_accept(EThread *t) override;
 
   // QUICPacketHandler
-  virtual void send_packet(const QUICPacket &packet, QUICNetVConnection *vc) 
override;
+  virtual void send_packet(const QUICPacket &packet, QUICNetVConnection *vc, 
QUICPacketNumberProtector &pn_protector) override;
 
 private:
   void _recv_packet(int event, UDPPacket *udp_packet) override;
@@ -90,7 +92,7 @@ public:
   int event_handler(int event, Event *data);
 
   // QUICPacketHandler
-  virtual void send_packet(const QUICPacket &packet, QUICNetVConnection *vc) 
override;
+  virtual void send_packet(const QUICPacket &packet, QUICNetVConnection *vc, 
QUICPacketNumberProtector &pn_protector) override;
 
 private:
   void _recv_packet(int event, UDPPacket *udp_packet) override;
diff --git a/iocore/net/QUICNetVConnection.cc b/iocore/net/QUICNetVConnection.cc
index e472756..a9c9d20 100644
--- a/iocore/net/QUICNetVConnection.cc
+++ b/iocore/net/QUICNetVConnection.cc
@@ -1077,7 +1077,7 @@ QUICNetVConnection::_state_common_send_packet()
       break;
     }
 
-    this->_packet_handler->send_packet(*packet, this);
+    this->_packet_handler->send_packet(*packet, this, this->_pn_protector);
     if (packet->type() == QUICPacketType::HANDSHAKE) {
       ++this->_handshake_packets_sent;
     }
@@ -1116,7 +1116,7 @@ QUICNetVConnection::_state_handshake_send_retry_packet()
   }
 
   QUICPacketUPtr packet = this->_build_packet(std::move(buf), len, 
retransmittable, QUICPacketType::RETRY);
-  this->_packet_handler->send_packet(*packet, this);
+  this->_packet_handler->send_packet(*packet, this, this->_pn_protector);
   this->_loss_detector->on_packet_sent(std::move(packet));
 
   QUIC_INCREMENT_DYN_STAT_EX(QUICStats::total_packets_sent_stat, 1);
@@ -1137,7 +1137,7 @@ QUICNetVConnection::_state_closing_send_packet()
   // that an endpoint maintains for a closing connection, endpoints MAY
   // send the exact same packet.
   if (this->_the_final_packet) {
-    this->_packet_handler->send_packet(*this->_the_final_packet, this);
+    this->_packet_handler->send_packet(*this->_the_final_packet, this, 
this->_pn_protector);
   }
   return QUICErrorUPtr(new QUICNoError());
 }
diff --git a/iocore/net/QUICPacketHandler.cc b/iocore/net/QUICPacketHandler.cc
index ab2a573..7e6a58e 100644
--- a/iocore/net/QUICPacketHandler.cc
+++ b/iocore/net/QUICPacketHandler.cc
@@ -70,7 +70,8 @@ QUICPacketHandler::close_conenction(QUICNetVConnection *conn)
 }
 
 void
-QUICPacketHandler::_send_packet(Continuation *c, const QUICPacket &packet, 
UDPConnection *udp_con, IpEndpoint &addr, uint32_t pmtu)
+QUICPacketHandler::_send_packet(Continuation *c, const QUICPacket &packet, 
UDPConnection *udp_con, IpEndpoint &addr, uint32_t pmtu,
+                                QUICPacketNumberProtector *pn_protector)
 {
   size_t udp_len;
   Ptr<IOBufferBlock> udp_payload(new_IOBufferBlock());
@@ -78,6 +79,10 @@ QUICPacketHandler::_send_packet(Continuation *c, const 
QUICPacket &packet, UDPCo
   packet.store(reinterpret_cast<uint8_t *>(udp_payload->end()), &udp_len);
   udp_payload->fill(udp_len);
 
+  if (pn_protector) {
+    QUICPacket::protect_packet_number(reinterpret_cast<uint8_t 
*>(udp_payload->start()), udp_len, pn_protector);
+  }
+
   UDPPacket *udp_packet = new_UDPPacket(addr, 0, udp_payload);
 
   if (is_debug_tag_set(tag)) {
@@ -216,7 +221,7 @@ QUICPacketHandlerIn::_recv_packet(int event, UDPPacket 
*udp_packet)
       QUICDebugDS(scid, dcid, "Unsupported version: 0x%x", v);
 
       QUICPacketUPtr vn = 
QUICPacketFactory::create_version_negotiation_packet(scid, dcid);
-      this->_send_packet(this, *vn, udp_packet->getConnection(), 
udp_packet->from, 1200);
+      this->_send_packet(this, *vn, udp_packet->getConnection(), 
udp_packet->from, 1200, nullptr);
       udp_packet->free();
       return;
     }
@@ -247,7 +252,7 @@ QUICPacketHandlerIn::_recv_packet(int event, UDPPacket 
*udp_packet)
       token.generate(dcid, params->server_id());
     }
     auto packet = QUICPacketFactory::create_stateless_reset_packet(dcid, 
token);
-    this->_send_packet(this, *packet, udp_packet->getConnection(), 
udp_packet->from, 1200);
+    this->_send_packet(this, *packet, udp_packet->getConnection(), 
udp_packet->from, 1200, nullptr);
     udp_packet->free();
     return;
   }
@@ -297,9 +302,9 @@ QUICPacketHandlerIn::_recv_packet(int event, UDPPacket 
*udp_packet)
 
 // TODO: Should be called via eventProcessor?
 void
-QUICPacketHandlerIn::send_packet(const QUICPacket &packet, QUICNetVConnection 
*vc)
+QUICPacketHandlerIn::send_packet(const QUICPacket &packet, QUICNetVConnection 
*vc, QUICPacketNumberProtector &pn_protector)
 {
-  this->_send_packet(this, packet, vc->get_udp_con(), vc->con.addr, 
vc->pmtu());
+  this->_send_packet(this, packet, vc->get_udp_con(), vc->con.addr, 
vc->pmtu(), &pn_protector);
 }
 
 //
@@ -342,9 +347,10 @@ QUICPacketHandlerOut::event_handler(int event, Event *data)
 }
 
 void
-QUICPacketHandlerOut::send_packet(const QUICPacket &packet, QUICNetVConnection 
*vc)
+QUICPacketHandlerOut::send_packet(const QUICPacket &packet, QUICNetVConnection 
*vc, QUICPacketNumberProtector &pn_protector)
 {
-  this->_send_packet(this, packet, vc->get_udp_con(), vc->con.addr, 
vc->pmtu());
+  // TODO Pass QUICPacketNumberProtector
+  this->_send_packet(this, packet, vc->get_udp_con(), vc->con.addr, 
vc->pmtu(), nullptr);
 }
 
 void
diff --git a/iocore/net/quic/QUICPacket.cc b/iocore/net/quic/QUICPacket.cc
index d9cb9d0..a28a3f5 100644
--- a/iocore/net/quic/QUICPacket.cc
+++ b/iocore/net/quic/QUICPacket.cc
@@ -258,6 +258,20 @@ QUICPacketLongHeader::length(size_t &length, uint8_t 
*field_len, const uint8_t *
   return true;
 }
 
+bool
+QUICPacketLongHeader::packet_number_offset(size_t &pn_offset, const uint8_t 
*packet, size_t packet_len)
+{
+  uint8_t dcil, scil;
+  size_t length;
+  uint8_t length_field_len;
+  if (!QUICPacketLongHeader::dcil(dcil, packet, packet_len) || 
!QUICPacketLongHeader::scil(scil, packet, packet_len) ||
+      !QUICPacketLongHeader::length(length, &length_field_len, packet, 
packet_len)) {
+    return false;
+  }
+  pn_offset = 6 + dcil + scil + length_field_len;
+  return true;
+}
+
 QUICConnectionId
 QUICPacketLongHeader::destination_cid() const
 {
@@ -530,6 +544,14 @@ QUICPacketShortHeader::key_phase(QUICKeyPhase &phase, 
const uint8_t *packet, siz
   return true;
 }
 
+bool
+QUICPacketShortHeader::packet_number_offset(size_t &pn_offset, const uint8_t 
*packet, size_t packet_len)
+{
+  int connection_id_len = QUICConfigParams::scid_len();
+  pn_offset             = 1 + connection_id_len;
+  return true;
+}
+
 /**
  * Header Length (doesn't include payload length)
  */
@@ -733,6 +755,84 @@ QUICPacket::decode_packet_number(QUICPacketNumber &dst, 
QUICPacketNumber src, si
   return true;
 }
 
+bool
+QUICPacket::protect_packet_number(uint8_t *packet, size_t packet_len, 
QUICPacketNumberProtector *pn_protector)
+{
+  size_t pn_offset             = 0;
+  uint8_t pn_len               = 4;
+  size_t sample_offset         = 0;
+  uint8_t sample_len           = 0;
+  constexpr int aead_expansion = 16; // Currently, AEAD expansion (which is 
probably AEAD tag) length is always 16
+  QUICKeyPhase phase;
+
+  if (QUICInvariants::is_long_header(packet)) {
+    QUICPacketType type;
+    QUICPacketLongHeader::type(type, packet, packet_len);
+    switch (type) {
+    case QUICPacketType::ZERO_RTT_PROTECTED:
+      phase = QUICKeyPhase::ZERORTT;
+      break;
+    default:
+      phase = QUICKeyPhase::CLEARTEXT;
+      break;
+    }
+    QUICPacketLongHeader::packet_number_offset(pn_offset, packet, packet_len);
+  } else {
+    QUICPacketShortHeader::key_phase(phase, packet, packet_len);
+    QUICPacketShortHeader::packet_number_offset(pn_offset, packet, packet_len);
+  }
+  sample_offset = std::min(pn_offset + 4, packet_len - aead_expansion);
+  sample_len    = 16; // On draft-12, the length is always 16 (See 5.6.1 and 
5.6.2)
+
+  uint8_t protected_pn[4]  = {0};
+  uint8_t protected_pn_len = 0;
+  pn_len                   = QUICTypeUtil::read_QUICPacketNumberLen(packet + 
pn_offset);
+  if (!pn_protector->protect(protected_pn, protected_pn_len, packet + 
pn_offset, pn_len, packet + sample_offset, phase)) {
+    return false;
+  }
+  memcpy(packet + pn_offset, protected_pn, pn_len);
+  return true;
+}
+
+bool
+QUICPacket::unprotect_packet_number(uint8_t *packet, size_t packet_len, 
QUICPacketNumberProtector *pn_protector)
+{
+  size_t pn_offset             = 0;
+  uint8_t pn_len               = 4;
+  size_t sample_offset         = 0;
+  uint8_t sample_len           = 0;
+  constexpr int aead_expansion = 16; // Currently, AEAD expansion (which is 
probably AEAD tag) length is always 16
+  QUICKeyPhase phase;
+
+  if (QUICInvariants::is_long_header(packet)) {
+    QUICPacketType type;
+    QUICPacketLongHeader::type(type, packet, packet_len);
+    switch (type) {
+    case QUICPacketType::ZERO_RTT_PROTECTED:
+      phase = QUICKeyPhase::ZERORTT;
+      break;
+    default:
+      phase = QUICKeyPhase::CLEARTEXT;
+      break;
+    }
+    QUICPacketLongHeader::packet_number_offset(pn_offset, packet, packet_len);
+  } else {
+    QUICPacketShortHeader::key_phase(phase, packet, packet_len);
+    QUICPacketShortHeader::packet_number_offset(pn_offset, packet, packet_len);
+  }
+  sample_offset = std::min(pn_offset + 4, packet_len - aead_expansion);
+  sample_len    = 16; // On draft-12, the length is always 16 (See 5.6.1 and 
5.6.2)
+
+  uint8_t unprotected_pn[4]  = {0};
+  uint8_t unprotected_pn_len = 0;
+  if (!pn_protector->unprotect(unprotected_pn, unprotected_pn_len, packet + 
pn_offset, pn_len, packet + sample_offset, phase)) {
+    return false;
+  }
+  unprotected_pn_len = QUICTypeUtil::read_QUICPacketNumberLen(unprotected_pn);
+  memcpy(packet + pn_offset, unprotected_pn, unprotected_pn_len);
+  return true;
+}
+
 //
 // QUICPacketFactory
 //
diff --git a/iocore/net/quic/QUICPacket.h b/iocore/net/quic/QUICPacket.h
index 3ce9bc8..4067fda 100644
--- a/iocore/net/quic/QUICPacket.h
+++ b/iocore/net/quic/QUICPacket.h
@@ -201,6 +201,7 @@ public:
    */
   static bool scil(uint8_t &scil, const uint8_t *packet, size_t packet_len);
   static bool length(size_t &length, uint8_t *field_len, const uint8_t 
*packet, size_t packet_len);
+  static bool packet_number_offset(size_t &pn_offset, const uint8_t *packet, 
size_t packet_len);
 
 private:
   QUICPacketNumber _packet_number;
@@ -237,6 +238,7 @@ public:
   void store(uint8_t *buf, size_t *len) const;
 
   static bool key_phase(QUICKeyPhase &key_phase, const uint8_t *packet, size_t 
packet_len);
+  static bool packet_number_offset(size_t &pn_offset, const uint8_t *packet, 
size_t packet_len);
 
 private:
   int _packet_number_len;
@@ -328,6 +330,9 @@ public:
   static bool encode_packet_number(QUICPacketNumber &dst, QUICPacketNumber 
src, size_t len);
   static bool decode_packet_number(QUICPacketNumber &dst, QUICPacketNumber 
src, size_t len, QUICPacketNumber largest_acked);
 
+  static bool protect_packet_number(uint8_t *packet, size_t packet_len, 
QUICPacketNumberProtector *pn_protector);
+  static bool unprotect_packet_number(uint8_t *packet, size_t packet_len, 
QUICPacketNumberProtector *pn_protector);
+
   LINK(QUICPacket, link);
 
 private:
diff --git a/iocore/net/quic/QUICPacketReceiveQueue.cc 
b/iocore/net/quic/QUICPacketReceiveQueue.cc
index bfd6c5f..fcb3184 100644
--- a/iocore/net/quic/QUICPacketReceiveQueue.cc
+++ b/iocore/net/quic/QUICPacketReceiveQueue.cc
@@ -138,7 +138,7 @@ QUICPacketReceiveQueue::dequeue(QUICPacketCreationResult 
&result)
     this->_offset      = 0;
   }
 
-  if (this->_unprotect_packet_number(pkt.get(), pkt_len)) {
+  if (QUICPacket::unprotect_packet_number(pkt.get(), pkt_len, 
&this->_pn_protector)) {
     quic_packet = this->_packet_factory.create(this->_from, std::move(pkt), 
pkt_len, this->_largest_received_packet_number, result);
   } else {
     result = QUICPacketCreationResult::FAILED;
@@ -179,52 +179,3 @@ QUICPacketReceiveQueue::reset()
 {
   this->_largest_received_packet_number = 0;
 }
-
-bool
-QUICPacketReceiveQueue::_unprotect_packet_number(uint8_t *packet, size_t 
packet_len)
-{
-  size_t pn_offset             = 0;
-  uint8_t pn_len               = 4;
-  size_t sample_offset         = 0;
-  uint8_t sample_len           = 0;
-  constexpr int aead_expansion = 16; // Currently, AEAD expansion (which is 
probably AEAD tag) length is always 16
-  int connection_id_len        = QUICConfigParams::scid_len();
-  QUICKeyPhase phase;
-
-  if (QUICInvariants::is_long_header(packet)) {
-    QUICPacketType type;
-    QUICPacketLongHeader::type(type, packet, packet_len);
-    switch (type) {
-    case QUICPacketType::ZERO_RTT_PROTECTED:
-      phase = QUICKeyPhase::ZERORTT;
-      break;
-    default:
-      phase = QUICKeyPhase::CLEARTEXT;
-      break;
-    }
-
-    uint8_t dcil, scil;
-    size_t payload_length;
-    uint8_t payload_length_field_len;
-    if (!QUICPacketLongHeader::dcil(dcil, packet, packet_len) || 
!QUICPacketLongHeader::scil(scil, packet, packet_len) ||
-        !QUICPacketLongHeader::payload_length(payload_length, 
&payload_length_field_len, packet, packet_len)) {
-      return false;
-    }
-    pn_offset = 6 + dcil + scil + payload_length_field_len;
-  } else {
-    QUICPacketShortHeader::key_phase(phase, packet, packet_len);
-    pn_offset = 1 + connection_id_len;
-  }
-  sample_offset = std::min(pn_offset + 4, packet_len - aead_expansion);
-  sample_len    = 16; // On draft-12, the length is always 16 (See 5.6.1 and 
5.6.2)
-
-  uint8_t unprotected_pn[4]  = {0};
-  uint8_t unprotected_pn_len = 0;
-  if (!this->_pn_protector.unprotect(unprotected_pn, unprotected_pn_len, 
packet + pn_offset, pn_len, packet + sample_offset,
-                                     phase)) {
-    return false;
-  }
-  unprotected_pn_len = QUICTypeUtil::read_QUICPacketNumberLen(unprotected_pn);
-  memcpy(packet + pn_offset, unprotected_pn, unprotected_pn_len);
-  return true;
-}
diff --git a/iocore/net/quic/QUICPacketReceiveQueue.h 
b/iocore/net/quic/QUICPacketReceiveQueue.h
index 230fc8a..9f9a3b8 100644
--- a/iocore/net/quic/QUICPacketReceiveQueue.h
+++ b/iocore/net/quic/QUICPacketReceiveQueue.h
@@ -48,6 +48,4 @@ private:
   size_t _payload_len     = 0;
   size_t _offset          = 0;
   IpEndpoint _from;
-
-  bool _unprotect_packet_number(uint8_t *packet, size_t packet_len);
 };

Reply via email to