Move TestPacketWriter from quic_connection_test.cc to quic_test_utils.
PiperOrigin-RevId: 334408019
Change-Id: I3fc479fea2af0e1e2e070bfd0c120ed732061586
diff --git a/quic/core/quic_connection_test.cc b/quic/core/quic_connection_test.cc
index 2c791f2..1b41e8e 100644
--- a/quic/core/quic_connection_test.cc
+++ b/quic/core/quic_connection_test.cc
@@ -24,7 +24,6 @@
#include "net/third_party/quiche/src/quic/core/quic_types.h"
#include "net/third_party/quiche/src/quic/core/quic_utils.h"
#include "net/third_party/quiche/src/quic/core/quic_versions.h"
-#include "net/third_party/quiche/src/quic/platform/api/quic_error_code_wrappers.h"
#include "net/third_party/quiche/src/quic/platform/api/quic_expect_bug.h"
#include "net/third_party/quiche/src/quic/platform/api/quic_flags.h"
#include "net/third_party/quiche/src/quic/platform/api/quic_logging.h"
@@ -40,7 +39,6 @@
#include "net/third_party/quiche/src/quic/test_tools/quic_sent_packet_manager_peer.h"
#include "net/third_party/quiche/src/quic/test_tools/quic_test_utils.h"
#include "net/third_party/quiche/src/quic/test_tools/simple_data_producer.h"
-#include "net/third_party/quiche/src/quic/test_tools/simple_quic_framer.h"
#include "net/third_party/quiche/src/quic/test_tools/simple_session_notifier.h"
#include "net/third_party/quiche/src/common/platform/api/quiche_arraysize.h"
#include "net/third_party/quiche/src/common/platform/api/quiche_str_cat.h"
@@ -110,162 +108,6 @@
}
}
-// TaggingEncrypter appends kTagSize bytes of |tag| to the end of each message.
-class TaggingEncrypter : public QuicEncrypter {
- public:
- explicit TaggingEncrypter(uint8_t tag) : tag_(tag) {}
- TaggingEncrypter(const TaggingEncrypter&) = delete;
- TaggingEncrypter& operator=(const TaggingEncrypter&) = delete;
-
- ~TaggingEncrypter() override {}
-
- // QuicEncrypter interface.
- bool SetKey(quiche::QuicheStringPiece /*key*/) override { return true; }
-
- bool SetNoncePrefix(quiche::QuicheStringPiece /*nonce_prefix*/) override {
- return true;
- }
-
- bool SetIV(quiche::QuicheStringPiece /*iv*/) override { return true; }
-
- bool SetHeaderProtectionKey(quiche::QuicheStringPiece /*key*/) override {
- return true;
- }
-
- bool EncryptPacket(uint64_t /*packet_number*/,
- quiche::QuicheStringPiece /*associated_data*/,
- quiche::QuicheStringPiece plaintext,
- char* output,
- size_t* output_length,
- size_t max_output_length) override {
- const size_t len = plaintext.size() + kTagSize;
- if (max_output_length < len) {
- return false;
- }
- // Memmove is safe for inplace encryption.
- memmove(output, plaintext.data(), plaintext.size());
- output += plaintext.size();
- memset(output, tag_, kTagSize);
- *output_length = len;
- return true;
- }
-
- std::string GenerateHeaderProtectionMask(
- quiche::QuicheStringPiece /*sample*/) override {
- return std::string(5, 0);
- }
-
- size_t GetKeySize() const override { return 0; }
- size_t GetNoncePrefixSize() const override { return 0; }
- size_t GetIVSize() const override { return 0; }
-
- size_t GetMaxPlaintextSize(size_t ciphertext_size) const override {
- return ciphertext_size - kTagSize;
- }
-
- size_t GetCiphertextSize(size_t plaintext_size) const override {
- return plaintext_size + kTagSize;
- }
-
- quiche::QuicheStringPiece GetKey() const override {
- return quiche::QuicheStringPiece();
- }
-
- quiche::QuicheStringPiece GetNoncePrefix() const override {
- return quiche::QuicheStringPiece();
- }
-
- private:
- enum {
- kTagSize = 12,
- };
-
- const uint8_t tag_;
-};
-
-// TaggingDecrypter ensures that the final kTagSize bytes of the message all
-// have the same value and then removes them.
-class TaggingDecrypter : public QuicDecrypter {
- public:
- ~TaggingDecrypter() override {}
-
- // QuicDecrypter interface
- bool SetKey(quiche::QuicheStringPiece /*key*/) override { return true; }
-
- bool SetNoncePrefix(quiche::QuicheStringPiece /*nonce_prefix*/) override {
- return true;
- }
-
- bool SetIV(quiche::QuicheStringPiece /*iv*/) override { return true; }
-
- bool SetHeaderProtectionKey(quiche::QuicheStringPiece /*key*/) override {
- return true;
- }
-
- bool SetPreliminaryKey(quiche::QuicheStringPiece /*key*/) override {
- QUIC_BUG << "should not be called";
- return false;
- }
-
- bool SetDiversificationNonce(const DiversificationNonce& /*key*/) override {
- return true;
- }
-
- bool DecryptPacket(uint64_t /*packet_number*/,
- quiche::QuicheStringPiece /*associated_data*/,
- quiche::QuicheStringPiece ciphertext,
- char* output,
- size_t* output_length,
- size_t /*max_output_length*/) override {
- if (ciphertext.size() < kTagSize) {
- return false;
- }
- if (!CheckTag(ciphertext, GetTag(ciphertext))) {
- return false;
- }
- *output_length = ciphertext.size() - kTagSize;
- memcpy(output, ciphertext.data(), *output_length);
- return true;
- }
-
- std::string GenerateHeaderProtectionMask(
- QuicDataReader* /*sample_reader*/) override {
- return std::string(5, 0);
- }
-
- size_t GetKeySize() const override { return 0; }
- size_t GetNoncePrefixSize() const override { return 0; }
- size_t GetIVSize() const override { return 0; }
- quiche::QuicheStringPiece GetKey() const override {
- return quiche::QuicheStringPiece();
- }
- quiche::QuicheStringPiece GetNoncePrefix() const override {
- return quiche::QuicheStringPiece();
- }
- // Use a distinct value starting with 0xFFFFFF, which is never used by TLS.
- uint32_t cipher_id() const override { return 0xFFFFFFF0; }
-
- protected:
- virtual uint8_t GetTag(quiche::QuicheStringPiece ciphertext) {
- return ciphertext.data()[ciphertext.size() - 1];
- }
-
- private:
- enum {
- kTagSize = 12,
- };
-
- bool CheckTag(quiche::QuicheStringPiece ciphertext, uint8_t tag) {
- for (size_t i = ciphertext.size() - kTagSize; i < ciphertext.size(); i++) {
- if (ciphertext.data()[i] != tag) {
- return false;
- }
- }
-
- return true;
- }
-};
-
// StringTaggingDecrypter ensures that the final kTagSize bytes of the message
// match the expected value.
class StrictTaggingDecrypter : public TaggingDecrypter {
@@ -336,352 +178,6 @@
}
};
-class TestPacketWriter : public QuicPacketWriter {
- struct PacketBuffer {
- QUIC_CACHELINE_ALIGNED char buffer[1500];
- bool in_use = false;
- };
-
- public:
- TestPacketWriter(ParsedQuicVersion version, MockClock* clock)
- : version_(version),
- framer_(SupportedVersions(version_), Perspective::IS_SERVER),
- clock_(clock) {
- QuicFramerPeer::SetLastSerializedServerConnectionId(framer_.framer(),
- TestConnectionId());
- framer_.framer()->SetInitialObfuscators(TestConnectionId());
-
- for (int i = 0; i < 128; ++i) {
- PacketBuffer* p = new PacketBuffer();
- packet_buffer_pool_.push_back(p);
- packet_buffer_pool_index_[p->buffer] = p;
- packet_buffer_free_list_.push_back(p);
- }
- }
- TestPacketWriter(const TestPacketWriter&) = delete;
- TestPacketWriter& operator=(const TestPacketWriter&) = delete;
-
- ~TestPacketWriter() override {
- EXPECT_EQ(packet_buffer_pool_.size(), packet_buffer_free_list_.size())
- << packet_buffer_pool_.size() - packet_buffer_free_list_.size()
- << " out of " << packet_buffer_pool_.size()
- << " packet buffers have been leaked.";
- for (auto p : packet_buffer_pool_) {
- delete p;
- }
- }
-
- // QuicPacketWriter interface
- WriteResult WritePacket(const char* buffer,
- size_t buf_len,
- const QuicIpAddress& /*self_address*/,
- const QuicSocketAddress& peer_address,
- PerPacketOptions* /*options*/) override {
- last_write_peer_address_ = peer_address;
- // If the buffer is allocated from the pool, return it back to the pool.
- // Note the buffer content doesn't change.
- if (packet_buffer_pool_index_.find(const_cast<char*>(buffer)) !=
- packet_buffer_pool_index_.end()) {
- FreePacketBuffer(buffer);
- }
-
- QuicEncryptedPacket packet(buffer, buf_len);
- ++packets_write_attempts_;
-
- if (packet.length() >= sizeof(final_bytes_of_last_packet_)) {
- final_bytes_of_previous_packet_ = final_bytes_of_last_packet_;
- memcpy(&final_bytes_of_last_packet_, packet.data() + packet.length() - 4,
- sizeof(final_bytes_of_last_packet_));
- }
-
- if (use_tagging_decrypter_) {
- if (framer_.framer()->version().KnowsWhichDecrypterToUse()) {
- framer_.framer()->InstallDecrypter(
- ENCRYPTION_INITIAL, std::make_unique<TaggingDecrypter>());
- framer_.framer()->InstallDecrypter(
- ENCRYPTION_HANDSHAKE, std::make_unique<TaggingDecrypter>());
- framer_.framer()->InstallDecrypter(
- ENCRYPTION_ZERO_RTT, std::make_unique<TaggingDecrypter>());
- framer_.framer()->InstallDecrypter(
- ENCRYPTION_FORWARD_SECURE, std::make_unique<TaggingDecrypter>());
- } else {
- framer_.framer()->SetDecrypter(ENCRYPTION_INITIAL,
- std::make_unique<TaggingDecrypter>());
- }
- } else if (framer_.framer()->version().KnowsWhichDecrypterToUse()) {
- framer_.framer()->InstallDecrypter(
- ENCRYPTION_FORWARD_SECURE,
- std::make_unique<NullDecrypter>(Perspective::IS_SERVER));
- }
- EXPECT_TRUE(framer_.ProcessPacket(packet))
- << framer_.framer()->detailed_error();
- if (block_on_next_write_) {
- write_blocked_ = true;
- block_on_next_write_ = false;
- }
- if (next_packet_too_large_) {
- next_packet_too_large_ = false;
- return WriteResult(WRITE_STATUS_ERROR, QUIC_EMSGSIZE);
- }
- if (always_get_packet_too_large_) {
- return WriteResult(WRITE_STATUS_ERROR, QUIC_EMSGSIZE);
- }
- if (IsWriteBlocked()) {
- return WriteResult(is_write_blocked_data_buffered_
- ? WRITE_STATUS_BLOCKED_DATA_BUFFERED
- : WRITE_STATUS_BLOCKED,
- 0);
- }
-
- if (ShouldWriteFail()) {
- return WriteResult(WRITE_STATUS_ERROR, 0);
- }
-
- last_packet_size_ = packet.length();
- last_packet_header_ = framer_.header();
- if (!framer_.connection_close_frames().empty()) {
- ++connection_close_packets_;
- }
- if (!write_pause_time_delta_.IsZero()) {
- clock_->AdvanceTime(write_pause_time_delta_);
- }
- if (is_batch_mode_) {
- bytes_buffered_ += last_packet_size_;
- return WriteResult(WRITE_STATUS_OK, 0);
- }
- return WriteResult(WRITE_STATUS_OK, last_packet_size_);
- }
-
- bool ShouldWriteFail() { return write_should_fail_; }
-
- bool IsWriteBlocked() const override { return write_blocked_; }
-
- void SetWriteBlocked() { write_blocked_ = true; }
-
- void SetWritable() override { write_blocked_ = false; }
-
- void SetShouldWriteFail() { write_should_fail_ = true; }
-
- QuicByteCount GetMaxPacketSize(
- const QuicSocketAddress& /*peer_address*/) const override {
- return max_packet_size_;
- }
-
- bool SupportsReleaseTime() const override { return supports_release_time_; }
-
- bool IsBatchMode() const override { return is_batch_mode_; }
-
- QuicPacketBuffer GetNextWriteLocation(
- const QuicIpAddress& /*self_address*/,
- const QuicSocketAddress& /*peer_address*/) override {
- return {AllocPacketBuffer(),
- [this](const char* p) { FreePacketBuffer(p); }};
- }
-
- WriteResult Flush() override {
- flush_attempts_++;
- if (block_on_next_flush_) {
- block_on_next_flush_ = false;
- SetWriteBlocked();
- return WriteResult(WRITE_STATUS_BLOCKED, /*errno*/ -1);
- }
- if (write_should_fail_) {
- return WriteResult(WRITE_STATUS_ERROR, /*errno*/ -1);
- }
- int bytes_flushed = bytes_buffered_;
- bytes_buffered_ = 0;
- return WriteResult(WRITE_STATUS_OK, bytes_flushed);
- }
-
- void BlockOnNextFlush() { block_on_next_flush_ = true; }
-
- void BlockOnNextWrite() { block_on_next_write_ = true; }
-
- void SimulateNextPacketTooLarge() { next_packet_too_large_ = true; }
-
- void AlwaysGetPacketTooLarge() { always_get_packet_too_large_ = true; }
-
- // Sets the amount of time that the writer should before the actual write.
- void SetWritePauseTimeDelta(QuicTime::Delta delta) {
- write_pause_time_delta_ = delta;
- }
-
- void SetBatchMode(bool new_value) { is_batch_mode_ = new_value; }
-
- const QuicPacketHeader& header() { return framer_.header(); }
-
- size_t frame_count() const { return framer_.num_frames(); }
-
- const std::vector<QuicAckFrame>& ack_frames() const {
- return framer_.ack_frames();
- }
-
- const std::vector<QuicStopWaitingFrame>& stop_waiting_frames() const {
- return framer_.stop_waiting_frames();
- }
-
- const std::vector<QuicConnectionCloseFrame>& connection_close_frames() const {
- return framer_.connection_close_frames();
- }
-
- const std::vector<QuicRstStreamFrame>& rst_stream_frames() const {
- return framer_.rst_stream_frames();
- }
-
- const std::vector<std::unique_ptr<QuicStreamFrame>>& stream_frames() const {
- return framer_.stream_frames();
- }
-
- const std::vector<std::unique_ptr<QuicCryptoFrame>>& crypto_frames() const {
- return framer_.crypto_frames();
- }
-
- const std::vector<QuicPingFrame>& ping_frames() const {
- return framer_.ping_frames();
- }
-
- const std::vector<QuicMessageFrame>& message_frames() const {
- return framer_.message_frames();
- }
-
- const std::vector<QuicWindowUpdateFrame>& window_update_frames() const {
- return framer_.window_update_frames();
- }
-
- const std::vector<QuicPaddingFrame>& padding_frames() const {
- return framer_.padding_frames();
- }
-
- const std::vector<QuicPathChallengeFrame>& path_challenge_frames() const {
- return framer_.path_challenge_frames();
- }
-
- const std::vector<QuicPathResponseFrame>& path_response_frames() const {
- return framer_.path_response_frames();
- }
-
- const QuicEncryptedPacket* coalesced_packet() const {
- return framer_.coalesced_packet();
- }
-
- size_t last_packet_size() { return last_packet_size_; }
-
- const QuicPacketHeader& last_packet_header() const {
- return last_packet_header_;
- }
-
- const QuicVersionNegotiationPacket* version_negotiation_packet() {
- return framer_.version_negotiation_packet();
- }
-
- void set_is_write_blocked_data_buffered(bool buffered) {
- is_write_blocked_data_buffered_ = buffered;
- }
-
- void set_perspective(Perspective perspective) {
- // We invert perspective here, because the framer needs to parse packets
- // we send.
- QuicFramerPeer::SetPerspective(framer_.framer(),
- QuicUtils::InvertPerspective(perspective));
- }
-
- // final_bytes_of_last_packet_ returns the last four bytes of the previous
- // packet as a little-endian, uint32_t. This is intended to be used with a
- // TaggingEncrypter so that tests can determine which encrypter was used for
- // a given packet.
- uint32_t final_bytes_of_last_packet() { return final_bytes_of_last_packet_; }
-
- // Returns the final bytes of the second to last packet.
- uint32_t final_bytes_of_previous_packet() {
- return final_bytes_of_previous_packet_;
- }
-
- void use_tagging_decrypter() { use_tagging_decrypter_ = true; }
-
- uint32_t packets_write_attempts() const { return packets_write_attempts_; }
-
- uint32_t flush_attempts() const { return flush_attempts_; }
-
- uint32_t connection_close_packets() const {
- return connection_close_packets_;
- }
-
- void Reset() { framer_.Reset(); }
-
- void SetSupportedVersions(const ParsedQuicVersionVector& versions) {
- framer_.SetSupportedVersions(versions);
- }
-
- void set_max_packet_size(QuicByteCount max_packet_size) {
- max_packet_size_ = max_packet_size;
- }
-
- void set_supports_release_time(bool supports_release_time) {
- supports_release_time_ = supports_release_time;
- }
-
- SimpleQuicFramer* framer() { return &framer_; }
-
- const QuicSocketAddress& last_write_peer_address() const {
- return last_write_peer_address_;
- }
-
- private:
- char* AllocPacketBuffer() {
- PacketBuffer* p = packet_buffer_free_list_.front();
- EXPECT_FALSE(p->in_use);
- p->in_use = true;
- packet_buffer_free_list_.pop_front();
- return p->buffer;
- }
-
- void FreePacketBuffer(const char* buffer) {
- auto iter = packet_buffer_pool_index_.find(const_cast<char*>(buffer));
- ASSERT_TRUE(iter != packet_buffer_pool_index_.end());
- PacketBuffer* p = iter->second;
- ASSERT_TRUE(p->in_use);
- p->in_use = false;
- packet_buffer_free_list_.push_back(p);
- }
-
- ParsedQuicVersion version_;
- SimpleQuicFramer framer_;
- size_t last_packet_size_ = 0;
- QuicPacketHeader last_packet_header_;
- bool write_blocked_ = false;
- bool write_should_fail_ = false;
- bool block_on_next_flush_ = false;
- bool block_on_next_write_ = false;
- bool next_packet_too_large_ = false;
- bool always_get_packet_too_large_ = false;
- bool is_write_blocked_data_buffered_ = false;
- bool is_batch_mode_ = false;
- // Number of times Flush() was called.
- uint32_t flush_attempts_ = 0;
- // (Batch mode only) Number of bytes buffered in writer. It is used as the
- // return value of a successful Flush().
- uint32_t bytes_buffered_ = 0;
- uint32_t final_bytes_of_last_packet_ = 0;
- uint32_t final_bytes_of_previous_packet_ = 0;
- bool use_tagging_decrypter_ = false;
- uint32_t packets_write_attempts_ = 0;
- uint32_t connection_close_packets_ = 0;
- MockClock* clock_ = nullptr;
- // If non-zero, the clock will pause during WritePacket for this amount of
- // time.
- QuicTime::Delta write_pause_time_delta_ = QuicTime::Delta::Zero();
- QuicByteCount max_packet_size_ = kMaxOutgoingPacketSize;
- bool supports_release_time_ = false;
- // Used to verify writer-allocated packet buffers are properly released.
- std::vector<PacketBuffer*> packet_buffer_pool_;
- // Buffer address => Address of the owning PacketBuffer.
- QuicHashMap<char*, PacketBuffer*> packet_buffer_pool_index_;
- // Indices in packet_buffer_pool_ that are not allocated.
- std::list<PacketBuffer*> packet_buffer_free_list_;
- // The peer address passed into WritePacket().
- QuicSocketAddress last_write_peer_address_;
-};
-
class TestConnection : public QuicConnection {
public:
TestConnection(QuicConnectionId connection_id,
diff --git a/quic/test_tools/quic_test_utils.cc b/quic/test_tools/quic_test_utils.cc
index 6e83351..5aa2a7a 100644
--- a/quic/test_tools/quic_test_utils.cc
+++ b/quic/test_tools/quic_test_utils.cc
@@ -14,6 +14,7 @@
#include "net/third_party/quiche/src/quic/core/crypto/crypto_framer.h"
#include "net/third_party/quiche/src/quic/core/crypto/crypto_handshake.h"
#include "net/third_party/quiche/src/quic/core/crypto/crypto_utils.h"
+#include "net/third_party/quiche/src/quic/core/crypto/null_decrypter.h"
#include "net/third_party/quiche/src/quic/core/crypto/null_encrypter.h"
#include "net/third_party/quiche/src/quic/core/crypto/quic_decrypter.h"
#include "net/third_party/quiche/src/quic/core/crypto/quic_encrypter.h"
@@ -27,6 +28,7 @@
#include "net/third_party/quiche/src/quic/core/quic_types.h"
#include "net/third_party/quiche/src/quic/core/quic_utils.h"
#include "net/third_party/quiche/src/quic/core/quic_versions.h"
+#include "net/third_party/quiche/src/quic/platform/api/quic_error_code_wrappers.h"
#include "net/third_party/quiche/src/quic/platform/api/quic_flags.h"
#include "net/third_party/quiche/src/quic/platform/api/quic_logging.h"
#include "net/third_party/quiche/src/quic/test_tools/crypto_test_utils.h"
@@ -1313,5 +1315,197 @@
return QuicMemSlice(std::move(buffer), data.size());
}
+bool TaggingEncrypter::EncryptPacket(
+ uint64_t /*packet_number*/,
+ quiche::QuicheStringPiece /*associated_data*/,
+ quiche::QuicheStringPiece plaintext,
+ char* output,
+ size_t* output_length,
+ size_t max_output_length) {
+ const size_t len = plaintext.size() + kTagSize;
+ if (max_output_length < len) {
+ return false;
+ }
+ // Memmove is safe for inplace encryption.
+ memmove(output, plaintext.data(), plaintext.size());
+ output += plaintext.size();
+ memset(output, tag_, kTagSize);
+ *output_length = len;
+ return true;
+}
+
+bool TaggingDecrypter::DecryptPacket(
+ uint64_t /*packet_number*/,
+ quiche::QuicheStringPiece /*associated_data*/,
+ quiche::QuicheStringPiece ciphertext,
+ char* output,
+ size_t* output_length,
+ size_t /*max_output_length*/) {
+ if (ciphertext.size() < kTagSize) {
+ return false;
+ }
+ if (!CheckTag(ciphertext, GetTag(ciphertext))) {
+ return false;
+ }
+ *output_length = ciphertext.size() - kTagSize;
+ memcpy(output, ciphertext.data(), *output_length);
+ return true;
+}
+
+bool TaggingDecrypter::CheckTag(quiche::QuicheStringPiece ciphertext,
+ uint8_t tag) {
+ for (size_t i = ciphertext.size() - kTagSize; i < ciphertext.size(); i++) {
+ if (ciphertext.data()[i] != tag) {
+ return false;
+ }
+ }
+
+ return true;
+}
+
+TestPacketWriter::TestPacketWriter(ParsedQuicVersion version, MockClock* clock)
+ : version_(version),
+ framer_(SupportedVersions(version_), Perspective::IS_SERVER),
+ clock_(clock) {
+ QuicFramerPeer::SetLastSerializedServerConnectionId(framer_.framer(),
+ TestConnectionId());
+ framer_.framer()->SetInitialObfuscators(TestConnectionId());
+
+ for (int i = 0; i < 128; ++i) {
+ PacketBuffer* p = new PacketBuffer();
+ packet_buffer_pool_.push_back(p);
+ packet_buffer_pool_index_[p->buffer] = p;
+ packet_buffer_free_list_.push_back(p);
+ }
+}
+
+TestPacketWriter::~TestPacketWriter() {
+ EXPECT_EQ(packet_buffer_pool_.size(), packet_buffer_free_list_.size())
+ << packet_buffer_pool_.size() - packet_buffer_free_list_.size()
+ << " out of " << packet_buffer_pool_.size()
+ << " packet buffers have been leaked.";
+ for (auto p : packet_buffer_pool_) {
+ delete p;
+ }
+}
+
+WriteResult TestPacketWriter::WritePacket(const char* buffer,
+ size_t buf_len,
+ const QuicIpAddress& /*self_address*/,
+ const QuicSocketAddress& peer_address,
+ PerPacketOptions* /*options*/) {
+ last_write_peer_address_ = peer_address;
+ // If the buffer is allocated from the pool, return it back to the pool.
+ // Note the buffer content doesn't change.
+ if (packet_buffer_pool_index_.find(const_cast<char*>(buffer)) !=
+ packet_buffer_pool_index_.end()) {
+ FreePacketBuffer(buffer);
+ }
+
+ QuicEncryptedPacket packet(buffer, buf_len);
+ ++packets_write_attempts_;
+
+ if (packet.length() >= sizeof(final_bytes_of_last_packet_)) {
+ final_bytes_of_previous_packet_ = final_bytes_of_last_packet_;
+ memcpy(&final_bytes_of_last_packet_, packet.data() + packet.length() - 4,
+ sizeof(final_bytes_of_last_packet_));
+ }
+
+ if (use_tagging_decrypter_) {
+ if (framer_.framer()->version().KnowsWhichDecrypterToUse()) {
+ framer_.framer()->InstallDecrypter(ENCRYPTION_INITIAL,
+ std::make_unique<TaggingDecrypter>());
+ framer_.framer()->InstallDecrypter(ENCRYPTION_HANDSHAKE,
+ std::make_unique<TaggingDecrypter>());
+ framer_.framer()->InstallDecrypter(ENCRYPTION_ZERO_RTT,
+ std::make_unique<TaggingDecrypter>());
+ framer_.framer()->InstallDecrypter(ENCRYPTION_FORWARD_SECURE,
+ std::make_unique<TaggingDecrypter>());
+ } else {
+ framer_.framer()->SetDecrypter(ENCRYPTION_INITIAL,
+ std::make_unique<TaggingDecrypter>());
+ }
+ } else if (framer_.framer()->version().KnowsWhichDecrypterToUse()) {
+ framer_.framer()->InstallDecrypter(
+ ENCRYPTION_FORWARD_SECURE,
+ std::make_unique<NullDecrypter>(Perspective::IS_SERVER));
+ }
+ EXPECT_TRUE(framer_.ProcessPacket(packet))
+ << framer_.framer()->detailed_error();
+ if (block_on_next_write_) {
+ write_blocked_ = true;
+ block_on_next_write_ = false;
+ }
+ if (next_packet_too_large_) {
+ next_packet_too_large_ = false;
+ return WriteResult(WRITE_STATUS_ERROR, QUIC_EMSGSIZE);
+ }
+ if (always_get_packet_too_large_) {
+ return WriteResult(WRITE_STATUS_ERROR, QUIC_EMSGSIZE);
+ }
+ if (IsWriteBlocked()) {
+ return WriteResult(is_write_blocked_data_buffered_
+ ? WRITE_STATUS_BLOCKED_DATA_BUFFERED
+ : WRITE_STATUS_BLOCKED,
+ 0);
+ }
+
+ if (ShouldWriteFail()) {
+ return WriteResult(WRITE_STATUS_ERROR, 0);
+ }
+
+ last_packet_size_ = packet.length();
+ last_packet_header_ = framer_.header();
+ if (!framer_.connection_close_frames().empty()) {
+ ++connection_close_packets_;
+ }
+ if (!write_pause_time_delta_.IsZero()) {
+ clock_->AdvanceTime(write_pause_time_delta_);
+ }
+ if (is_batch_mode_) {
+ bytes_buffered_ += last_packet_size_;
+ return WriteResult(WRITE_STATUS_OK, 0);
+ }
+ return WriteResult(WRITE_STATUS_OK, last_packet_size_);
+}
+
+QuicPacketBuffer TestPacketWriter::GetNextWriteLocation(
+ const QuicIpAddress& /*self_address*/,
+ const QuicSocketAddress& /*peer_address*/) {
+ return {AllocPacketBuffer(), [this](const char* p) { FreePacketBuffer(p); }};
+}
+
+WriteResult TestPacketWriter::Flush() {
+ flush_attempts_++;
+ if (block_on_next_flush_) {
+ block_on_next_flush_ = false;
+ SetWriteBlocked();
+ return WriteResult(WRITE_STATUS_BLOCKED, /*errno*/ -1);
+ }
+ if (write_should_fail_) {
+ return WriteResult(WRITE_STATUS_ERROR, /*errno*/ -1);
+ }
+ int bytes_flushed = bytes_buffered_;
+ bytes_buffered_ = 0;
+ return WriteResult(WRITE_STATUS_OK, bytes_flushed);
+}
+
+char* TestPacketWriter::AllocPacketBuffer() {
+ PacketBuffer* p = packet_buffer_free_list_.front();
+ EXPECT_FALSE(p->in_use);
+ p->in_use = true;
+ packet_buffer_free_list_.pop_front();
+ return p->buffer;
+}
+
+void TestPacketWriter::FreePacketBuffer(const char* buffer) {
+ auto iter = packet_buffer_pool_index_.find(const_cast<char*>(buffer));
+ ASSERT_TRUE(iter != packet_buffer_pool_index_.end());
+ PacketBuffer* p = iter->second;
+ ASSERT_TRUE(p->in_use);
+ p->in_use = false;
+ packet_buffer_free_list_.push_back(p);
+}
+
} // namespace test
} // namespace quic
diff --git a/quic/test_tools/quic_test_utils.h b/quic/test_tools/quic_test_utils.h
index cd87b5b..e207a56 100644
--- a/quic/test_tools/quic_test_utils.h
+++ b/quic/test_tools/quic_test_utils.h
@@ -27,11 +27,14 @@
#include "net/third_party/quiche/src/quic/core/quic_server_id.h"
#include "net/third_party/quiche/src/quic/core/quic_simple_buffer_allocator.h"
#include "net/third_party/quiche/src/quic/core/quic_types.h"
+#include "net/third_party/quiche/src/quic/core/quic_utils.h"
#include "net/third_party/quiche/src/quic/platform/api/quic_mem_slice_storage.h"
#include "net/third_party/quiche/src/quic/platform/api/quic_test.h"
#include "net/third_party/quiche/src/quic/test_tools/mock_clock.h"
#include "net/third_party/quiche/src/quic/test_tools/mock_quic_session_visitor.h"
#include "net/third_party/quiche/src/quic/test_tools/mock_random.h"
+#include "net/third_party/quiche/src/quic/test_tools/quic_framer_peer.h"
+#include "net/third_party/quiche/src/quic/test_tools/simple_quic_framer.h"
#include "net/third_party/quiche/src/common/platform/api/quiche_str_cat.h"
#include "net/third_party/quiche/src/common/platform/api/quiche_string_piece.h"
@@ -1741,6 +1744,354 @@
return arg == QUIC_STREAM_NO_ERROR;
}
+// TaggingEncrypter appends kTagSize bytes of |tag| to the end of each message.
+class TaggingEncrypter : public QuicEncrypter {
+ public:
+ explicit TaggingEncrypter(uint8_t tag) : tag_(tag) {}
+ TaggingEncrypter(const TaggingEncrypter&) = delete;
+ TaggingEncrypter& operator=(const TaggingEncrypter&) = delete;
+
+ ~TaggingEncrypter() override {}
+
+ // QuicEncrypter interface.
+ bool SetKey(quiche::QuicheStringPiece /*key*/) override { return true; }
+
+ bool SetNoncePrefix(quiche::QuicheStringPiece /*nonce_prefix*/) override {
+ return true;
+ }
+
+ bool SetIV(quiche::QuicheStringPiece /*iv*/) override { return true; }
+
+ bool SetHeaderProtectionKey(quiche::QuicheStringPiece /*key*/) override {
+ return true;
+ }
+
+ bool EncryptPacket(uint64_t packet_number,
+ quiche::QuicheStringPiece associated_data,
+ quiche::QuicheStringPiece plaintext,
+ char* output,
+ size_t* output_length,
+ size_t max_output_length) override;
+
+ std::string GenerateHeaderProtectionMask(
+ quiche::QuicheStringPiece /*sample*/) override {
+ return std::string(5, 0);
+ }
+
+ size_t GetKeySize() const override { return 0; }
+ size_t GetNoncePrefixSize() const override { return 0; }
+ size_t GetIVSize() const override { return 0; }
+
+ size_t GetMaxPlaintextSize(size_t ciphertext_size) const override {
+ return ciphertext_size - kTagSize;
+ }
+
+ size_t GetCiphertextSize(size_t plaintext_size) const override {
+ return plaintext_size + kTagSize;
+ }
+
+ quiche::QuicheStringPiece GetKey() const override {
+ return quiche::QuicheStringPiece();
+ }
+
+ quiche::QuicheStringPiece GetNoncePrefix() const override {
+ return quiche::QuicheStringPiece();
+ }
+
+ private:
+ enum {
+ kTagSize = 12,
+ };
+
+ const uint8_t tag_;
+};
+
+// TaggingDecrypter ensures that the final kTagSize bytes of the message all
+// have the same value and then removes them.
+class TaggingDecrypter : public QuicDecrypter {
+ public:
+ ~TaggingDecrypter() override {}
+
+ // QuicDecrypter interface
+ bool SetKey(quiche::QuicheStringPiece /*key*/) override { return true; }
+
+ bool SetNoncePrefix(quiche::QuicheStringPiece /*nonce_prefix*/) override {
+ return true;
+ }
+
+ bool SetIV(quiche::QuicheStringPiece /*iv*/) override { return true; }
+
+ bool SetHeaderProtectionKey(quiche::QuicheStringPiece /*key*/) override {
+ return true;
+ }
+
+ bool SetPreliminaryKey(quiche::QuicheStringPiece /*key*/) override {
+ QUIC_BUG << "should not be called";
+ return false;
+ }
+
+ bool SetDiversificationNonce(const DiversificationNonce& /*key*/) override {
+ return true;
+ }
+
+ bool DecryptPacket(uint64_t packet_number,
+ quiche::QuicheStringPiece associated_data,
+ quiche::QuicheStringPiece ciphertext,
+ char* output,
+ size_t* output_length,
+ size_t max_output_length) override;
+
+ std::string GenerateHeaderProtectionMask(
+ QuicDataReader* /*sample_reader*/) override {
+ return std::string(5, 0);
+ }
+
+ size_t GetKeySize() const override { return 0; }
+ size_t GetNoncePrefixSize() const override { return 0; }
+ size_t GetIVSize() const override { return 0; }
+ quiche::QuicheStringPiece GetKey() const override {
+ return quiche::QuicheStringPiece();
+ }
+ quiche::QuicheStringPiece GetNoncePrefix() const override {
+ return quiche::QuicheStringPiece();
+ }
+ // Use a distinct value starting with 0xFFFFFF, which is never used by TLS.
+ uint32_t cipher_id() const override { return 0xFFFFFFF0; }
+
+ protected:
+ virtual uint8_t GetTag(quiche::QuicheStringPiece ciphertext) {
+ return ciphertext.data()[ciphertext.size() - 1];
+ }
+
+ private:
+ enum {
+ kTagSize = 12,
+ };
+
+ bool CheckTag(quiche::QuicheStringPiece ciphertext, uint8_t tag);
+};
+
+class TestPacketWriter : public QuicPacketWriter {
+ struct PacketBuffer {
+ QUIC_CACHELINE_ALIGNED char buffer[1500];
+ bool in_use = false;
+ };
+
+ public:
+ TestPacketWriter(ParsedQuicVersion version, MockClock* clock);
+ TestPacketWriter(const TestPacketWriter&) = delete;
+ TestPacketWriter& operator=(const TestPacketWriter&) = delete;
+
+ ~TestPacketWriter() override;
+
+ // QuicPacketWriter interface
+ WriteResult WritePacket(const char* buffer,
+ size_t buf_len,
+ const QuicIpAddress& self_address,
+ const QuicSocketAddress& peer_address,
+ PerPacketOptions* options) override;
+
+ bool ShouldWriteFail() { return write_should_fail_; }
+
+ bool IsWriteBlocked() const override { return write_blocked_; }
+
+ void SetWriteBlocked() { write_blocked_ = true; }
+
+ void SetWritable() override { write_blocked_ = false; }
+
+ void SetShouldWriteFail() { write_should_fail_ = true; }
+
+ QuicByteCount GetMaxPacketSize(
+ const QuicSocketAddress& /*peer_address*/) const override {
+ return max_packet_size_;
+ }
+
+ bool SupportsReleaseTime() const override { return supports_release_time_; }
+
+ bool IsBatchMode() const override { return is_batch_mode_; }
+
+ QuicPacketBuffer GetNextWriteLocation(
+ const QuicIpAddress& /*self_address*/,
+ const QuicSocketAddress& /*peer_address*/) override;
+
+ WriteResult Flush() override;
+
+ void BlockOnNextFlush() { block_on_next_flush_ = true; }
+
+ void BlockOnNextWrite() { block_on_next_write_ = true; }
+
+ void SimulateNextPacketTooLarge() { next_packet_too_large_ = true; }
+
+ void AlwaysGetPacketTooLarge() { always_get_packet_too_large_ = true; }
+
+ // Sets the amount of time that the writer should before the actual write.
+ void SetWritePauseTimeDelta(QuicTime::Delta delta) {
+ write_pause_time_delta_ = delta;
+ }
+
+ void SetBatchMode(bool new_value) { is_batch_mode_ = new_value; }
+
+ const QuicPacketHeader& header() { return framer_.header(); }
+
+ size_t frame_count() const { return framer_.num_frames(); }
+
+ const std::vector<QuicAckFrame>& ack_frames() const {
+ return framer_.ack_frames();
+ }
+
+ const std::vector<QuicStopWaitingFrame>& stop_waiting_frames() const {
+ return framer_.stop_waiting_frames();
+ }
+
+ const std::vector<QuicConnectionCloseFrame>& connection_close_frames() const {
+ return framer_.connection_close_frames();
+ }
+
+ const std::vector<QuicRstStreamFrame>& rst_stream_frames() const {
+ return framer_.rst_stream_frames();
+ }
+
+ const std::vector<std::unique_ptr<QuicStreamFrame>>& stream_frames() const {
+ return framer_.stream_frames();
+ }
+
+ const std::vector<std::unique_ptr<QuicCryptoFrame>>& crypto_frames() const {
+ return framer_.crypto_frames();
+ }
+
+ const std::vector<QuicPingFrame>& ping_frames() const {
+ return framer_.ping_frames();
+ }
+
+ const std::vector<QuicMessageFrame>& message_frames() const {
+ return framer_.message_frames();
+ }
+
+ const std::vector<QuicWindowUpdateFrame>& window_update_frames() const {
+ return framer_.window_update_frames();
+ }
+
+ const std::vector<QuicPaddingFrame>& padding_frames() const {
+ return framer_.padding_frames();
+ }
+
+ const std::vector<QuicPathChallengeFrame>& path_challenge_frames() const {
+ return framer_.path_challenge_frames();
+ }
+
+ const std::vector<QuicPathResponseFrame>& path_response_frames() const {
+ return framer_.path_response_frames();
+ }
+
+ const QuicEncryptedPacket* coalesced_packet() const {
+ return framer_.coalesced_packet();
+ }
+
+ size_t last_packet_size() { return last_packet_size_; }
+
+ const QuicPacketHeader& last_packet_header() const {
+ return last_packet_header_;
+ }
+
+ const QuicVersionNegotiationPacket* version_negotiation_packet() {
+ return framer_.version_negotiation_packet();
+ }
+
+ void set_is_write_blocked_data_buffered(bool buffered) {
+ is_write_blocked_data_buffered_ = buffered;
+ }
+
+ void set_perspective(Perspective perspective) {
+ // We invert perspective here, because the framer needs to parse packets
+ // we send.
+ QuicFramerPeer::SetPerspective(framer_.framer(),
+ QuicUtils::InvertPerspective(perspective));
+ }
+
+ // final_bytes_of_last_packet_ returns the last four bytes of the previous
+ // packet as a little-endian, uint32_t. This is intended to be used with a
+ // TaggingEncrypter so that tests can determine which encrypter was used for
+ // a given packet.
+ uint32_t final_bytes_of_last_packet() { return final_bytes_of_last_packet_; }
+
+ // Returns the final bytes of the second to last packet.
+ uint32_t final_bytes_of_previous_packet() {
+ return final_bytes_of_previous_packet_;
+ }
+
+ void use_tagging_decrypter() { use_tagging_decrypter_ = true; }
+
+ uint32_t packets_write_attempts() const { return packets_write_attempts_; }
+
+ uint32_t flush_attempts() const { return flush_attempts_; }
+
+ uint32_t connection_close_packets() const {
+ return connection_close_packets_;
+ }
+
+ void Reset() { framer_.Reset(); }
+
+ void SetSupportedVersions(const ParsedQuicVersionVector& versions) {
+ framer_.SetSupportedVersions(versions);
+ }
+
+ void set_max_packet_size(QuicByteCount max_packet_size) {
+ max_packet_size_ = max_packet_size;
+ }
+
+ void set_supports_release_time(bool supports_release_time) {
+ supports_release_time_ = supports_release_time;
+ }
+
+ SimpleQuicFramer* framer() { return &framer_; }
+
+ const QuicSocketAddress& last_write_peer_address() const {
+ return last_write_peer_address_;
+ }
+
+ private:
+ char* AllocPacketBuffer();
+
+ void FreePacketBuffer(const char* buffer);
+
+ ParsedQuicVersion version_;
+ SimpleQuicFramer framer_;
+ size_t last_packet_size_ = 0;
+ QuicPacketHeader last_packet_header_;
+ bool write_blocked_ = false;
+ bool write_should_fail_ = false;
+ bool block_on_next_flush_ = false;
+ bool block_on_next_write_ = false;
+ bool next_packet_too_large_ = false;
+ bool always_get_packet_too_large_ = false;
+ bool is_write_blocked_data_buffered_ = false;
+ bool is_batch_mode_ = false;
+ // Number of times Flush() was called.
+ uint32_t flush_attempts_ = 0;
+ // (Batch mode only) Number of bytes buffered in writer. It is used as the
+ // return value of a successful Flush().
+ uint32_t bytes_buffered_ = 0;
+ uint32_t final_bytes_of_last_packet_ = 0;
+ uint32_t final_bytes_of_previous_packet_ = 0;
+ bool use_tagging_decrypter_ = false;
+ uint32_t packets_write_attempts_ = 0;
+ uint32_t connection_close_packets_ = 0;
+ MockClock* clock_ = nullptr;
+ // If non-zero, the clock will pause during WritePacket for this amount of
+ // time.
+ QuicTime::Delta write_pause_time_delta_ = QuicTime::Delta::Zero();
+ QuicByteCount max_packet_size_ = kMaxOutgoingPacketSize;
+ bool supports_release_time_ = false;
+ // Used to verify writer-allocated packet buffers are properly released.
+ std::vector<PacketBuffer*> packet_buffer_pool_;
+ // Buffer address => Address of the owning PacketBuffer.
+ QuicHashMap<char*, PacketBuffer*> packet_buffer_pool_index_;
+ // Indices in packet_buffer_pool_ that are not allocated.
+ std::list<PacketBuffer*> packet_buffer_free_list_;
+ // The peer address passed into WritePacket().
+ QuicSocketAddress last_write_peer_address_;
+};
+
} // namespace test
} // namespace quic