This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main-mod
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 0fc6fa82e8cc24887c2a0667c8aef3beb1a96704
Author: tqchen <[email protected]>
AuthorDate: Tue Aug 12 12:21:27 2025 -0400

    [FFI][REFATOR] Cleanup entry function to redirect
    
    This PR updates the entry function mechanism to create a stub that 
redirects to the real function.
    This new behavior helps to simplify the runtime logic supporting entry 
function.
    Also updates the name to __tvm_ffi_main__
---
 include/tvm/runtime/module.h                       |  4 +-
 jvm/core/src/main/java/org/apache/tvm/Module.java  |  2 +-
 python/tvm/runtime/module.py                       |  4 +-
 src/runtime/cuda/cuda_module.cc                    |  1 -
 src/runtime/library_module.cc                      | 10 +----
 src/runtime/metal/metal_module.mm                  |  1 -
 src/runtime/opencl/opencl_module.cc                |  1 -
 src/runtime/rocm/rocm_module.cc                    |  1 -
 src/runtime/vulkan/vulkan_wrapped_func.cc          |  1 -
 src/target/llvm/codegen_cpu.cc                     | 46 +++++++++++++---------
 src/target/llvm/llvm_module.cc                     | 10 +----
 src/target/source/codegen_c_host.cc                |  4 +-
 .../test_hexagon/test_async_dma_pipeline.py        |  4 +-
 .../contrib/test_hexagon/test_parallel_hvx.py      |  2 +-
 .../test_hexagon/test_parallel_hvx_load_vtcm.py    |  4 +-
 .../contrib/test_hexagon/test_parallel_scalar.py   |  2 +-
 .../contrib/test_hexagon/test_vtcm_bandwidth.py    |  4 +-
 17 files changed, 45 insertions(+), 56 deletions(-)

diff --git a/include/tvm/runtime/module.h b/include/tvm/runtime/module.h
index efbaa6508a..80c03ea751 100644
--- a/include/tvm/runtime/module.h
+++ b/include/tvm/runtime/module.h
@@ -290,14 +290,14 @@ namespace symbol {
 constexpr const char* tvm_ffi_library_ctx = "__tvm_ffi_library_ctx";
 /*! \brief Global variable to store binary data alongside a library module. */
 constexpr const char* tvm_ffi_library_bin = "__tvm_ffi_library_bin";
+/*! \brief Placeholder for the module's entry function. */
+constexpr const char* tvm_ffi_main = "__tvm_ffi_main__";
 /*! \brief global function to set device */
 constexpr const char* tvm_set_device = "__tvm_set_device";
 /*! \brief Auxiliary counter to global barrier. */
 constexpr const char* tvm_global_barrier_state = "__tvm_global_barrier_state";
 /*! \brief Prepare the global barrier before kernels that uses global barrier. 
*/
 constexpr const char* tvm_prepare_global_barrier = 
"__tvm_prepare_global_barrier";
-/*! \brief Placeholder for the module's entry function. */
-constexpr const char* tvm_module_main = "__tvm_main__";
 }  // namespace symbol
 
 // implementations of inline functions.
diff --git a/jvm/core/src/main/java/org/apache/tvm/Module.java 
b/jvm/core/src/main/java/org/apache/tvm/Module.java
index 5e78e26ae7..9fa65054f9 100644
--- a/jvm/core/src/main/java/org/apache/tvm/Module.java
+++ b/jvm/core/src/main/java/org/apache/tvm/Module.java
@@ -46,7 +46,7 @@ public class Module extends TVMObject {
   }
 
   private Function entry = null;
-  private final String entryName = "__tvm_main__";
+  private final String entryName = "__tvm_ffi_main__";
 
 
   /**
diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py
index 3dd4de5da0..e645d3a2b6 100644
--- a/python/tvm/runtime/module.py
+++ b/python/tvm/runtime/module.py
@@ -103,7 +103,7 @@ class Module(tvm.ffi.Object):
 
     def __new__(cls):
         instance = super(Module, cls).__new__(cls)  # pylint: 
disable=no-value-for-parameter
-        instance.entry_name = "__tvm_main__"
+        instance.entry_name = "__tvm_ffi_main__"
         instance._entry = None
         return instance
 
@@ -118,7 +118,7 @@ class Module(tvm.ffi.Object):
         """
         if self._entry:
             return self._entry
-        self._entry = self.get_function("__tvm_main__")
+        self._entry = self.get_function("__tvm_ffi_main__")
         return self._entry
 
     def implements_function(self, name, query_imports=False):
diff --git a/src/runtime/cuda/cuda_module.cc b/src/runtime/cuda/cuda_module.cc
index 2435cccf0a..6b71df928d 100644
--- a/src/runtime/cuda/cuda_module.cc
+++ b/src/runtime/cuda/cuda_module.cc
@@ -258,7 +258,6 @@ class CUDAPrepGlobalBarrier {
 ffi::Function CUDAModuleNode::GetFunction(const String& name,
                                           const ObjectPtr<Object>& 
sptr_to_self) {
   ICHECK_EQ(sptr_to_self.get(), this);
-  ICHECK_NE(name, symbol::tvm_module_main) << "Device function do not have 
main";
   if (name == symbol::tvm_prepare_global_barrier) {
     return ffi::Function(CUDAPrepGlobalBarrier(this, sptr_to_self));
   }
diff --git a/src/runtime/library_module.cc b/src/runtime/library_module.cc
index fffac4adea..24fc7518d6 100644
--- a/src/runtime/library_module.cc
+++ b/src/runtime/library_module.cc
@@ -50,15 +50,7 @@ class LibraryModuleNode final : public ModuleNode {
 
   ffi::Function GetFunction(const String& name, const ObjectPtr<Object>& 
sptr_to_self) final {
     TVMFFISafeCallType faddr;
-    if (name == runtime::symbol::tvm_module_main) {
-      const char* entry_name =
-          reinterpret_cast<const 
char*>(lib_->GetSymbol(runtime::symbol::tvm_module_main));
-      ICHECK(entry_name != nullptr)
-          << "Symbol " << runtime::symbol::tvm_module_main << " is not 
presented";
-      faddr = 
reinterpret_cast<TVMFFISafeCallType>(lib_->GetSymbol(entry_name));
-    } else {
-      faddr = 
reinterpret_cast<TVMFFISafeCallType>(lib_->GetSymbol(name.c_str()));
-    }
+    faddr = 
reinterpret_cast<TVMFFISafeCallType>(lib_->GetSymbol(name.c_str()));
     if (faddr == nullptr) return ffi::Function();
     return packed_func_wrapper_(faddr, sptr_to_self);
   }
diff --git a/src/runtime/metal/metal_module.mm 
b/src/runtime/metal/metal_module.mm
index be36e6197f..33bb1705c8 100644
--- a/src/runtime/metal/metal_module.mm
+++ b/src/runtime/metal/metal_module.mm
@@ -264,7 +264,6 @@ ffi::Function MetalModuleNode::GetFunction(const String& 
name,
   ffi::Function ret;
   AUTORELEASEPOOL {
     ICHECK_EQ(sptr_to_self.get(), this);
-    ICHECK_NE(name, symbol::tvm_module_main) << "Device function do not have 
main";
     auto it = fmap_.find(name);
     if (it == fmap_.end()) {
       ret = ffi::Function();
diff --git a/src/runtime/opencl/opencl_module.cc 
b/src/runtime/opencl/opencl_module.cc
index 19b426d4b4..1c61eeb596 100644
--- a/src/runtime/opencl/opencl_module.cc
+++ b/src/runtime/opencl/opencl_module.cc
@@ -138,7 +138,6 @@ cl::OpenCLWorkspace* 
OpenCLModuleNodeBase::GetGlobalWorkspace() {
 ffi::Function OpenCLModuleNodeBase::GetFunction(const String& name,
                                                 const ObjectPtr<Object>& 
sptr_to_self) {
   ICHECK_EQ(sptr_to_self.get(), this);
-  ICHECK_NE(name, symbol::tvm_module_main) << "Device function do not have 
main";
   auto it = fmap_.find(name);
   if (it == fmap_.end()) return ffi::Function();
   const FunctionInfo& info = it->second;
diff --git a/src/runtime/rocm/rocm_module.cc b/src/runtime/rocm/rocm_module.cc
index 791e4b1569..a871a41f0f 100644
--- a/src/runtime/rocm/rocm_module.cc
+++ b/src/runtime/rocm/rocm_module.cc
@@ -195,7 +195,6 @@ class ROCMWrappedFunc {
 ffi::Function ROCMModuleNode::GetFunction(const String& name,
                                           const ObjectPtr<Object>& 
sptr_to_self) {
   ICHECK_EQ(sptr_to_self.get(), this);
-  ICHECK_NE(name, symbol::tvm_module_main) << "Device function do not have 
main";
   auto it = fmap_.find(name);
   if (it == fmap_.end()) return ffi::Function();
   const FunctionInfo& info = it->second;
diff --git a/src/runtime/vulkan/vulkan_wrapped_func.cc 
b/src/runtime/vulkan/vulkan_wrapped_func.cc
index f4922a1bf0..db81c959dc 100644
--- a/src/runtime/vulkan/vulkan_wrapped_func.cc
+++ b/src/runtime/vulkan/vulkan_wrapped_func.cc
@@ -208,7 +208,6 @@ VulkanModuleNode::~VulkanModuleNode() {
 ffi::Function VulkanModuleNode::GetFunction(const String& name,
                                             const ObjectPtr<Object>& 
sptr_to_self) {
   ICHECK_EQ(sptr_to_self.get(), this);
-  ICHECK_NE(name, symbol::tvm_module_main) << "Device function do not have 
main";
   auto it = fmap_.find(name);
   if (it == fmap_.end()) return ffi::Function();
   const FunctionInfo& info = it->second;
diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc
index 4dd24026c0..69862ced0f 100644
--- a/src/target/llvm/codegen_cpu.cc
+++ b/src/target/llvm/codegen_cpu.cc
@@ -229,28 +229,38 @@ void CodeGenCPU::AddFunction(const GlobalVar& gvar, const 
PrimFunc& func) {
 }
 
 void CodeGenCPU::AddMainFunction(const std::string& entry_func_name) {
-  llvm::Function* f = module_->getFunction(entry_func_name);
-  ICHECK(f) << "Function " << entry_func_name << "does not in module";
-  llvm::Type* type = llvm::ArrayType::get(t_char_, entry_func_name.length() + 
1);
-  llvm::GlobalVariable* global =
-      new llvm::GlobalVariable(*module_, type, true, 
llvm::GlobalValue::WeakAnyLinkage, nullptr,
-                               runtime::symbol::tvm_module_main);
-#if TVM_LLVM_VERSION >= 100
-  global->setAlignment(llvm::Align(1));
-#else
-  global->setAlignment(1);
-#endif
-  // comdat is needed for windows select any linking to work
-  // set comdat to Any(weak linking)
+  // create a wrapper function with tvm_ffi_main name and redirects to the 
entry function
+  llvm::Function* target_func = module_->getFunction(entry_func_name);
+  ICHECK(target_func) << "Function " << entry_func_name << " does not exist in 
module";
+
+  // Create wrapper function
+  llvm::Function* wrapper_func =
+      llvm::Function::Create(ftype_tvm_ffi_c_func_, 
llvm::Function::WeakAnyLinkage,
+                             runtime::symbol::tvm_ffi_main, module_.get());
+
+  // Set attributes (Windows comdat, DLL export, etc.)
   if 
(llvm_target_->GetOrCreateTargetMachine()->getTargetTriple().isOSWindows()) {
-    llvm::Comdat* comdat = 
module_->getOrInsertComdat(runtime::symbol::tvm_module_main);
+    llvm::Comdat* comdat = 
module_->getOrInsertComdat(runtime::symbol::tvm_ffi_main);
     comdat->setSelectionKind(llvm::Comdat::Any);
-    global->setComdat(comdat);
+    wrapper_func->setComdat(comdat);
+  }
+
+  wrapper_func->setCallingConv(llvm::CallingConv::C);
+  
wrapper_func->setDLLStorageClass(llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass);
+
+  // Create simple tail call
+  llvm::BasicBlock* entry =
+      llvm::BasicBlock::Create(*llvm_target_->GetContext(), "entry", 
wrapper_func);
+  builder_->SetInsertPoint(entry);
+
+  // Forward all arguments to target function
+  std::vector<llvm::Value*> call_args;
+  for (llvm::Value& arg : wrapper_func->args()) {
+    call_args.push_back(&arg);
   }
 
-  global->setInitializer(
-      llvm::ConstantDataArray::getString(*llvm_target_->GetContext(), 
entry_func_name));
-  global->setDLLStorageClass(llvm::GlobalVariable::DLLExportStorageClass);
+  llvm::Value* result = builder_->CreateCall(target_func, call_args);
+  builder_->CreateRet(result);
 }
 
 std::unique_ptr<llvm::Module> CodeGenCPU::Finish() {
diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc
index a9e09652ee..2daf941edf 100644
--- a/src/target/llvm/llvm_module.cc
+++ b/src/target/llvm/llvm_module.cc
@@ -190,15 +190,7 @@ ffi::Function LLVMModuleNode::GetFunction(const String& 
name,
 
   TVMFFISafeCallType faddr;
   With<LLVMTarget> llvm_target(*llvm_instance_, 
LLVMTarget::GetTargetMetadata(*module_));
-  if (name == runtime::symbol::tvm_module_main) {
-    const char* entry_name = reinterpret_cast<const char*>(
-        GetGlobalAddr(runtime::symbol::tvm_module_main, *llvm_target));
-    ICHECK(entry_name != nullptr) << "Symbol " << 
runtime::symbol::tvm_module_main
-                                  << " is not presented";
-    faddr = reinterpret_cast<TVMFFISafeCallType>(GetFunctionAddr(entry_name, 
*llvm_target));
-  } else {
-    faddr = reinterpret_cast<TVMFFISafeCallType>(GetFunctionAddr(name, 
*llvm_target));
-  }
+  faddr = reinterpret_cast<TVMFFISafeCallType>(GetFunctionAddr(name, 
*llvm_target));
   if (faddr == nullptr) return ffi::Function();
   return tvm::runtime::WrapFFIFunction(faddr, sptr_to_self);
 }
diff --git a/src/target/source/codegen_c_host.cc 
b/src/target/source/codegen_c_host.cc
index 6cd12a9319..020054b3e1 100644
--- a/src/target/source/codegen_c_host.cc
+++ b/src/target/source/codegen_c_host.cc
@@ -77,11 +77,11 @@ void CodeGenCHost::AddFunction(const GlobalVar& gvar, const 
PrimFunc& func,
         << "CodeGenCHost: The entry func must have the global_symbol 
attribute, "
         << "but function " << gvar << " only has attributes " << func->attrs;
 
-    function_names_.push_back(runtime::symbol::tvm_module_main);
+    function_names_.push_back(runtime::symbol::tvm_ffi_main);
     stream << "// CodegenC: NOTE: Auto-generated entry function\n";
     PrintFuncPrefix(stream);
     PrintType(func->ret_type, stream);
-    stream << " " << tvm::runtime::symbol::tvm_module_main
+    stream << " " << tvm::runtime::symbol::tvm_ffi_main
            << "(void* self, void* args,int num_args, void* result) {\n";
     stream << "  return " << global_symbol.value() << "(self, args, num_args, 
result);\n";
     stream << "}\n";
diff --git a/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py 
b/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py
index fe7a615531..52a7ffbc24 100644
--- a/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py
+++ b/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py
@@ -289,9 +289,9 @@ def evaluate(
 
     if tvm.testing.utils.IS_IN_CI:
         # Run with reduced number and repeat for CI
-        timer = module.time_evaluator("__tvm_main__", hexagon_session.device, 
number=1, repeat=1)
+        timer = module.time_evaluator("__tvm_ffi_main__", 
hexagon_session.device, number=1, repeat=1)
     else:
-        timer = module.time_evaluator("__tvm_main__", hexagon_session.device, 
number=10, repeat=10)
+        timer = module.time_evaluator("__tvm_ffi_main__", 
hexagon_session.device, number=10, repeat=10)
 
     time = timer(a_hexagon, b_hexagon, c_hexagon)
     if expected_output is not None:
diff --git a/tests/python/contrib/test_hexagon/test_parallel_hvx.py 
b/tests/python/contrib/test_hexagon/test_parallel_hvx.py
index 8f77fa1c40..6822352568 100644
--- a/tests/python/contrib/test_hexagon/test_parallel_hvx.py
+++ b/tests/python/contrib/test_hexagon/test_parallel_hvx.py
@@ -160,7 +160,7 @@ def evaluate(hexagon_session, shape_dtypes, 
expected_output_producer, sch):
     repeat = 1
 
     timer = module.time_evaluator(
-        "__tvm_main__", hexagon_session.device, number=number, repeat=repeat
+        "__tvm_ffi_main__", hexagon_session.device, number=number, 
repeat=repeat
     )
     runtime = timer(a_hexagon, b_hexagon, c_hexagon)
     tvm.testing.assert_allclose(c_hexagon.numpy(), 
expected_output_producer(c_shape, a, b))
diff --git a/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py 
b/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py
index a584997dd5..17e31af0a7 100644
--- a/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py
+++ b/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py
@@ -331,7 +331,7 @@ def setup_and_run(hexagon_session, sch, a, b, c, 
operations, mem_scope="global")
     repeat = 1
 
     timer = module.time_evaluator(
-        "__tvm_main__", hexagon_session.device, number=number, repeat=repeat
+        "__tvm_ffi_main__", hexagon_session.device, number=number, 
repeat=repeat
     )
     time = timer(a_hexagon, b_hexagon, c_hexagon)
     gops = round(operations * 128 * 3 / time.mean / 1e9, 4)
@@ -365,7 +365,7 @@ def setup_and_run_preallocated(hexagon_session, sch, a, b, 
c, operations):
     repeat = 1
 
     timer = module.time_evaluator(
-        "__tvm_main__", hexagon_session.device, number=number, repeat=repeat
+        "__tvm_ffi_main__", hexagon_session.device, number=number, 
repeat=repeat
     )
     time = timer(a_hexagon, b_hexagon, c_hexagon, a_vtcm_hexagon, 
b_vtcm_hexagon, c_vtcm_hexagon)
     gops = round(operations * 128 * 3 / time.mean / 1e9, 4)
diff --git a/tests/python/contrib/test_hexagon/test_parallel_scalar.py 
b/tests/python/contrib/test_hexagon/test_parallel_scalar.py
index bd9c78d5da..dd765178dc 100644
--- a/tests/python/contrib/test_hexagon/test_parallel_scalar.py
+++ b/tests/python/contrib/test_hexagon/test_parallel_scalar.py
@@ -105,7 +105,7 @@ def evaluate(hexagon_session, operations, expected, sch):
     repeat = 1
 
     timer = module.time_evaluator(
-        "__tvm_main__", hexagon_session.device, number=number, repeat=repeat
+        "__tvm_ffi_main__", hexagon_session.device, number=number, 
repeat=repeat
     )
     runtime = timer(a_hexagon, b_hexagon, c_hexagon)
 
diff --git a/tests/python/contrib/test_hexagon/test_vtcm_bandwidth.py 
b/tests/python/contrib/test_hexagon/test_vtcm_bandwidth.py
index 931f99b2ec..551f441357 100644
--- a/tests/python/contrib/test_hexagon/test_vtcm_bandwidth.py
+++ b/tests/python/contrib/test_hexagon/test_vtcm_bandwidth.py
@@ -108,9 +108,9 @@ def evaluate(hexagon_session, sch, size):
 
     if tvm.testing.utils.IS_IN_CI:
         # Run with reduced number and repeat for CI
-        timer = module.time_evaluator("__tvm_main__", hexagon_session.device, 
number=1, repeat=1)
+        timer = module.time_evaluator("__tvm_ffi_main__", 
hexagon_session.device, number=1, repeat=1)
     else:
-        timer = module.time_evaluator("__tvm_main__", hexagon_session.device, 
number=10, repeat=10)
+        timer = module.time_evaluator("__tvm_ffi_main__", 
hexagon_session.device, number=10, repeat=10)
 
     runtime = timer(a_hexagon, a_vtcm_hexagon)
 

Reply via email to