This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new c7a01a4 [CUTLASS] Support batch_matmul (#9439)
c7a01a4 is described below
commit c7a01a4d457a2d1b1939dad82b572fb738ff4754
Author: masahi <[email protected]>
AuthorDate: Thu Nov 4 15:59:05 2021 +0900
[CUTLASS] Support batch_matmul (#9439)
* Import batched gemm change
commit cfacfa296e2487a189e52d189567b140c675ccc2
Author: Masahiro Masuda <[email protected]>
Date: Mon Nov 1 15:57:49 2021 +0900
change is_constant pattern to wildcard in gelu pattern
commit 84da94306ca81209a8ccc44fd7d606cbce047082
Author: Masahiro Masuda <[email protected]>
Date: Mon Nov 1 05:41:11 2021 +0900
fixed batch stride C
commit 66e5779ee69dc0cd3969f268608b551ec549d79b
Author: Masahiro Masuda <[email protected]>
Date: Sun Oct 31 20:47:16 2021 +0900
refactoring codegen
commit 561daeafa66cddf6a565b537072a5efce0b0dbf1
Author: Masahiro Masuda <[email protected]>
Date: Sun Oct 31 20:05:20 2021 +0900
generated kernel compiled and result match
commit a5740bcf5287097b64dff8adb50f0cddc2c41349
Author: Masahiro Masuda <[email protected]>
Date: Sun Oct 31 19:36:53 2021 +0900
partitioning looks good
commit 59112fdf78a4541905fad9b899737600e0ed9391
Author: Masahiro Masuda <[email protected]>
Date: Sun Oct 31 19:01:47 2021 +0900
[WIP] cutlass batch matmul support
* fixed test
* refactoring
* gelu test fixed
* more refactor
* batch_matmul fp32 accum working
* dynamic batch matmul working
* black
* remove doc TODO
---
python/tvm/contrib/cutlass/build.py | 141 +++++++++++++++++-----
python/tvm/contrib/cutlass/gemm_operation.py | 15 +--
python/tvm/contrib/cutlass/gen_gemm.py | 42 ++++---
python/tvm/contrib/cutlass/library.py | 2 +
python/tvm/relay/op/contrib/cutlass.py | 11 +-
src/relay/backend/contrib/cutlass/codegen.cc | 167 +++++++++++++++++++--------
tests/python/contrib/test_cutlass.py | 109 ++++++++++++++---
7 files changed, 363 insertions(+), 124 deletions(-)
diff --git a/python/tvm/contrib/cutlass/build.py
b/python/tvm/contrib/cutlass/build.py
index 58e7a11..615b900 100644
--- a/python/tvm/contrib/cutlass/build.py
+++ b/python/tvm/contrib/cutlass/build.py
@@ -83,6 +83,87 @@ class GemmAnnotator(tvm.relay.ExprVisitor):
self.signature["ret_dtype"] = op.ret_type.dtype
+def select_gemm_kernel(
+ cutlass_profiler, MM, KK, NN, out_dtype, batched, profile_all,
use_multiprocessing
+):
+ """Run CUTLASS profiler to select the best kernel, or return the default
one for dynamic
+ workloads."""
+ if any(isinstance(s, tvm.tir.Any) for s in [MM, KK, NN]):
+ out = cutlass_profiler.get_default(out_dtype, batched=batched)
+ logger.info("Picked the default kernel %s", out["name"])
+ else:
+ out = cutlass_profiler.profile(
+ MM,
+ NN,
+ KK,
+ out_dtype,
+ batched=batched,
+ profile_all=profile_all,
+ use_multiprocessing=use_multiprocessing,
+ )
+ if profile_all:
+ logger.info("The best kernel is %s", out["name"])
+ else:
+ logger.info("Picked the first kernel found %s", out["name"])
+ return out
+
+
+def handle_batch_matmul(
+ cutlass_profiler, op_type, arg0_shape, arg1_shape, out_dtype, profile_all,
use_multiprocessing
+):
+ """Profile and select a kernel for batch_matmul op workload."""
+ MM = arg0_shape[1]
+ KK = arg0_shape[2]
+ NN = arg1_shape[1]
+
+ out = select_gemm_kernel(
+ cutlass_profiler, MM, KK, NN, out_dtype, True, profile_all,
use_multiprocessing
+ )
+
+ if op_type == "cutlass.batch_matmul":
+ cutlass_op_def = out["opdef"]
+ else:
+ raise ValueError("%s pattern is not implemented." % op_type)
+
+ return {
+ "batch": arg0_shape[0],
+ "batch_stride_A": arg0_shape[1] * arg0_shape[2],
+ "batch_stride_B": arg1_shape[1] * arg1_shape[2],
+ "batch_stride_C": arg0_shape[1] * arg1_shape[1],
+ "cutlass_op_def": cutlass_op_def,
+ "cutlass_op_name": out["name"],
+ }
+
+
+def handle_dense(
+ cutlass_profiler, op_type, arg0_shape, arg1_shape, out_dtype, profile_all,
use_multiprocessing
+):
+ """Profile and select a kernel for dense op workload."""
+ MM = arg0_shape[0]
+ KK = arg0_shape[1]
+ NN = arg1_shape[0]
+
+ out = select_gemm_kernel(
+ cutlass_profiler, MM, KK, NN, out_dtype, False, profile_all,
use_multiprocessing
+ )
+
+ if op_type == "cutlass.dense":
+ cutlass_op_def = out["opdef"]
+ elif op_type == "cutlass.dense_bias":
+ cutlass_op_def = out["opdef_bias"]
+ elif op_type == "cutlass.dense_bias_relu":
+ cutlass_op_def = out["opdef_bias_relu"]
+ elif "cutlass.dense_bias_gelu" in op_type:
+ cutlass_op_def = out["opdef_bias_gelu"]
+ else:
+ raise ValueError("%s pattern is not implemented." % op_type)
+
+ return {
+ "cutlass_op_def": cutlass_op_def,
+ "cutlass_op_name": out["name"],
+ }
+
+
def tune_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False,
tmp_dir="./tmp"):
"""Given a module partitioned for CUTLASS offloading, profile each
workload to select which
kernels to emit.
@@ -123,41 +204,41 @@ def tune_cutlass_kernels(mod, sm, profile_all=True,
use_multiprocessing=False, t
if "cutlass" in fun_name:
num_cutlass_partition += 1
annotator.visit(func)
- # call cutlass profiler to find best settings, update attr
- new_attrs = {}
+ out_dtype = annotator.signature["ret_dtype"]
+ op_type = annotator.signature["op_type"]
+
+ new_attrs = {"op_type": op_type}
new_attrs.update(annotator.signature)
- for key in func.attrs.keys():
- new_attrs[key] = func.attrs[key]
- # call profiler
+ new_attrs.update(func.attrs)
arg0_shape = new_attrs["arg0_shape"]
arg1_shape = new_attrs["arg1_shape"]
- MM = arg0_shape[0]
- KK = arg0_shape[1]
- NN = arg1_shape[0]
- out_dtype = annotator.signature["ret_dtype"]
- if any(isinstance(s, tvm.tir.Any) for s in [MM, KK, NN]):
- out = cutlass_profiler.get_default(out_dtype)
- logger.info("Picked the default kernel %s", out["name"])
- else:
- out = cutlass_profiler.profile(
- MM, NN, KK, out_dtype, profile_all, use_multiprocessing
+
+ if "batch_matmul" in op_type:
+ new_attrs.update(
+ handle_batch_matmul(
+ cutlass_profiler,
+ op_type,
+ arg0_shape,
+ arg1_shape,
+ out_dtype,
+ profile_all,
+ use_multiprocessing,
+ )
+ )
+ elif "dense" in op_type:
+ new_attrs.update(
+ handle_dense(
+ cutlass_profiler,
+ op_type,
+ arg0_shape,
+ arg1_shape,
+ out_dtype,
+ profile_all,
+ use_multiprocessing,
+ )
)
- if profile_all:
- logger.info("The best kernel is %s", out["name"])
- else:
- logger.info("Picked the first kernel found %s",
out["name"])
-
- if new_attrs["op_type"] == "cutlass.dense":
- new_attrs["cutlass_op_def"] = out["opdef"]
- elif new_attrs["op_type"] == "cutlass.dense_bias":
- new_attrs["cutlass_op_def"] = out["opdef_bias"]
- elif new_attrs["op_type"] == "cutlass.dense_bias_relu":
- new_attrs["cutlass_op_def"] = out["opdef_bias_relu"]
- elif "cutlass.dense_bias_gelu" in new_attrs["op_type"]:
- new_attrs["cutlass_op_def"] = out["opdef_bias_gelu"]
else:
- raise ValueError("%s pattern is not implemented." %
new_attrs["op_type"])
- new_attrs["cutlass_op_name"] = out["name"]
+ raise ValueError("%s unsupported composite" % op_type)
if new_attrs["cutlass_op_name"].find("_tn_align") > 0:
new_attrs["lda"] = "K"
diff --git a/python/tvm/contrib/cutlass/gemm_operation.py
b/python/tvm/contrib/cutlass/gemm_operation.py
index e53b3ee..4673b4b 100644
--- a/python/tvm/contrib/cutlass/gemm_operation.py
+++ b/python/tvm/contrib/cutlass/gemm_operation.py
@@ -174,7 +174,7 @@ class EmitGemmInstance:
>"""
self.gemm_template = """
// Gemm operator ${operation_name}
- using Operation_${operation_name} = cutlass::gemm::device::Gemm<
+ using Operation_${operation_name} = cutlass::gemm::device::${kernel_name}<
${element_a}, ${layout_a},
${element_b}, ${layout_b},
${element_c}, ${layout_c},
@@ -189,13 +189,12 @@ class EmitGemmInstance:
${stages},
${align_a},
${align_b},
- false,
+ ${split_k_serial}
${math_operation}
- ${residual}
>;
"""
- def emit(self, operation, no_beta_scaling=False):
+ def emit(self, operation, no_beta_scaling=False, batched=False):
"""Instantiate a GEMM kernel from given `operation`."""
warp_shape = [
operation.tile_description.threadblock_shape[idx]
@@ -206,8 +205,6 @@ class EmitGemmInstance:
min(operation.C.alignment * DataTypeSize[operation.C.element], 128)
// DataTypeSize[operation.C.element]
)
- residual = ""
- complex_transform_tag = "cutlass::ComplexTransform::kNone"
values = {
"operation_name": operation.procedural_name(),
"element_a": DataTypeTag[operation.A.element],
@@ -243,14 +240,14 @@ class EmitGemmInstance:
"stages": str(operation.tile_description.stages),
"align_a": str(operation.A.alignment),
"align_b": str(operation.B.alignment),
- "transform_a": complex_transform_tag,
- "transform_b": complex_transform_tag,
"math_operation": MathOperationTag[
operation.tile_description.math_instruction.math_operation
],
- "residual": residual,
}
+ values["kernel_name"] = "GemmBatched" if batched else "Gemm"
+ values["split_k_serial"] = "" if batched else "false,"
+
gemm_template = substitute_template(
self.gemm_template,
{
diff --git a/python/tvm/contrib/cutlass/gen_gemm.py
b/python/tvm/contrib/cutlass/gen_gemm.py
index a43c6d4..1ed4bfe 100644
--- a/python/tvm/contrib/cutlass/gen_gemm.py
+++ b/python/tvm/contrib/cutlass/gen_gemm.py
@@ -47,6 +47,7 @@ def create_gemm_operator(
alignment_constraints,
epilogue_functor=EpilogueFunctor.LinearCombination,
swizzling_functor=SwizzlingFunctor.Identity8,
+ batched=False,
):
"""Exhaustively instantiate all kernels from a given configuration."""
ret = []
@@ -55,6 +56,9 @@ def create_gemm_operator(
element_a, element_b, element_c, element_epilogue = data_type
+ if batched:
+ swizzling_functor = SwizzlingFunctor.Batched
+
for layout in layouts:
for tile_description in tile_descriptions:
for alignment in alignment_constraints:
@@ -109,15 +113,17 @@ def create_gemm_operator(
kernel_emitter = EmitGemmInstance()
op_entry["op"] = op
op_entry["name"] = op.procedural_name()
- op_entry["opdef"] = kernel_emitter.emit(op)
- op_entry["opdef_bias"] = kernel_emitter.emit(op_bias,
no_beta_scaling=True)
+ op_entry["opdef"] = kernel_emitter.emit(op, batched=batched)
+ op_entry["opdef_bias"] = kernel_emitter.emit(
+ op_bias, no_beta_scaling=True, batched=batched
+ )
op_entry["opdef_bias_relu"] = kernel_emitter.emit(
- op_bias_relu, no_beta_scaling=True
+ op_bias_relu, no_beta_scaling=True, batched=batched
)
- op_entry["opdef_bias_gelu"] = kernel_emitter.emit(op_bias_gelu)
+ op_entry["opdef_bias_gelu"] =
kernel_emitter.emit(op_bias_gelu, batched=batched)
op_entry["src"] = profiler_emitter.emit(
op.procedural_name(),
- op_entry["opdef"],
+ kernel_emitter.emit(op, batched=False),
DataTypeTag[element_a],
DataTypeTag[element_b],
DataTypeTag[element_c],
@@ -128,7 +134,9 @@ def create_gemm_operator(
return ret
-def generate_tensor_op_common(math_instructions, alignment_constraints,
get_tile_descriptions):
+def generate_tensor_op_common(
+ math_instructions, alignment_constraints, get_tile_descriptions,
batched=False
+):
"""Common kernel generator to be used by archtecture specific
generators."""
ops = []
layouts = [
@@ -143,14 +151,16 @@ def generate_tensor_op_common(math_instructions,
alignment_constraints, get_tile
math_inst.element_accumulator,
]
- out = create_gemm_operator(layouts, tile_descriptions, data_type,
alignment_constraints)
+ out = create_gemm_operator(
+ layouts, tile_descriptions, data_type, alignment_constraints,
batched=batched
+ )
ops.extend(out)
return ops
-def generate_sm75_tensor_op_1688(out_dtype):
+def generate_sm75_tensor_op_1688(out_dtype, batched=False):
"""Generate GEMM kernels for Turing."""
assert out_dtype in ["float32", "float16"]
math_instructions = {
@@ -192,11 +202,11 @@ def generate_sm75_tensor_op_1688(out_dtype):
]
return generate_tensor_op_common(
- math_instructions, alignment_constraints, get_tile_descriptions
+ math_instructions, alignment_constraints, get_tile_descriptions,
batched
)
-def generate_sm80_tensor_op_16816(out_dtype):
+def generate_sm80_tensor_op_16816(out_dtype, batched=False):
"""Generate GEMM kernels for Ampere."""
assert out_dtype in ["float32", "float16"]
math_instructions = {
@@ -250,7 +260,7 @@ def generate_sm80_tensor_op_16816(out_dtype):
]
return generate_tensor_op_common(
- math_instructions, alignment_constraints, get_tile_descriptions
+ math_instructions, alignment_constraints, get_tile_descriptions,
batched
)
@@ -350,17 +360,19 @@ class CutlassGemmProfiler(object):
return False
return True
- def get_default(self, out_dtype):
+ def get_default(self, out_dtype, batched=False):
"""Return the default kernel for the requested architecture.
For now, the default kernel was picked arbitrary.
"""
- ops = GENERATOR_FUNC_TABLE[self.sm](out_dtype)
+ ops = GENERATOR_FUNC_TABLE[self.sm](out_dtype, batched)
default_kernel_name = DEFAULT_KERNELS[self.sm][out_dtype]
filtered = list(filter(lambda op: op["name"] == default_kernel_name,
ops))
assert len(filtered) == 1
return filtered[0]
- def profile(self, M, N, K, out_dtype, profile_all=True,
use_multiprocessing=False):
+ def profile(
+ self, M, N, K, out_dtype, profile_all=True, use_multiprocessing=False,
batched=False
+ ):
"""Profile and select the best kernel from candidate kernels.
If profile_all is False, return immediately after the first applicable
kernel is found.
If use_multiprocessing is True, compile all profiler executables in
parallel.
@@ -368,7 +380,7 @@ class CutlassGemmProfiler(object):
if (M, N, K) in self.cache:
return self.cache[(M, N, K)]
- ops = GENERATOR_FUNC_TABLE[self.sm](out_dtype)
+ ops = GENERATOR_FUNC_TABLE[self.sm](out_dtype, batched)
ops = list(filter(lambda op: self.check_align(op["name"], M), ops))
for op in ops:
diff --git a/python/tvm/contrib/cutlass/library.py
b/python/tvm/contrib/cutlass/library.py
index 7d54429..a3b90ff 100644
--- a/python/tvm/contrib/cutlass/library.py
+++ b/python/tvm/contrib/cutlass/library.py
@@ -160,6 +160,7 @@ class SwizzlingFunctor(enum.Enum):
Identity2 = enum_auto()
Identity4 = enum_auto()
Identity8 = enum_auto()
+ Batched = enum_auto()
SwizzlingFunctorTag = {
@@ -167,6 +168,7 @@ SwizzlingFunctorTag = {
SwizzlingFunctor.Identity2:
"cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<2>",
SwizzlingFunctor.Identity4:
"cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>",
SwizzlingFunctor.Identity8:
"cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>",
+ SwizzlingFunctor.Batched:
"cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle",
}
diff --git a/python/tvm/relay/op/contrib/cutlass.py
b/python/tvm/relay/op/contrib/cutlass.py
index 631089c..8ed3718 100644
--- a/python/tvm/relay/op/contrib/cutlass.py
+++ b/python/tvm/relay/op/contrib/cutlass.py
@@ -20,13 +20,13 @@ from ...dataflow_pattern import wildcard, is_op, is_constant
def make_gelu_pattern(bias_out, out_dtype="float16"):
- mul = is_op("multiply")(bias_out, is_constant())
+ mul = is_op("multiply")(bias_out, is_constant() | wildcard())
if out_dtype == "float16":
erf = is_op("cast")(is_op("erf")(is_op("cast")(mul)))
else:
erf = is_op("erf")(mul)
- mul_half = is_op("multiply")(erf, is_constant())
- add = is_op("add")(mul_half, is_constant())
+ mul_half = is_op("multiply")(erf, is_constant() | wildcard())
+ add = is_op("add")(mul_half, is_constant() | wildcard())
return is_op("multiply")(add, bias_out)
@@ -51,6 +51,10 @@ def make_gemm_pattern(with_bias=True, with_act=None,
out_dtype="float16"):
return make_gelu_pattern(gemm_out, out_dtype)
+def make_batch_matmul_pattern():
+ return is_op("nn.batch_matmul")(wildcard(), wildcard())
+
+
def partition_for_cutlass(mod):
"""Partition the input module into CUTLASS-supported subgraphs."""
dense_pat = ("cutlass.dense", make_gemm_pattern(False, None))
@@ -67,6 +71,7 @@ def partition_for_cutlass(mod):
dense_bias_relu_pat,
dense_bias_pat,
dense_pat,
+ ("cutlass.batch_matmul", make_batch_matmul_pattern()),
]
mod = transform.MergeComposite(cutlass_patterns)(mod)
mod = transform.AnnotateTarget(["cutlass"])(mod)
diff --git a/src/relay/backend/contrib/cutlass/codegen.cc
b/src/relay/backend/contrib/cutlass/codegen.cc
index c1217a0..f154f86 100644
--- a/src/relay/backend/contrib/cutlass/codegen.cc
+++ b/src/relay/backend/contrib/cutlass/codegen.cc
@@ -54,19 +54,21 @@ std::string GetDimAsStr(ObjectRef dim) {
return kAnyDim;
}
-Str2StrMap DenseArgs(const Map<String, ObjectRef>& attrs) {
+inline void CutlassPrint(std::ostringstream& os, const std::string& stmt, int
indent = 2) {
+ for (int i = 0; i < indent; ++i) {
+ os << " ";
+ }
+ os << stmt;
+}
+
+Str2StrMap GemmArgsCommon(const Map<String, ObjectRef>& attrs) {
Str2StrMap args;
auto arg0_dtype = std::string(attrs["arg0_dtype"].as<StringObj>()->data);
auto arg1_dtype = std::string(attrs["arg1_dtype"].as<StringObj>()->data);
auto ret_dtype = std::string(attrs["ret_dtype"].as<StringObj>()->data);
- auto arg0_shape = attrs["arg0_shape"].as<ArrayNode>();
- auto arg1_shape = attrs["arg1_shape"].as<ArrayNode>();
args["ElementInputA"] = dtype_map.at(arg0_dtype);
args["ElementInputB"] = dtype_map.at(arg1_dtype);
args["ElementOutput"] = dtype_map.at(ret_dtype);
- args["M"] = GetDimAsStr(arg0_shape->at(0));
- args["K"] = GetDimAsStr(arg0_shape->at(1));
- args["N"] = GetDimAsStr(arg1_shape->at(0));
args["op_def"] = std::string(attrs["cutlass_op_def"].as<StringObj>()->data);
args["op_name"] =
std::string(attrs["cutlass_op_name"].as<StringObj>()->data);
args["op_type"] = std::string(attrs["op_type"].as<StringObj>()->data);
@@ -76,23 +78,33 @@ Str2StrMap DenseArgs(const Map<String, ObjectRef>& attrs) {
return args;
}
-inline void CutlassPrint(std::ostringstream& os, const std::string& stmt, int
indent = 2) {
- for (int i = 0; i < indent; ++i) {
- os << " ";
- }
- os << stmt;
+Str2StrMap DenseArgs(const Map<String, ObjectRef>& attrs) {
+ Str2StrMap args = GemmArgsCommon(attrs);
+ auto arg0_shape = attrs["arg0_shape"].as<ArrayNode>();
+ auto arg1_shape = attrs["arg1_shape"].as<ArrayNode>();
+ args["M"] = GetDimAsStr(arg0_shape->at(0));
+ args["K"] = GetDimAsStr(arg0_shape->at(1));
+ args["N"] = GetDimAsStr(arg1_shape->at(0));
+ return args;
}
-std::string DenseOp(std::string id, const Str2StrMap& attrs,
- const std::vector<std::string>& func_args) {
- bool has_bias = false;
- bool is_gelu =
- attrs.at("op_type").find("cutlass.dense_bias_gelu") !=
std::string::npos; // fp32 or fp16
- if (attrs.at("op_type") == "cutlass.dense_bias" ||
- attrs.at("op_type") == "cutlass.dense_bias_relu" || is_gelu) {
- has_bias = true;
- }
- std::ostringstream gemm_decl;
+Str2StrMap BatchMatmulArgs(const Map<String, ObjectRef>& attrs) {
+ Str2StrMap args = GemmArgsCommon(attrs);
+ args["batch"] = GetDimAsStr(attrs["batch"]);
+ args["batch_stride_A"] = GetDimAsStr(attrs["batch_stride_A"]);
+ args["batch_stride_B"] = GetDimAsStr(attrs["batch_stride_B"]);
+ args["batch_stride_C"] = GetDimAsStr(attrs["batch_stride_C"]);
+ auto arg0_shape = attrs["arg0_shape"].as<ArrayNode>();
+ auto arg1_shape = attrs["arg1_shape"].as<ArrayNode>();
+ args["M"] = GetDimAsStr(arg0_shape->at(1));
+ args["K"] = GetDimAsStr(arg0_shape->at(2));
+ args["N"] = GetDimAsStr(arg1_shape->at(1));
+ return args;
+}
+
+void AppendPrologue(std::ostringstream& gemm_decl, const Str2StrMap& attrs,
+ const std::vector<std::string>& func_args, const
std::string& kernel,
+ bool has_bias, bool is_gelu, int m_axis_idx, int
n_axis_idx, int k_axis_idx) {
CutlassPrint(gemm_decl, "using ElementInputA = " + attrs.at("ElementInputA")
+ ";\n");
CutlassPrint(gemm_decl, "using ElementInputB = " + attrs.at("ElementInputB")
+ ";\n");
CutlassPrint(gemm_decl, "using ElementOutput = " + attrs.at("ElementOutput")
+ ";\n");
@@ -107,11 +119,10 @@ std::string DenseOp(std::string id, const Str2StrMap&
attrs,
return attrs.at(axis);
}
};
- CutlassPrint(gemm_decl, "int M = " + get_dim("M", 0, 0) + ";\n");
- CutlassPrint(gemm_decl, "int N = " + get_dim("N", 1, 0) + ";\n");
- CutlassPrint(gemm_decl, "int K = " + get_dim("K", 0, 1) + ";\n");
+ CutlassPrint(gemm_decl, "int M = " + get_dim("M", 0, m_axis_idx) + ";\n");
+ CutlassPrint(gemm_decl, "int N = " + get_dim("N", 1, n_axis_idx) + ";\n");
+ CutlassPrint(gemm_decl, "int K = " + get_dim("K", 0, k_axis_idx) + ";\n");
CutlassPrint(gemm_decl, "cutlass::gemm::GemmCoord problem_size(M, N, K);\n");
- // Initialize alpha for dot product computation
CutlassPrint(gemm_decl, "ElementComputeEpilogue alpha =
ElementComputeEpilogue(1);\n");
if (is_gelu) {
// GeLU epilogue does not compile with NoBetaScaling, so we explicitly
specify the scale.
@@ -120,11 +131,6 @@ std::string DenseOp(std::string id, const Str2StrMap&
attrs,
CutlassPrint(gemm_decl, "ElementComputeEpilogue beta =
ElementComputeEpilogue(0);\n");
}
- // Split K dimension into 1 partitions
- CutlassPrint(gemm_decl, "int split_k_slices = 1;\n");
-
- // Create a tuple of gemm kernel arguments. This is later passed as
arguments to launch
- // instantiated CUTLASS kernel
ICHECK(func_args.size() >= 2);
CutlassPrint(gemm_decl, "void* ptr_a = (void*)(" + func_args[0] +
"->data);\n");
CutlassPrint(gemm_decl, "void* ptr_b = (void*)(" + func_args[1] +
"->data);\n");
@@ -132,33 +138,24 @@ std::string DenseOp(std::string id, const Str2StrMap&
attrs,
ICHECK(func_args.size() >= 3);
CutlassPrint(gemm_decl, "void* ptr_c_bias = (void*)(" + func_args[2] +
"->data);\n");
}
+
CutlassPrint(gemm_decl, "void* ptr_out = (void*)(out0);\n");
- CutlassPrint(gemm_decl, "typename Gemm::Arguments arguments{\n");
+ CutlassPrint(gemm_decl, "using " + kernel + " = Operation_" +
attrs.at("op_name") + ";\n");
+ CutlassPrint(gemm_decl, "typename " + kernel + "::Arguments arguments{\n");
CutlassPrint(gemm_decl, " problem_size,\n");
- CutlassPrint(gemm_decl, " {static_cast<ElementInputA*>(ptr_a), " +
attrs.at("lda") + "},\n");
- CutlassPrint(gemm_decl, " {static_cast<ElementInputB*>(ptr_b), " +
attrs.at("ldb") + "},\n");
- if (has_bias) {
- CutlassPrint(gemm_decl, " {static_cast<ElementOutput*>(ptr_c_bias),
0},\n");
- } else {
- CutlassPrint(gemm_decl, " {static_cast<ElementOutput*>(ptr_out), " +
attrs.at("ldc") + "},\n");
- }
- CutlassPrint(gemm_decl, " {static_cast<ElementOutput*>(ptr_out), " +
attrs.at("ldc") + "},\n");
- if (has_bias && !is_gelu) {
- CutlassPrint(gemm_decl, " {alpha},\n");
- } else {
- // For GeLU, we explicitly specify the scale.
- CutlassPrint(gemm_decl, " {alpha, beta},\n");
- }
- CutlassPrint(gemm_decl, " split_k_slices};\n");
+}
+void AppendGemmExecute(std::ostringstream& gemm_decl, const std::string&
kernel) {
// Using the arguments, query for extra workspace required for matrix
multiplication computation
- CutlassPrint(gemm_decl, "size_t workspace_size =
Gemm::get_workspace_size(arguments);\n");
+ CutlassPrint(gemm_decl,
+ "size_t workspace_size = " + kernel +
"::get_workspace_size(arguments);\n");
// Allocate workspace memory
CutlassPrint(gemm_decl,
"cutlass::device_memory::allocation<uint8_t>
workspace(workspace_size);\n");
// Instantiate CUTLASS kernel depending on template
- CutlassPrint(gemm_decl, "Gemm gemm_op;\n");
+ CutlassPrint(gemm_decl, kernel + " gemm_op;\n");
+
// Check the problem size is supported or not
CutlassPrint(gemm_decl, "cutlass::Status status =
gemm_op.can_implement(arguments);\n");
CutlassPrint(gemm_decl, "CHECK(status == cutlass::Status::kSuccess);\n");
@@ -168,6 +165,72 @@ std::string DenseOp(std::string id, const Str2StrMap&
attrs,
// Launch initialized CUTLASS kernel
CutlassPrint(gemm_decl, "status = gemm_op();\n");
CutlassPrint(gemm_decl, "CHECK(status == cutlass::Status::kSuccess);\n");
+}
+
+std::string DenseOp(std::string id, const Str2StrMap& attrs,
+ const std::vector<std::string>& func_args) {
+ bool has_bias = false;
+ bool is_gelu =
+ attrs.at("op_type").find("cutlass.dense_bias_gelu") !=
std::string::npos; // fp32 or fp16
+ if (attrs.at("op_type") == "cutlass.dense_bias" ||
+ attrs.at("op_type") == "cutlass.dense_bias_relu" || is_gelu) {
+ has_bias = true;
+ }
+ std::ostringstream gemm_decl;
+ AppendPrologue(gemm_decl, attrs, func_args, "Gemm", has_bias, is_gelu, 0, 0,
1);
+
+ CutlassPrint(gemm_decl, " {static_cast<ElementInputA*>(ptr_a), " +
attrs.at("lda") + "},\n");
+ CutlassPrint(gemm_decl, " {static_cast<ElementInputB*>(ptr_b), " +
attrs.at("ldb") + "},\n");
+ if (has_bias) {
+ CutlassPrint(gemm_decl, " {static_cast<ElementOutput*>(ptr_c_bias),
0},\n");
+ } else {
+ CutlassPrint(gemm_decl, " {static_cast<ElementOutput*>(ptr_out), " +
attrs.at("ldc") + "},\n");
+ }
+ CutlassPrint(gemm_decl, " {static_cast<ElementOutput*>(ptr_out), " +
attrs.at("ldc") + "},\n");
+ if (has_bias && !is_gelu) {
+ CutlassPrint(gemm_decl, " {alpha},\n");
+ } else {
+ // For GeLU, we explicitly specify the scale.
+ CutlassPrint(gemm_decl, " {alpha, beta},\n");
+ }
+ CutlassPrint(gemm_decl, " 1};\n"); // split_k_slices
+
+ AppendGemmExecute(gemm_decl, "Gemm");
+ return gemm_decl.str();
+}
+
+std::string BatchMatmulOp(std::string id, const Str2StrMap& attrs,
+ const std::vector<std::string>& func_args) {
+ std::ostringstream gemm_decl;
+ AppendPrologue(gemm_decl, attrs, func_args, "BatchedGemm", false, false, 1,
1, 2);
+
+ auto get_batch_stride = [&attrs, &func_args](const std::string& name, int
arg0_idx, int arg1_idx,
+ int arg0_axis_idx, int
arg1_axis_idx) {
+ if (attrs.at(name) == kAnyDim) {
+ return func_args[arg0_idx] + "->shape[" + std::to_string(arg0_axis_idx)
+ "] * " +
+ func_args[arg1_idx] + "->shape[" + std::to_string(arg1_axis_idx)
+ "]";
+ } else {
+ return attrs.at(name);
+ }
+ };
+
+ CutlassPrint(gemm_decl, " {static_cast<ElementInputA*>(ptr_a), " +
attrs.at("lda") + "},\n");
+ CutlassPrint(gemm_decl, get_batch_stride("batch_stride_A", 0, 0, 1, 2) +
",\n");
+ CutlassPrint(gemm_decl, " {static_cast<ElementInputB*>(ptr_b), " +
attrs.at("ldb") + "},\n");
+ CutlassPrint(gemm_decl, get_batch_stride("batch_stride_B", 1, 1, 1, 2) +
",\n");
+ CutlassPrint(gemm_decl, " {static_cast<ElementOutput*>(ptr_out), " +
attrs.at("ldc") + "},\n");
+ CutlassPrint(gemm_decl, get_batch_stride("batch_stride_C", 0, 1, 1, 1) +
",\n");
+ CutlassPrint(gemm_decl, " {static_cast<ElementOutput*>(ptr_out), " +
attrs.at("ldc") + "},\n");
+ CutlassPrint(gemm_decl, get_batch_stride("batch_stride_C", 0, 1, 1, 1) +
",\n");
+ CutlassPrint(gemm_decl, " {alpha, beta},\n");
+
+ if (attrs.at("batch") == kAnyDim) {
+ CutlassPrint(gemm_decl, func_args[0] + "->shape[0]" + "};\n");
+ } else {
+ CutlassPrint(gemm_decl, attrs.at("batch") + "};\n");
+ }
+
+ AppendGemmExecute(gemm_decl, "BatchedGemm");
return gemm_decl.str();
}
@@ -279,6 +342,11 @@ class CodegenCutlass : public
MemoizedExprTranslator<std::vector<Output>>, publi
{"nn.dense", add_or_bias_add, "multiply", "erf", "multiply", "add",
"multiply"});
return GenerateBody(dense_call, "cutlass_dense_bias_gelu",
GetArgumentNames(caller),
DenseArgs(std::ref(attrs_)));
+ } else if (pattern_name == "cutlass.batch_matmul") {
+ const auto* batch_matmul_call =
+ GetRootCall(callee->body.as<CallNode>(), 0, {"nn.batch_matmul"});
+ return GenerateBody(batch_matmul_call, "cutlass_batch_matmul",
GetArgumentNames(caller),
+ BatchMatmulArgs(std::ref(attrs_)));
}
LOG(FATAL) << "Unknown composite function: " << pattern_name;
return {};
@@ -322,6 +390,8 @@ class CodegenCutlass : public
MemoizedExprTranslator<std::vector<Output>>, publi
if (func_name == "cutlass_dense" || func_name == "cutlass_dense_bias" ||
func_name == "cutlass_dense_bias_relu" || func_name ==
"cutlass_dense_bias_gelu") {
ret.decl = DenseOp(ext_func_id_, attribute_args, func_args);
+ } else if (func_name == "cutlass_batch_matmul") {
+ ret.decl = BatchMatmulOp(ext_func_id_, attribute_args, func_args);
}
return ret;
}
@@ -374,6 +444,7 @@ class CutlassModuleCodegen : public
CSourceModuleCodegenBase {
code_stream_ << "#include <cutlass/util/host_tensor.h>\n";
code_stream_ << "#include <cutlass/util/reference/host/tensor_fill.h>\n";
code_stream_ << "#include <cutlass/gemm/device/gemm.h>\n";
+ code_stream_ << "#include <cutlass/gemm/device/gemm_batched.h>\n";
code_stream_ << "#include
<cutlass/epilogue/thread/linear_combination_bias_relu.h>\n";
code_stream_ << "#include
<cutlass/epilogue/thread/linear_combination_gelu.h>\n";
diff --git a/tests/python/contrib/test_cutlass.py
b/tests/python/contrib/test_cutlass.py
index 0927c41..5a1ff8b 100644
--- a/tests/python/contrib/test_cutlass.py
+++ b/tests/python/contrib/test_cutlass.py
@@ -56,14 +56,16 @@ def get_ref_vm(mod, params, target="cuda"):
return VirtualMachine(vm_exec, dev), dev
-def get_output(rt_mod, x):
- rt_mod.set_input("data", x)
+def get_output(rt_mod, names, inputs):
+ for name, inp in zip(names, inputs):
+ rt_mod.set_input(name, inp)
rt_mod.run()
return rt_mod.get_output(0).asnumpy()
-def get_output_vm(vm, x):
- return vm.invoke("main", data=x).numpy()
+def get_output_vm(vm, names, inputs):
+ params = dict(zip(names, inputs))
+ return vm.invoke("main", **params).numpy()
def get_dense_with_shape(data_shape, weight_shape, out_dtype="float16"):
@@ -98,6 +100,16 @@ def get_dense_bias_gelu(M, N, K, out_dtype="float16"):
return add * bias_add
+def get_batch_matmul_with_shape(x_shape, y_shape, out_dtype="float16"):
+ x = relay.var("x", shape=x_shape, dtype="float16")
+ y = relay.var("y", shape=y_shape, dtype="float16")
+ return relay.nn.batch_matmul(x, y, out_dtype=out_dtype)
+
+
+def get_batch_matmul(batch, M, N, K, out_dtype="float16"):
+ return get_batch_matmul_with_shape((batch, M, K), (batch, N, K),
out_dtype="float16")
+
+
def profile_and_build(mod, params, sm, tmp_dir="./tmp", lib_path="compile.so"):
mod = partition_for_cutlass(mod)
mod, num_cutlass_partition = tune_cutlass_kernels(
@@ -123,7 +135,9 @@ def profile_and_build_vm(
return VirtualMachine(vm_exec, dev), dev, num_cutlass_partition
-def verify(func, M, N, K, ref_target="cuda", sm=80, atol=1e-5, rtol=1e-5,
run_benchmark=False):
+def verify_dense(
+ func, M, N, K, ref_target="cuda", sm=80, atol=1e-5, rtol=1e-5,
run_benchmark=False
+):
if not has_cutlass():
return
mod = tvm.IRModule.from_expr(func)
@@ -151,14 +165,14 @@ def verify(func, M, N, K, ref_target="cuda", sm=80,
atol=1e-5, rtol=1e-5, run_be
rt_mod_ref, dev = get_ref_vm(mod, params, target=ref_target)
x = tvm.nd.array(np_data, device=dev)
- out = get_output_vm(rt_mod, x)
- ref_out = get_output_vm(rt_mod_ref, x)
+ out = get_output_vm(rt_mod, ["data"], [x])
+ ref_out = get_output_vm(rt_mod_ref, ["data"], [x])
else:
rt_mod_ref, dev = get_ref_rt_mod(mod, params, target=ref_target)
rt_mod, dev, num_partition = profile_and_build(mod, params, sm)
x = tvm.nd.array(np_data, device=dev)
- out = get_output(rt_mod, x)
- ref_out = get_output(rt_mod_ref, x)
+ out = get_output(rt_mod, ["data"], [x])
+ ref_out = get_output(rt_mod_ref, ["data"], [x])
assert num_partition > 0
np.testing.assert_allclose(out, ref_out, atol=atol, rtol=rtol)
@@ -168,29 +182,65 @@ def verify(func, M, N, K, ref_target="cuda", sm=80,
atol=1e-5, rtol=1e-5, run_be
print("TVM with target %s:" % ref_target, rt_mod_ref.benchmark(dev,
number=1, repeat=600))
+def verify_batch_matmul(
+ func, batch, M, N, K, ref_target="cuda", sm=80, atol=1e-5, rtol=1e-5,
run_benchmark=False
+):
+ if not has_cutlass():
+ return
+ mod = tvm.IRModule.from_expr(func)
+ typ = relay.transform.InferType()(mod)["main"].body.checked_type
+ use_vm = any(isinstance(s, tvm.tir.Any) for s in typ.shape)
+ x_np = np.random.uniform(-1, 1, (batch, M, K)).astype("float16")
+ y_np = np.random.uniform(-1, 1, (batch, N, K)).astype("float16")
+
+ if use_vm:
+ rt_mod, dev, num_partition = profile_and_build_vm(mod, {}, sm)
+ rt_mod_ref, dev = get_ref_vm(mod, {}, target=ref_target)
+ assert num_partition > 0
+ x = tvm.nd.array(x_np, device=dev)
+ y = tvm.nd.array(y_np, device=dev)
+ out = get_output_vm(rt_mod, ["x", "y"], [x, y])
+ ref_out = get_output_vm(rt_mod_ref, ["x", "y"], [x, y])
+ else:
+ rt_mod, dev, num_partition = profile_and_build(mod, {}, sm)
+ rt_mod_ref, dev = get_ref_rt_mod(mod, {})
+ assert num_partition > 0
+
+ x = tvm.nd.array(x_np, device=dev)
+ y = tvm.nd.array(y_np, device=dev)
+ out = get_output(rt_mod, ["x", "y"], [x, y])
+ ref_out = get_output(rt_mod_ref, ["x", "y"], [x, y])
+
+ np.testing.assert_allclose(out, ref_out, atol=atol, rtol=rtol)
+
+ if True:
+ print("CUTLASS:", rt_mod.benchmark(dev, number=1, repeat=600))
+ print("TVM Tensorcore (no tuning):", rt_mod_ref.benchmark(dev,
number=1, repeat=600))
+
+
M = 1820
N = 768
K = 768
def test_dense():
- verify(get_dense(M, N, K), M, N, K)
- verify(get_dense(M, N, K, out_dtype="float32"), M, N, K)
+ verify_dense(get_dense(M, N, K), M, N, K)
+ verify_dense(get_dense(M, N, K, out_dtype="float32"), M, N, K)
def test_dense_bias():
- verify(get_dense_bias(M, N, K), M, N, K)
- verify(get_dense_bias(M, N, K, out_dtype="float32"), M, N, K)
+ verify_dense(get_dense_bias(M, N, K), M, N, K)
+ verify_dense(get_dense_bias(M, N, K, out_dtype="float32"), M, N, K)
def test_dense_bias_relu():
- verify(get_dense_bias_relu(M, N, K), M, N, K)
- verify(get_dense_bias_relu(M, N, K, out_dtype="float32"), M, N, K)
+ verify_dense(get_dense_bias_relu(M, N, K), M, N, K)
+ verify_dense(get_dense_bias_relu(M, N, K, out_dtype="float32"), M, N, K)
def test_dense_bias_gelu():
- verify(get_dense_bias_gelu(M, N, K), M, N, K, atol=1e-3, rtol=1e-3)
- verify(get_dense_bias_gelu(M, N, K, out_dtype="float32"), M, N, K,
atol=1e-3, rtol=1e-3)
+ verify_dense(get_dense_bias_gelu(M, N, K), M, N, K, atol=1e-3, rtol=1e-3)
+ verify_dense(get_dense_bias_gelu(M, N, K, out_dtype="float32"), M, N, K,
atol=1e-3, rtol=1e-3)
def test_dense_dynamic():
@@ -200,7 +250,7 @@ def test_dense_dynamic():
if has_cublas():
# TVM native fp16 dense (without tensorcore), using fp16 accum, seems
to have accuracy issues
# Use cublas as a reference
- verify(
+ verify_dense(
get_dense_with_shape(data_shape, weight_shape),
M,
N,
@@ -208,7 +258,7 @@ def test_dense_dynamic():
ref_target="cuda -libs=cublas",
)
- verify(
+ verify_dense(
get_dense_with_shape(data_shape, weight_shape, out_dtype="float32"),
M,
N,
@@ -218,5 +268,26 @@ def test_dense_dynamic():
)
+def test_batch_matmul():
+ batch = 8
+ verify_batch_matmul(get_batch_matmul(batch, M, N, K), batch, M, N, K)
+ verify_batch_matmul(get_batch_matmul(batch, M, N, K, out_dtype="float32"),
batch, M, N, K)
+
+ if has_cublas():
+ # Test dynamic shape batch_matmul
+ # AutoTVM does not seem to support it
+ x_shape = (relay.Any(), relay.Any(), K)
+ y_shape = (relay.Any(), relay.Any(), K)
+
+ verify_batch_matmul(
+ get_batch_matmul_with_shape(x_shape, y_shape),
+ batch,
+ M,
+ N,
+ K,
+ ref_target="cuda -libs=cublas",
+ )
+
+
if __name__ == "__main__":
pytest.main([__file__])