You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/03/10 07:11:26 UTC

[GitHub] cjolivier01 closed pull request #9984: [MXNET-38]add reshape predicator function to c_predict_api

cjolivier01 closed pull request #9984: [MXNET-38]add reshape predicator function to c_predict_api
URL: https://github.com/apache/incubator-mxnet/pull/9984
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/include/mxnet/c_predict_api.h b/include/mxnet/c_predict_api.h
index e4bfb398d53..a77d77702fe 100644
--- a/include/mxnet/c_predict_api.h
+++ b/include/mxnet/c_predict_api.h
@@ -119,6 +119,27 @@ MXNET_DLL int MXPredCreatePartialOut(const char* symbol_json_str,
                                      mx_uint num_output_nodes,
                                      const char** output_keys,
                                      PredictorHandle* out);
+/*!
+ * \brief Change the input shape of an existing predictor.
+ * \param num_input_nodes Number of input nodes to the net,
+ *    For feedforward net, this is 1.
+ * \param input_keys The name of input argument.
+ *    For feedforward net, this is {"data"}
+ * \param input_shape_indptr Index pointer of shapes of each input node.
+ *    The length of this array = num_input_nodes + 1.
+ *    For feedforward net that takes 4 dimensional input, this is {0, 4}.
+ * \param input_shape_data A flatted data of shapes of each input node.
+ *    For feedforward net that takes 4 dimensional input, this is the shape data.
+ * \param handle The original predictor handle.
+ * \param out The reshaped predictor handle.
+ * \return 0 when success, -1 when failure.
+ */
+int MXPredReshape(mx_uint num_input_nodes,
+                  const char** input_keys,
+                  const mx_uint* input_shape_indptr,
+                  const mx_uint* input_shape_data,
+                  PredictorHandle handle,
+                  PredictorHandle* out);
 /*!
  * \brief Get the shape of output node.
  *  The returned shape_data and shape_ndim is only valid before next call to MXPred function.
diff --git a/src/c_api/c_predict_api.cc b/src/c_api/c_predict_api.cc
index 3a693dbfcb9..becb0cb364f 100644
--- a/src/c_api/c_predict_api.cc
+++ b/src/c_api/c_predict_api.cc
@@ -43,6 +43,8 @@ struct MXAPIPredictor {
   std::vector<NDArray> out_arrays;
   // argument arrays
   std::vector<NDArray> arg_arrays;
+  // auxiliary arrays
+  std::vector<NDArray> aux_arrays;
   // output shapes
   std::vector<TShape> out_shapes;
   // uint32_t buffer for output shapes
@@ -51,6 +53,10 @@ struct MXAPIPredictor {
   std::unordered_map<std::string, size_t> key2arg;
   // executor
   std::unique_ptr<Executor> exec;
+  // symbol
+  nnvm::Symbol sym;
+  // Context
+  Context ctx;
 };
 
 struct MXAPINDList {
@@ -243,6 +249,97 @@ int MXPredCreatePartialOut(const char* symbol_json_str,
   API_END_HANDLE_ERROR(delete ret);
 }
 
+int MXPredReshape(mx_uint num_input_nodes,
+                  const char** input_keys,
+                  const mx_uint* input_shape_indptr,
+                  const mx_uint* input_shape_data,
+                  PredictorHandle handle,
+                  PredictorHandle* out) {
+  MXAPIPredictor* p = static_cast<MXAPIPredictor*>(handle);
+  std::unique_ptr<MXAPIPredictor> ret(new MXAPIPredictor());
+
+  API_BEGIN();
+  // shape inference
+  std::unordered_map<std::string, TShape> new_shape;
+  for (mx_uint i = 0; i < num_input_nodes; ++i) {
+    new_shape[std::string(input_keys[i])] =
+        TShape(input_shape_data + input_shape_indptr[i],
+            input_shape_data + input_shape_indptr[i + 1]);
+  }
+  ret->sym = p->sym;
+  std::vector<std::string> arg_names = ret->sym.ListInputNames(Symbol::kReadOnlyArgs);
+  std::vector<std::string> aux_names = ret->sym.ListInputNames(Symbol::kAuxiliaryStates);
+  std::vector<TShape> out_shapes(ret->sym.ListOutputNames().size());
+  std::vector<TShape> aux_shapes(aux_names.size());
+  std::vector<TShape> arg_shapes;
+  ret->key2arg = p->key2arg;
+
+  try {
+    std::vector<TShape> in_shapes;
+    in_shapes.reserve(arg_names.size());
+    for (std::string key : ret->sym.ListInputNames(Symbol::kAll)) {
+      if (new_shape.count(key) != 0) {
+        in_shapes.push_back(new_shape[key]);
+      } else {
+        in_shapes.push_back(TShape());
+      }
+    }
+    nnvm::Graph g; g.outputs = ret->sym.outputs;
+    g = mxnet::exec::InferShape(std::move(g), std::move(in_shapes), "__shape__");
+    bool infer_complete = (g.GetAttr<size_t>("shape_num_unknown_nodes") == 0);
+    CHECK(infer_complete)
+      << "The shape information of is not enough to get the shapes";
+    CopyAttr(g.indexed_graph(),
+             g.GetAttr<nnvm::ShapeVector>("shape"),
+             &arg_shapes, &out_shapes, &aux_shapes);
+  } catch (const mxnet::op::InferShapeError &err) {
+    throw dmlc::Error(err.msg);
+  }
+
+  ret->arg_arrays = p->arg_arrays;
+  ret->ctx = p->ctx;
+  for (size_t i=0; i < arg_names.size(); ++i) {
+    TShape newShape = arg_shapes[i];
+    NDArray &arr = p->arg_arrays[i];
+    if (new_shape.count(arg_names[i]) != 0) {
+      ret->arg_arrays[i].ReshapeAndAlloc(newShape);
+    } else {
+       CHECK_EQ(newShape.Size(), arr.shape().Size())
+        << "arg " << arg_names[i]
+        << " shape has been changed, only allow to change the shape of input data.";
+    }
+  }
+  p->arg_arrays.clear();
+
+  for (size_t i=0; i < aux_names.size(); ++i) {
+    TShape newShape = aux_shapes[i];
+    NDArray &arr = p->aux_arrays[i];
+    CHECK_EQ(newShape.Size(), arr.shape().Size())
+      << "aux " << aux_names[i]
+      << " shape has been changed, only allow to change the shape of input data.";
+  }
+  ret->aux_arrays = p->aux_arrays;
+  p->aux_arrays.clear();
+
+  // bind
+  {
+    std::map<std::string, Context> ctx_map;
+    std::vector<NDArray> grad_store;
+    grad_store.reserve(ret->arg_arrays.size());
+    std::vector<OpReqType> grad_req(ret->arg_arrays.size(), kNullOp);
+
+    ret->exec.reset(Executor::Bind(ret->sym, ret->ctx, ctx_map,
+                                   ret->arg_arrays,
+                                   grad_store, grad_req,
+                                   ret->aux_arrays,
+                                   p->exec.get()));
+    ret->out_shapes = out_shapes;
+    ret->out_arrays = ret->exec->outputs();
+  }
+  *out = ret.release();
+  API_END();
+}
+
 int MXPredGetOutputShape(PredictorHandle handle,
                          mx_uint out_index,
                          mx_uint** shape_data,


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services