This is an automated email from the ASF dual-hosted git repository.
mehrdadh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new b4d4b82dbb [Hexagon] Fix TIR vrmpy tensorization (#13404)
b4d4b82dbb is described below
commit b4d4b82dbb9be2e4d0954f9dfd8e1c46079b66ee
Author: masahi <[email protected]>
AuthorDate: Thu Nov 17 01:52:14 2022 +0900
[Hexagon] Fix TIR vrmpy tensorization (#13404)
[Hexagon] Fix vrmpy tensorization
---
python/tvm/tir/tensor_intrin/hexagon.py | 4 ----
.../python/unittest/test_tir_schedule_tensorize.py | 26 +++++++++++++++++++---
2 files changed, 23 insertions(+), 7 deletions(-)
diff --git a/python/tvm/tir/tensor_intrin/hexagon.py
b/python/tvm/tir/tensor_intrin/hexagon.py
index 6fa9dd8f00..306c8cd2e1 100644
--- a/python/tvm/tir/tensor_intrin/hexagon.py
+++ b/python/tvm/tir/tensor_intrin/hexagon.py
@@ -32,8 +32,6 @@ def dot_product_32x4_u8u8i32_desc(
for i in T.serial(0, 32):
for k in T.serial(0, 4):
with T.block("update"):
- with T.init():
- C[i] = T.int32(0)
vi, vk = T.axis.remap("SR", [i, k])
C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk],
"int32")
@@ -76,8 +74,6 @@ def dot_product_32x4_u8i8i32_desc(
for i in T.serial(0, 32):
for k in T.serial(0, 4):
with T.block("update"):
- with T.init():
- C[i] = T.int32(0)
vi, vk = T.axis.remap("SR", [i, k])
C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk],
"int32")
diff --git a/tests/python/unittest/test_tir_schedule_tensorize.py
b/tests/python/unittest/test_tir_schedule_tensorize.py
index f30e91b892..0129cee532 100644
--- a/tests/python/unittest/test_tir_schedule_tensorize.py
+++ b/tests/python/unittest/test_tir_schedule_tensorize.py
@@ -30,6 +30,7 @@ from tvm.tir.tensor_intrin.arm_cpu import (
)
from tvm.tir.tensor_intrin.rocm import AMDGPU_SDOT4_INTRIN
from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN
+from tvm.tir.tensor_intrin.hexagon import VRMPY_u8u8i32_INTRIN
# fmt: off
# pylint:
disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks
@@ -539,9 +540,9 @@ def test_tensorize_with_annotation():
verify_trace_roundtrip(sch=s, mod=func)
-def get_matmul_packed(m, n, k, lhs_type, int32_lanes):
+def get_matmul_packed(m, n, k, lhs_type, int32_lanes, rhs_dtype="int8"):
X = te.placeholder((m, k), name="X", dtype=lhs_type)
- packed_W = te.placeholder((n // int32_lanes, k // 4, int32_lanes, 4),
name="packedW", dtype="int8")
+ packed_W = te.placeholder((n // int32_lanes, k // 4, int32_lanes, 4),
name="packedW", dtype=rhs_dtype)
ak = te.reduce_axis((0, k), name="k")
matmul = te.compute(
@@ -549,7 +550,7 @@ def get_matmul_packed(m, n, k, lhs_type, int32_lanes):
lambda i, j: te.sum(
X[i, ak].astype("int32")
* packed_W[
- tvm.tir.indexdiv(j, 16), tvm.tir.indexdiv(ak, 4), j % 16, ak %
4
+ tvm.tir.indexdiv(j, int32_lanes), tvm.tir.indexdiv(ak, 4), j %
int32_lanes, ak % 4
].astype("int32"),
axis=ak,
),
@@ -598,6 +599,25 @@ def test_tensorize_arm_dot():
verify_trace_roundtrip(sch=sch, mod=func)
+def test_tensorize_vrmpy():
+ m, n, k = 128, 128, 128
+
+ func = get_matmul_packed(m, n, k, "uint8", 32, "uint8")
+
+ sch = tir.Schedule(func, debug_mask="all")
+ block = sch.get_block("compute")
+ _, j, k = sch.get_loops(block)
+
+ _, ji = sch.split(j, factors=[None, 32])
+ ko, ki = sch.split(k, factors=[None, 4])
+ sch.reorder(ko, ji, ki)
+
+ sch.decompose_reduction(block, ko)
+ sch.tensorize(ji, VRMPY_u8u8i32_INTRIN)
+
+ verify_trace_roundtrip(sch=sch, mod=func)
+
+
def test_tensorize_dpa4():
m, n, k = 128, 128, 128