You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/09/02 09:48:24 UTC

[GitHub] [tvm] masahi commented on a change in pull request #8909: [Relay, TOPI] Make Softmax op fusible with elemwise ops

masahi commented on a change in pull request #8909:
URL: https://github.com/apache/tvm/pull/8909#discussion_r700927350



##########
File path: python/tvm/topi/cuda/softmax.py
##########
@@ -71,41 +54,53 @@ def schedule_softmax(outs):
     #
     # TODO(tvm-team) Fix nvptx codegen or deprecate nvptx backend.
     def sched_warp_softmax():
-        if tgt.kind.name == "nvptx" or tgt.kind.name == "rocm":
-            return softmax.dtype == "float32" or softmax.dtype == "int32"
+        if tgt.kind.name in ["nvptx", "rocm"]:
+            dtype = softmax_op.output(0).dtype
+            return dtype in ["float32", "int32"]
         if tgt.kind.name != "cuda":
-            # this is used as the gpu schedule for other arches which may not have warp reductions
+            # this is used as the gpu schedule for other arches which
+            # may not have warp reductions
             return False
         return True
 
-    if len(softmax.shape) > 2:
-        ops = [max_elem.op, expsum.op, softmax.op]
+    if len(outs[0].shape) > 2:
+        ops = [max_elem.op, expsum.op, softmax_op]
         if delta is not None:
             ops.append(delta.op)
         if exp is not None:
             ops.append(exp.op)
+        if softmax_op != outs[0]:
+            ops.append(outs[0].op)
 
         for op in ops:
             s = schedule_injective_from_existing(s, op.output(0))
 
-    elif sched_warp_softmax():
+    elif sched_warp_softmax() and softmax_op == outs[0].op:
+        # TODO(masahi): Fix LowerThreadAllreduce pass to remove
+        # softmax_op == outs[0].op condition

Review comment:
       I had to add this check `softmax_op == outs[0].op:` to disable the warp reduction schedule when there are elemwise ops to be fused.  This is due to a bug in `lower_thread_allreduce.cc`, which I have a fix ready locally. 
   
   But constructing a test case requires changes in softmax schedules in this PR, so I'll send a bug fix PR after this PR is merged. This workaround will be removed in the next PR.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org