This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 83e7e9b2eb [Debug] Improve error messages in LiftTransformParams
(#16802)
83e7e9b2eb is described below
commit 83e7e9b2eb8dbeeb16dcfdbaf3336caa81071877
Author: Eric Lunderberg <[email protected]>
AuthorDate: Wed Mar 27 19:02:39 2024 -0500
[Debug] Improve error messages in LiftTransformParams (#16802)
The `LiftTransformParams` pass requires Relax functions that have the
`attr::kNumInput` attribute (`"num_input"`). By default, it collects
and applies only to functions with this attribute. If the user
specifies functions that don't match this criteria, the
`LiftTransformParams` will raise an error.
This commit improves the error messages that are raised when the
specified function is missing, is not an IRModule, or is missing the
`kNumInput` attribute. Previously the error messages were raised
implicitly by `IRModule::Lookup`, `Downcast<Function>`, or
`Optional::value`, respectively.
---
src/relax/transform/lift_transform_params.cc | 24 ++++++++++++++++++++++--
1 file changed, 22 insertions(+), 2 deletions(-)
diff --git a/src/relax/transform/lift_transform_params.cc
b/src/relax/transform/lift_transform_params.cc
index abf21189e4..7607d690d4 100644
--- a/src/relax/transform/lift_transform_params.cc
+++ b/src/relax/transform/lift_transform_params.cc
@@ -705,8 +705,28 @@ std::vector<std::pair<GlobalVar, Function>>
GetTargetFunctions(
std::vector<std::pair<GlobalVar, Function>> target_functions;
if (shared_transform.as<Array<String>>().value_or(Array<String>{}).size()) {
for (const auto& name : shared_transform.as<Array<String>>().value()) {
- auto gvar = mod->GetGlobalVar(name);
- target_functions.push_back({gvar,
Downcast<Function>(mod->Lookup(gvar))});
+ auto gvar = mod->global_var_map_.Get(name);
+ CHECK(gvar) << "When LiftTransformParams is called with a list of
function names, "
+ << "all function names must occur within the IRModule. "
+ << "However, the IRModule does not contain a function names
'" << name << "'";
+
+ auto base_func = mod->functions.Get(gvar.value());
+ ICHECK(base_func) << "Ill-formed IRModule. "
+ << "The map from name to GlobalVar found " <<
gvar.value()
+ << " for the function name '" << name
+ << "', but this GlobalVar does not appear in the
IRModule";
+
+ auto func = base_func.as<Function>();
+ CHECK(func) << "When LiftTransformParams is called with a list of
function names, "
+ << "only functions in the list must be relax functions. "
+ << "However, the function " << name << " is of type " <<
base_func->GetTypeKey();
+ CHECK(func.value()->GetAttr<Integer>(attr::kNumInput))
+ << "When LiftTransformParams is called with a list of function
names, "
+ << "all functions in the list must have the kNumInput ('" <<
attr::kNumInput
+ << "') attribute. "
+ << "However, the function " << name << " does not have the kNumInput
attribute";
+
+ target_functions.push_back({gvar.value(), func.value()});
}
} else {
// Get all the functions that have the `num_input` attribute.