Save large connection IDs on the heap

The IETF will soon increase the maximum size of connection IDs, this change allows us to scale that by placing large ones on the heap.

gfe-relnote: change representation of connection IDs, protected by quic_use_allocated_connection_ids
PiperOrigin-RevId: 252524833
Change-Id: Ia98aee46adebbe6f5f3df2a4c0062a54a22e2d0d
diff --git a/quic/core/quic_connection_id.cc b/quic/core/quic_connection_id.cc
index 4022283..e1487e3 100644
--- a/quic/core/quic_connection_id.cc
+++ b/quic/core/quic_connection_id.cc
@@ -19,27 +19,70 @@
 
 namespace quic {
 
-QuicConnectionId::QuicConnectionId() : length_(0) {}
+QuicConnectionId::QuicConnectionId() : QuicConnectionId(nullptr, 0) {}
 
 QuicConnectionId::QuicConnectionId(const char* data, uint8_t length) {
+  static_assert(
+      kQuicMaxConnectionIdLength <= std::numeric_limits<uint8_t>::max(),
+      "kQuicMaxConnectionIdLength too high");
   if (length > kQuicMaxConnectionIdLength) {
     QUIC_BUG << "Attempted to create connection ID of length " << length;
     length = kQuicMaxConnectionIdLength;
   }
   length_ = length;
-  if (length_ > 0) {
+  if (length_ == 0) {
+    return;
+  }
+  if (!GetQuicRestartFlag(quic_use_allocated_connection_ids)) {
     memcpy(data_, data, length_);
+    return;
+  }
+  if (length_ <= sizeof(data_short_)) {
+    memcpy(data_short_, data, length_);
+    return;
+  }
+  data_long_ = reinterpret_cast<char*>(malloc(length_));
+  CHECK_NE(nullptr, data_long_);
+  memcpy(data_long_, data, length_);
+}
+
+QuicConnectionId::~QuicConnectionId() {
+  if (!GetQuicRestartFlag(quic_use_allocated_connection_ids)) {
+    return;
+  }
+  if (length_ > sizeof(data_short_)) {
+    free(data_long_);
+    data_long_ = nullptr;
   }
 }
 
-QuicConnectionId::~QuicConnectionId() {}
+QuicConnectionId::QuicConnectionId(const QuicConnectionId& other)
+    : QuicConnectionId(other.data(), other.length()) {}
+
+QuicConnectionId& QuicConnectionId::operator=(const QuicConnectionId& other) {
+  set_length(other.length());
+  memcpy(mutable_data(), other.data(), length_);
+  return *this;
+}
 
 const char* QuicConnectionId::data() const {
-  return data_;
+  if (!GetQuicRestartFlag(quic_use_allocated_connection_ids)) {
+    return data_;
+  }
+  if (length_ <= sizeof(data_short_)) {
+    return data_short_;
+  }
+  return data_long_;
 }
 
 char* QuicConnectionId::mutable_data() {
-  return data_;
+  if (!GetQuicRestartFlag(quic_use_allocated_connection_ids)) {
+    return data_;
+  }
+  if (length_ <= sizeof(data_short_)) {
+    return data_short_;
+  }
+  return data_long_;
 }
 
 uint8_t QuicConnectionId::length() const {
@@ -47,6 +90,30 @@
 }
 
 void QuicConnectionId::set_length(uint8_t length) {
+  if (GetQuicRestartFlag(quic_use_allocated_connection_ids)) {
+    char temporary_data[sizeof(data_short_)];
+    if (length > sizeof(data_short_)) {
+      if (length_ <= sizeof(data_short_)) {
+        // Copy data from data_short_ to data_long_.
+        memcpy(temporary_data, data_short_, length_);
+        data_long_ = reinterpret_cast<char*>(malloc(length));
+        CHECK_NE(nullptr, data_long_);
+        memcpy(data_long_, temporary_data, length_);
+      } else {
+        // Resize data_long_.
+        char* realloc_result =
+            reinterpret_cast<char*>(realloc(data_long_, length));
+        CHECK_NE(nullptr, realloc_result);
+        data_long_ = realloc_result;
+      }
+    } else if (length_ > sizeof(data_short_)) {
+      // Copy data from data_long_ to data_short_.
+      memcpy(temporary_data, data_long_, length);
+      free(data_long_);
+      data_long_ = nullptr;
+      memcpy(data_short_, temporary_data, length);
+    }
+  }
   length_ = length;
 }
 
@@ -56,8 +123,9 @@
 
 size_t QuicConnectionId::Hash() const {
   uint64_t data_bytes[3] = {0, 0, 0};
-  static_assert(sizeof(data_bytes) >= sizeof(data_), "sizeof(data_) changed");
-  memcpy(data_bytes, data_, length_);
+  static_assert(sizeof(data_bytes) >= kQuicMaxConnectionIdLength,
+                "kQuicMaxConnectionIdLength changed");
+  memcpy(data_bytes, data(), length_);
   // This Hash function is designed to return the same value as the host byte
   // order representation when the connection ID length is 64 bits.
   return QuicEndian::NetToHost64(kQuicDefaultConnectionIdLength ^ length_ ^
@@ -68,7 +136,7 @@
   if (IsEmpty()) {
     return std::string("0");
   }
-  return QuicTextUtils::HexEncode(data_, length_);
+  return QuicTextUtils::HexEncode(data(), length_);
 }
 
 std::ostream& operator<<(std::ostream& os, const QuicConnectionId& v) {
@@ -77,7 +145,7 @@
 }
 
 bool QuicConnectionId::operator==(const QuicConnectionId& v) const {
-  return length_ == v.length_ && memcmp(data_, v.data_, length_) == 0;
+  return length_ == v.length_ && memcmp(data(), v.data(), length_) == 0;
 }
 
 bool QuicConnectionId::operator!=(const QuicConnectionId& v) const {
@@ -91,7 +159,7 @@
   if (length_ > v.length_) {
     return false;
   }
-  return memcmp(data_, v.data_, length_) < 0;
+  return memcmp(data(), v.data(), length_) < 0;
 }
 
 QuicConnectionId EmptyQuicConnectionId() {
diff --git a/quic/core/quic_connection_id.h b/quic/core/quic_connection_id.h
index d51366f..4b76f31 100644
--- a/quic/core/quic_connection_id.h
+++ b/quic/core/quic_connection_id.h
@@ -6,6 +6,7 @@
 #define QUICHE_QUIC_CORE_QUIC_CONNECTION_ID_H_
 
 #include <string>
+#include <vector>
 
 #include "net/third_party/quiche/src/quic/platform/api/quic_export.h"
 #include "net/third_party/quiche/src/quic/platform/api/quic_uint128.h"
@@ -43,12 +44,21 @@
   // Creates a connection ID from network order bytes.
   QuicConnectionId(const char* data, uint8_t length);
 
+  // Creates a connection ID from another connection ID.
+  QuicConnectionId(const QuicConnectionId& other);
+
+  // Assignment operator.
+  QuicConnectionId& operator=(const QuicConnectionId& other);
+
   ~QuicConnectionId();
 
   // Returns the length of the connection ID, in bytes.
   uint8_t length() const;
 
   // Sets the length of the connection ID, in bytes.
+  // WARNING: Calling set_length() can change the in-memory location of the
+  // connection ID. Callers must therefore ensure they call data() or
+  // mutable_data() after they call set_length().
   void set_length(uint8_t length);
 
   // Returns a pointer to the connection ID bytes, in network byte order.
@@ -79,10 +89,21 @@
   bool operator<(const QuicConnectionId& v) const;
 
  private:
-  // The connection ID is represented in network byte order
-  // in the first |length_| bytes of |data_|.
-  char data_[kQuicMaxConnectionIdLength];
-  uint8_t length_;
+  uint8_t length_;  // length of the connection ID, in bytes.
+  // The connection ID is represented in network byte order.
+  union {
+    // When quic_use_allocated_connection_ids is false, the connection ID is
+    // stored in the first |length_| bytes of |data_|.
+    char data_[kQuicMaxConnectionIdLength];
+    // When quic_use_allocated_connection_ids is true, if the connection ID
+    // fits in |data_short_|, it is stored in the first |length_| bytes of
+    // |data_short_|. Otherwise it is stored in |data_long_| which is
+    // guaranteed to have a size equal to |length_|. A value of 11 was chosen
+    // because our commonly used connection ID length is 8 and with the length,
+    // the class is padded to 12 bytes anyway.
+    char data_short_[11];
+    char* data_long_;
+  };
 };
 
 // Creates a connection ID of length zero, unless the restart flag
diff --git a/quic/core/quic_connection_id_test.cc b/quic/core/quic_connection_id_test.cc
index bbe04f9..1d290fc 100644
--- a/quic/core/quic_connection_id_test.cc
+++ b/quic/core/quic_connection_id_test.cc
@@ -91,6 +91,53 @@
   EXPECT_NE(connection_id64_2.Hash(), connection_id64_3.Hash());
 }
 
+TEST_F(QuicConnectionIdTest, AssignAndCopy) {
+  QuicConnectionId connection_id = test::TestConnectionId(1);
+  QuicConnectionId connection_id2 = test::TestConnectionId(2);
+  connection_id = connection_id2;
+  EXPECT_EQ(connection_id, test::TestConnectionId(2));
+  EXPECT_NE(connection_id, test::TestConnectionId(1));
+  connection_id = QuicConnectionId(test::TestConnectionId(1));
+  EXPECT_EQ(connection_id, test::TestConnectionId(1));
+  EXPECT_NE(connection_id, test::TestConnectionId(2));
+}
+
+TEST_F(QuicConnectionIdTest, ChangeLength) {
+  QuicConnectionId connection_id64_1 = test::TestConnectionId(1);
+  QuicConnectionId connection_id64_2 = test::TestConnectionId(2);
+  QuicConnectionId connection_id136_2 = test::TestConnectionId(2);
+  connection_id136_2.set_length(17);
+  memset(connection_id136_2.mutable_data() + 8, 0, 9);
+  char connection_id136_2_bytes[17] = {0, 0, 0, 0, 0, 0, 0, 2, 0,
+                                       0, 0, 0, 0, 0, 0, 0, 0};
+  QuicConnectionId connection_id136_2b(connection_id136_2_bytes,
+                                       sizeof(connection_id136_2_bytes));
+  EXPECT_EQ(connection_id136_2, connection_id136_2b);
+  QuicConnectionId connection_id = connection_id64_1;
+  connection_id.set_length(17);
+  EXPECT_NE(connection_id64_1, connection_id);
+  // Check resizing big to small.
+  connection_id.set_length(8);
+  EXPECT_EQ(connection_id64_1, connection_id);
+  // Check resizing small to big.
+  connection_id.set_length(17);
+  memset(connection_id.mutable_data(), 0, connection_id.length());
+  memcpy(connection_id.mutable_data(), connection_id64_2.data(),
+         connection_id64_2.length());
+  EXPECT_EQ(connection_id136_2, connection_id);
+  EXPECT_EQ(connection_id136_2b, connection_id);
+  QuicConnectionId connection_id120(connection_id136_2_bytes, 15);
+  connection_id.set_length(15);
+  EXPECT_EQ(connection_id120, connection_id);
+  // Check resizing big to big.
+  QuicConnectionId connection_id2 = connection_id120;
+  connection_id2.set_length(17);
+  connection_id2.mutable_data()[15] = 0;
+  connection_id2.mutable_data()[16] = 0;
+  EXPECT_EQ(connection_id136_2, connection_id2);
+  EXPECT_EQ(connection_id136_2b, connection_id2);
+}
+
 }  // namespace
 
 }  // namespace quic
diff --git a/quic/core/quic_data_reader.cc b/quic/core/quic_data_reader.cc
index ef09483..951a144 100644
--- a/quic/core/quic_data_reader.cc
+++ b/quic/core/quic_data_reader.cc
@@ -146,10 +146,21 @@
     return true;
   }
 
-  const bool ok = ReadBytes(connection_id->mutable_data(), length);
-  if (ok) {
-    connection_id->set_length(length);
+  if (!GetQuicRestartFlag(quic_use_allocated_connection_ids)) {
+    const bool ok = ReadBytes(connection_id->mutable_data(), length);
+    if (ok) {
+      connection_id->set_length(length);
+    }
+    return ok;
   }
+
+  if (BytesRemaining() < length) {
+    return false;
+  }
+
+  connection_id->set_length(length);
+  const bool ok = ReadBytes(connection_id->mutable_data(), length);
+  DCHECK(ok);
   return ok;
 }
 
diff --git a/quic/core/quic_versions.cc b/quic/core/quic_versions.cc
index d76b6be..11f8ea1 100644
--- a/quic/core/quic_versions.cc
+++ b/quic/core/quic_versions.cc
@@ -441,6 +441,7 @@
                      true);
   SetQuicRestartFlag(quic_do_not_override_connection_id, true);
   SetQuicRestartFlag(quic_no_framer_object_in_dispatcher, true);
+  SetQuicRestartFlag(quic_use_allocated_connection_ids, true);
 }
 
 void QuicEnableVersion(ParsedQuicVersion parsed_version) {