You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by wa...@apache.org on 2016/06/17 06:38:09 UTC

incubator-singa git commit: SINGA-198 - Change Layer::Setup API to include input Tensor shapes

Repository: incubator-singa
Updated Branches:
  refs/heads/dev a4d9aab52 -> 74f02143a


SINGA-198 - Change Layer::Setup API to include input Tensor shapes

Update the setup function from
```
void Setup(const LayerConf& conf);
```
to
```
void Setup(const Shape& in_sample_shape, const LayerConf& conf);  // for single input
void Setup(const vector<Shape>& in_sample_shapes, const LayerConf& conf); // for multiple outputs
```

functions for getting output sample shape are added
```
const Shape GetOutputSampleShape() const;  // used for single output
const Shape GetOutputSampleShape(int k) const;  // used for multiple outputs
```

pass tests.


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/74f02143
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/74f02143
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/74f02143

Branch: refs/heads/dev
Commit: 74f02143a22b0f3c478a45716beac7fb087b0e7b
Parents: a4d9aab
Author: Wei Wang <wa...@comp.nus.edu.sg>
Authored: Thu Jun 16 19:48:25 2016 +0800
Committer: Wei Wang <wa...@comp.nus.edu.sg>
Committed: Thu Jun 16 19:57:16 2016 +0800

----------------------------------------------------------------------
 include/singa/model/layer.h          | 46 ++++++++++++++++++++++++-------
 src/model/layer/activation.cc        |  5 ++--
 src/model/layer/activation.h         |  7 ++++-
 src/model/layer/batchnorm.cc         | 11 ++++----
 src/model/layer/batchnorm.h          | 12 +++++---
 src/model/layer/convolution.cc       | 13 +++++----
 src/model/layer/convolution.h        |  7 ++++-
 src/model/layer/cudnn_batchnorm.cc   |  4 +--
 src/model/layer/cudnn_batchnorm.h    |  6 ++--
 src/model/layer/cudnn_convolution.cc |  4 +--
 src/model/layer/cudnn_convolution.h  |  2 +-
 src/model/layer/cudnn_pooling.cc     |  4 +--
 src/model/layer/cudnn_pooling.h      |  2 +-
 src/model/layer/dense.cc             |  7 +++--
 src/model/layer/dense.h              |  6 +++-
 src/model/layer/dropout.cc           |  5 ++--
 src/model/layer/dropout.h            |  7 ++++-
 src/model/layer/flatten.cc           | 30 ++++++++++----------
 src/model/layer/flatten.h            |  9 ++++--
 src/model/layer/lrn.cc               |  5 ++--
 src/model/layer/lrn.h                | 13 ++++++---
 src/model/layer/pooling.cc           | 13 +++++----
 src/model/layer/pooling.h            |  8 ++++--
 src/model/layer/prelu.cc             |  5 ++--
 src/model/layer/prelu.h              |  8 +++++-
 src/model/layer/softmax.cc           | 36 +++++++++++++-----------
 src/model/layer/softmax.h            |  7 ++++-
 src/proto/model.proto                | 15 +---------
 test/singa/test_activation.cc        |  7 +++--
 test/singa/test_cudnn_activation.cc  |  7 +++--
 test/singa/test_cudnn_batchnorm.cc   | 17 +++---------
 test/singa/test_cudnn_convolution.cc | 31 +++++----------------
 test/singa/test_cudnn_dropout.cc     |  7 +++--
 test/singa/test_cudnn_lrn.cc         |  8 +++---
 test/singa/test_cudnn_pooling.cc     | 16 +++--------
 test/singa/test_cudnn_softmax.cc     | 17 ++++--------
 test/singa/test_dense.cc             | 16 ++++-------
 test/singa/test_dropout.cc           |  7 +++--
 test/singa/test_flatten.cc           | 11 ++++----
 test/singa/test_prelu.cc             | 11 ++++----
 test/singa/test_softmax.cc           | 17 +++---------
 41 files changed, 246 insertions(+), 223 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/74f02143/include/singa/model/layer.h
----------------------------------------------------------------------
diff --git a/include/singa/model/layer.h b/include/singa/model/layer.h
index 2addc98..5f5c197 100644
--- a/include/singa/model/layer.h
+++ b/include/singa/model/layer.h
@@ -36,12 +36,21 @@ class Layer {
   Layer() = default;
 
   /// Set meta data fields from a string representing a proto message.
-    void Setup(const string& proto_str) {
+  /// 'in_shape' is the shape of the input feature for one sample
+  void Setup(const vector<size_t>& in_shape, const string& proto_str) {
     LayerConf conf;
     conf.ParseFromString(proto_str);
-    this->Setup(conf);
+    this->Setup(in_shape, conf);
   }
 
+  /// 'in_shapes' is the shape of the input feature for one sample
+  void Setup(const vector<vector<size_t>>& in_shapes, const string& proto_str) {
+    LayerConf conf;
+    conf.ParseFromString(proto_str);
+    this->Setup(in_shapes, conf);
+  }
+
+
   // ============= Following Functions could be override =====================
   /// Destruct objects created by this layer.
   virtual ~Layer() {};
@@ -51,19 +60,36 @@ class Layer {
   virtual const std::string layer_type() const { return "Unknown"; }
 
   /// Set meta data fields configured in 'conf' (a proto message).
-  /// For some layers, which use input tensor shapes for setting its parameter
-  /// shapes (e.g, desen layer and convolution layer), users or wrapper
-  /// functions need to configure ncessary fields inside LayerConf.
+  /// Some layers would use input tensor shapes for setting its parameter
+  /// shapes (e.g, desen layer and convolution layer). 'in_shape' provides such
+  /// shape info. It represents the shape of the Tensor (with a single sample)
+  /// from the last layer.
   /// After calling Setup, the shape info of parameters should be accssed
-  /// correctly. All other info that depends on input tensors (e.g., batchsize)
-  /// should be set inside Forward(). Internal buffer/fields are set assuming
-  /// batchsize is 1.
-  virtual void Setup(const LayerConf& conf) {
+  /// correctly. Internal buffer/fields are set assuming batchsize is 1.
+  virtual void Setup(const Shape& in_sample, const LayerConf& conf) {
+    name_ = conf.name();
+    // TODO(wangwei) load param values from checkpoint files.
+  }
+
+  /// Used for layers that have multiple input tensors, e.g., concatenate layer.
+  virtual void Setup(const vector<Shape>& in_samples,
+                     const LayerConf& conf) {
     name_ = conf.name();
-    // for (const auto& spec : conf.param()) param_specs_.push_back(spec);
     // TODO(wangwei) load param values from checkpoint files.
   }
 
+  /// Return the shape of the generated Tensor without the batchsize dimension
+  virtual const Shape GetOutputSampleShape() {
+    LOG(FATAL) << "Pls override this function";
+    return vector<size_t>{};
+  }
+  /// Return the shape of the k-th generated tensor without the batchsize
+  /// dimension. Used for layers that generate multiple tensors.
+  virtual const Shape GetOutputSampleShape(int k) {
+    LOG(FATAL) << "Pls override this function";
+    return vector<size_t>{};
+  }
+
   /// Do feature transformation for the given 'input' tensor (denoted as x).
   /// 'flag' is either kTrain or kEval for feed-forward nets, and
   /// would be used for other phases of training other nets. For example, when

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/74f02143/src/model/layer/activation.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/activation.cc b/src/model/layer/activation.cc
index e7c0696..2f76a6d 100644
--- a/src/model/layer/activation.cc
+++ b/src/model/layer/activation.cc
@@ -20,12 +20,13 @@
 #include "./activation.h"
 namespace singa {
 
-void Activation::Setup(const LayerConf& conf) {
-  Layer::Setup(conf);
+void Activation::Setup(const Shape& in_sample, const LayerConf& conf) {
+  Layer::Setup(in_sample, conf);
   mode_ = conf.type();
   if (mode_ == "RELU") {
     neg_slope_ = conf.relu_conf().negative_slope();
   }
+  out_sample_shape_ = in_sample;
 }
 
 const Tensor Activation::Forward(int flag, const Tensor& input) {

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/74f02143/src/model/layer/activation.h
----------------------------------------------------------------------
diff --git a/src/model/layer/activation.h b/src/model/layer/activation.h
index 1747577..1799514 100644
--- a/src/model/layer/activation.h
+++ b/src/model/layer/activation.h
@@ -29,7 +29,11 @@ class Activation : public Layer {
   const std::string layer_type() const override { return "Activation"; }
 
   /// \copydoc Layer::Setup(const LayerConf&);
-  void Setup(const LayerConf& conf) override;
+  void Setup(const Shape& in_sample, const LayerConf& conf) override;
+  const Shape GetOutputSampleShape() {
+    CHECK(out_sample_shape_.size()) << "You may haven't call Setup()";
+    return out_sample_shape_;
+  }
 
   /// \copydoc Layer::Forward(int flag, const Tensor&)
   const Tensor Forward(int flag, const Tensor& input) override;
@@ -45,6 +49,7 @@ class Activation : public Layer {
  protected:
   std::string mode_;
   std::stack<Tensor> buf_;
+  Shape out_sample_shape_;
   float neg_slope_;
 };
 }  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/74f02143/src/model/layer/batchnorm.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/batchnorm.cc b/src/model/layer/batchnorm.cc
index bcd0870..4cc2320 100644
--- a/src/model/layer/batchnorm.cc
+++ b/src/model/layer/batchnorm.cc
@@ -21,12 +21,13 @@
 #include "batchnorm.h"
 
 namespace singa {
-void BatchNorm::Setup(const LayerConf& conf) {
-  Layer::Setup(conf);
+void BatchNorm::Setup(const Shape& in_sample, const LayerConf& conf) {
+  Layer::Setup(in_sample, conf);
+  out_sample_shape_ = in_sample;
   factor_ = conf.batchnorm_conf().factor();
-  channels_ = conf.batchnorm_conf().channels();
-  height_ = conf.batchnorm_conf().height();
-  width_ = conf.batchnorm_conf().width();
+  channels_ = in_sample.at(0);
+  height_ = in_sample.at(1);
+  width_ = in_sample.at(2);
 
   bnScale_.Reshape(Shape{channels_ * height_ * width_});
   bnBias_.ResetLike(bnScale_);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/74f02143/src/model/layer/batchnorm.h
----------------------------------------------------------------------
diff --git a/src/model/layer/batchnorm.h b/src/model/layer/batchnorm.h
index 0255179..433e0c7 100644
--- a/src/model/layer/batchnorm.h
+++ b/src/model/layer/batchnorm.h
@@ -33,8 +33,12 @@ class BatchNorm : public Layer {
     return "Batch Normalization";
   }
 
-  /// \copydoc Layer::Setup(const LayerConf&)
-  virtual void Setup(const LayerConf& conf) override;
+  /// \copydoc Layer::Setup(const LayerConf&);
+  void Setup(const Shape& in_sample, const LayerConf& conf) override;
+  const Shape GetOutputSampleShape() {
+    CHECK(out_sample_shape_.size()) << "You may haven't call Setup()";
+    return out_sample_shape_;
+  }
 
   const Tensor Forward(int flag, const Tensor& input)
     override;
@@ -77,8 +81,8 @@ class BatchNorm : public Layer {
   Tensor runningMean_, runningVariance_;
   // Store intermediate data, i.e., input tensor
   std::stack<Tensor> buf_;
-  
+  Shape out_sample_shape_;
 }; // class batchnorm
-} // namespace 
+} // namespace
 
 #endif  // SINGA_MODEL_LAYER_BATCHNORM_H

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/74f02143/src/model/layer/convolution.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/convolution.cc b/src/model/layer/convolution.cc
index 50ee3c8..c27960d 100644
--- a/src/model/layer/convolution.cc
+++ b/src/model/layer/convolution.cc
@@ -23,8 +23,8 @@
 namespace singa {
 using std::vector;
 
-void Convolution::Setup(const LayerConf &conf) {
-  Layer::Setup(conf);
+void Convolution::Setup(const Shape& in_sample, const LayerConf &conf) {
+  Layer::Setup(in_sample, conf);
   ConvolutionConf conv_conf = conf.convolution_conf();
   // kernel_size, pad, and stride are repeated fields.
   if (conv_conf.kernel_size_size() > 0) {
@@ -73,12 +73,15 @@ void Convolution::Setup(const LayerConf &conf) {
   bias_term_ = conv_conf.bias_term();
 
   // Shape of input image
-  channels_ = conv_conf.channels();
-  height_ = conv_conf.height();
-  width_ = conv_conf.width();
+  CHECK_EQ(in_sample.size(), 3u);
+  channels_ = in_sample.at(0);
+  height_ = in_sample.at(1);
+  width_ = in_sample.at(2);
 
   conv_height_ = (height_ + 2 * pad_h_ - kernel_h_) / stride_h_ + 1;
   conv_width_ = (width_ + 2 * pad_w_ - kernel_w_) / stride_w_ + 1;
+  out_sample_shape_ = vector<size_t>{num_filters_, conv_height_, conv_width_};
+
   col_height_ = channels_ * kernel_w_ * kernel_h_;
   col_width_ = conv_height_ * conv_width_;
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/74f02143/src/model/layer/convolution.h
----------------------------------------------------------------------
diff --git a/src/model/layer/convolution.h b/src/model/layer/convolution.h
index 477efb3..3901049 100644
--- a/src/model/layer/convolution.h
+++ b/src/model/layer/convolution.h
@@ -30,7 +30,11 @@ class Convolution : public Layer {
   const std::string layer_type() const override { return "Convolution"; }
 
   /// \copydoc Layer::Setup(const LayerConf&);
-  void Setup(const LayerConf &conf) override;
+  void Setup(const vector<size_t>& in_shape, const LayerConf& conf) override;
+  const Shape GetOutputSampleShape() {
+    CHECK(out_sample_shape_.size()) << "You may haven't call Setup()";
+    return out_sample_shape_;
+  }
 
   // void SetupParam(const Tensor &input);
   /// \copydoc Layer::Forward(int flag, const Tensor&)
@@ -72,6 +76,7 @@ class Convolution : public Layer {
   // store intermediate data, i.e., input tensor
   std::stack<Tensor> buf_;
   bool bias_term_;
+  vector<size_t> out_sample_shape_;
 };
 }  // namespace singa
 #endif  // SRC_MODEL_LAYER_CONVOLUTION_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/74f02143/src/model/layer/cudnn_batchnorm.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_batchnorm.cc b/src/model/layer/cudnn_batchnorm.cc
index 1393916..0e597fe 100644
--- a/src/model/layer/cudnn_batchnorm.cc
+++ b/src/model/layer/cudnn_batchnorm.cc
@@ -36,8 +36,8 @@ void CudnnBatchNorm::ToDevice(Device* device) {
   resultSaveVariance_.ToDevice(device);
 }
 
-void CudnnBatchNorm::Setup(const LayerConf& conf) {
-  BatchNorm::Setup(conf);
+void CudnnBatchNorm::Setup(const Shape& in_sample, const LayerConf& conf) {
+  BatchNorm::Setup(in_sample, conf);
   bnScale_.Reshape(Shape{1,channels_,1,1});
   bnBias_.ResetLike(bnScale_);
   dbnScale_.ResetLike(bnScale_);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/74f02143/src/model/layer/cudnn_batchnorm.h
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_batchnorm.h b/src/model/layer/cudnn_batchnorm.h
index 83258d2..47fd4c5 100644
--- a/src/model/layer/cudnn_batchnorm.h
+++ b/src/model/layer/cudnn_batchnorm.h
@@ -35,7 +35,7 @@ class CudnnBatchNorm : public BatchNorm {
      return "CudnnBatchNorm";
    }
 
-   void Setup(const LayerConf& conf) override;
+   void Setup(const Shape& in_sample, const LayerConf& conf) override;
 
    const Tensor Forward(int flag, const Tensor& input)
      override;
@@ -52,9 +52,9 @@ class CudnnBatchNorm : public BatchNorm {
    cudnnLRNDescriptor_t lrn_desc_;
    cudnnTensorDescriptor_t shape_desc_, param_desc_;
    Tensor resultSaveMean_, resultSaveVariance_;
-   
+
 }; // class CudnnBatchNorm
 }  // namespace
 
 #endif  // USE_CUDNN
-#endif  // SINGA_MODEL_LAYER_CUDNN_BATCHNORM 
+#endif  // SINGA_MODEL_LAYER_CUDNN_BATCHNORM

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/74f02143/src/model/layer/cudnn_convolution.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_convolution.cc b/src/model/layer/cudnn_convolution.cc
index efc7f88..8cdfc07 100644
--- a/src/model/layer/cudnn_convolution.cc
+++ b/src/model/layer/cudnn_convolution.cc
@@ -34,8 +34,8 @@ CudnnConvolution::~CudnnConvolution() {
   if (y_desc_ != nullptr) CUDNN_CHECK(cudnnDestroyTensorDescriptor(y_desc_));
 }
 
-void CudnnConvolution::Setup(const LayerConf &conf) {
-  Convolution::Setup(conf);
+void CudnnConvolution::Setup(const Shape& in_sample, const LayerConf &conf) {
+  Convolution::Setup(in_sample, conf);
   ConvolutionConf conv_conf = conf.convolution_conf();
   // convert MB to bytes
   workspace_byte_limit_ = conv_conf.workspace_byte_limit() << 20;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/74f02143/src/model/layer/cudnn_convolution.h
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_convolution.h b/src/model/layer/cudnn_convolution.h
index b86c576..6c15839 100644
--- a/src/model/layer/cudnn_convolution.h
+++ b/src/model/layer/cudnn_convolution.h
@@ -41,7 +41,7 @@ class CudnnConvolution : public Convolution {
                                                    const Tensor &grad) override;
 
   /// \copydoc Layer::Setup(const LayerConf&);
-  void Setup(const LayerConf &conf) override;
+  void Setup(const Shape& in_sample, const LayerConf &conf) override;
 
   void ToDevice(Device *device) override;
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/74f02143/src/model/layer/cudnn_pooling.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_pooling.cc b/src/model/layer/cudnn_pooling.cc
index fb8256a..9d288c0 100644
--- a/src/model/layer/cudnn_pooling.cc
+++ b/src/model/layer/cudnn_pooling.cc
@@ -32,8 +32,8 @@ CudnnPooling::~CudnnPooling() {
   if (y_desc_ != nullptr) CUDNN_CHECK(cudnnDestroyTensorDescriptor(y_desc_));
 }
 
-void CudnnPooling::Setup(const LayerConf &conf) {
-  Pooling::Setup(conf);
+void CudnnPooling::Setup(const Shape& in_sample, const LayerConf &conf) {
+  Pooling::Setup(in_sample, conf);
   PoolingConf pool_conf = conf.pooling_conf();
   if (pool_conf.nan_prop())
     nan_prop_ = CUDNN_PROPAGATE_NAN;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/74f02143/src/model/layer/cudnn_pooling.h
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_pooling.h b/src/model/layer/cudnn_pooling.h
index 1a38cd5..c3c7060 100644
--- a/src/model/layer/cudnn_pooling.h
+++ b/src/model/layer/cudnn_pooling.h
@@ -37,7 +37,7 @@ class CudnnPooling : public Pooling {
   /// \copydoc Layer::layer_type()
   const std::string layer_type() const override { return "CudnnPooling"; }
 
-  void Setup(const LayerConf &conf) override;
+  void Setup(const Shape& in_sample, const LayerConf &conf) override;
   const Tensor Forward(int flag, const Tensor &input) override;
   const std::pair<Tensor, vector<Tensor>> Backward(int flag,
                                                    const Tensor &grad) override;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/74f02143/src/model/layer/dense.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/dense.cc b/src/model/layer/dense.cc
index b349787..dd23329 100644
--- a/src/model/layer/dense.cc
+++ b/src/model/layer/dense.cc
@@ -27,11 +27,12 @@ Dense::~Dense() {
   // delete weight_;
   // delete bias_;
 }
-void Dense::Setup(const LayerConf &conf) {
-  Layer::Setup(conf);
+void Dense::Setup(const Shape& in_sample, const LayerConf &conf) {
+  Layer::Setup(in_sample, conf);
   auto dense_conf = conf.dense_conf();
+  CHECK_EQ(in_sample.size(), 1u);
+  vdim_ = in_sample.at(0);
   hdim_ = dense_conf.num_output();
-  vdim_ = dense_conf.num_input();
   transpose_ = dense_conf.transpose();
   if (transpose_)
     weight_.Reshape(Shape{vdim_, hdim_});

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/74f02143/src/model/layer/dense.h
----------------------------------------------------------------------
diff --git a/src/model/layer/dense.h b/src/model/layer/dense.h
index a5a6f66..6704106 100644
--- a/src/model/layer/dense.h
+++ b/src/model/layer/dense.h
@@ -31,7 +31,11 @@ class Dense : public Layer {
   const std::string layer_type() const override { return "Dense"; }
 
   /// \copydoc Layer::Setup(const LayerConf&);
-  void Setup(const LayerConf& conf) override;
+  void Setup(const Shape& in_sample, const LayerConf& conf) override;
+  const Shape GetOutputSampleShape() {
+    CHECK(hdim_) << "You may haven't call Setup()";
+    return vector<size_t>{hdim_};
+  }
 
   /// \copydoc Layer::Forward(int flag, const Tensor&)
   const Tensor Forward(int flag, const Tensor& input) override;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/74f02143/src/model/layer/dropout.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/dropout.cc b/src/model/layer/dropout.cc
index c2c97be..4781576 100644
--- a/src/model/layer/dropout.cc
+++ b/src/model/layer/dropout.cc
@@ -20,9 +20,10 @@
 #include "./dropout.h"
 namespace singa {
 
-void Dropout::Setup(const LayerConf& conf) {
-  Layer::Setup(conf);
+void Dropout::Setup(const Shape& in_sample, const LayerConf& conf) {
+  Layer::Setup(in_sample, conf);
   dropout_ratio_ = conf.dropout_conf().dropout_ratio();
+  out_sample_shape_= in_sample;
 }
 
 const Tensor Dropout::Forward(int flag, const Tensor& input) {

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/74f02143/src/model/layer/dropout.h
----------------------------------------------------------------------
diff --git a/src/model/layer/dropout.h b/src/model/layer/dropout.h
index 5efaf6a..e9ff798 100644
--- a/src/model/layer/dropout.h
+++ b/src/model/layer/dropout.h
@@ -29,7 +29,11 @@ class Dropout : public Layer {
   const std::string layer_type() const override { return "Dropout"; }
 
   /// \copydoc Layer::Setup(const LayerConf&);
-  void Setup(const LayerConf& conf) override;
+  void Setup(const Shape& in_sample, const LayerConf& conf) override;
+  const Shape GetOutputSampleShape() {
+    CHECK(out_sample_shape_.size()) << "You may haven't call Setup()";
+    return out_sample_shape_;
+  }
 
   /// \copydoc Layer::Forward(int flag, const Tensor&)
   /// if flag is kTrain, then do dropout with given dropout_ratio;
@@ -57,6 +61,7 @@ class Dropout : public Layer {
   /// the proability to set each element to 0.
   float dropout_ratio_;
   Tensor mask_;
+  vector<size_t> out_sample_shape_;
 };
 }  // namespace singa
 #endif  // SRC_MODEL_LAYER_DROPOUT_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/74f02143/src/model/layer/flatten.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/flatten.cc b/src/model/layer/flatten.cc
index 7341394..bc175a5 100644
--- a/src/model/layer/flatten.cc
+++ b/src/model/layer/flatten.cc
@@ -20,27 +20,25 @@
 #include "./flatten.h"
 namespace singa {
 
-void Flatten::Setup(const LayerConf &conf) {
-  Layer::Setup(conf);
+void Flatten::Setup(const Shape& in_sample, const LayerConf &conf) {
+  Layer::Setup(in_sample, conf);
   axis_ = conf.flatten_conf().axis();
+  size_t len = 1;
+  if (axis_ > 0)
+    for (size_t i = axis_ - 1; i < in_sample.size(); i++)
+      len *= in_sample.at(i);
+  out_sample_shape_.push_back(len);
 }
 
 const Tensor Flatten::Forward(int flag, const Tensor &input) {
-  Tensor output = input;
+  Tensor output;
   input_shape_ = input.shape();
-  if (!Axis()) {
-    // reshape to 1D
-    size_t dim = output.Size();
-    output.Reshape(Shape{dim});
-    output_shape_ = Shape{dim};
-  } else {
-    // reshape to 2D
-    size_t dim1 = 1, dim2;
-    for (int i = 0; i < Axis(); i++) dim1 *= output.shape(i);
-    dim2 = output.Size() / dim1;
-    output.Reshape(Shape{dim1, dim2});
-    output_shape_ = Shape{dim1, dim2};
-  }
+  if (axis_ == 0)
+    output = Reshape(input, vector<size_t>{input.Size()});
+  else
+    output =
+        Reshape(input, vector<size_t>{input.Size() / out_sample_shape_.at(0),
+                                      out_sample_shape_.at(0)});
   return output;
 }
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/74f02143/src/model/layer/flatten.h
----------------------------------------------------------------------
diff --git a/src/model/layer/flatten.h b/src/model/layer/flatten.h
index 580b2ba..0981f32 100644
--- a/src/model/layer/flatten.h
+++ b/src/model/layer/flatten.h
@@ -29,7 +29,11 @@ class Flatten : public Layer {
   const std::string layer_type() const override { return "Flatten"; }
 
   /// \copydoc Layer::Setup(const LayerConf&);
-  void Setup(const LayerConf &conf) override;
+  void Setup(const Shape& in_sample, const LayerConf& conf) override;
+  const Shape GetOutputSampleShape() {
+    CHECK(out_sample_shape_.size()) << "You may haven't call Setup()";
+    return out_sample_shape_;
+  }
 
   /// \copydoc Layer::Forward(int flag, const Tensor&);
   const Tensor Forward(int flag, const Tensor &input) override;
@@ -40,14 +44,13 @@ class Flatten : public Layer {
 
   const int Axis() const { return axis_; }
   const Shape input_shape() const { return input_shape_; }
-  const Shape output_shape() const { return output_shape_; }
 
  protected:
   /// flatten layer reshape the input to 2D, one from 0 to axis_-1, one from
   /// axis_ to end.
   /// if axis_ is 0, reshape the input to 1D.
   int axis_;
-  Shape input_shape_, output_shape_;
+  Shape input_shape_, out_sample_shape_;
 };
 }      // namespace singa
 #endif // SRC_MODEL_LAYER_FLATTEN_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/74f02143/src/model/layer/lrn.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/lrn.cc b/src/model/layer/lrn.cc
index 55135f1..2694fe9 100644
--- a/src/model/layer/lrn.cc
+++ b/src/model/layer/lrn.cc
@@ -21,8 +21,9 @@
 #include "lrn.h"
 
 namespace singa{
-void LRN::Setup(const LayerConf& conf) {
-  Layer::Setup(conf);
+void LRN::Setup(const Shape& in_sample, const LayerConf& conf) {
+  Layer::Setup(in_sample, conf);
+  out_sample_shape_ = in_sample;
   local_size_ = conf.lrn_conf().local_size();
   CHECK_EQ(local_size_ % 2, 1) << "LRN only supports odd values for Localvol";
   k_ = conf.lrn_conf().k();

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/74f02143/src/model/layer/lrn.h
----------------------------------------------------------------------
diff --git a/src/model/layer/lrn.h b/src/model/layer/lrn.h
index 118d062..a165d12 100644
--- a/src/model/layer/lrn.h
+++ b/src/model/layer/lrn.h
@@ -31,8 +31,12 @@ class LRN : public Layer {
     return "LRN";
   }
 
-  /// \copydoc Layer::Setup(const LayerConf&)
-  void Setup(const LayerConf& conf) override;
+  /// \copydoc Layer::Setup(const LayerConf&);
+  void Setup(const Shape& in_sample, const LayerConf& conf) override;
+  const Shape GetOutputSampleShape() {
+    CHECK(out_sample_shape_.size()) << "You may haven't call Setup()";
+    return out_sample_shape_;
+  }
 
   /**
    * Local Response Normalization edge
@@ -62,9 +66,10 @@ class LRN : public Layer {
   float alpha_, beta_, k_;
   // store intermediate data, i.e., input tensor
   std::stack<Tensor> buf_;
-  
+  Shape out_sample_shape_;
+
 }; // class LRN
-} // namespace 
+} // namespace
 
 #endif  // SINGA_MODEL_LAYER_LRN_H_
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/74f02143/src/model/layer/pooling.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/pooling.cc b/src/model/layer/pooling.cc
index 2655369..094177a 100644
--- a/src/model/layer/pooling.cc
+++ b/src/model/layer/pooling.cc
@@ -20,9 +20,8 @@
 #include "singa/model/layer.h"
 namespace singa {
 
-void Pooling::Setup(const LayerConf& conf) {
-  Layer::Setup(conf);
-
+void Pooling::Setup(const Shape& in_sample, const LayerConf& conf) {
+  Layer::Setup(in_sample, conf);
   PoolingConf pool_conf = conf.pooling_conf();
   if (pool_conf.has_kernel_size()) {
     kernel_w_ = kernel_h_ = pool_conf.kernel_size();
@@ -57,13 +56,15 @@ void Pooling::Setup(const LayerConf& conf) {
         pool_ == PoolingConf_PoolMethod_STOCHASTIC)
       << "Padding implemented only for average and max pooling.";
 
-  channels_ = pool_conf.channels();
-  height_ = pool_conf.height();
-  width_ = pool_conf.width();
+  CHECK_EQ(in_sample.size(), 3u);
+  channels_ = in_sample.at(0);
+  height_ = in_sample.at(1);
+  width_ = in_sample.at(2);
   pooled_height_ =
       static_cast<size_t>((height_ + 2 * pad_h_ - kernel_h_) / stride_h_) + 1;
   pooled_width_ =
       static_cast<size_t>((width_ + 2 * pad_w_ - kernel_w_) / stride_w_) + 1;
+  out_sample_shape_ = vector<size_t>{channels_, pooled_height_, pooled_width_};
 }
 
 const Tensor Pooling::Forward(int flag, const Tensor& input) {

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/74f02143/src/model/layer/pooling.h
----------------------------------------------------------------------
diff --git a/src/model/layer/pooling.h b/src/model/layer/pooling.h
index 522b603..ddee45b 100644
--- a/src/model/layer/pooling.h
+++ b/src/model/layer/pooling.h
@@ -30,8 +30,11 @@ class Pooling : public Layer {
   const std::string layer_type() const override { return "Pooling"; }
 
   /// \copydoc Layer::Setup(const LayerConf&);
-  void Setup(const LayerConf& conf) override;
-
+  void Setup(const Shape& in_sample, const LayerConf& conf) override;
+  const Shape GetOutputSampleShape() {
+    CHECK(out_sample_shape_.size()) << "You may haven't call Setup()";
+    return out_sample_shape_;
+  }
   /// \copydoc Layer::Forward(int flag, const Tensor&)
   const Tensor Forward(int flag, const Tensor& input) override;
 
@@ -57,6 +60,7 @@ class Pooling : public Layer {
   PoolingConf_PoolMethod pool_;
   // To store the input and output(of forward) tensors
   std::stack<Tensor> buf_;
+  Shape out_sample_shape_;
 };
 }  // namespace singa
 #endif  // SRC_MODEL_LAYER_POOLING_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/74f02143/src/model/layer/prelu.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/prelu.cc b/src/model/layer/prelu.cc
index 6766d75..6c58dbb 100644
--- a/src/model/layer/prelu.cc
+++ b/src/model/layer/prelu.cc
@@ -20,8 +20,9 @@
 #include "./prelu.h"
 namespace singa {
 
-void PReLU::Setup(const LayerConf &conf) {
-  Layer::Setup(conf);
+void PReLU::Setup(const Shape& in_sample, const LayerConf &conf) {
+  Layer::Setup(in_sample, conf);
+  out_sample_shape_ = in_sample;
   channel_shared_ = conf.prelu_conf().channel_shared();
   format_ = conf.prelu_conf().format();
   // Push back params into param_values_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/74f02143/src/model/layer/prelu.h
----------------------------------------------------------------------
diff --git a/src/model/layer/prelu.h b/src/model/layer/prelu.h
index 2ee5e9f..7387bfb 100644
--- a/src/model/layer/prelu.h
+++ b/src/model/layer/prelu.h
@@ -29,8 +29,13 @@ class PReLU : public Layer {
   /// \copydoc Layer::layer_type()
   const std::string layer_type() const override { return "PReLU"; }
 
+
   /// \copydoc Layer::Setup(const LayerConf&);
-  void Setup(const LayerConf &conf) override;
+  void Setup(const Shape& in_sample, const LayerConf& conf) override;
+  const Shape GetOutputSampleShape() {
+    CHECK(out_sample_shape_.size()) << "You may haven't call Setup()";
+    return out_sample_shape_;
+  }
 
   /// \copydoc Layer::Forward(int flag, const Tensor&)
   const Tensor Forward(int flag, const Tensor &input) override;
@@ -55,6 +60,7 @@ class PReLU : public Layer {
   std::string format_;  // format_ has two valid value, i.e. NCHW, NHWC
   Tensor a_;            // shape of a_ is 2D, i.e. (channels, 1)
   std::stack<Tensor> buf_;
+  Shape out_sample_shape_;
 };
 }  // namespace singa
 #endif  // SINGA_MODEL_LAYER_PRELU_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/74f02143/src/model/layer/softmax.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/softmax.cc b/src/model/layer/softmax.cc
index f554f25..25bb9fe 100644
--- a/src/model/layer/softmax.cc
+++ b/src/model/layer/softmax.cc
@@ -19,20 +19,22 @@
 #include "./softmax.h"
 namespace singa {
 
-void Softmax::Setup(const LayerConf& conf) {
-  Layer::Setup(conf);
-  axis_ = conf.softmax_conf().axis();  // default is 1
+void Softmax::Setup(const Shape& in_sample, const LayerConf& conf) {
+  Layer::Setup(in_sample, conf);
+  // TODO(wangwei) disable axis, use a flatten layer to reshape the tensor.
+  // axis_ = conf.softmax_conf().axis();  // default is 1
+  CHECK_EQ(in_sample.size(), 1u);
+  out_sample_shape_ = in_sample;
 }
 
 const Tensor Softmax::Forward(int flag, const Tensor& input) {
-  Tensor output;
-  if (input.nDim() == 1) {
-    output = SoftMax(input);
-  } else {
-    size_t nrow = Product(input.shape(), 0, axis_);
-    const Tensor& tmp = Reshape(input, Shape{nrow, input.Size() / nrow});
-    output = SoftMax(tmp);
-  }
+  CHECK_LE(input.nDim(), 2u);
+  Tensor output =  SoftMax(input);
+  /*
+  size_t nrow = Product(input.shape(), 0, axis_);
+  const Tensor& tmp = Reshape(input, Shape{nrow, input.Size() / nrow});
+  output = SoftMax(tmp);
+  */
   if (flag & kTrain)
     buf_.push(output);
   return output;
@@ -40,13 +42,15 @@ const Tensor Softmax::Forward(int flag, const Tensor& input) {
 
 const std::pair<Tensor, vector<Tensor>> Softmax::Backward(int flag,
                                                           const Tensor& grad) {
+  CHECK_LE(grad.nDim(), 2u);
   size_t nrow = 1, ncol = grad.Size();
-  if (grad.nDim() > 1 && axis_ > 0) {
-    nrow = Product(grad.shape(), 0, axis_);
-    ncol = Product(grad.shape(), axis_, grad.nDim());
-  }
   Tensor input_grad = grad.Clone();
-  input_grad.Reshape(Shape{nrow, ncol});
+  if (grad.nDim() > 1) {
+    nrow = grad.shape(0);
+    ncol = grad.shape(1);
+  } else {
+    input_grad.Reshape({nrow, ncol});
+  }
   CHECK(!buf_.empty());
   Tensor y = buf_.top();
   buf_.pop();

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/74f02143/src/model/layer/softmax.h
----------------------------------------------------------------------
diff --git a/src/model/layer/softmax.h b/src/model/layer/softmax.h
index ea3a70a..fed544e 100644
--- a/src/model/layer/softmax.h
+++ b/src/model/layer/softmax.h
@@ -26,7 +26,11 @@ class Softmax : public Layer {
   const std::string layer_type() const override { return "Softmax"; }
 
   /// \copydoc Layer::Setup(const LayerConf&);
-  void Setup(const LayerConf& conf) override;
+  void Setup(const Shape& in_sample, const LayerConf& conf) override;
+  const Shape GetOutputSampleShape() {
+    CHECK(out_sample_shape_.size()) << "You may haven't call Setup()";
+    return out_sample_shape_;
+  }
 
   /// \copydoc Layer::Forward(int flag, const Tensor&)
   const Tensor Forward(int flag, const Tensor& input) override;
@@ -40,6 +44,7 @@ class Softmax : public Layer {
  protected:
   int axis_;
   std::stack<Tensor> buf_;
+  Shape out_sample_shape_;
 };
 }  // namespace singa
 #endif  // SINGA_MODEL_LAYER_SOFTMAX_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/74f02143/src/proto/model.proto
----------------------------------------------------------------------
diff --git a/src/proto/model.proto b/src/proto/model.proto
index c49f767..e9746c1 100644
--- a/src/proto/model.proto
+++ b/src/proto/model.proto
@@ -384,10 +384,6 @@ message ConvolutionConf {
   // cudnn algorithm preference
   // options: "fastest", "limited_workspace", "no_workspace"
   optional string prefer = 51 [default = "fastest"];
-  // input shape
-  optional int32 channels = 52;
-  optional int32 height = 53;
-  optional int32 width = 54;
 }
 
 /*
@@ -590,7 +586,6 @@ message DenseConf {
   // May be negative to index from the end (e.g., -1 for the last axis).
   optional int32 axis = 5 [default = 1];
 
-  optional uint32 num_input = 20; // The number of inputs for the layer
   optional bool transpose = 21 [default = false]; // whether transpose or not
 }
 
@@ -664,10 +659,6 @@ message PoolingConf {
   // If global_pooling then it will pool over the size of the bottom by doing
   // kernel_h = bottom->height and kernel_w = bottom->width
   optional bool global_pooling = 12 [default = false];
-  // Shape of source
-  optional int32 channels = 50;
-  optional int32 height = 51;
-  optional int32 width = 52;
   // whether to propagate nan
   optional bool nan_prop = 53 [default = false];
 }
@@ -837,7 +828,7 @@ message SoftmaxConf {
   // The axis along which to perform the softmax -- may be negative to index
   // from the end (e.g., -1 for the last axis).
   // Any other axes will be evaluated as independent softmaxes.
-  optional int32 axis = 2 [default = 1];
+  // optional int32 axis = 2 [default = 1];
 }
 
 message TanHConf {
@@ -930,8 +921,4 @@ message BatchNormConf {
   // Used in the moving average computation runningMean =
   // newMean*factor + runningMean*(1-factor).
   optional double factor = 1 [default = 0.9];
-  // input shape
-  optional int32 channels = 2;
-  optional int32 height = 3;
-  optional int32 width = 4;
 }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/74f02143/test/singa/test_activation.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_activation.cc b/test/singa/test_activation.cc
index 2d88121..504d599 100644
--- a/test/singa/test_activation.cc
+++ b/test/singa/test_activation.cc
@@ -24,6 +24,7 @@
 #include <math.h> // exp, tanh
 
 using singa::Activation;
+using singa::Shape;
 TEST(Activation, Setup) {
   Activation acti;
   EXPECT_EQ("Activation", acti.layer_type());
@@ -33,7 +34,7 @@ TEST(Activation, Setup) {
   singa::ReLUConf* reluconf = conf.mutable_relu_conf();
   reluconf->set_negative_slope(0.5);
 
-  acti.Setup(conf);
+  acti.Setup(Shape{3}, conf);
   EXPECT_EQ("RELU", acti.Mode());
   EXPECT_EQ(0.5f, acti.Negative_slope());
 }
@@ -55,7 +56,7 @@ TEST(Activation, Forward) {
       singa::ReLUConf* reluconf = conf.mutable_relu_conf();
       reluconf->set_negative_slope(neg_slope);
     }
-    acti.Setup(conf);
+    acti.Setup(Shape{n}, conf);
 
     singa::Tensor out = acti.Forward(singa::kTrain, in);
 
@@ -100,7 +101,7 @@ TEST(Activation, Backward) {
       singa::ReLUConf* reluconf = conf.mutable_relu_conf();
       reluconf->set_negative_slope(neg_slope);
     }
-    acti.Setup(conf);
+    acti.Setup(Shape{n}, conf);
 
     singa::Tensor out = acti.Forward(singa::kTrain, in);
     const float* yptr = out.data<const float*>();

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/74f02143/test/singa/test_cudnn_activation.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_cudnn_activation.cc b/test/singa/test_cudnn_activation.cc
index 892b80b..0dac497 100644
--- a/test/singa/test_cudnn_activation.cc
+++ b/test/singa/test_cudnn_activation.cc
@@ -28,6 +28,7 @@
 #include <cudnn.h>
 
 using singa::CudnnActivation;
+using singa::Shape;
 TEST(TCudnnActivation, Setup) {
   CudnnActivation acti;
   EXPECT_EQ("CudnnActivation", acti.layer_type());
@@ -37,7 +38,7 @@ TEST(TCudnnActivation, Setup) {
   singa::ReLUConf* reluconf = conf.mutable_relu_conf();
   reluconf->set_negative_slope(0.5f);
 
-  acti.Setup(conf);
+  acti.Setup(Shape{3}, conf);
   acti.InitCudnn(1, singa::kFloat32);
   EXPECT_EQ(CUDNN_ACTIVATION_RELU, acti.CudnnMode());
   EXPECT_EQ(0.5f, acti.Negative_slope());
@@ -61,7 +62,7 @@ TEST(TCudnnActivation, Forward) {
       singa::ReLUConf* reluconf = conf.mutable_relu_conf();
       reluconf->set_negative_slope(neg_slope);
     }
-    acti.Setup(conf);
+    acti.Setup(Shape{n}, conf);
     // acti.InitCudnn(n, singa::kFloat32);
 
     singa::Tensor out = acti.Forward(singa::kTrain, in);
@@ -101,7 +102,7 @@ TEST(TCudnnActivation, Backward) {
       singa::ReLUConf* reluconf = conf.mutable_relu_conf();
       reluconf->set_negative_slope(neg_slope);
     }
-    acti.Setup(conf);
+    acti.Setup(Shape{n}, conf);
     acti.InitCudnn(n, singa::kFloat32);
     singa::Tensor out = acti.Forward(singa::kTrain, in);
     EXPECT_EQ(n, out.Size());

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/74f02143/test/singa/test_cudnn_batchnorm.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_cudnn_batchnorm.cc b/test/singa/test_cudnn_batchnorm.cc
index d38fdaa..ba090cb 100644
--- a/test/singa/test_cudnn_batchnorm.cc
+++ b/test/singa/test_cudnn_batchnorm.cc
@@ -25,7 +25,7 @@
 #include "gtest/gtest.h"
 
 using singa::CudnnBatchNorm;
-
+using singa::Shape;
 TEST(CudnnBatchNorm, Setup) {
   CudnnBatchNorm batchnorm;
   EXPECT_EQ("CudnnBatchNorm", batchnorm.layer_type());
@@ -33,10 +33,7 @@ TEST(CudnnBatchNorm, Setup) {
   singa::LayerConf conf;
   singa::BatchNormConf *batchnorm_conf = conf.mutable_batchnorm_conf();
   batchnorm_conf->set_factor(0.01);
-  batchnorm_conf->set_channels(2);
-  batchnorm_conf->set_height(4);
-  batchnorm_conf->set_width(4);
-  batchnorm.Setup(conf);
+  batchnorm.Setup(Shape{2, 4, 4}, conf);
 
   EXPECT_FLOAT_EQ(0.01, batchnorm.factor());
   EXPECT_EQ(2u, batchnorm.channels());
@@ -70,10 +67,7 @@ TEST(CudnnBatchNorm, Forward) {
   singa::LayerConf conf;
   singa::BatchNormConf *batchnorm_conf = conf.mutable_batchnorm_conf();
   batchnorm_conf->set_factor(0.9);
-  batchnorm_conf->set_channels(2);
-  batchnorm_conf->set_height(4);
-  batchnorm_conf->set_width(4);
-  batchnorm.Setup(conf);
+  batchnorm.Setup(Shape{2, 4, 4}, conf);
 
   batchnorm.ToDevice(&cuda);
   batchnorm.set_bnScale(alpha);
@@ -143,10 +137,7 @@ TEST(CudnnBatchNorm, Backward) {
   singa::LayerConf conf;
   singa::BatchNormConf *batchnorm_conf = conf.mutable_batchnorm_conf();
   batchnorm_conf->set_factor(1);
-  batchnorm_conf->set_channels(2);
-  batchnorm_conf->set_height(4);
-  batchnorm_conf->set_width(4);
-  batchnorm.Setup(conf);
+  batchnorm.Setup(Shape{2, 4, 4}, conf);
 
   const float dy[] = {
     -0.0064714, 0, 0, 0,

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/74f02143/test/singa/test_cudnn_convolution.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_cudnn_convolution.cc b/test/singa/test_cudnn_convolution.cc
index 2a17da2..3aa70dd 100644
--- a/test/singa/test_cudnn_convolution.cc
+++ b/test/singa/test_cudnn_convolution.cc
@@ -24,6 +24,7 @@
 #include "gtest/gtest.h"
 
 using singa::CudnnConvolution;
+using singa::Shape;
 TEST(CudnnConvolution, Setup) {
   CudnnConvolution conv;
   EXPECT_EQ("CudnnConvolution", conv.layer_type());
@@ -41,10 +42,7 @@ TEST(CudnnConvolution, Setup) {
   // MB
   convconf->set_workspace_byte_limit(256);
   convconf->set_prefer("fastest");
-  convconf->set_channels(1);
-  convconf->set_height(3);
-  convconf->set_width(3);
-  conv.Setup(conf);
+  conv.Setup(Shape{1, 3, 3}, conf);
 
   EXPECT_EQ(2u, conv.kernel_h());
   EXPECT_EQ(2u, conv.kernel_w());
@@ -95,10 +93,7 @@ TEST(CudnnConvolution, Forward) {
   // MB
   convconf->set_workspace_byte_limit(256);
   convconf->set_prefer("fastest");
-  convconf->set_channels(1);
-  convconf->set_height(3);
-  convconf->set_width(3);
-  conv.Setup(conf);
+  conv.Setup(Shape{1, 3, 3}, conf);
 
   // Parameter "flag" does not influence convolution
   singa::Tensor out1 = conv.Forward(singa::kTrain, in);
@@ -149,10 +144,7 @@ TEST(CudnnConvolution, Backward) {
   convconf->set_bias_term(true);
   convconf->set_workspace_byte_limit(256);
   convconf->set_prefer("fastest");
-  convconf->set_channels(1);
-  convconf->set_height(3);
-  convconf->set_width(3);
-  conv.Setup(conf);
+  conv.Setup(Shape{1, 3, 3}, conf);
 
   // Parameter "flag" does not influence convolution
   singa::Tensor out1 = conv.Forward(singa::kTrain, in);
@@ -222,10 +214,7 @@ TEST(CudnnConvolution_AT, Setup) {
   // MB
   convconf->set_workspace_byte_limit(256);
   convconf->set_prefer("autotune");
-  convconf->set_channels(1);
-  convconf->set_height(3);
-  convconf->set_width(3);
-  conv.Setup(conf);
+  conv.Setup(Shape{1, 3, 3}, conf);
 
   EXPECT_EQ(2u, conv.kernel_h());
   EXPECT_EQ(2u, conv.kernel_w());
@@ -276,10 +265,7 @@ TEST(CudnnConvolution_AT, Forward) {
   // MB
   convconf->set_workspace_byte_limit(256);
   convconf->set_prefer("autotune");
-  convconf->set_channels(1);
-  convconf->set_height(3);
-  convconf->set_width(3);
-  conv.Setup(conf);
+  conv.Setup(Shape{1, 3, 3}, conf);
 
   // Parameter "flag" does not influence convolution
   singa::Tensor out1 = conv.Forward(singa::kTrain, in);
@@ -330,10 +316,7 @@ TEST(CudnnConvolution_AT, Backward) {
   convconf->set_bias_term(true);
   convconf->set_workspace_byte_limit(256);
   convconf->set_prefer("autotune");
-  convconf->set_channels(1);
-  convconf->set_height(3);
-  convconf->set_width(3);
-  conv.Setup(conf);
+  conv.Setup(Shape{1, 3, 3}, conf);
 
   // Parameter "flag" does not influence convolution
   singa::Tensor out1 = conv.Forward(singa::kTrain, in);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/74f02143/test/singa/test_cudnn_dropout.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_cudnn_dropout.cc b/test/singa/test_cudnn_dropout.cc
index 32572d0..b8ce068 100644
--- a/test/singa/test_cudnn_dropout.cc
+++ b/test/singa/test_cudnn_dropout.cc
@@ -33,6 +33,7 @@ bool inline GetBitValue(const char* x, int pos) {
 }
 
 using singa::CudnnDropout;
+using singa::Shape;
 TEST(CudnnDropout, Setup) {
   CudnnDropout drop;
   EXPECT_EQ("CudnnDropout", drop.layer_type());
@@ -41,7 +42,7 @@ TEST(CudnnDropout, Setup) {
   singa::DropoutConf* dropconf = conf.mutable_dropout_conf();
   dropconf->set_dropout_ratio(0.8);
 
-  drop.Setup(conf);
+  drop.Setup(Shape{1}, conf);
   EXPECT_EQ(0.8f, drop.dropout_ratio());
 }
 
@@ -57,7 +58,7 @@ TEST(CudnnDropout, Forward) {
   singa::LayerConf conf;
   singa::DropoutConf* dropconf = conf.mutable_dropout_conf();
   dropconf->set_dropout_ratio(pdrop);
-  drop.Setup(conf);
+  drop.Setup(Shape{1}, conf);
 
   singa::Tensor out1 = drop.Forward(singa::kTrain, in);
 
@@ -101,7 +102,7 @@ TEST(CudnnDropout, Backward) {
   singa::LayerConf conf;
   singa::DropoutConf* dropconf = conf.mutable_dropout_conf();
   dropconf->set_dropout_ratio(pdrop);
-  drop.Setup(conf);
+  drop.Setup(Shape{1}, conf);
   singa::Tensor out1 = drop.Forward(singa::kTrain, in);
 
   const float dy[] = {4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 1.0f, 2.0f, 3.0f};

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/74f02143/test/singa/test_cudnn_lrn.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_cudnn_lrn.cc b/test/singa/test_cudnn_lrn.cc
index 390c588..8576943 100644
--- a/test/singa/test_cudnn_lrn.cc
+++ b/test/singa/test_cudnn_lrn.cc
@@ -27,7 +27,7 @@
 #include "gtest/gtest.h"
 
 using singa::CudnnLRN;
-
+using singa::Shape;
 TEST(CudnnLRN, Setup) {
   CudnnLRN lrn;
   EXPECT_EQ("CudnnLRN", lrn.layer_type());
@@ -38,7 +38,7 @@ TEST(CudnnLRN, Setup) {
   lrn_conf->set_local_size(3);
   lrn_conf->set_alpha(0.1);
   lrn_conf->set_beta(0.75);
-  lrn.Setup(conf);
+  lrn.Setup(Shape{1}, conf);
 
   EXPECT_FLOAT_EQ(1.0, lrn.k());
   EXPECT_EQ(3, lrn.local_size());
@@ -68,7 +68,7 @@ TEST(CudnnLRN, Forward) {
   lrn_conf->set_local_size(3);
   lrn_conf->set_alpha(0.1);
   lrn_conf->set_beta(0.75);
-  lrn.Setup(conf);
+  lrn.Setup(Shape{2, 4, 4}, conf);
 
   singa::Tensor out = lrn.Forward(singa::kTrain, in);
   singa::CppCPU host(0, 1);
@@ -152,7 +152,7 @@ TEST(CudnnLRN, Backward) {
   lrn_conf->set_local_size(3);
   lrn_conf->set_alpha(0.1);
   lrn_conf->set_beta(0.75);
-  lrn.Setup(conf);
+  lrn.Setup(Shape{2, 4, 4}, conf);
 
   lrn.Forward(singa::kTrain, x_tensor);
   const auto ret = lrn.Backward(singa::kTrain, dy_tensor);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/74f02143/test/singa/test_cudnn_pooling.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_cudnn_pooling.cc b/test/singa/test_cudnn_pooling.cc
index e66f212..c7f9061 100644
--- a/test/singa/test_cudnn_pooling.cc
+++ b/test/singa/test_cudnn_pooling.cc
@@ -24,6 +24,7 @@
 #include "gtest/gtest.h"
 
 using singa::CudnnPooling;
+using singa::Shape;
 TEST(CudnnPooling, Setup) {
   CudnnPooling pool;
   EXPECT_EQ("CudnnPooling", pool.layer_type());
@@ -37,10 +38,7 @@ TEST(CudnnPooling, Setup) {
   poolconf->set_pad_w(0);
   poolconf->set_stride_h(2);
   poolconf->set_stride_w(1);
-  poolconf->set_channels(1);
-  poolconf->set_height(3);
-  poolconf->set_width(3);
-  pool.Setup(conf);
+  pool.Setup(Shape{1, 3, 3}, conf);
 
   EXPECT_EQ(singa::PoolingConf_PoolMethod_MAX, pool.pool_method());
   EXPECT_EQ(1u, pool.kernel_h());
@@ -72,10 +70,7 @@ TEST(CudnnPooling, Forward) {
   poolconf->set_pad_w(0);
   poolconf->set_stride_h(1);
   poolconf->set_stride_w(1);
-  poolconf->set_channels(1);
-  poolconf->set_height(3);
-  poolconf->set_width(3);
-  pool.Setup(conf);
+  pool.Setup(Shape{1, 3, 3}, conf);
 
   // Parameter "flag" does not influence pooling
   singa::Tensor out1 = pool.Forward(singa::kTrain, in);
@@ -109,10 +104,7 @@ TEST(CudnnPooling, Backward) {
   poolconf->set_pad_w(0);
   poolconf->set_stride_h(1);
   poolconf->set_stride_w(1);
-  poolconf->set_channels(1);
-  poolconf->set_height(3);
-  poolconf->set_width(3);
-  pool.Setup(conf);
+  pool.Setup(Shape{1, 3, 3}, conf);
 
   singa::Tensor out1 = pool.Forward(singa::kTrain, in);
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/74f02143/test/singa/test_cudnn_softmax.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_cudnn_softmax.cc b/test/singa/test_cudnn_softmax.cc
index 05783e2..d671ecf 100644
--- a/test/singa/test_cudnn_softmax.cc
+++ b/test/singa/test_cudnn_softmax.cc
@@ -26,18 +26,17 @@
 #include <math.h>  // exp
 #include <cudnn.h>
 
+// TODO(wangwei) add test for matrix input
 using singa::CudnnSoftmax;
+using singa::Shape;
 TEST(CudnnSoftmax, Setup) {
   CudnnSoftmax sft;
   EXPECT_EQ("CudnnSoftmax", sft.layer_type());
 
   singa::LayerConf conf;
-  singa::SoftmaxConf* softmaxconf = conf.mutable_softmax_conf();
-  softmaxconf->set_axis(2);
 
-  sft.Setup(conf);
+  sft.Setup(Shape{4}, conf);
   sft.InitCudnn(1, singa::kFloat32);
-  EXPECT_EQ(2, sft.Axis());
 }
 
 TEST(CudnnSoftmax, Forward) {
@@ -47,12 +46,9 @@ TEST(CudnnSoftmax, Forward) {
   singa::Tensor in(singa::Shape{n}, &cuda);
   in.CopyDataFromHostPtr<float>(x, n);
 
-  int axis = 1;
   CudnnSoftmax sft;
   singa::LayerConf conf;
-  singa::SoftmaxConf* softmaxconf = conf.mutable_softmax_conf();
-  softmaxconf->set_axis(axis);
-  sft.Setup(conf);
+  sft.Setup(Shape{1}, conf);
   sft.InitCudnn(n, singa::kFloat32);
 
   singa::Tensor out = sft.Forward(singa::kTrain, in);
@@ -77,12 +73,9 @@ TEST(CudnnSoftmax, Backward) {
   singa::Tensor in(singa::Shape{n}, &cuda);
   in.CopyDataFromHostPtr<float>(x, n);
 
-  int axis = 1;
   CudnnSoftmax sft;
   singa::LayerConf conf;
-  singa::SoftmaxConf* softmaxconf = conf.mutable_softmax_conf();
-  softmaxconf->set_axis(axis);
-  sft.Setup(conf);
+  sft.Setup(Shape{1}, conf);
   singa::Tensor out = sft.Forward(singa::kTrain, in);
   singa::CppCPU host(0, 1);
   out.ToDevice(&host);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/74f02143/test/singa/test_dense.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_dense.cc b/test/singa/test_dense.cc
index 5050d7e..6f5518e 100644
--- a/test/singa/test_dense.cc
+++ b/test/singa/test_dense.cc
@@ -23,16 +23,16 @@
 #include "singa_config.h"
 
 using singa::Dense;
+using singa::Shape;
 TEST(Dense, Setup) {
   Dense dense;
   EXPECT_EQ("Dense", dense.layer_type());
 
   singa::LayerConf conf;
   singa::DenseConf *denseconf = conf.mutable_dense_conf();
-  denseconf->set_num_input(2);
   denseconf->set_num_output(3);
   denseconf->set_transpose(false);
-  dense.Setup(conf);
+  dense.Setup(Shape{2}, conf);
 
   EXPECT_EQ(3u, dense.num_output());
   EXPECT_EQ(2u, dense.num_input());
@@ -43,10 +43,9 @@ TEST(Dense, ForwardCpp) {
 
   singa::LayerConf conf;
   singa::DenseConf *denseconf = conf.mutable_dense_conf();
-  denseconf->set_num_input(2);
   denseconf->set_num_output(3);
   denseconf->set_transpose(false);
-  dense.Setup(conf);
+  dense.Setup(Shape{2}, conf);
 
   const size_t batchsize = 3, vdim = 2, hdim = 3;
   const float x[batchsize * vdim] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
@@ -82,10 +81,9 @@ TEST(Dense, BackwardCpp) {
 
   singa::LayerConf conf;
   singa::DenseConf *denseconf = conf.mutable_dense_conf();
-  denseconf->set_num_input(2);
   denseconf->set_num_output(3);
   denseconf->set_transpose(false);
-  dense.Setup(conf);
+  dense.Setup(Shape{2}, conf);
 
   const size_t batchsize = 3, vdim = 2, hdim = 3;
   const float x[batchsize * vdim] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
@@ -147,10 +145,9 @@ TEST(Dense, ForwardCuda) {
 
   singa::LayerConf conf;
   singa::DenseConf *denseconf = conf.mutable_dense_conf();
-  denseconf->set_num_input(2);
   denseconf->set_num_output(3);
   denseconf->set_transpose(false);
-  dense.Setup(conf);
+  dense.Setup(Shape{2}, conf);
 
   const size_t batchsize = 3, vdim = 2, hdim = 3;
   const float x[batchsize * vdim] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
@@ -186,10 +183,9 @@ TEST(Dense, BackwardCuda) {
 
   singa::LayerConf conf;
   singa::DenseConf *denseconf = conf.mutable_dense_conf();
-  denseconf->set_num_input(2);
   denseconf->set_num_output(3);
   denseconf->set_transpose(false);
-  dense.Setup(conf);
+  dense.Setup(Shape{2}, conf);
 
   const size_t batchsize = 3, vdim = 2, hdim = 3;
   const float x[batchsize * vdim] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/74f02143/test/singa/test_dropout.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_dropout.cc b/test/singa/test_dropout.cc
index d648ff8..b6ae9c6 100644
--- a/test/singa/test_dropout.cc
+++ b/test/singa/test_dropout.cc
@@ -23,6 +23,7 @@
 #include "gtest/gtest.h"
 
 using singa::Dropout;
+using singa::Shape;
 TEST(Dropout, Setup) {
   Dropout drop;
   EXPECT_EQ("Dropout", drop.layer_type());
@@ -31,7 +32,7 @@ TEST(Dropout, Setup) {
   singa::DropoutConf* dropconf = conf.mutable_dropout_conf();
   dropconf->set_dropout_ratio(0.8);
 
-  drop.Setup(conf);
+  drop.Setup(Shape{3}, conf);
   EXPECT_EQ(0.8f, drop.dropout_ratio());
 }
 
@@ -46,7 +47,7 @@ TEST(Dropout, Forward) {
   singa::LayerConf conf;
   singa::DropoutConf* dropconf = conf.mutable_dropout_conf();
   dropconf->set_dropout_ratio(pdrop);
-  drop.Setup(conf);
+  drop.Setup(Shape{1}, conf);
   float scale = 1.0f / (1.0f - pdrop);
 
   singa::Tensor out1 = drop.Forward(singa::kTrain, in);
@@ -84,7 +85,7 @@ TEST(Dropout, Backward) {
   singa::LayerConf conf;
   singa::DropoutConf* dropconf = conf.mutable_dropout_conf();
   dropconf->set_dropout_ratio(pdrop);
-  drop.Setup(conf);
+  drop.Setup(Shape{1}, conf);
   singa::Tensor out1 = drop.Forward(singa::kTrain, in);
 
   const float dy[] = {4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 1.0f, 2.0f, 3.0f};

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/74f02143/test/singa/test_flatten.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_flatten.cc b/test/singa/test_flatten.cc
index 0ba8d3c..f139a75 100644
--- a/test/singa/test_flatten.cc
+++ b/test/singa/test_flatten.cc
@@ -23,6 +23,7 @@
 #include "gtest/gtest.h"
 
 using singa::Flatten;
+using singa::Shape;
 TEST(Flatten, Setup) {
   Flatten flt;
   EXPECT_EQ("Flatten", flt.layer_type());
@@ -31,7 +32,7 @@ TEST(Flatten, Setup) {
   singa::FlattenConf *flattenconf = conf.mutable_flatten_conf();
   flattenconf->set_axis(1);
 
-  flt.Setup(conf);
+  flt.Setup(Shape{2}, conf);
   EXPECT_EQ(1, flt.Axis());
 }
 
@@ -48,7 +49,7 @@ TEST(Flatten, ForwardCPU) {
   singa::LayerConf conf;
   singa::FlattenConf *flattenconf = conf.mutable_flatten_conf();
   flattenconf->set_axis(axis);
-  flt.Setup(conf);
+  flt.Setup(Shape{1, 3, 2}, conf);
 
   singa::Tensor out = flt.Forward(singa::kTrain, in);
   EXPECT_EQ(n, out.Size());
@@ -72,7 +73,7 @@ TEST(Flatten, BackwardCPU) {
   singa::LayerConf conf;
   singa::FlattenConf *flattenconf = conf.mutable_flatten_conf();
   flattenconf->set_axis(axis);
-  flt.Setup(conf);
+  flt.Setup(Shape{1, 3, 2}, conf);
 
   singa::Tensor temp = flt.Forward(singa::kTrain, in);
   const auto out = flt.Backward(singa::kTrain, temp);
@@ -99,7 +100,7 @@ TEST(Flatten, ForwardGPU) {
   singa::LayerConf conf;
   singa::FlattenConf *flattenconf = conf.mutable_flatten_conf();
   flattenconf->set_axis(axis);
-  flt.Setup(conf);
+  flt.Setup(Shape{1, 3, 2}, conf);
 
   singa::Tensor out = flt.Forward(singa::kTrain, in);
   singa::CppCPU host(0, 1);
@@ -126,7 +127,7 @@ TEST(Flatten, BackwardGPU) {
   singa::LayerConf conf;
   singa::FlattenConf *flattenconf = conf.mutable_flatten_conf();
   flattenconf->set_axis(axis);
-  flt.Setup(conf);
+  flt.Setup(Shape{1, 3, 2}, conf);
 
   singa::Tensor out = flt.Forward(singa::kTrain, in);
   const auto ret = flt.Backward(singa::kTrain, out);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/74f02143/test/singa/test_prelu.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_prelu.cc b/test/singa/test_prelu.cc
index 883935d..faff093 100644
--- a/test/singa/test_prelu.cc
+++ b/test/singa/test_prelu.cc
@@ -24,6 +24,7 @@
 #include "singa_config.h"
 
 using singa::PReLU;
+using singa::Shape;
 TEST(PReLU, Setup) {
   PReLU prelu;
   EXPECT_EQ("PReLU", prelu.layer_type());
@@ -33,7 +34,7 @@ TEST(PReLU, Setup) {
   preluconf->set_channel_shared(true);
   preluconf->set_format("NHWC");
 
-  prelu.Setup(conf);
+  prelu.Setup(Shape{4}, conf);
   EXPECT_EQ(true, prelu.Channel_shared());
   EXPECT_EQ("NHWC", prelu.Format());
 }
@@ -51,7 +52,7 @@ TEST(PReLU, ForwardCPU) {
   singa::PReLUConf *preluconf = conf.mutable_prelu_conf();
   preluconf->set_channel_shared(false);
   preluconf->set_format("NHWC");
-  prelu.Setup(conf);
+  prelu.Setup(Shape{h, w, c}, conf);
 
   const float neg_slope[] = {0.25f, 0.5f, 0.75f};
   singa::Tensor a(singa::Shape{c});
@@ -91,7 +92,7 @@ TEST(PReLU, BackwardCPU) {
   singa::PReLUConf *preluconf = conf.mutable_prelu_conf();
   preluconf->set_channel_shared(false);
   preluconf->set_format("NCHW");
-  prelu.Setup(conf);
+  prelu.Setup(Shape{c, h, w}, conf);
 
   const float neg_slope[] = {0.25f, 0.5f, 0.75f};
   singa::Tensor a(singa::Shape{c});
@@ -151,7 +152,7 @@ TEST(PReLU, ForwardGPU) {
   singa::PReLUConf *preluconf = conf.mutable_prelu_conf();
   preluconf->set_channel_shared(false);
   preluconf->set_format("NHWC");
-  prelu.Setup(conf);
+  prelu.Setup(Shape{h, w, c}, conf);
 
   const float neg_slope[] = {0.25f, 0.5f, 0.75f};
   singa::Tensor a(singa::Shape{c}, &cuda);
@@ -194,7 +195,7 @@ TEST(PReLU, BackwardGPU) {
   singa::PReLUConf *preluconf = conf.mutable_prelu_conf();
   preluconf->set_channel_shared(false);
   preluconf->set_format("NCHW");
-  prelu.Setup(conf);
+  prelu.Setup(Shape{c, h, w}, conf);
 
   const float neg_slope[] = {0.25f, 0.5f, 0.75f};
   singa::Tensor a(singa::Shape{c}, &cuda);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/74f02143/test/singa/test_softmax.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_softmax.cc b/test/singa/test_softmax.cc
index fff8510..2bf4505 100644
--- a/test/singa/test_softmax.cc
+++ b/test/singa/test_softmax.cc
@@ -24,16 +24,13 @@
 #include <math.h> // exp
 
 using singa::Softmax;
+using singa::Shape;
 TEST(Softmax, Setup) {
   Softmax sft;
   EXPECT_EQ("Softmax", sft.layer_type());
 
   singa::LayerConf conf;
-  singa::SoftmaxConf* softmaxconf = conf.mutable_softmax_conf();
-  softmaxconf->set_axis(2);
-
-  sft.Setup(conf);
-  EXPECT_EQ(2, sft.Axis());
+  sft.Setup(Shape{3}, conf);
 }
 
 #ifdef USE_CBLAS
@@ -45,12 +42,9 @@ TEST(Softmax, Forward) {
   singa::Tensor in(singa::Shape{row, col});
   in.CopyDataFromHostPtr<float>(x, row * col);
 
-  int axis = 1;
   Softmax sft;
   singa::LayerConf conf;
-  singa::SoftmaxConf* softmaxconf = conf.mutable_softmax_conf();
-  softmaxconf->set_axis(axis);
-  sft.Setup(conf);
+  sft.Setup(Shape{col}, conf);
 
   singa::Tensor out = sft.Forward(singa::kTrain, in);
   const float* yptr = out.data<const float*>();
@@ -76,12 +70,9 @@ TEST(Softmax, Backward) {
   singa::Tensor in(singa::Shape{row, col});
   in.CopyDataFromHostPtr<float>(x, n);
 
-  int axis = 1;
   Softmax sft;
   singa::LayerConf conf;
-  singa::SoftmaxConf* softmaxconf = conf.mutable_softmax_conf();
-  softmaxconf->set_axis(axis);
-  sft.Setup(conf);
+  sft.Setup(Shape{col}, conf);
   singa::Tensor out = sft.Forward(singa::kTrain, in);
   const float* yptr = out.data<const float*>();