Add ProxyLayer parameter to BlindSignAuth::GetTokens. Chrome will use this to generate tokens for ProxyA and ProxyB.

PiperOrigin-RevId: 582422010
diff --git a/quiche/blind_sign_auth/blind_sign_auth.cc b/quiche/blind_sign_auth/blind_sign_auth.cc
index 62ca574..aa2aac0 100644
--- a/quiche/blind_sign_auth/blind_sign_auth.cc
+++ b/quiche/blind_sign_auth/blind_sign_auth.cc
@@ -25,7 +25,9 @@
 #include "anonymous_tokens/cpp/privacy_pass/rsa_bssa_public_metadata_client.h"
 #include "anonymous_tokens/cpp/privacy_pass/token_encodings.h"
 #include "anonymous_tokens/cpp/shared/proto_utils.h"
+#include "quiche/blind_sign_auth/blind_sign_auth_interface.h"
 #include "quiche/blind_sign_auth/blind_sign_auth_protos.h"
+#include "quiche/blind_sign_auth/blind_sign_http_interface.h"
 #include "quiche/blind_sign_auth/blind_sign_http_response.h"
 #include "quiche/common/platform/api/quiche_logging.h"
 #include "quiche/common/quiche_endian.h"
@@ -46,24 +48,34 @@
 
 void BlindSignAuth::GetTokens(std::string oauth_token, int num_tokens,
                               SignedTokenCallback callback) {
+  GetTokens(oauth_token, num_tokens, ProxyLayer::kProxyA, std::move(callback));
+}
+
+void BlindSignAuth::GetTokens(std::string oauth_token, int num_tokens,
+                              ProxyLayer proxy_layer,
+                              SignedTokenCallback callback) {
   // Create GetInitialData RPC.
   privacy::ppn::GetInitialDataRequest request;
   request.set_use_attestation(false);
   request.set_service_type("chromeipblinding");
   request.set_location_granularity(
       privacy::ppn::GetInitialDataRequest_LocationGranularity_CITY_GEOS);
+  // Validation version must be 2 to use ProxyLayer.
+  request.set_validation_version(2);
+  request.set_proxy_layer(QuicheProxyLayerToPpnProxyLayer(proxy_layer));
 
   // Call GetInitialData on the HttpFetcher.
   std::string body = request.SerializeAsString();
-  BlindSignHttpCallback initial_data_callback =
-      absl::bind_front(&BlindSignAuth::GetInitialDataCallback, this,
-                       oauth_token, num_tokens, std::move(callback));
+  BlindSignHttpCallback initial_data_callback = absl::bind_front(
+      &BlindSignAuth::GetInitialDataCallback, this, oauth_token, num_tokens,
+      proxy_layer, std::move(callback));
   http_fetcher_->DoRequest(BlindSignHttpRequestType::kGetInitialData,
                            oauth_token, body, std::move(initial_data_callback));
 }
 
 void BlindSignAuth::GetInitialDataCallback(
-    std::string oauth_token, int num_tokens, SignedTokenCallback callback,
+    std::string oauth_token, int num_tokens, ProxyLayer proxy_layer,
+    SignedTokenCallback callback,
     absl::StatusOr<BlindSignHttpResponse> response) {
   if (!response.ok()) {
     QUICHE_LOG(WARNING) << "GetInitialDataRequest failed: "
@@ -107,11 +119,11 @@
     QUICHE_DVLOG(1) << "Using Privacy Pass client";
     GeneratePrivacyPassTokens(
         initial_data_response, *public_metadata_expiry_time,
-        std::move(oauth_token), num_tokens, std::move(callback));
+        std::move(oauth_token), num_tokens, proxy_layer, std::move(callback));
   } else {
     QUICHE_DVLOG(1) << "Using public metadata client";
     GenerateRsaBssaTokens(initial_data_response, *public_metadata_expiry_time,
-                          std::move(oauth_token), num_tokens,
+                          std::move(oauth_token), num_tokens, proxy_layer,
                           std::move(callback));
   }
 }
@@ -119,7 +131,7 @@
 void BlindSignAuth::GeneratePrivacyPassTokens(
     privacy::ppn::GetInitialDataResponse initial_data_response,
     absl::Time public_metadata_expiry_time, std::string oauth_token,
-    int num_tokens, SignedTokenCallback callback) {
+    int num_tokens, ProxyLayer proxy_layer, SignedTokenCallback callback) {
   // Set up values used in the token generation loop.
   anonymous_tokens::RSAPublicKey public_key_proto;
   if (!public_key_proto.ParseFromString(
@@ -148,7 +160,7 @@
   }
   std::vector<uint16_t> kExpectedExtensionTypes = {
       /*ExpirationTimestamp=*/0x0001, /*GeoHint=*/0x0002,
-      /*ServiceType=*/0xF001, /*DebugMode=*/0xF002};
+      /*ServiceType=*/0xF001, /*DebugMode=*/0xF002, /*ProxyLayer=*/0xF003};
   absl::Status result =
       anonymous_tokens::ValidateExtensionsOrderAndValues(
           *extensions, absl::MakeSpan(kExpectedExtensionTypes), absl::Now());
@@ -224,6 +236,7 @@
       initial_data_response.privacy_pass_data().public_metadata_extensions());
   // TODO(b/295924807): deprecate this option after AT server defaults to it
   sign_request.set_do_not_use_rsa_public_exponent(true);
+  sign_request.set_proxy_layer(QuicheProxyLayerToPpnProxyLayer(proxy_layer));
 
   absl::StatusOr<anonymous_tokens::AnonymousTokensUseCase>
       use_case = anonymous_tokens::ParseUseCase(
@@ -249,7 +262,7 @@
 void BlindSignAuth::GenerateRsaBssaTokens(
     privacy::ppn::GetInitialDataResponse initial_data_response,
     absl::Time public_metadata_expiry_time, std::string oauth_token,
-    int num_tokens, SignedTokenCallback callback) {
+    int num_tokens, ProxyLayer proxy_layer, SignedTokenCallback callback) {
   // Create public metadata client.
   auto bssa_client =
       anonymous_tokens::AnonymousTokensRsaBssaClient::
@@ -316,6 +329,7 @@
   }
   // TODO(b/295924807): deprecate this option after AT server defaults to it
   sign_request.set_do_not_use_rsa_public_exponent(true);
+  sign_request.set_proxy_layer(QuicheProxyLayerToPpnProxyLayer(proxy_layer));
 
   privacy::ppn::PublicMetadataInfo public_metadata_info =
       initial_data_response.public_metadata_info();
@@ -592,4 +606,16 @@
   return absl::StatusCode::kUnknown;
 }
 
+privacy::ppn::ProxyLayer BlindSignAuth::QuicheProxyLayerToPpnProxyLayer(
+    quiche::ProxyLayer proxy_layer) {
+  switch (proxy_layer) {
+    case ProxyLayer::kProxyA: {
+      return privacy::ppn::ProxyLayer::PROXY_A;
+    }
+    case ProxyLayer::kProxyB: {
+      return privacy::ppn::ProxyLayer::PROXY_B;
+    }
+  }
+}
+
 }  // namespace quiche
diff --git a/quiche/blind_sign_auth/blind_sign_auth.h b/quiche/blind_sign_auth/blind_sign_auth.h
index 43fd604..567046c 100644
--- a/quiche/blind_sign_auth/blind_sign_auth.h
+++ b/quiche/blind_sign_auth/blind_sign_auth.h
@@ -36,20 +36,26 @@
   // BlindSignHttpInterface callbacks.
   // Callers can make multiple concurrent requests to GetTokens.
   void GetTokens(std::string oauth_token, int num_tokens,
-                 SignedTokenCallback callback) override;
+                 ProxyLayer proxy_layer, SignedTokenCallback callback) override;
+
+  // Returns signed unblinded tokens and their expiration time in a callback.
+  // This function sends ProxyLayer::kProxyA by default for compatibility.
+  void GetTokens(std::string oauth_token, int num_tokens,
+                 SignedTokenCallback callback);
 
  private:
   void GetInitialDataCallback(std::string oauth_token, int num_tokens,
+                              ProxyLayer proxy_layer,
                               SignedTokenCallback callback,
                               absl::StatusOr<BlindSignHttpResponse> response);
   void GeneratePrivacyPassTokens(
       privacy::ppn::GetInitialDataResponse initial_data_response,
       absl::Time public_metadata_expiry_time, std::string oauth_token,
-      int num_tokens, SignedTokenCallback callback);
+      int num_tokens, ProxyLayer proxy_layer, SignedTokenCallback callback);
   void GenerateRsaBssaTokens(
       privacy::ppn::GetInitialDataResponse initial_data_response,
       absl::Time public_metadata_expiry_time, std::string oauth_token,
-      int num_tokens, SignedTokenCallback callback);
+      int num_tokens, ProxyLayer proxy_layer, SignedTokenCallback callback);
   void AuthAndSignCallback(
       privacy::ppn::PublicMetadataInfo public_metadata_info,
       absl::Time public_key_expiry_time,
@@ -71,6 +77,8 @@
   absl::Status FingerprintPublicMetadata(
       const privacy::ppn::PublicMetadata& metadata, uint64_t* fingerprint);
   absl::StatusCode HttpCodeToStatusCode(int http_code);
+  privacy::ppn::ProxyLayer QuicheProxyLayerToPpnProxyLayer(
+      quiche::ProxyLayer proxy_layer);
 
   BlindSignHttpInterface* http_fetcher_ = nullptr;
   privacy::ppn::BlindSignAuthOptions auth_options_;
diff --git a/quiche/blind_sign_auth/blind_sign_auth_interface.h b/quiche/blind_sign_auth/blind_sign_auth_interface.h
index e2f3bab..81b2fd3 100644
--- a/quiche/blind_sign_auth/blind_sign_auth_interface.h
+++ b/quiche/blind_sign_auth/blind_sign_auth_interface.h
@@ -15,6 +15,12 @@
 
 namespace quiche {
 
+// ProxyLayer indicates which proxy layer that tokens will be used with.
+enum class ProxyLayer : int {
+  kProxyA,
+  kProxyB,
+};
+
 // A BlindSignToken is used to authenticate a request to a privacy proxy.
 // The token string contains a serialized SpendTokenData proto.
 // The token cannot be successfully redeemed after the expiration time.
@@ -33,6 +39,7 @@
 
   // Returns signed unblinded tokens in a callback. Tokens are single-use.
   virtual void GetTokens(std::string oauth_token, int num_tokens,
+                         ProxyLayer proxy_layer,
                          SignedTokenCallback callback) = 0;
 };
 
diff --git a/quiche/blind_sign_auth/blind_sign_auth_protos.h b/quiche/blind_sign_auth/blind_sign_auth_protos.h
index b77dbe6..9935fa6 100644
--- a/quiche/blind_sign_auth/blind_sign_auth_protos.h
+++ b/quiche/blind_sign_auth/blind_sign_auth_protos.h
@@ -6,6 +6,7 @@
 #include "quiche/blind_sign_auth/proto/blind_sign_auth_options.pb.h"  // IWYU pragma: export
 #include "quiche/blind_sign_auth/proto/get_initial_data.pb.h"  // IWYU pragma: export
 #include "quiche/blind_sign_auth/proto/key_services.pb.h"  // IWYU pragma: export
+#include "quiche/blind_sign_auth/proto/proxy_layer.pb.h"   // IWYU pragma: export
 #include "quiche/blind_sign_auth/proto/public_metadata.pb.h"  // IWYU pragma: export
 #include "quiche/blind_sign_auth/proto/spend_token_data.pb.h"  // IWYU pragma: export
 #include "anonymous_tokens/proto/anonymous_tokens.pb.h"  // IWYU pragma: export
diff --git a/quiche/blind_sign_auth/blind_sign_auth_test.cc b/quiche/blind_sign_auth/blind_sign_auth_test.cc
index e2727fe..51c3136 100644
--- a/quiche/blind_sign_auth/blind_sign_auth_test.cc
+++ b/quiche/blind_sign_auth/blind_sign_auth_test.cc
@@ -16,10 +16,10 @@
 #include "absl/time/time.h"
 #include "anonymous_tokens/cpp/crypto/crypto_utils.h"
 #include "anonymous_tokens/cpp/privacy_pass/token_encodings.h"
-#include "anonymous_tokens/cpp/testing/proto_utils.h"
 #include "anonymous_tokens/cpp/testing/utils.h"
 #include "openssl/base.h"
 #include "openssl/digest.h"
+#include "quiche/blind_sign_auth/blind_sign_auth_interface.h"
 #include "quiche/blind_sign_auth/blind_sign_auth_protos.h"
 #include "quiche/blind_sign_auth/blind_sign_http_interface.h"
 #include "quiche/blind_sign_auth/blind_sign_http_response.h"
@@ -79,6 +79,8 @@
     expected_get_initial_data_request_.set_service_type("chromeipblinding");
     expected_get_initial_data_request_.set_location_granularity(
         privacy::ppn::GetInitialDataRequest_LocationGranularity_CITY_GEOS);
+    expected_get_initial_data_request_.set_validation_version(2);
+    expected_get_initial_data_request_.set_proxy_layer(privacy::ppn::PROXY_A);
 
     // Create fake GetInitialDataResponse.
     privacy::ppn::GetInitialDataResponse fake_get_initial_data_response;
@@ -145,6 +147,14 @@
     QUICHE_EXPECT_OK(debug_mode_extension);
     extensions_.extensions.push_back(*debug_mode_extension);
 
+    anonymous_tokens::ProxyLayer proxy_layer;
+    proxy_layer.layer =
+        anonymous_tokens::ProxyLayer::kProxyA;
+    absl::StatusOr<anonymous_tokens::Extension>
+        proxy_layer_extension = proxy_layer.AsExtension();
+    QUICHE_EXPECT_OK(proxy_layer_extension);
+    extensions_.extensions.push_back(*proxy_layer_extension);
+
     absl::StatusOr<std::string> serialized_extensions =
         anonymous_tokens::EncodeExtensions(extensions_);
     QUICHE_EXPECT_OK(serialized_extensions);
@@ -308,7 +318,8 @@
         ValidateGetTokensOutput(*tokens);
         done.Notify();
       };
-  blind_sign_auth_->GetTokens(oauth_token_, num_tokens, std::move(callback));
+  blind_sign_auth_->GetTokens(oauth_token_, num_tokens, ProxyLayer::kProxyA,
+                              std::move(callback));
   done.WaitForNotification();
 }
 
@@ -333,7 +344,8 @@
         EXPECT_THAT(tokens.status().code(), absl::StatusCode::kInternal);
         done.Notify();
       };
-  blind_sign_auth_->GetTokens(oauth_token_, num_tokens, std::move(callback));
+  blind_sign_auth_->GetTokens(oauth_token_, num_tokens, ProxyLayer::kProxyA,
+                              std::move(callback));
   done.WaitForNotification();
 }
 
@@ -364,7 +376,8 @@
         EXPECT_THAT(tokens.status().code(), absl::StatusCode::kInvalidArgument);
         done.Notify();
       };
-  blind_sign_auth_->GetTokens(oauth_token_, num_tokens, std::move(callback));
+  blind_sign_auth_->GetTokens(oauth_token_, num_tokens, ProxyLayer::kProxyA,
+                              std::move(callback));
   done.WaitForNotification();
 }
 
@@ -406,7 +419,8 @@
         EXPECT_THAT(tokens.status().code(), absl::StatusCode::kInternal);
         done.Notify();
       };
-  blind_sign_auth_->GetTokens(oauth_token_, num_tokens, std::move(callback));
+  blind_sign_auth_->GetTokens(oauth_token_, num_tokens, ProxyLayer::kProxyA,
+                              std::move(callback));
   done.WaitForNotification();
 }
 
@@ -457,7 +471,8 @@
         ValidatePrivacyPassTokensOutput(*tokens);
         done.Notify();
       };
-  blind_sign_auth_->GetTokens(oauth_token_, num_tokens, std::move(callback));
+  blind_sign_auth_->GetTokens(oauth_token_, num_tokens, ProxyLayer::kProxyA,
+                              std::move(callback));
   done.WaitForNotification();
 }
 
@@ -493,7 +508,8 @@
         EXPECT_THAT(tokens.status().code(), absl::StatusCode::kInvalidArgument);
         done.Notify();
       };
-  blind_sign_auth_->GetTokens(oauth_token_, num_tokens, std::move(callback));
+  blind_sign_auth_->GetTokens(oauth_token_, num_tokens, ProxyLayer::kProxyA,
+                              std::move(callback));
   done.WaitForNotification();
 }
 
diff --git a/quiche/blind_sign_auth/cached_blind_sign_auth.cc b/quiche/blind_sign_auth/cached_blind_sign_auth.cc
index d2d7e70..638703e 100644
--- a/quiche/blind_sign_auth/cached_blind_sign_auth.cc
+++ b/quiche/blind_sign_auth/cached_blind_sign_auth.cc
@@ -23,6 +23,7 @@
 constexpr absl::Duration kFreshnessConstant = absl::Minutes(5);
 
 void CachedBlindSignAuth::GetTokens(std::string oauth_token, int num_tokens,
+                                    ProxyLayer proxy_layer,
                                     SignedTokenCallback callback) {
   if (num_tokens > max_tokens_per_request_) {
     std::move(callback)(absl::InvalidArgumentError(
@@ -57,7 +58,7 @@
       absl::bind_front(&CachedBlindSignAuth::HandleGetTokensResponse, this,
                        std::move(callback), num_tokens);
   blind_sign_auth_->GetTokens(oauth_token, kBlindSignAuthRequestMaxTokens,
-                              std::move(caching_callback));
+                              proxy_layer, std::move(caching_callback));
 }
 
 void CachedBlindSignAuth::HandleGetTokensResponse(
diff --git a/quiche/blind_sign_auth/cached_blind_sign_auth.h b/quiche/blind_sign_auth/cached_blind_sign_auth.h
index 30ca72d..b453e55 100644
--- a/quiche/blind_sign_auth/cached_blind_sign_auth.h
+++ b/quiche/blind_sign_auth/cached_blind_sign_auth.h
@@ -40,7 +40,7 @@
   // The GetTokens callback must not acquire any locks that the calling thread
   // owns, otherwise the callback will deadlock.
   void GetTokens(std::string oauth_token, int num_tokens,
-                 SignedTokenCallback callback) override;
+                 ProxyLayer proxy_layer, SignedTokenCallback callback) override;
 
   // Removes all tokens in the cache.
   void ClearCache() {
diff --git a/quiche/blind_sign_auth/cached_blind_sign_auth_test.cc b/quiche/blind_sign_auth/cached_blind_sign_auth_test.cc
index b3ee6e0..3f18938 100644
--- a/quiche/blind_sign_auth/cached_blind_sign_auth_test.cc
+++ b/quiche/blind_sign_auth/cached_blind_sign_auth_test.cc
@@ -16,6 +16,7 @@
 #include "absl/time/clock.h"
 #include "absl/time/time.h"
 #include "absl/types/span.h"
+#include "quiche/blind_sign_auth/blind_sign_auth_interface.h"
 #include "quiche/blind_sign_auth/test_tools/mock_blind_sign_auth_interface.h"
 #include "quiche/common/platform/api/quiche_mutex.h"
 #include "quiche/common/platform/api/quiche_test.h"
@@ -68,12 +69,13 @@
 
 TEST_F(CachedBlindSignAuthTest, TestGetTokensOneCallSuccessful) {
   EXPECT_CALL(mock_blind_sign_auth_interface_,
-              GetTokens(oauth_token_, kBlindSignAuthRequestMaxTokens, _))
+              GetTokens(oauth_token_, kBlindSignAuthRequestMaxTokens, _, _))
       .Times(1)
-      .WillOnce([this](Unused, int num_tokens, SignedTokenCallback callback) {
-        fake_tokens_ = MakeFakeTokens(num_tokens);
-        std::move(callback)(absl::MakeSpan(fake_tokens_));
-      });
+      .WillOnce(
+          [this](Unused, int num_tokens, Unused, SignedTokenCallback callback) {
+            fake_tokens_ = MakeFakeTokens(num_tokens);
+            std::move(callback)(absl::MakeSpan(fake_tokens_));
+          });
 
   int num_tokens = 5;
   QuicheNotification done;
@@ -88,16 +90,16 @@
       };
 
   cached_blind_sign_auth_->GetTokens(oauth_token_, num_tokens,
-                                     std::move(callback));
+                                     ProxyLayer::kProxyA, std::move(callback));
   done.WaitForNotification();
 }
 
 TEST_F(CachedBlindSignAuthTest, TestGetTokensMultipleRemoteCallsSuccessful) {
   EXPECT_CALL(mock_blind_sign_auth_interface_,
-              GetTokens(oauth_token_, kBlindSignAuthRequestMaxTokens, _))
+              GetTokens(oauth_token_, kBlindSignAuthRequestMaxTokens, _, _))
       .Times(2)
       .WillRepeatedly(
-          [this](Unused, int num_tokens, SignedTokenCallback callback) {
+          [this](Unused, int num_tokens, Unused, SignedTokenCallback callback) {
             fake_tokens_ = MakeFakeTokens(num_tokens);
             std::move(callback)(absl::MakeSpan(fake_tokens_));
           });
@@ -114,8 +116,8 @@
         first.Notify();
       };
 
-  cached_blind_sign_auth_->GetTokens(oauth_token_, num_tokens,
-                                     std::move(first_callback));
+  cached_blind_sign_auth_->GetTokens(
+      oauth_token_, num_tokens, ProxyLayer::kProxyA, std::move(first_callback));
   first.WaitForNotification();
 
   QuicheNotification second;
@@ -132,58 +134,17 @@
       };
 
   cached_blind_sign_auth_->GetTokens(oauth_token_, num_tokens,
+                                     ProxyLayer::kProxyA,
                                      std::move(second_callback));
   second.WaitForNotification();
 }
 
 TEST_F(CachedBlindSignAuthTest, TestGetTokensSecondRequestFilledFromCache) {
   EXPECT_CALL(mock_blind_sign_auth_interface_,
-              GetTokens(oauth_token_, kBlindSignAuthRequestMaxTokens, _))
+              GetTokens(oauth_token_, kBlindSignAuthRequestMaxTokens, _, _))
       .Times(1)
-      .WillOnce([this](Unused, int num_tokens, SignedTokenCallback callback) {
-        fake_tokens_ = MakeFakeTokens(num_tokens);
-        std::move(callback)(absl::MakeSpan(fake_tokens_));
-      });
-
-  int num_tokens = kBlindSignAuthRequestMaxTokens / 2;
-  QuicheNotification first;
-  SignedTokenCallback first_callback =
-      [num_tokens, &first](absl::StatusOr<absl::Span<BlindSignToken>> tokens) {
-        QUICHE_EXPECT_OK(tokens);
-        EXPECT_EQ(num_tokens, tokens->size());
-        for (int i = 0; i < num_tokens; i++) {
-          EXPECT_EQ(tokens->at(i).token, absl::StrCat("token:", i));
-        }
-        first.Notify();
-      };
-
-  cached_blind_sign_auth_->GetTokens(oauth_token_, num_tokens,
-                                     std::move(first_callback));
-  first.WaitForNotification();
-
-  QuicheNotification second;
-  SignedTokenCallback second_callback =
-      [num_tokens, &second](absl::StatusOr<absl::Span<BlindSignToken>> tokens) {
-        QUICHE_EXPECT_OK(tokens);
-        EXPECT_EQ(num_tokens, tokens->size());
-        for (int i = 0; i < num_tokens; i++) {
-          EXPECT_EQ(tokens->at(i).token,
-                    absl::StrCat("token:", i + num_tokens));
-        }
-        second.Notify();
-      };
-
-  cached_blind_sign_auth_->GetTokens(oauth_token_, num_tokens,
-                                     std::move(second_callback));
-  second.WaitForNotification();
-}
-
-TEST_F(CachedBlindSignAuthTest, TestGetTokensThirdRequestRefillsCache) {
-  EXPECT_CALL(mock_blind_sign_auth_interface_,
-              GetTokens(oauth_token_, kBlindSignAuthRequestMaxTokens, _))
-      .Times(2)
-      .WillRepeatedly(
-          [this](Unused, int num_tokens, SignedTokenCallback callback) {
+      .WillOnce(
+          [this](Unused, int num_tokens, Unused, SignedTokenCallback callback) {
             fake_tokens_ = MakeFakeTokens(num_tokens);
             std::move(callback)(absl::MakeSpan(fake_tokens_));
           });
@@ -200,8 +161,8 @@
         first.Notify();
       };
 
-  cached_blind_sign_auth_->GetTokens(oauth_token_, num_tokens,
-                                     std::move(first_callback));
+  cached_blind_sign_auth_->GetTokens(
+      oauth_token_, num_tokens, ProxyLayer::kProxyA, std::move(first_callback));
   first.WaitForNotification();
 
   QuicheNotification second;
@@ -217,6 +178,51 @@
       };
 
   cached_blind_sign_auth_->GetTokens(oauth_token_, num_tokens,
+                                     ProxyLayer::kProxyA,
+                                     std::move(second_callback));
+  second.WaitForNotification();
+}
+
+TEST_F(CachedBlindSignAuthTest, TestGetTokensThirdRequestRefillsCache) {
+  EXPECT_CALL(mock_blind_sign_auth_interface_,
+              GetTokens(oauth_token_, kBlindSignAuthRequestMaxTokens, _, _))
+      .Times(2)
+      .WillRepeatedly(
+          [this](Unused, int num_tokens, Unused, SignedTokenCallback callback) {
+            fake_tokens_ = MakeFakeTokens(num_tokens);
+            std::move(callback)(absl::MakeSpan(fake_tokens_));
+          });
+
+  int num_tokens = kBlindSignAuthRequestMaxTokens / 2;
+  QuicheNotification first;
+  SignedTokenCallback first_callback =
+      [num_tokens, &first](absl::StatusOr<absl::Span<BlindSignToken>> tokens) {
+        QUICHE_EXPECT_OK(tokens);
+        EXPECT_EQ(num_tokens, tokens->size());
+        for (int i = 0; i < num_tokens; i++) {
+          EXPECT_EQ(tokens->at(i).token, absl::StrCat("token:", i));
+        }
+        first.Notify();
+      };
+
+  cached_blind_sign_auth_->GetTokens(
+      oauth_token_, num_tokens, ProxyLayer::kProxyA, std::move(first_callback));
+  first.WaitForNotification();
+
+  QuicheNotification second;
+  SignedTokenCallback second_callback =
+      [num_tokens, &second](absl::StatusOr<absl::Span<BlindSignToken>> tokens) {
+        QUICHE_EXPECT_OK(tokens);
+        EXPECT_EQ(num_tokens, tokens->size());
+        for (int i = 0; i < num_tokens; i++) {
+          EXPECT_EQ(tokens->at(i).token,
+                    absl::StrCat("token:", i + num_tokens));
+        }
+        second.Notify();
+      };
+
+  cached_blind_sign_auth_->GetTokens(oauth_token_, num_tokens,
+                                     ProxyLayer::kProxyA,
                                      std::move(second_callback));
   second.WaitForNotification();
 
@@ -234,13 +240,14 @@
       };
 
   cached_blind_sign_auth_->GetTokens(oauth_token_, third_request_tokens,
+                                     ProxyLayer::kProxyA,
                                      std::move(third_callback));
   third.WaitForNotification();
 }
 
 TEST_F(CachedBlindSignAuthTest, TestGetTokensRequestTooLarge) {
   EXPECT_CALL(mock_blind_sign_auth_interface_,
-              GetTokens(oauth_token_, kBlindSignAuthRequestMaxTokens, _))
+              GetTokens(oauth_token_, kBlindSignAuthRequestMaxTokens, _, _))
       .Times(0);
 
   int num_tokens = kBlindSignAuthRequestMaxTokens + 1;
@@ -254,12 +261,12 @@
       };
 
   cached_blind_sign_auth_->GetTokens(oauth_token_, num_tokens,
-                                     std::move(callback));
+                                     ProxyLayer::kProxyA, std::move(callback));
 }
 
 TEST_F(CachedBlindSignAuthTest, TestGetTokensRequestNegative) {
   EXPECT_CALL(mock_blind_sign_auth_interface_,
-              GetTokens(oauth_token_, kBlindSignAuthRequestMaxTokens, _))
+              GetTokens(oauth_token_, kBlindSignAuthRequestMaxTokens, _, _))
       .Times(0);
 
   int num_tokens = -1;
@@ -272,21 +279,23 @@
       };
 
   cached_blind_sign_auth_->GetTokens(oauth_token_, num_tokens,
-                                     std::move(callback));
+                                     ProxyLayer::kProxyA, std::move(callback));
 }
 
 TEST_F(CachedBlindSignAuthTest, TestHandleGetTokensResponseErrorHandling) {
   EXPECT_CALL(mock_blind_sign_auth_interface_,
-              GetTokens(oauth_token_, kBlindSignAuthRequestMaxTokens, _))
+              GetTokens(oauth_token_, kBlindSignAuthRequestMaxTokens, _, _))
       .Times(2)
-      .WillOnce([](Unused, int num_tokens, SignedTokenCallback callback) {
-        std::move(callback)(absl::InternalError("AuthAndSign failed"));
-      })
-      .WillOnce([this](Unused, int num_tokens, SignedTokenCallback callback) {
-        fake_tokens_ = MakeFakeTokens(num_tokens);
-        fake_tokens_.pop_back();
-        std::move(callback)(absl::MakeSpan(fake_tokens_));
-      });
+      .WillOnce(
+          [](Unused, int num_tokens, Unused, SignedTokenCallback callback) {
+            std::move(callback)(absl::InternalError("AuthAndSign failed"));
+          })
+      .WillOnce(
+          [this](Unused, int num_tokens, Unused, SignedTokenCallback callback) {
+            fake_tokens_ = MakeFakeTokens(num_tokens);
+            fake_tokens_.pop_back();
+            std::move(callback)(absl::MakeSpan(fake_tokens_));
+          });
 
   int num_tokens = kBlindSignAuthRequestMaxTokens;
   QuicheNotification first;
@@ -297,8 +306,8 @@
         first.Notify();
       };
 
-  cached_blind_sign_auth_->GetTokens(oauth_token_, num_tokens,
-                                     std::move(first_callback));
+  cached_blind_sign_auth_->GetTokens(
+      oauth_token_, num_tokens, ProxyLayer::kProxyA, std::move(first_callback));
   first.WaitForNotification();
 
   QuicheNotification second;
@@ -310,13 +319,14 @@
       };
 
   cached_blind_sign_auth_->GetTokens(oauth_token_, num_tokens,
+                                     ProxyLayer::kProxyA,
                                      std::move(second_callback));
   second.WaitForNotification();
 }
 
 TEST_F(CachedBlindSignAuthTest, TestGetTokensZeroTokensRequested) {
   EXPECT_CALL(mock_blind_sign_auth_interface_,
-              GetTokens(oauth_token_, kBlindSignAuthRequestMaxTokens, _))
+              GetTokens(oauth_token_, kBlindSignAuthRequestMaxTokens, _, _))
       .Times(0);
 
   int num_tokens = 0;
@@ -327,17 +337,18 @@
       };
 
   cached_blind_sign_auth_->GetTokens(oauth_token_, num_tokens,
-                                     std::move(callback));
+                                     ProxyLayer::kProxyA, std::move(callback));
 }
 
 TEST_F(CachedBlindSignAuthTest, TestExpiredTokensArePruned) {
   EXPECT_CALL(mock_blind_sign_auth_interface_,
-              GetTokens(oauth_token_, kBlindSignAuthRequestMaxTokens, _))
+              GetTokens(oauth_token_, kBlindSignAuthRequestMaxTokens, _, _))
       .Times(1)
-      .WillOnce([this](Unused, int num_tokens, SignedTokenCallback callback) {
-        fake_tokens_ = MakeExpiredTokens(num_tokens);
-        std::move(callback)(absl::MakeSpan(fake_tokens_));
-      });
+      .WillOnce(
+          [this](Unused, int num_tokens, Unused, SignedTokenCallback callback) {
+            fake_tokens_ = MakeExpiredTokens(num_tokens);
+            std::move(callback)(absl::MakeSpan(fake_tokens_));
+          });
 
   int num_tokens = kBlindSignAuthRequestMaxTokens;
   QuicheNotification first;
@@ -348,17 +359,17 @@
         first.Notify();
       };
 
-  cached_blind_sign_auth_->GetTokens(oauth_token_, num_tokens,
-                                     std::move(first_callback));
+  cached_blind_sign_auth_->GetTokens(
+      oauth_token_, num_tokens, ProxyLayer::kProxyA, std::move(first_callback));
   first.WaitForNotification();
 }
 
 TEST_F(CachedBlindSignAuthTest, TestClearCacheRemovesTokens) {
   EXPECT_CALL(mock_blind_sign_auth_interface_,
-              GetTokens(oauth_token_, kBlindSignAuthRequestMaxTokens, _))
+              GetTokens(oauth_token_, kBlindSignAuthRequestMaxTokens, _, _))
       .Times(2)
       .WillRepeatedly(
-          [this](Unused, int num_tokens, SignedTokenCallback callback) {
+          [this](Unused, int num_tokens, Unused, SignedTokenCallback callback) {
             fake_tokens_ = MakeExpiredTokens(num_tokens);
             std::move(callback)(absl::MakeSpan(fake_tokens_));
           });
@@ -372,8 +383,8 @@
         first.Notify();
       };
 
-  cached_blind_sign_auth_->GetTokens(oauth_token_, num_tokens,
-                                     std::move(first_callback));
+  cached_blind_sign_auth_->GetTokens(
+      oauth_token_, num_tokens, ProxyLayer::kProxyA, std::move(first_callback));
   first.WaitForNotification();
 
   cached_blind_sign_auth_->ClearCache();
@@ -387,6 +398,7 @@
       };
 
   cached_blind_sign_auth_->GetTokens(oauth_token_, num_tokens,
+                                     ProxyLayer::kProxyA,
                                      std::move(second_callback));
   second.WaitForNotification();
 }
diff --git a/quiche/blind_sign_auth/test_tools/mock_blind_sign_auth_interface.h b/quiche/blind_sign_auth/test_tools/mock_blind_sign_auth_interface.h
index 19d740e..d3af877 100644
--- a/quiche/blind_sign_auth/test_tools/mock_blind_sign_auth_interface.h
+++ b/quiche/blind_sign_auth/test_tools/mock_blind_sign_auth_interface.h
@@ -17,7 +17,7 @@
     : public BlindSignAuthInterface {
  public:
   MOCK_METHOD(void, GetTokens,
-              (std::string oauth_token, int num_tokens,
+              (std::string oauth_token, int num_tokens, ProxyLayer proxy_layer,
                SignedTokenCallback callback),
               (override));
 };