You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by zh...@apache.org on 2020/01/17 17:49:12 UTC

[incubator-tvm] branch master updated: [QNN] Conv2D type checking for kernel per-channel scales. (#4732)

This is an automated email from the ASF dual-hosted git repository.

zhic pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new a5bb789  [QNN] Conv2D type checking for kernel per-channel scales. (#4732)
a5bb789 is described below

commit a5bb789a7b22957fafd1f91e4f7a4da5dc761ec4
Author: Animesh Jain <an...@umich.edu>
AuthorDate: Fri Jan 17 09:49:07 2020 -0800

    [QNN] Conv2D type checking for kernel per-channel scales. (#4732)
    
    * [QNN] Conv2D type checking for kernel per-channel scales.
    
    * Address commments.
    
    * Address comments.
    
    * - Adding safety checks for downcasts.
    
    Co-authored-by: shoubhik <sh...@gmail.com>
---
 src/relay/qnn/op/convolution.cc          |  5 ++++-
 src/relay/qnn/util.h                     |  4 ++++
 tests/python/relay/test_op_qnn_conv2d.py | 38 ++++++++++++++++++++++++++++----
 3 files changed, 42 insertions(+), 5 deletions(-)

diff --git a/src/relay/qnn/op/convolution.cc b/src/relay/qnn/op/convolution.cc
index c9ce0ec..2335c59 100644
--- a/src/relay/qnn/op/convolution.cc
+++ b/src/relay/qnn/op/convolution.cc
@@ -57,7 +57,10 @@ bool QnnConv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
   CHECK(IsScalarType(types[2], DataType::Int(32)));    // input_zero_point
   CHECK(IsScalarType(types[3], DataType::Int(32)));    // kernel_zero_point
   CHECK(IsScalarType(types[4], DataType::Float(32)));  // input_scale
-  CHECK(IsScalarType(types[5], DataType::Float(32)));  // kernel_scale
+  // Kernel scale can be a vector of length output_channels or a scalar.
+  size_t axis = param->kernel_layout.find('O');
+  CHECK(axis != std::string::npos) << "Kernel layout attribute is not defined";
+  AssignType(types[5], DataType::Float(32), weight->shape[axis], reporter);  // kernel scale
 
   // Collect the input tensor and output tensor devoid of scale and zero points to reuse Relay
   // Conv2D infer type function.
diff --git a/src/relay/qnn/util.h b/src/relay/qnn/util.h
index 2e33241..2316bed 100644
--- a/src/relay/qnn/util.h
+++ b/src/relay/qnn/util.h
@@ -152,6 +152,8 @@ Expr FixedPointMultiplyPerChannel(Expr tensor, std::vector<double> multiplier,
  */
 static inline bool IsScalarType(const Type& expr_type, const DataType& dtype) {
   const auto* tensor_type = expr_type.as<TensorTypeNode>();
+  CHECK(tensor_type) << "Only tensor type can be checked for scalar values. But got"
+                     << AsText(expr_type, false);
   CHECK_EQ(tensor_type->shape.size(), 0);
   CHECK(tensor_type->dtype == dtype) << "Expected " << dtype << " but got " << tensor_type->dtype;
   return true;
@@ -168,6 +170,8 @@ static inline void AssignType(const Type& expr_type, const DataType& dtype, cons
                               const TypeReporter& reporter) {
   // Scale/Zero_points can be either const scalar or a vector with C axis num elems.
   const auto* tensor_type = expr_type.as<TensorTypeNode>();
+  CHECK(tensor_type) << "Can assign type to Tensor type only. But got "
+                     << AsText(expr_type, false);
   const auto tensor_dtype = tensor_type->dtype;
   CHECK(tensor_dtype == dtype) << "Expected type is " << dtype << " but received " << tensor_dtype;
   if (tensor_type->shape.size() != 0) {
diff --git a/tests/python/relay/test_op_qnn_conv2d.py b/tests/python/relay/test_op_qnn_conv2d.py
index 9effa6f..9631ffc 100644
--- a/tests/python/relay/test_op_qnn_conv2d.py
+++ b/tests/python/relay/test_op_qnn_conv2d.py
@@ -768,8 +768,8 @@ def test_depthwise_depth_multiplier():
                                        channels=4)
         verify(ref_func, qnn_func, data_shape, data_dtype,
                 kernel_shape, kernel_dtype)
-        
-        
+
+
         # Depthwise multiplier = 2
         data_shape = (10, 4, 16, 16)
         data_dtype = 'uint8'
@@ -794,7 +794,7 @@ def test_depthwise_depth_multiplier():
                                        channels=8)
         verify(ref_func, qnn_func, data_shape, data_dtype,
                 kernel_shape, kernel_dtype)
-        
+
         # uint8 input, NHWC and HWOI
         # Depthwise multiplier = 1
         data_shape = (2, 16, 16, 4)
@@ -820,7 +820,7 @@ def test_depthwise_depth_multiplier():
                                        channels=4)
         verify(ref_func, qnn_func, data_shape, data_dtype,
                 kernel_shape, kernel_dtype)
-        
+
         # Depthwise multiplier = 2
         data_shape = (2, 16, 16, 4)
         data_dtype = 'uint8'
@@ -846,6 +846,35 @@ def test_depthwise_depth_multiplier():
         verify(ref_func, qnn_func, data_shape, data_dtype,
                 kernel_shape, kernel_dtype)
 
+def test_per_channel_kernel_scale():
+    with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d):
+        data_shape = (2, 1, 2, 4)
+        data_dtype = 'uint8'
+        kernel_shape = (3, 1, 2, 2)
+        kernel_dtype = 'uint8'
+        data = relay.var("data", shape=data_shape,
+                dtype=data_dtype)
+        kernel = relay.var("kernel", shape=kernel_shape,
+                dtype=kernel_dtype)
+        kernel_scales = [2, 2, 2]
+        kernel_scales = relay.const(np.array(kernel_scales).astype('float32'))
+        func = relay.qnn.op.conv2d(
+                data, kernel,
+                input_zero_point=relay.const(0, 'int32'),
+                kernel_zero_point=relay.const(0, 'int32'),
+                input_scale=relay.const(2.0, 'float32'),
+                kernel_scale=kernel_scales,
+                kernel_size=(2, 2),
+                padding=(0, 0),
+                strides=(1, 1),
+                dilation=(1, 1),
+                data_layout="NCHW",
+                kernel_layout="OIHW",
+                out_dtype="int32")
+
+        mod = relay.Function(relay.analysis.free_vars(func), func)
+        mod = relay.Module.from_expr(mod)
+
 if __name__ == "__main__":
     test_no_zero_point()
     test_input_zero_point()
@@ -861,3 +890,4 @@ if __name__ == "__main__":
     test_tflite_output_multiplier_greater_than_one()
     test_tflite_anistropic_strides()
     test_depthwise_depth_multiplier()
+    test_per_channel_kernel_scale()