This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 4c77f0fc24 [TIR] Extend DP4A tensor intrin (#16293)
4c77f0fc24 is described below
commit 4c77f0fc24b3b8dc4cf840ebaed215b4da9732b9
Author: Lufang Chen <[email protected]>
AuthorDate: Mon Jan 8 17:53:20 2024 +0800
[TIR] Extend DP4A tensor intrin (#16293)
* update dp4a tensor intrin
* update dp4a tensor intrin
* lint
---------
Co-authored-by: Lufang CHEN 陈橹方 <[email protected]>
---
python/tvm/tir/tensor_intrin/arm_cpu.py | 9 ++-
python/tvm/tir/tensor_intrin/dot_product_common.py | 82 +++++++++++++---------
python/tvm/tir/tensor_intrin/rocm.py | 3 +-
src/target/source/codegen_cuda.cc | 2 +
src/target/source/literal/cuda_int8_t.h | 64 +++++++++++++++++
.../tir-schedule/test_tir_schedule_tensorize.py | 50 +++++++------
6 files changed, 154 insertions(+), 56 deletions(-)
diff --git a/python/tvm/tir/tensor_intrin/arm_cpu.py
b/python/tvm/tir/tensor_intrin/arm_cpu.py
index c518f64f5a..a5003d41a8 100644
--- a/python/tvm/tir/tensor_intrin/arm_cpu.py
+++ b/python/tvm/tir/tensor_intrin/arm_cpu.py
@@ -14,11 +14,16 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-# pylint: disable=invalid-name,missing-function-docstring
+# pylint: disable=invalid-name,missing-function-docstring,unused-import
"""Intrinsics for ARM tensorization."""
from tvm.script import tir as T
from .. import TensorIntrin
-from .dot_product_common import DP4A_INTRIN # pylint: disable=unused-import
+from .dot_product_common import (
+ DP4A_S8S8S32_INTRIN,
+ DP4A_S8U8S32_INTRIN,
+ DP4A_U8S8S32_INTRIN,
+ DP4A_U8U8U32_INTRIN,
+)
# TODO(masahi): Parametrize the TVMScript description of dot product by
diff --git a/python/tvm/tir/tensor_intrin/dot_product_common.py
b/python/tvm/tir/tensor_intrin/dot_product_common.py
index c531b80380..db10422c8e 100644
--- a/python/tvm/tir/tensor_intrin/dot_product_common.py
+++ b/python/tvm/tir/tensor_intrin/dot_product_common.py
@@ -20,36 +20,52 @@ from tvm.script import tir as T
from .. import TensorIntrin
[email protected]_func
-def dp4a_desc(
- A: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"),
- B: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"),
- C: T.Buffer((1,), "int32", offset_factor=1, align=4, scope="local"),
-) -> None:
- with T.block("root"):
- T.reads(C[0], A[0:4], B[0:4])
- T.writes(C[0])
- for i in range(0, 4):
- with T.block("update"):
- vi = T.axis.remap("R", [i])
- C[0] = C[0] + T.cast(A[vi], "int32") * T.cast(B[vi], "int32")
-
-
[email protected]_func
-def dp4a_impl(
- A: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"),
- B: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"),
- C: T.Buffer((1,), "int32", offset_factor=1, align=4, scope="local"),
-) -> None:
- with T.block("root"):
- T.reads(C[0], A[0:4], B[0:4])
- T.writes(C[0])
-
- C[0] += T.call_pure_extern(
- "__dp4a", A.vload([0], "int8x4"), B.vload([0], "int8x4"),
T.int32(0), dtype="int32"
- )
-
-
-DP4A_INTRIN = "dp4a"
-
-TensorIntrin.register(DP4A_INTRIN, dp4a_desc, dp4a_impl)
+def get_dp4a_intrin(dtype_a, dtype_b, dtype_c):
+ if dtype_c == "uint32":
+ assert dtype_a == dtype_b == "uint8"
+ vec_type_a = "int8x4" if dtype_a == "int8" else "uint8x4"
+ vec_type_b = "int8x4" if dtype_b == "int8" else "uint8x4"
+
+ @T.prim_func
+ def dp4a_desc(
+ A: T.Buffer((4,), dtype_a, offset_factor=1, align=4, scope="shared"),
+ B: T.Buffer((4,), dtype_b, offset_factor=1, align=4, scope="shared"),
+ C: T.Buffer((1,), dtype_c, offset_factor=1, align=4, scope="local"),
+ ) -> None:
+ with T.block("root"):
+ T.reads(C[0], A[0:4], B[0:4])
+ T.writes(C[0])
+ for i in range(0, 4):
+ with T.block("update"):
+ vi = T.axis.remap("R", [i])
+ C[0] = C[0] + T.cast(A[vi], dtype_c) * T.cast(B[vi],
dtype_c)
+
+ @T.prim_func
+ def dp4a_impl(
+ A: T.Buffer((4,), dtype_a, offset_factor=1, align=4, scope="shared"),
+ B: T.Buffer((4,), dtype_b, offset_factor=1, align=4, scope="shared"),
+ C: T.Buffer((1,), dtype_c, offset_factor=1, align=4, scope="local"),
+ ) -> None:
+ with T.block("root"):
+ T.reads(C[0], A[0:4], B[0:4])
+ T.writes(C[0])
+
+ C[0] += T.call_pure_extern(
+ "__dp4a",
+ A.vload([0], vec_type_a),
+ B.vload([0], vec_type_b),
+ T.uint32(0) if dtype_c == "uint32" else T.int32(0),
+ dtype=dtype_c,
+ )
+
+ return dp4a_desc, dp4a_impl
+
+
+DP4A_S8S8S32_INTRIN = "dp4a_s8s8s32"
+TensorIntrin.register(DP4A_S8S8S32_INTRIN, *get_dp4a_intrin("int8", "int8",
"int32"))
+DP4A_U8S8S32_INTRIN = "dp4a_u8s8s32"
+TensorIntrin.register(DP4A_U8S8S32_INTRIN, *get_dp4a_intrin("uint8", "int8",
"int32"))
+DP4A_S8U8S32_INTRIN = "dp4a_s8u8s32"
+TensorIntrin.register(DP4A_S8U8S32_INTRIN, *get_dp4a_intrin("int8", "uint8",
"int32"))
+DP4A_U8U8U32_INTRIN = "dp4a_u8u8u32"
+TensorIntrin.register(DP4A_U8U8U32_INTRIN, *get_dp4a_intrin("uint8", "uint8",
"uint32"))
diff --git a/python/tvm/tir/tensor_intrin/rocm.py
b/python/tvm/tir/tensor_intrin/rocm.py
index 4b7c0da955..12dabfb2cd 100644
--- a/python/tvm/tir/tensor_intrin/rocm.py
+++ b/python/tvm/tir/tensor_intrin/rocm.py
@@ -20,7 +20,7 @@ from tvm.script import tir as T
from tvm.runtime import convert
from tvm.tir.expr import Cast, IntImm
-from .dot_product_common import dp4a_desc
+from .dot_product_common import get_dp4a_intrin
from .. import TensorIntrin
@@ -50,6 +50,7 @@ def sdot4(
AMDGPU_SDOT4_INTRIN = "sdot4"
+dp4a_desc, _ = get_dp4a_intrin("int8", "int8", "int32")
TensorIntrin.register(AMDGPU_SDOT4_INTRIN, dp4a_desc, sdot4)
WARP_SIZE = 64
diff --git a/src/target/source/codegen_cuda.cc
b/src/target/source/codegen_cuda.cc
index ef69b7a7d1..efed5c02f1 100644
--- a/src/target/source/codegen_cuda.cc
+++ b/src/target/source/codegen_cuda.cc
@@ -35,6 +35,7 @@
#include "../../tir/transforms/ir_utils.h"
#include "literal/cuda_half_t.h"
+#include "literal/cuda_int8_t.h"
#include "ptx.h"
namespace tvm {
@@ -130,6 +131,7 @@ std::string CodeGenCUDA::Finish() {
if (enable_int8_) {
decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 610)\n";
decl_stream << "#include <sm_61_intrinsics.h>\n";
+ decl_stream << _cuda_int8_t_def;
decl_stream << "#endif\n";
}
diff --git a/src/target/source/literal/cuda_int8_t.h
b/src/target/source/literal/cuda_int8_t.h
new file mode 100644
index 0000000000..ce166ea8f3
--- /dev/null
+++ b/src/target/source/literal/cuda_int8_t.h
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file cuda_int8_t.h
+ * \brief Extra int8 intrisic for cuda codegen.
+ */
+#ifndef TVM_TARGET_SOURCE_LITERAL_CUDA_INT8_T_H_
+#define TVM_TARGET_SOURCE_LITERAL_CUDA_INT8_T_H_
+
+static constexpr const char* _cuda_int8_t_def = R"(
+
+#if defined(__CUDACC_RTC__)
+#define __SM_61_INTRINSICS_DECL__ __device__
+#else /* !__CUDACC_RTC__ */
+#define __SM_61_INTRINSICS_DECL__ static __device__ __inline__
+#endif /* __CUDACC_RTC__ */
+
+#ifndef __CUDA_ARCH__
+#define __DEF_IF_HOST { }
+#else /* !__CUDA_ARCH__ */
+#define __DEF_IF_HOST ;
+#endif /* __CUDA_ARCH__ */
+
+__SM_61_INTRINSICS_DECL__ int __dp4a(unsigned int srcA, int srcB, int c)
__DEF_IF_HOST
+__SM_61_INTRINSICS_DECL__ int __dp4a(int srcA, unsigned int srcB, int c)
__DEF_IF_HOST
+
+#undef __DEF_IF_HOST
+
+#if !defined(__CUDACC_RTC__) && defined(__CUDA_ARCH__)
+__SM_61_INTRINSICS_DECL__ int __dp4a(unsigned int srcA, int srcB, int c) {
+ int ret;
+ asm volatile ("dp4a.u32.s32 %0, %1, %2, %3;" : "=r"(ret) : "r"(srcA),
"r"(srcB), "r"(c));
+ return ret;
+}
+
+__SM_61_INTRINSICS_DECL__ int __dp4a(int srcA, unsigned int srcB, int c) {
+ int ret;
+ asm volatile ("dp4a.s32.u32 %0, %1, %2, %3;" : "=r"(ret) : "r"(srcA),
"r"(srcB), "r"(c));
+ return ret;
+}
+#endif /* !__CUDACC_RTC__ && defined(__CUDA_ARCH__) */
+
+#undef __SM_61_INTRINSICS_DECL__
+
+)";
+
+#endif // TVM_TARGET_SOURCE_LITERAL_CUDA_INT8_T_H_
diff --git a/tests/python/tir-schedule/test_tir_schedule_tensorize.py
b/tests/python/tir-schedule/test_tir_schedule_tensorize.py
index 9646355f0a..1891914bc0 100644
--- a/tests/python/tir-schedule/test_tir_schedule_tensorize.py
+++ b/tests/python/tir-schedule/test_tir_schedule_tensorize.py
@@ -26,7 +26,10 @@ from tvm.tir.schedule.testing import (
verify_trace_roundtrip,
)
from tvm.tir.tensor_intrin.arm_cpu import (
- DP4A_INTRIN,
+ DP4A_S8S8S32_INTRIN,
+ DP4A_U8U8U32_INTRIN,
+ DP4A_U8S8S32_INTRIN,
+ DP4A_S8U8S32_INTRIN,
ARM_DOT_4x4_i8_NEON_INTRIN,
ARM_DOT_4x4_i8_SDOT_INTRIN,
)
@@ -687,26 +690,25 @@ def test_tensorize_vdmpy():
verify_trace_roundtrip(sch=sch, mod=func)
-def test_tensorize_dpa4():
- m, n, k = 128, 128, 128
-
- X = te.placeholder((m, k), name="X", dtype="int8")
- W = te.placeholder((n, k), name="W", dtype="int8")
- ak = te.reduce_axis((0, k), name="k")
-
- matmul = te.compute(
- (m, n),
- lambda i, j: te.sum(
- X[i, ak].astype("int32")
- * W[j, ak].astype("int32"),
- axis=ak,
- ),
- name="compute",
- )
+def test_tensorize_dp4a():
+ # pylint: disable=too-many-locals
+ def _test_intrin(dtype_a, dtype_b, dtype_c, intrin):
+ m, n, k = 128, 128, 128
+ X = te.placeholder((m, k), name="X", dtype=dtype_a)
+ W = te.placeholder((n, k), name="W", dtype=dtype_b)
+ ak = te.reduce_axis((0, k), name="k")
+
+ matmul = te.compute(
+ (m, n),
+ lambda i, j: te.sum(
+ X[i, ak].astype(dtype_c) * W[j, ak].astype(dtype_c),
+ axis=ak,
+ ),
+ name="compute",
+ )
- func = te.create_prim_func([X, W, matmul])
+ func = te.create_prim_func([X, W, matmul])
- for intrin in [AMDGPU_SDOT4_INTRIN, DP4A_INTRIN]:
sch = tir.Schedule(func, debug_mask="all")
block = sch.get_block("compute")
i, j, k = sch.get_loops(block)
@@ -717,7 +719,6 @@ def test_tensorize_dpa4():
ko, kt = sch.split(ko, factors=sch.sample_perfect_tile(ko, n=2))
sch.reorder(by, bx, ty, tx, yi, xi)
-
CC = sch.cache_write(block, 0, "local")
sch.reverse_compute_at(CC, tx)
@@ -734,6 +735,15 @@ def test_tensorize_dpa4():
verify_trace_roundtrip(sch=sch, mod=func)
+ for args in [
+ ("int8", "int8", "int32", AMDGPU_SDOT4_INTRIN),
+ ("int8", "int8", "int32", DP4A_S8S8S32_INTRIN),
+ ("int8", "uint8", "int32", DP4A_S8U8S32_INTRIN),
+ ("uint8", "int8", "int32", DP4A_U8S8S32_INTRIN),
+ ("uint8", "uint8", "uint32", DP4A_U8U8U32_INTRIN),
+ ]:
+ _test_intrin(*args)
+
def test_tensor_intrin_look_up():
intrin_name = 'non_existent_intrin'