Introduce IsConnectionIdLengthValidForVersion This CL is part of a larger change to allow the new connection ID invariants. It adds a new method QuicUtils::IsConnectionIdLengthValidForVersion() whose goal will be to replace uses of kQuicMaxConnectionIdLength in the codebase. This CL also plumbs the QUIC version to the TLS transport parameter parse/serialize code so it can call IsConnectionIdLengthValidForVersion. I suspect the transport parameter code will eventually need the version anyway as we create more QUIC versions that support TLS. gfe-relnote: refactor, protected by disabled quic_enable_v47 flag. PiperOrigin-RevId: 260938227 Change-Id: I590f7117de2b245044469e6dcdcca6f503c7a625
diff --git a/quic/core/crypto/transport_parameters.cc b/quic/core/crypto/transport_parameters.cc index 247f407..8f2349c 100644 --- a/quic/core/crypto/transport_parameters.cc +++ b/quic/core/crypto/transport_parameters.cc
@@ -13,6 +13,7 @@ #include "net/third_party/quiche/src/quic/core/quic_data_reader.h" #include "net/third_party/quiche/src/quic/core/quic_data_writer.h" #include "net/third_party/quiche/src/quic/core/quic_types.h" +#include "net/third_party/quiche/src/quic/core/quic_utils.h" #include "net/third_party/quiche/src/quic/core/quic_versions.h" #include "net/third_party/quiche/src/quic/platform/api/quic_bug_tracker.h" #include "net/third_party/quiche/src/quic/platform/api/quic_ptr_util.h" @@ -365,7 +366,8 @@ TransportParameters::~TransportParameters() = default; -bool SerializeTransportParameters(const TransportParameters& in, +bool SerializeTransportParameters(ParsedQuicVersion /*version*/, + const TransportParameters& in, std::vector<uint8_t>* out) { if (!in.AreValid()) { QUIC_DLOG(ERROR) << "Not serializing invalid transport parameters " << in; @@ -543,9 +545,10 @@ return true; } -bool ParseTransportParameters(const uint8_t* in, - size_t in_len, +bool ParseTransportParameters(ParsedQuicVersion version, Perspective perspective, + const uint8_t* in, + size_t in_len, TransportParameters* out) { out->perspective = perspective; CBS cbs; @@ -572,22 +575,25 @@ } bool parse_success = true; switch (param_id) { - case TransportParameters::kOriginalConnectionId: + case TransportParameters::kOriginalConnectionId: { if (!out->original_connection_id.IsEmpty()) { QUIC_DLOG(ERROR) << "Received a second original connection ID"; return false; } - if (CBS_len(&value) > static_cast<size_t>(kQuicMaxConnectionIdLength)) { + const size_t connection_id_length = CBS_len(&value); + if (!QuicUtils::IsConnectionIdLengthValidForVersion( + connection_id_length, version.transport_version)) { QUIC_DLOG(ERROR) << "Received original connection ID of " - << "invalid length " << CBS_len(&value); + << "invalid length " << connection_id_length; return false; } - if (CBS_len(&value) != 0) { - out->original_connection_id.set_length(CBS_len(&value)); + out->original_connection_id.set_length( + static_cast<uint8_t>(connection_id_length)); + if (out->original_connection_id.length() != 0) { memcpy(out->original_connection_id.mutable_data(), CBS_data(&value), - CBS_len(&value)); + out->original_connection_id.length()); } - break; + } break; case TransportParameters::kIdleTimeout: parse_success = out->idle_timeout_milliseconds.ReadFromCbs(&value); break; @@ -675,11 +681,15 @@ << "Failed to parse length of preferred address connection ID"; return false; } - if (CBS_len(&connection_id_cbs) > kQuicMaxConnectionIdLength) { - QUIC_DLOG(ERROR) << "Bad preferred address connection ID length"; + const size_t connection_id_length = CBS_len(&connection_id_cbs); + if (!QuicUtils::IsConnectionIdLengthValidForVersion( + connection_id_length, version.transport_version)) { + QUIC_DLOG(ERROR) << "Received preferred address connection ID of " + << "invalid length " << connection_id_length; return false; } - preferred_address.connection_id.set_length(CBS_len(&connection_id_cbs)); + preferred_address.connection_id.set_length( + static_cast<uint8_t>(connection_id_length)); if (preferred_address.connection_id.length() > 0 && !CBS_copy_bytes(&connection_id_cbs, reinterpret_cast<uint8_t*>(
diff --git a/quic/core/crypto/transport_parameters.h b/quic/core/crypto/transport_parameters.h index 368a7bf..95a01da 100644 --- a/quic/core/crypto/transport_parameters.h +++ b/quic/core/crypto/transport_parameters.h
@@ -185,6 +185,7 @@ // TLS extension. The serialized bytes are written to |*out|. Returns if the // parameters are valid and serialization succeeded. QUIC_EXPORT_PRIVATE bool SerializeTransportParameters( + ParsedQuicVersion version, const TransportParameters& in, std::vector<uint8_t>* out); @@ -193,9 +194,10 @@ // |perspective| indicates whether the input came from a client or a server. // This method returns true if the input was successfully parsed. // TODO(nharper): Write fuzz tests for this method. -QUIC_EXPORT_PRIVATE bool ParseTransportParameters(const uint8_t* in, - size_t in_len, +QUIC_EXPORT_PRIVATE bool ParseTransportParameters(ParsedQuicVersion version, Perspective perspective, + const uint8_t* in, + size_t in_len, TransportParameters* out); } // namespace quic
diff --git a/quic/core/crypto/transport_parameters_test.cc b/quic/core/crypto/transport_parameters_test.cc index 3f7e339..7379224 100644 --- a/quic/core/crypto/transport_parameters_test.cc +++ b/quic/core/crypto/transport_parameters_test.cc
@@ -7,6 +7,7 @@ #include <cstring> #include "third_party/boringssl/src/include/openssl/mem.h" +#include "net/third_party/quiche/src/quic/core/quic_versions.h" #include "net/third_party/quiche/src/quic/platform/api/quic_arraysize.h" #include "net/third_party/quiche/src/quic/platform/api/quic_ip_address.h" #include "net/third_party/quiche/src/quic/platform/api/quic_ptr_util.h" @@ -16,6 +17,7 @@ namespace quic { namespace test { namespace { +const ParsedQuicVersion kVersion(PROTOCOL_TLS1_3, QUIC_VERSION_99); const QuicVersionLabel kFakeVersionLabel = 0x01234567; const QuicVersionLabel kFakeVersionLabel2 = 0x89ABCDEF; const QuicConnectionId kFakeOriginalConnectionId = TestConnectionId(0x1337); @@ -101,11 +103,12 @@ kFakeActiveConnectionIdLimit); std::vector<uint8_t> serialized; - ASSERT_TRUE(SerializeTransportParameters(orig_params, &serialized)); + ASSERT_TRUE(SerializeTransportParameters(kVersion, orig_params, &serialized)); TransportParameters new_params; - ASSERT_TRUE(ParseTransportParameters(serialized.data(), serialized.size(), - Perspective::IS_CLIENT, &new_params)); + ASSERT_TRUE(ParseTransportParameters(kVersion, Perspective::IS_CLIENT, + serialized.data(), serialized.size(), + &new_params)); EXPECT_EQ(Perspective::IS_CLIENT, new_params.perspective); EXPECT_EQ(kFakeVersionLabel, new_params.version); @@ -160,11 +163,12 @@ kFakeActiveConnectionIdLimit); std::vector<uint8_t> serialized; - ASSERT_TRUE(SerializeTransportParameters(orig_params, &serialized)); + ASSERT_TRUE(SerializeTransportParameters(kVersion, orig_params, &serialized)); TransportParameters new_params; - ASSERT_TRUE(ParseTransportParameters(serialized.data(), serialized.size(), - Perspective::IS_SERVER, &new_params)); + ASSERT_TRUE(ParseTransportParameters(kVersion, Perspective::IS_SERVER, + serialized.data(), serialized.size(), + &new_params)); EXPECT_EQ(Perspective::IS_SERVER, new_params.perspective); EXPECT_EQ(kFakeVersionLabel, new_params.version); @@ -255,7 +259,7 @@ orig_params.max_packet_size.set_value(kFakeMaxPacketSize); std::vector<uint8_t> out; - EXPECT_FALSE(SerializeTransportParameters(orig_params, &out)); + EXPECT_FALSE(SerializeTransportParameters(kVersion, orig_params, &out)); } TEST_F(TransportParametersTest, ParseClientParams) { @@ -317,9 +321,9 @@ // clang-format on TransportParameters new_params; - ASSERT_TRUE(ParseTransportParameters(kClientParams, - QUIC_ARRAYSIZE(kClientParams), - Perspective::IS_CLIENT, &new_params)); + ASSERT_TRUE( + ParseTransportParameters(kVersion, Perspective::IS_CLIENT, kClientParams, + QUIC_ARRAYSIZE(kClientParams), &new_params)); EXPECT_EQ(Perspective::IS_CLIENT, new_params.perspective); EXPECT_EQ(kFakeVersionLabel, new_params.version); @@ -374,8 +378,8 @@ // clang-format on EXPECT_FALSE(ParseTransportParameters( - kClientParamsWithFullToken, QUIC_ARRAYSIZE(kClientParamsWithFullToken), - Perspective::IS_CLIENT, &out_params)); + kVersion, Perspective::IS_CLIENT, kClientParamsWithFullToken, + QUIC_ARRAYSIZE(kClientParamsWithFullToken), &out_params)); // clang-format off const uint8_t kClientParamsWithEmptyToken[] = { @@ -399,8 +403,8 @@ // clang-format on EXPECT_FALSE(ParseTransportParameters( - kClientParamsWithEmptyToken, QUIC_ARRAYSIZE(kClientParamsWithEmptyToken), - Perspective::IS_CLIENT, &out_params)); + kVersion, Perspective::IS_CLIENT, kClientParamsWithEmptyToken, + QUIC_ARRAYSIZE(kClientParamsWithEmptyToken), &out_params)); } TEST_F(TransportParametersTest, ParseClientParametersRepeated) { @@ -425,9 +429,9 @@ }; // clang-format on TransportParameters out_params; - EXPECT_FALSE(ParseTransportParameters(kClientParamsRepeated, - QUIC_ARRAYSIZE(kClientParamsRepeated), - Perspective::IS_CLIENT, &out_params)); + EXPECT_FALSE(ParseTransportParameters( + kVersion, Perspective::IS_CLIENT, kClientParamsRepeated, + QUIC_ARRAYSIZE(kClientParamsRepeated), &out_params)); } TEST_F(TransportParametersTest, ParseServerParams) { @@ -513,9 +517,9 @@ // clang-format on TransportParameters new_params; - ASSERT_TRUE(ParseTransportParameters(kServerParams, - QUIC_ARRAYSIZE(kServerParams), - Perspective::IS_SERVER, &new_params)); + ASSERT_TRUE( + ParseTransportParameters(kVersion, Perspective::IS_SERVER, kServerParams, + QUIC_ARRAYSIZE(kServerParams), &new_params)); EXPECT_EQ(Perspective::IS_SERVER, new_params.perspective); EXPECT_EQ(kFakeVersionLabel, new_params.version); @@ -579,9 +583,9 @@ // clang-format on TransportParameters out_params; - EXPECT_FALSE(ParseTransportParameters(kServerParamsRepeated, - QUIC_ARRAYSIZE(kServerParamsRepeated), - Perspective::IS_SERVER, &out_params)); + EXPECT_FALSE(ParseTransportParameters( + kVersion, Perspective::IS_SERVER, kServerParamsRepeated, + QUIC_ARRAYSIZE(kServerParamsRepeated), &out_params)); } TEST_F(TransportParametersTest, CryptoHandshakeMessageRoundtrip) { @@ -597,11 +601,12 @@ orig_params.google_quic_params->SetValue(1337, kTestValue); std::vector<uint8_t> serialized; - ASSERT_TRUE(SerializeTransportParameters(orig_params, &serialized)); + ASSERT_TRUE(SerializeTransportParameters(kVersion, orig_params, &serialized)); TransportParameters new_params; - ASSERT_TRUE(ParseTransportParameters(serialized.data(), serialized.size(), - Perspective::IS_CLIENT, &new_params)); + ASSERT_TRUE(ParseTransportParameters(kVersion, Perspective::IS_CLIENT, + serialized.data(), serialized.size(), + &new_params)); ASSERT_NE(new_params.google_quic_params.get(), nullptr); EXPECT_EQ(new_params.google_quic_params->tag(),
diff --git a/quic/core/quic_connection_id.cc b/quic/core/quic_connection_id.cc index c560bcc..3ad13a3 100644 --- a/quic/core/quic_connection_id.cc +++ b/quic/core/quic_connection_id.cc
@@ -125,6 +125,10 @@ } void QuicConnectionId::set_length(uint8_t length) { + if (length > kQuicMaxConnectionIdLength) { + QUIC_BUG << "Attempted to set connection ID length to " << length; + length = kQuicMaxConnectionIdLength; + } if (GetQuicRestartFlag(quic_use_allocated_connection_ids)) { QUIC_RESTART_FLAG_COUNT_N(quic_use_allocated_connection_ids, 5, 6); char temporary_data[sizeof(data_short_)];
diff --git a/quic/core/quic_utils.cc b/quic/core/quic_utils.cc index 160ac65..a9de14a 100644 --- a/quic/core/quic_utils.cc +++ b/quic/core/quic_utils.cc
@@ -540,12 +540,30 @@ } // static -bool QuicUtils::IsConnectionIdValidForVersion(QuicConnectionId connection_id, - QuicTransportVersion version) { - if (VariableLengthConnectionIdAllowedForVersion(version)) { - return true; +bool QuicUtils::IsConnectionIdLengthValidForVersion( + size_t connection_id_length, + QuicTransportVersion transport_version) { + // No version of QUIC can support lengths that do not fit in an uint8_t. + if (connection_id_length > + static_cast<size_t>(std::numeric_limits<uint8_t>::max())) { + return false; } - return connection_id.length() == kQuicDefaultConnectionIdLength; + const uint8_t connection_id_length8 = + static_cast<uint8_t>(connection_id_length); + // Versions that do not support variable lengths only support length 8. + if (!VariableLengthConnectionIdAllowedForVersion(transport_version)) { + return connection_id_length8 == kQuicDefaultConnectionIdLength; + } + // Currently all other versions require the length to be at most 18 bytes. + return connection_id_length8 <= kQuicMaxConnectionIdLength; +} + +// static +bool QuicUtils::IsConnectionIdValidForVersion( + QuicConnectionId connection_id, + QuicTransportVersion transport_version) { + return IsConnectionIdLengthValidForVersion(connection_id.length(), + transport_version); } QuicUint128 QuicUtils::GenerateStatelessResetToken(
diff --git a/quic/core/quic_utils.h b/quic/core/quic_utils.h index db29740..11154a5 100644 --- a/quic/core/quic_utils.h +++ b/quic/core/quic_utils.h
@@ -181,9 +181,15 @@ static bool VariableLengthConnectionIdAllowedForVersion( QuicTransportVersion version); + // Returns true if the connection ID length is valid for this QUIC version. + static bool IsConnectionIdLengthValidForVersion( + size_t connection_id_length, + QuicTransportVersion transport_version); + // Returns true if the connection ID is valid for this QUIC version. - static bool IsConnectionIdValidForVersion(QuicConnectionId connection_id, - QuicTransportVersion version); + static bool IsConnectionIdValidForVersion( + QuicConnectionId connection_id, + QuicTransportVersion transport_version); // Returns a connection ID suitable for QUIC use-cases that do not need the // connection ID for multiplexing. If the version allows variable lengths,
diff --git a/quic/core/tls_client_handshaker.cc b/quic/core/tls_client_handshaker.cc index d6fc038..735fd3b 100644 --- a/quic/core/tls_client_handshaker.cc +++ b/quic/core/tls_client_handshaker.cc
@@ -120,7 +120,8 @@ params.google_quic_params->SetStringPiece(kUAID, user_agent_id_); std::vector<uint8_t> param_bytes; - return SerializeTransportParameters(params, ¶m_bytes) && + return SerializeTransportParameters(session()->connection()->version(), + params, ¶m_bytes) && SSL_set_quic_transport_params(ssl(), param_bytes.data(), param_bytes.size()) == 1; } @@ -132,8 +133,9 @@ size_t param_bytes_len; SSL_get_peer_quic_transport_params(ssl(), ¶m_bytes, ¶m_bytes_len); if (param_bytes_len == 0 || - !ParseTransportParameters(param_bytes, param_bytes_len, - Perspective::IS_SERVER, ¶ms)) { + !ParseTransportParameters(session()->connection()->version(), + Perspective::IS_SERVER, param_bytes, + param_bytes_len, ¶ms)) { *error_details = "Unable to parse Transport Parameters"; return false; }
diff --git a/quic/core/tls_server_handshaker.cc b/quic/core/tls_server_handshaker.cc index 8172104..0e7d9e5 100644 --- a/quic/core/tls_server_handshaker.cc +++ b/quic/core/tls_server_handshaker.cc
@@ -188,8 +188,9 @@ SSL_get_peer_quic_transport_params(ssl(), &client_params_bytes, ¶ms_bytes_len); if (params_bytes_len == 0 || - !ParseTransportParameters(client_params_bytes, params_bytes_len, - Perspective::IS_CLIENT, &client_params)) { + !ParseTransportParameters(session()->connection()->version(), + Perspective::IS_CLIENT, client_params_bytes, + params_bytes_len, &client_params)) { *error_details = "Unable to parse Transport Parameters"; return false; } @@ -228,7 +229,8 @@ // TODO(nharper): Provide an actual value for the stateless reset token. server_params.stateless_reset_token.resize(16); std::vector<uint8_t> server_params_bytes; - if (!SerializeTransportParameters(server_params, &server_params_bytes) || + if (!SerializeTransportParameters(session()->connection()->version(), + server_params, &server_params_bytes) || SSL_set_quic_transport_params(ssl(), server_params_bytes.data(), server_params_bytes.size()) != 1) { return false;