This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new dbbcf904b0 [FFI][REFATOR] Cleanup entry function to redirect (#18205)
dbbcf904b0 is described below
commit dbbcf904b0ed5604b1f582ddb6e05443d8926247
Author: Tianqi Chen <[email protected]>
AuthorDate: Tue Aug 12 17:25:14 2025 -0400
[FFI][REFATOR] Cleanup entry function to redirect (#18205)
This PR updates the entry function mechanism to create a stub that
redirects to the real function.
This new behavior helps to simplify the runtime logic supporting entry
function.
Also updates the name to `__tvm_ffi_main__`
---
include/tvm/runtime/module.h | 4 +-
jvm/core/src/main/java/org/apache/tvm/Module.java | 2 +-
python/tvm/runtime/module.py | 4 +-
src/runtime/cuda/cuda_module.cc | 1 -
src/runtime/library_module.cc | 10 +----
src/runtime/metal/metal_module.mm | 1 -
src/runtime/opencl/opencl_module.cc | 1 -
src/runtime/rocm/rocm_module.cc | 1 -
src/runtime/vulkan/vulkan_wrapped_func.cc | 1 -
src/target/llvm/codegen_cpu.cc | 50 ++++++++++++++--------
src/target/llvm/llvm_module.cc | 10 +----
src/target/source/codegen_c_host.cc | 4 +-
.../test_hexagon/test_async_dma_pipeline.py | 8 +++-
.../contrib/test_hexagon/test_parallel_hvx.py | 2 +-
.../test_hexagon/test_parallel_hvx_load_vtcm.py | 4 +-
.../contrib/test_hexagon/test_parallel_scalar.py | 2 +-
.../contrib/test_hexagon/test_vtcm_bandwidth.py | 8 +++-
17 files changed, 57 insertions(+), 56 deletions(-)
diff --git a/include/tvm/runtime/module.h b/include/tvm/runtime/module.h
index efbaa6508a..80c03ea751 100644
--- a/include/tvm/runtime/module.h
+++ b/include/tvm/runtime/module.h
@@ -290,14 +290,14 @@ namespace symbol {
constexpr const char* tvm_ffi_library_ctx = "__tvm_ffi_library_ctx";
/*! \brief Global variable to store binary data alongside a library module. */
constexpr const char* tvm_ffi_library_bin = "__tvm_ffi_library_bin";
+/*! \brief Placeholder for the module's entry function. */
+constexpr const char* tvm_ffi_main = "__tvm_ffi_main__";
/*! \brief global function to set device */
constexpr const char* tvm_set_device = "__tvm_set_device";
/*! \brief Auxiliary counter to global barrier. */
constexpr const char* tvm_global_barrier_state = "__tvm_global_barrier_state";
/*! \brief Prepare the global barrier before kernels that uses global barrier.
*/
constexpr const char* tvm_prepare_global_barrier =
"__tvm_prepare_global_barrier";
-/*! \brief Placeholder for the module's entry function. */
-constexpr const char* tvm_module_main = "__tvm_main__";
} // namespace symbol
// implementations of inline functions.
diff --git a/jvm/core/src/main/java/org/apache/tvm/Module.java
b/jvm/core/src/main/java/org/apache/tvm/Module.java
index 5e78e26ae7..9fa65054f9 100644
--- a/jvm/core/src/main/java/org/apache/tvm/Module.java
+++ b/jvm/core/src/main/java/org/apache/tvm/Module.java
@@ -46,7 +46,7 @@ public class Module extends TVMObject {
}
private Function entry = null;
- private final String entryName = "__tvm_main__";
+ private final String entryName = "__tvm_ffi_main__";
/**
diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py
index 3dd4de5da0..e645d3a2b6 100644
--- a/python/tvm/runtime/module.py
+++ b/python/tvm/runtime/module.py
@@ -103,7 +103,7 @@ class Module(tvm.ffi.Object):
def __new__(cls):
instance = super(Module, cls).__new__(cls) # pylint:
disable=no-value-for-parameter
- instance.entry_name = "__tvm_main__"
+ instance.entry_name = "__tvm_ffi_main__"
instance._entry = None
return instance
@@ -118,7 +118,7 @@ class Module(tvm.ffi.Object):
"""
if self._entry:
return self._entry
- self._entry = self.get_function("__tvm_main__")
+ self._entry = self.get_function("__tvm_ffi_main__")
return self._entry
def implements_function(self, name, query_imports=False):
diff --git a/src/runtime/cuda/cuda_module.cc b/src/runtime/cuda/cuda_module.cc
index 2435cccf0a..6b71df928d 100644
--- a/src/runtime/cuda/cuda_module.cc
+++ b/src/runtime/cuda/cuda_module.cc
@@ -258,7 +258,6 @@ class CUDAPrepGlobalBarrier {
ffi::Function CUDAModuleNode::GetFunction(const String& name,
const ObjectPtr<Object>&
sptr_to_self) {
ICHECK_EQ(sptr_to_self.get(), this);
- ICHECK_NE(name, symbol::tvm_module_main) << "Device function do not have
main";
if (name == symbol::tvm_prepare_global_barrier) {
return ffi::Function(CUDAPrepGlobalBarrier(this, sptr_to_self));
}
diff --git a/src/runtime/library_module.cc b/src/runtime/library_module.cc
index fffac4adea..24fc7518d6 100644
--- a/src/runtime/library_module.cc
+++ b/src/runtime/library_module.cc
@@ -50,15 +50,7 @@ class LibraryModuleNode final : public ModuleNode {
ffi::Function GetFunction(const String& name, const ObjectPtr<Object>&
sptr_to_self) final {
TVMFFISafeCallType faddr;
- if (name == runtime::symbol::tvm_module_main) {
- const char* entry_name =
- reinterpret_cast<const
char*>(lib_->GetSymbol(runtime::symbol::tvm_module_main));
- ICHECK(entry_name != nullptr)
- << "Symbol " << runtime::symbol::tvm_module_main << " is not
presented";
- faddr =
reinterpret_cast<TVMFFISafeCallType>(lib_->GetSymbol(entry_name));
- } else {
- faddr =
reinterpret_cast<TVMFFISafeCallType>(lib_->GetSymbol(name.c_str()));
- }
+ faddr =
reinterpret_cast<TVMFFISafeCallType>(lib_->GetSymbol(name.c_str()));
if (faddr == nullptr) return ffi::Function();
return packed_func_wrapper_(faddr, sptr_to_self);
}
diff --git a/src/runtime/metal/metal_module.mm
b/src/runtime/metal/metal_module.mm
index be36e6197f..33bb1705c8 100644
--- a/src/runtime/metal/metal_module.mm
+++ b/src/runtime/metal/metal_module.mm
@@ -264,7 +264,6 @@ ffi::Function MetalModuleNode::GetFunction(const String&
name,
ffi::Function ret;
AUTORELEASEPOOL {
ICHECK_EQ(sptr_to_self.get(), this);
- ICHECK_NE(name, symbol::tvm_module_main) << "Device function do not have
main";
auto it = fmap_.find(name);
if (it == fmap_.end()) {
ret = ffi::Function();
diff --git a/src/runtime/opencl/opencl_module.cc
b/src/runtime/opencl/opencl_module.cc
index 19b426d4b4..1c61eeb596 100644
--- a/src/runtime/opencl/opencl_module.cc
+++ b/src/runtime/opencl/opencl_module.cc
@@ -138,7 +138,6 @@ cl::OpenCLWorkspace*
OpenCLModuleNodeBase::GetGlobalWorkspace() {
ffi::Function OpenCLModuleNodeBase::GetFunction(const String& name,
const ObjectPtr<Object>&
sptr_to_self) {
ICHECK_EQ(sptr_to_self.get(), this);
- ICHECK_NE(name, symbol::tvm_module_main) << "Device function do not have
main";
auto it = fmap_.find(name);
if (it == fmap_.end()) return ffi::Function();
const FunctionInfo& info = it->second;
diff --git a/src/runtime/rocm/rocm_module.cc b/src/runtime/rocm/rocm_module.cc
index 791e4b1569..a871a41f0f 100644
--- a/src/runtime/rocm/rocm_module.cc
+++ b/src/runtime/rocm/rocm_module.cc
@@ -195,7 +195,6 @@ class ROCMWrappedFunc {
ffi::Function ROCMModuleNode::GetFunction(const String& name,
const ObjectPtr<Object>&
sptr_to_self) {
ICHECK_EQ(sptr_to_self.get(), this);
- ICHECK_NE(name, symbol::tvm_module_main) << "Device function do not have
main";
auto it = fmap_.find(name);
if (it == fmap_.end()) return ffi::Function();
const FunctionInfo& info = it->second;
diff --git a/src/runtime/vulkan/vulkan_wrapped_func.cc
b/src/runtime/vulkan/vulkan_wrapped_func.cc
index f4922a1bf0..db81c959dc 100644
--- a/src/runtime/vulkan/vulkan_wrapped_func.cc
+++ b/src/runtime/vulkan/vulkan_wrapped_func.cc
@@ -208,7 +208,6 @@ VulkanModuleNode::~VulkanModuleNode() {
ffi::Function VulkanModuleNode::GetFunction(const String& name,
const ObjectPtr<Object>&
sptr_to_self) {
ICHECK_EQ(sptr_to_self.get(), this);
- ICHECK_NE(name, symbol::tvm_module_main) << "Device function do not have
main";
auto it = fmap_.find(name);
if (it == fmap_.end()) return ffi::Function();
const FunctionInfo& info = it->second;
diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc
index 4dd24026c0..6271d4edbe 100644
--- a/src/target/llvm/codegen_cpu.cc
+++ b/src/target/llvm/codegen_cpu.cc
@@ -229,28 +229,42 @@ void CodeGenCPU::AddFunction(const GlobalVar& gvar, const
PrimFunc& func) {
}
void CodeGenCPU::AddMainFunction(const std::string& entry_func_name) {
- llvm::Function* f = module_->getFunction(entry_func_name);
- ICHECK(f) << "Function " << entry_func_name << "does not in module";
- llvm::Type* type = llvm::ArrayType::get(t_char_, entry_func_name.length() +
1);
- llvm::GlobalVariable* global =
- new llvm::GlobalVariable(*module_, type, true,
llvm::GlobalValue::WeakAnyLinkage, nullptr,
- runtime::symbol::tvm_module_main);
-#if TVM_LLVM_VERSION >= 100
- global->setAlignment(llvm::Align(1));
-#else
- global->setAlignment(1);
-#endif
- // comdat is needed for windows select any linking to work
- // set comdat to Any(weak linking)
+ // create a wrapper function with tvm_ffi_main name and redirects to the
entry function
+ llvm::Function* target_func = module_->getFunction(entry_func_name);
+ ICHECK(target_func) << "Function " << entry_func_name << " does not exist in
module";
+
+ // Create wrapper function
+ llvm::Function* wrapper_func =
+ llvm::Function::Create(target_func->getFunctionType(),
llvm::Function::WeakAnyLinkage,
+ runtime::symbol::tvm_ffi_main, module_.get());
+
+ // Set attributes (Windows comdat, DLL export, etc.)
if
(llvm_target_->GetOrCreateTargetMachine()->getTargetTriple().isOSWindows()) {
- llvm::Comdat* comdat =
module_->getOrInsertComdat(runtime::symbol::tvm_module_main);
+ llvm::Comdat* comdat =
module_->getOrInsertComdat(runtime::symbol::tvm_ffi_main);
comdat->setSelectionKind(llvm::Comdat::Any);
- global->setComdat(comdat);
+ wrapper_func->setComdat(comdat);
}
- global->setInitializer(
- llvm::ConstantDataArray::getString(*llvm_target_->GetContext(),
entry_func_name));
- global->setDLLStorageClass(llvm::GlobalVariable::DLLExportStorageClass);
+ wrapper_func->setCallingConv(llvm::CallingConv::C);
+
wrapper_func->setDLLStorageClass(llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass);
+
+ // Create simple tail call
+ llvm::BasicBlock* entry =
+ llvm::BasicBlock::Create(*llvm_target_->GetContext(), "entry",
wrapper_func);
+ builder_->SetInsertPoint(entry);
+
+ // Forward all arguments to target function
+ std::vector<llvm::Value*> call_args;
+ for (llvm::Value& arg : wrapper_func->args()) {
+ call_args.push_back(&arg);
+ }
+
+ llvm::Value* result = builder_->CreateCall(target_func, call_args);
+ if (target_func->getReturnType()->isVoidTy()) {
+ builder_->CreateRetVoid();
+ } else {
+ builder_->CreateRet(result);
+ }
}
std::unique_ptr<llvm::Module> CodeGenCPU::Finish() {
diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc
index a9e09652ee..2daf941edf 100644
--- a/src/target/llvm/llvm_module.cc
+++ b/src/target/llvm/llvm_module.cc
@@ -190,15 +190,7 @@ ffi::Function LLVMModuleNode::GetFunction(const String&
name,
TVMFFISafeCallType faddr;
With<LLVMTarget> llvm_target(*llvm_instance_,
LLVMTarget::GetTargetMetadata(*module_));
- if (name == runtime::symbol::tvm_module_main) {
- const char* entry_name = reinterpret_cast<const char*>(
- GetGlobalAddr(runtime::symbol::tvm_module_main, *llvm_target));
- ICHECK(entry_name != nullptr) << "Symbol " <<
runtime::symbol::tvm_module_main
- << " is not presented";
- faddr = reinterpret_cast<TVMFFISafeCallType>(GetFunctionAddr(entry_name,
*llvm_target));
- } else {
- faddr = reinterpret_cast<TVMFFISafeCallType>(GetFunctionAddr(name,
*llvm_target));
- }
+ faddr = reinterpret_cast<TVMFFISafeCallType>(GetFunctionAddr(name,
*llvm_target));
if (faddr == nullptr) return ffi::Function();
return tvm::runtime::WrapFFIFunction(faddr, sptr_to_self);
}
diff --git a/src/target/source/codegen_c_host.cc
b/src/target/source/codegen_c_host.cc
index 6cd12a9319..020054b3e1 100644
--- a/src/target/source/codegen_c_host.cc
+++ b/src/target/source/codegen_c_host.cc
@@ -77,11 +77,11 @@ void CodeGenCHost::AddFunction(const GlobalVar& gvar, const
PrimFunc& func,
<< "CodeGenCHost: The entry func must have the global_symbol
attribute, "
<< "but function " << gvar << " only has attributes " << func->attrs;
- function_names_.push_back(runtime::symbol::tvm_module_main);
+ function_names_.push_back(runtime::symbol::tvm_ffi_main);
stream << "// CodegenC: NOTE: Auto-generated entry function\n";
PrintFuncPrefix(stream);
PrintType(func->ret_type, stream);
- stream << " " << tvm::runtime::symbol::tvm_module_main
+ stream << " " << tvm::runtime::symbol::tvm_ffi_main
<< "(void* self, void* args,int num_args, void* result) {\n";
stream << " return " << global_symbol.value() << "(self, args, num_args,
result);\n";
stream << "}\n";
diff --git a/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py
b/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py
index fe7a615531..9461da2277 100644
--- a/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py
+++ b/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py
@@ -289,9 +289,13 @@ def evaluate(
if tvm.testing.utils.IS_IN_CI:
# Run with reduced number and repeat for CI
- timer = module.time_evaluator("__tvm_main__", hexagon_session.device,
number=1, repeat=1)
+ timer = module.time_evaluator(
+ "__tvm_ffi_main__", hexagon_session.device, number=1, repeat=1
+ )
else:
- timer = module.time_evaluator("__tvm_main__", hexagon_session.device,
number=10, repeat=10)
+ timer = module.time_evaluator(
+ "__tvm_ffi_main__", hexagon_session.device, number=10, repeat=10
+ )
time = timer(a_hexagon, b_hexagon, c_hexagon)
if expected_output is not None:
diff --git a/tests/python/contrib/test_hexagon/test_parallel_hvx.py
b/tests/python/contrib/test_hexagon/test_parallel_hvx.py
index 8f77fa1c40..6822352568 100644
--- a/tests/python/contrib/test_hexagon/test_parallel_hvx.py
+++ b/tests/python/contrib/test_hexagon/test_parallel_hvx.py
@@ -160,7 +160,7 @@ def evaluate(hexagon_session, shape_dtypes,
expected_output_producer, sch):
repeat = 1
timer = module.time_evaluator(
- "__tvm_main__", hexagon_session.device, number=number, repeat=repeat
+ "__tvm_ffi_main__", hexagon_session.device, number=number,
repeat=repeat
)
runtime = timer(a_hexagon, b_hexagon, c_hexagon)
tvm.testing.assert_allclose(c_hexagon.numpy(),
expected_output_producer(c_shape, a, b))
diff --git a/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py
b/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py
index a584997dd5..17e31af0a7 100644
--- a/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py
+++ b/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py
@@ -331,7 +331,7 @@ def setup_and_run(hexagon_session, sch, a, b, c,
operations, mem_scope="global")
repeat = 1
timer = module.time_evaluator(
- "__tvm_main__", hexagon_session.device, number=number, repeat=repeat
+ "__tvm_ffi_main__", hexagon_session.device, number=number,
repeat=repeat
)
time = timer(a_hexagon, b_hexagon, c_hexagon)
gops = round(operations * 128 * 3 / time.mean / 1e9, 4)
@@ -365,7 +365,7 @@ def setup_and_run_preallocated(hexagon_session, sch, a, b,
c, operations):
repeat = 1
timer = module.time_evaluator(
- "__tvm_main__", hexagon_session.device, number=number, repeat=repeat
+ "__tvm_ffi_main__", hexagon_session.device, number=number,
repeat=repeat
)
time = timer(a_hexagon, b_hexagon, c_hexagon, a_vtcm_hexagon,
b_vtcm_hexagon, c_vtcm_hexagon)
gops = round(operations * 128 * 3 / time.mean / 1e9, 4)
diff --git a/tests/python/contrib/test_hexagon/test_parallel_scalar.py
b/tests/python/contrib/test_hexagon/test_parallel_scalar.py
index bd9c78d5da..dd765178dc 100644
--- a/tests/python/contrib/test_hexagon/test_parallel_scalar.py
+++ b/tests/python/contrib/test_hexagon/test_parallel_scalar.py
@@ -105,7 +105,7 @@ def evaluate(hexagon_session, operations, expected, sch):
repeat = 1
timer = module.time_evaluator(
- "__tvm_main__", hexagon_session.device, number=number, repeat=repeat
+ "__tvm_ffi_main__", hexagon_session.device, number=number,
repeat=repeat
)
runtime = timer(a_hexagon, b_hexagon, c_hexagon)
diff --git a/tests/python/contrib/test_hexagon/test_vtcm_bandwidth.py
b/tests/python/contrib/test_hexagon/test_vtcm_bandwidth.py
index 931f99b2ec..265f2bf5fd 100644
--- a/tests/python/contrib/test_hexagon/test_vtcm_bandwidth.py
+++ b/tests/python/contrib/test_hexagon/test_vtcm_bandwidth.py
@@ -108,9 +108,13 @@ def evaluate(hexagon_session, sch, size):
if tvm.testing.utils.IS_IN_CI:
# Run with reduced number and repeat for CI
- timer = module.time_evaluator("__tvm_main__", hexagon_session.device,
number=1, repeat=1)
+ timer = module.time_evaluator(
+ "__tvm_ffi_main__", hexagon_session.device, number=1, repeat=1
+ )
else:
- timer = module.time_evaluator("__tvm_main__", hexagon_session.device,
number=10, repeat=10)
+ timer = module.time_evaluator(
+ "__tvm_ffi_main__", hexagon_session.device, number=10, repeat=10
+ )
runtime = timer(a_hexagon, a_vtcm_hexagon)