You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by ka...@apache.org on 2016/12/15 03:26:30 UTC

incubator-singa git commit: SINGA-271 Add Concat and Slice layers

Repository: incubator-singa
Updated Branches:
  refs/heads/master 33cf5461b -> 82c12c2b9


SINGA-271 Add Concat and Slice layers

fixed some bugs from concat and slice layers:

1. the input shape (list) of the layer setup function should be the
shape of a single sample
2. the order of source layers of the concat layer should match the
order of the gradient tensors for the net::backward() function
3. the input and ouput of concate layer's forward is a list of tensor->single tensor;
slice layer's forward is: a single tensor->a list tensor; for the backward, the data type is revsersed as the foward function.


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/82c12c2b
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/82c12c2b
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/82c12c2b

Branch: refs/heads/master
Commit: 82c12c2b9e81926cbd46f621fa7f8da98f2217e8
Parents: 33cf546
Author: Wei Wang <wa...@comp.nus.edu.sg>
Authored: Fri Dec 2 16:27:20 2016 +0800
Committer: Wei Wang <wa...@gmail.com>
Committed: Wed Dec 14 14:46:23 2016 +0800

----------------------------------------------------------------------
 python/singa/layer.py     | 126 +++++++++++++++++++++++++++++++++--------
 python/singa/net.py       |   5 +-
 src/model/layer/concat.cc |  34 ++++++-----
 src/model/layer/concat.h  |   5 +-
 src/model/layer/slice.cc  |  31 +++++-----
 src/model/layer/slice.h   |   2 +
 test/python/test_layer.py |  18 +++++-
 test/singa/test_concat.cc |  23 ++++++--
 test/singa/test_slice.cc  |  19 +++----
 9 files changed, 191 insertions(+), 72 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/82c12c2b/python/singa/layer.py
----------------------------------------------------------------------
diff --git a/python/singa/layer.py b/python/singa/layer.py
index a7a14c5..f82d2df 100644
--- a/python/singa/layer.py
+++ b/python/singa/layer.py
@@ -116,26 +116,6 @@ class Layer(object):
 
         self.has_setup = False
 
-    def caffe_layer(self):
-        '''
-        Create a singa layer based on caffe layer configuration.
-        '''
-        _check_engine(engine, ['cudnn', 'singacpp', 'singacuda'])
-        if self.conf.type == 'InnerProduct' or self.conf.type == 14:
-            self.layer = _create_layer(engine, 'Dense')
-        else:
-            self.layer = _create_layer(engine, str(self.conf.type))
-
-    def param_names(self):
-        '''
-        Returns:
-            a list of strings, one for the name of one parameter Tensor
-        '''
-        names = []
-        for x in self.param_specs:
-            names.append(x['name'])
-        return names
-
     def setup(self, in_shapes):
         '''Call the C++ setup function to create params and set some meta data.
 
@@ -154,6 +134,17 @@ class Layer(object):
             self.layer.Setup(list(in_shapes), self.conf.SerializeToString())
         self.has_setup = True
 
+    def caffe_layer(self):
+        '''
+        Create a singa layer based on caffe layer configuration.
+        '''
+        _check_engine(engine, ['cudnn', 'singacpp', 'singacuda'])
+        if self.conf.type == 'InnerProduct' or self.conf.type == 14:
+            self.layer = _create_layer(engine, 'Dense')
+        else:
+            self.layer = _create_layer(engine, str(self.conf.type))
+
+
     def get_output_sample_shape(self):
         '''Called after setup to get the shape of the output sample(s).
 
@@ -165,6 +156,16 @@ class Layer(object):
             'Must call setup() before get_output_sample_shape()'
         return self.layer.GetOutputSampleShape()
 
+    def param_names(self):
+        '''
+        Returns:
+            a list of strings, one for the name of one parameter Tensor
+        '''
+        names = []
+        for x in self.param_specs:
+            names.append(x['name'])
+        return names
+
     def param_values(self):
         '''Return param value tensors.
 
@@ -278,10 +279,12 @@ class Dummy(Layer):
         self.has_setup = True
 
     def forward(self, flag, x):
+        '''Return the input x'''
         return x
 
     def backward(self, falg, dy):
-        return dy
+        '''Return dy, []'''
+        return dy, []
 
 
 class Conv2D(Layer):
@@ -732,6 +735,7 @@ class Merge(Layer):
     def forward(self, flag, inputs):
         '''Merge all input tensors by summation.
 
+        TODO(wangwei) do element-wise merge operations, e.g., avg, count
         Args:
             flag: not used.
             inputs (list): a list of tensors
@@ -749,6 +753,14 @@ class Merge(Layer):
         return output
 
     def backward(self, flag, grad):
+        '''Replicate the grad for each input source layer.
+
+        Args:
+            grad(Tensor), the gradient tensor of the merged result from forward
+
+        Returns:
+            A list of replicated grad, one per source layer
+        '''
         assert isinstance(grad, tensor.Tensor), 'The input must be Tensor'
         return [grad] * self.num_input, []  # * self.num_input
 
@@ -789,6 +801,14 @@ class Split(Layer):
         return outputs
 
     def backward(self, flag, grads):
+        '''Sum all grad tensors to generate a single output tensor.
+
+        Args:
+            grads(list of Tensor), one per dest layer
+
+        Returns:
+            a single tensor as the sum of all grads
+        '''
         assert len(grads) > 1, 'There must be multiple gradients'
         dx = tensor.Tensor()
         dx.reset_like(grads[0])
@@ -805,7 +825,7 @@ class Concat(Layer):
 
     Args:
         axis(int): 0 for concat row; 1 for concat columns;
-        input_sample_shapes: a list of shape tuples, one per input tensor
+        input_sample_shapes: a list of sample shape tuples, one per input tensor
     '''
 
     def __init__(self, name, axis, input_sample_shapes=None):
@@ -820,6 +840,36 @@ class Concat(Layer):
         if input_sample_shapes is not None:
             self.setup(input_sample_shapes)
 
+    def forward(self, flag, inputs):
+        '''Concatenate all input tensors.
+
+        Args:
+            flag: same as Layer::forward()
+            input: a list of tensors
+
+        Returns:
+            a single concatenated tensor
+        '''
+        assert type(inputs) is list, 'Must be a list of Tensors'
+        ys = super(Concat, self).forward(flag, inputs)
+        return ys[0]
+
+
+    def backward(self, flag, dy):
+        '''Backward propagate gradients through this layer.
+
+        Args:
+            flag: same as Layer::backward()
+            dy(Tensor): the gradient tensors of y w.r.t objective loss
+        Return:
+            <dx, []>, dx is a list tensors for the gradient of the inputs; []
+               is an empty list.
+        '''
+        if type(dy) is tensor.Tensor:
+            dy = [dy]
+        assert type(dy) is list, 'Must be a list(Tensor)'
+        return super(Concat, self).backward(flag, dy)
+
 
 class Slice(Layer):
     '''Slice the input tensor into multiple sub-tensors vertially (axis=0) or
@@ -829,7 +879,7 @@ class Slice(Layer):
         axis (int): 0 for slice rows; 1 for slice columns;
         slice_point(list): positions along the axis to do slice; there are n-1
             points for n sub-tensors;
-        input_sample_shape: input tensor shape
+        input_sample_shape: input tensor sample shape
     '''
 
     def __init__(self, name, axis, slice_point, input_sample_shape=None):
@@ -850,6 +900,36 @@ class Slice(Layer):
         for i in range(len(self.conf.slice_conf.slice_point) + 1):
             out.append(self.layer.GetOutputSampleShape(i))
 
+    def forward(self, flag, x):
+        '''Slice the input tensor on the given axis.
+
+        Args:
+            flag: same as Layer::forward()
+            x: a single input tensor
+
+        Returns:
+            a list a output tensor
+        '''
+        if type(x) is tensor.Tensor:
+            x = [x]
+        assert type(x) is list, 'Must be a list of Tensor'
+        return super(Slice, self).forward(flag, x)
+
+    def backward(self, flag, grads):
+        '''Concate all grad tensors to generate a single output tensor
+
+        Args:
+            flag: same as Layer::backward()
+            grads: a list of tensors, one for the gradient of one sliced tensor
+
+        Returns:
+            a single tensor for the gradient of the original user, and an empty
+                list.
+        '''
+        assert len(grads) > 1, 'There must be multiple gradients'
+        dxs, _ = super(Slice, self).backward(flag, grads)
+        return dxs[0], []
+
 
 class RNN(Layer):
     '''Recurrent layer with 4 types of units, namely lstm, gru, tanh and relu.

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/82c12c2b/python/singa/net.py
----------------------------------------------------------------------
diff --git a/python/singa/net.py b/python/singa/net.py
index d34afbc..b6f973d 100644
--- a/python/singa/net.py
+++ b/python/singa/net.py
@@ -301,7 +301,10 @@ class FeedForwardNet(object):
                 grads = grads[0]
             outs, _pgrads = cur.backward(kTrain, grads)
             pgrads.append(_pgrads)
-            output_of_layer[cur.name] = outs
+            if type(outs) is list:
+                output_of_layer[cur.name] = reversed(outs)
+            else:
+                output_of_layer[cur.name] = outs
             grads = []
 
         ret = []

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/82c12c2b/src/model/layer/concat.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/concat.cc b/src/model/layer/concat.cc
index b1c0b11..88c2409 100644
--- a/src/model/layer/concat.cc
+++ b/src/model/layer/concat.cc
@@ -27,19 +27,28 @@ RegisterLayerClass(singacl_concat, Concat);
 
 void Concat::Setup(const vector<Shape>& in_shapes, const LayerConf& conf) {
   Layer::Setup(in_shapes, conf);
-  dim_size_.clear();
+  out_sample_shape_.clear();
+  slice_point_.clear();
   axis_ = conf.concat_conf().axis();
-  out_sample_shape_ = {0, 0};
-  out_sample_shape_[1 - axis_] = in_shapes[0][1 - axis_];
-  for (auto& s: in_shapes) {
-    out_sample_shape_[axis_] += s[axis_];
-    dim_size_.push_back(s[axis_]);
-    // LOG(ERROR) << s[axis_];
+  if (axis_ == 0)
+    out_sample_shape_.push_back(in_shapes[0][0]);
+  else {
+    size_t l = 0;
+    for (auto& s: in_shapes) {
+       l += s[0];
+    }
+    out_sample_shape_.push_back(l);
   }
 }
 
 const vector<Tensor> Concat::Forward(int flag, const vector<Tensor>& inputs) {
   vector<Tensor> outputs;
+  slice_point_.clear();
+  size_t offset = 0;
+  for (auto& x : inputs) {
+    offset += x.shape(axis_);
+    slice_point_.push_back(offset);
+  }
   if (inputs.size() == 1u) {
     outputs = inputs;
   } else {
@@ -55,14 +64,13 @@ const std::pair<vector<Tensor>, vector<Tensor>> Concat::Backward(
     int flag, const vector<Tensor>& grads) {
   vector<Tensor> input_grad, param_grad;
   CHECK_EQ(grads.size(), 1u) << "Concat layer only have one output tensor.";
-  for (size_t i = 0, offset = 0; i < dim_size_.size(); i++) {
+  size_t last_offset = 0u;
+  for (auto p : slice_point_) {
     if (axis_ == 0)
-      input_grad.push_back(SliceRows(grads.at(0), offset,
-            offset + dim_size_[i]));
+      input_grad.push_back(SliceRows(grads.at(0), last_offset, p));
     else
-      input_grad.push_back(SliceColumns(grads.at(0), offset,
-            offset + dim_size_[i]));
-    offset += dim_size_[i];
+      input_grad.push_back(SliceColumns(grads.at(0), last_offset, p));
+    last_offset = p;
   }
   return std::make_pair(input_grad, param_grad);
 }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/82c12c2b/src/model/layer/concat.h
----------------------------------------------------------------------
diff --git a/src/model/layer/concat.h b/src/model/layer/concat.h
index 59293d7..4e9a967 100644
--- a/src/model/layer/concat.h
+++ b/src/model/layer/concat.h
@@ -45,8 +45,9 @@ class Concat : public Layer {
  protected:
   /// 0 for concat rows; 1 for concat cols
   int axis_ = 0;
-  /// dim_size_[i] the size of the i-th source tensor on the concat dim
-  vector<int> dim_size_;
+  /// slice_point_[i] the end offset of the i-th source tensor on the concat
+  /// axis of the result tensor
+  vector<int> slice_point_;
   Shape out_sample_shape_;
 };
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/82c12c2b/src/model/layer/slice.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/slice.cc b/src/model/layer/slice.cc
index 690a03e..66c05ee 100644
--- a/src/model/layer/slice.cc
+++ b/src/model/layer/slice.cc
@@ -28,32 +28,37 @@ RegisterLayerClass(singacl_slice, Slice);
 void Slice::Setup(const Shape& in_sample, const LayerConf& conf) {
   Layer::Setup(in_sample, conf);
   out_sample_shapes_.clear();
+  slice_point_.clear();
   axis_ = conf.slice_conf().axis();
   int offset = 0;
   // #slice point = # out tensors - 1
   for (size_t p : conf.slice_conf().slice_point()) {
-    Shape s{0, 0};
-    s[1 - axis_] = in_sample[1 - axis_];
-    s[axis_] = p - offset;
-    offset = p;
-    out_sample_shapes_.push_back(s);
+    slice_point_.push_back(p);
+    if (axis_ == 1) {
+      out_sample_shapes_.push_back({p - offset});
+      offset = p;
+    } else {
+      out_sample_shapes_.push_back(in_sample);
+    }
+  }
+  slice_point_.push_back(in_sample[0]);
+  if (axis_ == 1) {
+    out_sample_shapes_.push_back({in_sample[0] - offset});
+  } else {
+    out_sample_shapes_.push_back(in_sample);
   }
-  Shape s{0, 0};
-  s[1 - axis_] = in_sample[1 - axis_];
-  s[axis_] = in_sample[axis_] - offset;
-  out_sample_shapes_.push_back(s);
 }
 
 const vector<Tensor> Slice::Forward(int flag, const vector<Tensor>& inputs) {
   vector<Tensor> outputs;
   CHECK_EQ(inputs.size(), 1u) << "Split layer only have one input tensor.";
   size_t offset = 0;
-  for (auto& s : out_sample_shapes_) {
+  for (auto& s : slice_point_) {
     if (axis_ == 0)
-      outputs.push_back(SliceRows(inputs.at(0), offset, offset + s[axis_]));
+      outputs.push_back(SliceRows(inputs.at(0), offset, s));
     else
-      outputs.push_back(SliceColumns(inputs.at(0), offset, offset + s[axis_]));
-    offset += s[axis_];
+      outputs.push_back(SliceColumns(inputs.at(0), offset, s));
+    offset = s;
   }
   return outputs;
 }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/82c12c2b/src/model/layer/slice.h
----------------------------------------------------------------------
diff --git a/src/model/layer/slice.h b/src/model/layer/slice.h
index 99ce468..7ed61fc 100644
--- a/src/model/layer/slice.h
+++ b/src/model/layer/slice.h
@@ -48,6 +48,8 @@ class Slice : public Layer {
   int axis_ = 0;
   /// out_sample_shapes_[i] is the shape of the i-th output tensor
   vector<Shape> out_sample_shapes_;
+  /// slice point, end offset of each output
+  vector<size_t> slice_point_;
 };
 
 }  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/82c12c2b/test/python/test_layer.py
----------------------------------------------------------------------
diff --git a/test/python/test_layer.py b/test/python/test_layer.py
index d22207f..c0f19f3 100644
--- a/test/python/test_layer.py
+++ b/test/python/test_layer.py
@@ -209,21 +209,33 @@ class TestPythonLayer(unittest.TestCase):
         t2 = tensor.Tensor((1, 3))
         t1.set_value(1)
         t2.set_value(2)
-        lyr = layer.Concat('concat', 0, [t1.shape, t2.shape])
+        lyr = layer.Concat('concat', 0, [(3,), (3,)])
         t = lyr.forward(model_pb2.kTrain, [t1, t2])
-        tnp = tensor.to_numpy(t[0])
+        tnp = tensor.to_numpy(t)
         self.assertEquals(np.sum(tnp), 12)
+        t3 = tensor.Tensor((3, 3))
+        t3.set_value(1.5)
+        grads, _ = lyr.backward(model_pb2.kTrain, [t3])
+        gnp = tensor.to_numpy(grads[0])
+        self.assertEquals(np.sum(gnp), 6 * 1.5)
 
     def test_slice(self):
         t = np.zeros((3, 3))
         t[:, :2] = float(2)
         t[:, 2] = float(1)
-        lyr = layer.Slice('slice', 1, [2], t.shape)
+        lyr = layer.Slice('slice', 1, [2], (3,))
         out = lyr.forward(model_pb2.kTrain, [tensor.from_numpy(t)])
         t1 = tensor.to_numpy(out[0])
         t2 = tensor.to_numpy(out[1])
         self.assertEquals(np.average(t1), 2)
         self.assertEquals(np.average(t2), 1)
+        t1 = tensor.Tensor((3, 2))
+        t2 = tensor.Tensor((3, 1))
+        t1.set_value(1)
+        t2.set_value(2)
+        grad,_ = lyr.backward(model_pb2.kTrain, [t1, t2])
+        gnp = tensor.to_numpy(grad)
+        self.assertEquals(np.sum(gnp), 12)
 
 
 if __name__ == '__main__':

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/82c12c2b/test/singa/test_concat.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_concat.cc b/test/singa/test_concat.cc
index 80183a7..d7f1060 100644
--- a/test/singa/test_concat.cc
+++ b/test/singa/test_concat.cc
@@ -31,10 +31,9 @@ TEST(Concat, Setup) {
   conf.set_type("singa_concat");
   conf.mutable_concat_conf()->set_axis(0);
   singa::Concat layer;
-  layer.Setup({s1, s2}, conf);
+  layer.Setup({{3u}, {3u}}, conf);
   auto s = layer.GetOutputSampleShape();
   EXPECT_EQ(s[0], 3u);
-  EXPECT_EQ(s[1], 3u);
 }
 
 void ForwardConcatRowTest(std::shared_ptr<singa::Device> dev) {
@@ -45,7 +44,7 @@ void ForwardConcatRowTest(std::shared_ptr<singa::Device> dev) {
   conf.set_type("singa_concat");
   conf.mutable_concat_conf()->set_axis(0);
   singa::Concat layer;
-  layer.Setup({t1.shape(), t2.shape()}, conf);
+  layer.Setup({{c}, {c}}, conf);
   layer.ToDevice(dev);
 
   t1.SetValue(1.0f);
@@ -74,7 +73,7 @@ void ForwardConcatColumnTest(std::shared_ptr<singa::Device> dev) {
   conf.set_type("singa_concat");
   conf.mutable_concat_conf()->set_axis(1);
   singa::Concat layer;
-  layer.Setup({t1.shape(), t2.shape()}, conf);
+  layer.Setup({{a}, {b}}, conf);
   layer.ToDevice(dev);
 
   t1.SetValue(1.0f);
@@ -119,9 +118,15 @@ void BackwardConcatRowTest(std::shared_ptr<singa::Device> dev) {
   conf.set_type("singa_concat");
   conf.mutable_concat_conf()->set_axis(0);
   singa::Concat layer;
-  layer.Setup({{a, c}, {b, c}}, conf);
+  layer.Setup({{c}, {c}}, conf);
   layer.ToDevice(dev);
 
+  singa::Tensor t1({a, c}, dev);
+  singa::Tensor t2({b, c}, dev);
+  t1.SetValue(1.0f);
+  t2.SetValue(2.0f);
+  layer.Forward(singa::kTrain, {t1, t2});
+
   singa::Tensor t({a + b, c}, dev);
   singa::Uniform(-1.f, 1.f, &t);
   auto out = layer.Backward(singa::kTrain, {t});
@@ -149,9 +154,15 @@ void BackwardConcatColumnTest(std::shared_ptr<singa::Device> dev) {
   conf.set_type("singa_concat");
   conf.mutable_concat_conf()->set_axis(1);
   singa::Concat layer;
-  layer.Setup({{c, a}, {c, b}}, conf);
+  layer.Setup({{a}, {b}}, conf);
   layer.ToDevice(dev);
 
+  singa::Tensor t1({c, a}, dev);
+  singa::Tensor t2({c, b}, dev);
+  t1.SetValue(1.0f);
+  t2.SetValue(2.0f);
+  layer.Forward(singa::kTrain, {t1, t2});
+
   singa::Tensor t({c, a + b}, dev);
   singa::Uniform(-1.f, 1.f, &t);
   auto out = layer.Backward(singa::kTrain, {t});

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/82c12c2b/test/singa/test_slice.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_slice.cc b/test/singa/test_slice.cc
index 6039c47..f6b8997 100644
--- a/test/singa/test_slice.cc
+++ b/test/singa/test_slice.cc
@@ -24,20 +24,17 @@
 
 using singa::Shape;
 TEST(Slice, Setup) {
-  Shape s{2u, 3u};
   singa::LayerConf conf;
   conf.set_type("singa_slice");
   auto slice_conf = conf.mutable_slice_conf();
   slice_conf->set_axis(1);
   slice_conf->add_slice_point(2);
   singa::Slice layer;
-  layer.Setup(s, conf);
+  layer.Setup({3u}, conf);
   auto s1 = layer.GetOutputSampleShape(0);
   EXPECT_EQ(s1[0], 2u);
-  EXPECT_EQ(s1[1], 2u);
   auto s2 = layer.GetOutputSampleShape(1);
-  EXPECT_EQ(s2[0], 2u);
-  EXPECT_EQ(s2[1], 1u);
+  EXPECT_EQ(s2[0], 1u);
 }
 
 void ForwardSliceRowTest(std::shared_ptr<singa::Device> dev) {
@@ -46,9 +43,9 @@ void ForwardSliceRowTest(std::shared_ptr<singa::Device> dev) {
   conf.set_type("singa_slice");
   auto slice_conf = conf.mutable_slice_conf();
   slice_conf->set_axis(0);
-  slice_conf->add_slice_point(2);
+  slice_conf->add_slice_point(a);
   singa::Slice layer;
-  layer.Setup({a + b ,c}, conf);
+  layer.Setup({c}, conf);
   layer.ToDevice(dev);
 
   singa::Tensor t({a + b, c}, dev);
@@ -77,9 +74,9 @@ void ForwardSliceColumnTest(std::shared_ptr<singa::Device> dev) {
   conf.set_type("singa_slice");
   auto slice_conf = conf.mutable_slice_conf();
   slice_conf->set_axis(1);
-  slice_conf->add_slice_point(2);
+  slice_conf->add_slice_point(a);
   singa::Slice layer;
-  layer.Setup({c, a + b}, conf);
+  layer.Setup({a + b}, conf);
   layer.ToDevice(dev);
 
   singa::Tensor t({c, a + b}, dev);
@@ -132,7 +129,7 @@ void BackwardSliceRowTest(std::shared_ptr<singa::Device> dev) {
   slice_conf->set_axis(0);
   slice_conf->add_slice_point(2);
   singa::Slice layer;
-  layer.Setup({a + b ,c}, conf);
+  layer.Setup({c}, conf);
   layer.ToDevice(dev);
 
   singa::Tensor t1({a, c}, dev);
@@ -162,7 +159,7 @@ void BackwardSliceColumnTest(std::shared_ptr<singa::Device> dev) {
   slice_conf->set_axis(1);
   slice_conf->add_slice_point(2);
   singa::Slice layer;
-  layer.Setup({c , a + b}, conf);
+  layer.Setup({a + b}, conf);
   layer.ToDevice(dev);
 
   singa::Tensor t1({c, a}, dev);