This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 121bca6027 [Unity][Relax] Make RewriteDataflowReshape only rewrite 
volume-preserving ops (#15112)
121bca6027 is described below

commit 121bca6027e1d8f7754c26e0ee7563f39625c4f8
Author: Krzysztof Parzyszek <[email protected]>
AuthorDate: Sun Jun 18 14:23:36 2023 -0500

    [Unity][Relax] Make RewriteDataflowReshape only rewrite volume-preserving 
ops (#15112)
    
    * [Unity] Make RewriteDataflowReshape only rewrite volume-preserving ops
    
    The reshape operator expects that the number of elements in the source
    is the same as the number of elements in the result. There are operators
    that could have a reshape pattern that don't meet this requirement (e.g.
    strided_slice), and they should not be converted to reshape.
    
    * Move shape verification to IsCallingTIRReshape
    
    * Replace ICHECK_EQ(used_arg_indices.size(), 1) with return
    
    Since the check for has-reshape-pattern is done after this check, so
    don't abort if the check fails, just return.
---
 src/relax/transform/rewrite_dataflow_reshape.cc    | 51 +++++++++++++++++++---
 .../test_transform_rewrite_dataflow_reshape.py     | 50 ++++++++++++++++++++-
 2 files changed, 93 insertions(+), 8 deletions(-)

diff --git a/src/relax/transform/rewrite_dataflow_reshape.cc 
b/src/relax/transform/rewrite_dataflow_reshape.cc
index daef7f7409..796b761a16 100644
--- a/src/relax/transform/rewrite_dataflow_reshape.cc
+++ b/src/relax/transform/rewrite_dataflow_reshape.cc
@@ -20,12 +20,15 @@
  * \file src/relax/transform/rewrite_dataflow_reshape.cc
  * \brief Transform all reshape within dataflow block to a relax.reshape 
operator
  */
+#include <tvm/arith/analyzer.h>
 #include <tvm/relax/analysis.h>
 #include <tvm/relax/expr_functor.h>
 #include <tvm/relax/transform.h>
 #include <tvm/tir/analysis.h>
 #include <tvm/tir/function.h>
 
+#include <vector>
+
 #include "../op/tensor/manipulate.h"
 
 namespace tvm {
@@ -69,7 +72,7 @@ class DataflowReshapeRewriter : public ExprMutator {
   }
 
   Expr VisitExpr_(const CallNode* call) final {
-    if (!IsCallingTIRReshape(call)) {
+    if (call->args.size() < 2) {
       return GetRef<Call>(call);
     }
 
@@ -85,16 +88,21 @@ class DataflowReshapeRewriter : public ExprMutator {
     // can generate a fused TupleGetItem + reshape function whose input is a 
tuple. FuseTIR
     // then flattens the tuple input so that the fused TIR reshape function 
ends up having
     // multiple input buffers. But only one of them should be accessed and 
reshaped.
-    ICHECK_EQ(used_arg_indices.size(), 1);
+    if (used_arg_indices.size() != 1) {
+      return GetRef<Call>(call);
+    }
 
     auto arg = arg_tuple[used_arg_indices[0]];
 
-    TensorStructInfo res_sinfo = 
Downcast<TensorStructInfo>(call->struct_info_);
-    ICHECK(res_sinfo->shape.defined());
+    if (!IsCallingTIRReshape(call, arg)) {
+      return GetRef<Call>(call);
+    }
+
+    TensorStructInfo res_sinfo = 
Downcast<TensorStructInfo>(call->struct_info_.value());
     return reshape(arg, res_sinfo->shape.value());
   }
 
-  bool IsCallingTIRReshape(const CallNode* call) {
+  bool IsCallingTIRReshape(const CallNode* call, Expr inp) {
     static const Op& call_tir_op = Op::Get("relax.call_tir");
     if (call->op != call_tir_op) {
       return false;
@@ -102,7 +110,38 @@ class DataflowReshapeRewriter : public ExprMutator {
     const GlobalVar& global_var = Downcast<GlobalVar>(call->args[0]);
     const auto* func = mod_->functions.Get(global_var).as<tir::PrimFuncNode>();
     ICHECK_NOTNULL(func);
-    return HasReshapePattern(GetRef<tir::PrimFunc>(func));
+    if (!HasReshapePattern(GetRef<tir::PrimFunc>(func))) {
+      return false;
+    }
+
+    // The reshape operator expects that the number of elements in the source 
is the same
+    // as the number of elements in the result. There are operators that could 
have a reshape
+    // pattern that don't meet this requirement (e.g. strided_slice), and they 
should not be
+    // converted to reshape.
+    ICHECK(inp->struct_info_.defined() && call->struct_info_.defined());
+    TensorStructInfo inp_sinfo = 
Downcast<TensorStructInfo>(inp->struct_info_.value());
+    TensorStructInfo res_sinfo = 
Downcast<TensorStructInfo>(call->struct_info_.value());
+
+    if (inp_sinfo->IsUnknownDtype() || inp_sinfo->dtype != res_sinfo->dtype) {
+      return false;
+    }
+    ICHECK(inp_sinfo->shape.defined() && res_sinfo->shape.defined());
+    if (inp_sinfo->IsUnknownNdim() || res_sinfo->IsUnknownNdim()) {
+      return false;
+    }
+    auto product = [](Array<PrimExpr> args) -> PrimExpr {
+      ICHECK(!args.empty());
+      PrimExpr p = args[0];
+      for (int i = 1, e = args.size(); i < e; ++i) p *= args[i];
+      return p;
+    };
+    auto inp_count = product(inp_sinfo->GetShape().value());
+    auto res_count = product(res_sinfo->GetShape().value());
+    if (!arith::Analyzer().CanProveEqual(inp_count, res_count)) {
+      return false;
+    }
+
+    return true;
   }
 
   const IRModule& mod_;
diff --git a/tests/python/relax/test_transform_rewrite_dataflow_reshape.py 
b/tests/python/relax/test_transform_rewrite_dataflow_reshape.py
index ff802bfb56..3a3da3a7dc 100644
--- a/tests/python/relax/test_transform_rewrite_dataflow_reshape.py
+++ b/tests/python/relax/test_transform_rewrite_dataflow_reshape.py
@@ -48,8 +48,7 @@ def test_reshape_expand_dims():
         def expand_dims(
             rxplaceholder: T.Buffer((T.int64(2), T.int64(4), T.int64(3)), 
"float32"),
             expand_dims: T.Buffer(
-                (T.int64(2), T.int64(1), T.int64(4), T.int64(1), T.int64(3)),
-                "float32",
+                (T.int64(2), T.int64(1), T.int64(4), T.int64(1), T.int64(3)), 
"float32"
             ),
         ):
             for i0, i1, i2, i3, i4 in T.grid(
@@ -381,5 +380,52 @@ def test_tuple_get_reshape():
     tvm.ir.assert_structural_equal(rewritten, Expected)
 
 
+def test_invalid_reshape():
+    @tvm.script.ir_module
+    class Module:
+        # The strided_slice op has the reshape pattern, but it can take only a 
part of the input.
+        # It can't be replaced with the reshape op because reshape expects to 
preserve the "volume"
+        # of the input.
+        @T.prim_func
+        def strided_slice(
+            A: T.Buffer((T.int64(1), T.int64(1024)), "int32"),
+            T_strided_slice: T.Buffer((T.int64(1), T.int64(1000)), "int32"),
+        ):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            for ax0, ax1 in T.grid(T.int64(1), T.int64(1000)):
+                with T.block("T_strided_slice"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(A[v_ax0, v_ax1])
+                    T.writes(T_strided_slice[v_ax0, v_ax1])
+                    T_strided_slice[v_ax0, v_ax1] = A[v_ax0, v_ax1]
+
+        @T.prim_func
+        def add_one(
+            A: T.Buffer((T.int64(1), T.int64(1000)), "int32"),
+            T_add_one: T.buffer((T.int64(1), T.int64(1000)), "int32"),
+        ):
+            for ax0, ax1 in T.grid(T.int64(1), T.int64(1000)):
+                with T.block("T_add_one"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(A[v_ax0, v_ax1])
+                    T.writes(T_add_one[v_ax0, v_ax1])
+                    T_add_one[v_ax0, v_ax1] = A[v_ax0, v_ax1] + 1
+
+        @R.function
+        def main(A: R.Tensor((1, 1024), dtype="int32")) -> R.Tensor((1, 1000), 
dtype="int32"):
+            with R.dataflow():
+                cls = Module
+                S = R.call_tir(
+                    cls.strided_slice, (A,), out_sinfo=R.Tensor((1, 1000), 
dtype="int32")
+                )
+                A = R.call_tir(cls.add_one, (S,), out_sinfo=R.Tensor((1, 
1000), dtype="int32"))
+                R.output(A)
+            return A
+
+    assert relax.analysis.has_reshape_pattern(Module["strided_slice"])
+    rewritten = relax.transform.RewriteDataflowReshape()(Module)
+    tvm.ir.assert_structural_equal(rewritten, Module)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to