blob: 06d2808b6b50ceface832bda2cf0d1b24035310f [file] [log] [blame]
// Copyright 2023 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "quiche/quic/moqt/moqt_session.h"
#include <array>
#include <cstdint>
#include <memory>
#include <optional>
#include <string>
#include <utility>
#include <vector>
#include "absl/algorithm/container.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "quiche/quic/core/quic_types.h"
#include "quiche/quic/moqt/moqt_messages.h"
#include "quiche/quic/moqt/moqt_parser.h"
#include "quiche/quic/moqt/moqt_subscribe_windows.h"
#include "quiche/quic/moqt/moqt_track.h"
#include "quiche/quic/platform/api/quic_bug_tracker.h"
#include "quiche/common/platform/api/quiche_bug_tracker.h"
#include "quiche/common/platform/api/quiche_logging.h"
#include "quiche/common/quiche_buffer_allocator.h"
#include "quiche/common/quiche_stream.h"
#include "quiche/web_transport/web_transport.h"
#define ENDPOINT \
(perspective() == Perspective::IS_SERVER ? "MoQT Server: " : "MoQT Client: ")
namespace moqt {
using ::quic::Perspective;
MoqtSession::Stream* MoqtSession::GetControlStream() {
if (!control_stream_.has_value()) {
return nullptr;
}
webtransport::Stream* raw_stream = session_->GetStreamById(*control_stream_);
if (raw_stream == nullptr) {
return nullptr;
}
return static_cast<Stream*>(raw_stream->visitor());
}
void MoqtSession::SendControlMessage(quiche::QuicheBuffer message) {
Stream* control_stream = GetControlStream();
if (control_stream == nullptr) {
QUICHE_LOG(DFATAL) << "Trying to send a message on the control stream "
"while it does not exist";
return;
}
control_stream->SendOrBufferMessage(std::move(message));
}
void MoqtSession::OnSessionReady() {
QUICHE_DLOG(INFO) << ENDPOINT << "Underlying session ready";
if (parameters_.perspective == Perspective::IS_SERVER) {
return;
}
webtransport::Stream* control_stream =
session_->OpenOutgoingBidirectionalStream();
if (control_stream == nullptr) {
Error(MoqtError::kInternalError, "Unable to open a control stream");
return;
}
control_stream->SetVisitor(std::make_unique<Stream>(
this, control_stream, /*is_control_stream=*/true));
control_stream_ = control_stream->GetStreamId();
MoqtClientSetup setup = MoqtClientSetup{
.supported_versions = std::vector<MoqtVersion>{parameters_.version},
.role = MoqtRole::kPubSub,
};
if (!parameters_.using_webtrans) {
setup.path = parameters_.path;
}
SendControlMessage(framer_.SerializeClientSetup(setup));
QUIC_DLOG(INFO) << ENDPOINT << "Send the SETUP message";
}
void MoqtSession::OnSessionClosed(webtransport::SessionErrorCode,
const std::string& error_message) {
if (!error_.empty()) {
// Avoid erroring out twice.
return;
}
QUICHE_DLOG(INFO) << ENDPOINT << "Underlying session closed with message: "
<< error_message;
error_ = error_message;
std::move(callbacks_.session_terminated_callback)(error_message);
}
void MoqtSession::OnIncomingBidirectionalStreamAvailable() {
while (webtransport::Stream* stream =
session_->AcceptIncomingBidirectionalStream()) {
if (control_stream_.has_value()) {
Error(MoqtError::kProtocolViolation, "Bidirectional stream already open");
return;
}
stream->SetVisitor(std::make_unique<Stream>(this, stream));
stream->visitor()->OnCanRead();
}
}
void MoqtSession::OnIncomingUnidirectionalStreamAvailable() {
while (webtransport::Stream* stream =
session_->AcceptIncomingUnidirectionalStream()) {
stream->SetVisitor(std::make_unique<Stream>(this, stream));
stream->visitor()->OnCanRead();
}
}
void MoqtSession::OnDatagramReceived(absl::string_view datagram) {
MoqtObject message;
absl::string_view payload = MoqtParser::ProcessDatagram(datagram, message);
if (payload.empty()) {
Error(MoqtError::kProtocolViolation, "Malformed datagram");
return;
}
QUICHE_DLOG(INFO) << ENDPOINT
<< "Received OBJECT message in datagram for subscribe_id "
<< message.subscribe_id << " for track alias "
<< message.track_alias << " with sequence "
<< message.group_id << ":" << message.object_id
<< " send_order " << message.object_send_order << " length "
<< payload.size();
auto [full_track_name, visitor] = TrackPropertiesFromAlias(message);
if (visitor != nullptr) {
visitor->OnObjectFragment(full_track_name, message.group_id,
message.object_id, message.object_send_order,
message.forwarding_preference, payload, true);
}
}
void MoqtSession::Error(MoqtError code, absl::string_view error) {
if (!error_.empty()) {
// Avoid erroring out twice.
return;
}
QUICHE_DLOG(INFO) << ENDPOINT << "MOQT session closed with code: "
<< static_cast<int>(code) << " and message: " << error;
error_ = std::string(error);
session_->CloseSession(static_cast<uint64_t>(code), error);
std::move(callbacks_.session_terminated_callback)(error);
}
void MoqtSession::AddLocalTrack(const FullTrackName& full_track_name,
MoqtForwardingPreference forwarding_preference,
LocalTrack::Visitor* visitor) {
local_tracks_.try_emplace(full_track_name, full_track_name,
forwarding_preference, visitor);
}
// TODO: Create state that allows ANNOUNCE_OK/ERROR on spurious namespaces to
// trigger session errors.
void MoqtSession::Announce(absl::string_view track_namespace,
MoqtOutgoingAnnounceCallback announce_callback) {
if (peer_role_ == MoqtRole::kPublisher) {
std::move(announce_callback)(
track_namespace,
MoqtAnnounceErrorReason{MoqtAnnounceErrorCode::kInternalError,
"ANNOUNCE cannot be sent to Publisher"});
return;
}
if (pending_outgoing_announces_.contains(track_namespace)) {
std::move(announce_callback)(
track_namespace,
MoqtAnnounceErrorReason{
MoqtAnnounceErrorCode::kInternalError,
"ANNOUNCE message already outstanding for namespace"});
return;
}
MoqtAnnounce message;
message.track_namespace = track_namespace;
SendControlMessage(framer_.SerializeAnnounce(message));
QUIC_DLOG(INFO) << ENDPOINT << "Sent ANNOUNCE message for "
<< message.track_namespace;
pending_outgoing_announces_[track_namespace] = std::move(announce_callback);
}
bool MoqtSession::HasSubscribers(const FullTrackName& full_track_name) const {
auto it = local_tracks_.find(full_track_name);
return (it != local_tracks_.end() && it->second.HasSubscriber());
}
bool MoqtSession::SubscribeAbsolute(absl::string_view track_namespace,
absl::string_view name,
uint64_t start_group, uint64_t start_object,
RemoteTrack::Visitor* visitor,
absl::string_view auth_info) {
MoqtSubscribe message;
message.track_namespace = track_namespace;
message.track_name = name;
message.start_group = start_group;
message.start_object = start_object;
message.end_group = std::nullopt;
message.end_object = std::nullopt;
if (!auth_info.empty()) {
message.authorization_info = std::move(auth_info);
}
return Subscribe(message, visitor);
}
bool MoqtSession::SubscribeAbsolute(absl::string_view track_namespace,
absl::string_view name,
uint64_t start_group, uint64_t start_object,
uint64_t end_group,
RemoteTrack::Visitor* visitor,
absl::string_view auth_info) {
if (end_group < start_group) {
QUIC_DLOG(ERROR) << "Subscription end is before beginning";
return false;
}
MoqtSubscribe message;
message.track_namespace = track_namespace;
message.track_name = name;
message.start_group = start_group;
message.start_object = start_object;
message.end_group = end_group;
message.end_object = std::nullopt;
if (!auth_info.empty()) {
message.authorization_info = std::move(auth_info);
}
return Subscribe(message, visitor);
}
bool MoqtSession::SubscribeAbsolute(absl::string_view track_namespace,
absl::string_view name,
uint64_t start_group, uint64_t start_object,
uint64_t end_group, uint64_t end_object,
RemoteTrack::Visitor* visitor,
absl::string_view auth_info) {
if (end_group < start_group) {
QUIC_DLOG(ERROR) << "Subscription end is before beginning";
return false;
}
if (end_group == start_group && end_object < start_object) {
QUIC_DLOG(ERROR) << "Subscription end is before beginning";
return false;
}
MoqtSubscribe message;
message.track_namespace = track_namespace;
message.track_name = name;
message.start_group = start_group;
message.start_object = start_object;
message.end_group = end_group;
message.end_object = end_object;
if (!auth_info.empty()) {
message.authorization_info = std::move(auth_info);
}
return Subscribe(message, visitor);
}
bool MoqtSession::SubscribeCurrentObject(absl::string_view track_namespace,
absl::string_view name,
RemoteTrack::Visitor* visitor,
absl::string_view auth_info) {
MoqtSubscribe message;
message.track_namespace = track_namespace;
message.track_name = name;
message.start_group = std::nullopt;
message.start_object = std::nullopt;
message.end_group = std::nullopt;
message.end_object = std::nullopt;
if (!auth_info.empty()) {
message.authorization_info = std::move(auth_info);
}
return Subscribe(message, visitor);
}
bool MoqtSession::SubscribeCurrentGroup(absl::string_view track_namespace,
absl::string_view name,
RemoteTrack::Visitor* visitor,
absl::string_view auth_info) {
MoqtSubscribe message;
message.track_namespace = track_namespace;
message.track_name = name;
// First object of current group.
message.start_group = std::nullopt;
message.start_object = 0;
message.end_group = std::nullopt;
message.end_object = std::nullopt;
if (!auth_info.empty()) {
message.authorization_info = std::move(auth_info);
}
return Subscribe(message, visitor);
}
bool MoqtSession::SubscribeIsDone(uint64_t subscribe_id, SubscribeDoneCode code,
absl::string_view reason_phrase) {
// Search all the tracks to find the subscribe ID.
auto name_it = local_track_by_subscribe_id_.find(subscribe_id);
if (name_it == local_track_by_subscribe_id_.end()) {
return false;
}
auto track_it = local_tracks_.find(name_it->second);
if (track_it == local_tracks_.end()) {
return false;
}
LocalTrack& track = track_it->second;
MoqtSubscribeDone subscribe_done;
subscribe_done.subscribe_id = subscribe_id;
subscribe_done.status_code = code;
subscribe_done.reason_phrase = reason_phrase;
SubscribeWindow* window = track.GetWindow(subscribe_id);
if (window == nullptr) {
return false;
}
subscribe_done.final_id = window->largest_delivered();
SendControlMessage(framer_.SerializeSubscribeDone(subscribe_done));
QUIC_DLOG(INFO) << ENDPOINT << "Sent SUBSCRIBE_DONE message for "
<< subscribe_id;
// Clean up the subscription
track.DeleteWindow(subscribe_id);
local_track_by_subscribe_id_.erase(name_it);
return true;
}
bool MoqtSession::Subscribe(MoqtSubscribe& message,
RemoteTrack::Visitor* visitor) {
if (peer_role_ == MoqtRole::kSubscriber) {
QUIC_DLOG(INFO) << ENDPOINT << "Tried to send SUBSCRIBE to subscriber peer";
return false;
}
// TODO(martinduke): support authorization info
message.subscribe_id = next_subscribe_id_++;
FullTrackName ftn(std::string(message.track_namespace),
std::string(message.track_name));
auto it = remote_track_aliases_.find(ftn);
if (it != remote_track_aliases_.end()) {
message.track_alias = it->second;
if (message.track_alias >= next_remote_track_alias_) {
next_remote_track_alias_ = message.track_alias + 1;
}
} else {
message.track_alias = next_remote_track_alias_++;
}
SendControlMessage(framer_.SerializeSubscribe(message));
QUIC_DLOG(INFO) << ENDPOINT << "Sent SUBSCRIBE message for "
<< message.track_namespace << ":" << message.track_name;
active_subscribes_.try_emplace(message.subscribe_id, message, visitor);
return true;
}
std::optional<webtransport::StreamId> MoqtSession::OpenUnidirectionalStream() {
if (!session_->CanOpenNextOutgoingUnidirectionalStream()) {
return std::nullopt;
}
webtransport::Stream* new_stream =
session_->OpenOutgoingUnidirectionalStream();
if (new_stream == nullptr) {
return std::nullopt;
}
new_stream->SetVisitor(std::make_unique<Stream>(this, new_stream, false));
return new_stream->GetStreamId();
}
std::pair<FullTrackName, RemoteTrack::Visitor*>
MoqtSession::TrackPropertiesFromAlias(const MoqtObject& message) {
auto it = remote_tracks_.find(message.track_alias);
RemoteTrack::Visitor* visitor = nullptr;
if (it == remote_tracks_.end()) {
// SUBSCRIBE_OK has not arrived yet, but deliver it.
auto subscribe_it = active_subscribes_.find(message.subscribe_id);
if (subscribe_it == active_subscribes_.end()) {
return std::pair<FullTrackName, RemoteTrack::Visitor*>(
{{"", ""}, nullptr});
}
ActiveSubscribe& subscribe = subscribe_it->second;
visitor = subscribe.visitor;
subscribe.received_object = true;
if (subscribe.forwarding_preference.has_value()) {
if (message.forwarding_preference != *subscribe.forwarding_preference) {
Error(MoqtError::kProtocolViolation,
"Forwarding preference changes mid-track");
return std::pair<FullTrackName, RemoteTrack::Visitor*>(
{{"", ""}, nullptr});
}
} else {
subscribe.forwarding_preference = message.forwarding_preference;
}
return std::pair<FullTrackName, RemoteTrack::Visitor*>(
{{subscribe.message.track_namespace, subscribe.message.track_name},
subscribe.visitor});
}
RemoteTrack& track = it->second;
if (!track.CheckForwardingPreference(message.forwarding_preference)) {
// Incorrect forwarding preference.
Error(MoqtError::kProtocolViolation,
"Forwarding preference changes mid-track");
return std::pair<FullTrackName, RemoteTrack::Visitor*>({{"", ""}, nullptr});
}
return std::pair<FullTrackName, RemoteTrack::Visitor*>(
{{track.full_track_name().track_namespace,
track.full_track_name().track_name},
track.visitor()});
}
bool MoqtSession::PublishObject(const FullTrackName& full_track_name,
uint64_t group_id, uint64_t object_id,
uint64_t object_send_order,
absl::string_view payload, bool end_of_stream) {
auto track_it = local_tracks_.find(full_track_name);
if (track_it == local_tracks_.end()) {
QUICHE_DLOG(ERROR) << ENDPOINT << "Sending OBJECT for nonexistent track";
return false;
}
LocalTrack& track = track_it->second;
MoqtForwardingPreference forwarding_preference =
track.forwarding_preference();
if ((forwarding_preference == MoqtForwardingPreference::kObject ||
forwarding_preference == MoqtForwardingPreference::kDatagram) &&
!end_of_stream) {
QUIC_BUG(MoqtSession_PublishObject_end_of_stream_required)
<< "Forwarding preferences of Object or Datagram require stream to be "
"immediately closed";
return false;
}
track.SentSequence(FullSequence(group_id, object_id));
std::vector<SubscribeWindow*> subscriptions =
track.ShouldSend({group_id, object_id});
if (subscriptions.empty()) {
return true;
}
MoqtObject object;
QUICHE_DCHECK(track.track_alias().has_value());
object.track_alias = *track.track_alias();
object.group_id = group_id;
object.object_id = object_id;
object.object_send_order = object_send_order;
object.forwarding_preference = forwarding_preference;
object.payload_length = payload.size();
int failures = 0;
quiche::StreamWriteOptions write_options;
write_options.set_send_fin(end_of_stream);
for (auto subscription : subscriptions) {
subscription->OnObjectDelivered(FullSequence(group_id, object_id));
if (forwarding_preference == MoqtForwardingPreference::kDatagram) {
object.subscribe_id = subscription->subscribe_id();
quiche::QuicheBuffer datagram =
framer_.SerializeObjectDatagram(object, payload);
// TODO(martinduke): It's OK to just silently fail, but better to notify
// the app on errors.
session_->SendOrQueueDatagram(datagram.AsStringView());
continue;
}
bool new_stream = false;
std::optional<webtransport::StreamId> stream_id =
subscription->GetStreamForSequence(FullSequence(group_id, object_id));
if (!stream_id.has_value()) {
new_stream = true;
stream_id = OpenUnidirectionalStream();
if (!stream_id.has_value()) {
QUICHE_DLOG(ERROR) << ENDPOINT
<< "Sending OBJECT to nonexistent stream";
++failures;
continue;
}
if (!end_of_stream) {
subscription->AddStream(group_id, object_id, *stream_id);
}
}
webtransport::Stream* stream = session_->GetStreamById(*stream_id);
if (stream == nullptr) {
QUICHE_DLOG(ERROR) << ENDPOINT << "Sending OBJECT to nonexistent stream "
<< *stream_id;
++failures;
continue;
}
object.subscribe_id = subscription->subscribe_id();
quiche::QuicheBuffer header =
framer_.SerializeObjectHeader(object, new_stream);
std::array<absl::string_view, 2> views = {header.AsStringView(), payload};
if (!stream->Writev(views, write_options).ok()) {
QUICHE_DLOG(ERROR) << ENDPOINT << "Failed to write OBJECT message";
++failures;
continue;
}
QUICHE_LOG(INFO) << ENDPOINT << "Sending object length " << payload.length()
<< " for " << full_track_name.track_namespace << ":"
<< full_track_name.track_name << " with sequence "
<< object.group_id << ":" << object.object_id
<< " on stream " << *stream_id;
if (end_of_stream && !new_stream) {
subscription->RemoveStream(group_id, object_id);
}
}
return (failures == 0);
}
void MoqtSession::CloseObjectStream(const FullTrackName& full_track_name,
uint64_t group_id) {
auto track_it = local_tracks_.find(full_track_name);
if (track_it == local_tracks_.end()) {
QUICHE_DLOG(ERROR) << ENDPOINT << "Sending OBJECT for nonexistent track";
return;
}
LocalTrack& track = track_it->second;
MoqtForwardingPreference forwarding_preference =
track.forwarding_preference();
if (forwarding_preference == MoqtForwardingPreference::kObject ||
forwarding_preference == MoqtForwardingPreference::kDatagram) {
QUIC_BUG(MoqtSession_CloseStreamObject_wrong_type)
<< "Forwarding preferences of Object or Datagram require stream to be "
"immediately closed, and thus are not valid CloseObjectStream() "
"targets";
return;
}
std::vector<SubscribeWindow*> subscriptions =
track.ShouldSend({group_id, /*object=*/0});
for (SubscribeWindow* subscription : subscriptions) {
std::optional<webtransport::StreamId> stream_id =
subscription->GetStreamForSequence(
FullSequence(group_id, /*object=*/0));
if (!stream_id.has_value()) {
continue;
}
webtransport::Stream* stream = session_->GetStreamById(*stream_id);
if (stream == nullptr) {
continue;
}
bool success = stream->SendFin();
QUICHE_BUG_IF(MoqtSession_CloseObjectStream_fin_failed, !success);
}
}
void MoqtSession::Stream::OnCanRead() {
bool fin =
quiche::ProcessAllReadableRegions(*stream_, [&](absl::string_view chunk) {
parser_.ProcessData(chunk, /*end_of_stream=*/false);
});
if (fin) {
parser_.ProcessData("", /*end_of_stream=*/true);
}
}
void MoqtSession::Stream::OnCanWrite() {}
void MoqtSession::Stream::OnResetStreamReceived(
webtransport::StreamErrorCode error) {
if (is_control_stream_.has_value() && *is_control_stream_) {
session_->Error(
MoqtError::kProtocolViolation,
absl::StrCat("Control stream reset with error code ", error));
}
}
void MoqtSession::Stream::OnStopSendingReceived(
webtransport::StreamErrorCode error) {
if (is_control_stream_.has_value() && *is_control_stream_) {
session_->Error(
MoqtError::kProtocolViolation,
absl::StrCat("Control stream reset with error code ", error));
}
}
void MoqtSession::Stream::OnObjectMessage(const MoqtObject& message,
absl::string_view payload,
bool end_of_message) {
if (is_control_stream_ == true) {
session_->Error(MoqtError::kProtocolViolation,
"Received OBJECT message on control stream");
return;
}
QUICHE_DLOG(INFO)
<< ENDPOINT << "Received OBJECT message on stream "
<< stream_->GetStreamId() << " for subscribe_id " << message.subscribe_id
<< " for track alias " << message.track_alias << " with sequence "
<< message.group_id << ":" << message.object_id << " send_order "
<< message.object_send_order << " forwarding_preference "
<< MoqtForwardingPreferenceToString(message.forwarding_preference)
<< " length " << payload.size() << " explicit length "
<< (message.payload_length.has_value() ? (int)*message.payload_length
: -1)
<< (end_of_message ? "F" : "");
if (!session_->parameters_.deliver_partial_objects) {
if (!end_of_message) { // Buffer partial object.
absl::StrAppend(&partial_object_, payload);
return;
}
if (!partial_object_.empty()) { // Completes the object
absl::StrAppend(&partial_object_, payload);
payload = absl::string_view(partial_object_);
}
}
auto [full_track_name, visitor] = session_->TrackPropertiesFromAlias(message);
if (visitor != nullptr) {
visitor->OnObjectFragment(full_track_name, message.group_id,
message.object_id, message.object_send_order,
message.forwarding_preference, payload,
end_of_message);
}
partial_object_.clear();
}
void MoqtSession::Stream::OnClientSetupMessage(const MoqtClientSetup& message) {
if (is_control_stream_.has_value()) {
if (!*is_control_stream_) {
session_->Error(MoqtError::kProtocolViolation,
"Received SETUP on non-control stream");
return;
}
} else {
is_control_stream_ = true;
}
session_->control_stream_ = stream_->GetStreamId();
if (perspective() == Perspective::IS_CLIENT) {
session_->Error(MoqtError::kProtocolViolation,
"Received CLIENT_SETUP from server");
return;
}
if (absl::c_find(message.supported_versions, session_->parameters_.version) ==
message.supported_versions.end()) {
// TODO(martinduke): Is this the right error code? See issue #346.
session_->Error(MoqtError::kProtocolViolation,
absl::StrCat("Version mismatch: expected 0x",
absl::Hex(session_->parameters_.version)));
return;
}
QUICHE_DLOG(INFO) << ENDPOINT << "Received the SETUP message";
if (session_->parameters_.perspective == Perspective::IS_SERVER) {
MoqtServerSetup response;
response.selected_version = session_->parameters_.version;
response.role = MoqtRole::kPubSub;
SendOrBufferMessage(session_->framer_.SerializeServerSetup(response));
QUIC_DLOG(INFO) << ENDPOINT << "Sent the SETUP message";
}
// TODO: handle role and path.
std::move(session_->callbacks_.session_established_callback)();
session_->peer_role_ = *message.role;
}
void MoqtSession::Stream::OnServerSetupMessage(const MoqtServerSetup& message) {
if (is_control_stream_.has_value()) {
if (!*is_control_stream_) {
session_->Error(MoqtError::kProtocolViolation,
"Received SETUP on non-control stream");
return;
}
} else {
is_control_stream_ = true;
}
if (perspective() == Perspective::IS_SERVER) {
session_->Error(MoqtError::kProtocolViolation,
"Received SERVER_SETUP from client");
return;
}
if (message.selected_version != session_->parameters_.version) {
// TODO(martinduke): Is this the right error code? See issue #346.
session_->Error(MoqtError::kProtocolViolation,
absl::StrCat("Version mismatch: expected 0x",
absl::Hex(session_->parameters_.version)));
return;
}
QUIC_DLOG(INFO) << ENDPOINT << "Received the SETUP message";
// TODO: handle role and path.
std::move(session_->callbacks_.session_established_callback)();
session_->peer_role_ = *message.role;
}
void MoqtSession::Stream::SendSubscribeError(const MoqtSubscribe& message,
SubscribeErrorCode error_code,
absl::string_view reason_phrase,
uint64_t track_alias) {
MoqtSubscribeError subscribe_error;
subscribe_error.subscribe_id = message.subscribe_id;
subscribe_error.error_code = error_code;
subscribe_error.reason_phrase = reason_phrase;
subscribe_error.track_alias = track_alias;
SendOrBufferMessage(
session_->framer_.SerializeSubscribeError(subscribe_error));
}
void MoqtSession::Stream::OnSubscribeMessage(const MoqtSubscribe& message) {
std::string reason_phrase = "";
if (!CheckIfIsControlStream()) {
return;
}
if (session_->peer_role_ == MoqtRole::kPublisher) {
QUIC_DLOG(INFO) << ENDPOINT << "Publisher peer sent SUBSCRIBE";
session_->Error(MoqtError::kProtocolViolation,
"Received SUBSCRIBE from publisher");
return;
}
QUIC_DLOG(INFO) << ENDPOINT << "Received a SUBSCRIBE for "
<< message.track_namespace << ":" << message.track_name;
auto it = session_->local_tracks_.find(FullTrackName(
std::string(message.track_namespace), std::string(message.track_name)));
if (it == session_->local_tracks_.end()) {
QUIC_DLOG(INFO) << ENDPOINT << "Rejected because "
<< message.track_namespace << ":" << message.track_name
<< " does not exist";
SendSubscribeError(message, SubscribeErrorCode::kInternalError,
"Track does not exist", message.track_alias);
return;
}
LocalTrack& track = it->second;
if ((track.track_alias().has_value() &&
message.track_alias != *track.track_alias()) ||
session_->used_track_aliases_.contains(message.track_alias)) {
// Propose a different track_alias.
SendSubscribeError(message, SubscribeErrorCode::kRetryTrackAlias,
"Track alias already exists",
session_->next_local_track_alias_++);
return;
} else { // Use client-provided alias.
track.set_track_alias(message.track_alias);
if (message.track_alias >= session_->next_local_track_alias_) {
session_->next_local_track_alias_ = message.track_alias + 1;
}
session_->used_track_aliases_.insert(message.track_alias);
}
FullSequence start;
std::optional<FullSequence> end;
if (message.start_group.has_value()) {
// The filter is AbsoluteStart or AbsoluteRange.
QUIC_BUG_IF(quic_bug_invalid_subscribe, !message.start_object.has_value())
<< "Start group without start object";
start = FullSequence(*message.start_group, *message.start_object);
} else {
// The filter is LatestObject or LatestGroup.
start = track.next_sequence();
if (message.start_object.has_value()) {
// The filter is LatestGroup.
QUIC_BUG_IF(quic_bug_invalid_subscribe, *message.start_object != 0)
<< "LatestGroup does not start with zero";
start.object = 0;
} else {
--start.object;
}
}
if (message.end_group.has_value()) {
end = FullSequence(*message.end_group, message.end_object.has_value()
? *message.end_object
: UINT64_MAX);
}
LocalTrack::Visitor::PublishPastObjectsCallback publish_past_objects;
SubscribeWindow window =
end.has_value()
? SubscribeWindow(message.subscribe_id, track.forwarding_preference(),
start.group, start.object, end->group, end->object)
: SubscribeWindow(message.subscribe_id, track.forwarding_preference(),
start.group, start.object);
if (start < track.next_sequence() && track.visitor() != nullptr) {
absl::StatusOr<LocalTrack::Visitor::PublishPastObjectsCallback>
past_objects_available = track.visitor()->OnSubscribeForPast(window);
if (!past_objects_available.ok()) {
SendSubscribeError(message, SubscribeErrorCode::kInternalError,
past_objects_available.status().message(),
message.track_alias);
return;
}
publish_past_objects = *std::move(past_objects_available);
}
MoqtSubscribeOk subscribe_ok;
subscribe_ok.subscribe_id = message.subscribe_id;
SendOrBufferMessage(session_->framer_.SerializeSubscribeOk(subscribe_ok));
QUIC_DLOG(INFO) << ENDPOINT << "Created subscription for "
<< message.track_namespace << ":" << message.track_name;
if (end.has_value()) {
track.AddWindow(message.subscribe_id, start.group, start.object, end->group,
end->object);
} else {
track.AddWindow(message.subscribe_id, start.group, start.object);
}
session_->local_track_by_subscribe_id_.emplace(message.subscribe_id,
track.full_track_name());
if (publish_past_objects) {
std::move(publish_past_objects)();
}
}
void MoqtSession::Stream::OnSubscribeOkMessage(const MoqtSubscribeOk& message) {
if (!CheckIfIsControlStream()) {
return;
}
auto it = session_->active_subscribes_.find(message.subscribe_id);
if (it == session_->active_subscribes_.end()) {
session_->Error(MoqtError::kProtocolViolation,
"Received SUBSCRIBE_OK for nonexistent subscribe");
return;
}
MoqtSubscribe& subscribe = it->second.message;
QUIC_DLOG(INFO) << ENDPOINT << "Received the SUBSCRIBE_OK for "
<< "subscribe_id = " << message.subscribe_id << " "
<< subscribe.track_namespace << ":" << subscribe.track_name;
// Copy the Remote Track from session_->active_subscribes_ to
// session_->remote_tracks_.
FullTrackName ftn(subscribe.track_namespace, subscribe.track_name);
RemoteTrack::Visitor* visitor = it->second.visitor;
auto [track_iter, new_entry] = session_->remote_tracks_.try_emplace(
subscribe.track_alias, ftn, subscribe.track_alias, visitor);
if (it->second.forwarding_preference.has_value()) {
if (!track_iter->second.CheckForwardingPreference(
*it->second.forwarding_preference)) {
session_->Error(MoqtError::kProtocolViolation,
"Forwarding preference different in early objects");
return;
}
}
// TODO: handle expires.
if (visitor != nullptr) {
visitor->OnReply(ftn, std::nullopt);
}
session_->active_subscribes_.erase(it);
}
void MoqtSession::Stream::OnSubscribeErrorMessage(
const MoqtSubscribeError& message) {
if (!CheckIfIsControlStream()) {
return;
}
auto it = session_->active_subscribes_.find(message.subscribe_id);
if (it == session_->active_subscribes_.end()) {
session_->Error(MoqtError::kProtocolViolation,
"Received SUBSCRIBE_ERROR for nonexistent subscribe");
return;
}
if (it->second.received_object) {
session_->Error(MoqtError::kProtocolViolation,
"Received SUBSCRIBE_ERROR after object");
return;
}
MoqtSubscribe& subscribe = it->second.message;
QUIC_DLOG(INFO) << ENDPOINT << "Received the SUBSCRIBE_ERROR for "
<< "subscribe_id = " << message.subscribe_id << " ("
<< subscribe.track_namespace << ":" << subscribe.track_name
<< ")" << ", error = " << static_cast<int>(message.error_code)
<< " (" << message.reason_phrase << ")";
RemoteTrack::Visitor* visitor = it->second.visitor;
FullTrackName ftn(subscribe.track_namespace, subscribe.track_name);
if (message.error_code == SubscribeErrorCode::kRetryTrackAlias) {
// Automatically resubscribe with new alias.
session_->remote_track_aliases_[ftn] = message.track_alias;
session_->Subscribe(subscribe, visitor);
} else if (visitor != nullptr) {
visitor->OnReply(ftn, message.reason_phrase);
}
session_->active_subscribes_.erase(it);
}
void MoqtSession::Stream::OnUnsubscribeMessage(const MoqtUnsubscribe& message) {
session_->SubscribeIsDone(message.subscribe_id,
SubscribeDoneCode::kUnsubscribed, "");
}
void MoqtSession::Stream::OnSubscribeUpdateMessage(
const MoqtSubscribeUpdate& message) {
// Search all the tracks to find the subscribe ID.
auto name_it =
session_->local_track_by_subscribe_id_.find(message.subscribe_id);
if (name_it == session_->local_track_by_subscribe_id_.end()) {
return;
}
auto track_it = session_->local_tracks_.find(name_it->second);
if (track_it == session_->local_tracks_.end()) {
return;
}
LocalTrack& track = track_it->second;
SubscribeWindow* window = track.GetWindow(message.subscribe_id);
if (window == nullptr) {
return;
}
FullSequence start(message.start_group, message.start_object);
std::optional<FullSequence> end;
if (message.end_group.has_value()) {
end = FullSequence(*message.end_group, message.end_object.has_value()
? *message.end_object
: UINT64_MAX);
}
// TODO(martinduke): Handle the case where the update range is invalid.
if (window->UpdateStartEnd(start, end)) {
std::optional<FullSequence> largest_delivered = window->largest_delivered();
if (largest_delivered.has_value() && end <= *largest_delivered) {
session_->SubscribeIsDone(message.subscribe_id,
SubscribeDoneCode::kSubscriptionEnded,
"SUBSCRIBE_UPDATE moved subscription end");
}
}
}
void MoqtSession::Stream::OnAnnounceMessage(const MoqtAnnounce& message) {
if (session_->peer_role_ == MoqtRole::kSubscriber) {
QUIC_DLOG(INFO) << ENDPOINT << "Subscriber peer sent SUBSCRIBE";
session_->Error(MoqtError::kProtocolViolation,
"Received ANNOUNCE from Subscriber");
return;
}
if (!CheckIfIsControlStream()) {
return;
}
std::optional<MoqtAnnounceErrorReason> error =
session_->callbacks_.incoming_announce_callback(message.track_namespace);
if (error.has_value()) {
MoqtAnnounceError reply;
reply.track_namespace = message.track_namespace;
reply.error_code = error->error_code;
reply.reason_phrase = error->reason_phrase;
SendOrBufferMessage(session_->framer_.SerializeAnnounceError(reply));
return;
}
MoqtAnnounceOk ok;
ok.track_namespace = message.track_namespace;
SendOrBufferMessage(session_->framer_.SerializeAnnounceOk(ok));
}
void MoqtSession::Stream::OnAnnounceOkMessage(const MoqtAnnounceOk& message) {
if (!CheckIfIsControlStream()) {
return;
}
auto it = session_->pending_outgoing_announces_.find(message.track_namespace);
if (it == session_->pending_outgoing_announces_.end()) {
session_->Error(MoqtError::kProtocolViolation,
"Received ANNOUNCE_OK for nonexistent announce");
return;
}
std::move(it->second)(message.track_namespace, std::nullopt);
session_->pending_outgoing_announces_.erase(it);
}
void MoqtSession::Stream::OnAnnounceErrorMessage(
const MoqtAnnounceError& message) {
if (!CheckIfIsControlStream()) {
return;
}
auto it = session_->pending_outgoing_announces_.find(message.track_namespace);
if (it == session_->pending_outgoing_announces_.end()) {
session_->Error(MoqtError::kProtocolViolation,
"Received ANNOUNCE_ERROR for nonexistent announce");
return;
}
std::move(it->second)(
message.track_namespace,
MoqtAnnounceErrorReason{message.error_code,
std::string(message.reason_phrase)});
session_->pending_outgoing_announces_.erase(it);
}
void MoqtSession::Stream::OnParsingError(MoqtError error_code,
absl::string_view reason) {
session_->Error(error_code, absl::StrCat("Parse error: ", reason));
}
bool MoqtSession::Stream::CheckIfIsControlStream() {
if (!is_control_stream_.has_value()) {
session_->Error(MoqtError::kProtocolViolation,
"Received SUBSCRIBE_REQUEST as first message");
return false;
}
if (!*is_control_stream_) {
session_->Error(MoqtError::kProtocolViolation,
"Received SUBSCRIBE_REQUEST on non-control stream");
return false;
}
return true;
}
void MoqtSession::Stream::SendOrBufferMessage(quiche::QuicheBuffer message,
bool fin) {
quiche::StreamWriteOptions options;
options.set_send_fin(fin);
options.set_buffer_unconditionally(true);
std::array<absl::string_view, 1> write_vector = {message.AsStringView()};
absl::Status success = stream_->Writev(absl::MakeSpan(write_vector), options);
if (!success.ok()) {
session_->Error(MoqtError::kInternalError,
"Failed to write a control message");
}
}
} // namespace moqt