You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2017/08/16 17:32:16 UTC

[incubator-mxnet] branch master updated: add depthwise convolution's gpu version optimization (#7393)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new d7d31b2  add depthwise convolution's gpu version optimization (#7393)
d7d31b2 is described below

commit d7d31b2d9ce2eee98a3b9f41bc8c526a3125ce78
Author: shuqian.qu <qs...@gmail.com>
AuthorDate: Thu Aug 17 01:32:14 2017 +0800

    add depthwise convolution's gpu version optimization (#7393)
    
    * add depthwise convolution's gpu version optimization
    
    * add more config for test_depthwise_convolution
    
    * remove CUDA_1D_KERNEL_LOOP
    
    * fix windows compiling error
    
    * add support for kAddTo when cal input's backward
    
    * remove depthwise_conv_off params
    
    * Update convolution.cu
    
    * Update test_operator.py
---
 src/common/cuda_utils.h                   |  28 +-
 src/operator/convolution.cu               |  14 +
 src/operator/depthwise_convolution-inl.h  | 349 +++++++++++++++
 src/operator/depthwise_convolution_tf.cuh | 703 ++++++++++++++++++++++++++++++
 tests/python/unittest/test_operator.py    |  38 ++
 5 files changed, 1131 insertions(+), 1 deletion(-)

diff --git a/src/common/cuda_utils.h b/src/common/cuda_utils.h
index 8897007..483390f 100644
--- a/src/common/cuda_utils.h
+++ b/src/common/cuda_utils.h
@@ -153,6 +153,16 @@ inline const char* CurandGetErrorString(curandStatus_t status) {
   return "Unknown cuRAND status";
 }
 
+template <typename DType>
+inline DType __device__ CudaMax(DType a, DType b) {
+    return a > b ? a : b;
+}
+
+template <typename DType>
+inline DType __device__ CudaMin(DType a, DType b) {
+    return a < b ? a : b;
+}
+
 }  // namespace cuda
 }  // namespace common
 }  // namespace mxnet
@@ -219,6 +229,14 @@ inline const char* CurandGetErrorString(curandStatus_t status) {
         << "cuRAND: " << common::cuda::CurandGetErrorString(e); \
   }
 
+#if !defined(_MSC_VER)
+#define CUDA_UNROLL _Pragma("unroll")
+#define CUDA_NOUNROLL _Pragma("nounroll")
+#else
+#define CUDA_UNROLL
+#define CUDA_NOUNROLL
+#endif
+
 /*!
  * \brief Determine major version number of the gpu's cuda compute architecture.
  * \param device_id The device index of the cuda-capable gpu of interest.
@@ -291,7 +309,6 @@ inline bool GetEnvAllowTensorCore() {
   return dmlc::GetEnv("MXNET_CUDA_ALLOW_TENSOR_CORE",
                       dmlc::optional<bool>(default_value)).value();
 }
-
 #endif  // MXNET_USE_CUDA
 
 #if MXNET_USE_CUDNN
@@ -401,6 +418,15 @@ static inline __device__ void atomicAdd(mshadow::half::half_t *address,
     old = atomicCAS(address_as_ui, assumed, old);
   } while (assumed != old);
 }
+
+template <typename DType>
+__device__ inline DType ldg(const DType* address) {
+#if __CUDA_ARCH__ >= 350
+    return __ldg(address);
+#else
+    return *address;
+#endif
+}
 #endif
 
 #endif  // MXNET_COMMON_CUDA_UTILS_H_
diff --git a/src/operator/convolution.cu b/src/operator/convolution.cu
index ab35484..f5777c1 100644
--- a/src/operator/convolution.cu
+++ b/src/operator/convolution.cu
@@ -29,6 +29,8 @@
 #include "./cudnn_convolution-inl.h"
 #endif  // MXNET_USE_CUDNN
 
+#include "./depthwise_convolution-inl.h"
+
 namespace mxnet {
 namespace op {
 
@@ -45,6 +47,18 @@ Operator* CreateOp<gpu>(ConvolutionParam param, int dtype,
     })
     return op;
   }
+
+  // depth wise conv
+  if (param.num_filter == param.num_group &&
+      param.layout.value() == mshadow::kNCHW &&
+      param.num_filter == (*in_shape)[conv::kData][1] &&
+      param.kernel.ndim() == 2 &&
+      param.dilate == mshadow::Shape2(1, 1) &&
+      dtype == mshadow::kFloat32) {
+    op = new DepthwiseConvolutionOp<float>(param, *in_shape, *out_shape);
+    return op;
+  }
+
 #if MXNET_USE_CUDNN == 1
   // The NVIDIA Pascal architecture was the first to include 16-bit ALUs.
   // Thus, when the framework is compiled with MSHADOW_USE_PASCAL == 1, we
diff --git a/src/operator/depthwise_convolution-inl.h b/src/operator/depthwise_convolution-inl.h
new file mode 100644
index 0000000..5beea45
--- /dev/null
+++ b/src/operator/depthwise_convolution-inl.h
@@ -0,0 +1,349 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file depthwise_convolution-inl.h
+ * \brief CUDA depthwise convolution code
+ * \author shuqian.qu@hobot.cc
+*/
+#ifndef MXNET_OPERATOR_DEPTHWISE_CONVOLUTION_INL_H_
+#define MXNET_OPERATOR_DEPTHWISE_CONVOLUTION_INL_H_
+#include <algorithm>
+#include <vector>
+#include "./convolution-inl.h"
+#include "../common/cuda_utils.h"
+
+#if MXNET_USE_CUDA
+#include <cub/cub.cuh>
+#include "./depthwise_convolution_tf.cuh"
+
+namespace mxnet {
+namespace op {
+using namespace tf::depthwise_conv;
+template<typename DType>
+class DepthwiseConvolutionOp : public Operator {
+ public:
+  explicit DepthwiseConvolutionOp(const ConvolutionParam& param,
+                                  const std::vector<TShape>& in_shape,
+                                  const std::vector<TShape>& out_shape) {
+    args_.batch = in_shape[conv::kData][0];
+    args_.in_channel = in_shape[conv::kData][1];
+    args_.in_height = in_shape[conv::kData][2];
+    args_.in_width = in_shape[conv::kData][3];
+    args_.filter_height = in_shape[conv::kWeight][2];
+    args_.filter_width = in_shape[conv::kWeight][3];
+    args_.stride_height = param.stride[0];
+    args_.stride_width = param.stride[1];
+    args_.pad_height = param.pad[0];
+    args_.pad_width = param.pad[1];
+    args_.out_channel = out_shape[conv::kOut][1];
+    args_.out_height = out_shape[conv::kOut][2];
+    args_.out_width = out_shape[conv::kOut][3];
+    bias_term_ = !param.no_bias;
+  }
+
+  ~DepthwiseConvolutionOp() {}
+
+  virtual void Forward(const OpContext &ctx,
+                       const std::vector<TBlob> &in_data,
+                       const std::vector<OpReqType> &req,
+                       const std::vector<TBlob> &out_data,
+                       const std::vector<TBlob> &aux_args);
+
+  virtual void Backward(const OpContext &ctx,
+                        const std::vector<TBlob> &out_grad,
+                        const std::vector<TBlob> &in_data,
+                        const std::vector<TBlob> &out_data,
+                        const std::vector<OpReqType> &req,
+                        const std::vector<TBlob> &in_grad,
+                        const std::vector<TBlob> &aux_args);
+
+ private:
+  DepthwiseArgs args_;
+  bool bias_term_;
+};  // class DepthwiseConvolutionOp
+
+namespace depthwise_conv {
+namespace cuda {
+template<typename DType, int kFilterWidth, int kFilterHeight>
+__global__ void __launch_bounds__(1024, 2)
+DepthwiseConv2dBackwardFilterKernel(const DepthwiseArgs args,
+                                     const DType* out_grad,
+                                     const DType* input,
+                                     DType* filter_grad) {
+  const int in_height = args.in_height;
+  const int in_width = args.in_width;
+  const int channel = args.in_channel;
+  const int filter_height = kFilterHeight > 0 ? kFilterHeight : args.filter_height;
+  const int filter_width = kFilterWidth > 0 ? kFilterWidth : args.filter_width;
+  const int stride_height = args.stride_height;
+  const int stride_width = args.stride_width;
+  const int pad_height = args.pad_height;
+  const int pad_width = args.pad_width;
+  const int out_height = args.out_height;
+  const int out_width = args.out_width;
+
+  const int filter_pixels = filter_width * filter_height;
+  const int out_pixels = out_height * out_width;
+  const int in_pixels = in_height * in_width;
+  const int batch_channel_num = channel * args.batch;
+  const int candidate_reduce_thread_num = out_pixels % blockDim.x;
+
+  for (int b = blockIdx.x; b < batch_channel_num; b += gridDim.x) {
+    const int local_batch = b / channel;
+    const int local_channel = b % channel;
+    const int filter_offset_temp = local_channel * filter_pixels;
+    const int out_grad_offset_temp = (local_batch * channel * out_pixels) +
+        (local_channel * out_pixels);
+
+    for (int out_id = threadIdx.x; out_id < out_pixels; out_id += blockDim.x) {
+      const int reduce_thread_num = ((out_pixels - out_id) > candidate_reduce_thread_num) ?
+          blockDim.x : candidate_reduce_thread_num;
+
+      const int out_w = out_id % out_width;
+      const int out_h = (out_id / out_width) % out_height;
+      const int out_grad_offset = out_grad_offset_temp + (out_h * out_width) + (out_w);
+      const DType out_g = ldg(out_grad + out_grad_offset);
+
+      const int in_h_start = out_h * stride_height - pad_height;
+      const int in_w_start = out_w * stride_width - pad_width;
+      CUDA_UNROLL for (int f_h = 0; f_h < filter_height; ++f_h) {
+        const int in_h = in_h_start + f_h;
+        const int input_offset_temp = (local_batch * channel * in_pixels) +
+            (local_channel * in_pixels) + (in_h * in_width);
+        const int filter_offset_h = filter_width * f_h;
+
+        CUDA_UNROLL for (int f_w = 0; f_w < filter_width; ++f_w) {
+          const int in_w = in_w_start + f_w;
+          DType partial_grad = DType(0.0f);
+          if (in_h >= 0 && in_h < in_height && in_w >= 0 && in_w < in_width) {
+            const int input_offset = input_offset_temp + in_w;
+            partial_grad = ldg(input + input_offset) * out_g;
+          }
+          // reduce all valid partial grad in a block
+          typedef cub::BlockReduce<DType, mshadow::cuda::kBaseThreadNum> BlockReduceT;
+          __shared__ typename BlockReduceT::TempStorage temp_storage_reduce;
+          DType aggregate = BlockReduceT(temp_storage_reduce).Sum(partial_grad, reduce_thread_num);
+          if (threadIdx.x == 0) {
+            DType* addr = filter_grad + f_w + filter_offset_h + filter_offset_temp;
+            atomicAdd(addr, aggregate);
+          }
+          __syncthreads();
+        }  // for filter_width
+      }  // for filter_height
+    }  // for out_pixels
+    __syncthreads();
+  }  // for batch_channel_num
+}
+}  // namespace cuda
+
+template<typename DType>
+void DepthwiseConv2dForwardGpu(mshadow::Stream<gpu> *stream,
+                               const DepthwiseArgs& args,
+                               const std::vector<TBlob> &in_data,
+                               const std::vector<TBlob> &out_data) {
+  using namespace mshadow;
+  using namespace mshadow::expr;
+  using namespace tf::depthwise_conv;
+  using namespace tf::depthwise_conv::cuda;
+  Tensor<gpu, 4, DType> data = in_data[conv::kData].get<gpu, 4, DType>(stream);
+  Tensor<gpu, 4, DType> weight = in_data[conv::kWeight].get<gpu, 4, DType>(stream);
+  Tensor<gpu, 4, DType> out = out_data[conv::kOut].get<gpu, 4, DType>(stream);
+
+  // select kernel
+  if (CanLaunchDepthwiseConv2dGPUSmall(args)) {
+    LaunchDepthwiseConv2dGPUSmall<DType, DIRECTION_FORWARD>(
+        stream,
+        args,
+        data.dptr_,
+        weight.dptr_,
+        out.dptr_);
+  } else {
+    int num_output = out_data[conv::kOut].shape_.Size();
+    int block_num = std::min(num_output/mshadow::cuda::kBaseThreadNum + 1,
+                             mshadow::cuda::kMaxGridNum);
+    auto s = mshadow::Stream<gpu>::GetStream(stream);
+    if (args.filter_height == 3 && args.filter_width == 3) {
+      DepthwiseConv2dForwardKernel<DType, 3, 3>
+          <<<block_num, mshadow::cuda::kBaseThreadNum, 0, s>>>(data.dptr_,
+                                                               weight.dptr_,
+                                                               args,
+                                                               num_output,
+                                                               out.dptr_);
+    } else {
+      DepthwiseConv2dForwardKernel<DType, -1, -1>
+          <<<block_num, mshadow::cuda::kBaseThreadNum, 0, s>>>(data.dptr_,
+                                                               weight.dptr_,
+                                                               args,
+                                                               num_output,
+                                                               out.dptr_);
+    }
+    MSHADOW_CUDA_POST_KERNEL_CHECK(DepthwiseConv2dForwardKernel);
+  }
+}
+
+template<typename DType>
+void DepthwiseConv2dBackwardDataGpu(mshadow::Stream<gpu> *stream,
+                                    const DepthwiseArgs& args,
+                                    const std::vector<TBlob> &out_grad,
+                                    const std::vector<TBlob> &in_data,
+                                    const std::vector<TBlob> &in_grad) {
+  using namespace mshadow;
+  using namespace mshadow::expr;
+  using namespace tf::depthwise_conv;
+  using namespace tf::depthwise_conv::cuda;
+  Tensor<gpu, 4, DType> out_g = out_grad[conv::kOut].get<gpu, 4, DType>(stream);
+  Tensor<gpu, 4, DType> weight = in_data[conv::kWeight].get<gpu, 4, DType>(stream);
+  Tensor<gpu, 4, DType> in_data_g = in_grad[conv::kData].get<gpu, 4, DType>(stream);
+  // select kernel
+  if (CanLaunchDepthwiseConv2dGPUSmall(args)) {
+    LaunchDepthwiseConv2dGPUSmall<DType, DIRECTION_BACKWARD>(
+        stream,
+        args,
+        out_g.dptr_,
+        weight.dptr_,
+        in_data_g.dptr_);
+  } else {
+    int num_in_grad = in_grad[conv::kData].shape_.Size();
+    auto s = mshadow::Stream<gpu>::GetStream(stream);
+    int block_num = std::min(num_in_grad/mshadow::cuda::kBaseThreadNum + 1,
+                             mshadow::cuda::kMaxGridNum);
+    DepthwiseConv2dBackwardDataKernel<DType>
+        <<<block_num, mshadow::cuda::kBaseThreadNum, 0, s>>>(args,
+                                                             out_g.dptr_,
+                                                             weight.dptr_,
+                                                             in_data_g.dptr_,
+                                                             num_in_grad);
+    MSHADOW_CUDA_POST_KERNEL_CHECK(DepthwiseConv2dBackwardDataKernel);
+  }
+}
+
+template<typename DType>
+void DepthwiseConv2dBackwardFilterGpu(mshadow::Stream<gpu> *stream,
+                                      const DepthwiseArgs& args,
+                                      const std::vector<TBlob> &out_grad,
+                                      const std::vector<TBlob> &in_data,
+                                      const std::vector<TBlob> &in_grad) {
+  using namespace mshadow;
+  using namespace mshadow::expr;
+  using namespace tf::depthwise_conv;
+  Tensor<gpu, 4, DType> out_g = out_grad[conv::kOut].get<gpu, 4, DType>(stream);
+  Tensor<gpu, 4, DType> in_d = in_data[conv::kData].get<gpu, 4, DType>(stream);
+  Tensor<gpu, 4, DType> weight_grad = in_grad[conv::kWeight].get<gpu, 4, DType>(stream);
+  // select kernel
+  if (TryLaunchDepthwiseConv2dBackwardFilterGPUSmall<DType>(stream, args,
+                                                            out_g.dptr_,
+                                                            in_d.dptr_,
+                                                            weight_grad.dptr_)) {
+    return;
+  } else {
+    int num_out_grad = out_grad[conv::kOut].shape_.Size();
+    auto s = mshadow::Stream<gpu>::GetStream(stream);
+    int block_num = std::min(args.out_channel * args.batch, mshadow::cuda::kMaxGridNum);
+    if (args.filter_width == 3 && args.filter_height == 3) {
+      cuda::DepthwiseConv2dBackwardFilterKernel<DType, 3, 3>
+          <<<block_num, mshadow::cuda::kBaseThreadNum, 0, s>>>(args,
+                                                               out_g.dptr_,
+                                                               in_d.dptr_,
+                                                               weight_grad.dptr_);
+    } else {
+      cuda::DepthwiseConv2dBackwardFilterKernel<DType, -1, -1>
+          <<<block_num, mshadow::cuda::kBaseThreadNum, 0, s>>>(args,
+                                                               out_g.dptr_,
+                                                               in_d.dptr_,
+                                                               weight_grad.dptr_);
+    }
+    MSHADOW_CUDA_POST_KERNEL_CHECK(DepthwiseConv2dBackwardFilterKernel);
+  }
+}
+}  // namespace depthwise_conv
+
+template<typename DType>
+void DepthwiseConvolutionOp<DType>::Forward(const OpContext &ctx,
+                                            const std::vector<TBlob> &in_data,
+                                            const std::vector<OpReqType> &req,
+                                            const std::vector<TBlob> &out_data,
+                                            const std::vector<TBlob> &aux_states) {
+  using namespace mshadow;
+  using namespace mshadow::expr;
+  auto stream = ctx.get_stream<gpu>();
+  CHECK_EQ(req[conv::kOut], kWriteTo);
+  // output forward
+  depthwise_conv::DepthwiseConv2dForwardGpu<DType>(stream, args_, in_data, out_data);
+
+  // bias forward
+  if (bias_term_) {
+    Tensor<gpu, 1, DType> bias = in_data[conv::kBias].get<gpu, 1, DType>(stream);
+    Tensor<gpu, 3, DType> output_3d = out_data[conv::kOut].get_with_shape<gpu, 3, DType>(
+        Shape3(args_.batch, args_.out_channel, args_.out_height * args_.out_width), stream);
+    // has bias term, broadcast it to the same shape of output_3d in channel dim
+    output_3d += mshadow::expr::broadcast<1>(bias, output_3d.shape_);
+  }
+}
+
+template<typename DType>
+void DepthwiseConvolutionOp<DType>::Backward(const OpContext &ctx,
+                                             const std::vector<TBlob> &out_grad,
+                                             const std::vector<TBlob> &in_data,
+                                             const std::vector<TBlob> &out_data,
+                                             const std::vector<OpReqType> &req,
+                                             const std::vector<TBlob> &in_grad,
+                                             const std::vector<TBlob> &aux_states) {
+  using namespace mshadow;
+  using namespace mshadow::expr;
+  auto stream = ctx.get_stream<gpu>();
+  // backward data
+  if (req[conv::kData] != kNullOp) {
+    if (req[conv::kData] != kAddTo) {
+      mshadow::Tensor<gpu, 4, DType> igrad = in_grad[conv::kData].get<gpu, 4, DType>(stream);
+      igrad = 0.0f;
+    }
+    depthwise_conv::DepthwiseConv2dBackwardDataGpu<DType>(stream,
+                                                          args_,
+                                                          out_grad,
+                                                          in_data,
+                                                          in_grad);
+  }
+
+  // backward filter
+  if (req[conv::kWeight] != kNullOp) {
+    if (req[conv::kWeight] != kAddTo) {
+      mshadow::Tensor<gpu, 4, DType> wgrad = in_grad[conv::kWeight].get<gpu, 4, DType>(stream);
+      wgrad = 0.0f;
+    }
+    depthwise_conv::DepthwiseConv2dBackwardFilterGpu<DType>(stream,
+                                                            args_,
+                                                            out_grad,
+                                                            in_data,
+                                                            in_grad);
+  }
+
+  // backward bias
+  if (bias_term_) {
+    Tensor<gpu, 1, DType> dbias = in_grad[conv::kBias].get<gpu, 1, DType>(stream);
+    Tensor<gpu, 3, DType> dout = out_grad[conv::kOut].get_with_shape<gpu, 3, DType>(
+        Shape3(args_.batch, args_.out_channel, args_.out_height * args_.out_width), stream);
+    ASSIGN_DISPATCH(dbias, req[conv::kBias], sumall_except_dim<1>(dout));
+  }
+}
+}  // namespace op
+}  // namespace mxnet
+#endif
+
+#endif  // MXNET_OPERATOR_DEPTHWISE_CONVOLUTION_INL_H_
diff --git a/src/operator/depthwise_convolution_tf.cuh b/src/operator/depthwise_convolution_tf.cuh
new file mode 100644
index 0000000..a1538b6
--- /dev/null
+++ b/src/operator/depthwise_convolution_tf.cuh
@@ -0,0 +1,703 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file depthwise_convolution_tf.cuh
+ * \brief some depthwise convolution CUDA kernel code. The main logic comes
+ *        from tensorflow, but the filter's layerout and many argument names
+ *        are different with origin version.
+ * \author shuqian.qu@hobot.cc
+*/
+#ifndef MXNET_OPERATOR_DEPTHWISE_CONVOLUTION_TF_CUH_
+#define MXNET_OPERATOR_DEPTHWISE_CONVOLUTION_TF_CUH_
+#include "../common/cuda_utils.h"
+#include "./mxnet_op.h"
+
+namespace tf {
+namespace depthwise_conv {
+struct DepthwiseArgs {
+  // Input layer dimensions
+  int batch;
+  int in_height;
+  int in_width;
+  int in_channel;
+  int filter_height;
+  int filter_width;
+  int stride_height;
+  int stride_width;
+  int pad_height;
+  int pad_width;
+
+  // Output layer dimensions
+  int out_height;
+  int out_width;
+  int out_channel;
+};
+
+namespace cuda {
+template<typename DType, int kFilterHeight, int kFilterWidth>
+__global__ void __launch_bounds__(1024, 2)
+DepthwiseConv2dForwardKernel(const DType* input,
+                             const DType* filter,
+                             const DepthwiseArgs args,
+                             int num_outputs,
+                             DType* output) {
+  const int in_channel = args.in_channel;
+  const int in_height = args.in_height;
+  const int in_width = args.in_width;
+  const int filter_height = kFilterHeight > 0 ? kFilterHeight : args.filter_height;
+  const int filter_width = kFilterWidth > 0 ? kFilterWidth : args.filter_width;
+  const int stride_height = args.stride_height;
+  const int stride_width = args.stride_width;
+  const int pad_height = args.pad_height;
+  const int pad_width = args.pad_width;
+  const int out_channel = args.out_channel;
+  const int out_height = args.out_height;
+  const int out_width = args.out_width;
+
+  CUDA_KERNEL_LOOP(thread_id, num_outputs) {
+    // Compute the indexes of this thread in the output.
+    //
+    // We want coalesced reads so we make sure that each warp reads
+    // a contiguous chunk of memory.
+    //
+    // THIS IS PROBABLY WRONG, we are not doing coalesced reads
+    // into the input, because of the depth multiplier division...
+    const int out_w = thread_id % out_width;
+    const int out_h = (thread_id / out_width) % out_height;
+    const int out_c = (thread_id / out_width / out_height) % out_channel;
+    const int out_b = thread_id / out_width / out_height / out_channel;
+    const int in_c = out_c;
+
+    // Data is stored in the following format (let's assume we
+    // flatten the height and width into one contiguous dimension
+    // called "P".
+    //
+    // B1C1P1 B1C1P2 ..... B1C2P1 B1C2P2 ....
+    // B2C1P1 B2C1P2 ..... B2C2P1 B2C2P2 ....
+    //
+    // Each row contains in_channel * in_height * in_width values
+    // for each sample in the batch.
+    //
+    // We can further flatten it into:
+    //
+    // B1C1P1 B1C1P2 .....
+    // B1C2P1 B1C2P2 ....
+    // B2C1P1 B2C1P2 .....
+    // B2C2P1 B2C2P2 ....
+    //
+    // where each row is a contiguous array of all of the spatial
+    // pixels for a given batch and input depth.  The following
+    // loop unrolls across the filter dimensions for a given thread,
+    // indexing into the filter value and the corresponding input
+    // patch.
+    //
+    // We can compute the index into the patch once right here.
+    const int input_offset_temp = (out_b * in_channel + in_c) * (in_height * in_width);
+    const int filter_offset_temp = in_c * filter_height * filter_width;
+
+    // Finally, we can iterate over the spatial dimensions and perform the
+    // convolution, writing into the output at the end.
+    //
+    // We perform an additional optimization, where we can determine
+    // whether the patch fits within the image indices statically, and
+    // avoid boundary checking within the loop.
+    const int input_h_start = out_h * stride_height - pad_height;
+    const int input_w_start = out_w * stride_width - pad_width;
+    const int input_h_end = input_h_start + filter_height;
+    const int input_w_end = input_w_start + filter_width;
+
+    DType sum = 0;
+    if (input_h_start >= 0 && input_w_start >= 0 &&
+        input_h_end < in_height && input_w_end < in_width) {
+      // Loop that doesn't need to check for boundary conditions.
+      CUDA_UNROLL for (int f_h = 0; f_h < filter_height; ++f_h) {
+        const int in_h = input_h_start + f_h;
+        const int filter_offset_h = filter_width * f_h;
+        CUDA_UNROLL for (int f_w = 0; f_w < filter_width; ++f_w) {
+          const int in_w = input_w_start + f_w;
+          const int input_offset = (input_offset_temp) + (in_h * in_width) + in_w;
+          const int filter_offset = filter_offset_temp + filter_offset_h + f_w;
+          sum += ldg(input + input_offset) * ldg(filter + filter_offset);
+        }
+      }
+    } else {
+      // Loop that needs to check for boundary conditions.
+      CUDA_UNROLL for (int f_h = 0; f_h < filter_height; ++f_h) {
+        const int in_h = input_h_start + f_h;
+        const int filter_offset_h = filter_width * f_h;
+        CUDA_UNROLL for (int f_w = 0; f_w < filter_width; ++f_w) {
+          const int in_w = input_w_start + f_w;
+          // TODO(vrv): the in_h check can be done outside of this loop;
+          // benchmark both methods to determine the better decision.
+          if (in_h >= 0 && in_h < in_height && in_w >= 0 && in_w < in_width) {
+            const int in_w = input_w_start + f_w;
+            const int input_offset = input_offset_temp + (in_h * in_width) + in_w;
+            const int filter_offset = filter_offset_temp + filter_offset_h + f_w;
+            sum += ldg(input + input_offset) * ldg(filter + filter_offset);
+          }
+        }
+      }
+    }
+    output[thread_id] = sum;
+  }
+}
+
+// The DepthwiseConv2dKernelSmall perform either forward or backward input
+// convolution depending on a template argument of this enum.
+enum DepthwiseConv2dDirection { DIRECTION_FORWARD, DIRECTION_BACKWARD };
+
+// CUDA kernel to compute the depthwise convolution forward pass in NCHW format,
+// tailored for small images up to 32x32. Only use this kernel if
+// CanLaunchDepthwiseConv2dGPUSmall(args) returns true.
+// Tiles of the input and filter tensors are loaded into shared memory before
+// performing the convolution. Each thread handles two elements per iteration,
+// one each in the lower and upper half of a tile.
+// Backward input direction is the same as forward direction with the filter
+// rotated by 180°.
+template <typename DType, DepthwiseConv2dDirection kDirection,
+          int kBlockSlices, bool kEvenHeight, int kFilterHeight, int kFilterWidth>
+__global__ __launch_bounds__(1024, 2) void DepthwiseConv2dKernelSmall(
+    const DepthwiseArgs args, const DType* input, const DType* filter, DType* output) {
+  extern __shared__ __align__(sizeof(DType)) unsigned char shared_memory[];
+  DType* const shared_data = reinterpret_cast<DType*>(shared_memory);
+
+  const int in_height = args.in_height;
+  const int in_width = args.in_width;
+  const int in_channel = args.in_channel;
+  const int filter_height = kFilterHeight > 0 ? kFilterHeight : args.filter_height;
+  const int filter_width = kFilterWidth > 0 ? kFilterWidth : args.filter_width;
+  const int pad_height = args.pad_height;
+  const int pad_width = args.pad_width;
+
+  // Fixed blockDim.z, tailored for maximum grid size for images of size 16x16.
+  const int block_height = blockDim.y;
+
+  // These values are the same for all threads and could
+  // be precomputed on the CPU.
+  const int block_pixels = in_width * block_height;
+  const int block_size = block_pixels * kBlockSlices;
+  const int in_pixels = in_width * in_height;
+  const int in_increment = in_width - 1;
+  const int filter_pixels = filter_height * filter_width;
+  const int tile_width = in_width + filter_width - 1;
+  const int even_height = kEvenHeight || (1 & ~in_height);
+  const int tile_height = in_height + filter_height - even_height;
+  const int tile_pixels = tile_width * tile_height;
+  const int tile_size = tile_pixels * kBlockSlices;
+  const int tile_offset = block_height * tile_width;
+  const int pad_offset = pad_height * tile_width + pad_width;
+  const int in_slices = in_channel * args.batch;
+  const int in_blocks = (in_slices + kBlockSlices - 1) / kBlockSlices;
+
+  const int thread_width = threadIdx.x;
+  const int thread_height = threadIdx.y;
+  const int thread_channel = threadIdx.z;
+
+  // Position in block.
+  const int thread_pix = thread_height * in_width + thread_width;
+  const int thread_idx = thread_channel * block_pixels + thread_pix;
+
+  // Initialize tile, in particular the padding.
+  for (int i = thread_idx; i < tile_size; i += block_size) {
+    shared_data[i] = DType(0);
+  }
+  __syncthreads();
+
+  // Position in tensors.
+  const int tensor_idx = thread_channel * in_pixels + thread_pix;
+
+  // Position in (padded) shared memory.
+  const int data_pix = thread_height * tile_width + thread_width;
+  const int data_idx = thread_channel * tile_pixels + data_pix;
+
+  // Position in shared memory, offset by pad_height / pad_width.
+  const int tile_idx = data_idx + pad_offset;
+
+  const int filter_pix = thread_pix;
+  const int filter_channel = thread_channel;
+  const int filter_idx = filter_pixels * filter_channel + filter_pix;
+
+  const int max_slice = in_slices - thread_channel;
+  const int filter_write_offset = filter_pix < filter_pixels ? tile_size + filter_idx : 0;
+  const int filter_read_offset = tile_size +
+    (kDirection == DIRECTION_FORWARD ?
+     filter_pixels * filter_channel : filter_pixels * (filter_channel + 1));
+  const bool skip_second = !kEvenHeight && thread_height + (in_height & 1) == block_height;
+
+  for (int b = blockIdx.x; b < in_blocks; b += gridDim.x) {
+    const int slice = b * kBlockSlices;
+
+    const int inout_offset = slice * in_pixels + tensor_idx;
+    const bool slice_in_range = slice < max_slice;
+
+    if (slice_in_range) {
+      const DType* const in_ptr = inout_offset + input;
+      DType* const tile_ptr = tile_idx + shared_data;
+      tile_ptr[0] = ldg(in_ptr);
+      if (!skip_second) {
+        tile_ptr[tile_offset] = ldg(block_pixels + in_ptr);
+      }
+    }
+
+    if (filter_write_offset != 0) {
+      const int filter_offset = ((slice + filter_channel) % in_channel)* filter_pixels + filter_pix;
+      shared_data[filter_write_offset] = ldg(filter_offset + filter);
+    }
+
+    // Note: the condition to reach this is uniform across the entire block.
+    __syncthreads();
+
+    if (slice_in_range) {
+      DType sum1 = 0;
+      DType sum2 = 0;
+      int shared_offset = data_idx;
+      const DType* filter_ptr = filter_read_offset + shared_data;
+      CUDA_UNROLL for (int r = 0; r < filter_height; ++r) {
+        CUDA_UNROLL for (int c = 0; c < filter_width; ++c) {
+          if (kDirection == DIRECTION_BACKWARD) {
+            filter_ptr--;
+          }
+          const DType filter_value = *filter_ptr;
+          const DType* const tile_ptr = shared_offset + shared_data;
+          sum1 += filter_value * tile_ptr[0];
+          sum2 += filter_value * tile_ptr[tile_offset];
+          ++shared_offset;
+          if (kDirection == DIRECTION_FORWARD) {
+            filter_ptr++;
+          }
+        }
+        shared_offset += in_increment;
+      }
+      DType* const out_ptr = inout_offset + output;
+      if (kDirection == DIRECTION_FORWARD) {
+        out_ptr[0] = sum1;
+        if (!skip_second) {
+          out_ptr[block_pixels] = sum2;
+        }
+      } else {
+        out_ptr[0] += sum1;
+        if (!skip_second) {
+          out_ptr[block_pixels] += sum2;
+        }
+      }
+    }
+
+    // Note: the condition to reach this is uniform across the entire block.
+    __syncthreads();
+  }
+}
+
+template<typename DType>
+__global__ void __launch_bounds__(640, 2)
+DepthwiseConv2dBackwardDataKernel(const DepthwiseArgs args,
+                                  const DType* out_grad,
+                                  const DType* filter, DType* in_grad,
+                                  int num_in_grad) {
+  const int channel = args.in_channel;
+  const int in_height = args.in_height;
+  const int in_width = args.in_width;
+  const int filter_height = args.filter_height;
+  const int filter_width = args.filter_width;
+  const int stride_height = args.stride_height;
+  const int stride_width = args.stride_width;
+  const int pad_height = args.pad_height;
+  const int pad_width = args.pad_width;
+  const int out_height = args.out_height;
+  const int out_width = args.out_width;
+
+  const int in_pixels = in_height * in_width;
+  const int out_pixels = out_height * out_width;
+
+  CUDA_KERNEL_LOOP(thread_id, num_in_grad) {
+    // Compute the indexes of this thread in the input.
+    const int in_w = thread_id % in_width;
+    const int in_h = (thread_id / in_width) % in_height;
+    const int channel_idx = (thread_id / in_width / in_height) % channel;
+    const int batch_idx = thread_id / channel / in_width / in_height;
+    DType sum = 0.0f;
+
+    const int out_h_start = mxnet::common::cuda::CudaMax<int>(
+        0, (in_h - filter_height + pad_height + stride_height) / stride_height);
+    const int out_h_end = mxnet::common::cuda::CudaMin(
+        out_height - 1, (in_h + pad_height) / stride_height);
+    const int out_w_start = mxnet::common::cuda::CudaMax<int>(
+            0, (in_w - filter_width + pad_width + stride_width) / stride_width);
+    const int out_w_end = mxnet::common::cuda::CudaMin(
+        out_width - 1, (in_w + pad_width) / stride_width);
+
+    const int filter_offset_temp = channel_idx * filter_height * filter_width;
+    const int out_grad_offset_temp = (batch_idx * channel * out_pixels) +
+        (channel_idx * out_pixels);
+
+    for (int out_h = out_h_start; out_h <= out_h_end; ++out_h) {
+      const int f_h = in_h + pad_height - out_h * stride_height;
+      const int filter_offset_h = filter_offset_temp + f_h * filter_width;
+      const int out_grad_offset_h = out_grad_offset_temp + out_h * out_width;
+      for (int out_w = out_w_start; out_w <= out_w_end; ++out_w) {
+        const int f_w = in_w + pad_width - out_w * stride_width;
+        const int filter_offset = filter_offset_h + f_w;
+        const int out_grad_offset = out_grad_offset_h + out_w;
+        sum += ldg(out_grad + out_grad_offset) * ldg(filter + filter_offset);
+      }
+    }
+    const int in_grad_offset = (batch_idx * channel * in_pixels) +
+        (channel_idx * in_pixels) + (in_h * in_width) + (in_w);
+    in_grad[in_grad_offset] += sum;
+  }
+}
+
+// CUDA kernel to compute the depthwise convolution backward w.r.t. filter in
+// NCHW format, tailored for small images up to 32x32. Only use this kernel if
+// CanLaunchDepthwiseConv2dGPUSmall(args) returns true.
+// Tiles of the input tensor are loaded into shared memory before performing the
+// convolution. Per iteration and filter element, each thread first performs
+// a partial convolution for two elements, one each in the lower and upper half
+// of a tile. The intermediate result of all pixels of a warp are then
+// accumulated and written to shared memory. Finally, the values in shared
+// memory are warp-accumulated (in chunks of kAccumPixels elements) and summed
+// up in global memory using atomics.
+// Requirements: threads per block must be multiple of 32 and <= launch_bounds,
+// kAccumPixels * 64 >= args.in_height * args.in_width * kBlockSlices.
+template <typename DType, int kBlockSlices, int kAccumPixels, int kFilterHeight, int kFilterWidth>
+__global__
+__launch_bounds__(1024, 2) void DepthwiseConv2dBackwardFilterKernelSmall(
+    const DepthwiseArgs args, const DType* output, const DType* input, DType* filter) {
+  extern __shared__ __align__(sizeof(DType)) unsigned char shared_memory[];
+  DType* const shared_data = reinterpret_cast<DType*>(shared_memory);
+
+  const int in_height = args.in_height;
+  const int in_width = blockDim.x;  // slower (see b/62280718): args.in_width;
+  const int in_channel = args.in_channel;
+  const int filter_height = kFilterHeight > 0 ? kFilterHeight : args.filter_height;
+  const int filter_width = kFilterWidth > 0 ? kFilterWidth : args.filter_width;
+  const int pad_height = args.pad_height;
+  const int pad_width = args.pad_width;
+
+  const int block_height = blockDim.y;
+
+  // These values are the same for all threads and could
+  // be precomputed on the CPU.
+  const int block_pixels = in_width * block_height;
+  const int block_size = block_pixels * kBlockSlices;
+  assert((block_size & 31) == 0);
+  const int in_pixels = in_width * in_height;
+  const int in_increment = in_width - 1;
+  const int filter_pixels = filter_height * filter_width;
+  const int tile_width = in_width + filter_width - 1;
+  const int tile_height = 2 * block_height + filter_height - 1;
+  const int tile_pixels = tile_width * tile_height;
+  const int tile_size = tile_pixels * kBlockSlices;
+  const int tile_offset = block_height * tile_width;
+  const int pad_offset = pad_height * tile_width + pad_width;
+  const int in_slices = in_channel * args.batch;
+  const int in_blocks = (in_slices + kBlockSlices - 1) / kBlockSlices;
+  // The accumulator has a fixed number of pixels that can be reduced by one
+  // warp. Pixels beyond ceil(in_pixels * kBlockSlices / 64) are never written.
+  assert(kAccumPixels * 64 >= in_height * in_width * kBlockSlices);
+  const int accum_increment = kAccumPixels * kBlockSlices;
+  const int accum_size = filter_pixels * accum_increment;
+
+  const int thread_width = threadIdx.x;
+  const int thread_height = threadIdx.y;
+  const int thread_channel = threadIdx.z;
+
+  // Position in block.
+  const int thread_pix = thread_height * in_width + thread_width;
+  const int thread_idx = thread_channel * block_pixels + thread_pix;
+
+  // Initialize tile, in particular the padding and accumulator.
+  for (int i = thread_idx; i < tile_size + accum_size; i += block_size) {
+    shared_data[i] = DType(0);
+  }
+  __syncthreads();
+
+  // Position in tensors.
+  const int tensor_idx = thread_channel * in_pixels + thread_pix;
+
+  // Position in (padded) shared memory.
+  const int data_pix = thread_height * tile_width + thread_width;
+  const int data_idx = thread_channel * tile_pixels + data_pix;
+
+  // Position in shared memory, offset by pad_height / pad_width.
+  const int tile_idx = data_idx + pad_offset;
+
+  // Position in accumulator (kBlockSlices per warp, depth major).
+  const int accum_pix = thread_pix / (32 / kBlockSlices);
+  const int accum_idx = thread_channel * kAccumPixels + accum_pix;
+
+  const int max_slice = in_slices - thread_channel;
+  const int accum_offset = tile_size + accum_idx;
+  const bool skip_second = block_height + thread_height >= in_height;
+
+  for (int b = blockIdx.x; b < in_blocks; b += gridDim.x) {
+    const int slice = b * kBlockSlices;
+
+    const int inout_offset = slice * in_pixels + tensor_idx;
+    const bool slice_in_range = slice < max_slice;
+
+    if (slice_in_range) {
+      const DType* const in_ptr = inout_offset + input;
+      DType* const tile_ptr = tile_idx + shared_data;
+      tile_ptr[0] = ldg(in_ptr);
+      if (!skip_second) {
+        tile_ptr[tile_offset] = ldg(block_pixels + in_ptr);
+      }
+    }
+
+    // Note: the condition to reach this is uniform across the entire block.
+    __syncthreads();
+
+    if (slice_in_range) {
+      const DType* const out_ptr = inout_offset + output;
+      const DType out1 = ldg(out_ptr);
+      const DType out2 = skip_second ? DType(0) : ldg(block_pixels + out_ptr);
+      int shared_offset = data_idx;
+      DType* accum_ptr = accum_offset + shared_data;
+      CUDA_UNROLL for (int r = 0; r < filter_height; ++r) {
+        CUDA_UNROLL for (int c = 0; c < filter_width; ++c) {
+          const DType* const tile_ptr = shared_offset + shared_data;
+          DType val = out1 * tile_ptr[0] + out2 * tile_ptr[tile_offset];
+          // Warp-accumulate pixels of the same depth and write to accumulator.
+          for (int delta = 16 / kBlockSlices; delta > 0; delta /= 2) {
+            val += __shfl_down(val, delta);
+          }
+          if (!(thread_idx & 32 / kBlockSlices - 1)) {
+            *accum_ptr = val;
+          }
+          ++shared_offset;
+          accum_ptr += accum_increment;
+        }
+        shared_offset += in_increment;
+      }
+    }
+
+    // Note: the condition to reach this is uniform across the entire block.
+    __syncthreads();
+
+    const DType* const accum_data = tile_size + shared_data;
+    for (int i = thread_idx; i < accum_size; i += block_size) {
+      const int filter_idx = i / kAccumPixels;
+      const int filter_pix = filter_idx / kBlockSlices;
+      const int filter_channel = (slice + filter_idx % kBlockSlices) % in_channel;
+      // convert to CHW
+      const int filter_offset = filter_channel * filter_pixels +
+          (filter_pix/filter_width) * filter_height + filter_pix % filter_width;
+
+      if (filter_channel < in_channel) {
+        DType val = accum_data[i];
+        // Warp-accumulate pixels of the same depth from the accumulator.
+        for (int delta = kAccumPixels / 2; delta > 0; delta /= 2) {
+          val += __shfl_down(val, delta);
+        }
+        if (!(thread_idx & kAccumPixels - 1)) {
+          atomicAdd(filter_offset + filter, val);
+        }
+      }
+    }
+  }
+}
+
+
+}  // namespace cuda
+
+// Returns whether depthwise convolution forward or backward input pass can be
+// performed using the faster ('Small') variant of the kernel.
+bool CanLaunchDepthwiseConv2dGPUSmall(const DepthwiseArgs& args) {
+  return args.stride_height == 1 && args.stride_width == 1 && args.in_height <= 32 &&
+      args.in_width <= 32 && args.in_height == args.out_height &&
+      args.in_width == args.out_width && args.pad_height >= 0 &&
+      args.pad_height < args.filter_height && args.pad_width >= 0 &&
+      args.pad_width < args.filter_width &&
+      args.filter_height * args.filter_width <= (args.in_height + 1) / 2 * args.in_width;
+}
+
+// Returns whether depthwise convolution backward filter pass can be performed
+// using the faster ('Small') variant of the kernel.
+bool CanLaunchDepthwiseConv2dBackwardFilterGPUSmall(const DepthwiseArgs args,
+                                                    const int block_height) {
+  return args.stride_height == 1 && args.stride_width == 1 && args.in_height <= 32 &&
+      args.in_width <= 32 && args.in_height == args.out_height &&
+      args.in_width == args.out_width && args.pad_height >= 0 &&
+      args.pad_height < args.filter_height && args.pad_width >= 0 &&
+      args.pad_width < args.filter_width && block_height <= args.in_height &&
+      args.filter_height * args.filter_width <= block_height * args.in_width;
+}
+
+template <typename DType, cuda::DepthwiseConv2dDirection kDirection,
+          int kBlockSlices, bool kEvenHeight>
+void LaunchDepthwiseConv2dGPUSmall(mshadow::Stream<mxnet::gpu> *stream,
+                                   const DepthwiseArgs args,
+                                   const DType* input, const DType* filter, DType* output) {
+  const int block_height = (args.in_height + 1) / 2;
+  dim3 block_dim = dim3(args.in_width, block_height, kBlockSlices);
+
+  const int tile_width = args.in_width + args.filter_width - 1;
+  const int tile_height = block_height * 2 + args.filter_height - 1;
+  const int tile_pixels = tile_height * tile_width;
+  const int filter_pixels = args.filter_height * args.filter_width;
+  const int shared_memory_size =
+      kBlockSlices * (tile_pixels + filter_pixels) * sizeof(DType);
+  const int num_outputs =
+      args.batch * args.out_height * args.out_width * args.out_channel;
+  int block_count = std::min(num_outputs/(block_dim.x * block_dim.y * block_dim.z) + 1,
+                             (unsigned)mshadow::cuda::kMaxGridNum);
+  auto s = mshadow::Stream<mxnet::gpu>::GetStream(stream);
+  if (args.filter_height == 3 && args.filter_width == 3) {
+    cuda::DepthwiseConv2dKernelSmall<DType, kDirection, kBlockSlices, kEvenHeight, 3, 3>
+        <<<block_count, block_dim, shared_memory_size, s>>>(args, input, filter, output);
+  } else {
+    cuda::DepthwiseConv2dKernelSmall<DType, kDirection, kBlockSlices, kEvenHeight, -1, -1>
+        <<<block_count, block_dim, shared_memory_size, s>>>(args, input, filter, output);
+  }
+  MSHADOW_CUDA_POST_KERNEL_CHECK(DepthwiseConv2dKernelSmall);
+}
+
+template <typename DType, cuda::DepthwiseConv2dDirection kDirection, int kBlockSlices>
+void LaunchDepthwiseConv2dGPUSmall(mshadow::Stream<mxnet::gpu> *stream,
+                                   const DepthwiseArgs args,
+                                   const DType* input, const DType* filter, DType* output) {
+  if (args.in_height & 1) {
+    LaunchDepthwiseConv2dGPUSmall<DType, kDirection, kBlockSlices, false>(
+        stream, args, input, filter, output);
+  } else {
+    LaunchDepthwiseConv2dGPUSmall<DType, kDirection, kBlockSlices, true>(
+        stream, args, input, filter, output);
+  }
+}
+
+template <typename DType, cuda::DepthwiseConv2dDirection kDirection>
+void LaunchDepthwiseConv2dGPUSmall(mshadow::Stream<mxnet::gpu> *stream,
+                                   const DepthwiseArgs args,
+                                   const DType* input, const DType* filter, DType* output) {
+  // Maximize (power of two) kBlockSlices while keeping a block within 1024
+  // threads (2 pixels per thread).
+  const int block_pixels = (args.in_height + 1) / 2 * args.in_width;
+  if (block_pixels > 256) {
+    LaunchDepthwiseConv2dGPUSmall<DType, kDirection, 2>(stream, args, input, filter, output);
+  } else if (block_pixels > 128) {
+    LaunchDepthwiseConv2dGPUSmall<DType, kDirection, 4>(stream, args, input, filter, output);
+  } else {
+    LaunchDepthwiseConv2dGPUSmall<DType, kDirection, 8>(stream, args, input, filter, output);
+  }
+}
+
+template <typename DType, int kBlockSlices, int kAccumPixels>
+bool TryLaunchDepthwiseConv2dBackwardFilterGPUSmall(mshadow::Stream<mxnet::gpu> *stream,
+                                                    const DepthwiseArgs args,
+                                                    const int block_height,
+                                                    const DType* out_grad,
+                                                    const DType* input,
+                                                    DType* filter_grad) {
+  const int tile_width = args.in_width + args.filter_width - 1;
+  const int tile_height = block_height * 2 + args.filter_height - 1;
+  const int tile_pixels = tile_height * tile_width;
+  const int filter_pixels = args.filter_height * args.filter_width;
+  const int shared_memory_size =
+      kBlockSlices * (tile_pixels + filter_pixels * kAccumPixels) * sizeof(DType);
+  if (shared_memory_size > 46 * 1024) {
+    return false;
+  }
+
+  dim3 block_dim = dim3(args.in_width, block_height, kBlockSlices);
+  const int num_out_grad =
+      args.batch * args.out_height * args.out_width * args.out_channel;
+  int block_count = num_out_grad/(block_dim.x * block_dim.y * block_dim.z) + 1;
+  auto s = mshadow::Stream<mxnet::gpu>::GetStream(stream);
+  if (args.filter_height == 3 && args.filter_width == 3) {
+    cuda::DepthwiseConv2dBackwardFilterKernelSmall<DType, kBlockSlices, kAccumPixels, 3, 3>
+        <<<block_count, block_dim, shared_memory_size, s>>>(
+            args, out_grad, input, filter_grad);
+  } else {
+    cuda::DepthwiseConv2dBackwardFilterKernelSmall<DType, kBlockSlices, kAccumPixels, -1, -1>
+        <<<block_count, block_dim, shared_memory_size, s>>>(
+            args, out_grad, input, filter_grad);
+  }
+  MSHADOW_CUDA_POST_KERNEL_CHECK(DepthwiseConv2dBackwardFilterKernelSmall);
+  return true;
+}
+
+template <typename DType, int kBlockSlices>
+bool TryLaunchDepthwiseConv2dBackwardFilterGPUSmall(mshadow::Stream<mxnet::gpu> *stream,
+                                                    const DepthwiseArgs args,
+                                                    const int block_height,
+                                                    const DType* out_grad,
+                                                    const DType* input,
+                                                    DType* filter_grad) {
+  // Minimize (power of two) kAccumPixels, while satisfying
+  // kAccumPixels * 32 >= block_height * in_width * kBlockSlices.
+  const int block_pixels = block_height * args.in_width * kBlockSlices;
+  if (block_pixels > 512) {
+    return TryLaunchDepthwiseConv2dBackwardFilterGPUSmall<DType, kBlockSlices, 32>(
+        stream, args, block_height, out_grad, input, filter_grad);
+  } else if (block_pixels > 256) {
+    return TryLaunchDepthwiseConv2dBackwardFilterGPUSmall<DType, kBlockSlices, 16>(
+        stream, args, block_height, out_grad, input, filter_grad);
+  } else {
+    return TryLaunchDepthwiseConv2dBackwardFilterGPUSmall<DType, kBlockSlices, 8>(
+        stream, args, block_height, out_grad, input, filter_grad);
+  }
+}
+
+template <typename DType>
+bool TryLaunchDepthwiseConv2dBackwardFilterGPUSmall(mshadow::Stream<mxnet::gpu> *stream,
+                                                    const DepthwiseArgs args,
+                                                    const DType* out_grad,
+                                                    const DType* input,
+                                                    DType* filter_grad) {
+  // Maximize (power of two) kBlockSlices while keeping a block within 1024
+  // threads (2 pixels per thread).
+  int block_slices = 8;
+  int block_height = (args.in_height + 1) / 2;
+  int round_mask = 1;
+  for (; block_slices > 1; block_slices /= 2) {
+    // args.in_width * block_height * kBlockSlices must be multiple of 32.
+    for (; block_height * args.in_width * block_slices & 31;
+         round_mask = round_mask * 2 + 1) {
+      block_height = block_height + round_mask & ~round_mask;
+    }
+    int block_size = block_height * args.in_width * block_slices;
+    if (block_size <= 1024) {
+      break;
+    }
+  }
+
+  if (!CanLaunchDepthwiseConv2dBackwardFilterGPUSmall(args, block_height)) {
+    return false;
+  }
+
+  switch (block_slices) {
+    case 8:
+      return TryLaunchDepthwiseConv2dBackwardFilterGPUSmall<DType, 8>(
+          stream, args, block_height, out_grad, input, filter_grad);
+    case 4:
+      return TryLaunchDepthwiseConv2dBackwardFilterGPUSmall<DType, 4>(
+          stream, args, block_height, out_grad, input, filter_grad);
+    case 2:
+      return TryLaunchDepthwiseConv2dBackwardFilterGPUSmall<DType, 2>(
+          stream, args, block_height, out_grad, input, filter_grad);
+    default:
+      return false;
+  }
+}
+
+}  // namespace depthwise_conv
+}  // namespace tf
+
+#endif  // MXNET_OPERATOR_DEPTHWISE_CONVOLUTION_TF_CUH_
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 7d56b46..a33cb03 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -955,6 +955,44 @@ def test_convolution_grouping():
     for arr1, arr2 in zip(exe1.outputs + exe1.grad_arrays, exe2.outputs + exe2.grad_arrays):
         np.testing.assert_allclose(arr1.asnumpy(), arr2.asnumpy(), rtol=1e-3, atol=1e-4)
 
+
+def test_depthwise_convolution():
+    for num_base in [32, 64]:
+        for kernel in [(3,3), (5,5)]:
+            for stride in [(1,1), (2,2)]:
+                for pad in [(0,0), (1,1)]:
+                    num_filter = num_base
+                    num_group = num_base
+                    shape = (2, num_base, 32, 32)
+
+                    x = mx.sym.Variable('x')
+                    w = mx.sym.Variable('w')
+                    b = mx.sym.Variable('b')
+                    y1 = mx.sym.Convolution(data=x, weight=w, bias=b, num_filter=num_filter, num_group=num_group,
+                            kernel=kernel, stride=stride, pad=pad)
+                    xslice = mx.sym.SliceChannel(data=x, num_outputs=num_group, axis=1)
+                    wslice = mx.sym.SliceChannel(data=w, num_outputs=num_group, axis=0)
+                    bslice = mx.sym.SliceChannel(data=b, num_outputs=num_group, axis=0)
+                    y2 = mx.sym.Concat(*[mx.sym.Convolution(data=xslice[i], weight=wslice[i], bias=bslice[i],
+                                                            num_filter=num_filter//num_group, kernel=kernel,
+                                                            stride=stride, pad=pad)
+                                       for i in range(num_group)])
+
+                    dev = default_context()
+                    exe1 = y1.simple_bind(dev, x=shape)
+                    exe2 = y2.simple_bind(mx.cpu(), x=shape, w=(num_filter, shape[1]//num_group, kernel[0], kernel[1]),
+                            b=(num_filter,))
+                    for arr1, arr2 in zip(exe1.arg_arrays, exe2.arg_arrays):
+                        arr1[:] = np.random.normal(size=arr1.shape)
+                        arr2[:] = arr1
+                    exe1.forward(is_train=True)
+                    exe1.backward(exe1.outputs[0])
+                    exe2.forward(is_train=True)
+                    exe2.backward(exe2.outputs[0])
+
+                    for arr1, arr2 in zip(exe1.outputs + exe1.grad_arrays, exe2.outputs + exe2.grad_arrays):
+                        np.testing.assert_allclose(arr1.asnumpy(), arr2.asnumpy(), rtol=1e-3, atol=1e-4)
+
 def gen_broadcast_data(idx):
     # Manually set test cases
     binary_op_data_shape = np.array(

-- 
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].