You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by an...@apache.org on 2019/06/14 18:31:33 UTC
[incubator-mxnet] branch master updated: [MXNET-1415]Add
MXEnginePushAsyncND and MXEnginePushSyncND C APIs (#15177)
This is an automated email from the ASF dual-hosted git repository.
anirudh2290 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 3b663ef [MXNET-1415]Add MXEnginePushAsyncND and MXEnginePushSyncND C APIs (#15177)
3b663ef is described below
commit 3b663ef6cc9bc6992d769b8ae7fa8e72e4f3201b
Author: JackieWu <wk...@live.cn>
AuthorDate: Sat Jun 15 02:31:11 2019 +0800
[MXNET-1415]Add MXEnginePushAsyncND and MXEnginePushSyncND C APIs (#15177)
* add MXEnginePushAsyncND and MXEnginePushSyncND
* fix test build
* return exception value
* retrigger CI
---
include/mxnet/c_api.h | 55 +++++++++++++++++++++++++++++---
src/c_api/c_api.cc | 40 +++++++++++++++++++++++
tests/cpp/engine/threaded_engine_test.cc | 48 ++++++++++++++++++++++++++++
3 files changed, 139 insertions(+), 4 deletions(-)
diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index 2d5122c..b3dca69 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -2792,9 +2792,9 @@ MXNET_DLL int MXNDArrayCreateFromSharedMemEx(int shared_pid, int shared_id, cons
* \param ctx_handle Execution context.
* \param const_vars_handle The variables that current operation will use
* but not mutate.
- * \param num_const_vars The number of const_vars.
+ * \param num_const_vars The number of const_vars_handle.
* \param mutable_vars_handle The variables that current operation will mutate.
- * \param num_mutable_vars The number of mutable_vars.
+ * \param num_mutable_vars The number of mutable_vars_handle.
* \param prop_handle Property of the function.
* \param priority Priority of the action, as hint to the engine.
* \param opr_name The operation name.
@@ -2816,9 +2816,9 @@ MXNET_DLL int MXEnginePushAsync(EngineAsyncFunc async_func, void* func_param,
* \param ctx_handle Execution context.
* \param const_vars_handle The variables that current operation will use
* but not mutate.
- * \param num_const_vars The number of const_vars.
+ * \param num_const_vars The number of const_vars_handle.
* \param mutable_vars_handle The variables that current operation will mutate.
- * \param num_mutable_vars The number of mutable_vars.
+ * \param num_mutable_vars The number of mutable_vars_handle.
* \param prop_handle Property of the function.
* \param priority Priority of the action, as hint to the engine.
* \param opr_name The operation name.
@@ -2830,6 +2830,53 @@ MXNET_DLL int MXEnginePushSync(EngineSyncFunc sync_func, void* func_param,
EngineFnPropertyHandle prop_handle DEFAULT(NULL),
int priority DEFAULT(0), const char* opr_name DEFAULT(NULL));
+/*!
+ * \brief Push an asynchronous operation to the engine.
+ * \param async_func Execution function whici takes a parameter on_complete
+ * that must be called when the execution ompletes.
+ * \param func_param The parameter set on calling async_func, can be NULL.
+ * \param deleter The callback to free func_param, can be NULL.
+ * \param ctx_handle Execution context.
+ * \param const_nds_handle The NDArrays that current operation will use
+ * but not mutate.
+ * \param num_const_nds The number of const_nds_handle.
+ * \param mutable_nds_handle The NDArrays that current operation will mutate.
+ * \param num_mutable_nds The number of mutable_nds_handle.
+ * \param prop_handle Property of the function.
+ * \param priority Priority of the action, as hint to the engine.
+ * \param opr_name The operation name.
+ * \param wait Whether this is a WaitForVar operation.
+ */
+MXNET_DLL int MXEnginePushAsyncND(EngineAsyncFunc async_func, void* func_param,
+ EngineFuncParamDeleter deleter, ContextHandle ctx_handle,
+ NDArrayHandle const_nds_handle, int num_const_nds,
+ NDArrayHandle mutable_nds_handle, int num_mutable_nds,
+ EngineFnPropertyHandle prop_handle DEFAULT(NULL),
+ int priority DEFAULT(0), const char* opr_name DEFAULT(NULL),
+ bool wait DEFAULT(false));
+
+/*!
+ * \brief Push a synchronous operation to the engine.
+ * \param sync_func Execution function that executes the operation.
+ * \param func_param The parameter set on calling sync_func, can be NULL.
+ * \param deleter The callback to free func_param, can be NULL.
+ * \param ctx_handle Execution context.
+ * \param const_nds_handle The NDArrays that current operation will use
+ * but not mutate.
+ * \param num_const_nds The number of const_nds_handle.
+ * \param mutable_nds_handle The NDArrays that current operation will mutate.
+ * \param num_mutable_nds The number of mutable_nds_handle.
+ * \param prop_handle Property of the function.
+ * \param priority Priority of the action, as hint to the engine.
+ * \param opr_name The operation name.
+ */
+MXNET_DLL int MXEnginePushSyncND(EngineSyncFunc sync_func, void* func_param,
+ EngineFuncParamDeleter deleter, ContextHandle ctx_handle,
+ NDArrayHandle const_nds_handle, int num_const_nds,
+ NDArrayHandle mutable_nds_handle, int num_mutable_nds,
+ EngineFnPropertyHandle prop_handle DEFAULT(NULL),
+ int priority DEFAULT(0), const char* opr_name DEFAULT(NULL));
+
#ifdef __cplusplus
}
#endif // __cplusplus
diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc
index a9a49b0..35bd3ee 100644
--- a/src/c_api/c_api.cc
+++ b/src/c_api/c_api.cc
@@ -1534,6 +1534,46 @@ int MXEnginePushSync(EngineSyncFunc sync_func, void* func_param,
API_END();
}
+int MXEnginePushAsyncND(EngineAsyncFunc async_func, void* func_param,
+ EngineFuncParamDeleter deleter, ContextHandle ctx_handle,
+ NDArrayHandle const_nds_handle, int num_const_nds,
+ NDArrayHandle mutable_nds_handle, int num_mutable_nds,
+ EngineFnPropertyHandle prop_handle, int priority,
+ const char* opr_name, bool wait) {
+ API_BEGIN();
+ NDArray* const_nds = static_cast<NDArray*>(const_nds_handle);
+ NDArray* mutable_nds = static_cast<NDArray*>(mutable_nds_handle);
+ std::vector<VarHandle> const_var_vec(num_const_nds);
+ for (int i = 0; i < num_const_nds; ++i) const_var_vec[i] = (const_nds+i)->var();
+ std::vector<VarHandle> mutable_var_vec(num_mutable_nds);
+ for (int i = 0; i < num_mutable_nds; ++i) mutable_var_vec[i] = (mutable_nds+i)->var();
+ return MXEnginePushAsync(async_func, func_param, deleter, ctx_handle,
+ const_var_vec.data(), num_const_nds,
+ mutable_var_vec.data(), num_mutable_nds,
+ prop_handle, priority, opr_name, wait);
+ API_END();
+}
+
+int MXEnginePushSyncND(EngineSyncFunc sync_func, void* func_param,
+ EngineFuncParamDeleter deleter, ContextHandle ctx_handle,
+ NDArrayHandle const_nds_handle, int num_const_nds,
+ NDArrayHandle mutable_nds_handle, int num_mutable_nds,
+ EngineFnPropertyHandle prop_handle, int priority,
+ const char* opr_name) {
+ API_BEGIN();
+ NDArray* const_nds = static_cast<NDArray*>(const_nds_handle);
+ NDArray* mutable_nds = static_cast<NDArray*>(mutable_nds_handle);
+ std::vector<VarHandle> const_var_vec(num_const_nds);
+ for (int i = 0; i < num_const_nds; ++i) const_var_vec[i] = (const_nds+i)->var();
+ std::vector<VarHandle> mutable_var_vec(num_mutable_nds);
+ for (int i = 0; i < num_mutable_nds; ++i) mutable_var_vec[i] = (mutable_nds+i)->var();
+ return MXEnginePushSync(sync_func, func_param, deleter, ctx_handle,
+ const_var_vec.data(), num_const_nds,
+ mutable_var_vec.data(), num_mutable_nds,
+ prop_handle, priority, opr_name);
+ API_END();
+}
+
int MXStorageEmptyCache(int dev_type, int dev_id) {
API_BEGIN();
Context ctx = Context::Create(static_cast<Context::DeviceType>(dev_type), dev_id);
diff --git a/tests/cpp/engine/threaded_engine_test.cc b/tests/cpp/engine/threaded_engine_test.cc
index ef3aec1..6b863f8 100644
--- a/tests/cpp/engine/threaded_engine_test.cc
+++ b/tests/cpp/engine/threaded_engine_test.cc
@@ -29,6 +29,7 @@
#include <gtest/gtest.h>
#include <mxnet/c_api.h>
#include <mxnet/engine.h>
+#include <mxnet/ndarray.h>
#include <dmlc/timer.h>
#include <cstdio>
#include <thread>
@@ -254,6 +255,53 @@ TEST(Engine, PushFunc) {
EXPECT_EQ(res, -1);
}
+TEST(Engine, PushFuncND) {
+ auto ctx = mxnet::Context{};
+ mxnet::NDArray nd(ctx);
+
+ // Test #1
+ LOG(INFO) << "===== Test #1: PushAsyncND param and deleter =====";
+ int* a = new int(100);
+ int res = MXEnginePushAsyncND(FooAsyncFunc, a, FooFuncDeleter, &ctx, &nd, 1, nullptr, 0);
+ EXPECT_EQ(res, 0);
+
+ // Test #2
+ LOG(INFO) << "===== Test #2: PushAsyncND NULL param and NULL deleter =====";
+ res = MXEnginePushAsyncND(FooAsyncFunc, nullptr, nullptr, &ctx, nullptr, 0, &nd, 0);
+ EXPECT_EQ(res, 0);
+
+ // Test #3
+ LOG(INFO) << "===== Test #3: PushAsyncND invalid number of const nds =====";
+ res = MXEnginePushAsyncND(FooAsyncFunc, nullptr, nullptr, &ctx, &nd, -1, nullptr, 0);
+ EXPECT_EQ(res, -1);
+
+ // Test #4
+ LOG(INFO) << "===== Test #4: PushAsyncND invalid number of mutable nds =====";
+ res = MXEnginePushAsyncND(FooAsyncFunc, nullptr, nullptr, &ctx, nullptr, 0, &nd, -1);
+ EXPECT_EQ(res, -1);
+
+ // Test #5
+ LOG(INFO) << "===== Test #5: PushSyncND param and deleter =====";
+ int* b = new int(101);
+ res = MXEnginePushSyncND(FooSyncFunc, b, FooFuncDeleter, &ctx, &nd, 1, nullptr, 0);
+ EXPECT_EQ(res, 0);
+
+ // Test #6
+ LOG(INFO) << "===== Test #6: PushSyncND NULL param and NULL deleter =====";
+ res = MXEnginePushSyncND(FooSyncFunc, nullptr, nullptr, &ctx, nullptr, 0, &nd, 1);
+ EXPECT_EQ(res, 0);
+
+ // Test #7
+ LOG(INFO) << "===== Test #7: PushSyncND invalid number of const nds =====";
+ res = MXEnginePushSyncND(FooSyncFunc, nullptr, nullptr, &ctx, &nd, -1, nullptr, 0);
+ EXPECT_EQ(res, -1);
+
+ // Test #8
+ LOG(INFO) << "===== Test #8: PushSyncND invalid number of mutable nds =====";
+ res = MXEnginePushSyncND(FooSyncFunc, nullptr, nullptr, &ctx, nullptr, 0, &nd, -1);
+ EXPECT_EQ(res, -1);
+}
+
TEST(Engine, basics) {
auto&& engine = mxnet::Engine::Get();
auto&& var = engine->NewVariable();