Add new methods to QuicFramer for controlling decrypters This CL is a roll forward of cl/242758726. I had to make test-only changes to fix the broken test //third_party/quic/core:tls_handshaker_test. gfe-relnote: Protected behind QUIC_VERSION_99 and quic_supports_tls_handshake PiperOrigin-RevId: 242988047 Change-Id: I6ba5970e58ae5f7d9ecc312affd5d86e1504ec7b
diff --git a/quic/core/http/quic_spdy_client_session_test.cc b/quic/core/http/quic_spdy_client_session_test.cc index 4e1dc9e..9a99ca7 100644 --- a/quic/core/http/quic_spdy_client_session_test.cc +++ b/quic/core/http/quic_spdy_client_session_test.cc
@@ -489,13 +489,22 @@ EXPECT_CALL(*connection_, OnError(_)).Times(1); // Verify that a decryptable packet with bad frames does close the connection. - QuicConnectionId connection_id = session_->connection()->connection_id(); + QuicConnectionId destination_connection_id = + session_->connection()->connection_id(); + QuicConnectionId source_connection_id = EmptyQuicConnectionId(); QuicFramerPeer::SetLastSerializedConnectionId( - QuicConnectionPeer::GetFramer(connection_), connection_id); + QuicConnectionPeer::GetFramer(connection_), destination_connection_id); ParsedQuicVersionVector versions = {GetParam()}; + bool version_flag = false; + QuicConnectionIdIncluded scid_included = CONNECTION_ID_ABSENT; + if (GetParam().transport_version > QUIC_VERSION_43) { + version_flag = true; + source_connection_id = destination_connection_id; + scid_included = CONNECTION_ID_PRESENT; + } std::unique_ptr<QuicEncryptedPacket> packet(ConstructMisFramedEncryptedPacket( - connection_id, EmptyQuicConnectionId(), false, false, 100, "data", - CONNECTION_ID_ABSENT, CONNECTION_ID_ABSENT, PACKET_4BYTE_PACKET_NUMBER, + destination_connection_id, source_connection_id, version_flag, false, 100, + "data", CONNECTION_ID_ABSENT, scid_included, PACKET_4BYTE_PACKET_NUMBER, &versions, Perspective::IS_SERVER)); std::unique_ptr<QuicReceivedPacket> received( ConstructReceivedPacket(*packet, QuicTime::Zero()));
diff --git a/quic/core/quic_connection.cc b/quic/core/quic_connection.cc index 8546076..307104d 100644 --- a/quic/core/quic_connection.cc +++ b/quic/core/quic_connection.cc
@@ -388,8 +388,7 @@ } QUIC_DLOG(INFO) << ENDPOINT << "Created connection with connection_id: " << connection_id - << " and version: " - << QuicVersionToString(transport_version()); + << " and version: " << ParsedQuicVersionToString(version()); QUIC_BUG_IF(!QuicUtils::IsConnectionIdValidForVersion(connection_id, transport_version())) @@ -2884,6 +2883,20 @@ } } +void QuicConnection::InstallDecrypter( + EncryptionLevel level, + std::unique_ptr<QuicDecrypter> decrypter) { + framer_.InstallDecrypter(level, std::move(decrypter)); + if (!undecryptable_packets_.empty() && + !process_undecryptable_packets_alarm_->IsSet()) { + process_undecryptable_packets_alarm_->Set(clock_->ApproximateNow()); + } +} + +void QuicConnection::RemoveDecrypter(EncryptionLevel level) { + framer_.RemoveDecrypter(level); +} + const QuicDecrypter* QuicConnection::decrypter() const { return framer_.decrypter(); } @@ -3907,6 +3920,13 @@ return packet_generator_.GetGuaranteedLargestMessagePayload(); } +uint32_t QuicConnection::cipher_id() const { + if (version().KnowsWhichDecrypterToUse()) { + return framer_.GetDecrypter(last_decrypted_packet_level_)->cipher_id(); + } + return framer_.decrypter()->cipher_id(); +} + bool QuicConnection::ShouldSetAckAlarm() const { DCHECK(ack_frame_updated()); if (ack_alarm_->IsSet()) {
diff --git a/quic/core/quic_connection.h b/quic/core/quic_connection.h index 602361e..d13458d 100644 --- a/quic/core/quic_connection.h +++ b/quic/core/quic_connection.h
@@ -672,6 +672,10 @@ std::unique_ptr<QuicDecrypter> decrypter, bool latch_once_used); + void InstallDecrypter(EncryptionLevel level, + std::unique_ptr<QuicDecrypter> decrypter); + void RemoveDecrypter(EncryptionLevel level); + const QuicDecrypter* decrypter() const; const QuicDecrypter* alternative_decrypter() const; @@ -765,8 +769,8 @@ // connection ID lengths do not change. QuicPacketLength GetGuaranteedLargestMessagePayload() const; - // Return the id of the cipher of the primary decrypter of the framer. - uint32_t cipher_id() const { return framer_.decrypter()->cipher_id(); } + // Returns the id of the cipher last used for decrypting packets. + uint32_t cipher_id() const; std::vector<std::unique_ptr<QuicEncryptedPacket>>* termination_packets() { return termination_packets_.get();
diff --git a/quic/core/quic_connection_test.cc b/quic/core/quic_connection_test.cc index cd53edf..4072ecd 100644 --- a/quic/core/quic_connection_test.cc +++ b/quic/core/quic_connection_test.cc
@@ -5,14 +5,15 @@ #include "net/third_party/quiche/src/quic/core/quic_connection.h" #include <errno.h> + #include <memory> #include <ostream> -#include <utility> - #include <string> +#include <utility> #include "net/third_party/quiche/src/quic/core/congestion_control/loss_detection_interface.h" #include "net/third_party/quiche/src/quic/core/congestion_control/send_algorithm_interface.h" +#include "net/third_party/quiche/src/quic/core/crypto/null_decrypter.h" #include "net/third_party/quiche/src/quic/core/crypto/null_encrypter.h" #include "net/third_party/quiche/src/quic/core/crypto/quic_decrypter.h" #include "net/third_party/quiche/src/quic/core/crypto/quic_encrypter.h" @@ -358,8 +359,21 @@ } if (use_tagging_decrypter_) { - framer_.framer()->SetDecrypter(ENCRYPTION_INITIAL, - QuicMakeUnique<TaggingDecrypter>()); + if (framer_.framer()->version().KnowsWhichDecrypterToUse()) { + framer_.framer()->InstallDecrypter(ENCRYPTION_INITIAL, + QuicMakeUnique<TaggingDecrypter>()); + framer_.framer()->InstallDecrypter(ENCRYPTION_ZERO_RTT, + QuicMakeUnique<TaggingDecrypter>()); + framer_.framer()->InstallDecrypter(ENCRYPTION_FORWARD_SECURE, + QuicMakeUnique<TaggingDecrypter>()); + } else { + framer_.framer()->SetDecrypter(ENCRYPTION_INITIAL, + QuicMakeUnique<TaggingDecrypter>()); + } + } else if (framer_.framer()->version().KnowsWhichDecrypterToUse()) { + framer_.framer()->InstallDecrypter( + ENCRYPTION_FORWARD_SECURE, + QuicMakeUnique<NullDecrypter>(Perspective::IS_SERVER)); } EXPECT_TRUE(framer_.ProcessPacket(packet)); if (block_on_next_write_) { @@ -967,6 +981,12 @@ .WillRepeatedly(Return(QuicTime::Zero())); EXPECT_CALL(*loss_algorithm_, DetectLosses(_, _, _, _, _, _)) .Times(AnyNumber()); + + if (connection_.version().KnowsWhichDecrypterToUse()) { + connection_.InstallDecrypter( + ENCRYPTION_FORWARD_SECURE, + QuicMakeUnique<NullDecrypter>(Perspective::IS_CLIENT)); + } } QuicConnectionTest(const QuicConnectionTest&) = delete; @@ -994,6 +1014,16 @@ void use_tagging_decrypter() { writer_->use_tagging_decrypter(); } + void SetDecrypter(EncryptionLevel level, + std::unique_ptr<QuicDecrypter> decrypter) { + if (connection_.version().KnowsWhichDecrypterToUse()) { + connection_.InstallDecrypter(level, std::move(decrypter)); + connection_.RemoveDecrypter(ENCRYPTION_INITIAL); + } else { + connection_.SetDecrypter(level, std::move(decrypter)); + } + } + void ProcessPacket(uint64_t number) { EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(1); ProcessDataPacket(number); @@ -1048,8 +1078,11 @@ void ForceProcessFramePacket(QuicFrame frame) { QuicFrames frames; frames.push_back(QuicFrame(frame)); - QuicPacketCreatorPeer::SetSendVersionInPacket( - &peer_creator_, connection_.perspective() == Perspective::IS_SERVER); + bool send_version = connection_.perspective() == Perspective::IS_SERVER; + if (connection_.version().KnowsWhichDecrypterToUse()) { + send_version = true; + } + QuicPacketCreatorPeer::SetSendVersionInPacket(&peer_creator_, send_version); QuicPacketHeader header; QuicPacketCreatorPeer::FillPacketHeader(&peer_creator_, &header); char encrypted_buffer[kMaxOutgoingPacketSize]; @@ -1081,6 +1114,16 @@ peer_framer_.perspective() == Perspective::IS_SERVER) { header.destination_connection_id_included = CONNECTION_ID_ABSENT; } + if (level == ENCRYPTION_INITIAL && + peer_framer_.version().KnowsWhichDecrypterToUse()) { + header.version_flag = true; + header.retry_token_length_length = VARIABLE_LENGTH_INTEGER_LENGTH_1; + header.length_length = VARIABLE_LENGTH_INTEGER_LENGTH_2; + if (peer_framer_.perspective() == Perspective::IS_SERVER) { + header.source_connection_id = connection_id_; + header.source_connection_id_included = CONNECTION_ID_PRESENT; + } + } header.packet_number = QuicPacketNumber(number); QuicFrames frames; frames.push_back(frame); @@ -1094,9 +1137,16 @@ QuicPacketCreatorPeer::GetEncryptionLevel(&peer_creator_), QuicMakeUnique<TaggingEncrypter>(0x01)); // Set the corresponding decrypter. - connection_.SetDecrypter( - QuicPacketCreatorPeer::GetEncryptionLevel(&peer_creator_), - QuicMakeUnique<StrictTaggingDecrypter>(0x01)); + if (connection_.version().KnowsWhichDecrypterToUse()) { + connection_.InstallDecrypter( + QuicPacketCreatorPeer::GetEncryptionLevel(&peer_creator_), + QuicMakeUnique<StrictTaggingDecrypter>(0x01)); + connection_.RemoveDecrypter(ENCRYPTION_INITIAL); + } else { + connection_.SetDecrypter( + QuicPacketCreatorPeer::GetEncryptionLevel(&peer_creator_), + QuicMakeUnique<StrictTaggingDecrypter>(0x01)); + } } char buffer[kMaxOutgoingPacketSize]; @@ -2962,8 +3012,8 @@ EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(1); peer_framer_.SetEncrypter(ENCRYPTION_FORWARD_SECURE, QuicMakeUnique<TaggingEncrypter>(0x01)); - connection_.SetDecrypter(ENCRYPTION_FORWARD_SECURE, - QuicMakeUnique<StrictTaggingDecrypter>(0x01)); + SetDecrypter(ENCRYPTION_FORWARD_SECURE, + QuicMakeUnique<StrictTaggingDecrypter>(0x01)); ProcessDataPacketAtLevel(2, false, ENCRYPTION_FORWARD_SECURE); EXPECT_EQ(0u, connection_.NumQueuedPackets()); @@ -3388,7 +3438,7 @@ 1, stream_id, QUIC_ERROR_PROCESSING_STREAM, 14))); } EXPECT_EQ(1u, writer_->frame_count()); - EXPECT_EQ(1u, writer_->rst_stream_frames().size()); + ASSERT_EQ(1u, writer_->rst_stream_frames().size()); EXPECT_EQ(stream_id, writer_->rst_stream_frames().front().stream_id); } @@ -4200,8 +4250,8 @@ // Transition to the new encryption state and process another encrypted packet // which should result in the original packet being processed. - connection_.SetDecrypter(ENCRYPTION_ZERO_RTT, - QuicMakeUnique<StrictTaggingDecrypter>(tag)); + SetDecrypter(ENCRYPTION_ZERO_RTT, + QuicMakeUnique<StrictTaggingDecrypter>(tag)); connection_.SetDefaultEncryptionLevel(ENCRYPTION_ZERO_RTT); connection_.SetEncrypter(ENCRYPTION_ZERO_RTT, QuicMakeUnique<TaggingEncrypter>(tag)); @@ -4274,8 +4324,8 @@ // Transition to the new encryption state and process another encrypted packet // which should result in the original packets being processed. EXPECT_FALSE(connection_.GetProcessUndecryptablePacketsAlarm()->IsSet()); - connection_.SetDecrypter(ENCRYPTION_ZERO_RTT, - QuicMakeUnique<StrictTaggingDecrypter>(tag)); + SetDecrypter(ENCRYPTION_ZERO_RTT, + QuicMakeUnique<StrictTaggingDecrypter>(tag)); EXPECT_TRUE(connection_.GetProcessUndecryptablePacketsAlarm()->IsSet()); connection_.SetDefaultEncryptionLevel(ENCRYPTION_ZERO_RTT); connection_.SetEncrypter(ENCRYPTION_ZERO_RTT, @@ -5608,8 +5658,8 @@ EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); EXPECT_FALSE(connection_.GetAckAlarm()->IsSet()); const uint8_t tag = 0x07; - connection_.SetDecrypter(ENCRYPTION_ZERO_RTT, - QuicMakeUnique<StrictTaggingDecrypter>(tag)); + SetDecrypter(ENCRYPTION_ZERO_RTT, + QuicMakeUnique<StrictTaggingDecrypter>(tag)); peer_framer_.SetEncrypter(ENCRYPTION_ZERO_RTT, QuicMakeUnique<TaggingEncrypter>(tag)); // Process a packet from the non-crypto stream. @@ -5650,8 +5700,8 @@ EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); EXPECT_FALSE(connection_.GetAckAlarm()->IsSet()); const uint8_t tag = 0x07; - connection_.SetDecrypter(ENCRYPTION_ZERO_RTT, - QuicMakeUnique<StrictTaggingDecrypter>(tag)); + SetDecrypter(ENCRYPTION_ZERO_RTT, + QuicMakeUnique<StrictTaggingDecrypter>(tag)); peer_framer_.SetEncrypter(ENCRYPTION_ZERO_RTT, QuicMakeUnique<TaggingEncrypter>(tag)); // Process a packet from the non-crypto stream. @@ -5731,8 +5781,8 @@ EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); EXPECT_FALSE(connection_.GetAckAlarm()->IsSet()); const uint8_t tag = 0x07; - connection_.SetDecrypter(ENCRYPTION_ZERO_RTT, - QuicMakeUnique<StrictTaggingDecrypter>(tag)); + SetDecrypter(ENCRYPTION_ZERO_RTT, + QuicMakeUnique<StrictTaggingDecrypter>(tag)); peer_framer_.SetEncrypter(ENCRYPTION_ZERO_RTT, QuicMakeUnique<TaggingEncrypter>(tag)); // Process a packet from the non-crypto stream. @@ -5793,8 +5843,8 @@ EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); EXPECT_FALSE(connection_.GetAckAlarm()->IsSet()); const uint8_t tag = 0x07; - connection_.SetDecrypter(ENCRYPTION_ZERO_RTT, - QuicMakeUnique<StrictTaggingDecrypter>(tag)); + SetDecrypter(ENCRYPTION_ZERO_RTT, + QuicMakeUnique<StrictTaggingDecrypter>(tag)); peer_framer_.SetEncrypter(ENCRYPTION_ZERO_RTT, QuicMakeUnique<TaggingEncrypter>(tag)); // Process a packet from the non-crypto stream. @@ -5931,8 +5981,8 @@ EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); EXPECT_FALSE(connection_.GetAckAlarm()->IsSet()); const uint8_t tag = 0x07; - connection_.SetDecrypter(ENCRYPTION_ZERO_RTT, - QuicMakeUnique<StrictTaggingDecrypter>(tag)); + SetDecrypter(ENCRYPTION_ZERO_RTT, + QuicMakeUnique<StrictTaggingDecrypter>(tag)); peer_framer_.SetEncrypter(ENCRYPTION_ZERO_RTT, QuicMakeUnique<TaggingEncrypter>(tag)); // Process a packet from the non-crypto stream. @@ -5987,8 +6037,8 @@ EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); EXPECT_FALSE(connection_.GetAckAlarm()->IsSet()); const uint8_t tag = 0x07; - connection_.SetDecrypter(ENCRYPTION_ZERO_RTT, - QuicMakeUnique<StrictTaggingDecrypter>(tag)); + SetDecrypter(ENCRYPTION_ZERO_RTT, + QuicMakeUnique<StrictTaggingDecrypter>(tag)); peer_framer_.SetEncrypter(ENCRYPTION_ZERO_RTT, QuicMakeUnique<TaggingEncrypter>(tag)); // Process a packet from the non-crypto stream. @@ -6048,8 +6098,8 @@ EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); EXPECT_FALSE(connection_.GetAckAlarm()->IsSet()); const uint8_t tag = 0x07; - connection_.SetDecrypter(ENCRYPTION_ZERO_RTT, - QuicMakeUnique<StrictTaggingDecrypter>(tag)); + SetDecrypter(ENCRYPTION_ZERO_RTT, + QuicMakeUnique<StrictTaggingDecrypter>(tag)); peer_framer_.SetEncrypter(ENCRYPTION_ZERO_RTT, QuicMakeUnique<TaggingEncrypter>(tag)); // Process a packet from the non-crypto stream. @@ -6115,8 +6165,8 @@ EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); EXPECT_FALSE(connection_.GetAckAlarm()->IsSet()); const uint8_t tag = 0x07; - connection_.SetDecrypter(ENCRYPTION_ZERO_RTT, - QuicMakeUnique<StrictTaggingDecrypter>(tag)); + SetDecrypter(ENCRYPTION_ZERO_RTT, + QuicMakeUnique<StrictTaggingDecrypter>(tag)); peer_framer_.SetEncrypter(ENCRYPTION_ZERO_RTT, QuicMakeUnique<TaggingEncrypter>(tag)); // Process a packet from the non-crypto stream. @@ -6202,8 +6252,8 @@ EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); EXPECT_FALSE(connection_.GetAckAlarm()->IsSet()); const uint8_t tag = 0x07; - connection_.SetDecrypter(ENCRYPTION_ZERO_RTT, - QuicMakeUnique<StrictTaggingDecrypter>(tag)); + SetDecrypter(ENCRYPTION_ZERO_RTT, + QuicMakeUnique<StrictTaggingDecrypter>(tag)); peer_framer_.SetEncrypter(ENCRYPTION_ZERO_RTT, QuicMakeUnique<TaggingEncrypter>(tag)); // Process a packet from the non-crypto stream. @@ -6273,8 +6323,8 @@ EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); EXPECT_FALSE(connection_.GetAckAlarm()->IsSet()); const uint8_t tag = 0x07; - connection_.SetDecrypter(ENCRYPTION_ZERO_RTT, - QuicMakeUnique<StrictTaggingDecrypter>(tag)); + SetDecrypter(ENCRYPTION_ZERO_RTT, + QuicMakeUnique<StrictTaggingDecrypter>(tag)); peer_framer_.SetEncrypter(ENCRYPTION_ZERO_RTT, QuicMakeUnique<TaggingEncrypter>(tag)); // Process a packet from the non-crypto stream. @@ -6444,8 +6494,8 @@ EXPECT_CALL(visitor_, OnStreamFrame(_)); peer_framer_.SetEncrypter(ENCRYPTION_FORWARD_SECURE, QuicMakeUnique<TaggingEncrypter>(0x01)); - connection_.SetDecrypter(ENCRYPTION_FORWARD_SECURE, - QuicMakeUnique<StrictTaggingDecrypter>(0x01)); + SetDecrypter(ENCRYPTION_FORWARD_SECURE, + QuicMakeUnique<StrictTaggingDecrypter>(0x01)); ProcessDataPacketAtLevel(1, false, ENCRYPTION_FORWARD_SECURE); connection_.SendStreamDataWithString( GetNthClientInitiatedStreamId(1, connection_.transport_version()), "foo", @@ -8730,8 +8780,8 @@ EXPECT_TRUE(connection_.GetAckAlarm()->IsSet()); peer_framer_.SetEncrypter(ENCRYPTION_ZERO_RTT, QuicMakeUnique<TaggingEncrypter>(0x02)); - connection_.SetDecrypter(ENCRYPTION_ZERO_RTT, - QuicMakeUnique<StrictTaggingDecrypter>(0x02)); + SetDecrypter(ENCRYPTION_ZERO_RTT, + QuicMakeUnique<StrictTaggingDecrypter>(0x02)); connection_.SetEncrypter(ENCRYPTION_INITIAL, QuicMakeUnique<TaggingEncrypter>(0x02)); // Receives packet 1000 in application data. @@ -8758,8 +8808,8 @@ peer_framer_.SetEncrypter(ENCRYPTION_FORWARD_SECURE, QuicMakeUnique<TaggingEncrypter>(0x02)); - connection_.SetDecrypter(ENCRYPTION_FORWARD_SECURE, - QuicMakeUnique<StrictTaggingDecrypter>(0x02)); + SetDecrypter(ENCRYPTION_FORWARD_SECURE, + QuicMakeUnique<StrictTaggingDecrypter>(0x02)); // Verify zero rtt and forward secure packets get acked in the same packet. EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); ProcessDataPacketAtLevel(1003, false, ENCRYPTION_FORWARD_SECURE); @@ -8778,8 +8828,8 @@ EXPECT_TRUE(connection_.GetAckAlarm()->IsSet()); peer_framer_.SetEncrypter(ENCRYPTION_ZERO_RTT, QuicMakeUnique<TaggingEncrypter>(0x02)); - connection_.SetDecrypter(ENCRYPTION_ZERO_RTT, - QuicMakeUnique<StrictTaggingDecrypter>(0x02)); + SetDecrypter(ENCRYPTION_ZERO_RTT, + QuicMakeUnique<StrictTaggingDecrypter>(0x02)); connection_.SetEncrypter(ENCRYPTION_INITIAL, QuicMakeUnique<TaggingEncrypter>(0x02)); // Receives packet 1000 in application data.
diff --git a/quic/core/quic_crypto_client_handshaker.cc b/quic/core/quic_crypto_client_handshaker.cc index 013ab3f..d6c9af4 100644 --- a/quic/core/quic_crypto_client_handshaker.cc +++ b/quic/core/quic_crypto_client_handshaker.cc
@@ -375,10 +375,16 @@ crypto_config_->pad_full_hello()); SendHandshakeMessage(out); // Be prepared to decrypt with the new server write key. - session()->connection()->SetAlternativeDecrypter( - ENCRYPTION_ZERO_RTT, - std::move(crypto_negotiated_params_->initial_crypters.decrypter), - true /* latch once used */); + if (session()->connection()->version().KnowsWhichDecrypterToUse()) { + session()->connection()->InstallDecrypter( + ENCRYPTION_ZERO_RTT, + std::move(crypto_negotiated_params_->initial_crypters.decrypter)); + } else { + session()->connection()->SetAlternativeDecrypter( + ENCRYPTION_ZERO_RTT, + std::move(crypto_negotiated_params_->initial_crypters.decrypter), + true /* latch once used */); + } // Send subsequent packets under encryption on the assumption that the // server will accept the handshake. session()->connection()->SetEncrypter( @@ -584,10 +590,8 @@ // to see whether the response was a reject, and if so, move on to // the reject-processing state. if ((in->tag() == kREJ) || (in->tag() == kSREJ)) { - // alternative_decrypter will be nullptr if the original alternative - // decrypter latched and became the primary decrypter. That happens - // if we received a message encrypted with the INITIAL key. - if (session()->connection()->alternative_decrypter() == nullptr) { + // A reject message must be sent in ENCRYPTION_INITIAL. + if (session()->connection()->last_decrypted_level() != ENCRYPTION_INITIAL) { // The rejection was sent encrypted! stream_->CloseConnectionWithDetails( QUIC_CRYPTO_ENCRYPTION_LEVEL_INCORRECT, "encrypted REJ message"); @@ -603,10 +607,7 @@ return; } - // alternative_decrypter will be nullptr if the original alternative - // decrypter latched and became the primary decrypter. That happens - // if we received a message encrypted with the INITIAL key. - if (session()->connection()->alternative_decrypter() != nullptr) { + if (session()->connection()->last_decrypted_level() == ENCRYPTION_INITIAL) { // The server hello was sent without encryption. stream_->CloseConnectionWithDetails(QUIC_CRYPTO_ENCRYPTION_LEVEL_INCORRECT, "unencrypted SHLO message"); @@ -638,9 +639,14 @@ // has been floated that the server shouldn't send packets encrypted // with the FORWARD_SECURE key until it receives a FORWARD_SECURE // packet from the client. - session()->connection()->SetAlternativeDecrypter( - ENCRYPTION_FORWARD_SECURE, std::move(crypters->decrypter), - false /* don't latch */); + if (session()->connection()->version().KnowsWhichDecrypterToUse()) { + session()->connection()->InstallDecrypter(ENCRYPTION_FORWARD_SECURE, + std::move(crypters->decrypter)); + } else { + session()->connection()->SetAlternativeDecrypter( + ENCRYPTION_FORWARD_SECURE, std::move(crypters->decrypter), + false /* don't latch */); + } session()->connection()->SetEncrypter(ENCRYPTION_FORWARD_SECURE, std::move(crypters->encrypter)); session()->connection()->SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE);
diff --git a/quic/core/quic_crypto_server_handshaker.cc b/quic/core/quic_crypto_server_handshaker.cc index cd3cce9..c0e61ef 100644 --- a/quic/core/quic_crypto_server_handshaker.cc +++ b/quic/core/quic_crypto_server_handshaker.cc
@@ -230,9 +230,16 @@ session()->connection()->SetDefaultEncryptionLevel(ENCRYPTION_ZERO_RTT); // Set the decrypter immediately so that we no longer accept unencrypted // packets. - session()->connection()->SetDecrypter( - ENCRYPTION_ZERO_RTT, - std::move(crypto_negotiated_params_->initial_crypters.decrypter)); + if (session()->connection()->version().KnowsWhichDecrypterToUse()) { + session()->connection()->InstallDecrypter( + ENCRYPTION_ZERO_RTT, + std::move(crypto_negotiated_params_->initial_crypters.decrypter)); + session()->connection()->RemoveDecrypter(ENCRYPTION_INITIAL); + } else { + session()->connection()->SetDecrypter( + ENCRYPTION_ZERO_RTT, + std::move(crypto_negotiated_params_->initial_crypters.decrypter)); + } session()->connection()->SetDiversificationNonce(*diversification_nonce); session()->connection()->set_fully_pad_crypto_hadshake_packets( @@ -244,10 +251,17 @@ std::move(crypto_negotiated_params_->forward_secure_crypters.encrypter)); session()->connection()->SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); - session()->connection()->SetAlternativeDecrypter( - ENCRYPTION_FORWARD_SECURE, - std::move(crypto_negotiated_params_->forward_secure_crypters.decrypter), - false /* don't latch */); + if (session()->connection()->version().KnowsWhichDecrypterToUse()) { + session()->connection()->InstallDecrypter( + ENCRYPTION_FORWARD_SECURE, + std::move( + crypto_negotiated_params_->forward_secure_crypters.decrypter)); + } else { + session()->connection()->SetAlternativeDecrypter( + ENCRYPTION_FORWARD_SECURE, + std::move(crypto_negotiated_params_->forward_secure_crypters.decrypter), + false /* don't latch */); + } encryption_established_ = true; handshake_confirmed_ = true;
diff --git a/quic/core/quic_framer.cc b/quic/core/quic_framer.cc index f99b60f..0e30453 100644 --- a/quic/core/quic_framer.cc +++ b/quic/core/quic_framer.cc
@@ -346,6 +346,32 @@ return NUM_PACKET_NUMBER_SPACES; } +EncryptionLevel GetEncryptionLevel(const QuicPacketHeader& header) { + switch (header.form) { + case GOOGLE_QUIC_PACKET: + QUIC_BUG << "Cannot determine EncryptionLevel from Google QUIC header"; + break; + case IETF_QUIC_SHORT_HEADER_PACKET: + return ENCRYPTION_FORWARD_SECURE; + case IETF_QUIC_LONG_HEADER_PACKET: + switch (header.long_packet_type) { + case INITIAL: + return ENCRYPTION_INITIAL; + case HANDSHAKE: + return ENCRYPTION_HANDSHAKE; + case ZERO_RTT_PROTECTED: + return ENCRYPTION_ZERO_RTT; + case VERSION_NEGOTIATION: + case RETRY: + case INVALID_PACKET_TYPE: + QUIC_BUG << "No encryption used with type " + << QuicUtils::QuicLongHeaderTypetoString( + header.long_packet_type); + } + } + return NUM_ENCRYPTION_LEVELS; +} + QuicStringPiece TruncateErrorString(QuicStringPiece error) { if (error.length() <= kMaxErrorStringLength) { return error; @@ -3861,6 +3887,7 @@ std::unique_ptr<QuicDecrypter> decrypter) { DCHECK_EQ(alternative_decrypter_level_, NUM_ENCRYPTION_LEVELS); DCHECK_GE(level, decrypter_level_); + DCHECK(!version_.KnowsWhichDecrypterToUse()); decrypter_[decrypter_level_] = nullptr; decrypter_[level] = std::move(decrypter); decrypter_level_ = level; @@ -3871,6 +3898,7 @@ std::unique_ptr<QuicDecrypter> decrypter, bool latch_once_used) { DCHECK_NE(level, decrypter_level_); + DCHECK(!version_.KnowsWhichDecrypterToUse()); if (alternative_decrypter_level_ != NUM_ENCRYPTION_LEVELS) { decrypter_[alternative_decrypter_level_] = nullptr; } @@ -3879,6 +3907,22 @@ alternative_decrypter_latch_ = latch_once_used; } +void QuicFramer::InstallDecrypter(EncryptionLevel level, + std::unique_ptr<QuicDecrypter> decrypter) { + DCHECK(version_.KnowsWhichDecrypterToUse()); + decrypter_[level] = std::move(decrypter); +} + +void QuicFramer::RemoveDecrypter(EncryptionLevel level) { + DCHECK(version_.KnowsWhichDecrypterToUse()); + decrypter_[level] = nullptr; +} + +const QuicDecrypter* QuicFramer::GetDecrypter(EncryptionLevel level) const { + DCHECK(version_.KnowsWhichDecrypterToUse()); + return decrypter_[level].get(); +} + const QuicDecrypter* QuicFramer::decrypter() const { return decrypter_[decrypter_level_].get(); } @@ -3974,18 +4018,31 @@ size_t buffer_length, size_t* decrypted_length, EncryptionLevel* decrypted_level) { - DCHECK(decrypter_[decrypter_level_] != nullptr); + EncryptionLevel level = decrypter_level_; + QuicDecrypter* decrypter = decrypter_[level].get(); QuicDecrypter* alternative_decrypter = nullptr; - if (alternative_decrypter_level_ != NUM_ENCRYPTION_LEVELS) { + if (version().KnowsWhichDecrypterToUse()) { + level = GetEncryptionLevel(header); + decrypter = decrypter_[level].get(); + if (decrypter == nullptr) { + return false; + } + if (level == ENCRYPTION_ZERO_RTT && + perspective_ == Perspective::IS_CLIENT && header.nonce != nullptr) { + decrypter->SetDiversificationNonce(*header.nonce); + } + } else if (alternative_decrypter_level_ != NUM_ENCRYPTION_LEVELS) { alternative_decrypter = decrypter_[alternative_decrypter_level_].get(); } - bool success = decrypter_[decrypter_level_]->DecryptPacket( + DCHECK(decrypter != nullptr); + + bool success = decrypter->DecryptPacket( header.packet_number.ToUint64(), associated_data, encrypted, decrypted_buffer, decrypted_length, buffer_length); if (success) { - visitor_->OnDecryptedPacket(decrypter_level_); - *decrypted_level = decrypter_level_; + visitor_->OnDecryptedPacket(level); + *decrypted_level = level; } else if (alternative_decrypter != nullptr) { if (header.nonce != nullptr) { DCHECK_EQ(perspective_, Perspective::IS_CLIENT);
diff --git a/quic/core/quic_framer.h b/quic/core/quic_framer.h index fc189b2..8d2fd93 100644 --- a/quic/core/quic_framer.h +++ b/quic/core/quic_framer.h
@@ -475,6 +475,11 @@ std::unique_ptr<QuicDecrypter> decrypter, bool latch_once_used); + void InstallDecrypter(EncryptionLevel level, + std::unique_ptr<QuicDecrypter> decrypter); + void RemoveDecrypter(EncryptionLevel level); + + const QuicDecrypter* GetDecrypter(EncryptionLevel level) const; const QuicDecrypter* decrypter() const; const QuicDecrypter* alternative_decrypter() const;
diff --git a/quic/core/quic_framer_test.cc b/quic/core/quic_framer_test.cc index bc1dc5c..4cb837e 100644 --- a/quic/core/quic_framer_test.cc +++ b/quic/core/quic_framer_test.cc
@@ -456,8 +456,13 @@ kQuicDefaultConnectionIdLength) { SetQuicFlag(&FLAGS_quic_supports_tls_handshake, true); framer_.set_version(version_); - framer_.SetDecrypter(ENCRYPTION_INITIAL, - std::unique_ptr<QuicDecrypter>(decrypter_)); + if (framer_.version().KnowsWhichDecrypterToUse()) { + framer_.InstallDecrypter(ENCRYPTION_INITIAL, + std::unique_ptr<QuicDecrypter>(decrypter_)); + } else { + framer_.SetDecrypter(ENCRYPTION_INITIAL, + std::unique_ptr<QuicDecrypter>(decrypter_)); + } framer_.SetEncrypter(ENCRYPTION_INITIAL, std::unique_ptr<QuicEncrypter>(encrypter_)); @@ -465,6 +470,14 @@ framer_.InferPacketHeaderTypeFromVersion(); } + void SetDecrypterLevel(EncryptionLevel level) { + if (!framer_.version().KnowsWhichDecrypterToUse()) { + return; + } + decrypter_ = new TestDecrypter(); + framer_.InstallDecrypter(level, std::unique_ptr<QuicDecrypter>(decrypter_)); + } + // Helper function to get unsigned char representation of the handshake // protocol byte of the current QUIC version number. unsigned char GetQuicVersionProtocolByte() { @@ -803,6 +816,7 @@ } TEST_P(QuicFramerTest, LargePacket) { + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off unsigned char packet[kMaxIncomingPacketSize + 1] = { // public flags (8 byte connection_id) @@ -960,6 +974,7 @@ } TEST_P(QuicFramerTest, PacketHeaderWith0ByteConnectionId) { + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); QuicFramerPeer::SetLastSerializedConnectionId(&framer_, FramerTestConnectionId()); QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); @@ -1015,6 +1030,7 @@ } TEST_P(QuicFramerTest, PacketHeaderWithVersionFlag) { + SetDecrypterLevel(ENCRYPTION_ZERO_RTT); // clang-format off PacketFragments packet = { // public flags (0 byte connection_id) @@ -1114,6 +1130,7 @@ } TEST_P(QuicFramerTest, PacketHeaderWith4BytePacketNumber) { + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); QuicFramerPeer::SetLargestPacketNumber(&framer_, kPacketNumber - 2); // clang-format off @@ -1173,6 +1190,7 @@ } TEST_P(QuicFramerTest, PacketHeaderWith2BytePacketNumber) { + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); QuicFramerPeer::SetLargestPacketNumber(&framer_, kPacketNumber - 2); // clang-format off @@ -1233,6 +1251,7 @@ } TEST_P(QuicFramerTest, PacketHeaderWith1BytePacketNumber) { + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); QuicFramerPeer::SetLargestPacketNumber(&framer_, kPacketNumber - 2); // clang-format off @@ -1294,6 +1313,7 @@ } TEST_P(QuicFramerTest, PacketNumberDecreasesThenIncreases) { + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // Test the case when a packet is received from the past and future packet // numbers are still calculated relative to the largest received packet. QuicPacketHeader header; @@ -1360,6 +1380,7 @@ } TEST_P(QuicFramerTest, PacketWithDiversificationNonce) { + SetDecrypterLevel(ENCRYPTION_ZERO_RTT); // clang-format off unsigned char packet[] = { // public flags: includes nonce flag @@ -1529,6 +1550,7 @@ } TEST_P(QuicFramerTest, PaddingFrame) { + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off unsigned char packet[] = { // public flags (8 byte connection_id) @@ -1673,6 +1695,7 @@ } TEST_P(QuicFramerTest, StreamFrame) { + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off PacketFragments packet = { // public flags (8 byte connection_id) @@ -1829,6 +1852,7 @@ if (framer_.transport_version() != QUIC_VERSION_99) { return; } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off PacketFragments packet = { // type (short header, 4 byte packet number) @@ -1882,11 +1906,18 @@ return; } QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); - framer_.SetDecrypter(ENCRYPTION_INITIAL, - QuicMakeUnique<NullDecrypter>(Perspective::IS_CLIENT)); decrypter_ = new test::TestDecrypter(); - framer_.SetAlternativeDecrypter( - ENCRYPTION_ZERO_RTT, std::unique_ptr<QuicDecrypter>(decrypter_), false); + if (framer_.version().KnowsWhichDecrypterToUse()) { + framer_.InstallDecrypter(ENCRYPTION_INITIAL, QuicMakeUnique<NullDecrypter>( + Perspective::IS_CLIENT)); + framer_.InstallDecrypter(ENCRYPTION_ZERO_RTT, + std::unique_ptr<QuicDecrypter>(decrypter_)); + } else { + framer_.SetDecrypter(ENCRYPTION_INITIAL, + QuicMakeUnique<NullDecrypter>(Perspective::IS_CLIENT)); + framer_.SetAlternativeDecrypter( + ENCRYPTION_ZERO_RTT, std::unique_ptr<QuicDecrypter>(decrypter_), false); + } // clang-format off unsigned char packet[] = { @@ -2031,6 +2062,7 @@ } TEST_P(QuicFramerTest, StreamFrame2ByteStreamId) { + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off PacketFragments packet = { // public flags (8 byte connection_id) @@ -2182,6 +2214,7 @@ } TEST_P(QuicFramerTest, StreamFrame1ByteStreamId) { + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off PacketFragments packet = { // public flags (8 byte connection_id) @@ -2333,6 +2366,7 @@ } TEST_P(QuicFramerTest, StreamFrameWithVersion) { + SetDecrypterLevel(ENCRYPTION_ZERO_RTT); // clang-format off PacketFragments packet = { // public flags (version, 8 byte connection_id) @@ -2522,6 +2556,7 @@ } TEST_P(QuicFramerTest, RejectPacket) { + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); visitor_.accept_packet_ = false; // clang-format off @@ -2667,6 +2702,7 @@ } TEST_P(QuicFramerTest, AckFrameOneAckBlock) { + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off PacketFragments packet = { // public flags (8 byte connection_id) @@ -2818,6 +2854,7 @@ // and handles the case where the first ack block is larger than the // largest_acked packet. TEST_P(QuicFramerTest, FirstAckFrameUnderflow) { + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off PacketFragments packet = { // public flags (8 byte connection_id) @@ -2952,6 +2989,7 @@ // for now, only v99 return; } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off PacketFragments packet99 = { // type (short header, 4 byte packet number) @@ -3009,6 +3047,7 @@ // for now, only v99 return; } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off PacketFragments packet99 = { // type (short header, 4 byte packet number) @@ -3064,6 +3103,7 @@ // for now, only v99 return; } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off PacketFragments packet99 = { // type (short header, 4 byte packet number) @@ -3113,6 +3153,7 @@ // for now, only v99 return; } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off PacketFragments packet99 = { // type (short header, 4 byte packet number) @@ -3161,6 +3202,7 @@ // for now, only v99 return; } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off PacketFragments packet99 = { // type (short header, 4 byte packet number) @@ -3359,6 +3401,7 @@ } TEST_P(QuicFramerTest, AckFrameOneAckBlockMaxLength) { + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off PacketFragments packet = { // public flags (8 byte connection_id) @@ -3501,6 +3544,7 @@ // Tests ability to handle multiple ackblocks after the first ack // block. Non-version-99 tests include multiple timestamps as well. TEST_P(QuicFramerTest, AckFrameTwoTimeStampsMultipleAckBlocks) { + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off PacketFragments packet = { // public flags (8 byte connection_id) @@ -4134,6 +4178,7 @@ } TEST_P(QuicFramerTest, RstStreamFrame) { + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off PacketFragments packet = { // public flags (8 byte connection_id) @@ -4259,6 +4304,7 @@ } TEST_P(QuicFramerTest, ConnectionCloseFrame) { + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off PacketFragments packet = { // public flags (8 byte connection_id) @@ -4411,6 +4457,7 @@ // This frame does not exist in versions other than 99. return; } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off PacketFragments packet99 = { @@ -4689,6 +4736,7 @@ // This frame is available only in version 99. return; } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off PacketFragments packet99 = { // type (short header, 4 byte packet number) @@ -4732,6 +4780,7 @@ // This frame available only in version 99. return; } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off PacketFragments packet99 = { // type (short header, 4 byte packet number) @@ -4773,6 +4822,7 @@ } TEST_P(QuicFramerTest, BlockedFrame) { + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off PacketFragments packet = { // public flags (8 byte connection_id) @@ -4882,6 +4932,7 @@ } TEST_P(QuicFramerTest, PingFrame) { + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off unsigned char packet[] = { // public flags (8 byte connection_id) @@ -4965,6 +5016,7 @@ if (framer_.transport_version() <= QUIC_VERSION_44) { return; } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off PacketFragments packet45 = { // type (short header, 4 byte packet number) @@ -5267,11 +5319,18 @@ return; } QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); - framer_.SetDecrypter(ENCRYPTION_INITIAL, - QuicMakeUnique<NullDecrypter>(Perspective::IS_CLIENT)); decrypter_ = new test::TestDecrypter(); - framer_.SetAlternativeDecrypter( - ENCRYPTION_ZERO_RTT, std::unique_ptr<QuicDecrypter>(decrypter_), false); + if (framer_.version().KnowsWhichDecrypterToUse()) { + framer_.InstallDecrypter(ENCRYPTION_INITIAL, QuicMakeUnique<NullDecrypter>( + Perspective::IS_CLIENT)); + framer_.InstallDecrypter(ENCRYPTION_ZERO_RTT, + std::unique_ptr<QuicDecrypter>(decrypter_)); + } else { + framer_.SetDecrypter(ENCRYPTION_INITIAL, + QuicMakeUnique<NullDecrypter>(Perspective::IS_CLIENT)); + framer_.SetAlternativeDecrypter( + ENCRYPTION_ZERO_RTT, std::unique_ptr<QuicDecrypter>(decrypter_), false); + } // This packet cannot be decrypted because diversification nonce is missing. QuicEncryptedPacket encrypted(AsChars(packet), QUIC_ARRAYSIZE(packet), false); EXPECT_TRUE(framer_.ProcessPacket(encrypted)); @@ -5297,11 +5356,18 @@ return; } QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); - framer_.SetDecrypter(ENCRYPTION_INITIAL, - QuicMakeUnique<NullDecrypter>(Perspective::IS_CLIENT)); decrypter_ = new test::TestDecrypter(); - framer_.SetAlternativeDecrypter( - ENCRYPTION_ZERO_RTT, std::unique_ptr<QuicDecrypter>(decrypter_), false); + if (framer_.version().KnowsWhichDecrypterToUse()) { + framer_.InstallDecrypter(ENCRYPTION_INITIAL, QuicMakeUnique<NullDecrypter>( + Perspective::IS_CLIENT)); + framer_.InstallDecrypter(ENCRYPTION_ZERO_RTT, + std::unique_ptr<QuicDecrypter>(decrypter_)); + } else { + framer_.SetDecrypter(ENCRYPTION_INITIAL, + QuicMakeUnique<NullDecrypter>(Perspective::IS_CLIENT)); + framer_.SetAlternativeDecrypter( + ENCRYPTION_ZERO_RTT, std::unique_ptr<QuicDecrypter>(decrypter_), false); + } // This packet cannot be decrypted because diversification nonce is missing. QuicEncryptedPacket encrypted(AsChars(packet), QUIC_ARRAYSIZE(packet), false); EXPECT_FALSE(framer_.ProcessPacket(encrypted)); @@ -6207,6 +6273,7 @@ // CRYPTO frames aren't supported prior to v46. return; } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off PacketFragments packet = { @@ -9462,6 +9529,7 @@ } TEST_P(QuicFramerTest, StopPacketProcessing) { + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off unsigned char packet[] = { // public flags (8 byte connection_id) @@ -9656,10 +9724,14 @@ TEST_P(QuicFramerTest, ConstructEncryptedPacket) { // Since we are using ConstructEncryptedPacket, we have to set the framer's // crypto to be Null. - framer_.SetDecrypter(ENCRYPTION_INITIAL, - QuicMakeUnique<NullDecrypter>(framer_.perspective())); - framer_.SetEncrypter(ENCRYPTION_INITIAL, - QuicMakeUnique<NullEncrypter>(framer_.perspective())); + if (framer_.version().KnowsWhichDecrypterToUse()) { + framer_.InstallDecrypter( + ENCRYPTION_FORWARD_SECURE, + QuicMakeUnique<NullDecrypter>(framer_.perspective())); + } else { + framer_.SetDecrypter(ENCRYPTION_INITIAL, + QuicMakeUnique<NullDecrypter>(framer_.perspective())); + } ParsedQuicVersionVector versions; versions.push_back(framer_.version()); std::unique_ptr<QuicEncryptedPacket> packet(ConstructEncryptedPacket( @@ -9694,10 +9766,16 @@ // Verify that the packet returned by ConstructMisFramedEncryptedPacket() // does cause the framer to return an error. TEST_P(QuicFramerTest, ConstructMisFramedEncryptedPacket) { + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // Since we are using ConstructEncryptedPacket, we have to set the framer's // crypto to be Null. - framer_.SetDecrypter(ENCRYPTION_INITIAL, - QuicMakeUnique<NullDecrypter>(framer_.perspective())); + if (framer_.version().KnowsWhichDecrypterToUse()) { + framer_.InstallDecrypter(ENCRYPTION_INITIAL, QuicMakeUnique<NullDecrypter>( + framer_.perspective())); + } else { + framer_.SetDecrypter(ENCRYPTION_INITIAL, + QuicMakeUnique<NullDecrypter>(framer_.perspective())); + } framer_.SetEncrypter(ENCRYPTION_INITIAL, QuicMakeUnique<NullEncrypter>(framer_.perspective())); ParsedQuicVersionVector versions; @@ -9870,6 +9948,7 @@ if (framer_.transport_version() != QUIC_VERSION_99) { return; } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off PacketFragments packet99 = { @@ -9952,6 +10031,7 @@ if (framer_.transport_version() != QUIC_VERSION_99) { return; } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off PacketFragments packet99 = { @@ -10039,6 +10119,7 @@ if (framer_.transport_version() != QUIC_VERSION_99) { return; } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off PacketFragments packet99 = { @@ -10086,6 +10167,7 @@ if (framer_.transport_version() != QUIC_VERSION_99) { return; } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off PacketFragments packet99 = { @@ -10136,6 +10218,7 @@ if (framer_.transport_version() != QUIC_VERSION_99) { return; } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off PacketFragments packet99 = { @@ -10184,6 +10267,7 @@ if (framer_.transport_version() != QUIC_VERSION_99) { return; } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off PacketFragments packet99 = { @@ -10237,6 +10321,7 @@ if (framer_.transport_version() != QUIC_VERSION_99) { return; } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off unsigned char packet99[] = { @@ -10280,6 +10365,7 @@ if (framer_.transport_version() != QUIC_VERSION_99) { return; } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off unsigned char packet99[] = { @@ -10329,6 +10415,7 @@ if (framer_.transport_version() != QUIC_VERSION_99) { return; } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off unsigned char packet99[] = { @@ -10373,6 +10460,7 @@ if (framer_.transport_version() != QUIC_VERSION_99) { return; } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off unsigned char packet99[] = { @@ -10420,6 +10508,7 @@ if (framer_.transport_version() != QUIC_VERSION_99) { return; } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off unsigned char packet99[] = { @@ -10449,6 +10538,7 @@ if (framer_.transport_version() != QUIC_VERSION_99) { return; } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off PacketFragments packet99 = { @@ -10501,6 +10591,7 @@ if (framer_.transport_version() != QUIC_VERSION_99) { return; } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off PacketFragments packet99 = { @@ -10548,6 +10639,7 @@ if (framer_.transport_version() != QUIC_VERSION_99) { return; } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off PacketFragments packet99 = { @@ -10596,6 +10688,7 @@ if (framer_.transport_version() != QUIC_VERSION_99) { return; } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off PacketFragments packet99 = { @@ -10647,6 +10740,7 @@ if (framer_.transport_version() != QUIC_VERSION_99) { return; } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off unsigned char packet99[] = { @@ -10681,6 +10775,7 @@ if (framer_.transport_version() != QUIC_VERSION_99) { return; } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off unsigned char packet99[] = { @@ -11080,6 +11175,7 @@ // This frame is only for version 99. return; } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off PacketFragments packet99 = { // type (short header, 4 byte packet number) @@ -11135,6 +11231,7 @@ // This frame is only for version 99. return; } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off PacketFragments packet99 = { // type (short header, 4 byte packet number) @@ -11192,6 +11289,7 @@ // The NEW_CONNECTION_ID frame is only for version 99. return; } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off PacketFragments packet99 = { // type (short header, 4 byte packet number) @@ -11284,6 +11382,7 @@ // This frame is only for version 99. return; } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off PacketFragments packet = { // type (short header, 4 byte packet number) @@ -11376,6 +11475,7 @@ if (framer_.transport_version() != QUIC_VERSION_99) { return; } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off PacketFragments packet99 = { @@ -11463,6 +11563,7 @@ if (framer_.transport_version() != QUIC_VERSION_99) { return; } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off PacketFragments packet99 = { @@ -11545,6 +11646,7 @@ if (framer_.transport_version() != QUIC_VERSION_99) { return; } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off PacketFragments packet99 = { @@ -11710,6 +11812,7 @@ if (framer_.transport_version() != QUIC_VERSION_99) { return; } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off PacketFragments packet = { // type (short header, 4 byte packet number) @@ -11741,6 +11844,7 @@ if (framer_.transport_version() != QUIC_VERSION_99) { return; } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off PacketFragments packet = { @@ -11773,6 +11877,7 @@ if (framer_.transport_version() != QUIC_VERSION_99) { return; } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off PacketFragments packet = { @@ -11805,6 +11910,7 @@ if (framer_.transport_version() != QUIC_VERSION_99) { return; } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off PacketFragments packet = { // type (short header, 4 byte packet number) @@ -11840,6 +11946,7 @@ if (framer_.transport_version() != QUIC_VERSION_99) { return; } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off PacketFragments packet = { @@ -11872,6 +11979,7 @@ if (framer_.transport_version() != QUIC_VERSION_99) { return; } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off PacketFragments packet = { @@ -11904,6 +12012,7 @@ if (framer_.transport_version() != QUIC_VERSION_99) { return; } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off PacketFragments packet = { // type (short header, 4 byte packet number) @@ -11939,6 +12048,7 @@ if (framer_.transport_version() != QUIC_VERSION_99) { return; } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off PacketFragments packets[] = { @@ -12339,6 +12449,7 @@ // This frame is only for version 99. return; } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off PacketFragments packet99 = { // type (short header, 4 byte packet number) @@ -12419,6 +12530,7 @@ } TEST_P(QuicFramerTest, AckFrameWithInvalidLargestObserved) { + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off unsigned char packet[] = { // public flags (8 byte connection_id) @@ -12519,6 +12631,7 @@ } TEST_P(QuicFramerTest, FirstAckBlockJustUnderFlow) { + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off unsigned char packet[] = { // public flags (8 byte connection_id) @@ -12621,6 +12734,7 @@ } TEST_P(QuicFramerTest, ThirdAckBlockJustUnderflow) { + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off unsigned char packet[] = { // public flags (8 byte connection_id) @@ -12769,6 +12883,7 @@ if (!QuicVersionHasLongHeaderLengths(framer_.transport_version())) { return; } + SetDecrypterLevel(ENCRYPTION_ZERO_RTT); // clang-format off unsigned char packet[] = { // first coalesced packet @@ -12863,6 +12978,7 @@ if (!QuicVersionHasLongHeaderLengths(framer_.transport_version())) { return; } + SetDecrypterLevel(ENCRYPTION_ZERO_RTT); // clang-format off unsigned char packet[] = { // first coalesced packet @@ -12945,6 +13061,7 @@ if (!QuicVersionHasLongHeaderLengths(framer_.transport_version())) { return; } + SetDecrypterLevel(ENCRYPTION_ZERO_RTT); // clang-format off unsigned char packet[] = { // first coalesced packet @@ -13005,6 +13122,7 @@ if (framer_.transport_version() < QUIC_VERSION_46) { return; } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); char connection_id_bytes[9] = {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, 0x42}; QuicConnectionId connection_id(connection_id_bytes, @@ -13045,6 +13163,7 @@ if (framer_.transport_version() < QUIC_VERSION_46) { return; } + SetDecrypterLevel(ENCRYPTION_ZERO_RTT); framer_.SetShouldUpdateExpectedConnectionIdLength(true); // clang-format off @@ -13099,6 +13218,7 @@ EXPECT_EQ(visitor_.header_.get()->packet_number, QuicPacketNumber(UINT64_C(0x12345678))); + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off unsigned char short_header_packet[] = { // type (short header, 4 byte packet number) @@ -13166,7 +13286,13 @@ }; // clang-format on - framer_.SetDecrypter(ENCRYPTION_ZERO_RTT, QuicMakeUnique<TestDecrypter>()); + if (framer_.version().KnowsWhichDecrypterToUse()) { + framer_.InstallDecrypter(ENCRYPTION_ZERO_RTT, + QuicMakeUnique<TestDecrypter>()); + framer_.RemoveDecrypter(ENCRYPTION_INITIAL); + } else { + framer_.SetDecrypter(ENCRYPTION_ZERO_RTT, QuicMakeUnique<TestDecrypter>()); + } if (!QuicVersionHasLongHeaderLengths(framer_.transport_version())) { EXPECT_TRUE(framer_.ProcessPacket( QuicEncryptedPacket(AsChars(long_header_packet), @@ -13202,8 +13328,14 @@ QuicEncryptedPacket short_header_encrypted( AsChars(short_header_packet), QUIC_ARRAYSIZE(short_header_packet), false); - framer_.SetDecrypter(ENCRYPTION_FORWARD_SECURE, - QuicMakeUnique<TestDecrypter>()); + if (framer_.version().KnowsWhichDecrypterToUse()) { + framer_.InstallDecrypter(ENCRYPTION_FORWARD_SECURE, + QuicMakeUnique<TestDecrypter>()); + framer_.RemoveDecrypter(ENCRYPTION_ZERO_RTT); + } else { + framer_.SetDecrypter(ENCRYPTION_FORWARD_SECURE, + QuicMakeUnique<TestDecrypter>()); + } EXPECT_TRUE(framer_.ProcessPacket(short_header_encrypted)); EXPECT_EQ(QUIC_NO_ERROR, framer_.error());
diff --git a/quic/core/quic_packet_creator_test.cc b/quic/core/quic_packet_creator_test.cc index b642c9f..2b71b06 100644 --- a/quic/core/quic_packet_creator_test.cc +++ b/quic/core/quic_packet_creator_test.cc
@@ -9,6 +9,7 @@ #include <ostream> #include <string> +#include "net/third_party/quiche/src/quic/core/crypto/null_decrypter.h" #include "net/third_party/quiche/src/quic/core/crypto/null_encrypter.h" #include "net/third_party/quiche/src/quic/core/crypto/quic_decrypter.h" #include "net/third_party/quiche/src/quic/core/crypto/quic_encrypter.h" @@ -164,6 +165,17 @@ client_framer_.set_visitor(&framer_visitor_); server_framer_.set_visitor(&framer_visitor_); client_framer_.set_data_producer(&producer_); + if (server_framer_.version().KnowsWhichDecrypterToUse()) { + server_framer_.InstallDecrypter( + ENCRYPTION_ZERO_RTT, + QuicMakeUnique<NullDecrypter>(Perspective::IS_SERVER)); + server_framer_.InstallDecrypter( + ENCRYPTION_HANDSHAKE, + QuicMakeUnique<NullDecrypter>(Perspective::IS_SERVER)); + server_framer_.InstallDecrypter( + ENCRYPTION_FORWARD_SECURE, + QuicMakeUnique<NullDecrypter>(Perspective::IS_SERVER)); + } } ~QuicPacketCreatorTest() override {
diff --git a/quic/core/quic_packet_generator_test.cc b/quic/core/quic_packet_generator_test.cc index 223383a..1e12424 100644 --- a/quic/core/quic_packet_generator_test.cc +++ b/quic/core/quic_packet_generator_test.cc
@@ -9,6 +9,7 @@ #include <string> #include "net/third_party/quiche/src/quic/core/crypto/crypto_protocol.h" +#include "net/third_party/quiche/src/quic/core/crypto/null_decrypter.h" #include "net/third_party/quiche/src/quic/core/crypto/null_encrypter.h" #include "net/third_party/quiche/src/quic/core/crypto/quic_decrypter.h" #include "net/third_party/quiche/src/quic/core/crypto/quic_encrypter.h" @@ -215,6 +216,11 @@ QuicMakeUnique<NullEncrypter>(Perspective::IS_CLIENT)); creator_->set_encryption_level(ENCRYPTION_FORWARD_SECURE); framer_.set_data_producer(&producer_); + if (simple_framer_.framer()->version().KnowsWhichDecrypterToUse()) { + simple_framer_.framer()->InstallDecrypter( + ENCRYPTION_FORWARD_SECURE, + QuicMakeUnique<NullDecrypter>(Perspective::IS_SERVER)); + } generator_.AttachPacketFlusher(); }
diff --git a/quic/core/quic_versions.cc b/quic/core/quic_versions.cc index fa9f640..ef2657c 100644 --- a/quic/core/quic_versions.cc +++ b/quic/core/quic_versions.cc
@@ -37,6 +37,11 @@ } } +bool ParsedQuicVersion::KnowsWhichDecrypterToUse() const { + return transport_version == QUIC_VERSION_99 || + handshake_protocol == PROTOCOL_TLS1_3; +} + std::ostream& operator<<(std::ostream& os, const ParsedQuicVersion& version) { os << ParsedQuicVersionToString(version); return os;
diff --git a/quic/core/quic_versions.h b/quic/core/quic_versions.h index cc780b4..160adeb 100644 --- a/quic/core/quic_versions.h +++ b/quic/core/quic_versions.h
@@ -145,6 +145,8 @@ return handshake_protocol != other.handshake_protocol || transport_version != other.transport_version; } + + bool KnowsWhichDecrypterToUse() const; }; QUIC_EXPORT_PRIVATE ParsedQuicVersion UnsupportedQuicVersion();
diff --git a/quic/core/tls_client_handshaker.cc b/quic/core/tls_client_handshaker.cc index 5081a48..932d537 100644 --- a/quic/core/tls_client_handshaker.cc +++ b/quic/core/tls_client_handshaker.cc
@@ -73,8 +73,8 @@ session()->connection_id(), &crypters); session()->connection()->SetEncrypter(ENCRYPTION_INITIAL, std::move(crypters.encrypter)); - session()->connection()->SetDecrypter(ENCRYPTION_INITIAL, - std::move(crypters.decrypter)); + session()->connection()->InstallDecrypter(ENCRYPTION_INITIAL, + std::move(crypters.decrypter)); state_ = STATE_HANDSHAKE_RUNNING; // Configure certificate verification. // TODO(nharper): This only verifies certs on initial connection, not on
diff --git a/quic/core/tls_handshaker.cc b/quic/core/tls_handshaker.cc index c6394b8..3a45f36 100644 --- a/quic/core/tls_handshaker.cc +++ b/quic/core/tls_handshaker.cc
@@ -205,22 +205,8 @@ const std::vector<uint8_t>& write_secret) { std::unique_ptr<QuicEncrypter> encrypter = CreateEncrypter(write_secret); session()->connection()->SetEncrypter(level, std::move(encrypter)); - if (level != ENCRYPTION_FORWARD_SECURE) { - std::unique_ptr<QuicDecrypter> decrypter = CreateDecrypter(read_secret); - session()->connection()->SetDecrypter(level, std::move(decrypter)); - } else { - // When forward-secure read keys are available, they get set as the - // alternative decrypter instead of the primary decrypter. One reason for - // this is that after the forward secure keys become available, the server - // still has crypto handshake messages to read at the handshake encryption - // level, meaning that both the ENCRYPTION_ZERO_RTT and - // ENCRYPTION_FORWARD_SECURE decrypters need to be available. (Tests also - // assume that an alternative decrypter gets set, so at some point we need - // to call SetAlternativeDecrypter.) - std::unique_ptr<QuicDecrypter> decrypter = CreateDecrypter(read_secret); - session()->connection()->SetAlternativeDecrypter( - level, std::move(decrypter), /*latch_once_used*/ true); - } + std::unique_ptr<QuicDecrypter> decrypter = CreateDecrypter(read_secret); + session()->connection()->InstallDecrypter(level, std::move(decrypter)); } void TlsHandshaker::WriteMessage(EncryptionLevel level, QuicStringPiece data) {
diff --git a/quic/core/tls_handshaker_test.cc b/quic/core/tls_handshaker_test.cc index 6aa83d2..a7b2aa8 100644 --- a/quic/core/tls_handshaker_test.cc +++ b/quic/core/tls_handshaker_test.cc
@@ -263,15 +263,31 @@ } } +ParsedQuicVersionVector AllTlsSupportedVersions() { + SetQuicReloadableFlag(quic_enable_version_99, true); + SetQuicFlag(&FLAGS_quic_supports_tls_handshake, true); + ParsedQuicVersionVector supported_versions; + for (QuicTransportVersion version : kSupportedTransportVersions) { + if (!QuicVersionUsesCryptoFrames(version)) { + // The TLS handshake is only deployable if CRYPTO frames are also used. + continue; + } + supported_versions.push_back(ParsedQuicVersion(PROTOCOL_TLS1_3, version)); + } + return supported_versions; +} + class TlsHandshakerTest : public QuicTest { public: TlsHandshakerTest() : client_conn_(new MockQuicConnection(&conn_helper_, &alarm_factory_, - Perspective::IS_CLIENT)), + Perspective::IS_CLIENT, + AllTlsSupportedVersions())), server_conn_(new MockQuicConnection(&conn_helper_, &alarm_factory_, - Perspective::IS_SERVER)), + Perspective::IS_SERVER, + AllTlsSupportedVersions())), client_session_(client_conn_, /*create_mock_crypto_stream=*/false), server_session_(server_conn_, /*create_mock_crypto_stream=*/false) { client_stream_ = new TestQuicCryptoClientStream(&client_session_);
diff --git a/quic/core/tls_server_handshaker.cc b/quic/core/tls_server_handshaker.cc index 96e1802..ec254b5 100644 --- a/quic/core/tls_server_handshaker.cc +++ b/quic/core/tls_server_handshaker.cc
@@ -63,14 +63,16 @@ : TlsHandshaker(stream, session, ssl_ctx), proof_source_(proof_source), crypto_negotiated_params_(new QuicCryptoNegotiatedParameters) { + DCHECK_EQ(PROTOCOL_TLS1_3, + session->connection()->version().handshake_protocol); CrypterPair crypters; CryptoUtils::CreateTlsInitialCrypters( Perspective::IS_SERVER, session->connection()->transport_version(), session->connection_id(), &crypters); session->connection()->SetEncrypter(ENCRYPTION_INITIAL, std::move(crypters.encrypter)); - session->connection()->SetDecrypter(ENCRYPTION_INITIAL, - std::move(crypters.decrypter)); + session->connection()->InstallDecrypter(ENCRYPTION_INITIAL, + std::move(crypters.decrypter)); // Configure the SSL to be a server. SSL_set_accept_state(ssl());
diff --git a/quic/test_tools/crypto_test_utils.cc b/quic/test_tools/crypto_test_utils.cc index af33d4e..63154b9 100644 --- a/quic/test_tools/crypto_test_utils.cc +++ b/quic/test_tools/crypto_test_utils.cc
@@ -749,6 +749,11 @@ void CompareCrypters(const QuicEncrypter* encrypter, const QuicDecrypter* decrypter, std::string label) { + if (encrypter == nullptr || decrypter == nullptr) { + ADD_FAILURE() << "Expected non-null crypters; have " << encrypter << " and " + << decrypter; + return; + } QuicStringPiece encrypter_key = encrypter->GetKey(); QuicStringPiece encrypter_iv = encrypter->GetNoncePrefix(); QuicStringPiece decrypter_key = decrypter->GetKey();
diff --git a/quic/test_tools/quic_test_utils.cc b/quic/test_tools/quic_test_utils.cc index cd54eb8..2de1491 100644 --- a/quic/test_tools/quic_test_utils.cc +++ b/quic/test_tools/quic_test_utils.cc
@@ -499,7 +499,7 @@ : QuicSession(connection, nullptr, DefaultQuicConfig(), - CurrentSupportedVersions()) { + connection->supported_versions()) { if (create_mock_crypto_stream) { crypto_stream_ = QuicMakeUnique<MockQuicCryptoStream>(this); } @@ -923,6 +923,11 @@ header.reset_flag = reset_flag; header.packet_number_length = packet_number_length; header.packet_number = QuicPacketNumber(packet_number); + if (QuicVersionHasLongHeaderLengths((*versions)[0].transport_version) && + version_flag) { + header.retry_token_length_length = VARIABLE_LENGTH_INTEGER_LENGTH_1; + header.length_length = VARIABLE_LENGTH_INTEGER_LENGTH_2; + } QuicFrame frame(QuicStreamFrame(1, false, 0, QuicStringPiece(data))); QuicFrames frames; frames.push_back(frame); @@ -941,8 +946,7 @@ GetIncludedDestinationConnectionIdLength(header), GetIncludedSourceConnectionIdLength(header), version_flag, false /* no diversification nonce */, packet_number_length, - VARIABLE_LENGTH_INTEGER_LENGTH_0, 0, VARIABLE_LENGTH_INTEGER_LENGTH_0)] = - 0x1F; + header.retry_token_length_length, 0, header.length_length)] = 0x1F; char* buffer = new char[kMaxOutgoingPacketSize]; size_t encrypted_length =
diff --git a/quic/test_tools/simulator/quic_endpoint.cc b/quic/test_tools/simulator/quic_endpoint.cc index 1072724..9e5c3fa 100644 --- a/quic/test_tools/simulator/quic_endpoint.cc +++ b/quic/test_tools/simulator/quic_endpoint.cc
@@ -86,8 +86,14 @@ connection_.set_visitor(this); connection_.SetEncrypter(ENCRYPTION_FORWARD_SECURE, QuicMakeUnique<NullEncrypter>(perspective)); - connection_.SetDecrypter(ENCRYPTION_FORWARD_SECURE, - QuicMakeUnique<NullDecrypter>(perspective)); + if (connection_.version().KnowsWhichDecrypterToUse()) { + connection_.InstallDecrypter(ENCRYPTION_FORWARD_SECURE, + QuicMakeUnique<NullDecrypter>(perspective)); + connection_.RemoveDecrypter(ENCRYPTION_INITIAL); + } else { + connection_.SetDecrypter(ENCRYPTION_FORWARD_SECURE, + QuicMakeUnique<NullDecrypter>(perspective)); + } connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); if (perspective == Perspective::IS_SERVER) { // Skip version negotiation.