You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by an...@apache.org on 2022/09/19 19:38:39 UTC

[tvm] 27/28: div impl

This is an automated email from the ASF dual-hosted git repository.

andrewzhaoluo pushed a commit to branch aluo/rebase-09192022-autotensorization
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit b8a3df098aa5e9d92d0d7202af5250ddf24e9709
Author: Andrew Zhao Luo <an...@gmail.com>
AuthorDate: Fri Sep 16 15:47:02 2022 -0700

    div impl
---
 python/tvm/relay/qnn/op/qnn.py                     |  68 ++++++++++++
 .../transform/fake_quantization_to_integer.py      |  88 +++++++++++++++-
 src/relay/qnn/op/div.cc                            | 117 +++++++++++++++++++++
 3 files changed, 272 insertions(+), 1 deletion(-)

diff --git a/python/tvm/relay/qnn/op/qnn.py b/python/tvm/relay/qnn/op/qnn.py
index 1f38385107..6d1cabeb8d 100644
--- a/python/tvm/relay/qnn/op/qnn.py
+++ b/python/tvm/relay/qnn/op/qnn.py
@@ -788,6 +788,74 @@ def mul(
     )
 
 
+def div(
+    lhs,
+    rhs,
+    lhs_scale,
+    lhs_zero_point,
+    rhs_scale,
+    rhs_zero_point,
+    output_scale,
+    output_zero_point,
+    lhs_axis=-1,
+    rhs_axis=-1,
+):
+    """Quantized division with numpy-style broadcasting.
+
+    Parameters
+    ----------
+    lhs : relay.Expr
+        The left hand side quantized input data.
+
+    rhs : relay.Expr
+        The right hand side quantized input data.
+
+    lhs_scale: relay.Expr
+        The scale of the lhs quantized expr.
+
+    lhs_zero_point: relay.Expr
+       The zero point of lhs quantized expr.
+
+    rhs_scale: relay.Expr
+        The scale of the rhs quantized expr.
+
+    rhs_zero_point: relay.Expr
+       The zero point of rhs quantized expr.
+
+    output_scale: relay.Expr
+        The scale of the output quantized expr.
+
+    output_zero_point: relay.Expr
+       The zero point of output quantized expr.
+
+    lhs_axis: int
+        The channel axis for lhs quantization. Default value is -1 which corresponds
+        to the last axis.
+
+    rhs_axis: int
+        The channel axis for rhs quantization. Default value is -1 which corresponds
+        to the last axis.
+
+    Returns
+    -------
+    result : relay.Expr
+        The computed result.
+
+    """
+    return _make.div(
+        lhs,
+        rhs,
+        lhs_scale,
+        lhs_zero_point,
+        rhs_scale,
+        rhs_zero_point,
+        output_scale,
+        output_zero_point,
+        lhs_axis,
+        rhs_axis,
+    )
+
+
 def tanh(x, scale, zero_point, output_scale, output_zero_point):
     """Quantized tanh.
 
diff --git a/python/tvm/relay/transform/fake_quantization_to_integer.py b/python/tvm/relay/transform/fake_quantization_to_integer.py
index 242740399f..82afb1b4c3 100644
--- a/python/tvm/relay/transform/fake_quantization_to_integer.py
+++ b/python/tvm/relay/transform/fake_quantization_to_integer.py
@@ -19,6 +19,7 @@ import numpy as np
 import tvm
 from tvm import relay
 from tvm.ir import TensorAffineType, TupleAffineType
+from tvm.relay.op.tensor import ones_like
 
 # import to register canonicalization funcs for fq2i
 # pylint: disable=unused-import
@@ -198,6 +199,60 @@ def broadcast_to(expr, type_map):
     return [out, t]
 
 
+@register_fake_quantization_to_integer("take")
+def take(expr, type_map):
+    """Rewrite a take op"""
+    arg1 = expr.args[0]
+    t = type_map[arg1]
+    arg2 = expr.args[1]
+    out = relay.op.take(
+        arg1,
+        arg2,
+        axis=expr.attrs.axis,
+        batch_dims=expr.attrs.batch_dims,
+        mode=expr.attrs.mode,
+    )
+    return [out, t]
+
+
+@register_fake_quantization_to_integer("power")
+def power(expr, type_map):
+    base = expr.args[0]
+    exponent = expr.args[1]
+
+    base_type = type_map[base]
+
+    if not isinstance(exponent, relay.Constant):
+        return [expr, type_map[expr]]
+
+    data = exponent.data.numpy()
+    if not len(data.shape) == 0:
+        return [expr, type_map[expr]]
+
+    data = data.item()
+    if data != 2:
+        return [expr, type_map[expr]]
+
+    out = relay.qnn.op.mul(
+        base,
+        base,
+        base_type.scale,
+        base_type.zero_point,
+        base_type.scale,
+        base_type.zero_point,
+        output_scale=base_type.scale * base_type.scale,
+        output_zero_point=base_type.zero_point,
+        lhs_axis=base_type.axis,
+        rhs_axis=base_type.axis,
+    )
+    return [
+        out,
+        TensorAffineType(
+            base_type.scale * base_type.scale, base_type.zero_point, base_type.dtype, base_type.axis
+        ),
+    ]
+
+
 @register_fake_quantization_to_integer("nn.bias_add")
 def bias_add(expr, type_map):
     """Rewrite a bias_add op"""
@@ -520,6 +575,37 @@ def register_binary_qnn(op_name, op):
 register_binary_qnn("add", lambda *args: relay.qnn.op.add(*args))
 register_binary_qnn("multiply", lambda *args: relay.qnn.op.mul(*args))
 register_binary_qnn("subtract", lambda *args: relay.qnn.op.subtract(*args))
+register_binary_qnn("divide", lambda *args: relay.qnn.op.div(*args))
+
+
+'''
+@register_fake_quantization_to_integer("divide")
+def divide(expr, type_map):
+    """Rewrite an adaptive avgpool op"""
+    numerator = expr.args[0]
+    denominator = expr.args[1]
+    numerator_t = type_map[numerator]
+    denominator_t = type_map[denominator]
+    new_scale = numerator_t.scale / (denominator_t.scale * (denominator - denominator_t.zero_point))
+    out = relay.divide(numerator, ones_like(denominator))
+    assert numerator_t.axis == denominator_t.axis, "Only support identical axis for now."
+    # print(out)
+
+    print("new out:")
+    str_new_out = str(relay.transform.InferType()(tvm.IRModule.from_expr(out)))
+    print("\n".join(str_new_out.split("\n")[-10:]))
+    print("old_out:")
+    str_old_out = str(relay.transform.InferType()(tvm.IRModule.from_expr(expr)))
+    print("\n".join(str_old_out.split("\n")[-10:]))
+    print()
+    breakpoint()
+    # print("yay!")
+    # This is to get broadcasting working to get same shape
+    return [
+        out,
+        TensorAffineType(new_scale, numerator_t.zero_point, numerator_t.dtype, numerator_t.axis),
+    ]
+'''
 
 
 def register_binary_identity(op_name, op):
@@ -585,4 +671,4 @@ register_unary_qnn("sigmoid", relay.qnn.op.sigmoid)
 register_unary_qnn("hardswish", relay.qnn.op.hardswish)
 register_unary_qnn("tanh", relay.qnn.op.tanh)
 register_unary_qnn("abs", relay.qnn.op.abs)
-register_unary_qnn("log", relay.qnn.op.log)
+register_unary_qnn("log", relay.qnn.op.log)
\ No newline at end of file
diff --git a/src/relay/qnn/op/div.cc b/src/relay/qnn/op/div.cc
new file mode 100644
index 0000000000..3c37ed41c4
--- /dev/null
+++ b/src/relay/qnn/op/div.cc
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/qnn/op/mul.cc
+ * \brief QNN mul operator.
+ */
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/relay/qnn/attrs.h>
+
+#include "../../transforms/pattern_utils.h"
+#include "../utils.h"
+#include "op_common.h"
+
+namespace tvm {
+namespace relay {
+namespace qnn {
+
+/*
+ * \brief Canonicalizes the QNN div op.
+ * \param attrs The QNN div attrs.
+ * \param new_args The new mutated args to the call node.
+ * \param arg_types The types of input and output.
+ * \return The sequence of Relay ops for mul op.
+ */
+Expr QnnDivCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
+                        const Array<tvm::relay::Type>& arg_types) {
+  Expr output;
+
+  // Get the attrs.
+  QnnBinaryOpArguments args(new_args);
+
+  // Get the input dtype and shape.
+  QnnBinaryOpTensorType input_type(arg_types, 0);
+
+  // data types
+  const auto int32_dtype = DataType::Int(32);
+  const auto float32_dtype = DataType::Float(32);
+
+  const auto* broadcast_attrs = attrs.as<BroadcastAttrs>();
+  ICHECK(broadcast_attrs != nullptr);
+
+  if (IsConstScalar(args.lhs_scale) && IsConstScalar(args.rhs_scale)) {
+    /* If both are constant:
+
+    n1/n2 = [s1(q1-z1)] / [s2(q2-z2)]
+    n1/n2 = [s1/s2][(q1-z1)/(q2-z2)]
+
+    As [(q1-z1)/(q2-z2)] is integer division, we lose perhaps significant precision.
+    To get around this we scale the numerator by C to ensure that
+
+    |C(q1-z1)| >> (q2 - z2) and the precision loss from the division is minimal:
+
+    n1/n2 = [s1/(s2 * C)][C(q1-z1)/(q2-z2)]
+    */
+
+    auto lhs_shifted = Cast(args.lhs, int32_dtype);
+    auto rhs_shifted = Cast(args.rhs, int32_dtype);
+
+    auto zero_scalar = MakeConstantScalar(int32_dtype, 0);
+    if (!IsEqualScalar(args.lhs_zero_point, zero_scalar)) {
+      lhs_shifted = Subtract(lhs_shifted, args.lhs_zero_point);
+    }
+
+    if (!IsEqualScalar(args.rhs_zero_point, zero_scalar)) {
+      rhs_shifted = Subtract(rhs_shifted, args.rhs_zero_point);
+    }
+
+    // multiply numerator to avoid precision loss, as accumulate in INT32 and
+    // may deal with UINT16, multiply by 2^15
+    int divide_scale_factor = 32768;
+    auto divide_scale_factor_constant = MakeConstantScalar(int32_dtype, divide_scale_factor);
+    output = Divide(Multiply(lhs_shifted, divide_scale_factor_constant), rhs_shifted);
+
+    // Get the adjusted new scale and zero points.
+    float lhs_scale_float = GetScalarFromConstant<float>(args.lhs_scale);
+    float rhs_scale_float = GetScalarFromConstant<float>(args.rhs_scale);
+    float new_scale_float = lhs_scale_float / (rhs_scale_float * divide_scale_factor);
+    auto new_input_scale = MakeConstantScalar(float32_dtype, new_scale_float);
+    auto new_input_zero_point = zero_scalar;
+
+    // Requantize to get Q_c
+    output = Requantize(output, input_type.shape, new_input_scale, new_input_zero_point,
+                        args.output_scale, args.output_zero_point, input_type.dtype);
+  } else {
+    LOG(FATAL) << "Non-constant scale_factor not supported yet.";
+  }
+
+  return output;
+}
+
+// QNN Multiplication operator.
+QNN_REGISTER_BINARY_OP("div")
+    .describe("Elementwise div with broadcasting for quantized tensors.")
+    .set_support_level(11)
+    .set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnDivCanonicalize);
+
+}  // namespace qnn
+}  // namespace relay
+}  // namespace tvm