You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by zh...@apache.org on 2017/05/24 12:12:22 UTC
[3/5] incubator-singa git commit: SINGA-315 Reduce memory footprint
by Python generator for parameter gradient
SINGA-315 Reduce memory footprint by Python generator for parameter gradient
Update the API of net::backward() function.
1. add arguments dy and output. dy for the input gradient tensor(s), e.g. from the loss functions.
output is a list of layer names, whose output gradient tensor(s) would be returned in addition to the param gradient tensor(s).
2. returnes a generator iterator that generates (param_names, param_values, param_grads, out_grads) after processing each layer.
The callee function can update the parameters and release the gradient tensors in layerwise.
Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/ea078dca
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/ea078dca
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/ea078dca
Branch: refs/heads/master
Commit: ea078dca9cfafdda9d5a15ae2f2823897d217292
Parents: fa4f631
Author: wangwei <wa...@comp.nus.edu.sg>
Authored: Tue May 23 13:48:39 2017 +0800
Committer: wangwei <wa...@comp.nus.edu.sg>
Committed: Tue May 23 13:48:39 2017 +0800
----------------------------------------------------------------------
python/singa/net.py | 100 ++++++++++++++++++++++++++-----------------
test/python/test_net.py | 24 ++++++++++-
2 files changed, 84 insertions(+), 40 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/ea078dca/python/singa/net.py
----------------------------------------------------------------------
diff --git a/python/singa/net.py b/python/singa/net.py
index 0226864..96a9c79 100644
--- a/python/singa/net.py
+++ b/python/singa/net.py
@@ -169,22 +169,25 @@ class FeedForwardNet(object):
Currently only support nets with a single output layer, and a single
loss objective and metric.
- TODO(wangwei) consider multiple loss objectives and metrics.
+ For multiple outputs (with multiple loss/metric), please manually
+ call forward, compute loss/metric and call backward. backward() is also
+ more memory efficient than this function.
Args:
x: input data, a single input Tensor or a dict: layer name -> Tensor
y: label data, a single input Tensor.
-
Returns:
gradients of parameters and the loss and metric values.
'''
out = self.forward(kTrain, x)
l = self.loss.forward(kTrain, out, y)
+ m = None
if self.metric is not None:
m = self.metric.evaluate(out, y)
- return self.backward(), (l.l1(), m)
- else:
- return self.backward(), (l.l1(), None)
+ grads = [] # store all gradient tensors; memory inefficient
+ for _, _, grad, _ in self.backward():
+ grads.append(grad)
+ return grads[::-1], l.l1(), m
def evaluate(self, x, y):
'''Evaluate the loss and metric of the given data.
@@ -250,22 +253,23 @@ class FeedForwardNet(object):
def forward(self, flag, x, output=[]):
'''Forward the input(s) through every layer.
- If a layer has inputs from other layers and from x, the data from x is
- ordered before the data from other layers, e.g., if layer 1 -> layer 2,
- and x['layer 2'] has data, then the input of layer 2 is
- flatten([x['layer 2'], output of layer 1])
-
Args:
flag: True for training; False for evaluation; could also be
model_pb2.kTrain or model_pb2.kEval, or other values for future
use.
- x: a single SINGA tensor or a dictionary: layer name-> singa tensor
+ x: a single SINGA tensor if there is a single input; otherwise, a
+ dictionary: layer name-> singa tensor, for each layer accepting
+ input data. Do not associate a layer with input tensor if it is
+ connected from another layer. For such case, use a Dummy() layer
+ to accept the input data and connect the dummy layer to this
+ layer.
output(list): a list of layer names whose output would be returned
- in addition to the default output
+ in addition to the default output.
Returns:
- if there is only one output layer, return its output tensor(s);
- else return a dictionary: layer name -> output tensor(s)
+ if there is only one output layer and output arg is empty, return
+ the result from the single output layer; otherwise, return a
+ dictionary: layer name -> output tensor(s)
'''
if self.ordered_layers is None:
self.ordered_layers = self.topo_sort(self.layers, self.src_of_layer)
@@ -321,11 +325,26 @@ class FeedForwardNet(object):
else:
return ret
- def backward(self):
+ def backward(self, dy, output=[]):
'''Run back-propagation after forward-propagation.
+ Args:
+ dy: a single tensor if there is a single loss function; otherwise,
+ a dictionary maps the name of the layer connecting to the loss
+ function -> gradient from the loss function. Do not associate a
+ layer with gradient tensor if it is connecting to another layer.
+ For such case, connect this layer to a Dummy() layer and use the
+ dummy layer to accept the gradient.
+ output(list): a list of layer names whose output gradient would be
+ returned in addition to the param gradient
+
Returns:
- a list of gradient tensor for all parameters
+ a geneartor iterator that generates
+ (param_names, param_values, param_grads, layer_grads) after
+ processing each layer h, where the first three lists are for h
+ and the last item is a dictionary which maps
+ layer name -> its output gradient tensor(s). At the end of this
+ function, the key set includes all layers in the output arg.
'''
if self.dst_of_layer is None:
self.dst_of_layer = {}
@@ -335,29 +354,35 @@ class FeedForwardNet(object):
srcs = self.src_of_layer[cur.name]
for src in srcs:
self.dst_of_layer[src.name].append(cur)
- grad = self.loss.backward()
- if len(grad.shape) > 1:
- grad /= grad.shape[0] # average across the batch
- # print 'grad', grad.l1()
- grads = [grad]
- output_of_layer = {}
- pgrads = []
+ output_of_layer = {} # outputs generated by each layer
+ ret = {} # outputs to return
+ if type(dy) is dict:
+ input_of_layer = dy
+ else:
+ assert isinstance(dy, tensor.Tensor), \
+ 'The inputs of a net should be dict or a single tensor'
+ input_of_layer = {self.ordered_layers[-1].name: dy}
for cur in reversed(self.ordered_layers):
+ inputs = []
+ if cur.name in input_of_layer:
+ if type(input_of_layer[cur.name]) is list:
+ inputs.extend(input_of_layer[cur.name])
+ else:
+ inputs.append(input_of_layer[cur.name])
for dst in self.dst_of_layer[cur.name]:
outputs = output_of_layer[dst.name]
if type(outputs) == list:
assert len(outputs) > 0, \
'the gradient from layer %s is empty' % dst.name
- grads.append(outputs[0])
+ inputs.append(outputs[0])
outputs.pop(0)
else:
- grads.append(outputs)
+ inputs.append(outputs)
output_of_layer[dst.name] = []
# del output_of_layer[dst.name]
- if len(grads) == 1:
- grads = grads[0]
- outs, _pgrads = cur.backward(kTrain, grads)
- pgrads.append(_pgrads)
+ if len(inputs) == 1:
+ inputs = inputs[0]
+ outs, pgrads = cur.backward(kTrain, inputs)
if verbose:
disp_src = '+'.join(
[dst.name for dst in self.dst_of_layer[cur.name]])
@@ -371,12 +396,10 @@ class FeedForwardNet(object):
output_of_layer[cur.name] = outs[::-1]
else:
output_of_layer[cur.name] = outs
- grads = []
-
- ret = []
- for pgrad in reversed(pgrads):
- ret.extend(pgrad)
- return ret
+ if cur.name in output:
+ ret[cur.name] = outs
+ # ret.update(output_of_layer)
+ yield (cur.param_names(), cur.param_values(), pgrads)
def save(self, f, buffer_size=10, use_pickle=False):
'''Save model parameters using io/snapshot.
@@ -391,7 +414,7 @@ class FeedForwardNet(object):
'''
if use_pickle:
params = {}
- # since SINGA>=1.1.1
+ # since SINGA>=1.1.1 (1101)
params['SINGA_VERSION'] = __version__
for (name, val) in zip(self.param_names(), self.param_values()):
val.to_host()
@@ -416,10 +439,10 @@ class FeedForwardNet(object):
version = 0
def get_name(name):
- if version < 1011:
+ if version < 1101:
idx = name.rfind('/')
assert idx > 0, '/ must be in the parameter name'
- name = name[:idx-1] + '_' + name[idx:]
+ name = name[:idx] + '_' + name[idx+1:]
return name
if use_pickle:
@@ -442,7 +465,6 @@ class FeedForwardNet(object):
sp = snapshot.Snapshot(f, False, buffer_size)
params = sp.read()
if 'SINGA_VERSION' in params:
- # for SINGA >= 1.1.1
version = params['SINGA_VERSION']
for name, val in zip(self.param_names(), self.param_values()):
name = get_name(name)
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/ea078dca/test/python/test_net.py
----------------------------------------------------------------------
diff --git a/test/python/test_net.py b/test/python/test_net.py
index 50b976c..aad9b12 100644
--- a/test/python/test_net.py
+++ b/test/python/test_net.py
@@ -74,7 +74,7 @@ class TestFeedForwardNet(unittest.TestCase):
out = tensor.to_numpy(out['split1'])
self.assertAlmostEqual(np.average(out), 2)
- def test_save(self):
+ def test_save_load(self):
ffn = net.FeedForwardNet(loss.SoftmaxCrossEntropy())
ffn.add(layer.Conv2D('conv', 4, 3, input_sample_shape=(3, 12, 12)))
ffn.add(layer.Flatten('flat'))
@@ -88,6 +88,28 @@ class TestFeedForwardNet(unittest.TestCase):
ffn.load('test_snaphost')
ffn.load('test_pickle', use_pickle=True)
+ def test_train_one_batch(self):
+ ffn = net.FeedForwardNet(loss.SoftmaxCrossEntropy())
+ ffn.add(layer.Conv2D('conv', 4, 3, input_sample_shape=(3, 12, 12)))
+ ffn.add(layer.Flatten('flat'))
+ ffn.add(layer.Dense('dense', num_output=4))
+ for pname, pval in zip(ffn.param_names(), ffn.param_values()):
+ pval.set_value(0.1)
+ x = tensor.Tensor((4, 3, 12, 12))
+ x.gaussian(0, 0.01)
+ y = np.asarray([[1, 0, 0],
+ [0, 0, 1],
+ [0, 0, 1],
+ [0, 1, 0]], dtype=np.int32)
+ y = tensor.from_numpy(y)
+ o = ffn.forward(True, x)
+ ffn.loss.forward(True, o, y)
+ g = ffn.loss.backward()
+ for pname, pvalue, pgrad in ffn.backward(g):
+ self.assertEqual(len(pvalue), len(pgrad))
+ for p, g in zip(pvalue, pgrad):
+ self.assertEqual(p.size(), g.size())
+
if __name__ == '__main__':
unittest.main()