Add ClearCache function to CachedBlindSignAuth. This function will be called when the cached tokens are invalid and need to be deleted. PiperOrigin-RevId: 533576003
diff --git a/quiche/blind_sign_auth/cached_blind_sign_auth.h b/quiche/blind_sign_auth/cached_blind_sign_auth.h index f976fb1..fcfde2c 100644 --- a/quiche/blind_sign_auth/cached_blind_sign_auth.h +++ b/quiche/blind_sign_auth/cached_blind_sign_auth.h
@@ -46,6 +46,12 @@ std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> callback) override; + // Removes all tokens in the cache. + void ClearCache() { + QuicheWriterMutexLock lock(&mutex_); + cached_tokens_.clear(); + } + private: void HandleGetTokensResponse( absl::StatusOr<absl::Span<BlindSignToken>> tokens, int num_tokens,
diff --git a/quiche/blind_sign_auth/cached_blind_sign_auth_test.cc b/quiche/blind_sign_auth/cached_blind_sign_auth_test.cc index 9bc7ee3..7e39b98 100644 --- a/quiche/blind_sign_auth/cached_blind_sign_auth_test.cc +++ b/quiche/blind_sign_auth/cached_blind_sign_auth_test.cc
@@ -27,7 +27,6 @@ namespace { using ::testing::_; -using ::testing::Invoke; using ::testing::InvokeArgument; using ::testing::Unused; @@ -72,13 +71,13 @@ EXPECT_CALL(mock_blind_sign_auth_interface_, GetTokens(oauth_token_, kBlindSignAuthRequestMaxTokens, _)) .Times(1) - .WillOnce(Invoke( + .WillOnce( [this](Unused, int num_tokens, std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> callback) { fake_tokens_ = MakeFakeTokens(num_tokens); callback(absl::MakeSpan(fake_tokens_)); - })); + }); int num_tokens = 5; QuicheNotification done; @@ -100,13 +99,13 @@ EXPECT_CALL(mock_blind_sign_auth_interface_, GetTokens(oauth_token_, kBlindSignAuthRequestMaxTokens, _)) .Times(2) - .WillRepeatedly(Invoke( + .WillRepeatedly( [this](Unused, int num_tokens, std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> callback) { fake_tokens_ = MakeFakeTokens(num_tokens); callback(absl::MakeSpan(fake_tokens_)); - })); + }); int num_tokens = kBlindSignAuthRequestMaxTokens - 1; QuicheNotification first; @@ -146,13 +145,13 @@ EXPECT_CALL(mock_blind_sign_auth_interface_, GetTokens(oauth_token_, kBlindSignAuthRequestMaxTokens, _)) .Times(1) - .WillOnce(Invoke( + .WillOnce( [this](Unused, int num_tokens, std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> callback) { fake_tokens_ = MakeFakeTokens(num_tokens); callback(absl::MakeSpan(fake_tokens_)); - })); + }); int num_tokens = kBlindSignAuthRequestMaxTokens / 2; QuicheNotification first; @@ -191,13 +190,13 @@ EXPECT_CALL(mock_blind_sign_auth_interface_, GetTokens(oauth_token_, kBlindSignAuthRequestMaxTokens, _)) .Times(2) - .WillRepeatedly(Invoke( + .WillRepeatedly( [this](Unused, int num_tokens, std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> callback) { fake_tokens_ = MakeFakeTokens(num_tokens); callback(absl::MakeSpan(fake_tokens_)); - })); + }); int num_tokens = kBlindSignAuthRequestMaxTokens / 2; QuicheNotification first; @@ -289,14 +288,14 @@ GetTokens(oauth_token_, kBlindSignAuthRequestMaxTokens, _)) .Times(2) .WillOnce(InvokeArgument<2>(absl::InternalError("AuthAndSign failed"))) - .WillOnce(Invoke( + .WillOnce( [this](Unused, int num_tokens, std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> callback) { fake_tokens_ = MakeFakeTokens(num_tokens); fake_tokens_.pop_back(); callback(absl::MakeSpan(fake_tokens_)); - })); + }); int num_tokens = kBlindSignAuthRequestMaxTokens; QuicheNotification first; @@ -343,13 +342,13 @@ EXPECT_CALL(mock_blind_sign_auth_interface_, GetTokens(oauth_token_, kBlindSignAuthRequestMaxTokens, _)) .Times(1) - .WillOnce(Invoke( + .WillOnce( [this](Unused, int num_tokens, std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> callback) { fake_tokens_ = MakeExpiredTokens(num_tokens); callback(absl::MakeSpan(fake_tokens_)); - })); + }); int num_tokens = kBlindSignAuthRequestMaxTokens; QuicheNotification first; @@ -365,6 +364,46 @@ first.WaitForNotification(); } +TEST_F(CachedBlindSignAuthTest, TestClearCacheRemovesTokens) { + EXPECT_CALL(mock_blind_sign_auth_interface_, + GetTokens(oauth_token_, kBlindSignAuthRequestMaxTokens, _)) + .Times(2) + .WillRepeatedly( + [this](Unused, int num_tokens, + std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> + callback) { + fake_tokens_ = MakeExpiredTokens(num_tokens); + callback(absl::MakeSpan(fake_tokens_)); + }); + + int num_tokens = kBlindSignAuthRequestMaxTokens / 2; + QuicheNotification first; + std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> + first_callback = + [&first](absl::StatusOr<absl::Span<BlindSignToken>> tokens) { + EXPECT_THAT(tokens.status().code(), + absl::StatusCode::kResourceExhausted); + first.Notify(); + }; + + cached_blind_sign_auth_->GetTokens(oauth_token_, num_tokens, first_callback); + first.WaitForNotification(); + + cached_blind_sign_auth_->ClearCache(); + + QuicheNotification second; + std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)> + second_callback = + [&second](absl::StatusOr<absl::Span<BlindSignToken>> tokens) { + EXPECT_THAT(tokens.status().code(), + absl::StatusCode::kResourceExhausted); + second.Notify(); + }; + + cached_blind_sign_auth_->GetTokens(oauth_token_, num_tokens, second_callback); + second.WaitForNotification(); +} + } // namespace } // namespace test } // namespace quiche