This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 6ce16026f7 [Unity] support symbolic var in RewriteDataflowReshape
(#16086)
6ce16026f7 is described below
commit 6ce16026f71800a8c242280e1d7b79f63d7bdaeb
Author: Hongyi Jin <[email protected]>
AuthorDate: Tue Nov 7 18:54:48 2023 -0800
[Unity] support symbolic var in RewriteDataflowReshape (#16086)
* rewrite dynamic shape reshape
* format
* fix test
---
src/relax/analysis/tir_op_pattern_kind.cc | 13 +++-
.../test_transform_rewrite_dataflow_reshape.py | 81 ++++++++++++++++++++++
2 files changed, 92 insertions(+), 2 deletions(-)
diff --git a/src/relax/analysis/tir_op_pattern_kind.cc
b/src/relax/analysis/tir_op_pattern_kind.cc
index 55b5d76213..c56f019e6b 100644
--- a/src/relax/analysis/tir_op_pattern_kind.cc
+++ b/src/relax/analysis/tir_op_pattern_kind.cc
@@ -399,8 +399,10 @@ bool HasReshapePattern(const PrimFunc& func) {
return;
}
+ Map<tir::Var, Range> var_range;
for (const IterVar& v : block->iter_vars) {
ana_.Bind(v->var, Range::FromMinExtent(v->dom->min, v->dom->extent));
+ var_range.Set(v->var, Range::FromMinExtent(v->dom->min,
v->dom->extent));
}
// Step 1. Get the load/store pattern of the block body.
@@ -425,14 +427,21 @@ bool HasReshapePattern(const PrimFunc& func) {
// This check requires at least one of the src/dst side is a trivial
buffer
// access (e.g., buf[ax0, ax1, ax2]).
- auto f_calc_flattened_idx = [](const Buffer& buffer, const
Array<PrimExpr>& indices) {
+ auto f_calc_flattened_idx = [&](const Buffer& buffer, const
Array<PrimExpr>& indices) {
ICHECK_EQ(indices.size(), buffer->shape.size());
int ndim = indices.size();
PrimExpr idx = 0;
for (int i = 0; i < ndim; ++i) {
idx = idx * buffer->shape[i] + indices[i];
}
- return idx;
+ idx = ana_.Simplify(idx);
+ return arith::IterMapSimplify(
+ /*indices=*/{idx},
+ /*input_iters=*/var_range,
+ /*input_pred=*/Bool(true),
+ /*check_level=*/arith::IterMapLevel::Surjective,
+ /*analyzer=*/&ana_,
+ /*simplify_trivial_iterators=*/true)[0];
};
auto f_is_trivial_indices = [block, this](const Buffer& buffer,
diff --git a/tests/python/relax/test_transform_rewrite_dataflow_reshape.py
b/tests/python/relax/test_transform_rewrite_dataflow_reshape.py
index 01065bea21..26578393fe 100644
--- a/tests/python/relax/test_transform_rewrite_dataflow_reshape.py
+++ b/tests/python/relax/test_transform_rewrite_dataflow_reshape.py
@@ -221,6 +221,87 @@ def test_reshape_pattern_detect():
tvm.ir.assert_structural_equal(mod, Expected)
+def test_reshape_dynamic_shape():
+ @tvm.script.ir_module
+ class Module:
+ @T.prim_func(private=True)
+ def reshape(var_A: T.handle, var_T_reshape: T.handle):
+ T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
+ n = T.int32()
+ A = T.match_buffer(var_A, (n, 16, 128), "float16")
+ T_reshape = T.match_buffer(var_T_reshape, (1, n, 16, 128),
"float16")
+ # with T.block("root"):
+ for ax0_ax1_ax2_fused_0 in T.thread_binding(n * 2,
thread="blockIdx.x"):
+ for ax0_ax1_ax2_fused_1 in T.thread_binding(1024,
thread="threadIdx.x"):
+ with T.block("T_reshape"):
+ v0 = T.axis.spatial(
+ n, (ax0_ax1_ax2_fused_0 * 1024 +
ax0_ax1_ax2_fused_1) // 2048
+ )
+ v1 = T.axis.spatial(
+ 16, (ax0_ax1_ax2_fused_0 * 1024 +
ax0_ax1_ax2_fused_1) % 2048 // 128
+ )
+ v2 = T.axis.spatial(
+ 128, (ax0_ax1_ax2_fused_0 * 1024 +
ax0_ax1_ax2_fused_1) % 128
+ )
+ T.reads(
+ A[((v2 // 128 + v1) // 32 + v0) % n, (v2 // 128 +
v1) % 32, v2 % 128]
+ )
+ T.writes(T_reshape[0, v0, v1, v2])
+ T_reshape[0, v0, v1, v2] = A[
+ ((v2 // 128 + v1) // 32 + v0) % n, (v2 // 128 +
v1) % 32, v2 % 128
+ ]
+
+ @R.function
+ def main(x: R.Tensor((8, 3), dtype="float32")) -> R.Tensor((2, 4, 3),
dtype="float32"):
+ cls = Module
+ with R.dataflow():
+ y = R.call_tir(cls.reshape, (x,), out_sinfo=R.Tensor((2, 4,
3), dtype="float32"))
+ z = R.add(y, R.const(1, "float32"))
+ R.output(z)
+ return z
+
+ @tvm.script.ir_module
+ class Expected:
+ @T.prim_func(private=True)
+ def reshape(var_A: T.handle, var_T_reshape: T.handle):
+ T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
+ n = T.int32()
+ A = T.match_buffer(var_A, (n, 16, 128), "float16")
+ T_reshape = T.match_buffer(var_T_reshape, (1, n, 16, 128),
"float16")
+ # with T.block("root"):
+ for ax0_ax1_ax2_fused_0 in T.thread_binding(n * 2,
thread="blockIdx.x"):
+ for ax0_ax1_ax2_fused_1 in T.thread_binding(1024,
thread="threadIdx.x"):
+ with T.block("T_reshape"):
+ v0 = T.axis.spatial(
+ n, (ax0_ax1_ax2_fused_0 * 1024 +
ax0_ax1_ax2_fused_1) // 2048
+ )
+ v1 = T.axis.spatial(
+ 16, (ax0_ax1_ax2_fused_0 * 1024 +
ax0_ax1_ax2_fused_1) % 2048 // 128
+ )
+ v2 = T.axis.spatial(
+ 128, (ax0_ax1_ax2_fused_0 * 1024 +
ax0_ax1_ax2_fused_1) % 128
+ )
+ T.reads(
+ A[((v2 // 128 + v1) // 32 + v0) % n, (v2 // 128 +
v1) % 32, v2 % 128]
+ )
+ T.writes(T_reshape[0, v0, v1, v2])
+ T_reshape[0, v0, v1, v2] = A[
+ ((v2 // 128 + v1) // 32 + v0) % n, (v2 // 128 +
v1) % 32, v2 % 128
+ ]
+
+ @R.function
+ def main(x: R.Tensor((8, 3), dtype="float32")) -> R.Tensor((2, 4, 3),
dtype="float32"):
+ with R.dataflow():
+ y: R.Tensor((2, 4, 3), dtype="float32") = R.reshape(x,
R.shape([2, 4, 3]))
+ z: R.Tensor((2, 4, 3), dtype="float32") = R.add(y, R.const(1,
"float32"))
+ R.output(z)
+ return z
+
+ assert relax.analysis.has_reshape_pattern(Module["reshape"])
+ mod = relax.transform.RewriteDataflowReshape()(Module)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
def test_reshape_non_dataflow():
@tvm.script.ir_module
class Module: