masahi commented on a change in pull request #8909:
URL: https://github.com/apache/tvm/pull/8909#discussion_r701406339



##########
File path: python/tvm/topi/cuda/softmax.py
##########
@@ -71,41 +54,53 @@ def schedule_softmax(outs):
     #
     # TODO(tvm-team) Fix nvptx codegen or deprecate nvptx backend.
     def sched_warp_softmax():
-        if tgt.kind.name == "nvptx" or tgt.kind.name == "rocm":
-            return softmax.dtype == "float32" or softmax.dtype == "int32"
+        if tgt.kind.name in ["nvptx", "rocm"]:
+            dtype = softmax_op.output(0).dtype
+            return dtype in ["float32", "int32"]
         if tgt.kind.name != "cuda":
-            # this is used as the gpu schedule for other arches which may not 
have warp reductions
+            # this is used as the gpu schedule for other arches which
+            # may not have warp reductions
             return False
         return True
 
-    if len(softmax.shape) > 2:
-        ops = [max_elem.op, expsum.op, softmax.op]
+    if len(outs[0].shape) > 2:
+        ops = [max_elem.op, expsum.op, softmax_op]
         if delta is not None:
             ops.append(delta.op)
         if exp is not None:
             ops.append(exp.op)
+        if softmax_op != outs[0]:
+            ops.append(outs[0].op)
 
         for op in ops:
             s = schedule_injective_from_existing(s, op.output(0))
 
-    elif sched_warp_softmax():
+    elif sched_warp_softmax() and softmax_op == outs[0].op:
+        # TODO(masahi): Fix LowerThreadAllreduce pass to remove
+        # softmax_op == outs[0].op condition

Review comment:
       For the cuda backend, the softmax schedule always tries to use this warp 
reduction schedule. So the test case added in this PR would hit this pass, 
which will be skipped until my next bug fix PR (will be coming as soon as this 
PR is merged).




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to