You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ma...@apache.org on 2021/04/29 20:14:23 UTC

[incubator-mxnet] branch v1.8.x updated: Reuse params from cached_op_args (#20221) (#20229)

This is an automated email from the ASF dual-hosted git repository.

manuseth pushed a commit to branch v1.8.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.8.x by this push:
     new e8a0a93  Reuse params from cached_op_args (#20221) (#20229)
e8a0a93 is described below

commit e8a0a93b500caa909874d341138715d21add0130
Author: Sam Skalicky <sa...@gmail.com>
AuthorDate: Thu Apr 29 13:10:53 2021 -0700

    Reuse params from cached_op_args (#20221) (#20229)
    
    * initial commit
    
    * fixed handling
    
    * fixed formatting
    
    Co-authored-by: Ubuntu <ub...@ip-172-31-6-220.us-west-2.compute.internal>
    
    Co-authored-by: Ubuntu <ub...@ip-172-31-6-220.us-west-2.compute.internal>
---
 python/mxnet/gluon/block.py | 20 +++++++++++++++-----
 1 file changed, 15 insertions(+), 5 deletions(-)

diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index edd3372..561762e 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -872,6 +872,7 @@ class HybridBlock(Block):
         super(HybridBlock, self).__init__(prefix=prefix, params=params)
         self._cached_graph = ()
         self._cached_op = None
+        self._cached_op_args = []
         self._out_format = None
         self._in_format = None
         self._active = False
@@ -923,10 +924,17 @@ class HybridBlock(Block):
     def _build_cache(self, *args):
         data, out = self._get_graph(*args)
         data_names = {data.name: i for i, data in enumerate(data)}
-        params = self.collect_params()
         input_names = out.list_inputs()
-        param_names = set(params.keys())
         expected_names = set(input_names)
+
+        # try to reuse cached_op_args for params
+        if len(self._cached_op_args) > 0:
+            params = {param_tuple[1].name:param_tuple[1]
+                      for param_tuple in self._cached_op_args
+                      if isinstance(param_tuple[1], Parameter)}
+        else:
+            params = self.collect_params()
+        param_names = set(params.keys())
         for name in expected_names:
             assert name in param_names or name in data_names, \
                 "Unknown input to HybridBlock: %s" %name
@@ -1133,10 +1141,11 @@ class HybridBlock(Block):
         """
         if len(kwargs) > 0:
             self._backend_opts = kwargs
+        if not backend:
+            raise ValueError('Must specify "backend" to optimize_for')
 
-        if clear or not self._active:
-            self.hybridize(True, backend, clear, static_alloc, static_shape,
-                           inline_limit, forward_bulk_size, backward_bulk_size)
+        self.hybridize(True, backend, clear, static_alloc, static_shape,
+                       inline_limit, forward_bulk_size, backward_bulk_size)
 
         # do part of forward API call
         has_symbol, has_ndarray, ctx_set, _ = _gather_type_ctx_info([x] + list(args))
@@ -1160,6 +1169,7 @@ class HybridBlock(Block):
     def _clear_cached_op(self):
         self._cached_graph = ()
         self._cached_op = None
+        self._cached_op_args = []
 
     def register_child(self, block, name=None):
         if not isinstance(block, HybridBlock):