Johnson9009 opened a new issue #8093:
URL: https://github.com/apache/tvm/issues/8093


   Below simple case can reproduce the issue.
   ```python
   import tvm
   from tvm import te, nd
   
   _dtype = tvm.DataType("int8")
   dshape = (1, 14, 14, 1024)
   
   A = te.placeholder(dshape, name="A", dtype=_dtype)
   C = te.compute(dshape, lambda *i: A(*i) + 3, name="C")
   
   s = te.create_schedule(C.op)
   
   c_axis = s[C].fuse(*C.op.axis)
   outer, inner = s[C].split(c_axis, nparts=4)
   outer, inner = s[C].split(inner, 28*1024)
   
   ir_mod = tvm.lower(s, [A, C], name='fadd')
   ```
   
   The IR before and after pass "StorageFlatten" is something like below.
   ```
   PrintIR(Before StorageFlatten):
   primfn(A_1: handle, C_1: handle) -> ()
     attr = {"global_symbol": "fadd", "tir.noalias": True}
     buffers = {C: Buffer(C_2: Pointer(int32), int32, [1, 14, 14, 1024], []),
                A: Buffer(A_2: Pointer(int8), int8, [1, 14, 14, 1024], [])}
     buffer_map = {A_1: A, C_1: C} {
     attr [C] "realize_scope" = "";
     realize(C, [0:1, 0:14, 0:14, 0:1024], True {
       for (i0.i1.fused.i2.fused.i3.fused.outer: int32, 0, 4) {
         for (i0.i1.fused.i2.fused.i3.fused.inner.outer: int32, 0, 2) {
           for (i0.i1.fused.i2.fused.i3.fused.inner.inner: int32, 0, 28672) {
             if 
@tir.likely((floordiv(floordiv(floordiv(((i0.i1.fused.i2.fused.i3.fused.inner.inner
 + (i0.i1.fused.i2.fused.i3.fused.inner.outer*28672)) + 
(i0.i1.fused.i2.fused.i3.fused.outer*50176)), 1024), 14), 14) < 1), dtype=bool) 
{
               if 
@tir.likely((floordiv(floordiv(((i0.i1.fused.i2.fused.i3.fused.inner.inner + 
(i0.i1.fused.i2.fused.i3.fused.inner.outer*28672)) + 
(i0.i1.fused.i2.fused.i3.fused.outer*50176)), 1024), 14) < 14), dtype=bool) {
                 if 
@tir.likely((floordiv(((i0.i1.fused.i2.fused.i3.fused.inner.inner + 
(i0.i1.fused.i2.fused.i3.fused.inner.outer*28672)) + 
(i0.i1.fused.i2.fused.i3.fused.outer*50176)), 1024) < 196), dtype=bool) {
                   if @tir.likely((((i0.i1.fused.i2.fused.i3.fused.inner.inner 
+ (i0.i1.fused.i2.fused.i3.fused.inner.outer*28672)) + 
(i0.i1.fused.i2.fused.i3.fused.outer*50176)) < 200704), dtype=bool) {
                     if @tir.likely(((i0.i1.fused.i2.fused.i3.fused.inner.inner 
+ (i0.i1.fused.i2.fused.i3.fused.inner.outer*28672)) < 50176), dtype=bool) {
                       
C[floordiv(floordiv(floordiv(((i0.i1.fused.i2.fused.i3.fused.inner.inner + 
(i0.i1.fused.i2.fused.i3.fused.inner.outer*28672)) + 
(i0.i1.fused.i2.fused.i3.fused.outer*50176)), 1024), 14), 14), 
floormod(floordiv(floordiv(((i0.i1.fused.i2.fused.i3.fused.inner.inner + 
(i0.i1.fused.i2.fused.i3.fused.inner.outer*28672)) + 
(i0.i1.fused.i2.fused.i3.fused.outer*50176)), 1024), 14), 14), 
floormod(floordiv(((i0.i1.fused.i2.fused.i3.fused.inner.inner + 
(i0.i1.fused.i2.fused.i3.fused.inner.outer*28672)) + 
(i0.i1.fused.i2.fused.i3.fused.outer*50176)), 1024), 14), 
floormod(((i0.i1.fused.i2.fused.i3.fused.inner.inner + 
(i0.i1.fused.i2.fused.i3.fused.inner.outer*28672)) + 
(i0.i1.fused.i2.fused.i3.fused.outer*50176)), 1024)] = (cast(int32, 
A[floordiv(floordiv(floordiv(((i0.i1.fused.i2.fused.i3.fused.inner.inner + 
(i0.i1.fused.i2.fused.i3.fused.inner.outer*28672)) + 
(i0.i1.fused.i2.fused.i3.fused.outer*50176)), 1024), 14), 14), 
floormod(floordiv(floordiv(((i0.i1.fused.i2.fused.i3.
 fused.inner.inner + (i0.i1.fused.i2.fused.i3.fused.inner.outer*28672)) + 
(i0.i1.fused.i2.fused.i3.fused.outer*50176)), 1024), 14), 14), 
floormod(floordiv(((i0.i1.fused.i2.fused.i3.fused.inner.inner + 
(i0.i1.fused.i2.fused.i3.fused.inner.outer*28672)) + 
(i0.i1.fused.i2.fused.i3.fused.outer*50176)), 1024), 14), 
floormod(((i0.i1.fused.i2.fused.i3.fused.inner.inner + 
(i0.i1.fused.i2.fused.i3.fused.inner.outer*28672)) + 
(i0.i1.fused.i2.fused.i3.fused.outer*50176)), 1024)]) + 3)
                     }
                   }
                 }
               }
             }
           }
         }
       }
     })
   }
   
   PrintIR(After StorageFlatten):
   primfn(A_1: handle, C_1: handle) -> ()
     attr = {"global_symbol": "fadd", "tir.noalias": True}
     buffers = {C: Buffer(C_2: Pointer(int32), int32, [1, 14, 14, 1024], []),
                A: Buffer(A_2: Pointer(int8), int8, [1, 14, 14, 1024], [])}
     buffer_map = {A_1: A, C_1: C} {
     for (i0.i1.fused.i2.fused.i3.fused.outer: int32, 0, 4) {
       for (i0.i1.fused.i2.fused.i3.fused.inner.outer: int32, 0, 2) {
         for (i0.i1.fused.i2.fused.i3.fused.inner.inner: int32, 0, 28672) {
           if 
@tir.likely((floordiv(floordiv(floordiv(((i0.i1.fused.i2.fused.i3.fused.inner.inner
 + (i0.i1.fused.i2.fused.i3.fused.inner.outer*28672)) + 
(i0.i1.fused.i2.fused.i3.fused.outer*50176)), 1024), 14), 14) < 1), dtype=bool) 
{
             if 
@tir.likely((floordiv(floordiv(((i0.i1.fused.i2.fused.i3.fused.inner.inner + 
(i0.i1.fused.i2.fused.i3.fused.inner.outer*28672)) + 
(i0.i1.fused.i2.fused.i3.fused.outer*50176)), 1024), 14) < 14), dtype=bool) {
               if 
@tir.likely((floordiv(((i0.i1.fused.i2.fused.i3.fused.inner.inner + 
(i0.i1.fused.i2.fused.i3.fused.inner.outer*28672)) + 
(i0.i1.fused.i2.fused.i3.fused.outer*50176)), 1024) < 196), dtype=bool) {
                 if @tir.likely((((i0.i1.fused.i2.fused.i3.fused.inner.inner + 
(i0.i1.fused.i2.fused.i3.fused.inner.outer*28672)) + 
(i0.i1.fused.i2.fused.i3.fused.outer*50176)) < 200704), dtype=bool) {
                   if @tir.likely(((i0.i1.fused.i2.fused.i3.fused.inner.inner + 
(i0.i1.fused.i2.fused.i3.fused.inner.outer*28672)) < 50176), dtype=bool) {
                     
C_2[((((floordiv((((i0.i1.fused.i2.fused.i3.fused.outer*49) + 
(i0.i1.fused.i2.fused.i3.fused.inner.outer*28)) + 
floordiv(i0.i1.fused.i2.fused.i3.fused.inner.inner, 1024)), 196)*200704) + 
(floormod(((i0.i1.fused.i2.fused.i3.fused.inner.outer*2) + 
floordiv(((i0.i1.fused.i2.fused.i3.fused.outer*49) + 
floordiv(i0.i1.fused.i2.fused.i3.fused.inner.inner, 1024)), 14)), 14)*14336)) + 
(floormod(((i0.i1.fused.i2.fused.i3.fused.outer*49) + 
floordiv(i0.i1.fused.i2.fused.i3.fused.inner.inner, 1024)), 14)*1024)) + 
floormod(i0.i1.fused.i2.fused.i3.fused.inner.inner, 1024))] = (cast(int32, 
(int8*)A_2[((((floordiv((((i0.i1.fused.i2.fused.i3.fused.outer*49) + 
(i0.i1.fused.i2.fused.i3.fused.inner.outer*28)) + 
floordiv(i0.i1.fused.i2.fused.i3.fused.inner.inner, 1024)), 196)*200704) + 
(floormod(((i0.i1.fused.i2.fused.i3.fused.inner.outer*2) + 
floordiv(((i0.i1.fused.i2.fused.i3.fused.outer*49) + 
floordiv(i0.i1.fused.i2.fused.i3.fused.inner.inner, 1024)), 14)), 14)*14336)) + 
(floormod(
 ((i0.i1.fused.i2.fused.i3.fused.outer*49) + 
floordiv(i0.i1.fused.i2.fused.i3.fused.inner.inner, 1024)), 14)*1024)) + 
floormod(i0.i1.fused.i2.fused.i3.fused.inner.inner, 1024))]) + 3)
                   }
                 }
               }
             }
           }
         }
       }
     }
   }
   ```
   I can't found a smaller case can reproduce this issue, and the data shape is 
one of real ResNet-50 layer.
   The fix to it can be simple, just like something be done in this temporary 
and trial PR #8090.
   I have done some debugging and analyzing of it.
   ```
   1st Dimension Index:   floordiv(floordiv(floordiv((k + j*28672 + i*50176), 
1024), 14), 14)
   2nd Dimension Index:  floormod(floordiv(floordiv((k + j*28672 + i*50176), 
1024), 14), 14)
   3rd Dimension Index:  floormod(floordiv((k + j*28672 + i*50176), 1024), 14)
   4th Dimension Index:  floormod((k + j*28672 + i*50176), 1024)
   
   Now merge the 1st and 2nd dimension:
   The merged expression is 1st_dim_index * 14 + 2nd_dim_index.
   pick out the common part of 1st_dim_index and 2nd_dim_index and set it "x1".
   x1 = floordiv(floordiv((k + j*28672 + i*50176), 1024), 14)
   x2 = floordiv(x1, 14)*14
   x3 = floormod(x1, 14)
   
   Then the whole merged expression is "x2 + x3", obviously it can be 
simplified to "x1".
   
   Round 1:
   1. x1 of x2 -> floordiv((floordiv((k + j*28672), 1024) + i*49), 14)
      We can see the (i*50176) part is moved out.
   2. x2 -> floordiv((floordiv((k + j*28672), 1024) + i*49), 196)*14
      We can see the two "14" is merged together as "196" because of 
floordiv(floordiv(xxx, 14), 14).
   3. x1 of x3 -> floordiv((floordiv((k + j*28672), 1024) + i*49), 14)
   4. x3 -> floormod(floordiv((floordiv((k + j*28672), 1024) + i*49), 14), 14)
   
   Status:
   x1 = floordiv((k + j*28672), 1024) + i*49
   x2 = floordiv(x1, 196)*14
   x3 = floormod(floordiv(x1, 14), 14)
   
   all: x2 + x3
   Now the above simply make the simplify rule "floordiv(xxx, c) * c + 
floormod(xxx, c)" of add node can't be applied. 
   ```


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to