This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 0ffd24c9f3 [Unity][Contrib] Introduce several features of cutlass 
profiler (#14275)
0ffd24c9f3 is described below

commit 0ffd24c9f32d16a8f3b250d85f3802de3468db1b
Author: Bohan Hou <[email protected]>
AuthorDate: Sat Mar 18 17:08:36 2023 -0400

    [Unity][Contrib] Introduce several features of cutlass profiler (#14275)
    
    - allow Conv2d using different alignment factors for input and epilogue, 
which can influence performance
    - store the profiler cache on disk, reducing CUTLASS profiler overhead 
across different runs
    - use the same set of default tile configurations as CUTLASS for sm80 
https://github.com/NVIDIA/cutlass/blob/master/tools/library/scripts/generator.py#L1881
---
 python/tvm/contrib/cutlass/gen_conv2d.py    | 119 +++++++++++++++++-----------
 python/tvm/contrib/cutlass/gen_gemm.py      |  10 ++-
 python/tvm/contrib/cutlass/gen_tensor_op.py |   6 +-
 3 files changed, 88 insertions(+), 47 deletions(-)

diff --git a/python/tvm/contrib/cutlass/gen_conv2d.py 
b/python/tvm/contrib/cutlass/gen_conv2d.py
index bb26a47a55..9e9e16426b 100644
--- a/python/tvm/contrib/cutlass/gen_conv2d.py
+++ b/python/tvm/contrib/cutlass/gen_conv2d.py
@@ -16,6 +16,8 @@
 # under the License.
 # pylint: disable=invalid-name, dangerous-default-value
 """Conv2d kernel generator and profiler for CUTLASS."""
+import os
+import pickle
 from functools import partial
 from .conv2d_operation import Conv2dOperation, EmitConv2dInstance
 from .gen_gemm import CutlassGemmProfiler
@@ -40,6 +42,7 @@ def create_conv2d_operator_with_epilogue(
     tile_description,
     data_type,
     alignment,
+    alignment_epilogue,
     swizzling_functor,
     split_k_slices,
 ):
@@ -78,7 +81,7 @@ def create_conv2d_operator_with_epilogue(
 
     A = TensorDescription(element_a, LayoutType.TensorNHWC, alignment)
     B = TensorDescription(element_b, LayoutType.TensorNHWC, alignment)
-    C = TensorDescription(element_c, LayoutType.TensorNHWC, alignment)
+    C = TensorDescription(element_c, LayoutType.TensorNHWC, alignment_epilogue)
 
     op = Conv2dOperation(
         conv_kind,
@@ -110,6 +113,7 @@ def enumerate_conv2d_operators(
     conv_kind,
     stride_support,
     split_k_slices,
+    alignment_c,
     tile_descriptions,
     data_type,
     alignment_constraints,
@@ -128,47 +132,49 @@ def enumerate_conv2d_operators(
 
     for split_k_slice in split_k_slices:
         for tile in tile_descriptions:
-            for alignment in alignment_constraints:
-
-                A = TensorDescription(element_a, LayoutType.TensorNHWC, 
alignment)
-                B = TensorDescription(element_b, LayoutType.TensorNHWC, 
alignment)
-                C = TensorDescription(element_c, LayoutType.TensorNHWC, 
alignment)
-
-                if element_c == DataType.s32 and A.alignment == 1:
-                    tile.threadblock_shape[0] = min(tile.threadblock_shape[0], 
128)
-                    tile.threadblock_shape[1] = min(tile.threadblock_shape[1], 
128)
-
-                op = Conv2dOperation(
-                    conv_kind,
-                    IteratorAlgorithm.Optimized,
-                    tile.minimum_compute_capability,
-                    tile,
-                    A,
-                    B,
-                    C,
-                    element_epilogue,
-                    stride_support,
-                    EpilogueFunctor.LinearCombination,
-                    swizzling_functor,
-                    split_k_slice,
-                )
-
-                ret.append(
-                    {
-                        "src": profiler_emitter.emit(
-                            kernel_emitter.emit(op, 
emit_reduction=split_k_slice > 1),
-                            op.procedural_name(),
-                            element_output=element_c,
-                            split_k_slices=split_k_slice,
-                        ),
-                        "name": op.procedural_name(),
-                        "tile_description": tile,
-                        "alignment": alignment,
-                        "data_type": data_type,
-                        "swizzle_functor": swizzling_functor,
-                        "split_k_slices": split_k_slice,
-                    }
-                )
+            for alignmentAB in alignment_constraints:
+                for alignmentC in alignment_c:
+
+                    A = TensorDescription(element_a, LayoutType.TensorNHWC, 
alignmentAB)
+                    B = TensorDescription(element_b, LayoutType.TensorNHWC, 
alignmentAB)
+                    C = TensorDescription(element_c, LayoutType.TensorNHWC, 
alignmentC)
+
+                    if element_c == DataType.s32 and A.alignment == 1:
+                        tile.threadblock_shape[0] = 
min(tile.threadblock_shape[0], 128)
+                        tile.threadblock_shape[1] = 
min(tile.threadblock_shape[1], 128)
+
+                    op = Conv2dOperation(
+                        conv_kind,
+                        IteratorAlgorithm.Optimized,
+                        tile.minimum_compute_capability,
+                        tile,
+                        A,
+                        B,
+                        C,
+                        element_epilogue,
+                        stride_support,
+                        EpilogueFunctor.LinearCombination,
+                        swizzling_functor,
+                        split_k_slice,
+                    )
+
+                    ret.append(
+                        {
+                            "src": profiler_emitter.emit(
+                                kernel_emitter.emit(op, 
emit_reduction=split_k_slice > 1),
+                                op.procedural_name(),
+                                element_output=element_c,
+                                split_k_slices=split_k_slice,
+                            ),
+                            "name": op.procedural_name(),
+                            "tile_description": tile,
+                            "alignment": alignmentAB,
+                            "alignment_epilogue": alignmentC,
+                            "data_type": data_type,
+                            "swizzle_functor": swizzling_functor,
+                            "split_k_slices": split_k_slice,
+                        }
+                    )
 
     return ret
 
@@ -181,7 +187,11 @@ class CutlassConv2DProfiler:
         self.sm = sm
         assert sm in GENERATOR_FUNC_TABLE, "sm%d not supported yet." % sm
         self.engine = ProfilerEngine(sm, cutlass_path, binary_path)
-        self.cache = {}
+        self.cache_path = os.path.join(binary_path, 
"cutlass_conv2d_cache.pickle")
+        if os.path.exists(self.cache_path):
+            self.cache = pickle.load(open(self.cache_path, "rb"))
+        else:
+            self.cache = {}
 
     def get_default(
         self,
@@ -216,6 +226,7 @@ class CutlassConv2DProfiler:
             tile_description,
             data_type,
             alignment,
+            alignment,
             swizzling_functor,
             split_k_slices=1,
         )
@@ -265,12 +276,27 @@ class CutlassConv2DProfiler:
         if workload in self.cache:
             return self.cache[workload]
 
+        def alignments(dtype):
+            if dtype in ["float16"]:
+                alignments = [8, 4, 2, 1]
+            elif dtype in ["float", "float32"]:
+                alignments = [4, 2, 1]
+            else:
+                raise ValueError("Unsupported data type: %s" % dtype)
+            return alignments
+
         ops = GENERATOR_FUNC_TABLE[self.sm](
             out_dtype,
             data_dtype,
             weight_dtype,
-            partial(enumerate_conv2d_operators, conv_kind, stride_support, 
split_k_slices),
-            lambda align: all([dim % align == 0 for dim in [IC, OC]]),
+            partial(
+                enumerate_conv2d_operators,
+                conv_kind,
+                stride_support,
+                split_k_slices,
+                [align for align in alignments(out_dtype) if OC % align == 0],
+            ),
+            lambda align: all([dim % align == 0 for dim in [IC]]),
             use_3xtf32,
             profile_all_alignments,
             # Use fp32 accumulation for wgrad to align with cuDNN
@@ -294,6 +320,8 @@ class CutlassConv2DProfiler:
 
         op = min(ops, key=lambda i: i["runtime"])
         self.cache[workload] = op
+        with open(self.cache_path, "wb") as f:
+            pickle.dump(self.cache, f)
         return op
 
     def profile(
@@ -350,6 +378,7 @@ class CutlassConv2DProfiler:
             op["tile_description"],
             op["data_type"],
             op["alignment"],
+            op["alignment_epilogue"],
             op["swizzle_functor"],
             op["split_k_slices"],
         )
diff --git a/python/tvm/contrib/cutlass/gen_gemm.py 
b/python/tvm/contrib/cutlass/gen_gemm.py
index f5f160a400..0ea6231b81 100644
--- a/python/tvm/contrib/cutlass/gen_gemm.py
+++ b/python/tvm/contrib/cutlass/gen_gemm.py
@@ -16,6 +16,8 @@
 # under the License.
 # pylint: disable=invalid-name
 """GEMM kernel generator and profiler for CUTLASS."""
+import os
+import pickle
 from functools import partial
 
 from .gemm_operation import EmitGemmInstance, GemmOperation
@@ -195,7 +197,11 @@ class CutlassGemmProfiler:
         assert sm in GENERATOR_FUNC_TABLE and sm in DEFAULT_KERNELS, "sm%d not 
supported yet." % sm
         self.engine = ProfilerEngine(sm, cutlass_path, binary_path)
         self.sm = sm
-        self.cache = {}
+        self.cache_path = os.path.join(binary_path, 
"cutlass_gemm_cache.pickle")
+        if os.path.exists(self.cache_path):
+            self.cache = pickle.load(open(self.cache_path, "rb"))
+        else:
+            self.cache = {}
 
     def get_default(
         self,
@@ -294,6 +300,8 @@ class CutlassGemmProfiler:
 
         op = min(ops, key=lambda i: i["runtime"])
         self.cache[(M, N, K)] = op
+        with open(self.cache_path, "wb") as f:
+            pickle.dump(self.cache, f)
         return op
 
     def profile(
diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py 
b/python/tvm/contrib/cutlass/gen_tensor_op.py
index 6b2587a0b0..177304e102 100644
--- a/python/tvm/contrib/cutlass/gen_tensor_op.py
+++ b/python/tvm/contrib/cutlass/gen_tensor_op.py
@@ -230,8 +230,9 @@ def generate_sm80_tensor_op_16816(
 
     def get_default_tile_descriptions(block_k_factor):
         return [
-            ([256, 128, int(32 * block_k_factor)], 3, [4, 2, 1], min_cc, 
max_cc),
             ([128, 256, int(32 * block_k_factor)], 3, [2, 4, 1], min_cc, 
max_cc),
+            ([256, 128, int(32 * block_k_factor)], 3, [4, 2, 1], min_cc, 
max_cc),
+            ([256, 64, int(32 * block_k_factor)], 3, [4, 1, 1], min_cc, 
max_cc),
             ([256, 64, int(32 * block_k_factor)], 4, [4, 1, 1], min_cc, 
max_cc),
             ([64, 256, int(32 * block_k_factor)], 4, [1, 4, 1], min_cc, 
max_cc),
             ([128, 128, int(32 * block_k_factor)], 3, [2, 2, 1], min_cc, 
max_cc),
@@ -245,6 +246,9 @@ def generate_sm80_tensor_op_16816(
             ([256, 64, int(64 * block_k_factor)], 4, [4, 1, 1], min_cc, 
max_cc_smem_limited),
             ([64, 256, int(64 * block_k_factor)], 4, [1, 4, 1], min_cc, 
max_cc_smem_limited),
             ([128, 128, int(64 * block_k_factor)], 4, [2, 2, 1], min_cc, 
max_cc),
+            ([256, 64, int(64 * block_k_factor)], 3, [4, 1, 1], min_cc, 
max_cc),
+            ([64, 256, int(64 * block_k_factor)], 3, [1, 4, 1], min_cc, 
max_cc),
+            ([128, 128, int(64 * block_k_factor)], 3, [2, 2, 1], min_cc, 
max_cc),
             ([128, 64, int(64 * block_k_factor)], 3, [2, 2, 1], min_cc, 
max_cc),
             ([64, 128, int(64 * block_k_factor)], 3, [2, 2, 1], min_cc, 
max_cc),
             ([64, 64, int(64 * block_k_factor)], 5, [2, 2, 1], min_cc, max_cc),

Reply via email to