MasterJH5574 opened a new pull request, #13449:
URL: https://github.com/apache/tvm/pull/13449

   This PR fixes the behavior of IndexDataTypeNormalizer on CastNode.
   
   ## Background
   
   Consider the following case,
   ```python
   A = te.placeholder((tir.IntImm("int64", 2), tir.IntImm("int64", 4)), 
name="A")
   B = topi.reshape(A, (4, 2))
   func = te.create_prim_func([A, B], index_dtype_override=None)
   ```
   the generated PrimFunc is
   ```python
   @T.prim_func
   def func(A: T.Buffer[(T.int64(2), T.int64(4)), "float32"], T_reshape: 
T.Buffer[(4, 2), "float32"]):
       for i0, i1 in T.grid(4, 2):
           with T.block("T_reshape"):
               ax0, ax1 = T.axis.remap("SS", [i0, i1])
               T.reads(A[(T.Cast("int64", ax0) * T.int64(2) + T.Cast("int64", 
ax1)) % T.int64(8) // T.int64(4), (T.Cast("int64", ax0) * T.int64(2) + 
T.Cast("int64", ax1)) % T.int64(4)])
               T.writes(T_reshape[ax0, ax1])
               T_reshape[ax0, ax1] = A[(T.Cast("int64", ax0) * T.int64(2) + 
T.Cast("int64", ax1)) % T.int64(8) // T.int64(4), (T.Cast("int64", ax0) * 
T.int64(2) + T.Cast("int64", ax1)) % T.int64(4)]
   ```
   Here loop variables `ax0` and `ax1` have dtype int32, since the shape of the 
output buffer is in int32. Other other hand, the input buffer has shape in 
int64. So as the script above shows, CreatePrimFunc will cast the int32 
variables to int64 first, and access the input buffer afterwards.
   
   Now if we use the option `index_dtype_override` to specify an index dtype as 
below,
   ```python
   func = te.create_prim_func([A, B], index_dtype_override=None)
   ```
   the generated function will be
   ```python
   @T.prim_func
   def func(A: T.Buffer[(T.int64(2), T.int64(4)), "float32"], T_reshape: 
T.Buffer[(T.int64(4), T.int64(2)), "float32"]):
       for i0, i1 in T.grid(T.int64(4), T.int64(2)):
           with T.block("T_reshape"):
               ax0, ax1 = T.axis.remap("SS", [i0, i1])
               T.reads(A[(T.Cast("int64", ax0) * T.int64(2) + T.Cast("int64", 
ax1)) % T.int64(8) // T.int64(4), (T.Cast("int64", ax0) * T.int64(2) + 
T.Cast("int64", ax1)) % T.int64(4)])
               T.writes(T_reshape[ax0, ax1])
               T_reshape[ax0, ax1] = A[(T.Cast("int64", ax0) * T.int64(2) + 
T.Cast("int64", ax1)) % T.int64(8) // T.int64(4), (T.Cast("int64", ax0) * 
T.int64(2) + T.Cast("int64", ax1)) % T.int64(4)]
   ```
   Note that though all variables and the buffer shapes have dtype int64, there 
are still CastNodes such as `T.Cast("int64", ax0)` when `ax0` is already an 
int64 variable. We don’t want such redundant casting.
   
   ## Fix
   
   To fix the issue above, this PR overrides the `VisitExpr_(const CastNode* 
cast)` method in IndexDataTypeNormalizer. When the `value` field of a CastNode 
already has the target dtype, we no longer cast it.
   
   ---
   
   cc @vinx13 @junrushao 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to