You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by an...@apache.org on 2019/06/14 18:31:33 UTC

[incubator-mxnet] branch master updated: [MXNET-1415]Add MXEnginePushAsyncND and MXEnginePushSyncND C APIs (#15177)

This is an automated email from the ASF dual-hosted git repository.

anirudh2290 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 3b663ef  [MXNET-1415]Add MXEnginePushAsyncND and MXEnginePushSyncND C APIs (#15177)
3b663ef is described below

commit 3b663ef6cc9bc6992d769b8ae7fa8e72e4f3201b
Author: JackieWu <wk...@live.cn>
AuthorDate: Sat Jun 15 02:31:11 2019 +0800

    [MXNET-1415]Add MXEnginePushAsyncND and MXEnginePushSyncND C APIs (#15177)
    
    * add MXEnginePushAsyncND and MXEnginePushSyncND
    
    * fix test build
    
    * return exception value
    
    * retrigger CI
---
 include/mxnet/c_api.h                    | 55 +++++++++++++++++++++++++++++---
 src/c_api/c_api.cc                       | 40 +++++++++++++++++++++++
 tests/cpp/engine/threaded_engine_test.cc | 48 ++++++++++++++++++++++++++++
 3 files changed, 139 insertions(+), 4 deletions(-)

diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index 2d5122c..b3dca69 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -2792,9 +2792,9 @@ MXNET_DLL int MXNDArrayCreateFromSharedMemEx(int shared_pid, int shared_id, cons
   * \param ctx_handle Execution context.
   * \param const_vars_handle The variables that current operation will use
   *                          but not mutate.
-  * \param num_const_vars The number of const_vars.
+  * \param num_const_vars The number of const_vars_handle.
   * \param mutable_vars_handle The variables that current operation will mutate.
-  * \param num_mutable_vars The number of mutable_vars.
+  * \param num_mutable_vars The number of mutable_vars_handle.
   * \param prop_handle Property of the function.
   * \param priority Priority of the action, as hint to the engine.
   * \param opr_name The operation name.
@@ -2816,9 +2816,9 @@ MXNET_DLL int MXEnginePushAsync(EngineAsyncFunc async_func, void* func_param,
   * \param ctx_handle Execution context.
   * \param const_vars_handle The variables that current operation will use
   *                          but not mutate.
-  * \param num_const_vars The number of const_vars.
+  * \param num_const_vars The number of const_vars_handle.
   * \param mutable_vars_handle The variables that current operation will mutate.
-  * \param num_mutable_vars The number of mutable_vars.
+  * \param num_mutable_vars The number of mutable_vars_handle.
   * \param prop_handle Property of the function.
   * \param priority Priority of the action, as hint to the engine.
   * \param opr_name The operation name.
@@ -2830,6 +2830,53 @@ MXNET_DLL int MXEnginePushSync(EngineSyncFunc sync_func, void* func_param,
                                EngineFnPropertyHandle prop_handle DEFAULT(NULL),
                                int priority DEFAULT(0), const char* opr_name DEFAULT(NULL));
 
+/*!
+  * \brief Push an asynchronous operation to the engine.
+  * \param async_func Execution function whici takes a parameter on_complete
+  *                   that must be called when the execution ompletes.
+  * \param func_param The parameter set on calling async_func, can be NULL.
+  * \param deleter The callback to free func_param, can be NULL.
+  * \param ctx_handle Execution context.
+  * \param const_nds_handle The NDArrays that current operation will use
+  *                          but not mutate.
+  * \param num_const_nds The number of const_nds_handle.
+  * \param mutable_nds_handle The NDArrays that current operation will mutate.
+  * \param num_mutable_nds The number of mutable_nds_handle.
+  * \param prop_handle Property of the function.
+  * \param priority Priority of the action, as hint to the engine.
+  * \param opr_name The operation name.
+  * \param wait Whether this is a WaitForVar operation.
+  */
+MXNET_DLL int MXEnginePushAsyncND(EngineAsyncFunc async_func, void* func_param,
+                                EngineFuncParamDeleter deleter, ContextHandle ctx_handle,
+                                NDArrayHandle const_nds_handle, int num_const_nds,
+                                NDArrayHandle mutable_nds_handle, int num_mutable_nds,
+                                EngineFnPropertyHandle prop_handle DEFAULT(NULL),
+                                int priority DEFAULT(0), const char* opr_name DEFAULT(NULL),
+                                bool wait DEFAULT(false));
+
+/*!
+  * \brief Push a synchronous operation to the engine.
+  * \param sync_func Execution function that executes the operation.
+  * \param func_param The parameter set on calling sync_func, can be NULL.
+  * \param deleter The callback to free func_param, can be NULL.
+  * \param ctx_handle Execution context.
+  * \param const_nds_handle The NDArrays that current operation will use
+  *                          but not mutate.
+  * \param num_const_nds The number of const_nds_handle.
+  * \param mutable_nds_handle The NDArrays that current operation will mutate.
+  * \param num_mutable_nds The number of mutable_nds_handle.
+  * \param prop_handle Property of the function.
+  * \param priority Priority of the action, as hint to the engine.
+  * \param opr_name The operation name.
+  */
+MXNET_DLL int MXEnginePushSyncND(EngineSyncFunc sync_func, void* func_param,
+                               EngineFuncParamDeleter deleter, ContextHandle ctx_handle,
+                               NDArrayHandle const_nds_handle, int num_const_nds,
+                               NDArrayHandle mutable_nds_handle, int num_mutable_nds,
+                               EngineFnPropertyHandle prop_handle DEFAULT(NULL),
+                               int priority DEFAULT(0), const char* opr_name DEFAULT(NULL));
+
 #ifdef __cplusplus
 }
 #endif  // __cplusplus
diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc
index a9a49b0..35bd3ee 100644
--- a/src/c_api/c_api.cc
+++ b/src/c_api/c_api.cc
@@ -1534,6 +1534,46 @@ int MXEnginePushSync(EngineSyncFunc sync_func, void* func_param,
   API_END();
 }
 
+int MXEnginePushAsyncND(EngineAsyncFunc async_func, void* func_param,
+                      EngineFuncParamDeleter deleter, ContextHandle ctx_handle,
+                      NDArrayHandle const_nds_handle, int num_const_nds,
+                      NDArrayHandle mutable_nds_handle, int num_mutable_nds,
+                      EngineFnPropertyHandle prop_handle, int priority,
+                      const char* opr_name, bool wait) {
+  API_BEGIN();
+  NDArray* const_nds = static_cast<NDArray*>(const_nds_handle);
+  NDArray* mutable_nds = static_cast<NDArray*>(mutable_nds_handle);
+  std::vector<VarHandle> const_var_vec(num_const_nds);
+  for (int i = 0; i < num_const_nds; ++i) const_var_vec[i] = (const_nds+i)->var();
+  std::vector<VarHandle> mutable_var_vec(num_mutable_nds);
+  for (int i = 0; i < num_mutable_nds; ++i) mutable_var_vec[i] = (mutable_nds+i)->var();
+  return MXEnginePushAsync(async_func, func_param, deleter, ctx_handle,
+                           const_var_vec.data(), num_const_nds,
+                           mutable_var_vec.data(), num_mutable_nds,
+                           prop_handle, priority, opr_name, wait);
+  API_END();
+}
+
+int MXEnginePushSyncND(EngineSyncFunc sync_func, void* func_param,
+                     EngineFuncParamDeleter deleter, ContextHandle ctx_handle,
+                     NDArrayHandle const_nds_handle, int num_const_nds,
+                     NDArrayHandle mutable_nds_handle, int num_mutable_nds,
+                     EngineFnPropertyHandle prop_handle, int priority,
+                     const char* opr_name) {
+  API_BEGIN();
+  NDArray* const_nds = static_cast<NDArray*>(const_nds_handle);
+  NDArray* mutable_nds = static_cast<NDArray*>(mutable_nds_handle);
+  std::vector<VarHandle> const_var_vec(num_const_nds);
+  for (int i = 0; i < num_const_nds; ++i) const_var_vec[i] = (const_nds+i)->var();
+  std::vector<VarHandle> mutable_var_vec(num_mutable_nds);
+  for (int i = 0; i < num_mutable_nds; ++i) mutable_var_vec[i] = (mutable_nds+i)->var();
+  return MXEnginePushSync(sync_func, func_param, deleter, ctx_handle,
+                          const_var_vec.data(), num_const_nds,
+                          mutable_var_vec.data(), num_mutable_nds,
+                          prop_handle, priority, opr_name);
+  API_END();
+}
+
 int MXStorageEmptyCache(int dev_type, int dev_id) {
   API_BEGIN();
   Context ctx = Context::Create(static_cast<Context::DeviceType>(dev_type), dev_id);
diff --git a/tests/cpp/engine/threaded_engine_test.cc b/tests/cpp/engine/threaded_engine_test.cc
index ef3aec1..6b863f8 100644
--- a/tests/cpp/engine/threaded_engine_test.cc
+++ b/tests/cpp/engine/threaded_engine_test.cc
@@ -29,6 +29,7 @@
 #include <gtest/gtest.h>
 #include <mxnet/c_api.h>
 #include <mxnet/engine.h>
+#include <mxnet/ndarray.h>
 #include <dmlc/timer.h>
 #include <cstdio>
 #include <thread>
@@ -254,6 +255,53 @@ TEST(Engine, PushFunc) {
   EXPECT_EQ(res, -1);
 }
 
+TEST(Engine, PushFuncND) {
+  auto ctx = mxnet::Context{};
+  mxnet::NDArray nd(ctx);
+
+  // Test #1
+  LOG(INFO) << "===== Test #1: PushAsyncND param and deleter =====";
+  int* a = new int(100);
+  int res = MXEnginePushAsyncND(FooAsyncFunc, a, FooFuncDeleter, &ctx, &nd, 1, nullptr, 0);
+  EXPECT_EQ(res, 0);
+
+  // Test #2
+  LOG(INFO) << "===== Test #2: PushAsyncND NULL param and NULL deleter =====";
+  res = MXEnginePushAsyncND(FooAsyncFunc, nullptr, nullptr, &ctx, nullptr, 0, &nd, 0);
+  EXPECT_EQ(res, 0);
+
+  // Test #3
+  LOG(INFO) << "===== Test #3: PushAsyncND invalid number of const nds =====";
+  res = MXEnginePushAsyncND(FooAsyncFunc, nullptr, nullptr, &ctx, &nd, -1, nullptr, 0);
+  EXPECT_EQ(res, -1);
+
+  // Test #4
+  LOG(INFO) << "===== Test #4: PushAsyncND invalid number of mutable nds =====";
+  res = MXEnginePushAsyncND(FooAsyncFunc, nullptr, nullptr, &ctx, nullptr, 0, &nd, -1);
+  EXPECT_EQ(res, -1);
+
+  // Test #5
+  LOG(INFO) << "===== Test #5: PushSyncND param and deleter =====";
+  int* b = new int(101);
+  res = MXEnginePushSyncND(FooSyncFunc, b, FooFuncDeleter, &ctx, &nd, 1, nullptr, 0);
+  EXPECT_EQ(res, 0);
+
+  // Test #6
+  LOG(INFO) << "===== Test #6: PushSyncND NULL param and NULL deleter =====";
+  res = MXEnginePushSyncND(FooSyncFunc, nullptr, nullptr, &ctx, nullptr, 0, &nd, 1);
+  EXPECT_EQ(res, 0);
+
+  // Test #7
+  LOG(INFO) << "===== Test #7: PushSyncND invalid number of const nds =====";
+  res = MXEnginePushSyncND(FooSyncFunc, nullptr, nullptr, &ctx, &nd, -1, nullptr, 0);
+  EXPECT_EQ(res, -1);
+
+  // Test #8
+  LOG(INFO) << "===== Test #8: PushSyncND invalid number of mutable nds =====";
+  res = MXEnginePushSyncND(FooSyncFunc, nullptr, nullptr, &ctx, nullptr, 0, &nd, -1);
+  EXPECT_EQ(res, -1);
+}
+
 TEST(Engine, basics) {
   auto&& engine = mxnet::Engine::Get();
   auto&& var = engine->NewVariable();