Restore connection context in various callbacks in TlsServerHandshaker.

Protected by FLAGS_quic_reloadable_flag_quic_tls_restore_connection_context_in_callbacks.

PiperOrigin-RevId: 395476747
diff --git a/quic/core/quic_flags_list.h b/quic/core/quic_flags_list.h
index 423401a..631c38f 100644
--- a/quic/core/quic_flags_list.h
+++ b/quic/core/quic_flags_list.h
@@ -99,6 +99,8 @@
 QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_reset_per_packet_state_for_undecryptable_packets, true)
 // If true, respect FLAGS_quic_time_wait_list_max_pending_packets as the upper bound of queued packets in time wait list.
 QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_add_upperbound_for_queued_packets, true)
+// If true, restore connection context in various callbacks in TlsServerHandshaker.
+QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_tls_restore_connection_context_in_callbacks, true)
 // If true, send PATH_RESPONSE upon receiving PATH_CHALLENGE regardless of perspective. --gfe2_reloadable_flag_quic_start_peer_migration_earlier has to be true before turn on this flag.
 QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_send_path_response2, true)
 // If true, set burst token to 2 in cwnd bootstrapping experiment.
diff --git a/quic/core/tls_server_handshaker.cc b/quic/core/tls_server_handshaker.cc
index 8112b41..514c462 100644
--- a/quic/core/tls_server_handshaker.cc
+++ b/quic/core/tls_server_handshaker.cc
@@ -143,6 +143,21 @@
   handshaker_ = nullptr;
 
   handshaker->decrypted_session_ticket_ = std::move(plaintext);
+  const bool is_async =
+      (handshaker->expected_ssl_error() == SSL_ERROR_PENDING_TICKET);
+
+  absl::optional<QuicConnectionContextSwitcher> context_switcher;
+  if (handshaker->restore_connection_context_in_callbacks_) {
+    QUIC_RELOADABLE_FLAG_COUNT_N(
+        quic_tls_restore_connection_context_in_callbacks, 1, 3);
+    if (is_async) {
+      context_switcher.emplace(handshaker->connection_context());
+    }
+    QUIC_TRACESTRING(
+        absl::StrCat("TLS ticket decryption done. len(decrypted_ticket):",
+                     handshaker->decrypted_session_ticket_.size()));
+  }
+
   // DecryptCallback::Run could be called synchronously. When that happens, we
   // are currently in the middle of a call to AdvanceHandshake.
   // (AdvanceHandshake called SSL_do_handshake, which through some layers
@@ -154,7 +169,7 @@
   // is pending), TlsServerHandshaker is not actively processing handshake
   // messages. We need to have it resume processing handshake messages by
   // calling AdvanceHandshake.
-  if (handshaker->expected_ssl_error() == SSL_ERROR_PENDING_TICKET) {
+  if (is_async) {
     handshaker->AdvanceHandshakeFromCallback();
   }
 
@@ -668,6 +683,19 @@
   QUIC_DVLOG(1) << "OnComputeSignatureDone. ok:" << ok
                 << ", is_sync:" << is_sync
                 << ", len(signature):" << signature.size();
+  absl::optional<QuicConnectionContextSwitcher> context_switcher;
+  if (restore_connection_context_in_callbacks_) {
+    QUIC_RELOADABLE_FLAG_COUNT_N(
+        quic_tls_restore_connection_context_in_callbacks, 2, 3);
+
+    if (!is_sync) {
+      context_switcher.emplace(connection_context());
+    }
+
+    QUIC_TRACESTRING(absl::StrCat("TLS compute signature done. ok:", ok,
+                                  ", len(signature):", signature.size()));
+  }
+
   if (ok) {
     cert_verify_sig_ = std::move(signature);
     proof_source_details_ = std::move(details);
@@ -935,6 +963,20 @@
                 << ", len(handshake_hints):" << handshake_hints.size()
                 << ", len(ticket_encryption_key):"
                 << ticket_encryption_key.size();
+  absl::optional<QuicConnectionContextSwitcher> context_switcher;
+  if (restore_connection_context_in_callbacks_) {
+    QUIC_RELOADABLE_FLAG_COUNT_N(
+        quic_tls_restore_connection_context_in_callbacks, 3, 3);
+
+    if (!is_sync) {
+      context_switcher.emplace(connection_context());
+    }
+
+    QUIC_TRACESTRING(absl::StrCat(
+        "TLS select certificate done: ok:", ok,
+        ", len(handshake_hints):", handshake_hints.size(),
+        ", len(ticket_encryption_key):", ticket_encryption_key.size()));
+  }
   ticket_encryption_key_ = std::string(ticket_encryption_key);
   select_cert_status_ = QUIC_FAILURE;
   if (ok) {
diff --git a/quic/core/tls_server_handshaker.h b/quic/core/tls_server_handshaker.h
index 0f7289e..948a29a 100644
--- a/quic/core/tls_server_handshaker.h
+++ b/quic/core/tls_server_handshaker.h
@@ -313,6 +313,11 @@
   }
   QuicTime now() const { return session()->GetClock()->Now(); }
 
+  QuicConnectionContext* connection_context() {
+    QUICHE_DCHECK(restore_connection_context_in_callbacks_);
+    return session()->connection()->context();
+  }
+
   std::unique_ptr<ProofSourceHandle> proof_source_handle_;
   ProofSource* proof_source_;
 
@@ -357,6 +362,8 @@
       crypto_negotiated_params_;
   TlsServerConnection tls_connection_;
   const QuicCryptoServerConfig* crypto_config_;  // Unowned.
+  const bool restore_connection_context_in_callbacks_ =
+      GetQuicReloadableFlag(quic_tls_restore_connection_context_in_callbacks);
 };
 
 }  // namespace quic