This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 147ed5e27d [Unity][CodeGen] RunCodegen based on externally-exposed
functions (#16422)
147ed5e27d is described below
commit 147ed5e27d76432ea02ce1cd590ba2cbbd76378e
Author: Eric Lunderberg <[email protected]>
AuthorDate: Mon Jan 29 20:51:34 2024 -0600
[Unity][CodeGen] RunCodegen based on externally-exposed functions (#16422)
* [IR] Add utility methods to IRModule
* `IRModule.clone`: Clone the module. While in C++, a module can be
copied using `IRModule::CopyOnWrite()`, copying a module in Python
required passing all members into the `IRModule` initializer. The
`IRModule.clone` method provides an easier way to copy an `IRModule`
from python.
* `IRModule.__delitem__`: Remove a function from the module. This
exposes the C++ method `IRModuleNode::Remove` for use in the python
API. This uses the python `del` keyword, similar to a native python
list. Similar to the existing `IRModule.__getitem__`, this can be
called with either a `GlobalVar` or a python string.
* `IRModule.__contains__`: Check if a function is in the module. This
allows the pythone keyword `in` to check if a module contains a
specific function. Similar to the existing `IRModule.__getitem__`,
this can be called either with a `GlobalVar` (`if gvar in mod`) or
with a python string (`if "function_name" in mod`).
* [Unity][CodeGen] RunCodegen based on externally-exposed functions
Prior to this commit, `relax.transform.RunCodegen` required a list of
entry functions for a module, defaulting to `"main"` if not specified.
The list of entry functions is duplicate information that could be
inferred from the module, and should not be required from the user.
This commit updates `RunCodegen` to treat all externally-exposed
functions as entry points, in the same manner as
`DeadCodeElimination`.
For backwards compatibility, the `entry_functions` argument is still
accepted, and is used to augment the list of externally-exposed
functions.
---
python/tvm/ir/module.py | 9 +++++
python/tvm/relax/transform/transform.py | 3 +-
src/ir/module.cc | 34 ++++++++++++++++++
src/relax/transform/run_codegen.cc | 40 +++++++++++++++++----
tests/python/relax/test_transform_codegen_pass.py | 43 +++++++++++++++++++----
5 files changed, 115 insertions(+), 14 deletions(-)
diff --git a/python/tvm/ir/module.py b/python/tvm/ir/module.py
index a3e097947c..ea3ef6d883 100644
--- a/python/tvm/ir/module.py
+++ b/python/tvm/ir/module.py
@@ -80,6 +80,9 @@ class IRModule(Node, Scriptable):
global_infos,
)
+ def clone(self) -> "IRModule":
+ return _ffi_api.Module_Clone(self)
+
def functions_items(self):
"""Get items in self.functions.items() in alphabetical order.
@@ -138,6 +141,12 @@ class IRModule(Node, Scriptable):
return _ffi_api.Module_Lookup(self, var)
return _ffi_api.Module_LookupDef(self, var)
+ def __delitem__(self, var: Union[str, _expr.GlobalVar]):
+ _ffi_api.Module_Remove(self, var)
+
+ def __contains__(self, var: Union[str, _expr.GlobalVar]) -> bool:
+ return _ffi_api.Module_Contains(self, var)
+
def update(self, other):
"""Insert functions in another Module to current one.
diff --git a/python/tvm/relax/transform/transform.py
b/python/tvm/relax/transform/transform.py
index 1f390adb2e..e360c09392 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -574,7 +574,8 @@ def RunCodegen(
The registered pass to remove unused functions.
"""
if entry_functions is None:
- entry_functions = ["main"]
+ entry_functions = []
+
# enable cutlass byoc registries
# pylint: disable=unused-import,import-outside-toplevel
from tvm.contrib import cutlass as _cutlass
diff --git a/src/ir/module.cc b/src/ir/module.cc
index 156158a85f..2e60441e94 100644
--- a/src/ir/module.cc
+++ b/src/ir/module.cc
@@ -413,6 +413,12 @@ TVM_REGISTER_GLOBAL("ir.IRModule")
return IRModule(funcs, types, {}, {}, dict_attrs, global_infos);
});
+TVM_REGISTER_GLOBAL("ir.Module_Clone").set_body_typed([](IRModule mod) ->
IRModule {
+ IRModule clone = mod;
+ clone.CopyOnWrite();
+ return clone;
+});
+
TVM_REGISTER_GLOBAL("ir.Module_Add")
.set_body_typed([](IRModule mod, GlobalVar var, ObjectRef val, bool
update) -> IRModule {
ICHECK(val->IsInstance<RelayExprNode>());
@@ -423,6 +429,34 @@ TVM_REGISTER_GLOBAL("ir.Module_Add")
return mod;
});
+TVM_REGISTER_GLOBAL("ir.Module_Remove")
+ .set_body_typed([](IRModule mod, Variant<String, GlobalVar> var) ->
IRModule {
+ GlobalVar gvar = [&]() {
+ if (auto opt = var.as<GlobalVar>()) {
+ return opt.value();
+ } else if (auto opt = var.as<String>()) {
+ return mod->GetGlobalVar(opt.value());
+ } else {
+ LOG(FATAL) << "InternalError: "
+ << "Variant didn't contain any of the allowed types";
+ }
+ }();
+ mod->Remove(gvar);
+ return mod;
+ });
+
+TVM_REGISTER_GLOBAL("ir.Module_Contains")
+ .set_body_typed([](IRModule mod, Variant<String, GlobalVar> var) -> bool {
+ if (auto opt = var.as<GlobalVar>()) {
+ return mod->functions.count(opt.value());
+ } else if (auto opt = var.as<String>()) {
+ return mod->global_var_map_.count(opt.value());
+ } else {
+ LOG(FATAL) << "InternalError: "
+ << "Variant didn't contain any of the allowed types";
+ }
+ });
+
TVM_REGISTER_GLOBAL("ir.Module_AddDef").set_body_method<IRModule>(&IRModuleNode::AddTypeDef);
TVM_REGISTER_GLOBAL("ir.Module_GetGlobalVar")
diff --git a/src/relax/transform/run_codegen.cc
b/src/relax/transform/run_codegen.cc
index 9955b5f483..c385ae46ef 100644
--- a/src/relax/transform/run_codegen.cc
+++ b/src/relax/transform/run_codegen.cc
@@ -28,6 +28,7 @@
#include <iostream>
+#include "../../support/ordered_set.h"
#include "utils.h"
namespace tvm {
@@ -39,12 +40,39 @@ class CodeGenRunner : ExprMutator {
explicit CodeGenRunner(IRModule mod) : ExprMutator(mod) {}
- IRModule Run(Optional<Map<String, OptionMap>> target_options, Array<String>
entry_functions) {
+ IRModule Run(Optional<Map<String, OptionMap>> target_options,
+ Array<String> entry_function_names) {
IRModule mod = builder_->GetContextIRModule();
- for (const String& entry_func_name : entry_functions) {
- auto entry_func = mod->Lookup(entry_func_name);
- auto gvar = mod->GetGlobalVar(entry_func_name);
- builder_->UpdateFunction(gvar,
Downcast<BaseFunc>(VisitExpr(entry_func)));
+
+ support::OrderedSet<GlobalVar> entry_functions;
+ // Any user-provided functions are treated as entry functions.
+ for (const auto& name : entry_function_names) {
+ entry_functions.insert(mod->GetGlobalVar(name));
+ }
+
+ // In addtion, any externally-exposed function that does not
+ // belong to a specific codegen may be an entry function. These
+ // are added in alphabetical order, to ensure consistent order of
+ // evaluation for debug/test purposes.
+ {
+ std::vector<GlobalVar> attr_entry_functions;
+ for (const auto& [gv, func] : mod->functions) {
+ if (func->GetLinkageType() == LinkageType::kExternal &&
+ !func->GetAttr<String>(attr::kCodegen) &&
func->IsInstance<relax::FunctionNode>()) {
+ attr_entry_functions.push_back(gv);
+ }
+ }
+ std::sort(attr_entry_functions.begin(), attr_entry_functions.end(),
+ [](const auto& gvar_a, const auto& gvar_b) {
+ return gvar_a->name_hint > gvar_b->name_hint;
+ });
+ for (const auto& gvar : attr_entry_functions) {
+ entry_functions.insert(gvar);
+ }
+ }
+
+ for (const auto& gvar : entry_functions) {
+ builder_->UpdateFunction(gvar,
Downcast<BaseFunc>(VisitExpr(mod->Lookup(gvar))));
}
auto ext_mods = InvokeCodegen(mod, target_options.value_or({}));
@@ -65,7 +93,7 @@ class CodeGenRunner : ExprMutator {
}
// TODO(@tvm-team): Implicit pass dependency. Revisit when we have a
better way to handle this.
- return DeadCodeElimination(out_mod, entry_functions);
+ return DeadCodeElimination(out_mod, entry_function_names);
}
using ExprMutator::VisitExpr_;
diff --git a/tests/python/relax/test_transform_codegen_pass.py
b/tests/python/relax/test_transform_codegen_pass.py
index d103291388..cc8f390b96 100644
--- a/tests/python/relax/test_transform_codegen_pass.py
+++ b/tests/python/relax/test_transform_codegen_pass.py
@@ -48,21 +48,21 @@ target = tvm.target.Target(target_str)
dev = tvm.cuda()
-def check_executable(exec, dev, inputs, expected):
+def check_executable(exec, dev, inputs, expected, entry_func_name):
vm = relax.VirtualMachine(exec, dev)
- out = vm["main"](*inputs)
+ out = vm[entry_func_name](*inputs)
tvm.testing.assert_allclose(out.numpy(), expected.numpy(), atol=1e-5,
rtol=1e-5)
-def check_roundtrip(exec0, dev, inputs, expected):
+def check_roundtrip(exec0, dev, inputs, expected, entry_func_name="main"):
exec0.mod.export_library("exec.so")
exec1 = tvm.runtime.load_module("exec.so")
os.remove("exec.so")
assert exec0.stats() == exec1["stats"]()
assert exec0.as_text() == exec1["as_text"]()
- check_executable(exec0, dev, inputs, expected)
- check_executable(exec1, dev, inputs, expected)
+ check_executable(exec0, dev, inputs, expected, entry_func_name)
+ check_executable(exec1, dev, inputs, expected, entry_func_name)
def gen_ground_truth(mod, target, dev, inputs):
@@ -113,10 +113,17 @@ def setup_test():
return mod, inputs, expected
+entry_func_name = tvm.testing.parameter("main", "func")
+
+
@tvm.testing.requires_gpu
-def test_tensorrt_only():
+def test_tensorrt_only(entry_func_name):
mod, inputs, expected = setup_test()
+ if entry_func_name != "main":
+ mod[entry_func_name] = mod
+ del mod["main"]
+
# Define patterns that we want to offload to byoc
# This test will offload entire model
# Thus, define patterns for both `multiply` and `add` ops
@@ -135,7 +142,7 @@ def test_tensorrt_only():
ex0 = relax.build(new_mod, target, params={})
# Sanity check for the correctness and roundtrip
- check_roundtrip(ex0, dev, inputs, expected)
+ check_roundtrip(ex0, dev, inputs, expected, entry_func_name)
@tvm.testing.requires_gpu
@@ -248,6 +255,28 @@ def test_multiple_calls_same_extern():
tvm.ir.assert_structural_equal(mod["main"], Conv2dx2_after["main"])
+def test_default_entry_func():
+ """The entry function is not necessarily named "main"
+
+ Like `test_multiple_calls_same_extern`, but the main function is
+ named "func".
+ """
+ before_with_main = Conv2dx2
+ after_with_main = relax.transform.RunCodegen()(before_with_main)
+
+ def rename_main(mod):
+ mod = mod.clone()
+ mod["func"] = mod["main"].with_attr("global_symbol", "func")
+ del mod["main"]
+ return mod
+
+ before_with_func = rename_main(before_with_main)
+ expected_with_func = rename_main(after_with_main)
+ after_with_func = relax.transform.RunCodegen()(before_with_func)
+
+ tvm.ir.assert_structural_equal(expected_with_func["func"],
after_with_func["func"])
+
+
def test_dynamic_shape():
import tvm.relax.backend.contrib.cublas