You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by zh...@apache.org on 2016/11/22 08:49:30 UTC
[2/3] incubator-singa git commit: SINGA-271 Add Concat and Slice
layers
SINGA-271 Add Concat and Slice layers
Export c++ slice and concat layers to python
Pass python unit tests.
Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/d84af801
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/d84af801
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/d84af801
Branch: refs/heads/master
Commit: d84af80172cab4094ffeb28293d8cb0820d75cbd
Parents: 16f3bf6
Author: wang wei <wa...@comp.nus.edu.sg>
Authored: Sun Nov 20 15:47:11 2016 +0000
Committer: wang wei <wa...@comp.nus.edu.sg>
Committed: Tue Nov 22 06:33:32 2016 +0000
----------------------------------------------------------------------
include/singa/model/layer.h | 3 +-
python/singa/layer.py | 79 +++++++++++++++++++++++++++++++++-------
python/singa/net.py | 32 ++++++++++++----
src/api/model_layer.i | 37 ++++++++++++-------
test/python/test_layer.py | 35 +++++++++++++-----
5 files changed, 140 insertions(+), 46 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/d84af801/include/singa/model/layer.h
----------------------------------------------------------------------
diff --git a/include/singa/model/layer.h b/include/singa/model/layer.h
index e67fcc5..ca07a19 100644
--- a/include/singa/model/layer.h
+++ b/include/singa/model/layer.h
@@ -75,8 +75,7 @@ class Layer {
}
/// Used for layers that have multiple input tensors, e.g., concatenate layer.
- virtual void Setup(const vector<Shape>& in_samples,
- const LayerConf& conf) {
+ virtual void Setup(const vector<Shape>& in_samples, const LayerConf& conf) {
name_ = conf.name();
// TODO(wangwei) load param values from checkpoint files.
}
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/d84af801/python/singa/layer.py
----------------------------------------------------------------------
diff --git a/python/singa/layer.py b/python/singa/layer.py
index 730bea0..964ec17 100644
--- a/python/singa/layer.py
+++ b/python/singa/layer.py
@@ -94,20 +94,19 @@ class Layer(object):
# case1: parameters of conv and dense layers
# case2: type of activation layers
if (conf.type == 'Convolution' or conf.type == 4) or \
- (conf.type == 'InnerProduct' or conf.type == 14):
+ (conf.type == 'InnerProduct' or conf.type == 14):
w, b = _construct_param_specs_from_caffe_proto(conf)
del conf.param[:]
conf.param.extend([w, b])
self.param_specs.append(w)
self.param_specs.append(b)
- #print 'conf:\n', conf
+ # print 'conf:\n', conf
if conf.type == 'Pooling':
conf.pooling_conf.ceil = True
- #print 'conf:\n', conf
-
- elif (conf.type == 'ReLU' or conf.type == 18) or \
- (conf.type == 'Sigmoid' or conf.type == 19) or \
- (conf.type == 'TanH' or conf.type == 23):
+ # print 'conf:\n', conf
+ elif (conf.type == 'ReLU' or conf.type == 18 or
+ conf.type == 'Sigmoid' or conf.type == 19 or
+ conf.type == 'TanH' or conf.type == 23):
conf.type = (engine + '_' + conf.type).lower()
self.conf = conf
@@ -123,7 +122,6 @@ class Layer(object):
else:
self.layer = _create_layer(engine, str(self.conf.type))
-
def param_names(self):
'''
Returns:
@@ -145,8 +143,11 @@ class Layer(object):
'''
if self.has_setup:
return
- self.layer.Setup(list(in_shapes),
- self.conf.SerializeToString())
+ if type(in_shapes[0]) is tuple:
+ self.layer.SetupWithMultInputs([list(s) for s in in_shapes],
+ self.conf.SerializeToString())
+ else:
+ self.layer.Setup(list(in_shapes), self.conf.SerializeToString())
self.has_setup = True
def get_output_sample_shape(self):
@@ -194,6 +195,7 @@ class Layer(object):
xs = []
for t in x:
xs.append(t.singa_tensor)
+ y = self.layer.ForwardWithMultInputs(flag, xs)
else:
assert isinstance(x, tensor.Tensor), \
'input must be a Tensor or a list of Tensor'
@@ -204,7 +206,7 @@ class Layer(object):
else:
flag = model_pb2.kEval
y = self.layer.Forward(flag, xs)
- if type(y) == list:
+ if type(y) is tuple:
return tensor.from_raw_tensors(y)
else:
return tensor.from_raw_tensor(y)
@@ -224,12 +226,13 @@ class Layer(object):
dys = []
for t in dy:
dys.append(t.singa_tensor)
+ ret = self.layer.BackwardWithMultInputs(flag, dys)
else:
assert isinstance(dy, tensor.Tensor), \
'the input must be a Tensor or a set of Tensor'
dys = dy.singa_tensor
- ret = self.layer.Backward(flag, dys)
- if type(ret[0]) == list:
+ ret = self.layer.Backward(flag, dys)
+ if type(ret[0]) is tuple:
dxs = tensor.from_raw_tensors(ret[0])
else:
dxs = tensor.from_raw_tensor(ret[0])
@@ -275,6 +278,7 @@ class Dummy(Layer):
def backward(self, falg, dy):
return dy
+
class Conv2D(Layer):
"""Construct a layer for 2D convolution.
@@ -763,7 +767,7 @@ class Split(Layer):
self.has_setup = True
def get_output_sample_shape(self):
- return self.in_shape
+ return [self.in_shape] * self.num_output
def forward(self, flag, input):
'''Replicate the input tensor into mutiple tensors.
@@ -789,6 +793,53 @@ class Split(Layer):
return dx, []
+class Concat(Layer):
+ '''Concatenate tensors vertically (axis = 0) or horizontally (axis = 1).
+
+ Currently, only support tensors with 2 dimensions.
+
+ Args:
+ axis(int): 0 for concat row; 1 for concat columns;
+ input_sample_shapes: a list of shape tuples, one per input tensor
+ '''
+
+ def __init__(self, name, axis, input_sample_shapes=None):
+ super(Concat, self).__init__(name)
+ self.in_shapes = input_sample_shapes
+ self.axis = axis
+ self.conf.concat_conf.axis = axis
+ self.layer = _create_layer(engine, 'Concat')
+ if input_sample_shapes is not None:
+ self.setup(input_sample_shapes)
+
+
+class Slice(Layer):
+ '''Slice the input tensor into multiple sub-tensors vertially (axis=0) or
+ horizontally (axis=1).
+
+ Args:
+ axis (int): 0 for slice rows; 1 for slice columns;
+ slice_point(list): positions along the axis to do slice; there are n-1
+ points for n sub-tensors;
+ input_sample_shape: input tensor shape
+ '''
+
+ def __init__(self, name, axis, slice_point, input_sample_shape=None):
+ super(Slice, self).__init__(name)
+ self.in_shape = input_sample_shape
+ self.axis = axis
+ self.conf.slice_conf.axis = axis
+ self.conf.slice_conf.slice_point.extend(slice_point)
+ self.layer = _create_layer(engine, 'Slice')
+ if input_sample_shape is not None:
+ self.setup(input_sample_shape)
+
+ def get_output_sample_shape(self):
+ out = []
+ for i in range(len(self.conf.slice_conf.slice_point) + 1):
+ out.append(self.layer.GetOutputSampleShape(i))
+
+
class RNN(Layer):
'''Recurrent layer with 4 types of units, namely lstm, gru, tanh and relu.
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/d84af801/python/singa/net.py
----------------------------------------------------------------------
diff --git a/python/singa/net.py b/python/singa/net.py
index 293e97c..d34afbc 100644
--- a/python/singa/net.py
+++ b/python/singa/net.py
@@ -39,6 +39,7 @@ class FeedForwardNet(object):
self.src_of_layer = {}
self.dst_of_layer = None
self.ordered_layers = None
+ self.out_sample_shape_of_layer = {}
def to_device(self, dev):
for lyr in self.layers:
@@ -47,9 +48,11 @@ class FeedForwardNet(object):
def add(self, lyr, src=None):
"""Append a layer into the layer list.
- This function will get the sample shape from the last layer to setup
- the newly added layer. For the first layer, it is setup outside.
- The calling function should ensure the correctness of the layer order.
+ This function will get the sample shape from the src layers to setup the
+ newly added layer. For the first layer, it is setup outside. The calling
+ function should ensure the correctness of the layer order. If src is
+ None, the last layer is the src layer. If there are multiple src layers,
+ the src is a list of the src layers.
Args:
lyr (Layer): the layer to be added
@@ -70,11 +73,24 @@ class FeedForwardNet(object):
else:
self.src_of_layer[lyr.name] = []
if lyr.has_setup is False:
- # print shape
- in_shape = self.src_of_layer[lyr.name][0].get_output_sample_shape()
- lyr.setup(in_shape)
- print lyr.name, lyr.get_output_sample_shape()
+ in_shape = []
+ for src in self.src_of_layer[lyr.name]:
+ shapes = self.out_sample_shape_of_layer[src.name]
+ assert len(shapes) > 0, \
+ 'Cannot get output shape of layer %s' % lyr.name
+ in_shape.append(shapes[0])
+ shapes.pop(0)
+ if len(in_shape) == 1:
+ lyr.setup(in_shape[0])
+ else:
+ lyr.setup(in_shape)
+ out_shape = lyr.get_output_sample_shape()
+ if type(out_shape[0]) is tuple:
+ self.out_sample_shape_of_layer[lyr.name] = out_shape
+ else:
+ self.out_sample_shape_of_layer[lyr.name] = [out_shape]
self.layers.append(lyr)
+ print lyr.name, out_shape
return lyr
def param_values(self):
@@ -239,7 +255,7 @@ class FeedForwardNet(object):
disp_src += '-->' + cur.name
if type(out) is list:
print '%s: %s' % (disp_src,
- ' '.join([str(o.l1()) for o in out]))
+ ' '.join([str(o.l1()) for o in out]))
else:
print '%s: %f' % (disp_src, out.l1())
output_of_layer[cur.name] = out
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/d84af801/src/api/model_layer.i
----------------------------------------------------------------------
diff --git a/src/api/model_layer.i b/src/api/model_layer.i
index 3878873..7f582e7 100644
--- a/src/api/model_layer.i
+++ b/src/api/model_layer.i
@@ -40,6 +40,7 @@ using singa::ParamSpec;
using singa::DataType;
using singa::Device;
using singa::LayerConf;
+using singa::Shape;
%}
%shared_ptr(singa::Layer)
@@ -52,26 +53,36 @@ namespace std {
%template(VecStr) vector<string>;
%template(VecParamSpec) vector<singa::ParamSpec>;
%template(VecTensor) vector<singa::Tensor>;
+ %template(VecVecSize) vector<vector<size_t>>;
%template(PairTensorVecTensor) pair<singa::Tensor, vector<singa::Tensor>>;
%template(PairVecTensor) pair<vector<singa::Tensor>, vector<singa::Tensor>>;
}
-
namespace singa {
class Layer {
- public:
- Layer();
-// virtual void Setup(const std::vector<vector<size_t>>&, const string&);
- void Setup(const std::vector<size_t>& in_sample_shape,
- const std::string& proto_str);
- virtual const std::vector<Tensor> param_values();
- virtual const std::vector<size_t> GetOutputSampleShape() const;
- virtual void ToDevice(std::shared_ptr<Device> device);
- virtual void AsType(DataType dtype);
- virtual const Tensor Forward(int flag, const Tensor& input);
- virtual const std::pair<Tensor, std::vector<Tensor>> Backward(
- int flag, const Tensor& grad);
+ public:
+ Layer();
+ void Setup(const std::vector<size_t>&, const std::string& );
+ %rename(SetupWithMultInputs) Setup(const std::vector<std::vector<size_t>>&,
+ const std::string&);
+ void Setup(const std::vector<std::vector<size_t>>&, const std::string&);
+
+ virtual const std::vector<Tensor> param_values();
+ virtual const std::vector<size_t> GetOutputSampleShape() const;
+ %rename(GetOutputSampleShapeAt) GetOutputSampleShape(int k);
+ virtual const std::vector<size_t> GetOutputSampleShape(int k);
+ virtual void ToDevice(std::shared_ptr<Device> device);
+ virtual void AsType(DataType dtype);
+ virtual const Tensor Forward(int flag, const Tensor& input);
+ %rename(ForwardWithMultInputs) Forward(int flag, const std::vector<Tensor>&);
+ virtual const std::vector<Tensor> Forward(
+ int flag, const std::vector<Tensor>& inputs);
+ virtual const std::pair<Tensor, std::vector<Tensor>> Backward(
+ int flag, const Tensor& grad);
+ %rename(BackwardWithMultInputs) Backward(int, const vector<Tensor>&);
+ virtual const std::pair<std::vector<Tensor>, std::vector<Tensor>>
+ Backward(int flag, const vector<Tensor>& grads);
};
std::shared_ptr<Layer> CreateLayer(const std::string& type);
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/d84af801/test/python/test_layer.py
----------------------------------------------------------------------
diff --git a/test/python/test_layer.py b/test/python/test_layer.py
index 141cf56..d22207f 100644
--- a/test/python/test_layer.py
+++ b/test/python/test_layer.py
@@ -1,4 +1,4 @@
-#
+#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
@@ -6,25 +6,21 @@
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
-#
+#
# http://www.apache.org/licenses/LICENSE-2.0
-#
+#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-#
+#
-import sys
-import os
import unittest
import numpy as np
-#sys.path.append(os.path.join(os.path.dirname(__file__), '../../build/python'))
from singa import layer
-from singa import device
from singa import tensor
from singa.proto import model_pb2
@@ -43,7 +39,7 @@ class TestPythonLayer(unittest.TestCase):
)
def setUp(self):
- layer.engine='singacpp'
+ layer.engine = 'singacpp'
self.w = {'init': 'Xavier', 'regularizer': 1e-4}
self.b = {'init': 'Constant', 'value': 0}
self.sample_shape = None
@@ -208,6 +204,27 @@ class TestPythonLayer(unittest.TestCase):
out_sample_shape = flatten.get_output_sample_shape()
self.check_shape(out_sample_shape, (12,))
+ def test_concat(self):
+ t1 = tensor.Tensor((2, 3))
+ t2 = tensor.Tensor((1, 3))
+ t1.set_value(1)
+ t2.set_value(2)
+ lyr = layer.Concat('concat', 0, [t1.shape, t2.shape])
+ t = lyr.forward(model_pb2.kTrain, [t1, t2])
+ tnp = tensor.to_numpy(t[0])
+ self.assertEquals(np.sum(tnp), 12)
+
+ def test_slice(self):
+ t = np.zeros((3, 3))
+ t[:, :2] = float(2)
+ t[:, 2] = float(1)
+ lyr = layer.Slice('slice', 1, [2], t.shape)
+ out = lyr.forward(model_pb2.kTrain, [tensor.from_numpy(t)])
+ t1 = tensor.to_numpy(out[0])
+ t2 = tensor.to_numpy(out[1])
+ self.assertEquals(np.average(t1), 2)
+ self.assertEquals(np.average(t2), 1)
+
if __name__ == '__main__':
unittest.main()