Fix QuicSession::WillingAndAbleToWrite to check write keys. Protected by FLAGS_quic_reloadable_flag_quic_fix_willing_and_able_to_write2. PiperOrigin-RevId: 339286884 Change-Id: Ifb649dd85bade3a86864b146b20a61ffb56d3b77
diff --git a/quic/core/http/quic_spdy_session_test.cc b/quic/core/http/quic_spdy_session_test.cc index 84ea617..d95fa6d 100644 --- a/quic/core/http/quic_spdy_session_test.cc +++ b/quic/core/http/quic_spdy_session_test.cc
@@ -704,6 +704,7 @@ } TEST_P(QuicSpdySessionTestServer, OnCanWrite) { + CompleteHandshake(); session_.set_writev_consumes_all_data(true); TestStream* stream2 = session_.CreateOutgoingBidirectionalStream(); TestStream* stream4 = session_.CreateOutgoingBidirectionalStream(); @@ -861,6 +862,7 @@ } TEST_P(QuicSpdySessionTestServer, OnCanWriteCongestionControlBlocks) { + CompleteHandshake(); session_.set_writev_consumes_all_data(true); InSequence s; @@ -908,6 +910,7 @@ } TEST_P(QuicSpdySessionTestServer, OnCanWriteWriterBlocks) { + CompleteHandshake(); // Drive congestion control manually in order to ensure that // application-limited signaling is handled correctly. MockSendAlgorithm* send_algorithm = new StrictMock<MockSendAlgorithm>; @@ -1009,6 +1012,7 @@ TEST_P(QuicSpdySessionTestServer, OnCanWriteLimitsNumWritesIfFlowControlBlocked) { + CompleteHandshake(); // Drive congestion control manually in order to ensure that // application-limited signaling is handled correctly. MockSendAlgorithm* send_algorithm = new StrictMock<MockSendAlgorithm>; @@ -2031,6 +2035,7 @@ } TEST_P(QuicSpdySessionTestServer, OnStreamFrameLost) { + CompleteHandshake(); InSequence s; // Drive congestion control manually.
diff --git a/quic/core/quic_session.cc b/quic/core/quic_session.cc index a8d9e6f..5079b46 100644 --- a/quic/core/quic_session.cc +++ b/quic/core/quic_session.cc
@@ -649,9 +649,16 @@ // 3) If the crypto or headers streams are blocked, or // 4) connection is not flow control blocked and there are write blocked // streams. - if (QuicVersionUsesCryptoFrames(transport_version()) && - HasPendingHandshake()) { - return true; + if (QuicVersionUsesCryptoFrames(transport_version())) { + if (HasPendingHandshake()) { + return true; + } + if (GetQuicReloadableFlag(quic_fix_willing_and_able_to_write2)) { + QUIC_RELOADABLE_FLAG_COUNT(quic_fix_willing_and_able_to_write2); + if (!IsEncryptionEstablished()) { + return false; + } + } } if (control_frame_manager_.WillingToWrite() || !streams_with_pending_retransmission_.empty()) {
diff --git a/quic/core/quic_session_test.cc b/quic/core/quic_session_test.cc index 46ce528..738ee0b 100644 --- a/quic/core/quic_session_test.cc +++ b/quic/core/quic_session_test.cc
@@ -104,6 +104,9 @@ session()->config()->ProcessPeerHello(msg, CLIENT, &error_details); } EXPECT_THAT(error, IsQuicNoError()); + session()->OnNewEncryptionKeyAvailable( + ENCRYPTION_FORWARD_SECURE, + std::make_unique<NullEncrypter>(session()->perspective())); session()->OnConfigNegotiated(); if (session()->connection()->version().handshake_protocol == PROTOCOL_TLS1_3) { @@ -487,6 +490,16 @@ closed_streams_.insert(id); } + void CompleteHandshake() { + CryptoHandshakeMessage msg; + if (connection_->version().HasHandshakeDone() && + connection_->perspective() == Perspective::IS_SERVER) { + EXPECT_CALL(*connection_, SendControlFrame(_)) + .WillOnce(Invoke(&ClearControlFrame)); + } + session_.GetMutableCryptoStream()->OnHandshakeMessage(msg); + } + QuicTransportVersion transport_version() const { return connection_->transport_version(); } @@ -630,12 +643,7 @@ TEST_P(QuicSessionTestServer, OneRttKeysAvailable) { EXPECT_FALSE(session_.OneRttKeysAvailable()); - CryptoHandshakeMessage message; - if (connection_->version().HasHandshakeDone()) { - EXPECT_CALL(*connection_, SendControlFrame(_)); - } - connection_->SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); - session_.GetMutableCryptoStream()->OnHandshakeMessage(message); + CompleteHandshake(); EXPECT_TRUE(session_.OneRttKeysAvailable()); } @@ -928,6 +936,7 @@ } TEST_P(QuicSessionTestServer, OnCanWrite) { + CompleteHandshake(); session_.set_writev_consumes_all_data(true); TestStream* stream2 = session_.CreateOutgoingBidirectionalStream(); TestStream* stream4 = session_.CreateOutgoingBidirectionalStream(); @@ -1152,15 +1161,9 @@ TEST_P(QuicSessionTestServer, OnCanWriteBundlesStreams) { // Encryption needs to be established before data can be sent. - if (connection_->version().HasHandshakeDone()) { - EXPECT_CALL(*connection_, SendControlFrame(_)) - .WillRepeatedly(Invoke(&ClearControlFrame)); - } - CryptoHandshakeMessage msg; + CompleteHandshake(); MockPacketWriter* writer = static_cast<MockPacketWriter*>( QuicConnectionPeer::GetWriter(session_.connection())); - connection_->SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); - session_.GetMutableCryptoStream()->OnHandshakeMessage(msg); // Drive congestion control manually. MockSendAlgorithm* send_algorithm = new StrictMock<MockSendAlgorithm>; @@ -1199,6 +1202,7 @@ } TEST_P(QuicSessionTestServer, OnCanWriteCongestionControlBlocks) { + CompleteHandshake(); session_.set_writev_consumes_all_data(true); InSequence s; @@ -1246,6 +1250,7 @@ } TEST_P(QuicSessionTestServer, OnCanWriteWriterBlocks) { + CompleteHandshake(); // Drive congestion control manually in order to ensure that // application-limited signaling is handled correctly. MockSendAlgorithm* send_algorithm = new StrictMock<MockSendAlgorithm>; @@ -2373,6 +2378,7 @@ } TEST_P(QuicSessionTestServer, OnStreamFrameLost) { + CompleteHandshake(); InSequence s; // Drive congestion control manually.