Add an end-to-end test for WebTransport subprotocol negotiation.
PiperOrigin-RevId: 748789470
diff --git a/quiche/quic/core/http/end_to_end_test.cc b/quiche/quic/core/http/end_to_end_test.cc
index 0d1cd23..0b4d55e 100644
--- a/quiche/quic/core/http/end_to_end_test.cc
+++ b/quiche/quic/core/http/end_to_end_test.cc
@@ -93,6 +93,7 @@
#include "quiche/common/platform/api/quiche_test.h"
#include "quiche/common/quiche_stream.h"
#include "quiche/common/test_tools/quiche_test_utils.h"
+#include "quiche/web_transport/web_transport_headers.h"
using quiche::HttpHeaderBlock;
using spdy::SpdyFramer;
@@ -854,7 +855,8 @@
WebTransportHttp3* CreateWebTransportSession(
const std::string& path, bool wait_for_server_response,
- QuicSpdyStream** connect_stream_out = nullptr) {
+ std::initializer_list<std::pair<absl::string_view, absl::string_view>>
+ extra_headers = {}) {
// Wait until we receive the settings from the server indicating
// WebTransport support.
client_->WaitUntil(
@@ -869,6 +871,9 @@
headers[":path"] = path;
headers[":method"] = "CONNECT";
headers[":protocol"] = "webtransport";
+ for (const auto& [key, value] : extra_headers) {
+ headers[key] = std::string(value);
+ }
client_->SendMessage(headers, "", /*fin=*/false);
QuicSpdyStream* stream = client_->latest_created_stream();
@@ -886,9 +891,6 @@
[stream]() { return stream->headers_decompressed(); });
EXPECT_TRUE(session->ready());
}
- if (connect_stream_out != nullptr) {
- *connect_stream_out = stream;
- }
return session;
}
@@ -7180,6 +7182,40 @@
server_thread_->Resume();
}
+TEST_P(EndToEndTest, WebTransportSessionProtocolNegotiation) {
+ enable_web_transport_ = true;
+ ASSERT_TRUE(Initialize());
+
+ if (!version_.UsesHttp3()) {
+ return;
+ }
+
+ WebTransportHttp3* session = CreateWebTransportSession(
+ "/selected-subprotocol", /*wait_for_server_response=*/true,
+ {{webtransport::kSubprotocolRequestHeader, "a, b, c, d"},
+ {"subprotocol-index", "1"}});
+ ASSERT_NE(session, nullptr);
+ NiceMock<MockWebTransportSessionVisitor>& visitor =
+ SetupWebTransportVisitor(session);
+ EXPECT_EQ(session->GetNegotiatedSubprotocol(), "b");
+
+ WebTransportStream* received_stream =
+ session->AcceptIncomingUnidirectionalStream();
+ if (received_stream == nullptr) {
+ // Retry if reordering happens.
+ bool stream_received = false;
+ EXPECT_CALL(visitor, OnIncomingUnidirectionalStreamAvailable())
+ .WillOnce(Assign(&stream_received, true));
+ client_->WaitUntil(2000, [&stream_received]() { return stream_received; });
+ received_stream = session->AcceptIncomingUnidirectionalStream();
+ }
+ ASSERT_TRUE(received_stream != nullptr);
+ std::string received_data;
+ WebTransportStream::ReadResult result = received_stream->Read(&received_data);
+ EXPECT_EQ(received_data, "b");
+ EXPECT_TRUE(result.fin);
+}
+
TEST_P(EndToEndTest, WebTransportSessionSetupWithEchoWithSuffix) {
enable_web_transport_ = true;
ASSERT_TRUE(Initialize());
diff --git a/quiche/quic/test_tools/quic_test_backend.cc b/quiche/quic/test_tools/quic_test_backend.cc
index 93e7e44..c40db49 100644
--- a/quiche/quic/test_tools/quic_test_backend.cc
+++ b/quiche/quic/test_tools/quic_test_backend.cc
@@ -10,6 +10,8 @@
#include <utility>
#include <vector>
+#include "absl/status/statusor.h"
+#include "absl/strings/numbers.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_split.h"
#include "absl/strings/string_view.h"
@@ -17,6 +19,9 @@
#include "quiche/quic/test_tools/web_transport_resets_backend.h"
#include "quiche/quic/tools/web_transport_test_visitors.h"
#include "quiche/common/platform/api/quiche_googleurl.h"
+#include "quiche/web_transport/complete_buffer_visitor.h"
+#include "quiche/web_transport/web_transport.h"
+#include "quiche/web_transport/web_transport_headers.h"
namespace quic {
namespace test {
@@ -69,6 +74,40 @@
WebTransportSession* session_; // Not owned.
};
+// SubprotocolStreamVisitor opens one stream that contains the selected
+// subprotocol.
+class SubprotocolStreamVisitor : public WebTransportVisitor {
+ public:
+ SubprotocolStreamVisitor(WebTransportSession* session) : session_(session) {}
+
+ void OnSessionReady() override {
+ OnCanCreateNewOutgoingUnidirectionalStream();
+ }
+ void OnSessionClosed(WebTransportSessionError /*error_code*/,
+ const std::string& /*error_message*/) override {}
+ void OnIncomingBidirectionalStreamAvailable() override {}
+ void OnIncomingUnidirectionalStreamAvailable() override {}
+ void OnDatagramReceived(absl::string_view /*datagram*/) override {}
+ void OnCanCreateNewOutgoingBidirectionalStream() override {}
+ void OnCanCreateNewOutgoingUnidirectionalStream() override {
+ if (sent_) {
+ return;
+ }
+ webtransport::Stream* stream = session_->OpenOutgoingUnidirectionalStream();
+ if (stream == nullptr) {
+ return;
+ }
+ stream->SetVisitor(std::make_unique<webtransport::CompleteBufferVisitor>(
+ stream, session_->GetNegotiatedSubprotocol().value_or("[none]")));
+ stream->visitor()->OnCanWrite();
+ sent_ = true;
+ }
+
+ private:
+ WebTransportSession* session_; // Not owned.
+ bool sent_ = false;
+};
+
} // namespace
QuicSimpleServerBackend::WebTransportResponse
@@ -119,6 +158,38 @@
response.visitor = std::make_unique<SessionCloseVisitor>(session);
return response;
}
+ if (path == "/selected-subprotocol") {
+ auto subprotocol_it =
+ request_headers.find(webtransport::kSubprotocolRequestHeader);
+ if (subprotocol_it == request_headers.end()) {
+ WebTransportResponse response;
+ response.response_headers[":status"] = "400";
+ return response;
+ }
+ absl::StatusOr<std::vector<std::string>> subprotocols =
+ webtransport::ParseSubprotocolRequestHeader(subprotocol_it->second);
+ if (!subprotocols.ok() || subprotocols->empty()) {
+ WebTransportResponse response;
+ response.response_headers[":status"] = "400";
+ return response;
+ }
+ size_t subprotocol_index = 0;
+ auto subprotocol_index_it = request_headers.find("subprotocol-index");
+ if (subprotocol_index_it != request_headers.end()) {
+ if (!absl::SimpleAtoi(subprotocol_index_it->second, &subprotocol_index) ||
+ subprotocol_index >= subprotocols->size()) {
+ WebTransportResponse response;
+ response.response_headers[":status"] = "400";
+ return response;
+ }
+ }
+ WebTransportResponse response;
+ response.response_headers[":status"] = "200";
+ response.response_headers[webtransport::kSubprotocolResponseHeader] =
+ (*subprotocols)[subprotocol_index];
+ response.visitor = std::make_unique<SubprotocolStreamVisitor>(session);
+ return response;
+ }
WebTransportResponse response;
response.response_headers[":status"] = "404";
diff --git a/quiche/web_transport/web_transport_headers.h b/quiche/web_transport/web_transport_headers.h
index c43fcf3..9ed1674 100644
--- a/quiche/web_transport/web_transport_headers.h
+++ b/quiche/web_transport/web_transport_headers.h
@@ -17,8 +17,8 @@
namespace webtransport {
inline constexpr absl::string_view kSubprotocolRequestHeader =
- "WT-Available-Protocols";
-inline constexpr absl::string_view kSubprotocolResponseHeader = "WT-Protocol";
+ "wt-available-protocols";
+inline constexpr absl::string_view kSubprotocolResponseHeader = "wt-protocol";
QUICHE_EXPORT absl::StatusOr<std::vector<std::string>>
ParseSubprotocolRequestHeader(absl::string_view value);