You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by le...@apache.org on 2023/07/07 08:56:06 UTC

[tvm] branch main updated: [QNN] Support Dequantize to "float16" and Quantize to "uint16" (#15235)

This is an automated email from the ASF dual-hosted git repository.

leandron pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new d9d6a88a0a [QNN] Support Dequantize to "float16" and Quantize to "uint16" (#15235)
d9d6a88a0a is described below

commit d9d6a88a0aaafc63b954a1c435c7242af30083af
Author: Qiang Zhang <jo...@163.com>
AuthorDate: Fri Jul 7 16:56:00 2023 +0800

    [QNN] Support Dequantize to "float16" and Quantize to "uint16" (#15235)
---
 include/tvm/relay/qnn/attrs.h                |  2 ++
 python/tvm/relay/qnn/op/qnn.py               | 21 ++++++++++-------
 src/relay/qnn/op/dequantize.cc               | 28 ++++++++++++++++------
 src/relay/qnn/op/quantize.cc                 |  5 ++--
 src/relay/qnn/utils.h                        |  3 ++-
 tests/python/relay/test_op_qnn_dequantize.py | 35 +++++++++++++++++++++++++---
 tests/python/relay/test_op_qnn_quantize.py   | 23 ++++++++++++++++++
 7 files changed, 95 insertions(+), 22 deletions(-)

diff --git a/include/tvm/relay/qnn/attrs.h b/include/tvm/relay/qnn/attrs.h
index 64b2dc2098..85e0085286 100644
--- a/include/tvm/relay/qnn/attrs.h
+++ b/include/tvm/relay/qnn/attrs.h
@@ -95,9 +95,11 @@ struct SimulatedQuantizeAttrs : public tvm::AttrsNode<SimulatedQuantizeAttrs> {
 
 /*! \brief Attribute for dequantize operator */
 struct DequantizeAttrs : public tvm::AttrsNode<DequantizeAttrs> {
+  DataType out_dtype;
   int axis;
 
   TVM_DECLARE_ATTRS(DequantizeAttrs, "relay.attrs.DequantizeAttrs") {
+    TVM_ATTR_FIELD(out_dtype).describe("Output data type, can be one of [float16, float32].");
     TVM_ATTR_FIELD(axis)
         .describe(
             "The channel axis for channel wise dequantization. Default value is -1,"
diff --git a/python/tvm/relay/qnn/op/qnn.py b/python/tvm/relay/qnn/op/qnn.py
index eb64b56e82..968935f062 100644
--- a/python/tvm/relay/qnn/op/qnn.py
+++ b/python/tvm/relay/qnn/op/qnn.py
@@ -186,8 +186,8 @@ def requantize(
 
 def quantize(data, output_scale, output_zero_point, axis=-1, out_dtype="int8"):
     r"""Quantize op
-    This operator takes float32 as input and produces quantized int8 or unit8 as output.
-    The input tensor can be of any shape. The output shape is the same as input shape.
+    This operator takes float32 input and produces quantized output. The input
+    tensor can be of any shape. The output shape is the same as input shape.
 
     Q_output = clamp((round(input_tensor/output_scale) + output_zero_point),
                      out_dtype::min,
@@ -206,8 +206,9 @@ def quantize(data, output_scale, output_zero_point, axis=-1, out_dtype="int8"):
 
     axis : int
         The channel axis for quantization. Default value is -1 which corresponds to the last axis.
+
     out_dtype : str, optional
-        The data type of the input tensor. Can be [int8, uint8, int32]
+        The data type of the output tensor. Can be [int8, unit8, int16, uint16, int32].
 
     Returns
     -------
@@ -256,16 +257,15 @@ def simulated_quantize(data, output_scale, output_zero_point, axis=-1, out_dtype
     return _make.simulated_quantize(data, out_dtype, output_scale, output_zero_point, axis)
 
 
-def dequantize(data, input_scale, input_zero_point, axis=-1):
+def dequantize(data, input_scale, input_zero_point, axis=-1, out_dtype="float32"):
     r"""Dequantize op
-    This operator takes quantized int8 and unit8 as input and produces
-    dequantized float32 as output. The output shape is the same as input shape. The input
-    tensor can be of any shape.
+    This operator takes quantized input and produces dequantized float output.
+    The output shape is the same as input shape. The input tensor can be of any shape.
 
     Parameters
     ----------
     data : tvm.relay.Expr
-        The input tensor to be dequantized. Can be of type [int8, uint8, int32].
+        The input tensor to be dequantized. Can be of type [int8, unit8, int16, uint16, int32].
 
     input_scale : tvm.relay.Expr
         The input scale.
@@ -276,13 +276,16 @@ def dequantize(data, input_scale, input_zero_point, axis=-1):
     axis : int
         The channel axis for quantization. Default value is -1 which corresponds to the last axis.
 
+    out_dtype : str, optional
+        The data type of the output tensor. Can be [float16, float32].
+
     Returns
     -------
     result : tvm.relay.Expr
         The computed result.
     """
 
-    return _make.dequantize(data, input_scale, input_zero_point, axis)
+    return _make.dequantize(data, input_scale, input_zero_point, axis, out_dtype)
 
 
 def simulated_dequantize(data, input_scale, input_zero_point, axis=-1, in_dtype="int8"):
diff --git a/src/relay/qnn/op/dequantize.cc b/src/relay/qnn/op/dequantize.cc
index 1ddcde8123..5e2ef39eda 100644
--- a/src/relay/qnn/op/dequantize.cc
+++ b/src/relay/qnn/op/dequantize.cc
@@ -47,9 +47,10 @@ bool DequantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
 
   const auto input_dtype = data->dtype;
   ICHECK(input_dtype == DataType::Int(8) || input_dtype == DataType::UInt(8) ||
-         input_dtype == DataType::Int(16) || input_dtype == DataType::Int(32))
-      << "Input type should be one of the quantized types [unit8, int8, int16, int32] but was "
-      << input_dtype;
+         input_dtype == DataType::Int(16) || input_dtype == DataType::UInt(16) ||
+         input_dtype == DataType::Int(32))
+      << "Input type should be one of the quantized types [int8, unit8, int16, uint16, int32] but "
+      << "was " << input_dtype;
 
   const auto* dequantize_attrs = attrs.as<DequantizeAttrs>();
   int axis = dequantize_attrs->axis;
@@ -77,18 +78,24 @@ bool DequantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
   // Check and assign types for scale and zero points.
   AssignType(types[1], DataType::Float(32), axis_shape, reporter);  // scale
   AssignType(types[2], DataType::Int(32), axis_shape, reporter);    // zero point
+
   const Array<tvm::PrimExpr> oshape = data->shape;
-  // assign output type, output will always be float 32.
-  reporter->Assign(types[3], TensorType(oshape, DataType::Float(32)));
+  const DataType out_dtype = dequantize_attrs->out_dtype;
+  ICHECK(out_dtype == DataType::Float(16) || out_dtype == DataType::Float(32))
+      << "Output type should be one of [float16, float32] but was " << out_dtype;
+  // assign output type.
+  reporter->Assign(types[3], TensorType(oshape, out_dtype));
   return true;
 }
 
-Expr MakeDequantize(Expr data, Expr input_scale, Expr input_zero_point, int axis) {
+Expr MakeDequantize(Expr data, Expr input_scale, Expr input_zero_point, int axis,
+                    DataType out_dtype) {
   // real_value = scale * (quantized_value - zero_point)
   // A more detailed explanation can be found here -
   // https://github.com/google/gemmlowp/blob/master/doc/quantization.md
   auto attrs = make_object<DequantizeAttrs>();
   attrs->axis = axis;
+  attrs->out_dtype = out_dtype;
   static const Op& op = Op::Get("qnn.dequantize");
   return Call(op, {data, input_scale, input_zero_point}, Attrs(attrs), {});
 }
@@ -125,7 +132,14 @@ Expr DequantizeLower(const Expr& input_tensor, const Expr& input_scale,
 
   auto shift = Subtract(Cast(input_tensor, DataType::Int(32)), expanded_input_zero_point);
   auto scaled_output = Multiply(Cast(shift, DataType::Float(32)), expanded_input_scale);
-  return scaled_output;
+
+  const DataType out_dtype = attrs->out_dtype;
+  if (out_dtype.is_float() && out_dtype.bits() == 32) return scaled_output;
+
+  double min_val = tvm::min_value(out_dtype).as<FloatImmNode>()->value;
+  double max_val = tvm::max_value(out_dtype).as<FloatImmNode>()->value;
+  auto clamped_output = Clip(scaled_output, min_val, max_val);
+  return Cast(clamped_output, out_dtype);
 }
 
 Expr DequantizeQnnCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
diff --git a/src/relay/qnn/op/quantize.cc b/src/relay/qnn/op/quantize.cc
index 1a16705932..8ed1f9ef4c 100644
--- a/src/relay/qnn/op/quantize.cc
+++ b/src/relay/qnn/op/quantize.cc
@@ -91,8 +91,9 @@ bool QuantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
   const Array<tvm::PrimExpr> oshape = data->shape;
   const DataType out_dtype = quantize_attrs->out_dtype;
   ICHECK(out_dtype == DataType::Int(8) || out_dtype == DataType::UInt(8) ||
-         out_dtype == DataType::Int(16) || out_dtype == DataType::Int(32))
-      << "Output type should be one of [int8, unit8, int16, int32] but was " << out_dtype;
+         out_dtype == DataType::Int(16) || out_dtype == DataType::UInt(16) ||
+         out_dtype == DataType::Int(32))
+      << "Output type should be one of [int8, unit8, int16, uint16, int32] but was " << out_dtype;
   // assign output type
   reporter->Assign(types[3], TensorType(oshape, out_dtype));
   return true;
diff --git a/src/relay/qnn/utils.h b/src/relay/qnn/utils.h
index 5005d60685..4102fb29a6 100644
--- a/src/relay/qnn/utils.h
+++ b/src/relay/qnn/utils.h
@@ -135,7 +135,8 @@ static inline Expr Dequantize(const Expr& data, const Expr& input_scale,
 
   return DequantizeLower(data, input_scale, input_zero_point, types, attrs.operator->());
 }
-Expr MakeDequantize(Expr data, Expr input_scale, Expr input_zero_point, int axis);
+Expr MakeDequantize(Expr data, Expr input_scale, Expr input_zero_point, int axis,
+                    DataType out_dtype = DataType::Float(32));
 
 Expr QuantizeLower(const Expr& input_tensor, const Expr& output_scale,
                    const Expr& output_zero_point, const Array<tvm::relay::Type>& types,
diff --git a/tests/python/relay/test_op_qnn_dequantize.py b/tests/python/relay/test_op_qnn_dequantize.py
index b332bd94f3..3b2ae97eb6 100644
--- a/tests/python/relay/test_op_qnn_dequantize.py
+++ b/tests/python/relay/test_op_qnn_dequantize.py
@@ -23,13 +23,19 @@ from tvm.contrib import graph_executor
 from tvm.relay.testing import run_infer_type
 
 
-def dequantize_test_driver(in_dtype, quant_args, in_data, verify_output_data, axis):
+def dequantize_test_driver(
+    in_dtype, quant_args, in_data, verify_output_data, axis, out_dtype="float32"
+):
     shape = in_data.shape
     input_data = relay.var("input_data", shape=shape, dtype=in_dtype)
     input_zero_point = relay.const(quant_args["in_zero_point"], "int32")
     input_scale = relay.const(quant_args["in_scale"], "float32")
     quantized_output = relay.qnn.op.dequantize(
-        input_data, input_scale=input_scale, input_zero_point=input_zero_point, axis=axis
+        input_data,
+        input_scale=input_scale,
+        input_zero_point=input_zero_point,
+        axis=axis,
+        out_dtype=out_dtype,
     )
     mod = relay.Function(relay.analysis.free_vars(quantized_output), quantized_output)
     mod = tvm.IRModule.from_expr(mod)
@@ -41,7 +47,7 @@ def dequantize_test_driver(in_dtype, quant_args, in_data, verify_output_data, ax
         rt_mod.run()
         res = rt_mod.get_output(0).numpy()
         np.testing.assert_equal(res, verify_output_data)
-        assert res.dtype == np.float32
+        assert res.dtype == out_dtype
 
 
 def test_uint8_to_float32():
@@ -74,6 +80,28 @@ def test_int8_to_float32():
     )
 
 
+def test_int8_to_float16():
+    data = (
+        np.array([-128, -127, -126, -125, -124, 123, 124, 125, 126, 127])
+        .astype("int8")
+        .reshape((2, 5))
+    )
+    output = (
+        np.array([-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64])
+        .astype("float16")
+        .reshape((2, 5))
+    )
+    quant_args = {"in_zero_point": -1, "in_scale": 0.5}
+    dequantize_test_driver(
+        in_dtype="int8",
+        quant_args=quant_args,
+        in_data=data,
+        verify_output_data=output,
+        axis=-1,
+        out_dtype="float16",
+    )
+
+
 def test_scalar_int8_to_float32():
     data = np.array(-128).astype("int8")
     output = np.array(-63.5).astype("float32")
@@ -171,6 +199,7 @@ def test_dynamic_dequantize():
 if __name__ == "__main__":
     test_uint8_to_float32()
     test_int8_to_float32()
+    test_int8_to_float16()
     test_scalar_int8_to_float32()
     test_int32_to_float32()
     test_channelwise_axis_1()
diff --git a/tests/python/relay/test_op_qnn_quantize.py b/tests/python/relay/test_op_qnn_quantize.py
index 322382ca00..3a3521b11e 100644
--- a/tests/python/relay/test_op_qnn_quantize.py
+++ b/tests/python/relay/test_op_qnn_quantize.py
@@ -88,6 +88,28 @@ def test_float32_to_int8():
     )
 
 
+def test_float32_to_uint16():
+    data = (
+        np.array([-6553, -6552.8, -6552.6, -6552.4, -6552.2, 6553.2, 6553.4, 6553.6, 6553.8, 6554])
+        .astype("float32")
+        .reshape((2, 5))
+    )
+    output = (
+        np.array([0, 1, 2, 3, 4, 65531, 65532, 65533, 65534, 65535])
+        .astype("uint16")
+        .reshape((2, 5))
+    )
+    quant_args = {"out_zero_point": np.int32(32765), "out_scale": np.float32(0.2)}
+    quantize_test_driver(
+        in_dtype="float32",
+        quant_args=quant_args,
+        axis=-1,
+        out_dtype="uint16",
+        in_data=data,
+        verify_output_data=output,
+    )
+
+
 def test_scalar_float32_to_int8():
     data = np.array(-63.5).astype("float32")
     output = np.array(-128).astype("int8")
@@ -177,6 +199,7 @@ def test_dynamic_quantize():
 if __name__ == "__main__":
     test_float32_to_uint8()
     test_float32_to_int8()
+    test_float32_to_uint16()
     test_scalar_float32_to_int8()
     test_channelwise_axis_0()
     test_channelwise_axis_1()