Internal QUICHE change PiperOrigin-RevId: 259859589 Change-Id: I33f879f5422c0fad5d3570dbb48b82bac23cefff
diff --git a/quic/core/http/end_to_end_test.cc b/quic/core/http/end_to_end_test.cc index 578df12..0f6a7b1 100644 --- a/quic/core/http/end_to_end_test.cc +++ b/quic/core/http/end_to_end_test.cc
@@ -225,8 +225,6 @@ chlo_multiplier_(0), stream_factory_(nullptr), support_server_push_(false), - override_server_connection_id_(nullptr), - override_client_connection_id_(nullptr), expected_server_connection_id_length_(kQuicDefaultConnectionIdLength) { SetQuicFlag(FLAGS_quic_supports_tls_handshake, true); client_supported_versions_ = GetParam().client_supported_versions; @@ -272,12 +270,8 @@ if (!pre_shared_key_client_.empty()) { client->client()->SetPreSharedKey(pre_shared_key_client_); } - if (override_server_connection_id_ != nullptr) { - client->UseConnectionId(*override_server_connection_id_); - } - if (override_client_connection_id_ != nullptr) { - client->UseClientConnectionId(*override_client_connection_id_); - } + client->UseConnectionIdLength(override_server_connection_id_length_); + client->UseClientConnectionIdLength(override_client_connection_id_length_); client->client()->set_max_allowed_push_id(kMaxQuicStreamId); client->Connect(); return client; @@ -554,8 +548,8 @@ bool support_server_push_; std::string pre_shared_key_client_; std::string pre_shared_key_server_; - QuicConnectionId* override_server_connection_id_; - QuicConnectionId* override_client_connection_id_; + int override_server_connection_id_length_ = -1; + int override_client_connection_id_length_ = -1; uint8_t expected_server_connection_id_length_; }; @@ -634,10 +628,13 @@ } TEST_P(EndToEndTest, SimpleRequestResponseZeroConnectionID) { - QuicConnectionId connection_id = QuicUtils::CreateZeroConnectionId( - GetParam().negotiated_version.transport_version); - override_server_connection_id_ = &connection_id; - expected_server_connection_id_length_ = connection_id.length(); + if (!QuicUtils::VariableLengthConnectionIdAllowedForVersion( + GetParam().negotiated_version.transport_version)) { + ASSERT_TRUE(Initialize()); + return; + } + override_server_connection_id_length_ = 0; + expected_server_connection_id_length_ = 0; ASSERT_TRUE(Initialize()); EXPECT_EQ(kFooResponseBody, client_->SendSynchronousRequest("/foo")); @@ -654,10 +651,13 @@ } TEST_P(EndToEndTestWithTls, ZeroConnectionID) { - QuicConnectionId connection_id = QuicUtils::CreateZeroConnectionId( - GetParam().negotiated_version.transport_version); - override_server_connection_id_ = &connection_id; - expected_server_connection_id_length_ = connection_id.length(); + if (!QuicUtils::VariableLengthConnectionIdAllowedForVersion( + GetParam().negotiated_version.transport_version)) { + ASSERT_TRUE(Initialize()); + return; + } + override_server_connection_id_length_ = 0; + expected_server_connection_id_length_ = 0; ASSERT_TRUE(Initialize()); EXPECT_EQ(kFooResponseBody, client_->SendSynchronousRequest("/foo")); @@ -673,9 +673,7 @@ ASSERT_TRUE(Initialize()); return; } - QuicConnectionId connection_id = - TestConnectionIdNineBytesLong(UINT64_C(0xBADbadBADbad)); - override_server_connection_id_ = &connection_id; + override_server_connection_id_length_ = 9; ASSERT_TRUE(Initialize()); EXPECT_EQ(kFooResponseBody, client_->SendSynchronousRequest("/foo")); EXPECT_EQ("200", client_->response_headers()->find(":status")->second); @@ -694,12 +692,7 @@ ASSERT_TRUE(Initialize()); return; } - const char connection_id_bytes[16] = {0xb0, 0xb1, 0xb2, 0xb3, 0xb4, 0xb5, - 0xb6, 0xb7, 0xb8, 0xb9, 0xba, 0xbb, - 0xbc, 0xbd, 0xbe, 0xbf}; - QuicConnectionId connection_id = - QuicConnectionId(connection_id_bytes, sizeof(connection_id_bytes)); - override_server_connection_id_ = &connection_id; + override_server_connection_id_length_ = 16; ASSERT_TRUE(Initialize()); EXPECT_EQ(kFooResponseBody, client_->SendSynchronousRequest("/foo")); EXPECT_EQ("200", client_->response_headers()->find(":status")->second); @@ -715,16 +708,15 @@ ASSERT_TRUE(Initialize()); return; } - QuicConnectionId client_connection_id = - TestConnectionId(UINT64_C(0xc1c2c3c4c5c6c7c8)); - override_client_connection_id_ = &client_connection_id; + override_client_connection_id_length_ = kQuicDefaultConnectionIdLength; ASSERT_TRUE(Initialize()); EXPECT_EQ(kFooResponseBody, client_->SendSynchronousRequest("/foo")); EXPECT_EQ("200", client_->response_headers()->find(":status")->second); - EXPECT_EQ(client_connection_id, client_->client() - ->client_session() - ->connection() - ->client_connection_id()); + EXPECT_EQ(override_client_connection_id_length_, client_->client() + ->client_session() + ->connection() + ->client_connection_id() + .length()); } TEST_P(EndToEndTestWithTls, ForcedVersionNegotiationAndClientConnectionId) { @@ -734,17 +726,16 @@ } client_supported_versions_.insert(client_supported_versions_.begin(), QuicVersionReservedForNegotiation()); - QuicConnectionId client_connection_id = - TestConnectionId(UINT64_C(0xc1c2c3c4c5c6c7c8)); - override_client_connection_id_ = &client_connection_id; + override_client_connection_id_length_ = kQuicDefaultConnectionIdLength; ASSERT_TRUE(Initialize()); ASSERT_TRUE(ServerSendsVersionNegotiation()); EXPECT_EQ(kFooResponseBody, client_->SendSynchronousRequest("/foo")); EXPECT_EQ("200", client_->response_headers()->find(":status")->second); - EXPECT_EQ(client_connection_id, client_->client() - ->client_session() - ->connection() - ->client_connection_id()); + EXPECT_EQ(override_client_connection_id_length_, client_->client() + ->client_session() + ->connection() + ->client_connection_id() + .length()); } TEST_P(EndToEndTestWithTls, ForcedVersionNegotiationAndBadConnectionIdLength) { @@ -755,9 +746,7 @@ } client_supported_versions_.insert(client_supported_versions_.begin(), QuicVersionReservedForNegotiation()); - QuicConnectionId connection_id = - TestConnectionIdNineBytesLong(UINT64_C(0xBADbadBADbad)); - override_server_connection_id_ = &connection_id; + override_server_connection_id_length_ = 9; ASSERT_TRUE(Initialize()); ASSERT_TRUE(ServerSendsVersionNegotiation()); EXPECT_EQ(kFooResponseBody, client_->SendSynchronousRequest("/foo")); @@ -780,18 +769,8 @@ } client_supported_versions_.insert(client_supported_versions_.begin(), QuicVersionReservedForNegotiation()); - const char connection_id_bytes[16] = {0xb0, 0xb1, 0xb2, 0xb3, 0xb4, 0xb5, - 0xb6, 0xb7, 0xb8, 0xb9, 0xba, 0xbb, - 0xbc, 0xbd, 0xbe, 0xbf}; - QuicConnectionId connection_id = - QuicConnectionId(connection_id_bytes, sizeof(connection_id_bytes)); - override_server_connection_id_ = &connection_id; - const char client_connection_id_bytes[18] = { - 0xc0, 0xc1, 0xc2, 0xc3, 0xc4, 0xc5, 0xc6, 0xc7, 0xc8, - 0xc9, 0xca, 0xcb, 0xcc, 0xcd, 0xce, 0xcf, 0xc0, 0xc1}; - QuicConnectionId client_connection_id = QuicConnectionId( - client_connection_id_bytes, sizeof(client_connection_id_bytes)); - override_client_connection_id_ = &client_connection_id; + override_server_connection_id_length_ = 16; + override_client_connection_id_length_ = 18; ASSERT_TRUE(Initialize()); ASSERT_TRUE(ServerSendsVersionNegotiation()); EXPECT_EQ(kFooResponseBody, client_->SendSynchronousRequest("/foo")); @@ -801,10 +780,11 @@ ->connection() ->connection_id() .length()); - EXPECT_EQ(client_connection_id, client_->client() - ->client_session() - ->connection() - ->client_connection_id()); + EXPECT_EQ(override_client_connection_id_length_, client_->client() + ->client_session() + ->connection() + ->client_connection_id() + .length()); } TEST_P(EndToEndTest, MixGoodAndBadConnectionIdLengths) { @@ -815,11 +795,9 @@ } // Start client_ which will use a bad connection ID length. - QuicConnectionId connection_id = - TestConnectionIdNineBytesLong(UINT64_C(0xBADbadBADbad)); - override_server_connection_id_ = &connection_id; + override_server_connection_id_length_ = 9; ASSERT_TRUE(Initialize()); - override_server_connection_id_ = nullptr; + override_server_connection_id_length_ = -1; // Start client2 which will use a good connection ID length. std::unique_ptr<QuicTestClient> client2(CreateQuicClient(nullptr)); @@ -935,10 +913,13 @@ } TEST_P(EndToEndTest, MultipleRequestResponseZeroConnectionID) { - QuicConnectionId connection_id = QuicUtils::CreateZeroConnectionId( - GetParam().negotiated_version.transport_version); - override_server_connection_id_ = &connection_id; - expected_server_connection_id_length_ = connection_id.length(); + if (!QuicUtils::VariableLengthConnectionIdAllowedForVersion( + GetParam().negotiated_version.transport_version)) { + ASSERT_TRUE(Initialize()); + return; + } + override_server_connection_id_length_ = 0; + expected_server_connection_id_length_ = 0; ASSERT_TRUE(Initialize()); EXPECT_EQ(kFooResponseBody, client_->SendSynchronousRequest("/foo"));
diff --git a/quic/core/quic_connection.cc b/quic/core/quic_connection.cc index 22ee18a..a13c506 100644 --- a/quic/core/quic_connection.cc +++ b/quic/core/quic_connection.cc
@@ -570,9 +570,11 @@ } if (QuicContainsValue(packet.versions, version())) { - const std::string error_details = - "Server already supports client's version and should have accepted the " - "connection."; + const std::string error_details = QuicStrCat( + "Server already supports client's version ", + ParsedQuicVersionToString(version()), + " and should have accepted the connection instead of sending {", + ParsedQuicVersionVectorToString(packet.versions), "}."); QUIC_DLOG(WARNING) << error_details; CloseConnection(QUIC_INVALID_VERSION_NEGOTIATION_PACKET, error_details, ConnectionCloseBehavior::SILENT_CLOSE);
diff --git a/quic/test_tools/quic_test_client.cc b/quic/test_tools/quic_test_client.cc index e2e33c7..ef92cb0 100644 --- a/quic/test_tools/quic_test_client.cc +++ b/quic/test_tools/quic_test_client.cc
@@ -233,9 +233,14 @@ } QuicConnectionId MockableQuicClient::GenerateNewConnectionId() { - return server_connection_id_overridden_ - ? override_server_connection_id_ - : QuicClient::GenerateNewConnectionId(); + if (server_connection_id_overridden_) { + return override_server_connection_id_; + } + if (override_server_connection_id_length_ >= 0) { + return QuicUtils::CreateRandomConnectionId( + override_server_connection_id_length_); + } + return QuicClient::GenerateNewConnectionId(); } void MockableQuicClient::UseConnectionId( @@ -244,9 +249,20 @@ override_server_connection_id_ = server_connection_id; } +void MockableQuicClient::UseConnectionIdLength( + int server_connection_id_length) { + override_server_connection_id_length_ = server_connection_id_length; +} + QuicConnectionId MockableQuicClient::GetClientConnectionId() { - return client_connection_id_overridden_ ? override_client_connection_id_ - : QuicClient::GetClientConnectionId(); + if (client_connection_id_overridden_) { + return override_client_connection_id_; + } + if (override_client_connection_id_length_ >= 0) { + return QuicUtils::CreateRandomConnectionId( + override_client_connection_id_length_); + } + return QuicClient::GetClientConnectionId(); } void MockableQuicClient::UseClientConnectionId( @@ -255,6 +271,11 @@ override_client_connection_id_ = client_connection_id; } +void MockableQuicClient::UseClientConnectionIdLength( + int client_connection_id_length) { + override_client_connection_id_length_ = client_connection_id_length; +} + void MockableQuicClient::UseWriter(QuicPacketWriterWrapper* writer) { mockable_network_helper()->UseWriter(writer); } @@ -769,12 +790,23 @@ client_->UseConnectionId(server_connection_id); } +void QuicTestClient::UseConnectionIdLength(int server_connection_id_length) { + DCHECK(!connected()); + client_->UseConnectionIdLength(server_connection_id_length); +} + void QuicTestClient::UseClientConnectionId( QuicConnectionId client_connection_id) { DCHECK(!connected()); client_->UseClientConnectionId(client_connection_id); } +void QuicTestClient::UseClientConnectionIdLength( + int client_connection_id_length) { + DCHECK(!connected()); + client_->UseClientConnectionIdLength(client_connection_id_length); +} + bool QuicTestClient::MigrateSocket(const QuicIpAddress& new_host) { return client_->MigrateSocket(new_host); }
diff --git a/quic/test_tools/quic_test_client.h b/quic/test_tools/quic_test_client.h index 7b83562..cd5b09f 100644 --- a/quic/test_tools/quic_test_client.h +++ b/quic/test_tools/quic_test_client.h
@@ -56,8 +56,10 @@ QuicConnectionId GenerateNewConnectionId() override; void UseConnectionId(QuicConnectionId server_connection_id); + void UseConnectionIdLength(int server_connection_id_length); QuicConnectionId GetClientConnectionId() override; void UseClientConnectionId(QuicConnectionId client_connection_id); + void UseClientConnectionIdLength(int client_connection_id_length); void UseWriter(QuicPacketWriterWrapper* writer); void set_peer_address(const QuicSocketAddress& address); @@ -74,9 +76,11 @@ // Server connection ID to use, if server_connection_id_overridden_ QuicConnectionId override_server_connection_id_; bool server_connection_id_overridden_; + int override_server_connection_id_length_ = -1; // Client connection ID to use, if client_connection_id_overridden_ QuicConnectionId override_client_connection_id_; bool client_connection_id_overridden_; + int override_client_connection_id_length_ = -1; CachedNetworkParameters cached_network_paramaters_; }; @@ -227,9 +231,15 @@ // Configures client_ to use a specific server connection ID instead of a // random one. void UseConnectionId(QuicConnectionId server_connection_id); + // Configures client_ to use a specific server connection ID length instead + // of the default of kQuicDefaultConnectionIdLength. + void UseConnectionIdLength(int server_connection_id_length); // Configures client_ to use a specific client connection ID instead of an // empty one. void UseClientConnectionId(QuicConnectionId client_connection_id); + // Configures client_ to use a specific client connection ID length instead + // of the default of zero. + void UseClientConnectionIdLength(int client_connection_id_length); // Returns nullptr if the maximum number of streams have already been created. QuicSpdyClientStream* GetOrCreateStream();