This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 429f601284 [TIR] Enhance TVMScript Buffer Slice Access (#14693)
429f601284 is described below
commit 429f60128423d01fdfdc42b37de687ebc011c846
Author: lightzhan <[email protected]>
AuthorDate: Sun Jun 4 17:10:50 2023 +0800
[TIR] Enhance TVMScript Buffer Slice Access (#14693)
---
python/tvm/script/parser/core/evaluator.py | 38 ++++++++++++
python/tvm/script/parser/tir/operation.py | 92 ++++++++++++++++++++++++++---
python/tvm/script/parser/tir/parser.py | 22 +++++++
python/tvm/tir/buffer.py | 5 +-
tests/python/unittest/test_tvmscript_ops.py | 69 ++++++++++++++++++++++
5 files changed, 217 insertions(+), 9 deletions(-)
diff --git a/python/tvm/script/parser/core/evaluator.py
b/python/tvm/script/parser/core/evaluator.py
index 96901c522d..e2b67341dc 100644
--- a/python/tvm/script/parser/core/evaluator.py
+++ b/python/tvm/script/parser/core/evaluator.py
@@ -158,6 +158,44 @@ class ExprEvaluator:
res : Any
The evaluation result.
"""
+ args = []
+ if (
+ isinstance(node, doc.Call)
+ and hasattr(node.func, "attr")
+ and node.func.attr not in ["reads", "writes", "match_buffer",
"realize"]
+ ) or isinstance(node, (doc.BinOp, doc.UnaryOp, doc.Compare,
doc.BoolOp)):
+ if isinstance(node, doc.BinOp):
+ args = [node.left, node.right]
+ elif isinstance(node, doc.UnaryOp):
+ args = [node.operand]
+ elif isinstance(node, doc.Compare):
+ args = [node.left, *node.comparators]
+ else:
+ if isinstance(node, doc.Call):
+ args = node.args
+ elif isinstance(node, doc.BoolOp):
+ args = node.values
+ for arg in args:
+ if isinstance(arg, doc.Subscript) and isinstance(arg.slice,
(doc.Slice, doc.Tuple)):
+ if isinstance(arg.slice, doc.Slice):
+ check_slices = [arg.slice]
+ else:
+ check_slices = []
+ for p in arg.slice.elts:
+ if isinstance(p, doc.Slice):
+ check_slices.append(p)
+ for s in check_slices:
+ if not s.step and s.upper and s.lower:
+ s.step = doc.Constant(
+ 1,
+ None,
+ 1,
+ 1,
+ s.upper.lineno,
+ s.upper.end_col_offset + 1,
+ s.upper.lineno,
+ s.upper.end_col_offset + 2,
+ )
if isinstance(node, list):
return [self._visit(n) for n in node]
if isinstance(node, tuple):
diff --git a/python/tvm/script/parser/tir/operation.py
b/python/tvm/script/parser/tir/operation.py
index 3e120339a6..ab01d91c02 100644
--- a/python/tvm/script/parser/tir/operation.py
+++ b/python/tvm/script/parser/tir/operation.py
@@ -19,7 +19,9 @@
from typing import Type
from tvm import tir
+from tvm._ffi.runtime_ctypes import DataType, DataTypeCode
from tvm.tir import IntImm
+from tvm.tir.expr import FloatImm
from .._core import OpMethod, doc, register_op
@@ -32,14 +34,88 @@ def _register_expr_op(ty: Type): # pylint:
disable=invalid-name
a = IntImm("bool", a)
if isinstance(b, bool):
b = IntImm("bool", b)
- return tir.And(a, b)
+ if DataType(a.dtype).lanes > 1 or DataType(b.dtype).lanes > 1:
+ return a & b
+ else:
+ return tir.And(a, b)
def _or(a, b):
if isinstance(a, bool):
a = IntImm("bool", a)
if isinstance(b, bool):
b = IntImm("bool", b)
- return tir.Or(a, b)
+ if DataType(a.dtype).lanes > 1 or DataType(b.dtype).lanes > 1:
+ return a | b
+ else:
+ return tir.Or(a, b)
+
+ def _get_type_str(dtype: str):
+ if DataType(dtype).lanes == 1:
+ return dtype
+ index = dtype.find("x")
+ return dtype[0:index]
+
+ def _auto_broadcast(a, b, op):
+
+ if isinstance(a, int):
+ if hasattr(b, "dtype"):
+ if (
+ DataType(b.dtype).type_code == DataTypeCode.INT
+ or DataType(b.dtype).type_code == DataTypeCode.UINT
+ ):
+ a = IntImm(_get_type_str(b.dtype), a)
+ elif DataType(b.dtype).type_code == DataTypeCode.FLOAT:
+ a = FloatImm(_get_type_str(b.dtype), a)
+ elif isinstance(b, float):
+ a = FloatImm("float32", a)
+ else:
+ a = IntImm("int32", a)
+ elif isinstance(a, float):
+ if DataType(b.dtype).type_code == DataTypeCode.FLOAT:
+ a = FloatImm(_get_type_str(b.dtype), a)
+ else:
+ a = FloatImm("float32", a)
+
+ assert isinstance(a, tir.PrimExpr), "Operand should be a PrimExpr."
+ if isinstance(b, int):
+ if (
+ DataType(a.dtype).type_code == DataTypeCode.INT
+ or DataType(a.dtype).type_code == DataTypeCode.UINT
+ ):
+ b = IntImm(_get_type_str(a.dtype), b)
+ elif DataType(a.dtype).type_code == DataTypeCode.FLOAT:
+ b = FloatImm(_get_type_str(a.dtype), b)
+ elif isinstance(b, float):
+ b = FloatImm(_get_type_str(a.dtype), b)
+
+ if DataType(a.dtype).lanes == DataType(b.dtype).lanes:
+ return op(a, b)
+ elif DataType(a.dtype).lanes == 1 and DataType(a.dtype).lanes !=
DataType(b.dtype).lanes:
+ broadcast_a = tir.Broadcast(a, DataType(b.dtype).lanes)
+ return op(broadcast_a, b)
+ elif DataType(b.dtype).lanes == 1 and DataType(a.dtype).lanes !=
DataType(b.dtype).lanes:
+ broadcast_b = tir.Broadcast(b, DataType(a.dtype).lanes)
+ return op(a, broadcast_b)
+ else:
+ raise TypeError("do not know how to deal with it.")
+
+ def _eq(a, b):
+ return _auto_broadcast(a, b, tir.EQ)
+
+ def _ne(a, b):
+ return _auto_broadcast(a, b, tir.NE)
+
+ def _lt(a, b):
+ return _auto_broadcast(a, b, tir.LT)
+
+ def _le(a, b):
+ return _auto_broadcast(a, b, tir.LE)
+
+ def _gt(a, b):
+ return _auto_broadcast(a, b, tir.GT)
+
+ def _ge(a, b):
+ return _auto_broadcast(a, b, tir.GE)
def r(op: Type, i: int, m: OpMethod): # pylint: disable=invalid-name
register_op(ty, op, i)(m)
@@ -60,12 +136,12 @@ def _register_expr_op(ty: Type): # pylint:
disable=invalid-name
# doc.MatMult <-- not implemented
# doc.Pow <-- not implemented
# Case 2. cmpop
- r(doc.Eq, i, tir.EQ)
- r(doc.NotEq, i, tir.NE)
- r(doc.Lt, i, tir.LT)
- r(doc.LtE, i, tir.LE)
- r(doc.Gt, i, tir.GT)
- r(doc.GtE, i, tir.GE)
+ r(doc.Eq, i, _eq)
+ r(doc.NotEq, i, _ne)
+ r(doc.Lt, i, _lt)
+ r(doc.LtE, i, _le)
+ r(doc.Gt, i, _gt)
+ r(doc.GtE, i, _ge)
# doc.Is <-- not implemented
# doc.IsNot <-- not implemented
# doc.In <-- not implemented
diff --git a/python/tvm/script/parser/tir/parser.py
b/python/tvm/script/parser/tir/parser.py
index 7d81fecedb..f81f9bd9ea 100644
--- a/python/tvm/script/parser/tir/parser.py
+++ b/python/tvm/script/parser/tir/parser.py
@@ -211,6 +211,28 @@ def visit_assign(self: Parser, node: doc.Assign) -> None:
if len(node.targets) != 1:
self.report_error(node, "Consequential assignments like 'a = b = c'
are not supported.")
lhs = node.targets[0]
+
+ if isinstance(node.value, doc.Subscript):
+ check_slices = []
+ if isinstance(node.value.slice, doc.Slice):
+ check_slices = [node.value.slice]
+ elif isinstance(node.value.slice, doc.Tuple):
+ for p in node.value.slice.elts:
+ if isinstance(p, doc.Slice):
+ check_slices.append(p)
+ for s in check_slices:
+ if not s.step and s.upper and s.lower:
+ s.step = doc.Constant(
+ 1,
+ None,
+ 1,
+ 1,
+ s.upper.lineno,
+ s.upper.end_col_offset + 1,
+ s.upper.lineno,
+ s.upper.end_col_offset + 2,
+ )
+
rhs = self.eval_expr(node.value)
if isinstance(lhs, doc.Subscript):
if isinstance(lhs.slice, doc.Tuple):
diff --git a/python/tvm/tir/buffer.py b/python/tvm/tir/buffer.py
index 764b8a3dd3..ec57ad7801 100644
--- a/python/tvm/tir/buffer.py
+++ b/python/tvm/tir/buffer.py
@@ -203,11 +203,14 @@ class Buffer(Object, Scriptable):
return BufferRegion(self, region)
else:
expr_indices = []
- for index in indices:
+ for i, index in enumerate(indices):
if isinstance(index, slice):
start = 0 if index.start is None else index.start
stop = self.shape[i] if index.stop is None else index.stop
step = 1 if index.step is None else index.step
+ # We should ensure the dtype of start is the same with
that of step.
+ if isinstance(start, tvm.tir.expr.PrimExpr) and
isinstance(step, int):
+ step = tvm.tir.expr.IntImm(start.dtype, step)
lanes = analyzer.simplify((stop - start + step - 1) //
step)
if lanes == 1:
expr_indices.append(start)
diff --git a/tests/python/unittest/test_tvmscript_ops.py
b/tests/python/unittest/test_tvmscript_ops.py
index 8eba301fe7..671fe3cc19 100644
--- a/tests/python/unittest/test_tvmscript_ops.py
+++ b/tests/python/unittest/test_tvmscript_ops.py
@@ -177,6 +177,75 @@ def test_ceildiv():
tvm.testing.assert_allclose(a.numpy(), ref)
[email protected]_func
+def slice_op_test(
+ A: T.Buffer((10,), "float32"), B: T.Buffer((10,), "float32"), C:
T.Buffer((10,), "uint32")
+):
+ B[0:5] = A[0:5] + B[0:5]
+ B[0:5] = A[0:5] - B[0:5]
+ B[0:5] = A[0:5] * B[0:5]
+ B[0:5] = A[0:5] / B[0:5]
+ C[0:5] = C[0:5] % T.broadcast(T.uint32(5), 5)
+ B[0:5] = -B[0:5]
+ C[0:5] = C[0:5] >> 4
+ C[0:5] = C[0:5] << 4
+ C[0:5] = C[0:5] << C[0:5]
+ C[0:5] = C[0:5] >> C[0:5]
+ T.evaluate(A[0:5] > B[0:5])
+ T.evaluate(A[0:5] > 5)
+ T.evaluate(A[0:5] >= B[0:5])
+ T.evaluate(A[0:5] >= 5)
+ T.evaluate(A[0:5] < B[0:5])
+ T.evaluate(A[0:5] < 5)
+ T.evaluate(A[0:5] <= B[0:5])
+ T.evaluate(A[0:5] <= 5)
+ T.evaluate(A[0:5] == B[0:5])
+ T.evaluate(A[0:5] == 5)
+ T.evaluate(A[0:5] != B[0:5])
+ T.evaluate(A[0:5] != 5)
+ T.evaluate((A[0:5] > 0) and (B[0:5] > 0))
+ T.evaluate((A[0:5] > 0) or (B[0:5] > 0))
+ T.evaluate((A[0:5] < 0) and (1 > 0))
+ T.evaluate((A[0:5] > 0) or (1 > 0))
+
+
[email protected]_func
+def slice_op_test_ref(
+ A: T.Buffer((10,), "float32"), B: T.Buffer((10,), "float32"), C:
T.Buffer((10,), "uint32")
+):
+ B[0:5] = A[0:5] + B[0:5]
+ B[0:5] = A[0:5] - B[0:5]
+ B[0:5] = A[0:5] * B[0:5]
+ B[0:5] = A[0:5] / B[0:5]
+ C[0:5] = C[0:5] % T.Broadcast(T.uint32(5), 5)
+ B[0:5] = B[0:5] * T.Broadcast(T.float32(-1), 5)
+ C[0:5] = T.shift_right(C[0:5], T.Broadcast(T.uint32(4), 5))
+ C[0:5] = T.shift_left(C[0:5], T.Broadcast(T.uint32(4), 5))
+ C[0:5] = T.shift_left(C[0:5], C[0:5])
+ C[0:5] = T.shift_right(C[0:5], C[0:5])
+ T.evaluate(A[0:5] > B[0:5])
+ T.evaluate(A[0:5] > T.Broadcast(T.float32(5), 5))
+ T.evaluate(A[0:5] >= B[0:5])
+ T.evaluate(A[0:5] >= T.Broadcast(T.float32(5), 5))
+ T.evaluate(A[0:5] < B[0:5])
+ T.evaluate(A[0:5] < T.Broadcast(T.float32(5), 5))
+ T.evaluate(A[0:5] <= B[0:5])
+ T.evaluate(A[0:5] <= T.Broadcast(T.float32(5), 5))
+ T.evaluate(A[0:5] == B[0:5])
+ T.evaluate(A[0:5] == T.Broadcast(T.float32(5), 5))
+ T.evaluate(A[0:5] != B[0:5])
+ T.evaluate(A[0:5] != T.Broadcast(T.float32(5), 5))
+ T.bitwise_and(A[0:5] > T.Broadcast(T.float32(0), 5), B[0:5] >
T.Broadcast(T.float32(0), 5))
+ T.bitwise_or(A[0:5] > T.Broadcast(T.float32(0), 5), B[0:5] >
T.Broadcast(T.float32(0), 5))
+ T.bitwise_and(A[0:5] < T.Broadcast(T.float32(0), 5),
T.Broadcast(T.bool(1), 5))
+ T.bitwise_or(A[0:5] > T.Broadcast(T.float32(0), 5), T.Broadcast(T.bool(1),
5))
+
+
+def test_slice_op():
+ tvm.ir.assert_structural_equal(slice_op_test, slice_op_test_ref)
+
+
if __name__ == "__main__":
test_get_valid_counts_script_func()
test_alloc_zero_dim_buffer_round_trip()
+ test_slice_op()