This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 688d9dd068 [Unity][CUTLASS][Cherry-pick] Skip profiling all conv2d
output alignments when possible (#15583)
688d9dd068 is described below
commit 688d9dd06815ffee634a8643086bf662c2d1cf66
Author: masahi <[email protected]>
AuthorDate: Thu Aug 17 21:44:57 2023 +0900
[Unity][CUTLASS][Cherry-pick] Skip profiling all conv2d output alignments
when possible (#15583)
skip profiling all conv2d output alignments when possible
---
python/tvm/contrib/cutlass/gen_conv2d.py | 7 ++++++-
1 file changed, 6 insertions(+), 1 deletion(-)
diff --git a/python/tvm/contrib/cutlass/gen_conv2d.py
b/python/tvm/contrib/cutlass/gen_conv2d.py
index 3a9cbf1e84..3d14a427b1 100644
--- a/python/tvm/contrib/cutlass/gen_conv2d.py
+++ b/python/tvm/contrib/cutlass/gen_conv2d.py
@@ -285,6 +285,11 @@ class CutlassConv2DProfiler:
raise ValueError("Unsupported data type: %s" % dtype)
return alignments
+ alignments_c = [align for align in alignments(out_dtype) if OC % align
== 0]
+
+ if not profile_all_alignments:
+ alignments_c = [alignments_c[0]]
+
ops = GENERATOR_FUNC_TABLE[self.sm](
out_dtype,
data_dtype,
@@ -294,7 +299,7 @@ class CutlassConv2DProfiler:
conv_kind,
stride_support,
split_k_slices,
- [align for align in alignments(out_dtype) if OC % align == 0],
+ alignments_c,
),
lambda align: all([dim % align == 0 for dim in [IC]]),
use_3xtf32,