This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 920654c [Bugfix][TVMScript] Convert BufferSlice to BufferLoad when
used as range/loop start and end (#10370)
920654c is described below
commit 920654cf91584fedaf92ca8505620adf542cae71
Author: Zihao Ye <[email protected]>
AuthorDate: Fri Feb 25 13:49:05 2022 -0800
[Bugfix][TVMScript] Convert BufferSlice to BufferLoad when used as
range/loop start and end (#10370)
A quick fix of the parser issue mentioned in #10327 .
Ranges and loops require `start` and `stop` to be PrimExpr, however,
`BufferSlice` is not always scalar so it's not a `PrimExpr`.
This PR performs the transformation.
---
python/tvm/script/parser.py | 2 +-
python/tvm/script/tir/node.py | 3 ++-
python/tvm/script/tir/scope_handler.py | 3 ++-
src/printer/tvmscript_printer.cc | 7 ++++--
tests/python/unittest/test_tvmscript_roundtrip.py | 26 +++++++++++++++++++++++
5 files changed, 36 insertions(+), 5 deletions(-)
diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py
index caf6bc4..922c654 100644
--- a/python/tvm/script/parser.py
+++ b/python/tvm/script/parser.py
@@ -676,7 +676,7 @@ class TVMScriptParser(Transformer):
self.current_col_offset = node.span.start_column
self.context.enter_scope(nodes=node.body.stmts)
# for scope handler process the scope
- arg_list = self.parse_arg_list(func, node.rhs)
+ arg_list = [tvm.runtime.convert(arg, span=node.rhs.span) for arg in
self.parse_arg_list(func, node.rhs)]
func.enter_scope(node, self.context, arg_list, node.rhs.func_name.span)
func.body = self.parse_body(node)
res = func.exit_scope(node, self.context, arg_list,
node.rhs.func_name.span)
diff --git a/python/tvm/script/tir/node.py b/python/tvm/script/tir/node.py
index cfbc668..8564fc1 100644
--- a/python/tvm/script/tir/node.py
+++ b/python/tvm/script/tir/node.py
@@ -20,7 +20,7 @@
from typing import Optional, Union, List, Callable
import synr
-from tvm.runtime import ObjectGeneric
+from tvm.runtime import ObjectGeneric, convert
from tvm.tir import PrimExpr, Buffer, BufferLoad
from tvm.ir import Span
@@ -111,6 +111,7 @@ class BufferSlice(ObjectGeneric):
slices: List[Union[Slice, BufferSlice]] = []
for index in indices:
if isinstance(index, Slice):
+ index.start, index.stop = [convert(_) for _ in [index.start,
index.stop]]
check_index(index.start)
check_index(index.stop)
slices.append(index)
diff --git a/python/tvm/script/tir/scope_handler.py
b/python/tvm/script/tir/scope_handler.py
index 418bae1..07ba204 100644
--- a/python/tvm/script/tir/scope_handler.py
+++ b/python/tvm/script/tir/scope_handler.py
@@ -21,7 +21,7 @@ from typing import Tuple, Any, Callable, Optional, List,
Union, Mapping
import synr
import numpy as np
import tvm.tir
-from tvm.runtime import Object, String
+from tvm.runtime import Object, String, convert
from tvm.ir import Span, Range
from tvm.tir import Stmt, PrimExpr, IterVar, Var, Buffer, BufferRegion, ForKind
@@ -532,6 +532,7 @@ class ForScopeHandler(ScopeHandler):
for : For
The constructed For.
"""
+ begin, end = [convert(_) for _ in [begin, end]]
assert self.context and self.node, "call 'exit_scope' before
'enter_scope'"
extent = end if begin == 0 else self.context.analyzer.simplify(end -
begin)
self.annotations: Mapping[str, Object] = {}
diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc
index fe85eb3..e1ccd2f 100644
--- a/src/printer/tvmscript_printer.cc
+++ b/src/printer/tvmscript_printer.cc
@@ -22,6 +22,7 @@
* \brief Printer class to print Tensor IR to python syntax script
*/
+#include <tvm/arith/analyzer.h>
#include <tvm/ir/module.h>
#include <tvm/node/serialization.h>
#include <tvm/runtime/registry.h>
@@ -198,6 +199,8 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const
Stmt&)>,
* than in the header.
*/
Map<Var, Array<Buffer>> buffer_var_usage_;
+ /*! \brief Analyzer to simplify some expressions. */
+ arith::Analyzer ana_;
Doc VisitExpr_(const CastNode* op, ExprPrecedence* out_precedence) override;
Doc VisitExpr_(const VarNode* op, ExprPrecedence* out_precedence) override;
@@ -1607,7 +1610,7 @@ Doc TVMScriptPrinter::PrintBufferRegion(const
BufferRegionNode* op) {
if (i != 0) doc << ", ";
const auto& range = op->region[i];
if (!is_one(range->extent)) {
- doc << Print(range->min) << " : " << Print(range->min + range->extent);
+ doc << Print(range->min) << " : " << Print(ana_.Simplify(range->min +
range->extent));
} else {
doc << Print(range->min);
}
@@ -1641,7 +1644,7 @@ Doc TVMScriptPrinter::PrintLoop(const For& loop) {
if (is_zero(loop->min)) {
res << Print(loop->extent);
} else {
- res << Print(loop->min) << ", " << Print(loop->min + loop->extent);
+ res << Print(loop->min) << ", " << Print(ana_.Simplify(loop->min +
loop->extent));
}
if (loop->thread_binding.defined()) {
res << ", thread=";
diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py
b/tests/python/unittest/test_tvmscript_roundtrip.py
index 1633c05..e3a70bb 100644
--- a/tests/python/unittest/test_tvmscript_roundtrip.py
+++ b/tests/python/unittest/test_tvmscript_roundtrip.py
@@ -3374,5 +3374,31 @@ def test_T_ptr_allocate():
tvm.ir.assert_structural_equal(func, rt_func, True)
[email protected]_func
+def segment_sum(
+ A_ptr: T.handle, B_ptr: T.handle, indptr_ptr: T.handle, n: T.int32, m:
T.int32
+) -> None:
+ A = T.match_buffer(A_ptr, [m], dtype="float32")
+ B = T.match_buffer(B_ptr, [n], dtype="float32")
+ indptr = T.match_buffer(indptr_ptr, [n + 1], dtype="int32")
+ for i in T.serial(n):
+ with T.block("outer"):
+ vi = T.axis.spatial(n, i)
+ T.reads(indptr[i : i + 2], B[vi], A[indptr[i] : indptr[i + 1]])
+ T.writes(B[vi])
+ for j in T.serial(indptr[i], indptr[i + 1]):
+ with T.block("inner"):
+ vj = T.axis.reduce(m, j)
+ T.reads(B[vi], A[vj])
+ T.writes(B[vi])
+ with T.init():
+ B[vi] = T.float32(0)
+ B[vi] = B[vi] + A[vj]
+
+
+def test_parse_bufferslice_as_range_bound():
+ tvm.ir.assert_structural_equal(segment_sum,
tvm.script.from_source(segment_sum.script()))
+
+
if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))