https://github.com/ashgti updated https://github.com/llvm/llvm-project/pull/153121
>From d5f998c50d3188fc9eeb94ce80e4d4dfd15d6790 Mon Sep 17 00:00:00 2001 From: John Harrison <harj...@google.com> Date: Thu, 7 Aug 2025 08:56:11 -0700 Subject: [PATCH 1/2] [lldb] Refactoring JSONTransport into an abstract RPC Message Handler and transport layer. This abstracts the base Transport handler to have a MessageHandler component and allows us to generalize both JSON-RPC 2.0 for MCP (or an LSP) and DAP format. This should allow us to create clearly defined clients and servers for protocols, both for testing and for RPC between the lldb instances and an lldb-mcp multiplexer. This basic model is inspiried by the clangd/Transport.h file and the mlir/lsp-server-support/Transport.h that are both used for LSP servers within the llvm project. --- lldb/include/lldb/Host/JSONTransport.h | 324 ++++++++++----- lldb/source/Host/common/JSONTransport.cpp | 116 +----- lldb/source/Protocol/MCP/Protocol.cpp | 1 + lldb/tools/lldb-dap/DAP.cpp | 177 ++++---- lldb/tools/lldb-dap/DAP.h | 25 +- lldb/tools/lldb-dap/Protocol/ProtocolBase.h | 4 + lldb/tools/lldb-dap/Transport.cpp | 5 +- lldb/tools/lldb-dap/Transport.h | 5 +- lldb/tools/lldb-dap/tool/lldb-dap.cpp | 21 +- lldb/unittests/DAP/DAPTest.cpp | 16 +- lldb/unittests/DAP/Handler/DisconnectTest.cpp | 20 +- lldb/unittests/DAP/TestBase.cpp | 48 +-- lldb/unittests/DAP/TestBase.h | 91 +++-- lldb/unittests/Host/JSONTransportTest.cpp | 382 +++++++++++------- .../ProtocolServer/ProtocolMCPServerTest.cpp | 174 ++++---- 15 files changed, 765 insertions(+), 644 deletions(-) diff --git a/lldb/include/lldb/Host/JSONTransport.h b/lldb/include/lldb/Host/JSONTransport.h index 72f4404c92887..18126f599c380 100644 --- a/lldb/include/lldb/Host/JSONTransport.h +++ b/lldb/include/lldb/Host/JSONTransport.h @@ -13,29 +13,25 @@ #ifndef LLDB_HOST_JSONTRANSPORT_H #define LLDB_HOST_JSONTRANSPORT_H +#include "lldb/Host/MainLoop.h" #include "lldb/Host/MainLoopBase.h" #include "lldb/Utility/IOObject.h" #include "lldb/Utility/Status.h" #include "lldb/lldb-forward.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Error.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/JSON.h" +#include "llvm/Support/raw_ostream.h" #include <string> #include <system_error> +#include <variant> #include <vector> namespace lldb_private { -class TransportEOFError : public llvm::ErrorInfo<TransportEOFError> { -public: - static char ID; - - TransportEOFError() = default; - void log(llvm::raw_ostream &OS) const override; - std::error_code convertToErrorCode() const override; -}; - class TransportUnhandledContentsError : public llvm::ErrorInfo<TransportUnhandledContentsError> { public: @@ -54,112 +50,220 @@ class TransportUnhandledContentsError std::string m_unhandled_contents; }; -class TransportInvalidError : public llvm::ErrorInfo<TransportInvalidError> { +/// A transport is responsible for maintaining the connection to a client +/// application, and reading/writing structured messages to it. +/// +/// Transports have limited thread safety requirements: +/// - Messages will not be sent concurrently. +/// - Messages MAY be sent while Run() is reading, or its callback is active. +template <typename Req, typename Resp, typename Evt> class Transport { public: - static char ID; - - TransportInvalidError() = default; + using Message = std::variant<Req, Resp, Evt>; + + virtual ~Transport() = default; + + // Called by transport to send outgoing messages. + virtual void Event(const Evt &) = 0; + virtual void Request(const Req &) = 0; + virtual void Response(const Resp &) = 0; + + /// Implemented to handle incoming messages. (See Run() below). + class MessageHandler { + public: + virtual ~MessageHandler() = default; + virtual void OnEvent(const Evt &) = 0; + virtual void OnRequest(const Req &) = 0; + virtual void OnResponse(const Resp &) = 0; + }; + + /// Called by server or client to receive messages from the connection. + /// The transport should in turn invoke the handler to process messages. + /// The MainLoop is used to handle reading from the incoming connection and + /// will run until the loop is terminated. + virtual llvm::Error Run(MainLoop &, MessageHandler &) = 0; - void log(llvm::raw_ostream &OS) const override; - std::error_code convertToErrorCode() const override; +protected: + template <typename... Ts> inline auto Logv(const char *Fmt, Ts &&...Vals) { + Log(llvm::formatv(Fmt, std::forward<Ts>(Vals)...).str()); + } + virtual void Log(llvm::StringRef message) = 0; }; -/// A transport class that uses JSON for communication. -class JSONTransport { +/// A JSONTransport will encode and decode messages using JSON. +template <typename Req, typename Resp, typename Evt> +class JSONTransport : public Transport<Req, Resp, Evt> { public: - using ReadHandleUP = MainLoopBase::ReadHandleUP; - template <typename T> - using Callback = std::function<void(MainLoopBase &, const llvm::Expected<T>)>; - - JSONTransport(lldb::IOObjectSP input, lldb::IOObjectSP output); - virtual ~JSONTransport() = default; - - /// Transport is not copyable. - /// @{ - JSONTransport(const JSONTransport &rhs) = delete; - void operator=(const JSONTransport &rhs) = delete; - /// @} - - /// Writes a message to the output stream. - template <typename T> llvm::Error Write(const T &t) { - const std::string message = llvm::formatv("{0}", toJSON(t)).str(); - return WriteImpl(message); + using Transport<Req, Resp, Evt>::Transport; + + JSONTransport(lldb::IOObjectSP in, lldb::IOObjectSP out) + : m_in(in), m_out(out) {} + + void Event(const Evt &evt) override { Write(evt); } + void Request(const Req &req) override { Write(req); } + void Response(const Resp &resp) override { Write(resp); } + + /// Run registers the transport with the given MainLoop and handles any + /// incoming messages using the given MessageHandler. + llvm::Error + Run(MainLoop &loop, + typename Transport<Req, Resp, Evt>::MessageHandler &handler) override { + llvm::Error error = llvm::Error::success(); + Status status; + auto read_handle = loop.RegisterReadObject( + m_in, + std::bind(&JSONTransport::OnRead, this, &error, std::placeholders::_1, + std::ref(handler)), + status); + if (status.Fail()) { + // This error is only set if the read object handler is invoked, mark it + // as consumed if registration of the handler failed. + llvm::consumeError(std::move(error)); + return status.takeError(); + } + + status = loop.Run(); + if (status.Fail()) + return status.takeError(); + return error; } - /// Registers the transport with the MainLoop. - template <typename T> - llvm::Expected<ReadHandleUP> RegisterReadObject(MainLoopBase &loop, - Callback<T> read_cb) { - Status error; - ReadHandleUP handle = loop.RegisterReadObject( - m_input, - [read_cb, this](MainLoopBase &loop) { - char buf[kReadBufferSize]; - size_t num_bytes = sizeof(buf); - if (llvm::Error error = m_input->Read(buf, num_bytes).takeError()) { - read_cb(loop, std::move(error)); - return; - } - if (num_bytes) - m_buffer.append(std::string(buf, num_bytes)); - - // If the buffer has contents, try parsing any pending messages. - if (!m_buffer.empty()) { - llvm::Expected<std::vector<std::string>> messages = Parse(); - if (llvm::Error error = messages.takeError()) { - read_cb(loop, std::move(error)); - return; - } - - for (const auto &message : *messages) - if constexpr (std::is_same<T, std::string>::value) - read_cb(loop, message); - else - read_cb(loop, llvm::json::parse<T>(message)); - } - - // On EOF, notify the callback after the remaining messages were - // handled. - if (num_bytes == 0) { - if (m_buffer.empty()) - read_cb(loop, llvm::make_error<TransportEOFError>()); - else - read_cb(loop, llvm::make_error<TransportUnhandledContentsError>( - std::string(m_buffer))); - } - }, - error); - if (error.Fail()) - return error.takeError(); - return handle; - } + /// Public for testing purposes, otherwise this should be an implementation + /// detail. + static constexpr size_t kReadBufferSize = 1024; protected: - template <typename... Ts> inline auto Logv(const char *Fmt, Ts &&...Vals) { - Log(llvm::formatv(Fmt, std::forward<Ts>(Vals)...).str()); + virtual llvm::Expected<std::vector<std::string>> Parse() = 0; + virtual std::string Encode(const llvm::json::Value &message) = 0; + void Write(const llvm::json::Value &message) { + this->Logv("<-- {0}", message); + std::string output = Encode(message); + size_t bytes_written = output.size(); + Status status = m_out->Write(output.data(), bytes_written); + if (status.Fail()) { + this->Logv("writing failed: s{0}", status.AsCString()); + } } - virtual void Log(llvm::StringRef message); - virtual llvm::Error WriteImpl(const std::string &message) = 0; - virtual llvm::Expected<std::vector<std::string>> Parse() = 0; + llvm::SmallString<kReadBufferSize> m_buffer; - static constexpr size_t kReadBufferSize = 1024; +private: + void OnRead(llvm::Error *err, MainLoopBase &loop, + typename Transport<Req, Resp, Evt>::MessageHandler &handler) { + llvm::ErrorAsOutParameter ErrAsOutParam(err); + char buf[kReadBufferSize]; + size_t num_bytes = sizeof(buf); + if (Status status = m_in->Read(buf, num_bytes); status.Fail()) { + *err = status.takeError(); + loop.RequestTermination(); + return; + } + + if (num_bytes) + m_buffer.append(llvm::StringRef(buf, num_bytes)); + + // If the buffer has contents, try parsing any pending messages. + if (!m_buffer.empty()) { + llvm::Expected<std::vector<std::string>> raw_messages = Parse(); + if (llvm::Error error = raw_messages.takeError()) { + *err = std::move(error); + loop.RequestTermination(); + return; + } + + for (const auto &raw_message : *raw_messages) { + auto message = + llvm::json::parse<typename Transport<Req, Resp, Evt>::Message>( + raw_message); + if (!message) { + *err = message.takeError(); + loop.RequestTermination(); + return; + } + + if (Evt *evt = std::get_if<Evt>(&*message)) { + handler.OnEvent(*evt); + } else if (Req *req = std::get_if<Req>(&*message)) { + handler.OnRequest(*req); + } else if (Resp *resp = std::get_if<Resp>(&*message)) { + handler.OnResponse(*resp); + } else { + llvm_unreachable("unknown message type"); + } + } + } + + if (num_bytes == 0) { + // If we're at EOF and we have unhandled contents in the buffer, return an + // error for the partial message. + if (m_buffer.empty()) + *err = llvm::Error::success(); + else + *err = llvm::make_error<TransportUnhandledContentsError>( + std::string(m_buffer)); + loop.RequestTermination(); + } + } - lldb::IOObjectSP m_input; - lldb::IOObjectSP m_output; - llvm::SmallString<kReadBufferSize> m_buffer; + lldb::IOObjectSP m_in; + lldb::IOObjectSP m_out; }; /// A transport class for JSON with a HTTP header. -class HTTPDelimitedJSONTransport : public JSONTransport { +template <typename Req, typename Resp, typename Evt> +class HTTPDelimitedJSONTransport : public JSONTransport<Req, Resp, Evt> { public: - HTTPDelimitedJSONTransport(lldb::IOObjectSP input, lldb::IOObjectSP output) - : JSONTransport(input, output) {} - virtual ~HTTPDelimitedJSONTransport() = default; + using JSONTransport<Req, Resp, Evt>::JSONTransport; protected: - llvm::Error WriteImpl(const std::string &message) override; - llvm::Expected<std::vector<std::string>> Parse() override; + /// Encodes messages based on + /// https://microsoft.github.io/debug-adapter-protocol/overview#base-protocol + std::string Encode(const llvm::json::Value &message) override { + std::string output; + std::string raw_message = llvm::formatv("{0}", message).str(); + llvm::raw_string_ostream OS(output); + OS << kHeaderContentLength << kHeaderFieldSeparator << ' ' + << std::to_string(raw_message.size()) << kEndOfHeader << raw_message; + return output; + } + + /// Parses messages based on + /// https://microsoft.github.io/debug-adapter-protocol/overview#base-protocol + llvm::Expected<std::vector<std::string>> Parse() override { + std::vector<std::string> messages; + llvm::StringRef buffer = this->m_buffer; + while (buffer.contains(kEndOfHeader)) { + auto [headers, rest] = buffer.split(kEndOfHeader); + size_t content_length = 0; + // HTTP Headers are formatted like `<field-name> ':' [<field-value>]`. + for (const auto &header : llvm::split(headers, kHeaderSeparator)) { + auto [key, value] = header.split(kHeaderFieldSeparator); + // 'Content-Length' is the only meaningful key at the moment. Others are + // ignored. + if (!key.equals_insensitive(kHeaderContentLength)) + continue; + + value = value.trim(); + if (!llvm::to_integer(value, content_length, 10)) + return llvm::createStringError(std::errc::invalid_argument, + "invalid content length: %s", + value.str().c_str()); + } + + // Check if we have enough data. + if (content_length > rest.size()) + break; + + llvm::StringRef body = rest.take_front(content_length); + buffer = rest.drop_front(content_length); + messages.emplace_back(body.str()); + this->Logv("--> {0}", body); + } + + // Store the remainder of the buffer for the next read callback. + this->m_buffer = buffer.str(); + + return std::move(messages); + } static constexpr llvm::StringLiteral kHeaderContentLength = "Content-Length"; static constexpr llvm::StringLiteral kHeaderFieldSeparator = ":"; @@ -168,15 +272,31 @@ class HTTPDelimitedJSONTransport : public JSONTransport { }; /// A transport class for JSON RPC. -class JSONRPCTransport : public JSONTransport { +template <typename Req, typename Resp, typename Evt> +class JSONRPCTransport : public JSONTransport<Req, Resp, Evt> { public: - JSONRPCTransport(lldb::IOObjectSP input, lldb::IOObjectSP output) - : JSONTransport(input, output) {} - virtual ~JSONRPCTransport() = default; + using JSONTransport<Req, Resp, Evt>::JSONTransport; protected: - llvm::Error WriteImpl(const std::string &message) override; - llvm::Expected<std::vector<std::string>> Parse() override; + std::string Encode(const llvm::json::Value &message) override { + return llvm::formatv("{0}{1}", message, kMessageSeparator).str(); + } + + llvm::Expected<std::vector<std::string>> Parse() override { + std::vector<std::string> messages; + llvm::StringRef buf = this->m_buffer; + while (buf.contains(kMessageSeparator)) { + auto [raw_json, rest] = buf.split(kMessageSeparator); + buf = rest; + messages.emplace_back(raw_json.str()); + this->Logv("--> {0}", raw_json); + } + + // Store the remainder of the buffer for the next read callback. + this->m_buffer = buf.str(); + + return messages; + } static constexpr llvm::StringLiteral kMessageSeparator = "\n"; }; diff --git a/lldb/source/Host/common/JSONTransport.cpp b/lldb/source/Host/common/JSONTransport.cpp index 5f0fb3ce562c3..c4b42eafc85d3 100644 --- a/lldb/source/Host/common/JSONTransport.cpp +++ b/lldb/source/Host/common/JSONTransport.cpp @@ -7,136 +7,26 @@ //===----------------------------------------------------------------------===// #include "lldb/Host/JSONTransport.h" -#include "lldb/Utility/LLDBLog.h" #include "lldb/Utility/Log.h" #include "lldb/Utility/Status.h" -#include "lldb/lldb-forward.h" #include "llvm/ADT/StringExtras.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/Support/Error.h" #include "llvm/Support/raw_ostream.h" #include <string> -#include <utility> using namespace llvm; using namespace lldb; using namespace lldb_private; -void TransportEOFError::log(llvm::raw_ostream &OS) const { - OS << "transport EOF"; -} - -std::error_code TransportEOFError::convertToErrorCode() const { - return std::make_error_code(std::errc::io_error); -} +char TransportUnhandledContentsError::ID; TransportUnhandledContentsError::TransportUnhandledContentsError( std::string unhandled_contents) : m_unhandled_contents(unhandled_contents) {} void TransportUnhandledContentsError::log(llvm::raw_ostream &OS) const { - OS << "transport EOF with unhandled contents " << m_unhandled_contents; + OS << "transport EOF with unhandled contents: '" << m_unhandled_contents + << "'"; } std::error_code TransportUnhandledContentsError::convertToErrorCode() const { return std::make_error_code(std::errc::bad_message); } - -void TransportInvalidError::log(llvm::raw_ostream &OS) const { - OS << "transport IO object invalid"; -} -std::error_code TransportInvalidError::convertToErrorCode() const { - return std::make_error_code(std::errc::not_connected); -} - -JSONTransport::JSONTransport(IOObjectSP input, IOObjectSP output) - : m_input(std::move(input)), m_output(std::move(output)) {} - -void JSONTransport::Log(llvm::StringRef message) { - LLDB_LOG(GetLog(LLDBLog::Host), "{0}", message); -} - -// Parses messages based on -// https://microsoft.github.io/debug-adapter-protocol/overview#base-protocol -Expected<std::vector<std::string>> HTTPDelimitedJSONTransport::Parse() { - std::vector<std::string> messages; - StringRef buffer = m_buffer; - while (buffer.contains(kEndOfHeader)) { - auto [headers, rest] = buffer.split(kEndOfHeader); - size_t content_length = 0; - // HTTP Headers are formatted like `<field-name> ':' [<field-value>]`. - for (const auto &header : llvm::split(headers, kHeaderSeparator)) { - auto [key, value] = header.split(kHeaderFieldSeparator); - // 'Content-Length' is the only meaningful key at the moment. Others are - // ignored. - if (!key.equals_insensitive(kHeaderContentLength)) - continue; - - value = value.trim(); - if (!llvm::to_integer(value, content_length, 10)) - return createStringError(std::errc::invalid_argument, - "invalid content length: %s", - value.str().c_str()); - } - - // Check if we have enough data. - if (content_length > rest.size()) - break; - - StringRef body = rest.take_front(content_length); - buffer = rest.drop_front(content_length); - messages.emplace_back(body.str()); - Logv("--> {0}", body); - } - - // Store the remainder of the buffer for the next read callback. - m_buffer = buffer.str(); - - return std::move(messages); -} - -Error HTTPDelimitedJSONTransport::WriteImpl(const std::string &message) { - if (!m_output || !m_output->IsValid()) - return llvm::make_error<TransportInvalidError>(); - - Logv("<-- {0}", message); - - std::string Output; - raw_string_ostream OS(Output); - OS << kHeaderContentLength << kHeaderFieldSeparator << ' ' << message.length() - << kHeaderSeparator << kHeaderSeparator << message; - size_t num_bytes = Output.size(); - return m_output->Write(Output.data(), num_bytes).takeError(); -} - -Expected<std::vector<std::string>> JSONRPCTransport::Parse() { - std::vector<std::string> messages; - StringRef buf = m_buffer; - while (buf.contains(kMessageSeparator)) { - auto [raw_json, rest] = buf.split(kMessageSeparator); - buf = rest; - messages.emplace_back(raw_json.str()); - Logv("--> {0}", raw_json); - } - - // Store the remainder of the buffer for the next read callback. - m_buffer = buf.str(); - - return messages; -} - -Error JSONRPCTransport::WriteImpl(const std::string &message) { - if (!m_output || !m_output->IsValid()) - return llvm::make_error<TransportInvalidError>(); - - Logv("<-- {0}", message); - - std::string Output; - llvm::raw_string_ostream OS(Output); - OS << message << kMessageSeparator; - size_t num_bytes = Output.size(); - return m_output->Write(Output.data(), num_bytes).takeError(); -} - -char TransportEOFError::ID; -char TransportUnhandledContentsError::ID; -char TransportInvalidError::ID; diff --git a/lldb/source/Protocol/MCP/Protocol.cpp b/lldb/source/Protocol/MCP/Protocol.cpp index d9b11bd766686..65ddfaee70160 100644 --- a/lldb/source/Protocol/MCP/Protocol.cpp +++ b/lldb/source/Protocol/MCP/Protocol.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "lldb/Protocol/MCP/Protocol.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/JSON.h" using namespace llvm; diff --git a/lldb/tools/lldb-dap/DAP.cpp b/lldb/tools/lldb-dap/DAP.cpp index ce910b1f60b85..a9a0fe75a35b7 100644 --- a/lldb/tools/lldb-dap/DAP.cpp +++ b/lldb/tools/lldb-dap/DAP.cpp @@ -121,11 +121,12 @@ static std::string capitalize(llvm::StringRef str) { llvm::StringRef DAP::debug_adapter_path = ""; DAP::DAP(Log *log, const ReplMode default_repl_mode, - std::vector<std::string> pre_init_commands, Transport &transport) + std::vector<std::string> pre_init_commands, + llvm::StringRef client_name, DAPTransport &transport, MainLoop &loop) : log(log), transport(transport), broadcaster("lldb-dap"), progress_event_reporter( [&](const ProgressEvent &event) { SendJSON(event.ToJSON()); }), - repl_mode(default_repl_mode) { + repl_mode(default_repl_mode), m_client_name(client_name), m_loop(loop) { configuration.preInitCommands = std::move(pre_init_commands); RegisterRequests(); } @@ -258,36 +259,33 @@ void DAP::SendJSON(const llvm::json::Value &json) { llvm::json::Path::Root root; if (!fromJSON(json, message, root)) { DAP_LOG_ERROR(log, root.getError(), "({1}) encoding failed: {0}", - transport.GetClientName()); + m_client_name); return; } Send(message); } void DAP::Send(const Message &message) { - // FIXME: After all the requests have migrated from LegacyRequestHandler > - // RequestHandler<> this should be handled in RequestHandler<>::operator(). - if (auto *resp = std::get_if<Response>(&message); - resp && debugger.InterruptRequested()) { - // Clear the interrupt request. - debugger.CancelInterruptRequest(); - - // If the debugger was interrupted, convert this response into a 'cancelled' - // response because we might have a partial result. - Response cancelled{/*request_seq=*/resp->request_seq, - /*command=*/resp->command, - /*success=*/false, - /*message=*/eResponseMessageCancelled, - /*body=*/std::nullopt}; - if (llvm::Error err = transport.Write(cancelled)) - DAP_LOG_ERROR(log, std::move(err), "({1}) write failed: {0}", - transport.GetClientName()); - return; + if (const protocol::Event *event = std::get_if<protocol::Event>(&message)) { + transport.Event(*event); + } else if (const Request *req = std::get_if<Request>(&message)) { + transport.Request(*req); + } else if (const Response *resp = std::get_if<Response>(&message)) { + // FIXME: After all the requests have migrated from LegacyRequestHandler > + // RequestHandler<> this should be handled in RequestHandler<>::operator(). + if (debugger.InterruptRequested()) + // If the debugger was interrupted, convert this response into a + // 'cancelled' response because we might have a partial result. + transport.Response(Response{/*request_seq=*/resp->request_seq, + /*command=*/resp->command, + /*success=*/false, + /*message=*/eResponseMessageCancelled, + /*body=*/std::nullopt}); + else + transport.Response(*resp); + } else { + llvm_unreachable("Unexpected message type"); } - - if (llvm::Error err = transport.Write(message)) - DAP_LOG_ERROR(log, std::move(err), "({1}) write failed: {0}", - transport.GetClientName()); } // "OutputEvent": { @@ -755,7 +753,6 @@ void DAP::RunTerminateCommands() { } lldb::SBTarget DAP::CreateTarget(lldb::SBError &error) { - // Grab the name of the program we need to debug and create a target using // the given program as an argument. Executable file can be a source of target // architecture and platform, if they differ from the host. Setting exe path // in launch info is useless because Target.Launch() will not change @@ -795,7 +792,7 @@ void DAP::SetTarget(const lldb::SBTarget target) { bool DAP::HandleObject(const Message &M) { TelemetryDispatcher dispatcher(&debugger); - dispatcher.Set("client_name", transport.GetClientName().str()); + dispatcher.Set("client_name", m_client_name.str()); if (const auto *req = std::get_if<Request>(&M)) { { std::lock_guard<std::mutex> guard(m_active_request_mutex); @@ -821,8 +818,8 @@ bool DAP::HandleObject(const Message &M) { dispatcher.Set("error", llvm::Twine("unhandled-command:" + req->command).str()); - DAP_LOG(log, "({0}) error: unhandled command '{1}'", - transport.GetClientName(), req->command); + DAP_LOG(log, "({0}) error: unhandled command '{1}'", m_client_name, + req->command); return false; // Fail } @@ -920,8 +917,6 @@ llvm::Error DAP::Disconnect(bool terminateDebuggee) { SendTerminatedEvent(); disconnecting = true; - m_loop.AddPendingCallback( - [](MainLoopBase &loop) { loop.RequestTermination(); }); return ToError(error); } @@ -938,88 +933,74 @@ void DAP::ClearCancelRequest(const CancelArguments &args) { } template <typename T> -static std::optional<T> getArgumentsIfRequest(const Message &pm, +static std::optional<T> getArgumentsIfRequest(const Request &req, llvm::StringLiteral command) { - auto *const req = std::get_if<Request>(&pm); - if (!req || req->command != command) + if (req.command != command) return std::nullopt; T args; llvm::json::Path::Root root; - if (!fromJSON(req->arguments, args, root)) + if (!fromJSON(req.arguments, args, root)) return std::nullopt; return args; } -Status DAP::TransportHandler() { - llvm::set_thread_name(transport.GetClientName() + ".transport_handler"); +void DAP::OnEvent(const protocol::Event &event) { + // no-op, no supported events from the client to the server as of DAP v1.68. +} - auto cleanup = llvm::make_scope_exit([&]() { - // Ensure we're marked as disconnecting when the reader exits. +void DAP::OnRequest(const protocol::Request &request) { + if (request.command == "disconnect") disconnecting = true; - m_queue_cv.notify_all(); - }); - Status status; - auto handle = transport.RegisterReadObject<protocol::Message>( - m_loop, - [&](MainLoopBase &loop, llvm::Expected<protocol::Message> message) { - if (message.errorIsA<TransportEOFError>()) { - llvm::consumeError(message.takeError()); - loop.RequestTermination(); - return; - } + const std::optional<CancelArguments> cancel_args = + getArgumentsIfRequest<CancelArguments>(request, "cancel"); + if (cancel_args) { + { + std::lock_guard<std::mutex> guard(m_cancelled_requests_mutex); + if (cancel_args->requestId) + m_cancelled_requests.insert(*cancel_args->requestId); + } - if (llvm::Error err = message.takeError()) { - status = Status::FromError(std::move(err)); - loop.RequestTermination(); - return; - } + // If a cancel is requested for the active request, make a best + // effort attempt to interrupt. + std::lock_guard<std::mutex> guard(m_active_request_mutex); + if (m_active_request && cancel_args->requestId == m_active_request->seq) { + DAP_LOG(log, "({0}) interrupting inflight request (command={1} seq={2})", + m_client_name, m_active_request->command, m_active_request->seq); + debugger.RequestInterrupt(); + } + } - if (const protocol::Request *req = - std::get_if<protocol::Request>(&*message); - req && req->arguments == "disconnect") - disconnecting = true; - - const std::optional<CancelArguments> cancel_args = - getArgumentsIfRequest<CancelArguments>(*message, "cancel"); - if (cancel_args) { - { - std::lock_guard<std::mutex> guard(m_cancelled_requests_mutex); - if (cancel_args->requestId) - m_cancelled_requests.insert(*cancel_args->requestId); - } + std::lock_guard<std::mutex> guard(m_queue_mutex); + DAP_LOG(log, "({0}) queued (command={1} seq={2})", m_client_name, + request.command, request.seq); + m_queue.push_back(request); + m_queue_cv.notify_one(); +} - // If a cancel is requested for the active request, make a best - // effort attempt to interrupt. - std::lock_guard<std::mutex> guard(m_active_request_mutex); - if (m_active_request && - cancel_args->requestId == m_active_request->seq) { - DAP_LOG(log, - "({0}) interrupting inflight request (command={1} seq={2})", - transport.GetClientName(), m_active_request->command, - m_active_request->seq); - debugger.RequestInterrupt(); - } - } +void DAP::OnResponse(const protocol::Response &response) { + std::lock_guard<std::mutex> guard(m_queue_mutex); + DAP_LOG(log, "({0}) queued (command={1} seq={2})", m_client_name, + response.command, response.request_seq); + m_queue.push_back(response); + m_queue_cv.notify_one(); +} - std::lock_guard<std::mutex> guard(m_queue_mutex); - m_queue.push_back(std::move(*message)); - m_queue_cv.notify_one(); - }); - if (auto err = handle.takeError()) - return Status::FromError(std::move(err)); - if (llvm::Error err = m_loop.Run().takeError()) - return Status::FromError(std::move(err)); - return status; +void DAP::TransportHandler(llvm::Error *error) { + llvm::ErrorAsOutParameter ErrAsOutParam(*error); + auto cleanup = llvm::make_scope_exit([&]() { + // Ensure we're marked as disconnecting when the reader exits. + disconnecting = true; + m_queue_cv.notify_all(); + }); + *error = transport.Run(m_loop, *this); } llvm::Error DAP::Loop() { - // Can't use \a std::future<llvm::Error> because it doesn't compile on - // Windows. - std::future<Status> queue_reader = - std::async(std::launch::async, &DAP::TransportHandler, this); + llvm::Error error = llvm::Error::success(); + auto thread = std::thread(std::bind(&DAP::TransportHandler, this, &error)); auto cleanup = llvm::make_scope_exit([&]() { out.Stop(); @@ -1045,7 +1026,11 @@ llvm::Error DAP::Loop() { "unhandled packet"); } - return queue_reader.get().takeError(); + m_loop.AddPendingCallback( + [](MainLoopBase &loop) { loop.RequestTermination(); }); + thread.join(); + + return error; } lldb::SBError DAP::WaitForProcessToStop(std::chrono::seconds seconds) { @@ -1284,7 +1269,7 @@ void DAP::ProgressEventThread() { // them prevent multiple threads from writing simultaneously so no locking // is required. void DAP::EventThread() { - llvm::set_thread_name(transport.GetClientName() + ".event_handler"); + llvm::set_thread_name("lldb.DAP.client." + m_client_name + ".event_handler"); lldb::SBEvent event; lldb::SBListener listener = debugger.GetListener(); broadcaster.AddListener(listener, eBroadcastBitStopEventThread); @@ -1316,7 +1301,7 @@ void DAP::EventThread() { if (llvm::Error err = SendThreadStoppedEvent(*this)) DAP_LOG_ERROR(log, std::move(err), "({1}) reporting thread stopped: {0}", - transport.GetClientName()); + m_client_name); } break; case lldb::eStateRunning: diff --git a/lldb/tools/lldb-dap/DAP.h b/lldb/tools/lldb-dap/DAP.h index b0e9fa9c16b75..628f97257d5f0 100644 --- a/lldb/tools/lldb-dap/DAP.h +++ b/lldb/tools/lldb-dap/DAP.h @@ -78,12 +78,16 @@ enum DAPBroadcasterBits { enum class ReplMode { Variable = 0, Command, Auto }; -struct DAP { +using DAPTransport = + lldb_private::Transport<protocol::Request, protocol::Response, + protocol::Event>; + +struct DAP final : private DAPTransport::MessageHandler { /// Path to the lldb-dap binary itself. static llvm::StringRef debug_adapter_path; Log *log; - Transport &transport; + DAPTransport &transport; lldb::SBFile in; OutputRedirector out; OutputRedirector err; @@ -177,8 +181,11 @@ struct DAP { /// allocated. /// \param[in] transport /// Transport for this debug session. + /// \param[in] loop + /// Main loop associated with this instance. DAP(Log *log, const ReplMode default_repl_mode, - std::vector<std::string> pre_init_commands, Transport &transport); + std::vector<std::string> pre_init_commands, llvm::StringRef client_name, + DAPTransport &transport, lldb_private::MainLoop &loop); ~DAP(); @@ -317,7 +324,7 @@ struct DAP { lldb::SBTarget CreateTarget(lldb::SBError &error); /// Set given target object as a current target for lldb-dap and start - /// listeing for its breakpoint events. + /// listening for its breakpoint events. void SetTarget(const lldb::SBTarget target); bool HandleObject(const protocol::Message &M); @@ -420,13 +427,17 @@ struct DAP { const std::optional<std::vector<protocol::SourceBreakpoint>> &breakpoints); + void OnEvent(const protocol::Event &) override; + void OnRequest(const protocol::Request &) override; + void OnResponse(const protocol::Response &) override; + private: std::vector<protocol::Breakpoint> SetSourceBreakpoints( const protocol::Source &source, const std::optional<std::vector<protocol::SourceBreakpoint>> &breakpoints, SourceBreakpointMap &existing_breakpoints); - lldb_private::Status TransportHandler(); + void TransportHandler(llvm::Error *); /// Registration of request handler. /// @{ @@ -446,6 +457,8 @@ struct DAP { std::thread progress_event_thread; /// @} + const llvm::StringRef m_client_name; + /// List of addresses mapped by sourceReference. std::vector<lldb::addr_t> m_source_references; std::mutex m_source_references_mutex; @@ -456,7 +469,7 @@ struct DAP { std::condition_variable m_queue_cv; // Loop for managing reading from the client. - lldb_private::MainLoop m_loop; + lldb_private::MainLoop &m_loop; std::mutex m_cancelled_requests_mutex; llvm::SmallSet<int64_t, 4> m_cancelled_requests; diff --git a/lldb/tools/lldb-dap/Protocol/ProtocolBase.h b/lldb/tools/lldb-dap/Protocol/ProtocolBase.h index 81496380d412f..0a9ef538a7398 100644 --- a/lldb/tools/lldb-dap/Protocol/ProtocolBase.h +++ b/lldb/tools/lldb-dap/Protocol/ProtocolBase.h @@ -52,6 +52,7 @@ struct Request { }; llvm::json::Value toJSON(const Request &); bool fromJSON(const llvm::json::Value &, Request &, llvm::json::Path); +bool operator==(const Request &, const Request &); /// A debug adapter initiated event. struct Event { @@ -63,6 +64,7 @@ struct Event { }; llvm::json::Value toJSON(const Event &); bool fromJSON(const llvm::json::Value &, Event &, llvm::json::Path); +bool operator==(const Event &, const Event &); enum ResponseMessage : unsigned { /// The request was cancelled @@ -101,6 +103,7 @@ struct Response { }; bool fromJSON(const llvm::json::Value &, Response &, llvm::json::Path); llvm::json::Value toJSON(const Response &); +bool operator==(const Response &, const Response &); /// A structured message object. Used to return errors from requests. struct ErrorMessage { @@ -140,6 +143,7 @@ llvm::json::Value toJSON(const ErrorMessage &); using Message = std::variant<Request, Response, Event>; bool fromJSON(const llvm::json::Value &, Message &, llvm::json::Path); llvm::json::Value toJSON(const Message &); +bool operator==(const Message &, const Message &); inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, const Message &V) { OS << toJSON(V); diff --git a/lldb/tools/lldb-dap/Transport.cpp b/lldb/tools/lldb-dap/Transport.cpp index d602920da34e3..8f71f88cae1f7 100644 --- a/lldb/tools/lldb-dap/Transport.cpp +++ b/lldb/tools/lldb-dap/Transport.cpp @@ -14,7 +14,8 @@ using namespace llvm; using namespace lldb; using namespace lldb_private; -using namespace lldb_dap; + +namespace lldb_dap { Transport::Transport(llvm::StringRef client_name, lldb_dap::Log *log, lldb::IOObjectSP input, lldb::IOObjectSP output) @@ -24,3 +25,5 @@ Transport::Transport(llvm::StringRef client_name, lldb_dap::Log *log, void Transport::Log(llvm::StringRef message) { DAP_LOG(m_log, "({0}) {1}", m_client_name, message); } + +} // namespace lldb_dap diff --git a/lldb/tools/lldb-dap/Transport.h b/lldb/tools/lldb-dap/Transport.h index 9a7d8f424d40e..efeb0b9cd6c55 100644 --- a/lldb/tools/lldb-dap/Transport.h +++ b/lldb/tools/lldb-dap/Transport.h @@ -15,6 +15,7 @@ #define LLDB_TOOLS_LLDB_DAP_TRANSPORT_H #include "DAPForward.h" +#include "Protocol/ProtocolBase.h" #include "lldb/Host/JSONTransport.h" #include "lldb/lldb-forward.h" #include "llvm/ADT/StringRef.h" @@ -23,7 +24,9 @@ namespace lldb_dap { /// A transport class that performs the Debug Adapter Protocol communication /// with the client. -class Transport : public lldb_private::HTTPDelimitedJSONTransport { +class Transport final + : public lldb_private::HTTPDelimitedJSONTransport< + protocol::Request, protocol::Response, protocol::Event> { public: Transport(llvm::StringRef client_name, lldb_dap::Log *log, lldb::IOObjectSP input, lldb::IOObjectSP output); diff --git a/lldb/tools/lldb-dap/tool/lldb-dap.cpp b/lldb/tools/lldb-dap/tool/lldb-dap.cpp index 8bba4162aa7bf..c728b0af94c7c 100644 --- a/lldb/tools/lldb-dap/tool/lldb-dap.cpp +++ b/lldb/tools/lldb-dap/tool/lldb-dap.cpp @@ -284,7 +284,7 @@ serveConnection(const Socket::SocketProtocol &protocol, const std::string &name, }); std::condition_variable dap_sessions_condition; std::mutex dap_sessions_mutex; - std::map<IOObject *, DAP *> dap_sessions; + std::map<MainLoop *, DAP *> dap_sessions; unsigned int clientCount = 0; auto handle = listener->Accept(g_loop, [=, &dap_sessions_condition, &dap_sessions_mutex, &dap_sessions, @@ -300,8 +300,10 @@ serveConnection(const Socket::SocketProtocol &protocol, const std::string &name, std::thread client([=, &dap_sessions_condition, &dap_sessions_mutex, &dap_sessions]() { llvm::set_thread_name(client_name + ".runloop"); + MainLoop loop; Transport transport(client_name, log, io, io); - DAP dap(log, default_repl_mode, pre_init_commands, transport); + DAP dap(log, default_repl_mode, pre_init_commands, client_name, transport, + loop); if (auto Err = dap.ConfigureIO()) { llvm::logAllUnhandledErrors(std::move(Err), llvm::errs(), @@ -311,7 +313,7 @@ serveConnection(const Socket::SocketProtocol &protocol, const std::string &name, { std::scoped_lock<std::mutex> lock(dap_sessions_mutex); - dap_sessions[io.get()] = &dap; + dap_sessions[&loop] = &dap; } if (auto Err = dap.Loop()) { @@ -322,7 +324,7 @@ serveConnection(const Socket::SocketProtocol &protocol, const std::string &name, DAP_LOG(log, "({0}) client disconnected", client_name); std::unique_lock<std::mutex> lock(dap_sessions_mutex); - dap_sessions.erase(io.get()); + dap_sessions.erase(&loop); std::notify_all_at_thread_exit(dap_sessions_condition, std::move(lock)); }); client.detach(); @@ -344,13 +346,14 @@ serveConnection(const Socket::SocketProtocol &protocol, const std::string &name, bool client_failed = false; { std::scoped_lock<std::mutex> lock(dap_sessions_mutex); - for (auto [sock, dap] : dap_sessions) { + for (auto [loop, dap] : dap_sessions) { if (llvm::Error error = dap->Disconnect()) { client_failed = true; - llvm::errs() << "DAP client " << dap->transport.GetClientName() - << " disconnected failed: " + llvm::errs() << "DAP client disconnected failed: " << llvm::toString(std::move(error)) << "\n"; } + loop->AddPendingCallback( + [](MainLoopBase &loop) { loop.RequestTermination(); }); } } @@ -550,8 +553,10 @@ int main(int argc, char *argv[]) { stdout_fd, File::eOpenOptionWriteOnly, NativeFile::Unowned); constexpr llvm::StringLiteral client_name = "stdio"; + MainLoop loop; Transport transport(client_name, log.get(), input, output); - DAP dap(log.get(), default_repl_mode, pre_init_commands, transport); + DAP dap(log.get(), default_repl_mode, pre_init_commands, client_name, + transport, loop); // stdout/stderr redirection to the IDE's console if (auto Err = dap.ConfigureIO(stdout, stderr)) { diff --git a/lldb/unittests/DAP/DAPTest.cpp b/lldb/unittests/DAP/DAPTest.cpp index 138910d917424..744e6e69a8d33 100644 --- a/lldb/unittests/DAP/DAPTest.cpp +++ b/lldb/unittests/DAP/DAPTest.cpp @@ -11,7 +11,6 @@ #include "TestBase.h" #include "llvm/Testing/Support/Error.h" #include "gtest/gtest.h" -#include <memory> #include <optional> using namespace llvm; @@ -27,12 +26,15 @@ TEST_F(DAPTest, SendProtocolMessages) { /*log=*/nullptr, /*default_repl_mode=*/ReplMode::Auto, /*pre_init_commands=*/{}, - /*transport=*/*to_dap, + /*client_name=*/"test_client", + /*transport=*/*transport, + loop, }; dap.Send(Event{/*event=*/"my-event", /*body=*/std::nullopt}); - RunOnce<protocol::Message>([&](llvm::Expected<protocol::Message> message) { - ASSERT_THAT_EXPECTED( - message, HasValue(testing::VariantWith<Event>(testing::FieldsAre( - /*event=*/"my-event", /*body=*/std::nullopt)))); - }); + loop.AddPendingCallback( + [](lldb_private::MainLoopBase &loop) { loop.RequestTermination(); }); + ASSERT_THAT_ERROR(dap.Loop(), llvm::Succeeded()); + ASSERT_THAT(from_dap, + ElementsAre(testing::VariantWith<Event>(testing::FieldsAre( + /*event=*/"my-event", /*body=*/std::nullopt)))); } diff --git a/lldb/unittests/DAP/Handler/DisconnectTest.cpp b/lldb/unittests/DAP/Handler/DisconnectTest.cpp index 0546aeb154d50..5b082151680dd 100644 --- a/lldb/unittests/DAP/Handler/DisconnectTest.cpp +++ b/lldb/unittests/DAP/Handler/DisconnectTest.cpp @@ -31,8 +31,8 @@ TEST_F(DisconnectRequestHandlerTest, DisconnectTriggersTerminated) { EXPECT_FALSE(dap->disconnecting); ASSERT_THAT_ERROR(handler.Run(std::nullopt), Succeeded()); EXPECT_TRUE(dap->disconnecting); - std::vector<Message> messages = DrainOutput(); - EXPECT_THAT(messages, + RunOnce(); + EXPECT_THAT(from_dap, testing::Contains(testing::VariantWith<Event>(testing::FieldsAre( /*event=*/"terminated", /*body=*/testing::_)))); } @@ -53,11 +53,13 @@ TEST_F(DisconnectRequestHandlerTest, DisconnectTriggersTerminateCommands) { EXPECT_EQ(dap->target.GetProcess().GetState(), lldb::eStateStopped); ASSERT_THAT_ERROR(handler.Run(std::nullopt), Succeeded()); EXPECT_TRUE(dap->disconnecting); - std::vector<Message> messages = DrainOutput(); - EXPECT_THAT(messages, testing::ElementsAre( - OutputMatcher("Running terminateCommands:\n"), - OutputMatcher("(lldb) script print(2)\n"), - OutputMatcher("2\n"), - testing::VariantWith<Event>(testing::FieldsAre( - /*event=*/"terminated", /*body=*/testing::_)))); + RunOnce(); + EXPECT_THAT(from_dap, + testing::Contains(OutputMatcher("Running terminateCommands:\n"))); + EXPECT_THAT(from_dap, + testing::Contains(OutputMatcher("(lldb) script print(2)\n"))); + EXPECT_THAT(from_dap, testing::Contains(OutputMatcher("2\n"))); + EXPECT_THAT(from_dap, + testing::Contains(testing::VariantWith<Event>(testing::FieldsAre( + /*event=*/"terminated", /*body=*/testing::_)))); } diff --git a/lldb/unittests/DAP/TestBase.cpp b/lldb/unittests/DAP/TestBase.cpp index 8f9b098c8b1e1..64097d177c4a9 100644 --- a/lldb/unittests/DAP/TestBase.cpp +++ b/lldb/unittests/DAP/TestBase.cpp @@ -7,14 +7,11 @@ //===----------------------------------------------------------------------===// #include "TestBase.h" -#include "Protocol/ProtocolBase.h" #include "TestingSupport/TestUtilities.h" #include "lldb/API/SBDefines.h" #include "lldb/API/SBStructuredData.h" -#include "lldb/Host/File.h" #include "lldb/Host/MainLoop.h" #include "lldb/Host/Pipe.h" -#include "lldb/lldb-forward.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Error.h" #include "llvm/Testing/Support/Error.h" @@ -26,39 +23,17 @@ using namespace lldb; using namespace lldb_dap; using namespace lldb_dap::protocol; using namespace lldb_dap_tests; -using lldb_private::File; using lldb_private::MainLoop; -using lldb_private::MainLoopBase; -using lldb_private::NativeFile; using lldb_private::Pipe; -void TransportBase::SetUp() { - PipePairTest::SetUp(); - to_dap = std::make_unique<Transport>( - "to_dap", nullptr, - std::make_shared<NativeFile>(input.GetReadFileDescriptor(), - File::eOpenOptionReadOnly, - NativeFile::Unowned), - std::make_shared<NativeFile>(output.GetWriteFileDescriptor(), - File::eOpenOptionWriteOnly, - NativeFile::Unowned)); - from_dap = std::make_unique<Transport>( - "from_dap", nullptr, - std::make_shared<NativeFile>(output.GetReadFileDescriptor(), - File::eOpenOptionReadOnly, - NativeFile::Unowned), - std::make_shared<NativeFile>(input.GetWriteFileDescriptor(), - File::eOpenOptionWriteOnly, - NativeFile::Unowned)); -} - void DAPTestBase::SetUp() { TransportBase::SetUp(); dap = std::make_unique<DAP>( /*log=*/nullptr, /*default_repl_mode=*/ReplMode::Auto, /*pre_init_commands=*/std::vector<std::string>(), - /*transport=*/*to_dap); + /*client_name=*/"test_client", + /*transport=*/*transport, /*loop=*/loop); } void DAPTestBase::TearDown() { @@ -118,22 +93,3 @@ void DAPTestBase::LoadCore() { SBProcess process = dap->target.LoadCore(this->core->TmpName.data()); ASSERT_TRUE(process); } - -std::vector<Message> DAPTestBase::DrainOutput() { - std::vector<Message> msgs; - output.CloseWriteFileDescriptor(); - auto handle = from_dap->RegisterReadObject<protocol::Message>( - loop, [&](MainLoopBase &loop, Expected<protocol::Message> next) { - if (llvm::Error error = next.takeError()) { - loop.RequestTermination(); - consumeError(std::move(error)); - return; - } - - msgs.push_back(*next); - }); - - consumeError(handle.takeError()); - consumeError(loop.Run().takeError()); - return msgs; -} diff --git a/lldb/unittests/DAP/TestBase.h b/lldb/unittests/DAP/TestBase.h index afdfb540d39b8..4591c0fc72726 100644 --- a/lldb/unittests/DAP/TestBase.h +++ b/lldb/unittests/DAP/TestBase.h @@ -8,41 +8,81 @@ #include "DAP.h" #include "Protocol/ProtocolBase.h" -#include "TestingSupport/Host/PipeTestUtilities.h" -#include "Transport.h" #include "lldb/Host/MainLoop.h" +#include "lldb/Host/MainLoopBase.h" #include "llvm/ADT/StringRef.h" +#include "llvm/Testing/Support/Error.h" #include "gmock/gmock.h" #include "gtest/gtest.h" namespace lldb_dap_tests { +class TestTransport final + : public lldb_private::Transport<lldb_dap::protocol::Request, + lldb_dap::protocol::Response, + lldb_dap::protocol::Event> { +public: + using Message = lldb_private::Transport<lldb_dap::protocol::Request, + lldb_dap::protocol::Response, + lldb_dap::protocol::Event>::Message; + + TestTransport(lldb_private::MainLoop &loop, MessageHandler &handler) + : m_loop(loop), m_handler(handler) {} + + void Event(const lldb_dap::protocol::Event &e) override { + m_loop.AddPendingCallback([this, e](lldb_private::MainLoopBase &) { + this->m_handler.OnEvent(e); + }); + } + + void Request(const lldb_dap::protocol::Request &r) override { + m_loop.AddPendingCallback([this, r](lldb_private::MainLoopBase &) { + this->m_handler.OnRequest(r); + }); + } + + void Response(const lldb_dap::protocol::Response &r) override { + m_loop.AddPendingCallback([this, r](lldb_private::MainLoopBase &) { + this->m_handler.OnResponse(r); + }); + } + + llvm::Error Run(lldb_private::MainLoop &loop, MessageHandler &) override { + return loop.Run().takeError(); + } + + void Log(llvm::StringRef message) override { + log_messages.emplace_back(message); + } + + std::vector<std::string> log_messages; + +private: + lldb_private::MainLoop &m_loop; + MessageHandler &m_handler; +}; + /// A base class for tests that need transport configured for communicating DAP /// messages. -class TransportBase : public PipePairTest { +class TransportBase : public testing::Test, + public TestTransport::MessageHandler { protected: - std::unique_ptr<lldb_dap::Transport> to_dap; - std::unique_ptr<lldb_dap::Transport> from_dap; + std::vector<TestTransport::Message> from_dap; lldb_private::MainLoop loop; + std::unique_ptr<TestTransport> transport; - void SetUp() override; + void SetUp() override { + transport = std::make_unique<TestTransport>(loop, *this); + } - template <typename P> - void RunOnce(const std::function<void(llvm::Expected<P>)> &callback, - std::chrono::milliseconds timeout = std::chrono::seconds(1)) { - auto handle = from_dap->RegisterReadObject<P>( - loop, [&](lldb_private::MainLoopBase &loop, llvm::Expected<P> message) { - callback(std::move(message)); - loop.RequestTermination(); - }); - loop.AddCallback( - [](lldb_private::MainLoopBase &loop) { - loop.RequestTermination(); - FAIL() << "timeout waiting for read callback"; - }, - timeout); - ASSERT_THAT_EXPECTED(handle, llvm::Succeeded()); - ASSERT_THAT_ERROR(loop.Run().takeError(), llvm::Succeeded()); + void OnEvent(const lldb_dap::protocol::Event &e) override { + from_dap.emplace_back(e); + } + void OnRequest(const lldb_dap::protocol::Request &r) override { + from_dap.emplace_back(r); + } + void OnResponse(const lldb_dap::protocol::Response &r) override { + from_dap.emplace_back(r); } }; @@ -75,7 +115,12 @@ class DAPTestBase : public TransportBase { /// Closes the DAP output pipe and returns the remaining protocol messages in /// the buffer. - std::vector<lldb_dap::protocol::Message> DrainOutput(); + // std::vector<lldb_dap::protocol::Message> DrainOutput(); + void RunOnce() { + loop.AddPendingCallback( + [](lldb_private::MainLoopBase &loop) { loop.RequestTermination(); }); + ASSERT_THAT_ERROR(dap->Loop(), llvm::Succeeded()); + } }; } // namespace lldb_dap_tests diff --git a/lldb/unittests/Host/JSONTransportTest.cpp b/lldb/unittests/Host/JSONTransportTest.cpp index 4e94582d3bc6a..fdfde328c69b7 100644 --- a/lldb/unittests/Host/JSONTransportTest.cpp +++ b/lldb/unittests/Host/JSONTransportTest.cpp @@ -11,15 +11,18 @@ #include "lldb/Host/File.h" #include "lldb/Host/MainLoop.h" #include "lldb/Host/MainLoopBase.h" -#include "llvm/ADT/FunctionExtras.h" +#include "lldb/Utility/Log.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Error.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/FormatVariadic.h" #include "llvm/Support/JSON.h" +#include "llvm/Support/raw_ostream.h" #include "llvm/Testing/Support/Error.h" +#include "gmock/gmock.h" #include "gtest/gtest.h" #include <chrono> #include <cstddef> -#include <future> #include <memory> #include <string> @@ -28,22 +31,119 @@ using namespace lldb_private; namespace { -struct JSONTestType { - std::string str; +namespace test_protocol { + +struct Req { + std::string name; }; +json::Value toJSON(const Req &T) { return json::Object{{"req", T.name}}; } +bool fromJSON(const json::Value &V, Req &T, json::Path P) { + json::ObjectMapper O(V, P); + return O && O.map("req", T.name); +} +bool operator==(const Req &a, const Req &b) { return a.name == b.name; } +inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, const Req &V) { + OS << toJSON(V); + return OS; +} +void PrintTo(const Req &message, std::ostream *os) { + std::string O; + llvm::raw_string_ostream OS(O); + OS << message; + *os << O; +} -json::Value toJSON(const JSONTestType &T) { - return json::Object{{"str", T.str}}; +struct Resp { + std::string name; +}; +json::Value toJSON(const Resp &T) { return json::Object{{"resp", T.name}}; } +bool fromJSON(const json::Value &V, Resp &T, json::Path P) { + json::ObjectMapper O(V, P); + return O && O.map("resp", T.name); +} +bool operator==(const Resp &a, const Resp &b) { return a.name == b.name; } +inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, const Resp &V) { + OS << toJSON(V); + return OS; +} +void PrintTo(const Resp &message, std::ostream *os) { + std::string O; + llvm::raw_string_ostream OS(O); + OS << message; + *os << O; } -bool fromJSON(const json::Value &V, JSONTestType &T, json::Path P) { +struct Evt { + std::string name; +}; +json::Value toJSON(const Evt &T) { return json::Object{{"evt", T.name}}; } +bool fromJSON(const json::Value &V, Evt &T, json::Path P) { json::ObjectMapper O(V, P); - return O && O.map("str", T.str); + return O && O.map("evt", T.name); +} +bool operator==(const Evt &a, const Evt &b) { return a.name == b.name; } +inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, const Evt &V) { + OS << toJSON(V); + return OS; +} +void PrintTo(const Evt &message, std::ostream *os) { + std::string O; + llvm::raw_string_ostream OS(O); + OS << message; + *os << O; +} + +using Message = std::variant<Req, Resp, Evt>; +json::Value toJSON(const Message &T) { + if (const Req *req = std::get_if<Req>(&T)) + return toJSON(*req); + if (const Resp *resp = std::get_if<Resp>(&T)) + return toJSON(*resp); + if (const Evt *evt = std::get_if<Evt>(&T)) + return toJSON(*evt); + llvm_unreachable("unknown message type"); +} +bool fromJSON(const json::Value &V, Message &T, json::Path P) { + const json::Object *O = V.getAsObject(); + if (!O) { + P.report("expected object"); + return false; + } + if (O->get("req")) { + Req R; + if (!fromJSON(V, R, P)) + return false; + + T = std::move(R); + return true; + } + if (O->get("resp")) { + Resp R; + if (!fromJSON(V, R, P)) + return false; + + T = std::move(R); + return true; + } + if (O->get("evt")) { + Evt E; + if (!fromJSON(V, E, P)) + return false; + + T = std::move(E); + return true; + } + P.report("unknown message type"); + return false; } -template <typename T> class JSONTransportTest : public PipePairTest { +} // namespace test_protocol + +template <typename T, typename Req, typename Resp, typename Evt> +class JSONTransportTest : public PipePairTest { + protected: - std::unique_ptr<JSONTransport> transport; + std::unique_ptr<T> transport; MainLoop loop; void SetUp() override { @@ -57,53 +157,59 @@ template <typename T> class JSONTransportTest : public PipePairTest { NativeFile::Unowned)); } - template <typename P> - Expected<P> - RunOnce(std::chrono::milliseconds timeout = std::chrono::seconds(1)) { - std::promise<Expected<P>> promised_message; - std::future<Expected<P>> future_message = promised_message.get_future(); - RunUntil<P>( - [&promised_message](Expected<P> message) mutable -> bool { - promised_message.set_value(std::move(message)); - return /*keep_going*/ false; - }, - timeout); - return future_message.get(); + class MessageCollector final + : public Transport<Req, Resp, Evt>::MessageHandler { + public: + std::vector<typename T::Message> messages; + void OnEvent(const Evt &V) override { messages.emplace_back(V); } + void OnRequest(const Req &V) override { messages.emplace_back(V); } + void OnResponse(const Resp &V) override { messages.emplace_back(V); } + }; + + /// Run the transport MainLoop and return any messages received. + Expected<std::vector<typename T::Message>> + Run(std::chrono::milliseconds timeout = std::chrono::milliseconds(5000)) { + MessageCollector collector; + loop.AddCallback([](MainLoopBase &loop) { loop.RequestTermination(); }, + timeout); + if (auto error = transport->Run(loop, collector)) + return error; + return std::move(collector.messages); } - /// RunUntil runs the event loop until the callback returns `false` or a - /// timeout has occurred. - template <typename P> - void RunUntil(std::function<bool(Expected<P>)> callback, - std::chrono::milliseconds timeout = std::chrono::seconds(1)) { - auto handle = transport->RegisterReadObject<P>( - loop, [&callback](MainLoopBase &loop, Expected<P> message) mutable { - bool keep_going = callback(std::move(message)); - if (!keep_going) - loop.RequestTermination(); - }); - loop.AddCallback( - [&callback](MainLoopBase &loop) mutable { - loop.RequestTermination(); - callback(createStringError("timeout")); - }, - timeout); - EXPECT_THAT_EXPECTED(handle, Succeeded()); - EXPECT_THAT_ERROR(loop.Run().takeError(), Succeeded()); - } - - template <typename... Ts> llvm::Expected<size_t> Write(Ts... args) { + template <typename... Ts> void Write(Ts... args) { std::string message; for (const auto &arg : {args...}) message += Encode(arg); - return input.Write(message.data(), message.size()); + EXPECT_THAT_EXPECTED(input.Write(message.data(), message.size()), + Succeeded()); + } + + template <typename... Ts> void WriteAndCloseInput(Ts... args) { + Write<Ts...>(std::forward<Ts>(args)...); + input.CloseWriteFileDescriptor(); } virtual std::string Encode(const json::Value &) = 0; }; +class TestHTTPDelimitedJSONTransport final + : public HTTPDelimitedJSONTransport<test_protocol::Req, test_protocol::Resp, + test_protocol::Evt> { +public: + using HTTPDelimitedJSONTransport::HTTPDelimitedJSONTransport; + + void Log(llvm::StringRef message) override { + log_messages.emplace_back(message); + } + + std::vector<std::string> log_messages; +}; + class HTTPDelimitedJSONTransportTest - : public JSONTransportTest<HTTPDelimitedJSONTransport> { + : public JSONTransportTest<TestHTTPDelimitedJSONTransport, + test_protocol::Req, test_protocol::Resp, + test_protocol::Evt> { public: using JSONTransportTest::JSONTransportTest; @@ -118,7 +224,22 @@ class HTTPDelimitedJSONTransportTest } }; -class JSONRPCTransportTest : public JSONTransportTest<JSONRPCTransport> { +class TestJSONRPCTransport final + : public JSONRPCTransport<test_protocol::Req, test_protocol::Resp, + test_protocol::Evt> { +public: + using JSONRPCTransport::JSONRPCTransport; + + void Log(llvm::StringRef message) override { + log_messages.emplace_back(message); + } + + std::vector<std::string> log_messages; +}; + +class JSONRPCTransportTest + : public JSONTransportTest<TestJSONRPCTransport, test_protocol::Req, + test_protocol::Resp, test_protocol::Evt> { public: using JSONTransportTest::JSONTransportTest; @@ -134,6 +255,7 @@ class JSONRPCTransportTest : public JSONTransportTest<JSONRPCTransport> { // Failing on Windows, see https://github.com/llvm/llvm-project/issues/153446. #ifndef _WIN32 +using namespace test_protocol; TEST_F(HTTPDelimitedJSONTransportTest, MalformedRequests) { std::string malformed_header = @@ -141,84 +263,65 @@ TEST_F(HTTPDelimitedJSONTransportTest, MalformedRequests) { ASSERT_THAT_EXPECTED( input.Write(malformed_header.data(), malformed_header.size()), Succeeded()); - ASSERT_THAT_EXPECTED(RunOnce<JSONTestType>(), - FailedWithMessage("invalid content length: -1")); + ASSERT_THAT_EXPECTED(Run(), FailedWithMessage("invalid content length: -1")); } TEST_F(HTTPDelimitedJSONTransportTest, Read) { - ASSERT_THAT_EXPECTED(Write(JSONTestType{"foo"}), Succeeded()); - ASSERT_THAT_EXPECTED(RunOnce<JSONTestType>(), - HasValue(testing::FieldsAre(/*str=*/"foo"))); + WriteAndCloseInput(Req{"foo"}); + ASSERT_THAT_EXPECTED(Run(), HasValue(testing::ElementsAre(Req{"foo"}))); } TEST_F(HTTPDelimitedJSONTransportTest, ReadMultipleMessagesInSingleWrite) { - ASSERT_THAT_EXPECTED(Write(JSONTestType{"one"}, JSONTestType{"two"}), - Succeeded()); - unsigned count = 0; - RunUntil<JSONTestType>([&](Expected<JSONTestType> message) -> bool { - if (count == 0) { - EXPECT_THAT_EXPECTED(message, - HasValue(testing::FieldsAre(/*str=*/"one"))); - } else if (count == 1) { - EXPECT_THAT_EXPECTED(message, - HasValue(testing::FieldsAre(/*str=*/"two"))); - } - - count++; - return count < 2; - }); + WriteAndCloseInput(Message{Req{"one"}}, Message{Resp{"two"}}, + Message{Evt{"three"}}); + EXPECT_THAT_EXPECTED(Run(), HasValue(testing::ElementsAre( + Req{"one"}, Resp{"two"}, Evt{"three"}))); } TEST_F(HTTPDelimitedJSONTransportTest, ReadAcrossMultipleChunks) { std::string long_str = std::string(2048, 'x'); - ASSERT_THAT_EXPECTED(Write(JSONTestType{long_str}), Succeeded()); - ASSERT_THAT_EXPECTED(RunOnce<JSONTestType>(), - HasValue(testing::FieldsAre(/*str=*/long_str))); + WriteAndCloseInput(Req{long_str}); + ASSERT_THAT_EXPECTED(Run(), HasValue(testing::ElementsAre(Req{long_str}))); } TEST_F(HTTPDelimitedJSONTransportTest, ReadPartialMessage) { - std::string message = Encode(JSONTestType{"foo"}); - std::string part1 = message.substr(0, 28); - std::string part2 = message.substr(28); + std::string message = Encode(Req{"foo"}); + auto split_at = message.size() / 2; + std::string part1 = message.substr(0, split_at); + std::string part2 = message.substr(split_at); ASSERT_THAT_EXPECTED(input.Write(part1.data(), part1.size()), Succeeded()); - - ASSERT_THAT_EXPECTED( - RunOnce<JSONTestType>(/*timeout=*/std::chrono::milliseconds(10)), - FailedWithMessage("timeout")); + ASSERT_THAT_EXPECTED(Run(/*timeout=*/std::chrono::milliseconds(10)), + HasValue(testing::IsEmpty())); ASSERT_THAT_EXPECTED(input.Write(part2.data(), part2.size()), Succeeded()); - - ASSERT_THAT_EXPECTED(RunOnce<JSONTestType>(), - HasValue(testing::FieldsAre(/*str=*/"foo"))); + input.CloseWriteFileDescriptor(); + ASSERT_THAT_EXPECTED(Run(), HasValue(testing::ElementsAre(Req{"foo"}))); } TEST_F(HTTPDelimitedJSONTransportTest, ReadWithZeroByteWrites) { - std::string message = Encode(JSONTestType{"foo"}); - std::string part1 = message.substr(0, 28); - std::string part2 = message.substr(28); + std::string message = Encode(Req{"foo"}); + auto split_at = message.size() / 2; + std::string part1 = message.substr(0, split_at); + std::string part2 = message.substr(split_at); ASSERT_THAT_EXPECTED(input.Write(part1.data(), part1.size()), Succeeded()); - ASSERT_THAT_EXPECTED( - RunOnce<JSONTestType>(/*timeout=*/std::chrono::milliseconds(10)), - FailedWithMessage("timeout")); + ASSERT_THAT_EXPECTED(Run(/*timeout=*/std::chrono::milliseconds(10)), + HasValue(testing::IsEmpty())); ASSERT_THAT_EXPECTED(input.Write(part1.data(), 0), Succeeded()); // zero-byte write. - - ASSERT_THAT_EXPECTED( - RunOnce<JSONTestType>(/*timeout=*/std::chrono::milliseconds(10)), - FailedWithMessage("timeout")); + ASSERT_THAT_EXPECTED(Run(/*timeout=*/std::chrono::milliseconds(10)), + HasValue(testing::IsEmpty())); ASSERT_THAT_EXPECTED(input.Write(part2.data(), part2.size()), Succeeded()); - - ASSERT_THAT_EXPECTED(RunOnce<JSONTestType>(), - HasValue(testing::FieldsAre(/*str=*/"foo"))); + input.CloseWriteFileDescriptor(); + ASSERT_THAT_EXPECTED(Run(), HasValue(testing::ElementsAre(Req{"foo"}))); } TEST_F(HTTPDelimitedJSONTransportTest, ReadWithEOF) { input.CloseWriteFileDescriptor(); - ASSERT_THAT_EXPECTED(RunOnce<JSONTestType>(), Failed<TransportEOFError>()); + ASSERT_THAT_EXPECTED(Run(), HasValue(testing::IsEmpty())); } TEST_F(HTTPDelimitedJSONTransportTest, ReaderWithUnhandledData) { @@ -231,32 +334,35 @@ TEST_F(HTTPDelimitedJSONTransportTest, ReaderWithUnhandledData) { ASSERT_THAT_EXPECTED(input.Write(message.data(), message.size() - 1), Succeeded()); input.CloseWriteFileDescriptor(); - ASSERT_THAT_EXPECTED(RunOnce<JSONTestType>(), - Failed<TransportUnhandledContentsError>()); + ASSERT_THAT_EXPECTED(Run(), Failed<TransportUnhandledContentsError>()); } TEST_F(HTTPDelimitedJSONTransportTest, NoDataTimeout) { - ASSERT_THAT_EXPECTED( - RunOnce<JSONTestType>(/*timeout=*/std::chrono::milliseconds(10)), - FailedWithMessage("timeout")); + ASSERT_THAT_EXPECTED(Run(/*timeout=*/std::chrono::milliseconds(10)), + HasValue(testing::IsEmpty())); } TEST_F(HTTPDelimitedJSONTransportTest, InvalidTransport) { - transport = std::make_unique<HTTPDelimitedJSONTransport>(nullptr, nullptr); - auto handle = transport->RegisterReadObject<JSONTestType>( - loop, [&](MainLoopBase &, llvm::Expected<JSONTestType>) {}); - ASSERT_THAT_EXPECTED(handle, FailedWithMessage("IO object is not valid.")); + transport = + std::make_unique<TestHTTPDelimitedJSONTransport>(nullptr, nullptr); + ASSERT_THAT_EXPECTED(Run(), FailedWithMessage("IO object is not valid.")); } TEST_F(HTTPDelimitedJSONTransportTest, Write) { - ASSERT_THAT_ERROR(transport->Write(JSONTestType{"foo"}), Succeeded()); + transport->Request(Req{"foo"}); + transport->Response(Resp{"bar"}); + transport->Event(Evt{"baz"}); output.CloseWriteFileDescriptor(); char buf[1024]; Expected<size_t> bytes_read = output.Read(buf, sizeof(buf), std::chrono::milliseconds(1)); ASSERT_THAT_EXPECTED(bytes_read, Succeeded()); ASSERT_EQ(StringRef(buf, *bytes_read), StringRef("Content-Length: 13\r\n\r\n" - R"json({"str":"foo"})json")); + R"({"req":"foo"})" + "Content-Length: 14\r\n\r\n" + R"({"resp":"bar"})" + "Content-Length: 13\r\n\r\n" + R"({"evt":"baz"})")); } TEST_F(JSONRPCTransportTest, MalformedRequests) { @@ -264,80 +370,80 @@ TEST_F(JSONRPCTransportTest, MalformedRequests) { ASSERT_THAT_EXPECTED( input.Write(malformed_header.data(), malformed_header.size()), Succeeded()); - ASSERT_THAT_EXPECTED(RunOnce<JSONTestType>(), llvm::Failed()); + ASSERT_THAT_EXPECTED( + Run(), FailedWithMessage("[1:2, byte=2]: Invalid JSON value (null?)")); } TEST_F(JSONRPCTransportTest, Read) { - ASSERT_THAT_EXPECTED(Write(JSONTestType{"foo"}), Succeeded()); - ASSERT_THAT_EXPECTED(RunOnce<JSONTestType>(), - HasValue(testing::FieldsAre(/*str=*/"foo"))); + WriteAndCloseInput(Message{Req{"foo"}}, Message{Resp{"bar"}}, + Message{Evt{"baz"}}); + ASSERT_THAT_EXPECTED(Run(), HasValue(testing::ElementsAre( + Req{"foo"}, Resp{"bar"}, Evt{"baz"}))); } TEST_F(JSONRPCTransportTest, ReadAcrossMultipleChunks) { - std::string long_str = std::string(2048, 'x'); - std::string message = Encode(JSONTestType{long_str}); - ASSERT_THAT_EXPECTED(input.Write(message.data(), message.size()), - Succeeded()); - ASSERT_THAT_EXPECTED(RunOnce<JSONTestType>(), - HasValue(testing::FieldsAre(/*str=*/long_str))); + // Use a string longer than the chunk size to ensure we split the message + // across the chunk boundary. + std::string long_str = + std::string(JSONTransport<Req, Resp, Evt>::kReadBufferSize + 10, 'x'); + WriteAndCloseInput(Req{long_str}); + ASSERT_THAT_EXPECTED(Run(), HasValue(testing::ElementsAre(Req{long_str}))); } TEST_F(JSONRPCTransportTest, ReadPartialMessage) { - std::string message = R"({"str": "foo"})" + std::string message = R"({"req": "foo"})" "\n"; std::string part1 = message.substr(0, 7); std::string part2 = message.substr(7); ASSERT_THAT_EXPECTED(input.Write(part1.data(), part1.size()), Succeeded()); - - ASSERT_THAT_EXPECTED( - RunOnce<JSONTestType>(/*timeout=*/std::chrono::milliseconds(10)), - FailedWithMessage("timeout")); + ASSERT_THAT_EXPECTED(Run(std::chrono::milliseconds(10)), + HasValue(testing::IsEmpty())); ASSERT_THAT_EXPECTED(input.Write(part2.data(), part2.size()), Succeeded()); - - ASSERT_THAT_EXPECTED(RunOnce<JSONTestType>(), - HasValue(testing::FieldsAre(/*str=*/"foo"))); + input.CloseWriteFileDescriptor(); + ASSERT_THAT_EXPECTED(Run(), HasValue(testing::ElementsAre(Req{"foo"}))); } TEST_F(JSONRPCTransportTest, ReadWithEOF) { input.CloseWriteFileDescriptor(); - ASSERT_THAT_EXPECTED(RunOnce<JSONTestType>(), Failed<TransportEOFError>()); + ASSERT_THAT_EXPECTED(Run(), HasValue(testing::IsEmpty())); } TEST_F(JSONRPCTransportTest, ReaderWithUnhandledData) { - std::string message = R"json({"str": "foo"})json" - "\n"; + std::string message = R"json({"req": "foo")json"; // Write an incomplete message and close the handle. - ASSERT_THAT_EXPECTED(input.Write(message.data(), message.size() - 1), + ASSERT_THAT_EXPECTED(input.Write(message.data(), message.size()), Succeeded()); input.CloseWriteFileDescriptor(); - ASSERT_THAT_EXPECTED(RunOnce<JSONTestType>(), - Failed<TransportUnhandledContentsError>()); + EXPECT_THAT_EXPECTED(Run(), Failed<TransportUnhandledContentsError>()); } TEST_F(JSONRPCTransportTest, Write) { - ASSERT_THAT_ERROR(transport->Write(JSONTestType{"foo"}), Succeeded()); + transport->Request(Req{"foo"}); + transport->Response(Resp{"bar"}); + transport->Event(Evt{"baz"}); output.CloseWriteFileDescriptor(); char buf[1024]; Expected<size_t> bytes_read = output.Read(buf, sizeof(buf), std::chrono::milliseconds(1)); ASSERT_THAT_EXPECTED(bytes_read, Succeeded()); - ASSERT_EQ(StringRef(buf, *bytes_read), StringRef(R"json({"str":"foo"})json" + ASSERT_EQ(StringRef(buf, *bytes_read), StringRef(R"({"req":"foo"})" + "\n" + R"({"resp":"bar"})" + "\n" + R"({"evt":"baz"})" "\n")); } TEST_F(JSONRPCTransportTest, InvalidTransport) { - transport = std::make_unique<JSONRPCTransport>(nullptr, nullptr); - auto handle = transport->RegisterReadObject<JSONTestType>( - loop, [&](MainLoopBase &, llvm::Expected<JSONTestType>) {}); - ASSERT_THAT_EXPECTED(handle, FailedWithMessage("IO object is not valid.")); + transport = std::make_unique<TestJSONRPCTransport>(nullptr, nullptr); + ASSERT_THAT_EXPECTED(Run(), FailedWithMessage("IO object is not valid.")); } TEST_F(JSONRPCTransportTest, NoDataTimeout) { - ASSERT_THAT_EXPECTED( - RunOnce<JSONTestType>(/*timeout=*/std::chrono::milliseconds(10)), - FailedWithMessage("timeout")); + ASSERT_THAT_EXPECTED(Run(/*timeout=*/std::chrono::milliseconds(10)), + HasValue(testing::ElementsAre())); } #endif diff --git a/lldb/unittests/ProtocolServer/ProtocolMCPServerTest.cpp b/lldb/unittests/ProtocolServer/ProtocolMCPServerTest.cpp index 2ac40c41dd28e..588093edf321a 100644 --- a/lldb/unittests/ProtocolServer/ProtocolMCPServerTest.cpp +++ b/lldb/unittests/ProtocolServer/ProtocolMCPServerTest.cpp @@ -21,7 +21,9 @@ #include "lldb/Protocol/MCP/MCPError.h" #include "lldb/Protocol/MCP/Protocol.h" #include "llvm/Support/Error.h" +#include "llvm/Support/JSON.h" #include "llvm/Testing/Support/Error.h" +#include "gmock/gmock.h" #include "gtest/gtest.h" #include <chrono> #include <condition_variable> @@ -43,11 +45,18 @@ class TestProtocolServerMCP : public lldb_private::mcp::ProtocolServerMCP { using ProtocolServerMCP::ProtocolServerMCP; }; -class TestJSONTransport : public lldb_private::JSONRPCTransport { +using Message = typename Transport<Request, Response, Notification>::Message; + +class TestJSONTransport final + : public lldb_private::JSONRPCTransport<Request, Response, Notification> { public: using JSONRPCTransport::JSONRPCTransport; - using JSONRPCTransport::Parse; - using JSONRPCTransport::WriteImpl; + + void Log(llvm::StringRef message) override { + log_messages.emplace_back(message); + } + + std::vector<std::string> log_messages; }; /// Test tool that returns it argument as text. @@ -55,7 +64,7 @@ class TestTool : public Tool { public: using Tool::Tool; - virtual llvm::Expected<TextResult> Call(const ToolArguments &args) override { + llvm::Expected<TextResult> Call(const ToolArguments &args) override { std::string argument; if (const json::Object *args_obj = std::get<json::Value>(args).getAsObject()) { @@ -73,7 +82,7 @@ class TestTool : public Tool { class TestResourceProvider : public ResourceProvider { using ResourceProvider::ResourceProvider; - virtual std::vector<Resource> GetResources() const override { + std::vector<Resource> GetResources() const override { std::vector<Resource> resources; Resource resource; @@ -86,7 +95,7 @@ class TestResourceProvider : public ResourceProvider { return resources; } - virtual llvm::Expected<ResourceResult> + llvm::Expected<ResourceResult> ReadResource(llvm::StringRef uri) const override { if (uri != "lldb://foo/bar") return llvm::make_error<UnsupportedURI>(uri.str()); @@ -107,7 +116,7 @@ class ErrorTool : public Tool { public: using Tool::Tool; - virtual llvm::Expected<TextResult> Call(const ToolArguments &args) override { + llvm::Expected<TextResult> Call(const ToolArguments &args) override { return llvm::createStringError("error"); } }; @@ -117,7 +126,7 @@ class FailTool : public Tool { public: using Tool::Tool; - virtual llvm::Expected<TextResult> Call(const ToolArguments &args) override { + llvm::Expected<TextResult> Call(const ToolArguments &args) override { TextResult text_result; text_result.content.emplace_back(TextContent{{"failed"}}); text_result.isError = true; @@ -138,26 +147,33 @@ class ProtocolServerMCPTest : public ::testing::Test { static constexpr llvm::StringLiteral k_localhost = "localhost"; llvm::Error Write(llvm::StringRef message) { - return m_transport_up->WriteImpl(llvm::formatv("{0}\n", message).str()); + std::string output = llvm::formatv("{0}\n", message).str(); + size_t bytes_written = output.size(); + return m_io_sp->Write(output.data(), bytes_written).takeError(); } - template <typename P> - void - RunOnce(const std::function<void(llvm::Expected<P>)> &callback, - std::chrono::milliseconds timeout = std::chrono::milliseconds(100)) { - auto handle = m_transport_up->RegisterReadObject<P>( - loop, [&](lldb_private::MainLoopBase &loop, llvm::Expected<P> message) { - callback(std::move(message)); - loop.RequestTermination(); - }); - loop.AddCallback( - [&](lldb_private::MainLoopBase &loop) { - loop.RequestTermination(); - FAIL() << "timeout waiting for read callback"; - }, - timeout); - ASSERT_THAT_EXPECTED(handle, llvm::Succeeded()); - ASSERT_THAT_ERROR(loop.Run().takeError(), llvm::Succeeded()); + void CloseInput() { + EXPECT_THAT_ERROR(m_io_sp->Close().takeError(), Succeeded()); + } + + class MessageCollector final + : public Transport<Request, Response, Notification>::MessageHandler { + public: + std::vector<Message> messages; + void OnEvent(const Notification &V) override { messages.emplace_back(V); } + void OnRequest(const Request &V) override { messages.emplace_back(V); } + void OnResponse(const Response &V) override { messages.emplace_back(V); } + }; + + /// Run the transport MainLoop and return any messages received. + Expected<std::vector<Message>> + Run(std::chrono::milliseconds timeout = std::chrono::milliseconds(100)) { + MessageCollector collector; + loop.AddCallback([](MainLoopBase &loop) { loop.RequestTermination(); }, + timeout); + if (auto error = m_transport_up->Run(loop, collector)) + return error; + return std::move(collector.messages); } void SetUp() override { @@ -206,37 +222,39 @@ TEST_F(ProtocolServerMCPTest, Initialization) { llvm::StringLiteral response = R"json( {"id":0,"jsonrpc":"2.0","result":{"capabilities":{"resources":{"listChanged":false,"subscribe":false},"tools":{"listChanged":true}},"protocolVersion":"2024-11-05","serverInfo":{"name":"lldb-mcp","version":"0.1.0"}}})json"; - ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); - RunOnce<std::string>([&](llvm::Expected<std::string> response_str) { - ASSERT_THAT_EXPECTED(response_str, llvm::Succeeded()); - llvm::Expected<json::Value> response_json = json::parse(*response_str); - ASSERT_THAT_EXPECTED(response_json, llvm::Succeeded()); - - llvm::Expected<json::Value> expected_json = json::parse(response); - ASSERT_THAT_EXPECTED(expected_json, llvm::Succeeded()); - - EXPECT_EQ(*response_json, *expected_json); - }); + ASSERT_THAT_ERROR(Write(request), Succeeded()); + llvm::Expected<json::Value> expected_json = json::parse(response); + ASSERT_THAT_EXPECTED(expected_json, llvm::Succeeded()); + EXPECT_THAT_EXPECTED(Run(), HasValue(testing::ElementsAre(*expected_json))); } TEST_F(ProtocolServerMCPTest, ToolsList) { llvm::StringLiteral request = R"json({"method":"tools/list","params":{},"jsonrpc":"2.0","id":1})json"; - llvm::StringLiteral response = - R"json({"id":1,"jsonrpc":"2.0","result":{"tools":[{"description":"test tool","inputSchema":{"type":"object"},"name":"test"},{"description":"Run an lldb command.","inputSchema":{"properties":{"arguments":{"type":"string"},"debugger_id":{"type":"number"}},"required":["debugger_id"],"type":"object"},"name":"lldb_command"}]}})json"; - - ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); - RunOnce<std::string>([&](llvm::Expected<std::string> response_str) { - ASSERT_THAT_EXPECTED(response_str, llvm::Succeeded()); - - llvm::Expected<json::Value> response_json = json::parse(*response_str); - ASSERT_THAT_EXPECTED(response_json, llvm::Succeeded()); - llvm::Expected<json::Value> expected_json = json::parse(response); - ASSERT_THAT_EXPECTED(expected_json, llvm::Succeeded()); + ToolDefinition test_tool; + test_tool.name = "test"; + test_tool.description = "test tool"; + test_tool.inputSchema = json::Object{{"type", "object"}}; + + ToolDefinition lldb_command_tool; + lldb_command_tool.description = "Run an lldb command."; + lldb_command_tool.name = "lldb_command"; + lldb_command_tool.inputSchema = json::Object{ + {"type", "object"}, + {"properties", + json::Object{{"arguments", json::Object{{"type", "string"}}}, + {"debugger_id", json::Object{{"type", "number"}}}}}, + {"required", json::Array{"debugger_id"}}}; + Response response; + response.id = 1; + response.result = json::Object{ + {"tools", + json::Array{std::move(test_tool), std::move(lldb_command_tool)}}, + }; - EXPECT_EQ(*response_json, *expected_json); - }); + ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); + EXPECT_THAT_EXPECTED(Run(), HasValue(testing::ElementsAre(response))); } TEST_F(ProtocolServerMCPTest, ResourcesList) { @@ -246,17 +264,9 @@ TEST_F(ProtocolServerMCPTest, ResourcesList) { R"json({"id":2,"jsonrpc":"2.0","result":{"resources":[{"description":"description","mimeType":"application/json","name":"name","uri":"lldb://foo/bar"}]}})json"; ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); - RunOnce<std::string>([&](llvm::Expected<std::string> response_str) { - ASSERT_THAT_EXPECTED(response_str, llvm::Succeeded()); - - llvm::Expected<json::Value> response_json = json::parse(*response_str); - ASSERT_THAT_EXPECTED(response_json, llvm::Succeeded()); - - llvm::Expected<json::Value> expected_json = json::parse(response); - ASSERT_THAT_EXPECTED(expected_json, llvm::Succeeded()); - - EXPECT_EQ(*response_json, *expected_json); - }); + llvm::Expected<json::Value> expected_json = json::parse(response); + ASSERT_THAT_EXPECTED(expected_json, llvm::Succeeded()); + EXPECT_THAT_EXPECTED(Run(), HasValue(testing::ElementsAre(*expected_json))); } TEST_F(ProtocolServerMCPTest, ToolsCall) { @@ -266,17 +276,9 @@ TEST_F(ProtocolServerMCPTest, ToolsCall) { R"json({"id":11,"jsonrpc":"2.0","result":{"content":[{"text":"foo","type":"text"}],"isError":false}})json"; ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); - RunOnce<std::string>([&](llvm::Expected<std::string> response_str) { - ASSERT_THAT_EXPECTED(response_str, llvm::Succeeded()); - - llvm::Expected<json::Value> response_json = json::parse(*response_str); - ASSERT_THAT_EXPECTED(response_json, llvm::Succeeded()); - - llvm::Expected<json::Value> expected_json = json::parse(response); - ASSERT_THAT_EXPECTED(expected_json, llvm::Succeeded()); - - EXPECT_EQ(*response_json, *expected_json); - }); + llvm::Expected<json::Value> expected_json = json::parse(response); + ASSERT_THAT_EXPECTED(expected_json, llvm::Succeeded()); + ASSERT_THAT_EXPECTED(Run(), HasValue(testing::ElementsAre(*expected_json))); } TEST_F(ProtocolServerMCPTest, ToolsCallError) { @@ -288,17 +290,9 @@ TEST_F(ProtocolServerMCPTest, ToolsCallError) { R"json({"error":{"code":-32603,"message":"error"},"id":11,"jsonrpc":"2.0"})json"; ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); - RunOnce<std::string>([&](llvm::Expected<std::string> response_str) { - ASSERT_THAT_EXPECTED(response_str, llvm::Succeeded()); - - llvm::Expected<json::Value> response_json = json::parse(*response_str); - ASSERT_THAT_EXPECTED(response_json, llvm::Succeeded()); - - llvm::Expected<json::Value> expected_json = json::parse(response); - ASSERT_THAT_EXPECTED(expected_json, llvm::Succeeded()); - - EXPECT_EQ(*response_json, *expected_json); - }); + llvm::Expected<json::Value> expected_json = json::parse(response); + ASSERT_THAT_EXPECTED(expected_json, llvm::Succeeded()); + ASSERT_THAT_EXPECTED(Run(), HasValue(testing::ElementsAre(*expected_json))); } TEST_F(ProtocolServerMCPTest, ToolsCallFail) { @@ -310,17 +304,9 @@ TEST_F(ProtocolServerMCPTest, ToolsCallFail) { R"json({"id":11,"jsonrpc":"2.0","result":{"content":[{"text":"failed","type":"text"}],"isError":true}})json"; ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); - RunOnce<std::string>([&](llvm::Expected<std::string> response_str) { - ASSERT_THAT_EXPECTED(response_str, llvm::Succeeded()); - - llvm::Expected<json::Value> response_json = json::parse(*response_str); - ASSERT_THAT_EXPECTED(response_json, llvm::Succeeded()); - - llvm::Expected<json::Value> expected_json = json::parse(response); - ASSERT_THAT_EXPECTED(expected_json, llvm::Succeeded()); - - EXPECT_EQ(*response_json, *expected_json); - }); + llvm::Expected<json::Value> expected_json = json::parse(response); + ASSERT_THAT_EXPECTED(expected_json, llvm::Succeeded()); + ASSERT_THAT_EXPECTED(Run(), HasValue(testing::ElementsAre(*expected_json))); } TEST_F(ProtocolServerMCPTest, NotificationInitialized) { >From e4abbc0d59972853f1eb65cc80b8fac9b3bb4b64 Mon Sep 17 00:00:00 2001 From: John Harrison <harj...@google.com> Date: Wed, 13 Aug 2025 16:04:23 -0700 Subject: [PATCH 2/2] Addressing reviewer comments. --- lldb/include/lldb/Host/JSONTransport.h | 27 ++++++++++------ lldb/tools/lldb-dap/DAP.cpp | 44 ++++++++++++++++---------- lldb/tools/lldb-dap/DAP.h | 2 +- lldb/tools/lldb-dap/tool/lldb-dap.cpp | 5 +-- 4 files changed, 48 insertions(+), 30 deletions(-) diff --git a/lldb/include/lldb/Host/JSONTransport.h b/lldb/include/lldb/Host/JSONTransport.h index 18126f599c380..f160599243de7 100644 --- a/lldb/include/lldb/Host/JSONTransport.h +++ b/lldb/include/lldb/Host/JSONTransport.h @@ -139,9 +139,8 @@ class JSONTransport : public Transport<Req, Resp, Evt> { std::string output = Encode(message); size_t bytes_written = output.size(); Status status = m_out->Write(output.data(), bytes_written); - if (status.Fail()) { - this->Logv("writing failed: s{0}", status.AsCString()); - } + if (status.Fail()) + this->Logv("writing failed: {0}", status.AsCString()); } llvm::SmallString<kReadBufferSize> m_buffer; @@ -170,8 +169,8 @@ class JSONTransport : public Transport<Req, Resp, Evt> { return; } - for (const auto &raw_message : *raw_messages) { - auto message = + for (const std::string &raw_message : *raw_messages) { + llvm::Expected<typename Transport<Req, Resp, Evt>::Message> message = llvm::json::parse<typename Transport<Req, Resp, Evt>::Message>( raw_message); if (!message) { @@ -182,13 +181,20 @@ class JSONTransport : public Transport<Req, Resp, Evt> { if (Evt *evt = std::get_if<Evt>(&*message)) { handler.OnEvent(*evt); - } else if (Req *req = std::get_if<Req>(&*message)) { + continue; + } + + if (Req *req = std::get_if<Req>(&*message)) { handler.OnRequest(*req); - } else if (Resp *resp = std::get_if<Resp>(&*message)) { + continue; + } + + if (Resp *resp = std::get_if<Resp>(&*message)) { handler.OnResponse(*resp); - } else { - llvm_unreachable("unknown message type"); + continue; } + + llvm_unreachable("unknown message type"); } } @@ -235,7 +241,8 @@ class HTTPDelimitedJSONTransport : public JSONTransport<Req, Resp, Evt> { auto [headers, rest] = buffer.split(kEndOfHeader); size_t content_length = 0; // HTTP Headers are formatted like `<field-name> ':' [<field-value>]`. - for (const auto &header : llvm::split(headers, kHeaderSeparator)) { + for (const llvm::StringRef &header : + llvm::split(headers, kHeaderSeparator)) { auto [key, value] = header.split(kHeaderFieldSeparator); // 'Content-Length' is the only meaningful key at the moment. Others are // ignored. diff --git a/lldb/tools/lldb-dap/DAP.cpp b/lldb/tools/lldb-dap/DAP.cpp index a9a0fe75a35b7..2eced4f78fbd3 100644 --- a/lldb/tools/lldb-dap/DAP.cpp +++ b/lldb/tools/lldb-dap/DAP.cpp @@ -268,24 +268,33 @@ void DAP::SendJSON(const llvm::json::Value &json) { void DAP::Send(const Message &message) { if (const protocol::Event *event = std::get_if<protocol::Event>(&message)) { transport.Event(*event); - } else if (const Request *req = std::get_if<Request>(&message)) { + return; + } + + if (const Request *req = std::get_if<Request>(&message)) { transport.Request(*req); - } else if (const Response *resp = std::get_if<Response>(&message)) { + return; + } + + if (const Response *resp = std::get_if<Response>(&message)) { // FIXME: After all the requests have migrated from LegacyRequestHandler > // RequestHandler<> this should be handled in RequestHandler<>::operator(). - if (debugger.InterruptRequested()) - // If the debugger was interrupted, convert this response into a - // 'cancelled' response because we might have a partial result. + + // If the debugger was interrupted, convert this response into a + // 'cancelled' response because we might have a partial result. + if (debugger.InterruptRequested()) { transport.Response(Response{/*request_seq=*/resp->request_seq, /*command=*/resp->command, /*success=*/false, /*message=*/eResponseMessageCancelled, /*body=*/std::nullopt}); - else + } else { transport.Response(*resp); - } else { - llvm_unreachable("Unexpected message type"); + } + return; } + + llvm_unreachable("Unexpected message type"); } // "OutputEvent": { @@ -916,7 +925,8 @@ llvm::Error DAP::Disconnect(bool terminateDebuggee) { SendTerminatedEvent(); - disconnecting = true; + std::lock_guard<std::mutex> guard(m_queue_mutex); + m_disconnecting = true; return ToError(error); } @@ -952,7 +962,7 @@ void DAP::OnEvent(const protocol::Event &event) { void DAP::OnRequest(const protocol::Request &request) { if (request.command == "disconnect") - disconnecting = true; + m_disconnecting = true; const std::optional<CancelArguments> cancel_args = getArgumentsIfRequest<CancelArguments>(request, "cancel"); @@ -990,12 +1000,12 @@ void DAP::OnResponse(const protocol::Response &response) { void DAP::TransportHandler(llvm::Error *error) { llvm::ErrorAsOutParameter ErrAsOutParam(*error); - auto cleanup = llvm::make_scope_exit([&]() { - // Ensure we're marked as disconnecting when the reader exits. - disconnecting = true; - m_queue_cv.notify_all(); - }); *error = transport.Run(m_loop, *this); + + std::lock_guard<std::mutex> guard(m_queue_mutex); + // Ensure we're marked as disconnecting when the reader exits. + m_disconnecting = true; + m_queue_cv.notify_all(); } llvm::Error DAP::Loop() { @@ -1010,9 +1020,9 @@ llvm::Error DAP::Loop() { while (true) { std::unique_lock<std::mutex> lock(m_queue_mutex); - m_queue_cv.wait(lock, [&] { return disconnecting || !m_queue.empty(); }); + m_queue_cv.wait(lock, [&] { return m_disconnecting || !m_queue.empty(); }); - if (disconnecting && m_queue.empty()) + if (m_disconnecting && m_queue.empty()) break; Message next = m_queue.front(); diff --git a/lldb/tools/lldb-dap/DAP.h b/lldb/tools/lldb-dap/DAP.h index 628f97257d5f0..71ec6a7faf7bc 100644 --- a/lldb/tools/lldb-dap/DAP.h +++ b/lldb/tools/lldb-dap/DAP.h @@ -118,7 +118,6 @@ struct DAP final : private DAPTransport::MessageHandler { /// The focused thread for this DAP session. lldb::tid_t focus_tid = LLDB_INVALID_THREAD_ID; - bool disconnecting = false; llvm::once_flag terminated_event_flag; bool stop_at_entry = false; bool is_attach = false; @@ -467,6 +466,7 @@ struct DAP final : private DAPTransport::MessageHandler { std::deque<protocol::Message> m_queue; std::mutex m_queue_mutex; std::condition_variable m_queue_cv; + bool m_disconnecting = false; // Loop for managing reading from the client. lldb_private::MainLoop &m_loop; diff --git a/lldb/tools/lldb-dap/tool/lldb-dap.cpp b/lldb/tools/lldb-dap/tool/lldb-dap.cpp index c728b0af94c7c..b74085f25f4e2 100644 --- a/lldb/tools/lldb-dap/tool/lldb-dap.cpp +++ b/lldb/tools/lldb-dap/tool/lldb-dap.cpp @@ -39,6 +39,7 @@ #include "llvm/Support/PrettyStackTrace.h" #include "llvm/Support/Signals.h" #include "llvm/Support/Threading.h" +#include "llvm/Support/WithColor.h" #include "llvm/Support/raw_ostream.h" #include <condition_variable> #include <cstdio> @@ -349,8 +350,8 @@ serveConnection(const Socket::SocketProtocol &protocol, const std::string &name, for (auto [loop, dap] : dap_sessions) { if (llvm::Error error = dap->Disconnect()) { client_failed = true; - llvm::errs() << "DAP client disconnected failed: " - << llvm::toString(std::move(error)) << "\n"; + llvm::WithColor::error() << "DAP client disconnected failed: " + << llvm::toString(std::move(error)) << "\n"; } loop->AddPendingCallback( [](MainLoopBase &loop) { loop.RequestTermination(); }); _______________________________________________ lldb-commits mailing list lldb-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/lldb-commits