Supports the `:protocol` pseudo-header for connect requests if SETTINGS_ENABLE_CONNECT_PROTOCOL is sent.

PiperOrigin-RevId: 416316452
diff --git a/http2/adapter/header_validator.cc b/http2/adapter/header_validator.cc
index e37b737..5e8ee97 100644
--- a/http2/adapter/header_validator.cc
+++ b/http2/adapter/header_validator.cc
@@ -20,9 +20,14 @@
 
 const absl::string_view kHttp2StatusValueAllowedChars = "0123456789";
 
-// TODO(birenroy): Support websocket requests, which contain an extra
-// `:protocol` pseudo-header.
-bool ValidateRequestHeaders(const std::vector<std::string>& pseudo_headers) {
+bool ValidateRequestHeaders(const std::vector<std::string>& pseudo_headers,
+                            absl::string_view method, bool allow_connect) {
+  if (allow_connect && method == "CONNECT") {
+    static const std::vector<std::string>* kConnectHeaders =
+        new std::vector<std::string>(
+            {":authority", ":method", ":path", ":protocol", ":scheme"});
+    return pseudo_headers == *kConnectHeaders;
+  }
   static const std::vector<std::string>* kRequiredHeaders =
       new std::vector<std::string>(
           {":authority", ":method", ":path", ":scheme"});
@@ -48,6 +53,7 @@
 void HeaderValidator::StartHeaderBlock() {
   pseudo_headers_.clear();
   status_.clear();
+  method_.clear();
 }
 
 HeaderValidator::HeaderStatus HeaderValidator::ValidateSingleHeader(
@@ -82,6 +88,8 @@
         return HEADER_VALUE_INVALID_STATUS;
       }
       status_ = std::string(value);
+    } else if (key == ":method") {
+      method_ = std::string(value);
     }
     pseudo_headers_.push_back(std::string(key));
   }
@@ -94,7 +102,7 @@
   std::sort(pseudo_headers_.begin(), pseudo_headers_.end());
   switch (type) {
     case HeaderType::REQUEST:
-      return ValidateRequestHeaders(pseudo_headers_);
+      return ValidateRequestHeaders(pseudo_headers_, method_, allow_connect_);
     case HeaderType::REQUEST_TRAILER:
       return ValidateRequestTrailers(pseudo_headers_);
     case HeaderType::RESPONSE_100:
diff --git a/http2/adapter/header_validator.h b/http2/adapter/header_validator.h
index a27fdde..ecde204 100644
--- a/http2/adapter/header_validator.h
+++ b/http2/adapter/header_validator.h
@@ -22,6 +22,10 @@
  public:
   HeaderValidator() {}
 
+  // If called, this validator will allow the `:protocol` pseudo-header, as
+  // described in RFC 8441.
+  void AllowConnect() { allow_connect_ = true; }
+
   void StartHeaderBlock();
 
   enum HeaderStatus {
@@ -38,11 +42,14 @@
   // present for the given header type.
   bool FinishHeaderBlock(HeaderType type);
 
+  // For responses, returns the value of the ":status" header, if present.
   absl::string_view status_header() const { return status_; }
 
  private:
   std::vector<std::string> pseudo_headers_;
   std::string status_;
+  std::string method_;
+  bool allow_connect_ = false;
 };
 
 }  // namespace adapter
diff --git a/http2/adapter/header_validator_test.cc b/http2/adapter/header_validator_test.cc
index 01460c9..dcd92e1 100644
--- a/http2/adapter/header_validator_test.cc
+++ b/http2/adapter/header_validator_test.cc
@@ -154,8 +154,38 @@
   }
   EXPECT_EQ(HeaderValidator::HEADER_OK,
             v.ValidateSingleHeader(":protocol", "websocket"));
-  // For now, `:protocol` is treated as an extra pseudo-header.
+  // At this point, `:protocol` is treated as an extra pseudo-header.
   EXPECT_FALSE(v.FinishHeaderBlock(HeaderType::REQUEST));
+
+  // Future header blocks may send the `:protocol` pseudo-header for CONNECT
+  // requests.
+  v.AllowConnect();
+
+  v.StartHeaderBlock();
+  for (absl::string_view to_add : headers) {
+    EXPECT_EQ(HeaderValidator::HEADER_OK,
+              v.ValidateSingleHeader(to_add, "foo"));
+  }
+  EXPECT_EQ(HeaderValidator::HEADER_OK,
+            v.ValidateSingleHeader(":protocol", "websocket"));
+  // The method is "foo", not "CONNECT", so `:protocol` is still treated as an
+  // extra pseudo-header.
+  EXPECT_FALSE(v.FinishHeaderBlock(HeaderType::REQUEST));
+
+  v.StartHeaderBlock();
+  for (absl::string_view to_add : headers) {
+    if (to_add == ":method") {
+      EXPECT_EQ(HeaderValidator::HEADER_OK,
+                v.ValidateSingleHeader(to_add, "CONNECT"));
+    } else {
+      EXPECT_EQ(HeaderValidator::HEADER_OK,
+                v.ValidateSingleHeader(to_add, "foo"));
+    }
+  }
+  EXPECT_EQ(HeaderValidator::HEADER_OK,
+            v.ValidateSingleHeader(":protocol", "websocket"));
+  // After allowing the method, `:protocol` is acepted for CONNECT requests.
+  EXPECT_TRUE(v.FinishHeaderBlock(HeaderType::REQUEST));
 }
 
 TEST(HeaderValidatorTest, ResponsePseudoHeaders) {
diff --git a/http2/adapter/nghttp2_adapter_test.cc b/http2/adapter/nghttp2_adapter_test.cc
index 8b7d694..147d991 100644
--- a/http2/adapter/nghttp2_adapter_test.cc
+++ b/http2/adapter/nghttp2_adapter_test.cc
@@ -3531,6 +3531,154 @@
   EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS}));
 }
 
+TEST(NgHttp2AdapterTest, ServerForbidsProtocolPseudoheaderBeforeAck) {
+  DataSavingVisitor visitor;
+  auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor);
+
+  const std::string initial_frames =
+      TestFrameSequence().ClientPreface().Serialize();
+
+  testing::InSequence s;
+
+  // Client preface (empty SETTINGS)
+  EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0));
+  EXPECT_CALL(visitor, OnSettingsStart());
+  EXPECT_CALL(visitor, OnSettingsEnd());
+
+  const int64_t initial_result = adapter->ProcessBytes(initial_frames);
+  EXPECT_EQ(static_cast<size_t>(initial_result), initial_frames.size());
+
+  // The client attempts to send a CONNECT request with the `:protocol`
+  // pseudoheader before receiving the server's SETTINGS frame.
+  const std::string stream1_frames =
+      TestFrameSequence()
+          .Headers(1,
+                   {{":method", "CONNECT"},
+                    {":scheme", "https"},
+                    {":authority", "example.com"},
+                    {":path", "/this/is/request/one"},
+                    {":protocol", "websocket"}},
+                   /*fin=*/true)
+          .Serialize();
+
+  EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 0x5));
+  EXPECT_CALL(visitor, OnBeginHeadersForStream(1));
+  EXPECT_CALL(visitor, OnHeaderForStream(1, _, _)).Times(4);
+  EXPECT_CALL(
+      visitor,
+      OnErrorDebug("Invalid HTTP header field was received: frame type: 1, "
+                   "stream: 1, name: [:protocol], value: [websocket]"));
+  EXPECT_CALL(
+      visitor,
+      OnInvalidFrame(1, Http2VisitorInterface::InvalidFrameError::kHttpHeader));
+
+  int64_t stream_result = adapter->ProcessBytes(stream1_frames);
+  EXPECT_EQ(static_cast<size_t>(stream_result), stream1_frames.size());
+
+  // Server sends a SETTINGS ack and initial SETTINGS (with
+  // ENABLE_CONNECT_PROTOCOL).
+  EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1));
+  EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0));
+  EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 6, 0x0));
+  EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 6, 0x0, 0));
+
+  // The server sends a RST_STREAM for the offending stream.
+  EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 1, _, 0x0));
+  EXPECT_CALL(visitor,
+              OnFrameSent(RST_STREAM, 1, _, 0x0,
+                          static_cast<int>(Http2ErrorCode::PROTOCOL_ERROR)));
+  EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::PROTOCOL_ERROR));
+
+  adapter->SubmitSettings({{ENABLE_CONNECT_PROTOCOL, 1}});
+  int send_result = adapter->Send();
+  EXPECT_EQ(0, send_result);
+  EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS,
+                                            spdy::SpdyFrameType::SETTINGS,
+                                            spdy::SpdyFrameType::RST_STREAM}));
+  visitor.Clear();
+
+  // The client attempts to send a CONNECT request with the `:protocol`
+  // pseudoheader before acking the server's SETTINGS frame.
+  const std::string stream3_frames =
+      TestFrameSequence()
+          .Headers(3,
+                   {{":method", "CONNECT"},
+                    {":scheme", "https"},
+                    {":authority", "example.com"},
+                    {":path", "/this/is/request/two"},
+                    {":protocol", "websocket"}},
+                   /*fin=*/true)
+          .Serialize();
+
+  // Surprisingly, nghttp2 is okay with this.
+  EXPECT_CALL(visitor, OnFrameHeader(3, _, HEADERS, 0x5));
+  EXPECT_CALL(visitor, OnBeginHeadersForStream(3));
+  EXPECT_CALL(visitor, OnHeaderForStream(3, _, _)).Times(5);
+  EXPECT_CALL(visitor, OnEndHeadersForStream(3));
+  EXPECT_CALL(visitor, OnEndStream(3));
+
+  stream_result = adapter->ProcessBytes(stream3_frames);
+  EXPECT_EQ(static_cast<size_t>(stream_result), stream3_frames.size());
+
+  EXPECT_FALSE(adapter->want_write());
+}
+
+TEST(NgHttp2AdapterTest, ServerAllowsProtocolPseudoheaderAfterAck) {
+  DataSavingVisitor visitor;
+  auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor);
+  adapter->SubmitSettings({{ENABLE_CONNECT_PROTOCOL, 1}});
+
+  const std::string initial_frames =
+      TestFrameSequence().ClientPreface().Serialize();
+
+  testing::InSequence s;
+
+  // Client preface (empty SETTINGS)
+  EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0));
+  EXPECT_CALL(visitor, OnSettingsStart());
+  EXPECT_CALL(visitor, OnSettingsEnd());
+
+  const int64_t initial_result = adapter->ProcessBytes(initial_frames);
+  EXPECT_EQ(static_cast<size_t>(initial_result), initial_frames.size());
+
+  // Server initial SETTINGS (with ENABLE_CONNECT_PROTOCOL) and SETTINGS ack.
+  EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 6, 0x0));
+  EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 6, 0x0, 0));
+  EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1));
+  EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0));
+
+  int send_result = adapter->Send();
+  EXPECT_EQ(0, send_result);
+  visitor.Clear();
+
+  // The client attempts to send a CONNECT request with the `:protocol`
+  // pseudoheader after acking the server's SETTINGS frame.
+  const std::string stream_frames =
+      TestFrameSequence()
+          .SettingsAck()
+          .Headers(1,
+                   {{":method", "CONNECT"},
+                    {":scheme", "https"},
+                    {":authority", "example.com"},
+                    {":path", "/this/is/request/one"},
+                    {":protocol", "websocket"}},
+                   /*fin=*/true)
+          .Serialize();
+
+  EXPECT_CALL(visitor, OnFrameHeader(0, _, SETTINGS, 0x1));
+  EXPECT_CALL(visitor, OnSettingsAck());
+  EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 0x5));
+  EXPECT_CALL(visitor, OnBeginHeadersForStream(1));
+  EXPECT_CALL(visitor, OnHeaderForStream(1, _, _)).Times(5);
+  EXPECT_CALL(visitor, OnEndHeadersForStream(1));
+  EXPECT_CALL(visitor, OnEndStream(1));
+
+  const int64_t stream_result = adapter->ProcessBytes(stream_frames);
+  EXPECT_EQ(static_cast<size_t>(stream_result), stream_frames.size());
+
+  EXPECT_FALSE(adapter->want_write());
+}
+
 }  // namespace
 }  // namespace test
 }  // namespace adapter
diff --git a/http2/adapter/oghttp2_adapter_test.cc b/http2/adapter/oghttp2_adapter_test.cc
index 3b22a92..b274f4a 100644
--- a/http2/adapter/oghttp2_adapter_test.cc
+++ b/http2/adapter/oghttp2_adapter_test.cc
@@ -3563,6 +3563,168 @@
   EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::RST_STREAM}));
 }
 
+TEST(OgHttp2AdapterServerTest, ServerForbidsProtocolPseudoheaderBeforeAck) {
+  DataSavingVisitor visitor;
+  OgHttp2Adapter::Options options{.perspective = Perspective::kServer};
+  auto adapter = OgHttp2Adapter::Create(visitor, options);
+
+  const std::string initial_frames =
+      TestFrameSequence().ClientPreface().Serialize();
+
+  testing::InSequence s;
+
+  // Client preface (empty SETTINGS)
+  EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0));
+  EXPECT_CALL(visitor, OnSettingsStart());
+  EXPECT_CALL(visitor, OnSettingsEnd());
+
+  const int64_t initial_result = adapter->ProcessBytes(initial_frames);
+  EXPECT_EQ(static_cast<size_t>(initial_result), initial_frames.size());
+
+  // The client attempts to send a CONNECT request with the `:protocol`
+  // pseudoheader before receiving the server's SETTINGS frame.
+  const std::string stream1_frames =
+      TestFrameSequence()
+          .Headers(1,
+                   {{":method", "CONNECT"},
+                    {":scheme", "https"},
+                    {":authority", "example.com"},
+                    {":path", "/this/is/request/one"},
+                    {":protocol", "websocket"}},
+                   /*fin=*/true)
+          .Serialize();
+
+  EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 0x5));
+  EXPECT_CALL(visitor, OnBeginHeadersForStream(1));
+  EXPECT_CALL(visitor, OnHeaderForStream(1, _, _)).Times(5);
+  EXPECT_CALL(visitor,
+              OnInvalidFrame(
+                  1, Http2VisitorInterface::InvalidFrameError::kHttpMessaging));
+  EXPECT_CALL(visitor, OnEndStream(1));
+
+  int64_t stream_result = adapter->ProcessBytes(stream1_frames);
+  EXPECT_EQ(static_cast<size_t>(stream_result), stream1_frames.size());
+
+  // Server initial SETTINGS and SETTINGS ack.
+  EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x0));
+  EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x0, 0));
+  EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1));
+  EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0));
+
+  // The server sends a RST_STREAM for the offending stream.
+  EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 1, _, 0x0));
+  EXPECT_CALL(visitor,
+              OnFrameSent(RST_STREAM, 1, _, 0x0,
+                          static_cast<int>(Http2ErrorCode::PROTOCOL_ERROR)));
+  EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR));
+
+  // Server settings with ENABLE_CONNECT_PROTOCOL.
+  EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 6, 0x0));
+  EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 6, 0x0, 0));
+
+  adapter->SubmitSettings({{ENABLE_CONNECT_PROTOCOL, 1}});
+  int send_result = adapter->Send();
+  EXPECT_EQ(0, send_result);
+  EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS,
+                                            spdy::SpdyFrameType::SETTINGS,
+                                            spdy::SpdyFrameType::RST_STREAM,
+                                            spdy::SpdyFrameType::SETTINGS}));
+  visitor.Clear();
+
+  // The client attempts to send a CONNECT request with the `:protocol`
+  // pseudoheader before acking the server's SETTINGS frame.
+  const std::string stream3_frames =
+      TestFrameSequence()
+          .Headers(3,
+                   {{":method", "CONNECT"},
+                    {":scheme", "https"},
+                    {":authority", "example.com"},
+                    {":path", "/this/is/request/two"},
+                    {":protocol", "websocket"}},
+                   /*fin=*/true)
+          .Serialize();
+
+  EXPECT_CALL(visitor, OnFrameHeader(3, _, HEADERS, 0x5));
+  EXPECT_CALL(visitor, OnBeginHeadersForStream(3));
+  EXPECT_CALL(visitor, OnHeaderForStream(3, _, _)).Times(5);
+  EXPECT_CALL(visitor,
+              OnInvalidFrame(
+                  3, Http2VisitorInterface::InvalidFrameError::kHttpMessaging));
+  EXPECT_CALL(visitor, OnEndStream(3));
+
+  stream_result = adapter->ProcessBytes(stream3_frames);
+  EXPECT_EQ(static_cast<size_t>(stream_result), stream3_frames.size());
+
+  // The server sends a RST_STREAM for the offending stream.
+  EXPECT_TRUE(adapter->want_write());
+  EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 3, _, 0x0));
+  EXPECT_CALL(visitor,
+              OnFrameSent(RST_STREAM, 3, _, 0x0,
+                          static_cast<int>(Http2ErrorCode::PROTOCOL_ERROR)));
+  EXPECT_CALL(visitor, OnCloseStream(3, Http2ErrorCode::HTTP2_NO_ERROR));
+
+  send_result = adapter->Send();
+  EXPECT_EQ(0, send_result);
+  EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::RST_STREAM}));
+}
+
+TEST(OgHttp2AdapterServerTest, ServerAllowsProtocolPseudoheaderAfterAck) {
+  DataSavingVisitor visitor;
+  OgHttp2Adapter::Options options{.perspective = Perspective::kServer};
+  auto adapter = OgHttp2Adapter::Create(visitor, options);
+  adapter->SubmitSettings({{ENABLE_CONNECT_PROTOCOL, 1}});
+
+  const std::string initial_frames =
+      TestFrameSequence().ClientPreface().Serialize();
+
+  testing::InSequence s;
+
+  // Client preface (empty SETTINGS)
+  EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0));
+  EXPECT_CALL(visitor, OnSettingsStart());
+  EXPECT_CALL(visitor, OnSettingsEnd());
+
+  const int64_t initial_result = adapter->ProcessBytes(initial_frames);
+  EXPECT_EQ(static_cast<size_t>(initial_result), initial_frames.size());
+
+  // Server initial SETTINGS (with ENABLE_CONNECT_PROTOCOL) and SETTINGS ack.
+  EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 6, 0x0));
+  EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 6, 0x0, 0));
+  EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1));
+  EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0));
+
+  int send_result = adapter->Send();
+  EXPECT_EQ(0, send_result);
+  visitor.Clear();
+
+  // The client attempts to send a CONNECT request with the `:protocol`
+  // pseudoheader after acking the server's SETTINGS frame.
+  const std::string stream_frames =
+      TestFrameSequence()
+          .SettingsAck()
+          .Headers(1,
+                   {{":method", "CONNECT"},
+                    {":scheme", "https"},
+                    {":authority", "example.com"},
+                    {":path", "/this/is/request/one"},
+                    {":protocol", "websocket"}},
+                   /*fin=*/true)
+          .Serialize();
+
+  EXPECT_CALL(visitor, OnFrameHeader(0, _, SETTINGS, 0x1));
+  EXPECT_CALL(visitor, OnSettingsAck());
+  EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 0x5));
+  EXPECT_CALL(visitor, OnBeginHeadersForStream(1));
+  EXPECT_CALL(visitor, OnHeaderForStream(1, _, _)).Times(5);
+  EXPECT_CALL(visitor, OnEndHeadersForStream(1));
+  EXPECT_CALL(visitor, OnEndStream(1));
+
+  const int64_t stream_result = adapter->ProcessBytes(stream_frames);
+  EXPECT_EQ(static_cast<size_t>(stream_result), stream_frames.size());
+
+  EXPECT_FALSE(adapter->want_write());
+}
+
 }  // namespace
 }  // namespace test
 }  // namespace adapter
diff --git a/http2/adapter/oghttp2_session.cc b/http2/adapter/oghttp2_session.cc
index aaf97d7..1f6d2fd 100644
--- a/http2/adapter/oghttp2_session.cc
+++ b/http2/adapter/oghttp2_session.cc
@@ -1274,6 +1274,9 @@
           } else if (id_and_value.first == spdy::SETTINGS_HEADER_TABLE_SIZE) {
             decoder_.GetHpackDecoder()->ApplyHeaderTableSizeSetting(
                 id_and_value.second);
+          } else if (id_and_value.first == ENABLE_CONNECT_PROTOCOL &&
+                     id_and_value.second == 1u) {
+            headers_handler_.AllowConnect();
           }
         }
       });
diff --git a/http2/adapter/oghttp2_session.h b/http2/adapter/oghttp2_session.h
index e4a2374..74dedac 100644
--- a/http2/adapter/oghttp2_session.h
+++ b/http2/adapter/oghttp2_session.h
@@ -250,6 +250,7 @@
                     type_ == HeaderType::RESPONSE_100);
       return validator_.status_header();
     }
+    void AllowConnect() { validator_.AllowConnect(); }
 
    private:
     OgHttp2Session& session_;