LeiWang1999 commented on PR #15264:
URL: https://github.com/apache/tvm/pull/15264#issuecomment-1864277870
Hi @junrushao , I have encountered some issue and bisect to this pull
request. here is my case:
```python
import tvm
from tvm.script import tir as T
from tvm.tir import IndexMap
def ldmatrix_trans_32x8_to_shared_16x16_layout(thread_id, local_id):
row = 8 * (thread_id // 16) + (thread_id % 8)
col = 8 * ((thread_id % 16) // 8) + local_id % 8
return row, col
def ldmatrix_trans_permutation_16x16_32x8_16x16(kernel_i, kernel_j):
thread_id = kernel_i * 2 + kernel_j // 8
local_id = kernel_j % 8
return ldmatrix_trans_32x8_to_shared_16x16_layout(thread_id, local_id)
@tvm.script.ir_module
class MyModule:
@T.prim_func
def main(a: T.handle, b: T.handle):
T.func_attr({"global_symbol": "main", "tir.noalias": True})
A = T.match_buffer(a, [16, 16], dtype="float16")
B = T.match_buffer(b, [16, 16], dtype="float16")
for i, j in T.grid(16, 16):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(B[vi, vj])
T.writes(A[vi, vj])
A[vi, vj] = B[vi, vj]
ir_module = MyModule
sch = tvm.tir.Schedule(ir_module)
block_b = sch.get_block("B")
sch.transform_layout(block_b, ('read', 0),
ldmatrix_trans_permutation_16x16_32x8_16x16)
print("========================inject
transform=============================")
print(sch.mod["main"].script())
index_map = IndexMap.from_func(ldmatrix_trans_permutation_16x16_32x8_16x16)
inversed_index_map = index_map.inverse([16, 16])
def inverse_permutation(i, j):
return inversed_index_map.map_indices([i, j])
sch.transform_layout(block_b, ('read', 0), inverse_permutation)
print("========================inverse inject
transform=============================")
print(sch.mod["main"].script())
```
before this pr, the output is
```bash
# from tvm.script import tir as T
@T.prim_func
def main(A: T.Buffer((16, 16), "float16"), B: T.Buffer((16, 16), "float16")):
T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
# with T.block("root"):
for i, j in T.grid(16, 16):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(B[vi, vj])
T.writes(A[vi, vj])
A[vi, vj] = B[vi, vj]
```
As we can see, the indexmap can be simplified, and can be inversed.
After this pr, the output is
```bash
========================inject transform=============================
# from tvm.script import tir as T
@T.prim_func
def main(A: T.Buffer((16, 16), "float16"), B: T.Buffer((16, 16), "float16")):
T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
# with T.block("root"):
for i, j in T.grid(16, 16):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(B[(vi * 2 + vj // 8) // 16 * 8 + (vi * 2 + vj // 8) % 8,
(vi * 2 + vj // 8) % 16 // 8 * 8 + vj % 8])
T.writes(A[vi, vj])
A[vi, vj] = B[(vi * 2 + vj // 8) // 16 * 8 + (vi * 2 + vj // 8)
% 8, (vi * 2 + vj // 8) % 16 // 8 * 8 + vj % 8]
Traceback (most recent call last):
File
"/home/t-leiwang/ladder_workspace/tvm_gpu_gemm/discuss_inversemap.py", line 42,
in <module>
sch.transform_layout(block_b, ('read', 0), inverse_permutation)
File
"/home/t-leiwang/mlc_workspace/tvm_rebase/python/tvm/tir/schedule/_type_checker.py",
line 340, in wrap
return func(*args, **kwargs)
File
"/home/t-leiwang/mlc_workspace/tvm_rebase/python/tvm/tir/schedule/schedule.py",
line 3296, in transform_layout
_ffi_api.ScheduleTransformLayout( # type: ignore # pylint:
disable=no-member
File
"/home/t-leiwang/mlc_workspace/tvm_rebase/python/tvm/_ffi/_ctypes/packed_func.py",
line 238, in __call__
raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
7: TVMFuncCall
6:
tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<void
(tvm::tir::Schedule, tvm::tir::BlockRV const&, int, int, tvm::tir::IndexMap
const&, tvm::runtime::Optional<tvm::tir::IndexMap> const&,
bool)>::AssignTypedLambda<tvm::tir::{lambda(tvm::tir::Schedule,
tvm::tir::BlockRV const&, int, int, tvm::tir::IndexMap const&,
tvm::runtime::Optional<tvm::tir::IndexMap> const&,
bool)#17}>(tvm::tir::{lambda(tvm::tir::Schedule, tvm::tir::BlockRV const&, int,
int, tvm::tir::IndexMap const&, tvm::runtime::Optional<tvm::tir::IndexMap>
const&, bool)#17}, std::__cxx11::basic_string<char, std::char_traits<char>,
std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&,
tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*,
std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char>
>, tvm::runtime::TVMRetValue)
5: tvm::runtime::TypedPackedFunc<void (tvm::tir::Schedule,
tvm::tir::BlockRV const&, int, int, tvm::tir::IndexMap const&,
tvm::runtime::Optional<tvm::tir::IndexMap> const&,
bool)>::AssignTypedLambda<tvm::tir::{lambda(tvm::tir::Schedule,
tvm::tir::BlockRV const&, int, int, tvm::tir::IndexMap const&,
tvm::runtime::Optional<tvm::tir::IndexMap> const&,
bool)#17}>(tvm::tir::{lambda(tvm::tir::Schedule, tvm::tir::BlockRV const&, int,
int, tvm::tir::IndexMap const&, tvm::runtime::Optional<tvm::tir::IndexMap>
const&, bool)#17}, std::__cxx11::basic_string<char, std::char_traits<char>,
std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&,
tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const,
tvm::runtime::TVMRetValue) const [clone .isra.0]
4: tvm::tir::TracedScheduleNode::TransformLayout(tvm::tir::BlockRV const&,
int, tvm::tir::BufferIndexType, tvm::tir::IndexMap const&,
tvm::runtime::Optional<tvm::tir::IndexMap> const&, bool)
3: tvm::tir::ConcreteScheduleNode::TransformLayout(tvm::tir::BlockRV
const&, int, tvm::tir::BufferIndexType, tvm::tir::IndexMap const&,
tvm::runtime::Optional<tvm::tir::IndexMap> const&, bool)
2: tvm::tir::TransformLayout(tvm::tir::ScheduleState, tvm::tir::StmtSRef
const&, int, tvm::tir::BufferIndexType, tvm::tir::IndexMap const&,
tvm::runtime::Optional<tvm::tir::IndexMap> const&, bool)
1:
tvm::tir::IndexMap::NonSurjectiveInverse(tvm::runtime::Array<tvm::Range, void>,
tvm::arith::Analyzer*) const
0: tvm::tir::IndexMapInverseImpl(tvm::tir::IndexMap const&,
tvm::runtime::Array<tvm::Range, void> const&, tvm::arith::IterMapLevel,
tvm::arith::Analyzer*)
File "/home/t-leiwang/mlc_workspace/tvm_rebase/src/tir/ir/index_map.cc",
line 96
TVMError: Check failed: (padded_iter_map->errors.empty()) is false: Could
not parse mapping as sum of iterators. Error: IterMapExpr or subclasses should
only result from calls in IterMapRewriter using DirectMutate. Indirect return
occurred in i
```
Ths indexmap is not well optimized and the map inverse will throw an error.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]