This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 83e7e9b2eb [Debug] Improve error messages in LiftTransformParams 
(#16802)
83e7e9b2eb is described below

commit 83e7e9b2eb8dbeeb16dcfdbaf3336caa81071877
Author: Eric Lunderberg <[email protected]>
AuthorDate: Wed Mar 27 19:02:39 2024 -0500

    [Debug] Improve error messages in LiftTransformParams (#16802)
    
    The `LiftTransformParams` pass requires Relax functions that have the
    `attr::kNumInput` attribute (`"num_input"`).  By default, it collects
    and applies only to functions with this attribute.  If the user
    specifies functions that don't match this criteria, the
    `LiftTransformParams` will raise an error.
    
    This commit improves the error messages that are raised when the
    specified function is missing, is not an IRModule, or is missing the
    `kNumInput` attribute.  Previously the error messages were raised
    implicitly by `IRModule::Lookup`, `Downcast<Function>`, or
    `Optional::value`, respectively.
---
 src/relax/transform/lift_transform_params.cc | 24 ++++++++++++++++++++++--
 1 file changed, 22 insertions(+), 2 deletions(-)

diff --git a/src/relax/transform/lift_transform_params.cc 
b/src/relax/transform/lift_transform_params.cc
index abf21189e4..7607d690d4 100644
--- a/src/relax/transform/lift_transform_params.cc
+++ b/src/relax/transform/lift_transform_params.cc
@@ -705,8 +705,28 @@ std::vector<std::pair<GlobalVar, Function>> 
GetTargetFunctions(
   std::vector<std::pair<GlobalVar, Function>> target_functions;
   if (shared_transform.as<Array<String>>().value_or(Array<String>{}).size()) {
     for (const auto& name : shared_transform.as<Array<String>>().value()) {
-      auto gvar = mod->GetGlobalVar(name);
-      target_functions.push_back({gvar, 
Downcast<Function>(mod->Lookup(gvar))});
+      auto gvar = mod->global_var_map_.Get(name);
+      CHECK(gvar) << "When LiftTransformParams is called with a list of 
function names, "
+                  << "all function names must occur within the IRModule.  "
+                  << "However, the IRModule does not contain a function names 
'" << name << "'";
+
+      auto base_func = mod->functions.Get(gvar.value());
+      ICHECK(base_func) << "Ill-formed IRModule.  "
+                        << "The map from name to GlobalVar found " << 
gvar.value()
+                        << " for the function name '" << name
+                        << "', but this GlobalVar does not appear in the 
IRModule";
+
+      auto func = base_func.as<Function>();
+      CHECK(func) << "When LiftTransformParams is called with a list of 
function names, "
+                  << "only functions in the list must be relax functions.  "
+                  << "However, the function " << name << " is of type " << 
base_func->GetTypeKey();
+      CHECK(func.value()->GetAttr<Integer>(attr::kNumInput))
+          << "When LiftTransformParams is called with a list of function 
names, "
+          << "all functions in the list must have the kNumInput ('" << 
attr::kNumInput
+          << "') attribute.  "
+          << "However, the function " << name << " does not have the kNumInput 
attribute";
+
+      target_functions.push_back({gvar.value(), func.value()});
     }
   } else {
     // Get all the functions that have the `num_input` attribute.

Reply via email to