This is an automated email from the ASF dual-hosted git repository.

jwfromm pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 25c6218e64 [Unity] Skip shape checking on transformed params (#15736)
25c6218e64 is described below

commit 25c6218e6488be87c6ebdff391af9234a9d83585
Author: Wuwei Lin <[email protected]>
AuthorDate: Tue Sep 19 13:37:11 2023 -0700

    [Unity] Skip shape checking on transformed params (#15736)
    
    * [Unity] Skip shape checking on transformed params
    
    * fix typo
    
    * keep shape checking for dynamic params
---
 include/tvm/relax/expr.h                           |   7 +
 src/relax/backend/vm/vm_shape_lower.cc             |  62 +++++---
 src/relax/transform/bundle_model_params.cc         |   6 +-
 src/relax/transform/lift_transform_params.cc       |   3 +-
 src/relax/transform/rewrite_cuda_graph.cc          |   3 +-
 .../relax/test_backend_transform_shape_lower.py    | 177 +++++++++++++++++++++
 6 files changed, 230 insertions(+), 28 deletions(-)

diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h
index 2d1e805a41..02d6f8d276 100644
--- a/include/tvm/relax/expr.h
+++ b/include/tvm/relax/expr.h
@@ -1003,6 +1003,13 @@ constexpr const char* kWorkspaceSize = "WorkspaceSize";
 /*! \brief Override checking purity for this function and treat as pure
  * (is_pure must be set to true) */
 constexpr const char* kForcePure = "relax.force_pure";
+
+/*!
+ * \brief The number of inputs of a function.
+ * If a function has the num_input attribute, the last func->params.size() - 
num_inputs
+ * arguments are assumed to be weights that are fixed across invocations.
+ */
+constexpr const char* kNumInput = "num_input";
 }  // namespace attr
 
 /*! \brief The extern function, which can represent packed function. */
diff --git a/src/relax/backend/vm/vm_shape_lower.cc 
b/src/relax/backend/vm/vm_shape_lower.cc
index a5252be50b..8b8eb33f5b 100644
--- a/src/relax/backend/vm/vm_shape_lower.cc
+++ b/src/relax/backend/vm/vm_shape_lower.cc
@@ -198,7 +198,7 @@ class PrimExprSlotCollector : public ExprVisitor, public 
StructInfoVisitor {
  */
 class VMShapeLowerMutator
     : public ExprMutator,
-      public StructInfoFunctor<void(const StructInfo&, Expr, bool, const 
String&,
+      public StructInfoFunctor<void(const StructInfo&, Expr, bool, bool, const 
String&,
                                     std::vector<MatchShapeTodoItem>*)> {
  public:
   static IRModule Lower(IRModule mod, bool emit_err_ctx) {
@@ -241,12 +241,19 @@ class VMShapeLowerMutator
       builder_->BeginBindingBlock();
       this->builder_->EmitNormalized(shape_heap_binding);
       std::vector<MatchShapeTodoItem> match_todos;
+      size_t num_input = func->params.size();
+      if (auto opt_num_input = func->attrs.GetAttr<Integer>(attr::kNumInput)) {
+        // If the function has the attribute 'num_input', do shape checking on 
for the real inputs
+        // and skip weights.
+        num_input = static_cast<size_t>(opt_num_input.value()->value);
+      }
       for (size_t i = 0; i < func->params.size(); ++i) {
         StructInfo sinfo = GetStructInfo(func->params[i]);
         std::ostringstream err_ctx;
         err_ctx << "ErrorContext(fn=" << gvar->name_hint << ", loc=param[" << i
                 << "], param=" << func->params[i]->name_hint() << ", 
annotation=" << sinfo << ") ";
-        this->CheckMatchCast(sinfo, func->params[i], true, err_ctx.str(), 
&match_todos);
+        this->CheckMatchCast(sinfo, func->params[i], true, i >= num_input, 
err_ctx.str(),
+                             &match_todos);
       }
       // insert heap generation logic.
       match_todos = this->RunMatch(match_todos, false);
@@ -269,7 +276,7 @@ class VMShapeLowerMutator
               << ", loc=return, annotation=" << func->ret_struct_info << ") ";
       std::vector<MatchShapeTodoItem> match_todos;
       // NOTE: the return value's shape computation must already be defined.
-      this->CheckMatchCast(func->ret_struct_info, body_seq->body, false, 
err_ctx.str(),
+      this->CheckMatchCast(func->ret_struct_info, body_seq->body, false, 
false, err_ctx.str(),
                            &match_todos);
       // NOTE: the return value's shape computation must already be defined.
       this->RunMatch(match_todos, true);
@@ -377,7 +384,7 @@ class VMShapeLowerMutator
     std::ostringstream err_ctx;
     err_ctx << "ErrorContext(match_cast, struct_info=" << binding->struct_info 
<< ") ";
     // always_check=false
-    this->CheckMatchCast(binding->struct_info, value, false, err_ctx.str(), 
&match_todos);
+    this->CheckMatchCast(binding->struct_info, value, false, false, 
err_ctx.str(), &match_todos);
 
     match_todos = this->RunMatch(match_todos, false);
     this->EmitOutstandingPrimExprCompute();
@@ -556,37 +563,42 @@ class VMShapeLowerMutator
    * \param always_check Whether we insert runtime check even if we can prove
    *        that value's struct info already satisfies the condition.
    *        This option is necessary for argument checking per our calling 
convention.
-   *
+   * \param dynamic_only Whether we only check values with dynamic shapes.
    * \param err_ctx Extra error context to bring more informative error 
reporting.
    * \param match_todos List of match shape todo items collected when 
recursively
    *                    visit the match cast.
    */
   void CheckMatchCast(const StructInfo& struct_info, Expr value, bool 
always_check,
-                      const String& err_ctx, std::vector<MatchShapeTodoItem>* 
match_todos) {
-    return this->VisitStructInfo(struct_info, value, always_check, err_ctx, 
match_todos);
+                      bool dynamic_only, const String& err_ctx,
+                      std::vector<MatchShapeTodoItem>* match_todos) {
+    return this->VisitStructInfo(struct_info, value, always_check, 
dynamic_only, err_ctx,
+                                 match_todos);
   }
 
   void VisitStructInfo(const StructInfo& struct_info, Expr value, bool 
always_check,
-                       const String& err_ctx, std::vector<MatchShapeTodoItem>* 
match_todos) final {
+                       bool dynamic_only, const String& err_ctx,
+                       std::vector<MatchShapeTodoItem>* match_todos) final {
     // short-cut, if the struct info already satisfies the
     // constraint during match cast, we can skip matching
     if (!always_check && IsBaseOf(struct_info, GetStructInfo(value))) return;
-    return StructInfoFunctor::VisitStructInfo(struct_info, value, 
always_check, err_ctx,
-                                              match_todos);
+    return StructInfoFunctor::VisitStructInfo(struct_info, value, 
always_check, dynamic_only,
+                                              err_ctx, match_todos);
   }
 
   void VisitStructInfo_(const ObjectStructInfoNode* op, Expr value, bool 
always_check,
-                        const String& err_ctx, 
std::vector<MatchShapeTodoItem>* match_todos) final {
-  }
+                        bool dynamic_only, const String& err_ctx,
+                        std::vector<MatchShapeTodoItem>* match_todos) final {}
 
   void VisitStructInfo_(const PrimStructInfoNode* op, Expr value, bool 
always_check,
-                        const String& err_ctx, 
std::vector<MatchShapeTodoItem>* match_todos) final {
+                        bool dynamic_only, const String& err_ctx,
+                        std::vector<MatchShapeTodoItem>* match_todos) final {
     // TODO(relax-team) add PrimValue checks later.
     LOG(FATAL) << "MatchCast of PrimValue is not yet supported";
   }
 
   void VisitStructInfo_(const ShapeStructInfoNode* op, Expr value, bool 
always_check,
-                        const String& err_ctx, 
std::vector<MatchShapeTodoItem>* match_todos) final {
+                        bool dynamic_only, const String& err_ctx,
+                        std::vector<MatchShapeTodoItem>* match_todos) final {
     // emit runtime check of shape
     if (always_check || !IsBaseOf(ShapeStructInfo(op->ndim), 
GetStructInfo(value))) {
       // check_shape_info(value, ndim, err_ctx)
@@ -605,8 +617,16 @@ class VMShapeLowerMutator
   }
 
   void VisitStructInfo_(const TensorStructInfoNode* op, Expr value, bool 
always_check,
-                        const String& err_ctx, 
std::vector<MatchShapeTodoItem>* match_todos) final {
+                        bool dynamic_only, const String& err_ctx,
+                        std::vector<MatchShapeTodoItem>* match_todos) final {
     // emit runtime check of shape
+    auto* shape_expr = op->shape.as<ShapeExprNode>();
+    if (dynamic_only &&
+        std::all_of(shape_expr->values.begin(), shape_expr->values.end(),
+                    [](const PrimExpr& e) { return 
e->IsInstance<IntImmNode>(); })) {
+      // if we only check dynamic shapes, and the shape is static, we can skip.
+      return;
+    }
     if (always_check || !IsBaseOf(TensorStructInfo(op->dtype, op->ndim), 
GetStructInfo(value))) {
       // check_tensor_info(value, ndim, dtype, err_ctx)
       Call call(builtin_check_tensor_info_,
@@ -615,7 +635,7 @@ class VMShapeLowerMutator
       builder_->Emit(call, "_");
     }
 
-    if (auto* shape_expr = op->shape.as<ShapeExprNode>()) {
+    if (shape_expr != nullptr) {
       MatchShapeTodoItem item;
       item.input = value;
       item.pattern = shape_expr->values;
@@ -648,7 +668,8 @@ class VMShapeLowerMutator
   }
 
   void VisitStructInfo_(const TupleStructInfoNode* op, Expr value, bool 
always_check,
-                        const String& err_ctx, 
std::vector<MatchShapeTodoItem>* match_todos) final {
+                        bool dynamic_only, const String& err_ctx,
+                        std::vector<MatchShapeTodoItem>* match_todos) final {
     auto* value_tinfo = GetStructInfoAs<TupleStructInfoNode>(value);
     if (value_tinfo) {
       CHECK_EQ(value_tinfo->fields.size(), op->fields.size())
@@ -664,13 +685,14 @@ class VMShapeLowerMutator
     }
     // recursively visit each sub-field and run matching
     for (size_t i = 0; i < op->fields.size(); ++i) {
-      this->VisitStructInfo(op->fields[i], MakeTupleGetItem(value, i), 
always_check, err_ctx,
-                            match_todos);
+      this->VisitStructInfo(op->fields[i], MakeTupleGetItem(value, i), 
always_check, dynamic_only,
+                            err_ctx, match_todos);
     }
   }
 
   void VisitStructInfo_(const FuncStructInfoNode* op, Expr value, bool 
always_check,
-                        const String& err_ctx, 
std::vector<MatchShapeTodoItem>* match_todos) final {
+                        bool dynamic_only, const String& err_ctx,
+                        std::vector<MatchShapeTodoItem>* match_todos) final {
     // we only check function is callable.
     if (!always_check && MatchStructInfo<FuncStructInfo>(value)) return;
     // check_func_info(value, err_ctx)
diff --git a/src/relax/transform/bundle_model_params.cc 
b/src/relax/transform/bundle_model_params.cc
index 8f6e7a1291..2cc8902e57 100644
--- a/src/relax/transform/bundle_model_params.cc
+++ b/src/relax/transform/bundle_model_params.cc
@@ -33,15 +33,13 @@
 namespace tvm {
 namespace relax {
 
-static const auto kAttrNumInput = "num_input";
-
 class ModelParamBundler : public ExprMutator {
  public:
   ModelParamBundler() {}
 
   Expr VisitExpr_(const FunctionNode* op) override {
     Function func = GetRef<Function>(op);
-    auto opt_num_input = func->attrs.GetAttr<Integer>(kAttrNumInput);
+    auto opt_num_input = func->attrs.GetAttr<Integer>(attr::kNumInput);
     if (!opt_num_input) return func;
     auto signed_num_input = opt_num_input.value()->value;
 
@@ -68,7 +66,7 @@ class ModelParamBundler : public ExprMutator {
       var_to_expr_.Set(func->params[i], TupleGetItem(var_param_tuple, i - 
num_input));
     }
 
-    func = WithoutAttr(func, kAttrNumInput);
+    func = WithoutAttr(func, attr::kNumInput);
     func.CopyOnWrite()->params = params;
 
     return ExprMutator::VisitExpr_(func.get());
diff --git a/src/relax/transform/lift_transform_params.cc 
b/src/relax/transform/lift_transform_params.cc
index afa1e191f4..7cdf03b10d 100644
--- a/src/relax/transform/lift_transform_params.cc
+++ b/src/relax/transform/lift_transform_params.cc
@@ -234,7 +234,7 @@ class TransformParamsLifter : ExprMutator {
  private:
   Expr VisitExpr_(const FunctionNode* op) override {
     auto func = GetRef<Function>(op);
-    Optional<Integer> opt_num_input = 
func->attrs.GetAttr<Integer>(attr_num_input_);
+    Optional<Integer> opt_num_input = 
func->attrs.GetAttr<Integer>(attr::kNumInput);
     if (!opt_num_input) {
       return func;
     }
@@ -300,7 +300,6 @@ class TransformParamsLifter : ExprMutator {
     return VisitExpr_(static_cast<const VarNode*>(var));
   }
 
-  const char* attr_num_input_ = "num_input";
   // Remap the original parameters to TupleGetItem from the packed tuple of 
transformed parameters.
   std::unordered_map<Var, Expr, ObjectPtrHash, ObjectPtrEqual> param_remap_;
   // The plan of lifting the transform params
diff --git a/src/relax/transform/rewrite_cuda_graph.cc 
b/src/relax/transform/rewrite_cuda_graph.cc
index 22c927997d..c2a0754462 100644
--- a/src/relax/transform/rewrite_cuda_graph.cc
+++ b/src/relax/transform/rewrite_cuda_graph.cc
@@ -153,9 +153,8 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
       if (pair.second->IsInstance<FunctionNode>()) {
         // If a function has the num_input attribute, the last 
func->params.size() - num_inputs
         // inputs are assumed to be fixed and thus they can be captured into a 
cuda graph.
-        static const char* attr_num_input = "num_input";
         const auto& func = Downcast<Function>(pair.second);
-        if (auto num_input = func->attrs.GetAttr<Integer>(attr_num_input)) {
+        if (auto num_input = func->attrs.GetAttr<Integer>(attr::kNumInput)) {
           for (size_t i = num_input.value().IntValue(); i < 
func->params.size(); ++i) {
             static_vars_.insert(func->params[i].get());
           }
diff --git a/tests/python/relax/test_backend_transform_shape_lower.py 
b/tests/python/relax/test_backend_transform_shape_lower.py
index a5d4395e3c..b9a3537630 100644
--- a/tests/python/relax/test_backend_transform_shape_lower.py
+++ b/tests/python/relax/test_backend_transform_shape_lower.py
@@ -554,5 +554,182 @@ def test_symbolic_shape_multiple_function():
     assert_structural_equal(after, expected)
 
 
+def test_check_lifted_weights():
+    MS = MatchShapeCode
+
+    @I.ir_module
+    class Before:
+        @R.function
+        def main_transform_params(
+            params: R.Tuple(R.Tensor((16, 16), dtype="float32"))
+        ) -> R.Tuple(R.Tensor((16, 16), dtype="float32")):
+            R.func_attr({"relax.force_pure": 1})
+            return params
+
+        @R.function
+        def main(x: R.Tensor((16, 16), "float32"), param_0: R.Tensor((16, 16), 
dtype="float32")):
+            R.func_attr({"relax.force_pure": 1, "num_input": 1})
+            return (x, param_0)
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main_transform_params(
+            params: R.Tuple(R.Tensor((16, 16), dtype="float32"))
+        ) -> R.Tuple(R.Tensor((16, 16), dtype="float32")):
+            R.func_attr({"relax.force_pure": 1})
+            shape_heap: R.Object = R.null_value()
+            _: R.Tuple = R.call_packed(
+                "vm.builtin.check_tuple_info",
+                params,
+                R.prim_value(1),
+                R.str(""),
+                sinfo_args=(R.Tuple,),
+            )
+            gv: R.Tensor((16, 16), dtype="float32") = params[0]
+            _1: R.Tuple = R.call_packed(
+                "vm.builtin.check_tensor_info",
+                gv,
+                R.prim_value(2),
+                R.dtype("float32"),
+                R.str(""),
+                sinfo_args=(R.Tuple,),
+            )
+            _2: R.Tuple = R.call_packed(
+                "vm.builtin.match_shape",
+                gv,
+                shape_heap,
+                R.prim_value(2),
+                MS.ASSERT_EQUAL_TO_IMM,
+                R.prim_value(16),
+                MS.ASSERT_EQUAL_TO_IMM,
+                R.prim_value(16),
+                R.str(""),
+                sinfo_args=(R.Tuple,),
+            )
+            return params
+
+        @R.function
+        def main(
+            x: R.Tensor((16, 16), dtype="float32"), param_0: R.Tensor((16, 
16), dtype="float32")
+        ) -> R.Tuple(R.Tensor((16, 16), dtype="float32"), R.Tensor((16, 16), 
dtype="float32")):
+            R.func_attr({"num_input": 1, "relax.force_pure": 1})
+            shape_heap: R.Object = R.null_value()
+            _: R.Tuple = R.call_packed(
+                "vm.builtin.check_tensor_info",
+                x,
+                R.prim_value(2),
+                R.dtype("float32"),
+                R.str(""),
+                sinfo_args=(R.Tuple,),
+            )
+            _1: R.Tuple = R.call_packed(
+                "vm.builtin.match_shape",
+                x,
+                shape_heap,
+                R.prim_value(2),
+                MS.ASSERT_EQUAL_TO_IMM,
+                R.prim_value(16),
+                MS.ASSERT_EQUAL_TO_IMM,
+                R.prim_value(16),
+                R.str(""),
+                sinfo_args=(R.Tuple,),
+            )
+            return (x, param_0)
+
+    before = Before
+    after = relax.transform.VMShapeLower(emit_err_ctx=False)(before)
+    expected = Expected
+    assert_structural_equal(after, expected)
+
+
+def test_check_weights_with_dynamic_shape():
+    MS = MatchShapeCode
+
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(
+            x: R.Tensor((16, 16), "float32"),
+            params: R.Tuple(R.Tensor((16, 16), dtype="float32"), 
R.Tensor(("n",), "float32")),
+        ):
+            R.func_attr({"relax.force_pure": 1, "num_input": 1})
+            n = T.int64()
+            param_0 = params[0]
+            param_1 = params[1]
+            return (x, param_0, param_1)
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((16, 16), "float32"),
+            params: R.Tuple(R.Tensor((16, 16), dtype="float32"), 
R.Tensor(("n",), "float32")),
+        ):
+            n = T.int64()
+            R.func_attr({"num_input": 1, "relax.force_pure": 1})
+            shape_heap: R.Tensor(dtype="int64", ndim=1) = 
R.call_builtin_with_ctx(
+                "vm.builtin.alloc_shape_heap",
+                (R.prim_value(1),),
+                sinfo_args=(R.Tensor(dtype="int64", ndim=1),),
+            )
+            _: R.Tuple = R.call_packed(
+                "vm.builtin.check_tensor_info",
+                x,
+                R.prim_value(2),
+                R.dtype("float32"),
+                R.str(""),
+                sinfo_args=(R.Tuple,),
+            )
+            _1: R.Tuple = R.call_packed(
+                "vm.builtin.check_tuple_info",
+                params,
+                R.prim_value(2),
+                R.str(""),
+                sinfo_args=(R.Tuple,),
+            )
+            _param_1: R.Tensor((n,), dtype="float32") = params[1]
+            _2: R.Tuple = R.call_packed(
+                "vm.builtin.check_tensor_info",
+                _param_1,
+                R.prim_value(1),
+                R.dtype("float32"),
+                R.str(""),
+                sinfo_args=(R.Tuple,),
+            )
+            _3: R.Tuple = R.call_packed(
+                "vm.builtin.match_shape",
+                x,
+                shape_heap,
+                R.prim_value(2),
+                MS.ASSERT_EQUAL_TO_IMM,
+                R.prim_value(16),
+                MS.ASSERT_EQUAL_TO_IMM,
+                R.prim_value(16),
+                R.str(""),
+                sinfo_args=(R.Tuple,),
+            )
+            _4: R.Tuple = R.call_packed(
+                "vm.builtin.match_shape",
+                _param_1,
+                shape_heap,
+                MS.STORE_TO_HEAP,
+                R.prim_value(1),
+                R.prim_value(0),
+                R.str(""),
+                sinfo_args=(R.Tuple,),
+            )
+
+            param_0 = params[0]
+            param_1 = params[1]
+            return (x, param_0, param_1)
+
+    before = Before
+    after = relax.transform.VMShapeLower(emit_err_ctx=False)(before)
+    print(after)
+    expected = Expected
+    assert_structural_equal(after, expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to