You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2020/03/15 20:43:58 UTC

[incubator-tvm] branch master updated: [Relay, TOPI] Refactor Adaptive pool and add 3d support (#5049)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 7c5ff50  [Relay, TOPI] Refactor Adaptive pool and add 3d support (#5049)
7c5ff50 is described below

commit 7c5ff50873e91e9ad27b5f08847c27d58e8b5c4c
Author: masahi <ma...@gmail.com>
AuthorDate: Mon Mar 16 05:43:48 2020 +0900

    [Relay, TOPI] Refactor Adaptive pool and add 3d support (#5049)
    
    * add stub for nd impl
    
    * refactored indices compute
    
    * refactored divide step
    
    * remove unused variables, add doc
    
    * fix lint
    
    * add relay op def
    
    * add python registration
    
    * refactor topi test
    
    * update relay tests, but test result is weird
    
    * workaround for weird bug
    
    * add relay adaptive pool 3d test
    
    * add topi tests
    
    * update doc for 3d
    
    * typo fix
    
    * fix lint
    
    * add more tests including NDHWC
---
 include/tvm/relay/attrs/nn.h                     |  15 ++
 python/tvm/relay/op/nn/_nn.py                    |  10 ++
 python/tvm/relay/op/nn/nn.py                     |  92 ++++++++++++
 src/relay/op/nn/pooling.cc                       | 175 ++++++++++++++++++++++-
 tests/python/relay/test_op_level10.py            |  54 ++++---
 topi/include/topi/nn/pooling.h                   | 106 ++++++++------
 topi/python/topi/nn/pooling.py                   |  10 ++
 topi/python/topi/testing/__init__.py             |   1 +
 topi/python/topi/testing/adaptive_pool_python.py | 111 ++++++++++++++
 topi/src/topi.cc                                 |   7 +
 topi/tests/python/test_topi_pooling.py           |  44 +++---
 11 files changed, 528 insertions(+), 97 deletions(-)

diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h
index 5620feb..9a73358 100644
--- a/include/tvm/relay/attrs/nn.h
+++ b/include/tvm/relay/attrs/nn.h
@@ -528,6 +528,21 @@ struct AdaptivePool2DAttrs : public tvm::AttrsNode<AdaptivePool2DAttrs> {
   }
 };
 
+struct AdaptivePool3DAttrs : public tvm::AttrsNode<AdaptivePool3DAttrs> {
+  Array<IndexExpr> output_size;
+  std::string layout;
+
+  TVM_DECLARE_ATTRS(AdaptivePool3DAttrs, "relay.attrs.AdaptivePool3DAttrs") {
+    TVM_ATTR_FIELD(output_size).set_default(Array<IndexExpr>({}))
+      .describe("Output depth, height and width.");
+    TVM_ATTR_FIELD(layout).set_default("NCDHW")
+      .describe("Dimension ordering of data and weight. Can be 'NCDHW', 'NDHWC', etc."
+                  "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
+                  "dimensions respectively. Convolution is applied on 'D', 'H' and"
+                  "'W' dimensions.");
+  }
+};
+
 
 /*! \brief Attributes for 1D max pool operator */
 struct MaxPool1DAttrs : public tvm::AttrsNode<MaxPool1DAttrs> {
diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py
index c522ef9..c641c9d 100644
--- a/python/tvm/relay/op/nn/_nn.py
+++ b/python/tvm/relay/op/nn/_nn.py
@@ -257,6 +257,16 @@ reg.register_schedule("nn.adaptive_avg_pool2d", strategy.schedule_adaptive_pool)
 reg.register_pattern("nn.adaptive_avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
 
 
+# adaptive_max_pool3d
+reg.register_schedule("nn.adaptive_max_pool3d", strategy.schedule_adaptive_pool)
+reg.register_pattern("nn.adaptive_max_pool3d", OpPattern.OUT_ELEMWISE_FUSABLE)
+
+
+# adaptive_avg_pool3d
+reg.register_schedule("nn.adaptive_avg_pool3d", strategy.schedule_adaptive_pool)
+reg.register_pattern("nn.adaptive_avg_pool3d", OpPattern.OUT_ELEMWISE_FUSABLE)
+
+
 # leaky_relu
 reg.register_broadcast_schedule("nn.leaky_relu")
 reg.register_pattern("nn.leaky_relu", OpPattern.ELEMWISE)
diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py
index 30918a4..c62b1cf 100644
--- a/python/tvm/relay/op/nn/nn.py
+++ b/python/tvm/relay/op/nn/nn.py
@@ -2371,3 +2371,95 @@ def adaptive_avg_pool2d(data,
     """
     output_size = [] or output_size
     return _make.adaptive_avg_pool2d(data, output_size, layout)
+
+
+def adaptive_max_pool3d(data,
+                        output_size=None,
+                        layout="NCDHW"):
+    r"""3D adaptive max pooling operator. This operator is experimental.
+
+    This operator takes data as input and does 3D max value calculation
+    across each window represented by DxWxH.
+
+    In the default case, where the data_layout is `NCDHW`
+    a data Tensor with shape `(batch_size, in_channels, depth, height, width)`,
+    to produce an output Tensor with shape
+    (batch_size, in_channels, output_depth, output_height, output_width).
+
+    The pooling kernel and stride sizes are automatically chosen for
+    desired output sizes.
+
+    For output_size:
+        If this argument is not provided, input depth, height and width will be used
+        as output depth, height and width.
+
+        If a single integer is provided for output_size, the output size is
+        (N x C x output_size x output_size x output_size) for any input (NCDHW).
+
+        If a tuple of integers (depth, height, width) are provided for output_size,
+        the output size is (N x C x depth x height x width) for any input (NCDHW).
+
+    Parameters
+    ----------
+    data : tvm.relay.Expr
+        The input data to the operator.
+
+    output_size : tuple of int. optional
+        Output height and width.
+
+    layout : str, optional
+        Layout of the input.
+
+    Returns
+    -------
+    result : tvm.relay.Expr
+        The computed result.
+    """
+    output_size = [] or output_size
+    return _make.adaptive_max_pool3d(data, output_size, layout)
+
+
+def adaptive_avg_pool3d(data,
+                        output_size=None,
+                        layout="NCDHW"):
+    r"""3D adaptive avg pooling operator. This operator is experimental.
+
+    This operator takes data as input and does 3D avg value calculation
+    across each window represented by DxWxH.
+
+    In the default case, where the data_layout is `NCDHW`
+    a data Tensor with shape `(batch_size, in_channels, depth, height, width)`,
+    to produce an output Tensor with shape
+    (batch_size, in_channels, output_depth, output_height, output_width).
+
+    The pooling kernel and stride sizes are automatically chosen for
+    desired output sizes.
+
+    For output_size:
+        If this argument is not provided, input depth, height and width will be used
+        as output depth, height and width.
+
+        If a single integer is provided for output_size, the output size is
+        (N x C x output_size x output_size x output_size) for any input (NCDHW).
+
+        If a tuple of integers (depth, height, width) are provided for output_size,
+        the output size is (N x C x depth x height x width) for any input (NCDHW).
+
+    Parameters
+    ----------
+    data : tvm.relay.Expr
+        The input data to the operator.
+
+    output_size : tuple of int. optional
+        Output height and width.
+
+    layout : str, optional
+        Layout of the input.
+
+    Returns
+    -------
+    result : tvm.relay.Expr
+        The computed result.
+    """
+    output_size = [] or output_size
+    return _make.adaptive_avg_pool3d(data, output_size, layout)
diff --git a/src/relay/op/nn/pooling.cc b/src/relay/op/nn/pooling.cc
index 2b02c82..b18ed90 100644
--- a/src/relay/op/nn/pooling.cc
+++ b/src/relay/op/nn/pooling.cc
@@ -537,7 +537,6 @@ RELAY_REGISTER_OP("nn.adaptive_avg_pool2d")
                                PoolInferCorrectLayout<AdaptivePool2DAttrs>)
 .set_attr<FTVMCompute>("FTVMCompute", AdaptivePool2DCompute<topi::nn::kAvgPool>);
 
-
 // relay.nn.adaptive_max_pool2d
 Expr MakeAdaptiveMaxPool2D(Expr data,
                            Array<IndexExpr> output_size,
@@ -577,6 +576,180 @@ RELAY_REGISTER_OP("nn.adaptive_max_pool2d")
 .set_attr<FTVMCompute>("FTVMCompute", AdaptivePool2DCompute<topi::nn::kMaxPool>);
 
 
+TVM_REGISTER_NODE_TYPE(AdaptivePool3DAttrs);
+
+bool AdaptivePool3DRel(const Array<Type>& types,
+                       int num_inputs,
+                       const Attrs& attrs,
+                       const TypeReporter& reporter) {
+  CHECK_EQ(types.size(), 2);
+  const auto* data = types[0].as<TensorTypeNode>();
+  if (data == nullptr) { return false; }
+  const auto dshape = data->shape;
+  CHECK_GE(dshape.size(), 3U)
+    << "Pool3D only support input >= 3-D: input must have depth, height and width";
+  const auto* param = attrs.as<AdaptivePool3DAttrs>();
+  CHECK(param != nullptr);
+
+  Layout layout(param->layout);
+  CHECK(layout.Contains(LayoutAxis::Get('D')) && layout.Contains(LayoutAxis::Get('H')) &&
+        layout.Contains(LayoutAxis::Get('W')) && !layout.Contains(LayoutAxis::Get('d')) &&
+       !layout.Contains(LayoutAxis::Get('h')) && !layout.Contains(LayoutAxis::Get('w')))
+    << "Invalid layout " << layout
+    << ". Pool3D layout must have D, H and W, which cannot be split";
+
+  const auto didx = layout.IndexOf(LayoutAxis::Get('D'));
+  const auto hidx = layout.IndexOf(LayoutAxis::Get('H'));
+  const auto widx = layout.IndexOf(LayoutAxis::Get('W'));
+  Array<IndexExpr> oshape(dshape);
+  auto output_size = param->output_size;
+  CHECK_LE(output_size.size(), 3U)
+    << "output_size can have up to 3 elements.";
+  IndexExpr output_depth, output_height, output_width;
+  if (output_size.empty()) {
+    output_depth = dshape[didx];
+    output_height = dshape[hidx];
+    output_width = dshape[widx];
+  } else if (output_size.size() == 1) {
+    output_depth = output_size[0];
+    output_height = output_size[0];
+    output_width = output_size[0];
+  } else {
+    output_depth = output_size[0];
+    output_height = output_size[1];
+    output_width = output_size[2];
+  }
+
+  oshape.Set(didx, output_depth);
+  oshape.Set(hidx, output_height);
+  oshape.Set(widx, output_width);
+
+  // assign output type
+  reporter->Assign(types[1], TensorType(oshape, data->dtype));
+  return true;
+}
+
+template<topi::nn::PoolType mode>
+Array<te::Tensor> AdaptivePool3DCompute(const Attrs& attrs,
+                                        const Array<te::Tensor>& inputs,
+                                        const Type& out_type) {
+  static const Layout kNCDHW("NCDHW");
+  const auto* param = attrs.as<AdaptivePool3DAttrs>();
+  CHECK(param != nullptr);
+  Layout layout(param->layout);
+  CHECK(BijectiveLayoutNode::make(layout, kNCDHW).defined())
+    << "Adaptive pool3d currently only supports layouts that are convertible from NCDHW";
+  CHECK_EQ(layout.IndexOf(LayoutAxis::Get('d')), -1)
+    << "Adaptive pool3d does not support input split on depth";
+  CHECK_EQ(layout.IndexOf(LayoutAxis::Get('h')), -1)
+    << "Adaptive pool3d does not support input split on height";
+  CHECK_EQ(layout.IndexOf(LayoutAxis::Get('w')), -1)
+    << "Adaptive pool3d does not support input split on width";
+
+  CHECK(inputs[0].ndim() == 5U || inputs[0].ndim() == 6U)
+    << "Pool3D only support 5-D input (e.g., NCDHW)"
+    << " or 6-D input (last dimension is a split of channel)";
+
+  auto output_size = param->output_size;
+  const auto didx = layout.IndexOf(LayoutAxis::Get('D'));
+  const auto hidx = layout.IndexOf(LayoutAxis::Get('H'));
+  const auto widx = layout.IndexOf(LayoutAxis::Get('W'));
+  IndexExpr output_depth, output_height, output_width;
+  if (output_size.empty()) {
+    output_depth = inputs[0]->shape[didx];
+    output_height = inputs[0]->shape[hidx];
+    output_width = inputs[0]->shape[widx];
+  } else if (output_size.size() == 1) {
+    output_depth = output_size[0];
+    output_height = output_size[0];
+    output_width = output_size[0];
+  } else {
+    output_depth = output_size[0];
+    output_height = output_size[1];
+    output_width = output_size[2];
+  }
+
+  auto osize = Array<IndexExpr>{ output_depth, output_height, output_width };
+  return Array<te::Tensor> {
+    topi::nn::adaptive_pool3d(inputs[0], osize,  mode, layout.name())
+  };
+}
+
+// relay.nn.adaptive_max_pool3d
+Expr MakeAdaptiveMaxPool3D(Expr data,
+                           Array<IndexExpr> output_size,
+                           std::string layout) {
+  auto attrs = make_object<AdaptivePool3DAttrs>();
+  attrs->output_size = std::move(output_size);
+  attrs->layout = std::move(layout);
+  static const Op& op = Op::Get("nn.adaptive_max_pool3d");
+  return CallNode::make(op, {data}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relay.op.nn._make.adaptive_max_pool3d")
+.set_body_typed(MakeAdaptiveMaxPool3D);
+
+RELAY_REGISTER_OP("nn.adaptive_max_pool3d")
+  .describe(R"code(Adaptive max pooling operation for 3D data.
+
+- **data**: This depends on the `layout` parameter. Input is 5D array of shape
+            (batch_size, channels, depth, height, width) if `layout` is `NCDHW`.
+- **output_size**: If this argument is not provided, input depth, height and width will be used
+                   as output depth, height and width.
+                   If a single integer is provided for output_size, the output size is
+                   (N x C x output_size x output_size x output_size) for any input (NCDHW).
+                   If a tuple of integers (depth, height, width) are provided for output_size,
+                   the output size is (N x C x depth x height x width) for any input (NCDHW).
+- **out**: This depends on the `layout` parameter. Output is 5D array of shape
+           (batch_size, channels, output_depth, output_height, output_width)  if `layout` is `NCDHW`.
+
+)code" TVM_ADD_FILELINE)
+.set_attrs_type<AdaptivePool3DAttrs>()
+.set_num_inputs(1)
+.add_argument("data", "Tensor", "The input tensor.")
+.set_support_level(10)
+.add_type_rel("AdaptiveMaxPool3D", AdaptivePool3DRel)
+.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
+                               PoolInferCorrectLayout<AdaptivePool3DAttrs>)
+.set_attr<FTVMCompute>("FTVMCompute", AdaptivePool3DCompute<topi::nn::kMaxPool>);
+
+// relay.nn.adaptive_max_pool3d
+Expr MakeAdaptiveAvgPool3D(Expr data,
+                           Array<IndexExpr> output_size,
+                           std::string layout) {
+  auto attrs = make_object<AdaptivePool3DAttrs>();
+  attrs->output_size = std::move(output_size);
+  attrs->layout = std::move(layout);
+  static const Op& op = Op::Get("nn.adaptive_avg_pool3d");
+  return CallNode::make(op, {data}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relay.op.nn._make.adaptive_avg_pool3d")
+.set_body_typed(MakeAdaptiveAvgPool3D);
+
+RELAY_REGISTER_OP("nn.adaptive_avg_pool3d")
+  .describe(R"code(Adaptive avg pooling operation for 3D data.
+- **data**: This depends on the `layout` parameter. Input is 5D array of shape
+            (batch_size, channels, depth, height, width) if `layout` is `NCDHW`.
+- **output_size**: If this argument is not provided, input depth, height and width will be used
+                   as output depth, height and width.
+                   If a single integer is provided for output_size, the output size is
+                   (N x C x output_size x output_size x output_size) for any input (NCDHW).
+                   If a tuple of integers (depth, height, width) are provided for output_size,
+                   the output size is (N x C x depth x height x width) for any input (NCDHW).
+- **out**: This depends on the `layout` parameter. Output is 5D array of shape
+           (batch_size, channels, output_depth, output_height, output_width)  if `layout` is `NCDHW`.
+)code" TVM_ADD_FILELINE)
+.set_attrs_type<AdaptivePool3DAttrs>()
+.set_num_inputs(1)
+.add_argument("data", "Tensor", "The input tensor.")
+.set_support_level(10)
+.add_type_rel("AdaptiveAvgPool3D", AdaptivePool3DRel)
+.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
+                               PoolInferCorrectLayout<AdaptivePool3DAttrs>)
+.set_attr<FTVMCompute>("FTVMCompute", AdaptivePool3DCompute<topi::nn::kAvgPool>);
+
+
 bool Pool2DGradRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
                    const TypeReporter& reporter) {
   CHECK_EQ(types.size(), 3);
diff --git a/tests/python/relay/test_op_level10.py b/tests/python/relay/test_op_level10.py
index 2e6ed62..953760c 100644
--- a/tests/python/relay/test_op_level10.py
+++ b/tests/python/relay/test_op_level10.py
@@ -349,46 +349,43 @@ def test_ndarray_size():
     verify_ndarray_size((2, 3, 5))
     verify_ndarray_size((2, 3, 5, 7))
 
-def verify_adaptive_pool2d(dshape, out_size, pool_type, layout="NCHW", dtype="float32"):
-    def start_index(index, odim, idim):
-        return int(np.floor(index * idim / odim))
-
-    def end_index(index, odim, idim):
-        return int(np.ceil((index + 1) * idim / odim))
-
-    np_data = np.random.uniform(low=0, high=255, size=dshape).astype(dtype)
-    n, c, h, w = dshape
-    oh, ow = out_size
-    oshape = (n, c) + out_size
-    np_out = np.zeros(oshape).astype(dtype)
-    np_op = np.mean if pool_type == "avg" else np.max
-    for i in range(n):
-        for j in range(c):
-            for k in range(oh):
-                k_start = start_index(k, oh, h)
-                k_end = end_index(k, oh, h)
-                k_sl = slice(k_start, k_end)
-                for l in range(ow):
-                    l_start = start_index(l, ow, w)
-                    l_end = end_index(l, ow, w)
-                    l_sl = slice(l_start, l_end)
-                    np_out[i, j, k, l] = np_op(np_data[i, j, k_sl, l_sl])
 
-    opfunc = relay.nn.adaptive_avg_pool2d if pool_type == "avg" else relay.nn.adaptive_max_pool2d
-    x = relay.var("x", relay.TensorType((n, c, h, w), "float32"))
+def verify_adaptive_pool(dshape, out_size, pool_type, layout, dtype, opfunc):
+    x = relay.var("x", relay.TensorType(dshape, "float32"))
     y = opfunc(x, out_size, layout)
     func = relay.Function([x], y)
 
+    np_data = np.random.uniform(low=0, high=255, size=dshape).astype(dtype)
+    np_out = topi.testing.adaptive_pool(np_data, out_size, pool_type, layout)
+
     for target, ctx in ctx_list():
         intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
         relay_out = intrp1.evaluate(func)(np_data)
         tvm.testing.assert_allclose(relay_out.asnumpy(), np_out, rtol=1e-5, atol=1e-5)
 
-def test_adaptive_pool2d():
+
+def verify_adaptive_pool2d(dshape, out_size, pool_type, layout="NCHW", dtype="float32"):
+    opfunc = relay.nn.adaptive_avg_pool2d if pool_type == "avg" else relay.nn.adaptive_max_pool2d
+    verify_adaptive_pool(dshape, out_size, pool_type, layout, dtype, opfunc)
+
+
+def verify_adaptive_pool3d(dshape, out_size, pool_type, layout="NCHW", dtype="float32"):
+    opfunc = relay.nn.adaptive_avg_pool3d if pool_type == "avg" else relay.nn.adaptive_max_pool3d
+    verify_adaptive_pool(dshape, out_size, pool_type, layout, dtype, opfunc)
+
+
+def test_adaptive_pool():
     verify_adaptive_pool2d((1, 9, 224, 224), (1, 1), "max")
     verify_adaptive_pool2d((1, 3, 224, 224), (2, 3), "avg")
     verify_adaptive_pool2d((1, 14, 56, 78), (34, 13), "max")
     verify_adaptive_pool2d((1, 5, 46, 97), (4, 96), "avg")
+    verify_adaptive_pool2d((1, 224, 224, 3), (1, 1), "max", layout="NHWC")
+    verify_adaptive_pool2d((1, 3, 224, 224), (2, 3), "avg", layout="NHWC")
+    verify_adaptive_pool3d((1, 16, 32, 32, 32), (1, 1, 1), "max", layout="NCDHW")
+    verify_adaptive_pool3d((1, 16, 32, 32, 32), (1, 1, 1), "avg", layout="NCDHW")
+    verify_adaptive_pool3d((1, 16, 32, 32, 32), (1, 1, 1), "avg", layout="NDHWC")
+    verify_adaptive_pool3d((1, 16, 32, 32, 32), (2, 4, 4), "max", layout="NDHWC")
+
 
 def test_sequence_mask():
     def _verify(data_shape, mask_value, axis, dtype, itype):
@@ -453,7 +450,7 @@ def test_one_hot():
     _verify((3, 2, 4, 5), 6, 1.0, 0.0, 0, "float32")
 
 if __name__ == "__main__":
-    test_adaptive_pool2d()
+    test_adaptive_pool()
     test_collapse_sum_like()
     test_broadcast_to_like()
     test_slice_like()
@@ -463,4 +460,3 @@ if __name__ == "__main__":
     test_sequence_mask()
     test_ndarray_size()
     test_one_hot()
-
diff --git a/topi/include/topi/nn/pooling.h b/topi/include/topi/nn/pooling.h
index e6947ed..20b7b24 100644
--- a/topi/include/topi/nn/pooling.h
+++ b/topi/include/topi/nn/pooling.h
@@ -491,72 +491,72 @@ inline PrimExpr end_index(const Var& out_index,
 }
 
 /*!
-* \brief Perform adaptive pooling on height and width dimension of data.
+* \brief Perform adaptive pooling on N dimensional data
 *
 * \param x The input tensor
-* \param output_size Vector of two ints: {output_height, output_width}
+* \param output_size int vector of size in each dimension
 * \param pool_type The type of pooling operator
-* \param height_axis index of the height dimension
-* \param width_axis index of the width dimension
+* \param axes indices of each dimension
 *
 * \return The output tensor in same layout order
 */
 inline Tensor adaptive_pool_impl(const Tensor& x,
                                  const Array<PrimExpr>& output_size,
                                  PoolType pool_type,
-                                 const size_t height_axis,
-                                 const size_t width_axis) {
-  CHECK_EQ(output_size.size(), 2) << "Pooling kernel_size must have 2 elements";
+                                 const std::vector<int>& axes) {
+  const auto n_dim = output_size.size();
+  CHECK_EQ(axes.size(), n_dim) << "The number of axes not equal to the in/out dimension";
 
-  auto height = x->shape[height_axis];
-  auto width = x->shape[width_axis];
-
-  auto out_height = cast(DataType::Int(32), output_size[0]);
-  auto out_width = cast(DataType::Int(32), output_size[1]);
   Array<PrimExpr> out_shape = x->shape;
-  out_shape.Set(height_axis, out_height);
-  out_shape.Set(width_axis, out_width);
+  Array<PrimExpr> in_size, out_size;
+  for (size_t i = 0; i < n_dim; ++i) {
+    in_size.push_back(x->shape[axes[i]]);
+    out_size.push_back(cast(DataType::Int(32), output_size[i]));
+    out_shape.Set(axes[i], out_size[i]);
+  }
+
+  auto get_iter_vars = [=](const Array<Var>& output, bool reduce_indices) {
+    Array<PrimExpr> indices;
+    for (size_t i = 0; i < output.size(); ++i) indices.push_back(output[i]);
+    Array<tir::IterVar> reduce_axes;
+    for (size_t i = 0; i < n_dim; ++i) {
+      auto i_start = start_index(output[axes[i]], out_size[i], in_size[i]);
+      auto i_end = end_index(output[axes[i]], out_size[i], in_size[i]);
+      auto rv_name = "rv" + std::to_string(i);
+      auto rv_axis = tvm::te::reduce_axis(Range(0, i_end - i_start), rv_name);
+      reduce_axes.push_back(rv_axis);
+      if (reduce_indices) {
+        indices.Set(axes[i], i_start + rv_axis);
+      }
+    }
+    return std::make_tuple(indices, reduce_axes);
+  };
 
   if (pool_type == kMaxPool) {
     return tvm::te::compute(out_shape, [&](const Array<Var>& output) {
       Array<PrimExpr> indices;
-      for (const Var& var : output) indices.push_back(var);
-      auto i_start_h = start_index(output[height_axis], out_height, height);
-      auto i_end_h = end_index(output[height_axis], out_height, height);
-      auto i_start_w = start_index(output[width_axis], out_width, width);
-      auto i_end_w = end_index(output[width_axis], out_width, width);
-      auto dheight = tvm::te::reduce_axis(Range(0, i_end_h - i_start_h), "rv1");
-      auto dwidth = tvm::te::reduce_axis(Range(0, i_end_w - i_start_w), "rv2");
-      indices.Set(height_axis, i_start_h + dheight);
-      indices.Set(width_axis, i_start_w + dwidth);
-      return tvm::max(x(indices), { dheight, dwidth });  // NOLINT(*)
+      Array<tir::IterVar> reduce_axes;
+      std::tie(indices, reduce_axes) = get_iter_vars(output, true);
+      return tvm::max(x(indices), reduce_axes);  // NOLINT(*)
     }, "tensor", "adaptive_pool_max");
   } else if (pool_type == kAvgPool) {
     auto pool_sum = tvm::te::compute(out_shape, [&](const Array<Var>& output) {
       Array<PrimExpr> indices;
-      for (const Var& var : output) indices.push_back(var);
-      auto i_start_h = start_index(output[height_axis], out_height, height);
-      auto i_end_h = end_index(output[height_axis], out_height, height);
-      auto i_start_w = start_index(output[width_axis], out_width, width);
-      auto i_end_w = end_index(output[width_axis], out_width, width);
-      auto divide_factor = tvm::cast(x->dtype, (i_end_h - i_start_h)
-                                               * (i_end_w - i_start_w));
-      auto dheight = tvm::te::reduce_axis(Range(0, i_end_h - i_start_h), "rv1");
-      auto dwidth = tvm::te::reduce_axis(Range(0, i_end_w - i_start_w), "rv2");
-      indices.Set(height_axis, i_start_h + dheight);
-      indices.Set(width_axis, i_start_w + dwidth);
-      return tvm::sum(x(indices), { dheight, dwidth });
+      Array<tir::IterVar> reduce_axes;
+      std::tie(indices, reduce_axes) = get_iter_vars(output, true);
+      return tvm::sum(x(indices), reduce_axes);
     }, "tensor", "adaptive_pool_sum");
 
     return tvm::te::compute(out_shape, [&](const Array<Var>& output) {
       Array<PrimExpr> indices;
-      for (const Var& var : output) indices.push_back(var);
-      auto i_start_h = start_index(output[height_axis], out_height, height);
-      auto i_end_h = end_index(output[height_axis], out_height, height);
-      auto i_start_w = start_index(output[width_axis], out_width, width);
-      auto i_end_w = end_index(output[width_axis], out_width, width);
-      auto divide_factor = tvm::cast(x->dtype, (i_end_h - i_start_h)
-                                               * (i_end_w - i_start_w));
+      Array<tir::IterVar> reduce_axes;
+      std::tie(indices, reduce_axes) = get_iter_vars(output, false);
+
+      PrimExpr divide_factor = tvm::cast(x->dtype, 1);
+      for (size_t i = 0; i < n_dim; ++i) {
+        divide_factor *= tvm::cast(x->dtype, reduce_axes[i]->dom->extent);
+      }
+
       return div(pool_sum(indices), divide_factor);
     }, "tensor", kElementWise);
   } else {
@@ -598,7 +598,25 @@ inline Tensor adaptive_pool(const Tensor& x,
   int height_axis = -1, width_axis = -1;
   CHECK(find_height_width(layout, &height_axis, &width_axis))
     << "Unsupported layout " << layout;
-  return adaptive_pool_impl(x, output_size, pool_type, height_axis, width_axis);
+  return adaptive_pool_impl(x, output_size, pool_type, {height_axis, width_axis});
+}
+
+/*!
+* \brief Adaptively perform pooling on three dimensional data.
+*        See the two dimensional version above for details.
+* \param x The input tensor
+* \param output_size Vector of three ints: {output_depth, output_height, output_width}
+* \param pool_type The type of pooling operator
+* \param layout The input layout. The default is "NCDHW".
+*/
+inline Tensor adaptive_pool3d(const Tensor& x,
+                              const Array<PrimExpr>& output_size,
+                              PoolType pool_type,
+                              const std::string& layout = "NCDHW") {
+  int depth_axis = -1, height_axis = -1, width_axis = -1;
+  CHECK(find_depth_height_width(layout, &depth_axis, &height_axis, &width_axis))
+    << "Unsupported layout " << layout;
+  return adaptive_pool_impl(x, output_size, pool_type, {depth_axis, height_axis, width_axis});
 }
 
 /*!
diff --git a/topi/python/topi/nn/pooling.py b/topi/python/topi/nn/pooling.py
index e3d57ce..52317c2 100644
--- a/topi/python/topi/nn/pooling.py
+++ b/topi/python/topi/nn/pooling.py
@@ -218,6 +218,16 @@ def adaptive_pool(data,
     return cpp.nn.adaptive_pool(data, output_size, POOL_TYPE_CODE[pool_type], layout)
 
 
+def adaptive_pool3d(data,
+                    output_size,
+                    pool_type,
+                    layout="NCDHW"):
+    """Perform pooling on three dimensional data.
+       See the two dimensional version above for details.
+    """
+    return cpp.nn.adaptive_pool3d(data, output_size, POOL_TYPE_CODE[pool_type], layout)
+
+
 def pool1d(data,
            kernel,
            stride,
diff --git a/topi/python/topi/testing/__init__.py b/topi/python/topi/testing/__init__.py
index b0f4752..36c460e 100644
--- a/topi/python/topi/testing/__init__.py
+++ b/topi/python/topi/testing/__init__.py
@@ -55,3 +55,4 @@ from .space_to_depth import space_to_depth_python
 from .crop_and_resize_python import crop_and_resize_python
 from .common import get_injective_schedule, get_reduce_schedule, get_broadcast_schedule, \
     get_elemwise_schedule, get_conv2d_nchw_implement, dispatch
+from .adaptive_pool_python import adaptive_pool
diff --git a/topi/python/topi/testing/adaptive_pool_python.py b/topi/python/topi/testing/adaptive_pool_python.py
new file mode 100644
index 0000000..3f464ce
--- /dev/null
+++ b/topi/python/topi/testing/adaptive_pool_python.py
@@ -0,0 +1,111 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-argument, unused-variable
+"""adaptive pool in python"""
+import numpy as np
+
+
+def _start_index(index, odim, idim):
+    return int(np.floor(index * idim / odim))
+
+
+def _end_index(index, odim, idim):
+    return int(np.ceil((index + 1) * idim / odim))
+
+
+def _pool2d(in_size, out_size, np_data, np_op):
+    out = np.zeros(out_size).astype(np_data.dtype)
+    oh, ow = out_size
+    for k in range(oh):
+        k_start = _start_index(k, oh, in_size[0])
+        k_end = _end_index(k, oh, in_size[0])
+        k_sl = slice(k_start, k_end)
+        for l in range(ow):
+            l_start = _start_index(l, ow, in_size[1])
+            l_end = _end_index(l, ow, in_size[1])
+            l_sl = slice(l_start, l_end)
+            out[k, l] = np_op(np_data[k_sl, l_sl])
+    return out
+
+
+def _pool3d(in_size, out_size, np_data, np_op):
+    out = np.zeros(out_size).astype(np_data.dtype)
+    od, oh, ow = out_size
+    for m in range(od):
+        m_start = _start_index(m, od, in_size[0])
+        m_end = _end_index(m, od, in_size[0])
+        m_sl = slice(m_start, m_end)
+        for k in range(oh):
+            k_start = _start_index(k, oh, in_size[1])
+            k_end = _end_index(k, oh, in_size[1])
+            k_sl = slice(k_start, k_end)
+            for l in range(ow):
+                l_start = _start_index(l, ow, in_size[2])
+                l_end = _end_index(l, ow, in_size[2])
+                l_sl = slice(l_start, l_end)
+                out[m, k, l] = np_op(np_data[m_sl, k_sl, l_sl])
+    return out
+
+
+def adaptive_pool_nchw(np_data, out_size, pool_op, np_op):
+    """ The reference function for adaptive pool, nchw layout """
+    ishape = np_data.shape
+    n, c = ishape[:2]
+    oshape = (n, c) + out_size
+    np_out = np.zeros(oshape).astype(np_data.dtype)
+
+    for i in range(n):
+        for j in range(c):
+            np_out[i, j] = pool_op(ishape[2:], out_size, np_data[i, j], np_op)
+
+    return np_out
+
+
+def adaptive_pool_nhwc(np_data, out_size, pool_op, np_op):
+    """ The reference function for adaptive pool, nhwc layout """
+    ishape = np_data.shape
+    n, c = ishape[0], ishape[-1]
+    oshape = (n,) + out_size + (c,)
+    np_out = np.zeros(oshape).astype(np_data.dtype)
+
+    for i in range(n):
+        for j in range(c):
+            if len(out_size) == 2:
+                np_out[i, :, :, j] = pool_op(ishape[1:-1], out_size,
+                                             np_data[i, :, :, j], np_op)
+            else:
+                np_out[i, :, :, :, j] = pool_op(ishape[1:-1], out_size,
+                                                np_data[i, :, :, :, j], np_op)
+
+    return np_out
+
+
+def adaptive_pool(np_data, out_size, pool_type, layout):
+    """ The reference function for adaptive pool, for 2d and 3d """
+    if len(out_size) == 2:
+        pool_op = _pool2d
+    else:
+        assert len(out_size) == 3
+        pool_op = _pool3d
+
+    np_op = np.mean if pool_type == "avg" else np.max
+
+    if layout in ["NCHW", "NCDHW"]:
+        return adaptive_pool_nchw(np_data, out_size, pool_op, np_op)
+
+    assert layout in ["NHWC", "NDHWC"]
+    return adaptive_pool_nhwc(np_data, out_size, pool_op, np_op)
diff --git a/topi/src/topi.cc b/topi/src/topi.cc
index add01c2..5581f2b 100644
--- a/topi/src/topi.cc
+++ b/topi/src/topi.cc
@@ -549,6 +549,13 @@ TVM_REGISTER_GLOBAL("topi.nn.adaptive_pool")
                           args[3]);
 });
 
+TVM_REGISTER_GLOBAL("topi.nn.adaptive_pool3d")
+.set_body([](TVMArgs args, TVMRetValue *rv) {
+  *rv = nn::adaptive_pool3d(args[0], args[1],
+                            static_cast<nn::PoolType>(static_cast<int>(args[2])),
+                            args[3]);
+});
+
 TVM_REGISTER_GLOBAL("topi.nn.pool1d")
 .set_body([](TVMArgs args, TVMRetValue *rv) {
   *rv = nn::pool1d(args[0], args[1], args[2], args[3],
diff --git a/topi/tests/python/test_topi_pooling.py b/topi/tests/python/test_topi_pooling.py
index 64f0841..9bdbb10 100644
--- a/topi/tests/python/test_topi_pooling.py
+++ b/topi/tests/python/test_topi_pooling.py
@@ -244,33 +244,19 @@ def test_global_pool():
     verify_global_pool(1, 1024, 7, 7, 'max', 'NHWC')
     verify_global_pool(4, 1024, 7, 7, 'max', 'NHWC')
 
-def verify_adaptive_pool(dshape, out_size, pool_type, layout="NCHW", dtype="float32"):
-    def start_index(index, odim, idim):
-        return int(np.floor(index * idim / odim))
-
-    def end_index(index, odim, idim):
-        return int(np.ceil((index + 1) * idim / odim))
 
+def verify_adaptive_pool(dshape, out_size, pool_type, layout="NCHW", dtype="float32"):
     np_data = np.random.uniform(low=0, high=255, size=dshape).astype(dtype)
-    n, c, h, w = dshape
-    oh, ow = out_size
-    oshape = (n, c) + out_size
-    np_out = np.zeros(oshape).astype(dtype)
-    np_op = np.mean if pool_type == "avg" else np.max
-    for i in range(n):
-        for j in range(c):
-            for k in range(oh):
-                k_start = start_index(k, oh, h)
-                k_end = end_index(k, oh, h)
-                k_sl = slice(k_start, k_end)
-                for l in range(ow):
-                    l_start = start_index(l, ow, w)
-                    l_end = end_index(l, ow, w)
-                    l_sl = slice(l_start, l_end)
-                    np_out[i, j, k, l] = np_op(np_data[i, j, k_sl, l_sl])
+    np_out = topi.testing.adaptive_pool(np_data, out_size, pool_type, layout)
+    oshape = np_out.shape
 
     data = te.placeholder(dshape, name="data", dtype=dtype)
-    out = topi.nn.adaptive_pool(data, out_size, pool_type, layout)
+    if len(out_size) == 2:
+        out = topi.nn.adaptive_pool(data, out_size, pool_type, layout)
+    else:
+        assert len(out_size) == 3
+        out = topi.nn.adaptive_pool3d(data, out_size, pool_type, layout)
+
     def check_device(device):
         ctx = tvm.context(device, 0)
         if not ctx.exist:
@@ -289,11 +275,23 @@ def verify_adaptive_pool(dshape, out_size, pool_type, layout="NCHW", dtype="floa
     for device in get_all_backend():
         check_device(device)
 
+
 def test_adaptive_pool():
     verify_adaptive_pool((1, 3, 224, 224), (1, 1), "max")
     verify_adaptive_pool((1, 3, 224, 224), (1, 1), "avg")
     verify_adaptive_pool((1, 14, 56, 78), (34, 13), "max")
     verify_adaptive_pool((1, 5, 46, 97), (4, 96), "avg")
+    verify_adaptive_pool((1, 224, 224, 3), (1, 1), "max", layout="NHWC")
+    verify_adaptive_pool((1, 5, 46, 97), (4, 96), "avg", layout="NHWC")
+    verify_adaptive_pool((1, 16, 32, 32, 32), (1, 1, 1), "max", layout="NCDHW")
+    verify_adaptive_pool((1, 16, 32, 32, 32), (1, 1, 1), "avg", layout="NCDHW")
+    verify_adaptive_pool((1, 16, 32, 32, 32), (2, 2, 2), "avg", layout="NCDHW")
+    verify_adaptive_pool((1, 16, 64, 32, 32), (7, 8, 9), "avg", layout="NCDHW")
+    verify_adaptive_pool((1, 16, 64, 32, 32), (8, 16, 16), "avg", layout="NCDHW")
+    verify_adaptive_pool((1, 16, 32, 32, 32), (1, 1, 1), "avg", layout="NDHWC")
+    verify_adaptive_pool((1, 16, 32, 32, 32), (2, 2, 2), "max", layout="NDHWC")
+    verify_adaptive_pool((1, 16, 32, 32, 32), (2, 4, 4), "max", layout="NDHWC")
+
 
 def verify_pool3d(n, ic, ih, kh, sh, padding, pool_type,
                   ceil_mode, count_include_pad=True, layout='NCDHW'):