Group configs and factor out config lookup
This CL factors out common code for grabbing the primary and requested configs to be passed along through other handshake steps. In addition to reducing code duplication and boilerplate, this will be used as a natural way to add a new fallback config (for when Leto is not working) without complexifying the code.
gfe-relnote: No-op refactoring, not flag-protected.
PiperOrigin-RevId: 239287118
Change-Id: Ia36f9b8dafe7a6e3ca9212b5d0779e296579a76d
diff --git a/quic/core/crypto/quic_crypto_server_config.cc b/quic/core/crypto/quic_crypto_server_config.cc
index 426e8da..4c77e21 100644
--- a/quic/core/crypto/quic_crypto_server_config.cc
+++ b/quic/core/crypto/quic_crypto_server_config.cc
@@ -518,39 +518,21 @@
QuicStringPiece requested_scid;
client_hello.GetStringPiece(kSCID, &requested_scid);
-
- QuicReferenceCountedPointer<Config> requested_config;
- QuicReferenceCountedPointer<Config> primary_config;
- {
- QuicReaderMutexLock locked(&configs_lock_);
- if (!primary_config_.get()) {
- result->error_code = QUIC_CRYPTO_INTERNAL_ERROR;
- result->error_details = "No configurations loaded";
- } else {
- if (IsNextConfigReady(now)) {
- configs_lock_.ReaderUnlock();
- configs_lock_.WriterLock();
- SelectNewPrimaryConfig(now);
- DCHECK(primary_config_.get());
- DCHECK_EQ(configs_.find(primary_config_->id)->second.get(),
- primary_config_.get());
- configs_lock_.WriterUnlock();
- configs_lock_.ReaderLock();
- }
- }
-
- requested_config = GetConfigWithScid(requested_scid);
- primary_config = primary_config_;
- signed_config->config = primary_config_;
+ Configs configs;
+ if (!GetCurrentConfigs(now, requested_scid,
+ /* old_primary_config = */ nullptr, &configs)) {
+ result->error_code = QUIC_CRYPTO_INTERNAL_ERROR;
+ result->error_details = "No configurations loaded";
}
+ signed_config->config = configs.primary;
if (result->error_code == QUIC_NO_ERROR) {
// QUIC requires a new proof for each CHLO so clear any existing proof.
signed_config->chain = nullptr;
signed_config->proof.signature = "";
signed_config->proof.leaf_cert_scts = "";
- EvaluateClientHello(server_address, version, requested_config,
- primary_config, result, std::move(done_cb));
+ EvaluateClientHello(server_address, version, configs, result,
+ std::move(done_cb));
} else {
done_cb->Run(result, /* details = */ nullptr);
}
@@ -559,17 +541,10 @@
class QuicCryptoServerConfig::ProcessClientHelloCallback
: public ProofSource::Callback {
public:
- ProcessClientHelloCallback(
- const QuicCryptoServerConfig* config,
- std::unique_ptr<ProcessClientHelloContext> context,
- const QuicReferenceCountedPointer<QuicCryptoServerConfig::Config>&
- requested_config,
- const QuicReferenceCountedPointer<QuicCryptoServerConfig::Config>&
- primary_config)
- : config_(config),
- context_(std::move(context)),
- requested_config_(requested_config),
- primary_config_(primary_config) {}
+ ProcessClientHelloCallback(const QuicCryptoServerConfig* config,
+ std::unique_ptr<ProcessClientHelloContext> context,
+ const Configs& configs)
+ : config_(config), context_(std::move(context)), configs_(configs) {}
void Run(bool ok,
const QuicReferenceCountedPointer<ProofSource::Chain>& chain,
@@ -579,18 +554,14 @@
context_->signed_config()->chain = chain;
context_->signed_config()->proof = proof;
}
- config_->ProcessClientHelloAfterGetProof(
- !ok, std::move(details), std::move(context_), requested_config_,
- primary_config_);
+ config_->ProcessClientHelloAfterGetProof(!ok, std::move(details),
+ std::move(context_), configs_);
}
private:
const QuicCryptoServerConfig* config_;
std::unique_ptr<ProcessClientHelloContext> context_;
- const QuicReferenceCountedPointer<QuicCryptoServerConfig::Config>
- requested_config_;
- const QuicReferenceCountedPointer<QuicCryptoServerConfig::Config>
- primary_config_;
+ const Configs configs_;
};
class QuicCryptoServerConfig::ProcessClientHelloAfterGetProofCallback
@@ -603,19 +574,19 @@
std::unique_ptr<CryptoHandshakeMessage> out,
QuicStringPiece public_value,
std::unique_ptr<ProcessClientHelloContext> context,
- const QuicReferenceCountedPointer<Config>& requested_config)
+ const Configs& configs)
: config_(config),
proof_source_details_(std::move(proof_source_details)),
key_exchange_type_(key_exchange_type),
out_(std::move(out)),
public_value_(public_value),
context_(std::move(context)),
- requested_config_(requested_config) {}
+ configs_(configs) {}
void Run(bool ok) override {
config_->ProcessClientHelloAfterCalculateSharedKeys(
!ok, std::move(proof_source_details_), key_exchange_type_,
- std::move(out_), public_value_, std::move(context_), requested_config_);
+ std::move(out_), public_value_, std::move(context_), configs_);
}
private:
@@ -625,7 +596,7 @@
std::unique_ptr<CryptoHandshakeMessage> out_;
std::string public_value_;
std::unique_ptr<ProcessClientHelloContext> context_;
- const QuicReferenceCountedPointer<Config> requested_config_;
+ const Configs configs_;
};
void QuicCryptoServerConfig::ProcessClientHello(
@@ -655,6 +626,7 @@
params, signed_config, total_framing_overhead, chlo_packet_size,
std::move(done_cb));
+ // Verify that various parts of the CHLO are valid
std::string error_details;
QuicErrorCode valid = CryptoUtils::ValidateClientHello(
context->client_hello(), context->version(),
@@ -666,35 +638,9 @@
QuicStringPiece requested_scid;
context->client_hello().GetStringPiece(kSCID, &requested_scid);
- const QuicWallTime now(clock->WallNow());
-
- QuicReferenceCountedPointer<Config> requested_config;
- QuicReferenceCountedPointer<Config> primary_config;
- bool no_primary_config = false;
- {
- QuicReaderMutexLock locked(&configs_lock_);
-
- if (!primary_config_) {
- no_primary_config = true;
- } else {
- if (IsNextConfigReady(now)) {
- configs_lock_.ReaderUnlock();
- configs_lock_.WriterLock();
- SelectNewPrimaryConfig(now);
- DCHECK(primary_config_.get());
- DCHECK_EQ(configs_.find(primary_config_->id)->second.get(),
- primary_config_.get());
- configs_lock_.WriterUnlock();
- configs_lock_.ReaderLock();
- }
-
- // Use the config that the client requested in order to do key-agreement.
- // Otherwise give it a copy of |primary_config_| to use.
- primary_config = signed_config->config;
- requested_config = GetConfigWithScid(requested_scid);
- }
- }
- if (no_primary_config) {
+ Configs configs;
+ if (!GetCurrentConfigs(context->clock()->WallNow(), requested_scid,
+ signed_config->config, &configs)) {
context->Fail(QUIC_CRYPTO_INTERNAL_ERROR, "No configurations loaded");
return;
}
@@ -719,25 +665,24 @@
const QuicTransportVersion transport_version = context->transport_version();
auto cb = QuicMakeUnique<ProcessClientHelloCallback>(
- this, std::move(context), requested_config, primary_config);
+ this, std::move(context), configs);
DCHECK(proof_source_.get());
- proof_source_->GetProof(server_address, sni, primary_config->serialized,
+ proof_source_->GetProof(server_address, sni, configs.primary->serialized,
transport_version, chlo_hash, std::move(cb));
return;
}
ProcessClientHelloAfterGetProof(
/* found_error = */ false, /* proof_source_details = */ nullptr,
- std::move(context), requested_config, primary_config);
+ std::move(context), configs);
}
void QuicCryptoServerConfig::ProcessClientHelloAfterGetProof(
bool found_error,
std::unique_ptr<ProofSource::Details> proof_source_details,
std::unique_ptr<ProcessClientHelloContext> context,
- const QuicReferenceCountedPointer<Config>& requested_config,
- const QuicReferenceCountedPointer<Config>& primary_config) const {
+ const Configs& configs) const {
QUIC_BUG_IF(!QuicUtils::IsConnectionIdValidForVersion(
context->connection_id(), context->transport_version()))
<< "ProcessClientHelloAfterGetProof: attempted to use connection ID "
@@ -758,8 +703,8 @@
}
auto out = QuicMakeUnique<CryptoHandshakeMessage>();
- if (!context->info().reject_reasons.empty() || !requested_config) {
- BuildRejection(*context, *primary_config, context->info().reject_reasons,
+ if (!context->info().reject_reasons.empty() || !configs.requested) {
+ BuildRejection(*context, *configs.primary, context->info().reject_reasons,
out.get());
if (rejection_observer_ != nullptr) {
rejection_observer_->OnRejectionBuilt(context->info().reject_reasons,
@@ -789,22 +734,22 @@
}
size_t key_exchange_index;
- if (!FindMutualQuicTag(requested_config->aead, their_aeads,
+ if (!FindMutualQuicTag(configs.requested->aead, their_aeads,
&context->params()->aead, nullptr) ||
- !FindMutualQuicTag(requested_config->kexs, their_key_exchanges,
+ !FindMutualQuicTag(configs.requested->kexs, their_key_exchanges,
&context->params()->key_exchange,
&key_exchange_index)) {
context->Fail(QUIC_CRYPTO_NO_SUPPORT, "Unsupported AEAD or KEXS");
return;
}
- if (!requested_config->tb_key_params.empty()) {
+ if (!configs.requested->tb_key_params.empty()) {
QuicTagVector their_tbkps;
switch (context->client_hello().GetTaglist(kTBKP, &their_tbkps)) {
case QUIC_CRYPTO_MESSAGE_PARAMETER_NOT_FOUND:
break;
case QUIC_NO_ERROR:
- if (FindMutualQuicTag(requested_config->tb_key_params, their_tbkps,
+ if (FindMutualQuicTag(configs.requested->tb_key_params, their_tbkps,
&context->params()->token_binding_key_param,
nullptr)) {
break;
@@ -825,12 +770,12 @@
}
const AsynchronousKeyExchange* key_exchange =
- requested_config->key_exchanges[key_exchange_index].get();
+ configs.requested->key_exchanges[key_exchange_index].get();
std::string* initial_premaster_secret =
&context->params()->initial_premaster_secret;
auto cb = QuicMakeUnique<ProcessClientHelloAfterGetProofCallback>(
this, std::move(proof_source_details), key_exchange->type(),
- std::move(out), public_value, std::move(context), requested_config);
+ std::move(out), public_value, std::move(context), configs);
key_exchange->CalculateSharedKeyAsync(public_value, initial_premaster_secret,
std::move(cb));
}
@@ -842,7 +787,7 @@
std::unique_ptr<CryptoHandshakeMessage> out,
QuicStringPiece public_value,
std::unique_ptr<ProcessClientHelloContext> context,
- const QuicReferenceCountedPointer<Config>& requested_config) const {
+ const Configs& configs) const {
QUIC_BUG_IF(!QuicUtils::IsConnectionIdValidForVersion(
context->connection_id(), context->transport_version()))
<< "ProcessClientHelloAfterCalculateSharedKeys:"
@@ -866,12 +811,12 @@
context->client_hello().GetSerialized();
hkdf_suffix.reserve(context->connection_id().length() +
client_hello_serialized.length() +
- requested_config->serialized.size());
+ configs.requested->serialized.size());
hkdf_suffix.append(context->connection_id().data(),
context->connection_id().length());
hkdf_suffix.append(client_hello_serialized.data(),
client_hello_serialized.length());
- hkdf_suffix.append(requested_config->serialized);
+ hkdf_suffix.append(configs.requested->serialized);
DCHECK(proof_source_.get());
if (context->signed_config()->chain->certs.empty()) {
context->Fail(QUIC_CRYPTO_INTERNAL_ERROR, "Failed to get certs");
@@ -880,7 +825,7 @@
hkdf_suffix.append(context->signed_config()->chain->certs.at(0));
QuicStringPiece cetv_ciphertext;
- if (requested_config->channel_id_enabled &&
+ if (configs.requested->channel_id_enabled &&
context->client_hello().GetStringPiece(kCETV, &cetv_ciphertext)) {
CryptoHandshakeMessage client_hello_copy(context->client_hello());
client_hello_copy.Erase(kCETV);
@@ -895,7 +840,7 @@
context->connection_id().length());
hkdf_input.append(client_hello_copy_serialized.data(),
client_hello_copy_serialized.length());
- hkdf_input.append(requested_config->serialized);
+ hkdf_input.append(configs.requested->serialized);
CrypterPair crypters;
if (!CryptoUtils::DeriveKeys(
@@ -1008,7 +953,7 @@
out->SetVersionVector(kVER, context->supported_versions());
out->SetStringPiece(
kSourceAddressTokenTag,
- NewSourceAddressToken(*requested_config,
+ NewSourceAddressToken(*configs.requested,
context->info().source_address_tokens,
context->client_address().host(), context->rand(),
context->info().now, nullptr));
@@ -1037,6 +982,38 @@
return QuicReferenceCountedPointer<Config>();
}
+bool QuicCryptoServerConfig::GetCurrentConfigs(
+ const QuicWallTime& now,
+ QuicStringPiece requested_scid,
+ QuicReferenceCountedPointer<Config> old_primary_config,
+ Configs* configs) const {
+ QuicReaderMutexLock locked(&configs_lock_);
+
+ if (!primary_config_) {
+ return false;
+ }
+
+ if (IsNextConfigReady(now)) {
+ configs_lock_.ReaderUnlock();
+ configs_lock_.WriterLock();
+ SelectNewPrimaryConfig(now);
+ DCHECK(primary_config_.get());
+ DCHECK_EQ(configs_.find(primary_config_->id)->second.get(),
+ primary_config_.get());
+ configs_lock_.WriterUnlock();
+ configs_lock_.ReaderLock();
+ }
+
+ if (old_primary_config != nullptr) {
+ configs->primary = old_primary_config;
+ } else {
+ configs->primary = primary_config_;
+ }
+ configs->requested = GetConfigWithScid(requested_scid);
+
+ return true;
+}
+
// ConfigPrimaryTimeLessThan is a comparator that implements "less than" for
// Config's based on their primary_time.
// static
@@ -1141,12 +1118,10 @@
void QuicCryptoServerConfig::EvaluateClientHello(
const QuicSocketAddress& server_address,
QuicTransportVersion version,
- QuicReferenceCountedPointer<Config> requested_config,
- QuicReferenceCountedPointer<Config> primary_config,
+ const Configs& configs,
QuicReferenceCountedPointer<ValidateClientHelloResultCallback::Result>
client_hello_state,
std::unique_ptr<ValidateClientHelloResultCallback> done_cb) const {
-
ValidateClientHelloHelper helper(client_hello_state, &done_cb);
const CryptoHandshakeMessage& client_hello = client_hello_state->client_hello;
@@ -1172,7 +1147,7 @@
QuicStringPiece srct;
if (client_hello.GetStringPiece(kSourceAddressTokenTag, &srct)) {
Config& config =
- requested_config != nullptr ? *requested_config : *primary_config;
+ configs.requested != nullptr ? *configs.requested : *configs.primary;
source_address_token_error =
ParseSourceAddressToken(config, srct, &info->source_address_tokens);
@@ -1191,7 +1166,7 @@
info->valid_source_address_token = true;
}
- if (!requested_config.get()) {
+ if (!configs.requested) {
QuicStringPiece requested_scid;
if (client_hello.GetStringPiece(kSCID, &requested_scid)) {
info->reject_reasons.push_back(SERVER_CONFIG_UNKNOWN_CONFIG_FAILURE);