This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 7688db7 [PTX] Support mma.sp to use Sparse Tensor Cores and refactor
mma codegen (#10339)
7688db7 is described below
commit 7688db7ac5c4e1a043bf0dddeed75780ec49e70a
Author: Zihao Ye <[email protected]>
AuthorDate: Tue Mar 8 11:33:25 2022 -0800
[PTX] Support mma.sp to use Sparse Tensor Cores and refactor mma codegen
(#10339)
* init
* upd
* upd
* lint
* lint again
* upd
* add m16n8k32 testcase
* format
* use make_tuple instead of initializer list
* add metadata offset
* upd
* docstring and sanity
* add u8s8s32 back
* improvement
* compatible #9727
---
include/tvm/tir/builtin.h | 13 +
src/target/source/codegen_cuda.cc | 49 +-
src/target/source/ptx_mma.cc | 1806 +++++++-------------------
src/target/source/ptx_mma.h | 30 +-
src/tir/op/builtin.cc | 3 +
tests/python/unittest/test_tir_ptx_mma.py | 1 +
tests/python/unittest/test_tir_ptx_mma_sp.py | 346 +++++
7 files changed, 934 insertions(+), 1314 deletions(-)
diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h
index f7e1cfbc..0d9f823 100644
--- a/include/tvm/tir/builtin.h
+++ b/include/tvm/tir/builtin.h
@@ -596,6 +596,19 @@ TVM_DLL const Op& tvm_store_matrix_sync();
*/
TVM_DLL const Op& ptx_mma();
+/*!
+ * \brief tvm intrinsic for sparse tensor core ptx instructions.
+ *
+ * void ptx_mma_sp(StringImm shape, StringImm A_layout, StringImm B_layout,
+ * StringImm A_dtype, StringImm B_dtype, StringImm C_dtype,
+ * Var multiplicand_a, Expr a_index,
+ * Var multiplicand_b, Expr b_index,
+ * Var accumulator, Expr c_index,
+ * Var metadata, Expr meta_index,
+ * Var sparse_selector, bool saturate);
+ */
+TVM_DLL const Op& ptx_mma_sp();
+
// TODO(tvm-team) replace the usage of the vector operations by Shuffle.
/*!
* \brief Get the high level half of the vector
diff --git a/src/target/source/codegen_cuda.cc
b/src/target/source/codegen_cuda.cc
index 0dda079..f74d5cf 100644
--- a/src/target/source/codegen_cuda.cc
+++ b/src/target/source/codegen_cuda.cc
@@ -744,7 +744,8 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op,
std::ostream& os) {
// arg 10: C accumulator
// arg 11: C accumulator index
// arg 12: saturate
- ICHECK_EQ(op->args.size(), 13U);
+ // arg 13: (optional) 1-bit operator (xor or and)
+ ICHECK(op->args.size() == 13U || op->args.size() == 14U);
std::string shape = Downcast<StringImm>(op->args[0])->value;
std::string A_layout = Downcast<StringImm>(op->args[1])->value;
std::string B_layout = Downcast<StringImm>(op->args[2])->value;
@@ -757,11 +758,51 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op,
std::ostream& os) {
std::string b_bias = this->PrintExpr(op->args[9]);
std::string c_ref = this->PrintExpr(op->args[10]);
std::string c_bias = this->PrintExpr(op->args[11]);
- bool saturate = (Downcast<IntImm>(op->args[12])->value != 0);
- std::string asm_code = PrintMMAAssembly(shape, A_layout, B_layout,
A_dtype, B_dtype, C_dtype,
- a_ref, a_bias, b_ref, b_bias,
c_ref, c_bias, saturate);
+ bool saturate = Downcast<Bool>(op->args[12])->value;
+ std::string bit_op = op->args.size() > 13 ?
Downcast<StringImm>(op->args[13])->value : "";
+ std::string asm_code =
+ PrintMMAAssembly(shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype,
a_ref, a_bias, b_ref,
+ b_bias, c_ref, c_bias, "", "", "", bit_op, false,
saturate);
this->stream << asm_code;
+ } else if (op->op.same_as(builtin::ptx_mma_sp())) {
+ // arg 0: shape: mXnXkX
+ // arg 1: A layout: row/col
+ // arg 2: B layout: row/col
+ // arg 3: A precision: fp16, fp32, ...
+ // arg 4: B precision: fp16, fp32, ...
+ // arg 5: C precision: fp16, fp32, ...
+ // arg 6: A multiplicand
+ // arg 7: A multiplicand index
+ // arg 8: B multiplicand
+ // arg 9: B multiplicand index
+ // arg 10: C accumulator
+ // arg 11: C accumulator index
+ // arg 12: metadata
+ // arg 13: metadata index
+ // arg 14: sparse_selector
+ // arg 15: saturate
+ ICHECK_EQ(op->args.size(), 16U);
+ std::string shape = Downcast<StringImm>(op->args[0])->value;
+ std::string A_layout = Downcast<StringImm>(op->args[1])->value;
+ std::string B_layout = Downcast<StringImm>(op->args[2])->value;
+ std::string A_dtype = Downcast<StringImm>(op->args[3])->value;
+ std::string B_dtype = Downcast<StringImm>(op->args[4])->value;
+ std::string C_dtype = Downcast<StringImm>(op->args[5])->value;
+ std::string a_ref = this->PrintExpr(op->args[6]);
+ std::string a_offset = this->PrintExpr(op->args[7]);
+ std::string b_ref = this->PrintExpr(op->args[8]);
+ std::string b_offset = this->PrintExpr(op->args[9]);
+ std::string c_ref = this->PrintExpr(op->args[10]);
+ std::string c_offset = this->PrintExpr(op->args[11]);
+ std::string metadata = this->PrintExpr(op->args[12]);
+ std::string metadata_offset = this->PrintExpr(op->args[13]);
+ std::string sparse_selector = this->PrintExpr(op->args[14]);
+ bool saturate = Downcast<Bool>(op->args[15])->value;
+ std::string asm_code = PrintMMAAssembly(
+ shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_offset,
b_ref, b_offset,
+ c_ref, c_offset, metadata, metadata_offset, sparse_selector, "", true,
saturate);
+ this->stream << asm_code;
} else {
CodeGenC::VisitExpr_(op, os);
}
diff --git a/src/target/source/ptx_mma.cc b/src/target/source/ptx_mma.cc
index b618272..d04c018 100644
--- a/src/target/source/ptx_mma.cc
+++ b/src/target/source/ptx_mma.cc
@@ -23,1351 +23,543 @@
#include "ptx_mma.h"
+#include <algorithm>
+#include <string>
+#include <tuple>
+#include <utility>
+#include <vector>
+
namespace tvm {
namespace codegen {
-std::string ReplaceMMAArgument(std::string asm_code, const std::string&
original,
- const std::string& new_arg) {
- size_t len = original.size();
- size_t new_len = new_arg.size();
- size_t pos = asm_code.find(original);
- while (pos != std::string::npos) {
- asm_code = asm_code.replace(pos, len, new_arg);
- pos = asm_code.find(original, pos + new_len);
- }
- return asm_code;
-}
+// PTX related data structures and functions.
+namespace ptx {
-std::string PrintMMAm8n8k4Assembly(const std::string& A_layout, const
std::string& B_layout,
- const std::string& A_dtype, const
std::string& B_dtype,
- const std::string& C_dtype, const
std::string& a_ref,
- const std::string& a_bias, const
std::string& b_ref,
- const std::string& b_bias, const
std::string& c_ref,
- const std::string& c_bias, bool saturate) {
- std::string asm_code = "";
- std::string new_a_ref = "";
- std::string new_b_ref = "";
- std::string new_c_ref = "";
- ICHECK(((A_dtype == "fp16") && (B_dtype == "fp16")) ||
- ((A_dtype == "fp64") && (B_dtype == "fp64")));
- ICHECK(saturate == false) << "Saturate is not allowed for m8n8k4 mma.";
- if ((A_dtype == "fp16") && (B_dtype == "fp16")) {
- // A/B multiplicand is fp16, SM 70 Tensor Core instructions
- ICHECK((C_dtype == "fp16") || (C_dtype == "fp32"));
- if (C_dtype == "fp16") {
- // C accumulator is fp16
- new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
- new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
- new_c_ref = "((unsigned *)(" + c_ref + " + " + c_bias + "))";
- asm_code = R"(
- {
- __asm__ __volatile__(
-
"mma.sync.aligned.m8n8k4.left_layout.right_layout.f16.f16.f16.f16 "
- "{%0,%1,%2,%3}, {%4,%5}, {%6,%7}, "
- "{%8,%9,%10,%11};\n"
- : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
- : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]),
- "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
- }
- )";
- } else {
- // C accumulator is fp32
- new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
- new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
- new_c_ref = "((float *)(" + c_ref + " + " + c_bias + "))";
- asm_code = R"(
- {
- __asm__ __volatile__(
-
"mma.sync.aligned.m8n8k4.left_layout.right_layout.f32.f16.f16.f32 "
- "{%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, "
- "{%12,%13,%14,%15,%16,%17,%18,%19};\n"
- : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]),
- "=f"(D[4]), "=f"(D[5]), "=f"(D[6]), "=f"(D[7])
- : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]),
- "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]),
- "f"(C[4]), "f"(C[5]), "f"(C[6]), "f"(C[7]));
- }
- )";
- }
+/*!
+ * \brief PTX data type.
+ * \note
+ * PTX fundamental data types:
+ *
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#fundamental-types
+ * PTX matrix data types:
+ *
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-data-types
+ */
+enum class DataType : int {
+ kInt4 = 0,
+ kUInt4 = 1,
+ kInt8 = 2,
+ kUInt8 = 3,
+ kInt16 = 4,
+ kUInt16 = 5,
+ kInt32 = 6,
+ kUInt32 = 7,
+ kInt64 = 8,
+ kUInt64 = 9,
+ kFloat16 = 10,
+ kBFloat16 = 11,
+ kFloat16x2 = 12,
+ kFloat32 = 13,
+ kTensorFloat32 = 14,
+ kFloat64 = 15,
+ kBit1 = 16
+};
+
+static const char* dtype_str[] = {".s4", ".u4", ".s8", ".u8", ".s16",
".u16",
+ ".s32", ".u32", ".s64", ".u64", ".f16",
".bf16",
+ ".f16x2", ".f32", ".tf32", ".f64", ".b1"};
+static const uint32_t num_bits[] = {4, 4, 8, 8, 16, 16, 32, 32, 64, 64, 16,
16, 32, 32, 32, 64, 1};
+
+/*!
+ * \brief Create PTX data type from string.
+ */
+inline DataType DTypeFromString(const std::string str) {
+ if (str == "int4" || str == ".s4") {
+ return DataType::kInt4;
+ } else if (str == "uint4" || str == ".u4") {
+ return DataType::kUInt4;
+ } else if (str == "int8" || str == ".s8") {
+ return DataType::kInt8;
+ } else if (str == "uint8" || str == ".u8") {
+ return DataType::kUInt8;
+ } else if (str == "int16" || str == ".s16") {
+ return DataType::kInt16;
+ } else if (str == "uint16" || str == ".u16") {
+ return DataType::kUInt16;
+ } else if (str == "int32" || str == ".s32") {
+ return DataType::kInt32;
+ } else if (str == "uint32" || str == ".u32") {
+ return DataType::kUInt32;
+ } else if (str == "int64" || str == ".s64") {
+ return DataType::kInt64;
+ } else if (str == "uint64" || str == ".u64") {
+ return DataType::kUInt64;
+ } else if (str == "float16" || str == "fp16" || str == ".f16") {
+ return DataType::kFloat16;
+ } else if (str == "bfloat16" || str == "bf16") {
+ return DataType::kBFloat16;
+ } else if (str == ".f16x2") {
+ return DataType::kFloat16x2;
+ } else if (str == "float32" || str == "fp32" || str == ".f32") {
+ return DataType::kFloat32;
+ } else if (str == "tf32") {
+ return DataType::kTensorFloat32;
+ } else if (str == "float64" || str == "fp64" || str == ".f64") {
+ return DataType::kFloat64;
+ } else if (str == "int1" || str == ".b1") {
+ return DataType::kBit1;
} else {
- // A/B multiplicand is fp64, SM 80 Tensor Core instructions
- ICHECK(C_dtype == "fp64");
- ICHECK((A_layout == "row") && (B_layout == "col"))
- << "SM80 Fp64 Tensor Core instructions "
- << "with shape m8n8k4 expect A layout is row major and B layout is col
major.";
- // C accumulator is fp64
- new_a_ref = "((double *)(" + a_ref + " + " + a_bias + "))";
- new_b_ref = "((double *)(" + b_ref + " + " + b_bias + "))";
- new_c_ref = "((double *)(" + c_ref + " + " + c_bias + "))";
- asm_code = R"(
- {
- __asm__ __volatile__(
- "mma.sync.aligned.m8n8k4.row.col.f64.f64.f64.f64 "
- "{%0,%1}, {%2}, {%3}, "
- "{%4,%5};\n"
- : "=d"(D[0]), "=d"(D[1])
- : "d"(A[0]), "d"(B[0]),
- "d"(C[0]), "d"(C[1]));
- }
- )";
+ LOG(FATAL) << "Unrecognized PTX data type " << str;
+ return DataType(0);
}
- asm_code = ReplaceMMAArgument(asm_code, "left_layout", A_layout);
- asm_code = ReplaceMMAArgument(asm_code, "right_layout", B_layout);
- asm_code = ReplaceMMAArgument(asm_code, "A", new_a_ref);
- asm_code = ReplaceMMAArgument(asm_code, "B", new_b_ref);
- asm_code = ReplaceMMAArgument(asm_code, "C", new_c_ref);
- asm_code = ReplaceMMAArgument(asm_code, "D", new_c_ref);
- return asm_code;
}
-std::string PrintMMAm16n8k8Assembly(const std::string& A_layout, const
std::string& B_layout,
- const std::string& A_dtype, const
std::string& B_dtype,
- const std::string& C_dtype, const
std::string& a_ref,
- const std::string& a_bias, const
std::string& b_ref,
- const std::string& b_bias, const
std::string& c_ref,
- const std::string& c_bias, bool saturate) {
- std::string asm_code = "";
- std::string new_a_ref = "";
- std::string new_b_ref = "";
- std::string new_c_ref = "";
- ICHECK(((A_dtype == "fp16") && (B_dtype == "fp16")) ||
- ((A_dtype == "bf16") && (B_dtype == "bf16")));
- ICHECK(saturate == false) << "Saturate is not allowed for m16n8k8 mma.";
- if ((A_dtype == "fp16") && (B_dtype == "fp16")) {
- // A/B multiplicand is fp16, SM 75 Tensor Core instructions
- ICHECK((C_dtype == "fp16") || (C_dtype == "fp32"));
- ICHECK((A_layout == "row") && (B_layout == "col"))
- << "SM75 Tensor Core instructions "
- << "with shape m16n8k8 expect A layout is row major and B layout is
col major.";
- if (C_dtype == "fp16") {
- // C accumulator is fp16
- new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
- new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
- new_c_ref = "((unsigned *)(" + c_ref + " + " + c_bias + "))";
- asm_code = R"(
- {
- __asm__ __volatile__(
- "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
- "{%0,%1}, {%2,%3}, {%5}, "
- "{%5,%6};\n"
- : "=r"(D[0]), "=r"(D[1])
- : "r"(A[0]), "r"(A[1]), "r"(B[0]),
- "r"(C[0]), "r"(C[1]));
- }
- )";
- } else {
- // C accumulator is fp32
- new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
- new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
- new_c_ref = "((float *)(" + c_ref + " + " + c_bias + "))";
- asm_code = R"(
- {
- __asm__ __volatile__(
- "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
- "{%0,%1,%2,%3}, {%4,%5}, {%6}, "
- "{%7,%8,%9,%10};\n"
- : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
- : "r"(A[0]), "r"(A[1]), "r"(B[0]),
- "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]));
- }
- )";
- }
+/*!
+ * \brief Get the string representation of given PTX data type.
+ */
+inline std::string DTypeToString(DataType dtype) { return
dtype_str[static_cast<int>(dtype)]; }
+
+/*!
+ * \brief Get the number of bits of given PTX data type.
+ */
+inline uint32_t DTypeBits(DataType dtype) { return
num_bits[static_cast<int>(dtype)]; }
+
+/*!
+ * \brief Extract the value m, n, k from string m*n*k*
+ */
+inline std::tuple<int, int, int> ParseMMAShape(const std::string& str) {
+ size_t pos_m = str.find("m"), pos_n = str.find("n"), pos_k = str.find("k");
+ CHECK(pos_m != str.npos && pos_n != str.npos && pos_k != str.npos)
+ << "Cannot parse MMA shape " << str;
+ int m = std::stoi(str.substr(pos_m + 1, pos_n - pos_m - 1)),
+ n = std::stoi(str.substr(pos_n + 1, pos_k - pos_n - 1)), k =
std::stoi(str.substr(pos_k + 1));
+ return std::make_tuple(m, n, k);
+}
+
+/*!
+ * \brief Layout Type
+ */
+enum class LayoutType : int { kRowMajor = 0, kColumnMajor = 1 };
+
+/*!
+ * \brief Parse layout type
+ */
+LayoutType LayoutTypeFromString(const std::string& str) {
+ if (str == "row") {
+ return LayoutType::kRowMajor;
+ } else if (str == "col") {
+ return LayoutType::kColumnMajor;
} else {
- // A/B multiplicand is bf16, SM 80 Tensor Core instructions
- ICHECK(C_dtype == "fp32");
- ICHECK((A_layout == "row") && (B_layout == "col"))
- << "SM80 Tensor Core instructions "
- << "with shape m16n8k8 expect A layout is row major and B layout is
col major.";
- // C accumulator is fp32
- new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
- new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
- new_c_ref = "((float *)(" + c_ref + " + " + c_bias + "))";
- asm_code = R"(
- {
- __asm__ __volatile__(
- "mma.sync.aligned.m16n8k8.row.col.f32.bf16.bf16.f32 "
- "{%0,%1,%2,%3}, {%4,%5}, {%6}, "
- "{%7,%8,%9,%10};\n"
- : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
- : "r"(A[0]), "r"(A[1]), "r"(B[0]),
- "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]));
- }
- )";
+ LOG(FATAL) << "Unrecognized layout type " << str;
+ return LayoutType::kRowMajor;
}
- asm_code = ReplaceMMAArgument(asm_code, "left_layout", A_layout);
- asm_code = ReplaceMMAArgument(asm_code, "right_layout", B_layout);
- asm_code = ReplaceMMAArgument(asm_code, "A", new_a_ref);
- asm_code = ReplaceMMAArgument(asm_code, "B", new_b_ref);
- asm_code = ReplaceMMAArgument(asm_code, "C", new_c_ref);
- asm_code = ReplaceMMAArgument(asm_code, "D", new_c_ref);
- return asm_code;
}
-std::string PrintMMAm8n8k16Assembly(const std::string& A_layout, const
std::string& B_layout,
- const std::string& A_dtype, const
std::string& B_dtype,
- const std::string& C_dtype, const
std::string& a_ref,
- const std::string& a_bias, const
std::string& b_ref,
- const std::string& b_bias, const
std::string& c_ref,
- const std::string& c_bias, bool saturate) {
- std::string asm_code = "";
- std::string new_a_ref = "";
- std::string new_b_ref = "";
- std::string new_c_ref = "";
- ICHECK(((A_dtype == "int8") && (B_dtype == "int8")) ||
- ((A_dtype == "uint8") && (B_dtype == "int8")) ||
- ((A_dtype == "int8") && (B_dtype == "uint8")) ||
- ((A_dtype == "uint8") && (B_dtype == "uint8")));
- if ((A_dtype == "int8") && (B_dtype == "int8")) {
- // A/B multiplicand is int8, SM 75 Tensor Core instructions
- ICHECK(C_dtype == "int32");
- ICHECK((A_layout == "row") && (B_layout == "col"))
- << "SM75 Tensor Core instructions "
- << "with shape m8n8k16 expect A layout is row major and B layout is
col major.";
- // C accumulator is int32
- if (!saturate) {
- // no saturate
- new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
- new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
- new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
- asm_code = R"(
- {
- __asm__ __volatile__(
- "mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 "
- "{%0,%1}, {%2}, {%3}, "
- "{%4,%5};\n"
- : "=r"(D[0]), "=r"(D[1])
- : "r"(A[0]), "r"(B[0]),
- "r"(C[0]), "r"(C[1]));
- }
- )";
- } else {
- // saturate
- new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
- new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
- new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
- asm_code = R"(
- {
- __asm__ __volatile__(
- "mma.sync.aligned.m8n8k16.row.col.satfinite.s32.s8.s8.s32 "
- "{%0,%1}, {%2}, {%3}, "
- "{%4,%5};\n"
- : "=r"(D[0]), "=r"(D[1])
- : "r"(A[0]), "r"(B[0]),
- "r"(C[0]), "r"(C[1]));
- }
- )";
- }
- } else if ((A_dtype == "uint8") && (B_dtype == "int8")) {
- // A multiplicand is uint8, B multiplicand is int8
- // SM 75 Tensor Core instructions
- ICHECK(C_dtype == "int32");
- ICHECK((A_layout == "row") && (B_layout == "col"))
- << "SM75 Tensor Core instructions "
- << "with shape m8n8k16 expect A layout is row major and B layout is
col major.";
- // C accumulator is int32
- if (!saturate) {
- // no saturate
- new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
- new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
- new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
- asm_code = R"(
- {
- __asm__ __volatile__(
- "mma.sync.aligned.m8n8k16.row.col.s32.u8.s8.s32 "
- "{%0,%1}, {%2}, {%3}, "
- "{%4,%5};\n"
- : "=r"(D[0]), "=r"(D[1])
- : "r"(A[0]), "r"(B[0]),
- "r"(C[0]), "r"(C[1]));
- }
- )";
- } else {
- // saturate
- new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
- new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
- new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
- asm_code = R"(
- {
- __asm__ __volatile__(
- "mma.sync.aligned.m8n8k16.row.col.satfinite.s32.u8.s8.s32 "
- "{%0,%1}, {%2}, {%3}, "
- "{%4,%5};\n"
- : "=r"(D[0]), "=r"(D[1])
- : "r"(A[0]), "r"(B[0]),
- "r"(C[0]), "r"(C[1]));
- }
- )";
- }
- } else if ((A_dtype == "int8") && (B_dtype == "uint8")) {
- // A multiplicand is int8, B multiplicand is uint8
- // SM 75 Tensor Core instructions
- ICHECK(C_dtype == "int32");
- ICHECK((A_layout == "row") && (B_layout == "col"))
- << "SM75 Tensor Core instructions "
- << "with shape m8n8k16 expect A layout is row major and B layout is
col major.";
- // C accumulator is int32
- if (!saturate) {
- // no saturate
- new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
- new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
- new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
- asm_code = R"(
- {
- __asm__ __volatile__(
- "mma.sync.aligned.m8n8k16.row.col.s32.s8.u8.s32 "
- "{%0,%1}, {%2}, {%3}, "
- "{%4,%5};\n"
- : "=r"(D[0]), "=r"(D[1])
- : "r"(A[0]), "r"(B[0]),
- "r"(C[0]), "r"(C[1]));
- }
- )";
- } else {
- // saturate
- new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
- new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
- new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
- asm_code = R"(
- {
- __asm__ __volatile__(
- "mma.sync.aligned.m8n8k16.row.col.satfinite.s32.s8.u8.s32 "
- "{%0,%1}, {%2}, {%3}, "
- "{%4,%5};\n"
- : "=r"(D[0]), "=r"(D[1])
- : "r"(A[0]), "r"(B[0]),
- "r"(C[0]), "r"(C[1]));
- }
- )";
- }
- } else {
- // A/B multiplicand is uint8, SM 75 Tensor Core instructions
- ICHECK(C_dtype == "int32");
- ICHECK((A_layout == "row") && (B_layout == "col"))
- << "SM75 Tensor Core instructions "
- << "with shape m8n8k16 expect A layout is row major and B layout is
col major.";
- // C accumulator is int32
- if (!saturate) {
- // no saturate
- new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
- new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
- new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
- asm_code = R"(
- {
- __asm__ __volatile__(
- "mma.sync.aligned.m8n8k16.row.col.s32.u8.u8.s32 "
- "{%0,%1}, {%2}, {%3}, "
- "{%4,%5};\n"
- : "=r"(D[0]), "=r"(D[1])
- : "r"(A[0]), "r"(B[0]),
- "r"(C[0]), "r"(C[1]));
- }
- )";
- } else {
- // saturate
- new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
- new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
- new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
- asm_code = R"(
- {
- __asm__ __volatile__(
- "mma.sync.aligned.m8n8k16.row.col.satfinite.s32.u8.u8.s32 "
- "{%0,%1}, {%2}, {%3}, "
- "{%4,%5};\n"
- : "=r"(D[0]), "=r"(D[1])
- : "r"(A[0]), "r"(B[0]),
- "r"(C[0]), "r"(C[1]));
- }
- )";
- }
+static const char* layout_type_str[] = {"row", "col"};
+
+/*!
+ * \brief Convert layout type to string.
+ */
+inline std::string LayoutTypeToString(LayoutType layout) {
+ return layout_type_str[static_cast<int>(layout)];
+}
+
+/*!
+ * \brief MMA Configurations, used to determine validity.
+ */
+struct MMAConfig {
+ explicit MMAConfig(int m, int n, int k, DataType dtype_mul, bool use_bit_op,
bool sparse)
+ : m(m), n(n), k(k), dtype_mul(dtype_mul), use_bit_op(use_bit_op),
sparse(sparse) {}
+ int m, n, k;
+ DataType dtype_mul;
+ bool use_bit_op;
+ bool sparse;
+ inline bool operator==(const MMAConfig& other) {
+ return m == other.m && n == other.n && k == other.k && dtype_mul ==
other.dtype_mul &&
+ use_bit_op == other.use_bit_op && sparse == other.sparse;
+ }
+};
+
+/*!
+ * \brief Valid MMA configurations
+ * \note Reference:
+ *
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-shape
+ */
+const MMAConfig valid_mma_configs[] = {
+ MMAConfig(8, 8, 4, DataType::kFloat64, false, false),
+ MMAConfig(8, 8, 4, DataType::kFloat16, false, false),
+ MMAConfig(16, 8, 8, DataType::kFloat16, false, false),
+ MMAConfig(16, 8, 16, DataType::kFloat16, false, false),
+ MMAConfig(16, 8, 8, DataType::kBFloat16, false, false),
+ MMAConfig(16, 8, 16, DataType::kBFloat16, false, false),
+ MMAConfig(16, 8, 4, DataType::kTensorFloat32, false, false),
+ MMAConfig(16, 8, 8, DataType::kTensorFloat32, false, false),
+ MMAConfig(8, 8, 16, DataType::kInt8, false, false),
+ MMAConfig(16, 8, 16, DataType::kInt8, false, false),
+ MMAConfig(16, 8, 32, DataType::kInt8, false, false),
+ MMAConfig(8, 8, 16, DataType::kUInt8, false, false),
+ MMAConfig(16, 8, 16, DataType::kUInt8, false, false),
+ MMAConfig(16, 8, 32, DataType::kUInt8, false, false),
+ MMAConfig(8, 8, 32, DataType::kInt4, false, false),
+ MMAConfig(16, 8, 32, DataType::kInt4, false, false),
+ MMAConfig(16, 8, 64, DataType::kInt4, false, false),
+ MMAConfig(8, 8, 32, DataType::kUInt4, false, false),
+ MMAConfig(16, 8, 32, DataType::kUInt4, false, false),
+ MMAConfig(16, 8, 64, DataType::kUInt4, false, false),
+ MMAConfig(8, 8, 128, DataType::kBit1, true, false),
+ MMAConfig(16, 8, 128, DataType::kBit1, true, false),
+ MMAConfig(16, 8, 256, DataType::kBit1, true, false),
+ MMAConfig(16, 8, 16, DataType::kFloat16, false, true),
+ MMAConfig(16, 8, 32, DataType::kFloat16, false, true),
+ MMAConfig(16, 8, 16, DataType::kBFloat16, false, true),
+ MMAConfig(16, 8, 32, DataType::kBFloat16, false, true),
+ MMAConfig(16, 8, 8, DataType::kTensorFloat32, false, true),
+ MMAConfig(16, 8, 16, DataType::kTensorFloat32, false, true),
+ MMAConfig(16, 8, 32, DataType::kInt8, false, true),
+ MMAConfig(16, 8, 64, DataType::kInt8, false, true),
+ MMAConfig(16, 8, 32, DataType::kUInt8, false, true),
+ MMAConfig(16, 8, 64, DataType::kUInt8, false, true),
+ MMAConfig(16, 8, 64, DataType::kInt4, false, true),
+ MMAConfig(16, 8, 128, DataType::kInt4, false, true),
+ MMAConfig(16, 8, 64, DataType::kUInt4, false, true),
+ MMAConfig(16, 8, 128, DataType::kUInt4, false, true),
+};
+
+/*!
+ * \brief Check whether the multiplicand data type and accumulator data type
is valid for MMA
+ * computation.
+ * \param dtype_a The data type of multiplicand a.
+ * \param dtype_b The data type of multiplicand b.
+ * \param dtype_c The data type of accumulator c.
+ * \note Reference:
+ *
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-data-types
+ */
+void CheckMMADTypeCompatible(DataType dtype_a, DataType dtype_b, DataType
dtype_c) {
+ std::string ab_not_match_err_str = "The multiplicands' data type " +
DTypeToString(dtype_a) +
+ DTypeToString(dtype_b) + " do not match.";
+ // check a and b
+ switch (dtype_a) {
+ case DataType::kBit1:
+ case DataType::kFloat16:
+ case DataType::kBFloat16:
+ case DataType::kTensorFloat32:
+ case DataType::kFloat64:
+ CHECK(dtype_a == dtype_b) << ab_not_match_err_str;
+ break;
+ case DataType::kInt4:
+ case DataType::kUInt4:
+ CHECK(dtype_b == DataType::kInt4 || dtype_b == DataType::kUInt4) <<
ab_not_match_err_str;
+ break;
+ case DataType::kInt8:
+ case DataType::kUInt8:
+ CHECK(dtype_b == DataType::kInt8 || dtype_b == DataType::kUInt8) <<
ab_not_match_err_str;
+ break;
+ default:
+ CHECK(false) << "Invalid multiplicand data types: " <<
DTypeToString(dtype_a)
+ << DTypeToString(dtype_b);
+ }
+ // check a,b and c
+ switch (dtype_a) {
+ case DataType::kBit1:
+ case DataType::kInt4:
+ case DataType::kUInt4:
+ case DataType::kInt8:
+ case DataType::kUInt8:
+ CHECK(dtype_c == DataType::kInt32)
+ << "For multiplicand data type " << DTypeToString(dtype_a) <<
DTypeToString(dtype_b)
+ << ", accumulator data type should be s32.";
+ break;
+ case DataType::kFloat16:
+ CHECK(dtype_c == DataType::kFloat16 || dtype_c == DataType::kFloat32)
+ << "For multiplicand data type f16, accumulator data type should be
f16/f32.";
+ break;
+ case DataType::kBFloat16:
+ case DataType::kTensorFloat32:
+ CHECK(dtype_c == DataType::kFloat32)
+ << "For multiplicand data type bf16/tf32, accumulator data type can
only be f32.";
+ break;
+ case DataType::kFloat64:
+ CHECK(dtype_c == DataType::kFloat64)
+ << "For multiplicand data type f64, accumulator data type can only
be f64.";
+ break;
+ default:
+ CHECK(false) << "Invalid multiplicand/accumulator data types: " <<
DTypeToString(dtype_a)
+ << DTypeToString(dtype_b) << DTypeToString(dtype_c) << ".";
}
- asm_code = ReplaceMMAArgument(asm_code, "left_layout", A_layout);
- asm_code = ReplaceMMAArgument(asm_code, "right_layout", B_layout);
- asm_code = ReplaceMMAArgument(asm_code, "A", new_a_ref);
- asm_code = ReplaceMMAArgument(asm_code, "B", new_b_ref);
- asm_code = ReplaceMMAArgument(asm_code, "C", new_c_ref);
- asm_code = ReplaceMMAArgument(asm_code, "D", new_c_ref);
- return asm_code;
}
-std::string PrintMMAm8n8k32Assembly(const std::string& A_layout, const
std::string& B_layout,
- const std::string& A_dtype, const
std::string& B_dtype,
- const std::string& C_dtype, const
std::string& a_ref,
- const std::string& a_bias, const
std::string& b_ref,
- const std::string& b_bias, const
std::string& c_ref,
- const std::string& c_bias, bool saturate) {
- std::string asm_code = "";
- std::string new_a_ref = "";
- std::string new_b_ref = "";
- std::string new_c_ref = "";
- ICHECK(((A_dtype == "int4") && (B_dtype == "int4")) ||
- ((A_dtype == "uint4") && (B_dtype == "int4")) ||
- ((A_dtype == "int4") && (B_dtype == "uint4")) ||
- ((A_dtype == "uint4") && (B_dtype == "uint4")));
- if ((A_dtype == "int4") && (B_dtype == "int4")) {
- // A/B multiplicand is int4, SM 75 Tensor Core instructions
- ICHECK(C_dtype == "int32");
- ICHECK((A_layout == "row") && (B_layout == "col"))
- << "SM75 Tensor Core instructions "
- << "with shape m8n8k32 expect A layout is row major and B layout is
col major.";
- // C accumulator is int32
- if (!saturate) {
- // no saturate
- new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
- new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
- new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
- asm_code = R"(
- {
- __asm__ __volatile__(
- "mma.sync.aligned.m8n8k32.row.col.s32.s4.s4.s32 "
- "{%0,%1}, {%2}, {%3}, "
- "{%4,%5};\n"
- : "=r"(D[0]), "=r"(D[1])
- : "r"(A[0]), "r"(B[0]),
- "r"(C[0]), "r"(C[1]));
- }
- )";
- } else {
- // saturate
- new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
- new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
- new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
- asm_code = R"(
- {
- __asm__ __volatile__(
- "mma.sync.aligned.m8n8k32.row.col.satfinite.s32.s4.s4.s32 "
- "{%0,%1}, {%2}, {%3}, "
- "{%4,%5};\n"
- : "=r"(D[0]), "=r"(D[1])
- : "r"(A[0]), "r"(B[0]),
- "r"(C[0]), "r"(C[1]));
- }
- )";
- }
- } else if ((A_dtype == "uint4") && (B_dtype == "int4")) {
- // A multiplicand is uint4, B multiplicand is int4
- // SM 75 Tensor Core instructions
- ICHECK(C_dtype == "int32");
- ICHECK((A_layout == "row") && (B_layout == "col"))
- << "SM75 Tensor Core instructions "
- << "with shape m8n8k32 expect A layout is row major and B layout is
col major.";
- // C accumulator is int32
- if (!saturate) {
- // no saturate
- new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
- new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
- new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
- asm_code = R"(
- {
- __asm__ __volatile__(
- "mma.sync.aligned.m8n8k32.row.col.s32.u4.s4.s32 "
- "{%0,%1}, {%2}, {%3}, "
- "{%4,%5};\n"
- : "=r"(D[0]), "=r"(D[1])
- : "r"(A[0]), "r"(B[0]),
- "r"(C[0]), "r"(C[1]));
- }
- )";
- } else {
- // saturate
- new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
- new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
- new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
- asm_code = R"(
- {
- __asm__ __volatile__(
- "mma.sync.aligned.m8n8k32.row.col.satfinite.s32.u4.s4.s32 "
- "{%0,%1}, {%2}, {%3}, "
- "{%4,%5};\n"
- : "=r"(D[0]), "=r"(D[1])
- : "r"(A[0]), "r"(B[0]),
- "r"(C[0]), "r"(C[1]));
- }
- )";
- }
- } else if ((A_dtype == "int4") && (B_dtype == "uint4")) {
- // A multiplicand is int4, B multiplicand is uint4
- // SM 75 Tensor Core instructions
- ICHECK(C_dtype == "int32");
- ICHECK((A_layout == "row") && (B_layout == "col"))
- << "SM75 Tensor Core instructions "
- << "with shape m8n8k32 expect A layout is row major and B layout is
col major.";
- // C accumulator is int32
- if (!saturate) {
- // no saturate
- new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
- new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
- new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
- asm_code = R"(
- {
- __asm__ __volatile__(
- "mma.sync.aligned.m8n8k32.row.col.s32.s4.u4.s32 "
- "{%0,%1}, {%2}, {%3}, "
- "{%4,%5};\n"
- : "=r"(D[0]), "=r"(D[1])
- : "r"(A[0]), "r"(B[0]),
- "r"(C[0]), "r"(C[1]));
- }
- )";
- } else {
- // saturate
- new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
- new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
- new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
- asm_code = R"(
- {
- __asm__ __volatile__(
- "mma.sync.aligned.m8n8k32.row.col.satfinite.s32.s4.u4.s32 "
- "{%0,%1}, {%2}, {%3}, "
- "{%4,%5};\n"
- : "=r"(D[0]), "=r"(D[1])
- : "r"(A[0]), "r"(B[0]),
- "r"(C[0]), "r"(C[1]));
- }
- )";
- }
- } else {
- // A/B multiplicand is uint4, SM 75 Tensor Core instructions
- ICHECK(C_dtype == "int32");
- ICHECK((A_layout == "row") && (B_layout == "col"))
- << "SM75 Tensor Core instructions "
- << "with shape m8n8k32 expect A layout is row major and B layout is
col major.";
- // C accumulator is int32
- if (!saturate) {
- // no saturate
- new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
- new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
- new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
- asm_code = R"(
- {
- __asm__ __volatile__(
- "mma.sync.aligned.m8n8k32.row.col.s32.u4.u4.s32 "
- "{%0,%1}, {%2}, {%3}, "
- "{%4,%5};\n"
- : "=r"(D[0]), "=r"(D[1])
- : "r"(A[0]), "r"(B[0]),
- "r"(C[0]), "r"(C[1]));
- }
- )";
- } else {
- // saturate
- new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
- new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
- new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
- asm_code = R"(
- {
- __asm__ __volatile__(
- "mma.sync.aligned.m8n8k32.row.col.satfinite.s32.u4.u4.s32 "
- "{%0,%1}, {%2}, {%3}, "
- "{%4,%5};\n"
- : "=r"(D[0]), "=r"(D[1])
- : "r"(A[0]), "r"(B[0]),
- "r"(C[0]), "r"(C[1]));
- }
- )";
+/*!
+ * \brief Check whether the given configuration is valid for MMA computation.
+ * \param m The M in mMnNkK of MMA instructions.
+ * \param n The N in mMnNkK of MMA instructions.
+ * \param k The K in mMnNkK of MMA instructions.
+ * \param layout_a The layout of multiplicand A (row/col).
+ * \param layout_b The layout of multiplicand B (row/col).
+ * \param dtype_a The data type of multiplicand A.
+ * \param dtype_b The data type of multiplicand B.
+ * \param dtype_c The data type of accumulator C.
+ * \param bit_op The bit operator for 1-bit MMA computation, can be
"xor"/"and" or ""(if it's not
+ * 1-bit MMA).
+ * \param sparse Whether it's Sparse MMA or not.
+ * \param saturate Whether saturate output or not.
+ */
+void CheckMMAConfigValidity(int m, int n, int k, LayoutType layout_a,
LayoutType layout_b,
+ DataType dtype_a, DataType dtype_b, DataType
dtype_c,
+ const std::string& bit_op, bool sparse, bool
saturate) {
+ CHECK(bit_op == "xor" || bit_op == "and" || bit_op == "")
+ << "Unrecognized 1-bit operation " << bit_op << " , can only be
xor/and.";
+ bool use_bit_op = !bit_op.empty();
+ if (use_bit_op) {
+ CHECK(dtype_a == DataType::kBit1) << "Bit operator is only compatible with
1-bit multiplicand.";
+ }
+ CheckMMADTypeCompatible(dtype_a, dtype_b, dtype_c);
+ if (saturate) {
+ CHECK(dtype_a == DataType::kInt4 || dtype_a == DataType::kUInt4 || dtype_a
== DataType::kInt8 ||
+ dtype_a == DataType::kUInt8)
+ << "Output saturation only applicable to multiplicand type
s4/u4/s8/u8.";
+ }
+
+ if (!(m == 8 && n == 8 && k == 4 && dtype_a == ptx::DataType::kFloat16)) {
+ // Only MMA on m8n8k4 for fp16 supports customized layouts.
+ CHECK(layout_a == LayoutType::kRowMajor && layout_b ==
LayoutType::kColumnMajor)
+ << "Invalid layout combination " << LayoutTypeToString(layout_a) << ","
+ << LayoutTypeToString(layout_b) << ".";
+ }
+
+ MMAConfig config(m, n, k, dtype_a, use_bit_op, sparse);
+ bool match = false;
+ for (const MMAConfig& valid_config : valid_mma_configs) {
+ if (config == valid_config) {
+ match = true;
+ break;
}
}
- asm_code = ReplaceMMAArgument(asm_code, "left_layout", A_layout);
- asm_code = ReplaceMMAArgument(asm_code, "right_layout", B_layout);
- asm_code = ReplaceMMAArgument(asm_code, "A", new_a_ref);
- asm_code = ReplaceMMAArgument(asm_code, "B", new_b_ref);
- asm_code = ReplaceMMAArgument(asm_code, "C", new_c_ref);
- asm_code = ReplaceMMAArgument(asm_code, "D", new_c_ref);
- return asm_code;
+ CHECK(match) << "Cannot find matched MMA configurations.";
}
-std::string PrintMMAm16n8k4Assembly(const std::string& A_layout, const
std::string& B_layout,
- const std::string& A_dtype, const
std::string& B_dtype,
- const std::string& C_dtype, const
std::string& a_ref,
- const std::string& a_bias, const
std::string& b_ref,
- const std::string& b_bias, const
std::string& c_ref,
- const std::string& c_bias, bool saturate) {
- std::string asm_code = "";
- std::string new_a_ref = "";
- std::string new_b_ref = "";
- std::string new_c_ref = "";
- ICHECK((A_dtype == "tf32") && (B_dtype == "tf32"));
- ICHECK(saturate == false) << "Saturate is not allowed for m16n8k4 mma.";
- // A/B multiplicand is tf32, SM 80 Tensor Core instructions
- ICHECK(C_dtype == "fp32");
- ICHECK((A_layout == "row") && (B_layout == "col"))
- << "SM80 Tensor Core instructions "
- << "with shape m16n8k4 expect A layout is row major and B layout is col
major.";
- // C accumulator is fp32
- new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
- new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
- new_c_ref = "((float *)(" + c_ref + " + " + c_bias + "))";
- asm_code = R"(
- {
- __asm__ __volatile__(
- "mma.sync.aligned.m16n8k4.row.col.f32.tf32.tf32.f32 "
- "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, "
- "{%10,%11,%12,%13};\n"
- : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
- : "r"(A[0]), "r"(A[1]), "f"(A[2]), "r"(A[3]),
- "r"(B[0]), "r"(B[1]),
- "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]));
- }
- )";
- asm_code = ReplaceMMAArgument(asm_code, "left_layout", A_layout);
- asm_code = ReplaceMMAArgument(asm_code, "right_layout", B_layout);
- asm_code = ReplaceMMAArgument(asm_code, "A", new_a_ref);
- asm_code = ReplaceMMAArgument(asm_code, "B", new_b_ref);
- asm_code = ReplaceMMAArgument(asm_code, "C", new_c_ref);
- asm_code = ReplaceMMAArgument(asm_code, "D", new_c_ref);
- return asm_code;
-}
+/*!
+ * \brief Fragment attributes
+ */
+class FragAttrs {
+ public:
+ explicit FragAttrs(char reg_type, uint32_t size, std::string ptr_type)
+ : reg_type(reg_type), size(size), ptr_type(ptr_type) {}
+ /*! \brief PTX register type */
+ char reg_type;
+ /*! \brief Fragment size */
+ uint32_t size;
+ /*! \brief Fragment pointer type */
+ std::string ptr_type;
+};
-std::string PrintMMAm16n8k16Assembly(const std::string& A_layout, const
std::string& B_layout,
- const std::string& A_dtype, const
std::string& B_dtype,
- const std::string& C_dtype, const
std::string& a_ref,
- const std::string& a_bias, const
std::string& b_ref,
- const std::string& b_bias, const
std::string& c_ref,
- const std::string& c_bias, bool saturate)
{
- std::string asm_code = "";
- std::string new_a_ref = "";
- std::string new_b_ref = "";
- std::string new_c_ref = "";
- ICHECK(((A_dtype == "fp16") && (B_dtype == "fp16")) ||
- ((A_dtype == "bf16") && (B_dtype == "bf16")) ||
- ((A_dtype == "int8") && (B_dtype == "int8")) ||
- ((A_dtype == "uint8") && (B_dtype == "int8")) ||
- ((A_dtype == "int8") && (B_dtype == "uint8")) ||
- ((A_dtype == "uint8") && (B_dtype == "uint8")));
- if ((A_dtype == "fp16") && (B_dtype == "fp16")) {
- ICHECK(saturate == false) << "Saturate is not allowed for m16n8k8 fp16
mma.";
- // A/B multiplicand is fp16, SM 80 Tensor Core instructions
- ICHECK((C_dtype == "fp16") || (C_dtype == "fp32"));
- ICHECK((A_layout == "row") && (B_layout == "col"))
- << "SM80 Tensor Core instructions "
- << "with shape m16n8k16 expect A layout is row major and B layout is
col major.";
- if (C_dtype == "fp16") {
- // C accumulator is fp16
- new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
- new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
- new_c_ref = "((unsigned *)(" + c_ref + " + " + c_bias + "))";
- asm_code = R"(
- {
- __asm__ __volatile__(
- "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
- "{%0,%1}, {%2,%3,%4,%5}, {%6,%7}, "
- "{%8,%9};\n"
- : "=r"(D[0]), "=r"(D[1])
- : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
- "r"(B[0]), "r"(B[1]),
- "r"(C[0]), "r"(C[1]));
- }
- )";
- } else {
- // C accumulator is fp32
- new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
- new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
- new_c_ref = "((float *)(" + c_ref + " + " + c_bias + "))";
- asm_code = R"(
- {
- __asm__ __volatile__(
- "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
- "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, "
- "{%10,%11,%12,%13};\n"
- : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
- : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
- "r"(B[0]), "r"(B[1]),
- "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]));
- }
- )";
- }
- } else if ((A_dtype == "bf16") && (B_dtype == "bf16")) {
- // A/B multiplicand is bf16, SM 80 Tensor Core instructions
- ICHECK(saturate == false) << "Saturate is not allowed for m16n8k8 bf16
mma.";
- ICHECK(C_dtype == "fp32");
- ICHECK((A_layout == "row") && (B_layout == "col"))
- << "SM80 Tensor Core instructions "
- << "with shape m16n8k16 expect A layout is row major and B layout is
col major.";
- // C accumulator is fp32
- new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
- new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
- new_c_ref = "((float *)(" + c_ref + " + " + c_bias + "))";
- asm_code = R"(
- {
- __asm__ __volatile__(
- "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
- "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, "
- "{%10,%11,%12,%13};\n"
- : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
- : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
- "r"(B[0]), "r"(B[1]),
- "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]));
- }
- )";
- } else if ((A_dtype == "int8") && (B_dtype == "int8")) {
- // A/B multiplicand is int8, SM 80 Tensor Core instructions
- ICHECK(C_dtype == "int32");
- ICHECK((A_layout == "row") && (B_layout == "col"))
- << "SM80 Tensor Core instructions "
- << "with shape m16n8k16 expect A layout is row major and B layout is
col major.";
- // C accumulator is int32
- if (!saturate) {
- // no saturate
- new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
- new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
- new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
- asm_code = R"(
- {
- __asm__ __volatile__(
- "mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 "
- "{%0,%1,%2,%3}, {%4,%5}, {%6}, "
- "{%7,%8,%9,%10};\n"
- : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
- : "r"(A[0]), "r"(A[1]), "r"(B[0]),
- "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
- }
- )";
- } else {
- // saturate
- new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
- new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
- new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
- asm_code = R"(
- {
- __asm__ __volatile__(
- "mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite "
- "{%0,%1,%2,%3}, {%4,%5}, {%6}, "
- "{%7,%8,%9,%10};\n"
- : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
- : "r"(A[0]), "r"(A[1]), "r"(B[0]),
- "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
- }
- )";
- }
- } else if ((A_dtype == "uint8") && (B_dtype == "int8")) {
- // A multiplicand is uint8, B multiplicand is int8
- // SM 80 Tensor Core instructions
- ICHECK(C_dtype == "int32");
- ICHECK((A_layout == "row") && (B_layout == "col"))
- << "SM80 Tensor Core instructions "
- << "with shape m16n8k16 expect A layout is row major and B layout is
col major.";
- // C accumulator is int32
- if (!saturate) {
- // no saturate
- new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
- new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
- new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
- asm_code = R"(
- {
- __asm__ __volatile__(
- "mma.sync.aligned.m16n8k16.row.col.s32.u8.s8.s32 "
- "{%0,%1,%2,%3}, {%4,%5}, {%6}, "
- "{%7,%8,%9,%10};\n"
- : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
- : "r"(A[0]), "r"(A[1]), "r"(B[0]),
- "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
- }
- )";
- } else {
- // saturate
- new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
- new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
- new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
- asm_code = R"(
- {
- __asm__ __volatile__(
- "mma.sync.aligned.m16n8k16.row.col.s32.u8.s8.s32.satfinite "
- "{%0,%1,%2,%3}, {%4,%5}, {%6}, "
- "{%7,%8,%9,%10};\n"
- : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
- : "r"(A[0]), "r"(A[1]), "r"(B[0]),
- "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
- }
- )";
- }
- } else if ((A_dtype == "int8") && (B_dtype == "uint8")) {
- // A multiplicand is int8, B multiplicand is uint8
- // SM 80 Tensor Core instructions
- ICHECK(C_dtype == "int32");
- ICHECK((A_layout == "row") && (B_layout == "col"))
- << "SM80 Tensor Core instructions "
- << "with shape m16n8k16 expect A layout is row major and B layout is
col major.";
- // C accumulator is int32
- if (!saturate) {
- // no saturate
- new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
- new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
- new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
- asm_code = R"(
- {
- __asm__ __volatile__(
- "mma.sync.aligned.m16n8k16.row.col.s32.s8.u8.s32 "
- "{%0,%1,%2,%3}, {%4,%5}, {%6}, "
- "{%7,%8,%9,%10};\n"
- : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
- : "r"(A[0]), "r"(A[1]), "r"(B[0]),
- "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
- }
- )";
- } else {
- // saturate
- new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
- new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
- new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
- asm_code = R"(
- {
- __asm__ __volatile__(
- "mma.sync.aligned.m16n8k16.row.col.s32.s8.u8.s32.satfinite "
- "{%0,%1,%2,%3}, {%4,%5}, {%6}, "
- "{%7,%8,%9,%10};\n"
- : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
- : "r"(A[0]), "r"(A[1]), "r"(B[0]),
- "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
- }
- )";
- }
- } else {
- // A/B multiplicand is uint8, SM 80 Tensor Core instructions
- ICHECK(C_dtype == "int32");
- ICHECK((A_layout == "row") && (B_layout == "col"))
- << "SM80 Tensor Core instructions "
- << "with shape m16n8k16 expect A layout is row major and B layout is
col major.";
- // C accumulator is int32
- if (!saturate) {
- // no saturate
- new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
- new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
- new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
- asm_code = R"(
- {
- __asm__ __volatile__(
- "mma.sync.aligned.m16n8k16.row.col.s32.u8.u8.s32 "
- "{%0,%1,%2,%3}, {%4,%5}, {%6}, "
- "{%7,%8,%9,%10};\n"
- : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
- : "r"(A[0]), "r"(A[1]), "r"(B[0]),
- "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
- }
- )";
- } else {
- // saturate
- new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
- new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
- new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
- asm_code = R"(
- {
- __asm__ __volatile__(
- "mma.sync.aligned.m16n8k16.row.col.s32.u8.u8.s32.satfinite "
- "{%0,%1,%2,%3}, {%4,%5}, {%6}, "
- "{%7,%8,%9,%10};\n"
- : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
- : "r"(A[0]), "r"(A[1]), "r"(B[0]),
- "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
- }
- )";
- }
+/*!
+ * \brief Fragment attributes of given data type.
+ */
+inline FragAttrs GetFragAttrs(DataType dtype) {
+ switch (dtype) {
+ case DataType::kBit1:
+ case DataType::kInt4:
+ case DataType::kUInt4:
+ case DataType::kInt8:
+ case DataType::kUInt8:
+ case DataType::kFloat16: // .f16x2 register
+ case DataType::kBFloat16:
+ case DataType::kTensorFloat32:
+ return FragAttrs('r', 32, "(unsigned *)");
+ case DataType::kInt32:
+ return FragAttrs('r', 32, "(int *)");
+ case DataType::kFloat32:
+ return FragAttrs('f', 32, "(float *)");
+ case DataType::kFloat64:
+ return FragAttrs('d', 64, "(double *)");
+ default:
+ ICHECK(false) << DTypeToString(dtype) << " is not matrix data type in
MMA.";
+ return FragAttrs('\0', 0, "");
}
- asm_code = ReplaceMMAArgument(asm_code, "left_layout", A_layout);
- asm_code = ReplaceMMAArgument(asm_code, "right_layout", B_layout);
- asm_code = ReplaceMMAArgument(asm_code, "A", new_a_ref);
- asm_code = ReplaceMMAArgument(asm_code, "B", new_b_ref);
- asm_code = ReplaceMMAArgument(asm_code, "C", new_c_ref);
- asm_code = ReplaceMMAArgument(asm_code, "D", new_c_ref);
- return asm_code;
}
-std::string PrintMMAm16n8k32Assembly(const std::string& A_layout, const
std::string& B_layout,
- const std::string& A_dtype, const
std::string& B_dtype,
- const std::string& C_dtype, const
std::string& a_ref,
- const std::string& a_bias, const
std::string& b_ref,
- const std::string& b_bias, const
std::string& c_ref,
- const std::string& c_bias, bool saturate)
{
- std::string asm_code = "";
- std::string new_a_ref = "";
- std::string new_b_ref = "";
- std::string new_c_ref = "";
- ICHECK(((A_dtype == "int8") && (B_dtype == "int8")) ||
- ((A_dtype == "uint8") && (B_dtype == "int8")) ||
- ((A_dtype == "int8") && (B_dtype == "uint8")) ||
- ((A_dtype == "uint8") && (B_dtype == "uint8")));
- if ((A_dtype == "int8") && (B_dtype == "int8")) {
- // A/B multiplicand is int8, SM 80 Tensor Core instructions
- ICHECK(C_dtype == "int32");
- ICHECK((A_layout == "row") && (B_layout == "col"))
- << "SM80 Tensor Core instructions "
- << "with shape m16n8k32 expect A layout is row major and B layout is
col major.";
- // C accumulator is int32
- if (!saturate) {
- // no saturate
- new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
- new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
- new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
- asm_code = R"(
- {
- __asm__ __volatile__(
- "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 "
- "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, "
- "{%10,%11,%12,%13};\n"
- : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
- : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
- "r"(B[0]), "r"(B[1]),
- "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
- }
- )";
- } else {
- // saturate
- new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
- new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
- new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
- asm_code = R"(
- {
- __asm__ __volatile__(
- "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite "
- "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, "
- "{%10,%11,%12,%13};\n"
- : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
- : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
- "r"(B[0]), "r"(B[1]),
- "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
- }
- )";
- }
- } else if ((A_dtype == "uint8") && (B_dtype == "int8")) {
- // A multiplicand is uint8, B multiplicand is int8
- // SM 80 Tensor Core instructions
- ICHECK(C_dtype == "int32");
- ICHECK((A_layout == "row") && (B_layout == "col"))
- << "SM80 Tensor Core instructions "
- << "with shape m16n8k32 expect A layout is row major and B layout is
col major.";
- // C accumulator is int32
- if (!saturate) {
- // no saturate
- new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
- new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
- new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
- asm_code = R"(
- {
- __asm__ __volatile__(
- "mma.sync.aligned.m16n8k32.row.col.s32.u8.s8.s32 "
- "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, "
- "{%10,%11,%12,%13};\n"
- : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
- : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
- "r"(B[0]), "r"(B[1]),
- "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
- }
- )";
- } else {
- // saturate
- new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
- new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
- new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
- asm_code = R"(
- {
- __asm__ __volatile__(
- "mma.sync.aligned.m16n8k32.row.col.s32.u8.s8.s32.satfinite "
- "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, "
- "{%10,%11,%12,%13};\n"
- : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
- : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
- "r"(B[0]), "r"(B[1]),
- "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
- }
- )";
- }
- } else if ((A_dtype == "int8") && (B_dtype == "uint8")) {
- // A multiplicand is int8, B multiplicand is uint8
- // SM 80 Tensor Core instructions
- ICHECK(C_dtype == "int32");
- ICHECK((A_layout == "row") && (B_layout == "col"))
- << "SM80 Tensor Core instructions "
- << "with shape m16n8k32 expect A layout is row major and B layout is
col major.";
- // C accumulator is int32
- if (!saturate) {
- // no saturate
- new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
- new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
- new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
- asm_code = R"(
- {
- __asm__ __volatile__(
- "mma.sync.aligned.m16n8k32.row.col.s32.s8.u8.s32 "
- "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, "
- "{%10,%11,%12,%13};\n"
- : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
- : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
- "r"(B[0]), "r"(B[1]),
- "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
- }
- )";
- } else {
- // saturate
- new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
- new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
- new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
- asm_code = R"(
- {
- __asm__ __volatile__(
- "mma.sync.aligned.m16n8k32.row.col.s32.s8.u8.s32.satfinite "
- "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, "
- "{%10,%11,%12,%13};\n"
- : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
- : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
- "r"(B[0]), "r"(B[1]),
- "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
- }
- )";
+}; // namespace ptx
+
+/*!
+ * \brief Replace patterns with replacement strings.
+ * \note should use std::format instead when codebase is ported to C++20.
+ */
+class Replacer {
+ public:
+ void register_rule(const std::string& pattern, const std::string&
replacement) {
+ _rules.emplace_back(pattern, replacement);
+ }
+ std::string rewrite(std::string str) {
+ for (auto&& rule : _rules) {
+ std::string pattern, replacement;
+ std::tie(pattern, replacement) = rule;
+ size_t len = pattern.size();
+ size_t new_len = replacement.size();
+ size_t pos = str.find(pattern);
+ while (pos != std::string::npos) {
+ str = str.replace(pos, len, replacement);
+ pos = str.find(pattern, pos + new_len);
+ }
}
+ return str;
+ }
+ void empty_rules() { _rules.clear(); }
+
+ private:
+ std::vector<std::pair<std::string, std::string>> _rules;
+};
+
+/*!
+ * \brief Get the number of MMA computations for given shape and datatype.
+ */
+inline uint32_t GetNumMMAComputations(int m, int n, int k, ptx::DataType
dtype) {
+ if (m == 8 && n == 8 && k == 4 && dtype == ptx::DataType::kFloat16) {
+ // MMA for m8n8k4 on fp16 would launch 4 MMA computations instead of one.
+ return 4;
} else {
- // A/B multiplicand is uint8, SM 80 Tensor Core instructions
- ICHECK(C_dtype == "int32");
- ICHECK((A_layout == "row") && (B_layout == "col"))
- << "SM80 Tensor Core instructions "
- << "with shape m16n8k32 expect A layout is row major and B layout is
col major.";
- // C accumulator is int32
- if (!saturate) {
- // no saturate
- new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
- new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
- new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
- asm_code = R"(
- {
- __asm__ __volatile__(
- "mma.sync.aligned.m16n8k32.row.col.s32.u8.u8.s32 "
- "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, "
- "{%10,%11,%12,%13};\n"
- : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
- : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
- "r"(B[0]), "r"(B[1]),
- "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
- }
- )";
- } else {
- // saturate
- new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
- new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
- new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
- asm_code = R"(
- {
- __asm__ __volatile__(
- "mma.sync.aligned.m16n8k32.row.col.s32.u8.u8.s32.satfinite "
- "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, "
- "{%10,%11,%12,%13};\n"
- : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
- : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
- "r"(B[0]), "r"(B[1]),
- "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
- }
- )";
- }
+ return 1;
}
- asm_code = ReplaceMMAArgument(asm_code, "left_layout", A_layout);
- asm_code = ReplaceMMAArgument(asm_code, "right_layout", B_layout);
- asm_code = ReplaceMMAArgument(asm_code, "A", new_a_ref);
- asm_code = ReplaceMMAArgument(asm_code, "B", new_b_ref);
- asm_code = ReplaceMMAArgument(asm_code, "C", new_c_ref);
- asm_code = ReplaceMMAArgument(asm_code, "D", new_c_ref);
- return asm_code;
}
-std::string PrintMMAm16n8k64Assembly(const std::string& A_layout, const
std::string& B_layout,
- const std::string& A_dtype, const
std::string& B_dtype,
- const std::string& C_dtype, const
std::string& a_ref,
- const std::string& a_bias, const
std::string& b_ref,
- const std::string& b_bias, const
std::string& c_ref,
- const std::string& c_bias, bool saturate)
{
- std::string asm_code = "";
- std::string new_a_ref = "";
- std::string new_b_ref = "";
- std::string new_c_ref = "";
- ICHECK(((A_dtype == "int4") && (B_dtype == "int4")) ||
- ((A_dtype == "uint4") && (B_dtype == "int4")) ||
- ((A_dtype == "int4") && (B_dtype == "uint4")) ||
- ((A_dtype == "uint4") && (B_dtype == "uint4")));
- if ((A_dtype == "int4") && (B_dtype == "int4")) {
- // A/B multiplicand is int4, SM 80 Tensor Core instructions
- ICHECK(C_dtype == "int32");
- ICHECK((A_layout == "row") && (B_layout == "col"))
- << "SM80 Tensor Core instructions "
- << "with shape m16n8k64 expect A layout is row major and B layout is
col major.";
- // C accumulator is int32
- if (!saturate) {
- // no saturate
- new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
- new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
- new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
- asm_code = R"(
- {
- __asm__ __volatile__(
- "mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32 "
- "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, "
- "{%10,%11,%12,%13};\n"
- : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
- : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
- "r"(B[0]), "r"(B[1]),
- "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
- }
- )";
- } else {
- // saturate
- new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
- new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
- new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
- asm_code = R"(
- {
- __asm__ __volatile__(
- "mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32.satfinite "
- "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, "
- "{%10,%11,%12,%13};\n"
- : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
- : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
- "r"(B[0]), "r"(B[1]),
- "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
- }
- )";
- }
- } else if ((A_dtype == "uint4") && (B_dtype == "int4")) {
- // A multiplicand is uint4, B multiplicand is int4
- // SM 80 Tensor Core instructions
- ICHECK(C_dtype == "int32");
- ICHECK((A_layout == "row") && (B_layout == "col"))
- << "SM80 Tensor Core instructions "
- << "with shape m16n8k64 expect A layout is row major and B layout is
col major.";
- // C accumulator is int32
- if (!saturate) {
- // no saturate
- new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
- new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
- new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
- asm_code = R"(
- {
- __asm__ __volatile__(
- "mma.sync.aligned.m16n8k64.row.col.s32.u4.s4.s32 "
- "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, "
- "{%10,%11,%12,%13};\n"
- : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
- : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
- "r"(B[0]), "r"(B[1]),
- "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
- }
- )";
- } else {
- // saturate
- new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
- new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
- new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
- asm_code = R"(
- {
- __asm__ __volatile__(
- "mma.sync.aligned.m16n8k64.row.col.s32.u4.s4.s32.satfinite "
- "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, "
- "{%10,%11,%12,%13};\n"
- : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
- : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
- "r"(B[0]), "r"(B[1]),
- "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
- }
- )";
- }
- } else if ((A_dtype == "int4") && (B_dtype == "uint4")) {
- // A multiplicand is int4, B multiplicand is uint4
- // SM 80 Tensor Core instructions
- ICHECK(C_dtype == "int32");
- ICHECK((A_layout == "row") && (B_layout == "col"))
- << "SM80 Tensor Core instructions "
- << "with shape m16n8k64 expect A layout is row major and B layout is
col major.";
- // C accumulator is int32
- if (!saturate) {
- // no saturate
- new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
- new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
- new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
- asm_code = R"(
- {
- __asm__ __volatile__(
- "mma.sync.aligned.m16n8k64.row.col.s32.s4.u4.s32 "
- "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, "
- "{%10,%11,%12,%13};\n"
- : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
- : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
- "r"(B[0]), "r"(B[1]),
- "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
- }
- )";
- } else {
- // saturate
- new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
- new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
- new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
- asm_code = R"(
- {
- __asm__ __volatile__(
- "mma.sync.aligned.m16n8k64.row.col.s32.s4.u4.s32.satfinite "
- "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, "
- "{%10,%11,%12,%13};\n"
- : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
- : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
- "r"(B[0]), "r"(B[1]),
- "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
- }
- )";
- }
- } else {
- // A/B multiplicand is uint4, SM 75 Tensor Core instructions
- ICHECK(C_dtype == "int32");
- ICHECK((A_layout == "row") && (B_layout == "col"))
- << "SM80 Tensor Core instructions "
- << "with shape m16n8k64 expect A layout is row major and B layout is
col major.";
- // C accumulator is int32
- if (!saturate) {
- // no saturate
- new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
- new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
- new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
- asm_code = R"(
- {
- __asm__ __volatile__(
- "mma.sync.aligned.m16n8k64.row.col.s32.u4.u4.s32 "
- "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, "
- "{%10,%11,%12,%13};\n"
- : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
- : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
- "r"(B[0]), "r"(B[1]),
- "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
- }
- )";
- } else {
- // saturate
- new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
- new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
- new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
- asm_code = R"(
- {
- __asm__ __volatile__(
- "mma.sync.aligned.m16n8k64.row.col.s32.u4.u4.s32.satfinite "
- "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, "
- "{%10,%11,%12,%13};\n"
- : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
- : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
- "r"(B[0]), "r"(B[1]),
- "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
- }
- )";
+/*!
+ * \brief Return template string, input operands string and output operands
string.
+ * \param m The M in mMnNkK of MMA instructions.
+ * \param n The N in mMnNkK of MMA instructions.
+ * \param k The K in mMnNkK of MMA instructions.
+ * \param dtype_a The data type of multiplicand a.
+ * \param dtype_b The data type of multiplicand b.
+ * \param dtype_c The data type of accumulator c.
+ * \param sparse Whether it's Sparse MMA or not.
+ */
+inline std::tuple<std::string, std::string, std::string> GetMMAOperands(int m,
int n, int k,
+
ptx::DataType dtype_a,
+
ptx::DataType dtype_b,
+
ptx::DataType dtype_c,
+ bool
sparse) {
+ std::stringstream templates, inputs, outputs;
+ const ptx::FragAttrs frag_attr_a = ptx::GetFragAttrs(dtype_a),
+ frag_attr_b = ptx::GetFragAttrs(dtype_b),
+ frag_attr_c = ptx::GetFragAttrs(dtype_c);
+ constexpr uint32_t warp_size = 32;
+ const uint32_t threads = warp_size / GetNumMMAComputations(m, n, k, dtype_a);
+ const int num_operands_a =
+ (m * k) * ptx::DTypeBits(dtype_a) / frag_attr_a.size / threads
/ (sparse ? 2 : 1),
+ num_operands_b = (k * n) * ptx::DTypeBits(dtype_b) /
frag_attr_b.size / threads,
+ num_operands_c = (m * n) * ptx::DTypeBits(dtype_c) /
frag_attr_c.size / threads;
+
+ // generate templates;
+ int arg_counter = 0;
+ templates << "{"
+ << "%" << arg_counter++;
+ for (int i = 1; i < num_operands_c; ++i) {
+ templates << ", %" << arg_counter++;
+ }
+ templates << "}, {"
+ << "%" << arg_counter++;
+ for (int i = 1; i < num_operands_a; ++i) {
+ templates << ", %" << arg_counter++;
+ }
+ templates << "}, {"
+ << "%" << arg_counter++;
+ for (int i = 1; i < num_operands_b; ++i) {
+ templates << ", %" << arg_counter++;
+ }
+ templates << "}, {"
+ << "%" << arg_counter++;
+ for (int i = 1; i < num_operands_c; ++i) {
+ templates << ", %" << arg_counter++;
+ }
+ templates << "}";
+ // templates of metadata and sparse selector for sparse mma.
+ if (sparse) {
+ templates << ", %" << (arg_counter++) << ", F";
+ }
+
+ // generate inputs
+ for (int i = 0; i < num_operands_a; ++i) {
+ if (i != 0) {
+ inputs << ", ";
}
+ inputs << "\"" << frag_attr_a.reg_type << "\"((" << frag_attr_a.ptr_type
<< "(A))[" << i
+ << "])";
+ }
+ for (int i = 0; i < num_operands_b; ++i) {
+ inputs << ", \"" << frag_attr_b.reg_type << "\"((" << frag_attr_b.ptr_type
<< "(B))[" << i
+ << "])";
+ }
+ for (int i = 0; i < num_operands_c; ++i) {
+ inputs << ", \"" << frag_attr_c.reg_type << "\"((" << frag_attr_c.ptr_type
<< "(C))[" << i
+ << "])";
+ }
+ // input of metadata for sparse mma.
+ if (sparse) {
+ inputs << ", \"r\"(((unsigned *)(E))[0])";
}
- asm_code = ReplaceMMAArgument(asm_code, "left_layout", A_layout);
- asm_code = ReplaceMMAArgument(asm_code, "right_layout", B_layout);
- asm_code = ReplaceMMAArgument(asm_code, "A", new_a_ref);
- asm_code = ReplaceMMAArgument(asm_code, "B", new_b_ref);
- asm_code = ReplaceMMAArgument(asm_code, "C", new_c_ref);
- asm_code = ReplaceMMAArgument(asm_code, "D", new_c_ref);
- return asm_code;
-}
-std::string PrintMMAm16n8k256Assembly(const std::string& A_layout, const
std::string& B_layout,
- const std::string& A_dtype, const
std::string& B_dtype,
- const std::string& C_dtype, const
std::string& a_ref,
- const std::string& a_bias, const
std::string& b_ref,
- const std::string& b_bias, const
std::string& c_ref,
- const std::string& c_bias, bool
saturate) {
- std::string asm_code = "";
- std::string new_a_ref = "";
- std::string new_b_ref = "";
- std::string new_c_ref = "";
- ICHECK(((A_dtype == "uint1") && (B_dtype == "uint1")) ||
- ((A_dtype == "int1") && (B_dtype == "int1")));
- if ((A_dtype == "uint1") && (B_dtype == "uint1")) {
- // A/B multiplicand is uint1, SM 80 Tensor Core instructions
- ICHECK(C_dtype == "int32");
- ICHECK((A_layout == "row") && (B_layout == "col"))
- << "SM80 Tensor Core instructions "
- << "with shape m16n8k256 expect A layout is row major and B layout is
col major.";
- // C accumulator is int32
- new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
- new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
- new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
- asm_code = R"(
- {
- __asm__ __volatile__(
- "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.and.popc "
- "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, "
- "{%10,%11,%12,%13};\n"
- : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
- : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
- "r"(B[0]), "r"(B[1]),
- "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
- }
- )";
- } else {
- // A/B multiplicand is int1, SM 80 Tensor Core instructions
- ICHECK(C_dtype == "int32");
- ICHECK((A_layout == "row") && (B_layout == "col"))
- << "SM80 Tensor Core instructions "
- << "with shape m16n8k256 expect A layout is row major and B layout is
col major.";
- // C accumulator is int32
- new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
- new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
- new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
- asm_code = R"(
- {
- __asm__ __volatile__(
- "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc "
- "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, "
- "{%10,%11,%12,%13};\n"
- : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
- : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
- "r"(B[0]), "r"(B[1]),
- "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
- }
- )";
+ // generate outputs
+ for (int i = 0; i < num_operands_c; ++i) {
+ if (i != 0) {
+ outputs << ",";
+ }
+ outputs << " \"=" << frag_attr_c.reg_type << "\"((" <<
frag_attr_c.ptr_type << "(D))[" << i
+ << "])";
}
- asm_code = ReplaceMMAArgument(asm_code, "left_layout", A_layout);
- asm_code = ReplaceMMAArgument(asm_code, "right_layout", B_layout);
- asm_code = ReplaceMMAArgument(asm_code, "A", new_a_ref);
- asm_code = ReplaceMMAArgument(asm_code, "B", new_b_ref);
- asm_code = ReplaceMMAArgument(asm_code, "C", new_c_ref);
- asm_code = ReplaceMMAArgument(asm_code, "D", new_c_ref);
- return asm_code;
+ return std::make_tuple(templates.str(), inputs.str(), outputs.str());
}
std::string PrintMMAAssembly(const std::string& shape, const std::string&
A_layout,
const std::string& B_layout, const std::string&
A_dtype,
const std::string& B_dtype, const std::string&
C_dtype,
- const std::string& a_ref, const std::string&
a_bias,
- const std::string& b_ref, const std::string&
b_bias,
- const std::string& c_ref, const std::string&
c_bias, bool saturate) {
- ICHECK((shape == "m8n8k4") || (shape == "m16n8k8") || (shape == "m8n8k16") ||
- (shape == "m8n8k32") || (shape == "m16n8k4") || (shape == "m16n8k16")
||
- (shape == "m16n8k32") || (shape == "m16n8k64") || (shape ==
"m16n8k256"));
- ICHECK((A_layout == "row") || (A_layout == "col")) << "Unknown A layout: "
<< A_layout;
- ICHECK((B_layout == "row") || (B_layout == "col")) << "Unknown B layout: "
<< B_layout;
-
- if (shape == "m8n8k4") {
- return PrintMMAm8n8k4Assembly(A_layout, B_layout, A_dtype, B_dtype,
C_dtype, a_ref, a_bias,
- b_ref, b_bias, c_ref, c_bias, saturate);
- } else if (shape == "m16n8k8") {
- return PrintMMAm16n8k8Assembly(A_layout, B_layout, A_dtype, B_dtype,
C_dtype, a_ref, a_bias,
- b_ref, b_bias, c_ref, c_bias, saturate);
- } else if (shape == "m8n8k16") {
- return PrintMMAm8n8k16Assembly(A_layout, B_layout, A_dtype, B_dtype,
C_dtype, a_ref, a_bias,
- b_ref, b_bias, c_ref, c_bias, saturate);
- } else if (shape == "m8n8k32") {
- return PrintMMAm8n8k32Assembly(A_layout, B_layout, A_dtype, B_dtype,
C_dtype, a_ref, a_bias,
- b_ref, b_bias, c_ref, c_bias, saturate);
- } else if (shape == "m16n8k4") {
- return PrintMMAm16n8k4Assembly(A_layout, B_layout, A_dtype, B_dtype,
C_dtype, a_ref, a_bias,
- b_ref, b_bias, c_ref, c_bias, saturate);
- } else if (shape == "m16n8k16") {
- return PrintMMAm16n8k16Assembly(A_layout, B_layout, A_dtype, B_dtype,
C_dtype, a_ref, a_bias,
- b_ref, b_bias, c_ref, c_bias, saturate);
- } else if (shape == "m16n8k32") {
- return PrintMMAm16n8k32Assembly(A_layout, B_layout, A_dtype, B_dtype,
C_dtype, a_ref, a_bias,
- b_ref, b_bias, c_ref, c_bias, saturate);
- } else if (shape == "m16n8k64") {
- return PrintMMAm16n8k64Assembly(A_layout, B_layout, A_dtype, B_dtype,
C_dtype, a_ref, a_bias,
- b_ref, b_bias, c_ref, c_bias, saturate);
- } else if (shape == "m16n8k256") {
- return PrintMMAm16n8k256Assembly(A_layout, B_layout, A_dtype, B_dtype,
C_dtype, a_ref, a_bias,
- b_ref, b_bias, c_ref, c_bias, saturate);
+ const std::string& a_ref, const std::string&
a_offset,
+ const std::string& b_ref, const std::string&
b_offset,
+ const std::string& c_ref, const std::string&
c_offset,
+ const std::string& metadata, const std::string&
metadata_offset,
+ const std::string& sparsity_selector, const
std::string& bit_op,
+ bool sparse, bool saturate) {
+ ptx::DataType dtype_a = ptx::DTypeFromString(A_dtype), dtype_b =
ptx::DTypeFromString(B_dtype),
+ dtype_c = ptx::DTypeFromString(C_dtype);
+ ptx::LayoutType layout_a = ptx::LayoutTypeFromString(A_layout),
+ layout_b = ptx::LayoutTypeFromString(B_layout);
+ int m, n, k;
+ std::tie(m, n, k) = ptx::ParseMMAShape(shape);
+ CheckMMAConfigValidity(m, n, k, layout_a, layout_b, dtype_a, dtype_b,
dtype_c, bit_op, sparse,
+ saturate);
+ std::string asm_code = R"(
+ {
+ __asm__ __volatile__(
+
"mma{sparse}.sync.aligned.{shape}.{alayout}.{blayout}{saturate}{dtype}{atype}{btype}{ctype}{bitop}"
+ "{templates};\n"
+ : {outputs}
+ : {inputs});
}
- /*
- * TODO: add mma.m16n8k128
- */
- throw Error("Unknown PTX mma instructions.");
+)";
+ std::string templates_str, inputs_str, outputs_str;
+ std::tie(templates_str, inputs_str, outputs_str) =
+ GetMMAOperands(m, n, k, dtype_a, dtype_b, dtype_c, sparse);
+
+ // replace patterns
+ Replacer replacer;
+ replacer.register_rule("{sparse}", sparse ? ".sp" : "");
+ replacer.register_rule("{shape}", shape);
+ replacer.register_rule("{saturate}", saturate ? ".satfinite" : "");
+ replacer.register_rule("{alayout}", A_layout);
+ replacer.register_rule("{blayout}", B_layout);
+ replacer.register_rule("{atype}", ptx::DTypeToString(dtype_a));
+ replacer.register_rule("{btype}", ptx::DTypeToString(dtype_b));
+ replacer.register_rule("{ctype}", ptx::DTypeToString(dtype_c));
+ replacer.register_rule("{dtype}", ptx::DTypeToString(dtype_c));
+ replacer.register_rule("{bitop}", bit_op.empty() ? "" : "." + bit_op +
".popc");
+ replacer.register_rule("{templates}", templates_str);
+ replacer.register_rule("{outputs}", outputs_str);
+ replacer.register_rule("{inputs}", inputs_str);
+ asm_code = replacer.rewrite(asm_code);
+ replacer.empty_rules();
+ replacer.register_rule("A", a_ref + " + " + a_offset);
+ replacer.register_rule("B", b_ref + " + " + b_offset);
+ replacer.register_rule("C", c_ref + " + " + c_offset);
+ replacer.register_rule("D", c_ref + " + " + c_offset);
+ replacer.register_rule("E", metadata + " + " + metadata_offset);
+ replacer.register_rule("F", sparsity_selector);
+ asm_code = replacer.rewrite(asm_code);
+ return asm_code;
}
} // namespace codegen
diff --git a/src/target/source/ptx_mma.h b/src/target/source/ptx_mma.h
index d2a7a67..728478c 100644
--- a/src/target/source/ptx_mma.h
+++ b/src/target/source/ptx_mma.h
@@ -32,12 +32,36 @@
namespace tvm {
namespace codegen {
+/*!
+ * \brief Print MMA assembly string given parameters.
+ * \param shape The shape string mMnNkK
+ * \param A_layout The layout of multiplicand A, can be either "row" or "col".
+ * \param B_layout The layout of multiplicand B, can be either "row" or "col".
+ * \param A_dtype The data type of multiplicand A.
+ * \param B_dtype The data type of multiplicand B.
+ * \param C_dtype The data type of multiplicand C.
+ * \param a_ref Pointer to buffer A.
+ * \param a_offset The offset of element in A.
+ * \param b_ref Pointer to buffer B.
+ * \param b_offset The offset of element in B.
+ * \param c_ref Pointer to buffer C.
+ * \param c_offset The offset of element in C.
+ * \param metadata Pointer to metadata buffer (only used for sparse mma).
+ * \param metadata_offset The offset of element in metadata.
+ * \param sparsity_selector The sparsity selector in sparse mma.
+ * \param bit_op The bit operator used in 1-bit mma, can be either "xor" or
"and".
+ * \param sparse Whether it's sparse mma or not.
+ * \param saturate Whether saturate output or not.
+ */
std::string PrintMMAAssembly(const std::string& shape, const std::string&
A_layout,
const std::string& B_layout, const std::string&
A_dtype,
const std::string& B_dtype, const std::string&
C_dtype,
- const std::string& a_ref, const std::string&
a_bias,
- const std::string& b_ref, const std::string&
b_bias,
- const std::string& c_ref, const std::string&
c_bias, bool saturate);
+ const std::string& a_ref, const std::string&
a_offset,
+ const std::string& b_ref, const std::string&
b_offset,
+ const std::string& c_ref, const std::string&
c_offset,
+ const std::string& metadata, const std::string&
metadata_offset,
+ const std::string& sparsity_selector, const
std::string& bit_op,
+ bool sparse, bool saturate);
} // namespace codegen
} // namespace tvm
diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc
index 0e767ea..977050a 100644
--- a/src/tir/op/builtin.cc
+++ b/src/tir/op/builtin.cc
@@ -237,6 +237,9 @@ TIR_DEFINE_BUILTIN_FUNC(tvm_store_matrix_sync)
TIR_DEFINE_BUILTIN_FUNC(ptx_mma).set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
+TIR_DEFINE_BUILTIN_FUNC(ptx_mma_sp)
+ .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
+
TIR_DEFINE_BUILTIN_FUNC(vectorhigh)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure));
diff --git a/tests/python/unittest/test_tir_ptx_mma.py
b/tests/python/unittest/test_tir_ptx_mma.py
index 8f653c6..23405fd 100644
--- a/tests/python/unittest/test_tir_ptx_mma.py
+++ b/tests/python/unittest/test_tir_ptx_mma.py
@@ -1311,6 +1311,7 @@ def gemm_mma_m16n8k256_row_col_b1b1s32(a: T.handle, b:
T.handle, c: T.handle):
Accum.data,
0,
False,
+ "xor",
dtype="int32",
)
)
diff --git a/tests/python/unittest/test_tir_ptx_mma_sp.py
b/tests/python/unittest/test_tir_ptx_mma_sp.py
new file mode 100644
index 0000000..321cd28
--- /dev/null
+++ b/tests/python/unittest/test_tir_ptx_mma_sp.py
@@ -0,0 +1,346 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+from tvm.script import tir as T
+import numpy as np
+import tvm.testing
+
+
+def gen_2in4_mask(m: int, n: int):
+ assert n % 4 == 0
+ return np.array(
+ [[np.sort(np.random.choice(4, 2, replace=False)) for _ in range(n //
4)] for _ in range(m)]
+ ).astype("uint8")
+
+
+def get_dense_mat_by_mask(val, mask):
+ m, n_chunks, _ = mask.shape
+ val = val.reshape(m, n_chunks, 2)
+ ret = np.zeros((m, n_chunks, 4)).astype(val.dtype)
+ for i in range(m):
+ for j in range(n_chunks):
+ for k in range(2):
+ ret[i, j, mask[i, j, k]] = val[i, j, k]
+ return ret.reshape(m, n_chunks * 4)
+
+
[email protected]_func
+def mma_sp_m16n8k16_f16f16f16(a: T.handle, b: T.handle, c: T.handle,
_metadata: T.handle):
+ T.func_attr({"global_symbol": "default_function", "tir.noalias": True})
+ A = T.match_buffer(a, [16, 8], dtype="float16")
+ B = T.match_buffer(b, [16, 8], dtype="float16")
+ C = T.match_buffer(c, [16, 8], dtype="float16")
+ metadata = T.match_buffer(_metadata, [8], dtype="uint32")
+ brow = T.env_thread("blockIdx.y")
+ bcol = T.env_thread("blockIdx.x")
+ tx = T.env_thread("threadIdx.x")
+ T.launch_thread(brow, 1)
+ T.launch_thread(bcol, 1)
+ T.launch_thread(tx, 32)
+ multi_a = T.allocate([4], "float16", scope="local")
+ multi_b = T.allocate([4], "float16", scope="local")
+ accum = T.allocate([4], "float16", scope="local")
+ meta_local = T.allocate([1], "uint32", scope="local")
+ for i in range(4):
+ accum[i] = T.float16(0)
+
+ for i in range(4):
+ multi_a[i] = A[tx // 4 + i // 2 * 8, tx % 4 * 2 + i % 2]
+
+ for i in range(4):
+ multi_b[i] = B[tx % 4 * 2 + i % 2 + i // 2 * 8, tx // 4]
+
+ meta_local[0] = metadata[tx // 4]
+
+ T.evaluate(
+ T.ptx_mma_sp(
+ "m16n8k16",
+ "row",
+ "col",
+ "fp16",
+ "fp16",
+ "fp16",
+ multi_a.data,
+ 0,
+ multi_b.data,
+ 0,
+ accum.data,
+ 0,
+ meta_local.data,
+ 0,
+ 0,
+ False,
+ dtype="float16",
+ )
+ )
+
+ for i in range(4):
+ C[i // 2 * 8 + tx // 4, tx % 4 * 2 + i % 2] = accum[i]
+
+
[email protected]_func
+def mma_sp_m16n8k16_f16f16f32(a: T.handle, b: T.handle, c: T.handle,
_metadata: T.handle):
+ T.func_attr({"global_symbol": "default_function", "tir.noalias": True})
+ A = T.match_buffer(a, [16, 8], dtype="float16")
+ B = T.match_buffer(b, [16, 8], dtype="float16")
+ C = T.match_buffer(c, [16, 8], dtype="float32")
+ metadata = T.match_buffer(_metadata, [8], dtype="uint32")
+ brow = T.env_thread("blockIdx.y")
+ bcol = T.env_thread("blockIdx.x")
+ tx = T.env_thread("threadIdx.x")
+ T.launch_thread(brow, 1)
+ T.launch_thread(bcol, 1)
+ T.launch_thread(tx, 32)
+ multi_a = T.allocate([4], "float16", scope="local")
+ multi_b = T.allocate([4], "float16", scope="local")
+ accum = T.allocate([4], "float32", scope="local")
+ meta_local = T.allocate([1], "uint32", scope="local")
+ for i in range(4):
+ accum[i] = T.float16(0)
+
+ for i in range(4):
+ multi_a[i] = A[tx // 4 + i // 2 * 8, tx % 4 * 2 + i % 2]
+
+ for i in range(4):
+ multi_b[i] = B[tx % 4 * 2 + i % 2 + i // 2 * 8, tx // 4]
+
+ meta_local[0] = metadata[tx // 4]
+
+ T.evaluate(
+ T.ptx_mma_sp(
+ "m16n8k16",
+ "row",
+ "col",
+ "fp16",
+ "fp16",
+ "fp32",
+ multi_a.data,
+ 0,
+ multi_b.data,
+ 0,
+ accum.data,
+ 0,
+ meta_local.data,
+ 0,
+ 0,
+ False,
+ dtype="float32",
+ )
+ )
+
+ for i in range(4):
+ C[i // 2 * 8 + tx // 4, tx % 4 * 2 + i % 2] = accum[i]
+
+
[email protected]_func
+def mma_sp_m16n8k32_f16f16f16(a: T.handle, b: T.handle, c: T.handle,
_metadata: T.handle):
+ T.func_attr({"global_symbol": "default_function", "tir.noalias": True})
+ A = T.match_buffer(a, [16, 16], dtype="float16")
+ B = T.match_buffer(b, [32, 8], dtype="float16")
+ C = T.match_buffer(c, [16, 8], dtype="float16")
+ metadata = T.match_buffer(_metadata, [16], dtype="uint32")
+ brow = T.env_thread("blockIdx.y")
+ bcol = T.env_thread("blockIdx.x")
+ tx = T.env_thread("threadIdx.x")
+ T.launch_thread(brow, 1)
+ T.launch_thread(bcol, 1)
+ T.launch_thread(tx, 32)
+ multi_a = T.allocate([8], "float16", scope="local")
+ multi_b = T.allocate([8], "float16", scope="local")
+ accum = T.allocate([4], "float16", scope="local")
+ meta_local = T.allocate([1], "uint32", scope="local")
+ for i in range(4):
+ accum[i] = T.float16(0)
+
+ for i in range(8):
+ multi_a[i] = A[(i % 4) // 2 * 8 + tx // 4, i // 4 * 8 + tx % 4 * 2 + i
% 2]
+
+ for i in range(8):
+ multi_b[i] = B[i // 2 * 8 + tx % 4 * 2 + i % 2, tx // 4]
+
+ meta_local[0] = metadata[tx // 4 * 2 + tx % 2]
+
+ T.evaluate(
+ T.ptx_mma_sp(
+ "m16n8k32",
+ "row",
+ "col",
+ "fp16",
+ "fp16",
+ "fp16",
+ multi_a.data,
+ 0,
+ multi_b.data,
+ 0,
+ accum.data,
+ 0,
+ meta_local.data,
+ 0,
+ 0,
+ False,
+ dtype="float16",
+ )
+ )
+
+ for i in range(4):
+ C[i // 2 * 8 + tx // 4, tx % 4 * 2 + i % 2] = accum[i]
+
+
[email protected]_func
+def mma_sp_m16n8k32_f16f16f32(a: T.handle, b: T.handle, c: T.handle,
_metadata: T.handle):
+ T.func_attr({"global_symbol": "default_function", "tir.noalias": True})
+ A = T.match_buffer(a, [16, 16], dtype="float16")
+ B = T.match_buffer(b, [32, 8], dtype="float16")
+ C = T.match_buffer(c, [16, 8], dtype="float32")
+ metadata = T.match_buffer(_metadata, [16], dtype="uint32")
+ brow = T.env_thread("blockIdx.y")
+ bcol = T.env_thread("blockIdx.x")
+ tx = T.env_thread("threadIdx.x")
+ T.launch_thread(brow, 1)
+ T.launch_thread(bcol, 1)
+ T.launch_thread(tx, 32)
+ multi_a = T.allocate([8], "float16", scope="local")
+ multi_b = T.allocate([8], "float16", scope="local")
+ accum = T.allocate([4], "float32", scope="local")
+ meta_local = T.allocate([1], "uint32", scope="local")
+ for i in range(4):
+ accum[i] = T.float16(0)
+
+ for i in range(8):
+ multi_a[i] = A[(i % 4) // 2 * 8 + tx // 4, i // 4 * 8 + tx % 4 * 2 + i
% 2]
+
+ for i in range(8):
+ multi_b[i] = B[i // 2 * 8 + tx % 4 * 2 + i % 2, tx // 4]
+
+ meta_local[0] = metadata[tx // 4 * 2 + tx % 2]
+
+ T.evaluate(
+ T.ptx_mma_sp(
+ "m16n8k32",
+ "row",
+ "col",
+ "fp16",
+ "fp16",
+ "fp32",
+ multi_a.data,
+ 0,
+ multi_b.data,
+ 0,
+ accum.data,
+ 0,
+ meta_local.data,
+ 0,
+ 0,
+ False,
+ dtype="float32",
+ )
+ )
+
+ for i in range(4):
+ C[i // 2 * 8 + tx // 4, tx % 4 * 2 + i % 2] = accum[i]
+
+
[email protected]_cuda
+def test_mma_sp_m16n8k16_f16():
+ def get_meta_m16n8k16_half(mask):
+ assert mask.shape == (16, 4, 2)
+ mask = mask.reshape(16, 8)
+ ret = np.zeros((8,)).astype("uint32")
+
+ for i in range(8):
+ base = 1
+ for blk in range(2):
+ for j in range(8):
+ ret[i] |= int(mask[blk * 8 + i, j]) * base
+ base = base << 2
+ return ret
+
+ for out_dtype in ["float16", "float32"]:
+ func = mma_sp_m16n8k16_f16f16f16 if out_dtype == "float16" else
mma_sp_m16n8k16_f16f16f32
+ sch = tvm.tir.Schedule(func)
+ arch = tvm.contrib.nvcc.get_target_compute_version()
+ major, _ = tvm.contrib.nvcc.parse_compute_version(arch)
+ if major < 8:
+ # Requires SM80+
+ return
+ cuda_mod = tvm.build(sch.mod, target="cuda")
+
+ A_np = np.random.uniform(-1, 1, [16, 8]).astype("float16")
+ B_np = np.random.uniform(-1, 1, [16, 8]).astype("float16")
+ mask = gen_2in4_mask(16, 16)
+ A_dense_np = get_dense_mat_by_mask(A_np, mask)
+ C_np = np.matmul(A_dense_np, B_np).astype(out_dtype)
+ meta = get_meta_m16n8k16_half(mask)
+
+ ctx = tvm.cuda()
+ A_tvm = tvm.nd.array(A_np, ctx)
+ B_tvm = tvm.nd.array(B_np, ctx)
+ C_tvm = tvm.nd.array(np.zeros_like(C_np), ctx)
+ meta_tvm = tvm.nd.array(meta, ctx)
+ cuda_mod(A_tvm, B_tvm, C_tvm, meta_tvm)
+
+ tvm.testing.assert_allclose(C_tvm.numpy(), C_np, atol=1e-3, rtol=1e-3)
+
+
[email protected]_cuda
+def test_mma_sp_m16n8k32_f16():
+ def get_meta_m16n8k32_half(mask):
+ assert mask.shape == (16, 8, 2)
+ mask = mask.reshape(16, 2, 8)
+ ret = np.zeros((8, 2)).astype("uint32")
+
+ for i in range(8):
+ for k in range(2):
+ base = 1
+ for blk in range(2):
+ for j in range(8):
+ ret[i, k] |= int(mask[blk * 8 + i, k, j]) * base
+ base = base << 2
+
+ return ret.reshape(16)
+
+ for out_dtype in ["float16", "float32"]:
+ func = mma_sp_m16n8k32_f16f16f16 if out_dtype == "float16" else
mma_sp_m16n8k32_f16f16f32
+ sch = tvm.tir.Schedule(func)
+ arch = tvm.contrib.nvcc.get_target_compute_version()
+ major, _ = tvm.contrib.nvcc.parse_compute_version(arch)
+ if major < 8:
+ # Requires SM80+
+ return
+ cuda_mod = tvm.build(sch.mod, target="cuda")
+
+ A_np = np.random.uniform(-1, 1, [16, 16]).astype("float16")
+ B_np = np.random.uniform(-1, 1, [32, 8]).astype("float16")
+ mask = gen_2in4_mask(16, 32)
+ A_dense_np = get_dense_mat_by_mask(A_np, mask)
+ C_np = np.matmul(A_dense_np, B_np).astype(out_dtype)
+ meta = get_meta_m16n8k32_half(mask)
+
+ ctx = tvm.cuda()
+ A_tvm = tvm.nd.array(A_np, ctx)
+ B_tvm = tvm.nd.array(B_np, ctx)
+ C_tvm = tvm.nd.array(np.zeros_like(C_np), ctx)
+ meta_tvm = tvm.nd.array(meta, ctx)
+ cuda_mod(A_tvm, B_tvm, C_tvm, meta_tvm)
+
+ tvm.testing.assert_allclose(C_tvm.numpy(), C_np, atol=1e-3, rtol=1e-3)
+
+
+if __name__ == "__main__":
+ test_mma_sp_m16n8k16_f16()
+ test_mma_sp_m16n8k32_f16()