You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/02/01 18:29:52 UTC

[GitHub] [tvm] anwang2009 commented on a change in pull request #10086: [QNN] Register a bunch of unary elementwise ops

anwang2009 commented on a change in pull request #10086:
URL: https://github.com/apache/tvm/pull/10086#discussion_r796287746



##########
File path: src/relay/transforms/pattern_utils.h
##########
@@ -520,8 +525,8 @@ inline Expr FastSoftmax(Expr e, tvm::Attrs attr) {
   return Call(op, {e}, attr);
 }
 
-inline Expr Log(Expr e) {
-  static const Op& op = Op::Get("log");

Review comment:
       why remove log?

##########
File path: src/relay/qnn/op/op_common.h
##########
@@ -289,6 +289,89 @@ static inline bool QnnBroadcastRel(const Array<Type>& types, int num_inputs, con
       .set_attr<TNonComputational>("TNonComputational", true)                                      \
       .set_attr<FInferCorrectLayout>("FInferCorrectLayout", QnnBinaryBroadcastLayout)
 
+static inline bool QnnElementwiseUnaryFuncRel(const Array<Type>& types, int num_inputs,
+                                              const Attrs& attrs, const TypeReporter& reporter) {
+  // Expected Types: data, scale, zero_point, output_scale, output_zero_point
+  ICHECK_EQ(types.size(), 6);
+  const auto* x = types[0].as<TensorTypeNode>();
+  if (x == nullptr) return false;
+  ICHECK(x->dtype == DataType::Int(8) || x->dtype == DataType::UInt(8))
+      << "Expected quantized type(int8, uint8) for input but was " << x->dtype;
+
+  // Check the types of scale and zero points.
+  for (size_t i = 1; i < 5; ++i) {
+    if (types[i].as<IncompleteTypeNode>()) {
+      return false;
+    }
+  }
+  ICHECK(IsScalarType(types[1], DataType::Float(32)));  // scale
+  ICHECK(IsScalarType(types[2], DataType::Int(32)));    // zero_point
+  ICHECK(IsScalarType(types[3], DataType::Float(32)));  // output_scale
+  ICHECK(IsScalarType(types[4], DataType::Int(32)));    // output_zero_point
+
+  // Assign types for scale and zero points.
+  reporter->Assign(types[1], TensorType({}, DataType::Float(32)));  // scale
+  reporter->Assign(types[2], TensorType({}, DataType::Int(32)));    // zero_point
+  reporter->Assign(types[3], TensorType({}, DataType::Float(32)));  // output_scale
+  reporter->Assign(types[4], TensorType({}, DataType::Int(32)));    // output_zero_point
+
+  // Collect the input tensor and output tensor devoid of scale and zero points to reuse Relay
+  // IdentityRel infer type function.
+  Array<Type> tensor_types = {types[0], types[5]};
+  return IdentityRel(tensor_types, 2, attrs, reporter);
+}
+
+/*! Quick helper macro
+ * - Expose a positional make function to construct the node.
+ * - Register op to the registry.
+ *
+ * For Unary Operators which also take in QParams.
+ *
+ * \param OpName the name of registry.
+ */
+#define QNN_CREATE_UNARY_ELEMENTWISE_OP(OpName)                                                 \
+  TVM_REGISTER_GLOBAL("relay.qnn.op._make." OpName)                                             \
+      .set_body_typed(                                                                          \
+          [](Expr x, Expr scale, Expr zero_point, Expr output_scale, Expr output_zero_point) {  \
+            return Call(Op::Get("qnn." OpName),                                                 \
+                        {x, scale, zero_point, output_scale, output_zero_point}, Attrs(), {});  \
+          });                                                                                   \
+                                                                                                \
+  RELAY_REGISTER_OP("qnn." OpName)                                                              \
+      .describe("Elementwise " OpName " for quantized tensors.")                                \
+      .set_num_inputs(5)                                                                        \
+      .add_argument("data", "Quantized Tensor", "The input data.")                              \
+      .add_argument("scale", "Tensor", "The quantization scale of the input tensor.")           \
+      .add_argument("zero_point", "Tensor", "The quantization zero_point of the input tensor.") \
+      .add_argument("output_scale", "Tensor", "The quantization scale of the output tensor.")   \
+      .add_argument("output_zero_point", "Tensor",                                              \
+                    "The quantization zero_point of the output tensor.")                        \
+      .set_support_level(11)                                                                    \
+      .add_type_rel("qnn." OpName, QnnElementwiseUnaryFuncRel)                                  \
+      .set_attr<TNonComputational>("TNonComputational", true)
+
+/*! Quick helper macro
+ * Create a default canonicalization for a QNN operator, which dequantizes the operator
+ * runs the calculation using the provided Call func, and then requantizes.
+ *
+ * FloatingPointFunc is usually a handle from "src/relay/transforms/pattern_utils.h"
+ *
+ * \param OpName the name of registry.

Review comment:
       nit: update this to "FloatingPointFunc" description

##########
File path: src/relay/qnn/op/op_common.h
##########
@@ -289,6 +289,89 @@ static inline bool QnnBroadcastRel(const Array<Type>& types, int num_inputs, con
       .set_attr<TNonComputational>("TNonComputational", true)                                      \
       .set_attr<FInferCorrectLayout>("FInferCorrectLayout", QnnBinaryBroadcastLayout)
 
+static inline bool QnnElementwiseUnaryFuncRel(const Array<Type>& types, int num_inputs,
+                                              const Attrs& attrs, const TypeReporter& reporter) {
+  // Expected Types: data, scale, zero_point, output_scale, output_zero_point
+  ICHECK_EQ(types.size(), 6);
+  const auto* x = types[0].as<TensorTypeNode>();
+  if (x == nullptr) return false;
+  ICHECK(x->dtype == DataType::Int(8) || x->dtype == DataType::UInt(8))
+      << "Expected quantized type(int8, uint8) for input but was " << x->dtype;
+
+  // Check the types of scale and zero points.
+  for (size_t i = 1; i < 5; ++i) {
+    if (types[i].as<IncompleteTypeNode>()) {
+      return false;
+    }
+  }
+  ICHECK(IsScalarType(types[1], DataType::Float(32)));  // scale
+  ICHECK(IsScalarType(types[2], DataType::Int(32)));    // zero_point
+  ICHECK(IsScalarType(types[3], DataType::Float(32)));  // output_scale
+  ICHECK(IsScalarType(types[4], DataType::Int(32)));    // output_zero_point
+
+  // Assign types for scale and zero points.
+  reporter->Assign(types[1], TensorType({}, DataType::Float(32)));  // scale
+  reporter->Assign(types[2], TensorType({}, DataType::Int(32)));    // zero_point
+  reporter->Assign(types[3], TensorType({}, DataType::Float(32)));  // output_scale
+  reporter->Assign(types[4], TensorType({}, DataType::Int(32)));    // output_zero_point
+
+  // Collect the input tensor and output tensor devoid of scale and zero points to reuse Relay
+  // IdentityRel infer type function.
+  Array<Type> tensor_types = {types[0], types[5]};
+  return IdentityRel(tensor_types, 2, attrs, reporter);
+}
+
+/*! Quick helper macro
+ * - Expose a positional make function to construct the node.
+ * - Register op to the registry.
+ *
+ * For Unary Operators which also take in QParams.
+ *
+ * \param OpName the name of registry.
+ */
+#define QNN_CREATE_UNARY_ELEMENTWISE_OP(OpName)                                                 \
+  TVM_REGISTER_GLOBAL("relay.qnn.op._make." OpName)                                             \
+      .set_body_typed(                                                                          \
+          [](Expr x, Expr scale, Expr zero_point, Expr output_scale, Expr output_zero_point) {  \
+            return Call(Op::Get("qnn." OpName),                                                 \
+                        {x, scale, zero_point, output_scale, output_zero_point}, Attrs(), {});  \
+          });                                                                                   \
+                                                                                                \
+  RELAY_REGISTER_OP("qnn." OpName)                                                              \
+      .describe("Elementwise " OpName " for quantized tensors.")                                \
+      .set_num_inputs(5)                                                                        \
+      .add_argument("data", "Quantized Tensor", "The input data.")                              \
+      .add_argument("scale", "Tensor", "The quantization scale of the input tensor.")           \
+      .add_argument("zero_point", "Tensor", "The quantization zero_point of the input tensor.") \
+      .add_argument("output_scale", "Tensor", "The quantization scale of the output tensor.")   \
+      .add_argument("output_zero_point", "Tensor",                                              \
+                    "The quantization zero_point of the output tensor.")                        \
+      .set_support_level(11)                                                                    \
+      .add_type_rel("qnn." OpName, QnnElementwiseUnaryFuncRel)                                  \
+      .set_attr<TNonComputational>("TNonComputational", true)
+
+/*! Quick helper macro
+ * Create a default canonicalization for a QNN operator, which dequantizes the operator
+ * runs the calculation using the provided Call func, and then requantizes.
+ *
+ * FloatingPointFunc is usually a handle from "src/relay/transforms/pattern_utils.h"
+ *
+ * \param OpName the name of registry.
+ */
+#define QNN_UNARY_OP_DEFAULT_CANONICALIZATION(FloatingPointFunc)                                  \
+  [](const Attrs& attrs, const Array<Expr>& new_args, const Array<tvm::relay::Type>& arg_types) { \
+    QnnUnaryOpArguments args(new_args);                                                           \
+    QnnUnaryOpTensorType input_type(arg_types, 0);                                                \
+    Array<tvm::relay::Type> types;                                                                \
+    for (size_t i = 1; i < 5; ++i) {                                                              \
+      types.push_back(arg_types[i]);                                                              \
+    }                                                                                             \
+    auto dequantized_arg = Dequantize(args.x, args.scale, args.zero_point, types, -1);            \

Review comment:
       Looks like Dequantize -> DequantizeLower expects `types` to start with the input data type, but in this code `types` starts with the scale type.
   
   https://github.com/apache/tvm/blob/main/src/relay/qnn/op/dequantize.cc#L99-L105

##########
File path: tests/python/relay/test_op_qnn_unary_elementwise.py
##########
@@ -0,0 +1,214 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from typing import Callable, List
+
+import numpy as np
+import pytest
+import scipy.special
+import tvm
+from tvm import relay
+
+
+def dequantize(data, scale, zp):
+    return scale * (np.asarray(data) - zp)
+
+
+def generate_golden_output(
+    floating_point_golden_func, dequantized_x, output_scale, output_zero_point, dtype
+):
+    output = floating_point_golden_func(dequantized_x)
+    output = np.around(output / output_scale + output_zero_point)
+
+    np_dtype = {"int8": np.int8, "uint8": np.uint8}[dtype]
+
+    q_min = np.iinfo(np_dtype).min
+    q_max = np.iinfo(np_dtype).max
+    return np.clip(output, q_min, q_max)
+
+
+def run_qnn_func(func: relay.Function, args: List[relay.Expr]):
+    mod = tvm.IRModule.from_expr(func)
+    mod = relay.transform.InferType()(mod)
+    mod = relay.qnn.transform.Legalize()(mod)
+    mod = relay.qnn.transform.CanonicalizeOps()(mod)
+    func = mod["main"]
+
+    op_res = relay.create_executor("graph", device=tvm.cpu(0), target="llvm").evaluate(func)(*args)
+    return op_res.numpy()
+
+
+def create_qnn_func(
+    qnn_op: Callable[[relay.Expr, relay.Expr, relay.Expr, relay.Expr, relay.Expr], relay.Call],
+    x_data: np.ndarray,
+    input_scale: float,
+    input_zero_point: int,
+    output_scale: float,
+    output_zero_point: int,
+    input_dtype: str = "uint8",
+):
+    x = relay.var("x", shape=x_data.shape, dtype=input_dtype)
+    y = qnn_op(
+        x=x,
+        scale=relay.const(input_scale, "float32"),
+        zero_point=relay.const(input_zero_point, "int32"),
+        output_scale=relay.const(output_scale, "float32"),
+        output_zero_point=relay.const(output_zero_point, "int32"),
+    )
+    return relay.Function([x], y)
+
+
+def run_condition(
+    qnn_op: Callable[[relay.Expr, relay.Expr, relay.Expr, relay.Expr, relay.Expr], relay.Call],
+    floating_point_golden_func: Callable[[np.ndarray], np.ndarray],
+    x_data: np.ndarray,
+    input_scale: float,
+    input_zero_point: int,
+    output_scale: float,
+    output_zero_point: int,
+    input_dtype: str = "uint8",
+):
+    func = create_qnn_func(
+        qnn_op,
+        x_data,
+        input_scale=input_scale,
+        input_zero_point=input_zero_point,
+        output_scale=output_scale,
+        output_zero_point=output_zero_point,
+        input_dtype=input_dtype,
+    )
+
+    x_dequantized = dequantize(x_data, input_scale, input_zero_point)
+    golden_output = generate_golden_output(
+        floating_point_golden_func,
+        x_dequantized,
+        output_scale,
+        output_zero_point,
+        dtype=input_dtype,
+    )
+
+    op_res = run_qnn_func(func, [x_data])
+    np.testing.assert_equal(op_res, golden_output.astype(input_dtype))
+
+
+def generic_test(
+    qnn_op: Callable[[relay.Expr, relay.Expr, relay.Expr, relay.Expr, relay.Expr], relay.Call],
+    floating_point_golden_func: Callable[[np.ndarray], np.ndarray],
+    input_dtype: str = "uint8",
+    x_data: np.ndarray = np.arange(0, 256, dtype="uint8"),
+):
+    x_data = x_data.view(input_dtype)
+    return run_condition(
+        qnn_op,
+        floating_point_golden_func,
+        x_data,
+        input_scale=0.125,
+        input_zero_point=0,
+        output_scale=0.125,
+        output_zero_point=0,
+        input_dtype=input_dtype,
+    )
+
+
+class TestRSqrt:
+    def test_saturation(self):
+        # Same qparams in and out
+        x_data = np.array((255, 133, 0, 9)).reshape((1, 4))
+        run_condition(
+            relay.qnn.op.rsqrt,
+            lambda x: 1 / np.sqrt(x),
+            x_data,
+            input_scale=0.125,
+            input_zero_point=0,
+            output_scale=0.125,
+            output_zero_point=0,
+            input_dtype="uint8",
+        )
+
+        # Different scale
+        run_condition(
+            relay.qnn.op.rsqrt,
+            lambda x: 1 / np.sqrt(x),
+            x_data,
+            input_scale=0.125,
+            input_zero_point=0,
+            output_scale=0.25,
+            output_zero_point=0,
+            input_dtype="uint8",
+        )
+
+    def test_all_numbers_uint8(self):
+        generic_test(relay.qnn.op.rsqrt, lambda x: 1 / np.sqrt(x), input_dtype="uint8")
+
+    def test_all_numbers_int8(self):
+        generic_test(
+            relay.qnn.op.rsqrt,
+            lambda x: 1 / np.sqrt(x),
+            input_dtype="int8",
+            x_data=np.arange(1, 128, dtype="int8"),
+        )
+
+
+class Sqrt:
+    def test_all_numbers_uint8(self):
+        generic_test(relay.qnn.op.sqrt, np.sqrt, input_dtype="uint8")
+
+    def test_all_numbers_int8(self):
+        generic_test(
+            relay.qnn.op.sqrt,
+            np.sqrt,
+            input_dtype="int8",
+            x_data=np.arange(1, 128, dtype="int8"),

Review comment:
       any reason this test has `x_data` specified but the other int8 tests don't?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org