This is an automated email from the ASF dual-hosted git repository.
expye pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 588d1f2e93 [TensorIR][ROCm] AMD Matrix Core Support (#15106)
588d1f2e93 is described below
commit 588d1f2e93a4e50deb635464ce869efa7c116e1f
Author: Lei Wang <[email protected]>
AuthorDate: Wed Jun 28 16:46:07 2023 +0800
[TensorIR][ROCm] AMD Matrix Core Support (#15106)
* fix rocm arch parse issue,
* test case init.
* mfma scheduler
* support intrinsic.
* test case update ( i8-i32 f16-f32
* detect if the arch support matrix core
* replace requires_rocm to requires_matrixcore
* remove redundant debug print
* fix lint & replace local with warp for lowering
* replace default set rocm arch from parsering
* auto detect rocm archtecture.
* code optimize
* simple typo fix.
* lint
* fix lint
* lint fix
* i8 warp mfma fix
* typo fix
* fix get_rocm_arch.
---
python/tvm/contrib/rocm.py | 119 +++++-
python/tvm/testing/tir.py | 111 ++++++
python/tvm/testing/utils.py | 14 +-
python/tvm/tir/tensor_intrin/rocm.py | 433 ++++++++++++++++++++-
src/target/target_kind.cc | 12 +-
.../unittest/test_tir_schedule_tensorize_mfma.py | 314 +++++++++++++++
6 files changed, 987 insertions(+), 16 deletions(-)
diff --git a/python/tvm/contrib/rocm.py b/python/tvm/contrib/rocm.py
index b33e20cbc1..fb50c71e22 100644
--- a/python/tvm/contrib/rocm.py
+++ b/python/tvm/contrib/rocm.py
@@ -15,7 +15,9 @@
# specific language governing permissions and limitations
# under the License.
"""Utility for ROCm backend"""
+import re
import subprocess
+import os
from os.path import join, exists
import tvm._ffi
@@ -147,9 +149,10 @@ def callback_rocm_bitcode_path(rocdl_dir=None):
"oclc_daz_opt_off",
"oclc_finite_only_off",
"oclc_finite_only_on",
- "oclc_isa_version_803", # todo (t-vi): an alternative might be to
scan for the
- "oclc_isa_version_900", # isa version files (if the
linker throws out
- "oclc_isa_version_906", # the unneeded ones or we filter
for the arch we need)
+ # todo (t-vi): an alternative might be to scan for the
+ "oclc_isa_version_803",
+ "oclc_isa_version_900", # isa version files (if the linker throws out
+ "oclc_isa_version_906", # the unneeded ones or we filter for the arch
we need)
"oclc_isa_version_1030",
"oclc_unsafe_math_off",
"oclc_unsafe_math_on",
@@ -168,3 +171,113 @@ def callback_rocm_bitcode_path(rocdl_dir=None):
raise RuntimeError("could not find bitcode " + n)
return tvm.runtime.convert(bitcode_files)
+
+
+def parse_compute_version(compute_version):
+ """Parse compute capability string to divide major and minor version
+
+ Parameters
+ ----------
+ compute_version : str
+ compute capability of a GPU (e.g. "6.0")
+
+ Returns
+ -------
+ major : int
+ major version number
+ minor : int
+ minor version number
+ """
+ split_ver = compute_version.split(".")
+ try:
+ major = int(split_ver[0])
+ minor = int(split_ver[1])
+ return major, minor
+ except (IndexError, ValueError) as err:
+ # pylint: disable=raise-missing-from
+ raise RuntimeError("Compute version parsing error: " + str(err))
+
+
+def have_matrixcore(compute_version=None):
+ """Either MatrixCore support is provided in the compute capability or not
+
+ Parameters
+ ----------
+ compute_version : str, optional
+ compute capability of a GPU (e.g. "7.0").
+
+ Returns
+ -------
+ have_matrixcore : bool
+ True if MatrixCore support is provided, False otherwise
+ """
+ if compute_version is None:
+ if tvm.rocm(0).exist:
+ compute_version = tvm.rocm(0).compute_version
+ else:
+ raise RuntimeError("No ROCm runtime found")
+ major, _ = parse_compute_version(compute_version)
+ # matrix core first introduced in 8.0
+ if major >= 8:
+ return True
+
+ return False
+
+
+@tvm._ffi.register_func("tvm_callback_rocm_get_arch")
+def get_rocm_arch(rocm_path="/opt/rocm"):
+ """Utility function to get the AMD GPU architecture
+
+ Parameters
+ ----------
+ rocm_path : str
+ The path to rocm installation directory
+
+ Returns
+ -------
+ gpu_arch : str
+ The AMD GPU architecture
+ """
+ gpu_arch = "gfx900"
+ # check if rocm is installed
+ if not os.path.exists(rocm_path):
+ print("ROCm not detected, using default gfx900")
+ return gpu_arch
+ try:
+ # Execute rocminfo command
+ rocminfo_output =
subprocess.check_output([f"{rocm_path}/bin/rocminfo"]).decode("utf-8")
+
+ # Use regex to match the "Name" field
+ match = re.search(r"Name:\s+(gfx\d+[a-zA-Z]*)", rocminfo_output)
+ if match:
+ gpu_arch = match.group(1)
+ return gpu_arch
+ except subprocess.CalledProcessError:
+ print(
+ f"Unable to execute rocminfo command, \
+ please ensure ROCm is installed and you have an AMD GPU on
your system.\
+ using default {gpu_arch}."
+ )
+ return gpu_arch
+
+
+def find_rocm_path():
+ """Utility function to find ROCm path
+
+ Returns
+ -------
+ path : str
+ Path to ROCm root.
+ """
+ if "ROCM_PATH" in os.environ:
+ return os.environ["ROCM_PATH"]
+ cmd = ["which", "hipcc"]
+ proc = subprocess.Popen(cmd, stdout=subprocess.PIPE,
stderr=subprocess.STDOUT)
+ (out, _) = proc.communicate()
+ out = out.decode("utf-8").strip()
+ if proc.returncode == 0:
+ return os.path.realpath(os.path.join(out, "../.."))
+ rocm_path = "/opt/rocm"
+ if os.path.exists(os.path.join(rocm_path, "bin/hipcc")):
+ return rocm_path
+ raise RuntimeError("Cannot find ROCm path")
diff --git a/python/tvm/testing/tir.py b/python/tvm/testing/tir.py
index 57c1a85c5b..6842f1e519 100644
--- a/python/tvm/testing/tir.py
+++ b/python/tvm/testing/tir.py
@@ -128,3 +128,114 @@ def mma_schedule(
sch.tensorize(sch.get_loops(C_warp)[-2], mma_store_intrin)
return sch
+
+
+def mfma_schedule(
+ workload,
+ k_inner,
+ in_dtype,
+ b_transposed,
+ i_factors,
+ j_factors,
+ k_factors,
+ index_map_A,
+ index_map_B,
+ index_map_C,
+ ldmatrix_a_intrin,
+ ldmatrix_b_intrin,
+ mfma_intrin,
+ mfma_fill_intrin,
+ mfma_store_intrin,
+ shared_scope="shared",
+):
+ """Create a tensorized schedule for GEMM with MFMA intrinsics."""
+ import tvm
+
+ ir_module = tvm.IRModule({"main": workload})
+ sch = tvm.tir.Schedule(ir_module)
+
+ wmma_m = 16
+ wmma_n = 16
+ wmma_k = k_inner
+ warp_size = 64
+ block = sch.get_block("C")
+ i, j, k = sch.get_loops(block)
+ i, i_tc = sch.split(i, factors=[None, wmma_m])
+ j, j_tc = sch.split(j, factors=[None, wmma_n])
+ k, k_tc = sch.split(k, factors=[None, wmma_k])
+
+ sch.reorder(i, j, k, i_tc, j_tc, k_tc)
+
+ block_inner = sch.blockize(i_tc)
+ block_outer, block_inner = block_inner, block
+
+ num_ty = i_factors[2] * j_factors[2]
+
+ i0, i1, i2, i3, i4 = sch.split(i, factors=i_factors)
+ j0, j1, j2, j3, j4 = sch.split(j, factors=j_factors)
+ k0, k1, k2 = sch.split(k, k_factors)
+
+ sch.reorder(i0, j0, i1, j1, j2, i2, k0, k1, i3, j3, k2, i4, j4)
+
+ block_idx = sch.fuse(i0, j0)
+ block_idy = sch.fuse(i1, j1)
+ thread_idy = sch.fuse(j2, i2)
+ sch.bind(block_idx, "blockIdx.x")
+ sch.bind(block_idy, "blockIdx.y")
+ sch.bind(thread_idy, "threadIdx.y")
+
+ def fetch_to_shared(block, idx, ndim):
+ block_read = sch.cache_read(block, idx, shared_scope)
+ sch.compute_at(block_read, k0)
+ vector_size = 16 if in_dtype == "int8" else 8
+ fused = sch.fuse(*sch.get_loops(block_read)[-ndim:])
+ _, f_1, f_2, f_3 = sch.split(fused, factors=[None, num_ty, warp_size,
vector_size])
+ sch.bind(f_2, "threadIdx.x")
+ sch.bind(f_1, "threadIdx.y")
+ sch.vectorize(f_3)
+ return block_read
+
+ fetch_to_shared(block_outer, 0, 2)
+ fetch_to_shared(block_outer, 1, 2)
+
+ A_warp = sch.cache_read(block_outer, 0, "warp")
+ B_warp = sch.cache_read(block_outer, 1, "warp")
+
+ sch.compute_at(A_warp, k1)
+ sch.compute_at(B_warp, k1)
+ C_warp = sch.cache_write(block_outer, 0, "warp")
+ sch.reverse_compute_at(C_warp, thread_idy)
+
+ ii, jj = sch.get_loops(C_warp)[-2:]
+ io, ii = sch.split(ii, factors=[None, 16])
+ jo, ji = sch.split(jj, factors=[None, 16])
+ sch.reorder(io, jo, ii, ji)
+
+ sch.decompose_reduction(block_outer, sch.get_loops(block_outer)[3])
+ block_init_c = sch.get_block("C_init")
+
+ def tile_wmma_fragment(block_read, height, width):
+ i, j = sch.get_loops(block_read)[-2:]
+ i0, i1 = sch.split(i, factors=[None, height])
+ j0, j1 = sch.split(j, factors=[None, width])
+ sch.reorder(i0, j0, i1, j1)
+ return i1
+
+ loop_a = tile_wmma_fragment(A_warp, 16, k_inner)
+
+ if b_transposed:
+ loop_b = tile_wmma_fragment(B_warp, 16, k_inner)
+ else:
+ loop_b = tile_wmma_fragment(B_warp, k_inner, 16)
+
+ sch.transform_layout(A_warp, ("write", 0), index_map_A)
+ sch.transform_layout(B_warp, ("write", 0), index_map_B)
+ sch.transform_layout(C_warp, ("read", 0), index_map_C)
+
+ sch.tensorize(loop_a, ldmatrix_a_intrin)
+ sch.tensorize(loop_b, ldmatrix_b_intrin)
+ sch.tensorize(sch.get_loops(block_inner)[-3], mfma_intrin)
+ sch.tensorize(sch.get_loops(block_init_c)[-2], mfma_fill_intrin)
+ sch.tensorize(sch.get_loops(C_warp)[-2], mfma_store_intrin)
+
+ return sch
diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py
index 884a885fb2..5f81672748 100644
--- a/python/tvm/testing/utils.py
+++ b/python/tvm/testing/utils.py
@@ -91,7 +91,7 @@ import tvm.tir
import tvm.te
import tvm._ffi
-from tvm.contrib import nvcc, cudnn
+from tvm.contrib import nvcc, cudnn, rocm
import tvm.contrib.hexagon._ci_env_check as hexagon
from tvm.driver.tvmc.frontends import load_model
from tvm.error import TVMError
@@ -914,6 +914,14 @@ requires_rocm = Feature(
parent_features="gpu",
)
+# Mark a test as requiring a matrixcore to run
+requires_matrixcore = Feature(
+ "matrixcore",
+ "AMD Matrix Core",
+ run_time_check=lambda: tvm.rocm().exist and
rocm.have_matrixcore(tvm.rocm().compute_version),
+ parent_features="rocm",
+)
+
# Mark a test as requiring the metal runtime
requires_metal = Feature(
"metal",
@@ -1239,7 +1247,6 @@ def requires_package(*packages):
def parametrize_targets(*args):
-
"""Parametrize a test over a specific set of targets.
Use this decorator when you want your test to be run over a
@@ -1503,7 +1510,6 @@ def parameters(*value_sets, ids=None):
outputs = []
for param_values in zip(*value_sets):
-
# Optional cls parameter in case a parameter is defined inside a
# class scope.
def fixture_func(*_cls, request):
@@ -2029,7 +2035,6 @@ class CompareBeforeAfter:
return inner
if hasattr(transform, "_pytestfixturefunction"):
-
if not hasattr(cls, "_transform_orig"):
cls._transform_orig = transform
@@ -2050,7 +2055,6 @@ class CompareBeforeAfter:
return apply(transform(self))
else:
-
raise TypeError(
"Expected transform to be a tvm.ir.transform.Pass, or a method
returning a Pass"
)
diff --git a/python/tvm/tir/tensor_intrin/rocm.py
b/python/tvm/tir/tensor_intrin/rocm.py
index 3700f3e8da..4b7c0da955 100644
--- a/python/tvm/tir/tensor_intrin/rocm.py
+++ b/python/tvm/tir/tensor_intrin/rocm.py
@@ -18,8 +18,13 @@
"""Intrinsics for AMDGPU tensorization."""
from tvm.script import tir as T
-from .. import TensorIntrin
+from tvm.runtime import convert
+from tvm.tir.expr import Cast, IntImm
from .dot_product_common import dp4a_desc
+from .. import TensorIntrin
+
+
+lift = convert
@T.prim_func
@@ -46,3 +51,429 @@ def sdot4(
AMDGPU_SDOT4_INTRIN = "sdot4"
TensorIntrin.register(AMDGPU_SDOT4_INTRIN, dp4a_desc, sdot4)
+
+WARP_SIZE = 64
+M_DIM = 16
+N_DIM = 16
+
+
+def shared_16x4_to_local_64x1_layout_A(i, j):
+ thread_id = j * 16 + i
+ return thread_id, convert(0)
+
+
+def thread_id_shared_access_64x1_to_16x4_layout_A(thread_id, local_id):
+ i = thread_id % 16
+ j = thread_id // 16 + local_id
+ return i, j
+
+
+def shared_4x16_to_local_64x1_layout_B(i, j):
+ thread_id = i * 16 + j
+ return thread_id, convert(0)
+
+
+def thread_id_shared_access_64x1_to_4x16_layout_B(thread_id, local_id):
+ i = thread_id // 16
+ j = thread_id % 16 + local_id
+ return i, j
+
+
+def shared_16x16_to_local_64x4_layout_C(i, j):
+ thread_id = j + (i // 4) * 16
+ local = i % 4
+ return thread_id, local
+
+
+def thread_id_shared_access_64x4_to_16x16_layout_A(thread_id, local_id):
+ i = thread_id % 16
+ j = (thread_id // 16) * 4 + local_id
+ return i, j
+
+
+def shared_16x16_to_local_64x4_layout_A(i, j):
+ thread_id = i + 16 * (j // 4)
+ local = j % 4
+ return thread_id, local
+
+
+def thread_id_shared_access_64x4_to_16x16_layout_B(thread_id, local_id):
+ i = local_id + (thread_id // 16) * 4
+ j = thread_id % 16
+ return i, j
+
+
+def shared_16x16_to_local_64x4_layout_B(i, j):
+ thread_id = j + (i // 4) * 16
+ local = i % 4
+ return thread_id, local
+
+
+def thread_id_shared_access_64x4_to_16x16_layout_C(thread_id, local_id):
+ i = local_id + (thread_id // 16) * 4
+ j = thread_id % 16
+ return i, j
+
+
+def get_mma_fill_intrin(dtype, local_size):
+ zero = IntImm("int32", 0).astype(dtype)
+
+ # Assume M = N = 16
+ index_map = shared_16x16_to_local_64x4_layout_C
+
+ @T.prim_func
+ def mma_fill_desc(a: T.handle) -> None:
+ C_warp = T.match_buffer(a, [WARP_SIZE, local_size], dtype=dtype,
scope="warp")
+
+ with T.block("root"):
+ T.reads()
+ T.writes(C_warp[0:WARP_SIZE, 0:local_size])
+ for i0, i1 in T.grid(M_DIM, N_DIM):
+ with T.block("C_warp"):
+ i, j = T.axis.remap("SS", [i0, i1])
+ thread_id, local_id = T.meta_var(index_map(i, j))
+ T.reads()
+ T.writes(C_warp[thread_id, local_id])
+ C_warp[thread_id, local_id] = zero
+
+ @T.prim_func
+ def mma_fill_impl(a: T.handle) -> None:
+ C_warp = T.match_buffer(
+ a, [WARP_SIZE, local_size], dtype=dtype, scope="warp",
offset_factor=1
+ )
+
+ with T.block("root"):
+ T.reads()
+ T.writes(C_warp[0:WARP_SIZE, 0:local_size])
+ tx = T.env_thread("threadIdx.x")
+ T.launch_thread(tx, WARP_SIZE)
+ for local_id in T.serial(0, local_size):
+ C_warp[tx, local_id] = zero
+
+ return mma_fill_desc, mma_fill_impl
+
+
+def get_mfma_load_intrin(
+ k_dim=4,
+ dtype="float32",
+ scope="shared",
+ is_b=False,
+ transposed=False,
+):
+ local_size = (M_DIM * k_dim) // WARP_SIZE if not is_b else (N_DIM * k_dim)
// WARP_SIZE
+ memory_shape = (M_DIM, k_dim)
+ if is_b:
+ memory_shape = (N_DIM, k_dim) if transposed else (k_dim, N_DIM)
+
+ row_dim, col_dim = memory_shape
+
+ if k_dim == 4:
+ index_map = shared_16x4_to_local_64x1_layout_A
+ reverse_index_map = thread_id_shared_access_64x1_to_16x4_layout_A
+ if is_b:
+ index_map = (
+ shared_16x4_to_local_64x1_layout_A
+ if transposed
+ else shared_4x16_to_local_64x1_layout_B
+ )
+ reverse_index_map = (
+ thread_id_shared_access_64x1_to_16x4_layout_A
+ if transposed
+ else thread_id_shared_access_64x1_to_4x16_layout_B
+ )
+ elif k_dim == 16:
+ index_map = shared_16x16_to_local_64x4_layout_A
+ reverse_index_map = thread_id_shared_access_64x4_to_16x16_layout_A
+
+ if is_b:
+ index_map = (
+ shared_16x16_to_local_64x4_layout_A
+ if transposed
+ else shared_16x16_to_local_64x4_layout_B
+ )
+ reverse_index_map = (
+ thread_id_shared_access_64x4_to_16x16_layout_A
+ if transposed
+ else thread_id_shared_access_64x4_to_16x16_layout_B
+ )
+ else:
+ raise ValueError("k_dim must be 4 or 16 currently")
+
+ @T.prim_func
+ def mfma_load_desc(reg_handle: T.handle, memory_handle: T.handle) -> None:
+ memory = T.match_buffer(
+ memory_handle,
+ memory_shape,
+ dtype,
+ offset_factor=1,
+ scope=scope,
+ )
+ reg = T.match_buffer(
+ reg_handle, (WARP_SIZE, local_size), dtype, offset_factor=1,
scope="warp"
+ )
+
+ with T.block("root"):
+ T.reads(memory[0:row_dim, 0:col_dim])
+ T.writes(reg[0:WARP_SIZE, 0:local_size])
+
+ for ax0, ax1 in T.grid(row_dim, col_dim):
+ with T.block("memory_reg"):
+ v0, v1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(memory[v0, v1])
+
+ thread_id, local_id = T.meta_var(index_map(v0, v1))
+ T.writes(reg[thread_id, local_id])
+ reg[thread_id, local_id] = memory[v0, v1]
+
+ @T.prim_func
+ def mfma_load_impl(reg_handle: T.handle, memory_handle: T.handle) -> None:
+ s0 = T.int32()
+ s1 = T.int32()
+
+ memory = T.match_buffer(
+ memory_handle,
+ memory_shape,
+ dtype,
+ align=64,
+ offset_factor=1,
+ scope=scope,
+ strides=[s0, s1],
+ )
+ reg = T.match_buffer(
+ reg_handle, (WARP_SIZE, local_size), dtype, align=64,
offset_factor=1, scope="warp"
+ )
+
+ with T.block("root"):
+ T.reads(memory[0:row_dim, 0:col_dim])
+ T.writes(reg[0:WARP_SIZE, 0:local_size])
+ tx = T.env_thread("threadIdx.x")
+ for local_id in T.serial(0, local_size):
+ row, col = T.meta_var(reverse_index_map(tx, local_id))
+ T.launch_thread(tx, WARP_SIZE)
+ reg[tx, local_id] = memory[row, col]
+
+ return mfma_load_desc, mfma_load_impl
+
+
+def get_mfma_intrin(k_dim, in_dtype="float32", out_dtype="float32",
b_transposed=False):
+ local_size = (M_DIM * k_dim) // WARP_SIZE
+ local_size_out = (M_DIM * N_DIM) // WARP_SIZE
+ if k_dim == 4:
+ index_map_A = shared_16x4_to_local_64x1_layout_A
+ index_map_B = shared_4x16_to_local_64x1_layout_B
+ index_map_C = shared_16x16_to_local_64x4_layout_C
+ elif k_dim == 16:
+ index_map_A = shared_16x16_to_local_64x4_layout_A
+ index_map_B = shared_16x16_to_local_64x4_layout_B
+ index_map_C = shared_16x16_to_local_64x4_layout_C
+ else:
+ raise ValueError("k_dim must be 4 or 16 currently")
+
+ out_dtype_abbrv = {"float16": "f16", "float32": "f32", "int8": "i8",
"int32": "i32"}[out_dtype]
+
+ in_dtype_abbrv = {"float16": "f16", "float32": "f32", "int8": "i8",
"int32": "i32"}[in_dtype]
+
+ mfma_intrin =
f"llvm.amdgcn.mfma.{out_dtype_abbrv}.{M_DIM}x{N_DIM}x{k_dim}{in_dtype_abbrv}"
+
+ def maybe_cast(v):
+ if out_dtype != in_dtype:
+ return Cast(out_dtype, v)
+ return v
+
+ def maybe_swap(i, j):
+ if b_transposed:
+ return j, i
+ return i, j
+
+ @T.prim_func
+ def mfma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
+ A = T.match_buffer(a, (WARP_SIZE, local_size), in_dtype,
offset_factor=1, scope="warp")
+ B = T.match_buffer(b, (WARP_SIZE, local_size), in_dtype,
offset_factor=1, scope="warp")
+ C = T.match_buffer(c, (WARP_SIZE, local_size_out), out_dtype,
offset_factor=1, scope="warp")
+
+ with T.block("root"):
+ T.reads(
+ C[0:WARP_SIZE, 0:local_size_out],
+ A[0:WARP_SIZE, 0:local_size],
+ B[0:WARP_SIZE, 0:local_size],
+ )
+ T.writes(C[0:WARP_SIZE, 0:local_size_out])
+
+ for i, j, k in T.grid(M_DIM, N_DIM, k_dim):
+ with T.block("C"):
+ i, j, k = T.axis.remap("SSR", [i, j, k])
+ b_row_ind, b_col_ind = T.meta_var(maybe_swap(k, j))
+
+ thread_id_C, local_id_C = T.meta_var(index_map_C(i, j))
+ thread_id_A, local_id_A = T.meta_var(index_map_A(i, k))
+ thread_id_B, local_id_B =
T.meta_var(index_map_B(b_row_ind, b_col_ind))
+
+ T.reads(
+ C[thread_id_C, local_id_C],
+ A[thread_id_A, local_id_A],
+ B[thread_id_B, local_id_B],
+ )
+ T.writes(C[thread_id_C, local_id_C])
+
+ C[thread_id_C, local_id_C] += maybe_cast(
+ A[thread_id_A, local_id_A]
+ ) * maybe_cast(B[thread_id_B, local_id_B])
+
+ @T.prim_func
+ def mfma_sync_impl_float(a: T.handle, b: T.handle, c: T.handle) -> None:
+ A = T.match_buffer(a, (WARP_SIZE, local_size), in_dtype,
offset_factor=1, scope="warp")
+ B = T.match_buffer(b, (WARP_SIZE, local_size), in_dtype,
offset_factor=1, scope="warp")
+ C = T.match_buffer(c, (WARP_SIZE, local_size_out), out_dtype,
offset_factor=1, scope="warp")
+
+ with T.block("root"):
+ T.reads(
+ A[0:WARP_SIZE, 0:local_size],
+ B[0:WARP_SIZE, 0:local_size],
+ C[0:WARP_SIZE, 0:local_size_out],
+ )
+ T.writes(C[0:WARP_SIZE, 0:local_size_out])
+ tx = T.env_thread("threadIdx.x")
+ T.launch_thread(tx, WARP_SIZE)
+ C[tx, 0:local_size_out] = T.call_llvm_pure_intrin(
+ T.llvm_lookup_intrinsic_id(mfma_intrin),
+ T.uint32(6),
+ A[tx, 0:local_size],
+ B[tx, 0:local_size],
+ C[tx, 0:local_size_out],
+ T.int32(0),
+ T.int32(0),
+ T.int32(0),
+ dtype=f"{out_dtype}x4",
+ )
+
+ @T.prim_func
+ def mfma_sync_impl_integer(a: T.handle, b: T.handle, c: T.handle) -> None:
+ A = T.match_buffer(a, (WARP_SIZE, local_size), in_dtype,
offset_factor=1, scope="warp")
+ B = T.match_buffer(b, (WARP_SIZE, local_size), in_dtype,
offset_factor=1, scope="warp")
+ C = T.match_buffer(c, (WARP_SIZE, local_size_out), out_dtype,
offset_factor=1, scope="warp")
+
+ with T.block("root"):
+ T.reads(
+ A[0:WARP_SIZE, 0:local_size],
+ B[0:WARP_SIZE, 0:local_size],
+ C[0:WARP_SIZE, 0:local_size_out],
+ )
+ T.writes(C[0:WARP_SIZE, 0:local_size_out])
+ tx = T.env_thread("threadIdx.x")
+ T.launch_thread(tx, WARP_SIZE)
+
+ C[tx, 0:local_size_out] = T.call_llvm_pure_intrin(
+ T.llvm_lookup_intrinsic_id(mfma_intrin),
+ T.uint32(6),
+ T.call_intrin("int32", "tir.reinterpret", A[tx, 0:local_size]),
+ T.call_intrin("int32", "tir.reinterpret", A[tx, 0:local_size]),
+ C[tx, 0:local_size_out],
+ T.int32(0),
+ T.int32(0),
+ T.int32(0),
+ dtype=f"{out_dtype}x4",
+ )
+
+ return (
+ (mfma_sync_desc, mfma_sync_impl_integer)
+ if in_dtype == "int8"
+ else (mfma_sync_desc, mfma_sync_impl_float)
+ )
+
+
+def get_mfma_store_intrin(local_size=4, dtype="float32", scope="global"):
+ index_map = shared_16x16_to_local_64x4_layout_C
+
+ @T.prim_func
+ def mfma_store_desc(a: T.handle, c: T.handle) -> None:
+ C_warp = T.match_buffer(a, [WARP_SIZE, local_size], dtype=dtype,
scope="warp")
+ C = T.match_buffer(c, [M_DIM, N_DIM], dtype=dtype, scope=scope)
+
+ with T.block("root"):
+ T.reads(C_warp[0:WARP_SIZE, 0:local_size])
+ T.writes(C[0:M_DIM, 0:N_DIM])
+ for i0, i1 in T.grid(M_DIM, N_DIM):
+ with T.block("C_warp"):
+ v0, v1 = T.axis.remap("SS", [i0, i1])
+ thread_id, local_id = T.meta_var(index_map(v0, v1))
+ T.reads(C_warp[thread_id, local_id])
+ T.writes(C[v0, v1])
+ C[v0, v1] = C_warp[thread_id, local_id]
+
+ @T.prim_func
+ def mfma_store_impl(a: T.handle, c: T.handle) -> None:
+ s0 = T.int32()
+ s1 = T.int32()
+
+ C_warp = T.match_buffer(
+ a, [WARP_SIZE, local_size], dtype=dtype, scope="warp",
offset_factor=1
+ )
+ C = T.match_buffer(
+ c, [M_DIM, N_DIM], dtype=dtype, scope=scope, offset_factor=1,
strides=[s0, s1]
+ )
+
+ with T.block("root"):
+ T.reads(C_warp[0:WARP_SIZE, 0:local_size])
+ T.writes(C[0:M_DIM, 0:N_DIM])
+ tx = T.env_thread("threadIdx.x")
+ T.launch_thread(tx, WARP_SIZE)
+ for i in range(local_size):
+ C[((tx // 16) * 4) + i, (tx % 16)] = C_warp[tx, i]
+
+ return mfma_store_desc, mfma_store_impl
+
+
+ROCM_MFMA_fill_16x16_f32_INTRIN = "ROCM_mfma_fill_16x16_f32"
+TensorIntrin.register(ROCM_MFMA_fill_16x16_f32_INTRIN,
*get_mma_fill_intrin("float32", 4))
+
+ROCM_MFMA_fill_16x16_i32_INTRIN = "ROCM_mfma_fill_16x16_i32"
+TensorIntrin.register(ROCM_MFMA_fill_16x16_i32_INTRIN,
*get_mma_fill_intrin("int", 4))
+
+ROCM_MFMA_LOAD_16x16_A_SHARED_s8_INTRIN = "rocm_mfma_load_16x16_a_shared_s8"
+TensorIntrin.register(
+ ROCM_MFMA_LOAD_16x16_A_SHARED_s8_INTRIN, *get_mfma_load_intrin(16, "int8",
"shared")
+)
+ROCM_MFMA_LOAD_16x16_B_SHARED_s8_INTRIN = "rocm_mfma_load_b_16x16_shared_s8"
+TensorIntrin.register(
+ ROCM_MFMA_LOAD_16x16_B_SHARED_s8_INTRIN, *get_mfma_load_intrin(16, "int8",
"shared", is_b=True)
+)
+
+ROCM_MFMA_LOAD_16x16_A_SHARED_f16_INTRIN = "rocm_mfma_load_16x16_a_shared_f16"
+TensorIntrin.register(
+ ROCM_MFMA_LOAD_16x16_A_SHARED_f16_INTRIN, *get_mfma_load_intrin(16,
"float16", "shared")
+)
+ROCM_MFMA_LOAD_16x16_B_SHARED_f16_INTRIN = "rocm_mfma_load_b_16x16_shared_f16"
+TensorIntrin.register(
+ ROCM_MFMA_LOAD_16x16_B_SHARED_f16_INTRIN,
+ *get_mfma_load_intrin(16, "float16", "shared", is_b=True),
+)
+
+ROCM_MFMA_LOAD_16x4_A_SHARED_f32_INTRIN = "rocm_mfma_load_16x4_a_shared_f32"
+TensorIntrin.register(
+ ROCM_MFMA_LOAD_16x4_A_SHARED_f32_INTRIN, *get_mfma_load_intrin(4,
"float32", "shared")
+)
+ROCM_MFMA_LOAD_16x4_B_SHARED_f32_INTRIN = "rocm_mfma_load_b_16x4_shared_f32"
+TensorIntrin.register(
+ ROCM_MFMA_LOAD_16x4_B_SHARED_f32_INTRIN,
+ *get_mfma_load_intrin(4, "float32", "shared", is_b=True),
+)
+
+
+ROCM_MFMA_f32f32f32_INTRIN = "rocm_mfma_f32f32f32"
+TensorIntrin.register(ROCM_MFMA_f32f32f32_INTRIN, *get_mfma_intrin(4,
"float32", "float32"))
+
+ROCM_MFMA_f16f16f32_INTRIN = "rocm_mfma_f16f16f32"
+TensorIntrin.register(ROCM_MFMA_f16f16f32_INTRIN, *get_mfma_intrin(16,
"float16", "float32"))
+
+ROCM_MFMA_s8s8s32_INTRIN = "rocm_mfma_s8s8s32"
+TensorIntrin.register(ROCM_MFMA_s8s8s32_INTRIN, *get_mfma_intrin(16, "int8",
"int32"))
+
+ROCM_MFMA_STORE_16x16_s32_INTRIN = "rocm_mfma_store_16x16_s32"
+TensorIntrin.register(
+ ROCM_MFMA_STORE_16x16_s32_INTRIN, *get_mfma_store_intrin(4, "int32",
"global")
+)
+
+ROCM_MFMA_STORE_16x16_f32_INTRIN = "rocm_mfma_store_16x16_f32"
+TensorIntrin.register(
+ ROCM_MFMA_STORE_16x16_f32_INTRIN, *get_mfma_store_intrin(4, "float32",
"global")
+)
diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc
index 7d48d9b158..174ae6d43f 100644
--- a/src/target/target_kind.cc
+++ b/src/target/target_kind.cc
@@ -222,22 +222,20 @@ TargetJSON UpdateNVPTXAttrs(TargetJSON target) {
* \return The updated attributes
*/
TargetJSON UpdateROCmAttrs(TargetJSON target) {
+ using tvm::runtime::Registry;
CheckOrSetAttr(&target, "mtriple", "amdgcn-amd-amdhsa-hcc");
// Update -mcpu=gfx
- std::string arch;
+ std::string arch = "gfx900";
if (target.count("mcpu")) {
String mcpu = Downcast<String>(target.at("mcpu"));
arch = ExtractStringWithPrefix(mcpu, "gfx");
ICHECK(!arch.empty()) << "ValueError: ROCm target gets an invalid GFX
version: -mcpu=" << mcpu;
} else {
TVMRetValue val;
- if (!DetectDeviceFlag({kDLROCM, 0}, runtime::kGcnArch, &val)) {
- LOG(WARNING) << "Unable to detect ROCm compute arch, default to
\"-mcpu=gfx900\" instead";
- arch = "900";
- } else {
- arch = val.operator std::string();
+ if (const auto* f_get_rocm_arch =
Registry::Get("tvm_callback_rocm_get_arch")) {
+ arch = (*f_get_rocm_arch)().operator std::string();
}
- target.Set("mcpu", String("gfx") + arch);
+ target.Set("mcpu", String(arch));
}
// Update -mattr before ROCm 3.5:
// Before ROCm 3.5 we needed code object v2, starting
diff --git a/tests/python/unittest/test_tir_schedule_tensorize_mfma.py
b/tests/python/unittest/test_tir_schedule_tensorize_mfma.py
new file mode 100644
index 0000000000..8077a603bc
--- /dev/null
+++ b/tests/python/unittest/test_tir_schedule_tensorize_mfma.py
@@ -0,0 +1,314 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+import tvm
+from tvm import te
+from tvm.tir.tensor_intrin.rocm import (
+ shared_16x4_to_local_64x1_layout_A,
+ shared_4x16_to_local_64x1_layout_B,
+ shared_16x16_to_local_64x4_layout_A,
+ shared_16x16_to_local_64x4_layout_B,
+ shared_16x16_to_local_64x4_layout_C,
+ ROCM_MFMA_fill_16x16_f32_INTRIN,
+ ROCM_MFMA_LOAD_16x4_A_SHARED_f32_INTRIN,
+ ROCM_MFMA_LOAD_16x4_B_SHARED_f32_INTRIN,
+ ROCM_MFMA_f32f32f32_INTRIN,
+ ROCM_MFMA_STORE_16x16_f32_INTRIN,
+ ROCM_MFMA_LOAD_16x16_A_SHARED_f16_INTRIN,
+ ROCM_MFMA_LOAD_16x16_B_SHARED_f16_INTRIN,
+ ROCM_MFMA_f16f16f32_INTRIN,
+ ROCM_MFMA_STORE_16x16_f32_INTRIN,
+ ROCM_MFMA_fill_16x16_i32_INTRIN,
+ ROCM_MFMA_LOAD_16x16_A_SHARED_s8_INTRIN,
+ ROCM_MFMA_LOAD_16x16_B_SHARED_s8_INTRIN,
+ ROCM_MFMA_s8s8s32_INTRIN,
+ ROCM_MFMA_STORE_16x16_s32_INTRIN,
+)
+import tvm.testing
+import numpy as np
+from tvm.testing.tir import mfma_schedule
+
+
+M = 1024
+N = 1024
+K = 1024
+measure_perf = False
+gflops = (N * M * K) * 2 / 1e9
+
+
+def matmul(m, n, k, in_dtype, out_dtype, b_transposed):
+ b_shape = (n, k) if b_transposed else (k, n)
+ a = te.placeholder((m, k), name="A", dtype=in_dtype)
+ b = te.placeholder(b_shape, name="B", dtype=in_dtype)
+ k = te.reduce_axis((0, k), name="k")
+
+ def maybe_cast(v):
+ if in_dtype != out_dtype:
+ return tvm.tir.Cast(out_dtype, v)
+ return v
+
+ def maybe_swap(i, j):
+ if b_transposed:
+ return j, i
+ return i, j
+
+ c = te.compute(
+ (m, n),
+ lambda i, j: te.sum(maybe_cast(a[i, k]) * maybe_cast(b[maybe_swap(k,
j)]), axis=[k]),
+ name="C",
+ )
+ return (a, b, c)
+
+
+def run_test(
+ k_inner,
+ in_dtype,
+ out_dtype,
+ b_transposed,
+ i_factors,
+ j_factors,
+ k_factors,
+ index_map_A,
+ index_map_B,
+ index_map_C,
+ ldmatrix_a_intrin,
+ ldmatrix_b_intrin,
+ mma_intrin,
+ mma_fill_intrin,
+ mma_store_intrin,
+):
+ sch = mfma_schedule(
+ te.create_prim_func(matmul(M, N, K, in_dtype, out_dtype,
b_transposed)),
+ k_inner,
+ in_dtype,
+ b_transposed,
+ i_factors,
+ j_factors,
+ k_factors,
+ index_map_A,
+ index_map_B,
+ index_map_C,
+ ldmatrix_a_intrin,
+ ldmatrix_b_intrin,
+ mma_intrin,
+ mma_fill_intrin,
+ mma_store_intrin,
+ )
+
+ f = tvm.build(sch.mod["main"], target="rocm", name="dense")
+
+ dev = tvm.device("rocm", 0)
+ if in_dtype == "float32":
+ a_np = np.random.uniform(size=(M, K)).astype("float32")
+
+ if b_transposed:
+ b_np = np.random.uniform(size=(N, K)).astype("float32")
+ c_np = np.dot(a_np.astype("float32"),
b_np.astype("float32").transpose()).astype(
+ out_dtype
+ )
+ else:
+ b_np = np.random.uniform(size=(K, N)).astype("float32")
+ c_np = np.dot(a_np.astype("float32"),
b_np.astype("float32")).astype(out_dtype)
+ elif in_dtype == "float16":
+ a_np = np.random.uniform(size=(M, K)).astype("float16")
+
+ if b_transposed:
+ b_np = np.random.uniform(size=(N, K)).astype("float16")
+ c_np = np.dot(a_np.astype("float32"),
b_np.astype("float32").transpose()).astype(
+ out_dtype
+ )
+ else:
+ b_np = np.random.uniform(size=(K, N)).astype("float16")
+ c_np = np.dot(a_np.astype("float32"),
b_np.astype("float32")).astype(out_dtype)
+ else:
+ a_np = np.random.randint(-128, 128, (M, K)).astype("int8")
+
+ if b_transposed:
+ b_np = np.random.randint(-128, 128, (N, K)).astype("int8")
+ c_np = np.dot(a_np.astype("float32"),
b_np.astype("float32").transpose()).astype(
+ "int32"
+ )
+ else:
+ b_np = np.random.randint(-128, 128, (K, N)).astype("int8")
+ c_np = np.dot(a_np.astype("float32"),
b_np.astype("float32")).astype("int32")
+
+ a = tvm.nd.array(a_np, dev)
+ b = tvm.nd.array(b_np, dev)
+ c = tvm.nd.array(np.zeros((M, N), dtype=out_dtype), dev)
+
+ f(a, b, c)
+
+ if in_dtype != "float16":
+ # The numpy reference is computed with fp32 precision (otherwise too
slow).
+ # So there is non-trivial accuracy difference if TVM result is
computed with fp16 accumulation.
+ tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-2, atol=1e-2)
+
+ return lambda: f.time_evaluator(f.entry_name, dev, number=500)(a, b, c)
+
+
[email protected]_matrixcore
+def test_i8i8i32_m16n16k16():
+ def index_map_A(i, j):
+ return (
+ i // 16,
+ j // 16,
+ *shared_16x16_to_local_64x4_layout_A(i % 16, j % 16),
+ )
+
+ def index_map_B(i, j):
+ return (
+ i // 16,
+ j // 16,
+ *shared_16x16_to_local_64x4_layout_B(i % 16, j % 16),
+ )
+
+ def index_map_C(i, j):
+ return (
+ i // 16,
+ j // 16,
+ *shared_16x16_to_local_64x4_layout_C(i % 16, j % 16),
+ )
+
+ k_inner = 16
+ in_dtype = "int8"
+ out_dtype = "int32"
+ i_factors, j_factors, k_factors = [1, 8, 2, 4, 1], [1, 16, 2, 1, 2], [32,
2, 1]
+
+ timer = run_test(
+ k_inner,
+ in_dtype,
+ out_dtype,
+ False, # b_transposed
+ i_factors,
+ j_factors,
+ k_factors,
+ index_map_A,
+ index_map_B,
+ index_map_C,
+ ROCM_MFMA_LOAD_16x16_A_SHARED_s8_INTRIN,
+ ROCM_MFMA_LOAD_16x16_B_SHARED_s8_INTRIN,
+ ROCM_MFMA_s8s8s32_INTRIN,
+ ROCM_MFMA_fill_16x16_i32_INTRIN,
+ ROCM_MFMA_STORE_16x16_s32_INTRIN,
+ )
+
+ if measure_perf and timer:
+ print("test_i8i8i32_m16n16k16: %f GFLOPS" % (gflops / (timer().mean)))
+
+
[email protected]_matrixcore
+def test_f16f16f32_m16n16k16():
+ def index_map_A(i, j):
+ return (
+ i // 16,
+ j // 16,
+ *shared_16x16_to_local_64x4_layout_A(i % 16, j % 16),
+ )
+
+ def index_map_B(i, j):
+ return (
+ i // 16,
+ j // 16,
+ *shared_16x16_to_local_64x4_layout_B(i % 16, j % 16),
+ )
+
+ def index_map_C(i, j):
+ return (
+ i // 16,
+ j // 16,
+ *shared_16x16_to_local_64x4_layout_C(i % 16, j % 16),
+ )
+
+ k_inner = 16
+ in_dtype = "float16"
+ out_dtype = "float32"
+ i_factors, j_factors, k_factors = [1, 8, 2, 4, 1], [1, 16, 2, 1, 2], [32,
2, 1]
+
+ timer = run_test(
+ k_inner,
+ in_dtype,
+ out_dtype,
+ False, # b_transposed
+ i_factors,
+ j_factors,
+ k_factors,
+ index_map_A,
+ index_map_B,
+ index_map_C,
+ ROCM_MFMA_LOAD_16x16_A_SHARED_f16_INTRIN,
+ ROCM_MFMA_LOAD_16x16_B_SHARED_f16_INTRIN,
+ ROCM_MFMA_f16f16f32_INTRIN,
+ ROCM_MFMA_fill_16x16_f32_INTRIN,
+ ROCM_MFMA_STORE_16x16_f32_INTRIN,
+ )
+
+ if measure_perf and timer:
+ print("f16f16f32_m16n16k16: %f GFLOPS" % (gflops / (timer().mean)))
+
+
[email protected]_matrixcore
+def test_f32f32f32_m16n16k4():
+ def index_map_A(i, j):
+ return (
+ i // 16,
+ j // 16,
+ *shared_16x4_to_local_64x1_layout_A(i % 16, j % 16),
+ )
+
+ def index_map_B(i, j):
+ return (
+ i // 16,
+ j // 16,
+ *shared_4x16_to_local_64x1_layout_B(i % 16, j % 16),
+ )
+
+ def index_map_C(i, j):
+ return (
+ i // 16,
+ j // 16,
+ *shared_16x16_to_local_64x4_layout_C(i % 16, j % 16),
+ )
+
+ k_inner = 4
+ in_dtype = "float32"
+ out_dtype = "float32"
+ i_factors, j_factors, k_factors = [4, 2, 1, 4, 2], [4, 2, 2, 1, 4], [128,
2, 1]
+
+ timer = run_test(
+ k_inner,
+ in_dtype,
+ out_dtype,
+ False, # b_transposed
+ i_factors,
+ j_factors,
+ k_factors,
+ index_map_A,
+ index_map_B,
+ index_map_C,
+ ROCM_MFMA_LOAD_16x4_A_SHARED_f32_INTRIN,
+ ROCM_MFMA_LOAD_16x4_B_SHARED_f32_INTRIN,
+ ROCM_MFMA_f32f32f32_INTRIN,
+ ROCM_MFMA_fill_16x16_f32_INTRIN,
+ ROCM_MFMA_STORE_16x16_f32_INTRIN,
+ )
+
+ if measure_perf and timer:
+ print("test_f32f32f32_m16n16k4: %f GFLOPS" % (gflops / (timer().mean)))
+
+
+if __name__ == "__main__":
+ tvm.testing.main()