You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/03/16 21:41:44 UTC

[GitHub] [tvm] comaniac commented on pull request #7675: [Torch] Remove unnecessary reshapes for batch_matmul

comaniac commented on pull request #7675:
URL: https://github.com/apache/tvm/pull/7675#issuecomment-800626924


   Pushed a new commit to also reorder the reshape_b and transpose so that the simplify expression can be used.
   
   Before this PR:
   
   ```
   fn (%input0: Tensor[(10, 3, 4), float32], %input1: Tensor[(10, 4, 5), float32]) -> Tensor[(10, 3, 5), float32] {
     %0 = reshape(%input0, newshape=[-1, 3, 4]) /* ty=Tensor[(10, 3, 4), float32] */;
     %1 = reshape(%input1, newshape=[-1, 4, 5]) /* ty=Tensor[(10, 4, 5), float32] */;
     %2 = transpose(%1, axes=[0, 2, 1]) /* ty=Tensor[(10, 5, 4), float32] */;
     %3 = nn.batch_matmul(%0, %2, meta[relay.attrs.BatchMatmulAttrs][0]) /* ty=Tensor[(10, 3, 5), float32] */;
     reshape(%3, newshape=[10, 3, 5]) /* ty=Tensor[(10, 3, 5), float32] */
   }
   
   fn (%input0: Tensor[(10, 3, 4), float32], %input1: Tensor[(4, 5), float32]) -> Tensor[(10, 3, 5), float32] {
     %0 = reshape(%input0, newshape=[-1, 3, 4]) /* ty=Tensor[(10, 3, 4), float32] */;
     %1 = reshape(%input1, newshape=[-1, 4, 5]) /* ty=Tensor[(1, 4, 5), float32] */;
     %2 = transpose(%1, axes=[0, 2, 1]) /* ty=Tensor[(1, 5, 4), float32] */;
     %3 = nn.batch_matmul(%0, %2, meta[relay.attrs.BatchMatmulAttrs][0]) /* ty=Tensor[(10, 3, 5), float32] */;
     reshape(%3, newshape=[10, 3, 5]) /* ty=Tensor[(10, 3, 5), float32] */
   }
   
   fn (%input0: Tensor[(1, 12, 14, 64), float32], %input1: Tensor[(1, 12, 64, 14), float32]) -> Tensor[(1, 12, 14, 14), float32] {
     %0 = reshape(%input0, newshape=[-1, 14, 64]) /* ty=Tensor[(12, 14, 64), float32] */;
     %1 = reshape(%input1, newshape=[-1, 64, 14]) /* ty=Tensor[(12, 64, 14), float32] */;
     %2 = transpose(%1, axes=[0, 2, 1]) /* ty=Tensor[(12, 14, 64), float32] */;
     %3 = nn.batch_matmul(%0, %2, meta[relay.attrs.BatchMatmulAttrs][0]) /* ty=Tensor[(12, 14, 14), float32] */;
     reshape(%3, newshape=[1, 12, 14, 14]) /* ty=Tensor[(1, 12, 14, 14), float32] */
   }
   ```
   
   After this PR:
   
   ```
   fn (%input0: Tensor[(10, 3, 4), float32], %input1: Tensor[(10, 4, 5), float32]) -> Tensor[(10, 3, 5), float32] {
     %0 = transpose(%input1, axes=[0, 2, 1]) /* ty=Tensor[(10, 5, 4), float32] */;
     nn.batch_matmul(%input0, %0, meta[relay.attrs.BatchMatmulAttrs][0]) /* ty=Tensor[(10, 3, 5), float32] */
   }
   
   fn (%input0: Tensor[(10, 3, 4), float32], %input1: Tensor[(4, 5), float32]) -> Tensor[(10, 3, 5), float32] {
     %0 = transpose(%input1, axes=[1, 0]) /* ty=Tensor[(5, 4), float32] */;
     %1 = reshape(%0, newshape=[-1, 5, 4]) /* ty=Tensor[(1, 5, 4), float32] */;
     nn.batch_matmul(%input0, %1, meta[relay.attrs.BatchMatmulAttrs][0]) /* ty=Tensor[(10, 3, 5), float32] */
   }
   
   fn (%input0: Tensor[(1, 12, 14, 64), float32], %input1: Tensor[(1, 12, 64, 14), float32]) -> Tensor[(1, 12, 14, 14), float32] {
     %0 = reshape(%input0, newshape=[-1, 14, 64]) /* ty=Tensor[(12, 14, 64), float32] */;
     %1 = transpose(%input1, axes=[0, 1, 3, 2]) /* ty=Tensor[(1, 12, 14, 64), float32] */;
     %2 = reshape(%1, newshape=[-1, 14, 64]) /* ty=Tensor[(12, 14, 64), float32] */;
     %3 = nn.batch_matmul(%0, %2, meta[relay.attrs.BatchMatmulAttrs][0]) /* ty=Tensor[(12, 14, 14), float32] */;
     reshape(%3, newshape=[1, 12, 14, 14]) /* ty=Tensor[(1, 12, 14, 14), float32] */
   }
   ```
   
   In particular, since the weights in most PyTorch models have to be transposed when converting to Relay, the second case, for example, could be:
   
   ```
   fn (%input0: Tensor[(10, 3, 4), float32], %input1: Tensor[(5, 4), float32]) -> Tensor[(10, 3, 5), float32] {
     %0 = transpose(%input1, axes=[1, 0]) /* ty=Tensor[(4, 5), float32] */; <- Not added by matmul
     %1 = transpose(%0, axes=[1, 0]) /* ty=Tensor[(5, 4), float32] */; <- Added by matmul
     %2 = reshape(%1, newshape=[-1, 5, 4]) /* ty=Tensor[(1, 5, 4), float32] */;
     nn.batch_matmul(%input0, %2, meta[relay.attrs.BatchMatmulAttrs][0]) /* ty=Tensor[(10, 3, 5), float32] */
   }
   ```
   
   By applying SimplifyExpr to cancel unnecessary `transpose`, we could have:
   
   ```
   fn (%input0: Tensor[(10, 3, 4), float32], %input1: Tensor[(5, 4), float32]) -> Tensor[(10, 3, 5), float32] {
     %0 = reshape(%input1, newshape=[-1, 5, 4]) /* ty=Tensor[(1, 5, 4), float32] */;
     nn.batch_matmul(%input0, %0, meta[relay.attrs.BatchMatmulAttrs][0]) /* ty=Tensor[(10, 3, 5), float32] */
   }
   ```


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org