Project import generated by Copybara.
PiperOrigin-RevId: 237361882
Change-Id: I109a68f44db867b20f8c6a7732b0ce657133e52a
diff --git a/quic/core/tls_handshaker_test.cc b/quic/core/tls_handshaker_test.cc
new file mode 100644
index 0000000..2d678dd
--- /dev/null
+++ b/quic/core/tls_handshaker_test.cc
@@ -0,0 +1,417 @@
+// Copyright (c) 2017 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/third_party/quiche/src/quic/core/quic_utils.h"
+#include "net/third_party/quiche/src/quic/core/tls_client_handshaker.h"
+#include "net/third_party/quiche/src/quic/core/tls_server_handshaker.h"
+#include "net/third_party/quiche/src/quic/platform/api/quic_arraysize.h"
+#include "net/third_party/quiche/src/quic/platform/api/quic_ptr_util.h"
+#include "net/third_party/quiche/src/quic/platform/api/quic_string.h"
+#include "net/third_party/quiche/src/quic/platform/api/quic_test.h"
+#include "net/third_party/quiche/src/quic/test_tools/crypto_test_utils.h"
+#include "net/third_party/quiche/src/quic/test_tools/fake_proof_source.h"
+#include "net/third_party/quiche/src/quic/test_tools/mock_quic_session_visitor.h"
+#include "net/third_party/quiche/src/quic/test_tools/quic_test_utils.h"
+
+namespace quic {
+namespace test {
+namespace {
+
+using ::testing::_;
+
+class FakeProofVerifier : public ProofVerifier {
+ public:
+ FakeProofVerifier()
+ : verifier_(crypto_test_utils::ProofVerifierForTesting()) {}
+
+ QuicAsyncStatus VerifyProof(
+ const QuicString& hostname,
+ const uint16_t port,
+ const QuicString& server_config,
+ QuicTransportVersion quic_version,
+ QuicStringPiece chlo_hash,
+ const std::vector<QuicString>& certs,
+ const QuicString& cert_sct,
+ const QuicString& signature,
+ const ProofVerifyContext* context,
+ QuicString* error_details,
+ std::unique_ptr<ProofVerifyDetails>* details,
+ std::unique_ptr<ProofVerifierCallback> callback) override {
+ return verifier_->VerifyProof(
+ hostname, port, server_config, quic_version, chlo_hash, certs, cert_sct,
+ signature, context, error_details, details, std::move(callback));
+ }
+
+ QuicAsyncStatus VerifyCertChain(
+ const QuicString& hostname,
+ const std::vector<QuicString>& certs,
+ const ProofVerifyContext* context,
+ QuicString* error_details,
+ std::unique_ptr<ProofVerifyDetails>* details,
+ std::unique_ptr<ProofVerifierCallback> callback) override {
+ if (!active_) {
+ return verifier_->VerifyCertChain(hostname, certs, context, error_details,
+ details, std::move(callback));
+ }
+ pending_ops_.push_back(QuicMakeUnique<VerifyChainPendingOp>(
+ hostname, certs, context, error_details, details, std::move(callback),
+ verifier_.get()));
+ return QUIC_PENDING;
+ }
+
+ std::unique_ptr<ProofVerifyContext> CreateDefaultContext() override {
+ return nullptr;
+ }
+
+ void Activate() { active_ = true; }
+
+ size_t NumPendingCallbacks() const { return pending_ops_.size(); }
+
+ void InvokePendingCallback(size_t n) {
+ CHECK(NumPendingCallbacks() > n);
+ pending_ops_[n]->Run();
+ auto it = pending_ops_.begin() + n;
+ pending_ops_.erase(it);
+ }
+
+ private:
+ // Implementation of ProofVerifierCallback that fails if the callback is ever
+ // run.
+ class FailingProofVerifierCallback : public ProofVerifierCallback {
+ public:
+ void Run(bool ok,
+ const QuicString& error_details,
+ std::unique_ptr<ProofVerifyDetails>* details) override {
+ FAIL();
+ }
+ };
+
+ class VerifyChainPendingOp {
+ public:
+ VerifyChainPendingOp(const QuicString& hostname,
+ const std::vector<QuicString>& certs,
+ const ProofVerifyContext* context,
+ QuicString* error_details,
+ std::unique_ptr<ProofVerifyDetails>* details,
+ std::unique_ptr<ProofVerifierCallback> callback,
+ ProofVerifier* delegate)
+ : hostname_(hostname),
+ certs_(certs),
+ context_(context),
+ error_details_(error_details),
+ details_(details),
+ callback_(std::move(callback)),
+ delegate_(delegate) {}
+
+ void Run() {
+ // FakeProofVerifier depends on crypto_test_utils::ProofVerifierForTesting
+ // running synchronously. It passes a FailingProofVerifierCallback and
+ // runs the original callback after asserting that the verification ran
+ // synchronously.
+ QuicAsyncStatus status = delegate_->VerifyCertChain(
+ hostname_, certs_, context_, error_details_, details_,
+ QuicMakeUnique<FailingProofVerifierCallback>());
+ ASSERT_NE(status, QUIC_PENDING);
+ callback_->Run(status == QUIC_SUCCESS, *error_details_, details_);
+ }
+
+ private:
+ QuicString hostname_;
+ std::vector<QuicString> certs_;
+ const ProofVerifyContext* context_;
+ QuicString* error_details_;
+ std::unique_ptr<ProofVerifyDetails>* details_;
+ std::unique_ptr<ProofVerifierCallback> callback_;
+ ProofVerifier* delegate_;
+ };
+
+ std::unique_ptr<ProofVerifier> verifier_;
+ bool active_ = false;
+ std::vector<std::unique_ptr<VerifyChainPendingOp>> pending_ops_;
+};
+
+class TestQuicCryptoStream : public QuicCryptoStream {
+ public:
+ explicit TestQuicCryptoStream(QuicSession* session)
+ : QuicCryptoStream(session) {}
+
+ ~TestQuicCryptoStream() override = default;
+
+ virtual TlsHandshaker* handshaker() const = 0;
+
+ bool encryption_established() const override {
+ return handshaker()->encryption_established();
+ }
+
+ bool handshake_confirmed() const override {
+ return handshaker()->handshake_confirmed();
+ }
+
+ const QuicCryptoNegotiatedParameters& crypto_negotiated_params()
+ const override {
+ return handshaker()->crypto_negotiated_params();
+ }
+
+ CryptoMessageParser* crypto_message_parser() override {
+ return handshaker()->crypto_message_parser();
+ }
+
+ void WriteCryptoData(EncryptionLevel level, QuicStringPiece data) override {
+ pending_writes_.push_back(std::make_pair(QuicString(data), level));
+ }
+
+ const std::vector<std::pair<QuicString, EncryptionLevel>>& pending_writes() {
+ return pending_writes_;
+ }
+
+ // Sends the pending frames to |stream| and clears the array of pending
+ // writes.
+ void SendCryptoMessagesToPeer(QuicCryptoStream* stream) {
+ QUIC_LOG(INFO) << "Sending " << pending_writes_.size() << " frames";
+ // This is a minimal re-implementation of QuicCryptoStream::OnDataAvailable.
+ // It doesn't work to call QuicStream::OnStreamFrame because
+ // QuicCryptoStream::OnDataAvailable currently (as an implementation detail)
+ // relies on the QuicConnection to know the EncryptionLevel to pass into
+ // CryptoMessageParser::ProcessInput. Since the crypto messages in this test
+ // never reach the framer or connection and never get encrypted/decrypted,
+ // QuicCryptoStream::OnDataAvailable isn't able to call ProcessInput with
+ // the correct EncryptionLevel. Instead, that can be short-circuited by
+ // directly calling ProcessInput here.
+ for (size_t i = 0; i < pending_writes_.size(); ++i) {
+ if (!stream->crypto_message_parser()->ProcessInput(
+ pending_writes_[i].first, pending_writes_[i].second)) {
+ CloseConnectionWithDetails(
+ stream->crypto_message_parser()->error(),
+ stream->crypto_message_parser()->error_detail());
+ break;
+ }
+ }
+ pending_writes_.clear();
+ }
+
+ private:
+ std::vector<std::pair<QuicString, EncryptionLevel>> pending_writes_;
+};
+
+class TestQuicCryptoClientStream : public TestQuicCryptoStream {
+ public:
+ explicit TestQuicCryptoClientStream(QuicSession* session)
+ : TestQuicCryptoStream(session),
+ proof_verifier_(new FakeProofVerifier),
+ ssl_ctx_(TlsClientHandshaker::CreateSslCtx()),
+ handshaker_(new TlsClientHandshaker(
+ this,
+ session,
+ QuicServerId("test.example.com", 443, false),
+ proof_verifier_.get(),
+ ssl_ctx_.get(),
+ crypto_test_utils::ProofVerifyContextForTesting(),
+ "quic-tester")) {}
+
+ ~TestQuicCryptoClientStream() override = default;
+
+ TlsHandshaker* handshaker() const override { return handshaker_.get(); }
+
+ bool CryptoConnect() { return handshaker_->CryptoConnect(); }
+
+ FakeProofVerifier* GetFakeProofVerifier() const {
+ return proof_verifier_.get();
+ }
+
+ private:
+ std::unique_ptr<FakeProofVerifier> proof_verifier_;
+ bssl::UniquePtr<SSL_CTX> ssl_ctx_;
+ std::unique_ptr<TlsClientHandshaker> handshaker_;
+};
+
+class TestQuicCryptoServerStream : public TestQuicCryptoStream {
+ public:
+ TestQuicCryptoServerStream(QuicSession* session,
+ FakeProofSource* proof_source)
+ : TestQuicCryptoStream(session),
+ proof_source_(proof_source),
+ ssl_ctx_(TlsServerHandshaker::CreateSslCtx()),
+ handshaker_(new TlsServerHandshaker(this,
+ session,
+ ssl_ctx_.get(),
+ proof_source_)) {}
+
+ ~TestQuicCryptoServerStream() override = default;
+
+ void CancelOutstandingCallbacks() {
+ handshaker_->CancelOutstandingCallbacks();
+ }
+
+ TlsHandshaker* handshaker() const override { return handshaker_.get(); }
+
+ FakeProofSource* GetFakeProofSource() const { return proof_source_; }
+
+ private:
+ FakeProofSource* proof_source_;
+ bssl::UniquePtr<SSL_CTX> ssl_ctx_;
+ std::unique_ptr<TlsServerHandshaker> handshaker_;
+};
+
+void ExchangeHandshakeMessages(TestQuicCryptoStream* client,
+ TestQuicCryptoStream* server) {
+ while (!client->pending_writes().empty() ||
+ !server->pending_writes().empty()) {
+ client->SendCryptoMessagesToPeer(server);
+ server->SendCryptoMessagesToPeer(client);
+ }
+}
+
+class TlsHandshakerTest : public QuicTest {
+ public:
+ TlsHandshakerTest()
+ : client_conn_(new MockQuicConnection(&conn_helper_,
+ &alarm_factory_,
+ Perspective::IS_CLIENT)),
+ server_conn_(new MockQuicConnection(&conn_helper_,
+ &alarm_factory_,
+ Perspective::IS_SERVER)),
+ client_session_(client_conn_, /*create_mock_crypto_stream=*/false),
+ server_session_(server_conn_, /*create_mock_crypto_stream=*/false) {
+ client_stream_ = new TestQuicCryptoClientStream(&client_session_);
+ client_session_.SetCryptoStream(client_stream_);
+ server_stream_ =
+ new TestQuicCryptoServerStream(&server_session_, &proof_source_);
+ server_session_.SetCryptoStream(server_stream_);
+ client_session_.Initialize();
+ server_session_.Initialize();
+ EXPECT_FALSE(client_stream_->encryption_established());
+ EXPECT_FALSE(client_stream_->handshake_confirmed());
+ EXPECT_FALSE(server_stream_->encryption_established());
+ EXPECT_FALSE(server_stream_->handshake_confirmed());
+ }
+
+ MockQuicConnectionHelper conn_helper_;
+ MockAlarmFactory alarm_factory_;
+ MockQuicConnection* client_conn_;
+ MockQuicConnection* server_conn_;
+ MockQuicSession client_session_;
+ MockQuicSession server_session_;
+
+ FakeProofSource proof_source_;
+ TestQuicCryptoClientStream* client_stream_;
+ TestQuicCryptoServerStream* server_stream_;
+};
+
+TEST_F(TlsHandshakerTest, CryptoHandshake) {
+ EXPECT_CALL(*client_conn_, CloseConnection(_, _, _)).Times(0);
+ EXPECT_CALL(*server_conn_, CloseConnection(_, _, _)).Times(0);
+ client_stream_->CryptoConnect();
+ ExchangeHandshakeMessages(client_stream_, server_stream_);
+
+ EXPECT_TRUE(client_stream_->handshake_confirmed());
+ EXPECT_TRUE(client_stream_->encryption_established());
+ EXPECT_TRUE(server_stream_->handshake_confirmed());
+ EXPECT_TRUE(server_stream_->encryption_established());
+}
+
+TEST_F(TlsHandshakerTest, HandshakeWithAsyncProofSource) {
+ EXPECT_CALL(*client_conn_, CloseConnection(_, _, _)).Times(0);
+ EXPECT_CALL(*server_conn_, CloseConnection(_, _, _)).Times(0);
+ // Enable FakeProofSource to capture call to ComputeTlsSignature and run it
+ // asynchronously.
+ FakeProofSource* proof_source = server_stream_->GetFakeProofSource();
+ proof_source->Activate();
+
+ // Start handshake.
+ client_stream_->CryptoConnect();
+ ExchangeHandshakeMessages(client_stream_, server_stream_);
+
+ ASSERT_EQ(proof_source->NumPendingCallbacks(), 1);
+ proof_source->InvokePendingCallback(0);
+
+ ExchangeHandshakeMessages(client_stream_, server_stream_);
+
+ EXPECT_TRUE(client_stream_->handshake_confirmed());
+ EXPECT_TRUE(client_stream_->encryption_established());
+ EXPECT_TRUE(server_stream_->handshake_confirmed());
+ EXPECT_TRUE(server_stream_->encryption_established());
+}
+
+TEST_F(TlsHandshakerTest, CancelPendingProofSource) {
+ EXPECT_CALL(*client_conn_, CloseConnection(_, _, _)).Times(0);
+ EXPECT_CALL(*server_conn_, CloseConnection(_, _, _)).Times(0);
+ // Enable FakeProofSource to capture call to ComputeTlsSignature and run it
+ // asynchronously.
+ FakeProofSource* proof_source = server_stream_->GetFakeProofSource();
+ proof_source->Activate();
+
+ // Start handshake.
+ client_stream_->CryptoConnect();
+ ExchangeHandshakeMessages(client_stream_, server_stream_);
+
+ ASSERT_EQ(proof_source->NumPendingCallbacks(), 1);
+ server_stream_ = nullptr;
+
+ proof_source->InvokePendingCallback(0);
+}
+
+TEST_F(TlsHandshakerTest, HandshakeWithAsyncProofVerifier) {
+ EXPECT_CALL(*client_conn_, CloseConnection(_, _, _)).Times(0);
+ EXPECT_CALL(*server_conn_, CloseConnection(_, _, _)).Times(0);
+ // Enable FakeProofVerifier to capture call to VerifyCertChain and run it
+ // asynchronously.
+ FakeProofVerifier* proof_verifier = client_stream_->GetFakeProofVerifier();
+ proof_verifier->Activate();
+
+ // Start handshake.
+ client_stream_->CryptoConnect();
+ ExchangeHandshakeMessages(client_stream_, server_stream_);
+
+ ASSERT_EQ(proof_verifier->NumPendingCallbacks(), 1u);
+ proof_verifier->InvokePendingCallback(0);
+
+ ExchangeHandshakeMessages(client_stream_, server_stream_);
+
+ EXPECT_TRUE(client_stream_->handshake_confirmed());
+ EXPECT_TRUE(client_stream_->encryption_established());
+ EXPECT_TRUE(server_stream_->handshake_confirmed());
+ EXPECT_TRUE(server_stream_->encryption_established());
+}
+
+TEST_F(TlsHandshakerTest, ClientConnectionClosedOnTlsError) {
+ // Have client send ClientHello.
+ client_stream_->CryptoConnect();
+ EXPECT_CALL(*client_conn_, CloseConnection(QUIC_HANDSHAKE_FAILED, _, _));
+
+ // Send a zero-length ServerHello from server to client.
+ char bogus_handshake_message[] = {
+ // Handshake struct (RFC 8446 appendix B.3)
+ 2, // HandshakeType server_hello
+ 0, 0, 0, // uint24 length
+ };
+ server_stream_->WriteCryptoData(
+ ENCRYPTION_NONE,
+ QuicStringPiece(bogus_handshake_message,
+ QUIC_ARRAYSIZE(bogus_handshake_message)));
+ server_stream_->SendCryptoMessagesToPeer(client_stream_);
+
+ EXPECT_FALSE(client_stream_->handshake_confirmed());
+}
+
+TEST_F(TlsHandshakerTest, ServerConnectionClosedOnTlsError) {
+ EXPECT_CALL(*server_conn_, CloseConnection(QUIC_HANDSHAKE_FAILED, _, _));
+
+ // Send a zero-length ClientHello from client to server.
+ char bogus_handshake_message[] = {
+ // Handshake struct (RFC 8446 appendix B.3)
+ 1, // HandshakeType client_hello
+ 0, 0, 0, // uint24 length
+ };
+ client_stream_->WriteCryptoData(
+ ENCRYPTION_NONE,
+ QuicStringPiece(bogus_handshake_message,
+ QUIC_ARRAYSIZE(bogus_handshake_message)));
+ client_stream_->SendCryptoMessagesToPeer(server_stream_);
+
+ EXPECT_FALSE(server_stream_->handshake_confirmed());
+}
+
+} // namespace
+} // namespace test
+} // namespace quic