You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by sa...@apache.org on 2020/12/15 18:15:28 UTC
[incubator-mxnet] branch v1.x updated: [v1.x] Save/Load Gluon
Blocks & HybridBlocks (#19565)
This is an automated email from the ASF dual-hosted git repository.
samskalicky pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/v1.x by this push:
new b5fd18a [v1.x] Save/Load Gluon Blocks & HybridBlocks (#19565)
b5fd18a is described below
commit b5fd18a139649e411bce19fdad552ef60bc6d2c9
Author: Sam Skalicky <sa...@gmail.com>
AuthorDate: Tue Dec 15 10:12:40 2020 -0800
[v1.x] Save/Load Gluon Blocks & HybridBlocks (#19565)
* initial commit
* small tweaks
* renamed load_json to fromjson
* fixed fromjson
* fixed sanity
* changed tests data to zeros
* undo renaming in quantization
* fix indent
* changed to with open
* undo load_json -> fromjson change
---
python/mxnet/gluon/block.py | 145 ++++++++++++++++++++++++++++++-
tests/python/unittest/test_gluon_save.py | 64 ++++++++++++++
2 files changed, 208 insertions(+), 1 deletion(-)
diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index 4ab7290..d415c5f 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -24,12 +24,13 @@ import threading
import copy
import warnings
import re
+import json
from collections import OrderedDict, defaultdict
import numpy as np
from ..base import mx_real_t, MXNetError
from .. import symbol, ndarray, initializer, np_symbol
-from ..symbol import Symbol
+from ..symbol import Symbol, load_json
from ..ndarray import NDArray
from .. import name as _name
from .parameter import Parameter, ParameterDict, DeferredInitializationError
@@ -661,6 +662,148 @@ class Block(object):
for cld in self._children.values():
cld.hybridize(active, **kwargs)
+ def save(self, prefix):
+ """Save the model architecture and parameters to load again later
+
+ Saves the model architecture as a nested dictionary where each Block
+ in the model is a dictionary and its children are sub-dictionaries.
+
+ Each Block is uniquely identified by Block class name and a unique ID.
+ We save the child's name that that parent uses for it to restore later
+ in order to match the saved parameters.
+
+ Recursively traverses a Block's children in order (since its an
+ OrderedDict) and uses the unique ID to denote that specific Block.
+ Assumes that the model is created in an identical order every time.
+ If the model is not able to be recreated deterministically do not
+ use this set of APIs to save/load your model.
+
+ For HybridBlocks, the cached_graph (Symbol & inputs) is saved if
+ it has already been hybridized.
+
+ Parameters
+ ----------
+ prefix : str
+ The prefix to use in filenames for saving this model:
+ <prefix>-model.json and <prefix>-model.params
+ """
+ # create empty model structure
+ model = {}
+ def _save_cached_graphs(blk, index, structure):
+ # create new entry for this block
+ mdl = {'orig_name': blk.name}
+ # encode unique name based on block type and ID
+ name = type(blk).__name__.lower()
+ structure[name+str(index[0])] = mdl
+ if isinstance(blk, HybridBlock):
+ if blk._cached_graph:
+ # save in/out formats
+ mdl['in_format'] = blk._in_format
+ mdl['out_format'] = blk._out_format
+ # save cached graph & input symbols
+ syms, out = blk._cached_graph
+ mdl_syms = []
+ for sym in syms:
+ mdl_syms.append(sym.tojson())
+ mdl['inputs'] = mdl_syms
+ mdl['symbol'] = out.tojson()
+ mdl['hybridized'] = True
+ else:
+ mdl['hybridized'] = False
+ children = dict()
+ mdl['children'] = children
+ # recursively save children
+ for ch_name, child in blk._children.items():
+ index[0] += 1
+ # save child's original name in this block's map
+ children[child.name] = ch_name
+ _save_cached_graphs(child, index, mdl)
+ # save top-level block
+ index = [0]
+ _save_cached_graphs(self, index, model)
+ # save model
+ with open(prefix+'-model.json', 'w') as fp:
+ json.dump(model, fp)
+ # save params
+ self.save_parameters(prefix+'-model.params')
+
+ def load(self, prefix):
+ """Load a model saved using the `save` API
+
+ Reconfigures a model using the saved configuration. This function
+ does not regenerate the model architecture. It resets the children's
+ names as they were when saved in order to match the names of the
+ saved parameters.
+
+ This function assumes the Blocks in the model were created in the same
+ order they were when the model was saved. This is because each Block is
+ uniquely identified by Block class name and a unique ID in order (since
+ its an OrderedDict) and uses the unique ID to denote that specific Block.
+ Assumes that the model is created in an identical order every time.
+ If the model is not able to be recreated deterministically do not
+ use this set of APIs to save/load your model.
+
+ For HybridBlocks, the cached_graph (Symbol & inputs) and settings are
+ restored if it had been hybridized before saving.
+
+ Parameters
+ ----------
+ prefix : str
+ The prefix to use in filenames for loading this model:
+ <prefix>-model.json and <prefix>-model.params
+ """
+ # load model json from file
+ with open(prefix+'-model.json') as fp:
+ model = json.load(fp)
+
+ def _load_cached_graphs(blk, index, structure):
+ # get block name
+ name = type(blk).__name__.lower()
+ # lookup previous encoded name based on block type and ID
+ mdl = structure[name+str(index[0])]
+ # rename block to what it was when saved
+ blk._name = mdl['orig_name']
+ if isinstance(blk, HybridBlock):
+ if mdl['hybridized']:
+ # restore in/out formats
+ blk._in_format = mdl['in_format']
+ blk._out_format = mdl['out_format']
+ # get saved symbol
+ out = load_json(mdl['symbol'])
+ syms = []
+ # recreate inputs for this symbol
+ for inp in mdl['inputs']:
+ syms.append(load_json(inp))
+ # reset cached_graph and active status
+ blk._cached_graph = (syms, out)
+ blk._active = True
+ # rename params with updated block name
+ pnames = list(blk.params.keys())
+ for p in pnames:
+ param = blk.params._params[p]
+ new_name = blk.name +'_'+ p[len(blk.params._prefix):]
+ blk.params._params.pop(p)
+ blk.params._params[new_name] = param
+ # recursively reload children
+ for ch_name, child in blk._children.items():
+ index[0] += 1
+ _load_cached_graphs(child, index, mdl)
+ # current set of child names
+ ch_names = list(blk._children.keys())
+ # original child names
+ children = mdl['children']
+ # loop and remap children with original names
+ for ch_name in ch_names:
+ child = blk._children[ch_name]
+ blk._children.pop(ch_name)
+ orig_name = children[child.name]
+ blk._children[orig_name] = child
+ # load top-level block
+ index = [0]
+ _load_cached_graphs(self, index, model)
+ # load params
+ self.load_parameters(prefix+'-model.params')
+
def cast(self, dtype):
"""Cast this Block to use another data type.
diff --git a/tests/python/unittest/test_gluon_save.py b/tests/python/unittest/test_gluon_save.py
new file mode 100644
index 0000000..95ae7d9
--- /dev/null
+++ b/tests/python/unittest/test_gluon_save.py
@@ -0,0 +1,64 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import mxnet as mx
+from common import with_seed
+
+@with_seed()
+def test_save():
+ class MyBlock(mx.gluon.nn.Block):
+ def __init__(self, **kwargs):
+ super(MyBlock, self).__init__(**kwargs)
+ def add(self, block):
+ self._children[block.name + str(len(self._children))] = block
+ def forward(self, x, *args):
+ out = (x,) + args
+ for block in self._children.values():
+ out = block(*out)
+ return out
+
+ def createNet():
+ inside = MyBlock()
+ dense = mx.gluon.nn.Dense(10)
+ inside.add(dense)
+ net = MyBlock()
+ net.add(inside)
+ net.add(mx.gluon.nn.Dense(10))
+ return net
+
+ # create and initialize model
+ net1 = createNet()
+ net1.initialize()
+ # hybridize (the hybridizeable blocks, ie. the Dense layers)
+ net1.hybridize()
+ x = mx.nd.zeros((1,10))
+ out1 = net1(x)
+
+ # save hybridized model
+ net1.save('MyModel')
+
+ # create a new model, uninitialized
+ net2 = createNet()
+ # reload hybridized model
+ net2.load('MyModel')
+ # run inference again
+ out2 = net2(x)
+ mx.test_utils.assert_almost_equal(out1.asnumpy(), out2.asnumpy())
+
+if __name__ == '__main__':
+ import nose
+ nose.runmodule()