Internal QUICHE change PiperOrigin-RevId: 270719218 Change-Id: If2d4f9a6246777ef7920c276e41d97f7e4295631
diff --git a/quic/core/http/http_decoder.cc b/quic/core/http/http_decoder.cc index e1fa36e..7fe994e 100644 --- a/quic/core/http/http_decoder.cc +++ b/quic/core/http/http_decoder.cc
@@ -153,10 +153,13 @@ continue_processing = visitor_->OnSettingsFrameStart(header_length); break; case static_cast<uint64_t>(HttpFrameType::PUSH_PROMISE): + // This edge case needs to be handled here, because ReadFramePayload() + // does not get called if |current_frame_length_| is zero. if (current_frame_length_ == 0) { RaiseError(QUIC_INVALID_FRAME_DATA, "Corrupt PUSH_PROMISE frame."); return false; } + continue_processing = visitor_->OnPushPromiseFrameStart(header_length); break; case static_cast<uint64_t>(HttpFrameType::GOAWAY): break; @@ -239,10 +242,8 @@ bool success = reader->ReadVarInt62(&push_id); DCHECK(success); remaining_frame_length_ -= current_push_id_length_; - if (!visitor_->OnPushPromiseFrameStart( - push_id, - current_length_field_length_ + current_type_field_length_, - current_push_id_length_)) { + if (!visitor_->OnPushPromiseFramePushId(push_id, + current_push_id_length_)) { continue_processing = false; current_push_id_length_ = 0; break; @@ -259,10 +260,8 @@ bool success = push_id_reader.ReadVarInt62(&push_id); DCHECK(success); - if (!visitor_->OnPushPromiseFrameStart( - push_id, - current_length_field_length_ + current_type_field_length_, - current_push_id_length_)) { + if (!visitor_->OnPushPromiseFramePushId(push_id, + current_push_id_length_)) { continue_processing = false; current_push_id_length_ = 0; break;
diff --git a/quic/core/http/http_decoder.h b/quic/core/http/http_decoder.h index 30ebadc..862c5de 100644 --- a/quic/core/http/http_decoder.h +++ b/quic/core/http/http_decoder.h
@@ -37,11 +37,10 @@ // All the following methods return true to continue decoding, // and false to pause it. // On*FrameStart() methods are called after the frame header is completely - // processed. At that point it is safe to consume - // |frame_length.header_length| bytes. + // processed. At that point it is safe to consume |header_length| bytes. // Called when a PRIORITY frame has been received. - // |frame_length| contains PRIORITY frame length and payload length. + // |header_length| contains PRIORITY frame length and payload length. virtual bool OnPriorityFrameStart(QuicByteCount header_length) = 0; // Called when a PRIORITY frame has been successfully parsed. @@ -66,7 +65,7 @@ virtual bool OnDuplicatePushFrame(const DuplicatePushFrame& frame) = 0; // Called when a DATA frame has been received. - // |frame_length| contains DATA frame length and payload length. + // |header_length| contains DATA frame length and payload length. virtual bool OnDataFrameStart(QuicByteCount header_length) = 0; // Called when part of the payload of a DATA frame has been read. May be // called multiple times for a single frame. |payload| is guaranteed to be @@ -76,7 +75,7 @@ virtual bool OnDataFrameEnd() = 0; // Called when a HEADERS frame has been received. - // |frame_length| contains HEADERS frame length and payload length. + // |header_length| contains HEADERS frame length and payload length. virtual bool OnHeadersFrameStart(QuicByteCount header_length) = 0; // Called when part of the payload of a HEADERS frame has been read. May be // called multiple times for a single frame. |payload| is guaranteed to be @@ -86,20 +85,22 @@ // |frame_len| is the length of the HEADERS frame payload. virtual bool OnHeadersFrameEnd() = 0; - // Called when a PUSH_PROMISE frame has been received for |push_id|. - virtual bool OnPushPromiseFrameStart(PushId push_id, - QuicByteCount header_length, - QuicByteCount push_id_length) = 0; - // Called when part of the payload of a PUSH_PROMISE frame has been read. - // May be called multiple times for a single frame. |payload| is guaranteed - // to be non-empty. + // Called when a PUSH_PROMISE frame has been received. + virtual bool OnPushPromiseFrameStart(QuicByteCount header_length) = 0; + // Called when the Push ID field of a PUSH_PROMISE frame has been parsed. + // Called exactly once for a valid PUSH_PROMISE frame. + virtual bool OnPushPromiseFramePushId(PushId push_id, + QuicByteCount push_id_length) = 0; + // Called when part of the header block of a PUSH_PROMISE frame has been + // read. May be called multiple times for a single frame. |payload| is + // guaranteed to be non-empty. virtual bool OnPushPromiseFramePayload(QuicStringPiece payload) = 0; // Called when a PUSH_PROMISE frame has been completely processed. virtual bool OnPushPromiseFrameEnd() = 0; // Called when a frame of unknown type |frame_type| has been received. // Frame type might be reserved, Visitor must make sure to ignore. - // |frame_length| contains frame length and payload length. + // |header_length| contains frame length and payload length. virtual bool OnUnknownFrameStart(uint64_t frame_type, QuicByteCount header_length) = 0; // Called when part of the payload of the unknown frame has been read. May
diff --git a/quic/core/http/http_decoder_test.cc b/quic/core/http/http_decoder_test.cc index 3abc8be..4baf271 100644 --- a/quic/core/http/http_decoder_test.cc +++ b/quic/core/http/http_decoder_test.cc
@@ -52,10 +52,9 @@ MOCK_METHOD1(OnHeadersFramePayload, bool(QuicStringPiece payload)); MOCK_METHOD0(OnHeadersFrameEnd, bool()); - MOCK_METHOD3(OnPushPromiseFrameStart, - bool(PushId push_id, - QuicByteCount header_length, - QuicByteCount push_id_length)); + MOCK_METHOD1(OnPushPromiseFrameStart, bool(QuicByteCount header_length)); + MOCK_METHOD2(OnPushPromiseFramePushId, + bool(PushId push_id, QuicByteCount push_id_length)); MOCK_METHOD1(OnPushPromiseFramePayload, bool(QuicStringPiece payload)); MOCK_METHOD0(OnPushPromiseFrameEnd, bool()); @@ -81,7 +80,8 @@ ON_CALL(visitor_, OnHeadersFrameStart(_)).WillByDefault(Return(true)); ON_CALL(visitor_, OnHeadersFramePayload(_)).WillByDefault(Return(true)); ON_CALL(visitor_, OnHeadersFrameEnd()).WillByDefault(Return(true)); - ON_CALL(visitor_, OnPushPromiseFrameStart(_, _, _)) + ON_CALL(visitor_, OnPushPromiseFrameStart(_)).WillByDefault(Return(true)); + ON_CALL(visitor_, OnPushPromiseFramePushId(_, _)) .WillByDefault(Return(true)); ON_CALL(visitor_, OnPushPromiseFramePayload(_)).WillByDefault(Return(true)); ON_CALL(visitor_, OnPushPromiseFrameEnd()).WillByDefault(Return(true)); @@ -210,12 +210,16 @@ "Headers"); // headers // Visitor pauses processing. - EXPECT_CALL(visitor_, OnPushPromiseFrameStart(257, 2, 8)) + EXPECT_CALL(visitor_, OnPushPromiseFrameStart(2)).WillOnce(Return(false)); + EXPECT_CALL(visitor_, OnPushPromiseFramePushId(257, 8)) .WillOnce(Return(false)); QuicStringPiece remaining_input(input); QuicByteCount processed_bytes = ProcessInputWithGarbageAppended(remaining_input); - EXPECT_EQ(10u, processed_bytes); + EXPECT_EQ(2u, processed_bytes); + remaining_input = remaining_input.substr(processed_bytes); + processed_bytes = ProcessInputWithGarbageAppended(remaining_input); + EXPECT_EQ(8u, processed_bytes); remaining_input = remaining_input.substr(processed_bytes); EXPECT_CALL(visitor_, OnPushPromiseFramePayload(QuicStringPiece("Headers"))) @@ -229,7 +233,8 @@ EXPECT_EQ("", decoder_.error_detail()); // Process the full frame. - EXPECT_CALL(visitor_, OnPushPromiseFrameStart(257, 2, 8)); + EXPECT_CALL(visitor_, OnPushPromiseFrameStart(2)); + EXPECT_CALL(visitor_, OnPushPromiseFramePushId(257, 8)); EXPECT_CALL(visitor_, OnPushPromiseFramePayload(QuicStringPiece("Headers"))); EXPECT_CALL(visitor_, OnPushPromiseFrameEnd()); EXPECT_EQ(input.size(), ProcessInput(input)); @@ -237,7 +242,8 @@ EXPECT_EQ("", decoder_.error_detail()); // Process the frame incrementally. - EXPECT_CALL(visitor_, OnPushPromiseFrameStart(257, 2, 8)); + EXPECT_CALL(visitor_, OnPushPromiseFrameStart(2)); + EXPECT_CALL(visitor_, OnPushPromiseFramePushId(257, 8)); EXPECT_CALL(visitor_, OnPushPromiseFramePayload(QuicStringPiece("H"))); EXPECT_CALL(visitor_, OnPushPromiseFramePayload(QuicStringPiece("e"))); EXPECT_CALL(visitor_, OnPushPromiseFramePayload(QuicStringPiece("a"))); @@ -251,7 +257,8 @@ EXPECT_EQ("", decoder_.error_detail()); // Process push id incrementally and append headers with last byte of push id. - EXPECT_CALL(visitor_, OnPushPromiseFrameStart(257, 2, 8)); + EXPECT_CALL(visitor_, OnPushPromiseFrameStart(2)); + EXPECT_CALL(visitor_, OnPushPromiseFramePushId(257, 8)); EXPECT_CALL(visitor_, OnPushPromiseFramePayload(QuicStringPiece("Headers"))); EXPECT_CALL(visitor_, OnPushPromiseFrameEnd()); ProcessInputCharByChar(input.substr(0, 9)); @@ -260,6 +267,38 @@ EXPECT_EQ("", decoder_.error_detail()); } +TEST_F(HttpDecoderTest, CorruptPushPromiseFrame) { + InSequence s; + + std::string input = QuicTextUtils::HexDecode( + "05" // type (PUSH_PROMISE) + "01" // length + "40"); // first byte of two-byte varint push id + + { + HttpDecoder decoder(&visitor_); + EXPECT_CALL(visitor_, OnPushPromiseFrameStart(2)); + EXPECT_CALL(visitor_, OnError(&decoder)); + + decoder.ProcessInput(input.data(), input.size()); + + EXPECT_EQ(QUIC_INVALID_FRAME_DATA, decoder.error()); + EXPECT_EQ("PUSH_PROMISE frame malformed.", decoder.error_detail()); + } + { + HttpDecoder decoder(&visitor_); + EXPECT_CALL(visitor_, OnPushPromiseFrameStart(2)); + EXPECT_CALL(visitor_, OnError(&decoder)); + + for (auto c : input) { + decoder.ProcessInput(&c, 1); + } + + EXPECT_EQ(QUIC_INVALID_FRAME_DATA, decoder.error()); + EXPECT_EQ("PUSH_PROMISE frame malformed.", decoder.error_detail()); + } +} + TEST_F(HttpDecoderTest, MaxPushId) { InSequence s; std::string input = QuicTextUtils::HexDecode( @@ -767,8 +806,8 @@ "01"); // Push Id // Visitor pauses processing. - EXPECT_CALL(visitor_, OnPushPromiseFrameStart(1, 2, 1)) - .WillOnce(Return(false)); + EXPECT_CALL(visitor_, OnPushPromiseFrameStart(2)); + EXPECT_CALL(visitor_, OnPushPromiseFramePushId(1, 1)).WillOnce(Return(false)); EXPECT_EQ(input.size(), ProcessInputWithGarbageAppended(input)); EXPECT_CALL(visitor_, OnPushPromiseFrameEnd()).WillOnce(Return(false)); @@ -777,14 +816,16 @@ EXPECT_EQ("", decoder_.error_detail()); // Process the full frame. - EXPECT_CALL(visitor_, OnPushPromiseFrameStart(1, 2, 1)); + EXPECT_CALL(visitor_, OnPushPromiseFrameStart(2)); + EXPECT_CALL(visitor_, OnPushPromiseFramePushId(1, 1)); EXPECT_CALL(visitor_, OnPushPromiseFrameEnd()); EXPECT_EQ(input.size(), ProcessInput(input)); EXPECT_EQ(QUIC_NO_ERROR, decoder_.error()); EXPECT_EQ("", decoder_.error_detail()); // Process the frame incrementally. - EXPECT_CALL(visitor_, OnPushPromiseFrameStart(1, 2, 1)); + EXPECT_CALL(visitor_, OnPushPromiseFrameStart(2)); + EXPECT_CALL(visitor_, OnPushPromiseFramePushId(1, 1)); EXPECT_CALL(visitor_, OnPushPromiseFrameEnd()); ProcessInputCharByChar(input); EXPECT_EQ(QUIC_NO_ERROR, decoder_.error()); @@ -865,10 +906,6 @@ "\x05" // valid push id "foo", // superfluous data "Superfluous data in CANCEL_PUSH frame."}, - {"\x05" // type (PUSH_PROMISE) - "\x01" // length - "\x40", // first byte of two-byte varint push id - "PUSH_PROMISE frame malformed."}, {"\x0D" // type (MAX_PUSH_ID) "\x01" // length "\x40", // first byte of two-byte varint push id
diff --git a/quic/core/http/quic_receive_control_stream.cc b/quic/core/http/quic_receive_control_stream.cc index 86005e7..dd9f3ec 100644 --- a/quic/core/http/quic_receive_control_stream.cc +++ b/quic/core/http/quic_receive_control_stream.cc
@@ -97,7 +97,7 @@ return false; } - bool OnHeadersFrameStart(QuicByteCount /*frame_length*/) override { + bool OnHeadersFrameStart(QuicByteCount /*header_length*/) override { CloseConnectionOnWrongFrame("Headers"); return false; } @@ -112,9 +112,13 @@ return false; } - bool OnPushPromiseFrameStart(PushId /*push_id*/, - QuicByteCount /*frame_length*/, - QuicByteCount /*push_id_length*/) override { + bool OnPushPromiseFrameStart(QuicByteCount /*header_length*/) override { + CloseConnectionOnWrongFrame("Push Promise"); + return false; + } + + bool OnPushPromiseFramePushId(PushId /*push_id*/, + QuicByteCount /*push_id_length*/) override { CloseConnectionOnWrongFrame("Push Promise"); return false; } @@ -130,7 +134,7 @@ } bool OnUnknownFrameStart(uint64_t /* frame_type */, - QuicByteCount /* frame_length */) override { + QuicByteCount /* header_length */) override { // Ignore unknown frame types. return true; }
diff --git a/quic/core/http/quic_receive_control_stream_test.cc b/quic/core/http/quic_receive_control_stream_test.cc index 6914681..abb5d5e 100644 --- a/quic/core/http/quic_receive_control_stream_test.cc +++ b/quic/core/http/quic_receive_control_stream_test.cc
@@ -257,7 +257,10 @@ length); // TODO(lassey) Check for HTTP_WRONG_STREAM error code. EXPECT_CALL(*connection_, CloseConnection(QUIC_HTTP_DECODER_ERROR, _, _)) - .Times(AtLeast(1)); + .WillOnce( + Invoke(connection_, &MockQuicConnection::ReallyCloseConnection)); + EXPECT_CALL(*connection_, SendConnectionClosePacket(_, _)); + EXPECT_CALL(session_, OnConnectionClosed(_, _)); receive_control_stream_->OnStreamFrame(frame); }
diff --git a/quic/core/http/quic_spdy_stream.cc b/quic/core/http/quic_spdy_stream.cc index a089345..ea8e669 100644 --- a/quic/core/http/quic_spdy_stream.cc +++ b/quic/core/http/quic_spdy_stream.cc
@@ -124,15 +124,21 @@ return stream_->OnHeadersFrameEnd(); } - bool OnPushPromiseFrameStart(PushId push_id, - QuicByteCount header_length, - QuicByteCount push_id_length) override { + bool OnPushPromiseFrameStart(QuicByteCount header_length) override { if (!VersionUsesQpack(stream_->transport_version())) { CloseConnectionOnWrongFrame("Push Promise"); return false; } - return stream_->OnPushPromiseFrameStart(push_id, header_length, - push_id_length); + return stream_->OnPushPromiseFrameStart(header_length); + } + + bool OnPushPromiseFramePushId(PushId push_id, + QuicByteCount push_id_length) override { + if (!VersionUsesQpack(stream_->transport_version())) { + CloseConnectionOnWrongFrame("Push Promise"); + return false; + } + return stream_->OnPushPromiseFramePushId(push_id, push_id_length); } bool OnPushPromiseFramePayload(QuicStringPiece payload) override { @@ -929,16 +935,23 @@ return !sequencer()->IsClosed() && !reading_stopped(); } -bool QuicSpdyStream::OnPushPromiseFrameStart(PushId push_id, - QuicByteCount header_length, - QuicByteCount push_id_length) { +bool QuicSpdyStream::OnPushPromiseFrameStart(QuicByteCount header_length) { + DCHECK(VersionHasStreamType(transport_version())); + DCHECK(!qpack_decoded_headers_accumulator_); + + sequencer()->MarkConsumed(body_manager_.OnNonBody(header_length)); + + return true; +} + +bool QuicSpdyStream::OnPushPromiseFramePushId(PushId push_id, + QuicByteCount push_id_length) { DCHECK(VersionHasStreamType(transport_version())); DCHECK(!qpack_decoded_headers_accumulator_); // TODO(renjietang): Check max push id and handle errors. spdy_session_->OnPushPromise(id(), push_id); - sequencer()->MarkConsumed( - body_manager_.OnNonBody(header_length + push_id_length)); + sequencer()->MarkConsumed(body_manager_.OnNonBody(push_id_length)); qpack_decoded_headers_accumulator_ = std::make_unique<QpackDecodedHeadersAccumulator>(
diff --git a/quic/core/http/quic_spdy_stream.h b/quic/core/http/quic_spdy_stream.h index 4849705..a2ebc54 100644 --- a/quic/core/http/quic_spdy_stream.h +++ b/quic/core/http/quic_spdy_stream.h
@@ -259,9 +259,8 @@ bool OnHeadersFrameStart(QuicByteCount header_length); bool OnHeadersFramePayload(QuicStringPiece payload); bool OnHeadersFrameEnd(); - bool OnPushPromiseFrameStart(PushId push_id, - QuicByteCount header_length, - QuicByteCount push_id_length); + bool OnPushPromiseFrameStart(QuicByteCount header_length); + bool OnPushPromiseFramePushId(PushId push_id, QuicByteCount push_id_length); bool OnPushPromiseFramePayload(QuicStringPiece payload); bool OnPushPromiseFrameEnd(); bool OnUnknownFrameStart(uint64_t frame_type, QuicByteCount header_length);