This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 0ffd24c9f3 [Unity][Contrib] Introduce several features of cutlass
profiler (#14275)
0ffd24c9f3 is described below
commit 0ffd24c9f32d16a8f3b250d85f3802de3468db1b
Author: Bohan Hou <[email protected]>
AuthorDate: Sat Mar 18 17:08:36 2023 -0400
[Unity][Contrib] Introduce several features of cutlass profiler (#14275)
- allow Conv2d using different alignment factors for input and epilogue,
which can influence performance
- store the profiler cache on disk, reducing CUTLASS profiler overhead
across different runs
- use the same set of default tile configurations as CUTLASS for sm80
https://github.com/NVIDIA/cutlass/blob/master/tools/library/scripts/generator.py#L1881
---
python/tvm/contrib/cutlass/gen_conv2d.py | 119 +++++++++++++++++-----------
python/tvm/contrib/cutlass/gen_gemm.py | 10 ++-
python/tvm/contrib/cutlass/gen_tensor_op.py | 6 +-
3 files changed, 88 insertions(+), 47 deletions(-)
diff --git a/python/tvm/contrib/cutlass/gen_conv2d.py
b/python/tvm/contrib/cutlass/gen_conv2d.py
index bb26a47a55..9e9e16426b 100644
--- a/python/tvm/contrib/cutlass/gen_conv2d.py
+++ b/python/tvm/contrib/cutlass/gen_conv2d.py
@@ -16,6 +16,8 @@
# under the License.
# pylint: disable=invalid-name, dangerous-default-value
"""Conv2d kernel generator and profiler for CUTLASS."""
+import os
+import pickle
from functools import partial
from .conv2d_operation import Conv2dOperation, EmitConv2dInstance
from .gen_gemm import CutlassGemmProfiler
@@ -40,6 +42,7 @@ def create_conv2d_operator_with_epilogue(
tile_description,
data_type,
alignment,
+ alignment_epilogue,
swizzling_functor,
split_k_slices,
):
@@ -78,7 +81,7 @@ def create_conv2d_operator_with_epilogue(
A = TensorDescription(element_a, LayoutType.TensorNHWC, alignment)
B = TensorDescription(element_b, LayoutType.TensorNHWC, alignment)
- C = TensorDescription(element_c, LayoutType.TensorNHWC, alignment)
+ C = TensorDescription(element_c, LayoutType.TensorNHWC, alignment_epilogue)
op = Conv2dOperation(
conv_kind,
@@ -110,6 +113,7 @@ def enumerate_conv2d_operators(
conv_kind,
stride_support,
split_k_slices,
+ alignment_c,
tile_descriptions,
data_type,
alignment_constraints,
@@ -128,47 +132,49 @@ def enumerate_conv2d_operators(
for split_k_slice in split_k_slices:
for tile in tile_descriptions:
- for alignment in alignment_constraints:
-
- A = TensorDescription(element_a, LayoutType.TensorNHWC,
alignment)
- B = TensorDescription(element_b, LayoutType.TensorNHWC,
alignment)
- C = TensorDescription(element_c, LayoutType.TensorNHWC,
alignment)
-
- if element_c == DataType.s32 and A.alignment == 1:
- tile.threadblock_shape[0] = min(tile.threadblock_shape[0],
128)
- tile.threadblock_shape[1] = min(tile.threadblock_shape[1],
128)
-
- op = Conv2dOperation(
- conv_kind,
- IteratorAlgorithm.Optimized,
- tile.minimum_compute_capability,
- tile,
- A,
- B,
- C,
- element_epilogue,
- stride_support,
- EpilogueFunctor.LinearCombination,
- swizzling_functor,
- split_k_slice,
- )
-
- ret.append(
- {
- "src": profiler_emitter.emit(
- kernel_emitter.emit(op,
emit_reduction=split_k_slice > 1),
- op.procedural_name(),
- element_output=element_c,
- split_k_slices=split_k_slice,
- ),
- "name": op.procedural_name(),
- "tile_description": tile,
- "alignment": alignment,
- "data_type": data_type,
- "swizzle_functor": swizzling_functor,
- "split_k_slices": split_k_slice,
- }
- )
+ for alignmentAB in alignment_constraints:
+ for alignmentC in alignment_c:
+
+ A = TensorDescription(element_a, LayoutType.TensorNHWC,
alignmentAB)
+ B = TensorDescription(element_b, LayoutType.TensorNHWC,
alignmentAB)
+ C = TensorDescription(element_c, LayoutType.TensorNHWC,
alignmentC)
+
+ if element_c == DataType.s32 and A.alignment == 1:
+ tile.threadblock_shape[0] =
min(tile.threadblock_shape[0], 128)
+ tile.threadblock_shape[1] =
min(tile.threadblock_shape[1], 128)
+
+ op = Conv2dOperation(
+ conv_kind,
+ IteratorAlgorithm.Optimized,
+ tile.minimum_compute_capability,
+ tile,
+ A,
+ B,
+ C,
+ element_epilogue,
+ stride_support,
+ EpilogueFunctor.LinearCombination,
+ swizzling_functor,
+ split_k_slice,
+ )
+
+ ret.append(
+ {
+ "src": profiler_emitter.emit(
+ kernel_emitter.emit(op,
emit_reduction=split_k_slice > 1),
+ op.procedural_name(),
+ element_output=element_c,
+ split_k_slices=split_k_slice,
+ ),
+ "name": op.procedural_name(),
+ "tile_description": tile,
+ "alignment": alignmentAB,
+ "alignment_epilogue": alignmentC,
+ "data_type": data_type,
+ "swizzle_functor": swizzling_functor,
+ "split_k_slices": split_k_slice,
+ }
+ )
return ret
@@ -181,7 +187,11 @@ class CutlassConv2DProfiler:
self.sm = sm
assert sm in GENERATOR_FUNC_TABLE, "sm%d not supported yet." % sm
self.engine = ProfilerEngine(sm, cutlass_path, binary_path)
- self.cache = {}
+ self.cache_path = os.path.join(binary_path,
"cutlass_conv2d_cache.pickle")
+ if os.path.exists(self.cache_path):
+ self.cache = pickle.load(open(self.cache_path, "rb"))
+ else:
+ self.cache = {}
def get_default(
self,
@@ -216,6 +226,7 @@ class CutlassConv2DProfiler:
tile_description,
data_type,
alignment,
+ alignment,
swizzling_functor,
split_k_slices=1,
)
@@ -265,12 +276,27 @@ class CutlassConv2DProfiler:
if workload in self.cache:
return self.cache[workload]
+ def alignments(dtype):
+ if dtype in ["float16"]:
+ alignments = [8, 4, 2, 1]
+ elif dtype in ["float", "float32"]:
+ alignments = [4, 2, 1]
+ else:
+ raise ValueError("Unsupported data type: %s" % dtype)
+ return alignments
+
ops = GENERATOR_FUNC_TABLE[self.sm](
out_dtype,
data_dtype,
weight_dtype,
- partial(enumerate_conv2d_operators, conv_kind, stride_support,
split_k_slices),
- lambda align: all([dim % align == 0 for dim in [IC, OC]]),
+ partial(
+ enumerate_conv2d_operators,
+ conv_kind,
+ stride_support,
+ split_k_slices,
+ [align for align in alignments(out_dtype) if OC % align == 0],
+ ),
+ lambda align: all([dim % align == 0 for dim in [IC]]),
use_3xtf32,
profile_all_alignments,
# Use fp32 accumulation for wgrad to align with cuDNN
@@ -294,6 +320,8 @@ class CutlassConv2DProfiler:
op = min(ops, key=lambda i: i["runtime"])
self.cache[workload] = op
+ with open(self.cache_path, "wb") as f:
+ pickle.dump(self.cache, f)
return op
def profile(
@@ -350,6 +378,7 @@ class CutlassConv2DProfiler:
op["tile_description"],
op["data_type"],
op["alignment"],
+ op["alignment_epilogue"],
op["swizzle_functor"],
op["split_k_slices"],
)
diff --git a/python/tvm/contrib/cutlass/gen_gemm.py
b/python/tvm/contrib/cutlass/gen_gemm.py
index f5f160a400..0ea6231b81 100644
--- a/python/tvm/contrib/cutlass/gen_gemm.py
+++ b/python/tvm/contrib/cutlass/gen_gemm.py
@@ -16,6 +16,8 @@
# under the License.
# pylint: disable=invalid-name
"""GEMM kernel generator and profiler for CUTLASS."""
+import os
+import pickle
from functools import partial
from .gemm_operation import EmitGemmInstance, GemmOperation
@@ -195,7 +197,11 @@ class CutlassGemmProfiler:
assert sm in GENERATOR_FUNC_TABLE and sm in DEFAULT_KERNELS, "sm%d not
supported yet." % sm
self.engine = ProfilerEngine(sm, cutlass_path, binary_path)
self.sm = sm
- self.cache = {}
+ self.cache_path = os.path.join(binary_path,
"cutlass_gemm_cache.pickle")
+ if os.path.exists(self.cache_path):
+ self.cache = pickle.load(open(self.cache_path, "rb"))
+ else:
+ self.cache = {}
def get_default(
self,
@@ -294,6 +300,8 @@ class CutlassGemmProfiler:
op = min(ops, key=lambda i: i["runtime"])
self.cache[(M, N, K)] = op
+ with open(self.cache_path, "wb") as f:
+ pickle.dump(self.cache, f)
return op
def profile(
diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py
b/python/tvm/contrib/cutlass/gen_tensor_op.py
index 6b2587a0b0..177304e102 100644
--- a/python/tvm/contrib/cutlass/gen_tensor_op.py
+++ b/python/tvm/contrib/cutlass/gen_tensor_op.py
@@ -230,8 +230,9 @@ def generate_sm80_tensor_op_16816(
def get_default_tile_descriptions(block_k_factor):
return [
- ([256, 128, int(32 * block_k_factor)], 3, [4, 2, 1], min_cc,
max_cc),
([128, 256, int(32 * block_k_factor)], 3, [2, 4, 1], min_cc,
max_cc),
+ ([256, 128, int(32 * block_k_factor)], 3, [4, 2, 1], min_cc,
max_cc),
+ ([256, 64, int(32 * block_k_factor)], 3, [4, 1, 1], min_cc,
max_cc),
([256, 64, int(32 * block_k_factor)], 4, [4, 1, 1], min_cc,
max_cc),
([64, 256, int(32 * block_k_factor)], 4, [1, 4, 1], min_cc,
max_cc),
([128, 128, int(32 * block_k_factor)], 3, [2, 2, 1], min_cc,
max_cc),
@@ -245,6 +246,9 @@ def generate_sm80_tensor_op_16816(
([256, 64, int(64 * block_k_factor)], 4, [4, 1, 1], min_cc,
max_cc_smem_limited),
([64, 256, int(64 * block_k_factor)], 4, [1, 4, 1], min_cc,
max_cc_smem_limited),
([128, 128, int(64 * block_k_factor)], 4, [2, 2, 1], min_cc,
max_cc),
+ ([256, 64, int(64 * block_k_factor)], 3, [4, 1, 1], min_cc,
max_cc),
+ ([64, 256, int(64 * block_k_factor)], 3, [1, 4, 1], min_cc,
max_cc),
+ ([128, 128, int(64 * block_k_factor)], 3, [2, 2, 1], min_cc,
max_cc),
([128, 64, int(64 * block_k_factor)], 3, [2, 2, 1], min_cc,
max_cc),
([64, 128, int(64 * block_k_factor)], 3, [2, 2, 1], min_cc,
max_cc),
([64, 64, int(64 * block_k_factor)], 5, [2, 2, 1], min_cc, max_cc),