Add token expiration time checks to CachedBlindSignAuth to ensure that it only returns fresh tokens. Token expiration time is based on the AT server's signing public key. PiperOrigin-RevId: 532797040
diff --git a/quiche/blind_sign_auth/blind_sign_auth.cc b/quiche/blind_sign_auth/blind_sign_auth.cc index 8715eac..b1d8ae5 100644 --- a/quiche/blind_sign_auth/blind_sign_auth.cc +++ b/quiche/blind_sign_auth/blind_sign_auth.cc
@@ -7,6 +7,7 @@ #include <cstddef> #include <functional> #include <string> +#include <utility> #include <vector> #include "quiche/blind_sign_auth/proto/auth_and_sign.pb.h" @@ -38,8 +39,7 @@ void BlindSignAuth::GetTokens( absl::string_view oauth_token, int num_tokens, - std::function<void(absl::StatusOr<absl::Span<const std::string>>)> - callback) { + std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> callback) { // Create GetInitialData RPC. privacy::ppn::GetInitialDataRequest request; request.set_use_attestation(false); @@ -61,7 +61,7 @@ void BlindSignAuth::GetInitialDataCallback( absl::StatusOr<BlindSignHttpResponse> response, absl::string_view oauth_token, int num_tokens, - std::function<void(absl::StatusOr<absl::Span<std::string>>)> callback) { + std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> callback) { if (!response.ok()) { QUICHE_LOG(WARNING) << "GetInitialDataRequest failed: " << response.status(); @@ -93,6 +93,14 @@ callback(bssa_client.status()); return; } + absl::StatusOr<absl::Time> public_key_expiry_time = + private_membership::anonymous_tokens::TimeFromProto( + initial_data_response.at_public_metadata_public_key() + .expiration_time()); + if (!public_key_expiry_time.ok()) { + callback(absl::InternalError("Failed to parse public key expiration time")); + return; + } // Create plaintext tokens. // Client blinds plaintext tokens (random 32-byte strings) in CreateRequest. @@ -155,21 +163,23 @@ "/v1/authWithHeaderCreds", oauth_token.data(), sign_request.SerializeAsString(), [this, at_sign_request, public_metadata_info, + expiry_time_ = public_key_expiry_time.value(), bssa_client_ = bssa_client.value().get(), callback](absl::StatusOr<BlindSignHttpResponse> response) { - AuthAndSignCallback(response, public_metadata_info, *at_sign_request, - bssa_client_, callback); + AuthAndSignCallback(response, public_metadata_info, expiry_time_, + *at_sign_request, bssa_client_, callback); }); } void BlindSignAuth::AuthAndSignCallback( absl::StatusOr<BlindSignHttpResponse> response, privacy::ppn::PublicMetadataInfo public_metadata_info, + absl::Time public_key_expiry_time, private_membership::anonymous_tokens::AnonymousTokensSignRequest at_sign_request, private_membership::anonymous_tokens::AnonymousTokensRsaBssaClient* bssa_client, - std::function<void(absl::StatusOr<absl::Span<std::string>>)> callback) { + std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> callback) { // Validate response. if (!response.ok()) { QUICHE_LOG(WARNING) << "AuthAndSign failed: " << response.status(); @@ -245,7 +255,7 @@ } // Output SpendTokenData with data for the redeemer to make a SpendToken RPC. - std::vector<std::string> tokens_vec; + std::vector<BlindSignToken> tokens_vec; for (size_t i = 0; i < signed_tokens->size(); i++) { privacy::ppn::SpendTokenData spend_token_data; *spend_token_data.mutable_public_metadata() = @@ -266,10 +276,11 @@ spend_token_data.set_use_case(*use_case); spend_token_data.set_message_mask( signed_tokens->at(i).token().message_mask()); - tokens_vec.push_back(spend_token_data.SerializeAsString()); + tokens_vec.push_back(BlindSignToken{spend_token_data.SerializeAsString(), + public_key_expiry_time}); } - callback(absl::Span<std::string>(tokens_vec)); + callback(absl::Span<BlindSignToken>(tokens_vec)); } absl::Status BlindSignAuth::FingerprintPublicMetadata(
diff --git a/quiche/blind_sign_auth/blind_sign_auth.h b/quiche/blind_sign_auth/blind_sign_auth.h index 5a9384b..710659f 100644 --- a/quiche/blind_sign_auth/blind_sign_auth.h +++ b/quiche/blind_sign_auth/blind_sign_auth.h
@@ -8,11 +8,13 @@ #include <functional> #include <memory> #include <string> +#include <utility> #include <vector> #include "quiche/blind_sign_auth/proto/public_metadata.pb.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "absl/time/time.h" #include "absl/types/span.h" #include "quiche/blind_sign_auth/anonymous_tokens/cpp/client/anonymous_tokens_rsa_bssa_client.h" #include "quiche/blind_sign_auth/anonymous_tokens/proto/anonymous_tokens.pb.h" @@ -29,30 +31,31 @@ explicit BlindSignAuth(BlindSignHttpInterface* http_fetcher) : http_fetcher_(http_fetcher) {} - // Returns signed unblinded tokens in a callback. Tokens are single-use. + // Returns signed unblinded tokens and their expiration time in a callback. + // Tokens are single-use. // GetTokens starts asynchronous HTTP POST requests to a signer hostname // specified by the caller, with path and query params given in the request. // The GetTokens callback will run on the same thread as the // BlindSignHttpInterface callbacks. // Callers can make multiple concurrent requests to GetTokens. - void GetTokens( - absl::string_view oauth_token, int num_tokens, - std::function<void(absl::StatusOr<absl::Span<const std::string>>)> - callback) override; + void GetTokens(absl::string_view oauth_token, int num_tokens, + std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> + callback) override; private: void GetInitialDataCallback( absl::StatusOr<BlindSignHttpResponse> response, absl::string_view oauth_token, int num_tokens, - std::function<void(absl::StatusOr<absl::Span<std::string>>)> callback); + std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> callback); void AuthAndSignCallback( absl::StatusOr<BlindSignHttpResponse> response, privacy::ppn::PublicMetadataInfo public_metadata_info, + absl::Time public_key_expiry_time, private_membership::anonymous_tokens::AnonymousTokensSignRequest at_sign_request, private_membership::anonymous_tokens::AnonymousTokensRsaBssaClient* bssa_client, - std::function<void(absl::StatusOr<absl::Span<std::string>>)> callback); + std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> callback); absl::Status FingerprintPublicMetadata( const privacy::ppn::PublicMetadata& metadata, uint64_t* fingerprint);
diff --git a/quiche/blind_sign_auth/blind_sign_auth_interface.h b/quiche/blind_sign_auth/blind_sign_auth_interface.h index f7e3905..ac5b10a 100644 --- a/quiche/blind_sign_auth/blind_sign_auth_interface.h +++ b/quiche/blind_sign_auth/blind_sign_auth_interface.h
@@ -10,11 +10,20 @@ #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "absl/time/time.h" #include "absl/types/span.h" #include "quiche/common/platform/api/quiche_export.h" namespace quiche { +// A BlindSignToken is used to authenticate a request to a privacy proxy. +// The token string contains a serialized SpendTokenData proto. +// The token cannot be successfully redeemed after the expiration time. +struct QUICHE_EXPORT BlindSignToken { + std::string token; + absl::Time expiration; +}; + // BlindSignAuth provides signed, unblinded tokens to callers. class QUICHE_EXPORT BlindSignAuthInterface { public: @@ -23,7 +32,7 @@ // Returns signed unblinded tokens in a callback. Tokens are single-use. virtual void GetTokens( absl::string_view oauth_token, int num_tokens, - std::function<void(absl::StatusOr<absl::Span<const std::string>>)> + std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> callback) = 0; };
diff --git a/quiche/blind_sign_auth/blind_sign_auth_test.cc b/quiche/blind_sign_auth/blind_sign_auth_test.cc index bb9d8da..3a5d842 100644 --- a/quiche/blind_sign_auth/blind_sign_auth_test.cc +++ b/quiche/blind_sign_auth/blind_sign_auth_test.cc
@@ -129,10 +129,10 @@ sign_response_ = response; } - void ValidateGetTokensOutput(const absl::Span<const std::string>& tokens) { + void ValidateGetTokensOutput(const absl::Span<BlindSignToken>& tokens) { for (const auto& token : tokens) { privacy::ppn::SpendTokenData spend_token_data; - ASSERT_TRUE(spend_token_data.ParseFromString(token)); + ASSERT_TRUE(spend_token_data.ParseFromString(token.token)); // Validate token structure. EXPECT_EQ(spend_token_data.public_metadata().SerializeAsString(), public_metadata_info_.public_metadata().SerializeAsString()); @@ -191,9 +191,9 @@ int num_tokens = 1; QuicheNotification done; - std::function<void(absl::StatusOr<absl::Span<const std::string>>)> callback = + std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> callback = [this, &done, - num_tokens](absl::StatusOr<absl::Span<const std::string>> tokens) { + num_tokens](absl::StatusOr<absl::Span<BlindSignToken>> tokens) { QUICHE_EXPECT_OK(tokens); EXPECT_EQ(tokens->size(), num_tokens); ValidateGetTokensOutput(*tokens); @@ -216,8 +216,8 @@ int num_tokens = 1; QuicheNotification done; - std::function<void(absl::StatusOr<absl::Span<const std::string>>)> callback = - [&done](absl::StatusOr<absl::Span<const std::string>> tokens) { + std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> callback = + [&done](absl::StatusOr<absl::Span<BlindSignToken>> tokens) { EXPECT_THAT(tokens.status().code(), absl::StatusCode::kInternal); done.Notify(); }; @@ -245,8 +245,8 @@ int num_tokens = 1; QuicheNotification done; - std::function<void(absl::StatusOr<absl::Span<const std::string>>)> callback = - [&done](absl::StatusOr<absl::Span<const std::string>> tokens) { + std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> callback = + [&done](absl::StatusOr<absl::Span<BlindSignToken>> tokens) { EXPECT_THAT(tokens.status().code(), absl::StatusCode::kInvalidArgument); done.Notify(); }; @@ -286,8 +286,8 @@ int num_tokens = 1; QuicheNotification done; - std::function<void(absl::StatusOr<absl::Span<const std::string>>)> callback = - [&done](absl::StatusOr<absl::Span<const std::string>> tokens) { + std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> callback = + [&done](absl::StatusOr<absl::Span<BlindSignToken>> tokens) { EXPECT_THAT(tokens.status().code(), absl::StatusCode::kInternal); done.Notify(); };
diff --git a/quiche/blind_sign_auth/cached_blind_sign_auth.cc b/quiche/blind_sign_auth/cached_blind_sign_auth.cc index 34e5e73..7c9271d 100644 --- a/quiche/blind_sign_auth/cached_blind_sign_auth.cc +++ b/quiche/blind_sign_auth/cached_blind_sign_auth.cc
@@ -8,17 +8,22 @@ #include <vector> #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_format.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" #include "absl/types/span.h" +#include "quiche/blind_sign_auth/blind_sign_auth_interface.h" #include "quiche/common/platform/api/quiche_logging.h" #include "quiche/common/platform/api/quiche_mutex.h" namespace quiche { +constexpr absl::Duration kFreshnessConstant = absl::Minutes(5); + void CachedBlindSignAuth::GetTokens( absl::string_view oauth_token, int num_tokens, - std::function<void(absl::StatusOr<absl::Span<const std::string>>)> - callback) { + std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> callback) { if (num_tokens > max_tokens_per_request_) { callback(absl::InvalidArgumentError( absl::StrFormat("Number of tokens requested exceeds maximum: %d", @@ -31,25 +36,27 @@ return; } - std::vector<std::string> output_tokens; + std::vector<BlindSignToken> output_tokens; { QuicheWriterMutexLock lock(&mutex_); + RemoveExpiredTokens(); // Try to fill the request from cache. if (static_cast<size_t>(num_tokens) <= cached_tokens_.size()) { output_tokens = CreateOutputTokens(num_tokens); } } + if (!output_tokens.empty() || num_tokens == 0) { - callback(output_tokens); + callback(absl::MakeSpan(output_tokens)); return; } // Make a GetTokensRequest if the cache can't handle the request size. - std::function<void(absl::StatusOr<absl::Span<const std::string>>)> + std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> caching_callback = [this, num_tokens, - callback](absl::StatusOr<absl::Span<const std::string>> tokens) { + callback](absl::StatusOr<absl::Span<BlindSignToken>> tokens) { HandleGetTokensResponse(tokens, num_tokens, callback); }; blind_sign_auth_->GetTokens(oauth_token, kBlindSignAuthRequestMaxTokens, @@ -57,9 +64,8 @@ } void CachedBlindSignAuth::HandleGetTokensResponse( - absl::StatusOr<absl::Span<const std::string>> tokens, int num_tokens, - std::function<void(absl::StatusOr<absl::Span<const std::string>>)> - callback) { + absl::StatusOr<absl::Span<BlindSignToken>> tokens, int num_tokens, + std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> callback) { if (!tokens.ok()) { QUICHE_LOG(WARNING) << "BlindSignAuth::GetTokens failed: " << tokens.status(); @@ -72,16 +78,16 @@ << tokens->size(); } - std::vector<std::string> output_tokens; + std::vector<BlindSignToken> output_tokens; size_t cache_size; { QuicheWriterMutexLock lock(&mutex_); // Add returned tokens to cache. - for (const std::string& token : *tokens) { + for (const BlindSignToken& token : *tokens) { cached_tokens_.push_back(token); } - + RemoveExpiredTokens(); // Return tokens or a ResourceExhaustedError. cache_size = cached_tokens_.size(); if (cache_size >= static_cast<size_t>(num_tokens)) { @@ -90,7 +96,7 @@ } if (!output_tokens.empty()) { - callback(output_tokens); + callback(absl::MakeSpan(output_tokens)); return; } callback(absl::ResourceExhaustedError(absl::StrFormat( @@ -98,9 +104,9 @@ num_tokens, cache_size))); } -std::vector<std::string> CachedBlindSignAuth::CreateOutputTokens( +std::vector<BlindSignToken> CachedBlindSignAuth::CreateOutputTokens( int num_tokens) { - std::vector<std::string> output_tokens; + std::vector<BlindSignToken> output_tokens; if (cached_tokens_.size() < static_cast<size_t>(num_tokens)) { QUICHE_LOG(FATAL) << "Check failed, not enough tokens in cache: " << cached_tokens_.size() << " < " << num_tokens; @@ -112,4 +118,16 @@ return output_tokens; } +void CachedBlindSignAuth::RemoveExpiredTokens() { + size_t original_size = cached_tokens_.size(); + absl::Time now_plus_five_mins = absl::Now() + kFreshnessConstant; + for (size_t i = 0; i < original_size; i++) { + BlindSignToken token = std::move(cached_tokens_.front()); + cached_tokens_.pop_front(); + if (token.expiration > now_plus_five_mins) { + cached_tokens_.push_back(std::move(token)); + } + } +} + } // namespace quiche
diff --git a/quiche/blind_sign_auth/cached_blind_sign_auth.h b/quiche/blind_sign_auth/cached_blind_sign_auth.h index ee405a1..f976fb1 100644 --- a/quiche/blind_sign_auth/cached_blind_sign_auth.h +++ b/quiche/blind_sign_auth/cached_blind_sign_auth.h
@@ -35,29 +35,29 @@ : blind_sign_auth_(blind_sign_auth), max_tokens_per_request_(max_tokens_per_request) {} - // Returns signed unblinded tokens in a callback. Tokens are single-use. + // Returns signed unblinded tokens and expiration time in a callback. + // Tokens are single-use. They will not be usable after the expiration time. // // The GetTokens callback may be called synchronously on the calling thread, // or asynchronously on BlindSignAuth's BlindSignHttpInterface thread. // The GetTokens callback must not acquire any locks that the calling thread // owns, otherwise the callback will deadlock. - void GetTokens( - absl::string_view oauth_token, int num_tokens, - std::function<void(absl::StatusOr<absl::Span<const std::string>>)> - callback) override; + void GetTokens(absl::string_view oauth_token, int num_tokens, + std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> + callback) override; private: void HandleGetTokensResponse( - absl::StatusOr<absl::Span<const std::string>> tokens, int num_tokens, - std::function<void(absl::StatusOr<absl::Span<const std::string>>)> - callback); - std::vector<std::string> CreateOutputTokens(int num_tokens) + absl::StatusOr<absl::Span<BlindSignToken>> tokens, int num_tokens, + std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> callback); + std::vector<BlindSignToken> CreateOutputTokens(int num_tokens) QUICHE_EXCLUSIVE_LOCKS_REQUIRED(mutex_); + void RemoveExpiredTokens() QUICHE_EXCLUSIVE_LOCKS_REQUIRED(mutex_); BlindSignAuthInterface* blind_sign_auth_; int max_tokens_per_request_; QuicheMutex mutex_; - QuicheCircularDeque<std::string> cached_tokens_ QUICHE_GUARDED_BY(mutex_); + QuicheCircularDeque<BlindSignToken> cached_tokens_ QUICHE_GUARDED_BY(mutex_); }; } // namespace quiche
diff --git a/quiche/blind_sign_auth/cached_blind_sign_auth_test.cc b/quiche/blind_sign_auth/cached_blind_sign_auth_test.cc index dfad523..9bc7ee3 100644 --- a/quiche/blind_sign_auth/cached_blind_sign_auth_test.cc +++ b/quiche/blind_sign_auth/cached_blind_sign_auth_test.cc
@@ -7,12 +7,15 @@ #include <functional> #include <memory> #include <string> +#include <utility> #include <vector> #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" #include "absl/types/span.h" #include "quiche/blind_sign_auth/test_tools/mock_blind_sign_auth_interface.h" #include "quiche/common/platform/api/quiche_mutex.h" @@ -41,17 +44,28 @@ } public: - std::vector<std::string> MakeFakeTokens(int num_tokens) { - std::vector<std::string> fake_tokens; + std::vector<BlindSignToken> MakeFakeTokens(int num_tokens) { + std::vector<BlindSignToken> fake_tokens; for (int i = 0; i < kBlindSignAuthRequestMaxTokens; i++) { - fake_tokens.push_back(absl::StrCat("token:", i)); + fake_tokens.push_back(BlindSignToken{absl::StrCat("token:", i), + absl::Now() + absl::Hours(1)}); } return fake_tokens; } + + std::vector<BlindSignToken> MakeExpiredTokens(int num_tokens) { + std::vector<BlindSignToken> fake_tokens; + for (int i = 0; i < kBlindSignAuthRequestMaxTokens; i++) { + fake_tokens.push_back(BlindSignToken{absl::StrCat("token:", i), + absl::Now() - absl::Hours(1)}); + } + return fake_tokens; + } + MockBlindSignAuthInterface mock_blind_sign_auth_interface_; std::unique_ptr<CachedBlindSignAuth> cached_blind_sign_auth_; std::string oauth_token_ = "oauth_token"; - std::vector<std::string> fake_tokens_; + std::vector<BlindSignToken> fake_tokens_; }; TEST_F(CachedBlindSignAuthTest, TestGetTokensOneCallSuccessful) { @@ -60,7 +74,7 @@ .Times(1) .WillOnce(Invoke( [this](Unused, int num_tokens, - std::function<void(absl::StatusOr<absl::Span<std::string>>)> + std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> callback) { fake_tokens_ = MakeFakeTokens(num_tokens); callback(absl::MakeSpan(fake_tokens_)); @@ -68,13 +82,12 @@ int num_tokens = 5; QuicheNotification done; - std::function<void(absl::StatusOr<absl::Span<const std::string>>)> callback = - [num_tokens, - &done](absl::StatusOr<absl::Span<const std::string>> tokens) { + std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> callback = + [num_tokens, &done](absl::StatusOr<absl::Span<BlindSignToken>> tokens) { QUICHE_EXPECT_OK(tokens); EXPECT_EQ(num_tokens, tokens->size()); for (int i = 0; i < num_tokens; i++) { - EXPECT_EQ(tokens->at(i), absl::StrCat("token:", i)); + EXPECT_EQ(tokens->at(i).token, absl::StrCat("token:", i)); } done.Notify(); }; @@ -89,7 +102,7 @@ .Times(2) .WillRepeatedly(Invoke( [this](Unused, int num_tokens, - std::function<void(absl::StatusOr<absl::Span<std::string>>)> + std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> callback) { fake_tokens_ = MakeFakeTokens(num_tokens); callback(absl::MakeSpan(fake_tokens_)); @@ -97,36 +110,33 @@ int num_tokens = kBlindSignAuthRequestMaxTokens - 1; QuicheNotification first; - std::function<void(absl::StatusOr<absl::Span<const std::string>>)> - first_callback = - [num_tokens, - &first](absl::StatusOr<absl::Span<const std::string>> tokens) { - QUICHE_EXPECT_OK(tokens); - EXPECT_EQ(num_tokens, tokens->size()); - for (int i = 0; i < num_tokens; i++) { - EXPECT_EQ(tokens->at(i), absl::StrCat("token:", i)); - } - first.Notify(); - }; + std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> + first_callback = [num_tokens, &first]( + absl::StatusOr<absl::Span<BlindSignToken>> tokens) { + QUICHE_EXPECT_OK(tokens); + EXPECT_EQ(num_tokens, tokens->size()); + for (int i = 0; i < num_tokens; i++) { + EXPECT_EQ(tokens->at(i).token, absl::StrCat("token:", i)); + } + first.Notify(); + }; cached_blind_sign_auth_->GetTokens(oauth_token_, num_tokens, first_callback); first.WaitForNotification(); QuicheNotification second; - std::function<void(absl::StatusOr<absl::Span<const std::string>>)> - second_callback = - [num_tokens, - &second](absl::StatusOr<absl::Span<const std::string>> tokens) { - QUICHE_EXPECT_OK(tokens); - EXPECT_EQ(num_tokens, tokens->size()); - EXPECT_EQ( - tokens->at(0), - absl::StrCat("token:", kBlindSignAuthRequestMaxTokens - 1)); - for (int i = 1; i < num_tokens; i++) { - EXPECT_EQ(tokens->at(i), absl::StrCat("token:", i - 1)); - } - second.Notify(); - }; + std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> + second_callback = [num_tokens, &second]( + absl::StatusOr<absl::Span<BlindSignToken>> tokens) { + QUICHE_EXPECT_OK(tokens); + EXPECT_EQ(num_tokens, tokens->size()); + EXPECT_EQ(tokens->at(0).token, + absl::StrCat("token:", kBlindSignAuthRequestMaxTokens - 1)); + for (int i = 1; i < num_tokens; i++) { + EXPECT_EQ(tokens->at(i).token, absl::StrCat("token:", i - 1)); + } + second.Notify(); + }; cached_blind_sign_auth_->GetTokens(oauth_token_, num_tokens, second_callback); second.WaitForNotification(); @@ -138,7 +148,7 @@ .Times(1) .WillOnce(Invoke( [this](Unused, int num_tokens, - std::function<void(absl::StatusOr<absl::Span<std::string>>)> + std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> callback) { fake_tokens_ = MakeFakeTokens(num_tokens); callback(absl::MakeSpan(fake_tokens_)); @@ -146,33 +156,32 @@ int num_tokens = kBlindSignAuthRequestMaxTokens / 2; QuicheNotification first; - std::function<void(absl::StatusOr<absl::Span<const std::string>>)> - first_callback = - [num_tokens, - &first](absl::StatusOr<absl::Span<const std::string>> tokens) { - QUICHE_EXPECT_OK(tokens); - EXPECT_EQ(num_tokens, tokens->size()); - for (int i = 0; i < num_tokens; i++) { - EXPECT_EQ(tokens->at(i), absl::StrCat("token:", i)); - } - first.Notify(); - }; + std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> + first_callback = [num_tokens, &first]( + absl::StatusOr<absl::Span<BlindSignToken>> tokens) { + QUICHE_EXPECT_OK(tokens); + EXPECT_EQ(num_tokens, tokens->size()); + for (int i = 0; i < num_tokens; i++) { + EXPECT_EQ(tokens->at(i).token, absl::StrCat("token:", i)); + } + first.Notify(); + }; cached_blind_sign_auth_->GetTokens(oauth_token_, num_tokens, first_callback); first.WaitForNotification(); QuicheNotification second; - std::function<void(absl::StatusOr<absl::Span<const std::string>>)> - second_callback = - [num_tokens, - &second](absl::StatusOr<absl::Span<const std::string>> tokens) { - QUICHE_EXPECT_OK(tokens); - EXPECT_EQ(num_tokens, tokens->size()); - for (int i = 0; i < num_tokens; i++) { - EXPECT_EQ(tokens->at(i), absl::StrCat("token:", i + num_tokens)); - } - second.Notify(); - }; + std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> + second_callback = [num_tokens, &second]( + absl::StatusOr<absl::Span<BlindSignToken>> tokens) { + QUICHE_EXPECT_OK(tokens); + EXPECT_EQ(num_tokens, tokens->size()); + for (int i = 0; i < num_tokens; i++) { + EXPECT_EQ(tokens->at(i).token, + absl::StrCat("token:", i + num_tokens)); + } + second.Notify(); + }; cached_blind_sign_auth_->GetTokens(oauth_token_, num_tokens, second_callback); second.WaitForNotification(); @@ -184,7 +193,7 @@ .Times(2) .WillRepeatedly(Invoke( [this](Unused, int num_tokens, - std::function<void(absl::StatusOr<absl::Span<std::string>>)> + std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> callback) { fake_tokens_ = MakeFakeTokens(num_tokens); callback(absl::MakeSpan(fake_tokens_)); @@ -192,50 +201,48 @@ int num_tokens = kBlindSignAuthRequestMaxTokens / 2; QuicheNotification first; - std::function<void(absl::StatusOr<absl::Span<const std::string>>)> - first_callback = - [num_tokens, - &first](absl::StatusOr<absl::Span<const std::string>> tokens) { - QUICHE_EXPECT_OK(tokens); - EXPECT_EQ(num_tokens, tokens->size()); - for (int i = 0; i < num_tokens; i++) { - EXPECT_EQ(tokens->at(i), absl::StrCat("token:", i)); - } - first.Notify(); - }; + std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> + first_callback = [num_tokens, &first]( + absl::StatusOr<absl::Span<BlindSignToken>> tokens) { + QUICHE_EXPECT_OK(tokens); + EXPECT_EQ(num_tokens, tokens->size()); + for (int i = 0; i < num_tokens; i++) { + EXPECT_EQ(tokens->at(i).token, absl::StrCat("token:", i)); + } + first.Notify(); + }; cached_blind_sign_auth_->GetTokens(oauth_token_, num_tokens, first_callback); first.WaitForNotification(); QuicheNotification second; - std::function<void(absl::StatusOr<absl::Span<const std::string>>)> - second_callback = - [num_tokens, - &second](absl::StatusOr<absl::Span<const std::string>> tokens) { - QUICHE_EXPECT_OK(tokens); - EXPECT_EQ(num_tokens, tokens->size()); - for (int i = 0; i < num_tokens; i++) { - EXPECT_EQ(tokens->at(i), absl::StrCat("token:", i + num_tokens)); - } - second.Notify(); - }; + std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> + second_callback = [num_tokens, &second]( + absl::StatusOr<absl::Span<BlindSignToken>> tokens) { + QUICHE_EXPECT_OK(tokens); + EXPECT_EQ(num_tokens, tokens->size()); + for (int i = 0; i < num_tokens; i++) { + EXPECT_EQ(tokens->at(i).token, + absl::StrCat("token:", i + num_tokens)); + } + second.Notify(); + }; cached_blind_sign_auth_->GetTokens(oauth_token_, num_tokens, second_callback); second.WaitForNotification(); QuicheNotification third; int third_request_tokens = 10; - std::function<void(absl::StatusOr<absl::Span<const std::string>>)> - third_callback = - [third_request_tokens, - &third](absl::StatusOr<absl::Span<const std::string>> tokens) { - QUICHE_EXPECT_OK(tokens); - EXPECT_EQ(third_request_tokens, tokens->size()); - for (int i = 0; i < third_request_tokens; i++) { - EXPECT_EQ(tokens->at(i), absl::StrCat("token:", i)); - } - third.Notify(); - }; + std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> + third_callback = [third_request_tokens, &third]( + absl::StatusOr<absl::Span<BlindSignToken>> tokens) { + QUICHE_EXPECT_OK(tokens); + EXPECT_EQ(third_request_tokens, tokens->size()); + for (int i = 0; i < third_request_tokens; i++) { + EXPECT_EQ(tokens->at(i).token, absl::StrCat("token:", i)); + } + third.Notify(); + }; cached_blind_sign_auth_->GetTokens(oauth_token_, third_request_tokens, third_callback); @@ -248,8 +255,8 @@ .Times(0); int num_tokens = kBlindSignAuthRequestMaxTokens + 1; - std::function<void(absl::StatusOr<absl::Span<const std::string>>)> callback = - [](absl::StatusOr<absl::Span<const std::string>> tokens) { + std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> callback = + [](absl::StatusOr<absl::Span<BlindSignToken>> tokens) { EXPECT_THAT(tokens.status().code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT( tokens.status().message(), @@ -266,8 +273,8 @@ .Times(0); int num_tokens = -1; - std::function<void(absl::StatusOr<absl::Span<const std::string>>)> callback = - [num_tokens](absl::StatusOr<absl::Span<const std::string>> tokens) { + std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> callback = + [num_tokens](absl::StatusOr<absl::Span<BlindSignToken>> tokens) { EXPECT_THAT(tokens.status().code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT(tokens.status().message(), absl::StrFormat("Negative number of tokens requested: %d", @@ -284,7 +291,7 @@ .WillOnce(InvokeArgument<2>(absl::InternalError("AuthAndSign failed"))) .WillOnce(Invoke( [this](Unused, int num_tokens, - std::function<void(absl::StatusOr<absl::Span<std::string>>)> + std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> callback) { fake_tokens_ = MakeFakeTokens(num_tokens); fake_tokens_.pop_back(); @@ -293,9 +300,9 @@ int num_tokens = kBlindSignAuthRequestMaxTokens; QuicheNotification first; - std::function<void(absl::StatusOr<absl::Span<const std::string>>)> + std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> first_callback = - [&first](absl::StatusOr<absl::Span<const std::string>> tokens) { + [&first](absl::StatusOr<absl::Span<BlindSignToken>> tokens) { EXPECT_THAT(tokens.status().code(), absl::StatusCode::kInternal); EXPECT_THAT(tokens.status().message(), "AuthAndSign failed"); first.Notify(); @@ -305,9 +312,9 @@ first.WaitForNotification(); QuicheNotification second; - std::function<void(absl::StatusOr<absl::Span<const std::string>>)> + std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> second_callback = - [&second](absl::StatusOr<absl::Span<const std::string>> tokens) { + [&second](absl::StatusOr<absl::Span<BlindSignToken>> tokens) { EXPECT_THAT(tokens.status().code(), absl::StatusCode::kResourceExhausted); second.Notify(); @@ -323,8 +330,8 @@ .Times(0); int num_tokens = 0; - std::function<void(absl::StatusOr<absl::Span<const std::string>>)> callback = - [](absl::StatusOr<absl::Span<const std::string>> tokens) { + std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> callback = + [](absl::StatusOr<absl::Span<BlindSignToken>> tokens) { QUICHE_EXPECT_OK(tokens); EXPECT_EQ(tokens->size(), 0); }; @@ -332,6 +339,32 @@ cached_blind_sign_auth_->GetTokens(oauth_token_, num_tokens, callback); } +TEST_F(CachedBlindSignAuthTest, TestExpiredTokensArePruned) { + EXPECT_CALL(mock_blind_sign_auth_interface_, + GetTokens(oauth_token_, kBlindSignAuthRequestMaxTokens, _)) + .Times(1) + .WillOnce(Invoke( + [this](Unused, int num_tokens, + std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> + callback) { + fake_tokens_ = MakeExpiredTokens(num_tokens); + callback(absl::MakeSpan(fake_tokens_)); + })); + + int num_tokens = kBlindSignAuthRequestMaxTokens; + QuicheNotification first; + std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> + first_callback = + [&first](absl::StatusOr<absl::Span<BlindSignToken>> tokens) { + EXPECT_THAT(tokens.status().code(), + absl::StatusCode::kResourceExhausted); + first.Notify(); + }; + + cached_blind_sign_auth_->GetTokens(oauth_token_, num_tokens, first_callback); + first.WaitForNotification(); +} + } // namespace } // namespace test } // namespace quiche
diff --git a/quiche/blind_sign_auth/test_tools/mock_blind_sign_auth_interface.h b/quiche/blind_sign_auth/test_tools/mock_blind_sign_auth_interface.h index dcb4876..378e8bb 100644 --- a/quiche/blind_sign_auth/test_tools/mock_blind_sign_auth_interface.h +++ b/quiche/blind_sign_auth/test_tools/mock_blind_sign_auth_interface.h
@@ -20,12 +20,11 @@ class QUICHE_NO_EXPORT MockBlindSignAuthInterface : public BlindSignAuthInterface { public: - MOCK_METHOD( - void, GetTokens, - (absl::string_view oauth_token, int num_tokens, - std::function<void(absl::StatusOr<absl::Span<const std::string>>)> - callback), - (override)); + MOCK_METHOD(void, GetTokens, + (absl::string_view oauth_token, int num_tokens, + std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> + callback), + (override)); }; } // namespace quiche::test