BlindSignAuth: Refactor Privacy Pass auth flow by splitting it into reusable functions. PiperOrigin-RevId: 776695229
diff --git a/quiche/blind_sign_auth/blind_sign_auth.cc b/quiche/blind_sign_auth/blind_sign_auth.cc index d7f8542..8c0415d 100644 --- a/quiche/blind_sign_auth/blind_sign_auth.cc +++ b/quiche/blind_sign_auth/blind_sign_auth.cc
@@ -44,6 +44,30 @@ constexpr absl::string_view kIssuerHostname = "https://ipprotection-ppissuer.googleapis.com"; +constexpr size_t kExpectedExtensionTypesSize = 5; +constexpr std::array<uint16_t, kExpectedExtensionTypesSize> + kExpectedExtensionTypes = {0x0001, 0x0002, 0xF001, 0xF002, 0xF003}; + +using privacy::ppn::AuthAndSignRequest; +using privacy::ppn::AuthAndSignResponse; +using privacy::ppn::GetInitialDataRequest; +using privacy::ppn::GetInitialDataResponse; +using privacy::ppn::PrivacyPassTokenData; +using anonymous_tokens::AnonymousTokensUseCase; +using anonymous_tokens::CreatePublicKeyRSA; +using anonymous_tokens::DecodeExtensions; +using anonymous_tokens::ExpirationTimestamp; +using anonymous_tokens::ExtendedTokenRequest; +using anonymous_tokens::Extensions; +using anonymous_tokens::GeoHint; +using anonymous_tokens::MarshalTokenChallenge; +using anonymous_tokens::ParseUseCase; +using anonymous_tokens:: + PrivacyPassRsaBssaPublicMetadataClient; +using anonymous_tokens::RSAPublicKey; +using anonymous_tokens::Token; +using anonymous_tokens::TokenChallenge; +using anonymous_tokens::ValidateExtensionsOrderAndValues; } // namespace @@ -52,7 +76,7 @@ BlindSignAuthServiceType service_type, SignedTokenCallback callback) { // Create GetInitialData RPC. - privacy::ppn::GetInitialDataRequest request; + GetInitialDataRequest request; request.set_use_attestation(false); request.set_service_type(BlindSignAuthServiceTypeToString(service_type)); request.set_location_granularity( @@ -75,38 +99,21 @@ ProxyLayer proxy_layer, BlindSignAuthServiceType service_type, SignedTokenCallback callback, absl::StatusOr<BlindSignMessageResponse> response) { - if (!response.ok()) { - QUICHE_LOG(WARNING) << "GetInitialDataRequest failed: " - << response.status(); - std::move(callback)(absl::InvalidArgumentError( - "GetInitialDataRequest failed: invalid response")); - return; - } - absl::StatusCode code = response->status_code(); - if (code != absl::StatusCode::kOk) { - std::string message = - absl::StrCat("GetInitialDataRequest failed with code: ", code); - QUICHE_LOG(WARNING) << message; - std::move(callback)(absl::InvalidArgumentError(message)); - return; - } - // Parse GetInitialDataResponse. - privacy::ppn::GetInitialDataResponse initial_data_response; - if (!initial_data_response.ParseFromString(response->body())) { - QUICHE_LOG(WARNING) << "Failed to parse GetInitialDataResponse"; - std::move(callback)( - absl::InternalError("Failed to parse GetInitialDataResponse")); + absl::StatusOr<GetInitialDataResponse> initial_data_response = + ParseGetInitialDataResponseMessage(response); + if (!initial_data_response.ok()) { + std::move(callback)(initial_data_response.status()); return; } // Create token signing requests. - bool use_privacy_pass_client = - initial_data_response.has_privacy_pass_data() && + const bool use_privacy_pass_client = + initial_data_response->has_privacy_pass_data() && auth_options_.enable_privacy_pass(); if (use_privacy_pass_client) { QUICHE_DVLOG(1) << "Using Privacy Pass client"; - GeneratePrivacyPassTokens(initial_data_response, std::move(oauth_token), + GeneratePrivacyPassTokens(*initial_data_response, std::move(oauth_token), num_tokens, proxy_layer, service_type, std::move(callback)); } else { @@ -121,73 +128,18 @@ std::optional<std::string> oauth_token, int num_tokens, ProxyLayer proxy_layer, BlindSignAuthServiceType service_type, SignedTokenCallback callback) { - // Set up values used in the token generation loop. - anonymous_tokens::RSAPublicKey public_key_proto; - if (!public_key_proto.ParseFromString( - initial_data_response.at_public_metadata_public_key() - .serialized_public_key())) { - std::move(callback)( - absl::InvalidArgumentError("Failed to parse Privacy Pass public key")); + absl::StatusOr<PrivacyPassContext> pp_context = + CreatePrivacyPassContext(initial_data_response); + if (!pp_context.ok()) { + std::move(callback)(pp_context.status()); return; } - absl::StatusOr<bssl::UniquePtr<RSA>> bssl_rsa_key = - anonymous_tokens::CreatePublicKeyRSA( - public_key_proto.n(), public_key_proto.e()); - if (!bssl_rsa_key.ok()) { - QUICHE_LOG(ERROR) << "Failed to create RSA public key: " - << bssl_rsa_key.status(); - std::move(callback)(absl::InternalError("Failed to create RSA public key")); - return; - } - absl::StatusOr<anonymous_tokens::Extensions> extensions = - anonymous_tokens::DecodeExtensions( - initial_data_response.privacy_pass_data() - .public_metadata_extensions()); - if (!extensions.ok()) { - QUICHE_LOG(WARNING) << "Failed to decode extensions: " - << extensions.status(); - std::move(callback)( - absl::InvalidArgumentError("Failed to decode extensions")); - return; - } - std::vector<uint16_t> kExpectedExtensionTypes = { - /*ExpirationTimestamp=*/0x0001, /*GeoHint=*/0x0002, - /*ServiceType=*/0xF001, /*DebugMode=*/0xF002, /*ProxyLayer=*/0xF003}; - // TODO(b/345801768): Improve the API of - // `anonymous_tokens::ValidateExtensionsOrderAndValues` to - // avoid any possible TOCTOU problems. - absl::Status result = - anonymous_tokens::ValidateExtensionsOrderAndValues( - *extensions, absl::MakeSpan(kExpectedExtensionTypes), absl::Now()); - if (!result.ok()) { - QUICHE_LOG(WARNING) << "Failed to validate extensions: " << result; - std::move(callback)( - absl::InvalidArgumentError("Failed to validate extensions")); - return; - } - absl::StatusOr<anonymous_tokens::ExpirationTimestamp> - expiration_timestamp = anonymous_tokens:: - ExpirationTimestamp::FromExtension(extensions->extensions.at(0)); - if (!expiration_timestamp.ok()) { - QUICHE_LOG(WARNING) << "Failed to parse expiration timestamp: " - << expiration_timestamp.status(); - std::move(callback)( - absl::InvalidArgumentError("Failed to parse expiration timestamp")); - return; - } - absl::Time public_metadata_expiry_time = - absl::FromUnixSeconds(expiration_timestamp->timestamp); - - absl::StatusOr<anonymous_tokens::GeoHint> geo_hint = - anonymous_tokens::GeoHint::FromExtension( - extensions->extensions.at(1)); - QUICHE_CHECK(geo_hint.ok()); // Create token challenge. - anonymous_tokens::TokenChallenge challenge; + TokenChallenge challenge; challenge.issuer_name = kIssuerHostname; absl::StatusOr<std::string> token_challenge = - anonymous_tokens::MarshalTokenChallenge(challenge); + MarshalTokenChallenge(challenge); if (!token_challenge.ok()) { QUICHE_LOG(WARNING) << "Failed to marshal token challenge: " << token_challenge.status(); @@ -196,78 +148,34 @@ return; } - QuicheRandom* random = QuicheRandom::GetInstance(); - // Create vector of Privacy Pass clients, one for each token. - std::vector<anonymous_tokens::ExtendedTokenRequest> - extended_token_requests; - std::vector<std::unique_ptr<anonymous_tokens:: - PrivacyPassRsaBssaPublicMetadataClient>> - privacy_pass_clients; - std::vector<std::string> privacy_pass_blinded_tokens; - - for (int i = 0; i < num_tokens; i++) { - // Create client. - auto client = anonymous_tokens:: - PrivacyPassRsaBssaPublicMetadataClient::Create(*bssl_rsa_key.value()); - if (!client.ok()) { - QUICHE_LOG(WARNING) << "Failed to create Privacy Pass client: " - << client.status(); - std::move(callback)( - absl::InternalError("Failed to create Privacy Pass client")); - return; - } - - // Create nonce. - std::string nonce_rand(32, '\0'); - random->RandBytes(nonce_rand.data(), nonce_rand.size()); - - // Create token request. - absl::StatusOr<anonymous_tokens::ExtendedTokenRequest> - extended_token_request = client.value()->CreateTokenRequest( - *token_challenge, nonce_rand, - initial_data_response.privacy_pass_data().token_key_id(), - *extensions); - if (!extended_token_request.ok()) { - QUICHE_LOG(WARNING) << "Failed to create ExtendedTokenRequest: " - << extended_token_request.status(); - std::move(callback)( - absl::InternalError("Failed to create ExtendedTokenRequest")); - return; - } - privacy_pass_clients.push_back(*std::move(client)); - extended_token_requests.push_back(*extended_token_request); - privacy_pass_blinded_tokens.push_back(absl::Base64Escape( - extended_token_request->request.blinded_token_request)); + absl::StatusOr<GeneratedTokenRequests> token_requests_data = + GenerateBlindedTokenRequests(num_tokens, *pp_context->rsa_public_key, + *token_challenge, pp_context->token_key_id, + pp_context->extensions); + if (!token_requests_data.ok()) { + std::move(callback)(token_requests_data.status()); + return; } - privacy::ppn::AuthAndSignRequest sign_request; + AuthAndSignRequest sign_request; sign_request.set_service_type(BlindSignAuthServiceTypeToString(service_type)); sign_request.set_key_type(privacy::ppn::AT_PUBLIC_METADATA_KEY_TYPE); sign_request.set_key_version( initial_data_response.at_public_metadata_public_key().key_version()); - sign_request.mutable_blinded_token()->Assign( - privacy_pass_blinded_tokens.begin(), privacy_pass_blinded_tokens.end()); + *sign_request.mutable_blinded_token() = { + token_requests_data->privacy_pass_blinded_tokens_b64.begin(), + token_requests_data->privacy_pass_blinded_tokens_b64.end()}; sign_request.mutable_public_metadata_extensions()->assign( initial_data_response.privacy_pass_data().public_metadata_extensions()); // TODO(b/295924807): deprecate this option after AT server defaults to it sign_request.set_do_not_use_rsa_public_exponent(true); sign_request.set_proxy_layer(QuicheProxyLayerToPpnProxyLayer(proxy_layer)); - absl::StatusOr<anonymous_tokens::AnonymousTokensUseCase> - use_case = anonymous_tokens::ParseUseCase( - initial_data_response.at_public_metadata_public_key().use_case()); - if (!use_case.ok()) { - QUICHE_LOG(WARNING) << "Failed to parse use case: " << use_case.status(); - std::move(callback)(absl::InvalidArgumentError("Failed to parse use case")); - return; - } - BlindSignMessageCallback auth_and_sign_callback = absl::bind_front(&BlindSignAuth::PrivacyPassAuthAndSignCallback, this, - std::move(initial_data_response.privacy_pass_data() - .public_metadata_extensions()), - public_metadata_expiry_time, *geo_hint, *use_case, - std::move(privacy_pass_clients), std::move(callback)); + *std::move(pp_context), + std::move(token_requests_data->privacy_pass_clients), + std::move(callback)); // TODO(b/304811277): remove other usages of string.data() fetcher_->DoRequest(BlindSignMessageRequestType::kAuthAndSign, oauth_token, sign_request.SerializeAsString(), @@ -275,9 +183,7 @@ } void BlindSignAuth::PrivacyPassAuthAndSignCallback( - std::string encoded_extensions, absl::Time public_key_expiry_time, - anonymous_tokens::GeoHint geo_hint, - anonymous_tokens::AnonymousTokensUseCase use_case, + const PrivacyPassContext& pp_context, std::vector<std::unique_ptr<anonymous_tokens:: PrivacyPassRsaBssaPublicMetadataClient>> privacy_pass_clients, @@ -299,7 +205,7 @@ } // Decode AuthAndSignResponse. - privacy::ppn::AuthAndSignResponse sign_response; + AuthAndSignResponse sign_response; if (!sign_response.ParseFromString(response->body())) { QUICHE_LOG(WARNING) << "Failed to parse AuthAndSignResponse"; std::move(callback)( @@ -328,7 +234,7 @@ return; } - absl::StatusOr<anonymous_tokens::Token> token = + absl::StatusOr<Token> token = privacy_pass_clients[i]->FinalizeToken(unescaped_blinded_sig); if (!token.ok()) { QUICHE_LOG(WARNING) << "Failed to finalize token: " << token.status(); @@ -336,8 +242,7 @@ return; } - absl::StatusOr<std::string> marshaled_token = - anonymous_tokens::MarshalToken(*token); + absl::StatusOr<std::string> marshaled_token = MarshalToken(*token); if (!marshaled_token.ok()) { QUICHE_LOG(WARNING) << "Failed to marshal token: " << marshaled_token.status(); @@ -345,35 +250,21 @@ return; } - privacy::ppn::PrivacyPassTokenData privacy_pass_token_data; + PrivacyPassTokenData privacy_pass_token_data; privacy_pass_token_data.mutable_token()->assign( ConvertBase64ToWebSafeBase64(absl::Base64Escape(*marshaled_token))); privacy_pass_token_data.mutable_encoded_extensions()->assign( - ConvertBase64ToWebSafeBase64(absl::Base64Escape(encoded_extensions))); - privacy_pass_token_data.set_use_case_override(use_case); - tokens_vec.push_back( - BlindSignToken{privacy_pass_token_data.SerializeAsString(), - public_key_expiry_time, geo_hint}); + ConvertBase64ToWebSafeBase64( + absl::Base64Escape(pp_context.public_metadata_extensions_str))); + privacy_pass_token_data.set_use_case_override(pp_context.use_case); + tokens_vec.push_back(BlindSignToken{ + privacy_pass_token_data.SerializeAsString(), + pp_context.public_metadata_expiry_time, pp_context.geo_hint}); } std::move(callback)(absl::Span<BlindSignToken>(tokens_vec)); } -privacy::ppn::ProxyLayer BlindSignAuth::QuicheProxyLayerToPpnProxyLayer( - quiche::ProxyLayer proxy_layer) { - switch (proxy_layer) { - case ProxyLayer::kProxyA: { - return privacy::ppn::ProxyLayer::PROXY_A; - } - case ProxyLayer::kProxyB: { - return privacy::ppn::ProxyLayer::PROXY_B; - } - case ProxyLayer::kTerminalLayer: { - return privacy::ppn::ProxyLayer::TERMINAL_LAYER; - } - } -} - void BlindSignAuth::GetAttestationTokens(int /*num_tokens*/, ProxyLayer /*layer*/, AttestationDataCallback callback) { @@ -391,6 +282,158 @@ absl::UnimplementedError("AttestAndSign is not implemented")); } +absl::StatusOr<privacy::ppn::GetInitialDataResponse> +BlindSignAuth::ParseGetInitialDataResponseMessage( + const absl::StatusOr<BlindSignMessageResponse>& response) { + if (!response.ok()) { + QUICHE_LOG(WARNING) << "GetInitialDataRequest failed: " + << response.status(); + return absl::InvalidArgumentError( + "GetInitialDataRequest failed: invalid response"); + } + if (absl::StatusCode code = response->status_code(); + code != absl::StatusCode::kOk) { + std::string message = + absl::StrCat("GetInitialDataRequest failed with code: ", code); + QUICHE_LOG(WARNING) << message; + return absl::InvalidArgumentError(message); + } + // Parse GetInitialDataResponse. + GetInitialDataResponse initial_data_response; + if (!initial_data_response.ParseFromString(response->body())) { + QUICHE_LOG(WARNING) << "Failed to parse GetInitialDataResponse"; + return absl::InternalError("Failed to parse GetInitialDataResponse"); + } + return initial_data_response; +} + +absl::StatusOr<BlindSignAuth::PrivacyPassContext> +BlindSignAuth::CreatePrivacyPassContext( + const privacy::ppn::GetInitialDataResponse& initial_data_response) { + RSAPublicKey public_key_proto; + if (!public_key_proto.ParseFromString( + initial_data_response.at_public_metadata_public_key() + .serialized_public_key())) { + return absl::InvalidArgumentError( + "Failed to parse Privacy Pass public key"); + } + absl::StatusOr<bssl::UniquePtr<RSA>> bssl_rsa_key = + CreatePublicKeyRSA(public_key_proto.n(), public_key_proto.e()); + if (!bssl_rsa_key.ok()) { + return absl::InternalError(absl::StrCat("Failed to create RSA public key: ", + bssl_rsa_key.status().ToString())); + } + + PrivacyPassContext context; + context.rsa_public_key = *std::move(bssl_rsa_key); + context.key_version = + initial_data_response.at_public_metadata_public_key().key_version(); + context.token_key_id = + initial_data_response.privacy_pass_data().token_key_id(); + context.public_metadata_extensions_str = + initial_data_response.privacy_pass_data().public_metadata_extensions(); + + absl::StatusOr<Extensions> extensions = + DecodeExtensions(context.public_metadata_extensions_str); + if (!extensions.ok()) { + return absl::InvalidArgumentError(absl::StrCat( + "Failed to decode extensions: ", extensions.status().ToString())); + } + + if (absl::Status validation_result = ValidateExtensionsOrderAndValues( + *extensions, absl::MakeSpan(kExpectedExtensionTypes), absl::Now()); + validation_result.ok()) { + context.extensions = *std::move(extensions); + } else { + return absl::InvalidArgumentError(absl::StrCat( + "Failed to validate extensions: ", validation_result.ToString())); + } + + if (absl::StatusOr<ExpirationTimestamp> expiration_timestamp = + ExpirationTimestamp::FromExtension( + context.extensions.extensions.at(0)); + expiration_timestamp.ok()) { + context.public_metadata_expiry_time = + absl::FromUnixSeconds(expiration_timestamp->timestamp); + } else { + return absl::InvalidArgumentError( + absl::StrCat("Failed to parse expiration timestamp: ", + expiration_timestamp.status().ToString())); + } + + if (absl::StatusOr<GeoHint> geo_hint = + GeoHint::FromExtension(context.extensions.extensions.at(1)); + geo_hint.ok()) { + context.geo_hint = *std::move(geo_hint); + } else { + return absl::InvalidArgumentError(absl::StrCat( + "Failed to parse geo hint: ", geo_hint.status().ToString())); + } + + if (absl::StatusOr<AnonymousTokensUseCase> use_case = ParseUseCase( + initial_data_response.at_public_metadata_public_key().use_case()); + use_case.ok()) { + context.use_case = *std::move(use_case); + } else { + return absl::InvalidArgumentError(absl::StrCat( + "Failed to parse use case: ", use_case.status().ToString())); + } + + return context; +} + +absl::StatusOr<BlindSignAuth::GeneratedTokenRequests> +BlindSignAuth::GenerateBlindedTokenRequests( + int num_tokens, const RSA& rsa_public_key, + absl::string_view token_challenge_str, absl::string_view token_key_id, + const anonymous_tokens::Extensions& extensions) { + GeneratedTokenRequests result; + result.privacy_pass_clients.reserve(num_tokens); + result.privacy_pass_blinded_tokens_b64.reserve(num_tokens); + QuicheRandom* random = QuicheRandom::GetInstance(); + + for (int i = 0; i < num_tokens; i++) { + absl::StatusOr<std::unique_ptr<PrivacyPassRsaBssaPublicMetadataClient>> + client = PrivacyPassRsaBssaPublicMetadataClient::Create(rsa_public_key); + if (!client.ok()) { + return absl::InternalError( + absl::StrCat("Failed to create Privacy Pass client: ", + client.status().ToString())); + } + + std::string nonce_rand(32, '\0'); + random->RandBytes(nonce_rand.data(), nonce_rand.size()); + + absl::StatusOr<ExtendedTokenRequest> extended_token_request = + (*client)->CreateTokenRequest(token_challenge_str, nonce_rand, + token_key_id, extensions); + if (!extended_token_request.ok()) { + return absl::InternalError( + absl::StrCat("Failed to create ExtendedTokenRequest: ", + extended_token_request.status().ToString())); + } + result.privacy_pass_clients.push_back(*std::move(client)); + result.privacy_pass_blinded_tokens_b64.push_back(absl::Base64Escape( + extended_token_request->request.blinded_token_request)); + } + return result; +} + +privacy::ppn::ProxyLayer BlindSignAuth::QuicheProxyLayerToPpnProxyLayer( + quiche::ProxyLayer proxy_layer) { + switch (proxy_layer) { + case ProxyLayer::kProxyA: { + return privacy::ppn::ProxyLayer::PROXY_A; + } + case ProxyLayer::kProxyB: { + return privacy::ppn::ProxyLayer::PROXY_B; + } + case ProxyLayer::kTerminalLayer: { + return privacy::ppn::ProxyLayer::TERMINAL_LAYER; + } + } +} + std::string BlindSignAuth::ConvertBase64ToWebSafeBase64( std::string base64_string) { absl::c_replace(base64_string, /*old_value=*/'+', /*new_value=*/'-');
diff --git a/quiche/blind_sign_auth/blind_sign_auth.h b/quiche/blind_sign_auth/blind_sign_auth.h index ccc3118..7ff742c 100644 --- a/quiche/blind_sign_auth/blind_sign_auth.h +++ b/quiche/blind_sign_auth/blind_sign_auth.h
@@ -62,25 +62,57 @@ SignedTokenCallback callback) override; private: + struct PrivacyPassContext { + bssl::UniquePtr<RSA> rsa_public_key; + anonymous_tokens::Extensions extensions; + absl::Time public_metadata_expiry_time; + anonymous_tokens::GeoHint geo_hint; + anonymous_tokens::AnonymousTokensUseCase use_case; + std::string token_key_id; + uint32_t key_version = 0; + std::string public_metadata_extensions_str; + }; + + struct GeneratedTokenRequests { + std::vector<std::unique_ptr<anonymous_tokens:: + PrivacyPassRsaBssaPublicMetadataClient>> + privacy_pass_clients; + std::vector<std::string> privacy_pass_blinded_tokens_b64; + }; + + // Helper functions for GetTokens flow without device attestation. void GetInitialDataCallback( std::optional<std::string> oauth_token, int num_tokens, ProxyLayer proxy_layer, BlindSignAuthServiceType service_type, SignedTokenCallback callback, absl::StatusOr<BlindSignMessageResponse> response); + void GeneratePrivacyPassTokens( privacy::ppn::GetInitialDataResponse initial_data_response, std::optional<std::string> oauth_token, int num_tokens, ProxyLayer proxy_layer, BlindSignAuthServiceType service_type, SignedTokenCallback callback); + void PrivacyPassAuthAndSignCallback( - std::string encoded_extensions, absl::Time public_key_expiry_time, - anonymous_tokens::GeoHint geo_hint, - anonymous_tokens::AnonymousTokensUseCase use_case, + const PrivacyPassContext& pp_context, std::vector<std::unique_ptr<anonymous_tokens:: PrivacyPassRsaBssaPublicMetadataClient>> privacy_pass_clients, SignedTokenCallback callback, absl::StatusOr<BlindSignMessageResponse> response); + + absl::StatusOr<privacy::ppn::GetInitialDataResponse> + ParseGetInitialDataResponseMessage( + const absl::StatusOr<BlindSignMessageResponse>& response_statusor); + + absl::StatusOr<PrivacyPassContext> CreatePrivacyPassContext( + const privacy::ppn::GetInitialDataResponse& initial_data_response); + + absl::StatusOr<GeneratedTokenRequests> GenerateBlindedTokenRequests( + int num_tokens, const RSA& rsa_public_key, + absl::string_view token_challenge_str, absl::string_view token_key_id, + const anonymous_tokens::Extensions& extensions); + privacy::ppn::ProxyLayer QuicheProxyLayerToPpnProxyLayer( quiche::ProxyLayer proxy_layer); // Replaces '+' and '/' with '-' and '_' in a Base64 string.