Replace std::function with quiche::SingleUseCallback (alias of absl::AnyInvocable) in BlindSignAuth. Also fix a lifetime issue from passing a reference to `AnonymousTokensRsaBssaClient*` between callbacks by passing the `unique_ptr`. Tested in Chromium `net_unittests` and all tests pass. PiperOrigin-RevId: 546013661
diff --git a/quiche/blind_sign_auth/blind_sign_auth.cc b/quiche/blind_sign_auth/blind_sign_auth.cc index c76ab2d..b13005f 100644 --- a/quiche/blind_sign_auth/blind_sign_auth.cc +++ b/quiche/blind_sign_auth/blind_sign_auth.cc
@@ -5,7 +5,7 @@ #include "quiche/blind_sign_auth/blind_sign_auth.h" #include <cstddef> -#include <functional> +#include <memory> #include <string> #include <utility> #include <vector> @@ -15,6 +15,7 @@ #include "quiche/blind_sign_auth/proto/key_services.pb.h" #include "quiche/blind_sign_auth/proto/public_metadata.pb.h" #include "quiche/blind_sign_auth/proto/spend_token_data.pb.h" +#include "absl/functional/bind_front.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/escaping.h" @@ -37,9 +38,8 @@ } // namespace -void BlindSignAuth::GetTokens( - absl::string_view oauth_token, int num_tokens, - std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> callback) { +void BlindSignAuth::GetTokens(std::string oauth_token, int num_tokens, + SignedTokenCallback callback) { // Create GetInitialData RPC. privacy::ppn::GetInitialDataRequest request; request.set_use_attestation(false); @@ -50,22 +50,20 @@ // Call GetInitialData on the HttpFetcher. std::string path_and_query = "/v1/getInitialData"; std::string body = request.SerializeAsString(); - http_fetcher_->DoRequest( - path_and_query, oauth_token.data(), body, - [this, callback, oauth_token, - num_tokens](absl::StatusOr<BlindSignHttpResponse> response) { - GetInitialDataCallback(response, oauth_token, num_tokens, callback); - }); + BlindSignHttpCallback initial_data_callback = + absl::bind_front(&BlindSignAuth::GetInitialDataCallback, this, + oauth_token, num_tokens, std::move(callback)); + http_fetcher_->DoRequest(path_and_query, oauth_token, body, + std::move(initial_data_callback)); } void BlindSignAuth::GetInitialDataCallback( - absl::StatusOr<BlindSignHttpResponse> response, - absl::string_view oauth_token, int num_tokens, - std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> callback) { + std::string oauth_token, int num_tokens, SignedTokenCallback callback, + absl::StatusOr<BlindSignHttpResponse> response) { if (!response.ok()) { QUICHE_LOG(WARNING) << "GetInitialDataRequest failed: " << response.status(); - callback(response.status()); + std::move(callback)(response.status()); return; } absl::StatusCode code = HttpCodeToStatusCode(response.value().status_code()); @@ -73,14 +71,15 @@ std::string message = absl::StrCat("GetInitialDataRequest failed with code: ", code); QUICHE_LOG(WARNING) << message; - callback(absl::Status(code, message)); + std::move(callback)(absl::Status(code, message)); return; } // Parse GetInitialDataResponse. privacy::ppn::GetInitialDataResponse initial_data_response; if (!initial_data_response.ParseFromString(response.value().body())) { QUICHE_LOG(WARNING) << "Failed to parse GetInitialDataResponse"; - callback(absl::InternalError("Failed to parse GetInitialDataResponse")); + std::move(callback)( + absl::InternalError("Failed to parse GetInitialDataResponse")); return; } absl::StatusOr<absl::Time> public_metadata_expiry_time = @@ -89,7 +88,7 @@ .public_metadata() .expiration()); if (!public_metadata_expiry_time.ok()) { - callback( + std::move(callback)( absl::InternalError("Failed to parse public metadata expiration time")); return; } @@ -101,7 +100,7 @@ if (!bssa_client.ok()) { QUICHE_LOG(WARNING) << "Failed to create AT BSSA client: " << bssa_client.status(); - callback(bssa_client.status()); + std::move(callback)(bssa_client.status()); return; } @@ -125,7 +124,7 @@ if (!fingerprint_status.ok()) { QUICHE_LOG(WARNING) << "Failed to fingerprint public metadata: " << fingerprint_status; - callback(fingerprint_status); + std::move(callback)(fingerprint_status); return; } uint64_t fingerprint_big_endian = QuicheEndian::HostToNet64(fingerprint); @@ -142,7 +141,7 @@ if (!at_sign_request.ok()) { QUICHE_LOG(WARNING) << "Failed to create AT Sign Request: " << at_sign_request.status(); - callback(at_sign_request.status()); + std::move(callback)(at_sign_request.status()); return; } @@ -162,38 +161,36 @@ privacy::ppn::PublicMetadataInfo public_metadata_info = initial_data_response.public_metadata_info(); - http_fetcher_->DoRequest( - "/v1/authWithHeaderCreds", oauth_token.data(), - sign_request.SerializeAsString(), - [this, at_sign_request, public_metadata_info, - expiry_time_ = public_metadata_expiry_time.value(), - bssa_client_ = bssa_client.value().get(), - callback](absl::StatusOr<BlindSignHttpResponse> response) { - AuthAndSignCallback(response, public_metadata_info, expiry_time_, - *at_sign_request, bssa_client_, callback); - }); + BlindSignHttpCallback auth_and_sign_callback = absl::bind_front( + &BlindSignAuth::AuthAndSignCallback, this, public_metadata_info, + public_metadata_expiry_time.value(), *at_sign_request, + *std::move(bssa_client), std::move(callback)); + http_fetcher_->DoRequest("/v1/authWithHeaderCreds", oauth_token.data(), + sign_request.SerializeAsString(), + std::move(auth_and_sign_callback)); } void BlindSignAuth::AuthAndSignCallback( - absl::StatusOr<BlindSignHttpResponse> response, privacy::ppn::PublicMetadataInfo public_metadata_info, absl::Time public_key_expiry_time, private_membership::anonymous_tokens::AnonymousTokensSignRequest at_sign_request, - private_membership::anonymous_tokens::AnonymousTokensRsaBssaClient* + std::unique_ptr< + private_membership::anonymous_tokens::AnonymousTokensRsaBssaClient> bssa_client, - std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> callback) { + SignedTokenCallback callback, + absl::StatusOr<BlindSignHttpResponse> response) { // Validate response. if (!response.ok()) { QUICHE_LOG(WARNING) << "AuthAndSign failed: " << response.status(); - callback(response.status()); + std::move(callback)(response.status()); return; } absl::StatusCode code = HttpCodeToStatusCode(response.value().status_code()); if (code != absl::StatusCode::kOk) { std::string message = absl::StrCat("AuthAndSign failed with code: ", code); QUICHE_LOG(WARNING) << message; - callback(absl::Status(code, message)); + std::move(callback)(absl::Status(code, message)); return; } @@ -201,7 +198,8 @@ privacy::ppn::AuthAndSignResponse sign_response; if (!sign_response.ParseFromString(response.value().body())) { QUICHE_LOG(WARNING) << "Failed to parse AuthAndSignResponse"; - callback(absl::InternalError("Failed to parse AuthAndSignResponse")); + std::move(callback)( + absl::InternalError("Failed to parse AuthAndSignResponse")); return; } @@ -213,7 +211,7 @@ at_sign_request.blinded_tokens_size()) { QUICHE_LOG(WARNING) << "Response signature size does not equal request tokens size"; - callback(absl::InternalError( + std::move(callback)(absl::InternalError( "Response signature size does not equal request tokens size")); return; } @@ -224,7 +222,7 @@ if (!absl::Base64Unescape(sign_response.blinded_token_signature(i), &blinded_token)) { QUICHE_LOG(WARNING) << "Failed to unescape blinded token signature"; - callback( + std::move(callback)( absl::InternalError("Failed to unescape blinded token signature")); return; } @@ -246,14 +244,14 @@ if (!signed_tokens.ok()) { QUICHE_LOG(WARNING) << "AuthAndSign ProcessResponse failed: " << signed_tokens.status(); - callback(signed_tokens.status()); + std::move(callback)(signed_tokens.status()); return; } if (signed_tokens->size() != static_cast<size_t>(at_sign_response.anonymous_tokens_size())) { QUICHE_LOG(WARNING) << "ProcessResponse did not output the right number of signed tokens"; - callback(absl::InternalError( + std::move(callback)(absl::InternalError( "ProcessResponse did not output the right number of signed tokens")); return; } @@ -274,7 +272,7 @@ at_sign_response.anonymous_tokens(i).use_case()); if (!use_case.ok()) { QUICHE_LOG(WARNING) << "Failed to parse use case: " << use_case.status(); - callback(use_case.status()); + std::move(callback)(use_case.status()); return; } spend_token_data.set_use_case(*use_case); @@ -284,7 +282,7 @@ public_key_expiry_time}); } - callback(absl::Span<BlindSignToken>(tokens_vec)); + std::move(callback)(absl::Span<BlindSignToken>(tokens_vec)); } absl::Status BlindSignAuth::FingerprintPublicMetadata(
diff --git a/quiche/blind_sign_auth/blind_sign_auth.h b/quiche/blind_sign_auth/blind_sign_auth.h index 0c29136..cfedeca 100644 --- a/quiche/blind_sign_auth/blind_sign_auth.h +++ b/quiche/blind_sign_auth/blind_sign_auth.h
@@ -5,18 +5,13 @@ #ifndef QUICHE_BLIND_SIGN_AUTH_BLIND_SIGN_AUTH_H_ #define QUICHE_BLIND_SIGN_AUTH_BLIND_SIGN_AUTH_H_ -#include <functional> #include <memory> #include <string> -#include <utility> -#include <vector> #include "quiche/blind_sign_auth/proto/public_metadata.pb.h" #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/strings/string_view.h" #include "absl/time/time.h" -#include "absl/types/span.h" #include "quiche/blind_sign_auth/anonymous_tokens/cpp/client/anonymous_tokens_rsa_bssa_client.h" #include "quiche/blind_sign_auth/anonymous_tokens/proto/anonymous_tokens.pb.h" #include "quiche/blind_sign_auth/blind_sign_auth_interface.h" @@ -39,24 +34,23 @@ // The GetTokens callback will run on the same thread as the // BlindSignHttpInterface callbacks. // Callers can make multiple concurrent requests to GetTokens. - void GetTokens(absl::string_view oauth_token, int num_tokens, - std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> - callback) override; + void GetTokens(std::string oauth_token, int num_tokens, + SignedTokenCallback callback) override; private: - void GetInitialDataCallback( - absl::StatusOr<BlindSignHttpResponse> response, - absl::string_view oauth_token, int num_tokens, - std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> callback); + void GetInitialDataCallback(std::string oauth_token, int num_tokens, + SignedTokenCallback callback, + absl::StatusOr<BlindSignHttpResponse> response); void AuthAndSignCallback( - absl::StatusOr<BlindSignHttpResponse> response, privacy::ppn::PublicMetadataInfo public_metadata_info, absl::Time public_key_expiry_time, private_membership::anonymous_tokens::AnonymousTokensSignRequest at_sign_request, - private_membership::anonymous_tokens::AnonymousTokensRsaBssaClient* + std::unique_ptr< + private_membership::anonymous_tokens::AnonymousTokensRsaBssaClient> bssa_client, - std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> callback); + SignedTokenCallback callback, + absl::StatusOr<BlindSignHttpResponse> response); absl::Status FingerprintPublicMetadata( const privacy::ppn::PublicMetadata& metadata, uint64_t* fingerprint); absl::StatusCode HttpCodeToStatusCode(int http_code);
diff --git a/quiche/blind_sign_auth/blind_sign_auth_interface.h b/quiche/blind_sign_auth/blind_sign_auth_interface.h index ac5b10a..e2f3bab 100644 --- a/quiche/blind_sign_auth/blind_sign_auth_interface.h +++ b/quiche/blind_sign_auth/blind_sign_auth_interface.h
@@ -5,14 +5,13 @@ #ifndef QUICHE_BLIND_SIGN_AUTH_BLIND_SIGN_AUTH_INTERFACE_H_ #define QUICHE_BLIND_SIGN_AUTH_BLIND_SIGN_AUTH_INTERFACE_H_ -#include <functional> #include <string> #include "absl/status/statusor.h" -#include "absl/strings/string_view.h" #include "absl/time/time.h" #include "absl/types/span.h" #include "quiche/common/platform/api/quiche_export.h" +#include "quiche/common/quiche_callbacks.h" namespace quiche { @@ -24,16 +23,17 @@ absl::Time expiration; }; +using SignedTokenCallback = + SingleUseCallback<void(absl::StatusOr<absl::Span<BlindSignToken>>)>; + // BlindSignAuth provides signed, unblinded tokens to callers. class QUICHE_EXPORT BlindSignAuthInterface { public: virtual ~BlindSignAuthInterface() = default; // Returns signed unblinded tokens in a callback. Tokens are single-use. - virtual void GetTokens( - absl::string_view oauth_token, int num_tokens, - std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> - callback) = 0; + virtual void GetTokens(std::string oauth_token, int num_tokens, + SignedTokenCallback callback) = 0; }; } // namespace quiche
diff --git a/quiche/blind_sign_auth/blind_sign_auth_test.cc b/quiche/blind_sign_auth/blind_sign_auth_test.cc index ff1e969..6128750 100644 --- a/quiche/blind_sign_auth/blind_sign_auth_test.cc +++ b/quiche/blind_sign_auth/blind_sign_auth_test.cc
@@ -4,7 +4,6 @@ #include "quiche/blind_sign_auth/blind_sign_auth.h" -#include <functional> #include <memory> #include <string> #include <utility> @@ -37,7 +36,6 @@ using ::testing::Eq; using ::testing::InSequence; using ::testing::Invoke; -using ::testing::InvokeArgument; using ::testing::StartsWith; using ::testing::Unused; @@ -174,25 +172,25 @@ Eq(expected_get_initial_data_request_.SerializeAsString()), _)) .Times(1) - .WillOnce(InvokeArgument<3>(fake_public_key_response)); + .WillOnce([=](auto&&, auto&&, auto&&, auto get_initial_data_cb) { + std::move(get_initial_data_cb)(fake_public_key_response); + }); EXPECT_CALL(mock_http_interface_, DoRequest(Eq("/v1/authWithHeaderCreds"), Eq(oauth_token_), _, _)) .Times(1) - .WillOnce(Invoke( - [this](Unused, Unused, const std::string& body, - std::function<void(absl::StatusOr<BlindSignHttpResponse>)> - callback) { - CreateSignResponse(body); - BlindSignHttpResponse http_response( - 200, sign_response_.SerializeAsString()); - callback(http_response); - })); + .WillOnce(Invoke([this](Unused, Unused, const std::string& body, + BlindSignHttpCallback callback) { + CreateSignResponse(body); + BlindSignHttpResponse http_response( + 200, sign_response_.SerializeAsString()); + std::move(callback)(http_response); + })); } int num_tokens = 1; QuicheNotification done; - std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> callback = + SignedTokenCallback callback = [this, &done, num_tokens](absl::StatusOr<absl::Span<BlindSignToken>> tokens) { QUICHE_EXPECT_OK(tokens); @@ -200,7 +198,7 @@ ValidateGetTokensOutput(*tokens); done.Notify(); }; - blind_sign_auth_->GetTokens(oauth_token_, num_tokens, callback); + blind_sign_auth_->GetTokens(oauth_token_, num_tokens, std::move(callback)); done.WaitForNotification(); } @@ -208,8 +206,10 @@ EXPECT_CALL(mock_http_interface_, DoRequest(Eq("/v1/getInitialData"), Eq(oauth_token_), _, _)) .Times(1) - .WillOnce( - InvokeArgument<3>(absl::InternalError("Failed to create socket"))); + .WillOnce([=](auto&&, auto&&, auto&&, auto get_initial_data_cb) { + std::move(get_initial_data_cb)( + absl::InternalError("Failed to create socket")); + }); EXPECT_CALL(mock_http_interface_, DoRequest(Eq("/v1/authWithHeaderCreds"), _, _, _)) @@ -217,12 +217,12 @@ int num_tokens = 1; QuicheNotification done; - std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> callback = + SignedTokenCallback callback = [&done](absl::StatusOr<absl::Span<BlindSignToken>> tokens) { EXPECT_THAT(tokens.status().code(), absl::StatusCode::kInternal); done.Notify(); }; - blind_sign_auth_->GetTokens(oauth_token_, num_tokens, callback); + blind_sign_auth_->GetTokens(oauth_token_, num_tokens, std::move(callback)); done.WaitForNotification(); } @@ -238,7 +238,9 @@ DoRequest(Eq("/v1/getInitialData"), Eq(oauth_token_), Eq(expected_get_initial_data_request_.SerializeAsString()), _)) .Times(1) - .WillOnce(InvokeArgument<3>(fake_public_key_response)); + .WillOnce([=](auto&&, auto&&, auto&&, auto get_initial_data_cb) { + std::move(get_initial_data_cb)(fake_public_key_response); + }); EXPECT_CALL(mock_http_interface_, DoRequest(Eq("/v1/authWithHeaderCreds"), _, _, _)) @@ -246,12 +248,12 @@ int num_tokens = 1; QuicheNotification done; - std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> callback = + SignedTokenCallback callback = [&done](absl::StatusOr<absl::Span<BlindSignToken>> tokens) { EXPECT_THAT(tokens.status().code(), absl::StatusCode::kInvalidArgument); done.Notify(); }; - blind_sign_auth_->GetTokens(oauth_token_, num_tokens, callback); + blind_sign_auth_->GetTokens(oauth_token_, num_tokens, std::move(callback)); done.WaitForNotification(); } @@ -267,32 +269,32 @@ Eq(expected_get_initial_data_request_.SerializeAsString()), _)) .Times(1) - .WillOnce(InvokeArgument<3>(fake_public_key_response)); + .WillOnce([=](auto&&, auto&&, auto&&, auto get_initial_data_cb) { + std::move(get_initial_data_cb)(fake_public_key_response); + }); EXPECT_CALL(mock_http_interface_, DoRequest(Eq("/v1/authWithHeaderCreds"), Eq(oauth_token_), _, _)) .Times(1) - .WillOnce(Invoke( - [this](Unused, Unused, const std::string& body, - std::function<void(absl::StatusOr<BlindSignHttpResponse>)> - callback) { - CreateSignResponse(body); - // Add an invalid signature that can't be Base64 decoded. - sign_response_.add_blinded_token_signature("invalid_signature%"); - BlindSignHttpResponse http_response( - 200, sign_response_.SerializeAsString()); - callback(http_response); - })); + .WillOnce(Invoke([this](Unused, Unused, const std::string& body, + BlindSignHttpCallback callback) { + CreateSignResponse(body); + // Add an invalid signature that can't be Base64 decoded. + sign_response_.add_blinded_token_signature("invalid_signature%"); + BlindSignHttpResponse http_response( + 200, sign_response_.SerializeAsString()); + std::move(callback)(http_response); + })); } int num_tokens = 1; QuicheNotification done; - std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> callback = + SignedTokenCallback callback = [&done](absl::StatusOr<absl::Span<BlindSignToken>> tokens) { EXPECT_THAT(tokens.status().code(), absl::StatusCode::kInternal); done.Notify(); }; - blind_sign_auth_->GetTokens(oauth_token_, num_tokens, callback); + blind_sign_auth_->GetTokens(oauth_token_, num_tokens, std::move(callback)); done.WaitForNotification(); }
diff --git a/quiche/blind_sign_auth/blind_sign_http_interface.h b/quiche/blind_sign_auth/blind_sign_http_interface.h index d8111b4..09e92b2 100644 --- a/quiche/blind_sign_auth/blind_sign_http_interface.h +++ b/quiche/blind_sign_auth/blind_sign_http_interface.h
@@ -5,17 +5,18 @@ #ifndef QUICHE_BLIND_SIGN_AUTH_BLIND_SIGN_HTTP_INTERFACE_H_ #define QUICHE_BLIND_SIGN_AUTH_BLIND_SIGN_HTTP_INTERFACE_H_ -#include <functional> -#include <map> #include <string> -#include <vector> #include "absl/status/statusor.h" #include "quiche/blind_sign_auth/blind_sign_http_response.h" #include "quiche/common/platform/api/quiche_export.h" +#include "quiche/common/quiche_callbacks.h" namespace quiche { +using BlindSignHttpCallback = + quiche::SingleUseCallback<void(absl::StatusOr<BlindSignHttpResponse>)>; + // Interface for async HTTP POST requests in BlindSignAuth. // Implementers must send a request to a signer hostname, using the request's // arguments, and call the provided callback when a request is complete. @@ -31,10 +32,10 @@ // "application/x-protobuf". // DoRequest is async. When the request completes, the implementer must call // the provided callback. - virtual void DoRequest( - const std::string& path_and_query, - const std::string& authorization_header, const std::string& body, - std::function<void(absl::StatusOr<BlindSignHttpResponse>)> callback) = 0; + virtual void DoRequest(const std::string& path_and_query, + const std::string& authorization_header, + const std::string& body, + BlindSignHttpCallback callback) = 0; }; } // namespace quiche
diff --git a/quiche/blind_sign_auth/cached_blind_sign_auth.cc b/quiche/blind_sign_auth/cached_blind_sign_auth.cc index 7c9271d..d2d7e70 100644 --- a/quiche/blind_sign_auth/cached_blind_sign_auth.cc +++ b/quiche/blind_sign_auth/cached_blind_sign_auth.cc
@@ -7,6 +7,7 @@ #include <utility> #include <vector> +#include "absl/functional/bind_front.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" @@ -21,17 +22,16 @@ constexpr absl::Duration kFreshnessConstant = absl::Minutes(5); -void CachedBlindSignAuth::GetTokens( - absl::string_view oauth_token, int num_tokens, - std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> callback) { +void CachedBlindSignAuth::GetTokens(std::string oauth_token, int num_tokens, + SignedTokenCallback callback) { if (num_tokens > max_tokens_per_request_) { - callback(absl::InvalidArgumentError( + std::move(callback)(absl::InvalidArgumentError( absl::StrFormat("Number of tokens requested exceeds maximum: %d", kBlindSignAuthRequestMaxTokens))); return; } if (num_tokens < 0) { - callback(absl::InvalidArgumentError(absl::StrFormat( + std::move(callback)(absl::InvalidArgumentError(absl::StrFormat( "Negative number of tokens requested: %d", num_tokens))); return; } @@ -48,28 +48,25 @@ } if (!output_tokens.empty() || num_tokens == 0) { - callback(absl::MakeSpan(output_tokens)); + std::move(callback)(absl::MakeSpan(output_tokens)); return; } // Make a GetTokensRequest if the cache can't handle the request size. - std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> - caching_callback = - [this, num_tokens, - callback](absl::StatusOr<absl::Span<BlindSignToken>> tokens) { - HandleGetTokensResponse(tokens, num_tokens, callback); - }; + SignedTokenCallback caching_callback = + absl::bind_front(&CachedBlindSignAuth::HandleGetTokensResponse, this, + std::move(callback), num_tokens); blind_sign_auth_->GetTokens(oauth_token, kBlindSignAuthRequestMaxTokens, - caching_callback); + std::move(caching_callback)); } void CachedBlindSignAuth::HandleGetTokensResponse( - absl::StatusOr<absl::Span<BlindSignToken>> tokens, int num_tokens, - std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> callback) { + SignedTokenCallback callback, int num_tokens, + absl::StatusOr<absl::Span<BlindSignToken>> tokens) { if (!tokens.ok()) { QUICHE_LOG(WARNING) << "BlindSignAuth::GetTokens failed: " << tokens.status(); - callback(tokens); + std::move(callback)(tokens); return; } if (tokens->size() < static_cast<size_t>(num_tokens) || @@ -96,10 +93,10 @@ } if (!output_tokens.empty()) { - callback(absl::MakeSpan(output_tokens)); + std::move(callback)(absl::MakeSpan(output_tokens)); return; } - callback(absl::ResourceExhaustedError(absl::StrFormat( + std::move(callback)(absl::ResourceExhaustedError(absl::StrFormat( "Requested %d tokens, cache only has %d after GetTokensRequest", num_tokens, cache_size))); }
diff --git a/quiche/blind_sign_auth/cached_blind_sign_auth.h b/quiche/blind_sign_auth/cached_blind_sign_auth.h index fcfde2c..30ca72d 100644 --- a/quiche/blind_sign_auth/cached_blind_sign_auth.h +++ b/quiche/blind_sign_auth/cached_blind_sign_auth.h
@@ -5,13 +5,10 @@ #ifndef QUICHE_BLIND_SIGN_AUTH_CACHED_BLIND_SIGN_AUTH_H_ #define QUICHE_BLIND_SIGN_AUTH_CACHED_BLIND_SIGN_AUTH_H_ -#include <cstddef> -#include <functional> #include <string> #include <vector> #include "absl/status/statusor.h" -#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "quiche/blind_sign_auth/blind_sign_auth_interface.h" #include "quiche/common/platform/api/quiche_export.h" @@ -42,9 +39,8 @@ // or asynchronously on BlindSignAuth's BlindSignHttpInterface thread. // The GetTokens callback must not acquire any locks that the calling thread // owns, otherwise the callback will deadlock. - void GetTokens(absl::string_view oauth_token, int num_tokens, - std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> - callback) override; + void GetTokens(std::string oauth_token, int num_tokens, + SignedTokenCallback callback) override; // Removes all tokens in the cache. void ClearCache() { @@ -54,8 +50,8 @@ private: void HandleGetTokensResponse( - absl::StatusOr<absl::Span<BlindSignToken>> tokens, int num_tokens, - std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> callback); + SignedTokenCallback callback, int num_tokens, + absl::StatusOr<absl::Span<BlindSignToken>> tokens); std::vector<BlindSignToken> CreateOutputTokens(int num_tokens) QUICHE_EXCLUSIVE_LOCKS_REQUIRED(mutex_); void RemoveExpiredTokens() QUICHE_EXCLUSIVE_LOCKS_REQUIRED(mutex_);
diff --git a/quiche/blind_sign_auth/cached_blind_sign_auth_test.cc b/quiche/blind_sign_auth/cached_blind_sign_auth_test.cc index 7e39b98..b3ee6e0 100644 --- a/quiche/blind_sign_auth/cached_blind_sign_auth_test.cc +++ b/quiche/blind_sign_auth/cached_blind_sign_auth_test.cc
@@ -4,7 +4,6 @@ #include "quiche/blind_sign_auth/cached_blind_sign_auth.h" -#include <functional> #include <memory> #include <string> #include <utility> @@ -71,17 +70,14 @@ EXPECT_CALL(mock_blind_sign_auth_interface_, GetTokens(oauth_token_, kBlindSignAuthRequestMaxTokens, _)) .Times(1) - .WillOnce( - [this](Unused, int num_tokens, - std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> - callback) { - fake_tokens_ = MakeFakeTokens(num_tokens); - callback(absl::MakeSpan(fake_tokens_)); - }); + .WillOnce([this](Unused, int num_tokens, SignedTokenCallback callback) { + fake_tokens_ = MakeFakeTokens(num_tokens); + std::move(callback)(absl::MakeSpan(fake_tokens_)); + }); int num_tokens = 5; QuicheNotification done; - std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> callback = + SignedTokenCallback callback = [num_tokens, &done](absl::StatusOr<absl::Span<BlindSignToken>> tokens) { QUICHE_EXPECT_OK(tokens); EXPECT_EQ(num_tokens, tokens->size()); @@ -91,7 +87,8 @@ done.Notify(); }; - cached_blind_sign_auth_->GetTokens(oauth_token_, num_tokens, callback); + cached_blind_sign_auth_->GetTokens(oauth_token_, num_tokens, + std::move(callback)); done.WaitForNotification(); } @@ -100,18 +97,15 @@ GetTokens(oauth_token_, kBlindSignAuthRequestMaxTokens, _)) .Times(2) .WillRepeatedly( - [this](Unused, int num_tokens, - std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> - callback) { + [this](Unused, int num_tokens, SignedTokenCallback callback) { fake_tokens_ = MakeFakeTokens(num_tokens); - callback(absl::MakeSpan(fake_tokens_)); + std::move(callback)(absl::MakeSpan(fake_tokens_)); }); int num_tokens = kBlindSignAuthRequestMaxTokens - 1; QuicheNotification first; - std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> - first_callback = [num_tokens, &first]( - absl::StatusOr<absl::Span<BlindSignToken>> tokens) { + SignedTokenCallback first_callback = + [num_tokens, &first](absl::StatusOr<absl::Span<BlindSignToken>> tokens) { QUICHE_EXPECT_OK(tokens); EXPECT_EQ(num_tokens, tokens->size()); for (int i = 0; i < num_tokens; i++) { @@ -120,13 +114,13 @@ first.Notify(); }; - cached_blind_sign_auth_->GetTokens(oauth_token_, num_tokens, first_callback); + cached_blind_sign_auth_->GetTokens(oauth_token_, num_tokens, + std::move(first_callback)); first.WaitForNotification(); QuicheNotification second; - std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> - second_callback = [num_tokens, &second]( - absl::StatusOr<absl::Span<BlindSignToken>> tokens) { + SignedTokenCallback second_callback = + [num_tokens, &second](absl::StatusOr<absl::Span<BlindSignToken>> tokens) { QUICHE_EXPECT_OK(tokens); EXPECT_EQ(num_tokens, tokens->size()); EXPECT_EQ(tokens->at(0).token, @@ -137,7 +131,8 @@ second.Notify(); }; - cached_blind_sign_auth_->GetTokens(oauth_token_, num_tokens, second_callback); + cached_blind_sign_auth_->GetTokens(oauth_token_, num_tokens, + std::move(second_callback)); second.WaitForNotification(); } @@ -145,19 +140,15 @@ EXPECT_CALL(mock_blind_sign_auth_interface_, GetTokens(oauth_token_, kBlindSignAuthRequestMaxTokens, _)) .Times(1) - .WillOnce( - [this](Unused, int num_tokens, - std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> - callback) { - fake_tokens_ = MakeFakeTokens(num_tokens); - callback(absl::MakeSpan(fake_tokens_)); - }); + .WillOnce([this](Unused, int num_tokens, SignedTokenCallback callback) { + fake_tokens_ = MakeFakeTokens(num_tokens); + std::move(callback)(absl::MakeSpan(fake_tokens_)); + }); int num_tokens = kBlindSignAuthRequestMaxTokens / 2; QuicheNotification first; - std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> - first_callback = [num_tokens, &first]( - absl::StatusOr<absl::Span<BlindSignToken>> tokens) { + SignedTokenCallback first_callback = + [num_tokens, &first](absl::StatusOr<absl::Span<BlindSignToken>> tokens) { QUICHE_EXPECT_OK(tokens); EXPECT_EQ(num_tokens, tokens->size()); for (int i = 0; i < num_tokens; i++) { @@ -166,13 +157,13 @@ first.Notify(); }; - cached_blind_sign_auth_->GetTokens(oauth_token_, num_tokens, first_callback); + cached_blind_sign_auth_->GetTokens(oauth_token_, num_tokens, + std::move(first_callback)); first.WaitForNotification(); QuicheNotification second; - std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> - second_callback = [num_tokens, &second]( - absl::StatusOr<absl::Span<BlindSignToken>> tokens) { + SignedTokenCallback second_callback = + [num_tokens, &second](absl::StatusOr<absl::Span<BlindSignToken>> tokens) { QUICHE_EXPECT_OK(tokens); EXPECT_EQ(num_tokens, tokens->size()); for (int i = 0; i < num_tokens; i++) { @@ -182,7 +173,8 @@ second.Notify(); }; - cached_blind_sign_auth_->GetTokens(oauth_token_, num_tokens, second_callback); + cached_blind_sign_auth_->GetTokens(oauth_token_, num_tokens, + std::move(second_callback)); second.WaitForNotification(); } @@ -191,18 +183,15 @@ GetTokens(oauth_token_, kBlindSignAuthRequestMaxTokens, _)) .Times(2) .WillRepeatedly( - [this](Unused, int num_tokens, - std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> - callback) { + [this](Unused, int num_tokens, SignedTokenCallback callback) { fake_tokens_ = MakeFakeTokens(num_tokens); - callback(absl::MakeSpan(fake_tokens_)); + std::move(callback)(absl::MakeSpan(fake_tokens_)); }); int num_tokens = kBlindSignAuthRequestMaxTokens / 2; QuicheNotification first; - std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> - first_callback = [num_tokens, &first]( - absl::StatusOr<absl::Span<BlindSignToken>> tokens) { + SignedTokenCallback first_callback = + [num_tokens, &first](absl::StatusOr<absl::Span<BlindSignToken>> tokens) { QUICHE_EXPECT_OK(tokens); EXPECT_EQ(num_tokens, tokens->size()); for (int i = 0; i < num_tokens; i++) { @@ -211,13 +200,13 @@ first.Notify(); }; - cached_blind_sign_auth_->GetTokens(oauth_token_, num_tokens, first_callback); + cached_blind_sign_auth_->GetTokens(oauth_token_, num_tokens, + std::move(first_callback)); first.WaitForNotification(); QuicheNotification second; - std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> - second_callback = [num_tokens, &second]( - absl::StatusOr<absl::Span<BlindSignToken>> tokens) { + SignedTokenCallback second_callback = + [num_tokens, &second](absl::StatusOr<absl::Span<BlindSignToken>> tokens) { QUICHE_EXPECT_OK(tokens); EXPECT_EQ(num_tokens, tokens->size()); for (int i = 0; i < num_tokens; i++) { @@ -227,14 +216,15 @@ second.Notify(); }; - cached_blind_sign_auth_->GetTokens(oauth_token_, num_tokens, second_callback); + cached_blind_sign_auth_->GetTokens(oauth_token_, num_tokens, + std::move(second_callback)); second.WaitForNotification(); QuicheNotification third; int third_request_tokens = 10; - std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> - third_callback = [third_request_tokens, &third]( - absl::StatusOr<absl::Span<BlindSignToken>> tokens) { + SignedTokenCallback third_callback = + [third_request_tokens, + &third](absl::StatusOr<absl::Span<BlindSignToken>> tokens) { QUICHE_EXPECT_OK(tokens); EXPECT_EQ(third_request_tokens, tokens->size()); for (int i = 0; i < third_request_tokens; i++) { @@ -244,7 +234,7 @@ }; cached_blind_sign_auth_->GetTokens(oauth_token_, third_request_tokens, - third_callback); + std::move(third_callback)); third.WaitForNotification(); } @@ -254,7 +244,7 @@ .Times(0); int num_tokens = kBlindSignAuthRequestMaxTokens + 1; - std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> callback = + SignedTokenCallback callback = [](absl::StatusOr<absl::Span<BlindSignToken>> tokens) { EXPECT_THAT(tokens.status().code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT( @@ -263,7 +253,8 @@ kBlindSignAuthRequestMaxTokens)); }; - cached_blind_sign_auth_->GetTokens(oauth_token_, num_tokens, callback); + cached_blind_sign_auth_->GetTokens(oauth_token_, num_tokens, + std::move(callback)); } TEST_F(CachedBlindSignAuthTest, TestGetTokensRequestNegative) { @@ -272,7 +263,7 @@ .Times(0); int num_tokens = -1; - std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> callback = + SignedTokenCallback callback = [num_tokens](absl::StatusOr<absl::Span<BlindSignToken>> tokens) { EXPECT_THAT(tokens.status().code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT(tokens.status().message(), @@ -280,46 +271,46 @@ num_tokens)); }; - cached_blind_sign_auth_->GetTokens(oauth_token_, num_tokens, callback); + cached_blind_sign_auth_->GetTokens(oauth_token_, num_tokens, + std::move(callback)); } TEST_F(CachedBlindSignAuthTest, TestHandleGetTokensResponseErrorHandling) { EXPECT_CALL(mock_blind_sign_auth_interface_, GetTokens(oauth_token_, kBlindSignAuthRequestMaxTokens, _)) .Times(2) - .WillOnce(InvokeArgument<2>(absl::InternalError("AuthAndSign failed"))) - .WillOnce( - [this](Unused, int num_tokens, - std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> - callback) { - fake_tokens_ = MakeFakeTokens(num_tokens); - fake_tokens_.pop_back(); - callback(absl::MakeSpan(fake_tokens_)); - }); + .WillOnce([](Unused, int num_tokens, SignedTokenCallback callback) { + std::move(callback)(absl::InternalError("AuthAndSign failed")); + }) + .WillOnce([this](Unused, int num_tokens, SignedTokenCallback callback) { + fake_tokens_ = MakeFakeTokens(num_tokens); + fake_tokens_.pop_back(); + std::move(callback)(absl::MakeSpan(fake_tokens_)); + }); int num_tokens = kBlindSignAuthRequestMaxTokens; QuicheNotification first; - std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> - first_callback = - [&first](absl::StatusOr<absl::Span<BlindSignToken>> tokens) { - EXPECT_THAT(tokens.status().code(), absl::StatusCode::kInternal); - EXPECT_THAT(tokens.status().message(), "AuthAndSign failed"); - first.Notify(); - }; + SignedTokenCallback first_callback = + [&first](absl::StatusOr<absl::Span<BlindSignToken>> tokens) { + EXPECT_THAT(tokens.status().code(), absl::StatusCode::kInternal); + EXPECT_THAT(tokens.status().message(), "AuthAndSign failed"); + first.Notify(); + }; - cached_blind_sign_auth_->GetTokens(oauth_token_, num_tokens, first_callback); + cached_blind_sign_auth_->GetTokens(oauth_token_, num_tokens, + std::move(first_callback)); first.WaitForNotification(); QuicheNotification second; - std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> - second_callback = - [&second](absl::StatusOr<absl::Span<BlindSignToken>> tokens) { - EXPECT_THAT(tokens.status().code(), - absl::StatusCode::kResourceExhausted); - second.Notify(); - }; + SignedTokenCallback second_callback = + [&second](absl::StatusOr<absl::Span<BlindSignToken>> tokens) { + EXPECT_THAT(tokens.status().code(), + absl::StatusCode::kResourceExhausted); + second.Notify(); + }; - cached_blind_sign_auth_->GetTokens(oauth_token_, num_tokens, second_callback); + cached_blind_sign_auth_->GetTokens(oauth_token_, num_tokens, + std::move(second_callback)); second.WaitForNotification(); } @@ -329,38 +320,36 @@ .Times(0); int num_tokens = 0; - std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> callback = + SignedTokenCallback callback = [](absl::StatusOr<absl::Span<BlindSignToken>> tokens) { QUICHE_EXPECT_OK(tokens); EXPECT_EQ(tokens->size(), 0); }; - cached_blind_sign_auth_->GetTokens(oauth_token_, num_tokens, callback); + cached_blind_sign_auth_->GetTokens(oauth_token_, num_tokens, + std::move(callback)); } TEST_F(CachedBlindSignAuthTest, TestExpiredTokensArePruned) { EXPECT_CALL(mock_blind_sign_auth_interface_, GetTokens(oauth_token_, kBlindSignAuthRequestMaxTokens, _)) .Times(1) - .WillOnce( - [this](Unused, int num_tokens, - std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> - callback) { - fake_tokens_ = MakeExpiredTokens(num_tokens); - callback(absl::MakeSpan(fake_tokens_)); - }); + .WillOnce([this](Unused, int num_tokens, SignedTokenCallback callback) { + fake_tokens_ = MakeExpiredTokens(num_tokens); + std::move(callback)(absl::MakeSpan(fake_tokens_)); + }); int num_tokens = kBlindSignAuthRequestMaxTokens; QuicheNotification first; - std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> - first_callback = - [&first](absl::StatusOr<absl::Span<BlindSignToken>> tokens) { - EXPECT_THAT(tokens.status().code(), - absl::StatusCode::kResourceExhausted); - first.Notify(); - }; + SignedTokenCallback first_callback = + [&first](absl::StatusOr<absl::Span<BlindSignToken>> tokens) { + EXPECT_THAT(tokens.status().code(), + absl::StatusCode::kResourceExhausted); + first.Notify(); + }; - cached_blind_sign_auth_->GetTokens(oauth_token_, num_tokens, first_callback); + cached_blind_sign_auth_->GetTokens(oauth_token_, num_tokens, + std::move(first_callback)); first.WaitForNotification(); } @@ -369,38 +358,36 @@ GetTokens(oauth_token_, kBlindSignAuthRequestMaxTokens, _)) .Times(2) .WillRepeatedly( - [this](Unused, int num_tokens, - std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> - callback) { + [this](Unused, int num_tokens, SignedTokenCallback callback) { fake_tokens_ = MakeExpiredTokens(num_tokens); - callback(absl::MakeSpan(fake_tokens_)); + std::move(callback)(absl::MakeSpan(fake_tokens_)); }); int num_tokens = kBlindSignAuthRequestMaxTokens / 2; QuicheNotification first; - std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> - first_callback = - [&first](absl::StatusOr<absl::Span<BlindSignToken>> tokens) { - EXPECT_THAT(tokens.status().code(), - absl::StatusCode::kResourceExhausted); - first.Notify(); - }; + SignedTokenCallback first_callback = + [&first](absl::StatusOr<absl::Span<BlindSignToken>> tokens) { + EXPECT_THAT(tokens.status().code(), + absl::StatusCode::kResourceExhausted); + first.Notify(); + }; - cached_blind_sign_auth_->GetTokens(oauth_token_, num_tokens, first_callback); + cached_blind_sign_auth_->GetTokens(oauth_token_, num_tokens, + std::move(first_callback)); first.WaitForNotification(); cached_blind_sign_auth_->ClearCache(); QuicheNotification second; - std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> - second_callback = - [&second](absl::StatusOr<absl::Span<BlindSignToken>> tokens) { - EXPECT_THAT(tokens.status().code(), - absl::StatusCode::kResourceExhausted); - second.Notify(); - }; + SignedTokenCallback second_callback = + [&second](absl::StatusOr<absl::Span<BlindSignToken>> tokens) { + EXPECT_THAT(tokens.status().code(), + absl::StatusCode::kResourceExhausted); + second.Notify(); + }; - cached_blind_sign_auth_->GetTokens(oauth_token_, num_tokens, second_callback); + cached_blind_sign_auth_->GetTokens(oauth_token_, num_tokens, + std::move(second_callback)); second.WaitForNotification(); }
diff --git a/quiche/blind_sign_auth/test_tools/mock_blind_sign_auth_interface.h b/quiche/blind_sign_auth/test_tools/mock_blind_sign_auth_interface.h index 378e8bb..19d740e 100644 --- a/quiche/blind_sign_auth/test_tools/mock_blind_sign_auth_interface.h +++ b/quiche/blind_sign_auth/test_tools/mock_blind_sign_auth_interface.h
@@ -5,12 +5,8 @@ #ifndef QUICHE_BLIND_SIGN_AUTH_TEST_TOOLS_MOCK_BLIND_SIGN_AUTH_INTERFACE_H_ #define QUICHE_BLIND_SIGN_AUTH_TEST_TOOLS_MOCK_BLIND_SIGN_AUTH_INTERFACE_H_ -#include <functional> #include <string> -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" #include "quiche/blind_sign_auth/blind_sign_auth_interface.h" #include "quiche/common/platform/api/quiche_export.h" #include "quiche/common/platform/api/quiche_test.h" @@ -21,9 +17,8 @@ : public BlindSignAuthInterface { public: MOCK_METHOD(void, GetTokens, - (absl::string_view oauth_token, int num_tokens, - std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> - callback), + (std::string oauth_token, int num_tokens, + SignedTokenCallback callback), (override)); };
diff --git a/quiche/blind_sign_auth/test_tools/mock_blind_sign_http_interface.h b/quiche/blind_sign_auth/test_tools/mock_blind_sign_http_interface.h index 15e970b..1b4f729 100644 --- a/quiche/blind_sign_auth/test_tools/mock_blind_sign_http_interface.h +++ b/quiche/blind_sign_auth/test_tools/mock_blind_sign_http_interface.h
@@ -5,12 +5,9 @@ #ifndef QUICHE_BLIND_SIGN_AUTH_TEST_TOOLS_MOCK_BLIND_SIGN_HTTP_INTERFACE_H_ #define QUICHE_BLIND_SIGN_AUTH_TEST_TOOLS_MOCK_BLIND_SIGN_HTTP_INTERFACE_H_ -#include <functional> #include <string> -#include "absl/status/statusor.h" #include "quiche/blind_sign_auth/blind_sign_http_interface.h" -#include "quiche/blind_sign_auth/blind_sign_http_response.h" #include "quiche/common/platform/api/quiche_export.h" #include "quiche/common/platform/api/quiche_test.h" @@ -19,12 +16,11 @@ class QUICHE_NO_EXPORT MockBlindSignHttpInterface : public BlindSignHttpInterface { public: - MOCK_METHOD( - void, DoRequest, - (const std::string& path_and_query, - const std::string& authorization_header, const std::string& body, - std::function<void(absl::StatusOr<BlindSignHttpResponse>)> callback), - (override)); + MOCK_METHOD(void, DoRequest, + (const std::string& path_and_query, + const std::string& authorization_header, const std::string& body, + BlindSignHttpCallback callback), + (override)); }; } // namespace quiche::test