You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by sa...@apache.org on 2020/12/15 18:15:28 UTC

[incubator-mxnet] branch v1.x updated: [v1.x] Save/Load Gluon Blocks & HybridBlocks (#19565)

This is an automated email from the ASF dual-hosted git repository.

samskalicky pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.x by this push:
     new b5fd18a  [v1.x] Save/Load Gluon Blocks & HybridBlocks (#19565)
b5fd18a is described below

commit b5fd18a139649e411bce19fdad552ef60bc6d2c9
Author: Sam Skalicky <sa...@gmail.com>
AuthorDate: Tue Dec 15 10:12:40 2020 -0800

    [v1.x] Save/Load Gluon Blocks & HybridBlocks (#19565)
    
    * initial commit
    
    * small tweaks
    
    * renamed load_json to fromjson
    
    * fixed fromjson
    
    * fixed sanity
    
    * changed tests data to zeros
    
    * undo renaming in quantization
    
    * fix indent
    
    * changed to with open
    
    * undo load_json -> fromjson change
---
 python/mxnet/gluon/block.py              | 145 ++++++++++++++++++++++++++++++-
 tests/python/unittest/test_gluon_save.py |  64 ++++++++++++++
 2 files changed, 208 insertions(+), 1 deletion(-)

diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index 4ab7290..d415c5f 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -24,12 +24,13 @@ import threading
 import copy
 import warnings
 import re
+import json
 from collections import OrderedDict, defaultdict
 import numpy as np
 
 from ..base import mx_real_t, MXNetError
 from .. import symbol, ndarray, initializer, np_symbol
-from ..symbol import Symbol
+from ..symbol import Symbol, load_json
 from ..ndarray import NDArray
 from .. import name as _name
 from .parameter import Parameter, ParameterDict, DeferredInitializationError
@@ -661,6 +662,148 @@ class Block(object):
         for cld in self._children.values():
             cld.hybridize(active, **kwargs)
 
+    def save(self, prefix):
+        """Save the model architecture and parameters to load again later
+
+        Saves the model architecture as a nested dictionary where each Block
+        in the model is a dictionary and its children are sub-dictionaries.
+
+        Each Block is uniquely identified by Block class name and a unique ID.
+        We save the child's name that that parent uses for it to restore later
+        in order to match the saved parameters.
+
+        Recursively traverses a Block's children in order (since its an
+        OrderedDict) and uses the unique ID to denote that specific Block.
+        Assumes that the model is created in an identical order every time.
+        If the model is not able to be recreated deterministically do not
+        use this set of APIs to save/load your model.
+
+        For HybridBlocks, the cached_graph (Symbol & inputs) is saved if
+        it has already been hybridized.
+
+        Parameters
+        ----------
+        prefix : str
+            The prefix to use in filenames for saving this model:
+            <prefix>-model.json and <prefix>-model.params
+        """
+        # create empty model structure
+        model = {}
+        def _save_cached_graphs(blk, index, structure):
+            # create new entry for this block
+            mdl = {'orig_name': blk.name}
+            # encode unique name based on block type and ID
+            name = type(blk).__name__.lower()
+            structure[name+str(index[0])] = mdl
+            if isinstance(blk, HybridBlock):
+                if blk._cached_graph:
+                    # save in/out formats
+                    mdl['in_format'] = blk._in_format
+                    mdl['out_format'] = blk._out_format
+                    # save cached graph & input symbols
+                    syms, out = blk._cached_graph
+                    mdl_syms = []
+                    for sym in syms:
+                        mdl_syms.append(sym.tojson())
+                    mdl['inputs'] = mdl_syms
+                    mdl['symbol'] = out.tojson()
+                    mdl['hybridized'] = True
+                else:
+                    mdl['hybridized'] = False
+            children = dict()
+            mdl['children'] = children
+            # recursively save children
+            for ch_name, child in blk._children.items():
+                index[0] += 1
+                # save child's original name in this block's map
+                children[child.name] = ch_name
+                _save_cached_graphs(child, index, mdl)
+        # save top-level block
+        index = [0]
+        _save_cached_graphs(self, index, model)
+        # save model
+        with open(prefix+'-model.json', 'w') as fp:
+            json.dump(model, fp)
+        # save params
+        self.save_parameters(prefix+'-model.params')
+
+    def load(self, prefix):
+        """Load a model saved using the `save` API
+
+        Reconfigures a model using the saved configuration. This function
+        does not regenerate the model architecture. It resets the children's
+        names as they were when saved in order to match the names of the
+        saved parameters.
+
+        This function assumes the Blocks in the model were created in the same
+        order they were when the model was saved. This is because each Block is
+        uniquely identified by Block class name and a unique ID in order (since
+        its an OrderedDict) and uses the unique ID to denote that specific Block.
+        Assumes that the model is created in an identical order every time.
+        If the model is not able to be recreated deterministically do not
+        use this set of APIs to save/load your model.
+
+        For HybridBlocks, the cached_graph (Symbol & inputs) and settings are
+        restored if it had been hybridized before saving.
+
+        Parameters
+        ----------
+        prefix : str
+            The prefix to use in filenames for loading this model:
+            <prefix>-model.json and <prefix>-model.params
+        """
+        # load model json from file
+        with open(prefix+'-model.json') as fp:
+            model = json.load(fp)
+
+        def _load_cached_graphs(blk, index, structure):
+            # get block name
+            name = type(blk).__name__.lower()
+            # lookup previous encoded name based on block type and ID
+            mdl = structure[name+str(index[0])]
+            # rename block to what it was when saved
+            blk._name = mdl['orig_name']
+            if isinstance(blk, HybridBlock):
+                if mdl['hybridized']:
+                    # restore in/out formats
+                    blk._in_format = mdl['in_format']
+                    blk._out_format = mdl['out_format']
+                    # get saved symbol
+                    out = load_json(mdl['symbol'])
+                    syms = []
+                    # recreate inputs for this symbol
+                    for inp in mdl['inputs']:
+                        syms.append(load_json(inp))
+                    # reset cached_graph and active status
+                    blk._cached_graph = (syms, out)
+                    blk._active = True
+            # rename params with updated block name
+            pnames = list(blk.params.keys())
+            for p in pnames:
+                param = blk.params._params[p]
+                new_name = blk.name +'_'+ p[len(blk.params._prefix):]
+                blk.params._params.pop(p)
+                blk.params._params[new_name] = param
+            # recursively reload children
+            for ch_name, child in blk._children.items():
+                index[0] += 1
+                _load_cached_graphs(child, index, mdl)
+            # current set of child names
+            ch_names = list(blk._children.keys())
+            # original child names
+            children = mdl['children']
+            # loop and remap children with original names
+            for ch_name in ch_names:
+                child = blk._children[ch_name]
+                blk._children.pop(ch_name)
+                orig_name = children[child.name]
+                blk._children[orig_name] = child
+        # load top-level block
+        index = [0]
+        _load_cached_graphs(self, index, model)
+        # load params
+        self.load_parameters(prefix+'-model.params')
+
     def cast(self, dtype):
         """Cast this Block to use another data type.
 
diff --git a/tests/python/unittest/test_gluon_save.py b/tests/python/unittest/test_gluon_save.py
new file mode 100644
index 0000000..95ae7d9
--- /dev/null
+++ b/tests/python/unittest/test_gluon_save.py
@@ -0,0 +1,64 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import mxnet as mx
+from common import with_seed
+
+@with_seed()
+def test_save():
+    class MyBlock(mx.gluon.nn.Block):
+        def __init__(self, **kwargs):
+            super(MyBlock, self).__init__(**kwargs)
+        def add(self, block):
+            self._children[block.name + str(len(self._children))] = block
+        def forward(self, x, *args):
+            out = (x,) + args
+            for block in self._children.values():
+                out = block(*out)
+            return out
+
+    def createNet():
+        inside = MyBlock()
+        dense = mx.gluon.nn.Dense(10)
+        inside.add(dense)
+        net = MyBlock()
+        net.add(inside)
+        net.add(mx.gluon.nn.Dense(10))
+        return net
+
+    # create and initialize model
+    net1 = createNet()
+    net1.initialize()
+    # hybridize (the hybridizeable blocks, ie. the Dense layers)
+    net1.hybridize()
+    x = mx.nd.zeros((1,10))
+    out1 = net1(x)
+
+    # save hybridized model
+    net1.save('MyModel')
+
+    # create a new model, uninitialized
+    net2 = createNet()
+    # reload hybridized model
+    net2.load('MyModel')
+    # run inference again
+    out2 = net2(x)
+    mx.test_utils.assert_almost_equal(out1.asnumpy(), out2.asnumpy())
+
+if __name__ == '__main__':
+    import nose
+    nose.runmodule()