You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by zh...@apache.org on 2017/05/24 12:12:22 UTC

[3/5] incubator-singa git commit: SINGA-315 Reduce memory footprint by Python generator for parameter gradient

SINGA-315 Reduce memory footprint by Python generator for parameter gradient

Update the API of net::backward() function.
1. add arguments dy and output. dy for the input gradient tensor(s), e.g. from the loss functions.
   output is a list of layer names, whose output gradient tensor(s) would be returned in addition to the param gradient tensor(s).
2. returnes a generator iterator that generates (param_names, param_values, param_grads, out_grads) after processing each layer.
The callee function can update the parameters and release the gradient tensors in layerwise.


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/ea078dca
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/ea078dca
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/ea078dca

Branch: refs/heads/master
Commit: ea078dca9cfafdda9d5a15ae2f2823897d217292
Parents: fa4f631
Author: wangwei <wa...@comp.nus.edu.sg>
Authored: Tue May 23 13:48:39 2017 +0800
Committer: wangwei <wa...@comp.nus.edu.sg>
Committed: Tue May 23 13:48:39 2017 +0800

----------------------------------------------------------------------
 python/singa/net.py     | 100 ++++++++++++++++++++++++++-----------------
 test/python/test_net.py |  24 ++++++++++-
 2 files changed, 84 insertions(+), 40 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/ea078dca/python/singa/net.py
----------------------------------------------------------------------
diff --git a/python/singa/net.py b/python/singa/net.py
index 0226864..96a9c79 100644
--- a/python/singa/net.py
+++ b/python/singa/net.py
@@ -169,22 +169,25 @@ class FeedForwardNet(object):
 
         Currently only support nets with a single output layer, and a single
         loss objective and metric.
-        TODO(wangwei) consider multiple loss objectives and metrics.
+        For multiple outputs (with multiple loss/metric), please manually
+        call forward, compute loss/metric and call backward. backward() is also
+        more memory efficient than this function.
 
         Args:
             x: input data, a single input Tensor or a dict: layer name -> Tensor
             y: label data, a single input Tensor.
-
         Returns:
             gradients of parameters and the loss and metric values.
         '''
         out = self.forward(kTrain, x)
         l = self.loss.forward(kTrain, out, y)
+        m = None
         if self.metric is not None:
             m = self.metric.evaluate(out, y)
-            return self.backward(), (l.l1(), m)
-        else:
-            return self.backward(), (l.l1(), None)
+        grads = []  # store all gradient tensors; memory inefficient
+        for _, _, grad, _ in self.backward():
+            grads.append(grad)
+        return grads[::-1], l.l1(), m
 
     def evaluate(self, x, y):
         '''Evaluate the loss and metric of the given data.
@@ -250,22 +253,23 @@ class FeedForwardNet(object):
     def forward(self, flag, x, output=[]):
         '''Forward the input(s) through every layer.
 
-        If a layer has inputs from other layers and from x, the data from x is
-        ordered before the data from other layers, e.g., if layer 1 -> layer 2,
-        and x['layer 2'] has data, then the input of layer 2 is
-        flatten([x['layer 2'], output of layer 1])
-
         Args:
             flag: True for training; False for evaluation; could also be
                 model_pb2.kTrain or model_pb2.kEval, or other values for future
                 use.
-            x: a single SINGA tensor or a dictionary: layer name-> singa tensor
+            x: a single SINGA tensor if there is a single input; otherwise, a
+                dictionary: layer name-> singa tensor, for each layer accepting
+                input data. Do not associate a layer with input tensor if it is
+                connected from another layer. For such case, use a Dummy() layer
+                to accept the input data and connect the dummy layer to this
+                layer.
             output(list): a list of layer names whose output would be returned
-                in addition to the default output
+                in addition to the default output.
 
         Returns:
-            if there is only one output layer, return its output tensor(s);
-            else return a dictionary: layer name -> output tensor(s)
+            if there is only one output layer and output arg is empty, return
+                the result from the single output layer; otherwise, return a
+                dictionary: layer name -> output tensor(s)
         '''
         if self.ordered_layers is None:
             self.ordered_layers = self.topo_sort(self.layers, self.src_of_layer)
@@ -321,11 +325,26 @@ class FeedForwardNet(object):
         else:
             return ret
 
-    def backward(self):
+    def backward(self, dy, output=[]):
         '''Run back-propagation after forward-propagation.
 
+        Args:
+            dy: a single tensor if there is a single loss function; otherwise,
+                a dictionary maps the name of the layer connecting to the loss
+                function -> gradient from the loss function. Do not associate a
+                layer with gradient tensor if it is connecting to another layer.
+                For such case, connect this layer to a Dummy() layer and use the
+                dummy layer to accept the gradient.
+            output(list): a list of layer names whose output gradient would be
+                returned in addition to the param gradient
+
         Returns:
-            a list of gradient tensor for all parameters
+                a geneartor iterator that generates
+                (param_names, param_values, param_grads, layer_grads) after
+                processing each layer h, where the first three lists are for h
+                and the last item is a dictionary which maps
+                layer name -> its output gradient tensor(s). At the end of this
+                function, the key set includes all layers in the output arg.
         '''
         if self.dst_of_layer is None:
             self.dst_of_layer = {}
@@ -335,29 +354,35 @@ class FeedForwardNet(object):
                 srcs = self.src_of_layer[cur.name]
                 for src in srcs:
                     self.dst_of_layer[src.name].append(cur)
-        grad = self.loss.backward()
-        if len(grad.shape) > 1:
-            grad /= grad.shape[0]  # average across the batch
-        # print 'grad', grad.l1()
-        grads = [grad]
-        output_of_layer = {}
-        pgrads = []
+        output_of_layer = {}  # outputs generated by each layer
+        ret = {}  # outputs to return
+        if type(dy) is dict:
+            input_of_layer = dy
+        else:
+            assert isinstance(dy, tensor.Tensor), \
+                'The inputs of a net should be dict or a single tensor'
+            input_of_layer = {self.ordered_layers[-1].name: dy}
         for cur in reversed(self.ordered_layers):
+            inputs = []
+            if cur.name in input_of_layer:
+                if type(input_of_layer[cur.name]) is list:
+                    inputs.extend(input_of_layer[cur.name])
+                else:
+                    inputs.append(input_of_layer[cur.name])
             for dst in self.dst_of_layer[cur.name]:
                 outputs = output_of_layer[dst.name]
                 if type(outputs) == list:
                     assert len(outputs) > 0, \
                             'the gradient from layer %s is empty' % dst.name
-                    grads.append(outputs[0])
+                    inputs.append(outputs[0])
                     outputs.pop(0)
                 else:
-                    grads.append(outputs)
+                    inputs.append(outputs)
                     output_of_layer[dst.name] = []
                 # del output_of_layer[dst.name]
-            if len(grads) == 1:
-                grads = grads[0]
-            outs, _pgrads = cur.backward(kTrain, grads)
-            pgrads.append(_pgrads)
+            if len(inputs) == 1:
+                inputs = inputs[0]
+            outs, pgrads = cur.backward(kTrain, inputs)
             if verbose:
                 disp_src = '+'.join(
                         [dst.name for dst in self.dst_of_layer[cur.name]])
@@ -371,12 +396,10 @@ class FeedForwardNet(object):
                 output_of_layer[cur.name] = outs[::-1]
             else:
                 output_of_layer[cur.name] = outs
-            grads = []
-
-        ret = []
-        for pgrad in reversed(pgrads):
-            ret.extend(pgrad)
-        return ret
+            if cur.name in output:
+                ret[cur.name] = outs
+            # ret.update(output_of_layer)
+            yield (cur.param_names(), cur.param_values(), pgrads)
 
     def save(self, f, buffer_size=10, use_pickle=False):
         '''Save model parameters using io/snapshot.
@@ -391,7 +414,7 @@ class FeedForwardNet(object):
         '''
         if use_pickle:
             params = {}
-            # since SINGA>=1.1.1
+            # since SINGA>=1.1.1  (1101)
             params['SINGA_VERSION'] = __version__
             for (name, val) in zip(self.param_names(), self.param_values()):
                 val.to_host()
@@ -416,10 +439,10 @@ class FeedForwardNet(object):
         version = 0
 
         def get_name(name):
-            if version < 1011:
+            if version < 1101:
                 idx = name.rfind('/')
                 assert idx > 0, '/ must be in the parameter name'
-                name = name[:idx-1] + '_' + name[idx:]
+                name = name[:idx] + '_' + name[idx+1:]
             return name
 
         if use_pickle:
@@ -442,7 +465,6 @@ class FeedForwardNet(object):
             sp = snapshot.Snapshot(f, False, buffer_size)
             params = sp.read()
         if 'SINGA_VERSION' in params:
-            # for SINGA >= 1.1.1
             version = params['SINGA_VERSION']
         for name, val in zip(self.param_names(), self.param_values()):
             name = get_name(name)

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/ea078dca/test/python/test_net.py
----------------------------------------------------------------------
diff --git a/test/python/test_net.py b/test/python/test_net.py
index 50b976c..aad9b12 100644
--- a/test/python/test_net.py
+++ b/test/python/test_net.py
@@ -74,7 +74,7 @@ class TestFeedForwardNet(unittest.TestCase):
         out = tensor.to_numpy(out['split1'])
         self.assertAlmostEqual(np.average(out), 2)
 
-    def test_save(self):
+    def test_save_load(self):
         ffn = net.FeedForwardNet(loss.SoftmaxCrossEntropy())
         ffn.add(layer.Conv2D('conv', 4, 3, input_sample_shape=(3, 12, 12)))
         ffn.add(layer.Flatten('flat'))
@@ -88,6 +88,28 @@ class TestFeedForwardNet(unittest.TestCase):
         ffn.load('test_snaphost')
         ffn.load('test_pickle', use_pickle=True)
 
+    def test_train_one_batch(self):
+        ffn = net.FeedForwardNet(loss.SoftmaxCrossEntropy())
+        ffn.add(layer.Conv2D('conv', 4, 3, input_sample_shape=(3, 12, 12)))
+        ffn.add(layer.Flatten('flat'))
+        ffn.add(layer.Dense('dense', num_output=4))
+        for pname, pval in zip(ffn.param_names(), ffn.param_values()):
+            pval.set_value(0.1)
+        x = tensor.Tensor((4, 3, 12, 12))
+        x.gaussian(0, 0.01)
+        y = np.asarray([[1, 0, 0],
+                        [0, 0, 1],
+                        [0, 0, 1],
+                        [0, 1, 0]], dtype=np.int32)
+        y = tensor.from_numpy(y)
+        o = ffn.forward(True, x)
+        ffn.loss.forward(True, o, y)
+        g = ffn.loss.backward()
+        for pname, pvalue, pgrad in ffn.backward(g):
+            self.assertEqual(len(pvalue), len(pgrad))
+            for p, g in zip(pvalue, pgrad):
+                self.assertEqual(p.size(), g.size())
+
 
 if __name__ == '__main__':
     unittest.main()