blob: 035b9f9d030ed928691ec3a7dd7951b8970c2657 [file] [log] [blame]
// Copyright 2019 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "quic/masque/masque_compression_engine.h"
#include <cstdint>
#include "absl/strings/string_view.h"
#include "quic/core/quic_buffer_allocator.h"
#include "quic/core/quic_data_reader.h"
#include "quic/core/quic_data_writer.h"
#include "quic/core/quic_framer.h"
#include "quic/core/quic_session.h"
#include "quic/core/quic_types.h"
#include "quic/core/quic_versions.h"
#include "quic/platform/api/quic_containers.h"
#include "common/quiche_text_utils.h"
namespace quic {
namespace {
// |kFlowId0| is used to indicate creation of a new compression context.
const QuicDatagramStreamId kFlowId0 = 0;
enum MasqueAddressFamily : uint8_t {
MasqueAddressFamilyIPv4 = 4,
MasqueAddressFamilyIPv6 = 6,
};
} // namespace
MasqueCompressionEngine::MasqueCompressionEngine(
QuicSpdySession* masque_session)
: masque_session_(masque_session),
next_available_flow_id_(
masque_session_->perspective() == Perspective::IS_CLIENT ? 0 : 1) {}
QuicDatagramStreamId MasqueCompressionEngine::FindOrCreateCompressionContext(
QuicConnectionId client_connection_id,
QuicConnectionId server_connection_id,
const QuicSocketAddress& server_address, bool client_connection_id_present,
bool server_connection_id_present, bool* validated) {
QuicDatagramStreamId flow_id = kFlowId0;
*validated = false;
for (const auto& kv : contexts_) {
const MasqueCompressionContext& context = kv.second;
if (context.server_address != server_address) {
continue;
}
if (client_connection_id_present &&
context.client_connection_id != client_connection_id) {
continue;
}
if (server_connection_id_present &&
context.server_connection_id != server_connection_id) {
continue;
}
flow_id = kv.first;
QUICHE_DCHECK_NE(flow_id, kFlowId0);
*validated = context.validated;
QUIC_DVLOG(1) << "Compressing using " << (*validated ? "" : "un")
<< "validated flow_id " << flow_id << " to "
<< context.server_address << " client "
<< context.client_connection_id << " server "
<< context.server_connection_id;
break;
}
if (flow_id != kFlowId0) {
// Found a compression context, use it.
return flow_id;
}
// Create new compression context.
next_available_flow_id_ += 2;
flow_id = next_available_flow_id_;
QUIC_DVLOG(1) << "Compression assigning new flow_id " << flow_id << " to "
<< server_address << " client " << client_connection_id
<< " server " << server_connection_id;
MasqueCompressionContext context;
context.client_connection_id = client_connection_id;
context.server_connection_id = server_connection_id;
context.server_address = server_address;
contexts_[flow_id] = context;
return flow_id;
}
bool MasqueCompressionEngine::WriteCompressedPacketToSlice(
QuicConnectionId client_connection_id,
QuicConnectionId server_connection_id,
const QuicSocketAddress& server_address,
QuicConnectionId destination_connection_id,
QuicConnectionId source_connection_id, QuicDatagramStreamId flow_id,
bool validated, uint8_t first_byte, bool long_header,
QuicDataReader* reader, QuicDataWriter* writer) {
if (validated) {
QUIC_DVLOG(1) << "Compressing using validated flow_id " << flow_id;
if (!writer->WriteVarInt62(flow_id)) {
QUIC_BUG(quic_bug_10981_1) << "Failed to write flow_id";
return false;
}
} else {
QUIC_DVLOG(1) << "Compressing using unvalidated flow_id " << flow_id;
if (!writer->WriteVarInt62(kFlowId0)) {
QUIC_BUG(quic_bug_10981_2) << "Failed to write kFlowId0";
return false;
}
if (!writer->WriteVarInt62(flow_id)) {
QUIC_BUG(quic_bug_10981_3) << "Failed to write flow_id";
return false;
}
if (!writer->WriteLengthPrefixedConnectionId(client_connection_id)) {
QUIC_BUG(quic_bug_10981_4) << "Failed to write client_connection_id";
return false;
}
if (!writer->WriteLengthPrefixedConnectionId(server_connection_id)) {
QUIC_BUG(quic_bug_10981_5) << "Failed to write server_connection_id";
return false;
}
if (!writer->WriteUInt16(server_address.port())) {
QUIC_BUG(quic_bug_10981_6) << "Failed to write port";
return false;
}
QuicIpAddress peer_ip = server_address.host();
QUICHE_DCHECK(peer_ip.IsInitialized());
std::string peer_ip_bytes = peer_ip.ToPackedString();
QUICHE_DCHECK(!peer_ip_bytes.empty());
uint8_t address_id;
if (peer_ip.address_family() == IpAddressFamily::IP_V6) {
address_id = MasqueAddressFamilyIPv6;
if (peer_ip_bytes.length() != QuicIpAddress::kIPv6AddressSize) {
QUIC_BUG(quic_bug_10981_7) << "Bad IPv6 length " << server_address;
return false;
}
} else if (peer_ip.address_family() == IpAddressFamily::IP_V4) {
address_id = MasqueAddressFamilyIPv4;
if (peer_ip_bytes.length() != QuicIpAddress::kIPv4AddressSize) {
QUIC_BUG(quic_bug_10981_8) << "Bad IPv4 length " << server_address;
return false;
}
} else {
QUIC_BUG(quic_bug_10981_9)
<< "Unexpected server_address " << server_address;
return false;
}
if (!writer->WriteUInt8(address_id)) {
QUIC_BUG(quic_bug_10981_10) << "Failed to write address_id";
return false;
}
if (!writer->WriteStringPiece(peer_ip_bytes)) {
QUIC_BUG(quic_bug_10981_11) << "Failed to write IP address";
return false;
}
}
if (!writer->WriteUInt8(first_byte)) {
QUIC_BUG(quic_bug_10981_12) << "Failed to write first_byte";
return false;
}
if (long_header) {
QuicVersionLabel version_label;
if (!reader->ReadUInt32(&version_label)) {
QUIC_DLOG(ERROR) << "Failed to read version";
return false;
}
if (!writer->WriteUInt32(version_label)) {
QUIC_BUG(quic_bug_10981_13) << "Failed to write version";
return false;
}
QuicConnectionId packet_destination_connection_id,
packet_source_connection_id;
if (!reader->ReadLengthPrefixedConnectionId(
&packet_destination_connection_id) ||
!reader->ReadLengthPrefixedConnectionId(&packet_source_connection_id)) {
QUIC_DLOG(ERROR) << "Failed to parse long header connection IDs";
return false;
}
if (packet_destination_connection_id != destination_connection_id) {
QUIC_DLOG(ERROR) << "Long header packet's destination_connection_id "
<< packet_destination_connection_id
<< " does not match expected "
<< destination_connection_id;
return false;
}
if (packet_source_connection_id != source_connection_id) {
QUIC_DLOG(ERROR) << "Long header packet's source_connection_id "
<< packet_source_connection_id
<< " does not match expected " << source_connection_id;
return false;
}
} else {
QuicConnectionId packet_destination_connection_id;
if (!reader->ReadConnectionId(&packet_destination_connection_id,
destination_connection_id.length())) {
QUIC_DLOG(ERROR)
<< "Failed to read short header packet's destination_connection_id";
return false;
}
if (packet_destination_connection_id != destination_connection_id) {
QUIC_DLOG(ERROR) << "Short header packet's destination_connection_id "
<< packet_destination_connection_id
<< " does not match expected "
<< destination_connection_id;
return false;
}
}
absl::string_view packet_payload = reader->ReadRemainingPayload();
if (!writer->WriteStringPiece(packet_payload)) {
QUIC_BUG(quic_bug_10981_14) << "Failed to write packet_payload";
return false;
}
return true;
}
void MasqueCompressionEngine::CompressAndSendPacket(
absl::string_view packet, QuicConnectionId client_connection_id,
QuicConnectionId server_connection_id,
const QuicSocketAddress& server_address) {
QUIC_DVLOG(2) << "Compressing client " << client_connection_id << " server "
<< server_connection_id << "\n"
<< quiche::QuicheTextUtils::HexDump(packet);
QUICHE_DCHECK(server_address.IsInitialized());
if (packet.empty()) {
QUIC_BUG(quic_bug_10981_15) << "Tried to send empty packet";
return;
}
QuicDataReader reader(packet.data(), packet.length());
uint8_t first_byte;
if (!reader.ReadUInt8(&first_byte)) {
QUIC_BUG(quic_bug_10981_16) << "Failed to read first_byte";
return;
}
const bool long_header = (first_byte & FLAGS_LONG_HEADER) != 0;
bool client_connection_id_present = true, server_connection_id_present = true;
QuicConnectionId destination_connection_id, source_connection_id;
if (masque_session_->perspective() == Perspective::IS_SERVER) {
destination_connection_id = client_connection_id;
source_connection_id = server_connection_id;
if (!long_header) {
server_connection_id_present = false;
}
} else {
destination_connection_id = server_connection_id;
source_connection_id = client_connection_id;
if (!long_header) {
client_connection_id_present = false;
}
}
bool validated = false;
QuicDatagramStreamId flow_id = FindOrCreateCompressionContext(
client_connection_id, server_connection_id, server_address,
client_connection_id_present, server_connection_id_present, &validated);
size_t slice_length = packet.length() - destination_connection_id.length();
if (long_header) {
slice_length -= sizeof(uint8_t) * 2 + source_connection_id.length();
}
if (validated) {
slice_length += QuicDataWriter::GetVarInt62Len(flow_id);
} else {
slice_length += QuicDataWriter::GetVarInt62Len(kFlowId0) +
QuicDataWriter::GetVarInt62Len(flow_id) + sizeof(uint8_t) +
client_connection_id.length() + sizeof(uint8_t) +
server_connection_id.length() +
sizeof(server_address.port()) + sizeof(uint8_t) +
server_address.host().ToPackedString().length();
}
QuicBuffer buffer(
masque_session_->connection()->helper()->GetStreamSendBufferAllocator(),
slice_length);
QuicDataWriter writer(buffer.size(), buffer.data());
if (!WriteCompressedPacketToSlice(
client_connection_id, server_connection_id, server_address,
destination_connection_id, source_connection_id, flow_id, validated,
first_byte, long_header, &reader, &writer)) {
return;
}
MessageResult message_result =
masque_session_->SendMessage(QuicMemSlice(std::move(buffer)));
QUIC_DVLOG(1) << "Sent packet compressed with flow ID " << flow_id
<< " and got message result " << message_result;
}
bool MasqueCompressionEngine::ParseCompressionContext(
QuicDataReader* reader, MasqueCompressionContext* context) {
QuicDatagramStreamId new_flow_id;
if (!reader->ReadVarInt62(&new_flow_id)) {
QUIC_DLOG(ERROR) << "Could not read new_flow_id";
return false;
}
QuicConnectionId new_client_connection_id;
if (!reader->ReadLengthPrefixedConnectionId(&new_client_connection_id)) {
QUIC_DLOG(ERROR) << "Could not read new_client_connection_id";
return false;
}
QuicConnectionId new_server_connection_id;
if (!reader->ReadLengthPrefixedConnectionId(&new_server_connection_id)) {
QUIC_DLOG(ERROR) << "Could not read new_server_connection_id";
return false;
}
uint16_t port;
if (!reader->ReadUInt16(&port)) {
QUIC_DLOG(ERROR) << "Could not read port";
return false;
}
uint8_t address_id;
if (!reader->ReadUInt8(&address_id)) {
QUIC_DLOG(ERROR) << "Could not read address_id";
return false;
}
size_t ip_bytes_length;
if (address_id == MasqueAddressFamilyIPv6) {
ip_bytes_length = QuicIpAddress::kIPv6AddressSize;
} else if (address_id == MasqueAddressFamilyIPv4) {
ip_bytes_length = QuicIpAddress::kIPv4AddressSize;
} else {
QUIC_DLOG(ERROR) << "Unknown address_id " << static_cast<int>(address_id);
return false;
}
char ip_bytes[QuicIpAddress::kMaxAddressSize];
if (!reader->ReadBytes(ip_bytes, ip_bytes_length)) {
QUIC_DLOG(ERROR) << "Could not read IP address";
return false;
}
QuicIpAddress ip_address;
ip_address.FromPackedString(ip_bytes, ip_bytes_length);
if (!ip_address.IsInitialized()) {
QUIC_BUG(quic_bug_10981_17) << "Failed to parse IP address";
return false;
}
QuicSocketAddress new_server_address = QuicSocketAddress(ip_address, port);
auto context_pair = contexts_.find(new_flow_id);
if (context_pair == contexts_.end()) {
context->client_connection_id = new_client_connection_id;
context->server_connection_id = new_server_connection_id;
context->server_address = new_server_address;
context->validated = true;
contexts_[new_flow_id] = *context;
QUIC_DVLOG(1) << "Registered new flow_id " << new_flow_id << " to "
<< new_server_address << " client "
<< new_client_connection_id << " server "
<< new_server_connection_id;
} else {
*context = context_pair->second;
if (context->client_connection_id != new_client_connection_id) {
QUIC_LOG(ERROR)
<< "Received incorrect context registration for existing flow_id "
<< new_flow_id << " mismatched client "
<< context->client_connection_id << " " << new_client_connection_id;
return false;
}
if (context->server_connection_id != new_server_connection_id) {
QUIC_LOG(ERROR)
<< "Received incorrect context registration for existing flow_id "
<< new_flow_id << " mismatched server "
<< context->server_connection_id << " " << new_server_connection_id;
return false;
}
if (context->server_address != new_server_address) {
QUIC_LOG(ERROR)
<< "Received incorrect context registration for existing flow_id "
<< new_flow_id << " mismatched server " << context->server_address
<< " " << new_server_address;
return false;
}
if (!context->validated) {
context->validated = true;
contexts_[new_flow_id] = *context;
QUIC_DLOG(INFO) << "Successfully validated remotely-unvalidated flow_id "
<< new_flow_id << " to " << new_server_address
<< " client " << new_client_connection_id << " server "
<< new_server_connection_id;
} else {
QUIC_DVLOG(1) << "Decompressing using incoming locally-validated "
"remotely-unvalidated flow_id "
<< new_flow_id << " to " << new_server_address << " client "
<< new_client_connection_id << " server "
<< new_server_connection_id;
}
}
return true;
}
bool MasqueCompressionEngine::WriteDecompressedPacket(
QuicDataReader* reader, const MasqueCompressionContext& context,
std::vector<char>* packet, bool* version_present) {
QuicConnectionId destination_connection_id, source_connection_id;
if (masque_session_->perspective() == Perspective::IS_SERVER) {
destination_connection_id = context.server_connection_id;
source_connection_id = context.client_connection_id;
} else {
destination_connection_id = context.client_connection_id;
source_connection_id = context.server_connection_id;
}
size_t packet_length =
reader->BytesRemaining() + destination_connection_id.length();
uint8_t first_byte;
if (!reader->ReadUInt8(&first_byte)) {
QUIC_DLOG(ERROR) << "Failed to read first_byte";
return false;
}
*version_present = (first_byte & FLAGS_LONG_HEADER) != 0;
if (*version_present) {
packet_length += sizeof(uint8_t) * 2 + source_connection_id.length();
}
*packet = std::vector<char>(packet_length);
QuicDataWriter writer(packet->size(), packet->data());
if (!writer.WriteUInt8(first_byte)) {
QUIC_BUG(quic_bug_10981_18) << "Failed to write first_byte";
return false;
}
if (*version_present) {
QuicVersionLabel version_label;
if (!reader->ReadUInt32(&version_label)) {
QUIC_DLOG(ERROR) << "Failed to read version";
return false;
}
if (!writer.WriteUInt32(version_label)) {
QUIC_BUG(quic_bug_10981_19) << "Failed to write version";
return false;
}
if (!writer.WriteLengthPrefixedConnectionId(destination_connection_id)) {
QUIC_BUG(quic_bug_10981_20)
<< "Failed to write long header destination_connection_id";
return false;
}
if (!writer.WriteLengthPrefixedConnectionId(source_connection_id)) {
QUIC_BUG(quic_bug_10981_21)
<< "Failed to write long header source_connection_id";
return false;
}
} else {
if (!writer.WriteConnectionId(destination_connection_id)) {
QUIC_BUG(quic_bug_10981_22)
<< "Failed to write short header destination_connection_id";
return false;
}
}
absl::string_view payload = reader->ReadRemainingPayload();
if (!writer.WriteStringPiece(payload)) {
QUIC_BUG(quic_bug_10981_23) << "Failed to write payload";
return false;
}
return true;
}
bool MasqueCompressionEngine::DecompressDatagram(
absl::string_view datagram, QuicConnectionId* client_connection_id,
QuicConnectionId* server_connection_id, QuicSocketAddress* server_address,
std::vector<char>* packet, bool* version_present) {
QUIC_DVLOG(1) << "Decompressing DATAGRAM frame of length "
<< datagram.length();
QuicDataReader reader(datagram);
QuicDatagramStreamId flow_id;
if (!reader.ReadVarInt62(&flow_id)) {
QUIC_DLOG(ERROR) << "Could not read flow_id";
return false;
}
MasqueCompressionContext context;
if (flow_id == kFlowId0) {
if (!ParseCompressionContext(&reader, &context)) {
return false;
}
} else {
auto context_pair = contexts_.find(flow_id);
if (context_pair == contexts_.end()) {
QUIC_DLOG(ERROR) << "Received unknown flow_id " << flow_id;
return false;
}
context = context_pair->second;
if (!context.validated) {
context.validated = true;
contexts_[flow_id] = context;
QUIC_DLOG(INFO) << "Successfully validated remotely-validated flow_id "
<< flow_id << " to " << context.server_address
<< " client " << context.client_connection_id
<< " server " << context.server_connection_id;
} else {
QUIC_DVLOG(1) << "Decompressing using incoming locally-validated "
"remotely-validated flow_id "
<< flow_id << " to " << context.server_address << " client "
<< context.client_connection_id << " server "
<< context.server_connection_id;
}
}
if (!WriteDecompressedPacket(&reader, context, packet, version_present)) {
return false;
}
*server_address = context.server_address;
*client_connection_id = context.client_connection_id;
*server_connection_id = context.server_connection_id;
QUIC_DVLOG(2) << "Decompressed client " << context.client_connection_id
<< " server " << context.server_connection_id << "\n"
<< quiche::QuicheTextUtils::HexDump(
absl::string_view(packet->data(), packet->size()));
return true;
}
void MasqueCompressionEngine::UnregisterClientConnectionId(
QuicConnectionId client_connection_id) {
std::vector<QuicDatagramStreamId> flow_ids_to_remove;
for (const auto& kv : contexts_) {
const MasqueCompressionContext& context = kv.second;
if (context.client_connection_id == client_connection_id) {
flow_ids_to_remove.push_back(kv.first);
}
}
for (QuicDatagramStreamId flow_id : flow_ids_to_remove) {
contexts_.erase(flow_id);
}
}
} // namespace quic