This is an automated email from the ASF dual-hosted git repository.

expye pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 588d1f2e93 [TensorIR][ROCm] AMD Matrix Core Support (#15106)
588d1f2e93 is described below

commit 588d1f2e93a4e50deb635464ce869efa7c116e1f
Author: Lei Wang <[email protected]>
AuthorDate: Wed Jun 28 16:46:07 2023 +0800

    [TensorIR][ROCm] AMD Matrix Core Support (#15106)
    
    * fix rocm arch parse issue,
    
    * test case init.
    
    * mfma scheduler
    
    * support intrinsic.
    
    * test case update ( i8-i32 f16-f32
    
    * detect if the arch support matrix core
    
    * replace requires_rocm to requires_matrixcore
    
    * remove redundant debug print
    
    * fix lint  & replace local with warp for lowering
    
    * replace default set rocm arch from parsering
    
    * auto detect rocm archtecture.
    
    * code optimize
    
    * simple typo fix.
    
    * lint
    
    * fix lint
    
    * lint fix
    
    * i8 warp mfma fix
    
    * typo fix
    
    * fix get_rocm_arch.
---
 python/tvm/contrib/rocm.py                         | 119 +++++-
 python/tvm/testing/tir.py                          | 111 ++++++
 python/tvm/testing/utils.py                        |  14 +-
 python/tvm/tir/tensor_intrin/rocm.py               | 433 ++++++++++++++++++++-
 src/target/target_kind.cc                          |  12 +-
 .../unittest/test_tir_schedule_tensorize_mfma.py   | 314 +++++++++++++++
 6 files changed, 987 insertions(+), 16 deletions(-)

diff --git a/python/tvm/contrib/rocm.py b/python/tvm/contrib/rocm.py
index b33e20cbc1..fb50c71e22 100644
--- a/python/tvm/contrib/rocm.py
+++ b/python/tvm/contrib/rocm.py
@@ -15,7 +15,9 @@
 # specific language governing permissions and limitations
 # under the License.
 """Utility for ROCm backend"""
+import re
 import subprocess
+import os
 from os.path import join, exists
 
 import tvm._ffi
@@ -147,9 +149,10 @@ def callback_rocm_bitcode_path(rocdl_dir=None):
         "oclc_daz_opt_off",
         "oclc_finite_only_off",
         "oclc_finite_only_on",
-        "oclc_isa_version_803",  # todo (t-vi): an alternative might be to 
scan for the
-        "oclc_isa_version_900",  #              isa version files (if the 
linker throws out
-        "oclc_isa_version_906",  #              the unneeded ones or we filter 
for the arch we need)
+        # todo (t-vi): an alternative might be to scan for the
+        "oclc_isa_version_803",
+        "oclc_isa_version_900",  # isa version files (if the linker throws out
+        "oclc_isa_version_906",  # the unneeded ones or we filter for the arch 
we need)
         "oclc_isa_version_1030",
         "oclc_unsafe_math_off",
         "oclc_unsafe_math_on",
@@ -168,3 +171,113 @@ def callback_rocm_bitcode_path(rocdl_dir=None):
             raise RuntimeError("could not find bitcode " + n)
 
     return tvm.runtime.convert(bitcode_files)
+
+
+def parse_compute_version(compute_version):
+    """Parse compute capability string to divide major and minor version
+
+    Parameters
+    ----------
+    compute_version : str
+        compute capability of a GPU (e.g. "6.0")
+
+    Returns
+    -------
+    major : int
+        major version number
+    minor : int
+        minor version number
+    """
+    split_ver = compute_version.split(".")
+    try:
+        major = int(split_ver[0])
+        minor = int(split_ver[1])
+        return major, minor
+    except (IndexError, ValueError) as err:
+        # pylint: disable=raise-missing-from
+        raise RuntimeError("Compute version parsing error: " + str(err))
+
+
+def have_matrixcore(compute_version=None):
+    """Either MatrixCore support is provided in the compute capability or not
+
+    Parameters
+    ----------
+    compute_version : str, optional
+        compute capability of a GPU (e.g. "7.0").
+
+    Returns
+    -------
+    have_matrixcore : bool
+        True if MatrixCore support is provided, False otherwise
+    """
+    if compute_version is None:
+        if tvm.rocm(0).exist:
+            compute_version = tvm.rocm(0).compute_version
+        else:
+            raise RuntimeError("No ROCm runtime found")
+    major, _ = parse_compute_version(compute_version)
+    # matrix core first introduced in 8.0
+    if major >= 8:
+        return True
+
+    return False
+
+
+@tvm._ffi.register_func("tvm_callback_rocm_get_arch")
+def get_rocm_arch(rocm_path="/opt/rocm"):
+    """Utility function to get the AMD GPU architecture
+
+    Parameters
+    ----------
+    rocm_path : str
+        The path to rocm installation directory
+
+    Returns
+    -------
+    gpu_arch : str
+        The AMD GPU architecture
+    """
+    gpu_arch = "gfx900"
+    # check if rocm is installed
+    if not os.path.exists(rocm_path):
+        print("ROCm not detected, using default gfx900")
+        return gpu_arch
+    try:
+        # Execute rocminfo command
+        rocminfo_output = 
subprocess.check_output([f"{rocm_path}/bin/rocminfo"]).decode("utf-8")
+
+        # Use regex to match the "Name" field
+        match = re.search(r"Name:\s+(gfx\d+[a-zA-Z]*)", rocminfo_output)
+        if match:
+            gpu_arch = match.group(1)
+        return gpu_arch
+    except subprocess.CalledProcessError:
+        print(
+            f"Unable to execute rocminfo command, \
+                please ensure ROCm is installed and you have an AMD GPU on 
your system.\
+                    using default {gpu_arch}."
+        )
+        return gpu_arch
+
+
+def find_rocm_path():
+    """Utility function to find ROCm path
+
+    Returns
+    -------
+    path : str
+        Path to ROCm root.
+    """
+    if "ROCM_PATH" in os.environ:
+        return os.environ["ROCM_PATH"]
+    cmd = ["which", "hipcc"]
+    proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, 
stderr=subprocess.STDOUT)
+    (out, _) = proc.communicate()
+    out = out.decode("utf-8").strip()
+    if proc.returncode == 0:
+        return os.path.realpath(os.path.join(out, "../.."))
+    rocm_path = "/opt/rocm"
+    if os.path.exists(os.path.join(rocm_path, "bin/hipcc")):
+        return rocm_path
+    raise RuntimeError("Cannot find ROCm path")
diff --git a/python/tvm/testing/tir.py b/python/tvm/testing/tir.py
index 57c1a85c5b..6842f1e519 100644
--- a/python/tvm/testing/tir.py
+++ b/python/tvm/testing/tir.py
@@ -128,3 +128,114 @@ def mma_schedule(
     sch.tensorize(sch.get_loops(C_warp)[-2], mma_store_intrin)
 
     return sch
+
+
+def mfma_schedule(
+    workload,
+    k_inner,
+    in_dtype,
+    b_transposed,
+    i_factors,
+    j_factors,
+    k_factors,
+    index_map_A,
+    index_map_B,
+    index_map_C,
+    ldmatrix_a_intrin,
+    ldmatrix_b_intrin,
+    mfma_intrin,
+    mfma_fill_intrin,
+    mfma_store_intrin,
+    shared_scope="shared",
+):
+    """Create a tensorized schedule for GEMM with MFMA intrinsics."""
+    import tvm
+
+    ir_module = tvm.IRModule({"main": workload})
+    sch = tvm.tir.Schedule(ir_module)
+
+    wmma_m = 16
+    wmma_n = 16
+    wmma_k = k_inner
+    warp_size = 64
+    block = sch.get_block("C")
+    i, j, k = sch.get_loops(block)
+    i, i_tc = sch.split(i, factors=[None, wmma_m])
+    j, j_tc = sch.split(j, factors=[None, wmma_n])
+    k, k_tc = sch.split(k, factors=[None, wmma_k])
+
+    sch.reorder(i, j, k, i_tc, j_tc, k_tc)
+
+    block_inner = sch.blockize(i_tc)
+    block_outer, block_inner = block_inner, block
+
+    num_ty = i_factors[2] * j_factors[2]
+
+    i0, i1, i2, i3, i4 = sch.split(i, factors=i_factors)
+    j0, j1, j2, j3, j4 = sch.split(j, factors=j_factors)
+    k0, k1, k2 = sch.split(k, k_factors)
+
+    sch.reorder(i0, j0, i1, j1, j2, i2, k0, k1, i3, j3, k2, i4, j4)
+
+    block_idx = sch.fuse(i0, j0)
+    block_idy = sch.fuse(i1, j1)
+    thread_idy = sch.fuse(j2, i2)
+    sch.bind(block_idx, "blockIdx.x")
+    sch.bind(block_idy, "blockIdx.y")
+    sch.bind(thread_idy, "threadIdx.y")
+
+    def fetch_to_shared(block, idx, ndim):
+        block_read = sch.cache_read(block, idx, shared_scope)
+        sch.compute_at(block_read, k0)
+        vector_size = 16 if in_dtype == "int8" else 8
+        fused = sch.fuse(*sch.get_loops(block_read)[-ndim:])
+        _, f_1, f_2, f_3 = sch.split(fused, factors=[None, num_ty, warp_size, 
vector_size])
+        sch.bind(f_2, "threadIdx.x")
+        sch.bind(f_1, "threadIdx.y")
+        sch.vectorize(f_3)
+        return block_read
+
+    fetch_to_shared(block_outer, 0, 2)
+    fetch_to_shared(block_outer, 1, 2)
+
+    A_warp = sch.cache_read(block_outer, 0, "warp")
+    B_warp = sch.cache_read(block_outer, 1, "warp")
+
+    sch.compute_at(A_warp, k1)
+    sch.compute_at(B_warp, k1)
+    C_warp = sch.cache_write(block_outer, 0, "warp")
+    sch.reverse_compute_at(C_warp, thread_idy)
+
+    ii, jj = sch.get_loops(C_warp)[-2:]
+    io, ii = sch.split(ii, factors=[None, 16])
+    jo, ji = sch.split(jj, factors=[None, 16])
+    sch.reorder(io, jo, ii, ji)
+
+    sch.decompose_reduction(block_outer, sch.get_loops(block_outer)[3])
+    block_init_c = sch.get_block("C_init")
+
+    def tile_wmma_fragment(block_read, height, width):
+        i, j = sch.get_loops(block_read)[-2:]
+        i0, i1 = sch.split(i, factors=[None, height])
+        j0, j1 = sch.split(j, factors=[None, width])
+        sch.reorder(i0, j0, i1, j1)
+        return i1
+
+    loop_a = tile_wmma_fragment(A_warp, 16, k_inner)
+
+    if b_transposed:
+        loop_b = tile_wmma_fragment(B_warp, 16, k_inner)
+    else:
+        loop_b = tile_wmma_fragment(B_warp, k_inner, 16)
+
+    sch.transform_layout(A_warp, ("write", 0), index_map_A)
+    sch.transform_layout(B_warp, ("write", 0), index_map_B)
+    sch.transform_layout(C_warp, ("read", 0), index_map_C)
+
+    sch.tensorize(loop_a, ldmatrix_a_intrin)
+    sch.tensorize(loop_b, ldmatrix_b_intrin)
+    sch.tensorize(sch.get_loops(block_inner)[-3], mfma_intrin)
+    sch.tensorize(sch.get_loops(block_init_c)[-2], mfma_fill_intrin)
+    sch.tensorize(sch.get_loops(C_warp)[-2], mfma_store_intrin)
+
+    return sch
diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py
index 884a885fb2..5f81672748 100644
--- a/python/tvm/testing/utils.py
+++ b/python/tvm/testing/utils.py
@@ -91,7 +91,7 @@ import tvm.tir
 import tvm.te
 import tvm._ffi
 
-from tvm.contrib import nvcc, cudnn
+from tvm.contrib import nvcc, cudnn, rocm
 import tvm.contrib.hexagon._ci_env_check as hexagon
 from tvm.driver.tvmc.frontends import load_model
 from tvm.error import TVMError
@@ -914,6 +914,14 @@ requires_rocm = Feature(
     parent_features="gpu",
 )
 
+# Mark a test as requiring a matrixcore to run
+requires_matrixcore = Feature(
+    "matrixcore",
+    "AMD Matrix Core",
+    run_time_check=lambda: tvm.rocm().exist and 
rocm.have_matrixcore(tvm.rocm().compute_version),
+    parent_features="rocm",
+)
+
 # Mark a test as requiring the metal runtime
 requires_metal = Feature(
     "metal",
@@ -1239,7 +1247,6 @@ def requires_package(*packages):
 
 
 def parametrize_targets(*args):
-
     """Parametrize a test over a specific set of targets.
 
     Use this decorator when you want your test to be run over a
@@ -1503,7 +1510,6 @@ def parameters(*value_sets, ids=None):
 
     outputs = []
     for param_values in zip(*value_sets):
-
         # Optional cls parameter in case a parameter is defined inside a
         # class scope.
         def fixture_func(*_cls, request):
@@ -2029,7 +2035,6 @@ class CompareBeforeAfter:
             return inner
 
         if hasattr(transform, "_pytestfixturefunction"):
-
             if not hasattr(cls, "_transform_orig"):
                 cls._transform_orig = transform
 
@@ -2050,7 +2055,6 @@ class CompareBeforeAfter:
                 return apply(transform(self))
 
         else:
-
             raise TypeError(
                 "Expected transform to be a tvm.ir.transform.Pass, or a method 
returning a Pass"
             )
diff --git a/python/tvm/tir/tensor_intrin/rocm.py 
b/python/tvm/tir/tensor_intrin/rocm.py
index 3700f3e8da..4b7c0da955 100644
--- a/python/tvm/tir/tensor_intrin/rocm.py
+++ b/python/tvm/tir/tensor_intrin/rocm.py
@@ -18,8 +18,13 @@
 """Intrinsics for AMDGPU tensorization."""
 from tvm.script import tir as T
 
-from .. import TensorIntrin
+from tvm.runtime import convert
+from tvm.tir.expr import Cast, IntImm
 from .dot_product_common import dp4a_desc
+from .. import TensorIntrin
+
+
+lift = convert
 
 
 @T.prim_func
@@ -46,3 +51,429 @@ def sdot4(
 AMDGPU_SDOT4_INTRIN = "sdot4"
 
 TensorIntrin.register(AMDGPU_SDOT4_INTRIN, dp4a_desc, sdot4)
+
+WARP_SIZE = 64
+M_DIM = 16
+N_DIM = 16
+
+
+def shared_16x4_to_local_64x1_layout_A(i, j):
+    thread_id = j * 16 + i
+    return thread_id, convert(0)
+
+
+def thread_id_shared_access_64x1_to_16x4_layout_A(thread_id, local_id):
+    i = thread_id % 16
+    j = thread_id // 16 + local_id
+    return i, j
+
+
+def shared_4x16_to_local_64x1_layout_B(i, j):
+    thread_id = i * 16 + j
+    return thread_id, convert(0)
+
+
+def thread_id_shared_access_64x1_to_4x16_layout_B(thread_id, local_id):
+    i = thread_id // 16
+    j = thread_id % 16 + local_id
+    return i, j
+
+
+def shared_16x16_to_local_64x4_layout_C(i, j):
+    thread_id = j + (i // 4) * 16
+    local = i % 4
+    return thread_id, local
+
+
+def thread_id_shared_access_64x4_to_16x16_layout_A(thread_id, local_id):
+    i = thread_id % 16
+    j = (thread_id // 16) * 4 + local_id
+    return i, j
+
+
+def shared_16x16_to_local_64x4_layout_A(i, j):
+    thread_id = i + 16 * (j // 4)
+    local = j % 4
+    return thread_id, local
+
+
+def thread_id_shared_access_64x4_to_16x16_layout_B(thread_id, local_id):
+    i = local_id + (thread_id // 16) * 4
+    j = thread_id % 16
+    return i, j
+
+
+def shared_16x16_to_local_64x4_layout_B(i, j):
+    thread_id = j + (i // 4) * 16
+    local = i % 4
+    return thread_id, local
+
+
+def thread_id_shared_access_64x4_to_16x16_layout_C(thread_id, local_id):
+    i = local_id + (thread_id // 16) * 4
+    j = thread_id % 16
+    return i, j
+
+
+def get_mma_fill_intrin(dtype, local_size):
+    zero = IntImm("int32", 0).astype(dtype)
+
+    # Assume M = N = 16
+    index_map = shared_16x16_to_local_64x4_layout_C
+
+    @T.prim_func
+    def mma_fill_desc(a: T.handle) -> None:
+        C_warp = T.match_buffer(a, [WARP_SIZE, local_size], dtype=dtype, 
scope="warp")
+
+        with T.block("root"):
+            T.reads()
+            T.writes(C_warp[0:WARP_SIZE, 0:local_size])
+            for i0, i1 in T.grid(M_DIM, N_DIM):
+                with T.block("C_warp"):
+                    i, j = T.axis.remap("SS", [i0, i1])
+                    thread_id, local_id = T.meta_var(index_map(i, j))
+                    T.reads()
+                    T.writes(C_warp[thread_id, local_id])
+                    C_warp[thread_id, local_id] = zero
+
+    @T.prim_func
+    def mma_fill_impl(a: T.handle) -> None:
+        C_warp = T.match_buffer(
+            a, [WARP_SIZE, local_size], dtype=dtype, scope="warp", 
offset_factor=1
+        )
+
+        with T.block("root"):
+            T.reads()
+            T.writes(C_warp[0:WARP_SIZE, 0:local_size])
+            tx = T.env_thread("threadIdx.x")
+            T.launch_thread(tx, WARP_SIZE)
+            for local_id in T.serial(0, local_size):
+                C_warp[tx, local_id] = zero
+
+    return mma_fill_desc, mma_fill_impl
+
+
+def get_mfma_load_intrin(
+    k_dim=4,
+    dtype="float32",
+    scope="shared",
+    is_b=False,
+    transposed=False,
+):
+    local_size = (M_DIM * k_dim) // WARP_SIZE if not is_b else (N_DIM * k_dim) 
// WARP_SIZE
+    memory_shape = (M_DIM, k_dim)
+    if is_b:
+        memory_shape = (N_DIM, k_dim) if transposed else (k_dim, N_DIM)
+
+    row_dim, col_dim = memory_shape
+
+    if k_dim == 4:
+        index_map = shared_16x4_to_local_64x1_layout_A
+        reverse_index_map = thread_id_shared_access_64x1_to_16x4_layout_A
+        if is_b:
+            index_map = (
+                shared_16x4_to_local_64x1_layout_A
+                if transposed
+                else shared_4x16_to_local_64x1_layout_B
+            )
+            reverse_index_map = (
+                thread_id_shared_access_64x1_to_16x4_layout_A
+                if transposed
+                else thread_id_shared_access_64x1_to_4x16_layout_B
+            )
+    elif k_dim == 16:
+        index_map = shared_16x16_to_local_64x4_layout_A
+        reverse_index_map = thread_id_shared_access_64x4_to_16x16_layout_A
+
+        if is_b:
+            index_map = (
+                shared_16x16_to_local_64x4_layout_A
+                if transposed
+                else shared_16x16_to_local_64x4_layout_B
+            )
+            reverse_index_map = (
+                thread_id_shared_access_64x4_to_16x16_layout_A
+                if transposed
+                else thread_id_shared_access_64x4_to_16x16_layout_B
+            )
+    else:
+        raise ValueError("k_dim must be 4 or 16 currently")
+
+    @T.prim_func
+    def mfma_load_desc(reg_handle: T.handle, memory_handle: T.handle) -> None:
+        memory = T.match_buffer(
+            memory_handle,
+            memory_shape,
+            dtype,
+            offset_factor=1,
+            scope=scope,
+        )
+        reg = T.match_buffer(
+            reg_handle, (WARP_SIZE, local_size), dtype, offset_factor=1, 
scope="warp"
+        )
+
+        with T.block("root"):
+            T.reads(memory[0:row_dim, 0:col_dim])
+            T.writes(reg[0:WARP_SIZE, 0:local_size])
+
+            for ax0, ax1 in T.grid(row_dim, col_dim):
+                with T.block("memory_reg"):
+                    v0, v1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(memory[v0, v1])
+
+                    thread_id, local_id = T.meta_var(index_map(v0, v1))
+                    T.writes(reg[thread_id, local_id])
+                    reg[thread_id, local_id] = memory[v0, v1]
+
+    @T.prim_func
+    def mfma_load_impl(reg_handle: T.handle, memory_handle: T.handle) -> None:
+        s0 = T.int32()
+        s1 = T.int32()
+
+        memory = T.match_buffer(
+            memory_handle,
+            memory_shape,
+            dtype,
+            align=64,
+            offset_factor=1,
+            scope=scope,
+            strides=[s0, s1],
+        )
+        reg = T.match_buffer(
+            reg_handle, (WARP_SIZE, local_size), dtype, align=64, 
offset_factor=1, scope="warp"
+        )
+
+        with T.block("root"):
+            T.reads(memory[0:row_dim, 0:col_dim])
+            T.writes(reg[0:WARP_SIZE, 0:local_size])
+            tx = T.env_thread("threadIdx.x")
+            for local_id in T.serial(0, local_size):
+                row, col = T.meta_var(reverse_index_map(tx, local_id))
+                T.launch_thread(tx, WARP_SIZE)
+                reg[tx, local_id] = memory[row, col]
+
+    return mfma_load_desc, mfma_load_impl
+
+
+def get_mfma_intrin(k_dim, in_dtype="float32", out_dtype="float32", 
b_transposed=False):
+    local_size = (M_DIM * k_dim) // WARP_SIZE
+    local_size_out = (M_DIM * N_DIM) // WARP_SIZE
+    if k_dim == 4:
+        index_map_A = shared_16x4_to_local_64x1_layout_A
+        index_map_B = shared_4x16_to_local_64x1_layout_B
+        index_map_C = shared_16x16_to_local_64x4_layout_C
+    elif k_dim == 16:
+        index_map_A = shared_16x16_to_local_64x4_layout_A
+        index_map_B = shared_16x16_to_local_64x4_layout_B
+        index_map_C = shared_16x16_to_local_64x4_layout_C
+    else:
+        raise ValueError("k_dim must be 4 or 16 currently")
+
+    out_dtype_abbrv = {"float16": "f16", "float32": "f32", "int8": "i8", 
"int32": "i32"}[out_dtype]
+
+    in_dtype_abbrv = {"float16": "f16", "float32": "f32", "int8": "i8", 
"int32": "i32"}[in_dtype]
+
+    mfma_intrin = 
f"llvm.amdgcn.mfma.{out_dtype_abbrv}.{M_DIM}x{N_DIM}x{k_dim}{in_dtype_abbrv}"
+
+    def maybe_cast(v):
+        if out_dtype != in_dtype:
+            return Cast(out_dtype, v)
+        return v
+
+    def maybe_swap(i, j):
+        if b_transposed:
+            return j, i
+        return i, j
+
+    @T.prim_func
+    def mfma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
+        A = T.match_buffer(a, (WARP_SIZE, local_size), in_dtype, 
offset_factor=1, scope="warp")
+        B = T.match_buffer(b, (WARP_SIZE, local_size), in_dtype, 
offset_factor=1, scope="warp")
+        C = T.match_buffer(c, (WARP_SIZE, local_size_out), out_dtype, 
offset_factor=1, scope="warp")
+
+        with T.block("root"):
+            T.reads(
+                C[0:WARP_SIZE, 0:local_size_out],
+                A[0:WARP_SIZE, 0:local_size],
+                B[0:WARP_SIZE, 0:local_size],
+            )
+            T.writes(C[0:WARP_SIZE, 0:local_size_out])
+
+            for i, j, k in T.grid(M_DIM, N_DIM, k_dim):
+                with T.block("C"):
+                    i, j, k = T.axis.remap("SSR", [i, j, k])
+                    b_row_ind, b_col_ind = T.meta_var(maybe_swap(k, j))
+
+                    thread_id_C, local_id_C = T.meta_var(index_map_C(i, j))
+                    thread_id_A, local_id_A = T.meta_var(index_map_A(i, k))
+                    thread_id_B, local_id_B = 
T.meta_var(index_map_B(b_row_ind, b_col_ind))
+
+                    T.reads(
+                        C[thread_id_C, local_id_C],
+                        A[thread_id_A, local_id_A],
+                        B[thread_id_B, local_id_B],
+                    )
+                    T.writes(C[thread_id_C, local_id_C])
+
+                    C[thread_id_C, local_id_C] += maybe_cast(
+                        A[thread_id_A, local_id_A]
+                    ) * maybe_cast(B[thread_id_B, local_id_B])
+
+    @T.prim_func
+    def mfma_sync_impl_float(a: T.handle, b: T.handle, c: T.handle) -> None:
+        A = T.match_buffer(a, (WARP_SIZE, local_size), in_dtype, 
offset_factor=1, scope="warp")
+        B = T.match_buffer(b, (WARP_SIZE, local_size), in_dtype, 
offset_factor=1, scope="warp")
+        C = T.match_buffer(c, (WARP_SIZE, local_size_out), out_dtype, 
offset_factor=1, scope="warp")
+
+        with T.block("root"):
+            T.reads(
+                A[0:WARP_SIZE, 0:local_size],
+                B[0:WARP_SIZE, 0:local_size],
+                C[0:WARP_SIZE, 0:local_size_out],
+            )
+            T.writes(C[0:WARP_SIZE, 0:local_size_out])
+            tx = T.env_thread("threadIdx.x")
+            T.launch_thread(tx, WARP_SIZE)
+            C[tx, 0:local_size_out] = T.call_llvm_pure_intrin(
+                T.llvm_lookup_intrinsic_id(mfma_intrin),
+                T.uint32(6),
+                A[tx, 0:local_size],
+                B[tx, 0:local_size],
+                C[tx, 0:local_size_out],
+                T.int32(0),
+                T.int32(0),
+                T.int32(0),
+                dtype=f"{out_dtype}x4",
+            )
+
+    @T.prim_func
+    def mfma_sync_impl_integer(a: T.handle, b: T.handle, c: T.handle) -> None:
+        A = T.match_buffer(a, (WARP_SIZE, local_size), in_dtype, 
offset_factor=1, scope="warp")
+        B = T.match_buffer(b, (WARP_SIZE, local_size), in_dtype, 
offset_factor=1, scope="warp")
+        C = T.match_buffer(c, (WARP_SIZE, local_size_out), out_dtype, 
offset_factor=1, scope="warp")
+
+        with T.block("root"):
+            T.reads(
+                A[0:WARP_SIZE, 0:local_size],
+                B[0:WARP_SIZE, 0:local_size],
+                C[0:WARP_SIZE, 0:local_size_out],
+            )
+            T.writes(C[0:WARP_SIZE, 0:local_size_out])
+            tx = T.env_thread("threadIdx.x")
+            T.launch_thread(tx, WARP_SIZE)
+
+            C[tx, 0:local_size_out] = T.call_llvm_pure_intrin(
+                T.llvm_lookup_intrinsic_id(mfma_intrin),
+                T.uint32(6),
+                T.call_intrin("int32", "tir.reinterpret", A[tx, 0:local_size]),
+                T.call_intrin("int32", "tir.reinterpret", A[tx, 0:local_size]),
+                C[tx, 0:local_size_out],
+                T.int32(0),
+                T.int32(0),
+                T.int32(0),
+                dtype=f"{out_dtype}x4",
+            )
+
+    return (
+        (mfma_sync_desc, mfma_sync_impl_integer)
+        if in_dtype == "int8"
+        else (mfma_sync_desc, mfma_sync_impl_float)
+    )
+
+
+def get_mfma_store_intrin(local_size=4, dtype="float32", scope="global"):
+    index_map = shared_16x16_to_local_64x4_layout_C
+
+    @T.prim_func
+    def mfma_store_desc(a: T.handle, c: T.handle) -> None:
+        C_warp = T.match_buffer(a, [WARP_SIZE, local_size], dtype=dtype, 
scope="warp")
+        C = T.match_buffer(c, [M_DIM, N_DIM], dtype=dtype, scope=scope)
+
+        with T.block("root"):
+            T.reads(C_warp[0:WARP_SIZE, 0:local_size])
+            T.writes(C[0:M_DIM, 0:N_DIM])
+            for i0, i1 in T.grid(M_DIM, N_DIM):
+                with T.block("C_warp"):
+                    v0, v1 = T.axis.remap("SS", [i0, i1])
+                    thread_id, local_id = T.meta_var(index_map(v0, v1))
+                    T.reads(C_warp[thread_id, local_id])
+                    T.writes(C[v0, v1])
+                    C[v0, v1] = C_warp[thread_id, local_id]
+
+    @T.prim_func
+    def mfma_store_impl(a: T.handle, c: T.handle) -> None:
+        s0 = T.int32()
+        s1 = T.int32()
+
+        C_warp = T.match_buffer(
+            a, [WARP_SIZE, local_size], dtype=dtype, scope="warp", 
offset_factor=1
+        )
+        C = T.match_buffer(
+            c, [M_DIM, N_DIM], dtype=dtype, scope=scope, offset_factor=1, 
strides=[s0, s1]
+        )
+
+        with T.block("root"):
+            T.reads(C_warp[0:WARP_SIZE, 0:local_size])
+            T.writes(C[0:M_DIM, 0:N_DIM])
+            tx = T.env_thread("threadIdx.x")
+            T.launch_thread(tx, WARP_SIZE)
+            for i in range(local_size):
+                C[((tx // 16) * 4) + i, (tx % 16)] = C_warp[tx, i]
+
+    return mfma_store_desc, mfma_store_impl
+
+
+ROCM_MFMA_fill_16x16_f32_INTRIN = "ROCM_mfma_fill_16x16_f32"
+TensorIntrin.register(ROCM_MFMA_fill_16x16_f32_INTRIN, 
*get_mma_fill_intrin("float32", 4))
+
+ROCM_MFMA_fill_16x16_i32_INTRIN = "ROCM_mfma_fill_16x16_i32"
+TensorIntrin.register(ROCM_MFMA_fill_16x16_i32_INTRIN, 
*get_mma_fill_intrin("int", 4))
+
+ROCM_MFMA_LOAD_16x16_A_SHARED_s8_INTRIN = "rocm_mfma_load_16x16_a_shared_s8"
+TensorIntrin.register(
+    ROCM_MFMA_LOAD_16x16_A_SHARED_s8_INTRIN, *get_mfma_load_intrin(16, "int8", 
"shared")
+)
+ROCM_MFMA_LOAD_16x16_B_SHARED_s8_INTRIN = "rocm_mfma_load_b_16x16_shared_s8"
+TensorIntrin.register(
+    ROCM_MFMA_LOAD_16x16_B_SHARED_s8_INTRIN, *get_mfma_load_intrin(16, "int8", 
"shared", is_b=True)
+)
+
+ROCM_MFMA_LOAD_16x16_A_SHARED_f16_INTRIN = "rocm_mfma_load_16x16_a_shared_f16"
+TensorIntrin.register(
+    ROCM_MFMA_LOAD_16x16_A_SHARED_f16_INTRIN, *get_mfma_load_intrin(16, 
"float16", "shared")
+)
+ROCM_MFMA_LOAD_16x16_B_SHARED_f16_INTRIN = "rocm_mfma_load_b_16x16_shared_f16"
+TensorIntrin.register(
+    ROCM_MFMA_LOAD_16x16_B_SHARED_f16_INTRIN,
+    *get_mfma_load_intrin(16, "float16", "shared", is_b=True),
+)
+
+ROCM_MFMA_LOAD_16x4_A_SHARED_f32_INTRIN = "rocm_mfma_load_16x4_a_shared_f32"
+TensorIntrin.register(
+    ROCM_MFMA_LOAD_16x4_A_SHARED_f32_INTRIN, *get_mfma_load_intrin(4, 
"float32", "shared")
+)
+ROCM_MFMA_LOAD_16x4_B_SHARED_f32_INTRIN = "rocm_mfma_load_b_16x4_shared_f32"
+TensorIntrin.register(
+    ROCM_MFMA_LOAD_16x4_B_SHARED_f32_INTRIN,
+    *get_mfma_load_intrin(4, "float32", "shared", is_b=True),
+)
+
+
+ROCM_MFMA_f32f32f32_INTRIN = "rocm_mfma_f32f32f32"
+TensorIntrin.register(ROCM_MFMA_f32f32f32_INTRIN, *get_mfma_intrin(4, 
"float32", "float32"))
+
+ROCM_MFMA_f16f16f32_INTRIN = "rocm_mfma_f16f16f32"
+TensorIntrin.register(ROCM_MFMA_f16f16f32_INTRIN, *get_mfma_intrin(16, 
"float16", "float32"))
+
+ROCM_MFMA_s8s8s32_INTRIN = "rocm_mfma_s8s8s32"
+TensorIntrin.register(ROCM_MFMA_s8s8s32_INTRIN, *get_mfma_intrin(16, "int8", 
"int32"))
+
+ROCM_MFMA_STORE_16x16_s32_INTRIN = "rocm_mfma_store_16x16_s32"
+TensorIntrin.register(
+    ROCM_MFMA_STORE_16x16_s32_INTRIN, *get_mfma_store_intrin(4, "int32", 
"global")
+)
+
+ROCM_MFMA_STORE_16x16_f32_INTRIN = "rocm_mfma_store_16x16_f32"
+TensorIntrin.register(
+    ROCM_MFMA_STORE_16x16_f32_INTRIN, *get_mfma_store_intrin(4, "float32", 
"global")
+)
diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc
index 7d48d9b158..174ae6d43f 100644
--- a/src/target/target_kind.cc
+++ b/src/target/target_kind.cc
@@ -222,22 +222,20 @@ TargetJSON UpdateNVPTXAttrs(TargetJSON target) {
  * \return The updated attributes
  */
 TargetJSON UpdateROCmAttrs(TargetJSON target) {
+  using tvm::runtime::Registry;
   CheckOrSetAttr(&target, "mtriple", "amdgcn-amd-amdhsa-hcc");
   // Update -mcpu=gfx
-  std::string arch;
+  std::string arch = "gfx900";
   if (target.count("mcpu")) {
     String mcpu = Downcast<String>(target.at("mcpu"));
     arch = ExtractStringWithPrefix(mcpu, "gfx");
     ICHECK(!arch.empty()) << "ValueError: ROCm target gets an invalid GFX 
version: -mcpu=" << mcpu;
   } else {
     TVMRetValue val;
-    if (!DetectDeviceFlag({kDLROCM, 0}, runtime::kGcnArch, &val)) {
-      LOG(WARNING) << "Unable to detect ROCm compute arch, default to 
\"-mcpu=gfx900\" instead";
-      arch = "900";
-    } else {
-      arch = val.operator std::string();
+    if (const auto* f_get_rocm_arch = 
Registry::Get("tvm_callback_rocm_get_arch")) {
+      arch = (*f_get_rocm_arch)().operator std::string();
     }
-    target.Set("mcpu", String("gfx") + arch);
+    target.Set("mcpu", String(arch));
   }
   // Update -mattr before ROCm 3.5:
   //   Before ROCm 3.5 we needed code object v2, starting
diff --git a/tests/python/unittest/test_tir_schedule_tensorize_mfma.py 
b/tests/python/unittest/test_tir_schedule_tensorize_mfma.py
new file mode 100644
index 0000000000..8077a603bc
--- /dev/null
+++ b/tests/python/unittest/test_tir_schedule_tensorize_mfma.py
@@ -0,0 +1,314 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+import tvm
+from tvm import te
+from tvm.tir.tensor_intrin.rocm import (
+    shared_16x4_to_local_64x1_layout_A,
+    shared_4x16_to_local_64x1_layout_B,
+    shared_16x16_to_local_64x4_layout_A,
+    shared_16x16_to_local_64x4_layout_B,
+    shared_16x16_to_local_64x4_layout_C,
+    ROCM_MFMA_fill_16x16_f32_INTRIN,
+    ROCM_MFMA_LOAD_16x4_A_SHARED_f32_INTRIN,
+    ROCM_MFMA_LOAD_16x4_B_SHARED_f32_INTRIN,
+    ROCM_MFMA_f32f32f32_INTRIN,
+    ROCM_MFMA_STORE_16x16_f32_INTRIN,
+    ROCM_MFMA_LOAD_16x16_A_SHARED_f16_INTRIN,
+    ROCM_MFMA_LOAD_16x16_B_SHARED_f16_INTRIN,
+    ROCM_MFMA_f16f16f32_INTRIN,
+    ROCM_MFMA_STORE_16x16_f32_INTRIN,
+    ROCM_MFMA_fill_16x16_i32_INTRIN,
+    ROCM_MFMA_LOAD_16x16_A_SHARED_s8_INTRIN,
+    ROCM_MFMA_LOAD_16x16_B_SHARED_s8_INTRIN,
+    ROCM_MFMA_s8s8s32_INTRIN,
+    ROCM_MFMA_STORE_16x16_s32_INTRIN,
+)
+import tvm.testing
+import numpy as np
+from tvm.testing.tir import mfma_schedule
+
+
+M = 1024
+N = 1024
+K = 1024
+measure_perf = False
+gflops = (N * M * K) * 2 / 1e9
+
+
+def matmul(m, n, k, in_dtype, out_dtype, b_transposed):
+    b_shape = (n, k) if b_transposed else (k, n)
+    a = te.placeholder((m, k), name="A", dtype=in_dtype)
+    b = te.placeholder(b_shape, name="B", dtype=in_dtype)
+    k = te.reduce_axis((0, k), name="k")
+
+    def maybe_cast(v):
+        if in_dtype != out_dtype:
+            return tvm.tir.Cast(out_dtype, v)
+        return v
+
+    def maybe_swap(i, j):
+        if b_transposed:
+            return j, i
+        return i, j
+
+    c = te.compute(
+        (m, n),
+        lambda i, j: te.sum(maybe_cast(a[i, k]) * maybe_cast(b[maybe_swap(k, 
j)]), axis=[k]),
+        name="C",
+    )
+    return (a, b, c)
+
+
+def run_test(
+    k_inner,
+    in_dtype,
+    out_dtype,
+    b_transposed,
+    i_factors,
+    j_factors,
+    k_factors,
+    index_map_A,
+    index_map_B,
+    index_map_C,
+    ldmatrix_a_intrin,
+    ldmatrix_b_intrin,
+    mma_intrin,
+    mma_fill_intrin,
+    mma_store_intrin,
+):
+    sch = mfma_schedule(
+        te.create_prim_func(matmul(M, N, K, in_dtype, out_dtype, 
b_transposed)),
+        k_inner,
+        in_dtype,
+        b_transposed,
+        i_factors,
+        j_factors,
+        k_factors,
+        index_map_A,
+        index_map_B,
+        index_map_C,
+        ldmatrix_a_intrin,
+        ldmatrix_b_intrin,
+        mma_intrin,
+        mma_fill_intrin,
+        mma_store_intrin,
+    )
+
+    f = tvm.build(sch.mod["main"], target="rocm", name="dense")
+
+    dev = tvm.device("rocm", 0)
+    if in_dtype == "float32":
+        a_np = np.random.uniform(size=(M, K)).astype("float32")
+
+        if b_transposed:
+            b_np = np.random.uniform(size=(N, K)).astype("float32")
+            c_np = np.dot(a_np.astype("float32"), 
b_np.astype("float32").transpose()).astype(
+                out_dtype
+            )
+        else:
+            b_np = np.random.uniform(size=(K, N)).astype("float32")
+            c_np = np.dot(a_np.astype("float32"), 
b_np.astype("float32")).astype(out_dtype)
+    elif in_dtype == "float16":
+        a_np = np.random.uniform(size=(M, K)).astype("float16")
+
+        if b_transposed:
+            b_np = np.random.uniform(size=(N, K)).astype("float16")
+            c_np = np.dot(a_np.astype("float32"), 
b_np.astype("float32").transpose()).astype(
+                out_dtype
+            )
+        else:
+            b_np = np.random.uniform(size=(K, N)).astype("float16")
+            c_np = np.dot(a_np.astype("float32"), 
b_np.astype("float32")).astype(out_dtype)
+    else:
+        a_np = np.random.randint(-128, 128, (M, K)).astype("int8")
+
+        if b_transposed:
+            b_np = np.random.randint(-128, 128, (N, K)).astype("int8")
+            c_np = np.dot(a_np.astype("float32"), 
b_np.astype("float32").transpose()).astype(
+                "int32"
+            )
+        else:
+            b_np = np.random.randint(-128, 128, (K, N)).astype("int8")
+            c_np = np.dot(a_np.astype("float32"), 
b_np.astype("float32")).astype("int32")
+
+    a = tvm.nd.array(a_np, dev)
+    b = tvm.nd.array(b_np, dev)
+    c = tvm.nd.array(np.zeros((M, N), dtype=out_dtype), dev)
+
+    f(a, b, c)
+
+    if in_dtype != "float16":
+        # The numpy reference is computed with fp32 precision (otherwise too 
slow).
+        # So there is non-trivial accuracy difference if TVM result is 
computed with fp16 accumulation.
+        tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-2, atol=1e-2)
+
+    return lambda: f.time_evaluator(f.entry_name, dev, number=500)(a, b, c)
+
+
[email protected]_matrixcore
+def test_i8i8i32_m16n16k16():
+    def index_map_A(i, j):
+        return (
+            i // 16,
+            j // 16,
+            *shared_16x16_to_local_64x4_layout_A(i % 16, j % 16),
+        )
+
+    def index_map_B(i, j):
+        return (
+            i // 16,
+            j // 16,
+            *shared_16x16_to_local_64x4_layout_B(i % 16, j % 16),
+        )
+
+    def index_map_C(i, j):
+        return (
+            i // 16,
+            j // 16,
+            *shared_16x16_to_local_64x4_layout_C(i % 16, j % 16),
+        )
+
+    k_inner = 16
+    in_dtype = "int8"
+    out_dtype = "int32"
+    i_factors, j_factors, k_factors = [1, 8, 2, 4, 1], [1, 16, 2, 1, 2], [32, 
2, 1]
+
+    timer = run_test(
+        k_inner,
+        in_dtype,
+        out_dtype,
+        False,  # b_transposed
+        i_factors,
+        j_factors,
+        k_factors,
+        index_map_A,
+        index_map_B,
+        index_map_C,
+        ROCM_MFMA_LOAD_16x16_A_SHARED_s8_INTRIN,
+        ROCM_MFMA_LOAD_16x16_B_SHARED_s8_INTRIN,
+        ROCM_MFMA_s8s8s32_INTRIN,
+        ROCM_MFMA_fill_16x16_i32_INTRIN,
+        ROCM_MFMA_STORE_16x16_s32_INTRIN,
+    )
+
+    if measure_perf and timer:
+        print("test_i8i8i32_m16n16k16: %f GFLOPS" % (gflops / (timer().mean)))
+
+
[email protected]_matrixcore
+def test_f16f16f32_m16n16k16():
+    def index_map_A(i, j):
+        return (
+            i // 16,
+            j // 16,
+            *shared_16x16_to_local_64x4_layout_A(i % 16, j % 16),
+        )
+
+    def index_map_B(i, j):
+        return (
+            i // 16,
+            j // 16,
+            *shared_16x16_to_local_64x4_layout_B(i % 16, j % 16),
+        )
+
+    def index_map_C(i, j):
+        return (
+            i // 16,
+            j // 16,
+            *shared_16x16_to_local_64x4_layout_C(i % 16, j % 16),
+        )
+
+    k_inner = 16
+    in_dtype = "float16"
+    out_dtype = "float32"
+    i_factors, j_factors, k_factors = [1, 8, 2, 4, 1], [1, 16, 2, 1, 2], [32, 
2, 1]
+
+    timer = run_test(
+        k_inner,
+        in_dtype,
+        out_dtype,
+        False,  # b_transposed
+        i_factors,
+        j_factors,
+        k_factors,
+        index_map_A,
+        index_map_B,
+        index_map_C,
+        ROCM_MFMA_LOAD_16x16_A_SHARED_f16_INTRIN,
+        ROCM_MFMA_LOAD_16x16_B_SHARED_f16_INTRIN,
+        ROCM_MFMA_f16f16f32_INTRIN,
+        ROCM_MFMA_fill_16x16_f32_INTRIN,
+        ROCM_MFMA_STORE_16x16_f32_INTRIN,
+    )
+
+    if measure_perf and timer:
+        print("f16f16f32_m16n16k16: %f GFLOPS" % (gflops / (timer().mean)))
+
+
[email protected]_matrixcore
+def test_f32f32f32_m16n16k4():
+    def index_map_A(i, j):
+        return (
+            i // 16,
+            j // 16,
+            *shared_16x4_to_local_64x1_layout_A(i % 16, j % 16),
+        )
+
+    def index_map_B(i, j):
+        return (
+            i // 16,
+            j // 16,
+            *shared_4x16_to_local_64x1_layout_B(i % 16, j % 16),
+        )
+
+    def index_map_C(i, j):
+        return (
+            i // 16,
+            j // 16,
+            *shared_16x16_to_local_64x4_layout_C(i % 16, j % 16),
+        )
+
+    k_inner = 4
+    in_dtype = "float32"
+    out_dtype = "float32"
+    i_factors, j_factors, k_factors = [4, 2, 1, 4, 2], [4, 2, 2, 1, 4], [128, 
2, 1]
+
+    timer = run_test(
+        k_inner,
+        in_dtype,
+        out_dtype,
+        False,  # b_transposed
+        i_factors,
+        j_factors,
+        k_factors,
+        index_map_A,
+        index_map_B,
+        index_map_C,
+        ROCM_MFMA_LOAD_16x4_A_SHARED_f32_INTRIN,
+        ROCM_MFMA_LOAD_16x4_B_SHARED_f32_INTRIN,
+        ROCM_MFMA_f32f32f32_INTRIN,
+        ROCM_MFMA_fill_16x16_f32_INTRIN,
+        ROCM_MFMA_STORE_16x16_f32_INTRIN,
+    )
+
+    if measure_perf and timer:
+        print("test_f32f32f32_m16n16k4: %f GFLOPS" % (gflops / (timer().mean)))
+
+
+if __name__ == "__main__":
+    tvm.testing.main()


Reply via email to