You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/01/04 19:15:06 UTC

[GitHub] piiswrong closed pull request #9290: add epoch parameter to export() of HybridBlock

piiswrong closed pull request #9290: add epoch parameter to export() of HybridBlock
URL: https://github.com/apache/incubator-mxnet/pull/9290
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index 37734ac389..0d49def8cf 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -470,7 +470,7 @@ def infer_type(self, *args):
         """Infers data type of Parameters from inputs."""
         self._infer_attrs('infer_type', 'dtype', *args)
 
-    def export(self, path):
+    def export(self, path, epoch=0):
         """Export HybridBlock to json format that can be loaded by `mxnet.mod.Module`
         or the C++ interface.
 
@@ -480,8 +480,10 @@ def export(self, path):
         Parameters
         ----------
         path : str
-            Path to save model. Two files `path-symbol.json` and `path-0000.params`
-            will be created.
+            Path to save model. Two files `path-symbol.json` and `path-xxxx.params`
+            will be created, where xxxx is the 4 digits epoch number.
+        epoch : int
+            Epoch number of saved model.
         """
         if not self._cached_graph:
             raise RuntimeError(
@@ -499,7 +501,7 @@ def export(self, path):
             else:
                 assert name in aux_names
                 arg_dict['aux:%s'%name] = param._reduce()
-        ndarray.save('%s-0000.params'%path, arg_dict)
+        ndarray.save('%s-%04d.params'%(path, epoch), arg_dict)
 
     def forward(self, x, *args):
         """Defines the forward computation. Arguments can be either


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services