comaniac commented on a change in pull request #8909:
URL: https://github.com/apache/tvm/pull/8909#discussion_r701317606



##########
File path: python/tvm/topi/cuda/softmax.py
##########
@@ -71,41 +54,53 @@ def schedule_softmax(outs):
     #
     # TODO(tvm-team) Fix nvptx codegen or deprecate nvptx backend.
     def sched_warp_softmax():
-        if tgt.kind.name == "nvptx" or tgt.kind.name == "rocm":
-            return softmax.dtype == "float32" or softmax.dtype == "int32"
+        if tgt.kind.name in ["nvptx", "rocm"]:
+            dtype = softmax_op.output(0).dtype
+            return dtype in ["float32", "int32"]
         if tgt.kind.name != "cuda":
-            # this is used as the gpu schedule for other arches which may not 
have warp reductions
+            # this is used as the gpu schedule for other arches which
+            # may not have warp reductions
             return False
         return True
 
-    if len(softmax.shape) > 2:
-        ops = [max_elem.op, expsum.op, softmax.op]
+    if len(outs[0].shape) > 2:
+        ops = [max_elem.op, expsum.op, softmax_op]
         if delta is not None:
             ops.append(delta.op)
         if exp is not None:
             ops.append(exp.op)
+        if softmax_op != outs[0]:
+            ops.append(outs[0].op)
 
         for op in ops:
             s = schedule_injective_from_existing(s, op.output(0))
 
-    elif sched_warp_softmax():
+    elif sched_warp_softmax() and softmax_op == outs[0].op:
+        # TODO(masahi): Fix LowerThreadAllreduce pass to remove
+        # softmax_op == outs[0].op condition

Review comment:
       Thanks for the explanation. Do you think it would be better if we add a 
test for this case (expect not fused for now), and make it fused when sending 
the fix PR?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to