Generalize EventLoopTcpClientSocket between TCP and connected-UDP
Gives us the socket support needed for CONNECT-UDP. As far as interface and implementation goes, connected UDP sockets are pretty much the same as TCP sockets other than one small parameter sent to ::socket(). But the expected behavior is obviously very different, so this CL is primarily added testing for UDP socket expectations.
PiperOrigin-RevId: 476185935
diff --git a/build/source_list.bzl b/build/source_list.bzl
index 266e20f..b2f73be 100644
--- a/build/source_list.bzl
+++ b/build/source_list.bzl
@@ -949,8 +949,8 @@
"quic/core/batch_writer/quic_batch_writer_test.h",
"quic/core/batch_writer/quic_gso_batch_writer.h",
"quic/core/batch_writer/quic_sendmmsg_batch_writer.h",
+ "quic/core/io/event_loop_connecting_client_socket.h",
"quic/core/io/event_loop_socket_factory.h",
- "quic/core/io/event_loop_tcp_client_socket.h",
"quic/core/io/quic_default_event_loop.h",
"quic/core/io/quic_event_loop.h",
"quic/core/io/quic_poll_event_loop.h",
@@ -980,8 +980,8 @@
"quic/core/batch_writer/quic_batch_writer_buffer.cc",
"quic/core/batch_writer/quic_gso_batch_writer.cc",
"quic/core/batch_writer/quic_sendmmsg_batch_writer.cc",
+ "quic/core/io/event_loop_connecting_client_socket.cc",
"quic/core/io/event_loop_socket_factory.cc",
- "quic/core/io/event_loop_tcp_client_socket.cc",
"quic/core/io/quic_default_event_loop.cc",
"quic/core/io/quic_poll_event_loop.cc",
"quic/core/io/socket_posix.cc",
@@ -1292,7 +1292,7 @@
"quic/core/http/quic_spdy_client_session_test.cc",
"quic/core/http/quic_spdy_client_stream_test.cc",
"quic/core/http/quic_spdy_server_stream_base_test.cc",
- "quic/core/io/event_loop_tcp_client_socket_test.cc",
+ "quic/core/io/event_loop_connecting_client_socket_test.cc",
"quic/core/io/quic_all_event_loops_test.cc",
"quic/core/io/quic_poll_event_loop_test.cc",
"quic/core/io/socket_test.cc",
diff --git a/build/source_list.gni b/build/source_list.gni
index 7b04267..18fcd44 100644
--- a/build/source_list.gni
+++ b/build/source_list.gni
@@ -949,8 +949,8 @@
"src/quiche/quic/core/batch_writer/quic_batch_writer_test.h",
"src/quiche/quic/core/batch_writer/quic_gso_batch_writer.h",
"src/quiche/quic/core/batch_writer/quic_sendmmsg_batch_writer.h",
+ "src/quiche/quic/core/io/event_loop_connecting_client_socket.h",
"src/quiche/quic/core/io/event_loop_socket_factory.h",
- "src/quiche/quic/core/io/event_loop_tcp_client_socket.h",
"src/quiche/quic/core/io/quic_default_event_loop.h",
"src/quiche/quic/core/io/quic_event_loop.h",
"src/quiche/quic/core/io/quic_poll_event_loop.h",
@@ -980,8 +980,8 @@
"src/quiche/quic/core/batch_writer/quic_batch_writer_buffer.cc",
"src/quiche/quic/core/batch_writer/quic_gso_batch_writer.cc",
"src/quiche/quic/core/batch_writer/quic_sendmmsg_batch_writer.cc",
+ "src/quiche/quic/core/io/event_loop_connecting_client_socket.cc",
"src/quiche/quic/core/io/event_loop_socket_factory.cc",
- "src/quiche/quic/core/io/event_loop_tcp_client_socket.cc",
"src/quiche/quic/core/io/quic_default_event_loop.cc",
"src/quiche/quic/core/io/quic_poll_event_loop.cc",
"src/quiche/quic/core/io/socket_posix.cc",
@@ -1292,7 +1292,7 @@
"src/quiche/quic/core/http/quic_spdy_client_session_test.cc",
"src/quiche/quic/core/http/quic_spdy_client_stream_test.cc",
"src/quiche/quic/core/http/quic_spdy_server_stream_base_test.cc",
- "src/quiche/quic/core/io/event_loop_tcp_client_socket_test.cc",
+ "src/quiche/quic/core/io/event_loop_connecting_client_socket_test.cc",
"src/quiche/quic/core/io/quic_all_event_loops_test.cc",
"src/quiche/quic/core/io/quic_poll_event_loop_test.cc",
"src/quiche/quic/core/io/socket_test.cc",
diff --git a/build/source_list.json b/build/source_list.json
index 19cac11..df96337 100644
--- a/build/source_list.json
+++ b/build/source_list.json
@@ -948,8 +948,8 @@
"quiche/quic/core/batch_writer/quic_batch_writer_test.h",
"quiche/quic/core/batch_writer/quic_gso_batch_writer.h",
"quiche/quic/core/batch_writer/quic_sendmmsg_batch_writer.h",
+ "quiche/quic/core/io/event_loop_connecting_client_socket.h",
"quiche/quic/core/io/event_loop_socket_factory.h",
- "quiche/quic/core/io/event_loop_tcp_client_socket.h",
"quiche/quic/core/io/quic_default_event_loop.h",
"quiche/quic/core/io/quic_event_loop.h",
"quiche/quic/core/io/quic_poll_event_loop.h",
@@ -979,8 +979,8 @@
"quiche/quic/core/batch_writer/quic_batch_writer_buffer.cc",
"quiche/quic/core/batch_writer/quic_gso_batch_writer.cc",
"quiche/quic/core/batch_writer/quic_sendmmsg_batch_writer.cc",
+ "quiche/quic/core/io/event_loop_connecting_client_socket.cc",
"quiche/quic/core/io/event_loop_socket_factory.cc",
- "quiche/quic/core/io/event_loop_tcp_client_socket.cc",
"quiche/quic/core/io/quic_default_event_loop.cc",
"quiche/quic/core/io/quic_poll_event_loop.cc",
"quiche/quic/core/io/socket_posix.cc",
@@ -1291,7 +1291,7 @@
"quiche/quic/core/http/quic_spdy_client_session_test.cc",
"quiche/quic/core/http/quic_spdy_client_stream_test.cc",
"quiche/quic/core/http/quic_spdy_server_stream_base_test.cc",
- "quiche/quic/core/io/event_loop_tcp_client_socket_test.cc",
+ "quiche/quic/core/io/event_loop_connecting_client_socket_test.cc",
"quiche/quic/core/io/quic_all_event_loops_test.cc",
"quiche/quic/core/io/quic_poll_event_loop_test.cc",
"quiche/quic/core/io/socket_test.cc",
diff --git a/quiche/quic/core/connecting_client_socket.h b/quiche/quic/core/connecting_client_socket.h
index 01c27dd..d5f8ee4 100644
--- a/quiche/quic/core/connecting_client_socket.h
+++ b/quiche/quic/core/connecting_client_socket.h
@@ -9,9 +9,10 @@
#include "absl/status/status.h"
#include "absl/status/statusor.h"
+#include "quiche/quic/core/quic_types.h"
+#include "quiche/quic/platform/api/quic_socket_address.h"
#include "quiche/common/platform/api/quiche_export.h"
#include "quiche/common/platform/api/quiche_mem_slice.h"
-#include "quiche/quic/core/quic_types.h"
namespace quic {
@@ -73,6 +74,9 @@
// platform behavior.
virtual void Disconnect() = 0;
+ // Gets the address assigned to a connected socket.
+ virtual absl::StatusOr<QuicSocketAddress> GetLocalAddress() = 0;
+
// Blocking read. Receives and returns a buffer of up to `max_size` bytes from
// socket. Returns status on error.
virtual absl::StatusOr<quiche::QuicheMemSlice> ReceiveBlocking(
diff --git a/quiche/quic/core/io/event_loop_tcp_client_socket.cc b/quiche/quic/core/io/event_loop_connecting_client_socket.cc
similarity index 89%
rename from quiche/quic/core/io/event_loop_tcp_client_socket.cc
rename to quiche/quic/core/io/event_loop_connecting_client_socket.cc
index 84263bc..aefa353 100644
--- a/quiche/quic/core/io/event_loop_tcp_client_socket.cc
+++ b/quiche/quic/core/io/event_loop_connecting_client_socket.cc
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
-#include "quiche/quic/core/io/event_loop_tcp_client_socket.h"
+#include "quiche/quic/core/io/event_loop_connecting_client_socket.h"
#include <limits>
#include <string>
@@ -21,12 +21,14 @@
namespace quic {
-EventLoopTcpClientSocket::EventLoopTcpClientSocket(
+EventLoopConnectingClientSocket::EventLoopConnectingClientSocket(
+ socket_api::SocketProtocol protocol,
const quic::QuicSocketAddress& peer_address,
QuicByteCount receive_buffer_size, QuicByteCount send_buffer_size,
QuicEventLoop* event_loop, quiche::QuicheBufferAllocator* buffer_allocator,
AsyncVisitor* async_visitor)
- : peer_address_(peer_address),
+ : protocol_(protocol),
+ peer_address_(peer_address),
receive_buffer_size_(receive_buffer_size),
send_buffer_size_(send_buffer_size),
event_loop_(event_loop),
@@ -36,15 +38,15 @@
QUICHE_DCHECK(buffer_allocator_);
}
-EventLoopTcpClientSocket::~EventLoopTcpClientSocket() {
+EventLoopConnectingClientSocket::~EventLoopConnectingClientSocket() {
// Connected socket must be closed via Disconnect() before destruction. Cannot
// safely recover if state indicates caller may be expecting async callbacks.
QUICHE_DCHECK(connect_status_ != ConnectStatus::kConnecting);
QUICHE_DCHECK(!receive_max_size_.has_value());
QUICHE_DCHECK(absl::holds_alternative<absl::monostate>(send_data_));
if (descriptor_ != kInvalidSocketFd) {
- QUICHE_BUG(quic_event_loop_tcp_socket_invalid_destruction)
- << "Must call Disconnect() on connected TCP socket before destruction.";
+ QUICHE_BUG(quic_event_loop_connecting_socket_invalid_destruction)
+ << "Must call Disconnect() on connected socket before destruction.";
Close();
}
@@ -52,7 +54,7 @@
QUICHE_DCHECK(send_remaining_.empty());
}
-absl::Status EventLoopTcpClientSocket::ConnectBlocking() {
+absl::Status EventLoopConnectingClientSocket::ConnectBlocking() {
QUICHE_DCHECK_EQ(descriptor_, kInvalidSocketFd);
QUICHE_DCHECK(connect_status_ == ConnectStatus::kNotConnected);
QUICHE_DCHECK(!receive_max_size_.has_value());
@@ -101,7 +103,7 @@
return status;
}
-void EventLoopTcpClientSocket::ConnectAsync() {
+void EventLoopConnectingClientSocket::ConnectAsync() {
QUICHE_DCHECK(async_visitor_);
QUICHE_DCHECK_EQ(descriptor_, kInvalidSocketFd);
QUICHE_DCHECK(connect_status_ == ConnectStatus::kNotConnected);
@@ -117,7 +119,7 @@
FinishOrRearmAsyncConnect(DoInitialConnect());
}
-void EventLoopTcpClientSocket::Disconnect() {
+void EventLoopConnectingClientSocket::Disconnect() {
QUICHE_DCHECK_NE(descriptor_, kInvalidSocketFd);
QUICHE_DCHECK(connect_status_ != ConnectStatus::kNotConnected);
@@ -148,8 +150,16 @@
}
}
+absl::StatusOr<QuicSocketAddress>
+EventLoopConnectingClientSocket::GetLocalAddress() {
+ QUICHE_DCHECK_NE(descriptor_, kInvalidSocketFd);
+ QUICHE_DCHECK(connect_status_ == ConnectStatus::kConnected);
+
+ return socket_api::GetSocketAddress(descriptor_);
+}
+
absl::StatusOr<quiche::QuicheMemSlice>
-EventLoopTcpClientSocket::ReceiveBlocking(QuicByteCount max_size) {
+EventLoopConnectingClientSocket::ReceiveBlocking(QuicByteCount max_size) {
QUICHE_DCHECK_GT(max_size, 0u);
QUICHE_DCHECK_NE(descriptor_, kInvalidSocketFd);
QUICHE_DCHECK(connect_status_ == ConnectStatus::kConnected);
@@ -189,7 +199,7 @@
return buffer;
}
-void EventLoopTcpClientSocket::ReceiveAsync(QuicByteCount max_size) {
+void EventLoopConnectingClientSocket::ReceiveAsync(QuicByteCount max_size) {
QUICHE_DCHECK(async_visitor_);
QUICHE_DCHECK_GT(max_size, 0u);
QUICHE_DCHECK_NE(descriptor_, kInvalidSocketFd);
@@ -201,7 +211,7 @@
FinishOrRearmAsyncReceive(ReceiveInternal());
}
-absl::Status EventLoopTcpClientSocket::SendBlocking(std::string data) {
+absl::Status EventLoopConnectingClientSocket::SendBlocking(std::string data) {
QUICHE_DCHECK(!data.empty());
QUICHE_DCHECK(absl::holds_alternative<absl::monostate>(send_data_));
@@ -209,7 +219,7 @@
return SendBlockingInternal();
}
-absl::Status EventLoopTcpClientSocket::SendBlocking(
+absl::Status EventLoopConnectingClientSocket::SendBlocking(
quiche::QuicheMemSlice data) {
QUICHE_DCHECK(!data.empty());
QUICHE_DCHECK(absl::holds_alternative<absl::monostate>(send_data_));
@@ -218,7 +228,7 @@
return SendBlockingInternal();
}
-void EventLoopTcpClientSocket::SendAsync(std::string data) {
+void EventLoopConnectingClientSocket::SendAsync(std::string data) {
QUICHE_DCHECK(!data.empty());
QUICHE_DCHECK(absl::holds_alternative<absl::monostate>(send_data_));
@@ -228,7 +238,7 @@
FinishOrRearmAsyncSend(SendInternal());
}
-void EventLoopTcpClientSocket::SendAsync(quiche::QuicheMemSlice data) {
+void EventLoopConnectingClientSocket::SendAsync(quiche::QuicheMemSlice data) {
QUICHE_DCHECK(!data.empty());
QUICHE_DCHECK(absl::holds_alternative<absl::monostate>(send_data_));
@@ -239,9 +249,8 @@
FinishOrRearmAsyncSend(SendInternal());
}
-void EventLoopTcpClientSocket::OnSocketEvent(QuicEventLoop* event_loop,
- SocketFd fd,
- QuicSocketEventMask events) {
+void EventLoopConnectingClientSocket::OnSocketEvent(
+ QuicEventLoop* event_loop, SocketFd fd, QuicSocketEventMask events) {
QUICHE_DCHECK_EQ(event_loop, event_loop_);
QUICHE_DCHECK_EQ(fd, descriptor_);
@@ -261,16 +270,16 @@
}
}
-absl::Status EventLoopTcpClientSocket::Open() {
+absl::Status EventLoopConnectingClientSocket::Open() {
QUICHE_DCHECK_EQ(descriptor_, kInvalidSocketFd);
QUICHE_DCHECK(connect_status_ == ConnectStatus::kNotConnected);
QUICHE_DCHECK(!receive_max_size_.has_value());
QUICHE_DCHECK(absl::holds_alternative<absl::monostate>(send_data_));
QUICHE_DCHECK(send_remaining_.empty());
- absl::StatusOr<SocketFd> descriptor = socket_api::CreateSocket(
- peer_address_.host().address_family(), socket_api::SocketProtocol::kTcp,
- /*blocking=*/false);
+ absl::StatusOr<SocketFd> descriptor =
+ socket_api::CreateSocket(peer_address_.host().address_family(), protocol_,
+ /*blocking=*/false);
if (!descriptor.ok()) {
QUICHE_DVLOG(1) << "Failed to open socket for connection to address: "
<< peer_address_.ToString()
@@ -327,7 +336,7 @@
return absl::OkStatus();
}
-void EventLoopTcpClientSocket::Close() {
+void EventLoopConnectingClientSocket::Close() {
QUICHE_DCHECK_NE(descriptor_, kInvalidSocketFd);
bool unregistered = event_loop_->UnregisterSocket(descriptor_);
@@ -343,7 +352,7 @@
descriptor_ = kInvalidSocketFd;
}
-absl::Status EventLoopTcpClientSocket::DoInitialConnect() {
+absl::Status EventLoopConnectingClientSocket::DoInitialConnect() {
QUICHE_DCHECK_NE(descriptor_, kInvalidSocketFd);
QUICHE_DCHECK(connect_status_ == ConnectStatus::kNotConnected);
QUICHE_DCHECK(!receive_max_size_.has_value());
@@ -366,7 +375,7 @@
return connect_result;
}
-absl::Status EventLoopTcpClientSocket::GetConnectResult() {
+absl::Status EventLoopConnectingClientSocket::GetConnectResult() {
QUICHE_DCHECK_NE(descriptor_, kInvalidSocketFd);
QUICHE_DCHECK(connect_status_ == ConnectStatus::kConnecting);
QUICHE_DCHECK(!receive_max_size_.has_value());
@@ -414,7 +423,8 @@
return error;
}
-void EventLoopTcpClientSocket::FinishOrRearmAsyncConnect(absl::Status status) {
+void EventLoopConnectingClientSocket::FinishOrRearmAsyncConnect(
+ absl::Status status) {
if (absl::IsUnavailable(status)) {
if (!event_loop_->SupportsEdgeTriggered()) {
bool result = event_loop_->RearmSocket(
@@ -429,7 +439,7 @@
}
absl::StatusOr<quiche::QuicheMemSlice>
-EventLoopTcpClientSocket::ReceiveInternal() {
+EventLoopConnectingClientSocket::ReceiveInternal() {
QUICHE_DCHECK_NE(descriptor_, kInvalidSocketFd);
QUICHE_DCHECK(connect_status_ == ConnectStatus::kConnected);
QUICHE_CHECK(receive_max_size_.has_value());
@@ -473,7 +483,7 @@
}
}
-void EventLoopTcpClientSocket::FinishOrRearmAsyncReceive(
+void EventLoopConnectingClientSocket::FinishOrRearmAsyncReceive(
absl::StatusOr<quiche::QuicheMemSlice> buffer) {
QUICHE_DCHECK(async_visitor_);
QUICHE_DCHECK(connect_status_ == ConnectStatus::kConnected);
@@ -491,7 +501,7 @@
}
}
-absl::StatusOr<bool> EventLoopTcpClientSocket::OneBytePeek() {
+absl::StatusOr<bool> EventLoopConnectingClientSocket::OneBytePeek() {
QUICHE_DCHECK_NE(descriptor_, kInvalidSocketFd);
char peek_buffer;
@@ -504,7 +514,7 @@
}
}
-absl::Status EventLoopTcpClientSocket::SendBlockingInternal() {
+absl::Status EventLoopConnectingClientSocket::SendBlockingInternal() {
QUICHE_DCHECK_NE(descriptor_, kInvalidSocketFd);
QUICHE_DCHECK(connect_status_ == ConnectStatus::kConnected);
QUICHE_DCHECK(!absl::holds_alternative<absl::monostate>(send_data_));
@@ -552,7 +562,7 @@
return status;
}
-absl::Status EventLoopTcpClientSocket::SendInternal() {
+absl::Status EventLoopConnectingClientSocket::SendInternal() {
QUICHE_DCHECK_NE(descriptor_, kInvalidSocketFd);
QUICHE_DCHECK(connect_status_ == ConnectStatus::kConnected);
QUICHE_DCHECK(!absl::holds_alternative<absl::monostate>(send_data_));
@@ -588,7 +598,8 @@
return absl::OkStatus();
}
-void EventLoopTcpClientSocket::FinishOrRearmAsyncSend(absl::Status status) {
+void EventLoopConnectingClientSocket::FinishOrRearmAsyncSend(
+ absl::Status status) {
QUICHE_DCHECK(async_visitor_);
QUICHE_DCHECK(connect_status_ == ConnectStatus::kConnected);
diff --git a/quiche/quic/core/io/event_loop_tcp_client_socket.h b/quiche/quic/core/io/event_loop_connecting_client_socket.h
similarity index 78%
rename from quiche/quic/core/io/event_loop_tcp_client_socket.h
rename to quiche/quic/core/io/event_loop_connecting_client_socket.h
index 213aac2..c85911c 100644
--- a/quiche/quic/core/io/event_loop_tcp_client_socket.h
+++ b/quiche/quic/core/io/event_loop_connecting_client_socket.h
@@ -2,8 +2,8 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
-#ifndef QUICHE_QUIC_CORE_IO_EVENT_LOOP_TCP_CLIENT_SOCKET_H_
-#define QUICHE_QUIC_CORE_IO_EVENT_LOOP_TCP_CLIENT_SOCKET_H_
+#ifndef QUICHE_QUIC_CORE_IO_EVENT_LOOP_CONNECTING_CLIENT_SOCKET_H_
+#define QUICHE_QUIC_CORE_IO_EVENT_LOOP_CONNECTING_CLIENT_SOCKET_H_
#include <string>
@@ -13,6 +13,7 @@
#include "absl/types/variant.h"
#include "quiche/quic/core/connecting_client_socket.h"
#include "quiche/quic/core/io/quic_event_loop.h"
+#include "quiche/quic/core/io/socket.h"
#include "quiche/quic/core/quic_types.h"
#include "quiche/quic/platform/api/quic_socket_address.h"
#include "quiche/common/platform/api/quiche_export.h"
@@ -20,8 +21,9 @@
namespace quic {
-// A TCP client socket implemented using an underlying QuicEventLoop.
-class QUICHE_EXPORT_PRIVATE EventLoopTcpClientSocket
+// A connection-based client socket implemented using an underlying
+// QuicEventLoop.
+class QUICHE_EXPORT_PRIVATE EventLoopConnectingClientSocket
: public ConnectingClientSocket,
public QuicSocketEventListener {
public:
@@ -29,19 +31,21 @@
// `send_buffer_size` is zero. `async_visitor` may be null if no async
// operations will be requested. `event_loop`, `buffer_allocator`, and
// `async_visitor` (if non-null) must outlive the created socket.
- EventLoopTcpClientSocket(const quic::QuicSocketAddress& peer_address,
- QuicByteCount receive_buffer_size,
- QuicByteCount send_buffer_size,
- QuicEventLoop* event_loop,
- quiche::QuicheBufferAllocator* buffer_allocator,
- AsyncVisitor* async_visitor);
+ EventLoopConnectingClientSocket(
+ socket_api::SocketProtocol protocol,
+ const quic::QuicSocketAddress& peer_address,
+ QuicByteCount receive_buffer_size, QuicByteCount send_buffer_size,
+ QuicEventLoop* event_loop,
+ quiche::QuicheBufferAllocator* buffer_allocator,
+ AsyncVisitor* async_visitor);
- ~EventLoopTcpClientSocket() override;
+ ~EventLoopConnectingClientSocket() override;
// ConnectingClientSocket:
absl::Status ConnectBlocking() override;
void ConnectAsync() override;
void Disconnect() override;
+ absl::StatusOr<QuicSocketAddress> GetLocalAddress() override;
absl::StatusOr<quiche::QuicheMemSlice> ReceiveBlocking(
QuicByteCount max_size) override;
void ReceiveAsync(QuicByteCount max_size) override;
@@ -75,6 +79,7 @@
absl::Status SendInternal();
void FinishOrRearmAsyncSend(absl::Status status);
+ const socket_api::SocketProtocol protocol_;
const QuicSocketAddress peer_address_;
const QuicByteCount receive_buffer_size_;
const QuicByteCount send_buffer_size_;
@@ -98,4 +103,4 @@
} // namespace quic
-#endif // QUICHE_QUIC_CORE_IO_EVENT_LOOP_TCP_CLIENT_SOCKET_H_
+#endif // QUICHE_QUIC_CORE_IO_EVENT_LOOP_CONNECTING_CLIENT_SOCKET_H_
diff --git a/quiche/quic/core/io/event_loop_connecting_client_socket_test.cc b/quiche/quic/core/io/event_loop_connecting_client_socket_test.cc
new file mode 100644
index 0000000..c99dc64
--- /dev/null
+++ b/quiche/quic/core/io/event_loop_connecting_client_socket_test.cc
@@ -0,0 +1,697 @@
+// Copyright 2022 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "quiche/quic/core/io/event_loop_connecting_client_socket.h"
+
+#include <functional>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/functional/bind_front.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/optional.h"
+#include "absl/types/span.h"
+#include "quiche/quic/core/connecting_client_socket.h"
+#include "quiche/quic/core/io/event_loop_socket_factory.h"
+#include "quiche/quic/core/io/quic_default_event_loop.h"
+#include "quiche/quic/core/io/quic_event_loop.h"
+#include "quiche/quic/core/io/socket.h"
+#include "quiche/quic/core/quic_time.h"
+#include "quiche/quic/platform/api/quic_ip_address_family.h"
+#include "quiche/quic/platform/api/quic_socket_address.h"
+#include "quiche/quic/test_tools/mock_clock.h"
+#include "quiche/quic/test_tools/quic_test_utils.h"
+#include "quiche/common/platform/api/quiche_logging.h"
+#include "quiche/common/platform/api/quiche_mem_slice.h"
+#include "quiche/common/platform/api/quiche_mutex.h"
+#include "quiche/common/platform/api/quiche_test.h"
+#include "quiche/common/platform/api/quiche_test_loopback.h"
+#include "quiche/common/platform/api/quiche_thread.h"
+#include "quiche/common/simple_buffer_allocator.h"
+
+namespace quic::test {
+namespace {
+
+using ::testing::Combine;
+using ::testing::Values;
+using ::testing::ValuesIn;
+
+class TestServerSocketRunner : public quiche::QuicheThread {
+ public:
+ using SocketBehavior = std::function<void(
+ SocketFd connected_socket, socket_api::SocketProtocol protocol)>;
+
+ TestServerSocketRunner(SocketFd server_socket_descriptor,
+ SocketBehavior behavior)
+ : QuicheThread("TestServerSocketRunner"),
+ server_socket_descriptor_(server_socket_descriptor),
+ behavior_(std::move(behavior)) {
+ Start();
+ }
+
+ ~TestServerSocketRunner() override { WaitForCompletion(); }
+
+ void WaitForCompletion() { completion_notification_.WaitForNotification(); }
+
+ protected:
+ SocketFd server_socket_descriptor() const {
+ return server_socket_descriptor_;
+ }
+
+ const SocketBehavior& behavior() const { return behavior_; }
+
+ quiche::QuicheNotification& completion_notification() {
+ return completion_notification_;
+ }
+
+ private:
+ const SocketFd server_socket_descriptor_;
+ const SocketBehavior behavior_;
+
+ quiche::QuicheNotification completion_notification_;
+};
+
+class TestTcpServerSocketRunner : public TestServerSocketRunner {
+ public:
+ // On construction, spins a separate thread to accept a connection from
+ // `server_socket_descriptor`, runs `behavior` with that connection, and then
+ // closes the accepted connection socket.
+ TestTcpServerSocketRunner(SocketFd server_socket_descriptor,
+ SocketBehavior behavior)
+ : TestServerSocketRunner(server_socket_descriptor, behavior) {}
+ ~TestTcpServerSocketRunner() override = default;
+
+ protected:
+ void Run() override {
+ AcceptSocket();
+ behavior()(connection_socket_descriptor_, socket_api::SocketProtocol::kTcp);
+ CloseSocket();
+
+ completion_notification().Notify();
+ }
+
+ private:
+ void AcceptSocket() {
+ absl::StatusOr<socket_api::AcceptResult> connection_socket =
+ socket_api::Accept(server_socket_descriptor(), /*blocking=*/true);
+ QUICHE_CHECK(connection_socket.ok());
+ connection_socket_descriptor_ = connection_socket.value().fd;
+ }
+
+ void CloseSocket() {
+ QUICHE_CHECK(socket_api::Close(connection_socket_descriptor_).ok());
+ QUICHE_CHECK(socket_api::Close(server_socket_descriptor()).ok());
+ }
+
+ SocketFd connection_socket_descriptor_ = kInvalidSocketFd;
+};
+
+class TestUdpServerSocketRunner : public TestServerSocketRunner {
+ public:
+ // On construction, spins a separate thread to connect
+ // `server_socket_descriptor` to `client_socket_address`, runs `behavior` with
+ // that connection, and then disconnects the socket.
+ TestUdpServerSocketRunner(SocketFd server_socket_descriptor,
+ SocketBehavior behavior,
+ QuicSocketAddress client_socket_address)
+ : TestServerSocketRunner(server_socket_descriptor, behavior),
+ client_socket_address_(std::move(client_socket_address)) {}
+ ~TestUdpServerSocketRunner() override = default;
+
+ protected:
+ void Run() override {
+ ConnectSocket();
+ behavior()(server_socket_descriptor(), socket_api::SocketProtocol::kUdp);
+ DisconnectSocket();
+
+ completion_notification().Notify();
+ }
+
+ private:
+ void ConnectSocket() {
+ QUICHE_CHECK(
+ socket_api::Connect(server_socket_descriptor(), client_socket_address_)
+ .ok());
+ }
+
+ void DisconnectSocket() {
+ QUICHE_CHECK(socket_api::Close(server_socket_descriptor()).ok());
+ }
+
+ QuicSocketAddress client_socket_address_;
+};
+
+class EventLoopConnectingClientSocketTest
+ : public quiche::test::QuicheTestWithParam<
+ std::tuple<socket_api::SocketProtocol, QuicEventLoopFactory*>>,
+ public ConnectingClientSocket::AsyncVisitor {
+ public:
+ void SetUp() override {
+ QuicEventLoopFactory* event_loop_factory;
+ std::tie(protocol_, event_loop_factory) = GetParam();
+
+ event_loop_ = event_loop_factory->Create(&clock_);
+ socket_factory_ = std::make_unique<EventLoopSocketFactory>(
+ event_loop_.get(), quiche::SimpleBufferAllocator::Get());
+
+ QUICHE_CHECK(CreateListeningServerSocket());
+ }
+
+ void TearDown() override {
+ if (server_socket_descriptor_ != kInvalidSocketFd) {
+ QUICHE_CHECK(socket_api::Close(server_socket_descriptor_).ok());
+ }
+ }
+
+ void ConnectComplete(absl::Status status) override {
+ QUICHE_CHECK(!connect_result_.has_value());
+ connect_result_ = std::move(status);
+ }
+
+ void ReceiveComplete(absl::StatusOr<quiche::QuicheMemSlice> data) override {
+ QUICHE_CHECK(!receive_result_.has_value());
+ receive_result_ = std::move(data);
+ }
+
+ void SendComplete(absl::Status status) override {
+ QUICHE_CHECK(!send_result_.has_value());
+ send_result_ = std::move(status);
+ }
+
+ protected:
+ std::unique_ptr<ConnectingClientSocket> CreateSocket(
+ const quic::QuicSocketAddress& peer_address,
+ ConnectingClientSocket::AsyncVisitor* async_visitor) {
+ switch (protocol_) {
+ case socket_api::SocketProtocol::kUdp:
+ return socket_factory_->CreateConnectingUdpClientSocket(
+ peer_address, /*receive_buffer_size=*/0, /*send_buffer_size=*/0,
+ async_visitor);
+ case socket_api::SocketProtocol::kTcp:
+ return socket_factory_->CreateTcpClientSocket(
+ peer_address, /*receive_buffer_size=*/0, /*send_buffer_size=*/0,
+ async_visitor);
+ }
+ }
+
+ std::unique_ptr<ConnectingClientSocket> CreateSocketToEncourageDelayedSend(
+ const quic::QuicSocketAddress& peer_address,
+ ConnectingClientSocket::AsyncVisitor* async_visitor) {
+ switch (protocol_) {
+ case socket_api::SocketProtocol::kUdp:
+ // Nothing special for UDP since UDP does not gaurantee packets will be
+ // sent once send buffers are full.
+ return socket_factory_->CreateConnectingUdpClientSocket(
+ peer_address, /*receive_buffer_size=*/0, /*send_buffer_size=*/0,
+ async_visitor);
+ case socket_api::SocketProtocol::kTcp:
+ // For TCP, set a very small send buffer to encourage sends to be
+ // delayed.
+ return socket_factory_->CreateTcpClientSocket(
+ peer_address, /*receive_buffer_size=*/0, /*send_buffer_size=*/4,
+ async_visitor);
+ }
+ }
+
+ bool CreateListeningServerSocket() {
+ absl::StatusOr<SocketFd> socket = socket_api::CreateSocket(
+ quiche::TestLoopback().address_family(), protocol_,
+ /*blocking=*/true);
+ QUICHE_CHECK(socket.ok());
+
+ // For TCP, set an extremely small receive buffer size to increase the odds
+ // of buffers filling up when testing asynchronous writes.
+ if (protocol_ == socket_api::SocketProtocol::kTcp) {
+ static const QuicByteCount kReceiveBufferSize = 2;
+ absl::Status result =
+ socket_api::SetReceiveBufferSize(socket.value(), kReceiveBufferSize);
+ QUICHE_CHECK(result.ok());
+ }
+
+ QuicSocketAddress bind_address(quiche::TestLoopback(), /*port=*/0);
+ absl::Status result = socket_api::Bind(socket.value(), bind_address);
+ QUICHE_CHECK(result.ok());
+
+ absl::StatusOr<QuicSocketAddress> socket_address =
+ socket_api::GetSocketAddress(socket.value());
+ QUICHE_CHECK(socket_address.ok());
+
+ // TCP sockets need to listen for connections. UDP sockets are ready to
+ // receive.
+ if (protocol_ == socket_api::SocketProtocol::kTcp) {
+ result = socket_api::Listen(socket.value(), /*backlog=*/1);
+ QUICHE_CHECK(result.ok());
+ }
+
+ server_socket_descriptor_ = socket.value();
+ server_socket_address_ = std::move(socket_address).value();
+ return true;
+ }
+
+ std::unique_ptr<TestServerSocketRunner> CreateServerSocketRunner(
+ TestServerSocketRunner::SocketBehavior behavior,
+ ConnectingClientSocket* client_socket) {
+ std::unique_ptr<TestServerSocketRunner> runner;
+ switch (protocol_) {
+ case socket_api::SocketProtocol::kUdp: {
+ absl::StatusOr<QuicSocketAddress> client_socket_address =
+ client_socket->GetLocalAddress();
+ QUICHE_CHECK(client_socket_address.ok());
+ runner = std::make_unique<TestUdpServerSocketRunner>(
+ server_socket_descriptor_, std::move(behavior),
+ std::move(client_socket_address).value());
+ break;
+ }
+ case socket_api::SocketProtocol::kTcp:
+ runner = std::make_unique<TestTcpServerSocketRunner>(
+ server_socket_descriptor_, std::move(behavior));
+ break;
+ }
+
+ // Runner takes responsibility for closing server socket.
+ server_socket_descriptor_ = kInvalidSocketFd;
+
+ return runner;
+ }
+
+ socket_api::SocketProtocol protocol_;
+
+ SocketFd server_socket_descriptor_ = kInvalidSocketFd;
+ QuicSocketAddress server_socket_address_;
+
+ MockClock clock_;
+ std::unique_ptr<QuicEventLoop> event_loop_;
+ std::unique_ptr<EventLoopSocketFactory> socket_factory_;
+
+ absl::optional<absl::Status> connect_result_;
+ absl::optional<absl::StatusOr<quiche::QuicheMemSlice>> receive_result_;
+ absl::optional<absl::Status> send_result_;
+};
+
+std::string GetTestParamName(
+ ::testing::TestParamInfo<
+ std::tuple<socket_api::SocketProtocol, QuicEventLoopFactory*>>
+ info) {
+ auto [protocol, event_loop_factory] = info.param;
+
+ return EscapeTestParamName(absl::StrCat(socket_api::GetProtocolName(protocol),
+ "_", event_loop_factory->GetName()));
+}
+
+INSTANTIATE_TEST_SUITE_P(EventLoopConnectingClientSocketTests,
+ EventLoopConnectingClientSocketTest,
+ Combine(Values(socket_api::SocketProtocol::kUdp,
+ socket_api::SocketProtocol::kTcp),
+ ValuesIn(GetAllSupportedEventLoops())),
+ &GetTestParamName);
+
+TEST_P(EventLoopConnectingClientSocketTest, ConnectBlocking) {
+ std::unique_ptr<ConnectingClientSocket> socket =
+ CreateSocket(server_socket_address_,
+ /*async_visitor=*/nullptr);
+
+ // No socket runner to accept the connection for the server, but that is not
+ // expected to be necessary for the connection to complete from the client for
+ // TCP or UDP.
+ EXPECT_TRUE(socket->ConnectBlocking().ok());
+
+ socket->Disconnect();
+}
+
+TEST_P(EventLoopConnectingClientSocketTest, ConnectAsync) {
+ std::unique_ptr<ConnectingClientSocket> socket =
+ CreateSocket(server_socket_address_,
+ /*async_visitor=*/this);
+
+ socket->ConnectAsync();
+
+ // TCP connection typically completes asynchronously and UDP connection
+ // typically completes before ConnectAsync returns, but there is no simple way
+ // to ensure either behaves one way or the other. If connecting is
+ // asynchronous, expect completion once signalled by the event loop.
+ if (!connect_result_.has_value()) {
+ event_loop_->RunEventLoopOnce(QuicTime::Delta::FromSeconds(1));
+ ASSERT_TRUE(connect_result_.has_value());
+ }
+ EXPECT_TRUE(connect_result_.value().ok());
+
+ connect_result_.reset();
+ socket->Disconnect();
+ EXPECT_FALSE(connect_result_.has_value());
+}
+
+TEST_P(EventLoopConnectingClientSocketTest, ErrorBeforeConnectAsync) {
+ std::unique_ptr<ConnectingClientSocket> socket =
+ CreateSocket(server_socket_address_,
+ /*async_visitor=*/this);
+
+ // Close the server socket.
+ EXPECT_TRUE(socket_api::Close(server_socket_descriptor_).ok());
+ server_socket_descriptor_ = kInvalidSocketFd;
+
+ socket->ConnectAsync();
+ if (!connect_result_.has_value()) {
+ event_loop_->RunEventLoopOnce(QuicTime::Delta::FromSeconds(1));
+ ASSERT_TRUE(connect_result_.has_value());
+ }
+
+ switch (protocol_) {
+ case socket_api::SocketProtocol::kTcp:
+ // Expect an error because server socket was closed before connection.
+ EXPECT_FALSE(connect_result_.value().ok());
+ break;
+ case socket_api::SocketProtocol::kUdp:
+ // No error for UDP because UDP connection success does not rely on the
+ // server.
+ EXPECT_TRUE(connect_result_.value().ok());
+ socket->Disconnect();
+ break;
+ }
+}
+
+TEST_P(EventLoopConnectingClientSocketTest, ErrorDuringConnectAsync) {
+ std::unique_ptr<ConnectingClientSocket> socket =
+ CreateSocket(server_socket_address_,
+ /*async_visitor=*/this);
+
+ socket->ConnectAsync();
+
+ if (connect_result_.has_value()) {
+ // UDP typically completes connection immediately before this test has a
+ // chance to actually attempt the error. TCP typically completes
+ // asynchronously, but no simple way to ensure that always happens.
+ EXPECT_TRUE(connect_result_.value().ok());
+ socket->Disconnect();
+ return;
+ }
+
+ // Close the server socket.
+ EXPECT_TRUE(socket_api::Close(server_socket_descriptor_).ok());
+ server_socket_descriptor_ = kInvalidSocketFd;
+
+ EXPECT_FALSE(connect_result_.has_value());
+ event_loop_->RunEventLoopOnce(QuicTime::Delta::FromSeconds(1));
+ ASSERT_TRUE(connect_result_.has_value());
+
+ switch (protocol_) {
+ case socket_api::SocketProtocol::kTcp:
+ EXPECT_FALSE(connect_result_.value().ok());
+ break;
+ case socket_api::SocketProtocol::kUdp:
+ // No error for UDP because UDP connection success does not rely on the
+ // server.
+ EXPECT_TRUE(connect_result_.value().ok());
+ break;
+ }
+}
+
+TEST_P(EventLoopConnectingClientSocketTest, Disconnect) {
+ std::unique_ptr<ConnectingClientSocket> socket =
+ CreateSocket(server_socket_address_,
+ /*async_visitor=*/nullptr);
+
+ ASSERT_TRUE(socket->ConnectBlocking().ok());
+ socket->Disconnect();
+}
+
+TEST_P(EventLoopConnectingClientSocketTest, DisconnectCancelsConnectAsync) {
+ std::unique_ptr<ConnectingClientSocket> socket =
+ CreateSocket(server_socket_address_,
+ /*async_visitor=*/this);
+
+ socket->ConnectAsync();
+
+ bool expect_canceled = true;
+ if (connect_result_.has_value()) {
+ // UDP typically completes connection immediately before this test has a
+ // chance to actually attempt the disconnect. TCP typically completes
+ // asynchronously, but no simple way to ensure that always happens.
+ EXPECT_TRUE(connect_result_.value().ok());
+ expect_canceled = false;
+ }
+
+ socket->Disconnect();
+
+ if (expect_canceled) {
+ // Expect immediate cancelled error.
+ ASSERT_TRUE(connect_result_.has_value());
+ EXPECT_TRUE(absl::IsCancelled(connect_result_.value()));
+ }
+}
+
+TEST_P(EventLoopConnectingClientSocketTest, ConnectAndReconnect) {
+ std::unique_ptr<ConnectingClientSocket> socket =
+ CreateSocket(server_socket_address_,
+ /*async_visitor=*/nullptr);
+
+ ASSERT_TRUE(socket->ConnectBlocking().ok());
+ socket->Disconnect();
+
+ // Expect `socket` can reconnect now that it has been disconnected.
+ EXPECT_TRUE(socket->ConnectBlocking().ok());
+ socket->Disconnect();
+}
+
+TEST_P(EventLoopConnectingClientSocketTest, GetLocalAddress) {
+ std::unique_ptr<ConnectingClientSocket> socket =
+ CreateSocket(server_socket_address_,
+ /*async_visitor=*/nullptr);
+ ASSERT_TRUE(socket->ConnectBlocking().ok());
+
+ absl::StatusOr<QuicSocketAddress> address = socket->GetLocalAddress();
+ ASSERT_TRUE(address.ok());
+ EXPECT_TRUE(address.value().IsInitialized());
+
+ socket->Disconnect();
+}
+
+void SendDataOnSocket(absl::string_view data, SocketFd connected_socket,
+ socket_api::SocketProtocol protocol) {
+ QUICHE_CHECK(!data.empty());
+
+ // May attempt to send in pieces for TCP. For UDP, expect failure if `data`
+ // cannot be sent in a single packet.
+ do {
+ absl::StatusOr<absl::string_view> remainder =
+ socket_api::Send(connected_socket, data);
+ if (!remainder.ok()) {
+ return;
+ }
+ data = remainder.value();
+ } while (protocol == socket_api::SocketProtocol::kTcp && !data.empty());
+
+ QUICHE_CHECK(data.empty());
+}
+
+TEST_P(EventLoopConnectingClientSocketTest, ReceiveBlocking) {
+ std::unique_ptr<ConnectingClientSocket> socket =
+ CreateSocket(server_socket_address_,
+ /*async_visitor=*/nullptr);
+ ASSERT_TRUE(socket->ConnectBlocking().ok());
+
+ std::string expected = {1, 2, 3, 4, 5, 6, 7, 8};
+ std::unique_ptr<TestServerSocketRunner> runner = CreateServerSocketRunner(
+ absl::bind_front(&SendDataOnSocket, expected), socket.get());
+
+ std::string received;
+ absl::StatusOr<quiche::QuicheMemSlice> data;
+
+ // Expect exactly one packet for UDP, and at least two receives (data + FIN)
+ // for TCP.
+ do {
+ data = socket->ReceiveBlocking(100);
+ ASSERT_TRUE(data.ok());
+ received.append(data.value().data(), data.value().length());
+ } while (protocol_ == socket_api::SocketProtocol::kTcp &&
+ !data.value().empty());
+
+ EXPECT_EQ(received, expected);
+
+ socket->Disconnect();
+}
+
+TEST_P(EventLoopConnectingClientSocketTest, ReceiveAsync) {
+ std::unique_ptr<ConnectingClientSocket> socket =
+ CreateSocket(server_socket_address_,
+ /*async_visitor=*/this);
+ ASSERT_TRUE(socket->ConnectBlocking().ok());
+
+ // Start an async receive. Expect no immediate results because runner not
+ // yet setup to send.
+ socket->ReceiveAsync(100);
+ EXPECT_FALSE(receive_result_.has_value());
+
+ // Send data from server.
+ std::string expected = {1, 2, 3, 4, 5, 6, 7, 8};
+ std::unique_ptr<TestServerSocketRunner> runner = CreateServerSocketRunner(
+ absl::bind_front(&SendDataOnSocket, expected), socket.get());
+
+ EXPECT_FALSE(receive_result_.has_value());
+ for (int i = 0; i < 5 && !receive_result_.has_value(); ++i) {
+ event_loop_->RunEventLoopOnce(QuicTime::Delta::FromSeconds(1));
+ }
+
+ // Expect to receive at least some of the sent data.
+ ASSERT_TRUE(receive_result_.has_value());
+ ASSERT_TRUE(receive_result_.value().ok());
+ EXPECT_FALSE(receive_result_.value().value().empty());
+ std::string received(receive_result_.value().value().data(),
+ receive_result_.value().value().length());
+
+ // For TCP, expect at least one more receive for the FIN.
+ if (protocol_ == socket_api::SocketProtocol::kTcp) {
+ absl::StatusOr<quiche::QuicheMemSlice> data;
+ do {
+ data = socket->ReceiveBlocking(100);
+ ASSERT_TRUE(data.ok());
+ received.append(data.value().data(), data.value().length());
+ } while (!data.value().empty());
+ }
+
+ EXPECT_EQ(received, expected);
+
+ receive_result_.reset();
+ socket->Disconnect();
+ EXPECT_FALSE(receive_result_.has_value());
+}
+
+TEST_P(EventLoopConnectingClientSocketTest, DisconnectCancelsReceiveAsync) {
+ std::unique_ptr<ConnectingClientSocket> socket =
+ CreateSocket(server_socket_address_,
+ /*async_visitor=*/this);
+
+ ASSERT_TRUE(socket->ConnectBlocking().ok());
+
+ // Start an asynchronous read, expecting no completion because server never
+ // sends any data.
+ socket->ReceiveAsync(100);
+ EXPECT_FALSE(receive_result_.has_value());
+
+ // Disconnect and expect an immediate cancelled error.
+ socket->Disconnect();
+ ASSERT_TRUE(receive_result_.has_value());
+ ASSERT_FALSE(receive_result_.value().ok());
+ EXPECT_TRUE(absl::IsCancelled(receive_result_.value().status()));
+}
+
+// Receive from `connected_socket` until connection is closed, writing
+// received data to `out_received`.
+void ReceiveDataFromSocket(std::string* out_received, SocketFd connected_socket,
+ socket_api::SocketProtocol protocol) {
+ out_received->clear();
+
+ std::string buffer(100, 0);
+ absl::StatusOr<absl::Span<char>> received;
+
+ // Expect exactly one packet for UDP, and at least two receives (data + FIN)
+ // for TCP.
+ do {
+ received = socket_api::Receive(connected_socket, absl::MakeSpan(buffer));
+ QUICHE_CHECK(received.ok());
+ out_received->insert(out_received->end(), received.value().begin(),
+ received.value().end());
+ } while (protocol == socket_api::SocketProtocol::kTcp &&
+ !received.value().empty());
+ QUICHE_CHECK(!out_received->empty());
+}
+
+TEST_P(EventLoopConnectingClientSocketTest, SendBlocking) {
+ std::unique_ptr<ConnectingClientSocket> socket =
+ CreateSocket(server_socket_address_,
+ /*async_visitor=*/nullptr);
+ ASSERT_TRUE(socket->ConnectBlocking().ok());
+
+ std::string sent;
+ std::unique_ptr<TestServerSocketRunner> runner = CreateServerSocketRunner(
+ absl::bind_front(&ReceiveDataFromSocket, &sent), socket.get());
+
+ std::string expected = {1, 2, 3, 4, 5, 6, 7, 8};
+ EXPECT_TRUE(socket->SendBlocking(expected).ok());
+ socket->Disconnect();
+
+ runner->WaitForCompletion();
+ EXPECT_EQ(sent, expected);
+}
+
+TEST_P(EventLoopConnectingClientSocketTest, SendAsync) {
+ std::unique_ptr<ConnectingClientSocket> socket =
+ CreateSocketToEncourageDelayedSend(server_socket_address_,
+ /*async_visitor=*/this);
+ ASSERT_TRUE(socket->ConnectBlocking().ok());
+
+ std::string data = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
+ std::string expected;
+
+ std::unique_ptr<TestServerSocketRunner> runner;
+ std::string sent;
+ switch (protocol_) {
+ case socket_api::SocketProtocol::kTcp:
+ // Repeatedly write to socket until it does not complete synchronously.
+ do {
+ expected.insert(expected.end(), data.begin(), data.end());
+ send_result_.reset();
+ socket->SendAsync(data);
+ ASSERT_TRUE(!send_result_.has_value() || send_result_.value().ok());
+ } while (send_result_.has_value());
+
+ // Begin receiving from server and expect more data to send.
+ runner = CreateServerSocketRunner(
+ absl::bind_front(&ReceiveDataFromSocket, &sent), socket.get());
+ EXPECT_FALSE(send_result_.has_value());
+ for (int i = 0; i < 5 && !send_result_.has_value(); ++i) {
+ event_loop_->RunEventLoopOnce(QuicTime::Delta::FromSeconds(1));
+ }
+ break;
+
+ case socket_api::SocketProtocol::kUdp:
+ // Expect UDP send to always send immediately.
+ runner = CreateServerSocketRunner(
+ absl::bind_front(&ReceiveDataFromSocket, &sent), socket.get());
+ socket->SendAsync(data);
+ expected = data;
+ break;
+ }
+ ASSERT_TRUE(send_result_.has_value());
+ EXPECT_TRUE(send_result_.value().ok());
+
+ send_result_.reset();
+ socket->Disconnect();
+ EXPECT_FALSE(send_result_.has_value());
+
+ runner->WaitForCompletion();
+ EXPECT_EQ(sent, expected);
+}
+
+TEST_P(EventLoopConnectingClientSocketTest, DisconnectCancelsSendAsync) {
+ if (protocol_ == socket_api::SocketProtocol::kUdp) {
+ // UDP sends are always immediate, so cannot disconect mid-send.
+ return;
+ }
+
+ std::unique_ptr<ConnectingClientSocket> socket =
+ CreateSocketToEncourageDelayedSend(server_socket_address_,
+ /*async_visitor=*/this);
+ ASSERT_TRUE(socket->ConnectBlocking().ok());
+
+ std::string data = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
+
+ // Repeatedly write to socket until it does not complete synchronously.
+ do {
+ send_result_.reset();
+ socket->SendAsync(data);
+ ASSERT_TRUE(!send_result_.has_value() || send_result_.value().ok());
+ } while (send_result_.has_value());
+
+ // Disconnect and expect immediate cancelled error.
+ socket->Disconnect();
+ ASSERT_TRUE(send_result_.has_value());
+ EXPECT_TRUE(absl::IsCancelled(send_result_.value()));
+}
+
+} // namespace
+} // namespace quic::test
diff --git a/quiche/quic/core/io/event_loop_socket_factory.cc b/quiche/quic/core/io/event_loop_socket_factory.cc
index 643efab..b1aaec7 100644
--- a/quiche/quic/core/io/event_loop_socket_factory.cc
+++ b/quiche/quic/core/io/event_loop_socket_factory.cc
@@ -7,8 +7,9 @@
#include <memory>
#include "quiche/quic/core/connecting_client_socket.h"
-#include "quiche/quic/core/io/event_loop_tcp_client_socket.h"
+#include "quiche/quic/core/io/event_loop_connecting_client_socket.h"
#include "quiche/quic/core/io/quic_event_loop.h"
+#include "quiche/quic/core/io/socket.h"
#include "quiche/quic/core/quic_types.h"
#include "quiche/quic/platform/api/quic_socket_address.h"
#include "quiche/common/platform/api/quiche_logging.h"
@@ -28,9 +29,19 @@
const quic::QuicSocketAddress& peer_address,
QuicByteCount receive_buffer_size, QuicByteCount send_buffer_size,
ConnectingClientSocket::AsyncVisitor* async_visitor) {
- return std::make_unique<EventLoopTcpClientSocket>(
- peer_address, receive_buffer_size, send_buffer_size, event_loop_,
- buffer_allocator_, async_visitor);
+ return std::make_unique<EventLoopConnectingClientSocket>(
+ socket_api::SocketProtocol::kTcp, peer_address, receive_buffer_size,
+ send_buffer_size, event_loop_, buffer_allocator_, async_visitor);
+}
+
+std::unique_ptr<ConnectingClientSocket>
+EventLoopSocketFactory::CreateConnectingUdpClientSocket(
+ const quic::QuicSocketAddress& peer_address,
+ QuicByteCount receive_buffer_size, QuicByteCount send_buffer_size,
+ ConnectingClientSocket::AsyncVisitor* async_visitor) {
+ return std::make_unique<EventLoopConnectingClientSocket>(
+ socket_api::SocketProtocol::kUdp, peer_address, receive_buffer_size,
+ send_buffer_size, event_loop_, buffer_allocator_, async_visitor);
}
} // namespace quic
diff --git a/quiche/quic/core/io/event_loop_socket_factory.h b/quiche/quic/core/io/event_loop_socket_factory.h
index e3654f9..ee9a9f3 100644
--- a/quiche/quic/core/io/event_loop_socket_factory.h
+++ b/quiche/quic/core/io/event_loop_socket_factory.h
@@ -30,6 +30,10 @@
const quic::QuicSocketAddress& peer_address,
QuicByteCount receive_buffer_size, QuicByteCount send_buffer_size,
ConnectingClientSocket::AsyncVisitor* async_visitor) override;
+ std::unique_ptr<ConnectingClientSocket> CreateConnectingUdpClientSocket(
+ const quic::QuicSocketAddress& peer_address,
+ QuicByteCount receive_buffer_size, QuicByteCount send_buffer_size,
+ ConnectingClientSocket::AsyncVisitor* async_visitor) override;
private:
QuicEventLoop* const event_loop_; // unowned
diff --git a/quiche/quic/core/io/event_loop_tcp_client_socket_test.cc b/quiche/quic/core/io/event_loop_tcp_client_socket_test.cc
deleted file mode 100644
index 68343a5..0000000
--- a/quiche/quic/core/io/event_loop_tcp_client_socket_test.cc
+++ /dev/null
@@ -1,523 +0,0 @@
-// Copyright 2022 The Chromium Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style license that can be
-// found in the LICENSE file.
-
-#include "quiche/quic/core/io/event_loop_tcp_client_socket.h"
-
-#include <functional>
-#include <memory>
-#include <utility>
-#include <vector>
-
-#include "absl/functional/bind_front.h"
-#include "absl/status/status.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "absl/types/optional.h"
-#include "absl/types/span.h"
-#include "quiche/quic/core/connecting_client_socket.h"
-#include "quiche/quic/core/io/event_loop_socket_factory.h"
-#include "quiche/quic/core/io/quic_default_event_loop.h"
-#include "quiche/quic/core/io/quic_event_loop.h"
-#include "quiche/quic/core/io/socket.h"
-#include "quiche/quic/core/quic_time.h"
-#include "quiche/quic/platform/api/quic_ip_address_family.h"
-#include "quiche/quic/platform/api/quic_socket_address.h"
-#include "quiche/quic/test_tools/mock_clock.h"
-#include "quiche/quic/test_tools/quic_test_utils.h"
-#include "quiche/common/platform/api/quiche_logging.h"
-#include "quiche/common/platform/api/quiche_mem_slice.h"
-#include "quiche/common/platform/api/quiche_mutex.h"
-#include "quiche/common/platform/api/quiche_test.h"
-#include "quiche/common/platform/api/quiche_test_loopback.h"
-#include "quiche/common/platform/api/quiche_thread.h"
-#include "quiche/common/simple_buffer_allocator.h"
-
-namespace quic::test {
-namespace {
-
-bool CreateListeningServerSocket(SocketFd* out_socket_descriptor,
- QuicSocketAddress* out_socket_address) {
- QUICHE_CHECK(out_socket_descriptor);
- QUICHE_CHECK(out_socket_address);
-
- absl::StatusOr<SocketFd> socket = socket_api::CreateSocket(
- quiche::TestLoopback().address_family(), socket_api::SocketProtocol::kTcp,
- /*blocking=*/true);
- QUICHE_CHECK(socket.ok());
-
- // Set an extremely small receive buffer size to increase the odds of buffers
- // filling up when testing asynchronous writes.
- static const QuicByteCount kReceiveBufferSize = 2;
- absl::Status result =
- socket_api::SetReceiveBufferSize(socket.value(), kReceiveBufferSize);
- QUICHE_CHECK(result.ok());
-
- QuicSocketAddress bind_address(quiche::TestLoopback(), /*port=*/0);
- result = socket_api::Bind(socket.value(), bind_address);
- QUICHE_CHECK(result.ok());
-
- absl::StatusOr<QuicSocketAddress> socket_address =
- socket_api::GetSocketAddress(socket.value());
- QUICHE_CHECK(socket_address.ok());
-
- result = socket_api::Listen(socket.value(), /*backlog=*/1);
- QUICHE_CHECK(result.ok());
-
- *out_socket_descriptor = socket.value();
- *out_socket_address = std::move(socket_address).value();
- return true;
-}
-
-class TestTcpServerSocketRunner : public quiche::QuicheThread {
- public:
- using SocketBehavior = std::function<void(SocketFd connected_socket)>;
-
- // On construction, spins a separate thread to accept a connection from
- // `server_socket_descriptor`, runs `behavior` with that connection, and then
- // closes the accepted connection socket. If `allow_accept_failure` is true,
- // will silently stop if an error is encountered accepting the connection.
- TestTcpServerSocketRunner(SocketFd server_socket_descriptor,
- SocketBehavior behavior,
- bool allow_accept_failure = false)
- : QuicheThread("TestTcpServerSocketRunner"),
- server_socket_descriptor_(server_socket_descriptor),
- behavior_(std::move(behavior)),
- allow_accept_failure_(allow_accept_failure) {
- Start();
- }
-
- ~TestTcpServerSocketRunner() override { WaitForCompletion(); }
-
- void WaitForCompletion() { completion_notification_.WaitForNotification(); }
-
- protected:
- void Run() override {
- if (AcceptSocket()) {
- behavior_(connection_socket_descriptor_);
- CloseSocket();
- } else {
- QUICHE_CHECK(allow_accept_failure_);
- }
-
- completion_notification_.Notify();
- }
-
- private:
- bool AcceptSocket() {
- absl::StatusOr<socket_api::AcceptResult> connection_socket =
- socket_api::Accept(server_socket_descriptor_, /*blocking=*/true);
- if (connection_socket.ok()) {
- connection_socket_descriptor_ = connection_socket.value().fd;
- }
- return connection_socket.ok();
- }
-
- void CloseSocket() {
- QUICHE_CHECK(socket_api::Close(connection_socket_descriptor_).ok());
- }
-
- const SocketFd server_socket_descriptor_;
- const SocketBehavior behavior_;
- const bool allow_accept_failure_;
-
- SocketFd connection_socket_descriptor_;
-
- quiche::QuicheNotification completion_notification_;
-};
-
-class EventLoopTcpClientSocketTest
- : public quiche::test::QuicheTestWithParam<QuicEventLoopFactory*>,
- public ConnectingClientSocket::AsyncVisitor {
- public:
- void SetUp() override {
- QUICHE_CHECK(CreateListeningServerSocket(&server_socket_descriptor_,
- &server_socket_address_));
- }
-
- void TearDown() override {
- if (server_socket_descriptor_ != kInvalidSocketFd) {
- QUICHE_CHECK(socket_api::Close(server_socket_descriptor_).ok());
- }
- }
-
- void ConnectComplete(absl::Status status) override {
- QUICHE_CHECK(!connect_result_.has_value());
- connect_result_ = std::move(status);
- }
-
- void ReceiveComplete(absl::StatusOr<quiche::QuicheMemSlice> data) override {
- QUICHE_CHECK(!receive_result_.has_value());
- receive_result_ = std::move(data);
- }
-
- void SendComplete(absl::Status status) override {
- QUICHE_CHECK(!send_result_.has_value());
- send_result_ = std::move(status);
- }
-
- protected:
- SocketFd server_socket_descriptor_ = kInvalidSocketFd;
- QuicSocketAddress server_socket_address_;
-
- MockClock clock_;
- std::unique_ptr<QuicEventLoop> event_loop_ = GetParam()->Create(&clock_);
- EventLoopSocketFactory socket_factory_{event_loop_.get(),
- quiche::SimpleBufferAllocator::Get()};
-
- absl::optional<absl::Status> connect_result_;
- absl::optional<absl::StatusOr<quiche::QuicheMemSlice>> receive_result_;
- absl::optional<absl::Status> send_result_;
-};
-
-std::string GetTestParamName(
- ::testing::TestParamInfo<QuicEventLoopFactory*> info) {
- return EscapeTestParamName(info.param->GetName());
-}
-
-INSTANTIATE_TEST_SUITE_P(EventLoopTcpClientSocketTests,
- EventLoopTcpClientSocketTest,
- ::testing::ValuesIn(GetAllSupportedEventLoops()),
- &GetTestParamName);
-
-TEST_P(EventLoopTcpClientSocketTest, Connect) {
- std::unique_ptr<ConnectingClientSocket> socket =
- socket_factory_.CreateTcpClientSocket(server_socket_address_,
- /*receive_buffer_size=*/0,
- /*send_buffer_size=*/0,
- /*async_visitor=*/nullptr);
-
- // No socket runner to accept the connection for the server, but that is not
- // expected to be necessary for the connection to complete from the client.
- EXPECT_TRUE(socket->ConnectBlocking().ok());
-
- socket->Disconnect();
-}
-
-TEST_P(EventLoopTcpClientSocketTest, ConnectAsync) {
- std::unique_ptr<ConnectingClientSocket> socket =
- socket_factory_.CreateTcpClientSocket(server_socket_address_,
- /*receive_buffer_size=*/0,
- /*send_buffer_size=*/0,
- /*async_visitor=*/this);
-
- socket->ConnectAsync();
-
- // Synchronous completion not normally expected, but since there is no known
- // way to delay the server side of the connection (the OS does not wait for
- // an accept() call), cannot be gauranteed that the connection will always
- // complete asynchronously. If connecting asynchronously (normal behavior),
- // expect completion once signalled by the event loop.
- if (!connect_result_.has_value()) {
- event_loop_->RunEventLoopOnce(QuicTime::Delta::FromSeconds(1));
- ASSERT_TRUE(connect_result_.has_value());
- }
- EXPECT_TRUE(connect_result_.value().ok());
-
- connect_result_.reset();
- socket->Disconnect();
- EXPECT_FALSE(connect_result_.has_value());
-}
-
-TEST_P(EventLoopTcpClientSocketTest, ErrorBeforeConnectAsync) {
- std::unique_ptr<ConnectingClientSocket> socket =
- socket_factory_.CreateTcpClientSocket(server_socket_address_,
- /*receive_buffer_size=*/0,
- /*send_buffer_size=*/0,
- /*async_visitor=*/this);
-
- // Close the server socket.
- EXPECT_TRUE(socket_api::Close(server_socket_descriptor_).ok());
- server_socket_descriptor_ = kInvalidSocketFd;
-
- socket->ConnectAsync();
- if (!connect_result_.has_value()) {
- event_loop_->RunEventLoopOnce(QuicTime::Delta::FromSeconds(1));
- ASSERT_TRUE(connect_result_.has_value());
- }
-
- // Expect an error because server socket was closed before connection.
- EXPECT_FALSE(connect_result_.value().ok());
-}
-
-TEST_P(EventLoopTcpClientSocketTest, ErrorDuringConnectAsync) {
- std::unique_ptr<ConnectingClientSocket> socket =
- socket_factory_.CreateTcpClientSocket(server_socket_address_,
- /*receive_buffer_size=*/0,
- /*send_buffer_size=*/0,
- /*async_visitor=*/this);
-
- socket->ConnectAsync();
-
- if (connect_result_.has_value()) {
- // Not typical, but theoretically nothing to stop the connection from
- // completing before the server socket is closed to trigger the error.
- EXPECT_TRUE(connect_result_.value().ok());
- return;
- }
-
- // Close the server socket.
- EXPECT_TRUE(socket_api::Close(server_socket_descriptor_).ok());
- server_socket_descriptor_ = kInvalidSocketFd;
-
- // Expect an error once signalled.
- EXPECT_FALSE(connect_result_.has_value());
- event_loop_->RunEventLoopOnce(QuicTime::Delta::FromSeconds(1));
- ASSERT_TRUE(connect_result_.has_value());
- EXPECT_FALSE(connect_result_.value().ok());
-}
-
-TEST_P(EventLoopTcpClientSocketTest, Disconnect) {
- std::unique_ptr<ConnectingClientSocket> socket =
- socket_factory_.CreateTcpClientSocket(server_socket_address_,
- /*receive_buffer_size=*/0,
- /*send_buffer_size=*/0,
- /*async_visitor=*/nullptr);
-
- ASSERT_TRUE(socket->ConnectBlocking().ok());
- socket->Disconnect();
-}
-
-TEST_P(EventLoopTcpClientSocketTest, DisconnectCancelsConnectAsync) {
- std::unique_ptr<ConnectingClientSocket> socket =
- socket_factory_.CreateTcpClientSocket(server_socket_address_,
- /*receive_buffer_size=*/0,
- /*send_buffer_size=*/0,
- /*async_visitor=*/this);
-
- socket->ConnectAsync();
-
- if (connect_result_.has_value()) {
- // Not typical, but theoretically nothing to stop the connection from
- // completing before the server socket is closed to trigger the error.
- EXPECT_TRUE(connect_result_.value().ok());
- return;
- }
-
- socket->Disconnect();
-
- // Expect immediate cancelled error.
- ASSERT_TRUE(connect_result_.has_value());
- EXPECT_TRUE(absl::IsCancelled(connect_result_.value()));
-}
-
-TEST_P(EventLoopTcpClientSocketTest, ConnectAndReconnect) {
- std::unique_ptr<ConnectingClientSocket> socket =
- socket_factory_.CreateTcpClientSocket(server_socket_address_,
- /*receive_buffer_size=*/0,
- /*send_buffer_size=*/0,
- /*async_visitor=*/nullptr);
-
- ASSERT_TRUE(socket->ConnectBlocking().ok());
- socket->Disconnect();
-
- // Expect `socket` can reconnect now that it has been disconnected.
- EXPECT_TRUE(socket->ConnectBlocking().ok());
- socket->Disconnect();
-}
-
-void SendDataOnSocket(absl::string_view data, SocketFd connected_socket) {
- while (!data.empty()) {
- absl::StatusOr<absl::string_view> remainder =
- socket_api::Send(connected_socket, data);
- if (!remainder.ok()) {
- return;
- }
- data = remainder.value();
- }
-}
-
-TEST_P(EventLoopTcpClientSocketTest, Receive) {
- std::string expected = {1, 2, 3, 4, 5, 6, 7, 8};
- TestTcpServerSocketRunner runner(
- server_socket_descriptor_, absl::bind_front(&SendDataOnSocket, expected));
-
- std::unique_ptr<ConnectingClientSocket> socket =
- socket_factory_.CreateTcpClientSocket(server_socket_address_,
- /*receive_buffer_size=*/0,
- /*send_buffer_size=*/0,
- /*async_visitor=*/nullptr);
- ASSERT_TRUE(socket->ConnectBlocking().ok());
-
- std::string received;
- absl::StatusOr<quiche::QuicheMemSlice> data;
- do {
- data = socket->ReceiveBlocking(100);
- ASSERT_TRUE(data.ok());
- received.append(data.value().data(), data.value().length());
- } while (!data.value().empty());
- EXPECT_EQ(received, expected);
-
- socket->Disconnect();
-}
-
-TEST_P(EventLoopTcpClientSocketTest, ReceiveAsync) {
- std::unique_ptr<ConnectingClientSocket> socket =
- socket_factory_.CreateTcpClientSocket(server_socket_address_,
- /*receive_buffer_size=*/0,
- /*send_buffer_size=*/0,
- /*async_visitor=*/this);
- ASSERT_TRUE(socket->ConnectBlocking().ok());
-
- // Start an async receive. Expect no immediate results because runner not yet
- // setup to accept and send.
- socket->ReceiveAsync(100);
- EXPECT_FALSE(receive_result_.has_value());
-
- // Send data from server.
- std::string expected = {1, 2, 3, 4, 5, 6, 7, 8};
- TestTcpServerSocketRunner runner(
- server_socket_descriptor_, absl::bind_front(&SendDataOnSocket, expected));
- EXPECT_FALSE(receive_result_.has_value());
- for (int i = 0; i < 5 && !receive_result_.has_value(); ++i) {
- event_loop_->RunEventLoopOnce(QuicTime::Delta::FromSeconds(1));
- }
-
- // Expect to receive at least some of the sent data.
- ASSERT_TRUE(receive_result_.has_value());
- ASSERT_TRUE(receive_result_.value().ok());
- EXPECT_FALSE(receive_result_.value().value().empty());
- std::string received(receive_result_.value().value().data(),
- receive_result_.value().value().length());
-
- // Get any remaining data via blocking calls.
- absl::StatusOr<quiche::QuicheMemSlice> data;
- do {
- data = socket->ReceiveBlocking(100);
- ASSERT_TRUE(data.ok());
- received.append(data.value().data(), data.value().length());
- } while (!data.value().empty());
-
- EXPECT_EQ(received, expected);
-
- receive_result_.reset();
- socket->Disconnect();
- EXPECT_FALSE(receive_result_.has_value());
-}
-
-TEST_P(EventLoopTcpClientSocketTest, DisconnectCancelsReceiveAsync) {
- std::unique_ptr<ConnectingClientSocket> socket =
- socket_factory_.CreateTcpClientSocket(server_socket_address_,
- /*receive_buffer_size=*/0,
- /*send_buffer_size=*/0,
- /*async_visitor=*/this);
-
- ASSERT_TRUE(socket->ConnectBlocking().ok());
-
- // Start an asynchronous read, expecting no completion because server never
- // sends any data.
- socket->ReceiveAsync(100);
- EXPECT_FALSE(receive_result_.has_value());
-
- // Disconnect and expect an immediate cancelled error.
- socket->Disconnect();
- ASSERT_TRUE(receive_result_.has_value());
- ASSERT_FALSE(receive_result_.value().ok());
- EXPECT_TRUE(absl::IsCancelled(receive_result_.value().status()));
-}
-
-// Receive from `connected_socket` until connection is closed, writing received
-// data to `out_received`.
-void ReceiveDataFromSocket(std::string* out_received,
- SocketFd connected_socket) {
- out_received->clear();
-
- std::string buffer(100, 0);
- absl::StatusOr<absl::Span<char>> received;
- do {
- received = socket_api::Receive(connected_socket, absl::MakeSpan(buffer));
- QUICHE_CHECK(received.ok());
- out_received->insert(out_received->end(), received.value().begin(),
- received.value().end());
- } while (!received.value().empty());
-}
-
-TEST_P(EventLoopTcpClientSocketTest, Send) {
- std::string sent;
- TestTcpServerSocketRunner runner(
- server_socket_descriptor_,
- absl::bind_front(&ReceiveDataFromSocket, &sent));
-
- std::unique_ptr<ConnectingClientSocket> socket =
- socket_factory_.CreateTcpClientSocket(server_socket_address_,
- /*receive_buffer_size=*/0,
- /*send_buffer_size=*/0,
- /*async_visitor=*/nullptr);
- ASSERT_TRUE(socket->ConnectBlocking().ok());
-
- std::string expected = {1, 2, 3, 4, 5, 6, 7, 8};
- EXPECT_TRUE(socket->SendBlocking(expected).ok());
- socket->Disconnect();
-
- runner.WaitForCompletion();
- EXPECT_EQ(sent, expected);
-}
-
-TEST_P(EventLoopTcpClientSocketTest, SendAsync) {
- // Use a small send buffer to improve chances of a send needing to be
- // asynchronous.
- std::unique_ptr<ConnectingClientSocket> socket =
- socket_factory_.CreateTcpClientSocket(server_socket_address_,
- /*receive_buffer_size=*/0,
- /*send_buffer_size=*/4,
- /*async_visitor=*/this);
- ASSERT_TRUE(socket->ConnectBlocking().ok());
-
- std::string data = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
- std::string expected;
-
- // Repeatedly write to socket until it does not complete synchronously.
- do {
- expected.insert(expected.end(), data.begin(), data.end());
- send_result_.reset();
- socket->SendAsync(data);
- ASSERT_TRUE(!send_result_.has_value() || send_result_.value().ok());
- } while (send_result_.has_value());
-
- // Begin receiving from server and expect more data to send.
- std::string sent;
- TestTcpServerSocketRunner runner(
- server_socket_descriptor_,
- absl::bind_front(&ReceiveDataFromSocket, &sent));
- EXPECT_FALSE(send_result_.has_value());
- for (int i = 0; i < 5 && !send_result_.has_value(); ++i) {
- event_loop_->RunEventLoopOnce(QuicTime::Delta::FromSeconds(1));
- }
- ASSERT_TRUE(send_result_.has_value());
- EXPECT_TRUE(send_result_.value().ok());
-
- send_result_.reset();
- socket->Disconnect();
- EXPECT_FALSE(send_result_.has_value());
-
- runner.WaitForCompletion();
- EXPECT_EQ(sent, expected);
-}
-
-TEST_P(EventLoopTcpClientSocketTest, DisconnectCancelsSendAsync) {
- // Use a small send buffer to improve chances of a send needing to be
- // asynchronous.
- std::unique_ptr<ConnectingClientSocket> socket =
- socket_factory_.CreateTcpClientSocket(server_socket_address_,
- /*receive_buffer_size=*/0,
- /*send_buffer_size=*/4,
- /*async_visitor=*/this);
- ASSERT_TRUE(socket->ConnectBlocking().ok());
-
- std::string data = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
-
- // Repeatedly write to socket until it does not complete synchronously.
- do {
- send_result_.reset();
- socket->SendAsync(data);
- ASSERT_TRUE(!send_result_.has_value() || send_result_.value().ok());
- } while (send_result_.has_value());
-
- // Disconnect and expect immediate cancelled error.
- socket->Disconnect();
- ASSERT_TRUE(send_result_.has_value());
- EXPECT_TRUE(absl::IsCancelled(send_result_.value()));
-}
-
-} // namespace
-} // namespace quic::test
diff --git a/quiche/quic/core/io/socket.h b/quiche/quic/core/io/socket.h
index a591723..8d32bc3 100644
--- a/quiche/quic/core/io/socket.h
+++ b/quiche/quic/core/io/socket.h
@@ -40,6 +40,17 @@
kTcp,
};
+inline absl::string_view GetProtocolName(SocketProtocol protocol) {
+ switch (protocol) {
+ case SocketProtocol::kUdp:
+ return "UDP";
+ case SocketProtocol::kTcp:
+ return "TCP";
+ }
+
+ return "unknown";
+}
+
struct QUICHE_EXPORT_PRIVATE AcceptResult {
// Socket for interacting with the accepted connection.
SocketFd fd;
diff --git a/quiche/quic/core/socket_factory.h b/quiche/quic/core/socket_factory.h
index 4c473dc..d37499d 100644
--- a/quiche/quic/core/socket_factory.h
+++ b/quiche/quic/core/socket_factory.h
@@ -27,6 +27,19 @@
const quic::QuicSocketAddress& peer_address,
QuicByteCount receive_buffer_size, QuicByteCount send_buffer_size,
ConnectingClientSocket::AsyncVisitor* async_visitor) = 0;
+
+ // Will use platform default buffer size if `receive_buffer_size` or
+ // `send_buffer_size` is zero. If `async_visitor` is null, async operations
+ // must not be called on the created socket. If `async_visitor` is non-null,
+ // it must outlive the created socket.
+ //
+ // TODO(ericorth): Consider creating a sub-interface for connecting UDP
+ // sockets with additional functionality, e.g. sendto, if needed.
+ virtual std::unique_ptr<ConnectingClientSocket>
+ CreateConnectingUdpClientSocket(
+ const quic::QuicSocketAddress& peer_address,
+ QuicByteCount receive_buffer_size, QuicByteCount send_buffer_size,
+ ConnectingClientSocket::AsyncVisitor* async_visitor) = 0;
};
} // namespace quic
diff --git a/quiche/quic/tools/connect_tunnel_test.cc b/quiche/quic/tools/connect_tunnel_test.cc
index cb91e42..33753a1 100644
--- a/quiche/quic/tools/connect_tunnel_test.cc
+++ b/quiche/quic/tools/connect_tunnel_test.cc
@@ -69,6 +69,13 @@
QuicByteCount send_buffer_size,
ConnectingClientSocket::AsyncVisitor* async_visitor),
(override));
+ MOCK_METHOD(std::unique_ptr<ConnectingClientSocket>,
+ CreateConnectingUdpClientSocket,
+ (const quic::QuicSocketAddress& peer_address,
+ QuicByteCount receive_buffer_size,
+ QuicByteCount send_buffer_size,
+ ConnectingClientSocket::AsyncVisitor* async_visitor),
+ (override));
};
class MockSocket : public ConnectingClientSocket {
@@ -76,6 +83,8 @@
MOCK_METHOD(absl::Status, ConnectBlocking, (), (override));
MOCK_METHOD(void, ConnectAsync, (), (override));
MOCK_METHOD(void, Disconnect, (), (override));
+ MOCK_METHOD(absl::StatusOr<QuicSocketAddress>, GetLocalAddress, (),
+ (override));
MOCK_METHOD(absl::StatusOr<quiche::QuicheMemSlice>, ReceiveBlocking,
(QuicByteCount max_size), (override));
MOCK_METHOD(void, ReceiveAsync, (QuicByteCount max_size), (override));