You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2019/03/07 04:53:44 UTC
[incubator-mxnet] branch master updated: print error message for
mxnet::cpp::Operator::Invoke when failed (#14318)
This is an automated email from the ASF dual-hosted git repository.
zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 6caaa38 print error message for mxnet::cpp::Operator::Invoke when failed (#14318)
6caaa38 is described below
commit 6caaa38bdad5452e2872f382a7e61f173f9a0a6b
Author: JackieWu <wk...@live.cn>
AuthorDate: Thu Mar 7 12:53:27 2019 +0800
print error message for mxnet::cpp::Operator::Invoke when failed (#14318)
* raise exceptions for mxnet::cpp::Operator::Invoke when failed
* fix input shape
* fix cpplint
* fix cpp-package example shape
---
cpp-package/example/alexnet.cpp | 75 +++++++++++++++++++--------
cpp-package/example/inception_bn.cpp | 35 ++++++++++---
cpp-package/example/lenet_with_mxdataiter.cpp | 32 +++++++++---
cpp-package/example/resnet.cpp | 55 ++++++++++++++++----
cpp-package/include/mxnet-cpp/operator.hpp | 8 +--
5 files changed, 155 insertions(+), 50 deletions(-)
diff --git a/cpp-package/example/alexnet.cpp b/cpp-package/example/alexnet.cpp
index e2083a0..2b2d7b4 100644
--- a/cpp-package/example/alexnet.cpp
+++ b/cpp-package/example/alexnet.cpp
@@ -196,19 +196,39 @@ Symbol AlexnetSymbol(int num_classes) {
return softmax;
}
+NDArray ResizeInput(NDArray data, const Shape new_shape) {
+ NDArray pic = data.Reshape(Shape(0, 1, 28, 28));
+ NDArray pic_1channel;
+ Operator("_contrib_BilinearResize2D")
+ .SetParam("height", new_shape[2])
+ .SetParam("width", new_shape[3])
+ (pic).Invoke(pic_1channel);
+ NDArray output;
+ Operator("tile")
+ .SetParam("reps", Shape(1, 3, 1, 1))
+ (pic_1channel).Invoke(output);
+ return output;
+}
+
int main(int argc, char const *argv[]) {
/*basic config*/
- int batch_size = 256;
int max_epo = argc > 1 ? strtol(argv[1], NULL, 10) : 100;
float learning_rate = 1e-4;
float weight_decay = 1e-4;
- /*context and net symbol*/
- auto ctx = Context::gpu();
-#if MXNET_USE_CPU
- ctx = Context::cpu();
+ /*context*/
+ auto ctx = Context::cpu();
+ int num_gpu;
+ MXGetGPUCount(&num_gpu);
+ int batch_size = 32;
+#if !MXNET_USE_CPU
+ if (num_gpu > 0) {
+ ctx = Context::gpu();
+ batch_size = 256;
+ }
#endif
+ /*net symbol*/
auto Net = AlexnetSymbol(10);
/*args_map and aux_map is used for parameters' saving*/
@@ -216,8 +236,10 @@ int main(int argc, char const *argv[]) {
std::map<std::string, NDArray> aux_map;
/*we should tell mxnet the shape of data and label*/
- args_map["data"] = NDArray(Shape(batch_size, 3, 256, 256), ctx);
- args_map["label"] = NDArray(Shape(batch_size), ctx);
+ const Shape data_shape = Shape(batch_size, 3, 256, 256),
+ label_shape = Shape(batch_size);
+ args_map["data"] = NDArray(data_shape, ctx);
+ args_map["label"] = NDArray(label_shape, ctx);
/*with data and label, executor can be generated automatically*/
auto *exec = Net.SimpleBind(ctx, args_map);
@@ -261,17 +283,18 @@ int main(int argc, char const *argv[]) {
->SetParam("wd", weight_decay);
Accuracy acu_train, acu_val;
- LogLoss logloss_val;
- for (int iter = 0; iter < max_epo; ++iter) {
- LG << "Train Epoch: " << iter;
+ LogLoss logloss_train, logloss_val;
+ for (int epoch = 0; epoch < max_epo; ++epoch) {
+ LG << "Train Epoch: " << epoch;
/*reset the metric every epoch*/
acu_train.Reset();
/*reset the data iter every epoch*/
train_iter.Reset();
+ int iter = 0;
while (train_iter.Next()) {
auto batch = train_iter.GetDataBatch();
/*use copyto to feed new data and label to the executor*/
- batch.data.CopyTo(&args_map["data"]);
+ ResizeInput(batch.data, data_shape).CopyTo(&args_map["data"]);
batch.label.CopyTo(&args_map["label"]);
exec->Forward(true);
exec->Backward();
@@ -282,39 +305,47 @@ int main(int argc, char const *argv[]) {
NDArray::WaitAll();
acu_train.Update(batch.label, exec->outputs[0]);
+ logloss_train.Reset();
+ logloss_train.Update(batch.label, exec->outputs[0]);
+ ++iter;
+ LG << "EPOCH: " << epoch << " ITER: " << iter
+ << " Train Accuracy: " << acu_train.Get()
+ << " Train Loss: " << logloss_train.Get();
}
- LG << "ITER: " << iter << " Train Accuracy: " << acu_train.Get();
+ LG << "EPOCH: " << epoch << " Train Accuracy: " << acu_train.Get();
- LG << "Val Epoch: " << iter;
+ LG << "Val Epoch: " << epoch;
acu_val.Reset();
val_iter.Reset();
logloss_val.Reset();
+ iter = 0;
while (val_iter.Next()) {
auto batch = val_iter.GetDataBatch();
- LG << val_iter.GetDataBatch().index.size();
- batch.data.CopyTo(&args_map["data"]);
+ ResizeInput(batch.data, data_shape).CopyTo(&args_map["data"]);
batch.label.CopyTo(&args_map["label"]);
exec->Forward(false);
NDArray::WaitAll();
acu_val.Update(batch.label, exec->outputs[0]);
logloss_val.Update(batch.label, exec->outputs[0]);
+ LG << "EPOCH: " << epoch << " ITER: " << iter << " Val Accuracy: " << acu_val.Get();
+ ++iter;
}
- LG << "ITER: " << iter << " Val Accuracy: " << acu_val.Get();
- LG << "ITER: " << iter << " Val LogLoss: " << logloss_val.Get();
+ LG << "EPOCH: " << epoch << " Val Accuracy: " << acu_val.Get();
+ LG << "EPOCH: " << epoch << " Val LogLoss: " << logloss_val.Get();
/*save the parameters*/
std::stringstream ss;
- ss << iter;
- std::string iter_str;
- ss >> iter_str;
- std::string save_path_param = "alex_param_" + iter_str;
+ ss << epoch;
+ std::string epoch_str;
+ ss >> epoch_str;
+ std::string save_path_param = "alex_param_" + epoch_str;
auto save_args = args_map;
/*we do not want to save the data and label*/
save_args.erase(save_args.find("data"));
save_args.erase(save_args.find("label"));
/*the alexnet does not get any aux array, so we do not need to save
* aux_map*/
- LG << "ITER: " << iter << " Saving to..." << save_path_param;
+ LG << "EPOCH: " << epoch << " Saving to..." << save_path_param;
NDArray::Save(save_path_param, save_args);
}
/*don't foget to release the executor*/
diff --git a/cpp-package/example/inception_bn.cpp b/cpp-package/example/inception_bn.cpp
index 2073ebe..a29ef2d 100644
--- a/cpp-package/example/inception_bn.cpp
+++ b/cpp-package/example/inception_bn.cpp
@@ -142,23 +142,44 @@ Symbol InceptionSymbol(int num_classes) {
return SoftmaxOutput("softmax", fc1, data_label);
}
+NDArray ResizeInput(NDArray data, const Shape new_shape) {
+ NDArray pic = data.Reshape(Shape(0, 1, 28, 28));
+ NDArray pic_1channel;
+ Operator("_contrib_BilinearResize2D")
+ .SetParam("height", new_shape[2])
+ .SetParam("width", new_shape[3])
+ (pic).Invoke(pic_1channel);
+ NDArray output;
+ Operator("tile")
+ .SetParam("reps", Shape(1, 3, 1, 1))
+ (pic_1channel).Invoke(output);
+ return output;
+}
+
int main(int argc, char const *argv[]) {
int batch_size = 40;
int max_epoch = argc > 1 ? strtol(argv[1], NULL, 10) : 100;
float learning_rate = 1e-2;
float weight_decay = 1e-4;
- auto ctx = Context::gpu();
-#if MXNET_USE_CPU
- ctx = Context::cpu();
+ /*context*/
+ auto ctx = Context::cpu();
+ int num_gpu;
+ MXGetGPUCount(&num_gpu);
+#if !MXNET_USE_CPU
+ if (num_gpu > 0) {
+ ctx = Context::gpu();
+ }
#endif
auto inception_bn_net = InceptionSymbol(10);
std::map<std::string, NDArray> args_map;
std::map<std::string, NDArray> aux_map;
- args_map["data"] = NDArray(Shape(batch_size, 3, 224, 224), ctx);
- args_map["data_label"] = NDArray(Shape(batch_size), ctx);
+ const Shape data_shape = Shape(batch_size, 3, 224, 224),
+ label_shape = Shape(batch_size);
+ args_map["data"] = NDArray(data_shape, ctx);
+ args_map["data_label"] = NDArray(label_shape, ctx);
inception_bn_net.InferArgsMap(ctx, &args_map, args_map);
std::vector<std::string> data_files = { "./data/mnist_data/train-images-idx3-ubyte",
@@ -201,7 +222,7 @@ int main(int argc, char const *argv[]) {
train_acc.Reset();
while (train_iter.Next()) {
auto data_batch = train_iter.GetDataBatch();
- data_batch.data.CopyTo(&args_map["data"]);
+ ResizeInput(data_batch.data, data_shape).CopyTo(&args_map["data"]);
data_batch.label.CopyTo(&args_map["data_label"]);
NDArray::WaitAll();
@@ -221,7 +242,7 @@ int main(int argc, char const *argv[]) {
val_acc.Reset();
while (val_iter.Next()) {
auto data_batch = val_iter.GetDataBatch();
- data_batch.data.CopyTo(&args_map["data"]);
+ ResizeInput(data_batch.data, data_shape).CopyTo(&args_map["data"]);
data_batch.label.CopyTo(&args_map["data_label"]);
NDArray::WaitAll();
exec->Forward(false);
diff --git a/cpp-package/example/lenet_with_mxdataiter.cpp b/cpp-package/example/lenet_with_mxdataiter.cpp
index 33110fe..fac624b 100644
--- a/cpp-package/example/lenet_with_mxdataiter.cpp
+++ b/cpp-package/example/lenet_with_mxdataiter.cpp
@@ -66,6 +66,16 @@ Symbol LenetSymbol() {
return lenet;
}
+NDArray ResizeInput(NDArray data, const Shape new_shape) {
+ NDArray pic = data.Reshape(Shape(0, 1, 28, 28));
+ NDArray output;
+ Operator("_contrib_BilinearResize2D")
+ .SetParam("height", new_shape[2])
+ .SetParam("width", new_shape[3])
+ (pic).Invoke(output);
+ return output;
+}
+
int main(int argc, char const *argv[]) {
/*setup basic configs*/
int W = 28;
@@ -74,15 +84,23 @@ int main(int argc, char const *argv[]) {
int max_epoch = argc > 1 ? strtol(argv[1], NULL, 10) : 100;
float learning_rate = 1e-4;
float weight_decay = 1e-4;
- auto dev_ctx = Context::gpu();
-#if MXNET_USE_CPU
- dev_ctx = Context::cpu();
+
+ auto dev_ctx = Context::cpu();
+ int num_gpu;
+ MXGetGPUCount(&num_gpu);
+#if !MXNET_USE_CPU
+ if (num_gpu > 0) {
+ dev_ctx = Context::gpu();
+ }
#endif
+
auto lenet = LenetSymbol();
std::map<std::string, NDArray> args_map;
- args_map["data"] = NDArray(Shape(batch_size, 1, W, H), dev_ctx);
- args_map["data_label"] = NDArray(Shape(batch_size), dev_ctx);
+ const Shape data_shape = Shape(batch_size, 1, H, W),
+ label_shape = Shape(batch_size);
+ args_map["data"] = NDArray(data_shape, dev_ctx);
+ args_map["data_label"] = NDArray(label_shape, dev_ctx);
lenet.InferArgsMap(dev_ctx, &args_map, args_map);
args_map["fc1_w"] = NDArray(Shape(500, 4 * 4 * 50), dev_ctx);
@@ -131,7 +149,7 @@ int main(int argc, char const *argv[]) {
samples += batch_size;
auto data_batch = train_iter.GetDataBatch();
- data_batch.data.CopyTo(&args_map["data"]);
+ ResizeInput(data_batch.data, data_shape).CopyTo(&args_map["data"]);
data_batch.label.CopyTo(&args_map["data_label"]);
NDArray::WaitAll();
@@ -163,7 +181,7 @@ int main(int argc, char const *argv[]) {
val_iter.Reset();
while (val_iter.Next()) {
auto data_batch = val_iter.GetDataBatch();
- data_batch.data.CopyTo(&args_map["data"]);
+ ResizeInput(data_batch.data, data_shape).CopyTo(&args_map["data"]);
data_batch.label.CopyTo(&args_map["data_label"]);
NDArray::WaitAll();
diff --git a/cpp-package/example/resnet.cpp b/cpp-package/example/resnet.cpp
index 7200bd4..29071bd 100644
--- a/cpp-package/example/resnet.cpp
+++ b/cpp-package/example/resnet.cpp
@@ -153,8 +153,21 @@ Symbol ResNetSymbol(int num_class, int num_level = 3, int num_block = 9,
return SoftmaxOutput("softmax", fc, data_label);
}
+NDArray ResizeInput(NDArray data, const Shape new_shape) {
+ NDArray pic = data.Reshape(Shape(0, 1, 28, 28));
+ NDArray pic_1channel;
+ Operator("_contrib_BilinearResize2D")
+ .SetParam("height", new_shape[2])
+ .SetParam("width", new_shape[3])
+ (pic).Invoke(pic_1channel);
+ NDArray output;
+ Operator("tile")
+ .SetParam("reps", Shape(1, 3, 1, 1))
+ (pic_1channel).Invoke(output);
+ return output;
+}
+
int main(int argc, char const *argv[]) {
- int batch_size = 50;
int max_epoch = argc > 1 ? strtol(argv[1], NULL, 10) : 100;
float learning_rate = 1e-4;
float weight_decay = 1e-4;
@@ -163,13 +176,22 @@ int main(int argc, char const *argv[]) {
std::map<std::string, NDArray> args_map;
std::map<std::string, NDArray> aux_map;
- auto ctx = Context::gpu();
-#if MXNET_USE_CPU
- ctx = Context::cpu();;
+ /*context*/
+ auto ctx = Context::cpu();
+ int num_gpu;
+ MXGetGPUCount(&num_gpu);
+ int batch_size = 8;
+#if !MXNET_USE_CPU
+ if (num_gpu > 0) {
+ ctx = Context::gpu();
+ batch_size = 50;
+ }
#endif
- args_map["data"] = NDArray(Shape(batch_size, 3, 256, 256), ctx);
- args_map["data_label"] = NDArray(Shape(batch_size), ctx);
+ const Shape data_shape = Shape(batch_size, 3, 224, 224),
+ label_shape = Shape(batch_size);
+ args_map["data"] = NDArray(data_shape, ctx);
+ args_map["data_label"] = NDArray(label_shape, ctx);
resnet.InferArgsMap(ctx, &args_map, args_map);
std::vector<std::string> data_files = { "./data/mnist_data/train-images-idx3-ubyte",
@@ -206,13 +228,15 @@ int main(int argc, char const *argv[]) {
// Create metrics
Accuracy train_acc, val_acc;
- for (int iter = 0; iter < max_epoch; ++iter) {
- LG << "Epoch: " << iter;
+ LogLoss logloss_train, logloss_val;
+ for (int epoch = 0; epoch < max_epoch; ++epoch) {
+ LG << "Epoch: " << epoch;
train_iter.Reset();
train_acc.Reset();
+ int iter = 0;
while (train_iter.Next()) {
auto data_batch = train_iter.GetDataBatch();
- data_batch.data.CopyTo(&args_map["data"]);
+ ResizeInput(data_batch.data, data_shape).CopyTo(&args_map["data"]);
data_batch.label.CopyTo(&args_map["data_label"]);
NDArray::WaitAll();
@@ -225,20 +249,29 @@ int main(int argc, char const *argv[]) {
}
NDArray::WaitAll();
train_acc.Update(data_batch.label, exec->outputs[0]);
+ logloss_train.Reset();
+ logloss_train.Update(data_batch.label, exec->outputs[0]);
+ ++iter;
+ LG << "EPOCH: " << epoch << " ITER: " << iter
+ << " Train Accuracy: " << train_acc.Get()
+ << " Train Loss: " << logloss_train.Get();
}
+ LG << "EPOCH: " << epoch << " Train Accuracy: " << train_acc.Get();
val_iter.Reset();
val_acc.Reset();
+ iter = 0;
while (val_iter.Next()) {
auto data_batch = val_iter.GetDataBatch();
- data_batch.data.CopyTo(&args_map["data"]);
+ ResizeInput(data_batch.data, data_shape).CopyTo(&args_map["data"]);
data_batch.label.CopyTo(&args_map["data_label"]);
NDArray::WaitAll();
exec->Forward(false);
NDArray::WaitAll();
val_acc.Update(data_batch.label, exec->outputs[0]);
+ LG << "EPOCH: " << epoch << " ITER: " << iter << " Val Accuracy: " << val_acc.Get();
+ ++iter;
}
- LG << "Train Accuracy: " << train_acc.Get();
LG << "Validation Accuracy: " << val_acc.Get();
}
delete exec;
diff --git a/cpp-package/include/mxnet-cpp/operator.hpp b/cpp-package/include/mxnet-cpp/operator.hpp
index edc396f..8cdd78d 100644
--- a/cpp-package/include/mxnet-cpp/operator.hpp
+++ b/cpp-package/include/mxnet-cpp/operator.hpp
@@ -134,9 +134,11 @@ inline void Operator::Invoke(std::vector<NDArray> &outputs) {
outputs_receiver = output_handles.data();
}
- MXImperativeInvoke(handle_, num_inputs, input_ndarrays_.data(),
- &num_outputs, &outputs_receiver,
- param_keys.size(), param_keys.data(), param_values.data());
+ if (MXImperativeInvoke(handle_, num_inputs, input_ndarrays_.data(),
+ &num_outputs, &outputs_receiver,
+ param_keys.size(), param_keys.data(),
+ param_values.data()))
+ LOG(FATAL) << MXGetLastError();
if (outputs.size() > 0)
return;