This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 7275cf0eec [REFACTOR][FFI][RPC] Migrate RPC to use the latest FFI ABI
(#17931)
7275cf0eec is described below
commit 7275cf0eec17de59d518ede59218140e0e48a17a
Author: Tianqi Chen <[email protected]>
AuthorDate: Fri May 9 10:03:01 2025 -0400
[REFACTOR][FFI][RPC] Migrate RPC to use the latest FFI ABI (#17931)
This PR migrates the RPC to the latest FFI ABI instead of the legacy
translation layer.
---
.github/actions/setup/action.yml | 15 +-
conda/build-environment.yaml | 2 -
src/runtime/disco/message_queue.h | 13 +-
src/runtime/disco/protocol.h | 16 +-
src/runtime/disco/threaded_session.cc | 16 +-
src/runtime/metal/metal_module.mm | 10 +-
src/runtime/minrpc/minrpc_interfaces.h | 93 ---
src/runtime/minrpc/minrpc_logger.cc | 291 ---------
src/runtime/minrpc/minrpc_logger.h | 296 ---------
src/runtime/minrpc/minrpc_server.h | 719 ++-------------------
src/runtime/minrpc/minrpc_server_logging.h | 170 -----
.../posix_popen_server/posix_popen_server.cc | 3 -
src/runtime/minrpc/rpc_reference.h | 207 +++---
src/runtime/rpc/rpc_channel_logger.h | 186 ------
src/runtime/rpc/rpc_endpoint.cc | 75 +--
src/runtime/rpc/rpc_endpoint.h | 1 -
src/runtime/rpc/rpc_local_session.cc | 25 +-
src/runtime/rpc/rpc_module.cc | 22 +-
src/runtime/rpc/rpc_socket_impl.cc | 3 -
tests/python/runtime/test_runtime_rpc.py | 19 -
20 files changed, 201 insertions(+), 1981 deletions(-)
diff --git a/.github/actions/setup/action.yml b/.github/actions/setup/action.yml
index 6fd81c1d69..cd7fd9197f 100644
--- a/.github/actions/setup/action.yml
+++ b/.github/actions/setup/action.yml
@@ -3,7 +3,7 @@ runs:
steps:
- uses: actions/cache@v3
env:
- CACHE_NUMBER: 1
+ CACHE_NUMBER: 2
with:
path: ~/conda_pkgs_dir
key: ${{ runner.os }}-conda-${{ env.CACHE_NUMBER }}-${{
hashFiles('conda/build-environment.yaml') }}
@@ -15,8 +15,7 @@ runs:
channel-priority: strict
environment-file: conda/build-environment.yaml
auto-activate-base: false
- conda-solver: classic
- use-only-tar-bz2: true
+ miniforge-version: latest
python-version: 3.9
condarc-file: conda/condarc
- uses: conda-incubator/setup-miniconda@v3
@@ -26,14 +25,14 @@ runs:
channel-priority: strict
environment-file: conda/build-environment.yaml
auto-activate-base: false
- conda-solver: classic
+ miniforge-version: latest
use-only-tar-bz2: true
python-version: 3.9
condarc-file: conda/condarc
- name: Conda info
shell: pwsh
run: |
- conda info
- conda list
- conda info --envs
- conda list --name base
+ mamba info
+ mamba list
+ mamba info --envs
+ mamba list --name base
diff --git a/conda/build-environment.yaml b/conda/build-environment.yaml
index de4e6f4234..716b2198fa 100644
--- a/conda/build-environment.yaml
+++ b/conda/build-environment.yaml
@@ -20,12 +20,10 @@ name: tvm-build
# The conda channels to lookup the dependencies
channels:
- - anaconda
- conda-forge
# The packages to install to the environment
dependencies:
- - python=3.9
- conda < 24.9.0
- conda-build < 24.9.0
- git
diff --git a/src/runtime/disco/message_queue.h
b/src/runtime/disco/message_queue.h
index 6b3600acbb..0fa793c3ab 100644
--- a/src/runtime/disco/message_queue.h
+++ b/src/runtime/disco/message_queue.h
@@ -37,12 +37,9 @@ class DiscoStreamMessageQueue : private dmlc::Stream,
~DiscoStreamMessageQueue() = default;
void Send(const ffi::PackedArgs& args) {
- // Run legacy ABI translation.
- std::vector<TVMValue> values(args.size());
- std::vector<int> type_codes(args.size());
- PackedArgsToLegacyTVMArgs(args.data(), args.size(), values.data(),
type_codes.data());
// TODO(tqchen): use native convention that do not need ABI translation.
- RPCReference::ReturnPackedSeq(values.data(), type_codes.data(),
args.size(), this);
+ RPCReference::ReturnPackedSeq(reinterpret_cast<const
TVMFFIAny*>(args.data()), args.size(),
+ this);
CommitSendAndNotifyEnqueue();
}
@@ -57,11 +54,7 @@ class DiscoStreamMessageQueue : private dmlc::Stream,
packed_args[0] = static_cast<int>(DiscoAction::kShutDown);
packed_args[1] = 0;
} else {
- TVMValue* values = nullptr;
- int* type_codes = nullptr;
- RPCReference::RecvPackedSeq(&values, &type_codes, &num_args, this);
- packed_args =
reinterpret_cast<AnyView*>(ArenaAlloc<TVMFFIAny>(num_args));
- LegacyTVMArgsToPackedArgs(values, type_codes, num_args, packed_args);
+ RPCReference::RecvPackedSeq(reinterpret_cast<TVMFFIAny**>(&packed_args),
&num_args, this);
}
return ffi::PackedArgs(packed_args, num_args);
}
diff --git a/src/runtime/disco/protocol.h b/src/runtime/disco/protocol.h
index c14f7cdaba..f0e6cd28a3 100644
--- a/src/runtime/disco/protocol.h
+++ b/src/runtime/disco/protocol.h
@@ -61,7 +61,7 @@ struct DiscoProtocol {
inline void WriteObject(Object* obj);
/*! \brief Read the object from stream. Used by RPCReference. */
- inline void ReadObject(int* tcode, TVMValue* value);
+ inline void ReadObject(TVMFFIAny* out);
/*! \brief Callback method used when starting a new message. Used by
RPCReference. */
void MessageStart(uint64_t packet_nbytes) {}
@@ -149,7 +149,7 @@ inline void
DiscoProtocol<SubClassType>::WriteObject(Object* obj) {
self->template Write<int64_t>(reg_id);
} else if (obj->IsInstance<ffi::StringObj>()) {
ffi::StringObj* str = static_cast<ffi::StringObj*>(obj);
- self->template Write<uint32_t>(TypeIndex::kRuntimeString);
+ self->template Write<uint32_t>(ffi::TypeIndex::kTVMFFIStr);
self->template Write<uint64_t>(str->size);
self->template WriteArray<char>(str->data, str->size);
} else if (obj->IsInstance<ffi::BytesObj>()) {
@@ -159,7 +159,7 @@ inline void
DiscoProtocol<SubClassType>::WriteObject(Object* obj) {
self->template WriteArray<char>(bytes->data, bytes->size);
} else if (obj->IsInstance<ffi::ShapeObj>()) {
ffi::ShapeObj* shape = static_cast<ffi::ShapeObj*>(obj);
- self->template Write<uint32_t>(TypeIndex::kRuntimeShape);
+ self->template Write<uint32_t>(ffi::TypeIndex::kTVMFFIShape);
self->template Write<uint64_t>(shape->size);
self->template WriteArray<ffi::ShapeObj::index_type>(shape->data,
shape->size);
} else if (obj->IsInstance<DiscoDebugObject>()) {
@@ -174,7 +174,7 @@ inline void
DiscoProtocol<SubClassType>::WriteObject(Object* obj) {
}
template <class SubClassType>
-inline void DiscoProtocol<SubClassType>::ReadObject(int* tcode, TVMValue*
value) {
+inline void DiscoProtocol<SubClassType>::ReadObject(TVMFFIAny* out) {
SubClassType* self = static_cast<SubClassType*>(this);
ObjectRef result{nullptr};
uint32_t type_index;
@@ -184,7 +184,7 @@ inline void DiscoProtocol<SubClassType>::ReadObject(int*
tcode, TVMValue* value)
self->template Read<int64_t>(&dref->reg_id);
dref->session = Session{nullptr};
result = ObjectRef(std::move(dref));
- } else if (type_index == TypeIndex::kRuntimeString) {
+ } else if (type_index == ffi::TypeIndex::kTVMFFIStr) {
uint64_t size = 0;
self->template Read<uint64_t>(&size);
std::string data(size, '\0');
@@ -196,7 +196,7 @@ inline void DiscoProtocol<SubClassType>::ReadObject(int*
tcode, TVMValue* value)
std::string data(size, '\0');
self->template ReadArray<char>(data.data(), size);
result = ffi::Bytes(std::move(data));
- } else if (type_index == TypeIndex::kRuntimeShape) {
+ } else if (type_index == ffi::TypeIndex::kTVMFFIShape) {
uint64_t ndim = 0;
self->template Read<uint64_t>(&ndim);
std::vector<ffi::ShapeObj::index_type> data(ndim);
@@ -212,9 +212,7 @@ inline void DiscoProtocol<SubClassType>::ReadObject(int*
tcode, TVMValue* value)
LOG(FATAL) << "ValueError: Object type is not supported in Disco calling
convention: "
<< Object::TypeIndex2Key(type_index) << " (type_index = " <<
type_index << ")";
}
- // translate AnyView to legacy TVMValue and type_code
- AnyView res_view = result;
- AnyViewToLegacyTVMArgValue(res_view.CopyToTVMFFIAny(), value, tcode);
+ *reinterpret_cast<AnyView*>(out) = result;
object_arena_.push_back(result);
}
diff --git a/src/runtime/disco/threaded_session.cc
b/src/runtime/disco/threaded_session.cc
index f40fae007e..1f13184102 100644
--- a/src/runtime/disco/threaded_session.cc
+++ b/src/runtime/disco/threaded_session.cc
@@ -41,24 +41,16 @@ class DiscoThreadedMessageQueue : private dmlc::Stream,
private
DiscoProtocol<DiscoThreadedMessageQueue> {
public:
void Send(const ffi::PackedArgs& args) {
- // Run legacy ABI translation.
- std::vector<TVMValue> values(args.size());
- std::vector<int> type_codes(args.size());
- PackedArgsToLegacyTVMArgs(args.data(), args.size(), values.data(),
type_codes.data());
- // TODO(tqchen): use native convention that do not need ABI translation.
- RPCReference::ReturnPackedSeq(values.data(), type_codes.data(),
args.size(), this);
+ RPCReference::ReturnPackedSeq(reinterpret_cast<const
TVMFFIAny*>(args.data()), args.size(),
+ this);
CommitSendAndNotifyEnqueue();
}
ffi::PackedArgs Recv() {
DequeueNextPacket();
- TVMValue* values = nullptr;
- int* type_codes = nullptr;
+ AnyView* packed_args = nullptr;
int num_args = 0;
- RPCReference::RecvPackedSeq(&values, &type_codes, &num_args, this);
- // Run legacy ABI translation.
- AnyView* packed_args =
reinterpret_cast<AnyView*>(ArenaAlloc<TVMFFIAny>(num_args));
- LegacyTVMArgsToPackedArgs(values, type_codes, num_args, packed_args);
+ RPCReference::RecvPackedSeq(reinterpret_cast<TVMFFIAny**>(&packed_args),
&num_args, this);
return ffi::PackedArgs(packed_args, num_args);
}
diff --git a/src/runtime/metal/metal_module.mm
b/src/runtime/metal/metal_module.mm
index cc25fd8b0d..36062cae39 100644
--- a/src/runtime/metal/metal_module.mm
+++ b/src/runtime/metal/metal_module.mm
@@ -260,23 +260,23 @@ class MetalWrappedFunc {
ffi::Function MetalModuleNode::GetFunction(const String& name,
const ObjectPtr<Object>&
sptr_to_self) {
- ffi::Function f;
+ ffi::Function ret;
AUTORELEASEPOOL {
ICHECK_EQ(sptr_to_self.get(), this);
ICHECK_NE(name, symbol::tvm_module_main) << "Device function do not have
main";
auto it = fmap_.find(name);
if (it == fmap_.end()) {
- f = ffi::Function();
- return;
+ ret = ffi::Function();
+ return ret;
}
const FunctionInfo& info = it->second;
MetalWrappedFunc f;
size_t num_buffer_args = NumBufferArgs(info.arg_types);
f.Init(this, sptr_to_self, name, num_buffer_args, info.arg_types.size() -
num_buffer_args,
info.launch_param_tags);
- pf = PackFuncNonBufferArg(f, info.arg_types);
+ ret = PackFuncNonBufferArg(f, info.arg_types);
};
- return pf;
+ return ret;
}
Module MetalModuleCreate(std::unordered_map<std::string, std::string> smap,
diff --git a/src/runtime/minrpc/minrpc_interfaces.h
b/src/runtime/minrpc/minrpc_interfaces.h
deleted file mode 100644
index a45dee9f2c..0000000000
--- a/src/runtime/minrpc/minrpc_interfaces.h
+++ /dev/null
@@ -1,93 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-#ifndef TVM_RUNTIME_MINRPC_MINRPC_INTERFACES_H_
-#define TVM_RUNTIME_MINRPC_MINRPC_INTERFACES_H_
-
-#include <tvm/runtime/c_runtime_api.h>
-
-#include "rpc_reference.h"
-
-namespace tvm {
-namespace runtime {
-
-/*!
- * \brief Return interface used in ExecInterface to generate and send the
responses.
- */
-class MinRPCReturnInterface {
- public:
- virtual ~MinRPCReturnInterface() {}
- /*! * \brief sends a response to the client with kTVMNullptr in payload. */
- virtual void ReturnVoid() = 0;
-
- /*! * \brief sends a response to the client with one kTVMOpaqueHandle in
payload. */
- virtual void ReturnHandle(void* handle) = 0;
-
- /*! * \brief sends an exception response to the client with a kTVMStr in
payload. */
- virtual void ReturnException(const char* msg) = 0;
-
- /*! * \brief sends a packed argument sequnce to the client. */
- virtual void ReturnPackedSeq(const TVMValue* arg_values, const int*
type_codes, int num_args) = 0;
-
- /*! * \brief sends a copy of the requested remote data to the client. */
- virtual void ReturnCopyFromRemote(uint8_t* data_ptr, uint64_t num_bytes) = 0;
-
- /*! * \brief sends an exception response to the client with the last TVM
erros as the message. */
- virtual void ReturnLastTVMError() = 0;
-
- /*! * \brief internal error. */
- virtual void ThrowError(RPCServerStatus code, RPCCode info = RPCCode::kNone)
= 0;
-};
-
-/*!
- * \brief Execute interface used in MinRPCServer to process different received
commands
- */
-class MinRPCExecInterface {
- public:
- virtual ~MinRPCExecInterface() {}
-
- /*! * \brief Execute an Initilize server command. */
- virtual void InitServer(int num_args) = 0;
-
- /*! * \brief calls a function specified by the call_handle. */
- virtual void NormalCallFunc(uint64_t call_handle, TVMValue* values, int*
tcodes,
- int num_args) = 0;
-
- /*! * \brief Execute a copy from remote command by sending the data
described in arr to the client
- */
- virtual void CopyFromRemote(DLTensor* arr, uint64_t num_bytes, uint8_t*
data_ptr) = 0;
-
- /*! * \brief Execute a copy to remote command by receiving the data
described in arr from the
- * client */
- virtual int CopyToRemote(DLTensor* arr, uint64_t num_bytes, uint8_t*
data_ptr) = 0;
-
- /*! * \brief calls a system function specified by the code. */
- virtual void SysCallFunc(RPCCode code, TVMValue* values, int* tcodes, int
num_args) = 0;
-
- /*! * \brief internal error. */
- virtual void ThrowError(RPCServerStatus code, RPCCode info = RPCCode::kNone)
= 0;
-
- /*! * \brief return the ReturnInterface pointer that is used to generate and
send the responses.
- */
- virtual MinRPCReturnInterface* GetReturnInterface() = 0;
-};
-
-} // namespace runtime
-} // namespace tvm
-#endif // TVM_RUNTIME_MINRPC_MINRPC_INTERFACES_H_
diff --git a/src/runtime/minrpc/minrpc_logger.cc
b/src/runtime/minrpc/minrpc_logger.cc
deleted file mode 100644
index 4f3b7e764c..0000000000
--- a/src/runtime/minrpc/minrpc_logger.cc
+++ /dev/null
@@ -1,291 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-#include "minrpc_logger.h"
-
-#include <string.h>
-#include <time.h>
-#include <tvm/runtime/c_runtime_api.h>
-#include <tvm/runtime/logging.h>
-
-#include <functional>
-#include <iostream>
-#include <sstream>
-#include <unordered_map>
-
-#include "minrpc_interfaces.h"
-#include "rpc_reference.h"
-
-namespace tvm {
-namespace runtime {
-
-void Logger::LogTVMValue(int tcode, TVMValue value) {
- switch (tcode) {
- case kDLInt: {
- LogValue<int64_t>("(int64)", value.v_int64);
- break;
- }
- case kDLUInt: {
- LogValue<uint64_t>("(uint64)", value.v_int64);
- break;
- }
- case kDLFloat: {
- LogValue<float>("(float)", value.v_float64);
- break;
- }
- case kTVMDataType: {
- LogDLData("DLDataType(code,bits,lane)", &value.v_type);
- break;
- }
- case kDLDevice: {
- LogDLDevice("DLDevice(type,id)", &value.v_device);
- break;
- }
- case kTVMPackedFuncHandle: {
- LogValue<void*>("(PackedFuncHandle)", value.v_handle);
- break;
- }
- case kTVMModuleHandle: {
- LogValue<void*>("(ModuleHandle)", value.v_handle);
- break;
- }
- case kTVMOpaqueHandle: {
- LogValue<void*>("(OpaqueHandle)", value.v_handle);
- break;
- }
- case kTVMDLTensorHandle: {
- LogValue<void*>("(TensorHandle)", value.v_handle);
- break;
- }
- case kTVMNDArrayHandle: {
- LogValue<void*>("kTVMNDArrayHandle", value.v_handle);
- break;
- }
- case kTVMNullptr: {
- Log("Nullptr");
- break;
- }
- case kTVMStr: {
- Log("\"");
- Log(value.v_str);
- Log("\"");
- break;
- }
- case kTVMBytes: {
- TVMByteArray* bytes = static_cast<TVMByteArray*>(value.v_handle);
- int len = bytes->size;
- LogValue<int64_t>("(Bytes) [size]: ", len);
- if (PRINT_BYTES) {
- Log(", [Values]:");
- Log(" { ");
- if (len > 0) {
- LogValue<uint64_t>("", (uint8_t)bytes->data[0]);
- }
- for (int j = 1; j < len; j++) LogValue<uint64_t>(" - ",
(uint8_t)bytes->data[j]);
- Log(" } ");
- }
- break;
- }
- default: {
- Log("ERROR-kUnknownTypeCode)");
- break;
- }
- }
- Log("; ");
-}
-
-void Logger::OutputLog() {
- LOG(INFO) << os_.str();
- os_.str(std::string());
-}
-
-void MinRPCReturnsWithLog::ReturnVoid() {
- next_->ReturnVoid();
- logger_->Log("-> ReturnVoid");
- logger_->OutputLog();
-}
-
-void MinRPCReturnsWithLog::ReturnHandle(void* handle) {
- next_->ReturnHandle(handle);
- if (code_ == RPCCode::kGetGlobalFunc) {
- RegisterHandleName(handle);
- }
- logger_->LogValue<void*>("-> ReturnHandle: ", handle);
- logger_->OutputLog();
-}
-
-void MinRPCReturnsWithLog::ReturnException(const char* msg) {
- next_->ReturnException(msg);
- logger_->Log("-> Exception: ");
- logger_->Log(msg);
- logger_->OutputLog();
-}
-
-void MinRPCReturnsWithLog::ReturnPackedSeq(const TVMValue* arg_values, const
int* type_codes,
- int num_args) {
- next_->ReturnPackedSeq(arg_values, type_codes, num_args);
- ProcessValues(arg_values, type_codes, num_args);
- logger_->OutputLog();
-}
-
-void MinRPCReturnsWithLog::ReturnCopyFromRemote(uint8_t* data_ptr, uint64_t
num_bytes) {
- next_->ReturnCopyFromRemote(data_ptr, num_bytes);
- logger_->LogValue<uint64_t>("-> CopyFromRemote: ", num_bytes);
- logger_->LogValue<void*>(", ", static_cast<void*>(data_ptr));
- logger_->OutputLog();
-}
-
-void MinRPCReturnsWithLog::ReturnLastTVMError() {
- const char* err = TVMGetLastError();
- ReturnException(err);
-}
-
-void MinRPCReturnsWithLog::ThrowError(RPCServerStatus code, RPCCode info) {
- next_->ThrowError(code, info);
- logger_->Log("-> ERROR: ");
- logger_->Log(RPCServerStatusToString(code));
- logger_->OutputLog();
-}
-
-void MinRPCReturnsWithLog::ProcessValues(const TVMValue* values, const int*
tcodes, int num_args) {
- if (tcodes != nullptr) {
- logger_->Log("-> [");
- for (int i = 0; i < num_args; ++i) {
- logger_->LogTVMValue(tcodes[i], values[i]);
-
- if (tcodes[i] == kTVMOpaqueHandle) {
- RegisterHandleName(values[i].v_handle);
- }
- }
- logger_->Log("]");
- }
-}
-
-void MinRPCReturnsWithLog::ResetHandleName(RPCCode code) {
- code_ = code;
- handle_name_.clear();
-}
-
-void MinRPCReturnsWithLog::UpdateHandleName(const char* name) {
- if (handle_name_.length() != 0) {
- handle_name_.append("::");
- }
- handle_name_.append(name);
-}
-
-void MinRPCReturnsWithLog::GetHandleName(void* handle) {
- if (handle_descriptions_.find(handle) != handle_descriptions_.end()) {
- handle_name_.append(handle_descriptions_[handle]);
- logger_->LogHandleName(handle_name_);
- }
-}
-
-void MinRPCReturnsWithLog::ReleaseHandleName(void* handle) {
- if (handle_descriptions_.find(handle) != handle_descriptions_.end()) {
- logger_->LogHandleName(handle_descriptions_[handle]);
- handle_descriptions_.erase(handle);
- }
-}
-
-void MinRPCReturnsWithLog::RegisterHandleName(void* handle) {
- handle_descriptions_[handle] = handle_name_;
-}
-
-void MinRPCExecuteWithLog::InitServer(int num_args) {
- SetRPCCode(RPCCode::kInitServer);
- logger_->Log("Init Server");
- next_->InitServer(num_args);
-}
-
-void MinRPCExecuteWithLog::NormalCallFunc(uint64_t call_handle, TVMValue*
values, int* tcodes,
- int num_args) {
- SetRPCCode(RPCCode::kCallFunc);
- logger_->LogValue<void*>("call_handle: ",
reinterpret_cast<void*>(call_handle));
- ret_handler_->GetHandleName(reinterpret_cast<void*>(call_handle));
- if (num_args > 0) {
- logger_->Log(", ");
- }
- ProcessValues(values, tcodes, num_args);
- next_->NormalCallFunc(call_handle, values, tcodes, num_args);
-}
-
-void MinRPCExecuteWithLog::CopyFromRemote(DLTensor* arr, uint64_t num_bytes,
uint8_t* temp_data) {
- SetRPCCode(RPCCode::kCopyFromRemote);
- logger_->LogValue<void*>("data_handle: ", static_cast<void*>(arr->data));
- logger_->LogDLDevice(", DLDevice(type,id):", &(arr->device));
- logger_->LogValue<int64_t>(", ndim: ", arr->ndim);
- logger_->LogDLData(", DLDataType(code,bits,lane): ", &(arr->dtype));
- logger_->LogValue<uint64_t>(", num_bytes:", num_bytes);
- next_->CopyFromRemote(arr, num_bytes, temp_data);
-}
-
-int MinRPCExecuteWithLog::CopyToRemote(DLTensor* arr, uint64_t num_bytes,
uint8_t* data_ptr) {
- SetRPCCode(RPCCode::kCopyToRemote);
- logger_->LogValue<void*>("data_handle: ", static_cast<void*>(arr->data));
- logger_->LogDLDevice(", DLDevice(type,id):", &(arr->device));
- logger_->LogValue<int64_t>(", ndim: ", arr->ndim);
- logger_->LogDLData(", DLDataType(code,bits,lane): ", &(arr->dtype));
- logger_->LogValue<uint64_t>(", byte_offset: ", arr->byte_offset);
- return next_->CopyToRemote(arr, num_bytes, data_ptr);
-}
-
-void MinRPCExecuteWithLog::SysCallFunc(RPCCode code, TVMValue* values, int*
tcodes, int num_args) {
- SetRPCCode(code);
- if ((code) == RPCCode::kFreeHandle) {
- if ((num_args == 2) && (tcodes[0] == kTVMOpaqueHandle) && (tcodes[1] ==
kDLInt)) {
- logger_->LogValue<void*>("handle: ",
static_cast<void*>(values[0].v_handle));
- if (values[1].v_int64 == kTVMModuleHandle || values[1].v_int64 ==
kTVMPackedFuncHandle) {
-
ret_handler_->ReleaseHandleName(static_cast<void*>(values[0].v_handle));
- }
- }
- } else {
- ProcessValues(values, tcodes, num_args);
- }
- next_->SysCallFunc(code, values, tcodes, num_args);
-}
-
-void MinRPCExecuteWithLog::ThrowError(RPCServerStatus code, RPCCode info) {
- logger_->Log("-> Error\n");
- next_->ThrowError(code, info);
-}
-
-void MinRPCExecuteWithLog::ProcessValues(TVMValue* values, int* tcodes, int
num_args) {
- if (tcodes != nullptr) {
- logger_->Log("[");
- for (int i = 0; i < num_args; ++i) {
- logger_->LogTVMValue(tcodes[i], values[i]);
-
- if (tcodes[i] == kTVMStr) {
- if (strlen(values[i].v_str) > 0) {
- ret_handler_->UpdateHandleName(values[i].v_str);
- }
- }
- }
- logger_->Log("]");
- }
-}
-
-void MinRPCExecuteWithLog::SetRPCCode(RPCCode code) {
- logger_->Log(RPCCodeToString(code));
- logger_->Log(", ");
- ret_handler_->ResetHandleName(code);
-}
-
-} // namespace runtime
-} // namespace tvm
diff --git a/src/runtime/minrpc/minrpc_logger.h
b/src/runtime/minrpc/minrpc_logger.h
deleted file mode 100644
index 13d44c3cba..0000000000
--- a/src/runtime/minrpc/minrpc_logger.h
+++ /dev/null
@@ -1,296 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-#ifndef TVM_RUNTIME_MINRPC_MINRPC_LOGGER_H_
-#define TVM_RUNTIME_MINRPC_MINRPC_LOGGER_H_
-
-#include <tvm/runtime/c_runtime_api.h>
-
-#include <functional>
-#include <sstream>
-#include <string>
-#include <unordered_map>
-
-#include "minrpc_interfaces.h"
-#include "rpc_reference.h"
-
-namespace tvm {
-namespace runtime {
-
-#define PRINT_BYTES false
-
-/*!
- * \brief Generates a user readeable log on the console
- */
-class Logger {
- public:
- Logger() {}
-
- /*!
- * \brief this function logs a string
- *
- * \param s the string to be logged.
- */
- void Log(const char* s) { os_ << s; }
- void Log(std::string s) { os_ << s; }
-
- /*!
- * \brief this function logs a numerical value
- *
- * \param desc adds any necessary description before the value.
- * \param val is the value to be logged.
- */
- template <typename T>
- void LogValue(const char* desc, T val) {
- os_ << desc << val;
- }
-
- /*!
- * \brief this function logs the properties of a DLDevice
- *
- * \param desc adds any necessary description before the DLDevice.
- * \param dev is the pointer to the DLDevice to be logged.
- */
- void LogDLDevice(const char* desc, DLDevice* dev) {
- os_ << desc << "(" << dev->device_type << "," << dev->device_id << ")";
- }
-
- /*!
- * \brief this function logs the properties of a DLDataType
- *
- * \param desc adds any necessary description before the DLDataType.
- * \param data is the pointer to the DLDataType to be logged.
- */
- void LogDLData(const char* desc, DLDataType* data) {
- os_ << desc << "(" << (uint16_t)data->code << "," << (uint16_t)data->bits
<< "," << data->lanes
- << ")";
- }
-
- /*!
- * \brief this function logs a handle name.
- *
- * \param name is the name to be logged.
- */
- void LogHandleName(std::string name) {
- if (name.length() > 0) {
- os_ << " <" << name.c_str() << ">";
- }
- }
-
- /*!
- * \brief this function logs a TVMValue based on its type.
- *
- * \param tcode the type_code of the value stored in TVMValue.
- * \param value is the TVMValue to be logged.
- */
- void LogTVMValue(int tcode, TVMValue value);
-
- /*!
- * \brief this function output the log to the console.
- */
- void OutputLog();
-
- private:
- std::stringstream os_;
-};
-
-/*!
- * \brief A wrapper for a MinRPCReturns object, that also logs the responses.
- *
- * \param next underlying MinRPCReturns that generates the responses.
- */
-class MinRPCReturnsWithLog : public MinRPCReturnInterface {
- public:
- /*!
- * \brief Constructor.
- * \param io The IO handler.
- */
- MinRPCReturnsWithLog(MinRPCReturnInterface* next, Logger* logger)
- : next_(next), logger_(logger) {}
-
- ~MinRPCReturnsWithLog() {}
-
- void ReturnVoid();
-
- void ReturnHandle(void* handle);
-
- void ReturnException(const char* msg);
-
- void ReturnPackedSeq(const TVMValue* arg_values, const int* type_codes, int
num_args);
-
- void ReturnCopyFromRemote(uint8_t* data_ptr, uint64_t num_bytes);
-
- void ReturnLastTVMError();
-
- void ThrowError(RPCServerStatus code, RPCCode info = RPCCode::kNone);
-
- /*!
- * \brief this function logs a list of TVMValues, and registers handle_name
when needed.
- *
- * \param values is the list of TVMValues.
- * \param tcodes is the list type_code of the TVMValues.
- * \param num_args is the number of items in the list.
- */
- void ProcessValues(const TVMValue* values, const int* tcodes, int num_args);
-
- /*!
- * \brief this function is called when a new command is executed.
- * It clears the handle_name_ and records the command code.
- *
- * \param code the RPC command code.
- */
- void ResetHandleName(RPCCode code);
-
- /*!
- * \brief appends name to the handle_name_.
- *
- * \param name handle name.
- */
- void UpdateHandleName(const char* name);
-
- /*!
- * \brief get the stored handle description.
- *
- * \param handle the handle to get the description for.
- */
- void GetHandleName(void* handle);
-
- /*!
- * \brief remove the handle description from handle_descriptions_.
- *
- * \param handle the handle to remove the description for.
- */
- void ReleaseHandleName(void* handle);
-
- private:
- /*!
- * \brief add the handle description to handle_descriptions_.
- *
- * \param handle the handle to add the description for.
- */
- void RegisterHandleName(void* handle);
-
- MinRPCReturnInterface* next_;
- std::string handle_name_;
- std::unordered_map<void*, std::string> handle_descriptions_;
- RPCCode code_;
- Logger* logger_;
-};
-
-/*!
- * \brief A wrapper for a MinRPCExecute object, that also logs the responses.
- *
- * \param next: underlying MinRPCExecute that processes the packets.
- */
-class MinRPCExecuteWithLog : public MinRPCExecInterface {
- public:
- MinRPCExecuteWithLog(MinRPCExecInterface* next, Logger* logger) :
next_(next), logger_(logger) {
- ret_handler_ =
reinterpret_cast<MinRPCReturnsWithLog*>(next_->GetReturnInterface());
- }
-
- ~MinRPCExecuteWithLog() {}
-
- void InitServer(int num_args);
-
- void NormalCallFunc(uint64_t call_handle, TVMValue* values, int* tcodes, int
num_args);
-
- void CopyFromRemote(DLTensor* arr, uint64_t num_bytes, uint8_t* temp_data);
-
- int CopyToRemote(DLTensor* arr, uint64_t _num_bytes, uint8_t* _data_ptr);
-
- void SysCallFunc(RPCCode code, TVMValue* values, int* tcodes, int num_args);
-
- void ThrowError(RPCServerStatus code, RPCCode info = RPCCode::kNone);
-
- MinRPCReturnInterface* GetReturnInterface() { return
next_->GetReturnInterface(); }
-
- private:
- /*!
- * \brief this function logs a list of TVMValues, and updates handle_name
when needed.
- *
- * \param values is the list of TVMValues.
- * \param tcodes is the list type_code of the TVMValues.
- * \param num_args is the number of items in the list.
- */
- void ProcessValues(TVMValue* values, int* tcodes, int num_args);
-
- /*!
- * \brief this function is called when a new command is executed.
- *
- * \param code the RPC command code.
- */
- void SetRPCCode(RPCCode code);
-
- MinRPCExecInterface* next_;
- MinRPCReturnsWithLog* ret_handler_;
- Logger* logger_;
-};
-
-/*!
- * \brief A No-operation MinRPCReturns used within the MinRPCSniffer
- *
- * \tparam TIOHandler* IO provider to provide io handling.
- */
-template <typename TIOHandler>
-class MinRPCReturnsNoOp : public MinRPCReturnInterface {
- public:
- /*!
- * \brief Constructor.
- * \param io The IO handler.
- */
- explicit MinRPCReturnsNoOp(TIOHandler* io) : io_(io) {}
- ~MinRPCReturnsNoOp() {}
- void ReturnVoid() {}
- void ReturnHandle(void* handle) {}
- void ReturnException(const char* msg) {}
- void ReturnPackedSeq(const TVMValue* arg_values, const int* type_codes, int
num_args) {}
- void ReturnCopyFromRemote(uint8_t* data_ptr, uint64_t num_bytes) {}
- void ReturnLastTVMError() {}
- void ThrowError(RPCServerStatus code, RPCCode info) {}
-
- private:
- TIOHandler* io_;
-};
-
-/*!
- * \brief A No-operation MinRPCExecute used within the MinRPCSniffer
- *
- * \tparam ReturnInterface* ReturnInterface pointer to generate and send the
responses.
-
- */
-class MinRPCExecuteNoOp : public MinRPCExecInterface {
- public:
- explicit MinRPCExecuteNoOp(MinRPCReturnInterface* ret_handler) :
ret_handler_(ret_handler) {}
- ~MinRPCExecuteNoOp() {}
- void InitServer(int _num_args) {}
- void NormalCallFunc(uint64_t call_handle, TVMValue* values, int* tcodes, int
num_args) {}
- void CopyFromRemote(DLTensor* arr, uint64_t num_bytes, uint8_t* temp_data) {}
- int CopyToRemote(DLTensor* arr, uint64_t num_bytes, uint8_t* data_ptr) {
return 1; }
- void SysCallFunc(RPCCode code, TVMValue* values, int* tcodes, int num_args)
{}
- void ThrowError(RPCServerStatus code, RPCCode info) {}
- MinRPCReturnInterface* GetReturnInterface() { return ret_handler_; }
-
- private:
- MinRPCReturnInterface* ret_handler_;
-};
-
-} // namespace runtime
-} // namespace tvm
-
-#endif // TVM_RUNTIME_MINRPC_MINRPC_LOGGER_H_"
diff --git a/src/runtime/minrpc/minrpc_server.h
b/src/runtime/minrpc/minrpc_server.h
index 2b14a8ae83..727e2d6505 100644
--- a/src/runtime/minrpc/minrpc_server.h
+++ b/src/runtime/minrpc/minrpc_server.h
@@ -28,511 +28,24 @@
#ifndef TVM_RUNTIME_MINRPC_MINRPC_SERVER_H_
#define TVM_RUNTIME_MINRPC_MINRPC_SERVER_H_
-#ifndef DMLC_LITTLE_ENDIAN
-#define DMLC_LITTLE_ENDIAN 1
-#endif
-
-#include <string.h>
+#include <tvm/ffi/c_api.h>
+#include <tvm/ffi/function.h>
#include <tvm/runtime/c_runtime_api.h>
+#include <tvm/runtime/logging.h>
+#include <cstring>
#include <memory>
#include <utility>
#include "../../support/generic_arena.h"
-#include "minrpc_interfaces.h"
#include "rpc_reference.h"
-#ifndef MINRPC_CHECK
-#define MINRPC_CHECK(cond) \
- if (!(cond)) this->ThrowError(RPCServerStatus::kCheckError);
-#endif
-
namespace tvm {
namespace runtime {
-
-namespace detail {
+namespace details {
template <typename TIOHandler>
class PageAllocator;
-}
-
-/*!
- * \brief Responses to a minimum RPC command.
- *
- * \tparam TIOHandler IO provider to provide io handling.
- */
-template <typename TIOHandler>
-class MinRPCReturns : public MinRPCReturnInterface {
- public:
- /*!
- * \brief Constructor.
- * \param io The IO handler.
- */
- explicit MinRPCReturns(TIOHandler* io) : io_(io) {}
-
- void ReturnVoid() {
- int32_t num_args = 1;
- int32_t tcode = kTVMNullptr;
- RPCCode code = RPCCode::kReturn;
-
- uint64_t packet_nbytes = sizeof(code) + sizeof(num_args) + sizeof(tcode);
-
- io_->MessageStart(packet_nbytes);
- Write(packet_nbytes);
- Write(code);
- Write(num_args);
- Write(tcode);
- io_->MessageDone();
- }
-
- void ReturnHandle(void* handle) {
- int32_t num_args = 1;
- int32_t tcode = kTVMOpaqueHandle;
- RPCCode code = RPCCode::kReturn;
- uint64_t encode_handle = reinterpret_cast<uint64_t>(handle);
- uint64_t packet_nbytes =
- sizeof(code) + sizeof(num_args) + sizeof(tcode) +
sizeof(encode_handle);
-
- io_->MessageStart(packet_nbytes);
- Write(packet_nbytes);
- Write(code);
- Write(num_args);
- Write(tcode);
- Write(encode_handle);
- io_->MessageDone();
- }
-
- void ReturnException(const char* msg) { RPCReference::ReturnException(msg,
this); }
-
- void ReturnPackedSeq(const TVMValue* arg_values, const int* type_codes, int
num_args) {
- RPCReference::ReturnPackedSeq(arg_values, type_codes, num_args, this);
- }
-
- void ReturnCopyFromRemote(uint8_t* data_ptr, uint64_t num_bytes) {
- RPCCode code = RPCCode::kCopyAck;
- uint64_t packet_nbytes = sizeof(code) + num_bytes;
-
- io_->MessageStart(packet_nbytes);
- Write(packet_nbytes);
- Write(code);
- WriteArray(data_ptr, num_bytes);
- io_->MessageDone();
- }
-
- void ReturnLastTVMError() {
- const char* err = TVMGetLastError();
- ReturnException(err);
- }
-
- void MessageStart(uint64_t packet_nbytes) {
io_->MessageStart(packet_nbytes); }
-
- void MessageDone() { io_->MessageDone(); }
-
- void ThrowError(RPCServerStatus code, RPCCode info = RPCCode::kNone) {
- io_->Exit(static_cast<int>(code));
- }
-
- void WriteObject(void* obj) {
this->ThrowError(RPCServerStatus::kUnknownTypeCode); }
- uint64_t GetObjectBytes(void* obj) {
- this->ThrowError(RPCServerStatus::kUnknownTypeCode);
- return 0;
- }
-
- template <typename T>
- void Write(const T& data) {
- static_assert(std::is_trivial<T>::value &&
std::is_standard_layout<T>::value,
- "need to be trival");
- return WriteRawBytes(&data, sizeof(T));
- }
-
- template <typename T>
- void WriteArray(T* data, size_t count) {
- static_assert(std::is_trivial<T>::value &&
std::is_standard_layout<T>::value,
- "need to be trival");
- return WriteRawBytes(data, sizeof(T) * count);
- }
-
- private:
- void WriteRawBytes(const void* data, size_t size) {
- const uint8_t* buf = static_cast<const uint8_t*>(data);
- size_t ndone = 0;
- while (ndone < size) {
- ssize_t ret = io_->PosixWrite(buf, size - ndone);
- if (ret <= 0) {
- this->ThrowError(RPCServerStatus::kWriteError);
- }
- buf += ret;
- ndone += ret;
- }
- }
-
- TIOHandler* io_;
-};
-
-/*!
- * \brief Executing a minimum RPC command.
- *
- * \tparam TIOHandler IO provider to provide io handling.
- * \tparam MinRPCReturnInterface* handles response generatation and
transmission.
- */
-template <typename TIOHandler>
-class MinRPCExecute : public MinRPCExecInterface {
- public:
- MinRPCExecute(TIOHandler* io, MinRPCReturnInterface* ret_handler)
- : io_(io), ret_handler_(ret_handler) {}
-
- void InitServer(int num_args) {
- MINRPC_CHECK(num_args == 0);
- ret_handler_->ReturnVoid();
- }
-
- void NormalCallFunc(uint64_t call_handle, TVMValue* values, int* tcodes, int
num_args) {
- TVMValue ret_value[3];
- int ret_tcode[3];
-
- int call_ecode = TVMFuncCall(reinterpret_cast<void*>(call_handle), values,
tcodes, num_args,
- &(ret_value[1]), &(ret_tcode[1]));
-
- if (call_ecode == 0) {
- // Return value encoding as in LocalSession
- int rv_tcode = ret_tcode[1];
- ret_tcode[0] = kDLInt;
- ret_value[0].v_int64 = rv_tcode;
- if (rv_tcode == kTVMNDArrayHandle) {
- ret_tcode[1] = kTVMDLTensorHandle;
- ret_value[2].v_handle = ret_value[1].v_handle;
- ret_tcode[2] = kTVMOpaqueHandle;
- ret_handler_->ReturnPackedSeq(ret_value, ret_tcode, 3);
- } else if (rv_tcode == kTVMBytes) {
- ret_tcode[1] = kTVMBytes;
- ret_handler_->ReturnPackedSeq(ret_value, ret_tcode, 2);
-
TVMByteArrayFree(reinterpret_cast<TVMByteArray*>(ret_value[1].v_handle)); //
NOLINT(*)
- } else if (rv_tcode == kTVMPackedFuncHandle || rv_tcode ==
kTVMModuleHandle ||
- rv_tcode == kTVMObjectHandle) {
- ret_tcode[1] = kTVMOpaqueHandle;
- ret_handler_->ReturnPackedSeq(ret_value, ret_tcode, 2);
- } else {
- ret_handler_->ReturnPackedSeq(ret_value, ret_tcode, 2);
- }
- } else {
- ret_handler_->ReturnLastTVMError();
- }
- }
-
- void CopyFromRemote(DLTensor* arr, uint64_t num_bytes, uint8_t* data_ptr) {
- int call_ecode = 0;
- if (arr->device.device_type != kDLCPU) {
- DLTensor temp;
- temp.data = static_cast<void*>(data_ptr);
- temp.device = DLDevice{kDLCPU, 0};
- temp.ndim = arr->ndim;
- temp.dtype = arr->dtype;
- temp.shape = arr->shape;
- temp.strides = nullptr;
- temp.byte_offset = 0;
- call_ecode = TVMDeviceCopyDataFromTo(arr, &temp, nullptr);
- // need sync to make sure that the copy is completed.
- if (call_ecode == 0) {
- call_ecode = TVMSynchronize(arr->device.device_type,
arr->device.device_id, nullptr);
- }
- }
-
- if (call_ecode == 0) {
- ret_handler_->ReturnCopyFromRemote(data_ptr, num_bytes);
- } else {
- ret_handler_->ReturnLastTVMError();
- }
- }
-
- int CopyToRemote(DLTensor* arr, uint64_t num_bytes, uint8_t* data_ptr) {
- int call_ecode = 0;
-
- int ret = ReadArray(data_ptr, num_bytes);
- if (ret <= 0) return ret;
-
- if (arr->device.device_type != kDLCPU) {
- DLTensor temp;
- temp.data = data_ptr;
- temp.device = DLDevice{kDLCPU, 0};
- temp.ndim = arr->ndim;
- temp.dtype = arr->dtype;
- temp.shape = arr->shape;
- temp.strides = nullptr;
- temp.byte_offset = 0;
- call_ecode = TVMDeviceCopyDataFromTo(&temp, arr, nullptr);
- // need sync to make sure that the copy is completed.
- if (call_ecode == 0) {
- call_ecode = TVMSynchronize(arr->device.device_type,
arr->device.device_id, nullptr);
- }
- }
-
- if (call_ecode == 0) {
- ret_handler_->ReturnVoid();
- } else {
- ret_handler_->ReturnLastTVMError();
- }
-
- return 1;
- }
-
- void SysCallFunc(RPCCode code, TVMValue* values, int* tcodes, int num_args) {
- switch (code) {
- case RPCCode::kFreeHandle: {
- SyscallFreeHandle(values, tcodes, num_args);
- break;
- }
- case RPCCode::kGetGlobalFunc: {
- SyscallGetGlobalFunc(values, tcodes, num_args);
- break;
- }
- case RPCCode::kDevSetDevice: {
- ret_handler_->ReturnException("SetDevice not supported");
- break;
- }
- case RPCCode::kDevGetAttr: {
- ret_handler_->ReturnException("GetAttr not supported");
- break;
- }
- case RPCCode::kDevAllocData: {
- SyscallDevAllocData(values, tcodes, num_args);
- break;
- }
- case RPCCode::kDevAllocDataWithScope: {
- SyscallDevAllocDataWithScope(values, tcodes, num_args);
- break;
- }
- case RPCCode::kDevFreeData: {
- SyscallDevFreeData(values, tcodes, num_args);
- break;
- }
- case RPCCode::kDevCreateStream: {
- SyscallDevCreateStream(values, tcodes, num_args);
- break;
- }
- case RPCCode::kDevFreeStream: {
- SyscallDevFreeStream(values, tcodes, num_args);
- break;
- }
- case RPCCode::kDevStreamSync: {
- SyscallDevStreamSync(values, tcodes, num_args);
- break;
- }
- case RPCCode::kDevSetStream: {
- SyscallDevSetStream(values, tcodes, num_args);
- break;
- }
- case RPCCode::kCopyAmongRemote: {
- SyscallCopyAmongRemote(values, tcodes, num_args);
- break;
- }
- default: {
- ret_handler_->ReturnException("Syscall not recognized");
- break;
- }
- }
- }
-
- void SyscallFreeHandle(TVMValue* values, int* tcodes, int num_args) {
- MINRPC_CHECK(num_args == 1);
- MINRPC_CHECK(tcodes[0] == kTVMOpaqueHandle);
- void* handle = values[0].v_handle;
- int call_ecode = TVMObjectFree(handle);
-
- if (call_ecode == 0) {
- ret_handler_->ReturnVoid();
- } else {
- ret_handler_->ReturnLastTVMError();
- }
- }
-
- void SyscallGetGlobalFunc(TVMValue* values, int* tcodes, int num_args) {
- MINRPC_CHECK(num_args == 1);
- MINRPC_CHECK(tcodes[0] == kTVMStr);
- void* handle;
- int call_ecode = TVMFuncGetGlobal(values[0].v_str, &handle);
-
- if (call_ecode == 0) {
- ret_handler_->ReturnHandle(handle);
- } else {
- ret_handler_->ReturnLastTVMError();
- }
- }
-
- void SyscallCopyAmongRemote(TVMValue* values, int* tcodes, int num_args) {
- MINRPC_CHECK(num_args == 3);
- // from dltensor
- MINRPC_CHECK(tcodes[0] == kTVMDLTensorHandle);
- // to dltensor
- MINRPC_CHECK(tcodes[1] == kTVMDLTensorHandle);
- // stream
- MINRPC_CHECK(tcodes[2] == kTVMOpaqueHandle);
-
- void* from = values[0].v_handle;
- void* to = values[1].v_handle;
- TVMStreamHandle stream = values[2].v_handle;
-
- int call_ecode = TVMDeviceCopyDataFromTo(reinterpret_cast<DLTensor*>(from),
- reinterpret_cast<DLTensor*>(to),
stream);
-
- if (call_ecode == 0) {
- ret_handler_->ReturnVoid();
- } else {
- ret_handler_->ReturnLastTVMError();
- }
- }
-
- void SyscallDevAllocData(TVMValue* values, int* tcodes, int num_args) {
- MINRPC_CHECK(num_args == 4);
- MINRPC_CHECK(tcodes[0] == kDLDevice);
- MINRPC_CHECK(tcodes[1] == kDLInt);
- MINRPC_CHECK(tcodes[2] == kDLInt);
- MINRPC_CHECK(tcodes[3] == kTVMDataType);
-
- DLDevice dev = values[0].v_device;
- int64_t nbytes = values[1].v_int64;
- int64_t alignment = values[2].v_int64;
- DLDataType type_hint = values[3].v_type;
-
- void* handle;
- int call_ecode = TVMDeviceAllocDataSpace(dev, nbytes, alignment,
type_hint, &handle);
-
- if (call_ecode == 0) {
- ret_handler_->ReturnHandle(handle);
- } else {
- ret_handler_->ReturnLastTVMError();
- }
- }
-
- void SyscallDevAllocDataWithScope(TVMValue* values, int* tcodes, int
num_args) {
- MINRPC_CHECK(num_args == 2);
- MINRPC_CHECK(tcodes[0] == kTVMDLTensorHandle);
- MINRPC_CHECK(tcodes[1] == kTVMNullptr || tcodes[1] == kTVMStr);
-
- DLTensor* arr = static_cast<DLTensor*>(values[0].v_handle);
- const char* mem_scope = (tcodes[1] == kTVMNullptr ? nullptr :
values[1].v_str);
- void* handle;
- int call_ecode = TVMDeviceAllocDataSpaceWithScope(arr->device, arr->ndim,
arr->shape,
- arr->dtype, mem_scope,
&handle);
- if (call_ecode == 0) {
- ret_handler_->ReturnHandle(handle);
- } else {
- ret_handler_->ReturnLastTVMError();
- }
- }
-
- void SyscallDevFreeData(TVMValue* values, int* tcodes, int num_args) {
- MINRPC_CHECK(num_args == 2);
- MINRPC_CHECK(tcodes[0] == kDLDevice);
- MINRPC_CHECK(tcodes[1] == kTVMOpaqueHandle);
-
- DLDevice dev = values[0].v_device;
- void* handle = values[1].v_handle;
-
- int call_ecode = TVMDeviceFreeDataSpace(dev, handle);
-
- if (call_ecode == 0) {
- ret_handler_->ReturnVoid();
- } else {
- ret_handler_->ReturnLastTVMError();
- }
- }
-
- void SyscallDevCreateStream(TVMValue* values, int* tcodes, int num_args) {
- MINRPC_CHECK(num_args == 1);
- MINRPC_CHECK(tcodes[0] == kDLDevice);
-
- DLDevice dev = values[0].v_device;
- void* handle;
-
- int call_ecode = TVMStreamCreate(dev.device_type, dev.device_id, &handle);
-
- if (call_ecode == 0) {
- ret_handler_->ReturnHandle(handle);
- } else {
- ret_handler_->ReturnLastTVMError();
- }
- }
-
- void SyscallDevFreeStream(TVMValue* values, int* tcodes, int num_args) {
- MINRPC_CHECK(num_args == 2);
- MINRPC_CHECK(tcodes[0] == kDLDevice);
- MINRPC_CHECK(tcodes[1] == kTVMOpaqueHandle);
-
- DLDevice dev = values[0].v_device;
- void* handle = values[1].v_handle;
-
- int call_ecode = TVMStreamFree(dev.device_type, dev.device_id, handle);
-
- if (call_ecode == 0) {
- ret_handler_->ReturnVoid();
- } else {
- ret_handler_->ReturnLastTVMError();
- }
- }
-
- void SyscallDevStreamSync(TVMValue* values, int* tcodes, int num_args) {
- MINRPC_CHECK(num_args == 2);
- MINRPC_CHECK(tcodes[0] == kDLDevice);
- MINRPC_CHECK(tcodes[1] == kTVMOpaqueHandle);
-
- DLDevice dev = values[0].v_device;
- void* handle = values[1].v_handle;
-
- int call_ecode = TVMSynchronize(dev.device_type, dev.device_id, handle);
-
- if (call_ecode == 0) {
- ret_handler_->ReturnVoid();
- } else {
- ret_handler_->ReturnLastTVMError();
- }
- }
-
- void SyscallDevSetStream(TVMValue* values, int* tcodes, int num_args) {
- MINRPC_CHECK(num_args == 2);
- MINRPC_CHECK(tcodes[0] == kDLDevice);
- MINRPC_CHECK(tcodes[1] == kTVMOpaqueHandle);
-
- DLDevice dev = values[0].v_device;
- void* handle = values[1].v_handle;
-
- int call_ecode = TVMSetStream(dev.device_type, dev.device_id, handle);
-
- if (call_ecode == 0) {
- ret_handler_->ReturnVoid();
- } else {
- ret_handler_->ReturnLastTVMError();
- }
- }
-
- void ThrowError(RPCServerStatus code, RPCCode info = RPCCode::kNone) {
- ret_handler_->ThrowError(code, info);
- }
-
- MinRPCReturnInterface* GetReturnInterface() { return ret_handler_; }
-
- private:
- template <typename T>
- int ReadArray(T* data, size_t count) {
- static_assert(std::is_trivial<T>::value &&
std::is_standard_layout<T>::value,
- "need to be trival");
- return ReadRawBytes(data, sizeof(T) * count);
- }
-
- int ReadRawBytes(void* data, size_t size) {
- uint8_t* buf = static_cast<uint8_t*>(data);
- size_t ndone = 0;
- while (ndone < size) {
- ssize_t ret = io_->PosixRead(buf, size - ndone);
- if (ret <= 0) return ret;
- ndone += ret;
- buf += ret;
- }
- return 1;
- }
-
- TIOHandler* io_;
- MinRPCReturnInterface* ret_handler_;
-};
-
+} // namespace details
/*!
* \brief A minimum RPC server that only depends on the tvm C runtime..
*
@@ -544,180 +57,61 @@ class MinRPCExecute : public MinRPCExecInterface {
* - MessageStart(num_bytes), MessageDone(): framing APIs.
* - Exit: exit with status code.
*/
-template <typename TIOHandler, template <typename> class Allocator =
detail::PageAllocator>
+template <typename TIOHandler, template <typename> class Allocator =
details::PageAllocator>
class MinRPCServer {
public:
using PageAllocator = Allocator<TIOHandler>;
- /*!
- * \brief Constructor.
- * \param io The IO handler.
- */
- MinRPCServer(TIOHandler* io, std::unique_ptr<MinRPCExecInterface>&&
exec_handler)
- : io_(io), arena_(PageAllocator(io_)),
exec_handler_(std::move(exec_handler)) {}
+ using FServerHandler = ffi::TypedFunction<int(TVMFFIByteArray*, int)>;
- explicit MinRPCServer(TIOHandler* io)
- : io_(io),
- arena_(PageAllocator(io)),
- ret_handler_(new MinRPCReturns<TIOHandler>(io_)),
- exec_handler_(std::unique_ptr<MinRPCExecInterface>(
- new MinRPCExecute<TIOHandler>(io_, ret_handler_))) {}
-
- ~MinRPCServer() {
- if (ret_handler_ != nullptr) {
- delete ret_handler_;
- }
+ explicit MinRPCServer(TIOHandler* io) : io_(io), arena_(PageAllocator(io_)) {
+ auto fsend = ffi::Function::FromTyped([this](TVMFFIByteArray* bytes) {
+ return io_->PosixWrite(reinterpret_cast<const uint8_t*>(bytes->data),
bytes->size);
+ });
+ auto fcreate =
tvm::ffi::Function::GetGlobalRequired("rpc.CreateEventDrivenServer");
+ ffi::Any value = fcreate(fsend, "MinRPCServer", "");
+ fserver_handler_ = value.cast<FServerHandler>();
}
- /*! \brief Process a single request.
+ /*!
+ * \brief Process a single request.
*
* \return true when the server should continue processing requests. false
when it should be
* shutdown.
*/
bool ProcessOnePacket() {
- RPCCode code;
uint64_t packet_len;
-
arena_.RecycleAll();
allow_clean_shutdown_ = true;
-
Read(&packet_len);
if (packet_len == 0) return true;
- Read(&code);
- allow_clean_shutdown_ = false;
-
- if (code >= RPCCode::kSyscallCodeStart) {
- HandleSyscallFunc(code);
- } else {
- switch (code) {
- case RPCCode::kCallFunc: {
- HandleNormalCallFunc();
- break;
- }
- case RPCCode::kInitServer: {
- HandleInitServer();
- break;
- }
- case RPCCode::kCopyFromRemote: {
- HandleCopyFromRemote();
- break;
- }
- case RPCCode::kCopyToRemote: {
- HandleCopyToRemote();
- break;
- }
- case RPCCode::kShutdown: {
- Shutdown();
- return false;
- }
- default: {
- this->ThrowError(RPCServerStatus::kUnknownRPCCode);
- break;
- }
+ char* read_buffer = this->ArenaAlloc<char>(sizeof(uint64_t) + packet_len);
+ // copy header into read buffer
+ std::memcpy(read_buffer, &packet_len, sizeof(uint64_t));
+ // read the rest of the packet
+ ReadRawBytes(read_buffer + sizeof(uint64_t), packet_len);
+ // setup write flags
+ int write_flags = 3;
+ TVMFFIByteArray read_bytes{read_buffer, sizeof(uint64_t) +
static_cast<size_t>(packet_len)};
+ int status = fserver_handler_(&read_bytes, write_flags);
+
+ while (status == 2) {
+ TVMFFIByteArray write_bytes{nullptr, 0};
+ // continue call handler until it have nothing to write
+ status = fserver_handler_(&write_bytes, write_flags);
+ if (status == 0) {
+ this->Shutdown();
+ return false;
}
}
-
return true;
}
- void HandleInitServer() {
- uint64_t len;
- Read(&len);
- char* proto_ver = ArenaAlloc<char>(len + 1);
- ReadArray(proto_ver, len);
- TVMValue* values;
- int* tcodes;
- int num_args;
- RecvPackedSeq(&values, &tcodes, &num_args);
- exec_handler_->InitServer(num_args);
- }
-
void Shutdown() {
arena_.FreeAll();
io_->Close();
}
- void HandleNormalCallFunc() {
- uint64_t call_handle;
- TVMValue* values;
- int* tcodes;
- int num_args;
-
- Read(&call_handle);
- RecvPackedSeq(&values, &tcodes, &num_args);
- exec_handler_->NormalCallFunc(call_handle, values, tcodes, num_args);
- }
-
- void HandleCopyFromRemote() {
- DLTensor* arr = ArenaAlloc<DLTensor>(1);
- uint64_t data_handle;
- Read(&data_handle);
- arr->data = reinterpret_cast<void*>(data_handle);
- Read(&(arr->device));
- Read(&(arr->ndim));
- Read(&(arr->dtype));
- arr->shape = ArenaAlloc<int64_t>(arr->ndim);
- ReadArray(arr->shape, arr->ndim);
- arr->strides = nullptr;
- Read(&(arr->byte_offset));
-
- uint64_t num_bytes;
- Read(&num_bytes);
-
- uint8_t* data_ptr;
- if (arr->device.device_type == kDLCPU) {
- data_ptr = reinterpret_cast<uint8_t*>(data_handle) + arr->byte_offset;
- } else {
- data_ptr = ArenaAlloc<uint8_t>(num_bytes);
- }
-
- exec_handler_->CopyFromRemote(arr, num_bytes, data_ptr);
- }
-
- void HandleCopyToRemote() {
- DLTensor* arr = ArenaAlloc<DLTensor>(1);
- uint64_t data_handle;
- Read(&data_handle);
- arr->data = reinterpret_cast<void*>(data_handle);
- Read(&(arr->device));
- Read(&(arr->ndim));
- Read(&(arr->dtype));
- arr->shape = ArenaAlloc<int64_t>(arr->ndim);
- ReadArray(arr->shape, arr->ndim);
- arr->strides = nullptr;
- Read(&(arr->byte_offset));
- uint64_t num_bytes;
- Read(&num_bytes);
- int ret;
- if (arr->device.device_type == kDLCPU) {
- uint8_t* dptr = reinterpret_cast<uint8_t*>(data_handle) +
arr->byte_offset;
- ret = exec_handler_->CopyToRemote(arr, num_bytes, dptr);
- } else {
- uint8_t* temp_data = ArenaAlloc<uint8_t>(num_bytes);
- ret = exec_handler_->CopyToRemote(arr, num_bytes, temp_data);
- }
- if (ret == 0) {
- if (allow_clean_shutdown_) {
- Shutdown();
- io_->Exit(0);
- } else {
- this->ThrowError(RPCServerStatus::kReadError);
- }
- }
- if (ret == -1) {
- this->ThrowError(RPCServerStatus::kReadError);
- }
- }
-
- void HandleSyscallFunc(RPCCode code) {
- TVMValue* values;
- int* tcodes;
- int num_args;
- RecvPackedSeq(&values, &tcodes, &num_args);
-
- exec_handler_->SysCallFunc(code, values, tcodes, num_args);
- }
-
void ThrowError(RPCServerStatus code, RPCCode info = RPCCode::kNone) {
io_->Exit(static_cast<int>(code));
}
@@ -736,32 +130,7 @@ class MinRPCServer {
ReadRawBytes(data, sizeof(T));
}
- template <typename T>
- void ReadArray(T* data, size_t count) {
- static_assert(std::is_trivial<T>::value &&
std::is_standard_layout<T>::value,
- "need to be trival");
- return ReadRawBytes(data, sizeof(T) * count);
- }
-
- void ReadObject(int* tcode, TVMValue* value) {
- // handles RPCObject in minRPC
- // NOTE: object needs to be supported by C runtime
- // because minrpc's restriction of C only
- // we only handle RPCObjectRef
- uint32_t type_index;
- Read(&type_index);
- MINRPC_CHECK(type_index == kRuntimeRPCObjectRefTypeIndex);
- uint64_t object_handle;
- Read(&object_handle);
- tcode[0] = kTVMObjectHandle;
- value[0].v_handle = reinterpret_cast<void*>(object_handle);
- }
-
private:
- void RecvPackedSeq(TVMValue** out_values, int** out_tcodes, int*
out_num_args) {
- RPCReference::RecvPackedSeq(out_values, out_tcodes, out_num_args, this);
- }
-
void ReadRawBytes(void* data, size_t size) {
uint8_t* buf = static_cast<uint8_t*>(data);
size_t ndone = 0;
@@ -783,18 +152,17 @@ class MinRPCServer {
}
}
+ /*! \brief server handler. */
+ FServerHandler fserver_handler_;
/*! \brief IO handler. */
TIOHandler* io_;
/*! \brief internal arena. */
support::GenericArena<PageAllocator> arena_;
- MinRPCReturns<TIOHandler>* ret_handler_ = nullptr;
- std::unique_ptr<MinRPCExecInterface> exec_handler_;
/*! \brief Whether we are in a state that allows clean shutdown. */
bool allow_clean_shutdown_{true};
- static_assert(DMLC_LITTLE_ENDIAN == 1, "MinRPC only works on little
endian.");
};
-namespace detail {
+namespace details {
// Internal allocator that redirects alloc to TVM's C API.
template <typename TIOHandler>
class PageAllocator {
@@ -805,10 +173,9 @@ class PageAllocator {
ArenaPageHeader* allocate(size_t min_size) {
size_t npages = ((min_size + kPageSize - 1) / kPageSize);
- void* data;
+ void* data = malloc(npages * kPageSize);
- if (TVMDeviceAllocDataSpace(DLDevice{kDLCPU, 0}, npages * kPageSize,
kPageAlign,
- DLDataType{kDLInt, 1, 1}, &data) != 0) {
+ if (data == nullptr) {
io_->Exit(static_cast<int>(RPCServerStatus::kAllocError));
}
@@ -818,11 +185,7 @@ class PageAllocator {
return header;
}
- void deallocate(ArenaPageHeader* page) {
- if (TVMDeviceFreeDataSpace(DLDevice{kDLCPU, 0}, page) != 0) {
- io_->Exit(static_cast<int>(RPCServerStatus::kAllocError));
- }
- }
+ void deallocate(ArenaPageHeader* page) { free(page); }
static const constexpr int kPageSize = 2 << 10;
static const constexpr int kPageAlign = 8;
@@ -830,7 +193,7 @@ class PageAllocator {
private:
TIOHandler* io_;
};
-} // namespace detail
+} // namespace details
} // namespace runtime
} // namespace tvm
diff --git a/src/runtime/minrpc/minrpc_server_logging.h
b/src/runtime/minrpc/minrpc_server_logging.h
deleted file mode 100644
index 89650efe9a..0000000000
--- a/src/runtime/minrpc/minrpc_server_logging.h
+++ /dev/null
@@ -1,170 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-#ifndef TVM_RUNTIME_MINRPC_MINRPC_SERVER_LOGGING_H_
-#define TVM_RUNTIME_MINRPC_MINRPC_SERVER_LOGGING_H_
-
-#include <memory>
-#include <utility>
-
-#include "minrpc_logger.h"
-#include "minrpc_server.h"
-
-namespace tvm {
-namespace runtime {
-
-/*!
- * \brief A minimum RPC server that logs the received commands.
- *
- * \tparam TIOHandler IO provider to provide io handling.
- */
-template <typename TIOHandler>
-class MinRPCServerWithLog {
- public:
- explicit MinRPCServerWithLog(TIOHandler* io)
- : ret_handler_(io),
- ret_handler_wlog_(&ret_handler_, &logger_),
- exec_handler_(io, &ret_handler_wlog_),
- exec_handler_ptr_(new MinRPCExecuteWithLog(&exec_handler_, &logger_)),
- next_(io, std::move(exec_handler_ptr_)) {}
-
- bool ProcessOnePacket() { return next_.ProcessOnePacket(); }
-
- private:
- Logger logger_;
- MinRPCReturns<TIOHandler> ret_handler_;
- MinRPCExecute<TIOHandler> exec_handler_;
- MinRPCReturnsWithLog ret_handler_wlog_;
- std::unique_ptr<MinRPCExecuteWithLog> exec_handler_ptr_;
- MinRPCServer<TIOHandler> next_;
-};
-
-/*!
- * \brief A minimum RPC server that only logs the outgoing commands and
received responses.
- * (Does not process the packets or respond to them.)
- *
- * \tparam TIOHandler IO provider to provide io handling.
- */
-template <typename TIOHandler, template <typename> class Allocator =
detail::PageAllocator>
-class MinRPCSniffer {
- public:
- using PageAllocator = Allocator<TIOHandler>;
- explicit MinRPCSniffer(TIOHandler* io)
- : io_(io),
- arena_(PageAllocator(io_)),
- ret_handler_(io_),
- ret_handler_wlog_(&ret_handler_, &logger_),
- exec_handler_(&ret_handler_wlog_),
- exec_handler_ptr_(new MinRPCExecuteWithLog(&exec_handler_, &logger_)),
- next_(io_, std::move(exec_handler_ptr_)) {}
-
- bool ProcessOnePacket() { return next_.ProcessOnePacket(); }
-
- void ProcessOneResponse() {
- RPCCode code;
- uint64_t packet_len = 0;
-
- if (!Read(&packet_len)) return;
- if (packet_len == 0) {
- OutputLog();
- return;
- }
- if (!Read(&code)) return;
- switch (code) {
- case RPCCode::kReturn: {
- int32_t num_args;
- int* type_codes;
- TVMValue* values;
- RPCReference::RecvPackedSeq(&values, &type_codes, &num_args, this);
- ret_handler_wlog_.ReturnPackedSeq(values, type_codes, num_args);
- break;
- }
- case RPCCode::kException: {
- ret_handler_wlog_.ReturnException("");
- break;
- }
- default: {
- OutputLog();
- break;
- }
- }
- }
-
- void OutputLog() { logger_.OutputLog(); }
-
- void ThrowError(RPCServerStatus code, RPCCode info = RPCCode::kNone) {
- logger_.Log("-> ");
- logger_.Log(RPCServerStatusToString(code));
- OutputLog();
- }
-
- template <typename T>
- T* ArenaAlloc(int count) {
- static_assert(std::is_trivial<T>::value &&
std::is_standard_layout<T>::value,
- "need to be trival");
- return arena_.template allocate_<T>(count);
- }
-
- template <typename T>
- bool Read(T* data) {
- static_assert(std::is_trivial<T>::value &&
std::is_standard_layout<T>::value,
- "need to be trival");
- return ReadRawBytes(data, sizeof(T));
- }
-
- template <typename T>
- bool ReadArray(T* data, size_t count) {
- static_assert(std::is_trivial<T>::value &&
std::is_standard_layout<T>::value,
- "need to be trival");
- return ReadRawBytes(data, sizeof(T) * count);
- }
-
- void ReadObject(int* tcode, TVMValue* value) {
- this->ThrowError(RPCServerStatus::kUnknownTypeCode);
- }
-
- private:
- bool ReadRawBytes(void* data, size_t size) {
- uint8_t* buf = reinterpret_cast<uint8_t*>(data);
- size_t ndone = 0;
- while (ndone < size) {
- ssize_t ret = io_->PosixRead(buf, size - ndone);
- if (ret <= 0) {
- this->ThrowError(RPCServerStatus::kReadError);
- return false;
- }
- ndone += ret;
- buf += ret;
- }
- return true;
- }
-
- Logger logger_;
- TIOHandler* io_;
- support::GenericArena<PageAllocator> arena_;
- MinRPCReturnsNoOp<TIOHandler> ret_handler_;
- MinRPCReturnsWithLog ret_handler_wlog_;
- MinRPCExecuteNoOp exec_handler_;
- std::unique_ptr<MinRPCExecuteWithLog> exec_handler_ptr_;
- MinRPCServer<TIOHandler> next_;
-};
-
-} // namespace runtime
-} // namespace tvm
-#endif // TVM_RUNTIME_MINRPC_MINRPC_SERVER_LOGGING_H_
diff --git a/src/runtime/minrpc/posix_popen_server/posix_popen_server.cc
b/src/runtime/minrpc/posix_popen_server/posix_popen_server.cc
index b513d4b7cc..014704e970 100644
--- a/src/runtime/minrpc/posix_popen_server/posix_popen_server.cc
+++ b/src/runtime/minrpc/posix_popen_server/posix_popen_server.cc
@@ -17,9 +17,6 @@
* under the License.
*/
-// Disable constructor to bring minimum dep on c++ABI.
-#define TVM_ARENA_HAS_DESTRUCTOR 0
-
#include <unistd.h>
#include <cstdlib>
diff --git a/src/runtime/minrpc/rpc_reference.h
b/src/runtime/minrpc/rpc_reference.h
index ff3c9f22fd..41bb40b3f2 100644
--- a/src/runtime/minrpc/rpc_reference.h
+++ b/src/runtime/minrpc/rpc_reference.h
@@ -35,14 +35,6 @@ namespace runtime {
/*! \brief The current RPC procotol version. */
constexpr const char* kRPCProtocolVer = "0.8.0";
-/*!
- * \brief type index of kRuntimeRPCObjectRefTypeIndex
- * \note this needs to be kept consistent with runtime/object.h
- * but we explicitly declare it here because minrpc needs to be minimum dep
- * only c C API
- */
-constexpr const int kRuntimeRPCObjectRefTypeIndex = 9;
-
// When tvm.rpc.server.GetCRTMaxPacketSize global function is not registered.
const uint64_t kRPCMaxTransferSizeBytesDefault = UINT64_MAX;
@@ -83,7 +75,7 @@ enum class RPCServerStatus : int {
kInvalidTypeCodeNDArray,
kInvalidDLTensorFieldStride,
kInvalidDLTensorFieldByteOffset,
- kUnknownTypeCode,
+ kUnknownTypeIndex,
kUnknownRPCCode,
kRPCCodeNotSupported,
kUnknownRPCSyscall,
@@ -159,8 +151,8 @@ inline const char* RPCServerStatusToString(RPCServerStatus
status) {
case RPCServerStatus::kInvalidDLTensorFieldByteOffset: {
return "kInvalidDLTensorFieldByteOffset";
}
- case RPCServerStatus::kUnknownTypeCode:
- return "kUnknownTypeCode";
+ case RPCServerStatus::kUnknownTypeIndex:
+ return "kUnknownTypeIndex";
case RPCServerStatus::kUnknownRPCCode:
return "kUnknownRPCCode";
case RPCServerStatus::kRPCCodeNotSupported:
@@ -242,10 +234,10 @@ struct RPCReference {
* \return The total number of bytes.
*/
template <typename TChannel>
- static uint64_t PackedSeqGetNumBytes(const TVMValue* arg_values, const int*
type_codes,
- int num_args, bool client_mode,
TChannel* channel) {
+ static uint64_t PackedSeqGetNumBytes(const TVMFFIAny* packed_args, int
num_args, bool client_mode,
+ TChannel* channel) {
PackedSeqNumBytesGetter<TChannel> getter(channel);
- SendPackedSeq(arg_values, type_codes, num_args, client_mode, &getter);
+ SendPackedSeq(packed_args, num_args, client_mode, &getter);
return getter.num_bytes();
}
@@ -303,93 +295,89 @@ struct RPCReference {
* Note that we cannot simply take these argument out(as the handle)
* refers to a value on the remote(instead of local).
*
- * \param arg_values The values to be sent over.
- * \param type_codes The type codes to be sent over.
+ * \param packed_args The values to be sent over.
* \param num_args Number of argument.
* \param client_mode Whether it is a client to server call.
* \param channel The communication channel handler.
* \tparam TChannel The type of the communication channel.
*/
template <typename TChannel>
- static void SendPackedSeq(const TVMValue* arg_values, const int* type_codes,
int num_args,
- bool client_mode, TChannel* channel) {
+ static void SendPackedSeq(const TVMFFIAny* packed_args, int num_args, bool
client_mode,
+ TChannel* channel) {
channel->Write(num_args);
- channel->WriteArray(type_codes, num_args);
// Argument packing.
for (int i = 0; i < num_args; ++i) {
- int tcode = type_codes[i];
- TVMValue value = arg_values[i];
- switch (tcode) {
- case kDLInt:
- case kDLUInt:
- case kDLFloat: {
- channel->template Write<int64_t>(value.v_int64);
+ int32_t type_index = packed_args[i].type_index;
+ channel->template Write<int32_t>(type_index);
+ switch (type_index) {
+ case ffi::TypeIndex::kTVMFFINone: {
+ break;
+ }
+ case ffi::TypeIndex::kTVMFFIBool:
+ case ffi::TypeIndex::kTVMFFIInt:
+ case ffi::TypeIndex::kTVMFFIFloat: {
+ channel->template Write<int64_t>(packed_args[i].v_int64);
break;
}
- case kTVMArgBool: {
- channel->template Write<int64_t>(value.v_int64);
+ case ffi::TypeIndex::kTVMFFIOpaquePtr: {
+ // always send handle in 64 bit.
+ uint64_t handle = reinterpret_cast<uint64_t>(packed_args[i].v_ptr);
+ channel->template Write<int64_t>(handle);
break;
}
- case kTVMDataType: {
- channel->Write(value.v_type);
+ case ffi::TypeIndex::kTVMFFIDataType: {
+ channel->Write(packed_args[i].v_dtype);
// padding
int32_t padding = 0;
channel->template Write<int32_t>(padding);
break;
}
- case kDLDevice: {
- channel->Write(value.v_device);
+ case ffi::TypeIndex::kTVMFFIDevice: {
+ channel->Write(packed_args[i].v_device);
break;
}
- case kTVMPackedFuncHandle:
- case kTVMModuleHandle: {
+ case ffi::TypeIndex::kTVMFFIFunction:
+ case ffi::TypeIndex::kTVMFFIModule: {
if (!client_mode) {
channel->ThrowError(RPCServerStatus::kInvalidTypeCodeObject);
}
// always send handle in 64 bit.
- uint64_t handle = reinterpret_cast<uint64_t>(value.v_handle);
- channel->Write(handle);
- break;
- }
- case kTVMOpaqueHandle: {
- // always send handle in 64 bit.
- uint64_t handle = reinterpret_cast<uint64_t>(value.v_handle);
+ uint64_t handle = reinterpret_cast<uint64_t>(packed_args[i].v_obj);
channel->Write(handle);
break;
}
- case kTVMNDArrayHandle: {
+
+ case ffi::TypeIndex::kTVMFFINDArray: {
channel->ThrowError(RPCServerStatus::kInvalidTypeCodeNDArray);
break;
}
- case kTVMDLTensorHandle: {
- DLTensor* arr = static_cast<DLTensor*>(value.v_handle);
+ case ffi::TypeIndex::kTVMFFIDLTensorPtr: {
+ DLTensor* arr = static_cast<DLTensor*>(packed_args[i].v_ptr);
SendDLTensor(channel, arr);
break;
}
- case kTVMNullptr:
- break;
- case kTVMStr: {
- const char* s = value.v_str;
+ case ffi::TypeIndex::kTVMFFIRawStr: {
+ const char* s = packed_args[i].v_c_str;
uint64_t len = StrLength(s);
channel->Write(len);
channel->WriteArray(s, len);
break;
}
- case kTVMBytes: {
- TVMByteArray* bytes =
static_cast<TVMByteArray*>(arg_values[i].v_handle);
+ case ffi::TypeIndex::kTVMFFIByteArrayPtr: {
+ TVMFFIByteArray* bytes =
static_cast<TVMFFIByteArray*>(packed_args[i].v_ptr);
uint64_t len = bytes->size;
channel->Write(len);
channel->WriteArray(bytes->data, len);
break;
}
- case kTVMObjectHandle: {
- channel->WriteObject(static_cast<ffi::Object*>(value.v_handle));
- break;
- }
default: {
- channel->ThrowError(RPCServerStatus::kUnknownTypeCode);
+ if (type_index >= ffi::TypeIndex::kTVMFFIStaticObjectBegin) {
+
channel->WriteObject(reinterpret_cast<ffi::Object*>(packed_args[i].v_obj));
+ } else {
+ channel->ThrowError(RPCServerStatus::kUnknownTypeIndex);
+ }
break;
}
}
@@ -399,102 +387,95 @@ struct RPCReference {
/*!
* \brief Receive packed seq from the channel.
*
- * \param out_arg_values The values to be received.
- * \param out_tcodes The type codes to be received.
+ * \param out_packed_args The values to be received.
* \param out_num_args Number of argument.
* \param channel The communication channel handler.
* \tparam TChannel The type of the communication channel.
* \note The temporary space are populated via an arena inside channel.
*/
template <typename TChannel>
- static void RecvPackedSeq(TVMValue** out_values, int** out_tcodes, int*
out_num_args,
- TChannel* channel) {
+ static void RecvPackedSeq(TVMFFIAny** out_packed_args, int32_t*
out_num_args, TChannel* channel) {
// receive number of args
- int num_args;
+ int32_t num_args;
channel->Read(&num_args);
*out_num_args = num_args;
-
if (num_args == 0) {
- *out_values = nullptr;
- *out_tcodes = nullptr;
+ *out_packed_args = nullptr;
return;
}
- TVMValue* values = channel->template ArenaAlloc<TVMValue>(num_args);
- int* tcodes = channel->template ArenaAlloc<int>(num_args);
- *out_values = values;
- *out_tcodes = tcodes;
-
- // receive type code.
- channel->ReadArray(tcodes, num_args);
+ TVMFFIAny* packed_args = channel->template ArenaAlloc<TVMFFIAny>(num_args);
+ *out_packed_args = packed_args;
// receive arguments
- for (int i = 0; i < num_args; ++i) {
- auto& value = values[i];
- switch (tcodes[i]) {
- case kDLInt:
- case kDLUInt:
- case kDLFloat: {
- channel->template Read<int64_t>(&(value.v_int64));
+ for (int32_t i = 0; i < num_args; ++i) {
+ int32_t type_index;
+ channel->Read(&type_index);
+ packed_args[i].type_index = type_index;
+ switch (type_index) {
+ case ffi::TypeIndex::kTVMFFINone: {
+ break;
+ }
+ case ffi::TypeIndex::kTVMFFIBool:
+ case ffi::TypeIndex::kTVMFFIInt:
+ case ffi::TypeIndex::kTVMFFIFloat: {
+ channel->template Read<int64_t>(&(packed_args[i].v_int64));
break;
}
- case kTVMArgBool: {
- channel->template Read<int64_t>(&(value.v_int64));
+ case ffi::TypeIndex::kTVMFFIOpaquePtr: {
+ uint64_t handle;
+ channel->Read(&handle);
+ packed_args[i].v_ptr = reinterpret_cast<void*>(handle);
break;
}
- case kTVMDataType: {
- channel->Read(&(value.v_type));
+ case ffi::TypeIndex::kTVMFFIDataType: {
+ channel->Read(&(packed_args[i].v_dtype));
int32_t padding = 0;
channel->template Read<int32_t>(&padding);
break;
}
- case kDLDevice: {
- channel->Read(&(value.v_device));
+ case ffi::TypeIndex::kTVMFFIDevice: {
+ channel->Read(&(packed_args[i].v_device));
break;
}
- case kTVMPackedFuncHandle:
- case kTVMModuleHandle:
- case kTVMOpaqueHandle: {
+ case ffi::TypeIndex::kTVMFFIFunction:
+ case ffi::TypeIndex::kTVMFFIModule: {
// always send handle in 64 bit.
uint64_t handle;
channel->Read(&handle);
- value.v_handle = reinterpret_cast<void*>(handle);
- break;
- }
- case kTVMNullptr: {
- value.v_handle = nullptr;
+ packed_args[i].v_obj = reinterpret_cast<TVMFFIObject*>(handle);
break;
}
- case kTVMStr: {
+ case ffi::TypeIndex::kTVMFFIRawStr: {
uint64_t len;
channel->Read(&len);
char* str = channel->template ArenaAlloc<char>(len + 1);
str[len] = '\0';
channel->ReadArray(str, len);
- value.v_str = str;
+ packed_args[i].v_c_str = str;
break;
}
- case kTVMBytes: {
+ case ffi::TypeIndex::kTVMFFIByteArrayPtr: {
uint64_t len;
channel->Read(&len);
- TVMByteArray* arr = channel->template ArenaAlloc<TVMByteArray>(1);
+ TVMFFIByteArray* arr = channel->template
ArenaAlloc<TVMFFIByteArray>(1);
char* data = channel->template ArenaAlloc<char>(len);
arr->size = len;
arr->data = data;
channel->ReadArray(data, len);
- value.v_handle = arr;
+ packed_args[i].v_ptr = arr;
break;
}
- case kTVMDLTensorHandle: {
- value.v_handle = ReceiveDLTensor(channel);
- break;
- }
- case kTVMObjectHandle: {
- channel->ReadObject(&tcodes[i], &value);
+ case ffi::TypeIndex::kTVMFFIDLTensorPtr: {
+ packed_args[i].v_ptr = ReceiveDLTensor(channel);
break;
}
default: {
- channel->ThrowError(RPCServerStatus::kUnknownTypeCode);
+ if (type_index >= ffi::TypeIndex::kTVMFFIStaticObjectBegin) {
+ channel->ReadObject(&(packed_args[i]));
+ } else {
+ channel->ThrowError(RPCServerStatus::kUnknownTypeIndex);
+ }
break;
}
}
@@ -512,16 +493,17 @@ struct RPCReference {
static void ReturnException(const char* msg, TChannel* channel) {
RPCCode code = RPCCode::kException;
int32_t num_args = 1;
- int32_t tcode = kTVMStr;
+ int32_t type_index = ffi::TypeIndex::kTVMFFIRawStr;
uint64_t len = StrLength(msg);
- uint64_t packet_nbytes = sizeof(code) + sizeof(num_args) + sizeof(tcode) +
sizeof(len) + len;
+ uint64_t packet_nbytes =
+ sizeof(code) + sizeof(num_args) + sizeof(type_index) + sizeof(len) +
len;
channel->MessageStart(packet_nbytes);
channel->Write(packet_nbytes);
channel->Write(code);
channel->Write(num_args);
- channel->Write(tcode);
+ channel->Write(type_index);
channel->Write(len);
channel->WriteArray(msg, len);
channel->MessageDone();
@@ -535,17 +517,16 @@ struct RPCReference {
* \tparam TChannel The type of the communication channel.
*/
template <typename TChannel>
- static void ReturnPackedSeq(const TVMValue* arg_values, const int*
type_codes, int num_args,
- TChannel* channel) {
+ static void ReturnPackedSeq(const TVMFFIAny* packed_args, int num_args,
TChannel* channel) {
RPCCode code = RPCCode::kReturn;
uint64_t packet_nbytes =
- sizeof(code) + PackedSeqGetNumBytes(arg_values, type_codes, num_args,
false, channel);
+ sizeof(code) + PackedSeqGetNumBytes(packed_args, num_args, false,
channel);
channel->MessageStart(packet_nbytes);
channel->Write(packet_nbytes);
channel->Write(code);
- SendPackedSeq(arg_values, type_codes, num_args, false, channel);
+ SendPackedSeq(packed_args, num_args, false, channel);
channel->MessageDone();
}
@@ -558,16 +539,16 @@ struct RPCReference {
template <typename TChannel>
static void ReturnVoid(TChannel* channel) {
int32_t num_args = 1;
- int32_t tcode = kTVMNullptr;
+ int32_t type_index = ffi::TypeIndex::kTVMFFINone;
RPCCode code = RPCCode::kReturn;
- uint64_t packet_nbytes = sizeof(code) + sizeof(num_args) + sizeof(tcode);
+ uint64_t packet_nbytes = sizeof(code) + sizeof(num_args) +
sizeof(type_index);
channel->MessageStart(packet_nbytes);
channel->Write(packet_nbytes);
channel->Write(code);
channel->Write(num_args);
- channel->Write(tcode);
+ channel->Write(type_index);
channel->MessageDone();
}
};
diff --git a/src/runtime/rpc/rpc_channel_logger.h
b/src/runtime/rpc/rpc_channel_logger.h
deleted file mode 100644
index 8fe68f6690..0000000000
--- a/src/runtime/rpc/rpc_channel_logger.h
+++ /dev/null
@@ -1,186 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file rpc_channel_logger.h
- * \brief A wrapper for RPCChannel with a NanoRPCListener for logging the
commands.
- */
-#ifndef TVM_RUNTIME_RPC_RPC_CHANNEL_LOGGER_H_
-#define TVM_RUNTIME_RPC_RPC_CHANNEL_LOGGER_H_
-
-#include <tvm/runtime/c_runtime_api.h>
-
-#include <memory>
-#include <utility>
-
-#include "../../support/ssize.h"
-#include "../minrpc/minrpc_server_logging.h"
-#include "rpc_channel.h"
-
-#define RX_BUFFER_SIZE 65536
-
-namespace tvm {
-namespace runtime {
-
-class Buffer {
- public:
- Buffer(uint8_t* data, size_t data_size_bytes)
- : data_{data}, capacity_{data_size_bytes}, num_valid_bytes_{0},
read_cursor_{0} {}
-
- size_t Write(const uint8_t* data, size_t data_size_bytes) {
- size_t num_bytes_available = capacity_ - num_valid_bytes_;
- size_t num_bytes_to_copy = data_size_bytes;
- if (num_bytes_available < num_bytes_to_copy) {
- num_bytes_to_copy = num_bytes_available;
- }
-
- memcpy(&data_[num_valid_bytes_], data, num_bytes_to_copy);
- num_valid_bytes_ += num_bytes_to_copy;
- return num_bytes_to_copy;
- }
-
- size_t Read(uint8_t* data, size_t data_size_bytes) {
- size_t num_bytes_to_copy = data_size_bytes;
- size_t num_bytes_available = num_valid_bytes_ - read_cursor_;
- if (num_bytes_available < num_bytes_to_copy) {
- num_bytes_to_copy = num_bytes_available;
- }
-
- memcpy(data, &data_[read_cursor_], num_bytes_to_copy);
- read_cursor_ += num_bytes_to_copy;
- return num_bytes_to_copy;
- }
-
- void Clear() {
- num_valid_bytes_ = 0;
- read_cursor_ = 0;
- }
-
- size_t Size() const { return num_valid_bytes_; }
-
- private:
- /*! \brief pointer to data buffer. */
- uint8_t* data_;
-
- /*! \brief The total number of bytes available in data_.*/
- size_t capacity_;
-
- /*! \brief number of valid bytes in the buffer. */
- size_t num_valid_bytes_;
-
- /*! \brief Read cursor position. */
- size_t read_cursor_;
-};
-
-/*!
- * \brief A simple IO handler for MinRPCSniffer.
- *
- * \tparam Buffer* buffer to store received data.
- */
-class SnifferIOHandler {
- public:
- explicit SnifferIOHandler(Buffer* receive_buffer) :
receive_buffer_(receive_buffer) {}
-
- void MessageStart(size_t message_size_bytes) {}
-
- ssize_t PosixWrite(const uint8_t* buf, size_t buf_size_bytes) { return 0; }
-
- void MessageDone() {}
-
- ssize_t PosixRead(uint8_t* buf, size_t buf_size_bytes) {
- return receive_buffer_->Read(buf, buf_size_bytes);
- }
-
- void Close() {}
-
- void Exit(int code) {}
-
- private:
- Buffer* receive_buffer_;
-};
-
-/*!
- * \brief A simple rpc session that logs the received commands.
- */
-class NanoRPCListener {
- public:
- NanoRPCListener()
- : receive_buffer_(receive_storage_, receive_storage_size_bytes_),
- io_(&receive_buffer_),
- rpc_server_(&io_) {}
-
- void Listen(const uint8_t* data, size_t size) { receive_buffer_.Write(data,
size); }
-
- void ProcessTxPacket() {
- rpc_server_.ProcessOnePacket();
- ClearBuffer();
- }
-
- void ProcessRxPacket() {
- rpc_server_.ProcessOneResponse();
- ClearBuffer();
- }
-
- private:
- void ClearBuffer() { receive_buffer_.Clear(); }
-
- private:
- size_t receive_storage_size_bytes_ = RX_BUFFER_SIZE;
- uint8_t receive_storage_[RX_BUFFER_SIZE];
- Buffer receive_buffer_;
- SnifferIOHandler io_;
- MinRPCSniffer<SnifferIOHandler> rpc_server_;
-
- void HandleCompleteMessage() { rpc_server_.ProcessOnePacket(); }
-
- static void HandleCompleteMessageCb(void* context) {
- static_cast<NanoRPCListener*>(context)->HandleCompleteMessage();
- }
-};
-
-/*!
- * \brief A wrapper for RPCChannel, that also logs the commands sent.
- *
- * \tparam std::unique_ptr<RPCChannel>&& underlying RPCChannel unique_ptr.
- */
-class RPCChannelLogging : public RPCChannel {
- public:
- explicit RPCChannelLogging(std::unique_ptr<RPCChannel>&& next) { next_ =
std::move(next); }
-
- size_t Send(const void* data, size_t size) {
- listener_.ProcessRxPacket();
- listener_.Listen((const uint8_t*)data, size);
- listener_.ProcessTxPacket();
- return next_->Send(data, size);
- }
-
- size_t Recv(void* data, size_t size) {
- size_t ret = next_->Recv(data, size);
- listener_.Listen((const uint8_t*)data, size);
- return ret;
- }
-
- private:
- std::unique_ptr<RPCChannel> next_;
- NanoRPCListener listener_;
-};
-
-} // namespace runtime
-} // namespace tvm
-#endif // TVM_RUNTIME_RPC_RPC_CHANNEL_LOGGER_H_
diff --git a/src/runtime/rpc/rpc_endpoint.cc b/src/runtime/rpc/rpc_endpoint.cc
index 23edfa9bb5..cc7b1db807 100644
--- a/src/runtime/rpc/rpc_endpoint.cc
+++ b/src/runtime/rpc/rpc_endpoint.cc
@@ -189,14 +189,14 @@ class RPCEndpoint::EventHandler : public dmlc::Stream {
LOG(FATAL) << "RPCServerError:" << RPCServerStatusToString(code);
}
- uint64_t PackedSeqGetNumBytes(const TVMValue* arg_values, const int*
type_codes, int num_args,
- bool client_mode) {
- return RPCReference::PackedSeqGetNumBytes(arg_values, type_codes,
num_args, client_mode, this);
+ uint64_t PackedSeqGetNumBytes(const ffi::AnyView* packed_args, int num_args,
bool client_mode) {
+ return RPCReference::PackedSeqGetNumBytes(reinterpret_cast<const
TVMFFIAny*>(packed_args),
+ num_args, client_mode, this);
}
- void SendPackedSeq(const TVMValue* arg_values, const int* type_codes, int
num_args,
- bool client_mode) {
- RPCReference::SendPackedSeq(arg_values, type_codes, num_args, client_mode,
this);
+ void SendPackedSeq(const ffi::AnyView* packed_args, int num_args, bool
client_mode) {
+ RPCReference::SendPackedSeq(reinterpret_cast<const
TVMFFIAny*>(packed_args), num_args,
+ client_mode, this);
}
// Endian aware IO handling
@@ -228,7 +228,7 @@ class RPCEndpoint::EventHandler : public dmlc::Stream {
// which is needed for wasm and other env that goes through C API
if (obj->IsInstance<RPCObjectRefObj>()) {
auto* ref = static_cast<RPCObjectRefObj*>(obj);
- this->template Write<uint32_t>(kRuntimeRPCObjectRefTypeIndex);
+ this->template Write<uint32_t>(runtime::TypeIndex::kRuntimeRPCObjectRef);
uint64_t handle = reinterpret_cast<uint64_t>(ref->object_handle());
this->template Write<int64_t>(handle);
} else {
@@ -246,7 +246,7 @@ class RPCEndpoint::EventHandler : public dmlc::Stream {
}
}
- void ReadObject(int* tcode, TVMValue* value) {
+ void ReadObject(TVMFFIAny* out) {
// NOTE: for now all remote object are encoded as RPCObjectRef
// follow the same disco protocol in case we would like to upgrade later
//
@@ -254,7 +254,7 @@ class RPCEndpoint::EventHandler : public dmlc::Stream {
// which is needed for wasm and other env that goes through C API
uint32_t type_index;
this->template Read<uint32_t>(&type_index);
- if (type_index == kRuntimeRPCObjectRefTypeIndex) {
+ if (type_index == runtime::TypeIndex::kRuntimeRPCObjectRef) {
uint64_t handle;
this->template Read<uint64_t>(&handle);
// Always wrap things back in RPCObjectRef
@@ -263,8 +263,7 @@ class RPCEndpoint::EventHandler : public dmlc::Stream {
RPCObjectRef
rpc_obj(make_object<RPCObjectRefObj>(reinterpret_cast<void*>(handle), nullptr));
// Legacy ABI translation
// TODO(tqchen): remove this once we have upgraded to new ABI
- AnyView rpc_obj_view = rpc_obj;
- AnyViewToLegacyTVMArgValue(rpc_obj_view.CopyToTVMFFIAny(), value, tcode);
+ *reinterpret_cast<AnyView*>(out) = rpc_obj;
object_arena_.push_back(rpc_obj);
} else {
LOG(FATAL) << "ValueError: Object type is not supported in Disco calling
convention: "
@@ -342,7 +341,7 @@ class RPCEndpoint::EventHandler : public dmlc::Stream {
return;
} else {
ICHECK_EQ(init_header_step_, 1);
- this->ReadArray(dmlc::BeginPtr(*remote_key_), remote_key_->length());
+ this->ReadArray(remote_key_->data(), remote_key_->length());
this->SwitchToState(kRecvPacketNumBytes);
}
}
@@ -351,7 +350,6 @@ class RPCEndpoint::EventHandler : public dmlc::Stream {
void HandleProcessPacket(RPCSession::FEncodeReturn setreturn) {
RPCCode code = RPCCode::kNone;
this->Read(&code);
-
if (code >= RPCCode::kSyscallCodeStart) {
this->HandleSyscall(code);
} else {
@@ -397,15 +395,9 @@ class RPCEndpoint::EventHandler : public dmlc::Stream {
* \note The ffi::PackedArgs is available until we switchstate.
*/
ffi::PackedArgs RecvPackedSeq() {
- TVMValue* values;
- int* tcodes;
+ ffi::AnyView* packed_args;
int num_args;
- RPCReference::RecvPackedSeq(&values, &tcodes, &num_args, this);
-
- // Legacy ABI translation
- // TODO(tqchen): remove this once we have upgraded to new ABI
- AnyView* packed_args =
reinterpret_cast<AnyView*>(this->ArenaAlloc<TVMFFIAny>(num_args));
- LegacyTVMArgsToPackedArgs(values, tcodes, num_args, packed_args);
+ RPCReference::RecvPackedSeq(reinterpret_cast<TVMFFIAny**>(&packed_args),
&num_args, this);
return ffi::PackedArgs(packed_args, num_args);
}
@@ -426,12 +418,8 @@ class RPCEndpoint::EventHandler : public dmlc::Stream {
* \param args The arguments.
*/
void ReturnPackedSeq(ffi::PackedArgs args) {
- // Legacy ABI translation
- // TODO(tqchen): remove this once we have upgraded to new ABI
- TVMValue* values = this->ArenaAlloc<TVMValue>(args.size());
- int* tcodes = this->ArenaAlloc<int>(args.size());
- PackedArgsToLegacyTVMArgs(args.data(), args.size(), values, tcodes);
- RPCReference::ReturnPackedSeq(values, tcodes, args.size(), this);
+ RPCReference::ReturnPackedSeq(reinterpret_cast<const
TVMFFIAny*>(args.data()), args.size(),
+ this);
}
/*!
@@ -745,20 +733,15 @@ void RPCEndpoint::Init() {
std::lock_guard<std::mutex> lock(mutex_);
RPCCode code = static_cast<RPCCode>(all_args[0].cast<int>());
ffi::PackedArgs args = all_args.Slice(1);
- // Legacy ABI translation
- // TODO(tqchen): remove this once we have upgraded to new ABI
- TVMValue* values = handler_->ArenaAlloc<TVMValue>(args.size());
- int* tcodes = handler_->ArenaAlloc<int>(args.size());
- PackedArgsToLegacyTVMArgs(args.data(), args.size(), values, tcodes);
// run transmission
uint64_t packet_nbytes =
- sizeof(code) + handler_->PackedSeqGetNumBytes(values, tcodes,
args.size(), true);
+ sizeof(code) + handler_->PackedSeqGetNumBytes(args.data(),
args.size(), true);
// All packet begins with packet nbytes
handler_->Write(packet_nbytes);
handler_->Write(code);
- handler_->SendPackedSeq(values, tcodes, args.size(), true);
+ handler_->SendPackedSeq(args.data(), args.size(), true);
code = HandleUntilReturnEvent(true, [rv](ffi::PackedArgs args) {
ICHECK_EQ(args.size(), 1);
@@ -838,8 +821,12 @@ int RPCEndpoint::ServerAsyncIOEventHandler(const
std::string& in_bytes, int even
writer_.bytes_available());
}
ICHECK(code != RPCCode::kReturn && code != RPCCode::kCopyAck);
+ // if the code is kShutdown, return 0 to indicate the server should exit
if (code == RPCCode::kShutdown) return 0;
+ // if the writer has bytes available, return 2 to indicate the server should
send data
+ // usually by calling the handler again
if (writer_.bytes_available() != 0) return 2;
+ // otherwise, return 1 to indicate the server should and read
return 1;
}
@@ -849,22 +836,16 @@ void RPCEndpoint::InitRemoteSession(ffi::PackedArgs args)
{
std::string protocol_ver = kRPCProtocolVer;
uint64_t length = protocol_ver.length();
- // Legacy ABI translation
- // TODO(tqchen): remove this once we have upgraded to new ABI
- TVMValue* values = handler_->ArenaAlloc<TVMValue>(args.size());
- int* tcodes = handler_->ArenaAlloc<int>(args.size());
- PackedArgsToLegacyTVMArgs(args.data(), args.size(), values, tcodes);
-
// run transmission
uint64_t packet_nbytes = sizeof(code) + sizeof(length) + length +
- handler_->PackedSeqGetNumBytes(values, tcodes,
args.size(), true);
+ handler_->PackedSeqGetNumBytes(args.data(),
args.size(), true);
// All packet begins with packet nbytes
handler_->Write(packet_nbytes);
handler_->Write(code);
handler_->Write(length);
handler_->WriteArray(protocol_ver.data(), length);
- handler_->SendPackedSeq(values, tcodes, args.size(), true);
+ handler_->SendPackedSeq(args.data(), args.size(), true);
code = HandleUntilReturnEvent(true, [](ffi::PackedArgs args) {});
ICHECK(code == RPCCode::kReturn) << "code=" << static_cast<int>(code);
@@ -879,20 +860,14 @@ void RPCEndpoint::CallFunc(RPCSession::PackedFuncHandle
h, ffi::PackedArgs args,
RPCCode code = RPCCode::kCallFunc;
uint64_t handle = reinterpret_cast<uint64_t>(h);
- // Legacy ABI translation
- // TODO(tqchen): remove this once we have upgraded to new ABI
- TVMValue* values = handler_->ArenaAlloc<TVMValue>(args.size());
- int* tcodes = handler_->ArenaAlloc<int>(args.size());
- PackedArgsToLegacyTVMArgs(args.data(), args.size(), values, tcodes);
-
// run transmission
uint64_t packet_nbytes = sizeof(code) + sizeof(handle) +
- handler_->PackedSeqGetNumBytes(values, tcodes,
args.size(), true);
+ handler_->PackedSeqGetNumBytes(args.data(),
args.size(), true);
handler_->Write(packet_nbytes);
handler_->Write(code);
handler_->Write(handle);
- handler_->SendPackedSeq(values, tcodes, args.size(), true);
+ handler_->SendPackedSeq(args.data(), args.size(), true);
code = HandleUntilReturnEvent(true, encode_return);
ICHECK(code == RPCCode::kReturn) << "code=" << RPCCodeToString(code);
diff --git a/src/runtime/rpc/rpc_endpoint.h b/src/runtime/rpc/rpc_endpoint.h
index a420e6d92f..5d94aed470 100644
--- a/src/runtime/rpc/rpc_endpoint.h
+++ b/src/runtime/rpc/rpc_endpoint.h
@@ -34,7 +34,6 @@
#include "../../support/ring_buffer.h"
#include "../minrpc/rpc_reference.h"
#include "rpc_channel.h"
-#include "rpc_channel_logger.h"
#include "rpc_session.h"
namespace tvm {
diff --git a/src/runtime/rpc/rpc_local_session.cc
b/src/runtime/rpc/rpc_local_session.cc
index 5761828876..38b52181d5 100644
--- a/src/runtime/rpc/rpc_local_session.cc
+++ b/src/runtime/rpc/rpc_local_session.cc
@@ -47,10 +47,9 @@ void LocalSession::EncodeReturn(ffi::Any rv, const
FEncodeReturn& encode_return)
AnyView packed_args[3];
// NOTE: this is the place that we need to handle special RPC-related
// ABI convention for return value passing that is built on top of Any FFI.
- // We need to encode object pointers as opaque raw pointers for passing
- // TODO(tqchen): move to RPC to new ABI
+ // first argument is always the type index.
+ packed_args[0] = rv.type_index();
if (rv == nullptr) {
- packed_args[0] = static_cast<int32_t>(kTVMNullptr);
packed_args[1] = rv;
encode_return(ffi::PackedArgs(packed_args, 2));
} else if (rv.as<NDArray>()) {
@@ -59,43 +58,25 @@ void LocalSession::EncodeReturn(ffi::Any rv, const
FEncodeReturn& encode_return)
// The second pack value is a customized deleter that deletes the NDArray.
TVMFFIAny ret_any =
ffi::details::AnyUnsafe::MoveAnyToTVMFFIAny(std::move(rv));
void* opaque_handle = ret_any.v_obj;
- packed_args[0] = static_cast<int32_t>(kTVMNDArrayHandle);
- packed_args[1] =
-
static_cast<DLTensor*>(ObjectHandleToTVMArrayHandle(static_cast<Object*>(opaque_handle)));
+ packed_args[1] = TVMFFINDArrayGetDLTensorPtr(opaque_handle);
packed_args[2] = opaque_handle;
encode_return(ffi::PackedArgs(packed_args, 3));
} else if (const auto* bytes = rv.as<ffi::BytesObj>()) {
// always pass bytes as byte array
- packed_args[0] = static_cast<int32_t>(kTVMBytes);
TVMFFIByteArray byte_arr;
byte_arr.data = bytes->data;
byte_arr.size = bytes->size;
packed_args[1] = &byte_arr;
encode_return(ffi::PackedArgs(packed_args, 2));
} else if (const auto* str = rv.as<ffi::StringObj>()) {
- // always pass bytes as raw string
- packed_args[0] = static_cast<int32_t>(kTVMStr);
packed_args[1] = str->data;
encode_return(ffi::PackedArgs(packed_args, 2));
} else if (rv.as<ffi::ObjectRef>()) {
TVMFFIAny ret_any =
ffi::details::AnyUnsafe::MoveAnyToTVMFFIAny(std::move(rv));
void* opaque_handle = ret_any.v_obj;
packed_args[1] = opaque_handle;
- if (ret_any.type_index == ffi::TypeIndex::kTVMFFIModule) {
- packed_args[0] = static_cast<int32_t>(kTVMModuleHandle);
- } else if (ret_any.type_index == ffi::TypeIndex::kTVMFFIFunction) {
- packed_args[0] = static_cast<int32_t>(kTVMPackedFuncHandle);
- } else {
- packed_args[0] = static_cast<int32_t>(kTVMObjectHandle);
- }
encode_return(ffi::PackedArgs(packed_args, 2));
} else {
- AnyView temp = rv;
- TVMValue val;
- int type_code;
- AnyViewToLegacyTVMArgValue(temp.CopyToTVMFFIAny(), &val, &type_code);
- // normal POD encoding through rv
- packed_args[0] = type_code;
packed_args[1] = rv;
encode_return(ffi::PackedArgs(packed_args, 2));
}
diff --git a/src/runtime/rpc/rpc_module.cc b/src/runtime/rpc/rpc_module.cc
index ad42390dc6..a0315604bf 100644
--- a/src/runtime/rpc/rpc_module.cc
+++ b/src/runtime/rpc/rpc_module.cc
@@ -120,9 +120,7 @@ class RPCWrappedFunc : public Object {
case ffi::TypeIndex::kTVMFFIFunction:
case ffi::TypeIndex::kTVMFFIModule: {
packed_args[i] = UnwrapRemoteValueToHandle(args[i]);
- // hack, need to force set the type index to the correct one
- // so legacy RPC ABI translation can work
- // TODO(tqchen): remove this once we migrate to use new ABI as
transport
+ // need to force set the type index to the correct one
TVMFFIAny temp = packed_args[i].CopyToTVMFFIAny();
temp.type_index = args[i].type_index();
packed_args[i] = AnyView::CopyFromTVMFFIAny(temp);
@@ -290,30 +288,34 @@ void* RPCWrappedFunc::UnwrapRemoteValueToHandle(const
AnyView& arg) const {
}
void RPCWrappedFunc::WrapRemoteReturnToValue(ffi::PackedArgs args, ffi::Any*
rv) const {
- int tcode = args[0].cast<int>();
- // TODO(tqchen): move to RPC to new ABI
- if (tcode == kTVMNullptr) {
+ int type_index = args[0].cast<int>();
+ if (type_index == ffi::TypeIndex::kTVMFFINone) {
*rv = nullptr;
return;
- } else if (tcode == kTVMPackedFuncHandle) {
+ } else if (type_index == ffi::TypeIndex::kTVMFFIFunction) {
ICHECK_EQ(args.size(), 2);
void* handle = args[1].cast<void*>();
auto wf = std::make_shared<RPCWrappedFunc>(handle, sess_);
*rv = ffi::Function(
[wf](ffi::PackedArgs args, ffi::Any* rv) { return wf->operator()(args,
rv); });
- } else if (tcode == kTVMModuleHandle) {
+ } else if (type_index == ffi::TypeIndex::kTVMFFIModule) {
ICHECK_EQ(args.size(), 2);
void* handle = args[1].cast<void*>();
auto n = make_object<RPCModuleNode>(handle, sess_);
*rv = Module(n);
- } else if (tcode == kTVMNDArrayHandle || tcode == kTVMDLTensorHandle) {
+ } else if (type_index == ffi::TypeIndex::kTVMFFINDArray ||
+ type_index == ffi::TypeIndex::kTVMFFIDLTensorPtr) {
ICHECK_EQ(args.size(), 3);
auto tensor = args[1].cast<DLTensor*>();
void* nd_handle = args[2].cast<void*>();
*rv = NDArrayFromRemoteOpaqueHandle(sess_, tensor->data, tensor,
AddRPCSessionMask(tensor->device,
sess_->table_index()),
nd_handle);
- } else if (tcode == kTVMObjectHandle) {
+ } else if (type_index == ffi::TypeIndex::kTVMFFIBytes ||
+ type_index == ffi::TypeIndex::kTVMFFIStr) {
+ ICHECK_EQ(args.size(), 2);
+ *rv = args[1];
+ } else if (type_index >= ffi::TypeIndex::kTVMFFIStaticObjectBegin) {
ICHECK_EQ(args.size(), 2);
void* handle = args[1].cast<void*>();
auto n = make_object<RPCObjectRefObj>(handle, sess_);
diff --git a/src/runtime/rpc/rpc_socket_impl.cc
b/src/runtime/rpc/rpc_socket_impl.cc
index 286d143bad..f51117211a 100644
--- a/src/runtime/rpc/rpc_socket_impl.cc
+++ b/src/runtime/rpc/rpc_socket_impl.cc
@@ -98,9 +98,6 @@ std::shared_ptr<RPCEndpoint> RPCConnect(std::string url, int
port, std::string k
}
std::unique_ptr<RPCChannel> channel = std::make_unique<SockChannel>(sock);
- if (enable_logging) {
- channel.reset(new RPCChannelLogging(std::move(channel)));
- }
auto endpt = RPCEndpoint::Create(std::move(channel), key, remote_key);
endpt->InitRemoteSession(init_seq);
diff --git a/tests/python/runtime/test_runtime_rpc.py
b/tests/python/runtime/test_runtime_rpc.py
index 604d8eb42c..9ba89fe7f5 100644
--- a/tests/python/runtime/test_runtime_rpc.py
+++ b/tests/python/runtime/test_runtime_rpc.py
@@ -112,25 +112,6 @@ def test_rpc_simple():
check_remote()
[email protected]_rpc
-def test_rpc_simple_wlog():
- server = rpc.Server(key="x1")
- client = rpc.connect("127.0.0.1", server.port, key="x1",
enable_logging=True)
-
- def check_remote():
- f1 = client.get_function("rpc.test.addone")
- assert f1(10) == 11
- f3 = client.get_function("rpc.test.except")
-
- with pytest.raises(tvm._ffi.base.TVMError):
- f3("abc")
-
- f2 = client.get_function("rpc.test.strcat")
- assert f2("abc", 11) == "abc:11"
-
- check_remote()
-
-
@tvm.testing.requires_rpc
def test_rpc_runtime_string():
server = rpc.Server(key="x1")