You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/03/16 20:39:28 UTC

[GitHub] [tvm] comaniac opened a new pull request #7675: [Torch] Remove unnecessary reshapes for batch_matmul

comaniac opened a new pull request #7675:
URL: https://github.com/apache/tvm/pull/7675


   This PR removes unnecessary reshape ops in the PyTorch frontend when converting to batch_matmul. This should help the performance of NLP models such as BERT.
   
   cc @siju-samuel @masahi 
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] comaniac commented on pull request #7675: [Torch] Remove unnecessary reshapes for batch_matmul

Posted by GitBox <gi...@apache.org>.
comaniac commented on pull request #7675:
URL: https://github.com/apache/tvm/pull/7675#issuecomment-800626924


   Pushed a new commit to also reorder the reshape_b and transpose so that the simplify expression can be used.
   
   Before this PR:
   
   ```
   fn (%input0: Tensor[(10, 3, 4), float32], %input1: Tensor[(10, 4, 5), float32]) -> Tensor[(10, 3, 5), float32] {
     %0 = reshape(%input0, newshape=[-1, 3, 4]) /* ty=Tensor[(10, 3, 4), float32] */;
     %1 = reshape(%input1, newshape=[-1, 4, 5]) /* ty=Tensor[(10, 4, 5), float32] */;
     %2 = transpose(%1, axes=[0, 2, 1]) /* ty=Tensor[(10, 5, 4), float32] */;
     %3 = nn.batch_matmul(%0, %2, meta[relay.attrs.BatchMatmulAttrs][0]) /* ty=Tensor[(10, 3, 5), float32] */;
     reshape(%3, newshape=[10, 3, 5]) /* ty=Tensor[(10, 3, 5), float32] */
   }
   
   fn (%input0: Tensor[(10, 3, 4), float32], %input1: Tensor[(4, 5), float32]) -> Tensor[(10, 3, 5), float32] {
     %0 = reshape(%input0, newshape=[-1, 3, 4]) /* ty=Tensor[(10, 3, 4), float32] */;
     %1 = reshape(%input1, newshape=[-1, 4, 5]) /* ty=Tensor[(1, 4, 5), float32] */;
     %2 = transpose(%1, axes=[0, 2, 1]) /* ty=Tensor[(1, 5, 4), float32] */;
     %3 = nn.batch_matmul(%0, %2, meta[relay.attrs.BatchMatmulAttrs][0]) /* ty=Tensor[(10, 3, 5), float32] */;
     reshape(%3, newshape=[10, 3, 5]) /* ty=Tensor[(10, 3, 5), float32] */
   }
   
   fn (%input0: Tensor[(1, 12, 14, 64), float32], %input1: Tensor[(1, 12, 64, 14), float32]) -> Tensor[(1, 12, 14, 14), float32] {
     %0 = reshape(%input0, newshape=[-1, 14, 64]) /* ty=Tensor[(12, 14, 64), float32] */;
     %1 = reshape(%input1, newshape=[-1, 64, 14]) /* ty=Tensor[(12, 64, 14), float32] */;
     %2 = transpose(%1, axes=[0, 2, 1]) /* ty=Tensor[(12, 14, 64), float32] */;
     %3 = nn.batch_matmul(%0, %2, meta[relay.attrs.BatchMatmulAttrs][0]) /* ty=Tensor[(12, 14, 14), float32] */;
     reshape(%3, newshape=[1, 12, 14, 14]) /* ty=Tensor[(1, 12, 14, 14), float32] */
   }
   ```
   
   After this PR:
   
   ```
   fn (%input0: Tensor[(10, 3, 4), float32], %input1: Tensor[(10, 4, 5), float32]) -> Tensor[(10, 3, 5), float32] {
     %0 = transpose(%input1, axes=[0, 2, 1]) /* ty=Tensor[(10, 5, 4), float32] */;
     nn.batch_matmul(%input0, %0, meta[relay.attrs.BatchMatmulAttrs][0]) /* ty=Tensor[(10, 3, 5), float32] */
   }
   
   fn (%input0: Tensor[(10, 3, 4), float32], %input1: Tensor[(4, 5), float32]) -> Tensor[(10, 3, 5), float32] {
     %0 = transpose(%input1, axes=[1, 0]) /* ty=Tensor[(5, 4), float32] */;
     %1 = reshape(%0, newshape=[-1, 5, 4]) /* ty=Tensor[(1, 5, 4), float32] */;
     nn.batch_matmul(%input0, %1, meta[relay.attrs.BatchMatmulAttrs][0]) /* ty=Tensor[(10, 3, 5), float32] */
   }
   
   fn (%input0: Tensor[(1, 12, 14, 64), float32], %input1: Tensor[(1, 12, 64, 14), float32]) -> Tensor[(1, 12, 14, 14), float32] {
     %0 = reshape(%input0, newshape=[-1, 14, 64]) /* ty=Tensor[(12, 14, 64), float32] */;
     %1 = transpose(%input1, axes=[0, 1, 3, 2]) /* ty=Tensor[(1, 12, 14, 64), float32] */;
     %2 = reshape(%1, newshape=[-1, 14, 64]) /* ty=Tensor[(12, 14, 64), float32] */;
     %3 = nn.batch_matmul(%0, %2, meta[relay.attrs.BatchMatmulAttrs][0]) /* ty=Tensor[(12, 14, 14), float32] */;
     reshape(%3, newshape=[1, 12, 14, 14]) /* ty=Tensor[(1, 12, 14, 14), float32] */
   }
   ```
   
   In particular, since the weights in most PyTorch models have to be transposed when converting to Relay, the second case, for example, could be:
   
   ```
   fn (%input0: Tensor[(10, 3, 4), float32], %input1: Tensor[(5, 4), float32]) -> Tensor[(10, 3, 5), float32] {
     %0 = transpose(%input1, axes=[1, 0]) /* ty=Tensor[(4, 5), float32] */; <- Not added by matmul
     %1 = transpose(%0, axes=[1, 0]) /* ty=Tensor[(5, 4), float32] */; <- Added by matmul
     %2 = reshape(%1, newshape=[-1, 5, 4]) /* ty=Tensor[(1, 5, 4), float32] */;
     nn.batch_matmul(%input0, %2, meta[relay.attrs.BatchMatmulAttrs][0]) /* ty=Tensor[(10, 3, 5), float32] */
   }
   ```
   
   By applying SimplifyExpr to cancel unnecessary `transpose`, we could have:
   
   ```
   fn (%input0: Tensor[(10, 3, 4), float32], %input1: Tensor[(5, 4), float32]) -> Tensor[(10, 3, 5), float32] {
     %0 = reshape(%input1, newshape=[-1, 5, 4]) /* ty=Tensor[(1, 5, 4), float32] */;
     nn.batch_matmul(%input0, %0, meta[relay.attrs.BatchMatmulAttrs][0]) /* ty=Tensor[(10, 3, 5), float32] */
   }
   ```


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi merged pull request #7675: [Torch] Remove unnecessary reshapes for batch_matmul

Posted by GitBox <gi...@apache.org>.
masahi merged pull request #7675:
URL: https://github.com/apache/tvm/pull/7675


   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] comaniac commented on a change in pull request #7675: [Torch] Remove unnecessary reshapes for batch_matmul

Posted by GitBox <gi...@apache.org>.
comaniac commented on a change in pull request #7675:
URL: https://github.com/apache/tvm/pull/7675#discussion_r595554053



##########
File path: tests/python/frontend/pytorch/test_forward.py
##########
@@ -201,6 +201,7 @@ def verify_model(model_name, input_data=[], custom_convert_map={}, rtol=1e-5, at
     input_names = ["input{}".format(idx) for idx, inp in enumerate(baseline_input)]
     input_shapes = list(zip(input_names, [inp.shape for inp in baseline_input]))
     mod, params = relay.frontend.from_pytorch(trace, input_shapes, custom_convert_map)
+    print(relay.transform.InferType()(mod)["main"])

Review comment:
       Yeah just found that. Thanks :)




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on a change in pull request #7675: [Torch] Remove unnecessary reshapes for batch_matmul

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #7675:
URL: https://github.com/apache/tvm/pull/7675#discussion_r595552313



##########
File path: tests/python/frontend/pytorch/test_forward.py
##########
@@ -201,6 +201,7 @@ def verify_model(model_name, input_data=[], custom_convert_map={}, rtol=1e-5, at
     input_names = ["input{}".format(idx) for idx, inp in enumerate(baseline_input)]
     input_shapes = list(zip(input_names, [inp.shape for inp in baseline_input]))
     mod, params = relay.frontend.from_pytorch(trace, input_shapes, custom_convert_map)
+    print(relay.transform.InferType()(mod)["main"])

Review comment:
       remove this?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on pull request #7675: [Torch] Remove unnecessary reshapes for batch_matmul

Posted by GitBox <gi...@apache.org>.
masahi commented on pull request #7675:
URL: https://github.com/apache/tvm/pull/7675#issuecomment-800782361


   Thanks @comaniac 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org