Implement MASQUE CONNECT-UDP

Previously, the MASQUE code was supporting the legacy MASQUE protocol. This CL adds support for CONNECT-UDP, the new MASQUE wire format. This is test code currently meant for IETF interop events. We're mainly landing it to prevent bitrot. This code isn't used in production.

PiperOrigin-RevId: 359575658
Change-Id: Ia883662fb6ffe7827c15b1460f2b316494b7ffbd
diff --git a/quic/masque/masque_client_bin.cc b/quic/masque/masque_client_bin.cc
index 8e09a12..afd991a 100644
--- a/quic/masque/masque_client_bin.cc
+++ b/quic/masque/masque_client_bin.cc
@@ -29,6 +29,12 @@
                               false,
                               "If true, don't verify the server certificate.");
 
+DEFINE_QUIC_COMMAND_LINE_FLAG(std::string,
+                              masque_mode,
+                              "",
+                              "Allows setting MASQUE mode, valid values are "
+                              "open and legacy. Defaults to open.");
+
 namespace quic {
 
 namespace {
@@ -54,7 +60,8 @@
     masque_url = QuicUrl(absl::StrCat("https://", urls[0]), "https");
   }
   if (masque_url.host().empty()) {
-    QUIC_LOG(ERROR) << "Failed to parse MASQUE server address " << urls[0];
+    std::cerr << "Failed to parse MASQUE server address \"" << urls[0] << "\""
+              << std::endl;
     return 1;
   }
   std::unique_ptr<ProofVerifier> proof_verifier;
@@ -63,15 +70,23 @@
   } else {
     proof_verifier = CreateDefaultProofVerifier(masque_url.host());
   }
-  std::unique_ptr<MasqueEpollClient> masque_client =
-      MasqueEpollClient::Create(masque_url.host(), masque_url.port(),
-                                &epoll_server, std::move(proof_verifier));
+  MasqueMode masque_mode = MasqueMode::kOpen;
+  std::string mode_string = GetQuicFlag(FLAGS_masque_mode);
+  if (mode_string == "legacy") {
+    masque_mode = MasqueMode::kLegacy;
+  } else if (!mode_string.empty() && mode_string != "open") {
+    std::cerr << "Invalid masque_mode \"" << mode_string << "\"" << std::endl;
+    return 1;
+  }
+  std::unique_ptr<MasqueEpollClient> masque_client = MasqueEpollClient::Create(
+      masque_url.host(), masque_url.port(), masque_mode, &epoll_server,
+      std::move(proof_verifier));
   if (masque_client == nullptr) {
     return 1;
   }
 
   std::cerr << "MASQUE is connected " << masque_client->connection_id()
-            << std::endl;
+            << " in " << masque_mode << " mode" << std::endl;
 
   for (size_t i = 1; i < urls.size(); ++i) {
     if (!tools::SendEncapsulatedMasqueRequest(
diff --git a/quic/masque/masque_client_session.cc b/quic/masque/masque_client_session.cc
index eeedccf..5663c4a 100644
--- a/quic/masque/masque_client_session.cc
+++ b/quic/masque/masque_client_session.cc
@@ -3,10 +3,14 @@
 // found in the LICENSE file.
 
 #include "quic/masque/masque_client_session.h"
+#include "absl/algorithm/container.h"
+#include "quic/core/quic_data_reader.h"
+#include "common/platform/api/quiche_text_utils.h"
 
 namespace quic {
 
 MasqueClientSession::MasqueClientSession(
+    MasqueMode masque_mode,
     const QuicConfig& config,
     const ParsedQuicVersionVector& supported_versions,
     QuicConnection* connection,
@@ -20,36 +24,65 @@
                             server_id,
                             crypto_config,
                             push_promise_index),
+      masque_mode_(masque_mode),
       owner_(owner),
       compression_engine_(this) {}
 
 void MasqueClientSession::OnMessageReceived(absl::string_view message) {
   QUIC_DVLOG(1) << "Received DATAGRAM frame of length " << message.length();
+  if (masque_mode_ == MasqueMode::kLegacy) {
+    QuicConnectionId client_connection_id, server_connection_id;
+    QuicSocketAddress target_server_address;
+    std::vector<char> packet;
+    bool version_present;
+    if (!compression_engine_.DecompressDatagram(
+            message, &client_connection_id, &server_connection_id,
+            &target_server_address, &packet, &version_present)) {
+      return;
+    }
 
-  QuicConnectionId client_connection_id, server_connection_id;
-  QuicSocketAddress server_address;
-  std::vector<char> packet;
-  bool version_present;
-  if (!compression_engine_.DecompressDatagram(
-          message, &client_connection_id, &server_connection_id,
-          &server_address, &packet, &version_present)) {
+    auto connection_id_registration =
+        client_connection_id_registrations_.find(client_connection_id);
+    if (connection_id_registration ==
+        client_connection_id_registrations_.end()) {
+      QUIC_DLOG(ERROR) << "MasqueClientSession failed to dispatch "
+                       << client_connection_id;
+      return;
+    }
+    EncapsulatedClientSession* encapsulated_client_session =
+        connection_id_registration->second;
+    encapsulated_client_session->ProcessPacket(
+        absl::string_view(packet.data(), packet.size()), target_server_address);
+
+    QUIC_DVLOG(1) << "Sent " << packet.size() << " bytes to connection for "
+                  << client_connection_id;
     return;
   }
-
-  auto connection_id_registration =
-      client_connection_id_registrations_.find(client_connection_id);
-  if (connection_id_registration == client_connection_id_registrations_.end()) {
-    QUIC_DLOG(ERROR) << "MasqueClientSession failed to dispatch "
-                     << client_connection_id;
+  QuicDataReader reader(message);
+  QuicDatagramFlowId flow_id;
+  if (!reader.ReadVarInt62(&flow_id)) {
+    QUIC_DLOG(ERROR) << "Failed to parse flow_id";
+    return;
+  }
+  auto it =
+      absl::c_find_if(connect_udp_client_states_,
+                      [flow_id](const ConnectUdpClientState& connect_udp) {
+                        return connect_udp.flow_id() == flow_id;
+                      });
+  if (it == connect_udp_client_states_.end()) {
+    QUIC_DLOG(ERROR) << "Received unknown flow_id " << flow_id;
     return;
   }
   EncapsulatedClientSession* encapsulated_client_session =
-      connection_id_registration->second;
-  encapsulated_client_session->ProcessPacket(
-      absl::string_view(packet.data(), packet.size()), server_address);
+      it->encapsulated_client_session();
+  QuicSocketAddress target_server_address = it->target_server_address();
+  QUICHE_DCHECK_NE(encapsulated_client_session, nullptr);
+  QUICHE_DCHECK(target_server_address.IsInitialized());
+  absl::string_view packet = reader.ReadRemainingPayload();
+  encapsulated_client_session->ProcessPacket(packet, target_server_address);
 
-  QUIC_DVLOG(1) << "Sent " << packet.size() << " bytes to connection for "
-                << client_connection_id;
+  QUIC_DVLOG(1) << "Sent " << packet.size()
+                << " bytes to connection for flow_id " << flow_id;
 }
 
 void MasqueClientSession::OnMessageAcked(QuicMessageId message_id,
@@ -61,12 +94,87 @@
   QUIC_DVLOG(1) << "We believe DATAGRAM frame " << message_id << " was lost";
 }
 
-void MasqueClientSession::SendPacket(QuicConnectionId client_connection_id,
-                                     QuicConnectionId server_connection_id,
-                                     absl::string_view packet,
-                                     const QuicSocketAddress& server_address) {
-  compression_engine_.CompressAndSendPacket(
-      packet, client_connection_id, server_connection_id, server_address);
+const MasqueClientSession::ConnectUdpClientState*
+MasqueClientSession::GetOrCreateConnectUdpClientState(
+    const QuicSocketAddress& target_server_address,
+    EncapsulatedClientSession* encapsulated_client_session) {
+  for (const ConnectUdpClientState& client_state : connect_udp_client_states_) {
+    if (client_state.target_server_address() == target_server_address &&
+        client_state.encapsulated_client_session() ==
+            encapsulated_client_session) {
+      // Found existing CONNECT-UDP request.
+      return &client_state;
+    }
+  }
+  // No CONNECT-UDP request found, create a new one.
+  QuicSpdyClientStream* stream = CreateOutgoingBidirectionalStream();
+  if (stream == nullptr) {
+    // Stream flow control limits prevented us from opening a new stream.
+    QUIC_DLOG(ERROR) << "Failed to open CONNECT-UDP stream";
+    return nullptr;
+  }
+
+  QuicDatagramFlowId flow_id = compression_engine_.GetNextFlowId();
+
+  // Send the request.
+  spdy::Http2HeaderBlock headers;
+  headers[":method"] = "CONNECT-UDP";
+  headers[":scheme"] = "masque";
+  headers[":path"] = "/";
+  headers[":authority"] = target_server_address.ToString();
+  headers["datagram-flow-id"] = absl::StrCat(flow_id);
+  size_t bytes_sent =
+      stream->SendRequest(std::move(headers), /*body=*/"", /*fin=*/false);
+  if (bytes_sent == 0) {
+    QUIC_DLOG(ERROR) << "Failed to send CONNECT-UDP request";
+    return nullptr;
+  }
+
+  connect_udp_client_states_.push_back(ConnectUdpClientState(
+      stream, encapsulated_client_session, flow_id, target_server_address));
+  return &connect_udp_client_states_.back();
+}
+
+void MasqueClientSession::SendPacket(
+    QuicConnectionId client_connection_id,
+    QuicConnectionId server_connection_id,
+    absl::string_view packet,
+    const QuicSocketAddress& target_server_address,
+    EncapsulatedClientSession* encapsulated_client_session) {
+  if (masque_mode_ == MasqueMode::kLegacy) {
+    compression_engine_.CompressAndSendPacket(packet, client_connection_id,
+                                              server_connection_id,
+                                              target_server_address);
+    return;
+  }
+  const ConnectUdpClientState* connect_udp = GetOrCreateConnectUdpClientState(
+      target_server_address, encapsulated_client_session);
+  if (connect_udp == nullptr) {
+    QUIC_DLOG(ERROR) << "Failed to create CONNECT-UDP request";
+    return;
+  }
+
+  QuicDatagramFlowId flow_id = connect_udp->flow_id();
+  size_t slice_length =
+      QuicDataWriter::GetVarInt62Len(flow_id) + packet.length();
+  QuicUniqueBufferPtr buffer = MakeUniqueBuffer(
+      connection()->helper()->GetStreamSendBufferAllocator(), slice_length);
+  QuicDataWriter writer(slice_length, buffer.get());
+  if (!writer.WriteVarInt62(flow_id)) {
+    QUIC_BUG << "Failed to write flow_id";
+    return;
+  }
+  if (!writer.WriteBytes(packet.data(), packet.length())) {
+    QUIC_BUG << "Failed to write packet";
+    return;
+  }
+
+  QuicMemSlice slice(std::move(buffer), slice_length);
+  MessageResult message_result = SendMessage(QuicMemSliceSpan(&slice));
+
+  QUIC_DVLOG(1) << "Sent packet to " << target_server_address
+                << " compressed with flow ID " << flow_id
+                << " and got message result " << message_result;
 }
 
 void MasqueClientSession::RegisterConnectionId(
@@ -84,14 +192,65 @@
 }
 
 void MasqueClientSession::UnregisterConnectionId(
-    QuicConnectionId client_connection_id) {
+    QuicConnectionId client_connection_id,
+    EncapsulatedClientSession* encapsulated_client_session) {
   QUIC_DLOG(INFO) << "Unregistering " << client_connection_id;
-  if (client_connection_id_registrations_.find(client_connection_id) !=
-      client_connection_id_registrations_.end()) {
-    client_connection_id_registrations_.erase(client_connection_id);
-    owner_->UnregisterClientConnectionId(client_connection_id);
-    compression_engine_.UnregisterClientConnectionId(client_connection_id);
+  if (masque_mode_ == MasqueMode::kLegacy) {
+    if (client_connection_id_registrations_.find(client_connection_id) !=
+        client_connection_id_registrations_.end()) {
+      client_connection_id_registrations_.erase(client_connection_id);
+      owner_->UnregisterClientConnectionId(client_connection_id);
+      compression_engine_.UnregisterClientConnectionId(client_connection_id);
+    }
+    return;
   }
+
+  for (auto it = connect_udp_client_states_.begin();
+       it != connect_udp_client_states_.end();) {
+    if (it->encapsulated_client_session() == encapsulated_client_session) {
+      QUIC_DLOG(INFO) << "Removing state for flow_id " << it->flow_id();
+      auto* stream = it->stream();
+      it = connect_udp_client_states_.erase(it);
+      if (!stream->write_side_closed()) {
+        stream->Reset(QUIC_STREAM_CANCELLED);
+      }
+    } else {
+      ++it;
+    }
+  }
+}
+
+void MasqueClientSession::OnConnectionClosed(
+    const QuicConnectionCloseFrame& frame,
+    ConnectionCloseSource source) {
+  QuicSpdyClientSession::OnConnectionClosed(frame, source);
+  // Close all encapsulated sessions.
+  for (auto client_state : connect_udp_client_states_) {
+    client_state.encapsulated_client_session()->CloseConnection(
+        QUIC_CONNECTION_CANCELLED, "Underlying MASQUE connection was closed",
+        ConnectionCloseBehavior::SILENT_CLOSE);
+  }
+}
+
+void MasqueClientSession::OnStreamClosed(QuicStreamId stream_id) {
+  for (auto it = connect_udp_client_states_.begin();
+       it != connect_udp_client_states_.end();) {
+    if (it->stream()->id() == stream_id) {
+      QUIC_DLOG(INFO) << "Stream " << stream_id
+                      << " was closed, removing state for flow_id "
+                      << it->flow_id();
+      auto* encapsulated_client_session = it->encapsulated_client_session();
+      it = connect_udp_client_states_.erase(it);
+      encapsulated_client_session->CloseConnection(
+          QUIC_CONNECTION_CANCELLED,
+          "Underlying MASQUE CONNECT-UDP stream was closed",
+          ConnectionCloseBehavior::SILENT_CLOSE);
+    } else {
+      ++it;
+    }
+  }
+
+  QuicSpdyClientSession::OnStreamClosed(stream_id);
 }
 
 }  // namespace quic
diff --git a/quic/masque/masque_client_session.h b/quic/masque/masque_client_session.h
index 0eab5b2..f849c61 100644
--- a/quic/masque/masque_client_session.h
+++ b/quic/masque/masque_client_session.h
@@ -9,6 +9,7 @@
 #include "absl/strings/string_view.h"
 #include "quic/core/http/quic_spdy_client_session.h"
 #include "quic/masque/masque_compression_engine.h"
+#include "quic/masque/masque_utils.h"
 #include "quic/platform/api/quic_export.h"
 #include "quic/platform/api/quic_socket_address.h"
 
@@ -39,14 +40,21 @@
 
     // Process packet that was just decapsulated.
     virtual void ProcessPacket(absl::string_view packet,
-                               QuicSocketAddress server_address) = 0;
+                               QuicSocketAddress target_server_address) = 0;
+
+    // Close the encapsulated connection.
+    virtual void CloseConnection(
+        QuicErrorCode error,
+        const std::string& details,
+        ConnectionCloseBehavior connection_close_behavior) = 0;
   };
 
   // Takes ownership of |connection|, but not of |crypto_config| or
   // |push_promise_index| or |owner|. All pointers must be non-null. Caller
   // must ensure that |push_promise_index| and |owner| stay valid for the
   // lifetime of the newly created MasqueClientSession.
-  MasqueClientSession(const QuicConfig& config,
+  MasqueClientSession(MasqueMode masque_mode,
+                      const QuicConfig& config,
                       const ParsedQuicVersionVector& supported_versions,
                       QuicConnection* connection,
                       const QuicServerId& server_id,
@@ -60,17 +68,19 @@
 
   // From QuicSession.
   void OnMessageReceived(absl::string_view message) override;
-
   void OnMessageAcked(QuicMessageId message_id,
                       QuicTime receive_timestamp) override;
-
   void OnMessageLost(QuicMessageId message_id) override;
+  void OnConnectionClosed(const QuicConnectionCloseFrame& frame,
+                          ConnectionCloseSource source) override;
+  void OnStreamClosed(QuicStreamId stream_id) override;
 
   // Send encapsulated packet.
   void SendPacket(QuicConnectionId client_connection_id,
                   QuicConnectionId server_connection_id,
                   absl::string_view packet,
-                  const QuicSocketAddress& server_address);
+                  const QuicSocketAddress& target_server_address,
+                  EncapsulatedClientSession* encapsulated_client_session);
 
   // Register encapsulated client. This allows clients that are encapsulated
   // within this MASQUE session to indicate they own a given client connection
@@ -84,9 +94,48 @@
 
   // Unregister encapsulated client. |client_connection_id| must match a
   // value previously passed to RegisterConnectionId.
-  void UnregisterConnectionId(QuicConnectionId client_connection_id);
+  void UnregisterConnectionId(
+      QuicConnectionId client_connection_id,
+      EncapsulatedClientSession* encapsulated_client_session);
 
  private:
+  // State that the MasqueClientSession keeps for each CONNECT-UDP request.
+  class QUIC_NO_EXPORT ConnectUdpClientState {
+   public:
+    // |stream| and |encapsulated_client_session| must be valid for the lifetime
+    // of the ConnectUdpClientState.
+    explicit ConnectUdpClientState(
+        QuicSpdyClientStream* stream,
+        EncapsulatedClientSession* encapsulated_client_session,
+        QuicDatagramFlowId flow_id,
+        const QuicSocketAddress& target_server_address)
+        : stream_(stream),
+          encapsulated_client_session_(encapsulated_client_session),
+          flow_id_(flow_id),
+          target_server_address_(target_server_address) {}
+
+    QuicSpdyClientStream* stream() const { return stream_; }
+    EncapsulatedClientSession* encapsulated_client_session() const {
+      return encapsulated_client_session_;
+    }
+    QuicDatagramFlowId flow_id() const { return flow_id_; }
+    const QuicSocketAddress& target_server_address() const {
+      return target_server_address_;
+    }
+
+   private:
+    QuicSpdyClientStream* stream_;                            // Unowned.
+    EncapsulatedClientSession* encapsulated_client_session_;  // Unowned.
+    QuicDatagramFlowId flow_id_;
+    QuicSocketAddress target_server_address_;
+  };
+
+  const ConnectUdpClientState* GetOrCreateConnectUdpClientState(
+      const QuicSocketAddress& target_server_address,
+      EncapsulatedClientSession* encapsulated_client_session);
+
+  MasqueMode masque_mode_;
+  std::list<ConnectUdpClientState> connect_udp_client_states_;
   absl::flat_hash_map<QuicConnectionId,
                       EncapsulatedClientSession*,
                       QuicConnectionIdHash>
diff --git a/quic/masque/masque_compression_engine.h b/quic/masque/masque_compression_engine.h
index 96c5948..16b4549 100644
--- a/quic/masque/masque_compression_engine.h
+++ b/quic/masque/masque_compression_engine.h
@@ -70,6 +70,9 @@
   // compression table.
   void UnregisterClientConnectionId(QuicConnectionId client_connection_id);
 
+  // Generates a new datagram flow ID.
+  QuicDatagramFlowId GetNextFlowId();
+
  private:
   struct QUIC_NO_EXPORT MasqueCompressionContext {
     QuicConnectionId client_connection_id;
@@ -78,9 +81,6 @@
     bool validated = false;
   };
 
-  // Generates a new datagram flow ID.
-  QuicDatagramFlowId GetNextFlowId();
-
   // Finds or creates a new compression context to use during compression.
   // |client_connection_id_present| and |server_connection_id_present| indicate
   // whether the corresponding connection ID is present in the current packet.
diff --git a/quic/masque/masque_dispatcher.cc b/quic/masque/masque_dispatcher.cc
index 6edca62..03eb8f8 100644
--- a/quic/masque/masque_dispatcher.cc
+++ b/quic/masque/masque_dispatcher.cc
@@ -8,9 +8,11 @@
 namespace quic {
 
 MasqueDispatcher::MasqueDispatcher(
+    MasqueMode masque_mode,
     const QuicConfig* config,
     const QuicCryptoServerConfig* crypto_config,
     QuicVersionManager* version_manager,
+    QuicEpollServer* epoll_server,
     std::unique_ptr<QuicConnectionHelperInterface> helper,
     std::unique_ptr<QuicCryptoServerStreamBase::Helper> session_helper,
     std::unique_ptr<QuicAlarmFactory> alarm_factory,
@@ -24,6 +26,8 @@
                            std::move(alarm_factory),
                            masque_server_backend,
                            expected_server_connection_id_length),
+      masque_mode_(masque_mode),
+      epoll_server_(epoll_server),
       masque_server_backend_(masque_server_backend) {}
 
 std::unique_ptr<QuicSession> MasqueDispatcher::CreateQuicSession(
@@ -40,9 +44,9 @@
                          ParsedQuicVersionVector{version});
 
   auto session = std::make_unique<MasqueServerSession>(
-      config(), GetSupportedVersions(), connection, this, this,
-      session_helper(), crypto_config(), compressed_certs_cache(),
-      masque_server_backend_);
+      masque_mode_, config(), GetSupportedVersions(), connection, this, this,
+      epoll_server_, session_helper(), crypto_config(),
+      compressed_certs_cache(), masque_server_backend_);
   session->Initialize();
   return session;
 }
diff --git a/quic/masque/masque_dispatcher.h b/quic/masque/masque_dispatcher.h
index 2e0186c..ce371b5 100644
--- a/quic/masque/masque_dispatcher.h
+++ b/quic/masque/masque_dispatcher.h
@@ -8,6 +8,8 @@
 #include "absl/container/flat_hash_map.h"
 #include "quic/masque/masque_server_backend.h"
 #include "quic/masque/masque_server_session.h"
+#include "quic/masque/masque_utils.h"
+#include "quic/platform/api/quic_epoll.h"
 #include "quic/platform/api/quic_export.h"
 #include "quic/tools/quic_simple_dispatcher.h"
 
@@ -19,9 +21,11 @@
                                         public MasqueServerSession::Visitor {
  public:
   explicit MasqueDispatcher(
+      MasqueMode masque_mode,
       const QuicConfig* config,
       const QuicCryptoServerConfig* crypto_config,
       QuicVersionManager* version_manager,
+      QuicEpollServer* epoll_server,
       std::unique_ptr<QuicConnectionHelperInterface> helper,
       std::unique_ptr<QuicCryptoServerStreamBase::Helper> session_helper,
       std::unique_ptr<QuicAlarmFactory> alarm_factory,
@@ -51,6 +55,8 @@
       QuicConnectionId client_connection_id) override;
 
  private:
+  MasqueMode masque_mode_;
+  QuicEpollServer* epoll_server_;               // Unowned.
   MasqueServerBackend* masque_server_backend_;  // Unowned.
   // Mapping from client connection IDs to server sessions, allows routing
   // incoming packets to the right MASQUE connection.
diff --git a/quic/masque/masque_encapsulated_client_session.cc b/quic/masque/masque_encapsulated_client_session.cc
index 983a094..3d557c3 100644
--- a/quic/masque/masque_encapsulated_client_session.cc
+++ b/quic/masque/masque_encapsulated_client_session.cc
@@ -31,11 +31,19 @@
                                  received_packet);
 }
 
+void MasqueEncapsulatedClientSession::CloseConnection(
+    QuicErrorCode error,
+    const std::string& details,
+    ConnectionCloseBehavior connection_close_behavior) {
+  connection()->CloseConnection(error, details, connection_close_behavior);
+}
+
 void MasqueEncapsulatedClientSession::OnConnectionClosed(
-    const QuicConnectionCloseFrame& /*frame*/,
-    ConnectionCloseSource /*source*/) {
+    const QuicConnectionCloseFrame& frame,
+    ConnectionCloseSource source) {
+  QuicSpdyClientSession::OnConnectionClosed(frame, source);
   masque_client_session_->UnregisterConnectionId(
-      connection()->client_connection_id());
+      connection()->client_connection_id(), this);
 }
 
 }  // namespace quic
diff --git a/quic/masque/masque_encapsulated_client_session.h b/quic/masque/masque_encapsulated_client_session.h
index 570ccdb..0952159 100644
--- a/quic/masque/masque_encapsulated_client_session.h
+++ b/quic/masque/masque_encapsulated_client_session.h
@@ -44,6 +44,10 @@
   // From MasqueClientSession::EncapsulatedClientSession.
   void ProcessPacket(absl::string_view packet,
                      QuicSocketAddress server_address) override;
+  void CloseConnection(
+      QuicErrorCode error,
+      const std::string& details,
+      ConnectionCloseBehavior connection_close_behavior) override;
 
   // From QuicSession.
   void OnConnectionClosed(const QuicConnectionCloseFrame& frame,
diff --git a/quic/masque/masque_encapsulated_epoll_client.cc b/quic/masque/masque_encapsulated_epoll_client.cc
index 8fb938d..8891e4a 100644
--- a/quic/masque/masque_encapsulated_epoll_client.cc
+++ b/quic/masque/masque_encapsulated_epoll_client.cc
@@ -30,8 +30,8 @@
     absl::string_view packet(buffer, buf_len);
     client_->masque_client()->masque_client_session()->SendPacket(
         client_->session()->connection()->client_connection_id(),
-        client_->session()->connection()->connection_id(), packet,
-        peer_address);
+        client_->session()->connection()->connection_id(), packet, peer_address,
+        client_->masque_encapsulated_client_session());
     return WriteResult(WRITE_STATUS_OK, buf_len);
   }
 
@@ -94,7 +94,7 @@
 
 MasqueEncapsulatedEpollClient::~MasqueEncapsulatedEpollClient() {
   masque_client_->masque_client_session()->UnregisterConnectionId(
-      client_connection_id_);
+      client_connection_id_, masque_encapsulated_client_session());
 }
 
 std::unique_ptr<QuicSession>
diff --git a/quic/masque/masque_epoll_client.cc b/quic/masque/masque_epoll_client.cc
index 1d77e55..bf09da4 100644
--- a/quic/masque/masque_epoll_client.cc
+++ b/quic/masque/masque_epoll_client.cc
@@ -12,6 +12,7 @@
 MasqueEpollClient::MasqueEpollClient(
     QuicSocketAddress server_address,
     const QuicServerId& server_id,
+    MasqueMode masque_mode,
     QuicEpollServer* epoll_server,
     std::unique_ptr<ProofVerifier> proof_verifier,
     const std::string& authority)
@@ -20,6 +21,7 @@
                  MasqueSupportedVersions(),
                  epoll_server,
                  std::move(proof_verifier)),
+      masque_mode_(masque_mode),
       authority_(authority) {}
 
 std::unique_ptr<QuicSession> MasqueEpollClient::CreateQuicClientSession(
@@ -28,8 +30,8 @@
   QUIC_DLOG(INFO) << "Creating MASQUE session for "
                   << connection->connection_id();
   return std::make_unique<MasqueClientSession>(
-      *config(), supported_versions, connection, server_id(), crypto_config(),
-      push_promise_index(), this);
+      masque_mode_, *config(), supported_versions, connection, server_id(),
+      crypto_config(), push_promise_index(), this);
 }
 
 MasqueClientSession* MasqueEpollClient::masque_client_session() {
@@ -44,6 +46,7 @@
 std::unique_ptr<MasqueEpollClient> MasqueEpollClient::Create(
     const std::string& host,
     int port,
+    MasqueMode masque_mode,
     QuicEpollServer* epoll_server,
     std::unique_ptr<ProofVerifier> proof_verifier) {
   // Build the masque_client, and try to connect.
@@ -57,7 +60,7 @@
   // std::make_unique<MasqueEpollClient>(...) because the constructor for
   // MasqueEpollClient is private and therefore not accessible from make_unique.
   auto masque_client = QuicWrapUnique(new MasqueEpollClient(
-      addr, server_id, epoll_server, std::move(proof_verifier),
+      addr, server_id, masque_mode, epoll_server, std::move(proof_verifier),
       absl::StrCat(host, ":", port)));
 
   if (masque_client == nullptr) {
@@ -78,33 +81,35 @@
     return nullptr;
   }
 
-  std::string body = "foo";
+  if (masque_client->masque_mode() == MasqueMode::kLegacy) {
+    // Construct the legacy mode init request.
+    spdy::Http2HeaderBlock header_block;
+    header_block[":method"] = "POST";
+    header_block[":scheme"] = "https";
+    header_block[":authority"] = masque_client->authority_;
+    header_block[":path"] = "/.well-known/masque/init";
+    std::string body = "foo";
 
-  // Construct a GET or POST request for supplied URL.
-  spdy::Http2HeaderBlock header_block;
-  header_block[":method"] = "POST";
-  header_block[":scheme"] = "https";
-  header_block[":authority"] = masque_client->authority_;
-  header_block[":path"] = "/.well-known/masque/init";
+    // Make sure to store the response, for later output.
+    masque_client->set_store_response(true);
 
-  // Make sure to store the response, for later output.
-  masque_client->set_store_response(true);
+    // Send the MASQUE init command.
+    masque_client->SendRequestAndWaitForResponse(header_block, body,
+                                                 /*fin=*/true);
 
-  // Send the MASQUE init command.
-  masque_client->SendRequestAndWaitForResponse(header_block, body,
-                                               /*fin=*/true);
+    if (!masque_client->connected()) {
+      QUIC_LOG(ERROR)
+          << "MASQUE init request caused connection failure. Error: "
+          << QuicErrorCodeToString(masque_client->session()->error());
+      return nullptr;
+    }
 
-  if (!masque_client->connected()) {
-    QUIC_LOG(ERROR) << "MASQUE init request caused connection failure. Error: "
-                    << QuicErrorCodeToString(masque_client->session()->error());
-    return nullptr;
-  }
-
-  const int response_code = masque_client->latest_response_code();
-  if (response_code != 200) {
-    QUIC_LOG(ERROR) << "MASQUE init request failed with HTTP response code "
-                    << response_code;
-    return nullptr;
+    const int response_code = masque_client->latest_response_code();
+    if (response_code != 200) {
+      QUIC_LOG(ERROR) << "MASQUE init request failed with HTTP response code "
+                      << response_code;
+      return nullptr;
+    }
   }
   return masque_client;
 }
diff --git a/quic/masque/masque_epoll_client.h b/quic/masque/masque_epoll_client.h
index 345c18c..1db161f 100644
--- a/quic/masque/masque_epoll_client.h
+++ b/quic/masque/masque_epoll_client.h
@@ -6,6 +6,7 @@
 #define QUICHE_QUIC_MASQUE_MASQUE_EPOLL_CLIENT_H_
 
 #include "quic/masque/masque_client_session.h"
+#include "quic/masque/masque_utils.h"
 #include "quic/platform/api/quic_export.h"
 #include "quic/tools/quic_client.h"
 
@@ -19,6 +20,7 @@
   static std::unique_ptr<MasqueEpollClient> Create(
       const std::string& host,
       int port,
+      MasqueMode masque_mode,
       QuicEpollServer* epoll_server,
       std::unique_ptr<ProofVerifier> proof_verifier);
 
@@ -37,10 +39,13 @@
   void UnregisterClientConnectionId(
       QuicConnectionId client_connection_id) override;
 
+  MasqueMode masque_mode() const { return masque_mode_; }
+
  private:
   // Constructor is private, use Create() instead.
   MasqueEpollClient(QuicSocketAddress server_address,
                     const QuicServerId& server_id,
+                    MasqueMode masque_mode,
                     QuicEpollServer* epoll_server,
                     std::unique_ptr<ProofVerifier> proof_verifier,
                     const std::string& authority);
@@ -49,6 +54,7 @@
   MasqueEpollClient(const MasqueEpollClient&) = delete;
   MasqueEpollClient& operator=(const MasqueEpollClient&) = delete;
 
+  MasqueMode masque_mode_;
   std::string authority_;
 };
 
diff --git a/quic/masque/masque_epoll_server.cc b/quic/masque/masque_epoll_server.cc
index 34769a3..e8a0850 100644
--- a/quic/masque/masque_epoll_server.cc
+++ b/quic/masque/masque_epoll_server.cc
@@ -11,15 +11,18 @@
 
 namespace quic {
 
-MasqueEpollServer::MasqueEpollServer(MasqueServerBackend* masque_server_backend)
+MasqueEpollServer::MasqueEpollServer(MasqueMode masque_mode,
+                                     MasqueServerBackend* masque_server_backend)
     : QuicServer(CreateDefaultProofSource(),
                  masque_server_backend,
                  MasqueSupportedVersions()),
+      masque_mode_(masque_mode),
       masque_server_backend_(masque_server_backend) {}
 
 QuicDispatcher* MasqueEpollServer::CreateQuicDispatcher() {
   return new MasqueDispatcher(
-      &config(), &crypto_config(), version_manager(),
+      masque_mode_, &config(), &crypto_config(), version_manager(),
+      epoll_server(),
       std::make_unique<QuicEpollConnectionHelper>(epoll_server(),
                                                   QuicAllocator::BUFFER_POOL),
       std::make_unique<QuicSimpleCryptoServerStreamHelper>(),
diff --git a/quic/masque/masque_epoll_server.h b/quic/masque/masque_epoll_server.h
index 57eab6a..88161ee 100644
--- a/quic/masque/masque_epoll_server.h
+++ b/quic/masque/masque_epoll_server.h
@@ -6,6 +6,7 @@
 #define QUICHE_QUIC_MASQUE_MASQUE_EPOLL_SERVER_H_
 
 #include "quic/masque/masque_server_backend.h"
+#include "quic/masque/masque_utils.h"
 #include "quic/platform/api/quic_export.h"
 #include "quic/tools/quic_server.h"
 
@@ -14,7 +15,8 @@
 // QUIC server that implements MASQUE.
 class QUIC_NO_EXPORT MasqueEpollServer : public QuicServer {
  public:
-  explicit MasqueEpollServer(MasqueServerBackend* masque_server_backend);
+  explicit MasqueEpollServer(MasqueMode masque_mode,
+                             MasqueServerBackend* masque_server_backend);
 
   // Disallow copy and assign.
   MasqueEpollServer(const MasqueEpollServer&) = delete;
@@ -24,6 +26,7 @@
   QuicDispatcher* CreateQuicDispatcher() override;
 
  private:
+  MasqueMode masque_mode_;
   MasqueServerBackend* masque_server_backend_;  // Unowned.
 };
 
diff --git a/quic/masque/masque_server_backend.cc b/quic/masque/masque_server_backend.cc
index 1994819..230f7e7 100644
--- a/quic/masque/masque_server_backend.cc
+++ b/quic/masque/masque_server_backend.cc
@@ -19,9 +19,10 @@
 
 }  // namespace
 
-MasqueServerBackend::MasqueServerBackend(const std::string& server_authority,
+MasqueServerBackend::MasqueServerBackend(MasqueMode masque_mode,
+                                         const std::string& server_authority,
                                          const std::string& cache_directory)
-    : server_authority_(server_authority) {
+    : masque_mode_(masque_mode), server_authority_(server_authority) {
   if (!cache_directory.empty()) {
     QuicMemoryCacheBackend::InitializeBackend(cache_directory);
   }
@@ -43,16 +44,24 @@
   absl::string_view path = path_pair->second;
   absl::string_view scheme = scheme_pair->second;
   absl::string_view method = method_pair->second;
-  if (scheme != "https" || method != "POST" || request_body.empty()) {
-    // MASQUE requests MUST be a non-empty https POST.
-    return false;
-  }
+  std::string masque_path = "";
+  if (masque_mode_ == MasqueMode::kLegacy) {
+    if (scheme != "https" || method != "POST" || request_body.empty()) {
+      // MASQUE requests MUST be a non-empty https POST.
+      return false;
+    }
 
-  if (path.rfind("/.well-known/masque/", 0) != 0) {
-    // This request is not a MASQUE path.
-    return false;
+    if (path.rfind("/.well-known/masque/", 0) != 0) {
+      // This request is not a MASQUE path.
+      return false;
+    }
+    masque_path = path.substr(sizeof("/.well-known/masque/") - 1);
+  } else {
+    if (method != "CONNECT-UDP") {
+      // Unexpected method.
+      return false;
+    }
   }
-  std::string masque_path(path.substr(sizeof("/.well-known/masque/") - 1));
 
   if (!server_authority_.empty()) {
     auto authority_pair = request_headers.find(":authority");
diff --git a/quic/masque/masque_server_backend.h b/quic/masque/masque_server_backend.h
index 1809359..86b1e00 100644
--- a/quic/masque/masque_server_backend.h
+++ b/quic/masque/masque_server_backend.h
@@ -6,6 +6,7 @@
 #define QUICHE_QUIC_MASQUE_MASQUE_SERVER_BACKEND_H_
 
 #include "absl/container/flat_hash_map.h"
+#include "quic/masque/masque_utils.h"
 #include "quic/platform/api/quic_export.h"
 #include "quic/tools/quic_memory_cache_backend.h"
 
@@ -27,7 +28,8 @@
     virtual ~BackendClient() = default;
   };
 
-  explicit MasqueServerBackend(const std::string& server_authority,
+  explicit MasqueServerBackend(MasqueMode masque_mode,
+                               const std::string& server_authority,
                                const std::string& cache_directory);
 
   // Disallow copy and assign.
@@ -57,6 +59,7 @@
       const std::string& request_body,
       QuicSimpleServerBackend::RequestHandler* request_handler);
 
+  MasqueMode masque_mode_;
   std::string server_authority_;
   absl::flat_hash_map<std::string, std::unique_ptr<QuicBackendResponse>>
       active_response_map_;
diff --git a/quic/masque/masque_server_bin.cc b/quic/masque/masque_server_bin.cc
index b25bf83..a6b6225 100644
--- a/quic/masque/masque_server_bin.cc
+++ b/quic/masque/masque_server_bin.cc
@@ -35,6 +35,12 @@
     "Specifies the authority over which the server will accept MASQUE "
     "requests. Defaults to empty which allows all authorities.");
 
+DEFINE_QUIC_COMMAND_LINE_FLAG(std::string,
+                              masque_mode,
+                              "",
+                              "Allows setting MASQUE mode, valid values are "
+                              "open and legacy. Defaults to open.");
+
 int main(int argc, char* argv[]) {
   const char* usage = "Usage: masque_server [options]";
   std::vector<std::string> non_option_args =
@@ -44,17 +50,28 @@
     return 0;
   }
 
-  auto backend = std::make_unique<quic::MasqueServerBackend>(
-      GetQuicFlag(FLAGS_server_authority), GetQuicFlag(FLAGS_cache_dir));
+  quic::MasqueMode masque_mode = quic::MasqueMode::kOpen;
+  std::string mode_string = GetQuicFlag(FLAGS_masque_mode);
+  if (mode_string == "legacy") {
+    masque_mode = quic::MasqueMode::kLegacy;
+  } else if (!mode_string.empty() && mode_string != "open") {
+    std::cerr << "Invalid masque_mode \"" << mode_string << "\"" << std::endl;
+    return 1;
+  }
 
-  auto server = std::make_unique<quic::MasqueEpollServer>(backend.get());
+  auto backend = std::make_unique<quic::MasqueServerBackend>(
+      masque_mode, GetQuicFlag(FLAGS_server_authority),
+      GetQuicFlag(FLAGS_cache_dir));
+
+  auto server =
+      std::make_unique<quic::MasqueEpollServer>(masque_mode, backend.get());
 
   if (!server->CreateUDPSocketAndListen(quic::QuicSocketAddress(
           quic::QuicIpAddress::Any6(), GetQuicFlag(FLAGS_port)))) {
     return 1;
   }
 
-  std::cerr << "Started MASQUE server" << std::endl;
+  std::cerr << "Started " << masque_mode << " MASQUE server" << std::endl;
   server->HandleEventsForever();
   return 0;
 }
diff --git a/quic/masque/masque_server_session.cc b/quic/masque/masque_server_session.cc
index 228444a..de231ba 100644
--- a/quic/masque/masque_server_session.cc
+++ b/quic/masque/masque_server_session.cc
@@ -4,14 +4,79 @@
 
 #include "quic/masque/masque_server_session.h"
 
+#include <netdb.h>
+
+#include "absl/strings/str_cat.h"
+#include "quic/core/quic_data_reader.h"
+#include "quic/core/quic_udp_socket.h"
+#include "quic/tools/quic_url.h"
+#include "common/platform/api/quiche_text_utils.h"
+
 namespace quic {
 
+namespace {
+// RAII wrapper for QuicUdpSocketFd.
+class FdWrapper {
+ public:
+  // Takes ownership of |fd| and closes the file descriptor on destruction.
+  explicit FdWrapper(int address_family) {
+    QuicUdpSocketApi socket_api;
+    fd_ =
+        socket_api.Create(address_family,
+                          /*receive_buffer_size =*/kDefaultSocketReceiveBuffer,
+                          /*send_buffer_size =*/kDefaultSocketReceiveBuffer);
+  }
+
+  ~FdWrapper() {
+    if (fd_ == kQuicInvalidSocketFd) {
+      return;
+    }
+    QuicUdpSocketApi socket_api;
+    socket_api.Destroy(fd_);
+  }
+
+  // Hands ownership of the file descriptor to the caller.
+  QuicUdpSocketFd extract_fd() {
+    QuicUdpSocketFd fd = fd_;
+    fd_ = kQuicInvalidSocketFd;
+    return fd;
+  }
+
+  // Keeps ownership of the file descriptor.
+  QuicUdpSocketFd fd() { return fd_; }
+
+  // Disallow copy and move.
+  FdWrapper(const FdWrapper&) = delete;
+  FdWrapper(FdWrapper&&) = delete;
+  FdWrapper& operator=(const FdWrapper&) = delete;
+  FdWrapper& operator=(FdWrapper&&) = delete;
+
+ private:
+  QuicUdpSocketFd fd_;
+};
+
+std::unique_ptr<QuicBackendResponse> CreateBackendErrorResponse(
+    absl::string_view status,
+    absl::string_view body) {
+  spdy::Http2HeaderBlock response_headers;
+  response_headers[":status"] = status;
+  auto response = std::make_unique<QuicBackendResponse>();
+  response->set_response_type(QuicBackendResponse::REGULAR_RESPONSE);
+  response->set_headers(std::move(response_headers));
+  response->set_body(body);
+  return response;
+}
+
+}  // namespace
+
 MasqueServerSession::MasqueServerSession(
+    MasqueMode masque_mode,
     const QuicConfig& config,
     const ParsedQuicVersionVector& supported_versions,
     QuicConnection* connection,
     QuicSession::Visitor* visitor,
     Visitor* owner,
+    QuicEpollServer* epoll_server,
     QuicCryptoServerStreamBase::Helper* helper,
     const QuicCryptoServerConfig* crypto_config,
     QuicCompressedCertsCache* compressed_certs_cache,
@@ -26,41 +91,74 @@
                               masque_server_backend),
       masque_server_backend_(masque_server_backend),
       owner_(owner),
-      compression_engine_(this) {
+      epoll_server_(epoll_server),
+      compression_engine_(this),
+      masque_mode_(masque_mode) {
   masque_server_backend_->RegisterBackendClient(connection_id(), this);
 }
 
 void MasqueServerSession::OnMessageReceived(absl::string_view message) {
   QUIC_DVLOG(1) << "Received DATAGRAM frame of length " << message.length();
+  if (masque_mode_ == MasqueMode::kLegacy) {
+    QuicConnectionId client_connection_id, server_connection_id;
+    QuicSocketAddress target_server_address;
+    std::vector<char> packet;
+    bool version_present;
+    if (!compression_engine_.DecompressDatagram(
+            message, &client_connection_id, &server_connection_id,
+            &target_server_address, &packet, &version_present)) {
+      return;
+    }
 
-  QuicConnectionId client_connection_id, server_connection_id;
-  QuicSocketAddress server_address;
-  std::vector<char> packet;
-  bool version_present;
-  if (!compression_engine_.DecompressDatagram(
-          message, &client_connection_id, &server_connection_id,
-          &server_address, &packet, &version_present)) {
+    QUIC_DVLOG(1) << "Received packet of length " << packet.size() << " for "
+                  << target_server_address << " client "
+                  << client_connection_id;
+
+    if (version_present) {
+      if (client_connection_id.length() != kQuicDefaultConnectionIdLength) {
+        QUIC_DLOG(ERROR)
+            << "Dropping long header with invalid client_connection_id "
+            << client_connection_id;
+        return;
+      }
+      owner_->RegisterClientConnectionId(client_connection_id, this);
+    }
+
+    WriteResult write_result = connection()->writer()->WritePacket(
+        packet.data(), packet.size(), connection()->self_address().host(),
+        target_server_address, nullptr);
+    QUIC_DVLOG(1) << "Got " << write_result << " for " << packet.size()
+                  << " bytes to " << target_server_address;
+    return;
+  }
+  QUICHE_DCHECK_EQ(masque_mode_, MasqueMode::kOpen);
+  QuicDataReader reader(message);
+  QuicDatagramFlowId flow_id;
+  if (!reader.ReadVarInt62(&flow_id)) {
+    QUIC_DLOG(ERROR) << "Failed to read flow_id";
     return;
   }
 
-  QUIC_DVLOG(1) << "Received packet of length " << packet.size() << " for "
-                << server_address << " client " << client_connection_id;
-
-  if (version_present) {
-    if (client_connection_id.length() != kQuicDefaultConnectionIdLength) {
-      QUIC_DLOG(ERROR)
-          << "Dropping long header with invalid client_connection_id "
-          << client_connection_id;
-      return;
-    }
-    owner_->RegisterClientConnectionId(client_connection_id, this);
+  auto it =
+      absl::c_find_if(connect_udp_server_states_,
+                      [flow_id](const ConnectUdpServerState& connect_udp) {
+                        return connect_udp.flow_id() == flow_id;
+                      });
+  if (it == connect_udp_server_states_.end()) {
+    QUIC_DLOG(ERROR) << "Received unknown flow_id " << flow_id;
+    return;
   }
-
-  WriteResult write_result = connection()->writer()->WritePacket(
-      packet.data(), packet.size(), connection()->self_address().host(),
-      server_address, nullptr);
-  QUIC_DVLOG(1) << "Got " << write_result << " for " << packet.size()
-                << " bytes to " << server_address;
+  QuicSocketAddress target_server_address = it->target_server_address();
+  QUICHE_DCHECK(target_server_address.IsInitialized());
+  QuicUdpSocketFd fd = it->fd();
+  QUICHE_DCHECK_NE(fd, kQuicInvalidSocketFd);
+  absl::string_view packet = reader.ReadRemainingPayload();
+  QuicUdpSocketApi socket_api;
+  QuicUdpPacketInfo packet_info;
+  packet_info.SetPeerAddress(target_server_address);
+  WriteResult write_result =
+      socket_api.WritePacket(fd, packet.data(), packet.length(), packet_info);
+  QUIC_DVLOG(1) << "Wrote packet to server with result " << write_result;
 }
 
 void MasqueServerSession::OnMessageAcked(QuicMessageId message_id,
@@ -73,17 +171,123 @@
 }
 
 void MasqueServerSession::OnConnectionClosed(
-    const QuicConnectionCloseFrame& /*frame*/,
-    ConnectionCloseSource /*source*/) {
+    const QuicConnectionCloseFrame& frame,
+    ConnectionCloseSource source) {
+  QuicSimpleServerSession::OnConnectionClosed(frame, source);
   QUIC_DLOG(INFO) << "Closing connection for " << connection_id();
   masque_server_backend_->RemoveBackendClient(connection_id());
+  // Clearing this state will close all sockets.
+  connect_udp_server_states_.clear();
+}
+
+void MasqueServerSession::OnStreamClosed(QuicStreamId stream_id) {
+  connect_udp_server_states_.remove_if(
+      [stream_id](const ConnectUdpServerState& connect_udp) {
+        return connect_udp.stream_id() == stream_id;
+      });
+
+  QuicSimpleServerSession::OnStreamClosed(stream_id);
 }
 
 std::unique_ptr<QuicBackendResponse> MasqueServerSession::HandleMasqueRequest(
     const std::string& masque_path,
-    const spdy::Http2HeaderBlock& /*request_headers*/,
+    const spdy::Http2HeaderBlock& request_headers,
     const std::string& request_body,
-    QuicSimpleServerBackend::RequestHandler* /*request_handler*/) {
+    QuicSimpleServerBackend::RequestHandler* request_handler) {
+  if (masque_mode_ != MasqueMode::kLegacy) {
+    auto path_pair = request_headers.find(":path");
+    auto scheme_pair = request_headers.find(":scheme");
+    auto method_pair = request_headers.find(":method");
+    auto flow_id_pair = request_headers.find("datagram-flow-id");
+    auto authority_pair = request_headers.find(":authority");
+    if (path_pair == request_headers.end() ||
+        scheme_pair == request_headers.end() ||
+        method_pair == request_headers.end() ||
+        flow_id_pair == request_headers.end() ||
+        authority_pair == request_headers.end()) {
+      QUIC_DLOG(ERROR) << "MASQUE request is missing required headers";
+      return CreateBackendErrorResponse("400", "Missing required headers");
+    }
+    absl::string_view path = path_pair->second;
+    absl::string_view scheme = scheme_pair->second;
+    absl::string_view method = method_pair->second;
+    absl::string_view flow_id_str = flow_id_pair->second;
+    absl::string_view authority = authority_pair->second;
+    if (path.empty()) {
+      QUIC_DLOG(ERROR) << "MASQUE request with empty path";
+      return CreateBackendErrorResponse("400", "Empty path");
+    }
+    if (scheme.empty()) {
+      return CreateBackendErrorResponse("400", "Empty scheme");
+      return nullptr;
+    }
+    if (method != "CONNECT-UDP") {
+      QUIC_DLOG(ERROR) << "MASQUE request with bad method \"" << method << "\"";
+      return CreateBackendErrorResponse("400", "Bad method");
+    }
+    QuicDatagramFlowId flow_id;
+    if (!absl::SimpleAtoi(flow_id_str, &flow_id)) {
+      QUIC_DLOG(ERROR) << "MASQUE request with bad flow_id \"" << flow_id_str
+                       << "\"";
+      return CreateBackendErrorResponse("400", "Bad flow ID");
+    }
+    QuicUrl url(absl::StrCat("https://", authority));
+    if (!url.IsValid() || url.PathParamsQuery() != "/") {
+      QUIC_DLOG(ERROR) << "MASQUE request with bad authority \"" << authority
+                       << "\"";
+      return CreateBackendErrorResponse("400", "Bad authority");
+    }
+
+    std::string port = absl::StrCat(url.port());
+    addrinfo hint = {};
+    hint.ai_protocol = IPPROTO_UDP;
+
+    addrinfo* info_list = nullptr;
+    int result =
+        getaddrinfo(url.host().c_str(), port.c_str(), &hint, &info_list);
+    if (result != 0) {
+      QUIC_DLOG(ERROR) << "Failed to resolve " << authority << ": "
+                       << gai_strerror(result);
+      return CreateBackendErrorResponse("500", "DNS resolution failed");
+    }
+
+    QUICHE_CHECK_NE(info_list, nullptr);
+    std::unique_ptr<addrinfo, void (*)(addrinfo*)> info_list_owned(
+        info_list, freeaddrinfo);
+    QuicSocketAddress target_server_address(info_list->ai_addr,
+                                            info_list->ai_addrlen);
+    QUIC_DLOG(INFO) << "Got CONNECT_UDP request flow_id=" << flow_id
+                    << " target_server_address=\"" << target_server_address
+                    << "\"";
+
+    FdWrapper fd_wrapper(target_server_address.host().AddressFamilyToInt());
+    if (fd_wrapper.fd() == kQuicInvalidSocketFd) {
+      QUIC_DLOG(ERROR) << "Socket creation failed";
+      return CreateBackendErrorResponse("500", "Socket creation failed");
+    }
+    QuicSocketAddress any_v6_address(QuicIpAddress::Any6(), 0);
+    QuicUdpSocketApi socket_api;
+    if (!socket_api.Bind(fd_wrapper.fd(), any_v6_address)) {
+      QUIC_DLOG(ERROR) << "Socket bind failed";
+      return CreateBackendErrorResponse("500", "Socket bind failed");
+    }
+    epoll_server_->RegisterFDForRead(fd_wrapper.fd(), this);
+
+    connect_udp_server_states_.emplace_back(ConnectUdpServerState(
+        flow_id, request_handler->stream_id(), target_server_address,
+        fd_wrapper.extract_fd(), epoll_server_));
+
+    spdy::Http2HeaderBlock response_headers;
+    response_headers[":status"] = "200";
+    response_headers["datagram-flow-id"] = absl::StrCat(flow_id);
+    auto response = std::make_unique<QuicBackendResponse>();
+    response->set_response_type(QuicBackendResponse::INCOMPLETE_RESPONSE);
+    response->set_headers(std::move(response_headers));
+    response->set_body("");
+
+    return response;
+  }
+
   QUIC_DLOG(INFO) << "MasqueServerSession handling MASQUE request";
 
   if (masque_path == "init") {
@@ -120,9 +324,166 @@
 void MasqueServerSession::HandlePacketFromServer(
     const ReceivedPacketInfo& packet_info) {
   QUIC_DVLOG(1) << "MasqueServerSession received " << packet_info;
-  compression_engine_.CompressAndSendPacket(
-      packet_info.packet.AsStringPiece(), packet_info.destination_connection_id,
-      packet_info.source_connection_id, packet_info.peer_address);
+  if (masque_mode_ == MasqueMode::kLegacy) {
+    compression_engine_.CompressAndSendPacket(
+        packet_info.packet.AsStringPiece(),
+        packet_info.destination_connection_id, packet_info.source_connection_id,
+        packet_info.peer_address);
+    return;
+  }
+  QUIC_LOG(ERROR) << "Ignoring packet from server in " << masque_mode_
+                  << " mode";
+}
+
+void MasqueServerSession::OnRegistration(QuicEpollServer* /*eps*/,
+                                         QuicUdpSocketFd fd,
+                                         int event_mask) {
+  QUIC_DVLOG(1) << "OnRegistration " << fd << " event_mask " << event_mask;
+}
+
+void MasqueServerSession::OnModification(QuicUdpSocketFd fd, int event_mask) {
+  QUIC_DVLOG(1) << "OnModification " << fd << " event_mask " << event_mask;
+}
+
+void MasqueServerSession::OnEvent(QuicUdpSocketFd fd, QuicEpollEvent* event) {
+  if ((event->in_events & EPOLLIN) == 0) {
+    QUIC_DVLOG(1) << "Ignoring OnEvent fd " << fd << " event mask "
+                  << event->in_events;
+    return;
+  }
+  auto it = absl::c_find_if(connect_udp_server_states_,
+                            [fd](const ConnectUdpServerState& connect_udp) {
+                              return connect_udp.fd() == fd;
+                            });
+  if (it == connect_udp_server_states_.end()) {
+    QUIC_BUG << "Got unexpected event mask " << event->in_events
+             << " on unknown fd " << fd;
+    return;
+  }
+  QuicDatagramFlowId flow_id = it->flow_id();
+  QuicSocketAddress expected_target_server_address =
+      it->target_server_address();
+  QUICHE_DCHECK(expected_target_server_address.IsInitialized());
+  QUIC_DVLOG(1) << "Received readable event on fd " << fd << " (mask "
+                << event->in_events << ") flow_id " << flow_id << " server "
+                << expected_target_server_address;
+  QuicUdpSocketApi socket_api;
+  BitMask64 packet_info_interested(QuicUdpPacketInfoBit::PEER_ADDRESS);
+  char packet_buffer[kMaxIncomingPacketSize];
+  char control_buffer[kDefaultUdpPacketControlBufferSize];
+  while (true) {
+    QuicUdpSocketApi::ReadPacketResult read_result;
+    read_result.packet_buffer = {packet_buffer, sizeof(packet_buffer)};
+    read_result.control_buffer = {control_buffer, sizeof(control_buffer)};
+    socket_api.ReadPacket(fd, packet_info_interested, &read_result);
+    if (!read_result.ok) {
+      // Most likely there is nothing left to read, break out of read loop.
+      break;
+    }
+    if (!read_result.packet_info.HasValue(QuicUdpPacketInfoBit::PEER_ADDRESS)) {
+      QUIC_BUG << "Missing peer address when reading from fd " << fd;
+      continue;
+    }
+    if (read_result.packet_info.peer_address() !=
+        expected_target_server_address) {
+      QUIC_DLOG(ERROR) << "Ignoring UDP packet on fd " << fd
+                       << " from unexpected server address "
+                       << read_result.packet_info.peer_address()
+                       << " (expected " << expected_target_server_address
+                       << ")";
+      continue;
+    }
+    if (!connection()->connected()) {
+      QUIC_BUG << "Unexpected incoming UDP packet on fd " << fd << " from "
+               << expected_target_server_address
+               << " because MASQUE connection is closed";
+      return;
+    }
+    // The packet is valid, send it to the client in a DATAGRAM frame.
+    size_t slice_length = QuicDataWriter::GetVarInt62Len(flow_id) +
+                          read_result.packet_buffer.buffer_len;
+    QuicUniqueBufferPtr buffer = MakeUniqueBuffer(
+        connection()->helper()->GetStreamSendBufferAllocator(), slice_length);
+    QuicDataWriter writer(slice_length, buffer.get());
+    if (!writer.WriteVarInt62(flow_id)) {
+      QUIC_BUG << "Failed to write flow_id";
+      continue;
+    }
+    if (!writer.WriteBytes(read_result.packet_buffer.buffer,
+                           read_result.packet_buffer.buffer_len)) {
+      QUIC_BUG << "Failed to write packet";
+      continue;
+    }
+    QUICHE_DCHECK_EQ(writer.remaining(), 0u);
+    QuicMemSlice slice(std::move(buffer), slice_length);
+    MessageResult message_result = SendMessage(QuicMemSliceSpan(&slice));
+    QUIC_DVLOG(1) << "Sent UDP packet from target server of length "
+                  << read_result.packet_buffer.buffer_len << " with flow ID "
+                  << flow_id << " and got message result " << message_result;
+  }
+}
+
+void MasqueServerSession::OnUnregistration(QuicUdpSocketFd fd, bool replaced) {
+  QUIC_DVLOG(1) << "OnUnregistration " << fd << " " << (replaced ? "" : "!")
+                << " replaced";
+}
+
+void MasqueServerSession::OnShutdown(QuicEpollServer* /*eps*/,
+                                     QuicUdpSocketFd fd) {
+  QUIC_DVLOG(1) << "OnShutdown " << fd;
+}
+
+std::string MasqueServerSession::Name() const {
+  return std::string("MasqueServerSession-") + connection_id().ToString();
+}
+
+MasqueServerSession::ConnectUdpServerState::ConnectUdpServerState(
+    QuicDatagramFlowId flow_id,
+    QuicStreamId stream_id,
+    const QuicSocketAddress& target_server_address,
+    QuicUdpSocketFd fd,
+    QuicEpollServer* epoll_server)
+    : flow_id_(flow_id),
+      stream_id_(stream_id),
+      target_server_address_(target_server_address),
+      fd_(fd),
+      epoll_server_(epoll_server) {
+  QUICHE_DCHECK_NE(fd_, kQuicInvalidSocketFd);
+  QUICHE_DCHECK_NE(epoll_server_, nullptr);
+}
+
+MasqueServerSession::ConnectUdpServerState::~ConnectUdpServerState() {
+  if (fd_ == kQuicInvalidSocketFd) {
+    return;
+  }
+  QuicUdpSocketApi socket_api;
+  QUIC_DLOG(INFO) << "Closing fd " << fd_;
+  epoll_server_->UnregisterFD(fd_);
+  socket_api.Destroy(fd_);
+}
+
+MasqueServerSession::ConnectUdpServerState::ConnectUdpServerState(
+    MasqueServerSession::ConnectUdpServerState&& other) {
+  fd_ = kQuicInvalidSocketFd;
+  *this = std::move(other);
+}
+
+MasqueServerSession::ConnectUdpServerState&
+MasqueServerSession::ConnectUdpServerState::operator=(
+    MasqueServerSession::ConnectUdpServerState&& other) {
+  if (fd_ != kQuicInvalidSocketFd) {
+    QuicUdpSocketApi socket_api;
+    QUIC_DLOG(INFO) << "Closing fd " << fd_;
+    epoll_server_->UnregisterFD(fd_);
+    socket_api.Destroy(fd_);
+  }
+  flow_id_ = other.flow_id_;
+  stream_id_ = other.stream_id_;
+  target_server_address_ = other.target_server_address_;
+  fd_ = other.fd_;
+  epoll_server_ = other.epoll_server_;
+  other.fd_ = kQuicInvalidSocketFd;
+  return *this;
 }
 
 }  // namespace quic
diff --git a/quic/masque/masque_server_session.h b/quic/masque/masque_server_session.h
index 9e43b63..a73e641 100644
--- a/quic/masque/masque_server_session.h
+++ b/quic/masque/masque_server_session.h
@@ -5,8 +5,12 @@
 #ifndef QUICHE_QUIC_MASQUE_MASQUE_SERVER_SESSION_H_
 #define QUICHE_QUIC_MASQUE_MASQUE_SERVER_SESSION_H_
 
+#include "quic/core/quic_types.h"
+#include "quic/core/quic_udp_socket.h"
 #include "quic/masque/masque_compression_engine.h"
 #include "quic/masque/masque_server_backend.h"
+#include "quic/masque/masque_utils.h"
+#include "quic/platform/api/quic_epoll.h"
 #include "quic/platform/api/quic_export.h"
 #include "quic/tools/quic_simple_server_session.h"
 
@@ -15,7 +19,8 @@
 // QUIC server session for connection to MASQUE proxy.
 class QUIC_NO_EXPORT MasqueServerSession
     : public QuicSimpleServerSession,
-      public MasqueServerBackend::BackendClient {
+      public MasqueServerBackend::BackendClient,
+      public QuicEpollCallbackInterface {
  public:
   // Interface meant to be implemented by owner of this MasqueServerSession
   // instance.
@@ -33,11 +38,13 @@
   };
 
   explicit MasqueServerSession(
+      MasqueMode masque_mode,
       const QuicConfig& config,
       const ParsedQuicVersionVector& supported_versions,
       QuicConnection* connection,
       QuicSession::Visitor* visitor,
       Visitor* owner,
+      QuicEpollServer* epoll_server,
       QuicCryptoServerStreamBase::Helper* helper,
       const QuicCryptoServerConfig* crypto_config,
       QuicCompressedCertsCache* compressed_certs_cache,
@@ -54,6 +61,7 @@
   void OnMessageLost(QuicMessageId message_id) override;
   void OnConnectionClosed(const QuicConnectionCloseFrame& frame,
                           ConnectionCloseSource source) override;
+  void OnStreamClosed(QuicStreamId stream_id) override;
 
   // From MasqueServerBackend::BackendClient.
   std::unique_ptr<QuicBackendResponse> HandleMasqueRequest(
@@ -62,13 +70,61 @@
       const std::string& request_body,
       QuicSimpleServerBackend::RequestHandler* request_handler) override;
 
+  // From QuicEpollCallbackInterface.
+  void OnRegistration(QuicEpollServer* eps,
+                      QuicUdpSocketFd fd,
+                      int event_mask) override;
+  void OnModification(QuicUdpSocketFd fd, int event_mask) override;
+  void OnEvent(QuicUdpSocketFd fd, QuicEpollEvent* event) override;
+  void OnUnregistration(QuicUdpSocketFd fd, bool replaced) override;
+  void OnShutdown(QuicEpollServer* eps, QuicUdpSocketFd fd) override;
+  std::string Name() const override;
+
   // Handle packet for client, meant to be called by MasqueDispatcher.
   void HandlePacketFromServer(const ReceivedPacketInfo& packet_info);
 
  private:
+  // State that the MasqueServerSession keeps for each CONNECT-UDP request.
+  class QUIC_NO_EXPORT ConnectUdpServerState {
+   public:
+    // ConnectUdpServerState takes ownership of |fd|. It will unregister it
+    // from |epoll_server| and close the file descriptor when destructed.
+    explicit ConnectUdpServerState(
+        QuicDatagramFlowId flow_id,
+        QuicStreamId stream_id,
+        const QuicSocketAddress& target_server_address,
+        QuicUdpSocketFd fd,
+        QuicEpollServer* epoll_server);
+
+    ~ConnectUdpServerState();
+
+    // Disallow copy but allow move.
+    ConnectUdpServerState(const ConnectUdpServerState&) = delete;
+    ConnectUdpServerState(ConnectUdpServerState&&);
+    ConnectUdpServerState& operator=(const ConnectUdpServerState&) = delete;
+    ConnectUdpServerState& operator=(ConnectUdpServerState&&);
+
+    QuicDatagramFlowId flow_id() const { return flow_id_; }
+    QuicStreamId stream_id() const { return stream_id_; }
+    const QuicSocketAddress& target_server_address() const {
+      return target_server_address_;
+    }
+    QuicUdpSocketFd fd() const { return fd_; }
+
+   private:
+    QuicDatagramFlowId flow_id_;
+    QuicStreamId stream_id_;
+    QuicSocketAddress target_server_address_;
+    QuicUdpSocketFd fd_;             // Owned.
+    QuicEpollServer* epoll_server_;  // Unowned.
+  };
+
   MasqueServerBackend* masque_server_backend_;  // Unowned.
   Visitor* owner_;                              // Unowned.
+  QuicEpollServer* epoll_server_;               // Unowned.
   MasqueCompressionEngine compression_engine_;
+  MasqueMode masque_mode_;
+  std::list<ConnectUdpServerState> connect_udp_server_states_;
   bool masque_initialized_ = false;
 };
 
diff --git a/quic/masque/masque_utils.cc b/quic/masque/masque_utils.cc
index f4f77f5..17dec0d 100644
--- a/quic/masque/masque_utils.cc
+++ b/quic/masque/masque_utils.cc
@@ -27,4 +27,21 @@
   return config;
 }
 
+std::string MasqueModeToString(MasqueMode masque_mode) {
+  switch (masque_mode) {
+    case MasqueMode::kInvalid:
+      return "Invalid";
+    case MasqueMode::kLegacy:
+      return "Legacy";
+    case MasqueMode::kOpen:
+      return "Open";
+  }
+  return absl::StrCat("Unknown(", static_cast<int>(masque_mode), ")");
+}
+
+std::ostream& operator<<(std::ostream& os, const MasqueMode& masque_mode) {
+  os << MasqueModeToString(masque_mode);
+  return os;
+}
+
 }  // namespace quic
diff --git a/quic/masque/masque_utils.h b/quic/masque/masque_utils.h
index f151f55..8113047 100644
--- a/quic/masque/masque_utils.h
+++ b/quic/masque/masque_utils.h
@@ -18,7 +18,24 @@
 QUIC_NO_EXPORT QuicConfig MasqueEncapsulatedConfig();
 
 // Maximum packet size for encapsulated connections.
-const QuicByteCount kMasqueMaxEncapsulatedPacketSize = 1300;
+enum : QuicByteCount { kMasqueMaxEncapsulatedPacketSize = 1300 };
+
+// Mode that MASQUE is operating in.
+enum class MasqueMode : uint8_t {
+  kInvalid = 0,  // Should never be used.
+  kLegacy = 1,   // Legacy mode uses the legacy MASQUE protocol as documented in
+  // <https://tools.ietf.org/html/draft-schinazi-masque-protocol>. That version
+  // of MASQUE uses a custom application-protocol over HTTP/3, and also allows
+  // unauthenticated clients.
+  kOpen = 2,  // Open mode uses the MASQUE HTTP CONNECT-UDP method as documented
+  // in <https://tools.ietf.org/html/draft-ietf-masque-connect-udp>. This mode
+  // allows unauthenticated clients (a more restricted mode will be added to
+  // this enum at a later date).
+};
+
+QUIC_NO_EXPORT std::string MasqueModeToString(MasqueMode masque_mode);
+QUIC_NO_EXPORT std::ostream& operator<<(std::ostream& os,
+                                        const MasqueMode& masque_mode);
 
 }  // namespace quic