This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 9102fab05c [Unity][Pass] Support Symbolic Shape Deduction during
BindParam (#14154)
9102fab05c is described below
commit 9102fab05c1930f6ec3f2260d52a1f2492580a0f
Author: Siyuan Feng <[email protected]>
AuthorDate: Wed Mar 1 21:46:11 2023 +0800
[Unity][Pass] Support Symbolic Shape Deduction during BindParam (#14154)
`BindParam` replace function params to constant nodes. However, it will
drop the shape information of the params, considering the following case:
```python
@R.function
def main(
x: R.Tensor(("batch", "m"), dtype="float32"),
w0: R.Tensor(("n", "m"), dtype="float32"),
b0: R.Tensor(("n",), dtype="float32"),
w1: R.Tensor(("k", "n"), dtype="float32"),
b1: R.Tensor(("k",), dtype="float32"),
) -> R.Tensor(("batch", "k"), dtype="float32"):
batch = T.Var("batch", "int64")
k = T.Var("k", "int64")
m = T.Var("m", "int64")
n = T.Var("n", "int64")
with R.dataflow():
lv0 = R.call_tir("linear0", (x, w0, b0), out_sinfo=R.Tensor((batch,
n), dtype="float32"))
out = R.call_tir("linear1", (lv0, w1, b1),
out_sinfo=R.Tensor((batch, k), dtype="float32"))
R.output(out)
return out
```
The current pass will simply drop the symbolic var `n`, `k` and cause
undefined vars during build as
```python
@R.function
def main(x: R.Tensor((1, "m"), dtype="float32")) ->
R.Tensor(dtype="float32", ndim=2):
m = T.Var("m", "int64")
n = T.Var("n", "int64")
k = T.Var("k", "int64")
with R.dataflow():
lv0 = R.call_tir("linear0", (x, metadata["relax.expr.Constant"][0],
metadata["relax.expr.Constant"][1]), out_sinfo=R.Tensor((1, n),
dtype="float32"))
out = R.call_tir("linear1", (lv0,
metadata["relax.expr.Constant"][2], metadata["relax.expr.Constant"][3]),
out_sinfo=R.Tensor((1, k), dtype="float32"))
R.output(out)
return out
```
This PR updates the pass to bind the symbolic shape during binding.
---
include/tvm/relax/expr_functor.h | 4 +-
include/tvm/relax/utils.h | 4 +-
src/relax/analysis/struct_info_analysis.cc | 2 +-
src/relax/ir/expr_functor.cc | 4 +-
src/relax/transform/bind_params.cc | 64 +++++++++++++++++++++--
src/relax/utils.cc | 66 +++++++++++++++++-------
tests/python/relax/test_frontend_from_fx.py | 2 +-
tests/python/relax/test_transform_bind_params.py | 52 +++++++++++++++++++
8 files changed, 167 insertions(+), 31 deletions(-)
diff --git a/include/tvm/relax/expr_functor.h b/include/tvm/relax/expr_functor.h
index 655ecc52b6..ce209ccd46 100644
--- a/include/tvm/relax/expr_functor.h
+++ b/include/tvm/relax/expr_functor.h
@@ -306,7 +306,7 @@ class ExprVisitor : public ExprFunctor<void(const Expr&)> {
ExprVisitor* parent_;
};
// This visitor is not visible to child classes and only
- // used to supportd default visiting behavior.
+ // used to supported default visiting behavior.
DefaultStructInfoFieldVisitor default_struct_info_field_visitor_{this};
};
@@ -409,7 +409,7 @@ class ExprMutatorBase : public ExprFunctor<Expr(const
Expr&)> {
ExprMutatorBase* parent_;
};
// This visitor is not visible to child classes and only
- // used to supportd default visiting behavior.
+ // used to supported default visiting behavior.
DefaultStructInfoFieldMutator default_struct_info_field_mutator_{this};
};
diff --git a/include/tvm/relax/utils.h b/include/tvm/relax/utils.h
index dd0200623a..e7d928c4ae 100644
--- a/include/tvm/relax/utils.h
+++ b/include/tvm/relax/utils.h
@@ -103,10 +103,12 @@ class NameTable {
* \param expr The input expression.
* \param binds The variable to expression map that will be used to help the
* binding.
+ * \param symbolic_var_map The map from symbolic var to the expr it binds to.
*
* \return The updated expression.
*/
-TVM_DLL Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& binds);
+TVM_DLL Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& binds,
+ const tvm::Map<tir::Var, PrimExpr>& symbolic_var_map = {});
/*!
* \brief Check if the given StructInfo is for a boolean scalar (tensor of
rank 0 with a boolean
diff --git a/src/relax/analysis/struct_info_analysis.cc
b/src/relax/analysis/struct_info_analysis.cc
index 2de06fe5d6..7dfcd60c95 100644
--- a/src/relax/analysis/struct_info_analysis.cc
+++ b/src/relax/analysis/struct_info_analysis.cc
@@ -330,7 +330,7 @@ class StructInfoBaseChecker
return BaseCheckResult::kFailL0;
}
- // ndim msiamtch
+ // ndim mismatch
if (!lhs->IsUnknownNdim() && lhs->ndim != rhs->ndim) {
if (rhs->IsUnknownNdim()) return BaseCheckResult::kFailL1;
return BaseCheckResult::kFailL0;
diff --git a/src/relax/ir/expr_functor.cc b/src/relax/ir/expr_functor.cc
index 4c4b68f3d2..174d40053f 100644
--- a/src/relax/ir/expr_functor.cc
+++ b/src/relax/ir/expr_functor.cc
@@ -651,7 +651,7 @@ RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(DataTypeImmNode);
void ExprMutator::ReEmitBinding(const VarBindingNode* binding, Expr new_value)
{
Var new_var = this->VisitVarDef(binding->var);
- // fast path: reemit binding if nothing changes
+ // fast path: re-emit binding if nothing changes
if (new_var.same_as(binding->var) && new_value.same_as(binding->value)) {
builder_->EmitNormalized(GetRef<VarBinding>(binding));
return;
@@ -660,8 +660,8 @@ void ExprMutator::ReEmitBinding(const VarBindingNode*
binding, Expr new_value) {
Var temp = WithStructInfo(new_var, GetStructInfo(new_value));
if (!temp.same_as(new_var)) {
new_var = temp;
- this->var_remap_[binding->var->vid] = new_var;
}
+ this->var_remap_[binding->var->vid] = new_var;
builder_->EmitNormalized(VarBinding(new_var, new_value));
}
diff --git a/src/relax/transform/bind_params.cc
b/src/relax/transform/bind_params.cc
index 1de8d94461..c444a84f44 100644
--- a/src/relax/transform/bind_params.cc
+++ b/src/relax/transform/bind_params.cc
@@ -30,6 +30,57 @@
namespace tvm {
namespace relax {
+void MatchSymbolicVar(const Expr& arg, const Expr& constant,
+ Map<tir::Var, PrimExpr>* symbolic_var_map,
arith::Analyzer* analyzer_) {
+ auto opt_arg_sinfo = MatchStructInfo<TensorStructInfo>(arg);
+ CHECK(opt_arg_sinfo)
+ << "The struct info of the bound parameter is expected to be
TensorStructInfo, but got: "
+ << GetStructInfo(arg);
+ auto opt_const_sinfo = MatchStructInfo<TensorStructInfo>(constant);
+ // As the constant is generated by internal codes, we use ICHECK here.
+ ICHECK(opt_const_sinfo)
+ << "The struct info of the bound weight is expected to be
TensorStructInfo, but got: "
+ << GetStructInfo(constant);
+
+ TensorStructInfo arg_sinfo = opt_arg_sinfo.value();
+ TensorStructInfo const_sinfo = opt_const_sinfo.value();
+ ICHECK(!const_sinfo->IsUnknownDtype());
+ ICHECK(!const_sinfo->IsUnknownNdim());
+ ICHECK(const_sinfo->shape.defined());
+
+ // dtype mismatch
+ if (!arg_sinfo->IsUnknownDtype() && arg_sinfo->dtype != const_sinfo->dtype) {
+ LOG(FATAL) << "The dtype of the bound parameter is expected to be " <<
arg_sinfo->dtype
+ << ", but got: " << const_sinfo->dtype;
+ }
+ // ndim mismatch
+ if (!arg_sinfo->IsUnknownNdim() && arg_sinfo->ndim != const_sinfo->ndim) {
+ LOG(FATAL) << "The ndim of the bound parameter is expected to be " <<
arg_sinfo->ndim
+ << ", but got: " << const_sinfo->ndim;
+ }
+ if (!arg_sinfo->shape.defined()) return;
+ const auto* arg_shape = arg_sinfo->shape.value().as<ShapeExprNode>();
+ const auto* const_shape = const_sinfo->shape.value().as<ShapeExprNode>();
+
+ CHECK(arg_shape && const_shape)
+ << "The shape of the bound parameter and weight is expected to be
ShapeExprNode for now";
+
+ for (int i = 0; i < arg_sinfo->ndim; ++i) {
+ const PrimExpr& const_dim = const_shape->values[i];
+ ICHECK(tir::is_const_int(const_dim));
+ if (const auto* shape_var = arg_shape->values[i].as<tir::VarNode>()) {
+ auto it = symbolic_var_map->find(GetRef<tir::Var>(shape_var));
+ if (it == symbolic_var_map->end()) {
+ symbolic_var_map->Set(GetRef<tir::Var>(shape_var), const_dim);
+ } else {
+ CHECK(analyzer_->CanProveEqual((*it).second, const_dim))
+ << "The shape of the bound parameter is expected to be " <<
(*it).second
+ << ", but got: " << const_dim;
+ }
+ }
+ }
+}
+
/*!
* \brief Bind params to function by using name
* \param func Relax function
@@ -48,18 +99,23 @@ inline Function BindParamsByName(Function func, const
Map<String, runtime::NDArr
}
}
- std::unordered_map<Var, Expr, ObjectPtrHash, ObjectPtrEqual> bind_dict;
+ arith::Analyzer analyzer;
+ Map<Var, Expr> bind_dict;
+ Map<tir::Var, PrimExpr> symbolic_var_map;
+
for (auto& kv : params) {
if (name_dict.count(kv.first) == 0) {
continue;
}
- auto arg = name_dict.at(kv.first);
+ const Var& arg = name_dict.at(kv.first);
if (repeat_var.count(arg)) {
LOG(FATAL) << "ValueError: Multiple args in the function have name " <<
kv.first;
}
- bind_dict[arg] = Constant(kv.second);
+ Expr const_expr = Constant(kv.second);
+ bind_dict.Set(arg, const_expr);
+ MatchSymbolicVar(arg, const_expr, &symbolic_var_map, &analyzer);
}
- Expr bound_expr = Bind(func, bind_dict);
+ Expr bound_expr = Bind(func, bind_dict, symbolic_var_map);
Function ret = Downcast<Function>(bound_expr);
ICHECK(ret.defined()) << "The returning type is expected to be a Relax
Function."
<< "\n";
diff --git a/src/relax/utils.cc b/src/relax/utils.cc
index 1cf64cbf64..cf1d9bed98 100644
--- a/src/relax/utils.cc
+++ b/src/relax/utils.cc
@@ -25,7 +25,36 @@ namespace relax {
/*! \brief Helper to implement bind params.*/
class ExprBinder : public ExprMutator {
public:
- explicit ExprBinder(const tvm::Map<Var, Expr>& args_map) :
args_map_(args_map) {}
+ explicit ExprBinder(const tvm::Map<Var, Expr>& args_map,
+ const tvm::Map<tir::Var, PrimExpr>& symbolic_var_map)
+ : args_map_(args_map), symbolic_var_map_(symbolic_var_map) {}
+
+ private:
+ Expr VisitExpr_(const FunctionNode* op) final {
+ tvm::Array<Var> params;
+ bool all_params_unchanged = true;
+ for (const Var& param : op->params) {
+ if (args_map_.count(param)) {
+ all_params_unchanged = false;
+ } else {
+ Var new_param = this->VisitVarDef(param);
+ params.push_back(new_param);
+ if (!param.same_as(new_param)) {
+ this->var_remap_[param->vid] = new_param;
+ all_params_unchanged = false;
+ }
+ }
+ }
+
+ Expr body = this->VisitWithNewScope(op->body, params);
+
+ // FuncStructInfo does not depend on Expr
+ if (all_params_unchanged && body.same_as(op->body)) {
+ return GetRef<Expr>(op);
+ } else {
+ return Function(params, body,
VisitExprDepStructInfoField(op->ret_struct_info), op->attrs);
+ }
+ }
Expr VisitExpr_(const VarNode* op) final {
auto id = GetRef<Var>(op);
@@ -37,34 +66,31 @@ class ExprBinder : public ExprMutator {
}
}
+ PrimExpr VisitPrimExpr(const PrimExpr& expr) final {
+ if (const tir::VarNode* var = expr.as<tir::VarNode>()) {
+ auto it = symbolic_var_map_.find(GetRef<tir::Var>(var));
+ if (it != symbolic_var_map_.end()) {
+ return (*it).second;
+ }
+ }
+ return ExprMutator::VisitPrimExpr(expr);
+ }
+
private:
const tvm::Map<Var, Expr>& args_map_;
+ const tvm::Map<tir::Var, PrimExpr>& symbolic_var_map_;
};
/*!
* \brief Bind params on expr
* \param expr The expr where to bind params
- * \param args_map The map from param var to the expr it binds to
+ * \param binds The map from param var to the expr it binds to
+ * \param symbolic_var_map The map from symbolic var to the expr it binds to
* \return The result expr after bind params
*/
-Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& args_map) {
- if (const FunctionNode* func = expr.as<FunctionNode>()) {
- Expr new_body = ExprBinder(args_map).VisitExpr(func->body);
- Array<Var> new_params;
- for (size_t i = 0; i < func->params.size(); ++i) {
- if (!args_map.count(func->params[i])) {
- new_params.push_back(func->params[i]);
- }
- }
- if (new_body.same_as(func->body) && new_params.size() ==
func->params.size()) {
- return expr;
- }
- // The checked_type_ of the new function is deduced from the function body
- // TODO(@relax-team): Should infer the shape from the body as well
- return Function(new_params, new_body, NullOpt, func->attrs);
- } else {
- return ExprBinder(args_map).VisitExpr(expr);
- }
+Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& binds,
+ const tvm::Map<tir::Var, PrimExpr>& symbolic_var_map) {
+ return ExprBinder(binds, symbolic_var_map).VisitExpr(expr);
}
bool IsBoolStructInfo(const StructInfo& sinfo, bool permit_unknown_rank,
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index 24ed9946a3..e216010667 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -142,7 +142,7 @@ def test_linear():
def main(
input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
w1: R.Tensor((7, 10), dtype="float32"),
- w2: R.Tensor((1, 7), dtype="float32"),
+ w2: R.Tensor((7,), dtype="float32"),
) -> R.Tensor((1, 3, 10, 7), dtype="float32"):
# block 0
with R.dataflow():
diff --git a/tests/python/relax/test_transform_bind_params.py
b/tests/python/relax/test_transform_bind_params.py
index ceaf8fb165..1dfd9e0c8e 100644
--- a/tests/python/relax/test_transform_bind_params.py
+++ b/tests/python/relax/test_transform_bind_params.py
@@ -71,5 +71,57 @@ def test_bind_params(use_np_array):
tvm.testing.assert_allclose(res_before.numpy(), res_after.numpy())
+def test_bind_params_symbolic_vars():
+ @tvm.script.ir_module
+ class Before:
+ @R.function
+ def main(
+ x: R.Tensor(("batch", "m"), dtype="float32"),
+ w0: R.Tensor(("n", "m"), dtype="float32"),
+ b0: R.Tensor(("n",), dtype="float32"),
+ w1: R.Tensor(("k", "n"), dtype="float32"),
+ b1: R.Tensor(("k",), dtype="float32"),
+ ) -> R.Tensor(("batch", "k"), dtype="float32"):
+ batch = T.Var("batch", "int64")
+ k = T.Var("k", "int64")
+ m = T.Var("m", "int64")
+ n = T.Var("n", "int64")
+ with R.dataflow():
+ lv0 = R.call_tir(
+ "linear0", (x, w0, b0), out_sinfo=R.Tensor((batch, n),
dtype="float32")
+ )
+ out = R.call_tir(
+ "linear1", (lv0, w1, b1), out_sinfo=R.Tensor((batch, k),
dtype="float32")
+ )
+ R.output(out)
+ return out
+
+ m, n, k = 4, 6, 8
+ w0_tvm = tvm.nd.array(np.random.rand(n, m).astype(np.float32))
+ b0_tvm = tvm.nd.array(np.random.rand(n).astype(np.float32))
+ w1_tvm = tvm.nd.array(np.random.rand(k, n).astype(np.float32))
+ b1_tvm = tvm.nd.array(np.random.rand(k).astype(np.float32))
+ params_dict = {"w0": w0_tvm, "b0": b0_tvm, "w1": w1_tvm, "b1": b1_tvm}
+ mod = relax.transform.BindParams("main", params_dict)(Before)
+
+ # Since it contains ConstantNode, it's hard to check with structural
equality.
+ func = mod["main"]
+ assert len(func.params) == 1
+ batch = func.params[0].struct_info.shape[0]
+ tvm.ir.assert_structural_equal(
+ func.params[0].struct_info, relax.TensorStructInfo((batch, 4),
"float32")
+ )
+ tvm.ir.assert_structural_equal(
+ func.ret_struct_info, relax.TensorStructInfo((batch, 8), "float32")
+ )
+ bindings = func.body.blocks[0].bindings
+ tvm.ir.assert_structural_equal(
+ bindings[0].var.struct_info, relax.TensorStructInfo((batch, 6),
"float32")
+ )
+ tvm.ir.assert_structural_equal(
+ bindings[1].var.struct_info, relax.TensorStructInfo((batch, 8),
"float32")
+ )
+
+
if __name__ == "__main__":
tvm.testing.main()