This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 0fc047c98b [Compute-inline] Prefer T.where for reverse compute-inlined
block with predicate (#17128)
0fc047c98b is described below
commit 0fc047c98b1ebf730b8c9aad8b94ddac28a7b34b
Author: wrongtest <[email protected]>
AuthorDate: Fri Jul 5 11:45:12 2024 +0800
[Compute-inline] Prefer T.where for reverse compute-inlined block with
predicate (#17128)
* prefer T.where for reverse compute-inlined block with predicate
* update ut scripts
---------
Co-authored-by: wrongtest <[email protected]>
---
src/tir/schedule/primitive/compute_inline.cc | 44 +++++++++-------
tests/python/dlight/test_gpu_matmul.py | 20 ++++----
tests/python/dlight/test_gpu_matmul_tensorize.py | 20 ++++----
.../test_meta_schedule_schedule_rule_mlt_tc.py | 4 +-
.../test_tir_schedule_compute_inline.py | 59 +++++++++++++++++++---
5 files changed, 98 insertions(+), 49 deletions(-)
diff --git a/src/tir/schedule/primitive/compute_inline.cc
b/src/tir/schedule/primitive/compute_inline.cc
index d6be0e5805..df74497b4a 100644
--- a/src/tir/schedule/primitive/compute_inline.cc
+++ b/src/tir/schedule/primitive/compute_inline.cc
@@ -682,11 +682,14 @@ class ReverseComputeInliner : public BaseInliner {
using BaseInliner::VisitStmt_;
/*! \brief Generate the predicate after inlining based on the consumer
predicate */
- Block BuildInlinedConsumerPredicate(const BlockNode* producer_block) {
+ BlockRealize BuildInlinedConsumerPredicate(BlockRealize
producer_block_realize) {
// Bind the producer block iter domains for simplification
Map<Var, PrimExpr> subst_map;
+ Block producer_block = producer_block_realize->block;
for (int i = 0, n = producer_block->iter_vars.size(); i < n; ++i) {
const IterVar& iter = producer_block->iter_vars[i];
+ const PrimExpr& binding = producer_block_realize->iter_values[i];
+ subst_map.Set(iter->var, binding);
analyzer_.Bind(iter->var, Range::FromMinExtent(iter->dom->min,
iter->dom->extent));
}
if (producer_block->annotations.count(tir::attr::auto_copy) != 0) {
@@ -705,30 +708,33 @@ class ReverseComputeInliner : public BaseInliner {
PrimExpr predicate = Substituter(this)(consumer_iter_in_bound_);
// Simplify the predicate using the producer block iter domains
predicate = analyzer_.Simplify(predicate);
- ObjectPtr<BlockNode> block = make_object<BlockNode>(*producer_block);
if (is_one(predicate)) {
- return Block(block);
- }
- if (const auto* if_ = producer_block->body.as<tir::IfThenElseNode>()) {
- PrimExpr if_predicate = analyzer_.Simplify(if_->condition);
- if (!StructuralEqual()(predicate, if_predicate)) {
- predicate = analyzer_.Simplify(predicate && if_->condition);
+ return producer_block_realize;
+ }
+ if (const auto* if_ = producer_block->body.as<IfThenElseNode>()) {
+ if (!if_->else_case.defined()) {
+ PrimExpr if_predicate = analyzer_.Simplify(if_->condition);
+ if (!StructuralEqual()(predicate, if_predicate)) {
+ predicate = analyzer_.Simplify(predicate && if_->condition);
+ producer_block.CopyOnWrite()->body = if_->then_case;
+ }
}
- block->body = IfThenElse(predicate, if_->then_case);
- return Block(block);
}
- block->body = IfThenElse(predicate, block->body);
- return Block(block);
+ PrimExpr outer_predicate = Substitute(predicate, subst_map);
+ auto n = producer_block_realize.CopyOnWrite();
+ n->block = producer_block;
+ n->predicate = analyzer_.Simplify(outer_predicate);
+ return GetRef<BlockRealize>(n);
}
- Stmt VisitStmt_(const BlockNode* op) final {
- Block src_block = GetRef<Block>(op);
- Block tgt_block = Downcast<Block>(BaseInliner::VisitStmt_(op));
- if (op == producer_block_) {
- tgt_block = BuildInlinedConsumerPredicate(tgt_block.get());
- block_reuse.Set(src_block, tgt_block);
+ Stmt VisitStmt_(const BlockRealizeNode* op) final {
+ Block src_block = op->block;
+ BlockRealize tgt_block_realize =
Downcast<BlockRealize>(StmtMutator::VisitStmt_(op));
+ if (src_block.get() == producer_block_) {
+ tgt_block_realize = BuildInlinedConsumerPredicate(tgt_block_realize);
+ block_reuse.Set(src_block, tgt_block_realize->block);
}
- return std::move(tgt_block);
+ return std::move(tgt_block_realize);
}
Stmt VisitStmt_(const BufferStoreNode* _store) final {
diff --git a/tests/python/dlight/test_gpu_matmul.py
b/tests/python/dlight/test_gpu_matmul.py
index 63117073d1..ca32c286ab 100644
--- a/tests/python/dlight/test_gpu_matmul.py
+++ b/tests/python/dlight/test_gpu_matmul.py
@@ -113,10 +113,10 @@ class TestMatmul(BaseBeforeAfter):
v0 = T.axis.spatial(T.int64(1),
ax0)
v1 = T.axis.spatial((m +
T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_2 *
T.int64(4) + ax1)
v2 = T.axis.spatial(T.int64(4096),
ax0_ax2_0_fused * T.int64(64) + ax2_2 * T.int64(4) + ax2_0 * T.int64(2) +
ax2_1_1)
+ T.where(ax1_0 * T.int64(32) +
ax1_2 * T.int64(4) + ax1 < m)
T.reads(matmul_reindex_pad_local[v0, v1, v2])
T.writes(matmul[T.int64(0), v1,
v2])
- if v1 < m:
- matmul[T.int64(0), v1, v2] =
matmul_reindex_pad_local[v0, v1, v2]
+ matmul[T.int64(0), v1, v2] =
matmul_reindex_pad_local[v0, v1, v2]
# fmt: on
@@ -200,10 +200,10 @@ def test_matmul_int32():
v0 = T.axis.spatial(1, ax0)
v1 = T.axis.spatial((m + 31) // 32
* 32, ax1_0 * 32 + ax1_2 * 4 + ax1)
v2 = T.axis.spatial(4096,
ax0_ax2_0_fused * 64 + ax2_2 * 4 + ax2_0 * 2 + ax2_1_1)
+ T.where(ax1_0 * 32 + ax1_2 * 4 +
ax1 < m)
T.reads(matmul_reindex_pad_local[v0, v1, v2])
T.writes(matmul[0, v1, v2])
- if v1 < m:
- matmul[0, v1, v2] =
matmul_reindex_pad_local[v0, v1, v2]
+ matmul[0, v1, v2] =
matmul_reindex_pad_local[v0, v1, v2]
# fmt: on
mod = tvm.IRModule({"main": func})
@@ -466,10 +466,10 @@ class TestOutputFP32(BaseBeforeAfter):
v0 = T.axis.spatial(T.int64(1),
ax0)
v1 = T.axis.spatial((n +
T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_2 *
T.int64(4) + ax1)
v2 = T.axis.spatial(T.int64(4096),
ax0_ax2_0_fused * T.int64(64) + ax2_2 * T.int64(4) + ax2_0 * T.int64(2) +
ax2_1_1)
+ T.where(ax1_0 * T.int64(32) +
ax1_2 * T.int64(4) + ax1 < n)
T.reads(var_matmul_intermediate_reindex_pad_local[v0, v1, v2], lv13_1[v2],
lv3[T.int64(0), v1, v2])
T.writes(p_output0_intermediate[T.int64(0), v1, v2])
- if v1 < n:
-
p_output0_intermediate[T.int64(0), v1, v2] = T.Cast("float16",
var_matmul_intermediate_reindex_pad_local[v0, v1, v2] + T.Cast("float32",
lv13_1[v2])) + lv3[T.int64(0), v1, v2]
+ p_output0_intermediate[T.int64(0),
v1, v2] = T.Cast("float16", var_matmul_intermediate_reindex_pad_local[v0, v1,
v2] + T.Cast("float32", lv13_1[v2])) + lv3[T.int64(0), v1, v2]
# fmt: on
@@ -596,9 +596,9 @@ class TestInlineConsumerChain(BaseBeforeAfter):
v1 = T.axis.spatial((n +
T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_2 *
T.int64(4) + ax1)
v2 = T.axis.spatial(T.int64(2048),
ax0_ax2_0_fused * T.int64(64) + ax2_2 * T.int64(4) + ax2_0 * T.int64(2) +
ax2_1_1)
T.reads(lv52[T.int64(0), v1, v2],
var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2])
+ T.where(ax1_0 * T.int64(32) +
ax1_2 * T.int64(4) + ax1 < n)
T.writes(var_T_multiply_intermediate[v1, v2])
- if v1 < n:
-
var_T_multiply_intermediate[v1, v2] = T.Cast("float16", lv52[T.int64(0), v1,
v2]) * (var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2] *
T.sigmoid(var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2]))
+ var_T_multiply_intermediate[v1,
v2] = T.Cast("float16", lv52[T.int64(0), v1, v2]) *
(var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2] *
T.sigmoid(var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2]))
# fmt: on
@@ -666,10 +666,10 @@ class TestMatmulAndroid(AndroidBeforeAfter):
v0 = T.axis.spatial(T.int64(1),
ax0)
v1 = T.axis.spatial((m +
T.int64(31)) // T.int64(32) * T.int64(32), ax0_ax1_0_fused * T.int64(32) +
ax1_2 * T.int64(2) + ax1)
v2 = T.axis.spatial(T.int64(4096),
ax2_0 * T.int64(64) + ax2_2 * T.int64(8) + ax2_0_1 * T.int64(8) + ax2_1_1)
+ T.where(ax0_ax1_0_fused *
T.int64(32) + ax1_2 * T.int64(2) + ax1 < m)
T.reads(matmul_reindex_pad_local[v0, v1, v2])
T.writes(matmul[T.int64(0), v1,
v2])
- if v1 < m:
- matmul[T.int64(0), v1, v2] =
matmul_reindex_pad_local[v0, v1, v2]
+ matmul[T.int64(0), v1, v2] =
matmul_reindex_pad_local[v0, v1, v2]
# fmt: on
diff --git a/tests/python/dlight/test_gpu_matmul_tensorize.py
b/tests/python/dlight/test_gpu_matmul_tensorize.py
index 59ccfec55c..94d6a8e42a 100644
--- a/tests/python/dlight/test_gpu_matmul_tensorize.py
+++ b/tests/python/dlight/test_gpu_matmul_tensorize.py
@@ -254,10 +254,10 @@ class TestMatmulTensorizeTooSmall(BaseBeforeAfter):
v0 = T.axis.spatial(1, ax0)
v1 = T.axis.spatial((m + 31) // 32
* 32, ax1_0 * 32 + ax1_2 * 4 + ax1)
v2 = T.axis.spatial(64, ax2_2 * 4
+ ax2_0 * 2 + ax2_1_1)
+ T.where(ax1_0 * 32 + ax1_2 * 4 +
ax1 < m and ax2_2 * 4 + ax2_0 * 2 + ax2_1_1 < 15)
T.reads(compute_reindex_pad_local[v0, v1, v2])
T.writes(compute[v1, v2])
- if v1 < m and v2 < 15:
- compute[v1, v2] =
compute_reindex_pad_local[v0, v1, v2]
+ compute[v1, v2] =
compute_reindex_pad_local[v0, v1, v2]
# fmt: on
@@ -417,11 +417,11 @@ class TestMatmulTensorizeEpilogue(BaseBeforeAfter):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial((n + 127) // 128 *
128, ax1_0_0_ax2_0_0_fused * 128 + ax2_0_2_ax1_0_2_fused % 4 * 32 +
(ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) // 32)
v2 = T.axis.spatial(4096,
ax1_0_1_ax2_0_1_fused * 128 + ax2_0_2_ax1_0_2_fused // 4 * 32 +
(ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) % 32)
+ T.where(ax1_0_0_ax2_0_0_fused * 128 +
ax2_0_2_ax1_0_2_fused % 4 * 32 + ((ax0_ax1_fused_0 * 32 + ax0_ax1_fused_1) * 4
+ ax0_ax1_fused_2) // 32 < n)
T.reads(lv3[0, v1, v2],
var_NT_matmul_intermediate_reindex_pad_shared_dyn[v0, v1, v2])
T.writes(p_output0_intermediate[0, v1,
v2])
T.block_attr({"buffer_dim_align": [[0,
1, 16, 4]]})
- if v1 < n:
- p_output0_intermediate[0, v1, v2]
= lv3[0, v1, v2] * T.float16(0.5) +
var_NT_matmul_intermediate_reindex_pad_shared_dyn[v0, v1, v2]
+ p_output0_intermediate[0, v1, v2] =
lv3[0, v1, v2] * T.float16(0.5) +
var_NT_matmul_intermediate_reindex_pad_shared_dyn[v0, v1, v2]
# fmt: on
@@ -690,11 +690,11 @@ class TestMatmulInt8Tensorize3d2dDyn(BaseBeforeAfter):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial((m + 127) // 128 *
128, ax1_0_0_ax2_0_0_fused * 128 + ax2_0_2_ax1_0_2_fused % 4 * 32 +
(ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) // 32)
v2 = T.axis.spatial(4096,
ax1_0_1_ax2_0_1_fused * 128 + ax2_0_2_ax1_0_2_fused // 4 * 32 +
(ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) % 32)
+ T.where(ax1_0_0_ax2_0_0_fused * 128 +
ax2_0_2_ax1_0_2_fused % 4 * 32 + ((ax0_ax1_fused_0 * 32 + ax0_ax1_fused_1) * 4
+ ax0_ax1_fused_2) // 32 < m)
T.reads(matmul_1_reindex_pad_shared_dyn[v0, v1, v2])
T.writes(matmul_1[0, v1, v2])
T.block_attr({"buffer_dim_align": [[0,
1, 16, 4]]})
- if v1 < m:
- matmul_1[0, v1, v2] =
matmul_1_reindex_pad_shared_dyn[v0, v1, v2]
+ matmul_1[0, v1, v2] =
matmul_1_reindex_pad_shared_dyn[v0, v1, v2]
# fmt: on
@@ -831,10 +831,10 @@ class TestMatmulMetal(MetalBeforeAfter):
v0 = T.axis.spatial(1, ax0_1)
v1 = T.axis.spatial((batch_size +
15) // 16 * 16, ax1_0 * 16 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 128 +
ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) // 64)
v2 = T.axis.spatial(28672, ax2_0 *
64 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 128 + ax1_ax2_fused_2 * 128 +
ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) % 64)
+ T.where(ax1_0 * 16 +
(((ax1_ax2_fused_0 * 4 + ax1_ax2_fused_1 + ax1_ax2_fused_2) * 32 +
ax1_ax2_fused_3) * 4 + ax1_ax2_fused_4) // 64 < batch_size)
T.reads(C_reindex_pad_shared[v0,
v1, v2])
T.writes(C[v1, 0, v2])
- if v1 < batch_size:
- C[v1, 0, v2] =
C_reindex_pad_shared[v0, v1, v2]
+ C[v1, 0, v2] =
C_reindex_pad_shared[v0, v1, v2]
# fmt: on
@@ -971,10 +971,10 @@ class TestMatmulMetalInt4Quant(MetalBeforeAfter):
v0 = T.axis.spatial(1, ax0_1)
v1 = T.axis.spatial((batch_size +
15) // 16 * 16, ax1_0 * 16 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 128 +
ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) // 64)
v2 = T.axis.spatial(28672, ax2_0 *
64 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 128 + ax1_ax2_fused_2 * 128 +
ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) % 64)
+ T.where(ax1_0 * 16 +
(((ax1_ax2_fused_0 * 4 + ax1_ax2_fused_1 + ax1_ax2_fused_2) * 32 +
ax1_ax2_fused_3) * 4 + ax1_ax2_fused_4) // 64 < batch_size)
T.reads(C_reindex_pad_shared[v0,
v1, v2])
T.writes(C[v1, 0, v2])
- if v1 < batch_size:
- C[v1, 0, v2] =
C_reindex_pad_shared[v0, v1, v2]
+ C[v1, 0, v2] =
C_reindex_pad_shared[v0, v1, v2]
if __name__ == "__main__":
diff --git
a/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc.py
b/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc.py
index da00f294ba..df8607e551 100644
--- a/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc.py
+++ b/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc.py
@@ -856,11 +856,11 @@ def test_padded_matmul_relu():
v3 = T.axis.spatial(1, 0)
v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused
% 256 // 16)
v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused
% 16)
+ T.where(ax0_0_0_ax1_0_0_fused // 2 * 32 + ax2 * 16
+ ax0_ax1_ax3_ax4_ax5_fused % 256 // 16 < 127 and ax0_0_0_ax1_0_0_fused % 2 *
64 + ax0_0_1_ax1_0_1_fused * 32 + ax0_ax1_ax3_ax4_ax5_fused // 256 * 16 +
ax0_ax1_ax3_ax4_ax5_fused % 16 < 127)
T.reads(C_reindex_shared[v0, v1, v2, v3, v4, v5])
T.writes(compute[v4 + v2 * 16 + v0 * 32, v5 + v1 *
16])
T.block_attr({"meta_schedule.cooperative_fetch":
4})
- if v0 * 32 + v2 * 16 + v4 < 127 and v1 * 16 + v5 <
127:
- compute[v4 + v2 * 16 + v0 * 32, v5 + v1 * 16]
= T.max(C_reindex_shared[v0, v1, v2, v3, v4, v5], T.float32(0))
+ compute[v4 + v2 * 16 + v0 * 32, v5 + v1 * 16] =
T.max(C_reindex_shared[v0, v1, v2, v3, v4, v5], T.float32(0))
# fmt: on
decision_0 = [
diff --git a/tests/python/tir-schedule/test_tir_schedule_compute_inline.py
b/tests/python/tir-schedule/test_tir_schedule_compute_inline.py
index 5cf59985d3..2f779612a7 100644
--- a/tests/python/tir-schedule/test_tir_schedule_compute_inline.py
+++ b/tests/python/tir-schedule/test_tir_schedule_compute_inline.py
@@ -624,8 +624,8 @@ def elementwise_overcomputed_producer_reverse_inlined(
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
- if vi < 127 and vj < 127:
- C[vi, vj] = A[vi, vj] * 2.0 + 1.0
+ T.where(i < 127 and j < 127)
+ C[vi, vj] = A[vi, vj] * 2.0 + 1.0
@T.prim_func
@@ -652,8 +652,8 @@ def
elementwise_overcomputed_producer_simplify_predicate_reverse_inlined(
with T.block("B"):
vi = T.axis.spatial(128, i // 128)
vj = T.axis.spatial(128, i % 128)
- if vi < 127 and vj < 127:
- C[vi, vj] = A[vi, vj] * 2.0 + 1.0
+ T.where(i < 16255 and i % 128 < 127)
+ C[vi, vj] = A[vi, vj] * 2.0 + 1.0
@T.prim_func
@@ -678,8 +678,8 @@ def
elementwise_overcomputed_producer_injective_load_reverse_inlined(
for i0, j0, i1, j1 in T.grid(8, 8, 16, 16):
with T.block("B"):
vi, vj, vm, vn = T.axis.remap("SSSS", [i0, j0, i1, j1])
- if vi * 16 + vm < 127 and vj * 16 + vn < 127:
- C[vm + vi * 16, vn + vj * 16] = A[vi * 16 + vm, vj * 16 + vn]
* 2.0 + 1.0
+ T.where(i0 * 16 + i1 < 127 and j0 * 16 + j1 < 127)
+ C[vm + vi * 16, vn + vj * 16] = A[vi * 16 + vm, vj * 16 + vn] *
2.0 + 1.0
@T.prim_func
@@ -740,8 +740,7 @@ def elementwise_predicate_producer_inlined(a: T.handle, c:
T.handle) -> None:
vi, vj = T.axis.remap("SS", [i, j])
T.reads(A[vi, vj])
T.writes(C[vi, vj])
- if vi < 127:
- C[vi, vj] = A[vi, vj] * T.float32(2) + T.float32(1)
+ C[vi, vj] = A[vi, vj] * T.float32(2) + T.float32(1)
# fmt: off
@@ -1486,5 +1485,49 @@ def test_reverse_compute_inline_layer_norm():
assert_structural_equal_ignore_global_symbol(after, sch.mod["main"])
+def test_reverse_compute_inline_slicing_then_cachewrite():
+ @T.prim_func
+ def before(
+ x: T.Buffer((1, 16, 7, 7), "float32"),
+ T_strided_slice_with_axes: T.Buffer((1, 12, 7, 7), "float32"),
+ ):
+ T_add = T.alloc_buffer((1, 16, 7, 7))
+ for ax0, ax1, ax2, ax3 in T.grid(1, 16, 7, 7):
+ with T.block("T_add"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1,
ax2, ax3])
+ T_add[v_ax0, v_ax1, v_ax2, v_ax3] = x[v_ax0, v_ax1, v_ax2,
v_ax3] + T.float32(1)
+ for ax0, ax1, ax2, ax3 in T.grid(1, 12, 7, 7):
+ with T.block("T_strided_slice_with_axes"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1,
ax2, ax3])
+ T_strided_slice_with_axes[v_ax0, v_ax1, v_ax2, v_ax3] = T_add[
+ v_ax0, v_ax1, v_ax2, v_ax3
+ ]
+
+ @T.prim_func
+ def after(
+ x: T.Buffer((1, 16, 7, 7), "float32"),
+ T_strided_slice_with_axes: T.Buffer((1, 12, 7, 7), "float32"),
+ ):
+ T_strided_slice_with_axes_global = T.alloc_buffer((1, 12, 7, 7))
+ for ax0, ax1, ax2, ax3 in T.grid(1, 16, 7, 7):
+ with T.block("T_add"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1,
ax2, ax3])
+ T.where(ax1 < 12)
+ T_strided_slice_with_axes_global[v_ax0, v_ax1, v_ax2, v_ax3] =
x[
+ v_ax0, v_ax1, v_ax2, v_ax3
+ ] + T.float32(1)
+ for ax0, ax1, ax2, ax3 in T.grid(1, 12, 7, 7):
+ with T.block("T_strided_slice_with_axes_global"):
+ v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
+ T_strided_slice_with_axes[v0, v1, v2, v3] =
T_strided_slice_with_axes_global[
+ v0, v1, v2, v3
+ ]
+
+ sch = tir.Schedule(before)
+ sch.reverse_compute_inline(sch.get_block("T_strided_slice_with_axes"))
+ sch.cache_write(sch.get_block("T_add"), 0, "global")
+ assert_structural_equal_ignore_global_symbol(after, sch.mod["main"])
+
+
if __name__ == "__main__":
tvm.testing.main()