Refactor QUIC TlsHandshaker classes
All machinery to drive SSL_do_handshake is moved into TlsHandshaker.
Protected by not protected.
PiperOrigin-RevId: 340346649
Change-Id: Ie7025c135db1b97ddedfacd7fe781d5baac98b7c
diff --git a/quic/core/tls_handshaker.h b/quic/core/tls_handshaker.h
index 077e373..9288592 100644
--- a/quic/core/tls_handshaker.h
+++ b/quic/core/tls_handshaker.h
@@ -50,12 +50,38 @@
ssl_early_data_reason_t EarlyDataReason() const;
std::unique_ptr<QuicDecrypter> AdvanceKeysAndCreateCurrentOneRttDecrypter();
std::unique_ptr<QuicEncrypter> CreateCurrentOneRttEncrypter();
+ virtual HandshakeState GetHandshakeState() const = 0;
protected:
- virtual void AdvanceHandshake() = 0;
+ // Called when a new message is received on the crypto stream and is available
+ // for the TLS stack to read.
+ void AdvanceHandshake();
- virtual void CloseConnection(QuicErrorCode error,
- const std::string& reason_phrase) = 0;
+ void CloseConnection(QuicErrorCode error, const std::string& reason_phrase);
+
+ void OnConnectionClosed(QuicErrorCode error, ConnectionCloseSource source);
+
+ bool is_connection_closed() const { return is_connection_closed_; }
+
+ // Called when |SSL_do_handshake| returns 1, indicating that the handshake has
+ // finished. Note that due to 0-RTT, the handshake may "finish" twice;
+ // |SSL_in_early_data| can be used to determine whether the handshake is truly
+ // done.
+ virtual void FinishHandshake() = 0;
+
+ // Called when a handshake message is received after the handshake is
+ // complete.
+ virtual void ProcessPostHandshakeMessage() = 0;
+
+ // Called when an unexpected error code is received from |SSL_get_error|. If a
+ // subclass can expect more than just a single error (as provided by
+ // |set_expected_ssl_error|), it can override this method to handle that case.
+ virtual bool ShouldCloseConnectionOnUnexpectedError(int ssl_error);
+
+ void set_expected_ssl_error(int ssl_error) {
+ expected_ssl_error_ = ssl_error;
+ }
+ int expected_ssl_error() const { return expected_ssl_error_; }
// Returns the PRF used by the cipher suite negotiated in the TLS handshake.
const EVP_MD* Prf(const SSL_CIPHER* cipher);
@@ -101,6 +127,9 @@
void SendAlert(EncryptionLevel level, uint8_t desc) override;
private:
+ int expected_ssl_error_ = SSL_ERROR_WANT_READ;
+ bool is_connection_closed_ = false;
+
QuicCryptoStream* stream_;
HandshakerDelegateInterface* handshaker_delegate_;