This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 7688db7  [PTX] Support mma.sp to use Sparse Tensor Cores and refactor 
mma codegen (#10339)
7688db7 is described below

commit 7688db7ac5c4e1a043bf0dddeed75780ec49e70a
Author: Zihao Ye <[email protected]>
AuthorDate: Tue Mar 8 11:33:25 2022 -0800

    [PTX] Support mma.sp to use Sparse Tensor Cores and refactor mma codegen 
(#10339)
    
    * init
    
    * upd
    
    * upd
    
    * lint
    
    * lint again
    
    * upd
    
    * add m16n8k32 testcase
    
    * format
    
    * use make_tuple instead of initializer list
    
    * add metadata offset
    
    * upd
    
    * docstring and sanity
    
    * add u8s8s32 back
    
    * improvement
    
    * compatible #9727
---
 include/tvm/tir/builtin.h                    |   13 +
 src/target/source/codegen_cuda.cc            |   49 +-
 src/target/source/ptx_mma.cc                 | 1806 +++++++-------------------
 src/target/source/ptx_mma.h                  |   30 +-
 src/tir/op/builtin.cc                        |    3 +
 tests/python/unittest/test_tir_ptx_mma.py    |    1 +
 tests/python/unittest/test_tir_ptx_mma_sp.py |  346 +++++
 7 files changed, 934 insertions(+), 1314 deletions(-)

diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h
index f7e1cfbc..0d9f823 100644
--- a/include/tvm/tir/builtin.h
+++ b/include/tvm/tir/builtin.h
@@ -596,6 +596,19 @@ TVM_DLL const Op& tvm_store_matrix_sync();
  */
 TVM_DLL const Op& ptx_mma();
 
+/*!
+ * \brief tvm intrinsic for sparse tensor core ptx instructions.
+ *
+ * void ptx_mma_sp(StringImm shape, StringImm A_layout, StringImm B_layout,
+ *                 StringImm A_dtype, StringImm B_dtype, StringImm C_dtype,
+ *                 Var multiplicand_a, Expr a_index,
+ *                 Var multiplicand_b, Expr b_index,
+ *                 Var accumulator, Expr c_index,
+ *                 Var metadata, Expr meta_index,
+ *                 Var sparse_selector, bool saturate);
+ */
+TVM_DLL const Op& ptx_mma_sp();
+
 // TODO(tvm-team) replace the usage of the vector operations by Shuffle.
 /*!
  * \brief Get the high level half of the vector
diff --git a/src/target/source/codegen_cuda.cc 
b/src/target/source/codegen_cuda.cc
index 0dda079..f74d5cf 100644
--- a/src/target/source/codegen_cuda.cc
+++ b/src/target/source/codegen_cuda.cc
@@ -744,7 +744,8 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, 
std::ostream& os) {
     // arg 10: C accumulator
     // arg 11: C accumulator index
     // arg 12: saturate
-    ICHECK_EQ(op->args.size(), 13U);
+    // arg 13: (optional) 1-bit operator (xor or and)
+    ICHECK(op->args.size() == 13U || op->args.size() == 14U);
     std::string shape = Downcast<StringImm>(op->args[0])->value;
     std::string A_layout = Downcast<StringImm>(op->args[1])->value;
     std::string B_layout = Downcast<StringImm>(op->args[2])->value;
@@ -757,11 +758,51 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, 
std::ostream& os) {
     std::string b_bias = this->PrintExpr(op->args[9]);
     std::string c_ref = this->PrintExpr(op->args[10]);
     std::string c_bias = this->PrintExpr(op->args[11]);
-    bool saturate = (Downcast<IntImm>(op->args[12])->value != 0);
-    std::string asm_code = PrintMMAAssembly(shape, A_layout, B_layout, 
A_dtype, B_dtype, C_dtype,
-                                            a_ref, a_bias, b_ref, b_bias, 
c_ref, c_bias, saturate);
+    bool saturate = Downcast<Bool>(op->args[12])->value;
+    std::string bit_op = op->args.size() > 13 ? 
Downcast<StringImm>(op->args[13])->value : "";
+    std::string asm_code =
+        PrintMMAAssembly(shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, 
a_ref, a_bias, b_ref,
+                         b_bias, c_ref, c_bias, "", "", "", bit_op, false, 
saturate);
 
     this->stream << asm_code;
+  } else if (op->op.same_as(builtin::ptx_mma_sp())) {
+    // arg 0: shape: mXnXkX
+    // arg 1: A layout: row/col
+    // arg 2: B layout: row/col
+    // arg 3: A precision: fp16, fp32, ...
+    // arg 4: B precision: fp16, fp32, ...
+    // arg 5: C precision: fp16, fp32, ...
+    // arg 6: A multiplicand
+    // arg 7: A multiplicand index
+    // arg 8: B multiplicand
+    // arg 9: B multiplicand index
+    // arg 10: C accumulator
+    // arg 11: C accumulator index
+    // arg 12: metadata
+    // arg 13: metadata index
+    // arg 14: sparse_selector
+    // arg 15: saturate
+    ICHECK_EQ(op->args.size(), 16U);
+    std::string shape = Downcast<StringImm>(op->args[0])->value;
+    std::string A_layout = Downcast<StringImm>(op->args[1])->value;
+    std::string B_layout = Downcast<StringImm>(op->args[2])->value;
+    std::string A_dtype = Downcast<StringImm>(op->args[3])->value;
+    std::string B_dtype = Downcast<StringImm>(op->args[4])->value;
+    std::string C_dtype = Downcast<StringImm>(op->args[5])->value;
+    std::string a_ref = this->PrintExpr(op->args[6]);
+    std::string a_offset = this->PrintExpr(op->args[7]);
+    std::string b_ref = this->PrintExpr(op->args[8]);
+    std::string b_offset = this->PrintExpr(op->args[9]);
+    std::string c_ref = this->PrintExpr(op->args[10]);
+    std::string c_offset = this->PrintExpr(op->args[11]);
+    std::string metadata = this->PrintExpr(op->args[12]);
+    std::string metadata_offset = this->PrintExpr(op->args[13]);
+    std::string sparse_selector = this->PrintExpr(op->args[14]);
+    bool saturate = Downcast<Bool>(op->args[15])->value;
+    std::string asm_code = PrintMMAAssembly(
+        shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_offset, 
b_ref, b_offset,
+        c_ref, c_offset, metadata, metadata_offset, sparse_selector, "", true, 
saturate);
+    this->stream << asm_code;
   } else {
     CodeGenC::VisitExpr_(op, os);
   }
diff --git a/src/target/source/ptx_mma.cc b/src/target/source/ptx_mma.cc
index b618272..d04c018 100644
--- a/src/target/source/ptx_mma.cc
+++ b/src/target/source/ptx_mma.cc
@@ -23,1351 +23,543 @@
 
 #include "ptx_mma.h"
 
+#include <algorithm>
+#include <string>
+#include <tuple>
+#include <utility>
+#include <vector>
+
 namespace tvm {
 namespace codegen {
 
-std::string ReplaceMMAArgument(std::string asm_code, const std::string& 
original,
-                               const std::string& new_arg) {
-  size_t len = original.size();
-  size_t new_len = new_arg.size();
-  size_t pos = asm_code.find(original);
-  while (pos != std::string::npos) {
-    asm_code = asm_code.replace(pos, len, new_arg);
-    pos = asm_code.find(original, pos + new_len);
-  }
-  return asm_code;
-}
+// PTX related data structures and functions.
+namespace ptx {
 
-std::string PrintMMAm8n8k4Assembly(const std::string& A_layout, const 
std::string& B_layout,
-                                   const std::string& A_dtype, const 
std::string& B_dtype,
-                                   const std::string& C_dtype, const 
std::string& a_ref,
-                                   const std::string& a_bias, const 
std::string& b_ref,
-                                   const std::string& b_bias, const 
std::string& c_ref,
-                                   const std::string& c_bias, bool saturate) {
-  std::string asm_code = "";
-  std::string new_a_ref = "";
-  std::string new_b_ref = "";
-  std::string new_c_ref = "";
-  ICHECK(((A_dtype == "fp16") && (B_dtype == "fp16")) ||
-         ((A_dtype == "fp64") && (B_dtype == "fp64")));
-  ICHECK(saturate == false) << "Saturate is not allowed for m8n8k4 mma.";
-  if ((A_dtype == "fp16") && (B_dtype == "fp16")) {
-    // A/B multiplicand is fp16, SM 70 Tensor Core instructions
-    ICHECK((C_dtype == "fp16") || (C_dtype == "fp32"));
-    if (C_dtype == "fp16") {
-      // C accumulator is fp16
-      new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
-      new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
-      new_c_ref = "((unsigned *)(" + c_ref + " + " + c_bias + "))";
-      asm_code = R"(
-            {
-              __asm__ __volatile__(
-                  
"mma.sync.aligned.m8n8k4.left_layout.right_layout.f16.f16.f16.f16 "
-                  "{%0,%1,%2,%3}, {%4,%5}, {%6,%7}, "
-                  "{%8,%9,%10,%11};\n"
-                  : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
-                  : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]), 
-                    "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
-            }
-          )";
-    } else {
-      // C accumulator is fp32
-      new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
-      new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
-      new_c_ref = "((float *)(" + c_ref + " + " + c_bias + "))";
-      asm_code = R"(
-            {
-              __asm__ __volatile__(
-                  
"mma.sync.aligned.m8n8k4.left_layout.right_layout.f32.f16.f16.f32 "
-                  "{%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, "
-                  "{%12,%13,%14,%15,%16,%17,%18,%19};\n"
-                  : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]),
-                    "=f"(D[4]), "=f"(D[5]), "=f"(D[6]), "=f"(D[7])
-                  : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]), 
-                    "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]),
-                    "f"(C[4]), "f"(C[5]), "f"(C[6]), "f"(C[7]));
-            }
-          )";
-    }
+/*!
+ * \brief PTX data type.
+ * \note
+ * PTX fundamental data types:
+ * 
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#fundamental-types
+ * PTX matrix data types:
+ * 
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-data-types
+ */
+enum class DataType : int {
+  kInt4 = 0,
+  kUInt4 = 1,
+  kInt8 = 2,
+  kUInt8 = 3,
+  kInt16 = 4,
+  kUInt16 = 5,
+  kInt32 = 6,
+  kUInt32 = 7,
+  kInt64 = 8,
+  kUInt64 = 9,
+  kFloat16 = 10,
+  kBFloat16 = 11,
+  kFloat16x2 = 12,
+  kFloat32 = 13,
+  kTensorFloat32 = 14,
+  kFloat64 = 15,
+  kBit1 = 16
+};
+
+static const char* dtype_str[] = {".s4",    ".u4",  ".s8",   ".u8",  ".s16", 
".u16",
+                                  ".s32",   ".u32", ".s64",  ".u64", ".f16", 
".bf16",
+                                  ".f16x2", ".f32", ".tf32", ".f64", ".b1"};
+static const uint32_t num_bits[] = {4, 4, 8, 8, 16, 16, 32, 32, 64, 64, 16, 
16, 32, 32, 32, 64, 1};
+
+/*!
+ * \brief Create PTX data type from string.
+ */
+inline DataType DTypeFromString(const std::string str) {
+  if (str == "int4" || str == ".s4") {
+    return DataType::kInt4;
+  } else if (str == "uint4" || str == ".u4") {
+    return DataType::kUInt4;
+  } else if (str == "int8" || str == ".s8") {
+    return DataType::kInt8;
+  } else if (str == "uint8" || str == ".u8") {
+    return DataType::kUInt8;
+  } else if (str == "int16" || str == ".s16") {
+    return DataType::kInt16;
+  } else if (str == "uint16" || str == ".u16") {
+    return DataType::kUInt16;
+  } else if (str == "int32" || str == ".s32") {
+    return DataType::kInt32;
+  } else if (str == "uint32" || str == ".u32") {
+    return DataType::kUInt32;
+  } else if (str == "int64" || str == ".s64") {
+    return DataType::kInt64;
+  } else if (str == "uint64" || str == ".u64") {
+    return DataType::kUInt64;
+  } else if (str == "float16" || str == "fp16" || str == ".f16") {
+    return DataType::kFloat16;
+  } else if (str == "bfloat16" || str == "bf16") {
+    return DataType::kBFloat16;
+  } else if (str == ".f16x2") {
+    return DataType::kFloat16x2;
+  } else if (str == "float32" || str == "fp32" || str == ".f32") {
+    return DataType::kFloat32;
+  } else if (str == "tf32") {
+    return DataType::kTensorFloat32;
+  } else if (str == "float64" || str == "fp64" || str == ".f64") {
+    return DataType::kFloat64;
+  } else if (str == "int1" || str == ".b1") {
+    return DataType::kBit1;
   } else {
-    // A/B multiplicand is fp64, SM 80 Tensor Core instructions
-    ICHECK(C_dtype == "fp64");
-    ICHECK((A_layout == "row") && (B_layout == "col"))
-        << "SM80 Fp64 Tensor Core instructions "
-        << "with shape m8n8k4 expect A layout is row major and B layout is col 
major.";
-    // C accumulator is fp64
-    new_a_ref = "((double *)(" + a_ref + " + " + a_bias + "))";
-    new_b_ref = "((double *)(" + b_ref + " + " + b_bias + "))";
-    new_c_ref = "((double *)(" + c_ref + " + " + c_bias + "))";
-    asm_code = R"(
-          {
-            __asm__ __volatile__(
-                "mma.sync.aligned.m8n8k4.row.col.f64.f64.f64.f64 "
-                "{%0,%1}, {%2}, {%3}, "
-                "{%4,%5};\n"
-                : "=d"(D[0]), "=d"(D[1])
-                : "d"(A[0]), "d"(B[0]), 
-                  "d"(C[0]), "d"(C[1]));
-          }
-        )";
+    LOG(FATAL) << "Unrecognized PTX data type " << str;
+    return DataType(0);
   }
-  asm_code = ReplaceMMAArgument(asm_code, "left_layout", A_layout);
-  asm_code = ReplaceMMAArgument(asm_code, "right_layout", B_layout);
-  asm_code = ReplaceMMAArgument(asm_code, "A", new_a_ref);
-  asm_code = ReplaceMMAArgument(asm_code, "B", new_b_ref);
-  asm_code = ReplaceMMAArgument(asm_code, "C", new_c_ref);
-  asm_code = ReplaceMMAArgument(asm_code, "D", new_c_ref);
-  return asm_code;
 }
 
-std::string PrintMMAm16n8k8Assembly(const std::string& A_layout, const 
std::string& B_layout,
-                                    const std::string& A_dtype, const 
std::string& B_dtype,
-                                    const std::string& C_dtype, const 
std::string& a_ref,
-                                    const std::string& a_bias, const 
std::string& b_ref,
-                                    const std::string& b_bias, const 
std::string& c_ref,
-                                    const std::string& c_bias, bool saturate) {
-  std::string asm_code = "";
-  std::string new_a_ref = "";
-  std::string new_b_ref = "";
-  std::string new_c_ref = "";
-  ICHECK(((A_dtype == "fp16") && (B_dtype == "fp16")) ||
-         ((A_dtype == "bf16") && (B_dtype == "bf16")));
-  ICHECK(saturate == false) << "Saturate is not allowed for m16n8k8 mma.";
-  if ((A_dtype == "fp16") && (B_dtype == "fp16")) {
-    // A/B multiplicand is fp16, SM 75 Tensor Core instructions
-    ICHECK((C_dtype == "fp16") || (C_dtype == "fp32"));
-    ICHECK((A_layout == "row") && (B_layout == "col"))
-        << "SM75 Tensor Core instructions "
-        << "with shape m16n8k8 expect A layout is row major and B layout is 
col major.";
-    if (C_dtype == "fp16") {
-      // C accumulator is fp16
-      new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
-      new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
-      new_c_ref = "((unsigned *)(" + c_ref + " + " + c_bias + "))";
-      asm_code = R"(
-            {
-              __asm__ __volatile__(
-                  "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
-                  "{%0,%1}, {%2,%3}, {%5}, "
-                  "{%5,%6};\n"
-                  : "=r"(D[0]), "=r"(D[1])
-                  : "r"(A[0]), "r"(A[1]), "r"(B[0]), 
-                    "r"(C[0]), "r"(C[1]));
-            }
-          )";
-    } else {
-      // C accumulator is fp32
-      new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
-      new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
-      new_c_ref = "((float *)(" + c_ref + " + " + c_bias + "))";
-      asm_code = R"(
-            {
-              __asm__ __volatile__(
-                  "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
-                  "{%0,%1,%2,%3}, {%4,%5}, {%6}, "
-                  "{%7,%8,%9,%10};\n"
-                  : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
-                  : "r"(A[0]), "r"(A[1]), "r"(B[0]), 
-                    "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]));
-            }
-          )";
-    }
+/*!
+ * \brief Get the string representation of given PTX data type.
+ */
+inline std::string DTypeToString(DataType dtype) { return 
dtype_str[static_cast<int>(dtype)]; }
+
+/*!
+ * \brief Get the number of bits of given PTX data type.
+ */
+inline uint32_t DTypeBits(DataType dtype) { return 
num_bits[static_cast<int>(dtype)]; }
+
+/*!
+ * \brief Extract the value m, n, k from string m*n*k*
+ */
+inline std::tuple<int, int, int> ParseMMAShape(const std::string& str) {
+  size_t pos_m = str.find("m"), pos_n = str.find("n"), pos_k = str.find("k");
+  CHECK(pos_m != str.npos && pos_n != str.npos && pos_k != str.npos)
+      << "Cannot parse MMA shape " << str;
+  int m = std::stoi(str.substr(pos_m + 1, pos_n - pos_m - 1)),
+      n = std::stoi(str.substr(pos_n + 1, pos_k - pos_n - 1)), k = 
std::stoi(str.substr(pos_k + 1));
+  return std::make_tuple(m, n, k);
+}
+
+/*!
+ * \brief Layout Type
+ */
+enum class LayoutType : int { kRowMajor = 0, kColumnMajor = 1 };
+
+/*!
+ * \brief Parse layout type
+ */
+LayoutType LayoutTypeFromString(const std::string& str) {
+  if (str == "row") {
+    return LayoutType::kRowMajor;
+  } else if (str == "col") {
+    return LayoutType::kColumnMajor;
   } else {
-    // A/B multiplicand is bf16, SM 80 Tensor Core instructions
-    ICHECK(C_dtype == "fp32");
-    ICHECK((A_layout == "row") && (B_layout == "col"))
-        << "SM80 Tensor Core instructions "
-        << "with shape m16n8k8 expect A layout is row major and B layout is 
col major.";
-    // C accumulator is fp32
-    new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
-    new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
-    new_c_ref = "((float *)(" + c_ref + " + " + c_bias + "))";
-    asm_code = R"(
-          {
-            __asm__ __volatile__(
-                "mma.sync.aligned.m16n8k8.row.col.f32.bf16.bf16.f32 "
-                "{%0,%1,%2,%3}, {%4,%5}, {%6}, "
-                "{%7,%8,%9,%10};\n"
-                : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
-                : "r"(A[0]), "r"(A[1]), "r"(B[0]), 
-                  "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]));
-          }
-        )";
+    LOG(FATAL) << "Unrecognized layout type " << str;
+    return LayoutType::kRowMajor;
   }
-  asm_code = ReplaceMMAArgument(asm_code, "left_layout", A_layout);
-  asm_code = ReplaceMMAArgument(asm_code, "right_layout", B_layout);
-  asm_code = ReplaceMMAArgument(asm_code, "A", new_a_ref);
-  asm_code = ReplaceMMAArgument(asm_code, "B", new_b_ref);
-  asm_code = ReplaceMMAArgument(asm_code, "C", new_c_ref);
-  asm_code = ReplaceMMAArgument(asm_code, "D", new_c_ref);
-  return asm_code;
 }
 
-std::string PrintMMAm8n8k16Assembly(const std::string& A_layout, const 
std::string& B_layout,
-                                    const std::string& A_dtype, const 
std::string& B_dtype,
-                                    const std::string& C_dtype, const 
std::string& a_ref,
-                                    const std::string& a_bias, const 
std::string& b_ref,
-                                    const std::string& b_bias, const 
std::string& c_ref,
-                                    const std::string& c_bias, bool saturate) {
-  std::string asm_code = "";
-  std::string new_a_ref = "";
-  std::string new_b_ref = "";
-  std::string new_c_ref = "";
-  ICHECK(((A_dtype == "int8") && (B_dtype == "int8")) ||
-         ((A_dtype == "uint8") && (B_dtype == "int8")) ||
-         ((A_dtype == "int8") && (B_dtype == "uint8")) ||
-         ((A_dtype == "uint8") && (B_dtype == "uint8")));
-  if ((A_dtype == "int8") && (B_dtype == "int8")) {
-    // A/B multiplicand is int8, SM 75 Tensor Core instructions
-    ICHECK(C_dtype == "int32");
-    ICHECK((A_layout == "row") && (B_layout == "col"))
-        << "SM75 Tensor Core instructions "
-        << "with shape m8n8k16 expect A layout is row major and B layout is 
col major.";
-    // C accumulator is int32
-    if (!saturate) {
-      // no saturate
-      new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
-      new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
-      new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
-      asm_code = R"(
-            {
-              __asm__ __volatile__(
-                  "mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 "
-                  "{%0,%1}, {%2}, {%3}, "
-                  "{%4,%5};\n"
-                  : "=r"(D[0]), "=r"(D[1])
-                  : "r"(A[0]), "r"(B[0]), 
-                    "r"(C[0]), "r"(C[1]));
-            }
-          )";
-    } else {
-      // saturate
-      new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
-      new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
-      new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
-      asm_code = R"(
-            {
-              __asm__ __volatile__(
-                  "mma.sync.aligned.m8n8k16.row.col.satfinite.s32.s8.s8.s32 "
-                  "{%0,%1}, {%2}, {%3}, "
-                  "{%4,%5};\n"
-                  : "=r"(D[0]), "=r"(D[1])
-                  : "r"(A[0]), "r"(B[0]), 
-                    "r"(C[0]), "r"(C[1]));
-            }
-          )";
-    }
-  } else if ((A_dtype == "uint8") && (B_dtype == "int8")) {
-    // A multiplicand is uint8, B multiplicand is int8
-    // SM 75 Tensor Core instructions
-    ICHECK(C_dtype == "int32");
-    ICHECK((A_layout == "row") && (B_layout == "col"))
-        << "SM75 Tensor Core instructions "
-        << "with shape m8n8k16 expect A layout is row major and B layout is 
col major.";
-    // C accumulator is int32
-    if (!saturate) {
-      // no saturate
-      new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
-      new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
-      new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
-      asm_code = R"(
-            {
-              __asm__ __volatile__(
-                  "mma.sync.aligned.m8n8k16.row.col.s32.u8.s8.s32 "
-                  "{%0,%1}, {%2}, {%3}, "
-                  "{%4,%5};\n"
-                  : "=r"(D[0]), "=r"(D[1])
-                  : "r"(A[0]), "r"(B[0]), 
-                    "r"(C[0]), "r"(C[1]));
-            }
-          )";
-    } else {
-      // saturate
-      new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
-      new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
-      new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
-      asm_code = R"(
-            {
-              __asm__ __volatile__(
-                  "mma.sync.aligned.m8n8k16.row.col.satfinite.s32.u8.s8.s32 "
-                  "{%0,%1}, {%2}, {%3}, "
-                  "{%4,%5};\n"
-                  : "=r"(D[0]), "=r"(D[1])
-                  : "r"(A[0]), "r"(B[0]), 
-                    "r"(C[0]), "r"(C[1]));
-            }
-          )";
-    }
-  } else if ((A_dtype == "int8") && (B_dtype == "uint8")) {
-    // A multiplicand is int8, B multiplicand is uint8
-    // SM 75 Tensor Core instructions
-    ICHECK(C_dtype == "int32");
-    ICHECK((A_layout == "row") && (B_layout == "col"))
-        << "SM75 Tensor Core instructions "
-        << "with shape m8n8k16 expect A layout is row major and B layout is 
col major.";
-    // C accumulator is int32
-    if (!saturate) {
-      // no saturate
-      new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
-      new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
-      new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
-      asm_code = R"(
-            {
-              __asm__ __volatile__(
-                  "mma.sync.aligned.m8n8k16.row.col.s32.s8.u8.s32 "
-                  "{%0,%1}, {%2}, {%3}, "
-                  "{%4,%5};\n"
-                  : "=r"(D[0]), "=r"(D[1])
-                  : "r"(A[0]), "r"(B[0]), 
-                    "r"(C[0]), "r"(C[1]));
-            }
-          )";
-    } else {
-      // saturate
-      new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
-      new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
-      new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
-      asm_code = R"(
-            {
-              __asm__ __volatile__(
-                  "mma.sync.aligned.m8n8k16.row.col.satfinite.s32.s8.u8.s32 "
-                  "{%0,%1}, {%2}, {%3}, "
-                  "{%4,%5};\n"
-                  : "=r"(D[0]), "=r"(D[1])
-                  : "r"(A[0]), "r"(B[0]), 
-                    "r"(C[0]), "r"(C[1]));
-            }
-          )";
-    }
-  } else {
-    // A/B multiplicand is uint8, SM 75 Tensor Core instructions
-    ICHECK(C_dtype == "int32");
-    ICHECK((A_layout == "row") && (B_layout == "col"))
-        << "SM75 Tensor Core instructions "
-        << "with shape m8n8k16 expect A layout is row major and B layout is 
col major.";
-    // C accumulator is int32
-    if (!saturate) {
-      // no saturate
-      new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
-      new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
-      new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
-      asm_code = R"(
-            {
-              __asm__ __volatile__(
-                  "mma.sync.aligned.m8n8k16.row.col.s32.u8.u8.s32 "
-                  "{%0,%1}, {%2}, {%3}, "
-                  "{%4,%5};\n"
-                  : "=r"(D[0]), "=r"(D[1])
-                  : "r"(A[0]), "r"(B[0]), 
-                    "r"(C[0]), "r"(C[1]));
-            }
-          )";
-    } else {
-      // saturate
-      new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
-      new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
-      new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
-      asm_code = R"(
-            {
-              __asm__ __volatile__(
-                  "mma.sync.aligned.m8n8k16.row.col.satfinite.s32.u8.u8.s32 "
-                  "{%0,%1}, {%2}, {%3}, "
-                  "{%4,%5};\n"
-                  : "=r"(D[0]), "=r"(D[1])
-                  : "r"(A[0]), "r"(B[0]), 
-                    "r"(C[0]), "r"(C[1]));
-            }
-          )";
-    }
+static const char* layout_type_str[] = {"row", "col"};
+
+/*!
+ * \brief Convert layout type to string.
+ */
+inline std::string LayoutTypeToString(LayoutType layout) {
+  return layout_type_str[static_cast<int>(layout)];
+}
+
+/*!
+ * \brief MMA Configurations, used to determine validity.
+ */
+struct MMAConfig {
+  explicit MMAConfig(int m, int n, int k, DataType dtype_mul, bool use_bit_op, 
bool sparse)
+      : m(m), n(n), k(k), dtype_mul(dtype_mul), use_bit_op(use_bit_op), 
sparse(sparse) {}
+  int m, n, k;
+  DataType dtype_mul;
+  bool use_bit_op;
+  bool sparse;
+  inline bool operator==(const MMAConfig& other) {
+    return m == other.m && n == other.n && k == other.k && dtype_mul == 
other.dtype_mul &&
+           use_bit_op == other.use_bit_op && sparse == other.sparse;
+  }
+};
+
+/*!
+ * \brief Valid MMA configurations
+ * \note Reference:
+ * 
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-shape
+ */
+const MMAConfig valid_mma_configs[] = {
+    MMAConfig(8, 8, 4, DataType::kFloat64, false, false),
+    MMAConfig(8, 8, 4, DataType::kFloat16, false, false),
+    MMAConfig(16, 8, 8, DataType::kFloat16, false, false),
+    MMAConfig(16, 8, 16, DataType::kFloat16, false, false),
+    MMAConfig(16, 8, 8, DataType::kBFloat16, false, false),
+    MMAConfig(16, 8, 16, DataType::kBFloat16, false, false),
+    MMAConfig(16, 8, 4, DataType::kTensorFloat32, false, false),
+    MMAConfig(16, 8, 8, DataType::kTensorFloat32, false, false),
+    MMAConfig(8, 8, 16, DataType::kInt8, false, false),
+    MMAConfig(16, 8, 16, DataType::kInt8, false, false),
+    MMAConfig(16, 8, 32, DataType::kInt8, false, false),
+    MMAConfig(8, 8, 16, DataType::kUInt8, false, false),
+    MMAConfig(16, 8, 16, DataType::kUInt8, false, false),
+    MMAConfig(16, 8, 32, DataType::kUInt8, false, false),
+    MMAConfig(8, 8, 32, DataType::kInt4, false, false),
+    MMAConfig(16, 8, 32, DataType::kInt4, false, false),
+    MMAConfig(16, 8, 64, DataType::kInt4, false, false),
+    MMAConfig(8, 8, 32, DataType::kUInt4, false, false),
+    MMAConfig(16, 8, 32, DataType::kUInt4, false, false),
+    MMAConfig(16, 8, 64, DataType::kUInt4, false, false),
+    MMAConfig(8, 8, 128, DataType::kBit1, true, false),
+    MMAConfig(16, 8, 128, DataType::kBit1, true, false),
+    MMAConfig(16, 8, 256, DataType::kBit1, true, false),
+    MMAConfig(16, 8, 16, DataType::kFloat16, false, true),
+    MMAConfig(16, 8, 32, DataType::kFloat16, false, true),
+    MMAConfig(16, 8, 16, DataType::kBFloat16, false, true),
+    MMAConfig(16, 8, 32, DataType::kBFloat16, false, true),
+    MMAConfig(16, 8, 8, DataType::kTensorFloat32, false, true),
+    MMAConfig(16, 8, 16, DataType::kTensorFloat32, false, true),
+    MMAConfig(16, 8, 32, DataType::kInt8, false, true),
+    MMAConfig(16, 8, 64, DataType::kInt8, false, true),
+    MMAConfig(16, 8, 32, DataType::kUInt8, false, true),
+    MMAConfig(16, 8, 64, DataType::kUInt8, false, true),
+    MMAConfig(16, 8, 64, DataType::kInt4, false, true),
+    MMAConfig(16, 8, 128, DataType::kInt4, false, true),
+    MMAConfig(16, 8, 64, DataType::kUInt4, false, true),
+    MMAConfig(16, 8, 128, DataType::kUInt4, false, true),
+};
+
+/*!
+ * \brief Check whether the multiplicand data type and accumulator data type 
is valid for MMA
+ * computation.
+ * \param dtype_a The data type of multiplicand a.
+ * \param dtype_b The data type of multiplicand b.
+ * \param dtype_c The data type of accumulator c.
+ * \note Reference:
+ * 
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-data-types
+ */
+void CheckMMADTypeCompatible(DataType dtype_a, DataType dtype_b, DataType 
dtype_c) {
+  std::string ab_not_match_err_str = "The multiplicands' data type " + 
DTypeToString(dtype_a) +
+                                     DTypeToString(dtype_b) + " do not match.";
+  // check a and b
+  switch (dtype_a) {
+    case DataType::kBit1:
+    case DataType::kFloat16:
+    case DataType::kBFloat16:
+    case DataType::kTensorFloat32:
+    case DataType::kFloat64:
+      CHECK(dtype_a == dtype_b) << ab_not_match_err_str;
+      break;
+    case DataType::kInt4:
+    case DataType::kUInt4:
+      CHECK(dtype_b == DataType::kInt4 || dtype_b == DataType::kUInt4) << 
ab_not_match_err_str;
+      break;
+    case DataType::kInt8:
+    case DataType::kUInt8:
+      CHECK(dtype_b == DataType::kInt8 || dtype_b == DataType::kUInt8) << 
ab_not_match_err_str;
+      break;
+    default:
+      CHECK(false) << "Invalid multiplicand data types: " << 
DTypeToString(dtype_a)
+                   << DTypeToString(dtype_b);
+  }
+  // check a,b and c
+  switch (dtype_a) {
+    case DataType::kBit1:
+    case DataType::kInt4:
+    case DataType::kUInt4:
+    case DataType::kInt8:
+    case DataType::kUInt8:
+      CHECK(dtype_c == DataType::kInt32)
+          << "For multiplicand data type " << DTypeToString(dtype_a) << 
DTypeToString(dtype_b)
+          << ", accumulator data type should be s32.";
+      break;
+    case DataType::kFloat16:
+      CHECK(dtype_c == DataType::kFloat16 || dtype_c == DataType::kFloat32)
+          << "For multiplicand data type f16, accumulator data type should be 
f16/f32.";
+      break;
+    case DataType::kBFloat16:
+    case DataType::kTensorFloat32:
+      CHECK(dtype_c == DataType::kFloat32)
+          << "For multiplicand data type bf16/tf32, accumulator data type can 
only be f32.";
+      break;
+    case DataType::kFloat64:
+      CHECK(dtype_c == DataType::kFloat64)
+          << "For multiplicand data type f64, accumulator data type can only 
be f64.";
+      break;
+    default:
+      CHECK(false) << "Invalid multiplicand/accumulator data types: " << 
DTypeToString(dtype_a)
+                   << DTypeToString(dtype_b) << DTypeToString(dtype_c) << ".";
   }
-  asm_code = ReplaceMMAArgument(asm_code, "left_layout", A_layout);
-  asm_code = ReplaceMMAArgument(asm_code, "right_layout", B_layout);
-  asm_code = ReplaceMMAArgument(asm_code, "A", new_a_ref);
-  asm_code = ReplaceMMAArgument(asm_code, "B", new_b_ref);
-  asm_code = ReplaceMMAArgument(asm_code, "C", new_c_ref);
-  asm_code = ReplaceMMAArgument(asm_code, "D", new_c_ref);
-  return asm_code;
 }
 
-std::string PrintMMAm8n8k32Assembly(const std::string& A_layout, const 
std::string& B_layout,
-                                    const std::string& A_dtype, const 
std::string& B_dtype,
-                                    const std::string& C_dtype, const 
std::string& a_ref,
-                                    const std::string& a_bias, const 
std::string& b_ref,
-                                    const std::string& b_bias, const 
std::string& c_ref,
-                                    const std::string& c_bias, bool saturate) {
-  std::string asm_code = "";
-  std::string new_a_ref = "";
-  std::string new_b_ref = "";
-  std::string new_c_ref = "";
-  ICHECK(((A_dtype == "int4") && (B_dtype == "int4")) ||
-         ((A_dtype == "uint4") && (B_dtype == "int4")) ||
-         ((A_dtype == "int4") && (B_dtype == "uint4")) ||
-         ((A_dtype == "uint4") && (B_dtype == "uint4")));
-  if ((A_dtype == "int4") && (B_dtype == "int4")) {
-    // A/B multiplicand is int4, SM 75 Tensor Core instructions
-    ICHECK(C_dtype == "int32");
-    ICHECK((A_layout == "row") && (B_layout == "col"))
-        << "SM75 Tensor Core instructions "
-        << "with shape m8n8k32 expect A layout is row major and B layout is 
col major.";
-    // C accumulator is int32
-    if (!saturate) {
-      // no saturate
-      new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
-      new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
-      new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
-      asm_code = R"(
-            {
-              __asm__ __volatile__(
-                  "mma.sync.aligned.m8n8k32.row.col.s32.s4.s4.s32 "
-                  "{%0,%1}, {%2}, {%3}, "
-                  "{%4,%5};\n"
-                  : "=r"(D[0]), "=r"(D[1])
-                  : "r"(A[0]), "r"(B[0]), 
-                    "r"(C[0]), "r"(C[1]));
-            }
-          )";
-    } else {
-      // saturate
-      new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
-      new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
-      new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
-      asm_code = R"(
-            {
-              __asm__ __volatile__(
-                  "mma.sync.aligned.m8n8k32.row.col.satfinite.s32.s4.s4.s32 "
-                  "{%0,%1}, {%2}, {%3}, "
-                  "{%4,%5};\n"
-                  : "=r"(D[0]), "=r"(D[1])
-                  : "r"(A[0]), "r"(B[0]), 
-                    "r"(C[0]), "r"(C[1]));
-            }
-          )";
-    }
-  } else if ((A_dtype == "uint4") && (B_dtype == "int4")) {
-    // A multiplicand is uint4, B multiplicand is int4
-    // SM 75 Tensor Core instructions
-    ICHECK(C_dtype == "int32");
-    ICHECK((A_layout == "row") && (B_layout == "col"))
-        << "SM75 Tensor Core instructions "
-        << "with shape m8n8k32 expect A layout is row major and B layout is 
col major.";
-    // C accumulator is int32
-    if (!saturate) {
-      // no saturate
-      new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
-      new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
-      new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
-      asm_code = R"(
-            {
-              __asm__ __volatile__(
-                  "mma.sync.aligned.m8n8k32.row.col.s32.u4.s4.s32 "
-                  "{%0,%1}, {%2}, {%3}, "
-                  "{%4,%5};\n"
-                  : "=r"(D[0]), "=r"(D[1])
-                  : "r"(A[0]), "r"(B[0]), 
-                    "r"(C[0]), "r"(C[1]));
-            }
-          )";
-    } else {
-      // saturate
-      new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
-      new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
-      new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
-      asm_code = R"(
-            {
-              __asm__ __volatile__(
-                  "mma.sync.aligned.m8n8k32.row.col.satfinite.s32.u4.s4.s32 "
-                  "{%0,%1}, {%2}, {%3}, "
-                  "{%4,%5};\n"
-                  : "=r"(D[0]), "=r"(D[1])
-                  : "r"(A[0]), "r"(B[0]), 
-                    "r"(C[0]), "r"(C[1]));
-            }
-          )";
-    }
-  } else if ((A_dtype == "int4") && (B_dtype == "uint4")) {
-    // A multiplicand is int4, B multiplicand is uint4
-    // SM 75 Tensor Core instructions
-    ICHECK(C_dtype == "int32");
-    ICHECK((A_layout == "row") && (B_layout == "col"))
-        << "SM75 Tensor Core instructions "
-        << "with shape m8n8k32 expect A layout is row major and B layout is 
col major.";
-    // C accumulator is int32
-    if (!saturate) {
-      // no saturate
-      new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
-      new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
-      new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
-      asm_code = R"(
-            {
-              __asm__ __volatile__(
-                  "mma.sync.aligned.m8n8k32.row.col.s32.s4.u4.s32 "
-                  "{%0,%1}, {%2}, {%3}, "
-                  "{%4,%5};\n"
-                  : "=r"(D[0]), "=r"(D[1])
-                  : "r"(A[0]), "r"(B[0]), 
-                    "r"(C[0]), "r"(C[1]));
-            }
-          )";
-    } else {
-      // saturate
-      new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
-      new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
-      new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
-      asm_code = R"(
-            {
-              __asm__ __volatile__(
-                  "mma.sync.aligned.m8n8k32.row.col.satfinite.s32.s4.u4.s32 "
-                  "{%0,%1}, {%2}, {%3}, "
-                  "{%4,%5};\n"
-                  : "=r"(D[0]), "=r"(D[1])
-                  : "r"(A[0]), "r"(B[0]), 
-                    "r"(C[0]), "r"(C[1]));
-            }
-          )";
-    }
-  } else {
-    // A/B multiplicand is uint4, SM 75 Tensor Core instructions
-    ICHECK(C_dtype == "int32");
-    ICHECK((A_layout == "row") && (B_layout == "col"))
-        << "SM75 Tensor Core instructions "
-        << "with shape m8n8k32 expect A layout is row major and B layout is 
col major.";
-    // C accumulator is int32
-    if (!saturate) {
-      // no saturate
-      new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
-      new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
-      new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
-      asm_code = R"(
-            {
-              __asm__ __volatile__(
-                  "mma.sync.aligned.m8n8k32.row.col.s32.u4.u4.s32 "
-                  "{%0,%1}, {%2}, {%3}, "
-                  "{%4,%5};\n"
-                  : "=r"(D[0]), "=r"(D[1])
-                  : "r"(A[0]), "r"(B[0]), 
-                    "r"(C[0]), "r"(C[1]));
-            }
-          )";
-    } else {
-      // saturate
-      new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
-      new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
-      new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
-      asm_code = R"(
-            {
-              __asm__ __volatile__(
-                  "mma.sync.aligned.m8n8k32.row.col.satfinite.s32.u4.u4.s32 "
-                  "{%0,%1}, {%2}, {%3}, "
-                  "{%4,%5};\n"
-                  : "=r"(D[0]), "=r"(D[1])
-                  : "r"(A[0]), "r"(B[0]), 
-                    "r"(C[0]), "r"(C[1]));
-            }
-          )";
+/*!
+ * \brief Check whether the given configuration is valid for MMA computation.
+ * \param m The M in mMnNkK of MMA instructions.
+ * \param n The N in mMnNkK of MMA instructions.
+ * \param k The K in mMnNkK of MMA instructions.
+ * \param layout_a The layout of multiplicand A (row/col).
+ * \param layout_b The layout of multiplicand B (row/col).
+ * \param dtype_a The data type of multiplicand A.
+ * \param dtype_b The data type of multiplicand B.
+ * \param dtype_c The data type of accumulator C.
+ * \param bit_op The bit operator for 1-bit MMA computation, can be 
"xor"/"and" or ""(if it's not
+ * 1-bit MMA).
+ * \param sparse Whether it's Sparse MMA or not.
+ * \param saturate Whether saturate output or not.
+ */
+void CheckMMAConfigValidity(int m, int n, int k, LayoutType layout_a, 
LayoutType layout_b,
+                            DataType dtype_a, DataType dtype_b, DataType 
dtype_c,
+                            const std::string& bit_op, bool sparse, bool 
saturate) {
+  CHECK(bit_op == "xor" || bit_op == "and" || bit_op == "")
+      << "Unrecognized 1-bit operation " << bit_op << " , can only be 
xor/and.";
+  bool use_bit_op = !bit_op.empty();
+  if (use_bit_op) {
+    CHECK(dtype_a == DataType::kBit1) << "Bit operator is only compatible with 
1-bit multiplicand.";
+  }
+  CheckMMADTypeCompatible(dtype_a, dtype_b, dtype_c);
+  if (saturate) {
+    CHECK(dtype_a == DataType::kInt4 || dtype_a == DataType::kUInt4 || dtype_a 
== DataType::kInt8 ||
+          dtype_a == DataType::kUInt8)
+        << "Output saturation only applicable to multiplicand type 
s4/u4/s8/u8.";
+  }
+
+  if (!(m == 8 && n == 8 && k == 4 && dtype_a == ptx::DataType::kFloat16)) {
+    // Only MMA on m8n8k4 for fp16 supports customized layouts.
+    CHECK(layout_a == LayoutType::kRowMajor && layout_b == 
LayoutType::kColumnMajor)
+        << "Invalid layout combination " << LayoutTypeToString(layout_a) << ","
+        << LayoutTypeToString(layout_b) << ".";
+  }
+
+  MMAConfig config(m, n, k, dtype_a, use_bit_op, sparse);
+  bool match = false;
+  for (const MMAConfig& valid_config : valid_mma_configs) {
+    if (config == valid_config) {
+      match = true;
+      break;
     }
   }
-  asm_code = ReplaceMMAArgument(asm_code, "left_layout", A_layout);
-  asm_code = ReplaceMMAArgument(asm_code, "right_layout", B_layout);
-  asm_code = ReplaceMMAArgument(asm_code, "A", new_a_ref);
-  asm_code = ReplaceMMAArgument(asm_code, "B", new_b_ref);
-  asm_code = ReplaceMMAArgument(asm_code, "C", new_c_ref);
-  asm_code = ReplaceMMAArgument(asm_code, "D", new_c_ref);
-  return asm_code;
+  CHECK(match) << "Cannot find matched MMA configurations.";
 }
 
-std::string PrintMMAm16n8k4Assembly(const std::string& A_layout, const 
std::string& B_layout,
-                                    const std::string& A_dtype, const 
std::string& B_dtype,
-                                    const std::string& C_dtype, const 
std::string& a_ref,
-                                    const std::string& a_bias, const 
std::string& b_ref,
-                                    const std::string& b_bias, const 
std::string& c_ref,
-                                    const std::string& c_bias, bool saturate) {
-  std::string asm_code = "";
-  std::string new_a_ref = "";
-  std::string new_b_ref = "";
-  std::string new_c_ref = "";
-  ICHECK((A_dtype == "tf32") && (B_dtype == "tf32"));
-  ICHECK(saturate == false) << "Saturate is not allowed for m16n8k4 mma.";
-  // A/B multiplicand is tf32, SM 80 Tensor Core instructions
-  ICHECK(C_dtype == "fp32");
-  ICHECK((A_layout == "row") && (B_layout == "col"))
-      << "SM80 Tensor Core instructions "
-      << "with shape m16n8k4 expect A layout is row major and B layout is col 
major.";
-  // C accumulator is fp32
-  new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
-  new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
-  new_c_ref = "((float *)(" + c_ref + " + " + c_bias + "))";
-  asm_code = R"(
-        {
-          __asm__ __volatile__(
-              "mma.sync.aligned.m16n8k4.row.col.f32.tf32.tf32.f32 "
-              "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, "
-              "{%10,%11,%12,%13};\n"
-              : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
-              : "r"(A[0]), "r"(A[1]), "f"(A[2]), "r"(A[3]),
-                "r"(B[0]), "r"(B[1]), 
-                "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]));
-        }
-      )";
-  asm_code = ReplaceMMAArgument(asm_code, "left_layout", A_layout);
-  asm_code = ReplaceMMAArgument(asm_code, "right_layout", B_layout);
-  asm_code = ReplaceMMAArgument(asm_code, "A", new_a_ref);
-  asm_code = ReplaceMMAArgument(asm_code, "B", new_b_ref);
-  asm_code = ReplaceMMAArgument(asm_code, "C", new_c_ref);
-  asm_code = ReplaceMMAArgument(asm_code, "D", new_c_ref);
-  return asm_code;
-}
+/*!
+ * \brief Fragment attributes
+ */
+class FragAttrs {
+ public:
+  explicit FragAttrs(char reg_type, uint32_t size, std::string ptr_type)
+      : reg_type(reg_type), size(size), ptr_type(ptr_type) {}
+  /*! \brief PTX register type */
+  char reg_type;
+  /*! \brief Fragment size */
+  uint32_t size;
+  /*! \brief Fragment pointer type */
+  std::string ptr_type;
+};
 
-std::string PrintMMAm16n8k16Assembly(const std::string& A_layout, const 
std::string& B_layout,
-                                     const std::string& A_dtype, const 
std::string& B_dtype,
-                                     const std::string& C_dtype, const 
std::string& a_ref,
-                                     const std::string& a_bias, const 
std::string& b_ref,
-                                     const std::string& b_bias, const 
std::string& c_ref,
-                                     const std::string& c_bias, bool saturate) 
{
-  std::string asm_code = "";
-  std::string new_a_ref = "";
-  std::string new_b_ref = "";
-  std::string new_c_ref = "";
-  ICHECK(((A_dtype == "fp16") && (B_dtype == "fp16")) ||
-         ((A_dtype == "bf16") && (B_dtype == "bf16")) ||
-         ((A_dtype == "int8") && (B_dtype == "int8")) ||
-         ((A_dtype == "uint8") && (B_dtype == "int8")) ||
-         ((A_dtype == "int8") && (B_dtype == "uint8")) ||
-         ((A_dtype == "uint8") && (B_dtype == "uint8")));
-  if ((A_dtype == "fp16") && (B_dtype == "fp16")) {
-    ICHECK(saturate == false) << "Saturate is not allowed for m16n8k8 fp16 
mma.";
-    // A/B multiplicand is fp16, SM 80 Tensor Core instructions
-    ICHECK((C_dtype == "fp16") || (C_dtype == "fp32"));
-    ICHECK((A_layout == "row") && (B_layout == "col"))
-        << "SM80 Tensor Core instructions "
-        << "with shape m16n8k16 expect A layout is row major and B layout is 
col major.";
-    if (C_dtype == "fp16") {
-      // C accumulator is fp16
-      new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
-      new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
-      new_c_ref = "((unsigned *)(" + c_ref + " + " + c_bias + "))";
-      asm_code = R"(
-            {
-              __asm__ __volatile__(
-                  "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
-                  "{%0,%1}, {%2,%3,%4,%5}, {%6,%7}, "
-                  "{%8,%9};\n"
-                  : "=r"(D[0]), "=r"(D[1])
-                  : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
-                    "r"(B[0]), "r"(B[1]), 
-                    "r"(C[0]), "r"(C[1]));
-            }
-          )";
-    } else {
-      // C accumulator is fp32
-      new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
-      new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
-      new_c_ref = "((float *)(" + c_ref + " + " + c_bias + "))";
-      asm_code = R"(
-            {
-              __asm__ __volatile__(
-                  "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
-                  "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, "
-                  "{%10,%11,%12,%13};\n"
-                  : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
-                  : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
-                    "r"(B[0]), "r"(B[1]),
-                    "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]));
-            }
-          )";
-    }
-  } else if ((A_dtype == "bf16") && (B_dtype == "bf16")) {
-    // A/B multiplicand is bf16, SM 80 Tensor Core instructions
-    ICHECK(saturate == false) << "Saturate is not allowed for m16n8k8 bf16 
mma.";
-    ICHECK(C_dtype == "fp32");
-    ICHECK((A_layout == "row") && (B_layout == "col"))
-        << "SM80 Tensor Core instructions "
-        << "with shape m16n8k16 expect A layout is row major and B layout is 
col major.";
-    // C accumulator is fp32
-    new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
-    new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
-    new_c_ref = "((float *)(" + c_ref + " + " + c_bias + "))";
-    asm_code = R"(
-          {
-            __asm__ __volatile__(
-                "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
-                "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, "
-                "{%10,%11,%12,%13};\n"
-                : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
-                : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
-                  "r"(B[0]), "r"(B[1]),
-                  "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]));
-          }
-        )";
-  } else if ((A_dtype == "int8") && (B_dtype == "int8")) {
-    // A/B multiplicand is int8, SM 80 Tensor Core instructions
-    ICHECK(C_dtype == "int32");
-    ICHECK((A_layout == "row") && (B_layout == "col"))
-        << "SM80 Tensor Core instructions "
-        << "with shape m16n8k16 expect A layout is row major and B layout is 
col major.";
-    // C accumulator is int32
-    if (!saturate) {
-      // no saturate
-      new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
-      new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
-      new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
-      asm_code = R"(
-            {
-              __asm__ __volatile__(
-                  "mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 "
-                  "{%0,%1,%2,%3}, {%4,%5}, {%6}, "
-                  "{%7,%8,%9,%10};\n"
-                  : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
-                  : "r"(A[0]), "r"(A[1]), "r"(B[0]), 
-                    "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
-            }
-          )";
-    } else {
-      // saturate
-      new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
-      new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
-      new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
-      asm_code = R"(
-            {
-              __asm__ __volatile__(
-                  "mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite "
-                  "{%0,%1,%2,%3}, {%4,%5}, {%6}, "
-                  "{%7,%8,%9,%10};\n"
-                  : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
-                  : "r"(A[0]), "r"(A[1]), "r"(B[0]), 
-                    "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
-            }
-          )";
-    }
-  } else if ((A_dtype == "uint8") && (B_dtype == "int8")) {
-    // A multiplicand is uint8, B multiplicand is int8
-    // SM 80 Tensor Core instructions
-    ICHECK(C_dtype == "int32");
-    ICHECK((A_layout == "row") && (B_layout == "col"))
-        << "SM80 Tensor Core instructions "
-        << "with shape m16n8k16 expect A layout is row major and B layout is 
col major.";
-    // C accumulator is int32
-    if (!saturate) {
-      // no saturate
-      new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
-      new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
-      new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
-      asm_code = R"(
-            {
-              __asm__ __volatile__(
-                  "mma.sync.aligned.m16n8k16.row.col.s32.u8.s8.s32 "
-                  "{%0,%1,%2,%3}, {%4,%5}, {%6}, "
-                  "{%7,%8,%9,%10};\n"
-                  : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
-                  : "r"(A[0]), "r"(A[1]), "r"(B[0]), 
-                    "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
-            }
-          )";
-    } else {
-      // saturate
-      new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
-      new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
-      new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
-      asm_code = R"(
-            {
-              __asm__ __volatile__(
-                  "mma.sync.aligned.m16n8k16.row.col.s32.u8.s8.s32.satfinite "
-                  "{%0,%1,%2,%3}, {%4,%5}, {%6}, "
-                  "{%7,%8,%9,%10};\n"
-                  : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
-                  : "r"(A[0]), "r"(A[1]), "r"(B[0]), 
-                    "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
-            }
-          )";
-    }
-  } else if ((A_dtype == "int8") && (B_dtype == "uint8")) {
-    // A multiplicand is int8, B multiplicand is uint8
-    // SM 80 Tensor Core instructions
-    ICHECK(C_dtype == "int32");
-    ICHECK((A_layout == "row") && (B_layout == "col"))
-        << "SM80 Tensor Core instructions "
-        << "with shape m16n8k16 expect A layout is row major and B layout is 
col major.";
-    // C accumulator is int32
-    if (!saturate) {
-      // no saturate
-      new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
-      new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
-      new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
-      asm_code = R"(
-            {
-              __asm__ __volatile__(
-                  "mma.sync.aligned.m16n8k16.row.col.s32.s8.u8.s32 "
-                  "{%0,%1,%2,%3}, {%4,%5}, {%6}, "
-                  "{%7,%8,%9,%10};\n"
-                  : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
-                  : "r"(A[0]), "r"(A[1]), "r"(B[0]), 
-                    "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
-            }
-          )";
-    } else {
-      // saturate
-      new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
-      new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
-      new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
-      asm_code = R"(
-            {
-              __asm__ __volatile__(
-                  "mma.sync.aligned.m16n8k16.row.col.s32.s8.u8.s32.satfinite "
-                  "{%0,%1,%2,%3}, {%4,%5}, {%6}, "
-                  "{%7,%8,%9,%10};\n"
-                  : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
-                  : "r"(A[0]), "r"(A[1]), "r"(B[0]), 
-                    "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
-            }
-          )";
-    }
-  } else {
-    // A/B multiplicand is uint8, SM 80 Tensor Core instructions
-    ICHECK(C_dtype == "int32");
-    ICHECK((A_layout == "row") && (B_layout == "col"))
-        << "SM80 Tensor Core instructions "
-        << "with shape m16n8k16 expect A layout is row major and B layout is 
col major.";
-    // C accumulator is int32
-    if (!saturate) {
-      // no saturate
-      new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
-      new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
-      new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
-      asm_code = R"(
-            {
-              __asm__ __volatile__(
-                  "mma.sync.aligned.m16n8k16.row.col.s32.u8.u8.s32 "
-                  "{%0,%1,%2,%3}, {%4,%5}, {%6}, "
-                  "{%7,%8,%9,%10};\n"
-                  : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
-                  : "r"(A[0]), "r"(A[1]), "r"(B[0]), 
-                    "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
-            }
-          )";
-    } else {
-      // saturate
-      new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
-      new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
-      new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
-      asm_code = R"(
-            {
-              __asm__ __volatile__(
-                  "mma.sync.aligned.m16n8k16.row.col.s32.u8.u8.s32.satfinite "
-                  "{%0,%1,%2,%3}, {%4,%5}, {%6}, "
-                  "{%7,%8,%9,%10};\n"
-                  : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
-                  : "r"(A[0]), "r"(A[1]), "r"(B[0]), 
-                    "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
-            }
-          )";
-    }
+/*!
+ * \brief Fragment attributes of given data type.
+ */
+inline FragAttrs GetFragAttrs(DataType dtype) {
+  switch (dtype) {
+    case DataType::kBit1:
+    case DataType::kInt4:
+    case DataType::kUInt4:
+    case DataType::kInt8:
+    case DataType::kUInt8:
+    case DataType::kFloat16:  // .f16x2 register
+    case DataType::kBFloat16:
+    case DataType::kTensorFloat32:
+      return FragAttrs('r', 32, "(unsigned *)");
+    case DataType::kInt32:
+      return FragAttrs('r', 32, "(int *)");
+    case DataType::kFloat32:
+      return FragAttrs('f', 32, "(float *)");
+    case DataType::kFloat64:
+      return FragAttrs('d', 64, "(double *)");
+    default:
+      ICHECK(false) << DTypeToString(dtype) << " is not matrix data type in 
MMA.";
+      return FragAttrs('\0', 0, "");
   }
-  asm_code = ReplaceMMAArgument(asm_code, "left_layout", A_layout);
-  asm_code = ReplaceMMAArgument(asm_code, "right_layout", B_layout);
-  asm_code = ReplaceMMAArgument(asm_code, "A", new_a_ref);
-  asm_code = ReplaceMMAArgument(asm_code, "B", new_b_ref);
-  asm_code = ReplaceMMAArgument(asm_code, "C", new_c_ref);
-  asm_code = ReplaceMMAArgument(asm_code, "D", new_c_ref);
-  return asm_code;
 }
 
-std::string PrintMMAm16n8k32Assembly(const std::string& A_layout, const 
std::string& B_layout,
-                                     const std::string& A_dtype, const 
std::string& B_dtype,
-                                     const std::string& C_dtype, const 
std::string& a_ref,
-                                     const std::string& a_bias, const 
std::string& b_ref,
-                                     const std::string& b_bias, const 
std::string& c_ref,
-                                     const std::string& c_bias, bool saturate) 
{
-  std::string asm_code = "";
-  std::string new_a_ref = "";
-  std::string new_b_ref = "";
-  std::string new_c_ref = "";
-  ICHECK(((A_dtype == "int8") && (B_dtype == "int8")) ||
-         ((A_dtype == "uint8") && (B_dtype == "int8")) ||
-         ((A_dtype == "int8") && (B_dtype == "uint8")) ||
-         ((A_dtype == "uint8") && (B_dtype == "uint8")));
-  if ((A_dtype == "int8") && (B_dtype == "int8")) {
-    // A/B multiplicand is int8, SM 80 Tensor Core instructions
-    ICHECK(C_dtype == "int32");
-    ICHECK((A_layout == "row") && (B_layout == "col"))
-        << "SM80 Tensor Core instructions "
-        << "with shape m16n8k32 expect A layout is row major and B layout is 
col major.";
-    // C accumulator is int32
-    if (!saturate) {
-      // no saturate
-      new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
-      new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
-      new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
-      asm_code = R"(
-            {
-              __asm__ __volatile__(
-                  "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 "
-                  "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, "
-                  "{%10,%11,%12,%13};\n"
-                  : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
-                  : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
-                    "r"(B[0]), "r"(B[1]),
-                    "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
-            }
-          )";
-    } else {
-      // saturate
-      new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
-      new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
-      new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
-      asm_code = R"(
-            {
-              __asm__ __volatile__(
-                  "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite "
-                  "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, "
-                  "{%10,%11,%12,%13};\n"
-                  : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
-                  : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
-                    "r"(B[0]), "r"(B[1]),
-                    "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
-            }
-          )";
-    }
-  } else if ((A_dtype == "uint8") && (B_dtype == "int8")) {
-    // A multiplicand is uint8, B multiplicand is int8
-    // SM 80 Tensor Core instructions
-    ICHECK(C_dtype == "int32");
-    ICHECK((A_layout == "row") && (B_layout == "col"))
-        << "SM80 Tensor Core instructions "
-        << "with shape m16n8k32 expect A layout is row major and B layout is 
col major.";
-    // C accumulator is int32
-    if (!saturate) {
-      // no saturate
-      new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
-      new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
-      new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
-      asm_code = R"(
-            {
-              __asm__ __volatile__(
-                  "mma.sync.aligned.m16n8k32.row.col.s32.u8.s8.s32 "
-                  "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, "
-                  "{%10,%11,%12,%13};\n"
-                  : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
-                  : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
-                    "r"(B[0]), "r"(B[1]),
-                    "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
-            }
-          )";
-    } else {
-      // saturate
-      new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
-      new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
-      new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
-      asm_code = R"(
-            {
-              __asm__ __volatile__(
-                  "mma.sync.aligned.m16n8k32.row.col.s32.u8.s8.s32.satfinite "
-                  "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, "
-                  "{%10,%11,%12,%13};\n"
-                  : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
-                  : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
-                    "r"(B[0]), "r"(B[1]),
-                    "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
-            }
-          )";
-    }
-  } else if ((A_dtype == "int8") && (B_dtype == "uint8")) {
-    // A multiplicand is int8, B multiplicand is uint8
-    // SM 80 Tensor Core instructions
-    ICHECK(C_dtype == "int32");
-    ICHECK((A_layout == "row") && (B_layout == "col"))
-        << "SM80 Tensor Core instructions "
-        << "with shape m16n8k32 expect A layout is row major and B layout is 
col major.";
-    // C accumulator is int32
-    if (!saturate) {
-      // no saturate
-      new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
-      new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
-      new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
-      asm_code = R"(
-            {
-              __asm__ __volatile__(
-                  "mma.sync.aligned.m16n8k32.row.col.s32.s8.u8.s32 "
-                  "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, "
-                  "{%10,%11,%12,%13};\n"
-                  : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
-                  : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
-                    "r"(B[0]), "r"(B[1]),
-                    "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
-            }
-          )";
-    } else {
-      // saturate
-      new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
-      new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
-      new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
-      asm_code = R"(
-            {
-              __asm__ __volatile__(
-                  "mma.sync.aligned.m16n8k32.row.col.s32.s8.u8.s32.satfinite "
-                  "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, "
-                  "{%10,%11,%12,%13};\n"
-                  : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
-                  : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
-                    "r"(B[0]), "r"(B[1]),
-                    "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
-            }
-          )";
+};  // namespace ptx
+
+/*!
+ * \brief Replace patterns with replacement strings.
+ * \note should use std::format instead when codebase is ported to C++20.
+ */
+class Replacer {
+ public:
+  void register_rule(const std::string& pattern, const std::string& 
replacement) {
+    _rules.emplace_back(pattern, replacement);
+  }
+  std::string rewrite(std::string str) {
+    for (auto&& rule : _rules) {
+      std::string pattern, replacement;
+      std::tie(pattern, replacement) = rule;
+      size_t len = pattern.size();
+      size_t new_len = replacement.size();
+      size_t pos = str.find(pattern);
+      while (pos != std::string::npos) {
+        str = str.replace(pos, len, replacement);
+        pos = str.find(pattern, pos + new_len);
+      }
     }
+    return str;
+  }
+  void empty_rules() { _rules.clear(); }
+
+ private:
+  std::vector<std::pair<std::string, std::string>> _rules;
+};
+
+/*!
+ * \brief Get the number of MMA computations for given shape and datatype.
+ */
+inline uint32_t GetNumMMAComputations(int m, int n, int k, ptx::DataType 
dtype) {
+  if (m == 8 && n == 8 && k == 4 && dtype == ptx::DataType::kFloat16) {
+    // MMA for m8n8k4 on fp16 would launch 4 MMA computations instead of one.
+    return 4;
   } else {
-    // A/B multiplicand is uint8, SM 80 Tensor Core instructions
-    ICHECK(C_dtype == "int32");
-    ICHECK((A_layout == "row") && (B_layout == "col"))
-        << "SM80 Tensor Core instructions "
-        << "with shape m16n8k32 expect A layout is row major and B layout is 
col major.";
-    // C accumulator is int32
-    if (!saturate) {
-      // no saturate
-      new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
-      new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
-      new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
-      asm_code = R"(
-            {
-              __asm__ __volatile__(
-                  "mma.sync.aligned.m16n8k32.row.col.s32.u8.u8.s32 "
-                  "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, "
-                  "{%10,%11,%12,%13};\n"
-                  : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
-                  : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
-                    "r"(B[0]), "r"(B[1]),
-                    "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
-            }
-          )";
-    } else {
-      // saturate
-      new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
-      new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
-      new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
-      asm_code = R"(
-            {
-              __asm__ __volatile__(
-                  "mma.sync.aligned.m16n8k32.row.col.s32.u8.u8.s32.satfinite "
-                  "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, "
-                  "{%10,%11,%12,%13};\n"
-                  : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
-                  : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
-                    "r"(B[0]), "r"(B[1]),
-                    "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
-            }
-          )";
-    }
+    return 1;
   }
-  asm_code = ReplaceMMAArgument(asm_code, "left_layout", A_layout);
-  asm_code = ReplaceMMAArgument(asm_code, "right_layout", B_layout);
-  asm_code = ReplaceMMAArgument(asm_code, "A", new_a_ref);
-  asm_code = ReplaceMMAArgument(asm_code, "B", new_b_ref);
-  asm_code = ReplaceMMAArgument(asm_code, "C", new_c_ref);
-  asm_code = ReplaceMMAArgument(asm_code, "D", new_c_ref);
-  return asm_code;
 }
 
-std::string PrintMMAm16n8k64Assembly(const std::string& A_layout, const 
std::string& B_layout,
-                                     const std::string& A_dtype, const 
std::string& B_dtype,
-                                     const std::string& C_dtype, const 
std::string& a_ref,
-                                     const std::string& a_bias, const 
std::string& b_ref,
-                                     const std::string& b_bias, const 
std::string& c_ref,
-                                     const std::string& c_bias, bool saturate) 
{
-  std::string asm_code = "";
-  std::string new_a_ref = "";
-  std::string new_b_ref = "";
-  std::string new_c_ref = "";
-  ICHECK(((A_dtype == "int4") && (B_dtype == "int4")) ||
-         ((A_dtype == "uint4") && (B_dtype == "int4")) ||
-         ((A_dtype == "int4") && (B_dtype == "uint4")) ||
-         ((A_dtype == "uint4") && (B_dtype == "uint4")));
-  if ((A_dtype == "int4") && (B_dtype == "int4")) {
-    // A/B multiplicand is int4, SM 80 Tensor Core instructions
-    ICHECK(C_dtype == "int32");
-    ICHECK((A_layout == "row") && (B_layout == "col"))
-        << "SM80 Tensor Core instructions "
-        << "with shape m16n8k64 expect A layout is row major and B layout is 
col major.";
-    // C accumulator is int32
-    if (!saturate) {
-      // no saturate
-      new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
-      new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
-      new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
-      asm_code = R"(
-            {
-              __asm__ __volatile__(
-                  "mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32 "
-                  "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, "
-                  "{%10,%11,%12,%13};\n"
-                  : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
-                  : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
-                    "r"(B[0]), "r"(B[1]),
-                    "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
-            }
-          )";
-    } else {
-      // saturate
-      new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
-      new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
-      new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
-      asm_code = R"(
-            {
-              __asm__ __volatile__(
-                  "mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32.satfinite "
-                  "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, "
-                  "{%10,%11,%12,%13};\n"
-                  : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
-                  : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
-                    "r"(B[0]), "r"(B[1]),
-                    "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
-            }
-          )";
-    }
-  } else if ((A_dtype == "uint4") && (B_dtype == "int4")) {
-    // A multiplicand is uint4, B multiplicand is int4
-    // SM 80 Tensor Core instructions
-    ICHECK(C_dtype == "int32");
-    ICHECK((A_layout == "row") && (B_layout == "col"))
-        << "SM80 Tensor Core instructions "
-        << "with shape m16n8k64 expect A layout is row major and B layout is 
col major.";
-    // C accumulator is int32
-    if (!saturate) {
-      // no saturate
-      new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
-      new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
-      new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
-      asm_code = R"(
-            {
-              __asm__ __volatile__(
-                  "mma.sync.aligned.m16n8k64.row.col.s32.u4.s4.s32 "
-                  "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, "
-                  "{%10,%11,%12,%13};\n"
-                  : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
-                  : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
-                    "r"(B[0]), "r"(B[1]),
-                    "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
-            }
-          )";
-    } else {
-      // saturate
-      new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
-      new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
-      new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
-      asm_code = R"(
-            {
-              __asm__ __volatile__(
-                  "mma.sync.aligned.m16n8k64.row.col.s32.u4.s4.s32.satfinite "
-                  "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, "
-                  "{%10,%11,%12,%13};\n"
-                  : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
-                  : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
-                    "r"(B[0]), "r"(B[1]),
-                    "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
-            }
-          )";
-    }
-  } else if ((A_dtype == "int4") && (B_dtype == "uint4")) {
-    // A multiplicand is int4, B multiplicand is uint4
-    // SM 80 Tensor Core instructions
-    ICHECK(C_dtype == "int32");
-    ICHECK((A_layout == "row") && (B_layout == "col"))
-        << "SM80 Tensor Core instructions "
-        << "with shape m16n8k64 expect A layout is row major and B layout is 
col major.";
-    // C accumulator is int32
-    if (!saturate) {
-      // no saturate
-      new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
-      new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
-      new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
-      asm_code = R"(
-            {
-              __asm__ __volatile__(
-                  "mma.sync.aligned.m16n8k64.row.col.s32.s4.u4.s32 "
-                  "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, "
-                  "{%10,%11,%12,%13};\n"
-                  : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
-                  : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
-                    "r"(B[0]), "r"(B[1]),
-                    "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
-            }
-          )";
-    } else {
-      // saturate
-      new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
-      new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
-      new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
-      asm_code = R"(
-            {
-              __asm__ __volatile__(
-                  "mma.sync.aligned.m16n8k64.row.col.s32.s4.u4.s32.satfinite "
-                  "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, "
-                  "{%10,%11,%12,%13};\n"
-                  : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
-                  : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
-                    "r"(B[0]), "r"(B[1]),
-                    "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
-            }
-          )";
-    }
-  } else {
-    // A/B multiplicand is uint4, SM 75 Tensor Core instructions
-    ICHECK(C_dtype == "int32");
-    ICHECK((A_layout == "row") && (B_layout == "col"))
-        << "SM80 Tensor Core instructions "
-        << "with shape m16n8k64 expect A layout is row major and B layout is 
col major.";
-    // C accumulator is int32
-    if (!saturate) {
-      // no saturate
-      new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
-      new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
-      new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
-      asm_code = R"(
-            {
-              __asm__ __volatile__(
-                  "mma.sync.aligned.m16n8k64.row.col.s32.u4.u4.s32 "
-                  "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, "
-                  "{%10,%11,%12,%13};\n"
-                  : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
-                  : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
-                    "r"(B[0]), "r"(B[1]),
-                    "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
-            }
-          )";
-    } else {
-      // saturate
-      new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
-      new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
-      new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
-      asm_code = R"(
-            {
-              __asm__ __volatile__(
-                  "mma.sync.aligned.m16n8k64.row.col.s32.u4.u4.s32.satfinite "
-                  "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, "
-                  "{%10,%11,%12,%13};\n"
-                  : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
-                  : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
-                    "r"(B[0]), "r"(B[1]),
-                    "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
-            }
-          )";
+/*!
+ * \brief Return template string, input operands string and output operands 
string.
+ * \param m The M in mMnNkK of MMA instructions.
+ * \param n The N in mMnNkK of MMA instructions.
+ * \param k The K in mMnNkK of MMA instructions.
+ * \param dtype_a The data type of multiplicand a.
+ * \param dtype_b The data type of multiplicand b.
+ * \param dtype_c The data type of accumulator c.
+ * \param sparse Whether it's Sparse MMA or not.
+ */
+inline std::tuple<std::string, std::string, std::string> GetMMAOperands(int m, 
int n, int k,
+                                                                        
ptx::DataType dtype_a,
+                                                                        
ptx::DataType dtype_b,
+                                                                        
ptx::DataType dtype_c,
+                                                                        bool 
sparse) {
+  std::stringstream templates, inputs, outputs;
+  const ptx::FragAttrs frag_attr_a = ptx::GetFragAttrs(dtype_a),
+                       frag_attr_b = ptx::GetFragAttrs(dtype_b),
+                       frag_attr_c = ptx::GetFragAttrs(dtype_c);
+  constexpr uint32_t warp_size = 32;
+  const uint32_t threads = warp_size / GetNumMMAComputations(m, n, k, dtype_a);
+  const int num_operands_a =
+                (m * k) * ptx::DTypeBits(dtype_a) / frag_attr_a.size / threads 
/ (sparse ? 2 : 1),
+            num_operands_b = (k * n) * ptx::DTypeBits(dtype_b) / 
frag_attr_b.size / threads,
+            num_operands_c = (m * n) * ptx::DTypeBits(dtype_c) / 
frag_attr_c.size / threads;
+
+  // generate templates;
+  int arg_counter = 0;
+  templates << "{"
+            << "%" << arg_counter++;
+  for (int i = 1; i < num_operands_c; ++i) {
+    templates << ", %" << arg_counter++;
+  }
+  templates << "}, {"
+            << "%" << arg_counter++;
+  for (int i = 1; i < num_operands_a; ++i) {
+    templates << ", %" << arg_counter++;
+  }
+  templates << "}, {"
+            << "%" << arg_counter++;
+  for (int i = 1; i < num_operands_b; ++i) {
+    templates << ", %" << arg_counter++;
+  }
+  templates << "}, {"
+            << "%" << arg_counter++;
+  for (int i = 1; i < num_operands_c; ++i) {
+    templates << ", %" << arg_counter++;
+  }
+  templates << "}";
+  // templates of metadata and sparse selector for sparse mma.
+  if (sparse) {
+    templates << ", %" << (arg_counter++) << ", F";
+  }
+
+  // generate inputs
+  for (int i = 0; i < num_operands_a; ++i) {
+    if (i != 0) {
+      inputs << ", ";
     }
+    inputs << "\"" << frag_attr_a.reg_type << "\"((" << frag_attr_a.ptr_type 
<< "(A))[" << i
+           << "])";
+  }
+  for (int i = 0; i < num_operands_b; ++i) {
+    inputs << ", \"" << frag_attr_b.reg_type << "\"((" << frag_attr_b.ptr_type 
<< "(B))[" << i
+           << "])";
+  }
+  for (int i = 0; i < num_operands_c; ++i) {
+    inputs << ", \"" << frag_attr_c.reg_type << "\"((" << frag_attr_c.ptr_type 
<< "(C))[" << i
+           << "])";
+  }
+  // input of metadata for sparse mma.
+  if (sparse) {
+    inputs << ", \"r\"(((unsigned *)(E))[0])";
   }
-  asm_code = ReplaceMMAArgument(asm_code, "left_layout", A_layout);
-  asm_code = ReplaceMMAArgument(asm_code, "right_layout", B_layout);
-  asm_code = ReplaceMMAArgument(asm_code, "A", new_a_ref);
-  asm_code = ReplaceMMAArgument(asm_code, "B", new_b_ref);
-  asm_code = ReplaceMMAArgument(asm_code, "C", new_c_ref);
-  asm_code = ReplaceMMAArgument(asm_code, "D", new_c_ref);
-  return asm_code;
-}
 
-std::string PrintMMAm16n8k256Assembly(const std::string& A_layout, const 
std::string& B_layout,
-                                      const std::string& A_dtype, const 
std::string& B_dtype,
-                                      const std::string& C_dtype, const 
std::string& a_ref,
-                                      const std::string& a_bias, const 
std::string& b_ref,
-                                      const std::string& b_bias, const 
std::string& c_ref,
-                                      const std::string& c_bias, bool 
saturate) {
-  std::string asm_code = "";
-  std::string new_a_ref = "";
-  std::string new_b_ref = "";
-  std::string new_c_ref = "";
-  ICHECK(((A_dtype == "uint1") && (B_dtype == "uint1")) ||
-         ((A_dtype == "int1") && (B_dtype == "int1")));
-  if ((A_dtype == "uint1") && (B_dtype == "uint1")) {
-    // A/B multiplicand is uint1, SM 80 Tensor Core instructions
-    ICHECK(C_dtype == "int32");
-    ICHECK((A_layout == "row") && (B_layout == "col"))
-        << "SM80 Tensor Core instructions "
-        << "with shape m16n8k256 expect A layout is row major and B layout is 
col major.";
-    // C accumulator is int32
-    new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
-    new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
-    new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
-    asm_code = R"(
-          {
-            __asm__ __volatile__(
-                "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.and.popc "
-                "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, "
-                "{%10,%11,%12,%13};\n"
-                : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
-                : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
-                  "r"(B[0]), "r"(B[1]),
-                  "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
-          }
-        )";
-  } else {
-    // A/B multiplicand is int1, SM 80 Tensor Core instructions
-    ICHECK(C_dtype == "int32");
-    ICHECK((A_layout == "row") && (B_layout == "col"))
-        << "SM80 Tensor Core instructions "
-        << "with shape m16n8k256 expect A layout is row major and B layout is 
col major.";
-    // C accumulator is int32
-    new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))";
-    new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))";
-    new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))";
-    asm_code = R"(
-          {
-            __asm__ __volatile__(
-                "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc "
-                "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, "
-                "{%10,%11,%12,%13};\n"
-                : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
-                : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
-                  "r"(B[0]), "r"(B[1]),
-                  "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
-          }
-        )";
+  // generate outputs
+  for (int i = 0; i < num_operands_c; ++i) {
+    if (i != 0) {
+      outputs << ",";
+    }
+    outputs << " \"=" << frag_attr_c.reg_type << "\"((" << 
frag_attr_c.ptr_type << "(D))[" << i
+            << "])";
   }
-  asm_code = ReplaceMMAArgument(asm_code, "left_layout", A_layout);
-  asm_code = ReplaceMMAArgument(asm_code, "right_layout", B_layout);
-  asm_code = ReplaceMMAArgument(asm_code, "A", new_a_ref);
-  asm_code = ReplaceMMAArgument(asm_code, "B", new_b_ref);
-  asm_code = ReplaceMMAArgument(asm_code, "C", new_c_ref);
-  asm_code = ReplaceMMAArgument(asm_code, "D", new_c_ref);
-  return asm_code;
+  return std::make_tuple(templates.str(), inputs.str(), outputs.str());
 }
 
 std::string PrintMMAAssembly(const std::string& shape, const std::string& 
A_layout,
                              const std::string& B_layout, const std::string& 
A_dtype,
                              const std::string& B_dtype, const std::string& 
C_dtype,
-                             const std::string& a_ref, const std::string& 
a_bias,
-                             const std::string& b_ref, const std::string& 
b_bias,
-                             const std::string& c_ref, const std::string& 
c_bias, bool saturate) {
-  ICHECK((shape == "m8n8k4") || (shape == "m16n8k8") || (shape == "m8n8k16") ||
-         (shape == "m8n8k32") || (shape == "m16n8k4") || (shape == "m16n8k16") 
||
-         (shape == "m16n8k32") || (shape == "m16n8k64") || (shape == 
"m16n8k256"));
-  ICHECK((A_layout == "row") || (A_layout == "col")) << "Unknown A layout: " 
<< A_layout;
-  ICHECK((B_layout == "row") || (B_layout == "col")) << "Unknown B layout: " 
<< B_layout;
-
-  if (shape == "m8n8k4") {
-    return PrintMMAm8n8k4Assembly(A_layout, B_layout, A_dtype, B_dtype, 
C_dtype, a_ref, a_bias,
-                                  b_ref, b_bias, c_ref, c_bias, saturate);
-  } else if (shape == "m16n8k8") {
-    return PrintMMAm16n8k8Assembly(A_layout, B_layout, A_dtype, B_dtype, 
C_dtype, a_ref, a_bias,
-                                   b_ref, b_bias, c_ref, c_bias, saturate);
-  } else if (shape == "m8n8k16") {
-    return PrintMMAm8n8k16Assembly(A_layout, B_layout, A_dtype, B_dtype, 
C_dtype, a_ref, a_bias,
-                                   b_ref, b_bias, c_ref, c_bias, saturate);
-  } else if (shape == "m8n8k32") {
-    return PrintMMAm8n8k32Assembly(A_layout, B_layout, A_dtype, B_dtype, 
C_dtype, a_ref, a_bias,
-                                   b_ref, b_bias, c_ref, c_bias, saturate);
-  } else if (shape == "m16n8k4") {
-    return PrintMMAm16n8k4Assembly(A_layout, B_layout, A_dtype, B_dtype, 
C_dtype, a_ref, a_bias,
-                                   b_ref, b_bias, c_ref, c_bias, saturate);
-  } else if (shape == "m16n8k16") {
-    return PrintMMAm16n8k16Assembly(A_layout, B_layout, A_dtype, B_dtype, 
C_dtype, a_ref, a_bias,
-                                    b_ref, b_bias, c_ref, c_bias, saturate);
-  } else if (shape == "m16n8k32") {
-    return PrintMMAm16n8k32Assembly(A_layout, B_layout, A_dtype, B_dtype, 
C_dtype, a_ref, a_bias,
-                                    b_ref, b_bias, c_ref, c_bias, saturate);
-  } else if (shape == "m16n8k64") {
-    return PrintMMAm16n8k64Assembly(A_layout, B_layout, A_dtype, B_dtype, 
C_dtype, a_ref, a_bias,
-                                    b_ref, b_bias, c_ref, c_bias, saturate);
-  } else if (shape == "m16n8k256") {
-    return PrintMMAm16n8k256Assembly(A_layout, B_layout, A_dtype, B_dtype, 
C_dtype, a_ref, a_bias,
-                                     b_ref, b_bias, c_ref, c_bias, saturate);
+                             const std::string& a_ref, const std::string& 
a_offset,
+                             const std::string& b_ref, const std::string& 
b_offset,
+                             const std::string& c_ref, const std::string& 
c_offset,
+                             const std::string& metadata, const std::string& 
metadata_offset,
+                             const std::string& sparsity_selector, const 
std::string& bit_op,
+                             bool sparse, bool saturate) {
+  ptx::DataType dtype_a = ptx::DTypeFromString(A_dtype), dtype_b = 
ptx::DTypeFromString(B_dtype),
+                dtype_c = ptx::DTypeFromString(C_dtype);
+  ptx::LayoutType layout_a = ptx::LayoutTypeFromString(A_layout),
+                  layout_b = ptx::LayoutTypeFromString(B_layout);
+  int m, n, k;
+  std::tie(m, n, k) = ptx::ParseMMAShape(shape);
+  CheckMMAConfigValidity(m, n, k, layout_a, layout_b, dtype_a, dtype_b, 
dtype_c, bit_op, sparse,
+                         saturate);
+  std::string asm_code = R"(
+  {
+    __asm__ __volatile__(
+      
"mma{sparse}.sync.aligned.{shape}.{alayout}.{blayout}{saturate}{dtype}{atype}{btype}{ctype}{bitop}"
+      "{templates};\n"
+      : {outputs}
+      : {inputs});
   }
-  /*
-   * TODO: add mma.m16n8k128
-   */
-  throw Error("Unknown PTX mma instructions.");
+)";
+  std::string templates_str, inputs_str, outputs_str;
+  std::tie(templates_str, inputs_str, outputs_str) =
+      GetMMAOperands(m, n, k, dtype_a, dtype_b, dtype_c, sparse);
+
+  // replace patterns
+  Replacer replacer;
+  replacer.register_rule("{sparse}", sparse ? ".sp" : "");
+  replacer.register_rule("{shape}", shape);
+  replacer.register_rule("{saturate}", saturate ? ".satfinite" : "");
+  replacer.register_rule("{alayout}", A_layout);
+  replacer.register_rule("{blayout}", B_layout);
+  replacer.register_rule("{atype}", ptx::DTypeToString(dtype_a));
+  replacer.register_rule("{btype}", ptx::DTypeToString(dtype_b));
+  replacer.register_rule("{ctype}", ptx::DTypeToString(dtype_c));
+  replacer.register_rule("{dtype}", ptx::DTypeToString(dtype_c));
+  replacer.register_rule("{bitop}", bit_op.empty() ? "" : "." + bit_op + 
".popc");
+  replacer.register_rule("{templates}", templates_str);
+  replacer.register_rule("{outputs}", outputs_str);
+  replacer.register_rule("{inputs}", inputs_str);
+  asm_code = replacer.rewrite(asm_code);
+  replacer.empty_rules();
+  replacer.register_rule("A", a_ref + " + " + a_offset);
+  replacer.register_rule("B", b_ref + " + " + b_offset);
+  replacer.register_rule("C", c_ref + " + " + c_offset);
+  replacer.register_rule("D", c_ref + " + " + c_offset);
+  replacer.register_rule("E", metadata + " + " + metadata_offset);
+  replacer.register_rule("F", sparsity_selector);
+  asm_code = replacer.rewrite(asm_code);
+  return asm_code;
 }
 
 }  // namespace codegen
diff --git a/src/target/source/ptx_mma.h b/src/target/source/ptx_mma.h
index d2a7a67..728478c 100644
--- a/src/target/source/ptx_mma.h
+++ b/src/target/source/ptx_mma.h
@@ -32,12 +32,36 @@
 namespace tvm {
 namespace codegen {
 
+/*!
+ * \brief Print MMA assembly string given parameters.
+ * \param shape The shape string mMnNkK
+ * \param A_layout The layout of multiplicand A, can be either "row" or "col".
+ * \param B_layout The layout of multiplicand B, can be either "row" or "col".
+ * \param A_dtype The data type of multiplicand A.
+ * \param B_dtype The data type of multiplicand B.
+ * \param C_dtype The data type of multiplicand C.
+ * \param a_ref Pointer to buffer A.
+ * \param a_offset The offset of element in A.
+ * \param b_ref Pointer to buffer B.
+ * \param b_offset The offset of element in B.
+ * \param c_ref Pointer to buffer C.
+ * \param c_offset The offset of element in C.
+ * \param metadata Pointer to metadata buffer (only used for sparse mma).
+ * \param metadata_offset The offset of element in metadata.
+ * \param sparsity_selector The sparsity selector in sparse mma.
+ * \param bit_op The bit operator used in 1-bit mma, can be either "xor" or 
"and".
+ * \param sparse Whether it's sparse mma or not.
+ * \param saturate Whether saturate output or not.
+ */
 std::string PrintMMAAssembly(const std::string& shape, const std::string& 
A_layout,
                              const std::string& B_layout, const std::string& 
A_dtype,
                              const std::string& B_dtype, const std::string& 
C_dtype,
-                             const std::string& a_ref, const std::string& 
a_bias,
-                             const std::string& b_ref, const std::string& 
b_bias,
-                             const std::string& c_ref, const std::string& 
c_bias, bool saturate);
+                             const std::string& a_ref, const std::string& 
a_offset,
+                             const std::string& b_ref, const std::string& 
b_offset,
+                             const std::string& c_ref, const std::string& 
c_offset,
+                             const std::string& metadata, const std::string& 
metadata_offset,
+                             const std::string& sparsity_selector, const 
std::string& bit_op,
+                             bool sparse, bool saturate);
 
 }  // namespace codegen
 }  // namespace tvm
diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc
index 0e767ea..977050a 100644
--- a/src/tir/op/builtin.cc
+++ b/src/tir/op/builtin.cc
@@ -237,6 +237,9 @@ TIR_DEFINE_BUILTIN_FUNC(tvm_store_matrix_sync)
 TIR_DEFINE_BUILTIN_FUNC(ptx_mma).set_attr<TCallEffectKind>("TCallEffectKind",
                                                            
Integer(CallEffectKind::kOpaque));
 
+TIR_DEFINE_BUILTIN_FUNC(ptx_mma_sp)
+    .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kOpaque));
+
 TIR_DEFINE_BUILTIN_FUNC(vectorhigh)
     .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kPure));
 
diff --git a/tests/python/unittest/test_tir_ptx_mma.py 
b/tests/python/unittest/test_tir_ptx_mma.py
index 8f653c6..23405fd 100644
--- a/tests/python/unittest/test_tir_ptx_mma.py
+++ b/tests/python/unittest/test_tir_ptx_mma.py
@@ -1311,6 +1311,7 @@ def gemm_mma_m16n8k256_row_col_b1b1s32(a: T.handle, b: 
T.handle, c: T.handle):
             Accum.data,
             0,
             False,
+            "xor",
             dtype="int32",
         )
     )
diff --git a/tests/python/unittest/test_tir_ptx_mma_sp.py 
b/tests/python/unittest/test_tir_ptx_mma_sp.py
new file mode 100644
index 0000000..321cd28
--- /dev/null
+++ b/tests/python/unittest/test_tir_ptx_mma_sp.py
@@ -0,0 +1,346 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+from tvm.script import tir as T
+import numpy as np
+import tvm.testing
+
+
+def gen_2in4_mask(m: int, n: int):
+    assert n % 4 == 0
+    return np.array(
+        [[np.sort(np.random.choice(4, 2, replace=False)) for _ in range(n // 
4)] for _ in range(m)]
+    ).astype("uint8")
+
+
+def get_dense_mat_by_mask(val, mask):
+    m, n_chunks, _ = mask.shape
+    val = val.reshape(m, n_chunks, 2)
+    ret = np.zeros((m, n_chunks, 4)).astype(val.dtype)
+    for i in range(m):
+        for j in range(n_chunks):
+            for k in range(2):
+                ret[i, j, mask[i, j, k]] = val[i, j, k]
+    return ret.reshape(m, n_chunks * 4)
+
+
[email protected]_func
+def mma_sp_m16n8k16_f16f16f16(a: T.handle, b: T.handle, c: T.handle, 
_metadata: T.handle):
+    T.func_attr({"global_symbol": "default_function", "tir.noalias": True})
+    A = T.match_buffer(a, [16, 8], dtype="float16")
+    B = T.match_buffer(b, [16, 8], dtype="float16")
+    C = T.match_buffer(c, [16, 8], dtype="float16")
+    metadata = T.match_buffer(_metadata, [8], dtype="uint32")
+    brow = T.env_thread("blockIdx.y")
+    bcol = T.env_thread("blockIdx.x")
+    tx = T.env_thread("threadIdx.x")
+    T.launch_thread(brow, 1)
+    T.launch_thread(bcol, 1)
+    T.launch_thread(tx, 32)
+    multi_a = T.allocate([4], "float16", scope="local")
+    multi_b = T.allocate([4], "float16", scope="local")
+    accum = T.allocate([4], "float16", scope="local")
+    meta_local = T.allocate([1], "uint32", scope="local")
+    for i in range(4):
+        accum[i] = T.float16(0)
+
+    for i in range(4):
+        multi_a[i] = A[tx // 4 + i // 2 * 8, tx % 4 * 2 + i % 2]
+
+    for i in range(4):
+        multi_b[i] = B[tx % 4 * 2 + i % 2 + i // 2 * 8, tx // 4]
+
+    meta_local[0] = metadata[tx // 4]
+
+    T.evaluate(
+        T.ptx_mma_sp(
+            "m16n8k16",
+            "row",
+            "col",
+            "fp16",
+            "fp16",
+            "fp16",
+            multi_a.data,
+            0,
+            multi_b.data,
+            0,
+            accum.data,
+            0,
+            meta_local.data,
+            0,
+            0,
+            False,
+            dtype="float16",
+        )
+    )
+
+    for i in range(4):
+        C[i // 2 * 8 + tx // 4, tx % 4 * 2 + i % 2] = accum[i]
+
+
[email protected]_func
+def mma_sp_m16n8k16_f16f16f32(a: T.handle, b: T.handle, c: T.handle, 
_metadata: T.handle):
+    T.func_attr({"global_symbol": "default_function", "tir.noalias": True})
+    A = T.match_buffer(a, [16, 8], dtype="float16")
+    B = T.match_buffer(b, [16, 8], dtype="float16")
+    C = T.match_buffer(c, [16, 8], dtype="float32")
+    metadata = T.match_buffer(_metadata, [8], dtype="uint32")
+    brow = T.env_thread("blockIdx.y")
+    bcol = T.env_thread("blockIdx.x")
+    tx = T.env_thread("threadIdx.x")
+    T.launch_thread(brow, 1)
+    T.launch_thread(bcol, 1)
+    T.launch_thread(tx, 32)
+    multi_a = T.allocate([4], "float16", scope="local")
+    multi_b = T.allocate([4], "float16", scope="local")
+    accum = T.allocate([4], "float32", scope="local")
+    meta_local = T.allocate([1], "uint32", scope="local")
+    for i in range(4):
+        accum[i] = T.float16(0)
+
+    for i in range(4):
+        multi_a[i] = A[tx // 4 + i // 2 * 8, tx % 4 * 2 + i % 2]
+
+    for i in range(4):
+        multi_b[i] = B[tx % 4 * 2 + i % 2 + i // 2 * 8, tx // 4]
+
+    meta_local[0] = metadata[tx // 4]
+
+    T.evaluate(
+        T.ptx_mma_sp(
+            "m16n8k16",
+            "row",
+            "col",
+            "fp16",
+            "fp16",
+            "fp32",
+            multi_a.data,
+            0,
+            multi_b.data,
+            0,
+            accum.data,
+            0,
+            meta_local.data,
+            0,
+            0,
+            False,
+            dtype="float32",
+        )
+    )
+
+    for i in range(4):
+        C[i // 2 * 8 + tx // 4, tx % 4 * 2 + i % 2] = accum[i]
+
+
[email protected]_func
+def mma_sp_m16n8k32_f16f16f16(a: T.handle, b: T.handle, c: T.handle, 
_metadata: T.handle):
+    T.func_attr({"global_symbol": "default_function", "tir.noalias": True})
+    A = T.match_buffer(a, [16, 16], dtype="float16")
+    B = T.match_buffer(b, [32, 8], dtype="float16")
+    C = T.match_buffer(c, [16, 8], dtype="float16")
+    metadata = T.match_buffer(_metadata, [16], dtype="uint32")
+    brow = T.env_thread("blockIdx.y")
+    bcol = T.env_thread("blockIdx.x")
+    tx = T.env_thread("threadIdx.x")
+    T.launch_thread(brow, 1)
+    T.launch_thread(bcol, 1)
+    T.launch_thread(tx, 32)
+    multi_a = T.allocate([8], "float16", scope="local")
+    multi_b = T.allocate([8], "float16", scope="local")
+    accum = T.allocate([4], "float16", scope="local")
+    meta_local = T.allocate([1], "uint32", scope="local")
+    for i in range(4):
+        accum[i] = T.float16(0)
+
+    for i in range(8):
+        multi_a[i] = A[(i % 4) // 2 * 8 + tx // 4, i // 4 * 8 + tx % 4 * 2 + i 
% 2]
+
+    for i in range(8):
+        multi_b[i] = B[i // 2 * 8 + tx % 4 * 2 + i % 2, tx // 4]
+
+    meta_local[0] = metadata[tx // 4 * 2 + tx % 2]
+
+    T.evaluate(
+        T.ptx_mma_sp(
+            "m16n8k32",
+            "row",
+            "col",
+            "fp16",
+            "fp16",
+            "fp16",
+            multi_a.data,
+            0,
+            multi_b.data,
+            0,
+            accum.data,
+            0,
+            meta_local.data,
+            0,
+            0,
+            False,
+            dtype="float16",
+        )
+    )
+
+    for i in range(4):
+        C[i // 2 * 8 + tx // 4, tx % 4 * 2 + i % 2] = accum[i]
+
+
[email protected]_func
+def mma_sp_m16n8k32_f16f16f32(a: T.handle, b: T.handle, c: T.handle, 
_metadata: T.handle):
+    T.func_attr({"global_symbol": "default_function", "tir.noalias": True})
+    A = T.match_buffer(a, [16, 16], dtype="float16")
+    B = T.match_buffer(b, [32, 8], dtype="float16")
+    C = T.match_buffer(c, [16, 8], dtype="float32")
+    metadata = T.match_buffer(_metadata, [16], dtype="uint32")
+    brow = T.env_thread("blockIdx.y")
+    bcol = T.env_thread("blockIdx.x")
+    tx = T.env_thread("threadIdx.x")
+    T.launch_thread(brow, 1)
+    T.launch_thread(bcol, 1)
+    T.launch_thread(tx, 32)
+    multi_a = T.allocate([8], "float16", scope="local")
+    multi_b = T.allocate([8], "float16", scope="local")
+    accum = T.allocate([4], "float32", scope="local")
+    meta_local = T.allocate([1], "uint32", scope="local")
+    for i in range(4):
+        accum[i] = T.float16(0)
+
+    for i in range(8):
+        multi_a[i] = A[(i % 4) // 2 * 8 + tx // 4, i // 4 * 8 + tx % 4 * 2 + i 
% 2]
+
+    for i in range(8):
+        multi_b[i] = B[i // 2 * 8 + tx % 4 * 2 + i % 2, tx // 4]
+
+    meta_local[0] = metadata[tx // 4 * 2 + tx % 2]
+
+    T.evaluate(
+        T.ptx_mma_sp(
+            "m16n8k32",
+            "row",
+            "col",
+            "fp16",
+            "fp16",
+            "fp32",
+            multi_a.data,
+            0,
+            multi_b.data,
+            0,
+            accum.data,
+            0,
+            meta_local.data,
+            0,
+            0,
+            False,
+            dtype="float32",
+        )
+    )
+
+    for i in range(4):
+        C[i // 2 * 8 + tx // 4, tx % 4 * 2 + i % 2] = accum[i]
+
+
[email protected]_cuda
+def test_mma_sp_m16n8k16_f16():
+    def get_meta_m16n8k16_half(mask):
+        assert mask.shape == (16, 4, 2)
+        mask = mask.reshape(16, 8)
+        ret = np.zeros((8,)).astype("uint32")
+
+        for i in range(8):
+            base = 1
+            for blk in range(2):
+                for j in range(8):
+                    ret[i] |= int(mask[blk * 8 + i, j]) * base
+                    base = base << 2
+        return ret
+
+    for out_dtype in ["float16", "float32"]:
+        func = mma_sp_m16n8k16_f16f16f16 if out_dtype == "float16" else 
mma_sp_m16n8k16_f16f16f32
+        sch = tvm.tir.Schedule(func)
+        arch = tvm.contrib.nvcc.get_target_compute_version()
+        major, _ = tvm.contrib.nvcc.parse_compute_version(arch)
+        if major < 8:
+            # Requires SM80+
+            return
+        cuda_mod = tvm.build(sch.mod, target="cuda")
+
+        A_np = np.random.uniform(-1, 1, [16, 8]).astype("float16")
+        B_np = np.random.uniform(-1, 1, [16, 8]).astype("float16")
+        mask = gen_2in4_mask(16, 16)
+        A_dense_np = get_dense_mat_by_mask(A_np, mask)
+        C_np = np.matmul(A_dense_np, B_np).astype(out_dtype)
+        meta = get_meta_m16n8k16_half(mask)
+
+        ctx = tvm.cuda()
+        A_tvm = tvm.nd.array(A_np, ctx)
+        B_tvm = tvm.nd.array(B_np, ctx)
+        C_tvm = tvm.nd.array(np.zeros_like(C_np), ctx)
+        meta_tvm = tvm.nd.array(meta, ctx)
+        cuda_mod(A_tvm, B_tvm, C_tvm, meta_tvm)
+
+        tvm.testing.assert_allclose(C_tvm.numpy(), C_np, atol=1e-3, rtol=1e-3)
+
+
[email protected]_cuda
+def test_mma_sp_m16n8k32_f16():
+    def get_meta_m16n8k32_half(mask):
+        assert mask.shape == (16, 8, 2)
+        mask = mask.reshape(16, 2, 8)
+        ret = np.zeros((8, 2)).astype("uint32")
+
+        for i in range(8):
+            for k in range(2):
+                base = 1
+                for blk in range(2):
+                    for j in range(8):
+                        ret[i, k] |= int(mask[blk * 8 + i, k, j]) * base
+                        base = base << 2
+
+        return ret.reshape(16)
+
+    for out_dtype in ["float16", "float32"]:
+        func = mma_sp_m16n8k32_f16f16f16 if out_dtype == "float16" else 
mma_sp_m16n8k32_f16f16f32
+        sch = tvm.tir.Schedule(func)
+        arch = tvm.contrib.nvcc.get_target_compute_version()
+        major, _ = tvm.contrib.nvcc.parse_compute_version(arch)
+        if major < 8:
+            # Requires SM80+
+            return
+        cuda_mod = tvm.build(sch.mod, target="cuda")
+
+        A_np = np.random.uniform(-1, 1, [16, 16]).astype("float16")
+        B_np = np.random.uniform(-1, 1, [32, 8]).astype("float16")
+        mask = gen_2in4_mask(16, 32)
+        A_dense_np = get_dense_mat_by_mask(A_np, mask)
+        C_np = np.matmul(A_dense_np, B_np).astype(out_dtype)
+        meta = get_meta_m16n8k32_half(mask)
+
+        ctx = tvm.cuda()
+        A_tvm = tvm.nd.array(A_np, ctx)
+        B_tvm = tvm.nd.array(B_np, ctx)
+        C_tvm = tvm.nd.array(np.zeros_like(C_np), ctx)
+        meta_tvm = tvm.nd.array(meta, ctx)
+        cuda_mod(A_tvm, B_tvm, C_tvm, meta_tvm)
+
+    tvm.testing.assert_allclose(C_tvm.numpy(), C_np, atol=1e-3, rtol=1e-3)
+
+
+if __name__ == "__main__":
+    test_mma_sp_m16n8k16_f16()
+    test_mma_sp_m16n8k32_f16()

Reply via email to