This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 7c88732056 [Unity][Pass] Wellformed Analysis (#14032)
7c88732056 is described below
commit 7c8873205688686acb82d306d1474a16d9a87b6e
Author: Lesheng Jin <[email protected]>
AuthorDate: Fri Feb 17 18:11:08 2023 -0800
[Unity][Pass] Wellformed Analysis (#14032)
This PR implements relax wellformed analysis, which checks if the IRModule
is well-formed. (tests and examples are added).
Co-Authored-by: Ruihang Lai
[[email protected]](mailto:[email protected])
Co-Authored-by: Siyuan Feng
[[email protected]](mailto:[email protected])
Co-Authored-by: Tianqi Chen
[[email protected]](mailto:[email protected])
Co-authored-by: Steven S. Lyubomirsky
[[email protected]](mailto:[email protected])
Co-authored-by: Yong Wu [[email protected]](mailto:[email protected])
Co-Authored-by: Yuchen Jin
[[email protected]](mailto:[email protected])
Co-Authored-by: Yixin Dong <[email protected]>
Co-Authored-by: Chaofan Lin <[email protected]>
Co-Authored-by: Prakalp Srivastava
[[email protected]](mailto:[email protected])
Co-Authored-by: Junru Shao
[[email protected]](mailto:[email protected])
---
include/tvm/relax/analysis.h | 13 +
python/tvm/relax/analysis/analysis.py | 27 ++
python/tvm/relax/ir/instrument.py | 37 ++
src/relax/analysis/well_formed.cc | 465 ++++++++++++++++++++++++
tests/python/relax/conftest.py | 23 ++
tests/python/relax/test_analysis_well_formed.py | 438 ++++++++++++++++++++++
6 files changed, 1003 insertions(+)
diff --git a/include/tvm/relax/analysis.h b/include/tvm/relax/analysis.h
index ff576d4ebb..f9896efdf2 100644
--- a/include/tvm/relax/analysis.h
+++ b/include/tvm/relax/analysis.h
@@ -343,6 +343,19 @@ TVM_DLL relay::OpPatternKind AnalyzeOpPatternKind(const
tir::PrimFunc& func);
*/
TVM_DLL bool HasReshapePattern(const tir::PrimFunc& func);
+/*!
+ * \brief Check if the IRModule is well formed.
+ *
+ * \param m the IRModule to check.
+ * \param check_struct_info A boolean flag indicating if the property "every
Expr
+ * must have defined structure info" will be checked.
+ * \return true if the IRModule is well formed, false if not.
+ * \note By default the structure info is always checked. It is only in test
cases
+ * where `check_struct_info` might be false, so that other well-formed
requirements
+ * will be well tested and will not be blocked by not having structure info.
+ */
+TVM_DLL bool WellFormed(IRModule m, bool check_struct_info = true);
+
} // namespace relax
} // namespace tvm
diff --git a/python/tvm/relax/analysis/analysis.py
b/python/tvm/relax/analysis/analysis.py
index 27416c3a79..7107883478 100644
--- a/python/tvm/relax/analysis/analysis.py
+++ b/python/tvm/relax/analysis/analysis.py
@@ -25,6 +25,7 @@ from typing import Dict
from enum import IntEnum
from tvm import tir
+from tvm import IRModule
from tvm.relax.ty import Type
from tvm.relax.struct_info import StructInfo, FuncStructInfo
from tvm.relax.expr import Var, Expr, Call
@@ -207,3 +208,29 @@ def has_reshape_pattern(func: tir.PrimFunc) -> bool:
of this function.
"""
return _ffi_api.has_reshape_pattern(func) # type: ignore
+
+
+def well_formed(mod: IRModule, check_struct_info: bool = True) -> bool:
+ """Check if the IRModule is well formed.
+
+ Parameters
+ ----------
+ mod : tvm.IRModule
+ The input IRModule.
+
+ check_struct_info : bool
+ A boolean flag indicating if the property "every Expr must
+ have defined structure info" will be checked.
+
+ Returns
+ -------
+ ret: bool
+ True if the IRModule is well formed, False if not.
+
+ Note
+ ----
+ By default the structure info is always checked. It is only in test cases
+ where `check_struct_info` might be false, so that other well-formed
requirements
+ will be well tested and will not be blocked by not having structure info.
+ """
+ return _ffi_api.well_formed(mod, check_struct_info) # type: ignore
diff --git a/python/tvm/relax/ir/instrument.py
b/python/tvm/relax/ir/instrument.py
new file mode 100644
index 0000000000..fc51a796a7
--- /dev/null
+++ b/python/tvm/relax/ir/instrument.py
@@ -0,0 +1,37 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Common relax pass instrumentation across IR variants."""
+import tvm
+from tvm import relax
+
+
[email protected]_instrument
+class WellFormedInstrument:
+ """An instrument that checks the input/output IRModule of the Pass
+ is well formed. It will skip specific passes, like Normalize.
+ """
+
+ def __init__(self):
+ self.skip_pass_name = ["Normalize", "ResolveGlobals"]
+
+ def run_before_pass(self, mod, pass_info):
+ if pass_info.name not in self.skip_pass_name:
+ assert relax.analysis.well_formed(mod)
+
+ def run_after_pass(self, mod, pass_info):
+ if pass_info.name not in self.skip_pass_name:
+ assert relax.analysis.well_formed(mod)
diff --git a/src/relax/analysis/well_formed.cc
b/src/relax/analysis/well_formed.cc
new file mode 100644
index 0000000000..e7ec237fd5
--- /dev/null
+++ b/src/relax/analysis/well_formed.cc
@@ -0,0 +1,465 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file relax/analysis/well_formed.cc
+ * \brief Check if the IRModule is well-formed.
+ *
+ * This pass is supposed to be applied to normalized Relax AST.
+ * If it's malformed, messages will be logged as Warning.
+ * This pass will check:
+ * 1. Each Expr should have `struct_info_` field already populated, when
+ * `check_struct_info` is true.
+ * 2. GlobalVars are defined before use.
+ * 3. When a Function has a corresponding GlobalVar and a `global_symbol`
+ * attribute, the name of the GlobalVar must equal the value of the
+ * `global_symbol` attribute value.
+ * 4. Any variable cannot used as different function parameters in the same
IRModule
+ * 5. Vars are defined before use.
+ * 6. Vars are defined exactly once.
+ * 7. Symbolic Vars are defined before use.
+ * 8. DataflowVars cannot be defined inside BindingBlock.
+ * 9. Vars defined in IfNode, except the return Var, are invisible
+ * out of the If body.(May change for new AST designs)
+ * 10. SeqExpr only serves as function body, or in the true and
+ * false branches in IfNode.
+ * 11. The IR is in ANF:
+ * (a) Expressions cannot contain nested complex expressions.
+ * Here are the expressions that may be nested inside other
expressions:
+ * Var, DataflowVar, GlobalVar, Constant, ShapeExpr,
+ * Op, Tuple (we call these "leaf" expressions).
+ * (b) The right-hand side of a binding may contain a non-leaf expression
+ * (where all expressions nested in it are leaf expressions),
+ * other than SeqExprs (see rule 6)
+ * (c) Exceptions: The body of a Function node and the true branch
+ * and false branch of If nodes *must* be SeqExprs.
+ * (d) Places where non-leaf expressions cannot appear:
+ * * The tuple_value field of TupleGetItem nodes
+ * * The cond field of If nodes
+ * * The op or args fields of Call nodes
+ * * Inside the fields of Tuple nodes
+ * 12. Expr always has checked_type_ (with the exception of Op).
+ */
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/expr.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/struct_info_functor.h>
+#include <tvm/relax/utils.h>
+#include <tvm/tir/expr_functor.h>
+
+#include <unordered_set>
+
+namespace tvm {
+namespace relax {
+
+// TODO(relax-team): Consider further refactor using
+// Scope Frame to store manage the var context.
+//
+/*! \brief Helper to implement well formed check.*/
+class WellFormedChecker : public relax::ExprVisitor,
+ public relax::StructInfoVisitor,
+ public tir::ExprVisitor {
+ public:
+ static bool Check(IRModule mod, bool check_struct_info) {
+ WellFormedChecker well_formed_checker = WellFormedChecker(mod,
check_struct_info);
+
+ for (const auto& it : mod->functions) {
+ // visit relax.Function
+ if (auto* n = it.second.as<FunctionNode>()) {
+ Function func = GetRef<Function>(n);
+ well_formed_checker.CheckGlobalVarAndGsymbolConsistency(it.first,
func);
+ well_formed_checker.VisitExpr(func);
+ }
+ }
+ return well_formed_checker.well_formed_;
+ }
+
+ private:
+ explicit WellFormedChecker(IRModule mod, bool check_struct_info)
+ : mod_(std::move(mod)), check_struct_info_(check_struct_info) {}
+
+ using relax::ExprVisitor::VisitExpr_;
+ using tir::ExprVisitor::VisitExpr;
+ using tir::ExprVisitor::VisitExpr_;
+
+ // Possible mode of visitor
+ enum class VisitMode {
+ /*!
+ * \brief Check all vars are well-defined
+ */
+ kDefault,
+ /*!
+ * \brief Match define the vars on first occurance.
+ * Do not check the well-defined property of composite expr.
+ */
+ kMatchVarDef
+ };
+
+ void Malformed(Diagnostic diag) {
+ well_formed_ = false;
+ LOG(WARNING) << "This IR is not well formed: " << diag->message;
+ }
+
+ void CheckGlobalVarAndGsymbolConsistency(GlobalVar var, Function func) {
+ // check name in global var and gsymbol
+ Optional<String> gsymbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+ if (gsymbol.defined() && gsymbol != var->name_hint) {
+ Malformed(Diagnostic::Error(func->span)
+ << "Name in GlobalVar is not equal to name in gsymbol: " <<
var->name_hint
+ << " != " << gsymbol.value());
+ }
+ }
+
+ void VisitExpr(const Expr& expr) final {
+ if (!expr.as<OpNode>() && !expr->checked_type_.defined()) {
+ Malformed(Diagnostic::Error(expr) << "The checked_type_ of Expr " <<
expr << " is nullptr.");
+ }
+ relax::ExprVisitor::VisitExpr(expr);
+ }
+
+ void VisitExpr_(const GlobalVarNode* op) final {
+ GlobalVar var = GetRef<GlobalVar>(op);
+ if (!(mod_->ContainGlobalVar(var->name_hint) &&
+ mod_->GetGlobalVar(var->name_hint).same_as(var))) {
+ Malformed(Diagnostic::Error(var) << "GlobalVar " << op->name_hint << "
is not defined.");
+ }
+
+ if (op->checked_type_.defined()) {
+ if ((!op->checked_type_->IsInstance<FuncTypeNode>()) &&
+ (!op->checked_type_->IsInstance<PackedFuncTypeNode>())) {
+ Malformed(Diagnostic::Error(var) << "The checked_type_ of GlobalVar "
<< op->name_hint
+ << " must be either FuncType or
PackedFuncType.");
+ }
+ }
+
+ CheckStructInfo(op);
+ }
+
+ void VisitExpr_(const TupleNode* op) final {
+ for (size_t i = 0; i < op->fields.size(); i++) {
+ Expr expr = op->fields[i];
+ if (IsLeafOrTuple(expr)) {
+ this->VisitExpr(expr);
+ } else {
+ Malformed(Diagnostic::Error(expr)
+ << "Tuple is not in ANF form, field " << i << " gets " <<
expr->GetTypeKey());
+ }
+ }
+
+ CheckStructInfo(op);
+ }
+
+ void VisitExpr_(const TupleGetItemNode* op) final {
+ if (IsLeafOrTuple(op->tuple)) {
+ this->VisitExpr(op->tuple);
+ } else {
+ Malformed(Diagnostic::Error(op)
+ << "The tuple value in a TupleGetItem node must be a leaf
expression.");
+ }
+ CheckStructInfo(op);
+ }
+
+ void VisitExpr_(const VarNode* op) final {
+ Var var = GetRef<Var>(op);
+ if (var_set_.count(var) == 0) {
+ Malformed(Diagnostic::Error(var) << "Var " << op->name_hint() << " is
not defined.");
+ }
+ CheckStructInfo(op);
+ }
+
+ void VisitExpr_(const DataflowVarNode* op) final {
+ DataflowVar var = GetRef<DataflowVar>(op);
+ if (!is_dataflow_) {
+ Malformed(Diagnostic::Error(var)
+ << "DataflowVar " << op->name_hint() << " is used outside
DataflowBlock.");
+ }
+ if (dataflow_var_set_.count(var) == 0) {
+ Malformed(Diagnostic::Error(var) << "DataflowVar " << op->name_hint() <<
" is not defined.");
+ }
+ CheckStructInfo(op);
+ }
+
+ void VisitExpr_(const FunctionNode* op) final {
+ // save the var_set_ for local function
+ auto prev_var_set = var_set_;
+ auto prev_dataflow_var_set = dataflow_var_set_;
+ auto prev_symbolic_var_set = symbolic_var_set_;
+ bool old_dataflow_state = is_dataflow_;
+ // symbolic var is not captured across function boundaries
+ symbolic_var_set_.clear();
+ is_dataflow_ = false;
+
+ // first populate defs in params
+ WithMode(VisitMode::kMatchVarDef, [&]() {
+ ICHECK(mode_ == VisitMode::kMatchVarDef);
+ for (Var param : op->params) {
+ relax::StructInfoVisitor::VisitStructInfo(GetStructInfo(param));
+ }
+ });
+
+ // check all expr are well defined.
+ for (Var param : op->params) {
+ this->VisitVarDef(param);
+
+ if (param_var_func_map_.count(param) == 1) {
+ // TODO(relax-team): Complete this error info after we integrate
printer
+ Malformed(Diagnostic::Error(param->span)
+ << "Relax variable " << param->name_hint()
+ << " is repeatedly used as parameters in function.");
+ }
+ param_var_func_map_.insert({param, GetRef<Function>(op)});
+ }
+
+ if (auto seq = op->body.as<SeqExprNode>()) {
+ this->VisitSeqExpr(seq);
+ } else {
+ Malformed(Diagnostic::Error(op) << "Function bodies must be sequence
expressions");
+ }
+
+ is_dataflow_ = old_dataflow_state;
+ dataflow_var_set_ = prev_dataflow_var_set;
+ var_set_ = prev_var_set;
+ symbolic_var_set_ = prev_symbolic_var_set;
+ }
+
+ void VisitExpr_(const CallNode* op) final {
+ if (IsLeafOrTuple(op->op)) {
+ this->VisitExpr(op->op);
+ } else {
+ Malformed(Diagnostic::Error(op) << "The called expression must be a leaf
expression");
+ }
+ for (size_t i = 0; i < op->args.size(); i++) {
+ Expr arg = op->args[i];
+ if (IsLeafOrTuple(arg)) {
+ this->VisitExpr(arg);
+ } else {
+ Malformed(Diagnostic::Error(arg->span)
+ << "Call is not in ANF form, arg " << i << " gets " <<
arg->GetTypeKey());
+ }
+ }
+
+ for (const StructInfo& sinfo_arg : op->sinfo_args) {
+ this->VisitStructInfo(sinfo_arg);
+ }
+
+ CheckStructInfo(op);
+ }
+
+ void VisitExpr_(const IfNode* op) final {
+ if (IsLeafOrTuple(op->cond)) {
+ this->VisitExpr(op->cond);
+ } else {
+ Malformed(Diagnostic::Error(op) << "The condition for an if node must be
a leaf expression.");
+ }
+ auto true_seq = op->true_branch.as<SeqExprNode>();
+ auto false_seq = op->false_branch.as<SeqExprNode>();
+ if (true_seq && false_seq) {
+ std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> previous_var_set
= var_set_;
+ std::unordered_set<tir::Var, ObjectPtrHash, ObjectPtrEqual>
previous_symbolic_var_set =
+ symbolic_var_set_;
+ this->VisitSeqExpr(true_seq);
+ var_set_ = previous_var_set;
+ symbolic_var_set_ = previous_symbolic_var_set;
+ this->VisitSeqExpr(false_seq);
+ var_set_ = previous_var_set;
+ symbolic_var_set_ = previous_symbolic_var_set;
+ } else {
+ Malformed(Diagnostic::Error(op) << "If node branches must be seq exprs");
+ }
+ CheckStructInfo(op);
+ }
+
+ void VisitExpr_(const ShapeExprNode* op) final {
+ for (PrimExpr expr : op->values) {
+ // check if the symbolic vars in the expr are defined, e.g, 2 * m
+ tir::ExprVisitor::VisitExpr(expr);
+ if (!expr.dtype().is_int()) {
+ Malformed(Diagnostic::Error(expr)
+ << "Shape expressions must be of integer type, but got " <<
expr.dtype());
+ }
+ }
+ CheckStructInfo(op);
+ }
+
+ void VisitExpr_(const SeqExprNode* op) final {
+ Malformed(Diagnostic::Error(op) << "SeqExpr only serves as the function
body in FunctionNode, "
+ "or the true/false branch body in
IfNode.");
+ }
+
+ void VisitSeqExpr(const SeqExprNode* op) {
+ // a special call only if SeqExpr is the function body
+ // in FunctionNode or the true/false branch body in IfNode
+ for (BindingBlock block : op->blocks) {
+ this->VisitBindingBlock(block);
+ }
+ if (!IsLeafOrTuple(op->body)) {
+ Malformed(Diagnostic::Error(op) << "SeqExpr bodies must be leaf
expressions.");
+ }
+ this->VisitExpr(op->body);
+ CheckStructInfo(op);
+ }
+
+ void VisitBinding_(const VarBindingNode* binding) final {
+ this->VisitExpr(binding->value);
+ this->VisitVarDef(binding->var);
+ }
+
+ void VisitBinding_(const MatchCastNode* binding) final {
+ this->VisitExpr(binding->value);
+ // define the vars
+ WithMode(VisitMode::kMatchVarDef, [&]() {
this->VisitStructInfo(binding->struct_info); });
+
+ this->VisitStructInfo(binding->struct_info);
+ this->VisitVarDef(binding->var);
+ }
+
+ void VisitBindingBlock_(const DataflowBlockNode* block) final {
+ bool old_is_dataflow_ = is_dataflow_;
+ is_dataflow_ = true;
+ for (Binding binding : block->bindings) {
+ this->VisitBinding(binding);
+ }
+ is_dataflow_ = old_is_dataflow_;
+ dataflow_var_set_.clear();
+ }
+
+ void VisitVarDef_(const DataflowVarNode* var) final {
+ if (!is_dataflow_) {
+ Malformed(Diagnostic::Error(var)
+ << "DataflowVar " << var->name_hint() << " is defined outside
DataflowBlock.");
+ }
+ DataflowVar lv = GetRef<DataflowVar>(var);
+ if (dataflow_var_set_.count(lv) == 1) {
+ Malformed(Diagnostic::Error(var)
+ << "DataflowVar " << lv->name_hint() << " is defined more than
once.");
+ }
+ // register DataflowVar
+ dataflow_var_set_.insert(lv);
+ CheckStructInfo(var);
+ }
+
+ void VisitVarDef_(const VarNode* var) final {
+ Var gv = GetRef<Var>(var);
+ if (var_set_.count(gv) == 1) {
+ Malformed(Diagnostic::Error(var)
+ << "Var " << gv->name_hint() << " is defined more than once.");
+ }
+ // register Var
+ var_set_.insert(gv);
+ CheckStructInfo(var);
+ }
+
+ void VisitVarDef(const Var& var) final {
+ if (const DataflowVarNode* lv_node = var.as<DataflowVarNode>()) {
+ VisitVarDef_(lv_node);
+ } else if (const VarNode* gv_node = var.as<VarNode>()) {
+ VisitVarDef_(gv_node);
+ } else {
+ LOG(FATAL) << "TypeError: Invalid type: " << var->GetTypeKey();
+ }
+ }
+
+ void VisitExpr_(const tir::VarNode* op) final {
+ tir::Var var = GetRef<tir::Var>(op);
+ // default mode, check defined.
+ if (symbolic_var_set_.count(var) == 0) {
+ this->Malformed(Diagnostic::Error(var)
+ << "Symbolic Var " << var->name_hint << " is not
defined.");
+ }
+ }
+
+ void VisitStructInfoExprField(const Expr& expr) final {
+ if (mode_ == VisitMode::kMatchVarDef) {
+ // populate symbolic var in first occurrence
+ if (auto* op = expr.as<relax::VarNode>()) {
+ auto var = GetRef<relax::Var>(op);
+ if (var_set_.count(var) == 0) {
+ var_set_.insert(var);
+ }
+ }
+ if (auto* shape = expr.as<relax::ShapeExprNode>()) {
+ for (auto val : shape->values) {
+ this->VisitStructInfoExprField(val);
+ }
+ }
+ } else {
+ relax::ExprVisitor::VisitExpr(expr);
+ }
+ }
+
+ void VisitStructInfoExprField(const PrimExpr& expr) final {
+ if (mode_ == VisitMode::kMatchVarDef) {
+ // populate symbolic var in first occurrence
+ if (auto* op = expr.as<tir::VarNode>()) {
+ auto var = GetRef<tir::Var>(op);
+ if (symbolic_var_set_.count(var) == 0) {
+ symbolic_var_set_.insert(var);
+ }
+ }
+ } else {
+ tir::ExprVisitor::VisitExpr(expr);
+ }
+ }
+
+ void CheckStructInfo(const ExprNode* op) {
+ if (!check_struct_info_) {
+ return;
+ }
+
+ auto* sinfo = op->struct_info_.as<StructInfoNode>();
+ if (sinfo != nullptr) {
+ this->VisitStructInfo(GetRef<StructInfo>(sinfo));
+ } else {
+ Malformed(Diagnostic::Error(op) << "Expr must have struct_info
populated. "
+ << " Expr.type_key=" <<
op->GetTypeKey());
+ }
+ }
+
+ // Run callback with mode.
+ template <typename FType>
+ void WithMode(VisitMode mode, FType callback) {
+ std::swap(mode_, mode);
+ callback();
+ std::swap(mode_, mode);
+ }
+
+ IRModule mod_;
+ const bool check_struct_info_;
+ bool well_formed_ = true;
+ bool is_dataflow_;
+ // Current visit mode.
+ VisitMode mode_ = VisitMode::kDefault;
+ // set of context variables.
+ std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> var_set_;
+ std::unordered_set<DataflowVar, ObjectPtrHash, ObjectPtrEqual>
dataflow_var_set_;
+ std::unordered_set<tir::Var, ObjectPtrHash, ObjectPtrEqual>
symbolic_var_set_;
+ std::unordered_map<Var, Function, ObjectPtrHash, ObjectPtrEqual>
param_var_func_map_;
+};
+
+bool WellFormed(IRModule m, bool check_struct_info) {
+ return WellFormedChecker::Check(std::move(m), check_struct_info);
+}
+
+TVM_REGISTER_GLOBAL(("relax.analysis.well_formed"))
+ .set_body_typed([](IRModule m, bool check_struct_info) {
+ return WellFormed(m, check_struct_info);
+ });
+
+} // namespace relax
+} // namespace tvm
diff --git a/tests/python/relax/conftest.py b/tests/python/relax/conftest.py
new file mode 100644
index 0000000000..f1b1187066
--- /dev/null
+++ b/tests/python/relax/conftest.py
@@ -0,0 +1,23 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+
+import pytest
+
+import tvm
+from tvm.relax.ir.instrument import WellFormedInstrument
+
+
+tvm.transform.PassContext.current().override_instruments([WellFormedInstrument()])
diff --git a/tests/python/relax/test_analysis_well_formed.py
b/tests/python/relax/test_analysis_well_formed.py
new file mode 100644
index 0000000000..cc0de84d53
--- /dev/null
+++ b/tests/python/relax/test_analysis_well_formed.py
@@ -0,0 +1,438 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+import tvm
+import tvm.testing
+from tvm import tir
+from tvm import relax as rx
+from tvm.script import relax as R
+
+m = tir.Var("m", "int64")
+n = tir.Var("n", "int64")
+x = rx.Var("x", R.Tensor([m, n], "float32"))
+cond = rx.Var("cond", R.Tensor([], "bool"))
+
+
+def build_function(blocks, params=[]):
+ """Returns relax.function with given blocks"""
+ seq_expr = rx.SeqExpr(blocks, blocks[-1].bindings[-1].var)
+ func = rx.Function([x, cond] + params, seq_expr,
R.Tensor("float32")).with_attr(
+ "global_symbol", "foo"
+ )
+ return func
+
+
+def test_var():
+ # Error: Var gv0 is not defined
+ gv0 = rx.Var("gv0", R.Tensor([m, n], "float32"))
+ gv1 = rx.Var("gv1", R.Tensor([m, n], "float32"))
+ call_node = rx.op.add(x, gv0)
+ bindings = [rx.VarBinding(gv1, call_node)]
+ blocks = [rx.BindingBlock(bindings)]
+ func = build_function(blocks)
+ mod = tvm.IRModule({rx.GlobalVar("foo"): func})
+ assert not rx.analysis.well_formed(mod, check_struct_info=False)
+
+ # Error: Var gv0 is defined more than once
+ gv0 = rx.Var("gv0", R.Tensor([m, n], "float32"))
+ call_node = rx.op.add(x, x)
+ call_node2 = rx.op.multiply(x, x)
+ bindings = [rx.VarBinding(gv0, call_node), rx.VarBinding(gv0, call_node2)]
+ blocks = [rx.BindingBlock(bindings)]
+ func = build_function(blocks)
+ mod = tvm.IRModule({rx.GlobalVar("foo"): func})
+ assert not rx.analysis.well_formed(mod, check_struct_info=False)
+
+
+def test_dataflow_var():
+ # Error: DataflowVar lv0 is not defined
+ lv0 = rx.DataflowVar("lv0", R.Tensor([m, n], "float32"))
+ gv0 = rx.Var("gv0", R.Tensor([m, n], "float32"))
+ call_node = rx.op.add(x, lv0)
+ bindings = [rx.VarBinding(gv0, call_node)]
+ blocks = [rx.DataflowBlock(bindings)]
+ func = build_function(blocks)
+ mod = tvm.IRModule({rx.GlobalVar("foo"): func})
+ assert not rx.analysis.well_formed(mod, check_struct_info=False)
+
+ # Error: DataflowVar gv0 is defined more than once
+ lv0 = rx.DataflowVar("lv0", R.Tensor([m, n], "float32"))
+ call_node = rx.op.add(x, x)
+ call_node2 = rx.op.multiply(x, x)
+ bindings = [rx.VarBinding(lv0, call_node), rx.VarBinding(lv0, call_node2)]
+ blocks = [rx.DataflowBlock(bindings)]
+ func = build_function(blocks)
+ mod = tvm.IRModule({rx.GlobalVar("foo"): func})
+ assert not rx.analysis.well_formed(mod, check_struct_info=False)
+
+ # Error: DataflowVar lv0 is defined outside DataflowBlock
+ lv0 = rx.DataflowVar("lv0", R.Tensor([m, n], "float32"))
+ call_node = rx.op.add(x, x)
+ bindings = [rx.VarBinding(lv0, call_node)]
+ blocks = [rx.BindingBlock(bindings)]
+ func = build_function(blocks)
+ mod = tvm.IRModule({rx.GlobalVar("foo"): func})
+ assert not rx.analysis.well_formed(mod, check_struct_info=False)
+
+ # Error: DataflowVar lv0 is used outside DataflowBlock
+ lv0 = rx.DataflowVar("lv0", R.Tensor([m, n], "float32"))
+ gv0 = rx.Var("gv0", R.Tensor([m, n], "float32"))
+ call_node = rx.op.add(lv0, x)
+ bindings = [rx.VarBinding(lv0, call_node)]
+ blocks = [rx.BindingBlock(bindings)]
+ func = build_function(blocks)
+ mod = tvm.IRModule({rx.GlobalVar("foo"): func})
+ assert not rx.analysis.well_formed(mod, check_struct_info=False)
+
+
+def test_param_var():
+ v0 = rx.Var("v0", R.Tensor([m, n], "float32"))
+ v1 = rx.Var("v1", R.Tensor([m, n], "float32"))
+ v2 = rx.Var("v2", R.Tensor([m, n], "float32"))
+ bb = rx.BlockBuilder()
+ with bb.function("func1", [v0, v1]):
+ gv0 = bb.emit(rx.op.add(v0, v1))
+ bb.emit_func_output(gv0)
+ with bb.function("func2", [v0, v2]):
+ gv0 = bb.emit(rx.op.add(v2, v1))
+ bb.emit_func_output(gv0)
+ mod = bb.get()
+ assert not rx.analysis.well_formed(mod, check_struct_info=False)
+
+
+def test_global_var():
+ # Error: GlobalVar GlobalVar0 is not defined
+ gv0 = rx.Var("gv0", R.Tensor([m, n], "float32"))
+ globalvar = rx.GlobalVar("GlobalVar0")
+ call_node = rx.Call(
+ op=tvm.ir.Op.get("relax.call_tir"),
+ args=[globalvar, rx.Tuple([x]), rx.ShapeExpr([m, n])],
+ )
+ bindings = [rx.VarBinding(gv0, call_node)]
+ blocks = [rx.BindingBlock(bindings)]
+ func = build_function(blocks)
+ mod = tvm.IRModule({rx.GlobalVar("foo"): func})
+ assert not rx.analysis.well_formed(mod, check_struct_info=False)
+
+
+def test_symbolic_var():
+ # Error: Symbolic Var new_s is not defined
+ new_s = tir.Var("new_s", "int64")
+ gv0 = rx.Var("gv0", R.Tensor([m, new_s], "int64"))
+ call_node = rx.op.add(x, x)
+ bindings = [rx.VarBinding(gv0, call_node)]
+ blocks = [rx.BindingBlock(bindings)]
+ func = build_function(blocks)
+ mod = tvm.IRModule({rx.GlobalVar("foo"): func})
+ assert not rx.analysis.well_formed(mod, check_struct_info=False)
+
+
+def test_symbolic_var_invalid_type():
+ with pytest.raises(
+ tvm.TVMError, match="the value in ShapeStructInfo can only have dtype
of int64"
+ ):
+ dim = tir.Var("dim", "float32")
+ y = rx.Var("y", R.Tensor([dim], "float32"))
+ gv0 = rx.Var("gv0", R.Tensor([dim], "float32"))
+ call_node = rx.op.add(y, y)
+ bindings = [rx.VarBinding(gv0, call_node)]
+ blocks = [rx.BindingBlock(bindings)]
+ func = build_function(blocks, [y])
+ mod = tvm.IRModule({rx.GlobalVar("foo"): func})
+ assert not rx.analysis.well_formed(mod, check_struct_info=False)
+
+
+def test_seq_expr():
+ # Error: SeqExpr in VarBinding
+ gv0 = rx.Var("gv0", R.Tensor([m, n], "float32"))
+ # build a SeqExpr
+ gv1 = rx.Var("gv1", R.Tensor([m, n], "float32"))
+ call_node = rx.op.add(x, gv0)
+ _bindings = [rx.VarBinding(gv1, call_node)]
+ _blocks = [rx.BindingBlock(_bindings)]
+ _seq_expr = rx.SeqExpr(_blocks, gv1)
+ # build a Binding with the SeqExpr as value
+ bindings = [rx.VarBinding(gv0, _seq_expr)]
+ blocks = [rx.BindingBlock(bindings)]
+ func = build_function(blocks)
+ mod = tvm.IRModule({rx.GlobalVar("foo"): func})
+ assert not rx.analysis.well_formed(mod, check_struct_info=False)
+
+
+def test_if():
+ # Error: Var defined in true/false branch is invisible in the outer scope
+ # except the return Var, i.e the var in the last stmt
+ # v_in_if is invisible in the outer scope
+ v_in_if = rx.Var("v_in_if", R.Tensor([m, n], "float32"))
+ # gv0 is visible in the outer scope
+ gv0 = rx.Var("gv0", R.Tensor([m, n], "float32"))
+ # build true branch
+ true_bindings = [
+ rx.VarBinding(v_in_if, rx.op.add(x, x)),
+ rx.VarBinding(gv0, rx.op.multiply(x, x)),
+ ]
+ true_blocks = [rx.BindingBlock(true_bindings)]
+ true_seq_expr = rx.SeqExpr(true_blocks, true_blocks[-1].bindings[-1].var)
+ # build false branch
+ false_bindings = [
+ rx.VarBinding(v_in_if, rx.op.multiply(x, x)),
+ rx.VarBinding(gv0, rx.op.add(x, x)),
+ ]
+ false_blocks = [rx.BindingBlock(false_bindings)]
+ false_seq_expr = rx.SeqExpr(false_blocks,
false_blocks[-1].bindings[-1].var)
+ # build If node
+ if_node = rx.If(cond=cond, true_branch=true_seq_expr,
false_branch=false_seq_expr)
+ gv1 = rx.Var("gv1", R.Tensor([m, n], "float32"))
+ # try to call v_in_if defined in the true/false branch
+ bindings = [rx.VarBinding(gv0, if_node), rx.VarBinding(gv1, v_in_if)]
+ blocks = [rx.BindingBlock(bindings)]
+ func = build_function(blocks)
+ mod = tvm.IRModule({rx.GlobalVar("foo"): func})
+ assert not rx.analysis.well_formed(mod, check_struct_info=True)
+
+
+def test_if_non_seq_body():
+ # Error: If node has a body that is not a seq node
+ if_node = rx.If(cond=cond, true_branch=x, false_branch=x)
+ blocks = [
+ rx.BindingBlock(
+ [
+ rx.VarBinding(
+ rx.Var("gv1", R.Tensor([m, n], "float32")),
+ if_node,
+ )
+ ]
+ )
+ ]
+ func = build_function(blocks)
+ mod = tvm.IRModule.from_expr(func)
+ assert not rx.analysis.well_formed(mod, check_struct_info=False)
+
+ # on the other hand, if they're wrapped in a seq node, it's fine
+ seq = rx.SeqExpr([], x)
+ new_if_node = rx.If(cond=cond, true_branch=seq, false_branch=seq)
+ new_blocks = [
+ rx.BindingBlock(
+ [
+ rx.VarBinding(
+ rx.Var("gv1", R.Tensor([m, n], "float32")),
+ new_if_node,
+ )
+ ]
+ )
+ ]
+ new_func = build_function(new_blocks)
+ new_mod = tvm.IRModule.from_expr(new_func)
+ # apply normalization to fill in checked_type_
+ normalized = rx.transform.Normalize()(new_mod)
+ assert rx.analysis.well_formed(normalized, check_struct_info=True)
+
+
+def test_if_complex_condition():
+ # Error: If condition must be a leaf expression
+ cond_tuple = rx.Tuple([cond])
+ cond_idx = rx.TupleGetItem(cond_tuple, 0)
+ if_node = rx.If(cond_idx, rx.SeqExpr([], x), rx.SeqExpr([], x))
+ blocks = [
+ rx.BindingBlock(
+ [
+ rx.VarBinding(
+ rx.Var("gv1", R.Tensor([m, n], "float32")),
+ if_node,
+ )
+ ]
+ )
+ ]
+ func = build_function(blocks)
+ mod = tvm.IRModule.from_expr(func)
+ assert not rx.analysis.well_formed(mod, check_struct_info=False)
+
+ cond_var = rx.Var("q", R.Tensor([], "bool"))
+ new_if = rx.If(cond_var, rx.SeqExpr([], x), rx.SeqExpr([], x))
+ blocks = [
+ rx.BindingBlock(
+ [
+ rx.VarBinding(cond_var, cond_idx),
+ rx.VarBinding(
+ rx.Var("gv1", R.Tensor([m, n], "float32")),
+ new_if,
+ ),
+ ]
+ )
+ ]
+ func = build_function(blocks)
+ mod = tvm.IRModule.from_expr(func)
+ # apply normalization to fill in checked_type_
+ normalized = rx.transform.Normalize()(mod)
+ assert rx.analysis.well_formed(normalized, check_struct_info=True)
+
+
+def test_tuple_get_item_nested():
+ # Error: The tuple value in tuple get item must be a leaf expression
+ nested_tup = rx.Var(
+ "t", rx.TupleStructInfo([rx.TupleStructInfo([rx.TensorStructInfo([],
"int32")])])
+ )
+ double_idx = rx.TupleGetItem(rx.TupleGetItem(nested_tup, 0), 0)
+ ret_var = rx.Var("r", R.Tensor([], "int32"))
+ f = rx.Function(
+ [nested_tup],
+ rx.SeqExpr([rx.BindingBlock([rx.VarBinding(ret_var, double_idx)])],
ret_var),
+ ret_struct_info=R.Tensor(ndim=0, dtype="int32"),
+ )
+ f = f.with_attr("global_symbol", "f")
+ mod = tvm.IRModule.from_expr(f)
+ assert not rx.analysis.well_formed(mod, check_struct_info=False)
+
+ # okay with an intermediate binding
+ first_idx = rx.TupleGetItem(nested_tup, 0)
+ idx_var = rx.Var("v", rx.TupleStructInfo([rx.TensorStructInfo([],
"int32")]))
+ second_idx = rx.TupleGetItem(idx_var, 0)
+ new_f = rx.Function(
+ [nested_tup],
+ rx.SeqExpr(
+ [
+ rx.BindingBlock(
+ [rx.VarBinding(idx_var, first_idx), rx.VarBinding(ret_var,
second_idx)]
+ )
+ ],
+ ret_var,
+ ),
+ ret_struct_info=R.Tensor(ndim=0, dtype="int32"),
+ )
+ new_f = new_f.with_attr("global_symbol", "new_f")
+ mod = tvm.IRModule.from_expr(new_f)
+ # normalize in order to fill in checked type
+ normalized = rx.transform.Normalize()(mod)
+ assert rx.analysis.well_formed(normalized, check_struct_info=True)
+
+
+def test_complex_seq_body():
+ # Error: seq expr with a body that is not a leaf expression is not
permitted
+ x = rx.Var("x", R.Tensor([], "int32"))
+ y = rx.Var("y", R.Tensor([], "int32"))
+ func = rx.Function(
+ [x, y],
+ rx.SeqExpr([], rx.op.add(x, y)),
+ R.Tensor(ndim=0, dtype="int32"),
+ ).with_attr("global_symbol", "foo")
+ mod = tvm.IRModule.from_expr(func)
+ assert not rx.analysis.well_formed(mod, check_struct_info=False)
+
+ # but if the result is bound, then it's okay
+ z = rx.Var("z", R.Tensor([], "int32"))
+ new_func = rx.Function(
+ [x, y],
+ rx.SeqExpr(
+ [
+ rx.BindingBlock(
+ [
+ rx.VarBinding(
+ var=z,
+ value=rx.op.add(x, y),
+ )
+ ]
+ )
+ ],
+ z,
+ ),
+ R.Tensor(ndim=0, dtype="int32"),
+ ).with_attr("global_symbol", "foo")
+ new_mod = tvm.IRModule.from_expr(new_func)
+ # normalize in order to fill in checked type
+ normalized = rx.transform.Normalize()(new_mod)
+ assert rx.analysis.well_formed(normalized, check_struct_info=True)
+
+
+def test_ANF():
+ # Error: Nested Call
+ gv0 = rx.Var("gv0", R.Tensor([m, n], "float32"))
+ call_node = rx.op.add(x, rx.op.add(x, x))
+ bindings = [rx.VarBinding(gv0, call_node)]
+ blocks = [rx.BindingBlock(bindings)]
+ func = build_function(blocks)
+ mod = tvm.IRModule({rx.GlobalVar("foo"): func})
+ assert not rx.analysis.well_formed(mod, check_struct_info=False)
+
+ # Error: Call Node in Tuple
+ gv0 = rx.Var("gv0", R.Tensor([m, n], "float32"))
+ bindings = [rx.VarBinding(gv0, rx.Tuple((x, rx.op.add(x, x))))]
+ blocks = [rx.BindingBlock(bindings)]
+ func = build_function(blocks)
+ mod = tvm.IRModule({rx.GlobalVar("foo"): func})
+ assert not rx.analysis.well_formed(mod, check_struct_info=False)
+
+
+def test_global_var_vs_gsymbol():
+ # Error: gsymbol "main1" not equals to the name in global var "main"
+ gv0 = rx.Var("gv0", R.Tensor([m, n], "float32"))
+ bindings = [rx.VarBinding(gv0, x)]
+ blocks = [rx.DataflowBlock(bindings)]
+ func = rx.Function(
+ [x],
+ rx.SeqExpr(blocks, gv0),
+ R.Tensor(ndim=2, dtype="float32"),
+ ).with_attr("global_symbol", "main1")
+ mod = tvm.IRModule({rx.GlobalVar("main"): func})
+ assert not rx.analysis.well_formed(mod, check_struct_info=False)
+
+
+def test_nested_dataflow():
+ scalar_struct_info = rx.TensorStructInfo(shape=[], dtype="int32")
+ gv0 = rx.Var("gv0", scalar_struct_info)
+ f = rx.DataflowVar("f", rx.FuncStructInfo([], scalar_struct_info))
+ x0 = rx.DataflowVar("x0", scalar_struct_info)
+ x1 = rx.DataflowVar("x1", scalar_struct_info)
+ x2 = rx.DataflowVar("x2", scalar_struct_info)
+ y = rx.Var("y", scalar_struct_info)
+ inner_block = rx.DataflowBlock([rx.VarBinding(x0, rx.const(2, "int32")),
rx.VarBinding(y, x0)])
+ inner_func = rx.Function([], rx.SeqExpr([inner_block], y),
scalar_struct_info)
+ outer_block = rx.DataflowBlock(
+ [
+ rx.VarBinding(x1, rx.const(1, "int32")),
+ rx.VarBinding(f, inner_func),
+ rx.VarBinding(x2, rx.op.add(x1, rx.Call(f, []))),
+ rx.VarBinding(gv0, x2),
+ ]
+ )
+ func = rx.Function([], rx.SeqExpr([outer_block], gv0), scalar_struct_info)
+ mod = tvm.IRModule.from_expr(func)
+ normalized = rx.transform.Normalize()(mod)
+ assert rx.analysis.well_formed(normalized)
+
+
+def test_sinfo_args_tir_var_used_before_define_call_packed():
+ # Error: Symbolic Var m1, n1 are not defined
+ m1 = tir.Var("m1", "int64")
+ n1 = tir.Var("n1", "int64")
+ call = R.call_packed("my_func", x, sinfo_args=R.Tensor((m1, n1),
"float32"))
+ func = build_function([rx.BindingBlock([rx.VarBinding(rx.Var("gv"),
call)])])
+ mod = rx.transform.Normalize()(tvm.IRModule.from_expr(func))
+ assert not rx.analysis.well_formed(mod, check_struct_info=False)
+
+
+def test_sinfo_args_tir_var_used_before_define_call_tir():
+ # Error: Symbolic Var m1, n1 are not defined
+ m1 = tir.Var("m1", "int64")
+ n1 = tir.Var("n1", "int64")
+ call = R.call_tir("my_func", x, out_sinfo=R.Tensor((m1, n1), "float32"))
+ func = build_function([rx.BindingBlock([rx.VarBinding(rx.Var("gv"),
call)])])
+ mod = rx.transform.Normalize()(tvm.IRModule.from_expr(func))
+ assert not rx.analysis.well_formed(mod, check_struct_info=False)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()