This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 0fc047c98b [Compute-inline] Prefer T.where for reverse compute-inlined 
block with predicate (#17128)
0fc047c98b is described below

commit 0fc047c98b1ebf730b8c9aad8b94ddac28a7b34b
Author: wrongtest <[email protected]>
AuthorDate: Fri Jul 5 11:45:12 2024 +0800

    [Compute-inline] Prefer T.where for reverse compute-inlined block with 
predicate (#17128)
    
    * prefer T.where for reverse compute-inlined block with predicate
    
    * update ut scripts
    
    ---------
    
    Co-authored-by: wrongtest <[email protected]>
---
 src/tir/schedule/primitive/compute_inline.cc       | 44 +++++++++-------
 tests/python/dlight/test_gpu_matmul.py             | 20 ++++----
 tests/python/dlight/test_gpu_matmul_tensorize.py   | 20 ++++----
 .../test_meta_schedule_schedule_rule_mlt_tc.py     |  4 +-
 .../test_tir_schedule_compute_inline.py            | 59 +++++++++++++++++++---
 5 files changed, 98 insertions(+), 49 deletions(-)

diff --git a/src/tir/schedule/primitive/compute_inline.cc 
b/src/tir/schedule/primitive/compute_inline.cc
index d6be0e5805..df74497b4a 100644
--- a/src/tir/schedule/primitive/compute_inline.cc
+++ b/src/tir/schedule/primitive/compute_inline.cc
@@ -682,11 +682,14 @@ class ReverseComputeInliner : public BaseInliner {
   using BaseInliner::VisitStmt_;
 
   /*! \brief Generate the predicate after inlining based on the consumer 
predicate */
-  Block BuildInlinedConsumerPredicate(const BlockNode* producer_block) {
+  BlockRealize BuildInlinedConsumerPredicate(BlockRealize 
producer_block_realize) {
     // Bind the producer block iter domains for simplification
     Map<Var, PrimExpr> subst_map;
+    Block producer_block = producer_block_realize->block;
     for (int i = 0, n = producer_block->iter_vars.size(); i < n; ++i) {
       const IterVar& iter = producer_block->iter_vars[i];
+      const PrimExpr& binding = producer_block_realize->iter_values[i];
+      subst_map.Set(iter->var, binding);
       analyzer_.Bind(iter->var, Range::FromMinExtent(iter->dom->min, 
iter->dom->extent));
     }
     if (producer_block->annotations.count(tir::attr::auto_copy) != 0) {
@@ -705,30 +708,33 @@ class ReverseComputeInliner : public BaseInliner {
     PrimExpr predicate = Substituter(this)(consumer_iter_in_bound_);
     // Simplify the predicate using the producer block iter domains
     predicate = analyzer_.Simplify(predicate);
-    ObjectPtr<BlockNode> block = make_object<BlockNode>(*producer_block);
     if (is_one(predicate)) {
-      return Block(block);
-    }
-    if (const auto* if_ = producer_block->body.as<tir::IfThenElseNode>()) {
-      PrimExpr if_predicate = analyzer_.Simplify(if_->condition);
-      if (!StructuralEqual()(predicate, if_predicate)) {
-        predicate = analyzer_.Simplify(predicate && if_->condition);
+      return producer_block_realize;
+    }
+    if (const auto* if_ = producer_block->body.as<IfThenElseNode>()) {
+      if (!if_->else_case.defined()) {
+        PrimExpr if_predicate = analyzer_.Simplify(if_->condition);
+        if (!StructuralEqual()(predicate, if_predicate)) {
+          predicate = analyzer_.Simplify(predicate && if_->condition);
+          producer_block.CopyOnWrite()->body = if_->then_case;
+        }
       }
-      block->body = IfThenElse(predicate, if_->then_case);
-      return Block(block);
     }
-    block->body = IfThenElse(predicate, block->body);
-    return Block(block);
+    PrimExpr outer_predicate = Substitute(predicate, subst_map);
+    auto n = producer_block_realize.CopyOnWrite();
+    n->block = producer_block;
+    n->predicate = analyzer_.Simplify(outer_predicate);
+    return GetRef<BlockRealize>(n);
   }
 
-  Stmt VisitStmt_(const BlockNode* op) final {
-    Block src_block = GetRef<Block>(op);
-    Block tgt_block = Downcast<Block>(BaseInliner::VisitStmt_(op));
-    if (op == producer_block_) {
-      tgt_block = BuildInlinedConsumerPredicate(tgt_block.get());
-      block_reuse.Set(src_block, tgt_block);
+  Stmt VisitStmt_(const BlockRealizeNode* op) final {
+    Block src_block = op->block;
+    BlockRealize tgt_block_realize = 
Downcast<BlockRealize>(StmtMutator::VisitStmt_(op));
+    if (src_block.get() == producer_block_) {
+      tgt_block_realize = BuildInlinedConsumerPredicate(tgt_block_realize);
+      block_reuse.Set(src_block, tgt_block_realize->block);
     }
-    return std::move(tgt_block);
+    return std::move(tgt_block_realize);
   }
 
   Stmt VisitStmt_(const BufferStoreNode* _store) final {
diff --git a/tests/python/dlight/test_gpu_matmul.py 
b/tests/python/dlight/test_gpu_matmul.py
index 63117073d1..ca32c286ab 100644
--- a/tests/python/dlight/test_gpu_matmul.py
+++ b/tests/python/dlight/test_gpu_matmul.py
@@ -113,10 +113,10 @@ class TestMatmul(BaseBeforeAfter):
                                             v0 = T.axis.spatial(T.int64(1), 
ax0)
                                             v1 = T.axis.spatial((m + 
T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_2 * 
T.int64(4) + ax1)
                                             v2 = T.axis.spatial(T.int64(4096), 
ax0_ax2_0_fused * T.int64(64) + ax2_2 * T.int64(4) + ax2_0 * T.int64(2) + 
ax2_1_1)
+                                            T.where(ax1_0 * T.int64(32) + 
ax1_2 * T.int64(4) + ax1 < m)
                                             
T.reads(matmul_reindex_pad_local[v0, v1, v2])
                                             T.writes(matmul[T.int64(0), v1, 
v2])
-                                            if v1 < m:
-                                                matmul[T.int64(0), v1, v2] = 
matmul_reindex_pad_local[v0, v1, v2]
+                                            matmul[T.int64(0), v1, v2] = 
matmul_reindex_pad_local[v0, v1, v2]
     # fmt: on
 
 
@@ -200,10 +200,10 @@ def test_matmul_int32():
                                             v0 = T.axis.spatial(1, ax0)
                                             v1 = T.axis.spatial((m + 31) // 32 
* 32, ax1_0 * 32 + ax1_2 * 4 + ax1)
                                             v2 = T.axis.spatial(4096, 
ax0_ax2_0_fused * 64 + ax2_2 * 4 + ax2_0 * 2 + ax2_1_1)
+                                            T.where(ax1_0 * 32 + ax1_2 * 4 + 
ax1 < m)
                                             
T.reads(matmul_reindex_pad_local[v0, v1, v2])
                                             T.writes(matmul[0, v1, v2])
-                                            if v1 < m:
-                                                matmul[0, v1, v2] = 
matmul_reindex_pad_local[v0, v1, v2]
+                                            matmul[0, v1, v2] = 
matmul_reindex_pad_local[v0, v1, v2]
     # fmt: on
 
     mod = tvm.IRModule({"main": func})
@@ -466,10 +466,10 @@ class TestOutputFP32(BaseBeforeAfter):
                                             v0 = T.axis.spatial(T.int64(1), 
ax0)
                                             v1 = T.axis.spatial((n + 
T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_2 * 
T.int64(4) + ax1)
                                             v2 = T.axis.spatial(T.int64(4096), 
ax0_ax2_0_fused * T.int64(64) + ax2_2 * T.int64(4) + ax2_0 * T.int64(2) + 
ax2_1_1)
+                                            T.where(ax1_0 * T.int64(32) + 
ax1_2 * T.int64(4) + ax1 < n)
                                             
T.reads(var_matmul_intermediate_reindex_pad_local[v0, v1, v2], lv13_1[v2], 
lv3[T.int64(0), v1, v2])
                                             
T.writes(p_output0_intermediate[T.int64(0), v1, v2])
-                                            if v1 < n:
-                                                
p_output0_intermediate[T.int64(0), v1, v2] = T.Cast("float16", 
var_matmul_intermediate_reindex_pad_local[v0, v1, v2] + T.Cast("float32", 
lv13_1[v2])) + lv3[T.int64(0), v1, v2]
+                                            p_output0_intermediate[T.int64(0), 
v1, v2] = T.Cast("float16", var_matmul_intermediate_reindex_pad_local[v0, v1, 
v2] + T.Cast("float32", lv13_1[v2])) + lv3[T.int64(0), v1, v2]
 
     # fmt: on
 
@@ -596,9 +596,9 @@ class TestInlineConsumerChain(BaseBeforeAfter):
                                             v1 = T.axis.spatial((n + 
T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_2 * 
T.int64(4) + ax1)
                                             v2 = T.axis.spatial(T.int64(2048), 
ax0_ax2_0_fused * T.int64(64) + ax2_2 * T.int64(4) + ax2_0 * T.int64(2) + 
ax2_1_1)
                                             T.reads(lv52[T.int64(0), v1, v2], 
var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2])
+                                            T.where(ax1_0 * T.int64(32) + 
ax1_2 * T.int64(4) + ax1 < n)
                                             
T.writes(var_T_multiply_intermediate[v1, v2])
-                                            if v1 < n:
-                                                
var_T_multiply_intermediate[v1, v2] = T.Cast("float16", lv52[T.int64(0), v1, 
v2]) * (var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2] * 
T.sigmoid(var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2]))
+                                            var_T_multiply_intermediate[v1, 
v2] = T.Cast("float16", lv52[T.int64(0), v1, v2]) * 
(var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2] * 
T.sigmoid(var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2]))
 
     # fmt: on
 
@@ -666,10 +666,10 @@ class TestMatmulAndroid(AndroidBeforeAfter):
                                             v0 = T.axis.spatial(T.int64(1), 
ax0)
                                             v1 = T.axis.spatial((m + 
T.int64(31)) // T.int64(32) * T.int64(32), ax0_ax1_0_fused * T.int64(32) + 
ax1_2 * T.int64(2) + ax1)
                                             v2 = T.axis.spatial(T.int64(4096), 
ax2_0 * T.int64(64) + ax2_2 * T.int64(8) + ax2_0_1 * T.int64(8) + ax2_1_1)
+                                            T.where(ax0_ax1_0_fused * 
T.int64(32) + ax1_2 * T.int64(2) + ax1 < m)
                                             
T.reads(matmul_reindex_pad_local[v0, v1, v2])
                                             T.writes(matmul[T.int64(0), v1, 
v2])
-                                            if v1 < m:
-                                                matmul[T.int64(0), v1, v2] = 
matmul_reindex_pad_local[v0, v1, v2]
+                                            matmul[T.int64(0), v1, v2] = 
matmul_reindex_pad_local[v0, v1, v2]
     # fmt: on
 
 
diff --git a/tests/python/dlight/test_gpu_matmul_tensorize.py 
b/tests/python/dlight/test_gpu_matmul_tensorize.py
index 59ccfec55c..94d6a8e42a 100644
--- a/tests/python/dlight/test_gpu_matmul_tensorize.py
+++ b/tests/python/dlight/test_gpu_matmul_tensorize.py
@@ -254,10 +254,10 @@ class TestMatmulTensorizeTooSmall(BaseBeforeAfter):
                                             v0 = T.axis.spatial(1, ax0)
                                             v1 = T.axis.spatial((m + 31) // 32 
* 32, ax1_0 * 32 + ax1_2 * 4 + ax1)
                                             v2 = T.axis.spatial(64, ax2_2 * 4 
+ ax2_0 * 2 + ax2_1_1)
+                                            T.where(ax1_0 * 32 + ax1_2 * 4 + 
ax1 < m and ax2_2 * 4 + ax2_0 * 2 + ax2_1_1 < 15)
                                             
T.reads(compute_reindex_pad_local[v0, v1, v2])
                                             T.writes(compute[v1, v2])
-                                            if v1 < m and v2 < 15:
-                                                compute[v1, v2] = 
compute_reindex_pad_local[v0, v1, v2]
+                                            compute[v1, v2] = 
compute_reindex_pad_local[v0, v1, v2]
     # fmt: on
 
 
@@ -417,11 +417,11 @@ class TestMatmulTensorizeEpilogue(BaseBeforeAfter):
                                         v0 = T.axis.spatial(1, 0)
                                         v1 = T.axis.spatial((n + 127) // 128 * 
128, ax1_0_0_ax2_0_0_fused * 128 + ax2_0_2_ax1_0_2_fused % 4 * 32 + 
(ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) // 32)
                                         v2 = T.axis.spatial(4096, 
ax1_0_1_ax2_0_1_fused * 128 + ax2_0_2_ax1_0_2_fused // 4 * 32 + 
(ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) % 32)
+                                        T.where(ax1_0_0_ax2_0_0_fused * 128 + 
ax2_0_2_ax1_0_2_fused % 4 * 32 + ((ax0_ax1_fused_0 * 32 + ax0_ax1_fused_1) * 4 
+ ax0_ax1_fused_2) // 32 < n)
                                         T.reads(lv3[0, v1, v2], 
var_NT_matmul_intermediate_reindex_pad_shared_dyn[v0, v1, v2])
                                         T.writes(p_output0_intermediate[0, v1, 
v2])
                                         T.block_attr({"buffer_dim_align": [[0, 
1, 16, 4]]})
-                                        if v1 < n:
-                                            p_output0_intermediate[0, v1, v2] 
= lv3[0, v1, v2] * T.float16(0.5) + 
var_NT_matmul_intermediate_reindex_pad_shared_dyn[v0, v1, v2]
+                                        p_output0_intermediate[0, v1, v2] = 
lv3[0, v1, v2] * T.float16(0.5) + 
var_NT_matmul_intermediate_reindex_pad_shared_dyn[v0, v1, v2]
     # fmt: on
 
 
@@ -690,11 +690,11 @@ class TestMatmulInt8Tensorize3d2dDyn(BaseBeforeAfter):
                                         v0 = T.axis.spatial(1, 0)
                                         v1 = T.axis.spatial((m + 127) // 128 * 
128, ax1_0_0_ax2_0_0_fused * 128 + ax2_0_2_ax1_0_2_fused % 4 * 32 + 
(ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) // 32)
                                         v2 = T.axis.spatial(4096, 
ax1_0_1_ax2_0_1_fused * 128 + ax2_0_2_ax1_0_2_fused // 4 * 32 + 
(ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) % 32)
+                                        T.where(ax1_0_0_ax2_0_0_fused * 128 + 
ax2_0_2_ax1_0_2_fused % 4 * 32 + ((ax0_ax1_fused_0 * 32 + ax0_ax1_fused_1) * 4 
+ ax0_ax1_fused_2) // 32 < m)
                                         
T.reads(matmul_1_reindex_pad_shared_dyn[v0, v1, v2])
                                         T.writes(matmul_1[0, v1, v2])
                                         T.block_attr({"buffer_dim_align": [[0, 
1, 16, 4]]})
-                                        if v1 < m:
-                                            matmul_1[0, v1, v2] = 
matmul_1_reindex_pad_shared_dyn[v0, v1, v2]
+                                        matmul_1[0, v1, v2] = 
matmul_1_reindex_pad_shared_dyn[v0, v1, v2]
     # fmt: on
 
 
@@ -831,10 +831,10 @@ class TestMatmulMetal(MetalBeforeAfter):
                                             v0 = T.axis.spatial(1, ax0_1)
                                             v1 = T.axis.spatial((batch_size + 
15) // 16 * 16, ax1_0 * 16 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 128 + 
ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) // 64)
                                             v2 = T.axis.spatial(28672, ax2_0 * 
64 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 128 + ax1_ax2_fused_2 * 128 + 
ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) % 64)
+                                            T.where(ax1_0 * 16 + 
(((ax1_ax2_fused_0 * 4 + ax1_ax2_fused_1 + ax1_ax2_fused_2) * 32 + 
ax1_ax2_fused_3) * 4 + ax1_ax2_fused_4) // 64 < batch_size)
                                             T.reads(C_reindex_pad_shared[v0, 
v1, v2])
                                             T.writes(C[v1, 0, v2])
-                                            if v1 < batch_size:
-                                                C[v1, 0, v2] = 
C_reindex_pad_shared[v0, v1, v2]
+                                            C[v1, 0, v2] = 
C_reindex_pad_shared[v0, v1, v2]
     # fmt: on
 
 
@@ -971,10 +971,10 @@ class TestMatmulMetalInt4Quant(MetalBeforeAfter):
                                             v0 = T.axis.spatial(1, ax0_1)
                                             v1 = T.axis.spatial((batch_size + 
15) // 16 * 16, ax1_0 * 16 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 128 + 
ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) // 64)
                                             v2 = T.axis.spatial(28672, ax2_0 * 
64 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 128 + ax1_ax2_fused_2 * 128 + 
ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) % 64)
+                                            T.where(ax1_0 * 16 + 
(((ax1_ax2_fused_0 * 4 + ax1_ax2_fused_1 + ax1_ax2_fused_2) * 32 + 
ax1_ax2_fused_3) * 4 + ax1_ax2_fused_4) // 64 < batch_size)
                                             T.reads(C_reindex_pad_shared[v0, 
v1, v2])
                                             T.writes(C[v1, 0, v2])
-                                            if v1 < batch_size:
-                                                C[v1, 0, v2] = 
C_reindex_pad_shared[v0, v1, v2]
+                                            C[v1, 0, v2] = 
C_reindex_pad_shared[v0, v1, v2]
 
 
 if __name__ == "__main__":
diff --git 
a/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc.py 
b/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc.py
index da00f294ba..df8607e551 100644
--- a/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc.py
+++ b/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc.py
@@ -856,11 +856,11 @@ def test_padded_matmul_relu():
                             v3 = T.axis.spatial(1, 0)
                             v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused 
% 256 // 16)
                             v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused 
% 16)
+                            T.where(ax0_0_0_ax1_0_0_fused // 2 * 32 + ax2 * 16 
+ ax0_ax1_ax3_ax4_ax5_fused % 256 // 16 < 127 and ax0_0_0_ax1_0_0_fused % 2 * 
64 + ax0_0_1_ax1_0_1_fused * 32 + ax0_ax1_ax3_ax4_ax5_fused // 256 * 16 + 
ax0_ax1_ax3_ax4_ax5_fused % 16 < 127)
                             T.reads(C_reindex_shared[v0, v1, v2, v3, v4, v5])
                             T.writes(compute[v4 + v2 * 16 + v0 * 32, v5 + v1 * 
16])
                             T.block_attr({"meta_schedule.cooperative_fetch": 
4})
-                            if v0 * 32 + v2 * 16 + v4 < 127 and v1 * 16 + v5 < 
127:
-                                compute[v4 + v2 * 16 + v0 * 32, v5 + v1 * 16] 
= T.max(C_reindex_shared[v0, v1, v2, v3, v4, v5], T.float32(0))
+                            compute[v4 + v2 * 16 + v0 * 32, v5 + v1 * 16] = 
T.max(C_reindex_shared[v0, v1, v2, v3, v4, v5], T.float32(0))
     # fmt: on
 
     decision_0 = [
diff --git a/tests/python/tir-schedule/test_tir_schedule_compute_inline.py 
b/tests/python/tir-schedule/test_tir_schedule_compute_inline.py
index 5cf59985d3..2f779612a7 100644
--- a/tests/python/tir-schedule/test_tir_schedule_compute_inline.py
+++ b/tests/python/tir-schedule/test_tir_schedule_compute_inline.py
@@ -624,8 +624,8 @@ def elementwise_overcomputed_producer_reverse_inlined(
     for i, j in T.grid(128, 128):
         with T.block("B"):
             vi, vj = T.axis.remap("SS", [i, j])
-            if vi < 127 and vj < 127:
-                C[vi, vj] = A[vi, vj] * 2.0 + 1.0
+            T.where(i < 127 and j < 127)
+            C[vi, vj] = A[vi, vj] * 2.0 + 1.0
 
 
 @T.prim_func
@@ -652,8 +652,8 @@ def 
elementwise_overcomputed_producer_simplify_predicate_reverse_inlined(
         with T.block("B"):
             vi = T.axis.spatial(128, i // 128)
             vj = T.axis.spatial(128, i % 128)
-            if vi < 127 and vj < 127:
-                C[vi, vj] = A[vi, vj] * 2.0 + 1.0
+            T.where(i < 16255 and i % 128 < 127)
+            C[vi, vj] = A[vi, vj] * 2.0 + 1.0
 
 
 @T.prim_func
@@ -678,8 +678,8 @@ def 
elementwise_overcomputed_producer_injective_load_reverse_inlined(
     for i0, j0, i1, j1 in T.grid(8, 8, 16, 16):
         with T.block("B"):
             vi, vj, vm, vn = T.axis.remap("SSSS", [i0, j0, i1, j1])
-            if vi * 16 + vm < 127 and vj * 16 + vn < 127:
-                C[vm + vi * 16, vn + vj * 16] = A[vi * 16 + vm, vj * 16 + vn] 
* 2.0 + 1.0
+            T.where(i0 * 16 + i1 < 127 and j0 * 16 + j1 < 127)
+            C[vm + vi * 16, vn + vj * 16] = A[vi * 16 + vm, vj * 16 + vn] * 
2.0 + 1.0
 
 
 @T.prim_func
@@ -740,8 +740,7 @@ def elementwise_predicate_producer_inlined(a: T.handle, c: 
T.handle) -> None:
             vi, vj = T.axis.remap("SS", [i, j])
             T.reads(A[vi, vj])
             T.writes(C[vi, vj])
-            if vi < 127:
-                C[vi, vj] = A[vi, vj] * T.float32(2) + T.float32(1)
+            C[vi, vj] = A[vi, vj] * T.float32(2) + T.float32(1)
 
 
 # fmt: off
@@ -1486,5 +1485,49 @@ def test_reverse_compute_inline_layer_norm():
     assert_structural_equal_ignore_global_symbol(after, sch.mod["main"])
 
 
+def test_reverse_compute_inline_slicing_then_cachewrite():
+    @T.prim_func
+    def before(
+        x: T.Buffer((1, 16, 7, 7), "float32"),
+        T_strided_slice_with_axes: T.Buffer((1, 12, 7, 7), "float32"),
+    ):
+        T_add = T.alloc_buffer((1, 16, 7, 7))
+        for ax0, ax1, ax2, ax3 in T.grid(1, 16, 7, 7):
+            with T.block("T_add"):
+                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, 
ax2, ax3])
+                T_add[v_ax0, v_ax1, v_ax2, v_ax3] = x[v_ax0, v_ax1, v_ax2, 
v_ax3] + T.float32(1)
+        for ax0, ax1, ax2, ax3 in T.grid(1, 12, 7, 7):
+            with T.block("T_strided_slice_with_axes"):
+                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, 
ax2, ax3])
+                T_strided_slice_with_axes[v_ax0, v_ax1, v_ax2, v_ax3] = T_add[
+                    v_ax0, v_ax1, v_ax2, v_ax3
+                ]
+
+    @T.prim_func
+    def after(
+        x: T.Buffer((1, 16, 7, 7), "float32"),
+        T_strided_slice_with_axes: T.Buffer((1, 12, 7, 7), "float32"),
+    ):
+        T_strided_slice_with_axes_global = T.alloc_buffer((1, 12, 7, 7))
+        for ax0, ax1, ax2, ax3 in T.grid(1, 16, 7, 7):
+            with T.block("T_add"):
+                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, 
ax2, ax3])
+                T.where(ax1 < 12)
+                T_strided_slice_with_axes_global[v_ax0, v_ax1, v_ax2, v_ax3] = 
x[
+                    v_ax0, v_ax1, v_ax2, v_ax3
+                ] + T.float32(1)
+        for ax0, ax1, ax2, ax3 in T.grid(1, 12, 7, 7):
+            with T.block("T_strided_slice_with_axes_global"):
+                v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
+                T_strided_slice_with_axes[v0, v1, v2, v3] = 
T_strided_slice_with_axes_global[
+                    v0, v1, v2, v3
+                ]
+
+    sch = tir.Schedule(before)
+    sch.reverse_compute_inline(sch.get_block("T_strided_slice_with_axes"))
+    sch.cache_write(sch.get_block("T_add"), 0, "global")
+    assert_structural_equal_ignore_global_symbol(after, sch.mod["main"])
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to