Refactor TlsServerHandshaker to implement QuicCryptoServerStreamBase directly

Between QuicCryptoServerStreamBase, QuicCryptoServerStream, and
QuicCryptoServerStream::HandshakerInterface, there is one too many layers
of abstraction. Currently QuicCryptoServerStream basically only acts as an
intermediary between QuicCryptoServerStreamBase and the implementation of
its HandshakerInterface. Instead, its HandshakerInterfaces could implement
QuicCryptoServerStreamBase directly. This makes that change for the TLS
side. A future CL will collapse QuicCryptoServerStream and
QuicCryptoServerHandshaker into the same class.

gfe-relnote: refactor of TLS QUIC code, protected by multiple QUIC version flags
PiperOrigin-RevId: 290819354
Change-Id: Ia48d001bf0d1a7fb43863d22137ad0d0b897ad7b
diff --git a/quic/core/http/quic_server_session_base_test.cc b/quic/core/http/quic_server_session_base_test.cc
index d8063d1..13e264d 100644
--- a/quic/core/http/quic_server_session_base_test.cc
+++ b/quic/core/http/quic_server_session_base_test.cc
@@ -15,6 +15,7 @@
 #include "net/third_party/quiche/src/quic/core/quic_connection.h"
 #include "net/third_party/quiche/src/quic/core/quic_crypto_server_stream.h"
 #include "net/third_party/quiche/src/quic/core/quic_utils.h"
+#include "net/third_party/quiche/src/quic/core/tls_server_handshaker.h"
 #include "net/third_party/quiche/src/quic/platform/api/quic_expect_bug.h"
 #include "net/third_party/quiche/src/quic/platform/api/quic_flags.h"
 #include "net/third_party/quiche/src/quic/platform/api/quic_ptr_util.h"
@@ -479,6 +480,20 @@
                void(const CachedNetworkParameters* cached_network_parameters));
 };
 
+class MockTlsServerHandshaker : public TlsServerHandshaker {
+ public:
+  explicit MockTlsServerHandshaker(QuicServerSessionBase* session,
+                                   SSL_CTX* ssl_ctx,
+                                   ProofSource* proof_source)
+      : TlsServerHandshaker(session, ssl_ctx, proof_source) {}
+  MockTlsServerHandshaker(const MockTlsServerHandshaker&) = delete;
+  MockTlsServerHandshaker& operator=(const MockTlsServerHandshaker&) = delete;
+  ~MockTlsServerHandshaker() override {}
+
+  MOCK_METHOD1(SendServerConfigUpdate,
+               void(const CachedNetworkParameters* cached_network_parameters));
+};
+
 TEST_P(QuicServerSessionBaseTest, BandwidthEstimates) {
   // Test that bandwidth estimate updates are sent to the client, only when
   // bandwidth resumption is enabled, the bandwidth estimate has changed
@@ -505,10 +520,22 @@
         /*is_static=*/true);
   }
   QuicServerSessionBasePeer::SetCryptoStream(session_.get(), nullptr);
-  MockQuicCryptoServerStream* crypto_stream =
-      new MockQuicCryptoServerStream(&crypto_config_, &compressed_certs_cache_,
-                                     session_.get(), &stream_helper_);
-  QuicServerSessionBasePeer::SetCryptoStream(session_.get(), crypto_stream);
+  MockQuicCryptoServerStream* quic_crypto_stream = nullptr;
+  MockTlsServerHandshaker* tls_server_stream = nullptr;
+  if (session_->connection()->version().handshake_protocol ==
+      PROTOCOL_QUIC_CRYPTO) {
+    quic_crypto_stream = new MockQuicCryptoServerStream(
+        &crypto_config_, &compressed_certs_cache_, session_.get(),
+        &stream_helper_);
+    QuicServerSessionBasePeer::SetCryptoStream(session_.get(),
+                                               quic_crypto_stream);
+  } else {
+    tls_server_stream =
+        new MockTlsServerHandshaker(session_.get(), crypto_config_.ssl_ctx(),
+                                    crypto_config_.proof_source());
+    QuicServerSessionBasePeer::SetCryptoStream(session_.get(),
+                                               tls_server_stream);
+  }
   if (!VersionUsesHttp3(transport_version())) {
     session_->RegisterStreamPriority(
         QuicUtils::GetHeadersStreamId(connection_->transport_version()),
@@ -592,9 +619,15 @@
       session_->connection()->clock()->WallNow().ToUNIXSeconds());
   expected_network_params.set_serving_region(serving_region);
 
-  EXPECT_CALL(*crypto_stream,
-              SendServerConfigUpdate(EqualsProto(expected_network_params)))
-      .Times(1);
+  if (quic_crypto_stream) {
+    EXPECT_CALL(*quic_crypto_stream,
+                SendServerConfigUpdate(EqualsProto(expected_network_params)))
+        .Times(1);
+  } else {
+    EXPECT_CALL(*tls_server_stream,
+                SendServerConfigUpdate(EqualsProto(expected_network_params)))
+        .Times(1);
+  }
   EXPECT_CALL(*connection_, OnSendConnectionState(_)).Times(1);
   session_->OnCongestionWindowChange(now);
 }
diff --git a/quic/core/quic_crypto_server_stream.cc b/quic/core/quic_crypto_server_stream.cc
index abc681f..f4a514d 100644
--- a/quic/core/quic_crypto_server_stream.cc
+++ b/quic/core/quic_crypto_server_stream.cc
@@ -32,8 +32,20 @@
     QuicCompressedCertsCache* compressed_certs_cache,
     QuicSession* session,
     QuicCryptoServerStream::Helper* helper) {
-  return std::unique_ptr<QuicCryptoServerStream>(new QuicCryptoServerStream(
-      crypto_config, compressed_certs_cache, session, helper));
+  switch (session->connection()->version().handshake_protocol) {
+    case PROTOCOL_QUIC_CRYPTO:
+      return std::unique_ptr<QuicCryptoServerStream>(new QuicCryptoServerStream(
+          crypto_config, compressed_certs_cache, session, helper));
+    case PROTOCOL_TLS1_3:
+      return std::unique_ptr<TlsServerHandshaker>(new TlsServerHandshaker(
+          session, crypto_config->ssl_ctx(), crypto_config->proof_source()));
+    case PROTOCOL_UNSUPPORTED:
+      break;
+  }
+  QUIC_BUG << "Unknown handshake protocol: "
+           << static_cast<int>(
+                  session->connection()->version().handshake_protocol);
+  return nullptr;
 }
 
 QuicCryptoServerStream::QuicCryptoServerStream(
@@ -68,9 +80,8 @@
             crypto_config_, this, compressed_certs_cache_, session, helper_);
         break;
       case PROTOCOL_TLS1_3:
-        handshaker_ = std::make_unique<TlsServerHandshaker>(
-            this, session, crypto_config_->ssl_ctx(),
-            crypto_config_->proof_source());
+        QUIC_BUG
+            << "Attempting to create QuicCryptoServerStream for TLS version";
         break;
       case PROTOCOL_UNSUPPORTED:
         QUIC_BUG << "Attempting to create QuicCryptoServerStream for unknown "
@@ -176,9 +187,7 @@
           crypto_config_, this, compressed_certs_cache_, session(), helper_);
       break;
     case PROTOCOL_TLS1_3:
-      handshaker_ = std::make_unique<TlsServerHandshaker>(
-          this, session(), crypto_config_->ssl_ctx(),
-          crypto_config_->proof_source());
+      QUIC_BUG << "Attempting to use QuicCryptoServerStream with TLS";
       break;
     case PROTOCOL_UNSUPPORTED:
       QUIC_BUG << "Attempting to create QuicCryptoServerStream for unknown "
diff --git a/quic/core/quic_crypto_server_stream_test.cc b/quic/core/quic_crypto_server_stream_test.cc
index 9b13173..9d78641 100644
--- a/quic/core/quic_crypto_server_stream_test.cc
+++ b/quic/core/quic_crypto_server_stream_test.cc
@@ -108,7 +108,7 @@
     }
   }
 
-  QuicCryptoServerStream* server_stream() {
+  QuicCryptoServerStreamBase* server_stream() {
     return server_session_->GetMutableCryptoStream();
   }
 
diff --git a/quic/core/tls_client_handshaker.cc b/quic/core/tls_client_handshaker.cc
index 2f8ba67..da549f3 100644
--- a/quic/core/tls_client_handshaker.cc
+++ b/quic/core/tls_client_handshaker.cc
@@ -51,6 +51,7 @@
     QuicCryptoClientConfig* crypto_config,
     QuicCryptoClientStream::ProofHandler* proof_handler)
     : TlsHandshaker(stream, session),
+      session_(session),
       server_id_(server_id),
       proof_verifier_(crypto_config->proof_verifier()),
       verify_context_(std::move(verify_context)),
diff --git a/quic/core/tls_client_handshaker.h b/quic/core/tls_client_handshaker.h
index a2fa96a..c8f58bd 100644
--- a/quic/core/tls_client_handshaker.h
+++ b/quic/core/tls_client_handshaker.h
@@ -108,6 +108,9 @@
 
   void InsertSession(bssl::UniquePtr<SSL_SESSION> session) override;
 
+  QuicSession* session() { return session_; }
+  QuicSession* session_;
+
   QuicServerId server_id_;
 
   // Objects used for verifying the server's certificate chain.
diff --git a/quic/core/tls_handshaker.cc b/quic/core/tls_handshaker.cc
index a8fcf3a..c5effc3 100644
--- a/quic/core/tls_handshaker.cc
+++ b/quic/core/tls_handshaker.cc
@@ -15,7 +15,7 @@
 namespace quic {
 
 TlsHandshaker::TlsHandshaker(QuicCryptoStream* stream, QuicSession* session)
-    : stream_(stream), session_(session), delegate_(session) {}
+    : stream_(stream), delegate_(session) {}
 
 TlsHandshaker::~TlsHandshaker() {}
 
diff --git a/quic/core/tls_handshaker.h b/quic/core/tls_handshaker.h
index 88e5f1d..6fa22d0 100644
--- a/quic/core/tls_handshaker.h
+++ b/quic/core/tls_handshaker.h
@@ -67,7 +67,6 @@
   SSL* ssl() const { return tls_connection()->ssl(); }
 
   QuicCryptoStream* stream() { return stream_; }
-  QuicSession* session() { return session_; }
   HandshakerDelegateInterface* delegate() { return delegate_; }
 
   // SetEncryptionSecret provides the encryption secret to use at a particular
@@ -97,7 +96,6 @@
 
  private:
   QuicCryptoStream* stream_;
-  QuicSession* session_;
   HandshakerDelegateInterface* delegate_;
 
   QuicErrorCode parser_error_ = QUIC_NO_ERROR;
diff --git a/quic/core/tls_handshaker_test.cc b/quic/core/tls_handshaker_test.cc
index b175d55..6812549 100644
--- a/quic/core/tls_handshaker_test.cc
+++ b/quic/core/tls_handshaker_test.cc
@@ -267,6 +267,24 @@
   std::unique_ptr<TlsClientHandshaker> handshaker_;
 };
 
+class TestTlsServerHandshaker : public TlsServerHandshaker {
+ public:
+  TestTlsServerHandshaker(QuicSession* session,
+                          SSL_CTX* ssl_ctx,
+                          ProofSource* proof_source,
+                          TestQuicCryptoStream* test_stream)
+      : TlsServerHandshaker(session, ssl_ctx, proof_source),
+        test_stream_(test_stream) {}
+
+  void WriteCryptoData(EncryptionLevel level,
+                       quiche::QuicheStringPiece data) override {
+    test_stream_->WriteCryptoData(level, data);
+  }
+
+ private:
+  TestQuicCryptoStream* test_stream_;
+};
+
 class TestQuicCryptoServerStream : public TestQuicCryptoStream {
  public:
   TestQuicCryptoServerStream(QuicSession* session,
@@ -274,10 +292,10 @@
       : TestQuicCryptoStream(session),
         proof_source_(proof_source),
         ssl_ctx_(TlsServerConnection::CreateSslCtx()),
-        handshaker_(new TlsServerHandshaker(this,
-                                            session,
-                                            ssl_ctx_.get(),
-                                            proof_source_)) {}
+        handshaker_(new TestTlsServerHandshaker(session,
+                                                ssl_ctx_.get(),
+                                                proof_source_,
+                                                this)) {}
 
   ~TestQuicCryptoServerStream() override = default;
 
@@ -300,7 +318,7 @@
 };
 
 void ExchangeHandshakeMessages(TestQuicCryptoStream* client,
-                               TestQuicCryptoStream* server) {
+                               TestQuicCryptoServerStream* server) {
   while (!client->pending_writes().empty() ||
          !server->pending_writes().empty()) {
     client->SendCryptoMessagesToPeer(server);
diff --git a/quic/core/tls_server_handshaker.cc b/quic/core/tls_server_handshaker.cc
index 0cb7a29..b22d964 100644
--- a/quic/core/tls_server_handshaker.cc
+++ b/quic/core/tls_server_handshaker.cc
@@ -41,11 +41,11 @@
   handshaker_ = nullptr;
 }
 
-TlsServerHandshaker::TlsServerHandshaker(QuicCryptoStream* stream,
-                                         QuicSession* session,
+TlsServerHandshaker::TlsServerHandshaker(QuicSession* session,
                                          SSL_CTX* ssl_ctx,
                                          ProofSource* proof_source)
-    : TlsHandshaker(stream, session),
+    : TlsHandshaker(this, session),
+      QuicCryptoServerStreamBase(session),
       proof_source_(proof_source),
       crypto_negotiated_params_(new QuicCryptoNegotiatedParameters),
       tls_connection_(ssl_ctx, this) {
diff --git a/quic/core/tls_server_handshaker.h b/quic/core/tls_server_handshaker.h
index 919af91..f2341cf 100644
--- a/quic/core/tls_server_handshaker.h
+++ b/quic/core/tls_server_handshaker.h
@@ -19,15 +19,14 @@
 
 namespace quic {
 
-// An implementation of QuicCryptoServerStream::HandshakerInterface which uses
+// An implementation of QuicCryptoServerStreamBase which uses
 // TLS 1.3 for the crypto handshake protocol.
 class QUIC_EXPORT_PRIVATE TlsServerHandshaker
     : public TlsHandshaker,
       public TlsServerConnection::Delegate,
-      public QuicCryptoServerStream::HandshakerInterface {
+      public QuicCryptoServerStreamBase {
  public:
-  TlsServerHandshaker(QuicCryptoStream* stream,
-                      QuicSession* session,
+  TlsServerHandshaker(QuicSession* session,
                       SSL_CTX* ssl_ctx,
                       ProofSource* proof_source);
   TlsServerHandshaker(const TlsServerHandshaker&) = delete;
@@ -35,7 +34,7 @@
 
   ~TlsServerHandshaker() override;
 
-  // From QuicCryptoServerStream::HandshakerInterface
+  // From QuicCryptoServerStreamBase
   void CancelOutstandingCallbacks() override;
   bool GetBase64SHA256ClientChannelID(std::string* output) const override;
   void SendServerConfigUpdate(
@@ -50,7 +49,7 @@
   void OnPacketDecrypted(EncryptionLevel level) override;
   bool ShouldSendExpectCTHeader() const override;
 
-  // From QuicCryptoServerStream::HandshakerInterface and TlsHandshaker
+  // From QuicCryptoServerStreamBase and TlsHandshaker
   bool encryption_established() const override;
   bool one_rtt_keys_available() const override;
   const QuicCryptoNegotiatedParameters& crypto_negotiated_params()