This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 9ddd39be56 [Unity][UX] Symbolic Variables Used in Multiple Functions 
(#14606)
9ddd39be56 is described below

commit 9ddd39be56bcd62905269d6386b44e9989639550
Author: Chaofan Lin <[email protected]>
AuthorDate: Fri Apr 14 01:37:41 2023 +0800

    [Unity][UX] Symbolic Variables Used in Multiple Functions (#14606)
    
    Prior to this PR, there is no constraint to prevent user defining multiple 
functions which may use the same symbolic TIR var. For example, user may write 
the following script:
    ```
    batch_size = tir.Var("batch_size", "int64")
    
    @I.ir_module
    class Test:
        @R.function
        def main(
            x: R.Tensor((batch_size, 10), "float32"),
        ):
            with R.dataflow():
                lv = R.sum(x, axis=1)
                lv1 = R.mean(x, axis=1)
                out = R.add(lv, lv1)
                R.output(out)
            return out
    
        @R.function
        def main1(
            x: R.Tensor((batch_size, 10), "float32"),
            y: R.Tensor((batch_size, 10), "float32"),
        ):
            with R.dataflow():
                out = R.subtract(x, y)
                R.output(out)
            return out
    
    lowered_mod = LegalizeOps()(Test)
    
    ex = relax.build(lowered_mod, "llvm")
    vm = relax.VirtualMachine(ex, tvm.cpu(0))
    ```
    But this script can not be built because `VMShapeLower` can not handle this 
case well.
    Since it is a illegal behaviour, we need to prevent user from doing this. 
Specifically, this PR contains two parts of work:
    - Let well form checker to check this case.
    - Let `CopyWithNewVars` util copies the symbolic vars in the struct info 
inside the function.
---
 src/relax/analysis/well_formed.cc               | 57 ++++++++++++++++++-----
 src/relax/transform/fuse_ops.cc                 | 40 +---------------
 src/relax/transform/utils.h                     | 62 +++++++++++++++++++++++++
 src/relax/utils.cc                              | 11 ++++-
 tests/python/relax/test_analysis_well_formed.py | 14 ++++++
 tests/python/relax/test_utils.py                | 15 ++++++
 6 files changed, 147 insertions(+), 52 deletions(-)

diff --git a/src/relax/analysis/well_formed.cc 
b/src/relax/analysis/well_formed.cc
index 3eeefd0be5..aeae975bf5 100644
--- a/src/relax/analysis/well_formed.cc
+++ b/src/relax/analysis/well_formed.cc
@@ -30,16 +30,17 @@
  *    3. When a Function has a corresponding GlobalVar and a `global_symbol`
  *       attribute, the name of the GlobalVar must equal the value of the
  *       `global_symbol` attribute value.
- *    4. Any variable cannot used as different function parameters in the same 
IRModule
- *    5. Vars are defined before use.
- *    6. Vars are defined exactly once.
- *    7. Symbolic Vars are defined before use.
- *    8. DataflowVars cannot be defined inside BindingBlock.
- *    9. Vars defined in IfNode, except the return Var, are invisible
+ *    4. Any variable cannot used as different function parameters in the same 
IRModule.
+ *    5. Any symbolic var cannot present across different functions in the 
same IRModule.
+ *    6. Vars are defined before use.
+ *    7. Vars are defined exactly once.
+ *    8. Symbolic Vars are defined before use.
+ *    9. DataflowVars cannot be defined inside BindingBlock.
+ *    10. Vars defined in IfNode, except the return Var, are invisible
  *       out of the If body.(May change for new AST designs)
- *    10. SeqExpr only serves as function body, or in the true and
+ *    11. SeqExpr only serves as function body, or in the true and
  *       false branches in IfNode.
- *    11. The IR is in ANF:
+ *    12. The IR is in ANF:
  *       (a) Expressions cannot contain nested complex expressions.
  *           Here are the expressions that may be nested inside other 
expressions:
  *           Var, DataflowVar, GlobalVar, Constant, ShapeExpr,
@@ -54,7 +55,7 @@
  *           * The cond field of If nodes
  *           * The op or args fields of Call nodes
  *           * Inside the fields of Tuple nodes
- *    12. Expr always has checked_type_ (with the exception of Op).
+ *    13. Expr always has checked_type_ (with the exception of Op).
  */
 #include <tvm/relax/analysis.h>
 #include <tvm/relax/expr.h>
@@ -92,7 +93,7 @@ class WellFormedChecker : public relax::ExprVisitor,
 
  private:
   explicit WellFormedChecker(IRModule mod, bool check_struct_info)
-      : mod_(std::move(mod)), check_struct_info_(check_struct_info) {}
+      : mod_(std::move(mod)), check_struct_info_(check_struct_info), 
cur_visited_func_(nullptr) {}
 
   using relax::ExprVisitor::VisitExpr_;
   using tir::ExprVisitor::VisitExpr;
@@ -196,6 +197,12 @@ class WellFormedChecker : public relax::ExprVisitor,
   }
 
   void VisitExpr_(const FunctionNode* op) final {
+    // set current visited function.
+    // for nested functions, we only set the outermost function.
+    if (cur_visited_func_ == nullptr) {
+      cur_visited_func_ = op;
+    }
+
     // save the var_set_ for local function
     auto prev_var_set = var_set_;
     auto prev_dataflow_var_set = dataflow_var_set_;
@@ -223,7 +230,7 @@ class WellFormedChecker : public relax::ExprVisitor,
                   << "Relax variable " << param->name_hint()
                   << " is repeatedly used as parameters in function.");
       }
-      param_var_func_map_.insert({param, GetRef<Function>(op)});
+      param_var_func_map_.insert({param, cur_visited_func_});
     }
     // check function ret_struct_info
     if (op->ret_struct_info.defined()) {
@@ -242,11 +249,18 @@ class WellFormedChecker : public relax::ExprVisitor,
     dataflow_var_set_ = prev_dataflow_var_set;
     var_set_ = prev_var_set;
     symbolic_var_set_ = prev_symbolic_var_set;
+
+    if (cur_visited_func_ == op) {
+      cur_visited_func_ = nullptr;
+    }
   }
 
   void VisitExpr_(const CallNode* op) final {
     if (IsLeafOrTuple(op->op)) {
+      const FunctionNode* prev_visited_func = cur_visited_func_;
+      cur_visited_func_ = nullptr;  // close the symbolic var dup check
       this->VisitExpr(op->op);
+      cur_visited_func_ = prev_visited_func;
     } else {
       Malformed(Diagnostic::Error(op) << "The called expression must be a leaf 
expression");
     }
@@ -400,6 +414,21 @@ class WellFormedChecker : public relax::ExprVisitor,
       this->Malformed(Diagnostic::Error(var)
                       << "Symbolic Var " << var->name_hint << " is not 
defined.");
     }
+
+    // don't perform the check
+    if (cur_visited_func_ == nullptr) {
+      return;
+    }
+
+    // check across functions presence
+    auto it = symbolic_var_func_map_.find(var);
+    if (it != symbolic_var_func_map_.end() && it->second != cur_visited_func_) 
{
+      // TODO(relax-team): Complete this error info after we integrate printer
+      Malformed(Diagnostic::Error(var->span)
+                << "Symbolic Var " << var->name_hint
+                << " presents in different functions in the same Module.");
+    }
+    symbolic_var_func_map_.insert({var, cur_visited_func_});
   }
 
   void VisitStructInfo_(const FuncStructInfoNode* op) final {
@@ -473,6 +502,8 @@ class WellFormedChecker : public relax::ExprVisitor,
   const bool check_struct_info_;
   bool well_formed_ = true;
   bool is_dataflow_;
+  // Current visited function.
+  const FunctionNode* cur_visited_func_;
   // Current visit mode.
   VisitMode mode_ = VisitMode::kDefault;
   // set of context variables.
@@ -480,7 +511,9 @@ class WellFormedChecker : public relax::ExprVisitor,
   std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> recur_vars_;
   std::unordered_set<DataflowVar, ObjectPtrHash, ObjectPtrEqual> 
dataflow_var_set_;
   std::unordered_set<tir::Var, ObjectPtrHash, ObjectPtrEqual> 
symbolic_var_set_;
-  std::unordered_map<Var, Function, ObjectPtrHash, ObjectPtrEqual> 
param_var_func_map_;
+  std::unordered_map<Var, const FunctionNode*, ObjectPtrHash, ObjectPtrEqual> 
param_var_func_map_;
+  std::unordered_map<tir::Var, const FunctionNode*, ObjectPtrHash, 
ObjectPtrEqual>
+      symbolic_var_func_map_;
 };
 
 bool WellFormed(IRModule m, bool check_struct_info) {
diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc
index a49ae86267..adce61f4b8 100644
--- a/src/relax/transform/fuse_ops.cc
+++ b/src/relax/transform/fuse_ops.cc
@@ -42,6 +42,7 @@
 #include "../../relay/analysis/graph_partitioner.h"
 #include "../../support/arena.h"
 #include "tvm/relax/expr.h"
+#include "utils.h"
 
 namespace tvm {
 namespace relax {
@@ -345,45 +346,6 @@ class GraphCreator : public ExprVisitor {
   std::unordered_set<IndexedForwardGraph::Node*> initialized_nodes_;
 };
 
-/*!
- * \brief Renew the definition of symbolic vars in Relax.
- * \details This mutator is used to prevent the same symbolic var from being 
used in different
- *          functions, which is malformed.
- */
-class SymbolicVarRenewMutator : public ExprMutator, tir::ExprMutator {
- public:
-  static Function Renew(const Function& function) {
-    SymbolicVarRenewMutator mutator;
-    return Downcast<Function>(mutator.VisitExpr(function));
-  }
-
- private:
-  SymbolicVarRenewMutator() = default;
-  using relax::ExprMutator::VisitExpr;
-  using relax::ExprMutator::VisitExpr_;
-  using tir::ExprMutator::VisitExpr_;
-
-  PrimExpr VisitPrimExpr(const PrimExpr& expr) final { return 
tir::ExprMutator::VisitExpr(expr); }
-
-  // TODO(Siyuan): enhance the method to the following steps:
-  // 1. Visit and replace all tir::Vars at the definition point
-  // 2. Revisit the function again and update the use side.
-  PrimExpr VisitExpr_(const tir::VarNode* op) final {
-    auto it = var_map_.find(GetRef<tir::Var>(op));
-    if (it != var_map_.end()) {
-      return (*it).second;
-    } else {
-      auto n = make_object<tir::VarNode>(*op);
-      tir::Var v(n);
-      var_map_.Set(GetRef<tir::Var>(op), v);
-      return v;
-    }
-  }
-
- private:
-  Map<tir::Var, tir::Var> var_map_;
-};
-
 /*!
  * \brief The ExprMutator used to create a new grouped function
  * \details The workflow of this ExprMutator is:
diff --git a/src/relax/transform/utils.h b/src/relax/transform/utils.h
index 9334fd8347..5363bca68b 100644
--- a/src/relax/transform/utils.h
+++ b/src/relax/transform/utils.h
@@ -28,6 +28,7 @@
 #include <tvm/ir/module.h>
 #include <tvm/relax/expr.h>
 #include <tvm/relax/expr_functor.h>
+#include <tvm/tir/expr_functor.h>
 
 #include <algorithm>
 #include <string>
@@ -222,6 +223,67 @@ class VarReplacer : public ExprMutator {
   const VarMap& var_remap_;
 };
 
+/*!
+ * \brief Renew the definition of symbolic vars in Relax.
+ * \details This mutator is used to prevent the same symbolic var from being 
used in different
+ *          functions, which is malformed.
+ */
+class SymbolicVarRenewMutator : public ExprMutator, tir::ExprMutator {
+ public:
+  static Function Renew(const Function& function) {
+    SymbolicVarRenewMutator mutator;
+    return Downcast<Function>(mutator.VisitExpr(function));
+  }
+
+ private:
+  SymbolicVarRenewMutator() = default;
+  using relax::ExprMutator::VisitExpr;
+  using relax::ExprMutator::VisitExpr_;
+  using tir::ExprMutator::VisitExpr_;
+
+  PrimExpr VisitPrimExpr(const PrimExpr& expr) final { return 
tir::ExprMutator::VisitExpr(expr); }
+
+  // TODO(Siyuan): enhance the method to the following steps:
+  // 1. Visit and replace all tir::Vars at the definition point
+  // 2. Revisit the function again and update the use side.
+  PrimExpr VisitExpr_(const tir::VarNode* op) final {
+    auto it = var_map_.find(GetRef<tir::Var>(op));
+    if (it != var_map_.end()) {
+      return (*it).second;
+    } else {
+      auto n = make_object<tir::VarNode>(*op);
+      tir::Var v(n);
+      var_map_.Set(GetRef<tir::Var>(op), v);
+      return v;
+    }
+  }
+
+  Expr VisitExpr_(const FunctionNode* op) {
+    tvm::Array<Var> params;
+    bool all_params_unchanged = true;
+    for (Var param : op->params) {
+      Var new_param = this->VisitVarDef(param);
+      params.push_back(new_param);
+      if (!param.same_as(new_param)) {
+        var_remap_[param->vid] = new_param;
+        all_params_unchanged = false;
+      }
+    }
+
+    Expr body = this->VisitWithNewScope(op->body, params);
+
+    if (all_params_unchanged && body.same_as(op->body)) {
+      return GetRef<Expr>(op);
+    } else {
+      auto new_ret_sinfo = 
this->VisitExprDepStructInfoField(op->ret_struct_info);
+      return Function(params, body, new_ret_sinfo, op->attrs);
+    }
+  }
+
+ private:
+  Map<tir::Var, tir::Var> var_map_;
+};
+
 /*!
  * \brief Create a Constant with a scalar
  *
diff --git a/src/relax/utils.cc b/src/relax/utils.cc
index cf1d9bed98..131ed6c7d0 100644
--- a/src/relax/utils.cc
+++ b/src/relax/utils.cc
@@ -17,6 +17,8 @@
  * under the License.
  */
 
+#include "transform/utils.h"
+
 #include <tvm/relax/expr_functor.h>
 
 namespace tvm {
@@ -109,6 +111,7 @@ bool IsLeafOrTuple(const Expr& expr) {
          expr.as<OpNode>() || expr.as<TupleNode>();
 }
 
+/*! \brief Helper to implement CopyWithNewVars.*/
 class FunctionCopier : public ExprMutator {
  public:
   static Function Transform(Function func) {
@@ -116,7 +119,8 @@ class FunctionCopier : public ExprMutator {
     // All variables that are bound inside the original function would be 
copied
     // to satisfy the restriction in the well-formed check: Variables in Relax
     // must be bound exactly once.
-    return Downcast<Function>(copier.VisitExpr(func));
+    auto new_func = Downcast<Function>(copier.VisitExpr(func));
+    return SymbolicVarRenewMutator::Renew(new_func);
   }
 
   Var VisitVarDef_(const DataflowVarNode* var) override {
@@ -134,6 +138,11 @@ class FunctionCopier : public ExprMutator {
   }
 };
 
+/*!
+ * \brief Copy a new Relax function with new remapped vars and symbolic vars.
+ * \param func The Relax function we want to copy.
+ * \return The copied function.
+ */
 Function CopyWithNewVars(Function func) { return 
FunctionCopier::Transform(func); }
 
 TVM_REGISTER_GLOBAL("relax.CopyWithNewVars").set_body_typed(CopyWithNewVars);
diff --git a/tests/python/relax/test_analysis_well_formed.py 
b/tests/python/relax/test_analysis_well_formed.py
index b4b68504a4..97f076dc6c 100644
--- a/tests/python/relax/test_analysis_well_formed.py
+++ b/tests/python/relax/test_analysis_well_formed.py
@@ -142,6 +142,20 @@ def test_symbolic_var():
     assert not rx.analysis.well_formed(mod, check_struct_info=False)
 
 
+def test_symbolic_var_across_functions():
+    # Error: Symbolic Var s presents across different functions
+    s = tir.Var("s", "int64")
+    v0 = rx.Var("v0", R.Tensor([5, s], "float32"))
+    v1 = rx.Var("v1", R.Tensor([s, 7], "float32"))
+    bb = rx.BlockBuilder()
+    with bb.function("func1", [v0]):
+        bb.emit_func_output(v0)
+    with bb.function("func2", [v1]):
+        bb.emit_func_output(v1)
+    mod = bb.get()
+    assert not rx.analysis.well_formed(mod, check_struct_info=False)
+
+
 def test_symbolic_var_invalid_type():
     with pytest.raises(
         tvm.TVMError, match="the value in ShapeStructInfo can only have dtype 
of int64"
diff --git a/tests/python/relax/test_utils.py b/tests/python/relax/test_utils.py
index 15122dab37..c55876a3ba 100644
--- a/tests/python/relax/test_utils.py
+++ b/tests/python/relax/test_utils.py
@@ -36,6 +36,21 @@ def test_copy_with_new_vars():
         assert before_var != after_var
 
 
+def test_copy_with_new_vars_copied_symbolic_vars():
+    @R.function
+    def before(x: R.Tensor(("m",), "float32"), y: R.Tensor(("m",), "float32")):
+        gv = R.add(x, y)
+        return gv
+
+    after = relax.utils.copy_with_new_vars(before)
+    assert_structural_equal(after, before)
+
+    assert len(after.params) == len(before.params)
+    for before_var, after_var in zip(before.params, after.params):
+        assert before_var != after_var
+        assert before_var.struct_info.shape[0] != 
after_var.struct_info.shape[0]
+
+
 def test_copy_with_new_vars_on_ir_module():
     @tvm.script.ir_module
     class Actual:

Reply via email to