You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by wu...@apache.org on 2020/06/30 07:05:56 UTC

[incubator-tvm] branch master updated: Fix the meaning of conv{1, 2}d_transpose output_padding parameter. (#5758)

This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new bc22fb9  Fix the meaning of conv{1,2}d_transpose output_padding parameter. (#5758)
bc22fb9 is described below

commit bc22fb9dd9edaec263bcfa03290ad1d963ce3f56
Author: abergeron <ab...@gmail.com>
AuthorDate: Tue Jun 30 03:05:43 2020 -0400

    Fix the meaning of conv{1,2}d_transpose output_padding parameter. (#5758)
    
    * Add output_padding to generic
    
    * Add output_padding to the reference impl
    
    * Add output_padding to arm_cpu
    
    * Add output_padding to the test
    
    * Add output_padding for cuda
    
    * Add output_padding for x86
    
    * Make use of the new output_padding argument in Relay
    
    * Adjust conv2d_transpose Relay test
    
    * Fix lint errors
    
    * Fix the VTA declaration of conv2d_transpose
    
    * support for output padding in conv2d transpose
    
    * some output padding will break IR pass
    
    * Fix new conv2d_transpose test
    
    * Update tophub
    
    * Fix conv1d output_padding too.
    
    * Fix the conv1d_transpose reference function.
    
    * Fix the cuda impl
    
    * fix the topi test for conv1d
    
    * format
    
    * Add tests for conv1d_transpose output_padding and some check that the values are valid.
    
    * Add check in the implementations
    
    * Add checks to the implementations of conv2d
    
    * Make use of the output_padding argument from topi in relay.
    
    * Fix relay tests asking for invalid output_padding
    
    * Fix line length
    
    * Fix vta tests
    
    * Update tophub references
    
    * Trigger CI
    
    Co-authored-by: Thierry Moreau <tm...@octoml.ai>
---
 python/tvm/autotvm/tophub.py                       |  7 ++--
 python/tvm/relay/op/nn/_nn.py                      |  1 +
 python/tvm/relay/op/nn/nn.py                       |  4 +--
 python/tvm/relay/op/strategy/generic.py            |  9 ++----
 tests/python/relay/test_op_level2.py               | 37 +++++++++++-----------
 topi/python/topi/arm_cpu/conv2d_transpose.py       | 26 ++++++++++-----
 topi/python/topi/cuda/conv1d_transpose_ncw.py      | 13 ++++++--
 topi/python/topi/cuda/conv2d_transpose_nchw.py     | 12 +++++--
 topi/python/topi/nn/conv1d_transpose.py            | 13 ++++++--
 topi/python/topi/nn/conv2d_transpose.py            | 34 ++++++++++++--------
 .../topi/testing/conv1d_transpose_ncw_python.py    | 13 ++++++--
 .../python/topi/testing/conv2d_transpose_python.py | 24 ++++++++++----
 topi/python/topi/x86/conv2d_transpose.py           |  7 ++--
 .../tests/python/test_topi_conv1d_transpose_ncw.py | 32 ++++++++++---------
 .../python/test_topi_conv2d_transpose_nchw.py      | 31 ++++++++++--------
 vta/python/vta/top/vta_conv2d_transpose.py         | 14 +++++---
 vta/scripts/tune_conv2d_transpose.py               | 19 +++++++----
 .../test_benchmark_topi_conv2d_transpose.py        | 19 ++++++-----
 18 files changed, 195 insertions(+), 120 deletions(-)

diff --git a/python/tvm/autotvm/tophub.py b/python/tvm/autotvm/tophub.py
index c7c55ed..a11c16b 100644
--- a/python/tvm/autotvm/tophub.py
+++ b/python/tvm/autotvm/tophub.py
@@ -46,16 +46,15 @@ AUTOTVM_TOPHUB_ROOT_PATH = os.path.join(os.path.expanduser('~'), ".tvm", "tophub
 
 # the version of each package
 PACKAGE_VERSION = {
-    'arm_cpu':          "v0.06",
+    'arm_cpu':          "v0.07",
     'llvm':             "v0.04",
 
-    'cuda':             "v0.08",
+    'cuda':             "v0.09",
     'rocm':             "v0.05",
     'opencl':           "v0.04",
     'mali':             "v0.06",
     'intel_graphics':   "v0.02",
-
-    'vta':              "v0.08",
+    'vta':              "v0.09",
     'amd_apu':          "v0.01",
 }
 
diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py
index 564d6f7..0757e96 100644
--- a/python/tvm/relay/op/nn/_nn.py
+++ b/python/tvm/relay/op/nn/_nn.py
@@ -171,6 +171,7 @@ def convert_conv2d(attrs, inputs, tinfos, desired_layouts):
 reg.register_strategy("nn.conv2d_transpose", strategy.conv2d_transpose_strategy)
 reg.register_pattern("nn.conv2d_transpose", OpPattern.OUT_ELEMWISE_FUSABLE)
 
+
 @reg.register_legalize("nn.conv2d_transpose")
 def legalize_conv2d_transpose(attrs, inputs, types):
     """Legalize conv2d_transpose op.
diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py
index 3c47cf7..e5009d3 100644
--- a/python/tvm/relay/op/nn/nn.py
+++ b/python/tvm/relay/op/nn/nn.py
@@ -494,7 +494,7 @@ def conv2d_transpose(data,
         Layout of the output, by default, out_layout is the same as data_layout
 
     output_padding : Tuple[int], optional
-        Additional zero-padding to be added to one side of the output.
+        Used to disambiguate the output shape.
 
     out_dtype : str, optional
         Specifies the output data type for mixed precision conv2d.
@@ -562,7 +562,7 @@ def conv1d_transpose(data,
         Layout of the output, by default, out_layout is the same as data_layout
 
     output_padding : Tuple[int], optional
-        Additional zero-padding to be added to one side of the output.
+        Used to disambiguate the output shape.
 
     out_dtype : str, optional
         Specifies the output data type for mixed precision conv2d.
diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py
index e9feee6..632445b 100644
--- a/python/tvm/relay/op/strategy/generic.py
+++ b/python/tvm/relay/op/strategy/generic.py
@@ -333,11 +333,9 @@ def wrap_compute_conv2d_transpose(topi_compute):
         out_dtype = attrs.out_dtype
         out_dtype = (inputs[0].dtype if out_dtype in ("same", "")
                      else out_dtype)
-        out = topi_compute(
-            inputs[0], inputs[1], strides, padding, out_dtype)
         output_padding = get_const_tuple(attrs.output_padding)
-        out = topi.nn.pad(out, [0, 0, 0, 0],
-                          [0, 0, output_padding[0], output_padding[1]])
+        out = topi_compute(
+            inputs[0], inputs[1], strides, padding, out_dtype, output_padding)
         return [out]
     return compute_conv2d_transpose
 
@@ -502,9 +500,8 @@ def wrap_compute_conv1d_transpose(topi_compute):
         strides = get_const_tuple(attrs.strides)
         out_dtype = attrs.out_dtype
         out_dtype = (inputs[0].dtype if out_dtype in ("same", "") else out_dtype)
-        out = topi_compute(inputs[0], inputs[1], strides, padding, out_dtype)
         output_padding = get_const_tuple(attrs.output_padding)
-        out = topi.nn.pad(out, [0, 0, 0], [0, 0, output_padding[0]])
+        out = topi_compute(inputs[0], inputs[1], strides, padding, out_dtype, output_padding)
         return [out]
     return _compute_conv1d_tranpsoe
 
diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py
index d45372e..cd54d9f 100644
--- a/tests/python/relay/test_op_level2.py
+++ b/tests/python/relay/test_op_level2.py
@@ -704,21 +704,18 @@ def test_conv2d_transpose_infer_type():
 def test_conv2d_transpose_nchw_run():
     dshape = (1, 3, 18, 18)
     kshape = (3, 10, 3, 3)
-    oshape = (1, 10, 37, 37)
+    oshape = (1, 10, 36, 36)
     x = relay.var("x", shape=dshape)
     w = relay.var("w")
     y = relay.nn.conv2d_transpose(x, w,
                                   channels=10, kernel_size=(3,3), strides=(2,2),
-                                  padding=(1,1), output_padding=(2, 2))
+                                  padding=(1,1), output_padding=(1, 1))
     func = relay.Function([x, w], y)
     dtype = "float32"
     data = np.random.uniform(size=dshape).astype(dtype)
     kernel = np.random.uniform(size=kshape).astype(dtype)
-    c_np = topi.testing.conv2d_transpose_nchw_python(
-        data, kernel, 2, 1)
-    d_np = np.zeros(shape=oshape)
-    d_np[:,:,0:c_np.shape[2],0:c_np.shape[3]] = c_np
-    ref_res = d_np
+    ref_res = topi.testing.conv2d_transpose_nchw_python(
+        data, kernel, 2, 1, (1, 1))
 
     for target, ctx in ctx_list():
         intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
@@ -729,43 +726,45 @@ def test_conv2d_transpose_nchw_run():
 def test_conv2d_transpose_nhwc_run():
     dshape_nhwc = (1, 18, 18, 3)
     kshape_hwoi = (3, 3, 10, 3)
-    oshape_nhwc = (1, 37, 37, 10)
+    oshape_nhwc = (1, 36, 36, 10)
     x = relay.var("x", shape=dshape_nhwc)
     w = relay.var("w")
     # kshape and kernel_layout should have swapped IO.
     # kshape is HWOI and kernel_layout is HWIO
     y = relay.nn.conv2d_transpose(x, w,
                                   channels=10, kernel_size=(3, 3), strides=(2, 2),
-                                  padding=(1, 1), output_padding=(2, 2),
+                                  padding=(1, 1), output_padding=(1, 1),
                                   data_layout="NHWC", kernel_layout="HWIO")
     func = relay.Function([x, w], y)
     dtype = "float32"
     data = np.random.uniform(size=dshape_nhwc).astype(dtype)
     kernel = np.random.uniform(size=kshape_hwoi).astype(dtype)
     # use true kshape layout here - HWOI
-    c_np = topi.testing.conv2d_transpose_nhwc_python(data, kernel, 'HWOI', 2, 1)
-    d_np = np.zeros(shape=oshape_nhwc)
-    d_np[:,0:c_np.shape[1],0:c_np.shape[2],:] = c_np
+
+    ref_res = topi.testing.conv2d_transpose_nhwc_python(data, kernel, 'HWOI',
+                                                        2, 1, output_padding=(1, 1))
+
+    for target, ctx in ctx_list():
+        intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
+        op_res1 = intrp1.evaluate(func)(data, kernel)
+        tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5)
 
 
 def test_conv1d_transpose_ncw_run():
     dshape = (1, 3, 18)
     kshape = (3, 10, 3)
-    oshape = (1, 10, 37)
+    oshape = (1, 10, 36)
     x = relay.var("x", shape=dshape)
     w = relay.var("w")
     y = relay.nn.conv1d_transpose(x, w,
                                   channels=10, kernel_size=(3,), strides=(2,),
-                                  padding=(1,), output_padding=(2,))
+                                  padding=(1,), output_padding=(1,))
     func = relay.Function([x, w], y)
     dtype = "float32"
     data = np.random.uniform(size=dshape).astype(dtype)
     kernel = np.random.uniform(size=kshape).astype(dtype)
-    c_np = topi.testing.conv1d_transpose_ncw_python(
-        data, kernel, 2, 1)
-    d_np = np.zeros(shape=oshape)
-    d_np[:,:,0:c_np.shape[2]] = c_np
-    ref_res = d_np
+    ref_res = topi.testing.conv1d_transpose_ncw_python(
+        data, kernel, 2, 1, output_padding=(1,))
 
     for target, ctx in ctx_list():
         intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
diff --git a/topi/python/topi/arm_cpu/conv2d_transpose.py b/topi/python/topi/arm_cpu/conv2d_transpose.py
index 7eaa5ee..8152ae2 100644
--- a/topi/python/topi/arm_cpu/conv2d_transpose.py
+++ b/topi/python/topi/arm_cpu/conv2d_transpose.py
@@ -26,8 +26,11 @@ from ..nn import dilate, pad, get_pad_tuple
 from ..util import get_const_tuple, traverse_inline
 from .conv2d_spatial_pack import schedule_conv2d_spatial_pack_nchw
 
+
+
 @autotvm.register_topi_compute("conv2d_transpose_nchw.arm_cpu")
-def conv2d_transpose_nchw(cfg, Input, Filter, strides, padding, out_dtype):
+def conv2d_transpose_nchw(cfg, Input, Filter, strides, padding, out_dtype,
+                          output_padding):
     """Transposed 2D convolution nchw forward operator.
 
     Parameters
@@ -47,27 +50,34 @@ def conv2d_transpose_nchw(cfg, Input, Filter, strides, padding, out_dtype):
     out_dtype: str
         The output data type. This is used for mixed precision.
 
+    output_padding : tuple of int
+        Used to get the right output shape in gradients
+
     Returns
     -------
     Output : tvm.te.Tensor
         4-D with shape [batch, out_channel, out_height, out_width]
     """
-    return _decl_spatial_pack(cfg, Input, Filter, strides, padding, "NCHW", out_dtype, 2)
+    return _decl_spatial_pack(cfg, Input, Filter, strides, padding, "NCHW", out_dtype, 2,
+                              output_padding)
 
-def _decl_spatial_pack(cfg, data, kernel, strides, padding, layout, out_dtype, num_tile):
+def _decl_spatial_pack(cfg, data, kernel, strides, padding, layout, out_dtype, num_tile,
+                       output_padding):
     assert layout == "NCHW", "Only support NCHW"
     out_dtype = out_dtype or data.dtype
 
     N, CI, IH, IW = get_const_tuple(data.shape)
     _, CO, KH, KW = get_const_tuple(kernel.shape)
+    HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
+    opad_h, opad_w = output_padding
+    assert opad_h < HSTR and opad_w < WSTR
 
     pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (KH, KW))
-    bpad_top, bpad_bottom = KH - 1 - pad_top, KH - 1 - pad_bottom
-    bpad_left, bpad_right = KW - 1 - pad_left, KW - 1 - pad_right
-    HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
+    bpad_top, bpad_bottom = KH - 1 - pad_top, KH - 1 - pad_bottom + opad_h
+    bpad_left, bpad_right = KW - 1 - pad_left, KW - 1 - pad_right + opad_w
 
-    OH = (IH - 1) * HSTR - pad_top - pad_bottom + KH
-    OW = (IW - 1) * WSTR - pad_left - pad_right + KW
+    OH = (IH - 1) * HSTR - pad_top - pad_bottom + KH + opad_h
+    OW = (IW - 1) * WSTR - pad_left - pad_right + KW + opad_w
 
     dilated_input = dilate(data, [1, 1, HSTR, WSTR])
     data_pad = pad(dilated_input, [0, 0, bpad_top, bpad_left], [0, 0, bpad_bottom, bpad_right])
diff --git a/topi/python/topi/cuda/conv1d_transpose_ncw.py b/topi/python/topi/cuda/conv1d_transpose_ncw.py
index cf1b66c..a2ac7e1 100644
--- a/topi/python/topi/cuda/conv1d_transpose_ncw.py
+++ b/topi/python/topi/cuda/conv1d_transpose_ncw.py
@@ -24,7 +24,8 @@ from .. import nn
 from ..util import get_const_tuple, traverse_inline
 
 @autotvm.task.register_topi_compute("conv1d_transpose_nchw.cuda")
-def conv1d_transpose_ncw(cfg, data, kernel, stride, padding, out_dtype):
+def conv1d_transpose_ncw(cfg, data, kernel, stride, padding, out_dtype,
+                         output_padding):
     """Transposed 1D convolution ncw forward operator.
 
     Parameters
@@ -43,6 +44,8 @@ def conv1d_transpose_ncw(cfg, data, kernel, stride, padding, out_dtype):
         string: ['VALID', 'SAME']
     out_dtype: str
         The output type. This is used in mixed precision
+    output_padding : ints
+        Used to disambiguate the output shape.
 
     Returns
     -------
@@ -51,13 +54,17 @@ def conv1d_transpose_ncw(cfg, data, kernel, stride, padding, out_dtype):
     """
     if isinstance(stride, (tuple, list)):
         stride = stride[0]
+    if isinstance(output_padding, (tuple, list)):
+        output_padding = output_padding[0]
+    assert output_padding < stride
     cfg.stride = stride
+    cfg.output_padding = output_padding
     batch, inp_channels, inp_width = get_const_tuple(data.shape)
     _, out_channels, kernel_size = get_const_tuple(kernel.shape)
     pad_left, pad_right = nn.get_pad_tuple1d(padding, kernel_size)
-    out_width = (inp_width - 1) * stride + kernel_size - pad_left - pad_right
+    out_width = (inp_width - 1) * stride + kernel_size - pad_left - pad_right + output_padding
     pad_left = kernel_size - 1 - pad_left
-    pad_right = kernel_size - 1 - pad_right
+    pad_right = kernel_size - 1 - pad_right + output_padding
     dilated_width = stride * (inp_width - 1) + 1
     data = te.compute(
         (batch, inp_channels, pad_left + dilated_width + pad_right),
diff --git a/topi/python/topi/cuda/conv2d_transpose_nchw.py b/topi/python/topi/cuda/conv2d_transpose_nchw.py
index 17bd37d..5ad4947 100644
--- a/topi/python/topi/cuda/conv2d_transpose_nchw.py
+++ b/topi/python/topi/cuda/conv2d_transpose_nchw.py
@@ -25,8 +25,10 @@ from .. import nn
 from ..util import get_const_tuple, traverse_inline
 
 
+
 @autotvm.register_topi_compute("conv2d_transpose_nchw.cuda")
-def conv2d_transpose_nchw(cfg, data, kernel, stride, padding, out_dtype):
+def conv2d_transpose_nchw(cfg, data, kernel, stride, padding, out_dtype,
+                          output_padding):
     """Transposed 2D convolution nchw forward operator.
 
     Parameters
@@ -43,6 +45,8 @@ def conv2d_transpose_nchw(cfg, data, kernel, stride, padding, out_dtype):
         Padding size, or ['VALID', 'SAME']
     out_dtype: str
         The output type. This is used in mixed precision
+    output_padding : tuple of two ints
+        Used to disambiguate output shape.
 
     Returns
     -------
@@ -52,18 +56,20 @@ def conv2d_transpose_nchw(cfg, data, kernel, stride, padding, out_dtype):
     batch, inp_channels, inp_height, inp_width = get_const_tuple(data.shape)
     _, out_channels, kernel_height, kernel_width = get_const_tuple(kernel.shape)
     stride_height, stride_width = stride
+    outpad_height, outpad_width = output_padding
+    assert outpad_height < stride_height and outpad_width < stride_width
     cfg.stride = stride
     pad_top, pad_left, pad_bottom, pad_right = nn.get_pad_tuple(
         padding, (kernel_height, kernel_width))
 
     out_width = (inp_width - 1) * stride_width + \
-        kernel_width - pad_left - pad_right
+        kernel_width - pad_left - pad_right + outpad_width
     pad_left = kernel_width - 1 - pad_left
     pad_right = kernel_width - 1 - pad_right
     dilated_width = stride_width * (inp_width - 1) + 1
 
     out_height = (inp_height - 1) * stride_height + \
-        kernel_height - pad_top - pad_bottom
+        kernel_height - pad_top - pad_bottom + outpad_height
     pad_top = kernel_height - 1 - pad_top
     pad_bottom = kernel_height - 1 - pad_bottom
     dilated_height = stride_height * (inp_height - 1) + 1
diff --git a/topi/python/topi/nn/conv1d_transpose.py b/topi/python/topi/nn/conv1d_transpose.py
index 1895b1f..b5b55d2 100644
--- a/topi/python/topi/nn/conv1d_transpose.py
+++ b/topi/python/topi/nn/conv1d_transpose.py
@@ -23,7 +23,8 @@ from ..util import simplify
 from .util import get_pad_tuple1d
 
 
-def conv1d_transpose_ncw(data, kernel, stride, padding, out_dtype):
+def conv1d_transpose_ncw(data, kernel, stride, padding, out_dtype,
+                         output_padding):
     """Transposed 1D convolution ncw forward operator.
 
     Parameters
@@ -43,22 +44,30 @@ def conv1d_transpose_ncw(data, kernel, stride, padding, out_dtype):
     out_dtype : str
         The output data type. This is used for mixed precision.
 
+    output_padding : ints
+        Used to recover the actual output shape in case there are more
+        than one possible shape.  Must be smaller than stride.
+
     Returns
     -------
     output : tvm.te.Tensor
         3-D with shape [batch, out_channel, out_width]
+
     """
 
     # dilate and pad
     if isinstance(stride, (tuple, list)):
         stride = stride[0]
+    if isinstance(output_padding, (tuple, list)):
+        output_padding = output_padding[0]
     batch, channels_in, data_width = data.shape
     _, channels_out, kernel_width = kernel.shape
+    assert output_padding < stride
     channels_out = simplify(channels_out)
     data = dilate(data, [1, 1, stride], name='data_dilate')
     pad_left, pad_right = get_pad_tuple1d(padding, (kernel_width,))
     pad_left = kernel_width - 1 - pad_left
-    pad_right = kernel_width - 1 - pad_right
+    pad_right = kernel_width - 1 - pad_right + output_padding
     data = pad(data, [0, 0, pad_left], [0, 0, pad_right], name='data_pad')
 
     # transpose kernel, switch kernel layout to IOW
diff --git a/topi/python/topi/nn/conv2d_transpose.py b/topi/python/topi/nn/conv2d_transpose.py
index 3563112..1fe981d 100644
--- a/topi/python/topi/nn/conv2d_transpose.py
+++ b/topi/python/topi/nn/conv2d_transpose.py
@@ -25,7 +25,9 @@ from .util import get_pad_tuple
 from ..util import simplify
 
 
-def conv2d_transpose_nchw(Input, Filter, strides, padding, out_dtype):
+
+def conv2d_transpose_nchw(Input, Filter, strides, padding, out_dtype,
+                          output_padding):
     """Transposed 2D convolution nchw forward operator.
 
     Parameters
@@ -45,28 +47,34 @@ def conv2d_transpose_nchw(Input, Filter, strides, padding, out_dtype):
     out_dtype : str
         The output data type. This is used for mixed precision.
 
+    output_padding : tuple of ints
+        Used to get the right output shape for gradients
+
     Returns
     -------
     Output : tvm.te.Tensor
         4-D with shape [batch, out_channel, out_height, out_width]
     """
-    return declaration_conv2d_transpose_impl(Input, Filter, strides, padding, out_dtype)
+    return declaration_conv2d_transpose_impl(Input, Filter, strides, padding, out_dtype,
+                                             output_padding=output_padding)
 
 
-def conv2d_transpose_nchw_preprocess(data, kernel, strides, padding, out_dtype):
+def conv2d_transpose_nchw_preprocess(data, kernel, strides, padding, out_dtype, output_padding):
     """Preprocess data and kernel to make the compute pattern
        of conv2d_transpose the same as conv2d"""
     batch, in_c, in_h, in_w = data.shape
     _, out_c, filter_h, filter_w = kernel.shape
     stride_h, stride_w = strides
+    opad_h, opad_w = output_padding
+    assert opad_h < stride_h and opad_w < stride_w
     # dilate data
     data_dilate = dilate(data, [1, 1, stride_h, stride_w], name='data_dilate')
     # pad data
     fpad_top, fpad_left, fpad_bottom, fpad_right = get_pad_tuple(padding, (filter_h, filter_w))
     bpad_top = filter_h - 1 - fpad_top
-    bpad_bottom = filter_h - 1 - fpad_bottom
+    bpad_bottom = filter_h - 1 - fpad_bottom + opad_h
     bpad_left = filter_w - 1 - fpad_left
-    bpad_right = filter_w - 1 - fpad_right
+    bpad_right = filter_w - 1 - fpad_right + opad_w
     data_pad = pad(data_dilate, \
                    [0, 0, bpad_top, bpad_left], \
                    [0, 0, bpad_bottom, bpad_right], \
@@ -78,21 +86,21 @@ def conv2d_transpose_nchw_preprocess(data, kernel, strides, padding, out_dtype):
     return data_pad, kernel_transform
 
 
-def declaration_conv2d_transpose_impl(data, kernel, strides, padding, out_dtype):
+def declaration_conv2d_transpose_impl(data, kernel, strides, padding, out_dtype, output_padding):
     """Implementation of conv2d transpose"""
     data_pad, kernel_transform = \
-        conv2d_transpose_nchw_preprocess(data, kernel, strides, padding, out_dtype)
+        conv2d_transpose_nchw_preprocess(data, kernel, strides, padding, out_dtype, output_padding)
     batch, in_c, in_h, in_w = data_pad.shape
     out_c, _, filter_h, filter_w = kernel_transform.shape
-    stride_h, stride_w = strides
 
     # convolution stage
     out_c = simplify(out_c)
-    out_h = simplify(in_h - filter_h + 1)
-    out_w = simplify(in_w - filter_w + 1)
-    dc = te.reduce_axis((0, in_c), name='dc')
-    dh = te.reduce_axis((0, filter_h), name='dh')
-    dw = te.reduce_axis((0, filter_w), name='dw')
+
+    out_h = simplify(in_h - filter_h + 1 + output_padding[0])
+    out_w = simplify(in_w - filter_w + 1 + output_padding[1])
+    dc = tvm.reduce_axis((0, in_c), name='dc')
+    dh = tvm.reduce_axis((0, filter_h), name='dh')
+    dw = tvm.reduce_axis((0, filter_w), name='dw')
 
     Output = te.compute(
         (batch, out_c, out_h, out_w),
diff --git a/topi/python/topi/testing/conv1d_transpose_ncw_python.py b/topi/python/topi/testing/conv1d_transpose_ncw_python.py
index cb78bbf..b472f33 100644
--- a/topi/python/topi/testing/conv1d_transpose_ncw_python.py
+++ b/topi/python/topi/testing/conv1d_transpose_ncw_python.py
@@ -21,7 +21,7 @@ import scipy
 import topi
 from topi.nn.util import get_pad_tuple1d
 
-def conv1d_transpose_ncw_python(a_np, w_np, stride, padding):
+def conv1d_transpose_ncw_python(a_np, w_np, stride, padding, output_padding):
     """Transposed 1D convolution operator in NCW layout.
 
     Parameters
@@ -40,27 +40,34 @@ def conv1d_transpose_ncw_python(a_np, w_np, stride, padding):
         tuple of 2 ints for left and right padding, or
         ['VALID', 'SAME']
 
+    output_padding : tuple
+        Used to recover the actual output shape in case more than one
+        is possible
+
     Returns
     -------
     b_np : np.ndarray
         3-D with shape [batch, out_channel, out_width]
+
     """
     batch, in_c, in_w = a_np.shape
     _, out_c, filter_w = w_np.shape
+    opad = output_padding[0]
     if isinstance(stride, int):
         stride_w = stride
     else:
         stride_w = stride[0]
+    assert opad < stride_w
     fpad_left, fpad_right = get_pad_tuple1d(padding, filter_w)
     # dilate stage
     dilated_a_np = topi.testing.dilate_python(a_np, [1, 1, stride_w])
     # padding stage
     bpad_left = filter_w - 1 - fpad_left
-    bpad_right = filter_w - 1 - fpad_right
+    bpad_right = filter_w - 1 - fpad_right + opad
     padded_a_np = np.zeros((batch, in_c, dilated_a_np.shape[2]+bpad_left+bpad_right))
     padded_a_np[:, :, bpad_left:dilated_a_np.shape[2]+bpad_left] = dilated_a_np
     # convolution stage
-    out_w = (in_w - 1) * stride_w - fpad_left - fpad_right + filter_w
+    out_w = (in_w - 1) * stride_w - fpad_left - fpad_right + filter_w + opad
     b_np = np.zeros((batch, out_c, out_w))
     for n in range(batch):
         for f in range(out_c):
diff --git a/topi/python/topi/testing/conv2d_transpose_python.py b/topi/python/topi/testing/conv2d_transpose_python.py
index c789fec..83e9287 100644
--- a/topi/python/topi/testing/conv2d_transpose_python.py
+++ b/topi/python/topi/testing/conv2d_transpose_python.py
@@ -22,7 +22,7 @@ import topi
 from topi.nn.util import get_pad_tuple
 
 
-def conv2d_transpose_nchw_python(a_np, w_np, stride, padding):
+def conv2d_transpose_nchw_python(a_np, w_np, stride, padding, output_padding):
     """Transposed convolution operator in NCHW layout.
 
     Parameters
@@ -39,6 +39,9 @@ def conv2d_transpose_nchw_python(a_np, w_np, stride, padding):
     padding : int or str
         Padding size, or ['VALID', 'SAME']
 
+    output_padding : int or a list/tuple of two ints
+        Use to disambiguate the output shape.
+
     Returns
     -------
     b_np : np.ndarray
@@ -50,21 +53,26 @@ def conv2d_transpose_nchw_python(a_np, w_np, stride, padding):
         stride_h = stride_w = stride
     else:
         stride_h, stride_w = stride
+    if isinstance(output_padding, int):
+        opad_h = opad_w = output_padding
+    else:
+        opad_h, opad_w = output_padding
+    assert opad_h < stride_h and opad_w < stride_w
     # dilate stage
     dilated_a_np = topi.testing.dilate_python(a_np, [1, 1, stride_h, stride_w])
     # padding stage
     fpad_top, fpad_left, fpad_bottom, fpad_right = get_pad_tuple(padding, (filter_h, filter_w))
     bpad_top = filter_h - 1 - fpad_top
-    bpad_bottom = filter_h - 1 - fpad_bottom
+    bpad_bottom = filter_h - 1 - fpad_bottom + opad_h
     bpad_left = filter_w - 1 - fpad_left
-    bpad_right = filter_w - 1 - fpad_right
+    bpad_right = filter_w - 1 - fpad_right + opad_w
     padded_a_np = np.zeros((batch, in_c, dilated_a_np.shape[2]+bpad_top+bpad_bottom, \
                             dilated_a_np.shape[3]+bpad_left+bpad_right))
     padded_a_np[:, :, bpad_top:dilated_a_np.shape[2]+bpad_top, \
                 bpad_left:dilated_a_np.shape[3]+bpad_left] = dilated_a_np
     # convolution stage
-    out_h = (in_h - 1) * stride_h - fpad_top - fpad_bottom + filter_h
-    out_w = (in_w - 1) * stride_w - fpad_left - fpad_right + filter_w
+    out_h = (in_h - 1) * stride_h - fpad_top - fpad_bottom + filter_h + opad_h
+    out_w = (in_w - 1) * stride_w - fpad_left - fpad_right + filter_w + opad_w
     b_np = np.zeros((batch, out_c, out_h, out_w))
     for n in range(batch):
         for f in range(out_c):
@@ -75,7 +83,8 @@ def conv2d_transpose_nchw_python(a_np, w_np, stride, padding):
     return b_np
 
 
-def conv2d_transpose_nhwc_python(a_nhwc, weight, weight_format, stride, padding):
+def conv2d_transpose_nhwc_python(a_nhwc, weight, weight_format, stride, padding,
+                                 output_padding=(0, 0)):
     """Transposed convolution operator in NHWC layout.
 
     Parameters
@@ -117,6 +126,7 @@ def conv2d_transpose_nhwc_python(a_nhwc, weight, weight_format, stride, padding)
     else:
         raise ValueError('Valid weight_formats are HWIO, HWOI, OIHW or IOHW')
 
-    res_nchw = conv2d_transpose_nchw_python(a_nchw, w_iohw, stride, padding)
+    res_nchw = conv2d_transpose_nchw_python(a_nchw, w_iohw, stride, padding,
+                                            output_padding=output_padding)
     res_nhwc = np.transpose(res_nchw, (0, 2, 3, 1))
     return res_nhwc
diff --git a/topi/python/topi/x86/conv2d_transpose.py b/topi/python/topi/x86/conv2d_transpose.py
index f90edb5..d490b28 100644
--- a/topi/python/topi/x86/conv2d_transpose.py
+++ b/topi/python/topi/x86/conv2d_transpose.py
@@ -21,13 +21,16 @@ from ..util import traverse_inline
 from .. import nn
 from .conv2d import conv2d_nchw, schedule_conv2d_nchw
 
-def conv2d_transpose_nchw(data, kernel, strides, padding, out_dtype):
+
+def conv2d_transpose_nchw(data, kernel, strides, padding, out_dtype, output_padding):
     data_pad, kernel_transform = \
-        nn.conv2d_transpose_nchw_preprocess(data, kernel, strides, padding, out_dtype)
+        nn.conv2d_transpose_nchw_preprocess(data, kernel, strides, padding,
+                                            out_dtype, output_padding)
     # reuse conv2d_nchw implementation
     return conv2d_nchw(data_pad, kernel_transform, strides=(1, 1),
                        padding=(0, 0), dilation=(1, 1), out_dtype=out_dtype)
 
+
 def schedule_conv2d_transpose_nchw(outs):
     """Create schedule for tensors"""
     outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
diff --git a/topi/tests/python/test_topi_conv1d_transpose_ncw.py b/topi/tests/python/test_topi_conv1d_transpose_ncw.py
index 4d015bf..0cecbef 100644
--- a/topi/tests/python/test_topi_conv1d_transpose_ncw.py
+++ b/topi/tests/python/test_topi_conv1d_transpose_ncw.py
@@ -30,7 +30,7 @@ _conv1d_transpose_ncw_implement = {
     "gpu": (topi.cuda.conv1d_transpose_ncw, topi.cuda.schedule_conv1d_transpose_ncw)
 }
 
-def verify_conv1d_transpose_ncw(batch, in_channel, in_size, num_filter, kernel, stride, padding):
+def verify_conv1d_transpose_ncw(batch, in_channel, in_size, num_filter, kernel, stride, padding, output_padding):
     in_width = in_size
     A = te.placeholder((batch, in_channel, in_width), name='A')
     W = te.placeholder((in_channel, num_filter, kernel), name='W')
@@ -43,7 +43,7 @@ def verify_conv1d_transpose_ncw(batch, in_channel, in_size, num_filter, kernel,
     def get_ref_data():
         a_np = np.random.uniform(size=a_shape).astype(dtype)
         w_np = np.random.uniform(size=w_shape).astype(dtype)
-        b_np = topi.testing.conv1d_transpose_ncw_python(a_np, w_np, stride, padding)
+        b_np = topi.testing.conv1d_transpose_ncw_python(a_np, w_np, stride, padding, output_padding)
         c_np = np.maximum(b_np, 0)
         return a_np, w_np, b_np, c_np
 
@@ -56,7 +56,7 @@ def verify_conv1d_transpose_ncw(batch, in_channel, in_size, num_filter, kernel,
             return
         with tvm.target.create(device):
             fcompute, fschedule = topi.testing.dispatch(device, _conv1d_transpose_ncw_implement)
-            B = fcompute(A, W, stride, padding, A.dtype)
+            B = fcompute(A, W, stride, padding, A.dtype, output_padding)
             C = topi.nn.relu(B)
             s1 = fschedule([B])
             s2 = fschedule([C])
@@ -77,18 +77,20 @@ def verify_conv1d_transpose_ncw(batch, in_channel, in_size, num_filter, kernel,
 
 
 def test_conv1d_transpose_ncw():
-    verify_conv1d_transpose_ncw(1, 3, 224, 32, 5, 1, 0)
-    verify_conv1d_transpose_ncw(1, 3, 224, 32, 7, 1, 2)
-    verify_conv1d_transpose_ncw(1, 3, 224, 32, 5, 2, 1)
-    verify_conv1d_transpose_ncw(1, 3, 224, 32, 5, 2, 0)
-    verify_conv1d_transpose_ncw(1, 32, 32, 128, 5, 1, 0)
-    verify_conv1d_transpose_ncw(1, 32, 32, 128, 5, 2, 1)
-    verify_conv1d_transpose_ncw(1, 1, 1024, 1, 512, 1, 256)
-    verify_conv1d_transpose_ncw(1, 1, 1024, 1, 512, 2, 256)
-    verify_conv1d_transpose_ncw(1, 1, 1024, 1, 512, 5, 256)
-    verify_conv1d_transpose_ncw(1, 1, 10, 1, 5, 1, (0,3))
-    verify_conv1d_transpose_ncw(1, 1, 10, 1, 5, 1, (1,3))
-    verify_conv1d_transpose_ncw(1, 1, 10, 1, 5, 1, (2,3))
+    verify_conv1d_transpose_ncw(1, 3, 224, 32, 5, 1, 0, (0,))
+    verify_conv1d_transpose_ncw(1, 3, 224, 32, 7, 1, 2, (0,))
+    verify_conv1d_transpose_ncw(1, 3, 224, 32, 5, 2, 1, (0,))
+    verify_conv1d_transpose_ncw(1, 3, 224, 32, 5, 2, 1, (1,))
+    verify_conv1d_transpose_ncw(1, 3, 224, 32, 5, 2, 0, (0,))
+    verify_conv1d_transpose_ncw(1, 32, 32, 128, 5, 1, 0, (0,))
+    verify_conv1d_transpose_ncw(1, 32, 32, 128, 5, 2, 1, (0,))
+    verify_conv1d_transpose_ncw(1, 1, 1024, 1, 512, 1, 256, (0,))
+    verify_conv1d_transpose_ncw(1, 1, 1024, 1, 512, 2, 256, (0,))
+    verify_conv1d_transpose_ncw(1, 1, 1024, 1, 512, 5, 256, (0,))
+    verify_conv1d_transpose_ncw(1, 1, 1024, 1, 512, 5, 256, (3,))
+    verify_conv1d_transpose_ncw(1, 1, 10, 1, 5, 1, (0,3), (0,))
+    verify_conv1d_transpose_ncw(1, 1, 10, 1, 5, 1, (1,3), (0,))
+    verify_conv1d_transpose_ncw(1, 1, 10, 1, 5, 1, (2,3), (0,))
 
 if __name__ == "__main__":
     test_conv1d_transpose_ncw()
diff --git a/topi/tests/python/test_topi_conv2d_transpose_nchw.py b/topi/tests/python/test_topi_conv2d_transpose_nchw.py
index e8e1fce..11f1903 100644
--- a/topi/tests/python/test_topi_conv2d_transpose_nchw.py
+++ b/topi/tests/python/test_topi_conv2d_transpose_nchw.py
@@ -25,6 +25,7 @@ from topi.util import get_const_tuple
 
 from common import get_all_backend
 
+
 _conv2d_transpose_nchw_implement = {
     "generic": (topi.nn.conv2d_transpose_nchw, topi.generic.schedule_conv2d_transpose_nchw),
     "cpu": (topi.x86.conv2d_transpose_nchw, topi.x86.schedule_conv2d_transpose_nchw),
@@ -33,7 +34,7 @@ _conv2d_transpose_nchw_implement = {
     "hls": (topi.nn.conv2d_transpose_nchw, topi.hls.schedule_conv2d_transpose_nchw),
 }
 
-def verify_conv2d_transpose_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding):
+def verify_conv2d_transpose_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding, output_padding):
     in_height, in_width = in_size
     kernel_height, kernel_width = kernel
     stride_height, stride_width = stride
@@ -50,7 +51,7 @@ def verify_conv2d_transpose_nchw(batch, in_channel, in_size, num_filter, kernel,
     def get_ref_data():
         a_np = np.random.uniform(size=a_shape).astype(dtype)
         w_np = np.random.uniform(size=w_shape).astype(dtype)
-        b_np = topi.testing.conv2d_transpose_nchw_python(a_np, w_np, stride, padding)
+        b_np = topi.testing.conv2d_transpose_nchw_python(a_np, w_np, stride, padding, output_padding)
         c_np = np.maximum(b_np, 0)
         return a_np, w_np, b_np, c_np
 
@@ -67,7 +68,7 @@ def verify_conv2d_transpose_nchw(batch, in_channel, in_size, num_filter, kernel,
             B = fcompute(A, W,
                          [stride_height, stride_width],
                          [pad_top, pad_left, pad_bottom, pad_right],
-                         A.dtype)
+                         A.dtype, output_padding)
             C = topi.nn.relu(B)
             s1 = fschedule([B])
             s2 = fschedule([C])
@@ -87,16 +88,20 @@ def verify_conv2d_transpose_nchw(batch, in_channel, in_size, num_filter, kernel,
 
 
 def test_conv2d_transpose_nchw():
-    verify_conv2d_transpose_nchw(1, 3, (224, 224),  1, (1, 1), (1, 1), (0, 0, 0, 0))
-    verify_conv2d_transpose_nchw(1, 3, (224, 224),  32, (3, 3), (1, 1), (0, 0, 0, 0))
-    verify_conv2d_transpose_nchw(1, 3, (224, 224),  32, (3, 3), (3, 3), (0, 0, 0, 0))
-    verify_conv2d_transpose_nchw(1, 3, (224, 224),  32, (3, 3), (1, 1), (0, 0, 0, 0))
-    verify_conv2d_transpose_nchw(1, 3, (224, 224),  32, (3, 3), (2, 2), (1, 1, 1, 1))
-    verify_conv2d_transpose_nchw(1, 3, (224, 224),  32, (2, 2), (2, 2), (0, 0, 0, 0))
-    verify_conv2d_transpose_nchw(1, 32, (32, 32), 128, (5, 5), (1, 1), (0, 0, 0, 0))
-    verify_conv2d_transpose_nchw(1, 32, (32, 32), 128, (5, 5), (2, 2), (1, 1, 1, 1))
-    verify_conv2d_transpose_nchw(16, 32, (8192, 1), 8, (31, 1), (2, 1), (14, 0, 15, 0))
-    verify_conv2d_transpose_nchw(16, 512, (8, 1), 128, (31, 1), (2, 1), (14, 0, 15, 0))
+    verify_conv2d_transpose_nchw(1, 3, (224, 224),  1, (1, 1), (1, 1), (0, 0, 0, 0), (0, 0))
+    verify_conv2d_transpose_nchw(1, 3, (224, 224),  32, (3, 3), (1, 1), (0, 0, 0, 0), (0, 0))
+    verify_conv2d_transpose_nchw(1, 3, (224, 224),  32, (3, 3), (3, 3), (0, 0, 0, 0), (0, 0))
+    verify_conv2d_transpose_nchw(1, 3, (224, 224),  32, (3, 3), (1, 1), (0, 0, 0, 0), (0, 0))
+    verify_conv2d_transpose_nchw(1, 3, (224, 224),  32, (3, 3), (2, 2), (1, 1, 1, 1), (0, 0))
+    verify_conv2d_transpose_nchw(1, 3, (224, 224),  32, (3, 3), (2, 2), (1, 1, 1, 1), (1, 0))
+    verify_conv2d_transpose_nchw(1, 3, (224, 224),  32, (2, 2), (2, 2), (0, 0, 0, 0), (0, 0))
+    verify_conv2d_transpose_nchw(1, 3, (224, 224),  32, (2, 2), (2, 2), (0, 0, 0, 0), (1, 1))
+    verify_conv2d_transpose_nchw(1, 32, (32, 32), 128, (5, 5), (1, 1), (0, 0, 0, 0), (0, 0))
+    verify_conv2d_transpose_nchw(1, 32, (32, 32), 128, (5, 5), (2, 2), (1, 1, 1, 1), (0, 0))
+    verify_conv2d_transpose_nchw(16, 32, (8192, 1), 8, (31, 1), (2, 1), (14, 0, 15, 0), (0, 0))
+    verify_conv2d_transpose_nchw(16, 512, (8, 1), 128, (31, 1), (2, 1), (14, 0, 15, 0), (0, 0))
+    verify_conv2d_transpose_nchw(16, 512, (8, 1), 128, (31, 1), (2, 1), (14, 0, 15, 0), (1, 0))
+
 
 if __name__ == "__main__":
     test_conv2d_transpose_nchw()
diff --git a/vta/python/vta/top/vta_conv2d_transpose.py b/vta/python/vta/top/vta_conv2d_transpose.py
index 4f213f6..ddfebc2 100644
--- a/vta/python/vta/top/vta_conv2d_transpose.py
+++ b/vta/python/vta/top/vta_conv2d_transpose.py
@@ -28,20 +28,24 @@ from topi.nn.util import get_pad_tuple
 from ..environment import get_env
 
 @autotvm.register_topi_compute("conv2d_transpose_packed.vta")
-def conv2d_transpose_packed(cfg, data, kernel, strides, padding, out_dtype):
+def conv2d_transpose_packed(cfg, data, kernel, strides, padding, out_dtype,
+                            output_padding=(0, 0)):
     """Packed conv2d_transpose compute"""
     ishape = get_const_tuple(data.shape)
     kshape = get_const_tuple(kernel.shape)
     b, c_i, i_h, i_w, t_b, t_ci = ishape
     c_o, _, k_h, k_w, t_co, t_ci = kshape
     stride_h, stride_w = strides
+    opad_h, opad_w = output_padding
+    # FIXME(tmoreau89): currently IR pass breaks when output padding != (0,0)
+    assert opad_h == 0 and opad_w == 0, "VTA does not support output padding for now"
 
     # derive padding parameters
     fpad_top, fpad_left, fpad_bottom, fpad_right = get_pad_tuple(padding, (k_h, k_w))
     bpad_top = k_h - 1 - fpad_top
-    bpad_bottom = k_h - 1 - fpad_bottom
+    bpad_bottom = k_h - 1 - fpad_bottom + opad_h
     bpad_left = k_w - 1 - fpad_left
-    bpad_right = k_w - 1 - fpad_right
+    bpad_right = k_w - 1 - fpad_right + opad_w
 
     # padding stage
     dilated_input = topi.nn.dilate(data, [1, 1, stride_h, stride_w, 1, 1])
@@ -50,8 +54,8 @@ def conv2d_transpose_packed(cfg, data, kernel, strides, padding, out_dtype):
                            [0, 0, bpad_bottom, bpad_right, 0, 0])
 
     # convolution transpose stage
-    out_h = (i_h - 1) * stride_h - fpad_top - fpad_bottom + k_h
-    out_w = (i_w - 1) * stride_w - fpad_left - fpad_right + k_w
+    out_h = (i_h - 1) * stride_h - fpad_top - fpad_bottom + k_h + opad_h
+    out_w = (i_w - 1) * stride_w - fpad_left - fpad_right + k_w + opad_w
     oshape = (b, c_o, out_h, out_w, t_b, t_co)
     d_c = te.reduce_axis((0, c_i), name='d_c')
     d_h = te.reduce_axis((0, k_h), name='d_h')
diff --git a/vta/scripts/tune_conv2d_transpose.py b/vta/scripts/tune_conv2d_transpose.py
index 0871367..b7c380e 100644
--- a/vta/scripts/tune_conv2d_transpose.py
+++ b/vta/scripts/tune_conv2d_transpose.py
@@ -33,13 +33,15 @@ env = vta.get_env()
 
 Workload = namedtuple("Conv2DTransposeWorkload",
                       ['batch', 'height', 'width', 'in_filter', 'out_filter',
-                       'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride'])
+                       'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride',
+                       'o_hpad', 'o_wpad'])
 
+# DCGAN workloads
 dcgan_wkls = [
     # dcgan
-    ('DCGAN.CT1', Workload(env.BATCH,  4,  4, 1024, 512, 4, 4, 1, 1, 2, 2)),
-    ('DCGAN.CT2', Workload(env.BATCH,  8,  8,  512, 256, 4, 4, 1, 1, 2, 2)),
-    ('DCGAN.CT3', Workload(env.BATCH, 16, 16,  256, 128, 4, 4, 1, 1, 2, 2)),
+    ('DCGAN.CT1', Workload(env.BATCH,  4,  4, 1024, 512, 4, 4, 1, 1, 2, 2, 0, 0)),
+    ('DCGAN.CT2', Workload(env.BATCH,  8,  8,  512, 256, 4, 4, 1, 1, 2, 2, 0, 0)),
+    ('DCGAN.CT3', Workload(env.BATCH, 16, 16,  256, 128, 4, 4, 1, 1, 2, 2, 0, 0)),
 ]
 
 @tvm.te.tag_scope(tag=topi.tag.ELEMWISE)
@@ -51,7 +53,7 @@ def my_clip(x, a_min, a_max):
     x = te.compute(x.shape, lambda *i: tvm.te.max(x(*i), const_min), name="clipB")
     return x
 
-def conv2d_transpose(N, CI, H, W, CO, KH, KW, strides, padding):
+def conv2d_transpose(N, CI, H, W, CO, KH, KW, strides, padding, opadding):
     data_shape = (N//env.BATCH, CI//env.BLOCK_IN, H, W, env.BATCH, env.BLOCK_IN)
     kernel_shape = (CO//env.BLOCK_OUT, CI//env.BLOCK_IN, KH, KW, env.BLOCK_OUT, env.BLOCK_IN)
 
@@ -64,7 +66,9 @@ def conv2d_transpose(N, CI, H, W, CO, KH, KW, strides, padding):
             Filter=kernel,
             strides=strides,
             padding=padding,
-            out_dtype=env.acc_dtype)
+            out_dtype=env.acc_dtype,
+            output_padding=opadding
+        )
         res = topi.right_shift(res, env.WGT_WIDTH)
         res = my_clip(res, 0, (1 << env.OUT_WIDTH - 1) - 1)
         res = topi.cast(res, env.out_dtype)
@@ -109,11 +113,12 @@ if __name__ == '__main__':
         KW = wl.wkernel
         strides = (wl.hstride, wl.wstride)
         padding = (wl.hpad, wl.wpad)
+        opadding = (wl.o_hpad, wl.o_wpad)
 
         # Create task
         task = autotvm.task.create(
                 conv2d_transpose,
-                args=(N, CI, H, W, CO, KH, KW, strides, padding),
+                args=(N, CI, H, W, CO, KH, KW, strides, padding, opadding),
                 target=tvm.target.vta(),
                 target_host=env.target_host,
                 template_key='direct')
diff --git a/vta/tests/python/integration/test_benchmark_topi_conv2d_transpose.py b/vta/tests/python/integration/test_benchmark_topi_conv2d_transpose.py
index 90cc21f..558c3ab 100644
--- a/vta/tests/python/integration/test_benchmark_topi_conv2d_transpose.py
+++ b/vta/tests/python/integration/test_benchmark_topi_conv2d_transpose.py
@@ -40,7 +40,8 @@ from vta.testing import simulator
 
 Workload = namedtuple("Conv2DTransposeWorkload",
                       ['batch', 'height', 'width', 'in_filter', 'out_filter',
-                       'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride'])
+                       'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride',
+                       'o_hpad', 'o_wpad'])
 
 # Get batch info from env
 env = vta.get_env()
@@ -48,9 +49,9 @@ env = vta.get_env()
 # DCGAN workloads
 dcgan_wklds = [
     # dcgan
-    ('DCGAN.CT1', Workload(env.BATCH,  4,  4, 1024, 512, 4, 4, 1, 1, 2, 2)),
-    ('DCGAN.CT2', Workload(env.BATCH,  8,  8,  512, 256, 4, 4, 1, 1, 2, 2)),
-    ('DCGAN.CT3', Workload(env.BATCH, 16, 16,  256, 128, 4, 4, 1, 1, 2, 2)),
+    ('DCGAN.CT1', Workload(env.BATCH,  4,  4, 1024, 512, 4, 4, 1, 1, 2, 2, 0, 0)),
+    ('DCGAN.CT2', Workload(env.BATCH,  8,  8,  512, 256, 4, 4, 1, 1, 2, 2, 0, 0)),
+    ('DCGAN.CT3', Workload(env.BATCH, 16, 16,  256, 128, 4, 4, 1, 1, 2, 2, 0, 0)),
 ]
 
 # FIXME: we need a custom clip operator to circumvent a pattern detection limitation
@@ -109,8 +110,10 @@ def run_conv2d_transpose(env, remote, wl, target,
 
     # Define base computation schedule
     with target:
+
         res = fcompute(
-            data, kernel, (wl.hstride, wl.wstride), padding, env.acc_dtype)
+            data, kernel, (wl.hstride, wl.wstride), padding, env.acc_dtype,
+            (wl.o_hpad, wl.o_wpad))
         res = topi.right_shift(res, env.WGT_WIDTH)
         res = my_clip(res, 0, (1 << env.OUT_WIDTH - 1) - 1)
         res = topi.cast(res, env.out_dtype)
@@ -120,8 +123,8 @@ def run_conv2d_transpose(env, remote, wl, target,
             print(vta.lower(s, [data, kernel, res], simple_mode=True))
 
     # Derive number of ops
-    fout_height = (wl.height - 1) * wl.hstride - 2 * wl.hpad + wl.hkernel
-    fout_width = (wl.width - 1) * wl.wstride - 2 * wl.wpad + wl.wkernel
+    fout_height = (wl.height - 1) * wl.hstride - 2 * wl.hpad + wl.hkernel + wl.o_hpad
+    fout_width = (wl.width - 1) * wl.wstride - 2 * wl.wpad + wl.wkernel + wl.o_wpad
     num_ops = 2 * wl.batch * fout_height * fout_width * wl.hkernel * wl.wkernel * wl.out_filter * wl.in_filter
 
     # @memoize("vta.tests.test_benchmark_topi.conv2d.verify_nhwc")
@@ -132,7 +135,7 @@ def run_conv2d_transpose(env, remote, wl, target,
         a_np = np.random.randint(a_min, a_max, size=a_shape).astype(data.dtype)
         w_np = np.random.randint(w_min, w_max, size=(wl.in_filter, wl.out_filter, wl.hkernel, wl.wkernel)).astype(kernel.dtype)
         r_np = topi.testing.conv2d_transpose_nchw_python(
-            a_np.astype(env.acc_dtype), w_np.astype(env.acc_dtype), (wl.hstride, wl.wstride), wl.hpad).astype(env.acc_dtype)
+            a_np.astype(env.acc_dtype), w_np.astype(env.acc_dtype), (wl.hstride, wl.wstride), wl.hpad, (wl.o_hpad, wl.o_wpad)).astype(env.acc_dtype)
         return a_np, w_np, r_np
 
     # Data in original format