Refactor QUIC TlsHandshaker classes

All machinery to drive SSL_do_handshake is moved into TlsHandshaker.

Protected by not protected.

PiperOrigin-RevId: 340346649
Change-Id: Ie7025c135db1b97ddedfacd7fe781d5baac98b7c
diff --git a/quic/core/tls_client_handshaker.cc b/quic/core/tls_client_handshaker.cc
index c38724b..16864bf 100644
--- a/quic/core/tls_client_handshaker.cc
+++ b/quic/core/tls_client_handshaker.cc
@@ -36,7 +36,7 @@
 
   parent_->verify_details_ = std::move(*details);
   parent_->verify_result_ = ok ? ssl_verify_ok : ssl_verify_invalid;
-  parent_->state_ = STATE_HANDSHAKE_RUNNING;
+  parent_->set_expected_ssl_error(SSL_ERROR_WANT_READ);
   parent_->proof_verify_callback_ = nullptr;
   if (parent_->verify_details_) {
     parent_->proof_handler_->OnProofVerifyDetailsAvailable(
@@ -77,8 +77,6 @@
 }
 
 bool TlsClientHandshaker::CryptoConnect() {
-  state_ = STATE_HANDSHAKE_RUNNING;
-
   if (!pre_shared_key_.empty()) {
     // TODO(b/154162689) add PSK support to QUIC+TLS.
     std::string error_details =
@@ -273,7 +271,7 @@
   }
 
   session()->OnConfigNegotiated();
-  if (state_ == STATE_CONNECTION_CLOSED) {
+  if (is_connection_closed()) {
     *error_details =
         "Session closed the connection when parsing negotiated config.";
     return false;
@@ -286,12 +284,12 @@
 }
 
 bool TlsClientHandshaker::IsResumption() const {
-  QUIC_BUG_IF(!one_rtt_keys_available_);
+  QUIC_BUG_IF(!one_rtt_keys_available());
   return SSL_session_reused(ssl()) == 1;
 }
 
 bool TlsClientHandshaker::EarlyDataAccepted() const {
-  QUIC_BUG_IF(!one_rtt_keys_available_);
+  QUIC_BUG_IF(!one_rtt_keys_available());
   return SSL_early_data_accepted(ssl()) == 1;
 }
 
@@ -300,7 +298,7 @@
 }
 
 bool TlsClientHandshaker::ReceivedInchoateReject() const {
-  QUIC_BUG_IF(!one_rtt_keys_available_);
+  QUIC_BUG_IF(!one_rtt_keys_available());
   // REJ messages are a QUIC crypto feature, so TLS always returns false.
   return false;
 }
@@ -319,7 +317,7 @@
 }
 
 bool TlsClientHandshaker::one_rtt_keys_available() const {
-  return one_rtt_keys_available_;
+  return state_ >= HANDSHAKE_COMPLETE;
 }
 
 const QuicCryptoNegotiatedParameters&
@@ -332,16 +330,7 @@
 }
 
 HandshakeState TlsClientHandshaker::GetHandshakeState() const {
-  if (handshake_confirmed_) {
-    return HANDSHAKE_CONFIRMED;
-  }
-  if (one_rtt_keys_available_) {
-    return HANDSHAKE_COMPLETE;
-  }
-  if (state_ >= STATE_ENCRYPTION_HANDSHAKE_DATA_SENT) {
-    return HANDSHAKE_PROCESSED;
-  }
-  return HANDSHAKE_START;
+  return state_;
 }
 
 size_t TlsClientHandshaker::BufferSizeLimitForLevel(
@@ -376,13 +365,13 @@
   handshaker_delegate()->DiscardOldDecryptionKey(ENCRYPTION_INITIAL);
 }
 
-void TlsClientHandshaker::OnConnectionClosed(QuicErrorCode /*error*/,
-                                             ConnectionCloseSource /*source*/) {
-  state_ = STATE_CONNECTION_CLOSED;
+void TlsClientHandshaker::OnConnectionClosed(QuicErrorCode error,
+                                             ConnectionCloseSource source) {
+  TlsHandshaker::OnConnectionClosed(error, source);
 }
 
 void TlsClientHandshaker::OnHandshakeDoneReceived() {
-  if (!one_rtt_keys_available_) {
+  if (!one_rtt_keys_available()) {
     CloseConnection(QUIC_HANDSHAKE_FAILED,
                     "Unexpected handshake done received");
     return;
@@ -394,7 +383,7 @@
     EncryptionLevel level,
     const SSL_CIPHER* cipher,
     const std::vector<uint8_t>& write_secret) {
-  if (state_ == STATE_CONNECTION_CLOSED) {
+  if (is_connection_closed()) {
     return;
   }
   if (level == ENCRYPTION_FORWARD_SECURE || level == ENCRYPTION_ZERO_RTT) {
@@ -407,72 +396,15 @@
 }
 
 void TlsClientHandshaker::OnHandshakeConfirmed() {
-  DCHECK(one_rtt_keys_available_);
-  if (handshake_confirmed_) {
+  DCHECK(one_rtt_keys_available());
+  if (state_ >= HANDSHAKE_CONFIRMED) {
     return;
   }
-  handshake_confirmed_ = true;
+  state_ = HANDSHAKE_CONFIRMED;
   handshaker_delegate()->DiscardOldEncryptionKey(ENCRYPTION_HANDSHAKE);
   handshaker_delegate()->DiscardOldDecryptionKey(ENCRYPTION_HANDSHAKE);
 }
 
-void TlsClientHandshaker::AdvanceHandshake() {
-  if (state_ == STATE_CONNECTION_CLOSED) {
-    QUIC_LOG(INFO)
-        << "TlsClientHandshaker received message after connection closed";
-    return;
-  }
-  if (state_ == STATE_IDLE) {
-    CloseConnection(QUIC_HANDSHAKE_FAILED,
-                    "Client observed TLS handshake idle failure");
-    return;
-  }
-  if (state_ == STATE_HANDSHAKE_COMPLETE) {
-    int rv = SSL_process_quic_post_handshake(ssl());
-    if (rv != 1) {
-      CloseConnection(QUIC_HANDSHAKE_FAILED, "Unexpected post-handshake data");
-    }
-    return;
-  }
-
-  QUIC_LOG(INFO) << "TlsClientHandshaker: continuing handshake";
-  int rv = SSL_do_handshake(ssl());
-  if (rv == 1) {
-    FinishHandshake();
-    return;
-  }
-  int ssl_error = SSL_get_error(ssl(), rv);
-  bool should_close = true;
-  if (ssl_error == SSL_ERROR_EARLY_DATA_REJECTED) {
-    HandleZeroRttReject();
-    return;
-  }
-  switch (state_) {
-    case STATE_HANDSHAKE_RUNNING:
-      should_close = ssl_error != SSL_ERROR_WANT_READ;
-      break;
-    case STATE_CERT_VERIFY_PENDING:
-      should_close = ssl_error != SSL_ERROR_WANT_CERTIFICATE_VERIFY;
-      break;
-    default:
-      should_close = true;
-  }
-  if (should_close && state_ != STATE_CONNECTION_CLOSED) {
-    // TODO(nharper): Surface error details from the error queue when ssl_error
-    // is SSL_ERROR_SSL.
-    QUIC_LOG(WARNING) << "SSL_do_handshake failed; closing connection";
-    CloseConnection(QUIC_HANDSHAKE_FAILED,
-                    "Client observed TLS handshake failure");
-  }
-}
-
-void TlsClientHandshaker::CloseConnection(QuicErrorCode error,
-                                          const std::string& reason_phrase) {
-  DCHECK(!reason_phrase.empty());
-  state_ = STATE_CONNECTION_CLOSED;
-  stream()->OnUnrecoverableError(error, reason_phrase);
-}
-
 void TlsClientHandshaker::FinishHandshake() {
   // Fill crypto_negotiated_params_:
   const SSL_CIPHER* cipher = SSL_get_current_cipher(ssl());
@@ -495,7 +427,6 @@
     return;
   }
   QUIC_LOG(INFO) << "Client: handshake finished";
-  state_ = STATE_HANDSHAKE_COMPLETE;
 
   std::string error_details;
   if (!ProcessTransportParameters(&error_details)) {
@@ -531,10 +462,26 @@
   session()->OnAlpnSelected(received_alpn_string);
   QUIC_DLOG(INFO) << "Client: server selected ALPN: '" << received_alpn_string
                   << "'";
-  one_rtt_keys_available_ = true;
+  state_ = HANDSHAKE_COMPLETE;
   handshaker_delegate()->OnTlsHandshakeComplete();
 }
 
+void TlsClientHandshaker::ProcessPostHandshakeMessage() {
+  int rv = SSL_process_quic_post_handshake(ssl());
+  if (rv != 1) {
+    CloseConnection(QUIC_HANDSHAKE_FAILED, "Unexpected post-handshake data");
+  }
+}
+
+bool TlsClientHandshaker::ShouldCloseConnectionOnUnexpectedError(
+    int ssl_error) {
+  if (ssl_error != SSL_ERROR_EARLY_DATA_REJECTED) {
+    return true;
+  }
+  HandleZeroRttReject();
+  return false;
+}
+
 void TlsClientHandshaker::HandleZeroRttReject() {
   QUIC_LOG(INFO) << "0-RTT handshake attempted but was rejected by the server";
   DCHECK(session_cache_);
@@ -548,7 +495,7 @@
 
 enum ssl_verify_result_t TlsClientHandshaker::VerifyCert(uint8_t* out_alert) {
   if (verify_result_ != ssl_verify_retry ||
-      state_ == STATE_CERT_VERIFY_PENDING) {
+      expected_ssl_error() == SSL_ERROR_WANT_CERTIFICATE_VERIFY) {
     enum ssl_verify_result_t result = verify_result_;
     verify_result_ = ssl_verify_retry;
     return result;
@@ -591,7 +538,7 @@
       return ssl_verify_ok;
     case QUIC_PENDING:
       proof_verify_callback_ = proof_verify_callback;
-      state_ = STATE_CERT_VERIFY_PENDING;
+      set_expected_ssl_error(SSL_ERROR_WANT_CERTIFICATE_VERIFY);
       return ssl_verify_retry;
     case QUIC_FAILURE:
     default:
@@ -625,16 +572,15 @@
 
 void TlsClientHandshaker::WriteMessage(EncryptionLevel level,
                                        absl::string_view data) {
-  if (level == ENCRYPTION_HANDSHAKE &&
-      state_ < STATE_ENCRYPTION_HANDSHAKE_DATA_SENT) {
-    state_ = STATE_ENCRYPTION_HANDSHAKE_DATA_SENT;
+  if (level == ENCRYPTION_HANDSHAKE && state_ < HANDSHAKE_PROCESSED) {
+    state_ = HANDSHAKE_PROCESSED;
   }
   TlsHandshaker::WriteMessage(level, data);
 }
 
 void TlsClientHandshaker::SetServerApplicationStateForResumption(
     std::unique_ptr<ApplicationState> application_state) {
-  DCHECK_EQ(STATE_HANDSHAKE_COMPLETE, state_);
+  DCHECK(one_rtt_keys_available());
   received_application_state_ = std::move(application_state);
   // At least one tls session is cached before application state is received. So
   // insert now.
diff --git a/quic/core/tls_client_handshaker.h b/quic/core/tls_client_handshaker.h
index f5a349f..2f16620 100644
--- a/quic/core/tls_client_handshaker.h
+++ b/quic/core/tls_client_handshaker.h
@@ -86,9 +86,9 @@
     return &tls_connection_;
   }
 
-  void AdvanceHandshake() override;
-  void CloseConnection(QuicErrorCode error,
-                       const std::string& reason_phrase) override;
+  void FinishHandshake() override;
+  void ProcessPostHandshakeMessage() override;
+  bool ShouldCloseConnectionOnUnexpectedError(int ssl_error) override;
 
   // TlsClientConnection::Delegate implementation:
   enum ssl_verify_result_t VerifyCert(uint8_t* out_alert) override;
@@ -115,19 +115,9 @@
     TlsClientHandshaker* parent_;
   };
 
-  enum State {
-    STATE_IDLE,
-    STATE_HANDSHAKE_RUNNING,
-    STATE_CERT_VERIFY_PENDING,
-    STATE_ENCRYPTION_HANDSHAKE_DATA_SENT,
-    STATE_HANDSHAKE_COMPLETE,
-    STATE_CONNECTION_CLOSED,
-  } state_ = STATE_IDLE;
-
   bool SetAlpn();
   bool SetTransportParameters();
   bool ProcessTransportParameters(std::string* error_details);
-  void FinishHandshake();
   void HandleZeroRttReject();
 
   // Called when server completes handshake (i.e., either handshake done is
@@ -169,10 +159,9 @@
   enum ssl_verify_result_t verify_result_ = ssl_verify_retry;
   std::string cert_verify_error_details_;
 
+  HandshakeState state_ = HANDSHAKE_START;
   bool encryption_established_ = false;
   bool initial_keys_dropped_ = false;
-  bool one_rtt_keys_available_ = false;
-  bool handshake_confirmed_ = false;
   QuicReferenceCountedPointer<QuicCryptoNegotiatedParameters>
       crypto_negotiated_params_;
 
diff --git a/quic/core/tls_handshaker.cc b/quic/core/tls_handshaker.cc
index 7b3efd8..6c45f84 100644
--- a/quic/core/tls_handshaker.cc
+++ b/quic/core/tls_handshaker.cc
@@ -52,6 +52,50 @@
   return true;
 }
 
+void TlsHandshaker::AdvanceHandshake() {
+  if (is_connection_closed_) {
+    return;
+  }
+  if (GetHandshakeState() >= HANDSHAKE_COMPLETE) {
+    ProcessPostHandshakeMessage();
+    return;
+  }
+
+  QUIC_LOG(INFO) << "TlsHandshaker: continuing handshake";
+  int rv = SSL_do_handshake(ssl());
+  if (rv == 1) {
+    FinishHandshake();
+    return;
+  }
+  int ssl_error = SSL_get_error(ssl(), rv);
+  if (ssl_error == expected_ssl_error_) {
+    return;
+  }
+  if (ShouldCloseConnectionOnUnexpectedError(ssl_error) &&
+      !is_connection_closed_) {
+    QUIC_VLOG(1) << "SSL_do_handshake failed; SSL_get_error returns "
+                 << ssl_error;
+    ERR_print_errors_fp(stderr);
+    CloseConnection(QUIC_HANDSHAKE_FAILED, "TLS handshake failed");
+  }
+}
+
+void TlsHandshaker::CloseConnection(QuicErrorCode error,
+                                    const std::string& reason_phrase) {
+  DCHECK(!reason_phrase.empty());
+  stream()->OnUnrecoverableError(error, reason_phrase);
+  is_connection_closed_ = true;
+}
+
+void TlsHandshaker::OnConnectionClosed(QuicErrorCode /*error*/,
+                                       ConnectionCloseSource /*source*/) {
+  is_connection_closed_ = true;
+}
+
+bool TlsHandshaker::ShouldCloseConnectionOnUnexpectedError(int /*ssl_error*/) {
+  return true;
+}
+
 size_t TlsHandshaker::BufferSizeLimitForLevel(EncryptionLevel level) const {
   return SSL_quic_max_handshake_flight_len(
       ssl(), TlsConnection::BoringEncryptionLevel(level));
diff --git a/quic/core/tls_handshaker.h b/quic/core/tls_handshaker.h
index 077e373..9288592 100644
--- a/quic/core/tls_handshaker.h
+++ b/quic/core/tls_handshaker.h
@@ -50,12 +50,38 @@
   ssl_early_data_reason_t EarlyDataReason() const;
   std::unique_ptr<QuicDecrypter> AdvanceKeysAndCreateCurrentOneRttDecrypter();
   std::unique_ptr<QuicEncrypter> CreateCurrentOneRttEncrypter();
+  virtual HandshakeState GetHandshakeState() const = 0;
 
  protected:
-  virtual void AdvanceHandshake() = 0;
+  // Called when a new message is received on the crypto stream and is available
+  // for the TLS stack to read.
+  void AdvanceHandshake();
 
-  virtual void CloseConnection(QuicErrorCode error,
-                               const std::string& reason_phrase) = 0;
+  void CloseConnection(QuicErrorCode error, const std::string& reason_phrase);
+
+  void OnConnectionClosed(QuicErrorCode error, ConnectionCloseSource source);
+
+  bool is_connection_closed() const { return is_connection_closed_; }
+
+  // Called when |SSL_do_handshake| returns 1, indicating that the handshake has
+  // finished. Note that due to 0-RTT, the handshake may "finish" twice;
+  // |SSL_in_early_data| can be used to determine whether the handshake is truly
+  // done.
+  virtual void FinishHandshake() = 0;
+
+  // Called when a handshake message is received after the handshake is
+  // complete.
+  virtual void ProcessPostHandshakeMessage() = 0;
+
+  // Called when an unexpected error code is received from |SSL_get_error|. If a
+  // subclass can expect more than just a single error (as provided by
+  // |set_expected_ssl_error|), it can override this method to handle that case.
+  virtual bool ShouldCloseConnectionOnUnexpectedError(int ssl_error);
+
+  void set_expected_ssl_error(int ssl_error) {
+    expected_ssl_error_ = ssl_error;
+  }
+  int expected_ssl_error() const { return expected_ssl_error_; }
 
   // Returns the PRF used by the cipher suite negotiated in the TLS handshake.
   const EVP_MD* Prf(const SSL_CIPHER* cipher);
@@ -101,6 +127,9 @@
   void SendAlert(EncryptionLevel level, uint8_t desc) override;
 
  private:
+  int expected_ssl_error_ = SSL_ERROR_WANT_READ;
+  bool is_connection_closed_ = false;
+
   QuicCryptoStream* stream_;
   HandshakerDelegateInterface* handshaker_delegate_;
 
diff --git a/quic/core/tls_server_handshaker.cc b/quic/core/tls_server_handshaker.cc
index 31c4ce5..c14deb0 100644
--- a/quic/core/tls_server_handshaker.cc
+++ b/quic/core/tls_server_handshaker.cc
@@ -36,10 +36,10 @@
     handshaker_->cert_verify_sig_ = std::move(signature);
     handshaker_->proof_source_details_ = std::move(details);
   }
-  State last_state = handshaker_->state_;
-  handshaker_->state_ = STATE_SIGNATURE_COMPLETE;
+  int last_expected_ssl_error = handshaker_->expected_ssl_error();
+  handshaker_->set_expected_ssl_error(SSL_ERROR_WANT_READ);
   handshaker_->signature_callback_ = nullptr;
-  if (last_state == STATE_SIGNATURE_PENDING) {
+  if (last_expected_ssl_error == SSL_ERROR_WANT_PRIVATE_KEY_OPERATION) {
     handshaker_->AdvanceHandshakeFromCallback();
   }
 }
@@ -69,7 +69,7 @@
   // pending), TlsServerHandshaker is not actively processing handshake
   // messages. We need to have it resume processing handshake messages by
   // calling AdvanceHandshake.
-  if (handshaker_->state_ == STATE_TICKET_DECRYPTION_PENDING) {
+  if (handshaker_->expected_ssl_error() == SSL_ERROR_PENDING_TICKET) {
     handshaker_->AdvanceHandshakeFromCallback();
   }
   // The TicketDecrypter took ownership of this callback when Decrypt was
@@ -152,9 +152,8 @@
     CachedNetworkParameters /*cached_network_params*/) {}
 
 void TlsServerHandshaker::OnPacketDecrypted(EncryptionLevel level) {
-  if (level == ENCRYPTION_HANDSHAKE &&
-      state_ < STATE_ENCRYPTION_HANDSHAKE_DATA_PROCESSED) {
-    state_ = STATE_ENCRYPTION_HANDSHAKE_DATA_PROCESSED;
+  if (level == ENCRYPTION_HANDSHAKE && state_ < HANDSHAKE_PROCESSED) {
+    state_ = HANDSHAKE_PROCESSED;
     handshaker_delegate()->DiscardOldEncryptionKey(ENCRYPTION_INITIAL);
     handshaker_delegate()->DiscardOldDecryptionKey(ENCRYPTION_INITIAL);
   }
@@ -172,9 +171,9 @@
   return proof_source_details_.get();
 }
 
-void TlsServerHandshaker::OnConnectionClosed(QuicErrorCode /*error*/,
-                                             ConnectionCloseSource /*source*/) {
-  state_ = STATE_CONNECTION_CLOSED;
+void TlsServerHandshaker::OnConnectionClosed(QuicErrorCode error,
+                                             ConnectionCloseSource source) {
+  TlsHandshaker::OnConnectionClosed(error, source);
 }
 
 ssl_early_data_reason_t TlsServerHandshaker::EarlyDataReason() const {
@@ -186,7 +185,7 @@
 }
 
 bool TlsServerHandshaker::one_rtt_keys_available() const {
-  return one_rtt_keys_available_;
+  return state_ == HANDSHAKE_CONFIRMED;
 }
 
 const QuicCryptoNegotiatedParameters&
@@ -199,13 +198,7 @@
 }
 
 HandshakeState TlsServerHandshaker::GetHandshakeState() const {
-  if (one_rtt_keys_available_) {
-    return HANDSHAKE_CONFIRMED;
-  }
-  if (state_ >= STATE_ENCRYPTION_HANDSHAKE_DATA_PROCESSED) {
-    return HANDSHAKE_PROCESSED;
-  }
-  return HANDSHAKE_START;
+  return state_;
 }
 
 void TlsServerHandshaker::SetServerApplicationStateForResumption(
@@ -234,65 +227,17 @@
 
 void TlsServerHandshaker::OverrideQuicConfigDefaults(QuicConfig* /*config*/) {}
 
-void TlsServerHandshaker::AdvanceHandshake() {
-  if (state_ == STATE_CONNECTION_CLOSED) {
-    QUIC_LOG(INFO) << "TlsServerHandshaker received handshake message after "
-                      "connection was closed";
-    return;
-  }
-  if (state_ == STATE_HANDSHAKE_COMPLETE) {
-    // TODO(nharper): Handle post-handshake messages.
-    return;
-  }
-
-  int rv = SSL_do_handshake(ssl());
-  if (rv == 1) {
-    FinishHandshake();
-    return;
-  }
-
-  int ssl_error = SSL_get_error(ssl(), rv);
-  bool should_close = true;
-  switch (state_) {
-    case STATE_LISTENING:
-    case STATE_SIGNATURE_COMPLETE:
-      should_close = ssl_error != SSL_ERROR_WANT_READ;
-      break;
-    case STATE_SIGNATURE_PENDING:
-      should_close = ssl_error != SSL_ERROR_WANT_PRIVATE_KEY_OPERATION;
-      break;
-    case STATE_TICKET_DECRYPTION_PENDING:
-      should_close = ssl_error != SSL_ERROR_PENDING_TICKET;
-      break;
-    default:
-      should_close = true;
-  }
-  if (should_close && state_ != STATE_CONNECTION_CLOSED) {
-    QUIC_VLOG(1) << "SSL_do_handshake failed; SSL_get_error returns "
-                 << ssl_error << ", state_ = " << state_;
-    ERR_print_errors_fp(stderr);
-    CloseConnection(QUIC_HANDSHAKE_FAILED,
-                    "Server observed TLS handshake failure");
-  }
-}
-
 void TlsServerHandshaker::AdvanceHandshakeFromCallback() {
   AdvanceHandshake();
   if (GetQuicReloadableFlag(
           quic_process_undecryptable_packets_after_async_decrypt_callback) &&
-      state_ != STATE_CONNECTION_CLOSED) {
+      !is_connection_closed()) {
     QUIC_RELOADABLE_FLAG_COUNT(
         quic_process_undecryptable_packets_after_async_decrypt_callback);
     handshaker_delegate()->OnHandshakeCallbackDone();
   }
 }
 
-void TlsServerHandshaker::CloseConnection(QuicErrorCode error,
-                                          const std::string& reason_phrase) {
-  state_ = STATE_CONNECTION_CLOSED;
-  stream()->OnUnrecoverableError(error, reason_phrase);
-}
-
 bool TlsServerHandshaker::ProcessTransportParameters(
     std::string* error_details) {
   TransportParameters client_params;
@@ -395,7 +340,7 @@
     EncryptionLevel level,
     const SSL_CIPHER* cipher,
     const std::vector<uint8_t>& write_secret) {
-  if (state_ == STATE_CONNECTION_CLOSED) {
+  if (is_connection_closed()) {
     return;
   }
   if (level == ENCRYPTION_FORWARD_SECURE) {
@@ -436,8 +381,7 @@
   QUIC_DLOG(INFO) << "Server: handshake finished. Early data reason "
                   << reason_code << " ("
                   << CryptoUtils::EarlyDataReasonToString(reason_code) << ")";
-  state_ = STATE_HANDSHAKE_COMPLETE;
-  one_rtt_keys_available_ = true;
+  state_ = HANDSHAKE_CONFIRMED;
 
   handshaker_delegate()->OnTlsHandshakeComplete();
   handshaker_delegate()->DiscardOldEncryptionKey(ENCRYPTION_HANDSHAKE);
@@ -456,18 +400,18 @@
       session()->connection()->self_address(),
       session()->connection()->peer_address(), hostname_, sig_alg, in,
       std::unique_ptr<SignatureCallback>(signature_callback_));
-  if (state_ == STATE_SIGNATURE_COMPLETE) {
-    return PrivateKeyComplete(out, out_len, max_out);
+  if (signature_callback_) {
+    set_expected_ssl_error(SSL_ERROR_WANT_PRIVATE_KEY_OPERATION);
+    return ssl_private_key_retry;
   }
-  state_ = STATE_SIGNATURE_PENDING;
-  return ssl_private_key_retry;
+  return PrivateKeyComplete(out, out_len, max_out);
 }
 
 ssl_private_key_result_t TlsServerHandshaker::PrivateKeyComplete(
     uint8_t* out,
     size_t* out_len,
     size_t max_out) {
-  if (state_ == STATE_SIGNATURE_PENDING) {
+  if (expected_ssl_error() == SSL_ERROR_WANT_PRIVATE_KEY_OPERATION) {
     return ssl_private_key_retry;
   }
   if (cert_verify_sig_.size() > max_out || cert_verify_sig_.empty()) {
@@ -524,12 +468,12 @@
     // and when the callback is complete this function will be run again to
     // return the result.
     if (ticket_decryption_callback_) {
-      state_ = STATE_TICKET_DECRYPTION_PENDING;
+      set_expected_ssl_error(SSL_ERROR_PENDING_TICKET);
       return ssl_ticket_aead_retry;
     }
   }
   ticket_decryption_callback_ = nullptr;
-  state_ = STATE_LISTENING;
+  set_expected_ssl_error(SSL_ERROR_WANT_READ);
   if (decrypted_session_ticket_.empty()) {
     QUIC_DLOG(ERROR) << "Session ticket decryption failed; ignoring ticket";
     // Ticket decryption failed. Ignore the ticket.
diff --git a/quic/core/tls_server_handshaker.h b/quic/core/tls_server_handshaker.h
index b6ea353..fab4a87 100644
--- a/quic/core/tls_server_handshaker.h
+++ b/quic/core/tls_server_handshaker.h
@@ -86,15 +86,13 @@
   virtual void ProcessAdditionalTransportParameters(
       const TransportParameters& /*params*/) {}
 
-  // Called when a new message is received on the crypto stream and is available
-  // for the TLS stack to read.
-  void AdvanceHandshake() override;
-
   // Called when a potentially async operation is done and the done callback
   // needs to advance the handshake.
   void AdvanceHandshakeFromCallback();
-  void CloseConnection(QuicErrorCode error,
-                       const std::string& reason_phrase) override;
+
+  // TlsHandshaker implementation:
+  void FinishHandshake() override;
+  void ProcessPostHandshakeMessage() override {}
 
   // TlsServerConnection::Delegate implementation:
   int SelectCertificate(int* out_alert) override;
@@ -150,26 +148,9 @@
     TlsServerHandshaker* handshaker_;
   };
 
-  enum State {
-    STATE_LISTENING,
-    STATE_TICKET_DECRYPTION_PENDING,
-    STATE_SIGNATURE_PENDING,
-    STATE_SIGNATURE_COMPLETE,
-    STATE_ENCRYPTION_HANDSHAKE_DATA_PROCESSED,
-    STATE_HANDSHAKE_COMPLETE,
-    STATE_CONNECTION_CLOSED,
-  };
-
-  // Called when the TLS handshake is complete.
-  void FinishHandshake();
-
-  void CloseConnection(const std::string& reason_phrase);
-
   bool SetTransportParameters();
   bool ProcessTransportParameters(std::string* error_details);
 
-  State state_ = STATE_LISTENING;
-
   ProofSource* proof_source_;
   SignatureCallback* signature_callback_ = nullptr;
 
@@ -196,6 +177,7 @@
   // Pre-shared key used during the handshake.
   std::string pre_shared_key_;
 
+  HandshakeState state_ = HANDSHAKE_START;
   bool encryption_established_ = false;
   bool one_rtt_keys_available_ = false;
   bool valid_alpn_received_ = false;