You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/06/30 20:23:41 UTC

[GitHub] szha closed pull request #9860: [WIP] CMake NNPack support

szha closed pull request #9860: [WIP] CMake NNPack support
URL: https://github.com/apache/incubator-mxnet/pull/9860
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/.gitignore b/.gitignore
index 6be98c50466..90a05eaa2cd 100644
--- a/.gitignore
+++ b/.gitignore
@@ -127,6 +127,11 @@ CMakeFiles
 cmake_install.cmake
 lib
 
+# Kate / Kdevelop files
+*.kate-swp
+*.kdev4
+
+
 # Visual Studio Code
 .vscode
 
diff --git a/.gitmodules b/.gitmodules
index cdb8a553679..ea391ea584a 100644
--- a/.gitmodules
+++ b/.gitmodules
@@ -26,3 +26,6 @@
 [submodule "3rdparty/cub"]
 	path = 3rdparty/cub
 	url = https://github.com/dmlc/cub
+[submodule "3rdparty/nnpack/NNPACK"]
+	path = 3rdparty/NNPACK
+	url = https://github.com/Maratyszcza/NNPACK
diff --git a/3rdparty/NNPACK b/3rdparty/NNPACK
new file mode 160000
index 00000000000..83af25db118
--- /dev/null
+++ b/3rdparty/NNPACK
@@ -0,0 +1 @@
+Subproject commit 83af25db11883e160e65005f065f260488643c26
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 16d365355ce..6f97af51ad3 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -34,6 +34,7 @@ mxnet_option(USE_VTUNE            "Enable use of Intel Amplifier XE (VTune)" OFF
 mxnet_option(ENABLE_CUDA_RTC      "Build with CUDA runtime compilation support" ON)
 mxnet_option(INSTALL_EXAMPLES     "Install the example source files." OFF)
 mxnet_option(USE_SIGNAL_HANDLER   "Print stack traces on segfaults." OFF)
+mxnet_option(USE_NNPACK           "Build with NNPack support." OFF)
 
 if(USE_CUDA AND NOT USE_OLDCMAKECUDA)
   message(STATUS "CMake version '${CMAKE_VERSION}' using generator '${CMAKE_GENERATOR}'")
@@ -551,6 +552,52 @@ if(NOT EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/nnvm/CMakeLists.txt")
   list(APPEND mxnet_LINKER_LIBS ${nnvm_LINKER_LIBS})
 endif()
 
+# ---[ NNPack
+if(USE_NNPACK)
+  if (USE_MKLDNN)
+    message(FATAL_ERROR "Either MKLDNN or NNPack can be enabled but not both.")
+  endif()
+  # Add in NNPack and its dependencies
+  set(NNPACK_SOURCE_DIR "${CMAKE_SOURCE_DIR}/3rdparty/NNPACK")
+  if (EXISTS "${NNPACK_SOURCE_DIR}")
+    if (GOOGLETEST_SOURCE_DIR AND EXISTS "${GOOGLETEST_SOURCE_DIR}")
+      set(GOOGLETEST_SOURCE_DIR "${GTEST_ROOT}" CACHE STRING "Google Test source directory")
+    endif()
+
+    # Disable NNPack internal testing
+    set(NNPACK_BUILD_TESTS OFF CACHE BOOL "")
+    set(NNPACK_BUILD_BENCHMARKS OFF CACHE BOOL "")
+
+    # Compile statically
+    set(NNPACK_LIBRARY_TYPE "static" CACHE STRING "")
+    set(PTHREADPOOL_LIBRARY_TYPE "static" CACHE STRING "")
+    set(CPUINFO_LIBRARY_TYPE "static" CACHE STRING "")
+
+    # put NNPack dependencies in appropriate folders, see NNPack
+    # for other options
+    set(CONFU_DEPENDENCIES_SOURCE_DIR "${NNPACK_SOURCE_DIR}/deps"
+        CACHE PATH "Confu-style dependencies source directory")
+    set(CONFU_DEPENDENCIES_BINARY_DIR "${CMAKE_BINARY_DIR}/3rdparty/NNPACK/deps"
+        CACHE PATH "Confu-style dependencies source directory")
+
+    add_subdirectory("${NNPACK_SOURCE_DIR}")
+
+    # compile with -fPIC
+    set_property(TARGET nnpack PROPERTY POSITION_INDEPENDENT_CODE ON)
+    set_property(TARGET pthreadpool PROPERTY POSITION_INDEPENDENT_CODE ON)
+    set_property(TARGET cpuinfo PROPERTY POSITION_INDEPENDENT_CODE ON)
+
+    include_directories(${NNPACK_SOURCE_DIR}/include)
+    #PTHREADPOOL_SOURCE_DIR is set in cache by NNPack
+    include_directories(${PTHREADPOOL_SOURCE_DIR}/include)
+    add_definitions(-DMXNET_USE_NNPACK=1)
+    set(NNPack_LINKER_LIBS nnpack)
+    list(APPEND mxnet_LINKER_LIBS ${NNPack_LINKER_LIBS})
+  else()
+    message("NNPack submodule not found.")
+  endif()
+endif()
+
 if(NOT MSVC)
   # Only add c++11 flags and definitions after cuda compiling
   add_definitions(-DDMLC_USE_CXX11)
diff --git a/src/operator/convolution_v1.cc b/src/operator/convolution_v1.cc
index 86c0fbb3329..7d561f7d5ca 100644
--- a/src/operator/convolution_v1.cc
+++ b/src/operator/convolution_v1.cc
@@ -25,9 +25,6 @@
 */
 
 #include "./convolution_v1-inl.h"
-#if MXNET_USE_NNPACK == 1
-#include "./nnpack/nnpack_convolution-inl.h"
-#endif  // MXNET_USE_NNPACK
 
 namespace mxnet {
 namespace op {
diff --git a/src/operator/nn/convolution.cc b/src/operator/nn/convolution.cc
index 951063fb4b2..69ed4713d2c 100644
--- a/src/operator/nn/convolution.cc
+++ b/src/operator/nn/convolution.cc
@@ -28,9 +28,7 @@
 #include "../elemwise_op_common.h"
 #include "./mkldnn/mkldnn_ops-inl.h"
 #include "./mkldnn/mkldnn_base-inl.h"
-#if MXNET_USE_NNPACK == 1
-#include "./nnpack/nnpack_convolution-inl.h"
-#endif  // MXNET_USE_NNPACK
+#include "./nnpack/nnpack_ops-inl.h"
 
 namespace mxnet {
 namespace op {
diff --git a/src/operator/nn/fully_connected.cc b/src/operator/nn/fully_connected.cc
index 4362408a23a..b2b202a4ce0 100644
--- a/src/operator/nn/fully_connected.cc
+++ b/src/operator/nn/fully_connected.cc
@@ -25,9 +25,7 @@
 #include "./fully_connected-inl.h"
 #include "./mkldnn/mkldnn_ops-inl.h"
 #include "./mkldnn/mkldnn_base-inl.h"
-#if MXNET_USE_NNPACK == 1
-#include "./nnpack/nnpack_fully_connected-inl.h"
-#endif  // MXNET_USE_NNPACK
+#include "./nnpack/nnpack_ops-inl.h"
 
 namespace mxnet {
 namespace op {
diff --git a/src/operator/nn/nnpack/nnpack_ops-inl.h b/src/operator/nn/nnpack/nnpack_ops-inl.h
new file mode 100644
index 00000000000..279349206aa
--- /dev/null
+++ b/src/operator/nn/nnpack/nnpack_ops-inl.h
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file nnpack_ops-inl.h
+ * \brief
+ * \author David Braude
+*/
+
+#ifndef MXNET_OPERATOR_NN_NNPACK_NNPACK_OPS_INL_H_
+#define MXNET_OPERATOR_NN_NNPACK_NNPACK_OPS_INL_H_
+
+#if MXNET_USE_NNPACK == 1
+
+#include <mxnet/io.h>
+#include <mxnet/base.h>
+#include <mxnet/ndarray.h>
+#include <mxnet/operator.h>
+#include <mxnet/operator_util.h>
+#include <dmlc/logging.h>
+#include <dmlc/optional.h>
+#include <vector>
+#include <nnpack.h>
+
+// TODO:
+// Convolutional layer
+//     Inference-optimized forward propagation (nnp_convolution_inference)
+//     Training-optimized forward propagation (nnp_convolution_output)
+//     Training-optimized backward input gradient update (nnp_convolution_input_gradient)
+//     Training-optimized backward kernel gradient update (nnp_convolution_kernel_gradient)
+// Fully-connected layer
+//     Inference-optimized forward propagation (nnp_fully_connected_inference and nnp_fully_connected_inference_f16f32 version for FP16 weights)
+//     Training-optimized forward propagation (nnp_fully_connected_output)
+// Max pooling layer
+//     Forward propagation, both for training and inference, (nnp_max_pooling_output)
+// ReLU layer (with parametrized negative slope)
+//     Forward propagation, both for training and inference, optionally in-place, (nnp_relu_output)
+//     Backward input gradient update (nnp_relu_input_gradient)
+
+namespace mxnet {
+namespace op {
+
+/* For softmax */
+void NNPACKSoftmaxForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
+                          const NDArray &in_data, const OpReqType &req,
+                          const NDArray &out_data);
+
+}  // namespace op
+}  // namespace mxnet
+#endif  // MXNET_USE_MKLDNN == 1
+
+#endif  // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_OPS_INL_H_
diff --git a/src/operator/nnpack/nnpack_util.cc b/src/operator/nn/nnpack/nnpack_softmax.cc
similarity index 59%
rename from src/operator/nnpack/nnpack_util.cc
rename to src/operator/nn/nnpack/nnpack_softmax.cc
index 7d075e0554b..a1246545e6c 100644
--- a/src/operator/nnpack/nnpack_util.cc
+++ b/src/operator/nn/nnpack/nnpack_softmax.cc
@@ -18,20 +18,31 @@
  */
 
 /*!
- * Copyright (c) 2016 by Contributors
- * \file nnpack_util.cc
+ * \file nnpack_softmax.cc
  * \brief
- * \author Wei Wu
+ * \author David Braude
 */
 
-#if MXNET_USE_NNPACK == 1
-#include "nnpack_util.h"
+#include "../softmax-inl.h"
+#include "./nnpack_ops-inl.h"
+
 
+#if MXNET_USE_NNPACK == 1
 namespace mxnet {
 namespace op {
 
-NNPACKInitialize nnpackinitialize;
+void NNPACKSoftmaxForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
+                          const NDArray &in_data, const OpReqType &req,
+                          const NDArray &out_data) {
+  const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
+//     enum nnp_status nnp_softmax_output(
+//     size_t batch_size,
+//     size_t channels,
+//     const float input[],
+//     float output[],
+//     pthreadpool_t threadpool);
+}
 
-}  // namespace op
-}  // namespace mxnet
-#endif  // MXNET_USE_NNPACK
+}   // namespace op
+}   // namespace mxnet
+#endif
diff --git a/src/operator/nn/pooling.cc b/src/operator/nn/pooling.cc
index f719e0753e0..d13d43b1325 100644
--- a/src/operator/nn/pooling.cc
+++ b/src/operator/nn/pooling.cc
@@ -25,12 +25,8 @@
 */
 #include "../elemwise_op_common.h"
 #include "./pooling-inl.h"
-#if MXNET_USE_NNPACK == 1
-#include "./nnpack/nnpack_pooling-inl.h"
-#endif  // MXNET_USE_NNPACK
-#if MXNET_USE_MKLDNN == 1
+// #include "./nnpack/nnpack_pooling-inl.h"
 #include "./mkldnn/mkldnn_pooling-inl.h"
-#endif  // MXNET_USE_MKLDNN
 
 namespace mxnet {
 namespace op {
diff --git a/src/operator/nnpack/nnpack_convolution-inl.h b/src/operator/nnpack/nnpack_convolution-inl.h
deleted file mode 100644
index 0e2c73693d1..00000000000
--- a/src/operator/nnpack/nnpack_convolution-inl.h
+++ /dev/null
@@ -1,124 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * Copyright (c) 2016 by Contributors
- * \file nnpack_convolution-inl.h
- * \brief
- * \author Carwin
-*/
-#ifndef MXNET_OPERATOR_NNPACK_NNPACK_CONVOLUTION_INL_H_
-#define MXNET_OPERATOR_NNPACK_NNPACK_CONVOLUTION_INL_H_
-
-#include <dmlc/logging.h>
-#include <dmlc/parameter.h>
-#include <mxnet/operator.h>
-#include <algorithm>
-#include <map>
-#include <vector>
-#include <string>
-#include <utility>
-#include "../convolution-inl.h"
-#include "nnpack.h"
-#include "nnpack_util.h"
-
-namespace mxnet {
-namespace op {
-
-template <typename xpu, typename DType>
-class NNPACKConvolutionOp : public ConvolutionOp<xpu, DType> {
- private:
-  ConvolutionParam param_;
-
- public:
-  explicit NNPACKConvolutionOp(ConvolutionParam p)
-      : ConvolutionOp<xpu, DType>(p) {
-    this->param_ = p;
-  }
-
- public:
-  virtual void Forward(const OpContext &ctx, const std::vector<TBlob> &in_data,
-                       const std::vector<OpReqType> &req,
-                       const std::vector<TBlob> &out_data,
-                       const std::vector<TBlob> &aux_args) {
-    using namespace mshadow;
-    using namespace mshadow::expr;
-    Stream<xpu> *s = ctx.get_stream<xpu>();
-    Tensor<xpu, 4, DType> data = in_data[conv::kData].get<xpu, 4, DType>(s);
-    const size_t batch_size = data.shape_[0];
-    const size_t input_c = data.shape_[1];
-    const size_t input_h = data.shape_[2];
-    const size_t input_w = data.shape_[3];
-    Shape<3> wmat_shape =
-        Shape3(param_.num_group, param_.num_filter / param_.num_group,
-               input_c / param_.num_group * param_.kernel[0] *
-                   param_.kernel[1]);
-    Tensor<xpu, 3, DType> wmat =
-        in_data[conv::kWeight].get_with_shape<xpu, 3, DType>(wmat_shape, s);
-    Tensor<xpu, 4, DType> out = out_data[conv::kOut].get<xpu, 4, DType>(s);
-    nnp_size input_size = {input_w, input_h};
-    nnp_padding input_padding = {param_.pad[0], param_.pad[1], param_.pad[0],
-                               param_.pad[1]};
-    nnp_size kernel_size = {param_.kernel[1], param_.kernel[0]};
-    nnp_size output_subsampling = {param_.stride[1], param_.stride[0]};
-    Tensor<xpu, 1, DType> bias = in_data[conv::kBias].get<xpu, 1, DType>(s);
-
-    nnp_convolution_algorithm algorithm = nnp_convolution_algorithm_auto;
-    nnp_convolution_transform_strategy kts = nnp_convolution_transform_strategy_tuple_based;
-    nnp_status status = nnp_status_success;
-    if (batch_size == 1) {
-      status = nnp_convolution_inference(
-      algorithm,                    // enum nnp_convolution_algorithm,
-      kts,                          // enum nnp_convolution_transform_strategy,
-      input_c,                      // size_t input_channels,
-      param_.num_filter,            // size_t output_channels,
-      input_size,                   // struct nnp_size input_size,
-      input_padding,                // struct nnp_padding input_padding,
-      kernel_size,                  // struct nnp_size kernel_size,
-      output_subsampling,           // struct nnp_size output_subsampling,
-      data.dptr_,                   // const float input[],
-      wmat.dptr_,                   // const float kernel[],
-      bias.dptr_,                   // const float bias[],
-      out.dptr_,                    // float output[],
-      nnpackinitialize.threadpool,  // pthreadpool_t threadpool,
-      nullptr);
-    } else {
-      status = nnp_convolution_output(
-      algorithm,                    // enum nnp_convolution_algorithm algorithm,
-      batch_size,                   // size_t batch size of input tensor
-      input_c,                      // size_t input_channels,
-      param_.num_filter,            // size_t output_channels,
-      input_size,                   // struct nnp_size input_size,
-      input_padding,                // struct nnp_padding input_padding,
-      kernel_size,                  // struct nnp_size kernel_size,
-      data.dptr_,                   // const float input[],
-      wmat.dptr_,                   // const float kernel[],
-      bias.dptr_,                   // const float bias[],
-      out.dptr_,                    // float output[],
-      nnpackinitialize.threadpool,  // pthreadpool_t threadpool,
-      nullptr);
-    }
-    if (nnp_status_success != status) {
-      LOG(FATAL) << "nnpack convolution feedforward failed status=" << status;
-    }
-  }
-};  // class NNPACKConvolutionOp
-}  // namespace op
-}  // namespace mxnet
-#endif  // MXNET_OPERATOR_NNPACK_NNPACK_CONVOLUTION_INL_H_
diff --git a/src/operator/nnpack/nnpack_fully_connected-inl.h b/src/operator/nnpack/nnpack_fully_connected-inl.h
deleted file mode 100644
index d9412d20d0c..00000000000
--- a/src/operator/nnpack/nnpack_fully_connected-inl.h
+++ /dev/null
@@ -1,108 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * Copyright (c) 2016 by Contributors
- * \file nnpack_fully_connected-inl.h
- * \brief
- * \author Wei Wu
-*/
-#ifndef MXNET_OPERATOR_NNPACK_NNPACK_FULLY_CONNECTED_INL_H_
-#define MXNET_OPERATOR_NNPACK_NNPACK_FULLY_CONNECTED_INL_H_
-
-#include <dmlc/logging.h>
-#include <dmlc/parameter.h>
-#include <mxnet/operator.h>
-#include <algorithm>
-#include <map>
-#include <vector>
-#include <string>
-#include <utility>
-#include "../fully_connected-inl.h"
-#include "nnpack.h"
-#include "nnpack_util.h"
-
-namespace mxnet {
-namespace op {
-
-template <typename xpu, typename DType>
-class NNPACKFullyConnectedOp : public FullyConnectedOp<xpu, DType> {
- private:
-  FullyConnectedParam param_;
-
- public:
-  explicit NNPACKFullyConnectedOp(FullyConnectedParam p)
-      : FullyConnectedOp<xpu, DType>(p) {
-    this->param_ = p;
-  }
-
- public:
-  virtual void Forward(const OpContext &ctx, const std::vector<TBlob> &in_data,
-                       const std::vector<OpReqType> &req,
-                       const std::vector<TBlob> &out_data,
-                       const std::vector<TBlob> &aux_args) {
-    using namespace mshadow;
-    using namespace mshadow::expr;
-    if (req[fullc::kOut] == kNullOp) return;
-    CHECK_EQ(req[fullc::kOut], kWriteTo);
-    size_t expected = param_.no_bias ? 2 : 3;
-    CHECK_EQ(in_data.size(), expected);
-    CHECK_EQ(out_data.size(), 1);
-    const TShape& ishape = in_data[fullc::kData].shape_;
-    const TShape& oshape = out_data[fullc::kOut].shape_;
-    Stream<xpu> *s = ctx.get_stream<xpu>();
-    Tensor<xpu, 2, DType> data = in_data[fullc::kData].get_with_shape<xpu, 2, DType>(
-        Shape2(ishape[0], ishape.ProdShape(1, ishape.ndim())), s);
-    Tensor<xpu, 2, DType> wmat = in_data[fullc::kWeight].get<xpu, 2, DType>(s);
-    Tensor<xpu, 2, DType> out = out_data[fullc::kOut].get_with_shape<xpu, 2, DType>(
-        Shape2(oshape[0], oshape.ProdShape(1, oshape.ndim())), s);
-    const size_t batch_size = data.shape_[0];
-    const size_t input_c = data.shape_[1];
-    nnp_status status = nnp_status_success;
-    if (batch_size == 1) {
-      status = nnp_fully_connected_inference(
-      input_c,                       // size_t input_channels,
-      param_.num_hidden,             // size_t output_channels,
-      data.dptr_,                    // const float input[],
-      wmat.dptr_,                    // const float kernel[],
-      out.dptr_,                     // float output[],
-      nnpackinitialize.threadpool);  // pthreadpool_t threadpool,
-    } else {
-      status = nnp_fully_connected_output(
-      batch_size,                    // size_t batch size of input tensor
-      input_c,                       // size_t input_channels,
-      param_.num_hidden,             // size_t output_channels,
-      data.dptr_,                    // const float input[],
-      wmat.dptr_,                    // const float kernel[],
-      out.dptr_,                     // float output[],
-      nnpackinitialize.threadpool,   // pthreadpool_t threadpool,
-      nullptr);
-    }
-    if (nnp_status_success != status) {
-      LOG(FATAL) << "nnpack fully conneted feedforward failed status=" << status;
-    }
-    if (!param_.no_bias) {
-      Tensor<xpu, 1, DType> bias = in_data[fullc::kBias].get<xpu, 1, DType>(s);
-      out += repmat(bias, data.size(0));
-    }
-  }
-};  // class NNPACKFullyConnectedOp
-}  // namespace op
-}  // namespace mxnet
-#endif  // MXNET_OPERATOR_NNPACK_NNPACK_FULLY_CONNECTED_INL_H_
diff --git a/src/operator/nnpack/nnpack_pooling-inl.h b/src/operator/nnpack/nnpack_pooling-inl.h
deleted file mode 100644
index 25b47832275..00000000000
--- a/src/operator/nnpack/nnpack_pooling-inl.h
+++ /dev/null
@@ -1,91 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * Copyright (c) 2016 by Contributors
- * \file nnpack_pooling-inl.h
- * \brief
- * \author Wei Wu
-*/
-#ifndef MXNET_OPERATOR_NNPACK_NNPACK_POOLING_INL_H_
-#define MXNET_OPERATOR_NNPACK_NNPACK_POOLING_INL_H_
-
-#include <dmlc/logging.h>
-#include <dmlc/parameter.h>
-#include <mxnet/operator.h>
-#include <algorithm>
-#include <map>
-#include <vector>
-#include <string>
-#include <utility>
-#include "../pooling-inl.h"
-#include "nnpack.h"
-#include "nnpack_util.h"
-
-namespace mxnet {
-namespace op {
-
-template <typename xpu, typename DType>
-class NNPACKPoolingOp : public PoolingOp<xpu, DType> {
- private:
-  PoolingParam param_;
-
- public:
-  explicit NNPACKPoolingOp(PoolingParam p)
-      : PoolingOp<xpu, DType>(p) {
-    this->param_ = p;
-  }
-
- public:
-  virtual void Forward(const OpContext &ctx, const std::vector<TBlob> &in_data,
-                       const std::vector<OpReqType> &req,
-                       const std::vector<TBlob> &out_data,
-                       const std::vector<TBlob> &aux_args) {
-    using namespace mshadow;
-    using namespace mshadow::expr;
-    Stream<xpu> *s = ctx.get_stream<xpu>();
-    Tensor<xpu, 4, DType> data = in_data[pool_enum::kData].get<xpu, 4, DType>(s);
-    const size_t batch_size = data.shape_[0];
-    const size_t input_c = data.shape_[1];
-    const size_t input_h = data.shape_[2];
-    const size_t input_w = data.shape_[3];
-    Tensor<xpu, 4, DType> out = out_data[pool_enum::kOut].get<xpu, 4, DType>(s);
-    nnp_size input_size = {input_w, input_h};
-    nnp_padding input_padding = {param_.pad[0], param_.pad[1], param_.pad[0],
-                                 param_.pad[1]};
-    nnp_size kernel_size = {param_.kernel[1], param_.kernel[0]};
-    nnp_size output_subsampling = {param_.stride[1], param_.stride[0]};
-    nnp_status status = nnp_max_pooling_output(
-      batch_size,                    // size_t batch size of input tensor
-      input_c,                       // size_t input_channels,
-      input_size,                    // struct nnp_size input_size,
-      input_padding,                 // struct nnp_padding input_padding,
-      kernel_size,                   // struct nnp_size kernel_size,
-      output_subsampling,            // struct nnp_size output_subsampling,
-      data.dptr_,                    // const float input[],
-      out.dptr_,                     // float output[],
-      nnpackinitialize.threadpool);  // pthreadpool_t threadpool,
-    if (nnp_status_success != status) {
-      LOG(FATAL) << "nnpack max pooling feedforward failed status=" << status;
-    }
-  }
-};  // class NNPACKPoolingOp
-}  // namespace op
-}  // namespace mxnet
-#endif  // MXNET_OPERATOR_NNPACK_NNPACK_POOLING_INL_H_
diff --git a/src/operator/nnpack/nnpack_util.h b/src/operator/nnpack/nnpack_util.h
deleted file mode 100644
index 2edfb79ad46..00000000000
--- a/src/operator/nnpack/nnpack_util.h
+++ /dev/null
@@ -1,64 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * Copyright (c) 2016 by Contributors
- * \file nnpack_util.h
- * \brief
- * \author Carwin
-*/
-#ifndef MXNET_OPERATOR_NNPACK_NNPACK_UTIL_H_
-#define MXNET_OPERATOR_NNPACK_NNPACK_UTIL_H_
-
-#include <dmlc/logging.h>
-#include <dmlc/parameter.h>
-#include <nnpack.h>
-
-namespace mxnet {
-namespace op {
-
-class NNPACKInitialize {
- public:
-  pthreadpool_t threadpool;
-
- public:
-  NNPACKInitialize() {
-    nnp_status status = nnp_initialize();
-    if (nnp_status_success != status) {
-      LOG(FATAL) << "nnp_initialize failed status=" << status;
-    }
-    int num_threads = dmlc::GetEnv("MXNET_CPU_NNPACK_NTHREADS", 4);
-    this->threadpool = pthreadpool_create(num_threads);
-  }
-  virtual ~NNPACKInitialize() {
-    nnp_status status = nnp_deinitialize();
-    if (nnp_status_success != status) {
-      LOG(FATAL) << "nnp_deinitialize failed status=" << status;
-    }
-    pthreadpool_destroy(threadpool);
-  }
-};
-
-// nnpackinitialize will be used in all other nnpack op
-extern NNPACKInitialize nnpackinitialize;
-
-}  // namespace op
-}  // namespace mxnet
-
-#endif  // MXNET_OPERATOR_NNPACK_NNPACK_UTIL_H_


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services