You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/08/13 05:58:19 UTC

[GitHub] YouRancestor closed pull request #12136: accept GPU data as input

YouRancestor closed pull request #12136: accept GPU data as input
URL: https://github.com/apache/incubator-mxnet/pull/12136
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/include/mxnet/c_predict_api.h b/include/mxnet/c_predict_api.h
index a77d77702fe..8af078b992d 100644
--- a/include/mxnet/c_predict_api.h
+++ b/include/mxnet/c_predict_api.h
@@ -166,6 +166,25 @@ MXNET_DLL int MXPredSetInput(PredictorHandle handle,
                              const char* key,
                              const mx_float* data,
                              mx_uint size);
+
+/*!
+ * \brief MXPredSetInputGPU
+ * \param handle The predictor handle.
+ * \param key The name of input node to set. For feedforward net, this is "data".
+ * \param gpu_data pointer to GPU data, must in IEEE float format.
+ * \param size The size of data array, with float as element, used for safety check.
+ * \param dev_id the device id @em gpu_data sits at.
+ * \return 0 when success, -1 when failure.
+ * This function will save the data on @em gpu_data as a copy.
+ * It's better to alloc the @em gpu_data on the same device as specified in @em MXPredCreate.
+ */
+MXNET_DLL int MXPredSetInputGPU(PredictorHandle handle,
+                                const char* key,
+                                const mx_float *gpu_data,
+                                mx_uint size,
+                                int dev_id);
+
+
 /*!
  * \brief Run a forward pass to get the output.
  * \param handle The handle of the predictor.
diff --git a/src/c_api/c_predict_api.cc b/src/c_api/c_predict_api.cc
index becb0cb364f..ef589c1d0f8 100644
--- a/src/c_api/c_predict_api.cc
+++ b/src/c_api/c_predict_api.cc
@@ -372,6 +372,26 @@ int MXPredSetInput(PredictorHandle handle,
   API_END();
 }
 
+int MXPredSetInputGPU(PredictorHandle handle,
+                      const char* key,
+                      const mx_float* gpu_data,
+                      mx_uint size,
+                      int dev_id) {
+  MXAPIPredictor* p = static_cast<MXAPIPredictor*>(handle);
+  API_BEGIN();
+  auto it = p->key2arg.find(key);
+  if (it == p->key2arg.end()) {
+    LOG(FATAL) << "cannot find input key " << key;
+  }
+  NDArray& nd = p->arg_arrays[it->second];
+  TShape shape = nd.shape();
+  CHECK_EQ(shape.Size(), size) << "Input size mismatch.";
+  TBlob blob((void*)gpu_data, shape, gpu::kDevMask, mshadow::DataType<float>::kFlag, dev_id);  // NOLINT(*)
+  NDArray arr(blob, dev_id);
+  nd.SyncCopyFromNDArray(arr);
+  API_END();
+}
+
 int MXPredForward(PredictorHandle handle) {
   MXAPIPredictor* p = static_cast<MXAPIPredictor*>(handle);
   API_BEGIN();


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services