Add an RAII wrapper for quic::SocketFd. PiperOrigin-RevId: 845394682
diff --git a/quiche/quic/core/io/quic_server_io_harness.cc b/quiche/quic/core/io/quic_server_io_harness.cc index de30e62..e63778c 100644 --- a/quiche/quic/core/io/quic_server_io_harness.cc +++ b/quiche/quic/core/io/quic_server_io_harness.cc
@@ -24,20 +24,19 @@ namespace quic { -absl::StatusOr<SocketFd> CreateAndBindServerSocket( +absl::StatusOr<OwnedSocketFd> CreateAndBindServerSocket( const QuicSocketAddress& bind_address) { - SocketFd fd = QuicUdpSocketApi().Create( + OwnedSocketFd fd(QuicUdpSocketApi().Create( bind_address.host().AddressFamilyToInt(), /*receive_buffer_size=*/kDefaultSocketReceiveBuffer, - /*send_buffer_size=*/kDefaultSocketReceiveBuffer); - if (fd == kQuicInvalidSocketFd) { + /*send_buffer_size=*/kDefaultSocketReceiveBuffer)); + if (!fd.valid()) { return absl::InternalError("Failed to create socket"); } - bool success = QuicUdpSocketApi().Bind(fd, bind_address); + bool success = QuicUdpSocketApi().Bind(*fd, bind_address); if (!success) { - (void)socket_api::Close(fd); - return socket_api::GetSocketError(fd); + return socket_api::GetSocketError(*fd); } return fd;
diff --git a/quiche/quic/core/io/quic_server_io_harness.h b/quiche/quic/core/io/quic_server_io_harness.h index 38d8c66..01a8c83 100644 --- a/quiche/quic/core/io/quic_server_io_harness.h +++ b/quiche/quic/core/io/quic_server_io_harness.h
@@ -21,7 +21,7 @@ namespace quic { // Creates a UDP socket and binds it to the specified address. -absl::StatusOr<SocketFd> CreateAndBindServerSocket( +absl::StatusOr<OwnedSocketFd> CreateAndBindServerSocket( const QuicSocketAddress& bind_address); // QuicServerIoHarness registers itself with the provided event loop, reads
diff --git a/quiche/quic/core/io/socket.cc b/quiche/quic/core/io/socket.cc index ca28df2..19f5646 100644 --- a/quiche/quic/core/io/socket.cc +++ b/quiche/quic/core/io/socket.cc
@@ -24,7 +24,8 @@ #include "quiche/quic/core/io/socket_posix.inc" #endif -namespace quic::socket_api { +namespace quic { +namespace socket_api { namespace { @@ -298,4 +299,20 @@ } } -} // namespace quic::socket_api +} // namespace socket_api + +void OwnedSocketFd::reset() { + if (valid()) { + absl::Status status = socket_api::Close(fd_); + QUICHE_DLOG_IF(WARNING, !status.ok()) << "Failed to close FD: " << status; + fd_ = kInvalidSocketFd; + } +} + +SocketFd OwnedSocketFd::release() { + SocketFd fd = fd_; + fd_ = kInvalidSocketFd; + return fd; +} + +} // namespace quic
diff --git a/quiche/quic/core/io/socket.h b/quiche/quic/core/io/socket.h index dd49bd1..8debff7 100644 --- a/quiche/quic/core/io/socket.h +++ b/quiche/quic/core/io/socket.h
@@ -31,6 +31,36 @@ inline constexpr int kSocketErrorMsgSize = EMSGSIZE; #endif +// An std::unique_ptr-like wrapper around SocketFd. +class OwnedSocketFd { + public: + OwnedSocketFd() : fd_(kInvalidSocketFd) {} + explicit OwnedSocketFd(SocketFd fd) : fd_(fd) {} + ~OwnedSocketFd() { reset(); } + + OwnedSocketFd(const OwnedSocketFd&) = delete; + OwnedSocketFd& operator=(const OwnedSocketFd&) = delete; + + OwnedSocketFd(OwnedSocketFd&& other) noexcept { + fd_ = other.fd_; + other.fd_ = kInvalidSocketFd; + } + OwnedSocketFd& operator=(OwnedSocketFd&& other) noexcept { + reset(); + std::swap(fd_, other.fd_); + return *this; + } + + bool valid() const { return fd_ != kInvalidSocketFd; } + SocketFd get() const { return fd_; } + SocketFd operator*() const { return fd_; } + void reset(); + SocketFd release(); + + private: + SocketFd fd_; +}; + // Low-level platform-agnostic socket operations. Closely follows the behavior // of basic POSIX socket APIs, diverging mostly only to convert to/from cleaner // and platform-agnostic types.
diff --git a/quiche/quic/core/io/socket_test.cc b/quiche/quic/core/io/socket_test.cc index 58d8136..98c0028 100644 --- a/quiche/quic/core/io/socket_test.cc +++ b/quiche/quic/core/io/socket_test.cc
@@ -5,6 +5,7 @@ #include "quiche/quic/core/io/socket.h" #include <string> +#include <utility> #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -73,6 +74,10 @@ } } +bool IsValidSocket(SocketFd fd) { + return socket_api::GetSocketAddress(fd).ok(); +} + TEST(SocketTest, CreateAndCloseSocket) { QuicIpAddress localhost_address = quiche::TestLoopback(); absl::StatusOr<SocketFd> created_socket = socket_api::CreateSocket( @@ -370,5 +375,66 @@ QUICHE_EXPECT_OK(socket_api::Close(socket)); } +TEST(SocketTest, EmptyOwnedSocketFd) { + OwnedSocketFd owned_fd; + EXPECT_FALSE(owned_fd.valid()); + EXPECT_EQ(*owned_fd, kInvalidSocketFd); +} + +TEST(SocketTest, OwnedSocketFd) { + SocketFd fd = CreateTestSocket(socket_api::SocketProtocol::kUdp); + ASSERT_TRUE(IsValidSocket(fd)); + + OwnedSocketFd owned_fd(fd); + EXPECT_EQ(owned_fd.get(), fd); + EXPECT_TRUE(owned_fd.valid()); + EXPECT_TRUE(IsValidSocket(fd)); + + OwnedSocketFd owned_fd2 = std::move(owned_fd); + EXPECT_EQ(owned_fd.get(), // NOLINT(bugprone-use-after-move) + kInvalidSocketFd); + EXPECT_FALSE(owned_fd.valid()); + EXPECT_EQ(owned_fd2.get(), fd); + EXPECT_TRUE(owned_fd2.valid()); + EXPECT_TRUE(IsValidSocket(fd)); + + owned_fd2.reset(); + EXPECT_EQ(owned_fd2.get(), kInvalidSocketFd); + EXPECT_FALSE(owned_fd2.valid()); + EXPECT_FALSE(IsValidSocket(fd)); +} + +TEST(SocketTest, OwnedSocketFdMove) { + SocketFd fd = CreateTestSocket(socket_api::SocketProtocol::kUdp); + SocketFd fd2 = CreateTestSocket(socket_api::SocketProtocol::kUdp); + ASSERT_TRUE(IsValidSocket(fd)); + ASSERT_TRUE(IsValidSocket(fd2)); + + OwnedSocketFd owned_fd(fd); + OwnedSocketFd owned_fd2(fd2); + owned_fd = std::move(owned_fd2); + EXPECT_TRUE(owned_fd.valid()); + EXPECT_FALSE(owned_fd2.valid()); // NOLINT(bugprone-use-after-move) + EXPECT_FALSE(IsValidSocket(fd)); + EXPECT_TRUE(IsValidSocket(fd2)); +} + +TEST(SocketTest, OwnedSocketFdRelease) { + SocketFd fd = CreateTestSocket(socket_api::SocketProtocol::kUdp); + ASSERT_TRUE(IsValidSocket(fd)); + + { + OwnedSocketFd owned_fd(fd); + EXPECT_TRUE(owned_fd.valid()); + + EXPECT_EQ(owned_fd.release(), fd); + ASSERT_TRUE(IsValidSocket(fd)); + EXPECT_FALSE(owned_fd.valid()); + + owned_fd = OwnedSocketFd(fd); + } + EXPECT_FALSE(IsValidSocket(fd)); +} + } // namespace } // namespace quic::test
diff --git a/quiche/quic/test_tools/quic_server_peer.cc b/quiche/quic/test_tools/quic_server_peer.cc index 279a6a8..bf33eb3 100644 --- a/quiche/quic/test_tools/quic_server_peer.cc +++ b/quiche/quic/test_tools/quic_server_peer.cc
@@ -15,7 +15,7 @@ // static bool QuicServerPeer::SetSmallSocket(QuicServer* server) { int size = 1024 * 10; - return setsockopt(server->fd_, SOL_SOCKET, SO_RCVBUF, + return setsockopt(*server->fd_, SOL_SOCKET, SO_RCVBUF, reinterpret_cast<char*>(&size), sizeof(size)) != -1; }
diff --git a/quiche/quic/tools/quic_server.cc b/quiche/quic/tools/quic_server.cc index 1d82f62..44eed4c 100644 --- a/quiche/quic/tools/quic_server.cc +++ b/quiche/quic/tools/quic_server.cc
@@ -106,12 +106,6 @@ } QuicServer::~QuicServer() { - // Ensure the I/O harness is gone before closing the socket. - io_.reset(); - - (void)socket_api::Close(fd_); - fd_ = kInvalidSocketFd; - // Should be fine without because nothing should send requests to the backend // after `this` is destroyed, but for extra pointer safety, clear the socket // factory from the backend before the socket factory is destroyed. @@ -127,16 +121,16 @@ dispatcher_.reset(CreateQuicDispatcher()); - absl::StatusOr<SocketFd> fd = CreateAndBindServerSocket(address); + absl::StatusOr<OwnedSocketFd> fd = CreateAndBindServerSocket(address); if (!fd.ok()) { - QUIC_LOG(ERROR) << "Failed to create and bind socket: " << fd; + QUIC_LOG(ERROR) << "Failed to create and bind socket: " << fd.status(); return false; } - fd_ = *fd; - dispatcher_->InitializeWithWriter(CreateWriter(fd_)); + fd_ = *std::move(fd); + dispatcher_->InitializeWithWriter(CreateWriter(*fd_)); absl::StatusOr<std::unique_ptr<QuicServerIoHarness>> io = - QuicServerIoHarness::Create(event_loop_.get(), dispatcher_.get(), fd_); + QuicServerIoHarness::Create(event_loop_.get(), dispatcher_.get(), *fd_); if (!io.ok()) { QUICHE_LOG(ERROR) << "Failed to create I/O harness: " << io.status(); return false;
diff --git a/quiche/quic/tools/quic_server.h b/quiche/quic/tools/quic_server.h index 56fb461..5d600c3 100644 --- a/quiche/quic/tools/quic_server.h +++ b/quiche/quic/tools/quic_server.h
@@ -160,7 +160,7 @@ DeterministicConnectionIdGenerator connection_id_generator_; - SocketFd fd_ = kInvalidSocketFd; + OwnedSocketFd fd_; std::unique_ptr<QuicServerIoHarness> io_; };