This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 513fcf4 [TVMSCRIPT] TVMScript Parser support BufferSlice indices
(#8408)
513fcf4 is described below
commit 513fcf4bb7ea53f8663a6d56e67d82fe49fde21d
Author: Honghua Cao <[email protected]>
AuthorDate: Sat Jul 10 02:33:23 2021 +0800
[TVMSCRIPT] TVMScript Parser support BufferSlice indices (#8408)
Co-authored-by: honghua.cao <[email protected]>
---
python/tvm/script/node.py | 6 +-
tests/python/unittest/test_tvmscript_complete.py | 66 ++++++++++++++++++++++
.../python/unittest/test_tvmscript_error_report.py | 29 ++++++++++
3 files changed, 100 insertions(+), 1 deletion(-)
diff --git a/python/tvm/script/node.py b/python/tvm/script/node.py
index c459368..cfbc668 100644
--- a/python/tvm/script/node.py
+++ b/python/tvm/script/node.py
@@ -108,7 +108,7 @@ class BufferSlice(ObjectGeneric):
span,
)
- slices: List[Slice] = []
+ slices: List[Union[Slice, BufferSlice]] = []
for index in indices:
if isinstance(index, Slice):
check_index(index.start)
@@ -117,6 +117,10 @@ class BufferSlice(ObjectGeneric):
elif isinstance(index, (PrimExpr, int)):
check_index(index)
slices.append(Slice(index))
+ elif isinstance(index, BufferSlice):
+ buffer_load = index.asobject()
+ check_index(buffer_load)
+ slices.append(Slice(buffer_load))
else:
report_error(
"Unsupported index type for BufferSlice, "
diff --git a/tests/python/unittest/test_tvmscript_complete.py
b/tests/python/unittest/test_tvmscript_complete.py
index 76d4d9b..a4d2dec 100644
--- a/tests/python/unittest/test_tvmscript_complete.py
+++ b/tests/python/unittest/test_tvmscript_complete.py
@@ -190,9 +190,75 @@ def test_complete_opaque_block_error():
assert False
[email protected]
+def func_with_bufferslice_indices(data: ty.handle, index: ty.handle) -> None:
+ data_buf = tir.match_buffer(data, (16, 16), "float32")
+ index_buf = tir.match_buffer(index, (1,), "int32")
+ out_buf = tir.alloc_buffer((16, 16), "float32")
+
+ with tir.block([16, 16]) as [vi, vj]:
+ out_buf[vi, vj] = data_buf[vi, index_buf[0]]
+
+
[email protected]
+def expected_bufferslice_indices(data: ty.handle, index: ty.handle) -> None:
+ index_buf = tir.match_buffer(
+ index, [1], dtype="int32", elem_offset=0, align=128, offset_factor=1
+ )
+ data_buf = tir.match_buffer(data, [16, 16], elem_offset=0, align=128,
offset_factor=1)
+ with tir.block([], "root"):
+ tir.reads([])
+ tir.writes([])
+ out_buf = tir.alloc_buffer([16, 16], elem_offset=0, align=128,
offset_factor=1)
+ for i0, i1 in tir.grid(16, 16):
+ with tir.block([16, 16], "") as [vi, vj]:
+ tir.bind(vi, i0)
+ tir.bind(vj, i1)
+ tir.reads([data_buf[vi, 0:16], index_buf[0]])
+ tir.writes([out_buf[vi, vj]])
+ out_buf[vi, vj] = data_buf[vi, index_buf[0]]
+
+
[email protected]
+def func_with_recursive_bufferslice_indices(data: ty.handle, index: ty.handle)
-> None:
+ data_buf = tir.match_buffer(data, (16, 16), "float32")
+ index_buf = tir.match_buffer(index, (1,), "int32")
+ out_buf = tir.alloc_buffer((16, 16), "float32")
+
+ with tir.block([16, 16]) as [vi, vj]:
+ out_buf[vi, vj] = data_buf[index_buf[index_buf[0]], index_buf[0]]
+
+
[email protected]
+def expected_recursive_bufferslice_indices(data: ty.handle, index: ty.handle)
-> None:
+ index_buf = tir.match_buffer(
+ index, [1], dtype="int32", elem_offset=0, align=128, offset_factor=1
+ )
+ data_buf = tir.match_buffer(data, [16, 16], elem_offset=0, align=128,
offset_factor=1)
+ with tir.block([], "root"):
+ tir.reads([])
+ tir.writes([])
+ out_buf = tir.alloc_buffer([16, 16], elem_offset=0, align=128,
offset_factor=1)
+ for i0, i1 in tir.grid(16, 16):
+ with tir.block([16, 16], "") as [vi, vj]:
+ tir.bind(vi, i0)
+ tir.bind(vj, i1)
+ tir.reads([data_buf[0:16, 0:16], index_buf[0]])
+ tir.writes([out_buf[vi, vj]])
+ out_buf[vi, vj] = data_buf[index_buf[index_buf[0]],
index_buf[0]]
+
+
+def test_complete_buffer_indices():
+ new_func =
tvm.script.from_source(tvm.script.asscript(func_with_bufferslice_indices))
+ tvm.ir.assert_structural_equal(new_func, expected_bufferslice_indices)
+ new_func =
tvm.script.from_source(tvm.script.asscript(func_with_recursive_bufferslice_indices))
+ tvm.ir.assert_structural_equal(new_func,
expected_recursive_bufferslice_indices)
+
+
if __name__ == "__main__":
test_complete_matmul()
test_complete_matmul_original()
test_complete_with_root()
test_complete_opaque_block_error()
test_complete_part_region()
+ test_complete_buffer_indices()
diff --git a/tests/python/unittest/test_tvmscript_error_report.py
b/tests/python/unittest/test_tvmscript_error_report.py
index 052217b..a72b13e 100644
--- a/tests/python/unittest/test_tvmscript_error_report.py
+++ b/tests/python/unittest/test_tvmscript_error_report.py
@@ -291,8 +291,36 @@ def error_index_type() -> None:
A[vi, vj] = A[vi, 0.0] + 1 # error
+def error_bufferslice_index_type() -> None:
+ A = tir.alloc_buffer((1,), "float32")
+ B = tir.alloc_buffer((16, 16), "float32")
+ C = tir.alloc_buffer((16, 16), "float32")
+ with tir.block([16, 16]) as [vi, vj]:
+ C[vi, vj] = B[vi, A[0]] # error
+
+
def test_error_index_type():
check_error(error_index_type, 4)
+ check_error(error_bufferslice_index_type, 6)
+
+
+def error_index_with_stop() -> None:
+ A = tir.alloc_buffer((128, 128), "float32")
+ with tir.block([16, 16]) as [vi, vj]:
+ A[vi, vj] = A[vi, 1:10] + 1 # error
+
+
+def error_bufferslice_index_with_stop() -> None:
+ A = tir.alloc_buffer((1,), "int32")
+ B = tir.alloc_buffer((16, 16), "float32")
+ C = tir.alloc_buffer((16, 16), "float32")
+ with tir.block([16, 16]) as [vi, vj]:
+ C[vi, vj] = B[vi, A[0:1]] # error
+
+
+def test_error_index_with_stop_slice():
+ check_error(error_index_with_stop, 4)
+ check_error(error_bufferslice_index_with_stop, 6)
def mismatch_args() -> None:
@@ -383,5 +411,6 @@ if __name__ == "__main__":
test_opaque_access_during_complete()
test_convert_slice_to_bufferload()
test_error_index_type()
+ test_error_index_with_stop_slice()
test_mismatch_args()
test_tvm_exception_catch()