Make sure we have the right key before sending data.

Protected by FLAGS_quic_reloadable_flag_quic_check_keys_before_writing.

PiperOrigin-RevId: 333559010
Change-Id: I4ac0ed51c5401ef8b872c06c32148fa554eead54
diff --git a/quic/core/http/quic_server_session_base_test.cc b/quic/core/http/quic_server_session_base_test.cc
index ff3c62e..61303de 100644
--- a/quic/core/http/quic_server_session_base_test.cc
+++ b/quic/core/http/quic_server_session_base_test.cc
@@ -9,6 +9,7 @@
 #include <string>
 #include <utility>
 
+#include "net/third_party/quiche/src/quic/core/crypto/null_encrypter.h"
 #include "net/third_party/quiche/src/quic/core/crypto/quic_crypto_server_config.h"
 #include "net/third_party/quiche/src/quic/core/crypto/quic_random.h"
 #include "net/third_party/quiche/src/quic/core/proto/cached_network_parameters_proto.h"
@@ -347,7 +348,9 @@
   // streams.  For versions other than version 99, the server accepts slightly
   // more than the negotiated stream limit to deal with rare cases where a
   // client FIN/RST is lost.
-
+  connection_->SetEncrypter(
+      ENCRYPTION_FORWARD_SECURE,
+      std::make_unique<NullEncrypter>(session_->perspective()));
   connection_->SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE);
   session_->OnConfigNegotiated();
   if (!VersionHasIetfQuicFrames(transport_version())) {
@@ -400,7 +403,9 @@
   // Test that the server closes the connection if a client makes too many data
   // streams available.  The server accepts slightly more than the negotiated
   // stream limit to deal with rare cases where a client FIN/RST is lost.
-
+  connection_->SetEncrypter(
+      ENCRYPTION_FORWARD_SECURE,
+      std::make_unique<NullEncrypter>(session_->perspective()));
   connection_->SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE);
   session_->OnConfigNegotiated();
   const size_t kAvailableStreamLimit =
@@ -507,6 +512,9 @@
   QuicTagVector copt;
   copt.push_back(kBWRE);
   QuicConfigPeer::SetReceivedConnectionOptions(session_->config(), copt);
+  connection_->SetEncrypter(
+      ENCRYPTION_FORWARD_SECURE,
+      std::make_unique<NullEncrypter>(session_->perspective()));
   connection_->SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE);
   session_->OnConfigNegotiated();
   EXPECT_TRUE(
@@ -698,6 +706,9 @@
   QuicTagVector copt;
   copt.push_back(kBWMX);
   QuicConfigPeer::SetReceivedConnectionOptions(session_->config(), copt);
+  connection_->SetEncrypter(
+      ENCRYPTION_FORWARD_SECURE,
+      std::make_unique<NullEncrypter>(session_->perspective()));
   connection_->SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE);
   session_->OnConfigNegotiated();
   EXPECT_TRUE(
@@ -707,6 +718,9 @@
 TEST_P(QuicServerSessionBaseTest, NoBandwidthResumptionByDefault) {
   EXPECT_FALSE(
       QuicServerSessionBasePeer::IsBandwidthResumptionEnabled(session_.get()));
+  connection_->SetEncrypter(
+      ENCRYPTION_FORWARD_SECURE,
+      std::make_unique<NullEncrypter>(session_->perspective()));
   connection_->SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE);
   session_->OnConfigNegotiated();
   EXPECT_FALSE(
@@ -722,6 +736,9 @@
   QuicTagVector copt;
   copt.push_back(kQNSP);
   QuicConfigPeer::SetReceivedConnectionOptions(session_->config(), copt);
+  connection_->SetEncrypter(
+      ENCRYPTION_FORWARD_SECURE,
+      std::make_unique<NullEncrypter>(session_->perspective()));
   connection_->SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE);
   session_->OnConfigNegotiated();
   EXPECT_FALSE(session_->server_push_enabled());
diff --git a/quic/core/quic_connection.cc b/quic/core/quic_connection.cc
index 876e997..15298fe 100644
--- a/quic/core/quic_connection.cc
+++ b/quic/core/quic_connection.cc
@@ -3371,6 +3371,9 @@
   }
   encryption_level_ = level;
   packet_creator_.set_encryption_level(level);
+  QUIC_BUG_IF(!framer_.HasEncrypterOfEncryptionLevel(level))
+      << ENDPOINT << "Trying to set encryption level to "
+      << EncryptionLevelToString(level) << " while the key is missing";
 
   if (!changing_level) {
     return;
@@ -4517,6 +4520,16 @@
     return;
   }
 
+  if (check_keys_before_writing_) {
+    QUIC_RELOADABLE_FLAG_COUNT_N(quic_check_keys_before_writing, 1, 2);
+    if (!framer_.HasAnEncrypterForSpace(space)) {
+      QUIC_BUG << ENDPOINT
+               << "Try to bundle crypto with ACK with missing key of space "
+               << PacketNumberSpaceToString(space);
+      return;
+    }
+  }
+
   sent_packet_manager_.RetransmitDataOfSpaceIfAny(space);
 }
 
@@ -4540,6 +4553,13 @@
     if (!ack_timeout.IsInitialized()) {
       continue;
     }
+    if (check_keys_before_writing_) {
+      QUIC_RELOADABLE_FLAG_COUNT_N(quic_check_keys_before_writing, 2, 2);
+      if (!framer_.HasAnEncrypterForSpace(static_cast<PacketNumberSpace>(i))) {
+        // The key has been dropped.
+        continue;
+      }
+    }
     if (ack_timeout > clock_->ApproximateNow() &&
         ack_timeout > earliest_ack_timeout) {
       // Always send the earliest ACK to make forward progress in case alarm
diff --git a/quic/core/quic_connection.h b/quic/core/quic_connection.h
index 3eec673..c9d903e 100644
--- a/quic/core/quic_connection.h
+++ b/quic/core/quic_connection.h
@@ -1040,6 +1040,8 @@
     can_receive_ack_frequency_frame_ = true;
   }
 
+  bool check_keys_before_writing() const { return check_keys_before_writing_; }
+
  protected:
   // Calls cancel() on all the alarms owned by this connection.
   void CancelAllAlarms();
@@ -1791,6 +1793,9 @@
 
   const bool fix_out_of_order_sending_ =
       GetQuicReloadableFlag(quic_fix_out_of_order_sending2);
+
+  const bool check_keys_before_writing_ =
+      GetQuicReloadableFlag(quic_check_keys_before_writing);
 };
 
 }  // namespace quic
diff --git a/quic/core/quic_connection_test.cc b/quic/core/quic_connection_test.cc
index 17a7daa..4500a5b 100644
--- a/quic/core/quic_connection_test.cc
+++ b/quic/core/quic_connection_test.cc
@@ -4634,9 +4634,9 @@
   // which should result in the original packet being processed.
   SetDecrypter(ENCRYPTION_ZERO_RTT,
                std::make_unique<StrictTaggingDecrypter>(tag));
-  connection_.SetDefaultEncryptionLevel(ENCRYPTION_ZERO_RTT);
   connection_.SetEncrypter(ENCRYPTION_ZERO_RTT,
                            std::make_unique<TaggingEncrypter>(tag));
+  connection_.SetDefaultEncryptionLevel(ENCRYPTION_ZERO_RTT);
   EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(2);
   ProcessDataPacketAtLevel(2, !kHasStopWaiting, ENCRYPTION_ZERO_RTT);
 
@@ -4709,9 +4709,9 @@
   SetDecrypter(ENCRYPTION_ZERO_RTT,
                std::make_unique<StrictTaggingDecrypter>(tag));
   EXPECT_TRUE(connection_.GetProcessUndecryptablePacketsAlarm()->IsSet());
-  connection_.SetDefaultEncryptionLevel(ENCRYPTION_ZERO_RTT);
   connection_.SetEncrypter(ENCRYPTION_ZERO_RTT,
                            std::make_unique<TaggingEncrypter>(tag));
+  connection_.SetDefaultEncryptionLevel(ENCRYPTION_ZERO_RTT);
 
   EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(100);
   connection_.GetProcessUndecryptablePacketsAlarm()->Fire();
@@ -10965,9 +10965,9 @@
   SetDecrypter(ENCRYPTION_HANDSHAKE,
                std::make_unique<StrictTaggingDecrypter>(0x01));
   EXPECT_TRUE(connection_.GetProcessUndecryptablePacketsAlarm()->IsSet());
-  connection_.SetDefaultEncryptionLevel(ENCRYPTION_HANDSHAKE);
   connection_.SetEncrypter(ENCRYPTION_HANDSHAKE,
                            std::make_unique<TaggingEncrypter>(0x01));
+  connection_.SetDefaultEncryptionLevel(ENCRYPTION_HANDSHAKE);
   // Verify all ENCRYPTION_HANDSHAKE packets get processed.
   EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(6);
   connection_.GetProcessUndecryptablePacketsAlarm()->Fire();
diff --git a/quic/core/quic_crypto_stream_test.cc b/quic/core/quic_crypto_stream_test.cc
index b0f4de3..f7e2483 100644
--- a/quic/core/quic_crypto_stream_test.cc
+++ b/quic/core/quic_crypto_stream_test.cc
@@ -221,15 +221,18 @@
                        &MockQuicConnection::QuicConnection_SendCryptoData));
   stream_->WriteCryptoData(ENCRYPTION_INITIAL, data);
   // Send [1350, 2700) in ENCRYPTION_ZERO_RTT.
-  connection_->SetDefaultEncryptionLevel(ENCRYPTION_ZERO_RTT);
   std::unique_ptr<NullEncrypter> encrypter =
       std::make_unique<NullEncrypter>(Perspective::IS_CLIENT);
   connection_->SetEncrypter(ENCRYPTION_ZERO_RTT, std::move(encrypter));
+  connection_->SetDefaultEncryptionLevel(ENCRYPTION_ZERO_RTT);
   EXPECT_EQ(ENCRYPTION_ZERO_RTT, connection_->encryption_level());
   EXPECT_CALL(*connection_, SendCryptoData(ENCRYPTION_ZERO_RTT, 1350, 0))
       .WillOnce(Invoke(connection_,
                        &MockQuicConnection::QuicConnection_SendCryptoData));
   stream_->WriteCryptoData(ENCRYPTION_ZERO_RTT, data);
+  connection_->SetEncrypter(
+      ENCRYPTION_FORWARD_SECURE,
+      std::make_unique<NullEncrypter>(Perspective::IS_CLIENT));
   connection_->SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE);
   EXPECT_EQ(ENCRYPTION_FORWARD_SECURE, connection_->encryption_level());
 
@@ -277,15 +280,18 @@
                        &MockQuicConnection::QuicConnection_SendCryptoData));
   stream_->WriteCryptoData(ENCRYPTION_INITIAL, data);
   // Send [1000, 2000) in ENCRYPTION_HANDSHAKE.
-  connection_->SetDefaultEncryptionLevel(ENCRYPTION_HANDSHAKE);
   std::unique_ptr<NullEncrypter> encrypter =
       std::make_unique<NullEncrypter>(Perspective::IS_CLIENT);
   connection_->SetEncrypter(ENCRYPTION_HANDSHAKE, std::move(encrypter));
+  connection_->SetDefaultEncryptionLevel(ENCRYPTION_HANDSHAKE);
   EXPECT_EQ(ENCRYPTION_HANDSHAKE, connection_->encryption_level());
   EXPECT_CALL(*connection_, SendCryptoData(ENCRYPTION_HANDSHAKE, 1000, 0))
       .WillOnce(Invoke(connection_,
                        &MockQuicConnection::QuicConnection_SendCryptoData));
   stream_->WriteCryptoData(ENCRYPTION_HANDSHAKE, data);
+  connection_->SetEncrypter(
+      ENCRYPTION_FORWARD_SECURE,
+      std::make_unique<NullEncrypter>(Perspective::IS_CLIENT));
   connection_->SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE);
   EXPECT_EQ(ENCRYPTION_FORWARD_SECURE, connection_->encryption_level());
 
@@ -353,6 +359,9 @@
                        &MockQuicConnection::QuicConnection_SendCryptoData));
   stream_->WriteCryptoData(ENCRYPTION_INITIAL, data);
   // Send [1350, 2700) in ENCRYPTION_ZERO_RTT.
+  connection_->SetEncrypter(
+      ENCRYPTION_ZERO_RTT,
+      std::make_unique<NullEncrypter>(Perspective::IS_CLIENT));
   connection_->SetDefaultEncryptionLevel(ENCRYPTION_ZERO_RTT);
   std::unique_ptr<NullEncrypter> encrypter =
       std::make_unique<NullEncrypter>(Perspective::IS_CLIENT);
@@ -467,15 +476,18 @@
                        &MockQuicConnection::QuicConnection_SendCryptoData));
   stream_->WriteCryptoData(ENCRYPTION_INITIAL, data);
   // Send [1350, 2700) in ENCRYPTION_ZERO_RTT.
-  connection_->SetDefaultEncryptionLevel(ENCRYPTION_ZERO_RTT);
   std::unique_ptr<NullEncrypter> encrypter =
       std::make_unique<NullEncrypter>(Perspective::IS_CLIENT);
   connection_->SetEncrypter(ENCRYPTION_ZERO_RTT, std::move(encrypter));
+  connection_->SetDefaultEncryptionLevel(ENCRYPTION_ZERO_RTT);
   EXPECT_EQ(ENCRYPTION_ZERO_RTT, connection_->encryption_level());
   EXPECT_CALL(*connection_, SendCryptoData(ENCRYPTION_ZERO_RTT, 1350, 0))
       .WillOnce(Invoke(connection_,
                        &MockQuicConnection::QuicConnection_SendCryptoData));
   stream_->WriteCryptoData(ENCRYPTION_ZERO_RTT, data);
+  connection_->SetEncrypter(
+      ENCRYPTION_FORWARD_SECURE,
+      std::make_unique<NullEncrypter>(Perspective::IS_CLIENT));
   connection_->SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE);
   EXPECT_EQ(ENCRYPTION_FORWARD_SECURE, connection_->encryption_level());
 
@@ -591,6 +603,9 @@
   // Send [1350, 2700) in ENCRYPTION_ZERO_RTT and verify no write is attempted
   // because there is buffered data.
   EXPECT_CALL(*connection_, SendCryptoData(_, _, _)).Times(0);
+  connection_->SetEncrypter(
+      ENCRYPTION_ZERO_RTT,
+      std::make_unique<NullEncrypter>(Perspective::IS_CLIENT));
   connection_->SetDefaultEncryptionLevel(ENCRYPTION_ZERO_RTT);
   stream_->WriteCryptoData(ENCRYPTION_ZERO_RTT, data);
   EXPECT_EQ(ENCRYPTION_ZERO_RTT, connection_->encryption_level());
diff --git a/quic/core/quic_error_codes.cc b/quic/core/quic_error_codes.cc
index 54e4569..8a07598 100644
--- a/quic/core/quic_error_codes.cc
+++ b/quic/core/quic_error_codes.cc
@@ -231,6 +231,7 @@
     RETURN_STRING_LITERAL(QUIC_ZERO_RTT_REJECTION_LIMIT_REDUCED);
     RETURN_STRING_LITERAL(QUIC_ZERO_RTT_RESUMPTION_LIMIT_REDUCED);
     RETURN_STRING_LITERAL(QUIC_SILENT_IDLE_TIMEOUT);
+    RETURN_STRING_LITERAL(QUIC_MISSING_WRITE_KEYS);
 
     RETURN_STRING_LITERAL(QUIC_LAST_ERROR);
     // Intentionally have no default case, so we'll break the build
@@ -627,6 +628,8 @@
       return {true, static_cast<uint64_t>(INTERNAL_ERROR)};
     case QUIC_ZERO_RTT_RESUMPTION_LIMIT_REDUCED:
       return {true, static_cast<uint64_t>(PROTOCOL_VIOLATION)};
+    case QUIC_MISSING_WRITE_KEYS:
+      return {true, static_cast<uint64_t>(INTERNAL_ERROR)};
     case QUIC_LAST_ERROR:
       return {false, static_cast<uint64_t>(QUIC_LAST_ERROR)};
   }
diff --git a/quic/core/quic_error_codes.h b/quic/core/quic_error_codes.h
index 0a415ef..180dff7 100644
--- a/quic/core/quic_error_codes.h
+++ b/quic/core/quic_error_codes.h
@@ -495,8 +495,11 @@
   // The connection silently timed out due to no network activity.
   QUIC_SILENT_IDLE_TIMEOUT = 168,
 
+  // Try to write data without the right write keys.
+  QUIC_MISSING_WRITE_KEYS = 170,
+
   // No error. Used as bound while iterating.
-  QUIC_LAST_ERROR = 170,
+  QUIC_LAST_ERROR = 171,
 };
 // QuicErrorCodes is encoded as four octets on-the-wire when doing Google QUIC,
 // or a varint62 when doing IETF QUIC. Ensure that its value does not exceed
diff --git a/quic/core/quic_framer.cc b/quic/core/quic_framer.cc
index 6519f57..8b9d828 100644
--- a/quic/core/quic_framer.cc
+++ b/quic/core/quic_framer.cc
@@ -2034,6 +2034,23 @@
   return decrypter_[level] != nullptr;
 }
 
+bool QuicFramer::HasAnEncrypterForSpace(PacketNumberSpace space) const {
+  switch (space) {
+    case INITIAL_DATA:
+      return HasEncrypterOfEncryptionLevel(ENCRYPTION_INITIAL);
+    case HANDSHAKE_DATA:
+      return HasEncrypterOfEncryptionLevel(ENCRYPTION_HANDSHAKE);
+    case APPLICATION_DATA:
+      return HasEncrypterOfEncryptionLevel(ENCRYPTION_ZERO_RTT) ||
+             HasEncrypterOfEncryptionLevel(ENCRYPTION_FORWARD_SECURE);
+    case NUM_PACKET_NUMBER_SPACES:
+      break;
+  }
+  QUIC_BUG << ENDPOINT
+           << "Try to send data of space: " << PacketNumberSpaceToString(space);
+  return false;
+}
+
 bool QuicFramer::AppendPacketHeader(const QuicPacketHeader& header,
                                     QuicDataWriter* writer,
                                     size_t* length_field_offset) {
diff --git a/quic/core/quic_framer.h b/quic/core/quic_framer.h
index 8e63fc2..4fcd671 100644
--- a/quic/core/quic_framer.h
+++ b/quic/core/quic_framer.h
@@ -581,6 +581,9 @@
   // Returns true if decrypter of |level| is available.
   bool HasDecrypterOfEncryptionLevel(EncryptionLevel level) const;
 
+  // Returns true if an encrypter of |space| is available.
+  bool HasAnEncrypterForSpace(PacketNumberSpace space) const;
+
   void set_validate_flags(bool value) { validate_flags_ = value; }
 
   Perspective perspective() const { return perspective_; }
diff --git a/quic/core/quic_packet_creator.cc b/quic/core/quic_packet_creator.cc
index a13f3c3..30b911f 100644
--- a/quic/core/quic_packet_creator.cc
+++ b/quic/core/quic_packet_creator.cc
@@ -804,6 +804,8 @@
                 << packet_.encryption_level;
 
   if (!framer_->HasEncrypterOfEncryptionLevel(packet_.encryption_level)) {
+    // TODO(fayang): Use QUIC_MISSING_WRITE_KEYS for serialization failures due
+    // to missing keys.
     QUIC_BUG << ENDPOINT << "Attempting to serialize " << header
              << QuicFramesToString(queued_frames_)
              << " at missing encryption_level " << packet_.encryption_level
diff --git a/quic/core/quic_session.cc b/quic/core/quic_session.cc
index e3c389b..cec2337 100644
--- a/quic/core/quic_session.cc
+++ b/quic/core/quic_session.cc
@@ -774,6 +774,17 @@
                                    QuicStreamOffset offset,
                                    TransmissionType type) {
   DCHECK(QuicVersionUsesCryptoFrames(transport_version()));
+  if (connection()->check_keys_before_writing() &&
+      !connection()->framer().HasEncrypterOfEncryptionLevel(level)) {
+    const std::string error_details = quiche::QuicheStrCat(
+        "Try to send crypto data with missing keys of encryption level: ",
+        EncryptionLevelToString(level));
+    QUIC_BUG << ENDPOINT << error_details;
+    connection()->CloseConnection(
+        QUIC_MISSING_WRITE_KEYS, error_details,
+        ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET);
+    return 0;
+  }
   SetTransmissionType(type);
   const auto current_level = connection()->encryption_level();
   connection_->SetDefaultEncryptionLevel(level);
diff --git a/quic/core/quic_session_test.cc b/quic/core/quic_session_test.cc
index 5da254d..a5d9b5f 100644
--- a/quic/core/quic_session_test.cc
+++ b/quic/core/quic_session_test.cc
@@ -2978,6 +2978,9 @@
   EXPECT_TRUE(session_.WillingAndAbleToWrite());
 
   EXPECT_CALL(*connection_, SendCryptoData(_, _, _)).Times(0);
+  connection_->SetEncrypter(
+      ENCRYPTION_ZERO_RTT,
+      std::make_unique<NullEncrypter>(connection_->perspective()));
   crypto_stream->WriteCryptoData(ENCRYPTION_ZERO_RTT, data);
 
   EXPECT_CALL(*connection_, SendCryptoData(ENCRYPTION_INITIAL, 350, 1000))
diff --git a/quic/test_tools/simple_session_notifier_test.cc b/quic/test_tools/simple_session_notifier_test.cc
index b72a952..bc65f44 100644
--- a/quic/test_tools/simple_session_notifier_test.cc
+++ b/quic/test_tools/simple_session_notifier_test.cc
@@ -265,9 +265,9 @@
   producer.SaveCryptoData(ENCRYPTION_INITIAL, 500, crypto_data2);
   notifier_.WriteCryptoData(ENCRYPTION_INITIAL, 1024, 0);
   // Send crypto data [1024, 2048) in ENCRYPTION_ZERO_RTT.
-  connection_.SetDefaultEncryptionLevel(ENCRYPTION_ZERO_RTT);
   connection_.SetEncrypter(ENCRYPTION_ZERO_RTT, std::make_unique<NullEncrypter>(
                                                     Perspective::IS_CLIENT));
+  connection_.SetDefaultEncryptionLevel(ENCRYPTION_ZERO_RTT);
   EXPECT_CALL(connection_, SendCryptoData(ENCRYPTION_ZERO_RTT, 1024, 0))
       .WillOnce(Invoke(&connection_,
                        &MockQuicConnection::QuicConnection_SendCryptoData));