This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 6ce16026f7 [Unity] support symbolic var in RewriteDataflowReshape 
(#16086)
6ce16026f7 is described below

commit 6ce16026f71800a8c242280e1d7b79f63d7bdaeb
Author: Hongyi Jin <[email protected]>
AuthorDate: Tue Nov 7 18:54:48 2023 -0800

    [Unity] support symbolic var in RewriteDataflowReshape (#16086)
    
    * rewrite dynamic shape reshape
    
    * format
    
    * fix test
---
 src/relax/analysis/tir_op_pattern_kind.cc          | 13 +++-
 .../test_transform_rewrite_dataflow_reshape.py     | 81 ++++++++++++++++++++++
 2 files changed, 92 insertions(+), 2 deletions(-)

diff --git a/src/relax/analysis/tir_op_pattern_kind.cc 
b/src/relax/analysis/tir_op_pattern_kind.cc
index 55b5d76213..c56f019e6b 100644
--- a/src/relax/analysis/tir_op_pattern_kind.cc
+++ b/src/relax/analysis/tir_op_pattern_kind.cc
@@ -399,8 +399,10 @@ bool HasReshapePattern(const PrimFunc& func) {
         return;
       }
 
+      Map<tir::Var, Range> var_range;
       for (const IterVar& v : block->iter_vars) {
         ana_.Bind(v->var, Range::FromMinExtent(v->dom->min, v->dom->extent));
+        var_range.Set(v->var, Range::FromMinExtent(v->dom->min, 
v->dom->extent));
       }
 
       // Step 1. Get the load/store pattern of the block body.
@@ -425,14 +427,21 @@ bool HasReshapePattern(const PrimFunc& func) {
       // This check requires at least one of the src/dst side is a trivial 
buffer
       // access (e.g., buf[ax0, ax1, ax2]).
 
-      auto f_calc_flattened_idx = [](const Buffer& buffer, const 
Array<PrimExpr>& indices) {
+      auto f_calc_flattened_idx = [&](const Buffer& buffer, const 
Array<PrimExpr>& indices) {
         ICHECK_EQ(indices.size(), buffer->shape.size());
         int ndim = indices.size();
         PrimExpr idx = 0;
         for (int i = 0; i < ndim; ++i) {
           idx = idx * buffer->shape[i] + indices[i];
         }
-        return idx;
+        idx = ana_.Simplify(idx);
+        return arith::IterMapSimplify(
+            /*indices=*/{idx},
+            /*input_iters=*/var_range,
+            /*input_pred=*/Bool(true),
+            /*check_level=*/arith::IterMapLevel::Surjective,
+            /*analyzer=*/&ana_,
+            /*simplify_trivial_iterators=*/true)[0];
       };
 
       auto f_is_trivial_indices = [block, this](const Buffer& buffer,
diff --git a/tests/python/relax/test_transform_rewrite_dataflow_reshape.py 
b/tests/python/relax/test_transform_rewrite_dataflow_reshape.py
index 01065bea21..26578393fe 100644
--- a/tests/python/relax/test_transform_rewrite_dataflow_reshape.py
+++ b/tests/python/relax/test_transform_rewrite_dataflow_reshape.py
@@ -221,6 +221,87 @@ def test_reshape_pattern_detect():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_reshape_dynamic_shape():
+    @tvm.script.ir_module
+    class Module:
+        @T.prim_func(private=True)
+        def reshape(var_A: T.handle, var_T_reshape: T.handle):
+            T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
+            n = T.int32()
+            A = T.match_buffer(var_A, (n, 16, 128), "float16")
+            T_reshape = T.match_buffer(var_T_reshape, (1, n, 16, 128), 
"float16")
+            # with T.block("root"):
+            for ax0_ax1_ax2_fused_0 in T.thread_binding(n * 2, 
thread="blockIdx.x"):
+                for ax0_ax1_ax2_fused_1 in T.thread_binding(1024, 
thread="threadIdx.x"):
+                    with T.block("T_reshape"):
+                        v0 = T.axis.spatial(
+                            n, (ax0_ax1_ax2_fused_0 * 1024 + 
ax0_ax1_ax2_fused_1) // 2048
+                        )
+                        v1 = T.axis.spatial(
+                            16, (ax0_ax1_ax2_fused_0 * 1024 + 
ax0_ax1_ax2_fused_1) % 2048 // 128
+                        )
+                        v2 = T.axis.spatial(
+                            128, (ax0_ax1_ax2_fused_0 * 1024 + 
ax0_ax1_ax2_fused_1) % 128
+                        )
+                        T.reads(
+                            A[((v2 // 128 + v1) // 32 + v0) % n, (v2 // 128 + 
v1) % 32, v2 % 128]
+                        )
+                        T.writes(T_reshape[0, v0, v1, v2])
+                        T_reshape[0, v0, v1, v2] = A[
+                            ((v2 // 128 + v1) // 32 + v0) % n, (v2 // 128 + 
v1) % 32, v2 % 128
+                        ]
+
+        @R.function
+        def main(x: R.Tensor((8, 3), dtype="float32")) -> R.Tensor((2, 4, 3), 
dtype="float32"):
+            cls = Module
+            with R.dataflow():
+                y = R.call_tir(cls.reshape, (x,), out_sinfo=R.Tensor((2, 4, 
3), dtype="float32"))
+                z = R.add(y, R.const(1, "float32"))
+                R.output(z)
+            return z
+
+    @tvm.script.ir_module
+    class Expected:
+        @T.prim_func(private=True)
+        def reshape(var_A: T.handle, var_T_reshape: T.handle):
+            T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
+            n = T.int32()
+            A = T.match_buffer(var_A, (n, 16, 128), "float16")
+            T_reshape = T.match_buffer(var_T_reshape, (1, n, 16, 128), 
"float16")
+            # with T.block("root"):
+            for ax0_ax1_ax2_fused_0 in T.thread_binding(n * 2, 
thread="blockIdx.x"):
+                for ax0_ax1_ax2_fused_1 in T.thread_binding(1024, 
thread="threadIdx.x"):
+                    with T.block("T_reshape"):
+                        v0 = T.axis.spatial(
+                            n, (ax0_ax1_ax2_fused_0 * 1024 + 
ax0_ax1_ax2_fused_1) // 2048
+                        )
+                        v1 = T.axis.spatial(
+                            16, (ax0_ax1_ax2_fused_0 * 1024 + 
ax0_ax1_ax2_fused_1) % 2048 // 128
+                        )
+                        v2 = T.axis.spatial(
+                            128, (ax0_ax1_ax2_fused_0 * 1024 + 
ax0_ax1_ax2_fused_1) % 128
+                        )
+                        T.reads(
+                            A[((v2 // 128 + v1) // 32 + v0) % n, (v2 // 128 + 
v1) % 32, v2 % 128]
+                        )
+                        T.writes(T_reshape[0, v0, v1, v2])
+                        T_reshape[0, v0, v1, v2] = A[
+                            ((v2 // 128 + v1) // 32 + v0) % n, (v2 // 128 + 
v1) % 32, v2 % 128
+                        ]
+
+        @R.function
+        def main(x: R.Tensor((8, 3), dtype="float32")) -> R.Tensor((2, 4, 3), 
dtype="float32"):
+            with R.dataflow():
+                y: R.Tensor((2, 4, 3), dtype="float32") = R.reshape(x, 
R.shape([2, 4, 3]))
+                z: R.Tensor((2, 4, 3), dtype="float32") = R.add(y, R.const(1, 
"float32"))
+                R.output(z)
+            return z
+
+    assert relax.analysis.has_reshape_pattern(Module["reshape"])
+    mod = relax.transform.RewriteDataflowReshape()(Module)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
 def test_reshape_non_dataflow():
     @tvm.script.ir_module
     class Module:

Reply via email to