Improve handling of crypters in CompareClientAndServerKeys

gfe-relnote: n/a (test-only)
PiperOrigin-RevId: 240383505
Change-Id: Iedebac7d79c041d8a7c628043ffb37b3037bf08b
diff --git a/quic/test_tools/crypto_test_utils.cc b/quic/test_tools/crypto_test_utils.cc
index 66692bd..af33d4e 100644
--- a/quic/test_tools/crypto_test_utils.cc
+++ b/quic/test_tools/crypto_test_utils.cc
@@ -729,121 +729,93 @@
   rej->SetVector(kRREJ, reject_reasons);
 }
 
+namespace {
+
+#define RETURN_STRING_LITERAL(x) \
+  case x:                        \
+    return #x
+
+std::string EncryptionLevelString(EncryptionLevel level) {
+  switch (level) {
+    RETURN_STRING_LITERAL(ENCRYPTION_INITIAL);
+    RETURN_STRING_LITERAL(ENCRYPTION_HANDSHAKE);
+    RETURN_STRING_LITERAL(ENCRYPTION_ZERO_RTT);
+    RETURN_STRING_LITERAL(ENCRYPTION_FORWARD_SECURE);
+    default:
+      return "";
+  }
+}
+
+void CompareCrypters(const QuicEncrypter* encrypter,
+                     const QuicDecrypter* decrypter,
+                     std::string label) {
+  QuicStringPiece encrypter_key = encrypter->GetKey();
+  QuicStringPiece encrypter_iv = encrypter->GetNoncePrefix();
+  QuicStringPiece decrypter_key = decrypter->GetKey();
+  QuicStringPiece decrypter_iv = decrypter->GetNoncePrefix();
+  CompareCharArraysWithHexError(label + " key", encrypter_key.data(),
+                                encrypter_key.length(), decrypter_key.data(),
+                                decrypter_key.length());
+  CompareCharArraysWithHexError(label + " iv", encrypter_iv.data(),
+                                encrypter_iv.length(), decrypter_iv.data(),
+                                decrypter_iv.length());
+}
+
+}  // namespace
+
 void CompareClientAndServerKeys(QuicCryptoClientStream* client,
                                 QuicCryptoServerStream* server) {
   QuicFramer* client_framer = QuicConnectionPeer::GetFramer(
       QuicStreamPeer::session(client)->connection());
   QuicFramer* server_framer = QuicConnectionPeer::GetFramer(
       QuicStreamPeer::session(server)->connection());
-  const QuicEncrypter* client_encrypter(
-      QuicFramerPeer::GetEncrypter(client_framer, ENCRYPTION_ZERO_RTT));
-  const QuicDecrypter* client_decrypter(
-      QuicStreamPeer::session(client)->connection()->decrypter());
-  const QuicEncrypter* client_forward_secure_encrypter(
-      QuicFramerPeer::GetEncrypter(client_framer, ENCRYPTION_FORWARD_SECURE));
-  const QuicDecrypter* client_forward_secure_decrypter(
-      QuicStreamPeer::session(client)->connection()->alternative_decrypter());
-  const QuicEncrypter* server_encrypter(
-      QuicFramerPeer::GetEncrypter(server_framer, ENCRYPTION_ZERO_RTT));
-  const QuicDecrypter* server_decrypter(
-      QuicStreamPeer::session(server)->connection()->decrypter());
-  const QuicEncrypter* server_forward_secure_encrypter(
-      QuicFramerPeer::GetEncrypter(server_framer, ENCRYPTION_FORWARD_SECURE));
-  const QuicDecrypter* server_forward_secure_decrypter(
-      QuicStreamPeer::session(server)->connection()->alternative_decrypter());
-
-  QuicStringPiece client_encrypter_key = client_encrypter->GetKey();
-  QuicStringPiece client_encrypter_iv = client_encrypter->GetNoncePrefix();
-  QuicStringPiece client_decrypter_key = client_decrypter->GetKey();
-  QuicStringPiece client_decrypter_iv = client_decrypter->GetNoncePrefix();
-  QuicStringPiece client_forward_secure_encrypter_key =
-      client_forward_secure_encrypter->GetKey();
-  QuicStringPiece client_forward_secure_encrypter_iv =
-      client_forward_secure_encrypter->GetNoncePrefix();
-  QuicStringPiece client_forward_secure_decrypter_key =
-      client_forward_secure_decrypter->GetKey();
-  QuicStringPiece client_forward_secure_decrypter_iv =
-      client_forward_secure_decrypter->GetNoncePrefix();
-  QuicStringPiece server_encrypter_key = server_encrypter->GetKey();
-  QuicStringPiece server_encrypter_iv = server_encrypter->GetNoncePrefix();
-  QuicStringPiece server_decrypter_key = server_decrypter->GetKey();
-  QuicStringPiece server_decrypter_iv = server_decrypter->GetNoncePrefix();
-  QuicStringPiece server_forward_secure_encrypter_key =
-      server_forward_secure_encrypter->GetKey();
-  QuicStringPiece server_forward_secure_encrypter_iv =
-      server_forward_secure_encrypter->GetNoncePrefix();
-  QuicStringPiece server_forward_secure_decrypter_key =
-      server_forward_secure_decrypter->GetKey();
-  QuicStringPiece server_forward_secure_decrypter_iv =
-      server_forward_secure_decrypter->GetNoncePrefix();
+  for (EncryptionLevel level :
+       {ENCRYPTION_HANDSHAKE, ENCRYPTION_ZERO_RTT, ENCRYPTION_FORWARD_SECURE}) {
+    SCOPED_TRACE(EncryptionLevelString(level));
+    const QuicEncrypter* client_encrypter(
+        QuicFramerPeer::GetEncrypter(client_framer, level));
+    const QuicDecrypter* server_decrypter(
+        QuicFramerPeer::GetDecrypter(server_framer, level));
+    if (level == ENCRYPTION_FORWARD_SECURE ||
+        !(client_encrypter == nullptr && server_decrypter == nullptr)) {
+      CompareCrypters(client_encrypter, server_decrypter,
+                      "client " + EncryptionLevelString(level) + " write");
+    }
+    const QuicEncrypter* server_encrypter(
+        QuicFramerPeer::GetEncrypter(server_framer, level));
+    const QuicDecrypter* client_decrypter(
+        QuicFramerPeer::GetDecrypter(client_framer, level));
+    if (level == ENCRYPTION_FORWARD_SECURE ||
+        !(server_encrypter == nullptr && client_decrypter == nullptr)) {
+      CompareCrypters(server_encrypter, client_decrypter,
+                      "server " + EncryptionLevelString(level) + " write");
+    }
+  }
 
   QuicStringPiece client_subkey_secret =
       client->crypto_negotiated_params().subkey_secret;
   QuicStringPiece server_subkey_secret =
       server->crypto_negotiated_params().subkey_secret;
+  CompareCharArraysWithHexError("subkey secret", client_subkey_secret.data(),
+                                client_subkey_secret.length(),
+                                server_subkey_secret.data(),
+                                server_subkey_secret.length());
 
   const char kSampleLabel[] = "label";
   const char kSampleContext[] = "context";
   const size_t kSampleOutputLength = 32;
   std::string client_key_extraction;
   std::string server_key_extraction;
-  std::string client_tb_ekm;
-  std::string server_tb_ekm;
   EXPECT_TRUE(client->ExportKeyingMaterial(kSampleLabel, kSampleContext,
                                            kSampleOutputLength,
                                            &client_key_extraction));
   EXPECT_TRUE(server->ExportKeyingMaterial(kSampleLabel, kSampleContext,
                                            kSampleOutputLength,
                                            &server_key_extraction));
-
-  CompareCharArraysWithHexError("client write key", client_encrypter_key.data(),
-                                client_encrypter_key.length(),
-                                server_decrypter_key.data(),
-                                server_decrypter_key.length());
-  CompareCharArraysWithHexError("client write IV", client_encrypter_iv.data(),
-                                client_encrypter_iv.length(),
-                                server_decrypter_iv.data(),
-                                server_decrypter_iv.length());
-  CompareCharArraysWithHexError("server write key", server_encrypter_key.data(),
-                                server_encrypter_key.length(),
-                                client_decrypter_key.data(),
-                                client_decrypter_key.length());
-  CompareCharArraysWithHexError("server write IV", server_encrypter_iv.data(),
-                                server_encrypter_iv.length(),
-                                client_decrypter_iv.data(),
-                                client_decrypter_iv.length());
-  CompareCharArraysWithHexError("client forward secure write key",
-                                client_forward_secure_encrypter_key.data(),
-                                client_forward_secure_encrypter_key.length(),
-                                server_forward_secure_decrypter_key.data(),
-                                server_forward_secure_decrypter_key.length());
-  CompareCharArraysWithHexError("client forward secure write IV",
-                                client_forward_secure_encrypter_iv.data(),
-                                client_forward_secure_encrypter_iv.length(),
-                                server_forward_secure_decrypter_iv.data(),
-                                server_forward_secure_decrypter_iv.length());
-  CompareCharArraysWithHexError("server forward secure write key",
-                                server_forward_secure_encrypter_key.data(),
-                                server_forward_secure_encrypter_key.length(),
-                                client_forward_secure_decrypter_key.data(),
-                                client_forward_secure_decrypter_key.length());
-  CompareCharArraysWithHexError("server forward secure write IV",
-                                server_forward_secure_encrypter_iv.data(),
-                                server_forward_secure_encrypter_iv.length(),
-                                client_forward_secure_decrypter_iv.data(),
-                                client_forward_secure_decrypter_iv.length());
-  CompareCharArraysWithHexError("subkey secret", client_subkey_secret.data(),
-                                client_subkey_secret.length(),
-                                server_subkey_secret.data(),
-                                server_subkey_secret.length());
   CompareCharArraysWithHexError(
       "sample key extraction", client_key_extraction.data(),
       client_key_extraction.length(), server_key_extraction.data(),
       server_key_extraction.length());
-
-  CompareCharArraysWithHexError("token binding key extraction",
-                                client_tb_ekm.data(), client_tb_ekm.length(),
-                                server_tb_ekm.data(), server_tb_ekm.length());
 }
 
 QuicTag ParseTag(const char* tagstr) {
diff --git a/quic/test_tools/quic_framer_peer.cc b/quic/test_tools/quic_framer_peer.cc
index 68c7d0e..5f8068e 100644
--- a/quic/test_tools/quic_framer_peer.cc
+++ b/quic/test_tools/quic_framer_peer.cc
@@ -329,6 +329,12 @@
 }
 
 // static
+QuicDecrypter* QuicFramerPeer::GetDecrypter(QuicFramer* framer,
+                                            EncryptionLevel level) {
+  return framer->decrypter_[level].get();
+}
+
+// static
 size_t QuicFramerPeer::ComputeFrameLength(
     QuicFramer* framer,
     const QuicFrame& frame,
diff --git a/quic/test_tools/quic_framer_peer.h b/quic/test_tools/quic_framer_peer.h
index 6e34be2..e5c5e00 100644
--- a/quic/test_tools/quic_framer_peer.h
+++ b/quic/test_tools/quic_framer_peer.h
@@ -33,6 +33,7 @@
   static void SwapCrypters(QuicFramer* framer1, QuicFramer* framer2);
 
   static QuicEncrypter* GetEncrypter(QuicFramer* framer, EncryptionLevel level);
+  static QuicDecrypter* GetDecrypter(QuicFramer* framer, EncryptionLevel level);
 
   // IETF defined frame append/process methods.
   static bool ProcessIetfStreamFrame(QuicFramer* framer,