This is an automated email from the ASF dual-hosted git repository.

areusch pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new ecfd9692a0 Unify name mangling in TVM (#12066)
ecfd9692a0 is described below

commit ecfd9692a0e1763af0e34cd6203c0d5ee947e993
Author: Florin Blanaru <[email protected]>
AuthorDate: Thu Aug 11 17:18:46 2022 +0100

    Unify name mangling in TVM (#12066)
    
    * Add NameSupply and GlobalVarSupply
    
    * Build GlobalVarSupply from IRModules instead of having it attached to an 
IRModule.
    
    * Pass GlobalVarSupply when lowering shape funcs
    
    * Partially replace instantiations of GlobalVar with GlobalVarSupply
    
    * Construct GlobalVarSupply from IRModule
    
    * Add tests for supply
    
    * Add documentation for NameSupply and GlobalVarSupply
    
    Co-authored-by: Florin-Gabriel Blanaru <[email protected]>
---
 include/tvm/driver/driver_api.h                    |  12 +-
 include/tvm/ir/global_var_supply.h                 | 125 ++++++++++++++++++
 include/tvm/ir/module.h                            |  17 +--
 include/tvm/ir/name_supply.h                       | 123 ++++++++++++++++++
 python/tvm/ir/supply.py                            | 141 +++++++++++++++++++++
 src/auto_scheduler/feature.cc                      |   4 +-
 src/contrib/hybrid/codegen_hybrid.cc               | 106 +++++++---------
 src/contrib/hybrid/codegen_hybrid.h                |  10 +-
 src/driver/driver_api.cc                           |  22 ++--
 src/ir/global_var_supply.cc                        | 115 +++++++++++++++++
 src/ir/module.cc                                   |  23 +---
 src/ir/name_supply.cc                              | 108 ++++++++++++++++
 src/relay/backend/graph_executor_codegen.cc        |  22 +---
 src/relay/backend/task_extraction.cc               |   4 +-
 src/relay/backend/te_compiler.cc                   |  74 +++++------
 src/relay/backend/te_compiler.h                    |   4 +-
 src/relay/backend/te_compiler_cache.cc             |  71 +++--------
 src/relay/backend/te_compiler_cache.h              |   8 +-
 src/relay/ir/dataflow_matcher.cc                   |  12 +-
 .../transforms/auto_scheduler_layout_rewrite.cc    |   3 +-
 .../transforms/meta_schedule_layout_rewrite.cc     |   3 +-
 src/relay/transforms/partition_graph.cc            |   7 +-
 src/target/source/codegen_c.cc                     |  54 ++++----
 src/target/source/codegen_c_host.cc                |  14 +-
 src/target/source/codegen_cuda.cc                  |  14 +-
 src/target/source/codegen_metal.cc                 |   8 +-
 src/target/source/codegen_source_base.cc           |  28 +---
 src/target/source/codegen_source_base.h            |  11 +-
 src/te/operation/create_primfunc.cc                |  23 ++--
 src/tir/transforms/split_host_device.cc            |  12 +-
 tests/cpp/build_module_test.cc                     |   7 +-
 tests/cpp/c_codegen_test.cc                        |   6 +-
 tests/cpp/name_supply_test.cc                      | 129 +++++++++++++++++++
 tests/python/relay/backend/test_pass_lower_te.py   |  10 +-
 tests/python/relay/test_name_supply.py             |  72 +++++++++++
 35 files changed, 1052 insertions(+), 350 deletions(-)

diff --git a/include/tvm/driver/driver_api.h b/include/tvm/driver/driver_api.h
index 48800b193c..fffcab4966 100644
--- a/include/tvm/driver/driver_api.h
+++ b/include/tvm/driver/driver_api.h
@@ -29,6 +29,7 @@
 #ifndef TVM_DRIVER_DRIVER_API_H_
 #define TVM_DRIVER_DRIVER_API_H_
 
+#include <tvm/ir/global_var_supply.h>
 #include <tvm/ir/module.h>
 #include <tvm/ir/transform.h>
 #include <tvm/runtime/packed_func.h>
@@ -99,6 +100,7 @@ TVM_DLL IRModule LowerPrimFunc(tvm::tir::PrimFunc func, 
const std::string& name,
  * \param args The arguments to the function.
  * \param name The name of the lowered function.
  * \param binds Buffer assignments.
+ * \param global_var_supply The GlobalVarSupply to be used in the module.
  * \param simple_mode Disables the loop partition pass. Defaults to false.
  * \return The result module.
  */
@@ -106,7 +108,7 @@ TVM_DLL IRModule LowerPrimFunc(tvm::tir::PrimFunc func, 
const std::string& name,
 TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array<te::Tensor>& args,
                                const std::string& name,
                                const std::unordered_map<te::Tensor, 
tir::Buffer>& binds,
-                               bool simple_mode = false);
+                               GlobalVarSupply global_var_supply, bool 
simple_mode = false);
 
 /*!
  * \brief Build an IRModule given a TE schedule, args and binds. This function 
also applies
@@ -115,13 +117,14 @@ TVM_DLL IRModule LowerSchedule(te::Schedule sch, const 
Array<te::Tensor>& args,
  * \param args The arguments to the function (Array of Tensor, Buffer and Vars)
  * \param name The name of the lowered function.
  * \param binds Buffer assignments.
+ * \param global_var_supply The GlobalVarSupply to be used in the module.
  * \param simple_mode Disables the loop partition pass. Defaults to false.
  * \return The result module.
  */
 TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array<ObjectRef>& args,
                                const std::string& name,
                                const std::unordered_map<te::Tensor, 
tir::Buffer>& binds,
-                               bool simple_mode = false);
+                               GlobalVarSupply global_var_supply, bool 
simple_mode = false);
 
 /*!
  * \brief Create an IRModule out of a TE Schedule. It does not apply lowering 
passes. If you want
@@ -130,10 +133,13 @@ TVM_DLL IRModule LowerSchedule(te::Schedule sch, const 
Array<ObjectRef>& args,
  * \param args The arguments to the function.
  * \param name The name of the lowered function.
  * \param binds Buffer assignments.
+ * \param global_var_supply The GlobalVarSupply to be used in the module and 
when creating
+ * GlobalVars.
  * \return The result module.
  */
 IRModule ScheduleToModule(te::Schedule sch, const Array<ObjectRef>& args, 
const std::string& name,
-                          const std::unordered_map<te::Tensor, tir::Buffer>& 
binds);
+                          const std::unordered_map<te::Tensor, tir::Buffer>& 
binds,
+                          GlobalVarSupply global_var_supply);
 /*!
  * \brief Build a device and host module for a specific target from an 
IRModule.
  * \param funcs The functions to be built.
diff --git a/include/tvm/ir/global_var_supply.h 
b/include/tvm/ir/global_var_supply.h
new file mode 100644
index 0000000000..276c64a0d7
--- /dev/null
+++ b/include/tvm/ir/global_var_supply.h
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/ir/global_var_supply.h
+ * \brief GlobalVarSupply that can be used to generate unique \class GlobalVar.
+ */
+#ifndef TVM_IR_GLOBAL_VAR_SUPPLY_H_
+#define TVM_IR_GLOBAL_VAR_SUPPLY_H_
+
+#include <string>
+#include <unordered_map>
+
+#include "tvm/ir/expr.h"
+#include "tvm/ir/module.h"
+#include "tvm/ir/name_supply.h"
+
+namespace tvm {
+
+/*!
+ * \brief GlobalVarSupply can be used to generate unique GlobalVars.
+ */
+class GlobalVarSupplyNode : public Object {
+ public:
+  /*!
+   * \brief Empty constructor. Will use an empty NameSupply.
+   */
+  GlobalVarSupplyNode() : GlobalVarSupplyNode(NameSupply("")) {}
+
+  /*!
+   * \brief Constructor.
+   * \param name_supply The NameSupply to use for generating the names of 
fresh GlobalVars.
+   * \param name_to_var_map An optional map.
+   */
+  explicit GlobalVarSupplyNode(NameSupply name_supply,
+                               std::unordered_map<std::string, GlobalVar> 
name_to_var_map = {});
+
+  /*!
+   * \brief Generates a unique GlobalVar from this supply.
+   * \param name The name from which the name of the GlobalVar is derived.
+   * \param add_prefix If set to true, then the prefix of the contained 
NameSupply will be prepended
+   * to the name. \return A unique GlobalVar.
+   */
+  GlobalVar FreshGlobal(String name, bool add_prefix = true);
+
+  /*!
+   * \brief Looks up for a GlobalVar with the given name in this supply.
+   * If no entry is found, creates one, places it in the cache and returns it.
+   * \param name The name of the GlobalVar to search for.
+   * \param add_prefix If set to true, the prefix of the contained NameSupply 
will be prepended to
+   * the name before performing the search. \return A cached GlobalVar.
+   */
+  GlobalVar UniqueGlobalFor(const String& name, bool add_prefix = true);
+
+  /*!
+   * \brief Reserves an existing GlobalVar with this supply.
+   * \param var The GlobalVar to be registered.
+   * \param allow_conflict Allow conflict with other GlobalVars that have the 
same name.
+   */
+  void ReserveGlobalVar(const GlobalVar& var, bool allow_conflict = false);
+
+  void VisitAttrs(AttrVisitor* v) {}
+
+  /*! \brief The NameSupply used to generate unique name hints to GlobalVars. 
*/
+  NameSupply name_supply_;
+
+  static constexpr const char* _type_key = "GlobalVarSupply";
+  static constexpr const bool _type_has_method_sequal_reduce = false;
+  static constexpr const bool _type_has_method_shash_reduce = false;
+  TVM_DECLARE_FINAL_OBJECT_INFO(GlobalVarSupplyNode, Object);
+
+ private:
+  std::unordered_map<std::string, GlobalVar> name_to_var_map_;
+};
+
+/*!
+ * \brief Managed reference class to GlobalVarSupplyNode.
+ * \sa GlobalVarSupplyNode
+ */
+class GlobalVarSupply : public ObjectRef {
+ public:
+  /*!
+   * \brief Constructor.
+   * \param name_supply The NameSupply to be used when generating new 
GlobalVars.
+   * \param name_to_var_map An optional map.
+   */
+  TVM_DLL explicit GlobalVarSupply(const NameSupply& name_supply,
+                                   std::unordered_map<std::string, GlobalVar> 
name_to_var_map = {});
+
+  /*!
+   * \brief Constructs a supply from an array of IRModules. GlobalVars 
generated by this supply are
+   * guaranteed not to conflict with any GlobalVars that belong to the 
modules. \param modules Array
+   * of IRModules.
+   */
+  TVM_DLL explicit GlobalVarSupply(const Array<IRModule>& modules);
+
+  /*!
+   * \brief Constructs a GlobalVarSupply from an IRModule. GlobalVars 
generated by this supply are
+   * guaranteed not to conflict with GlobalVars that belong to the modules. 
\param module The
+   * IRModule.
+   */
+  TVM_DLL explicit GlobalVarSupply(const IRModule module);
+
+  TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(GlobalVarSupply, ObjectRef, 
GlobalVarSupplyNode);
+};
+
+}  // namespace tvm
+
+#endif  // TVM_IR_GLOBAL_VAR_SUPPLY_H_
diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h
index f73f2230df..7313b4f783 100644
--- a/include/tvm/ir/module.h
+++ b/include/tvm/ir/module.h
@@ -323,14 +323,6 @@ class IRModuleNode : public Object {
   /*! \brief Helper function for registering a typedef's constructors */
   void RegisterConstructors(const GlobalTypeVar& var, const TypeData& type);
 
-  /*!
-   * \brief Returns a version of \p name which is unique amongst all function 
definitions in module.
-   *
-   * \param name The original name.
-   * \return Updated name which is unique.
-   */
-  String GetUniqueName(const String& name);
-
   /*! \brief A map from string names to global variables that
    * ensures global uniqueness.
    */
@@ -481,6 +473,15 @@ namespace attr {
 
 // Following are attributes for IRModule only.
 
+/*!
+ * \brief Name of the module
+ *
+ * Type: String
+ *
+ * \sa tvm::runtime::String
+ */
+constexpr const char* kModuleName = "mod_name";
+
 /*!
  * \brief Executor targeted by the module
  *
diff --git a/include/tvm/ir/name_supply.h b/include/tvm/ir/name_supply.h
new file mode 100644
index 0000000000..a85a6fe70a
--- /dev/null
+++ b/include/tvm/ir/name_supply.h
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/ir/name_supply.h
+ * \brief NameSupply that can be used to generate unique variable names.
+ */
+#ifndef TVM_IR_NAME_SUPPLY_H_
+#define TVM_IR_NAME_SUPPLY_H_
+
+#include <string>
+#include <unordered_map>
+#include <utility>
+
+#include "tvm/ir/expr.h"
+
+namespace tvm {
+
+/*!
+ * \brief NameSupply can be used to generate unique names.
+ */
+class NameSupplyNode : public Object {
+ public:
+  /*!
+   * \brief Empty constructor. Needed by the TVM_REGISTER_NODE_TYPE macro.
+   */
+  NameSupplyNode() = default;
+
+  /*!
+   * \brief Constructor.
+   * \param prefix The prefix to be used with this NameSupply.
+   * \param name_map The map used to guarantee uniqueness.
+   */
+  NameSupplyNode(const String& prefix, std::unordered_map<std::string, int> 
name_map)
+      : prefix_(prefix), name_map(std::move(name_map)) {}
+
+  /*!
+   * \brief Generates a unique name from this NameSupply.
+   * \param name The name from which the generated name is derived.
+   * \param add_prefix If set to true, then the prefix of this NameSupply will 
be prepended to the
+   * name. \return A unique name.
+   */
+  String FreshName(const String& name, bool add_prefix = true);
+
+  /*!
+   * \brief Reserves an existing name with this NameSupply.
+   * \param name The name to be reserved.
+   * \param add_prefix If set to true, then the prefix of this NameSupply will 
be prepended to the
+   * name before reserving it. \return The name that was reserved with the 
NameSupply. It can be
+   * different if a prefix is added.
+   */
+  String ReserveName(const String& name, bool add_prefix = true);
+
+  /*!
+   * \brief Checks if this NameSupply already generated a name.
+   * \param name The name to check.
+   * \param add_prefix If set to true, then the prefix of this NameSupply will 
be prepended to the
+   * name before checking for it. \return True if the name has already been 
generated. False
+   * otherwise.
+   */
+  bool ContainsName(const String& name, bool add_prefix = true);
+
+  void VisitAttrs(AttrVisitor* v) {}
+
+  // Prefix for all GlobalVar names. It can be empty.
+  std::string prefix_;
+
+  static constexpr const char* _type_key = "NameSupply";
+  static constexpr const bool _type_has_method_sequal_reduce = false;
+  static constexpr const bool _type_has_method_shash_reduce = false;
+  TVM_DECLARE_FINAL_OBJECT_INFO(NameSupplyNode, Object);
+
+ private:
+  /*! \brief Helper function to add the NameSupply prefix to the name. */
+  String add_prefix_to_name(const String& name);
+
+  /*!
+   * \brief Function that will generate a unique name.
+   * \param name The name to be used as a base.
+   * \return A unique name.
+   */
+  std::string GetUniqueName(std::string name);
+
+  /*! \brief A map that is used to generate unique names. */
+  std::unordered_map<std::string, int> name_map;
+};
+
+/*!
+ * \brief Managed reference class to NameSupplyNode.
+ * \sa NameSupplyNode
+ */
+class NameSupply : public ObjectRef {
+ public:
+  /*!
+   * \brief Constructor.
+   * \param prefix The prefix to be used with this NameSupply.
+   * \param name_map An optional map.
+   */
+  TVM_DLL explicit NameSupply(const String& prefix,
+                              std::unordered_map<std::string, int> name_map = 
{});
+
+  TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(NameSupply, ObjectRef, NameSupplyNode);
+};
+
+}  // namespace tvm
+
+#endif  // TVM_IR_NAME_SUPPLY_H_
diff --git a/python/tvm/ir/supply.py b/python/tvm/ir/supply.py
new file mode 100644
index 0000000000..095ac43c03
--- /dev/null
+++ b/python/tvm/ir/supply.py
@@ -0,0 +1,141 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Suppliers that are used to guarantee uniqueness of names and GlobalVars."""
+import tvm
+from tvm import Object, IRModule
+from . import _ffi_api
+
+
+@tvm._ffi.register_object("NameSupply")
+class NameSupply(Object):
+    """NameSupply that can be used to generate unique names.
+
+    Parameters
+    ----------
+    prefix: The prefix to be added to the generated names.
+    """
+
+    def __init__(self, prefix=""):
+        self.__init_handle_by_constructor__(_ffi_api.NameSupply, prefix)
+
+    def fresh_name(self, name, add_prefix=True):
+        """Generates a unique name from this NameSupply.
+
+        Parameters
+        ----------
+        name: String
+            The name from which the generated name is derived.
+
+        add_prefix: bool
+            If set to true, then the prefix of this NameSupply will be 
prepended to the name.
+        """
+        return _ffi_api.NameSupply_FreshName(self, name, add_prefix)
+
+    def reserve_name(self, name, add_prefix=True):
+        """Reserves an existing name with this NameSupply.
+
+        Parameters
+        ----------
+        name: String
+            The name to be reserved.
+
+        add_prefix: bool
+            If set to true, then the prefix of this NameSupply will be 
prepended to the name
+            before reserving it.
+        """
+        return _ffi_api.NameSupply_ReserveName(self, name, add_prefix)
+
+    def contains_name(self, name, add_prefix=True):
+        """Checks if this NameSupply already generated a name.
+
+        Parameters
+        ----------
+        name: String
+            The name to check.
+
+        add_prefix: bool
+            If set to true, then the prefix of this NameSupply will be 
prepended to the name
+            before checking for it.
+        """
+        return _ffi_api.NameSupply_ContainsName(self, name, add_prefix)
+
+
+@tvm._ffi.register_object("GlobalVarSupply")
+class GlobalVarSupply(Object):
+    """GlobalVarSupply that holds a mapping between names and GlobalVars.
+
+    GlobalVarSupply can be used to generate new GlobalVars with a unique name.
+    It also can be used to retrieve previously generated GlobalVars based on a 
name.
+
+    Parameters
+    ----------
+    value: Union[List[IRModule], IRModule, NameSupply]
+        The IRModules used to build this GlobalVarSupply or a NameSupply.
+    """
+
+    def __init__(self, value=None):
+        if value is None:
+            name_supply = NameSupply("")
+            
self.__init_handle_by_constructor__(_ffi_api.GlobalVarSupply_NameSupply, 
name_supply)
+        elif isinstance(value, NameSupply):
+            
self.__init_handle_by_constructor__(_ffi_api.GlobalVarSupply_NameSupply, value)
+        elif isinstance(value, (list, tvm.container.Array)):
+            
self.__init_handle_by_constructor__(_ffi_api.GlobalVarSupply_IRModules, value)
+        elif isinstance(value, IRModule):
+            
self.__init_handle_by_constructor__(_ffi_api.GlobalVarSupply_IRModule, value)
+
+    def fresh_global(self, name, add_prefix=True):
+        """Generates a unique GlobalVar from this supply.
+
+        Parameters
+        ----------
+        name: String
+            The name from which the name of the GlobalVar is derived.
+
+        add_prefix: bool
+            If set to true, then the prefix of the contained NameSupply will 
be prepended
+            to the name.
+        """
+        return _ffi_api.GlobalVarSupply_FreshGlobal(self, name, add_prefix)
+
+    def unique_global_for(self, name, add_prefix=True):
+        """Looks up for a GlobalVar with the given name in this supply. If no 
entry is found
+        , creates one, places it in the cache and returns it.
+
+        Parameters
+        ----------
+        name: String
+            The name of the GlobalVar to search for.
+
+        add_prefix: bool
+            If set to true, the prefix of the contained NameSupply will be 
prepended to the
+            name before performing the search.
+        """
+        return _ffi_api.GlobalVarSupply_UniqueGlobalFor(self, name, add_prefix)
+
+    def reserve_global(self, global_var, allow_conflict=False):
+        """Reserves an existing GlobalVar with this supply.
+
+        Parameters
+        ----------
+        global_var: GlobalVar
+            The GlobalVar to be registered.
+
+        allow_conflict: bool
+            Allow conflict with other GlobalVars that have the same name
+        """
+        return _ffi_api.GlobalVarSupply_ReserveGlobalVar(self, global_var, 
allow_conflict)
diff --git a/src/auto_scheduler/feature.cc b/src/auto_scheduler/feature.cc
index ab60aef9ae..c930bf0c4e 100644
--- a/src/auto_scheduler/feature.cc
+++ b/src/auto_scheduler/feature.cc
@@ -27,6 +27,7 @@
 #include <tvm/auto_scheduler/measure.h>
 #include <tvm/auto_scheduler/measure_record.h>
 #include <tvm/driver/driver_api.h>
+#include <tvm/ir/global_var_supply.h>
 #include <tvm/runtime/registry.h>
 #include <tvm/support/parallel_for.h>
 #include <tvm/te/operation.h>
@@ -1371,7 +1372,8 @@ void GetPerStoreFeaturesWorkerFunc(const SearchTask& 
task, const State& state, i
     auto pass_ctx = tvm::transform::PassContext::Current();
 
     auto mod = ScheduleToModule(sch, Array<ObjectRef>{tensors.begin(), 
tensors.end()}, name,
-                                std::unordered_map<te::Tensor, te::Buffer>());
+                                std::unordered_map<te::Tensor, te::Buffer>(),
+                                GlobalVarSupply(NameSupply("")));
 
     bool disable_vectorize =
         pass_ctx->GetConfig<Bool>("tir.disable_vectorize", 
Bool(false)).value();
diff --git a/src/contrib/hybrid/codegen_hybrid.cc 
b/src/contrib/hybrid/codegen_hybrid.cc
index 24c7ee74cd..79c9e567b4 100644
--- a/src/contrib/hybrid/codegen_hybrid.cc
+++ b/src/contrib/hybrid/codegen_hybrid.cc
@@ -42,24 +42,6 @@ std::string dot_to_underscore(std::string s) {
   return s;
 }
 
-std::string CodeGenHybrid::GetUniqueName(std::string prefix) {
-  prefix = dot_to_underscore(prefix);
-  auto it = ids_allocated_.find(prefix);
-  if (it != ids_allocated_.end()) {
-    while (true) {
-      std::ostringstream os;
-      os << prefix << (++it->second);
-      std::string name = os.str();
-      if (ids_allocated_.count(name) == 0) {
-        prefix = name;
-        break;
-      }
-    }
-  }
-  ids_allocated_[prefix] = 0;
-  return prefix;
-}
-
 std::string CodeGenHybrid::Finish() { return stream.str(); }
 
 void CodeGenHybrid::PrintType(DataType t, std::ostream& os) {
@@ -428,7 +410,7 @@ std::string CodeGenHybrid::GetVarID(const VarNode* v) {
   if (id_map_.count(key)) {
     return id_map_[key];
   }
-  return id_map_[key] = GetUniqueName(v->name_hint);
+  return id_map_[key] = ids_allocated->FreshName(v->name_hint);
 }
 
 std::string CodeGenHybrid::GetTensorID(const Tensor& tensor) {
@@ -440,57 +422,57 @@ std::string CodeGenHybrid::GetTensorID(const Tensor& 
tensor) {
   if (tensor->op->num_outputs() > 1) {
     name_hint += "_v" + std::to_string(tensor->value_index);
   }
-  return id_map_[key] = GetUniqueName(name_hint);
+  return id_map_[key] = ids_allocated->FreshName(name_hint);
 }
 
 void CodeGenHybrid::ReserveKeywords() {
-  GetUniqueName("def");
-  GetUniqueName("for");
-  GetUniqueName("in");
-  GetUniqueName("range");
-  GetUniqueName("True");
-  GetUniqueName("False");
-  GetUniqueName("unroll");
-  GetUniqueName("const_range");
-  GetUniqueName("parallel");
-  GetUniqueName("vectorize");
-  GetUniqueName("bind");
-  GetUniqueName("threadIdx.x");
-  GetUniqueName("threadIdx.y");
-  GetUniqueName("threadIdx.z");
-  GetUniqueName("blockIdx.x");
-  GetUniqueName("blockIdx.y");
-  GetUniqueName("blockIdx.z");
-  GetUniqueName("vthread");
-  GetUniqueName("allocate");
-  GetUniqueName("output_tensor");
-  GetUniqueName("sqrt");
-  GetUniqueName("log");
-  GetUniqueName("tanh");
-  GetUniqueName("power");
-  GetUniqueName("exp");
-  GetUniqueName("sigmoid");
-  GetUniqueName("popcount");
-  GetUniqueName("likely");
-  GetUniqueName("int8");
-  GetUniqueName("int16");
-  GetUniqueName("int32");
-  GetUniqueName("int64");
-  GetUniqueName("uint8");
-  GetUniqueName("uint16");
-  GetUniqueName("uint32");
-  GetUniqueName("uint64");
-  GetUniqueName("float16");
-  GetUniqueName("float32");
-  GetUniqueName("float64");
-  GetUniqueName("ceil_div");
-  GetUniqueName("max_num_threads");
+  ids_allocated->ReserveName("def");
+  ids_allocated->ReserveName("for");
+  ids_allocated->ReserveName("in");
+  ids_allocated->ReserveName("range");
+  ids_allocated->ReserveName("True");
+  ids_allocated->ReserveName("False");
+  ids_allocated->ReserveName("unroll");
+  ids_allocated->ReserveName("const_range");
+  ids_allocated->ReserveName("parallel");
+  ids_allocated->ReserveName("vectorize");
+  ids_allocated->ReserveName("bind");
+  ids_allocated->ReserveName("threadIdx.x");
+  ids_allocated->ReserveName("threadIdx.y");
+  ids_allocated->ReserveName("threadIdx.z");
+  ids_allocated->ReserveName("blockIdx.x");
+  ids_allocated->ReserveName("blockIdx.y");
+  ids_allocated->ReserveName("blockIdx.z");
+  ids_allocated->ReserveName("vthread");
+  ids_allocated->ReserveName("allocate");
+  ids_allocated->ReserveName("output_tensor");
+  ids_allocated->ReserveName("sqrt");
+  ids_allocated->ReserveName("log");
+  ids_allocated->ReserveName("tanh");
+  ids_allocated->ReserveName("power");
+  ids_allocated->ReserveName("exp");
+  ids_allocated->ReserveName("sigmoid");
+  ids_allocated->ReserveName("popcount");
+  ids_allocated->ReserveName("likely");
+  ids_allocated->ReserveName("int8");
+  ids_allocated->ReserveName("int16");
+  ids_allocated->ReserveName("int32");
+  ids_allocated->ReserveName("int64");
+  ids_allocated->ReserveName("uint8");
+  ids_allocated->ReserveName("uint16");
+  ids_allocated->ReserveName("uint32");
+  ids_allocated->ReserveName("uint64");
+  ids_allocated->ReserveName("float16");
+  ids_allocated->ReserveName("float32");
+  ids_allocated->ReserveName("float64");
+  ids_allocated->ReserveName("ceil_div");
+  ids_allocated->ReserveName("max_num_threads");
 }
 
 void CodeGenHybrid::DumpStmt(const Stmt& stmt, const Array<ObjectRef>& inputs,
                              const Array<Tensor>& outputs, const std::string& 
name) {
   ReserveKeywords();
-  GetUniqueName(name);
+  ids_allocated->ReserveName(name);
 
   stream << "def " << name << "(";
   for (size_t i = 0; i < inputs.size(); ++i) {
diff --git a/src/contrib/hybrid/codegen_hybrid.h 
b/src/contrib/hybrid/codegen_hybrid.h
index da45ffb6a8..53026c7fc3 100644
--- a/src/contrib/hybrid/codegen_hybrid.h
+++ b/src/contrib/hybrid/codegen_hybrid.h
@@ -24,6 +24,7 @@
 #ifndef TVM_CONTRIB_HYBRID_CODEGEN_HYBRID_H_
 #define TVM_CONTRIB_HYBRID_CODEGEN_HYBRID_H_
 
+#include <tvm/ir/name_supply.h>
 #include <tvm/target/codegen.h>
 #include <tvm/te/operation.h>
 #include <tvm/te/schedule.h>
@@ -145,19 +146,14 @@ class CodeGenHybrid : public ExprFunctor<void(const 
PrimExpr&, std::ostream&)>,
   const int tab_{4};
   /*! \brief Print the current indent spaces. */
   inline void PrintIndent();
-  /*! \brief Keys are ids allocated, and values are the suffix to prevent 
double-name.  */
-  std::map<std::string, int> ids_allocated_;
+  /*! \brief NameSupply for allocated ids.  */
+  NameSupply ids_allocated = NameSupply("");
   /*!
    * \brief Keys are either (tensors, value_index) or (variables, 0).
    *        Values are the corresponding IDs.*/
   std::map<std::pair<const Object*, int>, std::string> id_map_;
   /*! \brief Variables (keys) binded to the threads (values). */
   std::map<const VarNode*, std::string> binds_;
-  /*!
-   * \brief Find an unallocated name for the given prefix.
-   * \param prefix The given prefix.
-   */
-  std::string GetUniqueName(std::string prefix);
   /*! \brief The output code string builder. */
   std::stringstream stream;
   /*!
diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc
index 6f4fb618d3..cbf809a267 100644
--- a/src/driver/driver_api.cc
+++ b/src/driver/driver_api.cc
@@ -261,7 +261,8 @@ IRModule ApplyPasses(IRModule mod, transform::Sequential 
seq) {
 
 // Convert te schedule to IRModule
 IRModule ScheduleToModule(te::Schedule sch, const Array<ObjectRef>& args, 
const std::string& name,
-                          const std::unordered_map<te::Tensor, tir::Buffer>& 
binds) {
+                          const std::unordered_map<te::Tensor, tir::Buffer>& 
binds,
+                          GlobalVarSupply global_var_supply) {
   sch = sch.normalize();
 
   transform::PassContext pass_ctx = transform::PassContext::Current();
@@ -289,7 +290,8 @@ IRModule ScheduleToModule(te::Schedule sch, const 
Array<ObjectRef>& args, const
   if (noalias) {
     f = WithAttr(std::move(f), "tir.noalias", Bool(true));
   }
-  return IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
+  GlobalVar global_var = global_var_supply->UniqueGlobalFor(name, false);
+  return IRModule(Map<GlobalVar, BaseFunc>({{global_var, f}}));
 }
 
 TVM_REGISTER_GLOBAL("driver.schedule_to_module")
@@ -302,7 +304,8 @@ TVM_REGISTER_GLOBAL("driver.schedule_to_module")
           c_binds.insert({kv.first, kv.second});
         }
       }
-      IRModule mod = ScheduleToModule(std::move(sch), args, name, c_binds);
+      IRModule mod =
+          ScheduleToModule(std::move(sch), args, name, c_binds, 
GlobalVarSupply(NameSupply("")));
       return mod;
     });
 
@@ -337,17 +340,19 @@ TVM_REGISTER_GLOBAL("driver.lower_primfunc")
     });
 
 IRModule LowerSchedule(te::Schedule sch, const Array<te::Tensor>& args, const 
std::string& name,
-                       const std::unordered_map<te::Tensor, tir::Buffer>& 
binds, bool simple_mode) {
+                       const std::unordered_map<te::Tensor, tir::Buffer>& 
binds,
+                       GlobalVarSupply global_var_supply, bool simple_mode) {
   Array<ObjectRef> ref_args;
   for (ObjectRef x : args) {
     ref_args.push_back(x);
   }
-  return LowerSchedule(std::move(sch), ref_args, name, binds);
+  return LowerSchedule(std::move(sch), ref_args, name, binds, 
global_var_supply);
 }
 
 IRModule LowerSchedule(te::Schedule sch, const Array<ObjectRef>& args, const 
std::string& name,
-                       const std::unordered_map<te::Tensor, tir::Buffer>& 
binds, bool simple_mode) {
-  IRModule mod = ScheduleToModule(std::move(sch), args, name, binds);
+                       const std::unordered_map<te::Tensor, tir::Buffer>& 
binds,
+                       GlobalVarSupply global_var_supply, bool simple_mode) {
+  IRModule mod = ScheduleToModule(std::move(sch), args, name, binds, 
global_var_supply);
   // Get the legacy TE pass list
   Array<transform::Pass> pass_list = CreatePassList(simple_mode);
   return LowerWithPassList(mod, pass_list);
@@ -363,7 +368,8 @@ TVM_REGISTER_GLOBAL("driver.lower_schedule")
           c_binds.insert({kv.first, kv.second});
         }
       }
-      return LowerSchedule(std::move(sch), args, name, c_binds, simple_mode);
+      return LowerSchedule(std::move(sch), args, name, c_binds, 
GlobalVarSupply(NameSupply("")),
+                           simple_mode);
     });
 
 /**
diff --git a/src/ir/global_var_supply.cc b/src/ir/global_var_supply.cc
new file mode 100644
index 0000000000..383d4445ad
--- /dev/null
+++ b/src/ir/global_var_supply.cc
@@ -0,0 +1,115 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file global_var_supply.cc
+ * \brief GlobalVarSupply that can be used to generate unique GlobalVars.
+ */
+#include "tvm/ir/global_var_supply.h"
+
+#include <tvm/runtime/registry.h>
+
+#include <utility>
+
+#include "tvm/ir/expr.h"
+
+namespace tvm {
+GlobalVarSupply::GlobalVarSupply(const NameSupply& name_supply,
+                                 std::unordered_map<std::string, GlobalVar> 
name_to_var_map) {
+  auto n = make_object<GlobalVarSupplyNode>(name_supply, name_to_var_map);
+  data_ = std::move(n);
+}
+
+std::string GetModuleName(const IRModule& module) {
+  return 
module->GetAttr<String>(tvm::attr::kModuleName).value_or("tvmgen_default");
+}
+
+GlobalVarSupply::GlobalVarSupply(const Array<IRModule>& modules) : 
GlobalVarSupply(NameSupply("")) {
+  if (!modules.empty()) {
+    IRModule first_mod = modules.front();
+    this->operator->()->name_supply_->prefix_ = GetModuleName(first_mod);
+  }
+  for (auto& mod : modules) {
+    for (auto kv : mod->functions) {
+      this->operator->()->ReserveGlobalVar(kv.first);
+    }
+  }
+}
+
+GlobalVarSupply::GlobalVarSupply(const IRModule module)
+    : GlobalVarSupply(Array<IRModule>{module}) {}
+
+void GlobalVarSupplyNode::ReserveGlobalVar(const GlobalVar& var, bool 
allow_conflict) {
+  name_supply_->ReserveName(var->name_hint, false);
+  if (!allow_conflict) {
+    ICHECK(name_to_var_map_.count(var->name_hint) == 0)
+        << "GlobalVar " << var << " conflicts by name in this supply.";
+  }
+  name_to_var_map_[var->name_hint] = var;
+}
+
+GlobalVarSupplyNode::GlobalVarSupplyNode(NameSupply name_supply,
+                                         std::unordered_map<std::string, 
GlobalVar> name_to_var_map)
+    : name_supply_(std::move(name_supply)), 
name_to_var_map_(std::move(name_to_var_map)) {}
+
+GlobalVar GlobalVarSupplyNode::UniqueGlobalFor(const String& name, bool 
add_prefix) {
+  String final_name = name_supply_->ReserveName(name, add_prefix);
+
+  auto it = name_to_var_map_.find(final_name);
+  if (it != name_to_var_map_.end()) {
+    return it->second;
+  } else {
+    GlobalVar var = GlobalVar(final_name);
+    name_to_var_map_.emplace(final_name, var);
+    return var;
+  }
+}
+
+GlobalVar GlobalVarSupplyNode::FreshGlobal(String name, bool add_prefix) {
+  String final_name = name_supply_->FreshName(name, add_prefix);
+  ICHECK(name_to_var_map_.find(final_name) == name_to_var_map_.end())
+      << "GlobalVar already exists for name " << final_name;
+  GlobalVar var = GlobalVar(final_name);
+  name_to_var_map_.emplace(final_name, var);
+  return var;
+}
+
+TVM_REGISTER_NODE_TYPE(GlobalVarSupplyNode);
+
+TVM_REGISTER_GLOBAL("ir.GlobalVarSupply_NameSupply")
+    .set_body_typed([](const NameSupply& name_supply) { return 
GlobalVarSupply(name_supply); });
+
+TVM_REGISTER_GLOBAL("ir.GlobalVarSupply_IRModule").set_body_typed([](IRModule 
mod) {
+  return GlobalVarSupply(std::move(mod));
+});
+
+TVM_REGISTER_GLOBAL("ir.GlobalVarSupply_IRModules").set_body_typed([](const 
Array<IRModule>& mods) {
+  return GlobalVarSupply(mods);
+});
+
+TVM_REGISTER_GLOBAL("ir.GlobalVarSupply_FreshGlobal")
+    .set_body_method<GlobalVarSupply>(&GlobalVarSupplyNode::FreshGlobal);
+
+TVM_REGISTER_GLOBAL("ir.GlobalVarSupply_UniqueGlobalFor")
+    .set_body_method<GlobalVarSupply>(&GlobalVarSupplyNode::UniqueGlobalFor);
+
+TVM_REGISTER_GLOBAL("ir.GlobalVarSupply_ReserveGlobalVar")
+    .set_body_method<GlobalVarSupply>(&GlobalVarSupplyNode::ReserveGlobalVar);
+
+}  // namespace tvm
diff --git a/src/ir/module.cc b/src/ir/module.cc
index 6f2c9f9fe9..8d6de5a536 100644
--- a/src/ir/module.cc
+++ b/src/ir/module.cc
@@ -21,6 +21,7 @@
  * \file  module.cc
  * \brief The global module in Relay.
  */
+#include <tvm/ir/global_var_supply.h>
 #include <tvm/ir/module.h>
 #include <tvm/node/structural_equal.h>
 #include <tvm/runtime/registry.h>
@@ -292,20 +293,6 @@ Constructor IRModuleNode::LookupTag(const int32_t tag) {
   return (*it).second;
 }
 
-String IRModuleNode::GetUniqueName(const String& name) {
-  String result = name;
-  int suffix = 0;
-  while (true) {
-    auto it = global_var_map_.find(result);
-    if (it == global_var_map_.end()) {
-      return result;
-    }
-    std::ostringstream os;
-    os << name << "_" << ++suffix;
-    result = os.str();
-  }
-}
-
 /*!
  * \brief Renames global type/term variables to prefer the 
GlobalTypeVar/GlobalVar in the lhs
  * ('one') side above the rhs ('two').
@@ -397,12 +384,14 @@ std::pair<IRModule, GlobalVar> 
IRModule::FromExprInContext(
     func = relay::Function(relay::FreeVars(expr), expr, Type(), 
relay::FreeTypeVars(expr, mod), {});
   }
 
+  GlobalVar main_gv;
+  auto global_var_supply = GlobalVarSupply(mod);
   if (gv_name.empty()) {
     // Bind function to 'main' (though rename if would clash with existing 
'main').
-    gv_name = mod->GetUniqueName("main");
+    main_gv = global_var_supply->FreshGlobal("main", false);
+  } else {
+    main_gv = global_var_supply->UniqueGlobalFor(gv_name, false);
   }
-
-  GlobalVar main_gv(gv_name);
   mod->Add(main_gv, func);
   return {mod, main_gv};
 }
diff --git a/src/ir/name_supply.cc b/src/ir/name_supply.cc
new file mode 100644
index 0000000000..93f568253c
--- /dev/null
+++ b/src/ir/name_supply.cc
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file name_supply.cc
+ * \brief NameSupply that can be used to generate unique variable names.
+ */
+#include "tvm/ir/name_supply.h"
+
+#include <tvm/runtime/registry.h>
+
+#include <utility>
+
+namespace tvm {
+
+NameSupply::NameSupply(const String& prefix, std::unordered_map<std::string, 
int> name_map) {
+  auto n = make_object<NameSupplyNode>(prefix, std::move(name_map));
+  data_ = std::move(n);
+}
+
+String NameSupplyNode::ReserveName(const String& name, bool add_prefix) {
+  String final_name = name;
+  if (add_prefix) {
+    final_name = add_prefix_to_name(name);
+  }
+  name_map[final_name] = 0;
+  return final_name;
+}
+
+String NameSupplyNode::FreshName(const String& name, bool add_prefix) {
+  String unique_name = name;
+  if (add_prefix) {
+    unique_name = add_prefix_to_name(name);
+  }
+  unique_name = GetUniqueName(unique_name);
+  return unique_name;
+}
+
+bool NameSupplyNode::ContainsName(const String& name, bool add_prefix) {
+  String unique_name = name;
+  if (add_prefix) {
+    unique_name = add_prefix_to_name(name);
+  }
+
+  return name_map.count(unique_name);
+}
+
+String NameSupplyNode::add_prefix_to_name(const String& name) {
+  if (prefix_.empty()) {
+    return name;
+  }
+
+  std::ostringstream ss;
+  ICHECK(name.defined());
+  ss << prefix_ << "_" << name;
+  return ss.str();
+}
+
+std::string NameSupplyNode::GetUniqueName(std::string name) {
+  for (size_t i = 0; i < name.size(); ++i) {
+    if (name[i] == '.') name[i] = '_';
+  }
+  auto it = name_map.find(name);
+  if (it != name_map.end()) {
+    auto new_name = name;
+    while (!name_map.insert({new_name, 0}).second) {
+      std::ostringstream os;
+      os << name << "_" << (++it->second);
+      new_name = os.str();
+    }
+    return new_name;
+  }
+  name_map[name] = 0;
+  return name;
+}
+
+TVM_REGISTER_NODE_TYPE(NameSupplyNode);
+
+TVM_REGISTER_GLOBAL("ir.NameSupply").set_body_typed([](String prefix) {
+  return NameSupply(prefix);
+});
+
+TVM_REGISTER_GLOBAL("ir.NameSupply_FreshName")
+    .set_body_method<NameSupply>(&NameSupplyNode::FreshName);
+
+TVM_REGISTER_GLOBAL("ir.NameSupply_ReserveName")
+    .set_body_method<NameSupply>(&NameSupplyNode::ReserveName);
+
+TVM_REGISTER_GLOBAL("ir.NameSupply_ContainsName")
+    .set_body_method<NameSupply>(&NameSupplyNode::ContainsName);
+
+}  // namespace tvm
diff --git a/src/relay/backend/graph_executor_codegen.cc 
b/src/relay/backend/graph_executor_codegen.cc
index c72511775a..ab725d82e6 100644
--- a/src/relay/backend/graph_executor_codegen.cc
+++ b/src/relay/backend/graph_executor_codegen.cc
@@ -466,7 +466,7 @@ class GraphExecutorCodegen : public 
backend::MemoizedExprTranslator<std::vector<
     }
 
     // Compute the operator name, because we used the get unique name when 
generating the kernel.
-    auto op_name = _GetUniqueName(func_name);
+    auto op_name = name_supply_->FreshName(func_name);
     auto node = GraphOpNode::make_node_ptr(op_name, GraphAttrs(), func_name, 
inputs, attrs);
     return AddNode(node, call);
   }
@@ -604,22 +604,6 @@ class GraphExecutorCodegen : public 
backend::MemoizedExprTranslator<std::vector<
     writer->EndObject();
   }
 
-  /*!
-   * \brief Get unique name for func
-   *
-   * \param name
-   * \return std::string
-   */
-  std::string _GetUniqueName(const std::string& name) {
-    if (!name_map_.count(name)) {
-      name_map_[name] = 1;
-      return name;
-    }
-    auto index = name_map_[name];
-    name_map_[name] += 1;
-    return _GetUniqueName(name + std::to_string(index));
-  }
-
  protected:
   /*! \brief nodes */
   std::vector<GraphObjectPtr> nodes_;
@@ -645,8 +629,8 @@ class GraphExecutorCodegen : public 
backend::MemoizedExprTranslator<std::vector<
   String mod_name_;
   /*! \brief function metadata */
   Map<String, FunctionInfo> function_metadata_;
-  /*! \brief name map */
-  std::unordered_map<std::string, size_t> name_map_;
+  /*! \brief NameSupply */
+  NameSupply name_supply_ = NameSupply("");
 };
 
 class GraphExecutorCodegenModule : public runtime::ModuleNode {
diff --git a/src/relay/backend/task_extraction.cc 
b/src/relay/backend/task_extraction.cc
index af4b49b4f1..c577e8e356 100644
--- a/src/relay/backend/task_extraction.cc
+++ b/src/relay/backend/task_extraction.cc
@@ -74,9 +74,9 @@ Array<meta_schedule::ExtractedTask> ExtractTask(
   });
   // Tasks are extracted via post order visit, return the reversed list.
   std::reverse(tasks.begin(), tasks.end());
-  std::unordered_map<std::string, int> name_map;
+  NameSupply name_supply = NameSupply("");
   for (ExtractedTask task : tasks) {
-    task->task_name = tec::GetUniqueName(task->task_name, &name_map);
+    task->task_name = name_supply->FreshName(task->task_name);
   }
   return tasks;
 }
diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc
index 8ca5a32b7f..5c79ed2070 100644
--- a/src/relay/backend/te_compiler.cc
+++ b/src/relay/backend/te_compiler.cc
@@ -92,6 +92,7 @@
 #include <tvm/driver/driver_api.h>
 #include <tvm/ir/attrs.h>
 #include <tvm/ir/function.h>
+#include <tvm/ir/name_supply.h>
 #include <tvm/relay/analysis.h>
 #include <tvm/relay/attrs/annotation.h>
 #include <tvm/relay/attrs/call.h>
@@ -134,30 +135,33 @@ TVM_REGISTER_OBJECT_TYPE(TECompilerNode);
 
 class TECompilerImpl : public TECompilerNode {
  public:
-  explicit TECompilerImpl(Optional<IRModule> opt_mod) {
+  explicit TECompilerImpl(Optional<IRModule> opt_mod, Optional<String> 
opt_mod_name) {
+    String mod_name = opt_mod_name.value_or("");
+    NameSupply name_supply = NameSupply(mod_name /* prefix */);
+    global_var_supply_ = GlobalVarSupply(name_supply);
     // Make sure we don't collide with any existing globals in the module.
     if (opt_mod) {
       for (const auto& kv : opt_mod.value()->functions) {
-        name_map_[kv.first->name_hint] = 1;
+        global_var_supply_->name_supply_->ReserveName(kv.first->name_hint, 
false);
       }
     }
   }
 
   // Lower the function.
-  CachedFunc Lower(const CCacheKey& key, std::function<String(String)> 
mangle_fn) {
-    return LowerInternal(key, mangle_fn)->cached_func;
+  CachedFunc Lower(const CCacheKey& key) {
+    return LowerInternal(key, global_var_supply_)->cached_func;
   }
 
+  // TODO(gigiblender): Only to be called by the global TE compiler.
+  //  Remove this when the global TE compiler is removed.
   CachedFunc Lower(const CCacheKey& key, const String mod_name) {
-    auto mangle_fn = [mod_name](String name) { return 
runtime::get_name_mangled(mod_name, name); };
-
-    return Lower(key, mangle_fn);
+    global_var_supply_->name_supply_->prefix_ = mod_name;
+    return LowerInternal(key, global_var_supply_)->cached_func;
   }
 
   // For now, build one module per function.
   PackedFunc JIT(const CCacheKey& key) final {
-    auto mangle_fn = [](String name) { return name; };
-    CCacheValue value = LowerInternal(key, mangle_fn);
+    CCacheValue value = LowerInternal(key, GlobalVarSupply(NameSupply("")));
     if (value->packed_func != nullptr) {
       return value->packed_func;
     }
@@ -335,7 +339,7 @@ class TECompilerImpl : public TECompilerNode {
 
  private:
   // implement lowered func
-  CCacheValue LowerInternal(const CCacheKey& key, 
std::function<String(String)> mangle_fn) {
+  CCacheValue LowerInternal(const CCacheKey& key, GlobalVarSupply 
global_var_supply) {
     VLOG(1) << "lowering:" << std::endl
             << PrettyPrint(key->source_func) << std::endl
             << "for target:" << std::endl
@@ -360,7 +364,7 @@ class TECompilerImpl : public TECompilerNode {
     if (opt_compiler.defined()) {
       // Don't compile now since we don't have anywhere to put the resulting 
runtime module.
       // Instead place the original definition in the cache and wait for 
LowerExternalFunctions.
-      IRModule ir_module;
+      IRModule ir_module({}, {});
       Optional<String> opt_global_symbol =
           key->source_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
       ICHECK(opt_global_symbol.defined()) << "External function has not been 
attached a name yet.";
@@ -369,7 +373,7 @@ class TECompilerImpl : public TECompilerNode {
       // the module's globals. Furthermore, the external codegen tool must 
bind the compiled
       // function to the "global_symbol" attribute on the source_func. So do 
not use GetUniqueName
       // here.
-      auto global_var = GlobalVar(opt_global_symbol.value());
+      auto global_var = 
global_var_supply->UniqueGlobalFor(opt_global_symbol.value(), false);
       global_var->checked_type_ = key->source_func->checked_type();
       ir_module->Add(global_var, key->source_func);
       value->cached_func = CachedFunc(key->target, global_var, {}, {}, 
te::Schedule{nullptr},
@@ -388,10 +392,7 @@ class TECompilerImpl : public TECompilerNode {
     With<Target> target_scope(key->target);
 
     ICHECK(!value->cached_func.defined());
-    value->cached_func = PrimFuncFor(key->source_func, key->target, 
[&](std::string name) {
-      auto mangled = mangle_fn(name);
-      return GetUniqueName(mangled, &name_map_);
-    });
+    value->cached_func = PrimFuncFor(key->source_func, key->target, 
global_var_supply);
 
     if (value->cached_func->prim_func.defined()) {
       VLOG(1) << "Lowering PrimFunc";
@@ -443,16 +444,11 @@ class TECompilerImpl : public TECompilerNode {
       }
       auto func_name = value->cached_func->prim_fn_var->name_hint;
       VLOG(1) << "scheduling";
-      IRModule scheduled_module =
-          tvm::LowerSchedule(value->cached_func->schedule, all_args, 
func_name, binds);
+      IRModule scheduled_module = 
tvm::LowerSchedule(value->cached_func->schedule, all_args,
+                                                     func_name, binds, 
global_var_supply);
       
scheduled_module->Update(tir::transform::BindParams(all_consts)(scheduled_module));
-      // Unfortunately the above machinery creates its own GlobalVars instead 
of using *the*
-      // GlobalVar we established above. Fix this before the confusion spreads 
any further.
-      // TODO(mbs): LowerSchedule should be given prim_fn_gvar instead of 
func_name.
       for (const auto& kv : scheduled_module->functions) {
-        GlobalVar global_var = kv.first->name_hint == 
value->cached_func->prim_fn_var->name_hint
-                                   ? value->cached_func->prim_fn_var
-                                   : kv.first;
+        GlobalVar global_var = kv.first;
         auto func = kv.second;
         // Propagate the structural hash of the relay function to the tir
         // function so associations can be made between the two.
@@ -498,9 +494,7 @@ class TECompilerImpl : public TECompilerNode {
 
     using tvm::transform::PassContext;
     With<PassContext> fresh_pass_ctx_scope(PassContext::Create());
-    value->cached_func = ShapeFuncFor(key->source_func, key->target, 
[&](std::string name) {
-      return GetUniqueName(name, &name_map_);
-    });
+    value->cached_func = ShapeFuncFor(key->source_func, key->target, 
global_var_supply_);
 
     ICHECK(
         
value->cached_func->funcs->Lookup(value->cached_func->prim_fn_var).as<tir::PrimFuncNode>());
@@ -527,8 +521,8 @@ class TECompilerImpl : public TECompilerNode {
 
   /*! \brief compiler cache lock*/
   std::mutex mutex_;
-  /*! \brief internal name map to get an unique name */
-  std::unordered_map<std::string, int> name_map_;
+  /*! \brief internal GlobalVarSupply to get unique GlobalVars  */
+  GlobalVarSupply global_var_supply_;
   /*! \brief internal compiler cache */
   std::unordered_map<CCacheKey, CCacheValue> cache_;
   /*! \brief internal compiler cache for shape funcs */
@@ -539,15 +533,16 @@ class TECompilerImpl : public TECompilerNode {
   Map<GlobalVar, String> device_contexts_;
 };
 
-TECompiler::TECompiler(Optional<IRModule> opt_mod) {
-  auto object = make_object<TECompilerImpl>(std::move(opt_mod));
+TECompiler::TECompiler(Optional<IRModule> opt_mod, Optional<String> mod_name) {
+  auto object = make_object<TECompilerImpl>(std::move(opt_mod), 
std::move(mod_name));
   data_ = object;
 }
 
 /*! \brief The global TE compiler */
 // TODO(mbs): To be terminated with extreme prejudice.
 TECompiler& TECompiler::Global() {
-  static TECompiler* inst = new 
TECompiler(make_object<TECompilerImpl>(Optional<IRModule>()));
+  static TECompiler* inst =
+      new TECompiler(make_object<TECompilerImpl>(Optional<IRModule>(), 
Optional<String>()));
   return *inst;
 }
 TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.use_auto_scheduler", Bool);
@@ -629,12 +624,11 @@ using AnalysisRemapping = std::unordered_map<Expr, Expr, 
ObjectHash, ObjectEqual
 class LowerTensorExprMutator : public DeviceAwareExprMutator {
  public:
   LowerTensorExprMutator(IRModule module, ProcessFn process_fn, 
CompilationConfig config,
-                         String module_name, TECompiler compiler)
+                         TECompiler compiler)
       : DeviceAwareExprMutator(module),
         module_(std::move(module)),
         process_fn_(std::move(process_fn)),
         config_(std::move(config)),
-        module_name_(std::move(module_name)),
         compiler_(std::move(compiler)),
         debug_op_(Op::Get("debug")) {}
 
@@ -925,7 +919,7 @@ class LowerTensorExprMutator : public 
DeviceAwareExprMutator {
       // codegen.
       CCacheKey key(Downcast<Function>(primitive_func), target,
                     GetVirtualDevice(GetRef<Call>(call_node)));
-      CachedFunc cfunc = compiler_->Lower(key, module_name_);
+      CachedFunc cfunc = compiler_->Lower(key);
       ICHECK(cfunc.defined());
       return MakeLoweredCall(primitive_func, cfunc->prim_fn_var, 
std::move(new_args),
                              call_node->span, target, cfunc->funcs->functions);
@@ -942,17 +936,15 @@ class LowerTensorExprMutator : public 
DeviceAwareExprMutator {
   // module we'll ultimately emit for each required device-type. Note that a 
primitive may be
   // lowered for multiple device types, each which will be assigned a fresh 
var.
   std::unordered_map<const VarNode*, BaseFunc> primitive_functions_;
-  String module_name_;
   TECompiler compiler_;
   // Cache ops that need to be frequently used later to reduce lookup overhead.
   const Op& debug_op_;
 };
 
-Pass LowerTensorExpr(const String& module_name, TECompiler compiler, ProcessFn 
process_fn,
-                     CompilationConfig config) {
+Pass LowerTensorExpr(TECompiler compiler, ProcessFn process_fn, 
CompilationConfig config) {
   runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> 
pass_func =
       [=](Function func, IRModule module, PassContext ctx) {
-        LowerTensorExprMutator lower_te(module, process_fn, config, 
module_name, compiler);
+        LowerTensorExprMutator lower_te(module, process_fn, config, compiler);
         return Downcast<Function>(lower_te.Mutate(func));
       };
   return CreateFunctionPass(pass_func, 0, "LowerTensorExpr", {});
@@ -1184,7 +1176,7 @@ void UpdateFunctionMetadata(BaseFunc func,
 /*! \brief Main lowering driving. */
 IRModule LowerTE(const IRModule& module, const String& module_name, ProcessFn 
process_fn,
                  CompilationConfig config) {
-  TECompiler compiler(module);
+  TECompiler compiler(module, module_name);
 
   // TODO(mbs): This is all unnecessarily convoluted. Better would be to 
accumulate the rewritten
   // module as we go (including rewritten Functions, lowered primitives, and 
runtime modules
@@ -1199,7 +1191,7 @@ IRModule LowerTE(const IRModule& module, const String& 
module_name, ProcessFn pr
   //  - Calls to functions tagged with "Primitive" are compiled to PrimFuncs, 
and calls updated
   //    (using call_lowered convention).
   IRModule updated_module =
-      LowerTensorExpr(module_name, compiler, std::move(process_fn), 
std::move(config))(module);
+      LowerTensorExpr(compiler, std::move(process_fn), 
std::move(config))(module);
 
   // The Functions tagged with "Compiler" are now residing in the cache ready 
to be
   // compiled by LowerExternalFunctions. However we still need a record of 
them in the
diff --git a/src/relay/backend/te_compiler.h b/src/relay/backend/te_compiler.h
index 5d16da4b8b..f2ba84014a 100644
--- a/src/relay/backend/te_compiler.h
+++ b/src/relay/backend/te_compiler.h
@@ -73,7 +73,7 @@ class TECompilerNode : public Object {
    * \param key The key to the cached function.
    * \return The result.
    */
-  virtual CachedFunc Lower(const CCacheKey& key, std::function<String(String)> 
mangle_fn) = 0;
+  virtual CachedFunc Lower(const CCacheKey& key) = 0;
 
   /*!
    * \brief Get lowered result.
@@ -137,7 +137,7 @@ class TECompilerNode : public Object {
 /*! \brief cache entry used in compile engine */
 class TECompiler : public ObjectRef {
  public:
-  explicit TECompiler(Optional<IRModule> opt_mod = {});
+  explicit TECompiler(Optional<IRModule> opt_mod = {}, Optional<String> 
mod_name = {});
   explicit TECompiler(ObjectPtr<Object> n) : ObjectRef(n) {}
   TECompilerNode* operator->() { return 
static_cast<TECompilerNode*>(get_mutable()); }
   using ContainerType = TECompilerNode;
diff --git a/src/relay/backend/te_compiler_cache.cc 
b/src/relay/backend/te_compiler_cache.cc
index bfb351f82b..da52d94b4e 100644
--- a/src/relay/backend/te_compiler_cache.cc
+++ b/src/relay/backend/te_compiler_cache.cc
@@ -137,8 +137,7 @@ class LowerToTECompute : public 
backend::MemoizedExprTranslator<Array<te::Tensor
   explicit LowerToTECompute(Target target)
       : target_(target), device_copy_op_(Op::Get("device_copy")) {}
 
-  Array<te::Tensor> Lower(const Function& relay_func,
-                          std::function<std::string(std::string)> renamer) {
+  Array<te::Tensor> Lower(const Function& relay_func) {
     for (Var param : relay_func->params) {
       Array<tvm::te::Tensor> inputs;
       for (const auto& ttype : FlattenTupleType(param->checked_type())) {
@@ -327,15 +326,15 @@ class ScheduleBuilder : public ExprVisitor {
     }
   }
 
-  CachedFunc Create(const Function& relay_func, 
std::function<std::string(std::string)> renamer) {
+  CachedFunc Create(const Function& relay_func, GlobalVarSupply 
global_var_supply) {
     LowerToTECompute lower_te_compute(target_);
-    Array<te::Tensor> tensor_outs = lower_te_compute.Lower(relay_func, 
renamer);
+    Array<te::Tensor> tensor_outs = lower_te_compute.Lower(relay_func);
     Array<te::Tensor> fn_inputs = lower_te_compute.fn_inputs_;
     VisitExpr(relay_func->body);
 
     // TODO(mbs): This should be the definitive global by which the PrimFunc 
is known and
     // no other GlobalVar ctors should appear inside the lowering machinery.
-    auto prim_fn_var = GlobalVar(renamer(lower_te_compute.candidate_name_));
+    auto prim_fn_var = 
global_var_supply->FreshGlobal(lower_te_compute.candidate_name_);
     prim_fn_var->checked_type_ = relay_func->checked_type();
 
     // Fusion over tupled results may leave identity relationships
@@ -402,8 +401,9 @@ class ScheduleBuilder : public ExprVisitor {
       }
     }
 
-    return CachedFunc(target_, prim_fn_var, fn_inputs, tensor_outs, schedule, 
prim_func, {},
-                      IRModule(Map<GlobalVar, BaseFunc>({})), 
lower_te_compute.constant_tensors_);
+    IRModule funcs = IRModule(Map<GlobalVar, BaseFunc>({}));
+    return CachedFunc(target_, prim_fn_var, fn_inputs, tensor_outs, schedule, 
prim_func, {}, funcs,
+                      lower_te_compute.constant_tensors_);
   }
 
   void VisitExpr_(const CallNode* call_node) final {
@@ -446,8 +446,8 @@ class ScheduleBuilder : public ExprVisitor {
  *  The funcs field in cache is not yet populated.
  */
 CachedFunc PrimFuncFor(const Function& source_func, const Target& target,
-                       std::function<std::string(std::string)> renamer) {
-  return ScheduleBuilder(target).Create(source_func, renamer);
+                       GlobalVarSupply global_var_supply) {
+  return ScheduleBuilder(target).Create(source_func, global_var_supply);
 }
 
 // Creates shape function from functor.
@@ -456,7 +456,7 @@ class MakeShapeFunc : public 
backend::MemoizedExprTranslator<Array<te::Tensor>>
   MakeShapeFunc() {}
 
   CachedFunc Create(const Function& prim_func, const Target& target,
-                    std::function<std::string(std::string)> renamer) {
+                    GlobalVarSupply global_var_supply) {
     VLOG_CONTEXT << "MakeShapeFunc";
     TShapeDataDependent shape_func_param_states;
 
@@ -527,8 +527,7 @@ class MakeShapeFunc : public 
backend::MemoizedExprTranslator<Array<te::Tensor>>
 
     // TODO(mbs): This should be the definitive global by which the PrimFunc 
is known and
     // no  other GlobalVar ctors should appear inside the lowering machinery.
-    auto func_name = renamer(candidate_name);
-    auto prim_fn_gvar = GlobalVar(func_name);
+    auto prim_fn_gvar = global_var_supply->FreshGlobal(candidate_name);
 
     // Gather the result types, again from the p.o.v. of the shape function 
rather than
     // the primitive it is derived for.
@@ -569,19 +568,10 @@ class MakeShapeFunc : public 
backend::MemoizedExprTranslator<Array<te::Tensor>>
     With<PassContext> fresh_pass_ctx_scope(PassContext::Create());
 
     std::unordered_map<te::Tensor, tir::Buffer> binds;
-    IRModule lowered_module = tvm::LowerSchedule(schedule, all_args, 
func_name, binds);
-
-    // Unfortunately the above machinery creates its own GlobalVars instead of 
using *the*
-    // GlobalVar we established above. Fix this before the confusion spreads 
any further.
-    // TODO(mbs): LowerSchedule should be given prim_fn_gvar instead of 
func_name.
-    IRModule fixed_lowered_module;
-    for (const auto& kv : lowered_module->functions) {
-      GlobalVar global_var =
-          kv.first->name_hint == prim_fn_gvar->name_hint ? prim_fn_gvar : 
kv.first;
-      fixed_lowered_module->Add(global_var, kv.second);
-    }
+    IRModule lowered_module =
+        tvm::LowerSchedule(schedule, all_args, prim_fn_gvar->name_hint, binds, 
global_var_supply);
     return CachedFunc(target, prim_fn_gvar, inputs, outputs, schedule, 
tir::PrimFunc{nullptr},
-                      shape_func_param_states, fixed_lowered_module);
+                      shape_func_param_states, lowered_module);
   }
 
   Array<te::Tensor> VisitExpr(const Expr& expr) final {
@@ -791,15 +781,14 @@ class MakeShapeFunc : public 
backend::MemoizedExprTranslator<Array<te::Tensor>>
 };
 
 CachedFunc ShapeFuncFor(const Function& prim_func, const Target& target,
-                        std::function<std::string(std::string)> renamer) {
-  return MakeShapeFunc().Create(prim_func, target, renamer);
+                        GlobalVarSupply global_var_supply) {
+  return MakeShapeFunc().Create(prim_func, target, global_var_supply);
 }
 
 std::pair<Array<te::Tensor>, std::string> LowerTECompute(const Function& 
source_func, Target target,
                                                          bool return_inputs) {
   LowerToTECompute lower_te_compute(target);
-  Array<te::Tensor> outputs =
-      lower_te_compute.Lower(source_func, [](std::string name) { return name; 
});
+  Array<te::Tensor> outputs = lower_te_compute.Lower(source_func);
   // Following ScheduleBuilder, remove placeholder ops from outputs.
   tvm::Array<te::Tensor> tensor_outs;
   for (const auto& tensor : outputs) {
@@ -814,34 +803,10 @@ std::pair<Array<te::Tensor>, std::string> 
LowerTECompute(const Function& source_
   return std::make_pair(tensor_outs, lower_te_compute.candidate_name_);
 }
 
-/*!
- * \brief Get unique name from name.
- * \param name The orginal name.
- * \return Updated name which is unique.
- */
-std::string GetUniqueName(std::string name, std::unordered_map<std::string, 
int>* name_map_) {
-  for (size_t i = 0; i < name.length(); ++i) {
-    if (name[i] == '.') name[i] = '_';
-  }
-  while (true) {
-    auto it = name_map_->find(name);
-    if (it == name_map_->end()) {
-      (*name_map_)[name] = 1;
-      return name;
-    } else {
-      std::ostringstream os;
-      os << name << "_" << it->second;
-      ++(it->second);
-      name = os.str();
-    }
-  }
-  return name;
-}
-
 TVM_REGISTER_GLOBAL("relay.backend.LowerToTE").set_body_typed([](Function 
prim_func) {
   auto tgt = tvm::Target("ext_dev");
   LowerToTECompute lower_te_compute(tgt);
-  auto outputs = lower_te_compute.Lower(prim_func, [&](std::string name) { 
return name; });
+  auto outputs = lower_te_compute.Lower(prim_func);
   return CachedFunc(tgt, GlobalVar(lower_te_compute.candidate_name_), 
lower_te_compute.fn_inputs_,
                     outputs, te::Schedule(), tir::PrimFunc(), {},
                     IRModule(Map<GlobalVar, BaseFunc>({})), 
lower_te_compute.constant_tensors_);
diff --git a/src/relay/backend/te_compiler_cache.h 
b/src/relay/backend/te_compiler_cache.h
index ac26198260..894a5f5be5 100644
--- a/src/relay/backend/te_compiler_cache.h
+++ b/src/relay/backend/te_compiler_cache.h
@@ -24,6 +24,7 @@
 #ifndef TVM_RELAY_BACKEND_TE_COMPILER_CACHE_H_
 #define TVM_RELAY_BACKEND_TE_COMPILER_CACHE_H_
 
+#include <tvm/ir/name_supply.h>
 #include <tvm/node/structural_equal.h>
 #include <tvm/node/structural_hash.h>
 #include <tvm/relay/analysis.h>
@@ -227,13 +228,10 @@ std::pair<Array<te::Tensor>, std::string> 
LowerTECompute(const Function& source_
  *  The funcs field in cache is not yet populated.
  */
 CachedFunc PrimFuncFor(const Function& source_func, const Target& target,
-                       std::function<std::string(std::string)> renamer);
+                       GlobalVarSupply global_var_supply);
 
 CachedFunc ShapeFuncFor(const Function& prim_func, const Target& target,
-                        std::function<std::string(std::string)> renamer);
-
-// TODO(mbs): Bring name uniqification under control -- this is replicated in 
quite a few places.
-std::string GetUniqueName(std::string name, std::unordered_map<std::string, 
int>* name_map);
+                        GlobalVarSupply global_var_supply);
 
 // implementations
 inline size_t CCacheKeyNode::Hash() const {
diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc
index b2776a41c5..42fec9e27a 100644
--- a/src/relay/ir/dataflow_matcher.cc
+++ b/src/relay/ir/dataflow_matcher.cc
@@ -22,6 +22,7 @@
  * \brief The dataflow pattern matcher for Relay.
  */
 
+#include <tvm/ir/global_var_supply.h>
 #include <tvm/relay/analysis.h>
 #include <tvm/relay/dataflow_matcher.h>
 #include <tvm/relay/expr_functor.h>
@@ -438,15 +439,8 @@ Expr InferType(const Expr& expr) {
 
 Expr InferTypeWithModule(const Expr& expr, const IRModule& m) {
   IRModule mod(m->functions, m->type_definitions, m->Imports());
-  int idx = 0;
-  std::string gv_name;
-  do {
-    std::ostringstream oss;
-    oss << "_tmp" << idx;
-    gv_name = oss.str();
-    ++idx;
-  } while (mod->ContainGlobalVar(gv_name));
-  GlobalVar gvar(gv_name);
+  GlobalVarSupply global_var_supply = GlobalVarSupply(mod);
+  GlobalVar gvar = global_var_supply->FreshGlobal("_tmp", false);
   BaseFunc func;
   if (expr.as<FunctionNode>()) {
     func = Downcast<Function>(expr);
diff --git a/src/relay/transforms/auto_scheduler_layout_rewrite.cc 
b/src/relay/transforms/auto_scheduler_layout_rewrite.cc
index c538dac048..25111cec8e 100644
--- a/src/relay/transforms/auto_scheduler_layout_rewrite.cc
+++ b/src/relay/transforms/auto_scheduler_layout_rewrite.cc
@@ -126,8 +126,7 @@ Expr AutoSchedulerLayoutRewriter::VisitExpr_(const 
CallNode* n) {
       CHECK(f) << "Could not find auto_scheduler.enter_layout_rewrite 
function.";
       (*f)();
 
-      tec::PrimFuncFor(GetRef<Function>(func), Target::Current(),
-                       [](std::string name) { return name; });
+      tec::PrimFuncFor(GetRef<Function>(func), Target::Current(), 
GlobalVarSupply(NameSupply("")));
 
       f = runtime::Registry::Get("auto_scheduler.exit_layout_rewrite");
       CHECK(f) << "Could not find ansor.exit_layout_rewrite function.";
diff --git a/src/relay/transforms/meta_schedule_layout_rewrite.cc 
b/src/relay/transforms/meta_schedule_layout_rewrite.cc
index b817802f17..8a70f224c6 100644
--- a/src/relay/transforms/meta_schedule_layout_rewrite.cc
+++ b/src/relay/transforms/meta_schedule_layout_rewrite.cc
@@ -127,8 +127,7 @@ Expr MetaScheduleLayoutRewriter::VisitExpr_(const CallNode* 
call) {
     if (const auto* func = call->op.as<FunctionNode>()) {
       LayoutIndexQueue* self = LayoutIndexQueue::Global();
       self->queue_.clear();
-      tec::PrimFuncFor(GetRef<Function>(func), Target::Current(),
-                       [](std::string name) { return name; });
+      tec::PrimFuncFor(GetRef<Function>(func), Target::Current(), 
GlobalVarSupply(NameSupply("")));
       if (!self->queue_.empty()) {
         std::deque<tir::IndexMap> queue = std::move(self->queue_);
         self->queue_.clear();
diff --git a/src/relay/transforms/partition_graph.cc 
b/src/relay/transforms/partition_graph.cc
index bc1ed518d4..e2df2e4272 100644
--- a/src/relay/transforms/partition_graph.cc
+++ b/src/relay/transforms/partition_graph.cc
@@ -333,14 +333,15 @@ class Partitioner : public MixedModeMutator {
         WithAttr(std::move(global_region_func), attr::kCompiler, 
tvm::runtime::String(target));
     global_region_func = WithAttr(std::move(global_region_func), 
attr::kInline, tvm::Integer(1));
 
-    std::string fname = name;
-    ICHECK(!module_->ContainGlobalVar(fname)) << "Global function " << fname 
<< " already exists";
+    GlobalVarSupply global_var_supply = GlobalVarSupply(module_);
+    GlobalVar glob_func = global_var_supply->FreshGlobal(name, false);
+    ICHECK(!module_->ContainGlobalVar(glob_func->name_hint))
+        << "Global function " << glob_func->name_hint << " already exists";
     // Create a global function and add it to the IRModule for the region.
     // This way we lift the functions that should be handled by external
     // codegen to the module scope and rely on the pass manager to prevent
     // relay function level passes (i.e. simplify inference and fusion)
     // optimizing it.
-    GlobalVar glob_func(fname);
     module_->Add(glob_func, global_region_func);
     module_ = relay::transform::InferType()(module_);
 
diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc
index 3fe7fa50d3..6b1ca81d85 100644
--- a/src/target/source/codegen_c.cc
+++ b/src/target/source/codegen_c.cc
@@ -45,33 +45,33 @@ void CodeGenC::InitFuncState(const PrimFunc& f) {
 
 void CodeGenC::ReserveKeywordsAsUnique() {
   // skip the first underscore, so SSA variable starts from _1
-  GetUniqueName("_");
-  GetUniqueName("extern");
-  GetUniqueName("void");
-  GetUniqueName("int");
-  GetUniqueName("float");
-  GetUniqueName("double");
-  GetUniqueName("char");
-  GetUniqueName("unsigned");
-  GetUniqueName("short");
-  GetUniqueName("long");
-  GetUniqueName("if");
-  GetUniqueName("else");
-  GetUniqueName("switch");
-  GetUniqueName("case");
-  GetUniqueName("default");
-  GetUniqueName("for");
-  GetUniqueName("do");
-  GetUniqueName("while");
-  GetUniqueName("goto");
-  GetUniqueName("register");
-  GetUniqueName("continue");
-  GetUniqueName("break");
-  GetUniqueName("typedef");
-  GetUniqueName("struct");
-  GetUniqueName("enum");
-  GetUniqueName("union");
-  GetUniqueName("return");
+  name_supply_->ReserveName("_");
+  name_supply_->ReserveName("extern");
+  name_supply_->ReserveName("void");
+  name_supply_->ReserveName("int");
+  name_supply_->ReserveName("float");
+  name_supply_->ReserveName("double");
+  name_supply_->ReserveName("char");
+  name_supply_->ReserveName("unsigned");
+  name_supply_->ReserveName("short");
+  name_supply_->ReserveName("long");
+  name_supply_->ReserveName("if");
+  name_supply_->ReserveName("else");
+  name_supply_->ReserveName("switch");
+  name_supply_->ReserveName("case");
+  name_supply_->ReserveName("default");
+  name_supply_->ReserveName("for");
+  name_supply_->ReserveName("do");
+  name_supply_->ReserveName("while");
+  name_supply_->ReserveName("goto");
+  name_supply_->ReserveName("register");
+  name_supply_->ReserveName("continue");
+  name_supply_->ReserveName("break");
+  name_supply_->ReserveName("typedef");
+  name_supply_->ReserveName("struct");
+  name_supply_->ReserveName("enum");
+  name_supply_->ReserveName("union");
+  name_supply_->ReserveName("return");
 }
 
 void CodeGenC::AddFunction(const PrimFunc& f) {
diff --git a/src/target/source/codegen_c_host.cc 
b/src/target/source/codegen_c_host.cc
index 54975d166e..a47158d378 100644
--- a/src/target/source/codegen_c_host.cc
+++ b/src/target/source/codegen_c_host.cc
@@ -42,7 +42,7 @@
 namespace tvm {
 namespace codegen {
 
-CodeGenCHost::CodeGenCHost() { module_name_ = 
GetUniqueName("__tvm_module_ctx"); }
+CodeGenCHost::CodeGenCHost() { module_name_ = 
name_supply_->FreshName("__tvm_module_ctx"); }
 
 void CodeGenCHost::Init(bool output_ssa, bool emit_asserts, std::string 
target_str,
                         const std::unordered_set<std::string>& devices) {
@@ -207,8 +207,8 @@ void CodeGenCHost::PrintGetFuncFromBackend(const 
std::string& func_name,
 
 void CodeGenCHost::PrintFuncCall(const std::string& packed_func_name, int 
num_args) {
   this->PrintIndent();
-  std::string ret_val = GetUniqueName("ret_val");
-  std::string ret_type_code = GetUniqueName("ret_type_code");
+  std::string ret_val = name_supply_->FreshName("ret_val");
+  std::string ret_type_code = name_supply_->FreshName("ret_type_code");
   this->stream << "TVMValue " << ret_val << ";\n";
   this->PrintIndent();
   this->stream << "int " << ret_type_code << ";\n";
@@ -231,8 +231,8 @@ void CodeGenCHost::PrintFuncCall(const std::string& 
packed_func_name, int num_ar
 void CodeGenCHost::PrintFuncCallC(const std::string& packed_func_name, int 
num_args,
                                   const std::string& resource_handle_name) {
   this->PrintIndent();
-  std::string ret_val = GetUniqueName("ret_val");
-  std::string ret_type_code = GetUniqueName("ret_type_code");
+  std::string ret_val = name_supply_->FreshName("ret_val");
+  std::string ret_type_code = name_supply_->FreshName("ret_type_code");
   this->stream << "TVMValue " << ret_val << ";\n";
   this->PrintIndent();
   this->stream << "int " << ret_type_code << ";\n";
@@ -264,7 +264,7 @@ std::string CodeGenCHost::GetPackedName(const CallNode* op) 
{
   if (it != declared_globals_.end()) {
     unique_name = it->second;
   } else {
-    unique_name = GetUniqueName(packed_func_name);
+    unique_name = name_supply_->FreshName(packed_func_name);
     declared_globals_[packed_func_name] = unique_name;
     decl_stream << "static void* " << unique_name << " = NULL;\n";
   }
@@ -310,7 +310,7 @@ CodeGenCHost::FunctionInfo 
CodeGenCHost::GetFunctionInfo(const CallNode* op,
 
 void CodeGenCHost::VisitExpr_(const CallNode* op, std::ostream& os) {  // 
NOLINT(*)
   if (op->op.same_as(builtin::tvm_stack_alloca())) {
-    std::string stack_name = GetUniqueName("stack");
+    std::string stack_name = name_supply_->FreshName("stack");
     const std::string& type = op->args[0].as<StringImmNode>()->value;
     const IntImmNode* num = op->args[1].as<IntImmNode>();
     ICHECK(num != nullptr);
diff --git a/src/target/source/codegen_cuda.cc 
b/src/target/source/codegen_cuda.cc
index dde1d112ed..7350292167 100644
--- a/src/target/source/codegen_cuda.cc
+++ b/src/target/source/codegen_cuda.cc
@@ -43,8 +43,8 @@ CodeGenCUDA::CodeGenCUDA() { restrict_keyword_ = 
"__restrict__"; }
 
 void CodeGenCUDA::Init(bool output_ssa) {
   CodeGenC::Init(output_ssa);
-  vid_global_barrier_state_ = 
GetUniqueName(runtime::symbol::tvm_global_barrier_state);
-  vid_global_barrier_expect_ = GetUniqueName("__barrier_expect");
+  vid_global_barrier_state_ = 
name_supply_->FreshName(runtime::symbol::tvm_global_barrier_state);
+  vid_global_barrier_expect_ = name_supply_->FreshName("__barrier_expect");
   ICHECK_EQ(vid_global_barrier_state_, 
runtime::symbol::tvm_global_barrier_state);
 }
 
@@ -403,7 +403,7 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { 
 // NOLINT(*)
 void CodeGenCUDA::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr 
lhs, PrimExpr rhs,
                                    std::ostream& os) {  // NOLINT(*)
   // Delcare the result.
-  std::string sret = GetUniqueName("_");
+  std::string sret = name_supply_->FreshName("_");
   this->PrintIndent();
   this->PrintType(t, stream);
   stream << ' ' << sret << ";\n";
@@ -555,7 +555,7 @@ void CodeGenCUDA::PrintStorageSync(const CallNode* op) {
     this->PrintIndent();
     this->stream << "atomicAdd(&" << vid_global_barrier_state_ << ", 1);\n";
     this->PrintIndent();
-    std::string ptr = GetUniqueName("pf");
+    std::string ptr = name_supply_->FreshName("pf");
     this->stream << "volatile unsigned* " << ptr << " = &" << 
vid_global_barrier_state_ << ";\n";
     this->PrintIndent();
     this->stream << vid_global_barrier_expect_ << " += " << num_blocks << 
";\n";
@@ -589,7 +589,7 @@ void CodeGenCUDA::VisitExpr_(const CastNode* op, 
std::ostream& os) {
 
   // We could emit make_float4 like calls, but the emitted code looks
   // too compact to read. Emit this as vectorized unary ops.
-  std::string sret = GetUniqueName("_");
+  std::string sret = name_supply_->FreshName("_");
   this->PrintIndent();
   this->PrintType(target_ty, stream);
   stream << ' ' << sret << ";\n";
@@ -631,7 +631,7 @@ void CodeGenCUDA::PrintCallExtern(Type ret_type, String 
global_symbol, const Arr
     // v = __ret;
     //
     // Declare the result vector.
-    std::string sret = GetUniqueName("_");
+    std::string sret = name_supply_->FreshName("_");
     this->PrintIndent();
     this->PrintType(ret_dtype, stream);
     stream << ' ' << sret << ";\n";
@@ -1138,7 +1138,7 @@ void CodeGenCUDA::VisitExpr_(const SelectNode* op, 
std::ostream& os) {
   ICHECK(op->false_value->dtype == op->dtype && op->true_value->dtype == 
op->dtype &&
          op->dtype.lanes() == op->condition.dtype().lanes());
 
-  std::string r_var = GetUniqueName("_");
+  std::string r_var = name_supply_->FreshName("_");
   this->PrintIndent();
   this->PrintType(op->dtype, stream);
   stream << ' ' << r_var << ";\n";
diff --git a/src/target/source/codegen_metal.cc 
b/src/target/source/codegen_metal.cc
index 0ec6179115..b3ca3eb461 100644
--- a/src/target/source/codegen_metal.cc
+++ b/src/target/source/codegen_metal.cc
@@ -55,7 +55,7 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) {
   // clear previous generated state.
   this->InitFuncState(f);
   // skip the first underscore, so SSA variable starts from _1
-  GetUniqueName("_");
+  name_supply_->FreshName("_");
 
   // add to alloc buffer type.
   auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
@@ -94,7 +94,7 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) {
   }
   // Setup normal arguments.
   size_t nargs = f->params.size() - num_buffer;
-  std::string varg = GetUniqueName("arg");
+  std::string varg = name_supply_->FreshName("arg");
   if (nargs != 0) {
     std::string arg_buf_type = static_cast<std::string>(global_symbol.value()) 
+ "_args_t";
     stream << "  constant " << arg_buf_type << "& " << varg << " [[ buffer(" 
<< num_buffer
@@ -127,8 +127,8 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) {
     decl_stream << "};\n\n";
   }
   // Setup the thread group info.
-  ICHECK_EQ(GetUniqueName("threadIdx"), "threadIdx");
-  ICHECK_EQ(GetUniqueName("blockIdx"), "blockIdx");
+  ICHECK_EQ(name_supply_->FreshName("threadIdx"), "threadIdx");
+  ICHECK_EQ(name_supply_->FreshName("blockIdx"), "blockIdx");
   int work_dim = 0;
   auto thread_axis = 
f->GetAttr<Array<tir::IterVar>>(tir::attr::kDeviceThreadAxis).value();
 
diff --git a/src/target/source/codegen_source_base.cc 
b/src/target/source/codegen_source_base.cc
index 2353d2e6ba..75833fd936 100644
--- a/src/target/source/codegen_source_base.cc
+++ b/src/target/source/codegen_source_base.cc
@@ -28,34 +28,14 @@ namespace tvm {
 namespace codegen {
 
 void CodeGenSourceBase::ClearFuncState() {
-  name_alloc_map_.clear();
+  name_supply_ = NameSupply("");
   ssa_assign_map_.clear();
   var_idmap_.clear();
   scope_mark_.clear();
 }
 
-std::string CodeGenSourceBase::GetUniqueName(std::string prefix) {
-  for (size_t i = 0; i < prefix.size(); ++i) {
-    if (prefix[i] == '.') prefix[i] = '_';
-  }
-  auto it = name_alloc_map_.find(prefix);
-  if (it != name_alloc_map_.end()) {
-    while (true) {
-      std::ostringstream os;
-      os << prefix << (++it->second);
-      std::string name = os.str();
-      if (name_alloc_map_.count(name) == 0) {
-        prefix = name;
-        break;
-      }
-    }
-  }
-  name_alloc_map_[prefix] = 0;
-  return prefix;
-}
-
 std::string CodeGenSourceBase::SSAGetID(std::string src, DataType t) {
-  if (name_alloc_map_.count(src)) return src;
+  if (name_supply_->ContainsName(src)) return src;
   auto it = ssa_assign_map_.find(src);
   if (it != ssa_assign_map_.end()) {
     if (scope_mark_.at(it->second.scope_id)) {
@@ -63,7 +43,7 @@ std::string CodeGenSourceBase::SSAGetID(std::string src, 
DataType t) {
     }
   }
   SSAEntry e;
-  e.vid = GetUniqueName("_");
+  e.vid = name_supply_->FreshName("_");
   e.scope_id = static_cast<int>(scope_mark_.size() - 1);
   ssa_assign_map_[src] = e;
   this->PrintIndent();
@@ -74,7 +54,7 @@ std::string CodeGenSourceBase::SSAGetID(std::string src, 
DataType t) {
 std::string CodeGenSourceBase::AllocVarID(const tir::VarNode* v) {
   ICHECK(!var_idmap_.count(v)) << "Need input to be in SSA form dup " << 
v->name_hint;
   std::string key = v->name_hint;
-  std::string vid = GetUniqueName(key);
+  std::string vid = name_supply_->FreshName(key);
   std::replace(vid.begin(), vid.end(), ':', '_');
   std::replace(vid.begin(), vid.end(), '-', '_');
   std::replace(vid.begin(), vid.end(), '.', '_');
diff --git a/src/target/source/codegen_source_base.h 
b/src/target/source/codegen_source_base.h
index 66287f9ad1..2fd0abcd68 100644
--- a/src/target/source/codegen_source_base.h
+++ b/src/target/source/codegen_source_base.h
@@ -25,6 +25,7 @@
 #ifndef TVM_TARGET_SOURCE_CODEGEN_SOURCE_BASE_H_
 #define TVM_TARGET_SOURCE_CODEGEN_SOURCE_BASE_H_
 
+#include <tvm/ir/name_supply.h>
 #include <tvm/runtime/metadata.h>
 #include <tvm/target/codegen.h>
 #include <tvm/tir/expr.h>
@@ -97,12 +98,6 @@ class CodeGenSourceBase {
    * \param t The type of the expression.
    */
   std::string SSAGetID(std::string src, DataType t);
-  /*!
-   * \brief get a unique name with the corresponding prefix
-   * \param prefix The prefix of the name
-   * \return The returned name.
-   */
-  std::string GetUniqueName(std::string prefix);
   /*!
    * \brief mark the beginning of a new scope
    * \return The scope id.
@@ -127,12 +122,12 @@ class CodeGenSourceBase {
   std::ostringstream stream;
   /*! \brief name of each variable */
   std::unordered_map<const tir::VarNode*, std::string> var_idmap_;
+  /*! \brief NameSupply for allocation */
+  NameSupply name_supply_ = NameSupply("");
 
  private:
   /*! \brief assignment map of ssa */
   std::unordered_map<std::string, SSAEntry> ssa_assign_map_;
-  /*! \brief name allocation map */
-  std::unordered_map<std::string, int> name_alloc_map_;
   /*! \brief array to check whether we are inside certain scope */
   std::vector<bool> scope_mark_;
   /*! \brief The current indentation value */
diff --git a/src/te/operation/create_primfunc.cc 
b/src/te/operation/create_primfunc.cc
index 68b25a1653..55df71a805 100644
--- a/src/te/operation/create_primfunc.cc
+++ b/src/te/operation/create_primfunc.cc
@@ -18,6 +18,7 @@
  */
 
 #include <tvm/arith/analyzer.h>
+#include <tvm/ir/name_supply.h>
 #include <tvm/runtime/registry.h>
 #include <tvm/tir/function.h>
 #include <tvm/tir/stmt_functor.h>
@@ -61,8 +62,10 @@ struct CreateFuncInfo {
   ProducerToBufferTransformer transformer;
   /*! \brief The buffers should be allocated at function root. */
   Array<Buffer> root_alloc;
-  /*! \brief The count map to make block name unique. */
-  std::unordered_map<String, int> name_count;
+  /*! \brief The NameSupply to make block name unique. */
+  NameSupply name_supply = NameSupply("");
+
+  String FreshName(String base_name) { return 
name_supply->FreshName(base_name); }
 
   explicit CreateFuncInfo(Array<te::Tensor> arg_list)
       : arg_list(std::move(arg_list)), transformer(tensor2buffers) {}
@@ -71,16 +74,6 @@ struct CreateFuncInfo {
     return std::any_of(arg_list.begin(), arg_list.end(),
                        [&tensor](const te::Tensor& arg) { return tensor == 
arg; });
   }
-
-  String GetUniqueName(const String& prefix) {
-    String unique_prefix = prefix;
-    auto it = name_count.find(prefix);
-    while (name_count.count(unique_prefix)) {
-      unique_prefix = prefix + "_" + std::to_string(++it->second);
-    }
-    name_count[unique_prefix] = 0;
-    return unique_prefix;
-  }
 };
 
 class LayoutFreePlaceholdersNormalizer : public StmtMutator {
@@ -179,7 +172,7 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& 
compute_op,
   Stmt body;
   if (const auto* reduce = expr_body.as<ReduceNode>()) {
     // Case 1. Reduce compute
-    block_name = info->GetUniqueName(compute_op->name);
+    block_name = info->FreshName(compute_op->name);
     int n_buffers = buffers.size();
 
     Array<PrimExpr> lhs;
@@ -236,7 +229,7 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& 
compute_op,
   } else {
     // Case 2. Data parallel compute
     ICHECK_EQ(tensors.size(), 1);
-    block_name = info->GetUniqueName(tensors[0]->GetNameHint());
+    block_name = info->FreshName(tensors[0]->GetNameHint());
     const PrimExpr& compute_body = Substitute(info->transformer(expr_body), 
var_map);
     body = BufferStore(info->tensor2buffers[tensors[0]], 
analyzer->Simplify(compute_body), indices);
   }
@@ -387,7 +380,7 @@ Stmt GenerateStmtFromExternOp(const te::ExternOp& 
extern_op, CreateFuncInfo* inf
                       Block(/*iter_vars=*/{},
                             /*reads=*/std::move(reads),
                             /*writes=*/std::move(writes),
-                            /*name_hint=*/info->GetUniqueName(extern_op->name),
+                            /*name_hint=*/info->FreshName(extern_op->name),
                             /*body=*/std::move(body),
                             /*init=*/NullOpt,
                             /*alloc_buffers=*/{},
diff --git a/src/tir/transforms/split_host_device.cc 
b/src/tir/transforms/split_host_device.cc
index 85845616f1..dc56a3ce76 100644
--- a/src/tir/transforms/split_host_device.cc
+++ b/src/tir/transforms/split_host_device.cc
@@ -21,6 +21,7 @@
  * \file split_host_device.cc
  * \brief Split device function from host.
  */
+#include <tvm/ir/global_var_supply.h>
 #include <tvm/ir/transform.h>
 #include <tvm/runtime/registry.h>
 #include <tvm/target/target.h>
@@ -302,12 +303,15 @@ class HostDeviceSplitter : public StmtMutator {
         arguments.push_back(var);
       }
     }
+    GlobalVarSupply global_var_supply = GlobalVarSupply(*device_mod_);
+    GlobalVar kernel_symbol_global = 
global_var_supply->FreshGlobal(kernel_symbol, false);
+
     PrimFunc device_func(params, Substitute(body, remap_vars));
     device_func = WithAttr(std::move(device_func), 
tir::attr::kDeviceThreadAxis, m.thread_axis_);
     device_func = WithAttr(std::move(device_func), tvm::attr::kCallingConv,
                            Integer(CallingConv::kDeviceKernelLaunch));
-    device_func =
-        WithAttr(std::move(device_func), tvm::attr::kGlobalSymbol, 
runtime::String(kernel_symbol));
+    device_func = WithAttr(std::move(device_func), tvm::attr::kGlobalSymbol,
+                           runtime::String(kernel_symbol_global->name_hint));
     device_func = WithAttr(std::move(device_func), tir::attr::kNoAlias, 
Integer(1));
     device_func = WithAttr(std::move(device_func), tvm::attr::kTarget, 
device_target_);
     device_func = WithAttr(std::move(device_func), tir::attr::kIsGlobalFunc, 
Integer(1));
@@ -315,11 +319,11 @@ class HostDeviceSplitter : public StmtMutator {
       device_func =
           WithAttr(std::move(device_func), 
tir::attr::kDeviceUseDynSharedMemory, Integer(1));
     }
-    (*device_mod_)->Add(GlobalVar(kernel_symbol), device_func);
+    (*device_mod_)->Add(kernel_symbol_global, device_func);
 
     // generate calls to the device function
     Array<PrimExpr> call_args;
-    call_args.push_back(StringImm(kernel_symbol));
+    call_args.push_back(StringImm(kernel_symbol_global->name_hint));
     for (PrimExpr arg : arguments) {
       call_args.push_back(arg);
     }
diff --git a/tests/cpp/build_module_test.cc b/tests/cpp/build_module_test.cc
index ff3641cd69..3d2adb2355 100644
--- a/tests/cpp/build_module_test.cc
+++ b/tests/cpp/build_module_test.cc
@@ -52,7 +52,7 @@ TEST(BuildModule, Basic) {
 
   auto target = Target("llvm");
 
-  auto lowered = LowerSchedule(s, args, "func", binds);
+  auto lowered = LowerSchedule(s, args, "func", binds, 
GlobalVarSupply(NameSupply("")));
   auto module = build(lowered, target, Target());
 
   auto mali_target = Target("opencl -model=Mali-T860MP4@800Mhz -device=mali");
@@ -121,8 +121,9 @@ TEST(BuildModule, Heterogeneous) {
   auto args2 = Array<Tensor>({copy, C, elemwise_sub});
 
   std::unordered_map<Tensor, Buffer> binds;
-  auto lowered_s1 = LowerSchedule(fcreate_s1(), args1, "elemwise_add", binds);
-  auto lowered_s2 = LowerSchedule(fcreate_s2(), args2, "elemwise_sub", binds);
+  GlobalVarSupply global_var_supply = GlobalVarSupply(NameSupply(""));
+  auto lowered_s1 = LowerSchedule(fcreate_s1(), args1, "elemwise_add", binds, 
global_var_supply);
+  auto lowered_s2 = LowerSchedule(fcreate_s2(), args2, "elemwise_sub", binds, 
global_var_supply);
   Map<tvm::Target, IRModule> inputs = {{target_cuda, lowered_s1}, 
{target_llvm, lowered_s2}};
   auto module = build(inputs, Target());
 
diff --git a/tests/cpp/c_codegen_test.cc b/tests/cpp/c_codegen_test.cc
index 097de862a9..442f76a8cf 100644
--- a/tests/cpp/c_codegen_test.cc
+++ b/tests/cpp/c_codegen_test.cc
@@ -52,7 +52,8 @@ TEST(CCodegen, MainFunctionOrder) {
   auto args = Array<Tensor>({A, B, elemwise_add});
 
   std::unordered_map<Tensor, Buffer> binds;
-  auto lowered = LowerSchedule(fcreate(), args, "elemwise_add", binds);
+  auto lowered =
+      LowerSchedule(fcreate(), args, "elemwise_add", binds, 
GlobalVarSupply(NameSupply("")));
   Map<tvm::Target, IRModule> inputs = {{target_c, lowered}};
   runtime::Module module = build(inputs, Target());
   Array<String> functions = module->GetFunction("get_func_names", false)();
@@ -81,7 +82,8 @@ auto BuildLowered(std::string op_name, tvm::Target target) {
 
   auto args = Array<Tensor>({A, B, op});
   std::unordered_map<Tensor, Buffer> binds;
-  auto lowered_s = LowerSchedule(fcreate_s(), args, op_name, binds);
+  auto lowered_s =
+      LowerSchedule(fcreate_s(), args, op_name, binds, 
GlobalVarSupply(NameSupply("")));
   return lowered_s;
 }
 
diff --git a/tests/cpp/name_supply_test.cc b/tests/cpp/name_supply_test.cc
new file mode 100644
index 0000000000..75b9ae86a9
--- /dev/null
+++ b/tests/cpp/name_supply_test.cc
@@ -0,0 +1,129 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <gtest/gtest.h>
+#include <tvm/ir/global_var_supply.h>
+#include <tvm/ir/module.h>
+#include <tvm/ir/name_supply.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/function.h>
+
+using namespace tvm;
+
+NameSupply preambleNameSupply() {
+  NameSupply name_supply = NameSupply("prefix");
+  name_supply->FreshName("test");
+  return name_supply;
+}
+
+TEST(NameSupply, FreshName) {
+  NameSupply name_supply = preambleNameSupply();
+  String fresh = name_supply->FreshName("test");
+
+  EXPECT_EQ(fresh.compare("prefix_test_1"), 0);
+}
+
+TEST(NameSupply, FreshNameNoConflict) {
+  NameSupply name_supply = preambleNameSupply();
+  String fresh = name_supply->FreshName("name_2");
+  EXPECT_EQ(fresh.compare("prefix_name_2"), 0);
+
+  fresh = name_supply->FreshName("name");
+  EXPECT_EQ(fresh.compare("prefix_name"), 0);
+
+  fresh = name_supply->FreshName("name");
+  EXPECT_EQ(fresh.compare("prefix_name_1"), 0);
+
+  fresh = name_supply->FreshName("name");
+  EXPECT_EQ(fresh.compare("prefix_name_3"), 0);
+}
+
+TEST(NameSupply, ContainsName) {
+  NameSupply name_supply = preambleNameSupply();
+
+  EXPECT_TRUE(name_supply->ContainsName("test"));
+  EXPECT_FALSE(name_supply->ContainsName("test_1"));
+}
+
+TEST(NameSupply, ReserveName) {
+  NameSupply name_supply = preambleNameSupply();
+  name_supply->ReserveName("otherTest", false);
+
+  EXPECT_TRUE(name_supply->ContainsName("otherTest", false));
+  EXPECT_FALSE(name_supply->ContainsName("otherTest"));
+
+  name_supply->ReserveName("otherTest");
+  EXPECT_TRUE(name_supply->ContainsName("prefix_otherTest", false));
+  EXPECT_TRUE(name_supply->ContainsName("otherTest"));
+}
+
+GlobalVarSupply preambleVarSupply() {
+  GlobalVarSupply global_var_supply = GlobalVarSupply(NameSupply(""));
+  global_var_supply->FreshGlobal("test");
+  return global_var_supply;
+}
+
+TEST(GlobalVarSupply, FreshGlobal) {
+  GlobalVarSupply global_var_supply = preambleVarSupply();
+  GlobalVar first_var = global_var_supply->FreshGlobal("test");
+  GlobalVar second_var = global_var_supply->FreshGlobal("test");
+
+  EXPECT_FALSE(tvm::StructuralEqual()(first_var, second_var));
+  EXPECT_EQ(first_var->name_hint.compare("test_1"), 0);
+  EXPECT_EQ(second_var->name_hint.compare("test_2"), 0);
+}
+
+TEST(GlobalVarSupply, UniqueGlobalFor) {
+  GlobalVarSupply global_var_supply = preambleVarSupply();
+  GlobalVar first_var = global_var_supply->UniqueGlobalFor("someName");
+  GlobalVar second_var = global_var_supply->UniqueGlobalFor("someName");
+
+  EXPECT_TRUE(tvm::StructuralEqual()(first_var, second_var));
+  EXPECT_EQ(first_var->name_hint.compare("someName"), 0);
+  EXPECT_EQ(second_var->name_hint.compare("someName"), 0);
+}
+
+TEST(GlobalVarSupply, ReserveGlobal) {
+  GlobalVarSupply global_var_supply = preambleVarSupply();
+  GlobalVar var = GlobalVar("someName");
+  global_var_supply->ReserveGlobalVar(var);
+  GlobalVar second_var = global_var_supply->UniqueGlobalFor("someName");
+  GlobalVar third_var = global_var_supply->FreshGlobal("someName");
+
+  EXPECT_TRUE(tvm::StructuralEqual()(var, second_var));
+  EXPECT_FALSE(tvm::StructuralEqual()(var, third_var));
+  EXPECT_EQ(second_var->name_hint.compare("someName"), 0);
+  EXPECT_EQ(third_var->name_hint.compare("someName_1"), 0);
+}
+
+TEST(GlobalVarSupply, BuildIRModule) {
+  auto x = relay::Var("x", relay::Type());
+  auto f = relay::Function(tvm::Array<relay::Var>{x}, x, relay::Type(), {});
+  GlobalVar var = GlobalVar("test");
+  IRModule module = IRModule({{var, f}});
+
+  GlobalVarSupply global_var_supply = GlobalVarSupply(module);
+  GlobalVar second_var = global_var_supply->UniqueGlobalFor("test", false);
+  GlobalVar third_var = global_var_supply->FreshGlobal("test", false);
+
+  EXPECT_TRUE(tvm::StructuralEqual()(var, second_var));
+  EXPECT_FALSE(tvm::StructuralEqual()(var, third_var));
+  EXPECT_EQ(second_var->name_hint.compare("test"), 0);
+  EXPECT_EQ(third_var->name_hint.compare("test_1"), 0);
+}
diff --git a/tests/python/relay/backend/test_pass_lower_te.py 
b/tests/python/relay/backend/test_pass_lower_te.py
index 310a16e269..fb79c1f2e7 100644
--- a/tests/python/relay/backend/test_pass_lower_te.py
+++ b/tests/python/relay/backend/test_pass_lower_te.py
@@ -203,12 +203,12 @@ def test_lower_extern_with_dynamic_shape():
     # Expected:
     # def @main(%a: Tensor[(5, 7), float32]) -> Tensor[(?, ?), float32] {
     #   %0 = (%a, %a);
-    #   call_lowered(@my_dyn, %0, 
metadata={prim_shape_fn_var='shape_func_add', relay_attrs={Extern=1}, 
prim_shape_fn_states=[2, 2], prim_shape_fn_num_inputs=2, 
all_prim_shape_fn_vars=['shape_func_add'], prim_shape_fn_num_outputs=1, 
all_prim_fn_vars=[]})
+    #   call_lowered(@my_dyn, %0, 
metadata={prim_shape_fn_var='test_shape_func_add', relay_attrs={Extern=1}, 
prim_shape_fn_states=[2, 2], prim_shape_fn_num_inputs=2, 
all_prim_shape_fn_vars=['shape_func_add'], prim_shape_fn_num_outputs=1, 
all_prim_fn_vars=[]})
     # }
     # def @my_dyn(%x: Tensor[(5, 7), float32] , %y: Tensor[(5, 7), float32] , 
Extern=1) -> Tensor[(?, ?), float32] {
     #   add(%x, %y)
     # }
-    # def @shape_func_add = <shape PrimFunc>
+    # def @test_shape_func_add = <shape PrimFunc>
 
     main = actual_mod["main"]
     call = main.body
@@ -218,14 +218,14 @@ def test_lower_extern_with_dynamic_shape():
     assert len(call.args[1].fields) == 2
     assert call.args[1].fields[0].name_hint == "a"
     assert call.args[1].fields[1].name_hint == "a"
-    assert call.attrs.metadata["prim_shape_fn_var"].name_hint == 
"shape_func_add"
+    assert call.attrs.metadata["prim_shape_fn_var"].name_hint == 
"test_shape_func_add"
     assert call.attrs.metadata["relay_attrs"].Extern == 1
     assert len(call.attrs.metadata["prim_shape_fn_states"]) == 2
     assert call.attrs.metadata["prim_shape_fn_states"][0] == 2
     assert call.attrs.metadata["prim_shape_fn_states"][1] == 2
     assert call.attrs.metadata["prim_shape_fn_num_inputs"] == 2
     assert len(call.attrs.metadata["all_prim_shape_fn_vars"]) == 1
-    assert call.attrs.metadata["all_prim_shape_fn_vars"][0].name_hint == 
"shape_func_add"
+    assert call.attrs.metadata["all_prim_shape_fn_vars"][0].name_hint == 
"test_shape_func_add"
     assert call.attrs.metadata["prim_shape_fn_num_outputs"] == 1
     assert len(call.attrs.metadata["all_prim_fn_vars"]) == 0
 
@@ -233,7 +233,7 @@ def test_lower_extern_with_dynamic_shape():
     assert isinstance(my_dyn, tvm.relay.Function)
     assert my_dyn.attrs["Extern"] == 1
 
-    shape_func_add = actual_mod["shape_func_add"]
+    shape_func_add = actual_mod["test_shape_func_add"]
     assert isinstance(shape_func_add, tvm.tir.PrimFunc)
 
 
diff --git a/tests/python/relay/test_name_supply.py 
b/tests/python/relay/test_name_supply.py
new file mode 100644
index 0000000000..688be19c81
--- /dev/null
+++ b/tests/python/relay/test_name_supply.py
@@ -0,0 +1,72 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+import tvm.testing
+
+from tvm import relay
+from tvm.ir import GlobalVar, structural_equal
+from tvm.ir.supply import NameSupply
+from tvm.ir.supply import GlobalVarSupply
+
+
+def test_name_supply():
+    name_supply = NameSupply("prefix")
+    name_supply.reserve_name("test")
+
+    assert name_supply.contains_name("test")
+    assert name_supply.fresh_name("test") == "prefix_test_1"
+    assert name_supply.contains_name("test_1")
+    assert not name_supply.contains_name("test_1", False)
+    assert not name_supply.contains_name("test_2")
+
+
+def test_global_var_supply_from_none():
+    var_supply = GlobalVarSupply()
+    global_var = GlobalVar("test")
+    var_supply.reserve_global(global_var)
+
+    assert structural_equal(var_supply.unique_global_for("test"), global_var)
+    assert not structural_equal(var_supply.fresh_global("test"), global_var)
+
+
+def test_global_var_supply_from_name_supply():
+    name_supply = NameSupply("prefix")
+    var_supply = GlobalVarSupply(name_supply)
+    global_var = GlobalVar("test")
+    var_supply.reserve_global(global_var)
+
+    assert structural_equal(var_supply.unique_global_for("test", False), 
global_var)
+    assert not structural_equal(var_supply.unique_global_for("test"), 
global_var)
+
+
+def test_global_var_supply_from_ir_mod():
+    x = relay.var("x")
+    y = relay.var("y")
+    mod = tvm.IRModule()
+    global_var = GlobalVar("test")
+    mod[global_var] = relay.Function([x, y], relay.add(x, y))
+    var_supply = GlobalVarSupply(mod)
+
+    second_global_var = var_supply.fresh_global("test", False)
+
+    assert structural_equal(var_supply.unique_global_for("test", False), 
global_var)
+    assert not structural_equal(var_supply.unique_global_for("test"), 
global_var)
+    assert not structural_equal(second_global_var, global_var)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()

Reply via email to