MasterJH5574 opened a new pull request, #13449:
URL: https://github.com/apache/tvm/pull/13449
This PR fixes the behavior of IndexDataTypeNormalizer on CastNode.
## Background
Consider the following case,
```python
A = te.placeholder((tir.IntImm("int64", 2), tir.IntImm("int64", 4)),
name="A")
B = topi.reshape(A, (4, 2))
func = te.create_prim_func([A, B], index_dtype_override=None)
```
the generated PrimFunc is
```python
@T.prim_func
def func(A: T.Buffer[(T.int64(2), T.int64(4)), "float32"], T_reshape:
T.Buffer[(4, 2), "float32"]):
for i0, i1 in T.grid(4, 2):
with T.block("T_reshape"):
ax0, ax1 = T.axis.remap("SS", [i0, i1])
T.reads(A[(T.Cast("int64", ax0) * T.int64(2) + T.Cast("int64",
ax1)) % T.int64(8) // T.int64(4), (T.Cast("int64", ax0) * T.int64(2) +
T.Cast("int64", ax1)) % T.int64(4)])
T.writes(T_reshape[ax0, ax1])
T_reshape[ax0, ax1] = A[(T.Cast("int64", ax0) * T.int64(2) +
T.Cast("int64", ax1)) % T.int64(8) // T.int64(4), (T.Cast("int64", ax0) *
T.int64(2) + T.Cast("int64", ax1)) % T.int64(4)]
```
Here loop variables `ax0` and `ax1` have dtype int32, since the shape of the
output buffer is in int32. Other other hand, the input buffer has shape in
int64. So as the script above shows, CreatePrimFunc will cast the int32
variables to int64 first, and access the input buffer afterwards.
Now if we use the option `index_dtype_override` to specify an index dtype as
below,
```python
func = te.create_prim_func([A, B], index_dtype_override=None)
```
the generated function will be
```python
@T.prim_func
def func(A: T.Buffer[(T.int64(2), T.int64(4)), "float32"], T_reshape:
T.Buffer[(T.int64(4), T.int64(2)), "float32"]):
for i0, i1 in T.grid(T.int64(4), T.int64(2)):
with T.block("T_reshape"):
ax0, ax1 = T.axis.remap("SS", [i0, i1])
T.reads(A[(T.Cast("int64", ax0) * T.int64(2) + T.Cast("int64",
ax1)) % T.int64(8) // T.int64(4), (T.Cast("int64", ax0) * T.int64(2) +
T.Cast("int64", ax1)) % T.int64(4)])
T.writes(T_reshape[ax0, ax1])
T_reshape[ax0, ax1] = A[(T.Cast("int64", ax0) * T.int64(2) +
T.Cast("int64", ax1)) % T.int64(8) // T.int64(4), (T.Cast("int64", ax0) *
T.int64(2) + T.Cast("int64", ax1)) % T.int64(4)]
```
Note that though all variables and the buffer shapes have dtype int64, there
are still CastNodes such as `T.Cast("int64", ax0)` when `ax0` is already an
int64 variable. We don’t want such redundant casting.
## Fix
To fix the issue above, this PR overrides the `VisitExpr_(const CastNode*
cast)` method in IndexDataTypeNormalizer. When the `value` field of a CastNode
already has the target dtype, we no longer cast it.
---
cc @vinx13 @junrushao
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]