Adds a return value to Http2VisitorInterface::OnCloseStream().
It turns out that Envoy can return an error in the `on_stream_close` user callback.
PiperOrigin-RevId: 439596004
diff --git a/http2/adapter/callback_visitor.cc b/http2/adapter/callback_visitor.cc
index 2a00eed..4b88a45 100644
--- a/http2/adapter/callback_visitor.cc
+++ b/http2/adapter/callback_visitor.cc
@@ -254,18 +254,20 @@
}
}
-void CallbackVisitor::OnCloseStream(Http2StreamId stream_id,
+bool CallbackVisitor::OnCloseStream(Http2StreamId stream_id,
Http2ErrorCode error_code) {
+ int result = 0;
if (callbacks_->on_stream_close_callback) {
QUICHE_VLOG(1) << "OnCloseStream(stream_id: " << stream_id
<< ", error_code: " << int(error_code) << ")";
- callbacks_->on_stream_close_callback(
+ result = callbacks_->on_stream_close_callback(
nullptr, stream_id, static_cast<uint32_t>(error_code), user_data_);
}
stream_map_.erase(stream_id);
if (stream_close_listener_) {
stream_close_listener_(stream_id);
}
+ return result == 0;
}
void CallbackVisitor::OnPriorityForStream(Http2StreamId /*stream_id*/,
diff --git a/http2/adapter/callback_visitor.h b/http2/adapter/callback_visitor.h
index a64c8e8..d68211d 100644
--- a/http2/adapter/callback_visitor.h
+++ b/http2/adapter/callback_visitor.h
@@ -47,7 +47,7 @@
absl::string_view data) override;
void OnEndStream(Http2StreamId stream_id) override;
void OnRstStream(Http2StreamId stream_id, Http2ErrorCode error_code) override;
- void OnCloseStream(Http2StreamId stream_id,
+ bool OnCloseStream(Http2StreamId stream_id,
Http2ErrorCode error_code) override;
void OnPriorityForStream(Http2StreamId stream_id,
Http2StreamId parent_stream_id,
diff --git a/http2/adapter/callback_visitor_test.cc b/http2/adapter/callback_visitor_test.cc
index fee6fc4..c6c19f8 100644
--- a/http2/adapter/callback_visitor_test.cc
+++ b/http2/adapter/callback_visitor_test.cc
@@ -225,6 +225,36 @@
visitor.OnEndHeadersForStream(1);
}
+TEST(ClientCallbackVisitorUnitTest, ResetAndGoaway) {
+ testing::StrictMock<MockNghttp2Callbacks> callbacks;
+ CallbackVisitor visitor(Perspective::kClient,
+ *MockNghttp2Callbacks::GetCallbacks(), &callbacks);
+
+ testing::InSequence seq;
+
+ // RST_STREAM on stream 1
+ EXPECT_CALL(callbacks, OnBeginFrame(HasFrameHeader(1, RST_STREAM, 0x0)));
+ EXPECT_TRUE(visitor.OnFrameHeader(1, 13, RST_STREAM, 0x0));
+
+ EXPECT_CALL(callbacks, OnFrameRecv(IsRstStream(1, NGHTTP2_INTERNAL_ERROR)));
+ visitor.OnRstStream(1, Http2ErrorCode::INTERNAL_ERROR);
+
+ EXPECT_CALL(callbacks, OnStreamClose(1, NGHTTP2_INTERNAL_ERROR));
+ EXPECT_TRUE(visitor.OnCloseStream(1, Http2ErrorCode::INTERNAL_ERROR));
+
+ EXPECT_CALL(callbacks, OnBeginFrame(HasFrameHeader(0, GOAWAY, 0x0)));
+ EXPECT_TRUE(visitor.OnFrameHeader(0, 13, GOAWAY, 0x0));
+
+ EXPECT_CALL(callbacks,
+ OnFrameRecv(IsGoAway(3, NGHTTP2_ENHANCE_YOUR_CALM, "calma te")));
+ EXPECT_TRUE(
+ visitor.OnGoAway(3, Http2ErrorCode::ENHANCE_YOUR_CALM, "calma te"));
+
+ EXPECT_CALL(callbacks, OnStreamClose(5, NGHTTP2_STREAM_CLOSED))
+ .WillOnce(testing::Return(NGHTTP2_ERR_CALLBACK_FAILURE));
+ EXPECT_FALSE(visitor.OnCloseStream(5, Http2ErrorCode::STREAM_CLOSED));
+}
+
TEST(ServerCallbackVisitorUnitTest, ConnectionFrames) {
testing::StrictMock<MockNghttp2Callbacks> callbacks;
CallbackVisitor visitor(Perspective::kServer,
diff --git a/http2/adapter/http2_visitor_interface.h b/http2/adapter/http2_visitor_interface.h
index db292d0..9136db2 100644
--- a/http2/adapter/http2_visitor_interface.h
+++ b/http2/adapter/http2_visitor_interface.h
@@ -83,7 +83,8 @@
};
virtual void OnConnectionError(ConnectionError error) = 0;
- // Called when the header for a frame is received.
+ // Called when the header for a frame is received. Returns false if a fatal
+ // error has occurred.
virtual bool OnFrameHeader(Http2StreamId /*stream_id*/, size_t /*length*/,
uint8_t /*type*/, uint8_t /*flags*/) {
return true;
@@ -138,22 +139,24 @@
// Called when the connection has received the complete header block for a
// logical HEADERS frame on a stream (which may contain CONTINUATION frames,
- // transparent to the user).
+ // transparent to the user). Returns false if a fatal error has occurred.
virtual bool OnEndHeadersForStream(Http2StreamId stream_id) = 0;
// Called when the connection receives the beginning of a DATA frame. The data
- // payload will be provided via subsequent calls to OnDataForStream().
+ // payload will be provided via subsequent calls to OnDataForStream(). Returns
+ // false if a fatal error has occurred.
virtual bool OnBeginDataForStream(Http2StreamId stream_id,
size_t payload_length) = 0;
// Called when the optional padding length field is parsed as part of a DATA
// frame payload. `padding_length` represents the total amount of padding for
- // this frame, including the length byte itself.
+ // this frame, including the length byte itself. Returns false if a fatal
+ // error has occurred.
virtual bool OnDataPaddingLength(Http2StreamId stream_id,
size_t padding_length) = 0;
// Called when the connection receives some |data| (as part of a DATA frame
- // payload) for a stream.
+ // payload) for a stream. Returns false if a fatal error has occurred.
virtual bool OnDataForStream(Http2StreamId stream_id,
absl::string_view data) = 0;
@@ -166,8 +169,9 @@
virtual void OnRstStream(Http2StreamId stream_id,
Http2ErrorCode error_code) = 0;
- // Called when a stream is closed.
- virtual void OnCloseStream(Http2StreamId stream_id,
+ // Called when a stream is closed. Returns false if a fatal error has
+ // occurred.
+ virtual bool OnCloseStream(Http2StreamId stream_id,
Http2ErrorCode error_code) = 0;
// Called when the connection receives a PRIORITY frame.
@@ -183,7 +187,8 @@
virtual void OnPushPromiseForStream(Http2StreamId stream_id,
Http2StreamId promised_stream_id) = 0;
- // Called when the connection receives a GOAWAY frame.
+ // Called when the connection receives a GOAWAY frame. Returns false if a
+ // fatal error has occurred.
virtual bool OnGoAway(Http2StreamId last_accepted_stream_id,
Http2ErrorCode error_code,
absl::string_view opaque_data) = 0;
diff --git a/http2/adapter/mock_http2_visitor.h b/http2/adapter/mock_http2_visitor.h
index 96f0978..098ca84 100644
--- a/http2/adapter/mock_http2_visitor.h
+++ b/http2/adapter/mock_http2_visitor.h
@@ -23,6 +23,7 @@
ON_CALL(*this, OnDataPaddingLength).WillByDefault(testing::Return(true));
ON_CALL(*this, OnBeginDataForStream).WillByDefault(testing::Return(true));
ON_CALL(*this, OnDataForStream).WillByDefault(testing::Return(true));
+ ON_CALL(*this, OnCloseStream).WillByDefault(testing::Return(true));
ON_CALL(*this, OnGoAway).WillByDefault(testing::Return(true));
ON_CALL(*this, OnInvalidFrame).WillByDefault(testing::Return(true));
ON_CALL(*this, OnMetadataForStream).WillByDefault(testing::Return(true));
@@ -67,10 +68,8 @@
(Http2StreamId stream_id, Http2ErrorCode error_code),
(override));
- MOCK_METHOD(void,
- OnCloseStream,
- (Http2StreamId stream_id, Http2ErrorCode error_code),
- (override));
+ MOCK_METHOD(bool, OnCloseStream,
+ (Http2StreamId stream_id, Http2ErrorCode error_code), (override));
MOCK_METHOD(void,
OnPriorityForStream,
diff --git a/http2/adapter/nghttp2_adapter_test.cc b/http2/adapter/nghttp2_adapter_test.cc
index 7a2bb11..47d085c 100644
--- a/http2/adapter/nghttp2_adapter_test.cc
+++ b/http2/adapter/nghttp2_adapter_test.cc
@@ -211,6 +211,7 @@
.WillOnce(
[&adapter](Http2StreamId stream_id, Http2ErrorCode /*error_code*/) {
adapter->RemoveStream(stream_id);
+ return true;
});
EXPECT_CALL(visitor, OnFrameHeader(0, 19, GOAWAY, 0));
EXPECT_CALL(visitor,
@@ -254,6 +255,7 @@
.WillOnce(
[&adapter](Http2StreamId stream_id, Http2ErrorCode /*error_code*/) {
adapter->RemoveStream(stream_id);
+ return true;
});
EXPECT_CALL(visitor, OnFrameHeader(5, 4, RST_STREAM, 0));
EXPECT_CALL(visitor, OnRstStream(5, Http2ErrorCode::REFUSED_STREAM));
@@ -261,6 +263,7 @@
.WillOnce(
[&adapter](Http2StreamId stream_id, Http2ErrorCode /*error_code*/) {
adapter->RemoveStream(stream_id);
+ return true;
});
adapter->ProcessBytes(TestFrameSequence()
.Data(1, "", true)
diff --git a/http2/adapter/nghttp2_callbacks.cc b/http2/adapter/nghttp2_callbacks.cc
index fee8ffb..b8489b7 100644
--- a/http2/adapter/nghttp2_callbacks.cc
+++ b/http2/adapter/nghttp2_callbacks.cc
@@ -252,8 +252,9 @@
void* user_data) {
QUICHE_CHECK_NE(user_data, nullptr);
auto* visitor = static_cast<Http2VisitorInterface*>(user_data);
- visitor->OnCloseStream(stream_id, ToHttp2ErrorCode(error_code));
- return 0;
+ const bool result =
+ visitor->OnCloseStream(stream_id, ToHttp2ErrorCode(error_code));
+ return result ? 0 : NGHTTP2_ERR_CALLBACK_FAILURE;
}
int OnExtensionChunkReceived(nghttp2_session* /*session*/,
diff --git a/http2/adapter/oghttp2_session.cc b/http2/adapter/oghttp2_session.cc
index d8431dd..ed4e7fe 100644
--- a/http2/adapter/oghttp2_session.cc
+++ b/http2/adapter/oghttp2_session.cc
@@ -1699,7 +1699,11 @@
void OgHttp2Session::CloseStream(Http2StreamId stream_id,
Http2ErrorCode error_code) {
- visitor_.OnCloseStream(stream_id, error_code);
+ const bool result = visitor_.OnCloseStream(stream_id, error_code);
+ if (!result) {
+ latched_error_ = true;
+ decoder_.StopProcessing();
+ }
stream_map_.erase(stream_id);
trailers_ready_.erase(stream_id);
metadata_ready_.erase(stream_id);
diff --git a/http2/adapter/recording_http2_visitor.cc b/http2/adapter/recording_http2_visitor.cc
index b9cbcfb..f746259 100644
--- a/http2/adapter/recording_http2_visitor.cc
+++ b/http2/adapter/recording_http2_visitor.cc
@@ -90,10 +90,11 @@
Http2ErrorCodeToString(error_code)));
}
-void RecordingHttp2Visitor::OnCloseStream(Http2StreamId stream_id,
+bool RecordingHttp2Visitor::OnCloseStream(Http2StreamId stream_id,
Http2ErrorCode error_code) {
events_.push_back(absl::StrFormat("OnCloseStream %d %s", stream_id,
Http2ErrorCodeToString(error_code)));
+ return true;
}
void RecordingHttp2Visitor::OnPriorityForStream(Http2StreamId stream_id,
diff --git a/http2/adapter/recording_http2_visitor.h b/http2/adapter/recording_http2_visitor.h
index 3ded946..144b5fb 100644
--- a/http2/adapter/recording_http2_visitor.h
+++ b/http2/adapter/recording_http2_visitor.h
@@ -41,7 +41,7 @@
absl::string_view data) override;
void OnEndStream(Http2StreamId stream_id) override;
void OnRstStream(Http2StreamId stream_id, Http2ErrorCode error_code) override;
- void OnCloseStream(Http2StreamId stream_id,
+ bool OnCloseStream(Http2StreamId stream_id,
Http2ErrorCode error_code) override;
void OnPriorityForStream(Http2StreamId stream_id,
Http2StreamId parent_stream_id,