LeiWang1999 commented on PR #15264:
URL: https://github.com/apache/tvm/pull/15264#issuecomment-1864872348

   Thanks tq, I found that under `index_map = 
IndexMap.from_func(ldmatrix_trans_permutation_16x16_32x8_16x16, 
index_dtype="int32")` though, the inverse map is not ok, the output is:
   
   ```bash
   ========================inject transform=============================
   # from tvm.script import tir as T
   
   @T.prim_func
   def main(A: T.Buffer((16, 16), "float16"), B: T.Buffer((16, 16), "float16")):
       T.func_attr({"tir.noalias": T.bool(True)})
       # with T.block("root"):
       for i, j in T.grid(16, 16):
           with T.block("B"):
               vi, vj = T.axis.remap("SS", [i, j])
               T.reads(B[(vi * 2 + vj // 8) // 16 * 8 + (vi * 2 + vj // 8) % 8, 
(vi * 2 + vj // 8) % 16 // 8 * 8 + vj % 8])
               T.writes(A[vi, vj])
               A[vi, vj] = B[(vi * 2 + vj // 8) // 16 * 8 + (vi * 2 + vj // 8) 
% 8, (vi * 2 + vj // 8) % 16 // 8 * 8 + vj % 8]
   ========================inverse inject transform=============================
   # from tvm.script import tir as T
   
   @T.prim_func
   def main(A: T.Buffer((16, 16), "float16"), B: T.Buffer((16, 16), "float16")):
       T.func_attr({"tir.noalias": T.bool(True)})
       # with T.block("root"):
       for i, j in T.grid(16, 16):yu
           with T.block("B"):
               vi, vj = T.axis.remap("SS", [i, j])
               T.reads(B[(vi * 2 + vj // 8) // 2, vj % 16])
               T.writes(A[vi, vj])
               A[vi, vj] = B[(vi * 2 + vj // 8) // 2, vj % 16]
   ```
   
   After applying the inverse_map to the map, the layout should remain 
consistent with its state prior to the transformation. 
   
   Before this pull request, the code was functioning as expected.
   
   ```bash
   ========================inject transform=============================
   # from tvm.script import tir as T
   @T.prim_func
   def func(A: T.Buffer[(16, 16), "float16"], B: T.Buffer[(16, 16), "float16"]):
       # function attr dict
       T.func_attr({"tir.noalias": True, "global_symbol": "main"})
       # body
       # with T.block("root")
       for i, j in T.grid(16, 16):
           with T.block("B"):
               vi, vj = T.axis.remap("SS", [i, j])
               T.reads(B[vi // 8 * 8 + vi % 4 * 2 + vj // 8, vi % 8 // 4 * 8 + 
vj % 8])
               T.writes(A[vi, vj])
               A[vi, vj] = B[vi // 8 * 8 + vi % 4 * 2 + vj // 8, vi % 8 // 4 * 
8 + vj % 8]
   
   ========================inverse inject transform=============================
   # from tvm.script import tir as T
   @T.prim_func
   def func(A: T.Buffer[(16, 16), "float16"], B: T.Buffer[(16, 16), "float16"]):
       # function attr dict
       T.func_attr({"tir.noalias": True, "global_symbol": "main"})
       # body
       # with T.block("root")
       for i, j in T.grid(16, 16):
           with T.block("B"):
               vi, vj = T.axis.remap("SS", [i, j])
               T.reads(B[vi, vj])
               T.writes(A[vi, vj])
               A[vi, vj] = B[vi, vj]
    ```
   
   I'll take a look tomorrow.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to