mulanxiaodingdang commented on issue #18146:
URL: https://github.com/apache/tvm/issues/18146#issuecomment-3069378755

   The following code produced this output, and it seems like it should do the 
job:
   ```
   #include <iostream>
   #include <tvm/runtime/relax_vm/executable.h>
   #include <tvm/runtime/logging.h>
   #include <tvm/runtime/memory/memory_manager.h>
   #include <tvm/runtime/data_type.h>
   #include <tvm/ffi/function.h>
   //#include <tvm/runtime/device_api.h>
   
   using tvm::runtime::relax_vm::VMExecutable;
   using tvm::runtime::Module;
   using tvm::runtime::NDArray;
   using tvm::Device;
   using tvm::runtime::DataType;
   using tvm::runtime::memory::AllocatorType;
   
   int main() {
       using namespace tvm;
       using namespace tvm::runtime;
   
       std::string path = "./compiled_artifact.so";
   
       // Step 1: 加载模块
       Module mod = Module::LoadFromFile(path);
       std::cout << mod << std::endl;
   
       // Step 2: 获取 vm_load_executable 并调用,返回 VM 实例
       tvm::ffi::Function vm_load_executable = 
mod.GetFunction("vm_load_executable");
       CHECK(vm_load_executable.defined()) << "vm_load_executable not found in 
module.";
   
       tvm::ffi::Any re = vm_load_executable();
       std::optional<Module> maybe_vm = re.as<Module>();
       if (!maybe_vm.has_value()) {
           std::cerr << "Returned value from vm_load_executable is not a 
Module" << std::endl;
           return -1;
       }
       Module vm = maybe_vm.value();
   
       // Step 3: 初始化虚拟机(关键步骤)
       Device dev{kDLCPU, 0};
       tvm::ffi::Function vm_init = vm.GetFunction("vm_initialization");
       CHECK(vm_init.defined()) << "vm_initialization not found in VM";
       vm_init(static_cast<int>(dev.device_type), 
static_cast<int>(dev.device_id),
               static_cast<int>(AllocatorType::kPooled),
               static_cast<int>(dev.device_type), 
static_cast<int>(dev.device_id),
               static_cast<int>(AllocatorType::kPooled));
   
       // Step 4: 获取 main 函数并运行
       tvm::ffi::Function vm_main = vm.GetFunction("main");
       CHECK(vm_main.defined()) << "main function not found";
   
       NDArray input = NDArray::Empty({3, 3}, DataType::Int(32), dev);
       int* in_data = static_cast<int*>(input->data);
       for (int i = 0; i < 9; ++i) in_data[i] = 42;
   
       tvm::ffi::Any ret = vm_main(input);
       std::optional<NDArray> output_opt = ret.as<NDArray>();
       if (!output_opt.has_value()) {
           std::cerr << "Error: output is not NDArray." << std::endl;
           return -1;
       }
       NDArray output = output_opt.value();
   
       int* out_data = static_cast<int*>(output->data);
       for (int i = 0; i < 9; ++i) std::cout << out_data[i] << " ";
       std::cout << std::endl;
   
       return 0;
   }
   ```
   
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to