Lunderberg commented on issue #14100:
URL: https://github.com/apache/tvm/issues/14100#issuecomment-1456332508

   @LeiWang1999 I have a possible solution, which works for the test script you 
provided.  Can you apply the diff below, and see if it resolves the issue?
   
   <details>
   <summary>Click to expand diff</summary>
   
   ```
   diff --git a/src/tir/schedule/primitive/blockize_tensorize.cc 
b/src/tir/schedule/primitive/blockize_tensorize.cc
   index 6860927c4..b97a1d872 100644
   --- a/src/tir/schedule/primitive/blockize_tensorize.cc
   +++ b/src/tir/schedule/primitive/blockize_tensorize.cc
   @@ -172,17 +172,19 @@ Array<Array<arith::IterMark>> SubspaceDivide(const 
BlockRealize& realize,
          inner = false;
        }
      }
   -  Array<Array<arith::IterMark>> result =
   -      arith::SubspaceDivide(realize->iter_values, loop_var_domain, 
inner_vars, realize->predicate,
   -                            arith::IterMapLevel::Surjective, analyzer,
   -                            
/*simplify_trivial_iterators=*/!preserve_unit_iters);
   +
   +  auto result = TrivialSubspaceDivision(realize->block->iter_vars,
   +                                        realize->iter_values,  //
   +                                        realize->predicate,    //
   +                                        outer_vars, inner_vars);
   +
      if (!result.empty()) {
        return result;
      }
   -  return TrivialSubspaceDivision(realize->block->iter_vars,
   -                                 realize->iter_values,  //
   -                                 realize->predicate,    //
   -                                 outer_vars, inner_vars);
   +
   +  return arith::SubspaceDivide(realize->iter_values, loop_var_domain, 
inner_vars,
   +                               realize->predicate, 
arith::IterMapLevel::Surjective, analyzer,
   +                               
/*simplify_trivial_iterators=*/!preserve_unit_iters);
    }
    
    /*!
   ```
   </details>
   
   <details>
   <summary>Click to expand TVMScript after `sch.tensorize` with fix</summary>
   
   ```python
   @T.prim_func
   def main(
       A: T.Buffer((16, 14, 14, 16, 16, 16), "float16"),
       W: T.Buffer((3, 3, 16, 32, 16, 16), "float16"),
       Conv: T.Buffer((16, 14, 14, 32, 16, 16), "float16"),
   ):
       T.func_attr({"global_symbol": "main", "tir.noalias": True})
       # with T.block("root"):
       Apad = T.alloc_buffer((16, 16, 16, 16, 16, 16), "float16")
       Apad_shared = T.alloc_buffer((16, 16, 16, 16, 16, 16), "float16", 
scope="shared")
       Apad_shared_wmma_matrix_a = T.alloc_buffer(
           (16, 16, 16, 16, 16, 16), "float16", scope="wmma.matrix_a"
       )
       W_shared = T.alloc_buffer((3, 3, 16, 32, 16, 16), "float16", 
scope="shared")
       W_shared_wmma_matrix_b = T.alloc_buffer(
           (3, 3, 16, 32, 16, 16), "float16", scope="wmma.matrix_b"
       )
       Conv_wmma_accumulator = T.alloc_buffer(
           (16, 14, 14, 32, 16, 16), "float16", scope="wmma.accumulator"
       )
       for n, h, w, i, nn, ii in T.grid(16, 16, 16, 16, 16, 16):
           with T.block("Apad_pad_const"):
               v_n, v_h, v_w, v_i, v_nn, v_ii = T.axis.remap(
                   "SSSSSS", [n, h, w, i, nn, ii]
               )
               T.reads()
               T.writes(Apad[v_n, v_h, v_w, v_i, v_nn, v_ii])
               Apad[v_n, v_h, v_w, v_i, v_nn, v_ii] = T.float16(0)
       for n, h, w, i, nn, ii in T.grid(16, 14, 14, 16, 16, 16):
           with T.block("Apad"):
               v_n, v_h, v_w, v_i, v_nn, v_ii = T.axis.remap(
                   "SSSSSS", [n, h, w, i, nn, ii]
               )
               T.reads(A[v_n, v_h, v_w, v_i, v_nn, v_ii])
               T.writes(Apad[v_n, v_h + 1, v_w + 1, v_i, v_nn, v_ii])
               Apad[v_n, v_h + 1, v_w + 1, v_i, v_nn, v_ii] = A[
                   v_n, v_h, v_w, v_i, v_nn, v_ii
               ]
       for n_0_0 in T.thread_binding(2, thread="blockIdx.x"):
           for o_0_0 in T.thread_binding(4, thread="blockIdx.y"):
               for n_0_1 in T.thread_binding(4, thread="threadIdx.y"):
                   for h, w in T.grid(14, 14):
                       for o_0_1 in T.thread_binding(2, thread="threadIdx.z"):
                           for n_1_init, o_1_init, nn_init, oo_init in T.grid(
                               2, 4, 16, 16
                           ):
                               with T.block("Conv_init"):
                                   v_n = T.axis.spatial(
                                       16, n_0_0 * 8 + n_0_1 * 2 + n_1_init
                                   )
                                   v_h, v_w = T.axis.remap("SS", [h, w])
                                   v_o = T.axis.spatial(
                                       32, o_0_0 * 8 + o_0_1 * 4 + o_1_init
                                   )
                                   v_nn, v_oo = T.axis.remap("SS", [nn_init, 
oo_init])
                                   T.reads()
                                   T.writes(
                                       Conv_wmma_accumulator[
                                           v_n, v_h, v_w, v_o, v_nn, v_oo
                                       ]
                                   )
                                   Conv_wmma_accumulator[
                                       v_n, v_h, v_w, v_o, v_nn, v_oo
                                   ] = T.float16(0)
                           for ic_0, kh in T.grid(8, 3):
                               for ax0_1_0 in T.thread_binding(2, 
thread="threadIdx.z"):
                                   for ax0_0 in T.thread_binding(4, 
thread="threadIdx.y"):
                                       for ax3_ax4_fused_1 in T.thread_binding(
                                           32, thread="threadIdx.x"
                                       ):
                                           for (
                                               ax1,
                                               ax2,
                                               ax3_ax4_fused_0,
                                               ax0_1_1,
                                           ) in T.grid(3, 2, 8, 1):
                                               with T.block("Apad_shared"):
                                                   v0 = T.axis.spatial(
                                                       16,
                                                       ax0_1_1
                                                       + n_0_0 * 8
                                                       + ax0_0 * 2
                                                       + ax0_1_0,
                                                   )
                                                   v1 = T.axis.spatial(16, h + 
kh)
                                                   v2 = T.axis.spatial(16, w + 
ax1)
                                                   v3 = T.axis.spatial(16, ic_0 
* 2 + ax2)
                                                   v4 = T.axis.spatial(
                                                       16,
                                                       (
                                                           ax3_ax4_fused_0 * 32
                                                           + ax3_ax4_fused_1
                                                       )
                                                       // 16,
                                                   )
                                                   v5 = T.axis.spatial(
                                                       16,
                                                       (
                                                           ax3_ax4_fused_0 * 32
                                                           + ax3_ax4_fused_1
                                                       )
                                                       % 16,
                                                   )
                                                   T.reads(Apad[v0, v1, v2, v3, 
v4, v5])
                                                   T.writes(
                                                       Apad_shared[v0, v1, v2, 
v3, v4, v5]
                                                   )
                                                   Apad_shared[
                                                       v0, v1, v2, v3, v4, v5
                                                   ] = Apad[v0, v1, v2, v3, v4, 
v5]
                               for ax0_ax1_ax2_fused_0 in T.thread_binding(
                                   4, thread="threadIdx.y"
                               ):
                                   for ax0_ax1_ax2_fused_1_0 in 
T.thread_binding(
                                       2, thread="threadIdx.z"
                                   ):
                                       for (
                                           ax0_ax1_ax2_fused_1_1,
                                           ax3_ax4_fused_0,
                                       ) in T.grid(6, 8):
                                           for ax3_ax4_fused_1 in 
T.thread_binding(
                                               32, thread="threadIdx.x"
                                           ):
                                               with T.block("W_shared"):
                                                   v0 = T.axis.spatial(3, kh)
                                                   v1 = T.axis.spatial(
                                                       3,
                                                       (
                                                           ax0_ax1_ax2_fused_0 
* 12
                                                           + 
ax0_ax1_ax2_fused_1_0 * 6
                                                           + 
ax0_ax1_ax2_fused_1_1
                                                       )
                                                       // 16,
                                                   )
                                                   v2 = T.axis.spatial(
                                                       16,
                                                       ic_0 * 2
                                                       + (
                                                           ax0_ax1_ax2_fused_0 
* 12
                                                           + 
ax0_ax1_ax2_fused_1_0 * 6
                                                           + 
ax0_ax1_ax2_fused_1_1
                                                       )
                                                       % 16
                                                       // 8,
                                                   )
                                                   v3 = T.axis.spatial(
                                                       32,
                                                       o_0_0 * 8
                                                       + (
                                                           ax0_ax1_ax2_fused_0 
* 12
                                                           + 
ax0_ax1_ax2_fused_1_0 * 6
                                                           + 
ax0_ax1_ax2_fused_1_1
                                                       )
                                                       % 8,
                                                   )
                                                   v4 = T.axis.spatial(
                                                       16,
                                                       (
                                                           ax3_ax4_fused_0 * 32
                                                           + ax3_ax4_fused_1
                                                       )
                                                       // 16,
                                                   )
                                                   v5 = T.axis.spatial(
                                                       16,
                                                       (
                                                           ax3_ax4_fused_0 * 32
                                                           + ax3_ax4_fused_1
                                                       )
                                                       % 16,
                                                   )
                                                   T.reads(W[v0, v1, v2, v3, 
v4, v5])
                                                   T.writes(
                                                       W_shared[v0, v1, v2, v3, 
v4, v5]
                                                   )
                                                   W_shared[v0, v1, v2, v3, v4, 
v5] = W[
                                                       v0, v1, v2, v3, v4, v5
                                                   ]
                               for ic_1, kw in T.grid(2, 3):
                                   for ax0 in range(2):
                                       with 
T.block("Apad_shared_wmma.matrix_a_o"):
                                           v0 = T.axis.spatial(
                                               16, n_0_0 * 8 + n_0_1 * 2 + ax0
                                           )
                                           v1 = T.axis.spatial(16, h + kh)
                                           v2 = T.axis.spatial(16, w + kw)
                                           v3 = T.axis.spatial(16, ic_0 * 2 + 
ic_1)
                                           v4_o = T.axis.spatial(1, 0)
                                           v5_o = T.axis.spatial(1, 0)
                                           T.reads(Apad_shared[v0, v1, v2, v3, 
0:16, 0:16])
                                           T.writes(
                                               Apad_shared_wmma_matrix_a[
                                                   v0, v1, v2, v3, 0:16, 0:16
                                               ]
                                           )
                                           A_s0 = T.int32()
                                           A_s1 = T.int32()
                                           A_1 = T.match_buffer(
                                               Apad_shared[v0, v1, v2, v3, 
0:16, 0:16],
                                               (16, 16),
                                               "float16",
                                               strides=(A_s0, A_s1),
                                               scope="shared",
                                               offset_factor=16,
                                           )
                                           C_s0 = T.int32()
                                           C_s1 = T.int32()
                                           C = T.match_buffer(
                                               Apad_shared_wmma_matrix_a[
                                                   v0, v1, v2, v3, 0:16, 0:16
                                               ],
                                               (16, 16),
                                               "float16",
                                               strides=(C_s0, C_s1),
                                               scope="wmma.matrix_a",
                                               offset_factor=16,
                                           )
                                           T.tvm_load_matrix_sync(
                                               C.data,
                                               16,
                                               16,
                                               16,
                                               C.elem_offset // C_s0 // 16 * 
(C_s0 // 16)
                                               + C.elem_offset % C_s0 // 16,
                                               T.tvm_access_ptr(
                                                   T.type_annotation("float16"),
                                                   A_1.data,
                                                   A_1.elem_offset,
                                                   A_s0 * 16,
                                                   1,
                                               ),
                                               A_s0,
                                               "row_major",
                                           )
                                   for ax0, ax1, ax2 in T.grid(4, 16, 16):
                                       with T.block("W_shared_wmma.matrix_b"):
                                           v0, v1 = T.axis.remap("SS", [kh, kw])
                                           v2 = T.axis.spatial(16, ic_0 * 2 + 
ic_1)
                                           v3 = T.axis.spatial(
                                               32, o_0_0 * 8 + o_0_1 * 4 + ax0
                                           )
                                           v4, v5 = T.axis.remap("SS", [ax1, 
ax2])
                                           T.reads(W_shared[v0, v1, v2, v3, v4, 
v5])
                                           T.writes(
                                               W_shared_wmma_matrix_b[
                                                   v0, v1, v2, v3, v4, v5
                                               ]
                                           )
                                           W_shared_wmma_matrix_b[
                                               v0, v1, v2, v3, v4, v5
                                           ] = W_shared[v0, v1, v2, v3, v4, v5]
                                   for n_1, o_1, nn, oo, ii in T.grid(2, 4, 16, 
16, 16):
                                       with T.block("Conv_update"):
                                           v_n = T.axis.spatial(
                                               16, n_0_0 * 8 + n_0_1 * 2 + n_1
                                           )
                                           v_h, v_w = T.axis.remap("SS", [h, w])
                                           v_o = T.axis.spatial(
                                               32, o_0_0 * 8 + o_0_1 * 4 + o_1
                                           )
                                           v_nn, v_oo = 
T.axis.remap("</summary>SS", [nn, oo])
                                           v_ic = T.axis.reduce(16, ic_0 * 2 + 
ic_1)
                                           v_kh, v_kw, v_ii = T.axis.remap(
                                               "RRR", [kh, kw, ii]
                                           )
                                           T.reads(
                                               Conv_wmma_accumulator[
                                                   v_n, v_h, v_w, v_o, v_nn, 
v_oo
                                               ],
                                               Apad_shared_wmma_matrix_a[
                                                   v_n,
                                                   v_h + v_kh,
                                                   v_w + v_kw,
                                                   v_ic,
                                                   v_nn,
                                                   v_ii,
                                               ],
                                               W_shared_wmma_matrix_b[
                                                   v_kh, v_kw, v_ic, v_o, v_ii, 
v_oo
                                               ],
                                           )
                                           T.writes(
                                               Conv_wmma_accumulator[
                                                   v_n, v_h, v_w, v_o, v_nn, 
v_oo
                                               ]
                                           )
                                           Conv_wmma_accumulator[
                                               v_n, v_h, v_w, v_o, v_nn, v_oo
                                           ] = (
                                               Conv_wmma_accumulator[
                                                   v_n, v_h, v_w, v_o, v_nn, 
v_oo
                                               ]
                                               + Apad_shared_wmma_matrix_a[
                                                   v_n,
                                                   v_h + v_kh,
                                                   v_w + v_kw,
                                                   v_ic,
                                                   v_nn,
                                                   v_ii,
                                               ]
                                               * W_shared_wmma_matrix_b[
                                                   v_kh, v_kw, v_ic, v_o, v_ii, 
v_oo
                                               ]
                                           )
                           for ax0, ax1, ax2, ax3 in T.grid(2, 4, 16, 16):
                               with T.block("Conv_wmma.accumulator"):
                                   v0 = T.axis.spatial(16, n_0_0 * 8 + n_0_1 * 
2 + ax0)
                                   v1, v2 = T.axis.remap("SS", [h, w])
                                   v3 = T.axis.spatial(32, o_0_0 * 8 + o_0_1 * 
4 + ax1)
                                   v4, v5 = T.axis.remap("SS", [ax2, ax3])
                                   T.reads(Conv_wmma_accumulator[v0, v1, v2, 
v3, v4, v5])
                                   T.writes(Conv[v0, v1, v2, v3, v4, v5])
                                   Conv[v0, v1, v2, v3, v4, v5] = 
Conv_wmma_accumulator[
                                       v0, v1, v2, v3, v4, v5
                                   ]
   ```
   </details>
   
   It looks like there were a couple of moving pieces that resulted in the 
error.  First, the `compute_at` did the right thing, and rewrote the iter vars 
`ax1` and `ax2` in terms of `h`, `w`, `kh`, and `kw`.  The `T.axis.spatial(16, 
h + kh)` was correct, as this produces a spatial axis for the 
`Apad_shared_wmma_matrix_a` buffer, corresponding to `ax1`.  However, this 
isn't injective with respect to `h` and `kh`, and so `sch.tensorize` fails 
because the call to `arith::SubspaceDivide` only handles injective transforms.
   
   The fix swaps the order in which `sch.tensorize` attempts to use 
`arith::SubspaceDivide` and a simpler `TrivialSubspaceDivision`, so that the 
`TrivialSubspaceDivision` check is applied first.  This check only inspects the 
loop iterators used by each expression, and doesn't require injective axes like 
`arith::SubspaceDivide`, and is able to recognize the tensorizable portion in 
the test script.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to