This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch unity-staging in repository https://gitbox.apache.org/repos/asf/tvm.git
commit e97060fd231610f8c5a80663e9d810f69d4e1328 Author: Lesheng Jin <[email protected]> AuthorDate: Fri Feb 17 18:11:08 2023 -0800 [Unity][Pass] Wellformed Analysis (#14032) This PR implements relax wellformed analysis, which checks if the IRModule is well-formed. (tests and examples are added). Co-Authored-by: Ruihang Lai <[email protected]> Co-Authored-by: Siyuan Feng <[email protected]> Co-Authored-by: Tianqi Chen <[email protected]> Co-authored-by: Steven S. Lyubomirsky <[email protected]> Co-authored-by: Yong Wu <[email protected]> Co-Authored-by: Yuchen Jin <[email protected]> Co-Authored-by: Yixin Dong <[email protected]> Co-Authored-by: Chaofan Lin <[email protected]> Co-Authored-by: Prakalp Srivastava <[email protected]> Co-Authored-by: Junru Shao <[email protected]> --- include/tvm/relax/analysis.h | 13 + python/tvm/relax/analysis/analysis.py | 27 ++ python/tvm/relax/ir/instrument.py | 37 ++ src/relax/analysis/well_formed.cc | 465 ++++++++++++++++++++++++ tests/python/relax/conftest.py | 23 ++ tests/python/relax/test_analysis_well_formed.py | 438 ++++++++++++++++++++++ 6 files changed, 1003 insertions(+) diff --git a/include/tvm/relax/analysis.h b/include/tvm/relax/analysis.h index ff576d4ebb..f9896efdf2 100644 --- a/include/tvm/relax/analysis.h +++ b/include/tvm/relax/analysis.h @@ -343,6 +343,19 @@ TVM_DLL relay::OpPatternKind AnalyzeOpPatternKind(const tir::PrimFunc& func); */ TVM_DLL bool HasReshapePattern(const tir::PrimFunc& func); +/*! + * \brief Check if the IRModule is well formed. + * + * \param m the IRModule to check. + * \param check_struct_info A boolean flag indicating if the property "every Expr + * must have defined structure info" will be checked. + * \return true if the IRModule is well formed, false if not. + * \note By default the structure info is always checked. It is only in test cases + * where `check_struct_info` might be false, so that other well-formed requirements + * will be well tested and will not be blocked by not having structure info. + */ +TVM_DLL bool WellFormed(IRModule m, bool check_struct_info = true); + } // namespace relax } // namespace tvm diff --git a/python/tvm/relax/analysis/analysis.py b/python/tvm/relax/analysis/analysis.py index 27416c3a79..7107883478 100644 --- a/python/tvm/relax/analysis/analysis.py +++ b/python/tvm/relax/analysis/analysis.py @@ -25,6 +25,7 @@ from typing import Dict from enum import IntEnum from tvm import tir +from tvm import IRModule from tvm.relax.ty import Type from tvm.relax.struct_info import StructInfo, FuncStructInfo from tvm.relax.expr import Var, Expr, Call @@ -207,3 +208,29 @@ def has_reshape_pattern(func: tir.PrimFunc) -> bool: of this function. """ return _ffi_api.has_reshape_pattern(func) # type: ignore + + +def well_formed(mod: IRModule, check_struct_info: bool = True) -> bool: + """Check if the IRModule is well formed. + + Parameters + ---------- + mod : tvm.IRModule + The input IRModule. + + check_struct_info : bool + A boolean flag indicating if the property "every Expr must + have defined structure info" will be checked. + + Returns + ------- + ret: bool + True if the IRModule is well formed, False if not. + + Note + ---- + By default the structure info is always checked. It is only in test cases + where `check_struct_info` might be false, so that other well-formed requirements + will be well tested and will not be blocked by not having structure info. + """ + return _ffi_api.well_formed(mod, check_struct_info) # type: ignore diff --git a/python/tvm/relax/ir/instrument.py b/python/tvm/relax/ir/instrument.py new file mode 100644 index 0000000000..fc51a796a7 --- /dev/null +++ b/python/tvm/relax/ir/instrument.py @@ -0,0 +1,37 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Common relax pass instrumentation across IR variants.""" +import tvm +from tvm import relax + + [email protected]_instrument +class WellFormedInstrument: + """An instrument that checks the input/output IRModule of the Pass + is well formed. It will skip specific passes, like Normalize. + """ + + def __init__(self): + self.skip_pass_name = ["Normalize", "ResolveGlobals"] + + def run_before_pass(self, mod, pass_info): + if pass_info.name not in self.skip_pass_name: + assert relax.analysis.well_formed(mod) + + def run_after_pass(self, mod, pass_info): + if pass_info.name not in self.skip_pass_name: + assert relax.analysis.well_formed(mod) diff --git a/src/relax/analysis/well_formed.cc b/src/relax/analysis/well_formed.cc new file mode 100644 index 0000000000..e7ec237fd5 --- /dev/null +++ b/src/relax/analysis/well_formed.cc @@ -0,0 +1,465 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relax/analysis/well_formed.cc + * \brief Check if the IRModule is well-formed. + * + * This pass is supposed to be applied to normalized Relax AST. + * If it's malformed, messages will be logged as Warning. + * This pass will check: + * 1. Each Expr should have `struct_info_` field already populated, when + * `check_struct_info` is true. + * 2. GlobalVars are defined before use. + * 3. When a Function has a corresponding GlobalVar and a `global_symbol` + * attribute, the name of the GlobalVar must equal the value of the + * `global_symbol` attribute value. + * 4. Any variable cannot used as different function parameters in the same IRModule + * 5. Vars are defined before use. + * 6. Vars are defined exactly once. + * 7. Symbolic Vars are defined before use. + * 8. DataflowVars cannot be defined inside BindingBlock. + * 9. Vars defined in IfNode, except the return Var, are invisible + * out of the If body.(May change for new AST designs) + * 10. SeqExpr only serves as function body, or in the true and + * false branches in IfNode. + * 11. The IR is in ANF: + * (a) Expressions cannot contain nested complex expressions. + * Here are the expressions that may be nested inside other expressions: + * Var, DataflowVar, GlobalVar, Constant, ShapeExpr, + * Op, Tuple (we call these "leaf" expressions). + * (b) The right-hand side of a binding may contain a non-leaf expression + * (where all expressions nested in it are leaf expressions), + * other than SeqExprs (see rule 6) + * (c) Exceptions: The body of a Function node and the true branch + * and false branch of If nodes *must* be SeqExprs. + * (d) Places where non-leaf expressions cannot appear: + * * The tuple_value field of TupleGetItem nodes + * * The cond field of If nodes + * * The op or args fields of Call nodes + * * Inside the fields of Tuple nodes + * 12. Expr always has checked_type_ (with the exception of Op). + */ +#include <tvm/relax/analysis.h> +#include <tvm/relax/expr.h> +#include <tvm/relax/expr_functor.h> +#include <tvm/relax/struct_info_functor.h> +#include <tvm/relax/utils.h> +#include <tvm/tir/expr_functor.h> + +#include <unordered_set> + +namespace tvm { +namespace relax { + +// TODO(relax-team): Consider further refactor using +// Scope Frame to store manage the var context. +// +/*! \brief Helper to implement well formed check.*/ +class WellFormedChecker : public relax::ExprVisitor, + public relax::StructInfoVisitor, + public tir::ExprVisitor { + public: + static bool Check(IRModule mod, bool check_struct_info) { + WellFormedChecker well_formed_checker = WellFormedChecker(mod, check_struct_info); + + for (const auto& it : mod->functions) { + // visit relax.Function + if (auto* n = it.second.as<FunctionNode>()) { + Function func = GetRef<Function>(n); + well_formed_checker.CheckGlobalVarAndGsymbolConsistency(it.first, func); + well_formed_checker.VisitExpr(func); + } + } + return well_formed_checker.well_formed_; + } + + private: + explicit WellFormedChecker(IRModule mod, bool check_struct_info) + : mod_(std::move(mod)), check_struct_info_(check_struct_info) {} + + using relax::ExprVisitor::VisitExpr_; + using tir::ExprVisitor::VisitExpr; + using tir::ExprVisitor::VisitExpr_; + + // Possible mode of visitor + enum class VisitMode { + /*! + * \brief Check all vars are well-defined + */ + kDefault, + /*! + * \brief Match define the vars on first occurance. + * Do not check the well-defined property of composite expr. + */ + kMatchVarDef + }; + + void Malformed(Diagnostic diag) { + well_formed_ = false; + LOG(WARNING) << "This IR is not well formed: " << diag->message; + } + + void CheckGlobalVarAndGsymbolConsistency(GlobalVar var, Function func) { + // check name in global var and gsymbol + Optional<String> gsymbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol); + if (gsymbol.defined() && gsymbol != var->name_hint) { + Malformed(Diagnostic::Error(func->span) + << "Name in GlobalVar is not equal to name in gsymbol: " << var->name_hint + << " != " << gsymbol.value()); + } + } + + void VisitExpr(const Expr& expr) final { + if (!expr.as<OpNode>() && !expr->checked_type_.defined()) { + Malformed(Diagnostic::Error(expr) << "The checked_type_ of Expr " << expr << " is nullptr."); + } + relax::ExprVisitor::VisitExpr(expr); + } + + void VisitExpr_(const GlobalVarNode* op) final { + GlobalVar var = GetRef<GlobalVar>(op); + if (!(mod_->ContainGlobalVar(var->name_hint) && + mod_->GetGlobalVar(var->name_hint).same_as(var))) { + Malformed(Diagnostic::Error(var) << "GlobalVar " << op->name_hint << " is not defined."); + } + + if (op->checked_type_.defined()) { + if ((!op->checked_type_->IsInstance<FuncTypeNode>()) && + (!op->checked_type_->IsInstance<PackedFuncTypeNode>())) { + Malformed(Diagnostic::Error(var) << "The checked_type_ of GlobalVar " << op->name_hint + << " must be either FuncType or PackedFuncType."); + } + } + + CheckStructInfo(op); + } + + void VisitExpr_(const TupleNode* op) final { + for (size_t i = 0; i < op->fields.size(); i++) { + Expr expr = op->fields[i]; + if (IsLeafOrTuple(expr)) { + this->VisitExpr(expr); + } else { + Malformed(Diagnostic::Error(expr) + << "Tuple is not in ANF form, field " << i << " gets " << expr->GetTypeKey()); + } + } + + CheckStructInfo(op); + } + + void VisitExpr_(const TupleGetItemNode* op) final { + if (IsLeafOrTuple(op->tuple)) { + this->VisitExpr(op->tuple); + } else { + Malformed(Diagnostic::Error(op) + << "The tuple value in a TupleGetItem node must be a leaf expression."); + } + CheckStructInfo(op); + } + + void VisitExpr_(const VarNode* op) final { + Var var = GetRef<Var>(op); + if (var_set_.count(var) == 0) { + Malformed(Diagnostic::Error(var) << "Var " << op->name_hint() << " is not defined."); + } + CheckStructInfo(op); + } + + void VisitExpr_(const DataflowVarNode* op) final { + DataflowVar var = GetRef<DataflowVar>(op); + if (!is_dataflow_) { + Malformed(Diagnostic::Error(var) + << "DataflowVar " << op->name_hint() << " is used outside DataflowBlock."); + } + if (dataflow_var_set_.count(var) == 0) { + Malformed(Diagnostic::Error(var) << "DataflowVar " << op->name_hint() << " is not defined."); + } + CheckStructInfo(op); + } + + void VisitExpr_(const FunctionNode* op) final { + // save the var_set_ for local function + auto prev_var_set = var_set_; + auto prev_dataflow_var_set = dataflow_var_set_; + auto prev_symbolic_var_set = symbolic_var_set_; + bool old_dataflow_state = is_dataflow_; + // symbolic var is not captured across function boundaries + symbolic_var_set_.clear(); + is_dataflow_ = false; + + // first populate defs in params + WithMode(VisitMode::kMatchVarDef, [&]() { + ICHECK(mode_ == VisitMode::kMatchVarDef); + for (Var param : op->params) { + relax::StructInfoVisitor::VisitStructInfo(GetStructInfo(param)); + } + }); + + // check all expr are well defined. + for (Var param : op->params) { + this->VisitVarDef(param); + + if (param_var_func_map_.count(param) == 1) { + // TODO(relax-team): Complete this error info after we integrate printer + Malformed(Diagnostic::Error(param->span) + << "Relax variable " << param->name_hint() + << " is repeatedly used as parameters in function."); + } + param_var_func_map_.insert({param, GetRef<Function>(op)}); + } + + if (auto seq = op->body.as<SeqExprNode>()) { + this->VisitSeqExpr(seq); + } else { + Malformed(Diagnostic::Error(op) << "Function bodies must be sequence expressions"); + } + + is_dataflow_ = old_dataflow_state; + dataflow_var_set_ = prev_dataflow_var_set; + var_set_ = prev_var_set; + symbolic_var_set_ = prev_symbolic_var_set; + } + + void VisitExpr_(const CallNode* op) final { + if (IsLeafOrTuple(op->op)) { + this->VisitExpr(op->op); + } else { + Malformed(Diagnostic::Error(op) << "The called expression must be a leaf expression"); + } + for (size_t i = 0; i < op->args.size(); i++) { + Expr arg = op->args[i]; + if (IsLeafOrTuple(arg)) { + this->VisitExpr(arg); + } else { + Malformed(Diagnostic::Error(arg->span) + << "Call is not in ANF form, arg " << i << " gets " << arg->GetTypeKey()); + } + } + + for (const StructInfo& sinfo_arg : op->sinfo_args) { + this->VisitStructInfo(sinfo_arg); + } + + CheckStructInfo(op); + } + + void VisitExpr_(const IfNode* op) final { + if (IsLeafOrTuple(op->cond)) { + this->VisitExpr(op->cond); + } else { + Malformed(Diagnostic::Error(op) << "The condition for an if node must be a leaf expression."); + } + auto true_seq = op->true_branch.as<SeqExprNode>(); + auto false_seq = op->false_branch.as<SeqExprNode>(); + if (true_seq && false_seq) { + std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> previous_var_set = var_set_; + std::unordered_set<tir::Var, ObjectPtrHash, ObjectPtrEqual> previous_symbolic_var_set = + symbolic_var_set_; + this->VisitSeqExpr(true_seq); + var_set_ = previous_var_set; + symbolic_var_set_ = previous_symbolic_var_set; + this->VisitSeqExpr(false_seq); + var_set_ = previous_var_set; + symbolic_var_set_ = previous_symbolic_var_set; + } else { + Malformed(Diagnostic::Error(op) << "If node branches must be seq exprs"); + } + CheckStructInfo(op); + } + + void VisitExpr_(const ShapeExprNode* op) final { + for (PrimExpr expr : op->values) { + // check if the symbolic vars in the expr are defined, e.g, 2 * m + tir::ExprVisitor::VisitExpr(expr); + if (!expr.dtype().is_int()) { + Malformed(Diagnostic::Error(expr) + << "Shape expressions must be of integer type, but got " << expr.dtype()); + } + } + CheckStructInfo(op); + } + + void VisitExpr_(const SeqExprNode* op) final { + Malformed(Diagnostic::Error(op) << "SeqExpr only serves as the function body in FunctionNode, " + "or the true/false branch body in IfNode."); + } + + void VisitSeqExpr(const SeqExprNode* op) { + // a special call only if SeqExpr is the function body + // in FunctionNode or the true/false branch body in IfNode + for (BindingBlock block : op->blocks) { + this->VisitBindingBlock(block); + } + if (!IsLeafOrTuple(op->body)) { + Malformed(Diagnostic::Error(op) << "SeqExpr bodies must be leaf expressions."); + } + this->VisitExpr(op->body); + CheckStructInfo(op); + } + + void VisitBinding_(const VarBindingNode* binding) final { + this->VisitExpr(binding->value); + this->VisitVarDef(binding->var); + } + + void VisitBinding_(const MatchCastNode* binding) final { + this->VisitExpr(binding->value); + // define the vars + WithMode(VisitMode::kMatchVarDef, [&]() { this->VisitStructInfo(binding->struct_info); }); + + this->VisitStructInfo(binding->struct_info); + this->VisitVarDef(binding->var); + } + + void VisitBindingBlock_(const DataflowBlockNode* block) final { + bool old_is_dataflow_ = is_dataflow_; + is_dataflow_ = true; + for (Binding binding : block->bindings) { + this->VisitBinding(binding); + } + is_dataflow_ = old_is_dataflow_; + dataflow_var_set_.clear(); + } + + void VisitVarDef_(const DataflowVarNode* var) final { + if (!is_dataflow_) { + Malformed(Diagnostic::Error(var) + << "DataflowVar " << var->name_hint() << " is defined outside DataflowBlock."); + } + DataflowVar lv = GetRef<DataflowVar>(var); + if (dataflow_var_set_.count(lv) == 1) { + Malformed(Diagnostic::Error(var) + << "DataflowVar " << lv->name_hint() << " is defined more than once."); + } + // register DataflowVar + dataflow_var_set_.insert(lv); + CheckStructInfo(var); + } + + void VisitVarDef_(const VarNode* var) final { + Var gv = GetRef<Var>(var); + if (var_set_.count(gv) == 1) { + Malformed(Diagnostic::Error(var) + << "Var " << gv->name_hint() << " is defined more than once."); + } + // register Var + var_set_.insert(gv); + CheckStructInfo(var); + } + + void VisitVarDef(const Var& var) final { + if (const DataflowVarNode* lv_node = var.as<DataflowVarNode>()) { + VisitVarDef_(lv_node); + } else if (const VarNode* gv_node = var.as<VarNode>()) { + VisitVarDef_(gv_node); + } else { + LOG(FATAL) << "TypeError: Invalid type: " << var->GetTypeKey(); + } + } + + void VisitExpr_(const tir::VarNode* op) final { + tir::Var var = GetRef<tir::Var>(op); + // default mode, check defined. + if (symbolic_var_set_.count(var) == 0) { + this->Malformed(Diagnostic::Error(var) + << "Symbolic Var " << var->name_hint << " is not defined."); + } + } + + void VisitStructInfoExprField(const Expr& expr) final { + if (mode_ == VisitMode::kMatchVarDef) { + // populate symbolic var in first occurrence + if (auto* op = expr.as<relax::VarNode>()) { + auto var = GetRef<relax::Var>(op); + if (var_set_.count(var) == 0) { + var_set_.insert(var); + } + } + if (auto* shape = expr.as<relax::ShapeExprNode>()) { + for (auto val : shape->values) { + this->VisitStructInfoExprField(val); + } + } + } else { + relax::ExprVisitor::VisitExpr(expr); + } + } + + void VisitStructInfoExprField(const PrimExpr& expr) final { + if (mode_ == VisitMode::kMatchVarDef) { + // populate symbolic var in first occurrence + if (auto* op = expr.as<tir::VarNode>()) { + auto var = GetRef<tir::Var>(op); + if (symbolic_var_set_.count(var) == 0) { + symbolic_var_set_.insert(var); + } + } + } else { + tir::ExprVisitor::VisitExpr(expr); + } + } + + void CheckStructInfo(const ExprNode* op) { + if (!check_struct_info_) { + return; + } + + auto* sinfo = op->struct_info_.as<StructInfoNode>(); + if (sinfo != nullptr) { + this->VisitStructInfo(GetRef<StructInfo>(sinfo)); + } else { + Malformed(Diagnostic::Error(op) << "Expr must have struct_info populated. " + << " Expr.type_key=" << op->GetTypeKey()); + } + } + + // Run callback with mode. + template <typename FType> + void WithMode(VisitMode mode, FType callback) { + std::swap(mode_, mode); + callback(); + std::swap(mode_, mode); + } + + IRModule mod_; + const bool check_struct_info_; + bool well_formed_ = true; + bool is_dataflow_; + // Current visit mode. + VisitMode mode_ = VisitMode::kDefault; + // set of context variables. + std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> var_set_; + std::unordered_set<DataflowVar, ObjectPtrHash, ObjectPtrEqual> dataflow_var_set_; + std::unordered_set<tir::Var, ObjectPtrHash, ObjectPtrEqual> symbolic_var_set_; + std::unordered_map<Var, Function, ObjectPtrHash, ObjectPtrEqual> param_var_func_map_; +}; + +bool WellFormed(IRModule m, bool check_struct_info) { + return WellFormedChecker::Check(std::move(m), check_struct_info); +} + +TVM_REGISTER_GLOBAL(("relax.analysis.well_formed")) + .set_body_typed([](IRModule m, bool check_struct_info) { + return WellFormed(m, check_struct_info); + }); + +} // namespace relax +} // namespace tvm diff --git a/tests/python/relax/conftest.py b/tests/python/relax/conftest.py new file mode 100644 index 0000000000..f1b1187066 --- /dev/null +++ b/tests/python/relax/conftest.py @@ -0,0 +1,23 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations + +import pytest + +import tvm +from tvm.relax.ir.instrument import WellFormedInstrument + + +tvm.transform.PassContext.current().override_instruments([WellFormedInstrument()]) diff --git a/tests/python/relax/test_analysis_well_formed.py b/tests/python/relax/test_analysis_well_formed.py new file mode 100644 index 0000000000..cc0de84d53 --- /dev/null +++ b/tests/python/relax/test_analysis_well_formed.py @@ -0,0 +1,438 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import tvm +import tvm.testing +from tvm import tir +from tvm import relax as rx +from tvm.script import relax as R + +m = tir.Var("m", "int64") +n = tir.Var("n", "int64") +x = rx.Var("x", R.Tensor([m, n], "float32")) +cond = rx.Var("cond", R.Tensor([], "bool")) + + +def build_function(blocks, params=[]): + """Returns relax.function with given blocks""" + seq_expr = rx.SeqExpr(blocks, blocks[-1].bindings[-1].var) + func = rx.Function([x, cond] + params, seq_expr, R.Tensor("float32")).with_attr( + "global_symbol", "foo" + ) + return func + + +def test_var(): + # Error: Var gv0 is not defined + gv0 = rx.Var("gv0", R.Tensor([m, n], "float32")) + gv1 = rx.Var("gv1", R.Tensor([m, n], "float32")) + call_node = rx.op.add(x, gv0) + bindings = [rx.VarBinding(gv1, call_node)] + blocks = [rx.BindingBlock(bindings)] + func = build_function(blocks) + mod = tvm.IRModule({rx.GlobalVar("foo"): func}) + assert not rx.analysis.well_formed(mod, check_struct_info=False) + + # Error: Var gv0 is defined more than once + gv0 = rx.Var("gv0", R.Tensor([m, n], "float32")) + call_node = rx.op.add(x, x) + call_node2 = rx.op.multiply(x, x) + bindings = [rx.VarBinding(gv0, call_node), rx.VarBinding(gv0, call_node2)] + blocks = [rx.BindingBlock(bindings)] + func = build_function(blocks) + mod = tvm.IRModule({rx.GlobalVar("foo"): func}) + assert not rx.analysis.well_formed(mod, check_struct_info=False) + + +def test_dataflow_var(): + # Error: DataflowVar lv0 is not defined + lv0 = rx.DataflowVar("lv0", R.Tensor([m, n], "float32")) + gv0 = rx.Var("gv0", R.Tensor([m, n], "float32")) + call_node = rx.op.add(x, lv0) + bindings = [rx.VarBinding(gv0, call_node)] + blocks = [rx.DataflowBlock(bindings)] + func = build_function(blocks) + mod = tvm.IRModule({rx.GlobalVar("foo"): func}) + assert not rx.analysis.well_formed(mod, check_struct_info=False) + + # Error: DataflowVar gv0 is defined more than once + lv0 = rx.DataflowVar("lv0", R.Tensor([m, n], "float32")) + call_node = rx.op.add(x, x) + call_node2 = rx.op.multiply(x, x) + bindings = [rx.VarBinding(lv0, call_node), rx.VarBinding(lv0, call_node2)] + blocks = [rx.DataflowBlock(bindings)] + func = build_function(blocks) + mod = tvm.IRModule({rx.GlobalVar("foo"): func}) + assert not rx.analysis.well_formed(mod, check_struct_info=False) + + # Error: DataflowVar lv0 is defined outside DataflowBlock + lv0 = rx.DataflowVar("lv0", R.Tensor([m, n], "float32")) + call_node = rx.op.add(x, x) + bindings = [rx.VarBinding(lv0, call_node)] + blocks = [rx.BindingBlock(bindings)] + func = build_function(blocks) + mod = tvm.IRModule({rx.GlobalVar("foo"): func}) + assert not rx.analysis.well_formed(mod, check_struct_info=False) + + # Error: DataflowVar lv0 is used outside DataflowBlock + lv0 = rx.DataflowVar("lv0", R.Tensor([m, n], "float32")) + gv0 = rx.Var("gv0", R.Tensor([m, n], "float32")) + call_node = rx.op.add(lv0, x) + bindings = [rx.VarBinding(lv0, call_node)] + blocks = [rx.BindingBlock(bindings)] + func = build_function(blocks) + mod = tvm.IRModule({rx.GlobalVar("foo"): func}) + assert not rx.analysis.well_formed(mod, check_struct_info=False) + + +def test_param_var(): + v0 = rx.Var("v0", R.Tensor([m, n], "float32")) + v1 = rx.Var("v1", R.Tensor([m, n], "float32")) + v2 = rx.Var("v2", R.Tensor([m, n], "float32")) + bb = rx.BlockBuilder() + with bb.function("func1", [v0, v1]): + gv0 = bb.emit(rx.op.add(v0, v1)) + bb.emit_func_output(gv0) + with bb.function("func2", [v0, v2]): + gv0 = bb.emit(rx.op.add(v2, v1)) + bb.emit_func_output(gv0) + mod = bb.get() + assert not rx.analysis.well_formed(mod, check_struct_info=False) + + +def test_global_var(): + # Error: GlobalVar GlobalVar0 is not defined + gv0 = rx.Var("gv0", R.Tensor([m, n], "float32")) + globalvar = rx.GlobalVar("GlobalVar0") + call_node = rx.Call( + op=tvm.ir.Op.get("relax.call_tir"), + args=[globalvar, rx.Tuple([x]), rx.ShapeExpr([m, n])], + ) + bindings = [rx.VarBinding(gv0, call_node)] + blocks = [rx.BindingBlock(bindings)] + func = build_function(blocks) + mod = tvm.IRModule({rx.GlobalVar("foo"): func}) + assert not rx.analysis.well_formed(mod, check_struct_info=False) + + +def test_symbolic_var(): + # Error: Symbolic Var new_s is not defined + new_s = tir.Var("new_s", "int64") + gv0 = rx.Var("gv0", R.Tensor([m, new_s], "int64")) + call_node = rx.op.add(x, x) + bindings = [rx.VarBinding(gv0, call_node)] + blocks = [rx.BindingBlock(bindings)] + func = build_function(blocks) + mod = tvm.IRModule({rx.GlobalVar("foo"): func}) + assert not rx.analysis.well_formed(mod, check_struct_info=False) + + +def test_symbolic_var_invalid_type(): + with pytest.raises( + tvm.TVMError, match="the value in ShapeStructInfo can only have dtype of int64" + ): + dim = tir.Var("dim", "float32") + y = rx.Var("y", R.Tensor([dim], "float32")) + gv0 = rx.Var("gv0", R.Tensor([dim], "float32")) + call_node = rx.op.add(y, y) + bindings = [rx.VarBinding(gv0, call_node)] + blocks = [rx.BindingBlock(bindings)] + func = build_function(blocks, [y]) + mod = tvm.IRModule({rx.GlobalVar("foo"): func}) + assert not rx.analysis.well_formed(mod, check_struct_info=False) + + +def test_seq_expr(): + # Error: SeqExpr in VarBinding + gv0 = rx.Var("gv0", R.Tensor([m, n], "float32")) + # build a SeqExpr + gv1 = rx.Var("gv1", R.Tensor([m, n], "float32")) + call_node = rx.op.add(x, gv0) + _bindings = [rx.VarBinding(gv1, call_node)] + _blocks = [rx.BindingBlock(_bindings)] + _seq_expr = rx.SeqExpr(_blocks, gv1) + # build a Binding with the SeqExpr as value + bindings = [rx.VarBinding(gv0, _seq_expr)] + blocks = [rx.BindingBlock(bindings)] + func = build_function(blocks) + mod = tvm.IRModule({rx.GlobalVar("foo"): func}) + assert not rx.analysis.well_formed(mod, check_struct_info=False) + + +def test_if(): + # Error: Var defined in true/false branch is invisible in the outer scope + # except the return Var, i.e the var in the last stmt + # v_in_if is invisible in the outer scope + v_in_if = rx.Var("v_in_if", R.Tensor([m, n], "float32")) + # gv0 is visible in the outer scope + gv0 = rx.Var("gv0", R.Tensor([m, n], "float32")) + # build true branch + true_bindings = [ + rx.VarBinding(v_in_if, rx.op.add(x, x)), + rx.VarBinding(gv0, rx.op.multiply(x, x)), + ] + true_blocks = [rx.BindingBlock(true_bindings)] + true_seq_expr = rx.SeqExpr(true_blocks, true_blocks[-1].bindings[-1].var) + # build false branch + false_bindings = [ + rx.VarBinding(v_in_if, rx.op.multiply(x, x)), + rx.VarBinding(gv0, rx.op.add(x, x)), + ] + false_blocks = [rx.BindingBlock(false_bindings)] + false_seq_expr = rx.SeqExpr(false_blocks, false_blocks[-1].bindings[-1].var) + # build If node + if_node = rx.If(cond=cond, true_branch=true_seq_expr, false_branch=false_seq_expr) + gv1 = rx.Var("gv1", R.Tensor([m, n], "float32")) + # try to call v_in_if defined in the true/false branch + bindings = [rx.VarBinding(gv0, if_node), rx.VarBinding(gv1, v_in_if)] + blocks = [rx.BindingBlock(bindings)] + func = build_function(blocks) + mod = tvm.IRModule({rx.GlobalVar("foo"): func}) + assert not rx.analysis.well_formed(mod, check_struct_info=True) + + +def test_if_non_seq_body(): + # Error: If node has a body that is not a seq node + if_node = rx.If(cond=cond, true_branch=x, false_branch=x) + blocks = [ + rx.BindingBlock( + [ + rx.VarBinding( + rx.Var("gv1", R.Tensor([m, n], "float32")), + if_node, + ) + ] + ) + ] + func = build_function(blocks) + mod = tvm.IRModule.from_expr(func) + assert not rx.analysis.well_formed(mod, check_struct_info=False) + + # on the other hand, if they're wrapped in a seq node, it's fine + seq = rx.SeqExpr([], x) + new_if_node = rx.If(cond=cond, true_branch=seq, false_branch=seq) + new_blocks = [ + rx.BindingBlock( + [ + rx.VarBinding( + rx.Var("gv1", R.Tensor([m, n], "float32")), + new_if_node, + ) + ] + ) + ] + new_func = build_function(new_blocks) + new_mod = tvm.IRModule.from_expr(new_func) + # apply normalization to fill in checked_type_ + normalized = rx.transform.Normalize()(new_mod) + assert rx.analysis.well_formed(normalized, check_struct_info=True) + + +def test_if_complex_condition(): + # Error: If condition must be a leaf expression + cond_tuple = rx.Tuple([cond]) + cond_idx = rx.TupleGetItem(cond_tuple, 0) + if_node = rx.If(cond_idx, rx.SeqExpr([], x), rx.SeqExpr([], x)) + blocks = [ + rx.BindingBlock( + [ + rx.VarBinding( + rx.Var("gv1", R.Tensor([m, n], "float32")), + if_node, + ) + ] + ) + ] + func = build_function(blocks) + mod = tvm.IRModule.from_expr(func) + assert not rx.analysis.well_formed(mod, check_struct_info=False) + + cond_var = rx.Var("q", R.Tensor([], "bool")) + new_if = rx.If(cond_var, rx.SeqExpr([], x), rx.SeqExpr([], x)) + blocks = [ + rx.BindingBlock( + [ + rx.VarBinding(cond_var, cond_idx), + rx.VarBinding( + rx.Var("gv1", R.Tensor([m, n], "float32")), + new_if, + ), + ] + ) + ] + func = build_function(blocks) + mod = tvm.IRModule.from_expr(func) + # apply normalization to fill in checked_type_ + normalized = rx.transform.Normalize()(mod) + assert rx.analysis.well_formed(normalized, check_struct_info=True) + + +def test_tuple_get_item_nested(): + # Error: The tuple value in tuple get item must be a leaf expression + nested_tup = rx.Var( + "t", rx.TupleStructInfo([rx.TupleStructInfo([rx.TensorStructInfo([], "int32")])]) + ) + double_idx = rx.TupleGetItem(rx.TupleGetItem(nested_tup, 0), 0) + ret_var = rx.Var("r", R.Tensor([], "int32")) + f = rx.Function( + [nested_tup], + rx.SeqExpr([rx.BindingBlock([rx.VarBinding(ret_var, double_idx)])], ret_var), + ret_struct_info=R.Tensor(ndim=0, dtype="int32"), + ) + f = f.with_attr("global_symbol", "f") + mod = tvm.IRModule.from_expr(f) + assert not rx.analysis.well_formed(mod, check_struct_info=False) + + # okay with an intermediate binding + first_idx = rx.TupleGetItem(nested_tup, 0) + idx_var = rx.Var("v", rx.TupleStructInfo([rx.TensorStructInfo([], "int32")])) + second_idx = rx.TupleGetItem(idx_var, 0) + new_f = rx.Function( + [nested_tup], + rx.SeqExpr( + [ + rx.BindingBlock( + [rx.VarBinding(idx_var, first_idx), rx.VarBinding(ret_var, second_idx)] + ) + ], + ret_var, + ), + ret_struct_info=R.Tensor(ndim=0, dtype="int32"), + ) + new_f = new_f.with_attr("global_symbol", "new_f") + mod = tvm.IRModule.from_expr(new_f) + # normalize in order to fill in checked type + normalized = rx.transform.Normalize()(mod) + assert rx.analysis.well_formed(normalized, check_struct_info=True) + + +def test_complex_seq_body(): + # Error: seq expr with a body that is not a leaf expression is not permitted + x = rx.Var("x", R.Tensor([], "int32")) + y = rx.Var("y", R.Tensor([], "int32")) + func = rx.Function( + [x, y], + rx.SeqExpr([], rx.op.add(x, y)), + R.Tensor(ndim=0, dtype="int32"), + ).with_attr("global_symbol", "foo") + mod = tvm.IRModule.from_expr(func) + assert not rx.analysis.well_formed(mod, check_struct_info=False) + + # but if the result is bound, then it's okay + z = rx.Var("z", R.Tensor([], "int32")) + new_func = rx.Function( + [x, y], + rx.SeqExpr( + [ + rx.BindingBlock( + [ + rx.VarBinding( + var=z, + value=rx.op.add(x, y), + ) + ] + ) + ], + z, + ), + R.Tensor(ndim=0, dtype="int32"), + ).with_attr("global_symbol", "foo") + new_mod = tvm.IRModule.from_expr(new_func) + # normalize in order to fill in checked type + normalized = rx.transform.Normalize()(new_mod) + assert rx.analysis.well_formed(normalized, check_struct_info=True) + + +def test_ANF(): + # Error: Nested Call + gv0 = rx.Var("gv0", R.Tensor([m, n], "float32")) + call_node = rx.op.add(x, rx.op.add(x, x)) + bindings = [rx.VarBinding(gv0, call_node)] + blocks = [rx.BindingBlock(bindings)] + func = build_function(blocks) + mod = tvm.IRModule({rx.GlobalVar("foo"): func}) + assert not rx.analysis.well_formed(mod, check_struct_info=False) + + # Error: Call Node in Tuple + gv0 = rx.Var("gv0", R.Tensor([m, n], "float32")) + bindings = [rx.VarBinding(gv0, rx.Tuple((x, rx.op.add(x, x))))] + blocks = [rx.BindingBlock(bindings)] + func = build_function(blocks) + mod = tvm.IRModule({rx.GlobalVar("foo"): func}) + assert not rx.analysis.well_formed(mod, check_struct_info=False) + + +def test_global_var_vs_gsymbol(): + # Error: gsymbol "main1" not equals to the name in global var "main" + gv0 = rx.Var("gv0", R.Tensor([m, n], "float32")) + bindings = [rx.VarBinding(gv0, x)] + blocks = [rx.DataflowBlock(bindings)] + func = rx.Function( + [x], + rx.SeqExpr(blocks, gv0), + R.Tensor(ndim=2, dtype="float32"), + ).with_attr("global_symbol", "main1") + mod = tvm.IRModule({rx.GlobalVar("main"): func}) + assert not rx.analysis.well_formed(mod, check_struct_info=False) + + +def test_nested_dataflow(): + scalar_struct_info = rx.TensorStructInfo(shape=[], dtype="int32") + gv0 = rx.Var("gv0", scalar_struct_info) + f = rx.DataflowVar("f", rx.FuncStructInfo([], scalar_struct_info)) + x0 = rx.DataflowVar("x0", scalar_struct_info) + x1 = rx.DataflowVar("x1", scalar_struct_info) + x2 = rx.DataflowVar("x2", scalar_struct_info) + y = rx.Var("y", scalar_struct_info) + inner_block = rx.DataflowBlock([rx.VarBinding(x0, rx.const(2, "int32")), rx.VarBinding(y, x0)]) + inner_func = rx.Function([], rx.SeqExpr([inner_block], y), scalar_struct_info) + outer_block = rx.DataflowBlock( + [ + rx.VarBinding(x1, rx.const(1, "int32")), + rx.VarBinding(f, inner_func), + rx.VarBinding(x2, rx.op.add(x1, rx.Call(f, []))), + rx.VarBinding(gv0, x2), + ] + ) + func = rx.Function([], rx.SeqExpr([outer_block], gv0), scalar_struct_info) + mod = tvm.IRModule.from_expr(func) + normalized = rx.transform.Normalize()(mod) + assert rx.analysis.well_formed(normalized) + + +def test_sinfo_args_tir_var_used_before_define_call_packed(): + # Error: Symbolic Var m1, n1 are not defined + m1 = tir.Var("m1", "int64") + n1 = tir.Var("n1", "int64") + call = R.call_packed("my_func", x, sinfo_args=R.Tensor((m1, n1), "float32")) + func = build_function([rx.BindingBlock([rx.VarBinding(rx.Var("gv"), call)])]) + mod = rx.transform.Normalize()(tvm.IRModule.from_expr(func)) + assert not rx.analysis.well_formed(mod, check_struct_info=False) + + +def test_sinfo_args_tir_var_used_before_define_call_tir(): + # Error: Symbolic Var m1, n1 are not defined + m1 = tir.Var("m1", "int64") + n1 = tir.Var("n1", "int64") + call = R.call_tir("my_func", x, out_sinfo=R.Tensor((m1, n1), "float32")) + func = build_function([rx.BindingBlock([rx.VarBinding(rx.Var("gv"), call)])]) + mod = rx.transform.Normalize()(tvm.IRModule.from_expr(func)) + assert not rx.analysis.well_formed(mod, check_struct_info=False) + + +if __name__ == "__main__": + tvm.testing.main()
