Add OnTransportParametersSent and Received to QuicConnectionDebugVisitor
diff --git a/quic/core/http/end_to_end_test.cc b/quic/core/http/end_to_end_test.cc index 399e19f..6b7bfd5 100644 --- a/quic/core/http/end_to_end_test.cc +++ b/quic/core/http/end_to_end_test.cc
@@ -70,6 +70,9 @@ using spdy::SpdyHeaderBlock; using spdy::SpdySerializedFrame; using spdy::SpdySettingsIR; +using ::testing::_; +using ::testing::Invoke; +using ::testing::NiceMock; namespace quic { namespace test { @@ -662,9 +665,9 @@ TEST_P(EndToEndTest, SimpleRequestResponseForcedVersionNegotiation) { client_supported_versions_.insert(client_supported_versions_.begin(), QuicVersionReservedForNegotiation()); - testing::NiceMock<MockQuicConnectionDebugVisitor> visitor; + NiceMock<MockQuicConnectionDebugVisitor> visitor; connection_debug_visitor_ = &visitor; - EXPECT_CALL(visitor, OnVersionNegotiationPacket(testing::_)).Times(1); + EXPECT_CALL(visitor, OnVersionNegotiationPacket(_)).Times(1); ASSERT_TRUE(Initialize()); ASSERT_TRUE(ServerSendsVersionNegotiation()); @@ -2542,7 +2545,7 @@ QuicFramer framer(server_supported_versions_, QuicTime::Zero(), Perspective::IS_SERVER, kQuicDefaultConnectionIdLength); std::unique_ptr<QuicEncryptedPacket> packet; - testing::NiceMock<MockQuicConnectionDebugVisitor> visitor; + NiceMock<MockQuicConnectionDebugVisitor> visitor; GetClientConnection()->set_debug_visitor(&visitor); if (VersionHasIetfInvariantHeader(client_connection->transport_version())) { packet = framer.BuildIetfStatelessResetPacket(incorrect_connection_id, @@ -2618,7 +2621,7 @@ VersionHasIetfInvariantHeader(client_connection->transport_version()), client_connection->version().HasLengthPrefixedConnectionIds(), server_supported_versions_)); - testing::NiceMock<MockQuicConnectionDebugVisitor> visitor; + NiceMock<MockQuicConnectionDebugVisitor> visitor; client_connection->set_debug_visitor(&visitor); EXPECT_CALL(visitor, OnIncorrectConnectionId(incorrect_connection_id)) .Times(1); @@ -4264,6 +4267,42 @@ ->CanCreatePushStreamWithId(kMaxQuicStreamId)); } +TEST_P(EndToEndTest, CustomTransportParameters) { + if (!version_.UsesTls()) { + // Custom transport parameters are only supported with TLS. + ASSERT_TRUE(Initialize()); + return; + } + constexpr auto kCustomParameter = + static_cast<TransportParameters::TransportParameterId>(0xff34); + client_config_.custom_transport_parameters_to_send()[kCustomParameter] = + "test"; + NiceMock<MockQuicConnectionDebugVisitor> visitor; + connection_debug_visitor_ = &visitor; + EXPECT_CALL(visitor, OnTransportParametersSent(_)) + .WillOnce(Invoke([kCustomParameter]( + const TransportParameters& transport_parameters) { + ASSERT_NE(transport_parameters.custom_parameters.find(kCustomParameter), + transport_parameters.custom_parameters.end()); + EXPECT_EQ(transport_parameters.custom_parameters.at(kCustomParameter), + "test"); + })); + EXPECT_CALL(visitor, OnTransportParametersReceived(_)).Times(1); + ASSERT_TRUE(Initialize()); + + EXPECT_TRUE(client_->client()->WaitForCryptoHandshakeConfirmed()); + + server_thread_->Pause(); + QuicConfig server_config = *GetServerSession()->config(); + server_thread_->Resume(); + ASSERT_NE(server_config.received_custom_transport_parameters().find( + kCustomParameter), + server_config.received_custom_transport_parameters().end()); + EXPECT_EQ( + server_config.received_custom_transport_parameters().at(kCustomParameter), + "test"); +} + TEST_P(EndToEndTest, DISABLED_CustomTransportParameters) { // TODO(b/155316241): Enable this test. constexpr auto kCustomParameter =
diff --git a/quic/core/quic_connection.cc b/quic/core/quic_connection.cc index 0747d4f..bb32b6e 100644 --- a/quic/core/quic_connection.cc +++ b/quic/core/quic_connection.cc
@@ -841,6 +841,20 @@ } } +void QuicConnection::OnTransportParametersSent( + const TransportParameters& transport_parameters) const { + if (debug_visitor_ != nullptr) { + debug_visitor_->OnTransportParametersSent(transport_parameters); + } +} + +void QuicConnection::OnTransportParametersReceived( + const TransportParameters& transport_parameters) const { + if (debug_visitor_ != nullptr) { + debug_visitor_->OnTransportParametersReceived(transport_parameters); + } +} + void QuicConnection::OnDecryptedPacket(EncryptionLevel level) { last_decrypted_packet_level_ = level; last_packet_decrypted_ = true;
diff --git a/quic/core/quic_connection.h b/quic/core/quic_connection.h index 6586fb2..9e39777 100644 --- a/quic/core/quic_connection.h +++ b/quic/core/quic_connection.h
@@ -26,6 +26,7 @@ #include "net/third_party/quiche/src/quic/core/crypto/quic_decrypter.h" #include "net/third_party/quiche/src/quic/core/crypto/quic_encrypter.h" +#include "net/third_party/quiche/src/quic/core/crypto/transport_parameters.h" #include "net/third_party/quiche/src/quic/core/frames/quic_max_streams_frame.h" #include "net/third_party/quiche/src/quic/core/proto/cached_network_parameters_proto.h" #include "net/third_party/quiche/src/quic/core/quic_alarm.h" @@ -341,6 +342,14 @@ // Called when |count| packet numbers have been skipped. virtual void OnNPacketNumbersSkipped(QuicPacketCount /*count*/) {} + + // Called for QUIC+TLS versions when we send transport parameters. + virtual void OnTransportParametersSent( + const TransportParameters& /*transport_parameters*/) {} + + // Called for QUIC+TLS versions when we receive transport parameters. + virtual void OnTransportParametersReceived( + const TransportParameters& /*transport_parameters*/) {} }; class QUIC_EXPORT_PRIVATE QuicConnectionHelperInterface { @@ -965,6 +974,14 @@ // Called when version is considered negotiated. void OnSuccessfulVersionNegotiation(); + // Called for QUIC+TLS versions when we send transport parameters. + void OnTransportParametersSent( + const TransportParameters& transport_parameters) const; + + // Called for QUIC+TLS versions when we receive transport parameters. + void OnTransportParametersReceived( + const TransportParameters& transport_parameters) const; + protected: // Calls cancel() on all the alarms owned by this connection. void CancelAllAlarms();
diff --git a/quic/core/tls_client_handshaker.cc b/quic/core/tls_client_handshaker.cc index 30a9ac4..aed628e 100644 --- a/quic/core/tls_client_handshaker.cc +++ b/quic/core/tls_client_handshaker.cc
@@ -216,6 +216,9 @@ params.google_quic_params->SetStringPiece(kUAID, user_agent_id_); } + // Notify QuicConnectionDebugVisitor. + session()->connection()->OnTransportParametersSent(params); + std::vector<uint8_t> param_bytes; return SerializeTransportParameters(session()->connection()->version(), params, ¶m_bytes) && @@ -244,6 +247,10 @@ return false; } + // Notify QuicConnectionDebugVisitor. + session()->connection()->OnTransportParametersReceived( + *received_transport_params_); + // When interoperating with non-Google implementations that do not send // the version extension, set it to what we expect. if (received_transport_params_->version == 0) {
diff --git a/quic/core/tls_server_handshaker.cc b/quic/core/tls_server_handshaker.cc index 5bd5b3d..99333a6 100644 --- a/quic/core/tls_server_handshaker.cc +++ b/quic/core/tls_server_handshaker.cc
@@ -284,6 +284,9 @@ return false; } + // Notify QuicConnectionDebugVisitor. + session()->connection()->OnTransportParametersReceived(client_params); + // When interoperating with non-Google implementations that do not send // the version extension, set it to what we expect. if (client_params.version == 0) { @@ -318,6 +321,10 @@ // TODO(nharper): Provide an actual value for the stateless reset token. server_params.stateless_reset_token.resize(16); + + // Notify QuicConnectionDebugVisitor. + session()->connection()->OnTransportParametersSent(server_params); + std::vector<uint8_t> server_params_bytes; if (!SerializeTransportParameters(session()->connection()->version(), server_params, &server_params_bytes) ||
diff --git a/quic/test_tools/quic_test_utils.h b/quic/test_tools/quic_test_utils.h index 7d1d9db..031833d 100644 --- a/quic/test_tools/quic_test_utils.h +++ b/quic/test_tools/quic_test_utils.h
@@ -15,6 +15,7 @@ #include "net/third_party/quiche/src/quic/core/congestion_control/loss_detection_interface.h" #include "net/third_party/quiche/src/quic/core/congestion_control/send_algorithm_interface.h" +#include "net/third_party/quiche/src/quic/core/crypto/transport_parameters.h" #include "net/third_party/quiche/src/quic/core/http/quic_client_push_promise_index.h" #include "net/third_party/quiche/src/quic/core/http/quic_server_session_base.h" #include "net/third_party/quiche/src/quic/core/http/quic_spdy_session.h" @@ -1415,6 +1416,16 @@ OnVersionNegotiationPacket, (const QuicVersionNegotiationPacket&), (override)); + + MOCK_METHOD(void, + OnTransportParametersSent, + (const TransportParameters&), + (override)); + + MOCK_METHOD(void, + OnTransportParametersReceived, + (const TransportParameters&), + (override)); }; class MockReceivedPacketManager : public QuicReceivedPacketManager {