This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 513fcf4  [TVMSCRIPT] TVMScript Parser support BufferSlice indices 
(#8408)
513fcf4 is described below

commit 513fcf4bb7ea53f8663a6d56e67d82fe49fde21d
Author: Honghua Cao <[email protected]>
AuthorDate: Sat Jul 10 02:33:23 2021 +0800

    [TVMSCRIPT] TVMScript Parser support BufferSlice indices (#8408)
    
    
    
    Co-authored-by: honghua.cao <[email protected]>
---
 python/tvm/script/node.py                          |  6 +-
 tests/python/unittest/test_tvmscript_complete.py   | 66 ++++++++++++++++++++++
 .../python/unittest/test_tvmscript_error_report.py | 29 ++++++++++
 3 files changed, 100 insertions(+), 1 deletion(-)

diff --git a/python/tvm/script/node.py b/python/tvm/script/node.py
index c459368..cfbc668 100644
--- a/python/tvm/script/node.py
+++ b/python/tvm/script/node.py
@@ -108,7 +108,7 @@ class BufferSlice(ObjectGeneric):
                     span,
                 )
 
-        slices: List[Slice] = []
+        slices: List[Union[Slice, BufferSlice]] = []
         for index in indices:
             if isinstance(index, Slice):
                 check_index(index.start)
@@ -117,6 +117,10 @@ class BufferSlice(ObjectGeneric):
             elif isinstance(index, (PrimExpr, int)):
                 check_index(index)
                 slices.append(Slice(index))
+            elif isinstance(index, BufferSlice):
+                buffer_load = index.asobject()
+                check_index(buffer_load)
+                slices.append(Slice(buffer_load))
             else:
                 report_error(
                     "Unsupported index type for BufferSlice, "
diff --git a/tests/python/unittest/test_tvmscript_complete.py 
b/tests/python/unittest/test_tvmscript_complete.py
index 76d4d9b..a4d2dec 100644
--- a/tests/python/unittest/test_tvmscript_complete.py
+++ b/tests/python/unittest/test_tvmscript_complete.py
@@ -190,9 +190,75 @@ def test_complete_opaque_block_error():
     assert False
 
 
[email protected]
+def func_with_bufferslice_indices(data: ty.handle, index: ty.handle) -> None:
+    data_buf = tir.match_buffer(data, (16, 16), "float32")
+    index_buf = tir.match_buffer(index, (1,), "int32")
+    out_buf = tir.alloc_buffer((16, 16), "float32")
+
+    with tir.block([16, 16]) as [vi, vj]:
+        out_buf[vi, vj] = data_buf[vi, index_buf[0]]
+
+
[email protected]
+def expected_bufferslice_indices(data: ty.handle, index: ty.handle) -> None:
+    index_buf = tir.match_buffer(
+        index, [1], dtype="int32", elem_offset=0, align=128, offset_factor=1
+    )
+    data_buf = tir.match_buffer(data, [16, 16], elem_offset=0, align=128, 
offset_factor=1)
+    with tir.block([], "root"):
+        tir.reads([])
+        tir.writes([])
+        out_buf = tir.alloc_buffer([16, 16], elem_offset=0, align=128, 
offset_factor=1)
+        for i0, i1 in tir.grid(16, 16):
+            with tir.block([16, 16], "") as [vi, vj]:
+                tir.bind(vi, i0)
+                tir.bind(vj, i1)
+                tir.reads([data_buf[vi, 0:16], index_buf[0]])
+                tir.writes([out_buf[vi, vj]])
+                out_buf[vi, vj] = data_buf[vi, index_buf[0]]
+
+
[email protected]
+def func_with_recursive_bufferslice_indices(data: ty.handle, index: ty.handle) 
-> None:
+    data_buf = tir.match_buffer(data, (16, 16), "float32")
+    index_buf = tir.match_buffer(index, (1,), "int32")
+    out_buf = tir.alloc_buffer((16, 16), "float32")
+
+    with tir.block([16, 16]) as [vi, vj]:
+        out_buf[vi, vj] = data_buf[index_buf[index_buf[0]], index_buf[0]]
+
+
[email protected]
+def expected_recursive_bufferslice_indices(data: ty.handle, index: ty.handle) 
-> None:
+    index_buf = tir.match_buffer(
+        index, [1], dtype="int32", elem_offset=0, align=128, offset_factor=1
+    )
+    data_buf = tir.match_buffer(data, [16, 16], elem_offset=0, align=128, 
offset_factor=1)
+    with tir.block([], "root"):
+        tir.reads([])
+        tir.writes([])
+        out_buf = tir.alloc_buffer([16, 16], elem_offset=0, align=128, 
offset_factor=1)
+        for i0, i1 in tir.grid(16, 16):
+            with tir.block([16, 16], "") as [vi, vj]:
+                tir.bind(vi, i0)
+                tir.bind(vj, i1)
+                tir.reads([data_buf[0:16, 0:16], index_buf[0]])
+                tir.writes([out_buf[vi, vj]])
+                out_buf[vi, vj] = data_buf[index_buf[index_buf[0]], 
index_buf[0]]
+
+
+def test_complete_buffer_indices():
+    new_func = 
tvm.script.from_source(tvm.script.asscript(func_with_bufferslice_indices))
+    tvm.ir.assert_structural_equal(new_func, expected_bufferslice_indices)
+    new_func = 
tvm.script.from_source(tvm.script.asscript(func_with_recursive_bufferslice_indices))
+    tvm.ir.assert_structural_equal(new_func, 
expected_recursive_bufferslice_indices)
+
+
 if __name__ == "__main__":
     test_complete_matmul()
     test_complete_matmul_original()
     test_complete_with_root()
     test_complete_opaque_block_error()
     test_complete_part_region()
+    test_complete_buffer_indices()
diff --git a/tests/python/unittest/test_tvmscript_error_report.py 
b/tests/python/unittest/test_tvmscript_error_report.py
index 052217b..a72b13e 100644
--- a/tests/python/unittest/test_tvmscript_error_report.py
+++ b/tests/python/unittest/test_tvmscript_error_report.py
@@ -291,8 +291,36 @@ def error_index_type() -> None:
         A[vi, vj] = A[vi, 0.0] + 1  # error
 
 
+def error_bufferslice_index_type() -> None:
+    A = tir.alloc_buffer((1,), "float32")
+    B = tir.alloc_buffer((16, 16), "float32")
+    C = tir.alloc_buffer((16, 16), "float32")
+    with tir.block([16, 16]) as [vi, vj]:
+        C[vi, vj] = B[vi, A[0]]  # error
+
+
 def test_error_index_type():
     check_error(error_index_type, 4)
+    check_error(error_bufferslice_index_type, 6)
+
+
+def error_index_with_stop() -> None:
+    A = tir.alloc_buffer((128, 128), "float32")
+    with tir.block([16, 16]) as [vi, vj]:
+        A[vi, vj] = A[vi, 1:10] + 1  # error
+
+
+def error_bufferslice_index_with_stop() -> None:
+    A = tir.alloc_buffer((1,), "int32")
+    B = tir.alloc_buffer((16, 16), "float32")
+    C = tir.alloc_buffer((16, 16), "float32")
+    with tir.block([16, 16]) as [vi, vj]:
+        C[vi, vj] = B[vi, A[0:1]]  # error
+
+
+def test_error_index_with_stop_slice():
+    check_error(error_index_with_stop, 4)
+    check_error(error_bufferslice_index_with_stop, 6)
 
 
 def mismatch_args() -> None:
@@ -383,5 +411,6 @@ if __name__ == "__main__":
     test_opaque_access_during_complete()
     test_convert_slice_to_bufferload()
     test_error_index_type()
+    test_error_index_with_stop_slice()
     test_mismatch_args()
     test_tvm_exception_catch()

Reply via email to