Parse only the message type on UnknownBidiStream.

This avoids the rickety nature of parsing the whole message and possibly destroying the parser.

The ASAN tests that previously caused trouble pass.

PiperOrigin-RevId: 863412638
diff --git a/quiche/quic/moqt/moqt_bidi_stream.h b/quiche/quic/moqt/moqt_bidi_stream.h
index e1e5700..f58d65c 100644
--- a/quiche/quic/moqt/moqt_bidi_stream.h
+++ b/quiche/quic/moqt/moqt_bidi_stream.h
@@ -9,6 +9,7 @@
 #include <cstddef>
 #include <cstdint>
 #include <memory>
+#include <optional>
 #include <queue>
 #include <utility>
 
@@ -55,12 +56,8 @@
         stream_deleted_callback_(std::move(stream_deleted_callback)),
         session_error_callback_(std::move(session_error_callback)) {}
   ~MoqtBidiStreamBase() override { std::move(stream_deleted_callback_)(); }
-  // The caller is responsible for calling stream->SetVisitor(). Derived
-  // classes will wrap this with a call to stream->SetPriority().
   virtual void set_stream(webtransport::Stream* absl_nonnull stream) {
-    stream_ = stream;
-    parser_ = std::make_unique<MoqtControlParser>(framer_->using_webtrans(),
-                                                  stream_, *this);
+    set_stream(stream, std::nullopt);
   }
 
   // MoqtControlParserVisitor implementation. All control messages are protocol
@@ -217,6 +214,17 @@
   }
 
  protected:
+  // The caller is responsible for calling stream->SetVisitor(). Derived
+  // classes will wrap this with a call to stream->SetPriority().
+  void set_stream(webtransport::Stream* absl_nonnull stream,
+                  std::optional<MoqtMessageType> first_message_type) {
+    stream_ = stream;
+    parser_ = std::make_unique<MoqtControlParser>(framer_->using_webtrans(),
+                                                  stream_, *this);
+    if (first_message_type.has_value()) {
+      parser_->set_message_type(static_cast<uint64_t>(*first_message_type));
+    }
+  }
   const size_t kMaxPendingMessages = 100;
   void AddToQueue(quiche::QuicheBuffer message) {
     if (pending_messages_.size() == kMaxPendingMessages) {
diff --git a/quiche/quic/moqt/moqt_parser.cc b/quiche/quic/moqt/moqt_parser.cc
index 9bbe71e..209c73c 100644
--- a/quiche/quic/moqt/moqt_parser.cc
+++ b/quiche/quic/moqt/moqt_parser.cc
@@ -471,7 +471,19 @@
   return error;
 }
 
-void MoqtControlParser::ReadAndDispatchMessages(bool stop_after_one_message) {
+bool MoqtMessageTypeParser::ReadUntilMessageTypeKnown() {
+  if (message_type_.has_value()) {
+    return true;
+  }
+  bool fin_read = false;
+  message_type_ = ReadVarInt62FromStream(stream_, fin_read);
+  if (fin_read) {
+    return false;
+  }
+  return true;
+}
+
+void MoqtControlParser::ReadAndDispatchMessages() {
   if (no_more_data_) {
     ParseError("Data after end of stream");
     return;
@@ -480,12 +492,7 @@
     return;
   }
   processing_ = true;
-  bool clear_processing_on_return = true;
-  auto on_return = absl::MakeCleanup([&] {
-    if (clear_processing_on_return) {
-      processing_ = false;
-    }
-  });
+  auto on_return = absl::MakeCleanup([&] { processing_ = false; });
   while (!no_more_data_) {
     bool fin_read = false;
     // Read the message type.
@@ -547,23 +554,10 @@
       ParseError("FIN on control stream");
       return;
     }
-    // It's possible ProcessMessage destroys the parser if
-    // stop_after_one_message is true, so extract what is needed so it can be
-    // reset beforehand.
-    QUICHE_DCHECK(message_type_.has_value());
-    MoqtMessageType message_type = static_cast<MoqtMessageType>(*message_type_);
+    ProcessMessage(absl::string_view(message.data(), message.size()),
+                   static_cast<MoqtMessageType>(*message_type_));
     message_type_.reset();
     message_size_.reset();
-    if (stop_after_one_message) {
-      clear_processing_on_return = false;
-      processing_ = false;
-    }
-    ProcessMessage(absl::string_view(message.data(), message.size()),
-                   message_type);
-    if (stop_after_one_message) {
-      return;
-    }
-    processing_ = true;
   }
 }
 
diff --git a/quiche/quic/moqt/moqt_parser.h b/quiche/quic/moqt/moqt_parser.h
index 19f2d21..e84eccc 100644
--- a/quiche/quic/moqt/moqt_parser.h
+++ b/quiche/quic/moqt/moqt_parser.h
@@ -13,7 +13,6 @@
 #include <optional>
 #include <string>
 
-#include "absl/base/nullability.h"
 #include "absl/strings/string_view.h"
 #include "quiche/quic/core/quic_data_reader.h"
 #include "quiche/quic/moqt/moqt_error.h"
@@ -88,6 +87,20 @@
   virtual void OnParsingError(MoqtError code, absl::string_view reason) = 0;
 };
 
+class QUICHE_EXPORT MoqtMessageTypeParser {
+ public:
+  MoqtMessageTypeParser(quiche::ReadStream* stream) : stream_(*stream) {}
+  ~MoqtMessageTypeParser() = default;
+
+  // Returns false if there was a FIN.
+  bool ReadUntilMessageTypeKnown();
+  std::optional<uint64_t> message_type() const { return message_type_; }
+
+ private:
+  quiche::ReadStream& stream_;
+  std::optional<uint64_t> message_type_;
+};
+
 class QUICHE_EXPORT MoqtControlParser {
  public:
   MoqtControlParser(bool uses_web_transport, quiche::ReadStream* stream,
@@ -97,7 +110,8 @@
         uses_web_transport_(uses_web_transport) {}
   ~MoqtControlParser() = default;
 
-  void ReadAndDispatchMessages(bool stop_after_one_message = false);
+  void set_message_type(uint64_t message_type) { message_type_ = message_type; }
+  void ReadAndDispatchMessages();
 
  private:
   // The central switch statement to dispatch a message to the correct
diff --git a/quiche/quic/moqt/moqt_parser_test.cc b/quiche/quic/moqt/moqt_parser_test.cc
index 355b96c..f864cd9 100644
--- a/quiche/quic/moqt/moqt_parser_test.cc
+++ b/quiche/quic/moqt/moqt_parser_test.cc
@@ -1213,26 +1213,30 @@
   EXPECT_FALSE(visitor_.parsing_error_.has_value());
 }
 
-TEST_F(MoqtMessageSpecificTest, ReadOnlyOneMessage) {
-  char buffer[5000];
+TEST_F(MoqtMessageSpecificTest, ReadOnlyMessageType) {
   webtransport::test::InMemoryStream stream(/*stream_id=*/0);
-  MoqtControlParser parser(kRawQuic, &stream, visitor_);
-  size_t write = 0;
-  std::unique_ptr<TestMessageBase> first_message =
-      CreateTestMessage(MoqtMessageType::kRequestOk, kRawQuic);
-  std::unique_ptr<TestMessageBase> second_message =
-      CreateTestMessage(MoqtMessageType::kRequestError, kRawQuic);
-  memcpy(buffer, first_message->PacketSample().data(),
-         first_message->total_message_size());
-  write += first_message->total_message_size();
-  memcpy(buffer + write, second_message->PacketSample().data(),
-         second_message->total_message_size());
-  write += second_message->total_message_size();
-  stream.Receive(absl::string_view(buffer, write), false);
-  parser.ReadAndDispatchMessages(true);
-  EXPECT_EQ(visitor_.messages_received_, 1);
-  parser.ReadAndDispatchMessages(true);
-  EXPECT_EQ(visitor_.messages_received_, 2);
+  MoqtMessageTypeParser parser(&stream);
+  char buffer[] = {0x40, 0x03};
+  stream.Receive(absl::string_view(buffer, sizeof(buffer)), false);
+  EXPECT_TRUE(parser.ReadUntilMessageTypeKnown());
+  EXPECT_EQ(parser.message_type(), 0x03);
+}
+
+TEST_F(MoqtMessageSpecificTest, ReadOnlyMessageTypeIncomplete) {
+  webtransport::test::InMemoryStream stream(/*stream_id=*/0);
+  MoqtMessageTypeParser parser(&stream);
+  char buffer[] = {0x40};
+  stream.Receive(absl::string_view(buffer, sizeof(buffer)), false);
+  EXPECT_TRUE(parser.ReadUntilMessageTypeKnown());
+  EXPECT_FALSE(parser.message_type().has_value());
+}
+
+TEST_F(MoqtMessageSpecificTest, ReadOnlyMessageTypeEarlyFin) {
+  webtransport::test::InMemoryStream stream(/*stream_id=*/0);
+  MoqtMessageTypeParser parser(&stream);
+  char buffer[] = {0x03};
+  stream.Receive(absl::string_view(buffer, sizeof(buffer)), true);
+  EXPECT_FALSE(parser.ReadUntilMessageTypeKnown());
 }
 
 TEST_F(MoqtMessageSpecificTest, DatagramSuccessful) {
diff --git a/quiche/quic/moqt/moqt_session.cc b/quiche/quic/moqt/moqt_session.cc
index efad8b9..10798cb 100644
--- a/quiche/quic/moqt/moqt_session.cc
+++ b/quiche/quic/moqt/moqt_session.cc
@@ -893,7 +893,56 @@
   return true;
 }
 
-void MoqtSession::UnknownBidiStream::OnClientSetupMessage(
+void MoqtSession::UnknownBidiStream::OnCanRead() {
+  if (!parser_.ReadUntilMessageTypeKnown()) {
+    // Got an early FIN.
+    stream_->ResetWithUserCode(kResetCodeCanceled);
+    return;
+  }
+  if (!parser_.message_type().has_value()) {
+    return;
+  }
+  MoqtMessageType message_type =
+      static_cast<MoqtMessageType>(*parser_.message_type());
+  switch (message_type) {
+    case MoqtMessageType::kClientSetup: {
+      if (session_->control_stream_.GetIfAvailable() != nullptr) {
+        session_->Error(MoqtError::kProtocolViolation,
+                        "Multiple control streams");
+        return;
+      }
+      auto control_stream = std::make_unique<ControlStream>(session_);
+      // Store a reference to the stream context when the current context is
+      // destroyed below.
+      ControlStream* temp_stream = control_stream.get();
+      session_->control_stream_ = temp_stream->GetWeakPtr();
+      control_stream->set_stream(stream_);
+      // Deletes the UnknownBidiStream object; no class access after this
+      // point.
+      stream_->SetVisitor(std::move(control_stream));
+      temp_stream->OnCanRead();
+      break;
+    }
+    default:
+      session_->Error(MoqtError::kProtocolViolation,
+                      "Unexpected message type received to start bidi stream");
+      return;
+  }
+}
+
+void MoqtSession::ControlStream::set_stream(
+    webtransport::Stream* absl_nonnull stream) {
+  stream->SetPriority(
+      webtransport::StreamPriority{/*send_group_id=*/kMoqtSendGroupId,
+                                   /*send_order=*/kMoqtControlStreamSendOrder});
+  if (session_->perspective() == Perspective::IS_SERVER) {
+    MoqtBidiStreamBase::set_stream(stream, MoqtMessageType::kClientSetup);
+  } else {
+    MoqtBidiStreamBase::set_stream(stream, std::nullopt);
+  }
+}
+
+void MoqtSession::ControlStream::OnClientSetupMessage(
     const MoqtClientSetup& message) {
   if (session_->perspective() == Perspective::IS_CLIENT) {
     session_->Error(MoqtError::kProtocolViolation,
@@ -911,28 +960,7 @@
   SendOrBufferMessage(session_->framer_.SerializeServerSetup(response));
   QUICHE_DLOG(INFO) << "Sent SERVER_SETUP";
   // TODO: handle path.
-  if (session_->control_stream_.GetIfAvailable() != nullptr) {
-    session_->Error(MoqtError::kProtocolViolation, "Multiple control streams");
-    return;
-  }
-  auto control_stream = std::make_unique<ControlStream>(session_);
-  // Store a reference to the stream context when the current context is
-  // destroyed below.
-  ControlStream* temp_stream = control_stream.get();
-  session_->control_stream_ = temp_stream->GetWeakPtr();
-  control_stream->set_stream(stream());
   std::move(session_->callbacks_.session_established_callback)();
-  // Deletes the UnknownBidiStream object; no class access after this point.
-  stream()->SetVisitor(std::move(control_stream));
-  temp_stream->OnCanRead();
-}
-
-void MoqtSession::ControlStream::set_stream(
-    webtransport::Stream* absl_nonnull stream) {
-  stream->SetPriority(
-      webtransport::StreamPriority{/*send_group_id=*/kMoqtSendGroupId,
-                                   /*send_order=*/kMoqtControlStreamSendOrder});
-  MoqtBidiStreamBase::set_stream(stream);
 }
 
 void MoqtSession::ControlStream::OnServerSetupMessage(
diff --git a/quiche/quic/moqt/moqt_session.h b/quiche/quic/moqt/moqt_session.h
index 4c54623..867fae5 100644
--- a/quiche/quic/moqt/moqt_session.h
+++ b/quiche/quic/moqt/moqt_session.h
@@ -209,31 +209,26 @@
   };
 
   // A stream is open, but we don't know the type until we receive a message.
-  class QUICHE_EXPORT UnknownBidiStream : public MoqtBidiStreamBase {
+  class QUICHE_EXPORT UnknownBidiStream : public webtransport::StreamVisitor {
    public:
     // Constructor for a stream initiated by the remote peer. The caller is
     // responsible for calling stream->SetVisitor().
     UnknownBidiStream(MoqtSession* session,
                       webtransport::Stream* absl_nonnull stream)
-        : MoqtBidiStreamBase(
-              &session->framer_, []() { /* Do nothing when deleted. */ },
-              [session](MoqtError code, absl::string_view reason) {
-                session->Error(code, reason);
-              }),
-          session_(session) {
-      set_stream(stream);
-    }
+        : session_(session), stream_(stream), parser_(stream) {}
+    ~UnknownBidiStream() {}
 
     // webtransport::StreamVisitor overrides.
-    void OnCanRead() override {
-      parser()->ReadAndDispatchMessages(/*one_message=*/true);
-    }
-
-    // MoqtControlParserVisitor overrides.
-    void OnClientSetupMessage(const MoqtClientSetup& message) override;
+    void OnResetStreamReceived(webtransport::StreamErrorCode error) override {}
+    void OnStopSendingReceived(webtransport::StreamErrorCode error) override {}
+    void OnWriteSideInDataRecvdState() override {}
+    void OnCanRead() override;
+    void OnCanWrite() override {}
 
    private:
     MoqtSession* session_;
+    webtransport::Stream* stream_;
+    MoqtMessageTypeParser parser_;
   };
 
   class QUICHE_EXPORT ControlStream : public MoqtBidiStreamBase {
@@ -256,6 +251,7 @@
     void set_stream(webtransport::Stream* absl_nonnull stream) override;
 
     // MoqtControlParserVisitor implementation.
+    void OnClientSetupMessage(const MoqtClientSetup& message) override;
     void OnServerSetupMessage(const MoqtServerSetup& message) override;
     void OnRequestOkMessage(const MoqtRequestOk& message) override;
     void OnRequestErrorMessage(const MoqtRequestError& message) override;
diff --git a/quiche/quic/moqt/moqt_session_test.cc b/quiche/quic/moqt/moqt_session_test.cc
index 99caa52..2dae407 100644
--- a/quiche/quic/moqt/moqt_session_test.cc
+++ b/quiche/quic/moqt/moqt_session_test.cc
@@ -319,25 +319,26 @@
 }
 
 TEST_F(MoqtSessionTest, OnClientSetup) {
-  MoqtSession server_session(
-      &mock_session_, MoqtSessionParameters(quic::Perspective::IS_SERVER),
-      std::make_unique<quic::test::TestAlarmFactory>(),
-      session_callbacks_.AsSessionCallbacks());
-  std::unique_ptr<MoqtControlParserVisitor> unknown_stream =
-      MoqtSessionPeer::CreateUnknownBidiStream(&server_session, &mock_stream_);
+  MoqtSessionParameters session_parameters(quic::Perspective::IS_SERVER);
+  MoqtSession server_session(&mock_session_, session_parameters,
+                             std::make_unique<quic::test::TestAlarmFactory>(),
+                             session_callbacks_.AsSessionCallbacks());
+  // Load a CLIENT_SETUP message into an in-memory stream.
+  webtransport::test::InMemoryStream in_memory_stream(0);
+  MoqtFramer framer(session_parameters.using_webtrans);
   MoqtClientSetup setup;
-  MoqtSessionParameters parameters(quic::Perspective::IS_CLIENT);
-  parameters.ToSetupParameters(setup.parameters);
-  EXPECT_CALL(mock_stream_, CanWrite).WillOnce(Return(true));
-  EXPECT_CALL(mock_stream_,
-              Writev(ControlMessageOfType(MoqtMessageType::kServerSetup), _));
+  session_parameters.ToSetupParameters(setup.parameters);
+  quiche::QuicheBuffer buffer = framer.SerializeClientSetup(setup);
+  in_memory_stream.Receive(absl::string_view(buffer.data(), buffer.size()),
+                           /*fin=*/false);
+
+  EXPECT_CALL(mock_session_, AcceptIncomingBidirectionalStream())
+      .WillOnce(Return(&in_memory_stream))
+      .WillOnce(Return(nullptr));
   EXPECT_CALL(session_callbacks_.session_established_callback, Call());
-  std::unique_ptr<webtransport::StreamVisitor> visitor;
-  EXPECT_CALL(mock_stream_, SetVisitor)
-      .WillOnce([&](std::unique_ptr<webtransport::StreamVisitor> new_visitor) {
-        visitor = std::move(new_visitor);
-      });
-  unknown_stream->OnClientSetupMessage(setup);
+  server_session.OnIncomingBidirectionalStreamAvailable();
+  EXPECT_EQ(static_cast<uint8_t>(in_memory_stream.last_data_sent()[0]),
+            static_cast<uint8_t>(MoqtMessageType::kServerSetup));
   EXPECT_NE(MoqtSessionPeer::GetControlStream(&server_session), nullptr);
 }
 
diff --git a/quiche/quic/moqt/test_tools/moqt_session_peer.h b/quiche/quic/moqt/test_tools/moqt_session_peer.h
index fb8c32f..dfad53b 100644
--- a/quiche/quic/moqt/test_tools/moqt_session_peer.h
+++ b/quiche/quic/moqt/test_tools/moqt_session_peer.h
@@ -51,14 +51,6 @@
     return new_stream;
   }
 
-  static std::unique_ptr<MoqtControlParserVisitor> CreateUnknownBidiStream(
-      MoqtSession* session, webtransport::Stream* stream) {
-    auto new_stream =
-        std::make_unique<MoqtSession::UnknownBidiStream>(session, stream);
-    new_stream->set_stream(stream);
-    return new_stream;
-  }
-
   static std::unique_ptr<MoqtDataParserVisitor> CreateIncomingDataStream(
       MoqtSession* session, webtransport::Stream* stream,
       MoqtDataStreamType type) {
diff --git a/quiche/web_transport/test_tools/in_memory_stream.cc b/quiche/web_transport/test_tools/in_memory_stream.cc
index 3b28a60..8a144bb 100644
--- a/quiche/web_transport/test_tools/in_memory_stream.cc
+++ b/quiche/web_transport/test_tools/in_memory_stream.cc
@@ -8,10 +8,12 @@
 #include <string>
 #include <vector>
 
+#include "absl/status/status.h"
 #include "absl/strings/cord.h"
 #include "absl/strings/string_view.h"
 #include "absl/types/span.h"
 #include "quiche/common/platform/api/quiche_logging.h"
+#include "quiche/common/quiche_mem_slice.h"
 #include "quiche/common/quiche_stream.h"
 #include "quiche/common/vectorized_io_utils.h"
 
@@ -54,6 +56,16 @@
   return buffer_.empty() && fin_received_;
 }
 
+absl::Status InMemoryStream::Writev(absl::Span<quiche::QuicheMemSlice> data,
+                                    const quiche::StreamWriteOptions& options) {
+  if (!CanWrite()) {
+    return absl::PermissionDeniedError("Stream is not writable.");
+  }
+  last_data_sent_ = std::string(data[0].AsStringView());
+  fin_sent_ |= options.send_fin();
+  return absl::OkStatus();
+}
+
 void InMemoryStream::Receive(absl::string_view data, bool fin) {
   QUICHE_DCHECK(!abruptly_terminated_);
   buffer_.Append(data);
diff --git a/quiche/web_transport/test_tools/in_memory_stream.h b/quiche/web_transport/test_tools/in_memory_stream.h
index 50fa7a6..d4fbd35 100644
--- a/quiche/web_transport/test_tools/in_memory_stream.h
+++ b/quiche/web_transport/test_tools/in_memory_stream.h
@@ -36,11 +36,8 @@
 
   // quiche::WriteStream implementation.
   absl::Status Writev(absl::Span<quiche::QuicheMemSlice> data,
-                      const quiche::StreamWriteOptions& options) override {
-    QUICHE_NOTREACHED() << "Writev called on a read-only stream";
-    return absl::UnimplementedError("Writev called on a read-only stream");
-  }
-  bool CanWrite() const override { return false; }
+                      const quiche::StreamWriteOptions& options) override;
+  bool CanWrite() const override { return !fin_sent_; }
 
   void AbruptlyTerminate(absl::Status) override { Terminate(); }
 
@@ -76,6 +73,10 @@
     peek_one_byte_at_a_time_ = peek_one_byte_at_a_time;
   }
 
+  // Returns what was last written to the stream.
+  absl::string_view last_data_sent() const { return last_data_sent_; }
+  bool fin_sent() const { return fin_sent_; }
+
  private:
   void Terminate();
 
@@ -86,6 +87,8 @@
   bool fin_received_ = false;
   bool abruptly_terminated_ = false;
   bool peek_one_byte_at_a_time_ = false;
+  std::string last_data_sent_;
+  bool fin_sent_ = false;
 };
 
 }  // namespace webtransport::test
diff --git a/quiche/web_transport/test_tools/in_memory_stream_test.cc b/quiche/web_transport/test_tools/in_memory_stream_test.cc
index 0886b43..9c70b7a 100644
--- a/quiche/web_transport/test_tools/in_memory_stream_test.cc
+++ b/quiche/web_transport/test_tools/in_memory_stream_test.cc
@@ -4,13 +4,16 @@
 
 #include "quiche/web_transport/test_tools/in_memory_stream.h"
 
+#include <array>
 #include <string>
 
 #include "absl/strings/str_cat.h"
 #include "absl/strings/string_view.h"
 #include "absl/types/span.h"
 #include "quiche/common/platform/api/quiche_test.h"
+#include "quiche/common/quiche_mem_slice.h"
 #include "quiche/common/quiche_stream.h"
+#include "quiche/common/test_tools/quiche_test_utils.h"
 #include "quiche/web_transport/web_transport.h"
 
 namespace webtransport::test {
@@ -79,5 +82,23 @@
   EXPECT_TRUE(fin_reached);
 }
 
+TEST(InMemoryStreamTest, Write) {
+  InMemoryStream stream(0);
+  EXPECT_TRUE(stream.CanWrite());
+  std::array write_vector = {quiche::QuicheMemSlice::Copy("test")};
+  quiche::StreamWriteOptions options;
+  QUICHE_EXPECT_OK(stream.Writev(absl::MakeSpan(write_vector), options));
+  EXPECT_EQ(stream.last_data_sent(), "test");
+  EXPECT_FALSE(stream.fin_sent());
+  // Send FIN.
+  options.set_send_fin(true);
+  write_vector = {quiche::QuicheMemSlice::Copy("test2")};
+  QUICHE_EXPECT_OK(stream.Writev(absl::MakeSpan(write_vector), options));
+  EXPECT_EQ(stream.last_data_sent(), "test2");
+  EXPECT_TRUE(stream.fin_sent());
+  EXPECT_FALSE(stream.CanWrite());
+  EXPECT_FALSE(stream.Writev(absl::MakeSpan(write_vector), options).ok());
+}
+
 }  // namespace
 }  // namespace webtransport::test