Move TestPacketWriter from quic_connection_test.cc to quic_test_utils. PiperOrigin-RevId: 334408019 Change-Id: I3fc479fea2af0e1e2e070bfd0c120ed732061586
diff --git a/quic/core/quic_connection_test.cc b/quic/core/quic_connection_test.cc index 2c791f2..1b41e8e 100644 --- a/quic/core/quic_connection_test.cc +++ b/quic/core/quic_connection_test.cc
@@ -24,7 +24,6 @@ #include "net/third_party/quiche/src/quic/core/quic_types.h" #include "net/third_party/quiche/src/quic/core/quic_utils.h" #include "net/third_party/quiche/src/quic/core/quic_versions.h" -#include "net/third_party/quiche/src/quic/platform/api/quic_error_code_wrappers.h" #include "net/third_party/quiche/src/quic/platform/api/quic_expect_bug.h" #include "net/third_party/quiche/src/quic/platform/api/quic_flags.h" #include "net/third_party/quiche/src/quic/platform/api/quic_logging.h" @@ -40,7 +39,6 @@ #include "net/third_party/quiche/src/quic/test_tools/quic_sent_packet_manager_peer.h" #include "net/third_party/quiche/src/quic/test_tools/quic_test_utils.h" #include "net/third_party/quiche/src/quic/test_tools/simple_data_producer.h" -#include "net/third_party/quiche/src/quic/test_tools/simple_quic_framer.h" #include "net/third_party/quiche/src/quic/test_tools/simple_session_notifier.h" #include "net/third_party/quiche/src/common/platform/api/quiche_arraysize.h" #include "net/third_party/quiche/src/common/platform/api/quiche_str_cat.h" @@ -110,162 +108,6 @@ } } -// TaggingEncrypter appends kTagSize bytes of |tag| to the end of each message. -class TaggingEncrypter : public QuicEncrypter { - public: - explicit TaggingEncrypter(uint8_t tag) : tag_(tag) {} - TaggingEncrypter(const TaggingEncrypter&) = delete; - TaggingEncrypter& operator=(const TaggingEncrypter&) = delete; - - ~TaggingEncrypter() override {} - - // QuicEncrypter interface. - bool SetKey(quiche::QuicheStringPiece /*key*/) override { return true; } - - bool SetNoncePrefix(quiche::QuicheStringPiece /*nonce_prefix*/) override { - return true; - } - - bool SetIV(quiche::QuicheStringPiece /*iv*/) override { return true; } - - bool SetHeaderProtectionKey(quiche::QuicheStringPiece /*key*/) override { - return true; - } - - bool EncryptPacket(uint64_t /*packet_number*/, - quiche::QuicheStringPiece /*associated_data*/, - quiche::QuicheStringPiece plaintext, - char* output, - size_t* output_length, - size_t max_output_length) override { - const size_t len = plaintext.size() + kTagSize; - if (max_output_length < len) { - return false; - } - // Memmove is safe for inplace encryption. - memmove(output, plaintext.data(), plaintext.size()); - output += plaintext.size(); - memset(output, tag_, kTagSize); - *output_length = len; - return true; - } - - std::string GenerateHeaderProtectionMask( - quiche::QuicheStringPiece /*sample*/) override { - return std::string(5, 0); - } - - size_t GetKeySize() const override { return 0; } - size_t GetNoncePrefixSize() const override { return 0; } - size_t GetIVSize() const override { return 0; } - - size_t GetMaxPlaintextSize(size_t ciphertext_size) const override { - return ciphertext_size - kTagSize; - } - - size_t GetCiphertextSize(size_t plaintext_size) const override { - return plaintext_size + kTagSize; - } - - quiche::QuicheStringPiece GetKey() const override { - return quiche::QuicheStringPiece(); - } - - quiche::QuicheStringPiece GetNoncePrefix() const override { - return quiche::QuicheStringPiece(); - } - - private: - enum { - kTagSize = 12, - }; - - const uint8_t tag_; -}; - -// TaggingDecrypter ensures that the final kTagSize bytes of the message all -// have the same value and then removes them. -class TaggingDecrypter : public QuicDecrypter { - public: - ~TaggingDecrypter() override {} - - // QuicDecrypter interface - bool SetKey(quiche::QuicheStringPiece /*key*/) override { return true; } - - bool SetNoncePrefix(quiche::QuicheStringPiece /*nonce_prefix*/) override { - return true; - } - - bool SetIV(quiche::QuicheStringPiece /*iv*/) override { return true; } - - bool SetHeaderProtectionKey(quiche::QuicheStringPiece /*key*/) override { - return true; - } - - bool SetPreliminaryKey(quiche::QuicheStringPiece /*key*/) override { - QUIC_BUG << "should not be called"; - return false; - } - - bool SetDiversificationNonce(const DiversificationNonce& /*key*/) override { - return true; - } - - bool DecryptPacket(uint64_t /*packet_number*/, - quiche::QuicheStringPiece /*associated_data*/, - quiche::QuicheStringPiece ciphertext, - char* output, - size_t* output_length, - size_t /*max_output_length*/) override { - if (ciphertext.size() < kTagSize) { - return false; - } - if (!CheckTag(ciphertext, GetTag(ciphertext))) { - return false; - } - *output_length = ciphertext.size() - kTagSize; - memcpy(output, ciphertext.data(), *output_length); - return true; - } - - std::string GenerateHeaderProtectionMask( - QuicDataReader* /*sample_reader*/) override { - return std::string(5, 0); - } - - size_t GetKeySize() const override { return 0; } - size_t GetNoncePrefixSize() const override { return 0; } - size_t GetIVSize() const override { return 0; } - quiche::QuicheStringPiece GetKey() const override { - return quiche::QuicheStringPiece(); - } - quiche::QuicheStringPiece GetNoncePrefix() const override { - return quiche::QuicheStringPiece(); - } - // Use a distinct value starting with 0xFFFFFF, which is never used by TLS. - uint32_t cipher_id() const override { return 0xFFFFFFF0; } - - protected: - virtual uint8_t GetTag(quiche::QuicheStringPiece ciphertext) { - return ciphertext.data()[ciphertext.size() - 1]; - } - - private: - enum { - kTagSize = 12, - }; - - bool CheckTag(quiche::QuicheStringPiece ciphertext, uint8_t tag) { - for (size_t i = ciphertext.size() - kTagSize; i < ciphertext.size(); i++) { - if (ciphertext.data()[i] != tag) { - return false; - } - } - - return true; - } -}; - // StringTaggingDecrypter ensures that the final kTagSize bytes of the message // match the expected value. class StrictTaggingDecrypter : public TaggingDecrypter { @@ -336,352 +178,6 @@ } }; -class TestPacketWriter : public QuicPacketWriter { - struct PacketBuffer { - QUIC_CACHELINE_ALIGNED char buffer[1500]; - bool in_use = false; - }; - - public: - TestPacketWriter(ParsedQuicVersion version, MockClock* clock) - : version_(version), - framer_(SupportedVersions(version_), Perspective::IS_SERVER), - clock_(clock) { - QuicFramerPeer::SetLastSerializedServerConnectionId(framer_.framer(), - TestConnectionId()); - framer_.framer()->SetInitialObfuscators(TestConnectionId()); - - for (int i = 0; i < 128; ++i) { - PacketBuffer* p = new PacketBuffer(); - packet_buffer_pool_.push_back(p); - packet_buffer_pool_index_[p->buffer] = p; - packet_buffer_free_list_.push_back(p); - } - } - TestPacketWriter(const TestPacketWriter&) = delete; - TestPacketWriter& operator=(const TestPacketWriter&) = delete; - - ~TestPacketWriter() override { - EXPECT_EQ(packet_buffer_pool_.size(), packet_buffer_free_list_.size()) - << packet_buffer_pool_.size() - packet_buffer_free_list_.size() - << " out of " << packet_buffer_pool_.size() - << " packet buffers have been leaked."; - for (auto p : packet_buffer_pool_) { - delete p; - } - } - - // QuicPacketWriter interface - WriteResult WritePacket(const char* buffer, - size_t buf_len, - const QuicIpAddress& /*self_address*/, - const QuicSocketAddress& peer_address, - PerPacketOptions* /*options*/) override { - last_write_peer_address_ = peer_address; - // If the buffer is allocated from the pool, return it back to the pool. - // Note the buffer content doesn't change. - if (packet_buffer_pool_index_.find(const_cast<char*>(buffer)) != - packet_buffer_pool_index_.end()) { - FreePacketBuffer(buffer); - } - - QuicEncryptedPacket packet(buffer, buf_len); - ++packets_write_attempts_; - - if (packet.length() >= sizeof(final_bytes_of_last_packet_)) { - final_bytes_of_previous_packet_ = final_bytes_of_last_packet_; - memcpy(&final_bytes_of_last_packet_, packet.data() + packet.length() - 4, - sizeof(final_bytes_of_last_packet_)); - } - - if (use_tagging_decrypter_) { - if (framer_.framer()->version().KnowsWhichDecrypterToUse()) { - framer_.framer()->InstallDecrypter( - ENCRYPTION_INITIAL, std::make_unique<TaggingDecrypter>()); - framer_.framer()->InstallDecrypter( - ENCRYPTION_HANDSHAKE, std::make_unique<TaggingDecrypter>()); - framer_.framer()->InstallDecrypter( - ENCRYPTION_ZERO_RTT, std::make_unique<TaggingDecrypter>()); - framer_.framer()->InstallDecrypter( - ENCRYPTION_FORWARD_SECURE, std::make_unique<TaggingDecrypter>()); - } else { - framer_.framer()->SetDecrypter(ENCRYPTION_INITIAL, - std::make_unique<TaggingDecrypter>()); - } - } else if (framer_.framer()->version().KnowsWhichDecrypterToUse()) { - framer_.framer()->InstallDecrypter( - ENCRYPTION_FORWARD_SECURE, - std::make_unique<NullDecrypter>(Perspective::IS_SERVER)); - } - EXPECT_TRUE(framer_.ProcessPacket(packet)) - << framer_.framer()->detailed_error(); - if (block_on_next_write_) { - write_blocked_ = true; - block_on_next_write_ = false; - } - if (next_packet_too_large_) { - next_packet_too_large_ = false; - return WriteResult(WRITE_STATUS_ERROR, QUIC_EMSGSIZE); - } - if (always_get_packet_too_large_) { - return WriteResult(WRITE_STATUS_ERROR, QUIC_EMSGSIZE); - } - if (IsWriteBlocked()) { - return WriteResult(is_write_blocked_data_buffered_ - ? WRITE_STATUS_BLOCKED_DATA_BUFFERED - : WRITE_STATUS_BLOCKED, - 0); - } - - if (ShouldWriteFail()) { - return WriteResult(WRITE_STATUS_ERROR, 0); - } - - last_packet_size_ = packet.length(); - last_packet_header_ = framer_.header(); - if (!framer_.connection_close_frames().empty()) { - ++connection_close_packets_; - } - if (!write_pause_time_delta_.IsZero()) { - clock_->AdvanceTime(write_pause_time_delta_); - } - if (is_batch_mode_) { - bytes_buffered_ += last_packet_size_; - return WriteResult(WRITE_STATUS_OK, 0); - } - return WriteResult(WRITE_STATUS_OK, last_packet_size_); - } - - bool ShouldWriteFail() { return write_should_fail_; } - - bool IsWriteBlocked() const override { return write_blocked_; } - - void SetWriteBlocked() { write_blocked_ = true; } - - void SetWritable() override { write_blocked_ = false; } - - void SetShouldWriteFail() { write_should_fail_ = true; } - - QuicByteCount GetMaxPacketSize( - const QuicSocketAddress& /*peer_address*/) const override { - return max_packet_size_; - } - - bool SupportsReleaseTime() const override { return supports_release_time_; } - - bool IsBatchMode() const override { return is_batch_mode_; } - - QuicPacketBuffer GetNextWriteLocation( - const QuicIpAddress& /*self_address*/, - const QuicSocketAddress& /*peer_address*/) override { - return {AllocPacketBuffer(), - [this](const char* p) { FreePacketBuffer(p); }}; - } - - WriteResult Flush() override { - flush_attempts_++; - if (block_on_next_flush_) { - block_on_next_flush_ = false; - SetWriteBlocked(); - return WriteResult(WRITE_STATUS_BLOCKED, /*errno*/ -1); - } - if (write_should_fail_) { - return WriteResult(WRITE_STATUS_ERROR, /*errno*/ -1); - } - int bytes_flushed = bytes_buffered_; - bytes_buffered_ = 0; - return WriteResult(WRITE_STATUS_OK, bytes_flushed); - } - - void BlockOnNextFlush() { block_on_next_flush_ = true; } - - void BlockOnNextWrite() { block_on_next_write_ = true; } - - void SimulateNextPacketTooLarge() { next_packet_too_large_ = true; } - - void AlwaysGetPacketTooLarge() { always_get_packet_too_large_ = true; } - - // Sets the amount of time that the writer should before the actual write. - void SetWritePauseTimeDelta(QuicTime::Delta delta) { - write_pause_time_delta_ = delta; - } - - void SetBatchMode(bool new_value) { is_batch_mode_ = new_value; } - - const QuicPacketHeader& header() { return framer_.header(); } - - size_t frame_count() const { return framer_.num_frames(); } - - const std::vector<QuicAckFrame>& ack_frames() const { - return framer_.ack_frames(); - } - - const std::vector<QuicStopWaitingFrame>& stop_waiting_frames() const { - return framer_.stop_waiting_frames(); - } - - const std::vector<QuicConnectionCloseFrame>& connection_close_frames() const { - return framer_.connection_close_frames(); - } - - const std::vector<QuicRstStreamFrame>& rst_stream_frames() const { - return framer_.rst_stream_frames(); - } - - const std::vector<std::unique_ptr<QuicStreamFrame>>& stream_frames() const { - return framer_.stream_frames(); - } - - const std::vector<std::unique_ptr<QuicCryptoFrame>>& crypto_frames() const { - return framer_.crypto_frames(); - } - - const std::vector<QuicPingFrame>& ping_frames() const { - return framer_.ping_frames(); - } - - const std::vector<QuicMessageFrame>& message_frames() const { - return framer_.message_frames(); - } - - const std::vector<QuicWindowUpdateFrame>& window_update_frames() const { - return framer_.window_update_frames(); - } - - const std::vector<QuicPaddingFrame>& padding_frames() const { - return framer_.padding_frames(); - } - - const std::vector<QuicPathChallengeFrame>& path_challenge_frames() const { - return framer_.path_challenge_frames(); - } - - const std::vector<QuicPathResponseFrame>& path_response_frames() const { - return framer_.path_response_frames(); - } - - const QuicEncryptedPacket* coalesced_packet() const { - return framer_.coalesced_packet(); - } - - size_t last_packet_size() { return last_packet_size_; } - - const QuicPacketHeader& last_packet_header() const { - return last_packet_header_; - } - - const QuicVersionNegotiationPacket* version_negotiation_packet() { - return framer_.version_negotiation_packet(); - } - - void set_is_write_blocked_data_buffered(bool buffered) { - is_write_blocked_data_buffered_ = buffered; - } - - void set_perspective(Perspective perspective) { - // We invert perspective here, because the framer needs to parse packets - // we send. - QuicFramerPeer::SetPerspective(framer_.framer(), - QuicUtils::InvertPerspective(perspective)); - } - - // final_bytes_of_last_packet_ returns the last four bytes of the previous - // packet as a little-endian, uint32_t. This is intended to be used with a - // TaggingEncrypter so that tests can determine which encrypter was used for - // a given packet. - uint32_t final_bytes_of_last_packet() { return final_bytes_of_last_packet_; } - - // Returns the final bytes of the second to last packet. - uint32_t final_bytes_of_previous_packet() { - return final_bytes_of_previous_packet_; - } - - void use_tagging_decrypter() { use_tagging_decrypter_ = true; } - - uint32_t packets_write_attempts() const { return packets_write_attempts_; } - - uint32_t flush_attempts() const { return flush_attempts_; } - - uint32_t connection_close_packets() const { - return connection_close_packets_; - } - - void Reset() { framer_.Reset(); } - - void SetSupportedVersions(const ParsedQuicVersionVector& versions) { - framer_.SetSupportedVersions(versions); - } - - void set_max_packet_size(QuicByteCount max_packet_size) { - max_packet_size_ = max_packet_size; - } - - void set_supports_release_time(bool supports_release_time) { - supports_release_time_ = supports_release_time; - } - - SimpleQuicFramer* framer() { return &framer_; } - - const QuicSocketAddress& last_write_peer_address() const { - return last_write_peer_address_; - } - - private: - char* AllocPacketBuffer() { - PacketBuffer* p = packet_buffer_free_list_.front(); - EXPECT_FALSE(p->in_use); - p->in_use = true; - packet_buffer_free_list_.pop_front(); - return p->buffer; - } - - void FreePacketBuffer(const char* buffer) { - auto iter = packet_buffer_pool_index_.find(const_cast<char*>(buffer)); - ASSERT_TRUE(iter != packet_buffer_pool_index_.end()); - PacketBuffer* p = iter->second; - ASSERT_TRUE(p->in_use); - p->in_use = false; - packet_buffer_free_list_.push_back(p); - } - - ParsedQuicVersion version_; - SimpleQuicFramer framer_; - size_t last_packet_size_ = 0; - QuicPacketHeader last_packet_header_; - bool write_blocked_ = false; - bool write_should_fail_ = false; - bool block_on_next_flush_ = false; - bool block_on_next_write_ = false; - bool next_packet_too_large_ = false; - bool always_get_packet_too_large_ = false; - bool is_write_blocked_data_buffered_ = false; - bool is_batch_mode_ = false; - // Number of times Flush() was called. - uint32_t flush_attempts_ = 0; - // (Batch mode only) Number of bytes buffered in writer. It is used as the - // return value of a successful Flush(). - uint32_t bytes_buffered_ = 0; - uint32_t final_bytes_of_last_packet_ = 0; - uint32_t final_bytes_of_previous_packet_ = 0; - bool use_tagging_decrypter_ = false; - uint32_t packets_write_attempts_ = 0; - uint32_t connection_close_packets_ = 0; - MockClock* clock_ = nullptr; - // If non-zero, the clock will pause during WritePacket for this amount of - // time. - QuicTime::Delta write_pause_time_delta_ = QuicTime::Delta::Zero(); - QuicByteCount max_packet_size_ = kMaxOutgoingPacketSize; - bool supports_release_time_ = false; - // Used to verify writer-allocated packet buffers are properly released. - std::vector<PacketBuffer*> packet_buffer_pool_; - // Buffer address => Address of the owning PacketBuffer. - QuicHashMap<char*, PacketBuffer*> packet_buffer_pool_index_; - // Indices in packet_buffer_pool_ that are not allocated. - std::list<PacketBuffer*> packet_buffer_free_list_; - // The peer address passed into WritePacket(). - QuicSocketAddress last_write_peer_address_; -}; - class TestConnection : public QuicConnection { public: TestConnection(QuicConnectionId connection_id,
diff --git a/quic/test_tools/quic_test_utils.cc b/quic/test_tools/quic_test_utils.cc index 6e83351..5aa2a7a 100644 --- a/quic/test_tools/quic_test_utils.cc +++ b/quic/test_tools/quic_test_utils.cc
@@ -14,6 +14,7 @@ #include "net/third_party/quiche/src/quic/core/crypto/crypto_framer.h" #include "net/third_party/quiche/src/quic/core/crypto/crypto_handshake.h" #include "net/third_party/quiche/src/quic/core/crypto/crypto_utils.h" +#include "net/third_party/quiche/src/quic/core/crypto/null_decrypter.h" #include "net/third_party/quiche/src/quic/core/crypto/null_encrypter.h" #include "net/third_party/quiche/src/quic/core/crypto/quic_decrypter.h" #include "net/third_party/quiche/src/quic/core/crypto/quic_encrypter.h" @@ -27,6 +28,7 @@ #include "net/third_party/quiche/src/quic/core/quic_types.h" #include "net/third_party/quiche/src/quic/core/quic_utils.h" #include "net/third_party/quiche/src/quic/core/quic_versions.h" +#include "net/third_party/quiche/src/quic/platform/api/quic_error_code_wrappers.h" #include "net/third_party/quiche/src/quic/platform/api/quic_flags.h" #include "net/third_party/quiche/src/quic/platform/api/quic_logging.h" #include "net/third_party/quiche/src/quic/test_tools/crypto_test_utils.h" @@ -1313,5 +1315,197 @@ return QuicMemSlice(std::move(buffer), data.size()); } +bool TaggingEncrypter::EncryptPacket( + uint64_t /*packet_number*/, + quiche::QuicheStringPiece /*associated_data*/, + quiche::QuicheStringPiece plaintext, + char* output, + size_t* output_length, + size_t max_output_length) { + const size_t len = plaintext.size() + kTagSize; + if (max_output_length < len) { + return false; + } + // Memmove is safe for inplace encryption. + memmove(output, plaintext.data(), plaintext.size()); + output += plaintext.size(); + memset(output, tag_, kTagSize); + *output_length = len; + return true; +} + +bool TaggingDecrypter::DecryptPacket( + uint64_t /*packet_number*/, + quiche::QuicheStringPiece /*associated_data*/, + quiche::QuicheStringPiece ciphertext, + char* output, + size_t* output_length, + size_t /*max_output_length*/) { + if (ciphertext.size() < kTagSize) { + return false; + } + if (!CheckTag(ciphertext, GetTag(ciphertext))) { + return false; + } + *output_length = ciphertext.size() - kTagSize; + memcpy(output, ciphertext.data(), *output_length); + return true; +} + +bool TaggingDecrypter::CheckTag(quiche::QuicheStringPiece ciphertext, + uint8_t tag) { + for (size_t i = ciphertext.size() - kTagSize; i < ciphertext.size(); i++) { + if (ciphertext.data()[i] != tag) { + return false; + } + } + + return true; +} + +TestPacketWriter::TestPacketWriter(ParsedQuicVersion version, MockClock* clock) + : version_(version), + framer_(SupportedVersions(version_), Perspective::IS_SERVER), + clock_(clock) { + QuicFramerPeer::SetLastSerializedServerConnectionId(framer_.framer(), + TestConnectionId()); + framer_.framer()->SetInitialObfuscators(TestConnectionId()); + + for (int i = 0; i < 128; ++i) { + PacketBuffer* p = new PacketBuffer(); + packet_buffer_pool_.push_back(p); + packet_buffer_pool_index_[p->buffer] = p; + packet_buffer_free_list_.push_back(p); + } +} + +TestPacketWriter::~TestPacketWriter() { + EXPECT_EQ(packet_buffer_pool_.size(), packet_buffer_free_list_.size()) + << packet_buffer_pool_.size() - packet_buffer_free_list_.size() + << " out of " << packet_buffer_pool_.size() + << " packet buffers have been leaked."; + for (auto p : packet_buffer_pool_) { + delete p; + } +} + +WriteResult TestPacketWriter::WritePacket(const char* buffer, + size_t buf_len, + const QuicIpAddress& /*self_address*/, + const QuicSocketAddress& peer_address, + PerPacketOptions* /*options*/) { + last_write_peer_address_ = peer_address; + // If the buffer is allocated from the pool, return it back to the pool. + // Note the buffer content doesn't change. + if (packet_buffer_pool_index_.find(const_cast<char*>(buffer)) != + packet_buffer_pool_index_.end()) { + FreePacketBuffer(buffer); + } + + QuicEncryptedPacket packet(buffer, buf_len); + ++packets_write_attempts_; + + if (packet.length() >= sizeof(final_bytes_of_last_packet_)) { + final_bytes_of_previous_packet_ = final_bytes_of_last_packet_; + memcpy(&final_bytes_of_last_packet_, packet.data() + packet.length() - 4, + sizeof(final_bytes_of_last_packet_)); + } + + if (use_tagging_decrypter_) { + if (framer_.framer()->version().KnowsWhichDecrypterToUse()) { + framer_.framer()->InstallDecrypter(ENCRYPTION_INITIAL, + std::make_unique<TaggingDecrypter>()); + framer_.framer()->InstallDecrypter(ENCRYPTION_HANDSHAKE, + std::make_unique<TaggingDecrypter>()); + framer_.framer()->InstallDecrypter(ENCRYPTION_ZERO_RTT, + std::make_unique<TaggingDecrypter>()); + framer_.framer()->InstallDecrypter(ENCRYPTION_FORWARD_SECURE, + std::make_unique<TaggingDecrypter>()); + } else { + framer_.framer()->SetDecrypter(ENCRYPTION_INITIAL, + std::make_unique<TaggingDecrypter>()); + } + } else if (framer_.framer()->version().KnowsWhichDecrypterToUse()) { + framer_.framer()->InstallDecrypter( + ENCRYPTION_FORWARD_SECURE, + std::make_unique<NullDecrypter>(Perspective::IS_SERVER)); + } + EXPECT_TRUE(framer_.ProcessPacket(packet)) + << framer_.framer()->detailed_error(); + if (block_on_next_write_) { + write_blocked_ = true; + block_on_next_write_ = false; + } + if (next_packet_too_large_) { + next_packet_too_large_ = false; + return WriteResult(WRITE_STATUS_ERROR, QUIC_EMSGSIZE); + } + if (always_get_packet_too_large_) { + return WriteResult(WRITE_STATUS_ERROR, QUIC_EMSGSIZE); + } + if (IsWriteBlocked()) { + return WriteResult(is_write_blocked_data_buffered_ + ? WRITE_STATUS_BLOCKED_DATA_BUFFERED + : WRITE_STATUS_BLOCKED, + 0); + } + + if (ShouldWriteFail()) { + return WriteResult(WRITE_STATUS_ERROR, 0); + } + + last_packet_size_ = packet.length(); + last_packet_header_ = framer_.header(); + if (!framer_.connection_close_frames().empty()) { + ++connection_close_packets_; + } + if (!write_pause_time_delta_.IsZero()) { + clock_->AdvanceTime(write_pause_time_delta_); + } + if (is_batch_mode_) { + bytes_buffered_ += last_packet_size_; + return WriteResult(WRITE_STATUS_OK, 0); + } + return WriteResult(WRITE_STATUS_OK, last_packet_size_); +} + +QuicPacketBuffer TestPacketWriter::GetNextWriteLocation( + const QuicIpAddress& /*self_address*/, + const QuicSocketAddress& /*peer_address*/) { + return {AllocPacketBuffer(), [this](const char* p) { FreePacketBuffer(p); }}; +} + +WriteResult TestPacketWriter::Flush() { + flush_attempts_++; + if (block_on_next_flush_) { + block_on_next_flush_ = false; + SetWriteBlocked(); + return WriteResult(WRITE_STATUS_BLOCKED, /*errno*/ -1); + } + if (write_should_fail_) { + return WriteResult(WRITE_STATUS_ERROR, /*errno*/ -1); + } + int bytes_flushed = bytes_buffered_; + bytes_buffered_ = 0; + return WriteResult(WRITE_STATUS_OK, bytes_flushed); +} + +char* TestPacketWriter::AllocPacketBuffer() { + PacketBuffer* p = packet_buffer_free_list_.front(); + EXPECT_FALSE(p->in_use); + p->in_use = true; + packet_buffer_free_list_.pop_front(); + return p->buffer; +} + +void TestPacketWriter::FreePacketBuffer(const char* buffer) { + auto iter = packet_buffer_pool_index_.find(const_cast<char*>(buffer)); + ASSERT_TRUE(iter != packet_buffer_pool_index_.end()); + PacketBuffer* p = iter->second; + ASSERT_TRUE(p->in_use); + p->in_use = false; + packet_buffer_free_list_.push_back(p); +} + } // namespace test } // namespace quic
diff --git a/quic/test_tools/quic_test_utils.h b/quic/test_tools/quic_test_utils.h index cd87b5b..e207a56 100644 --- a/quic/test_tools/quic_test_utils.h +++ b/quic/test_tools/quic_test_utils.h
@@ -27,11 +27,14 @@ #include "net/third_party/quiche/src/quic/core/quic_server_id.h" #include "net/third_party/quiche/src/quic/core/quic_simple_buffer_allocator.h" #include "net/third_party/quiche/src/quic/core/quic_types.h" +#include "net/third_party/quiche/src/quic/core/quic_utils.h" #include "net/third_party/quiche/src/quic/platform/api/quic_mem_slice_storage.h" #include "net/third_party/quiche/src/quic/platform/api/quic_test.h" #include "net/third_party/quiche/src/quic/test_tools/mock_clock.h" #include "net/third_party/quiche/src/quic/test_tools/mock_quic_session_visitor.h" #include "net/third_party/quiche/src/quic/test_tools/mock_random.h" +#include "net/third_party/quiche/src/quic/test_tools/quic_framer_peer.h" +#include "net/third_party/quiche/src/quic/test_tools/simple_quic_framer.h" #include "net/third_party/quiche/src/common/platform/api/quiche_str_cat.h" #include "net/third_party/quiche/src/common/platform/api/quiche_string_piece.h" @@ -1741,6 +1744,354 @@ return arg == QUIC_STREAM_NO_ERROR; } +// TaggingEncrypter appends kTagSize bytes of |tag| to the end of each message. +class TaggingEncrypter : public QuicEncrypter { + public: + explicit TaggingEncrypter(uint8_t tag) : tag_(tag) {} + TaggingEncrypter(const TaggingEncrypter&) = delete; + TaggingEncrypter& operator=(const TaggingEncrypter&) = delete; + + ~TaggingEncrypter() override {} + + // QuicEncrypter interface. + bool SetKey(quiche::QuicheStringPiece /*key*/) override { return true; } + + bool SetNoncePrefix(quiche::QuicheStringPiece /*nonce_prefix*/) override { + return true; + } + + bool SetIV(quiche::QuicheStringPiece /*iv*/) override { return true; } + + bool SetHeaderProtectionKey(quiche::QuicheStringPiece /*key*/) override { + return true; + } + + bool EncryptPacket(uint64_t packet_number, + quiche::QuicheStringPiece associated_data, + quiche::QuicheStringPiece plaintext, + char* output, + size_t* output_length, + size_t max_output_length) override; + + std::string GenerateHeaderProtectionMask( + quiche::QuicheStringPiece /*sample*/) override { + return std::string(5, 0); + } + + size_t GetKeySize() const override { return 0; } + size_t GetNoncePrefixSize() const override { return 0; } + size_t GetIVSize() const override { return 0; } + + size_t GetMaxPlaintextSize(size_t ciphertext_size) const override { + return ciphertext_size - kTagSize; + } + + size_t GetCiphertextSize(size_t plaintext_size) const override { + return plaintext_size + kTagSize; + } + + quiche::QuicheStringPiece GetKey() const override { + return quiche::QuicheStringPiece(); + } + + quiche::QuicheStringPiece GetNoncePrefix() const override { + return quiche::QuicheStringPiece(); + } + + private: + enum { + kTagSize = 12, + }; + + const uint8_t tag_; +}; + +// TaggingDecrypter ensures that the final kTagSize bytes of the message all +// have the same value and then removes them. +class TaggingDecrypter : public QuicDecrypter { + public: + ~TaggingDecrypter() override {} + + // QuicDecrypter interface + bool SetKey(quiche::QuicheStringPiece /*key*/) override { return true; } + + bool SetNoncePrefix(quiche::QuicheStringPiece /*nonce_prefix*/) override { + return true; + } + + bool SetIV(quiche::QuicheStringPiece /*iv*/) override { return true; } + + bool SetHeaderProtectionKey(quiche::QuicheStringPiece /*key*/) override { + return true; + } + + bool SetPreliminaryKey(quiche::QuicheStringPiece /*key*/) override { + QUIC_BUG << "should not be called"; + return false; + } + + bool SetDiversificationNonce(const DiversificationNonce& /*key*/) override { + return true; + } + + bool DecryptPacket(uint64_t packet_number, + quiche::QuicheStringPiece associated_data, + quiche::QuicheStringPiece ciphertext, + char* output, + size_t* output_length, + size_t max_output_length) override; + + std::string GenerateHeaderProtectionMask( + QuicDataReader* /*sample_reader*/) override { + return std::string(5, 0); + } + + size_t GetKeySize() const override { return 0; } + size_t GetNoncePrefixSize() const override { return 0; } + size_t GetIVSize() const override { return 0; } + quiche::QuicheStringPiece GetKey() const override { + return quiche::QuicheStringPiece(); + } + quiche::QuicheStringPiece GetNoncePrefix() const override { + return quiche::QuicheStringPiece(); + } + // Use a distinct value starting with 0xFFFFFF, which is never used by TLS. + uint32_t cipher_id() const override { return 0xFFFFFFF0; } + + protected: + virtual uint8_t GetTag(quiche::QuicheStringPiece ciphertext) { + return ciphertext.data()[ciphertext.size() - 1]; + } + + private: + enum { + kTagSize = 12, + }; + + bool CheckTag(quiche::QuicheStringPiece ciphertext, uint8_t tag); +}; + +class TestPacketWriter : public QuicPacketWriter { + struct PacketBuffer { + QUIC_CACHELINE_ALIGNED char buffer[1500]; + bool in_use = false; + }; + + public: + TestPacketWriter(ParsedQuicVersion version, MockClock* clock); + TestPacketWriter(const TestPacketWriter&) = delete; + TestPacketWriter& operator=(const TestPacketWriter&) = delete; + + ~TestPacketWriter() override; + + // QuicPacketWriter interface + WriteResult WritePacket(const char* buffer, + size_t buf_len, + const QuicIpAddress& self_address, + const QuicSocketAddress& peer_address, + PerPacketOptions* options) override; + + bool ShouldWriteFail() { return write_should_fail_; } + + bool IsWriteBlocked() const override { return write_blocked_; } + + void SetWriteBlocked() { write_blocked_ = true; } + + void SetWritable() override { write_blocked_ = false; } + + void SetShouldWriteFail() { write_should_fail_ = true; } + + QuicByteCount GetMaxPacketSize( + const QuicSocketAddress& /*peer_address*/) const override { + return max_packet_size_; + } + + bool SupportsReleaseTime() const override { return supports_release_time_; } + + bool IsBatchMode() const override { return is_batch_mode_; } + + QuicPacketBuffer GetNextWriteLocation( + const QuicIpAddress& /*self_address*/, + const QuicSocketAddress& /*peer_address*/) override; + + WriteResult Flush() override; + + void BlockOnNextFlush() { block_on_next_flush_ = true; } + + void BlockOnNextWrite() { block_on_next_write_ = true; } + + void SimulateNextPacketTooLarge() { next_packet_too_large_ = true; } + + void AlwaysGetPacketTooLarge() { always_get_packet_too_large_ = true; } + + // Sets the amount of time that the writer should before the actual write. + void SetWritePauseTimeDelta(QuicTime::Delta delta) { + write_pause_time_delta_ = delta; + } + + void SetBatchMode(bool new_value) { is_batch_mode_ = new_value; } + + const QuicPacketHeader& header() { return framer_.header(); } + + size_t frame_count() const { return framer_.num_frames(); } + + const std::vector<QuicAckFrame>& ack_frames() const { + return framer_.ack_frames(); + } + + const std::vector<QuicStopWaitingFrame>& stop_waiting_frames() const { + return framer_.stop_waiting_frames(); + } + + const std::vector<QuicConnectionCloseFrame>& connection_close_frames() const { + return framer_.connection_close_frames(); + } + + const std::vector<QuicRstStreamFrame>& rst_stream_frames() const { + return framer_.rst_stream_frames(); + } + + const std::vector<std::unique_ptr<QuicStreamFrame>>& stream_frames() const { + return framer_.stream_frames(); + } + + const std::vector<std::unique_ptr<QuicCryptoFrame>>& crypto_frames() const { + return framer_.crypto_frames(); + } + + const std::vector<QuicPingFrame>& ping_frames() const { + return framer_.ping_frames(); + } + + const std::vector<QuicMessageFrame>& message_frames() const { + return framer_.message_frames(); + } + + const std::vector<QuicWindowUpdateFrame>& window_update_frames() const { + return framer_.window_update_frames(); + } + + const std::vector<QuicPaddingFrame>& padding_frames() const { + return framer_.padding_frames(); + } + + const std::vector<QuicPathChallengeFrame>& path_challenge_frames() const { + return framer_.path_challenge_frames(); + } + + const std::vector<QuicPathResponseFrame>& path_response_frames() const { + return framer_.path_response_frames(); + } + + const QuicEncryptedPacket* coalesced_packet() const { + return framer_.coalesced_packet(); + } + + size_t last_packet_size() { return last_packet_size_; } + + const QuicPacketHeader& last_packet_header() const { + return last_packet_header_; + } + + const QuicVersionNegotiationPacket* version_negotiation_packet() { + return framer_.version_negotiation_packet(); + } + + void set_is_write_blocked_data_buffered(bool buffered) { + is_write_blocked_data_buffered_ = buffered; + } + + void set_perspective(Perspective perspective) { + // We invert perspective here, because the framer needs to parse packets + // we send. + QuicFramerPeer::SetPerspective(framer_.framer(), + QuicUtils::InvertPerspective(perspective)); + } + + // final_bytes_of_last_packet_ returns the last four bytes of the previous + // packet as a little-endian, uint32_t. This is intended to be used with a + // TaggingEncrypter so that tests can determine which encrypter was used for + // a given packet. + uint32_t final_bytes_of_last_packet() { return final_bytes_of_last_packet_; } + + // Returns the final bytes of the second to last packet. + uint32_t final_bytes_of_previous_packet() { + return final_bytes_of_previous_packet_; + } + + void use_tagging_decrypter() { use_tagging_decrypter_ = true; } + + uint32_t packets_write_attempts() const { return packets_write_attempts_; } + + uint32_t flush_attempts() const { return flush_attempts_; } + + uint32_t connection_close_packets() const { + return connection_close_packets_; + } + + void Reset() { framer_.Reset(); } + + void SetSupportedVersions(const ParsedQuicVersionVector& versions) { + framer_.SetSupportedVersions(versions); + } + + void set_max_packet_size(QuicByteCount max_packet_size) { + max_packet_size_ = max_packet_size; + } + + void set_supports_release_time(bool supports_release_time) { + supports_release_time_ = supports_release_time; + } + + SimpleQuicFramer* framer() { return &framer_; } + + const QuicSocketAddress& last_write_peer_address() const { + return last_write_peer_address_; + } + + private: + char* AllocPacketBuffer(); + + void FreePacketBuffer(const char* buffer); + + ParsedQuicVersion version_; + SimpleQuicFramer framer_; + size_t last_packet_size_ = 0; + QuicPacketHeader last_packet_header_; + bool write_blocked_ = false; + bool write_should_fail_ = false; + bool block_on_next_flush_ = false; + bool block_on_next_write_ = false; + bool next_packet_too_large_ = false; + bool always_get_packet_too_large_ = false; + bool is_write_blocked_data_buffered_ = false; + bool is_batch_mode_ = false; + // Number of times Flush() was called. + uint32_t flush_attempts_ = 0; + // (Batch mode only) Number of bytes buffered in writer. It is used as the + // return value of a successful Flush(). + uint32_t bytes_buffered_ = 0; + uint32_t final_bytes_of_last_packet_ = 0; + uint32_t final_bytes_of_previous_packet_ = 0; + bool use_tagging_decrypter_ = false; + uint32_t packets_write_attempts_ = 0; + uint32_t connection_close_packets_ = 0; + MockClock* clock_ = nullptr; + // If non-zero, the clock will pause during WritePacket for this amount of + // time. + QuicTime::Delta write_pause_time_delta_ = QuicTime::Delta::Zero(); + QuicByteCount max_packet_size_ = kMaxOutgoingPacketSize; + bool supports_release_time_ = false; + // Used to verify writer-allocated packet buffers are properly released. + std::vector<PacketBuffer*> packet_buffer_pool_; + // Buffer address => Address of the owning PacketBuffer. + QuicHashMap<char*, PacketBuffer*> packet_buffer_pool_index_; + // Indices in packet_buffer_pool_ that are not allocated. + std::list<PacketBuffer*> packet_buffer_free_list_; + // The peer address passed into WritePacket(). + QuicSocketAddress last_write_peer_address_; +}; + } // namespace test } // namespace quic