This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 8babadb100 [Relay] improve SimplifyClipAndConsecutiveCast pass (#15362)
8babadb100 is described below
commit 8babadb10099b52080c731c60aedaf8b9029d2ba
Author: 电线杆 <[email protected]>
AuthorDate: Sun Jul 23 11:47:57 2023 +0800
[Relay] improve SimplifyClipAndConsecutiveCast pass (#15362)
* improve SimplifyClipAndConsecutiveCast
* modify the test case ground truth just to pass test. Not sure if this is
correct.
* add documentation
---
src/relay/transforms/simplify_expr.cc | 86 ++++++++++++++++++---------
tests/python/relay/aot/test_crt_aot_usmp.py | 2 +-
tests/python/relay/test_pass_simplify_expr.py | 45 ++++++++++++++
3 files changed, 103 insertions(+), 30 deletions(-)
diff --git a/src/relay/transforms/simplify_expr.cc
b/src/relay/transforms/simplify_expr.cc
index fa3348b95a..208c9821b6 100644
--- a/src/relay/transforms/simplify_expr.cc
+++ b/src/relay/transforms/simplify_expr.cc
@@ -160,7 +160,10 @@ class SimplifyConsecutiveCast : public DFPatternRewrite {
DFPattern cast1_;
};
-bool CheckDataTypeMaxMinValue(DataType dtype, double min_value, double
max_value) {
+/*! If mode == 0, return true if the interval [min_value, max_value] contains
the range of dtype,
+ * and return false otherwise. If mode == 1, return true if the interval
[min_value, max_value] is
+ * contained by the range of dtype, and return false otherwise.*/
+bool CheckDataTypeMaxMinValue(DataType dtype, double min_value, double
max_value, int mode = 0) {
double lbound{}, ubound{};
if (dtype.is_int() || dtype.is_uint()) {
ubound =
static_cast<double>(Downcast<IntImm>(tvm::max_value(dtype))->value);
@@ -169,21 +172,40 @@ bool CheckDataTypeMaxMinValue(DataType dtype, double
min_value, double max_value
ubound = Downcast<FloatImm>(tvm::max_value(dtype))->value;
lbound = Downcast<FloatImm>(tvm::min_value(dtype))->value;
}
- return max_value >= ubound && min_value <= lbound;
+ if (mode == 0) {
+ return max_value >= ubound && min_value <= lbound;
+ } else if (mode == 1) {
+ return max_value <= ubound && min_value >= lbound;
+ } else {
+ LOG(FATAL) << "invalid mode " << mode << " in CheckDataTypeMaxMinValue";
+ return false;
+ }
}
/*!
- * \brief SimplifyClipAndConsecutiveCast matches the pattern clip->cast->cast
and remove redundant
- * casts.
- * Analysis of "redundancy" is done based on clip min/max values and min/max
values of casted data
- * type.
+ * \brief SimplifyClipAndConsecutiveCast matches the pattern
clip->cast->...->cast and remove
+ * redundant casts. Analysis of "redundancy" is done based on clip min/max
values and min/max values
+ * of casted data type.
+ *
+ * Example:
+ * %0 == [type=int32]
+ * %1 = clip(%0, a_min=0f, a_max=255f) [type=int32]
+ * %2 = cast(%1, dtype="uint8") [type=uint8]
+ * %3 = cast(%2, dtype="int32") [type=int32]
+ *
+ * Optimized to (both casts can be removed):
+ * %1 = clip(%0, a_min=0f, a_max=255f) [type=int32]
*/
class SimplifyClipAndConsecutiveCast : public DFPatternRewrite {
public:
SimplifyClipAndConsecutiveCast() {
clip_ = IsOp("clip")({IsWildcard()});
- cast1_ = IsOp("cast")({clip_});
- pattern_ = IsOp("cast")({cast1_});
+ ObjectPtr<CallPatternNode> pattern_ptr = make_object<CallPatternNode>();
+ pattern_ptr->op = IsOp("cast");
+ pattern_ptr->args.clear();
+ pattern_ = CallPattern(pattern_ptr);
+ AltPattern or_pattern{pattern_, clip_};
+ pattern_ptr->args.push_back(or_pattern);
}
Expr Callback(const Expr& pre, const Expr& post,
@@ -191,34 +213,40 @@ class SimplifyClipAndConsecutiveCast : public
DFPatternRewrite {
auto clip = Downcast<Call>(node_map[clip_][0]);
const CallNode* clip_node = clip.as<CallNode>();
const ClipAttrs* clip_attrs = clip_node->attrs.as<ClipAttrs>();
- DataType clip_dtype = Downcast<TensorType>(clip->checked_type())->dtype;
- auto cast1 = Downcast<Call>(node_map[cast1_][0]);
- DataType cast1_dtype = Downcast<TensorType>(cast1->checked_type())->dtype;
+ std::vector<Expr> remaining_casts{};
+ Expr cast_expr{post};
+ while (cast_expr != clip) {
+ DataType cast_dtype =
Downcast<TensorType>(cast_expr->checked_type())->dtype;
+ if (!CheckDataTypeMaxMinValue(cast_dtype, clip_attrs->a_min,
clip_attrs->a_max, 1)) {
+ remaining_casts.push_back(cast_expr);
+ }
+ cast_expr = cast_expr.as<CallNode>()->args[0];
+ }
- auto cast2 = Downcast<Call>(post);
- DataType cast2_dtype = Downcast<TensorType>(cast2->checked_type())->dtype;
+ Expr last_op = (remaining_casts.size() == 0) ? clip : remaining_casts[0];
+ DataType last_op_dtype =
Downcast<TensorType>(last_op->checked_type())->dtype;
+ bool need_additional_cast{false};
+ if (last_op_dtype != Downcast<TensorType>(post->checked_type())->dtype) {
+ need_additional_cast = true;
+ }
- if (clip_dtype == cast2_dtype &&
- CheckDataTypeMaxMinValue(cast1_dtype, clip_attrs->a_min,
clip_attrs->a_max)) {
- // Case 1:
- // Data type of Clip == target data type of second Cast and min/max
value of Clip == min/max
- // value of first Clip target data type. In this case both Clip ops can
be removed.
- // Example:
- // %0 == [type=int32]
- // %1 = clip(%0, a_min=0f, a_max=255f) [type=int32]
- // %2 = cast(%1, dtype="uint8") [type=uint8]
- // %3 = cast(%2, dtype="int32") [type=int32]
- //
- // Optimized to (both casts can be removed):
- // %1 = clip(%0, a_min=0f, a_max=255f) [type=int32]
- return node_map[clip_][0];
+ Expr res{clip};
+ for (size_t i = remaining_casts.size(); i > 0; --i) {
+ auto attrs = make_object<CastAttrs>();
+ attrs->dtype = remaining_casts[i -
1].as<CallNode>()->attrs.as<CastAttrs>()->dtype;
+ res = Call(Op::Get("cast"), {res}, Attrs(attrs), {});
}
- return post;
+ if (need_additional_cast) {
+ auto attrs = make_object<CastAttrs>();
+ attrs->dtype = Downcast<TensorType>(post->checked_type())->dtype;
+ res = Call(Op::Get("cast"), {res}, Attrs(attrs), {});
+ }
+ return res;
}
protected:
- DFPattern clip_, cast1_;
+ DFPattern clip_;
};
/*!
diff --git a/tests/python/relay/aot/test_crt_aot_usmp.py
b/tests/python/relay/aot/test_crt_aot_usmp.py
index 83aa46dc31..130c26b6f8 100644
--- a/tests/python/relay/aot/test_crt_aot_usmp.py
+++ b/tests/python/relay/aot/test_crt_aot_usmp.py
@@ -303,7 +303,7 @@ MOBILENET_V2_URL = (
"model_url, usmp_algo, workspace_size, constant_size",
[
(MOBILENET_V1_URL, "greedy_by_size", 4845696, 8468008),
- (MOBILENET_V1_URL, "greedy_by_conflicts", 4845696, 8468008),
+ (MOBILENET_V1_URL, "greedy_by_conflicts", 4444288, 8468008),
(MOBILENET_V1_URL, "hill_climb", 3240064, 8468008),
],
)
diff --git a/tests/python/relay/test_pass_simplify_expr.py
b/tests/python/relay/test_pass_simplify_expr.py
index b117c91d1c..ac6920d5b7 100644
--- a/tests/python/relay/test_pass_simplify_expr.py
+++ b/tests/python/relay/test_pass_simplify_expr.py
@@ -746,9 +746,54 @@ def test_simplify_clip_cast():
clip = relay.clip(x, a_min=0.0, a_max=255.0)
return relay.Function([x], clip)
+ def before3():
+ x = relay.var("x", shape=(4, 8), dtype="int32")
+ clip = relay.clip(x, a_min=0.0, a_max=255.0)
+ cast = relay.cast(clip, "uint8")
+ cast = relay.cast(cast, "int16")
+ cast = relay.cast(cast, "int32")
+ return relay.Function([x], cast)
+
+ def expected3():
+ x = relay.var("x", shape=(4, 8), dtype="int32")
+ clip = relay.clip(x, a_min=0.0, a_max=255.0)
+ return relay.Function([x], clip)
+
+ def before4():
+ x = relay.var("x", shape=(4, 8), dtype="float32")
+ clip = relay.clip(x, a_min=0.0, a_max=255.0)
+ cast = relay.cast(clip, "uint8")
+ cast = relay.cast(cast, "int16")
+ cast = relay.cast(cast, "int32")
+ return relay.Function([x], cast)
+
+ def expected4():
+ x = relay.var("x", shape=(4, 8), dtype="float32")
+ clip = relay.clip(x, a_min=0.0, a_max=255.0)
+ cast = relay.cast(clip, "int32")
+ return relay.Function([x], cast)
+
+ def before5():
+ x = relay.var("x", shape=(4, 8), dtype="float32")
+ clip = relay.clip(x, a_min=0.0, a_max=255.0)
+ cast = relay.cast(clip, "int8")
+ cast = relay.cast(cast, "int16")
+ cast = relay.cast(cast, "int32")
+ return relay.Function([x], cast)
+
+ def expected5():
+ x = relay.var("x", shape=(4, 8), dtype="float32")
+ clip = relay.clip(x, a_min=0.0, a_max=255.0)
+ cast = relay.cast(clip, "int8")
+ cast = relay.cast(cast, "int32")
+ return relay.Function([x], cast)
+
for before, expected in [
[before1(), expected1()],
[before2(), expected2()],
+ [before3(), expected3()],
+ [before4(), expected4()],
+ [before5(), expected5()],
]:
after = run_opt_pass(before, transform.SimplifyExpr())
expected = run_opt_pass(expected, transform.InferType())