Implement QUIC Header Protection gfe-relnote: Protected by QUIC_VERSION_99 PiperOrigin-RevId: 247137283 Change-Id: I1deb08d304b7739c3c8fa6b995e55fbd8652dc1e
diff --git a/quic/core/chlo_extractor.cc b/quic/core/chlo_extractor.cc index 602940c..2398693 100644 --- a/quic/core/chlo_extractor.cc +++ b/quic/core/chlo_extractor.cc
@@ -5,8 +5,10 @@ #include "net/third_party/quiche/src/quic/core/chlo_extractor.h" #include "net/third_party/quiche/src/quic/core/crypto/crypto_framer.h" +#include "net/third_party/quiche/src/quic/core/crypto/crypto_handshake.h" #include "net/third_party/quiche/src/quic/core/crypto/crypto_handshake_message.h" #include "net/third_party/quiche/src/quic/core/crypto/crypto_protocol.h" +#include "net/third_party/quiche/src/quic/core/crypto/crypto_utils.h" #include "net/third_party/quiche/src/quic/core/crypto/quic_decrypter.h" #include "net/third_party/quiche/src/quic/core/crypto/quic_encrypter.h" #include "net/third_party/quiche/src/quic/core/quic_framer.h" @@ -116,6 +118,23 @@ bool ChloFramerVisitor::OnUnauthenticatedPublicHeader( const QuicPacketHeader& header) { connection_id_ = header.destination_connection_id; + // QuicFramer creates a NullEncrypter and NullDecrypter at level + // ENCRYPTION_INITIAL, which are the correct ones to use with the QUIC Crypto + // handshake. When the TLS handshake is used, the IETF-style initial crypters + // are used instead, so those need to be created and installed. + if (header.version.handshake_protocol == PROTOCOL_TLS1_3) { + CrypterPair crypters; + CryptoUtils::CreateTlsInitialCrypters( + Perspective::IS_SERVER, header.version.transport_version, + header.destination_connection_id, &crypters); + framer_->SetEncrypter(ENCRYPTION_INITIAL, std::move(crypters.encrypter)); + if (framer_->version().KnowsWhichDecrypterToUse()) { + framer_->InstallDecrypter(ENCRYPTION_INITIAL, + std::move(crypters.decrypter)); + } else { + framer_->SetDecrypter(ENCRYPTION_INITIAL, std::move(crypters.decrypter)); + } + } return true; } bool ChloFramerVisitor::OnUnauthenticatedHeader(
diff --git a/quic/core/chlo_extractor_test.cc b/quic/core/chlo_extractor_test.cc index cea445d..e5e084c 100644 --- a/quic/core/chlo_extractor_test.cc +++ b/quic/core/chlo_extractor_test.cc
@@ -70,6 +70,14 @@ } QuicFramer framer(SupportedVersions(header_.version), QuicTime::Zero(), Perspective::IS_CLIENT, kQuicDefaultConnectionIdLength); + if (version.handshake_protocol == PROTOCOL_TLS1_3) { + CrypterPair crypters; + CryptoUtils::CreateTlsInitialCrypters(Perspective::IS_CLIENT, + version.transport_version, + TestConnectionId(), &crypters); + framer.SetEncrypter(ENCRYPTION_INITIAL, std::move(crypters.encrypter)); + framer.SetDecrypter(ENCRYPTION_INITIAL, std::move(crypters.decrypter)); + } if (!QuicVersionUsesCryptoFrames(version.transport_version) || munge_stream_id) { QuicStreamId stream_id =
diff --git a/quic/core/crypto/aes_base_encrypter.cc b/quic/core/crypto/aes_base_encrypter.cc index 49fc71d..cdb21b8 100644 --- a/quic/core/crypto/aes_base_encrypter.cc +++ b/quic/core/crypto/aes_base_encrypter.cc
@@ -11,7 +11,7 @@ bool AesBaseEncrypter::SetHeaderProtectionKey(QuicStringPiece key) { if (key.size() != GetKeySize()) { - QUIC_BUG << "Invalid key size for header protection"; + QUIC_BUG << "Invalid key size for header protection: " << key.size(); return false; } if (AES_set_encrypt_key(reinterpret_cast<const uint8_t*>(key.data()),
diff --git a/quic/core/crypto/crypto_utils.cc b/quic/core/crypto/crypto_utils.cc index 724f245..57edda0 100644 --- a/quic/core/crypto/crypto_utils.cc +++ b/quic/core/crypto/crypto_utils.cc
@@ -76,10 +76,14 @@ prf, pp_secret, "quic key", crypter->GetKeySize()); std::vector<uint8_t> iv = CryptoUtils::HkdfExpandLabel( prf, pp_secret, "quic iv", crypter->GetIVSize()); + std::vector<uint8_t> pn = CryptoUtils::HkdfExpandLabel( + prf, pp_secret, "quic hp", crypter->GetKeySize()); crypter->SetKey( QuicStringPiece(reinterpret_cast<char*>(key.data()), key.size())); crypter->SetIV( QuicStringPiece(reinterpret_cast<char*>(iv.data()), iv.size())); + crypter->SetHeaderProtectionKey( + QuicStringPiece(reinterpret_cast<char*>(pn.data()), pn.size())); } namespace { @@ -224,15 +228,23 @@ if (perspective == Perspective::IS_SERVER) { if (!crypters->encrypter->SetKey(hkdf.server_write_key()) || !crypters->encrypter->SetNoncePrefix(hkdf.server_write_iv()) || + !crypters->encrypter->SetHeaderProtectionKey( + hkdf.server_hp_key()) || !crypters->decrypter->SetKey(hkdf.client_write_key()) || - !crypters->decrypter->SetNoncePrefix(hkdf.client_write_iv())) { + !crypters->decrypter->SetNoncePrefix(hkdf.client_write_iv()) || + !crypters->decrypter->SetHeaderProtectionKey( + hkdf.client_hp_key())) { return false; } } else { if (!crypters->encrypter->SetKey(hkdf.client_write_key()) || !crypters->encrypter->SetNoncePrefix(hkdf.client_write_iv()) || + !crypters->encrypter->SetHeaderProtectionKey( + hkdf.client_hp_key()) || !crypters->decrypter->SetKey(hkdf.server_write_key()) || - !crypters->decrypter->SetNoncePrefix(hkdf.server_write_iv())) { + !crypters->decrypter->SetNoncePrefix(hkdf.server_write_iv()) || + !crypters->decrypter->SetHeaderProtectionKey( + hkdf.server_hp_key())) { return false; } } @@ -246,8 +258,10 @@ if (!crypters->encrypter->SetKey(hkdf.client_write_key()) || !crypters->encrypter->SetNoncePrefix(hkdf.client_write_iv()) || + !crypters->encrypter->SetHeaderProtectionKey(hkdf.client_hp_key()) || !crypters->decrypter->SetPreliminaryKey(hkdf.server_write_key()) || - !crypters->decrypter->SetNoncePrefix(hkdf.server_write_iv())) { + !crypters->decrypter->SetNoncePrefix(hkdf.server_write_iv()) || + !crypters->decrypter->SetHeaderProtectionKey(hkdf.server_hp_key())) { return false; } break; @@ -265,8 +279,10 @@ &nonce_prefix); if (!crypters->decrypter->SetKey(hkdf.client_write_key()) || !crypters->decrypter->SetNoncePrefix(hkdf.client_write_iv()) || + !crypters->decrypter->SetHeaderProtectionKey(hkdf.client_hp_key()) || !crypters->encrypter->SetKey(key) || - !crypters->encrypter->SetNoncePrefix(nonce_prefix)) { + !crypters->encrypter->SetNoncePrefix(nonce_prefix) || + !crypters->encrypter->SetHeaderProtectionKey(hkdf.server_hp_key())) { return false; } break;
diff --git a/quic/core/crypto/quic_hkdf.cc b/quic/core/crypto/quic_hkdf.cc index 3754cab..1bd9ad5 100644 --- a/quic/core/crypto/quic_hkdf.cc +++ b/quic/core/crypto/quic_hkdf.cc
@@ -39,8 +39,8 @@ size_t server_iv_bytes_to_generate, size_t subkey_secret_bytes_to_generate) { const size_t material_length = - client_key_bytes_to_generate + client_iv_bytes_to_generate + - server_key_bytes_to_generate + server_iv_bytes_to_generate + + 2 * client_key_bytes_to_generate + client_iv_bytes_to_generate + + 2 * server_key_bytes_to_generate + server_iv_bytes_to_generate + subkey_secret_bytes_to_generate; DCHECK_LT(material_length, kMaxKeyMaterialSize); @@ -85,6 +85,19 @@ if (subkey_secret_bytes_to_generate) { subkey_secret_ = QuicStringPiece(reinterpret_cast<char*>(&output_[j]), subkey_secret_bytes_to_generate); + j += subkey_secret_bytes_to_generate; + } + // Repeat client and server key bytes for header protection keys. + if (client_key_bytes_to_generate) { + client_hp_key_ = QuicStringPiece(reinterpret_cast<char*>(&output_[j]), + client_key_bytes_to_generate); + j += client_key_bytes_to_generate; + } + + if (server_key_bytes_to_generate) { + server_hp_key_ = QuicStringPiece(reinterpret_cast<char*>(&output_[j]), + server_key_bytes_to_generate); + j += server_key_bytes_to_generate; } }
diff --git a/quic/core/crypto/quic_hkdf.h b/quic/core/crypto/quic_hkdf.h index fb80f7b..c57b894 100644 --- a/quic/core/crypto/quic_hkdf.h +++ b/quic/core/crypto/quic_hkdf.h
@@ -54,6 +54,8 @@ QuicStringPiece server_write_key() const { return server_write_key_; } QuicStringPiece server_write_iv() const { return server_write_iv_; } QuicStringPiece subkey_secret() const { return subkey_secret_; } + QuicStringPiece client_hp_key() const { return client_hp_key_; } + QuicStringPiece server_hp_key() const { return server_hp_key_; } private: std::vector<uint8_t> output_; @@ -63,6 +65,8 @@ QuicStringPiece client_write_iv_; QuicStringPiece server_write_iv_; QuicStringPiece subkey_secret_; + QuicStringPiece client_hp_key_; + QuicStringPiece server_hp_key_; }; } // namespace quic
diff --git a/quic/core/http/quic_spdy_client_session_test.cc b/quic/core/http/quic_spdy_client_session_test.cc index 7f4d8dc..2787175 100644 --- a/quic/core/http/quic_spdy_client_session_test.cc +++ b/quic/core/http/quic_spdy_client_session_test.cc
@@ -8,7 +8,8 @@ #include <string> #include <vector> -#include "net/third_party/quiche/src/quic/core/crypto/aes_128_gcm_12_encrypter.h" +#include "net/third_party/quiche/src/quic/core/crypto/null_decrypter.h" +#include "net/third_party/quiche/src/quic/core/crypto/null_encrypter.h" #include "net/third_party/quiche/src/quic/core/http/quic_spdy_client_stream.h" #include "net/third_party/quiche/src/quic/core/http/spdy_utils.h" #include "net/third_party/quiche/src/quic/core/quic_utils.h" @@ -528,6 +529,15 @@ TEST_P(QuicSpdyClientSessionTest, InvalidFramedPacketReceived) { QuicSocketAddress server_address(TestPeerIPAddress(), kTestPort); QuicSocketAddress client_address(TestPeerIPAddress(), kTestPort); + if (GetParam().KnowsWhichDecrypterToUse()) { + connection_->InstallDecrypter( + ENCRYPTION_FORWARD_SECURE, + QuicMakeUnique<NullDecrypter>(Perspective::IS_CLIENT)); + } else { + connection_->SetDecrypter( + ENCRYPTION_FORWARD_SECURE, + QuicMakeUnique<NullDecrypter>(Perspective::IS_CLIENT)); + } EXPECT_CALL(*connection_, ProcessUdpPacket(server_address, client_address, _)) .WillRepeatedly(Invoke(static_cast<MockQuicConnection*>(connection_),
diff --git a/quic/core/quic_connection_test.cc b/quic/core/quic_connection_test.cc index dad38ab..80a7469 100644 --- a/quic/core/quic_connection_test.cc +++ b/quic/core/quic_connection_test.cc
@@ -65,8 +65,8 @@ namespace test { namespace { -const char data1[] = "foo"; -const char data2[] = "bar"; +const char data1[] = "foo data"; +const char data2[] = "bar data"; const bool kHasStopWaiting = true; @@ -1051,8 +1051,9 @@ frames.push_back(QuicFrame(frame)); QuicPacketCreatorPeer::SetSendVersionInPacket( &peer_creator_, connection_.perspective() == Perspective::IS_SERVER); - if (QuicPacketCreatorPeer::GetEncryptionLevel(&peer_creator_) > - ENCRYPTION_INITIAL) { + EncryptionLevel peer_encryption_level = + QuicPacketCreatorPeer::GetEncryptionLevel(&peer_creator_); + if (peer_encryption_level > ENCRYPTION_INITIAL) { // Set peer_framer_'s corresponding encrypter. peer_creator_.SetEncrypter( QuicPacketCreatorPeer::GetEncryptionLevel(&peer_creator_), @@ -2542,7 +2543,8 @@ OnPacketSent(_, _, _, _, HAS_RETRANSMITTABLE_DATA)); connection_.SendStreamDataWithString(3, "foo", 6, NO_FIN); // No ack sent. - EXPECT_EQ(1u, writer_->frame_count()); + size_t padding_frame_count = writer_->padding_frames().size(); + EXPECT_EQ(padding_frame_count + 1u, writer_->frame_count()); EXPECT_EQ(1u, writer_->stream_frames().size()); // No more packet loss for the rest of the test. @@ -2980,7 +2982,8 @@ EXPECT_FALSE(connection_.HasQueuedData()); // Parse the last packet and ensure it's the stream frame from stream 3. - EXPECT_EQ(1u, writer_->frame_count()); + size_t padding_frame_count = writer_->padding_frames().size(); + EXPECT_EQ(padding_frame_count + 1u, writer_->frame_count()); ASSERT_EQ(1u, writer_->stream_frames().size()); EXPECT_EQ(GetNthClientInitiatedStreamId(1, connection_.transport_version()), writer_->stream_frames()[0]->stream_id); @@ -3117,9 +3120,16 @@ EXPECT_EQ(0u, connection_.NumQueuedPackets()); EXPECT_FALSE(connection_.HasQueuedData()); + // Padding frames are added by v99 to ensure a minimum packet size. + size_t extra_padding_frames = 0; + if (GetParam().version.HasHeaderProtection()) { + extra_padding_frames = 1; + } + // Parse the last packet and ensure it's one stream frame from one stream. - EXPECT_EQ(1u, writer_->frame_count()); - EXPECT_EQ(1u, writer_->stream_frames().size()); + EXPECT_EQ(1u + extra_padding_frames, writer_->frame_count()); + EXPECT_EQ(extra_padding_frames, writer_->padding_frames().size()); + ASSERT_EQ(1u, writer_->stream_frames().size()); EXPECT_EQ(QuicUtils::GetHeadersStreamId(connection_.transport_version()), writer_->stream_frames()[0]->stream_id); EXPECT_TRUE(writer_->stream_frames()[0]->fin); @@ -3157,7 +3167,7 @@ // Parse the last packet and ensure it's one stream frame with a fin. EXPECT_EQ(1u, writer_->frame_count()); - EXPECT_EQ(1u, writer_->stream_frames().size()); + ASSERT_EQ(1u, writer_->stream_frames().size()); EXPECT_EQ(QuicUtils::GetHeadersStreamId(connection_.transport_version()), writer_->stream_frames()[0]->stream_id); EXPECT_TRUE(writer_->stream_frames()[0]->fin); @@ -3250,7 +3260,8 @@ connection_.SendControlFrame(QuicFrame(new QuicRstStreamFrame( 1, stream_id, QUIC_ERROR_PROCESSING_STREAM, 14))); } - EXPECT_EQ(1u, writer_->frame_count()); + size_t padding_frame_count = writer_->padding_frames().size(); + EXPECT_EQ(padding_frame_count + 1u, writer_->frame_count()); EXPECT_EQ(1u, writer_->rst_stream_frames().size()); } @@ -3277,7 +3288,8 @@ connection_.SendControlFrame(QuicFrame( new QuicRstStreamFrame(1, stream_id, QUIC_STREAM_NO_ERROR, 14))); } - EXPECT_EQ(1u, writer_->frame_count()); + size_t padding_frame_count = writer_->padding_frames().size(); + EXPECT_EQ(padding_frame_count + 1u, writer_->frame_count()); EXPECT_EQ(1u, writer_->rst_stream_frames().size()); } @@ -3343,7 +3355,8 @@ EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); clock_.AdvanceTime(DefaultRetransmissionTime()); connection_.GetRetransmissionAlarm()->Fire(); - EXPECT_EQ(1u, writer_->frame_count()); + size_t padding_frame_count = writer_->padding_frames().size(); + EXPECT_EQ(padding_frame_count + 1u, writer_->frame_count()); EXPECT_EQ(1u, writer_->rst_stream_frames().size()); EXPECT_EQ(stream_id, writer_->rst_stream_frames().front().stream_id); } @@ -3402,7 +3415,8 @@ EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(AtLeast(2)); clock_.AdvanceTime(DefaultRetransmissionTime()); connection_.GetRetransmissionAlarm()->Fire(); - EXPECT_EQ(1u, writer_->frame_count()); + size_t padding_frame_count = writer_->padding_frames().size(); + EXPECT_EQ(padding_frame_count + 1u, writer_->frame_count()); ASSERT_EQ(1u, writer_->rst_stream_frames().size()); EXPECT_EQ(stream_id, writer_->rst_stream_frames().front().stream_id); } @@ -3438,7 +3452,8 @@ connection_.SendControlFrame(QuicFrame(new QuicRstStreamFrame( 1, stream_id, QUIC_ERROR_PROCESSING_STREAM, 14))); } - EXPECT_EQ(1u, writer_->frame_count()); + size_t padding_frame_count = writer_->padding_frames().size(); + EXPECT_EQ(padding_frame_count + 1u, writer_->frame_count()); ASSERT_EQ(1u, writer_->rst_stream_frames().size()); EXPECT_EQ(stream_id, writer_->rst_stream_frames().front().stream_id); } @@ -3476,7 +3491,8 @@ // retransmission. connection_.SendControlFrame(QuicFrame( new QuicRstStreamFrame(1, stream_id, QUIC_STREAM_NO_ERROR, 14))); - EXPECT_EQ(1u, writer_->frame_count()); + size_t padding_frame_count = writer_->padding_frames().size(); + EXPECT_EQ(padding_frame_count + 1u, writer_->frame_count()); EXPECT_EQ(1u, writer_->rst_stream_frames().size()); } @@ -4655,7 +4671,8 @@ clock_.AdvanceTime(QuicTime::Delta::FromSeconds(15)); EXPECT_CALL(visitor_, SendPing()).WillOnce(Invoke([this]() { SendPing(); })); connection_.GetPingAlarm()->Fire(); - EXPECT_EQ(1u, writer_->frame_count()); + size_t padding_frame_count = writer_->padding_frames().size(); + EXPECT_EQ(padding_frame_count + 1u, writer_->frame_count()); ASSERT_EQ(1u, writer_->ping_frames().size()); writer_->Reset(); @@ -4710,7 +4727,8 @@ connection_.SendControlFrame(QuicFrame(QuicPingFrame(1))); })); connection_.GetPingAlarm()->Fire(); - EXPECT_EQ(1u, writer_->frame_count()); + size_t padding_frame_count = writer_->padding_frames().size(); + EXPECT_EQ(padding_frame_count + 1u, writer_->frame_count()); ASSERT_EQ(1u, writer_->ping_frames().size()); writer_->Reset(); @@ -5679,11 +5697,12 @@ clock_.AdvanceTime(DefaultDelayedAckTime()); connection_.GetAckAlarm()->Fire(); // Check that ack is sent and that delayed ack alarm is reset. + size_t padding_frame_count = writer_->padding_frames().size(); if (GetParam().no_stop_waiting) { - EXPECT_EQ(1u, writer_->frame_count()); + EXPECT_EQ(padding_frame_count + 1u, writer_->frame_count()); EXPECT_TRUE(writer_->stop_waiting_frames().empty()); } else { - EXPECT_EQ(2u, writer_->frame_count()); + EXPECT_EQ(padding_frame_count + 2u, writer_->frame_count()); EXPECT_FALSE(writer_->stop_waiting_frames().empty()); } EXPECT_FALSE(writer_->ack_frames().empty()); @@ -5721,11 +5740,12 @@ clock_.AdvanceTime(DefaultDelayedAckTime()); connection_.GetAckAlarm()->Fire(); // Check that ack is sent and that delayed ack alarm is reset. + size_t padding_frame_count = writer_->padding_frames().size(); if (GetParam().no_stop_waiting) { - EXPECT_EQ(1u, writer_->frame_count()); + EXPECT_EQ(padding_frame_count + 1u, writer_->frame_count()); EXPECT_TRUE(writer_->stop_waiting_frames().empty()); } else { - EXPECT_EQ(2u, writer_->frame_count()); + EXPECT_EQ(padding_frame_count + 2u, writer_->frame_count()); EXPECT_FALSE(writer_->stop_waiting_frames().empty()); } EXPECT_FALSE(writer_->ack_frames().empty()); @@ -5744,11 +5764,12 @@ clock_.AdvanceTime(DefaultDelayedAckTime()); connection_.GetAckAlarm()->Fire(); // Check that ack is sent and that delayed ack alarm is reset. + padding_frame_count = writer_->padding_frames().size(); if (GetParam().no_stop_waiting) { - EXPECT_EQ(1u, writer_->frame_count()); + EXPECT_EQ(padding_frame_count + 1u, writer_->frame_count()); EXPECT_TRUE(writer_->stop_waiting_frames().empty()); } else { - EXPECT_EQ(2u, writer_->frame_count()); + EXPECT_EQ(padding_frame_count + 2u, writer_->frame_count()); EXPECT_FALSE(writer_->stop_waiting_frames().empty()); } EXPECT_FALSE(writer_->ack_frames().empty()); @@ -5864,11 +5885,12 @@ clock_.AdvanceTime(DefaultDelayedAckTime()); connection_.GetAckAlarm()->Fire(); // Check that ack is sent and that delayed ack alarm is reset. + size_t padding_frame_count = writer_->padding_frames().size(); if (GetParam().no_stop_waiting) { - EXPECT_EQ(1u, writer_->frame_count()); + EXPECT_EQ(padding_frame_count + 1u, writer_->frame_count()); EXPECT_TRUE(writer_->stop_waiting_frames().empty()); } else { - EXPECT_EQ(2u, writer_->frame_count()); + EXPECT_EQ(padding_frame_count + 2u, writer_->frame_count()); EXPECT_FALSE(writer_->stop_waiting_frames().empty()); } EXPECT_FALSE(writer_->ack_frames().empty()); @@ -5887,11 +5909,12 @@ clock_.AdvanceTime(DefaultDelayedAckTime()); connection_.GetAckAlarm()->Fire(); // Check that ack is sent and that delayed ack alarm is reset. + padding_frame_count = writer_->padding_frames().size(); if (GetParam().no_stop_waiting) { - EXPECT_EQ(1u, writer_->frame_count()); + EXPECT_EQ(padding_frame_count + 1u, writer_->frame_count()); EXPECT_TRUE(writer_->stop_waiting_frames().empty()); } else { - EXPECT_EQ(2u, writer_->frame_count()); + EXPECT_EQ(padding_frame_count + 2u, writer_->frame_count()); EXPECT_FALSE(writer_->stop_waiting_frames().empty()); } EXPECT_FALSE(writer_->ack_frames().empty()); @@ -6425,11 +6448,12 @@ ProcessPacket(1); ProcessPacket(2); // Check that ack is sent and that delayed ack alarm is reset. + size_t padding_frame_count = writer_->padding_frames().size(); if (GetParam().no_stop_waiting) { - EXPECT_EQ(1u, writer_->frame_count()); + EXPECT_EQ(padding_frame_count + 1u, writer_->frame_count()); EXPECT_TRUE(writer_->stop_waiting_frames().empty()); } else { - EXPECT_EQ(2u, writer_->frame_count()); + EXPECT_EQ(padding_frame_count + 2u, writer_->frame_count()); EXPECT_FALSE(writer_->stop_waiting_frames().empty()); } EXPECT_FALSE(writer_->ack_frames().empty()); @@ -6450,14 +6474,16 @@ ProcessPacket(2); size_t frames_per_ack = GetParam().no_stop_waiting ? 1 : 2; if (!GetQuicRestartFlag(quic_enable_accept_random_ipn)) { - EXPECT_EQ(frames_per_ack, writer_->frame_count()); + size_t padding_frame_count = writer_->padding_frames().size(); + EXPECT_EQ(padding_frame_count + frames_per_ack, writer_->frame_count()); EXPECT_FALSE(writer_->ack_frames().empty()); writer_->Reset(); } EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); ProcessPacket(3); - EXPECT_EQ(frames_per_ack, writer_->frame_count()); + size_t padding_frame_count = writer_->padding_frames().size(); + EXPECT_EQ(padding_frame_count + frames_per_ack, writer_->frame_count()); EXPECT_FALSE(writer_->ack_frames().empty()); writer_->Reset(); @@ -6470,21 +6496,23 @@ if (GetQuicRestartFlag(quic_enable_accept_random_ipn)) { EXPECT_EQ(0u, writer_->frame_count()); } else { - EXPECT_EQ(frames_per_ack, writer_->frame_count()); + EXPECT_EQ(padding_frame_count + frames_per_ack, writer_->frame_count()); EXPECT_FALSE(writer_->ack_frames().empty()); writer_->Reset(); } EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); ProcessPacket(5); - EXPECT_EQ(frames_per_ack, writer_->frame_count()); + padding_frame_count = writer_->padding_frames().size(); + EXPECT_EQ(padding_frame_count + frames_per_ack, writer_->frame_count()); EXPECT_FALSE(writer_->ack_frames().empty()); writer_->Reset(); EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(0); // Now only set the timer on the 6th packet, instead of sending another ack. ProcessPacket(6); - EXPECT_EQ(0u, writer_->frame_count()); + padding_frame_count = writer_->padding_frames().size(); + EXPECT_EQ(padding_frame_count, writer_->frame_count()); EXPECT_TRUE(connection_.GetAckAlarm()->IsSet()); } @@ -6643,7 +6671,8 @@ .WillOnce(SetArgPointee<5>(lost_packets)); EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _)); ProcessAckPacket(&ack); - EXPECT_EQ(1u, writer_->frame_count()); + size_t padding_frame_count = writer_->padding_frames().size(); + EXPECT_EQ(padding_frame_count + 1u, writer_->frame_count()); EXPECT_EQ(1u, writer_->stream_frames().size()); writer_->Reset(); @@ -7458,7 +7487,8 @@ EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _)); EXPECT_CALL(*send_algorithm_, CanSend(_)).WillRepeatedly(Return(true)); ProcessAckPacket(&ack); - EXPECT_EQ(1u, writer_->frame_count()); + size_t padding_frame_count = writer_->padding_frames().size(); + EXPECT_EQ(padding_frame_count + 1u, writer_->frame_count()); EXPECT_EQ(1u, writer_->stream_frames().size()); EXPECT_TRUE(connection_.GetSendAlarm()->IsSet()); EXPECT_EQ(scheduled_pacing_time, connection_.GetSendAlarm()->deadline()); @@ -8339,10 +8369,11 @@ connection_.SendControlFrame(QuicFrame(QuicPingFrame(1))); })); connection_.GetPingAlarm()->Fire(); + size_t padding_frame_count = writer_->padding_frames().size(); if (GetParam().no_stop_waiting) { - EXPECT_EQ(2u, writer_->frame_count()); + EXPECT_EQ(padding_frame_count + 2u, writer_->frame_count()); } else { - EXPECT_EQ(3u, writer_->frame_count()); + EXPECT_EQ(padding_frame_count + 3u, writer_->frame_count()); } ASSERT_EQ(1u, writer_->ping_frames().size()); } @@ -8414,10 +8445,11 @@ connection_.SendControlFrame(QuicFrame(QuicPingFrame(1))); })); connection_.GetPingAlarm()->Fire(); + size_t padding_frame_count = writer_->padding_frames().size(); if (GetParam().no_stop_waiting) { - EXPECT_EQ(2u, writer_->frame_count()); + EXPECT_EQ(padding_frame_count + 2u, writer_->frame_count()); } else { - EXPECT_EQ(3u, writer_->frame_count()); + EXPECT_EQ(padding_frame_count + 3u, writer_->frame_count()); } ASSERT_EQ(1u, writer_->ping_frames().size()); } @@ -8566,6 +8598,7 @@ connection_.peer_address()); // Save the random contents of the challenge for later comparison to the // response. + ASSERT_GE(writer_->path_challenge_frames().size(), 1u); QuicPathFrameBuffer challenge_data = writer_->path_challenge_frames().front().data_buffer;
diff --git a/quic/core/quic_data_reader.cc b/quic/core/quic_data_reader.cc index b13b061..ef09483 100644 --- a/quic/core/quic_data_reader.cc +++ b/quic/core/quic_data_reader.cc
@@ -12,6 +12,9 @@ namespace quic { +QuicDataReader::QuicDataReader(QuicStringPiece data) + : QuicDataReader(data.data(), data.length(), NETWORK_BYTE_ORDER) {} + QuicDataReader::QuicDataReader(const char* data, const size_t len) : QuicDataReader(data, len, NETWORK_BYTE_ORDER) {} @@ -180,6 +183,15 @@ return true; } +bool QuicDataReader::Seek(size_t size) { + if (!CanRead(size)) { + OnFailure(); + return false; + } + pos_ += size; + return true; +} + bool QuicDataReader::IsDoneReading() const { return len_ == pos_; }
diff --git a/quic/core/quic_data_reader.h b/quic/core/quic_data_reader.h index 9e88a7e..a03b927 100644 --- a/quic/core/quic_data_reader.h +++ b/quic/core/quic_data_reader.h
@@ -33,6 +33,9 @@ public: // Constructs a reader using NETWORK_BYTE_ORDER endianness. // Caller must provide an underlying buffer to work on. + explicit QuicDataReader(QuicStringPiece data); + // Constructs a reader using NETWORK_BYTE_ORDER endianness. + // Caller must provide an underlying buffer to work on. QuicDataReader(const char* data, const size_t len); // Constructs a reader using the specified endianness. // Caller must provide an underlying buffer to work on. @@ -108,6 +111,11 @@ // Returns true on success, false otherwise. bool ReadBytes(void* result, size_t size); + // Skips over |size| bytes from the buffer and forwards the internal iterator. + // Returns true if there are at least |size| bytes remaining to read, false + // otherwise. + bool Seek(size_t size); + // Returns true if the entirety of the underlying buffer has been read via // Read*() calls. bool IsDoneReading() const;
diff --git a/quic/core/quic_data_writer.cc b/quic/core/quic_data_writer.cc index 7116853..42f1e4e 100644 --- a/quic/core/quic_data_writer.cc +++ b/quic/core/quic_data_writer.cc
@@ -196,6 +196,14 @@ return true; } +bool QuicDataWriter::Seek(size_t length) { + if (!BeginWrite(length)) { + return false; + } + length_ += length; + return true; +} + // Converts a uint64_t into an IETF/Quic formatted Variable Length // Integer. IETF Variable Length Integers have 62 significant bits, so // the value to write must be in the range of 0..(2^62)-1.
diff --git a/quic/core/quic_data_writer.h b/quic/core/quic_data_writer.h index bd1ded6..d2d2b6b 100644 --- a/quic/core/quic_data_writer.h +++ b/quic/core/quic_data_writer.h
@@ -117,6 +117,11 @@ // Write |length| random bytes generated by |random|. bool WriteRandomBytes(QuicRandom* random, size_t length); + // Advance the writer's position for writing by |length| bytes without writing + // anything. This method only makes sense to be used on a buffer that has + // already been written to (and is having certain parts rewritten). + bool Seek(size_t length); + size_t capacity() const { return capacity_; } size_t remaining() const { return capacity_ - length_; }
diff --git a/quic/core/quic_data_writer_test.cc b/quic/core/quic_data_writer_test.cc index 73bb156..07f0313 100644 --- a/quic/core/quic_data_writer_test.cc +++ b/quic/core/quic_data_writer_test.cc
@@ -1142,6 +1142,46 @@ EXPECT_EQ(123456u, read_stream_count); } +TEST_P(QuicDataWriterTest, Seek) { + char buffer[3] = {}; + QuicDataWriter writer(QUIC_ARRAYSIZE(buffer), buffer, GetParam().endianness); + EXPECT_TRUE(writer.WriteUInt8(42)); + EXPECT_TRUE(writer.Seek(1)); + EXPECT_TRUE(writer.WriteUInt8(3)); + + char expected[] = {42, 0, 3}; + for (size_t i = 0; i < QUIC_ARRAYSIZE(expected); ++i) { + EXPECT_EQ(buffer[i], expected[i]); + } +} + +TEST_P(QuicDataWriterTest, SeekTooFarFails) { + char buffer[20]; + + // Check that one can seek to the end of the writer, but not past. + { + QuicDataWriter writer(QUIC_ARRAYSIZE(buffer), buffer, + GetParam().endianness); + EXPECT_TRUE(writer.Seek(20)); + EXPECT_FALSE(writer.Seek(1)); + } + + // Seeking several bytes past the end fails. + { + QuicDataWriter writer(QUIC_ARRAYSIZE(buffer), buffer, + GetParam().endianness); + EXPECT_FALSE(writer.Seek(100)); + } + + // Seeking so far that arithmetic overflow could occur also fails. + { + QuicDataWriter writer(QUIC_ARRAYSIZE(buffer), buffer, + GetParam().endianness); + EXPECT_TRUE(writer.Seek(10)); + EXPECT_FALSE(writer.Seek(std::numeric_limits<size_t>::max())); + } +} + } // namespace } // namespace test } // namespace quic
diff --git a/quic/core/quic_dispatcher.cc b/quic/core/quic_dispatcher.cc index 750f9b4..46878d3 100644 --- a/quic/core/quic_dispatcher.cc +++ b/quic/core/quic_dispatcher.cc
@@ -472,10 +472,20 @@ } // Set the framer's version and continue processing. framer_.set_version(version); + + if (version.HasHeaderProtection()) { + ProcessHeader(header); + return false; + } return true; } bool QuicDispatcher::OnUnauthenticatedHeader(const QuicPacketHeader& header) { + ProcessHeader(header); + return false; +} + +void QuicDispatcher::ProcessHeader(const QuicPacketHeader& header) { QuicConnectionId connection_id = header.destination_connection_id; // Packet's connection ID is unknown. Apply the validity checks. QuicPacketFate fate = ValidityChecks(header); @@ -490,8 +500,6 @@ ProcessUnauthenticatedHeaderFate(fate, connection_id, header.form, header.version_flag, header.version); } - - return false; } void QuicDispatcher::ProcessUnauthenticatedHeaderFate( @@ -566,30 +574,32 @@ } // initial packet number of 0 is always invalid. - if (!header.packet_number.IsInitialized()) { - return kFateTimeWait; - } - if (GetQuicRestartFlag(quic_enable_accept_random_ipn)) { - QUIC_RESTART_FLAG_COUNT_N(quic_enable_accept_random_ipn, 1, 2); - // Accepting Initial Packet Numbers in 1...((2^31)-1) range... check - // maximum accordingly. - if (header.packet_number > MaxRandomInitialPacketNumber()) { + if (!framer_.version().HasHeaderProtection()) { + if (!header.packet_number.IsInitialized()) { return kFateTimeWait; } - } else { - // Count those that would have been accepted if FLAGS..random_ipn - // were true -- to detect/diagnose potential issues prior to - // enabling the flag. - if ((header.packet_number > - QuicPacketNumber(kMaxReasonableInitialPacketNumber)) && - (header.packet_number <= MaxRandomInitialPacketNumber())) { - QUIC_CODE_COUNT_N(had_possibly_random_ipn, 1, 2); - } - // Check that the sequence number is within the range that the client is - // expected to send before receiving a response from the server. - if (header.packet_number > - QuicPacketNumber(kMaxReasonableInitialPacketNumber)) { - return kFateTimeWait; + if (GetQuicRestartFlag(quic_enable_accept_random_ipn)) { + QUIC_RESTART_FLAG_COUNT_N(quic_enable_accept_random_ipn, 1, 2); + // Accepting Initial Packet Numbers in 1...((2^31)-1) range... check + // maximum accordingly. + if (header.packet_number > MaxRandomInitialPacketNumber()) { + return kFateTimeWait; + } + } else { + // Count those that would have been accepted if FLAGS..random_ipn + // were true -- to detect/diagnose potential issues prior to + // enabling the flag. + if ((header.packet_number > + QuicPacketNumber(kMaxReasonableInitialPacketNumber)) && + (header.packet_number <= MaxRandomInitialPacketNumber())) { + QUIC_CODE_COUNT_N(had_possibly_random_ipn, 1, 2); + } + // Check that the sequence number is within the range that the client is + // expected to send before receiving a response from the server. + if (header.packet_number > + QuicPacketNumber(kMaxReasonableInitialPacketNumber)) { + return kFateTimeWait; + } } } return kFateProcess;
diff --git a/quic/core/quic_dispatcher.h b/quic/core/quic_dispatcher.h index 9702a09..0c951ed 100644 --- a/quic/core/quic_dispatcher.h +++ b/quic/core/quic_dispatcher.h
@@ -132,10 +132,11 @@ // QuicFramerVisitorInterface implementation. Not expected to be called // outside of this class. void OnPacket() override; - // Called when the public header has been parsed. + // Called when the public header has been parsed. Returns false when just the + // public header is enough to dispatch the packet; true if the framer needs to + // continue parsing the packet. bool OnUnauthenticatedPublicHeader(const QuicPacketHeader& header) override; - // Called when the private header has been parsed of a data packet that is - // destined for the time wait manager. + // Called when the private header has been parsed. bool OnUnauthenticatedHeader(const QuicPacketHeader& header) override; void OnError(QuicFramer* framer) override; bool OnProtocolVersionMismatch(ParsedQuicVersion received_version, @@ -378,6 +379,10 @@ typedef QuicUnorderedSet<QuicConnectionId, QuicConnectionIdHash> QuicConnectionIdSet; + // Based on an unauthenticated packet header |header|, calls ValidityChecks + // and then either MaybeRejectStatelessly or ProcessUnauthenticatedHeaderFate. + void ProcessHeader(const QuicPacketHeader& header); + // Attempts to reject the connection statelessly, if stateless rejects are // possible and if the current packet contains a CHLO message. Determines a // fate which describes what subsequent processing should be performed on the
diff --git a/quic/core/quic_dispatcher_test.cc b/quic/core/quic_dispatcher_test.cc index 89befc5..878dd75 100644 --- a/quic/core/quic_dispatcher_test.cc +++ b/quic/core/quic_dispatcher_test.cc
@@ -815,6 +815,11 @@ } TEST_F(QuicDispatcherTest, TooBigSeqNoPacketToTimeWaitListManager) { + if (CurrentSupportedVersions().front().HasHeaderProtection()) { + // When header protection is in use, we don't put packets in the time wait + // list manager based on packet number. + return; + } CreateTimeWaitListManager(); SetQuicRestartFlag(quic_enable_accept_random_ipn, false); QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1);
diff --git a/quic/core/quic_framer.cc b/quic/core/quic_framer.cc index c4eefb1..7524b17 100644 --- a/quic/core/quic_framer.cc +++ b/quic/core/quic_framer.cc
@@ -10,8 +10,10 @@ #include <string> #include "net/third_party/quiche/src/quic/core/crypto/crypto_framer.h" +#include "net/third_party/quiche/src/quic/core/crypto/crypto_handshake.h" #include "net/third_party/quiche/src/quic/core/crypto/crypto_handshake_message.h" #include "net/third_party/quiche/src/quic/core/crypto/crypto_protocol.h" +#include "net/third_party/quiche/src/quic/core/crypto/crypto_utils.h" #include "net/third_party/quiche/src/quic/core/crypto/null_decrypter.h" #include "net/third_party/quiche/src/quic/core/crypto/null_encrypter.h" #include "net/third_party/quiche/src/quic/core/crypto/quic_decrypter.h" @@ -473,7 +475,8 @@ Perspective::IS_CLIENT), expected_connection_id_length_(expected_connection_id_length), should_update_expected_connection_id_length_(false), - supports_multiple_packet_number_spaces_(false) { + supports_multiple_packet_number_spaces_(false), + last_written_packet_number_length_(0) { DCHECK(!supported_versions.empty()); version_ = supported_versions_[0]; decrypter_[ENCRYPTION_INITIAL] = QuicMakeUnique<NullDecrypter>(perspective); @@ -1754,6 +1757,8 @@ return false; } + QuicStringPiece associated_data; + std::vector<char> ad_storage; if (header->form == IETF_QUIC_SHORT_HEADER_PACKET || header->long_packet_type != VERSION_NEGOTIATION) { DCHECK(header->form == IETF_QUIC_SHORT_HEADER_PACKET || @@ -1763,21 +1768,32 @@ // Process packet number. QuicPacketNumber base_packet_number; if (supports_multiple_packet_number_spaces_) { - base_packet_number = - largest_decrypted_packet_numbers_[GetPacketNumberSpace(*header)]; + PacketNumberSpace pn_space = GetPacketNumberSpace(*header); + if (pn_space == NUM_PACKET_NUMBER_SPACES) { + return RaiseError(QUIC_INVALID_PACKET_HEADER); + } + base_packet_number = largest_decrypted_packet_numbers_[pn_space]; } else { base_packet_number = largest_packet_number_; } uint64_t full_packet_number; - if (!ProcessAndCalculatePacketNumber( - encrypted_reader, header->packet_number_length, base_packet_number, - &full_packet_number)) { + bool hp_removal_failed = false; + if (version_.HasHeaderProtection()) { + if (!RemoveHeaderProtection(encrypted_reader, packet, header, + &full_packet_number, &ad_storage)) { + hp_removal_failed = true; + } + associated_data = QuicStringPiece(ad_storage.data(), ad_storage.size()); + } else if (!ProcessAndCalculatePacketNumber( + encrypted_reader, header->packet_number_length, + base_packet_number, &full_packet_number)) { set_detailed_error("Unable to read packet number."); RecordDroppedPacketReason(DroppedPacketReason::INVALID_PACKET_NUMBER); return RaiseError(QUIC_INVALID_PACKET_HEADER); } - if (!IsValidFullPacketNumber(full_packet_number, transport_version())) { + if (hp_removal_failed || + !IsValidFullPacketNumber(full_packet_number, transport_version())) { if (IsIetfStatelessResetPacket(*header)) { // This is a stateless reset packet. QuicIetfStatelessResetPacket packet( @@ -1785,6 +1801,10 @@ visitor_->OnAuthenticatedIetfStatelessResetPacket(packet); return true; } + if (hp_removal_failed) { + set_detailed_error("Unable to decrypt header protection."); + return RaiseError(QUIC_DECRYPTION_FAILURE); + } RecordDroppedPacketReason(DroppedPacketReason::INVALID_PACKET_NUMBER); set_detailed_error("packet numbers cannot be 0."); return RaiseError(QUIC_INVALID_PACKET_HEADER); @@ -1819,13 +1839,15 @@ } QuicStringPiece encrypted = encrypted_reader->ReadRemainingPayload(); - QuicStringPiece associated_data = GetAssociatedDataFromEncryptedPacket( - version_.transport_version, packet, - GetIncludedDestinationConnectionIdLength(*header), - GetIncludedSourceConnectionIdLength(*header), header->version_flag, - header->nonce != nullptr, header->packet_number_length, - header->retry_token_length_length, header->retry_token.length(), - header->length_length); + if (!version_.HasHeaderProtection()) { + associated_data = GetAssociatedDataFromEncryptedPacket( + version_.transport_version, packet, + GetIncludedDestinationConnectionIdLength(*header), + GetIncludedSourceConnectionIdLength(*header), header->version_flag, + header->nonce != nullptr, header->packet_number_length, + header->retry_token_length_length, header->retry_token.length(), + header->length_length); + } size_t decrypted_length = 0; EncryptionLevel decrypted_level; @@ -2202,6 +2224,7 @@ writer)) { return false; } + last_written_packet_number_length_ = header.packet_number_length; if (!header.version_flag) { return true; @@ -2429,8 +2452,12 @@ QuicPacketHeader* header) { QuicPacketNumber base_packet_number; if (supports_multiple_packet_number_spaces_) { - base_packet_number = - largest_decrypted_packet_numbers_[GetPacketNumberSpace(*header)]; + PacketNumberSpace pn_space = GetPacketNumberSpace(*header); + if (pn_space == NUM_PACKET_NUMBER_SPACES) { + set_detailed_error("Unable to determine packet number space."); + return RaiseError(QUIC_INVALID_PACKET_HEADER); + } + base_packet_number = largest_decrypted_packet_numbers_[pn_space]; } else { base_packet_number = largest_packet_number_; } @@ -2528,7 +2555,7 @@ set_detailed_error("Client-initiated RETRY is invalid."); return false; } - } else { + } else if (!header->version.HasHeaderProtection()) { header->packet_number_length = GetLongHeaderPacketNumberLength( header->version.transport_version, type); } @@ -2559,7 +2586,8 @@ set_detailed_error("Fixed bit is 0 in short header."); return false; } - if (!GetShortHeaderPacketNumberLength(transport_version(), type, + if (!header->version.HasHeaderProtection() && + !GetShortHeaderPacketNumberLength(transport_version(), type, infer_packet_header_type_from_version_, &header->packet_number_length)) { set_detailed_error("Illegal short header type value."); @@ -3980,10 +4008,228 @@ RaiseError(QUIC_ENCRYPTION_FAILURE); return 0; } + if (version_.HasHeaderProtection() && + !ApplyHeaderProtection(level, buffer, ad_len + output_length, ad_len)) { + QUIC_DLOG(ERROR) << "Applying header protection failed."; + RaiseError(QUIC_ENCRYPTION_FAILURE); + return 0; + } return ad_len + output_length; } +namespace { + +const size_t kHPSampleLen = 16; + +constexpr bool IsLongHeader(uint8_t type_byte) { + return (type_byte & FLAGS_LONG_HEADER) != 0; +} + +} // namespace + +bool QuicFramer::ApplyHeaderProtection(EncryptionLevel level, + char* buffer, + size_t buffer_len, + size_t ad_len) { + QuicDataReader buffer_reader(buffer, buffer_len); + QuicDataWriter buffer_writer(buffer_len, buffer); + // The sample starts 4 bytes after the start of the packet number. + if (ad_len < last_written_packet_number_length_) { + return false; + } + size_t pn_offset = ad_len - last_written_packet_number_length_; + // Sample the ciphertext and generate the mask to use for header protection. + size_t sample_offset = pn_offset + 4; + QuicDataReader sample_reader(buffer, buffer_len); + QuicStringPiece sample; + if (!sample_reader.Seek(sample_offset) || + !sample_reader.ReadStringPiece(&sample, kHPSampleLen)) { + QUIC_BUG << "Not enough bytes to sample: sample_offset " << sample_offset + << ", sample len: " << kHPSampleLen + << ", buffer len: " << buffer_len; + return false; + } + + std::string mask = encrypter_[level]->GenerateHeaderProtectionMask(sample); + if (mask.empty()) { + QUIC_BUG << "Unable to generate header protection mask."; + return false; + } + QuicDataReader mask_reader(mask.data(), mask.size()); + + // Apply the mask to the 4 or 5 least significant bits of the first byte. + uint8_t bitmask = 0x1f; + uint8_t type_byte; + if (!buffer_reader.ReadUInt8(&type_byte)) { + return false; + } + QuicLongHeaderType header_type; + if (IsLongHeader(type_byte)) { + bitmask = 0x0f; + if (!GetLongHeaderType(version_.transport_version, type_byte, + &header_type)) { + return false; + } + } + uint8_t mask_byte; + if (!mask_reader.ReadUInt8(&mask_byte) || + !buffer_writer.WriteUInt8(type_byte ^ (mask_byte & bitmask))) { + return false; + } + + // Adjust |pn_offset| to account for the diversification nonce. + if (IsLongHeader(type_byte) && header_type == ZERO_RTT_PROTECTED && + perspective_ == Perspective::IS_SERVER && + version_.handshake_protocol == PROTOCOL_QUIC_CRYPTO) { + if (pn_offset <= kDiversificationNonceSize) { + QUIC_BUG << "Expected diversification nonce, but not enough bytes"; + return false; + } + pn_offset -= kDiversificationNonceSize; + } + // Advance the reader and writer to the packet number. Both the reader and + // writer have each read/written one byte. + if (!buffer_writer.Seek(pn_offset - 1) || + !buffer_reader.Seek(pn_offset - 1)) { + return false; + } + // Apply the rest of the mask to the packet number. + for (size_t i = 0; i < last_written_packet_number_length_; ++i) { + uint8_t buffer_byte; + uint8_t mask_byte; + if (!mask_reader.ReadUInt8(&mask_byte) || + !buffer_reader.ReadUInt8(&buffer_byte) || + !buffer_writer.WriteUInt8(buffer_byte ^ mask_byte)) { + return false; + } + } + return true; +} + +bool QuicFramer::RemoveHeaderProtection(QuicDataReader* reader, + const QuicEncryptedPacket& packet, + QuicPacketHeader* header, + uint64_t* full_packet_number, + std::vector<char>* associated_data) { + EncryptionLevel expected_decryption_level = GetEncryptionLevel(*header); + QuicDecrypter* decrypter = decrypter_[expected_decryption_level].get(); + if (decrypter == nullptr) { + QUIC_DVLOG(1) + << "No decrypter available for removing header protection at level " + << expected_decryption_level; + return false; + } + + bool has_diversification_nonce = + header->form == IETF_QUIC_LONG_HEADER_PACKET && + header->long_packet_type == ZERO_RTT_PROTECTED && + perspective_ == Perspective::IS_CLIENT && + version_.handshake_protocol == PROTOCOL_QUIC_CRYPTO; + + // Read a sample from the ciphertext and compute the mask to use for header + // protection. + QuicStringPiece remaining_packet = reader->PeekRemainingPayload(); + QuicDataReader sample_reader(remaining_packet); + + // The sample starts 4 bytes after the start of the packet number. + QuicStringPiece pn; + if (!sample_reader.ReadStringPiece(&pn, 4)) { + QUIC_DVLOG(1) << "Not enough data to sample"; + return false; + } + if (has_diversification_nonce) { + // In Google QUIC, the diversification nonce comes between the packet number + // and the sample. + if (!sample_reader.Seek(kDiversificationNonceSize)) { + QUIC_DVLOG(1) << "No diversification nonce to skip over"; + return false; + } + } + std::string mask = decrypter->GenerateHeaderProtectionMask(&sample_reader); + QuicDataReader mask_reader(mask.data(), mask.size()); + if (mask.empty()) { + QUIC_DVLOG(1) << "Failed to compute mask"; + return false; + } + + // Unmask the rest of the type byte. + uint8_t bitmask = 0x1f; + if (IsLongHeader(header->type_byte)) { + bitmask = 0x0f; + } + uint8_t mask_byte; + if (!mask_reader.ReadUInt8(&mask_byte)) { + QUIC_DVLOG(1) << "No first byte to read from mask"; + return false; + } + header->type_byte ^= (mask_byte & bitmask); + + // Compute the packet number length. + header->packet_number_length = + static_cast<QuicPacketNumberLength>((header->type_byte & 0x03) + 1); + + char pn_buffer[IETF_MAX_PACKET_NUMBER_LENGTH] = {}; + QuicDataWriter pn_writer(QUIC_ARRAYSIZE(pn_buffer), pn_buffer); + + // Read the (protected) packet number from the reader and unmask the packet + // number. + for (size_t i = 0; i < header->packet_number_length; ++i) { + uint8_t protected_pn_byte, mask_byte; + if (!mask_reader.ReadUInt8(&mask_byte) || + !reader->ReadUInt8(&protected_pn_byte) || + !pn_writer.WriteUInt8(protected_pn_byte ^ mask_byte)) { + QUIC_DVLOG(1) << "Failed to unmask packet number"; + return false; + } + } + QuicDataReader packet_number_reader(pn_writer.data(), pn_writer.length()); + QuicPacketNumber base_packet_number; + if (supports_multiple_packet_number_spaces_) { + PacketNumberSpace pn_space = GetPacketNumberSpace(*header); + if (pn_space == NUM_PACKET_NUMBER_SPACES) { + return false; + } + base_packet_number = largest_decrypted_packet_numbers_[pn_space]; + } else { + base_packet_number = largest_packet_number_; + } + if (!ProcessAndCalculatePacketNumber( + &packet_number_reader, header->packet_number_length, + base_packet_number, full_packet_number)) { + return false; + } + + // Get the associated data, and apply the same unmasking operations to it. + QuicStringPiece ad = GetAssociatedDataFromEncryptedPacket( + version_.transport_version, packet, + GetIncludedDestinationConnectionIdLength(*header), + GetIncludedSourceConnectionIdLength(*header), header->version_flag, + has_diversification_nonce, header->packet_number_length, + header->retry_token_length_length, header->retry_token.length(), + header->length_length); + *associated_data = std::vector<char>(ad.begin(), ad.end()); + QuicDataWriter ad_writer(associated_data->size(), associated_data->data()); + + // Apply the unmasked type byte and packet number to |associated_data|. + if (!ad_writer.WriteUInt8(header->type_byte)) { + return false; + } + // Put the packet number at the end of the AD, or if there's a diversification + // nonce, before that (which is at the end of the AD). + size_t seek_len = ad_writer.remaining() - header->packet_number_length; + if (has_diversification_nonce) { + seek_len -= kDiversificationNonceSize; + } + if (!ad_writer.Seek(seek_len) || + !ad_writer.WriteBytes(pn_writer.data(), pn_writer.length())) { + QUIC_DVLOG(1) << "Failed to apply unmasking operations to AD"; + return false; + } + + return true; +} + size_t QuicFramer::EncryptPayload(EncryptionLevel level, QuicPacketNumber packet_number, const QuicPacket& packet, @@ -4012,6 +4258,12 @@ RaiseError(QUIC_ENCRYPTION_FAILURE); return 0; } + if (version_.HasHeaderProtection() && + !ApplyHeaderProtection(level, buffer, ad_len + output_length, ad_len)) { + QUIC_DLOG(ERROR) << "Applying header protection failed."; + RaiseError(QUIC_ENCRYPTION_FAILURE); + return 0; + } return ad_len + output_length; } @@ -5256,7 +5508,9 @@ QUIC_DLOG(INFO) << ENDPOINT << "Error: " << QuicErrorCodeToString(error) << " detail: " << detailed_error_; set_error(error); - visitor_->OnError(this); + if (visitor_) { + visitor_->OnError(this); + } return false; }
diff --git a/quic/core/quic_framer.h b/quic/core/quic_framer.h index 17a87fa..eb811bc 100644 --- a/quic/core/quic_framer.h +++ b/quic/core/quic_framer.h
@@ -591,6 +591,34 @@ size_t num_ack_blocks; }; + // Applies header protection to an IETF QUIC packet header in |buffer| using + // the encrypter for level |level|. The buffer has |buffer_len| bytes of data, + // with the first protected packet bytes starting at |ad_len|. + bool ApplyHeaderProtection(EncryptionLevel level, + char* buffer, + size_t buffer_len, + size_t ad_len); + + // Removes header protection from an IETF QUIC packet header. + // + // The packet number from the header is read from |reader|, where the packet + // number is the next contents in |reader|. |reader| is only advanced by the + // length of the packet number, but it is also used to peek the sample needed + // for removing header protection. + // + // Properties needed for removing header protection are read from |header|. + // The packet number length and type byte are written to |header|. + // + // The packet number, after removing header protection and decoding it, is + // written to |full_packet_number|. Finally, the header, with header + // protection removed, is written to |associated_data| to be used in packet + // decryption. |packet| is used in computing the asociated data. + bool RemoveHeaderProtection(QuicDataReader* reader, + const QuicEncryptedPacket& packet, + QuicPacketHeader* header, + uint64_t* full_packet_number, + std::vector<char>* associated_data); + bool ProcessDataPacket(QuicDataReader* reader, QuicPacketHeader* header, const QuicEncryptedPacket& packet, @@ -941,6 +969,10 @@ // Indicates whether this framer supports multiple packet number spaces. bool supports_multiple_packet_number_spaces_; + + // The length in bytes of the last packet number written to an IETF-framed + // packet. + size_t last_written_packet_number_length_; }; } // namespace quic
diff --git a/quic/core/quic_framer_test.cc b/quic/core/quic_framer_test.cc index 3082a94..700c4cf 100644 --- a/quic/core/quic_framer_test.cc +++ b/quic/core/quic_framer_test.cc
@@ -1013,12 +1013,25 @@ {"Unable to read packet number.", {0x12, 0x34, 0x56, 0x78}}, }; + + PacketFragments packet_hp = { + // type (short header, 4 byte packet number) + {"Unable to read type.", + {0x43}}, + // connection_id + // packet number + {"", + {0x12, 0x34, 0x56, 0x78}}, + }; // clang-format on PacketFragments& fragments = - framer_.transport_version() > QUIC_VERSION_44 - ? packet46 - : (framer_.transport_version() > QUIC_VERSION_43 ? packet44 : packet); + framer_.version().HasHeaderProtection() + ? packet_hp + : framer_.transport_version() > QUIC_VERSION_44 + ? packet46 + : (framer_.transport_version() > QUIC_VERSION_43 ? packet44 + : packet); std::unique_ptr<QuicEncryptedPacket> encrypted( AssemblePacketFromFragments(fragments)); EXPECT_FALSE(framer_.ProcessPacket(*encrypted)); @@ -1173,12 +1186,27 @@ {"Unable to read packet number.", {0x12, 0x34, 0x56, 0x78}}, }; + + PacketFragments packet_hp = { + // type (short header, 4 byte packet number) + {"Unable to read type.", + {0x43}}, + // connection_id + {"Unable to read Destination ConnectionId.", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x56, 0x78}}, + }; // clang-format on PacketFragments& fragments = - framer_.transport_version() > QUIC_VERSION_44 - ? packet46 - : (framer_.transport_version() > QUIC_VERSION_43 ? packet44 : packet); + framer_.version().HasHeaderProtection() + ? packet_hp + : framer_.transport_version() > QUIC_VERSION_44 + ? packet46 + : (framer_.transport_version() > QUIC_VERSION_43 ? packet44 + : packet); std::unique_ptr<QuicEncryptedPacket> encrypted( AssemblePacketFromFragments(fragments)); EXPECT_FALSE(framer_.ProcessPacket(*encrypted)); @@ -1233,16 +1261,38 @@ {"Unable to read packet number.", {0x56, 0x78}}, }; + + PacketFragments packet_hp = { + // type (short header, 2 byte packet number) + {"Unable to read type.", + {0x41}}, + // connection_id + {"Unable to read Destination ConnectionId.", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x56, 0x78}}, + // padding + {"", {0x00, 0x00}}, + }; // clang-format on PacketFragments& fragments = - framer_.transport_version() > QUIC_VERSION_44 - ? packet46 - : (framer_.transport_version() > QUIC_VERSION_43 ? packet44 : packet); + framer_.version().HasHeaderProtection() + ? packet_hp + : (framer_.transport_version() > QUIC_VERSION_44 + ? packet46 + : (framer_.transport_version() > QUIC_VERSION_43 ? packet44 + : packet)); std::unique_ptr<QuicEncryptedPacket> encrypted( AssemblePacketFromFragments(fragments)); - EXPECT_FALSE(framer_.ProcessPacket(*encrypted)); - EXPECT_EQ(QUIC_MISSING_PAYLOAD, framer_.error()); + if (framer_.version().HasHeaderProtection()) { + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + EXPECT_EQ(QUIC_NO_ERROR, framer_.error()); + } else { + EXPECT_FALSE(framer_.ProcessPacket(*encrypted)); + EXPECT_EQ(QUIC_MISSING_PAYLOAD, framer_.error()); + } ASSERT_TRUE(visitor_.header_.get()); EXPECT_EQ(FramerTestConnectionId(), visitor_.header_->destination_connection_id); @@ -1295,16 +1345,38 @@ {0x78}}, }; + PacketFragments packet_hp = { + // type (8 byte connection_id and 1 byte packet number) + {"Unable to read type.", + {0x40}}, + // connection_id + {"Unable to read Destination ConnectionId.", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x78}}, + // padding + {"", {0x00, 0x00, 0x00}}, + }; + // clang-format on PacketFragments& fragments = - framer_.transport_version() > QUIC_VERSION_44 - ? packet46 - : (framer_.transport_version() > QUIC_VERSION_43 ? packet44 : packet); + framer_.version().HasHeaderProtection() + ? packet_hp + : (framer_.transport_version() > QUIC_VERSION_44 + ? packet46 + : (framer_.transport_version() > QUIC_VERSION_43 ? packet44 + : packet)); std::unique_ptr<QuicEncryptedPacket> encrypted( AssemblePacketFromFragments(fragments)); - EXPECT_FALSE(framer_.ProcessPacket(*encrypted)); - EXPECT_EQ(QUIC_MISSING_PAYLOAD, framer_.error()); + if (framer_.version().HasHeaderProtection()) { + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + EXPECT_EQ(QUIC_NO_ERROR, framer_.error()); + } else { + EXPECT_FALSE(framer_.ProcessPacket(*encrypted)); + EXPECT_EQ(QUIC_MISSING_PAYLOAD, framer_.error()); + } ASSERT_TRUE(visitor_.header_.get()); EXPECT_EQ(FramerTestConnectionId(), visitor_.header_->destination_connection_id); @@ -1536,20 +1608,41 @@ 0x00, 0x00, 0x00, 0x00, 0x00 }; + + unsigned char packet45[] = { + // type (long header, ZERO_RTT_PROTECTED, 4-byte packet number) + 0xD3, + // version tag + 'Q', '0', '0', '0', + // connection_id length + 0x50, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (padding frame) + 0x00, + 0x00, 0x00, 0x00, 0x00 + }; // clang-format on - QuicEncryptedPacket encrypted( - AsChars(framer_.transport_version() > QUIC_VERSION_43 ? packet44 - : packet), - framer_.transport_version() > QUIC_VERSION_43 ? QUIC_ARRAYSIZE(packet44) - : QUIC_ARRAYSIZE(packet), - false); + unsigned char* p = packet; + size_t p_size = QUIC_ARRAYSIZE(packet); + if (framer_.transport_version() > QUIC_VERSION_44) { + p = packet45; + p_size = QUIC_ARRAYSIZE(packet45); + } else if (framer_.transport_version() > QUIC_VERSION_43) { + p = packet44; + p_size = QUIC_ARRAYSIZE(packet44); + } + QuicEncryptedPacket encrypted(AsChars(p), p_size, false); EXPECT_TRUE(framer_.ProcessPacket(encrypted)); EXPECT_EQ(QUIC_NO_ERROR, framer_.error()); ASSERT_TRUE(visitor_.header_.get()); EXPECT_EQ(0, visitor_.frame_count_); EXPECT_EQ(1, visitor_.version_mismatch_); - EXPECT_EQ(1u, visitor_.padding_frames_.size()); + ASSERT_EQ(1u, visitor_.padding_frames_.size()); EXPECT_EQ(5, visitor_.padding_frames_[0]->num_padding_bytes); } @@ -1997,7 +2090,10 @@ } QuicEncryptedPacket encrypted(AsChars(p), p_length, false); EXPECT_FALSE(framer_.ProcessPacket(encrypted)); - if (framer_.transport_version() >= QUIC_VERSION_44) { + if (framer_.version().HasHeaderProtection()) { + EXPECT_EQ(QUIC_DECRYPTION_FAILURE, framer_.error()); + EXPECT_EQ("Unable to decrypt header protection.", framer_.detailed_error()); + } else if (framer_.transport_version() >= QUIC_VERSION_44) { // Cannot read diversification nonce. EXPECT_EQ(QUIC_INVALID_PACKET_HEADER, framer_.error()); EXPECT_EQ("Unable to read nonce.", framer_.detailed_error()); @@ -9404,12 +9500,15 @@ 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', + 'q', 'r', 's', 't', }; // clang-format on unsigned char* p = packet; + size_t p_size = QUIC_ARRAYSIZE(packet); if (framer_.transport_version() == QUIC_VERSION_99) { p = packet99; + p_size = QUIC_ARRAYSIZE(packet99); } else if (framer_.transport_version() > QUIC_VERSION_44) { p = packet46; } else if (framer_.transport_version() > QUIC_VERSION_43) { @@ -9417,7 +9516,7 @@ } std::unique_ptr<QuicPacket> raw(new QuicPacket( - AsChars(p), QUIC_ARRAYSIZE(packet), false, PACKET_8BYTE_CONNECTION_ID, + AsChars(p), p_size, false, PACKET_8BYTE_CONNECTION_ID, PACKET_0BYTE_CONNECTION_ID, !kIncludeVersion, !kIncludeDiversificationNonce, PACKET_4BYTE_PACKET_NUMBER, VARIABLE_LENGTH_INTEGER_LENGTH_0, 0, VARIABLE_LENGTH_INTEGER_LENGTH_0)); @@ -9430,6 +9529,7 @@ } TEST_P(QuicFramerTest, EncryptPacketWithVersionFlag) { + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); QuicPacketNumber packet_number = kPacketNumber; // clang-format off unsigned char packet[] = { @@ -9504,26 +9604,28 @@ 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', + 'q', 'r', 's', 't', }; // clang-format on unsigned char* p = packet; + size_t p_size = QUIC_ARRAYSIZE(packet); if (framer_.transport_version() == QUIC_VERSION_99) { p = packet99; + p_size = QUIC_ARRAYSIZE(packet99); } else if (framer_.transport_version() > QUIC_VERSION_44) { p = packet46; + p_size = QUIC_ARRAYSIZE(packet46); } else if (framer_.transport_version() > QUIC_VERSION_43) { p = packet44; + p_size = QUIC_ARRAYSIZE(packet44); } std::unique_ptr<QuicPacket> raw(new QuicPacket( - AsChars(p), - framer_.transport_version() > QUIC_VERSION_43 ? QUIC_ARRAYSIZE(packet44) - : QUIC_ARRAYSIZE(packet), - false, PACKET_8BYTE_CONNECTION_ID, PACKET_0BYTE_CONNECTION_ID, - kIncludeVersion, !kIncludeDiversificationNonce, - PACKET_4BYTE_PACKET_NUMBER, VARIABLE_LENGTH_INTEGER_LENGTH_0, 0, - VARIABLE_LENGTH_INTEGER_LENGTH_0)); + AsChars(p), p_size, false, PACKET_8BYTE_CONNECTION_ID, + PACKET_0BYTE_CONNECTION_ID, kIncludeVersion, + !kIncludeDiversificationNonce, PACKET_4BYTE_PACKET_NUMBER, + VARIABLE_LENGTH_INTEGER_LENGTH_0, 0, VARIABLE_LENGTH_INTEGER_LENGTH_0)); char buffer[kMaxOutgoingPacketSize]; size_t encrypted_length = framer_.EncryptPayload( ENCRYPTION_INITIAL, packet_number, *raw, buffer, kMaxOutgoingPacketSize); @@ -13025,12 +13127,33 @@ {"Unable to read packet number.", {0x78}}, }; + + PacketFragments packet_with_padding = { + // type (8 byte connection_id and 1 byte packet number) + {"Unable to read type.", + {0x40}}, + // connection_id + {"Unable to read Destination ConnectionId.", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, 0x42}}, + // packet number + {"", + {0x78}}, + // padding + {"", {0x00, 0x00, 0x00}}, + }; // clang-format on + PacketFragments& fragments = + framer_.version().HasHeaderProtection() ? packet_with_padding : packet; std::unique_ptr<QuicEncryptedPacket> encrypted( - AssemblePacketFromFragments(packet)); - EXPECT_FALSE(framer_.ProcessPacket(*encrypted)); - EXPECT_EQ(QUIC_MISSING_PAYLOAD, framer_.error()); + AssemblePacketFromFragments(fragments)); + if (framer_.version().HasHeaderProtection()) { + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + EXPECT_EQ(QUIC_NO_ERROR, framer_.error()); + } else { + EXPECT_FALSE(framer_.ProcessPacket(*encrypted)); + EXPECT_EQ(QUIC_MISSING_PAYLOAD, framer_.error()); + } ASSERT_TRUE(visitor_.header_.get()); EXPECT_EQ(connection_id, visitor_.header_->destination_connection_id); EXPECT_FALSE(visitor_.header_->reset_flag); @@ -13038,7 +13161,7 @@ EXPECT_EQ(PACKET_1BYTE_PACKET_NUMBER, visitor_.header_->packet_number_length); EXPECT_EQ(kPacketNumber, visitor_.header_->packet_number); - CheckFramingBoundaries(packet, QUIC_INVALID_PACKET_HEADER); + CheckFramingBoundaries(fragments, QUIC_INVALID_PACKET_HEADER); } TEST_P(QuicFramerTest, UpdateExpectedConnectionIdLength) { @@ -13204,7 +13327,7 @@ // packet number 0x79, // padding frame - 0x00, + 0x00, 0x00, 0x00, }; // clang-format on
diff --git a/quic/core/quic_packet_creator.cc b/quic/core/quic_packet_creator.cc index f3390f0..8199307 100644 --- a/quic/core/quic_packet_creator.cc +++ b/quic/core/quic_packet_creator.cc
@@ -113,6 +113,9 @@ max_packet_length_ = length; max_plaintext_size_ = framer_->GetMaxPlaintextSize(max_packet_length_); + QUIC_BUG_IF(max_plaintext_size_ - PacketHeaderSize() < + MinPlaintextPacketSize()) + << "Attempted to set max packet length too small"; } // Stops serializing version of the protocol in packets sent after this call. @@ -439,6 +442,7 @@ max_plaintext_size_ - writer.length() - min_frame_size; const size_t bytes_consumed = std::min<size_t>(available_size, remaining_data_size); + const size_t plaintext_bytes_written = min_frame_size + bytes_consumed; const bool set_fin = fin && (bytes_consumed == remaining_data_size); QuicStreamFrame frame(id, set_fin, stream_offset, bytes_consumed); @@ -459,6 +463,12 @@ QUIC_BUG << "AppendStreamFrame failed"; return; } + if (plaintext_bytes_written < MinPlaintextPacketSize() && + !writer.WritePaddingBytes(MinPlaintextPacketSize() - + plaintext_bytes_written)) { + QUIC_BUG << "Unable to add padding bytes"; + return; + } if (!framer_->WriteIetfLongHeaderLength(header, &writer, length_field_offset, packet_.encryption_level)) { @@ -537,11 +547,7 @@ if (!queued_frames_.empty()) { return packet_size_; } - packet_size_ = GetPacketHeaderSize( - framer_->transport_version(), GetDestinationConnectionIdLength(), - GetSourceConnectionIdLength(), IncludeVersionInHeader(), - IncludeNonceInPublicHeader(), GetPacketNumberLength(), - GetRetryTokenLengthLength(), GetRetryToken().length(), GetLengthLength()); + packet_size_ = PacketHeaderSize(); return packet_size_; } @@ -771,6 +777,14 @@ return packet_.packet_number_length; } +size_t QuicPacketCreator::PacketHeaderSize() const { + return GetPacketHeaderSize( + framer_->transport_version(), GetDestinationConnectionIdLength(), + GetSourceConnectionIdLength(), IncludeVersionInHeader(), + IncludeNonceInPublicHeader(), GetPacketNumberLength(), + GetRetryTokenLengthLength(), GetRetryToken().length(), GetLengthLength()); +} + QuicVariableLengthIntegerLength QuicPacketCreator::GetRetryTokenLengthLength() const { if (QuicVersionHasLongHeaderLengths(framer_->transport_version()) && @@ -909,11 +923,24 @@ needs_full_padding_ = true; } - if (!needs_full_padding_ && pending_padding_bytes_ == 0) { + // Header protection requires a minimum plaintext packet size. + size_t extra_padding_bytes = 0; + if (framer_->version().HasHeaderProtection()) { + size_t frame_bytes = PacketSize() - PacketHeaderSize(); + + if (frame_bytes + pending_padding_bytes_ < MinPlaintextPacketSize() && + !needs_full_padding_) { + extra_padding_bytes = MinPlaintextPacketSize() - frame_bytes; + } + } + + if (!needs_full_padding_ && pending_padding_bytes_ == 0 && + extra_padding_bytes == 0) { // Do not need padding. return; } + int padding_bytes = -1; if (needs_full_padding_) { // Full padding does not consume pending padding bytes. packet_.num_padding_bytes = -1; @@ -921,11 +948,12 @@ packet_.num_padding_bytes = std::min<int16_t>(pending_padding_bytes_, BytesFree()); pending_padding_bytes_ -= packet_.num_padding_bytes; + padding_bytes = + std::max<int16_t>(packet_.num_padding_bytes, extra_padding_bytes); } - bool success = - AddFrame(QuicFrame(QuicPaddingFrame(packet_.num_padding_bytes)), false, - packet_.transmission_type); + bool success = AddFrame(QuicFrame(QuicPaddingFrame(padding_bytes)), false, + packet_.transmission_type); DCHECK(success); } @@ -1032,5 +1060,31 @@ packet_.encryption_level < ENCRYPTION_FORWARD_SECURE; } +size_t QuicPacketCreator::MinPlaintextPacketSize() const { + if (!framer_->version().HasHeaderProtection()) { + return 0; + } + // Header protection samples 16 bytes of ciphertext starting 4 bytes after the + // packet number. In IETF QUIC, all AEAD algorithms have a 16-byte auth tag + // (i.e. the ciphertext is 16 bytes larger than the plaintext). Since packet + // numbers could be as small as 1 byte, but the sample starts 4 bytes after + // the packet number, at least 3 bytes of plaintext are needed to make sure + // that there is enough ciphertext to sample. + // + // Google QUIC crypto uses different AEAD algorithms - in particular the auth + // tags are only 12 bytes instead of 16 bytes. Since the auth tag is 4 bytes + // shorter, 4 more bytes of plaintext are needed to guarantee there is enough + // ciphertext to sample. + // + // This method could check for PROTOCOL_TLS1_3 vs PROTOCOL_QUIC_CRYPTO and + // return 3 when TLS 1.3 is in use (the use of IETF vs Google QUIC crypters is + // determined based on the handshake protocol used). However, even when TLS + // 1.3 is used, unittests still use NullEncrypter/NullDecrypter (and other + // test crypters) which also only use 12 byte tags. + // + // TODO(nharper): Set this based on the handshake protocol in use. + return 7; +} + #undef ENDPOINT // undef for jumbo builds } // namespace quic
diff --git a/quic/core/quic_packet_creator.h b/quic/core/quic_packet_creator.h index 7ec8739..367ca11 100644 --- a/quic/core/quic_packet_creator.h +++ b/quic/core/quic_packet_creator.h
@@ -281,6 +281,9 @@ return framer_->transport_version(); } + // Returns the minimum size that the plaintext of a packet must be. + size_t MinPlaintextPacketSize() const; + private: friend class test::QuicPacketCreatorPeer; @@ -339,6 +342,9 @@ // function instead. QuicPacketNumberLength GetPacketNumberLength() const; + // Returns the size in bytes of the packet header. + size_t PacketHeaderSize() const; + // Returns whether the destination connection ID is sent over the wire. QuicConnectionIdIncluded GetDestinationConnectionIdIncluded() const;
diff --git a/quic/core/quic_packet_creator_test.cc b/quic/core/quic_packet_creator_test.cc index eddddc6..bd212de 100644 --- a/quic/core/quic_packet_creator_test.cc +++ b/quic/core/quic_packet_creator_test.cc
@@ -652,7 +652,9 @@ const size_t overhead = GetPacketHeaderOverhead(client_framer_.transport_version()) + GetEncryptionOverhead(); - for (size_t i = overhead; i < overhead + 100; ++i) { + for (size_t i = overhead + creator_.MinPlaintextPacketSize(); + i < overhead + 100; ++i) { + SCOPED_TRACE(i); creator_.SetMaxPacketLength(i); const bool should_have_room = i > @@ -1176,6 +1178,44 @@ if (!GetParam().version_serialization) { creator_.StopSendingVersion(); } + std::string data("test data"); + if (!QuicVersionUsesCryptoFrames(client_framer_.transport_version())) { + QuicStreamFrame stream_frame( + QuicUtils::GetCryptoStreamId(client_framer_.transport_version()), + /*fin=*/false, 0u, QuicStringPiece()); + frames_.push_back(QuicFrame(stream_frame)); + } else { + producer_.SaveCryptoData(ENCRYPTION_INITIAL, 0, data); + frames_.push_back( + QuicFrame(new QuicCryptoFrame(ENCRYPTION_INITIAL, 0, data.length()))); + } + SerializedPacket serialized = SerializeAllFrames(frames_); + + QuicPacketHeader header; + { + InSequence s; + EXPECT_CALL(framer_visitor_, OnPacket()); + EXPECT_CALL(framer_visitor_, OnUnauthenticatedPublicHeader(_)); + EXPECT_CALL(framer_visitor_, OnUnauthenticatedHeader(_)); + EXPECT_CALL(framer_visitor_, OnDecryptedPacket(_)); + EXPECT_CALL(framer_visitor_, OnPacketHeader(_)) + .WillOnce(DoAll(SaveArg<0>(&header), Return(true))); + if (QuicVersionUsesCryptoFrames(client_framer_.transport_version())) { + EXPECT_CALL(framer_visitor_, OnCryptoFrame(_)); + } else { + EXPECT_CALL(framer_visitor_, OnStreamFrame(_)); + } + EXPECT_CALL(framer_visitor_, OnPacketComplete()); + } + ProcessPacket(serialized); + EXPECT_EQ(GetParam().version_serialization, header.version_flag); + DeleteFrames(&frames_); +} + +TEST_P(QuicPacketCreatorTest, SerializeFrameShortData) { + if (!GetParam().version_serialization) { + creator_.StopSendingVersion(); + } std::string data("a"); if (!QuicVersionUsesCryptoFrames(client_framer_.transport_version())) { QuicStreamFrame stream_frame( @@ -1203,6 +1243,9 @@ } else { EXPECT_CALL(framer_visitor_, OnStreamFrame(_)); } + if (client_framer_.version().HasHeaderProtection()) { + EXPECT_CALL(framer_visitor_, OnPaddingFrame(_)); + } EXPECT_CALL(framer_visitor_, OnPacketComplete()); } ProcessPacket(serialized); @@ -1803,6 +1846,9 @@ } else { EXPECT_CALL(framer_visitor_, OnStreamFrame(_)); } + if (client_framer_.version().HasHeaderProtection()) { + EXPECT_CALL(framer_visitor_, OnPaddingFrame(_)); + } EXPECT_CALL(framer_visitor_, OnPacketComplete()); } ProcessPacket(serialized);
diff --git a/quic/core/quic_packet_generator_test.cc b/quic/core/quic_packet_generator_test.cc index dfd9ae2..43c3a7f 100644 --- a/quic/core/quic_packet_generator_test.cc +++ b/quic/core/quic_packet_generator_test.cc
@@ -272,7 +272,11 @@ ASSERT_TRUE(packet.encrypted_buffer != nullptr); ASSERT_TRUE(simple_framer_.ProcessPacket( QuicEncryptedPacket(packet.encrypted_buffer, packet.encrypted_length))); - EXPECT_EQ(num_frames, simple_framer_.num_frames()); + size_t num_padding_frames = 0; + if (contents.num_padding_frames == 0) { + num_padding_frames = simple_framer_.padding_frames().size(); + } + EXPECT_EQ(num_frames + num_padding_frames, simple_framer_.num_frames()); EXPECT_EQ(contents.num_ack_frames, simple_framer_.ack_frames().size()); EXPECT_EQ(contents.num_connection_close_frames, simple_framer_.connection_close_frames().size()); @@ -286,8 +290,10 @@ simple_framer_.crypto_frames().size()); EXPECT_EQ(contents.num_stop_waiting_frames, simple_framer_.stop_waiting_frames().size()); - EXPECT_EQ(contents.num_padding_frames, - simple_framer_.padding_frames().size()); + if (contents.num_padding_frames != 0) { + EXPECT_EQ(contents.num_padding_frames, + simple_framer_.padding_frames().size()); + } // From the receiver's perspective, MTU discovery frames are ping frames. EXPECT_EQ(contents.num_ping_frames + contents.num_mtu_discovery_frames, @@ -581,11 +587,11 @@ EXPECT_CALL(delegate_, OnSerializedPacket(_)) .WillOnce(Invoke(this, &QuicPacketGeneratorTest::SavePacket)); - MakeIOVector("foo", &iov_); + MakeIOVector("foo bar", &iov_); QuicConsumedData consumed = generator_.ConsumeData( QuicUtils::GetCryptoStreamId(framer_.transport_version()), &iov_, 1u, iov_.iov_len, 0, NO_FIN); - EXPECT_EQ(3u, consumed.bytes_consumed); + EXPECT_EQ(7u, consumed.bytes_consumed); EXPECT_FALSE(generator_.HasQueuedFrames()); EXPECT_FALSE(generator_.HasRetransmittableFrames());
diff --git a/quic/core/quic_types.h b/quic/core/quic_types.h index cc4fe2a..199ec94 100644 --- a/quic/core/quic_types.h +++ b/quic/core/quic_types.h
@@ -299,6 +299,7 @@ PACKET_2BYTE_PACKET_NUMBER = 2, PACKET_3BYTE_PACKET_NUMBER = 3, // Used in version > QUIC_VERSION_44. PACKET_4BYTE_PACKET_NUMBER = 4, + IETF_MAX_PACKET_NUMBER_LENGTH = 4, // TODO(rch): Remove this when we remove QUIC_VERSION_39. PACKET_6BYTE_PACKET_NUMBER = 6, PACKET_8BYTE_PACKET_NUMBER = 8
diff --git a/quic/test_tools/quic_test_utils.cc b/quic/test_tools/quic_test_utils.cc index e51bbae..b4bd25d 100644 --- a/quic/test_tools/quic_test_utils.cc +++ b/quic/test_tools/quic_test_utils.cc
@@ -885,15 +885,40 @@ QuicFrames frames; QuicFramer framer(*versions, QuicTime::Zero(), perspective, kQuicDefaultConnectionIdLength); - if (!QuicVersionUsesCryptoFrames((*versions)[0].transport_version)) { - QuicFrame frame(QuicStreamFrame( - QuicUtils::GetCryptoStreamId((*versions)[0].transport_version), false, - 0, QuicStringPiece(data))); + ParsedQuicVersion version = (*versions)[0]; + EncryptionLevel level = + header.version_flag ? ENCRYPTION_INITIAL : ENCRYPTION_FORWARD_SECURE; + if (version.handshake_protocol == PROTOCOL_TLS1_3 && + level == ENCRYPTION_INITIAL) { + CrypterPair crypters; + CryptoUtils::CreateTlsInitialCrypters(Perspective::IS_CLIENT, + version.transport_version, + destination_connection_id, &crypters); + framer.SetEncrypter(ENCRYPTION_INITIAL, std::move(crypters.encrypter)); + if (version.KnowsWhichDecrypterToUse()) { + framer.InstallDecrypter(ENCRYPTION_INITIAL, + std::move(crypters.decrypter)); + } else { + framer.SetDecrypter(ENCRYPTION_INITIAL, std::move(crypters.decrypter)); + } + } + if (!QuicVersionUsesCryptoFrames(version.transport_version)) { + QuicFrame frame( + QuicStreamFrame(QuicUtils::GetCryptoStreamId(version.transport_version), + false, 0, QuicStringPiece(data))); frames.push_back(frame); } else { - QuicFrame frame(new QuicCryptoFrame(ENCRYPTION_INITIAL, 0, data)); + QuicFrame frame(new QuicCryptoFrame(level, 0, data)); frames.push_back(frame); } + // We need a minimum of 7 bytes of encrypted payload. (See + // QuicPacketCreator::kMinPlaintextPacketSize.) This will guarantee that we + // have at least that much. (It ignores the overhead of the stream/crypto + // framing, so it overpads slightly.) + if (data.length() < 7) { + size_t padding_length = 7 - data.length(); + frames.push_back(QuicFrame(QuicPaddingFrame(padding_length))); + } std::unique_ptr<QuicPacket> packet( BuildUnsizedDataPacket(&framer, header, frames)); @@ -946,9 +971,26 @@ QuicFrame frame(QuicStreamFrame(1, false, 0, QuicStringPiece(data))); QuicFrames frames; frames.push_back(frame); + ParsedQuicVersion version = + (versions != nullptr ? *versions : AllSupportedVersions())[0]; QuicFramer framer(versions != nullptr ? *versions : AllSupportedVersions(), QuicTime::Zero(), perspective, kQuicDefaultConnectionIdLength); + if (version.handshake_protocol == PROTOCOL_TLS1_3 && version_flag) { + CrypterPair crypters; + CryptoUtils::CreateTlsInitialCrypters(Perspective::IS_CLIENT, + version.transport_version, + destination_connection_id, &crypters); + framer.SetEncrypter(ENCRYPTION_INITIAL, std::move(crypters.encrypter)); + framer.SetDecrypter(ENCRYPTION_INITIAL, std::move(crypters.decrypter)); + } + // We need a minimum of 7 bytes of encrypted payload. This will guarantee that + // we have at least that much. (It ignores the overhead of the stream/crypto + // framing, so it overpads slightly.) + if (data.length() < 7) { + size_t padding_length = 7 - data.length(); + frames.push_back(QuicFrame(QuicPaddingFrame(padding_length))); + } std::unique_ptr<QuicPacket> packet( BuildUnsizedDataPacket(&framer, header, frames));