This is an automated email from the ASF dual-hosted git repository.
jwfromm pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 25c6218e64 [Unity] Skip shape checking on transformed params (#15736)
25c6218e64 is described below
commit 25c6218e6488be87c6ebdff391af9234a9d83585
Author: Wuwei Lin <[email protected]>
AuthorDate: Tue Sep 19 13:37:11 2023 -0700
[Unity] Skip shape checking on transformed params (#15736)
* [Unity] Skip shape checking on transformed params
* fix typo
* keep shape checking for dynamic params
---
include/tvm/relax/expr.h | 7 +
src/relax/backend/vm/vm_shape_lower.cc | 62 +++++---
src/relax/transform/bundle_model_params.cc | 6 +-
src/relax/transform/lift_transform_params.cc | 3 +-
src/relax/transform/rewrite_cuda_graph.cc | 3 +-
.../relax/test_backend_transform_shape_lower.py | 177 +++++++++++++++++++++
6 files changed, 230 insertions(+), 28 deletions(-)
diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h
index 2d1e805a41..02d6f8d276 100644
--- a/include/tvm/relax/expr.h
+++ b/include/tvm/relax/expr.h
@@ -1003,6 +1003,13 @@ constexpr const char* kWorkspaceSize = "WorkspaceSize";
/*! \brief Override checking purity for this function and treat as pure
* (is_pure must be set to true) */
constexpr const char* kForcePure = "relax.force_pure";
+
+/*!
+ * \brief The number of inputs of a function.
+ * If a function has the num_input attribute, the last func->params.size() -
num_inputs
+ * arguments are assumed to be weights that are fixed across invocations.
+ */
+constexpr const char* kNumInput = "num_input";
} // namespace attr
/*! \brief The extern function, which can represent packed function. */
diff --git a/src/relax/backend/vm/vm_shape_lower.cc
b/src/relax/backend/vm/vm_shape_lower.cc
index a5252be50b..8b8eb33f5b 100644
--- a/src/relax/backend/vm/vm_shape_lower.cc
+++ b/src/relax/backend/vm/vm_shape_lower.cc
@@ -198,7 +198,7 @@ class PrimExprSlotCollector : public ExprVisitor, public
StructInfoVisitor {
*/
class VMShapeLowerMutator
: public ExprMutator,
- public StructInfoFunctor<void(const StructInfo&, Expr, bool, const
String&,
+ public StructInfoFunctor<void(const StructInfo&, Expr, bool, bool, const
String&,
std::vector<MatchShapeTodoItem>*)> {
public:
static IRModule Lower(IRModule mod, bool emit_err_ctx) {
@@ -241,12 +241,19 @@ class VMShapeLowerMutator
builder_->BeginBindingBlock();
this->builder_->EmitNormalized(shape_heap_binding);
std::vector<MatchShapeTodoItem> match_todos;
+ size_t num_input = func->params.size();
+ if (auto opt_num_input = func->attrs.GetAttr<Integer>(attr::kNumInput)) {
+ // If the function has the attribute 'num_input', do shape checking on
for the real inputs
+ // and skip weights.
+ num_input = static_cast<size_t>(opt_num_input.value()->value);
+ }
for (size_t i = 0; i < func->params.size(); ++i) {
StructInfo sinfo = GetStructInfo(func->params[i]);
std::ostringstream err_ctx;
err_ctx << "ErrorContext(fn=" << gvar->name_hint << ", loc=param[" << i
<< "], param=" << func->params[i]->name_hint() << ",
annotation=" << sinfo << ") ";
- this->CheckMatchCast(sinfo, func->params[i], true, err_ctx.str(),
&match_todos);
+ this->CheckMatchCast(sinfo, func->params[i], true, i >= num_input,
err_ctx.str(),
+ &match_todos);
}
// insert heap generation logic.
match_todos = this->RunMatch(match_todos, false);
@@ -269,7 +276,7 @@ class VMShapeLowerMutator
<< ", loc=return, annotation=" << func->ret_struct_info << ") ";
std::vector<MatchShapeTodoItem> match_todos;
// NOTE: the return value's shape computation must already be defined.
- this->CheckMatchCast(func->ret_struct_info, body_seq->body, false,
err_ctx.str(),
+ this->CheckMatchCast(func->ret_struct_info, body_seq->body, false,
false, err_ctx.str(),
&match_todos);
// NOTE: the return value's shape computation must already be defined.
this->RunMatch(match_todos, true);
@@ -377,7 +384,7 @@ class VMShapeLowerMutator
std::ostringstream err_ctx;
err_ctx << "ErrorContext(match_cast, struct_info=" << binding->struct_info
<< ") ";
// always_check=false
- this->CheckMatchCast(binding->struct_info, value, false, err_ctx.str(),
&match_todos);
+ this->CheckMatchCast(binding->struct_info, value, false, false,
err_ctx.str(), &match_todos);
match_todos = this->RunMatch(match_todos, false);
this->EmitOutstandingPrimExprCompute();
@@ -556,37 +563,42 @@ class VMShapeLowerMutator
* \param always_check Whether we insert runtime check even if we can prove
* that value's struct info already satisfies the condition.
* This option is necessary for argument checking per our calling
convention.
- *
+ * \param dynamic_only Whether we only check values with dynamic shapes.
* \param err_ctx Extra error context to bring more informative error
reporting.
* \param match_todos List of match shape todo items collected when
recursively
* visit the match cast.
*/
void CheckMatchCast(const StructInfo& struct_info, Expr value, bool
always_check,
- const String& err_ctx, std::vector<MatchShapeTodoItem>*
match_todos) {
- return this->VisitStructInfo(struct_info, value, always_check, err_ctx,
match_todos);
+ bool dynamic_only, const String& err_ctx,
+ std::vector<MatchShapeTodoItem>* match_todos) {
+ return this->VisitStructInfo(struct_info, value, always_check,
dynamic_only, err_ctx,
+ match_todos);
}
void VisitStructInfo(const StructInfo& struct_info, Expr value, bool
always_check,
- const String& err_ctx, std::vector<MatchShapeTodoItem>*
match_todos) final {
+ bool dynamic_only, const String& err_ctx,
+ std::vector<MatchShapeTodoItem>* match_todos) final {
// short-cut, if the struct info already satisfies the
// constraint during match cast, we can skip matching
if (!always_check && IsBaseOf(struct_info, GetStructInfo(value))) return;
- return StructInfoFunctor::VisitStructInfo(struct_info, value,
always_check, err_ctx,
- match_todos);
+ return StructInfoFunctor::VisitStructInfo(struct_info, value,
always_check, dynamic_only,
+ err_ctx, match_todos);
}
void VisitStructInfo_(const ObjectStructInfoNode* op, Expr value, bool
always_check,
- const String& err_ctx,
std::vector<MatchShapeTodoItem>* match_todos) final {
- }
+ bool dynamic_only, const String& err_ctx,
+ std::vector<MatchShapeTodoItem>* match_todos) final {}
void VisitStructInfo_(const PrimStructInfoNode* op, Expr value, bool
always_check,
- const String& err_ctx,
std::vector<MatchShapeTodoItem>* match_todos) final {
+ bool dynamic_only, const String& err_ctx,
+ std::vector<MatchShapeTodoItem>* match_todos) final {
// TODO(relax-team) add PrimValue checks later.
LOG(FATAL) << "MatchCast of PrimValue is not yet supported";
}
void VisitStructInfo_(const ShapeStructInfoNode* op, Expr value, bool
always_check,
- const String& err_ctx,
std::vector<MatchShapeTodoItem>* match_todos) final {
+ bool dynamic_only, const String& err_ctx,
+ std::vector<MatchShapeTodoItem>* match_todos) final {
// emit runtime check of shape
if (always_check || !IsBaseOf(ShapeStructInfo(op->ndim),
GetStructInfo(value))) {
// check_shape_info(value, ndim, err_ctx)
@@ -605,8 +617,16 @@ class VMShapeLowerMutator
}
void VisitStructInfo_(const TensorStructInfoNode* op, Expr value, bool
always_check,
- const String& err_ctx,
std::vector<MatchShapeTodoItem>* match_todos) final {
+ bool dynamic_only, const String& err_ctx,
+ std::vector<MatchShapeTodoItem>* match_todos) final {
// emit runtime check of shape
+ auto* shape_expr = op->shape.as<ShapeExprNode>();
+ if (dynamic_only &&
+ std::all_of(shape_expr->values.begin(), shape_expr->values.end(),
+ [](const PrimExpr& e) { return
e->IsInstance<IntImmNode>(); })) {
+ // if we only check dynamic shapes, and the shape is static, we can skip.
+ return;
+ }
if (always_check || !IsBaseOf(TensorStructInfo(op->dtype, op->ndim),
GetStructInfo(value))) {
// check_tensor_info(value, ndim, dtype, err_ctx)
Call call(builtin_check_tensor_info_,
@@ -615,7 +635,7 @@ class VMShapeLowerMutator
builder_->Emit(call, "_");
}
- if (auto* shape_expr = op->shape.as<ShapeExprNode>()) {
+ if (shape_expr != nullptr) {
MatchShapeTodoItem item;
item.input = value;
item.pattern = shape_expr->values;
@@ -648,7 +668,8 @@ class VMShapeLowerMutator
}
void VisitStructInfo_(const TupleStructInfoNode* op, Expr value, bool
always_check,
- const String& err_ctx,
std::vector<MatchShapeTodoItem>* match_todos) final {
+ bool dynamic_only, const String& err_ctx,
+ std::vector<MatchShapeTodoItem>* match_todos) final {
auto* value_tinfo = GetStructInfoAs<TupleStructInfoNode>(value);
if (value_tinfo) {
CHECK_EQ(value_tinfo->fields.size(), op->fields.size())
@@ -664,13 +685,14 @@ class VMShapeLowerMutator
}
// recursively visit each sub-field and run matching
for (size_t i = 0; i < op->fields.size(); ++i) {
- this->VisitStructInfo(op->fields[i], MakeTupleGetItem(value, i),
always_check, err_ctx,
- match_todos);
+ this->VisitStructInfo(op->fields[i], MakeTupleGetItem(value, i),
always_check, dynamic_only,
+ err_ctx, match_todos);
}
}
void VisitStructInfo_(const FuncStructInfoNode* op, Expr value, bool
always_check,
- const String& err_ctx,
std::vector<MatchShapeTodoItem>* match_todos) final {
+ bool dynamic_only, const String& err_ctx,
+ std::vector<MatchShapeTodoItem>* match_todos) final {
// we only check function is callable.
if (!always_check && MatchStructInfo<FuncStructInfo>(value)) return;
// check_func_info(value, err_ctx)
diff --git a/src/relax/transform/bundle_model_params.cc
b/src/relax/transform/bundle_model_params.cc
index 8f6e7a1291..2cc8902e57 100644
--- a/src/relax/transform/bundle_model_params.cc
+++ b/src/relax/transform/bundle_model_params.cc
@@ -33,15 +33,13 @@
namespace tvm {
namespace relax {
-static const auto kAttrNumInput = "num_input";
-
class ModelParamBundler : public ExprMutator {
public:
ModelParamBundler() {}
Expr VisitExpr_(const FunctionNode* op) override {
Function func = GetRef<Function>(op);
- auto opt_num_input = func->attrs.GetAttr<Integer>(kAttrNumInput);
+ auto opt_num_input = func->attrs.GetAttr<Integer>(attr::kNumInput);
if (!opt_num_input) return func;
auto signed_num_input = opt_num_input.value()->value;
@@ -68,7 +66,7 @@ class ModelParamBundler : public ExprMutator {
var_to_expr_.Set(func->params[i], TupleGetItem(var_param_tuple, i -
num_input));
}
- func = WithoutAttr(func, kAttrNumInput);
+ func = WithoutAttr(func, attr::kNumInput);
func.CopyOnWrite()->params = params;
return ExprMutator::VisitExpr_(func.get());
diff --git a/src/relax/transform/lift_transform_params.cc
b/src/relax/transform/lift_transform_params.cc
index afa1e191f4..7cdf03b10d 100644
--- a/src/relax/transform/lift_transform_params.cc
+++ b/src/relax/transform/lift_transform_params.cc
@@ -234,7 +234,7 @@ class TransformParamsLifter : ExprMutator {
private:
Expr VisitExpr_(const FunctionNode* op) override {
auto func = GetRef<Function>(op);
- Optional<Integer> opt_num_input =
func->attrs.GetAttr<Integer>(attr_num_input_);
+ Optional<Integer> opt_num_input =
func->attrs.GetAttr<Integer>(attr::kNumInput);
if (!opt_num_input) {
return func;
}
@@ -300,7 +300,6 @@ class TransformParamsLifter : ExprMutator {
return VisitExpr_(static_cast<const VarNode*>(var));
}
- const char* attr_num_input_ = "num_input";
// Remap the original parameters to TupleGetItem from the packed tuple of
transformed parameters.
std::unordered_map<Var, Expr, ObjectPtrHash, ObjectPtrEqual> param_remap_;
// The plan of lifting the transform params
diff --git a/src/relax/transform/rewrite_cuda_graph.cc
b/src/relax/transform/rewrite_cuda_graph.cc
index 22c927997d..c2a0754462 100644
--- a/src/relax/transform/rewrite_cuda_graph.cc
+++ b/src/relax/transform/rewrite_cuda_graph.cc
@@ -153,9 +153,8 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
if (pair.second->IsInstance<FunctionNode>()) {
// If a function has the num_input attribute, the last
func->params.size() - num_inputs
// inputs are assumed to be fixed and thus they can be captured into a
cuda graph.
- static const char* attr_num_input = "num_input";
const auto& func = Downcast<Function>(pair.second);
- if (auto num_input = func->attrs.GetAttr<Integer>(attr_num_input)) {
+ if (auto num_input = func->attrs.GetAttr<Integer>(attr::kNumInput)) {
for (size_t i = num_input.value().IntValue(); i <
func->params.size(); ++i) {
static_vars_.insert(func->params[i].get());
}
diff --git a/tests/python/relax/test_backend_transform_shape_lower.py
b/tests/python/relax/test_backend_transform_shape_lower.py
index a5d4395e3c..b9a3537630 100644
--- a/tests/python/relax/test_backend_transform_shape_lower.py
+++ b/tests/python/relax/test_backend_transform_shape_lower.py
@@ -554,5 +554,182 @@ def test_symbolic_shape_multiple_function():
assert_structural_equal(after, expected)
+def test_check_lifted_weights():
+ MS = MatchShapeCode
+
+ @I.ir_module
+ class Before:
+ @R.function
+ def main_transform_params(
+ params: R.Tuple(R.Tensor((16, 16), dtype="float32"))
+ ) -> R.Tuple(R.Tensor((16, 16), dtype="float32")):
+ R.func_attr({"relax.force_pure": 1})
+ return params
+
+ @R.function
+ def main(x: R.Tensor((16, 16), "float32"), param_0: R.Tensor((16, 16),
dtype="float32")):
+ R.func_attr({"relax.force_pure": 1, "num_input": 1})
+ return (x, param_0)
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main_transform_params(
+ params: R.Tuple(R.Tensor((16, 16), dtype="float32"))
+ ) -> R.Tuple(R.Tensor((16, 16), dtype="float32")):
+ R.func_attr({"relax.force_pure": 1})
+ shape_heap: R.Object = R.null_value()
+ _: R.Tuple = R.call_packed(
+ "vm.builtin.check_tuple_info",
+ params,
+ R.prim_value(1),
+ R.str(""),
+ sinfo_args=(R.Tuple,),
+ )
+ gv: R.Tensor((16, 16), dtype="float32") = params[0]
+ _1: R.Tuple = R.call_packed(
+ "vm.builtin.check_tensor_info",
+ gv,
+ R.prim_value(2),
+ R.dtype("float32"),
+ R.str(""),
+ sinfo_args=(R.Tuple,),
+ )
+ _2: R.Tuple = R.call_packed(
+ "vm.builtin.match_shape",
+ gv,
+ shape_heap,
+ R.prim_value(2),
+ MS.ASSERT_EQUAL_TO_IMM,
+ R.prim_value(16),
+ MS.ASSERT_EQUAL_TO_IMM,
+ R.prim_value(16),
+ R.str(""),
+ sinfo_args=(R.Tuple,),
+ )
+ return params
+
+ @R.function
+ def main(
+ x: R.Tensor((16, 16), dtype="float32"), param_0: R.Tensor((16,
16), dtype="float32")
+ ) -> R.Tuple(R.Tensor((16, 16), dtype="float32"), R.Tensor((16, 16),
dtype="float32")):
+ R.func_attr({"num_input": 1, "relax.force_pure": 1})
+ shape_heap: R.Object = R.null_value()
+ _: R.Tuple = R.call_packed(
+ "vm.builtin.check_tensor_info",
+ x,
+ R.prim_value(2),
+ R.dtype("float32"),
+ R.str(""),
+ sinfo_args=(R.Tuple,),
+ )
+ _1: R.Tuple = R.call_packed(
+ "vm.builtin.match_shape",
+ x,
+ shape_heap,
+ R.prim_value(2),
+ MS.ASSERT_EQUAL_TO_IMM,
+ R.prim_value(16),
+ MS.ASSERT_EQUAL_TO_IMM,
+ R.prim_value(16),
+ R.str(""),
+ sinfo_args=(R.Tuple,),
+ )
+ return (x, param_0)
+
+ before = Before
+ after = relax.transform.VMShapeLower(emit_err_ctx=False)(before)
+ expected = Expected
+ assert_structural_equal(after, expected)
+
+
+def test_check_weights_with_dynamic_shape():
+ MS = MatchShapeCode
+
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(
+ x: R.Tensor((16, 16), "float32"),
+ params: R.Tuple(R.Tensor((16, 16), dtype="float32"),
R.Tensor(("n",), "float32")),
+ ):
+ R.func_attr({"relax.force_pure": 1, "num_input": 1})
+ n = T.int64()
+ param_0 = params[0]
+ param_1 = params[1]
+ return (x, param_0, param_1)
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((16, 16), "float32"),
+ params: R.Tuple(R.Tensor((16, 16), dtype="float32"),
R.Tensor(("n",), "float32")),
+ ):
+ n = T.int64()
+ R.func_attr({"num_input": 1, "relax.force_pure": 1})
+ shape_heap: R.Tensor(dtype="int64", ndim=1) =
R.call_builtin_with_ctx(
+ "vm.builtin.alloc_shape_heap",
+ (R.prim_value(1),),
+ sinfo_args=(R.Tensor(dtype="int64", ndim=1),),
+ )
+ _: R.Tuple = R.call_packed(
+ "vm.builtin.check_tensor_info",
+ x,
+ R.prim_value(2),
+ R.dtype("float32"),
+ R.str(""),
+ sinfo_args=(R.Tuple,),
+ )
+ _1: R.Tuple = R.call_packed(
+ "vm.builtin.check_tuple_info",
+ params,
+ R.prim_value(2),
+ R.str(""),
+ sinfo_args=(R.Tuple,),
+ )
+ _param_1: R.Tensor((n,), dtype="float32") = params[1]
+ _2: R.Tuple = R.call_packed(
+ "vm.builtin.check_tensor_info",
+ _param_1,
+ R.prim_value(1),
+ R.dtype("float32"),
+ R.str(""),
+ sinfo_args=(R.Tuple,),
+ )
+ _3: R.Tuple = R.call_packed(
+ "vm.builtin.match_shape",
+ x,
+ shape_heap,
+ R.prim_value(2),
+ MS.ASSERT_EQUAL_TO_IMM,
+ R.prim_value(16),
+ MS.ASSERT_EQUAL_TO_IMM,
+ R.prim_value(16),
+ R.str(""),
+ sinfo_args=(R.Tuple,),
+ )
+ _4: R.Tuple = R.call_packed(
+ "vm.builtin.match_shape",
+ _param_1,
+ shape_heap,
+ MS.STORE_TO_HEAP,
+ R.prim_value(1),
+ R.prim_value(0),
+ R.str(""),
+ sinfo_args=(R.Tuple,),
+ )
+
+ param_0 = params[0]
+ param_1 = params[1]
+ return (x, param_0, param_1)
+
+ before = Before
+ after = relax.transform.VMShapeLower(emit_err_ctx=False)(before)
+ print(after)
+ expected = Expected
+ assert_structural_equal(after, expected)
+
+
if __name__ == "__main__":
tvm.testing.main()