Patch in 249021555 (and 248702511) Support QUIC Client connection IDs
diff --git a/quic/core/http/end_to_end_test.cc b/quic/core/http/end_to_end_test.cc index 209bc81..c72622c 100644 --- a/quic/core/http/end_to_end_test.cc +++ b/quic/core/http/end_to_end_test.cc
@@ -210,8 +210,9 @@ chlo_multiplier_(0), stream_factory_(nullptr), support_server_push_(false), - override_connection_id_(nullptr), - expected_connection_id_length_(kQuicDefaultConnectionIdLength) { + override_server_connection_id_(nullptr), + override_client_connection_id_(nullptr), + expected_server_connection_id_length_(kQuicDefaultConnectionIdLength) { SetQuicFlag(FLAGS_quic_supports_tls_handshake, true); SetQuicRestartFlag(quic_no_server_conn_ver_negotiation2, true); SetQuicReloadableFlag(quic_no_client_conn_ver_negotiation, true); @@ -258,8 +259,11 @@ if (!pre_shared_key_client_.empty()) { client->client()->SetPreSharedKey(pre_shared_key_client_); } - if (override_connection_id_ != nullptr) { - client->UseConnectionId(*override_connection_id_); + if (override_server_connection_id_ != nullptr) { + client->UseConnectionId(*override_server_connection_id_); + } + if (override_client_connection_id_ != nullptr) { + client->UseClientConnectionId(*override_client_connection_id_); } client->Connect(); return client; @@ -374,7 +378,7 @@ auto* test_server = new QuicTestServer( crypto_test_utils::ProofSourceForTesting(), server_config_, server_supported_versions_, &memory_cache_backend_, - expected_connection_id_length_); + expected_server_connection_id_length_); server_thread_ = QuicMakeUnique<ServerThread>(test_server, server_address_); if (chlo_multiplier_ != 0) { server_thread_->server()->SetChloMultiplier(chlo_multiplier_); @@ -537,8 +541,9 @@ bool support_server_push_; std::string pre_shared_key_client_; std::string pre_shared_key_server_; - QuicConnectionId* override_connection_id_; - uint8_t expected_connection_id_length_; + QuicConnectionId* override_server_connection_id_; + QuicConnectionId* override_client_connection_id_; + uint8_t expected_server_connection_id_length_; }; // Run all end to end tests with all supported versions. @@ -602,8 +607,8 @@ TEST_P(EndToEndTest, SimpleRequestResponseZeroConnectionID) { QuicConnectionId connection_id = QuicUtils::CreateZeroConnectionId( GetParam().negotiated_version.transport_version); - override_connection_id_ = &connection_id; - expected_connection_id_length_ = connection_id.length(); + override_server_connection_id_ = &connection_id; + expected_server_connection_id_length_ = connection_id.length(); ASSERT_TRUE(Initialize()); EXPECT_EQ(kFooResponseBody, client_->SendSynchronousRequest("/foo")); @@ -627,7 +632,7 @@ } QuicConnectionId connection_id = TestConnectionIdNineBytesLong(UINT64_C(0xBADbadBADbad)); - override_connection_id_ = &connection_id; + override_server_connection_id_ = &connection_id; ASSERT_TRUE(Initialize()); EXPECT_EQ(kFooResponseBody, client_->SendSynchronousRequest("/foo")); EXPECT_EQ("200", client_->response_headers()->find(":status")->second); @@ -638,6 +643,43 @@ .length()); } +TEST_P(EndToEndTest, ClientConnectionId) { + if (!GetParam().negotiated_version.SupportsClientConnectionIds()) { + ASSERT_TRUE(Initialize()); + return; + } + QuicConnectionId client_connection_id = + TestConnectionId(UINT64_C(0xc1c2c3c4c5c6c7c8)); + override_client_connection_id_ = &client_connection_id; + ASSERT_TRUE(Initialize()); + EXPECT_EQ(kFooResponseBody, client_->SendSynchronousRequest("/foo")); + EXPECT_EQ("200", client_->response_headers()->find(":status")->second); + EXPECT_EQ(client_connection_id, client_->client() + ->client_session() + ->connection() + ->client_connection_id()); +} + +TEST_P(EndToEndTest, ForcedVersionNegotiationAndClientConnectionId) { + if (!GetParam().negotiated_version.SupportsClientConnectionIds()) { + ASSERT_TRUE(Initialize()); + return; + } + client_supported_versions_.insert(client_supported_versions_.begin(), + QuicVersionReservedForNegotiation()); + QuicConnectionId client_connection_id = + TestConnectionId(UINT64_C(0xc1c2c3c4c5c6c7c8)); + override_client_connection_id_ = &client_connection_id; + ASSERT_TRUE(Initialize()); + ASSERT_TRUE(ServerSendsVersionNegotiation()); + EXPECT_EQ(kFooResponseBody, client_->SendSynchronousRequest("/foo")); + EXPECT_EQ("200", client_->response_headers()->find(":status")->second); + EXPECT_EQ(client_connection_id, client_->client() + ->client_session() + ->connection() + ->client_connection_id()); +} + TEST_P(EndToEndTest, ForcedVersionNegotiationAndBadConnectionIdLength) { if (!GetQuicRestartFlag( quic_allow_variable_length_connection_id_for_negotiation)) { @@ -653,7 +695,7 @@ QuicVersionReservedForNegotiation()); QuicConnectionId connection_id = TestConnectionIdNineBytesLong(UINT64_C(0xBADbadBADbad)); - override_connection_id_ = &connection_id; + override_server_connection_id_ = &connection_id; ASSERT_TRUE(Initialize()); ASSERT_TRUE(ServerSendsVersionNegotiation()); EXPECT_EQ(kFooResponseBody, client_->SendSynchronousRequest("/foo")); @@ -675,9 +717,9 @@ // Start client_ which will use a bad connection ID length. QuicConnectionId connection_id = TestConnectionIdNineBytesLong(UINT64_C(0xBADbadBADbad)); - override_connection_id_ = &connection_id; + override_server_connection_id_ = &connection_id; ASSERT_TRUE(Initialize()); - override_connection_id_ = nullptr; + override_server_connection_id_ = nullptr; // Start client2 which will use a good connection ID length. std::unique_ptr<QuicTestClient> client2(CreateQuicClient(nullptr)); @@ -766,8 +808,8 @@ TEST_P(EndToEndTest, MultipleRequestResponseZeroConnectionID) { QuicConnectionId connection_id = QuicUtils::CreateZeroConnectionId( GetParam().negotiated_version.transport_version); - override_connection_id_ = &connection_id; - expected_connection_id_length_ = connection_id.length(); + override_server_connection_id_ = &connection_id; + expected_server_connection_id_length_ = connection_id.length(); ASSERT_TRUE(Initialize()); EXPECT_EQ(kFooResponseBody, client_->SendSynchronousRequest("/foo")); @@ -2318,7 +2360,7 @@ TestConnectionIdToUInt64(client_connection->connection_id()) + 1); std::unique_ptr<QuicEncryptedPacket> packet( QuicFramer::BuildVersionNegotiationPacket( - incorrect_connection_id, + incorrect_connection_id, EmptyQuicConnectionId(), client_connection->transport_version() > QUIC_VERSION_43, server_supported_versions_)); testing::NiceMock<MockQuicConnectionDebugVisitor> visitor;
diff --git a/quic/core/http/quic_spdy_client_session_test.cc b/quic/core/http/quic_spdy_client_session_test.cc index 77d105c..0d6b36c 100644 --- a/quic/core/http/quic_spdy_client_session_test.cc +++ b/quic/core/http/quic_spdy_client_session_test.cc
@@ -504,7 +504,7 @@ session_->ProcessUdpPacket(client_address, server_address, valid_packet); // Verify that a non-decryptable packet doesn't close the connection. - QuicFramerPeer::SetLastSerializedConnectionId( + QuicFramerPeer::SetLastSerializedServerConnectionId( QuicConnectionPeer::GetFramer(connection_), connection_id); ParsedQuicVersionVector versions = SupportedVersions(GetParam()); QuicConnectionId destination_connection_id = EmptyQuicConnectionId(); @@ -549,7 +549,7 @@ QuicConnectionId destination_connection_id = session_->connection()->connection_id(); QuicConnectionId source_connection_id = EmptyQuicConnectionId(); - QuicFramerPeer::SetLastSerializedConnectionId( + QuicFramerPeer::SetLastSerializedServerConnectionId( QuicConnectionPeer::GetFramer(connection_), destination_connection_id); ParsedQuicVersionVector versions = {GetParam()}; bool version_flag = false;
diff --git a/quic/core/quic_connection.cc b/quic/core/quic_connection.cc index e61bae4..169b17b 100644 --- a/quic/core/quic_connection.cc +++ b/quic/core/quic_connection.cc
@@ -219,7 +219,7 @@ (perspective_ == Perspective::IS_SERVER ? "Server: " : "Client: ") QuicConnection::QuicConnection( - QuicConnectionId connection_id, + QuicConnectionId server_connection_id, QuicSocketAddress initial_peer_address, QuicConnectionHelperInterface* helper, QuicAlarmFactory* alarm_factory, @@ -230,7 +230,7 @@ : framer_(supported_versions, helper->GetClock()->ApproximateNow(), perspective, - connection_id.length()), + server_connection_id.length()), current_packet_content_(NO_FRAMES_RECEIVED), is_current_packet_connectivity_probing_(false), current_effective_peer_migration_type_(NO_CHANGE), @@ -242,7 +242,7 @@ encryption_level_(ENCRYPTION_INITIAL), clock_(helper->GetClock()), random_generator_(helper->GetRandomGenerator()), - connection_id_(connection_id), + server_connection_id_(server_connection_id), peer_address_(initial_peer_address), direct_peer_address_(initial_peer_address), active_effective_peer_migration_type_(NO_CHANGE), @@ -302,7 +302,10 @@ &arena_)), visitor_(nullptr), debug_visitor_(nullptr), - packet_generator_(connection_id_, &framer_, random_generator_, this), + packet_generator_(server_connection_id_, + &framer_, + random_generator_, + this), idle_network_timeout_(QuicTime::Delta::Infinite()), handshake_timeout_(QuicTime::Delta::Infinite()), time_of_first_packet_sent_after_receiving_( @@ -373,14 +376,14 @@ if (use_uber_received_packet_manager_) { QUIC_RELOADABLE_FLAG_COUNT(quic_use_uber_received_packet_manager); } - QUIC_DLOG(INFO) << ENDPOINT - << "Created connection with connection_id: " << connection_id + QUIC_DLOG(INFO) << ENDPOINT << "Created connection with server connection ID " + << server_connection_id << " and version: " << ParsedQuicVersionToString(version()); - QUIC_BUG_IF(!QuicUtils::IsConnectionIdValidForVersion(connection_id, + QUIC_BUG_IF(!QuicUtils::IsConnectionIdValidForVersion(server_connection_id, transport_version())) - << "QuicConnection: attempted to use connection ID " << connection_id - << " which is invalid with version " + << "QuicConnection: attempted to use server connection ID " + << server_connection_id << " which is invalid with version " << QuicVersionToString(transport_version()); framer_.set_visitor(this); @@ -444,7 +447,7 @@ sent_packet_manager_.SetFromConfig(config); if (config.HasReceivedBytesForConnectionId() && can_truncate_connection_ids_) { - packet_generator_.SetConnectionIdLength( + packet_generator_.SetServerConnectionIdLength( config.ReceivedBytesForConnectionId()); } max_undecryptable_packets_ = config.max_undecryptable_packets(); @@ -592,7 +595,7 @@ // Check that any public reset packet with a different connection ID that was // routed to this QuicConnection has been redirected before control reaches // here. (Check for a bug regression.) - DCHECK_EQ(connection_id_, packet.connection_id); + DCHECK_EQ(server_connection_id_, packet.connection_id); DCHECK_EQ(perspective_, Perspective::IS_CLIENT); if (debug_visitor_ != nullptr) { debug_visitor_->OnPublicResetPacket(packet); @@ -683,7 +686,7 @@ // Check that any public reset packet with a different connection ID that was // routed to this QuicConnection has been redirected before control reaches // here. (Check for a bug regression.) - DCHECK_EQ(connection_id_, packet.connection_id); + DCHECK_EQ(server_connection_id_, packet.connection_id); if (perspective_ == Perspective::IS_SERVER) { const std::string error_details = "Server received version negotiation packet."; @@ -766,10 +769,10 @@ void QuicConnection::OnRetryPacket(QuicConnectionId original_connection_id, QuicConnectionId new_connection_id, QuicStringPiece retry_token) { - if (original_connection_id != connection_id_) { + if (original_connection_id != server_connection_id_) { QUIC_DLOG(ERROR) << "Ignoring RETRY with original connection ID " << original_connection_id << " not matching expected " - << connection_id_ << " token " + << server_connection_id_ << " token " << QuicTextUtils::HexEncode(retry_token); return; } @@ -780,17 +783,18 @@ } retry_has_been_parsed_ = true; QUIC_DLOG(INFO) << "Received RETRY, replacing connection ID " - << connection_id_ << " with " << new_connection_id + << server_connection_id_ << " with " << new_connection_id << ", received token " << QuicTextUtils::HexEncode(retry_token); - connection_id_ = new_connection_id; - packet_generator_.SetConnectionId(connection_id_); + server_connection_id_ = new_connection_id; + packet_generator_.SetServerConnectionId(server_connection_id_); packet_generator_.SetRetryToken(retry_token); // Reinstall initial crypters because the connection ID changed. CrypterPair crypters; - CryptoUtils::CreateTlsInitialCrypters( - Perspective::IS_CLIENT, transport_version(), connection_id_, &crypters); + CryptoUtils::CreateTlsInitialCrypters(Perspective::IS_CLIENT, + transport_version(), + server_connection_id_, &crypters); SetEncrypter(ENCRYPTION_INITIAL, std::move(crypters.encrypter)); InstallDecrypter(ENCRYPTION_INITIAL, std::move(crypters.decrypter)); } @@ -817,29 +821,53 @@ QuicConnectionId server_connection_id = GetServerConnectionIdAsRecipient(header, perspective_); - if (server_connection_id == connection_id_ || - HasIncomingConnectionId(server_connection_id)) { + if (server_connection_id != server_connection_id_ && + !HasIncomingConnectionId(server_connection_id)) { + if (PacketCanReplaceConnectionId(header, perspective_)) { + QUIC_DLOG(INFO) << ENDPOINT << "Accepting packet with new connection ID " + << server_connection_id << " instead of " + << server_connection_id_; + return true; + } + + ++stats_.packets_dropped; + QUIC_DLOG(INFO) + << ENDPOINT << "Ignoring packet from unexpected server connection ID " + << server_connection_id << " instead of " << server_connection_id_; + QUIC_BUG + << ENDPOINT << "Ignoring packet from unexpected server connection ID " + << server_connection_id << " instead of " << server_connection_id_ + << " header " << header; + if (debug_visitor_ != nullptr) { + debug_visitor_->OnIncorrectConnectionId(server_connection_id); + } + // If this is a server, the dispatcher routes each packet to the + // QuicConnection responsible for the packet's connection ID. So if control + // arrives here and this is a server, the dispatcher must be malfunctioning. + DCHECK_NE(Perspective::IS_SERVER, perspective_); + return false; + } + + if (!version().SupportsClientConnectionIds()) { return true; } - if (PacketCanReplaceConnectionId(header, perspective_)) { - QUIC_DLOG(INFO) << ENDPOINT << "Accepting packet with new connection ID " - << server_connection_id << " instead of " << connection_id_; - return true; + QuicConnectionId client_connection_id = + GetClientConnectionIdAsRecipient(header, perspective_); + + if (client_connection_id != client_connection_id_) { + ++stats_.packets_dropped; + QUIC_DLOG(INFO) + << ENDPOINT << "Ignoring packet from unexpected client connection ID " + << client_connection_id << " instead of " << client_connection_id_; + QUIC_BUG + << ENDPOINT << "Ignoring packet from unexpected client connection ID " + << client_connection_id << " instead of " << client_connection_id_ + << " header " << header; + return false; } - ++stats_.packets_dropped; - QUIC_DLOG(INFO) << ENDPOINT - << "Ignoring packet from unexpected ConnectionId: " - << server_connection_id << " instead of " << connection_id_; - if (debug_visitor_ != nullptr) { - debug_visitor_->OnIncorrectConnectionId(server_connection_id); - } - // If this is a server, the dispatcher routes each packet to the - // QuicConnection responsible for the packet's connection ID. So if control - // arrives here and this is a server, the dispatcher must be malfunctioning. - DCHECK_NE(Perspective::IS_SERVER, perspective_); - return false; + return true; } bool QuicConnection::OnUnauthenticatedHeader(const QuicPacketHeader& header) { @@ -851,7 +879,7 @@ // routed to this QuicConnection has been redirected before control reaches // here. DCHECK(GetServerConnectionIdAsRecipient(header, perspective_) == - connection_id_ || + server_connection_id_ || HasIncomingConnectionId( GetServerConnectionIdAsRecipient(header, perspective_)) || PacketCanReplaceConnectionId(header, perspective_)); @@ -1109,7 +1137,7 @@ << " packet_number:" << last_header_.packet_number << " largest seen with ack:" << GetLargestReceivedPacketWithAck() - << " connection_id: " << connection_id_; + << " server_connection_id: " << server_connection_id_; // A new ack has a diminished largest_observed value. // If this was an old packet, we wouldn't even have checked. CloseConnection(QUIC_INVALID_ACK_DATA, "Largest observed too low.", @@ -2132,11 +2160,12 @@ } if (PacketCanReplaceConnectionId(header, perspective_) && - connection_id_ != header.source_connection_id) { - QUIC_DLOG(INFO) << ENDPOINT << "Replacing connection ID " << connection_id_ - << " with " << header.source_connection_id; - connection_id_ = header.source_connection_id; - packet_generator_.SetConnectionId(connection_id_); + server_connection_id_ != header.source_connection_id) { + QUIC_DLOG(INFO) << ENDPOINT << "Replacing connection ID " + << server_connection_id_ << " with " + << header.source_connection_id; + server_connection_id_ = header.source_connection_id; + packet_generator_.SetServerConnectionId(server_connection_id_); } if (!ValidateReceivedPacketNumber(header.packet_number)) { @@ -3578,7 +3607,7 @@ QUIC_DLOG(INFO) << ENDPOINT << "Sending path probe packet for connection_id = " - << connection_id_; + << server_connection_id_; OwningSerializedPacketPointer probing_packet; if (transport_version() != QUIC_VERSION_99) { @@ -4215,5 +4244,22 @@ return received_packet_manager_.ack_frame(); } +void QuicConnection::set_client_connection_id( + QuicConnectionId client_connection_id) { + if (!version().SupportsClientConnectionIds()) { + QUIC_BUG_IF(!client_connection_id.IsEmpty()) + << ENDPOINT << "Attempted to use client connection ID " + << client_connection_id << " with unsupported version " << version(); + return; + } + client_connection_id_ = client_connection_id; + QUIC_DLOG(INFO) << ENDPOINT << "setting client connection ID to " + << client_connection_id_ + << " for connection with server connection ID " + << server_connection_id_; + packet_generator_.SetClientConnectionId(client_connection_id_); + framer_.SetExpectedClientConnectionIdLength(client_connection_id_.length()); +} + #undef ENDPOINT // undef for jumbo builds } // namespace quic
diff --git a/quic/core/quic_connection.h b/quic/core/quic_connection.h index 8c8ed2b..00902a9 100644 --- a/quic/core/quic_connection.h +++ b/quic/core/quic_connection.h
@@ -341,7 +341,7 @@ // |initial_peer_address| using |writer| to write packets. |owns_writer| // specifies whether the connection takes ownership of |writer|. |helper| must // outlive this connection. - QuicConnection(QuicConnectionId connection_id, + QuicConnection(QuicConnectionId server_connection_id, QuicSocketAddress initial_peer_address, QuicConnectionHelperInterface* helper, QuicAlarmFactory* alarm_factory, @@ -582,7 +582,11 @@ const QuicSocketAddress& effective_peer_address() const { return effective_peer_address_; } - QuicConnectionId connection_id() const { return connection_id_; } + QuicConnectionId connection_id() const { return server_connection_id_; } + QuicConnectionId client_connection_id() const { + return client_connection_id_; + } + void set_client_connection_id(QuicConnectionId client_connection_id); const QuicClock* clock() const { return clock_; } QuicRandom* random_generator() const { return random_generator_; } QuicByteCount max_packet_length() const; @@ -1168,7 +1172,8 @@ const QuicClock* clock_; QuicRandom* random_generator_; - QuicConnectionId connection_id_; + QuicConnectionId server_connection_id_; + QuicConnectionId client_connection_id_; // Address on the last successfully processed packet received from the // direct peer. QuicSocketAddress self_address_;
diff --git a/quic/core/quic_connection_test.cc b/quic/core/quic_connection_test.cc index 4ecb780..5c98970 100644 --- a/quic/core/quic_connection_test.cc +++ b/quic/core/quic_connection_test.cc
@@ -938,7 +938,7 @@ peer_creator_.SetEncrypter( level, QuicMakeUnique<NullEncrypter>(peer_framer_.perspective())); } - QuicFramerPeer::SetLastSerializedConnectionId( + QuicFramerPeer::SetLastSerializedServerConnectionId( QuicConnectionPeer::GetFramer(&connection_), connection_id_); if (version().transport_version > QUIC_VERSION_43) { EXPECT_TRUE(QuicConnectionPeer::GetNoStopWaitingFrames(&connection_)); @@ -6877,7 +6877,8 @@ // Send a version negotiation packet. std::unique_ptr<QuicEncryptedPacket> encrypted( peer_framer_.BuildVersionNegotiationPacket( - connection_id_, connection_.transport_version() > QUIC_VERSION_43, + connection_id_, EmptyQuicConnectionId(), + connection_.transport_version() > QUIC_VERSION_43, AllSupportedVersions())); std::unique_ptr<QuicReceivedPacket> received( ConstructReceivedPacket(*encrypted, QuicTime::Zero()));
diff --git a/quic/core/quic_dispatcher.cc b/quic/core/quic_dispatcher.cc index 9af5f84..b57f398 100644 --- a/quic/core/quic_dispatcher.cc +++ b/quic/core/quic_dispatcher.cc
@@ -115,17 +115,17 @@ // list manager. class StatelessConnectionTerminator { public: - StatelessConnectionTerminator(QuicConnectionId connection_id, + StatelessConnectionTerminator(QuicConnectionId server_connection_id, const ParsedQuicVersion version, QuicConnectionHelperInterface* helper, QuicTimeWaitListManager* time_wait_list_manager) - : connection_id_(connection_id), + : server_connection_id_(server_connection_id), framer_(ParsedQuicVersionVector{version}, /*unused*/ QuicTime::Zero(), Perspective::IS_SERVER, /*unused*/ kQuicDefaultConnectionIdLength), collector_(helper->GetStreamSendBufferAllocator()), - creator_(connection_id, &framer_, &collector_), + creator_(server_connection_id, &framer_, &collector_), time_wait_list_manager_(time_wait_list_manager) { framer_.set_data_producer(&collector_); } @@ -154,7 +154,7 @@ creator_.Flush(); DCHECK_EQ(1u, collector_.packets()->size()); time_wait_list_manager_->AddConnectionIdToTimeWait( - connection_id_, ietf_quic, + server_connection_id_, ietf_quic, QuicTimeWaitListManager::SEND_TERMINATION_PACKETS, quic::ENCRYPTION_INITIAL, collector_.packets()); } @@ -195,14 +195,15 @@ creator_.Flush(); } time_wait_list_manager_->AddConnectionIdToTimeWait( - connection_id_, ietf_quic, + server_connection_id_, ietf_quic, QuicTimeWaitListManager::SEND_TERMINATION_PACKETS, ENCRYPTION_INITIAL, collector_.packets()); - DCHECK(time_wait_list_manager_->IsConnectionIdInTimeWait(connection_id_)); + DCHECK(time_wait_list_manager_->IsConnectionIdInTimeWait( + server_connection_id_)); } private: - QuicConnectionId connection_id_; + QuicConnectionId server_connection_id_; QuicFramer framer_; // Set as the visitor of |creator_| to collect any generated packets. PacketCollector collector_; @@ -214,7 +215,7 @@ class ChloAlpnExtractor : public ChloExtractor::Delegate { public: void OnChlo(QuicTransportVersion version, - QuicConnectionId connection_id, + QuicConnectionId server_connection_id, const CryptoHandshakeMessage& chlo) override { QuicStringPiece alpn_value; if (chlo.GetStringPiece(kALPN, &alpn_value)) { @@ -247,16 +248,17 @@ // ChloExtractor::Delegate implementation. void OnChlo(QuicTransportVersion version, - QuicConnectionId connection_id, + QuicConnectionId server_connection_id, const CryptoHandshakeMessage& chlo) override { // Extract the ALPN - ChloAlpnExtractor::OnChlo(version, connection_id, chlo); + ChloAlpnExtractor::OnChlo(version, server_connection_id, chlo); if (helper_->CanAcceptClientHello(chlo, client_address_, peer_address_, self_address_, &error_details_)) { can_accept_ = true; rejector_->OnChlo( - version, connection_id, - helper_->GenerateConnectionIdForReject(version, connection_id), chlo); + version, server_connection_id, + helper_->GenerateConnectionIdForReject(version, server_connection_id), + chlo); } } @@ -285,7 +287,7 @@ std::unique_ptr<QuicConnectionHelperInterface> helper, std::unique_ptr<QuicCryptoServerStream::Helper> session_helper, std::unique_ptr<QuicAlarmFactory> alarm_factory, - uint8_t expected_connection_id_length) + uint8_t expected_server_connection_id_length) : config_(config), crypto_config_(crypto_config), compressed_certs_cache_( @@ -301,14 +303,15 @@ framer_(GetSupportedVersions(), /*unused*/ QuicTime::Zero(), Perspective::IS_SERVER, - expected_connection_id_length), + expected_server_connection_id_length), last_error_(QUIC_NO_ERROR), new_sessions_allowed_per_event_loop_(0u), accept_new_connections_(true), - allow_short_initial_connection_ids_(false), + allow_short_initial_server_connection_ids_(false), last_version_label_(0), - expected_connection_id_length_(expected_connection_id_length), - should_update_expected_connection_id_length_(false), + expected_server_connection_id_length_( + expected_server_connection_id_length), + should_update_expected_server_connection_id_length_(false), no_framer_(GetQuicRestartFlag(quic_no_framer_object_in_dispatcher)) { if (!no_framer_) { framer_.set_visitor(this); @@ -345,12 +348,11 @@ } QUIC_RESTART_FLAG_COUNT(quic_no_framer_object_in_dispatcher); QuicPacketHeader header; - uint8_t destination_connection_id_length; std::string detailed_error; const QuicErrorCode error = QuicFramer::ProcessPacketDispatcher( - packet, expected_connection_id_length_, &header.form, + packet, expected_server_connection_id_length_, &header.form, &header.version_flag, &last_version_label_, - &destination_connection_id_length, &header.destination_connection_id, + &header.destination_connection_id, &header.source_connection_id, &detailed_error); if (error != QUIC_NO_ERROR) { // Packet has framing error. @@ -359,16 +361,18 @@ return; } header.version = ParseQuicVersionLabel(last_version_label_); - if (destination_connection_id_length != expected_connection_id_length_ && - !should_update_expected_connection_id_length_ && + if (header.destination_connection_id.length() != + expected_server_connection_id_length_ && + !should_update_expected_server_connection_id_length_ && !QuicUtils::VariableLengthConnectionIdAllowedForVersion( header.version.transport_version)) { SetLastError(QUIC_INVALID_PACKET_HEADER); QUIC_DLOG(ERROR) << "Invalid Connection Id Length"; return; } - if (should_update_expected_connection_id_length_) { - expected_connection_id_length_ = destination_connection_id_length; + if (should_update_expected_server_connection_id_length_) { + expected_server_connection_id_length_ = + header.destination_connection_id.length(); } // TODO(fayang): Instead of passing in QuicPacketHeader, pass format, // version_flag, version and destination_connection_id. Combine @@ -382,34 +386,36 @@ // the next packet does not use them incorrectly. } -QuicConnectionId QuicDispatcher::MaybeReplaceConnectionId( - QuicConnectionId connection_id, +QuicConnectionId QuicDispatcher::MaybeReplaceServerConnectionId( + QuicConnectionId server_connection_id, ParsedQuicVersion version) { - const uint8_t expected_connection_id_length = - no_framer_ ? expected_connection_id_length_ - : framer_.GetExpectedConnectionIdLength(); - if (connection_id.length() == expected_connection_id_length) { - return connection_id; + const uint8_t expected_server_connection_id_length = + no_framer_ ? expected_server_connection_id_length_ + : framer_.GetExpectedServerConnectionIdLength(); + if (server_connection_id.length() == expected_server_connection_id_length) { + return server_connection_id; } DCHECK(QuicUtils::VariableLengthConnectionIdAllowedForVersion( version.transport_version)); - auto it = connection_id_map_.find(connection_id); + auto it = connection_id_map_.find(server_connection_id); if (it != connection_id_map_.end()) { return it->second; } QuicConnectionId new_connection_id = session_helper_->GenerateConnectionIdForReject(version.transport_version, - connection_id); - DCHECK_EQ(expected_connection_id_length, new_connection_id.length()); - connection_id_map_.insert(std::make_pair(connection_id, new_connection_id)); - QUIC_DLOG(INFO) << "Replacing incoming connection ID " << connection_id + server_connection_id); + DCHECK_EQ(expected_server_connection_id_length, new_connection_id.length()); + connection_id_map_.insert( + std::make_pair(server_connection_id, new_connection_id)); + QUIC_DLOG(INFO) << "Replacing incoming connection ID " << server_connection_id << " with " << new_connection_id; return new_connection_id; } bool QuicDispatcher::OnUnauthenticatedPublicHeader( const QuicPacketHeader& header) { - current_connection_id_ = header.destination_connection_id; + current_server_connection_id_ = header.destination_connection_id; + current_client_connection_id_ = header.source_connection_id; // Port zero is only allowed for unidirectional UDP, so is disallowed by QUIC. // Given that we can't even send a reply rejecting the packet, just drop the @@ -423,52 +429,54 @@ if (header.destination_connection_id_included != CONNECTION_ID_PRESENT) { return false; } - QuicConnectionId connection_id = header.destination_connection_id; + QuicConnectionId server_connection_id = header.destination_connection_id; + QuicConnectionId client_connection_id = header.source_connection_id; // The IETF spec requires the client to generate an initial server // connection ID that is at least 64 bits long. After that initial // connection ID, the dispatcher picks a new one of its expected length. // Therefore we should never receive a connection ID that is smaller // than 64 bits and smaller than what we expect. - const uint8_t expected_connection_id_length = - no_framer_ ? expected_connection_id_length_ - : framer_.GetExpectedConnectionIdLength(); - if (connection_id.length() < kQuicMinimumInitialConnectionIdLength && - connection_id.length() < expected_connection_id_length && - !allow_short_initial_connection_ids_) { + const uint8_t expected_server_connection_id_length = + no_framer_ ? expected_server_connection_id_length_ + : framer_.GetExpectedServerConnectionIdLength(); + if (server_connection_id.length() < kQuicMinimumInitialConnectionIdLength && + server_connection_id.length() < expected_server_connection_id_length && + !allow_short_initial_server_connection_ids_) { DCHECK(header.version_flag); DCHECK(QuicUtils::VariableLengthConnectionIdAllowedForVersion( header.version.transport_version)); QUIC_DLOG(INFO) << "Packet with short destination connection ID " - << connection_id << " expected " - << static_cast<int>(expected_connection_id_length); - ProcessUnauthenticatedHeaderFate(kFateTimeWait, connection_id, header.form, - header.version_flag, header.version); + << server_connection_id << " expected " + << static_cast<int>(expected_server_connection_id_length); + ProcessUnauthenticatedHeaderFate(kFateTimeWait, server_connection_id, + header.form, header.version_flag, + header.version); return false; } // Packets with connection IDs for active connections are processed // immediately. - auto it = session_map_.find(connection_id); + auto it = session_map_.find(server_connection_id); if (it != session_map_.end()) { - DCHECK(!buffered_packets_.HasBufferedPackets(connection_id)); + DCHECK(!buffered_packets_.HasBufferedPackets(server_connection_id)); it->second->ProcessUdpPacket(current_self_address_, current_peer_address_, *current_packet_); return false; } - if (buffered_packets_.HasChloForConnection(connection_id)) { - BufferEarlyPacket(connection_id, header.form != GOOGLE_QUIC_PACKET, + if (buffered_packets_.HasChloForConnection(server_connection_id)) { + BufferEarlyPacket(server_connection_id, header.form != GOOGLE_QUIC_PACKET, header.version); return false; } // Check if we are buffering packets for this connection ID - if (temporarily_buffered_connections_.find(connection_id) != + if (temporarily_buffered_connections_.find(server_connection_id) != temporarily_buffered_connections_.end()) { // This packet was received while the a CHLO for the same connection ID was // being processed. Buffer it. - BufferEarlyPacket(connection_id, header.form != GOOGLE_QUIC_PACKET, + BufferEarlyPacket(server_connection_id, header.form != GOOGLE_QUIC_PACKET, header.version); return false; } @@ -483,7 +491,7 @@ return false; } - if (time_wait_list_manager_->IsConnectionIdInTimeWait(connection_id)) { + if (time_wait_list_manager_->IsConnectionIdInTimeWait(server_connection_id)) { // This connection ID is already in time-wait state. time_wait_list_manager_->ProcessPacket( current_self_address_, current_peer_address_, @@ -513,9 +521,10 @@ // Since the version is not supported, send a version negotiation // packet and stop processing the current packet. time_wait_list_manager()->SendVersionNegotiationPacket( - connection_id, header.form != GOOGLE_QUIC_PACKET, - GetSupportedVersions(), current_self_address_, - current_peer_address_, GetPerPacketContext()); + server_connection_id, client_connection_id, + header.form != GOOGLE_QUIC_PACKET, GetSupportedVersions(), + current_self_address_, current_peer_address_, + GetPerPacketContext()); } return false; } @@ -539,25 +548,25 @@ } void QuicDispatcher::ProcessHeader(const QuicPacketHeader& header) { - QuicConnectionId connection_id = header.destination_connection_id; + QuicConnectionId server_connection_id = header.destination_connection_id; // Packet's connection ID is unknown. Apply the validity checks. QuicPacketFate fate = ValidityChecks(header); if (fate == kFateProcess) { // Execute stateless rejection logic to determine the packet fate, then // invoke ProcessUnauthenticatedHeaderFate. - MaybeRejectStatelessly(connection_id, header.form, header.version_flag, - header.version); + MaybeRejectStatelessly(server_connection_id, header.form, + header.version_flag, header.version); } else { // If the fate is already known, process it without executing stateless // rejection logic. - ProcessUnauthenticatedHeaderFate(fate, connection_id, header.form, + ProcessUnauthenticatedHeaderFate(fate, server_connection_id, header.form, header.version_flag, header.version); } } void QuicDispatcher::ProcessUnauthenticatedHeaderFate( QuicPacketFate fate, - QuicConnectionId connection_id, + QuicConnectionId server_connection_id, PacketHeaderFormat form, bool version_flag, ParsedQuicVersion version) { @@ -570,33 +579,36 @@ // MaybeRejectStatelessly or OnExpiredPackets might have already added the // connection to time wait, in which case it should not be added again. if (!GetQuicReloadableFlag(quic_use_cheap_stateless_rejects) || - !time_wait_list_manager_->IsConnectionIdInTimeWait(connection_id)) { + !time_wait_list_manager_->IsConnectionIdInTimeWait( + server_connection_id)) { // Add this connection_id to the time-wait state, to safely reject // future packets. - QUIC_DLOG(INFO) << "Adding connection ID " << connection_id + QUIC_DLOG(INFO) << "Adding connection ID " << server_connection_id << " to time-wait list."; QUIC_CODE_COUNT(quic_reject_fate_time_wait); StatelesslyTerminateConnection( - connection_id, form, version_flag, version, QUIC_HANDSHAKE_FAILED, - "Reject connection", + server_connection_id, form, version_flag, version, + QUIC_HANDSHAKE_FAILED, "Reject connection", quic::QuicTimeWaitListManager::SEND_STATELESS_RESET); } - DCHECK(time_wait_list_manager_->IsConnectionIdInTimeWait(connection_id)); + DCHECK(time_wait_list_manager_->IsConnectionIdInTimeWait( + server_connection_id)); time_wait_list_manager_->ProcessPacket( - current_self_address_, current_peer_address_, connection_id, form, - GetPerPacketContext()); + current_self_address_, current_peer_address_, server_connection_id, + form, GetPerPacketContext()); // Any packets which were buffered while the stateless rejector logic was // running should be discarded. Do not inform the time wait list manager, // which should already have a made a decision about sending a reject // based on the CHLO alone. - buffered_packets_.DiscardPackets(connection_id); + buffered_packets_.DiscardPackets(server_connection_id); break; case kFateBuffer: // This packet is a non-CHLO packet which has arrived before the // corresponding CHLO, *or* this packet was received while the // corresponding CHLO was being processed. Buffer it. - BufferEarlyPacket(connection_id, form != GOOGLE_QUIC_PACKET, version); + BufferEarlyPacket(server_connection_id, form != GOOGLE_QUIC_PACKET, + version); break; case kFateDrop: // Do nothing with the packet. @@ -775,13 +787,13 @@ DeleteSessions(); } -void QuicDispatcher::OnConnectionClosed(QuicConnectionId connection_id, +void QuicDispatcher::OnConnectionClosed(QuicConnectionId server_connection_id, QuicErrorCode error, const std::string& error_details, ConnectionCloseSource source) { - auto it = session_map_.find(connection_id); + auto it = session_map_.find(server_connection_id); if (it == session_map_.end()) { - QUIC_BUG << "ConnectionId " << connection_id + QUIC_BUG << "ConnectionId " << server_connection_id << " does not exist in the session map. Error: " << QuicErrorCodeToString(error); QUIC_BUG << QuicStackTrace(); @@ -789,7 +801,7 @@ } QUIC_DLOG_IF(INFO, error != QUIC_NO_ERROR) - << "Closing connection (" << connection_id + << "Closing connection (" << server_connection_id << ") due to error: " << QuicErrorCodeToString(error) << ", with details: " << error_details; @@ -828,13 +840,13 @@ void QuicDispatcher::OnStopSendingReceived(const QuicStopSendingFrame& frame) {} void QuicDispatcher::OnConnectionAddedToTimeWaitList( - QuicConnectionId connection_id) { - QUIC_DLOG(INFO) << "Connection " << connection_id + QuicConnectionId server_connection_id) { + QUIC_DLOG(INFO) << "Connection " << server_connection_id << " added to time wait list."; } void QuicDispatcher::StatelesslyTerminateConnection( - QuicConnectionId connection_id, + QuicConnectionId server_connection_id, PacketHeaderFormat format, bool version_flag, ParsedQuicVersion version, @@ -844,26 +856,27 @@ if (format != IETF_QUIC_LONG_HEADER_PACKET && (!GetQuicReloadableFlag(quic_terminate_gquic_connection_as_ietf) || !version_flag)) { - QUIC_DVLOG(1) << "Statelessly terminating " << connection_id + QUIC_DVLOG(1) << "Statelessly terminating " << server_connection_id << " based on a non-ietf-long packet, action:" << action << ", error_code:" << error_code << ", error_details:" << error_details; time_wait_list_manager_->AddConnectionIdToTimeWait( - connection_id, format != GOOGLE_QUIC_PACKET, action, ENCRYPTION_INITIAL, - nullptr); + server_connection_id, format != GOOGLE_QUIC_PACKET, action, + ENCRYPTION_INITIAL, nullptr); return; } // If the version is known and supported by framer, send a connection close. if (IsSupportedVersion(version)) { QUIC_DVLOG(1) - << "Statelessly terminating " << connection_id + << "Statelessly terminating " << server_connection_id << " based on an ietf-long packet, which has a supported version:" << version << ", error_code:" << error_code << ", error_details:" << error_details; - StatelessConnectionTerminator terminator( - connection_id, version, helper_.get(), time_wait_list_manager_.get()); + StatelessConnectionTerminator terminator(server_connection_id, version, + helper_.get(), + time_wait_list_manager_.get()); // This also adds the connection to time wait list. if (format == GOOGLE_QUIC_PACKET) { QUIC_RELOADABLE_FLAG_COUNT_N(quic_terminate_gquic_connection_as_ietf, 1, @@ -875,7 +888,7 @@ } QUIC_DVLOG(1) - << "Statelessly terminating " << connection_id + << "Statelessly terminating " << server_connection_id << " based on an ietf-long packet, which has an unsupported version:" << version << ", error_code:" << error_code << ", error_details:" << error_details; @@ -883,13 +896,14 @@ // with an empty version list, which can be understood by the client. std::vector<std::unique_ptr<QuicEncryptedPacket>> termination_packets; termination_packets.push_back(QuicFramer::BuildVersionNegotiationPacket( - connection_id, /*ietf_quic=*/format != GOOGLE_QUIC_PACKET, + server_connection_id, EmptyQuicConnectionId(), + /*ietf_quic=*/format != GOOGLE_QUIC_PACKET, ParsedQuicVersionVector{UnsupportedQuicVersion()})); if (format == GOOGLE_QUIC_PACKET) { QUIC_RELOADABLE_FLAG_COUNT_N(quic_terminate_gquic_connection_as_ietf, 2, 2); } time_wait_list_manager()->AddConnectionIdToTimeWait( - connection_id, /*ietf_quic=*/format != GOOGLE_QUIC_PACKET, + server_connection_id, /*ietf_quic=*/format != GOOGLE_QUIC_PACKET, QuicTimeWaitListManager::SEND_TERMINATION_PACKETS, ENCRYPTION_INITIAL, &termination_packets); } @@ -913,7 +927,7 @@ DCHECK(!no_framer_); QUIC_BUG_IF( !time_wait_list_manager_->IsConnectionIdInTimeWait( - current_connection_id_) && + current_server_connection_id_) && !ShouldCreateSessionForUnknownVersion(framer_.last_version_label())) << "Unexpected version mismatch: " << QuicVersionLabelToString(framer_.last_version_label()); @@ -1090,11 +1104,11 @@ } void QuicDispatcher::OnExpiredPackets( - QuicConnectionId connection_id, + QuicConnectionId server_connection_id, BufferedPacketList early_arrived_packets) { QUIC_CODE_COUNT(quic_reject_buffered_packets_expired); StatelesslyTerminateConnection( - connection_id, + server_connection_id, early_arrived_packets.ietf_quic ? IETF_QUIC_LONG_HEADER_PACKET : GOOGLE_QUIC_PACKET, /*version_flag=*/true, early_arrived_packets.version, @@ -1107,24 +1121,44 @@ new_sessions_allowed_per_event_loop_ = max_connections_to_create; for (; new_sessions_allowed_per_event_loop_ > 0; --new_sessions_allowed_per_event_loop_) { - QuicConnectionId connection_id; + QuicConnectionId server_connection_id; BufferedPacketList packet_list = - buffered_packets_.DeliverPacketsForNextConnection(&connection_id); + buffered_packets_.DeliverPacketsForNextConnection( + &server_connection_id); const std::list<BufferedPacket>& packets = packet_list.buffered_packets; if (packets.empty()) { return; } - QuicConnectionId original_connection_id = connection_id; - connection_id = - MaybeReplaceConnectionId(connection_id, packet_list.version); + QuicConnectionId original_connection_id = server_connection_id; + server_connection_id = MaybeReplaceServerConnectionId(server_connection_id, + packet_list.version); QuicSession* session = - CreateQuicSession(connection_id, packets.front().peer_address, + CreateQuicSession(server_connection_id, packets.front().peer_address, packet_list.alpn, packet_list.version); - if (original_connection_id != connection_id) { + if (original_connection_id != server_connection_id) { session->connection()->AddIncomingConnectionId(original_connection_id); } - QUIC_DLOG(INFO) << "Created new session for " << connection_id; - session_map_.insert(std::make_pair(connection_id, QuicWrapUnique(session))); + if (packet_list.version.SupportsClientConnectionIds()) { + // Parse out the first packet's source connection ID and set it as the + // connection's client connection ID. + QuicPacketHeader header; + QuicVersionLabel version_label; + std::string detailed_error; + const QuicErrorCode error = QuicFramer::ProcessPacketDispatcher( + *packets.front().packet, expected_server_connection_id_length_, + &header.form, &header.version_flag, &version_label, + &header.destination_connection_id, &header.source_connection_id, + &detailed_error); + if (error == QUIC_NO_ERROR) { + session->connection()->set_client_connection_id( + header.source_connection_id); + } else { + QUIC_DLOG(ERROR) << detailed_error; + } + } + QUIC_DLOG(INFO) << "Created new session for " << server_connection_id; + session_map_.insert( + std::make_pair(server_connection_id, QuicWrapUnique(session))); DeliverPacketsToSession(packets, session); } } @@ -1134,21 +1168,23 @@ } bool QuicDispatcher::ShouldCreateOrBufferPacketForConnection( - QuicConnectionId connection_id, + QuicConnectionId server_connection_id, bool ietf_quic) { - QUIC_VLOG(1) << "Received packet from new connection " << connection_id; + QUIC_VLOG(1) << "Received packet from new connection " + << server_connection_id; return true; } // Return true if there is any packet buffered in the store. -bool QuicDispatcher::HasBufferedPackets(QuicConnectionId connection_id) { - return buffered_packets_.HasBufferedPackets(connection_id); +bool QuicDispatcher::HasBufferedPackets(QuicConnectionId server_connection_id) { + return buffered_packets_.HasBufferedPackets(server_connection_id); } -void QuicDispatcher::OnBufferPacketFailure(EnqueuePacketResult result, - QuicConnectionId connection_id) { - QUIC_DLOG(INFO) << "Fail to buffer packet on connection " << connection_id - << " because of " << result; +void QuicDispatcher::OnBufferPacketFailure( + EnqueuePacketResult result, + QuicConnectionId server_connection_id) { + QUIC_DLOG(INFO) << "Fail to buffer packet on connection " + << server_connection_id << " because of " << result; } bool QuicDispatcher::ShouldAttemptCheapStatelessRejection() { @@ -1160,21 +1196,22 @@ alarm_factory_.get()); } -void QuicDispatcher::BufferEarlyPacket(QuicConnectionId connection_id, +void QuicDispatcher::BufferEarlyPacket(QuicConnectionId server_connection_id, bool ietf_quic, ParsedQuicVersion version) { - bool is_new_connection = !buffered_packets_.HasBufferedPackets(connection_id); - if (is_new_connection && - !ShouldCreateOrBufferPacketForConnection(connection_id, ietf_quic)) { + bool is_new_connection = + !buffered_packets_.HasBufferedPackets(server_connection_id); + if (is_new_connection && !ShouldCreateOrBufferPacketForConnection( + server_connection_id, ietf_quic)) { return; } EnqueuePacketResult rs = buffered_packets_.EnqueuePacket( - connection_id, ietf_quic, *current_packet_, current_self_address_, + server_connection_id, ietf_quic, *current_packet_, current_self_address_, current_peer_address_, /*is_chlo=*/false, /*alpn=*/"", version); if (rs != EnqueuePacketResult::SUCCESS) { - OnBufferPacketFailure(rs, connection_id); + OnBufferPacketFailure(rs, server_connection_id); } } @@ -1184,48 +1221,56 @@ // Don't any create new connection. QUIC_CODE_COUNT(quic_reject_stop_accepting_new_connections); StatelesslyTerminateConnection( - current_connection_id(), form, /*version_flag=*/true, version, + current_server_connection_id(), form, /*version_flag=*/true, version, QUIC_HANDSHAKE_FAILED, "Stop accepting new connections", quic::QuicTimeWaitListManager::SEND_STATELESS_RESET); // Time wait list will reject the packet correspondingly. time_wait_list_manager()->ProcessPacket( - current_self_address(), current_peer_address(), current_connection_id(), - form, GetPerPacketContext()); + current_self_address(), current_peer_address(), + current_server_connection_id(), form, GetPerPacketContext()); return; } - if (!buffered_packets_.HasBufferedPackets(current_connection_id_) && - !ShouldCreateOrBufferPacketForConnection(current_connection_id_, + if (!buffered_packets_.HasBufferedPackets(current_server_connection_id_) && + !ShouldCreateOrBufferPacketForConnection(current_server_connection_id_, form != GOOGLE_QUIC_PACKET)) { return; } if (FLAGS_quic_allow_chlo_buffering && new_sessions_allowed_per_event_loop_ <= 0) { // Can't create new session any more. Wait till next event loop. - QUIC_BUG_IF(buffered_packets_.HasChloForConnection(current_connection_id_)); + QUIC_BUG_IF( + buffered_packets_.HasChloForConnection(current_server_connection_id_)); EnqueuePacketResult rs = buffered_packets_.EnqueuePacket( - current_connection_id_, form != GOOGLE_QUIC_PACKET, *current_packet_, - current_self_address_, current_peer_address_, + current_server_connection_id_, form != GOOGLE_QUIC_PACKET, + *current_packet_, current_self_address_, current_peer_address_, /*is_chlo=*/true, current_alpn_, version); if (rs != EnqueuePacketResult::SUCCESS) { - OnBufferPacketFailure(rs, current_connection_id_); + OnBufferPacketFailure(rs, current_server_connection_id_); } return; } - QuicConnectionId original_connection_id = current_connection_id_; - current_connection_id_ = - MaybeReplaceConnectionId(current_connection_id_, version); + QuicConnectionId original_connection_id = current_server_connection_id_; + current_server_connection_id_ = + MaybeReplaceServerConnectionId(current_server_connection_id_, version); // Creates a new session and process all buffered packets for this connection. - QuicSession* session = CreateQuicSession( - current_connection_id_, current_peer_address_, current_alpn_, version); - if (original_connection_id != current_connection_id_) { + QuicSession* session = + CreateQuicSession(current_server_connection_id_, current_peer_address_, + current_alpn_, version); + if (original_connection_id != current_server_connection_id_) { session->connection()->AddIncomingConnectionId(original_connection_id); } - QUIC_DLOG(INFO) << "Created new session for " << current_connection_id_; + if (version.SupportsClientConnectionIds()) { + session->connection()->set_client_connection_id( + current_client_connection_id_); + } + QUIC_DLOG(INFO) << "Created new session for " + << current_server_connection_id_; session_map_.insert( - std::make_pair(current_connection_id_, QuicWrapUnique(session))); + std::make_pair(current_server_connection_id_, QuicWrapUnique(session))); std::list<BufferedPacket> packets = - buffered_packets_.DeliverPackets(current_connection_id_).buffered_packets; + buffered_packets_.DeliverPackets(current_server_connection_id_) + .buffered_packets; // Process CHLO at first. session->ProcessUdpPacket(current_self_address_, current_peer_address_, *current_packet_); @@ -1295,12 +1340,13 @@ bool current_version_flag_; }; -void QuicDispatcher::MaybeRejectStatelessly(QuicConnectionId connection_id, - PacketHeaderFormat form, - bool version_flag, - ParsedQuicVersion version) { +void QuicDispatcher::MaybeRejectStatelessly( + QuicConnectionId server_connection_id, + PacketHeaderFormat form, + bool version_flag, + ParsedQuicVersion version) { if (version.handshake_protocol == PROTOCOL_TLS1_3) { - ProcessUnauthenticatedHeaderFate(kFateProcess, connection_id, form, + ProcessUnauthenticatedHeaderFate(kFateProcess, server_connection_id, form, version_flag, version); return; // TODO(nharper): Support buffering non-ClientHello packets when using TLS. @@ -1315,14 +1361,15 @@ if (FLAGS_quic_allow_chlo_buffering && !ChloExtractor::Extract(*current_packet_, GetSupportedVersions(), config_->create_session_tag_indicators(), - &alpn_extractor, connection_id.length())) { + &alpn_extractor, + server_connection_id.length())) { // Buffer non-CHLO packets. - ProcessUnauthenticatedHeaderFate(kFateBuffer, connection_id, form, + ProcessUnauthenticatedHeaderFate(kFateBuffer, server_connection_id, form, version_flag, version); return; } current_alpn_ = alpn_extractor.ConsumeAlpn(); - ProcessUnauthenticatedHeaderFate(kFateProcess, connection_id, form, + ProcessUnauthenticatedHeaderFate(kFateProcess, server_connection_id, form, version_flag, version); return; } @@ -1337,8 +1384,8 @@ rejector.get()); if (!ChloExtractor::Extract(*current_packet_, GetSupportedVersions(), config_->create_session_tag_indicators(), - &validator, connection_id.length())) { - ProcessUnauthenticatedHeaderFate(kFateBuffer, connection_id, form, + &validator, server_connection_id.length())) { + ProcessUnauthenticatedHeaderFate(kFateBuffer, server_connection_id, form, version_flag, version); return; } @@ -1347,13 +1394,13 @@ if (!validator.can_accept()) { // This CHLO is prohibited by policy. QUIC_CODE_COUNT(quic_reject_cant_accept_chlo); - StatelessConnectionTerminator terminator(connection_id, version, helper(), - time_wait_list_manager_.get()); + StatelessConnectionTerminator terminator( + server_connection_id, version, helper(), time_wait_list_manager_.get()); terminator.CloseConnection(QUIC_HANDSHAKE_FAILED, validator.error_details(), form != GOOGLE_QUIC_PACKET); QuicSession::RecordConnectionCloseAtServer( QUIC_HANDSHAKE_FAILED, ConnectionCloseSource::FROM_SELF); - ProcessUnauthenticatedHeaderFate(kFateTimeWait, connection_id, form, + ProcessUnauthenticatedHeaderFate(kFateTimeWait, server_connection_id, form, version_flag, version); return; } @@ -1368,10 +1415,10 @@ // Insert into set of connection IDs to buffer const bool ok = - temporarily_buffered_connections_.insert(connection_id).second; + temporarily_buffered_connections_.insert(server_connection_id).second; QUIC_BUG_IF(!ok) << "Processing multiple stateless rejections for connection ID " - << connection_id; + << server_connection_id; // Continue stateless rejector processing std::unique_ptr<StatelessRejectorProcessDoneCallback> cb( @@ -1395,7 +1442,7 @@ current_peer_address_ = current_peer_address; current_self_address_ = current_self_address; current_packet_ = current_packet.get(); - current_connection_id_ = rejector->connection_id(); + current_server_connection_id_ = rejector->connection_id(); if (!no_framer_) { framer_.set_version(first_version); }
diff --git a/quic/core/quic_dispatcher.h b/quic/core/quic_dispatcher.h index 6adf5f3..5e755dd 100644 --- a/quic/core/quic_dispatcher.h +++ b/quic/core/quic_dispatcher.h
@@ -49,7 +49,7 @@ std::unique_ptr<QuicConnectionHelperInterface> helper, std::unique_ptr<QuicCryptoServerStream::Helper> session_helper, std::unique_ptr<QuicAlarmFactory> alarm_factory, - uint8_t expected_connection_id_length); + uint8_t expected_server_connection_id_length); QuicDispatcher(const QuicDispatcher&) = delete; QuicDispatcher& operator=(const QuicDispatcher&) = delete; @@ -76,7 +76,7 @@ // QuicSession::Visitor interface implementation (via inheritance of // QuicTimeWaitListManager::Visitor): // Ensure that the closed connection is cleaned up asynchronously. - void OnConnectionClosed(QuicConnectionId connection_id, + void OnConnectionClosed(QuicConnectionId server_connection_id, QuicErrorCode error, const std::string& error_details, ConnectionCloseSource source) override; @@ -99,7 +99,8 @@ // QuicTimeWaitListManager::Visitor interface implementation // Called whenever the time wait list manager adds a new connection to the // time-wait list. - void OnConnectionAddedToTimeWaitList(QuicConnectionId connection_id) override; + void OnConnectionAddedToTimeWaitList( + QuicConnectionId server_connection_id) override; using SessionMap = QuicUnorderedMap<QuicConnectionId, std::unique_ptr<QuicSession>, @@ -191,7 +192,7 @@ const QuicIetfStatelessResetPacket& packet) override; // QuicBufferedPacketStore::VisitorInterface implementation. - void OnExpiredPackets(QuicConnectionId connection_id, + void OnExpiredPackets(QuicConnectionId server_connection_id, QuicBufferedPacketStore::BufferedPacketList early_arrived_packets) override; @@ -202,7 +203,7 @@ virtual bool HasChlosBuffered() const; protected: - virtual QuicSession* CreateQuicSession(QuicConnectionId connection_id, + virtual QuicSession* CreateQuicSession(QuicConnectionId server_connection_id, const QuicSocketAddress& peer_address, QuicStringPiece alpn, const ParsedQuicVersion& version) = 0; @@ -238,9 +239,9 @@ // will be owned by the dispatcher as time_wait_list_manager_ virtual QuicTimeWaitListManager* CreateQuicTimeWaitListManager(); - // Called when |connection_id| doesn't have an open connection yet, to buffer - // |current_packet_| until it can be delivered to the connection. - void BufferEarlyPacket(QuicConnectionId connection_id, + // Called when |server_connection_id| doesn't have an open connection yet, + // to buffer |current_packet_| until it can be delivered to the connection. + void BufferEarlyPacket(QuicConnectionId server_connection_id, bool ietf_quic, ParsedQuicVersion version); @@ -269,8 +270,8 @@ const ParsedQuicVersionVector& GetSupportedVersions(); - QuicConnectionId current_connection_id() const { - return current_connection_id_; + QuicConnectionId current_server_connection_id() const { + return current_server_connection_id_; } const QuicSocketAddress& current_self_address() const { return current_self_address_; @@ -319,15 +320,15 @@ // for CHLO. Returns true if a new connection should be created or its packets // should be buffered, false otherwise. virtual bool ShouldCreateOrBufferPacketForConnection( - QuicConnectionId connection_id, + QuicConnectionId server_connection_id, bool ietf_quic); - bool HasBufferedPackets(QuicConnectionId connection_id); + bool HasBufferedPackets(QuicConnectionId server_connection_id); // Called when BufferEarlyPacket() fail to buffer the packet. virtual void OnBufferPacketFailure( QuicBufferedPacketStore::EnqueuePacketResult result, - QuicConnectionId connection_id); + QuicConnectionId server_connection_id); // Removes the session from the session map and write blocked list, and adds // the ConnectionId to the time-wait list. If |session_closed_statelessly| is @@ -344,7 +345,7 @@ // connection to time wait list or 2) directly add connection to time wait // list with |action|. void StatelesslyTerminateConnection( - QuicConnectionId connection_id, + QuicConnectionId server_connection_id, PacketHeaderFormat format, bool version_flag, ParsedQuicVersion version, @@ -363,21 +364,22 @@ // If true, our framer will change its expected connection ID length // to the received destination connection ID length of all IETF long headers. void SetShouldUpdateExpectedConnectionIdLength( - bool should_update_expected_connection_id_length) { + bool should_update_expected_server_connection_id_length) { if (!no_framer_) { framer_.SetShouldUpdateExpectedConnectionIdLength( - should_update_expected_connection_id_length); + should_update_expected_server_connection_id_length); return; } - should_update_expected_connection_id_length_ = - should_update_expected_connection_id_length; + should_update_expected_server_connection_id_length_ = + should_update_expected_server_connection_id_length; } // If true, the dispatcher will allow incoming initial packets that have - // connection IDs shorter than 64 bits. - void SetAllowShortInitialConnectionIds( - bool allow_short_initial_connection_ids) { - allow_short_initial_connection_ids_ = allow_short_initial_connection_ids; + // destination connection IDs shorter than 64 bits. + void SetAllowShortInitialServerConnectionIds( + bool allow_short_initial_server_connection_ids) { + allow_short_initial_server_connection_ids_ = + allow_short_initial_server_connection_ids; } private: @@ -395,7 +397,7 @@ // possible and if the current packet contains a CHLO message. Determines a // fate which describes what subsequent processing should be performed on the // packets, like ValidityChecks, and invokes ProcessUnauthenticatedHeaderFate. - void MaybeRejectStatelessly(QuicConnectionId connection_id, + void MaybeRejectStatelessly(QuicConnectionId server_connection_id, PacketHeaderFormat form, bool version_flag, ParsedQuicVersion version); @@ -408,7 +410,7 @@ // Perform the appropriate actions on the current packet based on |fate| - // either process, buffer, or drop it. void ProcessUnauthenticatedHeaderFate(QuicPacketFate fate, - QuicConnectionId connection_id, + QuicConnectionId server_connection_id, PacketHeaderFormat form, bool version_flag, ParsedQuicVersion version); @@ -440,8 +442,9 @@ // If the connection ID length is different from what the dispatcher expects, // replace the connection ID with a random one of the right length, // and save it to make sure the mapping is persistent. - QuicConnectionId MaybeReplaceConnectionId(QuicConnectionId connection_id, - ParsedQuicVersion version); + QuicConnectionId MaybeReplaceServerConnectionId( + QuicConnectionId server_connection_id, + ParsedQuicVersion version); // Returns true if |version| is a supported protocol version. bool IsSupportedVersion(const ParsedQuicVersion version); @@ -505,7 +508,8 @@ const QuicReceivedPacket* current_packet_; // If |current_packet_| is a CHLO packet, the extracted alpn. std::string current_alpn_; - QuicConnectionId current_connection_id_; + QuicConnectionId current_server_connection_id_; + QuicConnectionId current_client_connection_id_; // Used to get the supported versions based on flag. Does not own. QuicVersionManager* version_manager_; @@ -524,8 +528,9 @@ bool accept_new_connections_; // If false, the dispatcher follows the IETF spec and rejects packets with - // invalid connection IDs lengths below 64 bits. If true they are allowed. - bool allow_short_initial_connection_ids_; + // invalid destination connection IDs lengths below 64 bits. + // If true they are allowed. + bool allow_short_initial_server_connection_ids_; // The last QUIC version label received. Used when no_framer_ is true. // TODO(fayang): remove this member variable, instead, add an argument to @@ -537,14 +542,15 @@ // encode its length. This variable contains the length we expect to read. // This is also used to signal an error when a long header packet with // different destination connection ID length is received when - // should_update_expected_connection_id_length_ is false and packet's version - // does not allow variable length connection ID. Used when no_framer_ is true. - uint8_t expected_connection_id_length_; + // should_update_expected_server_connection_id_length_ is false and packet's + // version does not allow variable length connection ID. Used when no_framer_ + // is true. + uint8_t expected_server_connection_id_length_; - // If true, change expected_connection_id_length_ to be the received + // If true, change expected_server_connection_id_length_ to be the received // destination connection ID length of all IETF long headers. Used when // no_framer_ is true. - bool should_update_expected_connection_id_length_; + bool should_update_expected_server_connection_id_length_; // Latched value of quic_no_framer_object_in_dispatcher. const bool no_framer_;
diff --git a/quic/core/quic_dispatcher_test.cc b/quic/core/quic_dispatcher_test.cc index 46c5bc2..668541b 100644 --- a/quic/core/quic_dispatcher_test.cc +++ b/quic/core/quic_dispatcher_test.cc
@@ -156,7 +156,7 @@ using QuicDispatcher::current_client_address; using QuicDispatcher::current_peer_address; using QuicDispatcher::current_self_address; - using QuicDispatcher::SetAllowShortInitialConnectionIds; + using QuicDispatcher::SetAllowShortInitialServerConnectionIds; using QuicDispatcher::writer; }; @@ -493,7 +493,7 @@ EXPECT_CALL(*dispatcher_, CreateQuicSession(_, _, _, _)).Times(0); EXPECT_CALL(*time_wait_list_manager_, - SendVersionNegotiationPacket(_, _, _, _, _, _)) + SendVersionNegotiationPacket(_, _, _, _, _, _, _)) .Times(1); QuicTransportVersion version = static_cast<QuicTransportVersion>(QuicTransportVersionMin() - 1); @@ -512,7 +512,7 @@ EXPECT_CALL(*dispatcher_, CreateQuicSession(_, _, _, _)).Times(0); EXPECT_CALL(*time_wait_list_manager_, - SendVersionNegotiationPacket(_, _, _, _, _, _)) + SendVersionNegotiationPacket(_, _, _, _, _, _, _)) .Times(0); QuicTransportVersion version = static_cast<QuicTransportVersion>(QuicTransportVersionMin() - 1); @@ -538,7 +538,7 @@ EXPECT_CALL(*dispatcher_, CreateQuicSession(_, _, _, _)).Times(0); EXPECT_CALL(*time_wait_list_manager_, - SendVersionNegotiationPacket(_, _, _, _, _, _)) + SendVersionNegotiationPacket(_, _, _, _, _, _, _)) .Times(1); QuicTransportVersion version = static_cast<QuicTransportVersion>(QuicTransportVersionMin() - 1); @@ -686,7 +686,7 @@ QuicUtils::CreateRandomConnectionId(mock_helper_.GetRandomGenerator()); // Disable validation of invalid short connection IDs. - dispatcher_->SetAllowShortInitialConnectionIds(true); + dispatcher_->SetAllowShortInitialServerConnectionIds(true); // Note that StrayPacketTruncatedConnectionId covers the case where the // validation is still enabled.
diff --git a/quic/core/quic_framer.cc b/quic/core/quic_framer.cc index 27af834..c35c076 100644 --- a/quic/core/quic_framer.cc +++ b/quic/core/quic_framer.cc
@@ -459,10 +459,11 @@ QuicFramer::QuicFramer(const ParsedQuicVersionVector& supported_versions, QuicTime creation_time, Perspective perspective, - uint8_t expected_connection_id_length) + uint8_t expected_server_connection_id_length) : visitor_(nullptr), error_(QUIC_NO_ERROR), - last_serialized_connection_id_(EmptyQuicConnectionId()), + last_serialized_server_connection_id_(EmptyQuicConnectionId()), + last_serialized_client_connection_id_(EmptyQuicConnectionId()), last_version_label_(0), version_(PROTOCOL_UNSUPPORTED, QUIC_VERSION_UNSUPPORTED), supported_versions_(supported_versions), @@ -478,8 +479,10 @@ data_producer_(nullptr), infer_packet_header_type_from_version_(perspective == Perspective::IS_CLIENT), - expected_connection_id_length_(expected_connection_id_length), - should_update_expected_connection_id_length_(false), + expected_server_connection_id_length_( + expected_server_connection_id_length), + expected_client_connection_id_length_(0), + should_update_expected_server_connection_id_length_(false), supports_multiple_packet_number_spaces_(false), last_written_packet_number_length_(0) { DCHECK(!supported_versions.empty()); @@ -1387,14 +1390,17 @@ // static std::unique_ptr<QuicEncryptedPacket> QuicFramer::BuildVersionNegotiationPacket( - QuicConnectionId connection_id, + QuicConnectionId server_connection_id, + QuicConnectionId client_connection_id, bool ietf_quic, const ParsedQuicVersionVector& versions) { if (ietf_quic) { - return BuildIetfVersionNegotiationPacket(connection_id, versions); + return BuildIetfVersionNegotiationPacket(server_connection_id, + client_connection_id, versions); } + DCHECK(client_connection_id.IsEmpty()); DCHECK(!versions.empty()); - size_t len = kPublicFlagsSize + connection_id.length() + + size_t len = kPublicFlagsSize + server_connection_id.length() + versions.size() * kQuicVersionSize; std::unique_ptr<char[]> buffer(new char[len]); // Endianness is not a concern here, version negotiation packet does not have @@ -1409,7 +1415,7 @@ return nullptr; } - if (!writer.WriteConnectionId(connection_id)) { + if (!writer.WriteConnectionId(server_connection_id)) { return nullptr; } @@ -1427,13 +1433,14 @@ // static std::unique_ptr<QuicEncryptedPacket> QuicFramer::BuildIetfVersionNegotiationPacket( - QuicConnectionId connection_id, + QuicConnectionId server_connection_id, + QuicConnectionId client_connection_id, const ParsedQuicVersionVector& versions) { QUIC_DVLOG(1) << "Building IETF version negotiation packet: " << ParsedQuicVersionVectorToString(versions); DCHECK(!versions.empty()); size_t len = kPacketHeaderTypeSize + kConnectionIdLengthSize + - connection_id.length() + + client_connection_id.length() + server_connection_id.length() + (versions.size() + 1) * kQuicVersionSize; std::unique_ptr<char[]> buffer(new char[len]); QuicDataWriter writer(len, buffer.get()); @@ -1452,7 +1459,7 @@ return nullptr; } - if (!AppendIetfConnectionIds(true, EmptyQuicConnectionId(), connection_id, + if (!AppendIetfConnectionIds(true, client_connection_id, server_connection_id, &writer)) { return nullptr; } @@ -2064,14 +2071,14 @@ public_flags |= PACKET_PUBLIC_FLAGS_NONCE; } - QuicConnectionId connection_id = + QuicConnectionId server_connection_id = GetServerConnectionIdAsSender(header, perspective_); - QuicConnectionIdIncluded connection_id_included = + QuicConnectionIdIncluded server_connection_id_included = GetServerConnectionIdIncludedAsSender(header, perspective_); DCHECK_EQ(CONNECTION_ID_ABSENT, GetClientConnectionIdIncludedAsSender(header, perspective_)); - switch (connection_id_included) { + switch (server_connection_id_included) { case CONNECTION_ID_ABSENT: if (!writer->WriteUInt8(public_flags | PACKET_PUBLIC_FLAGS_0BYTE_CONNECTION_ID)) { @@ -2080,9 +2087,9 @@ break; case CONNECTION_ID_PRESENT: QUIC_BUG_IF(!QuicUtils::IsConnectionIdValidForVersion( - connection_id, transport_version())) + server_connection_id, transport_version())) << "AppendPacketHeader: attempted to use connection ID " - << connection_id << " which is invalid with version " + << server_connection_id << " which is invalid with version " << QuicVersionToString(transport_version()); public_flags |= PACKET_PUBLIC_FLAGS_8BYTE_CONNECTION_ID; @@ -2090,12 +2097,12 @@ public_flags |= PACKET_PUBLIC_FLAGS_8BYTE_CONNECTION_ID_OLD; } if (!writer->WriteUInt8(public_flags) || - !writer->WriteConnectionId(connection_id)) { + !writer->WriteConnectionId(server_connection_id)) { return false; } break; } - last_serialized_connection_id_ = connection_id; + last_serialized_server_connection_id_ = server_connection_id; if (header.version_flag) { DCHECK_EQ(Perspective::IS_CLIENT, perspective_); @@ -2194,7 +2201,10 @@ return false; } - last_serialized_connection_id_ = server_connection_id; + last_serialized_server_connection_id_ = server_connection_id; + if (version_.SupportsClientConnectionIds()) { + last_serialized_client_connection_id_ = GetClientConnectionIdAsSender(header, perspective_); + } if (QuicVersionHasLongHeaderLengths(transport_version()) && header.version_flag) { @@ -2336,7 +2346,7 @@ break; case PACKET_PUBLIC_FLAGS_0BYTE_CONNECTION_ID: *header_connection_id_included = CONNECTION_ID_ABSENT; - *header_connection_id = last_serialized_connection_id_; + *header_connection_id = last_serialized_server_connection_id_; break; } @@ -2521,11 +2531,15 @@ // connection ID, and those received by server must include 8-byte // destination connection ID. header->destination_connection_id_included = - perspective_ == Perspective::IS_CLIENT ? CONNECTION_ID_ABSENT - : CONNECTION_ID_PRESENT; + (perspective_ == Perspective::IS_SERVER || + version_.SupportsClientConnectionIds()) + ? CONNECTION_ID_PRESENT + : CONNECTION_ID_ABSENT; header->source_connection_id_included = - perspective_ == Perspective::IS_CLIENT ? CONNECTION_ID_PRESENT - : CONNECTION_ID_ABSENT; + (perspective_ == Perspective::IS_CLIENT || + version_.SupportsClientConnectionIds()) + ? CONNECTION_ID_PRESENT + : CONNECTION_ID_ABSENT; // Read version tag. QuicVersionLabel version_label; if (!ProcessVersionLabel(reader, &version_label)) { @@ -2580,8 +2594,10 @@ // Connection ID length depends on the perspective. Client does not expect // destination connection ID, and server expects destination connection ID. header->destination_connection_id_included = - perspective_ == Perspective::IS_CLIENT ? CONNECTION_ID_ABSENT - : CONNECTION_ID_PRESENT; + (perspective_ == Perspective::IS_SERVER || + version_.SupportsClientConnectionIds()) + ? CONNECTION_ID_PRESENT + : CONNECTION_ID_ABSENT; header->source_connection_id_included = CONNECTION_ID_ABSENT; if (infer_packet_header_type_from_version_ && transport_version() > QUIC_VERSION_44 && !(type & FLAGS_FIXED_BIT)) { @@ -2614,11 +2630,21 @@ bool QuicFramer::ProcessAndValidateIetfConnectionIdLength( QuicDataReader* reader, ParsedQuicVersion version, - bool should_update_expected_connection_id_length, - uint8_t* expected_connection_id_length, + Perspective perspective, + bool should_update_expected_server_connection_id_length, + uint8_t* expected_server_connection_id_length, uint8_t* destination_connection_id_length, uint8_t* source_connection_id_length, std::string* detailed_error) { + QUIC_LOG(ERROR) << "ds33 should_update_expected_server_connection_id_length " + << (should_update_expected_server_connection_id_length ? "Y" + : "N") + << " expected_server_connection_id_length " + << (int)*expected_server_connection_id_length + << " destination_connection_id_length " + << (int)*destination_connection_id_length + << " source_connection_id_length " + << (int)*source_connection_id_length; uint8_t connection_id_lengths_byte; if (!reader->ReadBytes(&connection_id_lengths_byte, 1)) { *detailed_error = "Unable to read ConnectionId length."; @@ -2629,18 +2655,21 @@ if (dcil != 0) { dcil += kConnectionIdLengthAdjustment; } - if (should_update_expected_connection_id_length && - *expected_connection_id_length != dcil) { - QUIC_DVLOG(1) << "Updating expected_connection_id_length: " - << static_cast<int>(*expected_connection_id_length) << " -> " - << static_cast<int>(dcil); - *expected_connection_id_length = dcil; - } uint8_t scil = connection_id_lengths_byte & kSourceConnectionIdLengthMask; if (scil != 0) { scil += kConnectionIdLengthAdjustment; } - if (!should_update_expected_connection_id_length && + if (should_update_expected_server_connection_id_length) { + uint8_t server_connection_id_length = + perspective == Perspective::IS_SERVER ? dcil : scil; + if (*expected_server_connection_id_length != server_connection_id_length) { + QUIC_DVLOG(1) << "Updating expected_server_connection_id_length: " + << static_cast<int>(*expected_server_connection_id_length) + << " -> " << static_cast<int>(server_connection_id_length); + *expected_server_connection_id_length = server_connection_id_length; + } + } + if (!should_update_expected_server_connection_id_length && (dcil != *destination_connection_id_length || scil != *source_connection_id_length) && !QuicUtils::VariableLengthConnectionIdAllowedForVersion( @@ -2665,18 +2694,35 @@ uint8_t destination_connection_id_length = header->destination_connection_id_included == CONNECTION_ID_PRESENT - ? expected_connection_id_length_ + ? (perspective_ == Perspective::IS_SERVER + ? expected_server_connection_id_length_ + : expected_client_connection_id_length_) : 0; uint8_t source_connection_id_length = header->source_connection_id_included == CONNECTION_ID_PRESENT - ? expected_connection_id_length_ + ? (perspective_ == Perspective::IS_CLIENT + ? expected_server_connection_id_length_ + : expected_client_connection_id_length_) : 0; + QUIC_LOG(ERROR) << ENDPOINT + << "ds33 should_update_expected_server_connection_id_length_ " + << (should_update_expected_server_connection_id_length_ ? "Y" + : "N") + << " expected_server_connection_id_length_ " + << (int)expected_server_connection_id_length_ + << " expected_client_connection_id_length_ " + << (int)expected_client_connection_id_length_ + << " destination_connection_id_length " + << (int)destination_connection_id_length + << " source_connection_id_length " + << (int)source_connection_id_length << " header " << *header; if (header->form == IETF_QUIC_LONG_HEADER_PACKET) { if (!ProcessAndValidateIetfConnectionIdLength( - reader, header->version, - should_update_expected_connection_id_length_, - &expected_connection_id_length_, &destination_connection_id_length, - &source_connection_id_length, &detailed_error_)) { + reader, header->version, perspective_, + should_update_expected_server_connection_id_length_, + &expected_server_connection_id_length_, + &destination_connection_id_length, &source_connection_id_length, + &detailed_error_)) { return false; } } @@ -2709,13 +2755,17 @@ header->destination_connection_id = header->source_connection_id; } else if (header->destination_connection_id_included == CONNECTION_ID_ABSENT) { - header->destination_connection_id = last_serialized_connection_id_; + header->destination_connection_id = last_serialized_server_connection_id_; } } else { QUIC_RESTART_FLAG_COUNT_N(quic_do_not_override_connection_id, 5, 5); if (header->source_connection_id_included == CONNECTION_ID_ABSENT) { DCHECK_EQ(EmptyQuicConnectionId(), header->source_connection_id); - header->source_connection_id = last_serialized_connection_id_; + if (perspective_ == Perspective::IS_CLIENT) { + header->source_connection_id = last_serialized_server_connection_id_; + } else { + header->source_connection_id = last_serialized_client_connection_id_; + } } } @@ -6042,28 +6092,30 @@ // static QuicErrorCode QuicFramer::ProcessPacketDispatcher( const QuicEncryptedPacket& packet, - uint8_t expected_connection_id_length, + uint8_t expected_server_connection_id_length, PacketHeaderFormat* format, bool* version_flag, QuicVersionLabel* version_label, - uint8_t* destination_connection_id_length, QuicConnectionId* destination_connection_id, + QuicConnectionId* source_connection_id, std::string* detailed_error) { QuicDataReader reader(packet.data(), packet.length()); + *source_connection_id = EmptyQuicConnectionId(); uint8_t first_byte; if (!reader.ReadBytes(&first_byte, 1)) { *detailed_error = "Unable to read first byte."; return QUIC_INVALID_PACKET_HEADER; } + uint8_t destination_connection_id_length = 0, source_connection_id_length = 0; if (!QuicUtils::IsIetfPacketHeader(first_byte)) { *format = GOOGLE_QUIC_PACKET; *version_flag = (first_byte & PACKET_PUBLIC_FLAGS_VERSION) != 0; - *destination_connection_id_length = + destination_connection_id_length = first_byte & PACKET_PUBLIC_FLAGS_8BYTE_CONNECTION_ID; - if (*destination_connection_id_length == 0 || + if (destination_connection_id_length == 0 || !reader.ReadConnectionId(destination_connection_id, - *destination_connection_id_length)) { + destination_connection_id_length)) { *detailed_error = "Unable to read ConnectionId."; return QUIC_INVALID_PACKET_HEADER; } @@ -6083,27 +6135,33 @@ *detailed_error = "Unable to read protocol version."; return QUIC_INVALID_PACKET_HEADER; } - // Set should_update_expected_connection_id_length to true to bypass + // Set should_update_expected_server_connection_id_length to true to bypass // connection ID lengths validation. - uint8_t unused_source_connection_id_length = 0; - uint8_t unused_expected_connection_id_length = 0; + uint8_t unused_expected_server_connection_id_length = 0; if (!ProcessAndValidateIetfConnectionIdLength( &reader, ParseQuicVersionLabel(*version_label), - /*should_update_expected_connection_id_length=*/true, - &unused_expected_connection_id_length, - destination_connection_id_length, - &unused_source_connection_id_length, detailed_error)) { + Perspective::IS_SERVER, + /*should_update_expected_server_connection_id_length=*/true, + &unused_expected_server_connection_id_length, + &destination_connection_id_length, &source_connection_id_length, + detailed_error)) { return QUIC_INVALID_PACKET_HEADER; } } else { - // For short header packets, expected_connection_id_length is used to + // For short header packets, expected_server_connection_id_length is used to // determine the destination_connection_id_length. - *destination_connection_id_length = expected_connection_id_length; + destination_connection_id_length = expected_server_connection_id_length; } // Read destination connection ID. if (!reader.ReadConnectionId(destination_connection_id, - *destination_connection_id_length)) { - *detailed_error = "Unable to read Destination ConnectionId."; + destination_connection_id_length)) { + *detailed_error = "Unable to read destination connection ID."; + return QUIC_INVALID_PACKET_HEADER; + } + // Read source connection ID. + if (!reader.ReadConnectionId(source_connection_id, + source_connection_id_length)) { + *detailed_error = "Unable to read source connection ID."; return QUIC_INVALID_PACKET_HEADER; } return QUIC_NO_ERROR; @@ -6234,13 +6292,14 @@ *detailed_error = "Packet is not a version negotiation packet"; return false; } - uint8_t expected_connection_id_length = 0, + uint8_t expected_server_connection_id_length = 0, destination_connection_id_length = 0, source_connection_id_length = 0; if (!ProcessAndValidateIetfConnectionIdLength( - &reader, UnsupportedQuicVersion(), - /*should_update_expected_connection_id_length=*/true, - &expected_connection_id_length, &destination_connection_id_length, - &source_connection_id_length, detailed_error)) { + &reader, UnsupportedQuicVersion(), Perspective::IS_CLIENT, + /*should_update_expected_server_connection_id_length=*/true, + &expected_server_connection_id_length, + &destination_connection_id_length, &source_connection_id_length, + detailed_error)) { return false; } if (destination_connection_id_length != 0) {
diff --git a/quic/core/quic_framer.h b/quic/core/quic_framer.h index ab0e8d7..931dd4a 100644 --- a/quic/core/quic_framer.h +++ b/quic/core/quic_framer.h
@@ -226,7 +226,7 @@ QuicFramer(const ParsedQuicVersionVector& supported_versions, QuicTime creation_time, Perspective perspective, - uint8_t expected_connection_id_length); + uint8_t expected_server_connection_id_length); QuicFramer(const QuicFramer&) = delete; QuicFramer& operator=(const QuicFramer&) = delete; @@ -374,18 +374,18 @@ QuicVariableLengthIntegerLength length_length); // Lightweight parsing of |packet| and populates |format|, |version_flag|, - // |version_label|, |destination_connection_id_length|, - // |destination_connection_id| and |detailed_error|. Please note, - // |expected_connection_id_length| is only used to determine IETF short header - // packet's destination connection ID length. + // |version_label|, |destination_connection_id|, |source_connection_id| and + // |detailed_error|. Please note, |expected_server_connection_id_length| is + // only used to determine IETF short header packet's destination connection ID + // length. static QuicErrorCode ProcessPacketDispatcher( const QuicEncryptedPacket& packet, - uint8_t expected_connection_id_length, + uint8_t expected_server_connection_id_length, PacketHeaderFormat* format, bool* version_flag, QuicVersionLabel* version_label, - uint8_t* destination_connection_id_length, QuicConnectionId* destination_connection_id, + QuicConnectionId* source_connection_id, std::string* detailed_error); // Serializes a packet containing |frames| into |buffer|. @@ -435,13 +435,15 @@ // Returns a new version negotiation packet. static std::unique_ptr<QuicEncryptedPacket> BuildVersionNegotiationPacket( - QuicConnectionId connection_id, + QuicConnectionId server_connection_id, + QuicConnectionId client_connection_id, bool ietf_quic, const ParsedQuicVersionVector& versions); // Returns a new IETF version negotiation packet. static std::unique_ptr<QuicEncryptedPacket> BuildIetfVersionNegotiationPacket( - QuicConnectionId connection_id, + QuicConnectionId server_connection_id, + QuicConnectionId client_connection_id, const ParsedQuicVersionVector& versions); // If header.version_flag is set, the version in the @@ -576,14 +578,23 @@ // If true, QuicFramer will change its expected connection ID length // to the received destination connection ID length of all IETF long headers. void SetShouldUpdateExpectedConnectionIdLength( - bool should_update_expected_connection_id_length) { - should_update_expected_connection_id_length_ = - should_update_expected_connection_id_length; + bool should_update_expected_server_connection_id_length) { + should_update_expected_server_connection_id_length_ = + should_update_expected_server_connection_id_length; } - // The connection ID length the framer expects on incoming IETF short headers. - uint8_t GetExpectedConnectionIdLength() { - return expected_connection_id_length_; + // The connection ID length the framer expects on incoming IETF short headers + // on the server. + uint8_t GetExpectedServerConnectionIdLength() { + return expected_server_connection_id_length_; + } + + // Change the expected destination connection ID length for short headers on + // the client. + void SetExpectedClientConnectionIdLength( + uint8_t expected_client_connection_id_length) { + expected_client_connection_id_length_ = + expected_client_connection_id_length; } void EnableMultiplePacketNumberSpacesSupport(); @@ -715,8 +726,9 @@ static bool ProcessAndValidateIetfConnectionIdLength( QuicDataReader* reader, ParsedQuicVersion version, - bool should_update_expected_connection_id_length, - uint8_t* expected_connection_id_length, + Perspective perspective, + bool should_update_expected_server_connection_id_length, + uint8_t* expected_server_connection_id_length, uint8_t* destination_connection_id_length, uint8_t* source_connection_id_length, std::string* detailed_error); @@ -958,8 +970,10 @@ // Largest successfully decrypted packet number per packet number space. Only // used when supports_multiple_packet_number_spaces_ is true. QuicPacketNumber largest_decrypted_packet_numbers_[NUM_PACKET_NUMBER_SPACES]; - // Updated by WritePacketHeader. - QuicConnectionId last_serialized_connection_id_; + // Last server connection ID seen on the wire. + QuicConnectionId last_serialized_server_connection_id_; + // Last client connection ID seen on the wire. + QuicConnectionId last_serialized_client_connection_id_; // The last QUIC version label received. // TODO(fayang): Remove this when deprecating // quic_no_framer_object_in_dispatcher. @@ -1014,18 +1028,18 @@ bool infer_packet_header_type_from_version_; // IETF short headers contain a destination connection ID but do not - // encode its length. This variable contains the length we expect to read. - // This is also used to validate the long header connection ID lengths in - // older versions of QUIC. - // TODO(fayang): Remove this when deprecating - // quic_no_framer_object_in_dispatcher. - uint8_t expected_connection_id_length_; + // encode its length. These variables contains the length we expect to read. + // This is also used to validate the long header destination connection ID + // lengths in older versions of QUIC. + uint8_t expected_server_connection_id_length_; + uint8_t expected_client_connection_id_length_; - // When this is true, QuicFramer will change expected_connection_id_length_ - // to the received destination connection ID length of all IETF long headers. + // When this is true, QuicFramer will change + // expected_server_connection_id_length_ to the received destination + // connection ID length of all IETF long headers. // TODO(fayang): Remove this when deprecating // quic_no_framer_object_in_dispatcher. - bool should_update_expected_connection_id_length_; + bool should_update_expected_server_connection_id_length_; // Indicates whether this framer supports multiple packet number spaces. bool supports_multiple_packet_number_spaces_;
diff --git a/quic/core/quic_framer_test.cc b/quic/core/quic_framer_test.cc index bf87ace..5f46e8e 100644 --- a/quic/core/quic_framer_test.cc +++ b/quic/core/quic_framer_test.cc
@@ -916,18 +916,17 @@ PacketHeaderFormat format; bool version_flag; - uint8_t destination_connection_id_length; - QuicConnectionId destination_connection_id; + QuicConnectionId destination_connection_id, source_connection_id; QuicVersionLabel version_label; std::string detailed_error; - EXPECT_EQ(QUIC_NO_ERROR, QuicFramer::ProcessPacketDispatcher( - *encrypted, kQuicDefaultConnectionIdLength, - &format, &version_flag, &version_label, - &destination_connection_id_length, - &destination_connection_id, &detailed_error)); + EXPECT_EQ(QUIC_NO_ERROR, + QuicFramer::ProcessPacketDispatcher( + *encrypted, kQuicDefaultConnectionIdLength, &format, + &version_flag, &version_label, &destination_connection_id, + &source_connection_id, &detailed_error)); EXPECT_EQ(GOOGLE_QUIC_PACKET, format); EXPECT_FALSE(version_flag); - EXPECT_EQ(kQuicDefaultConnectionIdLength, destination_connection_id_length); + EXPECT_EQ(kQuicDefaultConnectionIdLength, destination_connection_id.length()); EXPECT_EQ(FramerTestConnectionId(), destination_connection_id); } @@ -994,25 +993,24 @@ PacketHeaderFormat format; bool version_flag; - uint8_t destination_connection_id_length; - QuicConnectionId destination_connection_id; + QuicConnectionId destination_connection_id, source_connection_id; QuicVersionLabel version_label; std::string detailed_error; - EXPECT_EQ(QUIC_NO_ERROR, QuicFramer::ProcessPacketDispatcher( - *encrypted, kQuicDefaultConnectionIdLength, - &format, &version_flag, &version_label, - &destination_connection_id_length, - &destination_connection_id, &detailed_error)); + EXPECT_EQ(QUIC_NO_ERROR, + QuicFramer::ProcessPacketDispatcher( + *encrypted, kQuicDefaultConnectionIdLength, &format, + &version_flag, &version_label, &destination_connection_id, + &source_connection_id, &detailed_error)); EXPECT_EQ(IETF_QUIC_LONG_HEADER_PACKET, format); EXPECT_TRUE(version_flag); - EXPECT_EQ(kQuicDefaultConnectionIdLength, destination_connection_id_length); + EXPECT_EQ(kQuicDefaultConnectionIdLength, destination_connection_id.length()); EXPECT_EQ(FramerTestConnectionId(), destination_connection_id); } TEST_P(QuicFramerTest, PacketHeaderWith0ByteConnectionId) { SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); - QuicFramerPeer::SetLastSerializedConnectionId(&framer_, - FramerTestConnectionId()); + QuicFramerPeer::SetLastSerializedServerConnectionId(&framer_, + FramerTestConnectionId()); QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); // clang-format off @@ -6626,7 +6624,8 @@ QuicConnectionId connection_id = FramerTestConnectionId(); std::unique_ptr<QuicEncryptedPacket> data( framer_.BuildVersionNegotiationPacket( - connection_id, framer_.transport_version() > QUIC_VERSION_43, + connection_id, EmptyQuicConnectionId(), + framer_.transport_version() > QUIC_VERSION_43, SupportedVersions(GetParam()))); test::CompareCharArraysWithHexError("constructed packet", data->data(), data->length(), AsChars(p), p_size); @@ -13182,8 +13181,8 @@ QuicConnectionId connection_id(connection_id_bytes, sizeof(connection_id_bytes)); QuicFramerPeer::SetLargestPacketNumber(&framer_, kPacketNumber - 2); - QuicFramerPeer::SetExpectedConnectionIDLength(&framer_, - connection_id.length()); + QuicFramerPeer::SetExpectedServerConnectionIDLength(&framer_, + connection_id.length()); // clang-format off PacketFragments packet = { @@ -13320,8 +13319,7 @@ PacketHeaderFormat format; bool version_flag; - uint8_t destination_connection_id_length; - QuicConnectionId destination_connection_id; + QuicConnectionId destination_connection_id, source_connection_id; QuicVersionLabel version_label; std::string detailed_error; EXPECT_EQ(QUIC_NO_ERROR, @@ -13329,21 +13327,21 @@ QuicEncryptedPacket(AsChars(long_header_packet), QUIC_ARRAYSIZE(long_header_packet)), kQuicDefaultConnectionIdLength, &format, &version_flag, - &version_label, &destination_connection_id_length, - &destination_connection_id, &detailed_error)); + &version_label, &destination_connection_id, + &source_connection_id, &detailed_error)); EXPECT_EQ(IETF_QUIC_LONG_HEADER_PACKET, format); EXPECT_TRUE(version_flag); - EXPECT_EQ(9, destination_connection_id_length); + EXPECT_EQ(9, destination_connection_id.length()); EXPECT_EQ(FramerTestConnectionIdNineBytes(), destination_connection_id); - EXPECT_EQ(QUIC_NO_ERROR, - QuicFramer::ProcessPacketDispatcher( - short_header_encrypted, 9, &format, &version_flag, - &version_label, &destination_connection_id_length, - &destination_connection_id, &detailed_error)); + EXPECT_EQ( + QUIC_NO_ERROR, + QuicFramer::ProcessPacketDispatcher( + short_header_encrypted, 9, &format, &version_flag, &version_label, + &destination_connection_id, &source_connection_id, &detailed_error)); EXPECT_EQ(IETF_QUIC_SHORT_HEADER_PACKET, format); EXPECT_FALSE(version_flag); - EXPECT_EQ(9, destination_connection_id_length); + EXPECT_EQ(9, destination_connection_id.length()); EXPECT_EQ(FramerTestConnectionIdNineBytes(), destination_connection_id); }
diff --git a/quic/core/quic_packet_creator.cc b/quic/core/quic_packet_creator.cc index c013ce8..47af951 100644 --- a/quic/core/quic_packet_creator.cc +++ b/quic/core/quic_packet_creator.cc
@@ -53,15 +53,15 @@ #define ENDPOINT \ (framer_->perspective() == Perspective::IS_SERVER ? "Server: " : "Client: ") -QuicPacketCreator::QuicPacketCreator(QuicConnectionId connection_id, +QuicPacketCreator::QuicPacketCreator(QuicConnectionId server_connection_id, QuicFramer* framer, DelegateInterface* delegate) - : QuicPacketCreator(connection_id, + : QuicPacketCreator(server_connection_id, framer, QuicRandom::GetInstance(), delegate) {} -QuicPacketCreator::QuicPacketCreator(QuicConnectionId connection_id, +QuicPacketCreator::QuicPacketCreator(QuicConnectionId server_connection_id, QuicFramer* framer, QuicRandom* random, DelegateInterface* delegate) @@ -72,9 +72,10 @@ send_version_in_packet_(framer->perspective() == Perspective::IS_CLIENT), have_diversification_nonce_(false), max_packet_length_(0), - connection_id_included_(CONNECTION_ID_PRESENT), + server_connection_id_included_(CONNECTION_ID_PRESENT), packet_size_(0), - connection_id_(connection_id), + server_connection_id_(server_connection_id), + client_connection_id_(EmptyQuicConnectionId()), packet_(QuicPacketNumber(), PACKET_1BYTE_PACKET_NUMBER, nullptr, @@ -626,8 +627,9 @@ const ParsedQuicVersionVector& supported_versions) { DCHECK_EQ(Perspective::IS_SERVER, framer_->perspective()); std::unique_ptr<QuicEncryptedPacket> encrypted = - QuicFramer::BuildVersionNegotiationPacket(connection_id_, ietf_quic, - supported_versions); + QuicFramer::BuildVersionNegotiationPacket(server_connection_id_, + client_connection_id_, + ietf_quic, supported_versions); DCHECK(encrypted); DCHECK_GE(max_packet_length_, encrypted->length()); return encrypted; @@ -737,24 +739,24 @@ QuicConnectionId QuicPacketCreator::GetDestinationConnectionId() const { if (!GetQuicRestartFlag(quic_do_not_override_connection_id)) { - return connection_id_; + return server_connection_id_; } QUIC_RESTART_FLAG_COUNT_N(quic_do_not_override_connection_id, 1, 5); if (framer_->perspective() == Perspective::IS_SERVER) { - return EmptyQuicConnectionId(); + return client_connection_id_; } - return connection_id_; + return server_connection_id_; } QuicConnectionId QuicPacketCreator::GetSourceConnectionId() const { if (!GetQuicRestartFlag(quic_do_not_override_connection_id)) { - return connection_id_; + return server_connection_id_; } QUIC_RESTART_FLAG_COUNT_N(quic_do_not_override_connection_id, 6, 6); if (framer_->perspective() == Perspective::IS_CLIENT) { - return EmptyQuicConnectionId(); + return client_connection_id_; } - return connection_id_; + return server_connection_id_; } QuicConnectionIdIncluded QuicPacketCreator::GetDestinationConnectionIdIncluded() @@ -763,30 +765,33 @@ GetQuicRestartFlag(quic_do_not_override_connection_id)) { // Packets sent by client always include destination connection ID, and // those sent by the server do not include destination connection ID. - return framer_->perspective() == Perspective::IS_CLIENT + return (framer_->perspective() == Perspective::IS_CLIENT || + framer_->version().SupportsClientConnectionIds()) ? CONNECTION_ID_PRESENT : CONNECTION_ID_ABSENT; } - return connection_id_included_; + return server_connection_id_included_; } QuicConnectionIdIncluded QuicPacketCreator::GetSourceConnectionIdIncluded() const { // Long header packets sent by server include source connection ID. - if (HasIetfLongHeader() && framer_->perspective() == Perspective::IS_SERVER) { + if (HasIetfLongHeader() && + (framer_->perspective() == Perspective::IS_SERVER || + framer_->version().SupportsClientConnectionIds())) { return CONNECTION_ID_PRESENT; } if (GetQuicRestartFlag(quic_do_not_override_connection_id) && framer_->perspective() == Perspective::IS_SERVER) { QUIC_RESTART_FLAG_COUNT_N(quic_do_not_override_connection_id, 2, 5); - return connection_id_included_; + return server_connection_id_included_; } return CONNECTION_ID_ABSENT; } QuicConnectionIdLength QuicPacketCreator::GetDestinationConnectionIdLength() const { - DCHECK(QuicUtils::IsConnectionIdValidForVersion(connection_id_, + DCHECK(QuicUtils::IsConnectionIdValidForVersion(server_connection_id_, transport_version())); return GetDestinationConnectionIdIncluded() == CONNECTION_ID_PRESENT ? static_cast<QuicConnectionIdLength>( @@ -795,7 +800,7 @@ } QuicConnectionIdLength QuicPacketCreator::GetSourceConnectionIdLength() const { - DCHECK(QuicUtils::IsConnectionIdValidForVersion(connection_id_, + DCHECK(QuicUtils::IsConnectionIdValidForVersion(server_connection_id_, transport_version())); return GetSourceConnectionIdIncluded() == CONNECTION_ID_PRESENT ? static_cast<QuicConnectionIdLength>( @@ -1019,17 +1024,25 @@ return packet_.encryption_level == ENCRYPTION_INITIAL; } -void QuicPacketCreator::SetConnectionIdIncluded( - QuicConnectionIdIncluded connection_id_included) { - DCHECK(connection_id_included == CONNECTION_ID_PRESENT || - connection_id_included == CONNECTION_ID_ABSENT); +void QuicPacketCreator::SetServerConnectionIdIncluded( + QuicConnectionIdIncluded server_connection_id_included) { + DCHECK(server_connection_id_included == CONNECTION_ID_PRESENT || + server_connection_id_included == CONNECTION_ID_ABSENT); DCHECK(framer_->perspective() == Perspective::IS_SERVER || - connection_id_included != CONNECTION_ID_ABSENT); - connection_id_included_ = connection_id_included; + server_connection_id_included != CONNECTION_ID_ABSENT); + server_connection_id_included_ = server_connection_id_included; } -void QuicPacketCreator::SetConnectionId(QuicConnectionId connection_id) { - connection_id_ = connection_id; +void QuicPacketCreator::SetServerConnectionId( + QuicConnectionId server_connection_id) { + server_connection_id_ = server_connection_id; +} + +void QuicPacketCreator::SetClientConnectionId( + QuicConnectionId client_connection_id) { + DCHECK(client_connection_id.IsEmpty() || + framer_->version().SupportsClientConnectionIds()); + client_connection_id_ = client_connection_id; } void QuicPacketCreator::SetTransmissionType(TransmissionType type) {
diff --git a/quic/core/quic_packet_creator.h b/quic/core/quic_packet_creator.h index 6e1f7a2..b666f71 100644 --- a/quic/core/quic_packet_creator.h +++ b/quic/core/quic_packet_creator.h
@@ -56,10 +56,10 @@ virtual void OnFrameAddedToPacket(const QuicFrame& frame) {} }; - QuicPacketCreator(QuicConnectionId connection_id, + QuicPacketCreator(QuicConnectionId server_connection_id, QuicFramer* framer, DelegateInterface* delegate); - QuicPacketCreator(QuicConnectionId connection_id, + QuicPacketCreator(QuicConnectionId server_connection_id, QuicFramer* framer, QuicRandom* random, DelegateInterface* delegate); @@ -222,11 +222,15 @@ // Returns length of source connection ID to send over the wire. QuicConnectionIdLength GetSourceConnectionIdLength() const; - // Sets whether the connection ID should be sent over the wire. - void SetConnectionIdIncluded(QuicConnectionIdIncluded connection_id_included); + // Sets whether the server connection ID should be sent over the wire. + void SetServerConnectionIdIncluded( + QuicConnectionIdIncluded server_connection_id_included); - // Update the connection ID used in outgoing packets. - void SetConnectionId(QuicConnectionId connection_id); + // Update the server connection ID used in outgoing packets. + void SetServerConnectionId(QuicConnectionId server_connection_id); + + // Update the client connection ID used in outgoing packets. + void SetClientConnectionId(QuicConnectionId client_connection_id); // Sets the encryption level that will be applied to new packets. void set_encryption_level(EncryptionLevel level) { @@ -393,8 +397,8 @@ // Maximum length including headers and encryption (UDP payload length.) QuicByteCount max_packet_length_; size_t max_plaintext_size_; - // Whether the connection_id is sent over the wire. - QuicConnectionIdIncluded connection_id_included_; + // Whether the server_connection_id is sent over the wire. + QuicConnectionIdIncluded server_connection_id_included_; // Frames to be added to the next SerializedPacket QuicFrames queued_frames_; @@ -403,7 +407,8 @@ // TODO(ianswett): Move packet_size_ into SerializedPacket once // QuicEncryptedPacket has been flattened into SerializedPacket. size_t packet_size_; - QuicConnectionId connection_id_; + QuicConnectionId server_connection_id_; + QuicConnectionId client_connection_id_; // Packet used to invoke OnSerializedPacket. SerializedPacket packet_;
diff --git a/quic/core/quic_packet_generator.cc b/quic/core/quic_packet_generator.cc index d811f9d..766f8ca 100644 --- a/quic/core/quic_packet_generator.cc +++ b/quic/core/quic_packet_generator.cc
@@ -17,12 +17,12 @@ namespace quic { -QuicPacketGenerator::QuicPacketGenerator(QuicConnectionId connection_id, +QuicPacketGenerator::QuicPacketGenerator(QuicConnectionId server_connection_id, QuicFramer* framer, QuicRandom* random_generator, DelegateInterface* delegate) : delegate_(delegate), - packet_creator_(connection_id, framer, random_generator, delegate), + packet_creator_(server_connection_id, framer, random_generator, delegate), next_transmission_type_(NOT_RETRANSMISSION), flusher_attached_(false), should_send_ack_(false), @@ -438,11 +438,11 @@ max_packets_in_flight); } -void QuicPacketGenerator::SetConnectionIdLength(uint32_t length) { +void QuicPacketGenerator::SetServerConnectionIdLength(uint32_t length) { if (length == 0) { - packet_creator_.SetConnectionIdIncluded(CONNECTION_ID_ABSENT); + packet_creator_.SetServerConnectionIdIncluded(CONNECTION_ID_ABSENT); } else { - packet_creator_.SetConnectionIdIncluded(CONNECTION_ID_PRESENT); + packet_creator_.SetServerConnectionIdIncluded(CONNECTION_ID_PRESENT); } } @@ -569,8 +569,14 @@ return packet_creator_.GetGuaranteedLargestMessagePayload(); } -void QuicPacketGenerator::SetConnectionId(QuicConnectionId connection_id) { - packet_creator_.SetConnectionId(connection_id); +void QuicPacketGenerator::SetServerConnectionId( + QuicConnectionId server_connection_id) { + packet_creator_.SetServerConnectionId(server_connection_id); +} + +void QuicPacketGenerator::SetClientConnectionId( + QuicConnectionId client_connection_id) { + packet_creator_.SetClientConnectionId(client_connection_id); } } // namespace quic
diff --git a/quic/core/quic_packet_generator.h b/quic/core/quic_packet_generator.h index 417fda6..ec25550 100644 --- a/quic/core/quic_packet_generator.h +++ b/quic/core/quic_packet_generator.h
@@ -76,7 +76,7 @@ QuicStopWaitingFrame* stop_waiting) = 0; }; - QuicPacketGenerator(QuicConnectionId connection_id, + QuicPacketGenerator(QuicConnectionId server_connection_id, QuicFramer* framer, QuicRandom* random_generator, DelegateInterface* delegate); @@ -185,8 +185,8 @@ void UpdatePacketNumberLength(QuicPacketNumber least_packet_awaited_by_peer, QuicPacketCount max_packets_in_flight); - // Set the minimum number of bytes for the connection id length; - void SetConnectionIdLength(uint32_t length); + // Set the minimum number of bytes for the server connection id length; + void SetServerConnectionIdLength(uint32_t length); // Sets the encrypter to use for the encryption level. void SetEncrypter(EncryptionLevel level, @@ -235,8 +235,11 @@ QuicPacketLength GetCurrentLargestMessagePayload() const; QuicPacketLength GetGuaranteedLargestMessagePayload() const; - // Update the connection ID used in outgoing packets. - void SetConnectionId(QuicConnectionId connection_id); + // Update the server connection ID used in outgoing packets. + void SetServerConnectionId(QuicConnectionId server_connection_id); + + // Update the client connection ID used in outgoing packets. + void SetClientConnectionId(QuicConnectionId client_connection_id); void set_debug_delegate(QuicPacketCreator::DebugDelegate* debug_delegate) { packet_creator_.set_debug_delegate(debug_delegate);
diff --git a/quic/core/quic_packet_generator_test.cc b/quic/core/quic_packet_generator_test.cc index 676d6e7..e1899ff 100644 --- a/quic/core/quic_packet_generator_test.cc +++ b/quic/core/quic_packet_generator_test.cc
@@ -1139,12 +1139,12 @@ TEST_F(QuicPacketGeneratorTest, TestConnectionIdLength) { QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); - generator_.SetConnectionIdLength(0); + generator_.SetServerConnectionIdLength(0); EXPECT_EQ(PACKET_0BYTE_CONNECTION_ID, creator_->GetDestinationConnectionIdLength()); for (size_t i = 1; i < 10; i++) { - generator_.SetConnectionIdLength(i); + generator_.SetServerConnectionIdLength(i); if (framer_.transport_version() > QUIC_VERSION_43) { EXPECT_EQ(PACKET_0BYTE_CONNECTION_ID, creator_->GetDestinationConnectionIdLength());
diff --git a/quic/core/quic_packets.cc b/quic/core/quic_packets.cc index 57bda54..738cc6e 100644 --- a/quic/core/quic_packets.cc +++ b/quic/core/quic_packets.cc
@@ -27,6 +27,16 @@ return header.source_connection_id; } +QuicConnectionId GetClientConnectionIdAsRecipient( + const QuicPacketHeader& header, + Perspective perspective) { + DCHECK(GetQuicRestartFlag(quic_do_not_override_connection_id)); + if (perspective == Perspective::IS_CLIENT) { + return header.destination_connection_id; + } + return header.source_connection_id; +} + QuicConnectionId GetServerConnectionIdAsSender(const QuicPacketHeader& header, Perspective perspective) { if (perspective == Perspective::IS_CLIENT || @@ -48,6 +58,16 @@ return header.source_connection_id_included; } +QuicConnectionId GetClientConnectionIdAsSender(const QuicPacketHeader& header, + Perspective perspective) { + if (perspective == Perspective::IS_CLIENT || + !GetQuicRestartFlag(quic_do_not_override_connection_id)) { + return header.source_connection_id; + } + QUIC_RESTART_FLAG_COUNT_N(quic_do_not_override_connection_id, 3, 5); + return header.destination_connection_id; +} + QuicConnectionIdIncluded GetClientConnectionIdIncludedAsSender( const QuicPacketHeader& header, Perspective perspective) {
diff --git a/quic/core/quic_packets.h b/quic/core/quic_packets.h index d3c2600..b1964e7 100644 --- a/quic/core/quic_packets.h +++ b/quic/core/quic_packets.h
@@ -41,6 +41,12 @@ // Returns the destination connection ID of |header| when |perspective| is // client, and the source connection ID when |perspective| is server. QUIC_EXPORT_PRIVATE QuicConnectionId +GetClientConnectionIdAsRecipient(const QuicPacketHeader& header, + Perspective perspective); + +// Returns the destination connection ID of |header| when |perspective| is +// client, and the source connection ID when |perspective| is server. +QUIC_EXPORT_PRIVATE QuicConnectionId GetServerConnectionIdAsSender(const QuicPacketHeader& header, Perspective perspective); @@ -51,6 +57,12 @@ GetServerConnectionIdIncludedAsSender(const QuicPacketHeader& header, Perspective perspective); +// Returns the destination connection ID of |header| when |perspective| is +// server, and the source connection ID when |perspective| is client. +QUIC_EXPORT_PRIVATE QuicConnectionId +GetClientConnectionIdAsSender(const QuicPacketHeader& header, + Perspective perspective); + // Returns the destination connection ID included of |header| when |perspective| // is server, and the source connection ID included when |perspective| is // client.
diff --git a/quic/core/quic_session.h b/quic/core/quic_session.h index d8f40c9..dc7153a 100644 --- a/quic/core/quic_session.h +++ b/quic/core/quic_session.h
@@ -52,7 +52,7 @@ virtual ~Visitor() {} // Called when the connection is closed after the streams have been closed. - virtual void OnConnectionClosed(QuicConnectionId connection_id, + virtual void OnConnectionClosed(QuicConnectionId server_connection_id, QuicErrorCode error, const std::string& error_details, ConnectionCloseSource source) = 0;
diff --git a/quic/core/quic_time_wait_list_manager.cc b/quic/core/quic_time_wait_list_manager.cc index fa62fc5..d646f41 100644 --- a/quic/core/quic_time_wait_list_manager.cc +++ b/quic/core/quic_time_wait_list_manager.cc
@@ -200,7 +200,8 @@ } void QuicTimeWaitListManager::SendVersionNegotiationPacket( - QuicConnectionId connection_id, + QuicConnectionId server_connection_id, + QuicConnectionId client_connection_id, bool ietf_quic, const ParsedQuicVersionVector& supported_versions, const QuicSocketAddress& self_address, @@ -209,7 +210,8 @@ SendOrQueuePacket(QuicMakeUnique<QueuedPacket>( self_address, peer_address, QuicFramer::BuildVersionNegotiationPacket( - connection_id, ietf_quic, supported_versions)), + server_connection_id, client_connection_id, + ietf_quic, supported_versions)), packet_context.get()); }
diff --git a/quic/core/quic_time_wait_list_manager.h b/quic/core/quic_time_wait_list_manager.h index dee5d11..26e894d 100644 --- a/quic/core/quic_time_wait_list_manager.h +++ b/quic/core/quic_time_wait_list_manager.h
@@ -121,7 +121,8 @@ // Sends a version negotiation packet for |connection_id| announcing support // for |supported_versions| to |peer_address| from |self_address|. virtual void SendVersionNegotiationPacket( - QuicConnectionId connection_id, + QuicConnectionId server_connection_id, + QuicConnectionId client_connection_id, bool ietf_quic, const ParsedQuicVersionVector& supported_versions, const QuicSocketAddress& self_address,
diff --git a/quic/core/quic_time_wait_list_manager_test.cc b/quic/core/quic_time_wait_list_manager_test.cc index a391390..c1dadba 100644 --- a/quic/core/quic_time_wait_list_manager_test.cc +++ b/quic/core/quic_time_wait_list_manager_test.cc
@@ -251,15 +251,16 @@ TEST_F(QuicTimeWaitListManagerTest, SendVersionNegotiationPacket) { std::unique_ptr<QuicEncryptedPacket> packet( - QuicFramer::BuildVersionNegotiationPacket(connection_id_, false, + QuicFramer::BuildVersionNegotiationPacket(connection_id_, + EmptyQuicConnectionId(), false, AllSupportedVersions())); EXPECT_CALL(writer_, WritePacket(_, packet->length(), self_address_.host(), peer_address_, _)) .WillOnce(Return(WriteResult(WRITE_STATUS_OK, 1))); time_wait_list_manager_.SendVersionNegotiationPacket( - connection_id_, false, AllSupportedVersions(), self_address_, - peer_address_, QuicMakeUnique<QuicPerPacketContext>()); + connection_id_, EmptyQuicConnectionId(), false, AllSupportedVersions(), + self_address_, peer_address_, QuicMakeUnique<QuicPerPacketContext>()); EXPECT_EQ(0u, time_wait_list_manager_.num_connections()); }
diff --git a/quic/core/quic_versions.cc b/quic/core/quic_versions.cc index a1f4d8e..ed493a5 100644 --- a/quic/core/quic_versions.cc +++ b/quic/core/quic_versions.cc
@@ -55,6 +55,16 @@ return transport_version > QUIC_VERSION_46; } +bool ParsedQuicVersion::SupportsClientConnectionIds() const { + if (!GetQuicRestartFlag(quic_do_not_override_connection_id)) { + // Do not enable this feature in a production version until this flag has + // been deprecated. + return false; + } + return transport_version >= QUIC_VERSION_99 || + transport_version == QUIC_VERSION_UNSUPPORTED; +} + std::ostream& operator<<(std::ostream& os, const ParsedQuicVersion& version) { os << ParsedQuicVersionToString(version); return os;
diff --git a/quic/core/quic_versions.h b/quic/core/quic_versions.h index 7853134..1f6a431 100644 --- a/quic/core/quic_versions.h +++ b/quic/core/quic_versions.h
@@ -162,6 +162,9 @@ // Returns whether this version supports IETF RETRY packets. bool SupportsRetry() const; + + // Returns whether this version supports client connection ID. + bool SupportsClientConnectionIds() const; }; QUIC_EXPORT_PRIVATE ParsedQuicVersion UnsupportedQuicVersion();
diff --git a/quic/quartc/quartc_dispatcher.cc b/quic/quartc/quartc_dispatcher.cc index fd4f289..091d56b 100644 --- a/quic/quartc/quartc_dispatcher.cc +++ b/quic/quartc/quartc_dispatcher.cc
@@ -39,7 +39,7 @@ // Allow incoming packets to set our expected connection ID length. SetShouldUpdateExpectedConnectionIdLength(true); // Allow incoming packets with connection ID lengths shorter than allowed. - SetAllowShortInitialConnectionIds(true); + SetAllowShortInitialServerConnectionIds(true); // QuicDispatcher takes ownership of the writer. QuicDispatcher::InitializeWithWriter(packet_writer.release()); // NB: This must happen *after* InitializeWithWriter. It can call us back
diff --git a/quic/quartc/quartc_dispatcher.h b/quic/quartc/quartc_dispatcher.h index 26b52ac..239b6ef 100644 --- a/quic/quartc/quartc_dispatcher.h +++ b/quic/quartc/quartc_dispatcher.h
@@ -41,7 +41,7 @@ Delegate* delegate); ~QuartcDispatcher() override; - QuartcSession* CreateQuicSession(QuicConnectionId connection_id, + QuartcSession* CreateQuicSession(QuicConnectionId server_connection_id, const QuicSocketAddress& client_address, QuicStringPiece alpn, const ParsedQuicVersion& version) override;
diff --git a/quic/test_tools/mock_quic_time_wait_list_manager.h b/quic/test_tools/mock_quic_time_wait_list_manager.h index d13d5ac..8a39442 100644 --- a/quic/test_tools/mock_quic_time_wait_list_manager.h +++ b/quic/test_tools/mock_quic_time_wait_list_manager.h
@@ -45,8 +45,9 @@ PacketHeaderFormat header_format, std::unique_ptr<QuicPerPacketContext> packet_context)); - MOCK_METHOD6(SendVersionNegotiationPacket, - void(QuicConnectionId connection_id, + MOCK_METHOD7(SendVersionNegotiationPacket, + void(QuicConnectionId server_connection_id, + QuicConnectionId client_connection_id, bool ietf_quic, const ParsedQuicVersionVector& supported_versions, const QuicSocketAddress& server_address,
diff --git a/quic/test_tools/quic_framer_peer.cc b/quic/test_tools/quic_framer_peer.cc index 23486ea..6eedd60 100644 --- a/quic/test_tools/quic_framer_peer.cc +++ b/quic/test_tools/quic_framer_peer.cc
@@ -22,10 +22,10 @@ } // static -void QuicFramerPeer::SetLastSerializedConnectionId( +void QuicFramerPeer::SetLastSerializedServerConnectionId( QuicFramer* framer, - QuicConnectionId connection_id) { - framer->last_serialized_connection_id_ = connection_id; + QuicConnectionId server_connection_id) { + framer->last_serialized_server_connection_id_ = server_connection_id; } // static @@ -337,11 +337,11 @@ } // static -void QuicFramerPeer::SetExpectedConnectionIDLength( +void QuicFramerPeer::SetExpectedServerConnectionIDLength( QuicFramer* framer, - uint8_t expected_connection_id_length) { - *const_cast<uint8_t*>(&framer->expected_connection_id_length_) = - expected_connection_id_length; + uint8_t expected_server_connection_id_length) { + *const_cast<uint8_t*>(&framer->expected_server_connection_id_length_) = + expected_server_connection_id_length; } // static
diff --git a/quic/test_tools/quic_framer_peer.h b/quic/test_tools/quic_framer_peer.h index 4a5efa6..0b17c19 100644 --- a/quic/test_tools/quic_framer_peer.h +++ b/quic/test_tools/quic_framer_peer.h
@@ -22,8 +22,9 @@ QuicPacketNumberLength packet_number_length, QuicPacketNumber last_packet_number, uint64_t packet_number); - static void SetLastSerializedConnectionId(QuicFramer* framer, - QuicConnectionId connection_id); + static void SetLastSerializedServerConnectionId( + QuicFramer* framer, + QuicConnectionId server_connection_id); static void SetLargestPacketNumber(QuicFramer* framer, QuicPacketNumber packet_number); static void SetPerspective(QuicFramer* framer, Perspective perspective); @@ -159,9 +160,9 @@ QuicPacketNumberLength packet_number_length); static void SetFirstSendingPacketNumber(QuicFramer* framer, uint64_t packet_number); - static void SetExpectedConnectionIDLength( + static void SetExpectedServerConnectionIDLength( QuicFramer* framer, - uint8_t expected_connection_id_length); + uint8_t expected_server_connection_id_length); static QuicPacketNumber GetLargestDecryptedPacketNumber( QuicFramer* framer, PacketNumberSpace packet_number_space);
diff --git a/quic/test_tools/quic_test_client.cc b/quic/test_tools/quic_test_client.cc index 497dd8a..a290e7d 100644 --- a/quic/test_tools/quic_test_client.cc +++ b/quic/test_tools/quic_test_client.cc
@@ -209,8 +209,10 @@ this), QuicWrapUnique( new RecordingProofVerifier(std::move(proof_verifier)))), - override_connection_id_(EmptyQuicConnectionId()), - connection_id_overridden_(false) {} + override_server_connection_id_(EmptyQuicConnectionId()), + server_connection_id_overridden_(false), + override_client_connection_id_(EmptyQuicConnectionId()), + client_connection_id_overridden_(false) {} MockableQuicClient::~MockableQuicClient() { if (connected()) { @@ -231,13 +233,26 @@ } QuicConnectionId MockableQuicClient::GenerateNewConnectionId() { - return connection_id_overridden_ ? override_connection_id_ - : QuicClient::GenerateNewConnectionId(); + return server_connection_id_overridden_ + ? override_server_connection_id_ + : QuicClient::GenerateNewConnectionId(); } -void MockableQuicClient::UseConnectionId(QuicConnectionId connection_id) { - connection_id_overridden_ = true; - override_connection_id_ = connection_id; +void MockableQuicClient::UseConnectionId( + QuicConnectionId server_connection_id) { + server_connection_id_overridden_ = true; + override_server_connection_id_ = server_connection_id; +} + +QuicConnectionId MockableQuicClient::GetClientConnectionId() { + return client_connection_id_overridden_ ? override_client_connection_id_ + : QuicClient::GetClientConnectionId(); +} + +void MockableQuicClient::UseClientConnectionId( + QuicConnectionId client_connection_id) { + client_connection_id_overridden_ = true; + override_client_connection_id_ = client_connection_id; } void MockableQuicClient::UseWriter(QuicPacketWriterWrapper* writer) { @@ -759,9 +774,15 @@ client_->UseWriter(writer); } -void QuicTestClient::UseConnectionId(QuicConnectionId connection_id) { +void QuicTestClient::UseConnectionId(QuicConnectionId server_connection_id) { DCHECK(!connected()); - client_->UseConnectionId(connection_id); + client_->UseConnectionId(server_connection_id); +} + +void QuicTestClient::UseClientConnectionId( + QuicConnectionId client_connection_id) { + DCHECK(!connected()); + client_->UseClientConnectionId(client_connection_id); } bool QuicTestClient::MigrateSocket(const QuicIpAddress& new_host) {
diff --git a/quic/test_tools/quic_test_client.h b/quic/test_tools/quic_test_client.h index 353c29e..5f1b6fa 100644 --- a/quic/test_tools/quic_test_client.h +++ b/quic/test_tools/quic_test_client.h
@@ -55,7 +55,9 @@ ~MockableQuicClient() override; QuicConnectionId GenerateNewConnectionId() override; - void UseConnectionId(QuicConnectionId connection_id); + void UseConnectionId(QuicConnectionId server_connection_id); + QuicConnectionId GetClientConnectionId() override; + void UseClientConnectionId(QuicConnectionId client_connection_id); void UseWriter(QuicPacketWriterWrapper* writer); void set_peer_address(const QuicSocketAddress& address); @@ -69,9 +71,12 @@ const MockableQuicClientEpollNetworkHelper* mockable_network_helper() const; private: - // ConnectionId to use, if connection_id_overridden_ - QuicConnectionId override_connection_id_; - bool connection_id_overridden_; + // Server connection ID to use, if server_connection_id_overridden_ + QuicConnectionId override_server_connection_id_; + bool server_connection_id_overridden_; + // Client connection ID to use, if client_connection_id_overridden_ + QuicConnectionId override_client_connection_id_; + bool client_connection_id_overridden_; CachedNetworkParameters cached_network_paramaters_; }; @@ -219,9 +224,12 @@ // Configures client_ to take ownership of and use the writer. // Must be called before initial connect. void UseWriter(QuicPacketWriterWrapper* writer); - // If the given ConnectionId is nonzero, configures client_ to use a specific - // ConnectionId instead of a random one. - void UseConnectionId(QuicConnectionId connection_id); + // Configures client_ to use a specific server connection ID instead of a + // random one. + void UseConnectionId(QuicConnectionId server_connection_id); + // Configures client_ to use a specific client connection ID instead of an + // empty one. + void UseClientConnectionId(QuicConnectionId client_connection_id); // Returns nullptr if the maximum number of streams have already been created. QuicSpdyClientStream* GetOrCreateStream();
diff --git a/quic/test_tools/quic_test_server.cc b/quic/test_tools/quic_test_server.cc index bda2148..68c500e 100644 --- a/quic/test_tools/quic_test_server.cc +++ b/quic/test_tools/quic_test_server.cc
@@ -77,7 +77,7 @@ std::unique_ptr<QuicCryptoServerStream::Helper> session_helper, std::unique_ptr<QuicAlarmFactory> alarm_factory, QuicSimpleServerBackend* quic_simple_server_backend, - uint8_t expected_connection_id_length) + uint8_t expected_server_connection_id_length) : QuicSimpleDispatcher(config, crypto_config, version_manager, @@ -85,7 +85,7 @@ std::move(session_helper), std::move(alarm_factory), quic_simple_server_backend, - expected_connection_id_length), + expected_server_connection_id_length), session_factory_(nullptr), stream_factory_(nullptr), crypto_stream_factory_(nullptr) {} @@ -170,13 +170,13 @@ const QuicConfig& config, const ParsedQuicVersionVector& supported_versions, QuicSimpleServerBackend* quic_simple_server_backend, - uint8_t expected_connection_id_length) + uint8_t expected_server_connection_id_length) : QuicServer(std::move(proof_source), config, QuicCryptoServerConfig::ConfigOptions(), supported_versions, quic_simple_server_backend, - expected_connection_id_length) {} + expected_server_connection_id_length) {} QuicDispatcher* QuicTestServer::CreateQuicDispatcher() { return new QuicTestDispatcher( @@ -186,7 +186,7 @@ std::unique_ptr<QuicCryptoServerStream::Helper>( new QuicSimpleCryptoServerStreamHelper(QuicRandom::GetInstance())), QuicMakeUnique<QuicEpollAlarmFactory>(epoll_server()), server_backend(), - expected_connection_id_length()); + expected_server_connection_id_length()); } void QuicTestServer::SetSessionFactory(SessionFactory* factory) {
diff --git a/quic/test_tools/quic_test_server.h b/quic/test_tools/quic_test_server.h index 52fb7c5..3661b7a 100644 --- a/quic/test_tools/quic_test_server.h +++ b/quic/test_tools/quic_test_server.h
@@ -69,7 +69,7 @@ const QuicConfig& config, const ParsedQuicVersionVector& supported_versions, QuicSimpleServerBackend* quic_simple_server_backend, - uint8_t expected_connection_id_length); + uint8_t expected_server_connection_id_length); // Create a custom dispatcher which creates custom sessions. QuicDispatcher* CreateQuicDispatcher() override;
diff --git a/quic/tools/quic_client_base.cc b/quic/tools/quic_client_base.cc index 57cbd9a..fd4776f 100644 --- a/quic/tools/quic_client_base.cc +++ b/quic/tools/quic_client_base.cc
@@ -135,6 +135,7 @@ can_reconnect_with_different_version ? ParsedQuicVersionVector{mutual_version} : supported_versions())); + session()->connection()->set_client_connection_id(GetClientConnectionId()); if (initial_max_packet_length_ != 0) { session()->connection()->SetMaxPacketLength(initial_max_packet_length_); } @@ -333,6 +334,10 @@ return QuicUtils::CreateRandomConnectionId(); } +QuicConnectionId QuicClientBase::GetClientConnectionId() { + return EmptyQuicConnectionId(); +} + bool QuicClientBase::CanReconnectWithDifferentVersion( ParsedQuicVersion* version) const { if (session_ == nullptr || session_->connection() == nullptr ||
diff --git a/quic/tools/quic_client_base.h b/quic/tools/quic_client_base.h index 0cc4d71..134404b 100644 --- a/quic/tools/quic_client_base.h +++ b/quic/tools/quic_client_base.h
@@ -268,6 +268,9 @@ // connection ID). virtual QuicConnectionId GenerateNewConnectionId(); + // Returns the client connection ID to use. + virtual QuicConnectionId GetClientConnectionId(); + QuicAlarmFactory* alarm_factory() { return alarm_factory_.get(); } // Subclasses may need to explicitly clear the session on destruction
diff --git a/quic/tools/quic_server.cc b/quic/tools/quic_server.cc index 4ac1dac..3696b7c 100644 --- a/quic/tools/quic_server.cc +++ b/quic/tools/quic_server.cc
@@ -64,7 +64,7 @@ const QuicCryptoServerConfig::ConfigOptions& crypto_config_options, const ParsedQuicVersionVector& supported_versions, QuicSimpleServerBackend* quic_simple_server_backend, - uint8_t expected_connection_id_length) + uint8_t expected_server_connection_id_length) : port_(0), fd_(-1), packets_dropped_(0), @@ -80,7 +80,8 @@ version_manager_(supported_versions), packet_reader_(new QuicPacketReader()), quic_simple_server_backend_(quic_simple_server_backend), - expected_connection_id_length_(expected_connection_id_length) { + expected_server_connection_id_length_( + expected_server_connection_id_length) { DCHECK(quic_simple_server_backend_); Initialize(); } @@ -159,7 +160,7 @@ new QuicSimpleCryptoServerStreamHelper(QuicRandom::GetInstance())), std::unique_ptr<QuicEpollAlarmFactory>( new QuicEpollAlarmFactory(&epoll_server_)), - quic_simple_server_backend_, expected_connection_id_length_); + quic_simple_server_backend_, expected_server_connection_id_length_); } void QuicServer::HandleEventsForever() {
diff --git a/quic/tools/quic_server.h b/quic/tools/quic_server.h index 6fb1646..1f9c225 100644 --- a/quic/tools/quic_server.h +++ b/quic/tools/quic_server.h
@@ -43,7 +43,7 @@ const QuicCryptoServerConfig::ConfigOptions& server_config_options, const ParsedQuicVersionVector& supported_versions, QuicSimpleServerBackend* quic_simple_server_backend, - uint8_t expected_connection_id_length); + uint8_t expected_server_connection_id_length); QuicServer(const QuicServer&) = delete; QuicServer& operator=(const QuicServer&) = delete; @@ -103,8 +103,8 @@ void set_silent_close(bool value) { silent_close_ = value; } - uint8_t expected_connection_id_length() { - return expected_connection_id_length_; + uint8_t expected_server_connection_id_length() { + return expected_server_connection_id_length_; } private: @@ -155,7 +155,7 @@ QuicSimpleServerBackend* quic_simple_server_backend_; // unowned. // Connection ID length expected to be read on incoming IETF short headers. - uint8_t expected_connection_id_length_; + uint8_t expected_server_connection_id_length_; }; } // namespace quic
diff --git a/quic/tools/quic_simple_dispatcher.cc b/quic/tools/quic_simple_dispatcher.cc index 4706b60..0f4d447 100644 --- a/quic/tools/quic_simple_dispatcher.cc +++ b/quic/tools/quic_simple_dispatcher.cc
@@ -16,14 +16,14 @@ std::unique_ptr<QuicCryptoServerStream::Helper> session_helper, std::unique_ptr<QuicAlarmFactory> alarm_factory, QuicSimpleServerBackend* quic_simple_server_backend, - uint8_t expected_connection_id_length) + uint8_t expected_server_connection_id_length) : QuicDispatcher(config, crypto_config, version_manager, std::move(helper), std::move(session_helper), std::move(alarm_factory), - expected_connection_id_length), + expected_server_connection_id_length), quic_simple_server_backend_(quic_simple_server_backend) {} QuicSimpleDispatcher::~QuicSimpleDispatcher() = default;
diff --git a/quic/tools/quic_simple_dispatcher.h b/quic/tools/quic_simple_dispatcher.h index 46d976c..8a6cf85 100644 --- a/quic/tools/quic_simple_dispatcher.h +++ b/quic/tools/quic_simple_dispatcher.h
@@ -21,7 +21,7 @@ std::unique_ptr<QuicCryptoServerStream::Helper> session_helper, std::unique_ptr<QuicAlarmFactory> alarm_factory, QuicSimpleServerBackend* quic_simple_server_backend, - uint8_t expected_connection_id_length); + uint8_t expected_server_connection_id_length); ~QuicSimpleDispatcher() override;