Move HTTP/3 Datagram parsing to QuicSpdySession

The flag is marked as enabling_blocked_by until we fully support draft-ietf-masque-h3-datagram.

Protected by FLAGS_quic_reloadable_flag_quic_h3_datagram.

PiperOrigin-RevId: 359877243
Change-Id: I3b76e3e703b14b5311c091a215237b34f2020bf1
diff --git a/quic/core/http/quic_spdy_session.cc b/quic/core/http/quic_spdy_session.cc
index 81c60b2..b6750fe 100644
--- a/quic/core/http/quic_spdy_session.cc
+++ b/quic/core/http/quic_spdy_session.cc
@@ -518,9 +518,11 @@
 
 QuicSpdySession::~QuicSpdySession() {
   QUIC_BUG_IF(destruction_indicator_ != 123456789)
-      << "QuicSpdyStream use after free. " << destruction_indicator_
+      << "QuicSpdySession use after free. " << destruction_indicator_
       << QuicStackTrace();
   destruction_indicator_ = 987654321;
+  QUIC_BUG_IF(!h3_datagram_registrations_.empty())
+      << "HTTP/3 datagram flow ID was not unregistered";
 }
 
 void QuicSpdySession::Initialize() {
@@ -1684,6 +1686,64 @@
   return result;
 }
 
+MessageStatus QuicSpdySession::SendHttp3Datagram(QuicDatagramFlowId flow_id,
+                                                 absl::string_view payload) {
+  size_t slice_length =
+      QuicDataWriter::GetVarInt62Len(flow_id) + payload.length();
+  QuicUniqueBufferPtr buffer = MakeUniqueBuffer(
+      connection()->helper()->GetStreamSendBufferAllocator(), slice_length);
+  QuicDataWriter writer(slice_length, buffer.get());
+  if (!writer.WriteVarInt62(flow_id)) {
+    QUIC_BUG << "Failed to write HTTP/3 datagram flow ID";
+    return MESSAGE_STATUS_INTERNAL_ERROR;
+  }
+  if (!writer.WriteBytes(payload.data(), payload.length())) {
+    QUIC_BUG << "Failed to write HTTP/3 datagram payload";
+    return MESSAGE_STATUS_INTERNAL_ERROR;
+  }
+
+  QuicMemSlice slice(std::move(buffer), slice_length);
+  return datagram_queue()->SendOrQueueDatagram(std::move(slice));
+}
+
+void QuicSpdySession::RegisterHttp3FlowId(
+    QuicDatagramFlowId flow_id,
+    QuicSpdySession::Http3DatagramVisitor* visitor) {
+  QUICHE_DCHECK_NE(visitor, nullptr);
+  auto insertion_result = h3_datagram_registrations_.insert({flow_id, visitor});
+  QUIC_BUG_IF(!insertion_result.second)
+      << "Attempted to doubly register HTTP/3 flow ID " << flow_id;
+}
+
+void QuicSpdySession::UnregisterHttp3FlowId(QuicDatagramFlowId flow_id) {
+  size_t num_erased = h3_datagram_registrations_.erase(flow_id);
+  QUIC_BUG_IF(num_erased != 1)
+      << "Attempted to unregister unknown HTTP/3 flow ID " << flow_id;
+}
+
+void QuicSpdySession::OnMessageReceived(absl::string_view message) {
+  QuicSession::OnMessageReceived(message);
+  if (!h3_datagram_supported_) {
+    QUIC_DLOG(ERROR) << "Ignoring unexpected received HTTP/3 datagram";
+    return;
+  }
+  QuicDataReader reader(message);
+  QuicDatagramFlowId flow_id;
+  if (!reader.ReadVarInt62(&flow_id)) {
+    QUIC_DLOG(ERROR) << "Failed to parse flow ID in received HTTP/3 datagram";
+    return;
+  }
+  auto it = h3_datagram_registrations_.find(flow_id);
+  if (it == h3_datagram_registrations_.end()) {
+    // TODO(dschinazi) buffer unknown HTTP/3 datagram flow IDs for a short
+    // period of time in case they were reordered.
+    QUIC_DLOG(ERROR) << "Received unknown HTTP/3 datagram flow ID " << flow_id;
+    return;
+  }
+  absl::string_view payload = reader.ReadRemainingPayload();
+  it->second->OnHttp3Datagram(flow_id, payload);
+}
+
 #undef ENDPOINT  // undef for jumbo builds
 
 }  // namespace quic
diff --git a/quic/core/http/quic_spdy_session.h b/quic/core/http/quic_spdy_session.h
index cff9e63..6a47b5d 100644
--- a/quic/core/http/quic_spdy_session.h
+++ b/quic/core/http/quic_spdy_session.h
@@ -420,6 +420,34 @@
   // SETTINGS.
   bool h3_datagram_supported() const { return h3_datagram_supported_; }
 
+  // Sends an HTTP/3 datagram. The flow ID is not part of |payload|.
+  MessageStatus SendHttp3Datagram(QuicDatagramFlowId flow_id,
+                                  absl::string_view payload);
+
+  class QUIC_EXPORT_PRIVATE Http3DatagramVisitor {
+   public:
+    virtual ~Http3DatagramVisitor() {}
+
+    // Called when an HTTP/3 datagram is received. |payload| does not contain
+    // the flow ID.
+    virtual void OnHttp3Datagram(QuicDatagramFlowId flow_id,
+                                 absl::string_view payload) = 0;
+  };
+
+  // Registers |visitor| to receive HTTP/3 datagrams for flow ID |flow_id|. This
+  // must not be called on a previously register flow ID without first calling
+  // UnregisterHttp3FlowId. |visitor| must be valid until a corresponding call
+  // to UnregisterHttp3FlowId. The flow ID must be unregistered before the
+  // QuicSpdySession is destroyed.
+  void RegisterHttp3FlowId(QuicDatagramFlowId flow_id,
+                           Http3DatagramVisitor* visitor);
+
+  // Unregister a given HTTP/3 datagram flow ID.
+  void UnregisterHttp3FlowId(QuicDatagramFlowId flow_id);
+
+  // Override from QuicSession to support HTTP/3 datagrams.
+  void OnMessageReceived(absl::string_view message) override;
+
  protected:
   // Override CreateIncomingStream(), CreateOutgoingBidirectionalStream() and
   // CreateOutgoingUnidirectionalStream() with QuicSpdyStream return type to
@@ -629,6 +657,9 @@
 
   // Whether both this endpoint and our peer support HTTP/3 datagrams.
   bool h3_datagram_supported_ = false;
+
+  absl::flat_hash_map<QuicDatagramFlowId, Http3DatagramVisitor*>
+      h3_datagram_registrations_;
 };
 
 }  // namespace quic
diff --git a/quic/core/http/quic_spdy_session_test.cc b/quic/core/http/quic_spdy_session_test.cc
index 3985769..d231f78 100644
--- a/quic/core/http/quic_spdy_session_test.cc
+++ b/quic/core/http/quic_spdy_session_test.cc
@@ -48,6 +48,7 @@
 #include "quic/test_tools/quic_test_utils.h"
 #include "common/platform/api/quiche_text_utils.h"
 #include "common/quiche_endian.h"
+#include "common/test_tools/quiche_test_utils.h"
 #include "spdy/core/spdy_framer.h"
 
 using spdy::kV3HighestPriority;
@@ -60,6 +61,7 @@
 using ::testing::_;
 using ::testing::AnyNumber;
 using ::testing::AtLeast;
+using ::testing::ElementsAre;
 using ::testing::InSequence;
 using ::testing::Invoke;
 using ::testing::Return;
@@ -3347,6 +3349,47 @@
   EXPECT_TRUE(session_.h3_datagram_supported());
 }
 
+TEST_P(QuicSpdySessionTestClient, H3DatagramRegistration) {
+  if (!version().UsesHttp3()) {
+    return;
+  }
+  CompleteHandshake();
+  SetQuicReloadableFlag(quic_h3_datagram, true);
+  QuicSpdySessionPeer::SetH3DatagramSupported(&session_, true);
+  SavingHttp3DatagramVisitor h3_datagram_visitor;
+  QuicDatagramFlowId flow_id = session_.GetNextDatagramFlowId();
+  ASSERT_EQ(QuicDataWriter::GetVarInt62Len(flow_id), 1);
+  uint8_t datagram[256];
+  datagram[0] = flow_id;
+  for (size_t i = 1; i < ABSL_ARRAYSIZE(datagram); i++) {
+    datagram[i] = i;
+  }
+  session_.RegisterHttp3FlowId(flow_id, &h3_datagram_visitor);
+  session_.OnMessageReceived(absl::string_view(
+      reinterpret_cast<const char*>(datagram), sizeof(datagram)));
+  EXPECT_THAT(
+      h3_datagram_visitor.received_h3_datagrams(),
+      ElementsAre(SavingHttp3DatagramVisitor::SavedHttp3Datagram{
+          flow_id, std::string(reinterpret_cast<const char*>(datagram + 1),
+                               sizeof(datagram) - 1)}));
+  session_.UnregisterHttp3FlowId(flow_id);
+}
+
+TEST_P(QuicSpdySessionTestClient, SendHttp3Datagram) {
+  if (!version().UsesHttp3()) {
+    return;
+  }
+  CompleteHandshake();
+  SetQuicReloadableFlag(quic_h3_datagram, true);
+  QuicSpdySessionPeer::SetH3DatagramSupported(&session_, true);
+  QuicDatagramFlowId flow_id = session_.GetNextDatagramFlowId();
+  std::string h3_datagram_payload = {1, 2, 3, 4, 5, 6};
+  EXPECT_CALL(*connection_, SendMessage(1, _, false))
+      .WillOnce(Return(MESSAGE_STATUS_SUCCESS));
+  EXPECT_EQ(session_.SendHttp3Datagram(flow_id, h3_datagram_payload),
+            MESSAGE_STATUS_SUCCESS);
+}
+
 }  // namespace
 }  // namespace test
 }  // namespace quic
diff --git a/quic/masque/masque_client_session.cc b/quic/masque/masque_client_session.cc
index 6c7859e..e158125 100644
--- a/quic/masque/masque_client_session.cc
+++ b/quic/masque/masque_client_session.cc
@@ -29,8 +29,8 @@
       compression_engine_(this) {}
 
 void MasqueClientSession::OnMessageReceived(absl::string_view message) {
-  QUIC_DVLOG(1) << "Received DATAGRAM frame of length " << message.length();
   if (masque_mode_ == MasqueMode::kLegacy) {
+    QUIC_DVLOG(1) << "Received DATAGRAM frame of length " << message.length();
     QuicConnectionId client_connection_id, server_connection_id;
     QuicSocketAddress target_server_address;
     std::vector<char> packet;
@@ -58,31 +58,8 @@
                   << client_connection_id;
     return;
   }
-  QuicDataReader reader(message);
-  QuicDatagramFlowId flow_id;
-  if (!reader.ReadVarInt62(&flow_id)) {
-    QUIC_DLOG(ERROR) << "Failed to parse flow_id";
-    return;
-  }
-  auto it =
-      absl::c_find_if(connect_udp_client_states_,
-                      [flow_id](const ConnectUdpClientState& connect_udp) {
-                        return connect_udp.flow_id() == flow_id;
-                      });
-  if (it == connect_udp_client_states_.end()) {
-    QUIC_DLOG(ERROR) << "Received unknown flow_id " << flow_id;
-    return;
-  }
-  EncapsulatedClientSession* encapsulated_client_session =
-      it->encapsulated_client_session();
-  QuicSocketAddress target_server_address = it->target_server_address();
-  QUICHE_DCHECK_NE(encapsulated_client_session, nullptr);
-  QUICHE_DCHECK(target_server_address.IsInitialized());
-  absl::string_view packet = reader.ReadRemainingPayload();
-  encapsulated_client_session->ProcessPacket(packet, target_server_address);
-
-  QUIC_DVLOG(1) << "Sent " << packet.size()
-                << " bytes to connection for flow_id " << flow_id;
+  QUICHE_DCHECK_EQ(masque_mode_, MasqueMode::kOpen);
+  QuicSpdySession::OnMessageReceived(message);
 }
 
 void MasqueClientSession::OnMessageAcked(QuicMessageId message_id,
@@ -130,8 +107,9 @@
     return nullptr;
   }
 
-  connect_udp_client_states_.push_back(ConnectUdpClientState(
-      stream, encapsulated_client_session, flow_id, target_server_address));
+  connect_udp_client_states_.push_back(
+      ConnectUdpClientState(stream, encapsulated_client_session, this, flow_id,
+                            target_server_address));
   return &connect_udp_client_states_.back();
 }
 
@@ -155,26 +133,12 @@
   }
 
   QuicDatagramFlowId flow_id = connect_udp->flow_id();
-  size_t slice_length =
-      QuicDataWriter::GetVarInt62Len(flow_id) + packet.length();
-  QuicUniqueBufferPtr buffer = MakeUniqueBuffer(
-      connection()->helper()->GetStreamSendBufferAllocator(), slice_length);
-  QuicDataWriter writer(slice_length, buffer.get());
-  if (!writer.WriteVarInt62(flow_id)) {
-    QUIC_BUG << "Failed to write flow_id";
-    return;
-  }
-  if (!writer.WriteBytes(packet.data(), packet.length())) {
-    QUIC_BUG << "Failed to write packet";
-    return;
-  }
-
-  QuicMemSlice slice(std::move(buffer), slice_length);
-  MessageResult message_result = SendMessage(QuicMemSliceSpan(&slice));
+  MessageStatus message_status =
+      SendHttp3Datagram(connect_udp->flow_id(), packet);
 
   QUIC_DVLOG(1) << "Sent packet to " << target_server_address
                 << " compressed with flow ID " << flow_id
-                << " and got message result " << message_result;
+                << " and got message status " << message_status;
 }
 
 void MasqueClientSession::RegisterConnectionId(
@@ -225,7 +189,7 @@
     ConnectionCloseSource source) {
   QuicSpdyClientSession::OnConnectionClosed(frame, source);
   // Close all encapsulated sessions.
-  for (auto client_state : connect_udp_client_states_) {
+  for (const auto& client_state : connect_udp_client_states_) {
     client_state.encapsulated_client_session()->CloseConnection(
         QUIC_CONNECTION_CANCELLED, "Underlying MASQUE connection was closed",
         ConnectionCloseBehavior::SILENT_CLOSE);
@@ -253,4 +217,55 @@
   QuicSpdyClientSession::OnStreamClosed(stream_id);
 }
 
+MasqueClientSession::ConnectUdpClientState::ConnectUdpClientState(
+    QuicSpdyClientStream* stream,
+    EncapsulatedClientSession* encapsulated_client_session,
+    MasqueClientSession* masque_session,
+    QuicDatagramFlowId flow_id,
+    const QuicSocketAddress& target_server_address)
+    : stream_(stream),
+      encapsulated_client_session_(encapsulated_client_session),
+      masque_session_(masque_session),
+      flow_id_(flow_id),
+      target_server_address_(target_server_address) {
+  QUICHE_DCHECK_NE(masque_session_, nullptr);
+  masque_session_->RegisterHttp3FlowId(this->flow_id(), this);
+}
+
+MasqueClientSession::ConnectUdpClientState::~ConnectUdpClientState() {
+  if (flow_id_.has_value()) {
+    masque_session_->UnregisterHttp3FlowId(flow_id());
+  }
+}
+
+MasqueClientSession::ConnectUdpClientState::ConnectUdpClientState(
+    MasqueClientSession::ConnectUdpClientState&& other) {
+  *this = std::move(other);
+}
+
+MasqueClientSession::ConnectUdpClientState&
+MasqueClientSession::ConnectUdpClientState::operator=(
+    MasqueClientSession::ConnectUdpClientState&& other) {
+  stream_ = other.stream_;
+  encapsulated_client_session_ = other.encapsulated_client_session_;
+  masque_session_ = other.masque_session_;
+  flow_id_ = other.flow_id_;
+  target_server_address_ = other.target_server_address_;
+  other.flow_id_.reset();
+  if (flow_id_.has_value()) {
+    masque_session_->UnregisterHttp3FlowId(flow_id());
+    masque_session_->RegisterHttp3FlowId(flow_id(), this);
+  }
+  return *this;
+}
+
+void MasqueClientSession::ConnectUdpClientState::OnHttp3Datagram(
+    QuicDatagramFlowId flow_id,
+    absl::string_view payload) {
+  QUICHE_DCHECK_EQ(flow_id, this->flow_id());
+  encapsulated_client_session_->ProcessPacket(payload, target_server_address_);
+  QUIC_DVLOG(1) << "Sent " << payload.size()
+                << " bytes to connection for flow_id " << flow_id;
+}
+
 }  // namespace quic
diff --git a/quic/masque/masque_client_session.h b/quic/masque/masque_client_session.h
index f849c61..e055a76 100644
--- a/quic/masque/masque_client_session.h
+++ b/quic/masque/masque_client_session.h
@@ -100,33 +100,47 @@
 
  private:
   // State that the MasqueClientSession keeps for each CONNECT-UDP request.
-  class QUIC_NO_EXPORT ConnectUdpClientState {
+  class QUIC_NO_EXPORT ConnectUdpClientState
+      : public QuicSpdySession::Http3DatagramVisitor {
    public:
     // |stream| and |encapsulated_client_session| must be valid for the lifetime
     // of the ConnectUdpClientState.
     explicit ConnectUdpClientState(
         QuicSpdyClientStream* stream,
         EncapsulatedClientSession* encapsulated_client_session,
+        MasqueClientSession* masque_session,
         QuicDatagramFlowId flow_id,
-        const QuicSocketAddress& target_server_address)
-        : stream_(stream),
-          encapsulated_client_session_(encapsulated_client_session),
-          flow_id_(flow_id),
-          target_server_address_(target_server_address) {}
+        const QuicSocketAddress& target_server_address);
+
+    ~ConnectUdpClientState();
+
+    // Disallow copy but allow move.
+    ConnectUdpClientState(const ConnectUdpClientState&) = delete;
+    ConnectUdpClientState(ConnectUdpClientState&&);
+    ConnectUdpClientState& operator=(const ConnectUdpClientState&) = delete;
+    ConnectUdpClientState& operator=(ConnectUdpClientState&&);
 
     QuicSpdyClientStream* stream() const { return stream_; }
     EncapsulatedClientSession* encapsulated_client_session() const {
       return encapsulated_client_session_;
     }
-    QuicDatagramFlowId flow_id() const { return flow_id_; }
+    QuicDatagramFlowId flow_id() const {
+      QUICHE_DCHECK(flow_id_.has_value());
+      return *flow_id_;
+    }
     const QuicSocketAddress& target_server_address() const {
       return target_server_address_;
     }
 
+    // From QuicSpdySession::Http3DatagramVisitor.
+    void OnHttp3Datagram(QuicDatagramFlowId flow_id,
+                         absl::string_view payload) override;
+
    private:
     QuicSpdyClientStream* stream_;                            // Unowned.
     EncapsulatedClientSession* encapsulated_client_session_;  // Unowned.
-    QuicDatagramFlowId flow_id_;
+    MasqueClientSession* masque_session_;                     // Unowned.
+    absl::optional<QuicDatagramFlowId> flow_id_;
     QuicSocketAddress target_server_address_;
   };
 
diff --git a/quic/masque/masque_server_session.cc b/quic/masque/masque_server_session.cc
index de231ba..1e420b9 100644
--- a/quic/masque/masque_server_session.cc
+++ b/quic/masque/masque_server_session.cc
@@ -7,6 +7,7 @@
 #include <netdb.h>
 
 #include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
 #include "quic/core/quic_data_reader.h"
 #include "quic/core/quic_udp_socket.h"
 #include "quic/tools/quic_url.h"
@@ -95,11 +96,12 @@
       compression_engine_(this),
       masque_mode_(masque_mode) {
   masque_server_backend_->RegisterBackendClient(connection_id(), this);
+  QUICHE_DCHECK_NE(epoll_server_, nullptr);
 }
 
 void MasqueServerSession::OnMessageReceived(absl::string_view message) {
-  QUIC_DVLOG(1) << "Received DATAGRAM frame of length " << message.length();
   if (masque_mode_ == MasqueMode::kLegacy) {
+    QUIC_DVLOG(1) << "Received DATAGRAM frame of length " << message.length();
     QuicConnectionId client_connection_id, server_connection_id;
     QuicSocketAddress target_server_address;
     std::vector<char> packet;
@@ -132,33 +134,7 @@
     return;
   }
   QUICHE_DCHECK_EQ(masque_mode_, MasqueMode::kOpen);
-  QuicDataReader reader(message);
-  QuicDatagramFlowId flow_id;
-  if (!reader.ReadVarInt62(&flow_id)) {
-    QUIC_DLOG(ERROR) << "Failed to read flow_id";
-    return;
-  }
-
-  auto it =
-      absl::c_find_if(connect_udp_server_states_,
-                      [flow_id](const ConnectUdpServerState& connect_udp) {
-                        return connect_udp.flow_id() == flow_id;
-                      });
-  if (it == connect_udp_server_states_.end()) {
-    QUIC_DLOG(ERROR) << "Received unknown flow_id " << flow_id;
-    return;
-  }
-  QuicSocketAddress target_server_address = it->target_server_address();
-  QUICHE_DCHECK(target_server_address.IsInitialized());
-  QuicUdpSocketFd fd = it->fd();
-  QUICHE_DCHECK_NE(fd, kQuicInvalidSocketFd);
-  absl::string_view packet = reader.ReadRemainingPayload();
-  QuicUdpSocketApi socket_api;
-  QuicUdpPacketInfo packet_info;
-  packet_info.SetPeerAddress(target_server_address);
-  WriteResult write_result =
-      socket_api.WritePacket(fd, packet.data(), packet.length(), packet_info);
-  QUIC_DVLOG(1) << "Wrote packet to server with result " << write_result;
+  QuicSpdySession::OnMessageReceived(message);
 }
 
 void MasqueServerSession::OnMessageAcked(QuicMessageId message_id,
@@ -275,7 +251,7 @@
 
     connect_udp_server_states_.emplace_back(ConnectUdpServerState(
         flow_id, request_handler->stream_id(), target_server_address,
-        fd_wrapper.extract_fd(), epoll_server_));
+        fd_wrapper.extract_fd(), this));
 
     spdy::Http2HeaderBlock response_headers;
     response_headers[":status"] = "200";
@@ -400,26 +376,12 @@
       return;
     }
     // The packet is valid, send it to the client in a DATAGRAM frame.
-    size_t slice_length = QuicDataWriter::GetVarInt62Len(flow_id) +
-                          read_result.packet_buffer.buffer_len;
-    QuicUniqueBufferPtr buffer = MakeUniqueBuffer(
-        connection()->helper()->GetStreamSendBufferAllocator(), slice_length);
-    QuicDataWriter writer(slice_length, buffer.get());
-    if (!writer.WriteVarInt62(flow_id)) {
-      QUIC_BUG << "Failed to write flow_id";
-      continue;
-    }
-    if (!writer.WriteBytes(read_result.packet_buffer.buffer,
-                           read_result.packet_buffer.buffer_len)) {
-      QUIC_BUG << "Failed to write packet";
-      continue;
-    }
-    QUICHE_DCHECK_EQ(writer.remaining(), 0u);
-    QuicMemSlice slice(std::move(buffer), slice_length);
-    MessageResult message_result = SendMessage(QuicMemSliceSpan(&slice));
+    MessageStatus message_status = SendHttp3Datagram(
+        flow_id, absl::string_view(read_result.packet_buffer.buffer,
+                                   read_result.packet_buffer.buffer_len));
     QUIC_DVLOG(1) << "Sent UDP packet from target server of length "
                   << read_result.packet_buffer.buffer_len << " with flow ID "
-                  << flow_id << " and got message result " << message_result;
+                  << flow_id << " and got message status " << message_status;
   }
 }
 
@@ -442,23 +404,27 @@
     QuicStreamId stream_id,
     const QuicSocketAddress& target_server_address,
     QuicUdpSocketFd fd,
-    QuicEpollServer* epoll_server)
+    MasqueServerSession* masque_session)
     : flow_id_(flow_id),
       stream_id_(stream_id),
       target_server_address_(target_server_address),
       fd_(fd),
-      epoll_server_(epoll_server) {
+      masque_session_(masque_session) {
   QUICHE_DCHECK_NE(fd_, kQuicInvalidSocketFd);
-  QUICHE_DCHECK_NE(epoll_server_, nullptr);
+  QUICHE_DCHECK_NE(masque_session_, nullptr);
+  masque_session_->RegisterHttp3FlowId(this->flow_id(), this);
 }
 
 MasqueServerSession::ConnectUdpServerState::~ConnectUdpServerState() {
+  if (flow_id_.has_value()) {
+    masque_session_->UnregisterHttp3FlowId(flow_id());
+  }
   if (fd_ == kQuicInvalidSocketFd) {
     return;
   }
   QuicUdpSocketApi socket_api;
   QUIC_DLOG(INFO) << "Closing fd " << fd_;
-  epoll_server_->UnregisterFD(fd_);
+  masque_session_->epoll_server()->UnregisterFD(fd_);
   socket_api.Destroy(fd_);
 }
 
@@ -474,16 +440,33 @@
   if (fd_ != kQuicInvalidSocketFd) {
     QuicUdpSocketApi socket_api;
     QUIC_DLOG(INFO) << "Closing fd " << fd_;
-    epoll_server_->UnregisterFD(fd_);
+    masque_session_->epoll_server()->UnregisterFD(fd_);
     socket_api.Destroy(fd_);
   }
   flow_id_ = other.flow_id_;
   stream_id_ = other.stream_id_;
   target_server_address_ = other.target_server_address_;
   fd_ = other.fd_;
-  epoll_server_ = other.epoll_server_;
+  masque_session_ = other.masque_session_;
   other.fd_ = kQuicInvalidSocketFd;
+  other.flow_id_.reset();
+  if (flow_id_.has_value()) {
+    masque_session_->UnregisterHttp3FlowId(flow_id());
+    masque_session_->RegisterHttp3FlowId(flow_id(), this);
+  }
   return *this;
 }
 
+void MasqueServerSession::ConnectUdpServerState::OnHttp3Datagram(
+    QuicDatagramFlowId flow_id,
+    absl::string_view payload) {
+  QUICHE_DCHECK_EQ(flow_id, this->flow_id());
+  QuicUdpSocketApi socket_api;
+  QuicUdpPacketInfo packet_info;
+  packet_info.SetPeerAddress(target_server_address_);
+  WriteResult write_result = socket_api.WritePacket(
+      fd_, payload.data(), payload.length(), packet_info);
+  QUIC_DVLOG(1) << "Wrote packet to server with result " << write_result;
+}
+
 }  // namespace quic
diff --git a/quic/masque/masque_server_session.h b/quic/masque/masque_server_session.h
index a73e641..14b8293 100644
--- a/quic/masque/masque_server_session.h
+++ b/quic/masque/masque_server_session.h
@@ -83,9 +83,12 @@
   // Handle packet for client, meant to be called by MasqueDispatcher.
   void HandlePacketFromServer(const ReceivedPacketInfo& packet_info);
 
+  QuicEpollServer* epoll_server() const { return epoll_server_; }
+
  private:
   // State that the MasqueServerSession keeps for each CONNECT-UDP request.
-  class QUIC_NO_EXPORT ConnectUdpServerState {
+  class QUIC_NO_EXPORT ConnectUdpServerState
+      : public QuicSpdySession::Http3DatagramVisitor {
    public:
     // ConnectUdpServerState takes ownership of |fd|. It will unregister it
     // from |epoll_server| and close the file descriptor when destructed.
@@ -94,7 +97,7 @@
         QuicStreamId stream_id,
         const QuicSocketAddress& target_server_address,
         QuicUdpSocketFd fd,
-        QuicEpollServer* epoll_server);
+        MasqueServerSession* masque_session);
 
     ~ConnectUdpServerState();
 
@@ -104,19 +107,26 @@
     ConnectUdpServerState& operator=(const ConnectUdpServerState&) = delete;
     ConnectUdpServerState& operator=(ConnectUdpServerState&&);
 
-    QuicDatagramFlowId flow_id() const { return flow_id_; }
+    QuicDatagramFlowId flow_id() const {
+      QUICHE_DCHECK(flow_id_.has_value());
+      return *flow_id_;
+    }
     QuicStreamId stream_id() const { return stream_id_; }
     const QuicSocketAddress& target_server_address() const {
       return target_server_address_;
     }
     QuicUdpSocketFd fd() const { return fd_; }
 
+    // From QuicSpdySession::Http3DatagramVisitor.
+    void OnHttp3Datagram(QuicDatagramFlowId flow_id,
+                         absl::string_view payload) override;
+
    private:
-    QuicDatagramFlowId flow_id_;
+    absl::optional<QuicDatagramFlowId> flow_id_;
     QuicStreamId stream_id_;
     QuicSocketAddress target_server_address_;
     QuicUdpSocketFd fd_;             // Owned.
-    QuicEpollServer* epoll_server_;  // Unowned.
+    MasqueServerSession* masque_session_;  // Unowned.
   };
 
   MasqueServerBackend* masque_server_backend_;  // Unowned.
diff --git a/quic/test_tools/quic_spdy_session_peer.cc b/quic/test_tools/quic_spdy_session_peer.cc
index 58eedcc..d41ddfd 100644
--- a/quic/test_tools/quic_spdy_session_peer.cc
+++ b/quic/test_tools/quic_spdy_session_peer.cc
@@ -109,5 +109,11 @@
   return session->qpack_encoder_receive_stream_;
 }
 
+// static
+void QuicSpdySessionPeer::SetH3DatagramSupported(QuicSpdySession* session,
+                                                 bool h3_datagram_supported) {
+  session->h3_datagram_supported_ = h3_datagram_supported;
+}
+
 }  // namespace test
 }  // namespace quic
diff --git a/quic/test_tools/quic_spdy_session_peer.h b/quic/test_tools/quic_spdy_session_peer.h
index c0413c6..4ad0367 100644
--- a/quic/test_tools/quic_spdy_session_peer.h
+++ b/quic/test_tools/quic_spdy_session_peer.h
@@ -57,6 +57,8 @@
       QuicSpdySession* session);
   static QpackReceiveStream* GetQpackEncoderReceiveStream(
       QuicSpdySession* session);
+  static void SetH3DatagramSupported(QuicSpdySession* session,
+                                     bool h3_datagram_supported);
 };
 
 }  // namespace test
diff --git a/quic/test_tools/quic_test_utils.h b/quic/test_tools/quic_test_utils.h
index fe255b4..6d5e52b 100644
--- a/quic/test_tools/quic_test_utils.h
+++ b/quic/test_tools/quic_test_utils.h
@@ -2284,6 +2284,32 @@
     const char* source_connection_id_bytes,
     uint8_t source_connection_id_length);
 
+// Implementation of Http3DatagramVisitor which saves all received datagrams.
+class SavingHttp3DatagramVisitor
+    : public QuicSpdySession::Http3DatagramVisitor {
+ public:
+  struct SavedHttp3Datagram {
+    QuicDatagramFlowId flow_id;
+    std::string payload;
+    bool operator==(const SavedHttp3Datagram& o) const {
+      return flow_id == o.flow_id && payload == o.payload;
+    }
+  };
+  const std::vector<SavedHttp3Datagram>& received_h3_datagrams() const {
+    return received_h3_datagrams_;
+  }
+
+  // Override from QuicSpdySession::Http3DatagramVisitor.
+  void OnHttp3Datagram(QuicDatagramFlowId flow_id,
+                       absl::string_view payload) override {
+    received_h3_datagrams_.push_back(
+        SavedHttp3Datagram{flow_id, std::string(payload)});
+  }
+
+ private:
+  std::vector<SavedHttp3Datagram> received_h3_datagrams_;
+};
+
 }  // namespace test
 }  // namespace quic