Add ClearCache function to CachedBlindSignAuth. This function will be called when the cached tokens are invalid and need to be deleted.
PiperOrigin-RevId: 533576003
diff --git a/quiche/blind_sign_auth/cached_blind_sign_auth.h b/quiche/blind_sign_auth/cached_blind_sign_auth.h
index f976fb1..fcfde2c 100644
--- a/quiche/blind_sign_auth/cached_blind_sign_auth.h
+++ b/quiche/blind_sign_auth/cached_blind_sign_auth.h
@@ -46,6 +46,12 @@
std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)>
callback) override;
+ // Removes all tokens in the cache.
+ void ClearCache() {
+ QuicheWriterMutexLock lock(&mutex_);
+ cached_tokens_.clear();
+ }
+
private:
void HandleGetTokensResponse(
absl::StatusOr<absl::Span<BlindSignToken>> tokens, int num_tokens,
diff --git a/quiche/blind_sign_auth/cached_blind_sign_auth_test.cc b/quiche/blind_sign_auth/cached_blind_sign_auth_test.cc
index 9bc7ee3..7e39b98 100644
--- a/quiche/blind_sign_auth/cached_blind_sign_auth_test.cc
+++ b/quiche/blind_sign_auth/cached_blind_sign_auth_test.cc
@@ -27,7 +27,6 @@
namespace {
using ::testing::_;
-using ::testing::Invoke;
using ::testing::InvokeArgument;
using ::testing::Unused;
@@ -72,13 +71,13 @@
EXPECT_CALL(mock_blind_sign_auth_interface_,
GetTokens(oauth_token_, kBlindSignAuthRequestMaxTokens, _))
.Times(1)
- .WillOnce(Invoke(
+ .WillOnce(
[this](Unused, int num_tokens,
std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)>
callback) {
fake_tokens_ = MakeFakeTokens(num_tokens);
callback(absl::MakeSpan(fake_tokens_));
- }));
+ });
int num_tokens = 5;
QuicheNotification done;
@@ -100,13 +99,13 @@
EXPECT_CALL(mock_blind_sign_auth_interface_,
GetTokens(oauth_token_, kBlindSignAuthRequestMaxTokens, _))
.Times(2)
- .WillRepeatedly(Invoke(
+ .WillRepeatedly(
[this](Unused, int num_tokens,
std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)>
callback) {
fake_tokens_ = MakeFakeTokens(num_tokens);
callback(absl::MakeSpan(fake_tokens_));
- }));
+ });
int num_tokens = kBlindSignAuthRequestMaxTokens - 1;
QuicheNotification first;
@@ -146,13 +145,13 @@
EXPECT_CALL(mock_blind_sign_auth_interface_,
GetTokens(oauth_token_, kBlindSignAuthRequestMaxTokens, _))
.Times(1)
- .WillOnce(Invoke(
+ .WillOnce(
[this](Unused, int num_tokens,
std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)>
callback) {
fake_tokens_ = MakeFakeTokens(num_tokens);
callback(absl::MakeSpan(fake_tokens_));
- }));
+ });
int num_tokens = kBlindSignAuthRequestMaxTokens / 2;
QuicheNotification first;
@@ -191,13 +190,13 @@
EXPECT_CALL(mock_blind_sign_auth_interface_,
GetTokens(oauth_token_, kBlindSignAuthRequestMaxTokens, _))
.Times(2)
- .WillRepeatedly(Invoke(
+ .WillRepeatedly(
[this](Unused, int num_tokens,
std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)>
callback) {
fake_tokens_ = MakeFakeTokens(num_tokens);
callback(absl::MakeSpan(fake_tokens_));
- }));
+ });
int num_tokens = kBlindSignAuthRequestMaxTokens / 2;
QuicheNotification first;
@@ -289,14 +288,14 @@
GetTokens(oauth_token_, kBlindSignAuthRequestMaxTokens, _))
.Times(2)
.WillOnce(InvokeArgument<2>(absl::InternalError("AuthAndSign failed")))
- .WillOnce(Invoke(
+ .WillOnce(
[this](Unused, int num_tokens,
std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)>
callback) {
fake_tokens_ = MakeFakeTokens(num_tokens);
fake_tokens_.pop_back();
callback(absl::MakeSpan(fake_tokens_));
- }));
+ });
int num_tokens = kBlindSignAuthRequestMaxTokens;
QuicheNotification first;
@@ -343,13 +342,13 @@
EXPECT_CALL(mock_blind_sign_auth_interface_,
GetTokens(oauth_token_, kBlindSignAuthRequestMaxTokens, _))
.Times(1)
- .WillOnce(Invoke(
+ .WillOnce(
[this](Unused, int num_tokens,
std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)>
callback) {
fake_tokens_ = MakeExpiredTokens(num_tokens);
callback(absl::MakeSpan(fake_tokens_));
- }));
+ });
int num_tokens = kBlindSignAuthRequestMaxTokens;
QuicheNotification first;
@@ -365,6 +364,46 @@
first.WaitForNotification();
}
+TEST_F(CachedBlindSignAuthTest, TestClearCacheRemovesTokens) {
+ EXPECT_CALL(mock_blind_sign_auth_interface_,
+ GetTokens(oauth_token_, kBlindSignAuthRequestMaxTokens, _))
+ .Times(2)
+ .WillRepeatedly(
+ [this](Unused, int num_tokens,
+ std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)>
+ callback) {
+ fake_tokens_ = MakeExpiredTokens(num_tokens);
+ callback(absl::MakeSpan(fake_tokens_));
+ });
+
+ int num_tokens = kBlindSignAuthRequestMaxTokens / 2;
+ QuicheNotification first;
+ std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)>
+ first_callback =
+ [&first](absl::StatusOr<absl::Span<BlindSignToken>> tokens) {
+ EXPECT_THAT(tokens.status().code(),
+ absl::StatusCode::kResourceExhausted);
+ first.Notify();
+ };
+
+ cached_blind_sign_auth_->GetTokens(oauth_token_, num_tokens, first_callback);
+ first.WaitForNotification();
+
+ cached_blind_sign_auth_->ClearCache();
+
+ QuicheNotification second;
+ std::function<void(absl::StatusOr<absl::Span<BlindSignToken>>)>
+ second_callback =
+ [&second](absl::StatusOr<absl::Span<BlindSignToken>> tokens) {
+ EXPECT_THAT(tokens.status().code(),
+ absl::StatusCode::kResourceExhausted);
+ second.Notify();
+ };
+
+ cached_blind_sign_auth_->GetTokens(oauth_token_, num_tokens, second_callback);
+ second.WaitForNotification();
+}
+
} // namespace
} // namespace test
} // namespace quiche