You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2020/01/01 09:43:03 UTC

[incubator-tvm] branch master updated: [FRONTEND][TF] Add conv3d (#4604)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 1ef1605  [FRONTEND][TF] Add conv3d (#4604)
1ef1605 is described below

commit 1ef1605a378c5a44f645b2f1459a58e0092f8726
Author: optima2005 <56...@users.noreply.github.com>
AuthorDate: Wed Jan 1 17:42:54 2020 +0800

    [FRONTEND][TF] Add conv3d (#4604)
    
    * [FRONTEND][TF] Add conv3d
    
    * fix high rtol
---
 include/tvm/relay/attrs/nn.h                     |   6 +-
 python/tvm/relay/frontend/tensorflow.py          | 133 ++++++++++++++++++++++-
 python/tvm/relay/op/nn/_nn.py                    |   5 +-
 src/relay/op/nn/convolution.cc                   |  19 ++--
 src/relay/op/nn/convolution.h                    |  13 ++-
 src/relay/op/op_common.h                         |  39 +++++++
 tests/python/frontend/tensorflow/test_forward.py |  62 ++++++++++-
 tests/python/relay/test_op_level2.py             |  96 ++++++++++++++++
 topi/python/topi/cuda/conv3d.py                  |  57 ++++++++--
 topi/python/topi/generic/nn.py                   |  16 +++
 topi/python/topi/nn/conv3d.py                    |  72 +++++++++++-
 topi/python/topi/nn/util.py                      |  12 +-
 topi/python/topi/testing/__init__.py             |   1 +
 topi/python/topi/testing/conv3d_ncdhw_python.py  |  37 ++-----
 topi/python/topi/testing/conv3d_ndhwc_python.py  |  82 ++++++++++++++
 topi/tests/python/test_topi_conv3d_ncdhw.py      |  20 +++-
 topi/tests/python/test_topi_conv3d_ndhwc.py      |  79 ++++++++++++++
 17 files changed, 683 insertions(+), 66 deletions(-)

diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h
index d724f81..a2cad94 100644
--- a/include/tvm/relay/attrs/nn.h
+++ b/include/tvm/relay/attrs/nn.h
@@ -212,7 +212,11 @@ struct Conv3DAttrs : public tvm::AttrsNode<Conv3DAttrs> {
         .describe("Specifies the strides of the convolution.");
     TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0, 0, 0}))
         .describe("If padding is non-zero, then the input is implicitly zero-padded"
-                  "on both sides for padding number of points");
+                  "Padding support both symmetric and asymmetric as"
+                  "one int : same padding used on all sides"
+                  "three int : back, bottom, right will use same padding as front, top, left"
+                  "six int : padding width in the order of (front, top, left, back, bottom,"
+                  "right)");
     TVM_ATTR_FIELD(dilation).set_default(Array<IndexExpr>({1, 1, 1}))
         .describe("Specifies the dilation rate to use for dilated convolution.");
     TVM_ATTR_FIELD(groups).set_default(1)
diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py
index f748fe8..db037e4 100644
--- a/python/tvm/relay/frontend/tensorflow.py
+++ b/python/tvm/relay/frontend/tensorflow.py
@@ -66,16 +66,18 @@ def _dimension_picker(prefix, surfix=''):
         kernel = attr['kernel_shape']
         if len(kernel) == 2:
             return prefix + '2d' + surfix
+        if len(kernel) == 3:
+            return prefix + '3d' + surfix
         raise tvm.error.OpAttributeInvalid(
-            'Only 2D kernels are supported for operator {}'.format(prefix + '2d'))
+            'Only 2D or 3D kernels are supported for operator {}'.format(prefix + '2d or 3d'))
     return _impl
 
 def _dimension_constraint():
     def _dim_check(attrs):
-        if len(attrs['kernel_shape']) == 2:
+        if len(attrs['kernel_shape']) in (2, 3):
             return True
         return False
-    return _dim_check, "Only 2d kernel supported."
+    return _dim_check, "Only 2d or 3d kernel supported."
 
 def _get_param(params, input_node):
     if isinstance(input_node, _expr.Constant):
@@ -425,6 +427,130 @@ def _conv(opname):
         return out
     return _impl
 
+def _conv3d(opname):
+    def _impl(inputs, attr, params):
+        attr['data_format'] = attr['data_format'].decode("utf-8")
+        flip_layout = False
+
+        inputs_data = inputs[0] if opname != 'conv_transpose' else inputs[2]
+
+        # NCDHW Layout require weights transpose
+        if attr['data_format'] == 'NCDHW':
+            tmp_shape = attr['_input_shapes'][inputs[1]]
+            tmp_shape = [tmp_shape[ii] for ii in (4, 3, 0, 1, 2)]
+            inputs[1] = _op.transpose(inputs[1], axes=(4, 3, 0, 1, 2))
+            attr['_input_shapes'][inputs[1]] = tmp_shape
+
+        input_shape = attr['_input_shapes'][inputs_data]
+        weights_shape = attr['_input_shapes'][inputs[1]]
+
+        if attr['_target_layout'] == "NCDHW" and attr['data_format'] == "NDHWC":
+            input_shape = [input_shape[ii] for ii in (0, 4, 1, 2, 3)]
+            inputs_data = _op.transpose(inputs_data, axes=(0, 4, 1, 2, 3))
+            weights_shape = [weights_shape[ii] for ii in (4, 3, 0, 1, 2)]
+            inputs[1] = _op.transpose(inputs[1], axes=(4, 3, 0, 1, 2))
+
+            attr['data_format'] = "NCDHW"
+            attr['strides'] = [attr['strides'][ii] for ii in (0, 4, 1, 2, 3)]
+            flip_layout = True
+
+        if attr['data_format'] == 'NDHWC':
+            kernel_d, kernel_h, kernel_w, _, _ = weights_shape
+            attr['kernel_shape'] = (kernel_d, kernel_h, kernel_w)
+            if opname == 'conv':
+                attr['channels'] = weights_shape[4]
+            elif opname == 'conv_transpose':
+                attr['channels'] = weights_shape[3]
+
+            if 'dilations' in attr:
+                attr['dilations'] =\
+                    (attr['dilations'][1], attr['dilations'][2], attr['dilations'][3])
+            attr['strides'] = (attr['strides'][1], attr['strides'][2], attr['strides'][3])
+        elif attr['data_format'] == 'NCDHW':
+            _, _, kernel_d, kernel_h, kernel_w = weights_shape
+            attr['kernel_shape'] = (kernel_d, kernel_h, kernel_w)
+            if opname == 'conv':
+                attr['channels'] = weights_shape[0]
+            elif opname == 'conv_transpose':
+                attr['channels'] = weights_shape[1]
+
+            if 'dilations' in attr:
+                attr['dilations'] =\
+                    (attr['dilations'][2], attr['dilations'][3], attr['dilations'][4])
+            attr['strides'] = (attr['strides'][2], attr['strides'][3], attr['strides'][4])
+        else:
+            msg = 'Value {} in attribute "data_format" of operator Conv is ' \
+                  'not valid.'
+            raise tvm.error.OpAttributeInvalid(msg.format(attr['data_format']))
+
+        # Fix padding
+        attr['padding'] = attr['padding'].decode("utf-8")
+
+        if attr['padding'] == 'VALID':
+            attr['padding'] = [0, 0, 0]
+        elif attr['padding'] == 'SAME':
+            stride_d, stride_h, stride_w = attr['strides']
+            kernel_d, kernel_h, kernel_w = attr['kernel_shape']
+
+            pdata_shape = input_shape
+            if opname == 'conv_transpose' and len(attr['_output_shapes']) > 0:
+                pdata_shape = attr['_output_shapes'][0]
+
+            if attr['data_format'] == 'NDHWC':
+                in_d = pdata_shape[1]
+                in_h = pdata_shape[2]
+                in_w = pdata_shape[3]
+            else:
+                in_d = pdata_shape[2]
+                in_h = pdata_shape[3]
+                in_w = pdata_shape[4]
+
+            dilation_d = attr['dilations'][0]
+            dilation_h = attr['dilations'][1]
+            dilation_w = attr['dilations'][2]
+            dilated_kernel_d = (kernel_d - 1) * dilation_d + 1
+            dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
+            dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
+            pad_d = _get_pad_pair(in_d, dilated_kernel_d, stride_d)
+            pad_v = _get_pad_pair(in_h, dilated_kernel_h, stride_h)
+            pad_h = _get_pad_pair(in_w, dilated_kernel_w, stride_w)
+
+            attr['padding'] = [pad_d[0], pad_v[0], pad_h[0], pad_v[0], pad_v[1], pad_h[1]]
+
+        else:
+            msg = 'Value {} in attribute "padding" of operator Conv is not ' \
+                  'valid.'
+            raise tvm.error.OpAttributeInvalid(msg.format(attr['padding']))
+
+        if 'kernel_layout' not in attr:
+            attr['kernel_layout'] = 'DHWIO' if attr['data_format'] == 'NDHWC' else 'OIDHW'
+
+        use_bias = len(inputs) == (3 if opname != 'conv_transpose' else 4)
+        channel_axis = 1 if attr['data_format'] == "NCDHW" else 3
+
+        # Ignore the new attributes from TF2.0, for now.
+        out = AttrCvt(
+            op_name=_dimension_picker('conv', \
+                surfix="_transpose" if opname == 'conv_transpose' else ""),
+            ignores=['explicit_paddings'],
+            transforms={
+                'kernel_shape': 'kernel_size',
+                'data_format': 'data_layout',
+                'dilations': ('dilation', (0, 0)),
+                'group': ('groups', 1)},
+            custom_check=_dimension_constraint())([inputs_data, inputs[1]], attr)
+
+        if use_bias:
+            out = _op.nn.bias_add(out,
+                                  inputs[2] if opname != 'conv_transpose' else inputs[3],
+                                  axis=channel_axis)
+
+        if flip_layout:
+            out = _op.transpose(out, axes=(0, 2, 3, 4, 1))
+
+        return out
+    return _impl
+
 def _decode_image():
     def _impl(inputs, attr, params):
         # Image decode wrapper: Expecting user to feed decoded input to next layer drop this layer.
@@ -1442,6 +1568,7 @@ _convert_map = {
     'Concat'                            : _concat(),
     'ConcatV2'                          : _concatV2(),
     'Conv2D'                            : _conv('conv'),
+    'Conv3D'                            : _conv3d('conv'),
     'Conv2DBackpropInput'               : _conv('conv_transpose'),
     'CropAndResize'                     : _crop_and_resize(),
     'DecodeJpeg'                        : _decode_image(),
diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py
index 3223258..452eb27 100644
--- a/python/tvm/relay/op/nn/_nn.py
+++ b/python/tvm/relay/op/nn/_nn.py
@@ -173,6 +173,7 @@ def compute_conv2d(attrs, inputs, out_type, target):
             assert len(weight_shape) == 5
             C, M, _, _, VC = weight_shape
             return C * VC * M
+
     if groups == 1:
         out = topi.nn.conv2d(
             inputs[0], inputs[1], strides, padding,
@@ -330,7 +331,7 @@ def compute_conv3d(attrs, inputs, out_type, target):
     out_dtype = (inputs[0].dtype if out_dtype in ("same", "")
                  else out_dtype)
 
-    assert layout in ["NCDHW"]
+    assert layout in ["NCDHW", "NDHWC"]
     (dilation_d, dilation_h, dilation_w) = dilation
     if dilation_d < 1 or dilation_h < 1 or dilation_w < 1:
         raise ValueError("dilation should be positive value")
@@ -353,6 +354,8 @@ def schedule_conv3d(attrs, outs, target):
     with target:
         if groups == 1 and layout == "NCDHW":
             return topi.generic.schedule_conv3d_ncdhw(outs)
+        elif groups == 1 and layout == "NDHWC":
+            return topi.generic.schedule_conv3d_ndhwc(outs)
 
     raise ValueError("No compatible schedule")
 
diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc
index 40c2446..e49c9d6 100644
--- a/src/relay/op/nn/convolution.cc
+++ b/src/relay/op/nn/convolution.cc
@@ -38,7 +38,7 @@ namespace relay {
 TVM_REGISTER_NODE_TYPE(Conv2DAttrs);
 
 template<typename T>
-Array<Array<Layout> > Conv2DInferCorrectLayout(
+Array<Array<Layout> > ConvInferCorrectLayout(
     const Attrs& attrs,
     const Array<Layout>& new_in_layouts,
     const Array<Layout>& old_in_layouts,
@@ -105,7 +105,7 @@ with the layer input to produce a tensor of outputs.
 .add_argument("weight", "Tensor", "The weight tensor.")
 .set_support_level(2)
 .add_type_rel("Conv2D", Conv2DRel<Conv2DAttrs>)
-.set_attr<FInferCorrectLayout>("FInferCorrectLayout", Conv2DInferCorrectLayout<Conv2DAttrs>);
+.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ConvInferCorrectLayout<Conv2DAttrs>);
 
 // relay.nn.conv3d
 TVM_REGISTER_NODE_TYPE(Conv3DAttrs);
@@ -163,7 +163,8 @@ with the layer input to produce a tensor of outputs.
 .add_argument("data", "Tensor", "The input tensor.")
 .add_argument("weight", "Tensor", "The weight tensor.")
 .set_support_level(2)
-.add_type_rel("Conv3D", Conv3DRel<Conv3DAttrs>);
+.add_type_rel("Conv3D", Conv3DRel<Conv3DAttrs>)
+.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ConvInferCorrectLayout<Conv3DAttrs>);
 
 
 // relay.nn.conv2d_transpose
@@ -337,7 +338,7 @@ v            (batch_size, channels, out_height, out_width) if `layout` is `NCHW`
 .add_argument("weight", "Tensor", "The weight tensor.")
 .set_support_level(2)
 .set_attr<FInferCorrectLayout>("FInferCorrectLayout",
-                               Conv2DInferCorrectLayout<Conv2DTransposeAttrs>)
+                               ConvInferCorrectLayout<Conv2DTransposeAttrs>)
 .add_type_rel("Conv2DTranspose", Conv2DTransposeRel);
 
 
@@ -635,7 +636,7 @@ RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_without_weight_transform")
 .set_support_level(10)
 .add_type_rel("Conv2DWinograd", Conv2DWinogradRel<Conv2DWinogradAttrs>)
 .set_attr<FInferCorrectLayout>("FInferCorrectLayout",
-        Conv2DInferCorrectLayout<Conv2DWinogradAttrs>);
+        ConvInferCorrectLayout<Conv2DWinogradAttrs>);
 
 // relay.nn.contrib_conv2d_winograd_weight_transform
 TVM_REGISTER_NODE_TYPE(Conv2DWinogradWeightTransformAttrs);
@@ -744,7 +745,7 @@ RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_nnpack_without_weight_transform")
 .add_argument("weight", "Tensor", "The weight tensor.")
 .set_support_level(10)
 .add_type_rel("Conv2DWinogradNNPACKRel", Conv2DWinogradRel<Conv2DAttrs>)
-.set_attr<FInferCorrectLayout>("FInferCorrectLayout", Conv2DInferCorrectLayout<Conv2DAttrs>);
+.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ConvInferCorrectLayout<Conv2DAttrs>);
 
 // relay.nn.contrib_conv2d_winograd_nnpack_weight_transform
 TVM_REGISTER_NODE_TYPE(Conv2DWinogradNNPACKWeightTransformAttrs);
@@ -854,7 +855,7 @@ RELAY_REGISTER_OP("nn.contrib_conv2d_NCHWc_int8")
 .set_support_level(10)
 .add_type_rel("Conv2DNCHWcInt8", Conv2DWinogradRel<Conv2DAttrs>)
 .set_attr<FInferCorrectLayout>("FInferCorrectLayout",
-        Conv2DInferCorrectLayout<Conv2DAttrs>);
+        ConvInferCorrectLayout<Conv2DAttrs>);
 
 // Positional relay function to create conv2d NCHWc operator
 // used by frontend FFI.
@@ -903,7 +904,7 @@ RELAY_REGISTER_OP("nn.contrib_conv2d_NCHWc")
 .set_support_level(10)
 .add_type_rel("Conv2DNCHWc", Conv2DWinogradRel<Conv2DAttrs>)
 .set_attr<FInferCorrectLayout>("FInferCorrectLayout",
-        Conv2DInferCorrectLayout<Conv2DAttrs>);
+        ConvInferCorrectLayout<Conv2DAttrs>);
 
 
 // Positional relay function to create depthwise conv2d NCHWc operator
@@ -953,7 +954,7 @@ RELAY_REGISTER_OP("nn.contrib_depthwise_conv2d_NCHWc")
 .set_support_level(10)
 .add_type_rel("Conv2D", Conv2DRel<Conv2DAttrs>)
 .set_attr<FInferCorrectLayout>("FInferCorrectLayout",
-        Conv2DInferCorrectLayout<Conv2DAttrs>);
+        ConvInferCorrectLayout<Conv2DAttrs>);
 
 
 bool DeformableConv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
diff --git a/src/relay/op/nn/convolution.h b/src/relay/op/nn/convolution.h
index efcf7df..0f4bb05 100644
--- a/src/relay/op/nn/convolution.h
+++ b/src/relay/op/nn/convolution.h
@@ -28,6 +28,8 @@
 #include <string>
 #include <utility>
 
+#include "../op_common.h"
+
 namespace tvm {
 namespace relay {
 
@@ -187,7 +189,7 @@ bool Conv3DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
                  param->kernel_size[1], param->kernel_size[2]}};
     }
 
-    /*wshape = trans_kernel_layout.BackwardShape(wshape); */
+    wshape = trans_kernel_layout.BackwardShape(wshape);
     channels = param->channels;
     dilated_ksize_z = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
     dilated_ksize_y = 1 + (param->kernel_size[1] - 1) * param->dilation[1];
@@ -196,6 +198,7 @@ bool Conv3DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
     if (weight != nullptr) {
       weight_dtype = weight->dtype;
     }
+
     // assign result to reporter
     reporter->Assign(types[1], TensorTypeNode::make(wshape, weight_dtype));
   } else {
@@ -225,22 +228,24 @@ bool Conv3DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
   // dilation
   Array<IndexExpr> oshape({dshape_ncdhw[0], channels, 0, 0, 0});
 
+  IndexExpr pad_d, pad_h, pad_w;
+  GetPaddingDepthHeightWidth(param->padding, &pad_d, &pad_h, &pad_w);
   if (!dshape_ncdhw[2].as<ir::Any>()) {
-    oshape.Set(2, indexdiv(dshape_ncdhw[2] + param->padding[0] * 2 - dilated_ksize_z,
+    oshape.Set(2, indexdiv(dshape_ncdhw[2] + pad_d - dilated_ksize_z,
                            param->strides[0]) + 1);
   } else {
     oshape.Set(2, dshape_ncdhw[2]);
   }
 
   if (!dshape_ncdhw[3].as<ir::Any>()) {
-    oshape.Set(3, indexdiv(dshape_ncdhw[3] + param->padding[1] * 2 - dilated_ksize_y,
+    oshape.Set(3, indexdiv(dshape_ncdhw[3] + pad_h - dilated_ksize_y,
                            param->strides[1]) + 1);
   } else {
     oshape.Set(3, dshape_ncdhw[3]);
   }
 
   if (!dshape_ncdhw[4].as<ir::Any>()) {
-    oshape.Set(4, indexdiv(dshape_ncdhw[4] + param->padding[2] * 2 - dilated_ksize_x,
+    oshape.Set(4, indexdiv(dshape_ncdhw[4] + pad_w - dilated_ksize_x,
                            param->strides[2]) + 1);
   } else {
     oshape.Set(4, dshape_ncdhw[4]);
diff --git a/src/relay/op/op_common.h b/src/relay/op/op_common.h
index 53495cc..04f26b9 100644
--- a/src/relay/op/op_common.h
+++ b/src/relay/op/op_common.h
@@ -162,6 +162,45 @@ inline void GetPaddingWidth(const Array<IndexExpr>& padding, IndexExpr* pad_w) {
   }
 }
 
+/*! \brief A utility function to get padding height and width from a 1, 2, 4 ints tuple. */
+inline void GetPaddingHeightWidth(const Array<IndexExpr>& padding, IndexExpr* pad_h,
+                                  IndexExpr* pad_w) {
+  if (padding.size() == 1) {
+    *pad_h = padding[0] * 2;
+    *pad_w = padding[0] * 2;
+  } else if (padding.size() == 2) {
+    *pad_h = padding[0] * 2;
+    *pad_w = padding[1] * 2;
+  } else if (padding.size() == 4) {
+    *pad_h = padding[0] + padding[2];
+    *pad_w = padding[1] + padding[3];
+  } else {
+    CHECK_EQ(padding.size(), 4) << " Padding size should be 1, 2 or 4, but got "
+        << padding.size();
+  }
+}
+
+/*! \brief A utility function to get padding depth, height and width from a 1, 3, 6 ints tuple. */
+inline void GetPaddingDepthHeightWidth(const Array<IndexExpr>& padding, IndexExpr* pad_d,
+                                       IndexExpr* pad_h, IndexExpr* pad_w) {
+  if (padding.size() == 1) {
+    *pad_d = padding[0] * 2;
+    *pad_h = padding[0] * 2;
+    *pad_w = padding[0] * 2;
+  } else if (padding.size() == 3) {
+    *pad_d = padding[0] * 2;
+    *pad_h = padding[1] * 2;
+    *pad_w = padding[2] * 2;
+  } else if (padding.size() == 6) {
+    *pad_d = padding[0] + padding[3];
+    *pad_h = padding[1] + padding[4];
+    *pad_w = padding[2] + padding[5];
+  } else {
+    CHECK_EQ(padding.size(), 6) << " Padding size should be 1, 3 or 6, but got "
+        << padding.size();
+  }
+}
+
 }  // namespace relay
 }  // namespace tvm
 
diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py
index 9b7fe62..97557d3 100644
--- a/tests/python/frontend/tensorflow/test_forward.py
+++ b/tests/python/frontend/tensorflow/test_forward.py
@@ -94,13 +94,14 @@ def vmobj_to_list(o):
 
 
 def run_tvm_graph(graph_def, input_data, input_node, num_output=1,
-                  target='llvm', out_names=None, opt_level=3, mode='graph_runtime'):
+                  target='llvm', out_names=None, opt_level=3, mode='graph_runtime',
+                  cuda_layout="NCHW"):
     """ Generic function to compile on relay and execute on tvm """
     input_data = convert_to_list(input_data)
     input_node = convert_to_list(input_node)
     layout = None
     if target == "cuda":
-        layout = "NCHW"
+        layout = cuda_layout
     target_host = None
     shape_dict = {e: i.shape for e, i in zip(input_node, input_data)}
     mod, params = relay.frontend.from_tensorflow(graph_def,
@@ -160,7 +161,8 @@ def run_tf_graph(sess, input_data, input_node, output_node):
 
 
 def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False,
-                        no_gpu=False, opt_level=3, mode='graph_runtime'):
+                        no_gpu=False, opt_level=3, mode='graph_runtime',
+                        cuda_layout="NCHW"):
     """Generic function to generate and compare tensorflow and TVM output"""
     def name_without_num(name):
         return name.split(':')[0] if ":" in name else name
@@ -191,7 +193,8 @@ def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False,
 
             tvm_output = run_tvm_graph(final_graph_def, in_data, in_node,
                                        target=device, out_names=out_name,
-                                       num_output=len(out_name), opt_level=opt_level, mode=mode)
+                                       num_output=len(out_name), opt_level=opt_level, mode=mode,
+                                       cuda_layout=cuda_layout)
             # since the names from tensorflow and relay runs are not exactly same,
             # first len(tf_output) will be compared
             for i in range(len(tf_output)):
@@ -470,6 +473,57 @@ def test_forward_convolution():
 
 
 #######################################################################
+# Convolution3D
+# -----------
+
+
+def _test_convolution3d(opname, tensor_in_sizes, filter_in_sizes,
+                        dilations, strides, padding, data_format,
+                        deconv_output_shape=[]):
+    """ One iteration of 3D convolution with given shapes and attributes """
+
+    total_size_1 = np.prod(tensor_in_sizes)
+    total_size_2 = np.prod(filter_in_sizes)
+    # Initializes the input tensor with array containing incrementing
+    # numbers from 1.
+    data_array = [f * 1.0 for f in range(1, total_size_1 + 1)]
+    filter_array = [f * 1.0 for f in range(1, total_size_2 + 1)]
+
+    with tf.Graph().as_default():
+        in_data = array_ops.placeholder(shape=tensor_in_sizes, dtype='float32')
+        in_filter = constant_op.constant(
+            filter_array, shape=filter_in_sizes, dtype='float32')
+        if data_format == 'NDHWC':
+            strides = [1] + strides + [1]
+            dilations = [1] + dilations + [1]
+        else:
+            strides = [1, 1] + strides
+            dilations = [1, 1] + dilations
+
+        if opname == 'conv':
+            nn_ops.conv3d(in_data,
+                          in_filter,
+                          strides=strides,
+                          dilations=dilations,
+                          padding=padding,
+                          data_format=data_format)
+
+            compare_tf_with_tvm(np.reshape(data_array, tensor_in_sizes).astype('float32'),
+                                'Placeholder:0', 'Conv3D:0', cuda_layout="NCDHW")
+
+def test_forward_convolution3d():
+    if is_gpu_available():
+        _test_convolution3d('conv', [4, 176, 8, 8, 8], [1, 1, 1, 176, 32], [1, 1, 1], [1, 1, 1], 'SAME', 'NCDHW')
+        _test_convolution3d('conv', [4, 19, 17, 17, 17], [3, 3, 3, 19, 19], [1, 1, 1], [2, 2, 2], 'VALID', 'NCDHW')
+        _test_convolution3d('conv', [4, 124, 17, 17, 17], [1, 1, 1, 124, 19], [1, 1, 1], [1, 1, 1], 'SAME', 'NCDHW')
+        _test_convolution3d('conv', [4, 12, 17, 17, 17], [3, 3, 3, 12, 32], [1, 1, 1], [2, 2, 2], 'VALID', 'NCDHW')
+    _test_convolution3d('conv', [4, 8, 8, 8, 176], [1, 1, 1, 176, 32], [1, 1, 1], [1, 1, 1], 'SAME', 'NDHWC')
+    _test_convolution3d('conv', [4, 17, 17, 17, 19], [3, 3, 3, 19, 19], [1, 1, 1], [2, 2, 2], 'VALID', 'NDHWC')
+    _test_convolution3d('conv', [4, 17, 17, 17, 124], [1, 1, 1, 124, 19], [1, 1, 1], [1, 1, 1], 'SAME', 'NDHWC')
+    _test_convolution3d('conv', [4, 17, 17, 17, 12], [3, 3, 3, 12, 32], [1, 1, 1], [2, 2, 2], 'VALID', 'NDHWC')
+
+
+#######################################################################
 # BiasAdd
 # -----------
 
diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py
index 2f19f7a..ceb5d09 100644
--- a/tests/python/relay/test_op_level2.py
+++ b/tests/python/relay/test_op_level2.py
@@ -294,6 +294,56 @@ def test_conv2d_winograd():
                          padding=(2, 2), channels=192, kernel_size=(7, 7))
 
 
+def test_conv3d_infer_type():
+    # symbolic in batch dimension
+    n, c, d, h, w = tvm.var("n"), 10, 224, 224, 224
+    x = relay.var("x", relay.ty.TensorType((n, c, d, h, w), "float32"))
+    w = relay.var("w")
+    y = relay.nn.conv3d(x, w,
+                        kernel_size=(3, 3, 3),
+                        padding=(1, 1, 1),
+                        channels=2)
+    yy = run_infer_type(y)
+    assert yy.checked_type ==  relay.TensorType(
+        (n, 2, 224, 224, 224), "float32")
+    assert yy.args[1].checked_type == relay.TensorType(
+        (2, 10, 3, 3, 3), "float32")
+
+    # infer by shape of w, mixed precision
+    n, c, d, h, w = tvm.var("n"), 10, 224, 224, 224
+    x = relay.var("x", relay.TensorType((n, c, d, h, w), "int8"))
+    w = relay.var("w", relay.TensorType((2, 10, 3, 3, 3), "int8"))
+    y = relay.nn.conv3d(x, w, out_dtype="int32")
+    assert "out_dtype=\"int32\"" in y.astext()
+    yy = run_infer_type(y)
+    assert yy.checked_type ==  relay.TensorType(
+        (n, 2, 222, 222, 222), "int32")
+
+    # infer shape in case of different dtypes for input and weight.
+    n, c, d, h, w = tvm.var("n"), 10, 224, 224, 224
+    x = relay.var("x", relay.TensorType((n, c, d, h, w), "uint8"))
+    w = relay.var("w", relay.TensorType((2, 10, 3, 3, 3), "int8"))
+    y = relay.nn.conv3d(x, w, out_dtype="int32")
+    assert "out_dtype=\"int32\"" in y.astext()
+    yy = run_infer_type(y)
+    assert yy.checked_type ==  relay.TensorType(
+        (n, 2, 222, 222, 222), "int32")
+
+    # Infer with NDHWC
+    n, c, d, h, w = 4, 32, 224, 224, 224
+    x = relay.var("x", relay.TensorType((n, d, h, w, c), "int8"))
+    wt = relay.var("w")
+    y = relay.nn.conv3d(x, wt,
+                        kernel_size=(3, 3, 3),
+                        padding=(1, 1, 1),
+                        channels=16,
+                        data_layout="NDHWC",
+                        out_dtype="int32")
+    yy = run_infer_type(y)
+    assert yy.checked_type ==  relay.TensorType(
+        (n, d, h, w, 16), "int32")
+
+
 def test_conv3d_run():
     def run_test_conv3d(dtype, out_dtype, scale, dshape, kshape,
                         padding=(1, 1, 1),
@@ -338,6 +388,50 @@ def test_conv3d_run():
     run_test_conv3d("float32", "float32", 1, dshape, kshape,
             padding=(1, 1, 1), channels=10, kernel_size=(3, 3 ,3))
 
+def test_conv3d_ndhwc_run():
+    def run_test_conv3d(dtype, out_dtype, scale, dshape, kshape,
+                        padding=(1, 1, 1),
+                        fref=None,
+                        groups=1,
+                        dilation=(1, 1, 1),
+                        except_targets=None,
+                        **attrs):
+        if except_targets is None:
+            except_targets = []
+
+        x = relay.var("x", shape=dshape, dtype=dtype)
+        w = relay.var("w", dtype=dtype)
+        y = relay.nn.conv3d(x, w,
+                            padding=padding,
+                            dilation=dilation,
+                            groups=groups,
+                            data_layout="NDHWC", kernel_layout="DHWIO",
+                            **attrs)
+        func = relay.Function([x, w], y)
+        data = np.random.uniform(-scale, scale, size=dshape).astype(dtype)
+        kernel = np.random.uniform(-scale, scale, size=kshape).astype(dtype)
+        dkernel = topi.testing.dilate_python(kernel, (1, 1) + dilation)
+        if fref is None:
+            ref_res = topi.testing.conv3d_ndhwc_python(
+                data.astype(out_dtype), dkernel.astype(out_dtype), 1, padding)
+        else:
+            ref_res = fref(data.astype(out_dtype), dkernel.astype(out_dtype))
+
+
+        for target, ctx in ctx_list():
+            if target in except_targets:
+                continue
+
+            intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
+            op_res1 = intrp1.evaluate(func)(data, kernel)
+            tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5)
+
+    # normal conv3d
+    dshape = (1, 5, 224, 224, 6)
+    kshape = (3, 3, 3, 6, 10)
+    run_test_conv3d("float32", "float32", 1, dshape, kshape,
+            padding=(1, 1, 1), channels=10, kernel_size=(3, 3 ,3), except_targets=["cuda"])
+
 
 def test_conv2d_transpose_infer_type():
     # symbolic in batch dimension
@@ -993,6 +1087,7 @@ if __name__ == "__main__":
     test_lrn()
     test_l2_normalize()
     test_conv2d_infer_type()
+    test_conv3d_infer_type()
     test_bitpack_infer_type()
     test_upsampling_infer_type()
     test_upsampling3d_infer_type()
@@ -1006,6 +1101,7 @@ if __name__ == "__main__":
     test_conv2d_run()
     test_conv2d_winograd()
     test_conv3d_run()
+    test_conv3d_ndhwc_run()
     test_bitserial_conv2d_infer_type()
     test_batch_flatten()
     test_upsampling()
diff --git a/topi/python/topi/cuda/conv3d.py b/topi/python/topi/cuda/conv3d.py
index 8d3c720..7d3c0b4 100644
--- a/topi/python/topi/cuda/conv3d.py
+++ b/topi/python/topi/cuda/conv3d.py
@@ -21,6 +21,7 @@ from tvm import autotvm
 from tvm.contrib import cudnn
 
 from .. import nn, generic
+from ..nn.util import get_pad_tuple3d
 from ..util import get_const_tuple, traverse_inline
 
 from .conv3d_direct import schedule_direct_3d_cuda
@@ -44,8 +45,10 @@ def conv3d_cuda(cfg, data, kernel, strides, padding, dilation, layout='NCDHW', o
     strides : int or a list/tuple of three ints
         stride size, or [stride_depth, stride_height, stride_width]
 
-    padding : int or a list/tuple of three ints
-        padding size, or [pad_depth, pad_height, pad_width]
+    padding : int or a list/tuple of 3 or 6 ints
+        padding size, or
+        [pad_depth, pad_height, pad_width] for 3 ints, or
+        [pad_front, pad_top, pad_left, pad_back, pad_bottom, pad_right] for 6 ints
 
     dilation: int or a list/tuple of three ints
         dilation size, or [dilation_depth, dilation_height, dilation_width]
@@ -77,25 +80,27 @@ def conv3d_cuda(cfg, data, kernel, strides, padding, dilation, layout='NCDHW', o
         # handle dilation
         stride_d, stride_h, stride_w = (strides, strides, strides) if isinstance(strides, int) \
             else strides
-        pad_d, pad_h, pad_w = (padding, padding, padding) if isinstance(padding, int) else padding
+        if isinstance(padding, (list, tuple)) and len(padding) > 3:
+            raise ValueError("Cudnn doesn't support asymmetric padding.")
+        pf, pt, pl, pk, pb, pr = get_pad_tuple3d(padding, (KD, KH, KW))
         dilation_d, dilation_h, dilation_w = (dilation, dilation, dilation) if \
             isinstance(dilation, int) else dilation
 
-        OD = (D + 2 * pad_d - KD) // stride_d + 1
-        OH = (H + 2 * pad_h - KH) // stride_h + 1
-        OW = (W + 2 * pad_w - KW) // stride_w + 1
-        cfg.add_flop(2 * N * OD * OH * OW * CO * CI * ((DH - 1) * dilation_d + 1) *\
+        OD = (D + pf + pk - KD) // stride_d + 1
+        OH = (H + pt + pb - KH) // stride_h + 1
+        OW = (W + pl + pr - KW) // stride_w + 1
+        cfg.add_flop(2 * N * OD * OH * OW * CO * CI * ((KD - 1) * dilation_d + 1) *\
                     ((KH - 1) * dilation_h + 1) * ((KW - 1) * dilation_w + 1))
 
         return cudnn.conv_forward(data,
                                   kernel,
-                                  [pad_d, pad_h, pad_w],
+                                  [pf, pt, pl],  # cudnn padding pt, pl on both sides of input
                                   [stride_d, stride_h, stride_w],
                                   [dilation_d, dilation_h, dilation_w],
                                   conv_mode=1,
                                   tensor_format=tensor_format,
                                   algo=-1,         # let CUDNN choose the best algo
-                                  conv_dtype=dtype)
+                                  conv_dtype=data.dtype)
 
     if layout == 'NCDHW':
         return nn.conv3d_ncdhw(data, kernel, strides, padding, dilation, out_dtype)
@@ -134,3 +139,37 @@ def schedule_conv3d_ncdhw_cuda(cfg, outs):
 
     traverse_inline(s, outs[0].op, _callback)
     return s
+
+
+@autotvm.register_topi_schedule(generic.schedule_conv3d_ndhwc, ["cuda", "gpu"],
+                                ["direct"])
+def schedule_conv3d_ndhwc_cuda(cfg, outs):
+    """TOPI schedule callback of conv3d for cuda gpu
+
+    Parameters
+    ----------
+    cfg: ConfigEntity
+        The config for this template
+
+    outs: Array of Tensor
+        The computation graph description of conv2d
+        in the format of an array of tensors.
+
+    Returns
+    -------
+    s: Schedule
+        The computation schedule for conv2d.
+    """
+    target = tvm.target.current_target()
+    if 'cudnn' in target.libs:
+        return generic.schedule_extern(outs)
+
+    outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
+    s = tvm.create_schedule([x.op for x in outs])
+
+    def _callback(op):
+        if op.tag == 'conv3d_ndhwc':
+            schedule_direct_3d_cuda(cfg, s, op.output(0))
+
+    traverse_inline(s, outs[0].op, _callback)
+    return s
diff --git a/topi/python/topi/generic/nn.py b/topi/python/topi/generic/nn.py
index 77f8cad..980db65 100644
--- a/topi/python/topi/generic/nn.py
+++ b/topi/python/topi/generic/nn.py
@@ -242,6 +242,22 @@ def schedule_conv3d_ncdhw(outs):
     """
     return _default_schedule(outs, False)
 
+@tvm.target.generic_func
+def schedule_conv3d_ndhwc(outs):
+    """Schedule for conv3d_ndhwc
+
+    Parameters
+    ----------
+    outs: Array of Tensor
+          The computation graph description of conv3d_ndhwc
+          in the format of an array of tensors.
+
+    Returns
+    -------
+    sch: Schedule
+        The computation schedule for the op.
+    """
+    return _default_schedule(outs, False)
 
 @tvm.target.generic_func
 def schedule_conv2d_transpose_nchw(outs):
diff --git a/topi/python/topi/nn/conv3d.py b/topi/python/topi/nn/conv3d.py
index 928f32f..21d893f 100644
--- a/topi/python/topi/nn/conv3d.py
+++ b/topi/python/topi/nn/conv3d.py
@@ -15,7 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 # pylint: disable=invalid-name, unused-variable, too-many-locals
-# pylint: disable=unused-argument, redefined-builtin
+# pylint: disable=unused-argument, redefined-builtin, no-else-return
 """Conv3D operators"""
 from __future__ import absolute_import as _abs
 import tvm
@@ -58,6 +58,8 @@ def conv3d(input, filter, strides, padding, dilation, layout='NCDHW', out_dtype=
     # default declaration
     if layout == 'NCDHW':
         return conv3d_ncdhw(input, filter, strides, padding, dilation, out_dtype)
+    elif layout == 'NDHWC':
+        return conv3d_ndhwc(input, filter, strides, padding, dilation, out_dtype)
     raise ValueError("not support this layout {} yet".format(layout))
 
 
@@ -128,3 +130,71 @@ def conv3d_ncdhw(Input, Filter, stride, padding, dilation, out_dtype=None):
                  xx * stride_w + rx * dilation_w].astype(out_dtype) *
             Filter[ff, rc, rz, ry, rx].astype(out_dtype),
             axis=[rc, rz, ry, rx]), tag="conv3d_ncdhw")
+
+
+def conv3d_ndhwc(Input, Filter, stride, padding, dilation, out_dtype='float32'):
+    """Convolution operator in NDHWC layout.
+
+    Parameters
+    ----------
+    Input : tvm.Tensor
+        5-D with shape [batch, in_channel, in_depth, in_height, in_width]
+
+    Filter : tvm.Tensor
+        5-D with shape [num_filter, in_channel, filter_depth, filter_height, filter_width]
+
+    stride : int or a list/tuple of three ints
+        Stride size, or [strid_depth, stride_height, stride_width]
+
+    padding : int or str
+        Padding size, or ['VALID', 'SAME']
+
+    dilation: int or a list/tuple of three ints
+        dilation size, or [dilation_depth, dilation_height, dilation_width]
+
+    Returns
+    -------
+    Output : tvm.Tensor
+        5-D with shape [batch, out_channel, out_depth, out_height, out_width]
+    """
+    assert isinstance(stride, int) or len(stride) == 3
+    assert isinstance(dilation, int) or len(dilation) == 3
+
+    if isinstance(stride, int):
+        stride_d = stride_h = stride_w = stride
+    else:
+        stride_d, stride_h, stride_w = stride
+
+    if isinstance(dilation, int):
+        dilation_d = dilation_h = dilation_w = dilation
+    else:
+        dilation_d, dilation_h, dilation_w = dilation
+
+    batch, in_depth, in_height, in_width, in_channel = Input.shape
+    kernel_d, kernel_h, kernel_w, channel, num_filter = Filter.shape
+    # compute the output shape
+    dilated_kernel_d = (kernel_d - 1) * dilation_d + 1
+    dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
+    dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
+
+    pad_front, pad_top, pad_left, pad_back, pad_down, pad_right = get_pad_tuple3d(
+        padding, (dilated_kernel_d, dilated_kernel_h, dilated_kernel_w))
+    out_channel = num_filter
+    out_depth = simplify((in_depth - dilated_kernel_d + pad_front + pad_back) // stride_d + 1)
+    out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1)
+    out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1)
+    pad_before = [0, pad_front, pad_top, pad_left, 0]
+    pad_after = [0, pad_back, pad_down, pad_right, 0]
+    PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput")
+    rc = tvm.reduce_axis((0, in_channel), name='rc')
+    rz = tvm.reduce_axis((0, kernel_d), name='rz')
+    ry = tvm.reduce_axis((0, kernel_h), name='ry')
+    rx = tvm.reduce_axis((0, kernel_w), name='rx')
+    Output = tvm.compute(
+        (batch, out_depth, out_height, out_width, out_channel),
+        lambda nn, zz, yy, xx, ff: tvm.sum(
+            PaddedInput[nn, zz * stride_d + rz * dilation_d, yy * stride_h + ry * dilation_h,
+                        xx * stride_w + rx * dilation_w, rc].astype(out_dtype) *
+            Filter[rz, ry, rx, rc, ff].astype(out_dtype), axis=[rz, ry, rx, rc]),
+        name="Conv3dOutput", tag="conv3d_ndhwc")
+    return Output
diff --git a/topi/python/topi/nn/util.py b/topi/python/topi/nn/util.py
index 847a5c8..c2c5c2b 100644
--- a/topi/python/topi/nn/util.py
+++ b/topi/python/topi/nn/util.py
@@ -158,9 +158,15 @@ def get_pad_tuple3d(padding, kernel):
     """
     # compute the padding size
     if isinstance(padding, (tuple, list)):
-        pad_h = padding[0] * 2
-        pad_w = padding[1] * 2
-        pad_d = padding[2] * 2
+        if len(padding) == 3:
+            pad_d = padding[0] * 2
+            pad_h = padding[1] * 2
+            pad_w = padding[2] * 2
+        elif len(padding) == 6:
+            return padding[0], padding[1], padding[2], padding[3], \
+                padding[4], padding[5]
+        else:
+            raise ValueError("Size of padding can only be 3 or 6")
     elif isinstance(padding, int):
         pad_d = pad_w = pad_h = padding * 2
     elif padding == "VALID":
diff --git a/topi/python/topi/testing/__init__.py b/topi/python/topi/testing/__init__.py
index 2826a2b..5e2f3fe 100644
--- a/topi/python/topi/testing/__init__.py
+++ b/topi/python/topi/testing/__init__.py
@@ -25,6 +25,7 @@ from .conv2d_hwcn_python import conv2d_hwcn_python
 from .conv2d_nchw_python import conv2d_nchw_python
 from .conv2d_nhwc_python import conv2d_nhwc_python
 from .conv3d_ncdhw_python import conv3d_ncdhw_python
+from .conv3d_ndhwc_python import conv3d_ndhwc_python
 from .conv2d_transpose_python import conv2d_transpose_nchw_python, conv2d_transpose_nhwc_python
 from .conv1d_transpose_ncw_python import conv1d_transpose_ncw_python
 from .deformable_conv2d_nchw_python import deformable_conv2d_nchw_python
diff --git a/topi/python/topi/testing/conv3d_ncdhw_python.py b/topi/python/topi/testing/conv3d_ncdhw_python.py
index 3a4db25..825ec62 100644
--- a/topi/python/topi/testing/conv3d_ncdhw_python.py
+++ b/topi/python/topi/testing/conv3d_ncdhw_python.py
@@ -18,6 +18,7 @@
 """Convolution 3D in python"""
 import numpy as np
 import scipy.signal
+from topi.nn.util import get_pad_tuple3d
 
 
 def _conv3d_ncdhw_python(a_np, w_np, stride, padding):
@@ -27,20 +28,13 @@ def _conv3d_ncdhw_python(a_np, w_np, stride, padding):
         stride_d = stride_h = stride_w = stride
     else:
         stride_d, stride_h, stride_w = stride
-    if isinstance(padding, int):
-        pad_d = pad_h = pad_w = padding * 2
-    elif isinstance(padding, (list, tuple)):
-        pad_d, pad_h, pad_w = padding[0] * 2, padding[1] * 2, padding[2] * 2
-    else:
-        pad_d = 0 if padding == 'VALID' else kernel_d - 1
-        pad_h = 0 if padding == 'VALID' else kernel_h - 1
-        pad_w = 0 if padding == 'VALID' else kernel_w - 1
-    pad_front = int(np.ceil(float(pad_d) / 2))
-    pad_back = pad_d - pad_front
-    pad_top = int(np.ceil(float(pad_h) / 2))
-    pad_bottom = pad_h - pad_top
-    pad_left = int(np.ceil(float(pad_w) / 2))
-    pad_right = pad_w - pad_left
+
+    pad_front, pad_top, pad_left, pad_back, pad_bottom, pad_right = \
+        get_pad_tuple3d(padding, (kernel_d, kernel_h, kernel_w))
+    pad_d = pad_front + pad_back
+    pad_h = pad_top + pad_bottom
+    pad_w = pad_left + pad_right
+
     # compute the output shape
     out_channel = num_filter
     out_depth = (in_depth - kernel_d + pad_d) // stride_d + 1
@@ -53,19 +47,8 @@ def _conv3d_ncdhw_python(a_np, w_np, stride, padding):
             for c in range(in_channel):
                 if pad_d > 0 or pad_h > 0 or pad_w > 0:
                     apad = np.zeros((in_depth + pad_d, in_height + pad_h, in_width + pad_w))
-                    if pad_d == 0 and pad_h == 0:
-                        apad[:, :, pad_left:-pad_right] = a_np[n, c]
-                    elif pad_d == 0 and pad_w == 0:
-                        apad[:, pad_top:-pad_bottom, :] = a_np[n, c]
-                    elif pad_d == 0 and pad_h != 0 and pad_w != 0:
-                        apad[:, pad_top:-pad_bottom, pad_left:-pad_right] = a_np[n, c]
-                    elif pad_d != 0 and pad_h == 0:
-                        apad[pad_front:-pad_back, :, pad_left:-pad_right] = a_np[n, c]
-                    elif pad_d != 0 and pad_w == 0:
-                        apad[pad_front:-pad_back, pad_top:-pad_bottom, :] = a_np[n, c]
-                    elif pad_d != 0 and pad_h != 0 and pad_w != 0:
-                        apad[pad_front:-pad_back, pad_top:-pad_bottom, pad_left:-pad_right] = a_np[n, c]
-
+                    apad[pad_front:pad_front + in_depth, pad_top:pad_top + in_height,\
+                        pad_left:pad_left + in_width] = a_np[n, c]
                 else:
                     apad = a_np[n, c]
                 out = scipy.signal.convolve(
diff --git a/topi/python/topi/testing/conv3d_ndhwc_python.py b/topi/python/topi/testing/conv3d_ndhwc_python.py
new file mode 100644
index 0000000..2810f72
--- /dev/null
+++ b/topi/python/topi/testing/conv3d_ndhwc_python.py
@@ -0,0 +1,82 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals
+"""Convolution 3D in python"""
+import numpy as np
+import scipy.signal
+from topi.nn.util import get_pad_tuple3d
+
+
+def conv3d_ndhwc_python(a_np, w_np, stride, padding):
+    """Convolution 3D operator in NDHWC layout.
+
+    Parameters
+    ----------
+    a_np : numpy.ndarray
+        5-D with shape [batch, in_channel, in_depth, in_height, in_width]
+
+    w_np : numpy.ndarray
+        5-D with shape [num_filter, in_channel, filter_depth, filter_height, filter_width]
+
+    stride : int or a list/tuple of three ints
+        Stride size, or [stride_depth, stride_height, stride_width]
+
+    padding : int or str or a list/tuple of three ints
+        Padding size, or ['VALID', 'SAME'], or [pad_depth, pad_height, pad_width]
+    groups : int
+        Number of groups
+
+    Returns
+    -------
+    b_np : np.ndarray
+        5-D with shape [batch, out_channel, out_depth, out_height, out_width]
+    """
+    batch, in_depth, in_height, in_width, in_channel = a_np.shape
+    kernel_d, kernel_h, kernel_w, _, num_filter = w_np.shape
+    if isinstance(stride, int):
+        stride_d = stride_h = stride_w = stride
+    else:
+        stride_d, stride_h, stride_w = stride
+
+    pad_front, pad_top, pad_left, pad_back, pad_bottom, pad_right = \
+        get_pad_tuple3d(padding, (kernel_d, kernel_h, kernel_w))
+    pad_d = pad_front + pad_back
+    pad_h = pad_top + pad_bottom
+    pad_w = pad_left + pad_right
+    # compute the output shape
+    out_channel = num_filter
+    out_depth = (in_depth - kernel_d + pad_d) // stride_d + 1
+    out_height = (in_height - kernel_h + pad_h) // stride_h + 1
+    out_width = (in_width - kernel_w + pad_w) // stride_w + 1
+    # change the layout from NHWC to NCHW
+    at = a_np.transpose((0, 4, 1, 2, 3))
+    wt = w_np.transpose((4, 3, 0, 1, 2))
+    bt = np.zeros((batch, out_channel, out_depth, out_height, out_width))
+    # computation
+    for n in range(batch):
+        for f in range(out_channel):
+            for c in range(in_channel):
+                if pad_d > 0 or pad_h > 0 or pad_w > 0:
+                    apad = np.zeros((in_depth + pad_d, in_height + pad_h, in_width + pad_w))
+                    apad[pad_front:pad_front + in_depth, pad_top:pad_top + in_height,\
+                        pad_left:pad_left + in_width] = at[n, c]
+                else:
+                    apad = at[n, c]
+                out = scipy.signal.convolve(
+                    apad, np.flip(wt[f, c]), mode='valid')
+                bt[n, f] += out[::stride_d, ::stride_h, ::stride_w]
+    return bt.transpose((0, 2, 3, 4, 1))
diff --git a/topi/tests/python/test_topi_conv3d_ncdhw.py b/topi/tests/python/test_topi_conv3d_ncdhw.py
index 78827e4..6811906 100644
--- a/topi/tests/python/test_topi_conv3d_ncdhw.py
+++ b/topi/tests/python/test_topi_conv3d_ncdhw.py
@@ -22,12 +22,16 @@ from tvm import autotvm
 import topi
 import topi.testing
 from tvm.contrib.pickle_memoize import memoize
+from topi.nn.util import get_pad_tuple3d
 from topi.util import get_const_tuple
 
 from common import get_all_backend
 
 def verify_conv3d_ncdhw(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1, add_bias=False, add_relu=False):
-    print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation))
+    pad_front, pad_top, pad_left, pad_back, pad_bottom, pad_right = get_pad_tuple3d(padding, (kernel, kernel, kernel))
+    padding_sum = pad_front + pad_back + pad_top + pad_left + pad_bottom + pad_right
+    print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % (batch, in_channel, in_size, num_filter, kernel, stride,
+          padding_sum, dilation))
 
     in_depth = in_height = in_width = in_size
 
@@ -62,7 +66,7 @@ def verify_conv3d_ncdhw(batch, in_channel, in_size, num_filter, kernel, stride,
             return
         print("Running on target: %s" % device)
         with tvm.target.create(device):
-            C = topi.nn.conv3d(A, W, (stride, stride, stride), (padding, padding, padding),
+            C = topi.nn.conv3d(A, W, (stride, stride, stride), padding,
                                (dilation, dilation, dilation), layout='NCDHW', out_dtype=dtype)
             if add_bias:
                 C = topi.add(C, bias)
@@ -75,10 +79,10 @@ def verify_conv3d_ncdhw(batch, in_channel, in_size, num_filter, kernel, stride,
         b = tvm.nd.array(b_np, ctx)
         c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
         if add_bias:
-            func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation))
+            func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
             func(a, w, b, c)
         else:
-            func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation))
+            func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
             func(a, w, c)
         tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-4)
 
@@ -109,6 +113,14 @@ def test_conv3d_ncdhw():
     verify_conv3d_ncdhw(2, 2, 2, 2, 2, 2, 2)
     verify_conv3d_ncdhw(3, 3, 3, 3, 3, 3, 3)
 
+    # Asymmetric padding
+    verify_conv3d_ncdhw(1, 32, 32, 5, 1, 1, (0, 0, 0, 1, 1, 1))
+    verify_conv3d_ncdhw(1, 32, 32, 1, 1, 1, (2, 1, 2, 1, 2, 1))
+    verify_conv3d_ncdhw(1, 64, 56, 3, 3, 1, (2, 2, 2, 1, 1, 1), dilation=2)
+    verify_conv3d_ncdhw(1, 32, 32, 5, 1, 1, (0, 1, 1))
+    verify_conv3d_ncdhw(1, 32, 32, 1, 1, 1, (2, 1, 0))
+    verify_conv3d_ncdhw(1, 32, 32, 1, 3, 1, "VALID")
+    verify_conv3d_ncdhw(1, 32, 32, 5, 1, 1, "VALID")
 
 
 if __name__ == "__main__":
diff --git a/topi/tests/python/test_topi_conv3d_ndhwc.py b/topi/tests/python/test_topi_conv3d_ndhwc.py
new file mode 100644
index 0000000..66ccf08
--- /dev/null
+++ b/topi/tests/python/test_topi_conv3d_ndhwc.py
@@ -0,0 +1,79 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Example code to do convolution."""
+import os
+import numpy as np
+import tvm
+import topi
+import topi.testing
+from tvm.contrib.pickle_memoize import memoize
+from topi.util import get_const_tuple
+
+
+def verify_conv3d_ndhwc(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1):
+    in_depth = in_height = in_width = in_size
+
+    A = tvm.placeholder((batch, in_depth, in_height, in_width, in_channel), name='A')
+    W = tvm.placeholder((kernel, kernel, kernel, in_channel, num_filter), name='W')
+    B = topi.nn.conv3d_ndhwc(A, W, stride, padding, dilation)
+
+    a_shape = get_const_tuple(A.shape)
+    w_shape = get_const_tuple(W.shape)
+    dtype = A.dtype
+
+    @memoize("topi.tests.test_topi_conv3d_ndhwc.verify_ndhwc.v2")
+    def get_ref_data():
+        a_np = np.random.uniform(size=a_shape).astype(dtype)
+        w_np = np.random.uniform(size=w_shape).astype(dtype)
+        dw_np = topi.testing.dilate_python(w_np, (dilation, dilation, dilation, 1, 1))
+        b_np = topi.testing.conv3d_ndhwc_python(a_np, dw_np, stride, padding)
+        return a_np, w_np, b_np
+    a_np, w_np, b_np = get_ref_data()
+
+    def check_device(device):
+        if not tvm.module.enabled(device):
+            print("Skip because %s is not enabled" % device)
+            return
+        print("Running on target: %s" % device)
+        with tvm.target.create(device):
+            s = topi.generic.schedule_conv3d_ndhwc([B])
+        ctx = tvm.context(device, 0)
+        a = tvm.nd.array(a_np, ctx)
+        w = tvm.nd.array(w_np, ctx)
+        b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
+        func = tvm.build(s, [A, W, B], device)
+        func(a, w, b)
+        tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
+
+    for device in ['llvm']:
+        check_device(device)
+
+
+def test_conv3d_ndhwc():
+    verify_conv3d_ndhwc(1, 16, 32, 16, 3, 1, "SAME")
+    verify_conv3d_ndhwc(4, 32, 16, 32, 5, 2, "SAME")
+    verify_conv3d_ndhwc(4, 32, 16, 64, 5, 2, "SAME")
+    verify_conv3d_ndhwc(1, 64, 32, 64, 3, 1, "VALID")
+    verify_conv3d_ndhwc(1, 64, 32, 64, 3, 1, "VALID")
+    verify_conv3d_ndhwc(4, 32, 16, 32, 5, 2, "VALID")
+    verify_conv3d_ndhwc(4, 32, 16, 64, 5, 2, "VALID")
+    # dilation = 2
+    verify_conv3d_ndhwc(1, 64, 32, 64, 3, 1, "SAME", dilation=2)
+
+
+if __name__ == "__main__":
+    test_conv3d_ndhwc()