This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new a4fa5e7c59 [Unity][NestedMsg] Add NestedMsgTo helper function (#15223)
a4fa5e7c59 is described below
commit a4fa5e7c59b0f2e8ef49a8d4fb86aec006c71062
Author: Yixin Dong <[email protected]>
AuthorDate: Wed Jul 5 14:22:46 2023 +0800
[Unity][NestedMsg] Add NestedMsgTo helper function (#15223)
This PR adds a helper function called `NestedMsgTo<TargetType>` to map
nested message to any specified type. Also, it rewrites `NestedMsgToExpr` using
`NestedMsgTo`.
---
include/tvm/relax/nested_msg.h | 52 ++++++++++++++++++++++++++++++------------
1 file changed, 38 insertions(+), 14 deletions(-)
diff --git a/include/tvm/relax/nested_msg.h b/include/tvm/relax/nested_msg.h
index 761698f437..1ad5f02e07 100644
--- a/include/tvm/relax/nested_msg.h
+++ b/include/tvm/relax/nested_msg.h
@@ -340,19 +340,24 @@ NestedMsg<T> MapToNestedMsgBySInfo(Expr expr, FType
fmapleaf) {
}
/*!
- * \brief Map nested message back to the expr.
+ * \brief Map nested message back to TargetType.
*
* This function will decompose the nested message and
- * run fmapleaf for each leaf message and get the leaf expr,
- * then recursively combines the results as tuple expr.
+ * run fmapleaf for each leaf message and get the leaf value,
+ * then recursively combines the results by fcombine.
*
* \param msg The input nested message.
- * \param fmapleaf The mapping function for each leaf with signature `Expr
fmapleaf(Optional<T>)`.
+ * \param fmapleaf The mapping function for each leaf with signature
+ * `TargetType fmapleaf(Optional<T>)`.
+ * \param fcombine The function for combining all childs of a node into
TargetType with signature
+ * `TargetType fmapleaf(Array<TargetType>)`.
+ * \tparam TargetType the target type to map nested msg to.
* \tparam T the content type of nested msg.
- * \tparam FType The mapping function type.
+ * \tparam FMapLeaf The leaf mapping function type.
+ * \tparam FCombine The combining function type.
*/
-template <typename T, typename FType>
-Expr NestedMsgToExpr(NestedMsg<T> msg, FType fmapleaf) {
+template <typename TargetType, typename T, typename FMapLeaf, typename
FCombine>
+TargetType NestedMsgTo(NestedMsg<T> msg, FMapLeaf fmapleaf, FCombine fcombine)
{
if (msg.IsNull()) {
return fmapleaf(NullOpt);
} else if (msg.IsLeaf()) {
@@ -360,17 +365,36 @@ Expr NestedMsgToExpr(NestedMsg<T> msg, FType fmapleaf) {
} else {
ICHECK(msg.IsNested());
Array<NestedMsg<T>> arr = msg.NestedArray();
- Array<Expr> subexpr;
+ Array<TargetType> subexpr;
subexpr.reserve(arr.size());
for (size_t i = 0; i < arr.size(); ++i) {
- subexpr.push_back(NestedMsgToExpr<T, FType>(arr[i], fmapleaf));
+ subexpr.push_back(NestedMsgTo<TargetType>(arr[i], fmapleaf, fcombine));
}
+ return fcombine(subexpr);
+ }
+}
+
+/*!
+ * \brief Map nested message back to the expr.
+ *
+ * This function will decompose the nested message and
+ * run fmapleaf for each leaf message and get the leaf expr,
+ * then recursively combines the results as tuple expr.
+ *
+ * \param msg The input nested message.
+ * \param fmapleaf The mapping function for each leaf with signature `Expr
fmapleaf(Optional<T>)`.
+ * \tparam T the content type of nested msg.
+ * \tparam FType The mapping function type.
+ */
+template <typename T, typename FType>
+Expr NestedMsgToExpr(NestedMsg<T> msg, FType fmapleaf) {
+ return NestedMsgTo<Expr>(msg, fmapleaf, [](Array<Expr> arr) {
Optional<Expr> simplified_tuple;
bool simplified_flag = false;
- if (subexpr.size() >= 1) {
+ if (arr.size() >= 1) {
simplified_flag = true;
- for (size_t i = 0; i < subexpr.size() && simplified_flag; ++i) {
- auto* node = subexpr[i].as<TupleGetItemNode>();
+ for (size_t i = 0; i < arr.size() && simplified_flag; ++i) {
+ auto* node = arr[i].as<TupleGetItemNode>();
if (node == nullptr || node->index != static_cast<int>(i)) {
simplified_flag = false;
} else {
@@ -383,8 +407,8 @@ Expr NestedMsgToExpr(NestedMsg<T> msg, FType fmapleaf) {
}
}
}
- return simplified_flag ? simplified_tuple.value() : Tuple(subexpr);
- }
+ return simplified_flag ? simplified_tuple.value() : Tuple(arr);
+ });
}
/*!