Add TCP connection to TUN experiment playground code PiperOrigin-RevId: 895477677
diff --git a/quiche/quic/qbone/bonnet/tun_device_integration_test.cc b/quiche/quic/qbone/bonnet/tun_device_integration_test.cc index c5c50d0..3a1a4ae 100644 --- a/quiche/quic/qbone/bonnet/tun_device_integration_test.cc +++ b/quiche/quic/qbone/bonnet/tun_device_integration_test.cc
@@ -3,15 +3,25 @@ // found in the LICENSE file. #include <linux/if_tun.h> +#include <netinet/icmp6.h> +#include <netinet/ip6.h> +#include <netinet/tcp.h> #include <cerrno> +#include <cstdint> +#include <cstring> #include <memory> #include <string> +#include <utility> #include <vector> #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "absl/synchronization/notification.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" #include "absl/types/span.h" #include "quiche/quic/core/crypto/quic_random.h" #include "quiche/quic/core/io/socket.h" @@ -19,12 +29,15 @@ #include "quiche/quic/platform/api/quic_ip_address_family.h" #include "quiche/quic/platform/api/quic_socket_address.h" #include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/platform/api/quic_thread.h" #include "quiche/quic/qbone/bonnet/tun_device.h" #include "quiche/quic/qbone/bonnet/tun_device_controller.h" #include "quiche/quic/qbone/platform/ip_range.h" #include "quiche/quic/qbone/platform/kernel_interface.h" #include "quiche/quic/qbone/platform/netlink.h" #include "quiche/quic/test_tools/test_ip_packets.h" +#include "quiche/common/internet_checksum.h" +#include "quiche/common/platform/api/quiche_logging.h" namespace quic::test { namespace { @@ -56,7 +69,9 @@ std::unique_ptr<TunDeviceController> tun_device_controller_; }; -absl::Status SetNonBlocking(SocketFd fd) { +// Probably not necessary for TUN devices since TunTapDevice already opens the +// device in non-blocking mode, but good to make sure. +absl::Status SetNonBlocking(int fd) { int flags = ::fcntl(fd, F_GETFL, 0); if (flags < 0) { return absl::ErrnoToStatus(errno, "Failed to get flags"); @@ -71,7 +86,7 @@ ASSERT_TRUE(tun_device_->Init()); ASSERT_GT(tun_device_->GetWriteFileDescriptor(), -1); ASSERT_TRUE(tun_device_controller_->UpdateAddress( - {IpRange(local_address_, /*prefix_length=*/64)})); + IpRange(local_address_, /*prefix_length=*/64))); ASSERT_TRUE(tun_device_->Up()); int sndbuf = 500; @@ -115,7 +130,7 @@ ASSERT_TRUE(tun_device_->Init()); ASSERT_GT(tun_device_->GetWriteFileDescriptor(), -1); ASSERT_TRUE(tun_device_controller_->UpdateAddress( - {IpRange(local_address_, /*prefix_length=*/64)})); + IpRange(local_address_, /*prefix_length=*/64))); ASSERT_TRUE(tun_device_->Up()); QuicSocketAddress source_endpoint(remote_address_, /*port=*/53368); @@ -144,5 +159,342 @@ ASSERT_EQ(receive_data->size(), payload.size()); } +// Useful so a connected TCP socket can disappear immediately after the test is +// done rather than hanging around to wait for graceful TCP termination. +absl::Status DisableLinger(SocketFd fd) { + struct linger linger; + linger.l_onoff = 1; + linger.l_linger = 0; + if (::setsockopt(fd, SOL_SOCKET, SO_LINGER, &linger, sizeof(linger)) < 0) { + return absl::ErrnoToStatus(errno, "Failed to set SO_LINGER"); + } + return absl::OkStatus(); +} + +// Skips the IPv6 header and known extension headers (like hop-by-hop options). +absl::Span<const uint8_t> SkipHeader(absl::Span<const uint8_t> packet, + uint8_t* out_next_header) { + QUICHE_CHECK_GE(packet.size(), sizeof(ip6_hdr)); + const ip6_hdr* ip_header = reinterpret_cast<const ip6_hdr*>(packet.data()); + QUICHE_CHECK_EQ(ip_header->ip6_vfc >> 4, 6); + *out_next_header = ip_header->ip6_nxt; + packet = packet.subspan(sizeof(ip6_hdr)); + while (*out_next_header == IPPROTO_HOPOPTS) { + QUICHE_CHECK_GE(packet.size(), sizeof(ip6_hbh)); + const ip6_hbh* hbh_header = reinterpret_cast<const ip6_hbh*>(packet.data()); + QUICHE_CHECK_GE(packet.size(), 8 + (8 * hbh_header->ip6h_len)); + *out_next_header = hbh_header->ip6h_nxt; + packet = packet.subspan(8 + (8 * hbh_header->ip6h_len)); + } + return packet; +} + +// Returns true if packet is an ignorable ICMPv6 packet that we simply don't +// care about. +bool IsGarbagePacket(absl::Span<const uint8_t> packet) { + uint8_t next_header; + absl::Span<const uint8_t> inner_packet = SkipHeader(packet, &next_header); + if (next_header != IPPROTO_ICMPV6) { + return false; + } + + QUICHE_CHECK_GE(inner_packet.size(), sizeof(icmp6_hdr)); + const icmp6_hdr* icmp6_header = + reinterpret_cast<const icmp6_hdr*>(inner_packet.data()); + switch (icmp6_header->icmp6_type) { + case ND_ROUTER_SOLICIT: + case 143: // Multicast Listener Discovery (MLDv2 Listener Report) + return true; + default: + QUICHE_CHECK(false) + << "Unexpected ICMPv6 type: " << icmp6_header->icmp6_type + << ". Should it be added to the garbage packet filter list?"; + return false; + } +} + +void SkipGarbagePackets(int file_descriptor, Kernel* kernel) { + std::vector<uint8_t> read_buffer(3200); + ssize_t bytes_read = -1; + do { + bytes_read = + kernel->read(file_descriptor, read_buffer.data(), read_buffer.size()); + } while (bytes_read > 0 && + IsGarbagePacket(absl::MakeSpan(read_buffer.data(), bytes_read))); + ASSERT_LT(bytes_read, 0); +} + +std::vector<uint8_t> ReadTcpPacket(int file_descriptor, Kernel* kernel, + absl::Duration timeout, + const QuicIpAddress& expected_source, + const QuicIpAddress& expected_destination) { + std::vector<uint8_t> read_buffer(3200); + ssize_t bytes_read = -1; + absl::Time deadline = absl::Now() + timeout; + while (true) { + bytes_read = + kernel->read(file_descriptor, read_buffer.data(), read_buffer.size()); + if (bytes_read > 0 && + IsGarbagePacket(absl::MakeSpan(read_buffer.data(), bytes_read))) { + continue; + } + + if (bytes_read > 0) { + break; + } + QUICHE_CHECK_EQ(errno, EWOULDBLOCK); + + if (absl::Now() > deadline) { + return {}; + } + + absl::SleepFor(absl::Milliseconds(1)); + } + + QUICHE_CHECK_GT(bytes_read, 0) + << "Failed to read packet: " << strerror(errno); + QUICHE_CHECK_GE(bytes_read, sizeof(ip6_hdr)); + const ip6_hdr* ip_header = + reinterpret_cast<const ip6_hdr*>(read_buffer.data()); + QUICHE_CHECK_EQ(ip_header->ip6_vfc >> 4, 6); + QUICHE_CHECK_EQ(QuicIpAddress(ip_header->ip6_src), expected_source); + QUICHE_CHECK_EQ(QuicIpAddress(ip_header->ip6_dst), expected_destination); + + uint8_t next_header; + absl::Span<const uint8_t> tcp_packet = + SkipHeader(absl::MakeSpan(read_buffer.data(), bytes_read), &next_header); + QUICHE_CHECK_EQ(next_header, IPPROTO_TCP); + QUICHE_CHECK_GE(tcp_packet.size(), sizeof(tcphdr)); + + return std::vector<uint8_t>(tcp_packet.begin(), tcp_packet.end()); +} + +void UpdateTcpChecksum(tcphdr* tcp_header, const QuicIpAddress& source_address, + const QuicIpAddress& destination_address, + absl::Span<const uint8_t> payload) { + quiche::InternetChecksum checksum; + checksum.Update(source_address.ToPackedString()); + checksum.Update(destination_address.ToPackedString()); + uint8_t protocol[] = {0x00, IPPROTO_TCP}; + checksum.Update(protocol, sizeof(protocol)); + uint16_t tcp_length = htons(sizeof(tcphdr) + payload.size()); + checksum.Update(reinterpret_cast<uint8_t*>(&tcp_length), sizeof(tcp_length)); + checksum.Update(reinterpret_cast<const uint8_t*>(tcp_header), sizeof(tcphdr)); + checksum.Update(payload.data(), payload.size()); + + tcp_header->check = checksum.Value(); +} + +class TcpReceiveThread : public QuicThread { + public: + TcpReceiveThread(OwnedSocketFd tcp_socket, absl::Notification* stop) + : QuicThread("TcpReceiveThread"), + tcp_socket_(std::move(tcp_socket)), + stop_(stop) {} + + void Run() override { + receive_buffer_.resize(3200); + + for (;;) { + if (stop_->HasBeenNotified()) { + break; + } + absl::StatusOr<absl::Span<char>> receive_data = socket_api::Receive( + tcp_socket_.get(), absl::MakeSpan(receive_buffer_)); + + if (receive_data.ok()) { + bytes_received_ += receive_data.value().size(); + } else { + QUICHE_CHECK_EQ(receive_data.status().code(), + absl::StatusCode::kUnavailable); + absl::SleepFor(absl::Milliseconds(1)); + } + } + } + + int bytes_received() const { return bytes_received_; } + + private: + OwnedSocketFd tcp_socket_; + std::vector<char> receive_buffer_; + int bytes_received_ = 0; + absl::Notification* stop_; +}; + +TEST_F(TunDeviceIntegrationTest, TcpConnection) { + ASSERT_TRUE(tun_device_->Init()); + ASSERT_GT(tun_device_->GetWriteFileDescriptor(), -1); + ASSERT_TRUE(tun_device_controller_->UpdateAddress( + IpRange(local_address_, /*prefix_length=*/64))); + ASSERT_TRUE(tun_device_->Up()); + ASSERT_TRUE(tun_device_controller_->UpdateRoutes( + IpRange(local_address_, /*prefix_length=*/64), + {IpRange(remote_address_, /*prefix_length=*/64)})); + + ASSERT_OK(SetNonBlocking(tun_device_->GetReadFileDescriptor())); + SkipGarbagePackets(tun_device_->GetReadFileDescriptor(), &kernel_); + + QuicSocketAddress client_endpoint(remote_address_, /*port=*/55171); + QuicSocketAddress server_endpoint(local_address_, /*port=*/60722); + + absl::StatusOr<SocketFd> tcp_socket = socket_api::CreateSocket( + IpAddressFamily::IP_V6, socket_api::SocketProtocol::kTcp, + /*blocking=*/false); + ASSERT_OK(tcp_socket); + OwnedSocketFd owned_tcp_socket(tcp_socket.value()); + + ASSERT_OK(socket_api::Bind(tcp_socket.value(), server_endpoint)); + ASSERT_OK(socket_api::Listen(tcp_socket.value(), 5)); + + tcphdr tcp_header; + ::memset(&tcp_header, 0, sizeof(tcp_header)); + tcp_header.source = htons(client_endpoint.port()); + tcp_header.dest = htons(server_endpoint.port()); + tcp_header.seq = htonl(142); + tcp_header.doff = 5; + tcp_header.syn = 1; + UpdateTcpChecksum(&tcp_header, client_endpoint.host(), server_endpoint.host(), + /*payload=*/{}); + std::string syn_packet = CreateIpPacket( + client_endpoint.host(), server_endpoint.host(), + absl::string_view(reinterpret_cast<const char*>(&tcp_header), + sizeof(tcp_header)), + IpPacketPayloadType::kTcp); + + ASSERT_EQ(kernel_.write(tun_device_->GetWriteFileDescriptor(), + syn_packet.data(), syn_packet.size()), + syn_packet.size()); + + std::vector<uint8_t> syn_ack_packet = ReadTcpPacket( + tun_device_->GetReadFileDescriptor(), &kernel_, absl::Seconds(10), + server_endpoint.host(), client_endpoint.host()); + ASSERT_GE(syn_ack_packet.size(), sizeof(tcphdr)); + const tcphdr* syn_ack_tcp_header = + reinterpret_cast<const tcphdr*>(syn_ack_packet.data()); + ASSERT_EQ(syn_ack_tcp_header->syn, 1); + ASSERT_EQ(syn_ack_tcp_header->ack, 1); + ASSERT_EQ(ntohl(syn_ack_tcp_header->ack_seq), 143); + + ::memset(&tcp_header, 0, sizeof(tcp_header)); + tcp_header.source = htons(client_endpoint.port()); + tcp_header.dest = htons(server_endpoint.port()); + tcp_header.seq = htonl(143); + tcp_header.ack_seq = htonl(ntohl(syn_ack_tcp_header->seq) + 1); + tcp_header.doff = 5; + tcp_header.ack = 1; + UpdateTcpChecksum(&tcp_header, client_endpoint.host(), server_endpoint.host(), + /*payload=*/{}); + std::string ack_packet = CreateIpPacket( + client_endpoint.host(), server_endpoint.host(), + absl::string_view(reinterpret_cast<const char*>(&tcp_header), + sizeof(tcp_header)), + IpPacketPayloadType::kTcp); + + ASSERT_EQ(kernel_.write(tun_device_->GetWriteFileDescriptor(), + ack_packet.data(), ack_packet.size()), + ack_packet.size()); + + absl::StatusOr<socket_api::AcceptResult> accept_result = + socket_api::Accept(tcp_socket.value(), /*blocking=*/false); + ASSERT_OK(accept_result); + OwnedSocketFd connected_socket(accept_result->fd); + ASSERT_OK(DisableLinger(connected_socket.get())); + ASSERT_EQ(accept_result->peer_address, client_endpoint); + + absl::Notification stop_receive_thread; + TcpReceiveThread tcp_receive_thread(std::move(connected_socket), + &stop_receive_thread); + tcp_receive_thread.Start(); + + ::memset(&tcp_header, 0, sizeof(tcp_header)); + tcp_header.source = htons(client_endpoint.port()); + tcp_header.dest = htons(server_endpoint.port()); + tcp_header.ack_seq = htonl(ntohl(syn_ack_tcp_header->seq) + 1); + tcp_header.doff = 5; + tcp_header.ack = 1; + std::string payload(100, 'a'); + int sequence_number = 143; + int highest_ack_seq = 0; + for (int i = 0; i < 1000000; ++i) { + tcp_header.seq = htonl(sequence_number); + tcp_header.check = 0; + UpdateTcpChecksum( + &tcp_header, client_endpoint.host(), server_endpoint.host(), + absl::MakeSpan(reinterpret_cast<const uint8_t*>(payload.data()), + payload.size())); + std::string combined_payload = absl::StrCat( + absl::string_view(reinterpret_cast<const char*>(&tcp_header), + sizeof(tcphdr)), + payload); + std::string packet = + CreateIpPacket(client_endpoint.host(), server_endpoint.host(), + combined_payload, IpPacketPayloadType::kTcp); + ASSERT_EQ(kernel_.write(tun_device_->GetWriteFileDescriptor(), + packet.data(), packet.size()), + packet.size()); + + sequence_number += payload.size(); + + bool zero_window = false; + for (;;) { + std::vector<uint8_t> response_packet = ReadTcpPacket( + tun_device_->GetReadFileDescriptor(), &kernel_, absl::ZeroDuration(), + server_endpoint.host(), client_endpoint.host()); + if (response_packet.empty()) { + break; + } + ASSERT_GE(response_packet.size(), sizeof(tcphdr)); + const tcphdr* response_tcp_header = + reinterpret_cast<const tcphdr*>(response_packet.data()); + ASSERT_EQ(response_tcp_header->ack, 1); + ASSERT_GE(ntohl(response_tcp_header->ack_seq), highest_ack_seq); + ASSERT_EQ(ntohl(response_tcp_header->seq), + ntohl(syn_ack_tcp_header->seq) + 1); + + if (ntohs(response_tcp_header->window) == 0) { + zero_window = true; + QUICHE_LOG(INFO) + << "Window is zero, stopping massive writes at iteration " << i; + break; + } + + if (ntohl(response_tcp_header->ack_seq) > highest_ack_seq) { + highest_ack_seq = ntohl(response_tcp_header->ack_seq); + if (highest_ack_seq > sequence_number) { + // Missed packets have been filled in, so we can jump back ahead. + sequence_number = highest_ack_seq; + } + } else { + // Revert and retry at missed sequence number. + sequence_number = ntohl(response_tcp_header->ack_seq); + } + } + + if (zero_window) { + break; + } + } + + for (;;) { + std::vector<uint8_t> response_packet = ReadTcpPacket( + tun_device_->GetReadFileDescriptor(), &kernel_, absl::Seconds(5), + server_endpoint.host(), client_endpoint.host()); + if (response_packet.empty()) { + break; + } + ASSERT_GE(response_packet.size(), sizeof(tcphdr)); + const tcphdr* response_tcp_header = + reinterpret_cast<const tcphdr*>(response_packet.data()); + ASSERT_EQ(response_tcp_header->ack, 1); + ASSERT_GE(ntohl(response_tcp_header->ack_seq), highest_ack_seq); + ASSERT_EQ(ntohl(response_tcp_header->seq), + ntohl(syn_ack_tcp_header->seq) + 1); + highest_ack_seq = ntohl(response_tcp_header->ack_seq); + } + + stop_receive_thread.Notify(); + tcp_receive_thread.Join(); +} + } // namespace } // namespace quic::test