Handle closing WebTransport sessions.

This currently does not actually send an appropriate capsule, since the code for capsules is not around yet.

PiperOrigin-RevId: 396466894
diff --git a/quic/core/http/end_to_end_test.cc b/quic/core/http/end_to_end_test.cc
index c6c6780..31041fa 100644
--- a/quic/core/http/end_to_end_test.cc
+++ b/quic/core/http/end_to_end_test.cc
@@ -6225,6 +6225,39 @@
   EXPECT_GT(received, 0);
 }
 
+TEST_P(EndToEndTest, WebTransportSessionClose) {
+  enable_web_transport_ = true;
+  ASSERT_TRUE(Initialize());
+
+  if (!version_.UsesHttp3()) {
+    return;
+  }
+
+  WebTransportHttp3* session =
+      CreateWebTransportSession("/echo", /*wait_for_server_response=*/true);
+  ASSERT_TRUE(session != nullptr);
+  NiceMock<MockClientVisitor>& visitor = SetupWebTransportVisitor(session);
+
+  WebTransportStream* stream = session->OpenOutgoingBidirectionalStream();
+  ASSERT_TRUE(stream != nullptr);
+  QuicStreamId stream_id = stream->GetStreamId();
+  EXPECT_TRUE(stream->Write("test"));
+  // Keep stream open.
+
+  bool close_received = false;
+  // TODO(vasilvv): once we have capsule support, actually check the error code
+  // and the error message returned.
+  EXPECT_CALL(visitor, OnSessionClosed(_, _))
+      .WillOnce(Assign(&close_received, true));
+  session->CloseSession(42, "test error");
+  client_->WaitUntil(2000, [&]() { return close_received; });
+  EXPECT_TRUE(close_received);
+
+  QuicSpdyStream* spdy_stream =
+      GetClientSession()->GetOrCreateSpdyDataStream(stream_id);
+  EXPECT_TRUE(spdy_stream == nullptr);
+}
+
 }  // namespace
 }  // namespace test
 }  // namespace quic
diff --git a/quic/core/http/quic_spdy_stream.cc b/quic/core/http/quic_spdy_stream.cc
index 0f0c7dc..f2b8f82 100644
--- a/quic/core/http/quic_spdy_stream.cc
+++ b/quic/core/http/quic_spdy_stream.cc
@@ -832,7 +832,7 @@
   }
 
   if (web_transport_ != nullptr) {
-    web_transport_->CloseAllAssociatedStreams();
+    web_transport_->OnConnectStreamClosing();
   }
   if (web_transport_data_ != nullptr) {
     WebTransportHttp3* web_transport =
diff --git a/quic/core/http/web_transport_http3.cc b/quic/core/http/web_transport_http3.cc
index b15ff16..679cb2b 100644
--- a/quic/core/http/web_transport_http3.cc
+++ b/quic/core/http/web_transport_http3.cc
@@ -13,6 +13,7 @@
 #include "quic/core/http/quic_spdy_stream.h"
 #include "quic/core/quic_data_reader.h"
 #include "quic/core/quic_data_writer.h"
+#include "quic/core/quic_error_codes.h"
 #include "quic/core/quic_stream.h"
 #include "quic/core/quic_types.h"
 #include "quic/core/quic_utils.h"
@@ -28,6 +29,8 @@
 namespace {
 class QUIC_NO_EXPORT NoopWebTransportVisitor : public WebTransportVisitor {
   void OnSessionReady(const spdy::SpdyHeaderBlock&) override {}
+  void OnSessionClosed(WebTransportSessionError /*error_code*/,
+                       const std::string& /*error_message*/) override {}
   void OnIncomingBidirectionalStreamAvailable() override {}
   void OnIncomingUnidirectionalStreamAvailable() override {}
   void OnDatagramReceived(absl::string_view /*datagram*/) override {}
@@ -70,7 +73,7 @@
   }
 }
 
-void WebTransportHttp3::CloseAllAssociatedStreams() {
+void WebTransportHttp3::OnConnectStreamClosing() {
   // Copy the stream list before iterating over it, as calls to ResetStream()
   // can potentially mutate the |session_| list.
   std::vector<QuicStreamId> streams(streams_.begin(), streams_.end());
@@ -83,6 +86,16 @@
     connect_stream_->UnregisterHttp3DatagramContextId(context_id_);
   }
   connect_stream_->UnregisterHttp3DatagramRegistrationVisitor();
+
+  visitor_->OnSessionClosed(error_code_, error_message_);
+}
+
+void WebTransportHttp3::CloseSession(WebTransportSessionError /*error_code*/,
+                                     absl::string_view /*error_message*/) {
+  // TODO(vasilvv): this should write a capsule and send FIN instead, but since
+  // we currently don't handle capsules, this is the next most meaningful
+  // choice.
+  connect_stream_->Reset(QUIC_STREAM_CANCELLED);
 }
 
 void WebTransportHttp3::HeadersReceived(const spdy::SpdyHeaderBlock& headers) {
diff --git a/quic/core/http/web_transport_http3.h b/quic/core/http/web_transport_http3.h
index 6b2c3cd..7401f39 100644
--- a/quic/core/http/web_transport_http3.h
+++ b/quic/core/http/web_transport_http3.h
@@ -48,10 +48,13 @@
 
   void AssociateStream(QuicStreamId stream_id);
   void OnStreamClosed(QuicStreamId stream_id) { streams_.erase(stream_id); }
-  void CloseAllAssociatedStreams();
+  void OnConnectStreamClosing();
 
   size_t NumberOfAssociatedStreams() { return streams_.size(); }
 
+  void CloseSession(WebTransportSessionError error_code,
+                    absl::string_view error_message) override;
+
   // Return the earliest incoming stream that has been received by the session
   // but has not been accepted.  Returns nullptr if there are no incoming
   // streams.
@@ -96,6 +99,11 @@
   absl::flat_hash_set<QuicStreamId> streams_;
   quiche::QuicheCircularDeque<QuicStreamId> incoming_bidirectional_streams_;
   quiche::QuicheCircularDeque<QuicStreamId> incoming_unidirectional_streams_;
+
+  // Those are set to default values, which are used if the session is not
+  // closed cleanly using an appropriate capsule.
+  WebTransportSessionError error_code_ = 0;
+  std::string error_message_ = "";
 };
 
 class QUIC_EXPORT_PRIVATE WebTransportHttp3UnidirectionalStream
diff --git a/quic/core/quic_types.h b/quic/core/quic_types.h
index 20927b2..80ba33e 100644
--- a/quic/core/quic_types.h
+++ b/quic/core/quic_types.h
@@ -58,6 +58,8 @@
 using WebTransportSessionId = uint64_t;
 // WebTransport stream reset codes are 8-bit.
 using WebTransportStreamError = uint8_t;
+// WebTransport session error codes are 32-bit.
+using WebTransportSessionError = uint32_t;
 
 enum : size_t { kQuicPathFrameBufferSize = 8 };
 using QuicPathFrameBuffer = std::array<uint8_t, kQuicPathFrameBufferSize>;
diff --git a/quic/core/web_transport_interface.h b/quic/core/web_transport_interface.h
index 214d262..dbfce33 100644
--- a/quic/core/web_transport_interface.h
+++ b/quic/core/web_transport_interface.h
@@ -87,6 +87,10 @@
   // data.
   virtual void OnSessionReady(const spdy::SpdyHeaderBlock& headers) = 0;
 
+  // Notifies the visitor when the session has been closed.
+  virtual void OnSessionClosed(WebTransportSessionError error_code,
+                               const std::string& error_message) = 0;
+
   // Notifies the visitor when a new stream has been received.  The stream in
   // question can be retrieved using AcceptIncomingBidirectionalStream() or
   // AcceptIncomingUnidirectionalStream().
@@ -106,6 +110,11 @@
  public:
   virtual ~WebTransportSession() {}
 
+  // Closes the WebTransport session in question with the specified |error_code|
+  // and |error_message|.
+  virtual void CloseSession(WebTransportSessionError error_code,
+                            absl::string_view error_message) = 0;
+
   // Return the earliest incoming stream that has been received by the session
   // but has not been accepted.  Returns nullptr if there are no incoming
   // streams.
diff --git a/quic/quic_transport/quic_transport_client_session.h b/quic/quic_transport/quic_transport_client_session.h
index ad29a15..0a890f2 100644
--- a/quic/quic_transport/quic_transport_client_session.h
+++ b/quic/quic_transport/quic_transport_client_session.h
@@ -17,6 +17,7 @@
 #include "quic/core/quic_crypto_client_stream.h"
 #include "quic/core/quic_crypto_stream.h"
 #include "quic/core/quic_datagram_queue.h"
+#include "quic/core/quic_error_codes.h"
 #include "quic/core/quic_server_id.h"
 #include "quic/core/quic_session.h"
 #include "quic/core/quic_stream.h"
@@ -113,6 +114,13 @@
   void OnProofVerifyDetailsAvailable(
       const ProofVerifyDetails& verify_details) override;
 
+  void CloseSession(WebTransportSessionError /*error_code*/,
+                    absl::string_view error_message) override {
+    connection()->CloseConnection(
+        QUIC_NO_ERROR, std::string(error_message),
+        ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET);
+  }
+
  protected:
   class QUIC_EXPORT_PRIVATE ClientIndication : public QuicStream {
    public:
diff --git a/quic/test_tools/quic_transport_test_tools.h b/quic/test_tools/quic_transport_test_tools.h
index ec4d42d..afa8b54 100644
--- a/quic/test_tools/quic_transport_test_tools.h
+++ b/quic/test_tools/quic_transport_test_tools.h
@@ -15,6 +15,8 @@
 class MockClientVisitor : public WebTransportVisitor {
  public:
   MOCK_METHOD(void, OnSessionReady, (const spdy::SpdyHeaderBlock&), (override));
+  MOCK_METHOD(void, OnSessionClosed,
+              (WebTransportSessionError, const std::string&), (override));
   MOCK_METHOD(void, OnIncomingBidirectionalStreamAvailable, (), (override));
   MOCK_METHOD(void, OnIncomingUnidirectionalStreamAvailable, (), (override));
   MOCK_METHOD(void, OnDatagramReceived, (absl::string_view), (override));
diff --git a/quic/tools/web_transport_test_visitors.h b/quic/tools/web_transport_test_visitors.h
index c8d8513..9d7a5fc 100644
--- a/quic/tools/web_transport_test_visitors.h
+++ b/quic/tools/web_transport_test_visitors.h
@@ -147,6 +147,9 @@
     }
   }
 
+  void OnSessionClosed(WebTransportSessionError /*error_code*/,
+                       const std::string& /*error_message*/) override {}
+
   void OnIncomingBidirectionalStreamAvailable() override {
     while (true) {
       WebTransportStream* stream =