Supports the `:protocol` pseudo-header for connect requests if SETTINGS_ENABLE_CONNECT_PROTOCOL is sent.
PiperOrigin-RevId: 416316452
diff --git a/http2/adapter/header_validator.cc b/http2/adapter/header_validator.cc
index e37b737..5e8ee97 100644
--- a/http2/adapter/header_validator.cc
+++ b/http2/adapter/header_validator.cc
@@ -20,9 +20,14 @@
const absl::string_view kHttp2StatusValueAllowedChars = "0123456789";
-// TODO(birenroy): Support websocket requests, which contain an extra
-// `:protocol` pseudo-header.
-bool ValidateRequestHeaders(const std::vector<std::string>& pseudo_headers) {
+bool ValidateRequestHeaders(const std::vector<std::string>& pseudo_headers,
+ absl::string_view method, bool allow_connect) {
+ if (allow_connect && method == "CONNECT") {
+ static const std::vector<std::string>* kConnectHeaders =
+ new std::vector<std::string>(
+ {":authority", ":method", ":path", ":protocol", ":scheme"});
+ return pseudo_headers == *kConnectHeaders;
+ }
static const std::vector<std::string>* kRequiredHeaders =
new std::vector<std::string>(
{":authority", ":method", ":path", ":scheme"});
@@ -48,6 +53,7 @@
void HeaderValidator::StartHeaderBlock() {
pseudo_headers_.clear();
status_.clear();
+ method_.clear();
}
HeaderValidator::HeaderStatus HeaderValidator::ValidateSingleHeader(
@@ -82,6 +88,8 @@
return HEADER_VALUE_INVALID_STATUS;
}
status_ = std::string(value);
+ } else if (key == ":method") {
+ method_ = std::string(value);
}
pseudo_headers_.push_back(std::string(key));
}
@@ -94,7 +102,7 @@
std::sort(pseudo_headers_.begin(), pseudo_headers_.end());
switch (type) {
case HeaderType::REQUEST:
- return ValidateRequestHeaders(pseudo_headers_);
+ return ValidateRequestHeaders(pseudo_headers_, method_, allow_connect_);
case HeaderType::REQUEST_TRAILER:
return ValidateRequestTrailers(pseudo_headers_);
case HeaderType::RESPONSE_100:
diff --git a/http2/adapter/header_validator.h b/http2/adapter/header_validator.h
index a27fdde..ecde204 100644
--- a/http2/adapter/header_validator.h
+++ b/http2/adapter/header_validator.h
@@ -22,6 +22,10 @@
public:
HeaderValidator() {}
+ // If called, this validator will allow the `:protocol` pseudo-header, as
+ // described in RFC 8441.
+ void AllowConnect() { allow_connect_ = true; }
+
void StartHeaderBlock();
enum HeaderStatus {
@@ -38,11 +42,14 @@
// present for the given header type.
bool FinishHeaderBlock(HeaderType type);
+ // For responses, returns the value of the ":status" header, if present.
absl::string_view status_header() const { return status_; }
private:
std::vector<std::string> pseudo_headers_;
std::string status_;
+ std::string method_;
+ bool allow_connect_ = false;
};
} // namespace adapter
diff --git a/http2/adapter/header_validator_test.cc b/http2/adapter/header_validator_test.cc
index 01460c9..dcd92e1 100644
--- a/http2/adapter/header_validator_test.cc
+++ b/http2/adapter/header_validator_test.cc
@@ -154,8 +154,38 @@
}
EXPECT_EQ(HeaderValidator::HEADER_OK,
v.ValidateSingleHeader(":protocol", "websocket"));
- // For now, `:protocol` is treated as an extra pseudo-header.
+ // At this point, `:protocol` is treated as an extra pseudo-header.
EXPECT_FALSE(v.FinishHeaderBlock(HeaderType::REQUEST));
+
+ // Future header blocks may send the `:protocol` pseudo-header for CONNECT
+ // requests.
+ v.AllowConnect();
+
+ v.StartHeaderBlock();
+ for (absl::string_view to_add : headers) {
+ EXPECT_EQ(HeaderValidator::HEADER_OK,
+ v.ValidateSingleHeader(to_add, "foo"));
+ }
+ EXPECT_EQ(HeaderValidator::HEADER_OK,
+ v.ValidateSingleHeader(":protocol", "websocket"));
+ // The method is "foo", not "CONNECT", so `:protocol` is still treated as an
+ // extra pseudo-header.
+ EXPECT_FALSE(v.FinishHeaderBlock(HeaderType::REQUEST));
+
+ v.StartHeaderBlock();
+ for (absl::string_view to_add : headers) {
+ if (to_add == ":method") {
+ EXPECT_EQ(HeaderValidator::HEADER_OK,
+ v.ValidateSingleHeader(to_add, "CONNECT"));
+ } else {
+ EXPECT_EQ(HeaderValidator::HEADER_OK,
+ v.ValidateSingleHeader(to_add, "foo"));
+ }
+ }
+ EXPECT_EQ(HeaderValidator::HEADER_OK,
+ v.ValidateSingleHeader(":protocol", "websocket"));
+ // After allowing the method, `:protocol` is acepted for CONNECT requests.
+ EXPECT_TRUE(v.FinishHeaderBlock(HeaderType::REQUEST));
}
TEST(HeaderValidatorTest, ResponsePseudoHeaders) {
diff --git a/http2/adapter/nghttp2_adapter_test.cc b/http2/adapter/nghttp2_adapter_test.cc
index 8b7d694..147d991 100644
--- a/http2/adapter/nghttp2_adapter_test.cc
+++ b/http2/adapter/nghttp2_adapter_test.cc
@@ -3531,6 +3531,154 @@
EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS}));
}
+TEST(NgHttp2AdapterTest, ServerForbidsProtocolPseudoheaderBeforeAck) {
+ DataSavingVisitor visitor;
+ auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor);
+
+ const std::string initial_frames =
+ TestFrameSequence().ClientPreface().Serialize();
+
+ testing::InSequence s;
+
+ // Client preface (empty SETTINGS)
+ EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0));
+ EXPECT_CALL(visitor, OnSettingsStart());
+ EXPECT_CALL(visitor, OnSettingsEnd());
+
+ const int64_t initial_result = adapter->ProcessBytes(initial_frames);
+ EXPECT_EQ(static_cast<size_t>(initial_result), initial_frames.size());
+
+ // The client attempts to send a CONNECT request with the `:protocol`
+ // pseudoheader before receiving the server's SETTINGS frame.
+ const std::string stream1_frames =
+ TestFrameSequence()
+ .Headers(1,
+ {{":method", "CONNECT"},
+ {":scheme", "https"},
+ {":authority", "example.com"},
+ {":path", "/this/is/request/one"},
+ {":protocol", "websocket"}},
+ /*fin=*/true)
+ .Serialize();
+
+ EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 0x5));
+ EXPECT_CALL(visitor, OnBeginHeadersForStream(1));
+ EXPECT_CALL(visitor, OnHeaderForStream(1, _, _)).Times(4);
+ EXPECT_CALL(
+ visitor,
+ OnErrorDebug("Invalid HTTP header field was received: frame type: 1, "
+ "stream: 1, name: [:protocol], value: [websocket]"));
+ EXPECT_CALL(
+ visitor,
+ OnInvalidFrame(1, Http2VisitorInterface::InvalidFrameError::kHttpHeader));
+
+ int64_t stream_result = adapter->ProcessBytes(stream1_frames);
+ EXPECT_EQ(static_cast<size_t>(stream_result), stream1_frames.size());
+
+ // Server sends a SETTINGS ack and initial SETTINGS (with
+ // ENABLE_CONNECT_PROTOCOL).
+ EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1));
+ EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0));
+ EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 6, 0x0));
+ EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 6, 0x0, 0));
+
+ // The server sends a RST_STREAM for the offending stream.
+ EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 1, _, 0x0));
+ EXPECT_CALL(visitor,
+ OnFrameSent(RST_STREAM, 1, _, 0x0,
+ static_cast<int>(Http2ErrorCode::PROTOCOL_ERROR)));
+ EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::PROTOCOL_ERROR));
+
+ adapter->SubmitSettings({{ENABLE_CONNECT_PROTOCOL, 1}});
+ int send_result = adapter->Send();
+ EXPECT_EQ(0, send_result);
+ EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS,
+ spdy::SpdyFrameType::SETTINGS,
+ spdy::SpdyFrameType::RST_STREAM}));
+ visitor.Clear();
+
+ // The client attempts to send a CONNECT request with the `:protocol`
+ // pseudoheader before acking the server's SETTINGS frame.
+ const std::string stream3_frames =
+ TestFrameSequence()
+ .Headers(3,
+ {{":method", "CONNECT"},
+ {":scheme", "https"},
+ {":authority", "example.com"},
+ {":path", "/this/is/request/two"},
+ {":protocol", "websocket"}},
+ /*fin=*/true)
+ .Serialize();
+
+ // Surprisingly, nghttp2 is okay with this.
+ EXPECT_CALL(visitor, OnFrameHeader(3, _, HEADERS, 0x5));
+ EXPECT_CALL(visitor, OnBeginHeadersForStream(3));
+ EXPECT_CALL(visitor, OnHeaderForStream(3, _, _)).Times(5);
+ EXPECT_CALL(visitor, OnEndHeadersForStream(3));
+ EXPECT_CALL(visitor, OnEndStream(3));
+
+ stream_result = adapter->ProcessBytes(stream3_frames);
+ EXPECT_EQ(static_cast<size_t>(stream_result), stream3_frames.size());
+
+ EXPECT_FALSE(adapter->want_write());
+}
+
+TEST(NgHttp2AdapterTest, ServerAllowsProtocolPseudoheaderAfterAck) {
+ DataSavingVisitor visitor;
+ auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor);
+ adapter->SubmitSettings({{ENABLE_CONNECT_PROTOCOL, 1}});
+
+ const std::string initial_frames =
+ TestFrameSequence().ClientPreface().Serialize();
+
+ testing::InSequence s;
+
+ // Client preface (empty SETTINGS)
+ EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0));
+ EXPECT_CALL(visitor, OnSettingsStart());
+ EXPECT_CALL(visitor, OnSettingsEnd());
+
+ const int64_t initial_result = adapter->ProcessBytes(initial_frames);
+ EXPECT_EQ(static_cast<size_t>(initial_result), initial_frames.size());
+
+ // Server initial SETTINGS (with ENABLE_CONNECT_PROTOCOL) and SETTINGS ack.
+ EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 6, 0x0));
+ EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 6, 0x0, 0));
+ EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1));
+ EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0));
+
+ int send_result = adapter->Send();
+ EXPECT_EQ(0, send_result);
+ visitor.Clear();
+
+ // The client attempts to send a CONNECT request with the `:protocol`
+ // pseudoheader after acking the server's SETTINGS frame.
+ const std::string stream_frames =
+ TestFrameSequence()
+ .SettingsAck()
+ .Headers(1,
+ {{":method", "CONNECT"},
+ {":scheme", "https"},
+ {":authority", "example.com"},
+ {":path", "/this/is/request/one"},
+ {":protocol", "websocket"}},
+ /*fin=*/true)
+ .Serialize();
+
+ EXPECT_CALL(visitor, OnFrameHeader(0, _, SETTINGS, 0x1));
+ EXPECT_CALL(visitor, OnSettingsAck());
+ EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 0x5));
+ EXPECT_CALL(visitor, OnBeginHeadersForStream(1));
+ EXPECT_CALL(visitor, OnHeaderForStream(1, _, _)).Times(5);
+ EXPECT_CALL(visitor, OnEndHeadersForStream(1));
+ EXPECT_CALL(visitor, OnEndStream(1));
+
+ const int64_t stream_result = adapter->ProcessBytes(stream_frames);
+ EXPECT_EQ(static_cast<size_t>(stream_result), stream_frames.size());
+
+ EXPECT_FALSE(adapter->want_write());
+}
+
} // namespace
} // namespace test
} // namespace adapter
diff --git a/http2/adapter/oghttp2_adapter_test.cc b/http2/adapter/oghttp2_adapter_test.cc
index 3b22a92..b274f4a 100644
--- a/http2/adapter/oghttp2_adapter_test.cc
+++ b/http2/adapter/oghttp2_adapter_test.cc
@@ -3563,6 +3563,168 @@
EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::RST_STREAM}));
}
+TEST(OgHttp2AdapterServerTest, ServerForbidsProtocolPseudoheaderBeforeAck) {
+ DataSavingVisitor visitor;
+ OgHttp2Adapter::Options options{.perspective = Perspective::kServer};
+ auto adapter = OgHttp2Adapter::Create(visitor, options);
+
+ const std::string initial_frames =
+ TestFrameSequence().ClientPreface().Serialize();
+
+ testing::InSequence s;
+
+ // Client preface (empty SETTINGS)
+ EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0));
+ EXPECT_CALL(visitor, OnSettingsStart());
+ EXPECT_CALL(visitor, OnSettingsEnd());
+
+ const int64_t initial_result = adapter->ProcessBytes(initial_frames);
+ EXPECT_EQ(static_cast<size_t>(initial_result), initial_frames.size());
+
+ // The client attempts to send a CONNECT request with the `:protocol`
+ // pseudoheader before receiving the server's SETTINGS frame.
+ const std::string stream1_frames =
+ TestFrameSequence()
+ .Headers(1,
+ {{":method", "CONNECT"},
+ {":scheme", "https"},
+ {":authority", "example.com"},
+ {":path", "/this/is/request/one"},
+ {":protocol", "websocket"}},
+ /*fin=*/true)
+ .Serialize();
+
+ EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 0x5));
+ EXPECT_CALL(visitor, OnBeginHeadersForStream(1));
+ EXPECT_CALL(visitor, OnHeaderForStream(1, _, _)).Times(5);
+ EXPECT_CALL(visitor,
+ OnInvalidFrame(
+ 1, Http2VisitorInterface::InvalidFrameError::kHttpMessaging));
+ EXPECT_CALL(visitor, OnEndStream(1));
+
+ int64_t stream_result = adapter->ProcessBytes(stream1_frames);
+ EXPECT_EQ(static_cast<size_t>(stream_result), stream1_frames.size());
+
+ // Server initial SETTINGS and SETTINGS ack.
+ EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x0));
+ EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x0, 0));
+ EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1));
+ EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0));
+
+ // The server sends a RST_STREAM for the offending stream.
+ EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 1, _, 0x0));
+ EXPECT_CALL(visitor,
+ OnFrameSent(RST_STREAM, 1, _, 0x0,
+ static_cast<int>(Http2ErrorCode::PROTOCOL_ERROR)));
+ EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR));
+
+ // Server settings with ENABLE_CONNECT_PROTOCOL.
+ EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 6, 0x0));
+ EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 6, 0x0, 0));
+
+ adapter->SubmitSettings({{ENABLE_CONNECT_PROTOCOL, 1}});
+ int send_result = adapter->Send();
+ EXPECT_EQ(0, send_result);
+ EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS,
+ spdy::SpdyFrameType::SETTINGS,
+ spdy::SpdyFrameType::RST_STREAM,
+ spdy::SpdyFrameType::SETTINGS}));
+ visitor.Clear();
+
+ // The client attempts to send a CONNECT request with the `:protocol`
+ // pseudoheader before acking the server's SETTINGS frame.
+ const std::string stream3_frames =
+ TestFrameSequence()
+ .Headers(3,
+ {{":method", "CONNECT"},
+ {":scheme", "https"},
+ {":authority", "example.com"},
+ {":path", "/this/is/request/two"},
+ {":protocol", "websocket"}},
+ /*fin=*/true)
+ .Serialize();
+
+ EXPECT_CALL(visitor, OnFrameHeader(3, _, HEADERS, 0x5));
+ EXPECT_CALL(visitor, OnBeginHeadersForStream(3));
+ EXPECT_CALL(visitor, OnHeaderForStream(3, _, _)).Times(5);
+ EXPECT_CALL(visitor,
+ OnInvalidFrame(
+ 3, Http2VisitorInterface::InvalidFrameError::kHttpMessaging));
+ EXPECT_CALL(visitor, OnEndStream(3));
+
+ stream_result = adapter->ProcessBytes(stream3_frames);
+ EXPECT_EQ(static_cast<size_t>(stream_result), stream3_frames.size());
+
+ // The server sends a RST_STREAM for the offending stream.
+ EXPECT_TRUE(adapter->want_write());
+ EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 3, _, 0x0));
+ EXPECT_CALL(visitor,
+ OnFrameSent(RST_STREAM, 3, _, 0x0,
+ static_cast<int>(Http2ErrorCode::PROTOCOL_ERROR)));
+ EXPECT_CALL(visitor, OnCloseStream(3, Http2ErrorCode::HTTP2_NO_ERROR));
+
+ send_result = adapter->Send();
+ EXPECT_EQ(0, send_result);
+ EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::RST_STREAM}));
+}
+
+TEST(OgHttp2AdapterServerTest, ServerAllowsProtocolPseudoheaderAfterAck) {
+ DataSavingVisitor visitor;
+ OgHttp2Adapter::Options options{.perspective = Perspective::kServer};
+ auto adapter = OgHttp2Adapter::Create(visitor, options);
+ adapter->SubmitSettings({{ENABLE_CONNECT_PROTOCOL, 1}});
+
+ const std::string initial_frames =
+ TestFrameSequence().ClientPreface().Serialize();
+
+ testing::InSequence s;
+
+ // Client preface (empty SETTINGS)
+ EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0));
+ EXPECT_CALL(visitor, OnSettingsStart());
+ EXPECT_CALL(visitor, OnSettingsEnd());
+
+ const int64_t initial_result = adapter->ProcessBytes(initial_frames);
+ EXPECT_EQ(static_cast<size_t>(initial_result), initial_frames.size());
+
+ // Server initial SETTINGS (with ENABLE_CONNECT_PROTOCOL) and SETTINGS ack.
+ EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 6, 0x0));
+ EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 6, 0x0, 0));
+ EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1));
+ EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0));
+
+ int send_result = adapter->Send();
+ EXPECT_EQ(0, send_result);
+ visitor.Clear();
+
+ // The client attempts to send a CONNECT request with the `:protocol`
+ // pseudoheader after acking the server's SETTINGS frame.
+ const std::string stream_frames =
+ TestFrameSequence()
+ .SettingsAck()
+ .Headers(1,
+ {{":method", "CONNECT"},
+ {":scheme", "https"},
+ {":authority", "example.com"},
+ {":path", "/this/is/request/one"},
+ {":protocol", "websocket"}},
+ /*fin=*/true)
+ .Serialize();
+
+ EXPECT_CALL(visitor, OnFrameHeader(0, _, SETTINGS, 0x1));
+ EXPECT_CALL(visitor, OnSettingsAck());
+ EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 0x5));
+ EXPECT_CALL(visitor, OnBeginHeadersForStream(1));
+ EXPECT_CALL(visitor, OnHeaderForStream(1, _, _)).Times(5);
+ EXPECT_CALL(visitor, OnEndHeadersForStream(1));
+ EXPECT_CALL(visitor, OnEndStream(1));
+
+ const int64_t stream_result = adapter->ProcessBytes(stream_frames);
+ EXPECT_EQ(static_cast<size_t>(stream_result), stream_frames.size());
+
+ EXPECT_FALSE(adapter->want_write());
+}
+
} // namespace
} // namespace test
} // namespace adapter
diff --git a/http2/adapter/oghttp2_session.cc b/http2/adapter/oghttp2_session.cc
index aaf97d7..1f6d2fd 100644
--- a/http2/adapter/oghttp2_session.cc
+++ b/http2/adapter/oghttp2_session.cc
@@ -1274,6 +1274,9 @@
} else if (id_and_value.first == spdy::SETTINGS_HEADER_TABLE_SIZE) {
decoder_.GetHpackDecoder()->ApplyHeaderTableSizeSetting(
id_and_value.second);
+ } else if (id_and_value.first == ENABLE_CONNECT_PROTOCOL &&
+ id_and_value.second == 1u) {
+ headers_handler_.AllowConnect();
}
}
});
diff --git a/http2/adapter/oghttp2_session.h b/http2/adapter/oghttp2_session.h
index e4a2374..74dedac 100644
--- a/http2/adapter/oghttp2_session.h
+++ b/http2/adapter/oghttp2_session.h
@@ -250,6 +250,7 @@
type_ == HeaderType::RESPONSE_100);
return validator_.status_header();
}
+ void AllowConnect() { validator_.AllowConnect(); }
private:
OgHttp2Session& session_;