You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2020/12/10 22:42:33 UTC

[GitHub] [incubator-mxnet] leezu commented on a change in pull request #19564: Save/Load Gluon Blocks & HybridBlocks

leezu commented on a change in pull request #19564:
URL: https://github.com/apache/incubator-mxnet/pull/19564#discussion_r540550518



##########
File path: python/mxnet/gluon/block.py
##########
@@ -571,6 +572,140 @@ def initialize(self, init=initializer.Uniform(), ctx=None, verbose=False,
         for v in params.values():
             v.initialize(None, ctx, init, force_reinit=force_reinit)
 
+    def save(self, prefix):
+        """Save the model architecture and parameters to load again later
+
+        Saves the model architecture as a nested dictionary where each Block
+        in the model is a dictionary and its children are sub-dictionaries.
+
+        Each Block is uniquely identified by Block class name and a unique ID.
+        We save each Block's parameter UUID to restore later in order to match
+        the saved parameters.
+
+        Recursively traverses a Block's children in order (since its an
+        OrderedDict) and uses the unique ID to denote that specific Block.
+
+        Assumes that the model is created in an identical order every time.
+        If the model is not able to be recreated deterministically do not
+        use this set of APIs to save/load your model.
+
+        For HybridBlocks, the cached_graph is saved (Symbol & inputs) if
+        it has already been hybridized.
+
+        Parameters
+        ----------
+        prefix : str
+            The prefix to use in filenames for saving this model:
+            <prefix>-model.json and <prefix>-model.params
+        """
+        # create empty model structure
+        model = {}
+        def _save_cached_graphs(blk, index, structure):
+            # create new entry for this block
+            mdl = {}
+            # encode unique name based on block type and ID
+            name = type(blk).__name__.lower()
+            structure[name+str(index[0])] = mdl
+            if isinstance(blk, HybridBlock):
+                if blk._cached_graph:
+                    # save in/out formats
+                    mdl['in_format'] = blk._in_format
+                    mdl['out_format'] = blk._out_format
+                    # save cached graph & input symbols
+                    syms, out = blk._cached_graph
+                    mdl_syms = []
+                    for sym in syms:
+                        mdl_syms.append(sym.tojson())
+                    mdl['inputs'] = mdl_syms
+                    mdl['symbol'] = out.tojson()
+                    mdl['hybridized'] = True
+                else:
+                    mdl['hybridized'] = False
+            # save param uuids
+            pmap = {}
+            mdl['params'] = pmap
+            pnames = list(blk.params.keys())
+            for p in pnames:
+                param = blk.params[p]
+                pmap[p] = param._uuid
+            # recursively save children
+            for child in blk._children.values():
+                index[0] += 1
+                _save_cached_graphs(child(), index, mdl)
+        # save top-level block
+        index = [0]

Review comment:
       Why is index a list?

##########
File path: python/mxnet/gluon/block.py
##########
@@ -571,6 +572,140 @@ def initialize(self, init=initializer.Uniform(), ctx=None, verbose=False,
         for v in params.values():
             v.initialize(None, ctx, init, force_reinit=force_reinit)
 
+    def save(self, prefix):
+        """Save the model architecture and parameters to load again later
+
+        Saves the model architecture as a nested dictionary where each Block
+        in the model is a dictionary and its children are sub-dictionaries.
+
+        Each Block is uniquely identified by Block class name and a unique ID.
+        We save each Block's parameter UUID to restore later in order to match
+        the saved parameters.
+
+        Recursively traverses a Block's children in order (since its an
+        OrderedDict) and uses the unique ID to denote that specific Block.
+
+        Assumes that the model is created in an identical order every time.
+        If the model is not able to be recreated deterministically do not
+        use this set of APIs to save/load your model.
+
+        For HybridBlocks, the cached_graph is saved (Symbol & inputs) if
+        it has already been hybridized.
+
+        Parameters
+        ----------
+        prefix : str
+            The prefix to use in filenames for saving this model:
+            <prefix>-model.json and <prefix>-model.params
+        """
+        # create empty model structure
+        model = {}
+        def _save_cached_graphs(blk, index, structure):
+            # create new entry for this block
+            mdl = {}
+            # encode unique name based on block type and ID
+            name = type(blk).__name__.lower()
+            structure[name+str(index[0])] = mdl
+            if isinstance(blk, HybridBlock):
+                if blk._cached_graph:
+                    # save in/out formats
+                    mdl['in_format'] = blk._in_format
+                    mdl['out_format'] = blk._out_format
+                    # save cached graph & input symbols
+                    syms, out = blk._cached_graph
+                    mdl_syms = []
+                    for sym in syms:
+                        mdl_syms.append(sym.tojson())
+                    mdl['inputs'] = mdl_syms
+                    mdl['symbol'] = out.tojson()
+                    mdl['hybridized'] = True
+                else:
+                    mdl['hybridized'] = False
+            # save param uuids
+            pmap = {}
+            mdl['params'] = pmap
+            pnames = list(blk.params.keys())
+            for p in pnames:
+                param = blk.params[p]
+                pmap[p] = param._uuid
+            # recursively save children
+            for child in blk._children.values():
+                index[0] += 1
+                _save_cached_graphs(child(), index, mdl)
+        # save top-level block
+        index = [0]
+        _save_cached_graphs(self, index, model)
+        # save model
+        fp = open(prefix+'-model.json', 'w')
+        json.dump(model, fp)
+        fp.close()

Review comment:
       `with open()` is preferred in Python




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org