vinx13 commented on code in PR #11677:
URL: https://github.com/apache/tvm/pull/11677#discussion_r894961383


##########
python/tvm/tir/tensor_intrin/cuda.py:
##########
@@ -482,3 +484,325 @@ def mma_store_impl(a: T.handle, c: T.handle) -> None:
 TensorIntrin.register(
     MMA_store_16x16_i32_global_INTRIN, *get_mma_store_intrin("int32", 8, 
"global")
 )
+
+
+######## WMMA intrinsics ########
+
+
+def get_wmma_fragment_index(buffer, m_dim, n_dim):
+    """Compute wmma fragment index using elem_offset of the buffer"""
+    frag_size = lift(m_dim * n_dim)
+    return buffer.elem_offset // frag_size + (buffer.elem_offset % frag_size) 
// n_dim
+
+
+def get_wmma_load_intrin(
+    m_dim: int,
+    n_dim: int,
+    k_dim: int,
+    dtype: str,
+    shared_scope: str,
+    is_b: bool,
+    is_col_major: bool,
+) -> Tuple[PrimFunc, PrimFunc]:
+    """Generator of wmma_load intrins"""
+    wmma_fragment_scope = "wmma.matrix_{}".format("b" if is_b else "a")
+    layout = "col_major" if is_col_major else "row_major"
+
+    @T.prim_func
+    def wmma_load_desc(a: T.handle, c: T.handle) -> None:
+        A = T.match_buffer(
+            a, (m_dim, n_dim), dtype, align=128, offset_factor=16, 
scope=shared_scope
+        )
+        C = T.match_buffer(
+            c, (m_dim, n_dim), dtype, align=128, offset_factor=16, 
scope=wmma_fragment_scope
+        )
+        with T.block("root"):
+            T.reads(A[0:m_dim, 0:n_dim])
+            T.writes(C[0:m_dim, 0:n_dim])
+            for i, j in T.grid(m_dim, n_dim):
+                with T.block("load"):
+                    vii, vjj = T.axis.remap("SS", [i, j])
+                    C[vii, vjj] = A[vii, vjj]
+
+    @T.prim_func
+    def wmma_load_impl(a: T.handle, c: T.handle) -> None:
+        s1 = T.var("int32")
+        s0 = T.var("int32")
+        A = T.match_buffer(
+            a,
+            (m_dim, n_dim),
+            dtype,
+            align=128,
+            offset_factor=16,
+            scope=shared_scope,
+            strides=[s1, s0],
+        )
+        C = T.match_buffer(
+            c, (m_dim, n_dim), dtype, align=128, offset_factor=16, 
scope=wmma_fragment_scope
+        )
+        with T.block("root"):
+            T.reads(A[0:m_dim, 0:n_dim])
+            T.writes(C[0:m_dim, 0:n_dim])
+            T.evaluate(
+                T.tvm_load_matrix_sync(
+                    C.data,
+                    m_dim,
+                    n_dim,
+                    k_dim,
+                    get_wmma_fragment_index(C, m_dim, n_dim),
+                    A.access_ptr("r"),
+                    s1,
+                    layout,
+                    dtype="handle",
+                )
+            )
+
+    return wmma_load_desc, wmma_load_impl
+
+
+def get_wmma_fill_intrin(
+    m_dim: int, n_dim: int, k_dim: int, dtype: str
+) -> Tuple[PrimFunc, PrimFunc]:
+    """Generator of wmma_fill intrins"""
+    zero = IntImm("int32", 0).astype(dtype)
+
+    @T.prim_func
+    def wmma_fill_desc(c: T.handle) -> None:
+        C = T.match_buffer(
+            c, (m_dim, n_dim), dtype, align=128, offset_factor=16, 
scope="wmma.accumulator"
+        )
+        with T.block("root"):
+            T.reads()
+            T.writes(C[0:m_dim, 0:n_dim])
+            for i, j in T.grid(m_dim, n_dim):
+                with T.block("init"):
+                    vii, vjj = T.axis.remap("SS", [i, j])
+                    C[vii, vjj] = zero
+
+    @T.prim_func
+    def wmma_fill_impl(c: T.handle) -> None:
+        C = T.match_buffer(
+            c, (m_dim, n_dim), dtype, align=128, offset_factor=16, 
scope="wmma.accumulator"
+        )
+        with T.block("root"):
+            T.reads()
+            T.writes(C[0:m_dim, 0:n_dim])
+            T.evaluate(
+                T.tvm_fill_fragment(
+                    C.data,
+                    m_dim,
+                    n_dim,
+                    k_dim,
+                    get_wmma_fragment_index(C, m_dim, n_dim),
+                    T.float32(0),
+                    dtype="handle",
+                )
+            )
+
+    return wmma_fill_desc, wmma_fill_impl
+
+
+def get_wmma_store_intrin(
+    m_dim: int, n_dim: int, k_dim: int, dtype: str, scope: str
+) -> Tuple[PrimFunc, PrimFunc]:
+    """Generator of wmma_store intrins"""
+
+    @T.prim_func
+    def wmma_store_desc(a: T.handle, c: T.handle) -> None:
+        A = T.match_buffer(
+            a, (m_dim, n_dim), dtype, align=128, offset_factor=16, 
scope="wmma.accumulator"
+        )
+        C = T.match_buffer(c, (m_dim, n_dim), dtype, align=128, 
offset_factor=16, scope=scope)
+        with T.block("root"):
+            T.reads(A[0:m_dim, 0:n_dim])
+            T.writes(C[0:m_dim, 0:n_dim])
+            for i, j in T.grid(m_dim, n_dim):
+                with T.block("store"):
+                    vii, vjj = T.axis.remap("SS", [i, j])
+                    C[vii, vjj] = A[vii, vjj]
+
+    @T.prim_func
+    def wmma_store_impl(a: T.handle, c: T.handle) -> None:
+        s1 = T.var("int32")
+        s0 = T.var("int32")
+        A = T.match_buffer(
+            a, (m_dim, n_dim), dtype, align=128, offset_factor=16, 
scope="wmma.accumulator"
+        )
+        C = T.match_buffer(
+            c, (m_dim, n_dim), dtype, align=128, offset_factor=16, 
scope=scope, strides=[s1, s0]
+        )
+        with T.block("root"):
+            T.reads(A[0:m_dim, 0:n_dim])
+            T.writes(C[0:m_dim, 0:n_dim])
+            T.evaluate(
+                T.tvm_store_matrix_sync(
+                    A.data,
+                    m_dim,
+                    n_dim,
+                    k_dim,
+                    get_wmma_fragment_index(A, m_dim, n_dim),
+                    C.access_ptr("w"),
+                    s1,
+                    "row_major",
+                    dtype="handle",
+                )
+            )
+
+    return wmma_store_desc, wmma_store_impl
+
+
+def get_wmma_sync_intrin(
+    m_dim: int, n_dim: int, k_dim: int, in_dtype: str, out_dtype: str, 
b_transposed: bool
+) -> Tuple[PrimFunc, PrimFunc]:
+    """Generator of wmma_sync intrins"""
+
+    def maybe_cast(v):
+        if in_dtype != out_dtype:
+            return Cast(out_dtype, v)
+        return v
+
+    def maybe_swap(i, j):
+        if b_transposed:
+            return j, i
+        return i, j
+
+    b_shape_0, b_shape_1 = maybe_swap(k_dim, n_dim)
+
+    @T.prim_func
+    def wmma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
+        A = T.match_buffer(
+            a, (m_dim, k_dim), in_dtype, align=128, offset_factor=16, 
scope="wmma.matrix_a"
+        )
+        B = T.match_buffer(
+            b,
+            maybe_swap(k_dim, n_dim),
+            in_dtype,
+            align=128,
+            offset_factor=16,
+            scope="wmma.matrix_b",
+        )
+        C = T.match_buffer(
+            c, (m_dim, n_dim), out_dtype, align=128, offset_factor=16, 
scope="wmma.accumulator"
+        )
+
+        with T.block("root"):
+            T.reads(C[0:m_dim, 0:n_dim], A[0:m_dim, 0:k_dim], B[0:b_shape_0, 
0:b_shape_1])
+            T.writes(C[0:m_dim, 0:n_dim])
+            for i, j, k in T.grid(m_dim, n_dim, k_dim):
+                with T.block(""):
+                    vii, vjj, vkk = T.axis.remap("SSR", [i, j, k])
+                    B_index_0, B_index_1 = maybe_swap(vkk, vjj)
+                    C[vii, vjj] = C[vii, vjj] + maybe_cast(A[vii, vkk]) * 
maybe_cast(
+                        B[B_index_0, B_index_1]
+                    )
+
+    @T.prim_func
+    def wmma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
+        A = T.match_buffer(
+            a, (m_dim, k_dim), in_dtype, align=128, offset_factor=16, 
scope="wmma.matrix_a"
+        )
+        B = T.match_buffer(
+            b,
+            maybe_swap(k_dim, n_dim),
+            in_dtype,
+            align=128,
+            offset_factor=16,
+            scope="wmma.matrix_b",
+        )
+        C = T.match_buffer(
+            c, (m_dim, n_dim), out_dtype, align=128, offset_factor=16, 
scope="wmma.accumulator"
+        )
+
+        with T.block("root"):
+            T.reads(C[0:m_dim, 0:n_dim], A[0:m_dim, 0:k_dim], B[0:b_shape_0, 
0:b_shape_1])
+            T.writes(C[0:m_dim, 0:n_dim])
+            T.evaluate(
+                T.tvm_mma_sync(
+                    C.data,
+                    get_wmma_fragment_index(C, m_dim, n_dim),
+                    A.data,
+                    get_wmma_fragment_index(A, m_dim, k_dim),
+                    B.data,
+                    get_wmma_fragment_index(B, b_shape_0, b_shape_1),
+                    C.data,
+                    get_wmma_fragment_index(C, m_dim, n_dim),
+                    dtype="handle",
+                )
+            )
+
+    return wmma_sync_desc, wmma_sync_impl
+
+
+WMMA_SYNC_16x16x16_f16f16f32_INTRIN = "wmma_sync_16x16x16_f16f16f32"
+TensorIntrin.register(
+    WMMA_SYNC_16x16x16_f16f16f32_INTRIN,
+    *get_wmma_sync_intrin(16, 16, 16, "float16", "float32", False),
+)
+
+WMMA_SYNC_16x16x16_f16f16f32_TRANS_INTRIN = 
"wmma_sync_16x16x16_f16f16f32_trans"
+TensorIntrin.register(
+    WMMA_SYNC_16x16x16_f16f16f32_TRANS_INTRIN,
+    *get_wmma_sync_intrin(16, 16, 16, "float16", "float32", False),

Review Comment:
   Good catch



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to