This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 2dc5046afc [Unity][Transform] Enhance RewriteDataflowReshape transform
(#14265)
2dc5046afc is described below
commit 2dc5046afc265bd83cd694c13f951d8c51f82fc4
Author: Bohan Hou <[email protected]>
AuthorDate: Sat Mar 18 11:49:41 2023 -0400
[Unity][Transform] Enhance RewriteDataflowReshape transform (#14265)
This PR enhances the current RewriteDataflowReshape transformation.
Originally, it will use loop vars to prove the equality of addresses of LHS
and RHS, which will cause some cases to fail due to the limitation of the arith
module.
Instead, we can just use block vars to do the proof, which is supposed to
be equivalent and the expressions are simpler, which allow us to cover more
cases.
---
src/relax/analysis/tir_op_pattern_kind.cc | 12 ++-
.../test_transform_rewrite_dataflow_reshape.py | 93 +++++++++++++++++++++-
2 files changed, 97 insertions(+), 8 deletions(-)
diff --git a/src/relax/analysis/tir_op_pattern_kind.cc
b/src/relax/analysis/tir_op_pattern_kind.cc
index dfa073fd9c..aed984781c 100644
--- a/src/relax/analysis/tir_op_pattern_kind.cc
+++ b/src/relax/analysis/tir_op_pattern_kind.cc
@@ -364,7 +364,6 @@ bool HasReshapePattern(const PrimFunc& func) {
if (block_iter[i]->iter_type != tir::IterVarType::kDataPar) {
return;
}
- var_map_.Set(block_iter[i]->var, iter_values[i]);
}
// Recurse into the block.
@@ -378,6 +377,10 @@ bool HasReshapePattern(const PrimFunc& func) {
return;
}
+ for (const IterVar& v : block->iter_vars) {
+ ana_.Bind(v->var, Range::FromMinExtent(v->dom->min, v->dom->extent));
+ }
+
// Step 1. Get the load/store pattern of the block body.
// To detect the reshape pattern, we require the block body to be a
// BufferStore, which has a BufferLoad as value.
@@ -409,18 +412,13 @@ bool HasReshapePattern(const PrimFunc& func) {
PrimExpr src_idx = f_calc_flattened_idx(src_buffer_,
buffer_load->indices);
PrimExpr dst_idx = f_calc_flattened_idx(dst_buffer_,
buffer_store->indices);
- // Step 4. Substitute the block iterators in the flattened index
- // with loop variables, and check if we can prove their equality.
- src_idx = tir::Substitute(std::move(src_idx), var_map_);
- dst_idx = tir::Substitute(std::move(dst_idx), var_map_);
+ // Step 4. Check if we can prove the equality of flattened indices.
if (ana_.CanProveEqual(src_idx, dst_idx)) {
this->is_reshape_ = true;
}
}
bool is_reshape_;
- /*! \brief The mapping from block vars to block binding values. */
- Map<tir::Var, PrimExpr> var_map_;
const Buffer& src_buffer_;
const Buffer& dst_buffer_;
arith::Analyzer ana_;
diff --git a/tests/python/relax/test_transform_rewrite_dataflow_reshape.py
b/tests/python/relax/test_transform_rewrite_dataflow_reshape.py
index b94c6864c4..ecf2a96064 100644
--- a/tests/python/relax/test_transform_rewrite_dataflow_reshape.py
+++ b/tests/python/relax/test_transform_rewrite_dataflow_reshape.py
@@ -132,6 +132,96 @@ def test_reshape_expand_dims():
tvm.ir.assert_structural_equal(mod, Expected)
+def test_reshape_pattern_detect():
+ # fmt: off
+ @tvm.script.ir_module
+ class Module:
+ @T.prim_func
+ def reshape(rxplaceholder: T.Buffer((T.int64(2), T.int64(4096),
T.int64(320)), "float32"), T_reshape: T.Buffer((T.int64(2), T.int64(4096),
T.int64(5), T.int64(64)), "float32")):
+ for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(T.int64(256),
thread="blockIdx.x"):
+ for ax0_ax1_ax2_ax3_fused_2 in T.thread_binding(T.int64(1024),
thread="threadIdx.x"):
+ for ax0_ax1_ax2_ax3_fused_0 in range(T.int64(10)):
+ with T.block("T_reshape"):
+ v_ax0 = T.axis.spatial(T.int64(2),
(ax0_ax1_ax2_ax3_fused_0 * T.int64(262144) + ax0_ax1_ax2_ax3_fused_1 *
T.int64(1024) + ax0_ax1_ax2_ax3_fused_2) // T.int64(1310720))
+ v_ax1 = T.axis.spatial(T.int64(4096),
(ax0_ax1_ax2_ax3_fused_0 * T.int64(262144) + ax0_ax1_ax2_ax3_fused_1 *
T.int64(1024) + ax0_ax1_ax2_ax3_fused_2) % T.int64(1310720) // T.int64(320))
+ v_ax2 = T.axis.spatial(T.int64(5),
(ax0_ax1_ax2_ax3_fused_0 * T.int64(262144) + ax0_ax1_ax2_ax3_fused_1 *
T.int64(1024) + ax0_ax1_ax2_ax3_fused_2) % T.int64(320) // T.int64(64))
+ v_ax3 = T.axis.spatial(T.int64(64),
(ax0_ax1_ax2_ax3_fused_0 * T.int64(262144) + ax0_ax1_ax2_ax3_fused_1 *
T.int64(1024) + ax0_ax1_ax2_ax3_fused_2) % T.int64(64))
+ T.reads(rxplaceholder[(((v_ax2 * T.int64(64) +
v_ax3) // T.int64(320) + v_ax1) // T.int64(4096) + v_ax0) % T.int64(2), ((v_ax2
* T.int64(64) + v_ax3) // T.int64(320) + v_ax1) % T.int64(4096), (v_ax2 *
T.int64(64) + v_ax3) % T.int64(320)])
+ T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3])
+ T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] =
rxplaceholder[(((v_ax2 * T.int64(64) + v_ax3) // T.int64(320) + v_ax1) //
T.int64(4096) + v_ax0) % T.int64(2), ((v_ax2 * T.int64(64) + v_ax3) //
T.int64(320) + v_ax1) % T.int64(4096), (v_ax2 * T.int64(64) + v_ax3) %
T.int64(320)]
+
+ @T.prim_func
+ def expand_dims(
+ rxplaceholder: T.Buffer((T.int64(2), T.int64(4096), T.int64(5),
T.int64(64)), "float32"),
+ expand_dims: T.Buffer(
+ (T.int64(2), T.int64(1), T.int64(4096), T.int64(1),
T.int64(5), T.int64(64)),
+ "float32",
+ ),
+ ):
+ for i0, i1, i2, i3, i4, i5 in T.grid(
+ T.int64(2), T.int64(1), T.int64(4096), T.int64(1), T.int64(5),
T.int64(64)
+ ):
+ with T.block("expand_dims"):
+ i0_1, i1_1, i2_1, i3_1, i4_1, i5_1 =
T.axis.remap("SSSSSS", [i0, i1, i2, i3, i4, i5])
+ T.reads(rxplaceholder[i0_1, i2_1, i4_1, i5_1])
+ T.writes(expand_dims[i0_1, i1_1, i2_1, i3_1, i4_1, i5_1])
+ expand_dims[i0_1, i1_1, i2_1, i3_1, i4_1, i5_1] =
rxplaceholder[i0_1, i2_1, i4_1, i5_1]
+
+ @R.function
+ def main(
+ x: R.Tensor((2, 4096, 320), dtype="float32")
+ ) -> R.Tensor((2, 1, 4096, 1, 5, 64), dtype="float32"):
+ cls = Module
+ with R.dataflow():
+ y = R.call_tir(cls.reshape, (x,), out_sinfo=R.Tensor((2, 4096,
5, 64), dtype="float32"))
+ z = R.call_tir(
+ cls.expand_dims, (y,), out_sinfo=R.Tensor((2, 1, 4096, 1,
5, 64), "float32")
+ )
+ R.output(z)
+ return z
+
+ @tvm.script.ir_module
+ class Expected:
+ @T.prim_func
+ def expand_dims(rxplaceholder: T.Buffer((T.int64(2), T.int64(4096),
T.int64(5), T.int64(64)), "float32"), expand_dims_1: T.Buffer((T.int64(2),
T.int64(1), T.int64(4096), T.int64(1), T.int64(5), T.int64(64)), "float32")):
+ # with T.block("root"):
+ for i0, i1, i2, i3, i4, i5 in T.grid(T.int64(2), T.int64(1),
T.int64(4096), T.int64(1), T.int64(5), T.int64(64)):
+ with T.block("expand_dims"):
+ i0_1, i1_1, i2_1, i3_1, i4_1, i5_1 =
T.axis.remap("SSSSSS", [i0, i1, i2, i3, i4, i5])
+ T.reads(rxplaceholder[i0_1, i2_1, i4_1, i5_1])
+ T.writes(expand_dims_1[i0_1, i1_1, i2_1, i3_1, i4_1, i5_1])
+ expand_dims_1[i0_1, i1_1, i2_1, i3_1, i4_1, i5_1] =
rxplaceholder[i0_1, i2_1, i4_1, i5_1]
+
+ @T.prim_func
+ def reshape(rxplaceholder: T.Buffer((T.int64(2), T.int64(4096),
T.int64(320)), "float32"), T_reshape: T.Buffer((T.int64(2), T.int64(4096),
T.int64(5), T.int64(64)), "float32")):
+ # with T.block("root"):
+ for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(T.int64(256),
thread="blockIdx.x"):
+ for ax0_ax1_ax2_ax3_fused_2 in T.thread_binding(T.int64(1024),
thread="threadIdx.x"):
+ for ax0_ax1_ax2_ax3_fused_0 in range(T.int64(10)):
+ with T.block("T_reshape"):
+ v_ax0 = T.axis.spatial(T.int64(2),
(ax0_ax1_ax2_ax3_fused_0 * T.int64(262144) + ax0_ax1_ax2_ax3_fused_1 *
T.int64(1024) + ax0_ax1_ax2_ax3_fused_2) // T.int64(1310720))
+ v_ax1 = T.axis.spatial(T.int64(4096),
(ax0_ax1_ax2_ax3_fused_0 * T.int64(262144) + ax0_ax1_ax2_ax3_fused_1 *
T.int64(1024) + ax0_ax1_ax2_ax3_fused_2) % T.int64(1310720) // T.int64(320))
+ v_ax2 = T.axis.spatial(T.int64(5),
(ax0_ax1_ax2_ax3_fused_0 * T.int64(262144) + ax0_ax1_ax2_ax3_fused_1 *
T.int64(1024) + ax0_ax1_ax2_ax3_fused_2) % T.int64(320) // T.int64(64))
+ v_ax3 = T.axis.spatial(T.int64(64),
(ax0_ax1_ax2_ax3_fused_0 * T.int64(262144) + ax0_ax1_ax2_ax3_fused_1 *
T.int64(1024) + ax0_ax1_ax2_ax3_fused_2) % T.int64(64))
+ T.reads(rxplaceholder[(((v_ax2 * T.int64(64) +
v_ax3) // T.int64(320) + v_ax1) // T.int64(4096) + v_ax0) % T.int64(2), ((v_ax2
* T.int64(64) + v_ax3) // T.int64(320) + v_ax1) % T.int64(4096), (v_ax2 *
T.int64(64) + v_ax3) % T.int64(320)])
+ T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3])
+ T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] =
rxplaceholder[(((v_ax2 * T.int64(64) + v_ax3) // T.int64(320) + v_ax1) //
T.int64(4096) + v_ax0) % T.int64(2), ((v_ax2 * T.int64(64) + v_ax3) //
T.int64(320) + v_ax1) % T.int64(4096), (v_ax2 * T.int64(64) + v_ax3) %
T.int64(320)]
+
+ @R.function
+ def main(x: R.Tensor((2, 4096, 320), dtype="float32")) -> R.Tensor((2,
1, 4096, 1, 5, 64), dtype="float32"):
+ cls = Expected
+ with R.dataflow():
+ y: R.Tensor((2, 4096, 5, 64), dtype="float32") = R.reshape(x,
R.shape([2, 4096, 5, 64]))
+ z = R.call_tir(cls.expand_dims, (y,), out_sinfo=R.Tensor((2,
1, 4096, 1, 5, 64), dtype="float32"))
+ R.output(z)
+ return z
+ # fmt: on
+
+ assert relax.analysis.has_reshape_pattern(Module["reshape"])
+ mod = relax.transform.RewriteDataflowReshape()(Module)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
def test_reshape_non_dataflow():
@tvm.script.ir_module
class Module:
@@ -168,4 +258,5 @@ def test_reshape_non_dataflow():
if __name__ == "__main__":
- tvm.testing.main()
+ test_reshape_pattern_detect()
+ # tvm.testing.main()