This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 9a985714f3 [Unity] Migrate Relax Executable/VM to `TVM_MODULE_VTABLE`
Convention (#16149)
9a985714f3 is described below
commit 9a985714f396d40a08458426423872283b886af8
Author: Junru Shao <[email protected]>
AuthorDate: Tue Nov 21 06:19:48 2023 -0800
[Unity] Migrate Relax Executable/VM to `TVM_MODULE_VTABLE` Convention
(#16149)
Following up with #16148, this PR migrates Relax Exectuable/VM to the
explicit vtable introduced as part of the TVM runtime Module calling
convention.
---
include/tvm/runtime/packed_func.h | 16 +-
include/tvm/runtime/relax_vm/executable.h | 24 ++-
include/tvm/runtime/relax_vm/vm.h | 2 -
src/runtime/relax_vm/executable.cc | 45 ++---
src/runtime/relax_vm/vm.cc | 299 ++++++++++++++++--------------
5 files changed, 197 insertions(+), 189 deletions(-)
diff --git a/include/tvm/runtime/packed_func.h
b/include/tvm/runtime/packed_func.h
index eebdb288d1..7266f8c4a5 100644
--- a/include/tvm/runtime/packed_func.h
+++ b/include/tvm/runtime/packed_func.h
@@ -1153,13 +1153,19 @@ struct PackedFuncValueConverter {
}
\
}
-#define TVM_MODULE_VTABLE_BEGIN(TypeKey)
\
- const char* type_key() const final { return TypeKey; }
\
- PackedFunc GetFunction(const String& _name, const ObjectPtr<Object>& _self)
final { \
+#define TVM_MODULE_VTABLE_BEGIN(TypeKey)
\
+ const char* type_key() const final { return TypeKey; }
\
+ PackedFunc GetFunction(const String& _name, const ObjectPtr<Object>& _self)
override { \
using SelfPtr = std::remove_cv_t<decltype(this)>;
#define TVM_MODULE_VTABLE_END() \
return PackedFunc(nullptr); \
}
+#define TVM_MODULE_VTABLE_END_WITH_DEFAULT(MemFunc) \
+ { \
+ auto f = (MemFunc); \
+ return (this->*f)(_name); \
+ } \
+ } // NOLINT(*)
#define TVM_MODULE_VTABLE_ENTRY(Name, MemFunc)
\
if (_name == Name) {
\
return PackedFunc([_self](TVMArgs args, TVMRetValue* rv) -> void {
\
@@ -2234,6 +2240,6 @@ inline TVMArgValue::operator DLDataType() const {
inline TVMArgValue::operator DataType() const { return DataType(operator
DLDataType()); }
-} // namespace runtime
-} // namespace tvm
+} // namespace runtime // NOLINT(*)
+} // namespace tvm // NOLINT(*)
#endif // TVM_RUNTIME_PACKED_FUNC_H_
diff --git a/include/tvm/runtime/relax_vm/executable.h
b/include/tvm/runtime/relax_vm/executable.h
index 68c3cc7e15..845842f22a 100644
--- a/include/tvm/runtime/relax_vm/executable.h
+++ b/include/tvm/runtime/relax_vm/executable.h
@@ -25,6 +25,7 @@
#include <tvm/runtime/container/closure.h>
#include <tvm/runtime/object.h>
+#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include <string>
@@ -87,14 +88,6 @@ struct VMFuncInfo {
*/
class Executable : public runtime::ModuleNode {
public:
- /*!
- * \brief Get a PackedFunc from the executable module.
- * \param name the name of the function.
- * \param sptr_to_self The shared_ptr that points to this module node.
- * \return PackedFunc or nullptr when it is not available.
- */
- PackedFunc GetFunction(const String& name, const ObjectPtr<Object>&
sptr_to_self) final;
-
/*! \brief Get the property of the runtime module .*/
int GetPropertyMask() const final { return
ModulePropertyMask::kBinarySerializable; };
@@ -144,6 +137,12 @@ class Executable : public runtime::ModuleNode {
* \param format The target format of the saved file.
*/
void SaveToFile(const String& file_name, const String& format) final;
+ /*! \brief Create a Relax virtual machine and load `this` as the executable.
*/
+ Module VMLoadExecutable() const;
+ /*! \brief Create a Relax virtual machine with profiler and load `this` as
the executable. */
+ Module VMProfilerLoadExecutable() const;
+ /*! \brief Check if the Executable contains a specific function. */
+ bool HasFunction(const String& name) const;
/*!
* \brief Load Executable from the file.
* \param file_name The path of the file that load the executable from.
@@ -164,7 +163,14 @@ class Executable : public runtime::ModuleNode {
virtual ~Executable() {}
- const char* type_key() const final { return "relax.Executable"; }
+ TVM_MODULE_VTABLE_BEGIN("relax.Executable");
+ TVM_MODULE_VTABLE_ENTRY("stats", &Executable::Stats);
+ TVM_MODULE_VTABLE_ENTRY("as_text", &Executable::AsText);
+ TVM_MODULE_VTABLE_ENTRY("as_python", &Executable::AsPython);
+ TVM_MODULE_VTABLE_ENTRY("vm_load_executable", &Executable::VMLoadExecutable);
+ TVM_MODULE_VTABLE_ENTRY("vm_profiler_load_executable",
&Executable::VMProfilerLoadExecutable);
+ TVM_MODULE_VTABLE_ENTRY("has_function", &Executable::HasFunction);
+ TVM_MODULE_VTABLE_END();
private:
/*!
diff --git a/include/tvm/runtime/relax_vm/vm.h
b/include/tvm/runtime/relax_vm/vm.h
index 4e1f1361be..d2c96e9e97 100644
--- a/include/tvm/runtime/relax_vm/vm.h
+++ b/include/tvm/runtime/relax_vm/vm.h
@@ -176,8 +176,6 @@ class VirtualMachine : public runtime::ModuleNode {
~VirtualMachine() {}
- const char* type_key() const final { return "relax.VirtualMachine"; }
-
//--------------------------------------------------------------------------
// The following section contains states that other builtin can depend on
//--------------------------------------------------------------------------
diff --git a/src/runtime/relax_vm/executable.cc
b/src/runtime/relax_vm/executable.cc
index 98be501ec2..f45786c3da 100644
--- a/src/runtime/relax_vm/executable.cc
+++ b/src/runtime/relax_vm/executable.cc
@@ -52,37 +52,6 @@ enum ConstantType : int {
ICHECK(val) << "Invalid VM file format in the " << section << " section." \
<< "\n";
-PackedFunc Executable::GetFunction(const String& name, const
ObjectPtr<Object>& sptr_to_self) {
- if (name == "stats") {
- return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*rv = this->Stats(); });
- } else if (name == "as_text") {
- return PackedFunc(
- [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv =
this->AsText(); });
- } else if (name == "as_python") {
- return PackedFunc(
- [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv =
this->AsPython(); });
- } else if (name == "vm_load_executable") {
- return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
- ObjectPtr<VirtualMachine> vm = VirtualMachine::Create();
- ICHECK(sptr_to_self.get() == this);
- vm->LoadExecutable(GetObjectPtr<Executable>(this));
- *rv = Module(vm);
- });
- } else if (name == "vm_profiler_load_executable") {
- return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
- ObjectPtr<VirtualMachine> vm = VirtualMachine::CreateProfiler();
- ICHECK(sptr_to_self.get() == this);
- vm->LoadExecutable(GetObjectPtr<Executable>(this));
- *rv = Module(vm);
- });
- } else if (name == "has_function") {
- return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
- *rv = static_cast<bool>(this->func_map.count(args[0]));
- });
- }
- return nullptr;
-}
-
std::string Executable::Stats() const {
std::ostringstream oss;
oss << "Relax VM executable statistics:" << std::endl;
@@ -435,6 +404,20 @@ std::string RegNameToStr(RegName reg) {
return "%" + std::to_string(reg);
}
+Module Executable::VMLoadExecutable() const {
+ ObjectPtr<VirtualMachine> vm = VirtualMachine::Create();
+ vm->LoadExecutable(GetObjectPtr<Executable>(const_cast<Executable*>(this)));
+ return Module(vm);
+}
+
+Module Executable::VMProfilerLoadExecutable() const {
+ ObjectPtr<VirtualMachine> vm = VirtualMachine::CreateProfiler();
+ vm->LoadExecutable(GetObjectPtr<Executable>(const_cast<Executable*>(this)));
+ return Module(vm);
+}
+
+bool Executable::HasFunction(const String& name) const { return
func_map.count(name); }
+
String Executable::AsText() const {
auto get_func_name = [&](Index index) -> std::string {
if (static_cast<size_t>(index) < func_table.size()) {
diff --git a/src/runtime/relax_vm/vm.cc b/src/runtime/relax_vm/vm.cc
index 35d4a61d22..b31268e697 100644
--- a/src/runtime/relax_vm/vm.cc
+++ b/src/runtime/relax_vm/vm.cc
@@ -20,6 +20,8 @@
/*!
* \file src/runtime/relax_vm/vm.cc
*/
+#include <dlpack/dlpack.h>
+#include <tvm/runtime/memory/memory_manager.h>
#include <tvm/runtime/nvtx.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/profiling.h>
@@ -183,21 +185,46 @@ class VirtualMachineImpl : public VirtualMachine {
// Public facing functions overloading
//---------------------------------------------------
void LoadExecutable(ObjectPtr<Executable> exec) final;
-
void Init(const std::vector<Device>& devices,
const std::vector<AllocatorType>& alloc_types) final;
-
- PackedFunc GetFunction(const String& name, const ObjectPtr<Object>&
sptr_to_self) override;
-
VMClosure GetClosure(const String& func_name) final {
return this->GetClosureInternal(func_name, false).value();
}
-
void InvokeClosurePacked(const ObjectRef& closure_or_packedfunc, TVMArgs
args,
TVMRetValue* rv) final;
-
void SetInstrument(PackedFunc instrument) final { this->instrument_ =
instrument; }
+ //---------------------------------------------------
+ // Functions in the vtable of Module
+ //---------------------------------------------------
+ void _Init(TVMArgs args, TVMRetValue* rv);
+ void _SaveClosure(TVMArgs args, TVMRetValue* rv);
+ void _InvokeClosure(TVMArgs args, TVMRetValue* rv);
+ void _InvokeClosureStateful(std::string func_name);
+ void _SetInstrument(TVMArgs args, TVMRetValue* rv);
+ void _GetOutputArity(TVMArgs args, TVMRetValue* rv);
+ void _GetOutput(TVMArgs args, TVMRetValue* rv);
+ void _SetInputWithoutParamModule(TVMArgs args, TVMRetValue* rv);
+ void _SetInputWithParamModule(TVMArgs args, TVMRetValue* rv);
+ int _GetFunctionArity(std::string func_name);
+ std::string _GetFunctionParamName(std::string func_name, int index);
+ PackedFunc _LookupFunction(const String& name);
+
+ TVM_MODULE_VTABLE_BEGIN("relax.VirtualMachine");
+ TVM_MODULE_VTABLE_ENTRY_PACKED("vm_initialization",
&VirtualMachineImpl::_Init);
+ TVM_MODULE_VTABLE_ENTRY_PACKED("save_function",
&VirtualMachineImpl::_SaveClosure);
+ TVM_MODULE_VTABLE_ENTRY_PACKED("invoke_closure",
&VirtualMachineImpl::_InvokeClosure);
+ TVM_MODULE_VTABLE_ENTRY("invoke_stateful",
&VirtualMachineImpl::_InvokeClosureStateful);
+ TVM_MODULE_VTABLE_ENTRY_PACKED("set_instrument",
&VirtualMachineImpl::_SetInstrument);
+ TVM_MODULE_VTABLE_ENTRY_PACKED("get_output_arity",
&VirtualMachineImpl::_GetOutputArity);
+ TVM_MODULE_VTABLE_ENTRY_PACKED("get_output",
&VirtualMachineImpl::_GetOutput);
+ TVM_MODULE_VTABLE_ENTRY_PACKED("set_input",
&VirtualMachineImpl::_SetInputWithoutParamModule);
+ TVM_MODULE_VTABLE_ENTRY_PACKED("set_input_with_param_module",
+
&VirtualMachineImpl::_SetInputWithParamModule);
+ TVM_MODULE_VTABLE_ENTRY("get_function_arity",
&VirtualMachineImpl::_GetFunctionArity);
+ TVM_MODULE_VTABLE_ENTRY("get_function_param_name",
&VirtualMachineImpl::_GetFunctionParamName);
+ TVM_MODULE_VTABLE_END_WITH_DEFAULT(&VirtualMachineImpl::_LookupFunction);
+
//--------------------------------------------------
// Additional support arguments functions for VM
//--------------------------------------------------
@@ -214,14 +241,13 @@ class VirtualMachineImpl : public VirtualMachine {
* \param func_name The function name.
* \param args args[offset:] are arguments to the function. If the arguments
are not of the
* correct device for the function, they will be copied to the device.
- * \param offset Starting offset of the arguments in \p args.
* \param with_param_module If set to true, the last argument will be a
module and can be invoked
* to get the argument, this is mainly used for debugging purposes
and setting composite
* objects. \note This interface works when using VM over RPC by internally
converting NDArray in
* the arguments to DLTensor, which is supported in RPC where remote could
only have a minimal C
* runtime.
*/
- void SetInput(std::string func_name, TVMArgs args, int offset, bool
with_param_module = false);
+ void SetInput(std::string func_name, bool with_param_module, TVMArgs args);
/*!
* \brief Look up whether the VM has a function by the given name.
@@ -455,149 +481,21 @@ RegType VirtualMachineImpl::LookupVMOutput(const
std::string& func_name) {
return outputs_[func_name];
}
-PackedFunc VirtualMachineImpl::GetFunction(const String& name,
- const ObjectPtr<Object>&
sptr_to_self) {
- if (name == "vm_initialization") {
- // initialize the VirtualMachine, takes variable-length arguments
- // first argument is a runtime::Module, followed by one or more
device_type, device_id,
- // and the AllocatorType associated with the device.
- return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
- ICHECK_EQ(args.size() % 3, 0);
- std::vector<Device> devices;
- std::vector<AllocatorType> alloc_types;
- for (int i = 0; i < args.size(); i += 3) {
- Device dev;
- int device_type = args[i];
- dev.device_type = DLDeviceType(device_type);
- dev.device_id = args[i + 1];
- int type = args[i + 2];
- devices.push_back(dev);
- alloc_types.push_back(AllocatorType(type));
- }
- this->Init(devices, alloc_types);
- });
- } else if (name == "save_function") {
- return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
- ICHECK_GE(args.size(), 3);
- this->SaveClosure(args[0], args[1], args[2],
- TVMArgs(args.values + 3, args.type_codes + 3,
args.size() - 3));
- });
- } else if (name == "invoke_closure") {
- return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
- VMClosure clo = args[0];
- this->InvokeClosurePacked(clo, TVMArgs(args.values + 1, args.type_codes
+ 1, args.size() - 1),
- rv);
- });
- } else if (name == "set_instrument") {
- return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
- PackedFunc func;
- if (args[0].type_code() != kTVMPackedFuncHandle) {
- String func_name = args[0];
- const PackedFunc* factory = Registry::Get(func_name);
- ICHECK(factory != nullptr) << "Cannot find factory " << func_name;
- TVMRetValue rv;
- factory->CallPacked(TVMArgs(args.values + 1, args.type_codes + 1,
args.num_args - 1), &rv);
- func = rv;
- } else {
- func = args[0];
- }
- this->SetInstrument(func);
- });
- } else if (name == "invoke_stateful") {
- return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
- std::string func_name = args[0];
- const auto& m = this->exec_->func_map;
- if (m.find(func_name) == m.end()) {
- LOG(FATAL) << "ValueError: Unknown function: " << func_name;
- }
- Index gf_idx = m.at(func_name);
- if (!inputs_.count(func_name)) {
- LOG(FATAL) << "ValueError: No inputs set for stateful call of " <<
func_name
- << "; use `set_input` first.";
- return;
- }
- outputs_[func_name] = this->InvokeClosureInternal(func_pool_[gf_idx],
inputs_[func_name]);
- });
- } else if (name == "get_output_arity") {
- return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
- std::string func_name = args[0];
- RegType out = LookupVMOutput(func_name);
- // use remaining args as indices
- ObjectRef obj = IndexIntoNestedObject(out.AsObjectRef<ObjectRef>(),
args, 1);
- // after chasing through the indices, examine the final object
- if (const auto* arr = obj.as<ArrayNode>()) {
- *rv = static_cast<int>(arr->size());
- } else {
- *rv = -1;
- }
- });
- } else if (name == "get_output") {
- return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
- std::string func_name = args[0];
- RegType out = LookupVMOutput(func_name);
- // use remaining args as indices
- ObjectRef obj = IndexIntoNestedObject(out.AsObjectRef<ObjectRef>(),
args, 1);
- if (obj.as<ArrayNode>()) {
- LOG(FATAL) << "ValueError: `get_output` cannot return a tuple for RPC
compatibility. "
- "Please specify another index argument.";
- return;
- }
- *rv = obj;
- });
- } else if (name == "set_input") {
- return PackedFunc(
- [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
SetInput(args[0], args, 1); });
- } else if (name == "set_input_with_param_module") {
- return PackedFunc(
- [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
SetInput(args[0], args, 1, true); });
- } else if (name == "get_function_arity") {
- return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
- std::string func_name = args[0];
- const VMFuncInfo& vm_func = LookupVMFuncInfo(func_name);
- *rv = static_cast<int>(vm_func.param_names.size());
- });
- } else if (name == "get_function_param_name") {
- return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
- std::string func_name = args[0];
- int index = args[1];
- const VMFuncInfo& vm_func = LookupVMFuncInfo(func_name);
- if (static_cast<size_t>(index) >= vm_func.param_names.size()) {
- LOG(FATAL) << "ValueError: Invalid index for " << func_name << " (" <<
index << " out of "
- << vm_func.param_names.size() << ")";
- }
- *rv = vm_func.param_names[index];
- });
- } else {
- // default case, look up closure in VM.
- if (Optional<VMClosure> opt = this->GetClosureInternal(name, true)) {
- auto clo = opt.value();
- return PackedFunc([sptr_to_self, this, clo](TVMArgs args, TVMRetValue*
rv) {
- this->InvokeClosurePacked(clo, args, rv);
- });
- } else {
- return PackedFunc(nullptr);
- }
- }
-}
-
-void VirtualMachineImpl::SetInput(std::string func_name, TVMArgs args, int
offset,
- bool with_param_module) {
+void VirtualMachineImpl::SetInput(std::string func_name, bool
with_param_module, TVMArgs args) {
const auto& m = exec_->func_map;
if (m.find(func_name) != m.end()) {
Index gf_idx = m.at(func_name);
const VMFuncInfo& vm_func = exec_->func_table[gf_idx];
size_t params_num = vm_func.num_args;
- ICHECK_EQ(args.size() - offset, params_num)
+ ICHECK_EQ(args.size(), params_num)
<< "The number of provided parameters doesn't match the number of
arguments for";
std::vector<RegType> func_args(params_num);
-
- for (int i = offset; i < args.size(); ++i) {
- int index = i - offset;
+ for (int i = 0; i < args.size(); ++i) {
if (with_param_module && i == args.size() - 1) {
// call param func to get the arguments(usually corresponds to param
pack.)
- func_args[index] = (args[i].operator
Module()).GetFunction("get_params")();
+ func_args[i] = (args[i].operator Module()).GetFunction("get_params")();
} else {
- func_args[index] = ConvertArgToDevice(args[i], devices[0],
allocators[0]);
+ func_args[i] = ConvertArgToDevice(args[i], devices[0], allocators[0]);
}
}
inputs_[func_name] = func_args;
@@ -926,6 +824,123 @@ void VirtualMachineImpl::RunLoop() {
ObjectPtr<VirtualMachine> VirtualMachine::Create() { return
make_object<VirtualMachineImpl>(); }
+//--------------------------------------------------------------------
+// FFI related code
+//--------------------------------------------------------------------
+
+void VirtualMachineImpl::_Init(TVMArgs args, TVMRetValue* rv) {
+ ICHECK_EQ(args.size() % 3, 0);
+ std::vector<Device> devices;
+ std::vector<AllocatorType> alloc_types;
+ for (int i = 0; i < args.size(); i += 3) {
+ int device_type = args[i];
+ int device_id = args[i + 1];
+ int alloc_type = args[i + 2];
+ devices.push_back(Device{DLDeviceType(device_type), device_id});
+ alloc_types.push_back(AllocatorType(alloc_type));
+ }
+ this->Init(devices, alloc_types);
+}
+
+void VirtualMachineImpl::_SaveClosure(TVMArgs args, TVMRetValue* rv) {
+ ICHECK_GE(args.size(), 3);
+ std::string func_name = args[0];
+ this->SaveClosure(func_name, args[1], args[2],
+ TVMArgs(args.values + 3, args.type_codes + 3, args.size()
- 3));
+}
+
+void VirtualMachineImpl::_InvokeClosure(TVMArgs args, TVMRetValue* rv) {
+ this->InvokeClosurePacked(args[0], TVMArgs(args.values + 1, args.type_codes
+ 1, args.size() - 1),
+ rv);
+}
+
+void VirtualMachineImpl::_InvokeClosureStateful(std::string func_name) {
+ const std::unordered_map<std::string, Index>& m = this->exec_->func_map;
+ if (m.find(func_name) == m.end()) {
+ LOG(FATAL) << "ValueError: Unknown function: " << func_name;
+ }
+ if (!inputs_.count(func_name)) {
+ LOG(FATAL) << "ValueError: No inputs set for stateful call of " <<
func_name
+ << "; use `set_input` first.";
+ return;
+ }
+ outputs_[func_name] =
+ this->InvokeClosureInternal(func_pool_[m.at(func_name)],
inputs_[func_name]);
+}
+
+void VirtualMachineImpl::_SetInstrument(TVMArgs args, TVMRetValue* rv) {
+ if (args[0].type_code() == kTVMPackedFuncHandle) {
+ this->SetInstrument(args[0]);
+ } else {
+ String func_name = args[0];
+ const PackedFunc* factory = Registry::Get(func_name);
+ CHECK(factory) << "Cannot find factory " << func_name;
+ TVMRetValue rv;
+ factory->CallPacked(TVMArgs(args.values + 1, args.type_codes + 1,
args.num_args - 1), &rv);
+ this->SetInstrument(rv);
+ }
+}
+
+void VirtualMachineImpl::_GetOutputArity(TVMArgs args, TVMRetValue* rv) {
+ std::string func_name = args[0];
+ RegType out = LookupVMOutput(func_name);
+ ObjectRef obj = IndexIntoNestedObject(out.AsObjectRef<ObjectRef>(), args, 1);
+ if (const auto* arr = obj.as<ArrayNode>()) {
+ *rv = static_cast<int>(arr->size());
+ } else {
+ *rv = -1;
+ }
+}
+
+void VirtualMachineImpl::_GetOutput(TVMArgs args, TVMRetValue* rv) {
+ std::string func_name = args[0];
+ RegType out = LookupVMOutput(func_name);
+ ObjectRef obj = IndexIntoNestedObject(out.AsObjectRef<ObjectRef>(), args, 1);
+ if (obj.as<ArrayNode>()) {
+ LOG(FATAL) << "ValueError: `get_output` cannot return a tuple for RPC
compatibility. "
+ "Please specify another index argument.";
+ return;
+ }
+ *rv = obj;
+}
+
+void VirtualMachineImpl::_SetInputWithoutParamModule(TVMArgs args,
TVMRetValue* rv) {
+ std::string func_name = args[0];
+ this->SetInput(func_name, false,
+ TVMArgs(args.values + 1, args.type_codes + 1, args.num_args -
1));
+}
+
+void VirtualMachineImpl::_SetInputWithParamModule(TVMArgs args, TVMRetValue*
rv) {
+ std::string func_name = args[0];
+ this->SetInput(func_name, true, TVMArgs(args.values + 1, args.type_codes +
1, args.num_args - 1));
+}
+
+int VirtualMachineImpl::_GetFunctionArity(std::string func_name) {
+ const VMFuncInfo& vm_func = LookupVMFuncInfo(func_name);
+ return vm_func.param_names.size();
+}
+
+std::string VirtualMachineImpl::_GetFunctionParamName(std::string func_name,
int index) {
+ const VMFuncInfo& vm_func = LookupVMFuncInfo(func_name);
+ if (static_cast<size_t>(index) >= vm_func.param_names.size()) {
+ LOG(FATAL) << "ValueError: Invalid index for " << func_name << " (" <<
index << " out of "
+ << vm_func.param_names.size() << ")";
+ }
+ return vm_func.param_names[index];
+}
+
+PackedFunc VirtualMachineImpl::_LookupFunction(const String& name) {
+ if (Optional<VMClosure> opt = this->GetClosureInternal(name, true)) {
+ return PackedFunc(
+ [clo = opt.value(), _self = GetRef<Module>(this)](TVMArgs args,
TVMRetValue* rv) -> void {
+ auto* self =
const_cast<VirtualMachineImpl*>(_self.as<VirtualMachineImpl>());
+ ICHECK(self);
+ self->InvokeClosurePacked(clo, args, rv);
+ });
+ }
+ return PackedFunc(nullptr);
+}
+
//----------------------------------------------------------------
// Profiler can be optionally disabled via a macro to reduce dep.
//----------------------------------------------------------------
@@ -958,7 +973,7 @@ class VirtualMachineProfiler : public VirtualMachineImpl {
if (inputs.size() == 0) {
ICHECK(args.num_args > 1) << "No input is provided";
TVMArgs f_args(args.values + 1, args.type_codes + 1, args.num_args -
1);
- SetInput(f_name, args, 1);
+ SetInput(f_name, false, TVMArgs(args.values + 1, args.type_codes +
1, args.num_args - 1));
inputs = GetInputsFor(f_name);
clear_inputs = true;
} else {