This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 8e27d6c fix error report on Store (#8895)
8e27d6c is described below
commit 8e27d6c18f3cde541ed065009faf2fc43902cdea
Author: Siyuan Feng <[email protected]>
AuthorDate: Thu Sep 2 05:49:12 2021 +0800
fix error report on Store (#8895)
---
python/tvm/script/parser.py | 2 +-
.../python/unittest/test_tvmscript_error_report.py | 43 ++++++++--------------
2 files changed, 16 insertions(+), 29 deletions(-)
diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py
index 9acf21b..60fc496 100644
--- a/python/tvm/script/parser.py
+++ b/python/tvm/script/parser.py
@@ -536,7 +536,7 @@ class TVMScriptParser(Transformer):
if len(indexes) != 1:
self.report_error(
f"Store is only allowed with one index, but {len(indexes)}
were provided.",
- tvm.ir.Span.union([x.span for x in indexes]),
+ node.params[1].span,
)
# Store
return tvm.tir.Store(
diff --git a/tests/python/unittest/test_tvmscript_error_report.py
b/tests/python/unittest/test_tvmscript_error_report.py
index 7aeceec..70a2aea 100644
--- a/tests/python/unittest/test_tvmscript_error_report.py
+++ b/tests/python/unittest/test_tvmscript_error_report.py
@@ -14,6 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+
+import pytest
+import sys
import tvm
from tvm import tir
from tvm.script import ty, from_source
@@ -380,6 +383,17 @@ def test_match_buffer_shape_mismatch():
check_error(buffer_shape_mismatch, 7)
+def high_dim_store() -> None:
+ with tir.block([], "root"):
+ B = tir.allocate([256], "float32", "global")
+ for i, j in tir.grid(16, 16):
+ B[i, j] = 1.0 # error: Store is only allowed with one index
+
+
+def test_high_dim_store():
+ check_error(high_dim_store, 5)
+
+
def check_error(module, rel_lineno):
# Override the default renderer to accumulate errors
_, start_line = inspect.getsourcelines(module)
@@ -404,31 +418,4 @@ def check_error(module, rel_lineno):
if __name__ == "__main__":
- test_buffer_bind()
- test_range_missing_args()
- test_undefined_buffer()
- test_unsupported_stmt()
- test_unsupported_function_call()
- test_missing_type_annotation()
- test_invalid_expr_stmt()
- test_invalid_for_function()
- test_invalid_block_function()
- test_return_not_allowed()
- test_tir_assert()
- test_no_body()
- test_allocate_with_buffers()
- test_inconsistent_binding()
- test_invalid_block_axes()
- test_miss_block_bind()
- test_invalid_loop_var()
- test_inconsistent_grid()
- test_invalid_match_buffer_region()
- test_duplicate_buffer()
- test_duplicate_block_signature()
- test_opaque_access_during_complete()
- test_convert_slice_to_bufferload()
- test_error_index_type()
- test_error_index_with_stop_slice()
- test_mismatch_args()
- test_tvm_exception_catch()
- test_match_buffer_shape_mismatch()
+ sys.exit(pytest.main([__file__] + sys.argv[1:]))