You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by bg...@apache.org on 2022/05/23 05:54:58 UTC

[incubator-mxnet] branch master updated: [master] Enabled tests using the whole batch for calibration (#21008)

This is an automated email from the ASF dual-hosted git repository.

bgawrych pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 9ca3b27635 [master] Enabled tests using the whole batch for calibration (#21008)
9ca3b27635 is described below

commit 9ca3b27635a289ef197a17eeb627464ce1c20b76
Author: DominikaJedynak <do...@intel.com>
AuthorDate: Mon May 23 07:54:38 2022 +0200

    [master] Enabled tests using the whole batch for calibration (#21008)
    
    * Added possibility to rerun test with all batches used for calibration
    
    * Review suggestions
    
    * Calibration dataset fix
---
 tests/python/dnnl/subgraphs/subgraph_common.py | 21 +++++++++++++++++++--
 1 file changed, 19 insertions(+), 2 deletions(-)

diff --git a/tests/python/dnnl/subgraphs/subgraph_common.py b/tests/python/dnnl/subgraphs/subgraph_common.py
index 5615a398fb..009d9cc785 100644
--- a/tests/python/dnnl/subgraphs/subgraph_common.py
+++ b/tests/python/dnnl/subgraphs/subgraph_common.py
@@ -159,9 +159,26 @@ def check_quantize(net_original, data_shapes, out_type, name='conv',
       # make a list to have a common path for one and multiple outputs
       ref_out = [ref_out]
 
-  dataArray= mx.gluon.data.ArrayDataset(*data)
+  class TestDataLoader(mx.gluon.data.DataLoader):
+    def __init__(self, data):
+      self.data = data
+      self.finish = False
+
+    def __iter__(self):
+      self.finish = False
+      return self
+
+    def __next__(self):
+      if self.finish:
+        raise StopIteration
+      self.finish = True
+      return self.data
+
+    def __del__(self):
+      pass
+
+  calib_data = TestDataLoader(data)
 
-  calib_data = mx.gluon.data.DataLoader(dataArray, batch_size=1)
   for quantize_granularity in quantize_granularity_list:
     qnet = quantization.quantize_net(net_original,
                                      device=mx.cpu(),