You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by mo...@apache.org on 2022/07/13 12:08:11 UTC
[tvm] branch main updated: [TOPI, x86] Properly handle fused ops in TE softmax schedule (#12015)
This is an automated email from the ASF dual-hosted git repository.
moreau pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new c30b420f61 [TOPI, x86] Properly handle fused ops in TE softmax schedule (#12015)
c30b420f61 is described below
commit c30b420f61295fb60530dd01e84f8988605d72a5
Author: masahi <ma...@gmail.com>
AuthorDate: Wed Jul 13 21:08:05 2022 +0900
[TOPI, x86] Properly handle fused ops in TE softmax schedule (#12015)
* fix x86 softmax fusion
* properly handle the case where softmax and fuseed op having different layout
* add test
---
python/tvm/topi/x86/nn.py | 50 +++++++++++++++++----------
tests/python/frontend/pytorch/test_forward.py | 40 +++++++++++++++++++++
2 files changed, 71 insertions(+), 19 deletions(-)
diff --git a/python/tvm/topi/x86/nn.py b/python/tvm/topi/x86/nn.py
index 5475fc772e..5fd1108811 100644
--- a/python/tvm/topi/x86/nn.py
+++ b/python/tvm/topi/x86/nn.py
@@ -18,6 +18,7 @@
"""x86 nn operators"""
from tvm import te
from ..utils import traverse_inline
+from .injective import schedule_injective_from_existing
def _schedule_softmax(softmax_op, s, outs):
@@ -48,28 +49,39 @@ def _schedule_softmax(softmax_op, s, outs):
)
)
- # only parallelize outer dimensions up to axis
- outer_axes = [s[softmax_op].op.axis[i] for i in range(0, axis)]
- fused_outer_axes = s[softmax_op].fuse(*outer_axes)
- s[softmax_op].parallel(fused_outer_axes)
+ output = outs[0]
- # move computations with the same outer dimensions under the same root
- s[max_elem].compute_at(s[softmax_op], fused_outer_axes)
- s[expsum].compute_at(s[softmax_op], fused_outer_axes)
+ def _schedule(output_op, softmax_op):
+ # only parallelize outer dimensions up to axis
+ outer_axes = [output_op.axis[i] for i in range(0, axis)]
+ fused_outer_axes = s[output_op].fuse(*outer_axes)
+ s[output_op].parallel(fused_outer_axes)
- if delta is not None:
- s[exp].compute_inline()
- s[delta].compute_inline()
- if exp is not None:
- s[exp].compute_at(s[softmax_op], fused_outer_axes)
+ if softmax_op != output_op:
+ # fuse softmax output with following elemwise ops.
+ s[softmax_op].compute_at(s[output_op], fused_outer_axes)
- if softmax_op != outs[0].op:
- # fuse softmax output with following elemwise ops.
- output = outs[0]
- outer_axes = [s[output].op.axis[i] for i in range(0, axis)]
- fused_outer_axes = s[output].fuse(*outer_axes)
- s[output].parallel(fused_outer_axes)
- s[softmax_op].compute_at(s[output], fused_outer_axes)
+ # move computations with the same outer dimensions under the same root
+ s[max_elem].compute_at(s[output_op], fused_outer_axes)
+ s[expsum].compute_at(s[output_op], fused_outer_axes)
+
+ if delta is not None:
+ s[exp].compute_inline()
+ s[delta].compute_inline()
+ if exp is not None:
+ s[exp].compute_at(s[output_op], fused_outer_axes)
+
+ if list(output.shape) == list(softmax_op.output(0).shape):
+ _schedule(output.op, softmax_op)
+ else:
+ # This case can happen, for example, if the 4D input to softmax
+ # is in the NCHW layout while the fused elemwise op takes the NCHWc layout.
+ # Since we parallelize over outer axes up to the "axis" parameter of softmax,
+ # softmax and the fused op need to be in the same layout if we want to
+ # fuse them under the same parallel loop.
+ # This case can be removed if softmax supported AlterLayout.
+ schedule_injective_from_existing(s, output)
+ _schedule(softmax_op, softmax_op)
def schedule_softmax(outs):
diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py
index 30ba713396..cd7c50d486 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -4544,5 +4544,45 @@ def test_mod():
verify_model(test_fn, [torch.tensor([1, 2, 3, 4, 5]), torch.tensor(-1.5)])
+def test_softmax_fuse():
+ # https://github.com/apache/tvm/issues/12001
+ class Model(torch.nn.Module):
+ def __init__(self, nchwc_post_op=False) -> None:
+ super().__init__()
+ self.conv = torch.nn.Conv2d(3, 3, (1, 1), 1)
+ self.nchwc_post_op = nchwc_post_op
+
+ @torch.no_grad()
+ def forward(self, x):
+ t0a = self.conv(x)
+ t0b = torch.floor(x)
+ t2b = torch.softmax(t0a, dim=2)
+
+ if self.nchwc_post_op:
+ t3a = t0a - t0b
+ t4a = t2b - t0b
+ t6a = t3a + t4a
+ return t6a
+
+ return t2b + 1
+
+ sh = [3, 3, 10, 1]
+ inp = torch.ones(*sh, dtype=torch.float32)
+
+ for model in [Model(nchwc_post_op=False).eval(), Model(nchwc_post_op=True).eval()]:
+ output_torch = model(inp).numpy()
+
+ mod, params = relay.frontend.from_pytorch(torch.jit.trace(model, inp), [("inp0", sh)])
+
+ with tvm.transform.PassContext(opt_level=4):
+ out = (
+ relay.create_executor("graph", mod, params=params)
+ .evaluate()(inp0=inp.numpy())
+ .numpy()
+ )
+
+ tvm.testing.assert_allclose(out, output_torch, rtol=1e-5, atol=1e-5)
+
+
if __name__ == "__main__":
pytest.main([__file__])