This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new a4fa5e7c59 [Unity][NestedMsg] Add NestedMsgTo helper function (#15223)
a4fa5e7c59 is described below

commit a4fa5e7c59b0f2e8ef49a8d4fb86aec006c71062
Author: Yixin Dong <[email protected]>
AuthorDate: Wed Jul 5 14:22:46 2023 +0800

    [Unity][NestedMsg] Add NestedMsgTo helper function (#15223)
    
    This PR adds a helper function called `NestedMsgTo<TargetType>` to map 
nested message to any specified type. Also, it rewrites `NestedMsgToExpr` using 
`NestedMsgTo`.
---
 include/tvm/relax/nested_msg.h | 52 ++++++++++++++++++++++++++++++------------
 1 file changed, 38 insertions(+), 14 deletions(-)

diff --git a/include/tvm/relax/nested_msg.h b/include/tvm/relax/nested_msg.h
index 761698f437..1ad5f02e07 100644
--- a/include/tvm/relax/nested_msg.h
+++ b/include/tvm/relax/nested_msg.h
@@ -340,19 +340,24 @@ NestedMsg<T> MapToNestedMsgBySInfo(Expr expr, FType 
fmapleaf) {
 }
 
 /*!
- * \brief Map nested message back to the expr.
+ * \brief Map nested message back to TargetType.
  *
  * This function will decompose the nested message and
- * run fmapleaf for each leaf message and get the leaf expr,
- * then recursively combines the results as tuple expr.
+ * run fmapleaf for each leaf message and get the leaf value,
+ * then recursively combines the results by fcombine.
  *
  * \param msg The input nested message.
- * \param fmapleaf The mapping function for each leaf with signature `Expr 
fmapleaf(Optional<T>)`.
+ * \param fmapleaf The mapping function for each leaf with signature
+ * `TargetType fmapleaf(Optional<T>)`.
+ * \param fcombine The function for combining all childs of a node into 
TargetType with signature
+ * `TargetType fmapleaf(Array<TargetType>)`.
+ * \tparam TargetType the target type to map nested msg to.
  * \tparam T the content type of nested msg.
- * \tparam FType The mapping function type.
+ * \tparam FMapLeaf The leaf mapping function type.
+ * \tparam FCombine The combining function type.
  */
-template <typename T, typename FType>
-Expr NestedMsgToExpr(NestedMsg<T> msg, FType fmapleaf) {
+template <typename TargetType, typename T, typename FMapLeaf, typename 
FCombine>
+TargetType NestedMsgTo(NestedMsg<T> msg, FMapLeaf fmapleaf, FCombine fcombine) 
{
   if (msg.IsNull()) {
     return fmapleaf(NullOpt);
   } else if (msg.IsLeaf()) {
@@ -360,17 +365,36 @@ Expr NestedMsgToExpr(NestedMsg<T> msg, FType fmapleaf) {
   } else {
     ICHECK(msg.IsNested());
     Array<NestedMsg<T>> arr = msg.NestedArray();
-    Array<Expr> subexpr;
+    Array<TargetType> subexpr;
     subexpr.reserve(arr.size());
     for (size_t i = 0; i < arr.size(); ++i) {
-      subexpr.push_back(NestedMsgToExpr<T, FType>(arr[i], fmapleaf));
+      subexpr.push_back(NestedMsgTo<TargetType>(arr[i], fmapleaf, fcombine));
     }
+    return fcombine(subexpr);
+  }
+}
+
+/*!
+ * \brief Map nested message back to the expr.
+ *
+ * This function will decompose the nested message and
+ * run fmapleaf for each leaf message and get the leaf expr,
+ * then recursively combines the results as tuple expr.
+ *
+ * \param msg The input nested message.
+ * \param fmapleaf The mapping function for each leaf with signature `Expr 
fmapleaf(Optional<T>)`.
+ * \tparam T the content type of nested msg.
+ * \tparam FType The mapping function type.
+ */
+template <typename T, typename FType>
+Expr NestedMsgToExpr(NestedMsg<T> msg, FType fmapleaf) {
+  return NestedMsgTo<Expr>(msg, fmapleaf, [](Array<Expr> arr) {
     Optional<Expr> simplified_tuple;
     bool simplified_flag = false;
-    if (subexpr.size() >= 1) {
+    if (arr.size() >= 1) {
       simplified_flag = true;
-      for (size_t i = 0; i < subexpr.size() && simplified_flag; ++i) {
-        auto* node = subexpr[i].as<TupleGetItemNode>();
+      for (size_t i = 0; i < arr.size() && simplified_flag; ++i) {
+        auto* node = arr[i].as<TupleGetItemNode>();
         if (node == nullptr || node->index != static_cast<int>(i)) {
           simplified_flag = false;
         } else {
@@ -383,8 +407,8 @@ Expr NestedMsgToExpr(NestedMsg<T> msg, FType fmapleaf) {
         }
       }
     }
-    return simplified_flag ? simplified_tuple.value() : Tuple(subexpr);
-  }
+    return simplified_flag ? simplified_tuple.value() : Tuple(arr);
+  });
 }
 
 /*!

Reply via email to