Clean up MoQT error cases: - Added error codes from the spec. - Reject opening of second bidirectional stream. PiperOrigin-RevId: 601451814
diff --git a/quiche/quic/moqt/moqt_session.cc b/quiche/quic/moqt/moqt_session.cc index b17a40f..adb4689 100644 --- a/quiche/quic/moqt/moqt_session.cc +++ b/quiche/quic/moqt/moqt_session.cc
@@ -44,7 +44,7 @@ webtransport::Stream* control_stream = session_->OpenOutgoingBidirectionalStream(); if (control_stream == nullptr) { - Error("Unable to open a control stream"); + Error(kGenericError, "Unable to open a control stream"); return; } control_stream->SetVisitor(std::make_unique<Stream>( @@ -60,7 +60,7 @@ quiche::QuicheBuffer serialized_setup = framer_.SerializeClientSetup(setup); bool success = control_stream->Write(serialized_setup.AsStringView()); if (!success) { - Error("Failed to write client SETUP message"); + Error(kGenericError, "Failed to write client SETUP message"); return; } QUIC_DLOG(INFO) << ENDPOINT << "Send the SETUP message"; @@ -81,6 +81,10 @@ void MoqtSession::OnIncomingBidirectionalStreamAvailable() { while (webtransport::Stream* stream = session_->AcceptIncomingBidirectionalStream()) { + if (control_stream_.has_value()) { + Error(kProtocolViolation, "Bidirectional stream already open"); + return; + } stream->SetVisitor(std::make_unique<Stream>(this, stream)); stream->visitor()->OnCanRead(); } @@ -93,16 +97,15 @@ } } -void MoqtSession::Error(absl::string_view error) { +void MoqtSession::Error(MoqtError code, absl::string_view error) { if (!error_.empty()) { // Avoid erroring out twice. return; } - QUICHE_DLOG(INFO) << ENDPOINT - << "MOQT session closed with message: " << error; + QUICHE_DLOG(INFO) << ENDPOINT << "MOQT session closed with code: " + << static_cast<int>(code) << " and message: " << error; error_ = std::string(error); - // TODO(vasilvv): figure out the error code. - session_->CloseSession(1, error); + session_->CloseSession(code, error); std::move(session_terminated_callback_)(error); } @@ -126,7 +129,7 @@ bool success = session_->GetStreamById(*control_stream_) ->Write(framer_.SerializeAnnounce(message).AsStringView()); if (!success) { - Error("Failed to write ANNOUNCE message"); + Error(kGenericError, "Failed to write ANNOUNCE message"); return; } QUIC_DLOG(INFO) << ENDPOINT << "Sent ANNOUNCE message for " @@ -227,7 +230,7 @@ session_->GetStreamById(*control_stream_) ->Write(framer_.SerializeSubscribeRequest(message).AsStringView()); if (!success) { - Error("Failed to write SUBSCRIBE_REQUEST message"); + Error(kGenericError, "Failed to write SUBSCRIBE_REQUEST message"); return false; } QUIC_DLOG(INFO) << ENDPOINT << "Sent SUBSCRIBE_REQUEST message for " @@ -316,6 +319,7 @@ webtransport::StreamErrorCode error) { if (is_control_stream_.has_value() && *is_control_stream_) { session_->Error( + kProtocolViolation, absl::StrCat("Control stream reset with error code ", error)); } } @@ -323,6 +327,7 @@ webtransport::StreamErrorCode error) { if (is_control_stream_.has_value() && *is_control_stream_) { session_->Error( + kProtocolViolation, absl::StrCat("Control stream reset with error code ", error)); } } @@ -331,7 +336,8 @@ absl::string_view payload, bool end_of_message) { if (is_control_stream_ == true) { - session_->Error("Received OBJECT message on control stream"); + session_->Error(kProtocolViolation, + "Received OBJECT message on control stream"); return; } QUICHE_DLOG(INFO) << ENDPOINT << "Received OBJECT message on stream " @@ -365,7 +371,7 @@ } if (session_->num_buffered_objects_ >= kMaxBufferedObjects) { session_->num_buffered_objects_++; - session_->Error("Too many buffered objects"); + session_->Error(kGenericError, "Too many buffered objects"); return; } queue->push_back(BufferedObject(stream_->GetStreamId(), message, payload, @@ -387,19 +393,22 @@ void MoqtSession::Stream::OnClientSetupMessage(const MoqtClientSetup& message) { if (is_control_stream_.has_value()) { if (!*is_control_stream_) { - session_->Error("Received SETUP on non-control stream"); + session_->Error(kProtocolViolation, + "Received SETUP on non-control stream"); return; } } else { is_control_stream_ = true; } if (perspective() == Perspective::IS_CLIENT) { - session_->Error("Received CLIENT_SETUP from server"); + session_->Error(kProtocolViolation, "Received CLIENT_SETUP from server"); return; } if (absl::c_find(message.supported_versions, session_->parameters_.version) == message.supported_versions.end()) { - session_->Error(absl::StrCat("Version mismatch: expected 0x", + // TODO(martinduke): Is this the right error code? See issue #346. + session_->Error(kProtocolViolation, + absl::StrCat("Version mismatch: expected 0x", absl::Hex(session_->parameters_.version))); return; } @@ -411,7 +420,7 @@ bool success = stream_->Write( session_->framer_.SerializeServerSetup(response).AsStringView()); if (!success) { - session_->Error("Failed to write server SETUP message"); + session_->Error(kGenericError, "Failed to write server SETUP message"); return; } QUIC_DLOG(INFO) << ENDPOINT << "Sent the SETUP message"; @@ -423,18 +432,21 @@ void MoqtSession::Stream::OnServerSetupMessage(const MoqtServerSetup& message) { if (is_control_stream_.has_value()) { if (!*is_control_stream_) { - session_->Error("Received SETUP on non-control stream"); + session_->Error(kProtocolViolation, + "Received SETUP on non-control stream"); return; } } else { is_control_stream_ = true; } if (perspective() == Perspective::IS_SERVER) { - session_->Error("Received SERVER_SETUP from client"); + session_->Error(kProtocolViolation, "Received SERVER_SETUP from client"); return; } if (message.selected_version != session_->parameters_.version) { - session_->Error(absl::StrCat("Version mismatch: expected 0x", + // TODO(martinduke): Is this the right error code? See issue #346. + session_->Error(kProtocolViolation, + absl::StrCat("Version mismatch: expected 0x", absl::Hex(session_->parameters_.version))); return; } @@ -455,7 +467,7 @@ stream_->Write(session_->framer_.SerializeSubscribeError(subscribe_error) .AsStringView()); if (!success) { - session_->Error("Failed to write SUBSCRIBE_ERROR message"); + session_->Error(kGenericError, "Failed to write SUBSCRIBE_ERROR message"); } } @@ -503,7 +515,7 @@ bool success = stream_->Write( session_->framer_.SerializeSubscribeOk(subscribe_ok).AsStringView()); if (!success) { - session_->Error("Failed to write SUBSCRIBE_OK message"); + session_->Error(kGenericError, "Failed to write SUBSCRIBE_OK message"); return; } QUIC_DLOG(INFO) << ENDPOINT << "Created subscription for " @@ -521,13 +533,14 @@ return; } if (session_->tracks_by_alias_.contains(message.track_id)) { - session_->Error("Received duplicate track_alias"); + session_->Error(kDuplicateTrackAlias, "Received duplicate track_alias"); return; } auto it = session_->remote_tracks_.find(FullTrackName( std::string(message.track_namespace), std::string(message.track_name))); if (it == session_->remote_tracks_.end()) { - session_->Error("Received SUBSCRIBE_OK for nonexistent subscribe"); + session_->Error(kProtocolViolation, + "Received SUBSCRIBE_OK for nonexistent subscribe"); return; } // Note that if there are multiple SUBSCRIBE_OK for the same track, @@ -573,7 +586,8 @@ auto it = session_->remote_tracks_.find(FullTrackName( std::string(message.track_namespace), std::string(message.track_name))); if (it == session_->remote_tracks_.end()) { - session_->Error("Received SUBSCRIBE_ERROR for nonexistent subscribe"); + session_->Error(kProtocolViolation, + "Received SUBSCRIBE_ERROR for nonexistent subscribe"); return; } QUIC_DLOG(INFO) << ENDPOINT << "Received the SUBSCRIBE_ERROR for " @@ -594,7 +608,7 @@ bool success = stream_->Write(session_->framer_.SerializeAnnounceOk(ok).AsStringView()); if (!success) { - session_->Error("Failed to write ANNOUNCE_OK message"); + session_->Error(kGenericError, "Failed to write ANNOUNCE_OK message"); return; } } @@ -605,7 +619,8 @@ } auto it = session_->pending_outgoing_announces_.find(message.track_namespace); if (it == session_->pending_outgoing_announces_.end()) { - session_->Error("Received ANNOUNCE_OK for nonexistent announce"); + session_->Error(kProtocolViolation, + "Received ANNOUNCE_OK for nonexistent announce"); return; } std::move(it->second)(message.track_namespace, std::nullopt); @@ -619,7 +634,8 @@ } auto it = session_->pending_outgoing_announces_.find(message.track_namespace); if (it == session_->pending_outgoing_announces_.end()) { - session_->Error("Received ANNOUNCE_ERROR for nonexistent announce"); + session_->Error(kProtocolViolation, + "Received ANNOUNCE_ERROR for nonexistent announce"); return; } std::move(it->second)(message.track_namespace, message.reason_phrase); @@ -627,16 +643,18 @@ } void MoqtSession::Stream::OnParsingError(absl::string_view reason) { - session_->Error(absl::StrCat("Parse error: ", reason)); + session_->Error(kProtocolViolation, absl::StrCat("Parse error: ", reason)); } bool MoqtSession::Stream::CheckIfIsControlStream() { if (!is_control_stream_.has_value()) { - session_->Error("Received SUBSCRIBE_REQUEST as first message"); + session_->Error(kProtocolViolation, + "Received SUBSCRIBE_REQUEST as first message"); return false; } if (!*is_control_stream_) { - session_->Error("Received SUBSCRIBE_REQUEST on non-control stream"); + session_->Error(kProtocolViolation, + "Received SUBSCRIBE_REQUEST on non-control stream"); return false; } return true;
diff --git a/quiche/quic/moqt/moqt_session.h b/quiche/quic/moqt/moqt_session.h index 645372b..ec71181 100644 --- a/quiche/quic/moqt/moqt_session.h +++ b/quiche/quic/moqt/moqt_session.h
@@ -47,6 +47,16 @@ MoqtSessionDeletedCallback session_deleted_callback = +[] {}; }; +enum MoqtError : uint64_t { + kNoError = 0x0, + kGenericError = 0x1, + kUnauthorized = 0x2, + kProtocolViolation = 0x3, + kDuplicateTrackAlias = 0x4, + kParameterLengthMismatch = 0x5, + kGoawayTimeout = 0x10, +}; + class QUICHE_EXPORT MoqtSession : public webtransport::SessionVisitor { public: MoqtSession(webtransport::Session* session, MoqtSessionParameters parameters, @@ -73,7 +83,7 @@ void OnCanCreateNewOutgoingBidirectionalStream() override {} void OnCanCreateNewOutgoingUnidirectionalStream() override {} - void Error(absl::string_view error); + void Error(MoqtError code, absl::string_view error); quic::Perspective perspective() const { return parameters_.perspective; }
diff --git a/quiche/quic/moqt/moqt_session_test.cc b/quiche/quic/moqt/moqt_session_test.cc index 4e07a7e..34e7fc3 100644 --- a/quiche/quic/moqt/moqt_session_test.cc +++ b/quiche/quic/moqt/moqt_session_test.cc
@@ -235,12 +235,13 @@ TEST_F(MoqtSessionTest, Error) { bool reported_error = false; - EXPECT_CALL(mock_session_, CloseSession(1, "foo")).Times(1); + EXPECT_CALL(mock_session_, CloseSession(kParameterLengthMismatch, "foo")) + .Times(1); EXPECT_CALL(session_callbacks_.session_terminated_callback, Call(_)) .WillOnce([&](absl::string_view error_message) { reported_error = (error_message == "foo"); }); - session_.Error("foo"); + session_.Error(kParameterLengthMismatch, "foo"); EXPECT_TRUE(reported_error); } @@ -673,6 +674,88 @@ EXPECT_EQ(next_seq.object, 1); } +TEST_F(MoqtSessionTest, OneBidirectionalStreamClient) { + StrictMock<webtransport::test::MockStream> mock_stream; + EXPECT_CALL(mock_session_, OpenOutgoingBidirectionalStream()) + .WillOnce(Return(&mock_stream)); + std::unique_ptr<webtransport::StreamVisitor> visitor; + // Save a reference to MoqtSession::Stream + EXPECT_CALL(mock_stream, SetVisitor(_)) + .WillOnce([&](std::unique_ptr<webtransport::StreamVisitor> new_visitor) { + visitor = std::move(new_visitor); + }); + EXPECT_CALL(mock_stream, GetStreamId()) + .WillOnce(Return(webtransport::StreamId(4))); + bool correct_message = false; + EXPECT_CALL(mock_stream, Writev(_, _)) + .WillOnce([&](absl::Span<const absl::string_view> data, + const quiche::StreamWriteOptions& options) { + correct_message = true; + EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kClientSetup); + return absl::OkStatus(); + }); + session_.OnSessionReady(); + EXPECT_TRUE(correct_message); + + // Peer tries to open a bidi stream. + bool reported_error = false; + EXPECT_CALL(mock_session_, AcceptIncomingBidirectionalStream()) + .WillOnce(Return(&mock_stream)); + EXPECT_CALL(mock_session_, CloseSession(kProtocolViolation, + "Bidirectional stream already open")) + .Times(1); + EXPECT_CALL(session_callbacks_.session_terminated_callback, Call(_)) + .WillOnce([&](absl::string_view error_message) { + reported_error = (error_message == "Bidirectional stream already open"); + }); + session_.OnIncomingBidirectionalStreamAvailable(); + EXPECT_TRUE(reported_error); +} + +TEST_F(MoqtSessionTest, OneBidirectionalStreamServer) { + MoqtSessionParameters server_parameters = { + /*version=*/MoqtVersion::kDraft01, + /*perspective=*/quic::Perspective::IS_SERVER, + /*using_webtrans=*/true, + /*path=*/"", + /*deliver_partial_objects=*/false, + }; + MoqtSession server_session(&mock_session_, server_parameters, + session_callbacks_.AsSessionCallbacks()); + StrictMock<webtransport::test::MockStream> mock_stream; + std::unique_ptr<MoqtParserVisitor> stream_input = + MoqtSessionPeer::CreateControlStream(&server_session, &mock_stream); + MoqtClientSetup setup = { + /*supported_versions*/ {MoqtVersion::kDraft01}, + /*role=*/MoqtRole::kBoth, + /*path=*/std::nullopt, + }; + bool correct_message = false; + EXPECT_CALL(mock_stream, Writev(_, _)) + .WillOnce([&](absl::Span<const absl::string_view> data, + const quiche::StreamWriteOptions& options) { + correct_message = true; + EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kServerSetup); + return absl::OkStatus(); + }); + EXPECT_CALL(session_callbacks_.session_established_callback, Call()).Times(1); + stream_input->OnClientSetupMessage(setup); + + // Peer tries to open a bidi stream. + bool reported_error = false; + EXPECT_CALL(mock_session_, AcceptIncomingBidirectionalStream()) + .WillOnce(Return(&mock_stream)); + EXPECT_CALL(mock_session_, CloseSession(kProtocolViolation, + "Bidirectional stream already open")) + .Times(1); + EXPECT_CALL(session_callbacks_.session_terminated_callback, Call(_)) + .WillOnce([&](absl::string_view error_message) { + reported_error = (error_message == "Bidirectional stream already open"); + }); + server_session.OnIncomingBidirectionalStreamAvailable(); + EXPECT_TRUE(reported_error); +} + // TODO: Cover more error cases in the above } // namespace test
diff --git a/quiche/quic/moqt/tools/chat_client_bin.cc b/quiche/quic/moqt/tools/chat_client_bin.cc index 536709c..85772c4 100644 --- a/quiche/quic/moqt/tools/chat_client_bin.cc +++ b/quiche/quic/moqt/tools/chat_client_bin.cc
@@ -147,7 +147,7 @@ std::optional<absl::string_view> message) { if (message.has_value()) { std::cout << "ANNOUNCE rejected, " << *message << "\n"; - session_->Error("Local ANNOUNCE rejected"); + session_->Error(moqt::kGenericError, "Local ANNOUNCE rejected"); return; } std::cout << "ANNOUNCE for " << track_namespace << " accepted\n"; @@ -224,7 +224,8 @@ if (!got_version) { // Chat server currently does not send version if (line != "version=1") { - session_->Error("Catalog does not begin with version"); + session_->Error(moqt::kProtocolViolation, + "Catalog does not begin with version"); return; } got_version = true; @@ -277,7 +278,8 @@ subscribes_to_make_++; } else { if (it->second.from_group == group_sequence) { - session_->Error("User listed twice in Catalog"); + session_->Error(moqt::kProtocolViolation, + "User listed twice in Catalog"); return; } it->second.from_group = group_sequence;