Replace QUIC CID when the first INITIAL packet of a connection is enqueued into the QuicBufferedPacketStore. This address a review comment of cl/641294174. Protected by FLAGS_quic_restart_flag_quic_dispatcher_replace_cid_on_first_packet. PiperOrigin-RevId: 658154062
diff --git a/quiche/common/quiche_feature_flags_list.h b/quiche/common/quiche_feature_flags_list.h index bb0a6da..0e4689b 100755 --- a/quiche/common/quiche_feature_flags_list.h +++ b/quiche/common/quiche_feature_flags_list.h
@@ -51,6 +51,7 @@ QUICHE_FLAG(bool, quiche_reloadable_flag_quic_testonly_default_false, false, false, "A testonly reloadable flag that will always default to false.") QUICHE_FLAG(bool, quiche_reloadable_flag_quic_testonly_default_true, true, true, "A testonly reloadable flag that will always default to true.") QUICHE_FLAG(bool, quiche_reloadable_flag_quic_use_received_client_addresses_cache, true, true, "If true, use a LRU cache to record client addresses of packets received on server's original address.") +QUICHE_FLAG(bool, quiche_restart_flag_quic_dispatcher_replace_cid_on_first_packet, false, false, "If true, QuicDispatcher will always generate the replaced connection ID when the first INITIAL packet is received, even if it's buffered in the packet store.") QUICHE_FLAG(bool, quiche_restart_flag_quic_support_ect1, false, false, "When true, allows sending of QUIC packets marked ECT(1). A different flag (TBD) will actually utilize this capability to send ECT(1).") QUICHE_FLAG(bool, quiche_restart_flag_quic_support_release_time_for_gso, false, false, "If true, QuicGsoBatchWriter will support release time if it is available and the process has the permission to do so.") QUICHE_FLAG(bool, quiche_restart_flag_quic_testonly_default_false, false, false, "A testonly restart flag that will always default to false.")
diff --git a/quiche/quic/core/quic_buffered_packet_store.cc b/quiche/quic/core/quic_buffered_packet_store.cc index e2fb225..9d525b2 100644 --- a/quiche/quic/core/quic_buffered_packet_store.cc +++ b/quiche/quic/core/quic_buffered_packet_store.cc
@@ -26,9 +26,11 @@ #include "quiche/quic/core/quic_types.h" #include "quiche/quic/core/quic_versions.h" #include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/platform/api/quic_flag_utils.h" #include "quiche/quic/platform/api/quic_flags.h" #include "quiche/quic/platform/api/quic_socket_address.h" #include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/print_elements.h" namespace quic { @@ -63,10 +65,12 @@ BufferedPacket::BufferedPacket(std::unique_ptr<QuicReceivedPacket> packet, QuicSocketAddress self_address, - QuicSocketAddress peer_address) + QuicSocketAddress peer_address, + bool is_ietf_initial_packet) : packet(std::move(packet)), self_address(self_address), - peer_address(peer_address) {} + peer_address(peer_address), + is_ietf_initial_packet(is_ietf_initial_packet) {} BufferedPacket::BufferedPacket(BufferedPacket&& other) = default; @@ -105,7 +109,7 @@ EnqueuePacketResult QuicBufferedPacketStore::EnqueuePacket( const ReceivedPacketInfo& packet_info, std::optional<ParsedClientHello> parsed_chlo, - ConnectionIdGeneratorInterface* connection_id_generator) { + ConnectionIdGeneratorInterface& connection_id_generator) { QuicConnectionId connection_id = packet_info.destination_connection_id; const QuicReceivedPacket& packet = packet_info.packet; const QuicSocketAddress& self_address = packet_info.self_address; @@ -113,6 +117,9 @@ const ParsedQuicVersion& version = packet_info.version; const bool ietf_quic = packet_info.form != GOOGLE_QUIC_PACKET; const bool is_chlo = parsed_chlo.has_value(); + const bool is_ietf_initial_packet = + (version.IsKnown() && packet_info.form == IETF_QUIC_LONG_HEADER_PACKET && + packet_info.long_packet_type == INITIAL); QUIC_BUG_IF(quic_bug_12410_1, !GetQuicFlag(quic_allow_chlo_buffering)) << "Shouldn't buffer packets if disabled via flag."; QUIC_BUG_IF(quic_bug_12410_2, @@ -121,52 +128,104 @@ QUIC_BUG_IF(quic_bug_12410_4, is_chlo && !version.IsKnown()) << "Should have version for CHLO packet."; - const bool is_first_packet = !undecryptable_packets_.contains(connection_id); - if (is_first_packet) { - if (ShouldNotBufferPacket(is_chlo)) { - // Drop the packet if the upper limit of undecryptable packets has been - // reached or the whole capacity of the store has been reached. - return TOO_MANY_CONNECTIONS; - } - undecryptable_packets_.emplace( - std::make_pair(connection_id, BufferedPacketList())); - undecryptable_packets_.back().second.ietf_quic = ietf_quic; - undecryptable_packets_.back().second.version = version; - } - QUICHE_CHECK(undecryptable_packets_.contains(connection_id)); - BufferedPacketList& queue = - undecryptable_packets_.find(connection_id)->second; + bool is_first_packet; + BufferedPacketListNode* node = nullptr; - if (!is_chlo) { - // If current packet is not CHLO, it might not be buffered because store - // only buffers certain number of undecryptable packets per connection. - size_t num_non_chlo_packets = connections_with_chlo_.contains(connection_id) - ? (queue.buffered_packets.size() - 1) - : queue.buffered_packets.size(); - if (num_non_chlo_packets >= kDefaultMaxUndecryptablePackets) { + if (replace_cid_on_first_packet_) { + QUIC_RESTART_FLAG_COUNT_N(quic_dispatcher_replace_cid_on_first_packet, 1, + 13); + auto iter = buffered_session_map_.find(connection_id); + is_first_packet = (iter == buffered_session_map_.end()); + if (is_first_packet) { + if (ShouldNotBufferPacket(is_chlo)) { + // Drop the packet if the upper limit of undecryptable packets has been + // reached or the whole capacity of the store has been reached. + return TOO_MANY_CONNECTIONS; + } + iter = buffered_session_map_.emplace_hint( + iter, connection_id, std::make_shared<BufferedPacketListNode>()); + iter->second->ietf_quic = ietf_quic; + iter->second->version = version; + iter->second->original_connection_id = connection_id; + iter->second->creation_time = clock_->ApproximateNow(); + buffered_sessions_.push_back(iter->second.get()); + ++num_buffered_sessions_; + } + node = iter->second.get(); + QUICHE_DCHECK(buffered_session_map_.contains(connection_id)); + } else { + is_first_packet = !undecryptable_packets_.contains(connection_id); + if (is_first_packet) { + if (ShouldNotBufferPacket(is_chlo)) { + // Drop the packet if the upper limit of undecryptable packets has been + // reached or the whole capacity of the store has been reached. + return TOO_MANY_CONNECTIONS; + } + undecryptable_packets_.emplace( + std::make_pair(connection_id, BufferedPacketList())); + undecryptable_packets_.back().second.ietf_quic = ietf_quic; + undecryptable_packets_.back().second.version = version; + } + QUICHE_DCHECK(undecryptable_packets_.contains(connection_id)); + } + + BufferedPacketList& queue = + replace_cid_on_first_packet_ + ? *node + : undecryptable_packets_.find(connection_id)->second; + + if (replace_cid_on_first_packet_) { + // TODO(wub): Rename kDefaultMaxUndecryptablePackets when deprecating + // --quic_dispatcher_replace_cid_on_first_packet. + if (!is_chlo && + queue.buffered_packets.size() >= kDefaultMaxUndecryptablePackets) { // If there are kMaxBufferedPacketsPerConnection packets buffered up for // this connection, drop the current packet. return TOO_MANY_PACKETS; } - } + } else { + if (!is_chlo) { + // If current packet is not CHLO, it might not be buffered because store + // only buffers certain number of undecryptable packets per connection. + size_t num_non_chlo_packets = + connections_with_chlo_.contains(connection_id) + ? (queue.buffered_packets.size() - 1) + : queue.buffered_packets.size(); + if (num_non_chlo_packets >= kDefaultMaxUndecryptablePackets) { + // If there are kMaxBufferedPacketsPerConnection packets buffered up for + // this connection, drop the current packet. + return TOO_MANY_PACKETS; + } + } - if (queue.buffered_packets.empty()) { - // If this is the first packet arrived on a new connection, initialize the - // creation time. - queue.creation_time = clock_->ApproximateNow(); + if (queue.buffered_packets.empty()) { + // If this is the first packet arrived on a new connection, initialize the + // creation time. + queue.creation_time = clock_->ApproximateNow(); + } } BufferedPacket new_entry(std::unique_ptr<QuicReceivedPacket>(packet.Clone()), - self_address, peer_address); + self_address, peer_address, is_ietf_initial_packet); if (is_chlo) { // Add CHLO to the beginning of buffered packets so that it can be delivered // first later. queue.buffered_packets.push_front(std::move(new_entry)); queue.parsed_chlo = std::move(parsed_chlo); - connections_with_chlo_[connection_id] = false; // Dummy value. // Set the version of buffered packets of this connection on CHLO. queue.version = version; - queue.connection_id_generator = connection_id_generator; + if (replace_cid_on_first_packet_) { + if (!buffered_sessions_with_chlo_.is_linked(node)) { + buffered_sessions_with_chlo_.push_back(node); + ++num_buffered_sessions_with_chlo_; + } else { + QUIC_BUG(quic_store_session_already_has_chlo) + << "Buffered session already has CHLO"; + } + } else { + queue.connection_id_generator = &connection_id_generator; + connections_with_chlo_[connection_id] = false; // Dummy value. + } } else { // Buffer non-CHLO packets in arrival order. queue.buffered_packets.push_back(std::move(new_entry)); @@ -184,86 +243,232 @@ } MaybeSetExpirationAlarm(); + + if (replace_cid_on_first_packet_ && is_ietf_initial_packet && + version.UsesTls() && !queue.HasAttemptedToReplaceConnectionId()) { + QUIC_RESTART_FLAG_COUNT_N(quic_dispatcher_replace_cid_on_first_packet, 2, + 13); + queue.SetAttemptedToReplaceConnectionId(&connection_id_generator); + std::optional<QuicConnectionId> replaced_connection_id = + connection_id_generator.MaybeReplaceConnectionId(connection_id, + packet_info.version); + // Normalize the output of MaybeReplaceConnectionId. + if (replaced_connection_id.has_value() && + (replaced_connection_id->IsEmpty() || + *replaced_connection_id == connection_id)) { + QUIC_CODE_COUNT(quic_store_replaced_cid_is_empty_or_same_as_original); + replaced_connection_id.reset(); + } + QUIC_DVLOG(1) << "MaybeReplaceConnectionId(" << connection_id << ") = " + << (replaced_connection_id.has_value() + ? replaced_connection_id->ToString() + : "nullopt"); + if (replaced_connection_id.has_value()) { + switch (visitor_->HandleConnectionIdCollision( + connection_id, *replaced_connection_id, self_address, peer_address, + version, + queue.parsed_chlo.has_value() ? &queue.parsed_chlo.value() + : nullptr)) { + case VisitorInterface::HandleCidCollisionResult::kOk: + queue.replaced_connection_id = *replaced_connection_id; + buffered_session_map_.insert( + {*replaced_connection_id, node->shared_from_this()}); + break; + case VisitorInterface::HandleCidCollisionResult::kCollision: + return CID_COLLISION; + } + } + } + return SUCCESS; } bool QuicBufferedPacketStore::HasBufferedPackets( QuicConnectionId connection_id) const { + if (replace_cid_on_first_packet_) { + return buffered_session_map_.contains(connection_id); + } return undecryptable_packets_.contains(connection_id); } bool QuicBufferedPacketStore::HasChlosBuffered() const { + if (replace_cid_on_first_packet_) { + return num_buffered_sessions_with_chlo_ != 0; + } return !connections_with_chlo_.empty(); } BufferedPacketList QuicBufferedPacketStore::DeliverPackets( QuicConnectionId connection_id) { - BufferedPacketList packets_to_deliver; - auto it = undecryptable_packets_.find(connection_id); - if (it != undecryptable_packets_.end()) { - packets_to_deliver = std::move(it->second); - undecryptable_packets_.erase(connection_id); - std::list<BufferedPacket> initial_packets; - std::list<BufferedPacket> other_packets; - for (auto& packet : packets_to_deliver.buffered_packets) { - QuicLongHeaderType long_packet_type = INVALID_PACKET_TYPE; - PacketHeaderFormat unused_format; - bool unused_version_flag; - bool unused_use_length_prefix; - QuicVersionLabel unused_version_label; - ParsedQuicVersion unused_parsed_version = UnsupportedQuicVersion(); - QuicConnectionId unused_destination_connection_id; - QuicConnectionId unused_source_connection_id; - std::optional<absl::string_view> unused_retry_token; - std::string unused_detailed_error; + if (!replace_cid_on_first_packet_) { + BufferedPacketList packets_to_deliver; + auto it = undecryptable_packets_.find(connection_id); + if (it != undecryptable_packets_.end()) { + packets_to_deliver = std::move(it->second); + undecryptable_packets_.erase(connection_id); + std::list<BufferedPacket> initial_packets; + std::list<BufferedPacket> other_packets; + for (auto& packet : packets_to_deliver.buffered_packets) { + QuicLongHeaderType long_packet_type = INVALID_PACKET_TYPE; + PacketHeaderFormat unused_format; + bool unused_version_flag; + bool unused_use_length_prefix; + QuicVersionLabel unused_version_label; + ParsedQuicVersion unused_parsed_version = UnsupportedQuicVersion(); + QuicConnectionId unused_destination_connection_id; + QuicConnectionId unused_source_connection_id; + std::optional<absl::string_view> unused_retry_token; + std::string unused_detailed_error; - // We don't need to pass |generator| because we already got the correct - // connection ID length when we buffered the packet and indexed by - // connection ID. - QuicErrorCode error_code = QuicFramer::ParsePublicHeaderDispatcher( - *packet.packet, connection_id.length(), &unused_format, - &long_packet_type, &unused_version_flag, &unused_use_length_prefix, - &unused_version_label, &unused_parsed_version, - &unused_destination_connection_id, &unused_source_connection_id, - &unused_retry_token, &unused_detailed_error); + // We don't need to pass |generator| because we already got the correct + // connection ID length when we buffered the packet and indexed by + // connection ID. + QuicErrorCode error_code = QuicFramer::ParsePublicHeaderDispatcher( + *packet.packet, connection_id.length(), &unused_format, + &long_packet_type, &unused_version_flag, &unused_use_length_prefix, + &unused_version_label, &unused_parsed_version, + &unused_destination_connection_id, &unused_source_connection_id, + &unused_retry_token, &unused_detailed_error); - if (error_code == QUIC_NO_ERROR && long_packet_type == INITIAL) { - initial_packets.push_back(std::move(packet)); - } else { - other_packets.push_back(std::move(packet)); + if (error_code == QUIC_NO_ERROR && long_packet_type == INITIAL) { + initial_packets.push_back(std::move(packet)); + } else { + other_packets.push_back(std::move(packet)); + } } - } - initial_packets.splice(initial_packets.end(), other_packets); - packets_to_deliver.buffered_packets = std::move(initial_packets); + initial_packets.splice(initial_packets.end(), other_packets); + packets_to_deliver.buffered_packets = std::move(initial_packets); + } + return packets_to_deliver; } - return packets_to_deliver; + + auto it = buffered_session_map_.find(connection_id); + if (it == buffered_session_map_.end()) { + return BufferedPacketList(); + } + QUIC_RESTART_FLAG_COUNT_N(quic_dispatcher_replace_cid_on_first_packet, 3, 13); + std::shared_ptr<BufferedPacketListNode> node = it->second->shared_from_this(); + RemoveFromStore(*node); + std::list<BufferedPacket> initial_packets; + std::list<BufferedPacket> other_packets; + for (auto& packet : node->buffered_packets) { + if (packet.is_ietf_initial_packet) { + initial_packets.push_back(std::move(packet)); + } else { + other_packets.push_back(std::move(packet)); + } + } + initial_packets.splice(initial_packets.end(), other_packets); + node->buffered_packets = std::move(initial_packets); + BufferedPacketList& packet_list = *node; + return std::move(packet_list); } void QuicBufferedPacketStore::DiscardPackets(QuicConnectionId connection_id) { - undecryptable_packets_.erase(connection_id); - connections_with_chlo_.erase(connection_id); + if (!replace_cid_on_first_packet_) { + undecryptable_packets_.erase(connection_id); + connections_with_chlo_.erase(connection_id); + return; + } + + QUIC_RESTART_FLAG_COUNT_N(quic_dispatcher_replace_cid_on_first_packet, 4, 13); + auto it = buffered_session_map_.find(connection_id); + if (it == buffered_session_map_.end()) { + return; + } + + RemoveFromStore(*it->second); +} + +void QuicBufferedPacketStore::RemoveFromStore(BufferedPacketListNode& node) { + QUICHE_DCHECK(replace_cid_on_first_packet_); + QUICHE_DCHECK_EQ(buffered_sessions_with_chlo_.size(), + num_buffered_sessions_with_chlo_); + QUICHE_DCHECK_EQ(buffered_sessions_.size(), num_buffered_sessions_); + + // Remove |node| from all lists. + QUIC_BUG_IF(quic_store_chlo_state_inconsistent, + node.parsed_chlo.has_value() != + buffered_sessions_with_chlo_.is_linked(&node)) + << "Inconsistent CHLO state for connection " + << node.original_connection_id + << ", parsed_chlo.has_value:" << node.parsed_chlo.has_value() + << ", is_linked:" << buffered_sessions_with_chlo_.is_linked(&node); + if (buffered_sessions_with_chlo_.is_linked(&node)) { + buffered_sessions_with_chlo_.erase(&node); + --num_buffered_sessions_with_chlo_; + } + + if (buffered_sessions_.is_linked(&node)) { + buffered_sessions_.erase(&node); + --num_buffered_sessions_; + } else { + QUIC_BUG(quic_store_missing_node_in_main_list) + << "Missing node in main buffered session list for connection " + << node.original_connection_id; + } + + if (node.HasReplacedConnectionId()) { + bool erased = buffered_session_map_.erase(*node.replaced_connection_id) > 0; + QUIC_BUG_IF(quic_store_missing_replaced_cid_in_map, !erased) + << "Node has replaced CID but it's not in the map. original_cid: " + << node.original_connection_id + << " replaced_cid: " << *node.replaced_connection_id; + } + + bool erased = buffered_session_map_.erase(node.original_connection_id) > 0; + QUIC_BUG_IF(quic_store_missing_original_cid_in_map, !erased) + << "Node missing in the map. original_cid: " + << node.original_connection_id; } void QuicBufferedPacketStore::DiscardAllPackets() { - undecryptable_packets_.clear(); - connections_with_chlo_.clear(); + if (!replace_cid_on_first_packet_) { + undecryptable_packets_.clear(); + connections_with_chlo_.clear(); + } else { + QUIC_RESTART_FLAG_COUNT_N(quic_dispatcher_replace_cid_on_first_packet, 5, + 13); + buffered_sessions_with_chlo_.clear(); + num_buffered_sessions_with_chlo_ = 0; + buffered_sessions_.clear(); + num_buffered_sessions_ = 0; + buffered_session_map_.clear(); + } expiration_alarm_->Cancel(); } void QuicBufferedPacketStore::OnExpirationTimeout() { QuicTime expiration_time = clock_->ApproximateNow() - connection_life_span_; - while (!undecryptable_packets_.empty()) { - auto& entry = undecryptable_packets_.front(); - if (entry.second.creation_time > expiration_time) { + if (!replace_cid_on_first_packet_) { + while (!undecryptable_packets_.empty()) { + auto& entry = undecryptable_packets_.front(); + if (entry.second.creation_time > expiration_time) { + break; + } + QuicConnectionId connection_id = entry.first; + visitor_->OnExpiredPackets(connection_id, std::move(entry.second)); + undecryptable_packets_.pop_front(); + connections_with_chlo_.erase(connection_id); + } + if (!undecryptable_packets_.empty()) { + MaybeSetExpirationAlarm(); + } + return; + } + QUIC_RESTART_FLAG_COUNT_N(quic_dispatcher_replace_cid_on_first_packet, 6, 13); + while (!buffered_sessions_.empty()) { + BufferedPacketListNode& node = buffered_sessions_.front(); + if (node.creation_time > expiration_time) { break; } - QuicConnectionId connection_id = entry.first; - visitor_->OnExpiredPackets(connection_id, std::move(entry.second)); - undecryptable_packets_.pop_front(); - connections_with_chlo_.erase(connection_id); + std::shared_ptr<BufferedPacketListNode> node_ref = node.shared_from_this(); + QuicConnectionId connection_id = node.original_connection_id; + RemoveFromStore(node); + visitor_->OnExpiredPackets(connection_id, std::move(node)); } - if (!undecryptable_packets_.empty()) { + if (!buffered_sessions_.empty()) { MaybeSetExpirationAlarm(); } } @@ -274,16 +479,27 @@ } } -bool QuicBufferedPacketStore::ShouldNotBufferPacket(bool is_chlo) { - bool is_store_full = - undecryptable_packets_.size() >= kDefaultMaxConnectionsInStore; +bool QuicBufferedPacketStore::ShouldNotBufferPacket(bool is_chlo) const { + size_t num_connections = replace_cid_on_first_packet_ + ? num_buffered_sessions_ + : undecryptable_packets_.size(); + + bool is_store_full = num_connections >= kDefaultMaxConnectionsInStore; if (is_chlo) { return is_store_full; } + size_t num_connections_with_chlo = replace_cid_on_first_packet_ + ? num_buffered_sessions_with_chlo_ + : connections_with_chlo_.size(); + + QUIC_BUG_IF(quic_store_too_many_connections_with_chlo, + num_connections < num_connections_with_chlo) + << "num_connections: " << num_connections + << ", num_connections_with_chlo: " << num_connections_with_chlo; size_t num_connections_without_chlo = - undecryptable_packets_.size() - connections_with_chlo_.size(); + num_connections - num_connections_with_chlo; bool reach_non_chlo_limit = num_connections_without_chlo >= kMaxConnectionsWithoutCHLO; @@ -292,24 +508,48 @@ BufferedPacketList QuicBufferedPacketStore::DeliverPacketsForNextConnection( QuicConnectionId* connection_id) { - if (connections_with_chlo_.empty()) { + if (!replace_cid_on_first_packet_) { + if (connections_with_chlo_.empty()) { + // Returns empty list if no CHLO has been buffered. + return BufferedPacketList(); + } + *connection_id = connections_with_chlo_.front().first; + connections_with_chlo_.pop_front(); + + BufferedPacketList packets = DeliverPackets(*connection_id); + QUICHE_DCHECK(!packets.buffered_packets.empty() && + packets.parsed_chlo.has_value()) + << "Try to deliver connectons without CHLO. # packets:" + << packets.buffered_packets.size() + << ", has_parsed_chlo:" << packets.parsed_chlo.has_value(); + return packets; + } + + QUIC_RESTART_FLAG_COUNT_N(quic_dispatcher_replace_cid_on_first_packet, 7, 13); + if (buffered_sessions_with_chlo_.empty()) { // Returns empty list if no CHLO has been buffered. return BufferedPacketList(); } - *connection_id = connections_with_chlo_.front().first; - connections_with_chlo_.pop_front(); - BufferedPacketList packets = DeliverPackets(*connection_id); - QUICHE_DCHECK(!packets.buffered_packets.empty() && - packets.parsed_chlo.has_value()) + *connection_id = buffered_sessions_with_chlo_.front().original_connection_id; + BufferedPacketList packet_list = DeliverPackets(*connection_id); + QUICHE_DCHECK(!packet_list.buffered_packets.empty() && + packet_list.parsed_chlo.has_value()) << "Try to deliver connectons without CHLO. # packets:" - << packets.buffered_packets.size() - << ", has_parsed_chlo:" << packets.parsed_chlo.has_value(); - return packets; + << packet_list.buffered_packets.size() + << ", has_parsed_chlo:" << packet_list.parsed_chlo.has_value(); + return packet_list; } bool QuicBufferedPacketStore::HasChloForConnection( QuicConnectionId connection_id) { + if (replace_cid_on_first_packet_) { + auto it = buffered_session_map_.find(connection_id); + if (it == buffered_session_map_.end()) { + return false; + } + return it->second->parsed_chlo.has_value(); + } return connections_with_chlo_.contains(connection_id); } @@ -325,18 +565,42 @@ QUICHE_DCHECK_NE(out_sni, nullptr); QUICHE_DCHECK_NE(tls_alert, nullptr); QUICHE_DCHECK_EQ(version.handshake_protocol, PROTOCOL_TLS1_3); - auto it = undecryptable_packets_.find(connection_id); - if (it == undecryptable_packets_.end()) { + + if (!replace_cid_on_first_packet_) { + auto it = undecryptable_packets_.find(connection_id); + if (it == undecryptable_packets_.end()) { + QUIC_BUG(quic_bug_10838_1) + << "Cannot ingest packet for unknown connection ID " << connection_id; + return false; + } + it->second.tls_chlo_extractor.IngestPacket(version, packet); + if (!it->second.tls_chlo_extractor.HasParsedFullChlo()) { + *tls_alert = it->second.tls_chlo_extractor.tls_alert(); + return false; + } + const TlsChloExtractor& tls_chlo_extractor = it->second.tls_chlo_extractor; + *out_supported_groups = tls_chlo_extractor.supported_groups(); + *out_alpns = tls_chlo_extractor.alpns(); + *out_sni = tls_chlo_extractor.server_name(); + *out_resumption_attempted = tls_chlo_extractor.resumption_attempted(); + *out_early_data_attempted = tls_chlo_extractor.early_data_attempted(); + return true; + } + + QUIC_RESTART_FLAG_COUNT_N(quic_dispatcher_replace_cid_on_first_packet, 8, 13); + auto it = buffered_session_map_.find(connection_id); + if (it == buffered_session_map_.end()) { QUIC_BUG(quic_bug_10838_1) << "Cannot ingest packet for unknown connection ID " << connection_id; return false; } - it->second.tls_chlo_extractor.IngestPacket(version, packet); - if (!it->second.tls_chlo_extractor.HasParsedFullChlo()) { - *tls_alert = it->second.tls_chlo_extractor.tls_alert(); + BufferedPacketListNode& node = *it->second; + node.tls_chlo_extractor.IngestPacket(version, packet); + if (!node.tls_chlo_extractor.HasParsedFullChlo()) { + *tls_alert = node.tls_chlo_extractor.tls_alert(); return false; } - const TlsChloExtractor& tls_chlo_extractor = it->second.tls_chlo_extractor; + const TlsChloExtractor& tls_chlo_extractor = node.tls_chlo_extractor; *out_supported_groups = tls_chlo_extractor.supported_groups(); *out_cert_compression_algos = tls_chlo_extractor.cert_compression_algos(); *out_alpns = tls_chlo_extractor.alpns();
diff --git a/quiche/quic/core/quic_buffered_packet_store.h b/quiche/quic/core/quic_buffered_packet_store.h index db9b74f..2555944 100644 --- a/quiche/quic/core/quic_buffered_packet_store.h +++ b/quiche/quic/core/quic_buffered_packet_store.h
@@ -5,6 +5,7 @@ #ifndef QUICHE_QUIC_CORE_QUIC_BUFFERED_PACKET_STORE_H_ #define QUICHE_QUIC_CORE_QUIC_BUFFERED_PACKET_STORE_H_ +#include <cstddef> #include <cstdint> #include <list> #include <memory> @@ -12,6 +13,7 @@ #include <string> #include <vector> +#include "absl/container/flat_hash_map.h" #include "quiche/quic/core/connection_id_generator.h" #include "quiche/quic/core/quic_alarm.h" #include "quiche/quic/core/quic_alarm_factory.h" @@ -30,6 +32,7 @@ #include "quiche/quic/platform/api/quic_socket_address.h" #include "quiche/common/platform/api/quiche_export.h" #include "quiche/common/quiche_buffer_allocator.h" +#include "quiche/common/quiche_intrusive_list.h" #include "quiche/common/quiche_linked_hash_map.h" namespace quic { @@ -51,14 +54,18 @@ public: enum EnqueuePacketResult { SUCCESS = 0, - TOO_MANY_PACKETS, // Too many packets stored up for a certain connection. - TOO_MANY_CONNECTIONS // Too many connections stored up in the store. + // Too many packets stored up for a certain connection. + TOO_MANY_PACKETS, + // Too many connections stored up in the store. + TOO_MANY_CONNECTIONS, + // Replaced CID collide with a buffered or active session. + CID_COLLISION, }; struct QUICHE_EXPORT BufferedPacket { BufferedPacket(std::unique_ptr<QuicReceivedPacket> packet, QuicSocketAddress self_address, - QuicSocketAddress peer_address); + QuicSocketAddress peer_address, bool is_ietf_initial_packet); BufferedPacket(BufferedPacket&& other); BufferedPacket& operator=(BufferedPacket&& other); @@ -68,6 +75,7 @@ std::unique_ptr<QuicReceivedPacket> packet; QuicSocketAddress self_address; QuicSocketAddress peer_address; + bool is_ietf_initial_packet; }; // A queue of BufferedPackets for a connection. @@ -79,6 +87,20 @@ ~BufferedPacketList(); + bool HasReplacedConnectionId() const { + return replaced_connection_id.has_value() && + !replaced_connection_id->IsEmpty(); + } + + bool HasAttemptedToReplaceConnectionId() const { + return connection_id_generator != nullptr; + } + + void SetAttemptedToReplaceConnectionId( + ConnectionIdGeneratorInterface* generator) { + connection_id_generator = generator; + } + std::list<BufferedPacket> buffered_packets; QuicTime creation_time; // |parsed_chlo| is set iff the entire CHLO has been received. @@ -90,16 +112,36 @@ ParsedQuicVersion version; TlsChloExtractor tls_chlo_extractor; // Only one reference to the generator is stored per connection, and this is - // stored when the CHLO is buffered. The connection needs a stable, - // consistent way to generate IDs. Fixing it on the CHLO is a - // straightforward way to enforce that. + // stored when the replaced CID is generated. ConnectionIdGeneratorInterface* connection_id_generator = nullptr; + // The original connection ID of the connection. + // Only used when replace_cid_on_first_packet_ is true. + QuicConnectionId original_connection_id; + // Set to the result of ConnectionIdGenerator::MaybeReplaceConnectionId, + // when the first IETF INITIAL packet is enqueued. + // Note that std::nullopt indicates one the following cases, you can use + // HasAttemptedToReplaceConnectionId() to distinguish them: + // 1. No attempt to replace CID has been made. + // 2. One attempt to replace CID has been made, but the CID generator does + // not want to replace it. + std::optional<QuicConnectionId> replaced_connection_id; }; using BufferedPacketMap = quiche::QuicheLinkedHashMap<QuicConnectionId, BufferedPacketList, QuicConnectionIdHash>; + // Tag type for the list of sessions with full CHLO buffered. + struct QUICHE_EXPORT BufferedSessionsWithChloList {}; + + // The internal data structure for a buffered session. + struct QUICHE_EXPORT BufferedPacketListNode + : public quiche::QuicheIntrusiveLink<BufferedPacketListNode>, + public quiche::QuicheIntrusiveLink<BufferedPacketListNode, + BufferedSessionsWithChloList>, + public std::enable_shared_from_this<BufferedPacketListNode>, + public BufferedPacketList {}; + class QUICHE_EXPORT VisitorInterface { public: virtual ~VisitorInterface() {} @@ -107,6 +149,25 @@ // Called for each expired connection when alarm fires. virtual void OnExpiredPackets(QuicConnectionId connection_id, BufferedPacketList early_arrived_packets) = 0; + + enum class HandleCidCollisionResult { + kOk, + kCollision, + }; + // Check and handle CID collision for |replaced_connection_id|. + // This method is called immediately after |replaced_connection_id| is + // generated by the connection ID generator, at which time the mapping from + // |replaced_connection_id| to the connection is not yet established, which + // means if the implementation calls + // store.HasBufferedPackets(replaced_connection_id); + // and it returns true, then |replaced_connection_id| has already been + // mapped to another connection, i.e. a CID collision. + virtual HandleCidCollisionResult HandleConnectionIdCollision( + const QuicConnectionId& original_connection_id, + const QuicConnectionId& replaced_connection_id, + const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address, ParsedQuicVersion version, + const ParsedClientHello* parsed_chlo) = 0; }; QuicBufferedPacketStore(VisitorInterface* visitor, const QuicClock* clock, @@ -126,9 +187,10 @@ EnqueuePacketResult EnqueuePacket( const ReceivedPacketInfo& packet_info, std::optional<ParsedClientHello> parsed_chlo, - ConnectionIdGeneratorInterface* connection_id_generator); + ConnectionIdGeneratorInterface& connection_id_generator); // Returns true if there are any packets buffered for |connection_id|. + // |connection_id| can be either original or replaced connection ID. bool HasBufferedPackets(QuicConnectionId connection_id) const; // Ingests this packet into the corresponding TlsChloExtractor. This should @@ -154,9 +216,11 @@ // Returns the list of buffered packets for |connection_id| and removes them // from the store. Returns an empty list if no early arrived packets for this // connection are present. + // |connection_id| can be either original or replaced connection ID. BufferedPacketList DeliverPackets(QuicConnectionId connection_id); // Discards packets buffered for |connection_id|, if any. + // |connection_id| can be either original or replaced connection ID. void DiscardPackets(QuicConnectionId connection_id); // Discards all the packets. @@ -177,12 +241,17 @@ BufferedPacketList DeliverPacketsForNextConnection( QuicConnectionId* connection_id); - // Is given connection already buffered in the store? + // Is the given connection in the store and contains the full CHLO? + // |connection_id| can be either original or replaced connection ID. bool HasChloForConnection(QuicConnectionId connection_id); - // Is there any CHLO buffered in the store? + // Is there any connection in the store that contains a full CHLO? bool HasChlosBuffered() const; + bool replace_cid_on_first_packet() const { + return replace_cid_on_first_packet_; + } + private: friend class test::QuicBufferedPacketStorePeer; @@ -191,11 +260,46 @@ // Return true if add an extra packet will go beyond allowed max connection // limit. The limit for non-CHLO packet and CHLO packet is different. - bool ShouldNotBufferPacket(bool is_chlo); + bool ShouldNotBufferPacket(bool is_chlo) const; + + // Remove |node| from the buffered store. If caller wants to access |node| + // after this call, it should use a shared_ptr<BufferedPacketListNode> to keep + // |node| alive: + // + // BufferedPacketListNode& node = ...; + // auto node_ref = node.shared_from_this(); + // RemoveFromStore(node); + // |node| can still be used here. + // + void RemoveFromStore(BufferedPacketListNode& node); + + const bool replace_cid_on_first_packet_ = + GetQuicRestartFlag(quic_dispatcher_replace_cid_on_first_packet); // A map to store packet queues with creation time for each connection. + // Only used when !replace_cid_on_first_packet_. BufferedPacketMap undecryptable_packets_; + // Map from connection ID to the list of buffered packets for that connection. + // The key can be either the original or the replaced connection ID. + // The value is never nullptr. + // Only used when replace_cid_on_first_packet_ is true. + absl::flat_hash_map<QuicConnectionId, std::shared_ptr<BufferedPacketListNode>, + QuicConnectionIdHash> + buffered_session_map_; + + // Main list of all buffered sessions, in insertion order. + // Only used when replace_cid_on_first_packet_ is true. + quiche::QuicheIntrusiveList<BufferedPacketListNode> buffered_sessions_; + size_t num_buffered_sessions_ = 0; + + // Secondary list of all buffered sessions with full CHLO. + // Only used when replace_cid_on_first_packet_ is true. + quiche::QuicheIntrusiveList<BufferedPacketListNode, + BufferedSessionsWithChloList> + buffered_sessions_with_chlo_; + size_t num_buffered_sessions_with_chlo_ = 0; + // The max time the packets of a connection can be buffer in the store. const QuicTime::Delta connection_life_span_; @@ -209,6 +313,7 @@ // Keeps track of connection with CHLO buffered up already and the order they // arrive. + // Only used when !replace_cid_on_first_packet_. quiche::QuicheLinkedHashMap<QuicConnectionId, bool, QuicConnectionIdHash> connections_with_chlo_; };
diff --git a/quiche/quic/core/quic_buffered_packet_store_test.cc b/quiche/quic/core/quic_buffered_packet_store_test.cc index 85ed960..8ae493d 100644 --- a/quiche/quic/core/quic_buffered_packet_store_test.cc +++ b/quiche/quic/core/quic_buffered_packet_store_test.cc
@@ -63,7 +63,7 @@ const QuicReceivedPacket& packet, QuicSocketAddress self_address, QuicSocketAddress peer_address, const ParsedQuicVersion& version, std::optional<ParsedClientHello> parsed_chlo, - ConnectionIdGeneratorInterface* connection_id_generator) { + ConnectionIdGeneratorInterface& connection_id_generator) { ReceivedPacketInfo packet_info(self_address, peer_address, packet); packet_info.destination_connection_id = connection_id; packet_info.form = form; @@ -85,6 +85,15 @@ last_expired_packet_queue_ = std::move(early_arrived_packets); } + HandleCidCollisionResult HandleConnectionIdCollision( + const QuicConnectionId& /*original_connection_id*/, + const QuicConnectionId& /*replaced_connection_id*/, + const QuicSocketAddress& /*self_address*/, + const QuicSocketAddress& /*peer_address*/, ParsedQuicVersion /*version*/, + const ParsedClientHello* /*parsed_chlo*/) override { + return HandleCidCollisionResult::kOk; + } + // The packets queue for most recently expirect connection. BufferedPacketList last_expired_packet_queue_; }; @@ -120,7 +129,8 @@ QuicConnectionId connection_id = TestConnectionId(1); EnqueuePacketToStore(store_, connection_id, GOOGLE_QUIC_PACKET, INVALID_PACKET_TYPE, packet_, self_address_, - peer_address_, invalid_version_, kNoParsedChlo, nullptr); + peer_address_, invalid_version_, kNoParsedChlo, + connection_id_generator_); EXPECT_TRUE(store_.HasBufferedPackets(connection_id)); auto packets = store_.DeliverPackets(connection_id); const std::list<BufferedPacket>& queue = packets.buffered_packets; @@ -143,11 +153,12 @@ QuicConnectionId connection_id = TestConnectionId(1); EnqueuePacketToStore(store_, connection_id, GOOGLE_QUIC_PACKET, INVALID_PACKET_TYPE, packet_, self_address_, - peer_address_, invalid_version_, kNoParsedChlo, nullptr); + peer_address_, invalid_version_, kNoParsedChlo, + connection_id_generator_); EnqueuePacketToStore(store_, connection_id, GOOGLE_QUIC_PACKET, INVALID_PACKET_TYPE, packet_, self_address_, addr_with_new_port, invalid_version_, kNoParsedChlo, - nullptr); + connection_id_generator_); std::list<BufferedPacket> queue = store_.DeliverPackets(connection_id).buffered_packets; ASSERT_EQ(2u, queue.size()); @@ -161,12 +172,14 @@ size_t num_connections = 10; for (uint64_t conn_id = 1; conn_id <= num_connections; ++conn_id) { QuicConnectionId connection_id = TestConnectionId(conn_id); - EnqueuePacketToStore( - store_, connection_id, GOOGLE_QUIC_PACKET, INVALID_PACKET_TYPE, packet_, - self_address_, peer_address_, invalid_version_, kNoParsedChlo, nullptr); - EnqueuePacketToStore( - store_, connection_id, GOOGLE_QUIC_PACKET, INVALID_PACKET_TYPE, packet_, - self_address_, peer_address_, invalid_version_, kNoParsedChlo, nullptr); + EnqueuePacketToStore(store_, connection_id, GOOGLE_QUIC_PACKET, + INVALID_PACKET_TYPE, packet_, self_address_, + peer_address_, invalid_version_, kNoParsedChlo, + connection_id_generator_); + EnqueuePacketToStore(store_, connection_id, GOOGLE_QUIC_PACKET, + INVALID_PACKET_TYPE, packet_, self_address_, + peer_address_, invalid_version_, kNoParsedChlo, + connection_id_generator_); } // Deliver packets in reversed order. @@ -178,35 +191,37 @@ } } +// Tests that for one connection, only limited number of packets can be +// buffered. TEST_F(QuicBufferedPacketStoreTest, FailToBufferTooManyPacketsOnExistingConnection) { - // Tests that for one connection, only limited number of packets can be - // buffered. - size_t num_packets = kDefaultMaxUndecryptablePackets + 1; + // Max number of packets that can be buffered per connection. + const size_t kMaxPacketsPerConnection = + store_.replace_cid_on_first_packet() + ? kDefaultMaxUndecryptablePackets + : kDefaultMaxUndecryptablePackets + 1; QuicConnectionId connection_id = TestConnectionId(1); - // Arrived CHLO packet shouldn't affect how many non-CHLO pacekts store can - // keep. EXPECT_EQ(QuicBufferedPacketStore::SUCCESS, - EnqueuePacketToStore(store_, connection_id, GOOGLE_QUIC_PACKET, - INVALID_PACKET_TYPE, packet_, self_address_, - peer_address_, valid_version_, - kDefaultParsedChlo, nullptr)); - for (size_t i = 1; i <= num_packets; ++i) { - // Only first |kDefaultMaxUndecryptablePackets packets| will be buffered. + EnqueuePacketToStore(store_, connection_id, + IETF_QUIC_LONG_HEADER_PACKET, INITIAL, packet_, + self_address_, peer_address_, valid_version_, + kDefaultParsedChlo, connection_id_generator_)); + for (size_t i = 1; i <= kMaxPacketsPerConnection; ++i) { + // All packets will be buffered except the last one. EnqueuePacketResult result = EnqueuePacketToStore( store_, connection_id, GOOGLE_QUIC_PACKET, INVALID_PACKET_TYPE, packet_, - self_address_, peer_address_, invalid_version_, kNoParsedChlo, nullptr); - if (i <= kDefaultMaxUndecryptablePackets) { + self_address_, peer_address_, invalid_version_, kNoParsedChlo, + connection_id_generator_); + if (i != kMaxPacketsPerConnection) { EXPECT_EQ(EnqueuePacketResult::SUCCESS, result); } else { EXPECT_EQ(EnqueuePacketResult::TOO_MANY_PACKETS, result); } } - // Only first |kDefaultMaxUndecryptablePackets| non-CHLO packets and CHLO are - // buffered. - EXPECT_EQ(kDefaultMaxUndecryptablePackets + 1, - store_.DeliverPackets(connection_id).buffered_packets.size()); + // Verify |kMaxPacketsPerConnection| packets are buffered. + EXPECT_EQ(store_.DeliverPackets(connection_id).buffered_packets.size(), + kMaxPacketsPerConnection); } TEST_F(QuicBufferedPacketStoreTest, ReachNonChloConnectionUpperLimit) { @@ -217,7 +232,8 @@ QuicConnectionId connection_id = TestConnectionId(conn_id); EnqueuePacketResult result = EnqueuePacketToStore( store_, connection_id, GOOGLE_QUIC_PACKET, INVALID_PACKET_TYPE, packet_, - self_address_, peer_address_, invalid_version_, kNoParsedChlo, nullptr); + self_address_, peer_address_, invalid_version_, kNoParsedChlo, + connection_id_generator_); if (conn_id <= kMaxConnectionsWithoutCHLO) { EXPECT_EQ(EnqueuePacketResult::SUCCESS, result); } else { @@ -244,11 +260,12 @@ size_t num_chlos = kDefaultMaxConnectionsInStore - kMaxConnectionsWithoutCHLO + 1; for (uint64_t conn_id = 1; conn_id <= num_chlos; ++conn_id) { - EXPECT_EQ(EnqueuePacketResult::SUCCESS, - EnqueuePacketToStore( - store_, TestConnectionId(conn_id), GOOGLE_QUIC_PACKET, - INVALID_PACKET_TYPE, packet_, self_address_, peer_address_, - valid_version_, kDefaultParsedChlo, nullptr)); + EXPECT_EQ( + EnqueuePacketResult::SUCCESS, + EnqueuePacketToStore(store_, TestConnectionId(conn_id), + GOOGLE_QUIC_PACKET, INVALID_PACKET_TYPE, packet_, + self_address_, peer_address_, valid_version_, + kDefaultParsedChlo, connection_id_generator_)); } // Send data packets on another |kMaxConnectionsWithoutCHLO| connections. @@ -259,7 +276,7 @@ EnqueuePacketResult result = EnqueuePacketToStore( store_, connection_id, GOOGLE_QUIC_PACKET, INVALID_PACKET_TYPE, packet_, self_address_, peer_address_, valid_version_, kDefaultParsedChlo, - nullptr); + connection_id_generator_); if (conn_id <= kDefaultMaxConnectionsInStore) { EXPECT_EQ(EnqueuePacketResult::SUCCESS, result); } else { @@ -273,27 +290,17 @@ EnqueuePacketToStore( store_, TestConnectionId(1), GOOGLE_QUIC_PACKET, INVALID_PACKET_TYPE, packet_, self_address_, peer_address_, - valid_version_, kDefaultParsedChlo, &connection_id_generator_)); + valid_version_, kDefaultParsedChlo, connection_id_generator_)); QuicConnectionId delivered_conn_id; BufferedPacketList packet_list = store_.DeliverPacketsForNextConnection(&delivered_conn_id); EXPECT_EQ(1u, packet_list.buffered_packets.size()); EXPECT_EQ(delivered_conn_id, TestConnectionId(1)); - EXPECT_EQ(&connection_id_generator_, packet_list.connection_id_generator); -} - -TEST_F(QuicBufferedPacketStoreTest, NullGeneratorOk) { - EXPECT_EQ(EnqueuePacketResult::SUCCESS, - EnqueuePacketToStore(store_, TestConnectionId(1), - GOOGLE_QUIC_PACKET, INVALID_PACKET_TYPE, - packet_, self_address_, peer_address_, - valid_version_, kDefaultParsedChlo, nullptr)); - QuicConnectionId delivered_conn_id; - BufferedPacketList packet_list = - store_.DeliverPacketsForNextConnection(&delivered_conn_id); - EXPECT_EQ(1u, packet_list.buffered_packets.size()); - EXPECT_EQ(delivered_conn_id, TestConnectionId(1)); - EXPECT_EQ(packet_list.connection_id_generator, nullptr); + if (GetQuicRestartFlag(quic_dispatcher_replace_cid_on_first_packet)) { + EXPECT_EQ(packet_list.connection_id_generator, nullptr); + } else { + EXPECT_EQ(packet_list.connection_id_generator, &connection_id_generator_); + } } TEST_F(QuicBufferedPacketStoreTest, GeneratorIgnoredForNonChlo) { @@ -302,18 +309,22 @@ EnqueuePacketToStore( store_, TestConnectionId(1), GOOGLE_QUIC_PACKET, INVALID_PACKET_TYPE, packet_, self_address_, peer_address_, - valid_version_, kDefaultParsedChlo, &connection_id_generator_)); + valid_version_, kDefaultParsedChlo, connection_id_generator_)); EXPECT_EQ(EnqueuePacketResult::SUCCESS, EnqueuePacketToStore(store_, TestConnectionId(1), GOOGLE_QUIC_PACKET, INVALID_PACKET_TYPE, packet_, self_address_, peer_address_, - valid_version_, kNoParsedChlo, &generator2)); + valid_version_, kNoParsedChlo, generator2)); QuicConnectionId delivered_conn_id; BufferedPacketList packet_list = store_.DeliverPacketsForNextConnection(&delivered_conn_id); EXPECT_EQ(2u, packet_list.buffered_packets.size()); EXPECT_EQ(delivered_conn_id, TestConnectionId(1)); - EXPECT_EQ(packet_list.connection_id_generator, &connection_id_generator_); + if (GetQuicRestartFlag(quic_dispatcher_replace_cid_on_first_packet)) { + EXPECT_EQ(packet_list.connection_id_generator, nullptr); + } else { + EXPECT_EQ(packet_list.connection_id_generator, &connection_id_generator_); + } } TEST_F(QuicBufferedPacketStoreTest, EnqueueChloOnTooManyDifferentConnections) { @@ -326,7 +337,7 @@ EnqueuePacketToStore(store_, connection_id, GOOGLE_QUIC_PACKET, INVALID_PACKET_TYPE, packet_, self_address_, peer_address_, invalid_version_, - kNoParsedChlo, &connection_id_generator_)); + kNoParsedChlo, connection_id_generator_)); } // Buffer CHLOs on other connections till store is full. @@ -336,7 +347,7 @@ EnqueuePacketResult rs = EnqueuePacketToStore( store_, connection_id, GOOGLE_QUIC_PACKET, INVALID_PACKET_TYPE, packet_, self_address_, peer_address_, valid_version_, kDefaultParsedChlo, - &connection_id_generator_); + connection_id_generator_); if (i <= kDefaultMaxConnectionsInStore) { EXPECT_EQ(EnqueuePacketResult::SUCCESS, rs); EXPECT_TRUE(store_.HasChloForConnection(connection_id)); @@ -354,7 +365,7 @@ EnqueuePacketToStore( store_, TestConnectionId(1), GOOGLE_QUIC_PACKET, INVALID_PACKET_TYPE, packet_, self_address_, peer_address_, - valid_version_, kDefaultParsedChlo, &connection_id_generator_)); + valid_version_, kDefaultParsedChlo, connection_id_generator_)); EXPECT_TRUE(store_.HasChloForConnection(TestConnectionId(1))); QuicConnectionId delivered_conn_id; @@ -372,7 +383,11 @@ EXPECT_EQ(2u, packet_list.buffered_packets.size()); EXPECT_EQ(TestConnectionId(1u), delivered_conn_id); } - EXPECT_EQ(packet_list.connection_id_generator, &connection_id_generator_); + if (GetQuicRestartFlag(quic_dispatcher_replace_cid_on_first_packet)) { + EXPECT_EQ(packet_list.connection_id_generator, nullptr); + } else { + EXPECT_EQ(packet_list.connection_id_generator, &connection_id_generator_); + } } EXPECT_FALSE(store_.HasChlosBuffered()); } @@ -384,18 +399,18 @@ EnqueuePacketToStore(store_, connection_id, GOOGLE_QUIC_PACKET, INVALID_PACKET_TYPE, packet_, self_address_, peer_address_, invalid_version_, kNoParsedChlo, - &connection_id_generator_); + connection_id_generator_); EXPECT_EQ(EnqueuePacketResult::SUCCESS, - EnqueuePacketToStore( - store_, connection_id, GOOGLE_QUIC_PACKET, INVALID_PACKET_TYPE, - packet_, self_address_, peer_address_, valid_version_, - kDefaultParsedChlo, &connection_id_generator_)); + EnqueuePacketToStore(store_, connection_id, GOOGLE_QUIC_PACKET, + INVALID_PACKET_TYPE, packet_, self_address_, + peer_address_, valid_version_, + kDefaultParsedChlo, connection_id_generator_)); QuicConnectionId connection_id2 = TestConnectionId(2); EXPECT_EQ(EnqueuePacketResult::SUCCESS, EnqueuePacketToStore(store_, connection_id2, GOOGLE_QUIC_PACKET, INVALID_PACKET_TYPE, packet_, self_address_, peer_address_, invalid_version_, kNoParsedChlo, - &connection_id_generator_)); + connection_id_generator_)); // CHLO on connection 3 arrives 1ms later. clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(1)); @@ -406,7 +421,7 @@ EnqueuePacketToStore(store_, connection_id3, GOOGLE_QUIC_PACKET, INVALID_PACKET_TYPE, packet_, self_address_, another_client_address, valid_version_, - kDefaultParsedChlo, &connection_id_generator_); + kDefaultParsedChlo, connection_id_generator_); // Advance clock to the time when connection 1 and 2 expires. clock_.AdvanceTime( @@ -431,7 +446,11 @@ // Connection 3 is the next to be delivered as connection 1 already expired. EXPECT_EQ(connection_id3, delivered_conn_id); - EXPECT_EQ(packet_list.connection_id_generator, &connection_id_generator_); + if (GetQuicRestartFlag(quic_dispatcher_replace_cid_on_first_packet)) { + EXPECT_EQ(packet_list.connection_id_generator, nullptr); + } else { + EXPECT_EQ(packet_list.connection_id_generator, &connection_id_generator_); + } ASSERT_EQ(1u, packet_list.buffered_packets.size()); // Packets in connection 3 should use another peer address. EXPECT_EQ(another_client_address, @@ -442,10 +461,12 @@ QuicConnectionId connection_id4 = TestConnectionId(4); EnqueuePacketToStore(store_, connection_id4, GOOGLE_QUIC_PACKET, INVALID_PACKET_TYPE, packet_, self_address_, - peer_address_, invalid_version_, kNoParsedChlo, nullptr); + peer_address_, invalid_version_, kNoParsedChlo, + connection_id_generator_); EnqueuePacketToStore(store_, connection_id4, GOOGLE_QUIC_PACKET, INVALID_PACKET_TYPE, packet_, self_address_, - peer_address_, invalid_version_, kNoParsedChlo, nullptr); + peer_address_, invalid_version_, kNoParsedChlo, + connection_id_generator_); clock_.AdvanceTime( QuicBufferedPacketStorePeer::expiration_alarm(&store_)->deadline() - clock_.ApproximateNow()); @@ -461,10 +482,12 @@ // Enqueue some packets EnqueuePacketToStore(store_, connection_id, GOOGLE_QUIC_PACKET, INVALID_PACKET_TYPE, packet_, self_address_, - peer_address_, invalid_version_, kNoParsedChlo, nullptr); + peer_address_, invalid_version_, kNoParsedChlo, + connection_id_generator_); EnqueuePacketToStore(store_, connection_id, GOOGLE_QUIC_PACKET, INVALID_PACKET_TYPE, packet_, self_address_, - peer_address_, invalid_version_, kNoParsedChlo, nullptr); + peer_address_, invalid_version_, kNoParsedChlo, + connection_id_generator_); EXPECT_TRUE(store_.HasBufferedPackets(connection_id)); EXPECT_FALSE(store_.HasChlosBuffered()); @@ -489,14 +512,16 @@ // Enqueue some packets, which include a CHLO EnqueuePacketToStore(store_, connection_id, GOOGLE_QUIC_PACKET, INVALID_PACKET_TYPE, packet_, self_address_, - peer_address_, invalid_version_, kNoParsedChlo, nullptr); + peer_address_, invalid_version_, kNoParsedChlo, + connection_id_generator_); EnqueuePacketToStore(store_, connection_id, GOOGLE_QUIC_PACKET, INVALID_PACKET_TYPE, packet_, self_address_, peer_address_, valid_version_, kDefaultParsedChlo, - nullptr); + connection_id_generator_); EnqueuePacketToStore(store_, connection_id, GOOGLE_QUIC_PACKET, INVALID_PACKET_TYPE, packet_, self_address_, - peer_address_, invalid_version_, kNoParsedChlo, nullptr); + peer_address_, invalid_version_, kNoParsedChlo, + connection_id_generator_); EXPECT_TRUE(store_.HasBufferedPackets(connection_id)); EXPECT_TRUE(store_.HasChlosBuffered()); @@ -522,17 +547,19 @@ // Enqueue some packets for two connection IDs EnqueuePacketToStore(store_, connection_id_1, GOOGLE_QUIC_PACKET, INVALID_PACKET_TYPE, packet_, self_address_, - peer_address_, invalid_version_, kNoParsedChlo, nullptr); + peer_address_, invalid_version_, kNoParsedChlo, + connection_id_generator_); EnqueuePacketToStore(store_, connection_id_1, GOOGLE_QUIC_PACKET, INVALID_PACKET_TYPE, packet_, self_address_, - peer_address_, invalid_version_, kNoParsedChlo, nullptr); + peer_address_, invalid_version_, kNoParsedChlo, + connection_id_generator_); ParsedClientHello parsed_chlo; parsed_chlo.alpns.push_back("h3"); parsed_chlo.sni = TestHostname(); - EnqueuePacketToStore(store_, connection_id_2, GOOGLE_QUIC_PACKET, - INVALID_PACKET_TYPE, packet_, self_address_, - peer_address_, valid_version_, parsed_chlo, nullptr); + EnqueuePacketToStore(store_, connection_id_2, IETF_QUIC_LONG_HEADER_PACKET, + INITIAL, packet_, self_address_, peer_address_, + valid_version_, parsed_chlo, connection_id_generator_); EXPECT_TRUE(store_.HasBufferedPackets(connection_id_1)); EXPECT_TRUE(store_.HasBufferedPackets(connection_id_2)); EXPECT_TRUE(store_.HasChlosBuffered()); @@ -554,8 +581,12 @@ EXPECT_EQ(TestHostname(), packets.parsed_chlo->sni); // Since connection_id_2's chlo arrives, verify version is set. EXPECT_EQ(valid_version_, packets.version); - EXPECT_TRUE(store_.HasChlosBuffered()); + if (store_.replace_cid_on_first_packet()) { + EXPECT_FALSE(store_.HasChlosBuffered()); + } else { + EXPECT_TRUE(store_.HasChlosBuffered()); + } // Discard the packets for connection 2 store_.DiscardPackets(connection_id_2); EXPECT_FALSE(store_.HasChlosBuffered()); @@ -586,7 +617,8 @@ EXPECT_FALSE(store_.HasBufferedPackets(connection_id)); EnqueuePacketToStore(store_, connection_id, GOOGLE_QUIC_PACKET, INVALID_PACKET_TYPE, packet_, self_address_, - peer_address_, valid_version_, kNoParsedChlo, nullptr); + peer_address_, valid_version_, kNoParsedChlo, + connection_id_generator_); EXPECT_TRUE(store_.HasBufferedPackets(connection_id)); // The packet in 'packet_' is not a TLS CHLO packet. @@ -608,10 +640,12 @@ EnqueuePacketToStore(store_, connection_id, GOOGLE_QUIC_PACKET, INVALID_PACKET_TYPE, *packets[0], self_address_, - peer_address_, valid_version_, kNoParsedChlo, nullptr); + peer_address_, valid_version_, kNoParsedChlo, + connection_id_generator_); EnqueuePacketToStore(store_, connection_id, GOOGLE_QUIC_PACKET, INVALID_PACKET_TYPE, *packets[1], self_address_, - peer_address_, valid_version_, kNoParsedChlo, nullptr); + peer_address_, valid_version_, kNoParsedChlo, + connection_id_generator_); EXPECT_TRUE(store_.HasBufferedPackets(connection_id)); EXPECT_FALSE(store_.IngestPacketForTlsChloExtraction( @@ -691,13 +725,15 @@ EnqueuePacketToStore(store_, connection_id, packet_format, long_packet_type, packet_, self_address_, peer_address_, valid_version_, - kNoParsedChlo, nullptr); + kNoParsedChlo, connection_id_generator_); EnqueuePacketToStore(store_, connection_id, IETF_QUIC_LONG_HEADER_PACKET, INITIAL, *initial_packets[0], self_address_, - peer_address_, valid_version_, kNoParsedChlo, nullptr); + peer_address_, valid_version_, kNoParsedChlo, + connection_id_generator_); EnqueuePacketToStore(store_, connection_id, IETF_QUIC_LONG_HEADER_PACKET, INITIAL, *initial_packets[1], self_address_, - peer_address_, valid_version_, kNoParsedChlo, nullptr); + peer_address_, valid_version_, kNoParsedChlo, + connection_id_generator_); BufferedPacketList delivered_packets = store_.DeliverPackets(connection_id); EXPECT_THAT(delivered_packets.buffered_packets, SizeIs(3)); @@ -728,7 +764,8 @@ false, ECN_ECT1); EnqueuePacketToStore(store_, connection_id, GOOGLE_QUIC_PACKET, INVALID_PACKET_TYPE, ect1_packet, self_address_, - peer_address_, valid_version_, kNoParsedChlo, nullptr); + peer_address_, valid_version_, kNoParsedChlo, + connection_id_generator_); BufferedPacketList delivered_packets = store_.DeliverPackets(connection_id); EXPECT_THAT(delivered_packets.buffered_packets, SizeIs(1)); for (const auto& packet : delivered_packets.buffered_packets) { @@ -736,6 +773,15 @@ } } +TEST_F(QuicBufferedPacketStoreTest, EmptyBufferedPacketList) { + BufferedPacketList packet_list; + EXPECT_TRUE(packet_list.buffered_packets.empty()); + EXPECT_FALSE(packet_list.parsed_chlo.has_value()); + EXPECT_FALSE(packet_list.version.IsKnown()); + EXPECT_TRUE(packet_list.original_connection_id.IsEmpty()); + EXPECT_FALSE(packet_list.replaced_connection_id.has_value()); +} + } // namespace } // namespace test } // namespace quic
diff --git a/quiche/quic/core/quic_dispatcher.cc b/quiche/quic/core/quic_dispatcher.cc index 49ace1c..4b832e8 100644 --- a/quiche/quic/core/quic_dispatcher.cc +++ b/quiche/quic/core/quic_dispatcher.cc
@@ -447,10 +447,15 @@ if (buffered_packets_.HasChloForConnection(server_connection_id)) { EnqueuePacketResult rs = buffered_packets_.EnqueuePacket( packet_info, - /*parsed_chlo=*/std::nullopt, /*connection_id_generator=*/nullptr); + /*parsed_chlo=*/std::nullopt, ConnectionIdGenerator()); switch (rs) { case EnqueuePacketResult::SUCCESS: break; + case EnqueuePacketResult::CID_COLLISION: + QUICHE_DCHECK(false) << "Connection " << server_connection_id + << " already has a CHLO buffered, but " + "EnqueuePacket returned CID_COLLISION."; + ABSL_FALLTHROUGH_INTENDED; case EnqueuePacketResult::TOO_MANY_PACKETS: ABSL_FALLTHROUGH_INTENDED; case EnqueuePacketResult::TOO_MANY_CONNECTIONS: @@ -658,10 +663,17 @@ // or it could be a fragment of a multi-packet CHLO. EnqueuePacketResult rs = buffered_packets_.EnqueuePacket( packet_info, - /*parsed_chlo=*/std::nullopt, /*connection_id_generator=*/nullptr); + /*parsed_chlo=*/std::nullopt, ConnectionIdGenerator()); switch (rs) { case EnqueuePacketResult::SUCCESS: break; + case EnqueuePacketResult::CID_COLLISION: + QUICHE_DCHECK(buffered_packets_.replace_cid_on_first_packet()); + QUIC_RESTART_FLAG_COUNT_N(quic_dispatcher_replace_cid_on_first_packet, + 9, 13); + buffered_packets_.DiscardPackets( + packet_info.destination_connection_id); + ABSL_FALLTHROUGH_INTENDED; case EnqueuePacketResult::TOO_MANY_PACKETS: ABSL_FALLTHROUGH_INTENDED; case EnqueuePacketResult::TOO_MANY_CONNECTIONS: @@ -693,10 +705,15 @@ // Buffer non-CHLO packets. EnqueuePacketResult rs = buffered_packets_.EnqueuePacket( packet_info, - /*parsed_chlo=*/std::nullopt, /*connection_id_generator=*/nullptr); + /*parsed_chlo=*/std::nullopt, ConnectionIdGenerator()); switch (rs) { case EnqueuePacketResult::SUCCESS: break; + case EnqueuePacketResult::CID_COLLISION: + // This should never happen; we only replace CID in the packet store + // for IETF packets. + QUIC_BUG(quic_store_cid_collision_from_gquic_packet); + ABSL_FALLTHROUGH_INTENDED; case EnqueuePacketResult::TOO_MANY_PACKETS: ABSL_FALLTHROUGH_INTENDED; case EnqueuePacketResult::TOO_MANY_CONNECTIONS: @@ -1036,6 +1053,9 @@ return false; } +// TODO(wub): After deprecating --quic_dispatcher_replace_cid_on_first_packet, +// remove |server_connection_id| because |early_arrived_packets| already +// contains the original and replaced connection ID. void QuicDispatcher::OnExpiredPackets( QuicConnectionId server_connection_id, BufferedPacketList early_arrived_packets) { @@ -1083,7 +1103,8 @@ continue; } auto session_ptr = CreateSessionFromChlo( - server_connection_id, *packet_list.parsed_chlo, packet_list.version, + server_connection_id, packet_list.replaced_connection_id, + *packet_list.parsed_chlo, packet_list.version, packets.front().self_address, packets.front().peer_address, packet_list.connection_id_generator); if (session_ptr != nullptr) { @@ -1120,10 +1141,17 @@ QUIC_BUG_IF(quic_bug_12724_7, buffered_packets_.HasChloForConnection( packet_info->destination_connection_id)); EnqueuePacketResult rs = buffered_packets_.EnqueuePacket( - *packet_info, std::move(parsed_chlo), &ConnectionIdGenerator()); + *packet_info, std::move(parsed_chlo), ConnectionIdGenerator()); switch (rs) { case EnqueuePacketResult::SUCCESS: break; + case EnqueuePacketResult::CID_COLLISION: + QUICHE_DCHECK(buffered_packets_.replace_cid_on_first_packet()); + QUIC_RESTART_FLAG_COUNT_N(quic_dispatcher_replace_cid_on_first_packet, + 10, 13); + buffered_packets_.DiscardPackets( + packet_info->destination_connection_id); + ABSL_FALLTHROUGH_INTENDED; case EnqueuePacketResult::TOO_MANY_PACKETS: ABSL_FALLTHROUGH_INTENDED; case EnqueuePacketResult::TOO_MANY_CONNECTIONS: @@ -1133,10 +1161,46 @@ return; } + if (buffered_packets_.replace_cid_on_first_packet()) { + QUIC_RESTART_FLAG_COUNT_N(quic_dispatcher_replace_cid_on_first_packet, 11, + 13); + BufferedPacketList packet_list = buffered_packets_.DeliverPackets( + packet_info->destination_connection_id); + // Get original_connection_id from buffered packets because + // destination_connection_id may be replaced connection_id if any packets + // have been sent by packet store. + QuicConnectionId original_connection_id = + packet_list.buffered_packets.empty() + ? packet_info->destination_connection_id + : packet_list.original_connection_id; + + auto session_ptr = CreateSessionFromChlo( + original_connection_id, packet_list.replaced_connection_id, parsed_chlo, + packet_info->version, packet_info->self_address, + packet_info->peer_address, packet_list.connection_id_generator); + if (session_ptr == nullptr) { + // The only reason that CreateSessionFromChlo returns nullptr is because + // of CID collision, which can only happen if CreateSessionFromChlo + // attempted to replace the CID, CreateSessionFromChlo only replaces the + // CID when connection_id_generator is nullptr. + QUICHE_DCHECK_EQ(packet_list.connection_id_generator, nullptr); + return; + } + // Process the current packet first, then deliver queued-up packets. + // Note that multi-packet CHLOs, if received in packet number order, will + // not be delivered in the same order. This needs to be fixed. + session_ptr->ProcessUdpPacket(packet_info->self_address, + packet_info->peer_address, + packet_info->packet); + DeliverPacketsToSession(packet_list.buffered_packets, session_ptr.get()); + --new_sessions_allowed_per_event_loop_; + return; + } + auto session_ptr = CreateSessionFromChlo( - packet_info->destination_connection_id, parsed_chlo, packet_info->version, - packet_info->self_address, packet_info->peer_address, - &ConnectionIdGenerator()); + packet_info->destination_connection_id, std::nullopt, parsed_chlo, + packet_info->version, packet_info->self_address, + packet_info->peer_address, &ConnectionIdGenerator()); if (session_ptr == nullptr) { return; } @@ -1206,53 +1270,93 @@ std::shared_ptr<QuicSession> QuicDispatcher::CreateSessionFromChlo( const QuicConnectionId original_connection_id, + const std::optional<QuicConnectionId>& replaced_connection_id, const ParsedClientHello& parsed_chlo, const ParsedQuicVersion version, const QuicSocketAddress self_address, const QuicSocketAddress peer_address, ConnectionIdGeneratorInterface* connection_id_generator) { + bool should_generate_cid = false; if (connection_id_generator == nullptr) { + should_generate_cid = true; connection_id_generator = &ConnectionIdGenerator(); } - std::optional<QuicConnectionId> server_connection_id = - connection_id_generator->MaybeReplaceConnectionId(original_connection_id, - version); - const bool replaced_connection_id = server_connection_id.has_value(); - if (!replaced_connection_id) { + std::optional<QuicConnectionId> server_connection_id; + if (buffered_packets_.replace_cid_on_first_packet()) { + if (should_generate_cid) { + QUIC_RESTART_FLAG_COUNT_N(quic_dispatcher_replace_cid_on_first_packet, 12, + 13); + server_connection_id = connection_id_generator->MaybeReplaceConnectionId( + original_connection_id, version); + // Normalize the output of MaybeReplaceConnectionId. + if (server_connection_id.has_value() && + (server_connection_id->IsEmpty() || + *server_connection_id == original_connection_id)) { + server_connection_id.reset(); + } + QUIC_DVLOG(1) << "MaybeReplaceConnectionId(" << original_connection_id + << ") = " + << (server_connection_id.has_value() + ? server_connection_id->ToString() + : "nullopt"); + + if (server_connection_id.has_value()) { + switch (HandleConnectionIdCollision( + original_connection_id, *server_connection_id, self_address, + peer_address, version, &parsed_chlo)) { + case VisitorInterface::HandleCidCollisionResult::kOk: + break; + case VisitorInterface::HandleCidCollisionResult::kCollision: + return nullptr; + } + } + } else { + QUIC_RESTART_FLAG_COUNT_N(quic_dispatcher_replace_cid_on_first_packet, 13, + 13); + server_connection_id = replaced_connection_id; + } + } else { + server_connection_id = connection_id_generator->MaybeReplaceConnectionId( + original_connection_id, version); + } + const bool connection_id_replaced = server_connection_id.has_value(); + if (!connection_id_replaced) { server_connection_id = original_connection_id; } - QUIC_CODE_COUNT(quic_connection_id_chosen); - if (reference_counted_session_map_.count(*server_connection_id) > 0) { - // The new connection ID is owned by another session. Avoid creating one - // altogether, as this connection attempt cannot possibly succeed. - QUIC_CODE_COUNT(quic_connection_id_collision); - QuicConnection* other_connection = - reference_counted_session_map_[*server_connection_id]->connection(); - if (other_connection != nullptr) { // Just make sure there is no crash. - QUIC_LOG_EVERY_N_SEC(ERROR, 10) - << "QUIC Connection ID collision. original_connection_id:" - << original_connection_id.ToString() - << " server_connection_id:" << server_connection_id->ToString() - << ", version:" << version << ", self_address:" << self_address - << ", peer_address:" << peer_address - << ", parsed_chlo:" << parsed_chlo - << ", other peer address: " << other_connection->peer_address() - << ", other CIDs: " - << quiche::PrintElements( - other_connection->GetActiveServerConnectionIds()) - << ", other stats: " << other_connection->GetStats(); + if (!buffered_packets_.replace_cid_on_first_packet()) { + QUIC_CODE_COUNT(quic_connection_id_chosen); + if (reference_counted_session_map_.count(*server_connection_id) > 0) { + // The new connection ID is owned by another session. Avoid creating one + // altogether, as this connection attempt cannot possibly succeed. + QUIC_CODE_COUNT(quic_connection_id_collision); + QuicConnection* other_connection = + reference_counted_session_map_[*server_connection_id]->connection(); + if (other_connection != nullptr) { // Just make sure there is no crash. + QUIC_LOG_EVERY_N_SEC(ERROR, 10) + << "QUIC Connection ID collision. original_connection_id:" + << original_connection_id.ToString() + << " server_connection_id:" << server_connection_id->ToString() + << ", version:" << version << ", self_address:" << self_address + << ", peer_address:" << peer_address + << ", parsed_chlo:" << parsed_chlo + << ", other peer address: " << other_connection->peer_address() + << ", other CIDs: " + << quiche::PrintElements( + other_connection->GetActiveServerConnectionIds()) + << ", other stats: " << other_connection->GetStats(); + } + if (connection_id_replaced) { + QUIC_CODE_COUNT(quic_replaced_connection_id_collision); + // The original connection ID does not correspond to an existing + // session. It is safe to send CONNECTION_CLOSE and add to TIME_WAIT. + StatelesslyTerminateConnection( + self_address, peer_address, original_connection_id, + IETF_QUIC_LONG_HEADER_PACKET, + /*version_flag=*/true, version.HasLengthPrefixedConnectionIds(), + version, QUIC_HANDSHAKE_FAILED, + "Connection ID collision, please retry", + QuicTimeWaitListManager::SEND_CONNECTION_CLOSE_PACKETS); + } + return nullptr; } - if (replaced_connection_id) { - QUIC_CODE_COUNT(quic_replaced_connection_id_collision); - // The original connection ID does not correspond to an existing - // session. It is safe to send CONNECTION_CLOSE and add to TIME_WAIT. - StatelesslyTerminateConnection( - self_address, peer_address, original_connection_id, - IETF_QUIC_LONG_HEADER_PACKET, - /*version_flag=*/true, version.HasLengthPrefixedConnectionIds(), - version, QUIC_HANDSHAKE_FAILED, - "Connection ID collision, please retry", - QuicTimeWaitListManager::SEND_CONNECTION_CLOSE_PACKETS); - } - return nullptr; } // Creates a new session and process all buffered packets for this connection. std::string alpn = SelectAlpn(parsed_chlo.alpns); @@ -1267,7 +1371,7 @@ return nullptr; } - if (replaced_connection_id) { + if (connection_id_replaced) { session->connection()->SetOriginalDestinationConnectionId( original_connection_id); } @@ -1283,7 +1387,7 @@ << *server_connection_id; } else { ++num_sessions_in_session_map_; - if (replaced_connection_id) { + if (connection_id_replaced) { auto insertion_result2 = reference_counted_session_map_.insert( std::make_pair(original_connection_id, session_ptr)); QUIC_BUG_IF(quic_460317833_02, !insertion_result2.second) @@ -1297,6 +1401,69 @@ return session_ptr; } +QuicDispatcher::HandleCidCollisionResult +QuicDispatcher::HandleConnectionIdCollision( + const QuicConnectionId& original_connection_id, + const QuicConnectionId& replaced_connection_id, + const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address, ParsedQuicVersion version, + const ParsedClientHello* parsed_chlo) { + QUICHE_DCHECK(buffered_packets_.replace_cid_on_first_packet()); + HandleCidCollisionResult result = HandleCidCollisionResult::kOk; + auto existing_session_iter = + reference_counted_session_map_.find(replaced_connection_id); + if (existing_session_iter != reference_counted_session_map_.end()) { + // Collide with an active session in dispatcher. + result = HandleCidCollisionResult::kCollision; + QUIC_CODE_COUNT(quic_connection_id_collision); + QuicConnection* other_connection = + existing_session_iter->second->connection(); + if (other_connection != nullptr) { // Just make sure there is no crash. + QUIC_LOG_EVERY_N_SEC(ERROR, 10) + << "QUIC Connection ID collision. original_connection_id:" + << original_connection_id + << ", replaced_connection_id:" << replaced_connection_id + << ", version:" << version << ", self_address:" << self_address + << ", peer_address:" << peer_address << ", parsed_chlo:" + << (parsed_chlo == nullptr ? "null" : parsed_chlo->ToString()) + << ", other peer address: " << other_connection->peer_address() + << ", other CIDs: " + << quiche::PrintElements( + other_connection->GetActiveServerConnectionIds()) + << ", other stats: " << other_connection->GetStats(); + } + } else if (buffered_packets_.HasBufferedPackets(replaced_connection_id)) { + // Collide with a buffered session in packet store. + result = HandleCidCollisionResult::kCollision; + QUIC_CODE_COUNT(quic_connection_id_collision_with_buffered_session); + } + + if (result == HandleCidCollisionResult::kOk) { + return result; + } + + const bool collide_with_active_session = + existing_session_iter != reference_counted_session_map_.end(); + QUIC_DLOG(INFO) << "QUIC Connection ID collision with " + << (collide_with_active_session ? "active session" + : "buffered session") + << " for original_connection_id:" << original_connection_id + << ", replaced_connection_id:" << replaced_connection_id; + + // The original connection ID does not correspond to an existing + // session. It is safe to send CONNECTION_CLOSE and add to TIME_WAIT. + StatelesslyTerminateConnection( + self_address, peer_address, original_connection_id, + IETF_QUIC_LONG_HEADER_PACKET, + /*version_flag=*/true, version.HasLengthPrefixedConnectionIds(), version, + QUIC_HANDSHAKE_FAILED, "Connection ID collision, please retry", + QuicTimeWaitListManager::SEND_CONNECTION_CLOSE_PACKETS); + + // Caller is responsible for erasing the connection from the buffered store, + // if needed. + return result; +} + void QuicDispatcher::MaybeResetPacketsWithNoVersion( const ReceivedPacketInfo& packet_info) { QUICHE_DCHECK(!packet_info.version_flag);
diff --git a/quiche/quic/core/quic_dispatcher.h b/quiche/quic/core/quic_dispatcher.h index 58e9d68..80118b7 100644 --- a/quiche/quic/core/quic_dispatcher.h +++ b/quiche/quic/core/quic_dispatcher.h
@@ -157,6 +157,12 @@ void OnExpiredPackets(QuicConnectionId server_connection_id, QuicBufferedPacketStore::BufferedPacketList early_arrived_packets) override; + HandleCidCollisionResult HandleConnectionIdCollision( + const QuicConnectionId& original_connection_id, + const QuicConnectionId& replaced_connection_id, + const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address, ParsedQuicVersion version, + const ParsedClientHello* parsed_chlo) override; void OnPathDegrading() override {} // Create connections for previously buffered CHLOs as many as allowed. @@ -383,8 +389,19 @@ bool IsServerConnectionIdTooShort(QuicConnectionId connection_id) const; // Core CHLO processing logic. + // + // |connection_id_generator| != nullptr indicates we have attempted to + // call connection_id_generator->MaybeReplaceConnectionId() and the result is + // in |replaced_connection_id|. + // + // |connection_id_generator| == nullptr indicates we have not attempted to + // generate a replacement connection ID, in that case + // - |replaced_connection_id| should be std::nullopt. + // - CreateSessionFromChlo will generate a replacement connection ID using + // ConnectionIdGenerator().MaybeReplaceConnectionId(). std::shared_ptr<QuicSession> CreateSessionFromChlo( QuicConnectionId original_connection_id, + const std::optional<QuicConnectionId>& replaced_connection_id, const ParsedClientHello& parsed_chlo, ParsedQuicVersion version, QuicSocketAddress self_address, QuicSocketAddress peer_address, ConnectionIdGeneratorInterface* connection_id_generator);
diff --git a/quiche/quic/core/quic_dispatcher_test.cc b/quiche/quic/core/quic_dispatcher_test.cc index 21cb7c3..3d081ea 100644 --- a/quiche/quic/core/quic_dispatcher_test.cc +++ b/quiche/quic/core/quic_dispatcher_test.cc
@@ -2752,10 +2752,23 @@ const size_t kNumCHLOs = kMaxNumSessionsToCreate + kDefaultMaxConnectionsInStore + 1; for (uint64_t conn_id = 1; conn_id <= kNumCHLOs; ++conn_id) { - if (conn_id <= kMaxNumSessionsToCreate) { + const bool should_drop = + (conn_id > kMaxNumSessionsToCreate + kDefaultMaxConnectionsInStore); + if (store->replace_cid_on_first_packet() && !should_drop) { + // MaybeReplaceConnectionId will be called once per connection, whether it + // is buffered or not. EXPECT_CALL(connection_id_generator_, MaybeReplaceConnectionId(TestConnectionId(conn_id), version_)) .WillOnce(Return(std::nullopt)); + } + + if (conn_id <= kMaxNumSessionsToCreate) { + if (!store->replace_cid_on_first_packet()) { + EXPECT_CALL( + connection_id_generator_, + MaybeReplaceConnectionId(TestConnectionId(conn_id), version_)) + .WillOnce(Return(std::nullopt)); + } EXPECT_CALL( *dispatcher_, CreateQuicSession(TestConnectionId(conn_id), _, client_addr_, @@ -2793,9 +2806,13 @@ for (uint64_t conn_id = kMaxNumSessionsToCreate + 1; conn_id <= kMaxNumSessionsToCreate + kDefaultMaxConnectionsInStore; ++conn_id) { - EXPECT_CALL(connection_id_generator_, - MaybeReplaceConnectionId(TestConnectionId(conn_id), version_)) - .WillOnce(Return(std::nullopt)); + // MaybeReplaceConnectionId should have been called once per buffered + // session. + if (!store->replace_cid_on_first_packet()) { + EXPECT_CALL(connection_id_generator_, + MaybeReplaceConnectionId(TestConnectionId(conn_id), version_)) + .WillOnce(Return(std::nullopt)); + } EXPECT_CALL( *dispatcher_, CreateQuicSession(TestConnectionId(conn_id), _, client_addr_, @@ -2862,6 +2879,13 @@ expect_generator_is_called_ = false; EXPECT_CALL(*dispatcher_, ConnectionIdGenerator()) .WillRepeatedly(ReturnRef(generator2)); + if (store->replace_cid_on_first_packet()) { + // generator2 should be used to replace the connection ID when the first + // IETF INITIAL is enqueued. + EXPECT_CALL(generator2, + MaybeReplaceConnectionId(TestConnectionId(conn_id), version_)) + .WillOnce(Return(std::nullopt)); + } ProcessFirstFlight(TestConnectionId(conn_id)); EXPECT_TRUE(store->HasChloForConnection(TestConnectionId(conn_id))); // Change the generator back so that the session can only access generator2 @@ -2869,11 +2893,13 @@ EXPECT_CALL(*dispatcher_, ConnectionIdGenerator()) .WillRepeatedly(ReturnRef(connection_id_generator_)); - // Consume the buffered CHLO. The buffered connection should be - // created using generator2. - EXPECT_CALL(generator2, - MaybeReplaceConnectionId(TestConnectionId(conn_id), version_)) - .WillOnce(Return(std::nullopt)); + if (!store->replace_cid_on_first_packet()) { + // Consume the buffered CHLO. The buffered connection should be + // created using generator2. + EXPECT_CALL(generator2, + MaybeReplaceConnectionId(TestConnectionId(conn_id), version_)) + .WillOnce(Return(std::nullopt)); + } EXPECT_CALL(*dispatcher_, CreateQuicSession(TestConnectionId(conn_id), _, client_addr_, Eq(ExpectedAlpn()), _, MatchParsedClientHello(), _)) @@ -2975,11 +3001,11 @@ ProcessFirstFlight(TestConnectionId(conn_id)); } - // Process another |kDefaultMaxUndecryptablePackets| + 1 data packets. The - // last one should be dropped. - for (uint64_t packet_number = 2; - packet_number <= kDefaultMaxUndecryptablePackets + 2; ++packet_number) { - ProcessPacket(client_addr_, last_connection_id, true, "data packet"); + // |last_connection_id| has 1 packet buffered now. Process another + // |kDefaultMaxUndecryptablePackets| + 1 data packets to reach max number of + // buffered packets per connection. + for (uint64_t i = 0; i <= kDefaultMaxUndecryptablePackets; ++i) { + ProcessPacket(client_addr_, last_connection_id, false, "data packet"); } // Reset counter and process buffered CHLO. @@ -2990,11 +3016,25 @@ dispatcher_.get(), config_, last_connection_id, client_addr_, &mock_helper_, &mock_alarm_factory_, &crypto_config_, QuicDispatcherPeer::GetCache(dispatcher_.get()), &session1_)))); - // Only CHLO and following |kDefaultMaxUndecryptablePackets| data packets - // should be process. + + const QuicBufferedPacketStore* store = + QuicDispatcherPeer::GetBufferedPackets(dispatcher_.get()); + const QuicBufferedPacketStore::BufferedPacketList* + last_connection_buffered_packets = + QuicBufferedPacketStorePeer::FindBufferedPackets(store, + last_connection_id); + ASSERT_NE(last_connection_buffered_packets, nullptr); + if (store->replace_cid_on_first_packet()) { + ASSERT_EQ(last_connection_buffered_packets->buffered_packets.size(), + kDefaultMaxUndecryptablePackets); + } else { + ASSERT_EQ(last_connection_buffered_packets->buffered_packets.size(), + kDefaultMaxUndecryptablePackets + 1); + } + // All buffered packets should be delivered to the session. EXPECT_CALL(*reinterpret_cast<MockQuicConnection*>(session1_->connection()), ProcessUdpPacket(_, _, _)) - .Times(kDefaultMaxUndecryptablePackets + 1) + .Times(last_connection_buffered_packets->buffered_packets.size()) .WillRepeatedly(WithArg<2>( Invoke([this, last_connection_id](const QuicEncryptedPacket& packet) { if (version_.UsesQuicCrypto()) { @@ -3036,7 +3076,7 @@ ValidatePacket(TestConnectionId(conn_id), packet); } }))); - } else { + } else if (!store->replace_cid_on_first_packet()) { expect_generator_is_called_ = false; } ProcessFirstFlight(TestConnectionId(conn_id)); @@ -3168,6 +3208,189 @@ EXPECT_TRUE(got_ce); } +class DualCIDBufferedPacketStoreTest : public BufferedPacketStoreTest { + protected: + void SetUp() override { + if (!GetQuicRestartFlag(quic_dispatcher_replace_cid_on_first_packet)) { + GTEST_SKIP(); + } + + BufferedPacketStoreTest::SetUp(); + QuicDispatcherPeer::set_new_sessions_allowed_per_event_loop( + dispatcher_.get(), 0); + + // Prevent |ProcessFirstFlight| from setting up expectations for + // MaybeReplaceConnectionId. + expect_generator_is_called_ = false; + EXPECT_CALL(connection_id_generator_, MaybeReplaceConnectionId(_, _)) + .WillRepeatedly(Invoke( + this, &DualCIDBufferedPacketStoreTest::ReplaceConnectionIdInTest)); + } + + std::optional<QuicConnectionId> ReplaceConnectionIdInTest( + const QuicConnectionId& original, const ParsedQuicVersion& version) { + auto it = replaced_cid_map_.find(original); + if (it == replaced_cid_map_.end()) { + ADD_FAILURE() << "Bad test setup: no replacement CID for " << original + << ", version " << version; + return std::nullopt; + } + return it->second; + } + + QuicBufferedPacketStore& store() { + return *QuicDispatcherPeer::GetBufferedPackets(dispatcher_.get()); + } + + using BufferedPacketList = QuicBufferedPacketStore::BufferedPacketList; + const BufferedPacketList* FindBufferedPackets( + QuicConnectionId connection_id) { + return QuicBufferedPacketStorePeer::FindBufferedPackets(&store(), + connection_id); + } + + absl::flat_hash_map<QuicConnectionId, std::optional<QuicConnectionId>> + replaced_cid_map_; + + private: + using BufferedPacketStoreTest::expect_generator_is_called_; +}; + +INSTANTIATE_TEST_SUITE_P(DualCIDBufferedPacketStoreTests, + DualCIDBufferedPacketStoreTest, + ::testing::ValuesIn(CurrentSupportedVersionsWithTls()), + ::testing::PrintToStringParamName()); + +TEST_P(DualCIDBufferedPacketStoreTest, CanLookUpByBothCIDs) { + replaced_cid_map_[TestConnectionId(1)] = TestConnectionId(2); + ProcessFirstFlight(TestConnectionId(1)); + + ASSERT_TRUE(store().HasBufferedPackets(TestConnectionId(1))); + ASSERT_TRUE(store().HasBufferedPackets(TestConnectionId(2))); + + const BufferedPacketList* packets1 = FindBufferedPackets(TestConnectionId(1)); + const BufferedPacketList* packets2 = FindBufferedPackets(TestConnectionId(2)); + EXPECT_EQ(packets1, packets2); + EXPECT_EQ(packets1->original_connection_id, TestConnectionId(1)); + EXPECT_EQ(packets1->replaced_connection_id, TestConnectionId(2)); +} + +TEST_P(DualCIDBufferedPacketStoreTest, DeliverPacketsByOriginalCID) { + replaced_cid_map_[TestConnectionId(1)] = TestConnectionId(2); + ProcessFirstFlight(TestConnectionId(1)); + + ASSERT_TRUE(store().HasBufferedPackets(TestConnectionId(1))); + ASSERT_TRUE(store().HasBufferedPackets(TestConnectionId(2))); + ASSERT_TRUE(store().HasChloForConnection(TestConnectionId(1))); + ASSERT_TRUE(store().HasChloForConnection(TestConnectionId(2))); + ASSERT_TRUE(store().HasChlosBuffered()); + + BufferedPacketList packets = store().DeliverPackets(TestConnectionId(1)); + EXPECT_EQ(packets.original_connection_id, TestConnectionId(1)); + EXPECT_EQ(packets.replaced_connection_id, TestConnectionId(2)); + + EXPECT_FALSE(store().HasBufferedPackets(TestConnectionId(1))); + EXPECT_FALSE(store().HasBufferedPackets(TestConnectionId(2))); + EXPECT_FALSE(store().HasChloForConnection(TestConnectionId(1))); + EXPECT_FALSE(store().HasChloForConnection(TestConnectionId(2))); + EXPECT_FALSE(store().HasChlosBuffered()); +} + +TEST_P(DualCIDBufferedPacketStoreTest, DeliverPacketsByReplacedCID) { + replaced_cid_map_[TestConnectionId(1)] = TestConnectionId(2); + replaced_cid_map_[TestConnectionId(3)] = TestConnectionId(4); + ProcessFirstFlight(TestConnectionId(1)); + ProcessFirstFlight(TestConnectionId(3)); + + ASSERT_TRUE(store().HasBufferedPackets(TestConnectionId(1))); + ASSERT_TRUE(store().HasBufferedPackets(TestConnectionId(3))); + ASSERT_TRUE(store().HasChloForConnection(TestConnectionId(1))); + ASSERT_TRUE(store().HasChloForConnection(TestConnectionId(3))); + ASSERT_TRUE(store().HasChlosBuffered()); + + BufferedPacketList packets2 = store().DeliverPackets(TestConnectionId(2)); + EXPECT_EQ(packets2.original_connection_id, TestConnectionId(1)); + EXPECT_EQ(packets2.replaced_connection_id, TestConnectionId(2)); + + EXPECT_FALSE(store().HasBufferedPackets(TestConnectionId(1))); + EXPECT_FALSE(store().HasBufferedPackets(TestConnectionId(2))); + EXPECT_TRUE(store().HasBufferedPackets(TestConnectionId(3))); + EXPECT_TRUE(store().HasBufferedPackets(TestConnectionId(4))); + EXPECT_FALSE(store().HasChloForConnection(TestConnectionId(1))); + EXPECT_FALSE(store().HasChloForConnection(TestConnectionId(2))); + EXPECT_TRUE(store().HasChloForConnection(TestConnectionId(3))); + EXPECT_TRUE(store().HasChloForConnection(TestConnectionId(4))); + EXPECT_TRUE(store().HasChlosBuffered()); + + BufferedPacketList packets4 = store().DeliverPackets(TestConnectionId(4)); + EXPECT_EQ(packets4.original_connection_id, TestConnectionId(3)); + EXPECT_EQ(packets4.replaced_connection_id, TestConnectionId(4)); + + EXPECT_FALSE(store().HasBufferedPackets(TestConnectionId(3))); + EXPECT_FALSE(store().HasBufferedPackets(TestConnectionId(4))); + EXPECT_FALSE(store().HasChloForConnection(TestConnectionId(3))); + EXPECT_FALSE(store().HasChloForConnection(TestConnectionId(4))); + EXPECT_FALSE(store().HasChlosBuffered()); +} + +TEST_P(DualCIDBufferedPacketStoreTest, DiscardPacketsByOriginalCID) { + replaced_cid_map_[TestConnectionId(1)] = TestConnectionId(2); + ProcessFirstFlight(TestConnectionId(1)); + + ASSERT_TRUE(store().HasBufferedPackets(TestConnectionId(1))); + + store().DiscardPackets(TestConnectionId(1)); + + EXPECT_FALSE(store().HasBufferedPackets(TestConnectionId(1))); + EXPECT_FALSE(store().HasBufferedPackets(TestConnectionId(2))); + EXPECT_FALSE(store().HasChloForConnection(TestConnectionId(1))); + EXPECT_FALSE(store().HasChloForConnection(TestConnectionId(2))); + EXPECT_FALSE(store().HasChlosBuffered()); +} + +TEST_P(DualCIDBufferedPacketStoreTest, DiscardPacketsByReplacedCID) { + replaced_cid_map_[TestConnectionId(1)] = TestConnectionId(2); + replaced_cid_map_[TestConnectionId(3)] = TestConnectionId(4); + ProcessFirstFlight(TestConnectionId(1)); + ProcessFirstFlight(TestConnectionId(3)); + + ASSERT_TRUE(store().HasBufferedPackets(TestConnectionId(2))); + ASSERT_TRUE(store().HasBufferedPackets(TestConnectionId(4))); + + store().DiscardPackets(TestConnectionId(2)); + + EXPECT_FALSE(store().HasBufferedPackets(TestConnectionId(1))); + EXPECT_FALSE(store().HasBufferedPackets(TestConnectionId(2))); + EXPECT_TRUE(store().HasBufferedPackets(TestConnectionId(3))); + EXPECT_TRUE(store().HasBufferedPackets(TestConnectionId(4))); + EXPECT_FALSE(store().HasChloForConnection(TestConnectionId(1))); + EXPECT_FALSE(store().HasChloForConnection(TestConnectionId(2))); + EXPECT_TRUE(store().HasChloForConnection(TestConnectionId(3))); + EXPECT_TRUE(store().HasChloForConnection(TestConnectionId(4))); + EXPECT_TRUE(store().HasChlosBuffered()); + + store().DiscardPackets(TestConnectionId(4)); + + EXPECT_FALSE(store().HasBufferedPackets(TestConnectionId(3))); + EXPECT_FALSE(store().HasBufferedPackets(TestConnectionId(4))); + EXPECT_FALSE(store().HasChloForConnection(TestConnectionId(3))); + EXPECT_FALSE(store().HasChloForConnection(TestConnectionId(4))); + EXPECT_FALSE(store().HasChlosBuffered()); +} + +TEST_P(DualCIDBufferedPacketStoreTest, CIDCollision) { + replaced_cid_map_[TestConnectionId(1)] = TestConnectionId(2); + replaced_cid_map_[TestConnectionId(3)] = TestConnectionId(2); + ProcessFirstFlight(TestConnectionId(1)); + ProcessFirstFlight(TestConnectionId(3)); + + ASSERT_TRUE(store().HasBufferedPackets(TestConnectionId(1))); + ASSERT_TRUE(store().HasBufferedPackets(TestConnectionId(2))); + + // QuicDispatcher should discard connection 3 after CID collision. + ASSERT_FALSE(store().HasBufferedPackets(TestConnectionId(3))); +} + } // namespace } // namespace test } // namespace quic
diff --git a/quiche/quic/core/quic_types.cc b/quiche/quic/core/quic_types.cc index 43b5ace..a180fb1 100644 --- a/quiche/quic/core/quic_types.cc +++ b/quiche/quic/core/quic_types.cc
@@ -420,6 +420,12 @@ return os; } +std::string ParsedClientHello::ToString() const { + std::ostringstream oss; + oss << *this; + return oss.str(); +} + bool operator==(const ParsedClientHello& a, const ParsedClientHello& b) { return a.sni == b.sni && a.uaid == b.uaid && a.supported_groups == b.supported_groups &&
diff --git a/quiche/quic/core/quic_types.h b/quiche/quic/core/quic_types.h index b85e079..f3c5ff2 100644 --- a/quiche/quic/core/quic_types.h +++ b/quiche/quic/core/quic_types.h
@@ -882,6 +882,8 @@ std::string retry_token; bool resumption_attempted = false; // TLS only. bool early_data_attempted = false; // TLS only. + + std::string ToString() const; }; QUICHE_EXPORT bool operator==(const ParsedClientHello& a,
diff --git a/quiche/quic/test_tools/quic_buffered_packet_store_peer.cc b/quiche/quic/test_tools/quic_buffered_packet_store_peer.cc index 4d11806..0e2aa15 100644 --- a/quiche/quic/test_tools/quic_buffered_packet_store_peer.cc +++ b/quiche/quic/test_tools/quic_buffered_packet_store_peer.cc
@@ -9,17 +9,33 @@ namespace quic { namespace test { -// static QuicAlarm* QuicBufferedPacketStorePeer::expiration_alarm( QuicBufferedPacketStore* store) { return store->expiration_alarm_.get(); } -// static void QuicBufferedPacketStorePeer::set_clock(QuicBufferedPacketStore* store, const QuicClock* clock) { store->clock_ = clock; } +const QuicBufferedPacketStore::BufferedPacketList* +QuicBufferedPacketStorePeer::FindBufferedPackets( + const QuicBufferedPacketStore* store, QuicConnectionId connection_id) { + if (store->replace_cid_on_first_packet()) { + auto it = store->buffered_session_map_.find(connection_id); + if (it == store->buffered_session_map_.end()) { + return nullptr; + } + return it->second.get(); + } + + auto it = store->undecryptable_packets_.find(connection_id); + if (it == store->undecryptable_packets_.end()) { + return nullptr; + } + return &it->second; +} + } // namespace test } // namespace quic
diff --git a/quiche/quic/test_tools/quic_buffered_packet_store_peer.h b/quiche/quic/test_tools/quic_buffered_packet_store_peer.h index 0610274..5bb9c17 100644 --- a/quiche/quic/test_tools/quic_buffered_packet_store_peer.h +++ b/quiche/quic/test_tools/quic_buffered_packet_store_peer.h
@@ -8,7 +8,9 @@ #include <memory> #include "quiche/quic/core/quic_alarm.h" +#include "quiche/quic/core/quic_buffered_packet_store.h" #include "quiche/quic/core/quic_clock.h" +#include "quiche/quic/core/quic_connection_id.h" namespace quic { @@ -23,6 +25,9 @@ static QuicAlarm* expiration_alarm(QuicBufferedPacketStore* store); static void set_clock(QuicBufferedPacketStore* store, const QuicClock* clock); + + static const QuicBufferedPacketStore::BufferedPacketList* FindBufferedPackets( + const QuicBufferedPacketStore* store, QuicConnectionId connection_id); }; } // namespace test