This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 688d9dd068 [Unity][CUTLASS][Cherry-pick] Skip profiling all conv2d 
output alignments when possible (#15583)
688d9dd068 is described below

commit 688d9dd06815ffee634a8643086bf662c2d1cf66
Author: masahi <[email protected]>
AuthorDate: Thu Aug 17 21:44:57 2023 +0900

    [Unity][CUTLASS][Cherry-pick] Skip profiling all conv2d output alignments 
when possible (#15583)
    
    skip profiling all conv2d output alignments when possible
---
 python/tvm/contrib/cutlass/gen_conv2d.py | 7 ++++++-
 1 file changed, 6 insertions(+), 1 deletion(-)

diff --git a/python/tvm/contrib/cutlass/gen_conv2d.py 
b/python/tvm/contrib/cutlass/gen_conv2d.py
index 3a9cbf1e84..3d14a427b1 100644
--- a/python/tvm/contrib/cutlass/gen_conv2d.py
+++ b/python/tvm/contrib/cutlass/gen_conv2d.py
@@ -285,6 +285,11 @@ class CutlassConv2DProfiler:
                 raise ValueError("Unsupported data type: %s" % dtype)
             return alignments
 
+        alignments_c = [align for align in alignments(out_dtype) if OC % align 
== 0]
+
+        if not profile_all_alignments:
+            alignments_c = [alignments_c[0]]
+
         ops = GENERATOR_FUNC_TABLE[self.sm](
             out_dtype,
             data_dtype,
@@ -294,7 +299,7 @@ class CutlassConv2DProfiler:
                 conv_kind,
                 stride_support,
                 split_k_slices,
-                [align for align in alignments(out_dtype) if OC % align == 0],
+                alignments_c,
             ),
             lambda align: all([dim % align == 0 for dim in [IC]]),
             use_3xtf32,

Reply via email to