Implement QUIC Header Protection
gfe-relnote: Protected by QUIC_VERSION_99
PiperOrigin-RevId: 247137283
Change-Id: I1deb08d304b7739c3c8fa6b995e55fbd8652dc1e
diff --git a/quic/core/quic_framer.cc b/quic/core/quic_framer.cc
index c4eefb1..7524b17 100644
--- a/quic/core/quic_framer.cc
+++ b/quic/core/quic_framer.cc
@@ -10,8 +10,10 @@
#include <string>
#include "net/third_party/quiche/src/quic/core/crypto/crypto_framer.h"
+#include "net/third_party/quiche/src/quic/core/crypto/crypto_handshake.h"
#include "net/third_party/quiche/src/quic/core/crypto/crypto_handshake_message.h"
#include "net/third_party/quiche/src/quic/core/crypto/crypto_protocol.h"
+#include "net/third_party/quiche/src/quic/core/crypto/crypto_utils.h"
#include "net/third_party/quiche/src/quic/core/crypto/null_decrypter.h"
#include "net/third_party/quiche/src/quic/core/crypto/null_encrypter.h"
#include "net/third_party/quiche/src/quic/core/crypto/quic_decrypter.h"
@@ -473,7 +475,8 @@
Perspective::IS_CLIENT),
expected_connection_id_length_(expected_connection_id_length),
should_update_expected_connection_id_length_(false),
- supports_multiple_packet_number_spaces_(false) {
+ supports_multiple_packet_number_spaces_(false),
+ last_written_packet_number_length_(0) {
DCHECK(!supported_versions.empty());
version_ = supported_versions_[0];
decrypter_[ENCRYPTION_INITIAL] = QuicMakeUnique<NullDecrypter>(perspective);
@@ -1754,6 +1757,8 @@
return false;
}
+ QuicStringPiece associated_data;
+ std::vector<char> ad_storage;
if (header->form == IETF_QUIC_SHORT_HEADER_PACKET ||
header->long_packet_type != VERSION_NEGOTIATION) {
DCHECK(header->form == IETF_QUIC_SHORT_HEADER_PACKET ||
@@ -1763,21 +1768,32 @@
// Process packet number.
QuicPacketNumber base_packet_number;
if (supports_multiple_packet_number_spaces_) {
- base_packet_number =
- largest_decrypted_packet_numbers_[GetPacketNumberSpace(*header)];
+ PacketNumberSpace pn_space = GetPacketNumberSpace(*header);
+ if (pn_space == NUM_PACKET_NUMBER_SPACES) {
+ return RaiseError(QUIC_INVALID_PACKET_HEADER);
+ }
+ base_packet_number = largest_decrypted_packet_numbers_[pn_space];
} else {
base_packet_number = largest_packet_number_;
}
uint64_t full_packet_number;
- if (!ProcessAndCalculatePacketNumber(
- encrypted_reader, header->packet_number_length, base_packet_number,
- &full_packet_number)) {
+ bool hp_removal_failed = false;
+ if (version_.HasHeaderProtection()) {
+ if (!RemoveHeaderProtection(encrypted_reader, packet, header,
+ &full_packet_number, &ad_storage)) {
+ hp_removal_failed = true;
+ }
+ associated_data = QuicStringPiece(ad_storage.data(), ad_storage.size());
+ } else if (!ProcessAndCalculatePacketNumber(
+ encrypted_reader, header->packet_number_length,
+ base_packet_number, &full_packet_number)) {
set_detailed_error("Unable to read packet number.");
RecordDroppedPacketReason(DroppedPacketReason::INVALID_PACKET_NUMBER);
return RaiseError(QUIC_INVALID_PACKET_HEADER);
}
- if (!IsValidFullPacketNumber(full_packet_number, transport_version())) {
+ if (hp_removal_failed ||
+ !IsValidFullPacketNumber(full_packet_number, transport_version())) {
if (IsIetfStatelessResetPacket(*header)) {
// This is a stateless reset packet.
QuicIetfStatelessResetPacket packet(
@@ -1785,6 +1801,10 @@
visitor_->OnAuthenticatedIetfStatelessResetPacket(packet);
return true;
}
+ if (hp_removal_failed) {
+ set_detailed_error("Unable to decrypt header protection.");
+ return RaiseError(QUIC_DECRYPTION_FAILURE);
+ }
RecordDroppedPacketReason(DroppedPacketReason::INVALID_PACKET_NUMBER);
set_detailed_error("packet numbers cannot be 0.");
return RaiseError(QUIC_INVALID_PACKET_HEADER);
@@ -1819,13 +1839,15 @@
}
QuicStringPiece encrypted = encrypted_reader->ReadRemainingPayload();
- QuicStringPiece associated_data = GetAssociatedDataFromEncryptedPacket(
- version_.transport_version, packet,
- GetIncludedDestinationConnectionIdLength(*header),
- GetIncludedSourceConnectionIdLength(*header), header->version_flag,
- header->nonce != nullptr, header->packet_number_length,
- header->retry_token_length_length, header->retry_token.length(),
- header->length_length);
+ if (!version_.HasHeaderProtection()) {
+ associated_data = GetAssociatedDataFromEncryptedPacket(
+ version_.transport_version, packet,
+ GetIncludedDestinationConnectionIdLength(*header),
+ GetIncludedSourceConnectionIdLength(*header), header->version_flag,
+ header->nonce != nullptr, header->packet_number_length,
+ header->retry_token_length_length, header->retry_token.length(),
+ header->length_length);
+ }
size_t decrypted_length = 0;
EncryptionLevel decrypted_level;
@@ -2202,6 +2224,7 @@
writer)) {
return false;
}
+ last_written_packet_number_length_ = header.packet_number_length;
if (!header.version_flag) {
return true;
@@ -2429,8 +2452,12 @@
QuicPacketHeader* header) {
QuicPacketNumber base_packet_number;
if (supports_multiple_packet_number_spaces_) {
- base_packet_number =
- largest_decrypted_packet_numbers_[GetPacketNumberSpace(*header)];
+ PacketNumberSpace pn_space = GetPacketNumberSpace(*header);
+ if (pn_space == NUM_PACKET_NUMBER_SPACES) {
+ set_detailed_error("Unable to determine packet number space.");
+ return RaiseError(QUIC_INVALID_PACKET_HEADER);
+ }
+ base_packet_number = largest_decrypted_packet_numbers_[pn_space];
} else {
base_packet_number = largest_packet_number_;
}
@@ -2528,7 +2555,7 @@
set_detailed_error("Client-initiated RETRY is invalid.");
return false;
}
- } else {
+ } else if (!header->version.HasHeaderProtection()) {
header->packet_number_length = GetLongHeaderPacketNumberLength(
header->version.transport_version, type);
}
@@ -2559,7 +2586,8 @@
set_detailed_error("Fixed bit is 0 in short header.");
return false;
}
- if (!GetShortHeaderPacketNumberLength(transport_version(), type,
+ if (!header->version.HasHeaderProtection() &&
+ !GetShortHeaderPacketNumberLength(transport_version(), type,
infer_packet_header_type_from_version_,
&header->packet_number_length)) {
set_detailed_error("Illegal short header type value.");
@@ -3980,10 +4008,228 @@
RaiseError(QUIC_ENCRYPTION_FAILURE);
return 0;
}
+ if (version_.HasHeaderProtection() &&
+ !ApplyHeaderProtection(level, buffer, ad_len + output_length, ad_len)) {
+ QUIC_DLOG(ERROR) << "Applying header protection failed.";
+ RaiseError(QUIC_ENCRYPTION_FAILURE);
+ return 0;
+ }
return ad_len + output_length;
}
+namespace {
+
+const size_t kHPSampleLen = 16;
+
+constexpr bool IsLongHeader(uint8_t type_byte) {
+ return (type_byte & FLAGS_LONG_HEADER) != 0;
+}
+
+} // namespace
+
+bool QuicFramer::ApplyHeaderProtection(EncryptionLevel level,
+ char* buffer,
+ size_t buffer_len,
+ size_t ad_len) {
+ QuicDataReader buffer_reader(buffer, buffer_len);
+ QuicDataWriter buffer_writer(buffer_len, buffer);
+ // The sample starts 4 bytes after the start of the packet number.
+ if (ad_len < last_written_packet_number_length_) {
+ return false;
+ }
+ size_t pn_offset = ad_len - last_written_packet_number_length_;
+ // Sample the ciphertext and generate the mask to use for header protection.
+ size_t sample_offset = pn_offset + 4;
+ QuicDataReader sample_reader(buffer, buffer_len);
+ QuicStringPiece sample;
+ if (!sample_reader.Seek(sample_offset) ||
+ !sample_reader.ReadStringPiece(&sample, kHPSampleLen)) {
+ QUIC_BUG << "Not enough bytes to sample: sample_offset " << sample_offset
+ << ", sample len: " << kHPSampleLen
+ << ", buffer len: " << buffer_len;
+ return false;
+ }
+
+ std::string mask = encrypter_[level]->GenerateHeaderProtectionMask(sample);
+ if (mask.empty()) {
+ QUIC_BUG << "Unable to generate header protection mask.";
+ return false;
+ }
+ QuicDataReader mask_reader(mask.data(), mask.size());
+
+ // Apply the mask to the 4 or 5 least significant bits of the first byte.
+ uint8_t bitmask = 0x1f;
+ uint8_t type_byte;
+ if (!buffer_reader.ReadUInt8(&type_byte)) {
+ return false;
+ }
+ QuicLongHeaderType header_type;
+ if (IsLongHeader(type_byte)) {
+ bitmask = 0x0f;
+ if (!GetLongHeaderType(version_.transport_version, type_byte,
+ &header_type)) {
+ return false;
+ }
+ }
+ uint8_t mask_byte;
+ if (!mask_reader.ReadUInt8(&mask_byte) ||
+ !buffer_writer.WriteUInt8(type_byte ^ (mask_byte & bitmask))) {
+ return false;
+ }
+
+ // Adjust |pn_offset| to account for the diversification nonce.
+ if (IsLongHeader(type_byte) && header_type == ZERO_RTT_PROTECTED &&
+ perspective_ == Perspective::IS_SERVER &&
+ version_.handshake_protocol == PROTOCOL_QUIC_CRYPTO) {
+ if (pn_offset <= kDiversificationNonceSize) {
+ QUIC_BUG << "Expected diversification nonce, but not enough bytes";
+ return false;
+ }
+ pn_offset -= kDiversificationNonceSize;
+ }
+ // Advance the reader and writer to the packet number. Both the reader and
+ // writer have each read/written one byte.
+ if (!buffer_writer.Seek(pn_offset - 1) ||
+ !buffer_reader.Seek(pn_offset - 1)) {
+ return false;
+ }
+ // Apply the rest of the mask to the packet number.
+ for (size_t i = 0; i < last_written_packet_number_length_; ++i) {
+ uint8_t buffer_byte;
+ uint8_t mask_byte;
+ if (!mask_reader.ReadUInt8(&mask_byte) ||
+ !buffer_reader.ReadUInt8(&buffer_byte) ||
+ !buffer_writer.WriteUInt8(buffer_byte ^ mask_byte)) {
+ return false;
+ }
+ }
+ return true;
+}
+
+bool QuicFramer::RemoveHeaderProtection(QuicDataReader* reader,
+ const QuicEncryptedPacket& packet,
+ QuicPacketHeader* header,
+ uint64_t* full_packet_number,
+ std::vector<char>* associated_data) {
+ EncryptionLevel expected_decryption_level = GetEncryptionLevel(*header);
+ QuicDecrypter* decrypter = decrypter_[expected_decryption_level].get();
+ if (decrypter == nullptr) {
+ QUIC_DVLOG(1)
+ << "No decrypter available for removing header protection at level "
+ << expected_decryption_level;
+ return false;
+ }
+
+ bool has_diversification_nonce =
+ header->form == IETF_QUIC_LONG_HEADER_PACKET &&
+ header->long_packet_type == ZERO_RTT_PROTECTED &&
+ perspective_ == Perspective::IS_CLIENT &&
+ version_.handshake_protocol == PROTOCOL_QUIC_CRYPTO;
+
+ // Read a sample from the ciphertext and compute the mask to use for header
+ // protection.
+ QuicStringPiece remaining_packet = reader->PeekRemainingPayload();
+ QuicDataReader sample_reader(remaining_packet);
+
+ // The sample starts 4 bytes after the start of the packet number.
+ QuicStringPiece pn;
+ if (!sample_reader.ReadStringPiece(&pn, 4)) {
+ QUIC_DVLOG(1) << "Not enough data to sample";
+ return false;
+ }
+ if (has_diversification_nonce) {
+ // In Google QUIC, the diversification nonce comes between the packet number
+ // and the sample.
+ if (!sample_reader.Seek(kDiversificationNonceSize)) {
+ QUIC_DVLOG(1) << "No diversification nonce to skip over";
+ return false;
+ }
+ }
+ std::string mask = decrypter->GenerateHeaderProtectionMask(&sample_reader);
+ QuicDataReader mask_reader(mask.data(), mask.size());
+ if (mask.empty()) {
+ QUIC_DVLOG(1) << "Failed to compute mask";
+ return false;
+ }
+
+ // Unmask the rest of the type byte.
+ uint8_t bitmask = 0x1f;
+ if (IsLongHeader(header->type_byte)) {
+ bitmask = 0x0f;
+ }
+ uint8_t mask_byte;
+ if (!mask_reader.ReadUInt8(&mask_byte)) {
+ QUIC_DVLOG(1) << "No first byte to read from mask";
+ return false;
+ }
+ header->type_byte ^= (mask_byte & bitmask);
+
+ // Compute the packet number length.
+ header->packet_number_length =
+ static_cast<QuicPacketNumberLength>((header->type_byte & 0x03) + 1);
+
+ char pn_buffer[IETF_MAX_PACKET_NUMBER_LENGTH] = {};
+ QuicDataWriter pn_writer(QUIC_ARRAYSIZE(pn_buffer), pn_buffer);
+
+ // Read the (protected) packet number from the reader and unmask the packet
+ // number.
+ for (size_t i = 0; i < header->packet_number_length; ++i) {
+ uint8_t protected_pn_byte, mask_byte;
+ if (!mask_reader.ReadUInt8(&mask_byte) ||
+ !reader->ReadUInt8(&protected_pn_byte) ||
+ !pn_writer.WriteUInt8(protected_pn_byte ^ mask_byte)) {
+ QUIC_DVLOG(1) << "Failed to unmask packet number";
+ return false;
+ }
+ }
+ QuicDataReader packet_number_reader(pn_writer.data(), pn_writer.length());
+ QuicPacketNumber base_packet_number;
+ if (supports_multiple_packet_number_spaces_) {
+ PacketNumberSpace pn_space = GetPacketNumberSpace(*header);
+ if (pn_space == NUM_PACKET_NUMBER_SPACES) {
+ return false;
+ }
+ base_packet_number = largest_decrypted_packet_numbers_[pn_space];
+ } else {
+ base_packet_number = largest_packet_number_;
+ }
+ if (!ProcessAndCalculatePacketNumber(
+ &packet_number_reader, header->packet_number_length,
+ base_packet_number, full_packet_number)) {
+ return false;
+ }
+
+ // Get the associated data, and apply the same unmasking operations to it.
+ QuicStringPiece ad = GetAssociatedDataFromEncryptedPacket(
+ version_.transport_version, packet,
+ GetIncludedDestinationConnectionIdLength(*header),
+ GetIncludedSourceConnectionIdLength(*header), header->version_flag,
+ has_diversification_nonce, header->packet_number_length,
+ header->retry_token_length_length, header->retry_token.length(),
+ header->length_length);
+ *associated_data = std::vector<char>(ad.begin(), ad.end());
+ QuicDataWriter ad_writer(associated_data->size(), associated_data->data());
+
+ // Apply the unmasked type byte and packet number to |associated_data|.
+ if (!ad_writer.WriteUInt8(header->type_byte)) {
+ return false;
+ }
+ // Put the packet number at the end of the AD, or if there's a diversification
+ // nonce, before that (which is at the end of the AD).
+ size_t seek_len = ad_writer.remaining() - header->packet_number_length;
+ if (has_diversification_nonce) {
+ seek_len -= kDiversificationNonceSize;
+ }
+ if (!ad_writer.Seek(seek_len) ||
+ !ad_writer.WriteBytes(pn_writer.data(), pn_writer.length())) {
+ QUIC_DVLOG(1) << "Failed to apply unmasking operations to AD";
+ return false;
+ }
+
+ return true;
+}
+
size_t QuicFramer::EncryptPayload(EncryptionLevel level,
QuicPacketNumber packet_number,
const QuicPacket& packet,
@@ -4012,6 +4258,12 @@
RaiseError(QUIC_ENCRYPTION_FAILURE);
return 0;
}
+ if (version_.HasHeaderProtection() &&
+ !ApplyHeaderProtection(level, buffer, ad_len + output_length, ad_len)) {
+ QUIC_DLOG(ERROR) << "Applying header protection failed.";
+ RaiseError(QUIC_ENCRYPTION_FAILURE);
+ return 0;
+ }
return ad_len + output_length;
}
@@ -5256,7 +5508,9 @@
QUIC_DLOG(INFO) << ENDPOINT << "Error: " << QuicErrorCodeToString(error)
<< " detail: " << detailed_error_;
set_error(error);
- visitor_->OnError(this);
+ if (visitor_) {
+ visitor_->OnError(this);
+ }
return false;
}