You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ke...@apache.org on 2020/03/17 06:54:40 UTC

[incubator-tvm] branch master updated: [Relay, TF Frontend] Dilation2D operator support (#5033)

This is an automated email from the ASF dual-hosted git repository.

kevinthesun pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 646cfc6  [Relay, TF Frontend] Dilation2D operator support (#5033)
646cfc6 is described below

commit 646cfc637169e85065fcae46b12068eab1ff8dbc
Author: Mahesh Ambule <15...@users.noreply.github.com>
AuthorDate: Tue Mar 17 12:24:32 2020 +0530

    [Relay, TF Frontend] Dilation2D operator support (#5033)
    
    * update docs for dilation 2d
    
    * dilation2d compute
    
    * dilation2d register
    
    * dilation2d rel compute
    
    * dilation2d strategy
    
    * dilation2d attrs
    
    * dilation2d generic schedule
    
    * dilation2d tf frontend support
    
    * dilation2d tf frontend test case
    
    * dilation2d test cases
    
    * pylint fixes
    
    * add exception for cuda target
    
    * Update docstring
    
    * Update docstring
    
    * change rates to dilations
    
    * removed unused param
    
    * merge master
    
    * Update nn.py
    
    * Update nn.py
---
 docs/api/python/topi.rst                         |   2 +
 docs/frontend/tensorflow.rst                     |   1 +
 docs/langref/relay_op.rst                        |   2 +
 include/tvm/relay/attrs/nn.h                     |  36 +++++
 python/tvm/relay/frontend/tensorflow.py          |  86 ++++++++++++
 python/tvm/relay/op/nn/_nn.py                    |   3 +
 python/tvm/relay/op/nn/nn.py                     |  57 ++++++++
 python/tvm/relay/op/op_attrs.py                  |   3 +
 python/tvm/relay/op/strategy/generic.py          |  51 +++++++
 src/relay/op/nn/convolution.cc                   |  60 +++++++++
 src/relay/op/nn/convolution.h                    |  71 ++++++++++
 tests/python/frontend/tensorflow/test_forward.py |  47 ++++++-
 tests/python/relay/test_op_level2.py             | 109 +++++++++++++++
 topi/python/topi/generic/nn.py                   |  30 +++++
 topi/python/topi/nn/__init__.py                  |   1 +
 topi/python/topi/nn/dilation2d.py                | 165 +++++++++++++++++++++++
 16 files changed, 723 insertions(+), 1 deletion(-)

diff --git a/docs/api/python/topi.rst b/docs/api/python/topi.rst
index 269d42d..39b120f 100644
--- a/docs/api/python/topi.rst
+++ b/docs/api/python/topi.rst
@@ -57,6 +57,7 @@ List of operators
    topi.nn.relu
    topi.nn.leaky_relu
    topi.nn.dilate
+   topi.nn.dilation2d
    topi.nn.pool
    topi.nn.global_pool
    topi.nn.adaptive_pool
@@ -197,6 +198,7 @@ topi.nn
 .. autofunction:: topi.nn.upsampling
 .. autofunction:: topi.nn.softmax
 .. autofunction:: topi.nn.dense
+.. autofunction:: topi.nn.dilation2d
 .. autofunction:: topi.nn.batch_matmul
 .. autofunction:: topi.nn.log_softmax
 .. autofunction:: topi.nn.conv2d_nchw
diff --git a/docs/frontend/tensorflow.rst b/docs/frontend/tensorflow.rst
index 8a54033..e06794d 100644
--- a/docs/frontend/tensorflow.rst
+++ b/docs/frontend/tensorflow.rst
@@ -140,6 +140,7 @@ Supported Ops
 - DecodeJpeg
 - DepthwiseConv2dNative
 - DepthToSpace
+- Dilation2D
 - Equal
 - Elu
 - Enter
diff --git a/docs/langref/relay_op.rst b/docs/langref/relay_op.rst
index 1be2eb5..94ba19d 100644
--- a/docs/langref/relay_op.rst
+++ b/docs/langref/relay_op.rst
@@ -70,6 +70,7 @@ This level enables typical convnet models.
    tvm.relay.nn.conv2d
    tvm.relay.nn.conv2d_transpose
    tvm.relay.nn.dense
+   tvm.relay.nn.dilation2d
    tvm.relay.nn.max_pool2d
    tvm.relay.nn.max_pool3d
    tvm.relay.nn.avg_pool2d
@@ -249,6 +250,7 @@ Level 2 Definitions
 .. autofunction:: tvm.relay.nn.conv2d
 .. autofunction:: tvm.relay.nn.conv2d_transpose
 .. autofunction:: tvm.relay.nn.dense
+.. autofunction:: tvm.relay.nn.dilation2d
 .. autofunction:: tvm.relay.nn.max_pool2d
 .. autofunction:: tvm.relay.nn.max_pool3d
 .. autofunction:: tvm.relay.nn.avg_pool2d
diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h
index 9a73358..6a7ee41 100644
--- a/include/tvm/relay/attrs/nn.h
+++ b/include/tvm/relay/attrs/nn.h
@@ -156,6 +156,42 @@ struct Conv2DAttrs : public tvm::AttrsNode<Conv2DAttrs> {
 };
 
 
+/*! \brief Attributes used in dilation operators */
+struct Dilation2DAttrs : public tvm::AttrsNode<Dilation2DAttrs> {
+  Array<IndexExpr> strides;
+  Array<IndexExpr> padding;
+  Array<IndexExpr> dilations;
+  std::string data_layout;
+  std::string kernel_layout;
+  DataType out_dtype;
+
+  TVM_DECLARE_ATTRS(Dilation2DAttrs, "relay.attrs.Dilation2DAttrs") {
+    TVM_ATTR_FIELD(strides).set_default(Array<IndexExpr>({1, 1}))
+        .describe("Specifies the strides of the sliding window. [stride_height, stride_width].");
+    TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0, 0}))
+        .describe("If padding is non-zero, then the input is implicitly zero-padded"
+                  "Padding support both symmetric and asymmetric as"
+                  "one int : same padding used on all sides"
+                  "two int : bottom, right will use same padding as top, left"
+                  "four int : padding width in the order of (top, left, bottom, right)");
+    TVM_ATTR_FIELD(dilations).set_default(Array<IndexExpr>({1, 1}))
+        .describe("Specifies the dilation rate to use. [dilation_height, dilation_width]");
+    TVM_ATTR_FIELD(data_layout).set_default("NCHW")
+        .describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
+                  "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
+                  "dimensions respectively. Convolution is applied on the 'H' and"
+                  "'W' dimensions.");
+    TVM_ATTR_FIELD(kernel_layout).set_default("IHW")
+        .describe("Dimension ordering of weight. Can be 'IHW', 'HWI', etc."
+                  "'I', 'H', 'W' stands for input_channel, height, and width"
+                  "dimensions respectively.");
+    TVM_ATTR_FIELD(out_dtype)
+        .set_default(NullValue<DataType>())
+        .describe("Output data type, set to explicit type under mixed precision setting");
+  }
+};
+
+
 /*! \brief Attributes used in winograd weight transformation operators */
 struct Conv2DWinogradWeightTransformAttrs :
     public tvm::AttrsNode<Conv2DWinogradWeightTransformAttrs> {
diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py
index fdfcea8..3dca365 100644
--- a/python/tvm/relay/frontend/tensorflow.py
+++ b/python/tvm/relay/frontend/tensorflow.py
@@ -410,6 +410,91 @@ def _conv(opname):
         return out
     return _impl
 
+
+# Dilation2d
+def _dilation2d():
+    def _impl(inputs, attr, params):
+        if 'data_format' not in attr:
+            attr['data_format'] = 'NHWC'
+
+        input_shape = attr['_input_shapes'][inputs[0]]
+        weights_shape = attr['_input_shapes'][inputs[1]]
+
+        if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC":
+            input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)]
+            inputs[0] = _op.transpose(inputs[0], axes=(0, 3, 1, 2))
+            weights_shape = [weights_shape[ii] for ii in (2, 0, 1)]
+            inputs[1] = _op.transpose(inputs[1], axes=(2, 0, 1))
+            attr['data_format'] = "NCHW"
+
+        if attr['data_format'] in ['NHWC', 'NCHW']:
+            if 'rates' in attr:
+                attr['dilations'] = attr['rates']
+            if 'dilations' in attr:
+                attr['dilations'] = (attr['dilations'][1], attr['dilations'][2])
+            attr['strides'] = (attr['strides'][1], attr['strides'][2])
+        else:
+            msg = 'Value {} in attribute "data_format" of operator Dilation2D is ' \
+                  'not valid.'
+            raise tvm.error.OpAttributeInvalid(msg.format(attr['data_format']))
+
+        attr['padding'] = attr['padding'].decode("utf-8")
+        if attr['padding'] == 'VALID':
+            attr['padding'] = [0, 0]
+        elif attr['padding'] == 'SAME':
+            stride_h, stride_w = attr['strides']
+            if attr['data_format'] == 'NHWC':
+                kernel_h, kernel_w = weights_shape[0], weights_shape[1]
+            else:
+                kernel_h, kernel_w = weights_shape[1], weights_shape[2]
+            if attr['data_format'] == 'NHWC':
+                in_h = input_shape[1]
+                in_w = input_shape[2]
+            else:
+                in_h = input_shape[2]
+                in_w = input_shape[3]
+
+            dilation_h = attr['dilations'][0]
+            dilation_w = attr['dilations'][1]
+            dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
+            dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
+            pad_v = _get_pad_pair(in_h, dilated_kernel_h, stride_h)
+            pad_h = _get_pad_pair(in_w, dilated_kernel_w, stride_w)
+
+            if attr['data_format'] == 'NHWC':
+                inputs[0] = _op.nn.pad(data=inputs[0],
+                                       pad_width=((0, 0),
+                                                  (pad_v[0], pad_v[1]),
+                                                  (pad_h[0], pad_h[1]),
+                                                  (0, 0)))
+            else:
+                inputs[0] = _op.nn.pad(data=inputs[0],
+                                       pad_width=((0, 0),
+                                                  (0, 0),
+                                                  (pad_v[0], pad_v[1]),
+                                                  (pad_h[0], pad_h[1])))
+
+            attr['padding'] = [0, 0]
+
+        else:
+            msg = 'Value {} in attribute "padding" of operator Dilation2d is not ' \
+                  'valid.'
+            raise tvm.error.OpAttributeInvalid(msg.format(attr['padding']))
+
+        attr['kernel_layout'] = 'HWI' if attr['data_format'] == 'NHWC' else 'IHW'
+        out = AttrCvt(
+            op_name='dilation2d',
+            ignores=['explicit_paddings', 'rates'],
+            transforms={
+                'data_format': 'data_layout',
+            })([inputs[0], inputs[1]], attr)
+        if attr['_target_layout'] == "NCHW":
+            out = _op.transpose(out, axes=(0, 2, 3, 1))
+        return out
+
+    return _impl
+
+
 def _conv3d(opname):
     def _impl(inputs, attr, params):
         attr['data_format'] = attr['data_format'].decode("utf-8")
@@ -1550,6 +1635,7 @@ _convert_map = {
     'DecodeJpeg'                        : _decode_image(),
     'DepthwiseConv2dNative'             : _conv('depthwise'),
     'DepthToSpace'                      : _depth_to_space(),
+    'Dilation2D'                        : _dilation2d(),
     'Equal'                             : _broadcast('equal'),
     'Elu'                               : _elu(),
     'Erf'                               : AttrCvt('erf'),
diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py
index c641c9d..c2fe6d0 100644
--- a/python/tvm/relay/op/nn/_nn.py
+++ b/python/tvm/relay/op/nn/_nn.py
@@ -186,6 +186,9 @@ def legalize_conv2d_transpose(attrs, inputs, types):
 reg.register_strategy("nn.conv3d", strategy.conv3d_strategy)
 reg.register_pattern("nn.conv3d", OpPattern.OUT_ELEMWISE_FUSABLE)
 
+# dilation2d
+reg.register_strategy("nn.dilation2d", strategy.dilation2d_strategy)
+reg.register_pattern("nn.dilation2d", OpPattern.OUT_ELEMWISE_FUSABLE)
 
 # conv1d_transpose
 reg.register_strategy("nn.conv1d_transpose", strategy.conv1d_transpose_strategy)
diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py
index c62b1cf..66c4ec3 100644
--- a/python/tvm/relay/op/nn/nn.py
+++ b/python/tvm/relay/op/nn/nn.py
@@ -2463,3 +2463,60 @@ def adaptive_avg_pool3d(data,
     """
     output_size = [] or output_size
     return _make.adaptive_avg_pool3d(data, output_size, layout)
+
+
+def dilation2d(data,
+               weight,
+               strides=(1, 1),
+               padding=(0, 0),
+               dilations=(1, 1),
+               data_layout="NCHW",
+               kernel_layout="IHW",
+               out_dtype=""):
+    r"""Dilation 2D.
+    This operator takes the weight as the dilation kernel and dilates it with
+    data to produce an output. In the default case, where the data_layout is `NCHW`
+    and kernel_layout is `OIHW`, dilation2d takes in a data Tensor with shape
+    `(batch_size, in_channels, height, width)`, and a weight Tensor with shape
+    `(channels, kernel_height, kernel_width)` to produce an output Tensor
+    with the following rule:
+
+    .. math::
+        \mbox{out}[b, c, y, x] = \max_{dy, dx}
+           \mbox{data}[b, c, \mbox{strides}[0] * y  + dy, \mbox{strides}[1] * x + dx] +
+           \mbox{weight}[c, dy, dx]
+
+    Padding and dilation are applied to data and weight respectively before the computation.
+    This operator accepts data layout specification. Semantically, the operator
+    will convert the layout to the canonical layout
+    (`NCHW` for data and `IHW` for weight) and perform the computation.
+
+    weight : tvm.relay.Expr
+        The weight expressions.
+
+    strides : Optional[Tuple[int]]
+        The strides of convolution.
+
+    padding : Optional[Tuple[int]]
+        The padding of convolution on both sides of inputs before convolution.
+
+    dilations : Optional[Tuple[int]]
+        Specifies the dilation rate to be used for dilated convolution.
+
+    data_layout : Optional[str]
+        Layout of the input.
+
+    kernel_layout : Optional[str]
+        Layout of the weight.
+
+    out_dtype : Optional[str]
+        Specifies the output data type.
+
+    Returns
+    -------
+    result : tvm.relay.Expr
+        The computed result.
+    """
+
+    return _make.dilation2d(data, weight, strides, padding, dilations, data_layout,
+                            kernel_layout, out_dtype)
diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py
index 12abf4a..37141e4 100644
--- a/python/tvm/relay/op/op_attrs.py
+++ b/python/tvm/relay/op/op_attrs.py
@@ -44,6 +44,9 @@ class Conv2DWinogradWeightTransformAttrs(Attrs):
 class Conv2DWinogradNNPACKWeightTransformAttrs(Attrs):
     """Attributes for nn.contrib_conv2d_winograd_nnpack_weight_transform"""
 
+@register_relay_attr_node
+class Dilation2DAttrs(Attrs):
+    """Attributes for nn.dilation2d"""
 
 @register_relay_attr_node
 class GlobalPool2DAttrs(Attrs):
diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py
index 312ce95..e849f8c 100644
--- a/python/tvm/relay/op/strategy/generic.py
+++ b/python/tvm/relay/op/strategy/generic.py
@@ -442,6 +442,57 @@ def conv1d_transpose_strategy(attrs, inputs, out_type, target):
                                 name="conv1d_transpose_ncw.generic")
     return strategy
 
+
+# dilation2d
+def wrap_compute_dilation2d(topi_compute, need_data_layout=False):
+    """Wrap dilation2d topi compute"""
+    def _compute_dilation2d(attrs, inputs, out_type):
+        padding = get_const_tuple(attrs.padding)
+        strides = get_const_tuple(attrs.strides)
+        dilations = get_const_tuple(attrs.dilations)
+        data_layout = attrs.get_str("data_layout")
+        out_dtype = attrs.out_dtype
+        out_dtype = (inputs[0].dtype if out_dtype in ("same", "")
+                     else out_dtype)
+        args = [inputs[0], inputs[1], strides, padding, dilations]
+        if need_data_layout:
+            args.append(data_layout)
+        args.append(out_dtype)
+        return [topi_compute(*args)]
+    return _compute_dilation2d
+
+
+@override_native_generic_func("dilation2d_strategy")
+def dilation2d_strategy(attrs, inputs, out_type, target):
+    """dilation2d_strategy generic strategy"""
+    logger.warning("dilation2d_strategy is not optimized for this platform.")
+    strategy = _op.OpStrategy()
+    dilations = get_const_tuple(attrs.dilations)
+    layout = attrs.data_layout
+    kernel_layout = attrs.kernel_layout
+
+    assert layout in ["NCHW", "NHWC"]
+    (dilation_h, dilation_w) = dilations
+    if dilation_h < 1 or dilation_w < 1:
+        raise ValueError("dilation should be positive value")
+
+    if layout == "NCHW":
+        assert kernel_layout == "IHW"
+        strategy.add_implementation(
+            wrap_compute_dilation2d(topi.nn.dilation2d_nchw),
+            wrap_topi_schedule(topi.generic.schedule_dilation2d_nchw),
+            name="dilation2d_nchw.generic")
+    elif layout == "NHWC":
+        assert kernel_layout == "HWI"
+        strategy.add_implementation(
+            wrap_compute_dilation2d(topi.nn.dilation2d_nhwc),
+            wrap_topi_schedule(topi.generic.schedule_dilation2d_nhwc),
+            name="dilation2d_nhwc.generic")
+    else:
+        raise RuntimeError("Unsupported dilation2d layout {}".format(layout))
+    return strategy
+
+
 # dense
 def wrap_compute_dense(topi_compute):
     """wrap dense topi compute"""
diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc
index 5906d3a..d642e2f 100644
--- a/src/relay/op/nn/convolution.cc
+++ b/src/relay/op/nn/convolution.cc
@@ -1040,6 +1040,66 @@ Expr MakeDeformableConv2D(Expr data,
 TVM_REGISTER_GLOBAL("relay.op.nn._make.deformable_conv2d")
 .set_body_typed(MakeDeformableConv2D);
 
+// relay.nn.dilation2d
+TVM_REGISTER_NODE_TYPE(Dilation2DAttrs);
+
+template<typename T>
+Array<Array<Layout> > Dilation2DInferCorrectLayout(
+    const Attrs& attrs,
+    const Array<Layout>& new_in_layouts,
+    const Array<Layout>& old_in_layouts,
+    const Array<Array<IndexExpr>> &old_in_shapes) {
+  const T* params = attrs.as<T>();
+
+  // We always make other operators to fit the layouts of convolution layers
+  // So this inference ignores all inputs
+  return Array<Array<Layout> >{{params->data_layout, params->kernel_layout},
+                               {params->data_layout}};
+}
+
+// Positional relay function to create dilation2d operator
+// used by frontend FFI.
+Expr MakeDilation2D(Expr data,
+                    Expr weight,
+                    Array<IndexExpr> strides,
+                    Array<IndexExpr> padding,
+                    Array<IndexExpr> dilations,
+                    std::string data_layout,
+                    std::string kernel_layout,
+                    DataType out_dtype) {
+  auto attrs = make_object<Dilation2DAttrs>();
+  attrs->strides = std::move(strides);
+  attrs->padding = std::move(padding);
+  attrs->dilations = std::move(dilations);
+  attrs->data_layout = std::move(data_layout);
+  attrs->kernel_layout = std::move(kernel_layout);
+  attrs->out_dtype = std::move(out_dtype);
+  static const Op& op = Op::Get("nn.dilation2d");
+  return CallNode::make(op, {data, weight}, Attrs(attrs), {});
+}
+
+
+TVM_REGISTER_GLOBAL("relay.op.nn._make.dilation2d")
+.set_body_typed(MakeDilation2D);
+
+
+RELAY_REGISTER_OP("nn.dilation2d")
+.describe(R"code(Computes grayscale dilation of 4D input and 3D filter.
+- **data**: This depends on the `layout` parameter. Input is 4D array of shape
+            (batch_size, in_channels, height, width) if `layout` is `NCHW`.
+- **weight**: (in_channels, height, width)
+- **out**:  This depends on the `layout` parameter. Output is 4D array of shape
+            (batch_size, channels, out_height, out_width) if `layout` is `NCHW`.
+)code" TVM_ADD_FILELINE)
+.set_attrs_type<Dilation2DAttrs>()
+.set_num_inputs(2)
+.add_argument("data", "Tensor", "The input tensor.")
+.add_argument("weight", "Tensor", "The weight tensor.")
+.set_support_level(2)
+.add_type_rel("Dilation2D", Dilation2DRel<Dilation2DAttrs>)
+.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
+         Dilation2DInferCorrectLayout<Dilation2DAttrs>);
+
 
 }  // namespace relay
 }  // namespace tvm
diff --git a/src/relay/op/nn/convolution.h b/src/relay/op/nn/convolution.h
index 9ee84a0..d451215 100644
--- a/src/relay/op/nn/convolution.h
+++ b/src/relay/op/nn/convolution.h
@@ -360,6 +360,77 @@ bool Conv3DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
   return true;
 }
 
+template <typename AttrType>
+bool Dilation2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+               const TypeReporter& reporter) {
+  CHECK_EQ(types.size(), 3);
+  const auto* data = types[0].as<TensorTypeNode>();
+  const auto* weight = types[1].as<TensorTypeNode>();
+  if (data == nullptr) return false;
+  static const Layout kNCHW("NCHW");
+  static const Layout kOIHW("IHW");
+
+  const AttrType* param = attrs.as<AttrType>();
+  CHECK(param != nullptr);
+  const Layout in_layout(param->data_layout);
+  const Layout kernel_layout(param->kernel_layout);
+
+  const auto trans_in_layout = BijectiveLayoutNode::make(in_layout, kNCHW);
+  CHECK(trans_in_layout.defined())
+      << "Dilation2D only support input layouts that are convertible from NCHW."
+      << " But got " << in_layout;
+
+  const auto trans_kernel_layout = BijectiveLayoutNode::make(kernel_layout, kOIHW);
+  CHECK(trans_kernel_layout.defined())
+      << "Dilation2D only support kernel layouts that are convertible from OIHW."
+      << " But got " << kernel_layout;
+
+  Layout out_layout(param->data_layout);
+  const auto trans_out_layout = BijectiveLayoutNode::make(out_layout, kNCHW);
+  CHECK(trans_out_layout.defined())
+      << "Dilation2D only support output layouts that are convertible from NCHW."
+      << " But got " << out_layout;
+
+  Array<IndexExpr> dshape_nchw = trans_in_layout.ForwardShape(data->shape);
+
+  IndexExpr channels, dilated_ksize_y, dilated_ksize_x;
+
+  // use weight to infer the conv shape.
+  if (weight == nullptr) return false;
+  auto wshape = trans_kernel_layout.ForwardShape(weight->shape);
+  channels = wshape[0];
+
+  dilated_ksize_y = 1 + (wshape[1] - 1) * param->dilations[0];
+  dilated_ksize_x = 1 + (wshape[2] - 1) * param->dilations[1];
+
+  // dilation
+  Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});
+  IndexExpr pad_h, pad_w;
+  GetPaddingHeightWidth(param->padding, &pad_h, &pad_w);
+  if (!dshape_nchw[2].as<tir::AnyNode>()) {
+    oshape.Set(2, indexdiv(dshape_nchw[2] + pad_h - dilated_ksize_y,
+                           param->strides[0]) + 1);
+  } else {
+    oshape.Set(2, dshape_nchw[2]);
+  }
+
+  if (!dshape_nchw[3].as<tir::AnyNode>()) {
+    oshape.Set(3, indexdiv(dshape_nchw[3] + pad_w - dilated_ksize_x,
+                           param->strides[1]) + 1);
+  } else {
+    oshape.Set(3, dshape_nchw[3]);
+  }
+
+  DataType out_dtype = param->out_dtype;
+  if (out_dtype.bits() == 0) {
+    out_dtype = data->dtype;
+  }
+  oshape = trans_out_layout.BackwardShape(oshape);
+  // assign output type
+  reporter->Assign(types[2], TensorType(oshape, out_dtype));
+  return true;
+}
+
 }  // namespace relay
 }  // namespace tvm
 #endif  // TVM_RELAY_OP_NN_CONVOLUTION_H_
diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py
index bb52695..a57d50c 100644
--- a/tests/python/frontend/tensorflow/test_forward.py
+++ b/tests/python/frontend/tensorflow/test_forward.py
@@ -3037,7 +3037,51 @@ def test_forward_add_n():
     _test_forward_add_n(in5)
 
 
-#######################################################################
+def _test_dilation2d(tensor_in_sizes, filter_in_sizes,
+                     strides, dilations, padding):
+    """ One iteration of dilation2d with given shapes and attributes """
+
+    total_size_1 = np.prod(tensor_in_sizes)
+    total_size_2 = np.prod(filter_in_sizes)
+    # Initializes the input tensor with array containing incrementing
+    # numbers from 1.
+    data_array = [f * 1.0 for f in range(1, total_size_1 + 1)]
+    filter_array = [f * 1.0 for f in range(1, total_size_2 + 1)]
+
+    with tf.Graph().as_default():
+        in_data = array_ops.placeholder(shape=tensor_in_sizes, dtype='float32')
+        in_filter = constant_op.constant(
+            filter_array, shape=filter_in_sizes, dtype='float32')
+
+        nn_ops.dilation2d(in_data,
+                          in_filter,
+                          strides=strides,
+                          rates=dilations,
+                          padding=padding)
+
+        compare_tf_with_tvm(np.reshape(data_array, tensor_in_sizes).astype('float32'),
+                            'Placeholder:0', 'Dilation2D:0', no_gpu=True)
+
+
+def test_forward_dilation():
+    _test_dilation2d([1, 18, 18, 32], [4, 4, 32], [1, 1, 1, 1], [1, 2, 1, 1], "VALID")
+    _test_dilation2d([1, 15, 15, 32], [4, 4, 32], [1, 1, 1, 1], [1, 2, 1, 1], "SAME")
+    _test_dilation2d([1, 5, 5, 1], [2, 2, 1], [1, 1, 1, 1], [1, 1, 1, 1], "VALID")
+    _test_dilation2d([1, 5, 5, 1], [3, 3, 1], [1, 1, 1, 1], [1, 2, 2, 1], "VALID")
+    _test_dilation2d([1, 5, 5, 3], [3, 3, 3], [1, 1, 1, 1], [1, 1, 1, 1], "SAME")
+    _test_dilation2d([1, 28, 28, 3], [5, 5, 3], [1, 2, 2, 1], [1, 1, 1, 1], "VALID")
+    _test_dilation2d([1, 224, 224, 10], [8, 8, 10], [1, 1, 1, 1], [1, 1, 1, 1], "VALID")
+    _test_dilation2d([1, 18, 18, 32], [4, 4, 32], [1, 1, 1, 1], [1, 2, 1, 1], "SAME")
+    _test_dilation2d([1, 15, 15, 32], [4, 4, 32], [1, 1, 1, 1], [1, 2, 1, 1], "VALID")
+    _test_dilation2d([1, 5, 5, 1], [7, 2, 1], [1, 3, 1, 1], [1, 1, 1, 1], "SAME")
+    _test_dilation2d([1, 5, 5, 1], [3, 4, 1], [1, 2, 1, 1], [1, 2, 2, 1], "SAME")
+    _test_dilation2d([1, 5, 5, 3], [3, 3, 3], [1, 1, 4, 1], [1, 1, 1, 1], "VALID")
+    _test_dilation2d([1, 28, 28, 3], [5, 6, 3], [1, 1, 2, 1], [1, 1, 1, 1], "SAME")
+    _test_dilation2d([1, 224, 224, 10], [8, 8, 10], [1, 3, 1, 1], [1, 1, 1, 1], "SAME")
+    _test_dilation2d([1, 3, 3, 1], [2, 2, 1], [1, 1, 1, 1], [1, 2, 2, 1], "SAME")
+    _test_dilation2d([1, 3, 3, 1], [2, 2, 1], [1, 1, 1, 1], [1, 1, 2, 1], "VALID")
+
+# #######################################################################
 # Main
 # ----
 if __name__ == '__main__':
@@ -3131,6 +3175,7 @@ if __name__ == '__main__':
     test_forward_l2_normalize()
     test_forward_space_to_batch_nd()
     test_forward_batch_to_space_nd()
+    test_forward_dilation()
 
     # End to End
     test_forward_inception_v3()
diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py
index 7a42fc3..53a5aa3 100644
--- a/tests/python/relay/test_op_level2.py
+++ b/tests/python/relay/test_op_level2.py
@@ -1219,6 +1219,113 @@ def test_depthwise_conv2d_int8():
                 graph, lib, params = relay.build(func, target, params=parameters)
 
 
+def test_dilation2d_infer_type():
+    # symbolic in batch dimension
+    n, h, w, c = te.var("n"), 224, 224, 10
+    x = relay.var("x", relay.ty.TensorType((n, c, h, w), "float32"))
+    kc, kh, kw = 10, 8, 8
+    w = relay.var("w", relay.ty.TensorType((kc, kw, kh), "float32"))
+    y = relay.nn.dilation2d(x, w,
+                            # kernel_size=(3, 3),
+                            strides=[1, 1, 1, 1],
+                            dilations=[1, 1, 1, 1],
+                            padding=[0, 0, 0, 0])
+    yy = run_infer_type(y)
+    assert yy.checked_type == relay.TensorType(
+        (n, 10, 217, 217), "float32")
+
+
+def test_dilation2d_run():
+    def run_test_dilation2d(indata, kernel, out,
+                            dtype='float32',
+                            strides=[1, 1],
+                            padding=[0, 0],
+                            dilations=[1, 1],
+                            except_targets=['cuda'],
+                            **attrs):
+
+        dshape = indata.shape
+        kshape = kernel.shape
+
+        if except_targets is None:
+            except_targets = []
+
+        x = relay.var("x", shape=dshape, dtype=dtype)
+        w = relay.var("w", shape=kshape, dtype=dtype)
+        y = relay.nn.dilation2d(x, w,
+                                strides=strides,
+                                dilations=dilations,
+                                padding=padding,
+                                **attrs)
+        func = relay.Function([x, w], y)
+
+        for target, ctx in ctx_list():
+            if target in except_targets:
+                continue
+            intrp = relay.create_executor("graph", ctx=ctx, target=target)
+            op_res = intrp.evaluate(func)(indata, kernel)
+            tvm.testing.assert_allclose(op_res.asnumpy(), out, rtol=1e-5, atol=1e-5)
+
+    def _convert_data(indata, kernel, out, layout=None):
+        indata = np.asarray(indata)
+        kernel = np.asarray(kernel)
+        out = np.asarray(out)
+        if layout == 'NCHW':
+            indata = indata.transpose([0, 3, 1, 2])
+            kernel = kernel.transpose([2, 0, 1])
+            out = out.transpose([0, 3, 1, 2])
+        return indata, kernel, out
+
+    image = [[[[.1], [.2]], [[.3], [.4]]]]
+    kernel = [[[.4], [.3]], [[.1], [.0]]]
+    out = [[[[.5]]]]
+    run_test_dilation2d(*_convert_data(image, kernel, out, layout='NCHW'))
+    run_test_dilation2d(*_convert_data(image, kernel, out), data_layout='NHWC', kernel_layout='HWI')
+
+    image = [[[[.1], [.2]], [[.3], [.4]]]]
+    kernel = [[[.4], [.3]], [[.1], [.0]]]
+    out = [[[[.5], [.6]], [[.7], [.8]]]]
+    run_test_dilation2d(*_convert_data(image, kernel, out, layout='NCHW'), padding=[0, 0, 1, 1])
+    run_test_dilation2d(*_convert_data(image, kernel, out), padding=[0, 0, 1, 1],
+                        data_layout='NHWC', kernel_layout='HWI')
+
+    image = [[[[.1, .2, .0], [.2, .3, .1]], [[.3, .4, .2], [.4, .5, .3]]]]
+    kernel = [[[.4, .5, .3], [.3, .4, .2]], [[.1, .2, .0], [.0, .1, -.1]]]
+    out = [[[[.5, .7, .3], [.6, .8, .4]], [[.7, .9, .5], [.8, 1., .6]]]]
+    run_test_dilation2d(*_convert_data(image, kernel, out, layout='NCHW'), padding=[0, 0, 1, 1])
+    run_test_dilation2d(*_convert_data(image, kernel, out), padding=[0, 0, 1, 1],
+                        data_layout='NHWC', kernel_layout='HWI')
+
+    image = [[[[.1], [.2]], [[.3], [.4]]], [[[.2], [.3]], [[.4], [.5]]]]
+    kernel = [[[.4], [.3]], [[.1], [.0]]]
+    out = [[[[.5], [.6]], [[.7], [.8]]], [[[.6], [.7]], [[.8], [.9]]]]
+    run_test_dilation2d(*_convert_data(image, kernel, out, layout='NCHW'), padding=[0, 0, 1, 1])
+    run_test_dilation2d(*_convert_data(image, kernel, out), padding=[0, 0, 1, 1],
+                        data_layout='NHWC', kernel_layout='HWI')
+
+    image = [[[[.1], [.2]], [[.3], [.4]]]]
+    kernel = [[[.4], [.3]]]
+    out = [[[[.5]], [[.7]]]]
+    run_test_dilation2d(*_convert_data(image, kernel, out, layout='NCHW'))
+    run_test_dilation2d(*_convert_data(image, kernel, out),
+                        data_layout='NHWC', kernel_layout='HWI')
+
+    image = [[[[.1], [.2], [.3]], [[.4], [.5], [.6]], [[.7], [.8], [.9]]]]
+    kernel = [[[.4], [.3]], [[.1], [.2]]]
+    out = [[[[.7], [.8], [.6]], [[1.0], [1.1], [.9]], [[.8], [.9], [.9]]]]
+    run_test_dilation2d(*_convert_data(image, kernel, out, layout='NCHW'), padding=[1, 1], dilations=[2, 2])
+    run_test_dilation2d(*_convert_data(image, kernel, out), padding=[1, 1], dilations=[2, 2],
+                        data_layout='NHWC', kernel_layout='HWI')
+
+    image = [[[[.1], [.2], [.3], [.4]], [[.5], [.6], [.7], [.8]],
+              [[.9], [1.0], [1.1], [1.2]]]]
+    kernel = [[[.4], [.3]], [[.1], [.2]]]
+    out = [[[[.8], [1.0]], [[1.2], [1.4]]]]
+    run_test_dilation2d(*_convert_data(image, kernel, out, layout='NCHW'), strides=[1, 2])
+    run_test_dilation2d(*_convert_data(image, kernel, out), strides=[1, 2],
+                        data_layout='NHWC', kernel_layout='HWI')
+
+
 def test_bitserial_conv2d_infer_type():
     # Basic shape test with ambiguous batch.
     n, c, h, w = te.size_var("n"), 32, 224, 224
@@ -1274,3 +1381,5 @@ if __name__ == "__main__":
     test_upsampling3d()
     test_conv2d_int8_intrinsics()
     test_depthwise_conv2d_int8()
+    test_dilation2d_infer_type()
+    test_dilation2d_run()
diff --git a/topi/python/topi/generic/nn.py b/topi/python/topi/generic/nn.py
index 25b5501..7177d04 100644
--- a/topi/python/topi/generic/nn.py
+++ b/topi/python/topi/generic/nn.py
@@ -648,3 +648,33 @@ def schedule_batch_matmul(outs):
         The computation schedule for the op.
     """
     return _default_schedule(outs, False)
+
+
+def schedule_dilation2d_nchw(outs):
+    """Schedule for dilation2d
+    Parameters
+    ----------
+    outs : Array of Tensor
+        The computation graph description of dilation2d
+        in the format of an array of tensors.
+    Returns
+    -------
+    sch : Schedule
+        The computation schedule for the op.
+    """
+    return _default_schedule(outs, False)
+
+
+def schedule_dilation2d_nhwc(outs):
+    """Schedule for dilation2d
+    Parameters
+    ----------
+    outs : Array of Tensor
+        The computation graph description of dilation2d
+        in the format of an array of tensors.
+    Returns
+    -------
+    sch : Schedule
+        The computation schedule for the op.
+    """
+    return _default_schedule(outs, False)
diff --git a/topi/python/topi/nn/__init__.py b/topi/python/topi/nn/__init__.py
index bd806b9..1b1067f 100644
--- a/topi/python/topi/nn/__init__.py
+++ b/topi/python/topi/nn/__init__.py
@@ -24,6 +24,7 @@ from .conv2d import *
 from .conv3d import *
 from .deformable_conv2d import *
 from .depthwise_conv2d import *
+from .dilation2d import *
 from .elemwise import *
 from .dilate import *
 from .flatten import *
diff --git a/topi/python/topi/nn/dilation2d.py b/topi/python/topi/nn/dilation2d.py
new file mode 100644
index 0000000..9cb4284
--- /dev/null
+++ b/topi/python/topi/nn/dilation2d.py
@@ -0,0 +1,165 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-variable, too-many-locals
+# pylint: disable=unused-argument, redefined-builtin
+"""Dilation2D operators"""
+from __future__ import absolute_import as _abs
+from tvm import te
+from topi.util import simplify
+from .pad import pad
+from .util import get_pad_tuple
+
+
+def dilation2d_nchw(input, filter, stride, padding, dilations, out_dtype=None):
+    """Dilation2D operator in NCHW layout.
+
+    Parameters
+    ----------
+    input : tvm.Tensor
+        4-D with shape [batch, in_channel, in_height, in_width]
+
+    filter : tvm.Tensor
+        3-D with shape [ in_channel, filter_height, filter_width]
+
+    stride : int or a list/tuple of two ints
+        Stride size, or [stride_height, stride_width]
+
+    padding : int or str
+        Padding size
+
+    dilations: int or a list/tuple of two ints
+        dilation size, or [dilation_height, dilation_width]
+
+    out_dtype : Optional[str]
+        Specifies the output data type.
+
+    Returns
+    -------
+    Output : tvm.Tensor
+        4-D with shape [batch, in_channel, out_height, out_width]
+    """
+    if out_dtype is None:
+        out_dtype = input.dtype
+    assert isinstance(stride, int) or len(stride) == 2
+    assert isinstance(dilations, int) or len(dilations) == 2
+    if isinstance(stride, int):
+        stride_h = stride_w = stride
+    else:
+        stride_h, stride_w = stride
+
+    if isinstance(dilations, int):
+        dilation_h = dilation_w = dilations
+    else:
+        dilation_h, dilation_w = dilations
+
+    batch, in_channel, in_height, in_width = input.shape
+    channel, kernel_h, kernel_w = filter.shape
+    assert in_channel.value == channel.value, \
+        "For Dilation2D input and filter channels should be same."
+
+    # compute the output shape
+    dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
+    dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
+    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
+        padding, (dilated_kernel_h, dilated_kernel_w))
+
+    out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1)
+    out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1)
+    # compute graph
+    pad_before = [0, 0, pad_top, pad_left]
+    pad_after = [0, 0, pad_down, pad_right]
+    temp = pad(input, pad_before, pad_after, name="pad_temp")
+    ry = te.reduce_axis((0, kernel_h), name='ry')
+    rx = te.reduce_axis((0, kernel_w), name='rx')
+
+    return te.compute(
+        (batch, in_channel, out_height, out_width),
+        lambda nn, ff, yy, xx: te.max(
+            temp[nn, ff, yy * stride_h + ry * dilation_h,
+                 xx * stride_w + rx * dilation_w].astype(out_dtype) +
+            filter[ff, ry, rx].astype(out_dtype),
+            axis=[ry, rx]), tag="dilation2d_nchw")
+
+
+def dilation2d_nhwc(input, filter, stride, padding, dilations, out_dtype=None):
+    """Dilation2D operator in NHWC layout.
+
+    Parameters
+    ----------
+    input : tvm.Tensor
+        4-D with shape [batch, in_height, in_width, in_channel]
+
+    filter : tvm.Tensor
+        3-D with shape [filter_height, filter_width, in_channel]
+
+    stride : int or a list/tuple of two ints
+        Stride size, or [stride_height, stride_width]
+
+    padding : int
+        Padding size
+
+    dilations: int or a list/tuple of two ints
+        dilation size, or [dilation_height, dilation_width]
+
+    out_dtype : Optional[str]
+        Specifies the output data type.
+
+    Returns
+    -------
+    Output : tvm.Tensor
+        4-D with shape [batch, out_height, out_width, in_channel]
+    """
+    if out_dtype is None:
+        out_dtype = input.dtype
+    assert isinstance(stride, int) or len(stride) == 2
+    assert isinstance(dilations, int) or len(dilations) == 2
+    if isinstance(stride, int):
+        stride_h = stride_w = stride
+    else:
+        stride_h, stride_w = stride
+
+    if isinstance(dilations, int):
+        dilation_h = dilation_w = dilations
+    else:
+        dilation_h, dilation_w = dilations
+
+    batch, in_height, in_width, in_channel = input.shape
+    kernel_h, kernel_w, channel = filter.shape
+    assert in_channel.value == channel.value, \
+        "For Dilation2D input and filter channels should be same."
+
+    # compute the output shape
+    dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
+    dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
+    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
+        padding, (dilated_kernel_h, dilated_kernel_w))
+
+    out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1)
+    out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1)
+    pad_before = [0, pad_top, pad_left, 0]
+    pad_after = [0, pad_down, pad_right, 0]
+    padded_input = pad(input, pad_before, pad_after, name="padded_input")
+    ry = te.reduce_axis((0, kernel_h), name='ry')
+    rx = te.reduce_axis((0, kernel_w), name='rx')
+
+    return te.compute(
+        (batch, out_height, out_width, in_channel),
+        lambda nn, yy, xx, ff: te.max(
+            padded_input[nn, yy * stride_h + ry * dilation_h,
+                         xx * stride_w + rx * dilation_w, ff].astype(out_dtype) +
+            filter[ry, rx, ff].astype(out_dtype),
+            axis=[ry, rx]), tag="dilation2d_nhcw")