This is an automated email from the ASF dual-hosted git repository.

comaniac pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 5c1a1cf  [CUDA] Improve injective schedule to enable half2 (#8457)
5c1a1cf is described below

commit 5c1a1cf7289b439b0042a85b63b0007dc1d9b98a
Author: Cody Yu <[email protected]>
AuthorDate: Tue Jul 13 17:57:19 2021 -0700

    [CUDA] Improve injective schedule to enable half2 (#8457)
    
    * [CUDA] Improve injective schedule to enable half2
    
    * lint
    
    * fix
    
    * trigger ci
---
 python/tvm/topi/cuda/injective.py | 36 +++++++++++++++++++++++++++++++++---
 1 file changed, 33 insertions(+), 3 deletions(-)

diff --git a/python/tvm/topi/cuda/injective.py 
b/python/tvm/topi/cuda/injective.py
index cce56b7..0faddc3 100644
--- a/python/tvm/topi/cuda/injective.py
+++ b/python/tvm/topi/cuda/injective.py
@@ -16,6 +16,8 @@
 # under the License.
 # pylint: disable=invalid-name, unused-variable,
 """Schedule for composition of injective operator"""
+import numpy as np
+
 import tvm
 from tvm import te
 from .. import utils
@@ -36,13 +38,21 @@ def schedule_injective_from_existing(sch, out):
     sch: Schedule
          The updated schedule.
     """
+
+    def find_nearest_small_factor(num, target):
+        """Find the nearest factor of the given number that is smaller than 
the target."""
+        for i in range(target, 0, -1):
+            if num % i == 0:
+                return i
+        # Unreachable because i=1 must hold.
+        return -1
+
     fused = sch[out].fuse(*sch[out].op.axis)
     num_thread = tvm.target.Target.current(allow_none=False).max_num_threads
     max_block = 256
 
-    # vectorize on fp16 data type. This allows to better utilize the memory
-    # bandwidth.
-    vector_width = 4 if out.dtype == "float16" else 1
+    # Vectorize on fp16 data type to enable half2 for better memory bandwidth 
utilization.
+    vector_width = 2 if out.dtype == "float16" else 1
 
     is_dynamic_output = False
     for dim in out.shape:
@@ -54,6 +64,26 @@ def schedule_injective_from_existing(sch, out):
 
     try:
         const_size = utils.get_const_int(out_len)
+
+        # Adjust block and thread to make sure they are dividable so that 
vectorize can be
+        # correctly applied.
+        if vector_width > 1 and const_size % vector_width == 0:
+            remain_total_size = const_size // vector_width
+            cand_sizes = []
+            for max_size in [num_thread, max_block]:
+                cand_sizes.append(
+                    max_size
+                    if remain_total_size % max_size == 0
+                    else find_nearest_small_factor(remain_total_size, max_size)
+                )
+                remain_total_size //= cand_sizes[-1]
+
+            # If the product of candidate dividable (block * thread) is too 
small,
+            # then the performance may be worse even half2 is enabled. Note 
that 0.7
+            # is just a heuristic ratio and may not be optimal for all 
workloads.
+            if np.prod(cand_sizes) / (max_block * num_thread) >= 0.7:
+                num_thread, max_block = cand_sizes
+
         need_block_split = const_size > max_block * num_thread * vector_width
     except ValueError:
         need_block_split = False

Reply via email to