https://github.com/ashgti updated https://github.com/llvm/llvm-project/pull/158357
>From 07a8a62569a41c881f721b2800086eb559da6fa8 Mon Sep 17 00:00:00 2001 From: John Harrison <harj...@google.com> Date: Fri, 12 Sep 2025 13:06:30 -0700 Subject: [PATCH] [lldb-mcp] Fix servers accepting more than one client. This fixes an issue where the MCP server would stop the main loop after the first client disconnects. This moves the MainLoop out of the Server instance and lifts the server up into the ProtocolServerMCP object instead. This allows us to register the client with the main loop used to accept and process requests. --- lldb/include/lldb/Host/JSONTransport.h | 19 +++-- lldb/include/lldb/Protocol/MCP/Server.h | 17 ++--- .../Protocol/MCP/ProtocolServerMCP.cpp | 24 +++++-- .../Plugins/Protocol/MCP/ProtocolServerMCP.h | 17 ++++- lldb/source/Protocol/MCP/Server.cpp | 43 ++++-------- lldb/unittests/Host/JSONTransportTest.cpp | 2 +- .../Protocol/ProtocolMCPServerTest.cpp | 70 ++++++++++++------- 7 files changed, 110 insertions(+), 82 deletions(-) diff --git a/lldb/include/lldb/Host/JSONTransport.h b/lldb/include/lldb/Host/JSONTransport.h index 210f33edace6e..c73021d204258 100644 --- a/lldb/include/lldb/Host/JSONTransport.h +++ b/lldb/include/lldb/Host/JSONTransport.h @@ -100,22 +100,21 @@ template <typename Req, typename Resp, typename Evt> class Transport { virtual llvm::Expected<MainLoop::ReadHandleUP> RegisterMessageHandler(MainLoop &loop, MessageHandler &handler) = 0; - // FIXME: Refactor mcp::Server to not directly access log on the transport. - // protected: +protected: template <typename... Ts> inline auto Logv(const char *Fmt, Ts &&...Vals) { Log(llvm::formatv(Fmt, std::forward<Ts>(Vals)...).str()); } virtual void Log(llvm::StringRef message) = 0; }; -/// A JSONTransport will encode and decode messages using JSON. +/// An IOTransport sends and receives messages using an IOObject. template <typename Req, typename Resp, typename Evt> -class JSONTransport : public Transport<Req, Resp, Evt> { +class IOTransport : public Transport<Req, Resp, Evt> { public: using Transport<Req, Resp, Evt>::Transport; using MessageHandler = typename Transport<Req, Resp, Evt>::MessageHandler; - JSONTransport(lldb::IOObjectSP in, lldb::IOObjectSP out) + IOTransport(lldb::IOObjectSP in, lldb::IOObjectSP out) : m_in(in), m_out(out) {} llvm::Error Send(const Evt &evt) override { return Write(evt); } @@ -127,7 +126,7 @@ class JSONTransport : public Transport<Req, Resp, Evt> { Status status; MainLoop::ReadHandleUP read_handle = loop.RegisterReadObject( m_in, - std::bind(&JSONTransport::OnRead, this, std::placeholders::_1, + std::bind(&IOTransport::OnRead, this, std::placeholders::_1, std::ref(handler)), status); if (status.Fail()) { @@ -203,9 +202,9 @@ class JSONTransport : public Transport<Req, Resp, Evt> { /// A transport class for JSON with a HTTP header. template <typename Req, typename Resp, typename Evt> -class HTTPDelimitedJSONTransport : public JSONTransport<Req, Resp, Evt> { +class HTTPDelimitedJSONTransport : public IOTransport<Req, Resp, Evt> { public: - using JSONTransport<Req, Resp, Evt>::JSONTransport; + using IOTransport<Req, Resp, Evt>::IOTransport; protected: /// Encodes messages based on @@ -270,9 +269,9 @@ class HTTPDelimitedJSONTransport : public JSONTransport<Req, Resp, Evt> { /// A transport class for JSON RPC. template <typename Req, typename Resp, typename Evt> -class JSONRPCTransport : public JSONTransport<Req, Resp, Evt> { +class JSONRPCTransport : public IOTransport<Req, Resp, Evt> { public: - using JSONTransport<Req, Resp, Evt>::JSONTransport; + using IOTransport<Req, Resp, Evt>::IOTransport; protected: std::string Encode(const llvm::json::Value &message) override { diff --git a/lldb/include/lldb/Protocol/MCP/Server.h b/lldb/include/lldb/Protocol/MCP/Server.h index b674d58159550..da8fe9c38dc7f 100644 --- a/lldb/include/lldb/Protocol/MCP/Server.h +++ b/lldb/include/lldb/Protocol/MCP/Server.h @@ -29,10 +29,11 @@ namespace lldb_protocol::mcp { class Server : public MCPTransport::MessageHandler { + using ClosedCallback = llvm::unique_function<void()>; + public: - Server(std::string name, std::string version, - std::unique_ptr<MCPTransport> transport_up, - lldb_private::MainLoop &loop); + Server(std::string name, std::string version, MCPTransport &client, + LogCallback log_callback = {}, ClosedCallback closed_callback = {}); ~Server() = default; using NotificationHandler = std::function<void(const Notification &)>; @@ -42,8 +43,6 @@ class Server : public MCPTransport::MessageHandler { void AddNotificationHandler(llvm::StringRef method, NotificationHandler handler); - llvm::Error Run(); - protected: ServerCapabilities GetCapabilities(); @@ -73,14 +72,16 @@ class Server : public MCPTransport::MessageHandler { void OnError(llvm::Error) override; void OnClosed() override; - void TerminateLoop(); +protected: + void Log(llvm::StringRef); private: const std::string m_name; const std::string m_version; - std::unique_ptr<MCPTransport> m_transport_up; - lldb_private::MainLoop &m_loop; + MCPTransport &m_client; + LogCallback m_log_callback; + ClosedCallback m_closed_callback; llvm::StringMap<std::unique_ptr<Tool>> m_tools; std::vector<std::unique_ptr<ResourceProvider>> m_resource_providers; diff --git a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp index dc18c8e06803a..3af43a1ea443c 100644 --- a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp +++ b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp @@ -66,7 +66,7 @@ void ProtocolServerMCP::Extend(lldb_protocol::mcp::Server &server) const { void ProtocolServerMCP::AcceptCallback(std::unique_ptr<Socket> socket) { Log *log = GetLog(LLDBLog::Host); - std::string client_name = llvm::formatv("client_{0}", m_instances.size() + 1); + std::string client_name = llvm::formatv("client_{0}", ++m_client_count); LLDB_LOG(log, "New MCP client connected: {0}", client_name); lldb::IOObjectSP io_sp = std::move(socket); @@ -74,16 +74,26 @@ void ProtocolServerMCP::AcceptCallback(std::unique_ptr<Socket> socket) { io_sp, io_sp, [client_name](llvm::StringRef message) { LLDB_LOG(GetLog(LLDBLog::Host), "{0}: {1}", client_name, message); }); + MCPTransport *transport_ptr = transport_up.get(); auto instance_up = std::make_unique<lldb_protocol::mcp::Server>( - std::string(kName), std::string(kVersion), std::move(transport_up), - m_loop); + std::string(kName), std::string(kVersion), *transport_up, + /*log_callback=*/ + [client_name](llvm::StringRef message) { + LLDB_LOG(GetLog(LLDBLog::Host), "{0} Server: {1}", client_name, + message); + }, + /*closed_callback=*/ + [this, transport_ptr]() { m_instances.erase(transport_ptr); }); Extend(*instance_up); - llvm::Error error = instance_up->Run(); - if (error) { - LLDB_LOG_ERROR(log, std::move(error), "Failed to run MCP server: {0}"); + llvm::Expected<MainLoop::ReadHandleUP> handle = + transport_up->RegisterMessageHandler(m_loop, *instance_up); + if (!handle) { + LLDB_LOG_ERROR(log, handle.takeError(), "Failed to run MCP server: {0}"); return; } - m_instances.push_back(std::move(instance_up)); + m_instances[transport_ptr] = + std::make_tuple<ServerUP, ReadHandleUP, TransportUP>( + std::move(instance_up), std::move(*handle), std::move(transport_up)); } llvm::Error ProtocolServerMCP::Start(ProtocolServer::Connection connection) { diff --git a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h index 0251664a2acc4..69ef490c98679 100644 --- a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h +++ b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h @@ -12,13 +12,21 @@ #include "lldb/Core/ProtocolServer.h" #include "lldb/Host/MainLoop.h" #include "lldb/Host/Socket.h" -#include "lldb/Protocol/MCP/Protocol.h" #include "lldb/Protocol/MCP/Server.h" +#include "lldb/Protocol/MCP/Transport.h" +#include <map> +#include <memory> #include <thread> +#include <tuple> +#include <vector> namespace lldb_private::mcp { class ProtocolServerMCP : public ProtocolServer { + using ReadHandleUP = MainLoopBase::ReadHandleUP; + using TransportUP = std::unique_ptr<lldb_protocol::mcp::MCPTransport>; + using ServerUP = std::unique_ptr<lldb_protocol::mcp::Server>; + public: ProtocolServerMCP(); virtual ~ProtocolServerMCP() override; @@ -52,11 +60,14 @@ class ProtocolServerMCP : public ProtocolServer { lldb_private::MainLoop m_loop; std::thread m_loop_thread; std::mutex m_mutex; + uint32_t m_client_count = 0; std::unique_ptr<Socket> m_listener; - std::vector<MainLoopBase::ReadHandleUP> m_listen_handlers; - std::vector<std::unique_ptr<lldb_protocol::mcp::Server>> m_instances; + std::vector<ReadHandleUP> m_listen_handlers; + std::map<lldb_protocol::mcp::MCPTransport *, + std::tuple<ServerUP, ReadHandleUP, TransportUP>> + m_instances; }; } // namespace lldb_private::mcp diff --git a/lldb/source/Protocol/MCP/Server.cpp b/lldb/source/Protocol/MCP/Server.cpp index f3489c620832f..45d0a5fad70aa 100644 --- a/lldb/source/Protocol/MCP/Server.cpp +++ b/lldb/source/Protocol/MCP/Server.cpp @@ -10,7 +10,6 @@ #include "lldb/Host/File.h" #include "lldb/Host/FileSystem.h" #include "lldb/Host/HostInfo.h" -#include "lldb/Host/JSONTransport.h" #include "lldb/Protocol/MCP/MCPError.h" #include "lldb/Protocol/MCP/Protocol.h" #include "llvm/ADT/SmallString.h" @@ -111,11 +110,11 @@ Expected<std::vector<ServerInfo>> ServerInfo::Load() { return infos; } -Server::Server(std::string name, std::string version, - std::unique_ptr<MCPTransport> transport_up, - lldb_private::MainLoop &loop) - : m_name(std::move(name)), m_version(std::move(version)), - m_transport_up(std::move(transport_up)), m_loop(loop) { +Server::Server(std::string name, std::string version, MCPTransport &client, + LogCallback log_callback, ClosedCallback closed_callback) + : m_name(std::move(name)), m_version(std::move(version)), m_client(client), + m_log_callback(std::move(log_callback)), + m_closed_callback(std::move(closed_callback)) { AddRequestHandlers(); } @@ -289,22 +288,15 @@ ServerCapabilities Server::GetCapabilities() { return capabilities; } -llvm::Error Server::Run() { - auto handle = m_transport_up->RegisterMessageHandler(m_loop, *this); - if (!handle) - return handle.takeError(); - - lldb_private::Status status = m_loop.Run(); - if (status.Fail()) - return status.takeError(); - - return llvm::Error::success(); +void Server::Log(llvm::StringRef message) { + if (m_log_callback) + m_log_callback(message); } void Server::Received(const Request &request) { auto SendResponse = [this](const Response &response) { - if (llvm::Error error = m_transport_up->Send(response)) - m_transport_up->Log(llvm::toString(std::move(error))); + if (llvm::Error error = m_client.Send(response)) + Log(llvm::toString(std::move(error))); }; llvm::Expected<Response> response = Handle(request); @@ -326,7 +318,7 @@ void Server::Received(const Request &request) { } void Server::Received(const Response &response) { - m_transport_up->Log("unexpected MCP message: response"); + Log("unexpected MCP message: response"); } void Server::Received(const Notification ¬ification) { @@ -334,16 +326,11 @@ void Server::Received(const Notification ¬ification) { } void Server::OnError(llvm::Error error) { - m_transport_up->Log(llvm::toString(std::move(error))); - TerminateLoop(); + Log(llvm::toString(std::move(error))); } void Server::OnClosed() { - m_transport_up->Log("EOF"); - TerminateLoop(); -} - -void Server::TerminateLoop() { - m_loop.AddPendingCallback( - [](lldb_private::MainLoopBase &loop) { loop.RequestTermination(); }); + Log("EOF"); + if (m_closed_callback) + m_closed_callback(); } diff --git a/lldb/unittests/Host/JSONTransportTest.cpp b/lldb/unittests/Host/JSONTransportTest.cpp index 445674f402252..3a36bf21f07ff 100644 --- a/lldb/unittests/Host/JSONTransportTest.cpp +++ b/lldb/unittests/Host/JSONTransportTest.cpp @@ -413,7 +413,7 @@ TEST_F(JSONRPCTransportTest, ReadAcrossMultipleChunks) { // Use a string longer than the chunk size to ensure we split the message // across the chunk boundary. std::string long_str = - std::string(JSONTransport<Req, Resp, Evt>::kReadBufferSize * 2, 'x'); + std::string(IOTransport<Req, Resp, Evt>::kReadBufferSize * 2, 'x'); Write(Req{long_str}); EXPECT_CALL(message_handler, Received(Req{long_str})); ASSERT_THAT_ERROR(Run(), Succeeded()); diff --git a/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp b/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp index f686255c6d41d..f3ca4cfc01788 100644 --- a/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp +++ b/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp @@ -122,53 +122,73 @@ class ProtocolServerMCPTest : public PipePairTest { public: SubsystemRAII<FileSystem, HostInfo, Socket> subsystems; - std::unique_ptr<lldb_protocol::mcp::Transport> transport_up; - std::unique_ptr<TestServer> server_up; MainLoop loop; + + std::unique_ptr<lldb_protocol::mcp::Transport> from_client; + std::unique_ptr<lldb_protocol::mcp::Transport> to_client; + MainLoopBase::ReadHandleUP handles[2]; + + std::unique_ptr<TestServer> server_up; MockMessageHandler<Request, Response, Notification> message_handler; llvm::Error Write(llvm::StringRef message) { llvm::Expected<json::Value> value = json::parse(message); if (!value) return value.takeError(); - return transport_up->Write(*value); + return from_client->Write(*value); } - llvm::Error Write(json::Value value) { return transport_up->Write(value); } + llvm::Error Write(json::Value value) { return from_client->Write(value); } /// Run the transport MainLoop and return any messages received. - llvm::Error - Run(std::chrono::milliseconds timeout = std::chrono::milliseconds(200)) { + llvm::Error Run() { loop.AddCallback([](MainLoopBase &loop) { loop.RequestTermination(); }, - timeout); - auto handle = transport_up->RegisterMessageHandler(loop, message_handler); - if (!handle) - return handle.takeError(); - - return server_up->Run(); + std::chrono::milliseconds(10)); + return loop.Run().takeError(); } void SetUp() override { PipePairTest::SetUp(); - transport_up = std::make_unique<lldb_protocol::mcp::Transport>( + from_client = std::make_unique<lldb_protocol::mcp::Transport>( std::make_shared<NativeFile>(input.GetReadFileDescriptor(), File::eOpenOptionReadOnly, NativeFile::Unowned), std::make_shared<NativeFile>(output.GetWriteFileDescriptor(), File::eOpenOptionWriteOnly, - NativeFile::Unowned)); - - server_up = std::make_unique<TestServer>( - "lldb-mcp", "0.1.0", - std::make_unique<lldb_protocol::mcp::Transport>( - std::make_shared<NativeFile>(output.GetReadFileDescriptor(), - File::eOpenOptionReadOnly, - NativeFile::Unowned), - std::make_shared<NativeFile>(input.GetWriteFileDescriptor(), - File::eOpenOptionWriteOnly, - NativeFile::Unowned)), - loop); + NativeFile::Unowned), + [](StringRef message) { + // Uncomment for debugging + // llvm::errs() << "from_client: " << message << '\n'; + }); + to_client = std::make_unique<lldb_protocol::mcp::Transport>( + std::make_shared<NativeFile>(output.GetReadFileDescriptor(), + File::eOpenOptionReadOnly, + NativeFile::Unowned), + std::make_shared<NativeFile>(input.GetWriteFileDescriptor(), + File::eOpenOptionWriteOnly, + NativeFile::Unowned), + [](StringRef message) { + // Uncomment for debugging + // llvm::errs() << "to_client: " << message << '\n'; + }); + + server_up = std::make_unique<TestServer>("lldb-mcp", "0.1.0", *to_client, + [](StringRef message) { + // Uncomment for debugging + // llvm::errs() << "server: " << + // message << '\n'; + }); + + auto maybe_from_client_handle = + from_client->RegisterMessageHandler(loop, message_handler); + EXPECT_THAT_EXPECTED(maybe_from_client_handle, Succeeded()); + handles[0] = std::move(*maybe_from_client_handle); + + auto maybe_to_client_handle = + to_client->RegisterMessageHandler(loop, *server_up); + EXPECT_THAT_EXPECTED(maybe_to_client_handle, Succeeded()); + handles[1] = std::move(*maybe_to_client_handle); } }; _______________________________________________ lldb-commits mailing list lldb-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/lldb-commits