Refactor TlsHandshaker classes
QuicCryptoClientConfig and QuicCryptoServerConfig each own an SSL_CTX,
which is currently created by TlsHandshaker. Those crypto config classes
can't take a dependency on TlsHandshaker (because TlsHandshaker depends on
classes have a dependency in the other direction), resulting in the SSL_CTX
being passed into the crypto config constructors. The SSL_CTX shouldn't be
exposed like this, as it's essentially an implementation detail of the
crypto handshake.
This CL splits TlsHandshaker in two. TlsConnection (and its subclasses) are
in quic/core/crypto, and handle the callbacks from BoringSSL. In turn, it
passes the implementation of those callbacks to a delegate. TlsHandshaker
implements this delegate and owns the TlsConnection.
gfe-relnote: refactor TLS handshake classes in QUIC; not flag protected
PiperOrigin-RevId: 253140899
Change-Id: Ie907a7f61798c29a385be15ea0f53403b86508ab
diff --git a/quic/core/tls_handshaker.cc b/quic/core/tls_handshaker.cc
index aeadfd4..3c82a49 100644
--- a/quic/core/tls_handshaker.cc
+++ b/quic/core/tls_handshaker.cc
@@ -13,42 +13,12 @@
namespace quic {
-namespace {
-
-class SslIndexSingleton {
- public:
- static SslIndexSingleton* GetInstance() {
- static SslIndexSingleton* instance = new SslIndexSingleton();
- return instance;
- }
-
- int HandshakerIndex() const { return ssl_ex_data_index_handshaker_; }
-
- private:
- SslIndexSingleton() {
- CRYPTO_library_init();
- ssl_ex_data_index_handshaker_ =
- SSL_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
- CHECK_LE(0, ssl_ex_data_index_handshaker_);
- }
-
- SslIndexSingleton(const SslIndexSingleton&) = delete;
- SslIndexSingleton& operator=(const SslIndexSingleton&) = delete;
-
- int ssl_ex_data_index_handshaker_;
-};
-
-} // namespace
-
TlsHandshaker::TlsHandshaker(QuicCryptoStream* stream,
QuicSession* session,
SSL_CTX* ssl_ctx)
: stream_(stream), session_(session) {
QUIC_BUG_IF(!GetQuicFlag(FLAGS_quic_supports_tls_handshake))
<< "Attempted to create TLS handshaker when TLS is disabled";
- ssl_.reset(SSL_new(ssl_ctx));
- SSL_set_ex_data(ssl(), SslIndexSingleton::GetInstance()->HandshakerIndex(),
- this);
}
TlsHandshaker::~TlsHandshaker() {}
@@ -62,7 +32,7 @@
// just received input at. If they mismatch, should ProcessInput return true
// or false? If data is for a future encryption level, it should be queued for
// later?
- if (SSL_provide_quic_data(ssl(), BoringEncryptionLevel(level),
+ if (SSL_provide_quic_data(ssl(), TlsConnection::BoringEncryptionLevel(level),
reinterpret_cast<const uint8_t*>(input.data()),
input.size()) != 1) {
// SSL_provide_quic_data can fail for 3 reasons:
@@ -84,58 +54,6 @@
return true;
}
-// static
-bssl::UniquePtr<SSL_CTX> TlsHandshaker::CreateSslCtx() {
- CRYPTO_library_init();
- bssl::UniquePtr<SSL_CTX> ssl_ctx(SSL_CTX_new(TLS_with_buffers_method()));
- SSL_CTX_set_min_proto_version(ssl_ctx.get(), TLS1_3_VERSION);
- SSL_CTX_set_max_proto_version(ssl_ctx.get(), TLS1_3_VERSION);
- SSL_CTX_set_quic_method(ssl_ctx.get(), &kSslQuicMethod);
- return ssl_ctx;
-}
-
-// static
-TlsHandshaker* TlsHandshaker::HandshakerFromSsl(const SSL* ssl) {
- return reinterpret_cast<TlsHandshaker*>(SSL_get_ex_data(
- ssl, SslIndexSingleton::GetInstance()->HandshakerIndex()));
-}
-
-// static
-EncryptionLevel TlsHandshaker::QuicEncryptionLevel(
- enum ssl_encryption_level_t level) {
- switch (level) {
- case ssl_encryption_initial:
- return ENCRYPTION_INITIAL;
- case ssl_encryption_early_data:
- return ENCRYPTION_ZERO_RTT;
- case ssl_encryption_handshake:
- return ENCRYPTION_HANDSHAKE;
- case ssl_encryption_application:
- return ENCRYPTION_FORWARD_SECURE;
- default:
- QUIC_BUG << "Invalid ssl_encryption_level_t " << static_cast<int>(level);
- return ENCRYPTION_INITIAL;
- }
-}
-
-// static
-enum ssl_encryption_level_t TlsHandshaker::BoringEncryptionLevel(
- EncryptionLevel level) {
- switch (level) {
- case ENCRYPTION_INITIAL:
- return ssl_encryption_initial;
- case ENCRYPTION_HANDSHAKE:
- return ssl_encryption_handshake;
- case ENCRYPTION_ZERO_RTT:
- return ssl_encryption_early_data;
- case ENCRYPTION_FORWARD_SECURE:
- return ssl_encryption_application;
- default:
- QUIC_BUG << "Invalid encryption level " << static_cast<int>(level);
- return ssl_encryption_initial;
- }
-}
-
const EVP_MD* TlsHandshaker::Prf() {
return EVP_get_digestbynid(
SSL_CIPHER_get_prf_nid(SSL_get_pending_cipher(ssl())));
@@ -159,53 +77,6 @@
return decrypter;
}
-const SSL_QUIC_METHOD TlsHandshaker::kSslQuicMethod{
- TlsHandshaker::SetEncryptionSecretCallback,
- TlsHandshaker::WriteMessageCallback, TlsHandshaker::FlushFlightCallback,
- TlsHandshaker::SendAlertCallback};
-
-// static
-int TlsHandshaker::SetEncryptionSecretCallback(
- SSL* ssl,
- enum ssl_encryption_level_t level,
- const uint8_t* read_key,
- const uint8_t* write_key,
- size_t secret_len) {
- // TODO(nharper): replace these vectors and memcpys with spans (which
- // unfortunately doesn't yet exist in quic/platform/api).
- std::vector<uint8_t> read_secret(secret_len), write_secret(secret_len);
- memcpy(read_secret.data(), read_key, secret_len);
- memcpy(write_secret.data(), write_key, secret_len);
- HandshakerFromSsl(ssl)->SetEncryptionSecret(QuicEncryptionLevel(level),
- read_secret, write_secret);
- return 1;
-}
-
-// static
-int TlsHandshaker::WriteMessageCallback(SSL* ssl,
- enum ssl_encryption_level_t level,
- const uint8_t* data,
- size_t len) {
- HandshakerFromSsl(ssl)->WriteMessage(
- QuicEncryptionLevel(level),
- QuicStringPiece(reinterpret_cast<const char*>(data), len));
- return 1;
-}
-
-// static
-int TlsHandshaker::FlushFlightCallback(SSL* ssl) {
- HandshakerFromSsl(ssl)->FlushFlight();
- return 1;
-}
-
-// static
-int TlsHandshaker::SendAlertCallback(SSL* ssl,
- enum ssl_encryption_level_t level,
- uint8_t desc) {
- HandshakerFromSsl(ssl)->SendAlert(QuicEncryptionLevel(level), desc);
- return 1;
-}
-
void TlsHandshaker::SetEncryptionSecret(
EncryptionLevel level,
const std::vector<uint8_t>& read_secret,