You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by mo...@apache.org on 2022/07/13 12:08:11 UTC

[tvm] branch main updated: [TOPI, x86] Properly handle fused ops in TE softmax schedule (#12015)

This is an automated email from the ASF dual-hosted git repository.

moreau pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new c30b420f61 [TOPI, x86] Properly handle fused ops in TE softmax schedule   (#12015)
c30b420f61 is described below

commit c30b420f61295fb60530dd01e84f8988605d72a5
Author: masahi <ma...@gmail.com>
AuthorDate: Wed Jul 13 21:08:05 2022 +0900

    [TOPI, x86] Properly handle fused ops in TE softmax schedule   (#12015)
    
    * fix x86 softmax fusion
    
    * properly handle the case where softmax and fuseed op having different layout
    
    * add test
---
 python/tvm/topi/x86/nn.py                     | 50 +++++++++++++++++----------
 tests/python/frontend/pytorch/test_forward.py | 40 +++++++++++++++++++++
 2 files changed, 71 insertions(+), 19 deletions(-)

diff --git a/python/tvm/topi/x86/nn.py b/python/tvm/topi/x86/nn.py
index 5475fc772e..5fd1108811 100644
--- a/python/tvm/topi/x86/nn.py
+++ b/python/tvm/topi/x86/nn.py
@@ -18,6 +18,7 @@
 """x86 nn operators"""
 from tvm import te
 from ..utils import traverse_inline
+from .injective import schedule_injective_from_existing
 
 
 def _schedule_softmax(softmax_op, s, outs):
@@ -48,28 +49,39 @@ def _schedule_softmax(softmax_op, s, outs):
             )
         )
 
-    # only parallelize outer dimensions up to axis
-    outer_axes = [s[softmax_op].op.axis[i] for i in range(0, axis)]
-    fused_outer_axes = s[softmax_op].fuse(*outer_axes)
-    s[softmax_op].parallel(fused_outer_axes)
+    output = outs[0]
 
-    # move computations with the same outer dimensions under the same root
-    s[max_elem].compute_at(s[softmax_op], fused_outer_axes)
-    s[expsum].compute_at(s[softmax_op], fused_outer_axes)
+    def _schedule(output_op, softmax_op):
+        # only parallelize outer dimensions up to axis
+        outer_axes = [output_op.axis[i] for i in range(0, axis)]
+        fused_outer_axes = s[output_op].fuse(*outer_axes)
+        s[output_op].parallel(fused_outer_axes)
 
-    if delta is not None:
-        s[exp].compute_inline()
-        s[delta].compute_inline()
-    if exp is not None:
-        s[exp].compute_at(s[softmax_op], fused_outer_axes)
+        if softmax_op != output_op:
+            # fuse softmax output with following elemwise ops.
+            s[softmax_op].compute_at(s[output_op], fused_outer_axes)
 
-    if softmax_op != outs[0].op:
-        # fuse softmax output with following elemwise ops.
-        output = outs[0]
-        outer_axes = [s[output].op.axis[i] for i in range(0, axis)]
-        fused_outer_axes = s[output].fuse(*outer_axes)
-        s[output].parallel(fused_outer_axes)
-        s[softmax_op].compute_at(s[output], fused_outer_axes)
+        # move computations with the same outer dimensions under the same root
+        s[max_elem].compute_at(s[output_op], fused_outer_axes)
+        s[expsum].compute_at(s[output_op], fused_outer_axes)
+
+        if delta is not None:
+            s[exp].compute_inline()
+            s[delta].compute_inline()
+        if exp is not None:
+            s[exp].compute_at(s[output_op], fused_outer_axes)
+
+    if list(output.shape) == list(softmax_op.output(0).shape):
+        _schedule(output.op, softmax_op)
+    else:
+        # This case can happen, for example, if the 4D input to softmax
+        # is in the NCHW layout while the fused elemwise op takes the NCHWc layout.
+        # Since we parallelize over outer axes up to the "axis" parameter of softmax,
+        # softmax and the fused op need to be in the same layout if we want to
+        # fuse them under the same parallel loop.
+        # This case can be removed if softmax supported AlterLayout.
+        schedule_injective_from_existing(s, output)
+        _schedule(softmax_op, softmax_op)
 
 
 def schedule_softmax(outs):
diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py
index 30ba713396..cd7c50d486 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -4544,5 +4544,45 @@ def test_mod():
         verify_model(test_fn, [torch.tensor([1, 2, 3, 4, 5]), torch.tensor(-1.5)])
 
 
+def test_softmax_fuse():
+    # https://github.com/apache/tvm/issues/12001
+    class Model(torch.nn.Module):
+        def __init__(self, nchwc_post_op=False) -> None:
+            super().__init__()
+            self.conv = torch.nn.Conv2d(3, 3, (1, 1), 1)
+            self.nchwc_post_op = nchwc_post_op
+
+        @torch.no_grad()
+        def forward(self, x):
+            t0a = self.conv(x)
+            t0b = torch.floor(x)
+            t2b = torch.softmax(t0a, dim=2)
+
+            if self.nchwc_post_op:
+                t3a = t0a - t0b
+                t4a = t2b - t0b
+                t6a = t3a + t4a
+                return t6a
+
+            return t2b + 1
+
+    sh = [3, 3, 10, 1]
+    inp = torch.ones(*sh, dtype=torch.float32)
+
+    for model in [Model(nchwc_post_op=False).eval(), Model(nchwc_post_op=True).eval()]:
+        output_torch = model(inp).numpy()
+
+        mod, params = relay.frontend.from_pytorch(torch.jit.trace(model, inp), [("inp0", sh)])
+
+        with tvm.transform.PassContext(opt_level=4):
+            out = (
+                relay.create_executor("graph", mod, params=params)
+                .evaluate()(inp0=inp.numpy())
+                .numpy()
+            )
+
+        tvm.testing.assert_allclose(out, output_torch, rtol=1e-5, atol=1e-5)
+
+
 if __name__ == "__main__":
     pytest.main([__file__])