Implement MASQUE CONNECT-UDP
Previously, the MASQUE code was supporting the legacy MASQUE protocol. This CL adds support for CONNECT-UDP, the new MASQUE wire format. This is test code currently meant for IETF interop events. We're mainly landing it to prevent bitrot. This code isn't used in production.
PiperOrigin-RevId: 359575658
Change-Id: Ia883662fb6ffe7827c15b1460f2b316494b7ffbd
diff --git a/quic/masque/masque_client_bin.cc b/quic/masque/masque_client_bin.cc
index 8e09a12..afd991a 100644
--- a/quic/masque/masque_client_bin.cc
+++ b/quic/masque/masque_client_bin.cc
@@ -29,6 +29,12 @@
false,
"If true, don't verify the server certificate.");
+DEFINE_QUIC_COMMAND_LINE_FLAG(std::string,
+ masque_mode,
+ "",
+ "Allows setting MASQUE mode, valid values are "
+ "open and legacy. Defaults to open.");
+
namespace quic {
namespace {
@@ -54,7 +60,8 @@
masque_url = QuicUrl(absl::StrCat("https://", urls[0]), "https");
}
if (masque_url.host().empty()) {
- QUIC_LOG(ERROR) << "Failed to parse MASQUE server address " << urls[0];
+ std::cerr << "Failed to parse MASQUE server address \"" << urls[0] << "\""
+ << std::endl;
return 1;
}
std::unique_ptr<ProofVerifier> proof_verifier;
@@ -63,15 +70,23 @@
} else {
proof_verifier = CreateDefaultProofVerifier(masque_url.host());
}
- std::unique_ptr<MasqueEpollClient> masque_client =
- MasqueEpollClient::Create(masque_url.host(), masque_url.port(),
- &epoll_server, std::move(proof_verifier));
+ MasqueMode masque_mode = MasqueMode::kOpen;
+ std::string mode_string = GetQuicFlag(FLAGS_masque_mode);
+ if (mode_string == "legacy") {
+ masque_mode = MasqueMode::kLegacy;
+ } else if (!mode_string.empty() && mode_string != "open") {
+ std::cerr << "Invalid masque_mode \"" << mode_string << "\"" << std::endl;
+ return 1;
+ }
+ std::unique_ptr<MasqueEpollClient> masque_client = MasqueEpollClient::Create(
+ masque_url.host(), masque_url.port(), masque_mode, &epoll_server,
+ std::move(proof_verifier));
if (masque_client == nullptr) {
return 1;
}
std::cerr << "MASQUE is connected " << masque_client->connection_id()
- << std::endl;
+ << " in " << masque_mode << " mode" << std::endl;
for (size_t i = 1; i < urls.size(); ++i) {
if (!tools::SendEncapsulatedMasqueRequest(
diff --git a/quic/masque/masque_client_session.cc b/quic/masque/masque_client_session.cc
index eeedccf..5663c4a 100644
--- a/quic/masque/masque_client_session.cc
+++ b/quic/masque/masque_client_session.cc
@@ -3,10 +3,14 @@
// found in the LICENSE file.
#include "quic/masque/masque_client_session.h"
+#include "absl/algorithm/container.h"
+#include "quic/core/quic_data_reader.h"
+#include "common/platform/api/quiche_text_utils.h"
namespace quic {
MasqueClientSession::MasqueClientSession(
+ MasqueMode masque_mode,
const QuicConfig& config,
const ParsedQuicVersionVector& supported_versions,
QuicConnection* connection,
@@ -20,36 +24,65 @@
server_id,
crypto_config,
push_promise_index),
+ masque_mode_(masque_mode),
owner_(owner),
compression_engine_(this) {}
void MasqueClientSession::OnMessageReceived(absl::string_view message) {
QUIC_DVLOG(1) << "Received DATAGRAM frame of length " << message.length();
+ if (masque_mode_ == MasqueMode::kLegacy) {
+ QuicConnectionId client_connection_id, server_connection_id;
+ QuicSocketAddress target_server_address;
+ std::vector<char> packet;
+ bool version_present;
+ if (!compression_engine_.DecompressDatagram(
+ message, &client_connection_id, &server_connection_id,
+ &target_server_address, &packet, &version_present)) {
+ return;
+ }
- QuicConnectionId client_connection_id, server_connection_id;
- QuicSocketAddress server_address;
- std::vector<char> packet;
- bool version_present;
- if (!compression_engine_.DecompressDatagram(
- message, &client_connection_id, &server_connection_id,
- &server_address, &packet, &version_present)) {
+ auto connection_id_registration =
+ client_connection_id_registrations_.find(client_connection_id);
+ if (connection_id_registration ==
+ client_connection_id_registrations_.end()) {
+ QUIC_DLOG(ERROR) << "MasqueClientSession failed to dispatch "
+ << client_connection_id;
+ return;
+ }
+ EncapsulatedClientSession* encapsulated_client_session =
+ connection_id_registration->second;
+ encapsulated_client_session->ProcessPacket(
+ absl::string_view(packet.data(), packet.size()), target_server_address);
+
+ QUIC_DVLOG(1) << "Sent " << packet.size() << " bytes to connection for "
+ << client_connection_id;
return;
}
-
- auto connection_id_registration =
- client_connection_id_registrations_.find(client_connection_id);
- if (connection_id_registration == client_connection_id_registrations_.end()) {
- QUIC_DLOG(ERROR) << "MasqueClientSession failed to dispatch "
- << client_connection_id;
+ QuicDataReader reader(message);
+ QuicDatagramFlowId flow_id;
+ if (!reader.ReadVarInt62(&flow_id)) {
+ QUIC_DLOG(ERROR) << "Failed to parse flow_id";
+ return;
+ }
+ auto it =
+ absl::c_find_if(connect_udp_client_states_,
+ [flow_id](const ConnectUdpClientState& connect_udp) {
+ return connect_udp.flow_id() == flow_id;
+ });
+ if (it == connect_udp_client_states_.end()) {
+ QUIC_DLOG(ERROR) << "Received unknown flow_id " << flow_id;
return;
}
EncapsulatedClientSession* encapsulated_client_session =
- connection_id_registration->second;
- encapsulated_client_session->ProcessPacket(
- absl::string_view(packet.data(), packet.size()), server_address);
+ it->encapsulated_client_session();
+ QuicSocketAddress target_server_address = it->target_server_address();
+ QUICHE_DCHECK_NE(encapsulated_client_session, nullptr);
+ QUICHE_DCHECK(target_server_address.IsInitialized());
+ absl::string_view packet = reader.ReadRemainingPayload();
+ encapsulated_client_session->ProcessPacket(packet, target_server_address);
- QUIC_DVLOG(1) << "Sent " << packet.size() << " bytes to connection for "
- << client_connection_id;
+ QUIC_DVLOG(1) << "Sent " << packet.size()
+ << " bytes to connection for flow_id " << flow_id;
}
void MasqueClientSession::OnMessageAcked(QuicMessageId message_id,
@@ -61,12 +94,87 @@
QUIC_DVLOG(1) << "We believe DATAGRAM frame " << message_id << " was lost";
}
-void MasqueClientSession::SendPacket(QuicConnectionId client_connection_id,
- QuicConnectionId server_connection_id,
- absl::string_view packet,
- const QuicSocketAddress& server_address) {
- compression_engine_.CompressAndSendPacket(
- packet, client_connection_id, server_connection_id, server_address);
+const MasqueClientSession::ConnectUdpClientState*
+MasqueClientSession::GetOrCreateConnectUdpClientState(
+ const QuicSocketAddress& target_server_address,
+ EncapsulatedClientSession* encapsulated_client_session) {
+ for (const ConnectUdpClientState& client_state : connect_udp_client_states_) {
+ if (client_state.target_server_address() == target_server_address &&
+ client_state.encapsulated_client_session() ==
+ encapsulated_client_session) {
+ // Found existing CONNECT-UDP request.
+ return &client_state;
+ }
+ }
+ // No CONNECT-UDP request found, create a new one.
+ QuicSpdyClientStream* stream = CreateOutgoingBidirectionalStream();
+ if (stream == nullptr) {
+ // Stream flow control limits prevented us from opening a new stream.
+ QUIC_DLOG(ERROR) << "Failed to open CONNECT-UDP stream";
+ return nullptr;
+ }
+
+ QuicDatagramFlowId flow_id = compression_engine_.GetNextFlowId();
+
+ // Send the request.
+ spdy::Http2HeaderBlock headers;
+ headers[":method"] = "CONNECT-UDP";
+ headers[":scheme"] = "masque";
+ headers[":path"] = "/";
+ headers[":authority"] = target_server_address.ToString();
+ headers["datagram-flow-id"] = absl::StrCat(flow_id);
+ size_t bytes_sent =
+ stream->SendRequest(std::move(headers), /*body=*/"", /*fin=*/false);
+ if (bytes_sent == 0) {
+ QUIC_DLOG(ERROR) << "Failed to send CONNECT-UDP request";
+ return nullptr;
+ }
+
+ connect_udp_client_states_.push_back(ConnectUdpClientState(
+ stream, encapsulated_client_session, flow_id, target_server_address));
+ return &connect_udp_client_states_.back();
+}
+
+void MasqueClientSession::SendPacket(
+ QuicConnectionId client_connection_id,
+ QuicConnectionId server_connection_id,
+ absl::string_view packet,
+ const QuicSocketAddress& target_server_address,
+ EncapsulatedClientSession* encapsulated_client_session) {
+ if (masque_mode_ == MasqueMode::kLegacy) {
+ compression_engine_.CompressAndSendPacket(packet, client_connection_id,
+ server_connection_id,
+ target_server_address);
+ return;
+ }
+ const ConnectUdpClientState* connect_udp = GetOrCreateConnectUdpClientState(
+ target_server_address, encapsulated_client_session);
+ if (connect_udp == nullptr) {
+ QUIC_DLOG(ERROR) << "Failed to create CONNECT-UDP request";
+ return;
+ }
+
+ QuicDatagramFlowId flow_id = connect_udp->flow_id();
+ size_t slice_length =
+ QuicDataWriter::GetVarInt62Len(flow_id) + packet.length();
+ QuicUniqueBufferPtr buffer = MakeUniqueBuffer(
+ connection()->helper()->GetStreamSendBufferAllocator(), slice_length);
+ QuicDataWriter writer(slice_length, buffer.get());
+ if (!writer.WriteVarInt62(flow_id)) {
+ QUIC_BUG << "Failed to write flow_id";
+ return;
+ }
+ if (!writer.WriteBytes(packet.data(), packet.length())) {
+ QUIC_BUG << "Failed to write packet";
+ return;
+ }
+
+ QuicMemSlice slice(std::move(buffer), slice_length);
+ MessageResult message_result = SendMessage(QuicMemSliceSpan(&slice));
+
+ QUIC_DVLOG(1) << "Sent packet to " << target_server_address
+ << " compressed with flow ID " << flow_id
+ << " and got message result " << message_result;
}
void MasqueClientSession::RegisterConnectionId(
@@ -84,14 +192,65 @@
}
void MasqueClientSession::UnregisterConnectionId(
- QuicConnectionId client_connection_id) {
+ QuicConnectionId client_connection_id,
+ EncapsulatedClientSession* encapsulated_client_session) {
QUIC_DLOG(INFO) << "Unregistering " << client_connection_id;
- if (client_connection_id_registrations_.find(client_connection_id) !=
- client_connection_id_registrations_.end()) {
- client_connection_id_registrations_.erase(client_connection_id);
- owner_->UnregisterClientConnectionId(client_connection_id);
- compression_engine_.UnregisterClientConnectionId(client_connection_id);
+ if (masque_mode_ == MasqueMode::kLegacy) {
+ if (client_connection_id_registrations_.find(client_connection_id) !=
+ client_connection_id_registrations_.end()) {
+ client_connection_id_registrations_.erase(client_connection_id);
+ owner_->UnregisterClientConnectionId(client_connection_id);
+ compression_engine_.UnregisterClientConnectionId(client_connection_id);
+ }
+ return;
}
+
+ for (auto it = connect_udp_client_states_.begin();
+ it != connect_udp_client_states_.end();) {
+ if (it->encapsulated_client_session() == encapsulated_client_session) {
+ QUIC_DLOG(INFO) << "Removing state for flow_id " << it->flow_id();
+ auto* stream = it->stream();
+ it = connect_udp_client_states_.erase(it);
+ if (!stream->write_side_closed()) {
+ stream->Reset(QUIC_STREAM_CANCELLED);
+ }
+ } else {
+ ++it;
+ }
+ }
+}
+
+void MasqueClientSession::OnConnectionClosed(
+ const QuicConnectionCloseFrame& frame,
+ ConnectionCloseSource source) {
+ QuicSpdyClientSession::OnConnectionClosed(frame, source);
+ // Close all encapsulated sessions.
+ for (auto client_state : connect_udp_client_states_) {
+ client_state.encapsulated_client_session()->CloseConnection(
+ QUIC_CONNECTION_CANCELLED, "Underlying MASQUE connection was closed",
+ ConnectionCloseBehavior::SILENT_CLOSE);
+ }
+}
+
+void MasqueClientSession::OnStreamClosed(QuicStreamId stream_id) {
+ for (auto it = connect_udp_client_states_.begin();
+ it != connect_udp_client_states_.end();) {
+ if (it->stream()->id() == stream_id) {
+ QUIC_DLOG(INFO) << "Stream " << stream_id
+ << " was closed, removing state for flow_id "
+ << it->flow_id();
+ auto* encapsulated_client_session = it->encapsulated_client_session();
+ it = connect_udp_client_states_.erase(it);
+ encapsulated_client_session->CloseConnection(
+ QUIC_CONNECTION_CANCELLED,
+ "Underlying MASQUE CONNECT-UDP stream was closed",
+ ConnectionCloseBehavior::SILENT_CLOSE);
+ } else {
+ ++it;
+ }
+ }
+
+ QuicSpdyClientSession::OnStreamClosed(stream_id);
}
} // namespace quic
diff --git a/quic/masque/masque_client_session.h b/quic/masque/masque_client_session.h
index 0eab5b2..f849c61 100644
--- a/quic/masque/masque_client_session.h
+++ b/quic/masque/masque_client_session.h
@@ -9,6 +9,7 @@
#include "absl/strings/string_view.h"
#include "quic/core/http/quic_spdy_client_session.h"
#include "quic/masque/masque_compression_engine.h"
+#include "quic/masque/masque_utils.h"
#include "quic/platform/api/quic_export.h"
#include "quic/platform/api/quic_socket_address.h"
@@ -39,14 +40,21 @@
// Process packet that was just decapsulated.
virtual void ProcessPacket(absl::string_view packet,
- QuicSocketAddress server_address) = 0;
+ QuicSocketAddress target_server_address) = 0;
+
+ // Close the encapsulated connection.
+ virtual void CloseConnection(
+ QuicErrorCode error,
+ const std::string& details,
+ ConnectionCloseBehavior connection_close_behavior) = 0;
};
// Takes ownership of |connection|, but not of |crypto_config| or
// |push_promise_index| or |owner|. All pointers must be non-null. Caller
// must ensure that |push_promise_index| and |owner| stay valid for the
// lifetime of the newly created MasqueClientSession.
- MasqueClientSession(const QuicConfig& config,
+ MasqueClientSession(MasqueMode masque_mode,
+ const QuicConfig& config,
const ParsedQuicVersionVector& supported_versions,
QuicConnection* connection,
const QuicServerId& server_id,
@@ -60,17 +68,19 @@
// From QuicSession.
void OnMessageReceived(absl::string_view message) override;
-
void OnMessageAcked(QuicMessageId message_id,
QuicTime receive_timestamp) override;
-
void OnMessageLost(QuicMessageId message_id) override;
+ void OnConnectionClosed(const QuicConnectionCloseFrame& frame,
+ ConnectionCloseSource source) override;
+ void OnStreamClosed(QuicStreamId stream_id) override;
// Send encapsulated packet.
void SendPacket(QuicConnectionId client_connection_id,
QuicConnectionId server_connection_id,
absl::string_view packet,
- const QuicSocketAddress& server_address);
+ const QuicSocketAddress& target_server_address,
+ EncapsulatedClientSession* encapsulated_client_session);
// Register encapsulated client. This allows clients that are encapsulated
// within this MASQUE session to indicate they own a given client connection
@@ -84,9 +94,48 @@
// Unregister encapsulated client. |client_connection_id| must match a
// value previously passed to RegisterConnectionId.
- void UnregisterConnectionId(QuicConnectionId client_connection_id);
+ void UnregisterConnectionId(
+ QuicConnectionId client_connection_id,
+ EncapsulatedClientSession* encapsulated_client_session);
private:
+ // State that the MasqueClientSession keeps for each CONNECT-UDP request.
+ class QUIC_NO_EXPORT ConnectUdpClientState {
+ public:
+ // |stream| and |encapsulated_client_session| must be valid for the lifetime
+ // of the ConnectUdpClientState.
+ explicit ConnectUdpClientState(
+ QuicSpdyClientStream* stream,
+ EncapsulatedClientSession* encapsulated_client_session,
+ QuicDatagramFlowId flow_id,
+ const QuicSocketAddress& target_server_address)
+ : stream_(stream),
+ encapsulated_client_session_(encapsulated_client_session),
+ flow_id_(flow_id),
+ target_server_address_(target_server_address) {}
+
+ QuicSpdyClientStream* stream() const { return stream_; }
+ EncapsulatedClientSession* encapsulated_client_session() const {
+ return encapsulated_client_session_;
+ }
+ QuicDatagramFlowId flow_id() const { return flow_id_; }
+ const QuicSocketAddress& target_server_address() const {
+ return target_server_address_;
+ }
+
+ private:
+ QuicSpdyClientStream* stream_; // Unowned.
+ EncapsulatedClientSession* encapsulated_client_session_; // Unowned.
+ QuicDatagramFlowId flow_id_;
+ QuicSocketAddress target_server_address_;
+ };
+
+ const ConnectUdpClientState* GetOrCreateConnectUdpClientState(
+ const QuicSocketAddress& target_server_address,
+ EncapsulatedClientSession* encapsulated_client_session);
+
+ MasqueMode masque_mode_;
+ std::list<ConnectUdpClientState> connect_udp_client_states_;
absl::flat_hash_map<QuicConnectionId,
EncapsulatedClientSession*,
QuicConnectionIdHash>
diff --git a/quic/masque/masque_compression_engine.h b/quic/masque/masque_compression_engine.h
index 96c5948..16b4549 100644
--- a/quic/masque/masque_compression_engine.h
+++ b/quic/masque/masque_compression_engine.h
@@ -70,6 +70,9 @@
// compression table.
void UnregisterClientConnectionId(QuicConnectionId client_connection_id);
+ // Generates a new datagram flow ID.
+ QuicDatagramFlowId GetNextFlowId();
+
private:
struct QUIC_NO_EXPORT MasqueCompressionContext {
QuicConnectionId client_connection_id;
@@ -78,9 +81,6 @@
bool validated = false;
};
- // Generates a new datagram flow ID.
- QuicDatagramFlowId GetNextFlowId();
-
// Finds or creates a new compression context to use during compression.
// |client_connection_id_present| and |server_connection_id_present| indicate
// whether the corresponding connection ID is present in the current packet.
diff --git a/quic/masque/masque_dispatcher.cc b/quic/masque/masque_dispatcher.cc
index 6edca62..03eb8f8 100644
--- a/quic/masque/masque_dispatcher.cc
+++ b/quic/masque/masque_dispatcher.cc
@@ -8,9 +8,11 @@
namespace quic {
MasqueDispatcher::MasqueDispatcher(
+ MasqueMode masque_mode,
const QuicConfig* config,
const QuicCryptoServerConfig* crypto_config,
QuicVersionManager* version_manager,
+ QuicEpollServer* epoll_server,
std::unique_ptr<QuicConnectionHelperInterface> helper,
std::unique_ptr<QuicCryptoServerStreamBase::Helper> session_helper,
std::unique_ptr<QuicAlarmFactory> alarm_factory,
@@ -24,6 +26,8 @@
std::move(alarm_factory),
masque_server_backend,
expected_server_connection_id_length),
+ masque_mode_(masque_mode),
+ epoll_server_(epoll_server),
masque_server_backend_(masque_server_backend) {}
std::unique_ptr<QuicSession> MasqueDispatcher::CreateQuicSession(
@@ -40,9 +44,9 @@
ParsedQuicVersionVector{version});
auto session = std::make_unique<MasqueServerSession>(
- config(), GetSupportedVersions(), connection, this, this,
- session_helper(), crypto_config(), compressed_certs_cache(),
- masque_server_backend_);
+ masque_mode_, config(), GetSupportedVersions(), connection, this, this,
+ epoll_server_, session_helper(), crypto_config(),
+ compressed_certs_cache(), masque_server_backend_);
session->Initialize();
return session;
}
diff --git a/quic/masque/masque_dispatcher.h b/quic/masque/masque_dispatcher.h
index 2e0186c..ce371b5 100644
--- a/quic/masque/masque_dispatcher.h
+++ b/quic/masque/masque_dispatcher.h
@@ -8,6 +8,8 @@
#include "absl/container/flat_hash_map.h"
#include "quic/masque/masque_server_backend.h"
#include "quic/masque/masque_server_session.h"
+#include "quic/masque/masque_utils.h"
+#include "quic/platform/api/quic_epoll.h"
#include "quic/platform/api/quic_export.h"
#include "quic/tools/quic_simple_dispatcher.h"
@@ -19,9 +21,11 @@
public MasqueServerSession::Visitor {
public:
explicit MasqueDispatcher(
+ MasqueMode masque_mode,
const QuicConfig* config,
const QuicCryptoServerConfig* crypto_config,
QuicVersionManager* version_manager,
+ QuicEpollServer* epoll_server,
std::unique_ptr<QuicConnectionHelperInterface> helper,
std::unique_ptr<QuicCryptoServerStreamBase::Helper> session_helper,
std::unique_ptr<QuicAlarmFactory> alarm_factory,
@@ -51,6 +55,8 @@
QuicConnectionId client_connection_id) override;
private:
+ MasqueMode masque_mode_;
+ QuicEpollServer* epoll_server_; // Unowned.
MasqueServerBackend* masque_server_backend_; // Unowned.
// Mapping from client connection IDs to server sessions, allows routing
// incoming packets to the right MASQUE connection.
diff --git a/quic/masque/masque_encapsulated_client_session.cc b/quic/masque/masque_encapsulated_client_session.cc
index 983a094..3d557c3 100644
--- a/quic/masque/masque_encapsulated_client_session.cc
+++ b/quic/masque/masque_encapsulated_client_session.cc
@@ -31,11 +31,19 @@
received_packet);
}
+void MasqueEncapsulatedClientSession::CloseConnection(
+ QuicErrorCode error,
+ const std::string& details,
+ ConnectionCloseBehavior connection_close_behavior) {
+ connection()->CloseConnection(error, details, connection_close_behavior);
+}
+
void MasqueEncapsulatedClientSession::OnConnectionClosed(
- const QuicConnectionCloseFrame& /*frame*/,
- ConnectionCloseSource /*source*/) {
+ const QuicConnectionCloseFrame& frame,
+ ConnectionCloseSource source) {
+ QuicSpdyClientSession::OnConnectionClosed(frame, source);
masque_client_session_->UnregisterConnectionId(
- connection()->client_connection_id());
+ connection()->client_connection_id(), this);
}
} // namespace quic
diff --git a/quic/masque/masque_encapsulated_client_session.h b/quic/masque/masque_encapsulated_client_session.h
index 570ccdb..0952159 100644
--- a/quic/masque/masque_encapsulated_client_session.h
+++ b/quic/masque/masque_encapsulated_client_session.h
@@ -44,6 +44,10 @@
// From MasqueClientSession::EncapsulatedClientSession.
void ProcessPacket(absl::string_view packet,
QuicSocketAddress server_address) override;
+ void CloseConnection(
+ QuicErrorCode error,
+ const std::string& details,
+ ConnectionCloseBehavior connection_close_behavior) override;
// From QuicSession.
void OnConnectionClosed(const QuicConnectionCloseFrame& frame,
diff --git a/quic/masque/masque_encapsulated_epoll_client.cc b/quic/masque/masque_encapsulated_epoll_client.cc
index 8fb938d..8891e4a 100644
--- a/quic/masque/masque_encapsulated_epoll_client.cc
+++ b/quic/masque/masque_encapsulated_epoll_client.cc
@@ -30,8 +30,8 @@
absl::string_view packet(buffer, buf_len);
client_->masque_client()->masque_client_session()->SendPacket(
client_->session()->connection()->client_connection_id(),
- client_->session()->connection()->connection_id(), packet,
- peer_address);
+ client_->session()->connection()->connection_id(), packet, peer_address,
+ client_->masque_encapsulated_client_session());
return WriteResult(WRITE_STATUS_OK, buf_len);
}
@@ -94,7 +94,7 @@
MasqueEncapsulatedEpollClient::~MasqueEncapsulatedEpollClient() {
masque_client_->masque_client_session()->UnregisterConnectionId(
- client_connection_id_);
+ client_connection_id_, masque_encapsulated_client_session());
}
std::unique_ptr<QuicSession>
diff --git a/quic/masque/masque_epoll_client.cc b/quic/masque/masque_epoll_client.cc
index 1d77e55..bf09da4 100644
--- a/quic/masque/masque_epoll_client.cc
+++ b/quic/masque/masque_epoll_client.cc
@@ -12,6 +12,7 @@
MasqueEpollClient::MasqueEpollClient(
QuicSocketAddress server_address,
const QuicServerId& server_id,
+ MasqueMode masque_mode,
QuicEpollServer* epoll_server,
std::unique_ptr<ProofVerifier> proof_verifier,
const std::string& authority)
@@ -20,6 +21,7 @@
MasqueSupportedVersions(),
epoll_server,
std::move(proof_verifier)),
+ masque_mode_(masque_mode),
authority_(authority) {}
std::unique_ptr<QuicSession> MasqueEpollClient::CreateQuicClientSession(
@@ -28,8 +30,8 @@
QUIC_DLOG(INFO) << "Creating MASQUE session for "
<< connection->connection_id();
return std::make_unique<MasqueClientSession>(
- *config(), supported_versions, connection, server_id(), crypto_config(),
- push_promise_index(), this);
+ masque_mode_, *config(), supported_versions, connection, server_id(),
+ crypto_config(), push_promise_index(), this);
}
MasqueClientSession* MasqueEpollClient::masque_client_session() {
@@ -44,6 +46,7 @@
std::unique_ptr<MasqueEpollClient> MasqueEpollClient::Create(
const std::string& host,
int port,
+ MasqueMode masque_mode,
QuicEpollServer* epoll_server,
std::unique_ptr<ProofVerifier> proof_verifier) {
// Build the masque_client, and try to connect.
@@ -57,7 +60,7 @@
// std::make_unique<MasqueEpollClient>(...) because the constructor for
// MasqueEpollClient is private and therefore not accessible from make_unique.
auto masque_client = QuicWrapUnique(new MasqueEpollClient(
- addr, server_id, epoll_server, std::move(proof_verifier),
+ addr, server_id, masque_mode, epoll_server, std::move(proof_verifier),
absl::StrCat(host, ":", port)));
if (masque_client == nullptr) {
@@ -78,33 +81,35 @@
return nullptr;
}
- std::string body = "foo";
+ if (masque_client->masque_mode() == MasqueMode::kLegacy) {
+ // Construct the legacy mode init request.
+ spdy::Http2HeaderBlock header_block;
+ header_block[":method"] = "POST";
+ header_block[":scheme"] = "https";
+ header_block[":authority"] = masque_client->authority_;
+ header_block[":path"] = "/.well-known/masque/init";
+ std::string body = "foo";
- // Construct a GET or POST request for supplied URL.
- spdy::Http2HeaderBlock header_block;
- header_block[":method"] = "POST";
- header_block[":scheme"] = "https";
- header_block[":authority"] = masque_client->authority_;
- header_block[":path"] = "/.well-known/masque/init";
+ // Make sure to store the response, for later output.
+ masque_client->set_store_response(true);
- // Make sure to store the response, for later output.
- masque_client->set_store_response(true);
+ // Send the MASQUE init command.
+ masque_client->SendRequestAndWaitForResponse(header_block, body,
+ /*fin=*/true);
- // Send the MASQUE init command.
- masque_client->SendRequestAndWaitForResponse(header_block, body,
- /*fin=*/true);
+ if (!masque_client->connected()) {
+ QUIC_LOG(ERROR)
+ << "MASQUE init request caused connection failure. Error: "
+ << QuicErrorCodeToString(masque_client->session()->error());
+ return nullptr;
+ }
- if (!masque_client->connected()) {
- QUIC_LOG(ERROR) << "MASQUE init request caused connection failure. Error: "
- << QuicErrorCodeToString(masque_client->session()->error());
- return nullptr;
- }
-
- const int response_code = masque_client->latest_response_code();
- if (response_code != 200) {
- QUIC_LOG(ERROR) << "MASQUE init request failed with HTTP response code "
- << response_code;
- return nullptr;
+ const int response_code = masque_client->latest_response_code();
+ if (response_code != 200) {
+ QUIC_LOG(ERROR) << "MASQUE init request failed with HTTP response code "
+ << response_code;
+ return nullptr;
+ }
}
return masque_client;
}
diff --git a/quic/masque/masque_epoll_client.h b/quic/masque/masque_epoll_client.h
index 345c18c..1db161f 100644
--- a/quic/masque/masque_epoll_client.h
+++ b/quic/masque/masque_epoll_client.h
@@ -6,6 +6,7 @@
#define QUICHE_QUIC_MASQUE_MASQUE_EPOLL_CLIENT_H_
#include "quic/masque/masque_client_session.h"
+#include "quic/masque/masque_utils.h"
#include "quic/platform/api/quic_export.h"
#include "quic/tools/quic_client.h"
@@ -19,6 +20,7 @@
static std::unique_ptr<MasqueEpollClient> Create(
const std::string& host,
int port,
+ MasqueMode masque_mode,
QuicEpollServer* epoll_server,
std::unique_ptr<ProofVerifier> proof_verifier);
@@ -37,10 +39,13 @@
void UnregisterClientConnectionId(
QuicConnectionId client_connection_id) override;
+ MasqueMode masque_mode() const { return masque_mode_; }
+
private:
// Constructor is private, use Create() instead.
MasqueEpollClient(QuicSocketAddress server_address,
const QuicServerId& server_id,
+ MasqueMode masque_mode,
QuicEpollServer* epoll_server,
std::unique_ptr<ProofVerifier> proof_verifier,
const std::string& authority);
@@ -49,6 +54,7 @@
MasqueEpollClient(const MasqueEpollClient&) = delete;
MasqueEpollClient& operator=(const MasqueEpollClient&) = delete;
+ MasqueMode masque_mode_;
std::string authority_;
};
diff --git a/quic/masque/masque_epoll_server.cc b/quic/masque/masque_epoll_server.cc
index 34769a3..e8a0850 100644
--- a/quic/masque/masque_epoll_server.cc
+++ b/quic/masque/masque_epoll_server.cc
@@ -11,15 +11,18 @@
namespace quic {
-MasqueEpollServer::MasqueEpollServer(MasqueServerBackend* masque_server_backend)
+MasqueEpollServer::MasqueEpollServer(MasqueMode masque_mode,
+ MasqueServerBackend* masque_server_backend)
: QuicServer(CreateDefaultProofSource(),
masque_server_backend,
MasqueSupportedVersions()),
+ masque_mode_(masque_mode),
masque_server_backend_(masque_server_backend) {}
QuicDispatcher* MasqueEpollServer::CreateQuicDispatcher() {
return new MasqueDispatcher(
- &config(), &crypto_config(), version_manager(),
+ masque_mode_, &config(), &crypto_config(), version_manager(),
+ epoll_server(),
std::make_unique<QuicEpollConnectionHelper>(epoll_server(),
QuicAllocator::BUFFER_POOL),
std::make_unique<QuicSimpleCryptoServerStreamHelper>(),
diff --git a/quic/masque/masque_epoll_server.h b/quic/masque/masque_epoll_server.h
index 57eab6a..88161ee 100644
--- a/quic/masque/masque_epoll_server.h
+++ b/quic/masque/masque_epoll_server.h
@@ -6,6 +6,7 @@
#define QUICHE_QUIC_MASQUE_MASQUE_EPOLL_SERVER_H_
#include "quic/masque/masque_server_backend.h"
+#include "quic/masque/masque_utils.h"
#include "quic/platform/api/quic_export.h"
#include "quic/tools/quic_server.h"
@@ -14,7 +15,8 @@
// QUIC server that implements MASQUE.
class QUIC_NO_EXPORT MasqueEpollServer : public QuicServer {
public:
- explicit MasqueEpollServer(MasqueServerBackend* masque_server_backend);
+ explicit MasqueEpollServer(MasqueMode masque_mode,
+ MasqueServerBackend* masque_server_backend);
// Disallow copy and assign.
MasqueEpollServer(const MasqueEpollServer&) = delete;
@@ -24,6 +26,7 @@
QuicDispatcher* CreateQuicDispatcher() override;
private:
+ MasqueMode masque_mode_;
MasqueServerBackend* masque_server_backend_; // Unowned.
};
diff --git a/quic/masque/masque_server_backend.cc b/quic/masque/masque_server_backend.cc
index 1994819..230f7e7 100644
--- a/quic/masque/masque_server_backend.cc
+++ b/quic/masque/masque_server_backend.cc
@@ -19,9 +19,10 @@
} // namespace
-MasqueServerBackend::MasqueServerBackend(const std::string& server_authority,
+MasqueServerBackend::MasqueServerBackend(MasqueMode masque_mode,
+ const std::string& server_authority,
const std::string& cache_directory)
- : server_authority_(server_authority) {
+ : masque_mode_(masque_mode), server_authority_(server_authority) {
if (!cache_directory.empty()) {
QuicMemoryCacheBackend::InitializeBackend(cache_directory);
}
@@ -43,16 +44,24 @@
absl::string_view path = path_pair->second;
absl::string_view scheme = scheme_pair->second;
absl::string_view method = method_pair->second;
- if (scheme != "https" || method != "POST" || request_body.empty()) {
- // MASQUE requests MUST be a non-empty https POST.
- return false;
- }
+ std::string masque_path = "";
+ if (masque_mode_ == MasqueMode::kLegacy) {
+ if (scheme != "https" || method != "POST" || request_body.empty()) {
+ // MASQUE requests MUST be a non-empty https POST.
+ return false;
+ }
- if (path.rfind("/.well-known/masque/", 0) != 0) {
- // This request is not a MASQUE path.
- return false;
+ if (path.rfind("/.well-known/masque/", 0) != 0) {
+ // This request is not a MASQUE path.
+ return false;
+ }
+ masque_path = path.substr(sizeof("/.well-known/masque/") - 1);
+ } else {
+ if (method != "CONNECT-UDP") {
+ // Unexpected method.
+ return false;
+ }
}
- std::string masque_path(path.substr(sizeof("/.well-known/masque/") - 1));
if (!server_authority_.empty()) {
auto authority_pair = request_headers.find(":authority");
diff --git a/quic/masque/masque_server_backend.h b/quic/masque/masque_server_backend.h
index 1809359..86b1e00 100644
--- a/quic/masque/masque_server_backend.h
+++ b/quic/masque/masque_server_backend.h
@@ -6,6 +6,7 @@
#define QUICHE_QUIC_MASQUE_MASQUE_SERVER_BACKEND_H_
#include "absl/container/flat_hash_map.h"
+#include "quic/masque/masque_utils.h"
#include "quic/platform/api/quic_export.h"
#include "quic/tools/quic_memory_cache_backend.h"
@@ -27,7 +28,8 @@
virtual ~BackendClient() = default;
};
- explicit MasqueServerBackend(const std::string& server_authority,
+ explicit MasqueServerBackend(MasqueMode masque_mode,
+ const std::string& server_authority,
const std::string& cache_directory);
// Disallow copy and assign.
@@ -57,6 +59,7 @@
const std::string& request_body,
QuicSimpleServerBackend::RequestHandler* request_handler);
+ MasqueMode masque_mode_;
std::string server_authority_;
absl::flat_hash_map<std::string, std::unique_ptr<QuicBackendResponse>>
active_response_map_;
diff --git a/quic/masque/masque_server_bin.cc b/quic/masque/masque_server_bin.cc
index b25bf83..a6b6225 100644
--- a/quic/masque/masque_server_bin.cc
+++ b/quic/masque/masque_server_bin.cc
@@ -35,6 +35,12 @@
"Specifies the authority over which the server will accept MASQUE "
"requests. Defaults to empty which allows all authorities.");
+DEFINE_QUIC_COMMAND_LINE_FLAG(std::string,
+ masque_mode,
+ "",
+ "Allows setting MASQUE mode, valid values are "
+ "open and legacy. Defaults to open.");
+
int main(int argc, char* argv[]) {
const char* usage = "Usage: masque_server [options]";
std::vector<std::string> non_option_args =
@@ -44,17 +50,28 @@
return 0;
}
- auto backend = std::make_unique<quic::MasqueServerBackend>(
- GetQuicFlag(FLAGS_server_authority), GetQuicFlag(FLAGS_cache_dir));
+ quic::MasqueMode masque_mode = quic::MasqueMode::kOpen;
+ std::string mode_string = GetQuicFlag(FLAGS_masque_mode);
+ if (mode_string == "legacy") {
+ masque_mode = quic::MasqueMode::kLegacy;
+ } else if (!mode_string.empty() && mode_string != "open") {
+ std::cerr << "Invalid masque_mode \"" << mode_string << "\"" << std::endl;
+ return 1;
+ }
- auto server = std::make_unique<quic::MasqueEpollServer>(backend.get());
+ auto backend = std::make_unique<quic::MasqueServerBackend>(
+ masque_mode, GetQuicFlag(FLAGS_server_authority),
+ GetQuicFlag(FLAGS_cache_dir));
+
+ auto server =
+ std::make_unique<quic::MasqueEpollServer>(masque_mode, backend.get());
if (!server->CreateUDPSocketAndListen(quic::QuicSocketAddress(
quic::QuicIpAddress::Any6(), GetQuicFlag(FLAGS_port)))) {
return 1;
}
- std::cerr << "Started MASQUE server" << std::endl;
+ std::cerr << "Started " << masque_mode << " MASQUE server" << std::endl;
server->HandleEventsForever();
return 0;
}
diff --git a/quic/masque/masque_server_session.cc b/quic/masque/masque_server_session.cc
index 228444a..de231ba 100644
--- a/quic/masque/masque_server_session.cc
+++ b/quic/masque/masque_server_session.cc
@@ -4,14 +4,79 @@
#include "quic/masque/masque_server_session.h"
+#include <netdb.h>
+
+#include "absl/strings/str_cat.h"
+#include "quic/core/quic_data_reader.h"
+#include "quic/core/quic_udp_socket.h"
+#include "quic/tools/quic_url.h"
+#include "common/platform/api/quiche_text_utils.h"
+
namespace quic {
+namespace {
+// RAII wrapper for QuicUdpSocketFd.
+class FdWrapper {
+ public:
+ // Takes ownership of |fd| and closes the file descriptor on destruction.
+ explicit FdWrapper(int address_family) {
+ QuicUdpSocketApi socket_api;
+ fd_ =
+ socket_api.Create(address_family,
+ /*receive_buffer_size =*/kDefaultSocketReceiveBuffer,
+ /*send_buffer_size =*/kDefaultSocketReceiveBuffer);
+ }
+
+ ~FdWrapper() {
+ if (fd_ == kQuicInvalidSocketFd) {
+ return;
+ }
+ QuicUdpSocketApi socket_api;
+ socket_api.Destroy(fd_);
+ }
+
+ // Hands ownership of the file descriptor to the caller.
+ QuicUdpSocketFd extract_fd() {
+ QuicUdpSocketFd fd = fd_;
+ fd_ = kQuicInvalidSocketFd;
+ return fd;
+ }
+
+ // Keeps ownership of the file descriptor.
+ QuicUdpSocketFd fd() { return fd_; }
+
+ // Disallow copy and move.
+ FdWrapper(const FdWrapper&) = delete;
+ FdWrapper(FdWrapper&&) = delete;
+ FdWrapper& operator=(const FdWrapper&) = delete;
+ FdWrapper& operator=(FdWrapper&&) = delete;
+
+ private:
+ QuicUdpSocketFd fd_;
+};
+
+std::unique_ptr<QuicBackendResponse> CreateBackendErrorResponse(
+ absl::string_view status,
+ absl::string_view body) {
+ spdy::Http2HeaderBlock response_headers;
+ response_headers[":status"] = status;
+ auto response = std::make_unique<QuicBackendResponse>();
+ response->set_response_type(QuicBackendResponse::REGULAR_RESPONSE);
+ response->set_headers(std::move(response_headers));
+ response->set_body(body);
+ return response;
+}
+
+} // namespace
+
MasqueServerSession::MasqueServerSession(
+ MasqueMode masque_mode,
const QuicConfig& config,
const ParsedQuicVersionVector& supported_versions,
QuicConnection* connection,
QuicSession::Visitor* visitor,
Visitor* owner,
+ QuicEpollServer* epoll_server,
QuicCryptoServerStreamBase::Helper* helper,
const QuicCryptoServerConfig* crypto_config,
QuicCompressedCertsCache* compressed_certs_cache,
@@ -26,41 +91,74 @@
masque_server_backend),
masque_server_backend_(masque_server_backend),
owner_(owner),
- compression_engine_(this) {
+ epoll_server_(epoll_server),
+ compression_engine_(this),
+ masque_mode_(masque_mode) {
masque_server_backend_->RegisterBackendClient(connection_id(), this);
}
void MasqueServerSession::OnMessageReceived(absl::string_view message) {
QUIC_DVLOG(1) << "Received DATAGRAM frame of length " << message.length();
+ if (masque_mode_ == MasqueMode::kLegacy) {
+ QuicConnectionId client_connection_id, server_connection_id;
+ QuicSocketAddress target_server_address;
+ std::vector<char> packet;
+ bool version_present;
+ if (!compression_engine_.DecompressDatagram(
+ message, &client_connection_id, &server_connection_id,
+ &target_server_address, &packet, &version_present)) {
+ return;
+ }
- QuicConnectionId client_connection_id, server_connection_id;
- QuicSocketAddress server_address;
- std::vector<char> packet;
- bool version_present;
- if (!compression_engine_.DecompressDatagram(
- message, &client_connection_id, &server_connection_id,
- &server_address, &packet, &version_present)) {
+ QUIC_DVLOG(1) << "Received packet of length " << packet.size() << " for "
+ << target_server_address << " client "
+ << client_connection_id;
+
+ if (version_present) {
+ if (client_connection_id.length() != kQuicDefaultConnectionIdLength) {
+ QUIC_DLOG(ERROR)
+ << "Dropping long header with invalid client_connection_id "
+ << client_connection_id;
+ return;
+ }
+ owner_->RegisterClientConnectionId(client_connection_id, this);
+ }
+
+ WriteResult write_result = connection()->writer()->WritePacket(
+ packet.data(), packet.size(), connection()->self_address().host(),
+ target_server_address, nullptr);
+ QUIC_DVLOG(1) << "Got " << write_result << " for " << packet.size()
+ << " bytes to " << target_server_address;
+ return;
+ }
+ QUICHE_DCHECK_EQ(masque_mode_, MasqueMode::kOpen);
+ QuicDataReader reader(message);
+ QuicDatagramFlowId flow_id;
+ if (!reader.ReadVarInt62(&flow_id)) {
+ QUIC_DLOG(ERROR) << "Failed to read flow_id";
return;
}
- QUIC_DVLOG(1) << "Received packet of length " << packet.size() << " for "
- << server_address << " client " << client_connection_id;
-
- if (version_present) {
- if (client_connection_id.length() != kQuicDefaultConnectionIdLength) {
- QUIC_DLOG(ERROR)
- << "Dropping long header with invalid client_connection_id "
- << client_connection_id;
- return;
- }
- owner_->RegisterClientConnectionId(client_connection_id, this);
+ auto it =
+ absl::c_find_if(connect_udp_server_states_,
+ [flow_id](const ConnectUdpServerState& connect_udp) {
+ return connect_udp.flow_id() == flow_id;
+ });
+ if (it == connect_udp_server_states_.end()) {
+ QUIC_DLOG(ERROR) << "Received unknown flow_id " << flow_id;
+ return;
}
-
- WriteResult write_result = connection()->writer()->WritePacket(
- packet.data(), packet.size(), connection()->self_address().host(),
- server_address, nullptr);
- QUIC_DVLOG(1) << "Got " << write_result << " for " << packet.size()
- << " bytes to " << server_address;
+ QuicSocketAddress target_server_address = it->target_server_address();
+ QUICHE_DCHECK(target_server_address.IsInitialized());
+ QuicUdpSocketFd fd = it->fd();
+ QUICHE_DCHECK_NE(fd, kQuicInvalidSocketFd);
+ absl::string_view packet = reader.ReadRemainingPayload();
+ QuicUdpSocketApi socket_api;
+ QuicUdpPacketInfo packet_info;
+ packet_info.SetPeerAddress(target_server_address);
+ WriteResult write_result =
+ socket_api.WritePacket(fd, packet.data(), packet.length(), packet_info);
+ QUIC_DVLOG(1) << "Wrote packet to server with result " << write_result;
}
void MasqueServerSession::OnMessageAcked(QuicMessageId message_id,
@@ -73,17 +171,123 @@
}
void MasqueServerSession::OnConnectionClosed(
- const QuicConnectionCloseFrame& /*frame*/,
- ConnectionCloseSource /*source*/) {
+ const QuicConnectionCloseFrame& frame,
+ ConnectionCloseSource source) {
+ QuicSimpleServerSession::OnConnectionClosed(frame, source);
QUIC_DLOG(INFO) << "Closing connection for " << connection_id();
masque_server_backend_->RemoveBackendClient(connection_id());
+ // Clearing this state will close all sockets.
+ connect_udp_server_states_.clear();
+}
+
+void MasqueServerSession::OnStreamClosed(QuicStreamId stream_id) {
+ connect_udp_server_states_.remove_if(
+ [stream_id](const ConnectUdpServerState& connect_udp) {
+ return connect_udp.stream_id() == stream_id;
+ });
+
+ QuicSimpleServerSession::OnStreamClosed(stream_id);
}
std::unique_ptr<QuicBackendResponse> MasqueServerSession::HandleMasqueRequest(
const std::string& masque_path,
- const spdy::Http2HeaderBlock& /*request_headers*/,
+ const spdy::Http2HeaderBlock& request_headers,
const std::string& request_body,
- QuicSimpleServerBackend::RequestHandler* /*request_handler*/) {
+ QuicSimpleServerBackend::RequestHandler* request_handler) {
+ if (masque_mode_ != MasqueMode::kLegacy) {
+ auto path_pair = request_headers.find(":path");
+ auto scheme_pair = request_headers.find(":scheme");
+ auto method_pair = request_headers.find(":method");
+ auto flow_id_pair = request_headers.find("datagram-flow-id");
+ auto authority_pair = request_headers.find(":authority");
+ if (path_pair == request_headers.end() ||
+ scheme_pair == request_headers.end() ||
+ method_pair == request_headers.end() ||
+ flow_id_pair == request_headers.end() ||
+ authority_pair == request_headers.end()) {
+ QUIC_DLOG(ERROR) << "MASQUE request is missing required headers";
+ return CreateBackendErrorResponse("400", "Missing required headers");
+ }
+ absl::string_view path = path_pair->second;
+ absl::string_view scheme = scheme_pair->second;
+ absl::string_view method = method_pair->second;
+ absl::string_view flow_id_str = flow_id_pair->second;
+ absl::string_view authority = authority_pair->second;
+ if (path.empty()) {
+ QUIC_DLOG(ERROR) << "MASQUE request with empty path";
+ return CreateBackendErrorResponse("400", "Empty path");
+ }
+ if (scheme.empty()) {
+ return CreateBackendErrorResponse("400", "Empty scheme");
+ return nullptr;
+ }
+ if (method != "CONNECT-UDP") {
+ QUIC_DLOG(ERROR) << "MASQUE request with bad method \"" << method << "\"";
+ return CreateBackendErrorResponse("400", "Bad method");
+ }
+ QuicDatagramFlowId flow_id;
+ if (!absl::SimpleAtoi(flow_id_str, &flow_id)) {
+ QUIC_DLOG(ERROR) << "MASQUE request with bad flow_id \"" << flow_id_str
+ << "\"";
+ return CreateBackendErrorResponse("400", "Bad flow ID");
+ }
+ QuicUrl url(absl::StrCat("https://", authority));
+ if (!url.IsValid() || url.PathParamsQuery() != "/") {
+ QUIC_DLOG(ERROR) << "MASQUE request with bad authority \"" << authority
+ << "\"";
+ return CreateBackendErrorResponse("400", "Bad authority");
+ }
+
+ std::string port = absl::StrCat(url.port());
+ addrinfo hint = {};
+ hint.ai_protocol = IPPROTO_UDP;
+
+ addrinfo* info_list = nullptr;
+ int result =
+ getaddrinfo(url.host().c_str(), port.c_str(), &hint, &info_list);
+ if (result != 0) {
+ QUIC_DLOG(ERROR) << "Failed to resolve " << authority << ": "
+ << gai_strerror(result);
+ return CreateBackendErrorResponse("500", "DNS resolution failed");
+ }
+
+ QUICHE_CHECK_NE(info_list, nullptr);
+ std::unique_ptr<addrinfo, void (*)(addrinfo*)> info_list_owned(
+ info_list, freeaddrinfo);
+ QuicSocketAddress target_server_address(info_list->ai_addr,
+ info_list->ai_addrlen);
+ QUIC_DLOG(INFO) << "Got CONNECT_UDP request flow_id=" << flow_id
+ << " target_server_address=\"" << target_server_address
+ << "\"";
+
+ FdWrapper fd_wrapper(target_server_address.host().AddressFamilyToInt());
+ if (fd_wrapper.fd() == kQuicInvalidSocketFd) {
+ QUIC_DLOG(ERROR) << "Socket creation failed";
+ return CreateBackendErrorResponse("500", "Socket creation failed");
+ }
+ QuicSocketAddress any_v6_address(QuicIpAddress::Any6(), 0);
+ QuicUdpSocketApi socket_api;
+ if (!socket_api.Bind(fd_wrapper.fd(), any_v6_address)) {
+ QUIC_DLOG(ERROR) << "Socket bind failed";
+ return CreateBackendErrorResponse("500", "Socket bind failed");
+ }
+ epoll_server_->RegisterFDForRead(fd_wrapper.fd(), this);
+
+ connect_udp_server_states_.emplace_back(ConnectUdpServerState(
+ flow_id, request_handler->stream_id(), target_server_address,
+ fd_wrapper.extract_fd(), epoll_server_));
+
+ spdy::Http2HeaderBlock response_headers;
+ response_headers[":status"] = "200";
+ response_headers["datagram-flow-id"] = absl::StrCat(flow_id);
+ auto response = std::make_unique<QuicBackendResponse>();
+ response->set_response_type(QuicBackendResponse::INCOMPLETE_RESPONSE);
+ response->set_headers(std::move(response_headers));
+ response->set_body("");
+
+ return response;
+ }
+
QUIC_DLOG(INFO) << "MasqueServerSession handling MASQUE request";
if (masque_path == "init") {
@@ -120,9 +324,166 @@
void MasqueServerSession::HandlePacketFromServer(
const ReceivedPacketInfo& packet_info) {
QUIC_DVLOG(1) << "MasqueServerSession received " << packet_info;
- compression_engine_.CompressAndSendPacket(
- packet_info.packet.AsStringPiece(), packet_info.destination_connection_id,
- packet_info.source_connection_id, packet_info.peer_address);
+ if (masque_mode_ == MasqueMode::kLegacy) {
+ compression_engine_.CompressAndSendPacket(
+ packet_info.packet.AsStringPiece(),
+ packet_info.destination_connection_id, packet_info.source_connection_id,
+ packet_info.peer_address);
+ return;
+ }
+ QUIC_LOG(ERROR) << "Ignoring packet from server in " << masque_mode_
+ << " mode";
+}
+
+void MasqueServerSession::OnRegistration(QuicEpollServer* /*eps*/,
+ QuicUdpSocketFd fd,
+ int event_mask) {
+ QUIC_DVLOG(1) << "OnRegistration " << fd << " event_mask " << event_mask;
+}
+
+void MasqueServerSession::OnModification(QuicUdpSocketFd fd, int event_mask) {
+ QUIC_DVLOG(1) << "OnModification " << fd << " event_mask " << event_mask;
+}
+
+void MasqueServerSession::OnEvent(QuicUdpSocketFd fd, QuicEpollEvent* event) {
+ if ((event->in_events & EPOLLIN) == 0) {
+ QUIC_DVLOG(1) << "Ignoring OnEvent fd " << fd << " event mask "
+ << event->in_events;
+ return;
+ }
+ auto it = absl::c_find_if(connect_udp_server_states_,
+ [fd](const ConnectUdpServerState& connect_udp) {
+ return connect_udp.fd() == fd;
+ });
+ if (it == connect_udp_server_states_.end()) {
+ QUIC_BUG << "Got unexpected event mask " << event->in_events
+ << " on unknown fd " << fd;
+ return;
+ }
+ QuicDatagramFlowId flow_id = it->flow_id();
+ QuicSocketAddress expected_target_server_address =
+ it->target_server_address();
+ QUICHE_DCHECK(expected_target_server_address.IsInitialized());
+ QUIC_DVLOG(1) << "Received readable event on fd " << fd << " (mask "
+ << event->in_events << ") flow_id " << flow_id << " server "
+ << expected_target_server_address;
+ QuicUdpSocketApi socket_api;
+ BitMask64 packet_info_interested(QuicUdpPacketInfoBit::PEER_ADDRESS);
+ char packet_buffer[kMaxIncomingPacketSize];
+ char control_buffer[kDefaultUdpPacketControlBufferSize];
+ while (true) {
+ QuicUdpSocketApi::ReadPacketResult read_result;
+ read_result.packet_buffer = {packet_buffer, sizeof(packet_buffer)};
+ read_result.control_buffer = {control_buffer, sizeof(control_buffer)};
+ socket_api.ReadPacket(fd, packet_info_interested, &read_result);
+ if (!read_result.ok) {
+ // Most likely there is nothing left to read, break out of read loop.
+ break;
+ }
+ if (!read_result.packet_info.HasValue(QuicUdpPacketInfoBit::PEER_ADDRESS)) {
+ QUIC_BUG << "Missing peer address when reading from fd " << fd;
+ continue;
+ }
+ if (read_result.packet_info.peer_address() !=
+ expected_target_server_address) {
+ QUIC_DLOG(ERROR) << "Ignoring UDP packet on fd " << fd
+ << " from unexpected server address "
+ << read_result.packet_info.peer_address()
+ << " (expected " << expected_target_server_address
+ << ")";
+ continue;
+ }
+ if (!connection()->connected()) {
+ QUIC_BUG << "Unexpected incoming UDP packet on fd " << fd << " from "
+ << expected_target_server_address
+ << " because MASQUE connection is closed";
+ return;
+ }
+ // The packet is valid, send it to the client in a DATAGRAM frame.
+ size_t slice_length = QuicDataWriter::GetVarInt62Len(flow_id) +
+ read_result.packet_buffer.buffer_len;
+ QuicUniqueBufferPtr buffer = MakeUniqueBuffer(
+ connection()->helper()->GetStreamSendBufferAllocator(), slice_length);
+ QuicDataWriter writer(slice_length, buffer.get());
+ if (!writer.WriteVarInt62(flow_id)) {
+ QUIC_BUG << "Failed to write flow_id";
+ continue;
+ }
+ if (!writer.WriteBytes(read_result.packet_buffer.buffer,
+ read_result.packet_buffer.buffer_len)) {
+ QUIC_BUG << "Failed to write packet";
+ continue;
+ }
+ QUICHE_DCHECK_EQ(writer.remaining(), 0u);
+ QuicMemSlice slice(std::move(buffer), slice_length);
+ MessageResult message_result = SendMessage(QuicMemSliceSpan(&slice));
+ QUIC_DVLOG(1) << "Sent UDP packet from target server of length "
+ << read_result.packet_buffer.buffer_len << " with flow ID "
+ << flow_id << " and got message result " << message_result;
+ }
+}
+
+void MasqueServerSession::OnUnregistration(QuicUdpSocketFd fd, bool replaced) {
+ QUIC_DVLOG(1) << "OnUnregistration " << fd << " " << (replaced ? "" : "!")
+ << " replaced";
+}
+
+void MasqueServerSession::OnShutdown(QuicEpollServer* /*eps*/,
+ QuicUdpSocketFd fd) {
+ QUIC_DVLOG(1) << "OnShutdown " << fd;
+}
+
+std::string MasqueServerSession::Name() const {
+ return std::string("MasqueServerSession-") + connection_id().ToString();
+}
+
+MasqueServerSession::ConnectUdpServerState::ConnectUdpServerState(
+ QuicDatagramFlowId flow_id,
+ QuicStreamId stream_id,
+ const QuicSocketAddress& target_server_address,
+ QuicUdpSocketFd fd,
+ QuicEpollServer* epoll_server)
+ : flow_id_(flow_id),
+ stream_id_(stream_id),
+ target_server_address_(target_server_address),
+ fd_(fd),
+ epoll_server_(epoll_server) {
+ QUICHE_DCHECK_NE(fd_, kQuicInvalidSocketFd);
+ QUICHE_DCHECK_NE(epoll_server_, nullptr);
+}
+
+MasqueServerSession::ConnectUdpServerState::~ConnectUdpServerState() {
+ if (fd_ == kQuicInvalidSocketFd) {
+ return;
+ }
+ QuicUdpSocketApi socket_api;
+ QUIC_DLOG(INFO) << "Closing fd " << fd_;
+ epoll_server_->UnregisterFD(fd_);
+ socket_api.Destroy(fd_);
+}
+
+MasqueServerSession::ConnectUdpServerState::ConnectUdpServerState(
+ MasqueServerSession::ConnectUdpServerState&& other) {
+ fd_ = kQuicInvalidSocketFd;
+ *this = std::move(other);
+}
+
+MasqueServerSession::ConnectUdpServerState&
+MasqueServerSession::ConnectUdpServerState::operator=(
+ MasqueServerSession::ConnectUdpServerState&& other) {
+ if (fd_ != kQuicInvalidSocketFd) {
+ QuicUdpSocketApi socket_api;
+ QUIC_DLOG(INFO) << "Closing fd " << fd_;
+ epoll_server_->UnregisterFD(fd_);
+ socket_api.Destroy(fd_);
+ }
+ flow_id_ = other.flow_id_;
+ stream_id_ = other.stream_id_;
+ target_server_address_ = other.target_server_address_;
+ fd_ = other.fd_;
+ epoll_server_ = other.epoll_server_;
+ other.fd_ = kQuicInvalidSocketFd;
+ return *this;
}
} // namespace quic
diff --git a/quic/masque/masque_server_session.h b/quic/masque/masque_server_session.h
index 9e43b63..a73e641 100644
--- a/quic/masque/masque_server_session.h
+++ b/quic/masque/masque_server_session.h
@@ -5,8 +5,12 @@
#ifndef QUICHE_QUIC_MASQUE_MASQUE_SERVER_SESSION_H_
#define QUICHE_QUIC_MASQUE_MASQUE_SERVER_SESSION_H_
+#include "quic/core/quic_types.h"
+#include "quic/core/quic_udp_socket.h"
#include "quic/masque/masque_compression_engine.h"
#include "quic/masque/masque_server_backend.h"
+#include "quic/masque/masque_utils.h"
+#include "quic/platform/api/quic_epoll.h"
#include "quic/platform/api/quic_export.h"
#include "quic/tools/quic_simple_server_session.h"
@@ -15,7 +19,8 @@
// QUIC server session for connection to MASQUE proxy.
class QUIC_NO_EXPORT MasqueServerSession
: public QuicSimpleServerSession,
- public MasqueServerBackend::BackendClient {
+ public MasqueServerBackend::BackendClient,
+ public QuicEpollCallbackInterface {
public:
// Interface meant to be implemented by owner of this MasqueServerSession
// instance.
@@ -33,11 +38,13 @@
};
explicit MasqueServerSession(
+ MasqueMode masque_mode,
const QuicConfig& config,
const ParsedQuicVersionVector& supported_versions,
QuicConnection* connection,
QuicSession::Visitor* visitor,
Visitor* owner,
+ QuicEpollServer* epoll_server,
QuicCryptoServerStreamBase::Helper* helper,
const QuicCryptoServerConfig* crypto_config,
QuicCompressedCertsCache* compressed_certs_cache,
@@ -54,6 +61,7 @@
void OnMessageLost(QuicMessageId message_id) override;
void OnConnectionClosed(const QuicConnectionCloseFrame& frame,
ConnectionCloseSource source) override;
+ void OnStreamClosed(QuicStreamId stream_id) override;
// From MasqueServerBackend::BackendClient.
std::unique_ptr<QuicBackendResponse> HandleMasqueRequest(
@@ -62,13 +70,61 @@
const std::string& request_body,
QuicSimpleServerBackend::RequestHandler* request_handler) override;
+ // From QuicEpollCallbackInterface.
+ void OnRegistration(QuicEpollServer* eps,
+ QuicUdpSocketFd fd,
+ int event_mask) override;
+ void OnModification(QuicUdpSocketFd fd, int event_mask) override;
+ void OnEvent(QuicUdpSocketFd fd, QuicEpollEvent* event) override;
+ void OnUnregistration(QuicUdpSocketFd fd, bool replaced) override;
+ void OnShutdown(QuicEpollServer* eps, QuicUdpSocketFd fd) override;
+ std::string Name() const override;
+
// Handle packet for client, meant to be called by MasqueDispatcher.
void HandlePacketFromServer(const ReceivedPacketInfo& packet_info);
private:
+ // State that the MasqueServerSession keeps for each CONNECT-UDP request.
+ class QUIC_NO_EXPORT ConnectUdpServerState {
+ public:
+ // ConnectUdpServerState takes ownership of |fd|. It will unregister it
+ // from |epoll_server| and close the file descriptor when destructed.
+ explicit ConnectUdpServerState(
+ QuicDatagramFlowId flow_id,
+ QuicStreamId stream_id,
+ const QuicSocketAddress& target_server_address,
+ QuicUdpSocketFd fd,
+ QuicEpollServer* epoll_server);
+
+ ~ConnectUdpServerState();
+
+ // Disallow copy but allow move.
+ ConnectUdpServerState(const ConnectUdpServerState&) = delete;
+ ConnectUdpServerState(ConnectUdpServerState&&);
+ ConnectUdpServerState& operator=(const ConnectUdpServerState&) = delete;
+ ConnectUdpServerState& operator=(ConnectUdpServerState&&);
+
+ QuicDatagramFlowId flow_id() const { return flow_id_; }
+ QuicStreamId stream_id() const { return stream_id_; }
+ const QuicSocketAddress& target_server_address() const {
+ return target_server_address_;
+ }
+ QuicUdpSocketFd fd() const { return fd_; }
+
+ private:
+ QuicDatagramFlowId flow_id_;
+ QuicStreamId stream_id_;
+ QuicSocketAddress target_server_address_;
+ QuicUdpSocketFd fd_; // Owned.
+ QuicEpollServer* epoll_server_; // Unowned.
+ };
+
MasqueServerBackend* masque_server_backend_; // Unowned.
Visitor* owner_; // Unowned.
+ QuicEpollServer* epoll_server_; // Unowned.
MasqueCompressionEngine compression_engine_;
+ MasqueMode masque_mode_;
+ std::list<ConnectUdpServerState> connect_udp_server_states_;
bool masque_initialized_ = false;
};
diff --git a/quic/masque/masque_utils.cc b/quic/masque/masque_utils.cc
index f4f77f5..17dec0d 100644
--- a/quic/masque/masque_utils.cc
+++ b/quic/masque/masque_utils.cc
@@ -27,4 +27,21 @@
return config;
}
+std::string MasqueModeToString(MasqueMode masque_mode) {
+ switch (masque_mode) {
+ case MasqueMode::kInvalid:
+ return "Invalid";
+ case MasqueMode::kLegacy:
+ return "Legacy";
+ case MasqueMode::kOpen:
+ return "Open";
+ }
+ return absl::StrCat("Unknown(", static_cast<int>(masque_mode), ")");
+}
+
+std::ostream& operator<<(std::ostream& os, const MasqueMode& masque_mode) {
+ os << MasqueModeToString(masque_mode);
+ return os;
+}
+
} // namespace quic
diff --git a/quic/masque/masque_utils.h b/quic/masque/masque_utils.h
index f151f55..8113047 100644
--- a/quic/masque/masque_utils.h
+++ b/quic/masque/masque_utils.h
@@ -18,7 +18,24 @@
QUIC_NO_EXPORT QuicConfig MasqueEncapsulatedConfig();
// Maximum packet size for encapsulated connections.
-const QuicByteCount kMasqueMaxEncapsulatedPacketSize = 1300;
+enum : QuicByteCount { kMasqueMaxEncapsulatedPacketSize = 1300 };
+
+// Mode that MASQUE is operating in.
+enum class MasqueMode : uint8_t {
+ kInvalid = 0, // Should never be used.
+ kLegacy = 1, // Legacy mode uses the legacy MASQUE protocol as documented in
+ // <https://tools.ietf.org/html/draft-schinazi-masque-protocol>. That version
+ // of MASQUE uses a custom application-protocol over HTTP/3, and also allows
+ // unauthenticated clients.
+ kOpen = 2, // Open mode uses the MASQUE HTTP CONNECT-UDP method as documented
+ // in <https://tools.ietf.org/html/draft-ietf-masque-connect-udp>. This mode
+ // allows unauthenticated clients (a more restricted mode will be added to
+ // this enum at a later date).
+};
+
+QUIC_NO_EXPORT std::string MasqueModeToString(MasqueMode masque_mode);
+QUIC_NO_EXPORT std::ostream& operator<<(std::ostream& os,
+ const MasqueMode& masque_mode);
} // namespace quic