wpan11nv commented on a change in pull request #5600:
URL: https://github.com/apache/incubator-tvm/pull/5600#discussion_r429998819



##########
File path: topi/python/topi/cuda/softmax.py
##########
@@ -53,13 +54,62 @@ def schedule_softmax(outs):
         raise ValueError('Tag is expected to be softmax_output or 
log_softmax_output. \
                          Got {0}'.format(op_tag))
 
+    # The nvptx backend only supports 32-bits warp shuffle instructions.
+    #
+    # TODO(tvm-team) Fix nvptx codegen or deprecate nvptx backend.
+    def sched_warp_softmax():
+        if tgt.target_name == "nvptx":
+            return softmax.dtype == "float32" or softmax.dtype == "int32"
+        return True
+
     if len(softmax.shape) > 2:
         ops = [max_elem.op, expsum.op, softmax.op]
         if exp is not None:
             ops.append(exp.op)
 
         for op in ops:
             s = schedule_injective_from_existing(s, op.output(0))
+
+    elif sched_warp_softmax():
+        # A warp of 32 threads performs a row reduction.
+        num_thread = tgt.thread_warp_size
+        block_x = te.thread_axis("blockIdx.x")
+        thread_x = te.thread_axis((0, num_thread), "threadIdx.x")
+
+        # (4) softmax
+        xo, xi = s[softmax].split(softmax.op.axis[1], nparts=num_thread)
+        if tgt.target_name != "nvptx":
+            _, xii = s[softmax].split(xi, factor=4)
+            s[softmax].vectorize(xii)
+        s[softmax].bind(xo, thread_x)
+        s[softmax].bind(softmax.op.axis[0], block_x)
+
+        # (3) expsum
+        k = expsum.op.reduce_axis[0]
+        ko, _ = s[expsum].split(k, nparts=num_thread)
+        s[expsum].bind(ko, thread_x)
+        s[expsum].compute_at(s[softmax], xo)
+
+        # (2) exp
+        if exp is not None:
+            xo, xi = s[exp].split(exp.op.axis[1], nparts=num_thread)
+            _, xii = s[exp].split(xi, factor=4)
+            s[exp].vectorize(xii)

Review comment:
       Good point, I forgot why I added this nvptx check. Now removed, 




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to