Limit the amount of incoming crypto data that will be buffered.

gfe-relnote: protected by disabled flag-protected QUIC_VERSION_48
PiperOrigin-RevId: 266019141
Change-Id: Ife996bdf80a28b3bcce4b02cda49bff0fd23a071
diff --git a/quic/core/crypto/tls_connection.h b/quic/core/crypto/tls_connection.h
index c15d920..4774ba6 100644
--- a/quic/core/crypto/tls_connection.h
+++ b/quic/core/crypto/tls_connection.h
@@ -69,7 +69,7 @@
   static enum ssl_encryption_level_t BoringEncryptionLevel(
       EncryptionLevel level);
 
-  SSL* ssl() { return ssl_.get(); }
+  SSL* ssl() const { return ssl_.get(); }
 
  protected:
   // TlsConnection does not take ownership of any of its arguments; they must
diff --git a/quic/core/quic_crypto_client_handshaker.cc b/quic/core/quic_crypto_client_handshaker.cc
index 158285d..ac15301 100644
--- a/quic/core/quic_crypto_client_handshaker.cc
+++ b/quic/core/quic_crypto_client_handshaker.cc
@@ -141,6 +141,11 @@
   return QuicCryptoHandshaker::crypto_message_parser();
 }
 
+size_t QuicCryptoClientHandshaker::BufferSizeLimitForLevel(
+    EncryptionLevel level) const {
+  return QuicCryptoHandshaker::BufferSizeLimitForLevel(level);
+}
+
 void QuicCryptoClientHandshaker::HandleServerConfigUpdateMessage(
     const CryptoHandshakeMessage& server_config_update) {
   DCHECK(server_config_update.tag() == kSCUP);
diff --git a/quic/core/quic_crypto_client_handshaker.h b/quic/core/quic_crypto_client_handshaker.h
index 5b7ce35..d33ebfe 100644
--- a/quic/core/quic_crypto_client_handshaker.h
+++ b/quic/core/quic_crypto_client_handshaker.h
@@ -44,6 +44,7 @@
   const QuicCryptoNegotiatedParameters& crypto_negotiated_params()
       const override;
   CryptoMessageParser* crypto_message_parser() override;
+  size_t BufferSizeLimitForLevel(EncryptionLevel level) const override;
 
   // From QuicCryptoHandshaker
   void OnHandshakeMessage(const CryptoHandshakeMessage& message) override;
diff --git a/quic/core/quic_crypto_client_stream.cc b/quic/core/quic_crypto_client_stream.cc
index 93f2a61..8f89e9a 100644
--- a/quic/core/quic_crypto_client_stream.cc
+++ b/quic/core/quic_crypto_client_stream.cc
@@ -84,6 +84,11 @@
   return handshaker_->crypto_message_parser();
 }
 
+size_t QuicCryptoClientStream::BufferSizeLimitForLevel(
+    EncryptionLevel level) const {
+  return handshaker_->BufferSizeLimitForLevel(level);
+}
+
 std::string QuicCryptoClientStream::chlo_hash() const {
   return handshaker_->chlo_hash();
 }
diff --git a/quic/core/quic_crypto_client_stream.h b/quic/core/quic_crypto_client_stream.h
index b8dff7e..89f0d2e 100644
--- a/quic/core/quic_crypto_client_stream.h
+++ b/quic/core/quic_crypto_client_stream.h
@@ -99,6 +99,10 @@
 
     // Used by QuicCryptoStream to parse data received on this stream.
     virtual CryptoMessageParser* crypto_message_parser() = 0;
+
+    // Used by QuicCryptoStream to know how much unprocessed data can be
+    // buffered at each encryption level.
+    virtual size_t BufferSizeLimitForLevel(EncryptionLevel level) const = 0;
   };
 
   // ProofHandler is an interface that handles callbacks from the crypto
@@ -142,6 +146,7 @@
   const QuicCryptoNegotiatedParameters& crypto_negotiated_params()
       const override;
   CryptoMessageParser* crypto_message_parser() override;
+  size_t BufferSizeLimitForLevel(EncryptionLevel level) const override;
 
   std::string chlo_hash() const;
 
diff --git a/quic/core/quic_crypto_handshaker.cc b/quic/core/quic_crypto_handshaker.cc
index fa0f78a..d608ead 100644
--- a/quic/core/quic_crypto_handshaker.cc
+++ b/quic/core/quic_crypto_handshaker.cc
@@ -45,5 +45,9 @@
   return &crypto_framer_;
 }
 
+size_t QuicCryptoHandshaker::BufferSizeLimitForLevel(EncryptionLevel) const {
+  return GetQuicFlag(FLAGS_quic_max_buffered_crypto_bytes);
+}
+
 #undef ENDPOINT  // undef for jumbo builds
 }  // namespace quic
diff --git a/quic/core/quic_crypto_handshaker.h b/quic/core/quic_crypto_handshaker.h
index 231acfc..e5d8d51 100644
--- a/quic/core/quic_crypto_handshaker.h
+++ b/quic/core/quic_crypto_handshaker.h
@@ -27,6 +27,7 @@
   void OnHandshakeMessage(const CryptoHandshakeMessage& message) override;
 
   CryptoMessageParser* crypto_message_parser();
+  size_t BufferSizeLimitForLevel(EncryptionLevel level) const;
 
  protected:
   QuicTag last_sent_handshake_message_tag() const {
diff --git a/quic/core/quic_crypto_server_handshaker.cc b/quic/core/quic_crypto_server_handshaker.cc
index 26134d7..ab381d7 100644
--- a/quic/core/quic_crypto_server_handshaker.cc
+++ b/quic/core/quic_crypto_server_handshaker.cc
@@ -373,6 +373,11 @@
   return QuicCryptoHandshaker::crypto_message_parser();
 }
 
+size_t QuicCryptoServerHandshaker::BufferSizeLimitForLevel(
+    EncryptionLevel level) const {
+  return QuicCryptoHandshaker::BufferSizeLimitForLevel(level);
+}
+
 void QuicCryptoServerHandshaker::ProcessClientHello(
     QuicReferenceCountedPointer<ValidateClientHelloResultCallback::Result>
         result,
diff --git a/quic/core/quic_crypto_server_handshaker.h b/quic/core/quic_crypto_server_handshaker.h
index f664408..4b7b3ba 100644
--- a/quic/core/quic_crypto_server_handshaker.h
+++ b/quic/core/quic_crypto_server_handshaker.h
@@ -58,6 +58,7 @@
   const QuicCryptoNegotiatedParameters& crypto_negotiated_params()
       const override;
   CryptoMessageParser* crypto_message_parser() override;
+  size_t BufferSizeLimitForLevel(EncryptionLevel level) const override;
 
   // From QuicCryptoHandshaker
   void OnHandshakeMessage(const CryptoHandshakeMessage& message) override;
diff --git a/quic/core/quic_crypto_server_stream.cc b/quic/core/quic_crypto_server_stream.cc
index 09344cd..353de98 100644
--- a/quic/core/quic_crypto_server_stream.cc
+++ b/quic/core/quic_crypto_server_stream.cc
@@ -111,6 +111,11 @@
   return handshaker()->crypto_message_parser();
 }
 
+size_t QuicCryptoServerStream::BufferSizeLimitForLevel(
+    EncryptionLevel level) const {
+  return handshaker()->BufferSizeLimitForLevel(level);
+}
+
 void QuicCryptoServerStream::OnSuccessfulVersionNegotiation(
     const ParsedQuicVersion& version) {
   DCHECK_EQ(version, session()->connection()->version());
diff --git a/quic/core/quic_crypto_server_stream.h b/quic/core/quic_crypto_server_stream.h
index 60a162c..3a7d6e7 100644
--- a/quic/core/quic_crypto_server_stream.h
+++ b/quic/core/quic_crypto_server_stream.h
@@ -119,6 +119,10 @@
 
     // Used by QuicCryptoStream to parse data received on this stream.
     virtual CryptoMessageParser* crypto_message_parser() = 0;
+
+    // Used by QuicCryptoStream to know how much unprocessed data can be
+    // buffered at each encryption level.
+    virtual size_t BufferSizeLimitForLevel(EncryptionLevel level) const = 0;
   };
 
   class Helper {
@@ -172,6 +176,7 @@
   const QuicCryptoNegotiatedParameters& crypto_negotiated_params()
       const override;
   CryptoMessageParser* crypto_message_parser() override;
+  size_t BufferSizeLimitForLevel(EncryptionLevel level) const override;
   void OnSuccessfulVersionNegotiation(
       const ParsedQuicVersion& version) override;
 
diff --git a/quic/core/quic_crypto_stream.cc b/quic/core/quic_crypto_stream.cc
index 9d70139..e5b13f1 100644
--- a/quic/core/quic_crypto_stream.cc
+++ b/quic/core/quic_crypto_stream.cc
@@ -78,6 +78,11 @@
       << "Versions less than 47 shouldn't receive CRYPTO frames";
   EncryptionLevel level = session()->connection()->last_decrypted_level();
   substreams_[level].sequencer.OnCryptoFrame(frame);
+  if (substreams_[level].sequencer.NumBytesBuffered() >
+      BufferSizeLimitForLevel(frame.level)) {
+    CloseConnectionWithDetails(QUIC_FLOW_CONTROL_RECEIVED_TOO_MUCH_DATA,
+                               "Too much crypto data received");
+  }
 }
 
 void QuicCryptoStream::OnStreamFrame(const QuicStreamFrame& frame) {
@@ -181,6 +186,10 @@
   send_buffer->OnStreamDataConsumed(bytes_consumed);
 }
 
+size_t QuicCryptoStream::BufferSizeLimitForLevel(EncryptionLevel) const {
+  return GetQuicFlag(FLAGS_quic_max_buffered_crypto_bytes);
+}
+
 void QuicCryptoStream::OnSuccessfulVersionNegotiation(
     const ParsedQuicVersion& /*version*/) {}
 
diff --git a/quic/core/quic_crypto_stream.h b/quic/core/quic_crypto_stream.h
index 01523ee..12a36f8 100644
--- a/quic/core/quic_crypto_stream.h
+++ b/quic/core/quic_crypto_stream.h
@@ -80,6 +80,10 @@
   // Provides the message parser to use when data is received on this stream.
   virtual CryptoMessageParser* crypto_message_parser() = 0;
 
+  // Returns the maximum number of bytes that can be buffered at a particular
+  // encryption level |level|.
+  virtual size_t BufferSizeLimitForLevel(EncryptionLevel level) const;
+
   // Called when the underlying QuicConnection has agreed upon a QUIC version to
   // use.
   virtual void OnSuccessfulVersionNegotiation(const ParsedQuicVersion& version);
diff --git a/quic/core/quic_crypto_stream_test.cc b/quic/core/quic_crypto_stream_test.cc
index 9d3bda5..af5a8cc 100644
--- a/quic/core/quic_crypto_stream_test.cc
+++ b/quic/core/quic_crypto_stream_test.cc
@@ -570,6 +570,23 @@
   EXPECT_FALSE(stream_->HasBufferedCryptoFrames());
 }
 
+TEST_F(QuicCryptoStreamTest, LimitBufferedCryptoData) {
+  if (!QuicVersionUsesCryptoFrames(connection_->transport_version())) {
+    return;
+  }
+
+  EXPECT_CALL(*connection_,
+              CloseConnection(QUIC_FLOW_CONTROL_RECEIVED_TOO_MUCH_DATA, _, _));
+  std::string large_frame(2 * GetQuicFlag(FLAGS_quic_max_buffered_crypto_bytes),
+                          'a');
+
+  // Set offset to 1 so that we guarantee the data gets buffered instead of
+  // immediately processed.
+  QuicStreamOffset offset = 1;
+  stream_->OnCryptoFrame(
+      QuicCryptoFrame(ENCRYPTION_INITIAL, offset, large_frame));
+}
+
 }  // namespace
 }  // namespace test
 }  // namespace quic
diff --git a/quic/core/tls_client_handshaker.cc b/quic/core/tls_client_handshaker.cc
index f362f1a..8e2d2bb 100644
--- a/quic/core/tls_client_handshaker.cc
+++ b/quic/core/tls_client_handshaker.cc
@@ -208,6 +208,11 @@
   return TlsHandshaker::crypto_message_parser();
 }
 
+size_t TlsClientHandshaker::BufferSizeLimitForLevel(
+    EncryptionLevel level) const {
+  return TlsHandshaker::BufferSizeLimitForLevel(level);
+}
+
 void TlsClientHandshaker::AdvanceHandshake() {
   if (state_ == STATE_CONNECTION_CLOSED) {
     QUIC_LOG(INFO)
diff --git a/quic/core/tls_client_handshaker.h b/quic/core/tls_client_handshaker.h
index d2d0a4a..47faf81 100644
--- a/quic/core/tls_client_handshaker.h
+++ b/quic/core/tls_client_handshaker.h
@@ -52,9 +52,12 @@
   const QuicCryptoNegotiatedParameters& crypto_negotiated_params()
       const override;
   CryptoMessageParser* crypto_message_parser() override;
+  size_t BufferSizeLimitForLevel(EncryptionLevel level) const override;
 
  protected:
-  TlsConnection* tls_connection() override { return &tls_connection_; }
+  const TlsConnection* tls_connection() const override {
+    return &tls_connection_;
+  }
 
   void AdvanceHandshake() override;
   void CloseConnection(QuicErrorCode error,
diff --git a/quic/core/tls_handshaker.cc b/quic/core/tls_handshaker.cc
index db50f5a..e6e59fd 100644
--- a/quic/core/tls_handshaker.cc
+++ b/quic/core/tls_handshaker.cc
@@ -54,6 +54,11 @@
   return true;
 }
 
+size_t TlsHandshaker::BufferSizeLimitForLevel(EncryptionLevel level) const {
+  return SSL_quic_max_handshake_flight_len(
+      ssl(), TlsConnection::BoringEncryptionLevel(level));
+}
+
 const EVP_MD* TlsHandshaker::Prf() {
   return EVP_get_digestbynid(
       SSL_CIPHER_get_prf_nid(SSL_get_pending_cipher(ssl())));
diff --git a/quic/core/tls_handshaker.h b/quic/core/tls_handshaker.h
index b4f16e8..7d5b9bc 100644
--- a/quic/core/tls_handshaker.h
+++ b/quic/core/tls_handshaker.h
@@ -50,6 +50,7 @@
   virtual const QuicCryptoNegotiatedParameters& crypto_negotiated_params()
       const = 0;
   virtual CryptoMessageParser* crypto_message_parser() { return this; }
+  size_t BufferSizeLimitForLevel(EncryptionLevel level) const;
 
  protected:
   virtual void AdvanceHandshake() = 0;
@@ -65,9 +66,9 @@
   std::unique_ptr<QuicDecrypter> CreateDecrypter(
       const std::vector<uint8_t>& pp_secret);
 
-  virtual TlsConnection* tls_connection() = 0;
+  virtual const TlsConnection* tls_connection() const = 0;
 
-  SSL* ssl() { return tls_connection()->ssl(); }
+  SSL* ssl() const { return tls_connection()->ssl(); }
 
   QuicCryptoStream* stream() { return stream_; }
   QuicSession* session() { return session_; }
diff --git a/quic/core/tls_server_handshaker.cc b/quic/core/tls_server_handshaker.cc
index eea6f55..a2ea397 100644
--- a/quic/core/tls_server_handshaker.cc
+++ b/quic/core/tls_server_handshaker.cc
@@ -136,6 +136,11 @@
   return TlsHandshaker::crypto_message_parser();
 }
 
+size_t TlsServerHandshaker::BufferSizeLimitForLevel(
+    EncryptionLevel level) const {
+  return TlsHandshaker::BufferSizeLimitForLevel(level);
+}
+
 void TlsServerHandshaker::AdvanceHandshake() {
   if (state_ == STATE_CONNECTION_CLOSED) {
     QUIC_LOG(INFO) << "TlsServerHandshaker received handshake message after "
diff --git a/quic/core/tls_server_handshaker.h b/quic/core/tls_server_handshaker.h
index 5ce699f..829aeaf 100644
--- a/quic/core/tls_server_handshaker.h
+++ b/quic/core/tls_server_handshaker.h
@@ -58,9 +58,12 @@
   const QuicCryptoNegotiatedParameters& crypto_negotiated_params()
       const override;
   CryptoMessageParser* crypto_message_parser() override;
+  size_t BufferSizeLimitForLevel(EncryptionLevel level) const override;
 
  protected:
-  TlsConnection* tls_connection() override { return &tls_connection_; }
+  const TlsConnection* tls_connection() const override {
+    return &tls_connection_;
+  }
 
   // Called when a new message is received on the crypto stream and is available
   // for the TLS stack to read.