You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/11/09 08:10:40 UTC

[GitHub] [tvm] Lyken17 commented on pull request #9465: [Conv2DTransposed] Fix wrong shape check and add new TOPI module to support groups

Lyken17 commented on pull request #9465:
URL: https://github.com/apache/tvm/pull/9465#issuecomment-963906327


   Hi @vin13,
   
   This is bit complex. Let me explain the situation. The bug was initially found when I tried to calculate the gradients of `nn.Conv2d` with `groups`
   
   ```python
   program = """
   def @main(%input0: Tensor[(1, 32, 224, 224), float32], 
           %v0_0_weight: Tensor[(32, 1, 3, 3), float32], 
           %v1_conv_0_0_weight: Tensor[(32, 1, 3, 3), float32], 
           %v1_conv_1_weight: Tensor[(16, 32, 1, 1), float32]) -> Tensor[(1, 32, 224, 224), float32] {
     %0 = nn.conv2d(%input0, %v0_0_weight, strides=[1, 1], padding=[1, 1, 1, 1], groups=32, channels=32, kernel_size=[3, 3]);
     %0
   }
   """
   
   mod = parse_module(program)
   
   mod = relay.transform.InferType()(mod)
   bwd_ir = relay.transform.gradient(mod['main'], mode="first_order")
   bwd_mod = tvm.IRModule.from_expr(bwd_ir)
   
   """
   fn (%input0: Tensor[(1, 32, 224, 224), float32], %v0_0_weight: Tensor[(32, 1, 3, 3), float32], %v1_conv_0_0_weight: Tensor[(32, 1, 3, 3), float32], %v1_conv_1_weight: Tensor[(16, 32, 1, 1), float32]) -> (Tensor[(1, 32, 224, 224), float32], (Tensor[(1, 32, 224, 224), float32], Tensor[(32, 1, 3, 3), float32], Tensor[(32, 1, 3, 3), float32], Tensor[(16, 32, 1, 1), float32])) {
     let %x = %input0;
     let %x1 = zeros_like(%x);
     let %x2 = %v0_0_weight;
     let %x3 = zeros_like(%x2);
     let %x4 = %v1_conv_0_0_weight;
     let %x5 = zeros_like(%x4);
     let %x6 = %v1_conv_1_weight;
     let %x7 = zeros_like(%x6);
     let %x8 = nn.conv2d(%x, %x2, padding=[1, 1, 1, 1], groups=32, channels=32, kernel_size=[3, 3]) /* ty=Tensor[(1, 32, 224, 224), float32] */;
     let %x9 = zeros_like(%x8);
     %0 = ones_like(%x8);
     %1 = nn.conv2d_transpose(%0, %x2, padding=[1, 1, 1, 1], groups=32);
     %9 = (
       let %x10 = add(%x1, %1);
       %2 = tile(%0, reps=[1, 1, 1, 1]);
       %3 = reshape(%x, newshape=[1, -1, 0, 0]);
       %4 = reshape(%2, newshape=[-1, 1, 0, 0]);
       %5 = nn.conv2d(%3, %4, padding=[1, 1, 1, 1], groups=32);
       %6 = reshape(%5, newshape=[1, 1, 32, 3, 3]);
       %7 = sum(%6, axis=[0]);
       %8 = transpose(%7, axes=[1, 0, 2, 3]);
       let %x11 = add(%x3, %8);
       (%x10, %x11, %x5, %x7)
     );
     (%x8, %9)
   }
   """
   ```
   
   It is shown that `%1 = nn.conv2d_transpose(%0, %x2, padding=[1, 1, 1, 1], groups=32);` is transformed to `%1 = nn.conv2d_transpose(%0, %x2, padding=[1, 1, 1, 1], groups=32);`. The shape of `%2` is (32, 1, 3, 3), which is `OIHW` for Conv2d and `IOHW` for Conv2dTransposed. This is consistent with PyTorch [[1](https://pytorch.org/docs/stable/generated/torch.nn.functional.conv2d.html),[2](https://pytorch.org/docs/stable/generated/torch.nn.functional.conv_transpose2d.html#torch.nn.functional.conv_transpose2d)] and PaddlePaddle[[1](https://www.paddlepaddle.org.cn/documentation/docs/en/1.8/api/dygraph/Conv2DTranspose.html)]. If the original purpose of `topi.nn.conv2d_transpose` is to use `OIHW` rather than `IOHW`, then the primal gradients registration is wrong. 
   
   From my personal perspective, I prefer `IOHW` more since `conv2dtranspose` should have an `transposed` weight and this make tvm consistent pytorch. Is there any specific reason (e.g., performance tuning) that makes `OIHW` better than `IOHW`?
   
   
   In either case, the `groups` support for conv2dtranspose is missing and should added. 
   
   
   
   
   
   
   
   
    


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org