You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by ka...@apache.org on 2016/06/27 14:11:49 UTC

[2/6] incubator-singa git commit: SINGA-204 Support the training of feed-forward neural nets

SINGA-204 Support the training of feed-forward neural nets

Fix the bug from pre/post increament in Block Inc/Dec reference which
resulted in out-of-memory.  Cifar10 data loading works.


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/cf1d8418
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/cf1d8418
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/cf1d8418

Branch: refs/heads/dev
Commit: cf1d841890842c6cf1573491f4fc9d7e1eca30f4
Parents: d826b2e
Author: Wei Wang <wa...@comp.nus.edu.sg>
Authored: Sat Jun 25 19:34:13 2016 +0800
Committer: Wei Wang <wa...@comp.nus.edu.sg>
Committed: Mon Jun 27 15:27:19 2016 +0800

----------------------------------------------------------------------
 examples/cifar10/alexnet.cc            | 105 +++++++++++------
 examples/cifar10/cifar10.cc            |  98 ----------------
 examples/cifar10/cifar10.h             |  99 ++++++++++++++++
 examples/cifar10/make.sh               |   1 +
 include/singa/core/common.h            |  16 +--
 include/singa/core/tensor.h            |   4 +-
 include/singa/model/feed_forward_net.h |  63 +++++-----
 include/singa/model/layer.h            |   2 +-
 include/singa/model/loss.h             |   4 +-
 include/singa/model/metric.h           |  23 ++++
 include/singa/model/optimizer.h        |   2 +-
 src/CMakeLists.txt                     |   3 +-
 src/core/device/cuda_gpu.cc            |   4 +-
 src/core/tensor/tensor.cc              |  14 ++-
 src/model/feed_forward_net.cc          | 176 ++++++++++++++--------------
 src/model/layer/convolution.cc         |   6 +-
 src/model/layer/convolution.h          |   1 +
 src/model/layer/cudnn_activation.cc    |   1 +
 src/model/layer/cudnn_convolution.cc   |   1 +
 src/model/layer/cudnn_dropout.cc       |   4 +
 src/model/layer/cudnn_dropout.h        |   1 +
 src/model/layer/cudnn_pooling.cc       |   1 +
 src/model/layer/cudnn_softmax.cc       |   1 +
 src/model/layer/dense.cc               |   3 +
 src/model/metric/accuracy.cc           |  62 ++++++++++
 src/model/metric/accuracy.h            |  84 -------------
 test/singa/test_accuracy.cc            |   2 +-
 27 files changed, 419 insertions(+), 362 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/cf1d8418/examples/cifar10/alexnet.cc
----------------------------------------------------------------------
diff --git a/examples/cifar10/alexnet.cc b/examples/cifar10/alexnet.cc
index 2917dd2..45d8571 100644
--- a/examples/cifar10/alexnet.cc
+++ b/examples/cifar10/alexnet.cc
@@ -22,7 +22,14 @@
 #include "singa/model/feed_forward_net.h"
 #include "singa/model/optimizer.h"
 #include "singa/model/initializer.h"
-
+#include "singa/model/metric.h"
+#include "singa/utils/channel.h"
+#include "singa/utils/string.h"
+#include "../../src/model/layer/cudnn_convolution.h"
+#include "../../src/model/layer/cudnn_activation.h"
+#include "../../src/model/layer/cudnn_pooling.h"
+#include "../../src/model/layer/dense.h"
+#include "../../src/model/layer/flatten.h"
 namespace singa {
 
 LayerConf GenConvConf(string name, int nb_filter, int kernel, int stride,
@@ -32,9 +39,9 @@ LayerConf GenConvConf(string name, int nb_filter, int kernel, int stride,
   conf.set_type("CudnnConvolution");
   ConvolutionConf *conv = conf.mutable_convolution_conf();
   conv->set_num_output(nb_filter);
-  conv->set_kernel_size(kernel);
-  conv->set_stride(stride);
-  conv->set_pad(pad);
+  conv->add_kernel_size(kernel);
+  conv->add_stride(stride);
+  conv->add_pad(pad);
 
   FillerConf *weight = conv->mutable_weight_filler();
   weight->set_type("Xavier");
@@ -50,7 +57,7 @@ LayerConf GenPoolingConf(string name, bool max_pool, int kernel, int stride, int
   pool->set_stride(stride);
   pool->set_pad(pad);
   if (!max_pool)
-    pool->set_pool(PoolingConf_AVE);
+    pool->set_pool(PoolingConf_PoolMethod_AVE);
   return conf;
 }
 
@@ -65,9 +72,9 @@ LayerConf GenDenseConf(string name, int num_output) {
   LayerConf conf;
   conf.set_name(name);
   conf.set_type("Dense");
-  DenseConf *dense = conf->mutable_dense_conf();
+  DenseConf *dense = conf.mutable_dense_conf();
   dense->set_num_output(num_output);
-  FillerConf *weight = conv->mutable_weight_filler();
+  FillerConf *weight = dense->mutable_weight_filler();
   weight->set_type("Xavier");
   return conf;
 }
@@ -79,22 +86,27 @@ LayerConf GenSoftmaxConf(string name) {
   return conf;
 }
 
-
-FeedForwordNet CreateNet(Optimizer* opt, Loss* loss, Metric* metric) {
-  FeedForwordNet net;
+LayerConf GenFlattenConf(string name) {
+  LayerConf conf;
+  conf.set_name(name);
+  conf.set_type("Flatten");
+  return conf;
+}
+FeedForwardNet CreateNet(Optimizer* opt, Loss<Tensor>* loss, Metric<Tensor>* metric) {
+  FeedForwardNet net;
   Shape s{3, 32, 32};
-  net.AddLayer(GenConvConf("conv1", 32, 5, 1, 2), &s);
-  net.AddLayer(GenReLUConf("relu1"));
-  net.AddLayer(GenConvConf("pool1", 3, 2, 0));
-  net.AddLayer(GenConvConf("conv2", 32, 5, 1, 2));
-  net.AddLayer(GenReLUConf("relu2"));
-  net.AddLayer(GenConvConf("pool2", 3, 2, 0));
-  net.AddLayer(GenConvConf("conv3", 64, 5, 1, 2));
-  net.AddLayer(GenReLUConf("relu3"));
-  net.AddLayer(GenConvConf("pool3", 3, 2, 0));
-  net.AddLayer(GenDenseConf("ip1", 10));
-  net.AddLayer(GenSoftmaxConf("softmax"));
 
+  net.Add(new CudnnConvolution(), GenConvConf("conv1", 32, 5, 1, 2), &s);
+  net.Add(new CudnnActivation(), GenReLUConf("relu1"));
+  net.Add(new CudnnPooling, GenPoolingConf("pool1", true, 3, 2, 0));
+  net.Add(new CudnnConvolution(), GenConvConf("conv2", 32, 5, 1, 2));
+  net.Add(new CudnnActivation(), GenReLUConf("relu2"));
+  net.Add(new CudnnPooling(), GenPoolingConf("pool2", true, 3, 2, 0));
+  net.Add(new CudnnConvolution, GenConvConf("conv3", 64, 5, 1, 2));
+  net.Add(new CudnnActivation(), GenReLUConf("relu3"));
+  net.Add(new CudnnConvolution(), GenConvConf("pool3", true, 3, 2, 0));
+  net.Add(new Flatten(), GenFlattenConf("flat"));
+  net.Add(new Dense(), GenDenseConf("ip1", 10));
   OptimizerConf opt_conf;
   opt_conf.set_momentum(0.9);
   opt->Setup(opt_conf);
@@ -103,42 +115,57 @@ FeedForwordNet CreateNet(Optimizer* opt, Loss* loss, Metric* metric) {
 }
 
 void Train(float lr, int num_epoch, string data_dir) {
-  SoftmaxCrossEntropy loss;
-  Accuracy acc;
-  SGD sgd;
-  sgd.SetLearningRate([lr](int step) {return lr;});
-  auto net = CreateNet(&opt, &loss, &metric);
   Cifar10 data(data_dir);
-  Tensor train_x, tain_y, test_x, test_y;
+  Tensor train_x, train_y, test_x, test_y;
   {
     auto train = data.ReadTrainData();
-    const auto mean = Average(train.first, 0);
-    train_x = SubRow(train.first, mean);
-    auto test = data.ReadTestData();
-    test_x = SubRow(test.first, mean);
+    size_t nsamples = train.first.shape(0);
+    auto matx = Reshape(train.first, Shape{nsamples, train.first.Size() / nsamples});
+    const auto mean = Average(matx, 0);
+    SubRow(mean, &matx);
+    train_x = Reshape(matx, train.first.shape());
     train_y = train.second;
+    auto test = data.ReadTestData();
+    nsamples = test.first.shape(0);
+    auto maty = Reshape(test.first, Shape{nsamples, test.first.Size() / nsamples});
+    SubRow(mean, &maty);
+    test_x = Reshape(maty, test.first.shape());
     test_y = test.second;
   }
-  net.Train(100, num_epoch, train_x, train_y, test_x, test_y);
+  LOG(ERROR) << "creating net";
+  SoftmaxCrossEntropy loss;
+  Accuracy acc;
+  SGD sgd;
+  sgd.SetLearningRateGenerator([lr](int step) {return lr;});
+  auto net = CreateNet(&sgd, &loss, &acc);
+
+  auto cuda = std::make_shared<CudaGPU>();
+  net.ToDevice(cuda);
+
+  train_x.ToDevice(cuda);
+  train_y.ToDevice(cuda);
+  net.Train(50, num_epoch, train_x, train_y); // test_x, test_y);
+}
+
+
 }
 
 int main(int argc, char** argv) {
-  InitChannel();
-  int pos = ArgPos(argc, argv, "-epoch");
+  singa::InitChannel(nullptr);
+  int pos = singa::ArgPos(argc, argv, "-epoch");
   int nEpoch = 5;
   if (pos != -1)
     nEpoch = atoi(argv[pos + 1]);
-  pos = ArgPos(argc, argv, "-lr");
+  pos = singa::ArgPos(argc, argv, "-lr");
   float lr = 0.01;
   if (pos != -1)
     lr = atof(argv[pos + 1]);
-  pos = ArgPos(argc, argv, "-data");
-  string data = "cifar-10-batch-bin";
+  pos = singa::ArgPos(argc, argv, "-data");
+  string data = "cifar-10-batches-bin";
   if (pos != -1)
     data = argv[pos + 1];
 
   LOG(INFO) << "Start training";
-  Train(lr, nEpoch, data);
+  singa::Train(lr, nEpoch, data);
   LOG(INFO) << "End training";
 }
-}

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/cf1d8418/examples/cifar10/cifar10.cc
----------------------------------------------------------------------
diff --git a/examples/cifar10/cifar10.cc b/examples/cifar10/cifar10.cc
deleted file mode 100644
index 7efc18f..0000000
--- a/examples/cifar10/cifar10.cc
+++ /dev/null
@@ -1,98 +0,0 @@
-/************************************************************
-*
-* Licensed to the Apache Software Foundation (ASF) under one
-* or more contributor license agreements.  See the NOTICE file
-* distributed with this work for additional information
-* regarding copyright ownership.  The ASF licenses this file
-* to you under the Apache License, Version 2.0 (the
-* "License"); you may not use this file except in compliance
-* with the License.  You may obtain a copy of the License at
-*
-*   http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing,
-* software distributed under the License is distributed on an
-* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-* KIND, either express or implied.  See the License for the
-* specific language governing permissions and limitations
-* under the License.
-*
-*************************************************************/
-#include <fstream>
-#include <string>
-#include <cstdint>
-#include <iostream>
-
-using std::string;
-namespace singa {
-/// For reading cifar10 binary data as tensors.
-class Cifar10 {
- public:
-  /// 'dir_path': path to the folder including the *.bin files
-  Cifar10(string dir_path, bool normalize = true)
-      : dir_path_(dir_path), normalize_(normalize) {}
-
-  /// read all training data into an image Tensor and a label Tensor
-  const std::pair<Tensor, Tensor> ReadTrainData(bool shuffle = false);
-  /// read all test data into an image Tensor and a label Tensor
-  const std::pair<Tensor, Tensor> ReadTestData();
-  /// read data from one file into an image Tensor and a label Tensor
-  const std::pair<Tensor, Tensor> ReadFile(string file, bool shuffle = false);
-
- private:
-  const int kImageSize = 32;
-  const int kImageVol = 3072;
-  const int kBatchSize = 10000;
-  const int kTrainFiles = 5;
-
-  string dir_path_;
-  bool normalize_;
-};
-
-void read_image(std::ifstream* file, int* label, char* buffer) {
-  char label_char;
-  file->read(&label_char, 1);
-  *label = label_char;
-  file->read(buffer, kImageVol);
-  return;
-}
-const std::pair<Tensor, Tensor> Cifar10::ReadFile(string file,
-                                                  bool shuffle = false) {
-  Tensor images(Shape{kTrainFiles, 3, kImageSize, kImageSize});
-  Tensor labels(Shape{kTrainFiles}, kInt);
-  if (dir_path_.back() != '/') dir_path_.push_back('/');
-  LOG(INFO) << "Reading file " << dir_path_ + file;
-  std::ifstream data_file((dir_path_ + file).c_str(),
-                          std::ios::in | std::ios::binary);
-  CHECK(data_file.is_open()) << "Unable to open file " << file;
-  int label;
-  char image[kImageVol];
-  float float_image[kImageVol];
-  int tmplabels[kBatchSize];
-  for (int itemid = 0; itemid < kBatchSize; ++itemid) {
-    read_image(&data_file, &label, image);
-    for (int i = 0; i < kImageVol; i++)
-      float_image[i] = static_cast<float>(static_cast<int>(image[i]));
-    images.CopyDataFromHostPtr(float_image, kImageVol, itemid * kImageVol);
-    tmplabels[itemid] = label;
-  }
-  labels.CopyDataFromHostPtr(tmplabels, kBatchSize);
-  return std::make_pair(images, labels);
-}
-
-const std::pair<Tensor, Tensor> Cifar10::ReadTrainData(bool shuffle = false) {
-  Tensor images(Shape{kBatchSize * kTrainFiles, 3, kImageSize, kImageSize});
-  Tensor labels(Shape{kBatchSize * kTrainFiles, 3, kImageSize, kImageSize});
-  for (int fileid = 0; fileid < kTrainFiles; ++fileid) {
-    string file = "data_batch_" + std::to_string(fileid + 1) + ".bin";
-    const auto ret = ReadFile(file);
-    CopyDataToFrom(&images, ret.first, ret.first.Size(),
-                   fileid * ret.first.Size());
-    CopyDataToFrom(&labels, ret.second, kBatchSize, fileid * kBatchSize);
-  }
-  return std::make_pair(images, labels);
-}
-const std::pair<Tensor, Tensor> Cifar10::ReadTrainData() {
-  return ReadFile("test_batch.bin");
-}
-}  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/cf1d8418/examples/cifar10/cifar10.h
----------------------------------------------------------------------
diff --git a/examples/cifar10/cifar10.h b/examples/cifar10/cifar10.h
new file mode 100644
index 0000000..261c048
--- /dev/null
+++ b/examples/cifar10/cifar10.h
@@ -0,0 +1,99 @@
+/************************************************************
+*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied.  See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*
+*************************************************************/
+#include <fstream>
+#include <string>
+#include <cstdint>
+#include <iostream>
+#include "singa/core/tensor.h"
+using std::string;
+namespace singa {
+/// For reading cifar10 binary data as tensors.
+class Cifar10 {
+ public:
+  /// 'dir_path': path to the folder including the *.bin files
+  Cifar10(string dir_path, bool normalize = true)
+      : dir_path_(dir_path), normalize_(normalize) {}
+
+  /// read all training data into an image Tensor and a label Tensor
+  const std::pair<Tensor, Tensor> ReadTrainData(bool shuffle = false);
+  /// read all test data into an image Tensor and a label Tensor
+  const std::pair<Tensor, Tensor> ReadTestData();
+  /// read data from one file into an image Tensor and a label Tensor
+  const std::pair<Tensor, Tensor> ReadFile(string file, bool shuffle = false);
+
+  void ReadImage(std::ifstream* file, int* label, char* buffer);
+ private:
+  const size_t kImageSize = 32;
+  const size_t kImageVol = 3072;
+  const size_t kBatchSize = 10000;
+  const size_t kTrainFiles = 1;
+
+  string dir_path_;
+  bool normalize_;
+};
+
+void Cifar10::ReadImage(std::ifstream* file, int* label, char* buffer) {
+  char label_char;
+  file->read(&label_char, 1);
+  *label = static_cast<int>(label_char);
+  file->read(buffer, kImageVol);
+  return;
+}
+const std::pair<Tensor, Tensor> Cifar10::ReadFile(string file, bool shuffle) {
+  Tensor images(Shape{kBatchSize, 3, kImageSize, kImageSize});
+  Tensor labels(Shape{kBatchSize}, kInt);
+  if (dir_path_.back() != '/') dir_path_.push_back('/');
+  LOG(INFO) << "Reading file " << dir_path_ + file;
+  std::ifstream data_file((dir_path_ + file).c_str(),
+                          std::ios::in | std::ios::binary);
+  CHECK(data_file.is_open()) << "Unable to open file " << dir_path_ + file;
+  int label;
+  char image[kImageVol];
+  float float_image[kImageVol];
+  int tmplabels[kBatchSize];
+  for (int itemid = 0; itemid < kBatchSize; ++itemid) {
+    // LOG(INFO) << "reading " << itemid << "-th image";
+    ReadImage(&data_file, &label, image);
+    for (int i = 0; i < kImageVol; i++)
+      float_image[i] = static_cast<float>(static_cast<int>(image[i]));
+    images.CopyDataFromHostPtr(float_image, kImageVol, itemid * kImageVol);
+    tmplabels[itemid] = label;
+  }
+  labels.CopyDataFromHostPtr(tmplabels, kBatchSize);
+  return std::make_pair(images, labels);
+}
+
+const std::pair<Tensor, Tensor> Cifar10::ReadTrainData(bool shuffle) {
+  Tensor images(Shape{kBatchSize * kTrainFiles, 3, kImageSize, kImageSize});
+  Tensor labels(Shape{kBatchSize * kTrainFiles}, kInt);
+  for (int fileid = 0; fileid < kTrainFiles; ++fileid) {
+    string file = "data_batch_" + std::to_string(fileid + 1) + ".bin";
+    const auto ret = ReadFile(file);
+    CopyDataToFrom(&images, ret.first, ret.first.Size(),
+                   fileid * ret.first.Size());
+    CopyDataToFrom(&labels, ret.second, kBatchSize, fileid * kBatchSize);
+  }
+  return std::make_pair(images, labels);
+}
+const std::pair<Tensor, Tensor> Cifar10::ReadTestData() {
+  return ReadFile("test_batch.bin");
+}
+}  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/cf1d8418/examples/cifar10/make.sh
----------------------------------------------------------------------
diff --git a/examples/cifar10/make.sh b/examples/cifar10/make.sh
new file mode 100755
index 0000000..17e4b39
--- /dev/null
+++ b/examples/cifar10/make.sh
@@ -0,0 +1 @@
+g++ -g --std=c++11 alexnet.cc -o alexnet -I../../include -I../../build/include -I/home/wangwei/local/cudnn4/include -I/home/wangwei/local/include -I/usr/local/cuda/include/ -I../../lib/cnmem/include -L../../build/lib/ -lsinga_core -lsinga_model  -lsinga_utils -lcudart -lcublas -lcurand -lcudnn -L/usr/local/cuda/lib64 -L/home/wangwei/local/cudnn4/lib64 ../../build/lib/libproto.a -lprotobuf

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/cf1d8418/include/singa/core/common.h
----------------------------------------------------------------------
diff --git a/include/singa/core/common.h b/include/singa/core/common.h
index cb1bdca..691d7d4 100644
--- a/include/singa/core/common.h
+++ b/include/singa/core/common.h
@@ -49,27 +49,29 @@ class Block {
  public:
   Block(void* ptr, size_t size, size_t offset = 0)
       : data_(ptr), size_(size), offset_(offset) {
-    ref_count_ = std::make_shared<std::atomic<int>>(1);
+    ref_count_ = 1; //std::make_shared<std::atomic<int>>(1);
   }
-  Block(void* ptr, size_t size, size_t offset, std::shared_ptr<atomic<int>> ref)
-      : data_(ptr), size_(size), offset_(offset), ref_count_(ref) {}
+//  Block(void* ptr, size_t size, size_t offset, std::shared_ptr<atomic<int>> ref)
+//      : data_(ptr), size_(size), offset_(offset), ref_count_(ref) {}
   void* mutable_data() const { return static_cast<char*>(data_) + offset_; }
   const void* data() const { return static_cast<char*>(data_) + offset_; }
   size_t size() const { return size_; }
   size_t offset() const { return offset_; }
   int IncRefCount() {
-    return (*ref_count_)++;
+    return ++ref_count_;  //(*ref_count_)++;
   }
   int DecRefCount() {
-    return  (*ref_count_)--;
+    return --ref_count_; // (*ref_count_)--;
   }
-  int ref_count() const { return ref_count_->load(); }
+  int ref_count() const { return ref_count_.load(); }
 
  private:
+  Block() {}
   void* data_ = nullptr;
   size_t size_ = 0;
   size_t offset_ = 0;
-  std::shared_ptr<std::atomic<int>> ref_count_ = nullptr;
+  // std::shared_ptr<std::atomic<int>> ref_count_ = nullptr;
+  std::atomic<int> ref_count_;
 };
 
 typedef struct _Context {

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/cf1d8418/include/singa/core/tensor.h
----------------------------------------------------------------------
diff --git a/include/singa/core/tensor.h b/include/singa/core/tensor.h
index 6de5c0c..3b496d9 100644
--- a/include/singa/core/tensor.h
+++ b/include/singa/core/tensor.h
@@ -96,6 +96,8 @@ class Tensor {
 
   /// return number of total elements
   size_t Size() const {
+    if (block_ == nullptr)
+      return 0u;
     CHECK_EQ(block_->size() % SizeOf(data_type_), 0u);
     return block_->size() / SizeOf(data_type_);
   }
@@ -315,7 +317,7 @@ Tensor Div(const SType x, const Tensor &in);
 template <typename SType>
 void Div(const SType x, const Tensor &in, Tensor *out);
 
-template <typename SType>
+template <typename SType = float>
 SType Sum(const Tensor &in);
 // ============Matrix (row/column) operations==================================
 /// Average elements in the Tensor, currently only support vector and matrix.

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/cf1d8418/include/singa/model/feed_forward_net.h
----------------------------------------------------------------------
diff --git a/include/singa/model/feed_forward_net.h b/include/singa/model/feed_forward_net.h
index 173600b..9beeb7a 100644
--- a/include/singa/model/feed_forward_net.h
+++ b/include/singa/model/feed_forward_net.h
@@ -18,7 +18,9 @@
 #ifndef SINGA_MODEL_FEED_FORWARD_NET_H_
 #define SINGA_MODEL_FEED_FORWARD_NET_H_
 #include "singa/model/layer.h"
-
+#include "singa/model/loss.h"
+#include "singa/model/metric.h"
+#include "singa/model/optimizer.h"
 namespace singa {
 
 /// The feed-forward neural net.
@@ -26,14 +28,14 @@ namespace singa {
 /// and conducting training, evaluation and prediction.
 class FeedForwardNet {
  public:
-  FeedForwardNet() = explicit;
+  FeedForwardNet() = default;
   ~FeedForwardNet();
 
   /// Add a layer with the assumption that
   /// 1. this function is called in correct order, i.e., the layers are added
   ///    following the topological order.
   /// 2. this layer has already been setup (Setup function is called outside).
-  void Add(Layer *layer);
+  Layer* Add(Layer* layer);
 
   // TODO(wangwei) add ConcatenateLayer and SliceLayer
   // AddConcatenateLayer(vector<Layer*> src, Layer *dst);
@@ -43,36 +45,34 @@ class FeedForwardNet {
   /// Assume the layer is added in corret order.
   /// For the first layer, 'sample_shape' (the input sample shape) is necessary
   /// for calling Setup().
-  void Add(const LayerConf &conf, const Shape *sample_shape = nullptr);
+  Layer* Add(const LayerConf& conf, const Shape* sample_shape = nullptr);
 
+  Layer* Add(Layer* layer, const LayerConf& conf, const Shape* sample_shape = nullptr);
   /// Set some fields used for training and evaluating the neural net.
   /// If the neural net is constructed for evaluation only, then 'opt' is not
   /// necessary; But for training, both 'opt' and 'loss' are necessary.
   /// 'shuffle' indicates shuffling training samples within one epoch it is
   /// valid using Train();
-  void Compile(bool shuffle, Optimizer *opt, Loss *loss, Metric *metric);
+  void Compile(bool shuffle, Optimizer* opt, Loss<Tensor>* loss,
+               Metric<Tensor>* metric);
 
   /// Conduct the training giving the training data 'x' and label 'y'.
-  /// Due to memory limit, 'x' and 'y' could not be very large. Hence, it is
-  /// typically used for small training datasets, e.g., cifar10 and MNIST which
-  /// can be stored in main memory.
-  void Train(int batchsize, int nb_epoch, Tensor x, Tensor y);
-  /// Conduct the training giving the training data 'x' and label 'y'.
-  /// 'val_split' is a ratio for splitting (1-'val_split') of training data for
+  /// 'val_split' of training data is used for
   /// validation. Validation is performance before every epoch.
   /// Due to memory limit, 'x' and 'y' could not be very large. Hence, it is
   /// typically used for small training datasets, e.g., cifar10 and MNIST which
   /// can be stored in main memory.
-  void Train(int batchsize, int nb_epoch, float val_split, Tensor x, Tensor y);
+  void Train(size_t batchsize, int nb_epoch, const Tensor& x, const Tensor& y,
+             float val_split = 0.0f);
   /// Conduct the training given the training and validation data.
   /// Validation is performance before every epoch.
   /// Due to memory limit, 'x' and 'y' could not be very large. Hence, it is
   /// typically used for small training datasets, e.g., cifar10 and MNIST which
   /// can be stored in main memory.
-  void Train(int batchsize, int nb_epoch, Tensor x, Tensor y, Tensor val_x,
-             Tensor val_y);
+  void Train(size_t batchsize, int nb_epoch, const Tensor& x, const Tensor& y,
+             const Tensor& val_x, const Tensor& val_y);
   /// Train the neural net over one batch of training data.
-  Tensor TrainOnBatch(Tensor x, Tensor y);
+  const std::pair<float, float> TrainOnBatch(const Tensor& x, const Tensor& y);
 
   /// Evaluate the neural net with given data.
   /// Returns one tensor for loss values and one tensor for metric values;
@@ -82,9 +82,10 @@ class FeedForwardNet {
   /// Due to memory limit, 'x' and 'y' could not be very large. Hence, it is
   /// typically used for small training datasets, e.g., cifar10 and MNIST which
   /// can be stored in main memory.
-  std::pair<Tensor, Tensor> Evaluate(Tensor x, Tensor y, int batchsize = 128);
+  std::pair<Tensor, Tensor> Evaluate(const Tensor& x, const Tensor& y,
+                                     size_t batchsize = 128);
   /// Evaluate the neural net for one batch of data
-  std::pair<Tensor, Tensor> EvaluateOnBatch(Tensor x, Tensor y);
+  std::pair<Tensor, Tensor> EvaluateOnBatch(const Tensor& x, const Tensor& y);
 
   /// Predict the probability distributation over candicate classes for each
   /// data sample. 'batchsize' is used for controlling the memory footprint.
@@ -92,35 +93,37 @@ class FeedForwardNet {
   /// Due to memory limit, 'x' and 'y' could not be very large. Hence, it is
   /// typically used for small training datasets, e.g., cifar10 and MNIST which
   /// can be stored in main memory.
-  Tensor Predict(const Tensor &x, int batchsize = 128);
+  const Tensor Predict(const Tensor& x, size_t batchsize = 128);
   /// Predict for one batch data.
-  Tensor PredictOnBatch(const Tensor &x);
+  const Tensor PredictOnBatch(const Tensor& x);
 
   /// Forward layers one by one using the data batch 'x'.
   /// Returns the prediction results (from the last layer).
-  Tensor Forward(const Tensor& x);
+  const Tensor Forward(int flag, const Tensor& x);
   /// Backward layers one by one using the gradient batch 'grad'.
   /// Returns the parameter gradients.
-  const vector<Tensor> Backward(const Tensor& grad);
+  const vector<Tensor> Backward(int flag, const Tensor& grad);
 
   /// Clone the neuaral net by cloning every layer to the given device.
   /// If 'device' is nullptr, then clone it one the current device.
-  FeedForwardNet Clone(std::shared_ptr<Device> device = nullptr);
+  FeedForwardNet Clone(std::shared_ptr<Device> device);
   /// Move the layer data to the given device.
-  void ToDevice(Device *device);
+  void ToDevice(std::shared_ptr<Device> device);
+  void ToHost() { ToDevice(defaultDevice); }
   /// Set the data type of each layer.
   void AsType(DataType dtype);
 
-  const vector<Layer *> layers() const { return layers_; }
+  const vector<Layer*> layers() const { return layers_; }
   const vector<string> GetParamNames() const;
-  const vector<Tensor *> GetParamValues() const;
-  const vector<Tensor *> GetParamGrads() const;
+  const vector<ParamSpec> GetParamSpecs() const;
+  const vector<Tensor*> GetParamValues() const;
+  const vector<Tensor*> GetParamGrads() const;
 
  protected:
-  vector<Layer *> layers_;
-  Optimizer *opt_;
-  Loss *loss_;
-  Metric *metric_;
+  vector<Layer*> layers_;
+  Optimizer* opt_;
+  Loss<Tensor>* loss_;
+  Metric<Tensor>* metric_;
 
   bool shuffle_ = true;
   Device* device_ = nullptr;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/cf1d8418/include/singa/model/layer.h
----------------------------------------------------------------------
diff --git a/include/singa/model/layer.h b/include/singa/model/layer.h
index 79eb069..ce8007c 100644
--- a/include/singa/model/layer.h
+++ b/include/singa/model/layer.h
@@ -151,7 +151,7 @@ class Layer {
 
   /// Clone the layer to the given device. Layer data (e.g., parameters) are
   /// deep copied. If 'device' is nullptr, then clone it one the current device.
-  virtual Layer* Clone(std::shared_ptr<Device> device);
+  // virtual Layer* Clone(std::shared_ptr<Device> device);
   /// Move the layer (including its parameters and other internal Tensor) onto
   /// the given device
   virtual void ToDevice(std::shared_ptr<Device> device) {

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/cf1d8418/include/singa/model/loss.h
----------------------------------------------------------------------
diff --git a/include/singa/model/loss.h b/include/singa/model/loss.h
index 79abace..41ec701 100644
--- a/include/singa/model/loss.h
+++ b/include/singa/model/loss.h
@@ -37,6 +37,7 @@ class Loss {
     Setup(loss);
   }
 	virtual ~Loss(){};
+  virtual void ToDevice(std::shared_ptr<Device> device) {}
   /// Set meta fields from user configurations.
   virtual void Setup(const LossConf& conf) {}
 
@@ -48,7 +49,8 @@ class Loss {
   /// It calls Forward() internally. The calling pattern should be
   /// [Evaluate|Forward] Backward.
   float Evaluate(const Tensor& prediction, const T& target) {
-    const Tensor& loss = Forward(prediction, target);
+    Tensor loss = Forward(prediction, target);
+    loss.ToHost();
     return Sum<float>(loss) / (1.0f * loss.Size());
   }
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/cf1d8418/include/singa/model/metric.h
----------------------------------------------------------------------
diff --git a/include/singa/model/metric.h b/include/singa/model/metric.h
index b99ff0d..d013fa4 100644
--- a/include/singa/model/metric.h
+++ b/include/singa/model/metric.h
@@ -33,6 +33,7 @@ class Metric {
  public:
   // TODO(wangwei) call Setup using a default MetricConf.
   Metric() = default;
+  virtual void ToDevice(std::shared_ptr<Device> device) {}
   void Setup(const string& conf) {
     MetricConf metric;
     metric.ParseFromString(conf);
@@ -51,6 +52,28 @@ class Metric {
     return Sum<float>(metric) / (1.0f * metric.Size());
   }
 };
+/// Compute the accuray of the prediction, which is matched against the
+/// ground truth labels.
+/// TODO(wangwei) consider multi-label cases.
+class Accuracy : public Metric<Tensor> {
+ public:
+  /// Set meta fields from user configurations.
+  void Setup(const MetricConf& conf) override { top_k_ = conf.top_k(); }
+
+  /// Check the prediction against the target (ground truth) for each data
+  /// sample. The returned Tensor has a float value for each sample, 0 for wrong
+  /// and 1 for correct. Users can call Sum(const Tensor&) / Tensor::Size() to
+  /// get the accuracy.
+  Tensor Forward(const Tensor& prediction, const Tensor& target);
+
+ private:
+  /// \copydoc Match(const Tensor&, const Tensor&);
+  Tensor Match(const Tensor& prediction, const vector<int>& target);
+  /// If the ground truth label is in the top k predicted labels, then the
+  /// prediction is correct.
+  size_t top_k_ = 1;
+};
+
 
 }  // namespace singa
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/cf1d8418/include/singa/model/optimizer.h
----------------------------------------------------------------------
diff --git a/include/singa/model/optimizer.h b/include/singa/model/optimizer.h
index f912668..a268126 100644
--- a/include/singa/model/optimizer.h
+++ b/include/singa/model/optimizer.h
@@ -155,7 +155,7 @@ class Regularizer {
 };
 
 // =============Vallina SGD with Momentum=====================================
-class SGD : Optimizer {
+class SGD : public Optimizer {
  public:
   void Setup(const OptimizerConf& conf);
   /// Apply the updating algorithm.

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/cf1d8418/src/CMakeLists.txt
----------------------------------------------------------------------
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index c51d454..af09799 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -36,7 +36,7 @@ IF (USE_CUDA)
     SET(FLAGS_BACKUP ${CMAKE_CXX_FLAGS})
     SET(CMAKE_CXX_FLAGS "")
     IF (CMAKE_BUILD_TYPE MATCHES DEBUG)
-        CUDA_COMPILE(cuda_objs SHARED ${cuda_source} 
+        CUDA_COMPILE(cuda_objs SHARED ${cuda_source}
             OPTIONS "-Xcompiler -fPIC -G -g")
     ELSE (CMAKE_BUILD_TYPE MATCHES  DEBUG)
         CUDA_COMPILE(cuda_objs SHARED ${cuda_source} OPTIONS "-Xcompiler -fPIC")
@@ -57,6 +57,7 @@ AUX_SOURCE_DIRECTORY(model model_source)
 AUX_SOURCE_DIRECTORY(model/layer model_source)
 AUX_SOURCE_DIRECTORY(model/optimizer model_source)
 AUX_SOURCE_DIRECTORY(model/loss model_source)
+AUX_SOURCE_DIRECTORY(model/metric model_source)
 #MESSAGE(STATUS "MODEL ${model_source}")
 ADD_LIBRARY(singa_model SHARED ${model_source})
 TARGET_LINK_LIBRARIES(singa_model ${SINGA_LINKER_LIBS})

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/cf1d8418/src/core/device/cuda_gpu.cc
----------------------------------------------------------------------
diff --git a/src/core/device/cuda_gpu.cc b/src/core/device/cuda_gpu.cc
index 5879c58..5f6ac17 100644
--- a/src/core/device/cuda_gpu.cc
+++ b/src/core/device/cuda_gpu.cc
@@ -121,7 +121,7 @@ void CudaGPU::CopyToFrom(void* dst, const void* src, size_t nBytes,
   // cudaMemcpyAsync(dst, src, nBytes,cudaMemcpyDefault, ctx_.stream);
 }
 
-/// Allocate cpu memory.
+/// Allocate gpu memory.
 void* CudaGPU::Malloc(int size) {
   void* ptr = nullptr;
   if (size > 0) {
@@ -132,7 +132,7 @@ void* CudaGPU::Malloc(int size) {
   return ptr;
 }
 
-/// Free cpu memory.
+/// Free gpu memory.
 void CudaGPU::Free(void* ptr) {
   if (ptr != nullptr) {
     // CUDA_CHECK(cudaFree(ptr));

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/cf1d8418/src/core/tensor/tensor.cc
----------------------------------------------------------------------
diff --git a/src/core/tensor/tensor.cc b/src/core/tensor/tensor.cc
index ec59aaa..898cdc6 100644
--- a/src/core/tensor/tensor.cc
+++ b/src/core/tensor/tensor.cc
@@ -118,7 +118,8 @@ void Tensor::ToDevice(std::shared_ptr<Device> dst) {
   // TODO(wangwei) the comparison is very strict. May compare against device ID?
   if (device_ != dst) {
     Tensor tmp(shape_, dst, data_type_);
-    tmp.CopyData(*this);
+    if (block_ != nullptr && Size())
+      tmp.CopyData(*this);
     if (block_ != nullptr && block_->DecRefCount() == 0)
       device_->FreeBlock(block_);
     block_ = tmp.block_;
@@ -136,7 +137,8 @@ void Tensor::CopyDataFromHostPtr(const DType *src, const size_t num,
       << "data_type is " << DataType_Name(data_type_)
       << " user given type is of size " << sizeof(DType);
   if (src != nullptr) {
-    device_->CopyDataFromHostPtr(block(), src, sizeof(DType) * num, offset);
+    device_->CopyDataFromHostPtr(block(), src, sizeof(DType) * num,
+        sizeof(DType) * offset);
   } else {
     LOG(WARNING) << "Copy data from null host ptr";
   }
@@ -637,13 +639,13 @@ Tensor ConcatenateColumns(const vector<Tensor> &in) {
   return out;
 }
 Tensor CopyRows(const Tensor &in, const size_t start, const size_t end) {
-  CHECK_EQ(in.nDim(), 2u);
   CHECK_LT(start, end);
   CHECK_GE(in.shape(0), end);
-  Shape s;
-  s = Shape{end - start, in.shape(1)};
+  Shape s = in.shape();
+  s[0] = end - start;
+  size_t sample_size = in.Size() / in.shape(0);
   Tensor out(s, in.device(), in.data_type());
-  CopyDataToFrom(&out, in, out.Size(), 0, start * out.shape(1));
+  CopyDataToFrom(&out, in, out.Size(), 0, start * sample_size);
   return out;
 }
 Tensor CopyColumns(const Tensor &in, const size_t start, const size_t end) {

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/cf1d8418/src/model/feed_forward_net.cc
----------------------------------------------------------------------
diff --git a/src/model/feed_forward_net.cc b/src/model/feed_forward_net.cc
index f9e6480..a24d36a 100644
--- a/src/model/feed_forward_net.cc
+++ b/src/model/feed_forward_net.cc
@@ -18,11 +18,11 @@
 
 #include "singa/model/feed_forward_net.h"
 #include "singa/utils/logging.h"
+#include "singa/utils/channel.h"
 namespace singa {
 
-~FeedForwardNet::FeedForwardNet() {
-  for (auto layer : layers_)
-    delete layer;
+FeedForwardNet::~FeedForwardNet() {
+  for (auto layer : layers_) delete layer;
 }
 Layer* FeedForwardNet::Add(Layer* layer) {
   layers_.push_back(layer);
@@ -32,7 +32,12 @@ Layer* FeedForwardNet::Add(Layer* layer) {
 Layer* FeedForwardNet::Add(const LayerConf& conf, const Shape* sample_shape) {
   CHECK(sample_shape != nullptr || layers_.size())
       << "Must provide the input sample shape for the first layer";
-  Layer* layer = CreateLayer(conf.type());
+  Layer* layer = nullptr;  // TODO(wangwei) use CreateLayer(conf.type());
+  Add(layer, conf, sample_shape);
+  return layer;
+}
+
+Layer* FeedForwardNet::Add(Layer* layer, const LayerConf& conf, const Shape* sample_shape) {
   if (sample_shape == nullptr)
     layer->Setup(layers_.back()->GetOutputSampleShape(), conf);
   else
@@ -44,28 +49,25 @@ Layer* FeedForwardNet::Add(const LayerConf& conf, const Shape* sample_shape) {
 const vector<string> FeedForwardNet::GetParamNames() const {
   vector<string> names;
   for (auto layer : layers_)
-    for (const auto name : layer->param_names())
-      names.push_back(name);
+    for (const auto name : layer->param_names()) names.push_back(name);
   return names;
 }
-const vector<Tensor *> FeedForwardNet::GetParamValues() const {
-  vector<Tensor *> values;
+const vector<Tensor*> FeedForwardNet::GetParamValues() const {
+  vector<Tensor*> values;
   for (auto layer : layers_)
-    for (const auto value : layer->param_values())
-      values.push_back(value);
+    for (const auto value : layer->param_values()) values.push_back(value);
   return values;
 }
 
-const vector<Tensor *> FeedForwardNet::GetParamSpecs() const {
-  vector<ParamSpec *> specs;
+const vector<ParamSpec> FeedForwardNet::GetParamSpecs() const {
+  vector<ParamSpec> specs;
   for (auto layer : layers_)
-    for (const auto spec : layer->param_specs())
-      specs.push_back(spec);
+    for (const auto spec : layer->param_specs()) specs.push_back(spec);
   return specs;
 }
 
-void FeedForwardNet::Compile(bool shuffle, Optimizer* opt, Loss* loss,
-                             Metric* metric) {
+void FeedForwardNet::Compile(bool shuffle, Optimizer* opt, Loss<Tensor>* loss,
+                             Metric<Tensor>* metric) {
   shuffle_ = shuffle;
   bool train = (opt != nullptr) && (loss != nullptr);
   bool test = metric != nullptr;
@@ -73,14 +75,17 @@ void FeedForwardNet::Compile(bool shuffle, Optimizer* opt, Loss* loss,
   opt_ = opt;
   loss_ = loss;
   metric_ = metric;
+  // init params and register them to sgd
 }
 
 void FeedForwardNet::ToDevice(std::shared_ptr<Device> device) {
   for (auto layer: layers_)
     layer->ToDevice(device);
+  /*
   opt_->ToDevice(device);
   loss_->ToDevice(device);
   metric_->ToDevice(device);
+  */
 }
 
 FeedForwardNet FeedForwardNet::Clone(std::shared_ptr<Device> device) {
@@ -98,118 +103,110 @@ FeedForwardNet FeedForwardNet::Clone(std::shared_ptr<Device> device) {
   net.device_ = device;
   net.dtype_ = dtype;
   */
+  return net;
 }
 
 void FeedForwardNet::AsType(DataType dtype) {
   LOG(FATAL) << "FeedForwardNet::AsType not implemented";
 }
 
-void FeedForwardNet::Train(int batchsize, int nb_epoch, Tensor x, Tensor y) {
-  CHECK_EQ(x.shape(0), y.shape(0)) << "Diff num of sampels in x and y";
-  int num_extra_samples = x.shape(0) % batchsize;
-  if (num_extra_samples != 0)
-    LOG(WARNING) << "The last " << num_extra_samples << " would not be used";
-  Channel *ch = GetChannel("perf");
-  for (int epoch = 0; epoch < nb_epoch; epoch++) {
-    float loss = 0.0f, metric = 0.0f;
-    int batch = 0;
-    for (; batch < x.shape(0) / batchsize; batch++) {
-      Tesnor bx = x.Slice(batch * batchsize, batch * batchsize + batchsize);
-      Tesnor by = y.Slice(batch * batchsize, batch * batchsize + batchsize);
-      const auto ret = TrainOnBatch(bx, by);
-      loss += ret.first;
-      metric += ret.second;
-    }
-    loss /= batch;
-    metric /= batch;
-    ch->Send("Epoch " + std::to_string(epoch) + ", training loss = " +
-             std::to_string(loss) + ", accuracy = " + std::to_string(metric));
-  }
-}
-
-void FeedForwardNet::Train(int batchsize, int nb_epoch, Tensor x, Tensor y,
-   float val_split) {
+void FeedForwardNet::Train(size_t batchsize, int nb_epoch, const Tensor& x,
+                           const Tensor& y, float val_split) {
   CHECK_EQ(x.shape(0), y.shape(0)) << "Diff num of sampels in x and y";
   size_t num_train = x.shape(0) * val_split;
-  const Tensor train_x = CopyRows(x, 0, num_train);
-  const Tensor train_y = CopyRows(y, 0, num_train);
-  const Tensor val_x = CopyRows(x, num_train, x.shape(0));
-  const Tensor val_y = CopyRows(y, num_train, x.shape(0));
-  Train(batchsize, nb_epoch, train_x, train_y, val_x, val_y);
+  if (val_split == 0.0f) {
+    Tensor dummy;
+    Train(batchsize, nb_epoch, x, y, dummy, dummy);
+  } else {
+    const Tensor train_x = CopyRows(x, 0, num_train);
+    const Tensor train_y = CopyRows(y, 0, num_train);
+    const Tensor test_x = CopyRows(x, num_train, x.shape(0));
+    const Tensor test_y = CopyRows(y, num_train, y.shape(0));
+    Train(batchsize, nb_epoch, train_x, train_y, test_x, test_y);
+  }
 }
 
-
-void FeedForwardNet::Train(int batchsize, int nb_epoch, Tensor x, Tensor y,
-    const Tensor & val_x, const Tensor &val_y) {
+void FeedForwardNet::Train(size_t batchsize, int nb_epoch, const Tensor& x,
+                           const Tensor& y, const Tensor& val_x,
+                           const Tensor& val_y) {
+  InitNetParams();
   CHECK_EQ(x.shape(0), y.shape(0)) << "Diff num of sampels in x and y";
   int num_extra_samples = x.shape(0) % batchsize;
   if (num_extra_samples != 0)
     LOG(WARNING) << "The last " << num_extra_samples << " would not be used";
-  Channel *train_ch = GetChannel("train_perf");
-  Channel *test_ch = GetChannel("test_perf");
+  Channel* train_ch = GetChannel("train_perf");
+  train_ch->EnableDestStderr(true);
+  Channel* val_ch = GetChannel("val_perf");
   for (int epoch = 0; epoch < nb_epoch; epoch++) {
     float loss = 0.0f, metric = 0.0f;
-    int b = 0;
-    for (;b < x.shape(0) / batchsize; b++) {
-      Tesnor bx = CopyRows(x, b * batchsize, b * batchsize + batchsize);
-      Tesnor by = CopyRows(y, b * batchsize, b * batchsize + batchsize);
+    size_t b = 0;
+    for (; b < x.shape(0) / batchsize; b++) {
+      const Tensor bx = CopyRows(x, b * batchsize, b * batchsize + batchsize);
+      const Tensor by = CopyRows(y, b * batchsize, b * batchsize + batchsize);
       const auto ret = TrainOnBatch(bx, by);
       loss += ret.first;
       metric += ret.second;
     }
-    loss /= batch;
-    metric /= batch;
+    loss /= b;
+    metric /= b;
     train_ch->Send("Epoch " + std::to_string(epoch) + ", training loss = " +
-             std::to_string(loss) + ", accuracy = " + std::to_string(metric));
-    const auto val_perf = Evaluate(val_x, val_y, batchsize);
-    test_ch->Send("Epoch " + std::to_string(epoch)
-        + ", test loss = " + std::to_string(Average(val_perf.first))
-        + ", metric = " + std::to_string(Average(val_perf.second)));
+                   std::to_string(loss) + ", accuracy = " +
+                   std::to_string(metric));
+    if (val_x.Size() && val_y.Size()) {
+      const auto val_perf = Evaluate(val_x, val_y, batchsize);
+      val_ch->Send("Epoch " + std::to_string(epoch) + ", val loss = " +
+                   std::to_string(Sum(val_perf.first) / val_y.Size()) +
+                   ", metric = " +
+                   std::to_string(Sum(val_perf.second) / val_y.Size()));
+    }
   }
 }
 
-const std::pair<float, float> FeedForwardNet::TrainOnBatch(const Tensor x,
-                                                           const Tensor y) {
-  const Tensor fea = Forward(kTrain, bx);
-  float loss = loss->Evaluate(fea, fy);
-  float metric = metric->Evaluate(fea, by);
-  const Tensor grad = loss->Backward();
-  Backward(kTrain, grad);
+const std::pair<float, float> FeedForwardNet::TrainOnBatch(const Tensor& x,
+                                                           const Tensor& y) {
+  int flag = kTrain;
+  const Tensor fea = Forward(flag, x);
+  float loss = loss_->Evaluate(fea, y);
+  float metric = metric_->Evaluate(fea, y);
+  const Tensor grad = loss_->Backward();
+  const auto grads = Backward(kTrain, grad);
   return std::make_pair(loss, metric);
 }
 
-const Tensor FeedForwardNet::Forward(int flag, const Tensor data) {
-  Tensor tmp = data;
+const Tensor FeedForwardNet::Forward(int flag, const Tensor& data) {
+  Tensor input = data, output;
   for (auto layer : layers_) {
-    tmp = layer->Forward(flag, tmp);
+//    LOG(INFO) << layer->name();
+    output = layer->Forward(flag, input);
+    input = output;
   }
-  return tmp;
+  return output;
 }
 
-cons vector<Tensor> FeedForwardNet::Backward(int flag, const Tensor grad) {
+const vector<Tensor> FeedForwardNet::Backward(int flag, const Tensor& grad) {
   vector<Tensor> param_grads;
   Tensor tmp = grad;
-  for (size_t i = layers_.size() - 1; i >= 0; i--) {
+  for (int i = layers_.size() - 1; i >= 0; i--) {
+ //   LOG(INFO) << layers_.at(i)->name();
     auto ret = layers_.at(i)->Backward(flag, tmp);
-    tmp =ret.first;
+    tmp = ret.first;
     if (ret.second.size())
-      for (const auto x: ret.second)
-        param_grads.push_back(x);
+      for (const auto x : ret.second) param_grads.push_back(x);
   }
   return param_grads;
 }
 
-std::pair<Tensor, Tensor> Evaluate(Tensor x, Tensor y, int batchsize) {
+std::pair<Tensor, Tensor> FeedForwardNet::Evaluate(const Tensor& x,
+                                                   const Tensor& y,
+                                                   size_t batchsize) {
   CHECK_EQ(x.shape(0), y.shape(0)) << "Diff num of sampels in x and y";
   CHECK_GE(x.shape(0), batchsize);
   int num_extra_samples = x.shape(0) % batchsize;
-  int b = 0;
   Tensor loss(Shape{x.shape(0)}), metric(Shape{x.shape(0)});
-  for (; b < x.shape(0) / batchsize; b++) {
+  for (size_t b = 0; b < x.shape(0) / batchsize; b++) {
     int start = b * batchsize, end = start + batchsize;
     const Tensor bx = CopyRows(x, start, end);
     const Tensor by = CopyRows(y, start, end);
-    const Tensor fea = Forward(kEval, bx);
     const auto ret = EvaluateOnBatch(bx, by);
     CopyDataToFrom(&loss, ret.first, batchsize, start, 0);
     CopyDataToFrom(&metric, ret.second, batchsize, start, 0);
@@ -230,18 +227,19 @@ std::pair<Tensor, Tensor> Evaluate(Tensor x, Tensor y, int batchsize) {
 
 std::pair<Tensor, Tensor> FeedForwardNet::EvaluateOnBatch(const Tensor& x,
                                                           const Tensor& y) {
-  const Tensor fea = Forward(kEval, bx);
-  const Tensor m = metric_->Forward(fea, by);
-  const Tensor l = loss_->Forward(fea, by);
+  int flag = kEval;
+  const Tensor fea = Forward(flag, x);
+  const Tensor m = metric_->Forward(fea, y);
+  const Tensor l = loss_->Forward(fea, y);
   return std::make_pair(m, l);
 }
 
-const Tensor FeedForwardNet::Predict(const Tensor& x, int batchsize) {
+const Tensor FeedForwardNet::Predict(const Tensor& x, size_t batchsize) {
   CHECK_GE(x.shape(0), batchsize);
   int num_extra_samples = x.shape(0) % batchsize;
-  const auto outshape = layers_.back().GetOutputSampleShape();
+  const auto outshape = layers_.back()->GetOutputSampleShape();
   Tensor y(Shape{x.shape(0), Product(outshape)}, x.device());
-  for (int b = 0; b < x.shape(0) / batchsize; b++) {
+  for (size_t b = 0; b < x.shape(0) / batchsize; b++) {
     int start = b * batchsize, end = start + batchsize;
     const Tensor bx = CopyRows(x, start, end);
     CopyDataToFrom(&y, PredictOnBatch(bx), batchsize * y.shape(1),
@@ -258,6 +256,6 @@ const Tensor FeedForwardNet::Predict(const Tensor& x, int batchsize) {
 }
 
 const Tensor FeedForwardNet::PredictOnBatch(const Tensor& x) {
-  return Foward(kEval, x);
+  return Forward(kEval, x);
 }
 }  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/cf1d8418/src/model/layer/convolution.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/convolution.cc b/src/model/layer/convolution.cc
index c27960d..e4991a4 100644
--- a/src/model/layer/convolution.cc
+++ b/src/model/layer/convolution.cc
@@ -112,5 +112,9 @@ const std::pair<Tensor, vector<Tensor>> Convolution::Backward(
 
   return std::make_pair(input_grad, param_grad);
 }
-
+void Convolution::ToDevice(std::shared_ptr<Device> device) {
+  Layer::ToDevice(device);
+  weight_.ToDevice(device);
+  bias_.ToDevice(device);
+}
 }  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/cf1d8418/src/model/layer/convolution.h
----------------------------------------------------------------------
diff --git a/src/model/layer/convolution.h b/src/model/layer/convolution.h
index 7ea5712..0e0b160 100644
--- a/src/model/layer/convolution.h
+++ b/src/model/layer/convolution.h
@@ -44,6 +44,7 @@ class Convolution : public Layer {
   const std::pair<Tensor, vector<Tensor>> Backward(int flag,
                                                    const Tensor &grad) override;
 
+  void ToDevice(std::shared_ptr<Device> device) override;
   size_t kernel_w() const { return kernel_w_; }
   size_t kernel_h() const { return kernel_h_; }
   size_t pad_w() const { return pad_w_; }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/cf1d8418/src/model/layer/cudnn_activation.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_activation.cc b/src/model/layer/cudnn_activation.cc
index 4216fea..df728ce 100644
--- a/src/model/layer/cudnn_activation.cc
+++ b/src/model/layer/cudnn_activation.cc
@@ -54,6 +54,7 @@ void CudnnActivation::InitCudnn(size_t size, DataType dtype) {
 }
 
 const Tensor CudnnActivation::Forward(int flag, const Tensor& input) {
+  CHECK(buf_.empty());
   auto size = input.Size();
   DataType dtype = input.data_type();
   if (!has_init_cudnn_) {

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/cf1d8418/src/model/layer/cudnn_convolution.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_convolution.cc b/src/model/layer/cudnn_convolution.cc
index d5ac2a3..eb507b2 100644
--- a/src/model/layer/cudnn_convolution.cc
+++ b/src/model/layer/cudnn_convolution.cc
@@ -156,6 +156,7 @@ void CudnnConvolution::InitCudnn(const Tensor &input) {
 }
 
 const Tensor CudnnConvolution::Forward(int flag, const Tensor &input) {
+  CHECK(buf_.empty());
   CHECK_EQ(input.device()->lang(), kCuda);
   CHECK_EQ(input.nDim(), 4u);
   if (flag & kTrain) buf_.push(input);  // buffer the input for backward

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/cf1d8418/src/model/layer/cudnn_dropout.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_dropout.cc b/src/model/layer/cudnn_dropout.cc
index 2e2e12b..f9b9dbf 100644
--- a/src/model/layer/cudnn_dropout.cc
+++ b/src/model/layer/cudnn_dropout.cc
@@ -106,6 +106,10 @@ const std::pair<Tensor, vector<Tensor>> CudnnDropout::Backward(
   }
   return std::make_pair(dx, param_grad);
 }
+void CudnnDropout::ToDevice(std::shared_ptr<Device> device) {
+  Dropout::ToDevice(device);
+  state.ToDevice(device);
+}
 }  // namespace singa
 #endif  // CUDNN_VERSION_MAJOR>=5
 #endif  // USE_CUDNN

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/cf1d8418/src/model/layer/cudnn_dropout.h
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_dropout.h b/src/model/layer/cudnn_dropout.h
index 6809653..9e0cb9e 100644
--- a/src/model/layer/cudnn_dropout.h
+++ b/src/model/layer/cudnn_dropout.h
@@ -42,6 +42,7 @@ class CudnnDropout : public Dropout {
   const std::pair<Tensor, vector<Tensor>> Backward(int flag,
                                                    const Tensor& grad) override;
 
+  void ToDevice(std::shared_ptr<Device> device) override;
  private:
   /// Init cudnn related data structures.
   void InitCudnn(int size, DataType dtype, std::shared_ptr<Device> dev,

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/cf1d8418/src/model/layer/cudnn_pooling.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_pooling.cc b/src/model/layer/cudnn_pooling.cc
index 6d7a5b1..e49a1ec 100644
--- a/src/model/layer/cudnn_pooling.cc
+++ b/src/model/layer/cudnn_pooling.cc
@@ -78,6 +78,7 @@ void CudnnPooling::InitCudnn(const Tensor &input) {
 }
 
 const Tensor CudnnPooling::Forward(int flag, const Tensor &input) {
+  CHECK(buf_.empty());
   CHECK_EQ(input.device()->lang(), kCuda);
   CHECK_EQ(input.nDim(), 4u);
   size_t batchsize = input.shape(0);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/cf1d8418/src/model/layer/cudnn_softmax.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_softmax.cc b/src/model/layer/cudnn_softmax.cc
index 77eab98..1d9e0b8 100644
--- a/src/model/layer/cudnn_softmax.cc
+++ b/src/model/layer/cudnn_softmax.cc
@@ -57,6 +57,7 @@ void CudnnSoftmax::InitCudnn(Shape shape, DataType dtype) {
 }
 
 const Tensor CudnnSoftmax::Forward(int flag, const Tensor& input) {
+  CHECK(buf_.empty());
   auto shape = input.shape();
   DataType dtype = input.data_type();
   if (!has_init_cudnn_) {

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/cf1d8418/src/model/layer/dense.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/dense.cc b/src/model/layer/dense.cc
index bad26a8..c6a9f8a 100644
--- a/src/model/layer/dense.cc
+++ b/src/model/layer/dense.cc
@@ -45,7 +45,9 @@ void Dense::Setup(const Shape& in_sample, const LayerConf &conf) {
 
 /// \copydoc Layer::Forward(int flag, const Tensor&)
 const Tensor Dense::Forward(int flag, const Tensor &input) {
+  CHECK(buf_.empty());
   Tensor output;
+  CHECK_EQ(input.nDim(), 2);
   if (transpose_)  // use the transposed version of weight_ for computing
     output = Mult(input, weight_);
   else
@@ -81,6 +83,7 @@ const std::pair<Tensor, vector<Tensor>> Dense::Backward(int flag,
 }
 
 void Dense::ToDevice(std::shared_ptr<Device> device) {
+  Layer::ToDevice(device);
   weight_.ToDevice(device);
   bias_.ToDevice(device);
 }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/cf1d8418/src/model/metric/accuracy.cc
----------------------------------------------------------------------
diff --git a/src/model/metric/accuracy.cc b/src/model/metric/accuracy.cc
new file mode 100644
index 0000000..1b667b1
--- /dev/null
+++ b/src/model/metric/accuracy.cc
@@ -0,0 +1,62 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "singa/model/metric.h"
+#include <algorithm>
+namespace singa {
+
+Tensor Accuracy::Match(const Tensor& predict, const vector<int>& target) {
+  Tensor prediction(predict.shape());
+  prediction.CopyData(predict);
+  size_t batchsize = target.size();
+  size_t nb_classes = prediction.Size() / batchsize;
+  // each row of prediction is the prob distribution for one sample
+  CHECK_EQ(prediction.shape().at(0), batchsize);
+  // TODO(wangwei) CloneToDevice(host);
+  const float* prob = prediction.data<float>();
+  float* score = new float[batchsize];
+  for (size_t b = 0; b < batchsize; b++) {
+    vector<std::pair<float, int>> prob_class;
+    for (size_t c = 0; c < nb_classes; c++) {
+      prob_class.push_back(std::make_pair(prob[b * nb_classes + c], c));
+    }
+    std::partial_sort(prob_class.begin(), prob_class.begin() + top_k_,
+                      prob_class.end(), std::greater<std::pair<float, int>>());
+
+    for (size_t k = 0; k < top_k_; k++)
+      if (prob_class.at(k).second == target.at(b)) score[b] = 1;
+  }
+  Tensor ret(Shape{batchsize});
+  ret.CopyDataFromHostPtr(score, batchsize);
+  return ret;
+}
+
+// TODO(wangwei) consider multi-label cases, where target is of shape
+// nb_samples * nb_classes
+Tensor Accuracy::Forward(const Tensor& prediction, const Tensor& t) {
+  Tensor target(t.shape(), t.data_type());
+  target.CopyData(t);
+  vector<int> target_vec;
+  // TODO(wangwei) copy target to host.
+  const int* target_value = target.data<int>();
+  for (size_t i = 0; i < target.Size(); i++)
+    target_vec.push_back(target_value[i]);
+  return Match(prediction, target_vec);
+}
+
+}  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/cf1d8418/src/model/metric/accuracy.h
----------------------------------------------------------------------
diff --git a/src/model/metric/accuracy.h b/src/model/metric/accuracy.h
deleted file mode 100644
index 69bd96b..0000000
--- a/src/model/metric/accuracy.h
+++ /dev/null
@@ -1,84 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef SINGA_MODEL_METRIC_ACCURACY_H_
-#define SINGA_MODEL_METRIC_ACCURACY_H_
-#include "singa/model/metric.h"
-#include <algorithm>
-namespace singa {
-
-/// Compute the accuray of the prediction, which is matched against the
-/// ground truth labels.
-/// TODO(wangwei) consider multi-label cases.
-class Accuracy : public Metric<Tensor> {
- public:
-  /// Set meta fields from user configurations.
-  void Setup(const MetricConf& conf) override { top_k_ = conf.top_k(); }
-
-  /// Check the prediction against the target (ground truth) for each data
-  /// sample. The returned Tensor has a float value for each sample, 0 for wrong
-  /// and 1 for correct. Users can call Sum(const Tensor&) / Tensor::Size() to
-  /// get the accuracy.
-  Tensor Forward(const Tensor& prediction, const Tensor& target);
-
- private:
-  /// \copydoc Match(const Tensor&, const Tensor&);
-  Tensor Match(const Tensor& prediction, const vector<int>& target);
-  /// If the ground truth label is in the top k predicted labels, then the
-  /// prediction is correct.
-  size_t top_k_ = 1;
-};
-
-Tensor Accuracy::Match(const Tensor& prediction, const vector<int>& target) {
-  size_t batchsize = target.size();
-  size_t nb_classes = prediction.Size() / batchsize;
-  // each row of prediction is the prob distribution for one sample
-  CHECK_EQ(prediction.shape().at(0), batchsize);
-  // TODO(wangwei) CloneToDevice(host);
-  const float* prob = prediction.data<float>();
-  float* score = new float[batchsize];
-  for (size_t b = 0; b < batchsize; b++) {
-    vector<std::pair<float, int>> prob_class;
-    for (size_t c = 0; c < nb_classes; c++) {
-      prob_class.push_back(std::make_pair(prob[b * nb_classes + c], c));
-    }
-    std::partial_sort(prob_class.begin(), prob_class.begin() + top_k_,
-                      prob_class.end(), std::greater<std::pair<float, int>>());
-
-    for (size_t k = 0; k < top_k_; k++)
-      if (prob_class.at(k).second == target.at(b)) score[b] = 1;
-  }
-  Tensor ret(Shape{batchsize});
-  ret.CopyDataFromHostPtr(score, batchsize);
-  return ret;
-}
-
-// TODO(wangwei) consider multi-label cases, where target is of shape
-// nb_samples * nb_classes
-Tensor Accuracy::Forward(const Tensor& prediction, const Tensor& target) {
-  vector<int> target_vec;
-  // TODO(wangwei) copy target to host.
-  const int* target_value = target.data<int>();
-  for (size_t i = 0; i < target.Size(); i++)
-    target_vec.push_back(target_value[i]);
-  return Match(prediction, target_vec);
-}
-
-}  // namespace singa
-
-#endif  // SINGA_MODEL_METRIC_ACCURACY_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/cf1d8418/test/singa/test_accuracy.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_accuracy.cc b/test/singa/test_accuracy.cc
index dc7719b..4ff14c0 100644
--- a/test/singa/test_accuracy.cc
+++ b/test/singa/test_accuracy.cc
@@ -20,7 +20,7 @@
 *************************************************************/
 
 #include "gtest/gtest.h"
-#include "../src/model/metric/accuracy.h"
+#include "singa/model/metric.h"
 
 TEST(Accuracy, Compute) {
   singa::Accuracy acc;