You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by bg...@apache.org on 2022/05/23 05:54:58 UTC
[incubator-mxnet] branch master updated: [master] Enabled tests using the whole batch for calibration (#21008)
This is an automated email from the ASF dual-hosted git repository.
bgawrych pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 9ca3b27635 [master] Enabled tests using the whole batch for calibration (#21008)
9ca3b27635 is described below
commit 9ca3b27635a289ef197a17eeb627464ce1c20b76
Author: DominikaJedynak <do...@intel.com>
AuthorDate: Mon May 23 07:54:38 2022 +0200
[master] Enabled tests using the whole batch for calibration (#21008)
* Added possibility to rerun test with all batches used for calibration
* Review suggestions
* Calibration dataset fix
---
tests/python/dnnl/subgraphs/subgraph_common.py | 21 +++++++++++++++++++--
1 file changed, 19 insertions(+), 2 deletions(-)
diff --git a/tests/python/dnnl/subgraphs/subgraph_common.py b/tests/python/dnnl/subgraphs/subgraph_common.py
index 5615a398fb..009d9cc785 100644
--- a/tests/python/dnnl/subgraphs/subgraph_common.py
+++ b/tests/python/dnnl/subgraphs/subgraph_common.py
@@ -159,9 +159,26 @@ def check_quantize(net_original, data_shapes, out_type, name='conv',
# make a list to have a common path for one and multiple outputs
ref_out = [ref_out]
- dataArray= mx.gluon.data.ArrayDataset(*data)
+ class TestDataLoader(mx.gluon.data.DataLoader):
+ def __init__(self, data):
+ self.data = data
+ self.finish = False
+
+ def __iter__(self):
+ self.finish = False
+ return self
+
+ def __next__(self):
+ if self.finish:
+ raise StopIteration
+ self.finish = True
+ return self.data
+
+ def __del__(self):
+ pass
+
+ calib_data = TestDataLoader(data)
- calib_data = mx.gluon.data.DataLoader(dataArray, batch_size=1)
for quantize_granularity in quantize_granularity_list:
qnet = quantization.quantize_net(net_original,
device=mx.cpu(),