This is an automated email from the ASF dual-hosted git repository.
sanirudh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new a2511cc516 [QoL][Relax] Use SeqExpr in IR types when SeqExpr is
required (#16859)
a2511cc516 is described below
commit a2511cc5160fa73131517515c79144bef7f4b076
Author: Eric Lunderberg <[email protected]>
AuthorDate: Sat Apr 20 03:15:52 2024 -0500
[QoL][Relax] Use SeqExpr in IR types when SeqExpr is required (#16859)
* [QoL][Relax] Use SeqExpr in IR types when SeqExpr is required
The Relax IR requires the `FunctionNode::body`, `IfNode::true_branch`,
and `IfNode::false_branch` to be instances of `relax::SeqExpr`.
If these Relax requirements are violated, correctly-implemented
transformations may raise exceptsion
(e.g. from `Downcast` in `Downcast<SeqExpr>(func->body)->blocks`), or
even segfault (e.g. when `.as` returns a nullptr in
`func->body.as<SeqExprNode>()->blocks`). Debugging these failure
modes is also difficult, as even the TVMScript printer relies on the
body of the function being a `SeqExprNode`.
This commit updates the C++ type of `FunctionNode::body`,
`IfNode::true_branch`, and `IfNode::false_branch` to be
`relax::SeqExpr` instead of `relax::Expr`. This does not impact any
well-formed Relax IR, and allows this type of ill-formed Relax IR type
to be checked at compile-time. A large number of checks applied
during TVM runtime can now be removed, as they duplicate the new
compile-time check.
To maintain backwards compatibility, this commit adds a new
constructor to `relax::SeqExpr`, which accepts a single `Expr body`
argument. This constructor provides either an additional reference to
the same underlying `relax::SeqExprNode`, if `body` already contains a
`relax::SeqExprNode`, and otherwise wraps the body in a
`relax::SeqExpr`. For implementations that previously produced
well-formed Relax IR, this change has no effect. For implementations
that previously produced ill-formed Relax IR, this change results in
the equivalent well-formed Relax IR.
Alternate implementations considered:
* Perform the backwards-compatibility wrapping within the
`relax::Function` and `relax::If` constructors. While this would
provide the intended conversion when these constructors are used,
Relax transforms make frequent use of copy-on-write
(e.g. `func.CopyOnWrite()->body = new_body`), which does not use the
constructor. Maintaining backwards compatibility for this usage
requires the implicit conversion constructor that was chosen for
this PR.
* Remove the Relax IR requirement for these expressions to be
`SeqExpr`. While this would make Relax more internally consistent,
such a change would break backwards compatibility that relies on
`SeqExpr` being present. While the callsites within TVM could be
updated to resolve this breakage, callsites outside of TVM
(e.g. MLC-LLM) could not. Exposing the special case within the C++
type, as done in this PR, maintains backwards compatibility.
* Resolve breakages in unit tests
All breakage was the result of callers relying on ill-formed Relax
maintaining that specific type form of ill-formed-ness.
---
include/tvm/relax/expr.h | 190 +++++++++++++---------
src/contrib/msc/core/ir/graph_builder.cc | 9 +-
src/contrib/msc/core/transform/set_expr_layout.cc | 20 +--
src/relax/analysis/well_formed.cc | 32 ++--
src/relax/backend/contrib/utils.cc | 2 +-
src/relax/ir/dataflow_matcher.cc | 29 +++-
src/relax/ir/expr.cc | 8 +
src/relax/training/utils.cc | 7 +-
src/relax/transform/fuse_ops.cc | 14 +-
src/relax/transform/fuse_tir.cc | 4 +-
src/relax/transform/gradient.cc | 2 -
src/script/printer/relax/binding.cc | 4 +-
src/script/printer/relax/function.cc | 3 +-
tests/python/relax/test_expr_functor.py | 2 +-
14 files changed, 189 insertions(+), 137 deletions(-)
diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h
index e2176cf720..0ca92a01a7 100644
--- a/include/tvm/relax/expr.h
+++ b/include/tvm/relax/expr.h
@@ -213,78 +213,6 @@ Call WithFields(Call call, Optional<Expr> opt_op =
Optional<Expr>(),
Optional<Array<StructInfo>> opt_sinfo_args =
Optional<Array<StructInfo>>(),
Optional<Span> opt_span = Optional<Span>());
-/*!
- * \brief Condition expression
- *
- * Unlike traditional statement `if`s, the if evalutes
- * to the result of the branch taken.
- *
- * x = if (true) { 1 } else { 0 }; // x is 1
- * y = if (false) { 1 } else { 0 }; // y is 0
- *
- * \note This is similar to C's ternary operator.
- */
-class IfNode : public ExprNode {
- public:
- /*! \brief The condition. */
- Expr cond;
- /*! \brief The expression evaluated when condition is true. */
- Expr true_branch;
- /*! \brief The expression evaluated when condition is false */
- Expr false_branch;
-
- void VisitAttrs(tvm::AttrVisitor* v) {
- v->Visit("cond", &cond);
- v->Visit("true_branch", &true_branch);
- v->Visit("false_branch", &false_branch);
- v->Visit("_checked_type_", &checked_type_);
- v->Visit("struct_info_", &struct_info_);
- v->Visit("span", &span);
- }
-
- bool SEqualReduce(const IfNode* other, SEqualReducer equal) const {
- equal->MarkGraphNode();
- return equal(cond, other->cond) && equal(true_branch, other->true_branch)
&&
- equal(false_branch, other->false_branch) && equal(struct_info_,
other->struct_info_);
- }
-
- void SHashReduce(SHashReducer hash_reduce) const {
- hash_reduce->MarkGraphNode();
- hash_reduce(cond);
- hash_reduce(true_branch);
- hash_reduce(false_branch);
- hash_reduce(struct_info_);
- }
-
- static constexpr const char* _type_key = "relax.expr.If";
- TVM_DECLARE_FINAL_OBJECT_INFO(IfNode, ExprNode);
-};
-
-class If : public Expr {
- public:
- /*!
- * \brief The constructor
- * \param cond The condition of a if node.
- * \param true_branch The fall through branch
- * \param false_branch The branch for execution when condition is false.
- * \param span The source span of the expression.
- */
- TVM_DLL If(Expr cond, Expr true_branch, Expr false_branch, Span span =
Span());
-
- TVM_DEFINE_OBJECT_REF_METHODS(If, Expr, IfNode);
- TVM_DEFINE_OBJECT_REF_COW_METHOD(IfNode);
-};
-
-/*!
- * \brief Returns \p if_expr with the given properties. A null property
denotes 'no change'.
- * Returns \p if_expr if all properties are unchanged. Otherwise, returns a
copy with the new
- * fields.
- */
-If WithFields(If if_expr, Optional<Expr> opt_cond = Optional<Expr>(),
- Optional<Expr> opt_true_branch = Optional<Expr>(),
- Optional<Expr> opt_false_branch = Optional<Expr>(),
- Optional<Span> opt_span = Optional<Span>());
-
/*! \brief Tuple container */
class TupleNode : public ExprNode {
public:
@@ -915,18 +843,113 @@ class SeqExprNode : public ExprNode {
class SeqExpr : public Expr {
public:
+ /* \brief Implicit conversion constructor
+ *
+ * Relax nodes that introduce a new scope (e.g. `relax::Function`)
+ * are required to be held as SeqExpr. This implicit conversion
+ * provides allows callsites to use these member variables when the
+ * C++ compile-time type is a `relax::Expr`. For example,
+ * a transform may use `func.CopyOnWrite()->body = expr;`.
+ *
+ * If the expression is already a `relax::SeqExpr`, the same
+ * underlying `relax::SeqExprNode` is used, and no copies are made.
+ */
+ TVM_DLL SeqExpr(Expr body); // NOLINT(*)
+
TVM_DLL explicit SeqExpr(Array<BindingBlock> blocks, Expr body, Span span =
Span());
TVM_DEFINE_OBJECT_REF_METHODS(SeqExpr, Expr, SeqExprNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(SeqExprNode);
};
+/*!
+ * \brief Condition expression
+ *
+ * Unlike traditional statement `if`s, the if evalutes
+ * to the result of the branch taken.
+ *
+ * x = if (true) { 1 } else { 0 }; // x is 1
+ * y = if (false) { 1 } else { 0 }; // y is 0
+ *
+ * \note This is similar to C's ternary operator.
+ */
+class IfNode : public ExprNode {
+ public:
+ /*! \brief The condition. */
+ Expr cond;
+ /*! \brief The expression evaluated when condition is true. */
+ SeqExpr true_branch;
+ /*! \brief The expression evaluated when condition is false */
+ SeqExpr false_branch;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ v->Visit("cond", &cond);
+ v->Visit("true_branch", &true_branch);
+ v->Visit("false_branch", &false_branch);
+ v->Visit("_checked_type_", &checked_type_);
+ v->Visit("struct_info_", &struct_info_);
+ v->Visit("span", &span);
+ }
+
+ bool SEqualReduce(const IfNode* other, SEqualReducer equal) const {
+ equal->MarkGraphNode();
+ return equal(cond, other->cond) && equal(true_branch, other->true_branch)
&&
+ equal(false_branch, other->false_branch) && equal(struct_info_,
other->struct_info_);
+ }
+
+ void SHashReduce(SHashReducer hash_reduce) const {
+ hash_reduce->MarkGraphNode();
+ hash_reduce(cond);
+ hash_reduce(true_branch);
+ hash_reduce(false_branch);
+ hash_reduce(struct_info_);
+ }
+
+ static constexpr const char* _type_key = "relax.expr.If";
+ TVM_DECLARE_FINAL_OBJECT_INFO(IfNode, ExprNode);
+};
+
+class If : public Expr {
+ public:
+ /*!
+ * \brief The constructor
+ *
+ * \param cond The condition of a if node.
+ *
+ * \param true_branch The fall through branch. If this is not a
+ * SeqExpr, it will be wrapped in a SeqExpr, to satisfy the
+ * Relax IR requirement that all scopes be contained in a
+ * SeqExpr.
+ *
+ * \param false_branch The branch for execution when condition is
+ * false. If this is not a SeqExpr, it will be wrapped in a
+ * SeqExpr, to satisfy the Relax IR requirement that all scopes
+ * be contained in a SeqExpr.
+ *
+ * \param span The source span of the expression.
+ */
+ TVM_DLL If(Expr cond, Expr true_branch, Expr false_branch, Span span =
Span());
+
+ TVM_DEFINE_OBJECT_REF_METHODS(If, Expr, IfNode);
+ TVM_DEFINE_OBJECT_REF_COW_METHOD(IfNode);
+};
+
+/*!
+ * \brief Returns \p if_expr with the given properties. A null property
denotes 'no change'.
+ * Returns \p if_expr if all properties are unchanged. Otherwise, returns a
copy with the new
+ * fields.
+ */
+If WithFields(If if_expr, Optional<Expr> opt_cond = Optional<Expr>(),
+ Optional<Expr> opt_true_branch = Optional<Expr>(),
+ Optional<Expr> opt_false_branch = Optional<Expr>(),
+ Optional<Span> opt_span = Optional<Span>());
+
/*! \brief A Relax function. */
class FunctionNode : public BaseFuncNode {
public:
/*! \brief The parameters to the function. */
Array<Var> params;
/*! \brief The body of the function. */
- Expr body;
+ SeqExpr body;
/*! \brief The return type of the function. */
StructInfo ret_struct_info;
/*! \brief Whether the function is annotated as pure or not. */
@@ -968,6 +991,27 @@ class FunctionNode : public BaseFuncNode {
class Function : public BaseFunc {
public:
+ /*!
+ * \brief Construct a Relax Function
+ *
+ * \param params The parameters accepted by the function
+ *
+ * \param body The body of the function. If this is not a
+ * SeqExpr, it will be wrapped in a SeqExpr, to satisfy the
+ * Relax IR requirement that all scopes be contained in a
+ * SeqExpr.
+ *
+ * \param ret_struct_info The StructInfo returned by the function.
+ * If NullOpt, will be inferred from the StructInfo of the
+ * function's body.
+ *
+ * \param is_pure The purity of the function.
+ *
+ * \param attrs Any attributes associated with the function.
+ * Defaults to an empty dictionary.
+ *
+ * \param span The source span of the expression.
+ */
TVM_DLL explicit Function(Array<Var> params, Expr body, Optional<StructInfo>
ret_struct_info,
bool is_pure = true, DictAttrs attrs =
DictAttrs(), Span span = Span());
diff --git a/src/contrib/msc/core/ir/graph_builder.cc
b/src/contrib/msc/core/ir/graph_builder.cc
index 02b5a2ee67..d35a462579 100644
--- a/src/contrib/msc/core/ir/graph_builder.cc
+++ b/src/contrib/msc/core/ir/graph_builder.cc
@@ -166,12 +166,9 @@ const MSCGraph RelaxGraphBuilder::Build(const
relax::Function& func) {
}
}
VisitExpr(func);
- if (const auto* b_node = func->body.as<relax::SeqExprNode>()) {
- ICHECK(expr_tensor_map_.count(b_node->body)) << "Can not find seqexpr body
" << b_node->body;
- output_names = expr_tensor_map_[b_node->body];
- } else {
- LOG(FATAL) << "Function body should be SeqExpr, get " << func->body;
- }
+ ICHECK(expr_tensor_map_.count(func->body->body))
+ << "Can not find seqexpr body " << func->body->body;
+ output_names = expr_tensor_map_[func->body->body];
// remove const nodes as weights
Array<MSCJoint> valid_nodes;
std::set<String> ignore_inputs;
diff --git a/src/contrib/msc/core/transform/set_expr_layout.cc
b/src/contrib/msc/core/transform/set_expr_layout.cc
index 0ece7a51ca..76775a5ba3 100644
--- a/src/contrib/msc/core/transform/set_expr_layout.cc
+++ b/src/contrib/msc/core/transform/set_expr_layout.cc
@@ -1268,13 +1268,9 @@ class LayoutInfer : public ExprVisitor {
SetExprLayout(call->args[i], var_layout_map_[func->params[i]]);
}
}
- if (const auto* b_node = func->body.as<relax::SeqExprNode>()) {
- if (b_node->body->IsInstance<VarNode>() &&
- var_layout_map_.count(Downcast<Var>(b_node->body))) {
- SetExprLayout(ret, var_layout_map_[Downcast<Var>(b_node->body)]);
- }
- } else {
- LOG(FATAL) << "Function body should be SeqExpr, get " << func->body;
+ if (func->body->body->IsInstance<VarNode>() &&
+ var_layout_map_.count(Downcast<Var>(func->body->body))) {
+ SetExprLayout(ret, var_layout_map_[Downcast<Var>(func->body->body)]);
}
}
@@ -1288,13 +1284,9 @@ class LayoutInfer : public ExprVisitor {
if (producer->IsInstance<CallNode>() &&
local_funcs_.count(Downcast<Call>(producer)->op)) {
const auto& caller = local_funcs_[Downcast<Call>(producer)->op];
- if (const auto* b_node = caller->body.as<relax::SeqExprNode>()) {
- if (b_node->body->IsInstance<VarNode>() &&
- var_map_.count(Downcast<Var>(b_node->body))) {
- SetExprLayout(b_node->body, param_layout);
- }
- } else {
- LOG(FATAL) << "Caller body should be SeqExpr, get " <<
caller->body;
+ if (caller->body->body->IsInstance<VarNode>() &&
+ var_map_.count(Downcast<Var>(caller->body->body))) {
+ SetExprLayout(caller->body->body, param_layout);
}
}
}
diff --git a/src/relax/analysis/well_formed.cc
b/src/relax/analysis/well_formed.cc
index b4a0fc4b98..a73e6fb233 100644
--- a/src/relax/analysis/well_formed.cc
+++ b/src/relax/analysis/well_formed.cc
@@ -281,11 +281,7 @@ class WellFormedChecker : public relax::ExprVisitor,
}
}
- if (auto seq = op->body.as<SeqExprNode>()) {
- this->VisitSeqExpr(seq);
- } else {
- Malformed(Diagnostic::Error(op) << "Function bodies must be sequence
expressions");
- }
+ this->VisitSeqExpr(op->body.get());
is_dataflow_ = old_dataflow_state;
dataflow_var_set_ = prev_dataflow_var_set;
@@ -367,21 +363,17 @@ class WellFormedChecker : public relax::ExprVisitor,
} else {
Malformed(Diagnostic::Error(op) << "The condition for an if node must be
a leaf expression.");
}
- auto true_seq = op->true_branch.as<SeqExprNode>();
- auto false_seq = op->false_branch.as<SeqExprNode>();
- if (true_seq && false_seq) {
- std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> previous_var_set
= var_set_;
- std::unordered_set<tir::Var, ObjectPtrHash, ObjectPtrEqual>
previous_symbolic_var_set =
- symbolic_var_set_;
- this->VisitSeqExpr(true_seq);
- var_set_ = previous_var_set;
- symbolic_var_set_ = previous_symbolic_var_set;
- this->VisitSeqExpr(false_seq);
- var_set_ = previous_var_set;
- symbolic_var_set_ = previous_symbolic_var_set;
- } else {
- Malformed(Diagnostic::Error(op) << "If node branches must be seq exprs");
- }
+
+ std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> previous_var_set =
var_set_;
+ std::unordered_set<tir::Var, ObjectPtrHash, ObjectPtrEqual>
previous_symbolic_var_set =
+ symbolic_var_set_;
+ this->VisitSeqExpr(op->true_branch.get());
+ var_set_ = previous_var_set;
+ symbolic_var_set_ = previous_symbolic_var_set;
+ this->VisitSeqExpr(op->false_branch.get());
+ var_set_ = previous_var_set;
+ symbolic_var_set_ = previous_symbolic_var_set;
+
CheckStructInfo(op);
}
diff --git a/src/relax/backend/contrib/utils.cc
b/src/relax/backend/contrib/utils.cc
index 20b2a6fce6..b260ea24be 100644
--- a/src/relax/backend/contrib/utils.cc
+++ b/src/relax/backend/contrib/utils.cc
@@ -36,7 +36,7 @@ Map<String, IntImm> ExtractArgIdx(String pattern_name,
Function f) {
ICHECK(pattern) << "Unsupported op_type " << pattern_name;
auto bindings = AnalyzeVar2Value(f);
- auto inner_body = Downcast<SeqExpr>(f->body)->body;
+ auto inner_body = f->body->body;
auto matched_expr = relax::ExtractMatchedExpr(pattern.value()->pattern,
inner_body, bindings);
ICHECK(matched_expr) << "ValueError: "
<< "For named pattern \"" << pattern_name
diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc
index cf8934c372..c0b8d1e1df 100644
--- a/src/relax/ir/dataflow_matcher.cc
+++ b/src/relax/ir/dataflow_matcher.cc
@@ -59,13 +59,30 @@ bool DFPatternMatcher::Match(const DFPattern& pattern,
const Expr& expr) {
return VisitDFPattern(pattern, expr);
}
-static Expr TryGetValOfVar(const Expr& expr, const Map<Var, Expr>& var2val) {
- if (var2val.empty()) return expr;
+static Expr TryGetValOfVar(Expr expr, const Map<Var, Expr>& var2val) {
+ auto unwrap = [&](Expr expr) -> Optional<Expr> {
+ // Unwrap variables into the value to which they are bound.
+ if (var2val.size()) {
+ if (const VarNode* var = expr.as<VarNode>()) {
+ if (auto may = var2val.Get(GetRef<Var>(var))) {
+ return may.value();
+ }
+ }
+ }
+
+ // Unwrap SeqExpr with no bindings. These can occur due to Relax
+ // IR constraints for the bodies of Function and If nodes.
+ if (auto seq = expr.as<SeqExprNode>()) {
+ if (seq->blocks.empty()) {
+ return seq->body;
+ }
+ }
+
+ return NullOpt;
+ };
- // if not match, try to match value of var if expr is a var.
- if (const VarNode* var = expr.as<VarNode>()) {
- auto may = var2val.Get(GetRef<Var>(var));
- if (may.defined()) return may.value();
+ while (auto unwrapped = unwrap(expr)) {
+ expr = unwrapped.value();
}
return expr;
diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc
index dd0f68dca4..eb46775765 100644
--- a/src/relax/ir/expr.cc
+++ b/src/relax/ir/expr.cc
@@ -492,6 +492,14 @@
TVM_REGISTER_GLOBAL("relax.DataflowBlock").set_body_typed([](Array<Binding> bind
TVM_REGISTER_NODE_TYPE(SeqExprNode);
+SeqExpr::SeqExpr(Expr body) {
+ if (auto seq = body.as<SeqExpr>()) {
+ *this = seq.value();
+ } else {
+ *this = SeqExpr(Array<BindingBlock>{}, body);
+ }
+}
+
SeqExpr::SeqExpr(Array<BindingBlock> blocks, Expr body, Span span) {
ObjectPtr<SeqExprNode> n = make_object<SeqExprNode>();
n->blocks = std::move(blocks);
diff --git a/src/relax/training/utils.cc b/src/relax/training/utils.cc
index 19faaad58b..a7348483f6 100644
--- a/src/relax/training/utils.cc
+++ b/src/relax/training/utils.cc
@@ -65,13 +65,10 @@ class AppendLossMutator : private ExprMutator {
num_backbone_outputs_(num_backbone_outputs) {}
Expr VisitExpr_(const FunctionNode* func) final {
- CHECK(func->body->IsInstance<SeqExprNode>() &&
loss_function_->body->IsInstance<SeqExprNode>())
- << "The bodies of the backbone and the loss function must be SeqExpr.";
-
// Well-formed checks and setting up class members
- loss_body_ = Downcast<SeqExpr>(loss_function_->body);
+ loss_body_ = loss_function_->body;
CheckLossBody();
- BackboneReturnToArr(func->body.as<SeqExprNode>()->body);
+ BackboneReturnToArr(func->body->body);
CheckAndRemapBackboneReturn();
CheckAndRemapLossParams(loss_function_->params);
diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc
index ee96f9fa80..04c07c439c 100644
--- a/src/relax/transform/fuse_ops.cc
+++ b/src/relax/transform/fuse_ops.cc
@@ -1266,9 +1266,19 @@ class CompositeFunctionAnnotator : public ExprMutator {
params.push_back(new_v);
}
+ // We cannot delegate to `ExprMutator::VisitExpr_(const FunctionNode*)` at
this point, as it
+ // would recursively visit the Call node. However, we are still required
to generate
+ // well-formed Relax IR. As a result, we need to build the SeqExpr
ourselves.
+ Var local_func_var("local_func", GetStructInfo(f_inner));
+ Var output_var("output", f_inner->ret_struct_info);
+ SeqExpr new_body({BindingBlock({
+ VarBinding(local_func_var, f_inner),
+ VarBinding(output_var, Call(local_func_var, params)),
+ })},
+ output_var);
+
// pure if the inner func is pure (no need to force purity if it's forced
for the inner func)
- return Function(param_vars, Call(f_inner, params),
func_node->ret_struct_info,
- f_inner->is_pure);
+ return Function(param_vars, new_body, func_node->ret_struct_info,
f_inner->is_pure);
}
private:
diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc
index 3df17b29ca..cb8d340f7d 100644
--- a/src/relax/transform/fuse_tir.cc
+++ b/src/relax/transform/fuse_tir.cc
@@ -438,9 +438,7 @@ class FusedTIRConstructor : public ExprVisitor {
ExprVisitor::VisitExpr_(func);
// Step 3. Create and remap buffers for function output
- ICHECK(func->body->IsInstance<SeqExprNode>())
- << "Function body is expected to be a SeqExpr, but got: " <<
func->body->GetTypeKey();
- Expr body = Downcast<SeqExpr>(func->body)->body;
+ Expr body = func->body->body;
auto it = func_info_.expr2buffers.find(body);
ICHECK(it != func_info_.expr2buffers.end())
<< "Fail to detect output buffers for function body";
diff --git a/src/relax/transform/gradient.cc b/src/relax/transform/gradient.cc
index 70e3e37876..cd07af37e0 100644
--- a/src/relax/transform/gradient.cc
+++ b/src/relax/transform/gradient.cc
@@ -664,8 +664,6 @@ class GradientMutator : private ExprMutator {
}
Expr VisitExpr_(const FunctionNode* func) final {
- CHECK(func->body->IsInstance<SeqExprNode>()) << "The body of the function
must be SeqExpr.";
-
orig_params_ = func->params;
Expr new_body = this->VisitExpr(func->body);
diff --git a/src/script/printer/relax/binding.cc
b/src/script/printer/relax/binding.cc
index 44a2cd338c..c8b616b4bc 100644
--- a/src/script/printer/relax/binding.cc
+++ b/src/script/printer/relax/binding.cc
@@ -27,8 +27,8 @@ IfDoc PrintIfExpr(const relax::If& n, const ObjectPath& n_p,
const IRDocsifier&
using relax::SeqExpr;
ExprDoc cond = d->AsDoc<ExprDoc>(n->cond, n_p->Attr("cond"));
std::vector<Array<StmtDoc>> branches{
- PrintSeqExpr(Downcast<SeqExpr>(n->true_branch),
n_p->Attr("true_branch"), d, false),
- PrintSeqExpr(Downcast<SeqExpr>(n->false_branch),
n_p->Attr("false_branch"), d, false),
+ PrintSeqExpr(n->true_branch, n_p->Attr("true_branch"), d, false),
+ PrintSeqExpr(n->false_branch, n_p->Attr("false_branch"), d, false),
};
if (var.defined()) {
for (Array<StmtDoc>& stmts : branches) {
diff --git a/src/script/printer/relax/function.cc
b/src/script/printer/relax/function.cc
index 458eb3766d..3b5302bebc 100644
--- a/src/script/printer/relax/function.cc
+++ b/src/script/printer/relax/function.cc
@@ -119,8 +119,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
}
// Step 6. Print body
- Array<StmtDoc> body =
- PrintSeqExpr(Downcast<relax::SeqExpr>(n->body), n_p->Attr("body"),
d, /*use_ret=*/true);
+ Array<StmtDoc> body = PrintSeqExpr(n->body, n_p->Attr("body"), d,
/*use_ret=*/true);
(*f)->stmts.insert((*f)->stmts.end(), body.begin(), body.end());
return HeaderWrapper(d, FunctionDoc(func_name, params, {decorator},
ret_type, (*f)->stmts));
});
diff --git a/tests/python/relax/test_expr_functor.py
b/tests/python/relax/test_expr_functor.py
index 0daf9d4a1f..f3d2432549 100644
--- a/tests/python/relax/test_expr_functor.py
+++ b/tests/python/relax/test_expr_functor.py
@@ -439,7 +439,7 @@ def test_if():
if_node = relax.If(x, x, x)
basic_check(
if_node,
- "\n".join(["If", "\tVar", "\tVar", "\tVar"]),
+ "\n".join(["If", "\tVar", "\tSeqExpr", "\t\tVar", "\tSeqExpr",
"\t\tVar"]),
"\n".join(["Var", "Var", "SeqExpr", "Var", "SeqExpr", "If"]),
)