This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 35a7992  [CUTLASS] Add parallel split-k support to wgrad (#10185)
35a7992 is described below

commit 35a7992fe20154b4da57ab1e9d3aa2585e9d9151
Author: Masahiro Masuda <[email protected]>
AuthorDate: Wed Feb 9 01:34:42 2022 +0900

    [CUTLASS] Add parallel split-k support to wgrad (#10185)
    
    * [CUTLASS] Add split-k support to wgrad
    
    commit 60b73a91b79d644d8c95f682eedaf47a89abba0d
    Author: Masahiro Masuda <[email protected]>
    Date:   Tue Feb 8 10:43:11 2022 +0900
    
        pylint
    
    commit ae2e7187256316c48c915c3c187feb5cd4d4dbd4
    Author: Masahiro Masuda <[email protected]>
    Date:   Sun Feb 6 14:51:52 2022 +0900
    
        Add split-k support for wgrad
    
        commit 43820d50055b0bd17b736f5c5830321c7509a20a
        Author: Masahiro Masuda <[email protected]>
        Date:   Sun Feb 6 10:07:34 2022 +0900
    
            fix and add doc
    
        commit 446a95b0aabc5ab69cdd2e414b812aab1c557f42
        Author: Masahiro Masuda <[email protected]>
        Date:   Sun Feb 6 09:48:38 2022 +0900
    
            dw conv2d properly supported for wgrad
    
        commit adc4e22d2e03a99f30ebb6a5e956a1749de693f0
        Author: Masahiro Masuda <[email protected]>
        Date:   Sat Feb 5 16:32:42 2022 +0900
    
            fix overwriting template
    
        commit 040eab000bc5f162c6e9aca70ae6d29378fe65bc
        Author: Masahiro Masuda <[email protected]>
        Date:   Sat Feb 5 16:06:27 2022 +0900
    
            black
    
        commit e5a07c24b7463552b8e545710d25472159bcc127
        Author: Masahiro Masuda <[email protected]>
        Date:   Sat Feb 5 16:03:10 2022 +0900
    
            add reduction in profiler
    
        commit be89334ab981d536d010dd765c9cf601dbdae5e0
        Author: Masahiro Masuda <[email protected]>
        Date:   Sat Feb 5 06:58:03 2022 +0900
    
            adding split k reduction to conv2d profiler
    
        commit ae09b0fbdc3a472eb320d866c054f73b3142f21c
        Author: Masahiro Masuda <[email protected]>
        Date:   Fri Feb 4 11:52:59 2022 +0900
    
            fixed conv2d_backward_weight typerel for dw conv2d
    
            commit 16fe5313fd1219e2e7d531ef9b36f64bb557e5e7
            Author: Masahiro Masuda <[email protected]>
            Date:   Thu Feb 3 12:59:22 2022 +0900
    
                wip
    
            commit 2167c2543340a285bb1985e8fe37e11aed51fb9b
            Author: Masahiro Masuda <[email protected]>
            Date:   Thu Feb 3 04:22:19 2022 +0900
    
                fix conv2d type rel for depth wise and grouped conv2d
    
        commit 14b12e5dd84fc34691d585213387198f091eefc5
        Author: Masahiro Masuda <[email protected]>
        Date:   Fri Feb 4 05:01:03 2022 +0900
    
            remove split_k.py
    
        commit b14127179c43f71c3ce5ccc7b4ca678a099e5497
        Author: Masahiro Masuda <[email protected]>
        Date:   Fri Feb 4 04:48:21 2022 +0900
    
            workaround for invalid split_k_slice
    
        commit 6e4c7e1d77d89f124abc77dbcdab69eff8a5d961
        Author: Masahiro Masuda <[email protected]>
        Date:   Fri Feb 4 02:43:58 2022 +0900
    
            support split k in profiler
    
        commit 2eb1cf43c7f56f0537cf249855054b5cbd357b13
        Author: Masahiro Masuda <[email protected]>
        Date:   Fri Feb 4 02:31:03 2022 +0900
    
            improvement
    
        commit 0bce8f3778a6bb05607232a0997d25681e55ce7c
        Author: Masahiro Masuda <[email protected]>
        Date:   Thu Feb 3 18:20:12 2022 +0900
    
            fixed for fp16 output
    
        commit 30df1bd5282a4d326856382726d4e63ee8c27e8e
        Author: Masahiro Masuda <[email protected]>
        Date:   Thu Feb 3 17:50:33 2022 +0900
    
            fp32 output works
    
        commit 7a519956b8d103464dff83b4f01b75973f4a33b0
        Author: Masahiro Masuda <[email protected]>
        Date:   Thu Feb 3 14:30:22 2022 +0900
    
            fix
    
        commit 4a383e2c7c37148a563e9cf34968fb7da3aaf91f
        Author: Masahiro Masuda <[email protected]>
        Date:   Thu Feb 3 14:05:24 2022 +0900
    
            update c++ codegen
    
        commit 6206e388cc7062cbef0b3c8c47fcd228b44b6818
        Author: Masahiro Masuda <[email protected]>
        Date:   Thu Feb 3 13:46:05 2022 +0900
    
            wip
    
        commit 0ece49b53e773ebc1ea71c7667abc0cbb29d91bf
        Author: Masahiro Masuda <[email protected]>
        Date:   Thu Feb 3 03:05:21 2022 +0900
    
            wip
    
        commit 08a6147940d9911fd65a890a4d90beb68176fc03
        Author: Masahiro Masuda <[email protected]>
        Date:   Wed Feb 2 13:10:21 2022 +0900
    
            test worked with fp32 output
    
        commit 084d5c47666df92ba6c2c1445d5a23de0193a119
        Author: Masahiro Masuda <[email protected]>
        Date:   Wed Feb 2 12:35:18 2022 +0900
    
            fix compile error for fprop
    
        commit 31f25436c5aca1a75336fa1a8d1c8a25a4936ee8
        Author: Masahiro Masuda <[email protected]>
        Date:   Wed Feb 2 12:18:06 2022 +0900
    
            compiled
    
        commit c2098e79ade47117f2c32132da864b1fa73fce4a
        Author: Masahiro Masuda <[email protected]>
        Date:   Wed Feb 2 11:11:43 2022 +0900
    
            wip
    
    commit a14585020151d0e09bb9bac549285dceb13e55e1
    Author: Masahiro Masuda <[email protected]>
    Date:   Sun Feb 6 14:46:16 2022 +0900
    
        fixed for sm75
    
    commit 61515062ef4576bf5b4e7e9e800f7f705738809c
    Author: Masahiro Masuda <[email protected]>
    Date:   Sun Feb 6 14:32:46 2022 +0900
    
        all tests work
    
    commit 041c094b3646e0f521f5bd2c4f6f6b5b1cff7b97
    Author: Masahiro Masuda <[email protected]>
    Date:   Sun Feb 6 14:19:09 2022 +0900
    
        dw conv2d properly supported for wgrad
    
    commit 2191918743a4e9ffb8254f3786d817be57ff49cc
    Author: Masahiro Masuda <[email protected]>
    Date:   Wed Feb 2 09:14:05 2022 +0900
    
        wgrad tests now work under pytest
    
    commit 78f76df1eb1602f66cacb888a97b6b267f8600a7
    Author: Masahiro Masuda <[email protected]>
    Date:   Wed Feb 2 07:31:54 2022 +0900
    
        run black
    
    commit 0a82149fe0586b0bf449fc7f3a1fa9809e9b38d2
    Author: Masahiro Masuda <[email protected]>
    Date:   Wed Feb 2 06:12:39 2022 +0900
    
        [CUTLASS] Add wgrad support (without split-k)
    
    * pylint
    
    * add more doc
    
    * more doc clarification
---
 python/tvm/contrib/cutlass/build.py            |  14 ++-
 python/tvm/contrib/cutlass/conv2d_operation.py |  82 ++++++++++++++--
 python/tvm/contrib/cutlass/conv2d_profiler.py  |  58 +++++++++--
 python/tvm/contrib/cutlass/gen_conv2d.py       |  87 +++++++++-------
 python/tvm/contrib/cutlass/gen_tensor_op.py    |   3 +
 src/relay/backend/contrib/cutlass/codegen.cc   | 131 +++++++++++++++++++++----
 tests/python/contrib/test_cutlass.py           |  48 ++++++---
 7 files changed, 340 insertions(+), 83 deletions(-)

diff --git a/python/tvm/contrib/cutlass/build.py 
b/python/tvm/contrib/cutlass/build.py
index 918eeaf..06c33f2 100644
--- a/python/tvm/contrib/cutlass/build.py
+++ b/python/tvm/contrib/cutlass/build.py
@@ -14,7 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# pylint: disable=invalid-name
+# pylint: disable=invalid-name, dangerous-default-value
 """Driver for partitioning and building a Relay module for CUTLASS offload."""
 import logging
 import os
@@ -238,6 +238,7 @@ def handle_conv2d(
     data_dtype,
     weight_dtype,
     use_3xtf32,
+    split_k_slices,
     profile_all_alignments,
     find_first_valid,
     use_multiprocessing,
@@ -269,6 +270,7 @@ def handle_conv2d(
             weight_dtype,
             use_3xtf32,
             conv_kind,
+            split_k_slices,
             profile_all_alignments,
             find_first_valid=find_first_valid,
             use_multiprocessing=use_multiprocessing,
@@ -288,6 +290,7 @@ def tune_cutlass_kernels(
     mod,
     sm,
     use_3xtf32=True,
+    split_k_slices=[1],
     profile_all_alignments=False,
     find_first_valid=False,
     use_multiprocessing=False,
@@ -309,6 +312,14 @@ def tune_cutlass_kernels(
         Wheter or not use slower but very accurate (compared to tf32) 3xtf32 
mode for
         fp32 inputs on tensorcore.
 
+    split_k_slices : list of int
+        Split factor candidates for split-K GEMM. If split-K > 1, the GEMM 
K-loop is computed in
+        parallel accross split-K blocks, and a seperate global reduction 
kernel is launched to
+        accumulate partial reductions. The profiler will pick the best split-k 
factor from the
+        given candidate list. Note that the larger split-K factor requires a 
larger workspace.
+        Currently, parallel split-k has been tested only for wgrad. For GEMM 
and other conv2d
+        kinds, split_k_slices is ignored.
+
     profile_all_alignments : bool
         When True, profile all kernal variants with smaller alignments than 
the largest possible.
 
@@ -380,6 +391,7 @@ def tune_cutlass_kernels(
                         arg0_dtype,
                         arg1_dtype,
                         use_3xtf32,
+                        split_k_slices,
                         profile_all_alignments,
                         find_first_valid,
                         use_multiprocessing,
diff --git a/python/tvm/contrib/cutlass/conv2d_operation.py 
b/python/tvm/contrib/cutlass/conv2d_operation.py
index 5318cc7..7b78c5a 100644
--- a/python/tvm/contrib/cutlass/conv2d_operation.py
+++ b/python/tvm/contrib/cutlass/conv2d_operation.py
@@ -35,6 +35,7 @@ class Conv2dOperation:
         stride_support,
         epilogue_functor=EpilogueFunctor.LinearCombination,
         swizzling_functor=SwizzlingFunctor.Identity1,
+        split_k_slices=1,
     ):
         self.operation_kind = OperationKind.Conv2d
         self.arch = arch
@@ -48,6 +49,7 @@ class Conv2dOperation:
         self.iterator_algorithm = iterator_algorithm
         self.stride_support = stride_support
         self.swizzling_functor = swizzling_functor
+        self.split_k_slices = split_k_slices
 
     def accumulator_type(self):
         return self.tile_description.math_instruction.element_accumulator
@@ -127,6 +129,9 @@ class Conv2dOperation:
                 "_${layout}_align${alignment}"
             )
 
+        if self.split_k_slices > 1:
+            configuration_name += "_splitk%d" % self.split_k_slices
+
         return substitute_template(
             configuration_name,
             {
@@ -172,6 +177,14 @@ class EmitConv2dInstance:
       ${unary_op}
     >"""
 
+        self.epilogue_wgrad = """
+    ${epilogue_functor}<
+      ${element_c},
+      4,
+      float,
+      float
+    >"""
+
         self.template = """
   // Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance 
"${operation_name}"
   using ${operation_name} =
@@ -197,9 +210,31 @@ class EmitConv2dInstance:
     ${align_a},
     ${align_b}
   >::Kernel;
+
+  ${reduction}
+"""
+
+        self.reduction_template = """
+using EpilogueOutputOp = ${epilogue};
+using ReductionOp = cutlass::reduction::thread::ReduceAdd<
+    ${element_accumulator},
+    ${element_accumulator},
+      EpilogueOutputOp::kCount
+      >;
+
+using ReductionKernel = cutlass::reduction::kernel::ReduceSplitK<
+    cutlass::MatrixShape<4, 32 * EpilogueOutputOp::kCount>,
+      EpilogueOutputOp,
+      ReductionOp
+      >;
+
+using ReductionDevice = 
cutlass::reduction::device::ReduceSplitK<ReductionKernel>;
+using ReductionStrideIndex = typename ReductionDevice::StrideIndex;
 """
 
-    def emit(self, operation, no_beta_scaling=False, 
residual_block_info=False):
+    def emit(
+        self, operation, no_beta_scaling=False, residual_block_info=False, 
emit_reduction=False
+    ):
         """Instantiate a Conv2d kernel from given `operation`."""
         warp_shape = [
             int(
@@ -214,6 +249,31 @@ class EmitConv2dInstance:
             / DataTypeSize[operation.C.element]
         )
 
+        element_c = operation.C.element
+        use_split_k_wgrad = operation.conv_kind == ConvKind.Wgrad and 
operation.split_k_slices > 1
+        # Gemm output always fp32 in wgrad with split k
+        element_c_gemm = DataType.f32 if use_split_k_wgrad else element_c
+
+        if emit_reduction:
+            epilogue_reduction = substitute_template(
+                self.epilogue_wgrad,
+                {
+                    "epilogue_functor": 
EpilogueFunctorTag[operation.epilogue_functor],
+                    "element_c": DataTypeTag[element_c],
+                },
+            )
+            reduction = substitute_template(
+                self.reduction_template,
+                {
+                    "epilogue": epilogue_reduction,
+                    "operation_name": operation.procedural_name(),
+                    "element_accumulator": 
DataTypeTag[operation.accumulator_type()],
+                },
+            )
+            gemm_template = substitute_template(self.template, {"reduction": 
reduction})
+        else:
+            gemm_template = substitute_template(self.template, {"reduction": 
""})
+
         values = {
             "operation_name": operation.procedural_name(),
             "conv_kind": ConvKindTag[operation.conv_kind],
@@ -222,7 +282,7 @@ class EmitConv2dInstance:
             "layout_a": LayoutTag[operation.A.layout],
             "element_b": DataTypeTag[operation.B.element],
             "layout_b": LayoutTag[operation.B.layout],
-            "element_c": DataTypeTag[operation.C.element],
+            "element_c": DataTypeTag[element_c_gemm],
             "layout_c": LayoutTag[operation.C.layout],
             "element_accumulator": DataTypeTag[operation.accumulator_type()],
             "opcode_class": OpcodeClassTag[
@@ -262,9 +322,19 @@ class EmitConv2dInstance:
             "conv_kernel_postfix": "",
         }
 
-        if residual_block_info:
+        if use_split_k_wgrad:
+            # Even if the output is fp16, gemm output is always fp32 for split 
k wgrad.
+            epilogue_gemm = substitute_template(
+                self.epilogue_wgrad,
+                {
+                    "epilogue_functor": 
EpilogueFunctorTag[operation.epilogue_functor],
+                    "element_c": "float",
+                },
+            )
+            template = substitute_template(gemm_template, {"epilogue": 
epilogue_gemm})
+        elif residual_block_info:
             template = substitute_template(
-                self.template, {"epilogue": self.epilogue_residual_block}
+                gemm_template, {"epilogue": self.epilogue_residual_block}
             )
             values.update(
                 {
@@ -276,9 +346,9 @@ class EmitConv2dInstance:
             )
         elif no_beta_scaling:
             template = substitute_template(
-                self.template, {"epilogue": self.epilogue_no_beta_scaling}
+                gemm_template, {"epilogue": self.epilogue_no_beta_scaling}
             )
         else:
-            template = substitute_template(self.template, {"epilogue": 
self.epilogue_default})
+            template = substitute_template(gemm_template, {"epilogue": 
self.epilogue_default})
 
         return substitute_template(template, values)
diff --git a/python/tvm/contrib/cutlass/conv2d_profiler.py 
b/python/tvm/contrib/cutlass/conv2d_profiler.py
index 2f4e769..1ed5550 100644
--- a/python/tvm/contrib/cutlass/conv2d_profiler.py
+++ b/python/tvm/contrib/cutlass/conv2d_profiler.py
@@ -17,6 +17,8 @@
 # pylint: disable=import-outside-toplevel, invalid-name
 """Instantiate a C++ source for profiling CUTLASS kernels."""
 
+from .library import DataTypeTag
+
 
 class Conv2dProfilerEmitter(object):
     """Emit a C++ source for profiling CUTLASS kernels."""
@@ -24,6 +26,32 @@ class Conv2dProfilerEmitter(object):
     def __init__(self):
         from jinja2 import Template
 
+        self.reduction = """
+      ReductionDevice reduction_op;
+      static cutlass::conv::Operator const kConvolutionalOperator = 
ImplicitGemm::kConvolutionalOperator;
+      typename ReductionDevice::Arguments reduction_args(
+        cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, 
problem_size).mn(),
+        problem_size.split_k_slices,
+        cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, 
problem_size),
+        {
+        reinterpret_cast<ImplicitGemm::ElementC*> (workspace.get()),
+          
ReductionStrideIndex(tensor_c.stride()[ImplicitGemm::ImplicitGemmKernel::kTensorCStrideIdx])
+          },
+        {
+        tensor_d.device_data(),
+          
ReductionStrideIndex(tensor_d.stride()[ImplicitGemm::ImplicitGemmKernel::kTensorCStrideIdx])
+          },
+        {
+        tensor_c.device_data(),
+          
ReductionStrideIndex(tensor_c.stride()[ImplicitGemm::ImplicitGemmKernel::kTensorCStrideIdx])
+          },
+        {ElementComputeEpilogue(1), ElementComputeEpilogue(0)}
+        );
+
+      reduction_op.initialize(reduction_args, nullptr);
+      reduction_op();
+"""
+
         self.template = Template(
             """
 #include <iostream>
@@ -35,6 +63,8 @@ class Conv2dProfilerEmitter(object):
 #include "cutlass/util/command_line.h"
 #include "cutlass/util/host_tensor.h"
 #include "cutlass/util/reference/host/tensor_fill.h"
+#include "cutlass/reduction/device/reduce_split_k.h"
+#include "cutlass/reduction/thread/reduction_operators.h"
 
 #define CUTLASS_CHECK(status)                                                  
                  \
   {                                                                            
                  \
@@ -88,10 +118,11 @@ struct Options {
 };
 
 double profile_convolution(Options const &options) {
-  using ElementOutput = typename ImplicitGemm::ElementC;
+  using ElementOutput = {{ElementOutput}};
   using ElementInputA = typename ImplicitGemm::ElementA;
   using ElementInputB = typename ImplicitGemm::ElementB;
 
+  int split_k_slices = {{SplitK}};
   cutlass::conv::Conv2dProblemSize problem_size(
                                                options.input_size,
                                                options.filter_size,
@@ -100,7 +131,7 @@ double profile_convolution(Options const &options) {
                                                options.dilation,
                                                options.output_size(),
                                                
cutlass::conv::Mode::kCrossCorrelation,
-                                               1
+                                               split_k_slices
                                                );
 
   auto conv_kind = ImplicitGemm::kConvolutionalOperator;
@@ -108,20 +139,26 @@ double profile_convolution(Options const &options) {
   auto b_extent = implicit_gemm_tensor_b_extent(conv_kind, problem_size);
   auto c_extent = implicit_gemm_tensor_c_extent(conv_kind, problem_size);
 
+  using LayoutC = typename ImplicitGemm::LayoutC;
   cutlass::HostTensor<ElementInputA, typename ImplicitGemm::LayoutA> 
tensor_a(a_extent);
   cutlass::HostTensor<ElementInputB, typename ImplicitGemm::LayoutB> 
tensor_b(b_extent);
   cutlass::HostTensor<ElementOutput, typename ImplicitGemm::LayoutC> 
tensor_c(c_extent);
-  cutlass::HostTensor<ElementOutput, typename ImplicitGemm::LayoutC> 
tensor_ref_c(c_extent);
+  cutlass::HostTensor<ElementOutput, LayoutC> tensor_d(c_extent);
+  cutlass::HostTensor<ImplicitGemm::ElementC, LayoutC> tensor_c_gemm(c_extent);
 
   using ElementComputeEpilogue = typename ImplicitGemm::ElementCompute;
 
+  cutlass::conv::SplitKMode const split_k_mode = split_k_slices > 1 ?
+      cutlass::conv::SplitKMode::kParallel : 
cutlass::conv::SplitKMode::kSerial;
+
   typename ImplicitGemm::Arguments arguments{
     problem_size,
     tensor_a.device_ref(),
     tensor_b.device_ref(),
-    tensor_c.device_ref(),
-    tensor_c.device_ref(),
+    tensor_c_gemm.device_ref(),
+    tensor_c_gemm.device_ref(),
     {ElementComputeEpilogue(1), ElementComputeEpilogue(0)},
+    split_k_mode,
   };
 
   ImplicitGemm implicit_gemm_op;
@@ -144,6 +181,7 @@ double profile_convolution(Options const &options) {
   for (int iteration = 0; iteration < 100; ++iteration) {
     auto status = implicit_gemm_op();
     CUTLASS_CHECK(status);
+    {{Reduction}}
   }
 
   cudaEventRecord(events[1]);
@@ -166,6 +204,12 @@ int main(int argc, char const **args) {
 """
         )
 
-    def emit(self, op_def, op_name):
-        src = self.template.render(OperatorDef=op_def, OperatorName=op_name)
+    def emit(self, op_def, op_name, element_output, split_k_slices=1):
+        src = self.template.render(
+            OperatorDef=op_def,
+            OperatorName=op_name,
+            ElementOutput=DataTypeTag[element_output],
+            SplitK=split_k_slices,
+            Reduction=self.reduction if split_k_slices > 1 else "",
+        )
         return src
diff --git a/python/tvm/contrib/cutlass/gen_conv2d.py 
b/python/tvm/contrib/cutlass/gen_conv2d.py
index 0d46000..b51afdc 100644
--- a/python/tvm/contrib/cutlass/gen_conv2d.py
+++ b/python/tvm/contrib/cutlass/gen_conv2d.py
@@ -14,7 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# pylint: disable=invalid-name
+# pylint: disable=invalid-name, dangerous-default-value
 """Conv2d kernel generator and profiler for CUTLASS."""
 from functools import partial
 from .conv2d_operation import Conv2dOperation, EmitConv2dInstance
@@ -40,6 +40,7 @@ def create_conv2d_operator_with_epilogue(
     data_type,
     alignment,
     swizzling_functor,
+    split_k_slices,
 ):
     """
     Instantiate a cutlass kernel from the given configuration,
@@ -90,11 +91,15 @@ def create_conv2d_operator_with_epilogue(
         stride_support,
         epilogue,
         swizzling_functor,
+        split_k_slices,
     )
 
     name = op.procedural_name()
     opdef = EmitConv2dInstance().emit(
-        op, no_beta_scaling=no_beta_scaling, 
residual_block_info=residual_block_info
+        op,
+        no_beta_scaling=no_beta_scaling,
+        residual_block_info=residual_block_info,
+        emit_reduction=split_k_slices > 1,
     )
 
     return name, opdef
@@ -103,6 +108,7 @@ def create_conv2d_operator_with_epilogue(
 def enumerate_conv2d_operators(
     conv_kind,
     stride_support,
+    split_k_slices,
     tile_descriptions,
     data_type,
     alignment_constraints,
@@ -119,37 +125,45 @@ def enumerate_conv2d_operators(
     if conv_kind == ConvKind.Dgrad and stride_support == StrideSupport.Strided:
         swizzling_functor = SwizzlingFunctor.StridedDgradIdentity1
 
-    for tile in tile_descriptions:
-        for alignment in alignment_constraints:
-
-            A = TensorDescription(element_a, LayoutType.TensorNHWC, alignment)
-            B = TensorDescription(element_b, LayoutType.TensorNHWC, alignment)
-            C = TensorDescription(element_c, LayoutType.TensorNHWC, alignment)
-
-            op = Conv2dOperation(
-                conv_kind,
-                IteratorAlgorithm.Optimized,
-                tile.minimum_compute_capability,
-                tile,
-                A,
-                B,
-                C,
-                element_epilogue,
-                stride_support,
-                EpilogueFunctor.LinearCombination,
-                swizzling_functor,
-            )
-
-            ret.append(
-                {
-                    "src": profiler_emitter.emit(kernel_emitter.emit(op), 
op.procedural_name()),
-                    "name": op.procedural_name(),
-                    "tile_description": tile,
-                    "alignment": alignment,
-                    "data_type": data_type,
-                    "swizzle_functor": swizzling_functor,
-                }
-            )
+    for split_k_slice in split_k_slices:
+        for tile in tile_descriptions:
+            for alignment in alignment_constraints:
+
+                A = TensorDescription(element_a, LayoutType.TensorNHWC, 
alignment)
+                B = TensorDescription(element_b, LayoutType.TensorNHWC, 
alignment)
+                C = TensorDescription(element_c, LayoutType.TensorNHWC, 
alignment)
+
+                op = Conv2dOperation(
+                    conv_kind,
+                    IteratorAlgorithm.Optimized,
+                    tile.minimum_compute_capability,
+                    tile,
+                    A,
+                    B,
+                    C,
+                    element_epilogue,
+                    stride_support,
+                    EpilogueFunctor.LinearCombination,
+                    swizzling_functor,
+                    split_k_slice,
+                )
+
+                ret.append(
+                    {
+                        "src": profiler_emitter.emit(
+                            kernel_emitter.emit(op, 
emit_reduction=split_k_slice > 1),
+                            op.procedural_name(),
+                            element_output=element_c,
+                            split_k_slices=split_k_slice,
+                        ),
+                        "name": op.procedural_name(),
+                        "tile_description": tile,
+                        "alignment": alignment,
+                        "data_type": data_type,
+                        "swizzle_functor": swizzling_functor,
+                        "split_k_slices": split_k_slice,
+                    }
+                )
 
     return ret
 
@@ -198,6 +212,7 @@ class CutlassConv2DProfiler:
             data_type,
             alignment,
             swizzling_functor,
+            split_k_slices=1,
         )
         return {"name": name, "opdef": opdef}
 
@@ -214,6 +229,7 @@ class CutlassConv2DProfiler:
         use_3xtf32,
         conv_kind,
         stride_support,
+        split_k_slices,
         profile_all_alignments=False,
         find_first_valid=False,
         use_multiprocessing=False,
@@ -248,7 +264,7 @@ class CutlassConv2DProfiler:
             out_dtype,
             data_dtype,
             weight_dtype,
-            partial(enumerate_conv2d_operators, conv_kind, stride_support),
+            partial(enumerate_conv2d_operators, conv_kind, stride_support, 
split_k_slices),
             lambda align: all([dim % align == 0 for dim in [IC, OC]]),
             use_3xtf32,
             profile_all_alignments,
@@ -288,6 +304,7 @@ class CutlassConv2DProfiler:
         weight_dtype,
         use_3xtf32=True,
         conv_kind=ConvKind.Fprop,
+        split_k_slices=[1],
         profile_all_alignments=False,
         find_first_valid=False,
         use_multiprocessing=False,
@@ -315,6 +332,7 @@ class CutlassConv2DProfiler:
             use_3xtf32,
             conv_kind,
             stride_support,
+            split_k_slices,
             profile_all_alignments,
             find_first_valid,
             use_multiprocessing,
@@ -328,6 +346,7 @@ class CutlassConv2DProfiler:
             op["data_type"],
             op["alignment"],
             op["swizzle_functor"],
+            op["split_k_slices"],
         )
 
         return name, opdef, op["runtime"]
diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py 
b/python/tvm/contrib/cutlass/gen_tensor_op.py
index 76d4383..b3f40f0 100644
--- a/python/tvm/contrib/cutlass/gen_tensor_op.py
+++ b/python/tvm/contrib/cutlass/gen_tensor_op.py
@@ -363,6 +363,9 @@ class ProfilerEngine:
         try:
             sp = subprocess.run(cmd, capture_output=True, check=True)
             rt = float(sp.stdout)
+            if rt == 0.0:
+                # This seems to happen with split-k using invalid 
split-k-slices
+                rt = float("inf")
             logger.info("%s, %f", op_name, rt)
         except subprocess.CalledProcessError:
             rt = float("inf")
diff --git a/src/relay/backend/contrib/cutlass/codegen.cc 
b/src/relay/backend/contrib/cutlass/codegen.cc
index fdd268d..b12da1a 100644
--- a/src/relay/backend/contrib/cutlass/codegen.cc
+++ b/src/relay/backend/contrib/cutlass/codegen.cc
@@ -284,15 +284,15 @@ std::string Conv2dOp(std::string id, const Str2StrMap& 
attrs,
                          op_type != "cutlass.conv2d_bias_silu" &&
                          op_type != "cutlass.conv2d_bias_hardswish";
 
+  const std::string op_name = attrs.at("op_name");
   std::ostringstream conv2d_decl;
   CutlassPrint(conv2d_decl, attrs.at("op_def"));
-  CutlassPrint(conv2d_decl, "using Operation_" + attrs.at("op_name") +
-                                " = 
cutlass::conv::device::ImplicitGemmConvolution<" +
-                                attrs.at("op_name") + ">;\n");
-  CutlassPrint(conv2d_decl, "using Conv2d = Operation_" + attrs.at("op_name") 
+ ";\n");
+  CutlassPrint(conv2d_decl, "using Operation_" + op_name +
+                                " = 
cutlass::conv::device::ImplicitGemmConvolution<" + op_name +
+                                ">;\n");
+  CutlassPrint(conv2d_decl, "using Conv2d = Operation_" + op_name + ";\n");
   CutlassPrint(conv2d_decl, "using ElementInputA = Conv2d::ElementA;\n");
   CutlassPrint(conv2d_decl, "using ElementInputB = Conv2d::ElementB;\n");
-  CutlassPrint(conv2d_decl, "using ElementOutput = Conv2d::ElementC;\n");
   CutlassPrint(conv2d_decl, "using ElementComputeEpilogue = 
Conv2d::ElementAccumulator;\n");
 
   auto get_dim = [&attrs](const std::string& axis, const std::string& 
var_name, int axis_idx) {
@@ -319,10 +319,25 @@ std::string Conv2dOp(std::string id, const Str2StrMap& 
attrs,
   CutlassPrint(conv2d_decl, "int dilation_h = " + attrs.at("dilation_h") + 
";\n");
   CutlassPrint(conv2d_decl, "int dilation_w = " + attrs.at("dilation_w") + 
";\n");
 
+  const bool use_split_k = op_name.find("splitk") != std::string::npos;
+
+  if (use_split_k) {
+    std::string split_k_slices = 
op_name.substr(op_name.find_last_not_of("0123456789") + 1);
+    CutlassPrint(conv2d_decl, "int split_k_slices = " + split_k_slices + 
";\n");
+  } else {
+    CutlassPrint(conv2d_decl, "int split_k_slices = 1;\n");
+  }
+
   CutlassPrint(
       conv2d_decl,
       "cutlass::conv::Conv2dProblemSize problem_size(N, H, W, C, K, R, S, P, 
Q, pad_h, pad_w, "
-      "stride_h, stride_w, dilation_h, dilation_w, 
cutlass::conv::Mode::kCrossCorrelation, 1);\n");
+      "stride_h, stride_w, dilation_h, dilation_w, 
cutlass::conv::Mode::kCrossCorrelation, "
+      "split_k_slices);\n");
+
+  const std::string split_k_mode = use_split_k ? "kParallel" : "kSerial";
+  CutlassPrint(conv2d_decl,
+               "const cutlass::conv::SplitKMode split_k_mode = 
cutlass::conv::SplitKMode::" +
+                   split_k_mode + ";\n");
 
   bool is_wgrad = op_type.find("backward_weight") != std::string::npos;
   bool is_dgrad = op_type.find("conv2d_transpose") != std::string::npos;
@@ -372,32 +387,51 @@ std::string Conv2dOp(std::string id, const Str2StrMap& 
attrs,
     CutlassPrint(conv2d_decl, "TensorNHWC layout_D(output_oshape);\n\n");
   }
 
+  if (use_split_k) {
+    CutlassPrint(conv2d_decl, "using ElementOutput = 
EpilogueOutputOp::ElementOutput;\n");
+  } else {
+    CutlassPrint(conv2d_decl, "using ElementOutput = Conv2d::ElementC;\n");
+  }
+
+  std::string tensor_c_init = "{static_cast<ElementOutput*>(ptr_out), 
layout_C}";
+  if (has_residual_block) {
+    tensor_c_init = "{static_cast<ElementOutput*>(ptr_residual), layout_C}";
+  } else if (has_bias) {
+    tensor_c_init =
+        "{static_cast<ElementOutput*>(ptr_c_bias), 
cutlass::layout::TensorNHWC::Stride(0)}";
+  }
+
+  CutlassPrint(conv2d_decl,
+               "cutlass::TensorRef<ElementOutput, TensorNHWC> tensor_c" + 
tensor_c_init + ";\n");
+  CutlassPrint(conv2d_decl,
+               "cutlass::TensorRef<ElementOutput, TensorNHWC> "
+               "tensor_d{static_cast<ElementOutput*>(ptr_out),layout_D};\n");
+
   CutlassPrint(conv2d_decl, "typename Conv2d::Arguments arguments{\n");
   CutlassPrint(conv2d_decl, " problem_size,\n");
   CutlassPrint(conv2d_decl, " {static_cast<ElementInputA*>(ptr_a), 
layout_A},\n");
   CutlassPrint(conv2d_decl, " {static_cast<ElementInputB*>(ptr_b), 
layout_B},\n");
 
-  if (has_residual_block) {
-    CutlassPrint(conv2d_decl, " {static_cast<ElementOutput*>(ptr_residual), 
layout_C},\n");
-  } else if (has_bias) {
-    CutlassPrint(
-        conv2d_decl,
-        " {static_cast<ElementOutput*>(ptr_c_bias), 
cutlass::layout::TensorNHWC::Stride(0)},\n");
+  if (use_split_k) {
+    CutlassPrint(conv2d_decl, "{nullptr, TensorNHWC()},\n");
+    CutlassPrint(conv2d_decl, "{nullptr, TensorNHWC()},\n");
   } else {
-    CutlassPrint(conv2d_decl, " {static_cast<ElementOutput*>(ptr_out), 
layout_C},\n");
+    CutlassPrint(conv2d_decl, " tensor_c,\n");
+    CutlassPrint(conv2d_decl, " tensor_d,\n");
   }
 
-  CutlassPrint(conv2d_decl, " 
{static_cast<ElementOutput*>(ptr_out),layout_D},\n");
-
   if (has_residual_block) {
+    ICHECK(use_split_k == false) << "Split-k not supported for residual block 
fusion";
     CutlassPrint(conv2d_decl, "{alpha, beta},\n");
     CutlassPrint(conv2d_decl, "cutlass::conv::SplitKMode::kSerial,\n");  // 
split_k_slices
     CutlassPrint(conv2d_decl, "static_cast<ElementOutput*>(ptr_bias),\n");
     CutlassPrint(conv2d_decl, "nullptr, 0, K};\n");
   } else if (has_bias && no_bias_scaling) {
-    CutlassPrint(conv2d_decl, " {alpha}\n};\n");
+    CutlassPrint(conv2d_decl, " {alpha},\n");
+    CutlassPrint(conv2d_decl, "split_k_mode\n};\n");
   } else {
-    CutlassPrint(conv2d_decl, "{alpha, beta}\n};\n");
+    CutlassPrint(conv2d_decl, "{alpha, beta},\n");
+    CutlassPrint(conv2d_decl, "split_k_mode\n};\n");
   }
 
   CutlassPrint(conv2d_decl, "Conv2d conv2d_op;\n");
@@ -408,13 +442,67 @@ std::string Conv2dOp(std::string id, const Str2StrMap& 
attrs,
                "cutlass::device_memory::allocation<uint8_t> 
workspace(workspace_size);\n");
   // Check the problem size is supported or not
   CutlassPrint(conv2d_decl, "cutlass::Status status = 
conv2d_op.can_implement(arguments);\n");
-  CutlassPrint(conv2d_decl, "CHECK(status == cutlass::Status::kSuccess);\n");
+  CutlassPrint(conv2d_decl, "CHECK(status == cutlass::Status::kSuccess);\n\n");
+
+  if (use_split_k) {
+    CutlassPrint(conv2d_decl,
+                 
"arguments.ref_D.reset(reinterpret_cast<ElementComputeEpilogue*>(workspace.get()),"
+                 " layout_D);\n\n");
+  }
+
   // Initialize CUTLASS kernel with arguments and workspace pointer
   CutlassPrint(conv2d_decl, "status = conv2d_op.initialize(arguments, 
workspace.get());\n");
-  CutlassPrint(conv2d_decl, "CHECK(status == cutlass::Status::kSuccess);\n");
+  CutlassPrint(conv2d_decl, "CHECK(status == cutlass::Status::kSuccess);\n\n");
+
+  if (use_split_k) {
+    CutlassPrint(
+        conv2d_decl,
+        "arguments.output_op = {ElementComputeEpilogue(1), 
ElementComputeEpilogue(0)}; \n");
+    CutlassPrint(conv2d_decl, "status = conv2d_op.update(arguments, 
workspace.get()); \n");
+    CutlassPrint(conv2d_decl, "CHECK(status == 
cutlass::Status::kSuccess);\n\n");
+  }
+
   // Launch initialized CUTLASS kernel
   CutlassPrint(conv2d_decl, "status = conv2d_op();\n");
-  CutlassPrint(conv2d_decl, "CHECK(status == cutlass::Status::kSuccess);\n");
+  CutlassPrint(conv2d_decl, "CHECK(status == cutlass::Status::kSuccess);\n\n");
+
+  if (use_split_k) {
+    CutlassPrint(conv2d_decl, "ReductionDevice reduction_op;\n");
+    CutlassPrint(conv2d_decl,
+                 "const static cutlass::conv::Operator kConvolutionalOperator 
= "
+                 "Conv2d::kConvolutionalOperator;\n");
+    CutlassPrint(conv2d_decl, "typename ReductionDevice::Arguments 
reduction_args(\n");
+    CutlassPrint(conv2d_decl,
+                 
"cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, "
+                 "problem_size).mn(),\n");
+    CutlassPrint(conv2d_decl, "problem_size.split_k_slices,\n");
+    CutlassPrint(conv2d_decl,
+                 
"cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, "
+                 "problem_size),\n");
+    CutlassPrint(conv2d_decl, "{\n");
+    CutlassPrint(conv2d_decl,
+                 " reinterpret_cast<Conv2d::ElementAccumulator*> 
(workspace.get()),\n");
+    CutlassPrint(conv2d_decl,
+                 
"ReductionStrideIndex(tensor_c.stride()[Conv2d::ImplicitGemmKernel::"
+                 "kTensorCStrideIdx])\n");
+    CutlassPrint(conv2d_decl, "},\n");
+    CutlassPrint(conv2d_decl, "{\n");
+    CutlassPrint(conv2d_decl, "tensor_d.data(),\n");
+    CutlassPrint(conv2d_decl,
+                 
"ReductionStrideIndex(tensor_d.stride()[Conv2d::ImplicitGemmKernel::"
+                 "kTensorCStrideIdx])\n");
+    CutlassPrint(conv2d_decl, "},\n");
+    CutlassPrint(conv2d_decl, "{\n");
+    CutlassPrint(conv2d_decl, "tensor_c.data(),\n");
+    CutlassPrint(conv2d_decl,
+                 
"ReductionStrideIndex(tensor_c.stride()[Conv2d::ImplicitGemmKernel::"
+                 "kTensorCStrideIdx])\n");
+    CutlassPrint(conv2d_decl, "},\n");
+    CutlassPrint(conv2d_decl, "   {alpha, beta}\n");
+    CutlassPrint(conv2d_decl, ");\n\n");
+    CutlassPrint(conv2d_decl, "status = 
reduction_op.initialize(reduction_args, nullptr);\n");
+    CutlassPrint(conv2d_decl, "status = reduction_op();\n");
+  }
 
   return conv2d_decl.str();
 }
@@ -720,6 +808,7 @@ class CutlassModuleCodegen : public 
CSourceModuleCodegenBase {
     code_stream_ << "#include <cuda_fp16.h>\n";
     code_stream_ << "#include <cutlass/cutlass.h>\n";
     code_stream_ << "#include <cutlass/coord.h>\n";
+    code_stream_ << "#include <cutlass/tensor_ref.h>\n";
     code_stream_ << "#include <cutlass/util/host_tensor.h>\n";
     code_stream_ << "#include <cutlass/gemm/device/gemm.h>\n";
     code_stream_ << "#include <cutlass/gemm/device/gemm_batched.h>\n";
@@ -734,6 +823,8 @@ class CutlassModuleCodegen : public 
CSourceModuleCodegenBase {
     code_stream_ << "#include 
<cutlass/epilogue/thread/linear_combination_hardswish.h>\n";
     code_stream_ << "#include 
<cutlass/epilogue/thread/linear_combination_residual_block.h>\n";
     code_stream_ << "#include 
<cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h>\n";
+    code_stream_ << "#include <cutlass/reduction/device/reduce_split_k.h>\n";
+    code_stream_ << "#include 
<cutlass/reduction/thread/reduction_operators.h>\n";
 
     ICHECK(ref->IsInstance<FunctionNode>());
     auto res = GenCutlassFunc(Downcast<Function>(ref));
diff --git a/tests/python/contrib/test_cutlass.py 
b/tests/python/contrib/test_cutlass.py
index ef55c74..ad75e73 100644
--- a/tests/python/contrib/test_cutlass.py
+++ b/tests/python/contrib/test_cutlass.py
@@ -253,16 +253,24 @@ def get_random_ndarray(shape, dtype):
 
 
 def profile_and_build(
-    mod, params, sm, tmp_dir="./tmp", lib_path="compile.so", 
use_fast_math=False, use_3xtf32=True
+    mod,
+    params,
+    sm,
+    split_k_slices=[1],
+    tmp_dir="./tmp",
+    lib_path="compile.so",
+    use_fast_math=False,
+    use_3xtf32=True,
 ):
     mod = partition_for_cutlass(mod)
     mod, num_cutlass_partition = tune_cutlass_kernels(
         mod,
         sm,
         use_3xtf32=use_3xtf32,
+        split_k_slices=split_k_slices,
         profile_all_alignments=False,
         find_first_valid=True,
-        use_multiprocessing=False,
+        use_multiprocessing=True,
         tmp_dir=tmp_dir,
     )
     with tvm.transform.PassContext(opt_level=3):
@@ -277,6 +285,7 @@ def profile_and_build_vm(
     mod,
     params,
     sm,
+    split_k_slices=[1],
     tmp_dir="./tmp",
     lib_path="compile.so",
     vmcode_path="vmcode.ro",
@@ -287,6 +296,7 @@ def profile_and_build_vm(
     mod, num_cutlass_partition = tune_cutlass_kernels(
         mod,
         sm,
+        split_k_slices=split_k_slices,
         use_3xtf32=use_3xtf32,
         profile_all_alignments=False,
         find_first_valid=True,
@@ -508,6 +518,7 @@ def verify_conv2d_common(
     inputs,
     params,
     sm=80,
+    split_k_slices=[1],
     atol=1e-5,
     rtol=1e-5,
     use_cudnn_ref=False,
@@ -543,7 +554,7 @@ def verify_conv2d_common(
     )
 
     rt_mod, _, num_cutlass_partition = profile_and_build_func(
-        mod_weight_ohwi, params, sm, use_fast_math=use_fast_math
+        mod_weight_ohwi, params, sm, split_k_slices, 
use_fast_math=use_fast_math
     )
     out = get_output_func(rt_mod, input_names, inputs)
 
@@ -597,6 +608,8 @@ def verify_conv2d(
     np_bias = get_random_ndarray((w_shape[0],), typ.dtype)
     params = {"weight": np_weight, "bias": np_bias}
 
+    split_k_slices = [1]
+
     return verify_conv2d_common(
         expr_nchw,
         expr_ref,
@@ -604,6 +617,7 @@ def verify_conv2d(
         [np_data],
         params,
         sm,
+        split_k_slices,
         atol,
         rtol,
         use_cudnn_ref,
@@ -620,6 +634,7 @@ def verify_conv2d_backward_weight(
     grad_shape,
     data_shape,
     sm=80,
+    split_k_slices=[1],
     atol=1e-5,
     rtol=1e-5,
     use_cudnn_ref=False,
@@ -640,6 +655,7 @@ def verify_conv2d_backward_weight(
         [np_grad, np_data],
         params,
         sm,
+        split_k_slices,
         atol,
         rtol,
         use_cudnn_ref,
@@ -838,18 +854,20 @@ def test_conv2d_backward_weight():
             weight_dtype=dtype,
         )
 
-        verify_conv2d_backward_weight(
-            mod_nchw,
-            mod_nchw,
-            o_shape,
-            d_shape,
-            sm=80,
-            atol=1e-3,
-            rtol=1e-3,
-            use_cudnn_ref=False,
-            grad_dtype=dtype,
-            data_dtype=dtype,
-        )
+        for split_k_slices in [1, 8]:
+            verify_conv2d_backward_weight(
+                mod_nchw,
+                mod_nchw,
+                o_shape,
+                d_shape,
+                sm=80,
+                split_k_slices=[split_k_slices],
+                atol=1e-3,
+                rtol=1e-3,
+                use_cudnn_ref=False,
+                grad_dtype=dtype,
+                data_dtype=dtype,
+            )
 
 
 def test_conv2d_bwd():

Reply via email to