This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 9a985714f3 [Unity] Migrate Relax Executable/VM to `TVM_MODULE_VTABLE` 
Convention (#16149)
9a985714f3 is described below

commit 9a985714f396d40a08458426423872283b886af8
Author: Junru Shao <[email protected]>
AuthorDate: Tue Nov 21 06:19:48 2023 -0800

    [Unity] Migrate Relax Executable/VM to `TVM_MODULE_VTABLE` Convention 
(#16149)
    
    Following up with #16148, this PR migrates Relax Exectuable/VM to the
    explicit vtable introduced as part of the TVM runtime Module calling
    convention.
---
 include/tvm/runtime/packed_func.h         |  16 +-
 include/tvm/runtime/relax_vm/executable.h |  24 ++-
 include/tvm/runtime/relax_vm/vm.h         |   2 -
 src/runtime/relax_vm/executable.cc        |  45 ++---
 src/runtime/relax_vm/vm.cc                | 299 ++++++++++++++++--------------
 5 files changed, 197 insertions(+), 189 deletions(-)

diff --git a/include/tvm/runtime/packed_func.h 
b/include/tvm/runtime/packed_func.h
index eebdb288d1..7266f8c4a5 100644
--- a/include/tvm/runtime/packed_func.h
+++ b/include/tvm/runtime/packed_func.h
@@ -1153,13 +1153,19 @@ struct PackedFuncValueConverter {
   }                                                                            
             \
   }
 
-#define TVM_MODULE_VTABLE_BEGIN(TypeKey)                                       
       \
-  const char* type_key() const final { return TypeKey; }                       
       \
-  PackedFunc GetFunction(const String& _name, const ObjectPtr<Object>& _self) 
final { \
+#define TVM_MODULE_VTABLE_BEGIN(TypeKey)                                       
          \
+  const char* type_key() const final { return TypeKey; }                       
          \
+  PackedFunc GetFunction(const String& _name, const ObjectPtr<Object>& _self) 
override { \
     using SelfPtr = std::remove_cv_t<decltype(this)>;
 #define TVM_MODULE_VTABLE_END() \
   return PackedFunc(nullptr);   \
   }
+#define TVM_MODULE_VTABLE_END_WITH_DEFAULT(MemFunc) \
+  {                                                 \
+    auto f = (MemFunc);                             \
+    return (this->*f)(_name);                       \
+  }                                                 \
+  }  // NOLINT(*)
 #define TVM_MODULE_VTABLE_ENTRY(Name, MemFunc)                                 
                   \
   if (_name == Name) {                                                         
                   \
     return PackedFunc([_self](TVMArgs args, TVMRetValue* rv) -> void {         
                   \
@@ -2234,6 +2240,6 @@ inline TVMArgValue::operator DLDataType() const {
 
 inline TVMArgValue::operator DataType() const { return DataType(operator 
DLDataType()); }
 
-}  // namespace runtime
-}  // namespace tvm
+}  // namespace runtime // NOLINT(*)
+}  // namespace tvm // NOLINT(*)
 #endif  // TVM_RUNTIME_PACKED_FUNC_H_
diff --git a/include/tvm/runtime/relax_vm/executable.h 
b/include/tvm/runtime/relax_vm/executable.h
index 68c3cc7e15..845842f22a 100644
--- a/include/tvm/runtime/relax_vm/executable.h
+++ b/include/tvm/runtime/relax_vm/executable.h
@@ -25,6 +25,7 @@
 
 #include <tvm/runtime/container/closure.h>
 #include <tvm/runtime/object.h>
+#include <tvm/runtime/packed_func.h>
 #include <tvm/runtime/registry.h>
 
 #include <string>
@@ -87,14 +88,6 @@ struct VMFuncInfo {
  */
 class Executable : public runtime::ModuleNode {
  public:
-  /*!
-   * \brief Get a PackedFunc from the executable module.
-   * \param name the name of the function.
-   * \param sptr_to_self The shared_ptr that points to this module node.
-   * \return PackedFunc or nullptr when it is not available.
-   */
-  PackedFunc GetFunction(const String& name, const ObjectPtr<Object>& 
sptr_to_self) final;
-
   /*! \brief Get the property of the runtime module .*/
   int GetPropertyMask() const final { return 
ModulePropertyMask::kBinarySerializable; };
 
@@ -144,6 +137,12 @@ class Executable : public runtime::ModuleNode {
    * \param format The target format of the saved file.
    */
   void SaveToFile(const String& file_name, const String& format) final;
+  /*! \brief Create a Relax virtual machine and load `this` as the executable. 
*/
+  Module VMLoadExecutable() const;
+  /*! \brief Create a Relax virtual machine with profiler and load `this` as 
the executable. */
+  Module VMProfilerLoadExecutable() const;
+  /*! \brief Check if the Executable contains a specific function. */
+  bool HasFunction(const String& name) const;
   /*!
    * \brief Load Executable from the file.
    * \param file_name The path of the file that load the executable from.
@@ -164,7 +163,14 @@ class Executable : public runtime::ModuleNode {
 
   virtual ~Executable() {}
 
-  const char* type_key() const final { return "relax.Executable"; }
+  TVM_MODULE_VTABLE_BEGIN("relax.Executable");
+  TVM_MODULE_VTABLE_ENTRY("stats", &Executable::Stats);
+  TVM_MODULE_VTABLE_ENTRY("as_text", &Executable::AsText);
+  TVM_MODULE_VTABLE_ENTRY("as_python", &Executable::AsPython);
+  TVM_MODULE_VTABLE_ENTRY("vm_load_executable", &Executable::VMLoadExecutable);
+  TVM_MODULE_VTABLE_ENTRY("vm_profiler_load_executable", 
&Executable::VMProfilerLoadExecutable);
+  TVM_MODULE_VTABLE_ENTRY("has_function", &Executable::HasFunction);
+  TVM_MODULE_VTABLE_END();
 
  private:
   /*!
diff --git a/include/tvm/runtime/relax_vm/vm.h 
b/include/tvm/runtime/relax_vm/vm.h
index 4e1f1361be..d2c96e9e97 100644
--- a/include/tvm/runtime/relax_vm/vm.h
+++ b/include/tvm/runtime/relax_vm/vm.h
@@ -176,8 +176,6 @@ class VirtualMachine : public runtime::ModuleNode {
 
   ~VirtualMachine() {}
 
-  const char* type_key() const final { return "relax.VirtualMachine"; }
-
   //--------------------------------------------------------------------------
   // The following section contains states that other builtin can depend on
   //--------------------------------------------------------------------------
diff --git a/src/runtime/relax_vm/executable.cc 
b/src/runtime/relax_vm/executable.cc
index 98be501ec2..f45786c3da 100644
--- a/src/runtime/relax_vm/executable.cc
+++ b/src/runtime/relax_vm/executable.cc
@@ -52,37 +52,6 @@ enum ConstantType : int {
   ICHECK(val) << "Invalid VM file format in the " << section << " section." \
               << "\n";
 
-PackedFunc Executable::GetFunction(const String& name, const 
ObjectPtr<Object>& sptr_to_self) {
-  if (name == "stats") {
-    return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { 
*rv = this->Stats(); });
-  } else if (name == "as_text") {
-    return PackedFunc(
-        [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = 
this->AsText(); });
-  } else if (name == "as_python") {
-    return PackedFunc(
-        [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = 
this->AsPython(); });
-  } else if (name == "vm_load_executable") {
-    return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
-      ObjectPtr<VirtualMachine> vm = VirtualMachine::Create();
-      ICHECK(sptr_to_self.get() == this);
-      vm->LoadExecutable(GetObjectPtr<Executable>(this));
-      *rv = Module(vm);
-    });
-  } else if (name == "vm_profiler_load_executable") {
-    return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
-      ObjectPtr<VirtualMachine> vm = VirtualMachine::CreateProfiler();
-      ICHECK(sptr_to_self.get() == this);
-      vm->LoadExecutable(GetObjectPtr<Executable>(this));
-      *rv = Module(vm);
-    });
-  } else if (name == "has_function") {
-    return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
-      *rv = static_cast<bool>(this->func_map.count(args[0]));
-    });
-  }
-  return nullptr;
-}
-
 std::string Executable::Stats() const {
   std::ostringstream oss;
   oss << "Relax VM executable statistics:" << std::endl;
@@ -435,6 +404,20 @@ std::string RegNameToStr(RegName reg) {
   return "%" + std::to_string(reg);
 }
 
+Module Executable::VMLoadExecutable() const {
+  ObjectPtr<VirtualMachine> vm = VirtualMachine::Create();
+  vm->LoadExecutable(GetObjectPtr<Executable>(const_cast<Executable*>(this)));
+  return Module(vm);
+}
+
+Module Executable::VMProfilerLoadExecutable() const {
+  ObjectPtr<VirtualMachine> vm = VirtualMachine::CreateProfiler();
+  vm->LoadExecutable(GetObjectPtr<Executable>(const_cast<Executable*>(this)));
+  return Module(vm);
+}
+
+bool Executable::HasFunction(const String& name) const { return 
func_map.count(name); }
+
 String Executable::AsText() const {
   auto get_func_name = [&](Index index) -> std::string {
     if (static_cast<size_t>(index) < func_table.size()) {
diff --git a/src/runtime/relax_vm/vm.cc b/src/runtime/relax_vm/vm.cc
index 35d4a61d22..b31268e697 100644
--- a/src/runtime/relax_vm/vm.cc
+++ b/src/runtime/relax_vm/vm.cc
@@ -20,6 +20,8 @@
 /*!
  * \file src/runtime/relax_vm/vm.cc
  */
+#include <dlpack/dlpack.h>
+#include <tvm/runtime/memory/memory_manager.h>
 #include <tvm/runtime/nvtx.h>
 #include <tvm/runtime/packed_func.h>
 #include <tvm/runtime/profiling.h>
@@ -183,21 +185,46 @@ class VirtualMachineImpl : public VirtualMachine {
   // Public facing functions overloading
   //---------------------------------------------------
   void LoadExecutable(ObjectPtr<Executable> exec) final;
-
   void Init(const std::vector<Device>& devices,
             const std::vector<AllocatorType>& alloc_types) final;
-
-  PackedFunc GetFunction(const String& name, const ObjectPtr<Object>& 
sptr_to_self) override;
-
   VMClosure GetClosure(const String& func_name) final {
     return this->GetClosureInternal(func_name, false).value();
   }
-
   void InvokeClosurePacked(const ObjectRef& closure_or_packedfunc, TVMArgs 
args,
                            TVMRetValue* rv) final;
-
   void SetInstrument(PackedFunc instrument) final { this->instrument_ = 
instrument; }
 
+  //---------------------------------------------------
+  // Functions in the vtable of Module
+  //---------------------------------------------------
+  void _Init(TVMArgs args, TVMRetValue* rv);
+  void _SaveClosure(TVMArgs args, TVMRetValue* rv);
+  void _InvokeClosure(TVMArgs args, TVMRetValue* rv);
+  void _InvokeClosureStateful(std::string func_name);
+  void _SetInstrument(TVMArgs args, TVMRetValue* rv);
+  void _GetOutputArity(TVMArgs args, TVMRetValue* rv);
+  void _GetOutput(TVMArgs args, TVMRetValue* rv);
+  void _SetInputWithoutParamModule(TVMArgs args, TVMRetValue* rv);
+  void _SetInputWithParamModule(TVMArgs args, TVMRetValue* rv);
+  int _GetFunctionArity(std::string func_name);
+  std::string _GetFunctionParamName(std::string func_name, int index);
+  PackedFunc _LookupFunction(const String& name);
+
+  TVM_MODULE_VTABLE_BEGIN("relax.VirtualMachine");
+  TVM_MODULE_VTABLE_ENTRY_PACKED("vm_initialization", 
&VirtualMachineImpl::_Init);
+  TVM_MODULE_VTABLE_ENTRY_PACKED("save_function", 
&VirtualMachineImpl::_SaveClosure);
+  TVM_MODULE_VTABLE_ENTRY_PACKED("invoke_closure", 
&VirtualMachineImpl::_InvokeClosure);
+  TVM_MODULE_VTABLE_ENTRY("invoke_stateful", 
&VirtualMachineImpl::_InvokeClosureStateful);
+  TVM_MODULE_VTABLE_ENTRY_PACKED("set_instrument", 
&VirtualMachineImpl::_SetInstrument);
+  TVM_MODULE_VTABLE_ENTRY_PACKED("get_output_arity", 
&VirtualMachineImpl::_GetOutputArity);
+  TVM_MODULE_VTABLE_ENTRY_PACKED("get_output", 
&VirtualMachineImpl::_GetOutput);
+  TVM_MODULE_VTABLE_ENTRY_PACKED("set_input", 
&VirtualMachineImpl::_SetInputWithoutParamModule);
+  TVM_MODULE_VTABLE_ENTRY_PACKED("set_input_with_param_module",
+                                 
&VirtualMachineImpl::_SetInputWithParamModule);
+  TVM_MODULE_VTABLE_ENTRY("get_function_arity", 
&VirtualMachineImpl::_GetFunctionArity);
+  TVM_MODULE_VTABLE_ENTRY("get_function_param_name", 
&VirtualMachineImpl::_GetFunctionParamName);
+  TVM_MODULE_VTABLE_END_WITH_DEFAULT(&VirtualMachineImpl::_LookupFunction);
+
   //--------------------------------------------------
   // Additional support arguments functions for VM
   //--------------------------------------------------
@@ -214,14 +241,13 @@ class VirtualMachineImpl : public VirtualMachine {
    * \param func_name The function name.
    * \param args args[offset:] are arguments to the function. If the arguments 
are not of the
    * correct device for the function, they will be copied to the device.
-   * \param offset Starting offset of the arguments in \p args.
    * \param with_param_module If set to true, the last argument will be a 
module and can be invoked
    *        to get the argument, this is mainly used for debugging purposes 
and setting composite
    * objects. \note This interface works when using VM over RPC by internally 
converting NDArray in
    * the arguments to DLTensor, which is supported in RPC where remote could 
only have a minimal C
    * runtime.
    */
-  void SetInput(std::string func_name, TVMArgs args, int offset, bool 
with_param_module = false);
+  void SetInput(std::string func_name, bool with_param_module, TVMArgs args);
 
   /*!
    * \brief Look up whether the VM has a function by the given name.
@@ -455,149 +481,21 @@ RegType VirtualMachineImpl::LookupVMOutput(const 
std::string& func_name) {
   return outputs_[func_name];
 }
 
-PackedFunc VirtualMachineImpl::GetFunction(const String& name,
-                                           const ObjectPtr<Object>& 
sptr_to_self) {
-  if (name == "vm_initialization") {
-    // initialize the VirtualMachine, takes variable-length arguments
-    // first argument is a runtime::Module, followed by one or more 
device_type, device_id,
-    // and the AllocatorType associated with the device.
-    return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
-      ICHECK_EQ(args.size() % 3, 0);
-      std::vector<Device> devices;
-      std::vector<AllocatorType> alloc_types;
-      for (int i = 0; i < args.size(); i += 3) {
-        Device dev;
-        int device_type = args[i];
-        dev.device_type = DLDeviceType(device_type);
-        dev.device_id = args[i + 1];
-        int type = args[i + 2];
-        devices.push_back(dev);
-        alloc_types.push_back(AllocatorType(type));
-      }
-      this->Init(devices, alloc_types);
-    });
-  } else if (name == "save_function") {
-    return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
-      ICHECK_GE(args.size(), 3);
-      this->SaveClosure(args[0], args[1], args[2],
-                        TVMArgs(args.values + 3, args.type_codes + 3, 
args.size() - 3));
-    });
-  } else if (name == "invoke_closure") {
-    return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
-      VMClosure clo = args[0];
-      this->InvokeClosurePacked(clo, TVMArgs(args.values + 1, args.type_codes 
+ 1, args.size() - 1),
-                                rv);
-    });
-  } else if (name == "set_instrument") {
-    return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
-      PackedFunc func;
-      if (args[0].type_code() != kTVMPackedFuncHandle) {
-        String func_name = args[0];
-        const PackedFunc* factory = Registry::Get(func_name);
-        ICHECK(factory != nullptr) << "Cannot find factory " << func_name;
-        TVMRetValue rv;
-        factory->CallPacked(TVMArgs(args.values + 1, args.type_codes + 1, 
args.num_args - 1), &rv);
-        func = rv;
-      } else {
-        func = args[0];
-      }
-      this->SetInstrument(func);
-    });
-  } else if (name == "invoke_stateful") {
-    return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
-      std::string func_name = args[0];
-      const auto& m = this->exec_->func_map;
-      if (m.find(func_name) == m.end()) {
-        LOG(FATAL) << "ValueError: Unknown function: " << func_name;
-      }
-      Index gf_idx = m.at(func_name);
-      if (!inputs_.count(func_name)) {
-        LOG(FATAL) << "ValueError: No inputs set for stateful call of " << 
func_name
-                   << "; use `set_input` first.";
-        return;
-      }
-      outputs_[func_name] = this->InvokeClosureInternal(func_pool_[gf_idx], 
inputs_[func_name]);
-    });
-  } else if (name == "get_output_arity") {
-    return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
-      std::string func_name = args[0];
-      RegType out = LookupVMOutput(func_name);
-      // use remaining args as indices
-      ObjectRef obj = IndexIntoNestedObject(out.AsObjectRef<ObjectRef>(), 
args, 1);
-      // after chasing through the indices, examine the final object
-      if (const auto* arr = obj.as<ArrayNode>()) {
-        *rv = static_cast<int>(arr->size());
-      } else {
-        *rv = -1;
-      }
-    });
-  } else if (name == "get_output") {
-    return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
-      std::string func_name = args[0];
-      RegType out = LookupVMOutput(func_name);
-      // use remaining args as indices
-      ObjectRef obj = IndexIntoNestedObject(out.AsObjectRef<ObjectRef>(), 
args, 1);
-      if (obj.as<ArrayNode>()) {
-        LOG(FATAL) << "ValueError: `get_output` cannot return a tuple for RPC 
compatibility. "
-                      "Please specify another index argument.";
-        return;
-      }
-      *rv = obj;
-    });
-  } else if (name == "set_input") {
-    return PackedFunc(
-        [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { 
SetInput(args[0], args, 1); });
-  } else if (name == "set_input_with_param_module") {
-    return PackedFunc(
-        [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { 
SetInput(args[0], args, 1, true); });
-  } else if (name == "get_function_arity") {
-    return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
-      std::string func_name = args[0];
-      const VMFuncInfo& vm_func = LookupVMFuncInfo(func_name);
-      *rv = static_cast<int>(vm_func.param_names.size());
-    });
-  } else if (name == "get_function_param_name") {
-    return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
-      std::string func_name = args[0];
-      int index = args[1];
-      const VMFuncInfo& vm_func = LookupVMFuncInfo(func_name);
-      if (static_cast<size_t>(index) >= vm_func.param_names.size()) {
-        LOG(FATAL) << "ValueError: Invalid index for " << func_name << " (" << 
index << " out of "
-                   << vm_func.param_names.size() << ")";
-      }
-      *rv = vm_func.param_names[index];
-    });
-  } else {
-    // default case, look up closure in VM.
-    if (Optional<VMClosure> opt = this->GetClosureInternal(name, true)) {
-      auto clo = opt.value();
-      return PackedFunc([sptr_to_self, this, clo](TVMArgs args, TVMRetValue* 
rv) {
-        this->InvokeClosurePacked(clo, args, rv);
-      });
-    } else {
-      return PackedFunc(nullptr);
-    }
-  }
-}
-
-void VirtualMachineImpl::SetInput(std::string func_name, TVMArgs args, int 
offset,
-                                  bool with_param_module) {
+void VirtualMachineImpl::SetInput(std::string func_name, bool 
with_param_module, TVMArgs args) {
   const auto& m = exec_->func_map;
   if (m.find(func_name) != m.end()) {
     Index gf_idx = m.at(func_name);
     const VMFuncInfo& vm_func = exec_->func_table[gf_idx];
     size_t params_num = vm_func.num_args;
-    ICHECK_EQ(args.size() - offset, params_num)
+    ICHECK_EQ(args.size(), params_num)
         << "The number of provided parameters doesn't match the number of 
arguments for";
     std::vector<RegType> func_args(params_num);
-
-    for (int i = offset; i < args.size(); ++i) {
-      int index = i - offset;
+    for (int i = 0; i < args.size(); ++i) {
       if (with_param_module && i == args.size() - 1) {
         // call param func to get the arguments(usually corresponds to param 
pack.)
-        func_args[index] = (args[i].operator 
Module()).GetFunction("get_params")();
+        func_args[i] = (args[i].operator Module()).GetFunction("get_params")();
       } else {
-        func_args[index] = ConvertArgToDevice(args[i], devices[0], 
allocators[0]);
+        func_args[i] = ConvertArgToDevice(args[i], devices[0], allocators[0]);
       }
     }
     inputs_[func_name] = func_args;
@@ -926,6 +824,123 @@ void VirtualMachineImpl::RunLoop() {
 
 ObjectPtr<VirtualMachine> VirtualMachine::Create() { return 
make_object<VirtualMachineImpl>(); }
 
+//--------------------------------------------------------------------
+// FFI related code
+//--------------------------------------------------------------------
+
+void VirtualMachineImpl::_Init(TVMArgs args, TVMRetValue* rv) {
+  ICHECK_EQ(args.size() % 3, 0);
+  std::vector<Device> devices;
+  std::vector<AllocatorType> alloc_types;
+  for (int i = 0; i < args.size(); i += 3) {
+    int device_type = args[i];
+    int device_id = args[i + 1];
+    int alloc_type = args[i + 2];
+    devices.push_back(Device{DLDeviceType(device_type), device_id});
+    alloc_types.push_back(AllocatorType(alloc_type));
+  }
+  this->Init(devices, alloc_types);
+}
+
+void VirtualMachineImpl::_SaveClosure(TVMArgs args, TVMRetValue* rv) {
+  ICHECK_GE(args.size(), 3);
+  std::string func_name = args[0];
+  this->SaveClosure(func_name, args[1], args[2],
+                    TVMArgs(args.values + 3, args.type_codes + 3, args.size() 
- 3));
+}
+
+void VirtualMachineImpl::_InvokeClosure(TVMArgs args, TVMRetValue* rv) {
+  this->InvokeClosurePacked(args[0], TVMArgs(args.values + 1, args.type_codes 
+ 1, args.size() - 1),
+                            rv);
+}
+
+void VirtualMachineImpl::_InvokeClosureStateful(std::string func_name) {
+  const std::unordered_map<std::string, Index>& m = this->exec_->func_map;
+  if (m.find(func_name) == m.end()) {
+    LOG(FATAL) << "ValueError: Unknown function: " << func_name;
+  }
+  if (!inputs_.count(func_name)) {
+    LOG(FATAL) << "ValueError: No inputs set for stateful call of " << 
func_name
+               << "; use `set_input` first.";
+    return;
+  }
+  outputs_[func_name] =
+      this->InvokeClosureInternal(func_pool_[m.at(func_name)], 
inputs_[func_name]);
+}
+
+void VirtualMachineImpl::_SetInstrument(TVMArgs args, TVMRetValue* rv) {
+  if (args[0].type_code() == kTVMPackedFuncHandle) {
+    this->SetInstrument(args[0]);
+  } else {
+    String func_name = args[0];
+    const PackedFunc* factory = Registry::Get(func_name);
+    CHECK(factory) << "Cannot find factory " << func_name;
+    TVMRetValue rv;
+    factory->CallPacked(TVMArgs(args.values + 1, args.type_codes + 1, 
args.num_args - 1), &rv);
+    this->SetInstrument(rv);
+  }
+}
+
+void VirtualMachineImpl::_GetOutputArity(TVMArgs args, TVMRetValue* rv) {
+  std::string func_name = args[0];
+  RegType out = LookupVMOutput(func_name);
+  ObjectRef obj = IndexIntoNestedObject(out.AsObjectRef<ObjectRef>(), args, 1);
+  if (const auto* arr = obj.as<ArrayNode>()) {
+    *rv = static_cast<int>(arr->size());
+  } else {
+    *rv = -1;
+  }
+}
+
+void VirtualMachineImpl::_GetOutput(TVMArgs args, TVMRetValue* rv) {
+  std::string func_name = args[0];
+  RegType out = LookupVMOutput(func_name);
+  ObjectRef obj = IndexIntoNestedObject(out.AsObjectRef<ObjectRef>(), args, 1);
+  if (obj.as<ArrayNode>()) {
+    LOG(FATAL) << "ValueError: `get_output` cannot return a tuple for RPC 
compatibility. "
+                  "Please specify another index argument.";
+    return;
+  }
+  *rv = obj;
+}
+
+void VirtualMachineImpl::_SetInputWithoutParamModule(TVMArgs args, 
TVMRetValue* rv) {
+  std::string func_name = args[0];
+  this->SetInput(func_name, false,
+                 TVMArgs(args.values + 1, args.type_codes + 1, args.num_args - 
1));
+}
+
+void VirtualMachineImpl::_SetInputWithParamModule(TVMArgs args, TVMRetValue* 
rv) {
+  std::string func_name = args[0];
+  this->SetInput(func_name, true, TVMArgs(args.values + 1, args.type_codes + 
1, args.num_args - 1));
+}
+
+int VirtualMachineImpl::_GetFunctionArity(std::string func_name) {
+  const VMFuncInfo& vm_func = LookupVMFuncInfo(func_name);
+  return vm_func.param_names.size();
+}
+
+std::string VirtualMachineImpl::_GetFunctionParamName(std::string func_name, 
int index) {
+  const VMFuncInfo& vm_func = LookupVMFuncInfo(func_name);
+  if (static_cast<size_t>(index) >= vm_func.param_names.size()) {
+    LOG(FATAL) << "ValueError: Invalid index for " << func_name << " (" << 
index << " out of "
+               << vm_func.param_names.size() << ")";
+  }
+  return vm_func.param_names[index];
+}
+
+PackedFunc VirtualMachineImpl::_LookupFunction(const String& name) {
+  if (Optional<VMClosure> opt = this->GetClosureInternal(name, true)) {
+    return PackedFunc(
+        [clo = opt.value(), _self = GetRef<Module>(this)](TVMArgs args, 
TVMRetValue* rv) -> void {
+          auto* self = 
const_cast<VirtualMachineImpl*>(_self.as<VirtualMachineImpl>());
+          ICHECK(self);
+          self->InvokeClosurePacked(clo, args, rv);
+        });
+  }
+  return PackedFunc(nullptr);
+}
+
 //----------------------------------------------------------------
 // Profiler can be optionally disabled via a macro to reduce dep.
 //----------------------------------------------------------------
@@ -958,7 +973,7 @@ class VirtualMachineProfiler : public VirtualMachineImpl {
         if (inputs.size() == 0) {
           ICHECK(args.num_args > 1) << "No input is provided";
           TVMArgs f_args(args.values + 1, args.type_codes + 1, args.num_args - 
1);
-          SetInput(f_name, args, 1);
+          SetInput(f_name, false, TVMArgs(args.values + 1, args.type_codes + 
1, args.num_args - 1));
           inputs = GetInputsFor(f_name);
           clear_inputs = true;
         } else {

Reply via email to