Make QUIC enforce ALPN when using TLS handshake

gfe-relnote: enforce ALPN when using TLS, protected by disabled quic_tls flag
PiperOrigin-RevId: 261159061
Change-Id: I9ccdd221e92beae2b83677e692e3c6d084351731
diff --git a/quic/core/tls_server_handshaker.cc b/quic/core/tls_server_handshaker.cc
index 0e7d9e5..0f2c0d2 100644
--- a/quic/core/tls_server_handshaker.cc
+++ b/quic/core/tls_server_handshaker.cc
@@ -61,7 +61,7 @@
 
   if (!SetTransportParameters()) {
     CloseConnection(QUIC_HANDSHAKE_FAILED,
-                    "Failed to set Transport Parameters");
+                    "Server failed to set Transport Parameters");
   }
 }
 
@@ -170,7 +170,8 @@
     QUIC_LOG(WARNING) << "SSL_do_handshake failed; SSL_get_error returns "
                       << ssl_error << ", state_ = " << state_;
     ERR_print_errors_fp(stderr);
-    CloseConnection(QUIC_HANDSHAKE_FAILED, "TLS handshake failed");
+    CloseConnection(QUIC_HANDSHAKE_FAILED,
+                    "Server observed TLS handshake failure");
   }
 }
 
@@ -239,6 +240,16 @@
 }
 
 void TlsServerHandshaker::FinishHandshake() {
+  if (!valid_alpn_received_) {
+    QUIC_DLOG(ERROR)
+        << "Server: handshake finished without receiving a known ALPN";
+    // TODO(b/130164908) this should send no_application_protocol
+    // instead of QUIC_HANDSHAKE_FAILED.
+    CloseConnection(QUIC_HANDSHAKE_FAILED,
+                    "Server did not receive a known ALPN");
+    return;
+  }
+
   QUIC_LOG(INFO) << "Server: handshake finished";
   state_ = STATE_HANDSHAKE_COMPLETE;
 
@@ -328,26 +339,48 @@
                                     const uint8_t* in,
                                     unsigned in_len) {
   // |in| contains a sequence of 1-byte-length-prefixed values.
-  // We currently simply return the first provided ALPN value.
-  // TODO(b/130164908) Act on ALPN.
+  *out_len = 0;
+  *out = nullptr;
   if (in_len == 0) {
-    *out_len = 0;
-    *out = nullptr;
-    QUIC_DLOG(INFO) << "No ALPN provided";
-    return SSL_TLSEXT_ERR_OK;
+    QUIC_DLOG(ERROR) << "No ALPN provided by client";
+    return SSL_TLSEXT_ERR_NOACK;
   }
-  const uint8_t first_alpn_length = in[0];
-  if (static_cast<unsigned>(first_alpn_length) > in_len - 1) {
-    QUIC_LOG(ERROR) << "Failed to parse ALPN";
-    return SSL_TLSEXT_ERR_ALERT_FATAL;
+
+  std::string expected_alpn_string =
+      AlpnForVersion(session()->connection()->version());
+
+  CBS all_alpns;
+  CBS_init(&all_alpns, in, in_len);
+
+  while (CBS_len(&all_alpns) > 0) {
+    CBS alpn;
+    if (!CBS_get_u8_length_prefixed(&all_alpns, &alpn)) {
+      QUIC_DLOG(ERROR) << "Failed to parse ALPN length";
+      return SSL_TLSEXT_ERR_NOACK;
+    }
+    const size_t alpn_length = CBS_len(&alpn);
+    if (alpn_length >
+        static_cast<size_t>(std::numeric_limits<uint8_t>::max())) {
+      QUIC_BUG << "Parsed impossible ALPN length " << alpn_length;
+      return SSL_TLSEXT_ERR_NOACK;
+    }
+    if (alpn_length == 0) {
+      QUIC_DLOG(ERROR) << "Received invalid zero-length ALPN";
+      return SSL_TLSEXT_ERR_NOACK;
+    }
+    std::string alpn_string(reinterpret_cast<const char*>(CBS_data(&alpn)),
+                            alpn_length);
+    if (alpn_string == expected_alpn_string) {
+      QUIC_DLOG(INFO) << "Server selecting ALPN '" << alpn_string << "'";
+      *out_len = static_cast<uint8_t>(alpn_length);
+      *out = CBS_data(&alpn);
+      valid_alpn_received_ = true;
+      return SSL_TLSEXT_ERR_OK;
+    }
   }
-  *out_len = first_alpn_length;
-  *out = in + 1;
-  QUIC_DLOG(INFO) << "Server selecting ALPN '"
-                  << QuicStringPiece(reinterpret_cast<const char*>(*out),
-                                     *out_len)
-                  << "'";
-  return SSL_TLSEXT_ERR_OK;
+
+  QUIC_DLOG(ERROR) << "No known ALPN provided by client";
+  return SSL_TLSEXT_ERR_NOACK;
 }
 
 }  // namespace quic