Fix QuicSession::WillingAndAbleToWrite to check write keys.

Protected by FLAGS_quic_reloadable_flag_quic_fix_willing_and_able_to_write2.

PiperOrigin-RevId: 339286884
Change-Id: Ifb649dd85bade3a86864b146b20a61ffb56d3b77
diff --git a/quic/core/http/quic_spdy_session_test.cc b/quic/core/http/quic_spdy_session_test.cc
index 84ea617..d95fa6d 100644
--- a/quic/core/http/quic_spdy_session_test.cc
+++ b/quic/core/http/quic_spdy_session_test.cc
@@ -704,6 +704,7 @@
 }
 
 TEST_P(QuicSpdySessionTestServer, OnCanWrite) {
+  CompleteHandshake();
   session_.set_writev_consumes_all_data(true);
   TestStream* stream2 = session_.CreateOutgoingBidirectionalStream();
   TestStream* stream4 = session_.CreateOutgoingBidirectionalStream();
@@ -861,6 +862,7 @@
 }
 
 TEST_P(QuicSpdySessionTestServer, OnCanWriteCongestionControlBlocks) {
+  CompleteHandshake();
   session_.set_writev_consumes_all_data(true);
   InSequence s;
 
@@ -908,6 +910,7 @@
 }
 
 TEST_P(QuicSpdySessionTestServer, OnCanWriteWriterBlocks) {
+  CompleteHandshake();
   // Drive congestion control manually in order to ensure that
   // application-limited signaling is handled correctly.
   MockSendAlgorithm* send_algorithm = new StrictMock<MockSendAlgorithm>;
@@ -1009,6 +1012,7 @@
 
 TEST_P(QuicSpdySessionTestServer,
        OnCanWriteLimitsNumWritesIfFlowControlBlocked) {
+  CompleteHandshake();
   // Drive congestion control manually in order to ensure that
   // application-limited signaling is handled correctly.
   MockSendAlgorithm* send_algorithm = new StrictMock<MockSendAlgorithm>;
@@ -2031,6 +2035,7 @@
 }
 
 TEST_P(QuicSpdySessionTestServer, OnStreamFrameLost) {
+  CompleteHandshake();
   InSequence s;
 
   // Drive congestion control manually.
diff --git a/quic/core/quic_session.cc b/quic/core/quic_session.cc
index a8d9e6f..5079b46 100644
--- a/quic/core/quic_session.cc
+++ b/quic/core/quic_session.cc
@@ -649,9 +649,16 @@
   // 3) If the crypto or headers streams are blocked, or
   // 4) connection is not flow control blocked and there are write blocked
   // streams.
-  if (QuicVersionUsesCryptoFrames(transport_version()) &&
-      HasPendingHandshake()) {
-    return true;
+  if (QuicVersionUsesCryptoFrames(transport_version())) {
+    if (HasPendingHandshake()) {
+      return true;
+    }
+    if (GetQuicReloadableFlag(quic_fix_willing_and_able_to_write2)) {
+      QUIC_RELOADABLE_FLAG_COUNT(quic_fix_willing_and_able_to_write2);
+      if (!IsEncryptionEstablished()) {
+        return false;
+      }
+    }
   }
   if (control_frame_manager_.WillingToWrite() ||
       !streams_with_pending_retransmission_.empty()) {
diff --git a/quic/core/quic_session_test.cc b/quic/core/quic_session_test.cc
index 46ce528..738ee0b 100644
--- a/quic/core/quic_session_test.cc
+++ b/quic/core/quic_session_test.cc
@@ -104,6 +104,9 @@
           session()->config()->ProcessPeerHello(msg, CLIENT, &error_details);
     }
     EXPECT_THAT(error, IsQuicNoError());
+    session()->OnNewEncryptionKeyAvailable(
+        ENCRYPTION_FORWARD_SECURE,
+        std::make_unique<NullEncrypter>(session()->perspective()));
     session()->OnConfigNegotiated();
     if (session()->connection()->version().handshake_protocol ==
         PROTOCOL_TLS1_3) {
@@ -487,6 +490,16 @@
     closed_streams_.insert(id);
   }
 
+  void CompleteHandshake() {
+    CryptoHandshakeMessage msg;
+    if (connection_->version().HasHandshakeDone() &&
+        connection_->perspective() == Perspective::IS_SERVER) {
+      EXPECT_CALL(*connection_, SendControlFrame(_))
+          .WillOnce(Invoke(&ClearControlFrame));
+    }
+    session_.GetMutableCryptoStream()->OnHandshakeMessage(msg);
+  }
+
   QuicTransportVersion transport_version() const {
     return connection_->transport_version();
   }
@@ -630,12 +643,7 @@
 
 TEST_P(QuicSessionTestServer, OneRttKeysAvailable) {
   EXPECT_FALSE(session_.OneRttKeysAvailable());
-  CryptoHandshakeMessage message;
-  if (connection_->version().HasHandshakeDone()) {
-    EXPECT_CALL(*connection_, SendControlFrame(_));
-  }
-  connection_->SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE);
-  session_.GetMutableCryptoStream()->OnHandshakeMessage(message);
+  CompleteHandshake();
   EXPECT_TRUE(session_.OneRttKeysAvailable());
 }
 
@@ -928,6 +936,7 @@
 }
 
 TEST_P(QuicSessionTestServer, OnCanWrite) {
+  CompleteHandshake();
   session_.set_writev_consumes_all_data(true);
   TestStream* stream2 = session_.CreateOutgoingBidirectionalStream();
   TestStream* stream4 = session_.CreateOutgoingBidirectionalStream();
@@ -1152,15 +1161,9 @@
 
 TEST_P(QuicSessionTestServer, OnCanWriteBundlesStreams) {
   // Encryption needs to be established before data can be sent.
-  if (connection_->version().HasHandshakeDone()) {
-    EXPECT_CALL(*connection_, SendControlFrame(_))
-        .WillRepeatedly(Invoke(&ClearControlFrame));
-  }
-  CryptoHandshakeMessage msg;
+  CompleteHandshake();
   MockPacketWriter* writer = static_cast<MockPacketWriter*>(
       QuicConnectionPeer::GetWriter(session_.connection()));
-  connection_->SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE);
-  session_.GetMutableCryptoStream()->OnHandshakeMessage(msg);
 
   // Drive congestion control manually.
   MockSendAlgorithm* send_algorithm = new StrictMock<MockSendAlgorithm>;
@@ -1199,6 +1202,7 @@
 }
 
 TEST_P(QuicSessionTestServer, OnCanWriteCongestionControlBlocks) {
+  CompleteHandshake();
   session_.set_writev_consumes_all_data(true);
   InSequence s;
 
@@ -1246,6 +1250,7 @@
 }
 
 TEST_P(QuicSessionTestServer, OnCanWriteWriterBlocks) {
+  CompleteHandshake();
   // Drive congestion control manually in order to ensure that
   // application-limited signaling is handled correctly.
   MockSendAlgorithm* send_algorithm = new StrictMock<MockSendAlgorithm>;
@@ -2373,6 +2378,7 @@
 }
 
 TEST_P(QuicSessionTestServer, OnStreamFrameLost) {
+  CompleteHandshake();
   InSequence s;
 
   // Drive congestion control manually.