Implement QUIC Header Protection

gfe-relnote: Protected by QUIC_VERSION_99
PiperOrigin-RevId: 247137283
Change-Id: I1deb08d304b7739c3c8fa6b995e55fbd8652dc1e
diff --git a/quic/core/quic_framer.cc b/quic/core/quic_framer.cc
index c4eefb1..7524b17 100644
--- a/quic/core/quic_framer.cc
+++ b/quic/core/quic_framer.cc
@@ -10,8 +10,10 @@
 #include <string>
 
 #include "net/third_party/quiche/src/quic/core/crypto/crypto_framer.h"
+#include "net/third_party/quiche/src/quic/core/crypto/crypto_handshake.h"
 #include "net/third_party/quiche/src/quic/core/crypto/crypto_handshake_message.h"
 #include "net/third_party/quiche/src/quic/core/crypto/crypto_protocol.h"
+#include "net/third_party/quiche/src/quic/core/crypto/crypto_utils.h"
 #include "net/third_party/quiche/src/quic/core/crypto/null_decrypter.h"
 #include "net/third_party/quiche/src/quic/core/crypto/null_encrypter.h"
 #include "net/third_party/quiche/src/quic/core/crypto/quic_decrypter.h"
@@ -473,7 +475,8 @@
                                              Perspective::IS_CLIENT),
       expected_connection_id_length_(expected_connection_id_length),
       should_update_expected_connection_id_length_(false),
-      supports_multiple_packet_number_spaces_(false) {
+      supports_multiple_packet_number_spaces_(false),
+      last_written_packet_number_length_(0) {
   DCHECK(!supported_versions.empty());
   version_ = supported_versions_[0];
   decrypter_[ENCRYPTION_INITIAL] = QuicMakeUnique<NullDecrypter>(perspective);
@@ -1754,6 +1757,8 @@
     return false;
   }
 
+  QuicStringPiece associated_data;
+  std::vector<char> ad_storage;
   if (header->form == IETF_QUIC_SHORT_HEADER_PACKET ||
       header->long_packet_type != VERSION_NEGOTIATION) {
     DCHECK(header->form == IETF_QUIC_SHORT_HEADER_PACKET ||
@@ -1763,21 +1768,32 @@
     // Process packet number.
     QuicPacketNumber base_packet_number;
     if (supports_multiple_packet_number_spaces_) {
-      base_packet_number =
-          largest_decrypted_packet_numbers_[GetPacketNumberSpace(*header)];
+      PacketNumberSpace pn_space = GetPacketNumberSpace(*header);
+      if (pn_space == NUM_PACKET_NUMBER_SPACES) {
+        return RaiseError(QUIC_INVALID_PACKET_HEADER);
+      }
+      base_packet_number = largest_decrypted_packet_numbers_[pn_space];
     } else {
       base_packet_number = largest_packet_number_;
     }
     uint64_t full_packet_number;
-    if (!ProcessAndCalculatePacketNumber(
-            encrypted_reader, header->packet_number_length, base_packet_number,
-            &full_packet_number)) {
+    bool hp_removal_failed = false;
+    if (version_.HasHeaderProtection()) {
+      if (!RemoveHeaderProtection(encrypted_reader, packet, header,
+                                  &full_packet_number, &ad_storage)) {
+        hp_removal_failed = true;
+      }
+      associated_data = QuicStringPiece(ad_storage.data(), ad_storage.size());
+    } else if (!ProcessAndCalculatePacketNumber(
+                   encrypted_reader, header->packet_number_length,
+                   base_packet_number, &full_packet_number)) {
       set_detailed_error("Unable to read packet number.");
       RecordDroppedPacketReason(DroppedPacketReason::INVALID_PACKET_NUMBER);
       return RaiseError(QUIC_INVALID_PACKET_HEADER);
     }
 
-    if (!IsValidFullPacketNumber(full_packet_number, transport_version())) {
+    if (hp_removal_failed ||
+        !IsValidFullPacketNumber(full_packet_number, transport_version())) {
       if (IsIetfStatelessResetPacket(*header)) {
         // This is a stateless reset packet.
         QuicIetfStatelessResetPacket packet(
@@ -1785,6 +1801,10 @@
         visitor_->OnAuthenticatedIetfStatelessResetPacket(packet);
         return true;
       }
+      if (hp_removal_failed) {
+        set_detailed_error("Unable to decrypt header protection.");
+        return RaiseError(QUIC_DECRYPTION_FAILURE);
+      }
       RecordDroppedPacketReason(DroppedPacketReason::INVALID_PACKET_NUMBER);
       set_detailed_error("packet numbers cannot be 0.");
       return RaiseError(QUIC_INVALID_PACKET_HEADER);
@@ -1819,13 +1839,15 @@
   }
 
   QuicStringPiece encrypted = encrypted_reader->ReadRemainingPayload();
-  QuicStringPiece associated_data = GetAssociatedDataFromEncryptedPacket(
-      version_.transport_version, packet,
-      GetIncludedDestinationConnectionIdLength(*header),
-      GetIncludedSourceConnectionIdLength(*header), header->version_flag,
-      header->nonce != nullptr, header->packet_number_length,
-      header->retry_token_length_length, header->retry_token.length(),
-      header->length_length);
+  if (!version_.HasHeaderProtection()) {
+    associated_data = GetAssociatedDataFromEncryptedPacket(
+        version_.transport_version, packet,
+        GetIncludedDestinationConnectionIdLength(*header),
+        GetIncludedSourceConnectionIdLength(*header), header->version_flag,
+        header->nonce != nullptr, header->packet_number_length,
+        header->retry_token_length_length, header->retry_token.length(),
+        header->length_length);
+  }
 
   size_t decrypted_length = 0;
   EncryptionLevel decrypted_level;
@@ -2202,6 +2224,7 @@
                           writer)) {
     return false;
   }
+  last_written_packet_number_length_ = header.packet_number_length;
 
   if (!header.version_flag) {
     return true;
@@ -2429,8 +2452,12 @@
                                               QuicPacketHeader* header) {
   QuicPacketNumber base_packet_number;
   if (supports_multiple_packet_number_spaces_) {
-    base_packet_number =
-        largest_decrypted_packet_numbers_[GetPacketNumberSpace(*header)];
+    PacketNumberSpace pn_space = GetPacketNumberSpace(*header);
+    if (pn_space == NUM_PACKET_NUMBER_SPACES) {
+      set_detailed_error("Unable to determine packet number space.");
+      return RaiseError(QUIC_INVALID_PACKET_HEADER);
+    }
+    base_packet_number = largest_decrypted_packet_numbers_[pn_space];
   } else {
     base_packet_number = largest_packet_number_;
   }
@@ -2528,7 +2555,7 @@
             set_detailed_error("Client-initiated RETRY is invalid.");
             return false;
           }
-        } else {
+        } else if (!header->version.HasHeaderProtection()) {
           header->packet_number_length = GetLongHeaderPacketNumberLength(
               header->version.transport_version, type);
         }
@@ -2559,7 +2586,8 @@
     set_detailed_error("Fixed bit is 0 in short header.");
     return false;
   }
-  if (!GetShortHeaderPacketNumberLength(transport_version(), type,
+  if (!header->version.HasHeaderProtection() &&
+      !GetShortHeaderPacketNumberLength(transport_version(), type,
                                         infer_packet_header_type_from_version_,
                                         &header->packet_number_length)) {
     set_detailed_error("Illegal short header type value.");
@@ -3980,10 +4008,228 @@
     RaiseError(QUIC_ENCRYPTION_FAILURE);
     return 0;
   }
+  if (version_.HasHeaderProtection() &&
+      !ApplyHeaderProtection(level, buffer, ad_len + output_length, ad_len)) {
+    QUIC_DLOG(ERROR) << "Applying header protection failed.";
+    RaiseError(QUIC_ENCRYPTION_FAILURE);
+    return 0;
+  }
 
   return ad_len + output_length;
 }
 
+namespace {
+
+const size_t kHPSampleLen = 16;
+
+constexpr bool IsLongHeader(uint8_t type_byte) {
+  return (type_byte & FLAGS_LONG_HEADER) != 0;
+}
+
+}  // namespace
+
+bool QuicFramer::ApplyHeaderProtection(EncryptionLevel level,
+                                       char* buffer,
+                                       size_t buffer_len,
+                                       size_t ad_len) {
+  QuicDataReader buffer_reader(buffer, buffer_len);
+  QuicDataWriter buffer_writer(buffer_len, buffer);
+  // The sample starts 4 bytes after the start of the packet number.
+  if (ad_len < last_written_packet_number_length_) {
+    return false;
+  }
+  size_t pn_offset = ad_len - last_written_packet_number_length_;
+  // Sample the ciphertext and generate the mask to use for header protection.
+  size_t sample_offset = pn_offset + 4;
+  QuicDataReader sample_reader(buffer, buffer_len);
+  QuicStringPiece sample;
+  if (!sample_reader.Seek(sample_offset) ||
+      !sample_reader.ReadStringPiece(&sample, kHPSampleLen)) {
+    QUIC_BUG << "Not enough bytes to sample: sample_offset " << sample_offset
+             << ", sample len: " << kHPSampleLen
+             << ", buffer len: " << buffer_len;
+    return false;
+  }
+
+  std::string mask = encrypter_[level]->GenerateHeaderProtectionMask(sample);
+  if (mask.empty()) {
+    QUIC_BUG << "Unable to generate header protection mask.";
+    return false;
+  }
+  QuicDataReader mask_reader(mask.data(), mask.size());
+
+  // Apply the mask to the 4 or 5 least significant bits of the first byte.
+  uint8_t bitmask = 0x1f;
+  uint8_t type_byte;
+  if (!buffer_reader.ReadUInt8(&type_byte)) {
+    return false;
+  }
+  QuicLongHeaderType header_type;
+  if (IsLongHeader(type_byte)) {
+    bitmask = 0x0f;
+    if (!GetLongHeaderType(version_.transport_version, type_byte,
+                           &header_type)) {
+      return false;
+    }
+  }
+  uint8_t mask_byte;
+  if (!mask_reader.ReadUInt8(&mask_byte) ||
+      !buffer_writer.WriteUInt8(type_byte ^ (mask_byte & bitmask))) {
+    return false;
+  }
+
+  // Adjust |pn_offset| to account for the diversification nonce.
+  if (IsLongHeader(type_byte) && header_type == ZERO_RTT_PROTECTED &&
+      perspective_ == Perspective::IS_SERVER &&
+      version_.handshake_protocol == PROTOCOL_QUIC_CRYPTO) {
+    if (pn_offset <= kDiversificationNonceSize) {
+      QUIC_BUG << "Expected diversification nonce, but not enough bytes";
+      return false;
+    }
+    pn_offset -= kDiversificationNonceSize;
+  }
+  // Advance the reader and writer to the packet number. Both the reader and
+  // writer have each read/written one byte.
+  if (!buffer_writer.Seek(pn_offset - 1) ||
+      !buffer_reader.Seek(pn_offset - 1)) {
+    return false;
+  }
+  // Apply the rest of the mask to the packet number.
+  for (size_t i = 0; i < last_written_packet_number_length_; ++i) {
+    uint8_t buffer_byte;
+    uint8_t mask_byte;
+    if (!mask_reader.ReadUInt8(&mask_byte) ||
+        !buffer_reader.ReadUInt8(&buffer_byte) ||
+        !buffer_writer.WriteUInt8(buffer_byte ^ mask_byte)) {
+      return false;
+    }
+  }
+  return true;
+}
+
+bool QuicFramer::RemoveHeaderProtection(QuicDataReader* reader,
+                                        const QuicEncryptedPacket& packet,
+                                        QuicPacketHeader* header,
+                                        uint64_t* full_packet_number,
+                                        std::vector<char>* associated_data) {
+  EncryptionLevel expected_decryption_level = GetEncryptionLevel(*header);
+  QuicDecrypter* decrypter = decrypter_[expected_decryption_level].get();
+  if (decrypter == nullptr) {
+    QUIC_DVLOG(1)
+        << "No decrypter available for removing header protection at level "
+        << expected_decryption_level;
+    return false;
+  }
+
+  bool has_diversification_nonce =
+      header->form == IETF_QUIC_LONG_HEADER_PACKET &&
+      header->long_packet_type == ZERO_RTT_PROTECTED &&
+      perspective_ == Perspective::IS_CLIENT &&
+      version_.handshake_protocol == PROTOCOL_QUIC_CRYPTO;
+
+  // Read a sample from the ciphertext and compute the mask to use for header
+  // protection.
+  QuicStringPiece remaining_packet = reader->PeekRemainingPayload();
+  QuicDataReader sample_reader(remaining_packet);
+
+  // The sample starts 4 bytes after the start of the packet number.
+  QuicStringPiece pn;
+  if (!sample_reader.ReadStringPiece(&pn, 4)) {
+    QUIC_DVLOG(1) << "Not enough data to sample";
+    return false;
+  }
+  if (has_diversification_nonce) {
+    // In Google QUIC, the diversification nonce comes between the packet number
+    // and the sample.
+    if (!sample_reader.Seek(kDiversificationNonceSize)) {
+      QUIC_DVLOG(1) << "No diversification nonce to skip over";
+      return false;
+    }
+  }
+  std::string mask = decrypter->GenerateHeaderProtectionMask(&sample_reader);
+  QuicDataReader mask_reader(mask.data(), mask.size());
+  if (mask.empty()) {
+    QUIC_DVLOG(1) << "Failed to compute mask";
+    return false;
+  }
+
+  // Unmask the rest of the type byte.
+  uint8_t bitmask = 0x1f;
+  if (IsLongHeader(header->type_byte)) {
+    bitmask = 0x0f;
+  }
+  uint8_t mask_byte;
+  if (!mask_reader.ReadUInt8(&mask_byte)) {
+    QUIC_DVLOG(1) << "No first byte to read from mask";
+    return false;
+  }
+  header->type_byte ^= (mask_byte & bitmask);
+
+  // Compute the packet number length.
+  header->packet_number_length =
+      static_cast<QuicPacketNumberLength>((header->type_byte & 0x03) + 1);
+
+  char pn_buffer[IETF_MAX_PACKET_NUMBER_LENGTH] = {};
+  QuicDataWriter pn_writer(QUIC_ARRAYSIZE(pn_buffer), pn_buffer);
+
+  // Read the (protected) packet number from the reader and unmask the packet
+  // number.
+  for (size_t i = 0; i < header->packet_number_length; ++i) {
+    uint8_t protected_pn_byte, mask_byte;
+    if (!mask_reader.ReadUInt8(&mask_byte) ||
+        !reader->ReadUInt8(&protected_pn_byte) ||
+        !pn_writer.WriteUInt8(protected_pn_byte ^ mask_byte)) {
+      QUIC_DVLOG(1) << "Failed to unmask packet number";
+      return false;
+    }
+  }
+  QuicDataReader packet_number_reader(pn_writer.data(), pn_writer.length());
+  QuicPacketNumber base_packet_number;
+  if (supports_multiple_packet_number_spaces_) {
+    PacketNumberSpace pn_space = GetPacketNumberSpace(*header);
+    if (pn_space == NUM_PACKET_NUMBER_SPACES) {
+      return false;
+    }
+    base_packet_number = largest_decrypted_packet_numbers_[pn_space];
+  } else {
+    base_packet_number = largest_packet_number_;
+  }
+  if (!ProcessAndCalculatePacketNumber(
+          &packet_number_reader, header->packet_number_length,
+          base_packet_number, full_packet_number)) {
+    return false;
+  }
+
+  // Get the associated data, and apply the same unmasking operations to it.
+  QuicStringPiece ad = GetAssociatedDataFromEncryptedPacket(
+      version_.transport_version, packet,
+      GetIncludedDestinationConnectionIdLength(*header),
+      GetIncludedSourceConnectionIdLength(*header), header->version_flag,
+      has_diversification_nonce, header->packet_number_length,
+      header->retry_token_length_length, header->retry_token.length(),
+      header->length_length);
+  *associated_data = std::vector<char>(ad.begin(), ad.end());
+  QuicDataWriter ad_writer(associated_data->size(), associated_data->data());
+
+  // Apply the unmasked type byte and packet number to |associated_data|.
+  if (!ad_writer.WriteUInt8(header->type_byte)) {
+    return false;
+  }
+  // Put the packet number at the end of the AD, or if there's a diversification
+  // nonce, before that (which is at the end of the AD).
+  size_t seek_len = ad_writer.remaining() - header->packet_number_length;
+  if (has_diversification_nonce) {
+    seek_len -= kDiversificationNonceSize;
+  }
+  if (!ad_writer.Seek(seek_len) ||
+      !ad_writer.WriteBytes(pn_writer.data(), pn_writer.length())) {
+    QUIC_DVLOG(1) << "Failed to apply unmasking operations to AD";
+    return false;
+  }
+
+  return true;
+}
+
 size_t QuicFramer::EncryptPayload(EncryptionLevel level,
                                   QuicPacketNumber packet_number,
                                   const QuicPacket& packet,
@@ -4012,6 +4258,12 @@
     RaiseError(QUIC_ENCRYPTION_FAILURE);
     return 0;
   }
+  if (version_.HasHeaderProtection() &&
+      !ApplyHeaderProtection(level, buffer, ad_len + output_length, ad_len)) {
+    QUIC_DLOG(ERROR) << "Applying header protection failed.";
+    RaiseError(QUIC_ENCRYPTION_FAILURE);
+    return 0;
+  }
 
   return ad_len + output_length;
 }
@@ -5256,7 +5508,9 @@
   QUIC_DLOG(INFO) << ENDPOINT << "Error: " << QuicErrorCodeToString(error)
                   << " detail: " << detailed_error_;
   set_error(error);
-  visitor_->OnError(this);
+  if (visitor_) {
+    visitor_->OnError(this);
+  }
   return false;
 }