Allow HttpDecoder visitor to return bool upon receiving HTTP/3 frames. The prevents decoder from reading after the stream has already freed its sequencer. gfe-relnote: Version 99 only. Not in prod. PiperOrigin-RevId: 251257525 Change-Id: Ia2bd1cc0073d45df60b91dc4fcfed327b31ec369
diff --git a/quic/core/http/http_decoder.cc b/quic/core/http/http_decoder.cc index c59e8b0..7f0cc68 100644 --- a/quic/core/http/http_decoder.cc +++ b/quic/core/http/http_decoder.cc
@@ -165,17 +165,26 @@ // Calling the following two visitor methods does not require parsing of any // frame payload. if (current_frame_type_ == 0x0) { - visitor_->OnDataFrameStart(Http3FrameLengths( - current_length_field_length_ + current_type_field_length_, - current_frame_length_)); + if (!visitor_->OnDataFrameStart(Http3FrameLengths( + current_length_field_length_ + current_type_field_length_, + current_frame_length_))) { + RaiseError(QUIC_INTERNAL_ERROR, "Visitor shut down."); + return; + } } else if (current_frame_type_ == 0x1) { - visitor_->OnHeadersFrameStart(Http3FrameLengths( - current_length_field_length_ + current_type_field_length_, - current_frame_length_)); + if (!visitor_->OnHeadersFrameStart(Http3FrameLengths( + current_length_field_length_ + current_type_field_length_, + current_frame_length_))) { + RaiseError(QUIC_INTERNAL_ERROR, "Visitor shut down."); + return; + } } else if (current_frame_type_ == 0x4) { - visitor_->OnSettingsFrameStart(Http3FrameLengths( - current_length_field_length_ + current_type_field_length_, - current_frame_length_)); + if (!visitor_->OnSettingsFrameStart(Http3FrameLengths( + current_length_field_length_ + current_type_field_length_, + current_frame_length_))) { + RaiseError(QUIC_INTERNAL_ERROR, "Visitor shut down."); + return; + } } remaining_frame_length_ = current_frame_length_; @@ -196,7 +205,10 @@ return; } DCHECK(!payload.empty()); - visitor_->OnDataFramePayload(payload); + if (!visitor_->OnDataFramePayload(payload)) { + RaiseError(QUIC_INTERNAL_ERROR, "Visitor shut down."); + return; + } remaining_frame_length_ -= payload.length(); break; } @@ -209,7 +221,10 @@ return; } DCHECK(!payload.empty()); - visitor_->OnHeadersFramePayload(payload); + if (!visitor_->OnHeadersFramePayload(payload)) { + RaiseError(QUIC_INTERNAL_ERROR, "Visitor shut down."); + return; + } remaining_frame_length_ -= payload.length(); break; } @@ -237,7 +252,10 @@ return; } remaining_frame_length_ -= bytes_remaining - reader->BytesRemaining(); - visitor_->OnPushPromiseFrameStart(push_id); + if (!visitor_->OnPushPromiseFrameStart(push_id)) { + RaiseError(QUIC_INTERNAL_ERROR, "Visitor shut down."); + return; + } } DCHECK_LT(remaining_frame_length_, current_frame_length_); QuicByteCount bytes_to_read = std::min<QuicByteCount>( @@ -251,7 +269,10 @@ return; } DCHECK(!payload.empty()); - visitor_->OnPushPromiseFramePayload(payload); + if (!visitor_->OnPushPromiseFramePayload(payload)) { + RaiseError(QUIC_INTERNAL_ERROR, "Visitor shut down."); + return; + } remaining_frame_length_ -= payload.length(); break; } @@ -303,11 +324,17 @@ DCHECK_EQ(0u, remaining_frame_length_); switch (current_frame_type_) { case 0x0: { // DATA - visitor_->OnDataFrameEnd(); + if (!visitor_->OnDataFrameEnd()) { + RaiseError(QUIC_INTERNAL_ERROR, "Visitor shut down."); + return; + } break; } case 0x1: { // HEADERS - visitor_->OnHeadersFrameEnd(); + if (!visitor_->OnHeadersFrameEnd()) { + RaiseError(QUIC_INTERNAL_ERROR, "Visitor shut down."); + return; + } break; } case 0x2: { // PRIORITY @@ -318,7 +345,10 @@ if (!ParsePriorityFrame(&reader, &frame)) { return; } - visitor_->OnPriorityFrame(frame); + if (!visitor_->OnPriorityFrame(frame)) { + RaiseError(QUIC_INTERNAL_ERROR, "Visitor shut down."); + return; + } break; } case 0x3: { // CANCEL_PUSH @@ -329,7 +359,10 @@ RaiseError(QUIC_INTERNAL_ERROR, "Unable to read push_id"); return; } - visitor_->OnCancelPushFrame(frame); + if (!visitor_->OnCancelPushFrame(frame)) { + RaiseError(QUIC_INTERNAL_ERROR, "Visitor shut down."); + return; + } break; } case 0x4: { // SETTINGS @@ -338,11 +371,17 @@ if (!ParseSettingsFrame(&reader, &frame)) { return; } - visitor_->OnSettingsFrame(frame); + if (!visitor_->OnSettingsFrame(frame)) { + RaiseError(QUIC_INTERNAL_ERROR, "Visitor shut down."); + return; + } break; } case 0x5: { // PUSH_PROMISE - visitor_->OnPushPromiseFrameEnd(); + if (!visitor_->OnPushPromiseFrameEnd()) { + RaiseError(QUIC_INTERNAL_ERROR, "Visitor shut down."); + return; + } break; } case 0x7: { // GOAWAY @@ -358,7 +397,10 @@ return; } frame.stream_id = static_cast<QuicStreamId>(stream_id); - visitor_->OnGoAwayFrame(frame); + if (!visitor_->OnGoAwayFrame(frame)) { + RaiseError(QUIC_INTERNAL_ERROR, "Visitor shut down."); + return; + } break; } @@ -369,7 +411,10 @@ RaiseError(QUIC_INTERNAL_ERROR, "Unable to read push_id"); return; } - visitor_->OnMaxPushIdFrame(frame); + if (!visitor_->OnMaxPushIdFrame(frame)) { + RaiseError(QUIC_INTERNAL_ERROR, "Visitor shut down."); + return; + } break; } @@ -380,7 +425,10 @@ RaiseError(QUIC_INTERNAL_ERROR, "Unable to read push_id"); return; } - visitor_->OnDuplicatePushFrame(frame); + if (!visitor_->OnDuplicatePushFrame(frame)) { + RaiseError(QUIC_INTERNAL_ERROR, "Visitor shut down."); + return; + } break; } }
diff --git a/quic/core/http/http_decoder.h b/quic/core/http/http_decoder.h index a392962..4ed4c66 100644 --- a/quic/core/http/http_decoder.h +++ b/quic/core/http/http_decoder.h
@@ -45,55 +45,71 @@ virtual void OnError(HttpDecoder* decoder) = 0; // Called when a PRIORITY frame has been successfully parsed. - virtual void OnPriorityFrame(const PriorityFrame& frame) = 0; + // Returns true to permit furthuring decoding, and false to prevent it. + virtual bool OnPriorityFrame(const PriorityFrame& frame) = 0; // Called when a CANCEL_PUSH frame has been successfully parsed. - virtual void OnCancelPushFrame(const CancelPushFrame& frame) = 0; + // Returns true to permit furthuring decoding, and false to prevent it. + virtual bool OnCancelPushFrame(const CancelPushFrame& frame) = 0; // Called when a MAX_PUSH_ID frame has been successfully parsed. - virtual void OnMaxPushIdFrame(const MaxPushIdFrame& frame) = 0; + // Returns true to permit furthuring decoding, and false to prevent it. + virtual bool OnMaxPushIdFrame(const MaxPushIdFrame& frame) = 0; // Called when a GOAWAY frame has been successfully parsed. - virtual void OnGoAwayFrame(const GoAwayFrame& frame) = 0; + // Returns true to permit furthuring decoding, and false to prevent it. + virtual bool OnGoAwayFrame(const GoAwayFrame& frame) = 0; // Called when a SETTINGS frame has been received. - virtual void OnSettingsFrameStart(Http3FrameLengths frame_length) = 0; + // Returns true to permit furthuring decoding, and false to prevent it. + virtual bool OnSettingsFrameStart(Http3FrameLengths frame_length) = 0; // Called when a SETTINGS frame has been successfully parsed. - virtual void OnSettingsFrame(const SettingsFrame& frame) = 0; + // Returns true to permit furthuring decoding, and false to prevent it. + virtual bool OnSettingsFrame(const SettingsFrame& frame) = 0; // Called when a DUPLICATE_PUSH frame has been successfully parsed. - virtual void OnDuplicatePushFrame(const DuplicatePushFrame& frame) = 0; + // Returns true to permit furthuring decoding, and false to prevent it. + virtual bool OnDuplicatePushFrame(const DuplicatePushFrame& frame) = 0; // Called when a DATA frame has been received. // |frame_length| contains DATA frame length and payload length. - virtual void OnDataFrameStart(Http3FrameLengths frame_length) = 0; + // Returns true to permit furthuring decoding, and false to prevent it. + virtual bool OnDataFrameStart(Http3FrameLengths frame_length) = 0; // Called when part of the payload of a DATA frame has been read. May be // called multiple times for a single frame. |payload| is guaranteed to be // non-empty. - virtual void OnDataFramePayload(QuicStringPiece payload) = 0; + // Returns true to permit furthuring decoding, and false to prevent it. + virtual bool OnDataFramePayload(QuicStringPiece payload) = 0; // Called when a DATA frame has been completely processed. - virtual void OnDataFrameEnd() = 0; + // Returns true to permit furthuring decoding, and false to prevent it. + virtual bool OnDataFrameEnd() = 0; // Called when a HEADERS frame has been received. // |frame_length| contains HEADERS frame length and payload length. - virtual void OnHeadersFrameStart(Http3FrameLengths frame_length) = 0; + // Returns true to permit furthuring decoding, and false to prevent it. + virtual bool OnHeadersFrameStart(Http3FrameLengths frame_length) = 0; // Called when part of the payload of a HEADERS frame has been read. May be // called multiple times for a single frame. |payload| is guaranteed to be // non-empty. - virtual void OnHeadersFramePayload(QuicStringPiece payload) = 0; + // Returns true to permit furthuring decoding, and false to prevent it. + virtual bool OnHeadersFramePayload(QuicStringPiece payload) = 0; // Called when a HEADERS frame has been completely processed. // |frame_len| is the length of the HEADERS frame payload. - virtual void OnHeadersFrameEnd() = 0; + // Returns true to permit furthuring decoding, and false to prevent it. + virtual bool OnHeadersFrameEnd() = 0; // Called when a PUSH_PROMISE frame has been received for |push_id|. - virtual void OnPushPromiseFrameStart(PushId push_id) = 0; + // Returns true to permit furthuring decoding, and false to prevent it. + virtual bool OnPushPromiseFrameStart(PushId push_id) = 0; // Called when part of the payload of a PUSH_PROMISE frame has been read. // May be called multiple times for a single frame. |payload| is guaranteed // to be non-empty. - virtual void OnPushPromiseFramePayload(QuicStringPiece payload) = 0; + // Returns true to permit furthuring decoding, and false to prevent it. + virtual bool OnPushPromiseFramePayload(QuicStringPiece payload) = 0; // Called when a PUSH_PROMISE frame has been completely processed. - virtual void OnPushPromiseFrameEnd() = 0; + // Returns true to permit furthuring decoding, and false to prevent it. + virtual bool OnPushPromiseFrameEnd() = 0; // TODO(rch): Consider adding methods like: // OnUnknownFrame{Start,Payload,End}()
diff --git a/quic/core/http/http_decoder_test.cc b/quic/core/http/http_decoder_test.cc index a4903b5..b99689f 100644 --- a/quic/core/http/http_decoder_test.cc +++ b/quic/core/http/http_decoder_test.cc
@@ -10,7 +10,9 @@ #include "net/third_party/quiche/src/quic/platform/api/quic_ptr_util.h" #include "net/third_party/quiche/src/quic/platform/api/quic_test.h" -using testing::InSequence; +using ::testing::_; +using ::testing::InSequence; +using ::testing::Return; namespace quic { @@ -21,30 +23,49 @@ // Called if an error is detected. MOCK_METHOD1(OnError, void(HttpDecoder* decoder)); - MOCK_METHOD1(OnPriorityFrame, void(const PriorityFrame& frame)); - MOCK_METHOD1(OnCancelPushFrame, void(const CancelPushFrame& frame)); - MOCK_METHOD1(OnMaxPushIdFrame, void(const MaxPushIdFrame& frame)); - MOCK_METHOD1(OnGoAwayFrame, void(const GoAwayFrame& frame)); - MOCK_METHOD1(OnSettingsFrameStart, void(Http3FrameLengths frame_lengths)); - MOCK_METHOD1(OnSettingsFrame, void(const SettingsFrame& frame)); - MOCK_METHOD1(OnDuplicatePushFrame, void(const DuplicatePushFrame& frame)); + MOCK_METHOD1(OnPriorityFrame, bool(const PriorityFrame& frame)); + MOCK_METHOD1(OnCancelPushFrame, bool(const CancelPushFrame& frame)); + MOCK_METHOD1(OnMaxPushIdFrame, bool(const MaxPushIdFrame& frame)); + MOCK_METHOD1(OnGoAwayFrame, bool(const GoAwayFrame& frame)); + MOCK_METHOD1(OnSettingsFrameStart, bool(Http3FrameLengths frame_lengths)); + MOCK_METHOD1(OnSettingsFrame, bool(const SettingsFrame& frame)); + MOCK_METHOD1(OnDuplicatePushFrame, bool(const DuplicatePushFrame& frame)); - MOCK_METHOD1(OnDataFrameStart, void(Http3FrameLengths frame_lengths)); - MOCK_METHOD1(OnDataFramePayload, void(QuicStringPiece payload)); - MOCK_METHOD0(OnDataFrameEnd, void()); + MOCK_METHOD1(OnDataFrameStart, bool(Http3FrameLengths frame_lengths)); + MOCK_METHOD1(OnDataFramePayload, bool(QuicStringPiece payload)); + MOCK_METHOD0(OnDataFrameEnd, bool()); - MOCK_METHOD1(OnHeadersFrameStart, void(Http3FrameLengths frame_lengths)); - MOCK_METHOD1(OnHeadersFramePayload, void(QuicStringPiece payload)); - MOCK_METHOD0(OnHeadersFrameEnd, void()); + MOCK_METHOD1(OnHeadersFrameStart, bool(Http3FrameLengths frame_lengths)); + MOCK_METHOD1(OnHeadersFramePayload, bool(QuicStringPiece payload)); + MOCK_METHOD0(OnHeadersFrameEnd, bool()); - MOCK_METHOD1(OnPushPromiseFrameStart, void(PushId push_id)); - MOCK_METHOD1(OnPushPromiseFramePayload, void(QuicStringPiece payload)); - MOCK_METHOD0(OnPushPromiseFrameEnd, void()); + MOCK_METHOD1(OnPushPromiseFrameStart, bool(PushId push_id)); + MOCK_METHOD1(OnPushPromiseFramePayload, bool(QuicStringPiece payload)); + MOCK_METHOD0(OnPushPromiseFrameEnd, bool()); }; class HttpDecoderTest : public QuicTest { public: - HttpDecoderTest() { decoder_.set_visitor(&visitor_); } + HttpDecoderTest() { + ON_CALL(visitor_, OnPriorityFrame(_)).WillByDefault(Return(true)); + ON_CALL(visitor_, OnCancelPushFrame(_)).WillByDefault(Return(true)); + ON_CALL(visitor_, OnMaxPushIdFrame(_)).WillByDefault(Return(true)); + ON_CALL(visitor_, OnGoAwayFrame(_)).WillByDefault(Return(true)); + ON_CALL(visitor_, OnSettingsFrameStart(_)).WillByDefault(Return(true)); + ON_CALL(visitor_, OnSettingsFrame(_)).WillByDefault(Return(true)); + ON_CALL(visitor_, OnDuplicatePushFrame(_)).WillByDefault(Return(true)); + ON_CALL(visitor_, OnDataFrameStart(_)).WillByDefault(Return(true)); + ON_CALL(visitor_, OnDataFramePayload(_)).WillByDefault(Return(true)); + ON_CALL(visitor_, OnDataFrameEnd()).WillByDefault(Return(true)); + ON_CALL(visitor_, OnHeadersFrameStart(_)).WillByDefault(Return(true)); + ON_CALL(visitor_, OnHeadersFramePayload(_)).WillByDefault(Return(true)); + ON_CALL(visitor_, OnHeadersFrameEnd()).WillByDefault(Return(true)); + ON_CALL(visitor_, OnPushPromiseFrameStart(_)).WillByDefault(Return(true)); + ON_CALL(visitor_, OnPushPromiseFramePayload(_)).WillByDefault(Return(true)); + ON_CALL(visitor_, OnPushPromiseFrameEnd()).WillByDefault(Return(true)); + decoder_.set_visitor(&visitor_); + } + HttpDecoder decoder_; testing::StrictMock<MockVisitor> visitor_; }; @@ -172,6 +193,13 @@ } EXPECT_EQ(QUIC_NO_ERROR, decoder_.error()); EXPECT_EQ("", decoder_.error_detail()); + + // Test on the situation when the visitor wants to stop processing. + EXPECT_CALL(visitor_, OnCancelPushFrame(CancelPushFrame({1}))) + .WillOnce(Return(false)); + EXPECT_EQ(0u, decoder_.ProcessInput(input, QUIC_ARRAYSIZE(input))); + EXPECT_EQ(QUIC_INTERNAL_ERROR, decoder_.error()); + EXPECT_EQ("Visitor shut down.", decoder_.error_detail()); } TEST_F(HttpDecoderTest, PushPromiseFrame) { @@ -209,6 +237,12 @@ } EXPECT_EQ(QUIC_NO_ERROR, decoder_.error()); EXPECT_EQ("", decoder_.error_detail()); + + // Test on the situation when the visitor wants to stop processing. + EXPECT_CALL(visitor_, OnPushPromiseFrameStart(1)).WillOnce(Return(false)); + EXPECT_EQ(0u, decoder_.ProcessInput(input, QUIC_ARRAYSIZE(input))); + EXPECT_EQ(QUIC_INTERNAL_ERROR, decoder_.error()); + EXPECT_EQ("Visitor shut down.", decoder_.error_detail()); } TEST_F(HttpDecoderTest, MaxPushId) { @@ -233,6 +267,13 @@ } EXPECT_EQ(QUIC_NO_ERROR, decoder_.error()); EXPECT_EQ("", decoder_.error_detail()); + + // Test on the situation when the visitor wants to stop processing. + EXPECT_CALL(visitor_, OnMaxPushIdFrame(MaxPushIdFrame({1}))) + .WillOnce(Return(false)); + EXPECT_EQ(0u, decoder_.ProcessInput(input, QUIC_ARRAYSIZE(input))); + EXPECT_EQ(QUIC_INTERNAL_ERROR, decoder_.error()); + EXPECT_EQ("Visitor shut down.", decoder_.error_detail()); } TEST_F(HttpDecoderTest, DuplicatePush) { @@ -256,6 +297,13 @@ } EXPECT_EQ(QUIC_NO_ERROR, decoder_.error()); EXPECT_EQ("", decoder_.error_detail()); + + // Test on the situation when the visitor wants to stop processing. + EXPECT_CALL(visitor_, OnDuplicatePushFrame(DuplicatePushFrame({1}))) + .WillOnce(Return(false)); + EXPECT_EQ(0u, decoder_.ProcessInput(input, QUIC_ARRAYSIZE(input))); + EXPECT_EQ(QUIC_INTERNAL_ERROR, decoder_.error()); + EXPECT_EQ("Visitor shut down.", decoder_.error_detail()); } TEST_F(HttpDecoderTest, PriorityFrame) { @@ -287,7 +335,6 @@ EXPECT_EQ(QUIC_NO_ERROR, decoder_.error()); EXPECT_EQ("", decoder_.error_detail()); - /* // Process the frame incremently. EXPECT_CALL(visitor_, OnPriorityFrame(frame)); for (char c : input) { @@ -295,7 +342,12 @@ } EXPECT_EQ(QUIC_NO_ERROR, decoder_.error()); EXPECT_EQ("", decoder_.error_detail()); - */ + + // Test on the situation when the visitor wants to stop processing. + EXPECT_CALL(visitor_, OnPriorityFrame(frame)).WillOnce(Return(false)); + EXPECT_EQ(0u, decoder_.ProcessInput(input, QUIC_ARRAYSIZE(input))); + EXPECT_EQ(QUIC_INTERNAL_ERROR, decoder_.error()); + EXPECT_EQ("Visitor shut down.", decoder_.error_detail()); } TEST_F(HttpDecoderTest, SettingsFrame) { @@ -341,6 +393,13 @@ } EXPECT_EQ(QUIC_NO_ERROR, decoder_.error()); EXPECT_EQ("", decoder_.error_detail()); + + // Test on the situation when the visitor wants to stop processing. + EXPECT_CALL(visitor_, OnSettingsFrameStart(Http3FrameLengths(2, 7))) + .WillOnce(Return(false)); + EXPECT_EQ(0u, decoder_.ProcessInput(input, QUIC_ARRAYSIZE(input))); + EXPECT_EQ(QUIC_INTERNAL_ERROR, decoder_.error()); + EXPECT_EQ("Visitor shut down.", decoder_.error_detail()); } TEST_F(HttpDecoderTest, DataFrame) { @@ -374,6 +433,13 @@ } EXPECT_EQ(QUIC_NO_ERROR, decoder_.error()); EXPECT_EQ("", decoder_.error_detail()); + + // Test on the situation when the visitor wants to stop processing. + EXPECT_CALL(visitor_, OnDataFrameStart(Http3FrameLengths(2, 5))) + .WillOnce(Return(false)); + EXPECT_EQ(0u, decoder_.ProcessInput(input, QUIC_ARRAYSIZE(input))); + EXPECT_EQ(QUIC_INTERNAL_ERROR, decoder_.error()); + EXPECT_EQ("Visitor shut down.", decoder_.error_detail()); } TEST_F(HttpDecoderTest, FrameHeaderPartialDelivery) {
diff --git a/quic/core/http/quic_receive_control_stream.cc b/quic/core/http/quic_receive_control_stream.cc index 87476e6..5941189 100644 --- a/quic/core/http/quic_receive_control_stream.cc +++ b/quic/core/http/quic_receive_control_stream.cc
@@ -30,64 +30,82 @@ ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); } - void OnPriorityFrame(const PriorityFrame& frame) override { + bool OnPriorityFrame(const PriorityFrame& frame) override { CloseConnectionOnWrongFrame("Priority"); + return false; } - void OnCancelPushFrame(const CancelPushFrame& frame) override { + bool OnCancelPushFrame(const CancelPushFrame& frame) override { CloseConnectionOnWrongFrame("Cancel Push"); + return false; } - void OnMaxPushIdFrame(const MaxPushIdFrame& frame) override { + bool OnMaxPushIdFrame(const MaxPushIdFrame& frame) override { CloseConnectionOnWrongFrame("Max Push Id"); + return false; } - void OnGoAwayFrame(const GoAwayFrame& frame) override { + bool OnGoAwayFrame(const GoAwayFrame& frame) override { CloseConnectionOnWrongFrame("Goaway"); + return false; } - void OnSettingsFrameStart(Http3FrameLengths frame_lengths) override { - stream_->OnSettingsFrameStart(frame_lengths); + bool OnSettingsFrameStart(Http3FrameLengths frame_lengths) override { + return stream_->OnSettingsFrameStart(frame_lengths); } - void OnSettingsFrame(const SettingsFrame& frame) override { - stream_->OnSettingsFrame(frame); + bool OnSettingsFrame(const SettingsFrame& frame) override { + return stream_->OnSettingsFrame(frame); } - void OnDuplicatePushFrame(const DuplicatePushFrame& frame) override { + bool OnDuplicatePushFrame(const DuplicatePushFrame& frame) override { CloseConnectionOnWrongFrame("Duplicate Push"); + return false; } - void OnDataFrameStart(Http3FrameLengths frame_lengths) override { + bool OnDataFrameStart(Http3FrameLengths frame_lengths) override { CloseConnectionOnWrongFrame("Data"); + return false; } - void OnDataFramePayload(QuicStringPiece payload) override { + bool OnDataFramePayload(QuicStringPiece payload) override { CloseConnectionOnWrongFrame("Data"); + return false; } - void OnDataFrameEnd() override { CloseConnectionOnWrongFrame("Data"); } + bool OnDataFrameEnd() override { + CloseConnectionOnWrongFrame("Data"); + return false; + } - void OnHeadersFrameStart(Http3FrameLengths frame_length) override { + bool OnHeadersFrameStart(Http3FrameLengths frame_length) override { CloseConnectionOnWrongFrame("Headers"); + return false; } - void OnHeadersFramePayload(QuicStringPiece payload) override { + bool OnHeadersFramePayload(QuicStringPiece payload) override { CloseConnectionOnWrongFrame("Headers"); + return false; } - void OnHeadersFrameEnd() override { CloseConnectionOnWrongFrame("Headers"); } - - void OnPushPromiseFrameStart(PushId push_id) override { - CloseConnectionOnWrongFrame("Push Promise"); + bool OnHeadersFrameEnd() override { + CloseConnectionOnWrongFrame("Headers"); + return false; } - void OnPushPromiseFramePayload(QuicStringPiece payload) override { + bool OnPushPromiseFrameStart(PushId push_id) override { CloseConnectionOnWrongFrame("Push Promise"); + return false; } - void OnPushPromiseFrameEnd() override { + bool OnPushPromiseFramePayload(QuicStringPiece payload) override { CloseConnectionOnWrongFrame("Push Promise"); + return false; + } + + bool OnPushPromiseFrameEnd() override { + CloseConnectionOnWrongFrame("Push Promise"); + return false; } private: @@ -137,20 +155,21 @@ } } -void QuicReceiveControlStream::OnSettingsFrameStart( +bool QuicReceiveControlStream::OnSettingsFrameStart( Http3FrameLengths frame_lengths) { if (received_settings_length_ != 0) { // TODO(renjietang): Change error code to HTTP_UNEXPECTED_FRAME. session()->connection()->CloseConnection( QUIC_INVALID_STREAM_ID, "Settings frames are received twice.", ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); - return; + return false; } received_settings_length_ += frame_lengths.header_length + frame_lengths.payload_length; + return true; } -void QuicReceiveControlStream::OnSettingsFrame(const SettingsFrame& settings) { +bool QuicReceiveControlStream::OnSettingsFrame(const SettingsFrame& settings) { QuicSpdySession* spdy_session = static_cast<QuicSpdySession*>(session()); for (auto& it : settings.values) { uint16_t setting_id = it.first; @@ -166,6 +185,7 @@ } } sequencer()->MarkConsumed(received_settings_length_); + return true; } } // namespace quic
diff --git a/quic/core/http/quic_receive_control_stream.h b/quic/core/http/quic_receive_control_stream.h index 818a1e8..a977f63 100644 --- a/quic/core/http/quic_receive_control_stream.h +++ b/quic/core/http/quic_receive_control_stream.h
@@ -36,8 +36,8 @@ protected: // Called from HttpDecoderVisitor. - void OnSettingsFrameStart(Http3FrameLengths frame_lengths); - void OnSettingsFrame(const SettingsFrame& settings); + bool OnSettingsFrameStart(Http3FrameLengths frame_lengths); + bool OnSettingsFrame(const SettingsFrame& settings); private: class HttpDecoderVisitor;
diff --git a/quic/core/http/quic_spdy_stream.cc b/quic/core/http/quic_spdy_stream.cc index bd15918..b025fd8 100644 --- a/quic/core/http/quic_spdy_stream.cc +++ b/quic/core/http/quic_spdy_stream.cc
@@ -45,84 +45,94 @@ ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); } - void OnPriorityFrame(const PriorityFrame& frame) override { + bool OnPriorityFrame(const PriorityFrame& frame) override { CloseConnectionOnWrongFrame("Priority"); + return false; } - void OnCancelPushFrame(const CancelPushFrame& frame) override { + bool OnCancelPushFrame(const CancelPushFrame& frame) override { CloseConnectionOnWrongFrame("Cancel Push"); + return false; } - void OnMaxPushIdFrame(const MaxPushIdFrame& frame) override { + bool OnMaxPushIdFrame(const MaxPushIdFrame& frame) override { CloseConnectionOnWrongFrame("Max Push Id"); + return false; } - void OnGoAwayFrame(const GoAwayFrame& frame) override { + bool OnGoAwayFrame(const GoAwayFrame& frame) override { CloseConnectionOnWrongFrame("Goaway"); + return false; } - void OnSettingsFrameStart(Http3FrameLengths frame_lengths) override { + bool OnSettingsFrameStart(Http3FrameLengths frame_lengths) override { CloseConnectionOnWrongFrame("Settings"); + return false; } - void OnSettingsFrame(const SettingsFrame& frame) override { + bool OnSettingsFrame(const SettingsFrame& frame) override { CloseConnectionOnWrongFrame("Settings"); + return false; } - void OnDuplicatePushFrame(const DuplicatePushFrame& frame) override { + bool OnDuplicatePushFrame(const DuplicatePushFrame& frame) override { CloseConnectionOnWrongFrame("Duplicate Push"); + return false; } - void OnDataFrameStart(Http3FrameLengths frame_lengths) override { - stream_->OnDataFrameStart(frame_lengths); + bool OnDataFrameStart(Http3FrameLengths frame_lengths) override { + return stream_->OnDataFrameStart(frame_lengths); } - void OnDataFramePayload(QuicStringPiece payload) override { + bool OnDataFramePayload(QuicStringPiece payload) override { DCHECK(!payload.empty()); - stream_->OnDataFramePayload(payload); + return stream_->OnDataFramePayload(payload); } - void OnDataFrameEnd() override { stream_->OnDataFrameEnd(); } + bool OnDataFrameEnd() override { return stream_->OnDataFrameEnd(); } - void OnHeadersFrameStart(Http3FrameLengths frame_length) override { + bool OnHeadersFrameStart(Http3FrameLengths frame_length) override { if (!VersionUsesQpack( stream_->session()->connection()->transport_version())) { CloseConnectionOnWrongFrame("Headers"); - return; + return false; } - stream_->OnHeadersFrameStart(frame_length); + return stream_->OnHeadersFrameStart(frame_length); } - void OnHeadersFramePayload(QuicStringPiece payload) override { + bool OnHeadersFramePayload(QuicStringPiece payload) override { DCHECK(!payload.empty()); if (!VersionUsesQpack( stream_->session()->connection()->transport_version())) { CloseConnectionOnWrongFrame("Headers"); - return; + return false; } - stream_->OnHeadersFramePayload(payload); + return stream_->OnHeadersFramePayload(payload); } - void OnHeadersFrameEnd() override { + bool OnHeadersFrameEnd() override { if (!VersionUsesQpack( stream_->session()->connection()->transport_version())) { CloseConnectionOnWrongFrame("Headers"); - return; + return false; } - stream_->OnHeadersFrameEnd(); + return stream_->OnHeadersFrameEnd(); } - void OnPushPromiseFrameStart(PushId push_id) override { + bool OnPushPromiseFrameStart(PushId push_id) override { CloseConnectionOnWrongFrame("Push Promise"); + return false; } - void OnPushPromiseFramePayload(QuicStringPiece payload) override { + bool OnPushPromiseFramePayload(QuicStringPiece payload) override { DCHECK(!payload.empty()); CloseConnectionOnWrongFrame("Push Promise"); + return false; } - void OnPushPromiseFrameEnd() override { + bool OnPushPromiseFrameEnd() override { CloseConnectionOnWrongFrame("Push Promise"); + return false; } private: @@ -470,7 +480,7 @@ SetPriority(priority); } -void QuicSpdyStream::OnStreamHeaderList(bool fin, +bool QuicSpdyStream::OnStreamHeaderList(bool fin, size_t frame_len, const QuicHeaderList& header_list) { // The headers list avoid infinite buffering by clearing the headers list @@ -482,7 +492,7 @@ if (header_list.empty()) { OnHeadersTooLarge(); if (IsDoneReading()) { - return; + return false; } } if (!headers_decompressed_) { @@ -490,6 +500,7 @@ } else { OnTrailingHeadersComplete(fin, frame_len, header_list); } + return !reading_stopped(); } void QuicSpdyStream::OnHeadersTooLarge() { @@ -701,25 +712,27 @@ spdy_session_ = nullptr; } -void QuicSpdyStream::OnDataFrameStart(Http3FrameLengths frame_lengths) { +bool QuicSpdyStream::OnDataFrameStart(Http3FrameLengths frame_lengths) { DCHECK( VersionHasDataFrameHeader(session()->connection()->transport_version())); - body_buffer_.OnDataHeader(frame_lengths); + return true; } -void QuicSpdyStream::OnDataFramePayload(QuicStringPiece payload) { +bool QuicSpdyStream::OnDataFramePayload(QuicStringPiece payload) { DCHECK( VersionHasDataFrameHeader(session()->connection()->transport_version())); body_buffer_.OnDataPayload(payload); + return true; } -void QuicSpdyStream::OnDataFrameEnd() { +bool QuicSpdyStream::OnDataFrameEnd() { DCHECK( VersionHasDataFrameHeader(session()->connection()->transport_version())); QUIC_DVLOG(1) << "Reaches the end of a data frame. Total bytes received are " << body_buffer_.total_body_bytes_received(); + return true; } bool QuicSpdyStream::OnStreamFrameAcked(QuicStreamOffset offset, @@ -769,7 +782,7 @@ return header_acked_length; } -void QuicSpdyStream::OnHeadersFrameStart(Http3FrameLengths frame_length) { +bool QuicSpdyStream::OnHeadersFrameStart(Http3FrameLengths frame_length) { DCHECK(VersionUsesQpack(spdy_session_->connection()->transport_version())); DCHECK(!qpack_decoded_headers_accumulator_); @@ -783,9 +796,10 @@ QuicMakeUnique<QpackDecodedHeadersAccumulator>( id(), spdy_session_->qpack_decoder(), spdy_session_->max_inbound_header_list_size()); + return true; } -void QuicSpdyStream::OnHeadersFramePayload(QuicStringPiece payload) { +bool QuicSpdyStream::OnHeadersFramePayload(QuicStringPiece payload) { DCHECK(VersionUsesQpack(spdy_session_->connection()->transport_version())); if (!qpack_decoded_headers_accumulator_->Decode(payload)) { @@ -794,11 +808,12 @@ QuicStrCat("Error decompressing header block on stream ", id(), ": ", qpack_decoded_headers_accumulator_->error_message()); CloseConnectionWithDetails(QUIC_DECOMPRESSION_FAILURE, error_message); - return; + return false; } + return true; } -void QuicSpdyStream::OnHeadersFrameEnd() { +bool QuicSpdyStream::OnHeadersFrameEnd() { DCHECK(VersionUsesQpack(spdy_session_->connection()->transport_version())); if (!qpack_decoded_headers_accumulator_->EndHeaderBlock()) { @@ -807,16 +822,18 @@ QuicStrCat("Error decompressing header block on stream ", id(), ": ", qpack_decoded_headers_accumulator_->error_message()); CloseConnectionWithDetails(QUIC_DECOMPRESSION_FAILURE, error_message); - return; + return false; } const QuicByteCount frame_length = headers_decompressed_ ? trailers_length_.payload_length : headers_length_.payload_length; - OnStreamHeaderList(/* fin = */ false, frame_length, - qpack_decoded_headers_accumulator_->quic_header_list()); + bool result = OnStreamHeaderList( + /* fin = */ false, frame_length, + qpack_decoded_headers_accumulator_->quic_header_list()); qpack_decoded_headers_accumulator_.reset(); + return result; } size_t QuicSpdyStream::WriteHeadersImpl(
diff --git a/quic/core/http/quic_spdy_stream.h b/quic/core/http/quic_spdy_stream.h index 2182d28..76c252d 100644 --- a/quic/core/http/quic_spdy_stream.h +++ b/quic/core/http/quic_spdy_stream.h
@@ -80,8 +80,9 @@ // Called by the session when decompressed headers have been completely // delivered to this stream. If |fin| is true, then this stream - // should be closed; no more data will be sent by the peer. - virtual void OnStreamHeaderList(bool fin, + // should be closed; no more data will be sent by the peer. Returns true if + // the headers are processed successfully without error. + virtual bool OnStreamHeaderList(bool fin, size_t frame_len, const QuicHeaderList& header_list); @@ -207,12 +208,12 @@ protected: // HTTP/3 - void OnDataFrameStart(Http3FrameLengths frame_lengths); - void OnDataFramePayload(QuicStringPiece payload); - void OnDataFrameEnd(); - void OnHeadersFrameStart(Http3FrameLengths frame_length); - void OnHeadersFramePayload(QuicStringPiece payload); - void OnHeadersFrameEnd(); + bool OnDataFrameStart(Http3FrameLengths frame_lengths); + bool OnDataFramePayload(QuicStringPiece payload); + bool OnDataFrameEnd(); + bool OnHeadersFrameStart(Http3FrameLengths frame_length); + bool OnHeadersFramePayload(QuicStringPiece payload); + bool OnHeadersFrameEnd(); // Called when the received headers are too large. By default this will // reset the stream.
diff --git a/quic/core/http/quic_spdy_stream_test.cc b/quic/core/http/quic_spdy_stream_test.cc index 7c423b4..7ecbbf5 100644 --- a/quic/core/http/quic_spdy_stream_test.cc +++ b/quic/core/http/quic_spdy_stream_test.cc
@@ -1863,6 +1863,63 @@ EXPECT_EQ("some data", data); } +TEST_P(QuicSpdyStreamTest, MalformedHeadersStopHttpDecoder) { + // The test stream will receive a stream frame containing malformed headers + // and normal body. Make sure the http decoder stops processing body after the + // connection shuts down. + testing::InSequence s; + if (!VersionUsesQpack(GetParam().transport_version)) { + return; + } + + Initialize(kShouldProcessData); + connection_->AdvanceTime(QuicTime::Delta::FromSeconds(1)); + + // Random bad headers. + std::string headers_frame_payload = + QuicTextUtils::HexDecode("00002a94e7036261"); + std::unique_ptr<char[]> headers_buffer; + QuicByteCount headers_frame_header_length = + encoder_.SerializeHeadersFrameHeader(headers_frame_payload.length(), + &headers_buffer); + QuicStringPiece headers_frame_header(headers_buffer.get(), + headers_frame_header_length); + + std::string data_frame_payload = "some data"; + std::unique_ptr<char[]> data_buffer; + QuicByteCount data_frame_header_length = encoder_.SerializeDataFrameHeader( + data_frame_payload.length(), &data_buffer); + QuicStringPiece data_frame_header(data_buffer.get(), + data_frame_header_length); + + std::string stream_frame_payload = + QuicStrCat(headers_frame_header, headers_frame_payload, data_frame_header, + data_frame_payload); + QuicStreamFrame frame(stream_->id(), false, 0, stream_frame_payload); + + EXPECT_CALL(*connection_, + CloseConnection(QUIC_DECOMPRESSION_FAILURE, + "Error decompressing header block on stream 4: " + "Incomplete header block.", + _)) + .WillOnce( + (Invoke([this](QuicErrorCode error, const std::string& error_details, + ConnectionCloseBehavior connection_close_behavior) { + connection_->ReallyCloseConnection(error, error_details, + connection_close_behavior); + }))); + EXPECT_CALL(*connection_, SendConnectionClosePacket(_, _)); + EXPECT_CALL(*session_, OnConnectionClosed(_, _, _)) + .WillOnce( + Invoke([this](QuicErrorCode error, const std::string& error_details, + ConnectionCloseSource source) { + session_->ReallyOnConnectionClosed(error, error_details, source); + })); + EXPECT_CALL(*session_, SendRstStream(_, _, _)); + EXPECT_CALL(*session_, SendRstStream(_, _, _)); + stream_->OnStreamFrame(frame); +} + } // namespace } // namespace test } // namespace quic