This is an automated email from the ASF dual-hosted git repository.

mehrdadh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b4d4b82dbb [Hexagon] Fix TIR vrmpy tensorization (#13404)
b4d4b82dbb is described below

commit b4d4b82dbb9be2e4d0954f9dfd8e1c46079b66ee
Author: masahi <[email protected]>
AuthorDate: Thu Nov 17 01:52:14 2022 +0900

    [Hexagon] Fix TIR vrmpy tensorization (#13404)
    
    [Hexagon] Fix vrmpy tensorization
---
 python/tvm/tir/tensor_intrin/hexagon.py            |  4 ----
 .../python/unittest/test_tir_schedule_tensorize.py | 26 +++++++++++++++++++---
 2 files changed, 23 insertions(+), 7 deletions(-)

diff --git a/python/tvm/tir/tensor_intrin/hexagon.py 
b/python/tvm/tir/tensor_intrin/hexagon.py
index 6fa9dd8f00..306c8cd2e1 100644
--- a/python/tvm/tir/tensor_intrin/hexagon.py
+++ b/python/tvm/tir/tensor_intrin/hexagon.py
@@ -32,8 +32,6 @@ def dot_product_32x4_u8u8i32_desc(
         for i in T.serial(0, 32):
             for k in T.serial(0, 4):
                 with T.block("update"):
-                    with T.init():
-                        C[i] = T.int32(0)
                     vi, vk = T.axis.remap("SR", [i, k])
                     C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk], 
"int32")
 
@@ -76,8 +74,6 @@ def dot_product_32x4_u8i8i32_desc(
         for i in T.serial(0, 32):
             for k in T.serial(0, 4):
                 with T.block("update"):
-                    with T.init():
-                        C[i] = T.int32(0)
                     vi, vk = T.axis.remap("SR", [i, k])
                     C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk], 
"int32")
 
diff --git a/tests/python/unittest/test_tir_schedule_tensorize.py 
b/tests/python/unittest/test_tir_schedule_tensorize.py
index f30e91b892..0129cee532 100644
--- a/tests/python/unittest/test_tir_schedule_tensorize.py
+++ b/tests/python/unittest/test_tir_schedule_tensorize.py
@@ -30,6 +30,7 @@ from tvm.tir.tensor_intrin.arm_cpu import (
 )
 from tvm.tir.tensor_intrin.rocm import AMDGPU_SDOT4_INTRIN
 from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN
+from tvm.tir.tensor_intrin.hexagon import VRMPY_u8u8i32_INTRIN
 
 # fmt: off
 # pylint: 
disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks
@@ -539,9 +540,9 @@ def test_tensorize_with_annotation():
     verify_trace_roundtrip(sch=s, mod=func)
 
 
-def get_matmul_packed(m, n, k, lhs_type, int32_lanes):
+def get_matmul_packed(m, n, k, lhs_type, int32_lanes, rhs_dtype="int8"):
     X = te.placeholder((m, k), name="X", dtype=lhs_type)
-    packed_W = te.placeholder((n // int32_lanes, k // 4, int32_lanes, 4), 
name="packedW", dtype="int8")
+    packed_W = te.placeholder((n // int32_lanes, k // 4, int32_lanes, 4), 
name="packedW", dtype=rhs_dtype)
 
     ak = te.reduce_axis((0, k), name="k")
     matmul = te.compute(
@@ -549,7 +550,7 @@ def get_matmul_packed(m, n, k, lhs_type, int32_lanes):
         lambda i, j: te.sum(
             X[i, ak].astype("int32")
             * packed_W[
-                tvm.tir.indexdiv(j, 16), tvm.tir.indexdiv(ak, 4), j % 16, ak % 
4
+                tvm.tir.indexdiv(j, int32_lanes), tvm.tir.indexdiv(ak, 4), j % 
int32_lanes, ak % 4
             ].astype("int32"),
             axis=ak,
         ),
@@ -598,6 +599,25 @@ def test_tensorize_arm_dot():
         verify_trace_roundtrip(sch=sch, mod=func)
 
 
+def test_tensorize_vrmpy():
+    m, n, k = 128, 128, 128
+
+    func = get_matmul_packed(m, n, k, "uint8", 32, "uint8")
+
+    sch = tir.Schedule(func, debug_mask="all")
+    block = sch.get_block("compute")
+    _, j, k = sch.get_loops(block)
+
+    _, ji = sch.split(j, factors=[None, 32])
+    ko, ki = sch.split(k, factors=[None, 4])
+    sch.reorder(ko, ji, ki)
+
+    sch.decompose_reduction(block, ko)
+    sch.tensorize(ji, VRMPY_u8u8i32_INTRIN)
+
+    verify_trace_roundtrip(sch=sch, mod=func)
+
+
 def test_tensorize_dpa4():
     m, n, k = 128, 128, 128
 

Reply via email to