You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by re...@apache.org on 2018/11/23 03:42:17 UTC

[incubator-mxnet] branch master updated: Support full convention in quantized pooling (#13260)

This is an automated email from the ASF dual-hosted git repository.

reminisce pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new cd0ce3b  Support full convention in quantized pooling (#13260)
cd0ce3b is described below

commit cd0ce3b13f416525736316389c55a0b31146220c
Author: Tao Lv <ta...@intel.com>
AuthorDate: Fri Nov 23 11:42:02 2018 +0800

    Support full convention in quantized pooling (#13260)
    
    * fix quantized pooling and enable it in INT8 SqueezeNet
    
    * add test
    
    * fix test
    
    * address review comments
    
    * refine the test for quantized pooling
---
 example/quantization/imagenet_gen_qsym_mkldnn.py |  6 +-----
 src/operator/quantization/quantized_pooling.cc   | 23 ++++++++++++++++++-----
 tests/python/quantization/test_quantization.py   | 18 +++++++++++++-----
 3 files changed, 32 insertions(+), 15 deletions(-)

diff --git a/example/quantization/imagenet_gen_qsym_mkldnn.py b/example/quantization/imagenet_gen_qsym_mkldnn.py
index 9056f79..c38019f 100644
--- a/example/quantization/imagenet_gen_qsym_mkldnn.py
+++ b/example/quantization/imagenet_gen_qsym_mkldnn.py
@@ -225,11 +225,7 @@ if __name__ == '__main__':
         rgb_mean = '123.68,116.779,103.939'
         rgb_std = '58.393, 57.12, 57.375'
         calib_layer = lambda name: name.endswith('_output')
-        excluded_sym_names += ['squeezenet0_flatten0_flatten0',
-                               'squeezenet0_pool0_fwd',
-                               'squeezenet0_pool1_fwd',
-                               'squeezenet0_pool2_fwd',
-                               'squeezenet0_pool3_fwd']
+        excluded_sym_names += ['squeezenet0_flatten0_flatten0']
         if exclude_first_conv:
             excluded_sym_names += ['squeezenet0_conv0_fwd']
     elif args.model == 'mobilenet1.0':
diff --git a/src/operator/quantization/quantized_pooling.cc b/src/operator/quantization/quantized_pooling.cc
index 477830a..8b62db9 100644
--- a/src/operator/quantization/quantized_pooling.cc
+++ b/src/operator/quantization/quantized_pooling.cc
@@ -52,17 +52,30 @@ bool QuantizedPoolingShape(const nnvm::NodeAttrs& attrs,
       << "kernel size (" << param.kernel[1]
       << ") exceeds input (" << dshape[W]
       << " padded to " << (dshape[W] + 2*param.pad[1]) << ")";
-  // only support valid convention
+
   oshape[N] = dshape[N];
   oshape[C] = dshape[C];
   if (param.global_pool) {
     oshape[H] = 1;
     oshape[W] = 1;
   } else {
-    oshape[H] = 1 + (dshape[H] + 2 * param.pad[0] - param.kernel[0]) /
-        param.stride[0];
-    oshape[W] = 1 + (dshape[W] + 2 * param.pad[1] - param.kernel[1]) /
-        param.stride[1];
+    if (param.pooling_convention == pool_enum::kValid) {
+      oshape[H] = 1 +
+                  (dshape[H] + 2 * param.pad[0] - param.kernel[0]) /
+                      param.stride[0];
+      oshape[W] = 1 +
+                  (dshape[W] + 2 * param.pad[1] - param.kernel[1]) /
+                      param.stride[1];
+    } else {
+      oshape[H] = 1 + static_cast<int>(std::ceil(
+                          static_cast<float>(dshape[H] + 2 * param.pad[0] -
+                                             param.kernel[0]) /
+                          param.stride[0]));
+      oshape[W] = 1 + static_cast<int>(std::ceil(
+                          static_cast<float>(dshape[W] + 2 * param.pad[1] -
+                                             param.kernel[1]) /
+                          param.stride[1]));
+    }
   }
 
   SHAPE_ASSIGN_CHECK(*in_shape, 1, TShape{1});
diff --git a/tests/python/quantization/test_quantization.py b/tests/python/quantization/test_quantization.py
index 5ae2c6c..e6212b8 100644
--- a/tests/python/quantization/test_quantization.py
+++ b/tests/python/quantization/test_quantization.py
@@ -214,7 +214,7 @@ def test_quantized_conv():
 
 @with_seed()
 def test_quantized_pooling():
-    def check_quantized_pooling(data_shape, kernel, pool_type, pad, stride, global_pool, qdtype):
+    def check_quantized_pooling(data_shape, kernel, pool_type, pad, stride, global_pool, qdtype, convention='valid'):
         if is_test_for_native_cpu():
             print('skipped testing quantized_pooling for native cpu since it is not supported yet')
             return
@@ -224,7 +224,8 @@ def test_quantized_pooling():
 
         data = mx.sym.Variable(name='data', shape=data_shape, dtype='float32')
         pooling_fp32 = mx.sym.Pooling(data=data, kernel=kernel, pad=pad, stride=stride,
-                                        pool_type=pool_type, global_pool=global_pool, cudnn_off=False)
+                                      pool_type=pool_type, global_pool=global_pool, cudnn_off=False,
+                                      pooling_convention=convention)
         arg_shapes, _, _ = pooling_fp32.infer_shape(data=data_shape)
         arg_names = pooling_fp32.list_arguments()
         pooling_fp32_exe = pooling_fp32.simple_bind(ctx=mx.current_context(), grad_req='null')
@@ -242,9 +243,10 @@ def test_quantized_pooling():
         min_data = mx.sym.Variable(name='min_data')
         max_data = mx.sym.Variable(name='max_data')
         quantized_pooling = mx.sym.contrib.quantized_pooling(data=qdata, min_data=min_data,
-                                                                max_data=max_data, kernel=kernel,
-                                                                pad=pad, stride=stride, pool_type=pool_type,
-                                                                global_pool=global_pool)
+                                                             max_data=max_data, kernel=kernel,
+                                                             pad=pad, stride=stride, pool_type=pool_type,
+                                                             global_pool=global_pool,
+                                                             pooling_convention=convention)
         pooling_int8_exe = quantized_pooling.simple_bind(ctx=mx.current_context(), grad_req='null')
         qarg_names = quantized_pooling.list_arguments()
         pooling_int8_exe.arg_dict[qarg_names[0]][:] = pooling_fp32_exe.arg_dict[arg_names[0]].astype(qdtype)
@@ -266,6 +268,12 @@ def test_quantized_pooling():
         check_quantized_pooling((3, 512, 7, 7), (7, 7), 'avg', (0, 0), (1, 1), False, qdtype)
         check_quantized_pooling((3, 512, 7, 7), (7, 7), 'avg', (0, 0), (1, 1), True, qdtype)
 
+        check_quantized_pooling((3, 4, 56, 56), (3, 3), 'max', (0, 0), (2, 2), False, qdtype, 'full')
+        check_quantized_pooling((3, 4, 56, 56), (3, 3), 'max', (0, 0), (2, 2), True, qdtype, 'full')
+        check_quantized_pooling((3, 512, 7, 7), (7, 7), 'avg', (0, 0), (1, 1), False, qdtype, 'full')
+        check_quantized_pooling((3, 512, 7, 7), (7, 7), 'avg', (0, 0), (1, 1), True, qdtype, 'full')
+
+
 @with_seed()
 def test_quantized_fc():
     def check_quantized_fc(data_shape, num_hidden, no_bias, qdtype, flatten=True):