Switch moqt_messages.h types to use std::string instead of absl::string_view
The immediate goal here is to fix a use-after-free caused by the fact that MoqtSession::active_subscriptions_ contains a bunch of MoqtSubscription objects that have views to potentially deleted track names (`blaze test -c dbg --config=asan //third_party/quic/moqt/...` will reproduce those).
The long-term goal is to make moqt_messages.h types safer to use, as they're really prone to UAFs right now.
PiperOrigin-RevId: 612578165
diff --git a/quiche/common/quiche_data_reader.cc b/quiche/common/quiche_data_reader.cc
index 84eca4d..15ad2f9 100644
--- a/quiche/common/quiche_data_reader.cc
+++ b/quiche/common/quiche_data_reader.cc
@@ -231,6 +231,13 @@
return ReadStringPiece(result, result_length);
}
+bool QuicheDataReader::ReadStringVarInt62(std::string& result) {
+ absl::string_view result_view;
+ bool success = ReadStringPieceVarInt62(&result_view);
+ result = std::string(result_view);
+ return success;
+}
+
absl::string_view QuicheDataReader::ReadRemainingPayload() {
absl::string_view payload = PeekRemainingPayload();
pos_ = len_;
diff --git a/quiche/common/quiche_data_reader.h b/quiche/common/quiche_data_reader.h
index 9f7dd56..d50c040 100644
--- a/quiche/common/quiche_data_reader.h
+++ b/quiche/common/quiche_data_reader.h
@@ -8,6 +8,7 @@
#include <cstddef>
#include <cstdint>
#include <limits>
+#include <string>
#include "absl/strings/string_view.h"
#include "quiche/common/platform/api/quiche_export.h"
@@ -111,6 +112,13 @@
// the number and subsequent string, true otherwise.
bool ReadStringPieceVarInt62(absl::string_view* result);
+ // Reads a string prefixed with a RFC 9000 varint length prefix, and copies it
+ // into the provided string.
+ //
+ // Returns false if there is not enough space in the buffer to read
+ // the number and subsequent string, true otherwise.
+ bool ReadStringVarInt62(std::string& result);
+
// Returns the remaining payload as a absl::string_view.
//
// NOTE: Does not copy but rather references strings in the underlying buffer.
diff --git a/quiche/quic/moqt/moqt_messages.h b/quiche/quic/moqt/moqt_messages.h
index 081abcd..81d2c4d 100644
--- a/quiche/quic/moqt/moqt_messages.h
+++ b/quiche/quic/moqt/moqt_messages.h
@@ -157,7 +157,7 @@
struct QUICHE_EXPORT MoqtClientSetup {
std::vector<MoqtVersion> supported_versions;
std::optional<MoqtRole> role;
- std::optional<absl::string_view> path;
+ std::optional<std::string> path;
};
struct QUICHE_EXPORT MoqtServerSetup {
@@ -230,14 +230,14 @@
struct QUICHE_EXPORT MoqtSubscribe {
uint64_t subscribe_id;
uint64_t track_alias;
- absl::string_view track_namespace;
- absl::string_view track_name;
+ std::string track_namespace;
+ std::string track_name;
// If the mode is kNone, the these are std::nullopt.
std::optional<MoqtSubscribeLocation> start_group;
std::optional<MoqtSubscribeLocation> start_object;
std::optional<MoqtSubscribeLocation> end_group;
std::optional<MoqtSubscribeLocation> end_object;
- std::optional<absl::string_view> authorization_info;
+ std::optional<std::string> authorization_info;
};
struct QUICHE_EXPORT MoqtSubscribeOk {
@@ -255,7 +255,7 @@
struct QUICHE_EXPORT MoqtSubscribeError {
uint64_t subscribe_id;
SubscribeErrorCode error_code;
- absl::string_view reason_phrase;
+ std::string reason_phrase;
uint64_t track_alias;
};
@@ -272,32 +272,32 @@
struct QUICHE_EXPORT MoqtSubscribeRst {
uint64_t subscribe_id;
uint64_t error_code;
- absl::string_view reason_phrase;
+ std::string reason_phrase;
uint64_t final_group;
uint64_t final_object;
};
struct QUICHE_EXPORT MoqtAnnounce {
- absl::string_view track_namespace;
- std::optional<absl::string_view> authorization_info;
+ std::string track_namespace;
+ std::optional<std::string> authorization_info;
};
struct QUICHE_EXPORT MoqtAnnounceOk {
- absl::string_view track_namespace;
+ std::string track_namespace;
};
struct QUICHE_EXPORT MoqtAnnounceError {
- absl::string_view track_namespace;
+ std::string track_namespace;
MoqtAnnounceErrorCode error_code;
- absl::string_view reason_phrase;
+ std::string reason_phrase;
};
struct QUICHE_EXPORT MoqtUnannounce {
- absl::string_view track_namespace;
+ std::string track_namespace;
};
struct QUICHE_EXPORT MoqtGoAway {
- absl::string_view new_session_uri;
+ std::string new_session_uri;
};
std::string MoqtMessageTypeToString(MoqtMessageType message_type);
diff --git a/quiche/quic/moqt/moqt_parser.cc b/quiche/quic/moqt/moqt_parser.cc
index 3c8a31c..fe49c7a 100644
--- a/quiche/quic/moqt/moqt_parser.cc
+++ b/quiche/quic/moqt/moqt_parser.cc
@@ -355,8 +355,8 @@
MoqtSubscribe subscribe_request;
if (!reader.ReadVarInt62(&subscribe_request.subscribe_id) ||
!reader.ReadVarInt62(&subscribe_request.track_alias) ||
- !reader.ReadStringPieceVarInt62(&subscribe_request.track_namespace) ||
- !reader.ReadStringPieceVarInt62(&subscribe_request.track_name) ||
+ !reader.ReadStringVarInt62(subscribe_request.track_namespace) ||
+ !reader.ReadStringVarInt62(subscribe_request.track_name) ||
!ReadLocation(reader, subscribe_request.start_group)) {
return 0;
}
@@ -429,7 +429,7 @@
uint64_t error_code;
if (!reader.ReadVarInt62(&subscribe_error.subscribe_id) ||
!reader.ReadVarInt62(&error_code) ||
- !reader.ReadStringPieceVarInt62(&subscribe_error.reason_phrase) ||
+ !reader.ReadStringVarInt62(subscribe_error.reason_phrase) ||
!reader.ReadVarInt62(&subscribe_error.track_alias)) {
return 0;
}
@@ -462,7 +462,7 @@
MoqtSubscribeRst subscribe_rst;
if (!reader.ReadVarInt62(&subscribe_rst.subscribe_id) ||
!reader.ReadVarInt62(&subscribe_rst.error_code) ||
- !reader.ReadStringPieceVarInt62(&subscribe_rst.reason_phrase) ||
+ !reader.ReadStringVarInt62(subscribe_rst.reason_phrase) ||
!reader.ReadVarInt62(&subscribe_rst.final_group) ||
!reader.ReadVarInt62(&subscribe_rst.final_object)) {
return 0;
@@ -473,7 +473,7 @@
size_t MoqtParser::ProcessAnnounce(quic::QuicDataReader& reader) {
MoqtAnnounce announce;
- if (!reader.ReadStringPieceVarInt62(&announce.track_namespace)) {
+ if (!reader.ReadStringVarInt62(announce.track_namespace)) {
return 0;
}
uint64_t num_params;
@@ -506,7 +506,7 @@
size_t MoqtParser::ProcessAnnounceOk(quic::QuicDataReader& reader) {
MoqtAnnounceOk announce_ok;
- if (!reader.ReadStringPieceVarInt62(&announce_ok.track_namespace)) {
+ if (!reader.ReadStringVarInt62(announce_ok.track_namespace)) {
return 0;
}
visitor_.OnAnnounceOkMessage(announce_ok);
@@ -515,7 +515,7 @@
size_t MoqtParser::ProcessAnnounceError(quic::QuicDataReader& reader) {
MoqtAnnounceError announce_error;
- if (!reader.ReadStringPieceVarInt62(&announce_error.track_namespace)) {
+ if (!reader.ReadStringVarInt62(announce_error.track_namespace)) {
return 0;
}
uint64_t error_code;
@@ -523,7 +523,7 @@
return 0;
}
announce_error.error_code = static_cast<MoqtAnnounceErrorCode>(error_code);
- if (!reader.ReadStringPieceVarInt62(&announce_error.reason_phrase)) {
+ if (!reader.ReadStringVarInt62(announce_error.reason_phrase)) {
return 0;
}
visitor_.OnAnnounceErrorMessage(announce_error);
@@ -532,7 +532,7 @@
size_t MoqtParser::ProcessUnannounce(quic::QuicDataReader& reader) {
MoqtUnannounce unannounce;
- if (!reader.ReadStringPieceVarInt62(&unannounce.track_namespace)) {
+ if (!reader.ReadStringVarInt62(unannounce.track_namespace)) {
return 0;
}
visitor_.OnUnannounceMessage(unannounce);
@@ -541,7 +541,7 @@
size_t MoqtParser::ProcessGoAway(quic::QuicDataReader& reader) {
MoqtGoAway goaway;
- if (!reader.ReadStringPieceVarInt62(&goaway.new_session_uri)) {
+ if (!reader.ReadStringVarInt62(goaway.new_session_uri)) {
return 0;
}
visitor_.OnGoAwayMessage(goaway);
diff --git a/quiche/quic/moqt/moqt_parser_test.cc b/quiche/quic/moqt/moqt_parser_test.cc
index 8116455..48fa22a 100644
--- a/quiche/quic/moqt/moqt_parser_test.cc
+++ b/quiche/quic/moqt/moqt_parser_test.cc
@@ -104,116 +104,51 @@
messages_received_++;
last_message_ = TestMessageBase::MessageStructuredData(object);
}
- void OnClientSetupMessage(const MoqtClientSetup& message) override {
+
+ template <typename Message>
+ void OnControlMessage(const Message& message) {
end_of_message_ = true;
- messages_received_++;
- MoqtClientSetup client_setup = message;
- if (client_setup.path.has_value()) {
- string0_ = std::string(*client_setup.path);
- client_setup.path = absl::string_view(string0_);
- }
- last_message_ = TestMessageBase::MessageStructuredData(client_setup);
+ ++messages_received_;
+ last_message_ = TestMessageBase::MessageStructuredData(message);
+ }
+ void OnClientSetupMessage(const MoqtClientSetup& message) override {
+ OnControlMessage(message);
}
void OnServerSetupMessage(const MoqtServerSetup& message) override {
- end_of_message_ = true;
- messages_received_++;
- MoqtServerSetup server_setup = message;
- last_message_ = TestMessageBase::MessageStructuredData(server_setup);
+ OnControlMessage(message);
}
void OnSubscribeMessage(const MoqtSubscribe& message) override {
- end_of_message_ = true;
- messages_received_++;
- MoqtSubscribe subscribe_request = message;
- string0_ = std::string(subscribe_request.track_namespace);
- subscribe_request.track_namespace = absl::string_view(string0_);
- string1_ = std::string(subscribe_request.track_name);
- subscribe_request.track_name = absl::string_view(string1_);
- if (subscribe_request.authorization_info.has_value()) {
- string2_ = std::string(*subscribe_request.authorization_info);
- subscribe_request.authorization_info = absl::string_view(string2_);
- }
- last_message_ = TestMessageBase::MessageStructuredData(subscribe_request);
+ OnControlMessage(message);
}
void OnSubscribeOkMessage(const MoqtSubscribeOk& message) override {
- end_of_message_ = true;
- messages_received_++;
- MoqtSubscribeOk subscribe_ok = message;
- last_message_ = TestMessageBase::MessageStructuredData(subscribe_ok);
+ OnControlMessage(message);
}
void OnSubscribeErrorMessage(const MoqtSubscribeError& message) override {
- end_of_message_ = true;
- messages_received_++;
- MoqtSubscribeError subscribe_error = message;
- string0_ = std::string(subscribe_error.reason_phrase);
- subscribe_error.reason_phrase = absl::string_view(string0_);
- last_message_ = TestMessageBase::MessageStructuredData(subscribe_error);
+ OnControlMessage(message);
}
void OnUnsubscribeMessage(const MoqtUnsubscribe& message) override {
- end_of_message_ = true;
- messages_received_++;
- MoqtUnsubscribe unsubscribe = message;
- last_message_ = TestMessageBase::MessageStructuredData(unsubscribe);
+ OnControlMessage(message);
}
void OnSubscribeFinMessage(const MoqtSubscribeFin& message) override {
- end_of_message_ = true;
- messages_received_++;
- MoqtSubscribeFin subscribe_fin = message;
- last_message_ = TestMessageBase::MessageStructuredData(subscribe_fin);
+ OnControlMessage(message);
}
void OnSubscribeRstMessage(const MoqtSubscribeRst& message) override {
- end_of_message_ = true;
- messages_received_++;
- MoqtSubscribeRst subscribe_rst = message;
- string0_ = std::string(subscribe_rst.reason_phrase);
- subscribe_rst.reason_phrase = absl::string_view(string0_);
- last_message_ = TestMessageBase::MessageStructuredData(subscribe_rst);
+ OnControlMessage(message);
}
void OnAnnounceMessage(const MoqtAnnounce& message) override {
- end_of_message_ = true;
- messages_received_++;
- MoqtAnnounce announce = message;
- string0_ = std::string(announce.track_namespace);
- announce.track_namespace = absl::string_view(string0_);
- if (announce.authorization_info.has_value()) {
- string1_ = std::string(*announce.authorization_info);
- announce.authorization_info = absl::string_view(string1_);
- }
- last_message_ = TestMessageBase::MessageStructuredData(announce);
+ OnControlMessage(message);
}
void OnAnnounceOkMessage(const MoqtAnnounceOk& message) override {
- end_of_message_ = true;
- messages_received_++;
- MoqtAnnounceOk announce_ok = message;
- string0_ = std::string(announce_ok.track_namespace);
- announce_ok.track_namespace = absl::string_view(string0_);
- last_message_ = TestMessageBase::MessageStructuredData(announce_ok);
+ OnControlMessage(message);
}
void OnAnnounceErrorMessage(const MoqtAnnounceError& message) override {
- end_of_message_ = true;
- messages_received_++;
- MoqtAnnounceError announce_error = message;
- string0_ = std::string(announce_error.track_namespace);
- announce_error.track_namespace = absl::string_view(string0_);
- string1_ = std::string(announce_error.reason_phrase);
- announce_error.reason_phrase = absl::string_view(string1_);
- last_message_ = TestMessageBase::MessageStructuredData(announce_error);
+ OnControlMessage(message);
}
void OnUnannounceMessage(const MoqtUnannounce& message) override {
- end_of_message_ = true;
- messages_received_++;
- MoqtUnannounce unannounce = message;
- string0_ = std::string(unannounce.track_namespace);
- unannounce.track_namespace = absl::string_view(string0_);
- last_message_ = TestMessageBase::MessageStructuredData(unannounce);
+ OnControlMessage(message);
}
void OnGoAwayMessage(const MoqtGoAway& message) override {
- got_goaway_ = true;
- end_of_message_ = true;
- messages_received_++;
- MoqtGoAway goaway = message;
- string0_ = std::string(goaway.new_session_uri);
- goaway.new_session_uri = absl::string_view(string0_);
- last_message_ = TestMessageBase::MessageStructuredData(goaway);
+ OnControlMessage(message);
}
void OnParsingError(MoqtError code, absl::string_view reason) override {
QUIC_LOG(INFO) << "Parsing error: " << reason;
@@ -223,14 +158,10 @@
std::optional<absl::string_view> object_payload_;
bool end_of_message_ = false;
- bool got_goaway_ = false;
std::optional<absl::string_view> parsing_error_;
MoqtError parsing_error_code_;
uint64_t messages_received_ = 0;
std::optional<TestMessageBase::MessageStructuredData> last_message_;
- // Stored strings for last_message_. The visitor API does not promise the
- // memory pointed to by string_views is persistent.
- std::string string0_, string1_, string2_;
};
class MoqtParserTest