Store original server connection IDs in the dispatcher's session map. Release them when the 0-RTT key discard timer fires. Include them in the list of active connection IDs to make sure they are destroyed if the timer never fires.

Protected by quic_restart_flag_quic_map_original_connection_ids.

PiperOrigin-RevId: 460950140
diff --git a/quiche/quic/core/http/end_to_end_test.cc b/quiche/quic/core/http/end_to_end_test.cc
index 8b7add5..270ab51 100644
--- a/quiche/quic/core/http/end_to_end_test.cc
+++ b/quiche/quic/core/http/end_to_end_test.cc
@@ -75,6 +75,7 @@
 #include "quiche/quic/tools/quic_server.h"
 #include "quiche/quic/tools/quic_simple_client_stream.h"
 #include "quiche/quic/tools/quic_simple_server_stream.h"
+#include "quiche/common/platform/api/quiche_test.h"
 
 using spdy::kV3LowestPriority;
 using spdy::SpdyFramer;
@@ -7079,6 +7080,45 @@
   CheckResponseHeaders("400");
 }
 
+TEST_P(EndToEndTest, OriginalConnectionIdClearedFromMap) {
+  SetQuicRestartFlag(quic_map_original_connection_ids, true);
+  connect_to_server_on_initialize_ = false;
+  ASSERT_TRUE(Initialize());
+  if (override_client_connection_id_length_ != kLongConnectionIdLength) {
+    // There might not be an original connection ID.
+    CreateClientWithWriter();
+    return;
+  }
+
+  server_thread_->Pause();
+  QuicDispatcher* dispatcher =
+      QuicServerPeer::GetDispatcher(server_thread_->server());
+  EXPECT_EQ(QuicDispatcherPeer::GetFirstSessionIfAny(dispatcher), nullptr);
+  server_thread_->Resume();
+
+  CreateClientWithWriter();  // Also connects.
+  EXPECT_NE(client_, nullptr);
+
+  server_thread_->Pause();
+  EXPECT_NE(QuicDispatcherPeer::GetFirstSessionIfAny(dispatcher), nullptr);
+  EXPECT_EQ(dispatcher->NumSessions(), 1);
+  auto ids = GetServerConnection()->GetActiveServerConnectionIds();
+  ASSERT_EQ(ids.size(), 2);
+  for (QuicConnectionId id : ids) {
+    EXPECT_NE(QuicDispatcherPeer::FindSession(dispatcher, id), nullptr);
+  }
+  QuicConnectionId original = ids[1];
+  server_thread_->Resume();
+
+  client_->SendSynchronousRequest("/foo");
+  client_->Disconnect();
+
+  server_thread_->Pause();
+  EXPECT_EQ(QuicDispatcherPeer::GetFirstSessionIfAny(dispatcher), nullptr);
+  EXPECT_EQ(QuicDispatcherPeer::FindSession(dispatcher, original), nullptr);
+  server_thread_->Resume();
+}
+
 }  // namespace
 }  // namespace test
 }  // namespace quic
diff --git a/quiche/quic/core/quic_connection.cc b/quiche/quic/core/quic_connection.cc
index aeca5a1..80315b3 100644
--- a/quiche/quic/core/quic_connection.cc
+++ b/quiche/quic/core/quic_connection.cc
@@ -176,6 +176,9 @@
     QUICHE_DCHECK(connection_->connected());
     QUIC_DLOG(INFO) << "0-RTT discard alarm fired";
     connection_->RemoveDecrypter(ENCRYPTION_ZERO_RTT);
+    if (GetQuicRestartFlag(quic_map_original_connection_ids)) {
+      connection_->RetireOriginalDestinationConnectionId();
+    }
   }
 };
 
@@ -968,13 +971,21 @@
       default_path_.server_connection_id;
 }
 
-QuicConnectionId QuicConnection::GetOriginalDestinationConnectionId() {
+QuicConnectionId QuicConnection::GetOriginalDestinationConnectionId() const {
   if (original_destination_connection_id_.has_value()) {
     return original_destination_connection_id_.value();
   }
   return default_path_.server_connection_id;
 }
 
+void QuicConnection::RetireOriginalDestinationConnectionId() {
+  if (original_destination_connection_id_.has_value()) {
+    visitor_->OnServerConnectionIdRetired(*original_destination_connection_id_);
+    QUIC_RESTART_FLAG_COUNT_N(quic_map_original_connection_ids, 3, 4);
+    original_destination_connection_id_.reset();
+  }
+}
+
 bool QuicConnection::ValidateServerConnectionId(
     const QuicPacketHeader& header) const {
   if (perspective_ == Perspective::IS_CLIENT &&
@@ -6909,23 +6920,32 @@
           quic_consider_original_connection_id_as_active_pre_handshake)) {
     QUIC_RELOADABLE_FLAG_COUNT(
         quic_consider_original_connection_id_as_active_pre_handshake);
-    if (!IsHandshakeComplete() &&
-        original_destination_connection_id_.has_value()) {
-      // Consider original_destination_connection_id_ as active before handshake
-      // completes.
-      if (std::find(result.begin(), result.end(),
-                    original_destination_connection_id_.value()) !=
-          result.end()) {
-        QUIC_BUG(quic_unexpected_original_destination_connection_id)
-            << "original_destination_connection_id: "
-            << original_destination_connection_id_.value()
-            << " is unexpectedly in active "
-               "list";
-      } else {
-        result.insert(result.end(),
-                      original_destination_connection_id_.value());
-      }
-      QUIC_CODE_COUNT(quic_active_original_connection_id_pre_handshake);
+  }
+  if (!original_destination_connection_id_.has_value()) {
+    return result;
+  }
+  bool add_original_connection_id = false;
+  if (GetQuicRestartFlag(quic_map_original_connection_ids)) {
+    QUIC_RESTART_FLAG_COUNT_N(quic_map_original_connection_ids, 4, 4);
+    add_original_connection_id = true;
+  } else if (
+      !IsHandshakeComplete() &&
+      GetQuicReloadableFlag(
+          quic_consider_original_connection_id_as_active_pre_handshake)) {
+    QUIC_CODE_COUNT(quic_active_original_connection_id_pre_handshake);
+    add_original_connection_id = true;
+  }
+  if (add_original_connection_id) {
+    if (std::find(result.begin(), result.end(),
+                  original_destination_connection_id_.value()) !=
+        result.end()) {
+      QUIC_BUG(quic_unexpected_original_destination_connection_id)
+          << "original_destination_connection_id: "
+          << original_destination_connection_id_.value()
+          << " is unexpectedly in active "
+             "list";
+    } else {
+      result.insert(result.end(), original_destination_connection_id_.value());
     }
   }
   return result;
diff --git a/quiche/quic/core/quic_connection.h b/quiche/quic/core/quic_connection.h
index e2c8f49..cd187e3 100644
--- a/quiche/quic/core/quic_connection.h
+++ b/quiche/quic/core/quic_connection.h
@@ -1081,7 +1081,11 @@
       const QuicConnectionId& original_destination_connection_id);
 
   // Returns the original destination connection ID used for this connection.
-  QuicConnectionId GetOriginalDestinationConnectionId();
+  QuicConnectionId GetOriginalDestinationConnectionId() const;
+
+  // Tells the visitor the serverside connection is no longer expecting packets
+  // with the client-generated destination connection ID.
+  void RetireOriginalDestinationConnectionId();
 
   // Called when ACK alarm goes off. Sends ACKs of those packet number spaces
   // which have expired ACK timeout. Only used when this connection supports
diff --git a/quiche/quic/core/quic_connection_test.cc b/quiche/quic/core/quic_connection_test.cc
index 13cf40f..b07eb49 100644
--- a/quiche/quic/core/quic_connection_test.cc
+++ b/quiche/quic/core/quic_connection_test.cc
@@ -15587,6 +15587,29 @@
   }
 }
 
+TEST_P(QuicConnectionTest, OriginalConnectionId) {
+  set_perspective(Perspective::IS_SERVER);
+  EXPECT_FALSE(connection_.GetDiscardZeroRttDecryptionKeysAlarm()->IsSet());
+  EXPECT_EQ(connection_.GetOriginalDestinationConnectionId(),
+            connection_.connection_id());
+  QuicConnectionId original({0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08});
+  connection_.SetOriginalDestinationConnectionId(original);
+  EXPECT_EQ(original, connection_.GetOriginalDestinationConnectionId());
+  // Send a 1-RTT packet to start the DiscardZeroRttDecryptionKeys timer.
+  EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(1);
+  ProcessDataPacketAtLevel(1, false, ENCRYPTION_FORWARD_SECURE);
+  if (GetQuicRestartFlag(quic_map_original_connection_ids) &&
+      connection_.version().UsesTls()) {
+    EXPECT_TRUE(connection_.GetDiscardZeroRttDecryptionKeysAlarm()->IsSet());
+    EXPECT_CALL(visitor_, OnServerConnectionIdRetired(original));
+    connection_.GetDiscardZeroRttDecryptionKeysAlarm()->Fire();
+    EXPECT_EQ(connection_.GetOriginalDestinationConnectionId(),
+              connection_.connection_id());
+  } else {
+    EXPECT_EQ(connection_.GetOriginalDestinationConnectionId(), original);
+  }
+}
+
 }  // namespace
 }  // namespace test
 }  // namespace quic
diff --git a/quiche/quic/core/quic_dispatcher.cc b/quiche/quic/core/quic_dispatcher.cc
index 925ef14..e19d598 100644
--- a/quiche/quic/core/quic_dispatcher.cc
+++ b/quiche/quic/core/quic_dispatcher.cc
@@ -614,13 +614,18 @@
                                  packet_info.peer_address, packet_info.packet);
     return true;
   }
-  if (packet_info.version.IsKnown()) {
+  if (packet_info.version.IsKnown() &&
+      !GetQuicRestartFlag(quic_map_original_connection_ids)) {
     // We did not find the connection ID, check if we've replaced it.
     // This is only performed for supported versions because packets with
     // unsupported versions can flow through this function in order to send
     // a version negotiation packet, but we know that their connection ID
     // did not get replaced since that is performed on connection creation,
     // and that only happens for known verions.
+    // There is no need to perform this check if
+    // |reference_counted_session_map_| is storing original connection IDs
+    // separately. It can be counterproductive to do this check if that
+    // consumes a nonce or generates a random connection ID.
     QuicConnectionId replaced_connection_id = MaybeReplaceServerConnectionId(
         server_connection_id, packet_info.version);
     if (replaced_connection_id != server_connection_id) {
@@ -1277,6 +1282,8 @@
     server_connection_id = MaybeReplaceServerConnectionId(server_connection_id,
                                                           packet_list.version);
     std::string alpn = SelectAlpn(parsed_chlo.alpns);
+    // TODO(martinduke): Consider changing CreateQuicSession to return a
+    // shared_ptr<QuicSession>.
     std::unique_ptr<QuicSession> session = CreateQuicSession(
         server_connection_id, packets.front().self_address,
         packets.front().peer_address, alpn, packet_list.version, parsed_chlo);
@@ -1289,6 +1296,7 @@
     auto insertion_result = reference_counted_session_map_.insert(
         std::make_pair(server_connection_id,
                        std::shared_ptr<QuicSession>(std::move(session))));
+    std::shared_ptr<QuicSession> session_ptr = insertion_result.first->second;
     if (!insertion_result.second) {
       QUIC_BUG(quic_bug_12724_5)
           << "Tried to add a session to session_map with existing connection "
@@ -1296,8 +1304,20 @@
           << server_connection_id;
     } else {
       ++num_sessions_in_session_map_;
+      if (GetQuicRestartFlag(quic_map_original_connection_ids) &&
+          original_connection_id != server_connection_id) {
+        QUIC_RESTART_FLAG_COUNT_N(quic_map_original_connection_ids, 1, 4);
+        auto insertion_result2 = reference_counted_session_map_.insert(
+            std::make_pair(original_connection_id, session_ptr));
+        QUIC_BUG_IF(quic_460317833_01, !insertion_result2.second)
+            << "Original connection ID already in session_map: "
+            << original_connection_id;
+        // If insertion of the original connection ID fails, it might cause
+        // loss of 0-RTT and other first flight packets, but the connection
+        // will usually progress.
+      }
     }
-    DeliverPacketsToSession(packets, insertion_result.first->second.get());
+    DeliverPacketsToSession(packets, session_ptr.get());
   }
 }
 
@@ -1395,10 +1415,10 @@
   QUIC_DLOG(INFO) << "Created new session for "
                   << packet_info->destination_connection_id;
 
-  QuicSession* session_ptr;
-  auto insertion_result = reference_counted_session_map_.insert(std::make_pair(
-      packet_info->destination_connection_id,
-      std::shared_ptr<QuicSession>(std::move(session.release()))));
+  auto insertion_result = reference_counted_session_map_.insert(
+      std::make_pair(packet_info->destination_connection_id,
+                     std::shared_ptr<QuicSession>(std::move(session))));
+  std::shared_ptr<QuicSession> session_ptr = insertion_result.first->second;
   if (!insertion_result.second) {
     QUIC_BUG(quic_bug_10287_9)
         << "Tried to add a session to session_map with existing "
@@ -1406,8 +1426,19 @@
         << packet_info->destination_connection_id;
   } else {
     ++num_sessions_in_session_map_;
+    if (GetQuicRestartFlag(quic_map_original_connection_ids) &&
+        replaced_connection_id) {
+      QUIC_RESTART_FLAG_COUNT_N(quic_map_original_connection_ids, 2, 4);
+      auto insertion_result2 = reference_counted_session_map_.insert(
+          std::make_pair(original_connection_id, session_ptr));
+      QUIC_BUG_IF(quic_460317833_02, !insertion_result2.second)
+          << "Original connection ID already in session_map: "
+          << original_connection_id;
+      // If insertion of the original connection ID fails, it might cause
+      // loss of 0-RTT and other first flight packets, but the connection
+      // will usually progress.
+    }
   }
-  session_ptr = insertion_result.first->second.get();
   std::list<BufferedPacket> packets =
       buffered_packets_.DeliverPackets(original_connection_id).buffered_packets;
   if (replaced_connection_id && !packets.empty()) {
@@ -1420,7 +1451,7 @@
   // Deliver queued-up packets in the same order as they arrived.
   // Do this even when flag is off because there might be still some packets
   // buffered in the store before flag is turned off.
-  DeliverPacketsToSession(packets, session_ptr);
+  DeliverPacketsToSession(packets, session_ptr.get());
   --new_sessions_allowed_per_event_loop_;
 }
 
diff --git a/quiche/quic/core/quic_dispatcher_test.cc b/quiche/quic/core/quic_dispatcher_test.cc
index b8b8b44..3cdd5ab 100644
--- a/quiche/quic/core/quic_dispatcher_test.cc
+++ b/quiche/quic/core/quic_dispatcher_test.cc
@@ -187,7 +187,19 @@
   }
 
   std::vector<QuicConnectionId> GetActiveServerConnectionIds() const override {
-    return active_connection_ids_;
+    if (!GetQuicRestartFlag(quic_map_original_connection_ids)) {
+      return active_connection_ids_;
+    }
+    std::vector<QuicConnectionId> result;
+    for (const auto& cid : active_connection_ids_) {
+      result.push_back(cid);
+    }
+    auto original_connection_id = GetOriginalDestinationConnectionId();
+    if (std::find(result.begin(), result.end(), original_connection_id) ==
+        result.end()) {
+      result.push_back(original_connection_id);
+    }
+    return result;
   }
 
   void UnregisterOnConnectionClosed() {
diff --git a/quiche/quic/core/quic_flags_list.h b/quiche/quic/core/quic_flags_list.h
index d9fe60d..d423535 100644
--- a/quiche/quic/core/quic_flags_list.h
+++ b/quiche/quic/core/quic_flags_list.h
@@ -101,6 +101,8 @@
 QUIC_FLAG(quic_reloadable_flag_quic_connection_migration_use_new_cid_v2, true)
 // If true, uses conservative cwnd gain and pacing gain when cwnd gets bootstrapped.
 QUIC_FLAG(quic_reloadable_flag_quic_conservative_cwnd_and_pacing_gains, false)
+// Store original QUIC connection IDs in the dispatcher's map
+QUIC_FLAG(quic_restart_flag_quic_map_original_connection_ids, false)
 // When the flag is true, exit STARTUP after the same number of loss events as PROBE_UP.
 QUIC_FLAG(quic_reloadable_flag_quic_bbr2_startup_probe_up_loss_events, true)
 // When true, defaults to BBR congestion control instead of Cubic.