This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new f60b08c9a4 [QoL][IR] Provide default constructor for
NameSupply/GlobalVarSupply (#17135)
f60b08c9a4 is described below
commit f60b08c9a421d24c7627038064526b5cd7e2610a
Author: Eric Lunderberg <[email protected]>
AuthorDate: Fri Jul 12 18:41:57 2024 -0500
[QoL][IR] Provide default constructor for NameSupply/GlobalVarSupply
(#17135)
Prior to this commit, a `tvm::NameSupply` needed to be constructed
with an explicit `const String& prefix` argument. Omitting this
argument would fall back to the default constructor provided by the
`TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS` macro, producing a
`NameSupply` holding a nullptr. This then leads to a segfault when
the null `NameSupply` is used.
The vast majority of usages of `NameSupply::NameSupply` (29 out of 31)
initialize it with an empty `prefix` string. The remaining two use
cases initialize it with a non-empty `prefix` string. There are no
cases in which a null `NameSupply` is initialized.
This commit updates `NameSupply` to use the
`TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS` macro instead of
`TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS`. This allows the default
constructor to provide the common usage of a `NameSupply` with an
empty prefix, rather than the error-prone usage of a null `NameSupply`
A similar change is also made for `GlobalVarSupply`, as the majority
of its uses also default to an empty prefix (11 out of 13).
---
include/tvm/ir/global_var_supply.h | 7 ++++---
include/tvm/ir/name_supply.h | 4 ++--
src/auto_scheduler/feature.cc | 3 +--
src/contrib/hybrid/codegen_hybrid.h | 2 +-
src/driver/driver_api.cc | 6 ++----
src/ir/global_var_supply.cc | 2 +-
src/relax/backend/contrib/cutlass/codegen.cc | 2 +-
src/relax/ir/block_builder.cc | 3 +--
src/relax/transform/allocate_workspace.cc | 2 +-
src/relax/transform/normalize.cc | 2 +-
src/relay/backend/graph_executor_codegen.cc | 2 +-
src/relay/backend/task_extraction.cc | 4 ++--
src/relay/backend/te_compiler.cc | 5 ++---
src/relay/backend/te_compiler_cache.cc | 4 ++--
src/relay/backend/te_compiler_cache.h | 2 +-
src/target/source/codegen_c.h | 2 +-
src/target/source/codegen_source_base.cc | 2 +-
src/target/source/codegen_source_base.h | 2 +-
src/te/operation/create_primfunc.cc | 2 +-
src/tir/ir/index_map.cc | 2 +-
tests/cpp/build_module_test.cc | 4 ++--
tests/cpp/c_codegen_test.cc | 6 ++----
tests/cpp/name_supply_test.cc | 4 ++--
23 files changed, 34 insertions(+), 40 deletions(-)
diff --git a/include/tvm/ir/global_var_supply.h
b/include/tvm/ir/global_var_supply.h
index 276c64a0d7..9ce0da5e02 100644
--- a/include/tvm/ir/global_var_supply.h
+++ b/include/tvm/ir/global_var_supply.h
@@ -41,7 +41,7 @@ class GlobalVarSupplyNode : public Object {
/*!
* \brief Empty constructor. Will use an empty NameSupply.
*/
- GlobalVarSupplyNode() : GlobalVarSupplyNode(NameSupply("")) {}
+ GlobalVarSupplyNode() : GlobalVarSupplyNode(NameSupply()) {}
/*!
* \brief Constructor.
@@ -100,7 +100,7 @@ class GlobalVarSupply : public ObjectRef {
* \param name_supply The NameSupply to be used when generating new
GlobalVars.
* \param name_to_var_map An optional map.
*/
- TVM_DLL explicit GlobalVarSupply(const NameSupply& name_supply,
+ TVM_DLL explicit GlobalVarSupply(const NameSupply& name_supply =
NameSupply(),
std::unordered_map<std::string, GlobalVar>
name_to_var_map = {});
/*!
@@ -117,7 +117,8 @@ class GlobalVarSupply : public ObjectRef {
*/
TVM_DLL explicit GlobalVarSupply(const IRModule module);
- TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(GlobalVarSupply, ObjectRef,
GlobalVarSupplyNode);
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(GlobalVarSupply, ObjectRef,
+ GlobalVarSupplyNode);
};
} // namespace tvm
diff --git a/include/tvm/ir/name_supply.h b/include/tvm/ir/name_supply.h
index f2c9af4926..11dac3fe52 100644
--- a/include/tvm/ir/name_supply.h
+++ b/include/tvm/ir/name_supply.h
@@ -116,7 +116,7 @@ class NameSupply : public ObjectRef {
* \param prefix The prefix to be used with this NameSupply.
* \param name_map An optional map.
*/
- TVM_DLL explicit NameSupply(const String& prefix,
+ TVM_DLL explicit NameSupply(const String& prefix = "",
std::unordered_map<std::string, int> name_map =
{});
/*!
@@ -129,7 +129,7 @@ class NameSupply : public ObjectRef {
TVM_DLL explicit NameSupply(Iter begin, Iter end, Lambda f)
: NameSupply("", GetNameMap(begin, end, f)) {}
- TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(NameSupply, ObjectRef, NameSupplyNode);
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(NameSupply, ObjectRef,
NameSupplyNode);
private:
template <typename Iter, typename Lambda>
diff --git a/src/auto_scheduler/feature.cc b/src/auto_scheduler/feature.cc
index 65cc13eb61..09255b5da5 100644
--- a/src/auto_scheduler/feature.cc
+++ b/src/auto_scheduler/feature.cc
@@ -1375,8 +1375,7 @@ void GetPerStoreFeaturesWorkerFunc(const SearchTask&
task, const State& state, i
auto pass_ctx = tvm::transform::PassContext::Current();
auto mod = ScheduleToModule(sch, Array<ObjectRef>{tensors.begin(),
tensors.end()}, name,
- std::unordered_map<te::Tensor, te::Buffer>(),
- GlobalVarSupply(NameSupply("")));
+ std::unordered_map<te::Tensor, te::Buffer>(),
GlobalVarSupply());
bool disable_vectorize =
pass_ctx->GetConfig<Bool>("tir.disable_vectorize",
Bool(false)).value();
diff --git a/src/contrib/hybrid/codegen_hybrid.h
b/src/contrib/hybrid/codegen_hybrid.h
index d1f578efdd..58be2cf112 100644
--- a/src/contrib/hybrid/codegen_hybrid.h
+++ b/src/contrib/hybrid/codegen_hybrid.h
@@ -145,7 +145,7 @@ class CodeGenHybrid : public ExprFunctor<void(const
PrimExpr&, std::ostream&)>,
/*! \brief Print the current indent spaces. */
inline void PrintIndent();
/*! \brief NameSupply for allocated ids. */
- NameSupply ids_allocated = NameSupply("");
+ NameSupply ids_allocated;
/*!
* \brief Keys are either (tensors, value_index) or (variables, 0).
* Values are the corresponding IDs.*/
diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc
index 3026f6e58f..105ac063e0 100644
--- a/src/driver/driver_api.cc
+++ b/src/driver/driver_api.cc
@@ -336,8 +336,7 @@ TVM_REGISTER_GLOBAL("driver.schedule_to_module")
c_binds.insert({kv.first, kv.second});
}
}
- IRModule mod =
- ScheduleToModule(std::move(sch), args, name, c_binds,
GlobalVarSupply(NameSupply("")));
+ IRModule mod = ScheduleToModule(std::move(sch), args, name, c_binds,
GlobalVarSupply());
return mod;
});
@@ -400,8 +399,7 @@ TVM_REGISTER_GLOBAL("driver.lower_schedule")
c_binds.insert({kv.first, kv.second});
}
}
- return LowerSchedule(std::move(sch), args, name, c_binds,
GlobalVarSupply(NameSupply("")),
- simple_mode);
+ return LowerSchedule(std::move(sch), args, name, c_binds,
GlobalVarSupply(), simple_mode);
});
/**
diff --git a/src/ir/global_var_supply.cc b/src/ir/global_var_supply.cc
index 383d4445ad..571a7f304c 100644
--- a/src/ir/global_var_supply.cc
+++ b/src/ir/global_var_supply.cc
@@ -40,7 +40,7 @@ std::string GetModuleName(const IRModule& module) {
return
module->GetAttr<String>(tvm::attr::kModuleName).value_or("tvmgen_default");
}
-GlobalVarSupply::GlobalVarSupply(const Array<IRModule>& modules) :
GlobalVarSupply(NameSupply("")) {
+GlobalVarSupply::GlobalVarSupply(const Array<IRModule>& modules) :
GlobalVarSupply() {
if (!modules.empty()) {
IRModule first_mod = modules.front();
this->operator->()->name_supply_->prefix_ = GetModuleName(first_mod);
diff --git a/src/relax/backend/contrib/cutlass/codegen.cc
b/src/relax/backend/contrib/cutlass/codegen.cc
index d4b0038be3..8ae0036db7 100644
--- a/src/relax/backend/contrib/cutlass/codegen.cc
+++ b/src/relax/backend/contrib/cutlass/codegen.cc
@@ -52,7 +52,7 @@ class CodegenCutlass : public
relax::MemoizedExprTranslator<OutputType>,
public relay::contrib::CodegenCBase {
public:
CodegenCutlass(const std::string& id, const Map<Var, Expr>& bindings)
- : ext_func_id_(id), bindings_(bindings), name_sup_("") {}
+ : ext_func_id_(id), bindings_(bindings) {}
void AddParm(Var param) {
ext_func_args_.push_back(param);
diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc
index e9a513c317..f6aec79a4a 100644
--- a/src/relax/ir/block_builder.cc
+++ b/src/relax/ir/block_builder.cc
@@ -58,8 +58,7 @@ namespace relax {
//---------------------------------------
class BlockBuilderImpl : public BlockBuilderNode {
public:
- explicit BlockBuilderImpl(IRModule context_mod)
- : name_supply_(""), context_mod_(std::move(context_mod)) {}
+ explicit BlockBuilderImpl(IRModule context_mod) :
context_mod_(std::move(context_mod)) {}
~BlockBuilderImpl() {
if (!block_stack_.empty()) {
diff --git a/src/relax/transform/allocate_workspace.cc
b/src/relax/transform/allocate_workspace.cc
index fcfbf18771..1d4a017712 100644
--- a/src/relax/transform/allocate_workspace.cc
+++ b/src/relax/transform/allocate_workspace.cc
@@ -37,7 +37,7 @@ class ExternFunctionRewriter : ExprMutator {
using ExprMutator::VisitExpr_;
ExternFunctionRewriter(IRModule mod, size_t max_workspace_size)
- : ExprMutator(mod), name_sup_(""),
max_workspace_size_(max_workspace_size) {}
+ : ExprMutator(mod), max_workspace_size_(max_workspace_size) {}
std::unordered_map<const GlobalVarNode*, Function> Run() {
std::unordered_map<const GlobalVarNode*, Function> ret;
diff --git a/src/relax/transform/normalize.cc b/src/relax/transform/normalize.cc
index 0939674e81..89080ebc3e 100644
--- a/src/relax/transform/normalize.cc
+++ b/src/relax/transform/normalize.cc
@@ -178,7 +178,7 @@ class GlobalVarNormalizer : private ExprMutator {
}
private:
- explicit GlobalVarNormalizer(const IRModule& m) : ExprMutator(), module_(m),
name_supply_("") {}
+ explicit GlobalVarNormalizer(const IRModule& m) : ExprMutator(), module_(m)
{}
using ExprMutator::VisitExpr_;
diff --git a/src/relay/backend/graph_executor_codegen.cc
b/src/relay/backend/graph_executor_codegen.cc
index 868173d28c..734b3d6e43 100644
--- a/src/relay/backend/graph_executor_codegen.cc
+++ b/src/relay/backend/graph_executor_codegen.cc
@@ -622,7 +622,7 @@ class GraphExecutorCodegen : public
backend::MemoizedExprTranslator<std::vector<
/*! \brief function metadata */
Map<String, FunctionInfo> function_metadata_;
/*! \brief NameSupply */
- NameSupply name_supply_ = NameSupply("");
+ NameSupply name_supply_;
};
class GraphExecutorCodegenModule : public runtime::ModuleNode {
diff --git a/src/relay/backend/task_extraction.cc
b/src/relay/backend/task_extraction.cc
index fc45311e08..6ac7a99d35 100644
--- a/src/relay/backend/task_extraction.cc
+++ b/src/relay/backend/task_extraction.cc
@@ -75,7 +75,7 @@ Array<meta_schedule::ExtractedTask> ExtractTask(IRModule mod,
Target target,
std::vector<std::tuple<std::string, Function, IRModule>> lower_results;
- NameSupply constant_name_supply("");
+ NameSupply constant_name_supply;
PostOrderVisit(mod->Lookup("main"), [&](const Expr& exp) {
if (exp->IsInstance<FunctionNode>()) {
@@ -129,7 +129,7 @@ Array<meta_schedule::ExtractedTask> ExtractTask(IRModule
mod, Target target,
// Tasks are extracted via post order visit, return the reversed list.
std::reverse(tasks.begin(), tasks.end());
- NameSupply name_supply = NameSupply("");
+ NameSupply name_supply;
for (ExtractedTask task : tasks) {
task->task_name = name_supply->FreshName(task->task_name);
}
diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc
index 8165954749..eab4837ba8 100644
--- a/src/relay/backend/te_compiler.cc
+++ b/src/relay/backend/te_compiler.cc
@@ -136,8 +136,7 @@ TVM_REGISTER_OBJECT_TYPE(TECompilerNode);
class TECompilerImpl : public TECompilerNode {
public:
explicit TECompilerImpl(Optional<IRModule> opt_mod, Optional<String>
opt_mod_name)
- :
global_var_supply_(GlobalVarSupply(NameSupply(opt_mod_name.value_or("")))),
- constant_name_supply_(NameSupply("")) {
+ :
global_var_supply_(GlobalVarSupply(NameSupply(opt_mod_name.value_or("")))) {
// Make sure we don't collide with any existing globals in the module.
if (opt_mod) {
for (const auto& kv : opt_mod.value()->functions) {
@@ -160,7 +159,7 @@ class TECompilerImpl : public TECompilerNode {
// For now, build one module per function.
PackedFunc JIT(const CCacheKey& key) final {
- CCacheValue value = LowerInternal(key, GlobalVarSupply(NameSupply("")));
+ CCacheValue value = LowerInternal(key, GlobalVarSupply());
if (value->packed_func != nullptr) {
return value->packed_func;
}
diff --git a/src/relay/backend/te_compiler_cache.cc
b/src/relay/backend/te_compiler_cache.cc
index 2655cf6671..79a41ae050 100644
--- a/src/relay/backend/te_compiler_cache.cc
+++ b/src/relay/backend/te_compiler_cache.cc
@@ -1127,7 +1127,7 @@ std::pair<Optional<tir::PrimFunc>, std::string>
LowerToPrimFunc(const Function&
}
tir::PrimFunc LowerToPrimFunc(const Function& relay_func, Target target) {
- auto [f_opt, _] = LowerToPrimFunc(relay_func, target, NameSupply(""));
+ auto [f_opt, _] = LowerToPrimFunc(relay_func, target, NameSupply());
(void)_; // to suppress -Werror=unused-variable warning
if (f_opt) {
return f_opt.value();
@@ -1143,7 +1143,7 @@ TVM_REGISTER_GLOBAL("relay.backend.LowerToPrimFunc")
TVM_REGISTER_GLOBAL("relay.backend.LowerToTE").set_body_typed([](Function
prim_func) {
auto tgt = tvm::Target("ext_dev");
- LowerToTECompute lower_te_compute(tgt, NameSupply(""));
+ LowerToTECompute lower_te_compute(tgt, NameSupply());
auto outputs = lower_te_compute.Lower(prim_func);
return CachedFunc(tgt, GlobalVar(lower_te_compute.candidate_name_),
lower_te_compute.fn_inputs_,
outputs, te::Schedule(), tir::PrimFunc(), {},
diff --git a/src/relay/backend/te_compiler_cache.h
b/src/relay/backend/te_compiler_cache.h
index 76939a923c..502e006322 100644
--- a/src/relay/backend/te_compiler_cache.h
+++ b/src/relay/backend/te_compiler_cache.h
@@ -251,7 +251,7 @@ CachedFunc PrimFuncFor(const Function& source_func, const
Target& target,
/*! \brief A specialization of PrimFuncFor, meant to be used when the names of
constants do not
* matter. */
inline CachedFunc PrimFuncFor(const Function& source_func, const Target&
target) {
- return PrimFuncFor(source_func, target, GlobalVarSupply(NameSupply("")),
NameSupply(""));
+ return PrimFuncFor(source_func, target, GlobalVarSupply(), NameSupply());
}
CachedFunc ShapeFuncFor(const Function& prim_func, const Target& target,
diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h
index e739df0ca1..8c5e1ffd89 100644
--- a/src/target/source/codegen_c.h
+++ b/src/target/source/codegen_c.h
@@ -340,7 +340,7 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&,
std::ostream&)>,
std::unordered_map<GlobalVar, String> internal_functions_;
/* \brief Name supply to generate unique function names */
- NameSupply func_name_supply_{""};
+ NameSupply func_name_supply_;
};
} // namespace codegen
diff --git a/src/target/source/codegen_source_base.cc
b/src/target/source/codegen_source_base.cc
index 9c17458bf2..60fa786d52 100644
--- a/src/target/source/codegen_source_base.cc
+++ b/src/target/source/codegen_source_base.cc
@@ -28,7 +28,7 @@ namespace tvm {
namespace codegen {
void CodeGenSourceBase::ClearFuncState() {
- name_supply_ = NameSupply("");
+ name_supply_ = NameSupply();
ssa_assign_map_.clear();
var_idmap_.clear();
scope_mark_.clear();
diff --git a/src/target/source/codegen_source_base.h
b/src/target/source/codegen_source_base.h
index 8191ad43aa..e2312ddb77 100644
--- a/src/target/source/codegen_source_base.h
+++ b/src/target/source/codegen_source_base.h
@@ -125,7 +125,7 @@ class CodeGenSourceBase {
/*! \brief name of each variable */
std::unordered_map<const tir::VarNode*, std::string> var_idmap_;
/*! \brief NameSupply for allocation */
- NameSupply name_supply_ = NameSupply("");
+ NameSupply name_supply_;
private:
/*! \brief assignment map of ssa */
diff --git a/src/te/operation/create_primfunc.cc
b/src/te/operation/create_primfunc.cc
index c7dbf3f5e0..2eb0693685 100644
--- a/src/te/operation/create_primfunc.cc
+++ b/src/te/operation/create_primfunc.cc
@@ -109,7 +109,7 @@ struct CreateFuncInfo {
/*! \brief The buffers should be allocated at function root. */
Array<Buffer> root_alloc;
/*! \brief The NameSupply to make block name unique. */
- NameSupply name_supply = NameSupply("");
+ NameSupply name_supply;
String FreshName(String base_name) { return
name_supply->FreshName(base_name); }
diff --git a/src/tir/ir/index_map.cc b/src/tir/ir/index_map.cc
index 149e4cecd4..aed8361d04 100644
--- a/src/tir/ir/index_map.cc
+++ b/src/tir/ir/index_map.cc
@@ -311,7 +311,7 @@ IndexMap IndexMap::RenameVariables(
const std::function<Optional<String>(const Var& var)>& f_name_map) const {
std::unordered_set<std::string> used_names;
Map<Var, Var> var_remap;
- NameSupply name_supply{""};
+ NameSupply name_supply;
const IndexMapNode* n = this->get();
if (f_name_map != nullptr) {
// Collect variables with pre-defined names provided by f_name_map.
diff --git a/tests/cpp/build_module_test.cc b/tests/cpp/build_module_test.cc
index 3d2adb2355..181a1fa3de 100644
--- a/tests/cpp/build_module_test.cc
+++ b/tests/cpp/build_module_test.cc
@@ -52,7 +52,7 @@ TEST(BuildModule, Basic) {
auto target = Target("llvm");
- auto lowered = LowerSchedule(s, args, "func", binds,
GlobalVarSupply(NameSupply("")));
+ auto lowered = LowerSchedule(s, args, "func", binds, GlobalVarSupply());
auto module = build(lowered, target, Target());
auto mali_target = Target("opencl -model=Mali-T860MP4@800Mhz -device=mali");
@@ -121,7 +121,7 @@ TEST(BuildModule, Heterogeneous) {
auto args2 = Array<Tensor>({copy, C, elemwise_sub});
std::unordered_map<Tensor, Buffer> binds;
- GlobalVarSupply global_var_supply = GlobalVarSupply(NameSupply(""));
+ GlobalVarSupply global_var_supply = GlobalVarSupply();
auto lowered_s1 = LowerSchedule(fcreate_s1(), args1, "elemwise_add", binds,
global_var_supply);
auto lowered_s2 = LowerSchedule(fcreate_s2(), args2, "elemwise_sub", binds,
global_var_supply);
Map<tvm::Target, IRModule> inputs = {{target_cuda, lowered_s1},
{target_llvm, lowered_s2}};
diff --git a/tests/cpp/c_codegen_test.cc b/tests/cpp/c_codegen_test.cc
index a01921239a..5f78383049 100644
--- a/tests/cpp/c_codegen_test.cc
+++ b/tests/cpp/c_codegen_test.cc
@@ -52,8 +52,7 @@ TEST(CCodegen, MainFunctionOrder) {
auto args = Array<Tensor>({A, B, elemwise_add});
std::unordered_map<Tensor, Buffer> binds;
- auto lowered =
- LowerSchedule(fcreate(), args, "elemwise_add", binds,
GlobalVarSupply(NameSupply("")));
+ auto lowered = LowerSchedule(fcreate(), args, "elemwise_add", binds,
GlobalVarSupply());
Map<tvm::Target, IRModule> inputs = {{target_c, lowered}};
runtime::Module module = build(inputs, Target());
Array<String> functions = module->GetFunction("get_func_names", false)();
@@ -82,8 +81,7 @@ auto BuildLowered(std::string op_name, tvm::Target target) {
auto args = Array<Tensor>({A, B, op});
std::unordered_map<Tensor, Buffer> binds;
- auto lowered_s =
- LowerSchedule(fcreate_s(), args, op_name, binds,
GlobalVarSupply(NameSupply("")));
+ auto lowered_s = LowerSchedule(fcreate_s(), args, op_name, binds,
GlobalVarSupply());
return lowered_s;
}
diff --git a/tests/cpp/name_supply_test.cc b/tests/cpp/name_supply_test.cc
index 75b9ae86a9..023d2e903a 100644
--- a/tests/cpp/name_supply_test.cc
+++ b/tests/cpp/name_supply_test.cc
@@ -27,7 +27,7 @@
using namespace tvm;
NameSupply preambleNameSupply() {
- NameSupply name_supply = NameSupply("prefix");
+ NameSupply name_supply("prefix");
name_supply->FreshName("test");
return name_supply;
}
@@ -74,7 +74,7 @@ TEST(NameSupply, ReserveName) {
}
GlobalVarSupply preambleVarSupply() {
- GlobalVarSupply global_var_supply = GlobalVarSupply(NameSupply(""));
+ GlobalVarSupply global_var_supply;
global_var_supply->FreshGlobal("test");
return global_var_supply;
}