Allow QUIC server to replace connection IDs

This CL changes the QuicDispatcher to have it replace the connection ID provided by the client if its length differs from what the dispatcher was configured with. It also changes QuicConnection on the client side to accept connection ID changes coming from the server, and replace its own connection ID to match what the server expects on outgoing packets. This checks VariableLengthConnectionIdAllowedForVersion() so it only impacts v99.

gfe-relnote: v99-only change, not flag protected
PiperOrigin-RevId: 239328650
Change-Id: I21ee0c0ca74c7624823c38a72f323ae6491e21e6
diff --git a/quic/core/http/end_to_end_test.cc b/quic/core/http/end_to_end_test.cc
index dd40098..c3f50ef 100644
--- a/quic/core/http/end_to_end_test.cc
+++ b/quic/core/http/end_to_end_test.cc
@@ -741,10 +741,17 @@
     ASSERT_TRUE(Initialize());
     return;
   }
-  QuicConnectionId connection_id = TestConnectionIdNineBytesLong(1);
+  QuicConnectionId connection_id =
+      TestConnectionIdNineBytesLong(UINT64_C(0xBADbadBADbad));
   override_connection_id_ = &connection_id;
-  ASSERT_FALSE(Initialize());
-  EXPECT_EQ(QUIC_HANDSHAKE_FAILED, client_->connection_error());
+  ASSERT_TRUE(Initialize());
+  EXPECT_EQ(kFooResponseBody, client_->SendSynchronousRequest("/foo"));
+  EXPECT_EQ("200", client_->response_headers()->find(":status")->second);
+  EXPECT_EQ(kQuicDefaultConnectionIdLength, client_->client()
+                                                ->client_session()
+                                                ->connection()
+                                                ->connection_id()
+                                                .length());
 }
 
 TEST_P(EndToEndTest, MixGoodAndBadConnectionIdLengths) {
@@ -754,14 +761,14 @@
     return;
   }
 
-  // Start client_ which will fail due to bad connection ID length.
-  QuicConnectionId connection_id = TestConnectionIdNineBytesLong(1);
+  // Start client_ which will use a bad connection ID length.
+  QuicConnectionId connection_id =
+      TestConnectionIdNineBytesLong(UINT64_C(0xBADbadBADbad));
   override_connection_id_ = &connection_id;
-  ASSERT_FALSE(Initialize());
-  EXPECT_EQ(QUIC_HANDSHAKE_FAILED, client_->connection_error());
+  ASSERT_TRUE(Initialize());
   override_connection_id_ = nullptr;
 
-  // Start client2 which will succeed.
+  // Start client2 which will use a good connection ID length.
   std::unique_ptr<QuicTestClient> client2(CreateQuicClient(nullptr));
   SpdyHeaderBlock headers;
   headers[":method"] = "POST";
@@ -771,9 +778,23 @@
   headers["content-length"] = "3";
   client2->SendMessage(headers, "", /*fin=*/false);
   client2->SendData("eep", true);
+
+  EXPECT_EQ(kFooResponseBody, client_->SendSynchronousRequest("/foo"));
+  EXPECT_EQ("200", client_->response_headers()->find(":status")->second);
+  EXPECT_EQ(kQuicDefaultConnectionIdLength, client_->client()
+                                                ->client_session()
+                                                ->connection()
+                                                ->connection_id()
+                                                .length());
+
   client2->WaitForResponse();
   EXPECT_EQ(kFooResponseBody, client2->response_body());
   EXPECT_EQ("200", client2->response_headers()->find(":status")->second);
+  EXPECT_EQ(kQuicDefaultConnectionIdLength, client2->client()
+                                                ->client_session()
+                                                ->connection()
+                                                ->connection_id()
+                                                .length());
 }
 
 TEST_P(EndToEndTest, SimpleRequestResponseWithLargeReject) {
diff --git a/quic/core/quic_connection.cc b/quic/core/quic_connection.cc
index 17752a6..dcea61d 100644
--- a/quic/core/quic_connection.cc
+++ b/quic/core/quic_connection.cc
@@ -21,6 +21,7 @@
 #include "net/third_party/quiche/src/quic/core/proto/cached_network_parameters.pb.h"
 #include "net/third_party/quiche/src/quic/core/quic_bandwidth.h"
 #include "net/third_party/quiche/src/quic/core/quic_config.h"
+#include "net/third_party/quiche/src/quic/core/quic_connection_id.h"
 #include "net/third_party/quiche/src/quic/core/quic_packet_generator.h"
 #include "net/third_party/quiche/src/quic/core/quic_pending_retransmission.h"
 #include "net/third_party/quiche/src/quic/core/quic_types.h"
@@ -206,6 +207,17 @@
   QuicConnection* connection_;
 };
 
+// Whether this incoming packet is allowed to replace our connection ID.
+bool PacketCanReplaceConnectionId(const QuicPacketHeader& header,
+                                  Perspective perspective) {
+  return perspective == Perspective::IS_CLIENT &&
+         header.form == IETF_QUIC_LONG_HEADER_PACKET &&
+         QuicUtils::VariableLengthConnectionIdAllowedForVersion(
+             header.version.transport_version) &&
+         (header.long_packet_type == INITIAL ||
+          header.long_packet_type == RETRY);
+}
+
 }  // namespace
 
 #define ENDPOINT \
@@ -719,9 +731,34 @@
   RetransmitUnackedPackets(ALL_UNACKED_RETRANSMISSION);
 }
 
+bool QuicConnection::HasIncomingConnectionId(QuicConnectionId connection_id) {
+  for (QuicConnectionId const& incoming_connection_id :
+       incoming_connection_ids_) {
+    if (incoming_connection_id == connection_id) {
+      return true;
+    }
+  }
+  return false;
+}
+
+void QuicConnection::AddIncomingConnectionId(QuicConnectionId connection_id) {
+  if (HasIncomingConnectionId(connection_id)) {
+    return;
+  }
+  incoming_connection_ids_.push_back(connection_id);
+}
+
 bool QuicConnection::OnUnauthenticatedPublicHeader(
     const QuicPacketHeader& header) {
-  if (header.destination_connection_id == connection_id_) {
+  if (header.destination_connection_id == connection_id_ ||
+      HasIncomingConnectionId(header.destination_connection_id)) {
+    return true;
+  }
+
+  if (PacketCanReplaceConnectionId(header, perspective_)) {
+    QUIC_DLOG(INFO) << ENDPOINT << "Accepting packet with new connection ID "
+                    << header.destination_connection_id << " instead of "
+                    << connection_id_;
     return true;
   }
 
@@ -748,7 +785,9 @@
   // Check that any public reset packet with a different connection ID that was
   // routed to this QuicConnection has been redirected before control reaches
   // here.
-  DCHECK_EQ(connection_id_, header.destination_connection_id);
+  DCHECK(header.destination_connection_id == connection_id_ ||
+         HasIncomingConnectionId(header.destination_connection_id) ||
+         PacketCanReplaceConnectionId(header, perspective_));
 
   if (!packet_generator_.IsPendingPacketEmpty()) {
     // Incoming packets may change a queued ACK frame.
@@ -1947,6 +1986,14 @@
     self_address_ = last_packet_destination_address_;
   }
 
+  if (PacketCanReplaceConnectionId(header, perspective_) &&
+      connection_id_ != header.source_connection_id) {
+    QUIC_DLOG(INFO) << ENDPOINT << "Replacing connection ID " << connection_id_
+                    << " with " << header.source_connection_id;
+    connection_id_ = header.source_connection_id;
+    packet_generator_.SetConnectionId(connection_id_);
+  }
+
   if (!ValidateReceivedPacketNumber(header.packet_number)) {
     return false;
   }
diff --git a/quic/core/quic_connection.h b/quic/core/quic_connection.h
index a2202ee..715a449 100644
--- a/quic/core/quic_connection.h
+++ b/quic/core/quic_connection.h
@@ -30,6 +30,7 @@
 #include "net/third_party/quiche/src/quic/core/quic_alarm.h"
 #include "net/third_party/quiche/src/quic/core/quic_alarm_factory.h"
 #include "net/third_party/quiche/src/quic/core/quic_blocked_writer_interface.h"
+#include "net/third_party/quiche/src/quic/core/quic_connection_id.h"
 #include "net/third_party/quiche/src/quic/core/quic_connection_stats.h"
 #include "net/third_party/quiche/src/quic/core/quic_framer.h"
 #include "net/third_party/quiche/src/quic/core/quic_one_block_arena.h"
@@ -864,6 +865,10 @@
     return sent_packet_manager_.handshake_confirmed();
   }
 
+  // Adds the connection ID to a set of connection IDs that are accepted as
+  // destination on incoming packets.
+  void AddIncomingConnectionId(QuicConnectionId connection_id);
+
  protected:
   // Calls cancel() on all the alarms owned by this connection.
   void CancelAllAlarms();
@@ -1111,6 +1116,9 @@
   // Returns the largest sent packet number that has been ACKed by peer.
   QuicPacketNumber GetLargestAckedPacket() const;
 
+  // Whether incoming_connection_ids_ contains connection_id.
+  bool HasIncomingConnectionId(QuicConnectionId connection_id);
+
   QuicFramer framer_;
 
   // Contents received in the current packet, especially used to identify
@@ -1136,7 +1144,7 @@
   const QuicClock* clock_;
   QuicRandom* random_generator_;
 
-  const QuicConnectionId connection_id_;
+  QuicConnectionId connection_id_;
   // Address on the last successfully processed packet received from the
   // direct peer.
   QuicSocketAddress self_address_;
@@ -1463,6 +1471,11 @@
   // saved and responded to.
   QuicDeque<QuicPathFrameBuffer> received_path_challenge_payloads_;
 
+  // Set of connection IDs that should be accepted as destination on
+  // received packets. This is conceptually a set but is implemented as a
+  // vector to improve performance since it is expected to be very small.
+  std::vector<QuicConnectionId> incoming_connection_ids_;
+
   // Latched value of quic_fix_termination_packets.
   const bool fix_termination_packets_;
 
diff --git a/quic/core/quic_dispatcher.cc b/quic/core/quic_dispatcher.cc
index 5fddab6..7ef2869 100644
--- a/quic/core/quic_dispatcher.cc
+++ b/quic/core/quic_dispatcher.cc
@@ -329,6 +329,30 @@
   //            next packet does not use them incorrectly.
 }
 
+QuicConnectionId QuicDispatcher::MaybeReplaceConnectionId(
+    QuicConnectionId connection_id,
+    ParsedQuicVersion version) {
+  const uint8_t expected_connection_id_length =
+      framer_.GetExpectedConnectionIdLength();
+  if (connection_id.length() == expected_connection_id_length) {
+    return connection_id;
+  }
+  DCHECK(QuicUtils::VariableLengthConnectionIdAllowedForVersion(
+      version.transport_version));
+  auto it = connection_id_map_.find(connection_id);
+  if (it != connection_id_map_.end()) {
+    return it->second;
+  }
+  QuicConnectionId new_connection_id =
+      session_helper_->GenerateConnectionIdForReject(version.transport_version,
+                                                     connection_id);
+  DCHECK_EQ(expected_connection_id_length, new_connection_id.length());
+  connection_id_map_.insert(std::make_pair(connection_id, new_connection_id));
+  QUIC_DLOG(INFO) << "Replacing incoming connection ID " << connection_id
+                  << " with " << new_connection_id;
+  return new_connection_id;
+}
+
 bool QuicDispatcher::OnUnauthenticatedPublicHeader(
     const QuicPacketHeader& header) {
   current_connection_id_ = header.destination_connection_id;
@@ -345,28 +369,10 @@
   if (header.destination_connection_id_included != CONNECTION_ID_PRESENT) {
     return false;
   }
-
-  // We currently do not support having the server change its connection ID
-  // length during the handshake. Until then, fast-fail connections.
-  // TODO(dschinazi): actually support changing connection IDs from the server.
-  if (header.destination_connection_id.length() !=
-      framer_.GetExpectedConnectionIdLength()) {
-    DCHECK(QuicUtils::VariableLengthConnectionIdAllowedForVersion(
-        header.version.transport_version));
-    QUIC_DLOG(INFO)
-        << "Packet with unexpected connection ID lengths: destination "
-        << header.destination_connection_id << " source "
-        << header.source_connection_id << " expected "
-        << static_cast<int>(framer_.GetExpectedConnectionIdLength());
-    ProcessUnauthenticatedHeaderFate(kFateTimeWait,
-                                     header.destination_connection_id,
-                                     header.form, header.version);
-    return false;
-  }
+  QuicConnectionId connection_id = header.destination_connection_id;
 
   // Packets with connection IDs for active connections are processed
   // immediately.
-  QuicConnectionId connection_id = header.destination_connection_id;
   auto it = session_map_.find(connection_id);
   if (it != session_map_.end()) {
     DCHECK(!buffered_packets_.HasBufferedPackets(connection_id));
@@ -998,9 +1004,15 @@
     if (packets.empty()) {
       return;
     }
+    QuicConnectionId original_connection_id = connection_id;
+    connection_id =
+        MaybeReplaceConnectionId(connection_id, packet_list.version);
     QuicSession* session =
         CreateQuicSession(connection_id, packets.front().peer_address,
                           packet_list.alpn, packet_list.version);
+    if (original_connection_id != connection_id) {
+      session->connection()->AddIncomingConnectionId(original_connection_id);
+    }
     QUIC_DLOG(INFO) << "Created new session for " << connection_id;
     session_map_.insert(std::make_pair(connection_id, QuicWrapUnique(session)));
     DeliverPacketsToSession(packets, session);
@@ -1093,10 +1105,17 @@
     }
     return;
   }
+
+  QuicConnectionId original_connection_id = current_connection_id_;
+  current_connection_id_ =
+      MaybeReplaceConnectionId(current_connection_id_, framer_.version());
   // Creates a new session and process all buffered packets for this connection.
   QuicSession* session =
       CreateQuicSession(current_connection_id_, current_peer_address_,
                         current_alpn_, framer_.version());
+  if (original_connection_id != current_connection_id_) {
+    session->connection()->AddIncomingConnectionId(original_connection_id);
+  }
   QUIC_DLOG(INFO) << "Created new session for " << current_connection_id_;
   session_map_.insert(
       std::make_pair(current_connection_id_, QuicWrapUnique(session)));
diff --git a/quic/core/quic_dispatcher.h b/quic/core/quic_dispatcher.h
index 63bba64..cdd1196 100644
--- a/quic/core/quic_dispatcher.h
+++ b/quic/core/quic_dispatcher.h
@@ -110,6 +110,14 @@
   // Deletes all sessions on the closed session list and clears the list.
   virtual void DeleteSessions();
 
+  using ConnectionIdMap = QuicUnorderedMap<QuicConnectionId,
+                                           QuicConnectionId,
+                                           QuicConnectionIdHash>;
+
+  const ConnectionIdMap& connection_id_map() const {
+    return connection_id_map_;
+  }
+
   // The largest packet number we expect to receive with a connection
   // ID for a connection that is not established yet.  The current design will
   // send a handshake and then up to 50 or so data packets, and then it may
@@ -405,6 +413,12 @@
       QuicTransportVersion first_version,
       PacketHeaderFormat form);
 
+  // If the connection ID length is different from what the dispatcher expects,
+  // replace the connection ID with a random one of the right length,
+  // and save it to make sure the mapping is persistent.
+  QuicConnectionId MaybeReplaceConnectionId(QuicConnectionId connection_id,
+                                            ParsedQuicVersion version);
+
   void set_new_sessions_allowed_per_event_loop(
       int16_t new_sessions_allowed_per_event_loop) {
     new_sessions_allowed_per_event_loop_ = new_sessions_allowed_per_event_loop;
@@ -422,6 +436,9 @@
 
   SessionMap session_map_;
 
+  // Map of connection IDs with bad lengths to their replacements.
+  ConnectionIdMap connection_id_map_;
+
   // Entity that manages connection_ids in time wait state.
   std::unique_ptr<QuicTimeWaitListManager> time_wait_list_manager_;
 
diff --git a/quic/core/quic_dispatcher_test.cc b/quic/core/quic_dispatcher_test.cc
index 9e86875..3f4ef0e 100644
--- a/quic/core/quic_dispatcher_test.cc
+++ b/quic/core/quic_dispatcher_test.cc
@@ -116,14 +116,14 @@
  public:
   TestDispatcher(const QuicConfig* config,
                  const QuicCryptoServerConfig* crypto_config,
-                 QuicVersionManager* version_manager)
+                 QuicVersionManager* version_manager,
+                 QuicRandom* random)
       : QuicDispatcher(config,
                        crypto_config,
                        version_manager,
                        QuicMakeUnique<MockQuicConnectionHelper>(),
                        std::unique_ptr<QuicCryptoServerStream::Helper>(
-                           new QuicSimpleCryptoServerStreamHelper(
-                               QuicRandom::GetInstance())),
+                           new QuicSimpleCryptoServerStreamHelper(random)),
                        QuicMakeUnique<MockAlarmFactory>(),
                        kQuicDefaultConnectionIdLength) {}
 
@@ -208,9 +208,11 @@
                        KeyExchangeSource::Default(),
                        TlsServerHandshaker::CreateSslCtx()),
         server_address_(QuicIpAddress::Any4(), 5),
-        dispatcher_(new NiceMock<TestDispatcher>(&config_,
-                                                 &crypto_config_,
-                                                 &version_manager_)),
+        dispatcher_(
+            new NiceMock<TestDispatcher>(&config_,
+                                         &crypto_config_,
+                                         &version_manager_,
+                                         mock_helper_.GetRandomGenerator())),
         time_wait_list_manager_(nullptr),
         session1_(nullptr),
         session2_(nullptr),
@@ -636,39 +638,52 @@
   ProcessPacket(client_address, connection_id, false, SerializeCHLO());
 }
 
-// Makes sure nine-byte connection IDs end up in the time wait list.
-TEST_F(QuicDispatcherTest, BadConnectionIdLengthPacketToTimeWaitListManager) {
+// Makes sure nine-byte connection IDs are replaced by 8-byte ones.
+TEST_F(QuicDispatcherTest, BadConnectionIdLengthReplaced) {
   if (!QuicUtils::VariableLengthConnectionIdAllowedForVersion(
           CurrentSupportedVersions()[0].transport_version)) {
+    // When variable length connection IDs are not supported, the connection
+    // fails. See StrayPacketTruncatedConnectionId.
     return;
   }
-  CreateTimeWaitListManager();
-
   QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1);
 
-  // Dispatcher forwards all packets for this connection_id to the time wait
-  // list manager.
-  EXPECT_CALL(*dispatcher_, CreateQuicSession(_, _, QuicStringPiece("hq"), _))
-      .Times(0);
-  EXPECT_CALL(*time_wait_list_manager_,
-              ProcessPacket(_, _, TestConnectionIdNineBytesLong(2), _, _))
-      .Times(1);
-  EXPECT_CALL(*time_wait_list_manager_,
-              AddConnectionIdToTimeWait(_, _, _, _, _))
-      .Times(1);
-  ProcessPacket(client_address, TestConnectionIdNineBytesLong(2), true,
-                SerializeCHLO());
+  QuicConnectionId bad_connection_id = TestConnectionIdNineBytesLong(2);
+  QuicConnectionId fixed_connection_id =
+      QuicUtils::CreateRandomConnectionId(mock_helper_.GetRandomGenerator());
+
+  EXPECT_CALL(*dispatcher_,
+              CreateQuicSession(fixed_connection_id, client_address,
+                                QuicStringPiece("hq"), _))
+      .WillOnce(testing::Return(CreateSession(
+          dispatcher_.get(), config_, fixed_connection_id, client_address,
+          &mock_helper_, &mock_alarm_factory_, &crypto_config_,
+          QuicDispatcherPeer::GetCache(dispatcher_.get()), &session1_)));
+  EXPECT_CALL(*reinterpret_cast<MockQuicConnection*>(session1_->connection()),
+              ProcessUdpPacket(_, _, _))
+      .WillOnce(WithArg<2>(
+          Invoke([this, bad_connection_id](const QuicEncryptedPacket& packet) {
+            ValidatePacket(bad_connection_id, packet);
+          })));
+  EXPECT_CALL(*dispatcher_,
+              ShouldCreateOrBufferPacketForConnection(bad_connection_id, _));
+  ProcessPacket(client_address, bad_connection_id, true, SerializeCHLO());
+  EXPECT_EQ(client_address, dispatcher_->current_peer_address());
+  EXPECT_EQ(server_address_, dispatcher_->current_self_address());
 }
 
-// Makes sure TestConnectionId(1) creates a new connection but
-// TestConnectionIdNineBytesLong(2) ends up in the time wait list.
+// Makes sure TestConnectionId(1) creates a new connection and
+// TestConnectionIdNineBytesLong(2) gets replaced.
 TEST_F(QuicDispatcherTest, MixGoodAndBadConnectionIdLengthPackets) {
   if (!QuicUtils::VariableLengthConnectionIdAllowedForVersion(
           CurrentSupportedVersions()[0].transport_version)) {
     return;
   }
-  CreateTimeWaitListManager();
+
   QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1);
+  QuicConnectionId bad_connection_id = TestConnectionIdNineBytesLong(2);
+  QuicConnectionId fixed_connection_id =
+      QuicUtils::CreateRandomConnectionId(mock_helper_.GetRandomGenerator());
 
   EXPECT_CALL(*dispatcher_,
               CreateQuicSession(TestConnectionId(1), client_address,
@@ -688,18 +703,22 @@
   EXPECT_EQ(client_address, dispatcher_->current_peer_address());
   EXPECT_EQ(server_address_, dispatcher_->current_self_address());
 
-  EXPECT_CALL(*dispatcher_, CreateQuicSession(TestConnectionIdNineBytesLong(2),
-                                              _, QuicStringPiece("hq"), _))
-      .Times(0);
-  EXPECT_CALL(*time_wait_list_manager_,
-              ProcessPacket(_, _, TestConnectionIdNineBytesLong(2), _, _))
-      .Times(1);
-  EXPECT_CALL(*time_wait_list_manager_,
-              AddConnectionIdToTimeWait(_, _, _, _, _))
-      .Times(1);
-
-  ProcessPacket(client_address, TestConnectionIdNineBytesLong(2), true,
-                SerializeCHLO());
+  EXPECT_CALL(*dispatcher_,
+              CreateQuicSession(fixed_connection_id, client_address,
+                                QuicStringPiece("hq"), _))
+      .WillOnce(testing::Return(CreateSession(
+          dispatcher_.get(), config_, fixed_connection_id, client_address,
+          &mock_helper_, &mock_alarm_factory_, &crypto_config_,
+          QuicDispatcherPeer::GetCache(dispatcher_.get()), &session2_)));
+  EXPECT_CALL(*reinterpret_cast<MockQuicConnection*>(session2_->connection()),
+              ProcessUdpPacket(_, _, _))
+      .WillOnce(WithArg<2>(
+          Invoke([this, bad_connection_id](const QuicEncryptedPacket& packet) {
+            ValidatePacket(bad_connection_id, packet);
+          })));
+  EXPECT_CALL(*dispatcher_,
+              ShouldCreateOrBufferPacketForConnection(bad_connection_id, _));
+  ProcessPacket(client_address, bad_connection_id, true, SerializeCHLO());
 
   EXPECT_CALL(*reinterpret_cast<MockQuicConnection*>(session1_->connection()),
               ProcessUdpPacket(_, _, _))
@@ -1304,15 +1323,20 @@
 // Packets with truncated connection IDs should be dropped.
 TEST_F(QuicDispatcherTestStrayPacketConnectionId,
        StrayPacketTruncatedConnectionId) {
+  if (QuicUtils::VariableLengthConnectionIdAllowedForVersion(
+          CurrentSupportedVersions()[0].transport_version)) {
+    // When variable length connection IDs are supported, the server
+    // transparently replaces empty connection IDs with one it chooses.
+    // See BadConnectionIdLengthReplaced.
+    return;
+  }
   CreateTimeWaitListManager();
 
   QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1);
   QuicConnectionId connection_id = TestConnectionId(1);
   EXPECT_CALL(*dispatcher_, CreateQuicSession(_, _, QuicStringPiece("hq"), _))
       .Times(0);
-  if (CurrentSupportedVersions()[0].transport_version > QUIC_VERSION_43 &&
-      !QuicUtils::VariableLengthConnectionIdAllowedForVersion(
-          CurrentSupportedVersions()[0].transport_version)) {
+  if (CurrentSupportedVersions()[0].transport_version > QUIC_VERSION_43) {
     // This IETF packet has invalid connection ID length.
     EXPECT_CALL(*time_wait_list_manager_, ProcessPacket(_, _, _, _, _))
         .Times(0);
diff --git a/quic/core/quic_packet_creator.cc b/quic/core/quic_packet_creator.cc
index 9bd3b11..d9b481c 100644
--- a/quic/core/quic_packet_creator.cc
+++ b/quic/core/quic_packet_creator.cc
@@ -970,6 +970,10 @@
   connection_id_included_ = connection_id_included;
 }
 
+void QuicPacketCreator::SetConnectionId(QuicConnectionId connection_id) {
+  connection_id_ = connection_id;
+}
+
 void QuicPacketCreator::SetTransmissionType(TransmissionType type) {
   DCHECK(can_set_transmission_type_);
 
diff --git a/quic/core/quic_packet_creator.h b/quic/core/quic_packet_creator.h
index 76e7726..150cb9c 100644
--- a/quic/core/quic_packet_creator.h
+++ b/quic/core/quic_packet_creator.h
@@ -215,6 +215,9 @@
   // Sets whether the connection ID should be sent over the wire.
   void SetConnectionIdIncluded(QuicConnectionIdIncluded connection_id_included);
 
+  // Update the connection ID used in outgoing packets.
+  void SetConnectionId(QuicConnectionId connection_id);
+
   // Sets the encryption level that will be applied to new packets.
   void set_encryption_level(EncryptionLevel level) {
     packet_.encryption_level = level;
diff --git a/quic/core/quic_packet_generator.cc b/quic/core/quic_packet_generator.cc
index fd18064..c2b3d48 100644
--- a/quic/core/quic_packet_generator.cc
+++ b/quic/core/quic_packet_generator.cc
@@ -535,4 +535,8 @@
   return packet_creator_.GetLargestMessagePayload();
 }
 
+void QuicPacketGenerator::SetConnectionId(QuicConnectionId connection_id) {
+  packet_creator_.SetConnectionId(connection_id);
+}
+
 }  // namespace quic
diff --git a/quic/core/quic_packet_generator.h b/quic/core/quic_packet_generator.h
index 32e95c3..4c780f5 100644
--- a/quic/core/quic_packet_generator.h
+++ b/quic/core/quic_packet_generator.h
@@ -229,6 +229,9 @@
   // Returns the largest payload that will fit into a single MESSAGE frame.
   QuicPacketLength GetLargestMessagePayload() const;
 
+  // Update the connection ID used in outgoing packets.
+  void SetConnectionId(QuicConnectionId connection_id);
+
   void set_debug_delegate(QuicPacketCreator::DebugDelegate* debug_delegate) {
     packet_creator_.set_debug_delegate(debug_delegate);
   }
diff --git a/quic/core/quic_utils.cc b/quic/core/quic_utils.cc
index 6aaa67b..d03cc0e 100644
--- a/quic/core/quic_utils.cc
+++ b/quic/core/quic_utils.cc
@@ -8,6 +8,7 @@
 #include <cstdint>
 #include <string>
 
+#include "net/third_party/quiche/src/quic/core/quic_connection_id.h"
 #include "net/third_party/quiche/src/quic/core/quic_constants.h"
 #include "net/third_party/quiche/src/quic/core/quic_types.h"
 #include "net/third_party/quiche/src/quic/platform/api/quic_aligned.h"
@@ -456,15 +457,37 @@
 
 // static
 QuicConnectionId QuicUtils::CreateRandomConnectionId() {
-  return CreateRandomConnectionId(QuicRandom::GetInstance());
+  return CreateRandomConnectionId(kQuicDefaultConnectionIdLength,
+                                  QuicRandom::GetInstance());
 }
 
 // static
 QuicConnectionId QuicUtils::CreateRandomConnectionId(QuicRandom* random) {
-  char connection_id_bytes[kQuicDefaultConnectionIdLength];
-  random->RandBytes(connection_id_bytes, QUIC_ARRAYSIZE(connection_id_bytes));
+  return CreateRandomConnectionId(kQuicDefaultConnectionIdLength, random);
+}
+// static
+QuicConnectionId QuicUtils::CreateRandomConnectionId(
+    uint8_t connection_id_length) {
+  return CreateRandomConnectionId(connection_id_length,
+                                  QuicRandom::GetInstance());
+}
+
+// static
+QuicConnectionId QuicUtils::CreateRandomConnectionId(
+    uint8_t connection_id_length,
+    QuicRandom* random) {
+  if (connection_id_length == 0) {
+    return EmptyQuicConnectionId();
+  }
+  if (connection_id_length > kQuicMaxConnectionIdLength) {
+    QUIC_BUG << "Tried to CreateRandomConnectionId of invalid length "
+             << static_cast<int>(connection_id_length);
+    connection_id_length = kQuicMaxConnectionIdLength;
+  }
+  char connection_id_bytes[kQuicMaxConnectionIdLength];
+  random->RandBytes(connection_id_bytes, connection_id_length);
   return QuicConnectionId(static_cast<char*>(connection_id_bytes),
-                          QUIC_ARRAYSIZE(connection_id_bytes));
+                          connection_id_length);
 }
 
 // static
diff --git a/quic/core/quic_utils.h b/quic/core/quic_utils.h
index 9a052bd..709ea64 100644
--- a/quic/core/quic_utils.h
+++ b/quic/core/quic_utils.h
@@ -154,6 +154,15 @@
   // Generates a random 64bit connection ID using the provided QuicRandom.
   static QuicConnectionId CreateRandomConnectionId(QuicRandom* random);
 
+  // Generates a random connection ID of the given length.
+  static QuicConnectionId CreateRandomConnectionId(
+      uint8_t connection_id_length);
+
+  // Generates a random connection ID of the given length using the provided
+  // QuicRandom.
+  static QuicConnectionId CreateRandomConnectionId(uint8_t connection_id_length,
+                                                   QuicRandom* random);
+
   // Returns true if the QUIC version allows variable length connection IDs.
   static bool VariableLengthConnectionIdAllowedForVersion(
       QuicTransportVersion version);
diff --git a/quic/core/quic_utils_test.cc b/quic/core/quic_utils_test.cc
index 6ee6a2a..d0ce7e8 100644
--- a/quic/core/quic_utils_test.cc
+++ b/quic/core/quic_utils_test.cc
@@ -177,6 +177,28 @@
   EXPECT_NE(connection_id, EmptyQuicConnectionId());
   EXPECT_NE(connection_id, TestConnectionId());
   EXPECT_NE(connection_id, TestConnectionId(1));
+  EXPECT_NE(connection_id, TestConnectionIdNineBytesLong(1));
+  EXPECT_EQ(QuicUtils::CreateRandomConnectionId().length(),
+            kQuicDefaultConnectionIdLength);
+}
+
+TEST_F(QuicUtilsTest, RandomConnectionIdVariableLength) {
+  MockRandom random(1337);
+  const uint8_t connection_id_length = 9;
+  QuicConnectionId connection_id =
+      QuicUtils::CreateRandomConnectionId(connection_id_length, &random);
+  EXPECT_EQ(connection_id.length(), connection_id_length);
+  char connection_id_bytes[connection_id_length];
+  random.RandBytes(connection_id_bytes, QUIC_ARRAYSIZE(connection_id_bytes));
+  EXPECT_EQ(connection_id,
+            QuicConnectionId(static_cast<char*>(connection_id_bytes),
+                             QUIC_ARRAYSIZE(connection_id_bytes)));
+  EXPECT_NE(connection_id, EmptyQuicConnectionId());
+  EXPECT_NE(connection_id, TestConnectionId());
+  EXPECT_NE(connection_id, TestConnectionId(1));
+  EXPECT_NE(connection_id, TestConnectionIdNineBytesLong(1));
+  EXPECT_EQ(QuicUtils::CreateRandomConnectionId(connection_id_length).length(),
+            connection_id_length);
 }
 
 TEST_F(QuicUtilsTest, VariableLengthConnectionId) {