This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 35a7992 [CUTLASS] Add parallel split-k support to wgrad (#10185)
35a7992 is described below
commit 35a7992fe20154b4da57ab1e9d3aa2585e9d9151
Author: Masahiro Masuda <[email protected]>
AuthorDate: Wed Feb 9 01:34:42 2022 +0900
[CUTLASS] Add parallel split-k support to wgrad (#10185)
* [CUTLASS] Add split-k support to wgrad
commit 60b73a91b79d644d8c95f682eedaf47a89abba0d
Author: Masahiro Masuda <[email protected]>
Date: Tue Feb 8 10:43:11 2022 +0900
pylint
commit ae2e7187256316c48c915c3c187feb5cd4d4dbd4
Author: Masahiro Masuda <[email protected]>
Date: Sun Feb 6 14:51:52 2022 +0900
Add split-k support for wgrad
commit 43820d50055b0bd17b736f5c5830321c7509a20a
Author: Masahiro Masuda <[email protected]>
Date: Sun Feb 6 10:07:34 2022 +0900
fix and add doc
commit 446a95b0aabc5ab69cdd2e414b812aab1c557f42
Author: Masahiro Masuda <[email protected]>
Date: Sun Feb 6 09:48:38 2022 +0900
dw conv2d properly supported for wgrad
commit adc4e22d2e03a99f30ebb6a5e956a1749de693f0
Author: Masahiro Masuda <[email protected]>
Date: Sat Feb 5 16:32:42 2022 +0900
fix overwriting template
commit 040eab000bc5f162c6e9aca70ae6d29378fe65bc
Author: Masahiro Masuda <[email protected]>
Date: Sat Feb 5 16:06:27 2022 +0900
black
commit e5a07c24b7463552b8e545710d25472159bcc127
Author: Masahiro Masuda <[email protected]>
Date: Sat Feb 5 16:03:10 2022 +0900
add reduction in profiler
commit be89334ab981d536d010dd765c9cf601dbdae5e0
Author: Masahiro Masuda <[email protected]>
Date: Sat Feb 5 06:58:03 2022 +0900
adding split k reduction to conv2d profiler
commit ae09b0fbdc3a472eb320d866c054f73b3142f21c
Author: Masahiro Masuda <[email protected]>
Date: Fri Feb 4 11:52:59 2022 +0900
fixed conv2d_backward_weight typerel for dw conv2d
commit 16fe5313fd1219e2e7d531ef9b36f64bb557e5e7
Author: Masahiro Masuda <[email protected]>
Date: Thu Feb 3 12:59:22 2022 +0900
wip
commit 2167c2543340a285bb1985e8fe37e11aed51fb9b
Author: Masahiro Masuda <[email protected]>
Date: Thu Feb 3 04:22:19 2022 +0900
fix conv2d type rel for depth wise and grouped conv2d
commit 14b12e5dd84fc34691d585213387198f091eefc5
Author: Masahiro Masuda <[email protected]>
Date: Fri Feb 4 05:01:03 2022 +0900
remove split_k.py
commit b14127179c43f71c3ce5ccc7b4ca678a099e5497
Author: Masahiro Masuda <[email protected]>
Date: Fri Feb 4 04:48:21 2022 +0900
workaround for invalid split_k_slice
commit 6e4c7e1d77d89f124abc77dbcdab69eff8a5d961
Author: Masahiro Masuda <[email protected]>
Date: Fri Feb 4 02:43:58 2022 +0900
support split k in profiler
commit 2eb1cf43c7f56f0537cf249855054b5cbd357b13
Author: Masahiro Masuda <[email protected]>
Date: Fri Feb 4 02:31:03 2022 +0900
improvement
commit 0bce8f3778a6bb05607232a0997d25681e55ce7c
Author: Masahiro Masuda <[email protected]>
Date: Thu Feb 3 18:20:12 2022 +0900
fixed for fp16 output
commit 30df1bd5282a4d326856382726d4e63ee8c27e8e
Author: Masahiro Masuda <[email protected]>
Date: Thu Feb 3 17:50:33 2022 +0900
fp32 output works
commit 7a519956b8d103464dff83b4f01b75973f4a33b0
Author: Masahiro Masuda <[email protected]>
Date: Thu Feb 3 14:30:22 2022 +0900
fix
commit 4a383e2c7c37148a563e9cf34968fb7da3aaf91f
Author: Masahiro Masuda <[email protected]>
Date: Thu Feb 3 14:05:24 2022 +0900
update c++ codegen
commit 6206e388cc7062cbef0b3c8c47fcd228b44b6818
Author: Masahiro Masuda <[email protected]>
Date: Thu Feb 3 13:46:05 2022 +0900
wip
commit 0ece49b53e773ebc1ea71c7667abc0cbb29d91bf
Author: Masahiro Masuda <[email protected]>
Date: Thu Feb 3 03:05:21 2022 +0900
wip
commit 08a6147940d9911fd65a890a4d90beb68176fc03
Author: Masahiro Masuda <[email protected]>
Date: Wed Feb 2 13:10:21 2022 +0900
test worked with fp32 output
commit 084d5c47666df92ba6c2c1445d5a23de0193a119
Author: Masahiro Masuda <[email protected]>
Date: Wed Feb 2 12:35:18 2022 +0900
fix compile error for fprop
commit 31f25436c5aca1a75336fa1a8d1c8a25a4936ee8
Author: Masahiro Masuda <[email protected]>
Date: Wed Feb 2 12:18:06 2022 +0900
compiled
commit c2098e79ade47117f2c32132da864b1fa73fce4a
Author: Masahiro Masuda <[email protected]>
Date: Wed Feb 2 11:11:43 2022 +0900
wip
commit a14585020151d0e09bb9bac549285dceb13e55e1
Author: Masahiro Masuda <[email protected]>
Date: Sun Feb 6 14:46:16 2022 +0900
fixed for sm75
commit 61515062ef4576bf5b4e7e9e800f7f705738809c
Author: Masahiro Masuda <[email protected]>
Date: Sun Feb 6 14:32:46 2022 +0900
all tests work
commit 041c094b3646e0f521f5bd2c4f6f6b5b1cff7b97
Author: Masahiro Masuda <[email protected]>
Date: Sun Feb 6 14:19:09 2022 +0900
dw conv2d properly supported for wgrad
commit 2191918743a4e9ffb8254f3786d817be57ff49cc
Author: Masahiro Masuda <[email protected]>
Date: Wed Feb 2 09:14:05 2022 +0900
wgrad tests now work under pytest
commit 78f76df1eb1602f66cacb888a97b6b267f8600a7
Author: Masahiro Masuda <[email protected]>
Date: Wed Feb 2 07:31:54 2022 +0900
run black
commit 0a82149fe0586b0bf449fc7f3a1fa9809e9b38d2
Author: Masahiro Masuda <[email protected]>
Date: Wed Feb 2 06:12:39 2022 +0900
[CUTLASS] Add wgrad support (without split-k)
* pylint
* add more doc
* more doc clarification
---
python/tvm/contrib/cutlass/build.py | 14 ++-
python/tvm/contrib/cutlass/conv2d_operation.py | 82 ++++++++++++++--
python/tvm/contrib/cutlass/conv2d_profiler.py | 58 +++++++++--
python/tvm/contrib/cutlass/gen_conv2d.py | 87 +++++++++-------
python/tvm/contrib/cutlass/gen_tensor_op.py | 3 +
src/relay/backend/contrib/cutlass/codegen.cc | 131 +++++++++++++++++++++----
tests/python/contrib/test_cutlass.py | 48 ++++++---
7 files changed, 340 insertions(+), 83 deletions(-)
diff --git a/python/tvm/contrib/cutlass/build.py
b/python/tvm/contrib/cutlass/build.py
index 918eeaf..06c33f2 100644
--- a/python/tvm/contrib/cutlass/build.py
+++ b/python/tvm/contrib/cutlass/build.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-# pylint: disable=invalid-name
+# pylint: disable=invalid-name, dangerous-default-value
"""Driver for partitioning and building a Relay module for CUTLASS offload."""
import logging
import os
@@ -238,6 +238,7 @@ def handle_conv2d(
data_dtype,
weight_dtype,
use_3xtf32,
+ split_k_slices,
profile_all_alignments,
find_first_valid,
use_multiprocessing,
@@ -269,6 +270,7 @@ def handle_conv2d(
weight_dtype,
use_3xtf32,
conv_kind,
+ split_k_slices,
profile_all_alignments,
find_first_valid=find_first_valid,
use_multiprocessing=use_multiprocessing,
@@ -288,6 +290,7 @@ def tune_cutlass_kernels(
mod,
sm,
use_3xtf32=True,
+ split_k_slices=[1],
profile_all_alignments=False,
find_first_valid=False,
use_multiprocessing=False,
@@ -309,6 +312,14 @@ def tune_cutlass_kernels(
Wheter or not use slower but very accurate (compared to tf32) 3xtf32
mode for
fp32 inputs on tensorcore.
+ split_k_slices : list of int
+ Split factor candidates for split-K GEMM. If split-K > 1, the GEMM
K-loop is computed in
+ parallel accross split-K blocks, and a seperate global reduction
kernel is launched to
+ accumulate partial reductions. The profiler will pick the best split-k
factor from the
+ given candidate list. Note that the larger split-K factor requires a
larger workspace.
+ Currently, parallel split-k has been tested only for wgrad. For GEMM
and other conv2d
+ kinds, split_k_slices is ignored.
+
profile_all_alignments : bool
When True, profile all kernal variants with smaller alignments than
the largest possible.
@@ -380,6 +391,7 @@ def tune_cutlass_kernels(
arg0_dtype,
arg1_dtype,
use_3xtf32,
+ split_k_slices,
profile_all_alignments,
find_first_valid,
use_multiprocessing,
diff --git a/python/tvm/contrib/cutlass/conv2d_operation.py
b/python/tvm/contrib/cutlass/conv2d_operation.py
index 5318cc7..7b78c5a 100644
--- a/python/tvm/contrib/cutlass/conv2d_operation.py
+++ b/python/tvm/contrib/cutlass/conv2d_operation.py
@@ -35,6 +35,7 @@ class Conv2dOperation:
stride_support,
epilogue_functor=EpilogueFunctor.LinearCombination,
swizzling_functor=SwizzlingFunctor.Identity1,
+ split_k_slices=1,
):
self.operation_kind = OperationKind.Conv2d
self.arch = arch
@@ -48,6 +49,7 @@ class Conv2dOperation:
self.iterator_algorithm = iterator_algorithm
self.stride_support = stride_support
self.swizzling_functor = swizzling_functor
+ self.split_k_slices = split_k_slices
def accumulator_type(self):
return self.tile_description.math_instruction.element_accumulator
@@ -127,6 +129,9 @@ class Conv2dOperation:
"_${layout}_align${alignment}"
)
+ if self.split_k_slices > 1:
+ configuration_name += "_splitk%d" % self.split_k_slices
+
return substitute_template(
configuration_name,
{
@@ -172,6 +177,14 @@ class EmitConv2dInstance:
${unary_op}
>"""
+ self.epilogue_wgrad = """
+ ${epilogue_functor}<
+ ${element_c},
+ 4,
+ float,
+ float
+ >"""
+
self.template = """
// Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance
"${operation_name}"
using ${operation_name} =
@@ -197,9 +210,31 @@ class EmitConv2dInstance:
${align_a},
${align_b}
>::Kernel;
+
+ ${reduction}
+"""
+
+ self.reduction_template = """
+using EpilogueOutputOp = ${epilogue};
+using ReductionOp = cutlass::reduction::thread::ReduceAdd<
+ ${element_accumulator},
+ ${element_accumulator},
+ EpilogueOutputOp::kCount
+ >;
+
+using ReductionKernel = cutlass::reduction::kernel::ReduceSplitK<
+ cutlass::MatrixShape<4, 32 * EpilogueOutputOp::kCount>,
+ EpilogueOutputOp,
+ ReductionOp
+ >;
+
+using ReductionDevice =
cutlass::reduction::device::ReduceSplitK<ReductionKernel>;
+using ReductionStrideIndex = typename ReductionDevice::StrideIndex;
"""
- def emit(self, operation, no_beta_scaling=False,
residual_block_info=False):
+ def emit(
+ self, operation, no_beta_scaling=False, residual_block_info=False,
emit_reduction=False
+ ):
"""Instantiate a Conv2d kernel from given `operation`."""
warp_shape = [
int(
@@ -214,6 +249,31 @@ class EmitConv2dInstance:
/ DataTypeSize[operation.C.element]
)
+ element_c = operation.C.element
+ use_split_k_wgrad = operation.conv_kind == ConvKind.Wgrad and
operation.split_k_slices > 1
+ # Gemm output always fp32 in wgrad with split k
+ element_c_gemm = DataType.f32 if use_split_k_wgrad else element_c
+
+ if emit_reduction:
+ epilogue_reduction = substitute_template(
+ self.epilogue_wgrad,
+ {
+ "epilogue_functor":
EpilogueFunctorTag[operation.epilogue_functor],
+ "element_c": DataTypeTag[element_c],
+ },
+ )
+ reduction = substitute_template(
+ self.reduction_template,
+ {
+ "epilogue": epilogue_reduction,
+ "operation_name": operation.procedural_name(),
+ "element_accumulator":
DataTypeTag[operation.accumulator_type()],
+ },
+ )
+ gemm_template = substitute_template(self.template, {"reduction":
reduction})
+ else:
+ gemm_template = substitute_template(self.template, {"reduction":
""})
+
values = {
"operation_name": operation.procedural_name(),
"conv_kind": ConvKindTag[operation.conv_kind],
@@ -222,7 +282,7 @@ class EmitConv2dInstance:
"layout_a": LayoutTag[operation.A.layout],
"element_b": DataTypeTag[operation.B.element],
"layout_b": LayoutTag[operation.B.layout],
- "element_c": DataTypeTag[operation.C.element],
+ "element_c": DataTypeTag[element_c_gemm],
"layout_c": LayoutTag[operation.C.layout],
"element_accumulator": DataTypeTag[operation.accumulator_type()],
"opcode_class": OpcodeClassTag[
@@ -262,9 +322,19 @@ class EmitConv2dInstance:
"conv_kernel_postfix": "",
}
- if residual_block_info:
+ if use_split_k_wgrad:
+ # Even if the output is fp16, gemm output is always fp32 for split
k wgrad.
+ epilogue_gemm = substitute_template(
+ self.epilogue_wgrad,
+ {
+ "epilogue_functor":
EpilogueFunctorTag[operation.epilogue_functor],
+ "element_c": "float",
+ },
+ )
+ template = substitute_template(gemm_template, {"epilogue":
epilogue_gemm})
+ elif residual_block_info:
template = substitute_template(
- self.template, {"epilogue": self.epilogue_residual_block}
+ gemm_template, {"epilogue": self.epilogue_residual_block}
)
values.update(
{
@@ -276,9 +346,9 @@ class EmitConv2dInstance:
)
elif no_beta_scaling:
template = substitute_template(
- self.template, {"epilogue": self.epilogue_no_beta_scaling}
+ gemm_template, {"epilogue": self.epilogue_no_beta_scaling}
)
else:
- template = substitute_template(self.template, {"epilogue":
self.epilogue_default})
+ template = substitute_template(gemm_template, {"epilogue":
self.epilogue_default})
return substitute_template(template, values)
diff --git a/python/tvm/contrib/cutlass/conv2d_profiler.py
b/python/tvm/contrib/cutlass/conv2d_profiler.py
index 2f4e769..1ed5550 100644
--- a/python/tvm/contrib/cutlass/conv2d_profiler.py
+++ b/python/tvm/contrib/cutlass/conv2d_profiler.py
@@ -17,6 +17,8 @@
# pylint: disable=import-outside-toplevel, invalid-name
"""Instantiate a C++ source for profiling CUTLASS kernels."""
+from .library import DataTypeTag
+
class Conv2dProfilerEmitter(object):
"""Emit a C++ source for profiling CUTLASS kernels."""
@@ -24,6 +26,32 @@ class Conv2dProfilerEmitter(object):
def __init__(self):
from jinja2 import Template
+ self.reduction = """
+ ReductionDevice reduction_op;
+ static cutlass::conv::Operator const kConvolutionalOperator =
ImplicitGemm::kConvolutionalOperator;
+ typename ReductionDevice::Arguments reduction_args(
+ cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator,
problem_size).mn(),
+ problem_size.split_k_slices,
+ cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator,
problem_size),
+ {
+ reinterpret_cast<ImplicitGemm::ElementC*> (workspace.get()),
+
ReductionStrideIndex(tensor_c.stride()[ImplicitGemm::ImplicitGemmKernel::kTensorCStrideIdx])
+ },
+ {
+ tensor_d.device_data(),
+
ReductionStrideIndex(tensor_d.stride()[ImplicitGemm::ImplicitGemmKernel::kTensorCStrideIdx])
+ },
+ {
+ tensor_c.device_data(),
+
ReductionStrideIndex(tensor_c.stride()[ImplicitGemm::ImplicitGemmKernel::kTensorCStrideIdx])
+ },
+ {ElementComputeEpilogue(1), ElementComputeEpilogue(0)}
+ );
+
+ reduction_op.initialize(reduction_args, nullptr);
+ reduction_op();
+"""
+
self.template = Template(
"""
#include <iostream>
@@ -35,6 +63,8 @@ class Conv2dProfilerEmitter(object):
#include "cutlass/util/command_line.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/tensor_fill.h"
+#include "cutlass/reduction/device/reduce_split_k.h"
+#include "cutlass/reduction/thread/reduction_operators.h"
#define CUTLASS_CHECK(status)
\
{
\
@@ -88,10 +118,11 @@ struct Options {
};
double profile_convolution(Options const &options) {
- using ElementOutput = typename ImplicitGemm::ElementC;
+ using ElementOutput = {{ElementOutput}};
using ElementInputA = typename ImplicitGemm::ElementA;
using ElementInputB = typename ImplicitGemm::ElementB;
+ int split_k_slices = {{SplitK}};
cutlass::conv::Conv2dProblemSize problem_size(
options.input_size,
options.filter_size,
@@ -100,7 +131,7 @@ double profile_convolution(Options const &options) {
options.dilation,
options.output_size(),
cutlass::conv::Mode::kCrossCorrelation,
- 1
+ split_k_slices
);
auto conv_kind = ImplicitGemm::kConvolutionalOperator;
@@ -108,20 +139,26 @@ double profile_convolution(Options const &options) {
auto b_extent = implicit_gemm_tensor_b_extent(conv_kind, problem_size);
auto c_extent = implicit_gemm_tensor_c_extent(conv_kind, problem_size);
+ using LayoutC = typename ImplicitGemm::LayoutC;
cutlass::HostTensor<ElementInputA, typename ImplicitGemm::LayoutA>
tensor_a(a_extent);
cutlass::HostTensor<ElementInputB, typename ImplicitGemm::LayoutB>
tensor_b(b_extent);
cutlass::HostTensor<ElementOutput, typename ImplicitGemm::LayoutC>
tensor_c(c_extent);
- cutlass::HostTensor<ElementOutput, typename ImplicitGemm::LayoutC>
tensor_ref_c(c_extent);
+ cutlass::HostTensor<ElementOutput, LayoutC> tensor_d(c_extent);
+ cutlass::HostTensor<ImplicitGemm::ElementC, LayoutC> tensor_c_gemm(c_extent);
using ElementComputeEpilogue = typename ImplicitGemm::ElementCompute;
+ cutlass::conv::SplitKMode const split_k_mode = split_k_slices > 1 ?
+ cutlass::conv::SplitKMode::kParallel :
cutlass::conv::SplitKMode::kSerial;
+
typename ImplicitGemm::Arguments arguments{
problem_size,
tensor_a.device_ref(),
tensor_b.device_ref(),
- tensor_c.device_ref(),
- tensor_c.device_ref(),
+ tensor_c_gemm.device_ref(),
+ tensor_c_gemm.device_ref(),
{ElementComputeEpilogue(1), ElementComputeEpilogue(0)},
+ split_k_mode,
};
ImplicitGemm implicit_gemm_op;
@@ -144,6 +181,7 @@ double profile_convolution(Options const &options) {
for (int iteration = 0; iteration < 100; ++iteration) {
auto status = implicit_gemm_op();
CUTLASS_CHECK(status);
+ {{Reduction}}
}
cudaEventRecord(events[1]);
@@ -166,6 +204,12 @@ int main(int argc, char const **args) {
"""
)
- def emit(self, op_def, op_name):
- src = self.template.render(OperatorDef=op_def, OperatorName=op_name)
+ def emit(self, op_def, op_name, element_output, split_k_slices=1):
+ src = self.template.render(
+ OperatorDef=op_def,
+ OperatorName=op_name,
+ ElementOutput=DataTypeTag[element_output],
+ SplitK=split_k_slices,
+ Reduction=self.reduction if split_k_slices > 1 else "",
+ )
return src
diff --git a/python/tvm/contrib/cutlass/gen_conv2d.py
b/python/tvm/contrib/cutlass/gen_conv2d.py
index 0d46000..b51afdc 100644
--- a/python/tvm/contrib/cutlass/gen_conv2d.py
+++ b/python/tvm/contrib/cutlass/gen_conv2d.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-# pylint: disable=invalid-name
+# pylint: disable=invalid-name, dangerous-default-value
"""Conv2d kernel generator and profiler for CUTLASS."""
from functools import partial
from .conv2d_operation import Conv2dOperation, EmitConv2dInstance
@@ -40,6 +40,7 @@ def create_conv2d_operator_with_epilogue(
data_type,
alignment,
swizzling_functor,
+ split_k_slices,
):
"""
Instantiate a cutlass kernel from the given configuration,
@@ -90,11 +91,15 @@ def create_conv2d_operator_with_epilogue(
stride_support,
epilogue,
swizzling_functor,
+ split_k_slices,
)
name = op.procedural_name()
opdef = EmitConv2dInstance().emit(
- op, no_beta_scaling=no_beta_scaling,
residual_block_info=residual_block_info
+ op,
+ no_beta_scaling=no_beta_scaling,
+ residual_block_info=residual_block_info,
+ emit_reduction=split_k_slices > 1,
)
return name, opdef
@@ -103,6 +108,7 @@ def create_conv2d_operator_with_epilogue(
def enumerate_conv2d_operators(
conv_kind,
stride_support,
+ split_k_slices,
tile_descriptions,
data_type,
alignment_constraints,
@@ -119,37 +125,45 @@ def enumerate_conv2d_operators(
if conv_kind == ConvKind.Dgrad and stride_support == StrideSupport.Strided:
swizzling_functor = SwizzlingFunctor.StridedDgradIdentity1
- for tile in tile_descriptions:
- for alignment in alignment_constraints:
-
- A = TensorDescription(element_a, LayoutType.TensorNHWC, alignment)
- B = TensorDescription(element_b, LayoutType.TensorNHWC, alignment)
- C = TensorDescription(element_c, LayoutType.TensorNHWC, alignment)
-
- op = Conv2dOperation(
- conv_kind,
- IteratorAlgorithm.Optimized,
- tile.minimum_compute_capability,
- tile,
- A,
- B,
- C,
- element_epilogue,
- stride_support,
- EpilogueFunctor.LinearCombination,
- swizzling_functor,
- )
-
- ret.append(
- {
- "src": profiler_emitter.emit(kernel_emitter.emit(op),
op.procedural_name()),
- "name": op.procedural_name(),
- "tile_description": tile,
- "alignment": alignment,
- "data_type": data_type,
- "swizzle_functor": swizzling_functor,
- }
- )
+ for split_k_slice in split_k_slices:
+ for tile in tile_descriptions:
+ for alignment in alignment_constraints:
+
+ A = TensorDescription(element_a, LayoutType.TensorNHWC,
alignment)
+ B = TensorDescription(element_b, LayoutType.TensorNHWC,
alignment)
+ C = TensorDescription(element_c, LayoutType.TensorNHWC,
alignment)
+
+ op = Conv2dOperation(
+ conv_kind,
+ IteratorAlgorithm.Optimized,
+ tile.minimum_compute_capability,
+ tile,
+ A,
+ B,
+ C,
+ element_epilogue,
+ stride_support,
+ EpilogueFunctor.LinearCombination,
+ swizzling_functor,
+ split_k_slice,
+ )
+
+ ret.append(
+ {
+ "src": profiler_emitter.emit(
+ kernel_emitter.emit(op,
emit_reduction=split_k_slice > 1),
+ op.procedural_name(),
+ element_output=element_c,
+ split_k_slices=split_k_slice,
+ ),
+ "name": op.procedural_name(),
+ "tile_description": tile,
+ "alignment": alignment,
+ "data_type": data_type,
+ "swizzle_functor": swizzling_functor,
+ "split_k_slices": split_k_slice,
+ }
+ )
return ret
@@ -198,6 +212,7 @@ class CutlassConv2DProfiler:
data_type,
alignment,
swizzling_functor,
+ split_k_slices=1,
)
return {"name": name, "opdef": opdef}
@@ -214,6 +229,7 @@ class CutlassConv2DProfiler:
use_3xtf32,
conv_kind,
stride_support,
+ split_k_slices,
profile_all_alignments=False,
find_first_valid=False,
use_multiprocessing=False,
@@ -248,7 +264,7 @@ class CutlassConv2DProfiler:
out_dtype,
data_dtype,
weight_dtype,
- partial(enumerate_conv2d_operators, conv_kind, stride_support),
+ partial(enumerate_conv2d_operators, conv_kind, stride_support,
split_k_slices),
lambda align: all([dim % align == 0 for dim in [IC, OC]]),
use_3xtf32,
profile_all_alignments,
@@ -288,6 +304,7 @@ class CutlassConv2DProfiler:
weight_dtype,
use_3xtf32=True,
conv_kind=ConvKind.Fprop,
+ split_k_slices=[1],
profile_all_alignments=False,
find_first_valid=False,
use_multiprocessing=False,
@@ -315,6 +332,7 @@ class CutlassConv2DProfiler:
use_3xtf32,
conv_kind,
stride_support,
+ split_k_slices,
profile_all_alignments,
find_first_valid,
use_multiprocessing,
@@ -328,6 +346,7 @@ class CutlassConv2DProfiler:
op["data_type"],
op["alignment"],
op["swizzle_functor"],
+ op["split_k_slices"],
)
return name, opdef, op["runtime"]
diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py
b/python/tvm/contrib/cutlass/gen_tensor_op.py
index 76d4383..b3f40f0 100644
--- a/python/tvm/contrib/cutlass/gen_tensor_op.py
+++ b/python/tvm/contrib/cutlass/gen_tensor_op.py
@@ -363,6 +363,9 @@ class ProfilerEngine:
try:
sp = subprocess.run(cmd, capture_output=True, check=True)
rt = float(sp.stdout)
+ if rt == 0.0:
+ # This seems to happen with split-k using invalid
split-k-slices
+ rt = float("inf")
logger.info("%s, %f", op_name, rt)
except subprocess.CalledProcessError:
rt = float("inf")
diff --git a/src/relay/backend/contrib/cutlass/codegen.cc
b/src/relay/backend/contrib/cutlass/codegen.cc
index fdd268d..b12da1a 100644
--- a/src/relay/backend/contrib/cutlass/codegen.cc
+++ b/src/relay/backend/contrib/cutlass/codegen.cc
@@ -284,15 +284,15 @@ std::string Conv2dOp(std::string id, const Str2StrMap&
attrs,
op_type != "cutlass.conv2d_bias_silu" &&
op_type != "cutlass.conv2d_bias_hardswish";
+ const std::string op_name = attrs.at("op_name");
std::ostringstream conv2d_decl;
CutlassPrint(conv2d_decl, attrs.at("op_def"));
- CutlassPrint(conv2d_decl, "using Operation_" + attrs.at("op_name") +
- " =
cutlass::conv::device::ImplicitGemmConvolution<" +
- attrs.at("op_name") + ">;\n");
- CutlassPrint(conv2d_decl, "using Conv2d = Operation_" + attrs.at("op_name")
+ ";\n");
+ CutlassPrint(conv2d_decl, "using Operation_" + op_name +
+ " =
cutlass::conv::device::ImplicitGemmConvolution<" + op_name +
+ ">;\n");
+ CutlassPrint(conv2d_decl, "using Conv2d = Operation_" + op_name + ";\n");
CutlassPrint(conv2d_decl, "using ElementInputA = Conv2d::ElementA;\n");
CutlassPrint(conv2d_decl, "using ElementInputB = Conv2d::ElementB;\n");
- CutlassPrint(conv2d_decl, "using ElementOutput = Conv2d::ElementC;\n");
CutlassPrint(conv2d_decl, "using ElementComputeEpilogue =
Conv2d::ElementAccumulator;\n");
auto get_dim = [&attrs](const std::string& axis, const std::string&
var_name, int axis_idx) {
@@ -319,10 +319,25 @@ std::string Conv2dOp(std::string id, const Str2StrMap&
attrs,
CutlassPrint(conv2d_decl, "int dilation_h = " + attrs.at("dilation_h") +
";\n");
CutlassPrint(conv2d_decl, "int dilation_w = " + attrs.at("dilation_w") +
";\n");
+ const bool use_split_k = op_name.find("splitk") != std::string::npos;
+
+ if (use_split_k) {
+ std::string split_k_slices =
op_name.substr(op_name.find_last_not_of("0123456789") + 1);
+ CutlassPrint(conv2d_decl, "int split_k_slices = " + split_k_slices +
";\n");
+ } else {
+ CutlassPrint(conv2d_decl, "int split_k_slices = 1;\n");
+ }
+
CutlassPrint(
conv2d_decl,
"cutlass::conv::Conv2dProblemSize problem_size(N, H, W, C, K, R, S, P,
Q, pad_h, pad_w, "
- "stride_h, stride_w, dilation_h, dilation_w,
cutlass::conv::Mode::kCrossCorrelation, 1);\n");
+ "stride_h, stride_w, dilation_h, dilation_w,
cutlass::conv::Mode::kCrossCorrelation, "
+ "split_k_slices);\n");
+
+ const std::string split_k_mode = use_split_k ? "kParallel" : "kSerial";
+ CutlassPrint(conv2d_decl,
+ "const cutlass::conv::SplitKMode split_k_mode =
cutlass::conv::SplitKMode::" +
+ split_k_mode + ";\n");
bool is_wgrad = op_type.find("backward_weight") != std::string::npos;
bool is_dgrad = op_type.find("conv2d_transpose") != std::string::npos;
@@ -372,32 +387,51 @@ std::string Conv2dOp(std::string id, const Str2StrMap&
attrs,
CutlassPrint(conv2d_decl, "TensorNHWC layout_D(output_oshape);\n\n");
}
+ if (use_split_k) {
+ CutlassPrint(conv2d_decl, "using ElementOutput =
EpilogueOutputOp::ElementOutput;\n");
+ } else {
+ CutlassPrint(conv2d_decl, "using ElementOutput = Conv2d::ElementC;\n");
+ }
+
+ std::string tensor_c_init = "{static_cast<ElementOutput*>(ptr_out),
layout_C}";
+ if (has_residual_block) {
+ tensor_c_init = "{static_cast<ElementOutput*>(ptr_residual), layout_C}";
+ } else if (has_bias) {
+ tensor_c_init =
+ "{static_cast<ElementOutput*>(ptr_c_bias),
cutlass::layout::TensorNHWC::Stride(0)}";
+ }
+
+ CutlassPrint(conv2d_decl,
+ "cutlass::TensorRef<ElementOutput, TensorNHWC> tensor_c" +
tensor_c_init + ";\n");
+ CutlassPrint(conv2d_decl,
+ "cutlass::TensorRef<ElementOutput, TensorNHWC> "
+ "tensor_d{static_cast<ElementOutput*>(ptr_out),layout_D};\n");
+
CutlassPrint(conv2d_decl, "typename Conv2d::Arguments arguments{\n");
CutlassPrint(conv2d_decl, " problem_size,\n");
CutlassPrint(conv2d_decl, " {static_cast<ElementInputA*>(ptr_a),
layout_A},\n");
CutlassPrint(conv2d_decl, " {static_cast<ElementInputB*>(ptr_b),
layout_B},\n");
- if (has_residual_block) {
- CutlassPrint(conv2d_decl, " {static_cast<ElementOutput*>(ptr_residual),
layout_C},\n");
- } else if (has_bias) {
- CutlassPrint(
- conv2d_decl,
- " {static_cast<ElementOutput*>(ptr_c_bias),
cutlass::layout::TensorNHWC::Stride(0)},\n");
+ if (use_split_k) {
+ CutlassPrint(conv2d_decl, "{nullptr, TensorNHWC()},\n");
+ CutlassPrint(conv2d_decl, "{nullptr, TensorNHWC()},\n");
} else {
- CutlassPrint(conv2d_decl, " {static_cast<ElementOutput*>(ptr_out),
layout_C},\n");
+ CutlassPrint(conv2d_decl, " tensor_c,\n");
+ CutlassPrint(conv2d_decl, " tensor_d,\n");
}
- CutlassPrint(conv2d_decl, "
{static_cast<ElementOutput*>(ptr_out),layout_D},\n");
-
if (has_residual_block) {
+ ICHECK(use_split_k == false) << "Split-k not supported for residual block
fusion";
CutlassPrint(conv2d_decl, "{alpha, beta},\n");
CutlassPrint(conv2d_decl, "cutlass::conv::SplitKMode::kSerial,\n"); //
split_k_slices
CutlassPrint(conv2d_decl, "static_cast<ElementOutput*>(ptr_bias),\n");
CutlassPrint(conv2d_decl, "nullptr, 0, K};\n");
} else if (has_bias && no_bias_scaling) {
- CutlassPrint(conv2d_decl, " {alpha}\n};\n");
+ CutlassPrint(conv2d_decl, " {alpha},\n");
+ CutlassPrint(conv2d_decl, "split_k_mode\n};\n");
} else {
- CutlassPrint(conv2d_decl, "{alpha, beta}\n};\n");
+ CutlassPrint(conv2d_decl, "{alpha, beta},\n");
+ CutlassPrint(conv2d_decl, "split_k_mode\n};\n");
}
CutlassPrint(conv2d_decl, "Conv2d conv2d_op;\n");
@@ -408,13 +442,67 @@ std::string Conv2dOp(std::string id, const Str2StrMap&
attrs,
"cutlass::device_memory::allocation<uint8_t>
workspace(workspace_size);\n");
// Check the problem size is supported or not
CutlassPrint(conv2d_decl, "cutlass::Status status =
conv2d_op.can_implement(arguments);\n");
- CutlassPrint(conv2d_decl, "CHECK(status == cutlass::Status::kSuccess);\n");
+ CutlassPrint(conv2d_decl, "CHECK(status == cutlass::Status::kSuccess);\n\n");
+
+ if (use_split_k) {
+ CutlassPrint(conv2d_decl,
+
"arguments.ref_D.reset(reinterpret_cast<ElementComputeEpilogue*>(workspace.get()),"
+ " layout_D);\n\n");
+ }
+
// Initialize CUTLASS kernel with arguments and workspace pointer
CutlassPrint(conv2d_decl, "status = conv2d_op.initialize(arguments,
workspace.get());\n");
- CutlassPrint(conv2d_decl, "CHECK(status == cutlass::Status::kSuccess);\n");
+ CutlassPrint(conv2d_decl, "CHECK(status == cutlass::Status::kSuccess);\n\n");
+
+ if (use_split_k) {
+ CutlassPrint(
+ conv2d_decl,
+ "arguments.output_op = {ElementComputeEpilogue(1),
ElementComputeEpilogue(0)}; \n");
+ CutlassPrint(conv2d_decl, "status = conv2d_op.update(arguments,
workspace.get()); \n");
+ CutlassPrint(conv2d_decl, "CHECK(status ==
cutlass::Status::kSuccess);\n\n");
+ }
+
// Launch initialized CUTLASS kernel
CutlassPrint(conv2d_decl, "status = conv2d_op();\n");
- CutlassPrint(conv2d_decl, "CHECK(status == cutlass::Status::kSuccess);\n");
+ CutlassPrint(conv2d_decl, "CHECK(status == cutlass::Status::kSuccess);\n\n");
+
+ if (use_split_k) {
+ CutlassPrint(conv2d_decl, "ReductionDevice reduction_op;\n");
+ CutlassPrint(conv2d_decl,
+ "const static cutlass::conv::Operator kConvolutionalOperator
= "
+ "Conv2d::kConvolutionalOperator;\n");
+ CutlassPrint(conv2d_decl, "typename ReductionDevice::Arguments
reduction_args(\n");
+ CutlassPrint(conv2d_decl,
+
"cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, "
+ "problem_size).mn(),\n");
+ CutlassPrint(conv2d_decl, "problem_size.split_k_slices,\n");
+ CutlassPrint(conv2d_decl,
+
"cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, "
+ "problem_size),\n");
+ CutlassPrint(conv2d_decl, "{\n");
+ CutlassPrint(conv2d_decl,
+ " reinterpret_cast<Conv2d::ElementAccumulator*>
(workspace.get()),\n");
+ CutlassPrint(conv2d_decl,
+
"ReductionStrideIndex(tensor_c.stride()[Conv2d::ImplicitGemmKernel::"
+ "kTensorCStrideIdx])\n");
+ CutlassPrint(conv2d_decl, "},\n");
+ CutlassPrint(conv2d_decl, "{\n");
+ CutlassPrint(conv2d_decl, "tensor_d.data(),\n");
+ CutlassPrint(conv2d_decl,
+
"ReductionStrideIndex(tensor_d.stride()[Conv2d::ImplicitGemmKernel::"
+ "kTensorCStrideIdx])\n");
+ CutlassPrint(conv2d_decl, "},\n");
+ CutlassPrint(conv2d_decl, "{\n");
+ CutlassPrint(conv2d_decl, "tensor_c.data(),\n");
+ CutlassPrint(conv2d_decl,
+
"ReductionStrideIndex(tensor_c.stride()[Conv2d::ImplicitGemmKernel::"
+ "kTensorCStrideIdx])\n");
+ CutlassPrint(conv2d_decl, "},\n");
+ CutlassPrint(conv2d_decl, " {alpha, beta}\n");
+ CutlassPrint(conv2d_decl, ");\n\n");
+ CutlassPrint(conv2d_decl, "status =
reduction_op.initialize(reduction_args, nullptr);\n");
+ CutlassPrint(conv2d_decl, "status = reduction_op();\n");
+ }
return conv2d_decl.str();
}
@@ -720,6 +808,7 @@ class CutlassModuleCodegen : public
CSourceModuleCodegenBase {
code_stream_ << "#include <cuda_fp16.h>\n";
code_stream_ << "#include <cutlass/cutlass.h>\n";
code_stream_ << "#include <cutlass/coord.h>\n";
+ code_stream_ << "#include <cutlass/tensor_ref.h>\n";
code_stream_ << "#include <cutlass/util/host_tensor.h>\n";
code_stream_ << "#include <cutlass/gemm/device/gemm.h>\n";
code_stream_ << "#include <cutlass/gemm/device/gemm_batched.h>\n";
@@ -734,6 +823,8 @@ class CutlassModuleCodegen : public
CSourceModuleCodegenBase {
code_stream_ << "#include
<cutlass/epilogue/thread/linear_combination_hardswish.h>\n";
code_stream_ << "#include
<cutlass/epilogue/thread/linear_combination_residual_block.h>\n";
code_stream_ << "#include
<cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h>\n";
+ code_stream_ << "#include <cutlass/reduction/device/reduce_split_k.h>\n";
+ code_stream_ << "#include
<cutlass/reduction/thread/reduction_operators.h>\n";
ICHECK(ref->IsInstance<FunctionNode>());
auto res = GenCutlassFunc(Downcast<Function>(ref));
diff --git a/tests/python/contrib/test_cutlass.py
b/tests/python/contrib/test_cutlass.py
index ef55c74..ad75e73 100644
--- a/tests/python/contrib/test_cutlass.py
+++ b/tests/python/contrib/test_cutlass.py
@@ -253,16 +253,24 @@ def get_random_ndarray(shape, dtype):
def profile_and_build(
- mod, params, sm, tmp_dir="./tmp", lib_path="compile.so",
use_fast_math=False, use_3xtf32=True
+ mod,
+ params,
+ sm,
+ split_k_slices=[1],
+ tmp_dir="./tmp",
+ lib_path="compile.so",
+ use_fast_math=False,
+ use_3xtf32=True,
):
mod = partition_for_cutlass(mod)
mod, num_cutlass_partition = tune_cutlass_kernels(
mod,
sm,
use_3xtf32=use_3xtf32,
+ split_k_slices=split_k_slices,
profile_all_alignments=False,
find_first_valid=True,
- use_multiprocessing=False,
+ use_multiprocessing=True,
tmp_dir=tmp_dir,
)
with tvm.transform.PassContext(opt_level=3):
@@ -277,6 +285,7 @@ def profile_and_build_vm(
mod,
params,
sm,
+ split_k_slices=[1],
tmp_dir="./tmp",
lib_path="compile.so",
vmcode_path="vmcode.ro",
@@ -287,6 +296,7 @@ def profile_and_build_vm(
mod, num_cutlass_partition = tune_cutlass_kernels(
mod,
sm,
+ split_k_slices=split_k_slices,
use_3xtf32=use_3xtf32,
profile_all_alignments=False,
find_first_valid=True,
@@ -508,6 +518,7 @@ def verify_conv2d_common(
inputs,
params,
sm=80,
+ split_k_slices=[1],
atol=1e-5,
rtol=1e-5,
use_cudnn_ref=False,
@@ -543,7 +554,7 @@ def verify_conv2d_common(
)
rt_mod, _, num_cutlass_partition = profile_and_build_func(
- mod_weight_ohwi, params, sm, use_fast_math=use_fast_math
+ mod_weight_ohwi, params, sm, split_k_slices,
use_fast_math=use_fast_math
)
out = get_output_func(rt_mod, input_names, inputs)
@@ -597,6 +608,8 @@ def verify_conv2d(
np_bias = get_random_ndarray((w_shape[0],), typ.dtype)
params = {"weight": np_weight, "bias": np_bias}
+ split_k_slices = [1]
+
return verify_conv2d_common(
expr_nchw,
expr_ref,
@@ -604,6 +617,7 @@ def verify_conv2d(
[np_data],
params,
sm,
+ split_k_slices,
atol,
rtol,
use_cudnn_ref,
@@ -620,6 +634,7 @@ def verify_conv2d_backward_weight(
grad_shape,
data_shape,
sm=80,
+ split_k_slices=[1],
atol=1e-5,
rtol=1e-5,
use_cudnn_ref=False,
@@ -640,6 +655,7 @@ def verify_conv2d_backward_weight(
[np_grad, np_data],
params,
sm,
+ split_k_slices,
atol,
rtol,
use_cudnn_ref,
@@ -838,18 +854,20 @@ def test_conv2d_backward_weight():
weight_dtype=dtype,
)
- verify_conv2d_backward_weight(
- mod_nchw,
- mod_nchw,
- o_shape,
- d_shape,
- sm=80,
- atol=1e-3,
- rtol=1e-3,
- use_cudnn_ref=False,
- grad_dtype=dtype,
- data_dtype=dtype,
- )
+ for split_k_slices in [1, 8]:
+ verify_conv2d_backward_weight(
+ mod_nchw,
+ mod_nchw,
+ o_shape,
+ d_shape,
+ sm=80,
+ split_k_slices=[split_k_slices],
+ atol=1e-3,
+ rtol=1e-3,
+ use_cudnn_ref=False,
+ grad_dtype=dtype,
+ data_dtype=dtype,
+ )
def test_conv2d_bwd():