Fix QuicSession::WillingAndAbleToWrite to check write keys.
Protected by FLAGS_quic_reloadable_flag_quic_fix_willing_and_able_to_write2.
PiperOrigin-RevId: 339286884
Change-Id: Ifb649dd85bade3a86864b146b20a61ffb56d3b77
diff --git a/quic/core/http/quic_spdy_session_test.cc b/quic/core/http/quic_spdy_session_test.cc
index 84ea617..d95fa6d 100644
--- a/quic/core/http/quic_spdy_session_test.cc
+++ b/quic/core/http/quic_spdy_session_test.cc
@@ -704,6 +704,7 @@
}
TEST_P(QuicSpdySessionTestServer, OnCanWrite) {
+ CompleteHandshake();
session_.set_writev_consumes_all_data(true);
TestStream* stream2 = session_.CreateOutgoingBidirectionalStream();
TestStream* stream4 = session_.CreateOutgoingBidirectionalStream();
@@ -861,6 +862,7 @@
}
TEST_P(QuicSpdySessionTestServer, OnCanWriteCongestionControlBlocks) {
+ CompleteHandshake();
session_.set_writev_consumes_all_data(true);
InSequence s;
@@ -908,6 +910,7 @@
}
TEST_P(QuicSpdySessionTestServer, OnCanWriteWriterBlocks) {
+ CompleteHandshake();
// Drive congestion control manually in order to ensure that
// application-limited signaling is handled correctly.
MockSendAlgorithm* send_algorithm = new StrictMock<MockSendAlgorithm>;
@@ -1009,6 +1012,7 @@
TEST_P(QuicSpdySessionTestServer,
OnCanWriteLimitsNumWritesIfFlowControlBlocked) {
+ CompleteHandshake();
// Drive congestion control manually in order to ensure that
// application-limited signaling is handled correctly.
MockSendAlgorithm* send_algorithm = new StrictMock<MockSendAlgorithm>;
@@ -2031,6 +2035,7 @@
}
TEST_P(QuicSpdySessionTestServer, OnStreamFrameLost) {
+ CompleteHandshake();
InSequence s;
// Drive congestion control manually.
diff --git a/quic/core/quic_session.cc b/quic/core/quic_session.cc
index a8d9e6f..5079b46 100644
--- a/quic/core/quic_session.cc
+++ b/quic/core/quic_session.cc
@@ -649,9 +649,16 @@
// 3) If the crypto or headers streams are blocked, or
// 4) connection is not flow control blocked and there are write blocked
// streams.
- if (QuicVersionUsesCryptoFrames(transport_version()) &&
- HasPendingHandshake()) {
- return true;
+ if (QuicVersionUsesCryptoFrames(transport_version())) {
+ if (HasPendingHandshake()) {
+ return true;
+ }
+ if (GetQuicReloadableFlag(quic_fix_willing_and_able_to_write2)) {
+ QUIC_RELOADABLE_FLAG_COUNT(quic_fix_willing_and_able_to_write2);
+ if (!IsEncryptionEstablished()) {
+ return false;
+ }
+ }
}
if (control_frame_manager_.WillingToWrite() ||
!streams_with_pending_retransmission_.empty()) {
diff --git a/quic/core/quic_session_test.cc b/quic/core/quic_session_test.cc
index 46ce528..738ee0b 100644
--- a/quic/core/quic_session_test.cc
+++ b/quic/core/quic_session_test.cc
@@ -104,6 +104,9 @@
session()->config()->ProcessPeerHello(msg, CLIENT, &error_details);
}
EXPECT_THAT(error, IsQuicNoError());
+ session()->OnNewEncryptionKeyAvailable(
+ ENCRYPTION_FORWARD_SECURE,
+ std::make_unique<NullEncrypter>(session()->perspective()));
session()->OnConfigNegotiated();
if (session()->connection()->version().handshake_protocol ==
PROTOCOL_TLS1_3) {
@@ -487,6 +490,16 @@
closed_streams_.insert(id);
}
+ void CompleteHandshake() {
+ CryptoHandshakeMessage msg;
+ if (connection_->version().HasHandshakeDone() &&
+ connection_->perspective() == Perspective::IS_SERVER) {
+ EXPECT_CALL(*connection_, SendControlFrame(_))
+ .WillOnce(Invoke(&ClearControlFrame));
+ }
+ session_.GetMutableCryptoStream()->OnHandshakeMessage(msg);
+ }
+
QuicTransportVersion transport_version() const {
return connection_->transport_version();
}
@@ -630,12 +643,7 @@
TEST_P(QuicSessionTestServer, OneRttKeysAvailable) {
EXPECT_FALSE(session_.OneRttKeysAvailable());
- CryptoHandshakeMessage message;
- if (connection_->version().HasHandshakeDone()) {
- EXPECT_CALL(*connection_, SendControlFrame(_));
- }
- connection_->SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE);
- session_.GetMutableCryptoStream()->OnHandshakeMessage(message);
+ CompleteHandshake();
EXPECT_TRUE(session_.OneRttKeysAvailable());
}
@@ -928,6 +936,7 @@
}
TEST_P(QuicSessionTestServer, OnCanWrite) {
+ CompleteHandshake();
session_.set_writev_consumes_all_data(true);
TestStream* stream2 = session_.CreateOutgoingBidirectionalStream();
TestStream* stream4 = session_.CreateOutgoingBidirectionalStream();
@@ -1152,15 +1161,9 @@
TEST_P(QuicSessionTestServer, OnCanWriteBundlesStreams) {
// Encryption needs to be established before data can be sent.
- if (connection_->version().HasHandshakeDone()) {
- EXPECT_CALL(*connection_, SendControlFrame(_))
- .WillRepeatedly(Invoke(&ClearControlFrame));
- }
- CryptoHandshakeMessage msg;
+ CompleteHandshake();
MockPacketWriter* writer = static_cast<MockPacketWriter*>(
QuicConnectionPeer::GetWriter(session_.connection()));
- connection_->SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE);
- session_.GetMutableCryptoStream()->OnHandshakeMessage(msg);
// Drive congestion control manually.
MockSendAlgorithm* send_algorithm = new StrictMock<MockSendAlgorithm>;
@@ -1199,6 +1202,7 @@
}
TEST_P(QuicSessionTestServer, OnCanWriteCongestionControlBlocks) {
+ CompleteHandshake();
session_.set_writev_consumes_all_data(true);
InSequence s;
@@ -1246,6 +1250,7 @@
}
TEST_P(QuicSessionTestServer, OnCanWriteWriterBlocks) {
+ CompleteHandshake();
// Drive congestion control manually in order to ensure that
// application-limited signaling is handled correctly.
MockSendAlgorithm* send_algorithm = new StrictMock<MockSendAlgorithm>;
@@ -2373,6 +2378,7 @@
}
TEST_P(QuicSessionTestServer, OnStreamFrameLost) {
+ CompleteHandshake();
InSequence s;
// Drive congestion control manually.