You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ma...@apache.org on 2021/04/29 20:14:23 UTC
[incubator-mxnet] branch v1.8.x updated: Reuse params from
cached_op_args (#20221) (#20229)
This is an automated email from the ASF dual-hosted git repository.
manuseth pushed a commit to branch v1.8.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/v1.8.x by this push:
new e8a0a93 Reuse params from cached_op_args (#20221) (#20229)
e8a0a93 is described below
commit e8a0a93b500caa909874d341138715d21add0130
Author: Sam Skalicky <sa...@gmail.com>
AuthorDate: Thu Apr 29 13:10:53 2021 -0700
Reuse params from cached_op_args (#20221) (#20229)
* initial commit
* fixed handling
* fixed formatting
Co-authored-by: Ubuntu <ub...@ip-172-31-6-220.us-west-2.compute.internal>
Co-authored-by: Ubuntu <ub...@ip-172-31-6-220.us-west-2.compute.internal>
---
python/mxnet/gluon/block.py | 20 +++++++++++++++-----
1 file changed, 15 insertions(+), 5 deletions(-)
diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index edd3372..561762e 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -872,6 +872,7 @@ class HybridBlock(Block):
super(HybridBlock, self).__init__(prefix=prefix, params=params)
self._cached_graph = ()
self._cached_op = None
+ self._cached_op_args = []
self._out_format = None
self._in_format = None
self._active = False
@@ -923,10 +924,17 @@ class HybridBlock(Block):
def _build_cache(self, *args):
data, out = self._get_graph(*args)
data_names = {data.name: i for i, data in enumerate(data)}
- params = self.collect_params()
input_names = out.list_inputs()
- param_names = set(params.keys())
expected_names = set(input_names)
+
+ # try to reuse cached_op_args for params
+ if len(self._cached_op_args) > 0:
+ params = {param_tuple[1].name:param_tuple[1]
+ for param_tuple in self._cached_op_args
+ if isinstance(param_tuple[1], Parameter)}
+ else:
+ params = self.collect_params()
+ param_names = set(params.keys())
for name in expected_names:
assert name in param_names or name in data_names, \
"Unknown input to HybridBlock: %s" %name
@@ -1133,10 +1141,11 @@ class HybridBlock(Block):
"""
if len(kwargs) > 0:
self._backend_opts = kwargs
+ if not backend:
+ raise ValueError('Must specify "backend" to optimize_for')
- if clear or not self._active:
- self.hybridize(True, backend, clear, static_alloc, static_shape,
- inline_limit, forward_bulk_size, backward_bulk_size)
+ self.hybridize(True, backend, clear, static_alloc, static_shape,
+ inline_limit, forward_bulk_size, backward_bulk_size)
# do part of forward API call
has_symbol, has_ndarray, ctx_set, _ = _gather_type_ctx_info([x] + list(args))
@@ -1160,6 +1169,7 @@ class HybridBlock(Block):
def _clear_cached_op(self):
self._cached_graph = ()
self._cached_op = None
+ self._cached_op_args = []
def register_child(self, block, name=None):
if not isinstance(block, HybridBlock):