This is an automated email from the ASF dual-hosted git repository.

comaniac pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 6d35f0b  [CUTLASS] Refactor cutlass kernel generation and selection 
(#9800)
6d35f0b is described below

commit 6d35f0bbdf656d393b0722cb93cd213217781c9d
Author: masahi <[email protected]>
AuthorDate: Fri Dec 31 08:20:03 2021 +0900

    [CUTLASS] Refactor cutlass kernel generation and selection (#9800)
---
 python/tvm/contrib/cutlass/build.py            |  78 +++------
 python/tvm/contrib/cutlass/conv2d_operation.py |   2 +-
 python/tvm/contrib/cutlass/gen_conv2d.py       | 198 ++++++++++++---------
 python/tvm/contrib/cutlass/gen_gemm.py         | 234 ++++++++++++++-----------
 python/tvm/contrib/cutlass/gen_tensor_op.py    |  18 ++
 python/tvm/relay/op/contrib/cutlass.py         |  11 +-
 6 files changed, 301 insertions(+), 240 deletions(-)

diff --git a/python/tvm/contrib/cutlass/build.py 
b/python/tvm/contrib/cutlass/build.py
index 3bc3b5d..e921302 100644
--- a/python/tvm/contrib/cutlass/build.py
+++ b/python/tvm/contrib/cutlass/build.py
@@ -94,15 +94,17 @@ class OpAnnotator(tvm.relay.ExprVisitor):
 
 
 def select_gemm_kernel(
-    cutlass_profiler, MM, KK, NN, out_dtype, batched, profile_all, 
use_multiprocessing
+    cutlass_profiler, op_type, MM, KK, NN, out_dtype, batched, profile_all, 
use_multiprocessing
 ):
     """Run CUTLASS profiler to select the best kernel, or return the default 
one for dynamic
     workloads."""
     if any(isinstance(s, tvm.tir.Any) for s in [MM, KK, NN]):
-        out = cutlass_profiler.get_default(out_dtype, batched=batched)
-        logger.info("Picked the default kernel %s", out["name"])
+        out = cutlass_profiler.get_default(op_type, out_dtype, batched=batched)
+        name, cutlass_op_def = out["name"], out["opdef"]
+        logger.info("Picked the default kernel %s", name)
     else:
-        out = cutlass_profiler.profile(
+        name, cutlass_op_def, _ = cutlass_profiler.profile(
+            op_type,
             MM,
             NN,
             KK,
@@ -112,10 +114,11 @@ def select_gemm_kernel(
             use_multiprocessing=use_multiprocessing,
         )
         if profile_all:
-            logger.info("The best kernel is %s", out["name"])
+            logger.info("The best kernel is %s", name)
         else:
-            logger.info("Picked the first kernel found %s", out["name"])
-    return out
+            logger.info("Picked the first kernel found %s", name)
+
+    return name, cutlass_op_def
 
 
 def handle_batch_matmul(
@@ -126,24 +129,17 @@ def handle_batch_matmul(
     KK = arg0_shape[2]
     NN = arg1_shape[1]
 
-    out = select_gemm_kernel(
-        cutlass_profiler, MM, KK, NN, out_dtype, True, profile_all, 
use_multiprocessing
+    name, cutlass_op_def = select_gemm_kernel(
+        cutlass_profiler, op_type, MM, KK, NN, out_dtype, True, profile_all, 
use_multiprocessing
     )
 
-    if op_type == "cutlass.batch_matmul":
-        cutlass_op_def = out["opdef"]
-    else:
-        raise ValueError("%s pattern is not implemented." % op_type)
-
-    assert "tn_align" in out["name"], "Only supports (row_major, col_major) 
input layout for now."
-
     return {
         "batch": arg0_shape[0],
         "batch_stride_A": arg0_shape[1] * arg0_shape[2],
         "batch_stride_B": arg1_shape[1] * arg1_shape[2],
         "batch_stride_C": arg0_shape[1] * arg1_shape[1],
         "cutlass_op_def": cutlass_op_def,
-        "cutlass_op_name": out["name"],
+        "cutlass_op_name": name,
         "lda": "K",
         "ldb": "K",
         "ldc": "N",
@@ -158,26 +154,15 @@ def handle_dense(
     KK = arg0_shape[1]
     NN = arg1_shape[0]
 
-    out = select_gemm_kernel(
-        cutlass_profiler, MM, KK, NN, out_dtype, False, profile_all, 
use_multiprocessing
+    name, cutlass_op_def = select_gemm_kernel(
+        cutlass_profiler, op_type, MM, KK, NN, out_dtype, False, profile_all, 
use_multiprocessing
     )
 
-    if op_type == "cutlass.dense":
-        cutlass_op_def = out["opdef"]
-    elif op_type == "cutlass.dense_bias":
-        cutlass_op_def = out["opdef_bias"]
-    elif op_type == "cutlass.dense_bias_relu":
-        cutlass_op_def = out["opdef_bias_relu"]
-    elif "cutlass.dense_bias_gelu" in op_type:
-        cutlass_op_def = out["opdef_bias_gelu"]
-    else:
-        raise ValueError("%s pattern is not implemented." % op_type)
-
-    assert "tn_align" in out["name"], "Only supports (row_major, col_major) 
input layout for now."
+    assert "tn_align" in name, "Only supports (row_major, col_major) input 
layout for now."
 
     return {
         "cutlass_op_def": cutlass_op_def,
-        "cutlass_op_name": out["name"],
+        "cutlass_op_name": name,
         "lda": "K",
         "ldb": "K",
         "ldc": "N",
@@ -198,10 +183,12 @@ def handle_conv2d(
 ):
     """Profile and select a kernel for conv2d op workload."""
     if any(isinstance(s, tvm.tir.Any) for s in d_shape):
-        out = cutlass_profiler.get_default(out_dtype)
-        logger.info("Picked the default kernel %s", out["name"])
+        out = cutlass_profiler.get_default(op_type, out_dtype)
+        name, cutlass_op_def = out["name"], out["opdef"]
+        logger.info("Picked the default kernel %s", name)
     else:
-        out = cutlass_profiler.profile(
+        name, cutlass_op_def, _ = cutlass_profiler.profile(
+            op_type,
             d_shape,
             w_shape,
             padding,
@@ -212,28 +199,13 @@ def handle_conv2d(
             use_multiprocessing=use_multiprocessing,
         )
         if profile_all:
-            logger.info("The best kernel is %s", out["name"])
+            logger.info("The best kernel is %s", name)
         else:
-            logger.info("Picked the first kernel found %s", out["name"])
-
-    if op_type == "cutlass.conv2d":
-        cutlass_op_def = out["opdef"]
-    elif op_type == "cutlass.conv2d_bias":
-        cutlass_op_def = out["opdef_bias"]
-    elif op_type == "cutlass.conv2d_bias_relu":
-        cutlass_op_def = out["opdef_bias_relu"]
-    elif op_type == "cutlass.conv2d_bias_sigmoid":
-        cutlass_op_def = out["opdef_bias_sigmoid"]
-    elif op_type == "cutlass.conv2d_bias_silu":
-        cutlass_op_def = out["opdef_bias_silu"]
-    elif op_type == "cutlass.conv2d_bias_hardswish":
-        cutlass_op_def = out["opdef_bias_hardswish"]
-    else:
-        raise ValueError("%s pattern is not implemented." % op_type)
+            logger.info("Picked the first kernel found %s", name)
 
     return {
         "cutlass_op_def": cutlass_op_def,
-        "cutlass_op_name": out["name"],
+        "cutlass_op_name": name,
     }
 
 
diff --git a/python/tvm/contrib/cutlass/conv2d_operation.py 
b/python/tvm/contrib/cutlass/conv2d_operation.py
index 3530892..1c7f9a3 100644
--- a/python/tvm/contrib/cutlass/conv2d_operation.py
+++ b/python/tvm/contrib/cutlass/conv2d_operation.py
@@ -186,7 +186,7 @@ class EmitConv2dInstance:
   >::Kernel;
 """
 
-    def emit(self, operation, no_beta_scaling=True):
+    def emit(self, operation, no_beta_scaling=False):
         """Instantiate a Conv2d kernel from given `operation`."""
         warp_shape = [
             int(
diff --git a/python/tvm/contrib/cutlass/gen_conv2d.py 
b/python/tvm/contrib/cutlass/gen_conv2d.py
index 43317f9..4e4a7b2 100644
--- a/python/tvm/contrib/cutlass/gen_conv2d.py
+++ b/python/tvm/contrib/cutlass/gen_conv2d.py
@@ -20,10 +20,7 @@ import re
 from .conv2d_operation import Conv2dOperation, EmitConv2dInstance
 from .gen_gemm import CutlassGemmProfiler
 from .conv2d_profiler import Conv2dProfilerEmitter
-from .gen_tensor_op import (
-    ProfilerEngine,
-    GENERATOR_FUNC_TABLE,
-)
+from .gen_tensor_op import ProfilerEngine, GENERATOR_FUNC_TABLE, EPILOGUE_MAP
 from .library import (
     EpilogueFunctor,
     SwizzlingFunctor,
@@ -35,7 +32,42 @@ from .library import (
 )
 
 
-def create_conv2d_operator(
+def create_conv2d_operator_with_epilogue(
+    op_type, tile_description, data_type, alignment, swizzling_functor
+):
+    """
+    Instantiate a cutlass kernel from the given configuration,
+    along with the epilouge functor
+    """
+    epilogue, no_beta_scaling = EPILOGUE_MAP[op_type]
+
+    element_a, element_b, element_c, element_epilogue = data_type
+
+    A = TensorDescription(element_a, LayoutType.TensorNHWC, alignment)
+    B = TensorDescription(element_b, LayoutType.TensorNHWC, alignment)
+    C = TensorDescription(element_c, LayoutType.TensorNHWC, alignment)
+
+    op = Conv2dOperation(
+        ConvKind.Fprop,
+        IteratorAlgorithm.Optimized,
+        tile_description.minimum_compute_capability,
+        tile_description,
+        A,
+        B,
+        C,
+        element_epilogue,
+        StrideSupport.Strided,
+        epilogue,
+        swizzling_functor,
+    )
+
+    name = op.procedural_name()
+    opdef = EmitConv2dInstance().emit(op, no_beta_scaling=no_beta_scaling)
+
+    return name, opdef
+
+
+def enumerate_conv2d_operators(
     tile_descriptions,
     data_type,
     alignment_constraints,
@@ -48,77 +80,38 @@ def create_conv2d_operator(
     profiler_emitter = Conv2dProfilerEmitter()
 
     element_a, element_b, element_c, element_epilogue = data_type
-    iterator_algorithms = [IteratorAlgorithm.Optimized]
 
-    layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, 
LayoutType.TensorNHWC)
     for tile in tile_descriptions:
         for alignment in alignment_constraints:
 
-            alignment_c = min(8, alignment)
-
-            A = TensorDescription(element_a, layout[0], alignment)
-            B = TensorDescription(element_b, layout[1], alignment)
-            C = TensorDescription(element_c, layout[2], alignment_c)
-
-            swizzling_functor_ = swizzling_functor
-
-            for iterator_algorithm in iterator_algorithms:
-                op_entry = {}
-
-                op = Conv2dOperation(
-                    ConvKind.Fprop,
-                    iterator_algorithm,
-                    tile.minimum_compute_capability,
-                    tile,
-                    A,
-                    B,
-                    C,
-                    element_epilogue,
-                    StrideSupport.Strided,
-                    EpilogueFunctor.LinearCombination,
-                    swizzling_functor_,
-                )
-
-                op_entry["opdef"] = kernel_emitter.emit(op)
-                op_entry["op"] = op
-                op_entry["src"] = profiler_emitter.emit(op_entry["opdef"], 
op.procedural_name())
-                op_entry["name"] = op.procedural_name()
-
-                # fused ops
-                for epilogue, opdef, no_bias_scaling in zip(
-                    [
-                        EpilogueFunctor.LinearCombinationBias,
-                        EpilogueFunctor.LinearCombinationRelu,
-                        EpilogueFunctor.LinearCombinationSigmoid,
-                        EpilogueFunctor.LinearCombinationSilu,
-                        EpilogueFunctor.LinearCombinationHardSwish,
-                    ],
-                    [
-                        "opdef_bias",
-                        "opdef_bias_relu",
-                        "opdef_bias_sigmoid",
-                        "opdef_bias_silu",
-                        "opdef_bias_hardswish",
-                    ],
-                    [True, True, False, False, False],
-                ):
-                    op = Conv2dOperation(
-                        ConvKind.Fprop,
-                        iterator_algorithm,
-                        tile.minimum_compute_capability,
-                        tile,
-                        A,
-                        B,
-                        C,
-                        element_epilogue,
-                        StrideSupport.Strided,
-                        epilogue,
-                        swizzling_functor_,
-                    )
-
-                    op_entry[opdef] = kernel_emitter.emit(op, no_bias_scaling)
-
-                ret.append(op_entry)
+            A = TensorDescription(element_a, LayoutType.TensorNHWC, alignment)
+            B = TensorDescription(element_b, LayoutType.TensorNHWC, alignment)
+            C = TensorDescription(element_c, LayoutType.TensorNHWC, alignment)
+
+            op = Conv2dOperation(
+                ConvKind.Fprop,
+                IteratorAlgorithm.Optimized,
+                tile.minimum_compute_capability,
+                tile,
+                A,
+                B,
+                C,
+                element_epilogue,
+                StrideSupport.Strided,
+                EpilogueFunctor.LinearCombination,
+                swizzling_functor,
+            )
+
+            ret.append(
+                {
+                    "src": profiler_emitter.emit(kernel_emitter.emit(op), 
op.procedural_name()),
+                    "name": op.procedural_name(),
+                    "tile_description": tile,
+                    "alignment": alignment,
+                    "data_type": data_type,
+                    "swizzle_functor": swizzling_functor,
+                }
+            )
 
     return ret
 
@@ -133,12 +126,15 @@ class CutlassConv2DProfiler:
         self.engine = ProfilerEngine(sm, cutlass_path, binary_path)
         self.cache = {}
 
-    def get_default(self, out_dtype):
-        gemm_profile_result = self.gemm_profiler.get_default(out_dtype)
+    def get_default(self, op_type, out_dtype):
+        gemm_profile_result = self.gemm_profiler.get_default(op_type, 
out_dtype)
         tile_description = gemm_profile_result["tile_description"]
         alignment = gemm_profile_result["alignment"]
         data_type = gemm_profile_result["data_type"]
-        return create_conv2d_operator([tile_description], data_type, 
[alignment])[0]
+        name, opdef = create_conv2d_operator_with_epilogue(
+            op_type, tile_description, data_type, alignment, 
SwizzlingFunctor.Identity4
+        )
+        return {"name": name, "opdef": opdef}
 
     def check_align(self, op_name, C, K):
         """Filter out kernels that cannot be supported."""
@@ -147,7 +143,7 @@ class CutlassConv2DProfiler:
         align = int(aligns[0][-1])
         return all([dim % align == 0 for dim in [C, K]])
 
-    def profile(
+    def select_op(
         self,
         d_shape,
         w_shape,
@@ -158,9 +154,9 @@ class CutlassConv2DProfiler:
         profile_all=True,
         use_multiprocessing=False,
     ):
-        """Profile and select the best kernel from candidate kernels.
-        If profile_all is False, return immediately after the first applicable 
kernel is found.
-        If use_multiprocessing is True, compile all profiler executables in 
parallel.
+        """
+        Profile and select the best kernel from candidate kernels.
+        See the documentation for the profile method below.
         """
         N, H, W, IC = d_shape
         OC, R, S, _ = w_shape
@@ -183,7 +179,10 @@ class CutlassConv2DProfiler:
         if workload in self.cache:
             return self.cache[workload]
 
-        ops = GENERATOR_FUNC_TABLE[self.sm](out_dtype, 
op_creator=create_conv2d_operator)
+        ops = GENERATOR_FUNC_TABLE[self.sm](
+            out_dtype,
+            op_creator=enumerate_conv2d_operators,
+        )
         ops = list(filter(lambda op: self.check_align(op["name"], IC, OC), 
ops))
 
         if profile_all:
@@ -201,6 +200,39 @@ class CutlassConv2DProfiler:
                 self.cache[workload] = op
                 return op
 
-        output = min(ops, key=lambda i: i["runtime"])
-        self.cache[workload] = output
-        return output
+        op = min(ops, key=lambda i: i["runtime"])
+        self.cache[workload] = op
+        return op
+
+    def profile(
+        self,
+        op_type,
+        d_shape,
+        w_shape,
+        padding,
+        stride,
+        dilation,
+        out_dtype,
+        profile_all=True,
+        use_multiprocessing=False,
+    ):
+        """Profile and select the best kernel from candidate kernels.
+        If profile_all is False, return immediately after the first applicable 
kernel is found.
+        If use_multiprocessing is True, compile all profiler executables in 
parallel.
+        """
+        op = self.select_op(
+            d_shape,
+            w_shape,
+            padding,
+            stride,
+            dilation,
+            out_dtype,
+            profile_all=profile_all,
+            use_multiprocessing=use_multiprocessing,
+        )
+
+        name, opdef = create_conv2d_operator_with_epilogue(
+            op_type, op["tile_description"], op["data_type"], op["alignment"], 
op["swizzle_functor"]
+        )
+
+        return name, opdef, op["runtime"]
diff --git a/python/tvm/contrib/cutlass/gen_gemm.py 
b/python/tvm/contrib/cutlass/gen_gemm.py
index 7048c32..9159ed8 100644
--- a/python/tvm/contrib/cutlass/gen_gemm.py
+++ b/python/tvm/contrib/cutlass/gen_gemm.py
@@ -16,14 +16,10 @@
 # under the License.
 # pylint: disable=invalid-name
 """GEMM kernel generator and profiler for CUTLASS."""
-from functools import partial
 import re
 from .gemm_operation import GemmOperation, EmitGemmInstance
 from .gemm_profiler import GemmProfilerEmitter
-from .gen_tensor_op import (
-    ProfilerEngine,
-    GENERATOR_FUNC_TABLE,
-)
+from .gen_tensor_op import ProfilerEngine, GENERATOR_FUNC_TABLE, EPILOGUE_MAP
 from .library import (
     EpilogueFunctor,
     SwizzlingFunctor,
@@ -33,12 +29,50 @@ from .library import (
 )
 
 
-def create_gemm_operator(
+def create_gemm_operator_with_epilogue(
+    op_type,
+    tile_description,
+    data_type,
+    alignment,
+    swizzling_functor,
+    batched=False,
+):
+    """
+    Instantiate a cutlass kernel from the given configuration,
+    along with the epilouge functor
+    """
+    element_a, element_b, element_c, element_epilogue = data_type
+
+    A = TensorDescription(element_a, LayoutType.RowMajor, alignment)
+    B = TensorDescription(element_b, LayoutType.ColumnMajor, alignment)
+    C = TensorDescription(element_c, LayoutType.RowMajor, alignment)
+
+    if batched:
+        swizzling_functor = SwizzlingFunctor.Batched
+
+    epilogue, no_beta_scaling = EPILOGUE_MAP[op_type]
+
+    op = GemmOperation(
+        tile_description.minimum_compute_capability,
+        tile_description,
+        A,
+        B,
+        C,
+        element_epilogue,
+        epilogue,
+        swizzling_functor,
+    )
+
+    return op.procedural_name(), EmitGemmInstance().emit(
+        op, no_beta_scaling=no_beta_scaling, batched=batched
+    )
+
+
+def enumerate_gemm_operators(
     tile_descriptions,
     data_type,
     alignment_constraints,
     swizzling_functor=SwizzlingFunctor.Identity8,
-    batched=False,
 ):
     """Exhaustively instantiate all kernels from a given configuration."""
     ret = []
@@ -47,86 +81,44 @@ def create_gemm_operator(
 
     element_a, element_b, element_c, element_epilogue = data_type
 
-    if batched:
-        swizzling_functor = SwizzlingFunctor.Batched
+    for tile_description in tile_descriptions:
+        for alignment in alignment_constraints:
+            A = TensorDescription(element_a, LayoutType.RowMajor, alignment)
+            B = TensorDescription(element_b, LayoutType.ColumnMajor, alignment)
+            C = TensorDescription(element_c, LayoutType.RowMajor, alignment)
+
+            op = GemmOperation(
+                tile_description.minimum_compute_capability,
+                tile_description,
+                A,
+                B,
+                C,
+                element_epilogue,
+                EpilogueFunctor.LinearCombination,
+                swizzling_functor,
+            )
+
+            src = profiler_emitter.emit(
+                op.procedural_name(),
+                kernel_emitter.emit(op, batched=False),
+                DataTypeTag[element_a],
+                DataTypeTag[element_b],
+                DataTypeTag[element_c],
+                op.leading_dim(),
+            )
+
+            ret.append(
+                {
+                    "src": src,
+                    "op": op,
+                    "name": op.procedural_name(),
+                    "tile_description": tile_description,
+                    "alignment": alignment,
+                    "data_type": data_type,
+                    "swizzle_functor": swizzling_functor,
+                }
+            )
 
-    layouts = [
-        (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor),
-    ]
-
-    for layout in layouts:
-        for tile_description in tile_descriptions:
-            for alignment in alignment_constraints:
-                alignment_c = min(8, alignment)
-
-                A = TensorDescription(element_a, layout[0], alignment)
-                B = TensorDescription(element_b, layout[1], alignment)
-                C = TensorDescription(element_c, layout[2], alignment_c)
-
-                op_entry = {}
-                op = GemmOperation(
-                    tile_description.minimum_compute_capability,
-                    tile_description,
-                    A,
-                    B,
-                    C,
-                    element_epilogue,
-                    EpilogueFunctor.LinearCombination,
-                    swizzling_functor,
-                )
-                op_bias = GemmOperation(
-                    tile_description.minimum_compute_capability,
-                    tile_description,
-                    A,
-                    B,
-                    C,
-                    element_epilogue,
-                    EpilogueFunctor.LinearCombinationBias,
-                    swizzling_functor,
-                )
-                op_bias_relu = GemmOperation(
-                    tile_description.minimum_compute_capability,
-                    tile_description,
-                    A,
-                    B,
-                    C,
-                    element_epilogue,
-                    EpilogueFunctor.LinearCombinationRelu,
-                    swizzling_functor,
-                )
-                op_bias_gelu = GemmOperation(
-                    tile_description.minimum_compute_capability,
-                    tile_description,
-                    A,
-                    B,
-                    C,
-                    element_epilogue,
-                    EpilogueFunctor.LinearCombinationGelu,
-                    swizzling_functor,
-                )
-
-                op_entry["op"] = op
-                op_entry["name"] = op.procedural_name()
-                op_entry["opdef"] = kernel_emitter.emit(op, batched=batched)
-                op_entry["opdef_bias"] = kernel_emitter.emit(
-                    op_bias, no_beta_scaling=True, batched=batched
-                )
-                op_entry["opdef_bias_relu"] = kernel_emitter.emit(
-                    op_bias_relu, no_beta_scaling=True, batched=batched
-                )
-                op_entry["opdef_bias_gelu"] = 
kernel_emitter.emit(op_bias_gelu, batched=batched)
-                op_entry["src"] = profiler_emitter.emit(
-                    op.procedural_name(),
-                    kernel_emitter.emit(op, batched=False),
-                    DataTypeTag[element_a],
-                    DataTypeTag[element_b],
-                    DataTypeTag[element_c],
-                    op.leading_dim(),
-                )
-                op_entry["tile_description"] = tile_description
-                op_entry["alignment"] = alignment
-                op_entry["data_type"] = data_type
-                ret.append(op_entry)
     return ret
 
 
@@ -164,30 +156,38 @@ class CutlassGemmProfiler:
         # When the above issue is resolved, we can remove the alignment check 
on M below.
         return all([dim % align == 0 for dim in [M, N, K]])
 
-    def get_default(self, out_dtype, batched=False):
+    def get_default(self, op_type, out_dtype, batched=False):
         """Return the default kernel for the requested architecture.
         For now, the default kernel was picked arbitrary.
         """
-        ops = GENERATOR_FUNC_TABLE[self.sm](
-            out_dtype, op_creator=partial(create_gemm_operator, 
batched=batched)
-        )
+        ops = GENERATOR_FUNC_TABLE[self.sm](out_dtype, 
op_creator=enumerate_gemm_operators)
         default_kernel_name = DEFAULT_KERNELS[self.sm][out_dtype]
         filtered = list(filter(lambda op: op["name"] == default_kernel_name, 
ops))
         assert len(filtered) == 1
-        return filtered[0]
+        op = filtered[0]
+        name, opdef = create_gemm_operator_with_epilogue(
+            op_type,
+            op["tile_description"],
+            op["data_type"],
+            op["alignment"],
+            op["swizzle_functor"],
+            batched=batched,
+        )
+        op.update({"name": name, "opdef": opdef})
+        return op
 
-    def profile(
-        self, M, N, K, out_dtype, profile_all=True, use_multiprocessing=False, 
batched=False
-    ):
-        """Profile and select the best kernel from candidate kernels.
-        If profile_all is False, return immediately after the first applicable 
kernel is found.
-        If use_multiprocessing is True, compile all profiler executables in 
parallel.
+    def select_op(self, M, N, K, out_dtype, profile_all=True, 
use_multiprocessing=False):
+        """
+        Profile and select the best kernel from candidate kernels.
+        See the documentation for the profile method below.
         """
         if (M, N, K) in self.cache:
-            return self.cache[(M, N, K)]
+            op = self.cache[(M, N, K)]
+            return op
 
         ops = GENERATOR_FUNC_TABLE[self.sm](
-            out_dtype, op_creator=partial(create_gemm_operator, 
batched=batched)
+            out_dtype,
+            op_creator=enumerate_gemm_operators,
         )
         ops = list(filter(lambda op: self.check_align(op["name"], M, N, K), 
ops))
 
@@ -201,6 +201,36 @@ class CutlassGemmProfiler:
                 self.cache[(M, N, K)] = op
                 return op
 
-        output = min(ops, key=lambda i: i["runtime"])
-        self.cache[(M, N, K)] = output
-        return output
+        op = min(ops, key=lambda i: i["runtime"])
+        self.cache[(M, N, K)] = op
+        return op
+
+    def profile(
+        self,
+        op_type,
+        M,
+        N,
+        K,
+        out_dtype,
+        profile_all=True,
+        use_multiprocessing=False,
+        batched=False,
+    ):
+        """Profile and select the best kernel from candidate kernels.
+        If profile_all is False, return immediately after the first applicable 
kernel is found.
+        If use_multiprocessing is True, compile all profiler executables in 
parallel.
+        """
+        op = self.select_op(
+            M, N, K, out_dtype, profile_all=profile_all, 
use_multiprocessing=use_multiprocessing
+        )
+
+        name, opdef = create_gemm_operator_with_epilogue(
+            op_type,
+            op["tile_description"],
+            op["data_type"],
+            op["alignment"],
+            op["swizzle_functor"],
+            batched=batched,
+        )
+
+        return name, opdef, op["runtime"]
diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py 
b/python/tvm/contrib/cutlass/gen_tensor_op.py
index 9ccde37..6632b15 100644
--- a/python/tvm/contrib/cutlass/gen_tensor_op.py
+++ b/python/tvm/contrib/cutlass/gen_tensor_op.py
@@ -27,6 +27,7 @@ from .library import (
     OpcodeClass,
     MathOperation,
     TileDescription,
+    EpilogueFunctor,
 )
 
 logger = logging.getLogger("cutlass")
@@ -165,6 +166,23 @@ GENERATOR_FUNC_TABLE = {
 }
 
 
+# (Epilogue functor name, no_beta_scaling)
+EPILOGUE_MAP = {
+    "cutlass.dense": (EpilogueFunctor.LinearCombination, False),
+    "cutlass.dense_bias": (EpilogueFunctor.LinearCombinationBias, True),
+    "cutlass.dense_bias_relu": (EpilogueFunctor.LinearCombinationRelu, True),
+    "cutlass.dense_bias_gelu_fp16": (EpilogueFunctor.LinearCombinationGelu, 
False),
+    "cutlass.dense_bias_gelu_fp32": (EpilogueFunctor.LinearCombinationGelu, 
False),
+    "cutlass.batch_matmul": (EpilogueFunctor.LinearCombination, False),
+    "cutlass.conv2d_bias_hardswish": 
(EpilogueFunctor.LinearCombinationHardSwish, False),
+    "cutlass.conv2d_bias_silu": (EpilogueFunctor.LinearCombinationSilu, False),
+    "cutlass.conv2d_bias_sigmoid": (EpilogueFunctor.LinearCombinationSigmoid, 
False),
+    "cutlass.conv2d_bias_relu": (EpilogueFunctor.LinearCombinationRelu, True),
+    "cutlass.conv2d_bias": (EpilogueFunctor.LinearCombinationBias, True),
+    "cutlass.conv2d": (EpilogueFunctor.LinearCombination, False),
+}
+
+
 class ProfilerEngine:
     """Compile and run a given profiler executable."""
 
diff --git a/python/tvm/relay/op/contrib/cutlass.py 
b/python/tvm/relay/op/contrib/cutlass.py
index eb36dc2..cbbc45a 100644
--- a/python/tvm/relay/op/contrib/cutlass.py
+++ b/python/tvm/relay/op/contrib/cutlass.py
@@ -16,6 +16,7 @@
 # under the License.
 # pylint: disable=invalid-name
 """Patterns supported CUTLASS."""
+from tvm import relay
 from tvm.ir.transform import Sequential, PassContext
 from tvm.relay import transform
 from tvm.relay.build_module import bind_params_by_name
@@ -95,6 +96,8 @@ def check_dtype(lhs, rhs):
 
 
 def get_root_call(call, root_op_name):
+    if not isinstance(call, relay.Call):
+        return None
     if str(call.op) == root_op_name:
         return call
     return get_root_call(call.args[0], root_op_name)
@@ -151,13 +154,17 @@ def partition_for_cutlass(mod, params=None):
         make_gemm_pattern(True, "gelu", out_dtype="float32"),
         check_gemm,
     )
-    cutlass_patterns = [
+
+    dense_patterns = [
         dense_bias_gelu_fp16_pat,
         dense_bias_gelu_fp32_pat,
         dense_bias_relu_pat,
         dense_bias_pat,
         dense_pat,
         ("cutlass.batch_matmul", make_batch_matmul_pattern(), 
check_batch_matmul),
+    ]
+
+    conv2d_patterns = [
         (
             "cutlass.conv2d_bias_hardswish",
             make_conv2d_pattern(with_bias=True, with_act="hardswish"),
@@ -182,6 +189,8 @@ def partition_for_cutlass(mod, params=None):
         ("cutlass.conv2d", make_conv2d_pattern(), check_conv2d),
     ]
 
+    cutlass_patterns = dense_patterns + conv2d_patterns
+
     if params is not None:
         mod["main"] = bind_params_by_name(mod["main"], params)
         remove_bn_pass = Sequential(

Reply via email to