Fix replacing connection IDs when initial crypters are in use

When the client uses an initial destination connection ID length different than 8 bytes, we replace that connection ID with a random 8-byte one. However, the initial crypters are still based on that original connection ID. This change makes sure that we initialize the TLS crypters with the right connection ID. This also removes the code that would set the crypters in tls_server_handshaker because that is no longer required now that we've completely removed in-connection version negotiation.

This change also makes sure we correctly check the QuicDispatcher::session_map_ for the replaced connection ID which is required to properly route subsequent long header packets after the first one.

This change is safe without flag protection because it only impacts versions that allow connection IDs of length different than 8, and all those versions are disabled by flags.

gfe-relnote: fix connection ID replacement, protected by quic_enable_v47/48/99
PiperOrigin-RevId: 257697143
Change-Id: Ifc3779e292104656abd72fae79fba8b7604cabe2
diff --git a/quic/core/http/end_to_end_test.cc b/quic/core/http/end_to_end_test.cc
index e05852b..c58d82f 100644
--- a/quic/core/http/end_to_end_test.cc
+++ b/quic/core/http/end_to_end_test.cc
@@ -646,7 +646,21 @@
                 GetParam().negotiated_version.transport_version));
 }
 
-TEST_P(EndToEndTest, BadConnectionIdLength) {
+TEST_P(EndToEndTestWithTls, ZeroConnectionID) {
+  QuicConnectionId connection_id = QuicUtils::CreateZeroConnectionId(
+      GetParam().negotiated_version.transport_version);
+  override_server_connection_id_ = &connection_id;
+  expected_server_connection_id_length_ = connection_id.length();
+  ASSERT_TRUE(Initialize());
+
+  EXPECT_EQ(kFooResponseBody, client_->SendSynchronousRequest("/foo"));
+  EXPECT_EQ("200", client_->response_headers()->find(":status")->second);
+  EXPECT_EQ(client_->client()->client_session()->connection()->connection_id(),
+            QuicUtils::CreateZeroConnectionId(
+                GetParam().negotiated_version.transport_version));
+}
+
+TEST_P(EndToEndTestWithTls, BadConnectionIdLength) {
   if (!QuicUtils::VariableLengthConnectionIdAllowedForVersion(
           GetParam().negotiated_version.transport_version)) {
     ASSERT_TRUE(Initialize());
@@ -665,7 +679,31 @@
                                                 .length());
 }
 
-TEST_P(EndToEndTest, ClientConnectionId) {
+// Tests a very long (16-byte) initial destination connection ID to make
+// sure the dispatcher properly replaces it with an 8-byte one.
+TEST_P(EndToEndTestWithTls, LongBadConnectionIdLength) {
+  if (!QuicUtils::VariableLengthConnectionIdAllowedForVersion(
+          GetParam().negotiated_version.transport_version)) {
+    ASSERT_TRUE(Initialize());
+    return;
+  }
+  const char connection_id_bytes[16] = {0xb0, 0xb1, 0xb2, 0xb3, 0xb4, 0xb5,
+                                        0xb6, 0xb7, 0xb8, 0xb9, 0xba, 0xbb,
+                                        0xbc, 0xbd, 0xbe, 0xbf};
+  QuicConnectionId connection_id =
+      QuicConnectionId(connection_id_bytes, sizeof(connection_id_bytes));
+  override_server_connection_id_ = &connection_id;
+  ASSERT_TRUE(Initialize());
+  EXPECT_EQ(kFooResponseBody, client_->SendSynchronousRequest("/foo"));
+  EXPECT_EQ("200", client_->response_headers()->find(":status")->second);
+  EXPECT_EQ(kQuicDefaultConnectionIdLength, client_->client()
+                                                ->client_session()
+                                                ->connection()
+                                                ->connection_id()
+                                                .length());
+}
+
+TEST_P(EndToEndTestWithTls, ClientConnectionId) {
   if (!GetParam().negotiated_version.SupportsClientConnectionIds()) {
     ASSERT_TRUE(Initialize());
     return;
@@ -682,7 +720,7 @@
                                       ->client_connection_id());
 }
 
-TEST_P(EndToEndTest, ForcedVersionNegotiationAndClientConnectionId) {
+TEST_P(EndToEndTestWithTls, ForcedVersionNegotiationAndClientConnectionId) {
   if (!GetParam().negotiated_version.SupportsClientConnectionIds()) {
     ASSERT_TRUE(Initialize());
     return;
@@ -702,7 +740,7 @@
                                       ->client_connection_id());
 }
 
-TEST_P(EndToEndTest, ForcedVersionNegotiationAndBadConnectionIdLength) {
+TEST_P(EndToEndTestWithTls, ForcedVersionNegotiationAndBadConnectionIdLength) {
   if (!QuicUtils::VariableLengthConnectionIdAllowedForVersion(
           GetParam().negotiated_version.transport_version)) {
     ASSERT_TRUE(Initialize());
@@ -724,6 +762,44 @@
                                                 .length());
 }
 
+// Forced Version Negotiation with a client connection ID and a long
+// connection ID.
+TEST_P(EndToEndTestWithTls, ForcedVersNegoAndClientCIDAndLongCID) {
+  if (!GetParam().negotiated_version.SupportsClientConnectionIds() ||
+      !QuicUtils::VariableLengthConnectionIdAllowedForVersion(
+          GetParam().negotiated_version.transport_version)) {
+    ASSERT_TRUE(Initialize());
+    return;
+  }
+  client_supported_versions_.insert(client_supported_versions_.begin(),
+                                    QuicVersionReservedForNegotiation());
+  const char connection_id_bytes[16] = {0xb0, 0xb1, 0xb2, 0xb3, 0xb4, 0xb5,
+                                        0xb6, 0xb7, 0xb8, 0xb9, 0xba, 0xbb,
+                                        0xbc, 0xbd, 0xbe, 0xbf};
+  QuicConnectionId connection_id =
+      QuicConnectionId(connection_id_bytes, sizeof(connection_id_bytes));
+  override_server_connection_id_ = &connection_id;
+  const char client_connection_id_bytes[18] = {
+      0xc0, 0xc1, 0xc2, 0xc3, 0xc4, 0xc5, 0xc6, 0xc7, 0xc8,
+      0xc9, 0xca, 0xcb, 0xcc, 0xcd, 0xce, 0xcf, 0xc0, 0xc1};
+  QuicConnectionId client_connection_id = QuicConnectionId(
+      client_connection_id_bytes, sizeof(client_connection_id_bytes));
+  override_client_connection_id_ = &client_connection_id;
+  ASSERT_TRUE(Initialize());
+  ASSERT_TRUE(ServerSendsVersionNegotiation());
+  EXPECT_EQ(kFooResponseBody, client_->SendSynchronousRequest("/foo"));
+  EXPECT_EQ("200", client_->response_headers()->find(":status")->second);
+  EXPECT_EQ(kQuicDefaultConnectionIdLength, client_->client()
+                                                ->client_session()
+                                                ->connection()
+                                                ->connection_id()
+                                                .length());
+  EXPECT_EQ(client_connection_id, client_->client()
+                                      ->client_session()
+                                      ->connection()
+                                      ->client_connection_id());
+}
+
 TEST_P(EndToEndTest, MixGoodAndBadConnectionIdLengths) {
   if (!QuicUtils::VariableLengthConnectionIdAllowedForVersion(
           GetParam().negotiated_version.transport_version)) {
diff --git a/quic/core/quic_connection.cc b/quic/core/quic_connection.cc
index 099217a..b693db7 100644
--- a/quic/core/quic_connection.cc
+++ b/quic/core/quic_connection.cc
@@ -351,17 +351,17 @@
   MaybeEnableMultiplePacketNumberSpacesSupport();
   DCHECK(perspective_ == Perspective::IS_CLIENT ||
          supported_versions.size() == 1);
-  InstallInitialCrypters();
+  InstallInitialCrypters(server_connection_id_);
 }
 
-void QuicConnection::InstallInitialCrypters() {
+void QuicConnection::InstallInitialCrypters(QuicConnectionId connection_id) {
   if (version().handshake_protocol != PROTOCOL_TLS1_3) {
     // Initial crypters are currently only supported with TLS.
     return;
   }
   CrypterPair crypters;
   CryptoUtils::CreateTlsInitialCrypters(perspective_, transport_version(),
-                                        server_connection_id_, &crypters);
+                                        connection_id, &crypters);
   SetEncrypter(ENCRYPTION_INITIAL, std::move(crypters.encrypter));
   InstallDecrypter(ENCRYPTION_INITIAL, std::move(crypters.decrypter));
 }
@@ -618,7 +618,7 @@
   packet_generator_.SetRetryToken(retry_token);
 
   // Reinstall initial crypters because the connection ID changed.
-  InstallInitialCrypters();
+  InstallInitialCrypters(server_connection_id_);
 }
 
 bool QuicConnection::HasIncomingConnectionId(QuicConnectionId connection_id) {
diff --git a/quic/core/quic_connection.h b/quic/core/quic_connection.h
index 503581b..5fdffee 100644
--- a/quic/core/quic_connection.h
+++ b/quic/core/quic_connection.h
@@ -868,6 +868,11 @@
   // For logging purpose.
   const QuicAckFrame& ack_frame() const;
 
+  // Install encrypter and decrypter for ENCRYPTION_INITIAL using
+  // |connection_id| as the first client-sent destination connection ID,
+  // or the one sent after an IETF Retry.
+  void InstallInitialCrypters(QuicConnectionId connection_id);
+
  protected:
   // Calls cancel() on all the alarms owned by this connection.
   void CancelAllAlarms();
@@ -1115,9 +1120,6 @@
   // Whether incoming_connection_ids_ contains connection_id.
   bool HasIncomingConnectionId(QuicConnectionId connection_id);
 
-  // Install encrypter and decrypter for ENCRYPTION_INITIAL.
-  void InstallInitialCrypters();
-
   QuicFramer framer_;
 
   // Contents received in the current packet, especially used to identify
diff --git a/quic/core/quic_dispatcher.cc b/quic/core/quic_dispatcher.cc
index f7973e0..0a83ca8 100644
--- a/quic/core/quic_dispatcher.cc
+++ b/quic/core/quic_dispatcher.cc
@@ -276,6 +276,9 @@
       session_helper_->GenerateConnectionIdForReject(version.transport_version,
                                                      server_connection_id);
   DCHECK_EQ(expected_server_connection_id_length_, new_connection_id.length());
+  // TODO(dschinazi) Prevent connection_id_map_ from growing indefinitely
+  // before we ship a version that supports variable length connection IDs
+  // to production.
   connection_id_map_.insert(
       std::make_pair(server_connection_id, new_connection_id));
   QUIC_DLOG(INFO) << "Replacing incoming connection ID " << server_connection_id
@@ -340,6 +343,21 @@
     it->second->ProcessUdpPacket(packet_info.self_address,
                                  packet_info.peer_address, packet_info.packet);
     return true;
+  } else {
+    // We did not find the connection ID, check if we've replaced it.
+    QuicConnectionId replaced_connection_id = MaybeReplaceServerConnectionId(
+        server_connection_id, packet_info.version);
+    if (replaced_connection_id != server_connection_id) {
+      // Search for the replacement.
+      auto it2 = session_map_.find(replaced_connection_id);
+      if (it2 != session_map_.end()) {
+        DCHECK(!buffered_packets_.HasBufferedPackets(replaced_connection_id));
+        it2->second->ProcessUdpPacket(packet_info.self_address,
+                                      packet_info.peer_address,
+                                      packet_info.packet);
+        return true;
+      }
+    }
   }
 
   if (buffered_packets_.HasChloForConnection(server_connection_id)) {
@@ -723,8 +741,13 @@
                           packet_list.alpn, packet_list.version);
     if (original_connection_id != server_connection_id) {
       session->connection()->AddIncomingConnectionId(original_connection_id);
+      session->connection()->InstallInitialCrypters(original_connection_id);
     }
     QUIC_DLOG(INFO) << "Created new session for " << server_connection_id;
+
+    DCHECK(session_map_.find(server_connection_id) == session_map_.end())
+        << "Tried to add session map existing entry " << server_connection_id;
+
     session_map_.insert(
         std::make_pair(server_connection_id, QuicWrapUnique(session)));
     DeliverPacketsToSession(packets, session);
@@ -825,9 +848,16 @@
                         packet_info->peer_address, alpn, packet_info->version);
   if (original_connection_id != packet_info->destination_connection_id) {
     session->connection()->AddIncomingConnectionId(original_connection_id);
+    session->connection()->InstallInitialCrypters(original_connection_id);
   }
   QUIC_DLOG(INFO) << "Created new session for "
                   << packet_info->destination_connection_id;
+
+  DCHECK(session_map_.find(packet_info->destination_connection_id) ==
+         session_map_.end())
+      << "Tried to add session map existing entry "
+      << packet_info->destination_connection_id;
+
   session_map_.insert(std::make_pair(packet_info->destination_connection_id,
                                      QuicWrapUnique(session)));
   std::list<BufferedPacket> packets =
diff --git a/quic/core/tls_client_handshaker.cc b/quic/core/tls_client_handshaker.cc
index 20d9ff6..d6fc038 100644
--- a/quic/core/tls_client_handshaker.cc
+++ b/quic/core/tls_client_handshaker.cc
@@ -68,14 +68,6 @@
 }
 
 bool TlsClientHandshaker::CryptoConnect() {
-  CrypterPair crypters;
-  CryptoUtils::CreateTlsInitialCrypters(
-      Perspective::IS_CLIENT, session()->connection()->transport_version(),
-      session()->connection_id(), &crypters);
-  session()->connection()->SetEncrypter(ENCRYPTION_INITIAL,
-                                        std::move(crypters.encrypter));
-  session()->connection()->InstallDecrypter(ENCRYPTION_INITIAL,
-                                            std::move(crypters.decrypter));
   state_ = STATE_HANDSHAKE_RUNNING;
 
   // Configure the SSL to be a client.
diff --git a/quic/core/tls_server_handshaker.cc b/quic/core/tls_server_handshaker.cc
index 02a6120..8172104 100644
--- a/quic/core/tls_server_handshaker.cc
+++ b/quic/core/tls_server_handshaker.cc
@@ -55,14 +55,6 @@
       tls_connection_(ssl_ctx, this) {
   DCHECK_EQ(PROTOCOL_TLS1_3,
             session->connection()->version().handshake_protocol);
-  CrypterPair crypters;
-  CryptoUtils::CreateTlsInitialCrypters(
-      Perspective::IS_SERVER, session->connection()->transport_version(),
-      session->connection_id(), &crypters);
-  session->connection()->SetEncrypter(ENCRYPTION_INITIAL,
-                                      std::move(crypters.encrypter));
-  session->connection()->InstallDecrypter(ENCRYPTION_INITIAL,
-                                          std::move(crypters.decrypter));
 
   // Configure the SSL to be a server.
   SSL_set_accept_state(ssl());