This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new c7200b7573 [Unity][Pass] Enhance constant folding to fold relax ops by
evaluating them. (#14146)
c7200b7573 is described below
commit c7200b7573b8ac462cf7a63cecf1715ee4eecbc3
Author: Prakalp Srivastava <[email protected]>
AuthorDate: Tue Feb 28 16:12:48 2023 -0500
[Unity][Pass] Enhance constant folding to fold relax ops by evaluating
them. (#14146)
* [Unity][Pass] Enhance constant folding to fold relax ops
by evaluating them.
This uses the registered legalization function attached to
the op to lower it to call_tir and uses the existing call_tir
folding mechanism to fold it.
This kind of op folding is only allowed within dataflow block
as ops could have side-effects.
Limitations:
* This currently does not support folding ops
that could lower to multiple call_tir bindings.
* Folding by evaluating ops is not always beneficial.
We need a heuristic to check if it is useful. This is
not implemented yet and folding is always allowed
by evaluating expressions.
* fix ci error
* fix doc
* fix bug
---
src/relax/transform/fold_constant.cc | 56 +++++++++--
tests/python/relax/test_transform_fold_constant.py | 103 +++++++++++++++++++++
2 files changed, 150 insertions(+), 9 deletions(-)
diff --git a/src/relax/transform/fold_constant.cc
b/src/relax/transform/fold_constant.cc
index 87b022c8ae..6b28f31889 100644
--- a/src/relax/transform/fold_constant.cc
+++ b/src/relax/transform/fold_constant.cc
@@ -21,6 +21,7 @@
#include <tvm/ir/function.h>
#include <tvm/relax/analysis.h>
#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/op_attr_types.h>
#include <tvm/relax/transform.h>
#include <tvm/relax/type.h>
#include <tvm/tir/function.h>
@@ -38,7 +39,7 @@ class ConstantFolder : public ExprMutator {
}
private:
- explicit ConstantFolder(IRModule ctx_module) : ctx_module_(ctx_module) {}
+ explicit ConstantFolder(IRModule ctx_module) : ExprMutator(ctx_module) {}
/*!
* \brief Pattern match the shape inside the given struct info to a
@@ -88,7 +89,8 @@ class ConstantFolder : public ExprMutator {
Optional<tir::PrimFunc> MatchPrimFunc(const Expr& op) {
if (auto* ptr = op.as<GlobalVarNode>()) {
// NOTE: as check works for nullptr(returns null)
- Optional<BaseFunc> base_func =
ctx_module_->functions.Get(GetRef<GlobalVar>(ptr));
+ Optional<BaseFunc> base_func =
+
builder_->GetContextIRModule()->functions.Get(GetRef<GlobalVar>(ptr));
if (auto* pfunc = base_func.as<tir::PrimFuncNode>()) {
return GetRef<tir::PrimFunc>(pfunc);
}
@@ -127,6 +129,19 @@ class ConstantFolder : public ExprMutator {
return build_func;
}
+ /*!
+ * \brief Checks if it is useful to fold \p expr.
+ * \details Folding an expr is a trade-off - we are materializing a constant
in the IRModule and
+ * paying compile time cost to avoid the cost of executing this expr at
runtime. For example,
+ * folding iota ops could result in large constants being materialized, thus
increasing the size
+ * of the program.
+ */
+ bool ShouldBeFolded(Expr expr) {
+ // TODO(prakalp): Implement a heuristic to check if folding this expr is
actually useful or
+ // not.
+ return true;
+ }
+
// Try constant evaluate the function call
// if failed return NullOpt
Optional<Expr> ConstEvaluateCallTIR(tir::PrimFunc tir_func,
Array<runtime::NDArray> arr_args,
@@ -159,7 +174,8 @@ class ConstantFolder : public ExprMutator {
return Constant(ret_tensor);
}
- Expr VisitCallTIR(Call call) {
+ // Returns the folded expr if the call is successfully folded to constant,
otherwise null.
+ Optional<Expr> VisitCallTIR(Call call) {
// call_tir needs to have at least three arguments
ICHECK_GE(call->args.size(), 2);
Optional<tir::PrimFunc> func = MatchPrimFunc(call->args[0]);
@@ -174,10 +190,10 @@ class ConstantFolder : public ExprMutator {
DynTensorType ret_type = Downcast<DynTensorType>(call->checked_type());
// value_or will return value if it is not null, otherwise return or
return ConstEvaluateCallTIR(func.value(), arr_args.value(),
shape.value(), ret_type->dtype)
- .value_or(call);
+ .value_or({});
}
// TODO(hongyi): support const-fold tuple outputs
- return std::move(call);
+ return {};
}
using ExprMutator::VisitExpr_;
@@ -185,11 +201,35 @@ class ConstantFolder : public ExprMutator {
Expr VisitExpr_(const CallNode* call) final {
// post-order mutation
Call post_call = Downcast<Call>(VisitExprPostOrder_(call));
+
+ // Check if it is useful to fold this call
+ if (!ShouldBeFolded(post_call)) return post_call;
+
static const Op& call_tir_op = Op::Get("relax.call_tir");
+ static const auto& legalize_map = Op::GetAttrMap<FLegalize>("FLegalize");
+ auto* op_node = post_call->op.as<OpNode>();
- if (call->op.same_as(call_tir_op)) {
- return VisitCallTIR(post_call);
+ // Not an OpNode
+ if (op_node == nullptr) {
+ return post_call;
}
+ auto op = GetRef<Op>(op_node);
+
+ if (op.same_as(call_tir_op)) {
+ return VisitCallTIR(post_call).value_or(post_call);
+ }
+
+ // If we are in a dataflow block, we can fold ops by lowering them to
call_tir.
+ if (builder_->CurrentBlockIsDataFlow() && legalize_map.count(op)) {
+ // Get the legalized expression
+ Expr legalized_expr = builder_->Normalize(legalize_map[op](builder_,
post_call));
+ // If the legalized expression is call_tir, try to fold it.
+ const CallNode* call = legalized_expr.as<CallNode>();
+ if (call && call->op.same_as(call_tir_op)) {
+ return VisitCallTIR(GetRef<Call>(call)).value_or(post_call);
+ }
+ }
+
return std::move(post_call);
}
@@ -211,8 +251,6 @@ class ConstantFolder : public ExprMutator {
return ExprMutator::VisitExpr_(op);
}
- // the context module to lookup functions
- IRModule ctx_module_;
// cache for function build, via structural equality
std::unordered_map<tir::PrimFunc, Optional<runtime::PackedFunc>,
StructuralHash, StructuralEqual>
func_build_cache_;
diff --git a/tests/python/relax/test_transform_fold_constant.py
b/tests/python/relax/test_transform_fold_constant.py
index da0816ef80..8a4f2599df 100644
--- a/tests/python/relax/test_transform_fold_constant.py
+++ b/tests/python/relax/test_transform_fold_constant.py
@@ -273,5 +273,108 @@ def test_int32_fold():
tvm.ir.assert_structural_equal(after, expected)
+def test_fold_single_relax_op():
+ # put before after in a single module
+ @tvm.script.ir_module
+ class Module:
+ @R.function
+ def before(c0: R.Tensor((16, 16), "float32")):
+ with R.dataflow():
+ gv = R.add(c0, c0)
+ R.output(gv)
+ return gv
+
+ @R.function
+ def expected(c1: R.Tensor((16, 16), "float32")):
+ return c1
+
+ c0_np = np.arange((16 * 16)).astype("float32").reshape(16, 16)
+ c1_np = c0_np + c0_np
+ before = gen_mod(Module, "before", {"c0": c0_np})
+ expected = gen_mod(Module, "expected", {"c1": c1_np})
+
+ after = relax.transform.FoldConstant()(before)
+ tvm.ir.assert_structural_equal(after, expected)
+
+
+def test_fold_multiple_relax_ops():
+ # put before after in a single module
+ @tvm.script.ir_module
+ class Module:
+ @R.function
+ def before(c0: R.Tensor((16, 16), "float32"), c1: R.Tensor((16, 16),
"float32")):
+ with R.dataflow():
+ lv0 = R.add(c0, c1)
+ lv1 = R.multiply(c0, lv0)
+ gv = R.subtract(lv1, c1)
+ R.output(gv)
+ return gv
+
+ @R.function
+ def expected(c4: R.Tensor((16, 16), "float32")):
+ return c4
+
+ c0_np = np.arange((16 * 16)).astype("float32").reshape(16, 16)
+ c1_np = np.arange((16 * 16)).astype("float32").reshape(16, 16)
+ c2_np = c0_np + c1_np
+ c3_np = c0_np * c2_np
+ c4_np = c3_np - c1_np
+ before = gen_mod(Module, "before", {"c0": c0_np, "c1": c1_np})
+ expected = gen_mod(Module, "expected", {"c4": c4_np})
+
+ after = relax.transform.FoldConstant()(before)
+ tvm.ir.assert_structural_equal(after, expected)
+
+
+def test_do_not_fold_ops_outside_dataflow():
+ # put before after in a single module
+ @tvm.script.ir_module
+ class Module:
+ @R.function
+ def before(c0: R.Tensor((16, 16), "float32")):
+ gv = R.add(c0, c0)
+ return gv
+
+ c0_np = np.arange((16 * 16)).astype("float32").reshape(16, 16)
+ before = gen_mod(Module, "before", {"c0": c0_np})
+
+ after = relax.transform.FoldConstant()(before)
+ tvm.ir.assert_structural_equal(after, before)
+
+
+def test_unsupported_fold_ops_legalized_to_multiple_calls():
+ @tvm.script.ir_module
+ class Module:
+ @R.function
+ def before(c0: R.Tensor((16, 16), "float32")):
+ with R.dataflow():
+ gv = R.nn.relu(c0)
+ R.output(gv)
+ return gv
+
+ c0_np = np.arange((16 * 16)).astype("float32").reshape(16, 16)
+ before = gen_mod(Module, "before", {"c0": c0_np})
+
+ from tvm.relax.transform.legalize_ops.common import register_legalize
+
+ def customized_legalize_relu(bb: relax.BlockBuilder, call: relax.Call):
+ from tvm import topi # pylint: disable=import-outside-toplevel
+
+ x = bb.emit_te(topi.nn.relu, *call.args)
+ return bb.call_te(topi.identity, x)
+
+ # register custom legalization for relu that emits multiple bindings for
testing
+ relu_legalize = tvm.ir.Op.get("relax.nn.relu").get_attr("FLegalize")
+ tvm.ir.Op.get("relax.nn.relu").reset_attr("FLegalize")
+ register_legalize("relax.nn.relu", customized_legalize_relu)
+
+ after = relax.transform.FoldConstant()(before)
+ tvm.ir.assert_structural_equal(after, before)
+
+ # revert to correct legalization of relu
+ tvm.ir.Op.get("relax.nn.relu").reset_attr("FLegalize")
+ register_legalize("relax.nn.relu", relu_legalize)
+
+
if __name__ == "__main__":
tvm.testing.main()