This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new c7a01a4  [CUTLASS] Support batch_matmul (#9439)
c7a01a4 is described below

commit c7a01a4d457a2d1b1939dad82b572fb738ff4754
Author: masahi <[email protected]>
AuthorDate: Thu Nov 4 15:59:05 2021 +0900

    [CUTLASS] Support batch_matmul (#9439)
    
    * Import batched gemm change
    
    commit cfacfa296e2487a189e52d189567b140c675ccc2
    Author: Masahiro Masuda <[email protected]>
    Date:   Mon Nov 1 15:57:49 2021 +0900
    
        change is_constant pattern to wildcard in gelu pattern
    
    commit 84da94306ca81209a8ccc44fd7d606cbce047082
    Author: Masahiro Masuda <[email protected]>
    Date:   Mon Nov 1 05:41:11 2021 +0900
    
        fixed batch stride C
    
    commit 66e5779ee69dc0cd3969f268608b551ec549d79b
    Author: Masahiro Masuda <[email protected]>
    Date:   Sun Oct 31 20:47:16 2021 +0900
    
        refactoring codegen
    
    commit 561daeafa66cddf6a565b537072a5efce0b0dbf1
    Author: Masahiro Masuda <[email protected]>
    Date:   Sun Oct 31 20:05:20 2021 +0900
    
        generated kernel compiled and result match
    
    commit a5740bcf5287097b64dff8adb50f0cddc2c41349
    Author: Masahiro Masuda <[email protected]>
    Date:   Sun Oct 31 19:36:53 2021 +0900
    
        partitioning looks good
    
    commit 59112fdf78a4541905fad9b899737600e0ed9391
    Author: Masahiro Masuda <[email protected]>
    Date:   Sun Oct 31 19:01:47 2021 +0900
    
        [WIP] cutlass batch matmul support
    
    * fixed test
    
    * refactoring
    
    * gelu test fixed
    
    * more refactor
    
    * batch_matmul fp32 accum working
    
    * dynamic batch matmul working
    
    * black
    
    * remove doc TODO
---
 python/tvm/contrib/cutlass/build.py          | 141 +++++++++++++++++-----
 python/tvm/contrib/cutlass/gemm_operation.py |  15 +--
 python/tvm/contrib/cutlass/gen_gemm.py       |  42 ++++---
 python/tvm/contrib/cutlass/library.py        |   2 +
 python/tvm/relay/op/contrib/cutlass.py       |  11 +-
 src/relay/backend/contrib/cutlass/codegen.cc | 167 +++++++++++++++++++--------
 tests/python/contrib/test_cutlass.py         | 109 ++++++++++++++---
 7 files changed, 363 insertions(+), 124 deletions(-)

diff --git a/python/tvm/contrib/cutlass/build.py 
b/python/tvm/contrib/cutlass/build.py
index 58e7a11..615b900 100644
--- a/python/tvm/contrib/cutlass/build.py
+++ b/python/tvm/contrib/cutlass/build.py
@@ -83,6 +83,87 @@ class GemmAnnotator(tvm.relay.ExprVisitor):
             self.signature["ret_dtype"] = op.ret_type.dtype
 
 
+def select_gemm_kernel(
+    cutlass_profiler, MM, KK, NN, out_dtype, batched, profile_all, 
use_multiprocessing
+):
+    """Run CUTLASS profiler to select the best kernel, or return the default 
one for dynamic
+    workloads."""
+    if any(isinstance(s, tvm.tir.Any) for s in [MM, KK, NN]):
+        out = cutlass_profiler.get_default(out_dtype, batched=batched)
+        logger.info("Picked the default kernel %s", out["name"])
+    else:
+        out = cutlass_profiler.profile(
+            MM,
+            NN,
+            KK,
+            out_dtype,
+            batched=batched,
+            profile_all=profile_all,
+            use_multiprocessing=use_multiprocessing,
+        )
+        if profile_all:
+            logger.info("The best kernel is %s", out["name"])
+        else:
+            logger.info("Picked the first kernel found %s", out["name"])
+    return out
+
+
+def handle_batch_matmul(
+    cutlass_profiler, op_type, arg0_shape, arg1_shape, out_dtype, profile_all, 
use_multiprocessing
+):
+    """Profile and select a kernel for batch_matmul op workload."""
+    MM = arg0_shape[1]
+    KK = arg0_shape[2]
+    NN = arg1_shape[1]
+
+    out = select_gemm_kernel(
+        cutlass_profiler, MM, KK, NN, out_dtype, True, profile_all, 
use_multiprocessing
+    )
+
+    if op_type == "cutlass.batch_matmul":
+        cutlass_op_def = out["opdef"]
+    else:
+        raise ValueError("%s pattern is not implemented." % op_type)
+
+    return {
+        "batch": arg0_shape[0],
+        "batch_stride_A": arg0_shape[1] * arg0_shape[2],
+        "batch_stride_B": arg1_shape[1] * arg1_shape[2],
+        "batch_stride_C": arg0_shape[1] * arg1_shape[1],
+        "cutlass_op_def": cutlass_op_def,
+        "cutlass_op_name": out["name"],
+    }
+
+
+def handle_dense(
+    cutlass_profiler, op_type, arg0_shape, arg1_shape, out_dtype, profile_all, 
use_multiprocessing
+):
+    """Profile and select a kernel for dense op workload."""
+    MM = arg0_shape[0]
+    KK = arg0_shape[1]
+    NN = arg1_shape[0]
+
+    out = select_gemm_kernel(
+        cutlass_profiler, MM, KK, NN, out_dtype, False, profile_all, 
use_multiprocessing
+    )
+
+    if op_type == "cutlass.dense":
+        cutlass_op_def = out["opdef"]
+    elif op_type == "cutlass.dense_bias":
+        cutlass_op_def = out["opdef_bias"]
+    elif op_type == "cutlass.dense_bias_relu":
+        cutlass_op_def = out["opdef_bias_relu"]
+    elif "cutlass.dense_bias_gelu" in op_type:
+        cutlass_op_def = out["opdef_bias_gelu"]
+    else:
+        raise ValueError("%s pattern is not implemented." % op_type)
+
+    return {
+        "cutlass_op_def": cutlass_op_def,
+        "cutlass_op_name": out["name"],
+    }
+
+
 def tune_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False, 
tmp_dir="./tmp"):
     """Given a module partitioned for CUTLASS offloading, profile each 
workload to select which
     kernels to emit.
@@ -123,41 +204,41 @@ def tune_cutlass_kernels(mod, sm, profile_all=True, 
use_multiprocessing=False, t
         if "cutlass" in fun_name:
             num_cutlass_partition += 1
             annotator.visit(func)
-            # call cutlass profiler to find best settings, update attr
-            new_attrs = {}
+            out_dtype = annotator.signature["ret_dtype"]
+            op_type = annotator.signature["op_type"]
+
+            new_attrs = {"op_type": op_type}
             new_attrs.update(annotator.signature)
-            for key in func.attrs.keys():
-                new_attrs[key] = func.attrs[key]
-            # call profiler
+            new_attrs.update(func.attrs)
             arg0_shape = new_attrs["arg0_shape"]
             arg1_shape = new_attrs["arg1_shape"]
-            MM = arg0_shape[0]
-            KK = arg0_shape[1]
-            NN = arg1_shape[0]
-            out_dtype = annotator.signature["ret_dtype"]
-            if any(isinstance(s, tvm.tir.Any) for s in [MM, KK, NN]):
-                out = cutlass_profiler.get_default(out_dtype)
-                logger.info("Picked the default kernel %s", out["name"])
-            else:
-                out = cutlass_profiler.profile(
-                    MM, NN, KK, out_dtype, profile_all, use_multiprocessing
+
+            if "batch_matmul" in op_type:
+                new_attrs.update(
+                    handle_batch_matmul(
+                        cutlass_profiler,
+                        op_type,
+                        arg0_shape,
+                        arg1_shape,
+                        out_dtype,
+                        profile_all,
+                        use_multiprocessing,
+                    )
+                )
+            elif "dense" in op_type:
+                new_attrs.update(
+                    handle_dense(
+                        cutlass_profiler,
+                        op_type,
+                        arg0_shape,
+                        arg1_shape,
+                        out_dtype,
+                        profile_all,
+                        use_multiprocessing,
+                    )
                 )
-                if profile_all:
-                    logger.info("The best kernel is %s", out["name"])
-                else:
-                    logger.info("Picked the first kernel found %s", 
out["name"])
-
-            if new_attrs["op_type"] == "cutlass.dense":
-                new_attrs["cutlass_op_def"] = out["opdef"]
-            elif new_attrs["op_type"] == "cutlass.dense_bias":
-                new_attrs["cutlass_op_def"] = out["opdef_bias"]
-            elif new_attrs["op_type"] == "cutlass.dense_bias_relu":
-                new_attrs["cutlass_op_def"] = out["opdef_bias_relu"]
-            elif "cutlass.dense_bias_gelu" in new_attrs["op_type"]:
-                new_attrs["cutlass_op_def"] = out["opdef_bias_gelu"]
             else:
-                raise ValueError("%s pattern is not implemented." % 
new_attrs["op_type"])
-            new_attrs["cutlass_op_name"] = out["name"]
+                raise ValueError("%s unsupported composite" % op_type)
 
             if new_attrs["cutlass_op_name"].find("_tn_align") > 0:
                 new_attrs["lda"] = "K"
diff --git a/python/tvm/contrib/cutlass/gemm_operation.py 
b/python/tvm/contrib/cutlass/gemm_operation.py
index e53b3ee..4673b4b 100644
--- a/python/tvm/contrib/cutlass/gemm_operation.py
+++ b/python/tvm/contrib/cutlass/gemm_operation.py
@@ -174,7 +174,7 @@ class EmitGemmInstance:
     >"""
         self.gemm_template = """
   // Gemm operator ${operation_name}
-  using Operation_${operation_name} = cutlass::gemm::device::Gemm<
+  using Operation_${operation_name} = cutlass::gemm::device::${kernel_name}<
     ${element_a}, ${layout_a},
     ${element_b}, ${layout_b},
     ${element_c}, ${layout_c},
@@ -189,13 +189,12 @@ class EmitGemmInstance:
     ${stages},
     ${align_a},
     ${align_b},
-    false,
+    ${split_k_serial}
     ${math_operation}
-    ${residual}
   >;
 """
 
-    def emit(self, operation, no_beta_scaling=False):
+    def emit(self, operation, no_beta_scaling=False, batched=False):
         """Instantiate a GEMM kernel from given `operation`."""
         warp_shape = [
             operation.tile_description.threadblock_shape[idx]
@@ -206,8 +205,6 @@ class EmitGemmInstance:
             min(operation.C.alignment * DataTypeSize[operation.C.element], 128)
             // DataTypeSize[operation.C.element]
         )
-        residual = ""
-        complex_transform_tag = "cutlass::ComplexTransform::kNone"
         values = {
             "operation_name": operation.procedural_name(),
             "element_a": DataTypeTag[operation.A.element],
@@ -243,14 +240,14 @@ class EmitGemmInstance:
             "stages": str(operation.tile_description.stages),
             "align_a": str(operation.A.alignment),
             "align_b": str(operation.B.alignment),
-            "transform_a": complex_transform_tag,
-            "transform_b": complex_transform_tag,
             "math_operation": MathOperationTag[
                 operation.tile_description.math_instruction.math_operation
             ],
-            "residual": residual,
         }
 
+        values["kernel_name"] = "GemmBatched" if batched else "Gemm"
+        values["split_k_serial"] = "" if batched else "false,"
+
         gemm_template = substitute_template(
             self.gemm_template,
             {
diff --git a/python/tvm/contrib/cutlass/gen_gemm.py 
b/python/tvm/contrib/cutlass/gen_gemm.py
index a43c6d4..1ed4bfe 100644
--- a/python/tvm/contrib/cutlass/gen_gemm.py
+++ b/python/tvm/contrib/cutlass/gen_gemm.py
@@ -47,6 +47,7 @@ def create_gemm_operator(
     alignment_constraints,
     epilogue_functor=EpilogueFunctor.LinearCombination,
     swizzling_functor=SwizzlingFunctor.Identity8,
+    batched=False,
 ):
     """Exhaustively instantiate all kernels from a given configuration."""
     ret = []
@@ -55,6 +56,9 @@ def create_gemm_operator(
 
     element_a, element_b, element_c, element_epilogue = data_type
 
+    if batched:
+        swizzling_functor = SwizzlingFunctor.Batched
+
     for layout in layouts:
         for tile_description in tile_descriptions:
             for alignment in alignment_constraints:
@@ -109,15 +113,17 @@ def create_gemm_operator(
                 kernel_emitter = EmitGemmInstance()
                 op_entry["op"] = op
                 op_entry["name"] = op.procedural_name()
-                op_entry["opdef"] = kernel_emitter.emit(op)
-                op_entry["opdef_bias"] = kernel_emitter.emit(op_bias, 
no_beta_scaling=True)
+                op_entry["opdef"] = kernel_emitter.emit(op, batched=batched)
+                op_entry["opdef_bias"] = kernel_emitter.emit(
+                    op_bias, no_beta_scaling=True, batched=batched
+                )
                 op_entry["opdef_bias_relu"] = kernel_emitter.emit(
-                    op_bias_relu, no_beta_scaling=True
+                    op_bias_relu, no_beta_scaling=True, batched=batched
                 )
-                op_entry["opdef_bias_gelu"] = kernel_emitter.emit(op_bias_gelu)
+                op_entry["opdef_bias_gelu"] = 
kernel_emitter.emit(op_bias_gelu, batched=batched)
                 op_entry["src"] = profiler_emitter.emit(
                     op.procedural_name(),
-                    op_entry["opdef"],
+                    kernel_emitter.emit(op, batched=False),
                     DataTypeTag[element_a],
                     DataTypeTag[element_b],
                     DataTypeTag[element_c],
@@ -128,7 +134,9 @@ def create_gemm_operator(
     return ret
 
 
-def generate_tensor_op_common(math_instructions, alignment_constraints, 
get_tile_descriptions):
+def generate_tensor_op_common(
+    math_instructions, alignment_constraints, get_tile_descriptions, 
batched=False
+):
     """Common kernel generator to be used by archtecture specific 
generators."""
     ops = []
     layouts = [
@@ -143,14 +151,16 @@ def generate_tensor_op_common(math_instructions, 
alignment_constraints, get_tile
             math_inst.element_accumulator,
         ]
 
-        out = create_gemm_operator(layouts, tile_descriptions, data_type, 
alignment_constraints)
+        out = create_gemm_operator(
+            layouts, tile_descriptions, data_type, alignment_constraints, 
batched=batched
+        )
 
         ops.extend(out)
 
     return ops
 
 
-def generate_sm75_tensor_op_1688(out_dtype):
+def generate_sm75_tensor_op_1688(out_dtype, batched=False):
     """Generate GEMM kernels for Turing."""
     assert out_dtype in ["float32", "float16"]
     math_instructions = {
@@ -192,11 +202,11 @@ def generate_sm75_tensor_op_1688(out_dtype):
         ]
 
     return generate_tensor_op_common(
-        math_instructions, alignment_constraints, get_tile_descriptions
+        math_instructions, alignment_constraints, get_tile_descriptions, 
batched
     )
 
 
-def generate_sm80_tensor_op_16816(out_dtype):
+def generate_sm80_tensor_op_16816(out_dtype, batched=False):
     """Generate GEMM kernels for Ampere."""
     assert out_dtype in ["float32", "float16"]
     math_instructions = {
@@ -250,7 +260,7 @@ def generate_sm80_tensor_op_16816(out_dtype):
         ]
 
     return generate_tensor_op_common(
-        math_instructions, alignment_constraints, get_tile_descriptions
+        math_instructions, alignment_constraints, get_tile_descriptions, 
batched
     )
 
 
@@ -350,17 +360,19 @@ class CutlassGemmProfiler(object):
             return False
         return True
 
-    def get_default(self, out_dtype):
+    def get_default(self, out_dtype, batched=False):
         """Return the default kernel for the requested architecture.
         For now, the default kernel was picked arbitrary.
         """
-        ops = GENERATOR_FUNC_TABLE[self.sm](out_dtype)
+        ops = GENERATOR_FUNC_TABLE[self.sm](out_dtype, batched)
         default_kernel_name = DEFAULT_KERNELS[self.sm][out_dtype]
         filtered = list(filter(lambda op: op["name"] == default_kernel_name, 
ops))
         assert len(filtered) == 1
         return filtered[0]
 
-    def profile(self, M, N, K, out_dtype, profile_all=True, 
use_multiprocessing=False):
+    def profile(
+        self, M, N, K, out_dtype, profile_all=True, use_multiprocessing=False, 
batched=False
+    ):
         """Profile and select the best kernel from candidate kernels.
         If profile_all is False, return immediately after the first applicable 
kernel is found.
         If use_multiprocessing is True, compile all profiler executables in 
parallel.
@@ -368,7 +380,7 @@ class CutlassGemmProfiler(object):
         if (M, N, K) in self.cache:
             return self.cache[(M, N, K)]
 
-        ops = GENERATOR_FUNC_TABLE[self.sm](out_dtype)
+        ops = GENERATOR_FUNC_TABLE[self.sm](out_dtype, batched)
         ops = list(filter(lambda op: self.check_align(op["name"], M), ops))
 
         for op in ops:
diff --git a/python/tvm/contrib/cutlass/library.py 
b/python/tvm/contrib/cutlass/library.py
index 7d54429..a3b90ff 100644
--- a/python/tvm/contrib/cutlass/library.py
+++ b/python/tvm/contrib/cutlass/library.py
@@ -160,6 +160,7 @@ class SwizzlingFunctor(enum.Enum):
     Identity2 = enum_auto()
     Identity4 = enum_auto()
     Identity8 = enum_auto()
+    Batched = enum_auto()
 
 
 SwizzlingFunctorTag = {
@@ -167,6 +168,7 @@ SwizzlingFunctorTag = {
     SwizzlingFunctor.Identity2: 
"cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<2>",
     SwizzlingFunctor.Identity4: 
"cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>",
     SwizzlingFunctor.Identity8: 
"cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>",
+    SwizzlingFunctor.Batched: 
"cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle",
 }
 
 
diff --git a/python/tvm/relay/op/contrib/cutlass.py 
b/python/tvm/relay/op/contrib/cutlass.py
index 631089c..8ed3718 100644
--- a/python/tvm/relay/op/contrib/cutlass.py
+++ b/python/tvm/relay/op/contrib/cutlass.py
@@ -20,13 +20,13 @@ from ...dataflow_pattern import wildcard, is_op, is_constant
 
 
 def make_gelu_pattern(bias_out, out_dtype="float16"):
-    mul = is_op("multiply")(bias_out, is_constant())
+    mul = is_op("multiply")(bias_out, is_constant() | wildcard())
     if out_dtype == "float16":
         erf = is_op("cast")(is_op("erf")(is_op("cast")(mul)))
     else:
         erf = is_op("erf")(mul)
-    mul_half = is_op("multiply")(erf, is_constant())
-    add = is_op("add")(mul_half, is_constant())
+    mul_half = is_op("multiply")(erf, is_constant() | wildcard())
+    add = is_op("add")(mul_half, is_constant() | wildcard())
     return is_op("multiply")(add, bias_out)
 
 
@@ -51,6 +51,10 @@ def make_gemm_pattern(with_bias=True, with_act=None, 
out_dtype="float16"):
     return make_gelu_pattern(gemm_out, out_dtype)
 
 
+def make_batch_matmul_pattern():
+    return is_op("nn.batch_matmul")(wildcard(), wildcard())
+
+
 def partition_for_cutlass(mod):
     """Partition the input module into CUTLASS-supported subgraphs."""
     dense_pat = ("cutlass.dense", make_gemm_pattern(False, None))
@@ -67,6 +71,7 @@ def partition_for_cutlass(mod):
         dense_bias_relu_pat,
         dense_bias_pat,
         dense_pat,
+        ("cutlass.batch_matmul", make_batch_matmul_pattern()),
     ]
     mod = transform.MergeComposite(cutlass_patterns)(mod)
     mod = transform.AnnotateTarget(["cutlass"])(mod)
diff --git a/src/relay/backend/contrib/cutlass/codegen.cc 
b/src/relay/backend/contrib/cutlass/codegen.cc
index c1217a0..f154f86 100644
--- a/src/relay/backend/contrib/cutlass/codegen.cc
+++ b/src/relay/backend/contrib/cutlass/codegen.cc
@@ -54,19 +54,21 @@ std::string GetDimAsStr(ObjectRef dim) {
   return kAnyDim;
 }
 
-Str2StrMap DenseArgs(const Map<String, ObjectRef>& attrs) {
+inline void CutlassPrint(std::ostringstream& os, const std::string& stmt, int 
indent = 2) {
+  for (int i = 0; i < indent; ++i) {
+    os << " ";
+  }
+  os << stmt;
+}
+
+Str2StrMap GemmArgsCommon(const Map<String, ObjectRef>& attrs) {
   Str2StrMap args;
   auto arg0_dtype = std::string(attrs["arg0_dtype"].as<StringObj>()->data);
   auto arg1_dtype = std::string(attrs["arg1_dtype"].as<StringObj>()->data);
   auto ret_dtype = std::string(attrs["ret_dtype"].as<StringObj>()->data);
-  auto arg0_shape = attrs["arg0_shape"].as<ArrayNode>();
-  auto arg1_shape = attrs["arg1_shape"].as<ArrayNode>();
   args["ElementInputA"] = dtype_map.at(arg0_dtype);
   args["ElementInputB"] = dtype_map.at(arg1_dtype);
   args["ElementOutput"] = dtype_map.at(ret_dtype);
-  args["M"] = GetDimAsStr(arg0_shape->at(0));
-  args["K"] = GetDimAsStr(arg0_shape->at(1));
-  args["N"] = GetDimAsStr(arg1_shape->at(0));
   args["op_def"] = std::string(attrs["cutlass_op_def"].as<StringObj>()->data);
   args["op_name"] = 
std::string(attrs["cutlass_op_name"].as<StringObj>()->data);
   args["op_type"] = std::string(attrs["op_type"].as<StringObj>()->data);
@@ -76,23 +78,33 @@ Str2StrMap DenseArgs(const Map<String, ObjectRef>& attrs) {
   return args;
 }
 
-inline void CutlassPrint(std::ostringstream& os, const std::string& stmt, int 
indent = 2) {
-  for (int i = 0; i < indent; ++i) {
-    os << " ";
-  }
-  os << stmt;
+Str2StrMap DenseArgs(const Map<String, ObjectRef>& attrs) {
+  Str2StrMap args = GemmArgsCommon(attrs);
+  auto arg0_shape = attrs["arg0_shape"].as<ArrayNode>();
+  auto arg1_shape = attrs["arg1_shape"].as<ArrayNode>();
+  args["M"] = GetDimAsStr(arg0_shape->at(0));
+  args["K"] = GetDimAsStr(arg0_shape->at(1));
+  args["N"] = GetDimAsStr(arg1_shape->at(0));
+  return args;
 }
 
-std::string DenseOp(std::string id, const Str2StrMap& attrs,
-                    const std::vector<std::string>& func_args) {
-  bool has_bias = false;
-  bool is_gelu =
-      attrs.at("op_type").find("cutlass.dense_bias_gelu") != 
std::string::npos;  // fp32 or fp16
-  if (attrs.at("op_type") == "cutlass.dense_bias" ||
-      attrs.at("op_type") == "cutlass.dense_bias_relu" || is_gelu) {
-    has_bias = true;
-  }
-  std::ostringstream gemm_decl;
+Str2StrMap BatchMatmulArgs(const Map<String, ObjectRef>& attrs) {
+  Str2StrMap args = GemmArgsCommon(attrs);
+  args["batch"] = GetDimAsStr(attrs["batch"]);
+  args["batch_stride_A"] = GetDimAsStr(attrs["batch_stride_A"]);
+  args["batch_stride_B"] = GetDimAsStr(attrs["batch_stride_B"]);
+  args["batch_stride_C"] = GetDimAsStr(attrs["batch_stride_C"]);
+  auto arg0_shape = attrs["arg0_shape"].as<ArrayNode>();
+  auto arg1_shape = attrs["arg1_shape"].as<ArrayNode>();
+  args["M"] = GetDimAsStr(arg0_shape->at(1));
+  args["K"] = GetDimAsStr(arg0_shape->at(2));
+  args["N"] = GetDimAsStr(arg1_shape->at(1));
+  return args;
+}
+
+void AppendPrologue(std::ostringstream& gemm_decl, const Str2StrMap& attrs,
+                    const std::vector<std::string>& func_args, const 
std::string& kernel,
+                    bool has_bias, bool is_gelu, int m_axis_idx, int 
n_axis_idx, int k_axis_idx) {
   CutlassPrint(gemm_decl, "using ElementInputA = " + attrs.at("ElementInputA") 
+ ";\n");
   CutlassPrint(gemm_decl, "using ElementInputB = " + attrs.at("ElementInputB") 
+ ";\n");
   CutlassPrint(gemm_decl, "using ElementOutput = " + attrs.at("ElementOutput") 
+ ";\n");
@@ -107,11 +119,10 @@ std::string DenseOp(std::string id, const Str2StrMap& 
attrs,
       return attrs.at(axis);
     }
   };
-  CutlassPrint(gemm_decl, "int M = " + get_dim("M", 0, 0) + ";\n");
-  CutlassPrint(gemm_decl, "int N = " + get_dim("N", 1, 0) + ";\n");
-  CutlassPrint(gemm_decl, "int K = " + get_dim("K", 0, 1) + ";\n");
+  CutlassPrint(gemm_decl, "int M = " + get_dim("M", 0, m_axis_idx) + ";\n");
+  CutlassPrint(gemm_decl, "int N = " + get_dim("N", 1, n_axis_idx) + ";\n");
+  CutlassPrint(gemm_decl, "int K = " + get_dim("K", 0, k_axis_idx) + ";\n");
   CutlassPrint(gemm_decl, "cutlass::gemm::GemmCoord problem_size(M, N, K);\n");
-  // Initialize alpha for dot product computation
   CutlassPrint(gemm_decl, "ElementComputeEpilogue alpha = 
ElementComputeEpilogue(1);\n");
   if (is_gelu) {
     // GeLU epilogue does not compile with NoBetaScaling, so we explicitly 
specify the scale.
@@ -120,11 +131,6 @@ std::string DenseOp(std::string id, const Str2StrMap& 
attrs,
     CutlassPrint(gemm_decl, "ElementComputeEpilogue beta = 
ElementComputeEpilogue(0);\n");
   }
 
-  // Split K dimension into 1 partitions
-  CutlassPrint(gemm_decl, "int split_k_slices = 1;\n");
-
-  // Create a tuple of gemm kernel arguments. This is later passed as 
arguments to launch
-  // instantiated CUTLASS kernel
   ICHECK(func_args.size() >= 2);
   CutlassPrint(gemm_decl, "void* ptr_a = (void*)(" + func_args[0] + 
"->data);\n");
   CutlassPrint(gemm_decl, "void* ptr_b = (void*)(" + func_args[1] + 
"->data);\n");
@@ -132,33 +138,24 @@ std::string DenseOp(std::string id, const Str2StrMap& 
attrs,
     ICHECK(func_args.size() >= 3);
     CutlassPrint(gemm_decl, "void* ptr_c_bias = (void*)(" + func_args[2] + 
"->data);\n");
   }
+
   CutlassPrint(gemm_decl, "void* ptr_out = (void*)(out0);\n");
 
-  CutlassPrint(gemm_decl, "typename Gemm::Arguments arguments{\n");
+  CutlassPrint(gemm_decl, "using " + kernel + " = Operation_" + 
attrs.at("op_name") + ";\n");
+  CutlassPrint(gemm_decl, "typename " + kernel + "::Arguments arguments{\n");
   CutlassPrint(gemm_decl, " problem_size,\n");
-  CutlassPrint(gemm_decl, " {static_cast<ElementInputA*>(ptr_a), " + 
attrs.at("lda") + "},\n");
-  CutlassPrint(gemm_decl, " {static_cast<ElementInputB*>(ptr_b), " + 
attrs.at("ldb") + "},\n");
-  if (has_bias) {
-    CutlassPrint(gemm_decl, " {static_cast<ElementOutput*>(ptr_c_bias), 
0},\n");
-  } else {
-    CutlassPrint(gemm_decl, " {static_cast<ElementOutput*>(ptr_out), " + 
attrs.at("ldc") + "},\n");
-  }
-  CutlassPrint(gemm_decl, " {static_cast<ElementOutput*>(ptr_out), " + 
attrs.at("ldc") + "},\n");
-  if (has_bias && !is_gelu) {
-    CutlassPrint(gemm_decl, " {alpha},\n");
-  } else {
-    // For GeLU, we explicitly specify the scale.
-    CutlassPrint(gemm_decl, " {alpha, beta},\n");
-  }
-  CutlassPrint(gemm_decl, " split_k_slices};\n");
+}
 
+void AppendGemmExecute(std::ostringstream& gemm_decl, const std::string& 
kernel) {
   // Using the arguments, query for extra workspace required for matrix 
multiplication computation
-  CutlassPrint(gemm_decl, "size_t workspace_size = 
Gemm::get_workspace_size(arguments);\n");
+  CutlassPrint(gemm_decl,
+               "size_t workspace_size = " + kernel + 
"::get_workspace_size(arguments);\n");
   // Allocate workspace memory
   CutlassPrint(gemm_decl,
                "cutlass::device_memory::allocation<uint8_t> 
workspace(workspace_size);\n");
   // Instantiate CUTLASS kernel depending on template
-  CutlassPrint(gemm_decl, "Gemm gemm_op;\n");
+  CutlassPrint(gemm_decl, kernel + " gemm_op;\n");
+
   // Check the problem size is supported or not
   CutlassPrint(gemm_decl, "cutlass::Status status = 
gemm_op.can_implement(arguments);\n");
   CutlassPrint(gemm_decl, "CHECK(status == cutlass::Status::kSuccess);\n");
@@ -168,6 +165,72 @@ std::string DenseOp(std::string id, const Str2StrMap& 
attrs,
   // Launch initialized CUTLASS kernel
   CutlassPrint(gemm_decl, "status = gemm_op();\n");
   CutlassPrint(gemm_decl, "CHECK(status == cutlass::Status::kSuccess);\n");
+}
+
+std::string DenseOp(std::string id, const Str2StrMap& attrs,
+                    const std::vector<std::string>& func_args) {
+  bool has_bias = false;
+  bool is_gelu =
+      attrs.at("op_type").find("cutlass.dense_bias_gelu") != 
std::string::npos;  // fp32 or fp16
+  if (attrs.at("op_type") == "cutlass.dense_bias" ||
+      attrs.at("op_type") == "cutlass.dense_bias_relu" || is_gelu) {
+    has_bias = true;
+  }
+  std::ostringstream gemm_decl;
+  AppendPrologue(gemm_decl, attrs, func_args, "Gemm", has_bias, is_gelu, 0, 0, 
1);
+
+  CutlassPrint(gemm_decl, " {static_cast<ElementInputA*>(ptr_a), " + 
attrs.at("lda") + "},\n");
+  CutlassPrint(gemm_decl, " {static_cast<ElementInputB*>(ptr_b), " + 
attrs.at("ldb") + "},\n");
+  if (has_bias) {
+    CutlassPrint(gemm_decl, " {static_cast<ElementOutput*>(ptr_c_bias), 
0},\n");
+  } else {
+    CutlassPrint(gemm_decl, " {static_cast<ElementOutput*>(ptr_out), " + 
attrs.at("ldc") + "},\n");
+  }
+  CutlassPrint(gemm_decl, " {static_cast<ElementOutput*>(ptr_out), " + 
attrs.at("ldc") + "},\n");
+  if (has_bias && !is_gelu) {
+    CutlassPrint(gemm_decl, " {alpha},\n");
+  } else {
+    // For GeLU, we explicitly specify the scale.
+    CutlassPrint(gemm_decl, " {alpha, beta},\n");
+  }
+  CutlassPrint(gemm_decl, " 1};\n");  // split_k_slices
+
+  AppendGemmExecute(gemm_decl, "Gemm");
+  return gemm_decl.str();
+}
+
+std::string BatchMatmulOp(std::string id, const Str2StrMap& attrs,
+                          const std::vector<std::string>& func_args) {
+  std::ostringstream gemm_decl;
+  AppendPrologue(gemm_decl, attrs, func_args, "BatchedGemm", false, false, 1, 
1, 2);
+
+  auto get_batch_stride = [&attrs, &func_args](const std::string& name, int 
arg0_idx, int arg1_idx,
+                                               int arg0_axis_idx, int 
arg1_axis_idx) {
+    if (attrs.at(name) == kAnyDim) {
+      return func_args[arg0_idx] + "->shape[" + std::to_string(arg0_axis_idx) 
+ "] * " +
+             func_args[arg1_idx] + "->shape[" + std::to_string(arg1_axis_idx) 
+ "]";
+    } else {
+      return attrs.at(name);
+    }
+  };
+
+  CutlassPrint(gemm_decl, " {static_cast<ElementInputA*>(ptr_a), " + 
attrs.at("lda") + "},\n");
+  CutlassPrint(gemm_decl, get_batch_stride("batch_stride_A", 0, 0, 1, 2) + 
",\n");
+  CutlassPrint(gemm_decl, " {static_cast<ElementInputB*>(ptr_b), " + 
attrs.at("ldb") + "},\n");
+  CutlassPrint(gemm_decl, get_batch_stride("batch_stride_B", 1, 1, 1, 2) + 
",\n");
+  CutlassPrint(gemm_decl, " {static_cast<ElementOutput*>(ptr_out), " + 
attrs.at("ldc") + "},\n");
+  CutlassPrint(gemm_decl, get_batch_stride("batch_stride_C", 0, 1, 1, 1) + 
",\n");
+  CutlassPrint(gemm_decl, " {static_cast<ElementOutput*>(ptr_out), " + 
attrs.at("ldc") + "},\n");
+  CutlassPrint(gemm_decl, get_batch_stride("batch_stride_C", 0, 1, 1, 1) + 
",\n");
+  CutlassPrint(gemm_decl, " {alpha, beta},\n");
+
+  if (attrs.at("batch") == kAnyDim) {
+    CutlassPrint(gemm_decl, func_args[0] + "->shape[0]" + "};\n");
+  } else {
+    CutlassPrint(gemm_decl, attrs.at("batch") + "};\n");
+  }
+
+  AppendGemmExecute(gemm_decl, "BatchedGemm");
   return gemm_decl.str();
 }
 
@@ -279,6 +342,11 @@ class CodegenCutlass : public 
MemoizedExprTranslator<std::vector<Output>>, publi
           {"nn.dense", add_or_bias_add, "multiply", "erf", "multiply", "add", 
"multiply"});
       return GenerateBody(dense_call, "cutlass_dense_bias_gelu", 
GetArgumentNames(caller),
                           DenseArgs(std::ref(attrs_)));
+    } else if (pattern_name == "cutlass.batch_matmul") {
+      const auto* batch_matmul_call =
+          GetRootCall(callee->body.as<CallNode>(), 0, {"nn.batch_matmul"});
+      return GenerateBody(batch_matmul_call, "cutlass_batch_matmul", 
GetArgumentNames(caller),
+                          BatchMatmulArgs(std::ref(attrs_)));
     }
     LOG(FATAL) << "Unknown composite function: " << pattern_name;
     return {};
@@ -322,6 +390,8 @@ class CodegenCutlass : public 
MemoizedExprTranslator<std::vector<Output>>, publi
     if (func_name == "cutlass_dense" || func_name == "cutlass_dense_bias" ||
         func_name == "cutlass_dense_bias_relu" || func_name == 
"cutlass_dense_bias_gelu") {
       ret.decl = DenseOp(ext_func_id_, attribute_args, func_args);
+    } else if (func_name == "cutlass_batch_matmul") {
+      ret.decl = BatchMatmulOp(ext_func_id_, attribute_args, func_args);
     }
     return ret;
   }
@@ -374,6 +444,7 @@ class CutlassModuleCodegen : public 
CSourceModuleCodegenBase {
     code_stream_ << "#include <cutlass/util/host_tensor.h>\n";
     code_stream_ << "#include <cutlass/util/reference/host/tensor_fill.h>\n";
     code_stream_ << "#include <cutlass/gemm/device/gemm.h>\n";
+    code_stream_ << "#include <cutlass/gemm/device/gemm_batched.h>\n";
     code_stream_ << "#include 
<cutlass/epilogue/thread/linear_combination_bias_relu.h>\n";
     code_stream_ << "#include 
<cutlass/epilogue/thread/linear_combination_gelu.h>\n";
 
diff --git a/tests/python/contrib/test_cutlass.py 
b/tests/python/contrib/test_cutlass.py
index 0927c41..5a1ff8b 100644
--- a/tests/python/contrib/test_cutlass.py
+++ b/tests/python/contrib/test_cutlass.py
@@ -56,14 +56,16 @@ def get_ref_vm(mod, params, target="cuda"):
     return VirtualMachine(vm_exec, dev), dev
 
 
-def get_output(rt_mod, x):
-    rt_mod.set_input("data", x)
+def get_output(rt_mod, names, inputs):
+    for name, inp in zip(names, inputs):
+        rt_mod.set_input(name, inp)
     rt_mod.run()
     return rt_mod.get_output(0).asnumpy()
 
 
-def get_output_vm(vm, x):
-    return vm.invoke("main", data=x).numpy()
+def get_output_vm(vm, names, inputs):
+    params = dict(zip(names, inputs))
+    return vm.invoke("main", **params).numpy()
 
 
 def get_dense_with_shape(data_shape, weight_shape, out_dtype="float16"):
@@ -98,6 +100,16 @@ def get_dense_bias_gelu(M, N, K, out_dtype="float16"):
     return add * bias_add
 
 
+def get_batch_matmul_with_shape(x_shape, y_shape, out_dtype="float16"):
+    x = relay.var("x", shape=x_shape, dtype="float16")
+    y = relay.var("y", shape=y_shape, dtype="float16")
+    return relay.nn.batch_matmul(x, y, out_dtype=out_dtype)
+
+
+def get_batch_matmul(batch, M, N, K, out_dtype="float16"):
+    return get_batch_matmul_with_shape((batch, M, K), (batch, N, K), 
out_dtype="float16")
+
+
 def profile_and_build(mod, params, sm, tmp_dir="./tmp", lib_path="compile.so"):
     mod = partition_for_cutlass(mod)
     mod, num_cutlass_partition = tune_cutlass_kernels(
@@ -123,7 +135,9 @@ def profile_and_build_vm(
     return VirtualMachine(vm_exec, dev), dev, num_cutlass_partition
 
 
-def verify(func, M, N, K, ref_target="cuda", sm=80, atol=1e-5, rtol=1e-5, 
run_benchmark=False):
+def verify_dense(
+    func, M, N, K, ref_target="cuda", sm=80, atol=1e-5, rtol=1e-5, 
run_benchmark=False
+):
     if not has_cutlass():
         return
     mod = tvm.IRModule.from_expr(func)
@@ -151,14 +165,14 @@ def verify(func, M, N, K, ref_target="cuda", sm=80, 
atol=1e-5, rtol=1e-5, run_be
 
         rt_mod_ref, dev = get_ref_vm(mod, params, target=ref_target)
         x = tvm.nd.array(np_data, device=dev)
-        out = get_output_vm(rt_mod, x)
-        ref_out = get_output_vm(rt_mod_ref, x)
+        out = get_output_vm(rt_mod, ["data"], [x])
+        ref_out = get_output_vm(rt_mod_ref, ["data"], [x])
     else:
         rt_mod_ref, dev = get_ref_rt_mod(mod, params, target=ref_target)
         rt_mod, dev, num_partition = profile_and_build(mod, params, sm)
         x = tvm.nd.array(np_data, device=dev)
-        out = get_output(rt_mod, x)
-        ref_out = get_output(rt_mod_ref, x)
+        out = get_output(rt_mod, ["data"], [x])
+        ref_out = get_output(rt_mod_ref, ["data"], [x])
 
     assert num_partition > 0
     np.testing.assert_allclose(out, ref_out, atol=atol, rtol=rtol)
@@ -168,29 +182,65 @@ def verify(func, M, N, K, ref_target="cuda", sm=80, 
atol=1e-5, rtol=1e-5, run_be
         print("TVM with target %s:" % ref_target, rt_mod_ref.benchmark(dev, 
number=1, repeat=600))
 
 
+def verify_batch_matmul(
+    func, batch, M, N, K, ref_target="cuda", sm=80, atol=1e-5, rtol=1e-5, 
run_benchmark=False
+):
+    if not has_cutlass():
+        return
+    mod = tvm.IRModule.from_expr(func)
+    typ = relay.transform.InferType()(mod)["main"].body.checked_type
+    use_vm = any(isinstance(s, tvm.tir.Any) for s in typ.shape)
+    x_np = np.random.uniform(-1, 1, (batch, M, K)).astype("float16")
+    y_np = np.random.uniform(-1, 1, (batch, N, K)).astype("float16")
+
+    if use_vm:
+        rt_mod, dev, num_partition = profile_and_build_vm(mod, {}, sm)
+        rt_mod_ref, dev = get_ref_vm(mod, {}, target=ref_target)
+        assert num_partition > 0
+        x = tvm.nd.array(x_np, device=dev)
+        y = tvm.nd.array(y_np, device=dev)
+        out = get_output_vm(rt_mod, ["x", "y"], [x, y])
+        ref_out = get_output_vm(rt_mod_ref, ["x", "y"], [x, y])
+    else:
+        rt_mod, dev, num_partition = profile_and_build(mod, {}, sm)
+        rt_mod_ref, dev = get_ref_rt_mod(mod, {})
+        assert num_partition > 0
+
+        x = tvm.nd.array(x_np, device=dev)
+        y = tvm.nd.array(y_np, device=dev)
+        out = get_output(rt_mod, ["x", "y"], [x, y])
+        ref_out = get_output(rt_mod_ref, ["x", "y"], [x, y])
+
+    np.testing.assert_allclose(out, ref_out, atol=atol, rtol=rtol)
+
+    if True:
+        print("CUTLASS:", rt_mod.benchmark(dev, number=1, repeat=600))
+        print("TVM Tensorcore (no tuning):", rt_mod_ref.benchmark(dev, 
number=1, repeat=600))
+
+
 M = 1820
 N = 768
 K = 768
 
 
 def test_dense():
-    verify(get_dense(M, N, K), M, N, K)
-    verify(get_dense(M, N, K, out_dtype="float32"), M, N, K)
+    verify_dense(get_dense(M, N, K), M, N, K)
+    verify_dense(get_dense(M, N, K, out_dtype="float32"), M, N, K)
 
 
 def test_dense_bias():
-    verify(get_dense_bias(M, N, K), M, N, K)
-    verify(get_dense_bias(M, N, K, out_dtype="float32"), M, N, K)
+    verify_dense(get_dense_bias(M, N, K), M, N, K)
+    verify_dense(get_dense_bias(M, N, K, out_dtype="float32"), M, N, K)
 
 
 def test_dense_bias_relu():
-    verify(get_dense_bias_relu(M, N, K), M, N, K)
-    verify(get_dense_bias_relu(M, N, K, out_dtype="float32"), M, N, K)
+    verify_dense(get_dense_bias_relu(M, N, K), M, N, K)
+    verify_dense(get_dense_bias_relu(M, N, K, out_dtype="float32"), M, N, K)
 
 
 def test_dense_bias_gelu():
-    verify(get_dense_bias_gelu(M, N, K), M, N, K, atol=1e-3, rtol=1e-3)
-    verify(get_dense_bias_gelu(M, N, K, out_dtype="float32"), M, N, K, 
atol=1e-3, rtol=1e-3)
+    verify_dense(get_dense_bias_gelu(M, N, K), M, N, K, atol=1e-3, rtol=1e-3)
+    verify_dense(get_dense_bias_gelu(M, N, K, out_dtype="float32"), M, N, K, 
atol=1e-3, rtol=1e-3)
 
 
 def test_dense_dynamic():
@@ -200,7 +250,7 @@ def test_dense_dynamic():
     if has_cublas():
         # TVM native fp16 dense (without tensorcore), using fp16 accum, seems 
to have accuracy issues
         # Use cublas as a reference
-        verify(
+        verify_dense(
             get_dense_with_shape(data_shape, weight_shape),
             M,
             N,
@@ -208,7 +258,7 @@ def test_dense_dynamic():
             ref_target="cuda -libs=cublas",
         )
 
-    verify(
+    verify_dense(
         get_dense_with_shape(data_shape, weight_shape, out_dtype="float32"),
         M,
         N,
@@ -218,5 +268,26 @@ def test_dense_dynamic():
     )
 
 
+def test_batch_matmul():
+    batch = 8
+    verify_batch_matmul(get_batch_matmul(batch, M, N, K), batch, M, N, K)
+    verify_batch_matmul(get_batch_matmul(batch, M, N, K, out_dtype="float32"), 
batch, M, N, K)
+
+    if has_cublas():
+        # Test dynamic shape batch_matmul
+        # AutoTVM does not seem to support it
+        x_shape = (relay.Any(), relay.Any(), K)
+        y_shape = (relay.Any(), relay.Any(), K)
+
+        verify_batch_matmul(
+            get_batch_matmul_with_shape(x_shape, y_shape),
+            batch,
+            M,
+            N,
+            K,
+            ref_target="cuda -libs=cublas",
+        )
+
+
 if __name__ == "__main__":
     pytest.main([__file__])

Reply via email to