This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 424c749a3d [MetaSchedule] Tile and pack intermediate output for CUDA 
TensorCore (#14108)
424c749a3d is described below

commit 424c749a3dac0ba42e89d3cbd04b024658d7d104
Author: Wuwei Lin <[email protected]>
AuthorDate: Mon Mar 6 03:24:33 2023 -0800

    [MetaSchedule] Tile and pack intermediate output for CUDA TensorCore 
(#14108)
    
    * [MetaSchedule] Tile and pack intermediate output for CUDA TensorCore
    
    * clean up schedule rule mltc
    
    * add lhs analyzer
    
    * prevent simplifying single point
    
    * clean up
    
    * lint
    
    * fix rewrite_tensorize test
    
    * fix software pipeline test
    
    * fix compile on mac
    
    * fix test cases
    
    * remove unused
    
    * rebase
    
    * only use json format for roundtrip
    
    * lint
    
    * Update src/tir/schedule/ir_comparator.h
    
    Co-authored-by: Siyuan Feng <[email protected]>
    
    ---------
    
    Co-authored-by: Tianqi Chen <[email protected]>
    Co-authored-by: Siyuan Feng <[email protected]>
---
 .../tvm/meta_schedule/testing/space_generation.py  |   2 +-
 src/meta_schedule/postproc/verify_gpu_code.cc      |   1 +
 .../schedule_rule/multi_level_tiling.cc            |  13 +-
 .../schedule_rule/multi_level_tiling.h             |   8 +-
 .../multi_level_tiling_tensor_core.cc              | 176 ++++-
 .../multi_level_tiling_wide_vector.cc              |  15 +-
 src/tir/analysis/block_access_region_detector.cc   |   9 +-
 src/tir/schedule/ir_comparator.cc                  |  10 +-
 src/tir/schedule/ir_comparator.h                   |   7 +-
 .../test_meta_schedule_schedule_rule_mlt_tc.py     | 783 +++++++++------------
 .../test_tir_transform_inject_software_pipeline.py |  16 +-
 11 files changed, 567 insertions(+), 473 deletions(-)

diff --git a/python/tvm/meta_schedule/testing/space_generation.py 
b/python/tvm/meta_schedule/testing/space_generation.py
index 0b7072b65a..45cd6659b6 100644
--- a/python/tvm/meta_schedule/testing/space_generation.py
+++ b/python/tvm/meta_schedule/testing/space_generation.py
@@ -88,7 +88,7 @@ def _find_match_sketch_id(
             decisions=new_decisions,
         ).apply_to_schedule(sch, remove_postproc=True)
         if structural_equal(sch.mod, expected_mod):
-            verify_trace_roundtrip(sch=sch, mod=mod, debug_mask=debug_mask)
+            verify_trace_roundtrip(sch=sch, mod=mod, debug_mask=debug_mask, 
text_format="json")
             return sketch_id
     return None
 
diff --git a/src/meta_schedule/postproc/verify_gpu_code.cc 
b/src/meta_schedule/postproc/verify_gpu_code.cc
index 99ffc1bfcd..6f9b46a0f7 100644
--- a/src/meta_schedule/postproc/verify_gpu_code.cc
+++ b/src/meta_schedule/postproc/verify_gpu_code.cc
@@ -162,6 +162,7 @@ class VerifyGPUCodeNode : public PostprocNode {
           
pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
           pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
           pass_list.push_back(tir::transform::UnifyThreadBinding());
+          
pass_list.push_back(tir::transform::ManifestSharedMemoryLocalStage());
           pass_list.push_back(tir::transform::CompactBufferAllocation());
           pass_list.push_back(tir::transform::LowerMatchBuffer());
           pass_list.push_back(tir::transform::InjectSoftwarePipeline());
diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc 
b/src/meta_schedule/schedule_rule/multi_level_tiling.cc
index 779114e9cf..0312c100b5 100644
--- a/src/meta_schedule/schedule_rule/multi_level_tiling.cc
+++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc
@@ -186,15 +186,15 @@ std::vector<State> 
MultiLevelTilingNode::AddWriteReuse(State state) const {
   return results;
 }
 
-Array<tir::LoopRV> MultiLevelTilingNode::SplitLoop(const Schedule& sch, 
BlockRV block, LoopRV loop,
-                                                   int n_tiles) const {
+std::pair<Array<tir::ExprRV>, Array<tir::LoopRV>> 
MultiLevelTilingNode::SplitLoop(
+    const Schedule& sch, BlockRV block, LoopRV loop, int n_tiles) const {
   Array<tir::ExprRV> factors = sch->SamplePerfectTile(
       /*loop=*/loop,
       /*n=*/n_tiles,
       /*max_innermost_factor=*/max_innermost_factor);
   Array<tir::LoopRV> splits = sch->Split(/*loop=*/loop,
                                          /*factors=*/{factors.begin(), 
factors.end()});
-  return splits;
+  return {factors, splits};
 }
 
 std::vector<State> MultiLevelTilingNode::TileLoopNest(State state) const {
@@ -207,6 +207,9 @@ std::vector<State> MultiLevelTilingNode::TileLoopNest(State 
state) const {
   // Step 2. For each loop axis, tile it
   int64_t spatial_loop_product = 1;
   std::vector<Array<LoopRV>> tiles(s_indices_.size() + r_indices_.size());
+  state->tile_factors.resize(tiles.size());
+  std::vector<Array<tir::ExprRV>> tile_factors;
+  tile_factors.resize(tiles.size());
   for (int i = 0, n = loops.size(); i < n; ++i) {
     LoopRV loop = loops[i];
     const std::vector<int>* idx = nullptr;
@@ -231,14 +234,16 @@ std::vector<State> 
MultiLevelTilingNode::TileLoopNest(State state) const {
     if (n_tiles == 1) {
       tiles[idx->at(0)].push_back(loop);
     } else {
-      auto splits = SplitLoop(sch, block_rv, loop, n_tiles);
+      auto [factors, splits] = SplitLoop(sch, block_rv, loop, n_tiles);
 
       // Put every tile to its slot
       for (int j = 0; j < n_tiles; ++j) {
         tiles[idx->at(j)].push_back(splits[j]);
+        tile_factors[idx->at(j)].push_back(factors[j]);
       }
     }
   }
+  state->tile_factors = std::move(tile_factors);
   // Step 3. Reorder to organize the tiles
   sch->Reorder(support::ConcatArrayList<LoopRV>(tiles.begin(), tiles.end()));
   // Step 4. Bind the tiles to threads
diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.h 
b/src/meta_schedule/schedule_rule/multi_level_tiling.h
index ff38756ff0..41b3ca9f26 100644
--- a/src/meta_schedule/schedule_rule/multi_level_tiling.h
+++ b/src/meta_schedule/schedule_rule/multi_level_tiling.h
@@ -94,6 +94,8 @@ class StateNode : public Object {
   tir::BlockRV block_rv;
   /*! \brief The loop tiles */
   Array<Array<tir::LoopRV>> tiles;
+  /*! \brief The factors of the loop tiles. */
+  Array<Array<tir::ExprRV>> tile_factors;
   /*! \brief The mapping from buffer index to read cache block. */
   std::unordered_map<int, tir::BlockRV> read_reuse;
   /*! \brief The mapping from buffer index to write cache block. */
@@ -163,8 +165,10 @@ class MultiLevelTilingNode : public ScheduleRuleNode {
  protected:
   virtual std::vector<State> ApplySubRules(std::vector<State> states);
 
-  virtual Array<tir::LoopRV> SplitLoop(const tir::Schedule& sch, tir::BlockRV 
block,
-                                       tir::LoopRV loop, int n_tiles) const;
+  virtual std::pair<Array<tir::ExprRV>, Array<tir::LoopRV>> SplitLoop(const 
tir::Schedule& sch,
+                                                                      
tir::BlockRV block,
+                                                                      
tir::LoopRV loop,
+                                                                      int 
n_tiles) const;
 
   // Annotate a block to use cooperative fetching
   void AnnotateCooperativeFetching(tir::Schedule* sch, const tir::BlockRV& 
block) const;
diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc 
b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc
index d5cca52d41..1f9945022b 100644
--- a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc
+++ b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc
@@ -17,6 +17,7 @@
  * under the License.
  */
 #include <tvm/meta_schedule/schedule_rule.h>
+#include <tvm/tir/op.h>
 
 #include <algorithm>
 #include <utility>
@@ -124,6 +125,9 @@ class MultiLevelTilingTensorCoreNode : public 
MultiLevelTilingNode {
  private:
   // SubRule: Add tensorization-related transformations
   inline std::vector<State> TransformForTensorization(TensorCoreState state) 
const;
+  // Subrule: Transform the layout of the output. This is necessary for 
efficient cache write the
+  // output in the shared memory.
+  std::vector<State> TransformIntermediateOutputLayout(TensorCoreState state);
   // Subrule: Add tensorized load
   inline std::vector<State> AddReadReuseTensorCore(TensorCoreState state) 
const;
   // Subrule: Add tensorized store
@@ -225,6 +229,9 @@ std::vector<State> 
MultiLevelTilingTensorCoreNode::ApplySubRules(std::vector<Sta
     return TransformForTensorization(Downcast<TensorCoreState>(state));
   });
   states = SubRule(std::move(states), [&](State state) { return 
TileLoopNest(state); });
+  states = SubRule(std::move(states), [&](State state) {
+    return TransformIntermediateOutputLayout(Downcast<TensorCoreState>(state));
+  });
   states = SubRule(std::move(states), [&](State state) { return 
AddWriteReuse(state); });
   states = SubRule(std::move(states), [&](State state) {
     return AddWriteReuseTensorCore(Downcast<TensorCoreState>(state));
@@ -248,25 +255,162 @@ void 
MultiLevelTilingTensorCoreNode::TileAndAnnotateTensorize(Schedule* sch,
   (*sch)->Annotate(blockized_outer, tir::attr::meta_schedule_auto_tensorize, 
intrin_name);
 }
 
+std::vector<State> 
MultiLevelTilingTensorCoreNode::TransformIntermediateOutputLayout(
+    TensorCoreState state) {
+  // Transform the intermediate output to packed layout
+  //   [..., warp_m, warp_n, accum_frag_m, accum_frag_n, accum_elem_m, 
accum_elem_n]
+  // where warp_m, warp_n are thread indices bound to the warp id, 
accum_frag_m, accum_frag_n are
+  // the index of the fragments in each warp, accum_elem_m, accum_elem_n are 
the index of the
+  // elements in each accumulator fragment.
+
+  // Get the shape of the wmma accumulator
+  auto [frag_shape_m, frag_shape_n] = [&]() {
+    tir::Block intrin_block =
+        Downcast<tir::BlockRealize>(
+            
tir::TensorIntrin::Get(state->intrin_group.init_intrin).value()->desc->body)
+            ->block;
+    tir::For loop_m = Downcast<tir::For>(intrin_block->body);
+    tir::For loop_n = Downcast<tir::For>(loop_m->body);
+    return std::make_tuple(loop_m->extent, loop_n->extent);
+  }();
+
+  // Get the tile index of the warp id (i.e. threadIdx.y)
+  auto it = std::find(tile_binds.begin(), tile_binds.end(), "threadIdx.y");
+  ICHECK(it != tile_binds.end());
+  auto tile_index_warp_id = std::distance(tile_binds.begin(), it);
+
+  // Get the extent of loop indicated by `loop_idx` inside the warp scope.
+  // For example, after spatial loops i, j are tiled, we will have
+  // tile_factors = ((i0, j0), (i1, j1), ..., (in, jn))
+  // This function computes the product of tile_factors[i][loop_idx] for i > 
tile_index_warp_id.
+  // `loop_idx` can be negative, in which case it is counted from the end.
+  auto f_get_inner_tile_product = [&](int loop_idx) {
+    Array<tir::ExprRV> factors;
+    for (int i = tile_index_warp_id + 1; i < 
static_cast<int>(s_indices_.size()); ++i) {
+      auto s_factors = state->tile_factors[s_indices_[i]];
+      if (loop_idx < 0) {
+        loop_idx += s_factors.size();
+      }
+      factors.push_back(s_factors[loop_idx]);
+    }
+    ICHECK(!factors.empty());
+    if (factors.size() == 1) {
+      return factors[0];
+    }
+    auto result = factors[0];
+    for (int i = 1; i < static_cast<int>(factors.size()); ++i) {
+      result = result * factors[i];
+    }
+    return result;
+  };
+
+  // Compute the number of output fragment of each warp
+  auto warp_num_frag_m = f_get_inner_tile_product(-2);
+  auto warp_num_frag_n = f_get_inner_tile_product(-1);
+
+  Schedule& sch = state->sch;
+  int buffer_ndim = 
static_cast<int>(sch->Get(state->block_rv)->writes[0]->buffer->shape.size());
+  // The dimension of the buffer should be larger or same as that of the 
tensor intrin.
+  ICHECK_GE(buffer_ndim, 2);
+  int num_higher_dims = buffer_ndim - 2;
+
+  auto index_map =
+      tir::IndexMap::FromFunc(buffer_ndim,
+                              // frag_shape_m and frag_shape_n are structural 
bindings that cannot
+                              // not be automatically captured until c++20
+                              [&, frag_shape_m = frag_shape_m,
+                               frag_shape_n = frag_shape_n](const 
Array<tir::Var>& indices) {
+                                Array<PrimExpr> result;
+                                result.reserve(indices.size() + 4);
+                                for (int i = 0; i < num_higher_dims; ++i) {
+                                  result.push_back(indices[i]);
+                                }
+                                const auto& m = indices[num_higher_dims];
+                                const auto& n = indices[num_higher_dims + 1];
+                                auto accum_m = floormod(m, frag_shape_m);
+                                auto accum_n = floormod(n, frag_shape_n);
+                                auto outer_m = floordiv(m, frag_shape_m);
+                                auto outer_n = floordiv(n, frag_shape_n);
+
+                                result.push_back(floordiv(outer_m, 
warp_num_frag_m));
+                                result.push_back(floordiv(outer_n, 
warp_num_frag_n));
+                                result.push_back(floormod(outer_m, 
warp_num_frag_m));
+                                result.push_back(floormod(outer_n, 
warp_num_frag_n));
+                                result.push_back(accum_m);
+                                result.push_back(accum_n);
+                                return result;
+                              });
+  sch->TransformLayout(state->block_rv, 0, tir::BufferIndexType::kWrite, 
index_map,
+                       /*pad_value=*/NullOpt, 
/*assume_injective_transform=*/true);
+
+  return {state};
+}
+
 std::vector<State> MultiLevelTilingTensorCoreNode::AddWriteReuseTensorCore(
     TensorCoreState state) const {
   // Add the cache write stage for Tensor Core
-  int level = r_indices_.front() - 1;
-  const LoopRV& loop = state->tiles[level].back();
   Schedule& sch = state->sch;
   auto cache_write = sch->CacheWrite(state->block_rv, 0, "wmma.accumulator");
-  sch->ReverseComputeAt(cache_write, loop, true);
-
-  if (state->write_reuse.count(0)) {
-    // Fuse the iterators of the cache_write
-    Array<LoopRV> buffer_loops = sch->GetLoops(state->write_reuse[0]);
-    ICHECK_GT(buffer_loops.size(), 2);
-    sch->Fuse(Array<LoopRV>{buffer_loops.end() - 2,  // The src shmem is 
always 2D
-                            buffer_loops.end()});
-    AnnotateCooperativeFetching(&sch, state->write_reuse[0]);
+
+  // The compute block has been tiled by the warp shape and the fragment shape.
+  // We need to bind the cache write block (from the accumulator to the shared 
memory) to the warp
+  // id. The schedule is as follows:
+  //
+  // After adding cache write for wmma.accumulator, we will have
+  //   for i0, j0, i1, j1, accum_m, accum_n:
+  //     shared_mem[i0, j0, i1, j1, accum_m, accum_n] = accum[i0, j0, i1, j1, 
accum_m, accum_n]
+  //   for i0', j0', i1', j1', accum_m', accum_n':
+  //      global_mem[i0', j0', i1', j1', accum_m', accum_n'] =
+  //        shared_mem[i0', j0', i1', j1', accum_m', accum_n']
+  // where i0' and j0' are already bound to the block id and warp id.
+  //
+  // To reduce the shared memory usage and allow efficient data movement, we 
will apply
+  // transformations to generate the following schedule:
+  //
+  //   for i1':
+  //     for i0_j0 (fused and bound to threadIdx.y):
+  //       for j1, accum_m, accum_n:
+  //         shared_mem[i0, j0, i1, j1, accum_m, accum_n] = accum[i0, j0, i1, 
j1, accum_m, accum_n]
+  //     for i0', j0', j1', accum_m', accum_n':
+  //       global_mem[i0', j0', i1', j1', accum_m', accum_n'] =
+  //         shared_mem[i0', j0', i1', j1', accum_m', accum_n']
+  //
+  // i1' is reordered to the outermost. This effectively allows only a row 
(i.e. loop i1') of the
+  // fragments are moved to the shared memory and then to the global memory 
each time.
+  // As a result, shared memory for the output will only have shape of [j1, 
accum_m, accum_n]
+  // instead of [i0 * i1 * accum_m, j0 * j1 * accum_n].
+
+  // Get the loops other than the innermost two loops (accum_m and accum_n).
+  auto f_get_loops = [&](const BlockRV& block_rv) -> std::array<LoopRV, 4> {
+    Array<LoopRV> buffer_loops = sch->GetLoops(block_rv);
+    ICHECK_GT(buffer_loops.size(), 6);
+    return {buffer_loops[buffer_loops.size() - 6], 
buffer_loops[buffer_loops.size() - 5],
+            buffer_loops[buffer_loops.size() - 4], 
buffer_loops[buffer_loops.size() - 3]};
+  };
+  {
+    const auto& [i0, j0, i1, j1] = f_get_loops(state->write_reuse[0]);
+    sch->Reorder({i1, i0, j0, j1});
+    sch->ComputeAt(cache_write, i1, true);
+  }
+  {
+    auto loops = f_get_loops(cache_write);
+    const auto& i0 = loops[0];
+    const auto& j0 = loops[1];
+    auto fused = sch->Fuse({i0, j0});
+    sch->Bind(fused, "threadIdx.y");
   }
+
   sch->ReverseComputeInline(state->tensor_core_reindex_store);
-  TileAndAnnotateTensorize(&sch, cache_write, 
state->intrin_group.store_intrin);
+  auto loops = sch->GetLoops(cache_write);
+  auto blockized_store = sch->Blockize(loops[loops.size() - 2]);
+  sch->Annotate(blockized_store, tir::attr::meta_schedule_auto_tensorize,
+                state->intrin_group.store_intrin);
+
+  Array<LoopRV> buffer_loops = sch->GetLoops(state->write_reuse[0]);
+  ICHECK_GT(buffer_loops.size(), 5);
+  sch->Fuse(Array<LoopRV>{buffer_loops.end() - 5,  // The src shmem is always 
2D
+                          buffer_loops.end()});
+  AnnotateCooperativeFetching(&sch, state->write_reuse[0]);
   return {state};
 }
 
@@ -508,7 +652,8 @@ Optional<LoopRV> 
MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin(
         state->sch->state(), GetRef<tir::Block>(block), buffer_index, 
index_type);
     auto sub_index_map = f_get_sub_index_map(lhs_buffer, 
reindexed_buffer_region->region);
     buffer_sub_index_map.Set(lhs_buffer, sub_index_map);
-    state->sch->TransformLayout(state->block_rv, buffer_index, index_type, 
sub_index_map, NullOpt);
+    state->sch->TransformLayout(state->block_rv, buffer_index, index_type, 
sub_index_map,
+                                /*pad_value=*/NullOpt, 
/*assume_injective_transform=*/true);
   };
 
   for (int i = 0, n = block_before_reindex->reads.size(); i < n; ++i) {
@@ -569,6 +714,11 @@ ScheduleRule ScheduleRule::MultiLevelTilingTensorCore(
   auto node = MultiLevelTilingInitCommon<MultiLevelTilingTensorCoreNode>(
       structure, tile_binds, max_innermost_factor, vector_load_lens, 
reuse_read, reuse_write);
 
+  CHECK(node->reuse_write_.req == ReuseType::kMustReuse &&
+        runtime::StorageScope::Create(node->reuse_write_.scope).rank ==
+            runtime::StorageRank::kShared)
+      << "ValueError: Shared memory write reuse must be enabled for 
MultiLevelTilingTensorCore.";
+
   node->intrin_groups.reserve(intrin_groups.size());
   for (const auto& intrin_group_config : intrin_groups) {
     
node->intrin_groups.emplace_back(TensorCoreIntrinGroup::FromConfig(intrin_group_config));
diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc 
b/src/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc
index d4c4a10fdd..e68b64ea2d 100644
--- a/src/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc
+++ b/src/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc
@@ -48,11 +48,12 @@ class MultiLevelTilingWideVectorNode : public 
MultiLevelTilingNode {
     return ScheduleRule(n);
   }
 
-  Array<tir::LoopRV> SplitLoop(const Schedule& sch, BlockRV block, LoopRV 
loop, int n_tiles) const;
+  std::pair<Array<tir::ExprRV>, Array<tir::LoopRV>> SplitLoop(const Schedule& 
sch, BlockRV block,
+                                                              LoopRV loop, int 
n_tiles) const;
 };
 
-Array<tir::LoopRV> MultiLevelTilingWideVectorNode::SplitLoop(const Schedule& 
sch, BlockRV block_rv,
-                                                             LoopRV loop_rv, 
int n_tiles) const {
+std::pair<Array<tir::ExprRV>, Array<tir::LoopRV>> 
MultiLevelTilingWideVectorNode::SplitLoop(
+    const Schedule& sch, BlockRV block_rv, LoopRV loop_rv, int n_tiles) const {
   const tir::ForNode* loop = TVM_SREF_TO_FOR(sch->GetSRef(loop_rv));
   const tir::StmtSRef block_sref = sch->GetSRef(block_rv);
   const tir::BlockNode* block_node = block_sref->StmtAs<tir::BlockNode>();
@@ -99,12 +100,14 @@ Array<tir::LoopRV> 
MultiLevelTilingWideVectorNode::SplitLoop(const Schedule& sch
       Array<tir::LoopRV> outer_splits = sch->Split(
           /*loop=*/inner_splits[0], /*factors=*/{outer_factors.begin(), 
outer_factors.end()});
       outer_splits.push_back(inner_splits[1]);
-      return outer_splits;
+      outer_factors.push_back(PrimExpr(vec_len));
+      return {outer_factors, outer_splits};
     } else {
       Array<tir::ExprRV> factors(n_tiles - 1, PrimExpr(1));
       factors.push_back(loop->extent);
-      return sch->Split(/*loop=*/loop_rv,
-                        /*factors=*/{factors.begin(), factors.end()});
+      Array<tir::LoopRV> splits = sch->Split(/*loop=*/loop_rv,
+                                             /*factors=*/{factors.begin(), 
factors.end()});
+      return {factors, splits};
     }
   }
 }
diff --git a/src/tir/analysis/block_access_region_detector.cc 
b/src/tir/analysis/block_access_region_detector.cc
index e9bff1b6fd..ab328efaa6 100644
--- a/src/tir/analysis/block_access_region_detector.cc
+++ b/src/tir/analysis/block_access_region_detector.cc
@@ -76,8 +76,6 @@ class BlockReadWriteDetector : public StmtExprVisitor {
   Map<Var, Buffer> buffer_var_map_;
   /*! \brief The target buffer var mapping to its matching */
   std::unordered_map<const VarNode*, MatchBufferRegion> match_buffers_;
-  /*! \brief The analyzer for simplifying*/
-  arith::Analyzer analyzer_;
 
   /*!
    * \brief Update read/write buffers and regions with provided buffer and 
region
@@ -330,7 +328,12 @@ Array<BufferRegion> BlockReadWriteDetector::CollectRegions(
     ICHECK_EQ(buffers[i]->shape.size(), regions[i].size());
     for (size_t j = 0; j < regions[i].size(); j++) {
       const tvm::arith::IntSet& range = regions[i][j];
-      region.push_back(range.CoverRange(Range::FromMinExtent(0, 
buffers[i]->shape[j])));
+      if (range.IsSinglePoint()) {
+        PrimExpr min = range.min();
+        region.push_back(Range::FromMinExtent(min, make_const(min.dtype(), 
1)));
+      } else {
+        region.push_back(range.CoverRange(Range::FromMinExtent(0, 
buffers[i]->shape[j])));
+      }
     }
     res.push_back(BufferRegion(buffers[i], region));
   }
diff --git a/src/tir/schedule/ir_comparator.cc 
b/src/tir/schedule/ir_comparator.cc
index 9d89c64163..5353a051a6 100644
--- a/src/tir/schedule/ir_comparator.cc
+++ b/src/tir/schedule/ir_comparator.cc
@@ -43,7 +43,7 @@ class TensorIntrinMismatchError : public ScheduleError {
     std::ostringstream os;
     os << "The stmt {0} doesn't match the tensor intrin\nThe pattern 
attempting to be matched:\n"
        << lhs_stmt_ << "\nDoes not match the tensorize description:\n"
-       << rhs_stmt_;
+       << rhs_stmt_ << '\n';
     for (const auto& msg : error_messages_) {
       os << msg << std::endl;
     }
@@ -173,6 +173,9 @@ bool TensorizeComparator::VisitStmt_(const 
BlockRealizeNode* op, const Stmt& oth
 
 bool TensorizeComparator::VisitStmt_(const BlockNode* op, const Stmt& other) {
   const auto* rhs = other.as<BlockNode>();
+  for (const IterVar& iter : op->iter_vars) {
+    lhs_analyzer_.Bind(iter->var, iter->dom);
+  }
   // Check block equality.
   // All iter vars and buffer regions including the order should match.
   // When checking iter vars, DefEqual is used to remap variables.
@@ -465,7 +468,7 @@ bool TensorizeComparator::CompareBufferRegion(const 
BufferRegion& lhs, const Buf
         }
         return false;
       }
-      if (!analyzer_.CanProveEqual(indices_base[i], lhs->region[i]->min)) {
+      if (!lhs_analyzer_.CanProveEqual(indices_base[i], lhs->region[i]->min)) {
         if (assert_mode_) {
           std::ostringstream os;
           os << "Buffer base index consistency check failed due to unequal 
index base: "
@@ -487,7 +490,8 @@ bool TensorizeComparator::CompareBufferRegion(const 
BufferRegion& lhs, const Buf
         }
         return false;
       }
-      PrimExpr normalized_lhs_min = (lhs->region[i + offset]->min - 
indices_base[i + offset]);
+      PrimExpr normalized_lhs_min =
+          lhs_analyzer_.Simplify((lhs->region[i + offset]->min - 
indices_base[i + offset]));
       if (!analyzer_.CanProveEqual(normalized_lhs_min, rhs->region[i]->min)) {
         if (assert_mode_) {
           std::ostringstream os;
diff --git a/src/tir/schedule/ir_comparator.h b/src/tir/schedule/ir_comparator.h
index 394d828673..debf0f946e 100644
--- a/src/tir/schedule/ir_comparator.h
+++ b/src/tir/schedule/ir_comparator.h
@@ -102,8 +102,13 @@ class TensorizeComparator : public ExprComparator, public 
StmtComparator {
   bool assert_mode_;
   /*! \brief Whether it is visiting the scope block (the outermost block). */
   bool is_scope_block = true;
-  /*! \brief The arithmetic analyzer. */
+  /*! \brief The arithmetic analyzer for comparing LHS and RHS */
   arith::Analyzer analyzer_;
+  /*!
+   * \brief The arithmetic analyzer for simplifying expressions on LHS.
+   *  This analyzer only contains the domains of the iterators on LHS.
+   */
+  arith::Analyzer lhs_analyzer_;
   /*! \brief Additional error messages. Only used when assert_mode is true. */
   std::vector<std::string> error_messages_;
   // variable remap if any
diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py 
b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py
index 9b869b4436..1cab2554e8 100644
--- a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py
+++ b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py
@@ -83,39 +83,39 @@ def test_matmul_relu(shared_scope):
     @T.prim_func
     def matmul_relu_0(A: T.Buffer((128, 128), "float16"), B: T.Buffer((128, 
128), "float16"), compute: T.Buffer((128, 128), "float32")) -> None:
         T.func_attr({"global_symbol": "main", "tir.noalias": True})
-        C_reindex_shared = T.alloc_buffer([128, 128], dtype="float32", 
scope=shared_scope)
-        C_reindex_shared_wmma_accumulator = T.alloc_buffer([128, 128], 
dtype="float32", scope="wmma.accumulator")
-        A_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", 
scope=shared_scope)
-        B_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", 
scope=shared_scope)
-        A_reindex_shared_wmma_matrix_a = T.alloc_buffer([128, 128], 
dtype="float16", scope="wmma.matrix_a")
-        B_reindex_shared_wmma_matrix_b = T.alloc_buffer([128, 128], 
dtype="float16", scope="wmma.matrix_b")
+        C_reindex_shared = T.alloc_buffer((4, 8, 2, 1, 16, 16), 
scope=shared_scope)
+        C_reindex_shared_wmma_accumulator = T.alloc_buffer((4, 8, 2, 1, 16, 
16), scope="wmma.accumulator")
+        A_reindex_shared = T.alloc_buffer((128, 128), "float16", 
scope=shared_scope)
+        B_reindex_shared = T.alloc_buffer((128, 128), "float16", 
scope=shared_scope)
+        A_reindex_shared_wmma_matrix_a = T.alloc_buffer((128, 128), "float16", 
scope="wmma.matrix_a")
+        B_reindex_shared_wmma_matrix_b = T.alloc_buffer((128, 128), "float16", 
scope="wmma.matrix_b")
         for ax0_0_0_ax1_0_0_fused in T.thread_binding(8, thread="blockIdx.y"):
             for ax0_0_1_ax1_0_1_fused in T.thread_binding(2, 
thread="blockIdx.x"):
                 for ax0_0_2_ax1_0_2_fused in T.thread_binding(2, 
thread="threadIdx.y"):
-                    for ax2_0_0 in T.serial(1):
-                        for ax0_ax1_fused in T.serial(4096):
+                    for ax2_0_0 in range(1):
+                        for ax0_ax1_fused in range(4096):
                             with T.block("A_reindex_shared"):
                                 v0 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused 
// 2 * 32 + ax0_ax1_fused // 128)
                                 v1 = T.axis.spatial(128, ax0_ax1_fused % 128)
                                 T.reads(A[v0, v1])
                                 T.writes(A_reindex_shared[v0, v1])
-                                T.block_attr({"buffer_dim_align":[[0, 0, 32, 
8]], "meta_schedule.cooperative_fetch":8})
+                                T.block_attr({"buffer_dim_align": [[0, 0, 32, 
8]], "meta_schedule.cooperative_fetch": 8})
                                 A_reindex_shared[v0, v1] = A[v0, v1]
-                        for ax0_ax1_fused in T.serial(4096):
+                        for ax0_ax1_fused in range(4096):
                             with T.block("B_reindex_shared"):
                                 v0 = T.axis.spatial(128, ax0_ax1_fused // 32)
                                 v1 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused 
% 2 * 64 + ax0_0_1_ax1_0_1_fused * 32 + ax0_ax1_fused % 32)
                                 T.reads(B[v0, v1])
                                 T.writes(B_reindex_shared[v0, v1])
-                                T.block_attr({"buffer_dim_align":[[0, 0, 32, 
8]], "meta_schedule.cooperative_fetch":1})
+                                T.block_attr({"buffer_dim_align": [[0, 0, 32, 
8]], "meta_schedule.cooperative_fetch": 1})
                                 B_reindex_shared[v0, v1] = B[v0, v1]
-                        for ax2_0_1 in T.serial(4):
+                        for ax2_0_1 in range(4):
                             for ax0_0, ax1_0 in T.grid(2, 2):
                                 with 
T.block("A_reindex_shared_wmma.matrix_a_o"):
                                     v0_o = T.axis.spatial(8, 
ax0_0_0_ax1_0_0_fused // 2 * 2 + ax0_0)
                                     v1_o = T.axis.spatial(8, ax2_0_1 * 2 + 
ax1_0)
-                                    T.reads(A_reindex_shared[v0_o * 16 : v0_o 
* 16 + 16, v1_o * 16 : v1_o * 16 + 16])
-                                    
T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : 
v1_o * 16 + 16])
+                                    T.reads(A_reindex_shared[v0_o * 16:v0_o * 
16 + 16, v1_o * 16:v1_o * 16 + 16])
+                                    
T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 
16:v1_o * 16 + 16])
                                     
T.block_attr({"meta_schedule.auto_tensorize": 
f"wmma_load_16x16x16_f16_a_{intrin_suffix}"})
                                     for ax0_1, ax1_1 in T.grid(16, 16):
                                         with 
T.block("A_reindex_shared_wmma.matrix_a"):
@@ -127,8 +127,8 @@ def test_matmul_relu(shared_scope):
                                 with 
T.block("B_reindex_shared_wmma.matrix_b_o"):
                                     v0_o = T.axis.spatial(8, ax2_0_1 * 2 + 
ax0_0)
                                     v1_o = T.axis.spatial(8, 
ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + 
ax0_0_2_ax1_0_2_fused + ax1_0)
-                                    T.reads(B_reindex_shared[v0_o * 16 : v0_o 
* 16 + 16, v1_o * 16 : v1_o * 16 + 16])
-                                    
T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : 
v1_o * 16 + 16])
+                                    T.reads(B_reindex_shared[v0_o * 16:v0_o * 
16 + 16, v1_o * 16:v1_o * 16 + 16])
+                                    
T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16:v0_o * 16 + 16, v1_o * 
16:v1_o * 16 + 16])
                                     
T.block_attr({"meta_schedule.auto_tensorize": 
f"wmma_load_16x16x16_f16_b_{intrin_suffix}"})
                                     for ax0_1, ax1_1 in T.grid(16, 16):
                                         with 
T.block("B_reindex_shared_wmma.matrix_b"):
@@ -141,44 +141,54 @@ def test_matmul_relu(shared_scope):
                                     v0_o = T.axis.spatial(8, 
ax0_0_0_ax1_0_0_fused // 2 * 2 + ax0_0_3 * 2 + ax0_0_4)
                                     v1_o = T.axis.spatial(8, ax1_0_4 + 
ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + 
ax0_0_2_ax1_0_2_fused + ax1_0_3)
                                     v2_o = T.axis.reduce(8, ax2_0_0 * 8 + 
ax2_0_1 * 2 + ax2_0_2)
-                                    
T.reads(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v2_o * 16 : 
v2_o * 16 + 16], B_reindex_shared_wmma_matrix_b[v2_o * 16 : v2_o * 16 + 16, 
v1_o * 16 : v1_o * 16 + 16])
-                                    
T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 
16 : v1_o * 16 + 16])
-                                    
T.block_attr({"meta_schedule.auto_tensorize":"wmma_sync_16x16x16_f16f16f32", 
"meta_schedule.auto_tensorize_init":"wmma_fill_16x16x16_f32", 
"warp_execution":1})
+                                    
T.reads(A_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v2_o * 16:v2_o 
* 16 + 16], B_reindex_shared_wmma_matrix_b[v2_o * 16:v2_o * 16 + 16, v1_o * 
16:v1_o * 16 + 16])
+                                    
T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, 0:16, 
0:16])
+                                    
T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32", 
"meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32", 
"warp_execution": 1})
                                     with T.init():
                                         for ax0_1, ax1_1 in T.grid(16, 16):
                                             with T.block("C_init"):
                                                 v0_i_init, v1_i_init = 
T.axis.remap("SS", [ax0_1, ax1_1])
                                                 T.reads()
-                                                
T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + 
v1_i_init])
-                                                
C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + v1_i_init] 
= T.float32(0)
+                                                
T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, 
v0_i_init, v1_i_init])
+                                                
C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i_init, 
v1_i_init] = T.float32(0)
                                     for ax0_1, ax1_1, ax2_1 in T.grid(16, 16, 
16):
                                         with T.block("C"):
                                             v0_i, v1_i, v2_i = 
T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1])
-                                            
T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i], 
A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], 
B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i])
-                                            
T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
-                                            
T.block_attr({"meta_schedule.tiling_structure":"SSSRRSRS"})
-                                            
C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = 
C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + 
T.cast(A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], 
"float32") * T.cast(B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 
+ v1_i], "float32")
-                    for ax0_0, ax1_0 in T.grid(2, 1):
-                        with T.block("C_reindex_shared_wmma.accumulator_o"):
-                            v0_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused // 
2 * 2 + ax0_0)
-                            v1_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 
* 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_2_ax1_0_2_fused + ax1_0)
-                            T.reads(C_reindex_shared_wmma_accumulator[v0_o * 
16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
-                            T.writes(C_reindex_shared[v0_o * 16 : v0_o * 16 + 
16, v1_o * 16 : v1_o * 16 + 16])
-                            T.block_attr({"meta_schedule.auto_tensorize": 
f"wmma_store_16x16x16_f32_{intrin_suffix}"})
-                            for ax0_1, ax1_1 in T.grid(16, 16):
-                                with 
T.block("C_reindex_shared_wmma.accumulator"):
-                                    v0_i, v1_i = T.axis.remap("SS", [ax0_1, 
ax1_1])
-                                    
T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
-                                    T.writes(C_reindex_shared[v0_o * 16 + 
v0_i, v1_o * 16 + v1_i])
-                                    C_reindex_shared[v0_o * 16 + v0_i, v1_o * 
16 + v1_i] = C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + 
v1_i]
-                for ax0_ax1_fused in T.serial(1024):
-                    with T.block("C_reindex_shared"):
-                        v0 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused // 2 * 
32 + ax0_ax1_fused // 32)
-                        v1 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused % 2 * 
64 + ax0_0_1_ax1_0_1_fused * 32 + ax0_ax1_fused % 32)
-                        T.reads(C_reindex_shared[v0, v1])
-                        T.writes(compute[v0, v1])
-                        T.block_attr({"meta_schedule.cooperative_fetch":4})
-                        compute[v0, v1] = T.max(C_reindex_shared[v0, v1], 
T.float32(0))
+                                            
T.reads(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i, 
v1_i], A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], 
B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i])
+                                            
T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i, 
v1_i])
+                                            
T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"})
+                                            
C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i, v1_i] = 
C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i, v1_i] + 
T.Cast("float32", A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + 
v2_i]) * T.Cast("float32", B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, 
v1_o * 16 + v1_i])
+                for ax2 in range(2):
+                    for ax0_ax1_fused in T.thread_binding(2, 
thread="threadIdx.y"):
+                        for ax2_1, ax3 in T.grid(1, 1):
+                            with 
T.block("C_reindex_shared_wmma.accumulator_o"):
+                                v0 = T.axis.spatial(4, ax0_0_0_ax1_0_0_fused 
// 2)
+                                v1 = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 
2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_ax1_fused)
+                                v2 = T.axis.spatial(2, ax2 + ax2_1)
+                                v3 = T.axis.spatial(1, ax3)
+                                v4_o = T.axis.spatial(1, 0)
+                                v5_o = T.axis.spatial(1, 0)
+                                T.reads(C_reindex_shared_wmma_accumulator[v0, 
v1, v2, v3, 0:16, 0:16])
+                                T.writes(C_reindex_shared[v0, v1, v2, v3, 
0:16, 0:16])
+                                T.block_attr({"meta_schedule.auto_tensorize": 
f"wmma_store_16x16x16_f32_{intrin_suffix}"})
+                                for ax4, ax5 in T.grid(16, 16):
+                                    with 
T.block("C_reindex_shared_wmma.accumulator"):
+                                        v4_i, v5_i = T.axis.remap("SS", [ax4, 
ax5])
+                                        
T.reads(C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i])
+                                        T.writes(C_reindex_shared[v0, v1, v2, 
v3, v4_i, v5_i])
+                                        C_reindex_shared[v0, v1, v2, v3, v4_i, 
v5_i] = C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i]
+                    for ax0_ax1_ax3_ax4_ax5_fused in range(512):
+                        with T.block("C_reindex_shared"):
+                            v0 = T.axis.spatial(4, ax0_0_0_ax1_0_0_fused // 2)
+                            v1 = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 
4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_ax1_ax3_ax4_ax5_fused // 256)
+                            v2 = T.axis.spatial(2, ax2)
+                            v3 = T.axis.spatial(1, 0)
+                            v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused 
% 256 // 16)
+                            v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused 
% 16)
+                            T.reads(C_reindex_shared[v0, v1, v2, v3, v4, v5])
+                            T.writes(compute[v4 + v2 * 16 + v0 * 32, v5 + v1 * 
16])
+                            T.block_attr({"meta_schedule.cooperative_fetch": 
4})
+                            compute[v4 + v2 * 16 + v0 * 32, v5 + v1 * 16] = 
T.max(C_reindex_shared[v0, v1, v2, v3, v4, v5], T.float32(0))
 
     # fmt: on
     decision_0 = [
@@ -223,44 +233,42 @@ def test_matmul_relu_with_fallback():
     # fmt: off
     @T.prim_func
     def matmul_relu_fallback_0(A: T.Buffer((128, 128), "float16"), B: 
T.Buffer((128, 128), "float16"), compute: T.Buffer((128, 128), "float32")) -> 
None:
-        # function attr dict
         T.func_attr({"global_symbol": "main", "tir.noalias": True})
-        # body
-        # with T.block("root")
-        C_reindex_shared = T.alloc_buffer([128, 128], dtype="float32", 
scope="shared")
-        C_reindex_shared_wmma_accumulator = T.alloc_buffer([128, 128], 
dtype="float32", scope="wmma.accumulator")
-        A_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", 
scope="shared")
-        B_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", 
scope="shared")
-        A_reindex_shared_wmma_matrix_a = T.alloc_buffer([128, 128], 
dtype="float16", scope="wmma.matrix_a")
-        B_reindex_shared_wmma_matrix_b = T.alloc_buffer([128, 128], 
dtype="float16", scope="wmma.matrix_b")
+        # with T.block("root"):
+        C_reindex_shared = T.alloc_buffer((4, 2, 2, 4, 16, 16), scope="shared")
+        C_reindex_shared_wmma_accumulator = T.alloc_buffer((4, 2, 2, 4, 16, 
16), scope="wmma.accumulator")
+        A_reindex_shared = T.alloc_buffer((128, 128), "float16", 
scope="shared")
+        B_reindex_shared = T.alloc_buffer((128, 128), "float16", 
scope="shared")
+        A_reindex_shared_wmma_matrix_a = T.alloc_buffer((128, 128), "float16", 
scope="wmma.matrix_a")
+        B_reindex_shared_wmma_matrix_b = T.alloc_buffer((128, 128), "float16", 
scope="wmma.matrix_b")
         for ax0_0_0_ax1_0_0_fused in T.thread_binding(2, thread="blockIdx.y"):
             for ax0_0_1_ax1_0_1_fused in T.thread_binding(2, 
thread="blockIdx.x"):
                 for ax0_0_2_ax1_0_2_fused in T.thread_binding(2, 
thread="threadIdx.y"):
-                    for ax2_0_0 in T.serial(2):
-                        for ax0_ax1_fused in T.serial(2048):
+                    for ax2_0_0 in range(2):
+                        for ax0_ax1_fused in range(2048):
                             with T.block("A_reindex_shared"):
                                 v0 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused 
* 64 + ax0_0_1_ax1_0_1_fused * 32 + ax0_ax1_fused // 64)
                                 v1 = T.axis.spatial(128, ax2_0_0 * 64 + 
ax0_ax1_fused % 64)
                                 T.reads(A[v0, v1])
                                 T.writes(A_reindex_shared[v0, v1])
-                                T.block_attr({"buffer_dim_align":[[0, 0, 32, 
8]], "meta_schedule.cooperative_fetch":4})
+                                T.block_attr({"buffer_dim_align": [[0, 0, 32, 
8]], "meta_schedule.cooperative_fetch": 4})
                                 A_reindex_shared[v0, v1] = A[v0, v1]
-                        for ax0_ax1_fused in T.serial(8192):
+                        for ax0_ax1_fused in range(8192):
                             with T.block("B_reindex_shared"):
                                 v0 = T.axis.spatial(128, ax2_0_0 * 64 + 
ax0_ax1_fused // 128)
                                 v1 = T.axis.spatial(128, ax0_ax1_fused % 128)
                                 T.reads(B[v0, v1])
                                 T.writes(B_reindex_shared[v0, v1])
-                                T.block_attr({"buffer_dim_align":[[0, 0, 32, 
8]], "meta_schedule.cooperative_fetch":2})
+                                T.block_attr({"buffer_dim_align": [[0, 0, 32, 
8]], "meta_schedule.cooperative_fetch": 2})
                                 B_reindex_shared[v0, v1] = B[v0, v1]
-                        for ax2_0_1 in T.serial(1):
+                        for ax2_0_1 in range(1):
                             for ax0_0, ax1_0 in T.grid(2, 4):
                                 with 
T.block("A_reindex_shared_wmma.matrix_a_o"):
                                     v0_o = T.axis.spatial(8, 
ax0_0_0_ax1_0_0_fused * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0)
                                     v1_o = T.axis.spatial(8, ax2_0_0 * 4 + 
ax1_0)
-                                    T.reads(A_reindex_shared[v0_o * 16 : v0_o 
* 16 + 16, v1_o * 16 : v1_o * 16 + 16])
-                                    
T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : 
v1_o * 16 + 16])
-                                    
T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_a_shared"})
+                                    T.reads(A_reindex_shared[v0_o * 16:v0_o * 
16 + 16, v1_o * 16:v1_o * 16 + 16])
+                                    
T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 
16:v1_o * 16 + 16])
+                                    
T.block_attr({"meta_schedule.auto_tensorize": 
"wmma_load_16x16x16_f16_a_shared"})
                                     for ax0_1, ax1_1 in T.grid(16, 16):
                                         with 
T.block("A_reindex_shared_wmma.matrix_a"):
                                             v0_i, v1_i = T.axis.remap("SS", 
[ax0_1, ax1_1])
@@ -271,9 +279,9 @@ def test_matmul_relu_with_fallback():
                                 with 
T.block("B_reindex_shared_wmma.matrix_b_o"):
                                     v0_o = T.axis.spatial(8, ax2_0_0 * 4 + 
ax0_0)
                                     v1_o = T.axis.spatial(8, 
ax0_0_2_ax1_0_2_fused * 4 + ax1_0)
-                                    T.reads(B_reindex_shared[v0_o * 16 : v0_o 
* 16 + 16, v1_o * 16 : v1_o * 16 + 16])
-                                    
T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : 
v1_o * 16 + 16])
-                                    
T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_b_shared"})
+                                    T.reads(B_reindex_shared[v0_o * 16:v0_o * 
16 + 16, v1_o * 16:v1_o * 16 + 16])
+                                    
T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16:v0_o * 16 + 16, v1_o * 
16:v1_o * 16 + 16])
+                                    
T.block_attr({"meta_schedule.auto_tensorize": 
"wmma_load_16x16x16_f16_b_shared"})
                                     for ax0_1, ax1_1 in T.grid(16, 16):
                                         with 
T.block("B_reindex_shared_wmma.matrix_b"):
                                             v0_i, v1_i = T.axis.remap("SS", 
[ax0_1, ax1_1])
@@ -285,44 +293,54 @@ def test_matmul_relu_with_fallback():
                                     v0_o = T.axis.spatial(8, 
ax0_0_0_ax1_0_0_fused * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_3 * 2 + ax0_0_4)
                                     v1_o = T.axis.spatial(8, 
ax0_0_2_ax1_0_2_fused * 4 + ax1_0_3 * 4 + ax1_0_4)
                                     v2_o = T.axis.reduce(8, ax2_0_0 * 4 + 
ax2_0_1 * 4 + ax2_0_2)
-                                    
T.reads(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v2_o * 16 : 
v2_o * 16 + 16], B_reindex_shared_wmma_matrix_b[v2_o * 16 : v2_o * 16 + 16, 
v1_o * 16 : v1_o * 16 + 16])
-                                    
T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 
16 : v1_o * 16 + 16])
-                                    
T.block_attr({"meta_schedule.auto_tensorize":"wmma_sync_16x16x16_f16f16f32", 
"meta_schedule.auto_tensorize_init":"wmma_fill_16x16x16_f32", 
"warp_execution":1})
+                                    
T.reads(A_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v2_o * 16:v2_o 
* 16 + 16], B_reindex_shared_wmma_matrix_b[v2_o * 16:v2_o * 16 + 16, v1_o * 
16:v1_o * 16 + 16])
+                                    
T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o 
% 4, 0:16, 0:16])
+                                    
T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32", 
"meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32", 
"warp_execution": 1})
                                     with T.init():
                                         for ax0_1, ax1_1 in T.grid(16, 16):
                                             with T.block("C_init"):
                                                 v0_i_init, v1_i_init = 
T.axis.remap("SS", [ax0_1, ax1_1])
                                                 T.reads()
-                                                
T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + 
v1_i_init])
-                                                
C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + v1_i_init] 
= T.float32(0)
+                                                
T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o 
% 4, v0_i_init, v1_i_init])
+                                                
C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4, 
v0_i_init, v1_i_init] = T.float32(0)
                                     for ax0_1, ax1_1, ax2_1 in T.grid(16, 16, 
16):
                                         with T.block("C"):
                                             v0_i, v1_i, v2_i = 
T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1])
-                                            
T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i], 
A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], 
B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i])
-                                            
T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
-                                            
T.block_attr({"meta_schedule.tiling_structure":"SSSRRSRS"})
-                                            
C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = 
C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + 
T.cast(A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], 
"float32") * T.cast(B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 
+ v1_i], "float32")
-                    for ax0_0, ax1_0 in T.grid(2, 4):
-                        with T.block("C_reindex_shared_wmma.accumulator_o"):
-                            v0_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused * 4 
+ ax0_0_1_ax1_0_1_fused * 2 + ax0_0)
-                            v1_o = T.axis.spatial(8, ax0_0_2_ax1_0_2_fused * 4 
+ ax1_0)
-                            T.reads(C_reindex_shared_wmma_accumulator[v0_o * 
16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
-                            T.writes(C_reindex_shared[v0_o * 16 : v0_o * 16 + 
16, v1_o * 16 : v1_o * 16 + 16])
-                            
T.block_attr({"meta_schedule.auto_tensorize":"wmma_store_16x16x16_f32_shared"})
-                            for ax0_1, ax1_1 in T.grid(16, 16):
-                                with 
T.block("C_reindex_shared_wmma.accumulator"):
-                                    v0_i, v1_i = T.axis.remap("SS", [ax0_1, 
ax1_1])
-                                    
T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
-                                    T.writes(C_reindex_shared[v0_o * 16 + 
v0_i, v1_o * 16 + v1_i])
-                                    C_reindex_shared[v0_o * 16 + v0_i, v1_o * 
16 + v1_i] = C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + 
v1_i]
-                for ax0_ax1_fused in T.serial(4096):
-                    with T.block("C_reindex_shared"):
-                        v0 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused * 64 + 
ax0_0_1_ax1_0_1_fused * 32 + ax0_ax1_fused // 128)
-                        v1 = T.axis.spatial(128, ax0_ax1_fused % 128)
-                        T.reads(C_reindex_shared[v0, v1])
-                        T.writes(compute[v0, v1])
-                        T.block_attr({"meta_schedule.cooperative_fetch":4})
-                        compute[v0, v1] = T.max(C_reindex_shared[v0, v1], 
T.float32(0))
+                                            
T.reads(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o 
% 4, v0_i, v1_i], A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + 
v2_i], B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i])
+                                            
T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o 
% 4, v0_i, v1_i])
+                                            
T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"})
+                                            
C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4, 
v0_i, v1_i] = C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, 
v1_o % 4, v0_i, v1_i] + T.Cast("float32", A_reindex_shared_wmma_matrix_a[v0_o * 
16 + v0_i, v2_o * 16 + v2_i]) * T.Cast("float32", 
B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i])
+                for ax2 in range(2):
+                    for ax0_ax1_fused in T.thread_binding(2, 
thread="threadIdx.y"):
+                        for ax2_1, ax3 in T.grid(1, 4):
+                            with 
T.block("C_reindex_shared_wmma.accumulator_o"):
+                                v0 = T.axis.spatial(4, ax0_0_0_ax1_0_0_fused * 
2 + ax0_0_1_ax1_0_1_fused)
+                                v1 = T.axis.spatial(2, ax0_ax1_fused)
+                                v2 = T.axis.spatial(2, ax2 + ax2_1)
+                                v3 = T.axis.spatial(4, ax3)
+                                v4_o = T.axis.spatial(1, 0)
+                                v5_o = T.axis.spatial(1, 0)
+                                T.reads(C_reindex_shared_wmma_accumulator[v0, 
v1, v2, v3, 0:16, 0:16])
+                                T.writes(C_reindex_shared[v0, v1, v2, v3, 
0:16, 0:16])
+                                T.block_attr({"meta_schedule.auto_tensorize": 
"wmma_store_16x16x16_f32_shared"})
+                                for ax4, ax5 in T.grid(16, 16):
+                                    with 
T.block("C_reindex_shared_wmma.accumulator"):
+                                        v4_i, v5_i = T.axis.remap("SS", [ax4, 
ax5])
+                                        
T.reads(C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i])
+                                        T.writes(C_reindex_shared[v0, v1, v2, 
v3, v4_i, v5_i])
+                                        C_reindex_shared[v0, v1, v2, v3, v4_i, 
v5_i] = C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i]
+                    for ax0_ax1_ax3_ax4_ax5_fused in range(2048):
+                        with T.block("C_reindex_shared"):
+                            v0 = T.axis.spatial(4, ax0_0_0_ax1_0_0_fused * 2 + 
ax0_0_1_ax1_0_1_fused)
+                            v1 = T.axis.spatial(2, ax0_ax1_ax3_ax4_ax5_fused 
// 1024)
+                            v2 = T.axis.spatial(2, ax2)
+                            v3 = T.axis.spatial(4, ax0_ax1_ax3_ax4_ax5_fused % 
1024 // 256)
+                            v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused 
% 256 // 16)
+                            v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused 
% 16)
+                            T.reads(C_reindex_shared[v0, v1, v2, v3, v4, v5])
+                            T.writes(compute[v4 + v2 * 16 + v0 * 32, v5 + v3 * 
16 + v1 * 64])
+                            T.block_attr({"meta_schedule.cooperative_fetch": 
4})
+                            compute[v4 + v2 * 16 + v0 * 32, v5 + v3 * 16 + v1 
* 64] = T.max(C_reindex_shared[v0, v1, v2, v3, v4, v5], T.float32(0))
     # fmt: on
     decision_0 = [
         ("SamplePerfectTile", [2, 2, 1, 1, 2]),
@@ -373,46 +391,46 @@ def test_conv2d(shared_scope):
     @T.prim_func
     def conv2d_0(inputs: T.Buffer((1, 16, 16, 32), "float16"), weight: 
T.Buffer((3, 3, 32, 32), "float16"), conv2d_nhwc: T.Buffer((1, 16, 16, 32), 
"float32")) -> None:
         T.func_attr({"global_symbol": "main", "tir.noalias": True})
-        PadInput = T.alloc_buffer([1, 18, 18, 32], dtype="float16")
-        conv2d_nhwc_reindex_shared = T.alloc_buffer([256, 32], 
dtype="float32", scope=shared_scope)
-        conv2d_nhwc_reindex_shared_wmma_accumulator = T.alloc_buffer([256, 
32], dtype="float32", scope="wmma.accumulator")
-        PadInput_reindex_shared = T.alloc_buffer([256, 288], dtype="float16", 
scope=shared_scope)
-        weight_reindex_shared = T.alloc_buffer([288, 32], dtype="float16", 
scope=shared_scope)
-        PadInput_reindex_shared_wmma_matrix_a = T.alloc_buffer([256, 288], 
dtype="float16", scope="wmma.matrix_a")
-        weight_reindex_shared_wmma_matrix_b = T.alloc_buffer([288, 32], 
dtype="float16", scope="wmma.matrix_b")
+        PadInput = T.alloc_buffer((1, 18, 18, 32), "float16")
+        conv2d_nhwc_reindex_shared = T.alloc_buffer((16, 2, 1, 1, 16, 16), 
scope=shared_scope)
+        conv2d_nhwc_reindex_shared_wmma_accumulator = T.alloc_buffer((16, 2, 
1, 1, 16, 16), scope="wmma.accumulator")
+        PadInput_reindex_shared = T.alloc_buffer((256, 288), "float16", 
scope=shared_scope)
+        weight_reindex_shared = T.alloc_buffer((288, 32), "float16", 
scope=shared_scope)
+        PadInput_reindex_shared_wmma_matrix_a = T.alloc_buffer((256, 288), 
"float16", scope="wmma.matrix_a")
+        weight_reindex_shared_wmma_matrix_b = T.alloc_buffer((288, 32), 
"float16", scope="wmma.matrix_b")
         for i0, i1, i2, i3 in T.grid(1, 18, 18, 32):
             with T.block("PadInput"):
-                i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3])
-                T.reads(inputs[i0_1, i1_1 - 1, i2_1 - 1, i3_1])
-                T.writes(PadInput[i0_1, i1_1, i2_1, i3_1])
-                PadInput[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(1 <= i1_1 
and i1_1 < 17 and 1 <= i2_1 and i2_1 < 17, inputs[i0_1, i1_1 - 1, i2_1 - 1, 
i3_1], T.float16(0), dtype="float16")
+                v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
+                T.reads(inputs[v_i0, v_i1 - 1, v_i2 - 1, v_i3])
+                T.writes(PadInput[v_i0, v_i1, v_i2, v_i3])
+                PadInput[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(1 <= v_i1 
and v_i1 < 17 and 1 <= v_i2 and v_i2 < 17, inputs[v_i0, v_i1 - 1, v_i2 - 1, 
v_i3], T.float16(0))
         for ax0_0_0_ax1_0_0_fused in T.thread_binding(2, thread="blockIdx.y"):
             for ax0_0_1_ax1_0_1_fused in T.thread_binding(16, 
thread="blockIdx.x"):
                 for ax0_0_2_ax1_0_2_fused in T.thread_binding(1, 
thread="threadIdx.y"):
-                    for ax2_0_0 in T.serial(1):
-                        for ax0_ax1_fused in T.serial(4608):
+                    for ax2_0_0 in range(1):
+                        for ax0_ax1_fused in range(4608):
                             with T.block("PadInput_reindex_shared"):
                                 v0 = T.axis.spatial(256, ax0_0_1_ax1_0_1_fused 
* 16 + ax0_ax1_fused // 288)
                                 v1 = T.axis.spatial(288, ax0_ax1_fused % 288)
                                 T.reads(PadInput[v0 // 256, v1 // 96 + v0 // 
16, v1 % 96 // 32 + v0 % 16, v1 % 32])
                                 T.writes(PadInput_reindex_shared[v0, v1])
-                                T.block_attr({"buffer_dim_align":[[0, 0, 32, 
8]], "meta_schedule.cooperative_fetch":2})
+                                T.block_attr({"buffer_dim_align": [[0, 0, 32, 
8]], "meta_schedule.cooperative_fetch": 2})
                                 PadInput_reindex_shared[v0, v1] = PadInput[v0 
// 256, v1 // 96 + v0 // 16, v1 % 96 // 32 + v0 % 16, v1 % 32]
-                        for ax0_ax1_fused in T.serial(4608):
+                        for ax0_ax1_fused in range(4608):
                             with T.block("weight_reindex_shared"):
                                 v0 = T.axis.spatial(288, ax0_ax1_fused // 16)
                                 v1 = T.axis.spatial(32, ax0_0_0_ax1_0_0_fused 
* 16 + ax0_ax1_fused % 16)
                                 T.reads(weight[v0 // 96, v0 % 96 // 32, v0 % 
32, v1])
                                 T.writes(weight_reindex_shared[v0, v1])
-                                T.block_attr({"buffer_dim_align":[[0, 0, 32, 
8]], "meta_schedule.cooperative_fetch":8})
+                                T.block_attr({"buffer_dim_align": [[0, 0, 32, 
8]], "meta_schedule.cooperative_fetch": 8})
                                 weight_reindex_shared[v0, v1] = weight[v0 // 
96, v0 % 96 // 32, v0 % 32, v1]
-                        for ax2_0_1 in T.serial(18):
+                        for ax2_0_1 in range(18):
                             for ax0_0, ax1_0 in T.grid(1, 1):
                                 with 
T.block("PadInput_reindex_shared_wmma.matrix_a_o"):
                                     v0_o = T.axis.spatial(16, 
ax0_0_1_ax1_0_1_fused + ax0_0)
                                     v1_o = T.axis.spatial(18, ax2_0_1 + ax1_0)
-                                    T.reads(PadInput_reindex_shared[v0_o * 16 
: v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
-                                    
T.writes(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o 
* 16 : v1_o * 16 + 16])
+                                    T.reads(PadInput_reindex_shared[v0_o * 
16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16])
+                                    
T.writes(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 
16:v1_o * 16 + 16])
                                     
T.block_attr({"meta_schedule.auto_tensorize": 
f"wmma_load_16x16x16_f16_a_{intrin_suffix}"})
                                     for ax0_1, ax1_1 in T.grid(16, 16):
                                         with 
T.block("PadInput_reindex_shared_wmma.matrix_a"):
@@ -424,8 +442,8 @@ def test_conv2d(shared_scope):
                                 with 
T.block("weight_reindex_shared_wmma.matrix_b_o"):
                                     v0_o = T.axis.spatial(18, ax2_0_1 + ax0_0)
                                     v1_o = T.axis.spatial(2, 
ax0_0_0_ax1_0_0_fused + ax1_0)
-                                    T.reads(weight_reindex_shared[v0_o * 16 : 
v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
-                                    
T.writes(weight_reindex_shared_wmma_matrix_b[v0_o * 16 : v0_o * 16 + 16, v1_o * 
16 : v1_o * 16 + 16])
+                                    T.reads(weight_reindex_shared[v0_o * 
16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16])
+                                    
T.writes(weight_reindex_shared_wmma_matrix_b[v0_o * 16:v0_o * 16 + 16, v1_o * 
16:v1_o * 16 + 16])
                                     
T.block_attr({"meta_schedule.auto_tensorize": 
f"wmma_load_16x16x16_f16_b_{intrin_suffix}"})
                                     for ax0_1, ax1_1 in T.grid(16, 16):
                                         with 
T.block("weight_reindex_shared_wmma.matrix_b"):
@@ -438,44 +456,49 @@ def test_conv2d(shared_scope):
                                     v0_o = T.axis.spatial(16, ax0_0_4 + 
ax0_0_1_ax1_0_1_fused + ax0_0_3)
                                     v1_o = T.axis.spatial(2, 
ax0_0_0_ax1_0_0_fused + ax1_0_3 + ax1_0_4)
                                     v2_o = T.axis.reduce(18, ax2_0_0 * 18 + 
ax2_0_1 + ax2_0_2)
-                                    
T.reads(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v2_o 
* 16 : v2_o * 16 + 16], weight_reindex_shared_wmma_matrix_b[v2_o * 16 : v2_o * 
16 + 16, v1_o * 16 : v1_o * 16 + 16])
-                                    
T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 
16, v1_o * 16 : v1_o * 16 + 16])
-                                    
T.block_attr({"meta_schedule.auto_tensorize":"wmma_sync_16x16x16_f16f16f32", 
"meta_schedule.auto_tensorize_init":"wmma_fill_16x16x16_f32", 
"warp_execution":1})
+                                    
T.reads(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v2_o * 
16:v2_o * 16 + 16], weight_reindex_shared_wmma_matrix_b[v2_o * 16:v2_o * 16 + 
16, v1_o * 16:v1_o * 16 + 16])
+                                    
T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, 0, 0, 0:16, 
0:16])
+                                    
T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32", 
"meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32", 
"warp_execution": 1})
                                     with T.init():
                                         for ax0_1, ax1_1 in T.grid(16, 16):
                                             with T.block("conv2d_nhwc_init"):
                                                 v0_i_init, v1_i_init = 
T.axis.remap("SS", [ax0_1, ax1_1])
                                                 T.reads()
-                                                
T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i_init, 
v1_o * 16 + v1_i_init])
-                                                
conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + 
v1_i_init] = T.float32(0)
+                                                
T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, 0, 0, 
v0_i_init, v1_i_init])
+                                                
conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, 0, 0, v0_i_init, 
v1_i_init] = T.float32(0)
                                     for ax0_1, ax1_1, ax2_1 in T.grid(16, 16, 
16):
                                         with T.block("conv2d_nhwc"):
                                             v0_i, v1_i, v2_i = 
T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1])
-                                            
T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 
+ v1_i], PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + 
v2_i], weight_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i])
-                                            
T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 
16 + v1_i])
-                                            
T.block_attr({"meta_schedule.tiling_structure":"SSSRRSRS"})
-                                            
conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] 
= conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + 
v1_i] + T.cast(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 
16 + v2_i], "float32") * T.cast(weight_reindex_shared_wmma_matrix_b[v2_o * 16 + 
v2_i, v1_o * 16 + v1_i], "float32")
-                    for ax0_0, ax1_0 in T.grid(1, 1):
-                        with 
T.block("conv2d_nhwc_reindex_shared_wmma.accumulator_o"):
-                            v0_o = T.axis.spatial(16, ax0_0_1_ax1_0_1_fused + 
ax0_0)
-                            v1_o = T.axis.spatial(2, ax0_0_0_ax1_0_0_fused + 
ax1_0)
-                            
T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, 
v1_o * 16 : v1_o * 16 + 16])
-                            T.writes(conv2d_nhwc_reindex_shared[v0_o * 16 : 
v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
-                            T.block_attr({"meta_schedule.auto_tensorize": 
f"wmma_store_16x16x16_f32_{intrin_suffix}"})
-                            for ax0_1, ax1_1 in T.grid(16, 16):
-                                with 
T.block("conv2d_nhwc_reindex_shared_wmma.accumulator"):
-                                    v0_i, v1_i = T.axis.remap("SS", [ax0_1, 
ax1_1])
-                                    
T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 
+ v1_i])
-                                    T.writes(conv2d_nhwc_reindex_shared[v0_o * 
16 + v0_i, v1_o * 16 + v1_i])
-                                    conv2d_nhwc_reindex_shared[v0_o * 16 + 
v0_i, v1_o * 16 + v1_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 
+ v0_i, v1_o * 16 + v1_i]
-                for ax0_ax1_fused in T.serial(256):
-                    with T.block("conv2d_nhwc_reindex_shared"):
-                        v0 = T.axis.spatial(256, ax0_0_1_ax1_0_1_fused * 16 + 
ax0_ax1_fused // 16)
-                        v1 = T.axis.spatial(32, ax0_0_0_ax1_0_0_fused * 16 +  
ax0_ax1_fused % 16)
-                        T.reads(conv2d_nhwc_reindex_shared[v0, v1])
-                        T.writes(conv2d_nhwc[v0 // 256, v0 // 16, v0 % 16, v1])
-                        T.block_attr({"meta_schedule.cooperative_fetch":3})
-                        conv2d_nhwc[v0 // 256, v0 // 16, v0 % 16, v1] = 
conv2d_nhwc_reindex_shared[v0, v1]
+                                            
T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, 0, 0, v0_i, 
v1_i], PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + 
v2_i], weight_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i])
+                                            
T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, 0, 0, v0_i, 
v1_i])
+                                            
T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"})
+                                            
conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, 0, 0, v0_i, v1_i] = 
conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, 0, 0, v0_i, v1_i] + 
T.Cast("float32", PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o 
* 16 + v2_i]) * T.Cast("float32", weight_reindex_shared_wmma_matrix_b[v2_o * 16 
+ v2_i, v1_o * 16 + v1_i])
+                for ax2 in range(1):
+                    for ax0_ax1_fused in T.thread_binding(1, 
thread="threadIdx.y"):
+                        for ax2_1, ax3 in T.grid(1, 1):
+                            with 
T.block("conv2d_nhwc_reindex_shared_wmma.accumulator_o"):
+                                v0, v1, v2, v3 = T.axis.remap("SSSS", 
[ax0_0_1_ax1_0_1_fused, ax0_0_0_ax1_0_0_fused, ax2_1, ax3])
+                                v4_o = T.axis.spatial(1, 0)
+                                v5_o = T.axis.spatial(1, 0)
+                                
T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0, v1, v2, v3, 0:16, 0:16])
+                                T.writes(conv2d_nhwc_reindex_shared[v0, v1, 
v2, v3, 0:16, 0:16])
+                                T.block_attr({"meta_schedule.auto_tensorize": 
f"wmma_store_16x16x16_f32_{intrin_suffix}"})
+                                for ax4, ax5 in T.grid(16, 16):
+                                    with 
T.block("conv2d_nhwc_reindex_shared_wmma.accumulator"):
+                                        v4_i, v5_i = T.axis.remap("SS", [ax4, 
ax5])
+                                        
T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i])
+                                        
T.writes(conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4_i, v5_i])
+                                        conv2d_nhwc_reindex_shared[v0, v1, v2, 
v3, v4_i, v5_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0, v1, v2, v3, 
v4_i, v5_i]
+                    for ax0_ax1_ax3_ax4_ax5_fused in range(256):
+                        with T.block("conv2d_nhwc_reindex_shared"):
+                            v0, v1, v2 = T.axis.remap("SSS", 
[ax0_0_1_ax1_0_1_fused, ax0_0_0_ax1_0_0_fused, ax2])
+                            v3 = T.axis.spatial(1, 0)
+                            v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused 
// 16)
+                            v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused 
% 16)
+                            T.reads(conv2d_nhwc_reindex_shared[v0, v1, v2, v3, 
v4, v5])
+                            T.writes(conv2d_nhwc[(v4 + v0 * 16) // 256, (v4 + 
v0 * 16) // 16, (v4 + v0 * 16) % 16, v5 + v1 * 16])
+                            T.block_attr({"meta_schedule.cooperative_fetch": 
3})
+                            conv2d_nhwc[(v4 + v0 * 16) // 256, (v4 + v0 * 16) 
// 16, (v4 + v0 * 16) % 16, v5 + v1 * 16] = conv2d_nhwc_reindex_shared[v0, v1, 
v2, v3, v4, v5]
     # fmt: on
     decision_0 = [
         ("SamplePerfectTile", [1, 16, 1, 1, 1]),
@@ -551,40 +574,40 @@ def test_matmul_relu_pipeline(shared_scope):
         T.func_attr({"global_symbol": "main", "tir.noalias": True})
         # body
         # with T.block("root")
-        C = T.alloc_buffer([128, 128], dtype="float32")
-        C_reindex_shared = T.alloc_buffer([128, 128], dtype="float32", 
scope=shared_scope)
-        C_reindex_shared_wmma_accumulator = T.alloc_buffer([128, 128], 
dtype="float32", scope="wmma.accumulator")
-        A_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", 
scope=shared_scope)
-        B_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", 
scope=shared_scope)
-        A_reindex_shared_wmma_matrix_a = T.alloc_buffer([128, 128], 
dtype="float16", scope="wmma.matrix_a")
-        B_reindex_shared_wmma_matrix_b = T.alloc_buffer([128, 128], 
dtype="float16", scope="wmma.matrix_b")
+        C = T.alloc_buffer((128, 128))
+        C_reindex_shared = T.alloc_buffer((4, 4, 2, 2, 16, 16), 
scope=shared_scope)
+        C_reindex_shared_wmma_accumulator = T.alloc_buffer((4, 4, 2, 2, 16, 
16), scope="wmma.accumulator")
+        A_reindex_shared = T.alloc_buffer((128, 128), "float16", 
scope=shared_scope)
+        B_reindex_shared = T.alloc_buffer((128, 128), "float16", 
scope=shared_scope)
+        A_reindex_shared_wmma_matrix_a = T.alloc_buffer((128, 128), "float16", 
scope="wmma.matrix_a")
+        B_reindex_shared_wmma_matrix_b = T.alloc_buffer((128, 128), "float16", 
scope="wmma.matrix_b")
         for ax0_0_0_ax1_0_0_fused in T.thread_binding(1, thread="blockIdx.y"):
             for ax0_0_1_ax1_0_1_fused in T.thread_binding(16, 
thread="blockIdx.x"):
                 for ax0_0_2_ax1_0_2_fused in T.thread_binding(1, 
thread="threadIdx.y"):
-                    for ax2_0_0 in T.serial(4, 
annotations={"software_pipeline_order":[0, 3, 1, 4, 5, 2, 6], 
"software_pipeline_stage":[0, 0, 0, 0, 0, 1, 1]}):
-                        for ax0_ax1_fused in T.serial(1024):
+                    for ax2_0_0 in T.serial(4, 
annotations={"software_pipeline_order": [0, 3, 1, 4, 5, 2, 6], 
"software_pipeline_stage": [0, 0, 0, 0, 0, 1, 1]}):
+                        for ax0_ax1_fused in range(1024):
                             with T.block("A_reindex_shared"):
                                 v0 = T.axis.spatial(128, ax0_0_1_ax1_0_1_fused 
// 4 * 32 + ax0_ax1_fused // 32)
                                 v1 = T.axis.spatial(128, ax2_0_0 * 32 + 
ax0_ax1_fused % 32)
                                 T.reads(A[v0, v1])
                                 T.writes(A_reindex_shared[v0, v1])
-                                T.block_attr({"buffer_dim_align":[[0, 0, 32, 
8]], "double_buffer_scope":0, "meta_schedule.cooperative_fetch":4, 
"tir.manifest_shared_memory_local_stage":1})
+                                T.block_attr({"buffer_dim_align": [[0, 0, 32, 
8]], "double_buffer_scope": 0, "meta_schedule.cooperative_fetch": 4, 
"tir.manifest_shared_memory_local_stage": 1})
                                 A_reindex_shared[v0, v1] = A[v0, v1]
-                        for ax0_ax1_fused in T.serial(1024):
+                        for ax0_ax1_fused in range(1024):
                             with T.block("B_reindex_shared"):
                                 v0 = T.axis.spatial(128, ax2_0_0 * 32 + 
ax0_ax1_fused // 32)
                                 v1 = T.axis.spatial(128, ax0_0_1_ax1_0_1_fused 
% 4 * 32 + ax0_ax1_fused % 32)
                                 T.reads(B[v0, v1])
                                 T.writes(B_reindex_shared[v0, v1])
-                                T.block_attr({"buffer_dim_align":[[0, 0, 32, 
8]], "double_buffer_scope":0, "meta_schedule.cooperative_fetch":2, 
"tir.manifest_shared_memory_local_stage":1})
+                                T.block_attr({"buffer_dim_align": [[0, 0, 32, 
8]], "double_buffer_scope": 0, "meta_schedule.cooperative_fetch": 2, 
"tir.manifest_shared_memory_local_stage": 1})
                                 B_reindex_shared[v0, v1] = B[v0, v1]
-                        for ax2_0_1 in T.serial(2, 
annotations={"software_pipeline_order":[0, 1, 2], "software_pipeline_stage":[0, 
0, 1]}):
+                        for ax2_0_1 in T.serial(2, 
annotations={"software_pipeline_order": [0, 1, 2], "software_pipeline_stage": 
[0, 0, 1]}):
                             for ax0_0, ax1_0 in T.grid(2, 1):
                                 with 
T.block("A_reindex_shared_wmma.matrix_a_o"):
                                     v0_o = T.axis.spatial(8, 
ax0_0_1_ax1_0_1_fused // 4 * 2 + ax0_0)
                                     v1_o = T.axis.spatial(8, ax2_0_0 * 2 + 
ax2_0_1 + ax1_0)
-                                    T.reads(A_reindex_shared[v0_o * 16 : v0_o 
* 16 + 16, v1_o * 16 : v1_o * 16 + 16])
-                                    
T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : 
v1_o * 16 + 16])
+                                    T.reads(A_reindex_shared[v0_o * 16:v0_o * 
16 + 16, v1_o * 16:v1_o * 16 + 16])
+                                    
T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 
16:v1_o * 16 + 16])
                                     
T.block_attr({"meta_schedule.auto_tensorize": 
f"wmma_load_16x16x16_f16_a_{intrin_suffix}"})
                                     for ax0_1, ax1_1 in T.grid(16, 16):
                                         with 
T.block("A_reindex_shared_wmma.matrix_a"):
@@ -596,8 +619,8 @@ def test_matmul_relu_pipeline(shared_scope):
                                 with 
T.block("B_reindex_shared_wmma.matrix_b_o"):
                                     v0_o = T.axis.spatial(8, ax2_0_0 * 2 + 
ax2_0_1 + ax0_0)
                                     v1_o = T.axis.spatial(8, 
ax0_0_1_ax1_0_1_fused % 4 * 2 + ax1_0)
-                                    T.reads(B_reindex_shared[v0_o * 16 : v0_o 
* 16 + 16, v1_o * 16 : v1_o * 16 + 16])
-                                    
T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : 
v1_o * 16 + 16])
+                                    T.reads(B_reindex_shared[v0_o * 16:v0_o * 
16 + 16, v1_o * 16:v1_o * 16 + 16])
+                                    
T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16:v0_o * 16 + 16, v1_o * 
16:v1_o * 16 + 16])
                                     
T.block_attr({"meta_schedule.auto_tensorize": 
f"wmma_load_16x16x16_f16_b_{intrin_suffix}"})
                                     for ax0_1, ax1_1 in T.grid(16, 16):
                                         with 
T.block("B_reindex_shared_wmma.matrix_b"):
@@ -610,50 +633,61 @@ def test_matmul_relu_pipeline(shared_scope):
                                     v0_o = T.axis.spatial(8, 
ax0_0_1_ax1_0_1_fused // 4 * 2 + ax0_0_3 * 2 + ax0_0_4)
                                     v1_o = T.axis.spatial(8, 
ax0_0_1_ax1_0_1_fused % 4 * 2 + ax1_0_3 * 2 + ax1_0_4)
                                     v2_o = T.axis.reduce(8, ax2_0_0 * 2 + 
ax2_0_1 + ax2_0_2)
-                                    
T.reads(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v2_o * 16 : 
v2_o * 16 + 16], B_reindex_shared_wmma_matrix_b[v2_o * 16 : v2_o * 16 + 16, 
v1_o * 16 : v1_o * 16 + 16])
-                                    
T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 
16 : v1_o * 16 + 16])
-                                    
T.block_attr({"meta_schedule.auto_tensorize":"wmma_sync_16x16x16_f16f16f32", 
"meta_schedule.auto_tensorize_init":"wmma_fill_16x16x16_f32", 
"warp_execution":1})
+                                    
T.reads(A_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v2_o * 16:v2_o 
* 16 + 16], B_reindex_shared_wmma_matrix_b[v2_o * 16:v2_o * 16 + 16, v1_o * 
16:v1_o * 16 + 16])
+                                    
T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 2, v0_o % 2, v1_o 
% 2, 0:16, 0:16])
+                                    
T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32", 
"meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32", 
"warp_execution": 1})
                                     with T.init():
                                         for ax0_1, ax1_1 in T.grid(16, 16):
                                             with T.block("C_init"):
                                                 v0_i_init, v1_i_init = 
T.axis.remap("SS", [ax0_1, ax1_1])
                                                 T.reads()
-                                                
T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + 
v1_i_init])
-                                                
C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + v1_i_init] 
= T.float32(0)
+                                                
T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 2, v0_o % 2, v1_o 
% 2, v0_i_init, v1_i_init])
+                                                
C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 2, v0_o % 2, v1_o % 2, 
v0_i_init, v1_i_init] = T.float32(0)
                                     for ax0_1, ax1_1, ax2_1 in T.grid(16, 16, 
16):
                                         with T.block("C"):
                                             v0_i, v1_i, v2_i = 
T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1])
-                                            
T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i], 
A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], 
B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i])
-                                            
T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
-                                            
T.block_attr({"meta_schedule.tiling_structure":"SSSRRSRS"})
-                                            
C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = 
C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + 
T.cast(A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], 
"float32") * T.cast(B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 
+ v1_i], "float32")
-                    for ax0_0, ax1_0 in T.grid(2, 2):
-                        with T.block("C_reindex_shared_wmma.accumulator_o"):
-                            v0_o = T.axis.spatial(8, ax0_0_1_ax1_0_1_fused // 
4 * 2 + ax0_0)
-                            v1_o = T.axis.spatial(8, ax0_0_1_ax1_0_1_fused % 4 
* 2 + ax1_0)
-                            T.reads(C_reindex_shared_wmma_accumulator[v0_o * 
16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
-                            T.writes(C_reindex_shared[v0_o * 16 : v0_o * 16 + 
16, v1_o * 16 : v1_o * 16 + 16])
-                            T.block_attr({"meta_schedule.auto_tensorize": 
f"wmma_store_16x16x16_f32_{intrin_suffix}"})
-                            for ax0_1, ax1_1 in T.grid(16, 16):
-                                with 
T.block("C_reindex_shared_wmma.accumulator"):
-                                    v0_i, v1_i = T.axis.remap("SS", [ax0_1, 
ax1_1])
-                                    
T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
-                                    T.writes(C_reindex_shared[v0_o * 16 + 
v0_i, v1_o * 16 + v1_i])
-                                    C_reindex_shared[v0_o * 16 + v0_i, v1_o * 
16 + v1_i] = C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + 
v1_i]
-                for ax0_ax1_fused in T.grid(1024):
-                    with T.block("C_reindex_shared"):
-                        v0 = T.axis.spatial(128, ax0_0_1_ax1_0_1_fused // 4 * 
32 + ax0_ax1_fused // 32)
-                        v1 = T.axis.spatial(128, ax0_0_1_ax1_0_1_fused % 4 * 
32 + ax0_ax1_fused % 32)
-                        T.reads(C_reindex_shared[v0, v1])
-                        T.writes(C[v0, v1])
-                        T.block_attr({"meta_schedule.cooperative_fetch":3})
-                        C[v0, v1] = C_reindex_shared[v0, v1]
+                                            
T.reads(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 2, v0_o % 2, v1_o 
% 2, v0_i, v1_i], A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + 
v2_i], B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i])
+                                            
T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 2, v0_o % 2, v1_o 
% 2, v0_i, v1_i])
+                                            
T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"})
+                                            
C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 2, v0_o % 2, v1_o % 2, 
v0_i, v1_i] = C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 2, v0_o % 2, 
v1_o % 2, v0_i, v1_i] + T.Cast("float32", A_reindex_shared_wmma_matrix_a[v0_o * 
16 + v0_i, v2_o * 16 + v2_i]) * T.Cast("float32", 
B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i])
+                for ax2 in range(2):
+                    for ax0_ax1_fused in T.thread_binding(1, 
thread="threadIdx.y"):
+                        for ax2_1, ax3 in T.grid(1, 2):
+                            with 
T.block("C_reindex_shared_wmma.accumulator_o"):
+                                v0 = T.axis.spatial(4, ax0_0_1_ax1_0_1_fused 
// 4)
+                                v1 = T.axis.spatial(4, ax0_0_1_ax1_0_1_fused % 
4)
+                                v2 = T.axis.spatial(2, ax2 + ax2_1)
+                                v3 = T.axis.spatial(2, ax3)
+                                v4_o = T.axis.spatial(1, 0)
+                                v5_o = T.axis.spatial(1, 0)
+                                T.reads(C_reindex_shared_wmma_accumulator[v0, 
v1, v2, v3, 0:16, 0:16])
+                                T.writes(C_reindex_shared[v0, v1, v2, v3, 
0:16, 0:16])
+                                T.block_attr({"meta_schedule.auto_tensorize": 
f"wmma_store_16x16x16_f32_{intrin_suffix}"})
+                                for ax4, ax5 in T.grid(16, 16):
+                                    with 
T.block("C_reindex_shared_wmma.accumulator"):
+                                        v4_i, v5_i = T.axis.remap("SS", [ax4, 
ax5])
+                                        
T.reads(C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i])
+                                        T.writes(C_reindex_shared[v0, v1, v2, 
v3, v4_i, v5_i])
+                                        C_reindex_shared[v0, v1, v2, v3, v4_i, 
v5_i] = C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i]
+                    for ax0_ax1_ax3_ax4_ax5_fused in range(512):
+                        with T.block("C_reindex_shared"):
+                            v0 = T.axis.spatial(4, ax0_0_1_ax1_0_1_fused // 4)
+                            v1 = T.axis.spatial(4, ax0_0_1_ax1_0_1_fused % 4)
+                            v2 = T.axis.spatial(2, ax2)
+                            v3 = T.axis.spatial(2, ax0_ax1_ax3_ax4_ax5_fused 
// 256)
+                            v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused 
% 256 // 16)
+                            v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused 
% 16)
+                            T.reads(C_reindex_shared[v0, v1, v2, v3, v4, v5])
+                            T.writes(C[v4 + v2 * 16 + v0 * 32, v5 + v3 * 16 + 
v1 * 32])
+                            T.block_attr({"meta_schedule.cooperative_fetch": 
3})
+                            C[v4 + v2 * 16 + v0 * 32, v5 + v3 * 16 + v1 * 32] 
= C_reindex_shared[v0, v1, v2, v3, v4, v5]
         for i0, i1 in T.grid(128, 128):
             with T.block("compute"):
-                i0_1, i1_1 = T.axis.remap("SS", [i0, i1])
-                T.reads(C[i0_1, i1_1])
-                T.writes(compute[i0_1, i1_1])
-                compute[i0_1, i1_1] = T.max(C[i0_1, i1_1], T.float32(0))
+                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+                T.reads(C[v_i0, v_i1])
+                T.writes(compute[v_i0, v_i1])
+                compute[v_i0, v_i1] = T.max(C[v_i0, v_i1], T.float32(0))
+
     # fmt: on
     decision_0 = [
         ("SamplePerfectTile", [1, 4, 1, 1, 2]),
@@ -693,141 +727,6 @@ def test_matmul_relu_pipeline(shared_scope):
     )
 
 
-def test_matmul_relu_global():
-    # fmt: off
-    @T.prim_func
-    def matmul_relu_global_0(A: T.Buffer((128, 128), "float16"), B: 
T.Buffer((128, 128), "float16"), compute: T.Buffer((128, 128), "float32")) -> 
None:
-        # function attr dict
-        T.func_attr({"global_symbol": "main", "tir.noalias": True})
-        # body
-        # with T.block("root")
-        C = T.alloc_buffer([128, 128], dtype="float32")
-        C_reindex_wmma_accumulator = T.alloc_buffer([128, 128], 
dtype="float32", scope="wmma.accumulator")
-        A_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", 
scope="shared")
-        B_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", 
scope="shared")
-        A_reindex_shared_wmma_matrix_a = T.alloc_buffer([128, 128], 
dtype="float16", scope="wmma.matrix_a")
-        B_reindex_shared_wmma_matrix_b = T.alloc_buffer([128, 128], 
dtype="float16", scope="wmma.matrix_b")
-        for ax0_0_0_ax1_0_0_fused in T.thread_binding(1, thread="blockIdx.y"):
-            for ax0_0_1_ax1_0_1_fused in T.thread_binding(1, 
thread="blockIdx.x"):
-                for ax0_0_2_ax1_0_2_fused in T.thread_binding(16, 
thread="threadIdx.y"):
-                    for ax2_0_0 in T.serial(2):
-                        for ax0_ax1_fused in T.serial(8192):
-                            with T.block("A_reindex_shared"):
-                                v0 = T.axis.spatial(128, ax0_ax1_fused // 64)
-                                v1 = T.axis.spatial(128, ax2_0_0 * 64 + 
ax0_ax1_fused % 64)
-                                T.reads(A[v0, v1])
-                                T.writes(A_reindex_shared[v0, v1])
-                                T.block_attr({"buffer_dim_align":[[0, 0, 32, 
8]], "meta_schedule.cooperative_fetch":1})
-                                A_reindex_shared[v0, v1] = A[v0, v1]
-                        for ax0_ax1_fused in T.serial(8192):
-                            with T.block("B_reindex_shared"):
-                                v0 = T.axis.spatial(128, ax2_0_0 * 64 + 
ax0_ax1_fused // 128)
-                                v1 = T.axis.spatial(128, ax0_ax1_fused % 128)
-                                T.reads(B[v0, v1])
-                                T.writes(B_reindex_shared[v0, v1])
-                                T.block_attr({"buffer_dim_align":[[0, 0, 32, 
8]], "meta_schedule.cooperative_fetch":1})
-                                B_reindex_shared[v0, v1] = B[v0, v1]
-                        for ax2_0_1 in T.serial(2):
-                            for ax0_0, ax1_0 in T.grid(1, 2):
-                                with 
T.block("A_reindex_shared_wmma.matrix_a_o"):
-                                    v0_o = T.axis.spatial(8, 
ax0_0_2_ax1_0_2_fused // 2 + ax0_0)
-                                    v1_o = T.axis.spatial(8, ax2_0_0 * 4 + 
ax2_0_1 * 2 + ax1_0)
-                                    T.reads(A_reindex_shared[v0_o * 16 : v0_o 
* 16 + 16, v1_o * 16 : v1_o * 16 + 16])
-                                    
T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : 
v1_o * 16 + 16])
-                                    
T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_a_shared"})
-                                    for ax0_1, ax1_1 in T.grid(16, 16):
-                                        with 
T.block("A_reindex_shared_wmma.matrix_a"):
-                                            v0_i, v1_i = T.axis.remap("SS", 
[ax0_1, ax1_1])
-                                            T.reads(A_reindex_shared[v0_o * 16 
+ v0_i, v1_o * 16 + v1_i])
-                                            
T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
-                                            
A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = 
A_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]
-                            for ax0_0, ax1_0 in T.grid(2, 4):
-                                with 
T.block("B_reindex_shared_wmma.matrix_b_o"):
-                                    v0_o = T.axis.spatial(8, ax2_0_0 * 4 + 
ax2_0_1 * 2 + ax0_0)
-                                    v1_o = T.axis.spatial(8, 
ax0_0_2_ax1_0_2_fused % 2 * 4 + ax1_0)
-                                    T.reads(B_reindex_shared[v0_o * 16 : v0_o 
* 16 + 16, v1_o * 16 : v1_o * 16 + 16])
-                                    
T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : 
v1_o * 16 + 16])
-                                    
T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_b_shared"})
-                                    for ax0_1, ax1_1 in T.grid(16, 16):
-                                        with 
T.block("B_reindex_shared_wmma.matrix_b"):
-                                            v0_i, v1_i = T.axis.remap("SS", 
[ax0_1, ax1_1])
-                                            T.reads(B_reindex_shared[v0_o * 16 
+ v0_i, v1_o * 16 + v1_i])
-                                            
T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
-                                            
B_reindex_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = 
B_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]
-                            for ax0_0_3, ax1_0_3, ax2_0_2, ax0_0_4, ax1_0_4 in 
T.grid(1, 4, 2, 1, 1):
-                                with T.block("C_o"):
-                                    v0_o = T.axis.spatial(8, 
ax0_0_2_ax1_0_2_fused // 2 + ax0_0_3 + ax0_0_4)
-                                    v1_o = T.axis.spatial(8, ax1_0_4 + 
ax0_0_2_ax1_0_2_fused % 2 * 4 + ax1_0_3)
-                                    v2_o = T.axis.reduce(8, ax2_0_0 * 4 + 
ax2_0_1 * 2 + ax2_0_2)
-                                    
T.reads(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v2_o * 16 : 
v2_o * 16 + 16], B_reindex_shared_wmma_matrix_b[v2_o * 16 : v2_o * 16 + 16, 
v1_o * 16 : v1_o * 16 + 16])
-                                    T.writes(C_reindex_wmma_accumulator[v0_o * 
16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
-                                    
T.block_attr({"meta_schedule.auto_tensorize":"wmma_sync_16x16x16_f16f16f32", 
"meta_schedule.auto_tensorize_init":"wmma_fill_16x16x16_f32", 
"warp_execution":1})
-                                    with T.init():
-                                        for ax0_1, ax1_1 in T.grid(16, 16):
-                                            with T.block("C_init"):
-                                                v0_i_init, v1_i_init = 
T.axis.remap("SS", [ax0_1, ax1_1])
-                                                T.reads()
-                                                
T.writes(C_reindex_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + 
v1_i_init])
-                                                
C_reindex_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + v1_i_init] = 
T.float32(0)
-                                    for ax0_1, ax1_1, ax2_1 in T.grid(16, 16, 
16):
-                                        with T.block("C"):
-                                            v0_i, v1_i, v2_i = 
T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1])
-                                            
T.reads(C_reindex_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i], 
A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], 
B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i])
-                                            
T.writes(C_reindex_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
-                                            
T.block_attr({"meta_schedule.tiling_structure":"SSSRRSRS"})
-                                            C_reindex_wmma_accumulator[v0_o * 
16 + v0_i, v1_o * 16 + v1_i] = C_reindex_wmma_accumulator[v0_o * 16 + v0_i, 
v1_o * 16 + v1_i] + T.cast(A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, 
v2_o * 16 + v2_i], "float32") * T.cast(B_reindex_shared_wmma_matrix_b[v2_o * 16 
+ v2_i, v1_o * 16 + v1_i], "float32")
-                    for ax0_0, ax1_0 in T.grid(1, 4):
-                        with T.block("C_reindex_wmma.accumulator_o"):
-                            v0_o = T.axis.spatial(8, ax0_0_2_ax1_0_2_fused // 
2 + ax0_0)
-                            v1_o = T.axis.spatial(8, ax0_0_2_ax1_0_2_fused % 2 
* 4 + ax1_0)
-                            T.reads(C_reindex_wmma_accumulator[v0_o * 16 : 
v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
-                            T.writes(C[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : 
v1_o * 16 + 16])
-                            
T.block_attr({"meta_schedule.auto_tensorize":"wmma_store_16x16x16_f32_global"})
-                            for ax0_1, ax1_1 in T.grid(16, 16):
-                                with T.block("C_reindex_wmma.accumulator"):
-                                    v0_i, v1_i = T.axis.remap("SS", [ax0_1, 
ax1_1])
-                                    T.reads(C_reindex_wmma_accumulator[v0_o * 
16 + v0_i, v1_o * 16 + v1_i])
-                                    T.writes(C[v0_o * 16 + v0_i, v1_o * 16 + 
v1_i])
-                                    C[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = 
C_reindex_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]
-        for i0, i1 in T.grid(128, 128):
-            with T.block("compute"):
-                i0_1, i1_1 = T.axis.remap("SS", [i0, i1])
-                T.reads(C[i0_1, i1_1])
-                T.writes(compute[i0_1, i1_1])
-                compute[i0_1, i1_1] = T.max(C[i0_1, i1_1], T.float32(0))
-    # fmt: on
-    decision_0 = [
-        ("SamplePerfectTile", [1, 1, 8, 1, 1]),
-        ("SamplePerfectTile", [1, 1, 2, 4, 1]),
-        ("SamplePerfectTile", [2, 2, 2]),
-        ("SampleCategorical", 0),
-        ("SampleCategorical", 0),
-    ]
-    mod = te.create_prim_func(
-        te_workload.matmul_relu(
-            n=128,
-            m=128,
-            k=128,
-            in_dtype="float16",
-            out_dtype="float32",
-        )
-    )
-    actual = generate_design_space(
-        kind="cuda",
-        mod=mod,
-        target=tvm.target.Target("cuda"),
-        types=None,
-        sch_rules=[multi_level_tiling_tensor_core(write_reuse_scope="global")]
-        + get_rules("cuda", ms.schedule_rule.AutoInline),
-    )
-    check_sketches(
-        mod,
-        sketches=actual,
-        expected_mods=[matmul_relu_global_0],
-        expected_decisions=[decision_0],
-    )
-
-
 def test_matmul_relu_non_tensorizable():
     # expected to do nothing on non-tensorizable workloads
     mod = te.create_prim_func(
@@ -842,7 +741,7 @@ def test_matmul_relu_non_tensorizable():
         mod=mod,
         target=tvm.target.Target("cuda"),
         types=None,
-        sch_rules=[multi_level_tiling_tensor_core(write_reuse_scope="global")]
+        sch_rules=[multi_level_tiling_tensor_core(write_reuse_scope="shared")]
         + get_rules("cuda", ms.schedule_rule.AutoInline),
     )
     tvm.ir.assert_structural_equal(mod, sch.mod["main"])
@@ -856,40 +755,40 @@ def test_padded_matmul_relu():
         T.func_attr({"global_symbol": "main", "tir.noalias": True})
         # body
         # with T.block("root")
-        C_reindex_shared = T.alloc_buffer([128, 128], dtype="float32", 
scope="shared")
-        C_reindex_shared_wmma_accumulator = T.alloc_buffer([128, 128], 
dtype="float32", scope="wmma.accumulator")
-        A_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", 
scope="shared")
-        B_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", 
scope="shared")
-        A_reindex_shared_wmma_matrix_a = T.alloc_buffer([128, 128], 
dtype="float16", scope="wmma.matrix_a")
-        B_reindex_shared_wmma_matrix_b = T.alloc_buffer([128, 128], 
dtype="float16", scope="wmma.matrix_b")
+        C_reindex_shared = T.alloc_buffer((4, 8, 2, 1, 16, 16), scope="shared")
+        C_reindex_shared_wmma_accumulator = T.alloc_buffer((4, 8, 2, 1, 16, 
16), scope="wmma.accumulator")
+        A_reindex_shared = T.alloc_buffer((128, 128), "float16", 
scope="shared")
+        B_reindex_shared = T.alloc_buffer((128, 128), "float16", 
scope="shared")
+        A_reindex_shared_wmma_matrix_a = T.alloc_buffer((128, 128), "float16", 
scope="wmma.matrix_a")
+        B_reindex_shared_wmma_matrix_b = T.alloc_buffer((128, 128), "float16", 
scope="wmma.matrix_b")
         for ax0_0_0_ax1_0_0_fused in T.thread_binding(8, thread="blockIdx.y"):
             for ax0_0_1_ax1_0_1_fused in T.thread_binding(2, 
thread="blockIdx.x"):
                 for ax0_0_2_ax1_0_2_fused in T.thread_binding(2, 
thread="threadIdx.y"):
-                    for ax2_0_0 in T.serial(1):
-                        for ax0_ax1_fused in T.serial(4096):
+                    for ax2_0_0 in range(1):
+                        for ax0_ax1_fused in range(4096):
                             with T.block("A_reindex_shared"):
                                 v0 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused 
// 2 * 32 + ax0_ax1_fused // 128)
                                 v1 = T.axis.spatial(128, ax0_ax1_fused % 128)
                                 T.reads(A[v0, v1])
                                 T.writes(A_reindex_shared[v0, v1])
-                                T.block_attr({"buffer_dim_align":[[0, 0, 32, 
8]], "meta_schedule.cooperative_fetch":8})
-                                A_reindex_shared[v0, v1] = T.if_then_else(v0 < 
127 and v1 < 127, A[v0, v1], T.float16(0), dtype="float16")
-                        for ax0_ax1_fused in T.serial(4096):
+                                T.block_attr({"buffer_dim_align": [[0, 0, 32, 
8]], "meta_schedule.cooperative_fetch": 8})
+                                A_reindex_shared[v0, v1] = T.if_then_else(v0 < 
127 and v1 < 127, A[v0, v1], T.float16(0))
+                        for ax0_ax1_fused in range(4096):
                             with T.block("B_reindex_shared"):
                                 v0 = T.axis.spatial(128, ax0_ax1_fused // 32)
                                 v1 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused 
% 2 * 64 + ax0_0_1_ax1_0_1_fused * 32 + ax0_ax1_fused % 32)
                                 T.reads(B[v0, v1])
                                 T.writes(B_reindex_shared[v0, v1])
-                                T.block_attr({"buffer_dim_align":[[0, 0, 32, 
8]], "meta_schedule.cooperative_fetch":1})
-                                B_reindex_shared[v0, v1] = T.if_then_else(v0 < 
127 and v1 < 127, B[v0, v1], T.float16(0), dtype="float16")
-                        for ax2_0_1 in T.serial(4):
+                                T.block_attr({"buffer_dim_align": [[0, 0, 32, 
8]], "meta_schedule.cooperative_fetch": 1})
+                                B_reindex_shared[v0, v1] = T.if_then_else(v0 < 
127 and v1 < 127, B[v0, v1], T.float16(0))
+                        for ax2_0_1 in range(4):
                             for ax0_0, ax1_0 in T.grid(2, 2):
                                 with 
T.block("A_reindex_shared_wmma.matrix_a_o"):
                                     v0_o = T.axis.spatial(8, 
ax0_0_0_ax1_0_0_fused // 2 * 2 + ax0_0)
                                     v1_o = T.axis.spatial(8, ax2_0_1 * 2 + 
ax1_0)
-                                    T.reads(A_reindex_shared[v0_o * 16 : v0_o 
* 16 + 16, v1_o * 16 : v1_o * 16 + 16])
-                                    
T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : 
v1_o * 16 + 16])
-                                    
T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_a_shared"})
+                                    T.reads(A_reindex_shared[v0_o * 16:v0_o * 
16 + 16, v1_o * 16:v1_o * 16 + 16])
+                                    
T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 
16:v1_o * 16 + 16])
+                                    
T.block_attr({"meta_schedule.auto_tensorize": 
"wmma_load_16x16x16_f16_a_shared"})
                                     for ax0_1, ax1_1 in T.grid(16, 16):
                                         with 
T.block("A_reindex_shared_wmma.matrix_a"):
                                             v0_i, v1_i = T.axis.remap("SS", 
[ax0_1, ax1_1])
@@ -900,9 +799,9 @@ def test_padded_matmul_relu():
                                 with 
T.block("B_reindex_shared_wmma.matrix_b_o"):
                                     v0_o = T.axis.spatial(8, ax2_0_1 * 2 + 
ax0_0)
                                     v1_o = T.axis.spatial(8, 
ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + 
ax0_0_2_ax1_0_2_fused + ax1_0)
-                                    T.reads(B_reindex_shared[v0_o * 16 : v0_o 
* 16 + 16, v1_o * 16 : v1_o * 16 + 16])
-                                    
T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : 
v1_o * 16 + 16])
-                                    
T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_b_shared"})
+                                    T.reads(B_reindex_shared[v0_o * 16:v0_o * 
16 + 16, v1_o * 16:v1_o * 16 + 16])
+                                    
T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16:v0_o * 16 + 16, v1_o * 
16:v1_o * 16 + 16])
+                                    
T.block_attr({"meta_schedule.auto_tensorize": 
"wmma_load_16x16x16_f16_b_shared"})
                                     for ax0_1, ax1_1 in T.grid(16, 16):
                                         with 
T.block("B_reindex_shared_wmma.matrix_b"):
                                             v0_i, v1_i = T.axis.remap("SS", 
[ax0_1, ax1_1])
@@ -914,45 +813,56 @@ def test_padded_matmul_relu():
                                     v0_o = T.axis.spatial(8, 
ax0_0_0_ax1_0_0_fused // 2 * 2 + ax0_0_3 * 2 + ax0_0_4)
                                     v1_o = T.axis.spatial(8, ax1_0_4 + 
ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + 
ax0_0_2_ax1_0_2_fused + ax1_0_3)
                                     v2_o = T.axis.reduce(8, ax2_0_0 * 8 + 
ax2_0_1 * 2 + ax2_0_2)
-                                    
T.reads(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v2_o * 16 : 
v2_o * 16 + 16], B_reindex_shared_wmma_matrix_b[v2_o * 16 : v2_o * 16 + 16, 
v1_o * 16 : v1_o * 16 + 16])
-                                    
T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 
16 : v1_o * 16 + 16])
-                                    
T.block_attr({"meta_schedule.auto_tensorize":"wmma_sync_16x16x16_f16f16f32", 
"meta_schedule.auto_tensorize_init":"wmma_fill_16x16x16_f32", 
"warp_execution":1})
+                                    
T.reads(A_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v2_o * 16:v2_o 
* 16 + 16], B_reindex_shared_wmma_matrix_b[v2_o * 16:v2_o * 16 + 16, v1_o * 
16:v1_o * 16 + 16])
+                                    
T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, 0:16, 
0:16])
+                                    
T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32", 
"meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32", 
"warp_execution": 1})
                                     with T.init():
                                         for ax0_1, ax1_1 in T.grid(16, 16):
                                             with T.block("C_init"):
                                                 v0_i_init, v1_i_init = 
T.axis.remap("SS", [ax0_1, ax1_1])
                                                 T.reads()
-                                                
T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + 
v1_i_init])
-                                                
C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + v1_i_init] 
= T.float32(0)
+                                                
T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, 
v0_i_init, v1_i_init])
+                                                
C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i_init, 
v1_i_init] = T.float32(0)
                                     for ax0_1, ax1_1, ax2_1 in T.grid(16, 16, 
16):
                                         with T.block("C"):
                                             v0_i, v1_i, v2_i = 
T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1])
-                                            
T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i], 
A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], 
B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i])
-                                            
T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
-                                            
T.block_attr({"meta_schedule.tiling_structure":"SSSRRSRS"})
-                                            
C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = 
C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + 
T.cast(A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], 
"float32") * T.cast(B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 
+ v1_i], "float32")
-                    for ax0_0, ax1_0 in T.grid(2, 1):
-                        with T.block("C_reindex_shared_wmma.accumulator_o"):
-                            v0_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused // 
2 * 2 + ax0_0)
-                            v1_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 
* 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_2_ax1_0_2_fused + ax1_0)
-                            T.reads(C_reindex_shared_wmma_accumulator[v0_o * 
16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
-                            T.writes(C_reindex_shared[v0_o * 16 : v0_o * 16 + 
16, v1_o * 16 : v1_o * 16 + 16])
-                            
T.block_attr({"meta_schedule.auto_tensorize":"wmma_store_16x16x16_f32_shared"})
-                            for ax0_1, ax1_1 in T.grid(16, 16):
-                                with 
T.block("C_reindex_shared_wmma.accumulator"):
-                                    v0_i, v1_i = T.axis.remap("SS", [ax0_1, 
ax1_1])
-                                    
T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
-                                    T.writes(C_reindex_shared[v0_o * 16 + 
v0_i, v1_o * 16 + v1_i])
-                                    C_reindex_shared[v0_o * 16 + v0_i, v1_o * 
16 + v1_i] = C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + 
v1_i]
-                for ax0_ax1_fused in T.serial(1024):
-                    with T.block("C_reindex_shared"):
-                        T.where(ax0_0_0_ax1_0_0_fused // 2 * 32 + 
ax0_ax1_fused // 32 < 127 and ax0_0_0_ax1_0_0_fused % 2 * 64 + 
ax0_0_1_ax1_0_1_fused * 32 + ax0_ax1_fused % 32 < 127)
-                        v0 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused // 2 * 
32 + ax0_ax1_fused // 32)
-                        v1 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused % 2 * 
64 + ax0_0_1_ax1_0_1_fused * 32 + ax0_ax1_fused % 32)
-                        T.reads(C_reindex_shared[v0, v1])
-                        T.writes(compute[v0, v1])
-                        T.block_attr({"meta_schedule.cooperative_fetch":4})
-                        compute[v0, v1] = T.max(C_reindex_shared[v0, v1], 
T.float32(0))
+                                            
T.reads(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i, 
v1_i], A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], 
B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i])
+                                            
T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i, 
v1_i])
+                                            
T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"})
+                                            
C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i, v1_i] = 
C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i, v1_i] + 
T.Cast("float32", A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + 
v2_i]) * T.Cast("float32", B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, 
v1_o * 16 + v1_i])
+                for ax2 in range(2):
+                    for ax0_ax1_fused in T.thread_binding(2, 
thread="threadIdx.y"):
+                        for ax2_1, ax3 in T.grid(1, 1):
+                            with 
T.block("C_reindex_shared_wmma.accumulator_o"):
+                                v0 = T.axis.spatial(4, ax0_0_0_ax1_0_0_fused 
// 2)
+                                v1 = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 
2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_ax1_fused)
+                                v2 = T.axis.spatial(2, ax2 + ax2_1)
+                                v3 = T.axis.spatial(1, ax3)
+                                v4_o = T.axis.spatial(1, 0)
+                                v5_o = T.axis.spatial(1, 0)
+                                T.reads(C_reindex_shared_wmma_accumulator[v0, 
v1, v2, v3, 0:16, 0:16])
+                                T.writes(C_reindex_shared[v0, v1, v2, v3, 
0:16, 0:16])
+                                T.block_attr({"meta_schedule.auto_tensorize": 
"wmma_store_16x16x16_f32_shared"})
+                                for ax4, ax5 in T.grid(16, 16):
+                                    with 
T.block("C_reindex_shared_wmma.accumulator"):
+                                        v4_i, v5_i = T.axis.remap("SS", [ax4, 
ax5])
+                                        
T.reads(C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i])
+                                        T.writes(C_reindex_shared[v0, v1, v2, 
v3, v4_i, v5_i])
+                                        C_reindex_shared[v0, v1, v2, v3, v4_i, 
v5_i] = C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i]
+                    for ax0_ax1_ax3_ax4_ax5_fused in range(512):
+                        with T.block("C_reindex_shared"):
+                            v0 = T.axis.spatial(4, ax0_0_0_ax1_0_0_fused // 2 
+ 0)
+                            v1 = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 
4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_ax1_ax3_ax4_ax5_fused % 512 // 256)
+                            v2 = T.axis.spatial(2, ax2)
+                            v3 = T.axis.spatial(1, 0)
+                            v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused 
% 256 // 16)
+                            v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused 
% 16)
+                            T.where(ax0_0_0_ax1_0_0_fused // 2 * 32 + ax2 * 16 
+ ax0_ax1_ax3_ax4_ax5_fused % 256 // 16 < 127 and ax0_0_0_ax1_0_0_fused % 2 * 
64 + ax0_0_1_ax1_0_1_fused * 32 + ax0_ax1_ax3_ax4_ax5_fused % 512 // 256 * 16 + 
ax0_ax1_ax3_ax4_ax5_fused % 16 < 127)
+                            T.reads(C_reindex_shared[v0, v1, v2, v3, v4, v5])
+                            T.writes(compute[v4 + v2 * 16 + v0 * 32, v5 + v1 * 
16])
+                            T.block_attr({"meta_schedule.cooperative_fetch": 
4})
+                            compute[v4 + v2 * 16 + v0 * 32, v5 + v1 * 16] = 
T.max(C_reindex_shared[v0, v1, v2, v3, v4, v5], T.float32(0))
+
     # fmt: on
 
     decision_0 = [
@@ -994,25 +904,25 @@ def test_conv_1x1():
     @T.prim_func
     def conv2d_1x1_0(inputs: T.Buffer((1, 16, 16, 64), "float16"), weight: 
T.Buffer((1, 1, 64, 64), "float16"), conv2d_nhwc: T.Buffer((1, 16, 16, 64), 
"float32")) -> None:
         T.func_attr({"global_symbol": "main", "tir.noalias": True})
-        conv2d_nhwc_reindex_shared = T.alloc_buffer([256, 64], 
dtype="float32", scope="shared")
-        conv2d_nhwc_reindex_shared_wmma_accumulator = T.alloc_buffer([256, 
64], dtype="float32", scope="wmma.accumulator")
-        PadInput_reindex_shared = T.alloc_buffer([256, 64], dtype="float16", 
scope="shared")
-        weight_reindex_shared = T.alloc_buffer([1, 1, 64, 64], 
dtype="float16", scope="shared")
-        PadInput_reindex_shared_wmma_matrix_a = T.alloc_buffer([256, 64], 
dtype="float16", scope="wmma.matrix_a")
-        weight_reindex_shared_wmma_matrix_b = T.alloc_buffer([1, 1, 64, 64], 
dtype="float16", scope="wmma.matrix_b")
+        conv2d_nhwc_reindex_shared = T.alloc_buffer((16, 4, 1, 1, 16, 16), 
scope="shared")
+        conv2d_nhwc_reindex_shared_wmma_accumulator = T.alloc_buffer((16, 4, 
1, 1, 16, 16), scope="wmma.accumulator")
+        PadInput_reindex_shared = T.alloc_buffer((256, 64), "float16", 
scope="shared")
+        weight_reindex_shared = T.alloc_buffer((1, 1, 64, 64), "float16", 
scope="shared")
+        PadInput_reindex_shared_wmma_matrix_a = T.alloc_buffer((256, 64), 
"float16", scope="wmma.matrix_a")
+        weight_reindex_shared_wmma_matrix_b = T.alloc_buffer((1, 1, 64, 64), 
"float16", scope="wmma.matrix_b")
         for ax2_0_0_ax3_0_0_fused in T.thread_binding(16, thread="blockIdx.y"):
             for ax2_0_1_ax3_0_1_fused in T.thread_binding(2, 
thread="blockIdx.x"):
                 for ax2_0_2_ax3_0_2_fused in T.thread_binding(2, 
thread="threadIdx.y"):
                     for ax0_0, ax1_0, ax4_0_0 in T.grid(1, 1, 1):
-                        for ax0_ax1_fused in T.serial(1024):
+                        for ax0_ax1_fused in range(1024):
                             with T.block("PadInput_reindex_shared"):
                                 v0 = T.axis.spatial(256, ax2_0_0_ax3_0_0_fused 
// 2 * 32 + ax2_0_1_ax3_0_1_fused * 16 + ax0_ax1_fused // 64)
                                 v1 = T.axis.spatial(64, ax0_ax1_fused % 64)
                                 T.reads(inputs[v0 // 256, v0 // 16, v0 % 16, 
v1])
                                 T.writes(PadInput_reindex_shared[v0, v1])
-                                T.block_attr({"buffer_dim_align":[[0, 0, 32, 
8]], "meta_schedule.cooperative_fetch":1})
+                                T.block_attr({"buffer_dim_align": [[0, 0, 32, 
8]], "meta_schedule.cooperative_fetch": 1})
                                 PadInput_reindex_shared[v0, v1] = inputs[v0 // 
256, v0 // 16, v0 % 16, v1]
-                        for ax0_ax1_ax2_ax3_fused in T.serial(2048):
+                        for ax0_ax1_ax2_ax3_fused in range(2048):
                             with T.block("weight_reindex_shared"):
                                 v0 = T.axis.spatial(1, 0)
                                 v1 = T.axis.spatial(1, 0)
@@ -1020,16 +930,16 @@ def test_conv_1x1():
                                 v3 = T.axis.spatial(64, ax2_0_0_ax3_0_0_fused 
% 2 * 32 + ax0_ax1_ax2_ax3_fused % 32)
                                 T.reads(weight[v0, v1, v2, v3])
                                 T.writes(weight_reindex_shared[v0, v1, v2, v3])
-                                T.block_attr({"buffer_dim_align":[[0, 2, 32, 
8]], "meta_schedule.cooperative_fetch":4})
+                                T.block_attr({"buffer_dim_align": [[0, 2, 32, 
8]], "meta_schedule.cooperative_fetch": 4})
                                 weight_reindex_shared[v0, v1, v2, v3] = 
weight[v0, v1, v2, v3]
                         for ax0_1, ax1_1, ax4_0_1 in T.grid(1, 1, 1):
                             for ax0_0_1, ax1_0_1 in T.grid(1, 4):
                                 with 
T.block("PadInput_reindex_shared_wmma.matrix_a_o"):
                                     v0_o = T.axis.spatial(16, 
ax2_0_0_ax3_0_0_fused // 2 * 2 + ax2_0_1_ax3_0_1_fused + ax0_0_1)
                                     v1_o = T.axis.spatial(4, ax1_0_1)
-                                    T.reads(PadInput_reindex_shared[v0_o * 16 
: v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
-                                    
T.writes(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o 
* 16 : v1_o * 16 + 16])
-                                    
T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_a_shared"})
+                                    T.reads(PadInput_reindex_shared[v0_o * 
16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16])
+                                    
T.writes(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 
16:v1_o * 16 + 16])
+                                    
T.block_attr({"meta_schedule.auto_tensorize": 
"wmma_load_16x16x16_f16_a_shared"})
                                     for ax0_1_1, ax1_1_1 in T.grid(16, 16):
                                         with 
T.block("PadInput_reindex_shared_wmma.matrix_a"):
                                             v0_i, v1_i = T.axis.remap("SS", 
[ax0_1_1, ax1_1_1])
@@ -1040,9 +950,9 @@ def test_conv_1x1():
                                 with 
T.block("weight_reindex_shared_wmma.matrix_b_o"):
                                     v0, v1, v2_o = T.axis.remap("SSS", [ax0, 
ax1, ax2_0])
                                     v3_o = T.axis.spatial(4, 
ax2_0_0_ax3_0_0_fused % 2 * 2 + ax2_0_2_ax3_0_2_fused + ax3_0)
-                                    T.reads(weight_reindex_shared[v0, v1, v2_o 
* 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16])
-                                    
T.writes(weight_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16 : v2_o * 16 + 
16, v3_o * 16 : v3_o * 16 + 16])
-                                    
T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_b_shared"})
+                                    T.reads(weight_reindex_shared[v0, v1, v2_o 
* 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16])
+                                    
T.writes(weight_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16:v2_o * 16 + 16, 
v3_o * 16:v3_o * 16 + 16])
+                                    
T.block_attr({"meta_schedule.auto_tensorize": 
"wmma_load_16x16x16_f16_b_shared"})
                                     for ax2_1, ax3_1 in T.grid(16, 16):
                                         with 
T.block("weight_reindex_shared_wmma.matrix_b"):
                                             v2_i, v3_i = T.axis.remap("SS", 
[ax2_1, ax3_1])
@@ -1056,44 +966,53 @@ def test_conv_1x1():
                                     v2_o = T.axis.spatial(16, ax2_0_4 + 
ax2_0_0_ax3_0_0_fused // 2 * 2 + ax2_0_1_ax3_0_1_fused + ax2_0_3)
                                     v3_o = T.axis.spatial(4, ax3_0_4 + 
ax2_0_0_ax3_0_0_fused % 2 * 2 + ax2_0_2_ax3_0_2_fused + ax3_0_3)
                                     v4_o = T.axis.reduce(4, ax4_0_0 * 4 + 
ax4_0_1 * 4 + ax4_0_2)
-                                    
T.reads(PadInput_reindex_shared_wmma_matrix_a[v2_o * 16 : v2_o * 16 + 16, v4_o 
* 16 : v4_o * 16 + 16], weight_reindex_shared_wmma_matrix_b[v0, v1, v4_o * 16 : 
v4_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16])
-                                    
T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 : v2_o * 16 + 
16, v3_o * 16 : v3_o * 16 + 16])
-                                    
T.block_attr({"meta_schedule.auto_tensorize":"wmma_sync_16x16x16_f16f16f32", 
"meta_schedule.auto_tensorize_init":"wmma_fill_16x16x16_f32", 
"warp_execution":1})
+                                    
T.reads(PadInput_reindex_shared_wmma_matrix_a[v2_o * 16:v2_o * 16 + 16, v4_o * 
16:v4_o * 16 + 16], weight_reindex_shared_wmma_matrix_b[v0, v1, v4_o * 16:v4_o 
* 16 + 16, v3_o * 16:v3_o * 16 + 16])
+                                    
T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o, v3_o, 0, 0, 0:16, 
0:16])
+                                    
T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32", 
"meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32", 
"warp_execution": 1})
                                     with T.init():
                                         for ax2_1, ax3_1 in T.grid(16, 16):
                                             with T.block("conv2d_nhwc_init"):
                                                 v2_i_init, v3_i_init = 
T.axis.remap("SS", [ax2_1, ax3_1])
                                                 T.reads()
-                                                
T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i_init, 
v3_o * 16 + v3_i_init])
-                                                
conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i_init, v3_o * 16 + 
v3_i_init] = T.float32(0)
+                                                
T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o, v3_o, 0, 0, 
v2_i_init, v3_i_init])
+                                                
conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o, v3_o, 0, 0, v2_i_init, 
v3_i_init] = T.float32(0)
                                     for ax2_1, ax3_1, ax4_1 in T.grid(16, 16, 
16):
                                         with T.block("conv2d_nhwc"):
                                             v2_i, v3_i, v4_i = 
T.axis.remap("SSR", [ax2_1, ax3_1, ax4_1])
-                                            
T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 
+ v3_i], PadInput_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + 
v4_i], weight_reindex_shared_wmma_matrix_b[v0, v1, v4_o * 16 + v4_i, v3_o * 16 
+ v3_i])
-                                            
T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 
16 + v3_i])
-                                            
T.block_attr({"meta_schedule.tiling_structure":"SSSRRSRS"})
-                                            
conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i] 
= conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + 
v3_i] + T.cast(PadInput_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 
16 + v4_i], "float32") * T.cast(weight_reindex_shared_wmma_matrix_b[v0, v1, 
v4_o * 16 + v4_i, v3_o * 16 + v3_i], "float32")
-                    for ax0_0, ax1_0 in T.grid(1, 1):
-                        with 
T.block("conv2d_nhwc_reindex_shared_wmma.accumulator_o"):
-                            v0_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused // 
2 * 2 + ax2_0_1_ax3_0_1_fused + ax0_0)
-                            v1_o = T.axis.spatial(4, ax2_0_0_ax3_0_0_fused % 2 
* 2 + ax2_0_2_ax3_0_2_fused + ax1_0)
-                            
T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, 
v1_o * 16 : v1_o * 16 + 16])
-                            T.writes(conv2d_nhwc_reindex_shared[v0_o * 16 : 
v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
-                            
T.block_attr({"meta_schedule.auto_tensorize":"wmma_store_16x16x16_f32_shared"})
-                            for ax0_1, ax1_1 in T.grid(16, 16):
-                                with 
T.block("conv2d_nhwc_reindex_shared_wmma.accumulator"):
-                                    v0_i, v1_i = T.axis.remap("SS", [ax0_1, 
ax1_1])
-                                    
T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 
+ v1_i])
-                                    T.writes(conv2d_nhwc_reindex_shared[v0_o * 
16 + v0_i, v1_o * 16 + v1_i])
-                                    conv2d_nhwc_reindex_shared[v0_o * 16 + 
v0_i, v1_o * 16 + v1_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 
+ v0_i, v1_o * 16 + v1_i]
-                for ax0_ax1_fused in T.serial(512):
-                    with T.block("conv2d_nhwc_reindex_shared"):
-                        v0 = T.axis.spatial(256, ax2_0_0_ax3_0_0_fused // 2 * 
32 + ax2_0_1_ax3_0_1_fused * 16 + ax0_ax1_fused // 32)
-                        v1 = T.axis.spatial(64, ax2_0_0_ax3_0_0_fused % 2 * 32 
+ ax0_ax1_fused % 32)
-                        T.reads(conv2d_nhwc_reindex_shared[v0, v1])
-                        T.writes(conv2d_nhwc[v0 // 256, v0 // 16, v0 % 16, v1])
-                        T.block_attr({"meta_schedule.cooperative_fetch":2})
-                        conv2d_nhwc[v0 // 256, v0 // 16, v0 % 16, v1] = 
conv2d_nhwc_reindex_shared[v0, v1]
+                                            
T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o, v3_o, 0, 0, v2_i, 
v3_i], PadInput_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + 
v4_i], weight_reindex_shared_wmma_matrix_b[v0, v1, v4_o * 16 + v4_i, v3_o * 16 
+ v3_i])
+                                            
T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o, v3_o, 0, 0, v2_i, 
v3_i])
+                                            
T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"})
+                                            
conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o, v3_o, 0, 0, v2_i, v3_i] = 
conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o, v3_o, 0, 0, v2_i, v3_i] + 
T.Cast("float32", PadInput_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o 
* 16 + v4_i]) * T.Cast("float32", weight_reindex_shared_wmma_matrix_b[v0, v1, 
v4_o * 16 + v4_i, v3_o * 16 + v3_i])
+                for ax2 in range(1):
+                    for ax0_ax1_fused in T.thread_binding(2, 
thread="threadIdx.y"):
+                        for ax2_1, ax3 in T.grid(1, 1):
+                            with 
T.block("conv2d_nhwc_reindex_shared_wmma.accumulator_o"):
+                                v0 = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused 
// 2 * 2 + ax2_0_1_ax3_0_1_fused)
+                                v1 = T.axis.spatial(4, ax2_0_0_ax3_0_0_fused % 
2 * 2 + ax0_ax1_fused)
+                                v2, v3 = T.axis.remap("SS", [ax2_1, ax3])
+                                v4_o = T.axis.spatial(1, 0)
+                                v5_o = T.axis.spatial(1, 0)
+                                
T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0, v1, v2, v3, 0:16, 0:16])
+                                T.writes(conv2d_nhwc_reindex_shared[v0, v1, 
v2, v3, 0:16, 0:16])
+                                T.block_attr({"meta_schedule.auto_tensorize": 
"wmma_store_16x16x16_f32_shared"})
+                                for ax4, ax5 in T.grid(16, 16):
+                                    with 
T.block("conv2d_nhwc_reindex_shared_wmma.accumulator"):
+                                        v4_i, v5_i = T.axis.remap("SS", [ax4, 
ax5])
+                                        
T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i])
+                                        
T.writes(conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4_i, v5_i])
+                                        conv2d_nhwc_reindex_shared[v0, v1, v2, 
v3, v4_i, v5_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0, v1, v2, v3, 
v4_i, v5_i]
+                    for ax0_ax1_ax3_ax4_ax5_fused in range(512):
+                        with T.block("conv2d_nhwc_reindex_shared"):
+                            v0 = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused // 2 
* 2 + ax2_0_1_ax3_0_1_fused)
+                            v1 = T.axis.spatial(4, ax2_0_0_ax3_0_0_fused % 2 * 
2 + ax0_ax1_ax3_ax4_ax5_fused // 256)
+                            v2 = T.axis.spatial(1, ax2)
+                            v3 = T.axis.spatial(1, 0)
+                            v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused 
% 256 // 16)
+                            v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused 
% 16)
+                            T.reads(conv2d_nhwc_reindex_shared[v0, v1, v2, v3, 
v4, v5])
+                            T.writes(conv2d_nhwc[(v4 + v0 * 16) // 256, (v4 + 
v0 * 16) // 16, (v4 + v0 * 16) % 16, v5 + v1 * 16])
+                            T.block_attr({"meta_schedule.cooperative_fetch": 
2})
+                            conv2d_nhwc[(v4 + v0 * 16) // 256, (v4 + v0 * 16) 
// 16, (v4 + v0 * 16) % 16, v5 + v1 * 16] = conv2d_nhwc_reindex_shared[v0, v1, 
v2, v3, v4, v5]
     # fmt: on
 
     decision_0 = [
diff --git 
a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py 
b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py
index 1e5fd8843b..b9f35ed553 100644
--- a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py
+++ b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py
@@ -1124,7 +1124,7 @@ def test_simple_compute_async():
                 B = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared")
                 with T.block():
                     T.reads(A[tx, 0])
-                    T.writes(B[0, tx, 0])
+                    T.writes(B[T.FloorMod(0, 2), tx, 0])
                     with T.attr(0, "async_commit_queue_scope", 0):
                         with T.attr(0, "async_scope", 1):
                             B[T.FloorMod(0, 2), tx, 0] = A[tx, 0] * 
T.float32(2)
@@ -1350,8 +1350,8 @@ def test_three_stage_compute_two_stage_async():
                                     B[i % 2, tx, 0] = A[tx, i] * T.float32(2)
                         with T.block():
                             T.where(i == 1 and i - 1 < 16)
-                            T.reads(B[(i + 1) % 2, tx, 0])
-                            T.writes(C[(i + 1) % 2, tx, 0])
+                            T.reads(B[(i - 1) % 2, tx, 0])
+                            T.writes(C[(i - 1) % 2, tx, 0])
                             with T.attr(0, "async_commit_queue_scope", 1):
                                 with T.attr(0, "async_wait_queue_scope", 0):
                                     with T.attr(0, 
"async_wait_inflight_count", 1):
@@ -1366,14 +1366,14 @@ def test_three_stage_compute_two_stage_async():
                         with T.block():
                             T.where(i + 2 < 16)
                             T.reads(A[tx, i + 2])
-                            T.writes(B[i % 2, tx, 0])
+                            T.writes(B[(i + 2) % 2, tx, 0])
                             with T.attr(0, "async_commit_queue_scope", 0):
                                 with T.attr(0, "async_scope", 1):
                                     B[(i + 2) % 2, tx, 0] = A[tx, i + 2] * 
T.float32(2)
                         with T.block():
                             T.where(i + 2 - 1 < 16)
-                            T.reads(B[(i + 1) % 2, tx, 0])
-                            T.writes(C[(i + 1) % 2, tx, 0])
+                            T.reads(B[(i - 1 + 2) % 2, tx, 0])
+                            T.writes(C[(i - 1 + 2) % 2, tx, 0])
                             with T.attr(0, "async_commit_queue_scope", 1):
                                 with T.attr(0, "async_wait_queue_scope", 0):
                                     with T.attr(0, 
"async_wait_inflight_count", 1):
@@ -1394,8 +1394,8 @@ def test_three_stage_compute_two_stage_async():
                     for i in T.unroll(2):
                         with T.block():
                             T.where(i + 16 - 1 < 16)
-                            T.reads(B[(i + 1) % 2, tx, 0])
-                            T.writes(C[(i + 1) % 2, tx, 0])
+                            T.reads(B[(i - 1 + 16) % 2, tx, 0])
+                            T.writes(C[(i - 1 + 16) % 2, tx, 0])
                             with T.attr(0, "async_commit_queue_scope", 1):
                                 with T.attr(0, "async_wait_queue_scope", 0):
                                     with T.attr(0, 
"async_wait_inflight_count", 0 - i):

Reply via email to