This is an automated email from the ASF dual-hosted git repository.
comaniac pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 6d35f0b [CUTLASS] Refactor cutlass kernel generation and selection
(#9800)
6d35f0b is described below
commit 6d35f0bbdf656d393b0722cb93cd213217781c9d
Author: masahi <[email protected]>
AuthorDate: Fri Dec 31 08:20:03 2021 +0900
[CUTLASS] Refactor cutlass kernel generation and selection (#9800)
---
python/tvm/contrib/cutlass/build.py | 78 +++------
python/tvm/contrib/cutlass/conv2d_operation.py | 2 +-
python/tvm/contrib/cutlass/gen_conv2d.py | 198 ++++++++++++---------
python/tvm/contrib/cutlass/gen_gemm.py | 234 ++++++++++++++-----------
python/tvm/contrib/cutlass/gen_tensor_op.py | 18 ++
python/tvm/relay/op/contrib/cutlass.py | 11 +-
6 files changed, 301 insertions(+), 240 deletions(-)
diff --git a/python/tvm/contrib/cutlass/build.py
b/python/tvm/contrib/cutlass/build.py
index 3bc3b5d..e921302 100644
--- a/python/tvm/contrib/cutlass/build.py
+++ b/python/tvm/contrib/cutlass/build.py
@@ -94,15 +94,17 @@ class OpAnnotator(tvm.relay.ExprVisitor):
def select_gemm_kernel(
- cutlass_profiler, MM, KK, NN, out_dtype, batched, profile_all,
use_multiprocessing
+ cutlass_profiler, op_type, MM, KK, NN, out_dtype, batched, profile_all,
use_multiprocessing
):
"""Run CUTLASS profiler to select the best kernel, or return the default
one for dynamic
workloads."""
if any(isinstance(s, tvm.tir.Any) for s in [MM, KK, NN]):
- out = cutlass_profiler.get_default(out_dtype, batched=batched)
- logger.info("Picked the default kernel %s", out["name"])
+ out = cutlass_profiler.get_default(op_type, out_dtype, batched=batched)
+ name, cutlass_op_def = out["name"], out["opdef"]
+ logger.info("Picked the default kernel %s", name)
else:
- out = cutlass_profiler.profile(
+ name, cutlass_op_def, _ = cutlass_profiler.profile(
+ op_type,
MM,
NN,
KK,
@@ -112,10 +114,11 @@ def select_gemm_kernel(
use_multiprocessing=use_multiprocessing,
)
if profile_all:
- logger.info("The best kernel is %s", out["name"])
+ logger.info("The best kernel is %s", name)
else:
- logger.info("Picked the first kernel found %s", out["name"])
- return out
+ logger.info("Picked the first kernel found %s", name)
+
+ return name, cutlass_op_def
def handle_batch_matmul(
@@ -126,24 +129,17 @@ def handle_batch_matmul(
KK = arg0_shape[2]
NN = arg1_shape[1]
- out = select_gemm_kernel(
- cutlass_profiler, MM, KK, NN, out_dtype, True, profile_all,
use_multiprocessing
+ name, cutlass_op_def = select_gemm_kernel(
+ cutlass_profiler, op_type, MM, KK, NN, out_dtype, True, profile_all,
use_multiprocessing
)
- if op_type == "cutlass.batch_matmul":
- cutlass_op_def = out["opdef"]
- else:
- raise ValueError("%s pattern is not implemented." % op_type)
-
- assert "tn_align" in out["name"], "Only supports (row_major, col_major)
input layout for now."
-
return {
"batch": arg0_shape[0],
"batch_stride_A": arg0_shape[1] * arg0_shape[2],
"batch_stride_B": arg1_shape[1] * arg1_shape[2],
"batch_stride_C": arg0_shape[1] * arg1_shape[1],
"cutlass_op_def": cutlass_op_def,
- "cutlass_op_name": out["name"],
+ "cutlass_op_name": name,
"lda": "K",
"ldb": "K",
"ldc": "N",
@@ -158,26 +154,15 @@ def handle_dense(
KK = arg0_shape[1]
NN = arg1_shape[0]
- out = select_gemm_kernel(
- cutlass_profiler, MM, KK, NN, out_dtype, False, profile_all,
use_multiprocessing
+ name, cutlass_op_def = select_gemm_kernel(
+ cutlass_profiler, op_type, MM, KK, NN, out_dtype, False, profile_all,
use_multiprocessing
)
- if op_type == "cutlass.dense":
- cutlass_op_def = out["opdef"]
- elif op_type == "cutlass.dense_bias":
- cutlass_op_def = out["opdef_bias"]
- elif op_type == "cutlass.dense_bias_relu":
- cutlass_op_def = out["opdef_bias_relu"]
- elif "cutlass.dense_bias_gelu" in op_type:
- cutlass_op_def = out["opdef_bias_gelu"]
- else:
- raise ValueError("%s pattern is not implemented." % op_type)
-
- assert "tn_align" in out["name"], "Only supports (row_major, col_major)
input layout for now."
+ assert "tn_align" in name, "Only supports (row_major, col_major) input
layout for now."
return {
"cutlass_op_def": cutlass_op_def,
- "cutlass_op_name": out["name"],
+ "cutlass_op_name": name,
"lda": "K",
"ldb": "K",
"ldc": "N",
@@ -198,10 +183,12 @@ def handle_conv2d(
):
"""Profile and select a kernel for conv2d op workload."""
if any(isinstance(s, tvm.tir.Any) for s in d_shape):
- out = cutlass_profiler.get_default(out_dtype)
- logger.info("Picked the default kernel %s", out["name"])
+ out = cutlass_profiler.get_default(op_type, out_dtype)
+ name, cutlass_op_def = out["name"], out["opdef"]
+ logger.info("Picked the default kernel %s", name)
else:
- out = cutlass_profiler.profile(
+ name, cutlass_op_def, _ = cutlass_profiler.profile(
+ op_type,
d_shape,
w_shape,
padding,
@@ -212,28 +199,13 @@ def handle_conv2d(
use_multiprocessing=use_multiprocessing,
)
if profile_all:
- logger.info("The best kernel is %s", out["name"])
+ logger.info("The best kernel is %s", name)
else:
- logger.info("Picked the first kernel found %s", out["name"])
-
- if op_type == "cutlass.conv2d":
- cutlass_op_def = out["opdef"]
- elif op_type == "cutlass.conv2d_bias":
- cutlass_op_def = out["opdef_bias"]
- elif op_type == "cutlass.conv2d_bias_relu":
- cutlass_op_def = out["opdef_bias_relu"]
- elif op_type == "cutlass.conv2d_bias_sigmoid":
- cutlass_op_def = out["opdef_bias_sigmoid"]
- elif op_type == "cutlass.conv2d_bias_silu":
- cutlass_op_def = out["opdef_bias_silu"]
- elif op_type == "cutlass.conv2d_bias_hardswish":
- cutlass_op_def = out["opdef_bias_hardswish"]
- else:
- raise ValueError("%s pattern is not implemented." % op_type)
+ logger.info("Picked the first kernel found %s", name)
return {
"cutlass_op_def": cutlass_op_def,
- "cutlass_op_name": out["name"],
+ "cutlass_op_name": name,
}
diff --git a/python/tvm/contrib/cutlass/conv2d_operation.py
b/python/tvm/contrib/cutlass/conv2d_operation.py
index 3530892..1c7f9a3 100644
--- a/python/tvm/contrib/cutlass/conv2d_operation.py
+++ b/python/tvm/contrib/cutlass/conv2d_operation.py
@@ -186,7 +186,7 @@ class EmitConv2dInstance:
>::Kernel;
"""
- def emit(self, operation, no_beta_scaling=True):
+ def emit(self, operation, no_beta_scaling=False):
"""Instantiate a Conv2d kernel from given `operation`."""
warp_shape = [
int(
diff --git a/python/tvm/contrib/cutlass/gen_conv2d.py
b/python/tvm/contrib/cutlass/gen_conv2d.py
index 43317f9..4e4a7b2 100644
--- a/python/tvm/contrib/cutlass/gen_conv2d.py
+++ b/python/tvm/contrib/cutlass/gen_conv2d.py
@@ -20,10 +20,7 @@ import re
from .conv2d_operation import Conv2dOperation, EmitConv2dInstance
from .gen_gemm import CutlassGemmProfiler
from .conv2d_profiler import Conv2dProfilerEmitter
-from .gen_tensor_op import (
- ProfilerEngine,
- GENERATOR_FUNC_TABLE,
-)
+from .gen_tensor_op import ProfilerEngine, GENERATOR_FUNC_TABLE, EPILOGUE_MAP
from .library import (
EpilogueFunctor,
SwizzlingFunctor,
@@ -35,7 +32,42 @@ from .library import (
)
-def create_conv2d_operator(
+def create_conv2d_operator_with_epilogue(
+ op_type, tile_description, data_type, alignment, swizzling_functor
+):
+ """
+ Instantiate a cutlass kernel from the given configuration,
+ along with the epilouge functor
+ """
+ epilogue, no_beta_scaling = EPILOGUE_MAP[op_type]
+
+ element_a, element_b, element_c, element_epilogue = data_type
+
+ A = TensorDescription(element_a, LayoutType.TensorNHWC, alignment)
+ B = TensorDescription(element_b, LayoutType.TensorNHWC, alignment)
+ C = TensorDescription(element_c, LayoutType.TensorNHWC, alignment)
+
+ op = Conv2dOperation(
+ ConvKind.Fprop,
+ IteratorAlgorithm.Optimized,
+ tile_description.minimum_compute_capability,
+ tile_description,
+ A,
+ B,
+ C,
+ element_epilogue,
+ StrideSupport.Strided,
+ epilogue,
+ swizzling_functor,
+ )
+
+ name = op.procedural_name()
+ opdef = EmitConv2dInstance().emit(op, no_beta_scaling=no_beta_scaling)
+
+ return name, opdef
+
+
+def enumerate_conv2d_operators(
tile_descriptions,
data_type,
alignment_constraints,
@@ -48,77 +80,38 @@ def create_conv2d_operator(
profiler_emitter = Conv2dProfilerEmitter()
element_a, element_b, element_c, element_epilogue = data_type
- iterator_algorithms = [IteratorAlgorithm.Optimized]
- layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC,
LayoutType.TensorNHWC)
for tile in tile_descriptions:
for alignment in alignment_constraints:
- alignment_c = min(8, alignment)
-
- A = TensorDescription(element_a, layout[0], alignment)
- B = TensorDescription(element_b, layout[1], alignment)
- C = TensorDescription(element_c, layout[2], alignment_c)
-
- swizzling_functor_ = swizzling_functor
-
- for iterator_algorithm in iterator_algorithms:
- op_entry = {}
-
- op = Conv2dOperation(
- ConvKind.Fprop,
- iterator_algorithm,
- tile.minimum_compute_capability,
- tile,
- A,
- B,
- C,
- element_epilogue,
- StrideSupport.Strided,
- EpilogueFunctor.LinearCombination,
- swizzling_functor_,
- )
-
- op_entry["opdef"] = kernel_emitter.emit(op)
- op_entry["op"] = op
- op_entry["src"] = profiler_emitter.emit(op_entry["opdef"],
op.procedural_name())
- op_entry["name"] = op.procedural_name()
-
- # fused ops
- for epilogue, opdef, no_bias_scaling in zip(
- [
- EpilogueFunctor.LinearCombinationBias,
- EpilogueFunctor.LinearCombinationRelu,
- EpilogueFunctor.LinearCombinationSigmoid,
- EpilogueFunctor.LinearCombinationSilu,
- EpilogueFunctor.LinearCombinationHardSwish,
- ],
- [
- "opdef_bias",
- "opdef_bias_relu",
- "opdef_bias_sigmoid",
- "opdef_bias_silu",
- "opdef_bias_hardswish",
- ],
- [True, True, False, False, False],
- ):
- op = Conv2dOperation(
- ConvKind.Fprop,
- iterator_algorithm,
- tile.minimum_compute_capability,
- tile,
- A,
- B,
- C,
- element_epilogue,
- StrideSupport.Strided,
- epilogue,
- swizzling_functor_,
- )
-
- op_entry[opdef] = kernel_emitter.emit(op, no_bias_scaling)
-
- ret.append(op_entry)
+ A = TensorDescription(element_a, LayoutType.TensorNHWC, alignment)
+ B = TensorDescription(element_b, LayoutType.TensorNHWC, alignment)
+ C = TensorDescription(element_c, LayoutType.TensorNHWC, alignment)
+
+ op = Conv2dOperation(
+ ConvKind.Fprop,
+ IteratorAlgorithm.Optimized,
+ tile.minimum_compute_capability,
+ tile,
+ A,
+ B,
+ C,
+ element_epilogue,
+ StrideSupport.Strided,
+ EpilogueFunctor.LinearCombination,
+ swizzling_functor,
+ )
+
+ ret.append(
+ {
+ "src": profiler_emitter.emit(kernel_emitter.emit(op),
op.procedural_name()),
+ "name": op.procedural_name(),
+ "tile_description": tile,
+ "alignment": alignment,
+ "data_type": data_type,
+ "swizzle_functor": swizzling_functor,
+ }
+ )
return ret
@@ -133,12 +126,15 @@ class CutlassConv2DProfiler:
self.engine = ProfilerEngine(sm, cutlass_path, binary_path)
self.cache = {}
- def get_default(self, out_dtype):
- gemm_profile_result = self.gemm_profiler.get_default(out_dtype)
+ def get_default(self, op_type, out_dtype):
+ gemm_profile_result = self.gemm_profiler.get_default(op_type,
out_dtype)
tile_description = gemm_profile_result["tile_description"]
alignment = gemm_profile_result["alignment"]
data_type = gemm_profile_result["data_type"]
- return create_conv2d_operator([tile_description], data_type,
[alignment])[0]
+ name, opdef = create_conv2d_operator_with_epilogue(
+ op_type, tile_description, data_type, alignment,
SwizzlingFunctor.Identity4
+ )
+ return {"name": name, "opdef": opdef}
def check_align(self, op_name, C, K):
"""Filter out kernels that cannot be supported."""
@@ -147,7 +143,7 @@ class CutlassConv2DProfiler:
align = int(aligns[0][-1])
return all([dim % align == 0 for dim in [C, K]])
- def profile(
+ def select_op(
self,
d_shape,
w_shape,
@@ -158,9 +154,9 @@ class CutlassConv2DProfiler:
profile_all=True,
use_multiprocessing=False,
):
- """Profile and select the best kernel from candidate kernels.
- If profile_all is False, return immediately after the first applicable
kernel is found.
- If use_multiprocessing is True, compile all profiler executables in
parallel.
+ """
+ Profile and select the best kernel from candidate kernels.
+ See the documentation for the profile method below.
"""
N, H, W, IC = d_shape
OC, R, S, _ = w_shape
@@ -183,7 +179,10 @@ class CutlassConv2DProfiler:
if workload in self.cache:
return self.cache[workload]
- ops = GENERATOR_FUNC_TABLE[self.sm](out_dtype,
op_creator=create_conv2d_operator)
+ ops = GENERATOR_FUNC_TABLE[self.sm](
+ out_dtype,
+ op_creator=enumerate_conv2d_operators,
+ )
ops = list(filter(lambda op: self.check_align(op["name"], IC, OC),
ops))
if profile_all:
@@ -201,6 +200,39 @@ class CutlassConv2DProfiler:
self.cache[workload] = op
return op
- output = min(ops, key=lambda i: i["runtime"])
- self.cache[workload] = output
- return output
+ op = min(ops, key=lambda i: i["runtime"])
+ self.cache[workload] = op
+ return op
+
+ def profile(
+ self,
+ op_type,
+ d_shape,
+ w_shape,
+ padding,
+ stride,
+ dilation,
+ out_dtype,
+ profile_all=True,
+ use_multiprocessing=False,
+ ):
+ """Profile and select the best kernel from candidate kernels.
+ If profile_all is False, return immediately after the first applicable
kernel is found.
+ If use_multiprocessing is True, compile all profiler executables in
parallel.
+ """
+ op = self.select_op(
+ d_shape,
+ w_shape,
+ padding,
+ stride,
+ dilation,
+ out_dtype,
+ profile_all=profile_all,
+ use_multiprocessing=use_multiprocessing,
+ )
+
+ name, opdef = create_conv2d_operator_with_epilogue(
+ op_type, op["tile_description"], op["data_type"], op["alignment"],
op["swizzle_functor"]
+ )
+
+ return name, opdef, op["runtime"]
diff --git a/python/tvm/contrib/cutlass/gen_gemm.py
b/python/tvm/contrib/cutlass/gen_gemm.py
index 7048c32..9159ed8 100644
--- a/python/tvm/contrib/cutlass/gen_gemm.py
+++ b/python/tvm/contrib/cutlass/gen_gemm.py
@@ -16,14 +16,10 @@
# under the License.
# pylint: disable=invalid-name
"""GEMM kernel generator and profiler for CUTLASS."""
-from functools import partial
import re
from .gemm_operation import GemmOperation, EmitGemmInstance
from .gemm_profiler import GemmProfilerEmitter
-from .gen_tensor_op import (
- ProfilerEngine,
- GENERATOR_FUNC_TABLE,
-)
+from .gen_tensor_op import ProfilerEngine, GENERATOR_FUNC_TABLE, EPILOGUE_MAP
from .library import (
EpilogueFunctor,
SwizzlingFunctor,
@@ -33,12 +29,50 @@ from .library import (
)
-def create_gemm_operator(
+def create_gemm_operator_with_epilogue(
+ op_type,
+ tile_description,
+ data_type,
+ alignment,
+ swizzling_functor,
+ batched=False,
+):
+ """
+ Instantiate a cutlass kernel from the given configuration,
+ along with the epilouge functor
+ """
+ element_a, element_b, element_c, element_epilogue = data_type
+
+ A = TensorDescription(element_a, LayoutType.RowMajor, alignment)
+ B = TensorDescription(element_b, LayoutType.ColumnMajor, alignment)
+ C = TensorDescription(element_c, LayoutType.RowMajor, alignment)
+
+ if batched:
+ swizzling_functor = SwizzlingFunctor.Batched
+
+ epilogue, no_beta_scaling = EPILOGUE_MAP[op_type]
+
+ op = GemmOperation(
+ tile_description.minimum_compute_capability,
+ tile_description,
+ A,
+ B,
+ C,
+ element_epilogue,
+ epilogue,
+ swizzling_functor,
+ )
+
+ return op.procedural_name(), EmitGemmInstance().emit(
+ op, no_beta_scaling=no_beta_scaling, batched=batched
+ )
+
+
+def enumerate_gemm_operators(
tile_descriptions,
data_type,
alignment_constraints,
swizzling_functor=SwizzlingFunctor.Identity8,
- batched=False,
):
"""Exhaustively instantiate all kernels from a given configuration."""
ret = []
@@ -47,86 +81,44 @@ def create_gemm_operator(
element_a, element_b, element_c, element_epilogue = data_type
- if batched:
- swizzling_functor = SwizzlingFunctor.Batched
+ for tile_description in tile_descriptions:
+ for alignment in alignment_constraints:
+ A = TensorDescription(element_a, LayoutType.RowMajor, alignment)
+ B = TensorDescription(element_b, LayoutType.ColumnMajor, alignment)
+ C = TensorDescription(element_c, LayoutType.RowMajor, alignment)
+
+ op = GemmOperation(
+ tile_description.minimum_compute_capability,
+ tile_description,
+ A,
+ B,
+ C,
+ element_epilogue,
+ EpilogueFunctor.LinearCombination,
+ swizzling_functor,
+ )
+
+ src = profiler_emitter.emit(
+ op.procedural_name(),
+ kernel_emitter.emit(op, batched=False),
+ DataTypeTag[element_a],
+ DataTypeTag[element_b],
+ DataTypeTag[element_c],
+ op.leading_dim(),
+ )
+
+ ret.append(
+ {
+ "src": src,
+ "op": op,
+ "name": op.procedural_name(),
+ "tile_description": tile_description,
+ "alignment": alignment,
+ "data_type": data_type,
+ "swizzle_functor": swizzling_functor,
+ }
+ )
- layouts = [
- (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor),
- ]
-
- for layout in layouts:
- for tile_description in tile_descriptions:
- for alignment in alignment_constraints:
- alignment_c = min(8, alignment)
-
- A = TensorDescription(element_a, layout[0], alignment)
- B = TensorDescription(element_b, layout[1], alignment)
- C = TensorDescription(element_c, layout[2], alignment_c)
-
- op_entry = {}
- op = GemmOperation(
- tile_description.minimum_compute_capability,
- tile_description,
- A,
- B,
- C,
- element_epilogue,
- EpilogueFunctor.LinearCombination,
- swizzling_functor,
- )
- op_bias = GemmOperation(
- tile_description.minimum_compute_capability,
- tile_description,
- A,
- B,
- C,
- element_epilogue,
- EpilogueFunctor.LinearCombinationBias,
- swizzling_functor,
- )
- op_bias_relu = GemmOperation(
- tile_description.minimum_compute_capability,
- tile_description,
- A,
- B,
- C,
- element_epilogue,
- EpilogueFunctor.LinearCombinationRelu,
- swizzling_functor,
- )
- op_bias_gelu = GemmOperation(
- tile_description.minimum_compute_capability,
- tile_description,
- A,
- B,
- C,
- element_epilogue,
- EpilogueFunctor.LinearCombinationGelu,
- swizzling_functor,
- )
-
- op_entry["op"] = op
- op_entry["name"] = op.procedural_name()
- op_entry["opdef"] = kernel_emitter.emit(op, batched=batched)
- op_entry["opdef_bias"] = kernel_emitter.emit(
- op_bias, no_beta_scaling=True, batched=batched
- )
- op_entry["opdef_bias_relu"] = kernel_emitter.emit(
- op_bias_relu, no_beta_scaling=True, batched=batched
- )
- op_entry["opdef_bias_gelu"] =
kernel_emitter.emit(op_bias_gelu, batched=batched)
- op_entry["src"] = profiler_emitter.emit(
- op.procedural_name(),
- kernel_emitter.emit(op, batched=False),
- DataTypeTag[element_a],
- DataTypeTag[element_b],
- DataTypeTag[element_c],
- op.leading_dim(),
- )
- op_entry["tile_description"] = tile_description
- op_entry["alignment"] = alignment
- op_entry["data_type"] = data_type
- ret.append(op_entry)
return ret
@@ -164,30 +156,38 @@ class CutlassGemmProfiler:
# When the above issue is resolved, we can remove the alignment check
on M below.
return all([dim % align == 0 for dim in [M, N, K]])
- def get_default(self, out_dtype, batched=False):
+ def get_default(self, op_type, out_dtype, batched=False):
"""Return the default kernel for the requested architecture.
For now, the default kernel was picked arbitrary.
"""
- ops = GENERATOR_FUNC_TABLE[self.sm](
- out_dtype, op_creator=partial(create_gemm_operator,
batched=batched)
- )
+ ops = GENERATOR_FUNC_TABLE[self.sm](out_dtype,
op_creator=enumerate_gemm_operators)
default_kernel_name = DEFAULT_KERNELS[self.sm][out_dtype]
filtered = list(filter(lambda op: op["name"] == default_kernel_name,
ops))
assert len(filtered) == 1
- return filtered[0]
+ op = filtered[0]
+ name, opdef = create_gemm_operator_with_epilogue(
+ op_type,
+ op["tile_description"],
+ op["data_type"],
+ op["alignment"],
+ op["swizzle_functor"],
+ batched=batched,
+ )
+ op.update({"name": name, "opdef": opdef})
+ return op
- def profile(
- self, M, N, K, out_dtype, profile_all=True, use_multiprocessing=False,
batched=False
- ):
- """Profile and select the best kernel from candidate kernels.
- If profile_all is False, return immediately after the first applicable
kernel is found.
- If use_multiprocessing is True, compile all profiler executables in
parallel.
+ def select_op(self, M, N, K, out_dtype, profile_all=True,
use_multiprocessing=False):
+ """
+ Profile and select the best kernel from candidate kernels.
+ See the documentation for the profile method below.
"""
if (M, N, K) in self.cache:
- return self.cache[(M, N, K)]
+ op = self.cache[(M, N, K)]
+ return op
ops = GENERATOR_FUNC_TABLE[self.sm](
- out_dtype, op_creator=partial(create_gemm_operator,
batched=batched)
+ out_dtype,
+ op_creator=enumerate_gemm_operators,
)
ops = list(filter(lambda op: self.check_align(op["name"], M, N, K),
ops))
@@ -201,6 +201,36 @@ class CutlassGemmProfiler:
self.cache[(M, N, K)] = op
return op
- output = min(ops, key=lambda i: i["runtime"])
- self.cache[(M, N, K)] = output
- return output
+ op = min(ops, key=lambda i: i["runtime"])
+ self.cache[(M, N, K)] = op
+ return op
+
+ def profile(
+ self,
+ op_type,
+ M,
+ N,
+ K,
+ out_dtype,
+ profile_all=True,
+ use_multiprocessing=False,
+ batched=False,
+ ):
+ """Profile and select the best kernel from candidate kernels.
+ If profile_all is False, return immediately after the first applicable
kernel is found.
+ If use_multiprocessing is True, compile all profiler executables in
parallel.
+ """
+ op = self.select_op(
+ M, N, K, out_dtype, profile_all=profile_all,
use_multiprocessing=use_multiprocessing
+ )
+
+ name, opdef = create_gemm_operator_with_epilogue(
+ op_type,
+ op["tile_description"],
+ op["data_type"],
+ op["alignment"],
+ op["swizzle_functor"],
+ batched=batched,
+ )
+
+ return name, opdef, op["runtime"]
diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py
b/python/tvm/contrib/cutlass/gen_tensor_op.py
index 9ccde37..6632b15 100644
--- a/python/tvm/contrib/cutlass/gen_tensor_op.py
+++ b/python/tvm/contrib/cutlass/gen_tensor_op.py
@@ -27,6 +27,7 @@ from .library import (
OpcodeClass,
MathOperation,
TileDescription,
+ EpilogueFunctor,
)
logger = logging.getLogger("cutlass")
@@ -165,6 +166,23 @@ GENERATOR_FUNC_TABLE = {
}
+# (Epilogue functor name, no_beta_scaling)
+EPILOGUE_MAP = {
+ "cutlass.dense": (EpilogueFunctor.LinearCombination, False),
+ "cutlass.dense_bias": (EpilogueFunctor.LinearCombinationBias, True),
+ "cutlass.dense_bias_relu": (EpilogueFunctor.LinearCombinationRelu, True),
+ "cutlass.dense_bias_gelu_fp16": (EpilogueFunctor.LinearCombinationGelu,
False),
+ "cutlass.dense_bias_gelu_fp32": (EpilogueFunctor.LinearCombinationGelu,
False),
+ "cutlass.batch_matmul": (EpilogueFunctor.LinearCombination, False),
+ "cutlass.conv2d_bias_hardswish":
(EpilogueFunctor.LinearCombinationHardSwish, False),
+ "cutlass.conv2d_bias_silu": (EpilogueFunctor.LinearCombinationSilu, False),
+ "cutlass.conv2d_bias_sigmoid": (EpilogueFunctor.LinearCombinationSigmoid,
False),
+ "cutlass.conv2d_bias_relu": (EpilogueFunctor.LinearCombinationRelu, True),
+ "cutlass.conv2d_bias": (EpilogueFunctor.LinearCombinationBias, True),
+ "cutlass.conv2d": (EpilogueFunctor.LinearCombination, False),
+}
+
+
class ProfilerEngine:
"""Compile and run a given profiler executable."""
diff --git a/python/tvm/relay/op/contrib/cutlass.py
b/python/tvm/relay/op/contrib/cutlass.py
index eb36dc2..cbbc45a 100644
--- a/python/tvm/relay/op/contrib/cutlass.py
+++ b/python/tvm/relay/op/contrib/cutlass.py
@@ -16,6 +16,7 @@
# under the License.
# pylint: disable=invalid-name
"""Patterns supported CUTLASS."""
+from tvm import relay
from tvm.ir.transform import Sequential, PassContext
from tvm.relay import transform
from tvm.relay.build_module import bind_params_by_name
@@ -95,6 +96,8 @@ def check_dtype(lhs, rhs):
def get_root_call(call, root_op_name):
+ if not isinstance(call, relay.Call):
+ return None
if str(call.op) == root_op_name:
return call
return get_root_call(call.args[0], root_op_name)
@@ -151,13 +154,17 @@ def partition_for_cutlass(mod, params=None):
make_gemm_pattern(True, "gelu", out_dtype="float32"),
check_gemm,
)
- cutlass_patterns = [
+
+ dense_patterns = [
dense_bias_gelu_fp16_pat,
dense_bias_gelu_fp32_pat,
dense_bias_relu_pat,
dense_bias_pat,
dense_pat,
("cutlass.batch_matmul", make_batch_matmul_pattern(),
check_batch_matmul),
+ ]
+
+ conv2d_patterns = [
(
"cutlass.conv2d_bias_hardswish",
make_conv2d_pattern(with_bias=True, with_act="hardswish"),
@@ -182,6 +189,8 @@ def partition_for_cutlass(mod, params=None):
("cutlass.conv2d", make_conv2d_pattern(), check_conv2d),
]
+ cutlass_patterns = dense_patterns + conv2d_patterns
+
if params is not None:
mod["main"] = bind_params_by_name(mod["main"], params)
remove_bn_pass = Sequential(