wrongtest-intellif commented on code in PR #15961:
URL: https://github.com/apache/tvm/pull/15961#discussion_r1368057844


##########
src/tir/transforms/inject_ptx_async_copy.cc:
##########
@@ -113,9 +116,11 @@ class PTXAsyncCopyInjector : public StmtMutator {
             return PrimExpr();
           }();
           if (src_offset.defined() && dst_offset.defined()) {

Review Comment:
   Could you add one unittest for related circumstance? thank you!



##########
src/tir/transforms/ir_utils.cc:
##########
@@ -417,7 +417,8 @@ Array<PrimExpr> GetBufferAllocationShape(const Buffer& 
buffer) {
   if (buffer->strides.size()) {
     ICHECK_EQ(buffer->shape.size(), buffer->strides.size());
     for (size_t i = buffer->strides.size() - 1; i > 0; --i) {
-      ICHECK(is_zero(floormod(buffer->strides[i - 1], buffer->strides[i])));

Review Comment:
   Could you add one unittest for the change? Eg, check the TIR before and 
after `LowerOpaqueBlock`:
   The whole script is as below, of course, we can only keep part of it 
including `lv55_reindex_pad_shared_dyn = T.alloc_buffer((T.int64(1), T.min((n + 
T.int64(63)) // T.int64(64) * T.int64(64), T.int64(96)), T.int64(64)), 
"float16", strides=(T.int64(72) * T.min((n + T.int64(63)) // T.int64(64) * 
T.int64(64), T.int64(96)), T.int64(72), T.int64(1)), scope="shared.dyn")`
   ```python
   @T.prim_func(private=True)
   def fused_NT_matmul4_add1(p_lv55: T.handle, lv11: T.Buffer((T.int64(2560), 
T.int64(10240)), "float16"), lv12: T.Buffer((T.int64(2560),), "float16"), 
p_output0: T.handle):
           T.func_attr({"tir.noalias": T.bool(True)})
           n = T.int64()
           lv55 = T.match_buffer(p_lv55, (T.int64(1), n, T.int64(10240)), 
"float16")
           var_T_add_intermediate = T.match_buffer(p_output0, (T.int64(1), n, 
T.int64(2560)), "float16")
           # with T.block("root"):
           for blockIdx_z in T.thread_binding(T.int64(1), thread="blockIdx.z"):
               for blockIdx_x in T.thread_binding(((n + T.int64(63)) // 
T.int64(64) * T.int64(4) + T.int64(7)) // T.int64(8), thread="blockIdx.x"):
                   for blockIdx_y in T.thread_binding(T.int64(20), 
thread="blockIdx.y"):
                       for threadIdx_x in T.thread_binding(T.int64(32), 
thread="threadIdx.x"):
                           for threadIdx_y in T.thread_binding(T.int64(4), 
thread="threadIdx.y"):
                               with T.block(""):
                                   T.reads(lv55[T.int64(0), blockIdx_x * 
T.int64(128):blockIdx_x * T.int64(128) + T.int64(128), 
T.int64(0):T.int64(10240)], lv11[blockIdx_y * T.int64(128):blockIdx_y * 
T.int64(128) + T.int64(128), T.int64(0):T.int64(10240)], lv12[blockIdx_y * 
T.int64(128) + threadIdx_y // T.int64(2) * T.int64(64):blockIdx_y * 
T.int64(128) + threadIdx_y // T.int64(2) * T.int64(64) + T.int64(64)])
                                   T.writes(var_T_add_intermediate[T.int64(0), 
blockIdx_x * T.int64(128) + threadIdx_y % T.int64(2) * T.int64(64):blockIdx_x * 
T.int64(128) + threadIdx_y % T.int64(2) * T.int64(64) + T.int64(64), blockIdx_y 
* T.int64(128) + threadIdx_y // T.int64(2) * T.int64(64):blockIdx_y * 
T.int64(128) + threadIdx_y // T.int64(2) * T.int64(64) + T.int64(64)])
                                   
var_NT_matmul_intermediate_reindex_pad_shared_dyn = T.alloc_buffer((T.int64(1), 
T.int64(128), T.int64(128)), "float16", scope="shared.dyn")
                                   
var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator = 
T.alloc_buffer((T.int64(1), T.int64(64), T.int64(64)), "float16", 
scope="wmma.accumulator")
                                   for ax1_0_3, ax2_0_3 in T.grid(T.int64(2), 
T.int64(2)):
                                       for ax1_0_4_init, ax2_0_4_init in 
T.grid(T.int64(2), T.int64(2)):
                                           with T.block("NT_matmul_o_init"):
                                               T.where(blockIdx_x * T.int64(2) 
+ threadIdx_y % T.int64(2) - (n + T.int64(63)) // T.int64(64) < T.int64(0))
                                               T.reads()
                                               
T.writes(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[T.int64(0),
 ax1_0_3 * T.int64(32) + ax1_0_4_init * T.int64(16):ax1_0_3 * T.int64(32) + 
ax1_0_4_init * T.int64(16) + T.int64(16), ax2_0_3 * T.int64(32) + ax2_0_4_init 
* T.int64(16):ax2_0_3 * T.int64(32) + ax2_0_4_init * T.int64(16) + T.int64(16)])
                                               with T.block("NT_matmul_init_o"):
                                                   T.reads()
                                                   
T.writes(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[T.int64(0),
 ax1_0_3 * T.int64(32) + ax1_0_4_init * T.int64(16):ax1_0_3 * T.int64(32) + 
ax1_0_4_init * T.int64(16) + T.int64(16), ax2_0_3 * T.int64(32) + ax2_0_4_init 
* T.int64(16):ax2_0_3 * T.int64(32) + ax2_0_4_init * T.int64(16) + T.int64(16)])
                                                   
T.tvm_fill_fragment(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator.data,
 16, 16, 16, ax1_0_3 * T.int64(8) + ax1_0_4_init * T.int64(4) + ax2_0_3 * 
T.int64(2) + ax2_0_4_init, T.float32(0))
                                       for ax3_0_0 in range(T.int64(160)):
                                           with T.block(""):
                                               T.reads(lv55[T.int64(0), 
blockIdx_x * T.int64(128) + ax1_0_3 * T.int64(32):blockIdx_x * T.int64(128) + 
ax1_0_3 * T.int64(32) + T.int64(96), ax3_0_0 * T.int64(64):ax3_0_0 * 
T.int64(64) + T.int64(64)], lv11[blockIdx_y * T.int64(128) + ax2_0_3 * 
T.int64(32):blockIdx_y * T.int64(128) + ax2_0_3 * T.int64(32) + T.int64(96), 
ax3_0_0 * T.int64(64):ax3_0_0 * T.int64(64) + T.int64(64)], 
var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[T.int64(0), 
ax1_0_3 * T.int64(32):ax1_0_3 * T.int64(32) + T.int64(32), ax2_0_3 * 
T.int64(32):ax2_0_3 * T.int64(32) + T.int64(32)])
                                               
T.writes(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[T.int64(0),
 ax1_0_3 * T.int64(32):ax1_0_3 * T.int64(32) + T.int64(32), ax2_0_3 * 
T.int64(32):ax2_0_3 * T.int64(32) + T.int64(32)])
                                               lv55_reindex_pad_shared_dyn = 
T.alloc_buffer((T.int64(1), T.min((n + T.int64(63)) // T.int64(64) * 
T.int64(64), T.int64(96)), T.int64(64)), "float16", strides=(T.int64(72) * 
T.min((n + T.int64(63)) // T.int64(64) * T.int64(64), T.int64(96)), 
T.int64(72), T.int64(1)), scope="shared.dyn")
                                               lv11_reindex_shared_dyn = 
T.alloc_buffer((T.int64(1), T.int64(96), T.int64(64)), "float16", 
strides=(T.int64(6912), T.int64(72), T.int64(1)), scope="shared.dyn")
                                               for ax0_ax1_fused_0 in 
range(T.int64(12)):
                                                   for ax0_ax1_fused_3 in 
T.vectorized(T.int64(4)):
                                                       with 
T.block("lv55_reindex_pad_shared.dyn"):
                                                           T.where(blockIdx_x * 
T.int64(16) + ax1_0_3 * T.int64(4) + ax0_ax1_fused_0 - (n + T.int64(63)) // 
T.int64(64) * T.int64(8) < T.int64(0))
                                                           
T.reads(lv55[T.int64(0), blockIdx_x * T.int64(128) + ax1_0_3 * T.int64(32) + 
ax0_ax1_fused_0 * T.int64(8) + threadIdx_y * T.int64(2) + threadIdx_x // 
T.int64(16), ax3_0_0 * T.int64(64) + threadIdx_x % T.int64(16) * T.int64(4) + 
ax0_ax1_fused_3])
                                                           
T.writes(lv55_reindex_pad_shared_dyn[T.int64(0), ax0_ax1_fused_0 * T.int64(8) + 
threadIdx_y * T.int64(2) + threadIdx_x // T.int64(16), threadIdx_x % 
T.int64(16) * T.int64(4) + ax0_ax1_fused_3])
                                                           
T.block_attr({"buffer_dim_align": [[0, 1, 16, 8]]})
                                                           
lv55_reindex_pad_shared_dyn[T.int64(0), ax0_ax1_fused_0 * T.int64(8) + 
threadIdx_y * T.int64(2) + threadIdx_x // T.int64(16), threadIdx_x % 
T.int64(16) * T.int64(4) + ax0_ax1_fused_3] = T.if_then_else(blockIdx_x * 
T.int64(128) + ax1_0_3 * T.int64(32) + ax0_ax1_fused_0 * T.int64(8) + 
threadIdx_y * T.int64(2) + threadIdx_x // T.int64(16) < n, lv55[T.int64(0), 
blockIdx_x * T.int64(128) + ax1_0_3 * T.int64(32) + ax0_ax1_fused_0 * 
T.int64(8) + threadIdx_y * T.int64(2) + threadIdx_x // T.int64(16), ax3_0_0 * 
T.int64(64) + threadIdx_x % T.int64(16) * T.int64(4) + ax0_ax1_fused_3], 
T.float16(0))
                                               for ax0_ax1_fused_0 in 
range(T.int64(12)):
                                                   for ax0_ax1_fused_3 in 
T.vectorized(T.int64(4)):
                                                       with 
T.block("lv11_reindex_shared.dyn"):
                                                           
T.reads(lv11[blockIdx_y * T.int64(128) + ax2_0_3 * T.int64(32) + 
ax0_ax1_fused_0 * T.int64(8) + threadIdx_y * T.int64(2) + threadIdx_x // 
T.int64(16), ax3_0_0 * T.int64(64) + threadIdx_x % T.int64(16) * T.int64(4) + 
ax0_ax1_fused_3])
                                                           
T.writes(lv11_reindex_shared_dyn[T.int64(0), ax0_ax1_fused_0 * T.int64(8) + 
threadIdx_y * T.int64(2) + threadIdx_x // T.int64(16), threadIdx_x % 
T.int64(16) * T.int64(4) + ax0_ax1_fused_3])
                                                           
T.block_attr({"buffer_dim_align": [[0, 1, 16, 8]]})
                                                           
lv11_reindex_shared_dyn[T.int64(0), ax0_ax1_fused_0 * T.int64(8) + threadIdx_y 
* T.int64(2) + threadIdx_x // T.int64(16), threadIdx_x % T.int64(16) * 
T.int64(4) + ax0_ax1_fused_3] = lv11[blockIdx_y * T.int64(128) + ax2_0_3 * 
T.int64(32) + ax0_ax1_fused_0 * T.int64(8) + threadIdx_y * T.int64(2) + 
threadIdx_x // T.int64(16), ax3_0_0 * T.int64(64) + threadIdx_x % T.int64(16) * 
T.int64(4) + ax0_ax1_fused_3]
                                               for ax3_0_1 in range(T.int64(4)):
                                                   with T.block(""):
                                                       
T.reads(lv55_reindex_pad_shared_dyn[T.int64(0), threadIdx_y % T.int64(2) * 
T.int64(64):threadIdx_y % T.int64(2) * T.int64(64) + T.int64(32), ax3_0_1 * 
T.int64(16):ax3_0_1 * T.int64(16) + T.int64(16)], 
lv11_reindex_shared_dyn[T.int64(0), threadIdx_y // T.int64(2) * 
T.int64(64):threadIdx_y // T.int64(2) * T.int64(64) + T.int64(32), ax3_0_1 * 
T.int64(16):ax3_0_1 * T.int64(16) + T.int64(16)], 
var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[T.int64(0), 
ax1_0_3 * T.int64(32):ax1_0_3 * T.int64(32) + T.int64(32), ax2_0_3 * 
T.int64(32):ax2_0_3 * T.int64(32) + T.int64(32)])
                                                       
T.writes(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[T.int64(0),
 ax1_0_3 * T.int64(32):ax1_0_3 * T.int64(32) + T.int64(32), ax2_0_3 * 
T.int64(32):ax2_0_3 * T.int64(32) + T.int64(32)])
                                                       
lv55_reindex_pad_shared_dyn_wmma_matrix_a = T.alloc_buffer((T.int64(1), 
T.int64(32), T.int64(16)), "float16", scope="wmma.matrix_a")
                                                       
lv11_reindex_shared_dyn_wmma_matrix_b = T.alloc_buffer((T.int64(1), 
T.int64(32), T.int64(16)), "float16", scope="wmma.matrix_b")
                                                       for ax0_0 in 
T.unroll(T.int64(2)):
                                                           for ax1_0 in 
T.unroll(T.int64(1)):
                                                               for ax0_1, ax1_1 
in T.grid(T.int64(16), T.int64(16)):
                                                                   with 
T.block("lv55_reindex_pad_shared.dyn_wmma.matrix_a"):
                                                                       
T.where(blockIdx_x * T.int64(2) + threadIdx_y % T.int64(2) - (n + T.int64(63)) 
// T.int64(64) < T.int64(0))
                                                                       
T.reads(lv55_reindex_pad_shared_dyn[T.int64(0), threadIdx_y % T.int64(2) * 
T.int64(64) + ax0_0 * T.int64(16) + ax0_1, ax3_0_1 * T.int64(16) + ax1_1])
                                                                       
T.writes(lv55_reindex_pad_shared_dyn_wmma_matrix_a[T.int64(0), ax0_0 * 
T.int64(16) + ax0_1, ax1_1])
                                                                       
lv55_reindex_pad_shared_dyn_wmma_matrix_a[T.int64(0), ax0_0 * T.int64(16) + 
ax0_1, ax1_1] = lv55_reindex_pad_shared_dyn[T.int64(0), threadIdx_y % 
T.int64(2) * T.int64(64) + ax0_0 * T.int64(16) + ax0_1, ax3_0_1 * T.int64(16) + 
ax1_1]
                                                       for ax0, ax1 in 
T.grid(T.int64(32), T.int64(16)):
                                                           with 
T.block("lv11_reindex_shared.dyn_wmma.matrix_b"):
                                                               
T.reads(lv11_reindex_shared_dyn[T.int64(0), threadIdx_y // T.int64(2) * 
T.int64(64) + ax0, ax3_0_1 * T.int64(16) + ax1])
                                                               
T.writes(lv11_reindex_shared_dyn_wmma_matrix_b[T.int64(0), ax0, ax1])
                                                               
lv11_reindex_shared_dyn_wmma_matrix_b[T.int64(0), ax0, ax1] = 
lv11_reindex_shared_dyn[T.int64(0), threadIdx_y // T.int64(2) * T.int64(64) + 
ax0, ax3_0_1 * T.int64(16) + ax1]
                                                       for ax1_0_4, ax2_0_4 in 
T.grid(T.int64(2), T.int64(2)):
                                                           with 
T.block("NT_matmul_o_update"):
                                                               
T.where(blockIdx_x * T.int64(2) + threadIdx_y % T.int64(2) - (n + T.int64(63)) 
// T.int64(64) < T.int64(0))
                                                               
T.reads(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[T.int64(0),
 ax1_0_3 * T.int64(32) + ax1_0_4 * T.int64(16):ax1_0_3 * T.int64(32) + ax1_0_4 
* T.int64(16) + T.int64(16), ax2_0_3 * T.int64(32) + ax2_0_4 * 
T.int64(16):ax2_0_3 * T.int64(32) + ax2_0_4 * T.int64(16) + T.int64(16)], 
lv55_reindex_pad_shared_dyn_wmma_matrix_a[T.int64(0), ax1_0_4 * 
T.int64(16):ax1_0_4 * T.int64(16) + T.int64(16), T.int64(0):T.int64(16)], 
lv11_reindex_shared_dyn_wmma_matrix_b[T.int64(0), ax2_0_4 * T.int64(16):ax2_0_4 
* T.int64(16) + T.int64(16), T.int64(0):T.int64(16)])
                                                               
T.writes(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[T.int64(0),
 ax1_0_3 * T.int64(32) + ax1_0_4 * T.int64(16):ax1_0_3 * T.int64(32) + ax1_0_4 
* T.int64(16) + T.int64(16), ax2_0_3 * T.int64(32) + ax2_0_4 * 
T.int64(16):ax2_0_3 * T.int64(32) + ax2_0_4 * T.int64(16) + T.int64(16)])
                                                               for ax1_1, 
ax2_1, ax3_1 in T.grid(T.int64(16), T.int64(16), T.int64(16)):
                                                                   with 
T.block("NT_matmul"):
                                                                       
T.reads(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[T.int64(0),
 ax1_0_3 * T.int64(32) + ax1_0_4 * T.int64(16) + ax1_1, ax2_0_3 * T.int64(32) + 
ax2_0_4 * T.int64(16) + ax2_1], 
lv55_reindex_pad_shared_dyn_wmma_matrix_a[T.int64(0), ax1_0_4 * T.int64(16) + 
ax1_1, ax3_1], lv11_reindex_shared_dyn_wmma_matrix_b[T.int64(0), ax2_0_4 * 
T.int64(16) + ax2_1, ax3_1])
                                                                       
T.writes(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[T.int64(0),
 ax1_0_3 * T.int64(32) + ax1_0_4 * T.int64(16) + ax1_1, ax2_0_3 * T.int64(32) + 
ax2_0_4 * T.int64(16) + ax2_1])
                                                                       
var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[T.int64(0), 
ax1_0_3 * T.int64(32) + ax1_0_4 * T.int64(16) + ax1_1, ax2_0_3 * T.int64(32) + 
ax2_0_4 * T.int64(16) + ax2_1] = 
var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[T.int64(0), 
ax1_0_3 * T.int64(32) + ax1_0_4 * T.int64(16) + ax1_1, ax2_0_3 * T.int64(32) + 
ax2_0_4 * T.int64(16) + ax2_1] + 
lv55_reindex_pad_shared_dyn_wmma_matrix_a[T.int64(0), ax1_0_4 * T.int64(16) + 
ax1_1, ax3_1] * lv11_reindex_shared_dyn_wmma_matrix_b[T.int64(0), ax2_0_4 * 
T.int64(16) + ax2_1, ax3_1]
                                   for ax0_0, ax1_0, ax0_1, ax1_1 in 
T.grid(T.int64(4), T.int64(4), T.int64(16), T.int64(16)):
                                       with 
T.block("var_NT_matmul_intermediate_reindex_pad_shared.dyn_wmma.accumulator"):
                                           T.where(blockIdx_x * T.int64(2) + 
threadIdx_y % T.int64(2) - (n + T.int64(63)) // T.int64(64) < T.int64(0))
                                           
T.reads(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[T.int64(0),
 ax0_0 * T.int64(16) + ax0_1, ax1_0 * T.int64(16) + ax1_1])
                                           
T.writes(var_NT_matmul_intermediate_reindex_pad_shared_dyn[T.int64(0), 
threadIdx_y % T.int64(2) * T.int64(64) + ax0_0 * T.int64(16) + ax0_1, 
threadIdx_y // T.int64(2) * T.int64(64) + ax1_0 * T.int64(16) + ax1_1])
                                           
var_NT_matmul_intermediate_reindex_pad_shared_dyn[T.int64(0), threadIdx_y % 
T.int64(2) * T.int64(64) + ax0_0 * T.int64(16) + ax0_1, threadIdx_y // 
T.int64(2) * T.int64(64) + ax1_0 * T.int64(16) + ax1_1] = 
var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[T.int64(0), 
ax0_0 * T.int64(16) + ax0_1, ax1_0 * T.int64(16) + ax1_1]
                                   for ax0_ax1_fused_0 in range(T.int64(32)):
                                       for ax0_ax1_fused_2 in 
T.vectorized(T.int64(4)):
                                           with 
T.block("var_NT_matmul_intermediate_reindex_pad_shared.dyn"):
                                               T.where(blockIdx_x * T.int64(2) 
+ threadIdx_y % T.int64(2) - (n + T.int64(63)) // T.int64(64) < T.int64(0))
                                               
T.reads(var_NT_matmul_intermediate_reindex_pad_shared_dyn[T.int64(0), 
threadIdx_y % T.int64(2) * T.int64(64) + ax0_ax1_fused_0 * T.int64(2) + 
threadIdx_x // T.int64(16), threadIdx_y // T.int64(2) * T.int64(64) + 
threadIdx_x % T.int64(16) * T.int64(4) + ax0_ax1_fused_2], lv12[blockIdx_y * 
T.int64(128) + threadIdx_y // T.int64(2) * T.int64(64) + threadIdx_x % 
T.int64(16) * T.int64(4) + ax0_ax1_fused_2])
                                               
T.writes(var_T_add_intermediate[T.int64(0), blockIdx_x * T.int64(128) + 
threadIdx_y % T.int64(2) * T.int64(64) + ax0_ax1_fused_0 * T.int64(2) + 
threadIdx_x // T.int64(16), blockIdx_y * T.int64(128) + threadIdx_y // 
T.int64(2) * T.int64(64) + threadIdx_x % T.int64(16) * T.int64(4) + 
ax0_ax1_fused_2])
                                               
T.block_attr({"buffer_dim_align": [[0, 1, 16, 4]]})
                                               if blockIdx_x * T.int64(128) + 
threadIdx_y % T.int64(2) * T.int64(64) + ax0_ax1_fused_0 * T.int64(2) + 
threadIdx_x // T.int64(16) < n:
                                                   
var_T_add_intermediate[T.int64(0), blockIdx_x * T.int64(128) + threadIdx_y % 
T.int64(2) * T.int64(64) + ax0_ax1_fused_0 * T.int64(2) + threadIdx_x // 
T.int64(16), blockIdx_y * T.int64(128) + threadIdx_y // T.int64(2) * 
T.int64(64) + threadIdx_x % T.int64(16) * T.int64(4) + ax0_ax1_fused_2] = 
var_NT_matmul_intermediate_reindex_pad_shared_dyn[T.int64(0), threadIdx_y % 
T.int64(2) * T.int64(64) + ax0_ax1_fused_0 * T.int64(2) + threadIdx_x // 
T.int64(16), threadIdx_y // T.int64(2) * T.int64(64) + threadIdx_x % 
T.int64(16) * T.int64(4) + ax0_ax1_fused_2] + lv12[blockIdx_y * T.int64(128) + 
threadIdx_y // T.int64(2) * T.int64(64) + threadIdx_x % T.int64(16) * 
T.int64(4) + ax0_ax1_fused_2]
   ```



##########
src/tir/transforms/inject_ptx_async_copy.cc:
##########
@@ -79,8 +80,10 @@ class PTXAsyncCopyInjector : public StmtMutator {
         if (indices_lanes == 1) {
           auto src_offset = load->indices[0];
           auto dst_offset = store->indices[0];
-          Array<PrimExpr> args = {store->buffer->data, tir::Mul(dst_offset, 
PrimExpr(index_factor)),

Review Comment:
   We can use helper function `TVM_DLL PrimExpr mul(PrimExpr a, PrimExpr b, 
Span span = Span()` to do automatical integer type upcasting.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to