This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 147ed5e27d [Unity][CodeGen] RunCodegen based on externally-exposed 
functions (#16422)
147ed5e27d is described below

commit 147ed5e27d76432ea02ce1cd590ba2cbbd76378e
Author: Eric Lunderberg <[email protected]>
AuthorDate: Mon Jan 29 20:51:34 2024 -0600

    [Unity][CodeGen] RunCodegen based on externally-exposed functions (#16422)
    
    * [IR] Add utility methods to IRModule
    
    * `IRModule.clone`: Clone the module.  While in C++, a module can be
      copied using `IRModule::CopyOnWrite()`, copying a module in Python
      required passing all members into the `IRModule` initializer.  The
      `IRModule.clone` method provides an easier way to copy an `IRModule`
      from python.
    
    * `IRModule.__delitem__`: Remove a function from the module.  This
      exposes the C++ method `IRModuleNode::Remove` for use in the python
      API.  This uses the python `del` keyword, similar to a native python
      list.  Similar to the existing `IRModule.__getitem__`, this can be
      called with either a `GlobalVar` or a python string.
    
    * `IRModule.__contains__`: Check if a function is in the module.  This
      allows the pythone keyword `in` to check if a module contains a
      specific function.  Similar to the existing `IRModule.__getitem__`,
      this can be called either with a `GlobalVar` (`if gvar in mod`) or
      with a python string (`if "function_name" in mod`).
    
    * [Unity][CodeGen] RunCodegen based on externally-exposed functions
    
    Prior to this commit, `relax.transform.RunCodegen` required a list of
    entry functions for a module, defaulting to `"main"` if not specified.
    The list of entry functions is duplicate information that could be
    inferred from the module, and should not be required from the user.
    This commit updates `RunCodegen` to treat all externally-exposed
    functions as entry points, in the same manner as
    `DeadCodeElimination`.
    
    For backwards compatibility, the `entry_functions` argument is still
    accepted, and is used to augment the list of externally-exposed
    functions.
---
 python/tvm/ir/module.py                           |  9 +++++
 python/tvm/relax/transform/transform.py           |  3 +-
 src/ir/module.cc                                  | 34 ++++++++++++++++++
 src/relax/transform/run_codegen.cc                | 40 +++++++++++++++++----
 tests/python/relax/test_transform_codegen_pass.py | 43 +++++++++++++++++++----
 5 files changed, 115 insertions(+), 14 deletions(-)

diff --git a/python/tvm/ir/module.py b/python/tvm/ir/module.py
index a3e097947c..ea3ef6d883 100644
--- a/python/tvm/ir/module.py
+++ b/python/tvm/ir/module.py
@@ -80,6 +80,9 @@ class IRModule(Node, Scriptable):
             global_infos,
         )
 
+    def clone(self) -> "IRModule":
+        return _ffi_api.Module_Clone(self)
+
     def functions_items(self):
         """Get items in self.functions.items() in alphabetical order.
 
@@ -138,6 +141,12 @@ class IRModule(Node, Scriptable):
             return _ffi_api.Module_Lookup(self, var)
         return _ffi_api.Module_LookupDef(self, var)
 
+    def __delitem__(self, var: Union[str, _expr.GlobalVar]):
+        _ffi_api.Module_Remove(self, var)
+
+    def __contains__(self, var: Union[str, _expr.GlobalVar]) -> bool:
+        return _ffi_api.Module_Contains(self, var)
+
     def update(self, other):
         """Insert functions in another Module to current one.
 
diff --git a/python/tvm/relax/transform/transform.py 
b/python/tvm/relax/transform/transform.py
index 1f390adb2e..e360c09392 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -574,7 +574,8 @@ def RunCodegen(
         The registered pass to remove unused functions.
     """
     if entry_functions is None:
-        entry_functions = ["main"]
+        entry_functions = []
+
     # enable cutlass byoc registries
     # pylint: disable=unused-import,import-outside-toplevel
     from tvm.contrib import cutlass as _cutlass
diff --git a/src/ir/module.cc b/src/ir/module.cc
index 156158a85f..2e60441e94 100644
--- a/src/ir/module.cc
+++ b/src/ir/module.cc
@@ -413,6 +413,12 @@ TVM_REGISTER_GLOBAL("ir.IRModule")
       return IRModule(funcs, types, {}, {}, dict_attrs, global_infos);
     });
 
+TVM_REGISTER_GLOBAL("ir.Module_Clone").set_body_typed([](IRModule mod) -> 
IRModule {
+  IRModule clone = mod;
+  clone.CopyOnWrite();
+  return clone;
+});
+
 TVM_REGISTER_GLOBAL("ir.Module_Add")
     .set_body_typed([](IRModule mod, GlobalVar var, ObjectRef val, bool 
update) -> IRModule {
       ICHECK(val->IsInstance<RelayExprNode>());
@@ -423,6 +429,34 @@ TVM_REGISTER_GLOBAL("ir.Module_Add")
       return mod;
     });
 
+TVM_REGISTER_GLOBAL("ir.Module_Remove")
+    .set_body_typed([](IRModule mod, Variant<String, GlobalVar> var) -> 
IRModule {
+      GlobalVar gvar = [&]() {
+        if (auto opt = var.as<GlobalVar>()) {
+          return opt.value();
+        } else if (auto opt = var.as<String>()) {
+          return mod->GetGlobalVar(opt.value());
+        } else {
+          LOG(FATAL) << "InternalError: "
+                     << "Variant didn't contain any of the allowed types";
+        }
+      }();
+      mod->Remove(gvar);
+      return mod;
+    });
+
+TVM_REGISTER_GLOBAL("ir.Module_Contains")
+    .set_body_typed([](IRModule mod, Variant<String, GlobalVar> var) -> bool {
+      if (auto opt = var.as<GlobalVar>()) {
+        return mod->functions.count(opt.value());
+      } else if (auto opt = var.as<String>()) {
+        return mod->global_var_map_.count(opt.value());
+      } else {
+        LOG(FATAL) << "InternalError: "
+                   << "Variant didn't contain any of the allowed types";
+      }
+    });
+
 
TVM_REGISTER_GLOBAL("ir.Module_AddDef").set_body_method<IRModule>(&IRModuleNode::AddTypeDef);
 
 TVM_REGISTER_GLOBAL("ir.Module_GetGlobalVar")
diff --git a/src/relax/transform/run_codegen.cc 
b/src/relax/transform/run_codegen.cc
index 9955b5f483..c385ae46ef 100644
--- a/src/relax/transform/run_codegen.cc
+++ b/src/relax/transform/run_codegen.cc
@@ -28,6 +28,7 @@
 
 #include <iostream>
 
+#include "../../support/ordered_set.h"
 #include "utils.h"
 
 namespace tvm {
@@ -39,12 +40,39 @@ class CodeGenRunner : ExprMutator {
 
   explicit CodeGenRunner(IRModule mod) : ExprMutator(mod) {}
 
-  IRModule Run(Optional<Map<String, OptionMap>> target_options, Array<String> 
entry_functions) {
+  IRModule Run(Optional<Map<String, OptionMap>> target_options,
+               Array<String> entry_function_names) {
     IRModule mod = builder_->GetContextIRModule();
-    for (const String& entry_func_name : entry_functions) {
-      auto entry_func = mod->Lookup(entry_func_name);
-      auto gvar = mod->GetGlobalVar(entry_func_name);
-      builder_->UpdateFunction(gvar, 
Downcast<BaseFunc>(VisitExpr(entry_func)));
+
+    support::OrderedSet<GlobalVar> entry_functions;
+    // Any user-provided functions are treated as entry functions.
+    for (const auto& name : entry_function_names) {
+      entry_functions.insert(mod->GetGlobalVar(name));
+    }
+
+    // In addtion, any externally-exposed function that does not
+    // belong to a specific codegen may be an entry function.  These
+    // are added in alphabetical order, to ensure consistent order of
+    // evaluation for debug/test purposes.
+    {
+      std::vector<GlobalVar> attr_entry_functions;
+      for (const auto& [gv, func] : mod->functions) {
+        if (func->GetLinkageType() == LinkageType::kExternal &&
+            !func->GetAttr<String>(attr::kCodegen) && 
func->IsInstance<relax::FunctionNode>()) {
+          attr_entry_functions.push_back(gv);
+        }
+      }
+      std::sort(attr_entry_functions.begin(), attr_entry_functions.end(),
+                [](const auto& gvar_a, const auto& gvar_b) {
+                  return gvar_a->name_hint > gvar_b->name_hint;
+                });
+      for (const auto& gvar : attr_entry_functions) {
+        entry_functions.insert(gvar);
+      }
+    }
+
+    for (const auto& gvar : entry_functions) {
+      builder_->UpdateFunction(gvar, 
Downcast<BaseFunc>(VisitExpr(mod->Lookup(gvar))));
     }
 
     auto ext_mods = InvokeCodegen(mod, target_options.value_or({}));
@@ -65,7 +93,7 @@ class CodeGenRunner : ExprMutator {
     }
 
     // TODO(@tvm-team): Implicit pass dependency. Revisit when we have a 
better way to handle this.
-    return DeadCodeElimination(out_mod, entry_functions);
+    return DeadCodeElimination(out_mod, entry_function_names);
   }
 
   using ExprMutator::VisitExpr_;
diff --git a/tests/python/relax/test_transform_codegen_pass.py 
b/tests/python/relax/test_transform_codegen_pass.py
index d103291388..cc8f390b96 100644
--- a/tests/python/relax/test_transform_codegen_pass.py
+++ b/tests/python/relax/test_transform_codegen_pass.py
@@ -48,21 +48,21 @@ target = tvm.target.Target(target_str)
 dev = tvm.cuda()
 
 
-def check_executable(exec, dev, inputs, expected):
+def check_executable(exec, dev, inputs, expected, entry_func_name):
     vm = relax.VirtualMachine(exec, dev)
-    out = vm["main"](*inputs)
+    out = vm[entry_func_name](*inputs)
     tvm.testing.assert_allclose(out.numpy(), expected.numpy(), atol=1e-5, 
rtol=1e-5)
 
 
-def check_roundtrip(exec0, dev, inputs, expected):
+def check_roundtrip(exec0, dev, inputs, expected, entry_func_name="main"):
     exec0.mod.export_library("exec.so")
     exec1 = tvm.runtime.load_module("exec.so")
     os.remove("exec.so")
     assert exec0.stats() == exec1["stats"]()
     assert exec0.as_text() == exec1["as_text"]()
 
-    check_executable(exec0, dev, inputs, expected)
-    check_executable(exec1, dev, inputs, expected)
+    check_executable(exec0, dev, inputs, expected, entry_func_name)
+    check_executable(exec1, dev, inputs, expected, entry_func_name)
 
 
 def gen_ground_truth(mod, target, dev, inputs):
@@ -113,10 +113,17 @@ def setup_test():
     return mod, inputs, expected
 
 
+entry_func_name = tvm.testing.parameter("main", "func")
+
+
 @tvm.testing.requires_gpu
-def test_tensorrt_only():
+def test_tensorrt_only(entry_func_name):
     mod, inputs, expected = setup_test()
 
+    if entry_func_name != "main":
+        mod[entry_func_name] = mod
+        del mod["main"]
+
     # Define patterns that we want to offload to byoc
     # This test will offload entire model
     # Thus, define patterns for both `multiply` and `add` ops
@@ -135,7 +142,7 @@ def test_tensorrt_only():
 
     ex0 = relax.build(new_mod, target, params={})
     # Sanity check for the correctness and roundtrip
-    check_roundtrip(ex0, dev, inputs, expected)
+    check_roundtrip(ex0, dev, inputs, expected, entry_func_name)
 
 
 @tvm.testing.requires_gpu
@@ -248,6 +255,28 @@ def test_multiple_calls_same_extern():
     tvm.ir.assert_structural_equal(mod["main"], Conv2dx2_after["main"])
 
 
+def test_default_entry_func():
+    """The entry function is not necessarily named "main"
+
+    Like `test_multiple_calls_same_extern`, but the main function is
+    named "func".
+    """
+    before_with_main = Conv2dx2
+    after_with_main = relax.transform.RunCodegen()(before_with_main)
+
+    def rename_main(mod):
+        mod = mod.clone()
+        mod["func"] = mod["main"].with_attr("global_symbol", "func")
+        del mod["main"]
+        return mod
+
+    before_with_func = rename_main(before_with_main)
+    expected_with_func = rename_main(after_with_main)
+    after_with_func = relax.transform.RunCodegen()(before_with_func)
+
+    tvm.ir.assert_structural_equal(expected_with_func["func"], 
after_with_func["func"])
+
+
 def test_dynamic_shape():
     import tvm.relax.backend.contrib.cublas
 

Reply via email to