This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 4c77f0fc24 [TIR] Extend DP4A tensor intrin (#16293)
4c77f0fc24 is described below

commit 4c77f0fc24b3b8dc4cf840ebaed215b4da9732b9
Author: Lufang Chen <[email protected]>
AuthorDate: Mon Jan 8 17:53:20 2024 +0800

    [TIR] Extend DP4A tensor intrin (#16293)
    
    * update dp4a tensor intrin
    
    * update dp4a tensor intrin
    
    * lint
    
    ---------
    
    Co-authored-by: Lufang CHEN 陈橹方 <[email protected]>
---
 python/tvm/tir/tensor_intrin/arm_cpu.py            |  9 ++-
 python/tvm/tir/tensor_intrin/dot_product_common.py | 82 +++++++++++++---------
 python/tvm/tir/tensor_intrin/rocm.py               |  3 +-
 src/target/source/codegen_cuda.cc                  |  2 +
 src/target/source/literal/cuda_int8_t.h            | 64 +++++++++++++++++
 .../tir-schedule/test_tir_schedule_tensorize.py    | 50 +++++++------
 6 files changed, 154 insertions(+), 56 deletions(-)

diff --git a/python/tvm/tir/tensor_intrin/arm_cpu.py 
b/python/tvm/tir/tensor_intrin/arm_cpu.py
index c518f64f5a..a5003d41a8 100644
--- a/python/tvm/tir/tensor_intrin/arm_cpu.py
+++ b/python/tvm/tir/tensor_intrin/arm_cpu.py
@@ -14,11 +14,16 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# pylint: disable=invalid-name,missing-function-docstring
+# pylint: disable=invalid-name,missing-function-docstring,unused-import
 """Intrinsics for ARM tensorization."""
 from tvm.script import tir as T
 from .. import TensorIntrin
-from .dot_product_common import DP4A_INTRIN  # pylint: disable=unused-import
+from .dot_product_common import (
+    DP4A_S8S8S32_INTRIN,
+    DP4A_S8U8S32_INTRIN,
+    DP4A_U8S8S32_INTRIN,
+    DP4A_U8U8U32_INTRIN,
+)
 
 
 # TODO(masahi): Parametrize the TVMScript description of dot product by
diff --git a/python/tvm/tir/tensor_intrin/dot_product_common.py 
b/python/tvm/tir/tensor_intrin/dot_product_common.py
index c531b80380..db10422c8e 100644
--- a/python/tvm/tir/tensor_intrin/dot_product_common.py
+++ b/python/tvm/tir/tensor_intrin/dot_product_common.py
@@ -20,36 +20,52 @@ from tvm.script import tir as T
 from .. import TensorIntrin
 
 
[email protected]_func
-def dp4a_desc(
-    A: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"),
-    B: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"),
-    C: T.Buffer((1,), "int32", offset_factor=1, align=4, scope="local"),
-) -> None:
-    with T.block("root"):
-        T.reads(C[0], A[0:4], B[0:4])
-        T.writes(C[0])
-        for i in range(0, 4):
-            with T.block("update"):
-                vi = T.axis.remap("R", [i])
-                C[0] = C[0] + T.cast(A[vi], "int32") * T.cast(B[vi], "int32")
-
-
[email protected]_func
-def dp4a_impl(
-    A: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"),
-    B: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"),
-    C: T.Buffer((1,), "int32", offset_factor=1, align=4, scope="local"),
-) -> None:
-    with T.block("root"):
-        T.reads(C[0], A[0:4], B[0:4])
-        T.writes(C[0])
-
-        C[0] += T.call_pure_extern(
-            "__dp4a", A.vload([0], "int8x4"), B.vload([0], "int8x4"), 
T.int32(0), dtype="int32"
-        )
-
-
-DP4A_INTRIN = "dp4a"
-
-TensorIntrin.register(DP4A_INTRIN, dp4a_desc, dp4a_impl)
+def get_dp4a_intrin(dtype_a, dtype_b, dtype_c):
+    if dtype_c == "uint32":
+        assert dtype_a == dtype_b == "uint8"
+    vec_type_a = "int8x4" if dtype_a == "int8" else "uint8x4"
+    vec_type_b = "int8x4" if dtype_b == "int8" else "uint8x4"
+
+    @T.prim_func
+    def dp4a_desc(
+        A: T.Buffer((4,), dtype_a, offset_factor=1, align=4, scope="shared"),
+        B: T.Buffer((4,), dtype_b, offset_factor=1, align=4, scope="shared"),
+        C: T.Buffer((1,), dtype_c, offset_factor=1, align=4, scope="local"),
+    ) -> None:
+        with T.block("root"):
+            T.reads(C[0], A[0:4], B[0:4])
+            T.writes(C[0])
+            for i in range(0, 4):
+                with T.block("update"):
+                    vi = T.axis.remap("R", [i])
+                    C[0] = C[0] + T.cast(A[vi], dtype_c) * T.cast(B[vi], 
dtype_c)
+
+    @T.prim_func
+    def dp4a_impl(
+        A: T.Buffer((4,), dtype_a, offset_factor=1, align=4, scope="shared"),
+        B: T.Buffer((4,), dtype_b, offset_factor=1, align=4, scope="shared"),
+        C: T.Buffer((1,), dtype_c, offset_factor=1, align=4, scope="local"),
+    ) -> None:
+        with T.block("root"):
+            T.reads(C[0], A[0:4], B[0:4])
+            T.writes(C[0])
+
+            C[0] += T.call_pure_extern(
+                "__dp4a",
+                A.vload([0], vec_type_a),
+                B.vload([0], vec_type_b),
+                T.uint32(0) if dtype_c == "uint32" else T.int32(0),
+                dtype=dtype_c,
+            )
+
+    return dp4a_desc, dp4a_impl
+
+
+DP4A_S8S8S32_INTRIN = "dp4a_s8s8s32"
+TensorIntrin.register(DP4A_S8S8S32_INTRIN, *get_dp4a_intrin("int8", "int8", 
"int32"))
+DP4A_U8S8S32_INTRIN = "dp4a_u8s8s32"
+TensorIntrin.register(DP4A_U8S8S32_INTRIN, *get_dp4a_intrin("uint8", "int8", 
"int32"))
+DP4A_S8U8S32_INTRIN = "dp4a_s8u8s32"
+TensorIntrin.register(DP4A_S8U8S32_INTRIN, *get_dp4a_intrin("int8", "uint8", 
"int32"))
+DP4A_U8U8U32_INTRIN = "dp4a_u8u8u32"
+TensorIntrin.register(DP4A_U8U8U32_INTRIN, *get_dp4a_intrin("uint8", "uint8", 
"uint32"))
diff --git a/python/tvm/tir/tensor_intrin/rocm.py 
b/python/tvm/tir/tensor_intrin/rocm.py
index 4b7c0da955..12dabfb2cd 100644
--- a/python/tvm/tir/tensor_intrin/rocm.py
+++ b/python/tvm/tir/tensor_intrin/rocm.py
@@ -20,7 +20,7 @@ from tvm.script import tir as T
 
 from tvm.runtime import convert
 from tvm.tir.expr import Cast, IntImm
-from .dot_product_common import dp4a_desc
+from .dot_product_common import get_dp4a_intrin
 from .. import TensorIntrin
 
 
@@ -50,6 +50,7 @@ def sdot4(
 
 AMDGPU_SDOT4_INTRIN = "sdot4"
 
+dp4a_desc, _ = get_dp4a_intrin("int8", "int8", "int32")
 TensorIntrin.register(AMDGPU_SDOT4_INTRIN, dp4a_desc, sdot4)
 
 WARP_SIZE = 64
diff --git a/src/target/source/codegen_cuda.cc 
b/src/target/source/codegen_cuda.cc
index ef69b7a7d1..efed5c02f1 100644
--- a/src/target/source/codegen_cuda.cc
+++ b/src/target/source/codegen_cuda.cc
@@ -35,6 +35,7 @@
 
 #include "../../tir/transforms/ir_utils.h"
 #include "literal/cuda_half_t.h"
+#include "literal/cuda_int8_t.h"
 #include "ptx.h"
 
 namespace tvm {
@@ -130,6 +131,7 @@ std::string CodeGenCUDA::Finish() {
   if (enable_int8_) {
     decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 610)\n";
     decl_stream << "#include <sm_61_intrinsics.h>\n";
+    decl_stream << _cuda_int8_t_def;
     decl_stream << "#endif\n";
   }
 
diff --git a/src/target/source/literal/cuda_int8_t.h 
b/src/target/source/literal/cuda_int8_t.h
new file mode 100644
index 0000000000..ce166ea8f3
--- /dev/null
+++ b/src/target/source/literal/cuda_int8_t.h
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file cuda_int8_t.h
+ * \brief Extra int8 intrisic for cuda codegen.
+ */
+#ifndef TVM_TARGET_SOURCE_LITERAL_CUDA_INT8_T_H_
+#define TVM_TARGET_SOURCE_LITERAL_CUDA_INT8_T_H_
+
+static constexpr const char* _cuda_int8_t_def = R"(
+
+#if defined(__CUDACC_RTC__)
+#define __SM_61_INTRINSICS_DECL__ __device__
+#else /* !__CUDACC_RTC__ */
+#define __SM_61_INTRINSICS_DECL__ static __device__ __inline__
+#endif /* __CUDACC_RTC__ */
+
+#ifndef __CUDA_ARCH__
+#define __DEF_IF_HOST { }
+#else  /* !__CUDA_ARCH__ */
+#define __DEF_IF_HOST ;
+#endif /* __CUDA_ARCH__ */
+
+__SM_61_INTRINSICS_DECL__ int __dp4a(unsigned int srcA, int srcB, int c) 
__DEF_IF_HOST
+__SM_61_INTRINSICS_DECL__ int __dp4a(int srcA, unsigned int srcB, int c) 
__DEF_IF_HOST
+
+#undef __DEF_IF_HOST
+
+#if !defined(__CUDACC_RTC__) && defined(__CUDA_ARCH__)
+__SM_61_INTRINSICS_DECL__ int __dp4a(unsigned int srcA, int srcB, int c) {
+    int ret;
+    asm volatile ("dp4a.u32.s32 %0, %1, %2, %3;" : "=r"(ret) : "r"(srcA), 
"r"(srcB), "r"(c));
+    return ret;
+}
+
+__SM_61_INTRINSICS_DECL__ int __dp4a(int srcA, unsigned int srcB, int c) {
+    int ret;
+    asm volatile ("dp4a.s32.u32 %0, %1, %2, %3;" : "=r"(ret) : "r"(srcA), 
"r"(srcB), "r"(c));
+    return ret;
+}
+#endif /* !__CUDACC_RTC__ && defined(__CUDA_ARCH__) */
+
+#undef __SM_61_INTRINSICS_DECL__
+
+)";
+
+#endif  // TVM_TARGET_SOURCE_LITERAL_CUDA_INT8_T_H_
diff --git a/tests/python/tir-schedule/test_tir_schedule_tensorize.py 
b/tests/python/tir-schedule/test_tir_schedule_tensorize.py
index 9646355f0a..1891914bc0 100644
--- a/tests/python/tir-schedule/test_tir_schedule_tensorize.py
+++ b/tests/python/tir-schedule/test_tir_schedule_tensorize.py
@@ -26,7 +26,10 @@ from tvm.tir.schedule.testing import (
     verify_trace_roundtrip,
 )
 from tvm.tir.tensor_intrin.arm_cpu import (
-    DP4A_INTRIN,
+    DP4A_S8S8S32_INTRIN,
+    DP4A_U8U8U32_INTRIN,
+    DP4A_U8S8S32_INTRIN,
+    DP4A_S8U8S32_INTRIN,
     ARM_DOT_4x4_i8_NEON_INTRIN,
     ARM_DOT_4x4_i8_SDOT_INTRIN,
 )
@@ -687,26 +690,25 @@ def test_tensorize_vdmpy():
     verify_trace_roundtrip(sch=sch, mod=func)
 
 
-def test_tensorize_dpa4():
-    m, n, k = 128, 128, 128
-
-    X = te.placeholder((m, k), name="X", dtype="int8")
-    W = te.placeholder((n, k), name="W", dtype="int8")
-    ak = te.reduce_axis((0, k), name="k")
-
-    matmul = te.compute(
-        (m, n),
-        lambda i, j: te.sum(
-            X[i, ak].astype("int32")
-            * W[j, ak].astype("int32"),
-            axis=ak,
-        ),
-        name="compute",
-    )
+def test_tensorize_dp4a():
+    # pylint: disable=too-many-locals
+    def _test_intrin(dtype_a, dtype_b, dtype_c, intrin):
+        m, n, k = 128, 128, 128
+        X = te.placeholder((m, k), name="X", dtype=dtype_a)
+        W = te.placeholder((n, k), name="W", dtype=dtype_b)
+        ak = te.reduce_axis((0, k), name="k")
+
+        matmul = te.compute(
+            (m, n),
+            lambda i, j: te.sum(
+                X[i, ak].astype(dtype_c) * W[j, ak].astype(dtype_c),
+                axis=ak,
+            ),
+            name="compute",
+        )
 
-    func = te.create_prim_func([X, W, matmul])
+        func = te.create_prim_func([X, W, matmul])
 
-    for intrin in [AMDGPU_SDOT4_INTRIN, DP4A_INTRIN]:
         sch = tir.Schedule(func, debug_mask="all")
         block = sch.get_block("compute")
         i, j, k = sch.get_loops(block)
@@ -717,7 +719,6 @@ def test_tensorize_dpa4():
         ko, kt = sch.split(ko, factors=sch.sample_perfect_tile(ko, n=2))
 
         sch.reorder(by, bx, ty, tx, yi, xi)
-
         CC = sch.cache_write(block, 0, "local")
         sch.reverse_compute_at(CC, tx)
 
@@ -734,6 +735,15 @@ def test_tensorize_dpa4():
 
         verify_trace_roundtrip(sch=sch, mod=func)
 
+    for args in [
+        ("int8", "int8", "int32", AMDGPU_SDOT4_INTRIN),
+        ("int8", "int8", "int32", DP4A_S8S8S32_INTRIN),
+        ("int8", "uint8", "int32", DP4A_S8U8S32_INTRIN),
+        ("uint8", "int8", "int32", DP4A_U8S8S32_INTRIN),
+        ("uint8", "uint8", "uint32", DP4A_U8U8U32_INTRIN),
+    ]:
+        _test_intrin(*args)
+
 
 def test_tensor_intrin_look_up():
     intrin_name = 'non_existent_intrin'

Reply via email to