Add ALPN to QUIC when using TLS

This CL makes our client send ALPN when using QUIC with TLS, makes the server echo the first ALPN value, and allows quic_client to override the ALPN for IETF interop events.

gfe-relnote: protected by disabled flag quic_supports_tls_handshake
PiperOrigin-RevId: 242682444
Change-Id: I7e60fb61c0afe02283e38598de29df9018b71ee8
diff --git a/quic/core/quic_versions.cc b/quic/core/quic_versions.cc
index 62f9ce9..ca6bdae 100644
--- a/quic/core/quic_versions.cc
+++ b/quic/core/quic_versions.cc
@@ -25,9 +25,9 @@
   return MakeQuicTag(d, c, b, a);
 }
 
-// Version label for ParsedQuicVersion(PROTOCOL_TLS1_3, QUIC_VERSION_99).
-// Defaults to "T099". Can be overridden for IETF interop events.
-QuicVersionLabel kQuicT099VersionLabel = 0;
+// IETF draft version for ParsedQuicVersion(PROTOCOL_TLS1_3, QUIC_VERSION_99).
+// Overrides the version label and ALPN string for IETF interop events.
+int32_t kQuicT099IetfDraftVersion = 0;
 
 }  // namespace
 
@@ -73,8 +73,8 @@
       return MakeVersionLabel(proto, '0', '4', '7');
     case QUIC_VERSION_99:
       if (parsed_version.handshake_protocol == PROTOCOL_TLS1_3 &&
-          kQuicT099VersionLabel != 0) {
-        return kQuicT099VersionLabel;
+          kQuicT099IetfDraftVersion != 0) {
+        return 0xff000000 + kQuicT099IetfDraftVersion;
       }
       return MakeVersionLabel(proto, '0', '9', '9');
     default:
@@ -116,7 +116,7 @@
 }
 
 ParsedQuicVersion ParseQuicVersionString(std::string version_string) {
-  if (version_string.length() == 0) {
+  if (version_string.empty()) {
     return UnsupportedQuicVersion();
   }
   int quic_version_number = 0;
@@ -140,7 +140,7 @@
       }
     }
   }
-  // Still recognize T099 even if kQuicT099VersionLabel has been changed.
+  // Still recognize T099 even if kQuicT099IetfDraftVersion has been changed.
   if (FLAGS_quic_supports_tls_handshake && version_string == "T099") {
     return ParsedQuicVersion(PROTOCOL_TLS1_3, QUIC_VERSION_99);
   }
@@ -378,6 +378,15 @@
   return ParsedQuicVersion(PROTOCOL_UNSUPPORTED, QUIC_VERSION_UNSUPPORTED);
 }
 
+std::string AlpnForVersion(ParsedQuicVersion parsed_version) {
+  if (parsed_version.handshake_protocol == PROTOCOL_TLS1_3 &&
+      parsed_version.transport_version == QUIC_VERSION_99 &&
+      kQuicT099IetfDraftVersion != 0) {
+    return "h3-" + QuicTextUtils::Uint64ToString(kQuicT099IetfDraftVersion);
+  }
+  return "h3-google-" + ParsedQuicVersionToString(parsed_version);
+}
+
 void QuicVersionInitializeSupportForIetfDraft(int32_t draft_version) {
   if (draft_version == 0) {
     return;
@@ -387,7 +396,7 @@
     return;
   }
 
-  kQuicT099VersionLabel = 0xff000000 + draft_version;
+  kQuicT099IetfDraftVersion = draft_version;
 
   // Enable necessary flags.
   SetQuicFlag(&FLAGS_quic_supports_tls_handshake, true);
diff --git a/quic/core/quic_versions.h b/quic/core/quic_versions.h
index c270f95..cc780b4 100644
--- a/quic/core/quic_versions.h
+++ b/quic/core/quic_versions.h
@@ -359,6 +359,10 @@
   return transport_version == QUIC_VERSION_99;
 }
 
+// Returns the ALPN string to use in TLS for this version of QUIC.
+QUIC_EXPORT_PRIVATE std::string AlpnForVersion(
+    ParsedQuicVersion parsed_version);
+
 // Initializes support for the provided IETF draft version by setting flags
 // and the version label.
 QUIC_EXPORT_PRIVATE void QuicVersionInitializeSupportForIetfDraft(
diff --git a/quic/core/quic_versions_test.cc b/quic/core/quic_versions_test.cc
index 6139ca8..09ed5e8 100644
--- a/quic/core/quic_versions_test.cc
+++ b/quic/core/quic_versions_test.cc
@@ -556,6 +556,21 @@
   EXPECT_EQ(QUIC_VERSION_99, 99);
 }
 
+TEST_F(QuicVersionsTest, AlpnForVersion) {
+  FLAGS_quic_supports_tls_handshake = true;
+  ParsedQuicVersion parsed_version_q047 =
+      ParsedQuicVersion(PROTOCOL_QUIC_CRYPTO, QUIC_VERSION_47);
+  ParsedQuicVersion parsed_version_t047 =
+      ParsedQuicVersion(PROTOCOL_TLS1_3, QUIC_VERSION_47);
+  ParsedQuicVersion parsed_version_t099 =
+      ParsedQuicVersion(PROTOCOL_TLS1_3, QUIC_VERSION_99);
+  FLAGS_quic_supports_tls_handshake = false;
+
+  EXPECT_EQ("h3-google-Q047", AlpnForVersion(parsed_version_q047));
+  EXPECT_EQ("h3-google-T047", AlpnForVersion(parsed_version_t047));
+  EXPECT_EQ("h3-google-T099", AlpnForVersion(parsed_version_t099));
+}
+
 TEST_F(QuicVersionsTest, InitializeSupportForIetfDraft) {
   FLAGS_quic_supports_tls_handshake = true;
   ParsedQuicVersion parsed_version_t099 =
@@ -563,16 +578,19 @@
   FLAGS_quic_supports_tls_handshake = false;
   EXPECT_EQ(MakeVersionLabel('T', '0', '9', '9'),
             CreateQuicVersionLabel(parsed_version_t099));
+  EXPECT_EQ("h3-google-T099", AlpnForVersion(parsed_version_t099));
 
   QuicVersionInitializeSupportForIetfDraft(0);
   EXPECT_EQ(MakeVersionLabel('T', '0', '9', '9'),
             CreateQuicVersionLabel(parsed_version_t099));
+  EXPECT_EQ("h3-google-T099", AlpnForVersion(parsed_version_t099));
   EXPECT_FALSE(FLAGS_quic_supports_tls_handshake);
 
   QuicVersionInitializeSupportForIetfDraft(18);
   EXPECT_TRUE(FLAGS_quic_supports_tls_handshake);
   EXPECT_EQ(MakeVersionLabel(0xff, 0, 0, 18),
             CreateQuicVersionLabel(parsed_version_t099));
+  EXPECT_EQ("h3-18", AlpnForVersion(parsed_version_t099));
 }
 
 TEST_F(QuicVersionsTest, QuicEnableVersion) {
diff --git a/quic/core/tls_client_handshaker.cc b/quic/core/tls_client_handshaker.cc
index adc2c81..5081a48 100644
--- a/quic/core/tls_client_handshaker.cc
+++ b/quic/core/tls_client_handshaker.cc
@@ -4,12 +4,14 @@
 
 #include "net/third_party/quiche/src/quic/core/tls_client_handshaker.h"
 
+#include <cstring>
 #include <string>
 
 #include "third_party/boringssl/src/include/openssl/ssl.h"
 #include "net/third_party/quiche/src/quic/core/crypto/quic_encrypter.h"
 #include "net/third_party/quiche/src/quic/core/crypto/transport_parameters.h"
 #include "net/third_party/quiche/src/quic/core/quic_session.h"
+#include "net/third_party/quiche/src/quic/platform/api/quic_text_utils.h"
 
 namespace quic {
 
@@ -87,6 +89,27 @@
     return false;
   }
 
+  std::string alpn_string =
+      AlpnForVersion(session()->supported_versions().front());
+  if (alpn_string.length() > std::numeric_limits<uint8_t>::max()) {
+    QUIC_BUG << "ALPN too long: '" << alpn_string << "'";
+    CloseConnection(QUIC_HANDSHAKE_FAILED, "ALPN too long");
+    return false;
+  }
+  const uint8_t alpn_length = alpn_string.length();
+  // SSL_set_alpn_protos expects a sequence of one-byte-length-prefixed strings
+  // so we copy alpn_string to a new buffer that has the length in alpn[0].
+  uint8_t alpn[std::numeric_limits<uint8_t>::max() + 1];
+  alpn[0] = alpn_length;
+  memcpy(reinterpret_cast<char*>(alpn + 1), alpn_string.data(), alpn_length);
+  if (SSL_set_alpn_protos(ssl(), alpn,
+                          static_cast<unsigned>(alpn_length) + 1) != 0) {
+    QUIC_BUG << "Failed to set ALPN: '" << alpn_string << "'";
+    CloseConnection(QUIC_HANDSHAKE_FAILED, "Failed to set ALPN");
+    return false;
+  }
+  QUIC_DLOG(INFO) << "Client using ALPN: '" << alpn_string << "'";
+
   // Set the Transport Parameters to send in the ClientHello
   if (!SetTransportParameters()) {
     CloseConnection(QUIC_HANDSHAKE_FAILED,
@@ -248,6 +271,28 @@
     return;
   }
 
+  const uint8_t* alpn_data = nullptr;
+  unsigned alpn_length = 0;
+  SSL_get0_alpn_selected(ssl(), &alpn_data, &alpn_length);
+  // TODO(b/130164908) Act on ALPN.
+  if (alpn_length != 0) {
+    std::string received_alpn_string(reinterpret_cast<const char*>(alpn_data),
+                                     alpn_length);
+    std::string sent_alpn_string =
+        AlpnForVersion(session()->supported_versions().front());
+    if (received_alpn_string != sent_alpn_string) {
+      QUIC_LOG(ERROR) << "Client: received mismatched ALPN '"
+                      << received_alpn_string << "', expected '"
+                      << sent_alpn_string << "'";
+      CloseConnection(QUIC_HANDSHAKE_FAILED, "Mismatched ALPN");
+      return;
+    }
+    QUIC_DLOG(INFO) << "Client: server selected ALPN: '" << received_alpn_string
+                    << "'";
+  } else {
+    QUIC_DLOG(INFO) << "Client: server did not select ALPN";
+  }
+
   session()->connection()->SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE);
   session()->NeuterUnencryptedData();
   encryption_established_ = true;
diff --git a/quic/core/tls_handshaker.cc b/quic/core/tls_handshaker.cc
index 17e2674..c6394b8 100644
--- a/quic/core/tls_handshaker.cc
+++ b/quic/core/tls_handshaker.cc
@@ -236,6 +236,8 @@
   // (draft-ietf-quic-transport-14, section 11.3). However, according to
   // quic_error_codes.h, this QUIC implementation only sends 1-byte error codes
   // right now.
+  QUIC_DLOG(INFO) << "TLS failing handshake due to alert "
+                  << static_cast<int>(desc);
   CloseConnection(QUIC_HANDSHAKE_FAILED, "TLS handshake failure");
 }
 
diff --git a/quic/core/tls_server_handshaker.cc b/quic/core/tls_server_handshaker.cc
index 16c84b1..96e1802 100644
--- a/quic/core/tls_server_handshaker.cc
+++ b/quic/core/tls_server_handshaker.cc
@@ -51,6 +51,8 @@
   bssl::UniquePtr<SSL_CTX> ssl_ctx = TlsHandshaker::CreateSslCtx();
   SSL_CTX_set_tlsext_servername_callback(
       ssl_ctx.get(), TlsServerHandshaker::SelectCertificateCallback);
+  SSL_CTX_set_alpn_select_cb(ssl_ctx.get(),
+                             TlsServerHandshaker::SelectAlpnCallback, nullptr);
   return ssl_ctx;
 }
 
@@ -363,4 +365,41 @@
   return SSL_TLSEXT_ERR_OK;
 }
 
+// static
+int TlsServerHandshaker::SelectAlpnCallback(SSL* ssl,
+                                            const uint8_t** out,
+                                            uint8_t* out_len,
+                                            const uint8_t* in,
+                                            unsigned in_len,
+                                            void* arg) {
+  return HandshakerFromSsl(ssl)->SelectAlpn(out, out_len, in, in_len);
+}
+
+int TlsServerHandshaker::SelectAlpn(const uint8_t** out,
+                                    uint8_t* out_len,
+                                    const uint8_t* in,
+                                    unsigned in_len) {
+  // |in| contains a sequence of 1-byte-length-prefixed values.
+  // We currently simply return the first provided ALPN value.
+  // TODO(b/130164908) Act on ALPN.
+  if (in_len == 0) {
+    *out_len = 0;
+    *out = nullptr;
+    QUIC_DLOG(INFO) << "No ALPN provided";
+    return SSL_TLSEXT_ERR_OK;
+  }
+  const uint8_t first_alpn_length = in[0];
+  if (static_cast<unsigned>(first_alpn_length) > in_len - 1) {
+    QUIC_LOG(ERROR) << "Failed to parse ALPN";
+    return SSL_TLSEXT_ERR_ALERT_FATAL;
+  }
+  *out_len = first_alpn_length;
+  *out = in + 1;
+  QUIC_DLOG(INFO) << "Server selecting ALPN '"
+                  << QuicStringPiece(reinterpret_cast<const char*>(*out),
+                                     *out_len)
+                  << "'";
+  return SSL_TLSEXT_ERR_OK;
+}
+
 }  // namespace quic
diff --git a/quic/core/tls_server_handshaker.h b/quic/core/tls_server_handshaker.h
index 1e9b497..e27d896 100644
--- a/quic/core/tls_server_handshaker.h
+++ b/quic/core/tls_server_handshaker.h
@@ -61,6 +61,14 @@
   // |ssl|.
   static int SelectCertificateCallback(SSL* ssl, int* out_alert, void* arg);
 
+  // Calls SelectAlpn after looking up the TlsServerHandshaker from |ssl|.
+  static int SelectAlpnCallback(SSL* ssl,
+                                const uint8_t** out,
+                                uint8_t* out_len,
+                                const uint8_t* in,
+                                unsigned in_len,
+                                void* arg);
+
  private:
   class SignatureCallback : public ProofSource::SignatureCallback {
    public:
@@ -148,6 +156,12 @@
   // |*out_alert| the TLS alert value that the server will send.
   int SelectCertificate(int* out_alert);
 
+  // Selects which ALPN to use based on the list sent by the client.
+  int SelectAlpn(const uint8_t** out,
+                 uint8_t* out_len,
+                 const uint8_t* in,
+                 unsigned in_len);
+
   static TlsServerHandshaker* HandshakerFromSsl(SSL* ssl);
 
   State state_ = STATE_LISTENING;