This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 8babadb100 [Relay] improve SimplifyClipAndConsecutiveCast pass (#15362)
8babadb100 is described below

commit 8babadb10099b52080c731c60aedaf8b9029d2ba
Author: 电线杆 <[email protected]>
AuthorDate: Sun Jul 23 11:47:57 2023 +0800

    [Relay] improve SimplifyClipAndConsecutiveCast pass (#15362)
    
    * improve SimplifyClipAndConsecutiveCast
    
    * modify the test case ground truth just to pass test. Not sure if this is 
correct.
    
    * add documentation
---
 src/relay/transforms/simplify_expr.cc         | 86 ++++++++++++++++++---------
 tests/python/relay/aot/test_crt_aot_usmp.py   |  2 +-
 tests/python/relay/test_pass_simplify_expr.py | 45 ++++++++++++++
 3 files changed, 103 insertions(+), 30 deletions(-)

diff --git a/src/relay/transforms/simplify_expr.cc 
b/src/relay/transforms/simplify_expr.cc
index fa3348b95a..208c9821b6 100644
--- a/src/relay/transforms/simplify_expr.cc
+++ b/src/relay/transforms/simplify_expr.cc
@@ -160,7 +160,10 @@ class SimplifyConsecutiveCast : public DFPatternRewrite {
   DFPattern cast1_;
 };
 
-bool CheckDataTypeMaxMinValue(DataType dtype, double min_value, double 
max_value) {
+/*! If mode == 0, return true if the interval [min_value, max_value] contains 
the range of dtype,
+ * and return false otherwise. If mode == 1, return true if the interval 
[min_value, max_value] is
+ * contained by the range of dtype, and return false otherwise.*/
+bool CheckDataTypeMaxMinValue(DataType dtype, double min_value, double 
max_value, int mode = 0) {
   double lbound{}, ubound{};
   if (dtype.is_int() || dtype.is_uint()) {
     ubound = 
static_cast<double>(Downcast<IntImm>(tvm::max_value(dtype))->value);
@@ -169,21 +172,40 @@ bool CheckDataTypeMaxMinValue(DataType dtype, double 
min_value, double max_value
     ubound = Downcast<FloatImm>(tvm::max_value(dtype))->value;
     lbound = Downcast<FloatImm>(tvm::min_value(dtype))->value;
   }
-  return max_value >= ubound && min_value <= lbound;
+  if (mode == 0) {
+    return max_value >= ubound && min_value <= lbound;
+  } else if (mode == 1) {
+    return max_value <= ubound && min_value >= lbound;
+  } else {
+    LOG(FATAL) << "invalid mode " << mode << " in CheckDataTypeMaxMinValue";
+    return false;
+  }
 }
 
 /*!
- * \brief SimplifyClipAndConsecutiveCast matches the pattern clip->cast->cast 
and remove redundant
- *   casts.
- * Analysis of "redundancy" is done based on clip min/max values and min/max 
values of casted data
- * type.
+ * \brief SimplifyClipAndConsecutiveCast matches the pattern 
clip->cast->...->cast and remove
+ * redundant casts. Analysis of "redundancy" is done based on clip min/max 
values and min/max values
+ * of casted data type.
+ *
+ * Example:
+ *   %0 == [type=int32]
+ *   %1 = clip(%0, a_min=0f, a_max=255f) [type=int32]
+ *   %2 = cast(%1, dtype="uint8") [type=uint8]
+ *   %3 = cast(%2, dtype="int32") [type=int32]
+ *
+ * Optimized to (both casts can be removed):
+ *   %1 = clip(%0, a_min=0f, a_max=255f) [type=int32]
  */
 class SimplifyClipAndConsecutiveCast : public DFPatternRewrite {
  public:
   SimplifyClipAndConsecutiveCast() {
     clip_ = IsOp("clip")({IsWildcard()});
-    cast1_ = IsOp("cast")({clip_});
-    pattern_ = IsOp("cast")({cast1_});
+    ObjectPtr<CallPatternNode> pattern_ptr = make_object<CallPatternNode>();
+    pattern_ptr->op = IsOp("cast");
+    pattern_ptr->args.clear();
+    pattern_ = CallPattern(pattern_ptr);
+    AltPattern or_pattern{pattern_, clip_};
+    pattern_ptr->args.push_back(or_pattern);
   }
 
   Expr Callback(const Expr& pre, const Expr& post,
@@ -191,34 +213,40 @@ class SimplifyClipAndConsecutiveCast : public 
DFPatternRewrite {
     auto clip = Downcast<Call>(node_map[clip_][0]);
     const CallNode* clip_node = clip.as<CallNode>();
     const ClipAttrs* clip_attrs = clip_node->attrs.as<ClipAttrs>();
-    DataType clip_dtype = Downcast<TensorType>(clip->checked_type())->dtype;
 
-    auto cast1 = Downcast<Call>(node_map[cast1_][0]);
-    DataType cast1_dtype = Downcast<TensorType>(cast1->checked_type())->dtype;
+    std::vector<Expr> remaining_casts{};
+    Expr cast_expr{post};
+    while (cast_expr != clip) {
+      DataType cast_dtype = 
Downcast<TensorType>(cast_expr->checked_type())->dtype;
+      if (!CheckDataTypeMaxMinValue(cast_dtype, clip_attrs->a_min, 
clip_attrs->a_max, 1)) {
+        remaining_casts.push_back(cast_expr);
+      }
+      cast_expr = cast_expr.as<CallNode>()->args[0];
+    }
 
-    auto cast2 = Downcast<Call>(post);
-    DataType cast2_dtype = Downcast<TensorType>(cast2->checked_type())->dtype;
+    Expr last_op = (remaining_casts.size() == 0) ? clip : remaining_casts[0];
+    DataType last_op_dtype = 
Downcast<TensorType>(last_op->checked_type())->dtype;
+    bool need_additional_cast{false};
+    if (last_op_dtype != Downcast<TensorType>(post->checked_type())->dtype) {
+      need_additional_cast = true;
+    }
 
-    if (clip_dtype == cast2_dtype &&
-        CheckDataTypeMaxMinValue(cast1_dtype, clip_attrs->a_min, 
clip_attrs->a_max)) {
-      // Case 1:
-      // Data type of Clip == target data type of second Cast and min/max 
value of Clip == min/max
-      // value of first Clip target data type. In this case both Clip ops can 
be removed.
-      // Example:
-      //   %0 == [type=int32]
-      //   %1 = clip(%0, a_min=0f, a_max=255f) [type=int32]
-      //   %2 = cast(%1, dtype="uint8") [type=uint8]
-      //   %3 = cast(%2, dtype="int32") [type=int32]
-      //
-      // Optimized to (both casts can be removed):
-      //   %1 = clip(%0, a_min=0f, a_max=255f) [type=int32]
-      return node_map[clip_][0];
+    Expr res{clip};
+    for (size_t i = remaining_casts.size(); i > 0; --i) {
+      auto attrs = make_object<CastAttrs>();
+      attrs->dtype = remaining_casts[i - 
1].as<CallNode>()->attrs.as<CastAttrs>()->dtype;
+      res = Call(Op::Get("cast"), {res}, Attrs(attrs), {});
     }
-    return post;
+    if (need_additional_cast) {
+      auto attrs = make_object<CastAttrs>();
+      attrs->dtype = Downcast<TensorType>(post->checked_type())->dtype;
+      res = Call(Op::Get("cast"), {res}, Attrs(attrs), {});
+    }
+    return res;
   }
 
  protected:
-  DFPattern clip_, cast1_;
+  DFPattern clip_;
 };
 
 /*!
diff --git a/tests/python/relay/aot/test_crt_aot_usmp.py 
b/tests/python/relay/aot/test_crt_aot_usmp.py
index 83aa46dc31..130c26b6f8 100644
--- a/tests/python/relay/aot/test_crt_aot_usmp.py
+++ b/tests/python/relay/aot/test_crt_aot_usmp.py
@@ -303,7 +303,7 @@ MOBILENET_V2_URL = (
     "model_url, usmp_algo, workspace_size, constant_size",
     [
         (MOBILENET_V1_URL, "greedy_by_size", 4845696, 8468008),
-        (MOBILENET_V1_URL, "greedy_by_conflicts", 4845696, 8468008),
+        (MOBILENET_V1_URL, "greedy_by_conflicts", 4444288, 8468008),
         (MOBILENET_V1_URL, "hill_climb", 3240064, 8468008),
     ],
 )
diff --git a/tests/python/relay/test_pass_simplify_expr.py 
b/tests/python/relay/test_pass_simplify_expr.py
index b117c91d1c..ac6920d5b7 100644
--- a/tests/python/relay/test_pass_simplify_expr.py
+++ b/tests/python/relay/test_pass_simplify_expr.py
@@ -746,9 +746,54 @@ def test_simplify_clip_cast():
         clip = relay.clip(x, a_min=0.0, a_max=255.0)
         return relay.Function([x], clip)
 
+    def before3():
+        x = relay.var("x", shape=(4, 8), dtype="int32")
+        clip = relay.clip(x, a_min=0.0, a_max=255.0)
+        cast = relay.cast(clip, "uint8")
+        cast = relay.cast(cast, "int16")
+        cast = relay.cast(cast, "int32")
+        return relay.Function([x], cast)
+
+    def expected3():
+        x = relay.var("x", shape=(4, 8), dtype="int32")
+        clip = relay.clip(x, a_min=0.0, a_max=255.0)
+        return relay.Function([x], clip)
+
+    def before4():
+        x = relay.var("x", shape=(4, 8), dtype="float32")
+        clip = relay.clip(x, a_min=0.0, a_max=255.0)
+        cast = relay.cast(clip, "uint8")
+        cast = relay.cast(cast, "int16")
+        cast = relay.cast(cast, "int32")
+        return relay.Function([x], cast)
+
+    def expected4():
+        x = relay.var("x", shape=(4, 8), dtype="float32")
+        clip = relay.clip(x, a_min=0.0, a_max=255.0)
+        cast = relay.cast(clip, "int32")
+        return relay.Function([x], cast)
+
+    def before5():
+        x = relay.var("x", shape=(4, 8), dtype="float32")
+        clip = relay.clip(x, a_min=0.0, a_max=255.0)
+        cast = relay.cast(clip, "int8")
+        cast = relay.cast(cast, "int16")
+        cast = relay.cast(cast, "int32")
+        return relay.Function([x], cast)
+
+    def expected5():
+        x = relay.var("x", shape=(4, 8), dtype="float32")
+        clip = relay.clip(x, a_min=0.0, a_max=255.0)
+        cast = relay.cast(clip, "int8")
+        cast = relay.cast(cast, "int32")
+        return relay.Function([x], cast)
+
     for before, expected in [
         [before1(), expected1()],
         [before2(), expected2()],
+        [before3(), expected3()],
+        [before4(), expected4()],
+        [before5(), expected5()],
     ]:
         after = run_opt_pass(before, transform.SimplifyExpr())
         expected = run_opt_pass(expected, transform.InferType())

Reply via email to