This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new c7200b7573 [Unity][Pass] Enhance constant folding to fold relax ops by 
evaluating them. (#14146)
c7200b7573 is described below

commit c7200b7573b8ac462cf7a63cecf1715ee4eecbc3
Author: Prakalp Srivastava <[email protected]>
AuthorDate: Tue Feb 28 16:12:48 2023 -0500

    [Unity][Pass] Enhance constant folding to fold relax ops by evaluating 
them. (#14146)
    
    * [Unity][Pass] Enhance constant folding to fold relax ops
    by evaluating them.
    
    This uses the registered legalization function attached to
    the op to lower it to call_tir and uses the existing call_tir
    folding mechanism to fold it.
    
    This kind of op folding is only allowed within dataflow block
    as ops could have side-effects.
    
    Limitations:
    * This currently does not support folding ops
    that could lower to multiple call_tir bindings.
    * Folding by evaluating ops is not always beneficial.
    We need a heuristic to check if it is useful. This is
    not implemented yet and folding is always allowed
    by evaluating expressions.
    
    * fix ci error
    
    * fix doc
    
    * fix bug
---
 src/relax/transform/fold_constant.cc               |  56 +++++++++--
 tests/python/relax/test_transform_fold_constant.py | 103 +++++++++++++++++++++
 2 files changed, 150 insertions(+), 9 deletions(-)

diff --git a/src/relax/transform/fold_constant.cc 
b/src/relax/transform/fold_constant.cc
index 87b022c8ae..6b28f31889 100644
--- a/src/relax/transform/fold_constant.cc
+++ b/src/relax/transform/fold_constant.cc
@@ -21,6 +21,7 @@
 #include <tvm/ir/function.h>
 #include <tvm/relax/analysis.h>
 #include <tvm/relax/expr_functor.h>
+#include <tvm/relax/op_attr_types.h>
 #include <tvm/relax/transform.h>
 #include <tvm/relax/type.h>
 #include <tvm/tir/function.h>
@@ -38,7 +39,7 @@ class ConstantFolder : public ExprMutator {
   }
 
  private:
-  explicit ConstantFolder(IRModule ctx_module) : ctx_module_(ctx_module) {}
+  explicit ConstantFolder(IRModule ctx_module) : ExprMutator(ctx_module) {}
 
   /*!
    * \brief Pattern match the shape inside the given struct info to a
@@ -88,7 +89,8 @@ class ConstantFolder : public ExprMutator {
   Optional<tir::PrimFunc> MatchPrimFunc(const Expr& op) {
     if (auto* ptr = op.as<GlobalVarNode>()) {
       // NOTE: as check works for nullptr(returns null)
-      Optional<BaseFunc> base_func = 
ctx_module_->functions.Get(GetRef<GlobalVar>(ptr));
+      Optional<BaseFunc> base_func =
+          
builder_->GetContextIRModule()->functions.Get(GetRef<GlobalVar>(ptr));
       if (auto* pfunc = base_func.as<tir::PrimFuncNode>()) {
         return GetRef<tir::PrimFunc>(pfunc);
       }
@@ -127,6 +129,19 @@ class ConstantFolder : public ExprMutator {
     return build_func;
   }
 
+  /*!
+   * \brief Checks if it is useful to fold \p expr.
+   * \details Folding an expr is a trade-off - we are materializing a constant 
in the IRModule and
+   * paying compile time cost to avoid the cost of executing this expr at 
runtime. For example,
+   * folding iota ops could result in large constants being materialized, thus 
increasing the size
+   * of the program.
+   */
+  bool ShouldBeFolded(Expr expr) {
+    // TODO(prakalp): Implement a heuristic to check if folding this expr is 
actually useful or
+    // not.
+    return true;
+  }
+
   // Try constant evaluate the function call
   // if failed return NullOpt
   Optional<Expr> ConstEvaluateCallTIR(tir::PrimFunc tir_func, 
Array<runtime::NDArray> arr_args,
@@ -159,7 +174,8 @@ class ConstantFolder : public ExprMutator {
     return Constant(ret_tensor);
   }
 
-  Expr VisitCallTIR(Call call) {
+  // Returns the folded expr if the call is successfully folded to constant, 
otherwise null.
+  Optional<Expr> VisitCallTIR(Call call) {
     // call_tir needs to have at least three arguments
     ICHECK_GE(call->args.size(), 2);
     Optional<tir::PrimFunc> func = MatchPrimFunc(call->args[0]);
@@ -174,10 +190,10 @@ class ConstantFolder : public ExprMutator {
       DynTensorType ret_type = Downcast<DynTensorType>(call->checked_type());
       // value_or will return value if it is not null, otherwise return or
       return ConstEvaluateCallTIR(func.value(), arr_args.value(), 
shape.value(), ret_type->dtype)
-          .value_or(call);
+          .value_or({});
     }
     // TODO(hongyi): support const-fold tuple outputs
-    return std::move(call);
+    return {};
   }
 
   using ExprMutator::VisitExpr_;
@@ -185,11 +201,35 @@ class ConstantFolder : public ExprMutator {
   Expr VisitExpr_(const CallNode* call) final {
     // post-order mutation
     Call post_call = Downcast<Call>(VisitExprPostOrder_(call));
+
+    // Check if it is useful to fold this call
+    if (!ShouldBeFolded(post_call)) return post_call;
+
     static const Op& call_tir_op = Op::Get("relax.call_tir");
+    static const auto& legalize_map = Op::GetAttrMap<FLegalize>("FLegalize");
+    auto* op_node = post_call->op.as<OpNode>();
 
-    if (call->op.same_as(call_tir_op)) {
-      return VisitCallTIR(post_call);
+    // Not an OpNode
+    if (op_node == nullptr) {
+      return post_call;
     }
+    auto op = GetRef<Op>(op_node);
+
+    if (op.same_as(call_tir_op)) {
+      return VisitCallTIR(post_call).value_or(post_call);
+    }
+
+    // If we are in a dataflow block, we can fold ops by lowering them to 
call_tir.
+    if (builder_->CurrentBlockIsDataFlow() && legalize_map.count(op)) {
+      // Get the legalized expression
+      Expr legalized_expr = builder_->Normalize(legalize_map[op](builder_, 
post_call));
+      // If the legalized expression is call_tir, try to fold it.
+      const CallNode* call = legalized_expr.as<CallNode>();
+      if (call && call->op.same_as(call_tir_op)) {
+        return VisitCallTIR(GetRef<Call>(call)).value_or(post_call);
+      }
+    }
+
     return std::move(post_call);
   }
 
@@ -211,8 +251,6 @@ class ConstantFolder : public ExprMutator {
     return ExprMutator::VisitExpr_(op);
   }
 
-  // the context module to lookup functions
-  IRModule ctx_module_;
   // cache for function build, via structural equality
   std::unordered_map<tir::PrimFunc, Optional<runtime::PackedFunc>, 
StructuralHash, StructuralEqual>
       func_build_cache_;
diff --git a/tests/python/relax/test_transform_fold_constant.py 
b/tests/python/relax/test_transform_fold_constant.py
index da0816ef80..8a4f2599df 100644
--- a/tests/python/relax/test_transform_fold_constant.py
+++ b/tests/python/relax/test_transform_fold_constant.py
@@ -273,5 +273,108 @@ def test_int32_fold():
     tvm.ir.assert_structural_equal(after, expected)
 
 
+def test_fold_single_relax_op():
+    # put before after in a single module
+    @tvm.script.ir_module
+    class Module:
+        @R.function
+        def before(c0: R.Tensor((16, 16), "float32")):
+            with R.dataflow():
+                gv = R.add(c0, c0)
+                R.output(gv)
+            return gv
+
+        @R.function
+        def expected(c1: R.Tensor((16, 16), "float32")):
+            return c1
+
+    c0_np = np.arange((16 * 16)).astype("float32").reshape(16, 16)
+    c1_np = c0_np + c0_np
+    before = gen_mod(Module, "before", {"c0": c0_np})
+    expected = gen_mod(Module, "expected", {"c1": c1_np})
+
+    after = relax.transform.FoldConstant()(before)
+    tvm.ir.assert_structural_equal(after, expected)
+
+
+def test_fold_multiple_relax_ops():
+    # put before after in a single module
+    @tvm.script.ir_module
+    class Module:
+        @R.function
+        def before(c0: R.Tensor((16, 16), "float32"), c1: R.Tensor((16, 16), 
"float32")):
+            with R.dataflow():
+                lv0 = R.add(c0, c1)
+                lv1 = R.multiply(c0, lv0)
+                gv = R.subtract(lv1, c1)
+                R.output(gv)
+            return gv
+
+        @R.function
+        def expected(c4: R.Tensor((16, 16), "float32")):
+            return c4
+
+    c0_np = np.arange((16 * 16)).astype("float32").reshape(16, 16)
+    c1_np = np.arange((16 * 16)).astype("float32").reshape(16, 16)
+    c2_np = c0_np + c1_np
+    c3_np = c0_np * c2_np
+    c4_np = c3_np - c1_np
+    before = gen_mod(Module, "before", {"c0": c0_np, "c1": c1_np})
+    expected = gen_mod(Module, "expected", {"c4": c4_np})
+
+    after = relax.transform.FoldConstant()(before)
+    tvm.ir.assert_structural_equal(after, expected)
+
+
+def test_do_not_fold_ops_outside_dataflow():
+    # put before after in a single module
+    @tvm.script.ir_module
+    class Module:
+        @R.function
+        def before(c0: R.Tensor((16, 16), "float32")):
+            gv = R.add(c0, c0)
+            return gv
+
+    c0_np = np.arange((16 * 16)).astype("float32").reshape(16, 16)
+    before = gen_mod(Module, "before", {"c0": c0_np})
+
+    after = relax.transform.FoldConstant()(before)
+    tvm.ir.assert_structural_equal(after, before)
+
+
+def test_unsupported_fold_ops_legalized_to_multiple_calls():
+    @tvm.script.ir_module
+    class Module:
+        @R.function
+        def before(c0: R.Tensor((16, 16), "float32")):
+            with R.dataflow():
+                gv = R.nn.relu(c0)
+                R.output(gv)
+            return gv
+
+    c0_np = np.arange((16 * 16)).astype("float32").reshape(16, 16)
+    before = gen_mod(Module, "before", {"c0": c0_np})
+
+    from tvm.relax.transform.legalize_ops.common import register_legalize
+
+    def customized_legalize_relu(bb: relax.BlockBuilder, call: relax.Call):
+        from tvm import topi  # pylint: disable=import-outside-toplevel
+
+        x = bb.emit_te(topi.nn.relu, *call.args)
+        return bb.call_te(topi.identity, x)
+
+    # register custom legalization for relu that emits multiple bindings for 
testing
+    relu_legalize = tvm.ir.Op.get("relax.nn.relu").get_attr("FLegalize")
+    tvm.ir.Op.get("relax.nn.relu").reset_attr("FLegalize")
+    register_legalize("relax.nn.relu", customized_legalize_relu)
+
+    after = relax.transform.FoldConstant()(before)
+    tvm.ir.assert_structural_equal(after, before)
+
+    # revert to correct legalization of relu
+    tvm.ir.Op.get("relax.nn.relu").reset_attr("FLegalize")
+    register_legalize("relax.nn.relu", relu_legalize)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to