Refactor HTTP/3 Datagram

This CL refactors our HTTP/3 Datagram APIs in order to prepare for the transition from draft-ietf-masque-h3-datagram-00 to draft-ietf-masque-h3-datagram-02. This CL switches the QuicSpdySession APIs to use the model and mindset from draft-02, but it still keeps draft-00 on the wire for now.

This CL has no server-side behavior changes. There is one behavior change on the client: the Datagram-Flow-Id now always uses a flow ID equal to the stream ID. That allows simplifying client code. The server code currently still uses whichever flow ID came in the Datagram-Flow-Id header. We'll keep it that way until we can remove support for draft-00.

The mismatch between draft-00 and draft-02 does mean that in some places we have variables called "stream_id" that contain the flow ID when draft-00 is in use, but it'll make it easier to reason about draft-02 which is the direction we're going in.

PiperOrigin-RevId: 383744375
diff --git a/quic/core/http/quic_spdy_session.cc b/quic/core/http/quic_spdy_session.cc
index 0ea2452..65287bd 100644
--- a/quic/core/http/quic_spdy_session.cc
+++ b/quic/core/http/quic_spdy_session.cc
@@ -461,14 +461,9 @@
 // Expected unidirectional static streams Requirement can be found at
 // https://tools.ietf.org/html/draft-ietf-quic-http-22#section-6.2.
 QuicSpdySession::QuicSpdySession(
-    QuicConnection* connection,
-    QuicSession::Visitor* visitor,
-    const QuicConfig& config,
-    const ParsedQuicVersionVector& supported_versions)
-    : QuicSession(connection,
-                  visitor,
-                  config,
-                  supported_versions,
+    QuicConnection* connection, QuicSession::Visitor* visitor,
+    const QuicConfig& config, const ParsedQuicVersionVector& supported_versions)
+    : QuicSession(connection, visitor, config, supported_versions,
                   /*num_expected_unidirectional_static_streams = */
                   VersionUsesHttp3(connection->transport_version())
                       ? static_cast<QuicStreamCount>(
@@ -495,10 +490,7 @@
       spdy_framer_(SpdyFramer::ENABLE_COMPRESSION),
       spdy_framer_visitor_(new SpdyFramerVisitor(this)),
       debug_visitor_(nullptr),
-      destruction_indicator_(123456789),
-      next_available_datagram_flow_id_(perspective() == Perspective::IS_SERVER
-                                           ? kFirstDatagramFlowIdServer
-                                           : kFirstDatagramFlowIdClient) {
+      destruction_indicator_(123456789) {
   h2_deframer_.set_visitor(spdy_framer_visitor_.get());
   h2_deframer_.set_debug_visitor(spdy_framer_visitor_.get());
   spdy_framer_.set_debug_visitor(spdy_framer_visitor_.get());
@@ -1644,25 +1636,33 @@
   }
 }
 
-QuicDatagramFlowId QuicSpdySession::GetNextDatagramFlowId() {
-  QuicDatagramFlowId result = next_available_datagram_flow_id_;
-  next_available_datagram_flow_id_ += kDatagramFlowIdIncrement;
-  return result;
-}
-
-MessageStatus QuicSpdySession::SendHttp3Datagram(QuicDatagramFlowId flow_id,
-                                                 absl::string_view payload) {
-  const size_t slice_length =
-      QuicDataWriter::GetVarInt62Len(flow_id) + payload.length();
+MessageStatus QuicSpdySession::SendHttp3Datagram(
+    QuicDatagramStreamId stream_id,
+    absl::optional<QuicDatagramContextId> context_id,
+    absl::string_view payload) {
+  size_t slice_length =
+      QuicDataWriter::GetVarInt62Len(stream_id) + payload.length();
+  if (context_id.has_value()) {
+    slice_length += QuicDataWriter::GetVarInt62Len(context_id.value());
+  }
   QuicBuffer buffer(connection()->helper()->GetStreamSendBufferAllocator(),
                     slice_length);
   QuicDataWriter writer(slice_length, buffer.data());
-  if (!writer.WriteVarInt62(flow_id)) {
-    QUIC_BUG(quic_bug_10360_10) << "Failed to write HTTP/3 datagram flow ID";
+  if (!writer.WriteVarInt62(stream_id)) {
+    QUIC_BUG(h3 datagram stream ID write fail)
+        << "Failed to write HTTP/3 datagram stream ID";
     return MESSAGE_STATUS_INTERNAL_ERROR;
   }
+  if (context_id.has_value()) {
+    if (!writer.WriteVarInt62(context_id.value())) {
+      QUIC_BUG(h3 datagram context ID write fail)
+          << "Failed to write HTTP/3 datagram context ID";
+      return MESSAGE_STATUS_INTERNAL_ERROR;
+    }
+  }
   if (!writer.WriteBytes(payload.data(), payload.length())) {
-    QUIC_BUG(quic_bug_10360_11) << "Failed to write HTTP/3 datagram payload";
+    QUIC_BUG(h3 datagram payload write fail)
+        << "Failed to write HTTP/3 datagram payload";
     return MESSAGE_STATUS_INTERNAL_ERROR;
   }
 
@@ -1670,29 +1670,23 @@
   return datagram_queue()->SendOrQueueDatagram(std::move(slice));
 }
 
-void QuicSpdySession::RegisterHttp3FlowId(
-    QuicDatagramFlowId flow_id,
-    QuicSpdySession::Http3DatagramVisitor* visitor) {
-  QUICHE_DCHECK_NE(visitor, nullptr);
-  auto insertion_result = h3_datagram_registrations_.insert({flow_id, visitor});
-  QUIC_BUG_IF(quic_bug_12477_7, !insertion_result.second)
-      << "Attempted to doubly register HTTP/3 flow ID " << flow_id;
-}
-
-void QuicSpdySession::UnregisterHttp3FlowId(QuicDatagramFlowId flow_id) {
-  size_t num_erased = h3_datagram_registrations_.erase(flow_id);
-  QUIC_BUG_IF(quic_bug_12477_8, num_erased != 1)
-      << "Attempted to unregister unknown HTTP/3 flow ID " << flow_id;
-}
-
-void QuicSpdySession::SetMaxTimeInQueueForFlowId(
-    QuicDatagramFlowId /*flow_id*/,
-    QuicTime::Delta max_time_in_queue) {
+void QuicSpdySession::SetMaxDatagramTimeInQueueForStreamId(
+    QuicStreamId /*stream_id*/, QuicTime::Delta max_time_in_queue) {
   // TODO(b/184598230): implement this in a way that works for multiple sessions
   // on a same connection.
   datagram_queue()->SetMaxTimeInQueue(max_time_in_queue);
 }
 
+void QuicSpdySession::RegisterHttp3DatagramFlowId(QuicDatagramStreamId flow_id,
+                                                  QuicStreamId stream_id) {
+  h3_datagram_flow_id_to_stream_id_map_[flow_id] = stream_id;
+}
+
+void QuicSpdySession::UnregisterHttp3DatagramFlowId(
+    QuicDatagramStreamId flow_id) {
+  h3_datagram_flow_id_to_stream_id_map_.erase(flow_id);
+}
+
 void QuicSpdySession::OnMessageReceived(absl::string_view message) {
   QuicSession::OnMessageReceived(message);
   if (!h3_datagram_supported_) {
@@ -1700,20 +1694,38 @@
     return;
   }
   QuicDataReader reader(message);
-  QuicDatagramFlowId flow_id;
-  if (!reader.ReadVarInt62(&flow_id)) {
-    QUIC_DLOG(ERROR) << "Failed to parse flow ID in received HTTP/3 datagram";
+  uint64_t stream_id64;
+  if (!reader.ReadVarInt62(&stream_id64)) {
+    QUIC_DLOG(ERROR) << "Failed to parse stream ID in received HTTP/3 datagram";
     return;
   }
-  auto it = h3_datagram_registrations_.find(flow_id);
-  if (it == h3_datagram_registrations_.end()) {
-    // TODO(dschinazi) buffer unknown HTTP/3 datagram flow IDs for a short
+  if (perspective() == Perspective::IS_SERVER) {
+    auto it = h3_datagram_flow_id_to_stream_id_map_.find(stream_id64);
+    if (it == h3_datagram_flow_id_to_stream_id_map_.end()) {
+      QUIC_DLOG(INFO) << "Received unknown HTTP/3 datagram flow ID "
+                      << stream_id64;
+      return;
+    }
+    stream_id64 = it->second;
+  }
+  if (stream_id64 > std::numeric_limits<QuicStreamId>::max()) {
+    // TODO(b/181256914) make this a connection close once we deprecate
+    // draft-ietf-masque-h3-datagram-00 in favor of later drafts.
+    QUIC_DLOG(ERROR) << "Received unexpectedly high HTTP/3 datagram stream ID "
+                     << stream_id64;
+    return;
+  }
+  QuicStreamId stream_id = static_cast<QuicStreamId>(stream_id64);
+  QuicSpdyStream* stream =
+      static_cast<QuicSpdyStream*>(GetActiveStream(stream_id));
+  if (stream == nullptr) {
+    QUIC_DLOG(INFO) << "Received HTTP/3 datagram for unknown stream ID "
+                    << stream_id;
+    // TODO(b/181256914) buffer unknown HTTP/3 datagram flow IDs for a short
     // period of time in case they were reordered.
-    QUIC_DLOG(ERROR) << "Received unknown HTTP/3 datagram flow ID " << flow_id;
     return;
   }
-  absl::string_view payload = reader.ReadRemainingPayload();
-  it->second->OnHttp3Datagram(flow_id, payload);
+  stream->OnDatagramReceived(&reader);
 }
 
 bool QuicSpdySession::SupportsWebTransport() {
diff --git a/quic/core/http/quic_spdy_session.h b/quic/core/http/quic_spdy_session.h
index 810297d..4df685d 100644
--- a/quic/core/http/quic_spdy_session.h
+++ b/quic/core/http/quic_spdy_session.h
@@ -391,41 +391,24 @@
   // extension.
   virtual void OnAcceptChFrameReceivedViaAlps(const AcceptChFrame& /*frame*/);
 
-  // Generates a new HTTP/3 datagram flow ID.
-  QuicDatagramFlowId GetNextDatagramFlowId();
-
   // Whether HTTP/3 datagrams are supported on this session, based on received
   // SETTINGS.
   bool h3_datagram_supported() const { return h3_datagram_supported_; }
 
-  // Sends an HTTP/3 datagram. The flow ID is not part of |payload|.
-  MessageStatus SendHttp3Datagram(QuicDatagramFlowId flow_id,
-                                  absl::string_view payload);
-
-  class QUIC_EXPORT_PRIVATE Http3DatagramVisitor {
-   public:
-    virtual ~Http3DatagramVisitor() {}
-
-    // Called when an HTTP/3 datagram is received. |payload| does not contain
-    // the flow ID.
-    virtual void OnHttp3Datagram(QuicDatagramFlowId flow_id,
-                                 absl::string_view payload) = 0;
-  };
-
-  // Registers |visitor| to receive HTTP/3 datagrams for flow ID |flow_id|. This
-  // must not be called on a previously register flow ID without first calling
-  // UnregisterHttp3FlowId. |visitor| must be valid until a corresponding call
-  // to UnregisterHttp3FlowId. The flow ID must be unregistered before the
-  // QuicSpdySession is destroyed.
-  void RegisterHttp3FlowId(QuicDatagramFlowId flow_id,
-                           Http3DatagramVisitor* visitor);
-
-  // Unregister a given HTTP/3 datagram flow ID.
-  void UnregisterHttp3FlowId(QuicDatagramFlowId flow_id);
-
-  // Sets max time in queue for a specified datagram flow ID.
-  void SetMaxTimeInQueueForFlowId(QuicDatagramFlowId flow_id,
-                                  QuicTime::Delta max_time_in_queue);
+  // This must not be used except by QuicSpdyStream::SendHttp3Datagram.
+  MessageStatus SendHttp3Datagram(
+      QuicDatagramStreamId stream_id,
+      absl::optional<QuicDatagramContextId> context_id,
+      absl::string_view payload);
+  // This must not be used except by QuicSpdyStream::SetMaxDatagramTimeInQueue.
+  void SetMaxDatagramTimeInQueueForStreamId(QuicStreamId stream_id,
+                                            QuicTime::Delta max_time_in_queue);
+  // This must not be used except by
+  // QuicSpdyStream::MaybeProcessReceivedWebTransportHeaders.
+  void RegisterHttp3DatagramFlowId(QuicDatagramStreamId flow_id,
+                                   QuicStreamId stream_id);
+  // This must not be used except by QuicSpdyStream::OnClose.
+  void UnregisterHttp3DatagramFlowId(QuicDatagramStreamId flow_id);
 
   // Override from QuicSession to support HTTP/3 datagrams.
   void OnMessageReceived(absl::string_view message) override;
@@ -688,18 +671,17 @@
   // frame has been sent yet.
   absl::optional<uint64_t> last_sent_http3_goaway_id_;
 
-  // Value of the smallest unused HTTP/3 datagram flow ID that this endpoint's
-  // datagram flow ID allocation service will use next.
-  QuicDatagramFlowId next_available_datagram_flow_id_;
-
   // Whether both this endpoint and our peer support HTTP/3 datagrams.
   bool h3_datagram_supported_ = false;
 
   // Whether the peer has indicated WebTransport support.
   bool peer_supports_webtransport_ = false;
 
-  absl::flat_hash_map<QuicDatagramFlowId, Http3DatagramVisitor*>
-      h3_datagram_registrations_;
+  // This maps from draft-ietf-masque-h3-datagram-00 flow IDs to stream IDs.
+  // TODO(b/181256914) remove this when we deprecate support for that draft in
+  // favor of more recent ones.
+  absl::flat_hash_map<uint64_t, QuicStreamId>
+      h3_datagram_flow_id_to_stream_id_map_;
 
   // Whether any settings have been received, either from the peer or from a
   // session ticket.
diff --git a/quic/core/http/quic_spdy_session_test.cc b/quic/core/http/quic_spdy_session_test.cc
index 9221937..7821be7 100644
--- a/quic/core/http/quic_spdy_session_test.cc
+++ b/quic/core/http/quic_spdy_session_test.cc
@@ -573,8 +573,7 @@
     headers.OnHeaderBlockStart();
     headers.OnHeader(":method", "CONNECT");
     headers.OnHeader(":protocol", "webtransport");
-    headers.OnHeader("datagram-flow-id",
-                     absl::StrCat(session_.GetNextDatagramFlowId()));
+    headers.OnHeader("datagram-flow-id", absl::StrCat(session_id));
     stream->OnStreamHeaderList(/*fin=*/true, 0, headers);
     WebTransportHttp3* web_transport =
         session_.GetWebTransportSession(session_id);
@@ -3470,26 +3469,6 @@
   EXPECT_EQ("multiple SETTINGS frames", error.value());
 }
 
-TEST_P(QuicSpdySessionTestClient, GetNextDatagramFlowId) {
-  if (!version().UsesHttp3()) {
-    return;
-  }
-  EXPECT_EQ(session_.GetNextDatagramFlowId(), 0u);
-  EXPECT_EQ(session_.GetNextDatagramFlowId(), 2u);
-  EXPECT_EQ(session_.GetNextDatagramFlowId(), 4u);
-  EXPECT_EQ(session_.GetNextDatagramFlowId(), 6u);
-}
-
-TEST_P(QuicSpdySessionTestServer, GetNextDatagramFlowId) {
-  if (!version().UsesHttp3()) {
-    return;
-  }
-  EXPECT_EQ(session_.GetNextDatagramFlowId(), 1u);
-  EXPECT_EQ(session_.GetNextDatagramFlowId(), 3u);
-  EXPECT_EQ(session_.GetNextDatagramFlowId(), 5u);
-  EXPECT_EQ(session_.GetNextDatagramFlowId(), 7u);
-}
-
 TEST_P(QuicSpdySessionTestClient, H3DatagramSetting) {
   if (!version().UsesHttp3()) {
     return;
@@ -3513,47 +3492,6 @@
   EXPECT_TRUE(session_.h3_datagram_supported());
 }
 
-TEST_P(QuicSpdySessionTestClient, H3DatagramRegistration) {
-  if (!version().UsesHttp3()) {
-    return;
-  }
-  CompleteHandshake();
-  session_.set_should_negotiate_h3_datagram(true);
-  QuicSpdySessionPeer::SetH3DatagramSupported(&session_, true);
-  SavingHttp3DatagramVisitor h3_datagram_visitor;
-  QuicDatagramFlowId flow_id = session_.GetNextDatagramFlowId();
-  ASSERT_EQ(QuicDataWriter::GetVarInt62Len(flow_id), 1);
-  uint8_t datagram[256];
-  datagram[0] = flow_id;
-  for (size_t i = 1; i < ABSL_ARRAYSIZE(datagram); i++) {
-    datagram[i] = i;
-  }
-  session_.RegisterHttp3FlowId(flow_id, &h3_datagram_visitor);
-  session_.OnMessageReceived(absl::string_view(
-      reinterpret_cast<const char*>(datagram), sizeof(datagram)));
-  EXPECT_THAT(
-      h3_datagram_visitor.received_h3_datagrams(),
-      ElementsAre(SavingHttp3DatagramVisitor::SavedHttp3Datagram{
-          flow_id, std::string(reinterpret_cast<const char*>(datagram + 1),
-                               sizeof(datagram) - 1)}));
-  session_.UnregisterHttp3FlowId(flow_id);
-}
-
-TEST_P(QuicSpdySessionTestClient, SendHttp3Datagram) {
-  if (!version().UsesHttp3()) {
-    return;
-  }
-  CompleteHandshake();
-  session_.set_should_negotiate_h3_datagram(true);
-  QuicSpdySessionPeer::SetH3DatagramSupported(&session_, true);
-  QuicDatagramFlowId flow_id = session_.GetNextDatagramFlowId();
-  std::string h3_datagram_payload = {1, 2, 3, 4, 5, 6};
-  EXPECT_CALL(*connection_, SendMessage(1, _, false))
-      .WillOnce(Return(MESSAGE_STATUS_SUCCESS));
-  EXPECT_EQ(session_.SendHttp3Datagram(flow_id, h3_datagram_payload),
-            MESSAGE_STATUS_SUCCESS);
-}
-
 TEST_P(QuicSpdySessionTestClient, WebTransportSetting) {
   if (!version().UsesHttp3()) {
     return;
diff --git a/quic/core/http/quic_spdy_stream.cc b/quic/core/http/quic_spdy_stream.cc
index e797dee..d2ec4a1 100644
--- a/quic/core/http/quic_spdy_stream.cc
+++ b/quic/core/http/quic_spdy_stream.cc
@@ -213,8 +213,7 @@
 }
 }  // namespace
 
-QuicSpdyStream::QuicSpdyStream(QuicStreamId id,
-                               QuicSpdySession* spdy_session,
+QuicSpdyStream::QuicSpdyStream(QuicStreamId id, QuicSpdySession* spdy_session,
                                StreamType type)
     : QuicStream(id, spdy_session, /*is_static=*/false, type),
       spdy_session_(spdy_session),
@@ -232,7 +231,11 @@
       sequencer_offset_(0),
       is_decoder_processing_input_(false),
       ack_listener_(nullptr),
-      last_sent_urgency_(kDefaultUrgency) {
+      last_sent_urgency_(kDefaultUrgency),
+      datagram_next_available_context_id_(spdy_session->perspective() ==
+                                                  Perspective::IS_SERVER
+                                              ? kFirstDatagramContextIdServer
+                                              : kFirstDatagramContextIdClient) {
   QUICHE_DCHECK_EQ(session()->connection(), spdy_session->connection());
   QUICHE_DCHECK_EQ(transport_version(), spdy_session->transport_version());
   QUICHE_DCHECK(!QuicUtils::IsCryptoStreamId(transport_version(), id));
@@ -886,6 +889,10 @@
     visitor->OnClose(this);
   }
 
+  if (datagram_flow_id_.has_value()) {
+    spdy_session_->UnregisterHttp3DatagramFlowId(datagram_flow_id_.value());
+  }
+
   if (web_transport_ != nullptr) {
     web_transport_->CloseAllAssociatedStreams();
   }
@@ -1271,7 +1278,7 @@
 
   std::string method;
   std::string protocol;
-  absl::optional<QuicDatagramFlowId> flow_id;
+  absl::optional<QuicDatagramStreamId> flow_id;
   for (const auto& header : header_list_) {
     const std::string& header_name = header.first;
     const std::string& header_value = header.second;
@@ -1291,7 +1298,7 @@
       if (flow_id.has_value() || header_value.empty()) {
         return;
       }
-      QuicDatagramFlowId flow_id_out;
+      QuicDatagramStreamId flow_id_out;
       if (!absl::SimpleAtoi(header_value, &flow_id_out)) {
         return;
       }
@@ -1304,8 +1311,18 @@
     return;
   }
 
+  RegisterHttp3DatagramFlowId(*flow_id);
+
   web_transport_ =
-      std::make_unique<WebTransportHttp3>(spdy_session_, this, id(), *flow_id);
+      std::make_unique<WebTransportHttp3>(spdy_session_, this, id());
+
+  // If we're in draft-ietf-masque-h3-datagram-00 mode, pretend we also received
+  // a REGISTER_DATAGRAM_NO_CONTEXT capsule with no extensions.
+  // TODO(b/181256914) remove this when we remove support for
+  // draft-ietf-masque-h3-datagram-00 in favor of later drafts.
+  RegisterHttp3DatagramContextId(/*context_id=*/absl::nullopt,
+                                 Http3DatagramContextExtensions(),
+                                 web_transport_.get());
 }
 
 void QuicSpdyStream::MaybeProcessSentWebTransportHeaders(
@@ -1327,11 +1344,14 @@
     return;
   }
 
-  QuicDatagramFlowId flow_id = spdy_session_->GetNextDatagramFlowId();
-  headers["datagram-flow-id"] = absl::StrCat(flow_id);
+  QuicDatagramStreamId stream_id = id();
+  headers["datagram-flow-id"] = absl::StrCat(stream_id);
 
   web_transport_ =
-      std::make_unique<WebTransportHttp3>(spdy_session_, this, id(), flow_id);
+      std::make_unique<WebTransportHttp3>(spdy_session_, this, id());
+  RegisterHttp3DatagramContextId(web_transport_->context_id(),
+                                 Http3DatagramContextExtensions(),
+                                 web_transport_.get());
 }
 
 void QuicSpdyStream::OnCanWriteNewData() {
@@ -1392,5 +1412,206 @@
     : session_id(session_id),
       adapter(stream->spdy_session_, stream, stream->sequencer()) {}
 
+MessageStatus QuicSpdyStream::SendHttp3Datagram(
+    absl::optional<QuicDatagramContextId> context_id,
+    absl::string_view payload) {
+  QuicDatagramStreamId stream_id =
+      datagram_flow_id_.has_value() ? datagram_flow_id_.value() : id();
+  return spdy_session_->SendHttp3Datagram(stream_id, context_id, payload);
+}
+
+void QuicSpdyStream::RegisterHttp3DatagramRegistrationVisitor(
+    Http3DatagramRegistrationVisitor* visitor) {
+  if (visitor == nullptr) {
+    QUIC_BUG(null datagram registration visitor)
+        << ENDPOINT << "Null datagram registration visitor for" << id();
+    return;
+  }
+  QUIC_DLOG(INFO) << ENDPOINT << "Registering datagram stream ID " << id();
+  datagram_registration_visitor_ = visitor;
+}
+
+void QuicSpdyStream::UnregisterHttp3DatagramRegistrationVisitor() {
+  QUIC_BUG_IF(h3 datagram unregister unknown stream ID,
+              datagram_registration_visitor_ == nullptr)
+      << ENDPOINT
+      << "Attempted to unregister unknown HTTP/3 datagram stream ID " << id();
+  QUIC_DLOG(INFO) << ENDPOINT << "Unregistering datagram stream ID " << id();
+  datagram_registration_visitor_ = nullptr;
+}
+
+void QuicSpdyStream::MoveHttp3DatagramRegistration(
+    Http3DatagramRegistrationVisitor* visitor) {
+  QUIC_BUG_IF(h3 datagram move unknown stream ID,
+              datagram_registration_visitor_ == nullptr)
+      << ENDPOINT << "Attempted to move unknown HTTP/3 datagram stream ID "
+      << id();
+  QUIC_DLOG(INFO) << ENDPOINT << "Moving datagram stream ID " << id();
+  datagram_registration_visitor_ = visitor;
+}
+
+void QuicSpdyStream::RegisterHttp3DatagramContextId(
+    absl::optional<QuicDatagramContextId> context_id,
+    const Http3DatagramContextExtensions& /*extensions*/,
+    Http3DatagramVisitor* visitor) {
+  if (visitor == nullptr) {
+    QUIC_BUG(null datagram visitor)
+        << ENDPOINT << "Null datagram visitor for stream ID " << id()
+        << " context ID " << (context_id.has_value() ? context_id.value() : 0);
+    return;
+  }
+  if (datagram_registration_visitor_ == nullptr) {
+    QUIC_BUG(context registration without registration visitor)
+        << ENDPOINT << "Cannot register context ID "
+        << (context_id.has_value() ? context_id.value() : 0)
+        << " without registration visitor for stream ID " << id();
+    return;
+  }
+  QUIC_DLOG(INFO) << ENDPOINT << "Registering datagram context ID "
+                  << (context_id.has_value() ? context_id.value() : 0)
+                  << " with stream ID " << id();
+  if (context_id.has_value()) {
+    if (datagram_no_context_visitor_ != nullptr) {
+      QUIC_BUG(h3 datagram context ID mix1)
+          << ENDPOINT
+          << "Attempted to mix registrations without and with context IDs "
+             "for stream ID "
+          << id();
+      return;
+    }
+    auto insertion_result =
+        datagram_context_visitors_.insert({context_id.value(), visitor});
+    QUIC_BUG_IF(h3 datagram double context registration,
+                !insertion_result.second)
+        << ENDPOINT << "Attempted to doubly register HTTP/3 stream ID " << id()
+        << " context ID " << context_id.value();
+    return;
+  }
+  // Registration without a context ID.
+  if (!datagram_context_visitors_.empty()) {
+    QUIC_BUG(h3 datagram context ID mix2)
+        << ENDPOINT
+        << "Attempted to mix registrations with and without context IDs "
+           "for stream ID "
+        << id();
+    return;
+  }
+  if (datagram_no_context_visitor_ != nullptr) {
+    QUIC_BUG(h3 datagram double no context registration)
+        << ENDPOINT << "Attempted to doubly register HTTP/3 stream ID " << id()
+        << " with no context ID";
+    return;
+  }
+  datagram_no_context_visitor_ = visitor;
+}
+
+void QuicSpdyStream::UnregisterHttp3DatagramContextId(
+    absl::optional<QuicDatagramContextId> context_id) {
+  if (datagram_registration_visitor_ == nullptr) {
+    QUIC_BUG(context unregistration without registration visitor)
+        << ENDPOINT << "Cannot unregister context ID "
+        << (context_id.has_value() ? context_id.value() : 0)
+        << " without registration visitor for stream ID " << id();
+    return;
+  }
+  QUIC_DLOG(INFO) << ENDPOINT << "Unregistering datagram context ID "
+                  << (context_id.has_value() ? context_id.value() : 0)
+                  << " with stream ID " << id();
+  if (context_id.has_value()) {
+    size_t num_erased = datagram_context_visitors_.erase(context_id.value());
+    QUIC_BUG_IF(h3 datagram unregister unknown context, num_erased != 1)
+        << "Attempted to unregister unknown HTTP/3 context ID "
+        << context_id.value() << " on stream ID " << id();
+    return;
+  }
+  // Unregistration without a context ID.
+  QUIC_BUG_IF(h3 datagram unknown context unregistration,
+              datagram_no_context_visitor_ == nullptr)
+      << "Attempted to unregister unknown no context on HTTP/3 stream ID "
+      << id();
+  datagram_no_context_visitor_ = nullptr;
+}
+
+void QuicSpdyStream::MoveHttp3DatagramContextIdRegistration(
+    absl::optional<QuicDatagramContextId> context_id,
+    Http3DatagramVisitor* visitor) {
+  if (datagram_registration_visitor_ == nullptr) {
+    QUIC_BUG(context move without registration visitor)
+        << ENDPOINT << "Cannot move context ID "
+        << (context_id.has_value() ? context_id.value() : 0)
+        << " without registration visitor for stream ID " << id();
+    return;
+  }
+  QUIC_DLOG(INFO) << ENDPOINT << "Moving datagram context ID "
+                  << (context_id.has_value() ? context_id.value() : 0)
+                  << " with stream ID " << id();
+  if (context_id.has_value()) {
+    QUIC_BUG_IF(h3 datagram move unknown context,
+                !datagram_context_visitors_.contains(context_id.value()))
+        << ENDPOINT << "Attempted to move unknown context ID "
+        << context_id.value() << " on stream ID " << id();
+    datagram_context_visitors_[context_id.value()] = visitor;
+    return;
+  }
+  // Move without a context ID.
+  QUIC_BUG_IF(h3 datagram unknown context move,
+              datagram_no_context_visitor_ == nullptr)
+      << "Attempted to move unknown no context on HTTP/3 stream ID " << id();
+  datagram_no_context_visitor_ = visitor;
+}
+
+void QuicSpdyStream::SetMaxDatagramTimeInQueue(
+    QuicTime::Delta max_time_in_queue) {
+  spdy_session_->SetMaxDatagramTimeInQueueForStreamId(id(), max_time_in_queue);
+}
+
+QuicDatagramContextId QuicSpdyStream::GetNextDatagramContextId() {
+  QuicDatagramContextId result = datagram_next_available_context_id_;
+  datagram_next_available_context_id_ += kDatagramContextIdIncrement;
+  return result;
+}
+
+void QuicSpdyStream::OnDatagramReceived(QuicDataReader* reader) {
+  absl::optional<QuicDatagramContextId> context_id;
+  const bool context_id_present = !datagram_context_visitors_.empty();
+  Http3DatagramVisitor* visitor;
+  if (context_id_present) {
+    QuicDatagramContextId parsed_context_id;
+    if (!reader->ReadVarInt62(&parsed_context_id)) {
+      QUIC_DLOG(ERROR) << "Failed to parse context ID in received HTTP/3 "
+                          "datagram on stream ID "
+                       << id();
+      return;
+    }
+    context_id = parsed_context_id;
+    auto it = datagram_context_visitors_.find(parsed_context_id);
+    if (it == datagram_context_visitors_.end()) {
+      // TODO(b/181256914) buffer unknown HTTP/3 datagrams for a short
+      // period of time in case they were reordered.
+      QUIC_DLOG(ERROR) << "Received unknown HTTP/3 datagram context ID "
+                       << parsed_context_id << " on stream ID " << id();
+      return;
+    }
+    visitor = it->second;
+  } else {
+    if (datagram_no_context_visitor_ == nullptr) {
+      // TODO(b/181256914) buffer unknown HTTP/3 datagrams for a short
+      // period of time in case they were reordered.
+      QUIC_DLOG(ERROR)
+          << "Received HTTP/3 datagram without any registrations on stream ID "
+          << id();
+      return;
+    }
+    visitor = datagram_no_context_visitor_;
+  }
+  absl::string_view payload = reader->ReadRemainingPayload();
+  visitor->OnHttp3Datagram(id(), context_id, payload);
+}
+
+void QuicSpdyStream::RegisterHttp3DatagramFlowId(QuicDatagramStreamId flow_id) {
+  datagram_flow_id_ = flow_id;
+  spdy_session_->RegisterHttp3DatagramFlowId(datagram_flow_id_.value(), id());
+}
+
 #undef ENDPOINT  // undef for jumbo builds
 }  // namespace quic
diff --git a/quic/core/http/quic_spdy_stream.h b/quic/core/http/quic_spdy_stream.h
index 74ae070..8a219d1 100644
--- a/quic/core/http/quic_spdy_stream.h
+++ b/quic/core/http/quic_spdy_stream.h
@@ -46,6 +46,8 @@
 class QuicSpdySession;
 class WebTransportHttp3;
 
+class QUIC_EXPORT_PRIVATE Http3DatagramContextExtensions {};
+
 // A QUIC stream that can send and receive HTTP2 (SPDY) headers.
 class QUIC_EXPORT_PRIVATE QuicSpdyStream
     : public QuicStream,
@@ -251,6 +253,94 @@
   // rejected due to buffer being full.  |write_size| must be non-zero.
   bool CanWriteNewBodyData(QuicByteCount write_size) const;
 
+  // Sends an HTTP/3 datagram. The stream and context IDs are not part of
+  // |payload|.
+  MessageStatus SendHttp3Datagram(
+      absl::optional<QuicDatagramContextId> context_id,
+      absl::string_view payload);
+
+  class QUIC_EXPORT_PRIVATE Http3DatagramVisitor {
+   public:
+    virtual ~Http3DatagramVisitor() {}
+
+    // Called when an HTTP/3 datagram is received. |payload| does not contain
+    // the stream or context IDs. Note that this contains the stream ID even if
+    // flow IDs from draft-ietf-masque-h3-datagram-00 are in use.
+    virtual void OnHttp3Datagram(
+        QuicStreamId stream_id,
+        absl::optional<QuicDatagramContextId> context_id,
+        absl::string_view payload) = 0;
+  };
+
+  class QUIC_EXPORT_PRIVATE Http3DatagramRegistrationVisitor {
+   public:
+    virtual ~Http3DatagramRegistrationVisitor() {}
+
+    // Called when a REGISTER_DATAGRAM_CONTEXT or REGISTER_DATAGRAM_NO_CONTEXT
+    // capsule is received. Note that this contains the stream ID even if flow
+    // IDs from draft-ietf-masque-h3-datagram-00 are in use.
+    virtual void OnContextReceived(
+        QuicStreamId stream_id,
+        absl::optional<QuicDatagramContextId> context_id,
+        const Http3DatagramContextExtensions& extensions) = 0;
+
+    // Called when a CLOSE_DATAGRAM_CONTEXT capsule is received. Note that this
+    // contains the stream ID even if flow IDs from
+    // draft-ietf-masque-h3-datagram-00 are in use.
+    virtual void OnContextClosed(
+        QuicStreamId stream_id,
+        absl::optional<QuicDatagramContextId> context_id,
+        const Http3DatagramContextExtensions& extensions) = 0;
+  };
+
+  // Registers |visitor| to receive HTTP/3 datagram context registrations. This
+  // must not be called without first calling
+  // UnregisterHttp3DatagramRegistrationVisitor. |visitor| must be valid until a
+  // corresponding call to UnregisterHttp3DatagramRegistrationVisitor.
+  void RegisterHttp3DatagramRegistrationVisitor(
+      Http3DatagramRegistrationVisitor* visitor);
+
+  // Unregisters for HTTP/3 datagram context registrations. Must not be called
+  // unless previously registered.
+  void UnregisterHttp3DatagramRegistrationVisitor();
+
+  // Moves an HTTP/3 datagram registration to a different visitor. Mainly meant
+  // to be used by the visitors' move operators.
+  void MoveHttp3DatagramRegistration(Http3DatagramRegistrationVisitor* visitor);
+
+  // Registers |visitor| to receive HTTP/3 datagrams for optional context ID
+  // |context_id|. This must not be called on a previously registered context ID
+  // without first calling UnregisterHttp3DatagramContextId. |visitor| must be
+  // valid until a corresponding call to UnregisterHttp3DatagramContextId. If
+  // this method is called multiple times, the context ID MUST either be always
+  // present, or always absent.
+  void RegisterHttp3DatagramContextId(
+      absl::optional<QuicDatagramContextId> context_id,
+      const Http3DatagramContextExtensions& extensions,
+      Http3DatagramVisitor* visitor);
+
+  // Unregisters an HTTP/3 datagram context ID. Must be called on a previously
+  // registered context.
+  void UnregisterHttp3DatagramContextId(
+      absl::optional<QuicDatagramContextId> context_id);
+
+  // Moves an HTTP/3 datagram context ID to a different visitor. Mainly meant
+  // to be used by the visitors' move operators.
+  void MoveHttp3DatagramContextIdRegistration(
+      absl::optional<QuicDatagramContextId> context_id,
+      Http3DatagramVisitor* visitor);
+
+  // Sets max datagram time in queue.
+  void SetMaxDatagramTimeInQueue(QuicTime::Delta max_time_in_queue);
+
+  // Generates a new HTTP/3 datagram context ID for this stream. A datagram
+  // registration visitor must be currently registered on this stream.
+  QuicDatagramContextId GetNextDatagramContextId();
+
+  void OnDatagramReceived(QuicDataReader* reader);
+
+  void RegisterHttp3DatagramFlowId(QuicDatagramStreamId flow_id);
+
  protected:
   // Called when the received headers are too large. By default this will
   // reset the stream.
@@ -394,6 +484,14 @@
   // If this stream is a WebTransport data stream, |web_transport_data_|
   // contains all of the associated metadata.
   std::unique_ptr<WebTransportDataStream> web_transport_data_;
+
+  // HTTP/3 Datagram support.
+  Http3DatagramRegistrationVisitor* datagram_registration_visitor_ = nullptr;
+  Http3DatagramVisitor* datagram_no_context_visitor_ = nullptr;
+  absl::optional<QuicDatagramStreamId> datagram_flow_id_;
+  QuicDatagramContextId datagram_next_available_context_id_;
+  absl::flat_hash_map<QuicDatagramContextId, Http3DatagramVisitor*>
+      datagram_context_visitors_;
 };
 
 }  // namespace quic
diff --git a/quic/core/http/quic_spdy_stream_test.cc b/quic/core/http/quic_spdy_stream_test.cc
index 7648277..038e0f6 100644
--- a/quic/core/http/quic_spdy_stream_test.cc
+++ b/quic/core/http/quic_spdy_stream_test.cc
@@ -3126,7 +3126,7 @@
   spdy::SpdyHeaderBlock headers;
   headers[":method"] = "CONNECT";
   headers[":protocol"] = "webtransport";
-  headers["datagram-flow-id"] = absl::StrCat(session_->GetNextDatagramFlowId());
+  headers["datagram-flow-id"] = absl::StrCat(stream_->id());
   stream_->WriteHeaders(std::move(headers), /*fin=*/false, nullptr);
   ASSERT_TRUE(stream_->web_transport() != nullptr);
   EXPECT_EQ(stream_->id(), stream_->web_transport()->id());
@@ -3144,8 +3144,7 @@
 
   headers_[":method"] = "CONNECT";
   headers_[":protocol"] = "webtransport";
-  headers_["datagram-flow-id"] =
-      absl::StrCat(session_->GetNextDatagramFlowId());
+  headers_["datagram-flow-id"] = absl::StrCat(stream_->id());
 
   stream_->OnStreamHeadersPriority(
       spdy::SpdyStreamPrecedence(kV3HighestPriority));
@@ -3157,6 +3156,155 @@
   EXPECT_EQ(stream_->id(), stream_->web_transport()->id());
 }
 
+TEST_P(QuicSpdyStreamTest,
+       ProcessIncomingWebTransportHeadersWithMismatchedFlowId) {
+  if (!UsesHttp3()) {
+    return;
+  }
+  // TODO(b/181256914) Remove this test when we deprecate
+  // draft-ietf-masque-h3-datagram-00 in favor of later drafts.
+
+  Initialize(kShouldProcessData);
+  session_->set_should_negotiate_h3_datagram(true);
+  session_->EnableWebTransport();
+  QuicSpdySessionPeer::EnableWebTransport(*session_);
+
+  headers_[":method"] = "CONNECT";
+  headers_[":protocol"] = "webtransport";
+  headers_["datagram-flow-id"] = "2";
+
+  stream_->OnStreamHeadersPriority(
+      spdy::SpdyStreamPrecedence(kV3HighestPriority));
+  ProcessHeaders(false, headers_);
+  EXPECT_EQ("", stream_->data());
+  EXPECT_FALSE(stream_->header_list().empty());
+  EXPECT_FALSE(stream_->IsDoneReading());
+  ASSERT_TRUE(stream_->web_transport() != nullptr);
+  EXPECT_EQ(stream_->id(), stream_->web_transport()->id());
+}
+
+TEST_P(QuicSpdyStreamTest, GetNextDatagramContextIdClient) {
+  if (!UsesHttp3()) {
+    return;
+  }
+  InitializeWithPerspective(kShouldProcessData, Perspective::IS_CLIENT);
+  ::testing::NiceMock<MockHttp3DatagramRegistrationVisitor> visitor;
+  stream_->RegisterHttp3DatagramRegistrationVisitor(&visitor);
+  EXPECT_EQ(stream_->GetNextDatagramContextId(), 0u);
+  EXPECT_EQ(stream_->GetNextDatagramContextId(), 2u);
+  EXPECT_EQ(stream_->GetNextDatagramContextId(), 4u);
+  EXPECT_EQ(stream_->GetNextDatagramContextId(), 6u);
+  stream_->UnregisterHttp3DatagramRegistrationVisitor();
+}
+
+TEST_P(QuicSpdyStreamTest, GetNextDatagramContextIdServer) {
+  if (!UsesHttp3()) {
+    return;
+  }
+  InitializeWithPerspective(kShouldProcessData, Perspective::IS_SERVER);
+  ::testing::NiceMock<MockHttp3DatagramRegistrationVisitor> visitor;
+  stream_->RegisterHttp3DatagramRegistrationVisitor(&visitor);
+  EXPECT_EQ(stream_->GetNextDatagramContextId(), 1u);
+  EXPECT_EQ(stream_->GetNextDatagramContextId(), 3u);
+  EXPECT_EQ(stream_->GetNextDatagramContextId(), 5u);
+  EXPECT_EQ(stream_->GetNextDatagramContextId(), 7u);
+  stream_->UnregisterHttp3DatagramRegistrationVisitor();
+}
+
+TEST_P(QuicSpdyStreamTest, H3DatagramRegistrationWithoutContext) {
+  if (!UsesHttp3()) {
+    return;
+  }
+  Initialize(kShouldProcessData);
+  session_->set_should_negotiate_h3_datagram(true);
+  QuicSpdySessionPeer::SetH3DatagramSupported(session_.get(), true);
+  session_->RegisterHttp3DatagramFlowId(stream_->id(), stream_->id());
+  ::testing::NiceMock<MockHttp3DatagramRegistrationVisitor>
+      h3_datagram_registration_visitor;
+  SavingHttp3DatagramVisitor h3_datagram_visitor;
+  absl::optional<QuicDatagramContextId> context_id;
+  Http3DatagramContextExtensions extensions;
+  ASSERT_EQ(QuicDataWriter::GetVarInt62Len(stream_->id()), 1);
+  std::array<char, 256> datagram;
+  datagram[0] = stream_->id();
+  for (size_t i = 1; i < datagram.size(); i++) {
+    datagram[i] = i;
+  }
+  stream_->RegisterHttp3DatagramRegistrationVisitor(
+      &h3_datagram_registration_visitor);
+  stream_->RegisterHttp3DatagramContextId(context_id, extensions,
+                                          &h3_datagram_visitor);
+  session_->OnMessageReceived(
+      absl::string_view(datagram.data(), datagram.size()));
+  EXPECT_THAT(h3_datagram_visitor.received_h3_datagrams(),
+              ElementsAre(SavingHttp3DatagramVisitor::SavedHttp3Datagram{
+                  stream_->id(), context_id,
+                  std::string(&datagram[1], datagram.size() - 1)}));
+  // Test move.
+  ::testing::NiceMock<MockHttp3DatagramRegistrationVisitor>
+      h3_datagram_registration_visitor2;
+  stream_->MoveHttp3DatagramRegistration(&h3_datagram_registration_visitor2);
+  SavingHttp3DatagramVisitor h3_datagram_visitor2;
+  stream_->MoveHttp3DatagramContextIdRegistration(context_id,
+                                                  &h3_datagram_visitor2);
+  EXPECT_TRUE(h3_datagram_visitor2.received_h3_datagrams().empty());
+  session_->OnMessageReceived(
+      absl::string_view(datagram.data(), datagram.size()));
+  EXPECT_THAT(h3_datagram_visitor2.received_h3_datagrams(),
+              ElementsAre(SavingHttp3DatagramVisitor::SavedHttp3Datagram{
+                  stream_->id(), context_id,
+                  std::string(&datagram[1], datagram.size() - 1)}));
+  // Cleanup.
+  stream_->UnregisterHttp3DatagramContextId(context_id);
+  stream_->UnregisterHttp3DatagramRegistrationVisitor();
+  session_->UnregisterHttp3DatagramFlowId(stream_->id());
+}
+
+TEST_P(QuicSpdyStreamTest, H3DatagramRegistrationWithContext) {
+  if (!UsesHttp3()) {
+    return;
+  }
+  Initialize(kShouldProcessData);
+  session_->set_should_negotiate_h3_datagram(true);
+  QuicSpdySessionPeer::SetH3DatagramSupported(session_.get(), true);
+  session_->RegisterHttp3DatagramFlowId(stream_->id(), stream_->id());
+  ::testing::NiceMock<MockHttp3DatagramRegistrationVisitor>
+      h3_datagram_registration_visitor;
+  SavingHttp3DatagramVisitor h3_datagram_visitor;
+  absl::optional<QuicDatagramContextId> context_id = 42;
+  Http3DatagramContextExtensions extensions;
+  stream_->RegisterHttp3DatagramRegistrationVisitor(
+      &h3_datagram_registration_visitor);
+  stream_->RegisterHttp3DatagramContextId(context_id, extensions,
+                                          &h3_datagram_visitor);
+  // Test move.
+  ::testing::NiceMock<MockHttp3DatagramRegistrationVisitor>
+      h3_datagram_registration_visitor2;
+  stream_->MoveHttp3DatagramRegistration(&h3_datagram_registration_visitor2);
+  SavingHttp3DatagramVisitor h3_datagram_visitor2;
+  stream_->MoveHttp3DatagramContextIdRegistration(context_id,
+                                                  &h3_datagram_visitor2);
+  // Cleanup.
+  stream_->UnregisterHttp3DatagramContextId(context_id);
+  stream_->UnregisterHttp3DatagramRegistrationVisitor();
+  session_->UnregisterHttp3DatagramFlowId(stream_->id());
+}
+
+TEST_P(QuicSpdyStreamTest, SendHttp3Datagram) {
+  if (!UsesHttp3()) {
+    return;
+  }
+  Initialize(kShouldProcessData);
+  session_->set_should_negotiate_h3_datagram(true);
+  QuicSpdySessionPeer::SetH3DatagramSupported(session_.get(), true);
+  absl::optional<QuicDatagramContextId> context_id;
+  std::string h3_datagram_payload = {1, 2, 3, 4, 5, 6};
+  EXPECT_CALL(*connection_, SendMessage(1, _, false))
+      .WillOnce(Return(MESSAGE_STATUS_SUCCESS));
+  EXPECT_EQ(stream_->SendHttp3Datagram(context_id, h3_datagram_payload),
+            MESSAGE_STATUS_SUCCESS);
+}
+
 }  // namespace
 }  // namespace test
 }  // namespace quic
diff --git a/quic/core/http/spdy_utils.cc b/quic/core/http/spdy_utils.cc
index fd6fdc6..27f5314 100644
--- a/quic/core/http/spdy_utils.cc
+++ b/quic/core/http/spdy_utils.cc
@@ -154,7 +154,7 @@
 }
 
 // static
-absl::optional<QuicDatagramFlowId> SpdyUtils::ParseDatagramFlowIdHeader(
+absl::optional<QuicDatagramStreamId> SpdyUtils::ParseDatagramFlowIdHeader(
     const spdy::SpdyHeaderBlock& headers) {
   auto flow_id_pair = headers.find("datagram-flow-id");
   if (flow_id_pair == headers.end()) {
@@ -162,7 +162,7 @@
   }
   std::vector<absl::string_view> flow_id_strings =
       absl::StrSplit(flow_id_pair->second, ',');
-  absl::optional<QuicDatagramFlowId> first_named_flow_id;
+  absl::optional<QuicDatagramStreamId> first_named_flow_id;
   for (absl::string_view flow_id_string : flow_id_strings) {
     std::vector<absl::string_view> flow_id_components =
         absl::StrSplit(flow_id_string, ';');
@@ -172,7 +172,7 @@
     absl::string_view flow_id_value_string = flow_id_components[0];
     quiche::QuicheTextUtils::RemoveLeadingAndTrailingWhitespace(
         &flow_id_value_string);
-    QuicDatagramFlowId flow_id;
+    QuicDatagramStreamId flow_id;
     if (!absl::SimpleAtoi(flow_id_value_string, &flow_id)) {
       continue;
     }
@@ -190,7 +190,7 @@
 
 // static
 void SpdyUtils::AddDatagramFlowIdHeader(spdy::SpdyHeaderBlock* headers,
-                                        QuicDatagramFlowId flow_id) {
+                                        QuicDatagramStreamId flow_id) {
   (*headers)["datagram-flow-id"] = absl::StrCat(flow_id);
 }
 
diff --git a/quic/core/http/spdy_utils.h b/quic/core/http/spdy_utils.h
index e3a8a97..01fae4f 100644
--- a/quic/core/http/spdy_utils.h
+++ b/quic/core/http/spdy_utils.h
@@ -55,12 +55,12 @@
 
   // Parses the "datagram-flow-id" header, returns the flow ID on success, or
   // returns absl::nullopt if the header was not present or failed to parse.
-  static absl::optional<QuicDatagramFlowId> ParseDatagramFlowIdHeader(
+  static absl::optional<QuicDatagramStreamId> ParseDatagramFlowIdHeader(
       const spdy::SpdyHeaderBlock& headers);
 
   // Adds the "datagram-flow-id" header.
   static void AddDatagramFlowIdHeader(spdy::SpdyHeaderBlock* headers,
-                                      QuicDatagramFlowId flow_id);
+                                      QuicDatagramStreamId flow_id);
 };
 
 }  // namespace quic
diff --git a/quic/core/http/spdy_utils_test.cc b/quic/core/http/spdy_utils_test.cc
index 7307665..918739d 100644
--- a/quic/core/http/spdy_utils_test.cc
+++ b/quic/core/http/spdy_utils_test.cc
@@ -34,7 +34,7 @@
 
 static void ValidateDatagramFlowId(
     const std::string& header_value,
-    absl::optional<QuicDatagramFlowId> expected_flow_id) {
+    absl::optional<QuicDatagramStreamId> expected_flow_id) {
   SpdyHeaderBlock headers;
   headers["datagram-flow-id"] = header_value;
   ASSERT_EQ(SpdyUtils::ParseDatagramFlowIdHeader(headers), expected_flow_id);
@@ -391,7 +391,7 @@
   SpdyHeaderBlock headers;
   EXPECT_EQ(SpdyUtils::ParseDatagramFlowIdHeader(headers), absl::nullopt);
   // Add header and verify it parses.
-  QuicDatagramFlowId flow_id = 123;
+  QuicDatagramStreamId flow_id = 123;
   SpdyUtils::AddDatagramFlowIdHeader(&headers, flow_id);
   EXPECT_EQ(SpdyUtils::ParseDatagramFlowIdHeader(headers), flow_id);
   // Test empty header.
diff --git a/quic/core/http/web_transport_http3.cc b/quic/core/http/web_transport_http3.cc
index 2a105c2..820950b 100644
--- a/quic/core/http/web_transport_http3.cc
+++ b/quic/core/http/web_transport_http3.cc
@@ -36,17 +36,19 @@
 
 WebTransportHttp3::WebTransportHttp3(QuicSpdySession* session,
                                      QuicSpdyStream* connect_stream,
-                                     WebTransportSessionId id,
-                                     QuicDatagramFlowId flow_id)
+                                     WebTransportSessionId id)
     : session_(session),
       connect_stream_(connect_stream),
       id_(id),
-      flow_id_(flow_id),
       visitor_(std::make_unique<NoopWebTransportVisitor>()) {
   QUICHE_DCHECK(session_->SupportsWebTransport());
   QUICHE_DCHECK(IsValidWebTransportSessionId(id, session_->version()));
   QUICHE_DCHECK_EQ(connect_stream_->id(), id);
-  session_->RegisterHttp3FlowId(flow_id, this);
+  connect_stream_->RegisterHttp3DatagramRegistrationVisitor(this);
+  if (session_->perspective() == Perspective::IS_CLIENT) {
+    context_is_known_ = true;
+    context_currently_registered_ = true;
+  }
 }
 
 void WebTransportHttp3::AssociateStream(QuicStreamId stream_id) {
@@ -74,7 +76,11 @@
   for (QuicStreamId id : streams) {
     session_->ResetStream(id, QUIC_STREAM_WEBTRANSPORT_SESSION_GONE);
   }
-  session_->UnregisterHttp3FlowId(flow_id_);
+  if (context_currently_registered_) {
+    context_currently_registered_ = false;
+    connect_stream_->UnregisterHttp3DatagramContextId(context_id_);
+  }
+  connect_stream_->UnregisterHttp3DatagramRegistrationVisitor();
 }
 
 void WebTransportHttp3::HeadersReceived(const spdy::SpdyHeaderBlock& headers) {
@@ -153,21 +159,81 @@
 }
 
 MessageStatus WebTransportHttp3::SendOrQueueDatagram(QuicMemSlice datagram) {
-  return session_->SendHttp3Datagram(
-      flow_id_, absl::string_view(datagram.data(), datagram.length()));
+  return connect_stream_->SendHttp3Datagram(
+      context_id_, absl::string_view(datagram.data(), datagram.length()));
 }
 
 void WebTransportHttp3::SetDatagramMaxTimeInQueue(
     QuicTime::Delta max_time_in_queue) {
-  session_->SetMaxTimeInQueueForFlowId(flow_id_, max_time_in_queue);
+  connect_stream_->SetMaxDatagramTimeInQueue(max_time_in_queue);
 }
 
-void WebTransportHttp3::OnHttp3Datagram(QuicDatagramFlowId flow_id,
-                                        absl::string_view payload) {
-  QUICHE_DCHECK_EQ(flow_id, flow_id_);
+void WebTransportHttp3::OnHttp3Datagram(
+    QuicStreamId stream_id, absl::optional<QuicDatagramContextId> context_id,
+    absl::string_view payload) {
+  QUICHE_DCHECK_EQ(stream_id, connect_stream_->id());
+  QUICHE_DCHECK(context_id == context_id_);
   visitor_->OnDatagramReceived(payload);
 }
 
+void WebTransportHttp3::OnContextReceived(
+    QuicStreamId stream_id, absl::optional<QuicDatagramContextId> context_id,
+    const Http3DatagramContextExtensions& /*extensions*/) {
+  if (stream_id != connect_stream_->id()) {
+    QUIC_BUG(WT3 bad datagram context registration)
+        << ENDPOINT << "Registered stream ID " << stream_id << ", expected "
+        << connect_stream_->id();
+    return;
+  }
+  if (!context_is_known_) {
+    context_is_known_ = true;
+    context_id_ = context_id;
+  }
+  if (context_id != context_id_) {
+    QUIC_DLOG(INFO) << ENDPOINT << "Ignoring unexpected context ID "
+                    << (context_id.has_value() ? context_id.value() : 0)
+                    << " instead of "
+                    << (context_id_.has_value() ? context_id_.value() : 0)
+                    << " on stream ID " << connect_stream_->id();
+    return;
+  }
+  if (session_->perspective() == Perspective::IS_SERVER) {
+    if (context_currently_registered_) {
+      QUIC_DLOG(ERROR) << ENDPOINT << "Received duplicate context ID "
+                       << (context_id_.has_value() ? context_id_.value() : 0)
+                       << " on stream ID " << connect_stream_->id();
+      session_->ResetStream(connect_stream_->id(), QUIC_STREAM_CANCELLED);
+      return;
+    }
+    context_currently_registered_ = true;
+    Http3DatagramContextExtensions reply_extensions;
+    connect_stream_->RegisterHttp3DatagramContextId(context_id_,
+                                                    reply_extensions, this);
+  }
+}
+
+void WebTransportHttp3::OnContextClosed(
+    QuicStreamId stream_id, absl::optional<QuicDatagramContextId> context_id,
+    const Http3DatagramContextExtensions& /*extensions*/) {
+  if (stream_id != connect_stream_->id()) {
+    QUIC_BUG(WT3 bad datagram context registration)
+        << ENDPOINT << "Closed context on stream ID " << stream_id
+        << ", expected " << connect_stream_->id();
+    return;
+  }
+  if (context_id != context_id_) {
+    QUIC_DLOG(INFO) << ENDPOINT << "Ignoring unexpected close of context ID "
+                    << (context_id.has_value() ? context_id.value() : 0)
+                    << " instead of "
+                    << (context_id_.has_value() ? context_id_.value() : 0)
+                    << " on stream ID " << connect_stream_->id();
+    return;
+  }
+  QUIC_DLOG(INFO) << ENDPOINT << "Received datagram context close on stream ID "
+                  << connect_stream_->id() << ", resetting stream";
+  session_->ResetStream(connect_stream_->id(), QUIC_STREAM_CANCELLED);
+}
+
 WebTransportHttp3UnidirectionalStream::WebTransportHttp3UnidirectionalStream(
     PendingStream* pending,
     QuicSpdySession* session)
diff --git a/quic/core/http/web_transport_http3.h b/quic/core/http/web_transport_http3.h
index 667256e..df5c4f6 100644
--- a/quic/core/http/web_transport_http3.h
+++ b/quic/core/http/web_transport_http3.h
@@ -28,12 +28,11 @@
 // <https://datatracker.ietf.org/doc/html/draft-ietf-webtrans-http3>
 class QUIC_EXPORT_PRIVATE WebTransportHttp3
     : public WebTransportSession,
-      public QuicSpdySession::Http3DatagramVisitor {
+      public QuicSpdyStream::Http3DatagramRegistrationVisitor,
+      public QuicSpdyStream::Http3DatagramVisitor {
  public:
-  WebTransportHttp3(QuicSpdySession* session,
-                    QuicSpdyStream* connect_stream,
-                    WebTransportSessionId id,
-                    QuicDatagramFlowId flow_id);
+  WebTransportHttp3(QuicSpdySession* session, QuicSpdyStream* connect_stream,
+                    WebTransportSessionId id);
 
   void HeadersReceived(const spdy::SpdyHeaderBlock& headers);
   void SetVisitor(std::unique_ptr<WebTransportVisitor> visitor) {
@@ -42,6 +41,9 @@
 
   WebTransportSessionId id() { return id_; }
   bool ready() { return ready_; }
+  absl::optional<QuicDatagramContextId> context_id() const {
+    return context_id_;
+  }
 
   void AssociateStream(QuicStreamId stream_id);
   void OnStreamClosed(QuicStreamId stream_id) { streams_.erase(stream_id); }
@@ -63,16 +65,32 @@
   MessageStatus SendOrQueueDatagram(QuicMemSlice datagram) override;
   void SetDatagramMaxTimeInQueue(QuicTime::Delta max_time_in_queue) override;
 
-  void OnHttp3Datagram(QuicDatagramFlowId flow_id,
+  // From QuicSpdyStream::Http3DatagramVisitor.
+  void OnHttp3Datagram(QuicStreamId stream_id,
+                       absl::optional<QuicDatagramContextId> context_id,
                        absl::string_view payload) override;
 
+  // From QuicSpdyStream::Http3DatagramRegistrationVisitor.
+  void OnContextReceived(
+      QuicStreamId stream_id, absl::optional<QuicDatagramContextId> context_id,
+      const Http3DatagramContextExtensions& extensions) override;
+  void OnContextClosed(
+      QuicStreamId stream_id, absl::optional<QuicDatagramContextId> context_id,
+      const Http3DatagramContextExtensions& extensions) override;
+
  private:
   QuicSpdySession* const session_;        // Unowned.
   QuicSpdyStream* const connect_stream_;  // Unowned.
   const WebTransportSessionId id_;
-  const QuicDatagramFlowId flow_id_;
+  absl::optional<QuicDatagramContextId> context_id_;
   // |ready_| is set to true when the peer has seen both sets of headers.
   bool ready_ = false;
+  // Whether we know which |context_id_| to use. On the client this is always
+  // true, and on the server it becomes true when we receive a context
+  // registeration capsule.
+  bool context_is_known_ = false;
+  // Whether |context_id_| is currently registered with |connect_stream_|.
+  bool context_currently_registered_ = false;
   std::unique_ptr<WebTransportVisitor> visitor_;
   absl::flat_hash_set<QuicStreamId> streams_;
   quiche::QuicheCircularDeque<QuicStreamId> incoming_bidirectional_streams_;
diff --git a/quic/core/quic_constants.h b/quic/core/quic_constants.h
index b750abb..162de89 100644
--- a/quic/core/quic_constants.h
+++ b/quic/core/quic_constants.h
@@ -295,10 +295,10 @@
 QUIC_EXPORT_PRIVATE extern const char* const kEPIDGoogleFrontEnd0;
 
 // HTTP/3 Datagrams.
-enum : QuicDatagramFlowId {
-  kFirstDatagramFlowIdClient = 0,
-  kFirstDatagramFlowIdServer = 1,
-  kDatagramFlowIdIncrement = 2,
+enum : QuicDatagramContextId {
+  kFirstDatagramContextIdClient = 0,
+  kFirstDatagramContextIdServer = 1,
+  kDatagramContextIdIncrement = 2,
 };
 
 }  // namespace quic
diff --git a/quic/core/quic_types.h b/quic/core/quic_types.h
index dc1fc69..bbe1a58 100644
--- a/quic/core/quic_types.h
+++ b/quic/core/quic_types.h
@@ -25,7 +25,13 @@
 using QuicPacketLength = uint16_t;
 using QuicControlFrameId = uint32_t;
 using QuicMessageId = uint32_t;
-using QuicDatagramFlowId = uint64_t;
+
+// TODO(b/181256914) replace QuicDatagramStreamId with QuicStreamId once we
+// remove support for draft-ietf-masque-h3-datagram-00 in favor of later drafts.
+using QuicDatagramStreamId = uint64_t;
+using QuicDatagramContextId = uint64_t;
+// Note that for draft-ietf-masque-h3-datagram-00, we represent the flow ID as a
+// QuicDatagramStreamId.
 
 // IMPORTANT: IETF QUIC defines stream IDs and stream counts as being unsigned
 // 62-bit numbers. However, we have decided to only support up to 2^32-1 streams
diff --git a/quic/masque/masque_client_session.cc b/quic/masque/masque_client_session.cc
index 31edff1..0d86569 100644
--- a/quic/masque/masque_client_session.cc
+++ b/quic/masque/masque_client_session.cc
@@ -92,11 +92,8 @@
     return nullptr;
   }
 
-  QuicDatagramFlowId flow_id = GetNextDatagramFlowId();
-
   QUIC_DLOG(INFO) << "Sending CONNECT-UDP request for " << target_server_address
-                  << " using flow ID " << flow_id << " on stream "
-                  << stream->id();
+                  << " on stream " << stream->id();
 
   // Send the request.
   spdy::Http2HeaderBlock headers;
@@ -104,7 +101,7 @@
   headers[":scheme"] = "masque";
   headers[":path"] = "/";
   headers[":authority"] = target_server_address.ToString();
-  SpdyUtils::AddDatagramFlowIdHeader(&headers, flow_id);
+  SpdyUtils::AddDatagramFlowIdHeader(&headers, stream->id());
   size_t bytes_sent =
       stream->SendRequest(std::move(headers), /*body=*/"", /*fin=*/false);
   if (bytes_sent == 0) {
@@ -112,9 +109,10 @@
     return nullptr;
   }
 
+  absl::optional<QuicDatagramContextId> context_id;
   connect_udp_client_states_.push_back(
-      ConnectUdpClientState(stream, encapsulated_client_session, this, flow_id,
-                            target_server_address));
+      ConnectUdpClientState(stream, encapsulated_client_session, this,
+                            context_id, target_server_address));
   return &connect_udp_client_states_.back();
 }
 
@@ -137,12 +135,15 @@
     return;
   }
 
-  QuicDatagramFlowId flow_id = connect_udp->flow_id();
-  MessageStatus message_status =
-      SendHttp3Datagram(connect_udp->flow_id(), packet);
+  MessageStatus message_status = SendHttp3Datagram(
+      connect_udp->stream()->id(), connect_udp->context_id(), packet);
 
   QUIC_DVLOG(1) << "Sent packet to " << target_server_address
-                << " compressed with flow ID " << flow_id
+                << " compressed with stream ID " << connect_udp->stream()->id()
+                << " context ID "
+                << (connect_udp->context_id().has_value()
+                        ? connect_udp->context_id().value()
+                        : 0)
                 << " and got message status "
                 << MessageStatusToString(message_status);
 }
@@ -178,7 +179,11 @@
   for (auto it = connect_udp_client_states_.begin();
        it != connect_udp_client_states_.end();) {
     if (it->encapsulated_client_session() == encapsulated_client_session) {
-      QUIC_DLOG(INFO) << "Removing state for flow_id " << it->flow_id();
+      QUIC_DLOG(INFO) << "Removing state for stream ID " << it->stream()->id()
+                      << " context ID "
+                      << (it->context_id().has_value()
+                              ? it->context_id().value()
+                              : 0);
       auto* stream = it->stream();
       it = connect_udp_client_states_.erase(it);
       if (!stream->write_side_closed()) {
@@ -217,8 +222,10 @@
        it != connect_udp_client_states_.end();) {
     if (it->stream()->id() == stream_id) {
       QUIC_DLOG(INFO) << "Stream " << stream_id
-                      << " was closed, removing state for flow_id "
-                      << it->flow_id();
+                      << " was closed, removing state for context ID "
+                      << (it->context_id().has_value()
+                              ? it->context_id().value()
+                              : 0);
       auto* encapsulated_client_session = it->encapsulated_client_session();
       it = connect_udp_client_states_.erase(it);
       encapsulated_client_session->CloseConnection(
@@ -250,20 +257,24 @@
     QuicSpdyClientStream* stream,
     EncapsulatedClientSession* encapsulated_client_session,
     MasqueClientSession* masque_session,
-    QuicDatagramFlowId flow_id,
+    absl::optional<QuicDatagramContextId> context_id,
     const QuicSocketAddress& target_server_address)
     : stream_(stream),
       encapsulated_client_session_(encapsulated_client_session),
       masque_session_(masque_session),
-      flow_id_(flow_id),
+      context_id_(context_id),
       target_server_address_(target_server_address) {
   QUICHE_DCHECK_NE(masque_session_, nullptr);
-  masque_session_->RegisterHttp3FlowId(this->flow_id(), this);
+  this->stream()->RegisterHttp3DatagramRegistrationVisitor(this);
+  Http3DatagramContextExtensions extensions;
+  this->stream()->RegisterHttp3DatagramContextId(this->context_id(), extensions,
+                                                 this);
 }
 
 MasqueClientSession::ConnectUdpClientState::~ConnectUdpClientState() {
-  if (flow_id_.has_value()) {
-    masque_session_->UnregisterHttp3FlowId(flow_id());
+  if (stream() != nullptr) {
+    stream()->UnregisterHttp3DatagramContextId(context_id());
+    stream()->UnregisterHttp3DatagramRegistrationVisitor();
   }
 }
 
@@ -278,23 +289,69 @@
   stream_ = other.stream_;
   encapsulated_client_session_ = other.encapsulated_client_session_;
   masque_session_ = other.masque_session_;
-  flow_id_ = other.flow_id_;
+  context_id_ = other.context_id_;
   target_server_address_ = other.target_server_address_;
-  other.flow_id_.reset();
-  if (flow_id_.has_value()) {
-    masque_session_->UnregisterHttp3FlowId(flow_id());
-    masque_session_->RegisterHttp3FlowId(flow_id(), this);
+  other.stream_ = nullptr;
+  if (stream() != nullptr) {
+    stream()->MoveHttp3DatagramRegistration(this);
+    stream()->MoveHttp3DatagramContextIdRegistration(context_id(), this);
   }
   return *this;
 }
 
 void MasqueClientSession::ConnectUdpClientState::OnHttp3Datagram(
-    QuicDatagramFlowId flow_id,
+    QuicStreamId stream_id, absl::optional<QuicDatagramContextId> context_id,
     absl::string_view payload) {
-  QUICHE_DCHECK_EQ(flow_id, this->flow_id());
+  QUICHE_DCHECK_EQ(stream_id, stream()->id());
+  QUICHE_DCHECK(context_id == context_id_);
   encapsulated_client_session_->ProcessPacket(payload, target_server_address_);
   QUIC_DVLOG(1) << "Sent " << payload.size()
-                << " bytes to connection for flow_id " << flow_id;
+                << " bytes to connection for stream ID " << stream_id
+                << " context ID "
+                << (context_id.has_value() ? context_id.value() : 0);
+}
+
+void MasqueClientSession::ConnectUdpClientState::OnContextReceived(
+    QuicStreamId stream_id, absl::optional<QuicDatagramContextId> context_id,
+    const Http3DatagramContextExtensions& /*extensions*/) {
+  if (stream_id != stream_->id()) {
+    QUIC_BUG(MASQUE client bad datagram context registration)
+        << "Registered stream ID " << stream_id << ", expected "
+        << stream_->id();
+    return;
+  }
+  if (context_id != context_id_) {
+    QUIC_DLOG(INFO) << "Ignoring unexpected context ID "
+                    << (context_id.has_value() ? context_id.value() : 0)
+                    << " instead of "
+                    << (context_id_.has_value() ? context_id_.value() : 0)
+                    << " on stream ID " << stream_->id();
+    return;
+  }
+  // Do nothing since the client registers first and we currently ignore
+  // extensions.
+}
+
+void MasqueClientSession::ConnectUdpClientState::OnContextClosed(
+    QuicStreamId stream_id, absl::optional<QuicDatagramContextId> context_id,
+    const Http3DatagramContextExtensions& /*extensions*/) {
+  if (stream_id != stream_->id()) {
+    QUIC_BUG(MASQUE client bad datagram context registration)
+        << "Closed context on stream ID " << stream_id << ", expected "
+        << stream_->id();
+    return;
+  }
+  if (context_id != context_id_) {
+    QUIC_DLOG(INFO) << "Ignoring unexpected close of context ID "
+                    << (context_id.has_value() ? context_id.value() : 0)
+                    << " instead of "
+                    << (context_id_.has_value() ? context_id_.value() : 0)
+                    << " on stream ID " << stream_->id();
+    return;
+  }
+  QUIC_DLOG(INFO) << "Received datagram context close on stream ID "
+                  << stream_->id() << ", closing stream";
+  masque_session_->ResetStream(stream_->id(), QUIC_STREAM_CANCELLED);
 }
 
 }  // namespace quic
diff --git a/quic/masque/masque_client_session.h b/quic/masque/masque_client_session.h
index 0b2a3dd..1c13087 100644
--- a/quic/masque/masque_client_session.h
+++ b/quic/masque/masque_client_session.h
@@ -107,7 +107,8 @@
  private:
   // State that the MasqueClientSession keeps for each CONNECT-UDP request.
   class QUIC_NO_EXPORT ConnectUdpClientState
-      : public QuicSpdySession::Http3DatagramVisitor {
+      : public QuicSpdyStream::Http3DatagramRegistrationVisitor,
+        public QuicSpdyStream::Http3DatagramVisitor {
    public:
     // |stream| and |encapsulated_client_session| must be valid for the lifetime
     // of the ConnectUdpClientState.
@@ -115,7 +116,7 @@
         QuicSpdyClientStream* stream,
         EncapsulatedClientSession* encapsulated_client_session,
         MasqueClientSession* masque_session,
-        QuicDatagramFlowId flow_id,
+        absl::optional<QuicDatagramContextId> context_id,
         const QuicSocketAddress& target_server_address);
 
     ~ConnectUdpClientState();
@@ -130,23 +131,33 @@
     EncapsulatedClientSession* encapsulated_client_session() const {
       return encapsulated_client_session_;
     }
-    QuicDatagramFlowId flow_id() const {
-      QUICHE_DCHECK(flow_id_.has_value());
-      return *flow_id_;
+    absl::optional<QuicDatagramContextId> context_id() const {
+      return context_id_;
     }
     const QuicSocketAddress& target_server_address() const {
       return target_server_address_;
     }
 
-    // From QuicSpdySession::Http3DatagramVisitor.
-    void OnHttp3Datagram(QuicDatagramFlowId flow_id,
+    // From QuicSpdyStream::Http3DatagramVisitor.
+    void OnHttp3Datagram(QuicStreamId stream_id,
+                         absl::optional<QuicDatagramContextId> context_id,
                          absl::string_view payload) override;
 
+    // From QuicSpdyStream::Http3DatagramRegistrationVisitor.
+    void OnContextReceived(
+        QuicStreamId stream_id,
+        absl::optional<QuicDatagramContextId> context_id,
+        const Http3DatagramContextExtensions& extensions) override;
+    void OnContextClosed(
+        QuicStreamId stream_id,
+        absl::optional<QuicDatagramContextId> context_id,
+        const Http3DatagramContextExtensions& extensions) override;
+
    private:
     QuicSpdyClientStream* stream_;                            // Unowned.
     EncapsulatedClientSession* encapsulated_client_session_;  // Unowned.
     MasqueClientSession* masque_session_;                     // Unowned.
-    absl::optional<QuicDatagramFlowId> flow_id_;
+    absl::optional<QuicDatagramContextId> context_id_;
     QuicSocketAddress target_server_address_;
   };
 
diff --git a/quic/masque/masque_compression_engine.cc b/quic/masque/masque_compression_engine.cc
index e383c79..e4471dd 100644
--- a/quic/masque/masque_compression_engine.cc
+++ b/quic/masque/masque_compression_engine.cc
@@ -21,7 +21,7 @@
 
 namespace {
 // |kFlowId0| is used to indicate creation of a new compression context.
-const QuicDatagramFlowId kFlowId0 = 0;
+const QuicDatagramStreamId kFlowId0 = 0;
 
 enum MasqueAddressFamily : uint8_t {
   MasqueAddressFamilyIPv4 = 4,
@@ -32,16 +32,16 @@
 
 MasqueCompressionEngine::MasqueCompressionEngine(
     QuicSpdySession* masque_session)
-    : masque_session_(masque_session) {}
+    : masque_session_(masque_session),
+      next_available_flow_id_(
+          masque_session_->perspective() == Perspective::IS_CLIENT ? 0 : 1) {}
 
-QuicDatagramFlowId MasqueCompressionEngine::FindOrCreateCompressionContext(
+QuicDatagramStreamId MasqueCompressionEngine::FindOrCreateCompressionContext(
     QuicConnectionId client_connection_id,
     QuicConnectionId server_connection_id,
-    const QuicSocketAddress& server_address,
-    bool client_connection_id_present,
-    bool server_connection_id_present,
-    bool* validated) {
-  QuicDatagramFlowId flow_id = kFlowId0;
+    const QuicSocketAddress& server_address, bool client_connection_id_present,
+    bool server_connection_id_present, bool* validated) {
+  QuicDatagramStreamId flow_id = kFlowId0;
   *validated = false;
   for (const auto& kv : contexts_) {
     const MasqueCompressionContext& context = kv.second;
@@ -74,11 +74,8 @@
   }
 
   // Create new compression context.
-  flow_id = masque_session_->GetNextDatagramFlowId();
-  if (flow_id == kFlowId0) {
-    // Do not use value zero which is reserved in this mode.
-    flow_id = masque_session_->GetNextDatagramFlowId();
-  }
+  next_available_flow_id_ += 2;
+  flow_id = next_available_flow_id_;
   QUIC_DVLOG(1) << "Compression assigning new flow_id " << flow_id << " to "
                 << server_address << " client " << client_connection_id
                 << " server " << server_connection_id;
@@ -96,13 +93,9 @@
     QuicConnectionId server_connection_id,
     const QuicSocketAddress& server_address,
     QuicConnectionId destination_connection_id,
-    QuicConnectionId source_connection_id,
-    QuicDatagramFlowId flow_id,
-    bool validated,
-    uint8_t first_byte,
-    bool long_header,
-    QuicDataReader* reader,
-    QuicDataWriter* writer) {
+    QuicConnectionId source_connection_id, QuicDatagramStreamId flow_id,
+    bool validated, uint8_t first_byte, bool long_header,
+    QuicDataReader* reader, QuicDataWriter* writer) {
   if (validated) {
     QUIC_DVLOG(1) << "Compressing using validated flow_id " << flow_id;
     if (!writer->WriteVarInt62(flow_id)) {
@@ -222,8 +215,7 @@
 }
 
 void MasqueCompressionEngine::CompressAndSendPacket(
-    absl::string_view packet,
-    QuicConnectionId client_connection_id,
+    absl::string_view packet, QuicConnectionId client_connection_id,
     QuicConnectionId server_connection_id,
     const QuicSocketAddress& server_address) {
   QUIC_DVLOG(2) << "Compressing client " << client_connection_id << " server "
@@ -258,7 +250,7 @@
   }
 
   bool validated = false;
-  QuicDatagramFlowId flow_id = FindOrCreateCompressionContext(
+  QuicDatagramStreamId flow_id = FindOrCreateCompressionContext(
       client_connection_id, server_connection_id, server_address,
       client_connection_id_present, server_connection_id_present, &validated);
 
@@ -297,9 +289,8 @@
 }
 
 bool MasqueCompressionEngine::ParseCompressionContext(
-    QuicDataReader* reader,
-    MasqueCompressionContext* context) {
-  QuicDatagramFlowId new_flow_id;
+    QuicDataReader* reader, MasqueCompressionContext* context) {
+  QuicDatagramStreamId new_flow_id;
   if (!reader->ReadVarInt62(&new_flow_id)) {
     QUIC_DLOG(ERROR) << "Could not read new_flow_id";
     return false;
@@ -398,10 +389,8 @@
 }
 
 bool MasqueCompressionEngine::WriteDecompressedPacket(
-    QuicDataReader* reader,
-    const MasqueCompressionContext& context,
-    std::vector<char>* packet,
-    bool* version_present) {
+    QuicDataReader* reader, const MasqueCompressionContext& context,
+    std::vector<char>* packet, bool* version_present) {
   QuicConnectionId destination_connection_id, source_connection_id;
   if (masque_session_->perspective() == Perspective::IS_SERVER) {
     destination_connection_id = context.server_connection_id;
@@ -464,16 +453,13 @@
 }
 
 bool MasqueCompressionEngine::DecompressDatagram(
-    absl::string_view datagram,
-    QuicConnectionId* client_connection_id,
-    QuicConnectionId* server_connection_id,
-    QuicSocketAddress* server_address,
-    std::vector<char>* packet,
-    bool* version_present) {
+    absl::string_view datagram, QuicConnectionId* client_connection_id,
+    QuicConnectionId* server_connection_id, QuicSocketAddress* server_address,
+    std::vector<char>* packet, bool* version_present) {
   QUIC_DVLOG(1) << "Decompressing DATAGRAM frame of length "
                 << datagram.length();
   QuicDataReader reader(datagram);
-  QuicDatagramFlowId flow_id;
+  QuicDatagramStreamId flow_id;
   if (!reader.ReadVarInt62(&flow_id)) {
     QUIC_DLOG(ERROR) << "Could not read flow_id";
     return false;
@@ -525,14 +511,14 @@
 
 void MasqueCompressionEngine::UnregisterClientConnectionId(
     QuicConnectionId client_connection_id) {
-  std::vector<QuicDatagramFlowId> flow_ids_to_remove;
+  std::vector<QuicDatagramStreamId> flow_ids_to_remove;
   for (const auto& kv : contexts_) {
     const MasqueCompressionContext& context = kv.second;
     if (context.client_connection_id == client_connection_id) {
       flow_ids_to_remove.push_back(kv.first);
     }
   }
-  for (QuicDatagramFlowId flow_id : flow_ids_to_remove) {
+  for (QuicDatagramStreamId flow_id : flow_ids_to_remove) {
     contexts_.erase(flow_id);
   }
 }
diff --git a/quic/masque/masque_compression_engine.h b/quic/masque/masque_compression_engine.h
index 9bea618..581b950 100644
--- a/quic/masque/masque_compression_engine.h
+++ b/quic/masque/masque_compression_engine.h
@@ -63,8 +63,7 @@
                           QuicConnectionId* client_connection_id,
                           QuicConnectionId* server_connection_id,
                           QuicSocketAddress* server_address,
-                          std::vector<char>* packet,
-                          bool* version_present);
+                          std::vector<char>* packet, bool* version_present);
 
   // Clears all entries referencing |client_connection_id| from the
   // compression table.
@@ -83,12 +82,11 @@
   // whether the corresponding connection ID is present in the current packet.
   // |validated| will contain whether the compression context that matches
   // these arguments is currently validated or not.
-  QuicDatagramFlowId FindOrCreateCompressionContext(
+  QuicDatagramStreamId FindOrCreateCompressionContext(
       QuicConnectionId client_connection_id,
       QuicConnectionId server_connection_id,
       const QuicSocketAddress& server_address,
-      bool client_connection_id_present,
-      bool server_connection_id_present,
+      bool client_connection_id_present, bool server_connection_id_present,
       bool* validated);
 
   // Writes compressed packet to |slice| during compression.
@@ -97,11 +95,9 @@
                                     const QuicSocketAddress& server_address,
                                     QuicConnectionId destination_connection_id,
                                     QuicConnectionId source_connection_id,
-                                    QuicDatagramFlowId flow_id,
-                                    bool validated,
-                                    uint8_t first_byte,
-                                    bool long_header,
-                                    QuicDataReader* reader,
+                                    QuicDatagramStreamId flow_id,
+                                    bool validated, uint8_t first_byte,
+                                    bool long_header, QuicDataReader* reader,
                                     QuicDataWriter* writer);
 
   // Parses compression context from flow ID 0 during decompression.
@@ -115,7 +111,8 @@
                                bool* version_present);
 
   QuicSpdySession* masque_session_;  // Unowned.
-  absl::flat_hash_map<QuicDatagramFlowId, MasqueCompressionContext> contexts_;
+  absl::flat_hash_map<QuicDatagramStreamId, MasqueCompressionContext> contexts_;
+  QuicDatagramStreamId next_available_flow_id_;
 };
 
 }  // namespace quic
diff --git a/quic/masque/masque_server_session.cc b/quic/masque/masque_server_session.cc
index 3c14fca..effc17f 100644
--- a/quic/masque/masque_server_session.cc
+++ b/quic/masque/masque_server_session.cc
@@ -58,8 +58,7 @@
 };
 
 std::unique_ptr<QuicBackendResponse> CreateBackendErrorResponse(
-    absl::string_view status,
-    absl::string_view error_details) {
+    absl::string_view status, absl::string_view error_details) {
   spdy::Http2HeaderBlock response_headers;
   response_headers[":status"] = status;
   response_headers["masque-debug-info"] = error_details;
@@ -72,24 +71,15 @@
 }  // namespace
 
 MasqueServerSession::MasqueServerSession(
-    MasqueMode masque_mode,
-    const QuicConfig& config,
+    MasqueMode masque_mode, const QuicConfig& config,
     const ParsedQuicVersionVector& supported_versions,
-    QuicConnection* connection,
-    QuicSession::Visitor* visitor,
-    Visitor* owner,
-    QuicEpollServer* epoll_server,
-    QuicCryptoServerStreamBase::Helper* helper,
+    QuicConnection* connection, QuicSession::Visitor* visitor, Visitor* owner,
+    QuicEpollServer* epoll_server, QuicCryptoServerStreamBase::Helper* helper,
     const QuicCryptoServerConfig* crypto_config,
     QuicCompressedCertsCache* compressed_certs_cache,
     MasqueServerBackend* masque_server_backend)
-    : QuicSimpleServerSession(config,
-                              supported_versions,
-                              connection,
-                              visitor,
-                              helper,
-                              crypto_config,
-                              compressed_certs_cache,
+    : QuicSimpleServerSession(config, supported_versions, connection, visitor,
+                              helper, crypto_config, compressed_certs_cache,
                               masque_server_backend),
       masque_server_backend_(masque_server_backend),
       owner_(owner),
@@ -153,8 +143,7 @@
 }
 
 void MasqueServerSession::OnConnectionClosed(
-    const QuicConnectionCloseFrame& frame,
-    ConnectionCloseSource source) {
+    const QuicConnectionCloseFrame& frame, ConnectionCloseSource source) {
   QuicSimpleServerSession::OnConnectionClosed(frame, source);
   QUIC_DLOG(INFO) << "Closing connection for " << connection_id();
   masque_server_backend_->RemoveBackendClient(connection_id());
@@ -165,7 +154,7 @@
 void MasqueServerSession::OnStreamClosed(QuicStreamId stream_id) {
   connect_udp_server_states_.remove_if(
       [stream_id](const ConnectUdpServerState& connect_udp) {
-        return connect_udp.stream_id() == stream_id;
+        return connect_udp.stream()->id() == stream_id;
       });
 
   QuicSimpleServerSession::OnStreamClosed(stream_id);
@@ -213,7 +202,7 @@
       QUIC_DLOG(ERROR) << "MASQUE request with bad method \"" << method << "\"";
       return CreateBackendErrorResponse("400", "Bad method");
     }
-    absl::optional<QuicDatagramFlowId> flow_id =
+    absl::optional<QuicDatagramStreamId> flow_id =
         SpdyUtils::ParseDatagramFlowIdHeader(request_headers);
     if (!flow_id.has_value()) {
       QUIC_DLOG(ERROR)
@@ -266,9 +255,25 @@
     }
     epoll_server_->RegisterFDForRead(fd_wrapper.fd(), this);
 
-    connect_udp_server_states_.emplace_back(ConnectUdpServerState(
-        *flow_id, request_handler->stream_id(), target_server_address,
-        fd_wrapper.extract_fd(), this));
+    absl::optional<QuicDatagramContextId> context_id;
+    QuicSpdyStream* stream = static_cast<QuicSpdyStream*>(
+        GetActiveStream(request_handler->stream_id()));
+    if (stream == nullptr) {
+      QUIC_BUG(bad masque server stream type)
+          << "Unexpected stream type for stream ID "
+          << request_handler->stream_id();
+      return CreateBackendErrorResponse("500", "Bad stream type");
+    }
+    stream->RegisterHttp3DatagramFlowId(*flow_id);
+    connect_udp_server_states_.push_back(
+        ConnectUdpServerState(stream, context_id, target_server_address,
+                              fd_wrapper.extract_fd(), this));
+
+    // TODO(b/181256914) remove this when we drop support for
+    // draft-ietf-masque-h3-datagram-00 in favor of later drafts.
+    Http3DatagramContextExtensions extensions;
+    stream->RegisterHttp3DatagramContextId(context_id, extensions,
+                                           &connect_udp_server_states_.back());
 
     spdy::Http2HeaderBlock response_headers;
     response_headers[":status"] = "200";
@@ -329,8 +334,7 @@
 }
 
 void MasqueServerSession::OnRegistration(QuicEpollServer* /*eps*/,
-                                         QuicUdpSocketFd fd,
-                                         int event_mask) {
+                                         QuicUdpSocketFd fd, int event_mask) {
   QUIC_DVLOG(1) << "OnRegistration " << fd << " event_mask " << event_mask;
 }
 
@@ -353,13 +357,12 @@
                                << event->in_events << " on unknown fd " << fd;
     return;
   }
-  QuicDatagramFlowId flow_id = it->flow_id();
   QuicSocketAddress expected_target_server_address =
       it->target_server_address();
   QUICHE_DCHECK(expected_target_server_address.IsInitialized());
   QUIC_DVLOG(1) << "Received readable event on fd " << fd << " (mask "
-                << event->in_events << ") flow_id " << flow_id << " server "
-                << expected_target_server_address;
+                << event->in_events << ") stream ID " << it->stream()->id()
+                << " server " << expected_target_server_address;
   QuicUdpSocketApi socket_api;
   BitMask64 packet_info_interested(QuicUdpPacketInfoBit::PEER_ADDRESS);
   char packet_buffer[kMaxIncomingPacketSize];
@@ -395,12 +398,14 @@
       return;
     }
     // The packet is valid, send it to the client in a DATAGRAM frame.
-    MessageStatus message_status = SendHttp3Datagram(
-        flow_id, absl::string_view(read_result.packet_buffer.buffer,
-                                   read_result.packet_buffer.buffer_len));
+    MessageStatus message_status = it->stream()->SendHttp3Datagram(
+        it->context_id(),
+        absl::string_view(read_result.packet_buffer.buffer,
+                          read_result.packet_buffer.buffer_len));
     QUIC_DVLOG(1) << "Sent UDP packet from " << expected_target_server_address
                   << " of length " << read_result.packet_buffer.buffer_len
-                  << " with flow ID " << flow_id << " and got message status "
+                  << " with stream ID " << it->stream()->id()
+                  << " and got message status "
                   << MessageStatusToString(message_status);
   }
 }
@@ -420,24 +425,25 @@
 }
 
 MasqueServerSession::ConnectUdpServerState::ConnectUdpServerState(
-    QuicDatagramFlowId flow_id,
-    QuicStreamId stream_id,
-    const QuicSocketAddress& target_server_address,
-    QuicUdpSocketFd fd,
+    QuicSpdyStream* stream, absl::optional<QuicDatagramContextId> context_id,
+    const QuicSocketAddress& target_server_address, QuicUdpSocketFd fd,
     MasqueServerSession* masque_session)
-    : flow_id_(flow_id),
-      stream_id_(stream_id),
+    : stream_(stream),
+      context_id_(context_id),
       target_server_address_(target_server_address),
       fd_(fd),
       masque_session_(masque_session) {
   QUICHE_DCHECK_NE(fd_, kQuicInvalidSocketFd);
   QUICHE_DCHECK_NE(masque_session_, nullptr);
-  masque_session_->RegisterHttp3FlowId(this->flow_id(), this);
+  this->stream()->RegisterHttp3DatagramRegistrationVisitor(this);
 }
 
 MasqueServerSession::ConnectUdpServerState::~ConnectUdpServerState() {
-  if (flow_id_.has_value()) {
-    masque_session_->UnregisterHttp3FlowId(flow_id());
+  if (stream() != nullptr) {
+    stream()->UnregisterHttp3DatagramRegistrationVisitor();
+    if (context_registered_) {
+      stream()->UnregisterHttp3DatagramContextId(context_id());
+    }
   }
   if (fd_ == kQuicInvalidSocketFd) {
     return;
@@ -463,24 +469,29 @@
     masque_session_->epoll_server()->UnregisterFD(fd_);
     socket_api.Destroy(fd_);
   }
-  flow_id_ = other.flow_id_;
-  stream_id_ = other.stream_id_;
+  stream_ = other.stream_;
+  other.stream_ = nullptr;
+  context_id_ = other.context_id_;
   target_server_address_ = other.target_server_address_;
   fd_ = other.fd_;
   masque_session_ = other.masque_session_;
   other.fd_ = kQuicInvalidSocketFd;
-  other.flow_id_.reset();
-  if (flow_id_.has_value()) {
-    masque_session_->UnregisterHttp3FlowId(flow_id());
-    masque_session_->RegisterHttp3FlowId(flow_id(), this);
+  context_registered_ = other.context_registered_;
+  other.context_registered_ = false;
+  if (stream() != nullptr) {
+    stream()->MoveHttp3DatagramRegistration(this);
+    if (context_registered_) {
+      stream()->MoveHttp3DatagramContextIdRegistration(context_id(), this);
+    }
   }
   return *this;
 }
 
 void MasqueServerSession::ConnectUdpServerState::OnHttp3Datagram(
-    QuicDatagramFlowId flow_id,
+    QuicStreamId stream_id, absl::optional<QuicDatagramContextId> context_id,
     absl::string_view payload) {
-  QUICHE_DCHECK_EQ(flow_id, this->flow_id());
+  QUICHE_DCHECK_EQ(stream_id, stream()->id());
+  QUICHE_DCHECK(context_id == context_id_);
   QuicUdpSocketApi socket_api;
   QuicUdpPacketInfo packet_info;
   packet_info.SetPeerAddress(target_server_address_);
@@ -490,4 +501,58 @@
                 << target_server_address_ << " with result " << write_result;
 }
 
+void MasqueServerSession::ConnectUdpServerState::OnContextReceived(
+    QuicStreamId stream_id, absl::optional<QuicDatagramContextId> context_id,
+    const Http3DatagramContextExtensions& /*extensions*/) {
+  if (stream_id != stream()->id()) {
+    QUIC_BUG(MASQUE server bad datagram context registration)
+        << "Registered stream ID " << stream_id << ", expected "
+        << stream()->id();
+    return;
+  }
+  if (!context_received_) {
+    context_received_ = true;
+    context_id_ = context_id;
+  }
+  if (context_id != context_id_) {
+    QUIC_DLOG(INFO) << "Ignoring unexpected context ID "
+                    << (context_id.has_value() ? context_id.value() : 0)
+                    << " instead of "
+                    << (context_id_.has_value() ? context_id_.value() : 0)
+                    << " on stream ID " << stream()->id();
+    return;
+  }
+  if (context_registered_) {
+    QUIC_BUG(MASQUE server double datagram context registration)
+        << "Try to re-register stream ID " << stream_id << " context ID "
+        << (context_id_.has_value() ? context_id_.value() : 0);
+    return;
+  }
+  context_registered_ = true;
+  Http3DatagramContextExtensions reply_extensions;
+  stream()->RegisterHttp3DatagramContextId(context_id_, reply_extensions, this);
+}
+
+void MasqueServerSession::ConnectUdpServerState::OnContextClosed(
+    QuicStreamId stream_id, absl::optional<QuicDatagramContextId> context_id,
+    const Http3DatagramContextExtensions& /*extensions*/) {
+  if (stream_id != stream()->id()) {
+    QUIC_BUG(MASQUE server bad datagram context registration)
+        << "Closed context on stream ID " << stream_id << ", expected "
+        << stream()->id();
+    return;
+  }
+  if (context_id != context_id_) {
+    QUIC_DLOG(INFO) << "Ignoring unexpected close of context ID "
+                    << (context_id.has_value() ? context_id.value() : 0)
+                    << " instead of "
+                    << (context_id_.has_value() ? context_id_.value() : 0)
+                    << " on stream ID " << stream()->id();
+    return;
+  }
+  QUIC_DLOG(INFO) << "Received datagram context close on stream ID "
+                  << stream()->id() << ", closing stream";
+  masque_session_->ResetStream(stream()->id(), QUIC_STREAM_CANCELLED);
+}
+
 }  // namespace quic
diff --git a/quic/masque/masque_server_session.h b/quic/masque/masque_server_session.h
index ebf0d21..1bec14c 100644
--- a/quic/masque/masque_server_session.h
+++ b/quic/masque/masque_server_session.h
@@ -38,14 +38,10 @@
   };
 
   explicit MasqueServerSession(
-      MasqueMode masque_mode,
-      const QuicConfig& config,
+      MasqueMode masque_mode, const QuicConfig& config,
       const ParsedQuicVersionVector& supported_versions,
-      QuicConnection* connection,
-      QuicSession::Visitor* visitor,
-      Visitor* owner,
-      QuicEpollServer* epoll_server,
-      QuicCryptoServerStreamBase::Helper* helper,
+      QuicConnection* connection, QuicSession::Visitor* visitor, Visitor* owner,
+      QuicEpollServer* epoll_server, QuicCryptoServerStreamBase::Helper* helper,
       const QuicCryptoServerConfig* crypto_config,
       QuicCompressedCertsCache* compressed_certs_cache,
       MasqueServerBackend* masque_server_backend);
@@ -71,8 +67,7 @@
       QuicSimpleServerBackend::RequestHandler* request_handler) override;
 
   // From QuicEpollCallbackInterface.
-  void OnRegistration(QuicEpollServer* eps,
-                      QuicUdpSocketFd fd,
+  void OnRegistration(QuicEpollServer* eps, QuicUdpSocketFd fd,
                       int event_mask) override;
   void OnModification(QuicUdpSocketFd fd, int event_mask) override;
   void OnEvent(QuicUdpSocketFd fd, QuicEpollEvent* event) override;
@@ -88,15 +83,15 @@
  private:
   // State that the MasqueServerSession keeps for each CONNECT-UDP request.
   class QUIC_NO_EXPORT ConnectUdpServerState
-      : public QuicSpdySession::Http3DatagramVisitor {
+      : public QuicSpdyStream::Http3DatagramRegistrationVisitor,
+        public QuicSpdyStream::Http3DatagramVisitor {
    public:
     // ConnectUdpServerState takes ownership of |fd|. It will unregister it
     // from |epoll_server| and close the file descriptor when destructed.
     explicit ConnectUdpServerState(
-        QuicDatagramFlowId flow_id,
-        QuicStreamId stream_id,
-        const QuicSocketAddress& target_server_address,
-        QuicUdpSocketFd fd,
+        QuicSpdyStream* stream,
+        absl::optional<QuicDatagramContextId> context_id,
+        const QuicSocketAddress& target_server_address, QuicUdpSocketFd fd,
         MasqueServerSession* masque_session);
 
     ~ConnectUdpServerState();
@@ -107,26 +102,38 @@
     ConnectUdpServerState& operator=(const ConnectUdpServerState&) = delete;
     ConnectUdpServerState& operator=(ConnectUdpServerState&&);
 
-    QuicDatagramFlowId flow_id() const {
-      QUICHE_DCHECK(flow_id_.has_value());
-      return *flow_id_;
+    QuicSpdyStream* stream() const { return stream_; }
+    absl::optional<QuicDatagramContextId> context_id() const {
+      return context_id_;
     }
-    QuicStreamId stream_id() const { return stream_id_; }
     const QuicSocketAddress& target_server_address() const {
       return target_server_address_;
     }
     QuicUdpSocketFd fd() const { return fd_; }
 
-    // From QuicSpdySession::Http3DatagramVisitor.
-    void OnHttp3Datagram(QuicDatagramFlowId flow_id,
+    // From QuicSpdyStream::Http3DatagramVisitor.
+    void OnHttp3Datagram(QuicStreamId stream_id,
+                         absl::optional<QuicDatagramContextId> context_id,
                          absl::string_view payload) override;
 
+    // From QuicSpdyStream::Http3DatagramRegistrationVisitor.
+    void OnContextReceived(
+        QuicStreamId stream_id,
+        absl::optional<QuicDatagramContextId> context_id,
+        const Http3DatagramContextExtensions& extensions) override;
+    void OnContextClosed(
+        QuicStreamId stream_id,
+        absl::optional<QuicDatagramContextId> context_id,
+        const Http3DatagramContextExtensions& extensions) override;
+
    private:
-    absl::optional<QuicDatagramFlowId> flow_id_;
-    QuicStreamId stream_id_;
+    QuicSpdyStream* stream_;
+    absl::optional<QuicDatagramContextId> context_id_;
     QuicSocketAddress target_server_address_;
-    QuicUdpSocketFd fd_;             // Owned.
+    QuicUdpSocketFd fd_;                   // Owned.
     MasqueServerSession* masque_session_;  // Unowned.
+    bool context_received_ = false;
+    bool context_registered_ = false;
   };
 
   bool ShouldNegotiateHttp3Datagram() override { return true; }
diff --git a/quic/test_tools/quic_test_utils.h b/quic/test_tools/quic_test_utils.h
index 1200f01..b2b7cb6 100644
--- a/quic/test_tools/quic_test_utils.h
+++ b/quic/test_tools/quic_test_utils.h
@@ -2323,31 +2323,49 @@
     uint8_t source_connection_id_length);
 
 // Implementation of Http3DatagramVisitor which saves all received datagrams.
-class SavingHttp3DatagramVisitor
-    : public QuicSpdySession::Http3DatagramVisitor {
+class SavingHttp3DatagramVisitor : public QuicSpdyStream::Http3DatagramVisitor {
  public:
   struct SavedHttp3Datagram {
-    QuicDatagramFlowId flow_id;
+    QuicStreamId stream_id;
+    absl::optional<QuicDatagramContextId> context_id;
     std::string payload;
     bool operator==(const SavedHttp3Datagram& o) const {
-      return flow_id == o.flow_id && payload == o.payload;
+      return stream_id == o.stream_id && context_id == o.context_id &&
+             payload == o.payload;
     }
   };
   const std::vector<SavedHttp3Datagram>& received_h3_datagrams() const {
     return received_h3_datagrams_;
   }
 
-  // Override from QuicSpdySession::Http3DatagramVisitor.
-  void OnHttp3Datagram(QuicDatagramFlowId flow_id,
+  // Override from QuicSpdyStream::Http3DatagramVisitor.
+  void OnHttp3Datagram(QuicStreamId stream_id,
+                       absl::optional<QuicDatagramContextId> context_id,
                        absl::string_view payload) override {
     received_h3_datagrams_.push_back(
-        SavedHttp3Datagram{flow_id, std::string(payload)});
+        SavedHttp3Datagram{stream_id, context_id, std::string(payload)});
   }
 
  private:
   std::vector<SavedHttp3Datagram> received_h3_datagrams_;
 };
 
+class MockHttp3DatagramRegistrationVisitor
+    : public QuicSpdyStream::Http3DatagramRegistrationVisitor {
+ public:
+  MOCK_METHOD(void, OnContextReceived,
+              (QuicStreamId stream_id,
+               absl::optional<QuicDatagramContextId> context_id,
+               const Http3DatagramContextExtensions& extensions),
+              (override));
+
+  MOCK_METHOD(void, OnContextClosed,
+              (QuicStreamId stream_id,
+               absl::optional<QuicDatagramContextId> context_id,
+               const Http3DatagramContextExtensions& extensions),
+              (override));
+};
+
 }  // namespace test
 }  // namespace quic