In QUIC TlsServerHandshaker, ensure DecryptCallback runs at most once.

Protected by FLAGS_quic_reloadable_flag_quic_tls_fix_ticket_decrypt.

PiperOrigin-RevId: 389907850
diff --git a/quic/core/quic_flags_list.h b/quic/core/quic_flags_list.h
index 3530ba4..00007a5 100644
--- a/quic/core/quic_flags_list.h
+++ b/quic/core/quic_flags_list.h
@@ -71,6 +71,8 @@
 QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_drop_unsent_path_response, true)
 // If true, enable server retransmittable on wire PING.
 QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_enable_server_on_wire_ping, true)
+// If true, fix a bug in TlsServerHandshaker where the ticket decrypt callback is cleared without being cancelled first.
+QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_tls_fix_ticket_decrypt, true)
 // If true, ignore peer_max_ack_delay during handshake.
 QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_ignore_peer_max_ack_delay_during_handshake, true)
 // If true, include stream information in idle timeout connection close detail.
diff --git a/quic/core/tls_server_handshaker.cc b/quic/core/tls_server_handshaker.cc
index 800d6f8..dcd7ea3 100644
--- a/quic/core/tls_server_handshaker.cc
+++ b/quic/core/tls_server_handshaker.cc
@@ -138,6 +138,29 @@
     // The callback was cancelled before we could run.
     return;
   }
+  if (handshaker_->fix_ticket_decrypt_) {
+    TlsServerHandshaker* handshaker = handshaker_;
+    handshaker_ = nullptr;
+
+    handshaker->decrypted_session_ticket_ = std::move(plaintext);
+    // DecryptCallback::Run could be called synchronously. When that happens, we
+    // are currently in the middle of a call to AdvanceHandshake.
+    // (AdvanceHandshake called SSL_do_handshake, which through some layers
+    // called SessionTicketOpen, which called TicketCrypter::Decrypt, which
+    // synchronously called this function.) In that case, the handshake will
+    // continue to be processed when this function returns.
+    //
+    // When this callback is called asynchronously (i.e. the ticket decryption
+    // is pending), TlsServerHandshaker is not actively processing handshake
+    // messages. We need to have it resume processing handshake messages by
+    // calling AdvanceHandshake.
+    if (handshaker->expected_ssl_error() == SSL_ERROR_PENDING_TICKET) {
+      handshaker->AdvanceHandshakeFromCallback();
+    }
+
+    handshaker->ticket_decryption_callback_ = nullptr;
+    return;
+  }
   handshaker_->decrypted_session_ticket_ = std::move(plaintext);
   // DecryptCallback::Run could be called synchronously. When that happens, we
   // are currently in the middle of a call to AdvanceHandshake.
@@ -725,15 +748,17 @@
     ticket_decryption_callback_ = new DecryptCallback(this);
     proof_source_->GetTicketCrypter()->Decrypt(
         in, std::unique_ptr<DecryptCallback>(ticket_decryption_callback_));
+
     // Decrypt can run the callback synchronously. In that case, the callback
     // will clear the ticket_decryption_callback_ pointer, and instead of
-    // returning ssl_ticket_aead_retry, we should continue processing to return
-    // the decrypted ticket.
+    // returning ssl_ticket_aead_retry, we should continue processing to
+    // return the decrypted ticket.
     //
     // If the callback is not run synchronously, return ssl_ticket_aead_retry
     // and when the callback is complete this function will be run again to
     // return the result.
     if (ticket_decryption_callback_) {
+      QUICHE_DCHECK(!ticket_decryption_callback_->IsDone());
       set_expected_ssl_error(SSL_ERROR_PENDING_TICKET);
       if (async_op_timer_.has_value()) {
         QUIC_CODE_COUNT(
@@ -741,6 +766,18 @@
       }
       async_op_timer_ = QuicTimeAccumulator();
       async_op_timer_->Start(now());
+
+      if (!fix_ticket_decrypt_) {
+        return ssl_ticket_aead_retry;
+      }
+    }
+  }
+
+  if (fix_ticket_decrypt_) {
+    // If the async ticket decryption is pending, either started by this
+    // SessionTicketOpen call or one that happened earlier, return
+    // ssl_ticket_aead_retry.
+    if (ticket_decryption_callback_ && !ticket_decryption_callback_->IsDone()) {
       return ssl_ticket_aead_retry;
     }
   }
diff --git a/quic/core/tls_server_handshaker.h b/quic/core/tls_server_handshaker.h
index d1a476f..7ed1a8c 100644
--- a/quic/core/tls_server_handshaker.h
+++ b/quic/core/tls_server_handshaker.h
@@ -198,6 +198,11 @@
     // If called, Cancel causes the pending callback to be a no-op.
     void Cancel();
 
+    // Return true if either
+    // - Cancel() has been called.
+    // - Run() has been called, or is in the middle of it.
+    bool IsDone() const { return handshaker_ == nullptr; }
+
    private:
     TlsServerHandshaker* handshaker_;
   };
@@ -348,6 +353,8 @@
   HandshakeState state_ = HANDSHAKE_START;
   bool encryption_established_ = false;
   bool valid_alpn_received_ = false;
+  const bool fix_ticket_decrypt_ =
+      GetQuicReloadableFlag(quic_tls_fix_ticket_decrypt);
   QuicReferenceCountedPointer<QuicCryptoNegotiatedParameters>
       crypto_negotiated_params_;
   TlsServerConnection tls_connection_;
diff --git a/quic/core/tls_server_handshaker_test.cc b/quic/core/tls_server_handshaker_test.cc
index 6f8db12..2478d8f 100644
--- a/quic/core/tls_server_handshaker_test.cc
+++ b/quic/core/tls_server_handshaker_test.cc
@@ -101,6 +101,7 @@
     return fake_proof_source_handle_;
   }
 
+  using TlsServerHandshaker::AdvanceHandshake;
   using TlsServerHandshaker::expected_ssl_error;
 
  private:
@@ -708,6 +709,43 @@
   EXPECT_TRUE(server_stream()->ResumptionAttempted());
 }
 
+TEST_P(TlsServerHandshakerTest, AdvanceHandshakeDuringAsyncDecryptCallback) {
+  if (GetParam().disable_resumption) {
+    return;
+  }
+
+  // Do the first handshake
+  InitializeFakeClient();
+  CompleteCryptoHandshake();
+  ExpectHandshakeSuccessful();
+
+  ticket_crypter_->SetRunCallbacksAsync(true);
+  // Now do another handshake
+  InitializeServerWithFakeProofSourceHandle();
+  server_handshaker_->SetupProofSourceHandle(
+      /*select_cert_action=*/FakeProofSourceHandle::Action::DELEGATE_SYNC,
+      /*compute_signature_action=*/FakeProofSourceHandle::Action::
+          DELEGATE_SYNC);
+  InitializeFakeClient();
+
+  AdvanceHandshakeWithFakeClient();
+
+  // Ensure an async DecryptCallback is now pending.
+  ASSERT_EQ(ticket_crypter_->NumPendingCallbacks(), 1u);
+
+  {
+    QuicConnection::ScopedPacketFlusher flusher(server_connection_);
+    server_handshaker_->AdvanceHandshake();
+  }
+
+  // This will delete |server_handshaker_|.
+  server_session_ = nullptr;
+
+  if (GetQuicReloadableFlag(quic_tls_fix_ticket_decrypt)) {
+    ticket_crypter_->RunPendingCallback(0);  // Should not crash.
+  }
+}
+
 TEST_P(TlsServerHandshakerTest, ResumptionWithFailingDecryptCallback) {
   if (GetParam().disable_resumption) {
     return;