MOQT MAX_SUBSCRIBE_ID implementation

Currently sets max_subscribe_id to UINT62_MAX if it's an earlier draft version. This can be removed soon, when we delete support for draft-05.

PiperOrigin-RevId: 678364064
diff --git a/quiche/quic/moqt/moqt_framer.cc b/quiche/quic/moqt/moqt_framer.cc
index 32d2a5f..f4ddd7a 100644
--- a/quiche/quic/moqt/moqt_framer.cc
+++ b/quiche/quic/moqt/moqt_framer.cc
@@ -263,6 +263,10 @@
     int_parameters.push_back(
         IntParameter(MoqtSetupParameter::kRole, *message.role));
   }
+  if (message.max_subscribe_id.has_value()) {
+    int_parameters.push_back(IntParameter(MoqtSetupParameter::kMaxSubscribeId,
+                                          *message.max_subscribe_id));
+  }
   if (message.supports_object_ack) {
     int_parameters.push_back(
         IntParameter(MoqtSetupParameter::kSupportObjectAcks, 1u));
@@ -287,6 +291,10 @@
     int_parameters.push_back(
         IntParameter(MoqtSetupParameter::kRole, *message.role));
   }
+  if (message.max_subscribe_id.has_value()) {
+    int_parameters.push_back(IntParameter(MoqtSetupParameter::kMaxSubscribeId,
+                                          *message.max_subscribe_id));
+  }
   if (message.supports_object_ack) {
     int_parameters.push_back(
         IntParameter(MoqtSetupParameter::kSupportObjectAcks, 1u));
@@ -499,6 +507,12 @@
                    WireStringWithVarInt62Length(message.new_session_uri));
 }
 
+quiche::QuicheBuffer MoqtFramer::SerializeMaxSubscribeId(
+    const MoqtMaxSubscribeId& message) {
+  return Serialize(WireVarInt62(MoqtMessageType::kMaxSubscribeId),
+                   WireVarInt62(message.max_subscribe_id));
+}
+
 quiche::QuicheBuffer MoqtFramer::SerializeObjectAck(
     const MoqtObjectAck& message) {
   return Serialize(WireVarInt62(MoqtMessageType::kObjectAck),
diff --git a/quiche/quic/moqt/moqt_framer.h b/quiche/quic/moqt/moqt_framer.h
index 188715a..40b4dbf 100644
--- a/quiche/quic/moqt/moqt_framer.h
+++ b/quiche/quic/moqt/moqt_framer.h
@@ -55,6 +55,8 @@
   quiche::QuicheBuffer SerializeUnannounce(const MoqtUnannounce& message);
   quiche::QuicheBuffer SerializeTrackStatus(const MoqtTrackStatus& message);
   quiche::QuicheBuffer SerializeGoAway(const MoqtGoAway& message);
+  quiche::QuicheBuffer SerializeMaxSubscribeId(
+      const MoqtMaxSubscribeId& message);
   quiche::QuicheBuffer SerializeObjectAck(const MoqtObjectAck& message);
 
  private:
diff --git a/quiche/quic/moqt/moqt_framer_test.cc b/quiche/quic/moqt/moqt_framer_test.cc
index cb89551..bd19955 100644
--- a/quiche/quic/moqt/moqt_framer_test.cc
+++ b/quiche/quic/moqt/moqt_framer_test.cc
@@ -48,6 +48,7 @@
       MoqtMessageType::kAnnounceError,
       MoqtMessageType::kUnannounce,
       MoqtMessageType::kGoAway,
+      MoqtMessageType::kMaxSubscribeId,
       MoqtMessageType::kObjectAck,
       MoqtMessageType::kClientSetup,
       MoqtMessageType::kServerSetup,
@@ -157,6 +158,10 @@
         auto data = std::get<MoqtGoAway>(structured_data);
         return framer_.SerializeGoAway(data);
       }
+      case moqt::MoqtMessageType::kMaxSubscribeId: {
+        auto data = std::get<MoqtMaxSubscribeId>(structured_data);
+        return framer_.SerializeMaxSubscribeId(data);
+      }
       case moqt::MoqtMessageType::kObjectAck: {
         auto data = std::get<MoqtObjectAck>(structured_data);
         return framer_.SerializeObjectAck(data);
diff --git a/quiche/quic/moqt/moqt_messages.cc b/quiche/quic/moqt/moqt_messages.cc
index 2ecdbf6..df357fe 100644
--- a/quiche/quic/moqt/moqt_messages.cc
+++ b/quiche/quic/moqt/moqt_messages.cc
@@ -88,6 +88,8 @@
       return "UNANNOUNCE";
     case MoqtMessageType::kGoAway:
       return "GOAWAY";
+    case MoqtMessageType::kMaxSubscribeId:
+      return "MAX_SUBSCRIBE_ID";
     case MoqtMessageType::kObjectAck:
       return "OBJECT_ACK";
   }
diff --git a/quiche/quic/moqt/moqt_messages.h b/quiche/quic/moqt/moqt_messages.h
index 0a6bc55..aaf151d 100644
--- a/quiche/quic/moqt/moqt_messages.h
+++ b/quiche/quic/moqt/moqt_messages.h
@@ -34,6 +34,7 @@
 };
 
 inline constexpr MoqtVersion kDefaultMoqtVersion = MoqtVersion::kDraft05;
+inline constexpr uint64_t kDefaultInitialMaxSubscribeId = 100;
 
 struct QUICHE_EXPORT MoqtSessionParameters {
   // TODO: support multiple versions.
@@ -50,6 +51,7 @@
   quic::Perspective perspective;
   bool using_webtrans;
   std::string path;
+  uint64_t max_subscribe_id = kDefaultInitialMaxSubscribeId;
   bool deliver_partial_objects = false;
   bool support_object_acks = false;
 };
@@ -84,6 +86,7 @@
   kTrackStatusRequest = 0x0d,
   kTrackStatus = 0x0e,
   kGoAway = 0x10,
+  kMaxSubscribeId = 0x15,
   kClientSetup = 0x40,
   kServerSetup = 0x41,
 
@@ -101,6 +104,7 @@
   kProtocolViolation = 0x3,
   kDuplicateTrackAlias = 0x4,
   kParameterLengthMismatch = 0x5,
+  kTooManySubscribes = 0x6,
   kGoawayTimeout = 0x10,
 };
 
@@ -121,6 +125,7 @@
 enum class QUICHE_EXPORT MoqtSetupParameter : uint64_t {
   kRole = 0x0,
   kPath = 0x1,
+  kMaxSubscribeId = 0x2,
 
   // QUICHE-specific extensions.
   // Indicates support for OACK messages.
@@ -220,12 +225,14 @@
   std::vector<MoqtVersion> supported_versions;
   std::optional<MoqtRole> role;
   std::optional<std::string> path;
+  std::optional<uint64_t> max_subscribe_id;
   bool supports_object_ack = false;
 };
 
 struct QUICHE_EXPORT MoqtServerSetup {
   MoqtVersion selected_version;
   std::optional<MoqtRole> role;
+  std::optional<uint64_t> max_subscribe_id;
   bool supports_object_ack = false;
 };
 
@@ -423,6 +430,10 @@
   std::string new_session_uri;
 };
 
+struct QUICHE_EXPORT MoqtMaxSubscribeId {
+  uint64_t max_subscribe_id;
+};
+
 // All of the four values in this message are encoded as varints.
 // `delta_from_deadline` is encoded as an absolute value, with the lowest bit
 // indicating the sign (0 if positive).
diff --git a/quiche/quic/moqt/moqt_parser.cc b/quiche/quic/moqt/moqt_parser.cc
index 923013d..87c45b9 100644
--- a/quiche/quic/moqt/moqt_parser.cc
+++ b/quiche/quic/moqt/moqt_parser.cc
@@ -224,6 +224,8 @@
       return ProcessTrackStatus(reader);
     case MoqtMessageType::kGoAway:
       return ProcessGoAway(reader);
+    case MoqtMessageType::kMaxSubscribeId:
+      return ProcessMaxSubscribeId(reader);
     case moqt::MoqtMessageType::kObjectAck:
       return ProcessObjectAck(reader);
   }
@@ -284,6 +286,18 @@
         }
         setup.path = value;
         break;
+      case MoqtSetupParameter::kMaxSubscribeId:
+        if (setup.max_subscribe_id.has_value()) {
+          ParseError("MAX_SUBSCRIBE_ID parameter appears twice in SETUP");
+          return 0;
+        }
+        uint64_t max_id;
+        if (!StringViewToVarInt(value, max_id)) {
+          ParseError("MAX_SUBSCRIBE_ID parameter is not a valid varint");
+          return 0;
+        }
+        setup.max_subscribe_id = max_id;
+        break;
       case MoqtSetupParameter::kSupportObjectAcks:
         uint64_t flag;
         if (!StringViewToVarInt(value, flag) || flag > 1) {
@@ -347,6 +361,18 @@
       case MoqtSetupParameter::kPath:
         ParseError("PATH parameter in SERVER_SETUP");
         return 0;
+      case MoqtSetupParameter::kMaxSubscribeId:
+        if (setup.max_subscribe_id.has_value()) {
+          ParseError("MAX_SUBSCRIBE_ID parameter appears twice in SETUP");
+          return 0;
+        }
+        uint64_t max_id;
+        if (!StringViewToVarInt(value, max_id)) {
+          ParseError("MAX_SUBSCRIBE_ID parameter is not a valid varint");
+          return 0;
+        }
+        setup.max_subscribe_id = max_id;
+        break;
       case MoqtSetupParameter::kSupportObjectAcks:
         uint64_t flag;
         if (!StringViewToVarInt(value, flag) || flag > 1) {
@@ -723,6 +749,15 @@
   return reader.PreviouslyReadPayload().length();
 }
 
+size_t MoqtControlParser::ProcessMaxSubscribeId(quic::QuicDataReader& reader) {
+  MoqtMaxSubscribeId max_subscribe_id;
+  if (!reader.ReadVarInt62(&max_subscribe_id.max_subscribe_id)) {
+    return 0;
+  }
+  visitor_.OnMaxSubscribeIdMessage(max_subscribe_id);
+  return reader.PreviouslyReadPayload().length();
+}
+
 size_t MoqtControlParser::ProcessObjectAck(quic::QuicDataReader& reader) {
   MoqtObjectAck object_ack;
   uint64_t raw_delta;
diff --git a/quiche/quic/moqt/moqt_parser.h b/quiche/quic/moqt/moqt_parser.h
index f129f72..f19a0a6 100644
--- a/quiche/quic/moqt/moqt_parser.h
+++ b/quiche/quic/moqt/moqt_parser.h
@@ -43,6 +43,7 @@
   virtual void OnUnannounceMessage(const MoqtUnannounce& message) = 0;
   virtual void OnTrackStatusMessage(const MoqtTrackStatus& message) = 0;
   virtual void OnGoAwayMessage(const MoqtGoAway& message) = 0;
+  virtual void OnMaxSubscribeIdMessage(const MoqtMaxSubscribeId& message) = 0;
   virtual void OnObjectAckMessage(const MoqtObjectAck& message) = 0;
 
   virtual void OnParsingError(MoqtError code, absl::string_view reason) = 0;
@@ -107,6 +108,7 @@
   size_t ProcessUnannounce(quic::QuicDataReader& reader);
   size_t ProcessTrackStatus(quic::QuicDataReader& reader);
   size_t ProcessGoAway(quic::QuicDataReader& reader);
+  size_t ProcessMaxSubscribeId(quic::QuicDataReader& reader);
   size_t ProcessObjectAck(quic::QuicDataReader& reader);
 
   // If |error| is not provided, assumes kProtocolViolation.
diff --git a/quiche/quic/moqt/moqt_parser_test.cc b/quiche/quic/moqt/moqt_parser_test.cc
index 53bbdf8..f69ef96 100644
--- a/quiche/quic/moqt/moqt_parser_test.cc
+++ b/quiche/quic/moqt/moqt_parser_test.cc
@@ -39,7 +39,7 @@
     MoqtMessageType::kAnnounceOk,     MoqtMessageType::kAnnounceError,
     MoqtMessageType::kUnannounce,     MoqtMessageType::kClientSetup,
     MoqtMessageType::kServerSetup,    MoqtMessageType::kGoAway,
-    MoqtMessageType::kObjectAck,
+    MoqtMessageType::kMaxSubscribeId, MoqtMessageType::kObjectAck,
 };
 constexpr std::array kDataStreamTypes{MoqtDataStreamType::kObjectStream,
                                       MoqtDataStreamType::kStreamHeaderTrack,
@@ -162,6 +162,9 @@
   void OnGoAwayMessage(const MoqtGoAway& message) override {
     OnControlMessage(message);
   }
+  void OnMaxSubscribeIdMessage(const MoqtMaxSubscribeId& message) override {
+    OnControlMessage(message);
+  }
   void OnObjectAckMessage(const MoqtObjectAck& message) override {
     OnControlMessage(message);
   }
@@ -541,6 +544,24 @@
   EXPECT_EQ(visitor_.parsing_error_code_, MoqtError::kProtocolViolation);
 }
 
+TEST_F(MoqtMessageSpecificTest, ClientSetupMaxSubscribeIdAppearsTwice) {
+  MoqtControlParser parser(kRawQuic, visitor_);
+  char setup[] = {
+      0x40, 0x40, 0x02, 0x01, 0x02,  // versions
+      0x04,                          // 4 params
+      0x00, 0x01, 0x03,              // role = PubSub
+      0x01, 0x03, 0x66, 0x6f, 0x6f,  // path = "foo"
+      0x02, 0x01, 0x32,              // max_subscribe_id = 50
+      0x02, 0x01, 0x32,              // max_subscribe_id = 50
+  };
+  parser.ProcessData(absl::string_view(setup, sizeof(setup)), false);
+  EXPECT_EQ(visitor_.messages_received_, 0);
+  EXPECT_TRUE(visitor_.parsing_error_.has_value());
+  EXPECT_EQ(*visitor_.parsing_error_,
+            "MAX_SUBSCRIBE_ID parameter appears twice in SETUP");
+  EXPECT_EQ(visitor_.parsing_error_code_, MoqtError::kProtocolViolation);
+}
+
 TEST_F(MoqtMessageSpecificTest, ServerSetupRoleIsMissing) {
   MoqtControlParser parser(kRawQuic, visitor_);
   char setup[] = {
@@ -635,6 +656,24 @@
   EXPECT_EQ(visitor_.parsing_error_code_, MoqtError::kProtocolViolation);
 }
 
+TEST_F(MoqtMessageSpecificTest, ServerSetupMaxSubscribeIdAppearsTwice) {
+  MoqtControlParser parser(kRawQuic, visitor_);
+  char setup[] = {
+      0x40, 0x40, 0x02, 0x01, 0x02,  // versions = 1, 2
+      0x04,                          // 4 params
+      0x00, 0x01, 0x03,              // role = PubSub
+      0x01, 0x03, 0x66, 0x6f, 0x6f,  // path = "foo"
+      0x02, 0x01, 0x32,              // max_subscribe_id = 50
+      0x02, 0x01, 0x32,              // max_subscribe_id = 50
+  };
+  parser.ProcessData(absl::string_view(setup, sizeof(setup)), false);
+  EXPECT_EQ(visitor_.messages_received_, 0);
+  EXPECT_TRUE(visitor_.parsing_error_.has_value());
+  EXPECT_EQ(*visitor_.parsing_error_,
+            "MAX_SUBSCRIBE_ID parameter appears twice in SETUP");
+  EXPECT_EQ(visitor_.parsing_error_code_, MoqtError::kProtocolViolation);
+}
+
 TEST_F(MoqtMessageSpecificTest, SubscribeAuthorizationInfoTwice) {
   MoqtControlParser parser(kWebTrans, visitor_);
   char subscribe[] = {
diff --git a/quiche/quic/moqt/moqt_session.cc b/quiche/quic/moqt/moqt_session.cc
index afcee57..d11f9d2 100644
--- a/quiche/quic/moqt/moqt_session.cc
+++ b/quiche/quic/moqt/moqt_session.cc
@@ -105,6 +105,7 @@
       callbacks_(std::move(callbacks)),
       framer_(quiche::SimpleBufferAllocator::Get(), parameters.using_webtrans),
       publisher_(DefaultPublisher::GetInstance()),
+      local_max_subscribe_id_(parameters.max_subscribe_id),
       liveness_token_(std::make_shared<Empty>()) {}
 
 MoqtSession::ControlStream* MoqtSession::GetControlStream() {
@@ -146,6 +147,7 @@
   MoqtClientSetup setup = MoqtClientSetup{
       .supported_versions = std::vector<MoqtVersion>{parameters_.version},
       .role = MoqtRole::kPubSub,
+      .max_subscribe_id = parameters_.max_subscribe_id,
       .supports_object_ack = parameters_.support_object_acks,
   };
   if (!parameters_.using_webtrans) {
@@ -389,6 +391,13 @@
     return false;
   }
   // TODO(martinduke): support authorization info
+  if (next_subscribe_id_ > peer_max_subscribe_id_) {
+    QUIC_DLOG(INFO) << ENDPOINT << "Tried to send SUBSCRIBE with ID "
+                    << message.subscribe_id
+                    << " which is greater than the maximum ID "
+                    << peer_max_subscribe_id_;
+    return false;
+  }
   message.subscribe_id = next_subscribe_id_++;
   FullTrackName ftn(std::string(message.track_namespace),
                     std::string(message.track_name));
@@ -492,6 +501,13 @@
   }
 }
 
+void MoqtSession::GrantMoreSubscribes(uint64_t num_subscribes) {
+  local_max_subscribe_id_ += num_subscribes;
+  MoqtMaxSubscribeId message;
+  message.max_subscribe_id = local_max_subscribe_id_;
+  SendControlMessage(framer_.SerializeMaxSubscribeId(message));
+}
+
 std::pair<FullTrackName, RemoteTrack::Visitor*>
 MoqtSession::TrackPropertiesFromAlias(const MoqtObject& message) {
   auto it = remote_tracks_.find(message.track_alias);
@@ -596,11 +612,18 @@
     MoqtServerSetup response;
     response.selected_version = session_->parameters_.version;
     response.role = MoqtRole::kPubSub;
+    response.max_subscribe_id = session_->parameters_.max_subscribe_id;
     response.supports_object_ack = session_->parameters_.support_object_acks;
     SendOrBufferMessage(session_->framer_.SerializeServerSetup(response));
     QUIC_DLOG(INFO) << ENDPOINT << "Sent the SETUP message";
   }
   // TODO: handle role and path.
+  if (message.max_subscribe_id.has_value()) {
+    session_->peer_max_subscribe_id_ = *message.max_subscribe_id;
+  } else if (session_->parameters_.version == MoqtVersion::kDraft05) {
+    // TODO (martinduke): Delete this when we roll the version number.
+    session_->peer_max_subscribe_id_ = UINT64_MAX >> 2;
+  }
   std::move(session_->callbacks_.session_established_callback)();
   session_->peer_role_ = *message.role;
 }
@@ -622,6 +645,12 @@
   session_->peer_supports_object_ack_ = message.supports_object_ack;
   QUIC_DLOG(INFO) << ENDPOINT << "Received the SETUP message";
   // TODO: handle role and path.
+  if (message.max_subscribe_id.has_value()) {
+    session_->peer_max_subscribe_id_ = *message.max_subscribe_id;
+  } else if (session_->parameters_.version == MoqtVersion::kDraft05) {
+    // TODO (martinduke): Delete this when we roll the version number.
+    session_->peer_max_subscribe_id_ = UINT64_MAX >> 2;
+  }
   std::move(session_->callbacks_.session_established_callback)();
   session_->peer_role_ = *message.role;
 }
@@ -646,6 +675,12 @@
                     "Received SUBSCRIBE from publisher");
     return;
   }
+  if (message.subscribe_id > session_->local_max_subscribe_id_) {
+    QUIC_DLOG(INFO) << ENDPOINT << "Received SUBSCRIBE with too large ID";
+    session_->Error(MoqtError::kTooManySubscribes,
+                    "Received SUBSCRIBE with too large ID");
+    return;
+  }
   QUIC_DLOG(INFO) << ENDPOINT << "Received a SUBSCRIBE for "
                   << message.track_namespace << ":" << message.track_name;
 
@@ -837,6 +872,25 @@
   // TODO: notify the application about this.
 }
 
+void MoqtSession::ControlStream::OnMaxSubscribeIdMessage(
+    const MoqtMaxSubscribeId& message) {
+  if (session_->peer_role_ == MoqtRole::kSubscriber) {
+    QUIC_DLOG(INFO) << ENDPOINT << "Subscriber peer sent MAX_SUBSCRIBE_ID";
+    session_->Error(MoqtError::kProtocolViolation,
+                    "Received MAX_SUBSCRIBE_ID from Subscriber");
+    return;
+  }
+  if (message.max_subscribe_id < session_->peer_max_subscribe_id_) {
+    QUIC_DLOG(INFO) << ENDPOINT
+                    << "Peer sent MAX_SUBSCRIBE_ID message with "
+                       "lower value than previous";
+    session_->Error(MoqtError::kProtocolViolation,
+                    "MAX_SUBSCRIBE_ID message has lower value than previous");
+    return;
+  }
+  session_->peer_max_subscribe_id_ = message.max_subscribe_id;
+}
+
 void MoqtSession::ControlStream::OnParsingError(MoqtError error_code,
                                                 absl::string_view reason) {
   session_->Error(error_code, absl::StrCat("Parse error: ", reason));
diff --git a/quiche/quic/moqt/moqt_session.h b/quiche/quic/moqt/moqt_session.h
index d437dc3..79104c6 100644
--- a/quiche/quic/moqt/moqt_session.h
+++ b/quiche/quic/moqt/moqt_session.h
@@ -172,6 +172,8 @@
       std::optional<webtransport::SendOrder> old_send_order,
       std::optional<webtransport::SendOrder> new_send_order);
 
+  void GrantMoreSubscribes(uint64_t num_subscribes);
+
  private:
   friend class test::MoqtSessionPeer;
 
@@ -207,6 +209,7 @@
     void OnUnannounceMessage(const MoqtUnannounce& /*message*/) override {}
     void OnTrackStatusMessage(const MoqtTrackStatus& message) override {}
     void OnGoAwayMessage(const MoqtGoAway& /*message*/) override {}
+    void OnMaxSubscribeIdMessage(const MoqtMaxSubscribeId& message) override;
     void OnObjectAckMessage(const MoqtObjectAck& message) override {
       auto subscription_it =
           session_->published_subscriptions_.find(message.subscribe_id);
@@ -521,6 +524,11 @@
   // parameter, and other checks have changed/been disabled.
   MoqtRole peer_role_ = MoqtRole::kPubSub;
 
+  // The maximum subscribe ID that the local endpoint can send.
+  uint64_t peer_max_subscribe_id_ = 0;
+  // The maximum subscribe ID sent to the peer.
+  uint64_t local_max_subscribe_id_ = 0;
+
   // Must be last.  Token used to make sure that the streams do not call into
   // the session when the session has already been destroyed.
   struct Empty {};
diff --git a/quiche/quic/moqt/moqt_session_test.cc b/quiche/quic/moqt/moqt_session_test.cc
index cd7d49c..25783a3 100644
--- a/quiche/quic/moqt/moqt_session_test.cc
+++ b/quiche/quic/moqt/moqt_session_test.cc
@@ -10,6 +10,7 @@
 #include <optional>
 #include <string>
 #include <utility>
+#include <vector>
 
 #include "absl/status/status.h"
 #include "absl/strings/match.h"
@@ -157,15 +158,25 @@
   static RemoteTrack& remote_track(MoqtSession* session, uint64_t track_alias) {
     return session->remote_tracks_.find(track_alias)->second;
   }
+
+  static void set_next_subscribe_id(MoqtSession* session, uint64_t id) {
+    session->next_subscribe_id_ = id;
+  }
+
+  static void set_peer_max_subscribe_id(MoqtSession* session, uint64_t id) {
+    session->peer_max_subscribe_id_ = id;
+  }
 };
 
 class MoqtSessionTest : public quic::test::QuicTest {
  public:
   MoqtSessionTest()
       : session_(&mock_session_,
-                 MoqtSessionParameters(quic::Perspective::IS_CLIENT),
+                 MoqtSessionParameters(quic::Perspective::IS_CLIENT, ""),
                  session_callbacks_.AsSessionCallbacks()) {
     session_.set_publisher(&publisher_);
+    MoqtSessionPeer::set_peer_max_subscribe_id(&session_,
+                                               kDefaultInitialMaxSubscribeId);
   }
   ~MoqtSessionTest() {
     EXPECT_CALL(session_callbacks_.session_deleted_callback, Call());
@@ -464,6 +475,53 @@
   EXPECT_TRUE(correct_message);
 }
 
+TEST_F(MoqtSessionTest, SubscribeIdTooHigh) {
+  // Peer subscribes to (0, 0)
+  MoqtSubscribe request = {
+      /*subscribe_id=*/kDefaultInitialMaxSubscribeId + 1,
+      /*track_alias=*/2,
+      /*track_namespace=*/"foo",
+      /*track_name=*/"bar",
+      /*subscriber_priority=*/0x80,
+      /*group_order=*/std::nullopt,
+      /*start_group=*/0,
+      /*start_object=*/0,
+      /*end_group=*/std::nullopt,
+      /*end_object=*/std::nullopt,
+      /*parameters=*/MoqtSubscribeParameters(),
+  };
+  webtransport::test::MockStream mock_stream;
+  std::unique_ptr<MoqtControlParserVisitor> stream_input =
+      MoqtSessionPeer::CreateControlStream(&session_, &mock_stream);
+  EXPECT_CALL(mock_session_,
+              CloseSession(static_cast<uint64_t>(MoqtError::kTooManySubscribes),
+                           "Received SUBSCRIBE with too large ID"))
+      .Times(1);
+  stream_input->OnSubscribeMessage(request);
+}
+
+TEST_F(MoqtSessionTest, TooManySubscribes) {
+  MoqtSessionPeer::set_next_subscribe_id(&session_,
+                                         kDefaultInitialMaxSubscribeId);
+  MockRemoteTrackVisitor remote_track_visitor;
+  webtransport::test::MockStream mock_stream;
+  std::unique_ptr<MoqtControlParserVisitor> stream_input =
+      MoqtSessionPeer::CreateControlStream(&session_, &mock_stream);
+  EXPECT_CALL(mock_session_, GetStreamById(_)).WillOnce(Return(&mock_stream));
+  bool correct_message = true;
+  EXPECT_CALL(mock_stream, Writev(_, _))
+      .WillOnce([&](absl::Span<const absl::string_view> data,
+                    const quiche::StreamWriteOptions& options) {
+        correct_message = true;
+        EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kSubscribe);
+        return absl::OkStatus();
+      });
+  EXPECT_TRUE(
+      session_.SubscribeCurrentGroup("foo", "bar", &remote_track_visitor));
+  EXPECT_FALSE(
+      session_.SubscribeCurrentGroup("foo", "bar", &remote_track_visitor));
+}
+
 TEST_F(MoqtSessionTest, SubscribeWithOk) {
   webtransport::test::MockStream mock_stream;
   std::unique_ptr<MoqtControlParserVisitor> stream_input =
@@ -496,6 +554,102 @@
   EXPECT_TRUE(correct_message);
 }
 
+TEST_F(MoqtSessionTest, MaxSubscribeIdChangesResponse) {
+  MoqtSessionPeer::set_next_subscribe_id(&session_,
+                                         kDefaultInitialMaxSubscribeId + 1);
+  MockRemoteTrackVisitor remote_track_visitor;
+  EXPECT_FALSE(
+      session_.SubscribeCurrentGroup("foo", "bar", &remote_track_visitor));
+  MoqtMaxSubscribeId max_subscribe_id = {
+      /*max_subscribe_id=*/kDefaultInitialMaxSubscribeId + 1,
+  };
+  webtransport::test::MockStream mock_stream;
+  std::unique_ptr<MoqtControlParserVisitor> stream_input =
+      MoqtSessionPeer::CreateControlStream(&session_, &mock_stream);
+  stream_input->OnMaxSubscribeIdMessage(max_subscribe_id);
+  EXPECT_CALL(mock_session_, GetStreamById(_)).WillOnce(Return(&mock_stream));
+  bool correct_message = true;
+  EXPECT_CALL(mock_stream, Writev(_, _))
+      .WillOnce([&](absl::Span<const absl::string_view> data,
+                    const quiche::StreamWriteOptions& options) {
+        correct_message = true;
+        EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kSubscribe);
+        return absl::OkStatus();
+      });
+  EXPECT_TRUE(
+      session_.SubscribeCurrentGroup("foo", "bar", &remote_track_visitor));
+  EXPECT_TRUE(correct_message);
+}
+
+TEST_F(MoqtSessionTest, LowerMaxSubscribeIdIsAnError) {
+  MoqtMaxSubscribeId max_subscribe_id = {
+      /*max_subscribe_id=*/kDefaultInitialMaxSubscribeId - 1,
+  };
+  webtransport::test::MockStream mock_stream;
+  std::unique_ptr<MoqtControlParserVisitor> stream_input =
+      MoqtSessionPeer::CreateControlStream(&session_, &mock_stream);
+  EXPECT_CALL(
+      mock_session_,
+      CloseSession(static_cast<uint64_t>(MoqtError::kProtocolViolation),
+                   "MAX_SUBSCRIBE_ID message has lower value than previous"))
+      .Times(1);
+  stream_input->OnMaxSubscribeIdMessage(max_subscribe_id);
+}
+
+TEST_F(MoqtSessionTest, GrantMoreSubscribes) {
+  webtransport::test::MockStream mock_stream;
+  std::unique_ptr<MoqtControlParserVisitor> stream_input =
+      MoqtSessionPeer::CreateControlStream(&session_, &mock_stream);
+  EXPECT_CALL(mock_session_, GetStreamById(_)).WillOnce(Return(&mock_stream));
+  bool correct_message = true;
+  EXPECT_CALL(mock_stream, Writev(_, _))
+      .WillOnce([&](absl::Span<const absl::string_view> data,
+                    const quiche::StreamWriteOptions& options) {
+        correct_message = true;
+        EXPECT_EQ(*ExtractMessageType(data[0]),
+                  MoqtMessageType::kMaxSubscribeId);
+        return absl::OkStatus();
+      });
+  session_.GrantMoreSubscribes(1);
+  EXPECT_TRUE(correct_message);
+  // Peer subscribes to (0, 0)
+  MoqtSubscribe request = {
+      /*subscribe_id=*/kDefaultInitialMaxSubscribeId + 1,
+      /*track_alias=*/2,
+      /*track_namespace=*/"foo",
+      /*track_name=*/"bar",
+      /*subscriber_priority=*/0x80,
+      /*group_order=*/std::nullopt,
+      /*start_group=*/0,
+      /*start_object=*/0,
+      /*end_group=*/std::nullopt,
+      /*end_object=*/std::nullopt,
+      /*parameters=*/MoqtSubscribeParameters(),
+  };
+  correct_message = false;
+  FullTrackName ftn("foo", "bar");
+  auto track = std::make_shared<MockTrackPublisher>(ftn);
+  EXPECT_CALL(*track, GetTrackStatus())
+      .WillRepeatedly(Return(MoqtTrackStatusCode::kInProgress));
+  EXPECT_CALL(*track, GetCachedObject(_)).WillRepeatedly([] {
+    return std::optional<PublishedObject>();
+  });
+  EXPECT_CALL(*track, GetCachedObjectsInRange(_, _))
+      .WillRepeatedly(Return(std::vector<FullSequence>()));
+  EXPECT_CALL(*track, GetLargestSequence())
+      .WillRepeatedly(Return(FullSequence(10, 20)));
+  publisher_.Add(track);
+  EXPECT_CALL(mock_stream, Writev(_, _))
+      .WillOnce([&](absl::Span<const absl::string_view> data,
+                    const quiche::StreamWriteOptions& options) {
+        correct_message = true;
+        EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kSubscribeOk);
+        return absl::OkStatus();
+      });
+  stream_input->OnSubscribeMessage(request);
+  EXPECT_TRUE(correct_message);
+}
+
 TEST_F(MoqtSessionTest, SubscribeWithError) {
   webtransport::test::MockStream mock_stream;
   std::unique_ptr<MoqtControlParserVisitor> stream_input =
diff --git a/quiche/quic/moqt/test_tools/moqt_test_message.h b/quiche/quic/moqt/test_tools/moqt_test_message.h
index 9633ef3..813d4e8 100644
--- a/quiche/quic/moqt/test_tools/moqt_test_message.h
+++ b/quiche/quic/moqt/test_tools/moqt_test_message.h
@@ -38,7 +38,7 @@
                     MoqtSubscribeDone, MoqtSubscribeUpdate, MoqtAnnounce,
                     MoqtAnnounceOk, MoqtAnnounceError, MoqtAnnounceCancel,
                     MoqtTrackStatusRequest, MoqtUnannounce, MoqtTrackStatus,
-                    MoqtGoAway, MoqtObjectAck>;
+                    MoqtGoAway, MoqtMaxSubscribeId, MoqtObjectAck>;
 
   // The total actual size of the message.
   size_t total_message_size() const { return wire_image_size_; }
@@ -282,7 +282,7 @@
     if (webtrans) {
       // Should not send PATH.
       client_setup_.path = std::nullopt;
-      raw_packet_[5] = 0x01;  // only one parameter
+      raw_packet_[5] = 0x02;  // only two parameters
       SetWireImage(raw_packet_, sizeof(raw_packet_) - 5);
     } else {
       SetWireImage(raw_packet_, sizeof(raw_packet_));
@@ -311,16 +311,20 @@
       QUIC_LOG(INFO) << "CLIENT_SETUP path mismatch";
       return false;
     }
+    if (cast.max_subscribe_id != client_setup_.max_subscribe_id) {
+      QUIC_LOG(INFO) << "CLIENT_SETUP max_subscribe_id mismatch";
+      return false;
+    }
     return true;
   }
 
   void ExpandVarints() override {
     if (client_setup_.path.has_value()) {
-      ExpandVarintsImpl("--vvvvvv-vv---");
+      ExpandVarintsImpl("--vvvvvv-vv-vv---");
       // first two bytes are already a 2B varint. Also, don't expand parameter
       // varints because that messes up the parameter length field.
     } else {
-      ExpandVarintsImpl("--vvvvvv-");
+      ExpandVarintsImpl("--vvvvvv-vv-");
     }
   }
 
@@ -329,18 +333,20 @@
   }
 
  private:
-  uint8_t raw_packet_[14] = {
-      0x40, 0x40,                   // type
-      0x02, 0x01, 0x02,             // versions
-      0x02,                         // 2 parameters
-      0x00, 0x01, 0x03,             // role = PubSub
-      0x01, 0x03, 0x66, 0x6f, 0x6f  // path = "foo"
+  uint8_t raw_packet_[17] = {
+      0x40, 0x40,                    // type
+      0x02, 0x01, 0x02,              // versions
+      0x03,                          // 3 parameters
+      0x00, 0x01, 0x03,              // role = PubSub
+      0x02, 0x01, 0x32,              // max_subscribe_id = 50
+      0x01, 0x03, 0x66, 0x6f, 0x6f,  // path = "foo"
   };
   MoqtClientSetup client_setup_ = {
       /*supported_versions=*/std::vector<MoqtVersion>(
           {static_cast<MoqtVersion>(1), static_cast<MoqtVersion>(2)}),
       /*role=*/MoqtRole::kPubSub,
       /*path=*/"foo",
+      /*max_subscribe_id=*/50,
   };
 };
 
@@ -360,11 +366,15 @@
       QUIC_LOG(INFO) << "SERVER_SETUP role mismatch";
       return false;
     }
+    if (cast.max_subscribe_id != server_setup_.max_subscribe_id) {
+      QUIC_LOG(INFO) << "SERVER_SETUP max_subscribe_id mismatch";
+      return false;
+    }
     return true;
   }
 
   void ExpandVarints() override {
-    ExpandVarintsImpl("--vvvv-");  // first two bytes are already a 2b varint
+    ExpandVarintsImpl("--vvvv-vv-");  // first two bytes are already a 2b varint
   }
 
   MessageStructuredData structured_data() const override {
@@ -372,14 +382,16 @@
   }
 
  private:
-  uint8_t raw_packet_[7] = {
+  uint8_t raw_packet_[10] = {
       0x40, 0x41,        // type
-      0x01, 0x01,        // version, one param
+      0x01, 0x02,        // version, two parameters
       0x00, 0x01, 0x03,  // role = PubSub
+      0x02, 0x01, 0x32,  // max_subscribe_id = 50
   };
   MoqtServerSetup server_setup_ = {
       /*selected_version=*/static_cast<MoqtVersion>(1),
       /*role=*/MoqtRole::kPubSub,
+      /*max_subscribe_id=*/50,
   };
 };
 
@@ -1031,6 +1043,38 @@
   };
 };
 
+class QUICHE_NO_EXPORT MaxSubscribeIdMessage : public TestMessageBase {
+ public:
+  MaxSubscribeIdMessage() : TestMessageBase() {
+    SetWireImage(raw_packet_, sizeof(raw_packet_));
+  }
+
+  bool EqualFieldValues(MessageStructuredData& values) const override {
+    auto cast = std::get<MoqtMaxSubscribeId>(values);
+    if (cast.max_subscribe_id != max_subscribe_id_.max_subscribe_id) {
+      QUIC_LOG(INFO) << "MAX_SUBSCRIBE_ID mismatch";
+      return false;
+    }
+    return true;
+  }
+
+  void ExpandVarints() override { ExpandVarintsImpl("vv"); }
+
+  MessageStructuredData structured_data() const override {
+    return TestMessageBase::MessageStructuredData(max_subscribe_id_);
+  }
+
+ private:
+  uint8_t raw_packet_[2] = {
+      0x15,
+      0x0b,
+  };
+
+  MoqtMaxSubscribeId max_subscribe_id_ = {
+      /*max_subscribe_id =*/11,
+  };
+};
+
 class QUICHE_NO_EXPORT ObjectAckMessage : public TestMessageBase {
  public:
   ObjectAckMessage() : TestMessageBase() {
@@ -1111,6 +1155,8 @@
       return std::make_unique<TrackStatusMessage>();
     case MoqtMessageType::kGoAway:
       return std::make_unique<GoAwayMessage>();
+    case MoqtMessageType::kMaxSubscribeId:
+      return std::make_unique<MaxSubscribeIdMessage>();
     case MoqtMessageType::kObjectAck:
       return std::make_unique<ObjectAckMessage>();
     case MoqtMessageType::kClientSetup: