Add an end-to-end test for WebTransport subprotocol negotiation.

PiperOrigin-RevId: 748789470
diff --git a/quiche/quic/core/http/end_to_end_test.cc b/quiche/quic/core/http/end_to_end_test.cc
index 0d1cd23..0b4d55e 100644
--- a/quiche/quic/core/http/end_to_end_test.cc
+++ b/quiche/quic/core/http/end_to_end_test.cc
@@ -93,6 +93,7 @@
 #include "quiche/common/platform/api/quiche_test.h"
 #include "quiche/common/quiche_stream.h"
 #include "quiche/common/test_tools/quiche_test_utils.h"
+#include "quiche/web_transport/web_transport_headers.h"
 
 using quiche::HttpHeaderBlock;
 using spdy::SpdyFramer;
@@ -854,7 +855,8 @@
 
   WebTransportHttp3* CreateWebTransportSession(
       const std::string& path, bool wait_for_server_response,
-      QuicSpdyStream** connect_stream_out = nullptr) {
+      std::initializer_list<std::pair<absl::string_view, absl::string_view>>
+          extra_headers = {}) {
     // Wait until we receive the settings from the server indicating
     // WebTransport support.
     client_->WaitUntil(
@@ -869,6 +871,9 @@
     headers[":path"] = path;
     headers[":method"] = "CONNECT";
     headers[":protocol"] = "webtransport";
+    for (const auto& [key, value] : extra_headers) {
+      headers[key] = std::string(value);
+    }
 
     client_->SendMessage(headers, "", /*fin=*/false);
     QuicSpdyStream* stream = client_->latest_created_stream();
@@ -886,9 +891,6 @@
                          [stream]() { return stream->headers_decompressed(); });
       EXPECT_TRUE(session->ready());
     }
-    if (connect_stream_out != nullptr) {
-      *connect_stream_out = stream;
-    }
     return session;
   }
 
@@ -7180,6 +7182,40 @@
   server_thread_->Resume();
 }
 
+TEST_P(EndToEndTest, WebTransportSessionProtocolNegotiation) {
+  enable_web_transport_ = true;
+  ASSERT_TRUE(Initialize());
+
+  if (!version_.UsesHttp3()) {
+    return;
+  }
+
+  WebTransportHttp3* session = CreateWebTransportSession(
+      "/selected-subprotocol", /*wait_for_server_response=*/true,
+      {{webtransport::kSubprotocolRequestHeader, "a, b, c, d"},
+       {"subprotocol-index", "1"}});
+  ASSERT_NE(session, nullptr);
+  NiceMock<MockWebTransportSessionVisitor>& visitor =
+      SetupWebTransportVisitor(session);
+  EXPECT_EQ(session->GetNegotiatedSubprotocol(), "b");
+
+  WebTransportStream* received_stream =
+      session->AcceptIncomingUnidirectionalStream();
+  if (received_stream == nullptr) {
+    // Retry if reordering happens.
+    bool stream_received = false;
+    EXPECT_CALL(visitor, OnIncomingUnidirectionalStreamAvailable())
+        .WillOnce(Assign(&stream_received, true));
+    client_->WaitUntil(2000, [&stream_received]() { return stream_received; });
+    received_stream = session->AcceptIncomingUnidirectionalStream();
+  }
+  ASSERT_TRUE(received_stream != nullptr);
+  std::string received_data;
+  WebTransportStream::ReadResult result = received_stream->Read(&received_data);
+  EXPECT_EQ(received_data, "b");
+  EXPECT_TRUE(result.fin);
+}
+
 TEST_P(EndToEndTest, WebTransportSessionSetupWithEchoWithSuffix) {
   enable_web_transport_ = true;
   ASSERT_TRUE(Initialize());
diff --git a/quiche/quic/test_tools/quic_test_backend.cc b/quiche/quic/test_tools/quic_test_backend.cc
index 93e7e44..c40db49 100644
--- a/quiche/quic/test_tools/quic_test_backend.cc
+++ b/quiche/quic/test_tools/quic_test_backend.cc
@@ -10,6 +10,8 @@
 #include <utility>
 #include <vector>
 
+#include "absl/status/statusor.h"
+#include "absl/strings/numbers.h"
 #include "absl/strings/str_cat.h"
 #include "absl/strings/str_split.h"
 #include "absl/strings/string_view.h"
@@ -17,6 +19,9 @@
 #include "quiche/quic/test_tools/web_transport_resets_backend.h"
 #include "quiche/quic/tools/web_transport_test_visitors.h"
 #include "quiche/common/platform/api/quiche_googleurl.h"
+#include "quiche/web_transport/complete_buffer_visitor.h"
+#include "quiche/web_transport/web_transport.h"
+#include "quiche/web_transport/web_transport_headers.h"
 
 namespace quic {
 namespace test {
@@ -69,6 +74,40 @@
   WebTransportSession* session_;  // Not owned.
 };
 
+// SubprotocolStreamVisitor opens one stream that contains the selected
+// subprotocol.
+class SubprotocolStreamVisitor : public WebTransportVisitor {
+ public:
+  SubprotocolStreamVisitor(WebTransportSession* session) : session_(session) {}
+
+  void OnSessionReady() override {
+    OnCanCreateNewOutgoingUnidirectionalStream();
+  }
+  void OnSessionClosed(WebTransportSessionError /*error_code*/,
+                       const std::string& /*error_message*/) override {}
+  void OnIncomingBidirectionalStreamAvailable() override {}
+  void OnIncomingUnidirectionalStreamAvailable() override {}
+  void OnDatagramReceived(absl::string_view /*datagram*/) override {}
+  void OnCanCreateNewOutgoingBidirectionalStream() override {}
+  void OnCanCreateNewOutgoingUnidirectionalStream() override {
+    if (sent_) {
+      return;
+    }
+    webtransport::Stream* stream = session_->OpenOutgoingUnidirectionalStream();
+    if (stream == nullptr) {
+      return;
+    }
+    stream->SetVisitor(std::make_unique<webtransport::CompleteBufferVisitor>(
+        stream, session_->GetNegotiatedSubprotocol().value_or("[none]")));
+    stream->visitor()->OnCanWrite();
+    sent_ = true;
+  }
+
+ private:
+  WebTransportSession* session_;  // Not owned.
+  bool sent_ = false;
+};
+
 }  // namespace
 
 QuicSimpleServerBackend::WebTransportResponse
@@ -119,6 +158,38 @@
     response.visitor = std::make_unique<SessionCloseVisitor>(session);
     return response;
   }
+  if (path == "/selected-subprotocol") {
+    auto subprotocol_it =
+        request_headers.find(webtransport::kSubprotocolRequestHeader);
+    if (subprotocol_it == request_headers.end()) {
+      WebTransportResponse response;
+      response.response_headers[":status"] = "400";
+      return response;
+    }
+    absl::StatusOr<std::vector<std::string>> subprotocols =
+        webtransport::ParseSubprotocolRequestHeader(subprotocol_it->second);
+    if (!subprotocols.ok() || subprotocols->empty()) {
+      WebTransportResponse response;
+      response.response_headers[":status"] = "400";
+      return response;
+    }
+    size_t subprotocol_index = 0;
+    auto subprotocol_index_it = request_headers.find("subprotocol-index");
+    if (subprotocol_index_it != request_headers.end()) {
+      if (!absl::SimpleAtoi(subprotocol_index_it->second, &subprotocol_index) ||
+          subprotocol_index >= subprotocols->size()) {
+        WebTransportResponse response;
+        response.response_headers[":status"] = "400";
+        return response;
+      }
+    }
+    WebTransportResponse response;
+    response.response_headers[":status"] = "200";
+    response.response_headers[webtransport::kSubprotocolResponseHeader] =
+        (*subprotocols)[subprotocol_index];
+    response.visitor = std::make_unique<SubprotocolStreamVisitor>(session);
+    return response;
+  }
 
   WebTransportResponse response;
   response.response_headers[":status"] = "404";
diff --git a/quiche/web_transport/web_transport_headers.h b/quiche/web_transport/web_transport_headers.h
index c43fcf3..9ed1674 100644
--- a/quiche/web_transport/web_transport_headers.h
+++ b/quiche/web_transport/web_transport_headers.h
@@ -17,8 +17,8 @@
 namespace webtransport {
 
 inline constexpr absl::string_view kSubprotocolRequestHeader =
-    "WT-Available-Protocols";
-inline constexpr absl::string_view kSubprotocolResponseHeader = "WT-Protocol";
+    "wt-available-protocols";
+inline constexpr absl::string_view kSubprotocolResponseHeader = "wt-protocol";
 
 QUICHE_EXPORT absl::StatusOr<std::vector<std::string>>
 ParseSubprotocolRequestHeader(absl::string_view value);