masahi commented on code in PR #14673:
URL: https://github.com/apache/tvm/pull/14673#discussion_r1171759762
##########
python/tvm/tir/tensor_intrin/cuda.py:
##########
@@ -1182,3 +1182,263 @@ def get_wmma_intrin_group(
"compute": compute_intrin,
"store": store_intrin,
}
+
+
+def get_index_A(elem_offset, stride):
+ i = elem_offset // stride
+ j = elem_offset % stride
+ stride_b = stride // 8
+ bi = i // 32
+ bj = j // 8
+ no = bi * stride_b + bj
+ return no * 8 + (i % 32) // 16 * 4
+
+
+def get_index_B(elem_offset, stride):
+ i = elem_offset // stride
+ j = elem_offset % stride
+ stride_b = stride // 32
+ bi = i // 8
+ bj = j // 32
+ no = bi * stride_b + bj
+ return no * 8 + (j % 32) // 8 * 2
+
+
+def get_index_C(elem_offset, stride):
+ i = elem_offset // stride
+ j = elem_offset % stride
+ stride_b = stride // 8
+ bi = i // 8
+ bj = j // 8
+ return (bi // 2) * 2 * stride_b + bi % 2 + bj * 2
+
+
[email protected]_func
+def m16n8k8_load_A_row_major_desc(a: T.handle, c: T.handle) -> None:
+ src = T.match_buffer(a, (32, 8), "float16", align=64, offset_factor=1,
scope="shared.dyn")
+ dst = T.match_buffer(c, (32, 8), "float16", align=64, offset_factor=1,
scope="m16n8k8.matrixA")
+
+ with T.block("root"):
+ T.reads(src[0:32, 0:8])
+ T.writes(dst[0:32, 0:8])
+ for i, j in T.grid(32, 8):
+ with T.block("m16n8k8_load_A"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ dst[vi, vj] = src[vi, vj]
+
+
[email protected]_func
+def m16n8k8_load_A_row_major_impl(a: T.handle, c: T.handle) -> None:
+ s0 = T.int32()
+ s1 = T.int32()
+ src = T.match_buffer(
+ a, (32, 8), "float16", align=64, offset_factor=1, scope="shared.dyn",
strides=[s0, s1]
+ )
+
+ d0 = T.int32()
+ d1 = T.int32()
+ dst = T.match_buffer(
+ c, (32, 8), "float16", align=64, offset_factor=1,
scope="m16n8k8.matrixA", strides=[d0, d1]
+ )
+
+ with T.block("root"):
+ T.reads(src[0:32, 0:8])
+ T.writes(dst[0:32, 0:8])
+
+ tx = T.env_thread("threadIdx.x")
+ T.launch_thread(tx, 32)
+
+ T.evaluate(
+ T.ptx_ldmatrix(
+ False, # trans
+ 4, # Always load 4 matrices
+ ".b16",
+ dst.data,
+ get_index_A(dst.elem_offset, d0),
+ src.access_ptr("r"),
+ tx * s0,
+ dtype="float16",
+ )
+ )
+
+
[email protected]_func
+def m16n8k8_load_B_row_major_desc(a: T.handle, c: T.handle) -> None:
+ src = T.match_buffer(a, (8, 32), "float16", align=64, offset_factor=1,
scope="shared.dyn")
+ dst = T.match_buffer(c, (8, 32), "float16", align=64, offset_factor=1,
scope="m16n8k8.matrixB")
+
+ with T.block("root"):
+ T.reads(src[0:8, 0:32])
+ T.writes(dst[0:8, 0:32])
+ for i, j in T.grid(8, 32):
+ with T.block("m16n8k8_load_B"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ dst[vi, vj] = src[vi, vj]
+
+
[email protected]_func
+def m16n8k8_load_B_row_major_impl(a: T.handle, c: T.handle) -> None:
+ s0 = T.int32()
+ s1 = T.int32()
+ src = T.match_buffer(
+ a, (8, 32), "float16", align=64, offset_factor=1, scope="shared.dyn",
strides=[s0, s1]
+ )
+ d0 = T.int32()
+ d1 = T.int32()
+ dst = T.match_buffer(
+ c, (8, 32), "float16", align=64, offset_factor=1,
scope="m16n8k8.matrixB", strides=[d0, d1]
+ )
+
+ with T.block("root"):
+ T.reads(src[0:8, 0:32])
+ T.writes(dst[0:8, 0:32])
+
+ tx = T.env_thread("threadIdx.x")
+ T.launch_thread(tx, 32)
+
+ T.evaluate(
+ T.ptx_ldmatrix(
+ True, # trans
+ 4, # Always load 4 matrices
+ ".b16",
+ dst.data,
+ get_index_B(dst.elem_offset, d0),
+ src.access_ptr("r"),
+ s0 * (tx % 8) + 8 * (tx // 8),
+ dtype="float16",
+ )
+ )
+
+
[email protected]_func
+def m16n8k8_store_C_row_major_desc(a: T.handle, c: T.handle) -> None:
+ src = T.match_buffer(a, (8, 8), "float16", align=64, offset_factor=1,
scope="m16n8k8.matrixC")
+ dst = T.match_buffer(c, (8, 8), "float16", align=64, offset_factor=1,
scope="shared.dyn")
+
+ with T.block("root"):
+ T.reads(src[0:8, 0:8])
+ T.writes(dst[0:8, 0:8])
+ for i, j in T.grid(8, 8):
+ with T.block("m16n8k8_store"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ dst[vi, vj] = src[vi, vj]
+
+
[email protected]_func
+def m16n8k8_store_C_row_major_impl(a: T.handle, c: T.handle) -> None:
+ src = T.match_buffer(a, (8, 8), "float16", align=64, offset_factor=1,
scope="m16n8k8.matrixC")
+ dst = T.match_buffer(c, (8, 8), "float16", align=64, offset_factor=1,
scope="shared.dyn")
+
+ with T.block("root"):
+ T.reads(src[0:8, 0:8])
+ T.writes(dst[0:8, 0:8])
+
+ tx = T.env_thread("threadIdx.x")
+ T.launch_thread(tx, 32)
+
+ for i in T.vectorized(2):
+ dst[tx // 4, tx % 4 * 2 + i] = src[tx // 4, tx % 4 * 2 + i]
+
+
[email protected]_func
+def m16n8k8_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
+ A = T.match_buffer(a, (16, 8), "float16", align=64, offset_factor=1,
scope="m16n8k8.matrixA")
+ B = T.match_buffer(b, (8, 8), "float16", align=64, offset_factor=1,
scope="m16n8k8.matrixB")
+ C = T.match_buffer(c, (16, 8), "float16", align=64, offset_factor=1,
scope="m16n8k8.matrixC")
+
+ with T.block("root"):
+ T.reads(C[0:16, 0:8], A[0:16, 0:8], B[0:8, 0:8])
+ T.writes(C[0:16, 0:8])
+ for i, j, k in T.grid(16, 8, 8):
+ with T.block("m16n8k8_sync"):
+ vi, vj, vkk = T.axis.remap("SSR", [i, j, k])
+ C[vi, vj] = C[vi, vj] + A[vi, vkk] * B[vkk, vj]
+
+
[email protected]_func
+def m16n8k8_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
+ a0 = T.int32()
+ a1 = T.int32()
+ A = T.match_buffer(
+ a, (16, 8), "float16", align=64, offset_factor=1,
scope="m16n8k8.matrixA", strides=[a0, a1]
+ )
+ b0 = T.int32()
+ b1 = T.int32()
+ B = T.match_buffer(
+ b, (8, 8), "float16", align=64, offset_factor=1,
scope="m16n8k8.matrixB", strides=[b0, b1]
+ )
+ c0 = T.int32()
+ c1 = T.int32()
+ C = T.match_buffer(
+ c, (16, 8), "float16", align=64, offset_factor=1,
scope="m16n8k8.matrixC", strides=[c0, c1]
+ )
+
+ with T.block("root"):
+ T.reads(C[0:16, 0:8], A[0:16, 0:8], B[0:8, 0:8])
+ T.writes(C[0:16, 0:8])
+ T.evaluate(
+ T.ptx_mma(
+ "m16n8k8",
+ "row",
+ "col",
+ "fp16",
+ "fp16",
+ "fp16",
+ A.data,
+ get_index_A(A.elem_offset, a0),
+ B.data,
+ get_index_B(B.elem_offset, b0),
+ C.data,
+ get_index_C(C.elem_offset, c0),
+ False,
+ dtype="float16",
+ )
+ )
+
+
[email protected]_func
+def m16n8k8_init_desc(c: T.handle) -> None:
+ dst = T.match_buffer(c, (16, 8), "float16", align=64, offset_factor=1,
scope="m16n8k8.matrixC")
+
+ with T.block("root"):
+ T.reads()
+ T.writes(dst[0:16, 0:8])
+ for i, j in T.grid(16, 8):
+ with T.block("m16n8k8_store"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ dst[vi, vj] = T.float16(0)
+
+
[email protected]_func
+def m16n8k8_init_impl(c: T.handle) -> None:
+ dst = T.match_buffer(c, (16, 8), "float16", align=64, offset_factor=1,
scope="m16n8k8.matrixC")
+
+ with T.block("root"):
+ T.reads()
+ T.writes(dst[0:16, 0:8])
+
+ tx = T.env_thread("threadIdx.x")
+ T.launch_thread(tx, 32)
+
+ for b in range(2):
+ for i in T.vectorized(2):
+ dst[b * 8 + tx // 4, tx % 4 * 2 + i] = T.float16(0)
+
+
+TensorIntrin.register("m16n8k8_init", m16n8k8_init_desc, m16n8k8_init_impl)
+TensorIntrin.register(
+ "m16n8k8_load_A_row_major", m16n8k8_load_A_row_major_desc,
m16n8k8_load_A_row_major_impl
+)
+TensorIntrin.register(
+ "m16n8k8_load_B_row_major", m16n8k8_load_B_row_major_desc,
m16n8k8_load_B_row_major_impl
+)
+TensorIntrin.register("m16n8k8_sync", m16n8k8_sync_desc, m16n8k8_sync_impl)
+TensorIntrin.register(
+ "m16n8k8_store_C_row_major", m16n8k8_store_C_row_major_desc,
m16n8k8_store_C_row_major_impl
+)
Review Comment:
Why can't the existing intrinsic definitions for m16n16k16 be used?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]