You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by li...@apache.org on 2020/04/12 06:10:59 UTC

[incubator-tvm] branch master updated: [Requantize] Cleanup and Optimize Lowering (#5286)

This is an automated email from the ASF dual-hosted git repository.

liuyizhi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 92d0ec1  [Requantize] Cleanup and Optimize Lowering (#5286)
92d0ec1 is described below

commit 92d0ec148683b7a59d39dd74c72e74e7d65f14c8
Author: Animesh Jain <an...@umich.edu>
AuthorDate: Sat Apr 11 23:10:52 2020 -0700

    [Requantize] Cleanup and Optimize Lowering (#5286)
    
    * Adding Cast back to Int32 in FixedPointMultiply.
    
    * Removing extra clip.
    
    * Fix space.
    
    * Retrigger.
    
    * Retrigger.
---
 src/relay/qnn/op/requantize.cc | 36 +++++++++++++-----------------------
 src/relay/qnn/util.cc          |  8 ++++++--
 src/relay/quantize/realize.cc  |  3 +--
 3 files changed, 20 insertions(+), 27 deletions(-)

diff --git a/src/relay/qnn/op/requantize.cc b/src/relay/qnn/op/requantize.cc
index 4ceb359..a2a4649 100644
--- a/src/relay/qnn/op/requantize.cc
+++ b/src/relay/qnn/op/requantize.cc
@@ -132,36 +132,28 @@ Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale,
                      const Expr& input_zero_point, const Expr& output_scale,
                      const Expr& output_zero_point, const RequantizeAttrs* param,
                      const Array<IndexExpr>& input_shape, const DataType& out_dtype) {
-  DataType hp_dtype = DataType::Int(64);
-
-  auto tensor = Cast(input_tensor, hp_dtype);
+  auto tensor = Cast(input_tensor, DataType::Int(32));
   // 1) Subtract the input_zero_point
   auto zero_scalar = MakeConstantScalar(DataType::Int(32), 0);
   if (!IsEqualScalar(input_zero_point, zero_scalar)) {
-    tensor = Subtract(tensor, Cast(input_zero_point, hp_dtype));
+    tensor = Subtract(tensor, Cast(input_zero_point, DataType::Int(32)));
   }
 
-  // Check if multiplier is greater than 1.
-  bool is_multiplier_gt_one = false;
-
   // 2) If the input and output scales are same, we can skip the fixed point multiplication. Check
   // if the input scale is per-tensor or per-channel. If it is per-tensor, there is single scale for
   // the whole tensor. For per-channel (aka per-axis), there is a vector of scales for the input
   // tensor. Depending on the quantization type, the fixed point multiplication routing is called.
-  auto scaled_int64_t = tensor;
+  auto scaled_int32_t = tensor;
   float output_scale_float = GetScalarFromConstant<float>(output_scale);
   if (IsConstScalar(input_scale)) {
     // This is per-tensor quantization. Single scale.
     float input_scale_float = GetScalarFromConstant<float>(input_scale);
     double double_multiplier =
         static_cast<double>(input_scale_float) / static_cast<double>(output_scale_float);
-    if (double_multiplier > 1) {
-      is_multiplier_gt_one = true;
-    }
     // Skip if input and output scales are same.
     if (!IsEqualScalar(input_scale, output_scale)) {
-      scaled_int64_t =
-          FixedPointMultiply(scaled_int64_t, double_multiplier, input_shape, param->rounding);
+      scaled_int32_t =
+          FixedPointMultiply(scaled_int32_t, double_multiplier, input_shape, param->rounding);
     }
   } else {
     // This is per-channel (per=axis) quantization.
@@ -171,30 +163,28 @@ Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale,
       double multiplier =
           static_cast<double>(input_axis_scale) / static_cast<double>(output_scale_float);
       double_multipliers.push_back(multiplier);
-      if (multiplier > 1) {
-        is_multiplier_gt_one = true;
-      }
     }
     int axis = param->axis;
     axis = (axis == -1) ? input_shape.size() - 1 : axis;
-    scaled_int64_t = FixedPointMultiplyPerChannel(scaled_int64_t, double_multipliers, input_shape,
+    scaled_int32_t = FixedPointMultiplyPerChannel(scaled_int32_t, double_multipliers, input_shape,
                                                   axis, param->rounding);
   }
 
   // 3) Add the output zero point.
-  auto shifted_int64_t = scaled_int64_t;
+  auto shifted_int32_t = scaled_int32_t;
   if (!IsEqualScalar(output_zero_point, zero_scalar)) {
-    shifted_int64_t = Add(Cast(output_zero_point, hp_dtype), scaled_int64_t);
+    shifted_int32_t = Add(Cast(output_zero_point, DataType::Int(32)), scaled_int32_t);
   }
 
   // 4) Clip to the out_dtype min/max. Skip clipping if out_dtype is Int32. The fixed point
-  // multiplication keeps the value in int32 range if the requantize scale is less than 1.
-  if (out_dtype == DataType::Int(32) && !is_multiplier_gt_one) {
-    return Cast(shifted_int64_t, out_dtype);
+  // multiplication keeps the value in int32 range.
+  if (out_dtype == DataType::Int(32)) {
+    return shifted_int32_t;
   }
+
   auto q_min = GetQmin(out_dtype);
   auto q_max = GetQmax(out_dtype);
-  auto clipped_t = Clip(shifted_int64_t, q_min, q_max);
+  auto clipped_t = Clip(shifted_int32_t, q_min, q_max);
   return Cast(clipped_t, out_dtype);
 }
 
diff --git a/src/relay/qnn/util.cc b/src/relay/qnn/util.cc
index 648de53..91fe3ca 100644
--- a/src/relay/qnn/util.cc
+++ b/src/relay/qnn/util.cc
@@ -80,6 +80,7 @@ Expr FixedPointMultiply(Expr tensor, double multiplier, const Array<IndexExpr>&
   // Choose high precision datatype to be int64. This is for avoiding overflow
   // in multiplication of two int32 values.
   DataType hp_dtype = DataType::Int(64);
+  tensor = Cast(tensor, hp_dtype);
 
   // 1) Calculating the integer multiplier and integer shift
   int32_t fixed_point_multiplier, shift;
@@ -130,7 +131,8 @@ Expr FixedPointMultiply(Expr tensor, double multiplier, const Array<IndexExpr>&
   tensor =
       RightShift(tensor, MakeConstantScalar(hp_dtype, total_right_shift));
 
-  return tensor;
+  // 6) The fixed point multiplication keeps the value in int32 range. Casting back to int32.
+  return Cast(tensor, DataType::Int(32));
 }
 
 Expr FixedPointMultiplyPerChannel(Expr tensor, std::vector<double> multipliers,
@@ -145,6 +147,7 @@ Expr FixedPointMultiplyPerChannel(Expr tensor, std::vector<double> multipliers,
   // Choose high precision datatype to be int64. This is for avoiding overflow
   // in multiplication of two int32 values.
   DataType hp_dtype = DataType::Int(64);
+  tensor = Cast(tensor, hp_dtype);
 
   // 1) Calculating the integer multiplier and integer shift. These are calculated per axis/per
   // channel.
@@ -218,7 +221,8 @@ Expr FixedPointMultiplyPerChannel(Expr tensor, std::vector<double> multipliers,
   auto exp_total_rshift_expr = ExpandBiasToMatchAxis(total_rshift_expr, n_dim, {channel_axis});
   tensor = RightShift(tensor, exp_total_rshift_expr);
 
-  return tensor;
+  // 6) The fixed point multiplication keeps the value in int32 range. Casting back to int32.
+  return Cast(tensor, DataType::Int(32));
 }
 
 }  // namespace qnn
diff --git a/src/relay/quantize/realize.cc b/src/relay/quantize/realize.cc
index 8e04a99..6d56e19 100644
--- a/src/relay/quantize/realize.cc
+++ b/src/relay/quantize/realize.cc
@@ -117,8 +117,7 @@ inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype,
   } else if (static_cast<int>(factor) == factor) {
     return Multiply(data, MakeConstantScalar(dtype, factor));
   } else {
-    data = qnn::FixedPointMultiply(
-        Cast(data, DataType::Int(64)), factor, data_shape, cfg->rounding);
+    data = qnn::FixedPointMultiply(data, factor, data_shape, cfg->rounding);
     return Cast(data, dtype);
   }
 }