mkatanbaf commented on code in PR #10967:
URL: https://github.com/apache/tvm/pull/10967#discussion_r864276388


##########
src/runtime/minrpc/minrpc_server.h:
##########
@@ -58,95 +63,133 @@ class PageAllocator;
 }
 
 /*!
- * \brief A minimum RPC server that only depends on the tvm C runtime..
- *
- *  All the dependencies are provided by the io arguments.
+ * \brief Responses to a minimum RPC command.
  *
  * \tparam TIOHandler IO provider to provide io handling.
- *         An IOHandler needs to provide the following functions:
- *         - PosixWrite, PosixRead, Close: posix style, read, write, close API.
- *         - MessageStart(num_bytes), MessageDone(): framing APIs.
- *         - Exit: exit with status code.
  */
-template <typename TIOHandler, template <typename> class Allocator = 
detail::PageAllocator>
-class MinRPCServer {
+template <typename TIOHandler>
+class MinRPCReturns : public MinRPCReturnInterface {
  public:
-  using PageAllocator = Allocator<TIOHandler>;
-
   /*!
    * \brief Constructor.
    * \param io The IO handler.
    */
-  explicit MinRPCServer(TIOHandler* io) : io_(io), arena_(PageAllocator(io)) {}
+  explicit MinRPCReturns(TIOHandler* io) : io_(io) {}
 
-  /*! \brief Process a single request.
-   *
-   * \return true when the server should continue processing requests. false 
when it should be
-   *  shutdown.
-   */
-  bool ProcessOnePacket() {
-    RPCCode code;
-    uint64_t packet_len;
+  void ReturnVoid() {
+    int32_t num_args = 1;
+    int32_t tcode = kTVMNullptr;
+    RPCCode code = RPCCode::kReturn;
 
-    arena_.RecycleAll();
-    allow_clean_shutdown_ = true;
+    uint64_t packet_nbytes = sizeof(code) + sizeof(num_args) + sizeof(tcode);
 
-    this->Read(&packet_len);
-    if (packet_len == 0) return true;
-    this->Read(&code);
+    io_->MessageStart(packet_nbytes);
+    Write(packet_nbytes);
+    Write(code);
+    Write(num_args);
+    Write(tcode);
+    io_->MessageDone();
+  }
 
-    allow_clean_shutdown_ = false;
+  void ReturnHandle(void* handle) {
+    int32_t num_args = 1;
+    int32_t tcode = kTVMOpaqueHandle;
+    RPCCode code = RPCCode::kReturn;
+    uint64_t encode_handle = reinterpret_cast<uint64_t>(handle);
+    uint64_t packet_nbytes =
+        sizeof(code) + sizeof(num_args) + sizeof(tcode) + 
sizeof(encode_handle);
 
-    if (code >= RPCCode::kSyscallCodeStart) {
-      this->HandleSyscallFunc(code);
-    } else {
-      switch (code) {
-        case RPCCode::kCallFunc: {
-          HandleNormalCallFunc();
-          break;
-        }
-        case RPCCode::kInitServer: {
-          HandleInitServer();
-          break;
-        }
-        case RPCCode::kCopyFromRemote: {
-          HandleCopyFromRemote();
-          break;
-        }
-        case RPCCode::kCopyToRemote: {
-          HandleCopyToRemote();
-          break;
-        }
-        case RPCCode::kShutdown: {
-          this->Shutdown();
-          return false;
-        }
-        default: {
-          this->ThrowError(RPCServerStatus::kUnknownRPCCode);
-          break;
-        }
+    io_->MessageStart(packet_nbytes);
+    Write(packet_nbytes);
+    Write(code);
+    Write(num_args);
+    Write(tcode);
+    Write(encode_handle);
+    io_->MessageDone();
+  }
+
+  void ReturnException(const char* msg) { RPCReference::ReturnException(msg, 
this); }
+
+  void ReturnPackedSeq(const TVMValue* arg_values, const int* type_codes, int 
num_args) {
+    RPCReference::ReturnPackedSeq(arg_values, type_codes, num_args, this);
+  }
+
+  void ReturnCopyFromRemote(uint8_t* data_ptr, uint64_t num_bytes) {
+    RPCCode code = RPCCode::kCopyAck;
+    uint64_t packet_nbytes = sizeof(code) + num_bytes;
+
+    io_->MessageStart(packet_nbytes);
+    Write(packet_nbytes);
+    Write(code);
+    WriteArray(data_ptr, num_bytes);
+    io_->MessageDone();
+  }
+
+  void ReturnLastTVMError() {
+    const char* err = TVMGetLastError();
+    ReturnException(err);
+  }
+
+  void MessageStart(uint64_t packet_nbytes) { 
io_->MessageStart(packet_nbytes); }
+
+  void MessageDone() { io_->MessageDone(); }
+
+  void ThrowError(RPCServerStatus code, RPCCode info = RPCCode::kNone) {
+    io_->Exit(static_cast<int>(code));
+  }
+
+  template <typename T>
+  void Write(const T& data) {
+    static_assert(std::is_trivial<T>::value && 
std::is_standard_layout<T>::value,
+                  "need to be trival");
+    return WriteRawBytes(&data, sizeof(T));
+  }
+
+  template <typename T>
+  void WriteArray(T* data, size_t count) {
+    static_assert(std::is_trivial<T>::value && 
std::is_standard_layout<T>::value,
+                  "need to be trival");
+    return WriteRawBytes(data, sizeof(T) * count);
+  }
+
+ private:
+  void WriteRawBytes(const void* data, size_t size) {
+    const uint8_t* buf = static_cast<const uint8_t*>(data);
+    size_t ndone = 0;
+    while (ndone < size) {
+      ssize_t ret = io_->PosixWrite(buf, size - ndone);
+      if (ret == 0 || ret == -1) {

Review Comment:
   done



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to