You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by zh...@apache.org on 2017/07/05 07:41:12 UTC

incubator-singa git commit: SINGA-329 - Support layer freezing during training (fine-tuning)

Repository: incubator-singa
Updated Branches:
  refs/heads/master 913417ad9 -> b6874d4f0


SINGA-329 - Support layer freezing during training (fine-tuning)

Adding an argument 'freeze' to the net.py FeedForwardNet::forward and backward function.
The backward function will stop BP after the 'freeze' layer.


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/b6874d4f
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/b6874d4f
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/b6874d4f

Branch: refs/heads/master
Commit: b6874d4f0c368068ab1c7954a14e7590b1d5a53f
Parents: 913417a
Author: wangwei <wa...@comp.nus.edu.sg>
Authored: Wed Jul 5 13:42:56 2017 +0800
Committer: wangwei <wa...@comp.nus.edu.sg>
Committed: Wed Jul 5 13:42:56 2017 +0800

----------------------------------------------------------------------
 python/singa/net.py | 20 +++++++++++++++++---
 1 file changed, 17 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b6874d4f/python/singa/net.py
----------------------------------------------------------------------
diff --git a/python/singa/net.py b/python/singa/net.py
index 3bd960a..faaef2b 100644
--- a/python/singa/net.py
+++ b/python/singa/net.py
@@ -252,7 +252,7 @@ class FeedForwardNet(object):
                     order.append(lyr)
         return order
 
-    def forward(self, flag, x, output=[]):
+    def forward(self, flag, x, output=[], freeze=None):
         '''Forward the input(s) through every layer.
 
         Args:
@@ -267,6 +267,8 @@ class FeedForwardNet(object):
                 layer.
             output(list): a list of layer names whose output would be returned
                 in addition to the default output.
+            freeze(str): layer name, freeze all layers before this layer; flag
+                is set to false for these layers.
 
         Returns:
             if there is only one output layer and output arg is empty, return
@@ -283,7 +285,16 @@ class FeedForwardNet(object):
             input_of_layer = {self.ordered_layers[0].name: x}
         output_of_layer = {}  # outputs generated by each layer
         ret = {}  # outputs to return
+        if freeze is not None:
+            is_valid = False
+            for lyr in self.ordered_layers:
+                is_valid |= lyr.name == freeze
+            assert is_valid, 'Invalid freeze layer name =%s' % freeze
+            old_flag = flag
+            flag = False
         for cur in self.ordered_layers:
+            if cur.name == freeze:
+                flag = old_flag
             inputs = []
             if cur.name in input_of_layer:
                 if type(input_of_layer[cur.name]) is list:
@@ -327,7 +338,7 @@ class FeedForwardNet(object):
         else:
             return ret
 
-    def backward(self, dy, output=[]):
+    def backward(self, dy, output=[], freeze=None):
         '''Run back-propagation after forward-propagation.
 
         Args:
@@ -339,6 +350,7 @@ class FeedForwardNet(object):
                 dummy layer to accept the gradient.
             output(list): a list of layer names whose output gradient would be
                 returned in addition to the param gradient
+            freeze(str): layer name, stop backward after this layer.
 
         Returns:
                 a geneartor iterator that generates
@@ -402,6 +414,8 @@ class FeedForwardNet(object):
                 ret[cur.name] = outs
             # ret.update(output_of_layer)
             yield (cur.param_names(), cur.param_values(), pgrads, ret)
+            if cur.name == freeze:
+                break
 
     def save(self, f, buffer_size=10, use_pickle=False):
         '''Save model parameters using io/snapshot.
@@ -414,10 +428,10 @@ class FeedForwardNet(object):
                 otherwise, it would use protobuf for serialization, which uses
                 less space.
         '''
-        params['SINGA_VERSION'] = __version__
         if use_pickle:
             params = {}
             # since SINGA>=1.1.1  (1101)
+            params['SINGA_VERSION'] = __version__
             for (name, val) in zip(self.param_names(), self.param_values()):
                 val.to_host()
                 params[name] = tensor.to_numpy(val)