You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/03/02 06:59:54 UTC

[GitHub] [tvm] Lyken17 opened a new pull request #10439: Fix gradient OP for nn.conv2d

Lyken17 opened a new pull request #10439:
URL: https://github.com/apache/tvm/pull/10439


   Current backward impl raises error for nn.Conv2d, either normal conv or depth-wise conv. See the code attached below.
   
   ```python
   import numpy as np
   
   import tvm
   from tvm import relay
   from tvm.contrib import graph_executor
   
   normal_conv_code = """
   fn (%input0: Tensor[(1, 3, 32, 32), float32], %v0_weight: Tensor[(3, 1, 3, 3), float32], %v0_bias: Tensor[(3), float32]) {
     %0 = nn.conv2d(%input0, %v0_weight, padding=[1, 1, 1, 1], groups=3, channels=3, kernel_size=[3, 3]);
     nn.bias_add(%0, %v0_bias)
   }
   """
   
   depthwise_conv_code = """
   fn (%input0: Tensor[(1, 3, 32, 32), float32], %v0_weight: Tensor[(3, 3, 3, 3), float32], %v0_bias: Tensor[(3), float32]) {
     %0 = nn.conv2d(%input0, %v0_weight, padding=[1, 1, 1, 1], groups=1, channels=3, kernel_size=[3, 3]);
     nn.bias_add(%0, %v0_bias)
   }
   """
   
   SEMVER = '#[version = "0.0.5"]\n'
   expr = tvm.parser.parse_expr(SEMVER + normal_conv_code)
   fmod = tvm.IRModule.from_expr(expr)
   
   mod = relay.transform.InferType()(fmod)
   bwd_expr = relay.transform.gradient(mod["main"], mode="first_order")
   
   bwd_mod = tvm.IRModule.from_expr(bwd_expr)
   bwd_mod = relay.transform.InferType()(bwd_mod)
   ```
   
   This PR aims to roll back the impl to previous version while fixing the bug for depth-wise (previous backward does not work for depth-wise conv).
   
   Thanks for contributing to TVM!   Please refer to guideline https://tvm.apache.org/docs/contribute/ for useful information and tips. After the pull request is submitted, please request code reviews from [Reviewers](https://github.com/apache/incubator-tvm/blob/master/CONTRIBUTORS.md#reviewers) by @ them in the pull request thread.
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Lyken17 edited a comment on pull request #10439: Fix gradient OP for nn.conv2d

Posted by GitBox <gi...@apache.org>.
Lyken17 edited a comment on pull request #10439:
URL: https://github.com/apache/tvm/pull/10439#issuecomment-1056647885


   Thanks for the explaination.  I agree it is important to annote the return dtype in gradient calculation. But such information is missing when loading models using `relay.fronted.from_xxx`. For example, 
   
   ```python
   import numpy as np
   
   import torch
   import torch.nn as nn
   
   import tvm
   from tvm import relay
   from tvm.contrib import graph_executor
   
   net = nn.Sequential(
       nn.Conv2d(3, 3, 3, padding=1, groups=3)
   )
   
   input_shape = [1, 3, 32, 32]
   input_data = torch.randn(input_shape)
   input_name = "input0"
   shape_list = [(input_name, input_data.shape)]
   
   scripted_model = torch.jit.trace(net, input_data).eval()
   fmod, params = relay.frontend.from_pytorch(scripted_model, shape_list, default_dtype="float32")
   
   mod = relay.transform.InferType()(fmod)
   bwd_expr = relay.transform.gradient(mod["main"], mode="first_order")
   
   bwd_mod = tvm.IRModule.from_expr(bwd_expr)
   bwd_mod = relay.transform.InferType()(bwd_mod)
   ```
   
   ```bash
   data types float32 and void do not match in BroadcastRel
   note: run with `TVM_BACKTRACE=1` environment variable to display a backtrace.
   ```


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Lyken17 commented on pull request #10439: Fix gradient OP for nn.conv2d

Posted by GitBox <gi...@apache.org>.
Lyken17 commented on pull request #10439:
URL: https://github.com/apache/tvm/pull/10439#issuecomment-1056438757


   cc @masahi since it is related with https://github.com/apache/tvm/pull/9954
   
   My env is built with commit 111b2da1372ec53cae76d63a21a45eb0c26e4a64 and the error can be easily reproduced via above program. 
   
   While I agree set a customized OP for conv2d_grad will ease Cudnn / Cutlass to accelerate,  there is something wrong with current impl.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Lyken17 commented on pull request #10439: Fix gradient OP for nn.conv2d

Posted by GitBox <gi...@apache.org>.
Lyken17 commented on pull request #10439:
URL: https://github.com/apache/tvm/pull/10439#issuecomment-1056656951


   Also, the dtype information is not provided in when I try to build model with relay 
   
   ```python
   input = relay.var("input", shape=[1,3,32,32], dtype="float32")
   weight = relay.var("weight", shape=[3,1,3,3], dtype="float32")
   out = relay.nn.conv2d(input, weight, groups=3, channels=3)
   fn = relay.Function([input, weight], out)
   fmod = tvm.IRModule.from_expr(fn)
   mod = relay.transform.InferType()(fmod)
   ```
   
   ```bash
   data types float32 and void do not match in BroadcastRel
   note: run with `TVM_BACKTRACE=1` environment variable to display a backtrace.
   ```


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Lyken17 edited a comment on pull request #10439: Fix gradient OP for nn.conv2d

Posted by GitBox <gi...@apache.org>.
Lyken17 edited a comment on pull request #10439:
URL: https://github.com/apache/tvm/pull/10439#issuecomment-1056656951


   Also when I try to build model with relay API, such dtype still cannot be handled properly. 
   
   ```python
   input = relay.var("input", shape=[1,3,32,32], dtype="float32")
   weight = relay.var("weight", shape=[3,1,3,3], dtype="float32")
   out = relay.nn.conv2d(input, weight, groups=3, channels=3)
   fn = relay.Function([input, weight], out)
   fmod = tvm.IRModule.from_expr(fn)
   
   mod = relay.transform.InferType()(fmod)
   
   bwd_expr = relay.transform.gradient(mod["main"], mode="first_order")
   bwd_mod = tvm.IRModule.from_expr(bwd_expr)
   bwd_mod = relay.transform.InferType()(bwd_mod)
   ```
   
   ```bash
   data types float32 and void do not match in BroadcastRel
   note: run with `TVM_BACKTRACE=1` environment variable to display a backtrace.
   ```
   
   There might be some missing part that make `out_type` null in this case. For now, I would recommend put "float32" as a default value to avoid these errors.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Lyken17 edited a comment on pull request #10439: Fix gradient OP for nn.conv2d

Posted by GitBox <gi...@apache.org>.
Lyken17 edited a comment on pull request #10439:
URL: https://github.com/apache/tvm/pull/10439#issuecomment-1056647885


   Thanks for the explaination.  I agree it is important to annote the return dtype in gradient calculation. But such information is missing when loading models using `relay.fronted.from_xxx`. For example, 
   
   ```python
   import numpy as np
   
   import torch
   import torch.nn as nn
   
   import tvm
   from tvm import relay
   from tvm.contrib import graph_executor
   
   net = nn.Sequential(
       nn.Conv2d(3, 3, 3, padding=1, groups=3)
   )
   
   input_shape = [1, 3, 32, 32]
   input_data = torch.randn(input_shape)
   input_name = "input0"
   shape_list = [(input_name, input_data.shape)]
   
   scripted_model = torch.jit.trace(net, input_data).eval()
   fmod, params = relay.frontend.from_pytorch(scripted_model, shape_list, default_dtype="float32")
   
   mod = relay.transform.InferType()(fmod)
   bwd_expr = relay.transform.gradient(mod["main"], mode="first_order")
   
   bwd_mod = tvm.IRModule.from_expr(bwd_expr)
   bwd_mod = relay.transform.InferType()(bwd_mod)
   ```
   
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on pull request #10439: Fix gradient OP for nn.conv2d

Posted by GitBox <gi...@apache.org>.
masahi commented on pull request #10439:
URL: https://github.com/apache/tvm/pull/10439#issuecomment-1056467063


   See https://github.com/apache/tvm/blob/be176974a03b6f7e69fee2186f3847d9c092c546/src/relay/op/nn/convolution.cc#L1803
   
   Usually, if `out_dtype` is not provided, we can use the dtype from inputs. But in your case, the parser somehow returns conv2d with `out_dtype == void`, which is weird. So the output dtype of the wgrad becomes `void`. So the correct fix is to do some change in the parser.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on pull request #10439: Fix gradient OP for nn.conv2d

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on pull request #10439:
URL: https://github.com/apache/tvm/pull/10439#issuecomment-1056447459


   CC @Hzfengsy @YuchenJin @ZihengJiang 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi closed pull request #10439: Fix gradient OP for nn.conv2d

Posted by GitBox <gi...@apache.org>.
masahi closed pull request #10439:
URL: https://github.com/apache/tvm/pull/10439


   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on pull request #10439: Fix gradient OP for nn.conv2d

Posted by GitBox <gi...@apache.org>.
masahi commented on pull request #10439:
URL: https://github.com/apache/tvm/pull/10439#issuecomment-1056691758


   We can update https://github.com/apache/tvm/blob/be176974a03b6f7e69fee2186f3847d9c092c546/src/relay/op/nn/convolution.cc#L1803 
   
   to 
   
   ```
     const auto dw_dtype = (param->out_dtype == DataType() or param->out_dtype.is_void())
                               ? grad->dtype
                               : param->out_dtype;
   ```
   
   I don't know why the default out dtype is `void`... but this should do the job. Welcome to send a PR.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Lyken17 edited a comment on pull request #10439: Fix gradient OP for nn.conv2d

Posted by GitBox <gi...@apache.org>.
Lyken17 edited a comment on pull request #10439:
URL: https://github.com/apache/tvm/pull/10439#issuecomment-1056656951


   Also, the dtype information is not provided in when I try to build model with relay 
   
   ```python
   input = relay.var("input", shape=[1,3,32,32], dtype="float32")
   weight = relay.var("weight", shape=[3,1,3,3], dtype="float32")
   out = relay.nn.conv2d(input, weight, groups=3, channels=3)
   fn = relay.Function([input, weight], out)
   fmod = tvm.IRModule.from_expr(fn)
   
   mod = relay.transform.InferType()(fmod)
   
   bwd_expr = relay.transform.gradient(mod["main"], mode="first_order")
   bwd_mod = tvm.IRModule.from_expr(bwd_expr)
   bwd_mod = relay.transform.InferType()(bwd_mod)
   ```
   
   ```bash
   data types float32 and void do not match in BroadcastRel
   note: run with `TVM_BACKTRACE=1` environment variable to display a backtrace.
   ```


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi edited a comment on pull request #10439: Fix gradient OP for nn.conv2d

Posted by GitBox <gi...@apache.org>.
masahi edited a comment on pull request #10439:
URL: https://github.com/apache/tvm/pull/10439#issuecomment-1056467063


   See https://github.com/apache/tvm/blob/be176974a03b6f7e69fee2186f3847d9c092c546/src/relay/op/nn/convolution.cc#L1803
   
   Usually, if `out_dtype` is not provided, we can use the dtype from inputs. But in your case, the parser somehow returns conv2d with `out_dtype == void`, which is weird. So the output dtype of the wgrad becomes `void`. So the correct fix is to do some change in the parser.
   
   And the reason `out_dtype` is important for wgrad is, if the input is fp16, we might want to set the out dtype to be fp32. Without out dtype, if the input is fp16 we will end up computing wgrad with fp16 precision, which is probably not what we want.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi edited a comment on pull request #10439: Fix gradient OP for nn.conv2d

Posted by GitBox <gi...@apache.org>.
masahi edited a comment on pull request #10439:
URL: https://github.com/apache/tvm/pull/10439#issuecomment-1056455121


   You need to add  `out_dtype="float32"` to your mod. It works.
   
   ```
   def @main(%input0: Tensor[(1, 3, 32, 32), float32], %v0_weight: Tensor[(3, 1, 3, 3), float32], %v0_bias: Tensor[(3), float32]) -> (Tensor[(1, 3, 32, 32), float32], (Tensor[(1, 3, 32, 32), float32], Tensor[(3, 1, 3, 3), float32], Tensor[(3), float32])) {
     let %x_0: Tensor[(1, 3, 32, 32), float32] = %input0;
     let %x_1: Tensor[(1, 3, 32, 32), float32] = zeros_like(%x_0) /* ty=Tensor[(1, 3, 32, 32), float32] */;
     let %x_2: Tensor[(3, 1, 3, 3), float32] = %v0_weight;
     let %x_3: Tensor[(3, 1, 3, 3), float32] = zeros_like(%x_2) /* ty=Tensor[(3, 1, 3, 3), float32] */;
     let %x_4: Tensor[(3), float32] = %v0_bias;
     let %x_5: Tensor[(3), float32] = zeros_like(%x_4) /* ty=Tensor[(3), float32] */;
     let %x_6: Tensor[(1, 3, 32, 32), float32] = nn.conv2d(%x_0, %x_2, padding=[1, 1, 1, 1], groups=3, channels=3, kernel_size=[3, 3], out_dtype="float32") /* ty=Tensor[(1, 3, 32, 32), float32] */;
     let %x_7: Tensor[(1, 3, 32, 32), float32] = zeros_like(%x_6) /* ty=Tensor[(1, 3, 32, 32), float32] */;
     let %x_8: Tensor[(1, 3, 32, 32), float32] = nn.bias_add(%x_6, %x_4) /* ty=Tensor[(1, 3, 32, 32), float32] */;
     let %x_9: Tensor[(1, 3, 32, 32), float32] = zeros_like(%x_8) /* ty=Tensor[(1, 3, 32, 32), float32] */;
     %0 = ones_like(%x_8) /* ty=Tensor[(1, 3, 32, 32), float32] */;
     %1 = collapse_sum_like(%0, %x_6) /* ty=Tensor[(1, 3, 32, 32), float32] */;
     %5 = (
       let %x_10: Tensor[(1, 3, 32, 32), float32] = add(%x_7, %1) /* ty=Tensor[(1, 3, 32, 32), float32] */;
       %2 = sum(%0, axis=[1], exclude=True) /* ty=Tensor[(3), float32] */;
       let %x_11: Tensor[(3), float32] = add(%x_5, %2) /* ty=Tensor[(3), float32] */;
       %3 = nn.conv2d_transpose(%x_10, %x_2, padding=[1, 1, 1, 1], groups=3, kernel_layout="IOHW") /* ty=Tensor[(1, 1, 32, 32), float32] */;
       let %x_12: Tensor[(1, 3, 32, 32), float32] = add(%x_1, %3) /* ty=Tensor[(1, 3, 32, 32), float32] */;
       %4 = nn.conv2d_backward_weight(%x_10, %x_0, padding=[1, 1, 1, 1], groups=3, channels=3, kernel_size=[3, 3], kernel_layout="NCHW", out_layout="OIHW", out_dtype="float32") /* ty=Tensor[(3, 1, 3, 3), float32] */;
       let %x_13: Tensor[(3, 1, 3, 3), float32] = add(%x_3, %4) /* ty=Tensor[(3, 1, 3, 3), float32] */;
       (%x_12, %x_13, %x_11)
     );
     (%x_8, %5)
   }
   ```


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on pull request #10439: Fix gradient OP for nn.conv2d

Posted by GitBox <gi...@apache.org>.
masahi commented on pull request #10439:
URL: https://github.com/apache/tvm/pull/10439#issuecomment-1056455121


   You need to add  `out_dtype="float32"` to your mod. It works.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Lyken17 commented on pull request #10439: Fix gradient OP for nn.conv2d

Posted by GitBox <gi...@apache.org>.
Lyken17 commented on pull request #10439:
URL: https://github.com/apache/tvm/pull/10439#issuecomment-1056647885


   Thanks for the explaination.  But such dtype information is missing when using `from_pytorch` and `from_onnx`. 
   
   ```python
   import numpy as np
   
   import torch
   import torch.nn as nn
   
   import tvm
   from tvm import relay
   from tvm.contrib import graph_executor
   
   net = nn.Sequential(
       nn.Conv2d(3, 3, 3, padding=1, groups=3)
   )
   
   input_shape = [1, 3, 32, 32]
   input_data = torch.randn(input_shape)
   input_name = "input0"
   shape_list = [(input_name, input_data.shape)]
   
   scripted_model = torch.jit.trace(net, input_data).eval()
   fmod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
   
   mod = relay.transform.InferType()(fmod)
   bwd_expr = relay.transform.gradient(mod["main"], mode="first_order")
   
   bwd_mod = tvm.IRModule.from_expr(bwd_expr)
   bwd_mod = relay.transform.InferType()(bwd_mod)
   ```
   
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Lyken17 edited a comment on pull request #10439: Fix gradient OP for nn.conv2d

Posted by GitBox <gi...@apache.org>.
Lyken17 edited a comment on pull request #10439:
URL: https://github.com/apache/tvm/pull/10439#issuecomment-1056647885


   Thanks for the explaination.  I agree it is important to annote the return dtype in gradient calculation. But such dtype information is missing in most tvm use-cases. For example, when we load model using `tvm.relay.frontend.from_xxx`, current impl cannot handle it properly
   
   ```python
   import numpy as np
   
   import torch
   import torch.nn as nn
   
   import tvm
   from tvm import relay
   from tvm.contrib import graph_executor
   
   net = nn.Sequential(
       nn.Conv2d(3, 3, 3, padding=1, groups=3)
   )
   
   input_shape = [1, 3, 32, 32]
   input_data = torch.randn(input_shape)
   input_name = "input0"
   shape_list = [(input_name, input_data.shape)]
   
   scripted_model = torch.jit.trace(net, input_data).eval()
   fmod, params = relay.frontend.from_pytorch(scripted_model, shape_list, default_dtype="float32")
   
   mod = relay.transform.InferType()(fmod)
   bwd_expr = relay.transform.gradient(mod["main"], mode="first_order")
   
   bwd_mod = tvm.IRModule.from_expr(bwd_expr)
   bwd_mod = relay.transform.InferType()(bwd_mod)
   ```
   
   ```bash
   data types float32 and void do not match in BroadcastRel
   note: run with `TVM_BACKTRACE=1` environment variable to display a backtrace.
   ```


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi edited a comment on pull request #10439: Fix gradient OP for nn.conv2d

Posted by GitBox <gi...@apache.org>.
masahi edited a comment on pull request #10439:
URL: https://github.com/apache/tvm/pull/10439#issuecomment-1056467063


   See https://github.com/apache/tvm/blob/be176974a03b6f7e69fee2186f3847d9c092c546/src/relay/op/nn/convolution.cc#L1803
   
   Usually, if `out_dtype` is not provided, we can use the dtype from inputs. But in your case, the parser somehow returns conv2d with `out_dtype == void`, which is weird. So the output dtype of the wgrad becomes `void`. So the correct fix is to do some change in the parser.
   
   And the reason `out_dtype` is important for wgrad is, if the input is fp16, we might want to set the out dtype to be fp32.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Lyken17 edited a comment on pull request #10439: Fix gradient OP for nn.conv2d

Posted by GitBox <gi...@apache.org>.
Lyken17 edited a comment on pull request #10439:
URL: https://github.com/apache/tvm/pull/10439#issuecomment-1056647885


   Thanks for the explaination.  I agree it is important to annote the return dtype in gradient calculation. But such information is missing when loading models using `relay.fronted.from_xxx`. For example, 
   
   ```python
   import numpy as np
   
   import torch
   import torch.nn as nn
   
   import tvm
   from tvm import relay
   from tvm.contrib import graph_executor
   
   net = nn.Sequential(
       nn.Conv2d(3, 3, 3, padding=1, groups=3)
   )
   
   input_shape = [1, 3, 32, 32]
   input_data = torch.randn(input_shape)
   input_name = "input0"
   shape_list = [(input_name, input_data.shape)]
   
   scripted_model = torch.jit.trace(net, input_data).eval()
   fmod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
   
   mod = relay.transform.InferType()(fmod)
   bwd_expr = relay.transform.gradient(mod["main"], mode="first_order")
   
   bwd_mod = tvm.IRModule.from_expr(bwd_expr)
   bwd_mod = relay.transform.InferType()(bwd_mod)
   ```
   
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on pull request #10439: Fix gradient OP for nn.conv2d

Posted by GitBox <gi...@apache.org>.
masahi commented on pull request #10439:
URL: https://github.com/apache/tvm/pull/10439#issuecomment-1057636263


   https://github.com/apache/tvm/pull/10459


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Lyken17 edited a comment on pull request #10439: Fix gradient OP for nn.conv2d

Posted by GitBox <gi...@apache.org>.
Lyken17 edited a comment on pull request #10439:
URL: https://github.com/apache/tvm/pull/10439#issuecomment-1056656951






-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org