You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/11/06 18:08:33 UTC

[GitHub] [tvm] Lyken17 opened a new pull request #9465: [Conv2DTransposed] Fix wrong shape check

Lyken17 opened a new pull request #9465:
URL: https://github.com/apache/tvm/pull/9465


   The default shape format of TVM is `N x Cx iH x iW` for input and `O x I x kH x kW` for weight, a proper shape for Conv2dTransposed should 
   
   * input: (batch, in_channels, iH, iW)
   * weight: (out_channels, in_channels // groups, kH, kW)
   
   Thus the original checking
   ```
   ICHECK(reporter->AssertEQ(indexdiv(dshape_nchw[1], param->groups), wshape[0]));
   ```  
   is wrong. The proper comparison dimension should be `wshape[1]` rather than `wshape[0]`.
   
   Besides, the name for debug is not correct either. All logging information are using `conv2d` rather than `conv2d_transposed`, which is confusing. 
   
   ## Example to trigger error in current implementation
   
   ```python
   import torch
   import torch as th
   import torch.nn as nn
   from torchvision import models
   import torch.onnx 
   
   import numpy as np
   
   import tvm
   from tvm import relay
   from tvm import relay, auto_scheduler
   from tvm.relay import testing
   
   
   SEMVER = '#[version = "0.0.5"]\n'
   
   def assert_graph_equal(lhs, rhs):
       tvm.ir.assert_structural_equal(lhs, rhs, map_free_vars=True)
   
   def roundtrip(expr):
       x = tvm.parser.fromtext(expr.astext())
       assert_graph_equal(x, expr)
   
   # Testing Utilities for full modules.
   def parse_module(code):
       mod = tvm.parser.parse(SEMVER + code)
       roundtrip(mod)
       return mod
   
   
   program = """
   def @main(%input0: Tensor[(1, 32, 224, 224), float32], 
           %v0_0_weight: Tensor[(32, 1, 3, 3), float32]) -> Tensor[(1, 32, 224, 224), float32] {
     /* test comment */
     %0 = nn.conv2d_transpose(%input0, %v0_0_weight, strides=[1, 1], padding=[1, 1, 1, 1], groups=32, channels=32, kernel_size=[3, 3]);
     %0
   }
   """
   
   mod = parse_module(program)
   print(mod)
   
   target = "llvm"
   lib = relay.build(mod, target=target, params=None)
   print("build [fwd] pass successful")
   
   ```
   
   
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Lyken17 commented on pull request #9465: [Conv2DTransposed] Fix wrong shape check and add new TOPI module to support groups

Posted by GitBox <gi...@apache.org>.
Lyken17 commented on pull request #9465:
URL: https://github.com/apache/tvm/pull/9465#issuecomment-966434464


   @vinx13 I've addressed the issue you mentioned and lint the file. 
   
   The code should be ready to merge and feel free to point out if there is any placed needed to be updated.
   
   For test cases, https://github.com/apache/tvm/blob/main/tests/python/relay/test_op_level2.py, I am afraid that tests may not be able to added yet. The shape check about relay is somehow buggy and might be related with https://github.com/apache/tvm/pull/9336 to fix.
   
   Currently 
   
   ```python
   n, h, w, c = te.size_var("n"), 10, 10, 12
   g = 4
   x = relay.var("x", relay.TensorType((n, h, w, c), "float32"))
   w = relay.var("w", relay.TensorType((12, 16 // g, 5, 5), "float32"))
   y = relay.nn.conv2d_transpose(x, w, output_padding=(1, 1), channels=16, groups=g, data_layout="NHWC")
   yy = run_infer_type(y)
   assert yy.checked_type == relay.TensorType((n, 15, 15, 16), "float32"), yy.checked_type
   # FAILED tests/python/relay/test_op_level2.py::test_conv2d_transpose_infer_type - AssertionError: TensorType([{n|n>=0}, 15, 15, 4], float32)
   ```
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on pull request #9465: [Conv2DTransposed] Fix wrong shape check and add new TOPI module to support groups

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on pull request #9465:
URL: https://github.com/apache/tvm/pull/9465#issuecomment-966729192


   Thank you all for the effort! @Lyken17 let's do the last mile, fixing the merge conflict and get the CI green :-)


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Lyken17 commented on a change in pull request #9465: [Conv2DTransposed] Fix wrong shape check and add new TOPI module to support groups

Posted by GitBox <gi...@apache.org>.
Lyken17 commented on a change in pull request #9465:
URL: https://github.com/apache/tvm/pull/9465#discussion_r747569408



##########
File path: python/tvm/relay/op/strategy/x86.py
##########
@@ -281,13 +281,25 @@ def conv2d_transpose_strategy_cpu(attrs, inputs, out_type, target):
     groups = attrs.groups
     assert layout == "NCHW", "only support nchw for now"
     assert dilation == (1, 1), "not support dilate now"
-    assert groups == 1, "only support groups == 1 for now"
+    # assert groups == 1, "only support groups == 1 for now"

Review comment:
       Removed

##########
File path: python/tvm/relay/op/strategy/generic.py
##########
@@ -471,13 +475,20 @@ def conv2d_transpose_strategy(attrs, inputs, out_type, target):
     groups = attrs.groups
     assert layout == "NCHW", "only support nchw for now"
     assert dilation == (1, 1), "not support dilate now"
-    assert groups == 1, "only support groups == 1 for now"
+    # assert groups == 1, "only support groups == 1 for now"

Review comment:
       removed




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Lyken17 commented on pull request #9465: [Conv2DTransposed] Fix wrong shape check and add new TOPI module to support groups

Posted by GitBox <gi...@apache.org>.
Lyken17 commented on pull request #9465:
URL: https://github.com/apache/tvm/pull/9465#issuecomment-966893188


   Lint fixed. Let wait for the green CI!


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] vinx13 commented on pull request #9465: [Conv2DTransposed] Fix wrong shape check and add new TOPI module to support groups

Posted by GitBox <gi...@apache.org>.
vinx13 commented on pull request #9465:
URL: https://github.com/apache/tvm/pull/9465#issuecomment-964455350


   I agree `IOHW` is preferred for `conv2d_transpose`. This is consistent with current TOPI and Relay implementation. I think the source of confusion is that Relay assign different meaning of [`OIHW` layout](https://github.com/apache/tvm/blob/main/include/tvm/relay/attrs/nn.h#L581-L586) to `conv2d_transpose`. It is actually lowered to TOPI's [IOHW implementations](https://github.com/apache/tvm/blob/main/python/tvm/topi/nn/conv2d_transpose.py#L80).
   Anyways I agree supporting `groups` is a great thing to add.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] vinx13 edited a comment on pull request #9465: [Conv2DTransposed] Fix wrong shape check and add new TOPI module to support groups

Posted by GitBox <gi...@apache.org>.
vinx13 edited a comment on pull request #9465:
URL: https://github.com/apache/tvm/pull/9465#issuecomment-963464553


   I'm confused. Currently TOPI implementations expect input of NCHW and kernel of IOHW for `conv2d_transpose`, so the shape checking in Relay in actually correct, isn't it?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] vinx13 commented on pull request #9465: [Conv2DTransposed] Fix wrong shape check and add new TOPI module to support groups

Posted by GitBox <gi...@apache.org>.
vinx13 commented on pull request #9465:
URL: https://github.com/apache/tvm/pull/9465#issuecomment-963464553


   I'm confused. Currently TOPI implementations expect input of NCHW and kernel of IOHW, so the shape checking in Relay in actually correct, isn't it?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Lyken17 commented on pull request #9465: [Conv2DTransposed] Fix wrong shape check

Posted by GitBox <gi...@apache.org>.
Lyken17 commented on pull request #9465:
URL: https://github.com/apache/tvm/pull/9465#issuecomment-963151253


   @Hzfengsy @junrushao1994 I've linted the code and added the test cases. However, current TOPI does not support groups in conv2dtranspose and I temporarilly mark the test as `skip`. This relies on another PR to fix https://github.com/apache/tvm/pull/8799, will come back to this PR after that was fixed. 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on pull request #9465: [Conv2DTransposed] Fix wrong shape check and add new TOPI module to support groups

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on pull request #9465:
URL: https://github.com/apache/tvm/pull/9465#issuecomment-966883682


   Looks good! Please fix the lint btw


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on pull request #9465: [Conv2DTransposed] Fix wrong shape check and add new TOPI module to support groups

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on pull request #9465:
URL: https://github.com/apache/tvm/pull/9465#issuecomment-967273964


   Thank you SO much @Lyken17 @alicja-SiMa-ai @vinx13 @AndrewZhaoLuo @Hzfengsy and many!


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] vinx13 commented on a change in pull request #9465: [Conv2DTransposed] Fix wrong shape check and add new TOPI module to support groups

Posted by GitBox <gi...@apache.org>.
vinx13 commented on a change in pull request #9465:
URL: https://github.com/apache/tvm/pull/9465#discussion_r747556160



##########
File path: python/tvm/relay/op/strategy/x86.py
##########
@@ -281,13 +281,25 @@ def conv2d_transpose_strategy_cpu(attrs, inputs, out_type, target):
     groups = attrs.groups
     assert layout == "NCHW", "only support nchw for now"
     assert dilation == (1, 1), "not support dilate now"
-    assert groups == 1, "only support groups == 1 for now"
+    # assert groups == 1, "only support groups == 1 for now"

Review comment:
       Remove this

##########
File path: python/tvm/relay/op/strategy/generic.py
##########
@@ -471,13 +475,20 @@ def conv2d_transpose_strategy(attrs, inputs, out_type, target):
     groups = attrs.groups
     assert layout == "NCHW", "only support nchw for now"
     assert dilation == (1, 1), "not support dilate now"
-    assert groups == 1, "only support groups == 1 for now"
+    # assert groups == 1, "only support groups == 1 for now"

Review comment:
       remove this
   




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Lyken17 commented on pull request #9465: [Conv2DTransposed] Fix wrong shape check and add new TOPI module to support groups

Posted by GitBox <gi...@apache.org>.
Lyken17 commented on pull request #9465:
URL: https://github.com/apache/tvm/pull/9465#issuecomment-966870153


   @junrushao1994 resolved, please have a check.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] vinx13 edited a comment on pull request #9465: [Conv2DTransposed] Fix wrong shape check and add new TOPI module to support groups

Posted by GitBox <gi...@apache.org>.
vinx13 edited a comment on pull request #9465:
URL: https://github.com/apache/tvm/pull/9465#issuecomment-963464553


   I'm confused. Currently TOPI implementations expect input of NCHW and kernel of IOHW for `conv2d_transpose`, so the shape checking in Relay in actually correct, isn't it?
   https://github.com/apache/tvm/blob/main/python/tvm/topi/nn/conv2d_transpose.py#L33-L37


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Lyken17 commented on pull request #9465: [Conv2DTransposed] Fix wrong shape check and add new TOPI module to support groups

Posted by GitBox <gi...@apache.org>.
Lyken17 commented on pull request #9465:
URL: https://github.com/apache/tvm/pull/9465#issuecomment-963906327


   Hi @vin13,
   
   This is bit complex. Let me explain the situation. The bug was initially found when I tried to calculate the gradients of `nn.Conv2d` with `groups`
   
   ```python
   program = """
   def @main(%input0: Tensor[(1, 32, 224, 224), float32], 
           %v0_0_weight: Tensor[(32, 1, 3, 3), float32], 
           %v1_conv_0_0_weight: Tensor[(32, 1, 3, 3), float32], 
           %v1_conv_1_weight: Tensor[(16, 32, 1, 1), float32]) -> Tensor[(1, 32, 224, 224), float32] {
     %0 = nn.conv2d(%input0, %v0_0_weight, strides=[1, 1], padding=[1, 1, 1, 1], groups=32, channels=32, kernel_size=[3, 3]);
     %0
   }
   """
   
   mod = parse_module(program)
   
   mod = relay.transform.InferType()(mod)
   bwd_ir = relay.transform.gradient(mod['main'], mode="first_order")
   bwd_mod = tvm.IRModule.from_expr(bwd_ir)
   
   """
   fn (%input0: Tensor[(1, 32, 224, 224), float32], %v0_0_weight: Tensor[(32, 1, 3, 3), float32], %v1_conv_0_0_weight: Tensor[(32, 1, 3, 3), float32], %v1_conv_1_weight: Tensor[(16, 32, 1, 1), float32]) -> (Tensor[(1, 32, 224, 224), float32], (Tensor[(1, 32, 224, 224), float32], Tensor[(32, 1, 3, 3), float32], Tensor[(32, 1, 3, 3), float32], Tensor[(16, 32, 1, 1), float32])) {
     let %x = %input0;
     let %x1 = zeros_like(%x);
     let %x2 = %v0_0_weight;
     let %x3 = zeros_like(%x2);
     let %x4 = %v1_conv_0_0_weight;
     let %x5 = zeros_like(%x4);
     let %x6 = %v1_conv_1_weight;
     let %x7 = zeros_like(%x6);
     let %x8 = nn.conv2d(%x, %x2, padding=[1, 1, 1, 1], groups=32, channels=32, kernel_size=[3, 3]) /* ty=Tensor[(1, 32, 224, 224), float32] */;
     let %x9 = zeros_like(%x8);
     %0 = ones_like(%x8);
     %1 = nn.conv2d_transpose(%0, %x2, padding=[1, 1, 1, 1], groups=32);
     %9 = (
       let %x10 = add(%x1, %1);
       %2 = tile(%0, reps=[1, 1, 1, 1]);
       %3 = reshape(%x, newshape=[1, -1, 0, 0]);
       %4 = reshape(%2, newshape=[-1, 1, 0, 0]);
       %5 = nn.conv2d(%3, %4, padding=[1, 1, 1, 1], groups=32);
       %6 = reshape(%5, newshape=[1, 1, 32, 3, 3]);
       %7 = sum(%6, axis=[0]);
       %8 = transpose(%7, axes=[1, 0, 2, 3]);
       let %x11 = add(%x3, %8);
       (%x10, %x11, %x5, %x7)
     );
     (%x8, %9)
   }
   """
   ```
   
   It is shown that `%1 = nn.conv2d_transpose(%0, %x2, padding=[1, 1, 1, 1], groups=32);` is transformed to `%1 = nn.conv2d_transpose(%0, %x2, padding=[1, 1, 1, 1], groups=32);`. The shape of `%2` is (32, 1, 3, 3), which is `OIHW` for Conv2d and `IOHW` for Conv2dTransposed. This is consistent with PyTorch [[1](https://pytorch.org/docs/stable/generated/torch.nn.functional.conv2d.html),[2](https://pytorch.org/docs/stable/generated/torch.nn.functional.conv_transpose2d.html#torch.nn.functional.conv_transpose2d)] and PaddlePaddle[[1](https://www.paddlepaddle.org.cn/documentation/docs/en/1.8/api/dygraph/Conv2DTranspose.html)]. If the original purpose of `topi.nn.conv2d_transpose` is to use `OIHW` rather than `IOHW`, then the primal gradients registration is wrong. 
   
   From my personal perspective, I prefer `IOHW` more since `conv2dtranspose` should have an `transposed` weight and this make tvm consistent pytorch. Is there any specific reason (e.g., performance tuning) that makes `OIHW` better than `IOHW`?
   
   
   In either case, the `groups` support for conv2dtranspose is missing and should added. 
   
   
   
   
   
   
   
   
    


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] vinx13 commented on pull request #9465: [Conv2DTransposed] Fix wrong shape check and add new TOPI module to support groups

Posted by GitBox <gi...@apache.org>.
vinx13 commented on pull request #9465:
URL: https://github.com/apache/tvm/pull/9465#issuecomment-964752985


   FYI there is another open PR https://github.com/apache/tvm/pull/9443


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 merged pull request #9465: [Conv2DTransposed] Fix wrong shape check and add new TOPI module to support groups

Posted by GitBox <gi...@apache.org>.
junrushao1994 merged pull request #9465:
URL: https://github.com/apache/tvm/pull/9465


   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Lyken17 edited a comment on pull request #9465: [Conv2DTransposed] Fix wrong shape check and add new TOPI module to support groups

Posted by GitBox <gi...@apache.org>.
Lyken17 edited a comment on pull request #9465:
URL: https://github.com/apache/tvm/pull/9465#issuecomment-963906327


   Hi @vin13,
   
   The situation is bit complex: The bug was initially found when I tried to calculate the gradients of `nn.Conv2d` with `groups`
   
   ```python
   program = """
   def @main(%input0: Tensor[(1, 32, 224, 224), float32], 
           %v0_0_weight: Tensor[(32, 1, 3, 3), float32], 
           %v1_conv_0_0_weight: Tensor[(32, 1, 3, 3), float32], 
           %v1_conv_1_weight: Tensor[(16, 32, 1, 1), float32]) -> Tensor[(1, 32, 224, 224), float32] {
     %0 = nn.conv2d(%input0, %v0_0_weight, strides=[1, 1], padding=[1, 1, 1, 1], groups=32, channels=32, kernel_size=[3, 3]);
     %0
   }
   """
   
   mod = parse_module(program)
   
   mod = relay.transform.InferType()(mod)
   bwd_ir = relay.transform.gradient(mod['main'], mode="first_order")
   bwd_mod = tvm.IRModule.from_expr(bwd_ir)
   
   """
   fn (%input0: Tensor[(1, 32, 224, 224), float32], %v0_0_weight: Tensor[(32, 1, 3, 3), float32], %v1_conv_0_0_weight: Tensor[(32, 1, 3, 3), float32], %v1_conv_1_weight: Tensor[(16, 32, 1, 1), float32]) -> (Tensor[(1, 32, 224, 224), float32], (Tensor[(1, 32, 224, 224), float32], Tensor[(32, 1, 3, 3), float32], Tensor[(32, 1, 3, 3), float32], Tensor[(16, 32, 1, 1), float32])) {
     let %x = %input0;
     let %x1 = zeros_like(%x);
     let %x2 = %v0_0_weight;
     let %x3 = zeros_like(%x2);
     let %x4 = %v1_conv_0_0_weight;
     let %x5 = zeros_like(%x4);
     let %x6 = %v1_conv_1_weight;
     let %x7 = zeros_like(%x6);
     let %x8 = nn.conv2d(%x, %x2, padding=[1, 1, 1, 1], groups=32, channels=32, kernel_size=[3, 3]) /* ty=Tensor[(1, 32, 224, 224), float32] */;
     let %x9 = zeros_like(%x8);
     %0 = ones_like(%x8);
     %1 = nn.conv2d_transpose(%0, %x2, padding=[1, 1, 1, 1], groups=32);
     %9 = (
       let %x10 = add(%x1, %1);
       %2 = tile(%0, reps=[1, 1, 1, 1]);
       %3 = reshape(%x, newshape=[1, -1, 0, 0]);
       %4 = reshape(%2, newshape=[-1, 1, 0, 0]);
       %5 = nn.conv2d(%3, %4, padding=[1, 1, 1, 1], groups=32);
       %6 = reshape(%5, newshape=[1, 1, 32, 3, 3]);
       %7 = sum(%6, axis=[0]);
       %8 = transpose(%7, axes=[1, 0, 2, 3]);
       let %x11 = add(%x3, %8);
       (%x10, %x11, %x5, %x7)
     );
     (%x8, %9)
   }
   """
   ```
   
   It is shown that `%1 = nn.conv2d_transpose(%0, %x2, padding=[1, 1, 1, 1], groups=32);` is transformed to `%1 = nn.conv2d_transpose(%0, %x2, padding=[1, 1, 1, 1], groups=32);`. The shape of `%2` is (32, 1, 3, 3), which is `OIHW` for Conv2d and `IOHW` for Conv2dTransposed. This is consistent with PyTorch [[1](https://pytorch.org/docs/stable/generated/torch.nn.functional.conv2d.html),[2](https://pytorch.org/docs/stable/generated/torch.nn.functional.conv_transpose2d.html#torch.nn.functional.conv_transpose2d)] and PaddlePaddle[[1](https://www.paddlepaddle.org.cn/documentation/docs/en/1.8/api/dygraph/Conv2DTranspose.html)]. 
   
   However, if the original purpose of `topi.nn.conv2d_transpose` is to use `OIHW` rather than `IOHW`, then the primal gradients registration is wrong. From my personal perspective, I prefer `IOHW` more since `conv2dtranspose` should have an `transposed` weight and this make tvm consistent pytorch. Is there any specific reason (e.g., performance tuning) that makes `OIHW` better than `IOHW`?
   
   There are two ways to fix the isssue 
   * The first is to refactor `topi.nn.conv2d_trasnpose` to make weight `OIHW` thus the gradients calculcation does not need to be changed and users can benefit from the consistent layout and APIs with other frameworks like PyTorch. 
   * The second is to re-write `nn.conv2d`'s primal gradient to make it compatible with `IOHW` weight layout. The advantage of this is we only need to update one code but users may get confused in the future because of the inconsistent layout. 
   
   In either cases, the `groups` support for conv2dtranspose is missing and should be added. 
   
   
   
   
   
   
   
   
    


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Lyken17 commented on pull request #9465: [Conv2DTransposed] Fix wrong shape check and add new TOPI module to support groups

Posted by GitBox <gi...@apache.org>.
Lyken17 commented on pull request #9465:
URL: https://github.com/apache/tvm/pull/9465#issuecomment-964702295


   Seems we are in consensus now,
   
   1. `Groups` support should be added for `conv2d_transpose`.
   2. The weight layout of `conv2d_transpose` should be `IOHW`, a.k.a, (input_channel, output_channel // groups, kernel_height, kernel_width). 
   
   Second point involves a series of changes on `topi.nn`, `topi.strategy` and `src/relay/op/convolution.h` and only update `conv2d_transpose` will make it inconsistent with `conv1d_tranpose` and `conv3d_transpose`.  Since there is ticket blocking v0.8.0 release https://github.com/apache/tvm/issues/8976, I am concerned that I might not be able to update all in one week. Let me explore whether it is possible to keep `OIHW` for now while supporting gradients and groups > 1.  


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] vinx13 edited a comment on pull request #9465: [Conv2DTransposed] Fix wrong shape check and add new TOPI module to support groups

Posted by GitBox <gi...@apache.org>.
vinx13 edited a comment on pull request #9465:
URL: https://github.com/apache/tvm/pull/9465#issuecomment-963464553


   I'm confused. Currently TOPI implementations expect input of NCHW and kernel of IOHW for `conv2d_transpose`, so the shape checking in Relay in actually correct, isn't it? (while `kernel_layout` in Relay attrs is still `OIHW` it has different meaning for `conv2d_transpose`)
   https://github.com/apache/tvm/blob/main/python/tvm/topi/nn/conv2d_transpose.py#L33-L37


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org