Implement the lower-level part of WebTransport subprotocol negotiation.

This is relatively easy, since due to the way our implementation is structured, the headers are entirely up to the API user; we just need to store the actual selected value so that we can pass it to the application.

PiperOrigin-RevId: 746142355
diff --git a/quiche/quic/core/http/quic_spdy_stream.cc b/quiche/quic/core/http/quic_spdy_stream.cc
index 8b5894e..70a6b2e 100644
--- a/quiche/quic/core/http/quic_spdy_stream.cc
+++ b/quiche/quic/core/http/quic_spdy_stream.cc
@@ -9,8 +9,10 @@
 #include <optional>
 #include <string>
 #include <utility>
+#include <vector>
 
 #include "absl/base/macros.h"
+#include "absl/status/statusor.h"
 #include "absl/strings/numbers.h"
 #include "absl/strings/str_cat.h"
 #include "absl/strings/string_view.h"
@@ -39,10 +41,12 @@
 #include "quiche/quic/platform/api/quic_logging.h"
 #include "quiche/quic/platform/api/quic_testvalue.h"
 #include "quiche/common/capsule.h"
+#include "quiche/common/http/http_header_block.h"
 #include "quiche/common/platform/api/quiche_flag_utils.h"
 #include "quiche/common/platform/api/quiche_logging.h"
 #include "quiche/common/quiche_mem_slice_storage.h"
 #include "quiche/common/quiche_text_utils.h"
+#include "quiche/web_transport/web_transport_headers.h"
 
 using ::quiche::Capsule;
 using ::quiche::CapsuleType;
@@ -1426,6 +1430,9 @@
     return;
   }
   if (session()->perspective() != Perspective::IS_CLIENT) {
+    if (web_transport_ != nullptr) {
+      web_transport_->MaybeSetSubprotocolFromResponseHeaders(headers);
+    }
     return;
   }
   QUICHE_DCHECK(IsValidWebTransportSessionId(id(), version()));
@@ -1446,6 +1453,23 @@
 
   web_transport_ =
       std::make_unique<WebTransportHttp3>(spdy_session_, this, id());
+
+  // Store the offered subprotocols so that we can later validate the
+  // server-selected one against those.
+  const auto subprotocol_offer_it =
+      headers.find(webtransport::kSubprotocolRequestHeader);
+  if (subprotocol_offer_it != headers.end()) {
+    absl::StatusOr<std::vector<std::string>> subprotocols_offered =
+        webtransport::ParseSubprotocolRequestHeader(
+            subprotocol_offer_it->second);
+    if (subprotocols_offered.ok()) {
+      web_transport_->set_subprotocols_offered(
+          *std::move(subprotocols_offered));
+    } else {
+      QUIC_DLOG(WARNING) << "Attempting to send WebTransport subprotocols that "
+                            "cannot be parsed.";
+    }
+  }
 }
 
 void QuicSpdyStream::OnCanWriteNewData() {
diff --git a/quiche/quic/core/http/quic_spdy_stream_test.cc b/quiche/quic/core/http/quic_spdy_stream_test.cc
index be64b2c..1536a57 100644
--- a/quiche/quic/core/http/quic_spdy_stream_test.cc
+++ b/quiche/quic/core/http/quic_spdy_stream_test.cc
@@ -26,6 +26,7 @@
 #include "quiche/quic/core/http/web_transport_http3.h"
 #include "quiche/quic/core/qpack/value_splitting_header_list.h"
 #include "quiche/quic/core/quic_connection.h"
+#include "quiche/quic/core/quic_stream_priority.h"
 #include "quiche/quic/core/quic_stream_sequencer_buffer.h"
 #include "quiche/quic/core/quic_utils.h"
 #include "quiche/quic/core/quic_versions.h"
@@ -44,6 +45,7 @@
 #include "quiche/quic/test_tools/quic_stream_peer.h"
 #include "quiche/quic/test_tools/quic_test_utils.h"
 #include "quiche/common/capsule.h"
+#include "quiche/common/http/http_header_block.h"
 #include "quiche/common/quiche_ip_address.h"
 #include "quiche/common/quiche_mem_slice_storage.h"
 #include "quiche/common/simple_buffer_allocator.h"
@@ -3287,7 +3289,7 @@
   EXPECT_FALSE(stream_->write_side_closed());
 }
 
-TEST_P(QuicSpdyStreamTest, ProcessOutgoingWebTransportHeaders) {
+TEST_P(QuicSpdyStreamTest, ProcessWebTransportHeadersAsClient) {
   if (!UsesHttp3()) {
     return;
   }
@@ -3304,20 +3306,97 @@
   EXPECT_CALL(*session_, WritevData(stream_->id(), _, _, _, _, _))
       .Times(AnyNumber());
 
-  quiche::HttpHeaderBlock headers;
-  headers[":method"] = "CONNECT";
-  headers[":protocol"] = "webtransport";
-  stream_->WriteHeaders(std::move(headers), /*fin=*/false, nullptr);
+  quiche::HttpHeaderBlock request_headers;
+  request_headers[":method"] = "CONNECT";
+  request_headers[":protocol"] = "webtransport";
+  request_headers["wt-available-protocols"] = "moqt-00, moqt-01; foo=bar";
+  stream_->WriteHeaders(std::move(request_headers), /*fin=*/false, nullptr);
   ASSERT_TRUE(stream_->web_transport() != nullptr);
   EXPECT_EQ(stream_->id(), stream_->web_transport()->id());
+  EXPECT_THAT(stream_->web_transport()->subprotocols_offered(),
+              ElementsAre("moqt-00", "moqt-01"));
+
+  quiche::HttpHeaderBlock response_headers;
+  response_headers[":status"] = "200";
+  response_headers["wt-protocol"] = "moqt-01";
+  stream_->web_transport()->HeadersReceived(response_headers);
+  EXPECT_EQ(stream_->web_transport()->rejection_reason(),
+            WebTransportHttp3RejectionReason::kNone);
+  EXPECT_EQ(stream_->web_transport()->GetNegotiatedSubprotocol(), "moqt-01");
 }
 
-TEST_P(QuicSpdyStreamTest, ProcessIncomingWebTransportHeaders) {
+TEST_P(QuicSpdyStreamTest, WebTransportRejectSubprotocolsThatWereNotOffered) {
   if (!UsesHttp3()) {
     return;
   }
 
-  Initialize(kShouldProcessData);
+  InitializeWithPerspective(kShouldProcessData, Perspective::IS_CLIENT);
+  session_->set_local_http_datagram_support(HttpDatagramSupport::kRfc);
+  session_->EnableWebTransport();
+  session_->OnSetting(SETTINGS_ENABLE_CONNECT_PROTOCOL, 1);
+  QuicSpdySessionPeer::EnableWebTransport(session_.get());
+  QuicSpdySessionPeer::SetHttpDatagramSupport(session_.get(),
+                                              HttpDatagramSupport::kRfc);
+
+  EXPECT_CALL(*stream_, WriteHeadersMock(false));
+  EXPECT_CALL(*session_, WritevData(stream_->id(), _, _, _, _, _))
+      .Times(AnyNumber());
+
+  quiche::HttpHeaderBlock request_headers;
+  request_headers[":method"] = "CONNECT";
+  request_headers[":protocol"] = "webtransport";
+  request_headers["wt-available-protocols"] = "moqt-00, moqt-01; foo=bar";
+  stream_->WriteHeaders(std::move(request_headers), /*fin=*/false, nullptr);
+  ASSERT_TRUE(stream_->web_transport() != nullptr);
+
+  quiche::HttpHeaderBlock response_headers;
+  response_headers[":status"] = "200";
+  response_headers["wt-protocol"] = "moqt-02";
+  stream_->web_transport()->HeadersReceived(response_headers);
+  EXPECT_EQ(stream_->web_transport()->rejection_reason(),
+            WebTransportHttp3RejectionReason::kSubprotocolMismatch);
+  EXPECT_EQ(stream_->web_transport()->GetNegotiatedSubprotocol(), std::nullopt);
+}
+
+TEST_P(QuicSpdyStreamTest, WebTransportInvalidSubprotocolResponse) {
+  if (!UsesHttp3()) {
+    return;
+  }
+
+  InitializeWithPerspective(kShouldProcessData, Perspective::IS_CLIENT);
+  session_->set_local_http_datagram_support(HttpDatagramSupport::kRfc);
+  session_->EnableWebTransport();
+  session_->OnSetting(SETTINGS_ENABLE_CONNECT_PROTOCOL, 1);
+  QuicSpdySessionPeer::EnableWebTransport(session_.get());
+  QuicSpdySessionPeer::SetHttpDatagramSupport(session_.get(),
+                                              HttpDatagramSupport::kRfc);
+
+  EXPECT_CALL(*stream_, WriteHeadersMock(false));
+  EXPECT_CALL(*session_, WritevData(stream_->id(), _, _, _, _, _))
+      .Times(AnyNumber());
+
+  quiche::HttpHeaderBlock request_headers;
+  request_headers[":method"] = "CONNECT";
+  request_headers[":protocol"] = "webtransport";
+  request_headers["wt-available-protocols"] = "moqt-00, moqt-01; foo=bar";
+  stream_->WriteHeaders(std::move(request_headers), /*fin=*/false, nullptr);
+  ASSERT_TRUE(stream_->web_transport() != nullptr);
+
+  quiche::HttpHeaderBlock response_headers;
+  response_headers[":status"] = "200";
+  response_headers["wt-protocol"] = "12345.67";
+  stream_->web_transport()->HeadersReceived(response_headers);
+  EXPECT_EQ(stream_->web_transport()->rejection_reason(),
+            WebTransportHttp3RejectionReason::kSubprotocolParseError);
+  EXPECT_EQ(stream_->web_transport()->GetNegotiatedSubprotocol(), std::nullopt);
+}
+
+TEST_P(QuicSpdyStreamTest, ProcessWebTransportHeadersAsServer) {
+  if (!UsesHttp3()) {
+    return;
+  }
+
+  InitializeWithPerspective(kShouldProcessData, Perspective::IS_SERVER);
   session_->set_local_http_datagram_support(HttpDatagramSupport::kRfc);
   session_->EnableWebTransport();
   QuicSpdySessionPeer::EnableWebTransport(session_.get());
@@ -3326,6 +3405,7 @@
 
   headers_[":method"] = "CONNECT";
   headers_[":protocol"] = "webtransport";
+  headers_["wt-available-protocols"] = "moqt-00, moqt-01; foo=bar";
 
   stream_->OnStreamHeadersPriority(
       spdy::SpdyStreamPrecedence(kV3HighestPriority));
@@ -3335,6 +3415,15 @@
   EXPECT_FALSE(stream_->IsDoneReading());
   ASSERT_TRUE(stream_->web_transport() != nullptr);
   EXPECT_EQ(stream_->id(), stream_->web_transport()->id());
+
+  EXPECT_CALL(*stream_, WriteHeadersMock(false));
+  EXPECT_CALL(*session_, WritevData(stream_->id(), _, _, _, _, _))
+      .Times(AnyNumber());
+  quiche::HttpHeaderBlock response_headers;
+  response_headers[":status"] = "200";
+  response_headers["wt-protocol"] = "moqt-01";
+  stream_->WriteHeaders(std::move(response_headers), /*fin=*/false, nullptr);
+  EXPECT_EQ(stream_->web_transport()->GetNegotiatedSubprotocol(), "moqt-01");
 }
 
 TEST_P(QuicSpdyStreamTest, IncomingWebTransportStreamWhenUnsupported) {
diff --git a/quiche/quic/core/http/web_transport_http3.cc b/quiche/quic/core/http/web_transport_http3.cc
index d0687ee..86d6086 100644
--- a/quiche/quic/core/http/web_transport_http3.cc
+++ b/quiche/quic/core/http/web_transport_http3.cc
@@ -12,6 +12,8 @@
 #include <vector>
 
 
+#include "absl/algorithm/container.h"
+#include "absl/status/statusor.h"
 #include "absl/strings/string_view.h"
 #include "quiche/quic/core/http/quic_spdy_session.h"
 #include "quiche/quic/core/http/quic_spdy_stream.h"
@@ -24,8 +26,10 @@
 #include "quiche/quic/core/quic_versions.h"
 #include "quiche/quic/platform/api/quic_bug_tracker.h"
 #include "quiche/common/capsule.h"
+#include "quiche/common/http/http_header_block.h"
 #include "quiche/common/platform/api/quiche_logging.h"
 #include "quiche/web_transport/web_transport.h"
+#include "quiche/web_transport/web_transport_headers.h"
 
 #define ENDPOINT \
   (session_->perspective() == Perspective::IS_SERVER ? "Server: " : "Client: ")
@@ -184,6 +188,12 @@
       rejection_reason_ = WebTransportHttp3RejectionReason::kWrongStatusCode;
       return;
     }
+    WebTransportHttp3RejectionReason subprotocol_result =
+        MaybeSetSubprotocolFromResponseHeaders(headers);
+    if (subprotocol_result != WebTransportHttp3RejectionReason::kNone) {
+      rejection_reason_ = subprotocol_result;
+      return;
+    }
   }
 
   QUIC_DVLOG(1) << ENDPOINT << "WebTransport session " << id_ << " ready.";
@@ -473,4 +483,34 @@
          webtransport_error_code / 0x1e;
 }
 
+WebTransportHttp3RejectionReason
+WebTransportHttp3::MaybeSetSubprotocolFromResponseHeaders(
+    const quiche::HttpHeaderBlock& headers) {
+  auto subprotocol_it = headers.find(webtransport::kSubprotocolResponseHeader);
+  if (subprotocol_it == headers.end()) {
+    return WebTransportHttp3RejectionReason::kNone;
+  }
+
+  absl::StatusOr<std::string> subprotocol =
+      webtransport::ParseSubprotocolResponseHeader(subprotocol_it->second);
+  if (!subprotocol.ok()) {
+    QUIC_DVLOG(1) << ENDPOINT
+                  << "WebTransport server has malformed WT-Protocol "
+                     "header, rejecting.";
+    return WebTransportHttp3RejectionReason::kSubprotocolParseError;
+  }
+
+  if (session_->perspective() == Perspective::IS_CLIENT &&
+      !absl::c_linear_search(subprotocols_offered_, *subprotocol)) {
+    QUIC_DVLOG(1) << ENDPOINT
+                  << "WebTransport server has offered a subprotocol value \""
+                  << *subprotocol
+                  << "\", which was not one of the ones offered, rejecting.";
+    return WebTransportHttp3RejectionReason::kSubprotocolMismatch;
+  }
+
+  subprotocol_selected_ = *std::move(subprotocol);
+  return WebTransportHttp3RejectionReason::kNone;
+}
+
 }  // namespace quic
diff --git a/quiche/quic/core/http/web_transport_http3.h b/quiche/quic/core/http/web_transport_http3.h
index b6d5110..c9950b6 100644
--- a/quiche/quic/core/http/web_transport_http3.h
+++ b/quiche/quic/core/http/web_transport_http3.h
@@ -7,9 +7,13 @@
 
 #include <memory>
 #include <optional>
+#include <string>
+#include <utility>
+#include <vector>
 
 #include "absl/base/attributes.h"
 #include "absl/container/flat_hash_set.h"
+#include "absl/strings/string_view.h"
 #include "absl/time/time.h"
 #include "quiche/quic/core/http/quic_spdy_session.h"
 #include "quiche/quic/core/http/web_transport_stream_adapter.h"
@@ -34,6 +38,8 @@
   kWrongStatusCode,
   kMissingDraftVersion,
   kUnsupportedDraftVersion,
+  kSubprotocolMismatch,
+  kSubprotocolParseError,
 };
 
 // A session of WebTransport over HTTP/3.  The session is owned by
@@ -117,6 +123,18 @@
   void OnGoAwayReceived();
   void OnDrainSessionReceived();
 
+  const std::vector<std::string>& subprotocols_offered() const {
+    return subprotocols_offered_;
+  }
+  void set_subprotocols_offered(std::vector<std::string> subprotocols_offered) {
+    subprotocols_offered_ = std::move(subprotocols_offered);
+  }
+  std::optional<std::string> GetNegotiatedSubprotocol() const override {
+    return subprotocol_selected_;
+  }
+  WebTransportHttp3RejectionReason MaybeSetSubprotocolFromResponseHeaders(
+      const quiche::HttpHeaderBlock& headers);
+
  private:
   // Notifies the visitor that the connection has been closed.  Ensures that the
   // visitor is only ever called once.
@@ -136,6 +154,12 @@
   bool close_received_ = false;
   bool close_notified_ = false;
 
+  // On client side, stores the offered subprotocols.
+  std::vector<std::string> subprotocols_offered_;
+  // Stores the actually selected subprotocol, both on the client and on the
+  // server.
+  std::optional<std::string> subprotocol_selected_;
+
   quiche::SingleUseCallback<void()> drain_callback_ = nullptr;
 
   WebTransportHttp3RejectionReason rejection_reason_ =
diff --git a/quiche/quic/core/quic_generic_session.h b/quiche/quic/core/quic_generic_session.h
index 384f0a3..d8c220a 100644
--- a/quiche/quic/core/quic_generic_session.h
+++ b/quiche/quic/core/quic_generic_session.h
@@ -7,6 +7,9 @@
 
 #include <cstdint>
 #include <memory>
+#include <optional>
+#include <string>
+#include <vector>
 
 #include "absl/algorithm/container.h"
 #include "absl/strings/string_view.h"
@@ -106,6 +109,9 @@
   }
   void NotifySessionDraining() override {}
   void SetOnDraining(quiche::SingleUseCallback<void()>) override {}
+  std::optional<std::string> GetNegotiatedSubprotocol() const override {
+    return alpn_;
+  }
 
   void CloseSession(webtransport::SessionErrorCode error_code,
                     absl::string_view error_message) override {
diff --git a/quiche/web_transport/encapsulated/encapsulated_web_transport.cc b/quiche/web_transport/encapsulated/encapsulated_web_transport.cc
index 30ab1d4..7371945 100644
--- a/quiche/web_transport/encapsulated/encapsulated_web_transport.cc
+++ b/quiche/web_transport/encapsulated/encapsulated_web_transport.cc
@@ -794,4 +794,11 @@
   QUICHE_BUG_IF(EncapsulatedWebTransport_SetPriority_order, !status.ok())
       << status;
 }
+
+std::optional<std::string> EncapsulatedSession::GetNegotiatedSubprotocol()
+    const {
+  // TODO: implement.
+  return std::nullopt;
+}
+
 }  // namespace webtransport
diff --git a/quiche/web_transport/encapsulated/encapsulated_web_transport.h b/quiche/web_transport/encapsulated/encapsulated_web_transport.h
index 0fe2b97..e4e7f7a 100644
--- a/quiche/web_transport/encapsulated/encapsulated_web_transport.h
+++ b/quiche/web_transport/encapsulated/encapsulated_web_transport.h
@@ -102,6 +102,7 @@
   SessionStats GetSessionStats() override;
   void NotifySessionDraining() override;
   void SetOnDraining(quiche::SingleUseCallback<void()> callback) override;
+  std::optional<std::string> GetNegotiatedSubprotocol() const override;
 
   // quiche::WriteStreamVisitor implementation.
   void OnCanWrite() override;
diff --git a/quiche/web_transport/test_tools/mock_web_transport.h b/quiche/web_transport/test_tools/mock_web_transport.h
index 4cb3aaf..656c25c 100644
--- a/quiche/web_transport/test_tools/mock_web_transport.h
+++ b/quiche/web_transport/test_tools/mock_web_transport.h
@@ -10,6 +10,7 @@
 #include <cstddef>
 #include <cstdint>
 #include <memory>
+#include <optional>
 #include <string>
 
 #include "absl/status/status.h"
@@ -94,6 +95,8 @@
   MOCK_METHOD(void, NotifySessionDraining, (), (override));
   MOCK_METHOD(void, SetOnDraining, (quiche::SingleUseCallback<void()>),
               (override));
+  MOCK_METHOD(std::optional<std::string>, GetNegotiatedSubprotocol, (),
+              (const, override));
 };
 
 }  // namespace test
diff --git a/quiche/web_transport/web_transport.h b/quiche/web_transport/web_transport.h
index 2c558ad..5cd92a0 100644
--- a/quiche/web_transport/web_transport.h
+++ b/quiche/web_transport/web_transport.h
@@ -11,6 +11,7 @@
 #include <cstddef>
 #include <cstdint>
 #include <memory>
+#include <optional>
 #include <string>
 
 // The dependencies of this API should be kept minimal and independent of
@@ -260,6 +261,10 @@
   // capsule), or the underlying connection (HTTP GOAWAY) is being drained by
   // the peer.
   virtual void SetOnDraining(quiche::SingleUseCallback<void()> callback) = 0;
+
+  // Returns the negotiated subprotocol, or std::nullopt, if none was
+  // negotiated.
+  virtual std::optional<std::string> GetNegotiatedSubprotocol() const = 0;
 };
 
 }  // namespace webtransport