You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by bg...@apache.org on 2022/03/10 13:44:33 UTC
[incubator-mxnet] branch v1.x updated: [v1.x] Reduce after quantization memory usage (#20925)
This is an automated email from the ASF dual-hosted git repository.
bgawrych pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/v1.x by this push:
new 06e5c73 [v1.x] Reduce after quantization memory usage (#20925)
06e5c73 is described below
commit 06e5c7317b0c4b0c4a528df3d04f99be6bc97149
Author: bgawrych <ba...@intel.com>
AuthorDate: Thu Mar 10 14:42:23 2022 +0100
[v1.x] Reduce after quantization memory usage (#20925)
* [v1.x] Reduce after quantization memory usage
* fix sanity
Co-authored-by: Bartlomiej Gawrych <ba...@intel.com>
---
python/mxnet/contrib/quantization.py | 8 +++++++-
1 file changed, 7 insertions(+), 1 deletion(-)
diff --git a/python/mxnet/contrib/quantization.py b/python/mxnet/contrib/quantization.py
index 732938e..609e0c2 100644
--- a/python/mxnet/contrib/quantization.py
+++ b/python/mxnet/contrib/quantization.py
@@ -996,14 +996,20 @@ def quantize_net_v2(network, quantized_dtype='auto', quantize_mode='full', quant
save_dict.update({('aux:%s' % k): v.as_in_context(cpu())
for k, v in aux_params.items()})
nd_save(param_name, save_dict)
+ for _, v in net.collect_params().items():
+ v.grad_req = 'null'
net.collect_params().load(param_name, cast_dtype=True, dtype_source='saved')
net.collect_params().reset_ctx(ctx)
+
if quantized_dtype == 'auto':
mx.nd.waitall()
net.optimize_for(x=data_nd, backend="MKLDNNShiftedQuantization")
tmp_file = os.path.join(tmpdirname, 'model')
net.export(tmp_file)
- net = SymbolBlock.imports(tmp_file + '-symbol.json', data_names, tmp_file + '-0000.params')
+ net = SymbolBlock.imports(tmp_file + '-symbol.json', data_names)
+ for _, v in net.collect_params().items():
+ v.grad_req = 'null'
+ net.collect_params().load(tmp_file + '-0000.params', cast_dtype=True, dtype_source='saved')
return net
def quantize_net(network, quantized_dtype='auto', quantize_mode='full',