This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 920654c  [Bugfix][TVMScript] Convert BufferSlice to BufferLoad when 
used as range/loop start and end (#10370)
920654c is described below

commit 920654cf91584fedaf92ca8505620adf542cae71
Author: Zihao Ye <[email protected]>
AuthorDate: Fri Feb 25 13:49:05 2022 -0800

    [Bugfix][TVMScript] Convert BufferSlice to BufferLoad when used as 
range/loop start and end (#10370)
    
    A quick fix of the parser issue mentioned in #10327 .
    Ranges and loops require `start` and `stop` to be PrimExpr, however, 
`BufferSlice` is not always scalar so it's not a `PrimExpr`.
    This PR performs the transformation.
---
 python/tvm/script/parser.py                       |  2 +-
 python/tvm/script/tir/node.py                     |  3 ++-
 python/tvm/script/tir/scope_handler.py            |  3 ++-
 src/printer/tvmscript_printer.cc                  |  7 ++++--
 tests/python/unittest/test_tvmscript_roundtrip.py | 26 +++++++++++++++++++++++
 5 files changed, 36 insertions(+), 5 deletions(-)

diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py
index caf6bc4..922c654 100644
--- a/python/tvm/script/parser.py
+++ b/python/tvm/script/parser.py
@@ -676,7 +676,7 @@ class TVMScriptParser(Transformer):
         self.current_col_offset = node.span.start_column
         self.context.enter_scope(nodes=node.body.stmts)
         # for scope handler process the scope
-        arg_list = self.parse_arg_list(func, node.rhs)
+        arg_list = [tvm.runtime.convert(arg, span=node.rhs.span) for arg in 
self.parse_arg_list(func, node.rhs)]
         func.enter_scope(node, self.context, arg_list, node.rhs.func_name.span)
         func.body = self.parse_body(node)
         res = func.exit_scope(node, self.context, arg_list, 
node.rhs.func_name.span)
diff --git a/python/tvm/script/tir/node.py b/python/tvm/script/tir/node.py
index cfbc668..8564fc1 100644
--- a/python/tvm/script/tir/node.py
+++ b/python/tvm/script/tir/node.py
@@ -20,7 +20,7 @@
 from typing import Optional, Union, List, Callable
 import synr
 
-from tvm.runtime import ObjectGeneric
+from tvm.runtime import ObjectGeneric, convert
 from tvm.tir import PrimExpr, Buffer, BufferLoad
 from tvm.ir import Span
 
@@ -111,6 +111,7 @@ class BufferSlice(ObjectGeneric):
         slices: List[Union[Slice, BufferSlice]] = []
         for index in indices:
             if isinstance(index, Slice):
+                index.start, index.stop = [convert(_) for _ in [index.start, 
index.stop]]
                 check_index(index.start)
                 check_index(index.stop)
                 slices.append(index)
diff --git a/python/tvm/script/tir/scope_handler.py 
b/python/tvm/script/tir/scope_handler.py
index 418bae1..07ba204 100644
--- a/python/tvm/script/tir/scope_handler.py
+++ b/python/tvm/script/tir/scope_handler.py
@@ -21,7 +21,7 @@ from typing import Tuple, Any, Callable, Optional, List, 
Union, Mapping
 import synr
 import numpy as np
 import tvm.tir
-from tvm.runtime import Object, String
+from tvm.runtime import Object, String, convert
 from tvm.ir import Span, Range
 from tvm.tir import Stmt, PrimExpr, IterVar, Var, Buffer, BufferRegion, ForKind
 
@@ -532,6 +532,7 @@ class ForScopeHandler(ScopeHandler):
         for : For
             The constructed For.
         """
+        begin, end = [convert(_) for _ in [begin, end]]
         assert self.context and self.node, "call 'exit_scope' before 
'enter_scope'"
         extent = end if begin == 0 else self.context.analyzer.simplify(end - 
begin)
         self.annotations: Mapping[str, Object] = {}
diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc
index fe85eb3..e1ccd2f 100644
--- a/src/printer/tvmscript_printer.cc
+++ b/src/printer/tvmscript_printer.cc
@@ -22,6 +22,7 @@
  * \brief Printer class to print Tensor IR to python syntax script
  */
 
+#include <tvm/arith/analyzer.h>
 #include <tvm/ir/module.h>
 #include <tvm/node/serialization.h>
 #include <tvm/runtime/registry.h>
@@ -198,6 +199,8 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const 
Stmt&)>,
    * than in the header.
    */
   Map<Var, Array<Buffer>> buffer_var_usage_;
+  /*! \brief Analyzer to simplify some expressions. */
+  arith::Analyzer ana_;
 
   Doc VisitExpr_(const CastNode* op, ExprPrecedence* out_precedence) override;
   Doc VisitExpr_(const VarNode* op, ExprPrecedence* out_precedence) override;
@@ -1607,7 +1610,7 @@ Doc TVMScriptPrinter::PrintBufferRegion(const 
BufferRegionNode* op) {
       if (i != 0) doc << ", ";
       const auto& range = op->region[i];
       if (!is_one(range->extent)) {
-        doc << Print(range->min) << " : " << Print(range->min + range->extent);
+        doc << Print(range->min) << " : " << Print(ana_.Simplify(range->min + 
range->extent));
       } else {
         doc << Print(range->min);
       }
@@ -1641,7 +1644,7 @@ Doc TVMScriptPrinter::PrintLoop(const For& loop) {
   if (is_zero(loop->min)) {
     res << Print(loop->extent);
   } else {
-    res << Print(loop->min) << ", " << Print(loop->min + loop->extent);
+    res << Print(loop->min) << ", " << Print(ana_.Simplify(loop->min + 
loop->extent));
   }
   if (loop->thread_binding.defined()) {
     res << ", thread=";
diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py 
b/tests/python/unittest/test_tvmscript_roundtrip.py
index 1633c05..e3a70bb 100644
--- a/tests/python/unittest/test_tvmscript_roundtrip.py
+++ b/tests/python/unittest/test_tvmscript_roundtrip.py
@@ -3374,5 +3374,31 @@ def test_T_ptr_allocate():
     tvm.ir.assert_structural_equal(func, rt_func, True)
 
 
[email protected]_func
+def segment_sum(
+    A_ptr: T.handle, B_ptr: T.handle, indptr_ptr: T.handle, n: T.int32, m: 
T.int32
+) -> None:
+    A = T.match_buffer(A_ptr, [m], dtype="float32")
+    B = T.match_buffer(B_ptr, [n], dtype="float32")
+    indptr = T.match_buffer(indptr_ptr, [n + 1], dtype="int32")
+    for i in T.serial(n):
+        with T.block("outer"):
+            vi = T.axis.spatial(n, i)
+            T.reads(indptr[i : i + 2], B[vi], A[indptr[i] : indptr[i + 1]])
+            T.writes(B[vi])
+            for j in T.serial(indptr[i], indptr[i + 1]):
+                with T.block("inner"):
+                    vj = T.axis.reduce(m, j)
+                    T.reads(B[vi], A[vj])
+                    T.writes(B[vi])
+                    with T.init():
+                        B[vi] = T.float32(0)
+                    B[vi] = B[vi] + A[vj]
+
+
+def test_parse_bufferslice_as_range_bound():
+    tvm.ir.assert_structural_equal(segment_sum, 
tvm.script.from_source(segment_sum.script()))
+
+
 if __name__ == "__main__":
     sys.exit(pytest.main([__file__] + sys.argv[1:]))

Reply via email to