This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 7c88732056 [Unity][Pass] Wellformed Analysis (#14032)
7c88732056 is described below

commit 7c8873205688686acb82d306d1474a16d9a87b6e
Author: Lesheng Jin <[email protected]>
AuthorDate: Fri Feb 17 18:11:08 2023 -0800

    [Unity][Pass] Wellformed Analysis (#14032)
    
    This PR implements relax wellformed analysis, which checks if the IRModule 
is well-formed. (tests and examples are added).
    
    Co-Authored-by: Ruihang Lai 
[[email protected]](mailto:[email protected])
    Co-Authored-by: Siyuan Feng 
[[email protected]](mailto:[email protected])
    Co-Authored-by: Tianqi Chen 
[[email protected]](mailto:[email protected])
    Co-authored-by: Steven S. Lyubomirsky 
[[email protected]](mailto:[email protected])
    Co-authored-by: Yong Wu [[email protected]](mailto:[email protected])
    Co-Authored-by: Yuchen Jin 
[[email protected]](mailto:[email protected])
    Co-Authored-by: Yixin Dong <[email protected]>
    Co-Authored-by: Chaofan Lin <[email protected]>
    Co-Authored-by: Prakalp Srivastava 
[[email protected]](mailto:[email protected])
    Co-Authored-by: Junru Shao 
[[email protected]](mailto:[email protected])
---
 include/tvm/relax/analysis.h                    |  13 +
 python/tvm/relax/analysis/analysis.py           |  27 ++
 python/tvm/relax/ir/instrument.py               |  37 ++
 src/relax/analysis/well_formed.cc               | 465 ++++++++++++++++++++++++
 tests/python/relax/conftest.py                  |  23 ++
 tests/python/relax/test_analysis_well_formed.py | 438 ++++++++++++++++++++++
 6 files changed, 1003 insertions(+)

diff --git a/include/tvm/relax/analysis.h b/include/tvm/relax/analysis.h
index ff576d4ebb..f9896efdf2 100644
--- a/include/tvm/relax/analysis.h
+++ b/include/tvm/relax/analysis.h
@@ -343,6 +343,19 @@ TVM_DLL relay::OpPatternKind AnalyzeOpPatternKind(const 
tir::PrimFunc& func);
  */
 TVM_DLL bool HasReshapePattern(const tir::PrimFunc& func);
 
+/*!
+ * \brief Check if the IRModule is well formed.
+ *
+ * \param m the IRModule to check.
+ * \param check_struct_info A boolean flag indicating if the property "every 
Expr
+ * must have defined structure info" will be checked.
+ * \return true if the IRModule is well formed, false if not.
+ * \note By default the structure info is always checked. It is only in test 
cases
+ * where `check_struct_info` might be false, so that other well-formed 
requirements
+ * will be well tested and will not be blocked by not having structure info.
+ */
+TVM_DLL bool WellFormed(IRModule m, bool check_struct_info = true);
+
 }  // namespace relax
 }  // namespace tvm
 
diff --git a/python/tvm/relax/analysis/analysis.py 
b/python/tvm/relax/analysis/analysis.py
index 27416c3a79..7107883478 100644
--- a/python/tvm/relax/analysis/analysis.py
+++ b/python/tvm/relax/analysis/analysis.py
@@ -25,6 +25,7 @@ from typing import Dict
 from enum import IntEnum
 
 from tvm import tir
+from tvm import IRModule
 from tvm.relax.ty import Type
 from tvm.relax.struct_info import StructInfo, FuncStructInfo
 from tvm.relax.expr import Var, Expr, Call
@@ -207,3 +208,29 @@ def has_reshape_pattern(func: tir.PrimFunc) -> bool:
     of this function.
     """
     return _ffi_api.has_reshape_pattern(func)  # type: ignore
+
+
+def well_formed(mod: IRModule, check_struct_info: bool = True) -> bool:
+    """Check if the IRModule is well formed.
+
+    Parameters
+    ----------
+    mod : tvm.IRModule
+        The input IRModule.
+
+    check_struct_info : bool
+        A boolean flag indicating if the property "every Expr must
+        have defined structure info" will be checked.
+
+    Returns
+    -------
+    ret: bool
+        True if the IRModule is well formed, False if not.
+
+    Note
+    ----
+    By default the structure info is always checked. It is only in test cases
+    where `check_struct_info` might be false, so that other well-formed 
requirements
+    will be well tested and will not be blocked by not having structure info.
+    """
+    return _ffi_api.well_formed(mod, check_struct_info)  # type: ignore
diff --git a/python/tvm/relax/ir/instrument.py 
b/python/tvm/relax/ir/instrument.py
new file mode 100644
index 0000000000..fc51a796a7
--- /dev/null
+++ b/python/tvm/relax/ir/instrument.py
@@ -0,0 +1,37 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Common relax pass instrumentation across IR variants."""
+import tvm
+from tvm import relax
+
+
[email protected]_instrument
+class WellFormedInstrument:
+    """An instrument that checks the input/output IRModule of the Pass
+    is well formed. It will skip specific passes, like Normalize.
+    """
+
+    def __init__(self):
+        self.skip_pass_name = ["Normalize", "ResolveGlobals"]
+
+    def run_before_pass(self, mod, pass_info):
+        if pass_info.name not in self.skip_pass_name:
+            assert relax.analysis.well_formed(mod)
+
+    def run_after_pass(self, mod, pass_info):
+        if pass_info.name not in self.skip_pass_name:
+            assert relax.analysis.well_formed(mod)
diff --git a/src/relax/analysis/well_formed.cc 
b/src/relax/analysis/well_formed.cc
new file mode 100644
index 0000000000..e7ec237fd5
--- /dev/null
+++ b/src/relax/analysis/well_formed.cc
@@ -0,0 +1,465 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file relax/analysis/well_formed.cc
+ * \brief Check if the IRModule is well-formed.
+ *
+ * This pass is supposed to be applied to normalized Relax AST.
+ * If it's malformed, messages will be logged as Warning.
+ * This pass will check:
+ *    1. Each Expr should have `struct_info_` field already populated, when
+ *      `check_struct_info` is true.
+ *    2. GlobalVars are defined before use.
+ *    3. When a Function has a corresponding GlobalVar and a `global_symbol`
+ *       attribute, the name of the GlobalVar must equal the value of the
+ *       `global_symbol` attribute value.
+ *    4. Any variable cannot used as different function parameters in the same 
IRModule
+ *    5. Vars are defined before use.
+ *    6. Vars are defined exactly once.
+ *    7. Symbolic Vars are defined before use.
+ *    8. DataflowVars cannot be defined inside BindingBlock.
+ *    9. Vars defined in IfNode, except the return Var, are invisible
+ *       out of the If body.(May change for new AST designs)
+ *    10. SeqExpr only serves as function body, or in the true and
+ *       false branches in IfNode.
+ *    11. The IR is in ANF:
+ *       (a) Expressions cannot contain nested complex expressions.
+ *           Here are the expressions that may be nested inside other 
expressions:
+ *           Var, DataflowVar, GlobalVar, Constant, ShapeExpr,
+ *           Op, Tuple (we call these "leaf" expressions).
+ *       (b) The right-hand side of a binding may contain a non-leaf expression
+ *           (where all expressions nested in it are leaf expressions),
+ *           other than SeqExprs (see rule 6)
+ *       (c) Exceptions: The body of a Function node and the true branch
+ *           and false branch of If nodes *must* be SeqExprs.
+ *       (d) Places where non-leaf expressions cannot appear:
+ *           * The tuple_value field of TupleGetItem nodes
+ *           * The cond field of If nodes
+ *           * The op or args fields of Call nodes
+ *           * Inside the fields of Tuple nodes
+ *    12. Expr always has checked_type_ (with the exception of Op).
+ */
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/expr.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/struct_info_functor.h>
+#include <tvm/relax/utils.h>
+#include <tvm/tir/expr_functor.h>
+
+#include <unordered_set>
+
+namespace tvm {
+namespace relax {
+
+// TODO(relax-team): Consider further refactor using
+// Scope Frame to store manage the var context.
+//
+/*! \brief Helper to implement well formed check.*/
+class WellFormedChecker : public relax::ExprVisitor,
+                          public relax::StructInfoVisitor,
+                          public tir::ExprVisitor {
+ public:
+  static bool Check(IRModule mod, bool check_struct_info) {
+    WellFormedChecker well_formed_checker = WellFormedChecker(mod, 
check_struct_info);
+
+    for (const auto& it : mod->functions) {
+      // visit relax.Function
+      if (auto* n = it.second.as<FunctionNode>()) {
+        Function func = GetRef<Function>(n);
+        well_formed_checker.CheckGlobalVarAndGsymbolConsistency(it.first, 
func);
+        well_formed_checker.VisitExpr(func);
+      }
+    }
+    return well_formed_checker.well_formed_;
+  }
+
+ private:
+  explicit WellFormedChecker(IRModule mod, bool check_struct_info)
+      : mod_(std::move(mod)), check_struct_info_(check_struct_info) {}
+
+  using relax::ExprVisitor::VisitExpr_;
+  using tir::ExprVisitor::VisitExpr;
+  using tir::ExprVisitor::VisitExpr_;
+
+  // Possible mode of visitor
+  enum class VisitMode {
+    /*!
+     * \brief Check all vars are well-defined
+     */
+    kDefault,
+    /*!
+     * \brief Match define the vars on first occurance.
+     * Do not check the well-defined property of composite expr.
+     */
+    kMatchVarDef
+  };
+
+  void Malformed(Diagnostic diag) {
+    well_formed_ = false;
+    LOG(WARNING) << "This IR is not well formed: " << diag->message;
+  }
+
+  void CheckGlobalVarAndGsymbolConsistency(GlobalVar var, Function func) {
+    // check name in global var and gsymbol
+    Optional<String> gsymbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+    if (gsymbol.defined() && gsymbol != var->name_hint) {
+      Malformed(Diagnostic::Error(func->span)
+                << "Name in GlobalVar is not equal to name in gsymbol: " << 
var->name_hint
+                << " != " << gsymbol.value());
+    }
+  }
+
+  void VisitExpr(const Expr& expr) final {
+    if (!expr.as<OpNode>() && !expr->checked_type_.defined()) {
+      Malformed(Diagnostic::Error(expr) << "The checked_type_ of Expr " << 
expr << " is nullptr.");
+    }
+    relax::ExprVisitor::VisitExpr(expr);
+  }
+
+  void VisitExpr_(const GlobalVarNode* op) final {
+    GlobalVar var = GetRef<GlobalVar>(op);
+    if (!(mod_->ContainGlobalVar(var->name_hint) &&
+          mod_->GetGlobalVar(var->name_hint).same_as(var))) {
+      Malformed(Diagnostic::Error(var) << "GlobalVar " << op->name_hint << " 
is not defined.");
+    }
+
+    if (op->checked_type_.defined()) {
+      if ((!op->checked_type_->IsInstance<FuncTypeNode>()) &&
+          (!op->checked_type_->IsInstance<PackedFuncTypeNode>())) {
+        Malformed(Diagnostic::Error(var) << "The checked_type_ of GlobalVar " 
<< op->name_hint
+                                         << " must be either FuncType or 
PackedFuncType.");
+      }
+    }
+
+    CheckStructInfo(op);
+  }
+
+  void VisitExpr_(const TupleNode* op) final {
+    for (size_t i = 0; i < op->fields.size(); i++) {
+      Expr expr = op->fields[i];
+      if (IsLeafOrTuple(expr)) {
+        this->VisitExpr(expr);
+      } else {
+        Malformed(Diagnostic::Error(expr)
+                  << "Tuple is not in ANF form, field " << i << " gets " << 
expr->GetTypeKey());
+      }
+    }
+
+    CheckStructInfo(op);
+  }
+
+  void VisitExpr_(const TupleGetItemNode* op) final {
+    if (IsLeafOrTuple(op->tuple)) {
+      this->VisitExpr(op->tuple);
+    } else {
+      Malformed(Diagnostic::Error(op)
+                << "The tuple value in a TupleGetItem node must be a leaf 
expression.");
+    }
+    CheckStructInfo(op);
+  }
+
+  void VisitExpr_(const VarNode* op) final {
+    Var var = GetRef<Var>(op);
+    if (var_set_.count(var) == 0) {
+      Malformed(Diagnostic::Error(var) << "Var " << op->name_hint() << " is 
not defined.");
+    }
+    CheckStructInfo(op);
+  }
+
+  void VisitExpr_(const DataflowVarNode* op) final {
+    DataflowVar var = GetRef<DataflowVar>(op);
+    if (!is_dataflow_) {
+      Malformed(Diagnostic::Error(var)
+                << "DataflowVar " << op->name_hint() << " is used outside 
DataflowBlock.");
+    }
+    if (dataflow_var_set_.count(var) == 0) {
+      Malformed(Diagnostic::Error(var) << "DataflowVar " << op->name_hint() << 
" is not defined.");
+    }
+    CheckStructInfo(op);
+  }
+
+  void VisitExpr_(const FunctionNode* op) final {
+    // save the var_set_ for local function
+    auto prev_var_set = var_set_;
+    auto prev_dataflow_var_set = dataflow_var_set_;
+    auto prev_symbolic_var_set = symbolic_var_set_;
+    bool old_dataflow_state = is_dataflow_;
+    // symbolic var is not captured across function boundaries
+    symbolic_var_set_.clear();
+    is_dataflow_ = false;
+
+    // first populate defs in params
+    WithMode(VisitMode::kMatchVarDef, [&]() {
+      ICHECK(mode_ == VisitMode::kMatchVarDef);
+      for (Var param : op->params) {
+        relax::StructInfoVisitor::VisitStructInfo(GetStructInfo(param));
+      }
+    });
+
+    // check all expr are well defined.
+    for (Var param : op->params) {
+      this->VisitVarDef(param);
+
+      if (param_var_func_map_.count(param) == 1) {
+        // TODO(relax-team): Complete this error info after we integrate 
printer
+        Malformed(Diagnostic::Error(param->span)
+                  << "Relax variable " << param->name_hint()
+                  << " is repeatedly used as parameters in function.");
+      }
+      param_var_func_map_.insert({param, GetRef<Function>(op)});
+    }
+
+    if (auto seq = op->body.as<SeqExprNode>()) {
+      this->VisitSeqExpr(seq);
+    } else {
+      Malformed(Diagnostic::Error(op) << "Function bodies must be sequence 
expressions");
+    }
+
+    is_dataflow_ = old_dataflow_state;
+    dataflow_var_set_ = prev_dataflow_var_set;
+    var_set_ = prev_var_set;
+    symbolic_var_set_ = prev_symbolic_var_set;
+  }
+
+  void VisitExpr_(const CallNode* op) final {
+    if (IsLeafOrTuple(op->op)) {
+      this->VisitExpr(op->op);
+    } else {
+      Malformed(Diagnostic::Error(op) << "The called expression must be a leaf 
expression");
+    }
+    for (size_t i = 0; i < op->args.size(); i++) {
+      Expr arg = op->args[i];
+      if (IsLeafOrTuple(arg)) {
+        this->VisitExpr(arg);
+      } else {
+        Malformed(Diagnostic::Error(arg->span)
+                  << "Call is not in ANF form, arg " << i << " gets " << 
arg->GetTypeKey());
+      }
+    }
+
+    for (const StructInfo& sinfo_arg : op->sinfo_args) {
+      this->VisitStructInfo(sinfo_arg);
+    }
+
+    CheckStructInfo(op);
+  }
+
+  void VisitExpr_(const IfNode* op) final {
+    if (IsLeafOrTuple(op->cond)) {
+      this->VisitExpr(op->cond);
+    } else {
+      Malformed(Diagnostic::Error(op) << "The condition for an if node must be 
a leaf expression.");
+    }
+    auto true_seq = op->true_branch.as<SeqExprNode>();
+    auto false_seq = op->false_branch.as<SeqExprNode>();
+    if (true_seq && false_seq) {
+      std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> previous_var_set 
= var_set_;
+      std::unordered_set<tir::Var, ObjectPtrHash, ObjectPtrEqual> 
previous_symbolic_var_set =
+          symbolic_var_set_;
+      this->VisitSeqExpr(true_seq);
+      var_set_ = previous_var_set;
+      symbolic_var_set_ = previous_symbolic_var_set;
+      this->VisitSeqExpr(false_seq);
+      var_set_ = previous_var_set;
+      symbolic_var_set_ = previous_symbolic_var_set;
+    } else {
+      Malformed(Diagnostic::Error(op) << "If node branches must be seq exprs");
+    }
+    CheckStructInfo(op);
+  }
+
+  void VisitExpr_(const ShapeExprNode* op) final {
+    for (PrimExpr expr : op->values) {
+      // check if the symbolic vars in the expr are defined, e.g, 2 * m
+      tir::ExprVisitor::VisitExpr(expr);
+      if (!expr.dtype().is_int()) {
+        Malformed(Diagnostic::Error(expr)
+                  << "Shape expressions must be of integer type, but got " << 
expr.dtype());
+      }
+    }
+    CheckStructInfo(op);
+  }
+
+  void VisitExpr_(const SeqExprNode* op) final {
+    Malformed(Diagnostic::Error(op) << "SeqExpr only serves as the function 
body in FunctionNode, "
+                                       "or the true/false branch body in 
IfNode.");
+  }
+
+  void VisitSeqExpr(const SeqExprNode* op) {
+    // a special call only if SeqExpr is the function body
+    // in FunctionNode or the true/false branch body in IfNode
+    for (BindingBlock block : op->blocks) {
+      this->VisitBindingBlock(block);
+    }
+    if (!IsLeafOrTuple(op->body)) {
+      Malformed(Diagnostic::Error(op) << "SeqExpr bodies must be leaf 
expressions.");
+    }
+    this->VisitExpr(op->body);
+    CheckStructInfo(op);
+  }
+
+  void VisitBinding_(const VarBindingNode* binding) final {
+    this->VisitExpr(binding->value);
+    this->VisitVarDef(binding->var);
+  }
+
+  void VisitBinding_(const MatchCastNode* binding) final {
+    this->VisitExpr(binding->value);
+    // define the vars
+    WithMode(VisitMode::kMatchVarDef, [&]() { 
this->VisitStructInfo(binding->struct_info); });
+
+    this->VisitStructInfo(binding->struct_info);
+    this->VisitVarDef(binding->var);
+  }
+
+  void VisitBindingBlock_(const DataflowBlockNode* block) final {
+    bool old_is_dataflow_ = is_dataflow_;
+    is_dataflow_ = true;
+    for (Binding binding : block->bindings) {
+      this->VisitBinding(binding);
+    }
+    is_dataflow_ = old_is_dataflow_;
+    dataflow_var_set_.clear();
+  }
+
+  void VisitVarDef_(const DataflowVarNode* var) final {
+    if (!is_dataflow_) {
+      Malformed(Diagnostic::Error(var)
+                << "DataflowVar " << var->name_hint() << " is defined outside 
DataflowBlock.");
+    }
+    DataflowVar lv = GetRef<DataflowVar>(var);
+    if (dataflow_var_set_.count(lv) == 1) {
+      Malformed(Diagnostic::Error(var)
+                << "DataflowVar " << lv->name_hint() << " is defined more than 
once.");
+    }
+    // register DataflowVar
+    dataflow_var_set_.insert(lv);
+    CheckStructInfo(var);
+  }
+
+  void VisitVarDef_(const VarNode* var) final {
+    Var gv = GetRef<Var>(var);
+    if (var_set_.count(gv) == 1) {
+      Malformed(Diagnostic::Error(var)
+                << "Var " << gv->name_hint() << " is defined more than once.");
+    }
+    // register Var
+    var_set_.insert(gv);
+    CheckStructInfo(var);
+  }
+
+  void VisitVarDef(const Var& var) final {
+    if (const DataflowVarNode* lv_node = var.as<DataflowVarNode>()) {
+      VisitVarDef_(lv_node);
+    } else if (const VarNode* gv_node = var.as<VarNode>()) {
+      VisitVarDef_(gv_node);
+    } else {
+      LOG(FATAL) << "TypeError: Invalid type: " << var->GetTypeKey();
+    }
+  }
+
+  void VisitExpr_(const tir::VarNode* op) final {
+    tir::Var var = GetRef<tir::Var>(op);
+    // default mode, check defined.
+    if (symbolic_var_set_.count(var) == 0) {
+      this->Malformed(Diagnostic::Error(var)
+                      << "Symbolic Var " << var->name_hint << " is not 
defined.");
+    }
+  }
+
+  void VisitStructInfoExprField(const Expr& expr) final {
+    if (mode_ == VisitMode::kMatchVarDef) {
+      // populate symbolic var in first occurrence
+      if (auto* op = expr.as<relax::VarNode>()) {
+        auto var = GetRef<relax::Var>(op);
+        if (var_set_.count(var) == 0) {
+          var_set_.insert(var);
+        }
+      }
+      if (auto* shape = expr.as<relax::ShapeExprNode>()) {
+        for (auto val : shape->values) {
+          this->VisitStructInfoExprField(val);
+        }
+      }
+    } else {
+      relax::ExprVisitor::VisitExpr(expr);
+    }
+  }
+
+  void VisitStructInfoExprField(const PrimExpr& expr) final {
+    if (mode_ == VisitMode::kMatchVarDef) {
+      // populate symbolic var in first occurrence
+      if (auto* op = expr.as<tir::VarNode>()) {
+        auto var = GetRef<tir::Var>(op);
+        if (symbolic_var_set_.count(var) == 0) {
+          symbolic_var_set_.insert(var);
+        }
+      }
+    } else {
+      tir::ExprVisitor::VisitExpr(expr);
+    }
+  }
+
+  void CheckStructInfo(const ExprNode* op) {
+    if (!check_struct_info_) {
+      return;
+    }
+
+    auto* sinfo = op->struct_info_.as<StructInfoNode>();
+    if (sinfo != nullptr) {
+      this->VisitStructInfo(GetRef<StructInfo>(sinfo));
+    } else {
+      Malformed(Diagnostic::Error(op) << "Expr must have struct_info 
populated. "
+                                      << " Expr.type_key=" << 
op->GetTypeKey());
+    }
+  }
+
+  // Run callback with mode.
+  template <typename FType>
+  void WithMode(VisitMode mode, FType callback) {
+    std::swap(mode_, mode);
+    callback();
+    std::swap(mode_, mode);
+  }
+
+  IRModule mod_;
+  const bool check_struct_info_;
+  bool well_formed_ = true;
+  bool is_dataflow_;
+  // Current visit mode.
+  VisitMode mode_ = VisitMode::kDefault;
+  // set of context variables.
+  std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> var_set_;
+  std::unordered_set<DataflowVar, ObjectPtrHash, ObjectPtrEqual> 
dataflow_var_set_;
+  std::unordered_set<tir::Var, ObjectPtrHash, ObjectPtrEqual> 
symbolic_var_set_;
+  std::unordered_map<Var, Function, ObjectPtrHash, ObjectPtrEqual> 
param_var_func_map_;
+};
+
+bool WellFormed(IRModule m, bool check_struct_info) {
+  return WellFormedChecker::Check(std::move(m), check_struct_info);
+}
+
+TVM_REGISTER_GLOBAL(("relax.analysis.well_formed"))
+    .set_body_typed([](IRModule m, bool check_struct_info) {
+      return WellFormed(m, check_struct_info);
+    });
+
+}  // namespace relax
+}  // namespace tvm
diff --git a/tests/python/relax/conftest.py b/tests/python/relax/conftest.py
new file mode 100644
index 0000000000..f1b1187066
--- /dev/null
+++ b/tests/python/relax/conftest.py
@@ -0,0 +1,23 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+
+import pytest
+
+import tvm
+from tvm.relax.ir.instrument import WellFormedInstrument
+
+
+tvm.transform.PassContext.current().override_instruments([WellFormedInstrument()])
diff --git a/tests/python/relax/test_analysis_well_formed.py 
b/tests/python/relax/test_analysis_well_formed.py
new file mode 100644
index 0000000000..cc0de84d53
--- /dev/null
+++ b/tests/python/relax/test_analysis_well_formed.py
@@ -0,0 +1,438 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+import tvm
+import tvm.testing
+from tvm import tir
+from tvm import relax as rx
+from tvm.script import relax as R
+
+m = tir.Var("m", "int64")
+n = tir.Var("n", "int64")
+x = rx.Var("x", R.Tensor([m, n], "float32"))
+cond = rx.Var("cond", R.Tensor([], "bool"))
+
+
+def build_function(blocks, params=[]):
+    """Returns relax.function with given blocks"""
+    seq_expr = rx.SeqExpr(blocks, blocks[-1].bindings[-1].var)
+    func = rx.Function([x, cond] + params, seq_expr, 
R.Tensor("float32")).with_attr(
+        "global_symbol", "foo"
+    )
+    return func
+
+
+def test_var():
+    # Error: Var gv0 is not defined
+    gv0 = rx.Var("gv0", R.Tensor([m, n], "float32"))
+    gv1 = rx.Var("gv1", R.Tensor([m, n], "float32"))
+    call_node = rx.op.add(x, gv0)
+    bindings = [rx.VarBinding(gv1, call_node)]
+    blocks = [rx.BindingBlock(bindings)]
+    func = build_function(blocks)
+    mod = tvm.IRModule({rx.GlobalVar("foo"): func})
+    assert not rx.analysis.well_formed(mod, check_struct_info=False)
+
+    # Error: Var gv0 is defined more than once
+    gv0 = rx.Var("gv0", R.Tensor([m, n], "float32"))
+    call_node = rx.op.add(x, x)
+    call_node2 = rx.op.multiply(x, x)
+    bindings = [rx.VarBinding(gv0, call_node), rx.VarBinding(gv0, call_node2)]
+    blocks = [rx.BindingBlock(bindings)]
+    func = build_function(blocks)
+    mod = tvm.IRModule({rx.GlobalVar("foo"): func})
+    assert not rx.analysis.well_formed(mod, check_struct_info=False)
+
+
+def test_dataflow_var():
+    # Error: DataflowVar lv0 is not defined
+    lv0 = rx.DataflowVar("lv0", R.Tensor([m, n], "float32"))
+    gv0 = rx.Var("gv0", R.Tensor([m, n], "float32"))
+    call_node = rx.op.add(x, lv0)
+    bindings = [rx.VarBinding(gv0, call_node)]
+    blocks = [rx.DataflowBlock(bindings)]
+    func = build_function(blocks)
+    mod = tvm.IRModule({rx.GlobalVar("foo"): func})
+    assert not rx.analysis.well_formed(mod, check_struct_info=False)
+
+    # Error: DataflowVar gv0 is defined more than once
+    lv0 = rx.DataflowVar("lv0", R.Tensor([m, n], "float32"))
+    call_node = rx.op.add(x, x)
+    call_node2 = rx.op.multiply(x, x)
+    bindings = [rx.VarBinding(lv0, call_node), rx.VarBinding(lv0, call_node2)]
+    blocks = [rx.DataflowBlock(bindings)]
+    func = build_function(blocks)
+    mod = tvm.IRModule({rx.GlobalVar("foo"): func})
+    assert not rx.analysis.well_formed(mod, check_struct_info=False)
+
+    # Error: DataflowVar lv0 is defined outside DataflowBlock
+    lv0 = rx.DataflowVar("lv0", R.Tensor([m, n], "float32"))
+    call_node = rx.op.add(x, x)
+    bindings = [rx.VarBinding(lv0, call_node)]
+    blocks = [rx.BindingBlock(bindings)]
+    func = build_function(blocks)
+    mod = tvm.IRModule({rx.GlobalVar("foo"): func})
+    assert not rx.analysis.well_formed(mod, check_struct_info=False)
+
+    # Error: DataflowVar lv0 is used outside DataflowBlock
+    lv0 = rx.DataflowVar("lv0", R.Tensor([m, n], "float32"))
+    gv0 = rx.Var("gv0", R.Tensor([m, n], "float32"))
+    call_node = rx.op.add(lv0, x)
+    bindings = [rx.VarBinding(lv0, call_node)]
+    blocks = [rx.BindingBlock(bindings)]
+    func = build_function(blocks)
+    mod = tvm.IRModule({rx.GlobalVar("foo"): func})
+    assert not rx.analysis.well_formed(mod, check_struct_info=False)
+
+
+def test_param_var():
+    v0 = rx.Var("v0", R.Tensor([m, n], "float32"))
+    v1 = rx.Var("v1", R.Tensor([m, n], "float32"))
+    v2 = rx.Var("v2", R.Tensor([m, n], "float32"))
+    bb = rx.BlockBuilder()
+    with bb.function("func1", [v0, v1]):
+        gv0 = bb.emit(rx.op.add(v0, v1))
+        bb.emit_func_output(gv0)
+    with bb.function("func2", [v0, v2]):
+        gv0 = bb.emit(rx.op.add(v2, v1))
+        bb.emit_func_output(gv0)
+    mod = bb.get()
+    assert not rx.analysis.well_formed(mod, check_struct_info=False)
+
+
+def test_global_var():
+    # Error: GlobalVar GlobalVar0 is not defined
+    gv0 = rx.Var("gv0", R.Tensor([m, n], "float32"))
+    globalvar = rx.GlobalVar("GlobalVar0")
+    call_node = rx.Call(
+        op=tvm.ir.Op.get("relax.call_tir"),
+        args=[globalvar, rx.Tuple([x]), rx.ShapeExpr([m, n])],
+    )
+    bindings = [rx.VarBinding(gv0, call_node)]
+    blocks = [rx.BindingBlock(bindings)]
+    func = build_function(blocks)
+    mod = tvm.IRModule({rx.GlobalVar("foo"): func})
+    assert not rx.analysis.well_formed(mod, check_struct_info=False)
+
+
+def test_symbolic_var():
+    # Error: Symbolic Var new_s is not defined
+    new_s = tir.Var("new_s", "int64")
+    gv0 = rx.Var("gv0", R.Tensor([m, new_s], "int64"))
+    call_node = rx.op.add(x, x)
+    bindings = [rx.VarBinding(gv0, call_node)]
+    blocks = [rx.BindingBlock(bindings)]
+    func = build_function(blocks)
+    mod = tvm.IRModule({rx.GlobalVar("foo"): func})
+    assert not rx.analysis.well_formed(mod, check_struct_info=False)
+
+
+def test_symbolic_var_invalid_type():
+    with pytest.raises(
+        tvm.TVMError, match="the value in ShapeStructInfo can only have dtype 
of int64"
+    ):
+        dim = tir.Var("dim", "float32")
+        y = rx.Var("y", R.Tensor([dim], "float32"))
+        gv0 = rx.Var("gv0", R.Tensor([dim], "float32"))
+        call_node = rx.op.add(y, y)
+        bindings = [rx.VarBinding(gv0, call_node)]
+        blocks = [rx.BindingBlock(bindings)]
+        func = build_function(blocks, [y])
+        mod = tvm.IRModule({rx.GlobalVar("foo"): func})
+        assert not rx.analysis.well_formed(mod, check_struct_info=False)
+
+
+def test_seq_expr():
+    # Error: SeqExpr in VarBinding
+    gv0 = rx.Var("gv0", R.Tensor([m, n], "float32"))
+    # build a SeqExpr
+    gv1 = rx.Var("gv1", R.Tensor([m, n], "float32"))
+    call_node = rx.op.add(x, gv0)
+    _bindings = [rx.VarBinding(gv1, call_node)]
+    _blocks = [rx.BindingBlock(_bindings)]
+    _seq_expr = rx.SeqExpr(_blocks, gv1)
+    # build a Binding with the SeqExpr as value
+    bindings = [rx.VarBinding(gv0, _seq_expr)]
+    blocks = [rx.BindingBlock(bindings)]
+    func = build_function(blocks)
+    mod = tvm.IRModule({rx.GlobalVar("foo"): func})
+    assert not rx.analysis.well_formed(mod, check_struct_info=False)
+
+
+def test_if():
+    # Error: Var defined in true/false branch is invisible in the outer scope
+    # except the return Var, i.e the var in the last stmt
+    # v_in_if is invisible in the outer scope
+    v_in_if = rx.Var("v_in_if", R.Tensor([m, n], "float32"))
+    # gv0 is visible in the outer scope
+    gv0 = rx.Var("gv0", R.Tensor([m, n], "float32"))
+    # build true branch
+    true_bindings = [
+        rx.VarBinding(v_in_if, rx.op.add(x, x)),
+        rx.VarBinding(gv0, rx.op.multiply(x, x)),
+    ]
+    true_blocks = [rx.BindingBlock(true_bindings)]
+    true_seq_expr = rx.SeqExpr(true_blocks, true_blocks[-1].bindings[-1].var)
+    # build false branch
+    false_bindings = [
+        rx.VarBinding(v_in_if, rx.op.multiply(x, x)),
+        rx.VarBinding(gv0, rx.op.add(x, x)),
+    ]
+    false_blocks = [rx.BindingBlock(false_bindings)]
+    false_seq_expr = rx.SeqExpr(false_blocks, 
false_blocks[-1].bindings[-1].var)
+    # build If node
+    if_node = rx.If(cond=cond, true_branch=true_seq_expr, 
false_branch=false_seq_expr)
+    gv1 = rx.Var("gv1", R.Tensor([m, n], "float32"))
+    # try to call v_in_if defined in the true/false branch
+    bindings = [rx.VarBinding(gv0, if_node), rx.VarBinding(gv1, v_in_if)]
+    blocks = [rx.BindingBlock(bindings)]
+    func = build_function(blocks)
+    mod = tvm.IRModule({rx.GlobalVar("foo"): func})
+    assert not rx.analysis.well_formed(mod, check_struct_info=True)
+
+
+def test_if_non_seq_body():
+    # Error: If node has a body that is not a seq node
+    if_node = rx.If(cond=cond, true_branch=x, false_branch=x)
+    blocks = [
+        rx.BindingBlock(
+            [
+                rx.VarBinding(
+                    rx.Var("gv1", R.Tensor([m, n], "float32")),
+                    if_node,
+                )
+            ]
+        )
+    ]
+    func = build_function(blocks)
+    mod = tvm.IRModule.from_expr(func)
+    assert not rx.analysis.well_formed(mod, check_struct_info=False)
+
+    # on the other hand, if they're wrapped in a seq node, it's fine
+    seq = rx.SeqExpr([], x)
+    new_if_node = rx.If(cond=cond, true_branch=seq, false_branch=seq)
+    new_blocks = [
+        rx.BindingBlock(
+            [
+                rx.VarBinding(
+                    rx.Var("gv1", R.Tensor([m, n], "float32")),
+                    new_if_node,
+                )
+            ]
+        )
+    ]
+    new_func = build_function(new_blocks)
+    new_mod = tvm.IRModule.from_expr(new_func)
+    # apply normalization to fill in checked_type_
+    normalized = rx.transform.Normalize()(new_mod)
+    assert rx.analysis.well_formed(normalized, check_struct_info=True)
+
+
+def test_if_complex_condition():
+    # Error: If condition must be a leaf expression
+    cond_tuple = rx.Tuple([cond])
+    cond_idx = rx.TupleGetItem(cond_tuple, 0)
+    if_node = rx.If(cond_idx, rx.SeqExpr([], x), rx.SeqExpr([], x))
+    blocks = [
+        rx.BindingBlock(
+            [
+                rx.VarBinding(
+                    rx.Var("gv1", R.Tensor([m, n], "float32")),
+                    if_node,
+                )
+            ]
+        )
+    ]
+    func = build_function(blocks)
+    mod = tvm.IRModule.from_expr(func)
+    assert not rx.analysis.well_formed(mod, check_struct_info=False)
+
+    cond_var = rx.Var("q", R.Tensor([], "bool"))
+    new_if = rx.If(cond_var, rx.SeqExpr([], x), rx.SeqExpr([], x))
+    blocks = [
+        rx.BindingBlock(
+            [
+                rx.VarBinding(cond_var, cond_idx),
+                rx.VarBinding(
+                    rx.Var("gv1", R.Tensor([m, n], "float32")),
+                    new_if,
+                ),
+            ]
+        )
+    ]
+    func = build_function(blocks)
+    mod = tvm.IRModule.from_expr(func)
+    # apply normalization to fill in checked_type_
+    normalized = rx.transform.Normalize()(mod)
+    assert rx.analysis.well_formed(normalized, check_struct_info=True)
+
+
+def test_tuple_get_item_nested():
+    # Error: The tuple value in tuple get item must be a leaf expression
+    nested_tup = rx.Var(
+        "t", rx.TupleStructInfo([rx.TupleStructInfo([rx.TensorStructInfo([], 
"int32")])])
+    )
+    double_idx = rx.TupleGetItem(rx.TupleGetItem(nested_tup, 0), 0)
+    ret_var = rx.Var("r", R.Tensor([], "int32"))
+    f = rx.Function(
+        [nested_tup],
+        rx.SeqExpr([rx.BindingBlock([rx.VarBinding(ret_var, double_idx)])], 
ret_var),
+        ret_struct_info=R.Tensor(ndim=0, dtype="int32"),
+    )
+    f = f.with_attr("global_symbol", "f")
+    mod = tvm.IRModule.from_expr(f)
+    assert not rx.analysis.well_formed(mod, check_struct_info=False)
+
+    # okay with an intermediate binding
+    first_idx = rx.TupleGetItem(nested_tup, 0)
+    idx_var = rx.Var("v", rx.TupleStructInfo([rx.TensorStructInfo([], 
"int32")]))
+    second_idx = rx.TupleGetItem(idx_var, 0)
+    new_f = rx.Function(
+        [nested_tup],
+        rx.SeqExpr(
+            [
+                rx.BindingBlock(
+                    [rx.VarBinding(idx_var, first_idx), rx.VarBinding(ret_var, 
second_idx)]
+                )
+            ],
+            ret_var,
+        ),
+        ret_struct_info=R.Tensor(ndim=0, dtype="int32"),
+    )
+    new_f = new_f.with_attr("global_symbol", "new_f")
+    mod = tvm.IRModule.from_expr(new_f)
+    # normalize in order to fill in checked type
+    normalized = rx.transform.Normalize()(mod)
+    assert rx.analysis.well_formed(normalized, check_struct_info=True)
+
+
+def test_complex_seq_body():
+    # Error: seq expr with a body that is not a leaf expression is not 
permitted
+    x = rx.Var("x", R.Tensor([], "int32"))
+    y = rx.Var("y", R.Tensor([], "int32"))
+    func = rx.Function(
+        [x, y],
+        rx.SeqExpr([], rx.op.add(x, y)),
+        R.Tensor(ndim=0, dtype="int32"),
+    ).with_attr("global_symbol", "foo")
+    mod = tvm.IRModule.from_expr(func)
+    assert not rx.analysis.well_formed(mod, check_struct_info=False)
+
+    # but if the result is bound, then it's okay
+    z = rx.Var("z", R.Tensor([], "int32"))
+    new_func = rx.Function(
+        [x, y],
+        rx.SeqExpr(
+            [
+                rx.BindingBlock(
+                    [
+                        rx.VarBinding(
+                            var=z,
+                            value=rx.op.add(x, y),
+                        )
+                    ]
+                )
+            ],
+            z,
+        ),
+        R.Tensor(ndim=0, dtype="int32"),
+    ).with_attr("global_symbol", "foo")
+    new_mod = tvm.IRModule.from_expr(new_func)
+    # normalize in order to fill in checked type
+    normalized = rx.transform.Normalize()(new_mod)
+    assert rx.analysis.well_formed(normalized, check_struct_info=True)
+
+
+def test_ANF():
+    # Error: Nested Call
+    gv0 = rx.Var("gv0", R.Tensor([m, n], "float32"))
+    call_node = rx.op.add(x, rx.op.add(x, x))
+    bindings = [rx.VarBinding(gv0, call_node)]
+    blocks = [rx.BindingBlock(bindings)]
+    func = build_function(blocks)
+    mod = tvm.IRModule({rx.GlobalVar("foo"): func})
+    assert not rx.analysis.well_formed(mod, check_struct_info=False)
+
+    # Error: Call Node in Tuple
+    gv0 = rx.Var("gv0", R.Tensor([m, n], "float32"))
+    bindings = [rx.VarBinding(gv0, rx.Tuple((x, rx.op.add(x, x))))]
+    blocks = [rx.BindingBlock(bindings)]
+    func = build_function(blocks)
+    mod = tvm.IRModule({rx.GlobalVar("foo"): func})
+    assert not rx.analysis.well_formed(mod, check_struct_info=False)
+
+
+def test_global_var_vs_gsymbol():
+    # Error: gsymbol "main1" not equals to the name in global var "main"
+    gv0 = rx.Var("gv0", R.Tensor([m, n], "float32"))
+    bindings = [rx.VarBinding(gv0, x)]
+    blocks = [rx.DataflowBlock(bindings)]
+    func = rx.Function(
+        [x],
+        rx.SeqExpr(blocks, gv0),
+        R.Tensor(ndim=2, dtype="float32"),
+    ).with_attr("global_symbol", "main1")
+    mod = tvm.IRModule({rx.GlobalVar("main"): func})
+    assert not rx.analysis.well_formed(mod, check_struct_info=False)
+
+
+def test_nested_dataflow():
+    scalar_struct_info = rx.TensorStructInfo(shape=[], dtype="int32")
+    gv0 = rx.Var("gv0", scalar_struct_info)
+    f = rx.DataflowVar("f", rx.FuncStructInfo([], scalar_struct_info))
+    x0 = rx.DataflowVar("x0", scalar_struct_info)
+    x1 = rx.DataflowVar("x1", scalar_struct_info)
+    x2 = rx.DataflowVar("x2", scalar_struct_info)
+    y = rx.Var("y", scalar_struct_info)
+    inner_block = rx.DataflowBlock([rx.VarBinding(x0, rx.const(2, "int32")), 
rx.VarBinding(y, x0)])
+    inner_func = rx.Function([], rx.SeqExpr([inner_block], y), 
scalar_struct_info)
+    outer_block = rx.DataflowBlock(
+        [
+            rx.VarBinding(x1, rx.const(1, "int32")),
+            rx.VarBinding(f, inner_func),
+            rx.VarBinding(x2, rx.op.add(x1, rx.Call(f, []))),
+            rx.VarBinding(gv0, x2),
+        ]
+    )
+    func = rx.Function([], rx.SeqExpr([outer_block], gv0), scalar_struct_info)
+    mod = tvm.IRModule.from_expr(func)
+    normalized = rx.transform.Normalize()(mod)
+    assert rx.analysis.well_formed(normalized)
+
+
+def test_sinfo_args_tir_var_used_before_define_call_packed():
+    # Error: Symbolic Var m1, n1 are not defined
+    m1 = tir.Var("m1", "int64")
+    n1 = tir.Var("n1", "int64")
+    call = R.call_packed("my_func", x, sinfo_args=R.Tensor((m1, n1), 
"float32"))
+    func = build_function([rx.BindingBlock([rx.VarBinding(rx.Var("gv"), 
call)])])
+    mod = rx.transform.Normalize()(tvm.IRModule.from_expr(func))
+    assert not rx.analysis.well_formed(mod, check_struct_info=False)
+
+
+def test_sinfo_args_tir_var_used_before_define_call_tir():
+    # Error: Symbolic Var m1, n1 are not defined
+    m1 = tir.Var("m1", "int64")
+    n1 = tir.Var("n1", "int64")
+    call = R.call_tir("my_func", x, out_sinfo=R.Tensor((m1, n1), "float32"))
+    func = build_function([rx.BindingBlock([rx.VarBinding(rx.Var("gv"), 
call)])])
+    mod = rx.transform.Normalize()(tvm.IRModule.from_expr(func))
+    assert not rx.analysis.well_formed(mod, check_struct_info=False)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()


Reply via email to