Handle closing WebTransport sessions.
This currently does not actually send an appropriate capsule, since the code for capsules is not around yet.
PiperOrigin-RevId: 396466894
diff --git a/quic/core/http/end_to_end_test.cc b/quic/core/http/end_to_end_test.cc
index c6c6780..31041fa 100644
--- a/quic/core/http/end_to_end_test.cc
+++ b/quic/core/http/end_to_end_test.cc
@@ -6225,6 +6225,39 @@
EXPECT_GT(received, 0);
}
+TEST_P(EndToEndTest, WebTransportSessionClose) {
+ enable_web_transport_ = true;
+ ASSERT_TRUE(Initialize());
+
+ if (!version_.UsesHttp3()) {
+ return;
+ }
+
+ WebTransportHttp3* session =
+ CreateWebTransportSession("/echo", /*wait_for_server_response=*/true);
+ ASSERT_TRUE(session != nullptr);
+ NiceMock<MockClientVisitor>& visitor = SetupWebTransportVisitor(session);
+
+ WebTransportStream* stream = session->OpenOutgoingBidirectionalStream();
+ ASSERT_TRUE(stream != nullptr);
+ QuicStreamId stream_id = stream->GetStreamId();
+ EXPECT_TRUE(stream->Write("test"));
+ // Keep stream open.
+
+ bool close_received = false;
+ // TODO(vasilvv): once we have capsule support, actually check the error code
+ // and the error message returned.
+ EXPECT_CALL(visitor, OnSessionClosed(_, _))
+ .WillOnce(Assign(&close_received, true));
+ session->CloseSession(42, "test error");
+ client_->WaitUntil(2000, [&]() { return close_received; });
+ EXPECT_TRUE(close_received);
+
+ QuicSpdyStream* spdy_stream =
+ GetClientSession()->GetOrCreateSpdyDataStream(stream_id);
+ EXPECT_TRUE(spdy_stream == nullptr);
+}
+
} // namespace
} // namespace test
} // namespace quic
diff --git a/quic/core/http/quic_spdy_stream.cc b/quic/core/http/quic_spdy_stream.cc
index 0f0c7dc..f2b8f82 100644
--- a/quic/core/http/quic_spdy_stream.cc
+++ b/quic/core/http/quic_spdy_stream.cc
@@ -832,7 +832,7 @@
}
if (web_transport_ != nullptr) {
- web_transport_->CloseAllAssociatedStreams();
+ web_transport_->OnConnectStreamClosing();
}
if (web_transport_data_ != nullptr) {
WebTransportHttp3* web_transport =
diff --git a/quic/core/http/web_transport_http3.cc b/quic/core/http/web_transport_http3.cc
index b15ff16..679cb2b 100644
--- a/quic/core/http/web_transport_http3.cc
+++ b/quic/core/http/web_transport_http3.cc
@@ -13,6 +13,7 @@
#include "quic/core/http/quic_spdy_stream.h"
#include "quic/core/quic_data_reader.h"
#include "quic/core/quic_data_writer.h"
+#include "quic/core/quic_error_codes.h"
#include "quic/core/quic_stream.h"
#include "quic/core/quic_types.h"
#include "quic/core/quic_utils.h"
@@ -28,6 +29,8 @@
namespace {
class QUIC_NO_EXPORT NoopWebTransportVisitor : public WebTransportVisitor {
void OnSessionReady(const spdy::SpdyHeaderBlock&) override {}
+ void OnSessionClosed(WebTransportSessionError /*error_code*/,
+ const std::string& /*error_message*/) override {}
void OnIncomingBidirectionalStreamAvailable() override {}
void OnIncomingUnidirectionalStreamAvailable() override {}
void OnDatagramReceived(absl::string_view /*datagram*/) override {}
@@ -70,7 +73,7 @@
}
}
-void WebTransportHttp3::CloseAllAssociatedStreams() {
+void WebTransportHttp3::OnConnectStreamClosing() {
// Copy the stream list before iterating over it, as calls to ResetStream()
// can potentially mutate the |session_| list.
std::vector<QuicStreamId> streams(streams_.begin(), streams_.end());
@@ -83,6 +86,16 @@
connect_stream_->UnregisterHttp3DatagramContextId(context_id_);
}
connect_stream_->UnregisterHttp3DatagramRegistrationVisitor();
+
+ visitor_->OnSessionClosed(error_code_, error_message_);
+}
+
+void WebTransportHttp3::CloseSession(WebTransportSessionError /*error_code*/,
+ absl::string_view /*error_message*/) {
+ // TODO(vasilvv): this should write a capsule and send FIN instead, but since
+ // we currently don't handle capsules, this is the next most meaningful
+ // choice.
+ connect_stream_->Reset(QUIC_STREAM_CANCELLED);
}
void WebTransportHttp3::HeadersReceived(const spdy::SpdyHeaderBlock& headers) {
diff --git a/quic/core/http/web_transport_http3.h b/quic/core/http/web_transport_http3.h
index 6b2c3cd..7401f39 100644
--- a/quic/core/http/web_transport_http3.h
+++ b/quic/core/http/web_transport_http3.h
@@ -48,10 +48,13 @@
void AssociateStream(QuicStreamId stream_id);
void OnStreamClosed(QuicStreamId stream_id) { streams_.erase(stream_id); }
- void CloseAllAssociatedStreams();
+ void OnConnectStreamClosing();
size_t NumberOfAssociatedStreams() { return streams_.size(); }
+ void CloseSession(WebTransportSessionError error_code,
+ absl::string_view error_message) override;
+
// Return the earliest incoming stream that has been received by the session
// but has not been accepted. Returns nullptr if there are no incoming
// streams.
@@ -96,6 +99,11 @@
absl::flat_hash_set<QuicStreamId> streams_;
quiche::QuicheCircularDeque<QuicStreamId> incoming_bidirectional_streams_;
quiche::QuicheCircularDeque<QuicStreamId> incoming_unidirectional_streams_;
+
+ // Those are set to default values, which are used if the session is not
+ // closed cleanly using an appropriate capsule.
+ WebTransportSessionError error_code_ = 0;
+ std::string error_message_ = "";
};
class QUIC_EXPORT_PRIVATE WebTransportHttp3UnidirectionalStream
diff --git a/quic/core/quic_types.h b/quic/core/quic_types.h
index 20927b2..80ba33e 100644
--- a/quic/core/quic_types.h
+++ b/quic/core/quic_types.h
@@ -58,6 +58,8 @@
using WebTransportSessionId = uint64_t;
// WebTransport stream reset codes are 8-bit.
using WebTransportStreamError = uint8_t;
+// WebTransport session error codes are 32-bit.
+using WebTransportSessionError = uint32_t;
enum : size_t { kQuicPathFrameBufferSize = 8 };
using QuicPathFrameBuffer = std::array<uint8_t, kQuicPathFrameBufferSize>;
diff --git a/quic/core/web_transport_interface.h b/quic/core/web_transport_interface.h
index 214d262..dbfce33 100644
--- a/quic/core/web_transport_interface.h
+++ b/quic/core/web_transport_interface.h
@@ -87,6 +87,10 @@
// data.
virtual void OnSessionReady(const spdy::SpdyHeaderBlock& headers) = 0;
+ // Notifies the visitor when the session has been closed.
+ virtual void OnSessionClosed(WebTransportSessionError error_code,
+ const std::string& error_message) = 0;
+
// Notifies the visitor when a new stream has been received. The stream in
// question can be retrieved using AcceptIncomingBidirectionalStream() or
// AcceptIncomingUnidirectionalStream().
@@ -106,6 +110,11 @@
public:
virtual ~WebTransportSession() {}
+ // Closes the WebTransport session in question with the specified |error_code|
+ // and |error_message|.
+ virtual void CloseSession(WebTransportSessionError error_code,
+ absl::string_view error_message) = 0;
+
// Return the earliest incoming stream that has been received by the session
// but has not been accepted. Returns nullptr if there are no incoming
// streams.
diff --git a/quic/quic_transport/quic_transport_client_session.h b/quic/quic_transport/quic_transport_client_session.h
index ad29a15..0a890f2 100644
--- a/quic/quic_transport/quic_transport_client_session.h
+++ b/quic/quic_transport/quic_transport_client_session.h
@@ -17,6 +17,7 @@
#include "quic/core/quic_crypto_client_stream.h"
#include "quic/core/quic_crypto_stream.h"
#include "quic/core/quic_datagram_queue.h"
+#include "quic/core/quic_error_codes.h"
#include "quic/core/quic_server_id.h"
#include "quic/core/quic_session.h"
#include "quic/core/quic_stream.h"
@@ -113,6 +114,13 @@
void OnProofVerifyDetailsAvailable(
const ProofVerifyDetails& verify_details) override;
+ void CloseSession(WebTransportSessionError /*error_code*/,
+ absl::string_view error_message) override {
+ connection()->CloseConnection(
+ QUIC_NO_ERROR, std::string(error_message),
+ ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET);
+ }
+
protected:
class QUIC_EXPORT_PRIVATE ClientIndication : public QuicStream {
public:
diff --git a/quic/test_tools/quic_transport_test_tools.h b/quic/test_tools/quic_transport_test_tools.h
index ec4d42d..afa8b54 100644
--- a/quic/test_tools/quic_transport_test_tools.h
+++ b/quic/test_tools/quic_transport_test_tools.h
@@ -15,6 +15,8 @@
class MockClientVisitor : public WebTransportVisitor {
public:
MOCK_METHOD(void, OnSessionReady, (const spdy::SpdyHeaderBlock&), (override));
+ MOCK_METHOD(void, OnSessionClosed,
+ (WebTransportSessionError, const std::string&), (override));
MOCK_METHOD(void, OnIncomingBidirectionalStreamAvailable, (), (override));
MOCK_METHOD(void, OnIncomingUnidirectionalStreamAvailable, (), (override));
MOCK_METHOD(void, OnDatagramReceived, (absl::string_view), (override));
diff --git a/quic/tools/web_transport_test_visitors.h b/quic/tools/web_transport_test_visitors.h
index c8d8513..9d7a5fc 100644
--- a/quic/tools/web_transport_test_visitors.h
+++ b/quic/tools/web_transport_test_visitors.h
@@ -147,6 +147,9 @@
}
}
+ void OnSessionClosed(WebTransportSessionError /*error_code*/,
+ const std::string& /*error_message*/) override {}
+
void OnIncomingBidirectionalStreamAvailable() override {
while (true) {
WebTransportStream* stream =