Refactor QuicDispatcher tests The dispatcher tests were incorrectly using QUIC_CRYPTO CHLOs with TLS versions. This CL updates them to use real packets from GetFirstFlightOfPackets(). gfe-relnote: n/a, test-only PiperOrigin-RevId: 307802112 Change-Id: I0c5af0a273f74036da72587da3838a489064acaf
diff --git a/quic/core/quic_dispatcher_test.cc b/quic/core/quic_dispatcher_test.cc index d52651a..cc98335 100644 --- a/quic/core/quic_dispatcher_test.cc +++ b/quic/core/quic_dispatcher_test.cc
@@ -27,6 +27,7 @@ #include "net/third_party/quiche/src/quic/platform/api/quic_test.h" #include "net/third_party/quiche/src/quic/test_tools/crypto_test_utils.h" #include "net/third_party/quiche/src/quic/test_tools/fake_proof_source.h" +#include "net/third_party/quiche/src/quic/test_tools/first_flight.h" #include "net/third_party/quiche/src/quic/test_tools/mock_quic_time_wait_list_manager.h" #include "net/third_party/quiche/src/quic/test_tools/quic_buffered_packet_store_peer.h" #include "net/third_party/quiche/src/quic/test_tools/quic_crypto_server_config_peer.h" @@ -294,17 +295,25 @@ client_connection_id_included, packet_number_length, &versions)); std::unique_ptr<QuicReceivedPacket> received_packet( ConstructReceivedPacket(*packet, mock_helper_.GetClock()->Now())); + ProcessReceivedPacket(std::move(received_packet), peer_address, version, + server_connection_id); + } - if (ChloExtractor::Extract(*packet, version, {}, nullptr, + void ProcessReceivedPacket( + std::unique_ptr<QuicReceivedPacket> received_packet, + const QuicSocketAddress& peer_address, + const ParsedQuicVersion& version, + const QuicConnectionId& server_connection_id) { + if (ChloExtractor::Extract(*received_packet, version, {}, nullptr, server_connection_id.length())) { // Add CHLO packet to the beginning to be verified first, because it is // also processed first by new session. data_connection_map_[server_connection_id].push_front( - std::string(packet->data(), packet->length())); + std::string(received_packet->data(), received_packet->length())); } else { // For non-CHLO, always append to last. data_connection_map_[server_connection_id].push_back( - std::string(packet->data(), packet->length())); + std::string(received_packet->data(), received_packet->length())); } dispatcher_->ProcessPacket(server_address_, peer_address, *received_packet); } @@ -353,16 +362,41 @@ std::string SerializeCHLO() { CryptoHandshakeMessage client_hello; client_hello.set_tag(kCHLO); - client_hello.SetStringPiece(kALPN, "hq"); + client_hello.SetStringPiece(kALPN, ExpectedAlpn()); return std::string(client_hello.GetSerialized().AsStringPiece()); } + void ProcessFirstFlight(const QuicSocketAddress& peer_address, + const QuicConnectionId& server_connection_id) { + ProcessFirstFlight(version_, peer_address, server_connection_id); + } + + void ProcessFirstFlight(const ParsedQuicVersion& version, + const QuicSocketAddress& peer_address, + const QuicConnectionId& server_connection_id) { + ProcessFirstFlight(version, peer_address, server_connection_id, + EmptyQuicConnectionId()); + } + + void ProcessFirstFlight(const ParsedQuicVersion& version, + const QuicSocketAddress& peer_address, + const QuicConnectionId& server_connection_id, + const QuicConnectionId& client_connection_id) { + std::vector<std::unique_ptr<QuicReceivedPacket>> packets = + GetFirstFlightOfPackets(version, server_connection_id, + client_connection_id); + for (auto&& packet : packets) { + ProcessReceivedPacket(std::move(packet), peer_address, version, + server_connection_id); + } + } + std::string ExpectedAlpnForVersion(ParsedQuicVersion version) { if (version.handshake_protocol == PROTOCOL_TLS1_3) { // TODO(b/149597791) Remove this once we can parse ALPN with TLS. return ""; } - return "hq"; + return AlpnForVersion(version); } std::string ExpectedAlpn() { return ExpectedAlpnForVersion(version_); } @@ -388,8 +422,7 @@ EXPECT_CALL(*dispatcher_, ShouldCreateOrBufferPacketForConnection( ReceivedPacketInfoConnectionIdEquals(connection_id))); - ProcessPacket(client_address, connection_id, true, version, SerializeCHLO(), - true, CONNECTION_ID_PRESENT, PACKET_4BYTE_PACKET_NUMBER, 1); + ProcessFirstFlight(version, client_address, connection_id); } void VerifyVersionNotSupported(ParsedQuicVersion version) { @@ -398,8 +431,7 @@ EXPECT_CALL(*dispatcher_, CreateQuicSession(connection_id, client_address, _, _)) .Times(0); - ProcessPacket(client_address, connection_id, true, version, SerializeCHLO(), - true, CONNECTION_ID_PRESENT, PACKET_4BYTE_PACKET_NUMBER, 1); + ProcessFirstFlight(version, client_address, connection_id); } ParsedQuicVersion version_; @@ -452,9 +484,8 @@ EXPECT_CALL(*dispatcher_, ShouldCreateOrBufferPacketForConnection( ReceivedPacketInfoConnectionIdEquals(TestConnectionId(1)))); - ProcessPacket(client_address, TestConnectionId(1), true, version_, - SerializeCHLO(), true, CONNECTION_ID_PRESENT, - PACKET_4BYTE_PACKET_NUMBER, 1); + + ProcessFirstFlight(client_address, TestConnectionId(1)); } TEST_P(QuicDispatcherTestAllVersions, ProcessPackets) { @@ -475,7 +506,7 @@ EXPECT_CALL(*dispatcher_, ShouldCreateOrBufferPacketForConnection( ReceivedPacketInfoConnectionIdEquals(TestConnectionId(1)))); - ProcessPacket(client_address, TestConnectionId(1), true, SerializeCHLO()); + ProcessFirstFlight(client_address, TestConnectionId(1)); EXPECT_CALL(*dispatcher_, CreateQuicSession(TestConnectionId(2), client_address, @@ -492,7 +523,7 @@ EXPECT_CALL(*dispatcher_, ShouldCreateOrBufferPacketForConnection( ReceivedPacketInfoConnectionIdEquals(TestConnectionId(2)))); - ProcessPacket(client_address, TestConnectionId(2), true, SerializeCHLO()); + ProcessFirstFlight(client_address, TestConnectionId(2)); EXPECT_CALL(*reinterpret_cast<MockQuicConnection*>(session1_->connection()), ProcessUdpPacket(_, _, _)) @@ -525,9 +556,7 @@ EXPECT_CALL(*dispatcher_, ShouldCreateOrBufferPacketForConnection( ReceivedPacketInfoConnectionIdEquals(TestConnectionId(1)))); - ProcessPacket(client_address, TestConnectionId(1), true, version_, - SerializeCHLO(), true, CONNECTION_ID_PRESENT, - PACKET_4BYTE_PACKET_NUMBER, 1); + ProcessFirstFlight(client_address, TestConnectionId(1)); // Packet number 256 with packet number length 1 would be considered as 0 in // dispatcher. ProcessPacket(client_address, TestConnectionId(1), false, version_, "", true, @@ -543,13 +572,8 @@ *time_wait_list_manager_, SendVersionNegotiationPacket(TestConnectionId(1), _, _, _, _, _, _, _)) .Times(1); - // Pad the CHLO message with enough data to make the packet large enough - // to trigger version negotiation. - std::string chlo = SerializeCHLO() + std::string(1200, 'a'); - DCHECK_LE(1200u, chlo.length()); - ProcessPacket(client_address, TestConnectionId(1), true, - QuicVersionReservedForNegotiation(), chlo, true, - CONNECTION_ID_PRESENT, PACKET_4BYTE_PACKET_NUMBER, 1); + ProcessFirstFlight(QuicVersionReservedForNegotiation(), client_address, + TestConnectionId(1)); } TEST_P(QuicDispatcherTestOneVersion, @@ -562,13 +586,8 @@ EXPECT_CALL(*time_wait_list_manager_, SendVersionNegotiationPacket(connection_id, _, _, _, _, _, _, _)) .Times(1); - // Pad the CHLO message with enough data to make the packet large enough - // to trigger version negotiation. - std::string chlo = SerializeCHLO() + std::string(1200, 'a'); - DCHECK_LE(1200u, chlo.length()); - ProcessPacket(client_address, connection_id, true, - QuicVersionReservedForNegotiation(), chlo, true, - CONNECTION_ID_PRESENT, PACKET_4BYTE_PACKET_NUMBER, 1); + ProcessFirstFlight(QuicVersionReservedForNegotiation(), client_address, + connection_id); } TEST_P(QuicDispatcherTestOneVersion, @@ -581,14 +600,8 @@ SendVersionNegotiationPacket( TestConnectionId(1), TestConnectionId(2), _, _, _, _, _, _)) .Times(1); - // Pad the CHLO message with enough data to make the packet large enough - // to trigger version negotiation. - std::string chlo = SerializeCHLO() + std::string(1200, 'a'); - DCHECK_LE(1200u, chlo.length()); - ProcessPacket(client_address, TestConnectionId(1), TestConnectionId(2), true, - QuicVersionReservedForNegotiation(), chlo, true, - CONNECTION_ID_PRESENT, CONNECTION_ID_PRESENT, - PACKET_4BYTE_PACKET_NUMBER, 1); + ProcessFirstFlight(QuicVersionReservedForNegotiation(), client_address, + TestConnectionId(1), TestConnectionId(2)); } TEST_P(QuicDispatcherTestOneVersion, NoVersionNegotiationWithSmallPacket) { @@ -652,7 +665,7 @@ EXPECT_CALL(*dispatcher_, ShouldCreateOrBufferPacketForConnection( ReceivedPacketInfoConnectionIdEquals(TestConnectionId(1)))); - ProcessPacket(client_address, TestConnectionId(1), true, SerializeCHLO()); + ProcessFirstFlight(client_address, TestConnectionId(1)); EXPECT_CALL(*reinterpret_cast<MockQuicConnection*>(session1_->connection()), CloseConnection(QUIC_PEER_GOING_AWAY, _, _)); @@ -681,7 +694,7 @@ EXPECT_CALL(*dispatcher_, ShouldCreateOrBufferPacketForConnection( ReceivedPacketInfoConnectionIdEquals(TestConnectionId(1)))); - ProcessPacket(client_address, connection_id, true, SerializeCHLO()); + ProcessFirstFlight(client_address, connection_id); // Now close the connection, which should add it to the time wait list. session1_->connection()->CloseConnection( @@ -717,7 +730,8 @@ .Times(0); EXPECT_CALL(*time_wait_list_manager_, SendPublicReset(_, _, _, _, _)) .Times(1); - ProcessPacket(client_address, connection_id, false, SerializeCHLO()); + ProcessPacket(client_address, connection_id, /*has_version_flag=*/false, + "data"); } TEST_P(QuicDispatcherTestAllVersions, @@ -772,7 +786,7 @@ EXPECT_CALL(*dispatcher_, ShouldCreateOrBufferPacketForConnection( ReceivedPacketInfoConnectionIdEquals(bad_connection_id))); - ProcessPacket(client_address, bad_connection_id, true, SerializeCHLO()); + ProcessFirstFlight(client_address, bad_connection_id); } // Makes sure zero-byte connection IDs are replaced by 8-byte ones. @@ -809,7 +823,7 @@ EXPECT_CALL(*dispatcher_, ShouldCreateOrBufferPacketForConnection( ReceivedPacketInfoConnectionIdEquals(bad_connection_id))); - ProcessPacket(client_address, bad_connection_id, true, SerializeCHLO()); + ProcessFirstFlight(client_address, bad_connection_id); } // Makes sure TestConnectionId(1) creates a new connection and @@ -839,7 +853,7 @@ EXPECT_CALL(*dispatcher_, ShouldCreateOrBufferPacketForConnection( ReceivedPacketInfoConnectionIdEquals(TestConnectionId(1)))); - ProcessPacket(client_address, TestConnectionId(1), true, SerializeCHLO()); + ProcessFirstFlight(client_address, TestConnectionId(1)); EXPECT_CALL(*dispatcher_, CreateQuicSession(fixed_connection_id, client_address, @@ -857,7 +871,7 @@ EXPECT_CALL(*dispatcher_, ShouldCreateOrBufferPacketForConnection( ReceivedPacketInfoConnectionIdEquals(bad_connection_id))); - ProcessPacket(client_address, bad_connection_id, true, SerializeCHLO()); + ProcessFirstFlight(client_address, bad_connection_id); EXPECT_CALL(*reinterpret_cast<MockQuicConnection*>(session1_->connection()), ProcessUdpPacket(_, _, _)) @@ -881,7 +895,8 @@ EXPECT_CALL(*time_wait_list_manager_, AddConnectionIdToTimeWait(_, _, _, _, _)) .Times(0); - ProcessPacket(client_address, TestConnectionId(1), true, SerializeCHLO()); + ProcessPacket(client_address, TestConnectionId(1), /*has_version_flag=*/true, + "data"); } TEST_P(QuicDispatcherTestAllVersions, @@ -900,10 +915,14 @@ EXPECT_CALL(*time_wait_list_manager_, AddConnectionIdToTimeWait(_, _, _, _, _)) .Times(0); - ProcessPacket(client_address, EmptyQuicConnectionId(), true, SerializeCHLO()); + ProcessFirstFlight(client_address, EmptyQuicConnectionId()); } TEST_P(QuicDispatcherTestAllVersions, OKSeqNoPacketProcessed) { + if (version_.handshake_protocol == PROTOCOL_TLS1_3) { + // QUIC+TLS allows clients to start with any packet number. + return; + } QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1); QuicConnectionId connection_id = TestConnectionId(1); @@ -1342,7 +1361,7 @@ .WillOnce(WithArg<2>(Invoke([this](const QuicEncryptedPacket& packet) { ValidatePacket(TestConnectionId(1), packet); }))); - ProcessPacket(client_address, TestConnectionId(1), true, SerializeCHLO()); + ProcessFirstFlight(client_address, TestConnectionId(1)); dispatcher_->StopAcceptingNewConnections(); EXPECT_FALSE(dispatcher_->accept_new_connections()); @@ -1352,7 +1371,7 @@ CreateQuicSession(TestConnectionId(2), client_address, Eq(ExpectedAlpn()), _)) .Times(0u); - ProcessPacket(client_address, TestConnectionId(2), true, SerializeCHLO()); + ProcessFirstFlight(client_address, TestConnectionId(2)); // Existing connections should be able to continue. EXPECT_CALL(*reinterpret_cast<MockQuicConnection*>(session1_->connection()), @@ -1373,7 +1392,7 @@ CreateQuicSession(TestConnectionId(2), client_address, Eq(ExpectedAlpn()), _)) .Times(0u); - ProcessPacket(client_address, TestConnectionId(2), true, SerializeCHLO()); + ProcessFirstFlight(client_address, TestConnectionId(2)); dispatcher_->StartAcceptingNewConnections(); EXPECT_TRUE(dispatcher_->accept_new_connections()); @@ -1390,7 +1409,7 @@ .WillOnce(WithArg<2>(Invoke([this](const QuicEncryptedPacket& packet) { ValidatePacket(TestConnectionId(1), packet); }))); - ProcessPacket(client_address, TestConnectionId(1), true, SerializeCHLO()); + ProcessFirstFlight(client_address, TestConnectionId(1)); } // Verify the stopgap test: Packets with truncated connection IDs should be @@ -1465,7 +1484,7 @@ EXPECT_CALL(*dispatcher_, ShouldCreateOrBufferPacketForConnection( ReceivedPacketInfoConnectionIdEquals(TestConnectionId(1)))); - ProcessPacket(client_address, TestConnectionId(1), true, SerializeCHLO()); + ProcessFirstFlight(client_address, TestConnectionId(1)); EXPECT_CALL(*dispatcher_, CreateQuicSession(_, client_address, Eq(ExpectedAlpn()), _)) @@ -1481,7 +1500,7 @@ EXPECT_CALL(*dispatcher_, ShouldCreateOrBufferPacketForConnection( ReceivedPacketInfoConnectionIdEquals(TestConnectionId(2)))); - ProcessPacket(client_address, TestConnectionId(2), true, SerializeCHLO()); + ProcessFirstFlight(client_address, TestConnectionId(2)); blocked_list_ = QuicDispatcherPeer::GetWriteBlockedList(dispatcher_.get()); }