Internal QUICHE change

PiperOrigin-RevId: 270719218
Change-Id: If2d4f9a6246777ef7920c276e41d97f7e4295631
diff --git a/quic/core/http/http_decoder.cc b/quic/core/http/http_decoder.cc
index e1fa36e..7fe994e 100644
--- a/quic/core/http/http_decoder.cc
+++ b/quic/core/http/http_decoder.cc
@@ -153,10 +153,13 @@
       continue_processing = visitor_->OnSettingsFrameStart(header_length);
       break;
     case static_cast<uint64_t>(HttpFrameType::PUSH_PROMISE):
+      // This edge case needs to be handled here, because ReadFramePayload()
+      // does not get called if |current_frame_length_| is zero.
       if (current_frame_length_ == 0) {
         RaiseError(QUIC_INVALID_FRAME_DATA, "Corrupt PUSH_PROMISE frame.");
         return false;
       }
+      continue_processing = visitor_->OnPushPromiseFrameStart(header_length);
       break;
     case static_cast<uint64_t>(HttpFrameType::GOAWAY):
       break;
@@ -239,10 +242,8 @@
         bool success = reader->ReadVarInt62(&push_id);
         DCHECK(success);
         remaining_frame_length_ -= current_push_id_length_;
-        if (!visitor_->OnPushPromiseFrameStart(
-                push_id,
-                current_length_field_length_ + current_type_field_length_,
-                current_push_id_length_)) {
+        if (!visitor_->OnPushPromiseFramePushId(push_id,
+                                                current_push_id_length_)) {
           continue_processing = false;
           current_push_id_length_ = 0;
           break;
@@ -259,10 +260,8 @@
 
         bool success = push_id_reader.ReadVarInt62(&push_id);
         DCHECK(success);
-        if (!visitor_->OnPushPromiseFrameStart(
-                push_id,
-                current_length_field_length_ + current_type_field_length_,
-                current_push_id_length_)) {
+        if (!visitor_->OnPushPromiseFramePushId(push_id,
+                                                current_push_id_length_)) {
           continue_processing = false;
           current_push_id_length_ = 0;
           break;
diff --git a/quic/core/http/http_decoder.h b/quic/core/http/http_decoder.h
index 30ebadc..862c5de 100644
--- a/quic/core/http/http_decoder.h
+++ b/quic/core/http/http_decoder.h
@@ -37,11 +37,10 @@
     // All the following methods return true to continue decoding,
     // and false to pause it.
     // On*FrameStart() methods are called after the frame header is completely
-    // processed.  At that point it is safe to consume
-    // |frame_length.header_length| bytes.
+    // processed.  At that point it is safe to consume |header_length| bytes.
 
     // Called when a PRIORITY frame has been received.
-    // |frame_length| contains PRIORITY frame length and payload length.
+    // |header_length| contains PRIORITY frame length and payload length.
     virtual bool OnPriorityFrameStart(QuicByteCount header_length) = 0;
 
     // Called when a PRIORITY frame has been successfully parsed.
@@ -66,7 +65,7 @@
     virtual bool OnDuplicatePushFrame(const DuplicatePushFrame& frame) = 0;
 
     // Called when a DATA frame has been received.
-    // |frame_length| contains DATA frame length and payload length.
+    // |header_length| contains DATA frame length and payload length.
     virtual bool OnDataFrameStart(QuicByteCount header_length) = 0;
     // Called when part of the payload of a DATA frame has been read.  May be
     // called multiple times for a single frame.  |payload| is guaranteed to be
@@ -76,7 +75,7 @@
     virtual bool OnDataFrameEnd() = 0;
 
     // Called when a HEADERS frame has been received.
-    // |frame_length| contains HEADERS frame length and payload length.
+    // |header_length| contains HEADERS frame length and payload length.
     virtual bool OnHeadersFrameStart(QuicByteCount header_length) = 0;
     // Called when part of the payload of a HEADERS frame has been read.  May be
     // called multiple times for a single frame.  |payload| is guaranteed to be
@@ -86,20 +85,22 @@
     // |frame_len| is the length of the HEADERS frame payload.
     virtual bool OnHeadersFrameEnd() = 0;
 
-    // Called when a PUSH_PROMISE frame has been received for |push_id|.
-    virtual bool OnPushPromiseFrameStart(PushId push_id,
-                                         QuicByteCount header_length,
-                                         QuicByteCount push_id_length) = 0;
-    // Called when part of the payload of a PUSH_PROMISE frame has been read.
-    // May be called multiple times for a single frame.  |payload| is guaranteed
-    // to be non-empty.
+    // Called when a PUSH_PROMISE frame has been received.
+    virtual bool OnPushPromiseFrameStart(QuicByteCount header_length) = 0;
+    // Called when the Push ID field of a PUSH_PROMISE frame has been parsed.
+    // Called exactly once for a valid PUSH_PROMISE frame.
+    virtual bool OnPushPromiseFramePushId(PushId push_id,
+                                          QuicByteCount push_id_length) = 0;
+    // Called when part of the header block of a PUSH_PROMISE frame has been
+    // read. May be called multiple times for a single frame.  |payload| is
+    // guaranteed to be non-empty.
     virtual bool OnPushPromiseFramePayload(QuicStringPiece payload) = 0;
     // Called when a PUSH_PROMISE frame has been completely processed.
     virtual bool OnPushPromiseFrameEnd() = 0;
 
     // Called when a frame of unknown type |frame_type| has been received.
     // Frame type might be reserved, Visitor must make sure to ignore.
-    // |frame_length| contains frame length and payload length.
+    // |header_length| contains frame length and payload length.
     virtual bool OnUnknownFrameStart(uint64_t frame_type,
                                      QuicByteCount header_length) = 0;
     // Called when part of the payload of the unknown frame has been read.  May
diff --git a/quic/core/http/http_decoder_test.cc b/quic/core/http/http_decoder_test.cc
index 3abc8be..4baf271 100644
--- a/quic/core/http/http_decoder_test.cc
+++ b/quic/core/http/http_decoder_test.cc
@@ -52,10 +52,9 @@
   MOCK_METHOD1(OnHeadersFramePayload, bool(QuicStringPiece payload));
   MOCK_METHOD0(OnHeadersFrameEnd, bool());
 
-  MOCK_METHOD3(OnPushPromiseFrameStart,
-               bool(PushId push_id,
-                    QuicByteCount header_length,
-                    QuicByteCount push_id_length));
+  MOCK_METHOD1(OnPushPromiseFrameStart, bool(QuicByteCount header_length));
+  MOCK_METHOD2(OnPushPromiseFramePushId,
+               bool(PushId push_id, QuicByteCount push_id_length));
   MOCK_METHOD1(OnPushPromiseFramePayload, bool(QuicStringPiece payload));
   MOCK_METHOD0(OnPushPromiseFrameEnd, bool());
 
@@ -81,7 +80,8 @@
     ON_CALL(visitor_, OnHeadersFrameStart(_)).WillByDefault(Return(true));
     ON_CALL(visitor_, OnHeadersFramePayload(_)).WillByDefault(Return(true));
     ON_CALL(visitor_, OnHeadersFrameEnd()).WillByDefault(Return(true));
-    ON_CALL(visitor_, OnPushPromiseFrameStart(_, _, _))
+    ON_CALL(visitor_, OnPushPromiseFrameStart(_)).WillByDefault(Return(true));
+    ON_CALL(visitor_, OnPushPromiseFramePushId(_, _))
         .WillByDefault(Return(true));
     ON_CALL(visitor_, OnPushPromiseFramePayload(_)).WillByDefault(Return(true));
     ON_CALL(visitor_, OnPushPromiseFrameEnd()).WillByDefault(Return(true));
@@ -210,12 +210,16 @@
                  "Headers");                                    // headers
 
   // Visitor pauses processing.
-  EXPECT_CALL(visitor_, OnPushPromiseFrameStart(257, 2, 8))
+  EXPECT_CALL(visitor_, OnPushPromiseFrameStart(2)).WillOnce(Return(false));
+  EXPECT_CALL(visitor_, OnPushPromiseFramePushId(257, 8))
       .WillOnce(Return(false));
   QuicStringPiece remaining_input(input);
   QuicByteCount processed_bytes =
       ProcessInputWithGarbageAppended(remaining_input);
-  EXPECT_EQ(10u, processed_bytes);
+  EXPECT_EQ(2u, processed_bytes);
+  remaining_input = remaining_input.substr(processed_bytes);
+  processed_bytes = ProcessInputWithGarbageAppended(remaining_input);
+  EXPECT_EQ(8u, processed_bytes);
   remaining_input = remaining_input.substr(processed_bytes);
 
   EXPECT_CALL(visitor_, OnPushPromiseFramePayload(QuicStringPiece("Headers")))
@@ -229,7 +233,8 @@
   EXPECT_EQ("", decoder_.error_detail());
 
   // Process the full frame.
-  EXPECT_CALL(visitor_, OnPushPromiseFrameStart(257, 2, 8));
+  EXPECT_CALL(visitor_, OnPushPromiseFrameStart(2));
+  EXPECT_CALL(visitor_, OnPushPromiseFramePushId(257, 8));
   EXPECT_CALL(visitor_, OnPushPromiseFramePayload(QuicStringPiece("Headers")));
   EXPECT_CALL(visitor_, OnPushPromiseFrameEnd());
   EXPECT_EQ(input.size(), ProcessInput(input));
@@ -237,7 +242,8 @@
   EXPECT_EQ("", decoder_.error_detail());
 
   // Process the frame incrementally.
-  EXPECT_CALL(visitor_, OnPushPromiseFrameStart(257, 2, 8));
+  EXPECT_CALL(visitor_, OnPushPromiseFrameStart(2));
+  EXPECT_CALL(visitor_, OnPushPromiseFramePushId(257, 8));
   EXPECT_CALL(visitor_, OnPushPromiseFramePayload(QuicStringPiece("H")));
   EXPECT_CALL(visitor_, OnPushPromiseFramePayload(QuicStringPiece("e")));
   EXPECT_CALL(visitor_, OnPushPromiseFramePayload(QuicStringPiece("a")));
@@ -251,7 +257,8 @@
   EXPECT_EQ("", decoder_.error_detail());
 
   // Process push id incrementally and append headers with last byte of push id.
-  EXPECT_CALL(visitor_, OnPushPromiseFrameStart(257, 2, 8));
+  EXPECT_CALL(visitor_, OnPushPromiseFrameStart(2));
+  EXPECT_CALL(visitor_, OnPushPromiseFramePushId(257, 8));
   EXPECT_CALL(visitor_, OnPushPromiseFramePayload(QuicStringPiece("Headers")));
   EXPECT_CALL(visitor_, OnPushPromiseFrameEnd());
   ProcessInputCharByChar(input.substr(0, 9));
@@ -260,6 +267,38 @@
   EXPECT_EQ("", decoder_.error_detail());
 }
 
+TEST_F(HttpDecoderTest, CorruptPushPromiseFrame) {
+  InSequence s;
+
+  std::string input = QuicTextUtils::HexDecode(
+      "05"    // type (PUSH_PROMISE)
+      "01"    // length
+      "40");  // first byte of two-byte varint push id
+
+  {
+    HttpDecoder decoder(&visitor_);
+    EXPECT_CALL(visitor_, OnPushPromiseFrameStart(2));
+    EXPECT_CALL(visitor_, OnError(&decoder));
+
+    decoder.ProcessInput(input.data(), input.size());
+
+    EXPECT_EQ(QUIC_INVALID_FRAME_DATA, decoder.error());
+    EXPECT_EQ("PUSH_PROMISE frame malformed.", decoder.error_detail());
+  }
+  {
+    HttpDecoder decoder(&visitor_);
+    EXPECT_CALL(visitor_, OnPushPromiseFrameStart(2));
+    EXPECT_CALL(visitor_, OnError(&decoder));
+
+    for (auto c : input) {
+      decoder.ProcessInput(&c, 1);
+    }
+
+    EXPECT_EQ(QUIC_INVALID_FRAME_DATA, decoder.error());
+    EXPECT_EQ("PUSH_PROMISE frame malformed.", decoder.error_detail());
+  }
+}
+
 TEST_F(HttpDecoderTest, MaxPushId) {
   InSequence s;
   std::string input = QuicTextUtils::HexDecode(
@@ -767,8 +806,8 @@
       "01");  // Push Id
 
   // Visitor pauses processing.
-  EXPECT_CALL(visitor_, OnPushPromiseFrameStart(1, 2, 1))
-      .WillOnce(Return(false));
+  EXPECT_CALL(visitor_, OnPushPromiseFrameStart(2));
+  EXPECT_CALL(visitor_, OnPushPromiseFramePushId(1, 1)).WillOnce(Return(false));
   EXPECT_EQ(input.size(), ProcessInputWithGarbageAppended(input));
 
   EXPECT_CALL(visitor_, OnPushPromiseFrameEnd()).WillOnce(Return(false));
@@ -777,14 +816,16 @@
   EXPECT_EQ("", decoder_.error_detail());
 
   // Process the full frame.
-  EXPECT_CALL(visitor_, OnPushPromiseFrameStart(1, 2, 1));
+  EXPECT_CALL(visitor_, OnPushPromiseFrameStart(2));
+  EXPECT_CALL(visitor_, OnPushPromiseFramePushId(1, 1));
   EXPECT_CALL(visitor_, OnPushPromiseFrameEnd());
   EXPECT_EQ(input.size(), ProcessInput(input));
   EXPECT_EQ(QUIC_NO_ERROR, decoder_.error());
   EXPECT_EQ("", decoder_.error_detail());
 
   // Process the frame incrementally.
-  EXPECT_CALL(visitor_, OnPushPromiseFrameStart(1, 2, 1));
+  EXPECT_CALL(visitor_, OnPushPromiseFrameStart(2));
+  EXPECT_CALL(visitor_, OnPushPromiseFramePushId(1, 1));
   EXPECT_CALL(visitor_, OnPushPromiseFrameEnd());
   ProcessInputCharByChar(input);
   EXPECT_EQ(QUIC_NO_ERROR, decoder_.error());
@@ -865,10 +906,6 @@
                     "\x05"  // valid push id
                     "foo",  // superfluous data
                     "Superfluous data in CANCEL_PUSH frame."},
-                   {"\x05"   // type (PUSH_PROMISE)
-                    "\x01"   // length
-                    "\x40",  // first byte of two-byte varint push id
-                    "PUSH_PROMISE frame malformed."},
                    {"\x0D"   // type (MAX_PUSH_ID)
                     "\x01"   // length
                     "\x40",  // first byte of two-byte varint push id
diff --git a/quic/core/http/quic_receive_control_stream.cc b/quic/core/http/quic_receive_control_stream.cc
index 86005e7..dd9f3ec 100644
--- a/quic/core/http/quic_receive_control_stream.cc
+++ b/quic/core/http/quic_receive_control_stream.cc
@@ -97,7 +97,7 @@
     return false;
   }
 
-  bool OnHeadersFrameStart(QuicByteCount /*frame_length*/) override {
+  bool OnHeadersFrameStart(QuicByteCount /*header_length*/) override {
     CloseConnectionOnWrongFrame("Headers");
     return false;
   }
@@ -112,9 +112,13 @@
     return false;
   }
 
-  bool OnPushPromiseFrameStart(PushId /*push_id*/,
-                               QuicByteCount /*frame_length*/,
-                               QuicByteCount /*push_id_length*/) override {
+  bool OnPushPromiseFrameStart(QuicByteCount /*header_length*/) override {
+    CloseConnectionOnWrongFrame("Push Promise");
+    return false;
+  }
+
+  bool OnPushPromiseFramePushId(PushId /*push_id*/,
+                                QuicByteCount /*push_id_length*/) override {
     CloseConnectionOnWrongFrame("Push Promise");
     return false;
   }
@@ -130,7 +134,7 @@
   }
 
   bool OnUnknownFrameStart(uint64_t /* frame_type */,
-                           QuicByteCount /* frame_length */) override {
+                           QuicByteCount /* header_length */) override {
     // Ignore unknown frame types.
     return true;
   }
diff --git a/quic/core/http/quic_receive_control_stream_test.cc b/quic/core/http/quic_receive_control_stream_test.cc
index 6914681..abb5d5e 100644
--- a/quic/core/http/quic_receive_control_stream_test.cc
+++ b/quic/core/http/quic_receive_control_stream_test.cc
@@ -257,7 +257,10 @@
                         length);
   // TODO(lassey) Check for HTTP_WRONG_STREAM error code.
   EXPECT_CALL(*connection_, CloseConnection(QUIC_HTTP_DECODER_ERROR, _, _))
-      .Times(AtLeast(1));
+      .WillOnce(
+          Invoke(connection_, &MockQuicConnection::ReallyCloseConnection));
+  EXPECT_CALL(*connection_, SendConnectionClosePacket(_, _));
+  EXPECT_CALL(session_, OnConnectionClosed(_, _));
   receive_control_stream_->OnStreamFrame(frame);
 }
 
diff --git a/quic/core/http/quic_spdy_stream.cc b/quic/core/http/quic_spdy_stream.cc
index a089345..ea8e669 100644
--- a/quic/core/http/quic_spdy_stream.cc
+++ b/quic/core/http/quic_spdy_stream.cc
@@ -124,15 +124,21 @@
     return stream_->OnHeadersFrameEnd();
   }
 
-  bool OnPushPromiseFrameStart(PushId push_id,
-                               QuicByteCount header_length,
-                               QuicByteCount push_id_length) override {
+  bool OnPushPromiseFrameStart(QuicByteCount header_length) override {
     if (!VersionUsesQpack(stream_->transport_version())) {
       CloseConnectionOnWrongFrame("Push Promise");
       return false;
     }
-    return stream_->OnPushPromiseFrameStart(push_id, header_length,
-                                            push_id_length);
+    return stream_->OnPushPromiseFrameStart(header_length);
+  }
+
+  bool OnPushPromiseFramePushId(PushId push_id,
+                                QuicByteCount push_id_length) override {
+    if (!VersionUsesQpack(stream_->transport_version())) {
+      CloseConnectionOnWrongFrame("Push Promise");
+      return false;
+    }
+    return stream_->OnPushPromiseFramePushId(push_id, push_id_length);
   }
 
   bool OnPushPromiseFramePayload(QuicStringPiece payload) override {
@@ -929,16 +935,23 @@
   return !sequencer()->IsClosed() && !reading_stopped();
 }
 
-bool QuicSpdyStream::OnPushPromiseFrameStart(PushId push_id,
-                                             QuicByteCount header_length,
-                                             QuicByteCount push_id_length) {
+bool QuicSpdyStream::OnPushPromiseFrameStart(QuicByteCount header_length) {
+  DCHECK(VersionHasStreamType(transport_version()));
+  DCHECK(!qpack_decoded_headers_accumulator_);
+
+  sequencer()->MarkConsumed(body_manager_.OnNonBody(header_length));
+
+  return true;
+}
+
+bool QuicSpdyStream::OnPushPromiseFramePushId(PushId push_id,
+                                              QuicByteCount push_id_length) {
   DCHECK(VersionHasStreamType(transport_version()));
   DCHECK(!qpack_decoded_headers_accumulator_);
 
   // TODO(renjietang): Check max push id and handle errors.
   spdy_session_->OnPushPromise(id(), push_id);
-  sequencer()->MarkConsumed(
-      body_manager_.OnNonBody(header_length + push_id_length));
+  sequencer()->MarkConsumed(body_manager_.OnNonBody(push_id_length));
 
   qpack_decoded_headers_accumulator_ =
       std::make_unique<QpackDecodedHeadersAccumulator>(
diff --git a/quic/core/http/quic_spdy_stream.h b/quic/core/http/quic_spdy_stream.h
index 4849705..a2ebc54 100644
--- a/quic/core/http/quic_spdy_stream.h
+++ b/quic/core/http/quic_spdy_stream.h
@@ -259,9 +259,8 @@
   bool OnHeadersFrameStart(QuicByteCount header_length);
   bool OnHeadersFramePayload(QuicStringPiece payload);
   bool OnHeadersFrameEnd();
-  bool OnPushPromiseFrameStart(PushId push_id,
-                               QuicByteCount header_length,
-                               QuicByteCount push_id_length);
+  bool OnPushPromiseFrameStart(QuicByteCount header_length);
+  bool OnPushPromiseFramePushId(PushId push_id, QuicByteCount push_id_length);
   bool OnPushPromiseFramePayload(QuicStringPiece payload);
   bool OnPushPromiseFrameEnd();
   bool OnUnknownFrameStart(uint64_t frame_type, QuicByteCount header_length);