Implement sending and receiving DRAIN_WEBTRANSPORT_SESSION

PiperOrigin-RevId: 540692201
diff --git a/quiche/common/capsule.cc b/quiche/common/capsule.cc
index 4c5ad5a..66d4a66 100644
--- a/quiche/common/capsule.cc
+++ b/quiche/common/capsule.cc
@@ -36,6 +36,8 @@
       return "LEGACY_DATAGRAM_WITHOUT_CONTEXT";
     case CapsuleType::CLOSE_WEBTRANSPORT_SESSION:
       return "CLOSE_WEBTRANSPORT_SESSION";
+    case CapsuleType::DRAIN_WEBTRANSPORT_SESSION:
+      return "DRAIN_WEBTRANSPORT_SESSION";
     case CapsuleType::ADDRESS_REQUEST:
       return "ADDRESS_REQUEST";
     case CapsuleType::ADDRESS_ASSIGN:
@@ -129,6 +131,10 @@
                       ",error_message=\"", error_message, "\")");
 }
 
+std::string DrainWebTransportSessionCapsule::ToString() const {
+  return "DRAIN_WEBTRANSPORT_SESSION()";
+}
+
 std::string AddressRequestCapsule::ToString() const {
   std::string rv = "ADDRESS_REQUEST[";
   for (auto requested_address : requested_addresses) {
@@ -293,6 +299,8 @@
           WireUint32(capsule.close_web_transport_session_capsule().error_code),
           WireBytes(
               capsule.close_web_transport_session_capsule().error_message));
+    case CapsuleType::DRAIN_WEBTRANSPORT_SESSION:
+      return SerializeCapsuleFields(capsule.capsule_type(), allocator);
     case CapsuleType::ADDRESS_REQUEST:
       return SerializeCapsuleFields(
           capsule.capsule_type(), allocator,
@@ -414,6 +422,8 @@
       capsule.error_message = reader.ReadRemainingPayload();
       return Capsule(std::move(capsule));
     }
+    case CapsuleType::DRAIN_WEBTRANSPORT_SESSION:
+      return Capsule(DrainWebTransportSessionCapsule());
     case CapsuleType::ADDRESS_REQUEST: {
       AddressRequestCapsule capsule;
       while (!reader.IsDoneReading()) {
diff --git a/quiche/common/capsule.h b/quiche/common/capsule.h
index 08bde5b..7220ee4 100644
--- a/quiche/common/capsule.h
+++ b/quiche/common/capsule.h
@@ -13,6 +13,7 @@
 #include "absl/strings/string_view.h"
 #include "absl/types/optional.h"
 #include "absl/types/variant.h"
+#include "quiche/common/platform/api/quiche_export.h"
 #include "quiche/common/platform/api/quiche_logging.h"
 #include "quiche/common/quiche_buffer_allocator.h"
 #include "quiche/common/quiche_ip_address.h"
@@ -29,6 +30,7 @@
 
   // <https://datatracker.ietf.org/doc/draft-ietf-webtrans-http3/>
   CLOSE_WEBTRANSPORT_SESSION = 0x2843,
+  DRAIN_WEBTRANSPORT_SESSION = 0x78ae,
 
   // draft-ietf-masque-connect-ip-03.
   ADDRESS_ASSIGN = 0x1ECA6A00,
@@ -106,6 +108,13 @@
            error_message == other.error_message;
   }
 };
+struct QUICHE_EXPORT DrainWebTransportSessionCapsule {
+  std::string ToString() const;
+  CapsuleType capsule_type() const {
+    return CapsuleType::DRAIN_WEBTRANSPORT_SESSION;
+  }
+  bool operator==(const DrainWebTransportSessionCapsule&) const { return true; }
+};
 
 // MASQUE CONNECT-IP.
 struct QUICHE_EXPORT PrefixWithId {
@@ -320,7 +329,8 @@
  private:
   absl::variant<DatagramCapsule, LegacyDatagramCapsule,
                 LegacyDatagramWithoutContextCapsule,
-                CloseWebTransportSessionCapsule, AddressRequestCapsule,
+                CloseWebTransportSessionCapsule,
+                DrainWebTransportSessionCapsule, AddressRequestCapsule,
                 AddressAssignCapsule, RouteAdvertisementCapsule,
                 WebTransportStreamDataCapsule, WebTransportResetStreamCapsule,
                 WebTransportStopSendingCapsule, WebTransportMaxStreamsCapsule,
diff --git a/quiche/common/capsule_test.cc b/quiche/common/capsule_test.cc
index 5ed4d1a..ae55aaf 100644
--- a/quiche/common/capsule_test.cc
+++ b/quiche/common/capsule_test.cc
@@ -136,6 +136,20 @@
   TestSerialization(expected_capsule, capsule_fragment);
 }
 
+TEST_F(CapsuleTest, DrainWebTransportStreamCapsule) {
+  std::string capsule_fragment = absl::HexStringToBytes(
+      "800078ae"  // DRAIN_WEBTRANSPORT_STREAM capsule type
+      "00"        // capsule length
+  );
+  Capsule expected_capsule = Capsule(DrainWebTransportSessionCapsule());
+  {
+    EXPECT_CALL(visitor_, OnCapsule(expected_capsule));
+    ASSERT_TRUE(capsule_parser_.IngestCapsuleFragment(capsule_fragment));
+  }
+  ValidateParserIsEmpty();
+  TestSerialization(expected_capsule, capsule_fragment);
+}
+
 TEST_F(CapsuleTest, AddressAssignCapsule) {
   std::string capsule_fragment = absl::HexStringToBytes(
       "9ECA6A00"  // ADDRESS_ASSIGN capsule type
diff --git a/quiche/common/wire_serialization.h b/quiche/common/wire_serialization.h
index 89a54f2..7cc2596 100644
--- a/quiche/common/wire_serialization.h
+++ b/quiche/common/wire_serialization.h
@@ -347,6 +347,10 @@
   QUICHE_RETURN_IF_ERROR(SerializeIntoWriterCore(writer, argno, data1));
   return SerializeIntoWriterCore(writer, argno + 1, rest...);
 }
+
+inline absl::Status SerializeIntoWriterCore(QuicheDataWriter&, int) {
+  return absl::OkStatus();
+}
 }  // namespace wire_serialization_internal
 
 // SerializeIntoWriter(writer, d1, d2, ... dN) serializes all of supplied data
@@ -369,6 +373,7 @@
 size_t ComputeLengthOnWire(T1 data1, Ts... rest) {
   return data1.GetLengthOnWire() + ComputeLengthOnWire(rest...);
 }
+inline size_t ComputeLengthOnWire() { return 0; }
 
 // SerializeIntoBuffer(allocator, d1, d2, ... dN) computes the length required
 // to store the supplied data, allocates the buffer of appropriate size using
diff --git a/quiche/common/wire_serialization_test.cc b/quiche/common/wire_serialization_test.cc
index b1dea91..9abdda9 100644
--- a/quiche/common/wire_serialization_test.cc
+++ b/quiche/common/wire_serialization_test.cc
@@ -252,5 +252,7 @@
 #endif
 }
 
+TEST(SerializationTest, Empty) { ExpectEncodingHex("nothing", ""); }
+
 }  // namespace
 }  // namespace quiche::test
diff --git a/quiche/quic/core/http/end_to_end_test.cc b/quiche/quic/core/http/end_to_end_test.cc
index 6404559..52efb1a 100644
--- a/quiche/quic/core/http/end_to_end_test.cc
+++ b/quiche/quic/core/http/end_to_end_test.cc
@@ -6804,6 +6804,29 @@
   EXPECT_TRUE(spdy_stream == nullptr);
 }
 
+TEST_P(EndToEndTest, WebTransportSessionReceiveDrain) {
+  enable_web_transport_ = true;
+  ASSERT_TRUE(Initialize());
+
+  if (!version_.UsesHttp3()) {
+    return;
+  }
+
+  WebTransportHttp3* session = CreateWebTransportSession(
+      "/session-close", /*wait_for_server_response=*/true);
+  ASSERT_TRUE(session != nullptr);
+
+  WebTransportStream* stream = session->OpenOutgoingUnidirectionalStream();
+  ASSERT_TRUE(stream != nullptr);
+  QUICHE_EXPECT_OK(quiche::WriteIntoStream(*stream, "DRAIN"));
+  EXPECT_TRUE(stream->SendFin());
+
+  bool drain_received = false;
+  session->SetOnDraining([&drain_received] { drain_received = true; });
+  client_->WaitUntil(2000, [&]() { return drain_received; });
+  EXPECT_TRUE(drain_received);
+}
+
 TEST_P(EndToEndTest, WebTransportSessionStreamTermination) {
   enable_web_transport_ = true;
   ASSERT_TRUE(Initialize());
diff --git a/quiche/quic/core/http/quic_spdy_stream.cc b/quiche/quic/core/http/quic_spdy_stream.cc
index 9a479f3..24a0a16 100644
--- a/quiche/quic/core/http/quic_spdy_stream.cc
+++ b/quiche/quic/core/http/quic_spdy_stream.cc
@@ -1397,6 +1397,14 @@
           capsule.close_web_transport_session_capsule().error_code,
           capsule.close_web_transport_session_capsule().error_message);
       return true;
+    case CapsuleType::DRAIN_WEBTRANSPORT_SESSION:
+      if (web_transport_ == nullptr) {
+        QUIC_DLOG(ERROR) << ENDPOINT << "Received capsule " << capsule
+                         << " for a non-WebTransport stream.";
+        return false;
+      }
+      web_transport_->OnDrainSessionReceived();
+      return true;
     case CapsuleType::ADDRESS_ASSIGN:
       if (connect_ip_visitor_ == nullptr) {
         return true;
diff --git a/quiche/quic/core/http/web_transport_http3.cc b/quiche/quic/core/http/web_transport_http3.cc
index 6e55cdb..6db4eae 100644
--- a/quiche/quic/core/http/web_transport_http3.cc
+++ b/quiche/quic/core/http/web_transport_http3.cc
@@ -276,6 +276,14 @@
   connect_stream_->SetMaxDatagramTimeInQueue(QuicTimeDelta(max_time_in_queue));
 }
 
+void WebTransportHttp3::NotifySessionDraining() {
+  if (!drain_sent_) {
+    connect_stream_->WriteCapsule(
+        quiche::Capsule(quiche::DrainWebTransportSessionCapsule()));
+    drain_sent_ = true;
+  }
+}
+
 void WebTransportHttp3::OnHttp3Datagram(QuicStreamId stream_id,
                                         absl::string_view payload) {
   QUICHE_DCHECK_EQ(stream_id, connect_stream_->id());
@@ -297,6 +305,8 @@
   }
 }
 
+void WebTransportHttp3::OnDrainSessionReceived() { OnGoAwayReceived(); }
+
 WebTransportHttp3UnidirectionalStream::WebTransportHttp3UnidirectionalStream(
     PendingStream* pending, QuicSpdySession* session)
     : QuicStream(pending, session, /*is_static=*/false),
diff --git a/quiche/quic/core/http/web_transport_http3.h b/quiche/quic/core/http/web_transport_http3.h
index eb54c06..e74335e 100644
--- a/quiche/quic/core/http/web_transport_http3.h
+++ b/quiche/quic/core/http/web_transport_http3.h
@@ -90,6 +90,7 @@
   QuicByteCount GetMaxDatagramSize() const override;
   void SetDatagramMaxTimeInQueue(absl::Duration max_time_in_queue) override;
 
+  void NotifySessionDraining() override;
   void SetOnDraining(quiche::SingleUseCallback<void()> callback) override {
     drain_callback_ = std::move(callback);
   }
@@ -106,6 +107,7 @@
   }
 
   void OnGoAwayReceived();
+  void OnDrainSessionReceived();
 
  private:
   // Notifies the visitor that the connection has been closed.  Ensures that the
@@ -130,6 +132,7 @@
 
   WebTransportHttp3RejectionReason rejection_reason_ =
       WebTransportHttp3RejectionReason::kNone;
+  bool drain_sent_ = false;
   // Those are set to default values, which are used if the session is not
   // closed cleanly using an appropriate capsule.
   WebTransportSessionError error_code_ = 0;
diff --git a/quiche/quic/test_tools/quic_test_backend.cc b/quiche/quic/test_tools/quic_test_backend.cc
index be8068b..4383f23 100644
--- a/quiche/quic/test_tools/quic_test_backend.cc
+++ b/quiche/quic/test_tools/quic_test_backend.cc
@@ -23,7 +23,8 @@
 // sends a unidirectional stream of format "code message" to this endpoint, it
 // will close the session with the corresponding error code and error message.
 // For instance, sending "42 test error" will cause it to be closed with code 42
-// and message "test error".
+// and message "test error".  As a special case, sending "DRAIN" would result in
+// a DRAIN_WEBTRANSPORT_SESSION capsule being sent.
 class SessionCloseVisitor : public WebTransportVisitor {
  public:
   SessionCloseVisitor(WebTransportSession* session) : session_(session) {}
@@ -41,6 +42,10 @@
     stream->SetVisitor(
         std::make_unique<WebTransportUnidirectionalEchoReadVisitor>(
             stream, [this](const std::string& data) {
+              if (data == "DRAIN") {
+                session_->NotifySessionDraining();
+                return;
+              }
               std::pair<absl::string_view, absl::string_view> parsed =
                   absl::StrSplit(data, absl::MaxSplits(' ', 1));
               WebTransportSessionError error_code = 0;
diff --git a/quiche/web_transport/test_tools/mock_web_transport.h b/quiche/web_transport/test_tools/mock_web_transport.h
index 2614581..e102a08 100644
--- a/quiche/web_transport/test_tools/mock_web_transport.h
+++ b/quiche/web_transport/test_tools/mock_web_transport.h
@@ -75,6 +75,7 @@
   MOCK_METHOD(uint64_t, GetMaxDatagramSize, (), (const, override));
   MOCK_METHOD(void, SetDatagramMaxTimeInQueue,
               (absl::Duration max_time_in_queue), (override));
+  MOCK_METHOD(void, NotifySessionDraining, (), (override));
   MOCK_METHOD(void, SetOnDraining, (quiche::SingleUseCallback<void()>),
               (override));
 };
diff --git a/quiche/web_transport/web_transport.h b/quiche/web_transport/web_transport.h
index 20598f1..6cb7811 100644
--- a/quiche/web_transport/web_transport.h
+++ b/quiche/web_transport/web_transport.h
@@ -215,6 +215,9 @@
   // being silently dropped.
   virtual void SetDatagramMaxTimeInQueue(absl::Duration max_time_in_queue) = 0;
 
+  // Sends a DRAIN_WEBTRANSPORT_SESSION capsule or an equivalent signal to the
+  // peer indicating that the session is draining.
+  virtual void NotifySessionDraining() = 0;
   // Notifies that either the session itself (DRAIN_WEBTRANSPORT_SESSION
   // capsule), or the underlying connection (HTTP GOAWAY) is being drained by
   // the peer.