You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ke...@apache.org on 2018/12/19 18:39:19 UTC

[incubator-mxnet] branch master updated: [MXNET-1252][1 of 2] Decouple NNVM to ONNX from NNVM to TenosrRT conversion (#13659)

This is an automated email from the ASF dual-hosted git repository.

kellen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new f85b17b  [MXNET-1252][1 of 2] Decouple NNVM to ONNX from NNVM to TenosrRT conversion (#13659)
f85b17b is described below

commit f85b17b8dd647f8d286ebc51cd2e582b2a829d41
Author: Haohuan Wang <ha...@umich.edu>
AuthorDate: Wed Dec 19 10:39:01 2018 -0800

    [MXNET-1252][1 of 2] Decouple NNVM to ONNX from NNVM to TenosrRT conversion (#13659)
---
 src/executor/tensorrt_pass.cc           |  8 +++----
 src/operator/contrib/nnvm_to_onnx-inl.h | 42 +++++++++++++++++++++++++++++----
 src/operator/contrib/nnvm_to_onnx.cc    | 34 +++++++++++++-------------
 src/operator/contrib/tensorrt-inl.h     | 38 ++---------------------------
 src/operator/contrib/tensorrt.cc        | 28 ++++++++++------------
 src/operator/contrib/tensorrt.cu        |  2 +-
 6 files changed, 76 insertions(+), 76 deletions(-)

diff --git a/src/executor/tensorrt_pass.cc b/src/executor/tensorrt_pass.cc
index b5fc8d1..d26704c 100644
--- a/src/executor/tensorrt_pass.cc
+++ b/src/executor/tensorrt_pass.cc
@@ -324,10 +324,10 @@ nnvm::NodePtr ConvertNnvmGraphToOnnx(const nnvm::Graph &g,
                                      std::unordered_map<std::string, NDArray>* const params_map) {
   auto p = nnvm::Node::Create();
   p->attrs.op = nnvm::Op::Get("_trt_op");
-  op::TRTParam trt_param = op::nnvm_to_onnx::ConvertNnvmGraphToOnnx(g, params_map);
-  p->attrs.dict["serialized_output_map"] = trt_param.serialized_output_map;
-  p->attrs.dict["serialized_input_map"]  = trt_param.serialized_input_map;
-  p->attrs.dict["serialized_onnx_graph"] = trt_param.serialized_onnx_graph;
+  op::ONNXParam onnx_param = op::nnvm_to_onnx::ConvertNnvmGraphToOnnx(g, params_map);
+  p->attrs.dict["serialized_output_map"] = onnx_param.serialized_output_map;
+  p->attrs.dict["serialized_input_map"]  = onnx_param.serialized_input_map;
+  p->attrs.dict["serialized_onnx_graph"] = onnx_param.serialized_onnx_graph;
   if (p->op()->attr_parser != nullptr) {
     p->op()->attr_parser(&(p->attrs));
   }
diff --git a/src/operator/contrib/nnvm_to_onnx-inl.h b/src/operator/contrib/nnvm_to_onnx-inl.h
index 58f88b0..011ffe6 100644
--- a/src/operator/contrib/nnvm_to_onnx-inl.h
+++ b/src/operator/contrib/nnvm_to_onnx-inl.h
@@ -37,7 +37,6 @@
 #include <nnvm/graph.h>
 #include <nnvm/pass_functions.h>
 
-#include <NvInfer.h>
 #include <onnx/onnx.pb.h>
 
 #include <algorithm>
@@ -49,13 +48,48 @@
 #include <utility>
 #include <string>
 
-#include "./tensorrt-inl.h"
 #include "../operator_common.h"
 #include "../../common/utils.h"
 #include "../../common/serialization.h"
 
 namespace mxnet {
 namespace op {
+
+namespace nnvm_to_onnx {
+    enum class TypeIO { Inputs = 0, Outputs = 1 };
+    using NameToIdx_t = std::map<std::string, int32_t>;
+    using InferenceTuple_t = std::tuple<uint32_t, TShape, int, int>;
+    using InferenceMap_t = std::map<std::string, InferenceTuple_t>;
+}  // namespace nnvm_to_onnx
+
+struct ONNXParam : public dmlc::Parameter<ONNXParam> {
+  std::string serialized_onnx_graph;
+  std::string serialized_input_map;
+  std::string serialized_output_map;
+  nnvm_to_onnx::NameToIdx_t input_map;
+  nnvm_to_onnx::InferenceMap_t output_map;
+  ::onnx::ModelProto onnx_pb_graph;
+
+  ONNXParam() {}
+
+  ONNXParam(const ::onnx::ModelProto& onnx_graph,
+           const nnvm_to_onnx::InferenceMap_t& input_map,
+           const nnvm_to_onnx::NameToIdx_t& output_map) {
+    common::Serialize(input_map, &serialized_input_map);
+    common::Serialize(output_map, &serialized_output_map);
+    onnx_graph.SerializeToString(&serialized_onnx_graph);
+  }
+
+DMLC_DECLARE_PARAMETER(ONNXParam) {
+    DMLC_DECLARE_FIELD(serialized_onnx_graph)
+    .describe("Serialized ONNX graph");
+    DMLC_DECLARE_FIELD(serialized_input_map)
+    .describe("Map from inputs to topological order as input.");
+    DMLC_DECLARE_FIELD(serialized_output_map)
+    .describe("Map from outputs to order in g.outputs.");
+  }
+};
+
 namespace nnvm_to_onnx {
 
 using namespace nnvm;
@@ -76,7 +110,7 @@ void ConvertConstant(GraphProto* const graph_proto,
   const std::string& node_name,
   std::unordered_map<std::string, NDArray>* const shared_buffer);
 
-void ConvertOutput(op::tensorrt::InferenceMap_t* const trt_output_map,
+void ConvertOutput(op::nnvm_to_onnx::InferenceMap_t* const trt_output_map,
                    GraphProto* const graph_proto,
                    const std::unordered_map<std::string, uint32_t>::iterator& out_iter,
                    const std::string& node_name,
@@ -133,7 +167,7 @@ void ConvertElementwiseAdd(NodeProto *node_proto,
                     const nnvm::IndexedGraph &ig,
                     const array_view<IndexedGraph::NodeEntry> &inputs);
 
-TRTParam ConvertNnvmGraphToOnnx(
+ONNXParam ConvertNnvmGraphToOnnx(
     const nnvm::Graph &g,
     std::unordered_map<std::string, NDArray> *const shared_buffer);
 
diff --git a/src/operator/contrib/nnvm_to_onnx.cc b/src/operator/contrib/nnvm_to_onnx.cc
index 9024666..784384e 100644
--- a/src/operator/contrib/nnvm_to_onnx.cc
+++ b/src/operator/contrib/nnvm_to_onnx.cc
@@ -47,7 +47,6 @@
 #include "../../operator/nn/fully_connected-inl.h"
 #include "../../operator/nn/pooling-inl.h"
 #include "../../operator/softmax_output-inl.h"
-#include "./tensorrt-inl.h"
 
 #if MXNET_USE_TENSORRT_ONNX_CHECKER
 #include <onnx/checker.h>
@@ -55,14 +54,17 @@
 
 namespace mxnet {
 namespace op {
+
+DMLC_REGISTER_PARAMETER(ONNXParam);
+
 namespace nnvm_to_onnx {
 
-op::TRTParam ConvertNnvmGraphToOnnx(
+op::ONNXParam ConvertNnvmGraphToOnnx(
     const nnvm::Graph& g,
     std::unordered_map<std::string, NDArray>* const shared_buffer) {
-    op::TRTParam trt_param;
-    op::tensorrt::NameToIdx_t trt_input_map;
-    op::tensorrt::InferenceMap_t trt_output_map;
+    op::ONNXParam onnx_param;
+    op::nnvm_to_onnx::NameToIdx_t onnx_input_map;
+    op::nnvm_to_onnx::InferenceMap_t onnx_output_map;
 
   const nnvm::IndexedGraph& ig = g.indexed_graph();
   const auto& storage_types = g.GetAttr<StorageTypeVector>("storage_type");
@@ -105,7 +107,7 @@ op::TRTParam ConvertNnvmGraphToOnnx(
           current_input++;
           continue;
         }
-        trt_input_map.emplace(node_name, current_input++);
+        onnx_input_map.emplace(node_name, current_input++);
         ConvertPlaceholder(node_name, placeholder_shapes, graph_proto);
       } else {
         // If it's not a placeholder, then by exclusion it's a constant.
@@ -140,23 +142,23 @@ op::TRTParam ConvertNnvmGraphToOnnx(
       auto out_iter = output_lookup.find(node_name);
       // We found an output
       if (out_iter != output_lookup.end()) {
-        ConvertOutput(&trt_output_map, graph_proto, out_iter, node_name, g,
+        ConvertOutput(&onnx_output_map, graph_proto, out_iter, node_name, g,
                       storage_types, dtypes);
       }  // output found
     }    // conversion function exists
   }      // loop over i from 0 to num_nodes
 
-  model_proto.SerializeToString(&trt_param.serialized_onnx_graph);
-  common::Serialize<op::tensorrt::NameToIdx_t>(trt_input_map,
-                                          &trt_param.serialized_input_map);
-  common::Serialize<op::tensorrt::InferenceMap_t>(trt_output_map,
-                                             &trt_param.serialized_output_map);
+  model_proto.SerializeToString(&onnx_param.serialized_onnx_graph);
+  common::Serialize<op::nnvm_to_onnx::NameToIdx_t>(onnx_input_map,
+                                          &onnx_param.serialized_input_map);
+  common::Serialize<op::nnvm_to_onnx::InferenceMap_t>(onnx_output_map,
+                                             &onnx_param.serialized_output_map);
 
 #if MXNET_USE_TENSORRT_ONNX_CHECKER
   onnx::checker::check_model(model_proto);
 #endif  // MXNET_USE_TENSORRT_ONNX_CHECKER
 
-  return trt_param;
+  return onnx_param;
 }
 
 void ConvertConvolution(NodeProto* node_proto, const NodeAttrs& attrs,
@@ -489,7 +491,7 @@ void ConvertConstant(
 }
 
 void ConvertOutput(
-    op::tensorrt::InferenceMap_t* const trt_output_map,
+    op::nnvm_to_onnx::InferenceMap_t* const output_map,
     GraphProto* const graph_proto,
     const std::unordered_map<std::string, uint32_t>::iterator& out_iter,
     const std::string& node_name, const nnvm::Graph& g,
@@ -501,10 +503,10 @@ void ConvertOutput(
   int dtype = dtypes[out_idx];
 
   // This should work with fp16 as well
-  op::tensorrt::InferenceTuple_t out_tuple{out_iter->second, out_shape, storage_type,
+  op::nnvm_to_onnx::InferenceTuple_t out_tuple{out_iter->second, out_shape, storage_type,
                                       dtype};
 
-  trt_output_map->emplace(node_name, out_tuple);
+  output_map->emplace(node_name, out_tuple);
 
   auto graph_out = graph_proto->add_output();
   auto tensor_type = graph_out->mutable_type()->mutable_tensor_type();
diff --git a/src/operator/contrib/tensorrt-inl.h b/src/operator/contrib/tensorrt-inl.h
index be335ab..062d22e 100644
--- a/src/operator/contrib/tensorrt-inl.h
+++ b/src/operator/contrib/tensorrt-inl.h
@@ -38,7 +38,6 @@
 #include <nnvm/pass_functions.h>
 
 #include <NvInfer.h>
-#include <onnx/onnx.pb.h>
 
 #include <algorithm>
 #include <iostream>
@@ -49,6 +48,7 @@
 #include <utility>
 #include <string>
 
+#include "nnvm_to_onnx-inl.h"
 #include "../operator_common.h"
 #include "../../common/utils.h"
 #include "../../common/serialization.h"
@@ -60,49 +60,15 @@ namespace mxnet {
 namespace op {
 
 using namespace nnvm;
-using namespace ::onnx;
 using int64 = ::google::protobuf::int64;
 
-namespace tensorrt {
-  enum class TypeIO { Inputs = 0, Outputs = 1 };
-  using NameToIdx_t = std::map<std::string, int32_t>;
-  using InferenceTuple_t = std::tuple<uint32_t, TShape, int, int>;
-  using InferenceMap_t = std::map<std::string, InferenceTuple_t>;
-}  // namespace tensorrt
 
 using trt_name_to_idx = std::map<std::string, uint32_t>;
 
-struct TRTParam : public dmlc::Parameter<TRTParam> {
-  std::string serialized_onnx_graph;
-  std::string serialized_input_map;
-  std::string serialized_output_map;
-  tensorrt::NameToIdx_t input_map;
-  tensorrt::InferenceMap_t output_map;
-  ::onnx::ModelProto onnx_pb_graph;
-
-  TRTParam() {}
-
-  TRTParam(const ::onnx::ModelProto& onnx_graph,
-           const tensorrt::InferenceMap_t& input_map,
-           const tensorrt::NameToIdx_t& output_map) {
-    common::Serialize(input_map, &serialized_input_map);
-    common::Serialize(output_map, &serialized_output_map);
-    onnx_graph.SerializeToString(&serialized_onnx_graph);
-  }
-
-DMLC_DECLARE_PARAMETER(TRTParam) {
-    DMLC_DECLARE_FIELD(serialized_onnx_graph)
-    .describe("Serialized ONNX graph");
-    DMLC_DECLARE_FIELD(serialized_input_map)
-    .describe("Map from inputs to topological order as input.");
-    DMLC_DECLARE_FIELD(serialized_output_map)
-    .describe("Map from outputs to order in g.outputs.");
-  }
-};
 
 struct TRTEngineParam {
   nvinfer1::IExecutionContext* trt_executor;
-  std::vector<std::pair<uint32_t, tensorrt::TypeIO> > binding_map;
+  std::vector<std::pair<uint32_t, nnvm_to_onnx::TypeIO> > binding_map;
 };
 
 }  // namespace op
diff --git a/src/operator/contrib/tensorrt.cc b/src/operator/contrib/tensorrt.cc
index 619fe1e..88a65fb 100644
--- a/src/operator/contrib/tensorrt.cc
+++ b/src/operator/contrib/tensorrt.cc
@@ -44,20 +44,18 @@
 namespace mxnet {
 namespace op {
 
-DMLC_REGISTER_PARAMETER(TRTParam);
-
 OpStatePtr GetPtrMapping(nvinfer1::ICudaEngine* trt_engine,
-                         tensorrt::NameToIdx_t input_map,
-                         tensorrt::NameToIdx_t output_map) {
+                         nnvm_to_onnx::NameToIdx_t input_map,
+                         nnvm_to_onnx::NameToIdx_t output_map) {
   TRTEngineParam param;
   for (int b = 0; b < trt_engine->getNbBindings(); ++b) {
     const std::string& binding_name = trt_engine->getBindingName(b);
     if (trt_engine->bindingIsInput(b)) {
       param.binding_map.emplace_back(input_map[binding_name],
-                                     tensorrt::TypeIO::Inputs);
+                                     nnvm_to_onnx::TypeIO::Inputs);
     } else {
       param.binding_map.emplace_back(output_map[binding_name],
-                                     tensorrt::TypeIO::Outputs);
+                                     nnvm_to_onnx::TypeIO::Outputs);
     }
   }
   param.trt_executor = trt_engine->createExecutionContext();
@@ -67,7 +65,7 @@ OpStatePtr GetPtrMapping(nvinfer1::ICudaEngine* trt_engine,
 OpStatePtr TRTCreateState(const nnvm::NodeAttrs& attrs, Context /*ctx*/,
                           const std::vector<TShape>& /*ishape*/,
                           const std::vector<int>& /*itype*/) {
-  const auto& node_param = nnvm::get<TRTParam>(attrs.parsed);
+  const auto& node_param = nnvm::get<ONNXParam>(attrs.parsed);
 
   ::onnx::ModelProto model_proto;
   bool success = model_proto.ParseFromString(node_param.serialized_onnx_graph);
@@ -82,7 +80,7 @@ OpStatePtr TRTCreateState(const nnvm::NodeAttrs& attrs, Context /*ctx*/,
   nvinfer1::ICudaEngine* const trt_engine = ::onnx_to_tensorrt::onnxToTrtCtx(
       node_param.serialized_onnx_graph, batch_size, 1 << 30);
 
-  tensorrt::NameToIdx_t output_map;
+  nnvm_to_onnx::NameToIdx_t output_map;
   for (auto& el : node_param.output_map) {
     output_map[el.first] = std::get<0>(el.second);
   }
@@ -90,7 +88,7 @@ OpStatePtr TRTCreateState(const nnvm::NodeAttrs& attrs, Context /*ctx*/,
 }
 
 void TRTParamParser(nnvm::NodeAttrs* attrs) {
-  TRTParam param_;
+  ONNXParam param_;
 
   try {
     param_.Init(attrs->dict);
@@ -114,7 +112,7 @@ void TRTParamParser(nnvm::NodeAttrs* attrs) {
 
 inline bool TRTInferShape(const NodeAttrs& attrs, std::vector<TShape>* /*in_shape*/,
                           std::vector<TShape>* out_shape) {
-  const auto &node_param = nnvm::get<TRTParam>(attrs.parsed);
+  const auto &node_param = nnvm::get<ONNXParam>(attrs.parsed);
   for (auto& el : node_param.output_map) {
     (*out_shape)[std::get<0>(el.second)] = std::get<1>(el.second);
   }
@@ -131,7 +129,7 @@ inline bool TRTInferStorageType(const NodeAttrs& /*attrs*/, const int /*dev_mask
 
 inline bool TRTInferType(const NodeAttrs& attrs, std::vector<int>* /*in_dtype*/,
                          std::vector<int>* out_dtype) {
-  const auto& node_param = nnvm::get<TRTParam>(attrs.parsed);
+  const auto& node_param = nnvm::get<ONNXParam>(attrs.parsed);
   for (auto& el : node_param.output_map) {
     (*out_dtype)[std::get<0>(el.second)] = std::get<3>(el.second);
   }
@@ -140,7 +138,7 @@ inline bool TRTInferType(const NodeAttrs& attrs, std::vector<int>* /*in_dtype*/,
 
 inline std::vector<std::string> TRTListInputNames(const NodeAttrs& attrs) {
   std::vector<std::string> output;
-  const auto& node_param = nnvm::get<TRTParam>(attrs.parsed);
+  const auto& node_param = nnvm::get<ONNXParam>(attrs.parsed);
   output.resize(node_param.input_map.size());
   for (auto& el : node_param.input_map) {
     output[el.second] = el.first;
@@ -150,7 +148,7 @@ inline std::vector<std::string> TRTListInputNames(const NodeAttrs& attrs) {
 
 inline std::vector<std::string> TRTListOutputNames(const NodeAttrs& attrs) {
   std::vector<std::string> output;
-  const auto& node_param = nnvm::get<TRTParam>(attrs.parsed);
+  const auto& node_param = nnvm::get<ONNXParam>(attrs.parsed);
   output.resize(node_param.output_map.size());
   for (auto& el : node_param.output_map) {
     output[std::get<0>(el.second)] = el.first;
@@ -162,11 +160,11 @@ NNVM_REGISTER_OP(_trt_op)
     .describe(R"code(TRT operation (one engine)
 )code" ADD_FILELINE)
     .set_num_inputs([](const NodeAttrs& attrs) {
-      const auto& node_param = nnvm::get<TRTParam>(attrs.parsed);
+      const auto& node_param = nnvm::get<ONNXParam>(attrs.parsed);
       return node_param.input_map.size();
     })
     .set_num_outputs([](const NodeAttrs& attrs) {
-      const auto& node_param = nnvm::get<TRTParam>(attrs.parsed);
+      const auto& node_param = nnvm::get<ONNXParam>(attrs.parsed);
       return node_param.output_map.size();
     })
     .set_attr_parser(TRTParamParser)
diff --git a/src/operator/contrib/tensorrt.cu b/src/operator/contrib/tensorrt.cu
index 2fe8727..9a9c3c0 100644
--- a/src/operator/contrib/tensorrt.cu
+++ b/src/operator/contrib/tensorrt.cu
@@ -52,7 +52,7 @@ void TRTCompute(const OpStatePtr& state, const OpContext& ctx,
   std::vector<void*> bindings;
   bindings.reserve(param.binding_map.size());
   for (auto& p : param.binding_map) {
-    if (p.second == tensorrt::TypeIO::Inputs) {
+    if (p.second == nnvm_to_onnx::TypeIO::Inputs) {
       bindings.emplace_back(inputs[p.first].dptr_);
     } else {
       bindings.emplace_back(outputs[p.first].dptr_);