You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2021/07/13 15:19:23 UTC

[GitHub] [incubator-mxnet] bgawrych commented on a change in pull request #20430: [FEATURE] Asymmetric fc fc

bgawrych commented on a change in pull request #20430:
URL: https://github.com/apache/incubator-mxnet/pull/20430#discussion_r668865247



##########
File path: tests/python/quantization/test_quantization.py
##########
@@ -1348,6 +1348,64 @@ def check(number, qdtype):
             check(i, qdtype)
 
 
+@with_seed()
+def test_onednn_shifted_fc_fc():
+    batch_size = 2
+    if not is_test_for_mkldnn():
+        print("Test only for mkldnn")
+        return
+
+    def get_fc_fc_layers():
+        net = mx.gluon.nn.HybridSequential()
+        with net.name_scope():
+            net.add(mx.gluon.nn.Dense(2, use_bias=True, flatten=True,
+                                      weight_initializer=mx.initializer.Normal(),
+                                      bias_initializer=mx.initializer.Normal()))
+            net.add(mx.gluon.nn.Dense(2, use_bias=True, flatten=True,
+                                      weight_initializer=mx.initializer.Normal(),
+                                      bias_initializer=mx.initializer.Normal()))
+        net.initialize()
+        return net
+
+    def quantize_net(net, qdtype, random_data):
+        calib_data = NDArrayIter(data=random_data, batch_size=batch_size)
+        calib_data = DummyIter(calib_data)
+        net = mx.contrib.quant.quantize_net(net, quantized_dtype=qdtype,
+                                            exclude_layers=None,
+                                            exclude_layers_match=[],
+                                            calib_data=calib_data,
+                                            calib_mode='naive',
+                                            num_calib_examples=1,
+                                            ctx=mx.current_context())
+        net.hybridize(static_alloc=True, static_shape=True)
+        print("calibrated, now run to get symbol")
+        out = net(random_data)
+        out.wait_to_read()
+
+        _, sym = net._cached_graph
+        fc_attrs = sym.attr_dict()['quantized_sg_mkldnn_fully_connected_0']
+        return fc_attrs, out
+
+    def check(qdtype, random_data):
+        net_ref = get_fc_fc_layers()
+        out_ref = net_ref(random_data)
+        out_ref.wait_to_read()
+
+        fc_attrs, out_q = quantize_net(net_ref, qdtype, random_data)
+
+        assert_almost_equal(out_ref, out_q)
+
+        if qdtype == 'auto':
+            assert fc_attrs['shifted_output'] == 'True'
+        else:
+            assert 'shifted' not in fc_attrs
+
+    with environment({'MXNET_DISABLE_SHIFTED_QUANTIZATION_OPTIMIZATIONS': '0',
+        'MXNET_DISABLE_SHIFTED_QUANTIZE_FC_OPTIMIZATION': '1'}):

Review comment:
       Can you align this env variables?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@mxnet.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org