This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main-mod
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit b7dbb7430747cee3bf6a1b3729a57a0f601a5654
Author: tqchen <[email protected]>
AuthorDate: Tue Aug 12 10:22:42 2025 -0400

    [REFATOR] Phase out entry_func
    
    This PR phases out the entry function. Previously module itself have a 
default entry function.
    Supporting such feature requires extra indirection and code logic to 
support the special case.
    
    Most of our use cases can move toward explicitly naming the function as 
mod["main"].
    That does mean that we can no longer implicitly call the module and instead 
need to
    do mod["name"] to lookup the function.
---
 include/tvm/runtime/module.h                       |  2 -
 include/tvm/tir/analysis.h                         |  2 +-
 include/tvm/tir/function.h                         | 10 -----
 include/tvm/tir/transform.h                        |  6 ---
 jvm/README.md                                      |  2 +-
 jvm/core/src/main/java/org/apache/tvm/Module.java  | 14 -------
 python/tvm/runtime/executable.py                   |  4 --
 python/tvm/runtime/module.py                       | 22 -----------
 python/tvm/tir/build.py                            |  3 +-
 python/tvm/tir/pipeline.py                         |  1 -
 python/tvm/tir/transform/transform.py              | 11 ------
 src/meta_schedule/arg_info.cc                      | 10 ++---
 src/runtime/cuda/cuda_module.cc                    |  1 -
 src/runtime/library_module.cc                      | 10 +----
 src/runtime/metal/metal_module.mm                  |  1 -
 src/runtime/opencl/opencl_module.cc                |  1 -
 src/runtime/rocm/rocm_module.cc                    |  1 -
 src/runtime/vulkan/vulkan_wrapped_func.cc          |  1 -
 src/target/llvm/codegen_cpu.cc                     | 25 -------------
 src/target/llvm/codegen_hexagon.cc                 | 11 ------
 src/target/llvm/llvm_module.cc                     | 19 +---------
 src/target/source/codegen_c_host.cc                | 14 -------
 src/tir/ir/transform.cc                            |  1 -
 src/tir/transforms/primfunc_utils.cc               | 43 +---------------------
 tests/python/codegen/test_target_codegen_device.py |  2 +-
 .../test_hexagon/test_async_dma_pipeline.py        |  4 +-
 .../contrib/test_hexagon/test_parallel_hvx.py      |  2 +-
 .../test_hexagon/test_parallel_hvx_load_vtcm.py    |  8 +---
 .../contrib/test_hexagon/test_parallel_scalar.py   |  4 +-
 .../contrib/test_hexagon/test_vtcm_bandwidth.py    |  4 +-
 .../test_runtime_builtin_kv_cache_transfer.py      |  2 +-
 ...runtime_builtin_paged_attention_kv_cache_cpu.py |  2 +-
 ..._builtin_paged_attention_kv_cache_flashinfer.py |  2 +-
 ...ltin_paged_attention_kv_cache_mla_flashinfer.py |  2 +-
 ...ime_builtin_paged_attention_kv_cache_mla_tir.py |  2 +-
 ...runtime_builtin_paged_attention_kv_cache_tir.py |  2 +-
 .../python/relax/test_runtime_builtin_rnn_state.py |  2 +-
 .../tir-transform/test_tir_transform_helpers.py    | 31 ----------------
 tests/python/tvmscript/test_tvmscript_roundtrip.py |  2 -
 39 files changed, 25 insertions(+), 261 deletions(-)

diff --git a/include/tvm/runtime/module.h b/include/tvm/runtime/module.h
index efbaa6508a..05b57de39d 100644
--- a/include/tvm/runtime/module.h
+++ b/include/tvm/runtime/module.h
@@ -296,8 +296,6 @@ constexpr const char* tvm_set_device = "__tvm_set_device";
 constexpr const char* tvm_global_barrier_state = "__tvm_global_barrier_state";
 /*! \brief Prepare the global barrier before kernels that uses global barrier. 
*/
 constexpr const char* tvm_prepare_global_barrier = 
"__tvm_prepare_global_barrier";
-/*! \brief Placeholder for the module's entry function. */
-constexpr const char* tvm_module_main = "__tvm_main__";
 }  // namespace symbol
 
 // implementations of inline functions.
diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h
index a21112b7d6..912a4449d0 100644
--- a/include/tvm/tir/analysis.h
+++ b/include/tvm/tir/analysis.h
@@ -356,7 +356,7 @@ TVM_DLL bool VerifyWellFormed(const IRModule& mod, bool 
assert_mode = true);
 
 /*!
  * \brief Find the entry function of the given IRModule, i.e, functions marked 
by
- * `tir::attr::kIsEntryFunc`, whose name is `main` or being the only PrimeFunc.
+ * whose name is `main` or being the only PrimeFunc.
  * \param mod The IRModule to find the entry function.
  * \param result_g_var The result GlobalVar of the entry function.
  * \return The entry function.
diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h
index 6ea50e9ae0..ff9a6ff927 100644
--- a/include/tvm/tir/function.h
+++ b/include/tvm/tir/function.h
@@ -316,16 +316,6 @@ constexpr const char* kKernelLaunchParams = 
"tir.kernel_launch_params";
  */
 constexpr const char* kNoAlias = "tir.noalias";
 
-/*!
- * \brief Mark the function as the entry function of
- *        the final generated runtime module.
- *
- * Type: Integer
- *
- * \note There can only be one entry function per module.
- */
-constexpr const char* kIsEntryFunc = "tir.is_entry_func";
-
 /*!
  * \brief Mark the function as the global function called from the host.
  *
diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h
index eb64d87f95..c7af05e7f2 100644
--- a/include/tvm/tir/transform.h
+++ b/include/tvm/tir/transform.h
@@ -703,12 +703,6 @@ TVM_DLL Pass RenormalizeSplitPattern();
  */
 TVM_DLL Pass BindTarget(Target target);
 
-/*!
- * \brief Set a PrimFunc as the entry point if it is only function in IRModule.
- * \return The pass.
- */
-TVM_DLL Pass AnnotateEntryFunc();
-
 /*!
  * \brief Filter PrimFuncs with a given condition.
  * \return The pass.
diff --git a/jvm/README.md b/jvm/README.md
index 71c737a4d0..051f2ccdc5 100644
--- a/jvm/README.md
+++ b/jvm/README.md
@@ -113,7 +113,7 @@ public class LoadAddFunc {
     arr.copyFrom(new float[]{3f, 4f});
     NDArray res = NDArray.empty(shape, dev);
 
-    fadd.entryFunc().pushArg(arr).pushArg(arr).pushArg(res).invoke();
+    fadd.getFunction("main").pushArg(arr).pushArg(arr).pushArg(res).invoke();
     System.out.println(Arrays.toString(res.asFloatArray()));
 
     arr.release();
diff --git a/jvm/core/src/main/java/org/apache/tvm/Module.java 
b/jvm/core/src/main/java/org/apache/tvm/Module.java
index 5e78e26ae7..d2c949035a 100644
--- a/jvm/core/src/main/java/org/apache/tvm/Module.java
+++ b/jvm/core/src/main/java/org/apache/tvm/Module.java
@@ -45,10 +45,6 @@ public class Module extends TVMObject {
     super(handle, TypeIndex.kTVMFFIModule);
   }
 
-  private Function entry = null;
-  private final String entryName = "__tvm_main__";
-
-
   /**
    * Easy for user to get the instance from returned TVMValue.
    * @return this
@@ -57,16 +53,6 @@ public class Module extends TVMObject {
     return this;
   }
 
-  /**
-   * Get the entry function.
-   * @return The entry function if exist
-   */
-  public Function entryFunc() {
-    if (entry == null) {
-      entry = getFunction(entryName);
-    }
-    return entry;
-  }
 
   /**
    * Get function from the module.
diff --git a/python/tvm/runtime/executable.py b/python/tvm/runtime/executable.py
index b6e13a65a9..a1a6606765 100644
--- a/python/tvm/runtime/executable.py
+++ b/python/tvm/runtime/executable.py
@@ -36,10 +36,6 @@ class Executable:
         """Get the PackedFunc from the jitted module."""
         return self.jit().get_function(name, query_imports=True)
 
-    def __call__(self, *args, **kwargs) -> Any:
-        """Call the executable."""
-        return self.jit().entry_func(*args, **kwargs)
-
     def jit(
         self,
         *,
diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py
index 3dd4de5da0..30f83474dc 100644
--- a/python/tvm/runtime/module.py
+++ b/python/tvm/runtime/module.py
@@ -103,24 +103,8 @@ class Module(tvm.ffi.Object):
 
     def __new__(cls):
         instance = super(Module, cls).__new__(cls)  # pylint: 
disable=no-value-for-parameter
-        instance.entry_name = "__tvm_main__"
-        instance._entry = None
         return instance
 
-    @property
-    def entry_func(self):
-        """Get the entry function
-
-        Returns
-        -------
-        f : tvm.runtime.PackedFunc
-            The entry function if exist
-        """
-        if self._entry:
-            return self._entry
-        self._entry = self.get_function("__tvm_main__")
-        return self._entry
-
     def implements_function(self, name, query_imports=False):
         """Returns True if the module has a definition for the global function 
with name. Note
         that has_function(name) does not imply get_function(name) is non-null 
since the module
@@ -179,12 +163,6 @@ class Module(tvm.ffi.Object):
             raise ValueError("Can only take string as function name")
         return self.get_function(name)
 
-    def __call__(self, *args):
-        if self._entry:
-            return self._entry(*args)
-        # pylint: disable=not-callable
-        return self.entry_func(*args)
-
     @property
     def type_key(self):
         """Get type key of the module."""
diff --git a/python/tvm/tir/build.py b/python/tvm/tir/build.py
index 98e549cc9c..431c601e72 100644
--- a/python/tvm/tir/build.py
+++ b/python/tvm/tir/build.py
@@ -80,8 +80,7 @@ def split_host_device_mods(mod: IRModule) -> Tuple[IRModule, 
Dict[Target, IRModu
             @T.prim_func
             def main(self_handle: T.handle, args: T.handle, num_args: T.int32, 
result: T.handle):
                 T.func_attr({"target": T.target({"keys": ["cpu"], "kind": 
"c"}),
-                            "calling_conv": 1,  # kCPackedFunc for entry 
functions
-                            "tir.is_entry_func": True})
+                            "calling_conv": 1})
                 # ... main function implementation
 
     The function will return:
diff --git a/python/tvm/tir/pipeline.py b/python/tvm/tir/pipeline.py
index ae78b05738..1082cd8fac 100644
--- a/python/tvm/tir/pipeline.py
+++ b/python/tvm/tir/pipeline.py
@@ -89,7 +89,6 @@ def default_tir_pipeline():
                 tir.transform.VerifyVTCMLimit(),
                 tir.transform.LowerVtcmAlloc(),
                 tir.transform.VerifyMemory(),
-                tir.transform.AnnotateEntryFunc(),
             ]
         )
         if bool(config.get("tir.detect_global_barrier", False)):
diff --git a/python/tvm/tir/transform/transform.py 
b/python/tvm/tir/transform/transform.py
index 93a182ca3b..178a203ca5 100644
--- a/python/tvm/tir/transform/transform.py
+++ b/python/tvm/tir/transform/transform.py
@@ -1018,17 +1018,6 @@ def BindTarget(target):
     return _ffi_api.BindTarget(target)  # type: ignore
 
 
-def AnnotateEntryFunc():
-    """Set a PrimFunc as the entry point if it is only function in IRModule.
-
-    Returns
-    -------
-    fpass : tvm.transform.Pass
-        The result pass
-    """
-    return _ffi_api.AnnotateEntryFunc()  # type: ignore
-
-
 def Filter(fcond: Callable):
     """Filter out PrimFuncs that does not satisfy the given condition.
     `fcond` should be a function that takes a primfunc and returns boolean.
diff --git a/src/meta_schedule/arg_info.cc b/src/meta_schedule/arg_info.cc
index 9c2ba084ad..c46a0bf280 100644
--- a/src/meta_schedule/arg_info.cc
+++ b/src/meta_schedule/arg_info.cc
@@ -25,12 +25,12 @@ namespace meta_schedule {
 
 /*!
  * \brief Find the entry function of the given IRModule, i.e, functions marked 
by
- * `tir::attr::kIsEntryFunc`, whose name is `main` or being the only PrimeFunc.
+ * whose name is `main` or being the only PrimeFunc.
  * \param mod The IRModule to find the entry function.
  * \return The entry function.
  */
 inline tir::PrimFunc FindEntryFunc(const IRModule& mod) {
-  // Priority 1: PrimFunc marked as `tir::attr::kIsEntryFunc`
+  // Priority 1: PrimFunc marked as `main`
   int num_prim_func = 0;
   const tir::PrimFuncNode* main_func = nullptr;
   const tir::PrimFuncNode* last_func = nullptr;
@@ -39,9 +39,6 @@ inline tir::PrimFunc FindEntryFunc(const IRModule& mod) {
     BaseFunc base_func = kv.second;
     if (const auto* func = base_func.as<tir::PrimFuncNode>()) {
       last_func = func;
-      if (func->HasNonzeroAttr(tir::attr::kIsEntryFunc)) {
-        return GetRef<tir::PrimFunc>(func);
-      }
       if (gv->name_hint == "main") {
         main_func = func;
       }
@@ -57,8 +54,7 @@ inline tir::PrimFunc FindEntryFunc(const IRModule& mod) {
     LOG(FATAL) << "ValueError: Cannot find any PrimFunc in the given IRModule: 
" << mod;
   }
   if (num_prim_func > 1) {
-    LOG(FATAL) << "ValueError: Multiple PrimFuncs exist in the IRModule, but 
none of them are "
-                  "annotated with `kIsEntryFunc`, i.e. `tir.is_entry_func`"
+    LOG(FATAL) << "ValueError: Multiple PrimFuncs exist in the IRModule, but 
none of them are main"
                << mod;
   }
   return GetRef<tir::PrimFunc>(last_func);
diff --git a/src/runtime/cuda/cuda_module.cc b/src/runtime/cuda/cuda_module.cc
index 2435cccf0a..6b71df928d 100644
--- a/src/runtime/cuda/cuda_module.cc
+++ b/src/runtime/cuda/cuda_module.cc
@@ -258,7 +258,6 @@ class CUDAPrepGlobalBarrier {
 ffi::Function CUDAModuleNode::GetFunction(const String& name,
                                           const ObjectPtr<Object>& 
sptr_to_self) {
   ICHECK_EQ(sptr_to_self.get(), this);
-  ICHECK_NE(name, symbol::tvm_module_main) << "Device function do not have 
main";
   if (name == symbol::tvm_prepare_global_barrier) {
     return ffi::Function(CUDAPrepGlobalBarrier(this, sptr_to_self));
   }
diff --git a/src/runtime/library_module.cc b/src/runtime/library_module.cc
index fffac4adea..24fc7518d6 100644
--- a/src/runtime/library_module.cc
+++ b/src/runtime/library_module.cc
@@ -50,15 +50,7 @@ class LibraryModuleNode final : public ModuleNode {
 
   ffi::Function GetFunction(const String& name, const ObjectPtr<Object>& 
sptr_to_self) final {
     TVMFFISafeCallType faddr;
-    if (name == runtime::symbol::tvm_module_main) {
-      const char* entry_name =
-          reinterpret_cast<const 
char*>(lib_->GetSymbol(runtime::symbol::tvm_module_main));
-      ICHECK(entry_name != nullptr)
-          << "Symbol " << runtime::symbol::tvm_module_main << " is not 
presented";
-      faddr = 
reinterpret_cast<TVMFFISafeCallType>(lib_->GetSymbol(entry_name));
-    } else {
-      faddr = 
reinterpret_cast<TVMFFISafeCallType>(lib_->GetSymbol(name.c_str()));
-    }
+    faddr = 
reinterpret_cast<TVMFFISafeCallType>(lib_->GetSymbol(name.c_str()));
     if (faddr == nullptr) return ffi::Function();
     return packed_func_wrapper_(faddr, sptr_to_self);
   }
diff --git a/src/runtime/metal/metal_module.mm 
b/src/runtime/metal/metal_module.mm
index be36e6197f..33bb1705c8 100644
--- a/src/runtime/metal/metal_module.mm
+++ b/src/runtime/metal/metal_module.mm
@@ -264,7 +264,6 @@ ffi::Function MetalModuleNode::GetFunction(const String& 
name,
   ffi::Function ret;
   AUTORELEASEPOOL {
     ICHECK_EQ(sptr_to_self.get(), this);
-    ICHECK_NE(name, symbol::tvm_module_main) << "Device function do not have 
main";
     auto it = fmap_.find(name);
     if (it == fmap_.end()) {
       ret = ffi::Function();
diff --git a/src/runtime/opencl/opencl_module.cc 
b/src/runtime/opencl/opencl_module.cc
index 19b426d4b4..1c61eeb596 100644
--- a/src/runtime/opencl/opencl_module.cc
+++ b/src/runtime/opencl/opencl_module.cc
@@ -138,7 +138,6 @@ cl::OpenCLWorkspace* 
OpenCLModuleNodeBase::GetGlobalWorkspace() {
 ffi::Function OpenCLModuleNodeBase::GetFunction(const String& name,
                                                 const ObjectPtr<Object>& 
sptr_to_self) {
   ICHECK_EQ(sptr_to_self.get(), this);
-  ICHECK_NE(name, symbol::tvm_module_main) << "Device function do not have 
main";
   auto it = fmap_.find(name);
   if (it == fmap_.end()) return ffi::Function();
   const FunctionInfo& info = it->second;
diff --git a/src/runtime/rocm/rocm_module.cc b/src/runtime/rocm/rocm_module.cc
index 791e4b1569..a871a41f0f 100644
--- a/src/runtime/rocm/rocm_module.cc
+++ b/src/runtime/rocm/rocm_module.cc
@@ -195,7 +195,6 @@ class ROCMWrappedFunc {
 ffi::Function ROCMModuleNode::GetFunction(const String& name,
                                           const ObjectPtr<Object>& 
sptr_to_self) {
   ICHECK_EQ(sptr_to_self.get(), this);
-  ICHECK_NE(name, symbol::tvm_module_main) << "Device function do not have 
main";
   auto it = fmap_.find(name);
   if (it == fmap_.end()) return ffi::Function();
   const FunctionInfo& info = it->second;
diff --git a/src/runtime/vulkan/vulkan_wrapped_func.cc 
b/src/runtime/vulkan/vulkan_wrapped_func.cc
index f4922a1bf0..db81c959dc 100644
--- a/src/runtime/vulkan/vulkan_wrapped_func.cc
+++ b/src/runtime/vulkan/vulkan_wrapped_func.cc
@@ -208,7 +208,6 @@ VulkanModuleNode::~VulkanModuleNode() {
 ffi::Function VulkanModuleNode::GetFunction(const String& name,
                                             const ObjectPtr<Object>& 
sptr_to_self) {
   ICHECK_EQ(sptr_to_self.get(), this);
-  ICHECK_NE(name, symbol::tvm_module_main) << "Device function do not have 
main";
   auto it = fmap_.find(name);
   if (it == fmap_.end()) return ffi::Function();
   const FunctionInfo& info = it->second;
diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc
index 4dd24026c0..e75ac025b9 100644
--- a/src/target/llvm/codegen_cpu.cc
+++ b/src/target/llvm/codegen_cpu.cc
@@ -228,31 +228,6 @@ void CodeGenCPU::AddFunction(const GlobalVar& gvar, const 
PrimFunc& func) {
   AddDebugInformation(function_, func->params.Map(GetType));
 }
 
-void CodeGenCPU::AddMainFunction(const std::string& entry_func_name) {
-  llvm::Function* f = module_->getFunction(entry_func_name);
-  ICHECK(f) << "Function " << entry_func_name << "does not in module";
-  llvm::Type* type = llvm::ArrayType::get(t_char_, entry_func_name.length() + 
1);
-  llvm::GlobalVariable* global =
-      new llvm::GlobalVariable(*module_, type, true, 
llvm::GlobalValue::WeakAnyLinkage, nullptr,
-                               runtime::symbol::tvm_module_main);
-#if TVM_LLVM_VERSION >= 100
-  global->setAlignment(llvm::Align(1));
-#else
-  global->setAlignment(1);
-#endif
-  // comdat is needed for windows select any linking to work
-  // set comdat to Any(weak linking)
-  if 
(llvm_target_->GetOrCreateTargetMachine()->getTargetTriple().isOSWindows()) {
-    llvm::Comdat* comdat = 
module_->getOrInsertComdat(runtime::symbol::tvm_module_main);
-    comdat->setSelectionKind(llvm::Comdat::Any);
-    global->setComdat(comdat);
-  }
-
-  global->setInitializer(
-      llvm::ConstantDataArray::getString(*llvm_target_->GetContext(), 
entry_func_name));
-  global->setDLLStorageClass(llvm::GlobalVariable::DLLExportStorageClass);
-}
-
 std::unique_ptr<llvm::Module> CodeGenCPU::Finish() {
   // link modules
   if (dbg_info_ != nullptr) {
diff --git a/src/target/llvm/codegen_hexagon.cc 
b/src/target/llvm/codegen_hexagon.cc
index 6f90da3d8a..3d8ed08eee 100644
--- a/src/target/llvm/codegen_hexagon.cc
+++ b/src/target/llvm/codegen_hexagon.cc
@@ -483,9 +483,6 @@ runtime::Module BuildHexagon(IRModule mod, Target target) {
   (void)CallOnce;
 
   auto cg = std::make_unique<CodeGenHexagon>();
-
-  std::string entry_func;
-
   for (auto kv : mod->functions) {
     if (!kv.second->IsInstance<PrimFuncNode>()) {
       // (@jroesch): we relax constraints here, relax functions will just be 
ignored.
@@ -493,18 +490,10 @@ runtime::Module BuildHexagon(IRModule mod, Target target) 
{
       continue;
     }
     auto f = Downcast<PrimFunc>(kv.second);
-    if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) {
-      auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
-      ICHECK(global_symbol.has_value());
-      entry_func = global_symbol.value();
-    }
   }
 
   cg->Init("TVMHexagonModule", llvm_target.get(), std::nullopt, false, false);
   cg->AddFunctionsOrdered(mod->functions.begin(), mod->functions.end());
-  if (entry_func.length() != 0) {
-    cg->AddMainFunction(entry_func);
-  }
 
   // Uncomment to get the LLVM module right out of codegen, before 
optimizations.
   // std::cerr << "HexagonModule.0 {\n" << *cg->GetModulePtr() << "}\n";
diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc
index a9e09652ee..6a12e59269 100644
--- a/src/target/llvm/llvm_module.cc
+++ b/src/target/llvm/llvm_module.cc
@@ -190,15 +190,7 @@ ffi::Function LLVMModuleNode::GetFunction(const String& 
name,
 
   TVMFFISafeCallType faddr;
   With<LLVMTarget> llvm_target(*llvm_instance_, 
LLVMTarget::GetTargetMetadata(*module_));
-  if (name == runtime::symbol::tvm_module_main) {
-    const char* entry_name = reinterpret_cast<const char*>(
-        GetGlobalAddr(runtime::symbol::tvm_module_main, *llvm_target));
-    ICHECK(entry_name != nullptr) << "Symbol " << 
runtime::symbol::tvm_module_main
-                                  << " is not presented";
-    faddr = reinterpret_cast<TVMFFISafeCallType>(GetFunctionAddr(entry_name, 
*llvm_target));
-  } else {
-    faddr = reinterpret_cast<TVMFFISafeCallType>(GetFunctionAddr(name, 
*llvm_target));
-  }
+  faddr = reinterpret_cast<TVMFFISafeCallType>(GetFunctionAddr(name, 
*llvm_target));
   if (faddr == nullptr) return ffi::Function();
   return tvm::runtime::WrapFFIFunction(faddr, sptr_to_self);
 }
@@ -337,15 +329,9 @@ void LLVMModuleNode::Init(const IRModule& mod, const 
Target& target) {
     }
     auto f = Downcast<PrimFunc>(kv.second);
     auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
-    bool is_entry_func = f->HasNonzeroAttr(tir::attr::kIsEntryFunc);
-
-    ICHECK(global_symbol || !is_entry_func) << "The entry func must be exposed 
externally.";
 
     if (global_symbol) {
       function_names_.push_back(global_symbol.value());
-      if (is_entry_func) {
-        entry_func = global_symbol.value();
-      }
     }
   }
   // TODO(@jroesch): follow up on this condition.
@@ -355,9 +341,6 @@ void LLVMModuleNode::Init(const IRModule& mod, const 
Target& target) {
   cg->Init("TVMMod", llvm_target.get(), system_lib_prefix, 
system_lib_prefix.has_value(), false);
   cg->SetFastMathFlags(llvm_target->GetFastMathFlags());
   cg->AddFunctionsOrdered(mod->functions.begin(), mod->functions.end());
-  if (entry_func.length() != 0) {
-    cg->AddMainFunction(entry_func);
-  }
 
   module_owning_ptr_ = cg->Finish();
   module_ = module_owning_ptr_.get();
diff --git a/src/target/source/codegen_c_host.cc 
b/src/target/source/codegen_c_host.cc
index 6cd12a9319..1c8a3dd2ea 100644
--- a/src/target/source/codegen_c_host.cc
+++ b/src/target/source/codegen_c_host.cc
@@ -72,20 +72,6 @@ void CodeGenCHost::AddFunction(const GlobalVar& gvar, const 
PrimFunc& func,
 
   emit_fwd_func_decl_ = emit_fwd_func_decl;
   CodeGenC::AddFunction(gvar, func);
-  if (func->HasNonzeroAttr(tir::attr::kIsEntryFunc)) {
-    ICHECK(global_symbol.has_value())
-        << "CodeGenCHost: The entry func must have the global_symbol 
attribute, "
-        << "but function " << gvar << " only has attributes " << func->attrs;
-
-    function_names_.push_back(runtime::symbol::tvm_module_main);
-    stream << "// CodegenC: NOTE: Auto-generated entry function\n";
-    PrintFuncPrefix(stream);
-    PrintType(func->ret_type, stream);
-    stream << " " << tvm::runtime::symbol::tvm_module_main
-           << "(void* self, void* args,int num_args, void* result) {\n";
-    stream << "  return " << global_symbol.value() << "(self, args, num_args, 
result);\n";
-    stream << "}\n";
-  }
 }
 
 void CodeGenCHost::GenerateForwardFunctionDeclarations(String global_symbol,
diff --git a/src/tir/ir/transform.cc b/src/tir/ir/transform.cc
index aafe6277e2..6ef7ffdca9 100644
--- a/src/tir/ir/transform.cc
+++ b/src/tir/ir/transform.cc
@@ -42,7 +42,6 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_cse_tir", Bool);
 TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_debug", Bool);
 TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_equiv_terms_in_cse_tir", Bool);
 TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_storage_rewrite", Bool);
-TVM_REGISTER_PASS_CONFIG_OPTION("tir.is_entry_func", Bool);
 TVM_REGISTER_PASS_CONFIG_OPTION("tir.add_lower_pass", Array<Array<ObjectRef>>);
 TVM_REGISTER_PASS_CONFIG_OPTION("tir.debug_keep_trivial_loop", Bool);
 TVM_REGISTER_PASS_CONFIG_OPTION("tir.use_async_copy", Bool);
diff --git a/src/tir/transforms/primfunc_utils.cc 
b/src/tir/transforms/primfunc_utils.cc
index b1f3476eab..99cf901377 100644
--- a/src/tir/transforms/primfunc_utils.cc
+++ b/src/tir/transforms/primfunc_utils.cc
@@ -29,45 +29,6 @@ namespace tvm {
 namespace tir {
 namespace transform {
 
-transform::Pass AnnotateEntryFunc() {
-  auto fpass = [](IRModule mod, transform::PassContext ctx) -> IRModule {
-    // If only a single function exists, that function must be the entry
-    if (mod->functions.size() == 1) {
-      auto [gvar, base_func] = *mod->functions.begin();
-      if (!base_func->HasNonzeroAttr(tir::attr::kIsEntryFunc)) {
-        if (auto ptr = base_func.as<PrimFuncNode>()) {
-          mod->Update(gvar, WithAttr(GetRef<PrimFunc>(ptr), 
tir::attr::kIsEntryFunc, true));
-        }
-      }
-      return mod;
-    }
-
-    // If the module has multiple functions, but only one is exposed
-    // externally, that function must be the entry.
-    bool has_external_non_primfuncs = false;
-    IRModule with_annotations;
-    for (const auto& [gvar, base_func] : mod->functions) {
-      bool is_external = 
base_func->GetAttr<String>(tvm::attr::kGlobalSymbol).has_value();
-      if (is_external) {
-        if (auto ptr = base_func.as<PrimFuncNode>()) {
-          with_annotations->Add(gvar,
-                                WithAttr(GetRef<PrimFunc>(ptr), 
tir::attr::kIsEntryFunc, true));
-        } else {
-          has_external_non_primfuncs = true;
-        }
-      }
-    }
-    if (with_annotations->functions.size() == 1 && 
!has_external_non_primfuncs) {
-      mod->Update(with_annotations);
-      return mod;
-    }
-
-    // Default fallback, no annotations may be inferred.
-    return mod;
-  };
-  return tvm::transform::CreateModulePass(fpass, 0, "tir.AnnotateEntryFunc", 
{});
-}
-
 transform::Pass Filter(ffi::TypedFunction<bool(PrimFunc)> fcond) {
   auto fpass = [fcond](tir::PrimFunc f, IRModule m, transform::PassContext 
ctx) {
     if (fcond(f)) {
@@ -81,9 +42,7 @@ transform::Pass Filter(ffi::TypedFunction<bool(PrimFunc)> 
fcond) {
 
 TVM_FFI_STATIC_INIT_BLOCK({
   namespace refl = tvm::ffi::reflection;
-  refl::GlobalDef()
-      .def("tir.transform.AnnotateEntryFunc", AnnotateEntryFunc)
-      .def("tir.transform.Filter", Filter);
+  refl::GlobalDef().def("tir.transform.Filter", Filter);
 });
 
 }  // namespace transform
diff --git a/tests/python/codegen/test_target_codegen_device.py 
b/tests/python/codegen/test_target_codegen_device.py
index 4dad03d700..0089e0bea6 100644
--- a/tests/python/codegen/test_target_codegen_device.py
+++ b/tests/python/codegen/test_target_codegen_device.py
@@ -95,7 +95,7 @@ def test_add_pipeline():
         dev = tvm.device(device, 0)
         target = tvm.target.Target(device, host)
         mhost = tvm.tir.build(sch.mod, target=target)
-        f = mhost.entry_func
+        f = mhost["main"]
         # launch the kernel.
         n = 1027
         a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev)
diff --git a/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py 
b/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py
index fe7a615531..1a787703ea 100644
--- a/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py
+++ b/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py
@@ -289,9 +289,9 @@ def evaluate(
 
     if tvm.testing.utils.IS_IN_CI:
         # Run with reduced number and repeat for CI
-        timer = module.time_evaluator("__tvm_main__", hexagon_session.device, 
number=1, repeat=1)
+        timer = module.time_evaluator("main", hexagon_session.device, 
number=1, repeat=1)
     else:
-        timer = module.time_evaluator("__tvm_main__", hexagon_session.device, 
number=10, repeat=10)
+        timer = module.time_evaluator("main", hexagon_session.device, 
number=10, repeat=10)
 
     time = timer(a_hexagon, b_hexagon, c_hexagon)
     if expected_output is not None:
diff --git a/tests/python/contrib/test_hexagon/test_parallel_hvx.py 
b/tests/python/contrib/test_hexagon/test_parallel_hvx.py
index 8f77fa1c40..39d46bd1c4 100644
--- a/tests/python/contrib/test_hexagon/test_parallel_hvx.py
+++ b/tests/python/contrib/test_hexagon/test_parallel_hvx.py
@@ -160,7 +160,7 @@ def evaluate(hexagon_session, shape_dtypes, 
expected_output_producer, sch):
     repeat = 1
 
     timer = module.time_evaluator(
-        "__tvm_main__", hexagon_session.device, number=number, repeat=repeat
+        "main", hexagon_session.device, number=number, repeat=repeat
     )
     runtime = timer(a_hexagon, b_hexagon, c_hexagon)
     tvm.testing.assert_allclose(c_hexagon.numpy(), 
expected_output_producer(c_shape, a, b))
diff --git a/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py 
b/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py
index a584997dd5..63a65f3716 100644
--- a/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py
+++ b/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py
@@ -330,9 +330,7 @@ def setup_and_run(hexagon_session, sch, a, b, c, 
operations, mem_scope="global")
     number = 1
     repeat = 1
 
-    timer = module.time_evaluator(
-        "__tvm_main__", hexagon_session.device, number=number, repeat=repeat
-    )
+    timer = module.time_evaluator("main", hexagon_session.device, 
number=number, repeat=repeat)
     time = timer(a_hexagon, b_hexagon, c_hexagon)
     gops = round(operations * 128 * 3 / time.mean / 1e9, 4)
     return gops, c_hexagon.numpy()
@@ -364,9 +362,7 @@ def setup_and_run_preallocated(hexagon_session, sch, a, b, 
c, operations):
     number = 1
     repeat = 1
 
-    timer = module.time_evaluator(
-        "__tvm_main__", hexagon_session.device, number=number, repeat=repeat
-    )
+    timer = module.time_evaluator("main", hexagon_session.device, 
number=number, repeat=repeat)
     time = timer(a_hexagon, b_hexagon, c_hexagon, a_vtcm_hexagon, 
b_vtcm_hexagon, c_vtcm_hexagon)
     gops = round(operations * 128 * 3 / time.mean / 1e9, 4)
     return gops, c_hexagon.numpy()
diff --git a/tests/python/contrib/test_hexagon/test_parallel_scalar.py 
b/tests/python/contrib/test_hexagon/test_parallel_scalar.py
index bd9c78d5da..5c8043fdff 100644
--- a/tests/python/contrib/test_hexagon/test_parallel_scalar.py
+++ b/tests/python/contrib/test_hexagon/test_parallel_scalar.py
@@ -104,9 +104,7 @@ def evaluate(hexagon_session, operations, expected, sch):
     number = 1
     repeat = 1
 
-    timer = module.time_evaluator(
-        "__tvm_main__", hexagon_session.device, number=number, repeat=repeat
-    )
+    timer = module.time_evaluator("main", hexagon_session.device, 
number=number, repeat=repeat)
     runtime = timer(a_hexagon, b_hexagon, c_hexagon)
 
     tvm.testing.assert_allclose(c_hexagon.numpy(), expected(a, b))
diff --git a/tests/python/contrib/test_hexagon/test_vtcm_bandwidth.py 
b/tests/python/contrib/test_hexagon/test_vtcm_bandwidth.py
index 931f99b2ec..015a9f0656 100644
--- a/tests/python/contrib/test_hexagon/test_vtcm_bandwidth.py
+++ b/tests/python/contrib/test_hexagon/test_vtcm_bandwidth.py
@@ -108,9 +108,9 @@ def evaluate(hexagon_session, sch, size):
 
     if tvm.testing.utils.IS_IN_CI:
         # Run with reduced number and repeat for CI
-        timer = module.time_evaluator("__tvm_main__", hexagon_session.device, 
number=1, repeat=1)
+        timer = module.time_evaluator("main", hexagon_session.device, 
number=1, repeat=1)
     else:
-        timer = module.time_evaluator("__tvm_main__", hexagon_session.device, 
number=10, repeat=10)
+        timer = module.time_evaluator("main", hexagon_session.device, 
number=10, repeat=10)
 
     runtime = timer(a_hexagon, a_vtcm_hexagon)
 
diff --git 
a/tests/python/relax/nvshmem/test_runtime_builtin_kv_cache_transfer.py 
b/tests/python/relax/nvshmem/test_runtime_builtin_kv_cache_transfer.py
index 81acf5ee86..0d2f445cb8 100644
--- a/tests/python/relax/nvshmem/test_runtime_builtin_kv_cache_transfer.py
+++ b/tests/python/relax/nvshmem/test_runtime_builtin_kv_cache_transfer.py
@@ -170,7 +170,7 @@ def set_global_func(head_dim, dtype):
         with target:
             mod = dl.ApplyDefaultSchedule(dl.gpu.Fallback())(mod)
         f = tvm.tir.build(mod["main"], target=target)
-        builts.append(f.entry_func)
+        builts.append(f["main"])
 
     (
         ftranspose_append,
diff --git 
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_cpu.py 
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_cpu.py
index 1941edeaa7..305fd18f35 100644
--- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_cpu.py
+++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_cpu.py
@@ -140,7 +140,7 @@ def set_global_func(head_dim, dtype):
         with target:
             mod = dl.ApplyDefaultSchedule(dl.gpu.Fallback())(mod)
         f = tvm.tir.build(mod["main"], target=target)
-        builts.append(f.entry_func)
+        builts.append(f["main"])
 
     (
         ftranspose_append,
diff --git 
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
 
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
index ffd3452292..e13ce1ca7b 100644
--- 
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
+++ 
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
@@ -156,7 +156,7 @@ def set_global_func():
         with target:
             mod = dl.ApplyDefaultSchedule(dl.gpu.Fallback())(mod)
         f = tvm.tir.build(mod["main"], target=target)
-        builts.append(f.entry_func)
+        builts.append(f["main"])
 
     (
         ftranspose_append,
diff --git 
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_flashinfer.py
 
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_flashinfer.py
index 2f726064a7..53044a786c 100644
--- 
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_flashinfer.py
+++ 
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_flashinfer.py
@@ -169,7 +169,7 @@ def set_global_func(dtype):
         with target:
             mod = dl.ApplyDefaultSchedule(dl.gpu.Fallback())(mod)
         f = tvm.tir.build(mod["main"], target=target)
-        builts.append(f.entry_func)
+        builts.append(f["main"])
 
     (
         ftranspose_append,
diff --git 
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_tir.py 
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_tir.py
index b2982abdb0..73a4d89dad 100644
--- 
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_tir.py
+++ 
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_tir.py
@@ -134,7 +134,7 @@ def set_global_func(dtype):
         with target:
             mod = dl.ApplyDefaultSchedule(dl.gpu.Fallback())(mod)
         f = tvm.tir.build(mod["main"], target=target)
-        builts.append(f.entry_func)
+        builts.append(f["main"])
 
     (
         ftranspose_append,
diff --git 
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py 
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
index 8cd3a73740..44169828e2 100644
--- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
+++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
@@ -142,7 +142,7 @@ def set_global_func(head_dim, dtype):
         with target:
             mod = dl.ApplyDefaultSchedule(dl.gpu.Fallback())(mod)
         f = tvm.tir.build(mod["main"], target=target)
-        builts.append(f.entry_func)
+        builts.append(f["main"])
 
     (
         ftranspose_append,
diff --git a/tests/python/relax/test_runtime_builtin_rnn_state.py 
b/tests/python/relax/test_runtime_builtin_rnn_state.py
index 095aba8b83..fe8c19257d 100644
--- a/tests/python/relax/test_runtime_builtin_rnn_state.py
+++ b/tests/python/relax/test_runtime_builtin_rnn_state.py
@@ -81,7 +81,7 @@ def set_global_func():
         with target:
             mod = dl.ApplyDefaultSchedule(dl.gpu.Fallback())(mod)  # pylint: 
disable=not-callable
         f = tvm.tir.build(mod["main"], target=target)
-        return f.entry_func
+        return f["main"]
 
     _f_tir_gets, _f_tir_sets = [], []
     for state in states:
diff --git a/tests/python/tir-transform/test_tir_transform_helpers.py 
b/tests/python/tir-transform/test_tir_transform_helpers.py
index 0bbd0e7160..a7ac11d7a9 100644
--- a/tests/python/tir-transform/test_tir_transform_helpers.py
+++ b/tests/python/tir-transform/test_tir_transform_helpers.py
@@ -21,27 +21,6 @@ from tvm.script import tir as T, ir as I
 import tvm.testing
 
 
-def test_annotate_entry_func_single_primfunc():
-    @tvm.script.ir_module
-    class MockModule:
-        @T.prim_func(private=True)
-        def func1(A: T.Buffer((16,), "float32")):
-            for i in T.serial(16):
-                if i == 5:
-                    if i == 5:
-                        A[i] = 0.0
-
-    mod = MockModule
-    assert mod
-    assert not mod["func1"].attrs
-    after = tvm.tir.transform.AnnotateEntryFunc()(mod)
-    assert (
-        after["func1"].attrs
-        and "tir.is_entry_func" in after["func1"].attrs
-        and after["func1"].attrs["tir.is_entry_func"]
-    )
-
-
 # Test module
 @tvm.script.ir_module
 class MockModule:
@@ -60,16 +39,6 @@ class MockModule:
                     A[i] = 0.0
 
 
[email protected]
-def test_annotate_entry_func_multiple_primfunc():
-    mod = MockModule
-    assert mod
-    assert not mod["func1"].attrs
-    assert not mod["func2"].attrs
-    # This should fail
-    after = tvm.tir.transform.AnnotateEntryFunc()(mod)
-
-
 def test_bind_target():
     mod = MockModule
     assert mod
diff --git a/tests/python/tvmscript/test_tvmscript_roundtrip.py 
b/tests/python/tvmscript/test_tvmscript_roundtrip.py
index 0e1b328844..73ca1dad3b 100644
--- a/tests/python/tvmscript/test_tvmscript_roundtrip.py
+++ b/tests/python/tvmscript/test_tvmscript_roundtrip.py
@@ -196,7 +196,6 @@ def opt_gemm_mod_host():
             T.func_attr(
                 {
                     "tir.noalias": True,
-                    "tir.is_entry_func": True,
                     "calling_conv": 1,
                 }
             )
@@ -2242,7 +2241,6 @@ def opt_conv_tensorcore_mod_host():
             {
                 "tir.noalias": True,
                 "global_symbol": "default_function",
-                "tir.is_entry_func": True,
                 "calling_conv": 1,
             }
         )


Reply via email to