You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/12/19 18:39:04 UTC

[GitHub] KellenSunderland closed pull request #13659: [MXNET-1252][1 of 2] Decouple NNVM to ONNX from NNVM to TenosrRT conv…

KellenSunderland closed pull request #13659: [MXNET-1252][1 of 2] Decouple NNVM to ONNX from NNVM to TenosrRT conv…
URL: https://github.com/apache/incubator-mxnet/pull/13659
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/executor/tensorrt_pass.cc b/src/executor/tensorrt_pass.cc
index b5fc8d15f7a..d26704c35cf 100644
--- a/src/executor/tensorrt_pass.cc
+++ b/src/executor/tensorrt_pass.cc
@@ -324,10 +324,10 @@ nnvm::NodePtr ConvertNnvmGraphToOnnx(const nnvm::Graph &g,
                                      std::unordered_map<std::string, NDArray>* const params_map) {
   auto p = nnvm::Node::Create();
   p->attrs.op = nnvm::Op::Get("_trt_op");
-  op::TRTParam trt_param = op::nnvm_to_onnx::ConvertNnvmGraphToOnnx(g, params_map);
-  p->attrs.dict["serialized_output_map"] = trt_param.serialized_output_map;
-  p->attrs.dict["serialized_input_map"]  = trt_param.serialized_input_map;
-  p->attrs.dict["serialized_onnx_graph"] = trt_param.serialized_onnx_graph;
+  op::ONNXParam onnx_param = op::nnvm_to_onnx::ConvertNnvmGraphToOnnx(g, params_map);
+  p->attrs.dict["serialized_output_map"] = onnx_param.serialized_output_map;
+  p->attrs.dict["serialized_input_map"]  = onnx_param.serialized_input_map;
+  p->attrs.dict["serialized_onnx_graph"] = onnx_param.serialized_onnx_graph;
   if (p->op()->attr_parser != nullptr) {
     p->op()->attr_parser(&(p->attrs));
   }
diff --git a/src/operator/contrib/nnvm_to_onnx-inl.h b/src/operator/contrib/nnvm_to_onnx-inl.h
index 58f88b05143..011ffe6b7dd 100644
--- a/src/operator/contrib/nnvm_to_onnx-inl.h
+++ b/src/operator/contrib/nnvm_to_onnx-inl.h
@@ -37,7 +37,6 @@
 #include <nnvm/graph.h>
 #include <nnvm/pass_functions.h>
 
-#include <NvInfer.h>
 #include <onnx/onnx.pb.h>
 
 #include <algorithm>
@@ -49,13 +48,48 @@
 #include <utility>
 #include <string>
 
-#include "./tensorrt-inl.h"
 #include "../operator_common.h"
 #include "../../common/utils.h"
 #include "../../common/serialization.h"
 
 namespace mxnet {
 namespace op {
+
+namespace nnvm_to_onnx {
+    enum class TypeIO { Inputs = 0, Outputs = 1 };
+    using NameToIdx_t = std::map<std::string, int32_t>;
+    using InferenceTuple_t = std::tuple<uint32_t, TShape, int, int>;
+    using InferenceMap_t = std::map<std::string, InferenceTuple_t>;
+}  // namespace nnvm_to_onnx
+
+struct ONNXParam : public dmlc::Parameter<ONNXParam> {
+  std::string serialized_onnx_graph;
+  std::string serialized_input_map;
+  std::string serialized_output_map;
+  nnvm_to_onnx::NameToIdx_t input_map;
+  nnvm_to_onnx::InferenceMap_t output_map;
+  ::onnx::ModelProto onnx_pb_graph;
+
+  ONNXParam() {}
+
+  ONNXParam(const ::onnx::ModelProto& onnx_graph,
+           const nnvm_to_onnx::InferenceMap_t& input_map,
+           const nnvm_to_onnx::NameToIdx_t& output_map) {
+    common::Serialize(input_map, &serialized_input_map);
+    common::Serialize(output_map, &serialized_output_map);
+    onnx_graph.SerializeToString(&serialized_onnx_graph);
+  }
+
+DMLC_DECLARE_PARAMETER(ONNXParam) {
+    DMLC_DECLARE_FIELD(serialized_onnx_graph)
+    .describe("Serialized ONNX graph");
+    DMLC_DECLARE_FIELD(serialized_input_map)
+    .describe("Map from inputs to topological order as input.");
+    DMLC_DECLARE_FIELD(serialized_output_map)
+    .describe("Map from outputs to order in g.outputs.");
+  }
+};
+
 namespace nnvm_to_onnx {
 
 using namespace nnvm;
@@ -76,7 +110,7 @@ void ConvertConstant(GraphProto* const graph_proto,
   const std::string& node_name,
   std::unordered_map<std::string, NDArray>* const shared_buffer);
 
-void ConvertOutput(op::tensorrt::InferenceMap_t* const trt_output_map,
+void ConvertOutput(op::nnvm_to_onnx::InferenceMap_t* const trt_output_map,
                    GraphProto* const graph_proto,
                    const std::unordered_map<std::string, uint32_t>::iterator& out_iter,
                    const std::string& node_name,
@@ -133,7 +167,7 @@ void ConvertElementwiseAdd(NodeProto *node_proto,
                     const nnvm::IndexedGraph &ig,
                     const array_view<IndexedGraph::NodeEntry> &inputs);
 
-TRTParam ConvertNnvmGraphToOnnx(
+ONNXParam ConvertNnvmGraphToOnnx(
     const nnvm::Graph &g,
     std::unordered_map<std::string, NDArray> *const shared_buffer);
 
diff --git a/src/operator/contrib/nnvm_to_onnx.cc b/src/operator/contrib/nnvm_to_onnx.cc
index 902466614c7..784384e94e1 100644
--- a/src/operator/contrib/nnvm_to_onnx.cc
+++ b/src/operator/contrib/nnvm_to_onnx.cc
@@ -47,7 +47,6 @@
 #include "../../operator/nn/fully_connected-inl.h"
 #include "../../operator/nn/pooling-inl.h"
 #include "../../operator/softmax_output-inl.h"
-#include "./tensorrt-inl.h"
 
 #if MXNET_USE_TENSORRT_ONNX_CHECKER
 #include <onnx/checker.h>
@@ -55,14 +54,17 @@
 
 namespace mxnet {
 namespace op {
+
+DMLC_REGISTER_PARAMETER(ONNXParam);
+
 namespace nnvm_to_onnx {
 
-op::TRTParam ConvertNnvmGraphToOnnx(
+op::ONNXParam ConvertNnvmGraphToOnnx(
     const nnvm::Graph& g,
     std::unordered_map<std::string, NDArray>* const shared_buffer) {
-    op::TRTParam trt_param;
-    op::tensorrt::NameToIdx_t trt_input_map;
-    op::tensorrt::InferenceMap_t trt_output_map;
+    op::ONNXParam onnx_param;
+    op::nnvm_to_onnx::NameToIdx_t onnx_input_map;
+    op::nnvm_to_onnx::InferenceMap_t onnx_output_map;
 
   const nnvm::IndexedGraph& ig = g.indexed_graph();
   const auto& storage_types = g.GetAttr<StorageTypeVector>("storage_type");
@@ -105,7 +107,7 @@ op::TRTParam ConvertNnvmGraphToOnnx(
           current_input++;
           continue;
         }
-        trt_input_map.emplace(node_name, current_input++);
+        onnx_input_map.emplace(node_name, current_input++);
         ConvertPlaceholder(node_name, placeholder_shapes, graph_proto);
       } else {
         // If it's not a placeholder, then by exclusion it's a constant.
@@ -140,23 +142,23 @@ op::TRTParam ConvertNnvmGraphToOnnx(
       auto out_iter = output_lookup.find(node_name);
       // We found an output
       if (out_iter != output_lookup.end()) {
-        ConvertOutput(&trt_output_map, graph_proto, out_iter, node_name, g,
+        ConvertOutput(&onnx_output_map, graph_proto, out_iter, node_name, g,
                       storage_types, dtypes);
       }  // output found
     }    // conversion function exists
   }      // loop over i from 0 to num_nodes
 
-  model_proto.SerializeToString(&trt_param.serialized_onnx_graph);
-  common::Serialize<op::tensorrt::NameToIdx_t>(trt_input_map,
-                                          &trt_param.serialized_input_map);
-  common::Serialize<op::tensorrt::InferenceMap_t>(trt_output_map,
-                                             &trt_param.serialized_output_map);
+  model_proto.SerializeToString(&onnx_param.serialized_onnx_graph);
+  common::Serialize<op::nnvm_to_onnx::NameToIdx_t>(onnx_input_map,
+                                          &onnx_param.serialized_input_map);
+  common::Serialize<op::nnvm_to_onnx::InferenceMap_t>(onnx_output_map,
+                                             &onnx_param.serialized_output_map);
 
 #if MXNET_USE_TENSORRT_ONNX_CHECKER
   onnx::checker::check_model(model_proto);
 #endif  // MXNET_USE_TENSORRT_ONNX_CHECKER
 
-  return trt_param;
+  return onnx_param;
 }
 
 void ConvertConvolution(NodeProto* node_proto, const NodeAttrs& attrs,
@@ -489,7 +491,7 @@ void ConvertConstant(
 }
 
 void ConvertOutput(
-    op::tensorrt::InferenceMap_t* const trt_output_map,
+    op::nnvm_to_onnx::InferenceMap_t* const output_map,
     GraphProto* const graph_proto,
     const std::unordered_map<std::string, uint32_t>::iterator& out_iter,
     const std::string& node_name, const nnvm::Graph& g,
@@ -501,10 +503,10 @@ void ConvertOutput(
   int dtype = dtypes[out_idx];
 
   // This should work with fp16 as well
-  op::tensorrt::InferenceTuple_t out_tuple{out_iter->second, out_shape, storage_type,
+  op::nnvm_to_onnx::InferenceTuple_t out_tuple{out_iter->second, out_shape, storage_type,
                                       dtype};
 
-  trt_output_map->emplace(node_name, out_tuple);
+  output_map->emplace(node_name, out_tuple);
 
   auto graph_out = graph_proto->add_output();
   auto tensor_type = graph_out->mutable_type()->mutable_tensor_type();
diff --git a/src/operator/contrib/tensorrt-inl.h b/src/operator/contrib/tensorrt-inl.h
index be335ab1208..062d22e3579 100644
--- a/src/operator/contrib/tensorrt-inl.h
+++ b/src/operator/contrib/tensorrt-inl.h
@@ -38,7 +38,6 @@
 #include <nnvm/pass_functions.h>
 
 #include <NvInfer.h>
-#include <onnx/onnx.pb.h>
 
 #include <algorithm>
 #include <iostream>
@@ -49,6 +48,7 @@
 #include <utility>
 #include <string>
 
+#include "nnvm_to_onnx-inl.h"
 #include "../operator_common.h"
 #include "../../common/utils.h"
 #include "../../common/serialization.h"
@@ -60,49 +60,15 @@ namespace mxnet {
 namespace op {
 
 using namespace nnvm;
-using namespace ::onnx;
 using int64 = ::google::protobuf::int64;
 
-namespace tensorrt {
-  enum class TypeIO { Inputs = 0, Outputs = 1 };
-  using NameToIdx_t = std::map<std::string, int32_t>;
-  using InferenceTuple_t = std::tuple<uint32_t, TShape, int, int>;
-  using InferenceMap_t = std::map<std::string, InferenceTuple_t>;
-}  // namespace tensorrt
 
 using trt_name_to_idx = std::map<std::string, uint32_t>;
 
-struct TRTParam : public dmlc::Parameter<TRTParam> {
-  std::string serialized_onnx_graph;
-  std::string serialized_input_map;
-  std::string serialized_output_map;
-  tensorrt::NameToIdx_t input_map;
-  tensorrt::InferenceMap_t output_map;
-  ::onnx::ModelProto onnx_pb_graph;
-
-  TRTParam() {}
-
-  TRTParam(const ::onnx::ModelProto& onnx_graph,
-           const tensorrt::InferenceMap_t& input_map,
-           const tensorrt::NameToIdx_t& output_map) {
-    common::Serialize(input_map, &serialized_input_map);
-    common::Serialize(output_map, &serialized_output_map);
-    onnx_graph.SerializeToString(&serialized_onnx_graph);
-  }
-
-DMLC_DECLARE_PARAMETER(TRTParam) {
-    DMLC_DECLARE_FIELD(serialized_onnx_graph)
-    .describe("Serialized ONNX graph");
-    DMLC_DECLARE_FIELD(serialized_input_map)
-    .describe("Map from inputs to topological order as input.");
-    DMLC_DECLARE_FIELD(serialized_output_map)
-    .describe("Map from outputs to order in g.outputs.");
-  }
-};
 
 struct TRTEngineParam {
   nvinfer1::IExecutionContext* trt_executor;
-  std::vector<std::pair<uint32_t, tensorrt::TypeIO> > binding_map;
+  std::vector<std::pair<uint32_t, nnvm_to_onnx::TypeIO> > binding_map;
 };
 
 }  // namespace op
diff --git a/src/operator/contrib/tensorrt.cc b/src/operator/contrib/tensorrt.cc
index 619fe1e2b8f..88a65fba3ea 100644
--- a/src/operator/contrib/tensorrt.cc
+++ b/src/operator/contrib/tensorrt.cc
@@ -44,20 +44,18 @@
 namespace mxnet {
 namespace op {
 
-DMLC_REGISTER_PARAMETER(TRTParam);
-
 OpStatePtr GetPtrMapping(nvinfer1::ICudaEngine* trt_engine,
-                         tensorrt::NameToIdx_t input_map,
-                         tensorrt::NameToIdx_t output_map) {
+                         nnvm_to_onnx::NameToIdx_t input_map,
+                         nnvm_to_onnx::NameToIdx_t output_map) {
   TRTEngineParam param;
   for (int b = 0; b < trt_engine->getNbBindings(); ++b) {
     const std::string& binding_name = trt_engine->getBindingName(b);
     if (trt_engine->bindingIsInput(b)) {
       param.binding_map.emplace_back(input_map[binding_name],
-                                     tensorrt::TypeIO::Inputs);
+                                     nnvm_to_onnx::TypeIO::Inputs);
     } else {
       param.binding_map.emplace_back(output_map[binding_name],
-                                     tensorrt::TypeIO::Outputs);
+                                     nnvm_to_onnx::TypeIO::Outputs);
     }
   }
   param.trt_executor = trt_engine->createExecutionContext();
@@ -67,7 +65,7 @@ OpStatePtr GetPtrMapping(nvinfer1::ICudaEngine* trt_engine,
 OpStatePtr TRTCreateState(const nnvm::NodeAttrs& attrs, Context /*ctx*/,
                           const std::vector<TShape>& /*ishape*/,
                           const std::vector<int>& /*itype*/) {
-  const auto& node_param = nnvm::get<TRTParam>(attrs.parsed);
+  const auto& node_param = nnvm::get<ONNXParam>(attrs.parsed);
 
   ::onnx::ModelProto model_proto;
   bool success = model_proto.ParseFromString(node_param.serialized_onnx_graph);
@@ -82,7 +80,7 @@ OpStatePtr TRTCreateState(const nnvm::NodeAttrs& attrs, Context /*ctx*/,
   nvinfer1::ICudaEngine* const trt_engine = ::onnx_to_tensorrt::onnxToTrtCtx(
       node_param.serialized_onnx_graph, batch_size, 1 << 30);
 
-  tensorrt::NameToIdx_t output_map;
+  nnvm_to_onnx::NameToIdx_t output_map;
   for (auto& el : node_param.output_map) {
     output_map[el.first] = std::get<0>(el.second);
   }
@@ -90,7 +88,7 @@ OpStatePtr TRTCreateState(const nnvm::NodeAttrs& attrs, Context /*ctx*/,
 }
 
 void TRTParamParser(nnvm::NodeAttrs* attrs) {
-  TRTParam param_;
+  ONNXParam param_;
 
   try {
     param_.Init(attrs->dict);
@@ -114,7 +112,7 @@ void TRTParamParser(nnvm::NodeAttrs* attrs) {
 
 inline bool TRTInferShape(const NodeAttrs& attrs, std::vector<TShape>* /*in_shape*/,
                           std::vector<TShape>* out_shape) {
-  const auto &node_param = nnvm::get<TRTParam>(attrs.parsed);
+  const auto &node_param = nnvm::get<ONNXParam>(attrs.parsed);
   for (auto& el : node_param.output_map) {
     (*out_shape)[std::get<0>(el.second)] = std::get<1>(el.second);
   }
@@ -131,7 +129,7 @@ inline bool TRTInferStorageType(const NodeAttrs& /*attrs*/, const int /*dev_mask
 
 inline bool TRTInferType(const NodeAttrs& attrs, std::vector<int>* /*in_dtype*/,
                          std::vector<int>* out_dtype) {
-  const auto& node_param = nnvm::get<TRTParam>(attrs.parsed);
+  const auto& node_param = nnvm::get<ONNXParam>(attrs.parsed);
   for (auto& el : node_param.output_map) {
     (*out_dtype)[std::get<0>(el.second)] = std::get<3>(el.second);
   }
@@ -140,7 +138,7 @@ inline bool TRTInferType(const NodeAttrs& attrs, std::vector<int>* /*in_dtype*/,
 
 inline std::vector<std::string> TRTListInputNames(const NodeAttrs& attrs) {
   std::vector<std::string> output;
-  const auto& node_param = nnvm::get<TRTParam>(attrs.parsed);
+  const auto& node_param = nnvm::get<ONNXParam>(attrs.parsed);
   output.resize(node_param.input_map.size());
   for (auto& el : node_param.input_map) {
     output[el.second] = el.first;
@@ -150,7 +148,7 @@ inline std::vector<std::string> TRTListInputNames(const NodeAttrs& attrs) {
 
 inline std::vector<std::string> TRTListOutputNames(const NodeAttrs& attrs) {
   std::vector<std::string> output;
-  const auto& node_param = nnvm::get<TRTParam>(attrs.parsed);
+  const auto& node_param = nnvm::get<ONNXParam>(attrs.parsed);
   output.resize(node_param.output_map.size());
   for (auto& el : node_param.output_map) {
     output[std::get<0>(el.second)] = el.first;
@@ -162,11 +160,11 @@ NNVM_REGISTER_OP(_trt_op)
     .describe(R"code(TRT operation (one engine)
 )code" ADD_FILELINE)
     .set_num_inputs([](const NodeAttrs& attrs) {
-      const auto& node_param = nnvm::get<TRTParam>(attrs.parsed);
+      const auto& node_param = nnvm::get<ONNXParam>(attrs.parsed);
       return node_param.input_map.size();
     })
     .set_num_outputs([](const NodeAttrs& attrs) {
-      const auto& node_param = nnvm::get<TRTParam>(attrs.parsed);
+      const auto& node_param = nnvm::get<ONNXParam>(attrs.parsed);
       return node_param.output_map.size();
     })
     .set_attr_parser(TRTParamParser)
diff --git a/src/operator/contrib/tensorrt.cu b/src/operator/contrib/tensorrt.cu
index 2fe8727b73e..9a9c3c02436 100644
--- a/src/operator/contrib/tensorrt.cu
+++ b/src/operator/contrib/tensorrt.cu
@@ -52,7 +52,7 @@ void TRTCompute(const OpStatePtr& state, const OpContext& ctx,
   std::vector<void*> bindings;
   bindings.reserve(param.binding_map.size());
   for (auto& p : param.binding_map) {
-    if (p.second == tensorrt::TypeIO::Inputs) {
+    if (p.second == nnvm_to_onnx::TypeIO::Inputs) {
       bindings.emplace_back(inputs[p.first].dptr_);
     } else {
       bindings.emplace_back(outputs[p.first].dptr_);


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services