kparzysz-quic commented on code in PR #14251:
URL: https://github.com/apache/tvm/pull/14251#discussion_r1145236799


##########
include/tvm/tir/stmt_functor.h:
##########
@@ -368,44 +368,134 @@ TVM_DLL PrimExpr Substitute(PrimExpr expr, 
std::function<Optional<PrimExpr>(cons
 
 /*!
  * \brief Substitute the var specified by vmap.
- * \param region The object whose vars are to be substituted
- * \param vmap The map of new values.
+ * \param arr The array of Stmt/PrimExpr to be substituted
+ * \param vmap returns a new value if re-mapping is needed, otherwise returns 
nullptr.
  * \return The result.
  */
-TVM_DLL Array<Range> Substitute(const Array<Range>& region, const Map<Var, 
PrimExpr>& vmap);
+template <typename T>
+Array<T> Substitute(const Array<T>& arr, 
std::function<Optional<PrimExpr>(const Var& var)> vmap) {
+  return arr.Map([&vmap](const auto& elem) { return Substitute(elem, vmap); });
+}
 
 /*!
- * \brief Sugar for substitute via a given map.
- * \param input The input to be updated.
- * \param value_map The map of new values.
- * \return The result.
- * \tparam T the input type, can be PrimExpr or Stmt.
+ * \brief Substitute the vars specified by vmap.
+ * \param range The array of Stmt/PrimExpr to be substituted
+ * \param vmap returns a new value if re-mapping is needed, otherwise returns 
nullptr.
+ * \return The modified Range.
  */
-template <typename T>
-inline auto Substitute(T input, const Map<Var, PrimExpr>& value_map) {
-  auto vmap = [&](const Var& var) -> Optional<PrimExpr> {
-    auto it = value_map.find(var);
-    if (it != value_map.end()) return (*it).second;
-    return Optional<PrimExpr>(nullptr);
+inline Range Substitute(const Range& range,
+                        std::function<Optional<PrimExpr>(const Var& var)> 
vmap) {
+  return Range::FromMinExtent(Substitute(range->min, vmap), 
Substitute(range->extent, vmap));
+}
+
+/*!
+ * \brief Substitute the vars specified by vmap.
+ *
+ * Delegates to the Substitute methods that use std::function.  This
+ * overload allows braced-initialization of the Map, whereas the
+ * template<typename Expr> overload cannot.
+ *
+ * \param obj The object in which TIR variables should be substituted
+ * \param vmap Map defining the TIR variables to be replaced
+ * \return The modified Range.
+ */
+template <typename Obj>
+auto Substitute(Obj&& obj, const Map<Var, PrimExpr>& vmap) {
+  auto func = [&vmap](const Var& var) -> Optional<PrimExpr> { return 
vmap.Get(var); };
+  return Substitute(std::forward<Obj>(obj), func);
+}
+
+/*!
+ * \brief Substitute the vars specified by vmap.
+ *
+ * Delegates to the Substitute methods that use std::function.
+ *
+ * \param obj The object in which TIR variables should be substituted
+ * \param vmap Map defining the TIR variables to be replaced
+ * \return The modified Range.

Review Comment:
   \return The modified object (not Range).



##########
include/tvm/tir/stmt_functor.h:
##########
@@ -368,44 +368,134 @@ TVM_DLL PrimExpr Substitute(PrimExpr expr, 
std::function<Optional<PrimExpr>(cons
 
 /*!
  * \brief Substitute the var specified by vmap.
- * \param region The object whose vars are to be substituted
- * \param vmap The map of new values.
+ * \param arr The array of Stmt/PrimExpr to be substituted
+ * \param vmap returns a new value if re-mapping is needed, otherwise returns 
nullptr.
  * \return The result.
  */
-TVM_DLL Array<Range> Substitute(const Array<Range>& region, const Map<Var, 
PrimExpr>& vmap);
+template <typename T>
+Array<T> Substitute(const Array<T>& arr, 
std::function<Optional<PrimExpr>(const Var& var)> vmap) {
+  return arr.Map([&vmap](const auto& elem) { return Substitute(elem, vmap); });
+}
 
 /*!
- * \brief Sugar for substitute via a given map.
- * \param input The input to be updated.
- * \param value_map The map of new values.
- * \return The result.
- * \tparam T the input type, can be PrimExpr or Stmt.
+ * \brief Substitute the vars specified by vmap.
+ * \param range The array of Stmt/PrimExpr to be substituted
+ * \param vmap returns a new value if re-mapping is needed, otherwise returns 
nullptr.
+ * \return The modified Range.
  */
-template <typename T>
-inline auto Substitute(T input, const Map<Var, PrimExpr>& value_map) {
-  auto vmap = [&](const Var& var) -> Optional<PrimExpr> {
-    auto it = value_map.find(var);
-    if (it != value_map.end()) return (*it).second;
-    return Optional<PrimExpr>(nullptr);
+inline Range Substitute(const Range& range,
+                        std::function<Optional<PrimExpr>(const Var& var)> 
vmap) {
+  return Range::FromMinExtent(Substitute(range->min, vmap), 
Substitute(range->extent, vmap));
+}
+
+/*!
+ * \brief Substitute the vars specified by vmap.
+ *
+ * Delegates to the Substitute methods that use std::function.  This
+ * overload allows braced-initialization of the Map, whereas the
+ * template<typename Expr> overload cannot.
+ *
+ * \param obj The object in which TIR variables should be substituted
+ * \param vmap Map defining the TIR variables to be replaced
+ * \return The modified Range.
+ */
+template <typename Obj>
+auto Substitute(Obj&& obj, const Map<Var, PrimExpr>& vmap) {
+  auto func = [&vmap](const Var& var) -> Optional<PrimExpr> { return 
vmap.Get(var); };
+  return Substitute(std::forward<Obj>(obj), func);
+}
+
+/*!
+ * \brief Substitute the vars specified by vmap.
+ *
+ * Delegates to the Substitute methods that use std::function.
+ *
+ * \param obj The object in which TIR variables should be substituted
+ * \param vmap Map defining the TIR variables to be replaced
+ * \return The modified Range.
+ */
+template <typename Obj, typename Expr,
+          typename = std::enable_if_t<std::is_base_of_v<PrimExpr, Expr>>>
+auto Substitute(Obj&& obj, const Map<Var, Expr>& vmap) {
+  auto func = [&vmap](const Var& var) -> Optional<PrimExpr> {
+    if (auto opt = vmap.Get(var)) {
+      return opt.value();
+    } else {
+      return NullOpt;
+    }
   };
-  return Substitute(std::move(input), vmap);
+  return Substitute(std::forward<Obj>(obj), func);
 }
 
 /*!
- * \brief Sugar for substitute via a given map.
- * \param input The input to be updated.
- * \param value_map The map of new values.
- * \return The result.
- * \tparam T the input type, can be PrimExpr or Stmt.
+ * \brief Substitute the vars specified by vmap.
+ *
+ * Delegates to the Substitute methods that use std::function.
+ *
+ * \param obj The object in which TIR variables should be substituted
+ * \param vmap Map defining the TIR variables to be replaced
+ * \return The modified Range.
  */
-template <typename T>
-inline T Substitute(T input, const std::unordered_map<const VarNode*, 
PrimExpr>& value_map) {
-  auto vmap = [&](const Var& var) -> Optional<PrimExpr> {
-    auto it = value_map.find(var.get());
-    if (it != value_map.end()) return (*it).second;
-    return Optional<PrimExpr>(nullptr);
+template <typename Obj, typename Expr,
+          typename = std::enable_if_t<std::is_base_of_v<PrimExpr, Expr>>>
+auto Substitute(Obj&& obj, const std::unordered_map<const VarNode*, Expr>& 
vmap) {
+  auto func = [&vmap](const Var& var) -> Optional<PrimExpr> {
+    if (auto it = vmap.find(var.get()); it != vmap.end()) {
+      return it->second;
+    } else {
+      return NullOpt;
+    }
+  };
+  return Substitute(std::forward<Obj>(obj), func);
+}
+
+/*!
+ * \brief Substitute the vars specified by vmap.
+ *
+ * Delegates to the Substitute methods that use std::function.
+ *
+ * \param obj The object in which TIR variables should be substituted
+ * \param vmap Map defining the TIR variables to be replaced
+ * \return The modified Range.

Review Comment:
   Same here.



##########
include/tvm/tir/stmt_functor.h:
##########
@@ -368,44 +368,134 @@ TVM_DLL PrimExpr Substitute(PrimExpr expr, 
std::function<Optional<PrimExpr>(cons
 
 /*!
  * \brief Substitute the var specified by vmap.
- * \param region The object whose vars are to be substituted
- * \param vmap The map of new values.
+ * \param arr The array of Stmt/PrimExpr to be substituted
+ * \param vmap returns a new value if re-mapping is needed, otherwise returns 
nullptr.
  * \return The result.
  */
-TVM_DLL Array<Range> Substitute(const Array<Range>& region, const Map<Var, 
PrimExpr>& vmap);
+template <typename T>
+Array<T> Substitute(const Array<T>& arr, 
std::function<Optional<PrimExpr>(const Var& var)> vmap) {
+  return arr.Map([&vmap](const auto& elem) { return Substitute(elem, vmap); });
+}
 
 /*!
- * \brief Sugar for substitute via a given map.
- * \param input The input to be updated.
- * \param value_map The map of new values.
- * \return The result.
- * \tparam T the input type, can be PrimExpr or Stmt.
+ * \brief Substitute the vars specified by vmap.
+ * \param range The array of Stmt/PrimExpr to be substituted
+ * \param vmap returns a new value if re-mapping is needed, otherwise returns 
nullptr.
+ * \return The modified Range.
  */
-template <typename T>
-inline auto Substitute(T input, const Map<Var, PrimExpr>& value_map) {
-  auto vmap = [&](const Var& var) -> Optional<PrimExpr> {
-    auto it = value_map.find(var);
-    if (it != value_map.end()) return (*it).second;
-    return Optional<PrimExpr>(nullptr);
+inline Range Substitute(const Range& range,
+                        std::function<Optional<PrimExpr>(const Var& var)> 
vmap) {
+  return Range::FromMinExtent(Substitute(range->min, vmap), 
Substitute(range->extent, vmap));
+}
+
+/*!
+ * \brief Substitute the vars specified by vmap.
+ *
+ * Delegates to the Substitute methods that use std::function.  This
+ * overload allows braced-initialization of the Map, whereas the
+ * template<typename Expr> overload cannot.
+ *
+ * \param obj The object in which TIR variables should be substituted
+ * \param vmap Map defining the TIR variables to be replaced
+ * \return The modified Range.
+ */
+template <typename Obj>
+auto Substitute(Obj&& obj, const Map<Var, PrimExpr>& vmap) {
+  auto func = [&vmap](const Var& var) -> Optional<PrimExpr> { return 
vmap.Get(var); };
+  return Substitute(std::forward<Obj>(obj), func);
+}
+
+/*!
+ * \brief Substitute the vars specified by vmap.
+ *
+ * Delegates to the Substitute methods that use std::function.
+ *
+ * \param obj The object in which TIR variables should be substituted
+ * \param vmap Map defining the TIR variables to be replaced
+ * \return The modified Range.
+ */
+template <typename Obj, typename Expr,
+          typename = std::enable_if_t<std::is_base_of_v<PrimExpr, Expr>>>
+auto Substitute(Obj&& obj, const Map<Var, Expr>& vmap) {
+  auto func = [&vmap](const Var& var) -> Optional<PrimExpr> {
+    if (auto opt = vmap.Get(var)) {
+      return opt.value();
+    } else {
+      return NullOpt;
+    }
   };
-  return Substitute(std::move(input), vmap);
+  return Substitute(std::forward<Obj>(obj), func);
 }
 
 /*!
- * \brief Sugar for substitute via a given map.
- * \param input The input to be updated.
- * \param value_map The map of new values.
- * \return The result.
- * \tparam T the input type, can be PrimExpr or Stmt.
+ * \brief Substitute the vars specified by vmap.
+ *
+ * Delegates to the Substitute methods that use std::function.
+ *
+ * \param obj The object in which TIR variables should be substituted
+ * \param vmap Map defining the TIR variables to be replaced
+ * \return The modified Range.

Review Comment:
   Same here: Range -> object.



##########
include/tvm/tir/stmt_functor.h:
##########
@@ -368,44 +368,134 @@ TVM_DLL PrimExpr Substitute(PrimExpr expr, 
std::function<Optional<PrimExpr>(cons
 
 /*!
  * \brief Substitute the var specified by vmap.
- * \param region The object whose vars are to be substituted
- * \param vmap The map of new values.
+ * \param arr The array of Stmt/PrimExpr to be substituted
+ * \param vmap returns a new value if re-mapping is needed, otherwise returns 
nullptr.
  * \return The result.
  */
-TVM_DLL Array<Range> Substitute(const Array<Range>& region, const Map<Var, 
PrimExpr>& vmap);
+template <typename T>
+Array<T> Substitute(const Array<T>& arr, 
std::function<Optional<PrimExpr>(const Var& var)> vmap) {
+  return arr.Map([&vmap](const auto& elem) { return Substitute(elem, vmap); });
+}
 
 /*!
- * \brief Sugar for substitute via a given map.
- * \param input The input to be updated.
- * \param value_map The map of new values.
- * \return The result.
- * \tparam T the input type, can be PrimExpr or Stmt.
+ * \brief Substitute the vars specified by vmap.
+ * \param range The array of Stmt/PrimExpr to be substituted
+ * \param vmap returns a new value if re-mapping is needed, otherwise returns 
nullptr.
+ * \return The modified Range.
  */
-template <typename T>
-inline auto Substitute(T input, const Map<Var, PrimExpr>& value_map) {
-  auto vmap = [&](const Var& var) -> Optional<PrimExpr> {
-    auto it = value_map.find(var);
-    if (it != value_map.end()) return (*it).second;
-    return Optional<PrimExpr>(nullptr);
+inline Range Substitute(const Range& range,
+                        std::function<Optional<PrimExpr>(const Var& var)> 
vmap) {
+  return Range::FromMinExtent(Substitute(range->min, vmap), 
Substitute(range->extent, vmap));
+}
+
+/*!
+ * \brief Substitute the vars specified by vmap.
+ *
+ * Delegates to the Substitute methods that use std::function.  This
+ * overload allows braced-initialization of the Map, whereas the
+ * template<typename Expr> overload cannot.
+ *
+ * \param obj The object in which TIR variables should be substituted
+ * \param vmap Map defining the TIR variables to be replaced
+ * \return The modified Range.
+ */
+template <typename Obj>
+auto Substitute(Obj&& obj, const Map<Var, PrimExpr>& vmap) {
+  auto func = [&vmap](const Var& var) -> Optional<PrimExpr> { return 
vmap.Get(var); };
+  return Substitute(std::forward<Obj>(obj), func);
+}
+
+/*!
+ * \brief Substitute the vars specified by vmap.
+ *
+ * Delegates to the Substitute methods that use std::function.
+ *
+ * \param obj The object in which TIR variables should be substituted
+ * \param vmap Map defining the TIR variables to be replaced
+ * \return The modified Range.
+ */
+template <typename Obj, typename Expr,
+          typename = std::enable_if_t<std::is_base_of_v<PrimExpr, Expr>>>
+auto Substitute(Obj&& obj, const Map<Var, Expr>& vmap) {
+  auto func = [&vmap](const Var& var) -> Optional<PrimExpr> {
+    if (auto opt = vmap.Get(var)) {
+      return opt.value();
+    } else {
+      return NullOpt;
+    }
   };
-  return Substitute(std::move(input), vmap);
+  return Substitute(std::forward<Obj>(obj), func);
 }
 
 /*!
- * \brief Sugar for substitute via a given map.
- * \param input The input to be updated.
- * \param value_map The map of new values.
- * \return The result.
- * \tparam T the input type, can be PrimExpr or Stmt.
+ * \brief Substitute the vars specified by vmap.
+ *
+ * Delegates to the Substitute methods that use std::function.
+ *
+ * \param obj The object in which TIR variables should be substituted
+ * \param vmap Map defining the TIR variables to be replaced
+ * \return The modified Range.
  */
-template <typename T>
-inline T Substitute(T input, const std::unordered_map<const VarNode*, 
PrimExpr>& value_map) {
-  auto vmap = [&](const Var& var) -> Optional<PrimExpr> {
-    auto it = value_map.find(var.get());
-    if (it != value_map.end()) return (*it).second;
-    return Optional<PrimExpr>(nullptr);
+template <typename Obj, typename Expr,
+          typename = std::enable_if_t<std::is_base_of_v<PrimExpr, Expr>>>
+auto Substitute(Obj&& obj, const std::unordered_map<const VarNode*, Expr>& 
vmap) {
+  auto func = [&vmap](const Var& var) -> Optional<PrimExpr> {
+    if (auto it = vmap.find(var.get()); it != vmap.end()) {
+      return it->second;
+    } else {
+      return NullOpt;
+    }
+  };
+  return Substitute(std::forward<Obj>(obj), func);
+}
+
+/*!
+ * \brief Substitute the vars specified by vmap.
+ *
+ * Delegates to the Substitute methods that use std::function.
+ *
+ * \param obj The object in which TIR variables should be substituted
+ * \param vmap Map defining the TIR variables to be replaced
+ * \return The modified Range.
+ */
+template <typename Obj, typename Expr, typename Hasher, typename 
EqualityChecker,
+          typename = std::enable_if_t<std::is_base_of_v<PrimExpr, Expr>>>
+auto Substitute(Obj&& obj, const std::unordered_map<Var, Expr, Hasher, 
EqualityChecker>& vmap) {
+  auto func = [&vmap](const Var& var) -> Optional<PrimExpr> {
+    if (auto it = vmap.find(var); it != vmap.end()) {
+      return it->second;
+    } else {
+      return NullOpt;
+    }
+  };
+  return Substitute(std::forward<Obj>(obj), func);
+}
+
+/*!
+ * \brief Substitute the vars specified by vmap.
+ *
+ * Delegates to the Substitute methods that use std::function.
+ *
+ * \param obj The object in which TIR variables should be substituted
+ * \param iter_vmap Map defining the TIR variables to be replaced
+ * \return The modified Range.

Review Comment:
   And one last time...



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to