Allow MoqtClient and MoqtServer to control session parameters.

This will enable partial object delivery on the relay, but is generally useful functionality.

PiperOrigin-RevId: 903589067
diff --git a/quiche/quic/moqt/test_tools/moqt_session_peer.h b/quiche/quic/moqt/test_tools/moqt_session_peer.h
index 6059293..805bba5 100644
--- a/quiche/quic/moqt/test_tools/moqt_session_peer.h
+++ b/quiche/quic/moqt/test_tools/moqt_session_peer.h
@@ -258,6 +258,10 @@
   static MoqtSession::ControlStream* GetControlStream(MoqtSession* session) {
     return session->control_stream_.GetIfAvailable();
   }
+
+  static const MoqtSessionParameters& GetParameters(MoqtSession* session) {
+    return session->parameters_;
+  }
 };
 
 }  // namespace moqt::test
diff --git a/quiche/quic/moqt/tools/moqt_client.cc b/quiche/quic/moqt/tools/moqt_client.cc
index 5140204..f2d185b 100644
--- a/quiche/quic/moqt/tools/moqt_client.cc
+++ b/quiche/quic/moqt/tools/moqt_client.cc
@@ -21,6 +21,8 @@
 #include "quiche/quic/moqt/moqt_messages.h"
 #include "quiche/quic/moqt/moqt_quic_config.h"
 #include "quiche/quic/moqt/moqt_session.h"
+#include "quiche/quic/moqt/moqt_session_callbacks.h"
+#include "quiche/quic/moqt/moqt_session_interface.h"
 #include "quiche/quic/platform/api/quic_socket_address.h"
 #include "quiche/quic/tools/quic_default_client.h"
 #include "quiche/quic/tools/quic_event_loop_tools.h"
@@ -34,11 +36,14 @@
 MoqtClient::MoqtClient(quic::QuicSocketAddress peer_address,
                        const quic::QuicServerId& server_id,
                        std::unique_ptr<quic::ProofVerifier> proof_verifier,
-                       quic::QuicEventLoop* event_loop)
+                       quic::QuicEventLoop* event_loop,
+                       MoqtSessionParameters parameters)
     : spdy_client_(peer_address, server_id, GetMoqtSupportedQuicVersions(),
-                   event_loop, std::move(proof_verifier)) {
+                   event_loop, std::move(proof_verifier)),
+      parameters_(parameters) {
   TuneQuicConfig(*spdy_client_.config());
   spdy_client_.set_enable_web_transport(true);
+  parameters_.perspective = quic::Perspective::IS_CLIENT;
 }
 
 void MoqtClient::Connect(std::string path, MoqtSessionCallbacks callbacks) {
@@ -101,8 +106,6 @@
     return absl::InternalError("Failed to initialize WebTransport session");
   }
 
-  MoqtSessionParameters parameters(quic::Perspective::IS_CLIENT);
-
   // Ensure that we never have a dangling pointer to the session.
   MoqtSessionDeletedCallback deleted_callback =
       std::move(callbacks.session_deleted_callback);
@@ -113,7 +116,7 @@
       };
 
   auto session = std::make_unique<MoqtSession>(
-      web_transport, parameters,
+      web_transport, parameters_,
       spdy_client_.default_network_helper()->event_loop()->CreateAlarmFactory(),
       std::move(callbacks));
   session_ = session.get();
diff --git a/quiche/quic/moqt/tools/moqt_client.h b/quiche/quic/moqt/tools/moqt_client.h
index 796a751..011ceb1 100644
--- a/quiche/quic/moqt/tools/moqt_client.h
+++ b/quiche/quic/moqt/tools/moqt_client.h
@@ -15,6 +15,7 @@
 #include "quiche/quic/core/quic_session.h"
 #include "quiche/quic/moqt/moqt_session.h"
 #include "quiche/quic/moqt/moqt_session_callbacks.h"
+#include "quiche/quic/moqt/moqt_session_interface.h"
 #include "quiche/quic/platform/api/quic_socket_address.h"
 #include "quiche/quic/tools/quic_default_client.h"
 #include "quiche/common/platform/api/quiche_export.h"
@@ -27,7 +28,8 @@
   MoqtClient(quic::QuicSocketAddress peer_address,
              const quic::QuicServerId& server_id,
              std::unique_ptr<quic::ProofVerifier> proof_verifier,
-             quic::QuicEventLoop* event_loop);
+             quic::QuicEventLoop* event_loop,
+             MoqtSessionParameters parameters = MoqtSessionParameters());
 
   // Establishes the connection to the specified endpoint. The errors are
   // returned via the session termination callback.
@@ -40,6 +42,7 @@
   absl::Status ConnectInner(std::string path, MoqtSessionCallbacks& callbacks);
 
   quic::QuicDefaultClient spdy_client_;
+  MoqtSessionParameters parameters_;
   MoqtSession* session_ = nullptr;
 };
 
diff --git a/quiche/quic/moqt/tools/moqt_end_to_end_test.cc b/quiche/quic/moqt/tools/moqt_end_to_end_test.cc
index 49c60e0..f415176 100644
--- a/quiche/quic/moqt/tools/moqt_end_to_end_test.cc
+++ b/quiche/quic/moqt/tools/moqt_end_to_end_test.cc
@@ -20,6 +20,7 @@
 #include "quiche/quic/core/io/quic_event_loop.h"
 #include "quiche/quic/core/quic_server_id.h"
 #include "quiche/quic/moqt/moqt_session.h"
+#include "quiche/quic/moqt/test_tools/moqt_session_peer.h"
 #include "quiche/quic/moqt/tools/moqt_client.h"
 #include "quiche/quic/moqt/tools/moqt_server.h"
 #include "quiche/quic/platform/api/quic_ip_address.h"
@@ -127,5 +128,50 @@
   EXPECT_TRUE(success);
 }
 
+TEST_F(MoqtEndToEndTest, CustomParametersHandshake) {
+  MoqtSessionParameters server_parameters;
+  server_parameters.deliver_partial_objects = true;
+  MoqtSession* server_session = nullptr;
+  auto server_backend = [&](absl::string_view /*path*/)
+      -> absl::StatusOr<MoqtConfigureSessionCallback> {
+    return [&](MoqtSession* session) {
+      server_session = session;
+      session->callbacks().session_terminated_callback = [](absl::string_view) {
+      };
+    };
+  };
+  MoqtServer custom_server(
+      quic::test::crypto_test_utils::ProofSourceForTesting(),
+      std::move(server_backend), server_parameters);
+  quic::QuicIpAddress host = quic::TestLoopback();
+  QUICHE_CHECK_OK(custom_server.CreateUDPSocketAndListen(
+      quic::QuicSocketAddress(host, /*port=*/0)));
+  quic::QuicSocketAddress custom_server_address(host, custom_server.port());
+
+  MoqtSessionParameters client_parameters;
+  client_parameters.max_request_id = 200;
+  MoqtClient client(custom_server_address,
+                    quic::QuicServerId("test.example.com", 443),
+                    quic::test::crypto_test_utils::ProofVerifierForTesting(),
+                    custom_server.event_loop(), client_parameters);
+
+  MoqtSessionCallbacks callbacks;
+  bool established = false;
+  callbacks.session_established_callback = [&] { established = true; };
+  callbacks.session_terminated_callback = UnexpectedClose;
+  client.Connect("/test", std::move(callbacks));
+
+  bool success = quic::ProcessEventsUntil(custom_server.event_loop(), [&] {
+    return established && server_session != nullptr;
+  });
+  EXPECT_TRUE(success);
+  ASSERT_NE(client.session(), nullptr);
+  EXPECT_EQ(MoqtSessionPeer::GetParameters(client.session()).max_request_id,
+            200);
+  ASSERT_NE(server_session, nullptr);
+  EXPECT_TRUE(
+      MoqtSessionPeer::GetParameters(server_session).deliver_partial_objects);
+}
+
 }  // namespace
 }  // namespace moqt::test
diff --git a/quiche/quic/moqt/tools/moqt_server.cc b/quiche/quic/moqt/tools/moqt_server.cc
index d4e7052..3f46ffc 100644
--- a/quiche/quic/moqt/tools/moqt_server.cc
+++ b/quiche/quic/moqt/tools/moqt_server.cc
@@ -27,9 +27,9 @@
 #include "quiche/quic/core/quic_time.h"
 #include "quiche/quic/core/quic_types.h"
 #include "quiche/quic/core/quic_versions.h"
-#include "quiche/quic/moqt/moqt_messages.h"
 #include "quiche/quic/moqt/moqt_quic_config.h"
 #include "quiche/quic/moqt/moqt_session.h"
+#include "quiche/quic/moqt/moqt_session_interface.h"
 #include "quiche/quic/platform/api/quic_socket_address.h"
 #include "quiche/quic/tools/quic_simple_crypto_server_stream_helper.h"
 #include "quiche/common/quiche_status_utils.h"
@@ -47,8 +47,10 @@
 }
 
 quic::WebTransportHandlerFactoryCallback CreateWebTransportCallback(
-    MoqtIncomingSessionCallback callback, quic::QuicEventLoop* event_loop) {
-  return [event_loop = event_loop, callback = std::move(callback)](
+    MoqtIncomingSessionCallback callback, quic::QuicEventLoop* event_loop,
+    const MoqtSessionParameters& session_parameters) {
+  return [event_loop = event_loop, callback = std::move(callback),
+          parameters = session_parameters](
              webtransport::Session* session,
              const quic::WebTransportIncomingRequestDetails& details)
              -> absl::StatusOr<quic::WebTransportConnectResponse> {
@@ -58,8 +60,6 @@
     if (!configurator.ok()) {
       return configurator.status();
     }
-
-    MoqtSessionParameters parameters(quic::Perspective::IS_SERVER);
     auto moqt_session = std::make_unique<MoqtSession>(
         session, parameters, event_loop->CreateAlarmFactory());
     std::move (*configurator)(moqt_session.get());
@@ -72,7 +72,8 @@
 }  // namespace
 
 MoqtServer::MoqtServer(std::unique_ptr<quic::ProofSource> proof_source,
-                       MoqtIncomingSessionCallback callback)
+                       MoqtIncomingSessionCallback callback,
+                       MoqtSessionParameters session_parameters)
     : config_(GenerateQuicConfig()),
       crypto_config_(GenerateRandomTokenSecret(),
                      quic::QuicRandom::GetInstance(), std::move(proof_source),
@@ -86,9 +87,11 @@
                   std::make_unique<quic::QuicSimpleCryptoServerStreamHelper>(),
                   event_loop_->CreateAlarmFactory(),
                   quic::kQuicDefaultConnectionIdLength,
-                  connection_id_generator_) {
-  dispatcher_.parameters().handler_factory =
-      CreateWebTransportCallback(std::move(callback), event_loop_.get());
+                  connection_id_generator_),
+      session_parameters_(session_parameters) {
+  session_parameters_.perspective = quic::Perspective::IS_SERVER;
+  dispatcher_.parameters().handler_factory = CreateWebTransportCallback(
+      std::move(callback), event_loop_.get(), session_parameters_);
   dispatcher_.parameters().subprotocol_callback =
       +[](absl::Span<const absl::string_view> subprotocols) {
         return absl::c_find(subprotocols, kDefaultMoqtVersion) -
diff --git a/quiche/quic/moqt/tools/moqt_server.h b/quiche/quic/moqt/tools/moqt_server.h
index a59e489..2db8e9a 100644
--- a/quiche/quic/moqt/tools/moqt_server.h
+++ b/quiche/quic/moqt/tools/moqt_server.h
@@ -21,6 +21,7 @@
 #include "quiche/quic/core/quic_config.h"
 #include "quiche/quic/core/quic_version_manager.h"
 #include "quiche/quic/moqt/moqt_session.h"
+#include "quiche/quic/moqt/moqt_session_interface.h"
 #include "quiche/quic/platform/api/quic_socket_address.h"
 #include "quiche/common/platform/api/quiche_export.h"
 #include "quiche/common/quiche_callbacks.h"
@@ -43,8 +44,10 @@
 // A simple MoQT server.
 class QUICHE_EXPORT MoqtServer {
  public:
-  explicit MoqtServer(std::unique_ptr<quic::ProofSource> proof_source,
-                      MoqtIncomingSessionCallback callback);
+  explicit MoqtServer(
+      std::unique_ptr<quic::ProofSource> proof_source,
+      MoqtIncomingSessionCallback callback,
+      MoqtSessionParameters session_parameters = MoqtSessionParameters());
 
   MoqtServer(const MoqtServer&) = delete;
   MoqtServer(MoqtServer&&) = delete;
@@ -65,6 +68,7 @@
   quic::DeterministicConnectionIdGenerator connection_id_generator_;
   std::unique_ptr<quic::QuicEventLoop> event_loop_;
   quic::WebTransportOnlyDispatcher dispatcher_;
+  MoqtSessionParameters session_parameters_;
 
   quic::OwnedSocketFd fd_;
   std::unique_ptr<quic::QuicServerIoHarness> io_;
diff --git a/quiche/quic/moqt/tools/moqt_server_test.cc b/quiche/quic/moqt/tools/moqt_server_test.cc
index 06fd560..079b554 100644
--- a/quiche/quic/moqt/tools/moqt_server_test.cc
+++ b/quiche/quic/moqt/tools/moqt_server_test.cc
@@ -6,19 +6,19 @@
 
 #include <utility>
 
-#include "absl/base/nullability.h"
 #include "absl/memory/memory.h"
 #include "absl/status/statusor.h"
 #include "absl/strings/string_view.h"
 #include "quiche/quic/core/http/web_transport_only_server_session.h"
 #include "quiche/quic/core/quic_alarm.h"
 #include "quiche/quic/core/quic_time.h"
+#include "quiche/quic/core/quic_types.h"
 #include "quiche/quic/moqt/moqt_session.h"
+#include "quiche/quic/moqt/moqt_session_interface.h"
 #include "quiche/quic/moqt/test_tools/moqt_session_peer.h"
 #include "quiche/quic/platform/api/quic_socket_address.h"
 #include "quiche/quic/platform/api/quic_test.h"
 #include "quiche/quic/test_tools/crypto_test_utils.h"
-#include "quiche/quic/tools/web_transport_only_backend.h"
 #include "quiche/common/http/http_header_block.h"
 #include "quiche/common/quiche_ip_address.h"
 #include "quiche/common/test_tools/quiche_test_utils.h"
@@ -79,4 +79,28 @@
   EXPECT_TRUE(alarm->IsSet());
 }
 
+TEST_F(MoqtServerTest, CustomSessionParameters) {
+  MoqtSessionParameters parameters;
+  parameters.deliver_partial_objects = true;
+  parameters.perspective = quic::Perspective::IS_CLIENT;
+  MoqtServer server(
+      quic::test::crypto_test_utils::ProofSourceForTesting(),
+      [&](absl::string_view /*path*/) {
+        return [&](MoqtSession* session) { session_ = session; };
+      },
+      parameters);
+  quiche::HttpHeaderBlock headers;
+  headers.AppendValueOrAddHeader(":path", "/foo");
+  absl::StatusOr<quic::WebTransportConnectResponse> response =
+      MoqtServerPeer::CallHandlerFactory(
+          server, &mock_session_,
+          quic::WebTransportIncomingRequestDetails{.headers =
+                                                       std::move(headers)});
+  QUICHE_EXPECT_OK(response.status());
+  ASSERT_NE(session_, nullptr);
+  EXPECT_TRUE(MoqtSessionPeer::GetParameters(session_).deliver_partial_objects);
+  EXPECT_EQ(MoqtSessionPeer::GetParameters(session_).perspective,
+            quic::Perspective::IS_SERVER);
+}
+
 }  // namespace moqt::test