Prevent QuicSimpleServerStream from sending two responses
diff --git a/quic/tools/quic_simple_server_stream.cc b/quic/tools/quic_simple_server_stream.cc index 4377b56..41fdd6a 100644 --- a/quic/tools/quic_simple_server_stream.cc +++ b/quic/tools/quic_simple_server_stream.cc
@@ -64,7 +64,7 @@ SendErrorResponse(); } ConsumeHeaderList(); - if (!fin) { + if (!fin && !response_sent_) { // CONNECT and other CONNECT-like methods (such as CONNECT-UDP) require // sending the response right after parsing the headers even though the FIN // bit has not been received on the request stream. @@ -290,6 +290,8 @@ quiche::QuicheTextUtils::Uint64ToString(generate_bytes_length_); WriteHeaders(std::move(headers), false, nullptr); + QUICHE_DCHECK(!response_sent_); + response_sent_ = true; WriteGeneratedBytes(); @@ -349,6 +351,8 @@ QUIC_DLOG(INFO) << "Stream " << id() << " writing headers (fin = false) : " << response_headers.DebugString(); WriteHeaders(std::move(response_headers), /*fin=*/false, nullptr); + QUICHE_DCHECK(!response_sent_); + response_sent_ = true; QUIC_DLOG(INFO) << "Stream " << id() << " writing body (fin = false) with size: " << body.size(); @@ -373,6 +377,8 @@ QUIC_DLOG(INFO) << "Stream " << id() << " writing headers (fin = " << send_fin << ") : " << response_headers.DebugString(); WriteHeaders(std::move(response_headers), send_fin, nullptr); + QUICHE_DCHECK(!response_sent_); + response_sent_ = true; if (send_fin) { // Nothing else to send. return;
diff --git a/quic/tools/quic_simple_server_stream.h b/quic/tools/quic_simple_server_stream.h index bbed95a..2dd59f9 100644 --- a/quic/tools/quic_simple_server_stream.h +++ b/quic/tools/quic_simple_server_stream.h
@@ -99,6 +99,8 @@ private: uint64_t generate_bytes_length_; + // Whether response headers have already been sent. + bool response_sent_ = false; QuicSimpleServerBackend* quic_simple_server_backend_; // Not owned. };
diff --git a/quic/tools/quic_simple_server_stream_test.cc b/quic/tools/quic_simple_server_stream_test.cc index 541316a..13b0335 100644 --- a/quic/tools/quic_simple_server_stream_test.cc +++ b/quic/tools/quic_simple_server_stream_test.cc
@@ -790,6 +790,32 @@ EXPECT_FALSE(stream_->send_error_response_was_called()); } +TEST_P(QuicSimpleServerStreamTest, ConnectWithInvalidHeader) { + EXPECT_CALL(session_, WritevData(_, _, _, _, _, _)) + .WillRepeatedly( + Invoke(&session_, &MockQuicSimpleServerSession::ConsumeData)); + QuicHeaderList header_list; + header_list.OnHeaderBlockStart(); + header_list.OnHeader(":authority", "www.google.com:4433"); + header_list.OnHeader(":method", "CONNECT-SILLY"); + // QUIC requires lower-case header names. + header_list.OnHeader("InVaLiD-HeAdEr", "Well that's just wrong!"); + header_list.OnHeaderBlockEnd(128, 128); + EXPECT_CALL(*stream_, WriteHeadersMock(/*fin=*/false)); + stream_->OnStreamHeaderList(/*fin=*/false, kFakeFrameLen, header_list); + std::unique_ptr<char[]> buffer; + QuicByteCount header_length = + HttpEncoder::SerializeDataFrameHeader(body_.length(), &buffer); + std::string header = std::string(buffer.get(), header_length); + std::string data = UsesHttp3() ? header + body_ : body_; + stream_->OnStreamFrame( + QuicStreamFrame(stream_->id(), /*fin=*/false, /*offset=*/0, data)); + EXPECT_EQ("CONNECT-SILLY", StreamHeadersValue(":method")); + EXPECT_EQ(body_, StreamBody()); + EXPECT_FALSE(stream_->send_response_was_called()); + EXPECT_TRUE(stream_->send_error_response_was_called()); +} + } // namespace } // namespace test } // namespace quic