This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 429f601284 [TIR] Enhance TVMScript Buffer Slice Access (#14693)
429f601284 is described below

commit 429f60128423d01fdfdc42b37de687ebc011c846
Author: lightzhan <[email protected]>
AuthorDate: Sun Jun 4 17:10:50 2023 +0800

    [TIR] Enhance TVMScript Buffer Slice Access (#14693)
---
 python/tvm/script/parser/core/evaluator.py  | 38 ++++++++++++
 python/tvm/script/parser/tir/operation.py   | 92 ++++++++++++++++++++++++++---
 python/tvm/script/parser/tir/parser.py      | 22 +++++++
 python/tvm/tir/buffer.py                    |  5 +-
 tests/python/unittest/test_tvmscript_ops.py | 69 ++++++++++++++++++++++
 5 files changed, 217 insertions(+), 9 deletions(-)

diff --git a/python/tvm/script/parser/core/evaluator.py 
b/python/tvm/script/parser/core/evaluator.py
index 96901c522d..e2b67341dc 100644
--- a/python/tvm/script/parser/core/evaluator.py
+++ b/python/tvm/script/parser/core/evaluator.py
@@ -158,6 +158,44 @@ class ExprEvaluator:
         res : Any
             The evaluation result.
         """
+        args = []
+        if (
+            isinstance(node, doc.Call)
+            and hasattr(node.func, "attr")
+            and node.func.attr not in ["reads", "writes", "match_buffer", 
"realize"]
+        ) or isinstance(node, (doc.BinOp, doc.UnaryOp, doc.Compare, 
doc.BoolOp)):
+            if isinstance(node, doc.BinOp):
+                args = [node.left, node.right]
+            elif isinstance(node, doc.UnaryOp):
+                args = [node.operand]
+            elif isinstance(node, doc.Compare):
+                args = [node.left, *node.comparators]
+            else:
+                if isinstance(node, doc.Call):
+                    args = node.args
+                elif isinstance(node, doc.BoolOp):
+                    args = node.values
+        for arg in args:
+            if isinstance(arg, doc.Subscript) and isinstance(arg.slice, 
(doc.Slice, doc.Tuple)):
+                if isinstance(arg.slice, doc.Slice):
+                    check_slices = [arg.slice]
+                else:
+                    check_slices = []
+                    for p in arg.slice.elts:
+                        if isinstance(p, doc.Slice):
+                            check_slices.append(p)
+                for s in check_slices:
+                    if not s.step and s.upper and s.lower:
+                        s.step = doc.Constant(
+                            1,
+                            None,
+                            1,
+                            1,
+                            s.upper.lineno,
+                            s.upper.end_col_offset + 1,
+                            s.upper.lineno,
+                            s.upper.end_col_offset + 2,
+                        )
         if isinstance(node, list):
             return [self._visit(n) for n in node]
         if isinstance(node, tuple):
diff --git a/python/tvm/script/parser/tir/operation.py 
b/python/tvm/script/parser/tir/operation.py
index 3e120339a6..ab01d91c02 100644
--- a/python/tvm/script/parser/tir/operation.py
+++ b/python/tvm/script/parser/tir/operation.py
@@ -19,7 +19,9 @@
 from typing import Type
 
 from tvm import tir
+from tvm._ffi.runtime_ctypes import DataType, DataTypeCode
 from tvm.tir import IntImm
+from tvm.tir.expr import FloatImm
 
 from .._core import OpMethod, doc, register_op
 
@@ -32,14 +34,88 @@ def _register_expr_op(ty: Type):  # pylint: 
disable=invalid-name
             a = IntImm("bool", a)
         if isinstance(b, bool):
             b = IntImm("bool", b)
-        return tir.And(a, b)
+        if DataType(a.dtype).lanes > 1 or DataType(b.dtype).lanes > 1:
+            return a & b
+        else:
+            return tir.And(a, b)
 
     def _or(a, b):
         if isinstance(a, bool):
             a = IntImm("bool", a)
         if isinstance(b, bool):
             b = IntImm("bool", b)
-        return tir.Or(a, b)
+        if DataType(a.dtype).lanes > 1 or DataType(b.dtype).lanes > 1:
+            return a | b
+        else:
+            return tir.Or(a, b)
+
+    def _get_type_str(dtype: str):
+        if DataType(dtype).lanes == 1:
+            return dtype
+        index = dtype.find("x")
+        return dtype[0:index]
+
+    def _auto_broadcast(a, b, op):
+
+        if isinstance(a, int):
+            if hasattr(b, "dtype"):
+                if (
+                    DataType(b.dtype).type_code == DataTypeCode.INT
+                    or DataType(b.dtype).type_code == DataTypeCode.UINT
+                ):
+                    a = IntImm(_get_type_str(b.dtype), a)
+                elif DataType(b.dtype).type_code == DataTypeCode.FLOAT:
+                    a = FloatImm(_get_type_str(b.dtype), a)
+            elif isinstance(b, float):
+                a = FloatImm("float32", a)
+            else:
+                a = IntImm("int32", a)
+        elif isinstance(a, float):
+            if DataType(b.dtype).type_code == DataTypeCode.FLOAT:
+                a = FloatImm(_get_type_str(b.dtype), a)
+            else:
+                a = FloatImm("float32", a)
+
+        assert isinstance(a, tir.PrimExpr), "Operand should be a PrimExpr."
+        if isinstance(b, int):
+            if (
+                DataType(a.dtype).type_code == DataTypeCode.INT
+                or DataType(a.dtype).type_code == DataTypeCode.UINT
+            ):
+                b = IntImm(_get_type_str(a.dtype), b)
+            elif DataType(a.dtype).type_code == DataTypeCode.FLOAT:
+                b = FloatImm(_get_type_str(a.dtype), b)
+        elif isinstance(b, float):
+            b = FloatImm(_get_type_str(a.dtype), b)
+
+        if DataType(a.dtype).lanes == DataType(b.dtype).lanes:
+            return op(a, b)
+        elif DataType(a.dtype).lanes == 1 and DataType(a.dtype).lanes != 
DataType(b.dtype).lanes:
+            broadcast_a = tir.Broadcast(a, DataType(b.dtype).lanes)
+            return op(broadcast_a, b)
+        elif DataType(b.dtype).lanes == 1 and DataType(a.dtype).lanes != 
DataType(b.dtype).lanes:
+            broadcast_b = tir.Broadcast(b, DataType(a.dtype).lanes)
+            return op(a, broadcast_b)
+        else:
+            raise TypeError("do not know how to deal with it.")
+
+    def _eq(a, b):
+        return _auto_broadcast(a, b, tir.EQ)
+
+    def _ne(a, b):
+        return _auto_broadcast(a, b, tir.NE)
+
+    def _lt(a, b):
+        return _auto_broadcast(a, b, tir.LT)
+
+    def _le(a, b):
+        return _auto_broadcast(a, b, tir.LE)
+
+    def _gt(a, b):
+        return _auto_broadcast(a, b, tir.GT)
+
+    def _ge(a, b):
+        return _auto_broadcast(a, b, tir.GE)
 
     def r(op: Type, i: int, m: OpMethod):  # pylint: disable=invalid-name
         register_op(ty, op, i)(m)
@@ -60,12 +136,12 @@ def _register_expr_op(ty: Type):  # pylint: 
disable=invalid-name
         # doc.MatMult <-- not implemented
         # doc.Pow <-- not implemented
         # Case 2. cmpop
-        r(doc.Eq, i, tir.EQ)
-        r(doc.NotEq, i, tir.NE)
-        r(doc.Lt, i, tir.LT)
-        r(doc.LtE, i, tir.LE)
-        r(doc.Gt, i, tir.GT)
-        r(doc.GtE, i, tir.GE)
+        r(doc.Eq, i, _eq)
+        r(doc.NotEq, i, _ne)
+        r(doc.Lt, i, _lt)
+        r(doc.LtE, i, _le)
+        r(doc.Gt, i, _gt)
+        r(doc.GtE, i, _ge)
         # doc.Is <-- not implemented
         # doc.IsNot <-- not implemented
         # doc.In <-- not implemented
diff --git a/python/tvm/script/parser/tir/parser.py 
b/python/tvm/script/parser/tir/parser.py
index 7d81fecedb..f81f9bd9ea 100644
--- a/python/tvm/script/parser/tir/parser.py
+++ b/python/tvm/script/parser/tir/parser.py
@@ -211,6 +211,28 @@ def visit_assign(self: Parser, node: doc.Assign) -> None:
     if len(node.targets) != 1:
         self.report_error(node, "Consequential assignments like 'a = b = c' 
are not supported.")
     lhs = node.targets[0]
+
+    if isinstance(node.value, doc.Subscript):
+        check_slices = []
+        if isinstance(node.value.slice, doc.Slice):
+            check_slices = [node.value.slice]
+        elif isinstance(node.value.slice, doc.Tuple):
+            for p in node.value.slice.elts:
+                if isinstance(p, doc.Slice):
+                    check_slices.append(p)
+        for s in check_slices:
+            if not s.step and s.upper and s.lower:
+                s.step = doc.Constant(
+                    1,
+                    None,
+                    1,
+                    1,
+                    s.upper.lineno,
+                    s.upper.end_col_offset + 1,
+                    s.upper.lineno,
+                    s.upper.end_col_offset + 2,
+                )
+
     rhs = self.eval_expr(node.value)
     if isinstance(lhs, doc.Subscript):
         if isinstance(lhs.slice, doc.Tuple):
diff --git a/python/tvm/tir/buffer.py b/python/tvm/tir/buffer.py
index 764b8a3dd3..ec57ad7801 100644
--- a/python/tvm/tir/buffer.py
+++ b/python/tvm/tir/buffer.py
@@ -203,11 +203,14 @@ class Buffer(Object, Scriptable):
             return BufferRegion(self, region)
         else:
             expr_indices = []
-            for index in indices:
+            for i, index in enumerate(indices):
                 if isinstance(index, slice):
                     start = 0 if index.start is None else index.start
                     stop = self.shape[i] if index.stop is None else index.stop
                     step = 1 if index.step is None else index.step
+                    # We should ensure the dtype of start is the same with 
that of step.
+                    if isinstance(start, tvm.tir.expr.PrimExpr) and 
isinstance(step, int):
+                        step = tvm.tir.expr.IntImm(start.dtype, step)
                     lanes = analyzer.simplify((stop - start + step - 1) // 
step)
                     if lanes == 1:
                         expr_indices.append(start)
diff --git a/tests/python/unittest/test_tvmscript_ops.py 
b/tests/python/unittest/test_tvmscript_ops.py
index 8eba301fe7..671fe3cc19 100644
--- a/tests/python/unittest/test_tvmscript_ops.py
+++ b/tests/python/unittest/test_tvmscript_ops.py
@@ -177,6 +177,75 @@ def test_ceildiv():
     tvm.testing.assert_allclose(a.numpy(), ref)
 
 
[email protected]_func
+def slice_op_test(
+    A: T.Buffer((10,), "float32"), B: T.Buffer((10,), "float32"), C: 
T.Buffer((10,), "uint32")
+):
+    B[0:5] = A[0:5] + B[0:5]
+    B[0:5] = A[0:5] - B[0:5]
+    B[0:5] = A[0:5] * B[0:5]
+    B[0:5] = A[0:5] / B[0:5]
+    C[0:5] = C[0:5] % T.broadcast(T.uint32(5), 5)
+    B[0:5] = -B[0:5]
+    C[0:5] = C[0:5] >> 4
+    C[0:5] = C[0:5] << 4
+    C[0:5] = C[0:5] << C[0:5]
+    C[0:5] = C[0:5] >> C[0:5]
+    T.evaluate(A[0:5] > B[0:5])
+    T.evaluate(A[0:5] > 5)
+    T.evaluate(A[0:5] >= B[0:5])
+    T.evaluate(A[0:5] >= 5)
+    T.evaluate(A[0:5] < B[0:5])
+    T.evaluate(A[0:5] < 5)
+    T.evaluate(A[0:5] <= B[0:5])
+    T.evaluate(A[0:5] <= 5)
+    T.evaluate(A[0:5] == B[0:5])
+    T.evaluate(A[0:5] == 5)
+    T.evaluate(A[0:5] != B[0:5])
+    T.evaluate(A[0:5] != 5)
+    T.evaluate((A[0:5] > 0) and (B[0:5] > 0))
+    T.evaluate((A[0:5] > 0) or (B[0:5] > 0))
+    T.evaluate((A[0:5] < 0) and (1 > 0))
+    T.evaluate((A[0:5] > 0) or (1 > 0))
+
+
[email protected]_func
+def slice_op_test_ref(
+    A: T.Buffer((10,), "float32"), B: T.Buffer((10,), "float32"), C: 
T.Buffer((10,), "uint32")
+):
+    B[0:5] = A[0:5] + B[0:5]
+    B[0:5] = A[0:5] - B[0:5]
+    B[0:5] = A[0:5] * B[0:5]
+    B[0:5] = A[0:5] / B[0:5]
+    C[0:5] = C[0:5] % T.Broadcast(T.uint32(5), 5)
+    B[0:5] = B[0:5] * T.Broadcast(T.float32(-1), 5)
+    C[0:5] = T.shift_right(C[0:5], T.Broadcast(T.uint32(4), 5))
+    C[0:5] = T.shift_left(C[0:5], T.Broadcast(T.uint32(4), 5))
+    C[0:5] = T.shift_left(C[0:5], C[0:5])
+    C[0:5] = T.shift_right(C[0:5], C[0:5])
+    T.evaluate(A[0:5] > B[0:5])
+    T.evaluate(A[0:5] > T.Broadcast(T.float32(5), 5))
+    T.evaluate(A[0:5] >= B[0:5])
+    T.evaluate(A[0:5] >= T.Broadcast(T.float32(5), 5))
+    T.evaluate(A[0:5] < B[0:5])
+    T.evaluate(A[0:5] < T.Broadcast(T.float32(5), 5))
+    T.evaluate(A[0:5] <= B[0:5])
+    T.evaluate(A[0:5] <= T.Broadcast(T.float32(5), 5))
+    T.evaluate(A[0:5] == B[0:5])
+    T.evaluate(A[0:5] == T.Broadcast(T.float32(5), 5))
+    T.evaluate(A[0:5] != B[0:5])
+    T.evaluate(A[0:5] != T.Broadcast(T.float32(5), 5))
+    T.bitwise_and(A[0:5] > T.Broadcast(T.float32(0), 5), B[0:5] > 
T.Broadcast(T.float32(0), 5))
+    T.bitwise_or(A[0:5] > T.Broadcast(T.float32(0), 5), B[0:5] > 
T.Broadcast(T.float32(0), 5))
+    T.bitwise_and(A[0:5] < T.Broadcast(T.float32(0), 5), 
T.Broadcast(T.bool(1), 5))
+    T.bitwise_or(A[0:5] > T.Broadcast(T.float32(0), 5), T.Broadcast(T.bool(1), 
5))
+
+
+def test_slice_op():
+    tvm.ir.assert_structural_equal(slice_op_test, slice_op_test_ref)
+
+
 if __name__ == "__main__":
     test_get_valid_counts_script_func()
     test_alloc_zero_dim_buffer_round_trip()
+    test_slice_op()

Reply via email to