In TlsServerHandshaker, switch from BoringSSL's tlsext_servername_callback to select_cert_cb for cert selection.

This paves the way to go/split-handshake-v2.

Protected by FLAGS_quic_reloadable_flag_quic_tls_use_early_select_cert.

PiperOrigin-RevId: 343531889
Change-Id: Ib0c677c18f50526fc86ed45b532f460cdedcd119
diff --git a/quic/core/crypto/tls_server_connection.cc b/quic/core/crypto/tls_server_connection.cc
index f47e5ad..0208a9c 100644
--- a/quic/core/crypto/tls_server_connection.cc
+++ b/quic/core/crypto/tls_server_connection.cc
@@ -40,6 +40,8 @@
        GetQuicRestartFlag(quic_session_tickets_always_enabled))) {
     SSL_CTX_set_early_data_enabled(ssl_ctx.get(), 1);
   }
+  SSL_CTX_set_select_certificate_cb(
+      ssl_ctx.get(), &TlsServerConnection::EarlySelectCertCallback);
   return ssl_ctx;
 }
 
@@ -62,6 +64,13 @@
 }
 
 // static
+ssl_select_cert_result_t TlsServerConnection::EarlySelectCertCallback(
+    const SSL_CLIENT_HELLO* client_hello) {
+  return ConnectionFromSsl(client_hello->ssl)
+      ->delegate_->EarlySelectCertCallback(client_hello);
+}
+
+// static
 int TlsServerConnection::SelectCertificateCallback(SSL* ssl,
                                                    int* out_alert,
                                                    void* /*arg*/) {
diff --git a/quic/core/crypto/tls_server_connection.h b/quic/core/crypto/tls_server_connection.h
index 954f830..50c6f7c 100644
--- a/quic/core/crypto/tls_server_connection.h
+++ b/quic/core/crypto/tls_server_connection.h
@@ -22,12 +22,19 @@
     virtual ~Delegate() {}
 
    protected:
+    // Called from BoringSSL right after SNI is extracted, which is very early
+    // in the handshake process.
+    virtual ssl_select_cert_result_t EarlySelectCertCallback(
+        const SSL_CLIENT_HELLO* client_hello) = 0;
+
     // Configures the certificate to use on |ssl_| based on the SNI sent by the
     // client. Returns an SSL_TLSEXT_ERR_* value (see
     // https://commondatastorage.googleapis.com/chromium-boringssl-docs/ssl.h.html#SSL_CTX_set_tlsext_servername_callback).
     //
     // If SelectCertificate returns SSL_TLSEXT_ERR_ALERT_FATAL, then it puts in
     // |*out_alert| the TLS alert value that the server will send.
+    //
+    // TODO(wub): Deprecate it after enabling --quic_tls_use_early_select_cert.
     virtual int SelectCertificate(int* out_alert) = 0;
 
     // Selects which ALPN to use based on the list sent by the client.
@@ -121,6 +128,9 @@
   // Specialization of TlsConnection::ConnectionFromSsl.
   static TlsServerConnection* ConnectionFromSsl(SSL* ssl);
 
+  static ssl_select_cert_result_t EarlySelectCertCallback(
+      const SSL_CLIENT_HELLO* client_hello);
+
   // These functions are registered as callbacks in BoringSSL and delegate their
   // implementation to the matching methods in Delegate above.
   static int SelectCertificateCallback(SSL* ssl, int* out_alert, void* arg);
diff --git a/quic/core/quic_flags_list.h b/quic/core/quic_flags_list.h
index 10c0cb7..2fd5426 100644
--- a/quic/core/quic_flags_list.h
+++ b/quic/core/quic_flags_list.h
@@ -79,6 +79,7 @@
 QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_stop_sending_uses_ietf_error_code, true)
 QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_testonly_default_false, false)
 QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_testonly_default_true, true)
+QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_tls_use_early_select_cert, false)
 QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_unified_iw_options, false)
 QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_use_circular_deque_for_unacked_packets, false)
 QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_use_encryption_level_context, false)
diff --git a/quic/core/tls_server_handshaker.cc b/quic/core/tls_server_handshaker.cc
index 01b5df6..38a4638 100644
--- a/quic/core/tls_server_handshaker.cc
+++ b/quic/core/tls_server_handshaker.cc
@@ -243,12 +243,26 @@
 }
 
 bool TlsServerHandshaker::ProcessTransportParameters(
+    const SSL_CLIENT_HELLO* client_hello,
     std::string* error_details) {
   TransportParameters client_params;
   const uint8_t* client_params_bytes;
   size_t params_bytes_len;
-  SSL_get_peer_quic_transport_params(ssl(), &client_params_bytes,
-                                     &params_bytes_len);
+  if (use_early_select_cert_) {
+    // When using early select cert callback, SSL_get_peer_quic_transport_params
+    // can not be used to retrieve the client's transport parameters, but we can
+    // use SSL_early_callback_ctx_extension_get to do that.
+    if (!SSL_early_callback_ctx_extension_get(
+            client_hello, TLSEXT_TYPE_quic_transport_parameters,
+            &client_params_bytes, &params_bytes_len)) {
+      params_bytes_len = 0;
+    }
+  } else {
+    DCHECK_EQ(client_hello, nullptr);
+    SSL_get_peer_quic_transport_params(ssl(), &client_params_bytes,
+                                       &params_bytes_len);
+  }
+
   if (params_bytes_len == 0) {
     *error_details = "Client's transport parameters are missing";
     return false;
@@ -506,6 +520,62 @@
   return ssl_ticket_aead_success;
 }
 
+ssl_select_cert_result_t TlsServerHandshaker::EarlySelectCertCallback(
+    const SSL_CLIENT_HELLO* client_hello) {
+  if (!use_early_select_cert_) {
+    return ssl_select_cert_success;
+  }
+
+  if (!pre_shared_key_.empty()) {
+    // TODO(b/154162689) add PSK support to QUIC+TLS.
+    QUIC_BUG << "QUIC server pre-shared keys not yet supported with TLS";
+    return ssl_select_cert_error;
+  }
+
+  // This callback is called very early by Boring SSL, most of the SSL_get_foo
+  // function do not work at this point, but SSL_get_servername does.
+  const char* hostname = SSL_get_servername(ssl(), TLSEXT_NAMETYPE_host_name);
+  if (hostname) {
+    hostname_ = hostname;
+    crypto_negotiated_params_->sni =
+        QuicHostnameUtils::NormalizeHostname(hostname_);
+    if (!ValidateHostname(hostname_)) {
+      return ssl_select_cert_error;
+    }
+  } else {
+    QUIC_LOG(INFO) << "No hostname indicated in SNI";
+  }
+
+  QuicReferenceCountedPointer<ProofSource::Chain> chain =
+      proof_source_->GetCertChain(session()->connection()->self_address(),
+                                  session()->connection()->peer_address(),
+                                  hostname_);
+  if (!chain || chain->certs.empty()) {
+    QUIC_LOG(ERROR) << "No certs provided for host '" << hostname_ << "'";
+    return ssl_select_cert_error;
+  }
+
+  CryptoBuffers cert_buffers = chain->ToCryptoBuffers();
+  tls_connection_.SetCertChain(cert_buffers.value);
+
+  std::string error_details;
+  if (!ProcessTransportParameters(client_hello, &error_details)) {
+    CloseConnection(QUIC_HANDSHAKE_FAILED, error_details);
+    return ssl_select_cert_error;
+  }
+  OverrideQuicConfigDefaults(session()->config());
+  session()->OnConfigNegotiated();
+
+  if (!SetTransportParameters()) {
+    QUIC_LOG(ERROR) << "Failed to set transport parameters";
+    return ssl_select_cert_error;
+  }
+
+  QUIC_DLOG(INFO) << "Set " << chain->certs.size() << " certs for server "
+                  << "with hostname " << hostname_;
+  return ssl_select_cert_success;
+}
+
 bool TlsServerHandshaker::ValidateHostname(const std::string& hostname) const {
   if (!QuicHostnameUtils::IsValidSNI(hostname)) {
     // TODO(b/151676147): Include this error string in the CONNECTION_CLOSE
@@ -517,6 +587,10 @@
 }
 
 int TlsServerHandshaker::SelectCertificate(int* out_alert) {
+  if (use_early_select_cert_) {
+    return SSL_TLSEXT_ERR_OK;
+  }
+
   const char* hostname = SSL_get_servername(ssl(), TLSEXT_NAMETYPE_host_name);
   if (hostname) {
     hostname_ = hostname;
@@ -548,7 +622,7 @@
   tls_connection_.SetCertChain(cert_buffers.value);
 
   std::string error_details;
-  if (!ProcessTransportParameters(&error_details)) {
+  if (!ProcessTransportParameters(nullptr, &error_details)) {
     CloseConnection(QUIC_HANDSHAKE_FAILED, error_details);
     *out_alert = SSL_AD_INTERNAL_ERROR;
     return SSL_TLSEXT_ERR_ALERT_FATAL;
diff --git a/quic/core/tls_server_handshaker.h b/quic/core/tls_server_handshaker.h
index c6426ce..09d03eb 100644
--- a/quic/core/tls_server_handshaker.h
+++ b/quic/core/tls_server_handshaker.h
@@ -103,6 +103,9 @@
       const ProofVerifyDetails& verify_details) override;
 
   // TlsServerConnection::Delegate implementation:
+  // Used to select certificates and process transport parameters.
+  ssl_select_cert_result_t EarlySelectCertCallback(
+      const SSL_CLIENT_HELLO* client_hello) override;
   int SelectCertificate(int* out_alert) override;
   int SelectAlpn(const uint8_t** out,
                  uint8_t* out_len,
@@ -158,7 +161,8 @@
 
   virtual bool ValidateHostname(const std::string& hostname) const;
   bool SetTransportParameters();
-  bool ProcessTransportParameters(std::string* error_details);
+  bool ProcessTransportParameters(const SSL_CLIENT_HELLO* client_hello,
+                                  std::string* error_details);
 
   ProofSource* proof_source_;
   SignatureCallback* signature_callback_ = nullptr;
@@ -192,6 +196,8 @@
   QuicReferenceCountedPointer<QuicCryptoNegotiatedParameters>
       crypto_negotiated_params_;
   TlsServerConnection tls_connection_;
+  const bool use_early_select_cert_ =
+      GetQuicReloadableFlag(quic_tls_use_early_select_cert);
 };
 
 }  // namespace quic