This is an automated email from the ASF dual-hosted git repository.
areusch pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new ecfd9692a0 Unify name mangling in TVM (#12066)
ecfd9692a0 is described below
commit ecfd9692a0e1763af0e34cd6203c0d5ee947e993
Author: Florin Blanaru <[email protected]>
AuthorDate: Thu Aug 11 17:18:46 2022 +0100
Unify name mangling in TVM (#12066)
* Add NameSupply and GlobalVarSupply
* Build GlobalVarSupply from IRModules instead of having it attached to an
IRModule.
* Pass GlobalVarSupply when lowering shape funcs
* Partially replace instantiations of GlobalVar with GlobalVarSupply
* Construct GlobalVarSupply from IRModule
* Add tests for supply
* Add documentation for NameSupply and GlobalVarSupply
Co-authored-by: Florin-Gabriel Blanaru <[email protected]>
---
include/tvm/driver/driver_api.h | 12 +-
include/tvm/ir/global_var_supply.h | 125 ++++++++++++++++++
include/tvm/ir/module.h | 17 +--
include/tvm/ir/name_supply.h | 123 ++++++++++++++++++
python/tvm/ir/supply.py | 141 +++++++++++++++++++++
src/auto_scheduler/feature.cc | 4 +-
src/contrib/hybrid/codegen_hybrid.cc | 106 +++++++---------
src/contrib/hybrid/codegen_hybrid.h | 10 +-
src/driver/driver_api.cc | 22 ++--
src/ir/global_var_supply.cc | 115 +++++++++++++++++
src/ir/module.cc | 23 +---
src/ir/name_supply.cc | 108 ++++++++++++++++
src/relay/backend/graph_executor_codegen.cc | 22 +---
src/relay/backend/task_extraction.cc | 4 +-
src/relay/backend/te_compiler.cc | 74 +++++------
src/relay/backend/te_compiler.h | 4 +-
src/relay/backend/te_compiler_cache.cc | 71 +++--------
src/relay/backend/te_compiler_cache.h | 8 +-
src/relay/ir/dataflow_matcher.cc | 12 +-
.../transforms/auto_scheduler_layout_rewrite.cc | 3 +-
.../transforms/meta_schedule_layout_rewrite.cc | 3 +-
src/relay/transforms/partition_graph.cc | 7 +-
src/target/source/codegen_c.cc | 54 ++++----
src/target/source/codegen_c_host.cc | 14 +-
src/target/source/codegen_cuda.cc | 14 +-
src/target/source/codegen_metal.cc | 8 +-
src/target/source/codegen_source_base.cc | 28 +---
src/target/source/codegen_source_base.h | 11 +-
src/te/operation/create_primfunc.cc | 23 ++--
src/tir/transforms/split_host_device.cc | 12 +-
tests/cpp/build_module_test.cc | 7 +-
tests/cpp/c_codegen_test.cc | 6 +-
tests/cpp/name_supply_test.cc | 129 +++++++++++++++++++
tests/python/relay/backend/test_pass_lower_te.py | 10 +-
tests/python/relay/test_name_supply.py | 72 +++++++++++
35 files changed, 1052 insertions(+), 350 deletions(-)
diff --git a/include/tvm/driver/driver_api.h b/include/tvm/driver/driver_api.h
index 48800b193c..fffcab4966 100644
--- a/include/tvm/driver/driver_api.h
+++ b/include/tvm/driver/driver_api.h
@@ -29,6 +29,7 @@
#ifndef TVM_DRIVER_DRIVER_API_H_
#define TVM_DRIVER_DRIVER_API_H_
+#include <tvm/ir/global_var_supply.h>
#include <tvm/ir/module.h>
#include <tvm/ir/transform.h>
#include <tvm/runtime/packed_func.h>
@@ -99,6 +100,7 @@ TVM_DLL IRModule LowerPrimFunc(tvm::tir::PrimFunc func,
const std::string& name,
* \param args The arguments to the function.
* \param name The name of the lowered function.
* \param binds Buffer assignments.
+ * \param global_var_supply The GlobalVarSupply to be used in the module.
* \param simple_mode Disables the loop partition pass. Defaults to false.
* \return The result module.
*/
@@ -106,7 +108,7 @@ TVM_DLL IRModule LowerPrimFunc(tvm::tir::PrimFunc func,
const std::string& name,
TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array<te::Tensor>& args,
const std::string& name,
const std::unordered_map<te::Tensor,
tir::Buffer>& binds,
- bool simple_mode = false);
+ GlobalVarSupply global_var_supply, bool
simple_mode = false);
/*!
* \brief Build an IRModule given a TE schedule, args and binds. This function
also applies
@@ -115,13 +117,14 @@ TVM_DLL IRModule LowerSchedule(te::Schedule sch, const
Array<te::Tensor>& args,
* \param args The arguments to the function (Array of Tensor, Buffer and Vars)
* \param name The name of the lowered function.
* \param binds Buffer assignments.
+ * \param global_var_supply The GlobalVarSupply to be used in the module.
* \param simple_mode Disables the loop partition pass. Defaults to false.
* \return The result module.
*/
TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array<ObjectRef>& args,
const std::string& name,
const std::unordered_map<te::Tensor,
tir::Buffer>& binds,
- bool simple_mode = false);
+ GlobalVarSupply global_var_supply, bool
simple_mode = false);
/*!
* \brief Create an IRModule out of a TE Schedule. It does not apply lowering
passes. If you want
@@ -130,10 +133,13 @@ TVM_DLL IRModule LowerSchedule(te::Schedule sch, const
Array<ObjectRef>& args,
* \param args The arguments to the function.
* \param name The name of the lowered function.
* \param binds Buffer assignments.
+ * \param global_var_supply The GlobalVarSupply to be used in the module and
when creating
+ * GlobalVars.
* \return The result module.
*/
IRModule ScheduleToModule(te::Schedule sch, const Array<ObjectRef>& args,
const std::string& name,
- const std::unordered_map<te::Tensor, tir::Buffer>&
binds);
+ const std::unordered_map<te::Tensor, tir::Buffer>&
binds,
+ GlobalVarSupply global_var_supply);
/*!
* \brief Build a device and host module for a specific target from an
IRModule.
* \param funcs The functions to be built.
diff --git a/include/tvm/ir/global_var_supply.h
b/include/tvm/ir/global_var_supply.h
new file mode 100644
index 0000000000..276c64a0d7
--- /dev/null
+++ b/include/tvm/ir/global_var_supply.h
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/ir/global_var_supply.h
+ * \brief GlobalVarSupply that can be used to generate unique \class GlobalVar.
+ */
+#ifndef TVM_IR_GLOBAL_VAR_SUPPLY_H_
+#define TVM_IR_GLOBAL_VAR_SUPPLY_H_
+
+#include <string>
+#include <unordered_map>
+
+#include "tvm/ir/expr.h"
+#include "tvm/ir/module.h"
+#include "tvm/ir/name_supply.h"
+
+namespace tvm {
+
+/*!
+ * \brief GlobalVarSupply can be used to generate unique GlobalVars.
+ */
+class GlobalVarSupplyNode : public Object {
+ public:
+ /*!
+ * \brief Empty constructor. Will use an empty NameSupply.
+ */
+ GlobalVarSupplyNode() : GlobalVarSupplyNode(NameSupply("")) {}
+
+ /*!
+ * \brief Constructor.
+ * \param name_supply The NameSupply to use for generating the names of
fresh GlobalVars.
+ * \param name_to_var_map An optional map.
+ */
+ explicit GlobalVarSupplyNode(NameSupply name_supply,
+ std::unordered_map<std::string, GlobalVar>
name_to_var_map = {});
+
+ /*!
+ * \brief Generates a unique GlobalVar from this supply.
+ * \param name The name from which the name of the GlobalVar is derived.
+ * \param add_prefix If set to true, then the prefix of the contained
NameSupply will be prepended
+ * to the name. \return A unique GlobalVar.
+ */
+ GlobalVar FreshGlobal(String name, bool add_prefix = true);
+
+ /*!
+ * \brief Looks up for a GlobalVar with the given name in this supply.
+ * If no entry is found, creates one, places it in the cache and returns it.
+ * \param name The name of the GlobalVar to search for.
+ * \param add_prefix If set to true, the prefix of the contained NameSupply
will be prepended to
+ * the name before performing the search. \return A cached GlobalVar.
+ */
+ GlobalVar UniqueGlobalFor(const String& name, bool add_prefix = true);
+
+ /*!
+ * \brief Reserves an existing GlobalVar with this supply.
+ * \param var The GlobalVar to be registered.
+ * \param allow_conflict Allow conflict with other GlobalVars that have the
same name.
+ */
+ void ReserveGlobalVar(const GlobalVar& var, bool allow_conflict = false);
+
+ void VisitAttrs(AttrVisitor* v) {}
+
+ /*! \brief The NameSupply used to generate unique name hints to GlobalVars.
*/
+ NameSupply name_supply_;
+
+ static constexpr const char* _type_key = "GlobalVarSupply";
+ static constexpr const bool _type_has_method_sequal_reduce = false;
+ static constexpr const bool _type_has_method_shash_reduce = false;
+ TVM_DECLARE_FINAL_OBJECT_INFO(GlobalVarSupplyNode, Object);
+
+ private:
+ std::unordered_map<std::string, GlobalVar> name_to_var_map_;
+};
+
+/*!
+ * \brief Managed reference class to GlobalVarSupplyNode.
+ * \sa GlobalVarSupplyNode
+ */
+class GlobalVarSupply : public ObjectRef {
+ public:
+ /*!
+ * \brief Constructor.
+ * \param name_supply The NameSupply to be used when generating new
GlobalVars.
+ * \param name_to_var_map An optional map.
+ */
+ TVM_DLL explicit GlobalVarSupply(const NameSupply& name_supply,
+ std::unordered_map<std::string, GlobalVar>
name_to_var_map = {});
+
+ /*!
+ * \brief Constructs a supply from an array of IRModules. GlobalVars
generated by this supply are
+ * guaranteed not to conflict with any GlobalVars that belong to the
modules. \param modules Array
+ * of IRModules.
+ */
+ TVM_DLL explicit GlobalVarSupply(const Array<IRModule>& modules);
+
+ /*!
+ * \brief Constructs a GlobalVarSupply from an IRModule. GlobalVars
generated by this supply are
+ * guaranteed not to conflict with GlobalVars that belong to the modules.
\param module The
+ * IRModule.
+ */
+ TVM_DLL explicit GlobalVarSupply(const IRModule module);
+
+ TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(GlobalVarSupply, ObjectRef,
GlobalVarSupplyNode);
+};
+
+} // namespace tvm
+
+#endif // TVM_IR_GLOBAL_VAR_SUPPLY_H_
diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h
index f73f2230df..7313b4f783 100644
--- a/include/tvm/ir/module.h
+++ b/include/tvm/ir/module.h
@@ -323,14 +323,6 @@ class IRModuleNode : public Object {
/*! \brief Helper function for registering a typedef's constructors */
void RegisterConstructors(const GlobalTypeVar& var, const TypeData& type);
- /*!
- * \brief Returns a version of \p name which is unique amongst all function
definitions in module.
- *
- * \param name The original name.
- * \return Updated name which is unique.
- */
- String GetUniqueName(const String& name);
-
/*! \brief A map from string names to global variables that
* ensures global uniqueness.
*/
@@ -481,6 +473,15 @@ namespace attr {
// Following are attributes for IRModule only.
+/*!
+ * \brief Name of the module
+ *
+ * Type: String
+ *
+ * \sa tvm::runtime::String
+ */
+constexpr const char* kModuleName = "mod_name";
+
/*!
* \brief Executor targeted by the module
*
diff --git a/include/tvm/ir/name_supply.h b/include/tvm/ir/name_supply.h
new file mode 100644
index 0000000000..a85a6fe70a
--- /dev/null
+++ b/include/tvm/ir/name_supply.h
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/ir/name_supply.h
+ * \brief NameSupply that can be used to generate unique variable names.
+ */
+#ifndef TVM_IR_NAME_SUPPLY_H_
+#define TVM_IR_NAME_SUPPLY_H_
+
+#include <string>
+#include <unordered_map>
+#include <utility>
+
+#include "tvm/ir/expr.h"
+
+namespace tvm {
+
+/*!
+ * \brief NameSupply can be used to generate unique names.
+ */
+class NameSupplyNode : public Object {
+ public:
+ /*!
+ * \brief Empty constructor. Needed by the TVM_REGISTER_NODE_TYPE macro.
+ */
+ NameSupplyNode() = default;
+
+ /*!
+ * \brief Constructor.
+ * \param prefix The prefix to be used with this NameSupply.
+ * \param name_map The map used to guarantee uniqueness.
+ */
+ NameSupplyNode(const String& prefix, std::unordered_map<std::string, int>
name_map)
+ : prefix_(prefix), name_map(std::move(name_map)) {}
+
+ /*!
+ * \brief Generates a unique name from this NameSupply.
+ * \param name The name from which the generated name is derived.
+ * \param add_prefix If set to true, then the prefix of this NameSupply will
be prepended to the
+ * name. \return A unique name.
+ */
+ String FreshName(const String& name, bool add_prefix = true);
+
+ /*!
+ * \brief Reserves an existing name with this NameSupply.
+ * \param name The name to be reserved.
+ * \param add_prefix If set to true, then the prefix of this NameSupply will
be prepended to the
+ * name before reserving it. \return The name that was reserved with the
NameSupply. It can be
+ * different if a prefix is added.
+ */
+ String ReserveName(const String& name, bool add_prefix = true);
+
+ /*!
+ * \brief Checks if this NameSupply already generated a name.
+ * \param name The name to check.
+ * \param add_prefix If set to true, then the prefix of this NameSupply will
be prepended to the
+ * name before checking for it. \return True if the name has already been
generated. False
+ * otherwise.
+ */
+ bool ContainsName(const String& name, bool add_prefix = true);
+
+ void VisitAttrs(AttrVisitor* v) {}
+
+ // Prefix for all GlobalVar names. It can be empty.
+ std::string prefix_;
+
+ static constexpr const char* _type_key = "NameSupply";
+ static constexpr const bool _type_has_method_sequal_reduce = false;
+ static constexpr const bool _type_has_method_shash_reduce = false;
+ TVM_DECLARE_FINAL_OBJECT_INFO(NameSupplyNode, Object);
+
+ private:
+ /*! \brief Helper function to add the NameSupply prefix to the name. */
+ String add_prefix_to_name(const String& name);
+
+ /*!
+ * \brief Function that will generate a unique name.
+ * \param name The name to be used as a base.
+ * \return A unique name.
+ */
+ std::string GetUniqueName(std::string name);
+
+ /*! \brief A map that is used to generate unique names. */
+ std::unordered_map<std::string, int> name_map;
+};
+
+/*!
+ * \brief Managed reference class to NameSupplyNode.
+ * \sa NameSupplyNode
+ */
+class NameSupply : public ObjectRef {
+ public:
+ /*!
+ * \brief Constructor.
+ * \param prefix The prefix to be used with this NameSupply.
+ * \param name_map An optional map.
+ */
+ TVM_DLL explicit NameSupply(const String& prefix,
+ std::unordered_map<std::string, int> name_map =
{});
+
+ TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(NameSupply, ObjectRef, NameSupplyNode);
+};
+
+} // namespace tvm
+
+#endif // TVM_IR_NAME_SUPPLY_H_
diff --git a/python/tvm/ir/supply.py b/python/tvm/ir/supply.py
new file mode 100644
index 0000000000..095ac43c03
--- /dev/null
+++ b/python/tvm/ir/supply.py
@@ -0,0 +1,141 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Suppliers that are used to guarantee uniqueness of names and GlobalVars."""
+import tvm
+from tvm import Object, IRModule
+from . import _ffi_api
+
+
+@tvm._ffi.register_object("NameSupply")
+class NameSupply(Object):
+ """NameSupply that can be used to generate unique names.
+
+ Parameters
+ ----------
+ prefix: The prefix to be added to the generated names.
+ """
+
+ def __init__(self, prefix=""):
+ self.__init_handle_by_constructor__(_ffi_api.NameSupply, prefix)
+
+ def fresh_name(self, name, add_prefix=True):
+ """Generates a unique name from this NameSupply.
+
+ Parameters
+ ----------
+ name: String
+ The name from which the generated name is derived.
+
+ add_prefix: bool
+ If set to true, then the prefix of this NameSupply will be
prepended to the name.
+ """
+ return _ffi_api.NameSupply_FreshName(self, name, add_prefix)
+
+ def reserve_name(self, name, add_prefix=True):
+ """Reserves an existing name with this NameSupply.
+
+ Parameters
+ ----------
+ name: String
+ The name to be reserved.
+
+ add_prefix: bool
+ If set to true, then the prefix of this NameSupply will be
prepended to the name
+ before reserving it.
+ """
+ return _ffi_api.NameSupply_ReserveName(self, name, add_prefix)
+
+ def contains_name(self, name, add_prefix=True):
+ """Checks if this NameSupply already generated a name.
+
+ Parameters
+ ----------
+ name: String
+ The name to check.
+
+ add_prefix: bool
+ If set to true, then the prefix of this NameSupply will be
prepended to the name
+ before checking for it.
+ """
+ return _ffi_api.NameSupply_ContainsName(self, name, add_prefix)
+
+
+@tvm._ffi.register_object("GlobalVarSupply")
+class GlobalVarSupply(Object):
+ """GlobalVarSupply that holds a mapping between names and GlobalVars.
+
+ GlobalVarSupply can be used to generate new GlobalVars with a unique name.
+ It also can be used to retrieve previously generated GlobalVars based on a
name.
+
+ Parameters
+ ----------
+ value: Union[List[IRModule], IRModule, NameSupply]
+ The IRModules used to build this GlobalVarSupply or a NameSupply.
+ """
+
+ def __init__(self, value=None):
+ if value is None:
+ name_supply = NameSupply("")
+
self.__init_handle_by_constructor__(_ffi_api.GlobalVarSupply_NameSupply,
name_supply)
+ elif isinstance(value, NameSupply):
+
self.__init_handle_by_constructor__(_ffi_api.GlobalVarSupply_NameSupply, value)
+ elif isinstance(value, (list, tvm.container.Array)):
+
self.__init_handle_by_constructor__(_ffi_api.GlobalVarSupply_IRModules, value)
+ elif isinstance(value, IRModule):
+
self.__init_handle_by_constructor__(_ffi_api.GlobalVarSupply_IRModule, value)
+
+ def fresh_global(self, name, add_prefix=True):
+ """Generates a unique GlobalVar from this supply.
+
+ Parameters
+ ----------
+ name: String
+ The name from which the name of the GlobalVar is derived.
+
+ add_prefix: bool
+ If set to true, then the prefix of the contained NameSupply will
be prepended
+ to the name.
+ """
+ return _ffi_api.GlobalVarSupply_FreshGlobal(self, name, add_prefix)
+
+ def unique_global_for(self, name, add_prefix=True):
+ """Looks up for a GlobalVar with the given name in this supply. If no
entry is found
+ , creates one, places it in the cache and returns it.
+
+ Parameters
+ ----------
+ name: String
+ The name of the GlobalVar to search for.
+
+ add_prefix: bool
+ If set to true, the prefix of the contained NameSupply will be
prepended to the
+ name before performing the search.
+ """
+ return _ffi_api.GlobalVarSupply_UniqueGlobalFor(self, name, add_prefix)
+
+ def reserve_global(self, global_var, allow_conflict=False):
+ """Reserves an existing GlobalVar with this supply.
+
+ Parameters
+ ----------
+ global_var: GlobalVar
+ The GlobalVar to be registered.
+
+ allow_conflict: bool
+ Allow conflict with other GlobalVars that have the same name
+ """
+ return _ffi_api.GlobalVarSupply_ReserveGlobalVar(self, global_var,
allow_conflict)
diff --git a/src/auto_scheduler/feature.cc b/src/auto_scheduler/feature.cc
index ab60aef9ae..c930bf0c4e 100644
--- a/src/auto_scheduler/feature.cc
+++ b/src/auto_scheduler/feature.cc
@@ -27,6 +27,7 @@
#include <tvm/auto_scheduler/measure.h>
#include <tvm/auto_scheduler/measure_record.h>
#include <tvm/driver/driver_api.h>
+#include <tvm/ir/global_var_supply.h>
#include <tvm/runtime/registry.h>
#include <tvm/support/parallel_for.h>
#include <tvm/te/operation.h>
@@ -1371,7 +1372,8 @@ void GetPerStoreFeaturesWorkerFunc(const SearchTask&
task, const State& state, i
auto pass_ctx = tvm::transform::PassContext::Current();
auto mod = ScheduleToModule(sch, Array<ObjectRef>{tensors.begin(),
tensors.end()}, name,
- std::unordered_map<te::Tensor, te::Buffer>());
+ std::unordered_map<te::Tensor, te::Buffer>(),
+ GlobalVarSupply(NameSupply("")));
bool disable_vectorize =
pass_ctx->GetConfig<Bool>("tir.disable_vectorize",
Bool(false)).value();
diff --git a/src/contrib/hybrid/codegen_hybrid.cc
b/src/contrib/hybrid/codegen_hybrid.cc
index 24c7ee74cd..79c9e567b4 100644
--- a/src/contrib/hybrid/codegen_hybrid.cc
+++ b/src/contrib/hybrid/codegen_hybrid.cc
@@ -42,24 +42,6 @@ std::string dot_to_underscore(std::string s) {
return s;
}
-std::string CodeGenHybrid::GetUniqueName(std::string prefix) {
- prefix = dot_to_underscore(prefix);
- auto it = ids_allocated_.find(prefix);
- if (it != ids_allocated_.end()) {
- while (true) {
- std::ostringstream os;
- os << prefix << (++it->second);
- std::string name = os.str();
- if (ids_allocated_.count(name) == 0) {
- prefix = name;
- break;
- }
- }
- }
- ids_allocated_[prefix] = 0;
- return prefix;
-}
-
std::string CodeGenHybrid::Finish() { return stream.str(); }
void CodeGenHybrid::PrintType(DataType t, std::ostream& os) {
@@ -428,7 +410,7 @@ std::string CodeGenHybrid::GetVarID(const VarNode* v) {
if (id_map_.count(key)) {
return id_map_[key];
}
- return id_map_[key] = GetUniqueName(v->name_hint);
+ return id_map_[key] = ids_allocated->FreshName(v->name_hint);
}
std::string CodeGenHybrid::GetTensorID(const Tensor& tensor) {
@@ -440,57 +422,57 @@ std::string CodeGenHybrid::GetTensorID(const Tensor&
tensor) {
if (tensor->op->num_outputs() > 1) {
name_hint += "_v" + std::to_string(tensor->value_index);
}
- return id_map_[key] = GetUniqueName(name_hint);
+ return id_map_[key] = ids_allocated->FreshName(name_hint);
}
void CodeGenHybrid::ReserveKeywords() {
- GetUniqueName("def");
- GetUniqueName("for");
- GetUniqueName("in");
- GetUniqueName("range");
- GetUniqueName("True");
- GetUniqueName("False");
- GetUniqueName("unroll");
- GetUniqueName("const_range");
- GetUniqueName("parallel");
- GetUniqueName("vectorize");
- GetUniqueName("bind");
- GetUniqueName("threadIdx.x");
- GetUniqueName("threadIdx.y");
- GetUniqueName("threadIdx.z");
- GetUniqueName("blockIdx.x");
- GetUniqueName("blockIdx.y");
- GetUniqueName("blockIdx.z");
- GetUniqueName("vthread");
- GetUniqueName("allocate");
- GetUniqueName("output_tensor");
- GetUniqueName("sqrt");
- GetUniqueName("log");
- GetUniqueName("tanh");
- GetUniqueName("power");
- GetUniqueName("exp");
- GetUniqueName("sigmoid");
- GetUniqueName("popcount");
- GetUniqueName("likely");
- GetUniqueName("int8");
- GetUniqueName("int16");
- GetUniqueName("int32");
- GetUniqueName("int64");
- GetUniqueName("uint8");
- GetUniqueName("uint16");
- GetUniqueName("uint32");
- GetUniqueName("uint64");
- GetUniqueName("float16");
- GetUniqueName("float32");
- GetUniqueName("float64");
- GetUniqueName("ceil_div");
- GetUniqueName("max_num_threads");
+ ids_allocated->ReserveName("def");
+ ids_allocated->ReserveName("for");
+ ids_allocated->ReserveName("in");
+ ids_allocated->ReserveName("range");
+ ids_allocated->ReserveName("True");
+ ids_allocated->ReserveName("False");
+ ids_allocated->ReserveName("unroll");
+ ids_allocated->ReserveName("const_range");
+ ids_allocated->ReserveName("parallel");
+ ids_allocated->ReserveName("vectorize");
+ ids_allocated->ReserveName("bind");
+ ids_allocated->ReserveName("threadIdx.x");
+ ids_allocated->ReserveName("threadIdx.y");
+ ids_allocated->ReserveName("threadIdx.z");
+ ids_allocated->ReserveName("blockIdx.x");
+ ids_allocated->ReserveName("blockIdx.y");
+ ids_allocated->ReserveName("blockIdx.z");
+ ids_allocated->ReserveName("vthread");
+ ids_allocated->ReserveName("allocate");
+ ids_allocated->ReserveName("output_tensor");
+ ids_allocated->ReserveName("sqrt");
+ ids_allocated->ReserveName("log");
+ ids_allocated->ReserveName("tanh");
+ ids_allocated->ReserveName("power");
+ ids_allocated->ReserveName("exp");
+ ids_allocated->ReserveName("sigmoid");
+ ids_allocated->ReserveName("popcount");
+ ids_allocated->ReserveName("likely");
+ ids_allocated->ReserveName("int8");
+ ids_allocated->ReserveName("int16");
+ ids_allocated->ReserveName("int32");
+ ids_allocated->ReserveName("int64");
+ ids_allocated->ReserveName("uint8");
+ ids_allocated->ReserveName("uint16");
+ ids_allocated->ReserveName("uint32");
+ ids_allocated->ReserveName("uint64");
+ ids_allocated->ReserveName("float16");
+ ids_allocated->ReserveName("float32");
+ ids_allocated->ReserveName("float64");
+ ids_allocated->ReserveName("ceil_div");
+ ids_allocated->ReserveName("max_num_threads");
}
void CodeGenHybrid::DumpStmt(const Stmt& stmt, const Array<ObjectRef>& inputs,
const Array<Tensor>& outputs, const std::string&
name) {
ReserveKeywords();
- GetUniqueName(name);
+ ids_allocated->ReserveName(name);
stream << "def " << name << "(";
for (size_t i = 0; i < inputs.size(); ++i) {
diff --git a/src/contrib/hybrid/codegen_hybrid.h
b/src/contrib/hybrid/codegen_hybrid.h
index da45ffb6a8..53026c7fc3 100644
--- a/src/contrib/hybrid/codegen_hybrid.h
+++ b/src/contrib/hybrid/codegen_hybrid.h
@@ -24,6 +24,7 @@
#ifndef TVM_CONTRIB_HYBRID_CODEGEN_HYBRID_H_
#define TVM_CONTRIB_HYBRID_CODEGEN_HYBRID_H_
+#include <tvm/ir/name_supply.h>
#include <tvm/target/codegen.h>
#include <tvm/te/operation.h>
#include <tvm/te/schedule.h>
@@ -145,19 +146,14 @@ class CodeGenHybrid : public ExprFunctor<void(const
PrimExpr&, std::ostream&)>,
const int tab_{4};
/*! \brief Print the current indent spaces. */
inline void PrintIndent();
- /*! \brief Keys are ids allocated, and values are the suffix to prevent
double-name. */
- std::map<std::string, int> ids_allocated_;
+ /*! \brief NameSupply for allocated ids. */
+ NameSupply ids_allocated = NameSupply("");
/*!
* \brief Keys are either (tensors, value_index) or (variables, 0).
* Values are the corresponding IDs.*/
std::map<std::pair<const Object*, int>, std::string> id_map_;
/*! \brief Variables (keys) binded to the threads (values). */
std::map<const VarNode*, std::string> binds_;
- /*!
- * \brief Find an unallocated name for the given prefix.
- * \param prefix The given prefix.
- */
- std::string GetUniqueName(std::string prefix);
/*! \brief The output code string builder. */
std::stringstream stream;
/*!
diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc
index 6f4fb618d3..cbf809a267 100644
--- a/src/driver/driver_api.cc
+++ b/src/driver/driver_api.cc
@@ -261,7 +261,8 @@ IRModule ApplyPasses(IRModule mod, transform::Sequential
seq) {
// Convert te schedule to IRModule
IRModule ScheduleToModule(te::Schedule sch, const Array<ObjectRef>& args,
const std::string& name,
- const std::unordered_map<te::Tensor, tir::Buffer>&
binds) {
+ const std::unordered_map<te::Tensor, tir::Buffer>&
binds,
+ GlobalVarSupply global_var_supply) {
sch = sch.normalize();
transform::PassContext pass_ctx = transform::PassContext::Current();
@@ -289,7 +290,8 @@ IRModule ScheduleToModule(te::Schedule sch, const
Array<ObjectRef>& args, const
if (noalias) {
f = WithAttr(std::move(f), "tir.noalias", Bool(true));
}
- return IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
+ GlobalVar global_var = global_var_supply->UniqueGlobalFor(name, false);
+ return IRModule(Map<GlobalVar, BaseFunc>({{global_var, f}}));
}
TVM_REGISTER_GLOBAL("driver.schedule_to_module")
@@ -302,7 +304,8 @@ TVM_REGISTER_GLOBAL("driver.schedule_to_module")
c_binds.insert({kv.first, kv.second});
}
}
- IRModule mod = ScheduleToModule(std::move(sch), args, name, c_binds);
+ IRModule mod =
+ ScheduleToModule(std::move(sch), args, name, c_binds,
GlobalVarSupply(NameSupply("")));
return mod;
});
@@ -337,17 +340,19 @@ TVM_REGISTER_GLOBAL("driver.lower_primfunc")
});
IRModule LowerSchedule(te::Schedule sch, const Array<te::Tensor>& args, const
std::string& name,
- const std::unordered_map<te::Tensor, tir::Buffer>&
binds, bool simple_mode) {
+ const std::unordered_map<te::Tensor, tir::Buffer>&
binds,
+ GlobalVarSupply global_var_supply, bool simple_mode) {
Array<ObjectRef> ref_args;
for (ObjectRef x : args) {
ref_args.push_back(x);
}
- return LowerSchedule(std::move(sch), ref_args, name, binds);
+ return LowerSchedule(std::move(sch), ref_args, name, binds,
global_var_supply);
}
IRModule LowerSchedule(te::Schedule sch, const Array<ObjectRef>& args, const
std::string& name,
- const std::unordered_map<te::Tensor, tir::Buffer>&
binds, bool simple_mode) {
- IRModule mod = ScheduleToModule(std::move(sch), args, name, binds);
+ const std::unordered_map<te::Tensor, tir::Buffer>&
binds,
+ GlobalVarSupply global_var_supply, bool simple_mode) {
+ IRModule mod = ScheduleToModule(std::move(sch), args, name, binds,
global_var_supply);
// Get the legacy TE pass list
Array<transform::Pass> pass_list = CreatePassList(simple_mode);
return LowerWithPassList(mod, pass_list);
@@ -363,7 +368,8 @@ TVM_REGISTER_GLOBAL("driver.lower_schedule")
c_binds.insert({kv.first, kv.second});
}
}
- return LowerSchedule(std::move(sch), args, name, c_binds, simple_mode);
+ return LowerSchedule(std::move(sch), args, name, c_binds,
GlobalVarSupply(NameSupply("")),
+ simple_mode);
});
/**
diff --git a/src/ir/global_var_supply.cc b/src/ir/global_var_supply.cc
new file mode 100644
index 0000000000..383d4445ad
--- /dev/null
+++ b/src/ir/global_var_supply.cc
@@ -0,0 +1,115 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file global_var_supply.cc
+ * \brief GlobalVarSupply that can be used to generate unique GlobalVars.
+ */
+#include "tvm/ir/global_var_supply.h"
+
+#include <tvm/runtime/registry.h>
+
+#include <utility>
+
+#include "tvm/ir/expr.h"
+
+namespace tvm {
+GlobalVarSupply::GlobalVarSupply(const NameSupply& name_supply,
+ std::unordered_map<std::string, GlobalVar>
name_to_var_map) {
+ auto n = make_object<GlobalVarSupplyNode>(name_supply, name_to_var_map);
+ data_ = std::move(n);
+}
+
+std::string GetModuleName(const IRModule& module) {
+ return
module->GetAttr<String>(tvm::attr::kModuleName).value_or("tvmgen_default");
+}
+
+GlobalVarSupply::GlobalVarSupply(const Array<IRModule>& modules) :
GlobalVarSupply(NameSupply("")) {
+ if (!modules.empty()) {
+ IRModule first_mod = modules.front();
+ this->operator->()->name_supply_->prefix_ = GetModuleName(first_mod);
+ }
+ for (auto& mod : modules) {
+ for (auto kv : mod->functions) {
+ this->operator->()->ReserveGlobalVar(kv.first);
+ }
+ }
+}
+
+GlobalVarSupply::GlobalVarSupply(const IRModule module)
+ : GlobalVarSupply(Array<IRModule>{module}) {}
+
+void GlobalVarSupplyNode::ReserveGlobalVar(const GlobalVar& var, bool
allow_conflict) {
+ name_supply_->ReserveName(var->name_hint, false);
+ if (!allow_conflict) {
+ ICHECK(name_to_var_map_.count(var->name_hint) == 0)
+ << "GlobalVar " << var << " conflicts by name in this supply.";
+ }
+ name_to_var_map_[var->name_hint] = var;
+}
+
+GlobalVarSupplyNode::GlobalVarSupplyNode(NameSupply name_supply,
+ std::unordered_map<std::string,
GlobalVar> name_to_var_map)
+ : name_supply_(std::move(name_supply)),
name_to_var_map_(std::move(name_to_var_map)) {}
+
+GlobalVar GlobalVarSupplyNode::UniqueGlobalFor(const String& name, bool
add_prefix) {
+ String final_name = name_supply_->ReserveName(name, add_prefix);
+
+ auto it = name_to_var_map_.find(final_name);
+ if (it != name_to_var_map_.end()) {
+ return it->second;
+ } else {
+ GlobalVar var = GlobalVar(final_name);
+ name_to_var_map_.emplace(final_name, var);
+ return var;
+ }
+}
+
+GlobalVar GlobalVarSupplyNode::FreshGlobal(String name, bool add_prefix) {
+ String final_name = name_supply_->FreshName(name, add_prefix);
+ ICHECK(name_to_var_map_.find(final_name) == name_to_var_map_.end())
+ << "GlobalVar already exists for name " << final_name;
+ GlobalVar var = GlobalVar(final_name);
+ name_to_var_map_.emplace(final_name, var);
+ return var;
+}
+
+TVM_REGISTER_NODE_TYPE(GlobalVarSupplyNode);
+
+TVM_REGISTER_GLOBAL("ir.GlobalVarSupply_NameSupply")
+ .set_body_typed([](const NameSupply& name_supply) { return
GlobalVarSupply(name_supply); });
+
+TVM_REGISTER_GLOBAL("ir.GlobalVarSupply_IRModule").set_body_typed([](IRModule
mod) {
+ return GlobalVarSupply(std::move(mod));
+});
+
+TVM_REGISTER_GLOBAL("ir.GlobalVarSupply_IRModules").set_body_typed([](const
Array<IRModule>& mods) {
+ return GlobalVarSupply(mods);
+});
+
+TVM_REGISTER_GLOBAL("ir.GlobalVarSupply_FreshGlobal")
+ .set_body_method<GlobalVarSupply>(&GlobalVarSupplyNode::FreshGlobal);
+
+TVM_REGISTER_GLOBAL("ir.GlobalVarSupply_UniqueGlobalFor")
+ .set_body_method<GlobalVarSupply>(&GlobalVarSupplyNode::UniqueGlobalFor);
+
+TVM_REGISTER_GLOBAL("ir.GlobalVarSupply_ReserveGlobalVar")
+ .set_body_method<GlobalVarSupply>(&GlobalVarSupplyNode::ReserveGlobalVar);
+
+} // namespace tvm
diff --git a/src/ir/module.cc b/src/ir/module.cc
index 6f2c9f9fe9..8d6de5a536 100644
--- a/src/ir/module.cc
+++ b/src/ir/module.cc
@@ -21,6 +21,7 @@
* \file module.cc
* \brief The global module in Relay.
*/
+#include <tvm/ir/global_var_supply.h>
#include <tvm/ir/module.h>
#include <tvm/node/structural_equal.h>
#include <tvm/runtime/registry.h>
@@ -292,20 +293,6 @@ Constructor IRModuleNode::LookupTag(const int32_t tag) {
return (*it).second;
}
-String IRModuleNode::GetUniqueName(const String& name) {
- String result = name;
- int suffix = 0;
- while (true) {
- auto it = global_var_map_.find(result);
- if (it == global_var_map_.end()) {
- return result;
- }
- std::ostringstream os;
- os << name << "_" << ++suffix;
- result = os.str();
- }
-}
-
/*!
* \brief Renames global type/term variables to prefer the
GlobalTypeVar/GlobalVar in the lhs
* ('one') side above the rhs ('two').
@@ -397,12 +384,14 @@ std::pair<IRModule, GlobalVar>
IRModule::FromExprInContext(
func = relay::Function(relay::FreeVars(expr), expr, Type(),
relay::FreeTypeVars(expr, mod), {});
}
+ GlobalVar main_gv;
+ auto global_var_supply = GlobalVarSupply(mod);
if (gv_name.empty()) {
// Bind function to 'main' (though rename if would clash with existing
'main').
- gv_name = mod->GetUniqueName("main");
+ main_gv = global_var_supply->FreshGlobal("main", false);
+ } else {
+ main_gv = global_var_supply->UniqueGlobalFor(gv_name, false);
}
-
- GlobalVar main_gv(gv_name);
mod->Add(main_gv, func);
return {mod, main_gv};
}
diff --git a/src/ir/name_supply.cc b/src/ir/name_supply.cc
new file mode 100644
index 0000000000..93f568253c
--- /dev/null
+++ b/src/ir/name_supply.cc
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file name_supply.cc
+ * \brief NameSupply that can be used to generate unique variable names.
+ */
+#include "tvm/ir/name_supply.h"
+
+#include <tvm/runtime/registry.h>
+
+#include <utility>
+
+namespace tvm {
+
+NameSupply::NameSupply(const String& prefix, std::unordered_map<std::string,
int> name_map) {
+ auto n = make_object<NameSupplyNode>(prefix, std::move(name_map));
+ data_ = std::move(n);
+}
+
+String NameSupplyNode::ReserveName(const String& name, bool add_prefix) {
+ String final_name = name;
+ if (add_prefix) {
+ final_name = add_prefix_to_name(name);
+ }
+ name_map[final_name] = 0;
+ return final_name;
+}
+
+String NameSupplyNode::FreshName(const String& name, bool add_prefix) {
+ String unique_name = name;
+ if (add_prefix) {
+ unique_name = add_prefix_to_name(name);
+ }
+ unique_name = GetUniqueName(unique_name);
+ return unique_name;
+}
+
+bool NameSupplyNode::ContainsName(const String& name, bool add_prefix) {
+ String unique_name = name;
+ if (add_prefix) {
+ unique_name = add_prefix_to_name(name);
+ }
+
+ return name_map.count(unique_name);
+}
+
+String NameSupplyNode::add_prefix_to_name(const String& name) {
+ if (prefix_.empty()) {
+ return name;
+ }
+
+ std::ostringstream ss;
+ ICHECK(name.defined());
+ ss << prefix_ << "_" << name;
+ return ss.str();
+}
+
+std::string NameSupplyNode::GetUniqueName(std::string name) {
+ for (size_t i = 0; i < name.size(); ++i) {
+ if (name[i] == '.') name[i] = '_';
+ }
+ auto it = name_map.find(name);
+ if (it != name_map.end()) {
+ auto new_name = name;
+ while (!name_map.insert({new_name, 0}).second) {
+ std::ostringstream os;
+ os << name << "_" << (++it->second);
+ new_name = os.str();
+ }
+ return new_name;
+ }
+ name_map[name] = 0;
+ return name;
+}
+
+TVM_REGISTER_NODE_TYPE(NameSupplyNode);
+
+TVM_REGISTER_GLOBAL("ir.NameSupply").set_body_typed([](String prefix) {
+ return NameSupply(prefix);
+});
+
+TVM_REGISTER_GLOBAL("ir.NameSupply_FreshName")
+ .set_body_method<NameSupply>(&NameSupplyNode::FreshName);
+
+TVM_REGISTER_GLOBAL("ir.NameSupply_ReserveName")
+ .set_body_method<NameSupply>(&NameSupplyNode::ReserveName);
+
+TVM_REGISTER_GLOBAL("ir.NameSupply_ContainsName")
+ .set_body_method<NameSupply>(&NameSupplyNode::ContainsName);
+
+} // namespace tvm
diff --git a/src/relay/backend/graph_executor_codegen.cc
b/src/relay/backend/graph_executor_codegen.cc
index c72511775a..ab725d82e6 100644
--- a/src/relay/backend/graph_executor_codegen.cc
+++ b/src/relay/backend/graph_executor_codegen.cc
@@ -466,7 +466,7 @@ class GraphExecutorCodegen : public
backend::MemoizedExprTranslator<std::vector<
}
// Compute the operator name, because we used the get unique name when
generating the kernel.
- auto op_name = _GetUniqueName(func_name);
+ auto op_name = name_supply_->FreshName(func_name);
auto node = GraphOpNode::make_node_ptr(op_name, GraphAttrs(), func_name,
inputs, attrs);
return AddNode(node, call);
}
@@ -604,22 +604,6 @@ class GraphExecutorCodegen : public
backend::MemoizedExprTranslator<std::vector<
writer->EndObject();
}
- /*!
- * \brief Get unique name for func
- *
- * \param name
- * \return std::string
- */
- std::string _GetUniqueName(const std::string& name) {
- if (!name_map_.count(name)) {
- name_map_[name] = 1;
- return name;
- }
- auto index = name_map_[name];
- name_map_[name] += 1;
- return _GetUniqueName(name + std::to_string(index));
- }
-
protected:
/*! \brief nodes */
std::vector<GraphObjectPtr> nodes_;
@@ -645,8 +629,8 @@ class GraphExecutorCodegen : public
backend::MemoizedExprTranslator<std::vector<
String mod_name_;
/*! \brief function metadata */
Map<String, FunctionInfo> function_metadata_;
- /*! \brief name map */
- std::unordered_map<std::string, size_t> name_map_;
+ /*! \brief NameSupply */
+ NameSupply name_supply_ = NameSupply("");
};
class GraphExecutorCodegenModule : public runtime::ModuleNode {
diff --git a/src/relay/backend/task_extraction.cc
b/src/relay/backend/task_extraction.cc
index af4b49b4f1..c577e8e356 100644
--- a/src/relay/backend/task_extraction.cc
+++ b/src/relay/backend/task_extraction.cc
@@ -74,9 +74,9 @@ Array<meta_schedule::ExtractedTask> ExtractTask(
});
// Tasks are extracted via post order visit, return the reversed list.
std::reverse(tasks.begin(), tasks.end());
- std::unordered_map<std::string, int> name_map;
+ NameSupply name_supply = NameSupply("");
for (ExtractedTask task : tasks) {
- task->task_name = tec::GetUniqueName(task->task_name, &name_map);
+ task->task_name = name_supply->FreshName(task->task_name);
}
return tasks;
}
diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc
index 8ca5a32b7f..5c79ed2070 100644
--- a/src/relay/backend/te_compiler.cc
+++ b/src/relay/backend/te_compiler.cc
@@ -92,6 +92,7 @@
#include <tvm/driver/driver_api.h>
#include <tvm/ir/attrs.h>
#include <tvm/ir/function.h>
+#include <tvm/ir/name_supply.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/annotation.h>
#include <tvm/relay/attrs/call.h>
@@ -134,30 +135,33 @@ TVM_REGISTER_OBJECT_TYPE(TECompilerNode);
class TECompilerImpl : public TECompilerNode {
public:
- explicit TECompilerImpl(Optional<IRModule> opt_mod) {
+ explicit TECompilerImpl(Optional<IRModule> opt_mod, Optional<String>
opt_mod_name) {
+ String mod_name = opt_mod_name.value_or("");
+ NameSupply name_supply = NameSupply(mod_name /* prefix */);
+ global_var_supply_ = GlobalVarSupply(name_supply);
// Make sure we don't collide with any existing globals in the module.
if (opt_mod) {
for (const auto& kv : opt_mod.value()->functions) {
- name_map_[kv.first->name_hint] = 1;
+ global_var_supply_->name_supply_->ReserveName(kv.first->name_hint,
false);
}
}
}
// Lower the function.
- CachedFunc Lower(const CCacheKey& key, std::function<String(String)>
mangle_fn) {
- return LowerInternal(key, mangle_fn)->cached_func;
+ CachedFunc Lower(const CCacheKey& key) {
+ return LowerInternal(key, global_var_supply_)->cached_func;
}
+ // TODO(gigiblender): Only to be called by the global TE compiler.
+ // Remove this when the global TE compiler is removed.
CachedFunc Lower(const CCacheKey& key, const String mod_name) {
- auto mangle_fn = [mod_name](String name) { return
runtime::get_name_mangled(mod_name, name); };
-
- return Lower(key, mangle_fn);
+ global_var_supply_->name_supply_->prefix_ = mod_name;
+ return LowerInternal(key, global_var_supply_)->cached_func;
}
// For now, build one module per function.
PackedFunc JIT(const CCacheKey& key) final {
- auto mangle_fn = [](String name) { return name; };
- CCacheValue value = LowerInternal(key, mangle_fn);
+ CCacheValue value = LowerInternal(key, GlobalVarSupply(NameSupply("")));
if (value->packed_func != nullptr) {
return value->packed_func;
}
@@ -335,7 +339,7 @@ class TECompilerImpl : public TECompilerNode {
private:
// implement lowered func
- CCacheValue LowerInternal(const CCacheKey& key,
std::function<String(String)> mangle_fn) {
+ CCacheValue LowerInternal(const CCacheKey& key, GlobalVarSupply
global_var_supply) {
VLOG(1) << "lowering:" << std::endl
<< PrettyPrint(key->source_func) << std::endl
<< "for target:" << std::endl
@@ -360,7 +364,7 @@ class TECompilerImpl : public TECompilerNode {
if (opt_compiler.defined()) {
// Don't compile now since we don't have anywhere to put the resulting
runtime module.
// Instead place the original definition in the cache and wait for
LowerExternalFunctions.
- IRModule ir_module;
+ IRModule ir_module({}, {});
Optional<String> opt_global_symbol =
key->source_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
ICHECK(opt_global_symbol.defined()) << "External function has not been
attached a name yet.";
@@ -369,7 +373,7 @@ class TECompilerImpl : public TECompilerNode {
// the module's globals. Furthermore, the external codegen tool must
bind the compiled
// function to the "global_symbol" attribute on the source_func. So do
not use GetUniqueName
// here.
- auto global_var = GlobalVar(opt_global_symbol.value());
+ auto global_var =
global_var_supply->UniqueGlobalFor(opt_global_symbol.value(), false);
global_var->checked_type_ = key->source_func->checked_type();
ir_module->Add(global_var, key->source_func);
value->cached_func = CachedFunc(key->target, global_var, {}, {},
te::Schedule{nullptr},
@@ -388,10 +392,7 @@ class TECompilerImpl : public TECompilerNode {
With<Target> target_scope(key->target);
ICHECK(!value->cached_func.defined());
- value->cached_func = PrimFuncFor(key->source_func, key->target,
[&](std::string name) {
- auto mangled = mangle_fn(name);
- return GetUniqueName(mangled, &name_map_);
- });
+ value->cached_func = PrimFuncFor(key->source_func, key->target,
global_var_supply);
if (value->cached_func->prim_func.defined()) {
VLOG(1) << "Lowering PrimFunc";
@@ -443,16 +444,11 @@ class TECompilerImpl : public TECompilerNode {
}
auto func_name = value->cached_func->prim_fn_var->name_hint;
VLOG(1) << "scheduling";
- IRModule scheduled_module =
- tvm::LowerSchedule(value->cached_func->schedule, all_args,
func_name, binds);
+ IRModule scheduled_module =
tvm::LowerSchedule(value->cached_func->schedule, all_args,
+ func_name, binds,
global_var_supply);
scheduled_module->Update(tir::transform::BindParams(all_consts)(scheduled_module));
- // Unfortunately the above machinery creates its own GlobalVars instead
of using *the*
- // GlobalVar we established above. Fix this before the confusion spreads
any further.
- // TODO(mbs): LowerSchedule should be given prim_fn_gvar instead of
func_name.
for (const auto& kv : scheduled_module->functions) {
- GlobalVar global_var = kv.first->name_hint ==
value->cached_func->prim_fn_var->name_hint
- ? value->cached_func->prim_fn_var
- : kv.first;
+ GlobalVar global_var = kv.first;
auto func = kv.second;
// Propagate the structural hash of the relay function to the tir
// function so associations can be made between the two.
@@ -498,9 +494,7 @@ class TECompilerImpl : public TECompilerNode {
using tvm::transform::PassContext;
With<PassContext> fresh_pass_ctx_scope(PassContext::Create());
- value->cached_func = ShapeFuncFor(key->source_func, key->target,
[&](std::string name) {
- return GetUniqueName(name, &name_map_);
- });
+ value->cached_func = ShapeFuncFor(key->source_func, key->target,
global_var_supply_);
ICHECK(
value->cached_func->funcs->Lookup(value->cached_func->prim_fn_var).as<tir::PrimFuncNode>());
@@ -527,8 +521,8 @@ class TECompilerImpl : public TECompilerNode {
/*! \brief compiler cache lock*/
std::mutex mutex_;
- /*! \brief internal name map to get an unique name */
- std::unordered_map<std::string, int> name_map_;
+ /*! \brief internal GlobalVarSupply to get unique GlobalVars */
+ GlobalVarSupply global_var_supply_;
/*! \brief internal compiler cache */
std::unordered_map<CCacheKey, CCacheValue> cache_;
/*! \brief internal compiler cache for shape funcs */
@@ -539,15 +533,16 @@ class TECompilerImpl : public TECompilerNode {
Map<GlobalVar, String> device_contexts_;
};
-TECompiler::TECompiler(Optional<IRModule> opt_mod) {
- auto object = make_object<TECompilerImpl>(std::move(opt_mod));
+TECompiler::TECompiler(Optional<IRModule> opt_mod, Optional<String> mod_name) {
+ auto object = make_object<TECompilerImpl>(std::move(opt_mod),
std::move(mod_name));
data_ = object;
}
/*! \brief The global TE compiler */
// TODO(mbs): To be terminated with extreme prejudice.
TECompiler& TECompiler::Global() {
- static TECompiler* inst = new
TECompiler(make_object<TECompilerImpl>(Optional<IRModule>()));
+ static TECompiler* inst =
+ new TECompiler(make_object<TECompilerImpl>(Optional<IRModule>(),
Optional<String>()));
return *inst;
}
TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.use_auto_scheduler", Bool);
@@ -629,12 +624,11 @@ using AnalysisRemapping = std::unordered_map<Expr, Expr,
ObjectHash, ObjectEqual
class LowerTensorExprMutator : public DeviceAwareExprMutator {
public:
LowerTensorExprMutator(IRModule module, ProcessFn process_fn,
CompilationConfig config,
- String module_name, TECompiler compiler)
+ TECompiler compiler)
: DeviceAwareExprMutator(module),
module_(std::move(module)),
process_fn_(std::move(process_fn)),
config_(std::move(config)),
- module_name_(std::move(module_name)),
compiler_(std::move(compiler)),
debug_op_(Op::Get("debug")) {}
@@ -925,7 +919,7 @@ class LowerTensorExprMutator : public
DeviceAwareExprMutator {
// codegen.
CCacheKey key(Downcast<Function>(primitive_func), target,
GetVirtualDevice(GetRef<Call>(call_node)));
- CachedFunc cfunc = compiler_->Lower(key, module_name_);
+ CachedFunc cfunc = compiler_->Lower(key);
ICHECK(cfunc.defined());
return MakeLoweredCall(primitive_func, cfunc->prim_fn_var,
std::move(new_args),
call_node->span, target, cfunc->funcs->functions);
@@ -942,17 +936,15 @@ class LowerTensorExprMutator : public
DeviceAwareExprMutator {
// module we'll ultimately emit for each required device-type. Note that a
primitive may be
// lowered for multiple device types, each which will be assigned a fresh
var.
std::unordered_map<const VarNode*, BaseFunc> primitive_functions_;
- String module_name_;
TECompiler compiler_;
// Cache ops that need to be frequently used later to reduce lookup overhead.
const Op& debug_op_;
};
-Pass LowerTensorExpr(const String& module_name, TECompiler compiler, ProcessFn
process_fn,
- CompilationConfig config) {
+Pass LowerTensorExpr(TECompiler compiler, ProcessFn process_fn,
CompilationConfig config) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)>
pass_func =
[=](Function func, IRModule module, PassContext ctx) {
- LowerTensorExprMutator lower_te(module, process_fn, config,
module_name, compiler);
+ LowerTensorExprMutator lower_te(module, process_fn, config, compiler);
return Downcast<Function>(lower_te.Mutate(func));
};
return CreateFunctionPass(pass_func, 0, "LowerTensorExpr", {});
@@ -1184,7 +1176,7 @@ void UpdateFunctionMetadata(BaseFunc func,
/*! \brief Main lowering driving. */
IRModule LowerTE(const IRModule& module, const String& module_name, ProcessFn
process_fn,
CompilationConfig config) {
- TECompiler compiler(module);
+ TECompiler compiler(module, module_name);
// TODO(mbs): This is all unnecessarily convoluted. Better would be to
accumulate the rewritten
// module as we go (including rewritten Functions, lowered primitives, and
runtime modules
@@ -1199,7 +1191,7 @@ IRModule LowerTE(const IRModule& module, const String&
module_name, ProcessFn pr
// - Calls to functions tagged with "Primitive" are compiled to PrimFuncs,
and calls updated
// (using call_lowered convention).
IRModule updated_module =
- LowerTensorExpr(module_name, compiler, std::move(process_fn),
std::move(config))(module);
+ LowerTensorExpr(compiler, std::move(process_fn),
std::move(config))(module);
// The Functions tagged with "Compiler" are now residing in the cache ready
to be
// compiled by LowerExternalFunctions. However we still need a record of
them in the
diff --git a/src/relay/backend/te_compiler.h b/src/relay/backend/te_compiler.h
index 5d16da4b8b..f2ba84014a 100644
--- a/src/relay/backend/te_compiler.h
+++ b/src/relay/backend/te_compiler.h
@@ -73,7 +73,7 @@ class TECompilerNode : public Object {
* \param key The key to the cached function.
* \return The result.
*/
- virtual CachedFunc Lower(const CCacheKey& key, std::function<String(String)>
mangle_fn) = 0;
+ virtual CachedFunc Lower(const CCacheKey& key) = 0;
/*!
* \brief Get lowered result.
@@ -137,7 +137,7 @@ class TECompilerNode : public Object {
/*! \brief cache entry used in compile engine */
class TECompiler : public ObjectRef {
public:
- explicit TECompiler(Optional<IRModule> opt_mod = {});
+ explicit TECompiler(Optional<IRModule> opt_mod = {}, Optional<String>
mod_name = {});
explicit TECompiler(ObjectPtr<Object> n) : ObjectRef(n) {}
TECompilerNode* operator->() { return
static_cast<TECompilerNode*>(get_mutable()); }
using ContainerType = TECompilerNode;
diff --git a/src/relay/backend/te_compiler_cache.cc
b/src/relay/backend/te_compiler_cache.cc
index bfb351f82b..da52d94b4e 100644
--- a/src/relay/backend/te_compiler_cache.cc
+++ b/src/relay/backend/te_compiler_cache.cc
@@ -137,8 +137,7 @@ class LowerToTECompute : public
backend::MemoizedExprTranslator<Array<te::Tensor
explicit LowerToTECompute(Target target)
: target_(target), device_copy_op_(Op::Get("device_copy")) {}
- Array<te::Tensor> Lower(const Function& relay_func,
- std::function<std::string(std::string)> renamer) {
+ Array<te::Tensor> Lower(const Function& relay_func) {
for (Var param : relay_func->params) {
Array<tvm::te::Tensor> inputs;
for (const auto& ttype : FlattenTupleType(param->checked_type())) {
@@ -327,15 +326,15 @@ class ScheduleBuilder : public ExprVisitor {
}
}
- CachedFunc Create(const Function& relay_func,
std::function<std::string(std::string)> renamer) {
+ CachedFunc Create(const Function& relay_func, GlobalVarSupply
global_var_supply) {
LowerToTECompute lower_te_compute(target_);
- Array<te::Tensor> tensor_outs = lower_te_compute.Lower(relay_func,
renamer);
+ Array<te::Tensor> tensor_outs = lower_te_compute.Lower(relay_func);
Array<te::Tensor> fn_inputs = lower_te_compute.fn_inputs_;
VisitExpr(relay_func->body);
// TODO(mbs): This should be the definitive global by which the PrimFunc
is known and
// no other GlobalVar ctors should appear inside the lowering machinery.
- auto prim_fn_var = GlobalVar(renamer(lower_te_compute.candidate_name_));
+ auto prim_fn_var =
global_var_supply->FreshGlobal(lower_te_compute.candidate_name_);
prim_fn_var->checked_type_ = relay_func->checked_type();
// Fusion over tupled results may leave identity relationships
@@ -402,8 +401,9 @@ class ScheduleBuilder : public ExprVisitor {
}
}
- return CachedFunc(target_, prim_fn_var, fn_inputs, tensor_outs, schedule,
prim_func, {},
- IRModule(Map<GlobalVar, BaseFunc>({})),
lower_te_compute.constant_tensors_);
+ IRModule funcs = IRModule(Map<GlobalVar, BaseFunc>({}));
+ return CachedFunc(target_, prim_fn_var, fn_inputs, tensor_outs, schedule,
prim_func, {}, funcs,
+ lower_te_compute.constant_tensors_);
}
void VisitExpr_(const CallNode* call_node) final {
@@ -446,8 +446,8 @@ class ScheduleBuilder : public ExprVisitor {
* The funcs field in cache is not yet populated.
*/
CachedFunc PrimFuncFor(const Function& source_func, const Target& target,
- std::function<std::string(std::string)> renamer) {
- return ScheduleBuilder(target).Create(source_func, renamer);
+ GlobalVarSupply global_var_supply) {
+ return ScheduleBuilder(target).Create(source_func, global_var_supply);
}
// Creates shape function from functor.
@@ -456,7 +456,7 @@ class MakeShapeFunc : public
backend::MemoizedExprTranslator<Array<te::Tensor>>
MakeShapeFunc() {}
CachedFunc Create(const Function& prim_func, const Target& target,
- std::function<std::string(std::string)> renamer) {
+ GlobalVarSupply global_var_supply) {
VLOG_CONTEXT << "MakeShapeFunc";
TShapeDataDependent shape_func_param_states;
@@ -527,8 +527,7 @@ class MakeShapeFunc : public
backend::MemoizedExprTranslator<Array<te::Tensor>>
// TODO(mbs): This should be the definitive global by which the PrimFunc
is known and
// no other GlobalVar ctors should appear inside the lowering machinery.
- auto func_name = renamer(candidate_name);
- auto prim_fn_gvar = GlobalVar(func_name);
+ auto prim_fn_gvar = global_var_supply->FreshGlobal(candidate_name);
// Gather the result types, again from the p.o.v. of the shape function
rather than
// the primitive it is derived for.
@@ -569,19 +568,10 @@ class MakeShapeFunc : public
backend::MemoizedExprTranslator<Array<te::Tensor>>
With<PassContext> fresh_pass_ctx_scope(PassContext::Create());
std::unordered_map<te::Tensor, tir::Buffer> binds;
- IRModule lowered_module = tvm::LowerSchedule(schedule, all_args,
func_name, binds);
-
- // Unfortunately the above machinery creates its own GlobalVars instead of
using *the*
- // GlobalVar we established above. Fix this before the confusion spreads
any further.
- // TODO(mbs): LowerSchedule should be given prim_fn_gvar instead of
func_name.
- IRModule fixed_lowered_module;
- for (const auto& kv : lowered_module->functions) {
- GlobalVar global_var =
- kv.first->name_hint == prim_fn_gvar->name_hint ? prim_fn_gvar :
kv.first;
- fixed_lowered_module->Add(global_var, kv.second);
- }
+ IRModule lowered_module =
+ tvm::LowerSchedule(schedule, all_args, prim_fn_gvar->name_hint, binds,
global_var_supply);
return CachedFunc(target, prim_fn_gvar, inputs, outputs, schedule,
tir::PrimFunc{nullptr},
- shape_func_param_states, fixed_lowered_module);
+ shape_func_param_states, lowered_module);
}
Array<te::Tensor> VisitExpr(const Expr& expr) final {
@@ -791,15 +781,14 @@ class MakeShapeFunc : public
backend::MemoizedExprTranslator<Array<te::Tensor>>
};
CachedFunc ShapeFuncFor(const Function& prim_func, const Target& target,
- std::function<std::string(std::string)> renamer) {
- return MakeShapeFunc().Create(prim_func, target, renamer);
+ GlobalVarSupply global_var_supply) {
+ return MakeShapeFunc().Create(prim_func, target, global_var_supply);
}
std::pair<Array<te::Tensor>, std::string> LowerTECompute(const Function&
source_func, Target target,
bool return_inputs) {
LowerToTECompute lower_te_compute(target);
- Array<te::Tensor> outputs =
- lower_te_compute.Lower(source_func, [](std::string name) { return name;
});
+ Array<te::Tensor> outputs = lower_te_compute.Lower(source_func);
// Following ScheduleBuilder, remove placeholder ops from outputs.
tvm::Array<te::Tensor> tensor_outs;
for (const auto& tensor : outputs) {
@@ -814,34 +803,10 @@ std::pair<Array<te::Tensor>, std::string>
LowerTECompute(const Function& source_
return std::make_pair(tensor_outs, lower_te_compute.candidate_name_);
}
-/*!
- * \brief Get unique name from name.
- * \param name The orginal name.
- * \return Updated name which is unique.
- */
-std::string GetUniqueName(std::string name, std::unordered_map<std::string,
int>* name_map_) {
- for (size_t i = 0; i < name.length(); ++i) {
- if (name[i] == '.') name[i] = '_';
- }
- while (true) {
- auto it = name_map_->find(name);
- if (it == name_map_->end()) {
- (*name_map_)[name] = 1;
- return name;
- } else {
- std::ostringstream os;
- os << name << "_" << it->second;
- ++(it->second);
- name = os.str();
- }
- }
- return name;
-}
-
TVM_REGISTER_GLOBAL("relay.backend.LowerToTE").set_body_typed([](Function
prim_func) {
auto tgt = tvm::Target("ext_dev");
LowerToTECompute lower_te_compute(tgt);
- auto outputs = lower_te_compute.Lower(prim_func, [&](std::string name) {
return name; });
+ auto outputs = lower_te_compute.Lower(prim_func);
return CachedFunc(tgt, GlobalVar(lower_te_compute.candidate_name_),
lower_te_compute.fn_inputs_,
outputs, te::Schedule(), tir::PrimFunc(), {},
IRModule(Map<GlobalVar, BaseFunc>({})),
lower_te_compute.constant_tensors_);
diff --git a/src/relay/backend/te_compiler_cache.h
b/src/relay/backend/te_compiler_cache.h
index ac26198260..894a5f5be5 100644
--- a/src/relay/backend/te_compiler_cache.h
+++ b/src/relay/backend/te_compiler_cache.h
@@ -24,6 +24,7 @@
#ifndef TVM_RELAY_BACKEND_TE_COMPILER_CACHE_H_
#define TVM_RELAY_BACKEND_TE_COMPILER_CACHE_H_
+#include <tvm/ir/name_supply.h>
#include <tvm/node/structural_equal.h>
#include <tvm/node/structural_hash.h>
#include <tvm/relay/analysis.h>
@@ -227,13 +228,10 @@ std::pair<Array<te::Tensor>, std::string>
LowerTECompute(const Function& source_
* The funcs field in cache is not yet populated.
*/
CachedFunc PrimFuncFor(const Function& source_func, const Target& target,
- std::function<std::string(std::string)> renamer);
+ GlobalVarSupply global_var_supply);
CachedFunc ShapeFuncFor(const Function& prim_func, const Target& target,
- std::function<std::string(std::string)> renamer);
-
-// TODO(mbs): Bring name uniqification under control -- this is replicated in
quite a few places.
-std::string GetUniqueName(std::string name, std::unordered_map<std::string,
int>* name_map);
+ GlobalVarSupply global_var_supply);
// implementations
inline size_t CCacheKeyNode::Hash() const {
diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc
index b2776a41c5..42fec9e27a 100644
--- a/src/relay/ir/dataflow_matcher.cc
+++ b/src/relay/ir/dataflow_matcher.cc
@@ -22,6 +22,7 @@
* \brief The dataflow pattern matcher for Relay.
*/
+#include <tvm/ir/global_var_supply.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/dataflow_matcher.h>
#include <tvm/relay/expr_functor.h>
@@ -438,15 +439,8 @@ Expr InferType(const Expr& expr) {
Expr InferTypeWithModule(const Expr& expr, const IRModule& m) {
IRModule mod(m->functions, m->type_definitions, m->Imports());
- int idx = 0;
- std::string gv_name;
- do {
- std::ostringstream oss;
- oss << "_tmp" << idx;
- gv_name = oss.str();
- ++idx;
- } while (mod->ContainGlobalVar(gv_name));
- GlobalVar gvar(gv_name);
+ GlobalVarSupply global_var_supply = GlobalVarSupply(mod);
+ GlobalVar gvar = global_var_supply->FreshGlobal("_tmp", false);
BaseFunc func;
if (expr.as<FunctionNode>()) {
func = Downcast<Function>(expr);
diff --git a/src/relay/transforms/auto_scheduler_layout_rewrite.cc
b/src/relay/transforms/auto_scheduler_layout_rewrite.cc
index c538dac048..25111cec8e 100644
--- a/src/relay/transforms/auto_scheduler_layout_rewrite.cc
+++ b/src/relay/transforms/auto_scheduler_layout_rewrite.cc
@@ -126,8 +126,7 @@ Expr AutoSchedulerLayoutRewriter::VisitExpr_(const
CallNode* n) {
CHECK(f) << "Could not find auto_scheduler.enter_layout_rewrite
function.";
(*f)();
- tec::PrimFuncFor(GetRef<Function>(func), Target::Current(),
- [](std::string name) { return name; });
+ tec::PrimFuncFor(GetRef<Function>(func), Target::Current(),
GlobalVarSupply(NameSupply("")));
f = runtime::Registry::Get("auto_scheduler.exit_layout_rewrite");
CHECK(f) << "Could not find ansor.exit_layout_rewrite function.";
diff --git a/src/relay/transforms/meta_schedule_layout_rewrite.cc
b/src/relay/transforms/meta_schedule_layout_rewrite.cc
index b817802f17..8a70f224c6 100644
--- a/src/relay/transforms/meta_schedule_layout_rewrite.cc
+++ b/src/relay/transforms/meta_schedule_layout_rewrite.cc
@@ -127,8 +127,7 @@ Expr MetaScheduleLayoutRewriter::VisitExpr_(const CallNode*
call) {
if (const auto* func = call->op.as<FunctionNode>()) {
LayoutIndexQueue* self = LayoutIndexQueue::Global();
self->queue_.clear();
- tec::PrimFuncFor(GetRef<Function>(func), Target::Current(),
- [](std::string name) { return name; });
+ tec::PrimFuncFor(GetRef<Function>(func), Target::Current(),
GlobalVarSupply(NameSupply("")));
if (!self->queue_.empty()) {
std::deque<tir::IndexMap> queue = std::move(self->queue_);
self->queue_.clear();
diff --git a/src/relay/transforms/partition_graph.cc
b/src/relay/transforms/partition_graph.cc
index bc1ed518d4..e2df2e4272 100644
--- a/src/relay/transforms/partition_graph.cc
+++ b/src/relay/transforms/partition_graph.cc
@@ -333,14 +333,15 @@ class Partitioner : public MixedModeMutator {
WithAttr(std::move(global_region_func), attr::kCompiler,
tvm::runtime::String(target));
global_region_func = WithAttr(std::move(global_region_func),
attr::kInline, tvm::Integer(1));
- std::string fname = name;
- ICHECK(!module_->ContainGlobalVar(fname)) << "Global function " << fname
<< " already exists";
+ GlobalVarSupply global_var_supply = GlobalVarSupply(module_);
+ GlobalVar glob_func = global_var_supply->FreshGlobal(name, false);
+ ICHECK(!module_->ContainGlobalVar(glob_func->name_hint))
+ << "Global function " << glob_func->name_hint << " already exists";
// Create a global function and add it to the IRModule for the region.
// This way we lift the functions that should be handled by external
// codegen to the module scope and rely on the pass manager to prevent
// relay function level passes (i.e. simplify inference and fusion)
// optimizing it.
- GlobalVar glob_func(fname);
module_->Add(glob_func, global_region_func);
module_ = relay::transform::InferType()(module_);
diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc
index 3fe7fa50d3..6b1ca81d85 100644
--- a/src/target/source/codegen_c.cc
+++ b/src/target/source/codegen_c.cc
@@ -45,33 +45,33 @@ void CodeGenC::InitFuncState(const PrimFunc& f) {
void CodeGenC::ReserveKeywordsAsUnique() {
// skip the first underscore, so SSA variable starts from _1
- GetUniqueName("_");
- GetUniqueName("extern");
- GetUniqueName("void");
- GetUniqueName("int");
- GetUniqueName("float");
- GetUniqueName("double");
- GetUniqueName("char");
- GetUniqueName("unsigned");
- GetUniqueName("short");
- GetUniqueName("long");
- GetUniqueName("if");
- GetUniqueName("else");
- GetUniqueName("switch");
- GetUniqueName("case");
- GetUniqueName("default");
- GetUniqueName("for");
- GetUniqueName("do");
- GetUniqueName("while");
- GetUniqueName("goto");
- GetUniqueName("register");
- GetUniqueName("continue");
- GetUniqueName("break");
- GetUniqueName("typedef");
- GetUniqueName("struct");
- GetUniqueName("enum");
- GetUniqueName("union");
- GetUniqueName("return");
+ name_supply_->ReserveName("_");
+ name_supply_->ReserveName("extern");
+ name_supply_->ReserveName("void");
+ name_supply_->ReserveName("int");
+ name_supply_->ReserveName("float");
+ name_supply_->ReserveName("double");
+ name_supply_->ReserveName("char");
+ name_supply_->ReserveName("unsigned");
+ name_supply_->ReserveName("short");
+ name_supply_->ReserveName("long");
+ name_supply_->ReserveName("if");
+ name_supply_->ReserveName("else");
+ name_supply_->ReserveName("switch");
+ name_supply_->ReserveName("case");
+ name_supply_->ReserveName("default");
+ name_supply_->ReserveName("for");
+ name_supply_->ReserveName("do");
+ name_supply_->ReserveName("while");
+ name_supply_->ReserveName("goto");
+ name_supply_->ReserveName("register");
+ name_supply_->ReserveName("continue");
+ name_supply_->ReserveName("break");
+ name_supply_->ReserveName("typedef");
+ name_supply_->ReserveName("struct");
+ name_supply_->ReserveName("enum");
+ name_supply_->ReserveName("union");
+ name_supply_->ReserveName("return");
}
void CodeGenC::AddFunction(const PrimFunc& f) {
diff --git a/src/target/source/codegen_c_host.cc
b/src/target/source/codegen_c_host.cc
index 54975d166e..a47158d378 100644
--- a/src/target/source/codegen_c_host.cc
+++ b/src/target/source/codegen_c_host.cc
@@ -42,7 +42,7 @@
namespace tvm {
namespace codegen {
-CodeGenCHost::CodeGenCHost() { module_name_ =
GetUniqueName("__tvm_module_ctx"); }
+CodeGenCHost::CodeGenCHost() { module_name_ =
name_supply_->FreshName("__tvm_module_ctx"); }
void CodeGenCHost::Init(bool output_ssa, bool emit_asserts, std::string
target_str,
const std::unordered_set<std::string>& devices) {
@@ -207,8 +207,8 @@ void CodeGenCHost::PrintGetFuncFromBackend(const
std::string& func_name,
void CodeGenCHost::PrintFuncCall(const std::string& packed_func_name, int
num_args) {
this->PrintIndent();
- std::string ret_val = GetUniqueName("ret_val");
- std::string ret_type_code = GetUniqueName("ret_type_code");
+ std::string ret_val = name_supply_->FreshName("ret_val");
+ std::string ret_type_code = name_supply_->FreshName("ret_type_code");
this->stream << "TVMValue " << ret_val << ";\n";
this->PrintIndent();
this->stream << "int " << ret_type_code << ";\n";
@@ -231,8 +231,8 @@ void CodeGenCHost::PrintFuncCall(const std::string&
packed_func_name, int num_ar
void CodeGenCHost::PrintFuncCallC(const std::string& packed_func_name, int
num_args,
const std::string& resource_handle_name) {
this->PrintIndent();
- std::string ret_val = GetUniqueName("ret_val");
- std::string ret_type_code = GetUniqueName("ret_type_code");
+ std::string ret_val = name_supply_->FreshName("ret_val");
+ std::string ret_type_code = name_supply_->FreshName("ret_type_code");
this->stream << "TVMValue " << ret_val << ";\n";
this->PrintIndent();
this->stream << "int " << ret_type_code << ";\n";
@@ -264,7 +264,7 @@ std::string CodeGenCHost::GetPackedName(const CallNode* op)
{
if (it != declared_globals_.end()) {
unique_name = it->second;
} else {
- unique_name = GetUniqueName(packed_func_name);
+ unique_name = name_supply_->FreshName(packed_func_name);
declared_globals_[packed_func_name] = unique_name;
decl_stream << "static void* " << unique_name << " = NULL;\n";
}
@@ -310,7 +310,7 @@ CodeGenCHost::FunctionInfo
CodeGenCHost::GetFunctionInfo(const CallNode* op,
void CodeGenCHost::VisitExpr_(const CallNode* op, std::ostream& os) { //
NOLINT(*)
if (op->op.same_as(builtin::tvm_stack_alloca())) {
- std::string stack_name = GetUniqueName("stack");
+ std::string stack_name = name_supply_->FreshName("stack");
const std::string& type = op->args[0].as<StringImmNode>()->value;
const IntImmNode* num = op->args[1].as<IntImmNode>();
ICHECK(num != nullptr);
diff --git a/src/target/source/codegen_cuda.cc
b/src/target/source/codegen_cuda.cc
index dde1d112ed..7350292167 100644
--- a/src/target/source/codegen_cuda.cc
+++ b/src/target/source/codegen_cuda.cc
@@ -43,8 +43,8 @@ CodeGenCUDA::CodeGenCUDA() { restrict_keyword_ =
"__restrict__"; }
void CodeGenCUDA::Init(bool output_ssa) {
CodeGenC::Init(output_ssa);
- vid_global_barrier_state_ =
GetUniqueName(runtime::symbol::tvm_global_barrier_state);
- vid_global_barrier_expect_ = GetUniqueName("__barrier_expect");
+ vid_global_barrier_state_ =
name_supply_->FreshName(runtime::symbol::tvm_global_barrier_state);
+ vid_global_barrier_expect_ = name_supply_->FreshName("__barrier_expect");
ICHECK_EQ(vid_global_barrier_state_,
runtime::symbol::tvm_global_barrier_state);
}
@@ -403,7 +403,7 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) {
// NOLINT(*)
void CodeGenCUDA::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr
lhs, PrimExpr rhs,
std::ostream& os) { // NOLINT(*)
// Delcare the result.
- std::string sret = GetUniqueName("_");
+ std::string sret = name_supply_->FreshName("_");
this->PrintIndent();
this->PrintType(t, stream);
stream << ' ' << sret << ";\n";
@@ -555,7 +555,7 @@ void CodeGenCUDA::PrintStorageSync(const CallNode* op) {
this->PrintIndent();
this->stream << "atomicAdd(&" << vid_global_barrier_state_ << ", 1);\n";
this->PrintIndent();
- std::string ptr = GetUniqueName("pf");
+ std::string ptr = name_supply_->FreshName("pf");
this->stream << "volatile unsigned* " << ptr << " = &" <<
vid_global_barrier_state_ << ";\n";
this->PrintIndent();
this->stream << vid_global_barrier_expect_ << " += " << num_blocks <<
";\n";
@@ -589,7 +589,7 @@ void CodeGenCUDA::VisitExpr_(const CastNode* op,
std::ostream& os) {
// We could emit make_float4 like calls, but the emitted code looks
// too compact to read. Emit this as vectorized unary ops.
- std::string sret = GetUniqueName("_");
+ std::string sret = name_supply_->FreshName("_");
this->PrintIndent();
this->PrintType(target_ty, stream);
stream << ' ' << sret << ";\n";
@@ -631,7 +631,7 @@ void CodeGenCUDA::PrintCallExtern(Type ret_type, String
global_symbol, const Arr
// v = __ret;
//
// Declare the result vector.
- std::string sret = GetUniqueName("_");
+ std::string sret = name_supply_->FreshName("_");
this->PrintIndent();
this->PrintType(ret_dtype, stream);
stream << ' ' << sret << ";\n";
@@ -1138,7 +1138,7 @@ void CodeGenCUDA::VisitExpr_(const SelectNode* op,
std::ostream& os) {
ICHECK(op->false_value->dtype == op->dtype && op->true_value->dtype ==
op->dtype &&
op->dtype.lanes() == op->condition.dtype().lanes());
- std::string r_var = GetUniqueName("_");
+ std::string r_var = name_supply_->FreshName("_");
this->PrintIndent();
this->PrintType(op->dtype, stream);
stream << ' ' << r_var << ";\n";
diff --git a/src/target/source/codegen_metal.cc
b/src/target/source/codegen_metal.cc
index 0ec6179115..b3ca3eb461 100644
--- a/src/target/source/codegen_metal.cc
+++ b/src/target/source/codegen_metal.cc
@@ -55,7 +55,7 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) {
// clear previous generated state.
this->InitFuncState(f);
// skip the first underscore, so SSA variable starts from _1
- GetUniqueName("_");
+ name_supply_->FreshName("_");
// add to alloc buffer type.
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
@@ -94,7 +94,7 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) {
}
// Setup normal arguments.
size_t nargs = f->params.size() - num_buffer;
- std::string varg = GetUniqueName("arg");
+ std::string varg = name_supply_->FreshName("arg");
if (nargs != 0) {
std::string arg_buf_type = static_cast<std::string>(global_symbol.value())
+ "_args_t";
stream << " constant " << arg_buf_type << "& " << varg << " [[ buffer("
<< num_buffer
@@ -127,8 +127,8 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) {
decl_stream << "};\n\n";
}
// Setup the thread group info.
- ICHECK_EQ(GetUniqueName("threadIdx"), "threadIdx");
- ICHECK_EQ(GetUniqueName("blockIdx"), "blockIdx");
+ ICHECK_EQ(name_supply_->FreshName("threadIdx"), "threadIdx");
+ ICHECK_EQ(name_supply_->FreshName("blockIdx"), "blockIdx");
int work_dim = 0;
auto thread_axis =
f->GetAttr<Array<tir::IterVar>>(tir::attr::kDeviceThreadAxis).value();
diff --git a/src/target/source/codegen_source_base.cc
b/src/target/source/codegen_source_base.cc
index 2353d2e6ba..75833fd936 100644
--- a/src/target/source/codegen_source_base.cc
+++ b/src/target/source/codegen_source_base.cc
@@ -28,34 +28,14 @@ namespace tvm {
namespace codegen {
void CodeGenSourceBase::ClearFuncState() {
- name_alloc_map_.clear();
+ name_supply_ = NameSupply("");
ssa_assign_map_.clear();
var_idmap_.clear();
scope_mark_.clear();
}
-std::string CodeGenSourceBase::GetUniqueName(std::string prefix) {
- for (size_t i = 0; i < prefix.size(); ++i) {
- if (prefix[i] == '.') prefix[i] = '_';
- }
- auto it = name_alloc_map_.find(prefix);
- if (it != name_alloc_map_.end()) {
- while (true) {
- std::ostringstream os;
- os << prefix << (++it->second);
- std::string name = os.str();
- if (name_alloc_map_.count(name) == 0) {
- prefix = name;
- break;
- }
- }
- }
- name_alloc_map_[prefix] = 0;
- return prefix;
-}
-
std::string CodeGenSourceBase::SSAGetID(std::string src, DataType t) {
- if (name_alloc_map_.count(src)) return src;
+ if (name_supply_->ContainsName(src)) return src;
auto it = ssa_assign_map_.find(src);
if (it != ssa_assign_map_.end()) {
if (scope_mark_.at(it->second.scope_id)) {
@@ -63,7 +43,7 @@ std::string CodeGenSourceBase::SSAGetID(std::string src,
DataType t) {
}
}
SSAEntry e;
- e.vid = GetUniqueName("_");
+ e.vid = name_supply_->FreshName("_");
e.scope_id = static_cast<int>(scope_mark_.size() - 1);
ssa_assign_map_[src] = e;
this->PrintIndent();
@@ -74,7 +54,7 @@ std::string CodeGenSourceBase::SSAGetID(std::string src,
DataType t) {
std::string CodeGenSourceBase::AllocVarID(const tir::VarNode* v) {
ICHECK(!var_idmap_.count(v)) << "Need input to be in SSA form dup " <<
v->name_hint;
std::string key = v->name_hint;
- std::string vid = GetUniqueName(key);
+ std::string vid = name_supply_->FreshName(key);
std::replace(vid.begin(), vid.end(), ':', '_');
std::replace(vid.begin(), vid.end(), '-', '_');
std::replace(vid.begin(), vid.end(), '.', '_');
diff --git a/src/target/source/codegen_source_base.h
b/src/target/source/codegen_source_base.h
index 66287f9ad1..2fd0abcd68 100644
--- a/src/target/source/codegen_source_base.h
+++ b/src/target/source/codegen_source_base.h
@@ -25,6 +25,7 @@
#ifndef TVM_TARGET_SOURCE_CODEGEN_SOURCE_BASE_H_
#define TVM_TARGET_SOURCE_CODEGEN_SOURCE_BASE_H_
+#include <tvm/ir/name_supply.h>
#include <tvm/runtime/metadata.h>
#include <tvm/target/codegen.h>
#include <tvm/tir/expr.h>
@@ -97,12 +98,6 @@ class CodeGenSourceBase {
* \param t The type of the expression.
*/
std::string SSAGetID(std::string src, DataType t);
- /*!
- * \brief get a unique name with the corresponding prefix
- * \param prefix The prefix of the name
- * \return The returned name.
- */
- std::string GetUniqueName(std::string prefix);
/*!
* \brief mark the beginning of a new scope
* \return The scope id.
@@ -127,12 +122,12 @@ class CodeGenSourceBase {
std::ostringstream stream;
/*! \brief name of each variable */
std::unordered_map<const tir::VarNode*, std::string> var_idmap_;
+ /*! \brief NameSupply for allocation */
+ NameSupply name_supply_ = NameSupply("");
private:
/*! \brief assignment map of ssa */
std::unordered_map<std::string, SSAEntry> ssa_assign_map_;
- /*! \brief name allocation map */
- std::unordered_map<std::string, int> name_alloc_map_;
/*! \brief array to check whether we are inside certain scope */
std::vector<bool> scope_mark_;
/*! \brief The current indentation value */
diff --git a/src/te/operation/create_primfunc.cc
b/src/te/operation/create_primfunc.cc
index 68b25a1653..55df71a805 100644
--- a/src/te/operation/create_primfunc.cc
+++ b/src/te/operation/create_primfunc.cc
@@ -18,6 +18,7 @@
*/
#include <tvm/arith/analyzer.h>
+#include <tvm/ir/name_supply.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/function.h>
#include <tvm/tir/stmt_functor.h>
@@ -61,8 +62,10 @@ struct CreateFuncInfo {
ProducerToBufferTransformer transformer;
/*! \brief The buffers should be allocated at function root. */
Array<Buffer> root_alloc;
- /*! \brief The count map to make block name unique. */
- std::unordered_map<String, int> name_count;
+ /*! \brief The NameSupply to make block name unique. */
+ NameSupply name_supply = NameSupply("");
+
+ String FreshName(String base_name) { return
name_supply->FreshName(base_name); }
explicit CreateFuncInfo(Array<te::Tensor> arg_list)
: arg_list(std::move(arg_list)), transformer(tensor2buffers) {}
@@ -71,16 +74,6 @@ struct CreateFuncInfo {
return std::any_of(arg_list.begin(), arg_list.end(),
[&tensor](const te::Tensor& arg) { return tensor ==
arg; });
}
-
- String GetUniqueName(const String& prefix) {
- String unique_prefix = prefix;
- auto it = name_count.find(prefix);
- while (name_count.count(unique_prefix)) {
- unique_prefix = prefix + "_" + std::to_string(++it->second);
- }
- name_count[unique_prefix] = 0;
- return unique_prefix;
- }
};
class LayoutFreePlaceholdersNormalizer : public StmtMutator {
@@ -179,7 +172,7 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp&
compute_op,
Stmt body;
if (const auto* reduce = expr_body.as<ReduceNode>()) {
// Case 1. Reduce compute
- block_name = info->GetUniqueName(compute_op->name);
+ block_name = info->FreshName(compute_op->name);
int n_buffers = buffers.size();
Array<PrimExpr> lhs;
@@ -236,7 +229,7 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp&
compute_op,
} else {
// Case 2. Data parallel compute
ICHECK_EQ(tensors.size(), 1);
- block_name = info->GetUniqueName(tensors[0]->GetNameHint());
+ block_name = info->FreshName(tensors[0]->GetNameHint());
const PrimExpr& compute_body = Substitute(info->transformer(expr_body),
var_map);
body = BufferStore(info->tensor2buffers[tensors[0]],
analyzer->Simplify(compute_body), indices);
}
@@ -387,7 +380,7 @@ Stmt GenerateStmtFromExternOp(const te::ExternOp&
extern_op, CreateFuncInfo* inf
Block(/*iter_vars=*/{},
/*reads=*/std::move(reads),
/*writes=*/std::move(writes),
- /*name_hint=*/info->GetUniqueName(extern_op->name),
+ /*name_hint=*/info->FreshName(extern_op->name),
/*body=*/std::move(body),
/*init=*/NullOpt,
/*alloc_buffers=*/{},
diff --git a/src/tir/transforms/split_host_device.cc
b/src/tir/transforms/split_host_device.cc
index 85845616f1..dc56a3ce76 100644
--- a/src/tir/transforms/split_host_device.cc
+++ b/src/tir/transforms/split_host_device.cc
@@ -21,6 +21,7 @@
* \file split_host_device.cc
* \brief Split device function from host.
*/
+#include <tvm/ir/global_var_supply.h>
#include <tvm/ir/transform.h>
#include <tvm/runtime/registry.h>
#include <tvm/target/target.h>
@@ -302,12 +303,15 @@ class HostDeviceSplitter : public StmtMutator {
arguments.push_back(var);
}
}
+ GlobalVarSupply global_var_supply = GlobalVarSupply(*device_mod_);
+ GlobalVar kernel_symbol_global =
global_var_supply->FreshGlobal(kernel_symbol, false);
+
PrimFunc device_func(params, Substitute(body, remap_vars));
device_func = WithAttr(std::move(device_func),
tir::attr::kDeviceThreadAxis, m.thread_axis_);
device_func = WithAttr(std::move(device_func), tvm::attr::kCallingConv,
Integer(CallingConv::kDeviceKernelLaunch));
- device_func =
- WithAttr(std::move(device_func), tvm::attr::kGlobalSymbol,
runtime::String(kernel_symbol));
+ device_func = WithAttr(std::move(device_func), tvm::attr::kGlobalSymbol,
+ runtime::String(kernel_symbol_global->name_hint));
device_func = WithAttr(std::move(device_func), tir::attr::kNoAlias,
Integer(1));
device_func = WithAttr(std::move(device_func), tvm::attr::kTarget,
device_target_);
device_func = WithAttr(std::move(device_func), tir::attr::kIsGlobalFunc,
Integer(1));
@@ -315,11 +319,11 @@ class HostDeviceSplitter : public StmtMutator {
device_func =
WithAttr(std::move(device_func),
tir::attr::kDeviceUseDynSharedMemory, Integer(1));
}
- (*device_mod_)->Add(GlobalVar(kernel_symbol), device_func);
+ (*device_mod_)->Add(kernel_symbol_global, device_func);
// generate calls to the device function
Array<PrimExpr> call_args;
- call_args.push_back(StringImm(kernel_symbol));
+ call_args.push_back(StringImm(kernel_symbol_global->name_hint));
for (PrimExpr arg : arguments) {
call_args.push_back(arg);
}
diff --git a/tests/cpp/build_module_test.cc b/tests/cpp/build_module_test.cc
index ff3641cd69..3d2adb2355 100644
--- a/tests/cpp/build_module_test.cc
+++ b/tests/cpp/build_module_test.cc
@@ -52,7 +52,7 @@ TEST(BuildModule, Basic) {
auto target = Target("llvm");
- auto lowered = LowerSchedule(s, args, "func", binds);
+ auto lowered = LowerSchedule(s, args, "func", binds,
GlobalVarSupply(NameSupply("")));
auto module = build(lowered, target, Target());
auto mali_target = Target("opencl -model=Mali-T860MP4@800Mhz -device=mali");
@@ -121,8 +121,9 @@ TEST(BuildModule, Heterogeneous) {
auto args2 = Array<Tensor>({copy, C, elemwise_sub});
std::unordered_map<Tensor, Buffer> binds;
- auto lowered_s1 = LowerSchedule(fcreate_s1(), args1, "elemwise_add", binds);
- auto lowered_s2 = LowerSchedule(fcreate_s2(), args2, "elemwise_sub", binds);
+ GlobalVarSupply global_var_supply = GlobalVarSupply(NameSupply(""));
+ auto lowered_s1 = LowerSchedule(fcreate_s1(), args1, "elemwise_add", binds,
global_var_supply);
+ auto lowered_s2 = LowerSchedule(fcreate_s2(), args2, "elemwise_sub", binds,
global_var_supply);
Map<tvm::Target, IRModule> inputs = {{target_cuda, lowered_s1},
{target_llvm, lowered_s2}};
auto module = build(inputs, Target());
diff --git a/tests/cpp/c_codegen_test.cc b/tests/cpp/c_codegen_test.cc
index 097de862a9..442f76a8cf 100644
--- a/tests/cpp/c_codegen_test.cc
+++ b/tests/cpp/c_codegen_test.cc
@@ -52,7 +52,8 @@ TEST(CCodegen, MainFunctionOrder) {
auto args = Array<Tensor>({A, B, elemwise_add});
std::unordered_map<Tensor, Buffer> binds;
- auto lowered = LowerSchedule(fcreate(), args, "elemwise_add", binds);
+ auto lowered =
+ LowerSchedule(fcreate(), args, "elemwise_add", binds,
GlobalVarSupply(NameSupply("")));
Map<tvm::Target, IRModule> inputs = {{target_c, lowered}};
runtime::Module module = build(inputs, Target());
Array<String> functions = module->GetFunction("get_func_names", false)();
@@ -81,7 +82,8 @@ auto BuildLowered(std::string op_name, tvm::Target target) {
auto args = Array<Tensor>({A, B, op});
std::unordered_map<Tensor, Buffer> binds;
- auto lowered_s = LowerSchedule(fcreate_s(), args, op_name, binds);
+ auto lowered_s =
+ LowerSchedule(fcreate_s(), args, op_name, binds,
GlobalVarSupply(NameSupply("")));
return lowered_s;
}
diff --git a/tests/cpp/name_supply_test.cc b/tests/cpp/name_supply_test.cc
new file mode 100644
index 0000000000..75b9ae86a9
--- /dev/null
+++ b/tests/cpp/name_supply_test.cc
@@ -0,0 +1,129 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <gtest/gtest.h>
+#include <tvm/ir/global_var_supply.h>
+#include <tvm/ir/module.h>
+#include <tvm/ir/name_supply.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/function.h>
+
+using namespace tvm;
+
+NameSupply preambleNameSupply() {
+ NameSupply name_supply = NameSupply("prefix");
+ name_supply->FreshName("test");
+ return name_supply;
+}
+
+TEST(NameSupply, FreshName) {
+ NameSupply name_supply = preambleNameSupply();
+ String fresh = name_supply->FreshName("test");
+
+ EXPECT_EQ(fresh.compare("prefix_test_1"), 0);
+}
+
+TEST(NameSupply, FreshNameNoConflict) {
+ NameSupply name_supply = preambleNameSupply();
+ String fresh = name_supply->FreshName("name_2");
+ EXPECT_EQ(fresh.compare("prefix_name_2"), 0);
+
+ fresh = name_supply->FreshName("name");
+ EXPECT_EQ(fresh.compare("prefix_name"), 0);
+
+ fresh = name_supply->FreshName("name");
+ EXPECT_EQ(fresh.compare("prefix_name_1"), 0);
+
+ fresh = name_supply->FreshName("name");
+ EXPECT_EQ(fresh.compare("prefix_name_3"), 0);
+}
+
+TEST(NameSupply, ContainsName) {
+ NameSupply name_supply = preambleNameSupply();
+
+ EXPECT_TRUE(name_supply->ContainsName("test"));
+ EXPECT_FALSE(name_supply->ContainsName("test_1"));
+}
+
+TEST(NameSupply, ReserveName) {
+ NameSupply name_supply = preambleNameSupply();
+ name_supply->ReserveName("otherTest", false);
+
+ EXPECT_TRUE(name_supply->ContainsName("otherTest", false));
+ EXPECT_FALSE(name_supply->ContainsName("otherTest"));
+
+ name_supply->ReserveName("otherTest");
+ EXPECT_TRUE(name_supply->ContainsName("prefix_otherTest", false));
+ EXPECT_TRUE(name_supply->ContainsName("otherTest"));
+}
+
+GlobalVarSupply preambleVarSupply() {
+ GlobalVarSupply global_var_supply = GlobalVarSupply(NameSupply(""));
+ global_var_supply->FreshGlobal("test");
+ return global_var_supply;
+}
+
+TEST(GlobalVarSupply, FreshGlobal) {
+ GlobalVarSupply global_var_supply = preambleVarSupply();
+ GlobalVar first_var = global_var_supply->FreshGlobal("test");
+ GlobalVar second_var = global_var_supply->FreshGlobal("test");
+
+ EXPECT_FALSE(tvm::StructuralEqual()(first_var, second_var));
+ EXPECT_EQ(first_var->name_hint.compare("test_1"), 0);
+ EXPECT_EQ(second_var->name_hint.compare("test_2"), 0);
+}
+
+TEST(GlobalVarSupply, UniqueGlobalFor) {
+ GlobalVarSupply global_var_supply = preambleVarSupply();
+ GlobalVar first_var = global_var_supply->UniqueGlobalFor("someName");
+ GlobalVar second_var = global_var_supply->UniqueGlobalFor("someName");
+
+ EXPECT_TRUE(tvm::StructuralEqual()(first_var, second_var));
+ EXPECT_EQ(first_var->name_hint.compare("someName"), 0);
+ EXPECT_EQ(second_var->name_hint.compare("someName"), 0);
+}
+
+TEST(GlobalVarSupply, ReserveGlobal) {
+ GlobalVarSupply global_var_supply = preambleVarSupply();
+ GlobalVar var = GlobalVar("someName");
+ global_var_supply->ReserveGlobalVar(var);
+ GlobalVar second_var = global_var_supply->UniqueGlobalFor("someName");
+ GlobalVar third_var = global_var_supply->FreshGlobal("someName");
+
+ EXPECT_TRUE(tvm::StructuralEqual()(var, second_var));
+ EXPECT_FALSE(tvm::StructuralEqual()(var, third_var));
+ EXPECT_EQ(second_var->name_hint.compare("someName"), 0);
+ EXPECT_EQ(third_var->name_hint.compare("someName_1"), 0);
+}
+
+TEST(GlobalVarSupply, BuildIRModule) {
+ auto x = relay::Var("x", relay::Type());
+ auto f = relay::Function(tvm::Array<relay::Var>{x}, x, relay::Type(), {});
+ GlobalVar var = GlobalVar("test");
+ IRModule module = IRModule({{var, f}});
+
+ GlobalVarSupply global_var_supply = GlobalVarSupply(module);
+ GlobalVar second_var = global_var_supply->UniqueGlobalFor("test", false);
+ GlobalVar third_var = global_var_supply->FreshGlobal("test", false);
+
+ EXPECT_TRUE(tvm::StructuralEqual()(var, second_var));
+ EXPECT_FALSE(tvm::StructuralEqual()(var, third_var));
+ EXPECT_EQ(second_var->name_hint.compare("test"), 0);
+ EXPECT_EQ(third_var->name_hint.compare("test_1"), 0);
+}
diff --git a/tests/python/relay/backend/test_pass_lower_te.py
b/tests/python/relay/backend/test_pass_lower_te.py
index 310a16e269..fb79c1f2e7 100644
--- a/tests/python/relay/backend/test_pass_lower_te.py
+++ b/tests/python/relay/backend/test_pass_lower_te.py
@@ -203,12 +203,12 @@ def test_lower_extern_with_dynamic_shape():
# Expected:
# def @main(%a: Tensor[(5, 7), float32]) -> Tensor[(?, ?), float32] {
# %0 = (%a, %a);
- # call_lowered(@my_dyn, %0,
metadata={prim_shape_fn_var='shape_func_add', relay_attrs={Extern=1},
prim_shape_fn_states=[2, 2], prim_shape_fn_num_inputs=2,
all_prim_shape_fn_vars=['shape_func_add'], prim_shape_fn_num_outputs=1,
all_prim_fn_vars=[]})
+ # call_lowered(@my_dyn, %0,
metadata={prim_shape_fn_var='test_shape_func_add', relay_attrs={Extern=1},
prim_shape_fn_states=[2, 2], prim_shape_fn_num_inputs=2,
all_prim_shape_fn_vars=['shape_func_add'], prim_shape_fn_num_outputs=1,
all_prim_fn_vars=[]})
# }
# def @my_dyn(%x: Tensor[(5, 7), float32] , %y: Tensor[(5, 7), float32] ,
Extern=1) -> Tensor[(?, ?), float32] {
# add(%x, %y)
# }
- # def @shape_func_add = <shape PrimFunc>
+ # def @test_shape_func_add = <shape PrimFunc>
main = actual_mod["main"]
call = main.body
@@ -218,14 +218,14 @@ def test_lower_extern_with_dynamic_shape():
assert len(call.args[1].fields) == 2
assert call.args[1].fields[0].name_hint == "a"
assert call.args[1].fields[1].name_hint == "a"
- assert call.attrs.metadata["prim_shape_fn_var"].name_hint ==
"shape_func_add"
+ assert call.attrs.metadata["prim_shape_fn_var"].name_hint ==
"test_shape_func_add"
assert call.attrs.metadata["relay_attrs"].Extern == 1
assert len(call.attrs.metadata["prim_shape_fn_states"]) == 2
assert call.attrs.metadata["prim_shape_fn_states"][0] == 2
assert call.attrs.metadata["prim_shape_fn_states"][1] == 2
assert call.attrs.metadata["prim_shape_fn_num_inputs"] == 2
assert len(call.attrs.metadata["all_prim_shape_fn_vars"]) == 1
- assert call.attrs.metadata["all_prim_shape_fn_vars"][0].name_hint ==
"shape_func_add"
+ assert call.attrs.metadata["all_prim_shape_fn_vars"][0].name_hint ==
"test_shape_func_add"
assert call.attrs.metadata["prim_shape_fn_num_outputs"] == 1
assert len(call.attrs.metadata["all_prim_fn_vars"]) == 0
@@ -233,7 +233,7 @@ def test_lower_extern_with_dynamic_shape():
assert isinstance(my_dyn, tvm.relay.Function)
assert my_dyn.attrs["Extern"] == 1
- shape_func_add = actual_mod["shape_func_add"]
+ shape_func_add = actual_mod["test_shape_func_add"]
assert isinstance(shape_func_add, tvm.tir.PrimFunc)
diff --git a/tests/python/relay/test_name_supply.py
b/tests/python/relay/test_name_supply.py
new file mode 100644
index 0000000000..688be19c81
--- /dev/null
+++ b/tests/python/relay/test_name_supply.py
@@ -0,0 +1,72 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+import tvm.testing
+
+from tvm import relay
+from tvm.ir import GlobalVar, structural_equal
+from tvm.ir.supply import NameSupply
+from tvm.ir.supply import GlobalVarSupply
+
+
+def test_name_supply():
+ name_supply = NameSupply("prefix")
+ name_supply.reserve_name("test")
+
+ assert name_supply.contains_name("test")
+ assert name_supply.fresh_name("test") == "prefix_test_1"
+ assert name_supply.contains_name("test_1")
+ assert not name_supply.contains_name("test_1", False)
+ assert not name_supply.contains_name("test_2")
+
+
+def test_global_var_supply_from_none():
+ var_supply = GlobalVarSupply()
+ global_var = GlobalVar("test")
+ var_supply.reserve_global(global_var)
+
+ assert structural_equal(var_supply.unique_global_for("test"), global_var)
+ assert not structural_equal(var_supply.fresh_global("test"), global_var)
+
+
+def test_global_var_supply_from_name_supply():
+ name_supply = NameSupply("prefix")
+ var_supply = GlobalVarSupply(name_supply)
+ global_var = GlobalVar("test")
+ var_supply.reserve_global(global_var)
+
+ assert structural_equal(var_supply.unique_global_for("test", False),
global_var)
+ assert not structural_equal(var_supply.unique_global_for("test"),
global_var)
+
+
+def test_global_var_supply_from_ir_mod():
+ x = relay.var("x")
+ y = relay.var("y")
+ mod = tvm.IRModule()
+ global_var = GlobalVar("test")
+ mod[global_var] = relay.Function([x, y], relay.add(x, y))
+ var_supply = GlobalVarSupply(mod)
+
+ second_global_var = var_supply.fresh_global("test", False)
+
+ assert structural_equal(var_supply.unique_global_for("test", False),
global_var)
+ assert not structural_equal(var_supply.unique_global_for("test"),
global_var)
+ assert not structural_equal(second_global_var, global_var)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()