You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by pt...@apache.org on 2020/07/13 18:47:06 UTC
[incubator-mxnet] branch master updated: Partition API adding and
deleting new params to Block and Symbol (#18405)
This is an automated email from the ASF dual-hosted git repository.
ptrendx pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 9c5b95a Partition API adding and deleting new params to Block and Symbol (#18405)
9c5b95a is described below
commit 9c5b95a9c5d6f83a067504fb47fac4e3aed27e81
Author: Serge Panev <sp...@nvidia.com>
AuthorDate: Mon Jul 13 11:45:29 2020 -0700
Partition API adding and deleting new params to Block and Symbol (#18405)
* Add deleting of args aux aux to Partition API
Signed-off-by: Serge Panev <sp...@nvidia.com>
* Delete args from Block.params
Signed-off-by: Serge Panev <sp...@nvidia.com>
* Fix to use arg/auxdict when optimize_for is called in HybridBlock
Signed-off-by: Serge Panev <sp...@nvidia.com>
* Address PR comments
Signed-off-by: Serge Panev <sp...@nvidia.com>
---
python/mxnet/gluon/block.py | 105 ++++++++++++++++++++++++++++++------------
python/mxnet/symbol/symbol.py | 32 ++++++++++++-
2 files changed, 106 insertions(+), 31 deletions(-)
diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index 2fda080..1f9cd43 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -1040,41 +1040,69 @@ class HybridBlock(Block):
warnings.warn("Parameter %s is not used by any computation. "
"Is this intended?"%unused, stacklevel=4)
- data_indices = []
- param_indices = []
- self._cached_op_args = []
- for i, name in enumerate(input_names):
- if name in data_names:
- data_indices.append(i)
- self._cached_op_args.append((True, data_names[name]))
- else:
- param_indices.append(i)
- self._cached_op_args.append((False, params[name]))
- flags = [('data_indices', data_indices), ('param_indices', param_indices)] + \
- self._flags
-
args, _ = _flatten(args, "input")
try:
- for is_arg, i in self._cached_op_args:
- if not is_arg:
- i.data()
+ for name in input_names:
+ if name in params:
+ params[name].data()
except DeferredInitializationError:
self._deferred_infer_shape(*args)
- for is_arg, i in self._cached_op_args:
- if not is_arg:
- i._finish_deferred_init()
+ for name in input_names:
+ if name in params:
+ params[name]._finish_deferred_init()
+ arg_dict, aux_dict = dict(), dict()
if self._backend:
ctx = args[0].context
# get list of params in the order of out.list_arguments
- arg_dict = {name:args[data_names[name]] if name in data_names.keys() else params[name].data()
- for name in out.list_arguments()}
- aux_dict = {name:args[data_names[name]] if name in data_names.keys() else params[name].data()
- for name in out.list_auxiliary_states()}
+ arg_dict.update({name:args[data_names[name]] if name in data_names.keys() else params[name].data()
+ for name in out.list_arguments()})
+ aux_dict.update({name:args[data_names[name]] if name in data_names.keys() else params[name].data()
+ for name in out.list_auxiliary_states()})
# Partition the graph.
out = out.optimize_for(self._backend, arg_dict, aux_dict, ctx, **self._backend_opts)
+
#update cached graph with partitioned graph
self._cached_graph = data, out
+
+ input_names = out.list_inputs()
+ data_indices = []
+ param_indices = []
+
+ # In the default case, _cached_ops_args contains all the parameters from params (the sets are identical)
+ # In the case of Partition API optimized graph _cached_ops_args might contain some parameters from params,
+ # might contain some new parameters created during optimization and added to `arg_dict/aux_dict`,
+ # and might not contain some parameters that were deleted during optimization.
+ self._cached_op_args = []
+ for i, name in enumerate(input_names):
+ pair = None
+ if name in data_names:
+ data_indices.append(i)
+ pair = (True, data_names[name])
+ else:
+ param_indices.append(i)
+ if name in params:
+ param = params[name]
+ else:
+ # The param is missing from the original params dictionary, which means the param must have
+ # been added by the Partition API backend
+ if name in arg_dict or name:
+ param_data = arg_dict[name]
+ elif name in aux_dict:
+ param_data = aux_dict[name]
+ else:
+ raise RuntimeError('A parameter was added to the graph during optimization but it was not '
+ 'added to the parameter dicts.\n'
+ 'Please check the backend.')
+
+ param = Parameter(name)
+ param._load_init(param_data, args[0].context)
+ pair = (False, param)
+
+ self._cached_op_args.append(pair)
+
+ flags = [('data_indices', data_indices), ('param_indices', param_indices)] + \
+ self._flags
self._cached_op = ndarray.CachedOp(out, flags)
@@ -1321,12 +1349,14 @@ class HybridBlock(Block):
arg_names = set(sym.list_arguments())
aux_names = set(sym.list_auxiliary_states())
arg_dict = {}
- for param in self.collect_params().values():
- if param.name in arg_names:
- arg_dict['arg:%s'%param.name] = param._reduce()
- else:
- assert param.name in aux_names
- arg_dict['aux:%s'%param.name] = param._reduce()
+ for is_arg, param in self._cached_op_args:
+ if not is_arg:
+ name = param.name
+ if name in arg_names:
+ arg_dict['arg:{}'.format(name)] = param._reduce()
+ else:
+ assert name in aux_names
+ arg_dict['aux:{}'.format(name)] = param._reduce()
save_fn = _mx_npx.save if is_np_array() else ndarray.save
params_filename = '%s-%04d.params'%(path, epoch)
save_fn(params_filename, arg_dict)
@@ -1437,6 +1467,23 @@ class HybridBlock(Block):
# pylint: disable= invalid-name
raise NotImplementedError
+ def reset_ctx(self, ctx):
+ """Re-assign all Parameters to other contexts. If the Block is hybridized, it will reset the _cached_op_args.
+
+ Parameters
+ ----------
+ ctx : Context or list of Context, default :py:meth:`context.current_context()`.
+ Assign Parameter to given context. If ctx is a list of Context, a
+ copy will be made for each context.
+ """
+ params = self.collect_params()
+ if self._cached_op:
+ for p in self._cached_op_args:
+ # resetting parameters creating by the partitioning backend
+ if p.name not in params:
+ p.reset_ctx(ctx)
+ for p in params.values():
+ p.reset_ctx(ctx)
class SymbolBlock(HybridBlock):
"""Construct block from symbol. This is useful for using pre-trained models
diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py
index 39b8799..89ff6bf 100644
--- a/python/mxnet/symbol/symbol.py
+++ b/python/mxnet/symbol/symbol.py
@@ -1544,8 +1544,36 @@ class Symbol(SymbolBase):
raise RuntimeError('Cannot add new aux in optimize_for since aux is None\n' +
'Provide a dictionary to the aux argument to optimize_for')
- # return modified symbol
- return Symbol(out)
+ new_sym = Symbol(out)
+
+ arg_names = self.list_arguments()
+ new_arg_names = new_sym.list_arguments()
+ deleted_arg_names = set([item for item in arg_names
+ if item not in set(new_arg_names)])
+
+ if len(deleted_arg_names) > 0:
+ if args is not None:
+ for a_n in deleted_arg_names:
+ if a_n in args:
+ args.pop(a_n)
+ else:
+ warnings.warn('A param was deleted during optimization, but no args dictionary was provided.\n' +
+ 'Please ensure that your model weights match the newly optimized model.')
+
+ aux_names = self.list_auxiliary_states()
+ new_aux_names = new_sym.list_auxiliary_states()
+ deleted_aux_names = set([item for item in aux_names
+ if item not in set(new_aux_names)])
+ if len(deleted_aux_names) > 0:
+ if aux is not None:
+ for a_n in deleted_aux_names:
+ if a_n in aux:
+ aux.pop(a_n)
+ else:
+ warnings.warn('A param was deleted during optimization, but no args dictionary was provided.\n' +
+ 'Please ensure that your model weights match the newly optimized model.')
+
+ return new_sym
# pylint: disable=too-many-locals
def _simple_bind(self, ctx, grad_req='write', type_dict=None, stype_dict=None,