Implement enough of MoQT to perform a handshake between the client and the server.
PiperOrigin-RevId: 570482344
diff --git a/build/source_list.bzl b/build/source_list.bzl
index ebaee62..d4ee29d 100644
--- a/build/source_list.bzl
+++ b/build/source_list.bzl
@@ -1478,14 +1478,17 @@
"quic/moqt/moqt_framer.h",
"quic/moqt/moqt_messages.h",
"quic/moqt/moqt_parser.h",
+ "quic/moqt/moqt_session.h",
"quic/moqt/test_tools/moqt_test_message.h",
]
moqt_srcs = [
"quic/moqt/moqt_framer.cc",
"quic/moqt/moqt_framer_test.cc",
+ "quic/moqt/moqt_integration_test.cc",
"quic/moqt/moqt_messages.cc",
"quic/moqt/moqt_parser.cc",
"quic/moqt/moqt_parser_test.cc",
+ "quic/moqt/moqt_session.cc",
]
binary_http_hdrs = [
"binary_http/binary_http_message.h",
diff --git a/build/source_list.gni b/build/source_list.gni
index 2866a7a..733583e 100644
--- a/build/source_list.gni
+++ b/build/source_list.gni
@@ -1482,14 +1482,17 @@
"src/quiche/quic/moqt/moqt_framer.h",
"src/quiche/quic/moqt/moqt_messages.h",
"src/quiche/quic/moqt/moqt_parser.h",
+ "src/quiche/quic/moqt/moqt_session.h",
"src/quiche/quic/moqt/test_tools/moqt_test_message.h",
]
moqt_srcs = [
"src/quiche/quic/moqt/moqt_framer.cc",
"src/quiche/quic/moqt/moqt_framer_test.cc",
+ "src/quiche/quic/moqt/moqt_integration_test.cc",
"src/quiche/quic/moqt/moqt_messages.cc",
"src/quiche/quic/moqt/moqt_parser.cc",
"src/quiche/quic/moqt/moqt_parser_test.cc",
+ "src/quiche/quic/moqt/moqt_session.cc",
]
binary_http_hdrs = [
"src/quiche/binary_http/binary_http_message.h",
diff --git a/build/source_list.json b/build/source_list.json
index 6f6595a..71f291f 100644
--- a/build/source_list.json
+++ b/build/source_list.json
@@ -1481,14 +1481,17 @@
"quiche/quic/moqt/moqt_framer.h",
"quiche/quic/moqt/moqt_messages.h",
"quiche/quic/moqt/moqt_parser.h",
+ "quiche/quic/moqt/moqt_session.h",
"quiche/quic/moqt/test_tools/moqt_test_message.h"
],
"moqt_srcs": [
"quiche/quic/moqt/moqt_framer.cc",
"quiche/quic/moqt/moqt_framer_test.cc",
+ "quiche/quic/moqt/moqt_integration_test.cc",
"quiche/quic/moqt/moqt_messages.cc",
"quiche/quic/moqt/moqt_parser.cc",
- "quiche/quic/moqt/moqt_parser_test.cc"
+ "quiche/quic/moqt/moqt_parser_test.cc",
+ "quiche/quic/moqt/moqt_session.cc"
],
"binary_http_hdrs": [
"quiche/binary_http/binary_http_message.h"
diff --git a/quiche/quic/core/http/web_transport_stream_adapter.cc b/quiche/quic/core/http/web_transport_stream_adapter.cc
index 18928f3..ff2faba 100644
--- a/quiche/quic/core/http/web_transport_stream_adapter.cc
+++ b/quiche/quic/core/http/web_transport_stream_adapter.cc
@@ -147,6 +147,10 @@
}
bool WebTransportStreamAdapter::SkipBytes(size_t bytes) {
+ if (stream_->read_side_closed()) {
+ // Useful when the stream has been reset in between Peek() and Skip().
+ return true;
+ }
sequencer_->MarkConsumed(bytes);
if (!fin_read_ && sequencer_->IsClosed()) {
fin_read_ = true;
diff --git a/quiche/quic/core/quic_generic_session.cc b/quiche/quic/core/quic_generic_session.cc
index 64e469c..2b767ca 100644
--- a/quiche/quic/core/quic_generic_session.cc
+++ b/quiche/quic/core/quic_generic_session.cc
@@ -193,6 +193,14 @@
quiche::QuicheMemSlice(std::move(buffer))));
}
+void QuicGenericSessionBase::OnConnectionClosed(
+ const QuicConnectionCloseFrame& frame, ConnectionCloseSource source) {
+ QuicSession::OnConnectionClosed(frame, source);
+ visitor_->OnSessionClosed(static_cast<webtransport::SessionErrorCode>(
+ frame.transport_close_frame_type),
+ frame.error_details);
+}
+
QuicGenericClientSession::QuicGenericClientSession(
QuicConnection* connection, bool owns_connection, Visitor* owner,
const QuicConfig& config, std::string host, uint16_t port, std::string alpn,
diff --git a/quiche/quic/core/quic_generic_session.h b/quiche/quic/core/quic_generic_session.h
index 773f170..48d02f3 100644
--- a/quiche/quic/core/quic_generic_session.h
+++ b/quiche/quic/core/quic_generic_session.h
@@ -62,6 +62,8 @@
void OnAlpnSelected(absl::string_view alpn) override {
QUICHE_DCHECK_EQ(alpn, alpn_);
}
+ void OnConnectionClosed(const QuicConnectionCloseFrame& frame,
+ ConnectionCloseSource source) override;
bool ShouldKeepConnectionAlive() const override { return true; }
diff --git a/quiche/quic/moqt/moqt_framer.cc b/quiche/quic/moqt/moqt_framer.cc
index 5806b7d..294b15c 100644
--- a/quiche/quic/moqt/moqt_framer.cc
+++ b/quiche/quic/moqt/moqt_framer.cc
@@ -21,6 +21,10 @@
inline size_t NeededVarIntLen(uint64_t value) {
return static_cast<size_t>(quic::QuicDataWriter::GetVarInt62Len(value));
}
+inline size_t NeededVarIntLen(MoqtVersion value) {
+ return static_cast<size_t>(
+ quic::QuicDataWriter::GetVarInt62Len(static_cast<uint64_t>(value)));
+}
inline size_t ParameterLen(uint64_t type, uint64_t value_len) {
return NeededVarIntLen(type) + NeededVarIntLen(value_len) + value_len;
}
@@ -87,10 +91,12 @@
quiche::QuicheBuffer MoqtFramer::SerializeSetup(const MoqtSetup& message) {
size_t message_len;
if (perspective_ == quic::Perspective::IS_CLIENT) {
- message_len = NeededVarIntLen(message.number_of_supported_versions);
- for (uint64_t i : message.supported_versions) {
- message_len += NeededVarIntLen(i);
+ message_len = NeededVarIntLen(message.supported_versions.size());
+ for (MoqtVersion version : message.supported_versions) {
+ message_len += NeededVarIntLen(version);
}
+ // TODO: figure out if the role needs to be sent on the client side or on
+ // both sides.
if (message.role.has_value()) {
message_len +=
ParameterLen(static_cast<uint64_t>(MoqtSetupParameter::kRole), 1);
@@ -112,12 +118,12 @@
writer.WriteVarInt62(static_cast<uint64_t>(MoqtMessageType::kSetup));
writer.WriteVarInt62(message_len);
if (perspective_ == quic::Perspective::IS_SERVER) {
- writer.WriteVarInt62(message.supported_versions[0]);
+ writer.WriteVarInt62(static_cast<uint64_t>(message.supported_versions[0]));
return buffer;
}
- writer.WriteVarInt62(message.number_of_supported_versions);
- for (uint64_t i : message.supported_versions) {
- writer.WriteVarInt62(i);
+ writer.WriteVarInt62(message.supported_versions.size());
+ for (MoqtVersion version : message.supported_versions) {
+ writer.WriteVarInt62(static_cast<uint64_t>(version));
}
if (message.role.has_value()) {
WriteIntParameter(writer, static_cast<uint64_t>(MoqtSetupParameter::kRole),
diff --git a/quiche/quic/moqt/moqt_integration_test.cc b/quiche/quic/moqt/moqt_integration_test.cc
new file mode 100644
index 0000000..0fa4245
--- /dev/null
+++ b/quiche/quic/moqt/moqt_integration_test.cc
@@ -0,0 +1,174 @@
+// Copyright 2023 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include <memory>
+#include <string>
+
+#include "absl/strings/string_view.h"
+#include "quiche/quic/core/crypto/quic_compressed_certs_cache.h"
+#include "quiche/quic/core/crypto/quic_crypto_client_config.h"
+#include "quiche/quic/core/crypto/quic_crypto_server_config.h"
+#include "quiche/quic/core/crypto/quic_random.h"
+#include "quiche/quic/core/quic_config.h"
+#include "quiche/quic/core/quic_generic_session.h"
+#include "quiche/quic/core/quic_types.h"
+#include "quiche/quic/moqt/moqt_messages.h"
+#include "quiche/quic/moqt/moqt_session.h"
+#include "quiche/quic/test_tools/crypto_test_utils.h"
+#include "quiche/quic/test_tools/simulator/simulator.h"
+#include "quiche/quic/test_tools/simulator/test_harness.h"
+#include "quiche/common/platform/api/quiche_test.h"
+
+namespace moqt::test {
+namespace {
+
+using ::quic::simulator::Simulator;
+using ::testing::_;
+using ::testing::Assign;
+
+class ClientEndpoint : public quic::simulator::QuicEndpointWithConnection {
+ public:
+ ClientEndpoint(Simulator* simulator, const std::string& name,
+ const std::string& peer_name, MoqtVersion version)
+ : QuicEndpointWithConnection(simulator, name, peer_name,
+ quic::Perspective::IS_CLIENT,
+ quic::GetQuicVersionsForGenericSession()),
+ crypto_config_(
+ quic::test::crypto_test_utils::ProofVerifierForTesting()),
+ quic_session_(connection_.get(), false, nullptr, quic::QuicConfig(),
+ "test.example.com", 443, "moqt", &session_,
+ /*visitor_owned=*/false, nullptr, &crypto_config_),
+ session_(
+ &quic_session_,
+ MoqtSessionParameters{.version = version,
+ .perspective = quic::Perspective::IS_CLIENT,
+ .using_webtrans = false},
+ established_callback_.AsStdFunction(),
+ terminated_callback_.AsStdFunction()) {
+ quic_session_.Initialize();
+ }
+
+ MoqtSession* session() { return &session_; }
+ quic::QuicGenericClientSession* quic_session() { return &quic_session_; }
+ testing::MockFunction<void()>& established_callback() {
+ return established_callback_;
+ }
+ testing::MockFunction<void(absl::string_view)>& terminated_callback() {
+ return terminated_callback_;
+ }
+
+ private:
+ testing::MockFunction<void()> established_callback_;
+ testing::MockFunction<void(absl::string_view)> terminated_callback_;
+ quic::QuicCryptoClientConfig crypto_config_;
+ quic::QuicGenericClientSession quic_session_;
+ MoqtSession session_;
+};
+
+class ServerEndpoint : public quic::simulator::QuicEndpointWithConnection {
+ public:
+ ServerEndpoint(Simulator* simulator, const std::string& name,
+ const std::string& peer_name, MoqtVersion version)
+ : QuicEndpointWithConnection(simulator, name, peer_name,
+ quic::Perspective::IS_SERVER,
+ quic::GetQuicVersionsForGenericSession()),
+ compressed_certs_cache_(
+ quic::QuicCompressedCertsCache::kQuicCompressedCertsCacheSize),
+ crypto_config_(quic::QuicCryptoServerConfig::TESTING,
+ quic::QuicRandom::GetInstance(),
+ quic::test::crypto_test_utils::ProofSourceForTesting(),
+ quic::KeyExchangeSource::Default()),
+ quic_session_(connection_.get(), false, nullptr, quic::QuicConfig(),
+ "moqt", &session_,
+ /*visitor_owned=*/false, nullptr, &crypto_config_,
+ &compressed_certs_cache_),
+ session_(
+ &quic_session_,
+ MoqtSessionParameters{.version = version,
+ .perspective = quic::Perspective::IS_SERVER,
+ .using_webtrans = false},
+ established_callback_.AsStdFunction(),
+ terminated_callback_.AsStdFunction()) {
+ quic_session_.Initialize();
+ }
+
+ MoqtSession* session() { return &session_; }
+ testing::MockFunction<void()>& established_callback() {
+ return established_callback_;
+ }
+ testing::MockFunction<void(absl::string_view)>& terminated_callback() {
+ return terminated_callback_;
+ }
+
+ private:
+ testing::MockFunction<void()> established_callback_;
+ testing::MockFunction<void(absl::string_view)> terminated_callback_;
+ quic::QuicCompressedCertsCache compressed_certs_cache_;
+ quic::QuicCryptoServerConfig crypto_config_;
+ quic::QuicGenericServerSession quic_session_;
+ MoqtSession session_;
+};
+
+class MoqtIntegrationTest : public quiche::test::QuicheTest {
+ public:
+ void CreateDefaultEndpoints() {
+ client_ = std::make_unique<ClientEndpoint>(
+ &test_harness_.simulator(), "Client", "Server", MoqtVersion::kDraft01);
+ server_ = std::make_unique<ServerEndpoint>(
+ &test_harness_.simulator(), "Server", "Client", MoqtVersion::kDraft01);
+ test_harness_.set_client(client_.get());
+ test_harness_.set_server(server_.get());
+ }
+
+ void WireUpEndpoints() { test_harness_.WireUpEndpoints(); }
+
+ protected:
+ quic::simulator::TestHarness test_harness_;
+
+ std::unique_ptr<ClientEndpoint> client_;
+ std::unique_ptr<ServerEndpoint> server_;
+};
+
+TEST_F(MoqtIntegrationTest, Handshake) {
+ CreateDefaultEndpoints();
+ WireUpEndpoints();
+
+ client_->quic_session()->CryptoConnect();
+ bool client_established = false;
+ bool server_established = false;
+ EXPECT_CALL(client_->established_callback(), Call())
+ .WillOnce(Assign(&client_established, true));
+ EXPECT_CALL(server_->established_callback(), Call())
+ .WillOnce(Assign(&server_established, true));
+ bool success = test_harness_.RunUntilWithDefaultTimeout(
+ [&]() { return client_established && server_established; });
+ EXPECT_TRUE(success);
+}
+
+TEST_F(MoqtIntegrationTest, VersionMismatch) {
+ client_ = std::make_unique<ClientEndpoint>(
+ &test_harness_.simulator(), "Client", "Server",
+ MoqtVersion::kUnrecognizedVersionForTests);
+ server_ = std::make_unique<ServerEndpoint>(
+ &test_harness_.simulator(), "Server", "Client", MoqtVersion::kDraft01);
+ test_harness_.set_client(client_.get());
+ test_harness_.set_server(server_.get());
+ WireUpEndpoints();
+
+ client_->quic_session()->CryptoConnect();
+ bool client_terminated = false;
+ bool server_terminated = false;
+ EXPECT_CALL(client_->established_callback(), Call()).Times(0);
+ EXPECT_CALL(server_->established_callback(), Call()).Times(0);
+ EXPECT_CALL(client_->terminated_callback(), Call(_))
+ .WillOnce(Assign(&client_terminated, true));
+ EXPECT_CALL(server_->terminated_callback(), Call(_))
+ .WillOnce(Assign(&server_terminated, true));
+ bool success = test_harness_.RunUntilWithDefaultTimeout(
+ [&]() { return client_terminated && server_terminated; });
+ EXPECT_TRUE(success);
+}
+
+} // namespace
+} // namespace moqt::test
diff --git a/quiche/quic/moqt/moqt_messages.h b/quiche/quic/moqt/moqt_messages.h
index a88f284..b137e59 100644
--- a/quiche/quic/moqt/moqt_messages.h
+++ b/quiche/quic/moqt/moqt_messages.h
@@ -15,10 +15,24 @@
#include "absl/strings/string_view.h"
#include "absl/types/optional.h"
#include "quiche/quic/core/quic_time.h"
+#include "quiche/quic/core/quic_types.h"
#include "quiche/common/platform/api/quiche_export.h"
namespace moqt {
+enum class MoqtVersion : uint64_t {
+ kDraft01 = 0xff000001,
+ kUnrecognizedVersionForTests = 0xfe0000ff,
+};
+
+struct QUICHE_EXPORT MoqtSessionParameters {
+ // TODO: support multiple versions.
+ MoqtVersion version;
+ quic::Perspective perspective;
+ bool using_webtrans;
+ std::string path;
+};
+
// The maximum length of a message, excluding any OBJECT payload. This prevents
// DoS attack via forcing the parser to buffer a large message (OBJECT payloads
// are not buffered by the parser).
@@ -56,8 +70,7 @@
};
struct QUICHE_EXPORT MoqtSetup {
- uint64_t number_of_supported_versions;
- std::vector<uint64_t> supported_versions;
+ std::vector<MoqtVersion> supported_versions;
absl::optional<MoqtRole> role;
absl::optional<absl::string_view> path;
};
diff --git a/quiche/quic/moqt/moqt_parser.cc b/quiche/quic/moqt/moqt_parser.cc
index 3de44cf..7913d27 100644
--- a/quiche/quic/moqt/moqt_parser.cc
+++ b/quiche/quic/moqt/moqt_parser.cc
@@ -345,19 +345,20 @@
absl::optional<size_t> MoqtParser::ProcessSetup(absl::string_view data) {
MoqtSetup setup;
quic::QuicDataReader reader(data);
+ uint64_t number_of_supported_versions;
if (perspective_ == quic::Perspective::IS_SERVER) {
- if (!reader.ReadVarInt62(&setup.number_of_supported_versions)) {
+ if (!reader.ReadVarInt62(&number_of_supported_versions)) {
return absl::nullopt;
}
} else {
- setup.number_of_supported_versions = 1;
+ number_of_supported_versions = 1;
}
uint64_t value;
- for (uint64_t i = 0; i < setup.number_of_supported_versions; ++i) {
+ for (uint64_t i = 0; i < number_of_supported_versions; ++i) {
if (!reader.ReadVarInt62(&value)) {
return absl::nullopt;
}
- setup.supported_versions.push_back(value);
+ setup.supported_versions.push_back(static_cast<MoqtVersion>(value));
}
// Parse parameters
while (!reader.IsDoneReading()) {
diff --git a/quiche/quic/moqt/moqt_session.cc b/quiche/quic/moqt/moqt_session.cc
new file mode 100644
index 0000000..f8b89ea
--- /dev/null
+++ b/quiche/quic/moqt/moqt_session.cc
@@ -0,0 +1,160 @@
+// Copyright 2023 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "quiche/quic/moqt/moqt_session.h"
+
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "quiche/quic/core/quic_types.h"
+#include "quiche/quic/moqt/moqt_messages.h"
+#include "quiche/common/platform/api/quiche_logging.h"
+#include "quiche/common/quiche_buffer_allocator.h"
+#include "quiche/common/quiche_stream.h"
+#include "quiche/web_transport/web_transport.h"
+
+#define ENDPOINT \
+ (perspective() == Perspective::IS_SERVER ? "MoQT Server: " : "MoQT Client: ")
+
+namespace moqt {
+
+using ::quic::Perspective;
+
+void MoqtSession::OnSessionReady() {
+ QUICHE_DLOG(INFO) << ENDPOINT << "Underlying session ready";
+ if (parameters_.perspective == Perspective::IS_SERVER) {
+ return;
+ }
+
+ webtransport::Stream* control_stream =
+ session_->OpenOutgoingBidirectionalStream();
+ if (control_stream == nullptr) {
+ Error("Unable to open a control stream");
+ return;
+ }
+ control_stream->SetVisitor(std::make_unique<Stream>(
+ this, control_stream, /*is_control_stream=*/true));
+ control_stream_ = control_stream->GetStreamId();
+ MoqtSetup setup = MoqtSetup{
+ .supported_versions = std::vector<MoqtVersion>{parameters_.version},
+ .role = MoqtRole::kBoth};
+ if (!parameters_.using_webtrans) {
+ setup.path = parameters_.path;
+ }
+ quiche::QuicheBuffer serialized_setup = framer_.SerializeSetup(setup);
+ bool success = control_stream->Write(serialized_setup.AsStringView());
+ if (!success) {
+ Error("Failed to write client SETUP message");
+ return;
+ }
+ QUICHE_DLOG(INFO) << ENDPOINT << "Send the SETUP message";
+}
+
+void MoqtSession::OnIncomingBidirectionalStreamAvailable() {
+ while (webtransport::Stream* stream =
+ session_->AcceptIncomingBidirectionalStream()) {
+ stream->SetVisitor(std::make_unique<Stream>(this, stream));
+ stream->visitor()->OnCanRead();
+ }
+}
+void MoqtSession::OnIncomingUnidirectionalStreamAvailable() {
+ while (webtransport::Stream* stream =
+ session_->AcceptIncomingUnidirectionalStream()) {
+ stream->SetVisitor(std::make_unique<Stream>(this, stream));
+ stream->visitor()->OnCanRead();
+ }
+}
+
+void MoqtSession::OnSessionClosed(webtransport::SessionErrorCode,
+ const std::string& error_message) {
+ if (!error_.empty()) {
+ // Avoid erroring out twice.
+ return;
+ }
+ QUICHE_DLOG(INFO) << ENDPOINT << "Underlying session closed with message: "
+ << error_message;
+ error_ = error_message;
+ std::move(session_terminated_callback_)(error_message);
+}
+
+void MoqtSession::Error(absl::string_view error) {
+ if (!error_.empty()) {
+ // Avoid erroring out twice.
+ return;
+ }
+ QUICHE_DLOG(INFO) << ENDPOINT
+ << "MOQT session closed with message: " << error;
+ error_ = std::string(error);
+ // TODO(vasilvv): figure out the error code.
+ session_->CloseSession(1, error);
+ std::move(session_terminated_callback_)(error);
+}
+
+void MoqtSession::Stream::OnCanRead() {
+ bool fin =
+ quiche::ProcessAllReadableRegions(*stream_, [&](absl::string_view chunk) {
+ parser_.ProcessData(chunk, /*end_of_stream=*/false);
+ });
+ if (fin) {
+ parser_.ProcessData("", /*end_of_stream=*/true);
+ }
+}
+void MoqtSession::Stream::OnCanWrite() {}
+void MoqtSession::Stream::OnResetStreamReceived(
+ webtransport::StreamErrorCode error) {
+ if (is_control_stream_.has_value() && *is_control_stream_) {
+ session_->Error(
+ absl::StrCat("Control stream reset with error code ", error));
+ }
+}
+void MoqtSession::Stream::OnStopSendingReceived(
+ webtransport::StreamErrorCode error) {
+ if (is_control_stream_.has_value() && *is_control_stream_) {
+ session_->Error(
+ absl::StrCat("Control stream reset with error code ", error));
+ }
+}
+
+void MoqtSession::Stream::OnSetupMessage(const MoqtSetup& message) {
+ if (is_control_stream_.has_value()) {
+ if (!*is_control_stream_) {
+ session_->Error("Received SETUP on non-control stream");
+ return;
+ }
+ } else {
+ is_control_stream_ = true;
+ }
+ if (absl::c_find(message.supported_versions, session_->parameters_.version) ==
+ message.supported_versions.end()) {
+ session_->Error(absl::StrCat("Version mismatch: expected 0x",
+ absl::Hex(session_->parameters_.version)));
+ return;
+ }
+ QUICHE_DLOG(INFO) << ENDPOINT << "Received the SETUP message";
+ if (session_->parameters_.perspective == Perspective::IS_SERVER) {
+ MoqtSetup response =
+ MoqtSetup{.supported_versions =
+ std::vector<MoqtVersion>{session_->parameters_.version},
+ .role = MoqtRole::kBoth};
+ bool success = stream_->Write(
+ session_->framer_.SerializeSetup(response).AsStringView());
+ if (!success) {
+ session_->Error("Failed to write server SETUP message");
+ return;
+ }
+ QUICHE_DLOG(INFO) << ENDPOINT << "Sent the SETUP message";
+ }
+ // TODO: handle role and path.
+ std::move(session_->session_established_callback_)();
+}
+
+void MoqtSession::Stream::OnParsingError(absl::string_view reason) {
+ session_->Error(absl::StrCat("Parse error: ", reason));
+}
+
+} // namespace moqt
diff --git a/quiche/quic/moqt/moqt_session.h b/quiche/quic/moqt/moqt_session.h
new file mode 100644
index 0000000..0faa8c0
--- /dev/null
+++ b/quiche/quic/moqt/moqt_session.h
@@ -0,0 +1,119 @@
+// Copyright 2023 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef QUICHE_QUIC_MOQT_MOQT_SESSION_H_
+#define QUICHE_QUIC_MOQT_MOQT_SESSION_H_
+
+#include <string>
+#include <utility>
+
+#include "absl/strings/string_view.h"
+#include "absl/types/optional.h"
+#include "quiche/quic/core/quic_types.h"
+#include "quiche/quic/moqt/moqt_framer.h"
+#include "quiche/quic/moqt/moqt_messages.h"
+#include "quiche/quic/moqt/moqt_parser.h"
+#include "quiche/common/platform/api/quiche_export.h"
+#include "quiche/common/quiche_callbacks.h"
+#include "quiche/common/simple_buffer_allocator.h"
+#include "quiche/web_transport/web_transport.h"
+
+namespace moqt {
+
+using MoqtSessionEstablishedCallback = quiche::SingleUseCallback<void()>;
+using MoqtSessionTerminatedCallback =
+ quiche::SingleUseCallback<void(absl::string_view error_message)>;
+
+class QUICHE_EXPORT MoqtSession : public webtransport::SessionVisitor {
+ public:
+ MoqtSession(webtransport::Session* session, MoqtSessionParameters parameters,
+ MoqtSessionEstablishedCallback session_established_callback,
+ MoqtSessionTerminatedCallback session_terminated_callback)
+ : session_(session),
+ parameters_(parameters),
+ session_established_callback_(std::move(session_established_callback)),
+ session_terminated_callback_(std::move(session_terminated_callback)),
+ framer_(quiche::SimpleBufferAllocator::Get(), parameters.perspective,
+ parameters.using_webtrans) {}
+
+ // webtransport::SessionVisitor implementation.
+ void OnSessionReady() override;
+ void OnSessionClosed(webtransport::SessionErrorCode,
+ const std::string&) override;
+ void OnIncomingBidirectionalStreamAvailable() override;
+ void OnIncomingUnidirectionalStreamAvailable() override;
+ void OnDatagramReceived(absl::string_view datagram) override {}
+ void OnCanCreateNewOutgoingBidirectionalStream() override {}
+ void OnCanCreateNewOutgoingUnidirectionalStream() override {}
+
+ void Error(absl::string_view error);
+
+ quic::Perspective perspective() const { return parameters_.perspective; }
+
+ private:
+ class QUICHE_EXPORT Stream : public webtransport::StreamVisitor,
+ public MoqtParserVisitor {
+ public:
+ Stream(MoqtSession* session, webtransport::Stream* stream)
+ : session_(session),
+ stream_(stream),
+ parser_(session->parameters_.perspective,
+ session->parameters_.using_webtrans, *this) {}
+ Stream(MoqtSession* session, webtransport::Stream* stream,
+ bool is_control_stream)
+ : session_(session),
+ stream_(stream),
+ parser_(session->parameters_.perspective,
+ session->parameters_.using_webtrans, *this),
+ is_control_stream_(is_control_stream) {}
+
+ // webtransport::StreamVisitor implementation.
+ void OnCanRead() override;
+ void OnCanWrite() override;
+ void OnResetStreamReceived(webtransport::StreamErrorCode error) override;
+ void OnStopSendingReceived(webtransport::StreamErrorCode error) override;
+ void OnWriteSideInDataRecvdState() override {}
+
+ // MoqtParserVisitor implementation.
+ void OnObjectMessage(const MoqtObject& message, absl::string_view payload,
+ bool end_of_message) override {}
+ void OnSetupMessage(const MoqtSetup& message) override;
+ void OnSubscribeRequestMessage(
+ const MoqtSubscribeRequest& message) override {}
+ void OnSubscribeOkMessage(const MoqtSubscribeOk& message) override {}
+ void OnSubscribeErrorMessage(const MoqtSubscribeError& message) override {}
+ void OnUnsubscribeMessage(const MoqtUnsubscribe& message) override {}
+ void OnAnnounceMessage(const MoqtAnnounce& message) override {}
+ void OnAnnounceOkMessage(const MoqtAnnounceOk& message) override {}
+ void OnAnnounceErrorMessage(const MoqtAnnounceError& message) override {}
+ void OnUnannounceMessage(const MoqtUnannounce& message) override {}
+ void OnGoAwayMessage() override {}
+ void OnParsingError(absl::string_view reason) override;
+
+ quic::Perspective perspective() const {
+ return session_->parameters_.perspective;
+ }
+
+ private:
+ MoqtSession* session_;
+ webtransport::Stream* stream_;
+ MoqtParser parser_;
+ // nullopt means "incoming stream, and we don't know if it's the control
+ // stream or a data stream yet".
+ absl::optional<bool> is_control_stream_;
+ };
+
+ webtransport::Session* session_;
+ MoqtSessionParameters parameters_;
+ MoqtSessionEstablishedCallback session_established_callback_;
+ MoqtSessionTerminatedCallback session_terminated_callback_;
+ MoqtFramer framer_;
+
+ absl::optional<webtransport::StreamId> control_stream_;
+ std::string error_;
+};
+
+} // namespace moqt
+
+#endif // QUICHE_QUIC_MOQT_MOQT_SESSION_H_
diff --git a/quiche/quic/moqt/test_tools/moqt_test_message.h b/quiche/quic/moqt/test_tools/moqt_test_message.h
index 876a619..7cf4803 100644
--- a/quiche/quic/moqt/test_tools/moqt_test_message.h
+++ b/quiche/quic/moqt/test_tools/moqt_test_message.h
@@ -234,12 +234,11 @@
bool EqualFieldValues(MessageStructuredData& values) const override {
auto cast = std::get<MoqtSetup>(values);
const MoqtSetup* compare = client_ ? &server_setup_ : &client_setup_;
- if (cast.number_of_supported_versions !=
- compare->number_of_supported_versions) {
+ if (cast.supported_versions.size() != compare->supported_versions.size()) {
QUIC_LOG(INFO) << "SETUP number of supported versions mismatch";
return false;
}
- for (uint64_t i = 0; i < cast.number_of_supported_versions; ++i) {
+ for (uint64_t i = 0; i < cast.supported_versions.size(); ++i) {
// Listed versions are 1 and 2, in that order.
if (cast.supported_versions[i] != compare->supported_versions[i]) {
QUIC_LOG(INFO) << "SETUP supported version mismatch";
@@ -284,14 +283,14 @@
0x01, // version
};
MoqtSetup client_setup_ = {
- /*number_of_supported_versions=*/2,
- /*supported_versions=*/std::vector<uint64_t>({1, 2}),
+ /*supported_versions=*/std::vector<MoqtVersion>(
+ {static_cast<MoqtVersion>(1), static_cast<MoqtVersion>(2)}),
/*role=*/MoqtRole::kBoth,
/*path=*/"foo",
};
MoqtSetup server_setup_ = {
- /*number_of_supported_versions=*/1,
- /*supported_versions=*/std::vector<uint64_t>({1}),
+ /*supported_versions=*/std::vector<MoqtVersion>(
+ {static_cast<MoqtVersion>(1)}),
/*role=*/absl::nullopt,
/*path=*/absl::nullopt,
};