LeiWang1999 commented on PR #15264:
URL: https://github.com/apache/tvm/pull/15264#issuecomment-1864277870

   Hi @junrushao , I have encountered some issue and bisect to this pull 
request. here is my case: 
   
   ```python
   import tvm
   from tvm.script import tir as T
   from tvm.tir import IndexMap
   
   def ldmatrix_trans_32x8_to_shared_16x16_layout(thread_id, local_id):
       row = 8 * (thread_id // 16) + (thread_id % 8)
       col = 8 * ((thread_id % 16) // 8) + local_id % 8
       return row, col
   
   def ldmatrix_trans_permutation_16x16_32x8_16x16(kernel_i, kernel_j):
       thread_id = kernel_i * 2 + kernel_j // 8
       local_id = kernel_j % 8
       return ldmatrix_trans_32x8_to_shared_16x16_layout(thread_id, local_id)
   
   @tvm.script.ir_module
   class MyModule:
       @T.prim_func
       def main(a: T.handle, b: T.handle):
           T.func_attr({"global_symbol": "main", "tir.noalias": True})
           A = T.match_buffer(a, [16, 16], dtype="float16")
           B = T.match_buffer(b, [16, 16], dtype="float16")
           
           for i, j in T.grid(16, 16):
               with T.block("B"):
                   vi, vj = T.axis.remap("SS", [i, j])
                   T.reads(B[vi, vj])
                   T.writes(A[vi, vj])
                   A[vi, vj] = B[vi, vj]
   
   ir_module = MyModule
   sch = tvm.tir.Schedule(ir_module)
   
   block_b = sch.get_block("B")
   sch.transform_layout(block_b, ('read', 0), 
ldmatrix_trans_permutation_16x16_32x8_16x16)
   print("========================inject 
transform=============================")
   print(sch.mod["main"].script())
   
   index_map = IndexMap.from_func(ldmatrix_trans_permutation_16x16_32x8_16x16)
   inversed_index_map = index_map.inverse([16, 16])
   def inverse_permutation(i, j):
       return inversed_index_map.map_indices([i, j])
   sch.transform_layout(block_b, ('read', 0), inverse_permutation)
   print("========================inverse inject 
transform=============================")
   print(sch.mod["main"].script())
   ```
   
   before this pr, the output is 
   ```bash
   # from tvm.script import tir as T
   
   @T.prim_func
   def main(A: T.Buffer((16, 16), "float16"), B: T.Buffer((16, 16), "float16")):
       T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
       # with T.block("root"):
       for i, j in T.grid(16, 16):
           with T.block("B"):
               vi, vj = T.axis.remap("SS", [i, j])
               T.reads(B[vi, vj])
               T.writes(A[vi, vj])
               A[vi, vj] = B[vi, vj]
   ``` 
   
   As we can see, the indexmap can be simplified, and can be inversed.
   
   After this pr, the output is 
   
   ```bash
   ========================inject transform=============================
   # from tvm.script import tir as T
   
   @T.prim_func
   def main(A: T.Buffer((16, 16), "float16"), B: T.Buffer((16, 16), "float16")):
       T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
       # with T.block("root"):
       for i, j in T.grid(16, 16):
           with T.block("B"):
               vi, vj = T.axis.remap("SS", [i, j])
               T.reads(B[(vi * 2 + vj // 8) // 16 * 8 + (vi * 2 + vj // 8) % 8, 
(vi * 2 + vj // 8) % 16 // 8 * 8 + vj % 8])
               T.writes(A[vi, vj])
               A[vi, vj] = B[(vi * 2 + vj // 8) // 16 * 8 + (vi * 2 + vj // 8) 
% 8, (vi * 2 + vj // 8) % 16 // 8 * 8 + vj % 8]
   Traceback (most recent call last):
     File 
"/home/t-leiwang/ladder_workspace/tvm_gpu_gemm/discuss_inversemap.py", line 42, 
in <module>
       sch.transform_layout(block_b, ('read', 0), inverse_permutation)
     File 
"/home/t-leiwang/mlc_workspace/tvm_rebase/python/tvm/tir/schedule/_type_checker.py",
 line 340, in wrap
       return func(*args, **kwargs)
     File 
"/home/t-leiwang/mlc_workspace/tvm_rebase/python/tvm/tir/schedule/schedule.py", 
line 3296, in transform_layout
       _ffi_api.ScheduleTransformLayout(  # type: ignore # pylint: 
disable=no-member
     File 
"/home/t-leiwang/mlc_workspace/tvm_rebase/python/tvm/_ffi/_ctypes/packed_func.py",
 line 238, in __call__
       raise get_last_ffi_error()
   tvm._ffi.base.TVMError: Traceback (most recent call last):
     7: TVMFuncCall
     6: 
tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<void
 (tvm::tir::Schedule, tvm::tir::BlockRV const&, int, int, tvm::tir::IndexMap 
const&, tvm::runtime::Optional<tvm::tir::IndexMap> const&, 
bool)>::AssignTypedLambda<tvm::tir::{lambda(tvm::tir::Schedule, 
tvm::tir::BlockRV const&, int, int, tvm::tir::IndexMap const&, 
tvm::runtime::Optional<tvm::tir::IndexMap> const&, 
bool)#17}>(tvm::tir::{lambda(tvm::tir::Schedule, tvm::tir::BlockRV const&, int, 
int, tvm::tir::IndexMap const&, tvm::runtime::Optional<tvm::tir::IndexMap> 
const&, bool)#17}, std::__cxx11::basic_string<char, std::char_traits<char>, 
std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, 
tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, 
std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> 
>, tvm::runtime::TVMRetValue)
     5: tvm::runtime::TypedPackedFunc<void (tvm::tir::Schedule, 
tvm::tir::BlockRV const&, int, int, tvm::tir::IndexMap const&, 
tvm::runtime::Optional<tvm::tir::IndexMap> const&, 
bool)>::AssignTypedLambda<tvm::tir::{lambda(tvm::tir::Schedule, 
tvm::tir::BlockRV const&, int, int, tvm::tir::IndexMap const&, 
tvm::runtime::Optional<tvm::tir::IndexMap> const&, 
bool)#17}>(tvm::tir::{lambda(tvm::tir::Schedule, tvm::tir::BlockRV const&, int, 
int, tvm::tir::IndexMap const&, tvm::runtime::Optional<tvm::tir::IndexMap> 
const&, bool)#17}, std::__cxx11::basic_string<char, std::char_traits<char>, 
std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, 
tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const, 
tvm::runtime::TVMRetValue) const [clone .isra.0]
     4: tvm::tir::TracedScheduleNode::TransformLayout(tvm::tir::BlockRV const&, 
int, tvm::tir::BufferIndexType, tvm::tir::IndexMap const&, 
tvm::runtime::Optional<tvm::tir::IndexMap> const&, bool)
     3: tvm::tir::ConcreteScheduleNode::TransformLayout(tvm::tir::BlockRV 
const&, int, tvm::tir::BufferIndexType, tvm::tir::IndexMap const&, 
tvm::runtime::Optional<tvm::tir::IndexMap> const&, bool)
     2: tvm::tir::TransformLayout(tvm::tir::ScheduleState, tvm::tir::StmtSRef 
const&, int, tvm::tir::BufferIndexType, tvm::tir::IndexMap const&, 
tvm::runtime::Optional<tvm::tir::IndexMap> const&, bool)
     1: 
tvm::tir::IndexMap::NonSurjectiveInverse(tvm::runtime::Array<tvm::Range, void>, 
tvm::arith::Analyzer*) const
     0: tvm::tir::IndexMapInverseImpl(tvm::tir::IndexMap const&, 
tvm::runtime::Array<tvm::Range, void> const&, tvm::arith::IterMapLevel, 
tvm::arith::Analyzer*)
     File "/home/t-leiwang/mlc_workspace/tvm_rebase/src/tir/ir/index_map.cc", 
line 96
   TVMError: Check failed: (padded_iter_map->errors.empty()) is false: Could 
not parse mapping as sum of iterators.  Error: IterMapExpr or subclasses should 
only result from calls in IterMapRewriter using DirectMutate.  Indirect return 
occurred in i
   ```
   
   Ths indexmap is not well optimized and the map inverse will throw an error.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to