This is an automated email from the ASF dual-hosted git repository.

sanirudh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new a2511cc516 [QoL][Relax] Use SeqExpr in IR types when SeqExpr is 
required (#16859)
a2511cc516 is described below

commit a2511cc5160fa73131517515c79144bef7f4b076
Author: Eric Lunderberg <[email protected]>
AuthorDate: Sat Apr 20 03:15:52 2024 -0500

    [QoL][Relax] Use SeqExpr in IR types when SeqExpr is required (#16859)
    
    * [QoL][Relax] Use SeqExpr in IR types when SeqExpr is required
    
    The Relax IR requires the `FunctionNode::body`, `IfNode::true_branch`,
    and `IfNode::false_branch` to be instances of `relax::SeqExpr`.
    If these Relax requirements are violated, correctly-implemented
    transformations may raise exceptsion
    (e.g. from `Downcast` in `Downcast<SeqExpr>(func->body)->blocks`), or
    even segfault (e.g. when `.as` returns a nullptr in
    `func->body.as<SeqExprNode>()->blocks`).  Debugging these failure
    modes is also difficult, as even the TVMScript printer relies on the
    body of the function being a `SeqExprNode`.
    
    This commit updates the C++ type of `FunctionNode::body`,
    `IfNode::true_branch`, and `IfNode::false_branch` to be
    `relax::SeqExpr` instead of `relax::Expr`.  This does not impact any
    well-formed Relax IR, and allows this type of ill-formed Relax IR type
    to be checked at compile-time.  A large number of checks applied
    during TVM runtime can now be removed, as they duplicate the new
    compile-time check.
    
    To maintain backwards compatibility, this commit adds a new
    constructor to `relax::SeqExpr`, which accepts a single `Expr body`
    argument.  This constructor provides either an additional reference to
    the same underlying `relax::SeqExprNode`, if `body` already contains a
    `relax::SeqExprNode`, and otherwise wraps the body in a
    `relax::SeqExpr`.  For implementations that previously produced
    well-formed Relax IR, this change has no effect.  For implementations
    that previously produced ill-formed Relax IR, this change results in
    the equivalent well-formed Relax IR.
    
    Alternate implementations considered:
    
    * Perform the backwards-compatibility wrapping within the
      `relax::Function` and `relax::If` constructors.  While this would
      provide the intended conversion when these constructors are used,
      Relax transforms make frequent use of copy-on-write
      (e.g. `func.CopyOnWrite()->body = new_body`), which does not use the
      constructor.  Maintaining backwards compatibility for this usage
      requires the implicit conversion constructor that was chosen for
      this PR.
    
    * Remove the Relax IR requirement for these expressions to be
      `SeqExpr`.  While this would make Relax more internally consistent,
      such a change would break backwards compatibility that relies on
      `SeqExpr` being present.  While the callsites within TVM could be
      updated to resolve this breakage, callsites outside of TVM
      (e.g. MLC-LLM) could not.  Exposing the special case within the C++
      type, as done in this PR, maintains backwards compatibility.
    
    * Resolve breakages in unit tests
    
    All breakage was the result of callers relying on ill-formed Relax
    maintaining that specific type form of ill-formed-ness.
---
 include/tvm/relax/expr.h                          | 190 +++++++++++++---------
 src/contrib/msc/core/ir/graph_builder.cc          |   9 +-
 src/contrib/msc/core/transform/set_expr_layout.cc |  20 +--
 src/relax/analysis/well_formed.cc                 |  32 ++--
 src/relax/backend/contrib/utils.cc                |   2 +-
 src/relax/ir/dataflow_matcher.cc                  |  29 +++-
 src/relax/ir/expr.cc                              |   8 +
 src/relax/training/utils.cc                       |   7 +-
 src/relax/transform/fuse_ops.cc                   |  14 +-
 src/relax/transform/fuse_tir.cc                   |   4 +-
 src/relax/transform/gradient.cc                   |   2 -
 src/script/printer/relax/binding.cc               |   4 +-
 src/script/printer/relax/function.cc              |   3 +-
 tests/python/relax/test_expr_functor.py           |   2 +-
 14 files changed, 189 insertions(+), 137 deletions(-)

diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h
index e2176cf720..0ca92a01a7 100644
--- a/include/tvm/relax/expr.h
+++ b/include/tvm/relax/expr.h
@@ -213,78 +213,6 @@ Call WithFields(Call call, Optional<Expr> opt_op = 
Optional<Expr>(),
                 Optional<Array<StructInfo>> opt_sinfo_args = 
Optional<Array<StructInfo>>(),
                 Optional<Span> opt_span = Optional<Span>());
 
-/*!
- * \brief Condition expression
- *
- * Unlike traditional statement `if`s, the if evalutes
- * to the result of the branch taken.
- *
- * x = if (true) { 1 } else { 0 }; // x is 1
- * y = if (false) { 1 } else { 0 }; // y is 0
- *
- * \note This is similar to C's ternary operator.
- */
-class IfNode : public ExprNode {
- public:
-  /*! \brief The condition. */
-  Expr cond;
-  /*! \brief The expression evaluated when condition is true. */
-  Expr true_branch;
-  /*! \brief The expression evaluated when condition is false */
-  Expr false_branch;
-
-  void VisitAttrs(tvm::AttrVisitor* v) {
-    v->Visit("cond", &cond);
-    v->Visit("true_branch", &true_branch);
-    v->Visit("false_branch", &false_branch);
-    v->Visit("_checked_type_", &checked_type_);
-    v->Visit("struct_info_", &struct_info_);
-    v->Visit("span", &span);
-  }
-
-  bool SEqualReduce(const IfNode* other, SEqualReducer equal) const {
-    equal->MarkGraphNode();
-    return equal(cond, other->cond) && equal(true_branch, other->true_branch) 
&&
-           equal(false_branch, other->false_branch) && equal(struct_info_, 
other->struct_info_);
-  }
-
-  void SHashReduce(SHashReducer hash_reduce) const {
-    hash_reduce->MarkGraphNode();
-    hash_reduce(cond);
-    hash_reduce(true_branch);
-    hash_reduce(false_branch);
-    hash_reduce(struct_info_);
-  }
-
-  static constexpr const char* _type_key = "relax.expr.If";
-  TVM_DECLARE_FINAL_OBJECT_INFO(IfNode, ExprNode);
-};
-
-class If : public Expr {
- public:
-  /*!
-   * \brief The constructor
-   * \param cond The condition of a if node.
-   * \param true_branch The fall through branch
-   * \param false_branch The branch for execution when condition is false.
-   * \param span The source span of the expression.
-   */
-  TVM_DLL If(Expr cond, Expr true_branch, Expr false_branch, Span span = 
Span());
-
-  TVM_DEFINE_OBJECT_REF_METHODS(If, Expr, IfNode);
-  TVM_DEFINE_OBJECT_REF_COW_METHOD(IfNode);
-};
-
-/*!
- * \brief Returns \p if_expr with the given properties. A null property 
denotes 'no change'.
- * Returns \p if_expr if all properties are unchanged. Otherwise, returns a 
copy with the new
- * fields.
- */
-If WithFields(If if_expr, Optional<Expr> opt_cond = Optional<Expr>(),
-              Optional<Expr> opt_true_branch = Optional<Expr>(),
-              Optional<Expr> opt_false_branch = Optional<Expr>(),
-              Optional<Span> opt_span = Optional<Span>());
-
 /*! \brief Tuple container */
 class TupleNode : public ExprNode {
  public:
@@ -915,18 +843,113 @@ class SeqExprNode : public ExprNode {
 
 class SeqExpr : public Expr {
  public:
+  /* \brief Implicit conversion constructor
+   *
+   * Relax nodes that introduce a new scope (e.g. `relax::Function`)
+   * are required to be held as SeqExpr.  This implicit conversion
+   * provides allows callsites to use these member variables when the
+   * C++ compile-time type is a `relax::Expr`.  For example,
+   * a transform may use `func.CopyOnWrite()->body = expr;`.
+   *
+   * If the expression is already a `relax::SeqExpr`, the same
+   * underlying `relax::SeqExprNode` is used, and no copies are made.
+   */
+  TVM_DLL SeqExpr(Expr body);  // NOLINT(*)
+
   TVM_DLL explicit SeqExpr(Array<BindingBlock> blocks, Expr body, Span span = 
Span());
   TVM_DEFINE_OBJECT_REF_METHODS(SeqExpr, Expr, SeqExprNode);
   TVM_DEFINE_OBJECT_REF_COW_METHOD(SeqExprNode);
 };
 
+/*!
+ * \brief Condition expression
+ *
+ * Unlike traditional statement `if`s, the if evalutes
+ * to the result of the branch taken.
+ *
+ * x = if (true) { 1 } else { 0 }; // x is 1
+ * y = if (false) { 1 } else { 0 }; // y is 0
+ *
+ * \note This is similar to C's ternary operator.
+ */
+class IfNode : public ExprNode {
+ public:
+  /*! \brief The condition. */
+  Expr cond;
+  /*! \brief The expression evaluated when condition is true. */
+  SeqExpr true_branch;
+  /*! \brief The expression evaluated when condition is false */
+  SeqExpr false_branch;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("cond", &cond);
+    v->Visit("true_branch", &true_branch);
+    v->Visit("false_branch", &false_branch);
+    v->Visit("_checked_type_", &checked_type_);
+    v->Visit("struct_info_", &struct_info_);
+    v->Visit("span", &span);
+  }
+
+  bool SEqualReduce(const IfNode* other, SEqualReducer equal) const {
+    equal->MarkGraphNode();
+    return equal(cond, other->cond) && equal(true_branch, other->true_branch) 
&&
+           equal(false_branch, other->false_branch) && equal(struct_info_, 
other->struct_info_);
+  }
+
+  void SHashReduce(SHashReducer hash_reduce) const {
+    hash_reduce->MarkGraphNode();
+    hash_reduce(cond);
+    hash_reduce(true_branch);
+    hash_reduce(false_branch);
+    hash_reduce(struct_info_);
+  }
+
+  static constexpr const char* _type_key = "relax.expr.If";
+  TVM_DECLARE_FINAL_OBJECT_INFO(IfNode, ExprNode);
+};
+
+class If : public Expr {
+ public:
+  /*!
+   * \brief The constructor
+   *
+   * \param cond The condition of a if node.
+   *
+   * \param true_branch The fall through branch.  If this is not a
+   *     SeqExpr, it will be wrapped in a SeqExpr, to satisfy the
+   *     Relax IR requirement that all scopes be contained in a
+   *     SeqExpr.
+   *
+   * \param false_branch The branch for execution when condition is
+   *     false.  If this is not a SeqExpr, it will be wrapped in a
+   *     SeqExpr, to satisfy the Relax IR requirement that all scopes
+   *     be contained in a SeqExpr.
+   *
+   * \param span The source span of the expression.
+   */
+  TVM_DLL If(Expr cond, Expr true_branch, Expr false_branch, Span span = 
Span());
+
+  TVM_DEFINE_OBJECT_REF_METHODS(If, Expr, IfNode);
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(IfNode);
+};
+
+/*!
+ * \brief Returns \p if_expr with the given properties. A null property 
denotes 'no change'.
+ * Returns \p if_expr if all properties are unchanged. Otherwise, returns a 
copy with the new
+ * fields.
+ */
+If WithFields(If if_expr, Optional<Expr> opt_cond = Optional<Expr>(),
+              Optional<Expr> opt_true_branch = Optional<Expr>(),
+              Optional<Expr> opt_false_branch = Optional<Expr>(),
+              Optional<Span> opt_span = Optional<Span>());
+
 /*! \brief A Relax function. */
 class FunctionNode : public BaseFuncNode {
  public:
   /*! \brief The parameters to the function. */
   Array<Var> params;
   /*! \brief The body of the function. */
-  Expr body;
+  SeqExpr body;
   /*! \brief The return type of the function. */
   StructInfo ret_struct_info;
   /*! \brief Whether the function is annotated as pure or not. */
@@ -968,6 +991,27 @@ class FunctionNode : public BaseFuncNode {
 
 class Function : public BaseFunc {
  public:
+  /*!
+   * \brief Construct a Relax Function
+   *
+   * \param params The parameters accepted by the function
+   *
+   * \param body The body of the function.  If this is not a
+   *     SeqExpr, it will be wrapped in a SeqExpr, to satisfy the
+   *     Relax IR requirement that all scopes be contained in a
+   *     SeqExpr.
+   *
+   * \param ret_struct_info The StructInfo returned by the function.
+   *     If NullOpt, will be inferred from the StructInfo of the
+   *     function's body.
+   *
+   * \param is_pure The purity of the function.
+   *
+   * \param attrs Any attributes associated with the function.
+   *     Defaults to an empty dictionary.
+   *
+   * \param span The source span of the expression.
+   */
   TVM_DLL explicit Function(Array<Var> params, Expr body, Optional<StructInfo> 
ret_struct_info,
                             bool is_pure = true, DictAttrs attrs = 
DictAttrs(), Span span = Span());
 
diff --git a/src/contrib/msc/core/ir/graph_builder.cc 
b/src/contrib/msc/core/ir/graph_builder.cc
index 02b5a2ee67..d35a462579 100644
--- a/src/contrib/msc/core/ir/graph_builder.cc
+++ b/src/contrib/msc/core/ir/graph_builder.cc
@@ -166,12 +166,9 @@ const MSCGraph RelaxGraphBuilder::Build(const 
relax::Function& func) {
     }
   }
   VisitExpr(func);
-  if (const auto* b_node = func->body.as<relax::SeqExprNode>()) {
-    ICHECK(expr_tensor_map_.count(b_node->body)) << "Can not find seqexpr body 
" << b_node->body;
-    output_names = expr_tensor_map_[b_node->body];
-  } else {
-    LOG(FATAL) << "Function body should be SeqExpr, get " << func->body;
-  }
+  ICHECK(expr_tensor_map_.count(func->body->body))
+      << "Can not find seqexpr body " << func->body->body;
+  output_names = expr_tensor_map_[func->body->body];
   // remove const nodes as weights
   Array<MSCJoint> valid_nodes;
   std::set<String> ignore_inputs;
diff --git a/src/contrib/msc/core/transform/set_expr_layout.cc 
b/src/contrib/msc/core/transform/set_expr_layout.cc
index 0ece7a51ca..76775a5ba3 100644
--- a/src/contrib/msc/core/transform/set_expr_layout.cc
+++ b/src/contrib/msc/core/transform/set_expr_layout.cc
@@ -1268,13 +1268,9 @@ class LayoutInfer : public ExprVisitor {
         SetExprLayout(call->args[i], var_layout_map_[func->params[i]]);
       }
     }
-    if (const auto* b_node = func->body.as<relax::SeqExprNode>()) {
-      if (b_node->body->IsInstance<VarNode>() &&
-          var_layout_map_.count(Downcast<Var>(b_node->body))) {
-        SetExprLayout(ret, var_layout_map_[Downcast<Var>(b_node->body)]);
-      }
-    } else {
-      LOG(FATAL) << "Function body should be SeqExpr, get " << func->body;
+    if (func->body->body->IsInstance<VarNode>() &&
+        var_layout_map_.count(Downcast<Var>(func->body->body))) {
+      SetExprLayout(ret, var_layout_map_[Downcast<Var>(func->body->body)]);
     }
   }
 
@@ -1288,13 +1284,9 @@ class LayoutInfer : public ExprVisitor {
           if (producer->IsInstance<CallNode>() &&
               local_funcs_.count(Downcast<Call>(producer)->op)) {
             const auto& caller = local_funcs_[Downcast<Call>(producer)->op];
-            if (const auto* b_node = caller->body.as<relax::SeqExprNode>()) {
-              if (b_node->body->IsInstance<VarNode>() &&
-                  var_map_.count(Downcast<Var>(b_node->body))) {
-                SetExprLayout(b_node->body, param_layout);
-              }
-            } else {
-              LOG(FATAL) << "Caller body should be SeqExpr, get " << 
caller->body;
+            if (caller->body->body->IsInstance<VarNode>() &&
+                var_map_.count(Downcast<Var>(caller->body->body))) {
+              SetExprLayout(caller->body->body, param_layout);
             }
           }
         }
diff --git a/src/relax/analysis/well_formed.cc 
b/src/relax/analysis/well_formed.cc
index b4a0fc4b98..a73e6fb233 100644
--- a/src/relax/analysis/well_formed.cc
+++ b/src/relax/analysis/well_formed.cc
@@ -281,11 +281,7 @@ class WellFormedChecker : public relax::ExprVisitor,
       }
     }
 
-    if (auto seq = op->body.as<SeqExprNode>()) {
-      this->VisitSeqExpr(seq);
-    } else {
-      Malformed(Diagnostic::Error(op) << "Function bodies must be sequence 
expressions");
-    }
+    this->VisitSeqExpr(op->body.get());
 
     is_dataflow_ = old_dataflow_state;
     dataflow_var_set_ = prev_dataflow_var_set;
@@ -367,21 +363,17 @@ class WellFormedChecker : public relax::ExprVisitor,
     } else {
       Malformed(Diagnostic::Error(op) << "The condition for an if node must be 
a leaf expression.");
     }
-    auto true_seq = op->true_branch.as<SeqExprNode>();
-    auto false_seq = op->false_branch.as<SeqExprNode>();
-    if (true_seq && false_seq) {
-      std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> previous_var_set 
= var_set_;
-      std::unordered_set<tir::Var, ObjectPtrHash, ObjectPtrEqual> 
previous_symbolic_var_set =
-          symbolic_var_set_;
-      this->VisitSeqExpr(true_seq);
-      var_set_ = previous_var_set;
-      symbolic_var_set_ = previous_symbolic_var_set;
-      this->VisitSeqExpr(false_seq);
-      var_set_ = previous_var_set;
-      symbolic_var_set_ = previous_symbolic_var_set;
-    } else {
-      Malformed(Diagnostic::Error(op) << "If node branches must be seq exprs");
-    }
+
+    std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> previous_var_set = 
var_set_;
+    std::unordered_set<tir::Var, ObjectPtrHash, ObjectPtrEqual> 
previous_symbolic_var_set =
+        symbolic_var_set_;
+    this->VisitSeqExpr(op->true_branch.get());
+    var_set_ = previous_var_set;
+    symbolic_var_set_ = previous_symbolic_var_set;
+    this->VisitSeqExpr(op->false_branch.get());
+    var_set_ = previous_var_set;
+    symbolic_var_set_ = previous_symbolic_var_set;
+
     CheckStructInfo(op);
   }
 
diff --git a/src/relax/backend/contrib/utils.cc 
b/src/relax/backend/contrib/utils.cc
index 20b2a6fce6..b260ea24be 100644
--- a/src/relax/backend/contrib/utils.cc
+++ b/src/relax/backend/contrib/utils.cc
@@ -36,7 +36,7 @@ Map<String, IntImm> ExtractArgIdx(String pattern_name, 
Function f) {
   ICHECK(pattern) << "Unsupported op_type " << pattern_name;
 
   auto bindings = AnalyzeVar2Value(f);
-  auto inner_body = Downcast<SeqExpr>(f->body)->body;
+  auto inner_body = f->body->body;
   auto matched_expr = relax::ExtractMatchedExpr(pattern.value()->pattern, 
inner_body, bindings);
   ICHECK(matched_expr) << "ValueError: "
                        << "For named pattern \"" << pattern_name
diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc
index cf8934c372..c0b8d1e1df 100644
--- a/src/relax/ir/dataflow_matcher.cc
+++ b/src/relax/ir/dataflow_matcher.cc
@@ -59,13 +59,30 @@ bool DFPatternMatcher::Match(const DFPattern& pattern, 
const Expr& expr) {
   return VisitDFPattern(pattern, expr);
 }
 
-static Expr TryGetValOfVar(const Expr& expr, const Map<Var, Expr>& var2val) {
-  if (var2val.empty()) return expr;
+static Expr TryGetValOfVar(Expr expr, const Map<Var, Expr>& var2val) {
+  auto unwrap = [&](Expr expr) -> Optional<Expr> {
+    // Unwrap variables into the value to which they are bound.
+    if (var2val.size()) {
+      if (const VarNode* var = expr.as<VarNode>()) {
+        if (auto may = var2val.Get(GetRef<Var>(var))) {
+          return may.value();
+        }
+      }
+    }
+
+    // Unwrap SeqExpr with no bindings.  These can occur due to Relax
+    // IR constraints for the bodies of Function and If nodes.
+    if (auto seq = expr.as<SeqExprNode>()) {
+      if (seq->blocks.empty()) {
+        return seq->body;
+      }
+    }
+
+    return NullOpt;
+  };
 
-  // if not match, try to match value of var if expr is a var.
-  if (const VarNode* var = expr.as<VarNode>()) {
-    auto may = var2val.Get(GetRef<Var>(var));
-    if (may.defined()) return may.value();
+  while (auto unwrapped = unwrap(expr)) {
+    expr = unwrapped.value();
   }
 
   return expr;
diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc
index dd0f68dca4..eb46775765 100644
--- a/src/relax/ir/expr.cc
+++ b/src/relax/ir/expr.cc
@@ -492,6 +492,14 @@ 
TVM_REGISTER_GLOBAL("relax.DataflowBlock").set_body_typed([](Array<Binding> bind
 
 TVM_REGISTER_NODE_TYPE(SeqExprNode);
 
+SeqExpr::SeqExpr(Expr body) {
+  if (auto seq = body.as<SeqExpr>()) {
+    *this = seq.value();
+  } else {
+    *this = SeqExpr(Array<BindingBlock>{}, body);
+  }
+}
+
 SeqExpr::SeqExpr(Array<BindingBlock> blocks, Expr body, Span span) {
   ObjectPtr<SeqExprNode> n = make_object<SeqExprNode>();
   n->blocks = std::move(blocks);
diff --git a/src/relax/training/utils.cc b/src/relax/training/utils.cc
index 19faaad58b..a7348483f6 100644
--- a/src/relax/training/utils.cc
+++ b/src/relax/training/utils.cc
@@ -65,13 +65,10 @@ class AppendLossMutator : private ExprMutator {
         num_backbone_outputs_(num_backbone_outputs) {}
 
   Expr VisitExpr_(const FunctionNode* func) final {
-    CHECK(func->body->IsInstance<SeqExprNode>() && 
loss_function_->body->IsInstance<SeqExprNode>())
-        << "The bodies of the backbone and the loss function must be SeqExpr.";
-
     // Well-formed checks and setting up class members
-    loss_body_ = Downcast<SeqExpr>(loss_function_->body);
+    loss_body_ = loss_function_->body;
     CheckLossBody();
-    BackboneReturnToArr(func->body.as<SeqExprNode>()->body);
+    BackboneReturnToArr(func->body->body);
     CheckAndRemapBackboneReturn();
     CheckAndRemapLossParams(loss_function_->params);
 
diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc
index ee96f9fa80..04c07c439c 100644
--- a/src/relax/transform/fuse_ops.cc
+++ b/src/relax/transform/fuse_ops.cc
@@ -1266,9 +1266,19 @@ class CompositeFunctionAnnotator : public ExprMutator {
       params.push_back(new_v);
     }
 
+    // We cannot delegate to `ExprMutator::VisitExpr_(const FunctionNode*)` at 
this point, as it
+    // would recursively visit the Call node.  However, we are still required 
to generate
+    // well-formed Relax IR.  As a result, we need to build the SeqExpr 
ourselves.
+    Var local_func_var("local_func", GetStructInfo(f_inner));
+    Var output_var("output", f_inner->ret_struct_info);
+    SeqExpr new_body({BindingBlock({
+                         VarBinding(local_func_var, f_inner),
+                         VarBinding(output_var, Call(local_func_var, params)),
+                     })},
+                     output_var);
+
     // pure if the inner func is pure (no need to force purity if it's forced 
for the inner func)
-    return Function(param_vars, Call(f_inner, params), 
func_node->ret_struct_info,
-                    f_inner->is_pure);
+    return Function(param_vars, new_body, func_node->ret_struct_info, 
f_inner->is_pure);
   }
 
  private:
diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc
index 3df17b29ca..cb8d340f7d 100644
--- a/src/relax/transform/fuse_tir.cc
+++ b/src/relax/transform/fuse_tir.cc
@@ -438,9 +438,7 @@ class FusedTIRConstructor : public ExprVisitor {
     ExprVisitor::VisitExpr_(func);
 
     // Step 3. Create and remap buffers for function output
-    ICHECK(func->body->IsInstance<SeqExprNode>())
-        << "Function body is expected to be a SeqExpr, but got: " << 
func->body->GetTypeKey();
-    Expr body = Downcast<SeqExpr>(func->body)->body;
+    Expr body = func->body->body;
     auto it = func_info_.expr2buffers.find(body);
     ICHECK(it != func_info_.expr2buffers.end())
         << "Fail to detect output buffers for function body";
diff --git a/src/relax/transform/gradient.cc b/src/relax/transform/gradient.cc
index 70e3e37876..cd07af37e0 100644
--- a/src/relax/transform/gradient.cc
+++ b/src/relax/transform/gradient.cc
@@ -664,8 +664,6 @@ class GradientMutator : private ExprMutator {
   }
 
   Expr VisitExpr_(const FunctionNode* func) final {
-    CHECK(func->body->IsInstance<SeqExprNode>()) << "The body of the function 
must be SeqExpr.";
-
     orig_params_ = func->params;
     Expr new_body = this->VisitExpr(func->body);
 
diff --git a/src/script/printer/relax/binding.cc 
b/src/script/printer/relax/binding.cc
index 44a2cd338c..c8b616b4bc 100644
--- a/src/script/printer/relax/binding.cc
+++ b/src/script/printer/relax/binding.cc
@@ -27,8 +27,8 @@ IfDoc PrintIfExpr(const relax::If& n, const ObjectPath& n_p, 
const IRDocsifier&
   using relax::SeqExpr;
   ExprDoc cond = d->AsDoc<ExprDoc>(n->cond, n_p->Attr("cond"));
   std::vector<Array<StmtDoc>> branches{
-      PrintSeqExpr(Downcast<SeqExpr>(n->true_branch), 
n_p->Attr("true_branch"), d, false),
-      PrintSeqExpr(Downcast<SeqExpr>(n->false_branch), 
n_p->Attr("false_branch"), d, false),
+      PrintSeqExpr(n->true_branch, n_p->Attr("true_branch"), d, false),
+      PrintSeqExpr(n->false_branch, n_p->Attr("false_branch"), d, false),
   };
   if (var.defined()) {
     for (Array<StmtDoc>& stmts : branches) {
diff --git a/src/script/printer/relax/function.cc 
b/src/script/printer/relax/function.cc
index 458eb3766d..3b5302bebc 100644
--- a/src/script/printer/relax/function.cc
+++ b/src/script/printer/relax/function.cc
@@ -119,8 +119,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
       }
 
       // Step 6. Print body
-      Array<StmtDoc> body =
-          PrintSeqExpr(Downcast<relax::SeqExpr>(n->body), n_p->Attr("body"), 
d, /*use_ret=*/true);
+      Array<StmtDoc> body = PrintSeqExpr(n->body, n_p->Attr("body"), d, 
/*use_ret=*/true);
       (*f)->stmts.insert((*f)->stmts.end(), body.begin(), body.end());
       return HeaderWrapper(d, FunctionDoc(func_name, params, {decorator}, 
ret_type, (*f)->stmts));
     });
diff --git a/tests/python/relax/test_expr_functor.py 
b/tests/python/relax/test_expr_functor.py
index 0daf9d4a1f..f3d2432549 100644
--- a/tests/python/relax/test_expr_functor.py
+++ b/tests/python/relax/test_expr_functor.py
@@ -439,7 +439,7 @@ def test_if():
     if_node = relax.If(x, x, x)
     basic_check(
         if_node,
-        "\n".join(["If", "\tVar", "\tVar", "\tVar"]),
+        "\n".join(["If", "\tVar", "\tSeqExpr", "\t\tVar", "\tSeqExpr", 
"\t\tVar"]),
         "\n".join(["Var", "Var", "SeqExpr", "Var", "SeqExpr", "If"]),
     )
 

Reply via email to