Add functions to handle WebTransport-Init headers. PiperOrigin-RevId: 583035611
diff --git a/quiche/web_transport/web_transport_headers.cc b/quiche/web_transport/web_transport_headers.cc index 7112f25..9d4b4f8 100644 --- a/quiche/web_transport/web_transport_headers.cc +++ b/quiche/web_transport/web_transport_headers.cc
@@ -4,26 +4,57 @@ #include "quiche/web_transport/web_transport_headers.h" +#include <array> +#include <cstdint> #include <optional> #include <string> #include <utility> #include <vector> +#include "absl/base/attributes.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "quiche/common/quiche_status_utils.h" #include "quiche/common/structured_headers.h" namespace webtransport { +namespace { +using ::quiche::structured_headers::Dictionary; +using ::quiche::structured_headers::DictionaryMember; +using ::quiche::structured_headers::Item; using ::quiche::structured_headers::ItemTypeToString; using ::quiche::structured_headers::List; using ::quiche::structured_headers::ParameterizedItem; using ::quiche::structured_headers::ParameterizedMember; +absl::Status CheckItemType(const ParameterizedMember& member, + Item::ItemType expected_type) { + if (member.member_is_inner_list || member.member.size() != 1) { + return absl::InvalidArgumentError(absl::StrCat( + "Expected all members to be of type", ItemTypeToString(expected_type), + ", found a nested list instead")); + } + if (member.member[0].item.Type() != expected_type) { + return absl::InvalidArgumentError(absl::StrCat( + "Expected all members to be of type ", ItemTypeToString(expected_type), + ", found ", ItemTypeToString(member.member[0].item.Type()), + " instead")); + } + return absl::OkStatus(); +} + +ABSL_CONST_INIT std::array kInitHeaderFields{ + std::make_pair("u", &WebTransportInitHeader::initial_unidi_limit), + std::make_pair("bl", &WebTransportInitHeader::initial_incoming_bidi_limit), + std::make_pair("br", &WebTransportInitHeader::initial_outgoing_bidi_limit), +}; +} // namespace + absl::StatusOr<std::vector<std::string>> ParseSubprotocolRequestHeader( absl::string_view value) { std::optional<List> parsed = quiche::structured_headers::ParseList(value); @@ -35,17 +66,8 @@ std::vector<std::string> result; result.reserve(parsed->size()); for (ParameterizedMember& member : *parsed) { - if (member.member_is_inner_list || member.member.size() != 1) { - return absl::InvalidArgumentError( - "Expected all members to be tokens, found a nested list instead"); - } - ParameterizedItem& item = member.member[0]; - if (!item.item.is_token()) { - return absl::InvalidArgumentError( - absl::StrCat("Expected all members to be tokens, found ", - ItemTypeToString(item.item.Type()), " instead")); - } - result.push_back(std::move(item).item.TakeString()); + QUICHE_RETURN_IF_ERROR(CheckItemType(member, Item::kTokenType)); + result.push_back(std::move(member.member[0].item).TakeString()); } return result; } @@ -62,4 +84,49 @@ return absl::StrJoin(subprotocols, ", "); } +absl::StatusOr<WebTransportInitHeader> ParseInitHeader( + absl::string_view header) { + std::optional<Dictionary> parsed = + quiche::structured_headers::ParseDictionary(header); + if (!parsed.has_value()) { + return absl::InvalidArgumentError( + "Failed to parse WebTransport-Init header as an sf-dictionary"); + } + WebTransportInitHeader output; + for (const auto& [field_name_a, field_value] : *parsed) { + for (const auto& [field_name_b, field_accessor] : kInitHeaderFields) { + if (field_name_a != field_name_b) { + continue; + } + QUICHE_RETURN_IF_ERROR(CheckItemType(field_value, Item::kIntegerType)); + int64_t value = field_value.member[0].item.GetInteger(); + if (value < 0) { + return absl::InvalidArgumentError( + absl::StrCat("Received negative value for ", field_name_a)); + } + output.*field_accessor = value; + } + } + return output; +} + +absl::StatusOr<std::string> SerializeInitHeader( + const WebTransportInitHeader& header) { + std::vector<DictionaryMember> members; + members.reserve(kInitHeaderFields.size()); + for (const auto& [field_name, field_accessor] : kInitHeaderFields) { + Item item(static_cast<int64_t>(header.*field_accessor)); + members.push_back(std::make_pair( + field_name, ParameterizedMember({ParameterizedItem(item, {})}, false, + /*parameters=*/{}))); + } + std::optional<std::string> result = + quiche::structured_headers::SerializeDictionary( + Dictionary(std::move(members))); + if (!result.has_value()) { + return absl::InternalError("Failed to serialize the dictionary"); + } + return *std::move(result); +} + } // namespace webtransport
diff --git a/quiche/web_transport/web_transport_headers.h b/quiche/web_transport/web_transport_headers.h index b3b16b2..10f6d2f 100644 --- a/quiche/web_transport/web_transport_headers.h +++ b/quiche/web_transport/web_transport_headers.h
@@ -5,6 +5,7 @@ #ifndef QUICHE_WEB_TRANSPORT_WEB_TRANSPORT_HEADERS_H_ #define QUICHE_WEB_TRANSPORT_WEB_TRANSPORT_HEADERS_H_ +#include <cstdint> #include <string> #include <vector> @@ -25,6 +26,35 @@ QUICHE_EXPORT absl::StatusOr<std::string> SerializeSubprotocolRequestHeader( absl::Span<const std::string> subprotocols); +inline constexpr absl::string_view kInitHeader = "WebTransport-Init"; + +// A deserialized representation of WebTransport-Init header that is used to +// indicate the initial stream flow control windows in WebTransport over HTTP/2. +// Specification: +// https://www.ietf.org/archive/id/draft-ietf-webtrans-http2-07.html#name-flow-control-header-field +struct QUICHE_EXPORT WebTransportInitHeader { + // Initial flow control window for unidirectional streams opened by the + // header's recipient. + uint64_t initial_unidi_limit = 0; + // Initial flow control window for bidirectional streams opened by the + // header's recipient. + uint64_t initial_incoming_bidi_limit = 0; + // Initial flow control window for bidirectional streams opened by the + // header's sender. + uint64_t initial_outgoing_bidi_limit = 0; + + bool operator==(const WebTransportInitHeader& other) const { + return initial_unidi_limit == other.initial_unidi_limit && + initial_incoming_bidi_limit == other.initial_incoming_bidi_limit && + initial_outgoing_bidi_limit == other.initial_outgoing_bidi_limit; + } +}; + +QUICHE_EXPORT absl::StatusOr<WebTransportInitHeader> ParseInitHeader( + absl::string_view header); +QUICHE_EXPORT absl::StatusOr<std::string> SerializeInitHeader( + const WebTransportInitHeader& header); + } // namespace webtransport #endif // QUICHE_WEB_TRANSPORT_WEB_TRANSPORT_HEADERS_H_
diff --git a/quiche/web_transport/web_transport_headers_test.cc b/quiche/web_transport/web_transport_headers_test.cc index 38cc2c4..056303a 100644 --- a/quiche/web_transport/web_transport_headers_test.cc +++ b/quiche/web_transport/web_transport_headers_test.cc
@@ -57,5 +57,48 @@ StatusIs(absl::StatusCode::kInvalidArgument, "Invalid token: 0123")); } +TEST(WebTransportHeader, ParseInitHeader) { + WebTransportInitHeader expected_header; + expected_header.initial_unidi_limit = 100; + expected_header.initial_incoming_bidi_limit = 200; + expected_header.initial_outgoing_bidi_limit = 400; + EXPECT_THAT(ParseInitHeader("br=400, bl=200, u=100"), + IsOkAndHolds(expected_header)); + EXPECT_THAT(ParseInitHeader("br=300, bl=200, u=100, br=400"), + IsOkAndHolds(expected_header)); + EXPECT_THAT(ParseInitHeader("br=400, bl=200; foo=bar, u=100"), + IsOkAndHolds(expected_header)); + EXPECT_THAT(ParseInitHeader("br=400, bl=200, u=100.0"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("found decimal instead"))); + EXPECT_THAT(ParseInitHeader("br=400, bl=200, u=?1"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("found boolean instead"))); + EXPECT_THAT(ParseInitHeader("br=400, bl=200, u=(a b)"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("found a nested list instead"))); + EXPECT_THAT(ParseInitHeader("br=400, bl=200, u=:abcd:"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("found byte sequence instead"))); + EXPECT_THAT(ParseInitHeader("br=400, bl=200, u=-1"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("negative value"))); + EXPECT_THAT(ParseInitHeader("br=400, bl=200, u=18446744073709551615"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Failed to parse"))); +} + +TEST(WebTransportHeaders, SerializeInitHeader) { + EXPECT_THAT(SerializeInitHeader(WebTransportInitHeader{}), + IsOkAndHolds("u=0, bl=0, br=0")); + + WebTransportInitHeader test_header; + test_header.initial_unidi_limit = 100; + test_header.initial_incoming_bidi_limit = 200; + test_header.initial_outgoing_bidi_limit = 400; + EXPECT_THAT(SerializeInitHeader(test_header), + IsOkAndHolds("u=100, bl=200, br=400")); +} + } // namespace } // namespace webtransport