Add an RAII wrapper for quic::SocketFd.

PiperOrigin-RevId: 845394682
diff --git a/quiche/quic/core/io/quic_server_io_harness.cc b/quiche/quic/core/io/quic_server_io_harness.cc
index de30e62..e63778c 100644
--- a/quiche/quic/core/io/quic_server_io_harness.cc
+++ b/quiche/quic/core/io/quic_server_io_harness.cc
@@ -24,20 +24,19 @@
 
 namespace quic {
 
-absl::StatusOr<SocketFd> CreateAndBindServerSocket(
+absl::StatusOr<OwnedSocketFd> CreateAndBindServerSocket(
     const QuicSocketAddress& bind_address) {
-  SocketFd fd = QuicUdpSocketApi().Create(
+  OwnedSocketFd fd(QuicUdpSocketApi().Create(
       bind_address.host().AddressFamilyToInt(),
       /*receive_buffer_size=*/kDefaultSocketReceiveBuffer,
-      /*send_buffer_size=*/kDefaultSocketReceiveBuffer);
-  if (fd == kQuicInvalidSocketFd) {
+      /*send_buffer_size=*/kDefaultSocketReceiveBuffer));
+  if (!fd.valid()) {
     return absl::InternalError("Failed to create socket");
   }
 
-  bool success = QuicUdpSocketApi().Bind(fd, bind_address);
+  bool success = QuicUdpSocketApi().Bind(*fd, bind_address);
   if (!success) {
-    (void)socket_api::Close(fd);
-    return socket_api::GetSocketError(fd);
+    return socket_api::GetSocketError(*fd);
   }
 
   return fd;
diff --git a/quiche/quic/core/io/quic_server_io_harness.h b/quiche/quic/core/io/quic_server_io_harness.h
index 38d8c66..01a8c83 100644
--- a/quiche/quic/core/io/quic_server_io_harness.h
+++ b/quiche/quic/core/io/quic_server_io_harness.h
@@ -21,7 +21,7 @@
 namespace quic {
 
 // Creates a UDP socket and binds it to the specified address.
-absl::StatusOr<SocketFd> CreateAndBindServerSocket(
+absl::StatusOr<OwnedSocketFd> CreateAndBindServerSocket(
     const QuicSocketAddress& bind_address);
 
 // QuicServerIoHarness registers itself with the provided event loop, reads
diff --git a/quiche/quic/core/io/socket.cc b/quiche/quic/core/io/socket.cc
index ca28df2..19f5646 100644
--- a/quiche/quic/core/io/socket.cc
+++ b/quiche/quic/core/io/socket.cc
@@ -24,7 +24,8 @@
 #include "quiche/quic/core/io/socket_posix.inc"
 #endif
 
-namespace quic::socket_api {
+namespace quic {
+namespace socket_api {
 
 namespace {
 
@@ -298,4 +299,20 @@
   }
 }
 
-}  // namespace quic::socket_api
+}  // namespace socket_api
+
+void OwnedSocketFd::reset() {
+  if (valid()) {
+    absl::Status status = socket_api::Close(fd_);
+    QUICHE_DLOG_IF(WARNING, !status.ok()) << "Failed to close FD: " << status;
+    fd_ = kInvalidSocketFd;
+  }
+}
+
+SocketFd OwnedSocketFd::release() {
+  SocketFd fd = fd_;
+  fd_ = kInvalidSocketFd;
+  return fd;
+}
+
+}  // namespace quic
diff --git a/quiche/quic/core/io/socket.h b/quiche/quic/core/io/socket.h
index dd49bd1..8debff7 100644
--- a/quiche/quic/core/io/socket.h
+++ b/quiche/quic/core/io/socket.h
@@ -31,6 +31,36 @@
 inline constexpr int kSocketErrorMsgSize = EMSGSIZE;
 #endif
 
+// An std::unique_ptr-like wrapper around SocketFd.
+class OwnedSocketFd {
+ public:
+  OwnedSocketFd() : fd_(kInvalidSocketFd) {}
+  explicit OwnedSocketFd(SocketFd fd) : fd_(fd) {}
+  ~OwnedSocketFd() { reset(); }
+
+  OwnedSocketFd(const OwnedSocketFd&) = delete;
+  OwnedSocketFd& operator=(const OwnedSocketFd&) = delete;
+
+  OwnedSocketFd(OwnedSocketFd&& other) noexcept {
+    fd_ = other.fd_;
+    other.fd_ = kInvalidSocketFd;
+  }
+  OwnedSocketFd& operator=(OwnedSocketFd&& other) noexcept {
+    reset();
+    std::swap(fd_, other.fd_);
+    return *this;
+  }
+
+  bool valid() const { return fd_ != kInvalidSocketFd; }
+  SocketFd get() const { return fd_; }
+  SocketFd operator*() const { return fd_; }
+  void reset();
+  SocketFd release();
+
+ private:
+  SocketFd fd_;
+};
+
 // Low-level platform-agnostic socket operations. Closely follows the behavior
 // of basic POSIX socket APIs, diverging mostly only to convert to/from cleaner
 // and platform-agnostic types.
diff --git a/quiche/quic/core/io/socket_test.cc b/quiche/quic/core/io/socket_test.cc
index 58d8136..98c0028 100644
--- a/quiche/quic/core/io/socket_test.cc
+++ b/quiche/quic/core/io/socket_test.cc
@@ -5,6 +5,7 @@
 #include "quiche/quic/core/io/socket.h"
 
 #include <string>
+#include <utility>
 
 #include "absl/status/status.h"
 #include "absl/status/statusor.h"
@@ -73,6 +74,10 @@
   }
 }
 
+bool IsValidSocket(SocketFd fd) {
+  return socket_api::GetSocketAddress(fd).ok();
+}
+
 TEST(SocketTest, CreateAndCloseSocket) {
   QuicIpAddress localhost_address = quiche::TestLoopback();
   absl::StatusOr<SocketFd> created_socket = socket_api::CreateSocket(
@@ -370,5 +375,66 @@
   QUICHE_EXPECT_OK(socket_api::Close(socket));
 }
 
+TEST(SocketTest, EmptyOwnedSocketFd) {
+  OwnedSocketFd owned_fd;
+  EXPECT_FALSE(owned_fd.valid());
+  EXPECT_EQ(*owned_fd, kInvalidSocketFd);
+}
+
+TEST(SocketTest, OwnedSocketFd) {
+  SocketFd fd = CreateTestSocket(socket_api::SocketProtocol::kUdp);
+  ASSERT_TRUE(IsValidSocket(fd));
+
+  OwnedSocketFd owned_fd(fd);
+  EXPECT_EQ(owned_fd.get(), fd);
+  EXPECT_TRUE(owned_fd.valid());
+  EXPECT_TRUE(IsValidSocket(fd));
+
+  OwnedSocketFd owned_fd2 = std::move(owned_fd);
+  EXPECT_EQ(owned_fd.get(),  // NOLINT(bugprone-use-after-move)
+            kInvalidSocketFd);
+  EXPECT_FALSE(owned_fd.valid());
+  EXPECT_EQ(owned_fd2.get(), fd);
+  EXPECT_TRUE(owned_fd2.valid());
+  EXPECT_TRUE(IsValidSocket(fd));
+
+  owned_fd2.reset();
+  EXPECT_EQ(owned_fd2.get(), kInvalidSocketFd);
+  EXPECT_FALSE(owned_fd2.valid());
+  EXPECT_FALSE(IsValidSocket(fd));
+}
+
+TEST(SocketTest, OwnedSocketFdMove) {
+  SocketFd fd = CreateTestSocket(socket_api::SocketProtocol::kUdp);
+  SocketFd fd2 = CreateTestSocket(socket_api::SocketProtocol::kUdp);
+  ASSERT_TRUE(IsValidSocket(fd));
+  ASSERT_TRUE(IsValidSocket(fd2));
+
+  OwnedSocketFd owned_fd(fd);
+  OwnedSocketFd owned_fd2(fd2);
+  owned_fd = std::move(owned_fd2);
+  EXPECT_TRUE(owned_fd.valid());
+  EXPECT_FALSE(owned_fd2.valid());  // NOLINT(bugprone-use-after-move)
+  EXPECT_FALSE(IsValidSocket(fd));
+  EXPECT_TRUE(IsValidSocket(fd2));
+}
+
+TEST(SocketTest, OwnedSocketFdRelease) {
+  SocketFd fd = CreateTestSocket(socket_api::SocketProtocol::kUdp);
+  ASSERT_TRUE(IsValidSocket(fd));
+
+  {
+    OwnedSocketFd owned_fd(fd);
+    EXPECT_TRUE(owned_fd.valid());
+
+    EXPECT_EQ(owned_fd.release(), fd);
+    ASSERT_TRUE(IsValidSocket(fd));
+    EXPECT_FALSE(owned_fd.valid());
+
+    owned_fd = OwnedSocketFd(fd);
+  }
+  EXPECT_FALSE(IsValidSocket(fd));
+}
+
 }  // namespace
 }  // namespace quic::test
diff --git a/quiche/quic/test_tools/quic_server_peer.cc b/quiche/quic/test_tools/quic_server_peer.cc
index 279a6a8..bf33eb3 100644
--- a/quiche/quic/test_tools/quic_server_peer.cc
+++ b/quiche/quic/test_tools/quic_server_peer.cc
@@ -15,7 +15,7 @@
 // static
 bool QuicServerPeer::SetSmallSocket(QuicServer* server) {
   int size = 1024 * 10;
-  return setsockopt(server->fd_, SOL_SOCKET, SO_RCVBUF,
+  return setsockopt(*server->fd_, SOL_SOCKET, SO_RCVBUF,
                     reinterpret_cast<char*>(&size), sizeof(size)) != -1;
 }
 
diff --git a/quiche/quic/tools/quic_server.cc b/quiche/quic/tools/quic_server.cc
index 1d82f62..44eed4c 100644
--- a/quiche/quic/tools/quic_server.cc
+++ b/quiche/quic/tools/quic_server.cc
@@ -106,12 +106,6 @@
 }
 
 QuicServer::~QuicServer() {
-  // Ensure the I/O harness is gone before closing the socket.
-  io_.reset();
-
-  (void)socket_api::Close(fd_);
-  fd_ = kInvalidSocketFd;
-
   // Should be fine without because nothing should send requests to the backend
   // after `this` is destroyed, but for extra pointer safety, clear the socket
   // factory from the backend before the socket factory is destroyed.
@@ -127,16 +121,16 @@
 
   dispatcher_.reset(CreateQuicDispatcher());
 
-  absl::StatusOr<SocketFd> fd = CreateAndBindServerSocket(address);
+  absl::StatusOr<OwnedSocketFd> fd = CreateAndBindServerSocket(address);
   if (!fd.ok()) {
-    QUIC_LOG(ERROR) << "Failed to create and bind socket: " << fd;
+    QUIC_LOG(ERROR) << "Failed to create and bind socket: " << fd.status();
     return false;
   }
-  fd_ = *fd;
-  dispatcher_->InitializeWithWriter(CreateWriter(fd_));
+  fd_ = *std::move(fd);
+  dispatcher_->InitializeWithWriter(CreateWriter(*fd_));
 
   absl::StatusOr<std::unique_ptr<QuicServerIoHarness>> io =
-      QuicServerIoHarness::Create(event_loop_.get(), dispatcher_.get(), fd_);
+      QuicServerIoHarness::Create(event_loop_.get(), dispatcher_.get(), *fd_);
   if (!io.ok()) {
     QUICHE_LOG(ERROR) << "Failed to create I/O harness: " << io.status();
     return false;
diff --git a/quiche/quic/tools/quic_server.h b/quiche/quic/tools/quic_server.h
index 56fb461..5d600c3 100644
--- a/quiche/quic/tools/quic_server.h
+++ b/quiche/quic/tools/quic_server.h
@@ -160,7 +160,7 @@
 
   DeterministicConnectionIdGenerator connection_id_generator_;
 
-  SocketFd fd_ = kInvalidSocketFd;
+  OwnedSocketFd fd_;
   std::unique_ptr<QuicServerIoHarness> io_;
 };