You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by bg...@apache.org on 2022/03/10 13:44:33 UTC

[incubator-mxnet] branch v1.x updated: [v1.x] Reduce after quantization memory usage (#20925)

This is an automated email from the ASF dual-hosted git repository.

bgawrych pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.x by this push:
     new 06e5c73  [v1.x] Reduce after quantization memory usage (#20925)
06e5c73 is described below

commit 06e5c7317b0c4b0c4a528df3d04f99be6bc97149
Author: bgawrych <ba...@intel.com>
AuthorDate: Thu Mar 10 14:42:23 2022 +0100

    [v1.x] Reduce after quantization memory usage (#20925)
    
    * [v1.x] Reduce after quantization memory usage
    
    * fix sanity
    
    Co-authored-by: Bartlomiej Gawrych <ba...@intel.com>
---
 python/mxnet/contrib/quantization.py | 8 +++++++-
 1 file changed, 7 insertions(+), 1 deletion(-)

diff --git a/python/mxnet/contrib/quantization.py b/python/mxnet/contrib/quantization.py
index 732938e..609e0c2 100644
--- a/python/mxnet/contrib/quantization.py
+++ b/python/mxnet/contrib/quantization.py
@@ -996,14 +996,20 @@ def quantize_net_v2(network, quantized_dtype='auto', quantize_mode='full', quant
         save_dict.update({('aux:%s' % k): v.as_in_context(cpu())
                           for k, v in aux_params.items()})
         nd_save(param_name, save_dict)
+        for _, v in net.collect_params().items():
+            v.grad_req = 'null'
         net.collect_params().load(param_name, cast_dtype=True, dtype_source='saved')
         net.collect_params().reset_ctx(ctx)
+
         if quantized_dtype == 'auto':
             mx.nd.waitall()
             net.optimize_for(x=data_nd, backend="MKLDNNShiftedQuantization")
             tmp_file = os.path.join(tmpdirname, 'model')
             net.export(tmp_file)
-            net = SymbolBlock.imports(tmp_file + '-symbol.json', data_names, tmp_file + '-0000.params')
+            net = SymbolBlock.imports(tmp_file + '-symbol.json', data_names)
+            for _, v in net.collect_params().items():
+                v.grad_req = 'null'
+            net.collect_params().load(tmp_file + '-0000.params', cast_dtype=True, dtype_source='saved')
     return net
 
 def quantize_net(network, quantized_dtype='auto', quantize_mode='full',