Only change the connection ID once from received initial packets

This matches the requirements from the IETF specification.
This is a client-only change and is therefore not flag-protected.

PiperOrigin-RevId: 360719109
Change-Id: If578250dce6a59d6e6156f5b48de34fb5ef6e32c
diff --git a/quic/core/quic_connection.cc b/quic/core/quic_connection.cc
index bbdb64c..173bc99 100644
--- a/quic/core/quic_connection.cc
+++ b/quic/core/quic_connection.cc
@@ -2757,6 +2757,14 @@
 
   if (PacketCanReplaceConnectionId(header, perspective_) &&
       server_connection_id_ != header.source_connection_id) {
+    QUICHE_DCHECK_EQ(header.long_packet_type, INITIAL);
+    if (server_connection_id_replaced_by_initial_) {
+      QUIC_DLOG(ERROR) << ENDPOINT << "Refusing to replace connection ID "
+                       << server_connection_id_ << " with "
+                       << header.source_connection_id;
+      return false;
+    }
+    server_connection_id_replaced_by_initial_ = true;
     QUIC_DLOG(INFO) << ENDPOINT << "Replacing connection ID "
                     << server_connection_id_ << " with "
                     << header.source_connection_id;
diff --git a/quic/core/quic_connection.h b/quic/core/quic_connection.h
index 33cf33b..302fb35 100644
--- a/quic/core/quic_connection.h
+++ b/quic/core/quic_connection.h
@@ -1712,6 +1712,10 @@
   // On the server, the connection ID is set when receiving the first packet.
   // This variable ensures we only set it this way once.
   bool client_connection_id_is_set_;
+
+  // Whether we've already replaced our server connection ID due to receiving an
+  // INITIAL packet with a different source connection ID. Only used on client.
+  bool server_connection_id_replaced_by_initial_ = false;
   // Address on the last successfully processed packet received from the
   // direct peer.
 
diff --git a/quic/core/quic_connection_test.cc b/quic/core/quic_connection_test.cc
index 4407411..32889da 100644
--- a/quic/core/quic_connection_test.cc
+++ b/quic/core/quic_connection_test.cc
@@ -1438,6 +1438,8 @@
                                bool missing_retry_id_in_config,
                                bool wrong_retry_id_in_config);
 
+  void TestReplaceConnectionIdFromInitial();
+
   QuicConnectionId connection_id_;
   QuicFramer framer_;
 
@@ -9266,7 +9268,7 @@
   frames.push_back(QuicFrame(ping_frame));
   frames.push_back(QuicFrame(padding_frame));
   std::unique_ptr<QuicPacket> packet =
-      BuildUnsizedDataPacket(&framer_, header, frames);
+      BuildUnsizedDataPacket(&peer_framer_, header, frames);
   char buffer[kMaxOutgoingPacketSize];
   size_t encrypted_length = peer_framer_.EncryptPayload(
       ENCRYPTION_FORWARD_SECURE, QuicPacketNumber(1), *packet, buffer,
@@ -9294,7 +9296,7 @@
   frames.push_back(QuicFrame(ping_frame));
   frames.push_back(QuicFrame(padding_frame));
   std::unique_ptr<QuicPacket> packet =
-      BuildUnsizedDataPacket(&framer_, header, frames);
+      BuildUnsizedDataPacket(&peer_framer_, header, frames);
   char buffer[kMaxOutgoingPacketSize];
   size_t encrypted_length = peer_framer_.EncryptPayload(
       ENCRYPTION_FORWARD_SECURE, QuicPacketNumber(1), *packet, buffer,
@@ -9322,7 +9324,7 @@
   frames.push_back(QuicFrame(ping_frame));
   frames.push_back(QuicFrame(padding_frame));
   std::unique_ptr<QuicPacket> packet =
-      BuildUnsizedDataPacket(&framer_, header, frames);
+      BuildUnsizedDataPacket(&peer_framer_, header, frames);
   char buffer[kMaxOutgoingPacketSize];
   size_t encrypted_length =
       peer_framer_.EncryptPayload(ENCRYPTION_INITIAL, QuicPacketNumber(1),
@@ -9334,6 +9336,79 @@
   EXPECT_EQ(0u, connection_.GetStats().packets_dropped);
   EXPECT_EQ(TestConnectionId(0x33), connection_.client_connection_id());
 }
+void QuicConnectionTest::TestReplaceConnectionIdFromInitial() {
+  if (!framer_.version().AllowsVariableLengthConnectionIds()) {
+    return;
+  }
+  // We start with a known connection ID.
+  EXPECT_TRUE(connection_.connected());
+  EXPECT_EQ(0u, connection_.GetStats().packets_dropped);
+  EXPECT_NE(TestConnectionId(0x33), connection_.connection_id());
+  // Receiving an initial can replace the connection ID once.
+  {
+    QuicPacketHeader header = ConstructPacketHeader(1, ENCRYPTION_INITIAL);
+    header.source_connection_id = TestConnectionId(0x33);
+    header.source_connection_id_included = CONNECTION_ID_PRESENT;
+    QuicFrames frames;
+    QuicPingFrame ping_frame;
+    QuicPaddingFrame padding_frame;
+    frames.push_back(QuicFrame(ping_frame));
+    frames.push_back(QuicFrame(padding_frame));
+    std::unique_ptr<QuicPacket> packet =
+        BuildUnsizedDataPacket(&peer_framer_, header, frames);
+    char buffer[kMaxOutgoingPacketSize];
+    size_t encrypted_length =
+        peer_framer_.EncryptPayload(ENCRYPTION_INITIAL, QuicPacketNumber(1),
+                                    *packet, buffer, kMaxOutgoingPacketSize);
+    QuicReceivedPacket received_packet(buffer, encrypted_length, clock_.Now(),
+                                       false);
+    ProcessReceivedPacket(kSelfAddress, kPeerAddress, received_packet);
+  }
+  EXPECT_TRUE(connection_.connected());
+  EXPECT_EQ(0u, connection_.GetStats().packets_dropped);
+  EXPECT_EQ(TestConnectionId(0x33), connection_.connection_id());
+  // Trying to replace the connection ID a second time drops the packet.
+  {
+    QuicPacketHeader header = ConstructPacketHeader(2, ENCRYPTION_INITIAL);
+    header.source_connection_id = TestConnectionId(0x66);
+    header.source_connection_id_included = CONNECTION_ID_PRESENT;
+    QuicFrames frames;
+    QuicPingFrame ping_frame;
+    QuicPaddingFrame padding_frame;
+    frames.push_back(QuicFrame(ping_frame));
+    frames.push_back(QuicFrame(padding_frame));
+    std::unique_ptr<QuicPacket> packet =
+        BuildUnsizedDataPacket(&peer_framer_, header, frames);
+    char buffer[kMaxOutgoingPacketSize];
+    size_t encrypted_length =
+        peer_framer_.EncryptPayload(ENCRYPTION_INITIAL, QuicPacketNumber(2),
+                                    *packet, buffer, kMaxOutgoingPacketSize);
+    QuicReceivedPacket received_packet(buffer, encrypted_length, clock_.Now(),
+                                       false);
+    ProcessReceivedPacket(kSelfAddress, kPeerAddress, received_packet);
+  }
+  EXPECT_TRUE(connection_.connected());
+  EXPECT_EQ(1u, connection_.GetStats().packets_dropped);
+  EXPECT_EQ(TestConnectionId(0x33), connection_.connection_id());
+}
+
+TEST_P(QuicConnectionTest, ReplaceServerConnectionIdFromInitial) {
+  TestReplaceConnectionIdFromInitial();
+}
+
+TEST_P(QuicConnectionTest, ReplaceServerConnectionIdFromRetryAndInitial) {
+  // First make the connection process a RETRY and replace the server connection
+  // ID a first time.
+  TestClientRetryHandling(/*invalid_retry_tag=*/false,
+                          /*missing_original_id_in_config=*/false,
+                          /*wrong_original_id_in_config=*/false,
+                          /*missing_retry_id_in_config=*/false,
+                          /*wrong_retry_id_in_config=*/false);
+  // Reset the test framer to use the right connection ID.
+  peer_framer_.SetInitialObfuscators(connection_.connection_id());
+  // Now process an INITIAL and replace the server connection ID a second time.
+  TestReplaceConnectionIdFromInitial();
+}
 
 // Regression test for b/134416344.
 TEST_P(QuicConnectionTest, CheckConnectedBeforeFlush) {