Implement QUIC ALPN selection on the server side.

This also fixes the bugs in client side found by a full custom ALPN test, and removes a tautological check from server-side ALPN parser.

gfe-relnote: n/a (protected by disabled quic_tls flag)
PiperOrigin-RevId: 266319592
Change-Id: I9e06b383abe187286f31d3cbce8be99e9370c9f2
diff --git a/quic/core/quic_session.cc b/quic/core/quic_session.cc
index 98aa242..57de377 100644
--- a/quic/core/quic_session.cc
+++ b/quic/core/quic_session.cc
@@ -1942,5 +1942,17 @@
   return stream_id_manager_.max_open_incoming_streams();
 }
 
+std::vector<QuicStringPiece>::const_iterator QuicSession::SelectAlpn(
+    const std::vector<QuicStringPiece>& alpns) const {
+  const std::string alpn = AlpnForVersion(connection()->version());
+  return std::find(alpns.cbegin(), alpns.cend(), alpn);
+}
+
+void QuicSession::OnAlpnSelected(QuicStringPiece alpn) {
+  QUIC_DLOG(INFO) << (perspective() == Perspective::IS_SERVER ? "Server: "
+                                                              : "Client: ")
+                  << "ALPN selected: " << alpn;
+}
+
 #undef ENDPOINT  // undef for jumbo builds
 }  // namespace quic
diff --git a/quic/core/quic_session.h b/quic/core/quic_session.h
index d5ddc69..16c7efc 100644
--- a/quic/core/quic_session.h
+++ b/quic/core/quic_session.h
@@ -460,12 +460,21 @@
   }
 
   // Returns the ALPN values to negotiate on this session.
-  virtual std::vector<std::string> GetAlpnsToOffer() {
+  virtual std::vector<std::string> GetAlpnsToOffer() const {
     // TODO(vasilvv): this currently sets HTTP/3 by default.  Switch all
     // non-HTTP applications to appropriate ALPNs.
     return std::vector<std::string>({AlpnForVersion(connection()->version())});
   }
 
+  // Provided a list of ALPNs offered by the client, selects an ALPN from the
+  // list, or alpns.end() if none of the ALPNs are acceptable.
+  virtual std::vector<QuicStringPiece>::const_iterator SelectAlpn(
+      const std::vector<QuicStringPiece>& alpns) const;
+
+  // Called when the ALPN of the connection is established for a connection that
+  // uses TLS handshake.
+  virtual void OnAlpnSelected(QuicStringPiece alpn);
+
  protected:
   using StreamMap = QuicSmallMap<QuicStreamId, std::unique_ptr<QuicStream>, 10>;
 
diff --git a/quic/core/tls_client_handshaker.cc b/quic/core/tls_client_handshaker.cc
index a410f81..bb45602 100644
--- a/quic/core/tls_client_handshaker.cc
+++ b/quic/core/tls_client_handshaker.cc
@@ -304,17 +304,17 @@
 
   std::string received_alpn_string(reinterpret_cast<const char*>(alpn_data),
                                    alpn_length);
-  std::string sent_alpn_string =
-      AlpnForVersion(session()->connection()->version());
-  if (received_alpn_string != sent_alpn_string) {
+  std::vector<std::string> offered_alpns = session()->GetAlpnsToOffer();
+  if (std::find(offered_alpns.begin(), offered_alpns.end(),
+                received_alpn_string) == offered_alpns.end()) {
     QUIC_LOG(ERROR) << "Client: received mismatched ALPN '"
-                    << received_alpn_string << "', expected '"
-                    << sent_alpn_string << "'";
+                    << received_alpn_string;
     // TODO(b/130164908) this should send no_application_protocol
     // instead of QUIC_HANDSHAKE_FAILED.
     CloseConnection(QUIC_HANDSHAKE_FAILED, "Client received mismatched ALPN");
     return;
   }
+  session()->OnAlpnSelected(received_alpn_string);
   QUIC_DLOG(INFO) << "Client: server selected ALPN: '" << received_alpn_string
                   << "'";
 
diff --git a/quic/core/tls_handshaker_test.cc b/quic/core/tls_handshaker_test.cc
index ef7f31b..a298089 100644
--- a/quic/core/tls_handshaker_test.cc
+++ b/quic/core/tls_handshaker_test.cc
@@ -23,6 +23,7 @@
 namespace {
 
 using ::testing::_;
+using ::testing::ElementsAreArray;
 using ::testing::Return;
 
 class FakeProofVerifier : public ProofVerifier {
@@ -305,9 +306,15 @@
     EXPECT_FALSE(client_stream_->handshake_confirmed());
     EXPECT_FALSE(server_stream_->encryption_established());
     EXPECT_FALSE(server_stream_->handshake_confirmed());
+    const std::string default_alpn =
+        AlpnForVersion(client_session_.connection()->version());
     ON_CALL(client_session_, GetAlpnsToOffer())
-        .WillByDefault(Return(std::vector<std::string>(
-            {AlpnForVersion(client_session_.connection()->version())})));
+        .WillByDefault(Return(std::vector<std::string>({default_alpn})));
+    ON_CALL(server_session_, SelectAlpn(_))
+        .WillByDefault(
+            [default_alpn](const std::vector<QuicStringPiece>& alpns) {
+              return std::find(alpns.begin(), alpns.end(), default_alpn);
+            });
   }
 
   MockQuicConnectionHelper conn_helper_;
@@ -499,6 +506,59 @@
   EXPECT_QUIC_BUG(client_stream_->CryptoConnect(), "Failed to set ALPN");
 }
 
+TEST_F(TlsHandshakerTest, ServerRequiresCustomALPN) {
+  static const std::string kTestAlpn = "An ALPN That Client Did Not Offer";
+  EXPECT_CALL(server_session_, SelectAlpn(_))
+      .WillOnce([](const std::vector<QuicStringPiece>& alpns) {
+        return std::find(alpns.cbegin(), alpns.cend(), kTestAlpn);
+      });
+  EXPECT_CALL(*client_conn_, CloseConnection(QUIC_HANDSHAKE_FAILED,
+                                             "Server did not select ALPN", _));
+  EXPECT_CALL(*server_conn_,
+              CloseConnection(QUIC_HANDSHAKE_FAILED,
+                              "Server did not receive a known ALPN", _));
+  client_stream_->CryptoConnect();
+  ExchangeHandshakeMessages(client_stream_, server_stream_);
+
+  EXPECT_FALSE(client_stream_->handshake_confirmed());
+  EXPECT_FALSE(client_stream_->encryption_established());
+  EXPECT_FALSE(server_stream_->handshake_confirmed());
+  EXPECT_FALSE(server_stream_->encryption_established());
+}
+
+TEST_F(TlsHandshakerTest, CustomALPNNegotiation) {
+  EXPECT_CALL(*client_conn_, CloseConnection(_, _, _)).Times(0);
+  EXPECT_CALL(*server_conn_, CloseConnection(_, _, _)).Times(0);
+  EXPECT_CALL(client_session_,
+              OnCryptoHandshakeEvent(QuicSession::ENCRYPTION_ESTABLISHED));
+  EXPECT_CALL(client_session_,
+              OnCryptoHandshakeEvent(QuicSession::HANDSHAKE_CONFIRMED));
+  EXPECT_CALL(server_session_,
+              OnCryptoHandshakeEvent(QuicSession::HANDSHAKE_CONFIRMED));
+
+  static const std::string kTestAlpn = "A Custom ALPN Value";
+  static const std::vector<std::string> kTestAlpns(
+      {"foo", "bar", kTestAlpn, "something else"});
+  EXPECT_CALL(client_session_, GetAlpnsToOffer())
+      .WillRepeatedly(Return(kTestAlpns));
+  EXPECT_CALL(server_session_, SelectAlpn(_))
+      .WillOnce([](const std::vector<QuicStringPiece>& alpns) {
+        EXPECT_THAT(alpns, ElementsAreArray(kTestAlpns));
+        return std::find(alpns.cbegin(), alpns.cend(), kTestAlpn);
+      });
+  EXPECT_CALL(client_session_, OnAlpnSelected(QuicStringPiece(kTestAlpn)));
+  EXPECT_CALL(server_session_, OnAlpnSelected(QuicStringPiece(kTestAlpn)));
+  client_stream_->CryptoConnect();
+  ExchangeHandshakeMessages(client_stream_, server_stream_);
+
+  EXPECT_TRUE(client_stream_->handshake_confirmed());
+  EXPECT_TRUE(client_stream_->encryption_established());
+  EXPECT_TRUE(server_stream_->handshake_confirmed());
+  EXPECT_TRUE(server_stream_->encryption_established());
+  EXPECT_TRUE(client_conn_->IsHandshakeConfirmed());
+  EXPECT_FALSE(server_conn_->IsHandshakeConfirmed());
+}
+
 }  // namespace
 }  // namespace test
 }  // namespace quic
diff --git a/quic/core/tls_server_handshaker.cc b/quic/core/tls_server_handshaker.cc
index a2ea397..0b4215d 100644
--- a/quic/core/tls_server_handshaker.cc
+++ b/quic/core/tls_server_handshaker.cc
@@ -352,41 +352,38 @@
     return SSL_TLSEXT_ERR_NOACK;
   }
 
-  std::string expected_alpn_string =
-      AlpnForVersion(session()->connection()->version());
-
   CBS all_alpns;
   CBS_init(&all_alpns, in, in_len);
 
+  std::vector<QuicStringPiece> alpns;
   while (CBS_len(&all_alpns) > 0) {
     CBS alpn;
     if (!CBS_get_u8_length_prefixed(&all_alpns, &alpn)) {
       QUIC_DLOG(ERROR) << "Failed to parse ALPN length";
       return SSL_TLSEXT_ERR_NOACK;
     }
+
     const size_t alpn_length = CBS_len(&alpn);
-    if (alpn_length >
-        static_cast<size_t>(std::numeric_limits<uint8_t>::max())) {
-      QUIC_BUG << "Parsed impossible ALPN length " << alpn_length;
-      return SSL_TLSEXT_ERR_NOACK;
-    }
     if (alpn_length == 0) {
       QUIC_DLOG(ERROR) << "Received invalid zero-length ALPN";
       return SSL_TLSEXT_ERR_NOACK;
     }
-    std::string alpn_string(reinterpret_cast<const char*>(CBS_data(&alpn)),
-                            alpn_length);
-    if (alpn_string == expected_alpn_string) {
-      QUIC_DLOG(INFO) << "Server selecting ALPN '" << alpn_string << "'";
-      *out_len = static_cast<uint8_t>(alpn_length);
-      *out = CBS_data(&alpn);
-      valid_alpn_received_ = true;
-      return SSL_TLSEXT_ERR_OK;
-    }
+
+    alpns.emplace_back(reinterpret_cast<const char*>(CBS_data(&alpn)),
+                       alpn_length);
   }
 
-  QUIC_DLOG(ERROR) << "No known ALPN provided by client";
-  return SSL_TLSEXT_ERR_NOACK;
+  auto selected_alpn = session()->SelectAlpn(alpns);
+  if (selected_alpn == alpns.end()) {
+    QUIC_DLOG(ERROR) << "No known ALPN provided by client";
+    return SSL_TLSEXT_ERR_NOACK;
+  }
+
+  session()->OnAlpnSelected(*selected_alpn);
+  valid_alpn_received_ = true;
+  *out_len = selected_alpn->size();
+  *out = reinterpret_cast<const uint8_t*>(selected_alpn->data());
+  return SSL_TLSEXT_ERR_OK;
 }
 
 }  // namespace quic
diff --git a/quic/test_tools/quic_test_utils.h b/quic/test_tools/quic_test_utils.h
index bbde859..5750b9b 100644
--- a/quic/test_tools/quic_test_utils.h
+++ b/quic/test_tools/quic_test_utils.h
@@ -633,7 +633,11 @@
   MOCK_CONST_METHOD0(ShouldKeepConnectionAlive, bool());
   MOCK_METHOD2(SendStopSending, void(uint16_t code, QuicStreamId stream_id));
   MOCK_METHOD1(OnCryptoHandshakeEvent, void(QuicSession::CryptoHandshakeEvent));
-  MOCK_METHOD0(GetAlpnsToOffer, std::vector<std::string>());
+  MOCK_CONST_METHOD0(GetAlpnsToOffer, std::vector<std::string>());
+  MOCK_CONST_METHOD1(SelectAlpn,
+                     std::vector<QuicStringPiece>::const_iterator(
+                         const std::vector<QuicStringPiece>&));
+  MOCK_METHOD1(OnAlpnSelected, void(QuicStringPiece));
 
   using QuicSession::ActivateStream;