This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 9102fab05c [Unity][Pass] Support Symbolic Shape Deduction during 
BindParam (#14154)
9102fab05c is described below

commit 9102fab05c1930f6ec3f2260d52a1f2492580a0f
Author: Siyuan Feng <[email protected]>
AuthorDate: Wed Mar 1 21:46:11 2023 +0800

    [Unity][Pass] Support Symbolic Shape Deduction during BindParam (#14154)
    
    `BindParam` replace function params to constant nodes. However, it will
    drop the shape information of the params, considering the following case:
    
    ```python
    @R.function
    def main(
        x: R.Tensor(("batch", "m"), dtype="float32"),
        w0: R.Tensor(("n", "m"), dtype="float32"),
        b0: R.Tensor(("n",), dtype="float32"),
        w1: R.Tensor(("k", "n"), dtype="float32"),
        b1: R.Tensor(("k",), dtype="float32"),
    ) -> R.Tensor(("batch", "k"), dtype="float32"):
        batch = T.Var("batch", "int64")
        k = T.Var("k", "int64")
        m = T.Var("m", "int64")
        n = T.Var("n", "int64")
        with R.dataflow():
            lv0 = R.call_tir("linear0", (x, w0, b0), out_sinfo=R.Tensor((batch, 
n), dtype="float32"))
            out = R.call_tir("linear1", (lv0, w1, b1), 
out_sinfo=R.Tensor((batch, k), dtype="float32"))
            R.output(out)
        return out
    ```
    
    The current pass will simply drop the symbolic var `n`, `k` and cause
    undefined vars during build as
    ```python
    @R.function
    def main(x: R.Tensor((1, "m"), dtype="float32")) -> 
R.Tensor(dtype="float32", ndim=2):
        m = T.Var("m", "int64")
        n = T.Var("n", "int64")
        k = T.Var("k", "int64")
        with R.dataflow():
            lv0 = R.call_tir("linear0", (x, metadata["relax.expr.Constant"][0], 
metadata["relax.expr.Constant"][1]), out_sinfo=R.Tensor((1, n), 
dtype="float32"))
            out = R.call_tir("linear1", (lv0, 
metadata["relax.expr.Constant"][2], metadata["relax.expr.Constant"][3]), 
out_sinfo=R.Tensor((1, k), dtype="float32"))
            R.output(out)
        return out
    ```
    
    This PR updates the pass to bind the symbolic shape during binding.
---
 include/tvm/relax/expr_functor.h                 |  4 +-
 include/tvm/relax/utils.h                        |  4 +-
 src/relax/analysis/struct_info_analysis.cc       |  2 +-
 src/relax/ir/expr_functor.cc                     |  4 +-
 src/relax/transform/bind_params.cc               | 64 +++++++++++++++++++++--
 src/relax/utils.cc                               | 66 +++++++++++++++++-------
 tests/python/relax/test_frontend_from_fx.py      |  2 +-
 tests/python/relax/test_transform_bind_params.py | 52 +++++++++++++++++++
 8 files changed, 167 insertions(+), 31 deletions(-)

diff --git a/include/tvm/relax/expr_functor.h b/include/tvm/relax/expr_functor.h
index 655ecc52b6..ce209ccd46 100644
--- a/include/tvm/relax/expr_functor.h
+++ b/include/tvm/relax/expr_functor.h
@@ -306,7 +306,7 @@ class ExprVisitor : public ExprFunctor<void(const Expr&)> {
     ExprVisitor* parent_;
   };
   // This visitor is not visible to child classes and only
-  // used to supportd default visiting behavior.
+  // used to supported default visiting behavior.
   DefaultStructInfoFieldVisitor default_struct_info_field_visitor_{this};
 };
 
@@ -409,7 +409,7 @@ class ExprMutatorBase : public ExprFunctor<Expr(const 
Expr&)> {
     ExprMutatorBase* parent_;
   };
   // This visitor is not visible to child classes and only
-  // used to supportd default visiting behavior.
+  // used to supported default visiting behavior.
   DefaultStructInfoFieldMutator default_struct_info_field_mutator_{this};
 };
 
diff --git a/include/tvm/relax/utils.h b/include/tvm/relax/utils.h
index dd0200623a..e7d928c4ae 100644
--- a/include/tvm/relax/utils.h
+++ b/include/tvm/relax/utils.h
@@ -103,10 +103,12 @@ class NameTable {
  * \param expr The input expression.
  * \param binds The variable to expression map that will be used to help the
  *        binding.
+ * \param symbolic_var_map The map from symbolic var to the expr it binds to.
  *
  * \return The updated expression.
  */
-TVM_DLL Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& binds);
+TVM_DLL Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& binds,
+                  const tvm::Map<tir::Var, PrimExpr>& symbolic_var_map = {});
 
 /*!
  * \brief Check if the given StructInfo is for a boolean scalar (tensor of 
rank 0 with a boolean
diff --git a/src/relax/analysis/struct_info_analysis.cc 
b/src/relax/analysis/struct_info_analysis.cc
index 2de06fe5d6..7dfcd60c95 100644
--- a/src/relax/analysis/struct_info_analysis.cc
+++ b/src/relax/analysis/struct_info_analysis.cc
@@ -330,7 +330,7 @@ class StructInfoBaseChecker
       return BaseCheckResult::kFailL0;
     }
 
-    // ndim msiamtch
+    // ndim mismatch
     if (!lhs->IsUnknownNdim() && lhs->ndim != rhs->ndim) {
       if (rhs->IsUnknownNdim()) return BaseCheckResult::kFailL1;
       return BaseCheckResult::kFailL0;
diff --git a/src/relax/ir/expr_functor.cc b/src/relax/ir/expr_functor.cc
index 4c4b68f3d2..174d40053f 100644
--- a/src/relax/ir/expr_functor.cc
+++ b/src/relax/ir/expr_functor.cc
@@ -651,7 +651,7 @@ RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(DataTypeImmNode);
 void ExprMutator::ReEmitBinding(const VarBindingNode* binding, Expr new_value) 
{
   Var new_var = this->VisitVarDef(binding->var);
 
-  // fast path: reemit binding if nothing changes
+  // fast path: re-emit binding if nothing changes
   if (new_var.same_as(binding->var) && new_value.same_as(binding->value)) {
     builder_->EmitNormalized(GetRef<VarBinding>(binding));
     return;
@@ -660,8 +660,8 @@ void ExprMutator::ReEmitBinding(const VarBindingNode* 
binding, Expr new_value) {
   Var temp = WithStructInfo(new_var, GetStructInfo(new_value));
   if (!temp.same_as(new_var)) {
     new_var = temp;
-    this->var_remap_[binding->var->vid] = new_var;
   }
+  this->var_remap_[binding->var->vid] = new_var;
 
   builder_->EmitNormalized(VarBinding(new_var, new_value));
 }
diff --git a/src/relax/transform/bind_params.cc 
b/src/relax/transform/bind_params.cc
index 1de8d94461..c444a84f44 100644
--- a/src/relax/transform/bind_params.cc
+++ b/src/relax/transform/bind_params.cc
@@ -30,6 +30,57 @@
 namespace tvm {
 namespace relax {
 
+void MatchSymbolicVar(const Expr& arg, const Expr& constant,
+                      Map<tir::Var, PrimExpr>* symbolic_var_map, 
arith::Analyzer* analyzer_) {
+  auto opt_arg_sinfo = MatchStructInfo<TensorStructInfo>(arg);
+  CHECK(opt_arg_sinfo)
+      << "The struct info of the bound parameter is expected to be 
TensorStructInfo, but got: "
+      << GetStructInfo(arg);
+  auto opt_const_sinfo = MatchStructInfo<TensorStructInfo>(constant);
+  // As the constant is generated by internal codes, we use ICHECK here.
+  ICHECK(opt_const_sinfo)
+      << "The struct info of the bound weight is expected to be 
TensorStructInfo, but got: "
+      << GetStructInfo(constant);
+
+  TensorStructInfo arg_sinfo = opt_arg_sinfo.value();
+  TensorStructInfo const_sinfo = opt_const_sinfo.value();
+  ICHECK(!const_sinfo->IsUnknownDtype());
+  ICHECK(!const_sinfo->IsUnknownNdim());
+  ICHECK(const_sinfo->shape.defined());
+
+  // dtype mismatch
+  if (!arg_sinfo->IsUnknownDtype() && arg_sinfo->dtype != const_sinfo->dtype) {
+    LOG(FATAL) << "The dtype of the bound parameter is expected to be " << 
arg_sinfo->dtype
+               << ", but got: " << const_sinfo->dtype;
+  }
+  // ndim mismatch
+  if (!arg_sinfo->IsUnknownNdim() && arg_sinfo->ndim != const_sinfo->ndim) {
+    LOG(FATAL) << "The ndim of the bound parameter is expected to be " << 
arg_sinfo->ndim
+               << ", but got: " << const_sinfo->ndim;
+  }
+  if (!arg_sinfo->shape.defined()) return;
+  const auto* arg_shape = arg_sinfo->shape.value().as<ShapeExprNode>();
+  const auto* const_shape = const_sinfo->shape.value().as<ShapeExprNode>();
+
+  CHECK(arg_shape && const_shape)
+      << "The shape of the bound parameter and weight is expected to be 
ShapeExprNode for now";
+
+  for (int i = 0; i < arg_sinfo->ndim; ++i) {
+    const PrimExpr& const_dim = const_shape->values[i];
+    ICHECK(tir::is_const_int(const_dim));
+    if (const auto* shape_var = arg_shape->values[i].as<tir::VarNode>()) {
+      auto it = symbolic_var_map->find(GetRef<tir::Var>(shape_var));
+      if (it == symbolic_var_map->end()) {
+        symbolic_var_map->Set(GetRef<tir::Var>(shape_var), const_dim);
+      } else {
+        CHECK(analyzer_->CanProveEqual((*it).second, const_dim))
+            << "The shape of the bound parameter is expected to be " << 
(*it).second
+            << ", but got: " << const_dim;
+      }
+    }
+  }
+}
+
 /*!
  * \brief Bind params to function by using name
  * \param func Relax function
@@ -48,18 +99,23 @@ inline Function BindParamsByName(Function func, const 
Map<String, runtime::NDArr
     }
   }
 
-  std::unordered_map<Var, Expr, ObjectPtrHash, ObjectPtrEqual> bind_dict;
+  arith::Analyzer analyzer;
+  Map<Var, Expr> bind_dict;
+  Map<tir::Var, PrimExpr> symbolic_var_map;
+
   for (auto& kv : params) {
     if (name_dict.count(kv.first) == 0) {
       continue;
     }
-    auto arg = name_dict.at(kv.first);
+    const Var& arg = name_dict.at(kv.first);
     if (repeat_var.count(arg)) {
       LOG(FATAL) << "ValueError: Multiple args in the function have name " << 
kv.first;
     }
-    bind_dict[arg] = Constant(kv.second);
+    Expr const_expr = Constant(kv.second);
+    bind_dict.Set(arg, const_expr);
+    MatchSymbolicVar(arg, const_expr, &symbolic_var_map, &analyzer);
   }
-  Expr bound_expr = Bind(func, bind_dict);
+  Expr bound_expr = Bind(func, bind_dict, symbolic_var_map);
   Function ret = Downcast<Function>(bound_expr);
   ICHECK(ret.defined()) << "The returning type is expected to be a Relax 
Function."
                         << "\n";
diff --git a/src/relax/utils.cc b/src/relax/utils.cc
index 1cf64cbf64..cf1d9bed98 100644
--- a/src/relax/utils.cc
+++ b/src/relax/utils.cc
@@ -25,7 +25,36 @@ namespace relax {
 /*! \brief Helper to implement bind params.*/
 class ExprBinder : public ExprMutator {
  public:
-  explicit ExprBinder(const tvm::Map<Var, Expr>& args_map) : 
args_map_(args_map) {}
+  explicit ExprBinder(const tvm::Map<Var, Expr>& args_map,
+                      const tvm::Map<tir::Var, PrimExpr>& symbolic_var_map)
+      : args_map_(args_map), symbolic_var_map_(symbolic_var_map) {}
+
+ private:
+  Expr VisitExpr_(const FunctionNode* op) final {
+    tvm::Array<Var> params;
+    bool all_params_unchanged = true;
+    for (const Var& param : op->params) {
+      if (args_map_.count(param)) {
+        all_params_unchanged = false;
+      } else {
+        Var new_param = this->VisitVarDef(param);
+        params.push_back(new_param);
+        if (!param.same_as(new_param)) {
+          this->var_remap_[param->vid] = new_param;
+          all_params_unchanged = false;
+        }
+      }
+    }
+
+    Expr body = this->VisitWithNewScope(op->body, params);
+
+    // FuncStructInfo does not depend on Expr
+    if (all_params_unchanged && body.same_as(op->body)) {
+      return GetRef<Expr>(op);
+    } else {
+      return Function(params, body, 
VisitExprDepStructInfoField(op->ret_struct_info), op->attrs);
+    }
+  }
 
   Expr VisitExpr_(const VarNode* op) final {
     auto id = GetRef<Var>(op);
@@ -37,34 +66,31 @@ class ExprBinder : public ExprMutator {
     }
   }
 
+  PrimExpr VisitPrimExpr(const PrimExpr& expr) final {
+    if (const tir::VarNode* var = expr.as<tir::VarNode>()) {
+      auto it = symbolic_var_map_.find(GetRef<tir::Var>(var));
+      if (it != symbolic_var_map_.end()) {
+        return (*it).second;
+      }
+    }
+    return ExprMutator::VisitPrimExpr(expr);
+  }
+
  private:
   const tvm::Map<Var, Expr>& args_map_;
+  const tvm::Map<tir::Var, PrimExpr>& symbolic_var_map_;
 };
 
 /*!
  * \brief Bind params on expr
  * \param expr The expr where to bind params
- * \param args_map The map from param var to the expr it binds to
+ * \param binds The map from param var to the expr it binds to
+ * \param symbolic_var_map The map from symbolic var to the expr it binds to
  * \return The result expr after bind params
  */
-Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& args_map) {
-  if (const FunctionNode* func = expr.as<FunctionNode>()) {
-    Expr new_body = ExprBinder(args_map).VisitExpr(func->body);
-    Array<Var> new_params;
-    for (size_t i = 0; i < func->params.size(); ++i) {
-      if (!args_map.count(func->params[i])) {
-        new_params.push_back(func->params[i]);
-      }
-    }
-    if (new_body.same_as(func->body) && new_params.size() == 
func->params.size()) {
-      return expr;
-    }
-    // The checked_type_ of the new function is deduced from the function body
-    // TODO(@relax-team): Should infer the shape from the body as well
-    return Function(new_params, new_body, NullOpt, func->attrs);
-  } else {
-    return ExprBinder(args_map).VisitExpr(expr);
-  }
+Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& binds,
+          const tvm::Map<tir::Var, PrimExpr>& symbolic_var_map) {
+  return ExprBinder(binds, symbolic_var_map).VisitExpr(expr);
 }
 
 bool IsBoolStructInfo(const StructInfo& sinfo, bool permit_unknown_rank,
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index 24ed9946a3..e216010667 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -142,7 +142,7 @@ def test_linear():
         def main(
             input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
             w1: R.Tensor((7, 10), dtype="float32"),
-            w2: R.Tensor((1, 7), dtype="float32"),
+            w2: R.Tensor((7,), dtype="float32"),
         ) -> R.Tensor((1, 3, 10, 7), dtype="float32"):
             # block 0
             with R.dataflow():
diff --git a/tests/python/relax/test_transform_bind_params.py 
b/tests/python/relax/test_transform_bind_params.py
index ceaf8fb165..1dfd9e0c8e 100644
--- a/tests/python/relax/test_transform_bind_params.py
+++ b/tests/python/relax/test_transform_bind_params.py
@@ -71,5 +71,57 @@ def test_bind_params(use_np_array):
     tvm.testing.assert_allclose(res_before.numpy(), res_after.numpy())
 
 
+def test_bind_params_symbolic_vars():
+    @tvm.script.ir_module
+    class Before:
+        @R.function
+        def main(
+            x: R.Tensor(("batch", "m"), dtype="float32"),
+            w0: R.Tensor(("n", "m"), dtype="float32"),
+            b0: R.Tensor(("n",), dtype="float32"),
+            w1: R.Tensor(("k", "n"), dtype="float32"),
+            b1: R.Tensor(("k",), dtype="float32"),
+        ) -> R.Tensor(("batch", "k"), dtype="float32"):
+            batch = T.Var("batch", "int64")
+            k = T.Var("k", "int64")
+            m = T.Var("m", "int64")
+            n = T.Var("n", "int64")
+            with R.dataflow():
+                lv0 = R.call_tir(
+                    "linear0", (x, w0, b0), out_sinfo=R.Tensor((batch, n), 
dtype="float32")
+                )
+                out = R.call_tir(
+                    "linear1", (lv0, w1, b1), out_sinfo=R.Tensor((batch, k), 
dtype="float32")
+                )
+                R.output(out)
+            return out
+
+    m, n, k = 4, 6, 8
+    w0_tvm = tvm.nd.array(np.random.rand(n, m).astype(np.float32))
+    b0_tvm = tvm.nd.array(np.random.rand(n).astype(np.float32))
+    w1_tvm = tvm.nd.array(np.random.rand(k, n).astype(np.float32))
+    b1_tvm = tvm.nd.array(np.random.rand(k).astype(np.float32))
+    params_dict = {"w0": w0_tvm, "b0": b0_tvm, "w1": w1_tvm, "b1": b1_tvm}
+    mod = relax.transform.BindParams("main", params_dict)(Before)
+
+    # Since it contains ConstantNode, it's hard to check with structural 
equality.
+    func = mod["main"]
+    assert len(func.params) == 1
+    batch = func.params[0].struct_info.shape[0]
+    tvm.ir.assert_structural_equal(
+        func.params[0].struct_info, relax.TensorStructInfo((batch, 4), 
"float32")
+    )
+    tvm.ir.assert_structural_equal(
+        func.ret_struct_info, relax.TensorStructInfo((batch, 8), "float32")
+    )
+    bindings = func.body.blocks[0].bindings
+    tvm.ir.assert_structural_equal(
+        bindings[0].var.struct_info, relax.TensorStructInfo((batch, 6), 
"float32")
+    )
+    tvm.ir.assert_structural_equal(
+        bindings[1].var.struct_info, relax.TensorStructInfo((batch, 8), 
"float32")
+    )
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to