blob: 99def666dd96fa94fc9dc6a4f675c3c75697b118 [file] [log] [blame]
// Copyright 2019 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "net/third_party/quiche/src/quic/masque/masque_compression_engine.h"
#include <cstdint>
#include "net/third_party/quiche/src/quic/core/quic_data_reader.h"
#include "net/third_party/quiche/src/quic/core/quic_data_writer.h"
#include "net/third_party/quiche/src/quic/core/quic_framer.h"
#include "net/third_party/quiche/src/quic/core/quic_session.h"
#include "net/third_party/quiche/src/quic/core/quic_types.h"
#include "net/third_party/quiche/src/quic/core/quic_versions.h"
#include "net/third_party/quiche/src/quic/platform/api/quic_containers.h"
#include "net/third_party/quiche/src/quic/platform/api/quic_string_piece.h"
#include "net/third_party/quiche/src/quic/platform/api/quic_text_utils.h"
namespace quic {
namespace {
const QuicDatagramFlowId kFlowId0 = 0;
enum MasqueAddressFamily : uint8_t {
MasqueAddressFamilyIPv4 = 4,
MasqueAddressFamilyIPv6 = 6,
};
} // namespace
MasqueCompressionEngine::MasqueCompressionEngine(QuicSession* masque_session)
: masque_session_(masque_session) {
if (masque_session_->perspective() == Perspective::IS_SERVER) {
next_flow_id_ = 1;
} else {
next_flow_id_ = 2;
}
}
void MasqueCompressionEngine::CompressAndSendPacket(
QuicStringPiece packet,
QuicConnectionId client_connection_id,
QuicConnectionId server_connection_id,
const QuicSocketAddress& server_socket_address) {
QUIC_DVLOG(2) << "Compressing client " << client_connection_id << " server "
<< server_connection_id << "\n"
<< QuicTextUtils::HexDump(packet);
DCHECK(server_socket_address.IsInitialized());
if (packet.empty()) {
QUIC_BUG << "Tried to send empty packet";
return;
}
QuicDataReader reader(packet.data(), packet.length());
uint8_t first_byte;
if (!reader.ReadUInt8(&first_byte)) {
QUIC_BUG << "Failed to read first_byte";
return;
}
const bool long_header = (first_byte & FLAGS_LONG_HEADER) != 0;
bool client_connection_id_present = true, server_connection_id_present = true;
QuicConnectionId destination_connection_id, source_connection_id;
if (masque_session_->perspective() == Perspective::IS_SERVER) {
destination_connection_id = client_connection_id;
source_connection_id = server_connection_id;
if (!long_header) {
server_connection_id_present = false;
}
} else {
destination_connection_id = server_connection_id;
source_connection_id = client_connection_id;
if (!long_header) {
client_connection_id_present = false;
}
}
QuicDatagramFlowId flow_id = 0;
bool validated = false;
for (auto kv : contexts_) {
const MasqueCompressionContext& context = kv.second;
if (context.server_socket_address != server_socket_address) {
continue;
}
if (client_connection_id_present &&
context.client_connection_id != client_connection_id) {
continue;
}
if (server_connection_id_present &&
context.server_connection_id != server_connection_id) {
continue;
}
flow_id = kv.first;
DCHECK_NE(flow_id, kFlowId0);
validated = context.validated;
QUIC_DVLOG(1) << "Compressing using " << (validated ? "" : "un")
<< "validated flow_id " << flow_id << " to "
<< context.server_socket_address << " client "
<< context.client_connection_id << " server "
<< context.server_connection_id;
break;
}
if (flow_id == 0) {
flow_id = GetNextFlowId();
QUIC_DVLOG(1) << "Compression assigning new flow_id " << flow_id << " to "
<< server_socket_address << " client " << client_connection_id
<< " server " << server_connection_id;
MasqueCompressionContext context;
context.client_connection_id = client_connection_id;
context.server_connection_id = server_connection_id;
context.server_socket_address = server_socket_address;
contexts_[flow_id] = context;
}
size_t slice_length = packet.length() - destination_connection_id.length();
if (long_header) {
slice_length -= sizeof(uint8_t) * 2 + source_connection_id.length();
}
if (validated) {
slice_length += QuicDataWriter::GetVarInt62Len(flow_id);
} else {
slice_length += QuicDataWriter::GetVarInt62Len(kFlowId0) +
QuicDataWriter::GetVarInt62Len(flow_id) + sizeof(uint8_t) +
client_connection_id.length() + sizeof(uint8_t) +
server_connection_id.length() +
sizeof(server_socket_address.port()) + sizeof(uint8_t) +
server_socket_address.host().ToPackedString().length();
}
QuicMemSlice slice(
masque_session_->connection()->helper()->GetStreamSendBufferAllocator(),
slice_length);
QuicDataWriter writer(slice_length, const_cast<char*>(slice.data()));
if (validated) {
QUIC_DVLOG(1) << "Compressing using validated flow_id " << flow_id;
if (!writer.WriteVarInt62(flow_id)) {
QUIC_BUG << "Failed to write flow_id";
return;
}
} else {
QUIC_DVLOG(1) << "Compressing using unvalidated flow_id " << flow_id;
if (!writer.WriteVarInt62(kFlowId0)) {
QUIC_BUG << "Failed to write kFlowId0";
return;
}
if (!writer.WriteVarInt62(flow_id)) {
QUIC_BUG << "Failed to write flow_id";
return;
}
if (!writer.WriteLengthPrefixedConnectionId(client_connection_id)) {
QUIC_BUG << "Failed to write client_connection_id";
return;
}
if (!writer.WriteLengthPrefixedConnectionId(server_connection_id)) {
QUIC_BUG << "Failed to write server_connection_id";
return;
}
if (!writer.WriteUInt16(server_socket_address.port())) {
QUIC_BUG << "Failed to write port";
return;
}
QuicIpAddress peer_ip = server_socket_address.host();
DCHECK(peer_ip.IsInitialized());
std::string peer_ip_bytes = peer_ip.ToPackedString();
DCHECK(!peer_ip_bytes.empty());
uint8_t address_id;
if (peer_ip.address_family() == IpAddressFamily::IP_V6) {
address_id = MasqueAddressFamilyIPv6;
if (peer_ip_bytes.length() != QuicIpAddress::kIPv6AddressSize) {
QUIC_BUG << "Bad IPv6 length " << server_socket_address;
return;
}
} else if (peer_ip.address_family() == IpAddressFamily::IP_V4) {
address_id = MasqueAddressFamilyIPv4;
if (peer_ip_bytes.length() != QuicIpAddress::kIPv4AddressSize) {
QUIC_BUG << "Bad IPv4 length " << server_socket_address;
return;
}
} else {
QUIC_BUG << "Unexpected server_socket_address " << server_socket_address;
return;
}
if (!writer.WriteUInt8(address_id)) {
QUIC_BUG << "Failed to write address_id";
return;
}
if (!writer.WriteStringPiece(peer_ip_bytes)) {
QUIC_BUG << "Failed to write IP address";
return;
}
}
if (!writer.WriteUInt8(first_byte)) {
QUIC_BUG << "Failed to write first_byte";
return;
}
if (long_header) {
QuicVersionLabel version_label;
if (!reader.ReadUInt32(&version_label)) {
QUIC_DLOG(ERROR) << "Failed to read version";
return;
}
if (!writer.WriteUInt32(version_label)) {
QUIC_BUG << "Failed to write version";
return;
}
QuicConnectionId packet_destination_connection_id,
packet_source_connection_id;
if (!reader.ReadLengthPrefixedConnectionId(
&packet_destination_connection_id) ||
!reader.ReadLengthPrefixedConnectionId(&packet_source_connection_id)) {
QUIC_DLOG(ERROR) << "Failed to parse long header connection IDs";
return;
}
if (packet_destination_connection_id != destination_connection_id) {
QUIC_DLOG(ERROR) << "Long header packet's destination_connection_id "
<< packet_destination_connection_id
<< " does not match expected "
<< destination_connection_id;
return;
}
if (packet_source_connection_id != source_connection_id) {
QUIC_DLOG(ERROR) << "Long header packet's source_connection_id "
<< packet_source_connection_id
<< " does not match expected " << source_connection_id;
return;
}
} else {
QuicConnectionId packet_destination_connection_id;
if (!reader.ReadConnectionId(&packet_destination_connection_id,
destination_connection_id.length())) {
QUIC_DLOG(ERROR)
<< "Failed to read short header packet's destination_connection_id";
return;
}
if (packet_destination_connection_id != destination_connection_id) {
QUIC_DLOG(ERROR) << "Short header packet's destination_connection_id "
<< packet_destination_connection_id
<< " does not match expected "
<< destination_connection_id;
return;
}
}
QuicStringPiece packet_payload = reader.ReadRemainingPayload();
if (!writer.WriteStringPiece(packet_payload)) {
QUIC_BUG << "Failed to write packet_payload";
return;
}
MessageResult message_result =
masque_session_->SendMessage(QuicMemSliceSpan(&slice));
QUIC_DVLOG(1) << "Message result " << message_result;
}
bool MasqueCompressionEngine::DecompressMessage(
QuicStringPiece message,
QuicConnectionId* client_connection_id,
QuicConnectionId* server_connection_id,
QuicSocketAddress* server_socket_address,
std::string* packet,
bool* version_present) {
QUIC_DVLOG(1) << "Decompressing message of length " << message.length();
QuicDataReader reader(message);
QuicDatagramFlowId flow_id;
if (!reader.ReadVarInt62(&flow_id)) {
QUIC_DLOG(ERROR) << "Could not read flow_id";
return false;
}
MasqueCompressionContext context;
if (flow_id == kFlowId0) {
QuicDatagramFlowId new_flow_id;
if (!reader.ReadVarInt62(&new_flow_id)) {
QUIC_DLOG(ERROR) << "Could not read new_flow_id";
return false;
}
QuicConnectionId new_client_connection_id;
if (!reader.ReadLengthPrefixedConnectionId(&new_client_connection_id)) {
QUIC_DLOG(ERROR) << "Could not read new_client_connection_id";
return false;
}
QuicConnectionId new_server_connection_id;
if (!reader.ReadLengthPrefixedConnectionId(&new_server_connection_id)) {
QUIC_DLOG(ERROR) << "Could not read new_server_connection_id";
return false;
}
uint16_t port;
if (!reader.ReadUInt16(&port)) {
QUIC_DLOG(ERROR) << "Could not read port";
return false;
}
uint8_t address_id;
if (!reader.ReadUInt8(&address_id)) {
QUIC_DLOG(ERROR) << "Could not read address_id";
return false;
}
size_t ip_bytes_length;
if (address_id == MasqueAddressFamilyIPv6) {
ip_bytes_length = QuicIpAddress::kIPv6AddressSize;
} else if (address_id == MasqueAddressFamilyIPv4) {
ip_bytes_length = QuicIpAddress::kIPv4AddressSize;
} else {
QUIC_DLOG(ERROR) << "Unknown address_id " << static_cast<int>(address_id);
return false;
}
char ip_bytes[QuicIpAddress::kMaxAddressSize];
if (!reader.ReadBytes(ip_bytes, ip_bytes_length)) {
QUIC_DLOG(ERROR) << "Could not read IP address";
return false;
}
QuicIpAddress ip_address;
ip_address.FromPackedString(ip_bytes, ip_bytes_length);
if (!ip_address.IsInitialized()) {
QUIC_BUG << "Failed to parse IP address";
return false;
}
QuicSocketAddress new_server_socket_address =
QuicSocketAddress(ip_address, port);
auto context_pair = contexts_.find(new_flow_id);
if (context_pair == contexts_.end()) {
context.client_connection_id = new_client_connection_id;
context.server_connection_id = new_server_connection_id;
context.server_socket_address = new_server_socket_address;
context.validated = true;
contexts_[new_flow_id] = context;
QUIC_DVLOG(1) << "Registered new flow_id " << new_flow_id << " to "
<< new_server_socket_address << " client "
<< new_client_connection_id << " server "
<< new_server_connection_id;
} else {
context = context_pair->second;
if (context.client_connection_id != new_client_connection_id) {
QUIC_LOG(ERROR)
<< "Received incorrect context registration for existing flow_id "
<< new_flow_id << " mismatched client "
<< context.client_connection_id << " " << new_client_connection_id;
return false;
}
if (context.server_connection_id != new_server_connection_id) {
QUIC_LOG(ERROR)
<< "Received incorrect context registration for existing flow_id "
<< new_flow_id << " mismatched server "
<< context.server_connection_id << " " << new_server_connection_id;
return false;
}
if (context.server_socket_address != new_server_socket_address) {
QUIC_LOG(ERROR)
<< "Received incorrect context registration for existing flow_id "
<< new_flow_id << " mismatched server "
<< context.server_socket_address << " "
<< new_server_socket_address;
return false;
}
if (!context.validated) {
context.validated = true;
contexts_[new_flow_id] = context;
QUIC_DLOG(INFO)
<< "Successfully validated remotely-unvalidated flow_id "
<< new_flow_id << " to " << new_server_socket_address << " client "
<< new_client_connection_id << " server "
<< new_server_connection_id;
} else {
QUIC_DVLOG(1) << "Decompressing using incoming locally-validated "
"remotely-unvalidated flow_id "
<< new_flow_id << " to " << new_server_socket_address
<< " client " << new_client_connection_id << " server "
<< new_server_connection_id;
}
}
} else {
auto context_pair = contexts_.find(flow_id);
if (context_pair == contexts_.end()) {
QUIC_DLOG(ERROR) << "Received unknown flow_id " << flow_id;
return false;
}
context = context_pair->second;
if (!context.validated) {
context.validated = true;
contexts_[flow_id] = context;
QUIC_DLOG(INFO) << "Successfully validated remotely-validated flow_id "
<< flow_id << " to " << context.server_socket_address
<< " client " << context.client_connection_id
<< " server " << context.server_connection_id;
} else {
QUIC_DVLOG(1) << "Decompressing using incoming locally-validated "
"remotely-validated flow_id "
<< flow_id << " to " << context.server_socket_address
<< " client " << context.client_connection_id << " server "
<< context.server_connection_id;
}
}
QuicConnectionId destination_connection_id, source_connection_id;
if (masque_session_->perspective() == Perspective::IS_SERVER) {
destination_connection_id = context.server_connection_id;
source_connection_id = context.client_connection_id;
} else {
destination_connection_id = context.client_connection_id;
source_connection_id = context.server_connection_id;
}
size_t packet_length =
reader.BytesRemaining() + destination_connection_id.length();
uint8_t first_byte;
if (!reader.ReadUInt8(&first_byte)) {
QUIC_DLOG(ERROR) << "Failed to read first_byte";
return false;
}
const bool long_header = (first_byte & FLAGS_LONG_HEADER) != 0;
if (long_header) {
packet_length += sizeof(uint8_t) * 2 + source_connection_id.length();
}
std::string new_packet(packet_length, 0);
QuicDataWriter writer(new_packet.length(), new_packet.data());
if (!writer.WriteUInt8(first_byte)) {
QUIC_BUG << "Failed to write first_byte";
return false;
}
if (long_header) {
QuicVersionLabel version_label;
if (!reader.ReadUInt32(&version_label)) {
QUIC_DLOG(ERROR) << "Failed to read version";
return false;
}
if (!writer.WriteUInt32(version_label)) {
QUIC_BUG << "Failed to write version";
return false;
}
if (!writer.WriteLengthPrefixedConnectionId(destination_connection_id)) {
QUIC_BUG << "Failed to write long header destination_connection_id";
return false;
}
if (!writer.WriteLengthPrefixedConnectionId(source_connection_id)) {
QUIC_BUG << "Failed to write long header source_connection_id";
return false;
}
} else {
if (!writer.WriteConnectionId(destination_connection_id)) {
QUIC_BUG << "Failed to write short header destination_connection_id";
return false;
}
}
QuicStringPiece payload = reader.ReadRemainingPayload();
if (!writer.WriteStringPiece(payload)) {
QUIC_BUG << "Failed to write payload";
return false;
}
*server_socket_address = context.server_socket_address;
*client_connection_id = context.client_connection_id;
*server_connection_id = context.server_connection_id;
*packet = new_packet;
*version_present = long_header;
QUIC_DVLOG(2) << "Decompressed client " << context.client_connection_id
<< " server " << context.server_connection_id << "\n"
<< QuicTextUtils::HexDump(new_packet);
return true;
}
QuicDatagramFlowId MasqueCompressionEngine::GetNextFlowId() {
const QuicDatagramFlowId next_flow_id = next_flow_id_;
next_flow_id_ += 2;
return next_flow_id;
}
void MasqueCompressionEngine::UnregisterClientConnectionId(
QuicConnectionId client_connection_id) {
std::vector<QuicDatagramFlowId> flow_ids_to_remove;
for (auto kv : contexts_) {
const MasqueCompressionContext& context = kv.second;
if (context.client_connection_id == client_connection_id) {
flow_ids_to_remove.push_back(kv.first);
}
}
for (QuicDatagramFlowId flow_id : flow_ids_to_remove) {
contexts_.erase(flow_id);
}
}
} // namespace quic