Group configs and factor out config lookup

This CL factors out common code for grabbing the primary and requested configs to be passed along through other handshake steps.  In addition to reducing code duplication and boilerplate, this will be used as a natural way to add a new fallback config (for when Leto is not working) without complexifying the code.

gfe-relnote: No-op refactoring, not flag-protected.
PiperOrigin-RevId: 239287118
Change-Id: Ia36f9b8dafe7a6e3ca9212b5d0779e296579a76d
diff --git a/quic/core/crypto/quic_crypto_server_config.cc b/quic/core/crypto/quic_crypto_server_config.cc
index 426e8da..4c77e21 100644
--- a/quic/core/crypto/quic_crypto_server_config.cc
+++ b/quic/core/crypto/quic_crypto_server_config.cc
@@ -518,39 +518,21 @@
 
   QuicStringPiece requested_scid;
   client_hello.GetStringPiece(kSCID, &requested_scid);
-
-  QuicReferenceCountedPointer<Config> requested_config;
-  QuicReferenceCountedPointer<Config> primary_config;
-  {
-    QuicReaderMutexLock locked(&configs_lock_);
-    if (!primary_config_.get()) {
-      result->error_code = QUIC_CRYPTO_INTERNAL_ERROR;
-      result->error_details = "No configurations loaded";
-    } else {
-      if (IsNextConfigReady(now)) {
-        configs_lock_.ReaderUnlock();
-        configs_lock_.WriterLock();
-        SelectNewPrimaryConfig(now);
-        DCHECK(primary_config_.get());
-        DCHECK_EQ(configs_.find(primary_config_->id)->second.get(),
-                  primary_config_.get());
-        configs_lock_.WriterUnlock();
-        configs_lock_.ReaderLock();
-      }
-    }
-
-    requested_config = GetConfigWithScid(requested_scid);
-    primary_config = primary_config_;
-    signed_config->config = primary_config_;
+  Configs configs;
+  if (!GetCurrentConfigs(now, requested_scid,
+                         /* old_primary_config = */ nullptr, &configs)) {
+    result->error_code = QUIC_CRYPTO_INTERNAL_ERROR;
+    result->error_details = "No configurations loaded";
   }
+  signed_config->config = configs.primary;
 
   if (result->error_code == QUIC_NO_ERROR) {
     // QUIC requires a new proof for each CHLO so clear any existing proof.
     signed_config->chain = nullptr;
     signed_config->proof.signature = "";
     signed_config->proof.leaf_cert_scts = "";
-    EvaluateClientHello(server_address, version, requested_config,
-                        primary_config, result, std::move(done_cb));
+    EvaluateClientHello(server_address, version, configs, result,
+                        std::move(done_cb));
   } else {
     done_cb->Run(result, /* details = */ nullptr);
   }
@@ -559,17 +541,10 @@
 class QuicCryptoServerConfig::ProcessClientHelloCallback
     : public ProofSource::Callback {
  public:
-  ProcessClientHelloCallback(
-      const QuicCryptoServerConfig* config,
-      std::unique_ptr<ProcessClientHelloContext> context,
-      const QuicReferenceCountedPointer<QuicCryptoServerConfig::Config>&
-          requested_config,
-      const QuicReferenceCountedPointer<QuicCryptoServerConfig::Config>&
-          primary_config)
-      : config_(config),
-        context_(std::move(context)),
-        requested_config_(requested_config),
-        primary_config_(primary_config) {}
+  ProcessClientHelloCallback(const QuicCryptoServerConfig* config,
+                             std::unique_ptr<ProcessClientHelloContext> context,
+                             const Configs& configs)
+      : config_(config), context_(std::move(context)), configs_(configs) {}
 
   void Run(bool ok,
            const QuicReferenceCountedPointer<ProofSource::Chain>& chain,
@@ -579,18 +554,14 @@
       context_->signed_config()->chain = chain;
       context_->signed_config()->proof = proof;
     }
-    config_->ProcessClientHelloAfterGetProof(
-        !ok, std::move(details), std::move(context_), requested_config_,
-        primary_config_);
+    config_->ProcessClientHelloAfterGetProof(!ok, std::move(details),
+                                             std::move(context_), configs_);
   }
 
  private:
   const QuicCryptoServerConfig* config_;
   std::unique_ptr<ProcessClientHelloContext> context_;
-  const QuicReferenceCountedPointer<QuicCryptoServerConfig::Config>
-      requested_config_;
-  const QuicReferenceCountedPointer<QuicCryptoServerConfig::Config>
-      primary_config_;
+  const Configs configs_;
 };
 
 class QuicCryptoServerConfig::ProcessClientHelloAfterGetProofCallback
@@ -603,19 +574,19 @@
       std::unique_ptr<CryptoHandshakeMessage> out,
       QuicStringPiece public_value,
       std::unique_ptr<ProcessClientHelloContext> context,
-      const QuicReferenceCountedPointer<Config>& requested_config)
+      const Configs& configs)
       : config_(config),
         proof_source_details_(std::move(proof_source_details)),
         key_exchange_type_(key_exchange_type),
         out_(std::move(out)),
         public_value_(public_value),
         context_(std::move(context)),
-        requested_config_(requested_config) {}
+        configs_(configs) {}
 
   void Run(bool ok) override {
     config_->ProcessClientHelloAfterCalculateSharedKeys(
         !ok, std::move(proof_source_details_), key_exchange_type_,
-        std::move(out_), public_value_, std::move(context_), requested_config_);
+        std::move(out_), public_value_, std::move(context_), configs_);
   }
 
  private:
@@ -625,7 +596,7 @@
   std::unique_ptr<CryptoHandshakeMessage> out_;
   std::string public_value_;
   std::unique_ptr<ProcessClientHelloContext> context_;
-  const QuicReferenceCountedPointer<Config> requested_config_;
+  const Configs configs_;
 };
 
 void QuicCryptoServerConfig::ProcessClientHello(
@@ -655,6 +626,7 @@
       params, signed_config, total_framing_overhead, chlo_packet_size,
       std::move(done_cb));
 
+  // Verify that various parts of the CHLO are valid
   std::string error_details;
   QuicErrorCode valid = CryptoUtils::ValidateClientHello(
       context->client_hello(), context->version(),
@@ -666,35 +638,9 @@
 
   QuicStringPiece requested_scid;
   context->client_hello().GetStringPiece(kSCID, &requested_scid);
-  const QuicWallTime now(clock->WallNow());
-
-  QuicReferenceCountedPointer<Config> requested_config;
-  QuicReferenceCountedPointer<Config> primary_config;
-  bool no_primary_config = false;
-  {
-    QuicReaderMutexLock locked(&configs_lock_);
-
-    if (!primary_config_) {
-      no_primary_config = true;
-    } else {
-      if (IsNextConfigReady(now)) {
-        configs_lock_.ReaderUnlock();
-        configs_lock_.WriterLock();
-        SelectNewPrimaryConfig(now);
-        DCHECK(primary_config_.get());
-        DCHECK_EQ(configs_.find(primary_config_->id)->second.get(),
-                  primary_config_.get());
-        configs_lock_.WriterUnlock();
-        configs_lock_.ReaderLock();
-      }
-
-      // Use the config that the client requested in order to do key-agreement.
-      // Otherwise give it a copy of |primary_config_| to use.
-      primary_config = signed_config->config;
-      requested_config = GetConfigWithScid(requested_scid);
-    }
-  }
-  if (no_primary_config) {
+  Configs configs;
+  if (!GetCurrentConfigs(context->clock()->WallNow(), requested_scid,
+                         signed_config->config, &configs)) {
     context->Fail(QUIC_CRYPTO_INTERNAL_ERROR, "No configurations loaded");
     return;
   }
@@ -719,25 +665,24 @@
     const QuicTransportVersion transport_version = context->transport_version();
 
     auto cb = QuicMakeUnique<ProcessClientHelloCallback>(
-        this, std::move(context), requested_config, primary_config);
+        this, std::move(context), configs);
 
     DCHECK(proof_source_.get());
-    proof_source_->GetProof(server_address, sni, primary_config->serialized,
+    proof_source_->GetProof(server_address, sni, configs.primary->serialized,
                             transport_version, chlo_hash, std::move(cb));
     return;
   }
 
   ProcessClientHelloAfterGetProof(
       /* found_error = */ false, /* proof_source_details = */ nullptr,
-      std::move(context), requested_config, primary_config);
+      std::move(context), configs);
 }
 
 void QuicCryptoServerConfig::ProcessClientHelloAfterGetProof(
     bool found_error,
     std::unique_ptr<ProofSource::Details> proof_source_details,
     std::unique_ptr<ProcessClientHelloContext> context,
-    const QuicReferenceCountedPointer<Config>& requested_config,
-    const QuicReferenceCountedPointer<Config>& primary_config) const {
+    const Configs& configs) const {
   QUIC_BUG_IF(!QuicUtils::IsConnectionIdValidForVersion(
       context->connection_id(), context->transport_version()))
       << "ProcessClientHelloAfterGetProof: attempted to use connection ID "
@@ -758,8 +703,8 @@
   }
 
   auto out = QuicMakeUnique<CryptoHandshakeMessage>();
-  if (!context->info().reject_reasons.empty() || !requested_config) {
-    BuildRejection(*context, *primary_config, context->info().reject_reasons,
+  if (!context->info().reject_reasons.empty() || !configs.requested) {
+    BuildRejection(*context, *configs.primary, context->info().reject_reasons,
                    out.get());
     if (rejection_observer_ != nullptr) {
       rejection_observer_->OnRejectionBuilt(context->info().reject_reasons,
@@ -789,22 +734,22 @@
   }
 
   size_t key_exchange_index;
-  if (!FindMutualQuicTag(requested_config->aead, their_aeads,
+  if (!FindMutualQuicTag(configs.requested->aead, their_aeads,
                          &context->params()->aead, nullptr) ||
-      !FindMutualQuicTag(requested_config->kexs, their_key_exchanges,
+      !FindMutualQuicTag(configs.requested->kexs, their_key_exchanges,
                          &context->params()->key_exchange,
                          &key_exchange_index)) {
     context->Fail(QUIC_CRYPTO_NO_SUPPORT, "Unsupported AEAD or KEXS");
     return;
   }
 
-  if (!requested_config->tb_key_params.empty()) {
+  if (!configs.requested->tb_key_params.empty()) {
     QuicTagVector their_tbkps;
     switch (context->client_hello().GetTaglist(kTBKP, &their_tbkps)) {
       case QUIC_CRYPTO_MESSAGE_PARAMETER_NOT_FOUND:
         break;
       case QUIC_NO_ERROR:
-        if (FindMutualQuicTag(requested_config->tb_key_params, their_tbkps,
+        if (FindMutualQuicTag(configs.requested->tb_key_params, their_tbkps,
                               &context->params()->token_binding_key_param,
                               nullptr)) {
           break;
@@ -825,12 +770,12 @@
   }
 
   const AsynchronousKeyExchange* key_exchange =
-      requested_config->key_exchanges[key_exchange_index].get();
+      configs.requested->key_exchanges[key_exchange_index].get();
   std::string* initial_premaster_secret =
       &context->params()->initial_premaster_secret;
   auto cb = QuicMakeUnique<ProcessClientHelloAfterGetProofCallback>(
       this, std::move(proof_source_details), key_exchange->type(),
-      std::move(out), public_value, std::move(context), requested_config);
+      std::move(out), public_value, std::move(context), configs);
   key_exchange->CalculateSharedKeyAsync(public_value, initial_premaster_secret,
                                         std::move(cb));
 }
@@ -842,7 +787,7 @@
     std::unique_ptr<CryptoHandshakeMessage> out,
     QuicStringPiece public_value,
     std::unique_ptr<ProcessClientHelloContext> context,
-    const QuicReferenceCountedPointer<Config>& requested_config) const {
+    const Configs& configs) const {
   QUIC_BUG_IF(!QuicUtils::IsConnectionIdValidForVersion(
       context->connection_id(), context->transport_version()))
       << "ProcessClientHelloAfterCalculateSharedKeys:"
@@ -866,12 +811,12 @@
       context->client_hello().GetSerialized();
   hkdf_suffix.reserve(context->connection_id().length() +
                       client_hello_serialized.length() +
-                      requested_config->serialized.size());
+                      configs.requested->serialized.size());
   hkdf_suffix.append(context->connection_id().data(),
                      context->connection_id().length());
   hkdf_suffix.append(client_hello_serialized.data(),
                      client_hello_serialized.length());
-  hkdf_suffix.append(requested_config->serialized);
+  hkdf_suffix.append(configs.requested->serialized);
   DCHECK(proof_source_.get());
   if (context->signed_config()->chain->certs.empty()) {
     context->Fail(QUIC_CRYPTO_INTERNAL_ERROR, "Failed to get certs");
@@ -880,7 +825,7 @@
   hkdf_suffix.append(context->signed_config()->chain->certs.at(0));
 
   QuicStringPiece cetv_ciphertext;
-  if (requested_config->channel_id_enabled &&
+  if (configs.requested->channel_id_enabled &&
       context->client_hello().GetStringPiece(kCETV, &cetv_ciphertext)) {
     CryptoHandshakeMessage client_hello_copy(context->client_hello());
     client_hello_copy.Erase(kCETV);
@@ -895,7 +840,7 @@
                       context->connection_id().length());
     hkdf_input.append(client_hello_copy_serialized.data(),
                       client_hello_copy_serialized.length());
-    hkdf_input.append(requested_config->serialized);
+    hkdf_input.append(configs.requested->serialized);
 
     CrypterPair crypters;
     if (!CryptoUtils::DeriveKeys(
@@ -1008,7 +953,7 @@
   out->SetVersionVector(kVER, context->supported_versions());
   out->SetStringPiece(
       kSourceAddressTokenTag,
-      NewSourceAddressToken(*requested_config,
+      NewSourceAddressToken(*configs.requested,
                             context->info().source_address_tokens,
                             context->client_address().host(), context->rand(),
                             context->info().now, nullptr));
@@ -1037,6 +982,38 @@
   return QuicReferenceCountedPointer<Config>();
 }
 
+bool QuicCryptoServerConfig::GetCurrentConfigs(
+    const QuicWallTime& now,
+    QuicStringPiece requested_scid,
+    QuicReferenceCountedPointer<Config> old_primary_config,
+    Configs* configs) const {
+  QuicReaderMutexLock locked(&configs_lock_);
+
+  if (!primary_config_) {
+    return false;
+  }
+
+  if (IsNextConfigReady(now)) {
+    configs_lock_.ReaderUnlock();
+    configs_lock_.WriterLock();
+    SelectNewPrimaryConfig(now);
+    DCHECK(primary_config_.get());
+    DCHECK_EQ(configs_.find(primary_config_->id)->second.get(),
+              primary_config_.get());
+    configs_lock_.WriterUnlock();
+    configs_lock_.ReaderLock();
+  }
+
+  if (old_primary_config != nullptr) {
+    configs->primary = old_primary_config;
+  } else {
+    configs->primary = primary_config_;
+  }
+  configs->requested = GetConfigWithScid(requested_scid);
+
+  return true;
+}
+
 // ConfigPrimaryTimeLessThan is a comparator that implements "less than" for
 // Config's based on their primary_time.
 // static
@@ -1141,12 +1118,10 @@
 void QuicCryptoServerConfig::EvaluateClientHello(
     const QuicSocketAddress& server_address,
     QuicTransportVersion version,
-    QuicReferenceCountedPointer<Config> requested_config,
-    QuicReferenceCountedPointer<Config> primary_config,
+    const Configs& configs,
     QuicReferenceCountedPointer<ValidateClientHelloResultCallback::Result>
         client_hello_state,
     std::unique_ptr<ValidateClientHelloResultCallback> done_cb) const {
-
   ValidateClientHelloHelper helper(client_hello_state, &done_cb);
 
   const CryptoHandshakeMessage& client_hello = client_hello_state->client_hello;
@@ -1172,7 +1147,7 @@
     QuicStringPiece srct;
     if (client_hello.GetStringPiece(kSourceAddressTokenTag, &srct)) {
       Config& config =
-          requested_config != nullptr ? *requested_config : *primary_config;
+          configs.requested != nullptr ? *configs.requested : *configs.primary;
       source_address_token_error =
           ParseSourceAddressToken(config, srct, &info->source_address_tokens);
 
@@ -1191,7 +1166,7 @@
     info->valid_source_address_token = true;
   }
 
-  if (!requested_config.get()) {
+  if (!configs.requested) {
     QuicStringPiece requested_scid;
     if (client_hello.GetStringPiece(kSCID, &requested_scid)) {
       info->reject_reasons.push_back(SERVER_CONFIG_UNKNOWN_CONFIG_FAILURE);
diff --git a/quic/core/crypto/quic_crypto_server_config.h b/quic/core/crypto/quic_crypto_server_config.h
index 8b020d0..0304bf3 100644
--- a/quic/core/crypto/quic_crypto_server_config.h
+++ b/quic/core/crypto/quic_crypto_server_config.h
@@ -513,6 +513,23 @@
       QuicStringPiece requested_scid) const
       SHARED_LOCKS_REQUIRED(configs_lock_);
 
+  // A snapshot of the configs associated with an in-progress handshake.
+  struct Configs {
+    QuicReferenceCountedPointer<Config> requested;
+    QuicReferenceCountedPointer<Config> primary;
+  };
+
+  // Get a snapshot of the current configs associated with a handshake.  If this
+  // method was called earlier in this handshake |old_primary_config| should be
+  // set to the primary config returned from that invocation, otherwise nullptr.
+  //
+  // Returns true if any configs are loaded.  If false is returned, |configs| is
+  // not modified.
+  bool GetCurrentConfigs(const QuicWallTime& now,
+                         QuicStringPiece requested_scid,
+                         QuicReferenceCountedPointer<Config> old_primary_config,
+                         Configs* configs) const;
+
   // ConfigPrimaryTimeLessThan returns true if a->primary_time <
   // b->primary_time.
   static bool ConfigPrimaryTimeLessThan(
@@ -530,8 +547,7 @@
   void EvaluateClientHello(
       const QuicSocketAddress& server_address,
       QuicTransportVersion version,
-      QuicReferenceCountedPointer<Config> requested_config,
-      QuicReferenceCountedPointer<Config> primary_config,
+      const Configs& configs,
       QuicReferenceCountedPointer<ValidateClientHelloResultCallback::Result>
           client_hello_state,
       std::unique_ptr<ValidateClientHelloResultCallback> done_cb) const;
@@ -660,8 +676,7 @@
       bool found_error,
       std::unique_ptr<ProofSource::Details> proof_source_details,
       std::unique_ptr<ProcessClientHelloContext> context,
-      const QuicReferenceCountedPointer<Config>& requested_config,
-      const QuicReferenceCountedPointer<Config>& primary_config) const;
+      const Configs& configs) const;
 
   // Callback class for bridging between ProcessClientHelloAfterGetProof and
   // ProcessClientHelloAfterCalculateSharedKeys.
@@ -676,7 +691,7 @@
       std::unique_ptr<CryptoHandshakeMessage> out,
       QuicStringPiece public_value,
       std::unique_ptr<ProcessClientHelloContext> context,
-      const QuicReferenceCountedPointer<Config>& requested_config) const;
+      const Configs& configs) const;
 
   // BuildRejection sets |out| to be a REJ message in reply to |client_hello|.
   void BuildRejection(const ProcessClientHelloContext& context,
@@ -833,16 +848,20 @@
   //   2) primary_config_ != nullptr -> primary_config_->is_primary
   //   3) ∀ c∈configs_, c->is_primary <-> c == primary_config_
   mutable QuicMutex configs_lock_;
+
   // configs_ contains all active server configs. It's expected that there are
   // about half-a-dozen configs active at any one time.
   ConfigMap configs_ GUARDED_BY(configs_lock_);
+
   // primary_config_ points to a Config (which is also in |configs_|) which is
   // the primary config - i.e. the one that we'll give out to new clients.
   mutable QuicReferenceCountedPointer<Config> primary_config_
       GUARDED_BY(configs_lock_);
+
   // next_config_promotion_time_ contains the nearest, future time when an
   // active config will be promoted to primary.
   mutable QuicWallTime next_config_promotion_time_ GUARDED_BY(configs_lock_);
+
   // Callback to invoke when the primary config changes.
   std::unique_ptr<PrimaryConfigChangedCallback> primary_config_changed_cb_
       GUARDED_BY(configs_lock_);