This is an automated email from the ASF dual-hosted git repository.
comaniac pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 5c1a1cf [CUDA] Improve injective schedule to enable half2 (#8457)
5c1a1cf is described below
commit 5c1a1cf7289b439b0042a85b63b0007dc1d9b98a
Author: Cody Yu <[email protected]>
AuthorDate: Tue Jul 13 17:57:19 2021 -0700
[CUDA] Improve injective schedule to enable half2 (#8457)
* [CUDA] Improve injective schedule to enable half2
* lint
* fix
* trigger ci
---
python/tvm/topi/cuda/injective.py | 36 +++++++++++++++++++++++++++++++++---
1 file changed, 33 insertions(+), 3 deletions(-)
diff --git a/python/tvm/topi/cuda/injective.py
b/python/tvm/topi/cuda/injective.py
index cce56b7..0faddc3 100644
--- a/python/tvm/topi/cuda/injective.py
+++ b/python/tvm/topi/cuda/injective.py
@@ -16,6 +16,8 @@
# under the License.
# pylint: disable=invalid-name, unused-variable,
"""Schedule for composition of injective operator"""
+import numpy as np
+
import tvm
from tvm import te
from .. import utils
@@ -36,13 +38,21 @@ def schedule_injective_from_existing(sch, out):
sch: Schedule
The updated schedule.
"""
+
+ def find_nearest_small_factor(num, target):
+ """Find the nearest factor of the given number that is smaller than
the target."""
+ for i in range(target, 0, -1):
+ if num % i == 0:
+ return i
+ # Unreachable because i=1 must hold.
+ return -1
+
fused = sch[out].fuse(*sch[out].op.axis)
num_thread = tvm.target.Target.current(allow_none=False).max_num_threads
max_block = 256
- # vectorize on fp16 data type. This allows to better utilize the memory
- # bandwidth.
- vector_width = 4 if out.dtype == "float16" else 1
+ # Vectorize on fp16 data type to enable half2 for better memory bandwidth
utilization.
+ vector_width = 2 if out.dtype == "float16" else 1
is_dynamic_output = False
for dim in out.shape:
@@ -54,6 +64,26 @@ def schedule_injective_from_existing(sch, out):
try:
const_size = utils.get_const_int(out_len)
+
+ # Adjust block and thread to make sure they are dividable so that
vectorize can be
+ # correctly applied.
+ if vector_width > 1 and const_size % vector_width == 0:
+ remain_total_size = const_size // vector_width
+ cand_sizes = []
+ for max_size in [num_thread, max_block]:
+ cand_sizes.append(
+ max_size
+ if remain_total_size % max_size == 0
+ else find_nearest_small_factor(remain_total_size, max_size)
+ )
+ remain_total_size //= cand_sizes[-1]
+
+ # If the product of candidate dividable (block * thread) is too
small,
+ # then the performance may be worse even half2 is enabled. Note
that 0.7
+ # is just a heuristic ratio and may not be optimal for all
workloads.
+ if np.prod(cand_sizes) / (max_block * num_thread) >= 0.7:
+ num_thread, max_block = cand_sizes
+
need_block_split = const_size > max_block * num_thread * vector_width
except ValueError:
need_block_split = False