You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by an...@apache.org on 2022/09/19 23:38:27 UTC
[tvm] 01/01: changes change
This is an automated email from the ASF dual-hosted git repository.
andrewzhaoluo pushed a commit to branch aluo/autotensorization-redux-09192022
in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 72d4ca5218e1a1ee0378bc6681eff756598e3598
Author: Andrew Zhao Luo <an...@gmail.com>
AuthorDate: Mon Sep 19 12:58:43 2022 -0700
changes change
---
python/tvm/meta_schedule/default_config.py | 108 ++++++++++++++++++-
python/tvm/relay/op/contrib/dnnl.py | 64 ++++++++---
python/tvm/relay/qnn/op/qnn.py | 68 ++++++++++++
python/tvm/relay/qnn/transform.py | 78 ++++++++++++++
.../transform/fake_quantization_to_integer.py | 86 +++++++++++++++
src/relay/qnn/op/div.cc | 117 +++++++++++++++++++++
6 files changed, 501 insertions(+), 20 deletions(-)
diff --git a/python/tvm/meta_schedule/default_config.py b/python/tvm/meta_schedule/default_config.py
index ac4028ec50..eaa026e3b4 100644
--- a/python/tvm/meta_schedule/default_config.py
+++ b/python/tvm/meta_schedule/default_config.py
@@ -20,6 +20,8 @@ import logging
from os import path as osp
from typing import Callable, Dict, List, Optional, Union
+from tvm._ffi.registry import register_func
+from tvm.contrib import nvcc
from tvm.ir import IRModule
from tvm.target import Target
from tvm.tir import PrimFunc
@@ -43,6 +45,20 @@ FnPostproc = Callable[[], List[Postproc]]
FnMutatorProb = Callable[[], Dict[Mutator, float]]
+def target_has_vnni(target):
+ return target in {
+ "cascadelake",
+ "icelake-client",
+ "icelake-server",
+ "rocketlake",
+ "tigerlake",
+ "cooperlake",
+ "sapphirerapids",
+ "alderlake",
+ }
+
+
+@register_func("tvm.meta_schedule.tune.parse_mod") # for use in ApplyHistoryBest
def mod(mod: Union[PrimFunc, IRModule]) -> IRModule: # pylint: disable=redefined-outer-name
"""Normalize the input to an IRModule"""
if isinstance(mod, PrimFunc):
@@ -174,9 +190,13 @@ def schedule_rules( # pylint: disable=redefined-outer-name
return sch_rules()
if sch_rules is not None:
raise TypeError(f"Expected `sch_rules` to be None or callable, but gets: {sch_rules}")
- if target.kind.name in ["llvm", "hexagon"]:
+ if target.kind.name == "llvm":
+ if target_has_vnni(target.mcpu):
+ return _DefaultLLVMVNNI.schedule_rules()
return _DefaultLLVM.schedule_rules()
if target.kind.name in ["cuda", "rocm", "vulkan"]:
+ if target.kind.name == "cuda" and nvcc.have_tensorcore(target=target):
+ return _DefaultCUDATensorCore.schedule_rules()
return _DefaultCUDA.schedule_rules()
raise ValueError(f"Unsupported target: {target}")
@@ -190,9 +210,13 @@ def postproc( # pylint: disable=redefined-outer-name
return postproc()
if postproc is not None:
raise TypeError(f"Expected `postproc` to be None or callable, but gets: {postproc}")
- if target.kind.name in ["llvm", "hexagon"]:
+ if target.kind.name == "llvm":
+ if target_has_vnni(target.mcpu):
+ return _DefaultLLVMVNNI.postprocs()
return _DefaultLLVM.postprocs()
if target.kind.name in ["cuda", "rocm", "vulkan"]:
+ if target.kind.name == "cuda" and nvcc.have_tensorcore(target=target):
+ return _DefaultCUDATensorCore.postprocs()
return _DefaultCUDA.postprocs()
raise ValueError(f"Unsupported target: {target}")
@@ -208,9 +232,13 @@ def mutator_probs( # pylint: disable=redefined-outer-name
raise TypeError(
f"Expected `mutator_probs` to be None or callable, but gets: {mutator_probs}"
)
- if target.kind.name in ["llvm", "hexagon"]:
+ if target.kind.name == "llvm":
+ if target_has_vnni(target.mcpu):
+ return _DefaultLLVMVNNI.mutator_probs()
return _DefaultLLVM.mutator_probs()
if target.kind.name in ["cuda", "rocm", "vulkan"]:
+ if target.kind.name == "cuda" and nvcc.have_tensorcore(target=target):
+ return _DefaultCUDATensorCore.mutator_probs()
return _DefaultCUDA.mutator_probs()
raise ValueError(f"Unsupported target: {target}")
@@ -277,6 +305,78 @@ class _DefaultLLVM:
}
+class _DefaultLLVMVNNI:
+ """Default tuning configuration for LLVM with VNNI."""
+
+ @staticmethod
+ def schedule_rules() -> List[ScheduleRule]:
+ from tvm.meta_schedule import schedule_rule as M
+ from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN
+
+ logger.info("Using schedule rule: LLVM VNNI")
+
+ return [
+ M.AutoInline(
+ into_producer=False,
+ into_consumer=True,
+ inline_const_tensor=True,
+ disallow_if_then_else=True,
+ require_injective=True,
+ require_ordered=True,
+ disallow_op=["tir.exp"],
+ ),
+ M.AddRFactor(max_jobs_per_core=16, max_innermost_factor=64),
+ M.MultiLevelTilingWithIntrin(
+ VNNI_DOT_16x4_INTRIN,
+ structure="SSRSRS",
+ tile_binds=None,
+ max_innermost_factor=64,
+ vector_load_lens=None,
+ reuse_read=None,
+ reuse_write=M.ReuseType(
+ req="may",
+ levels=[1, 2],
+ scope="global",
+ ),
+ ),
+ M.MultiLevelTiling(
+ structure="SSRSRS",
+ tile_binds=None,
+ max_innermost_factor=64,
+ vector_load_lens=None,
+ reuse_read=None,
+ reuse_write=M.ReuseType(
+ req="may",
+ levels=[1, 2],
+ scope="global",
+ ),
+ ),
+ M.ParallelizeVectorizeUnroll(
+ max_jobs_per_core=16,
+ max_vectorize_extent=64,
+ unroll_max_steps=[0, 16, 64, 512],
+ unroll_explicit=True,
+ ),
+ M.RandomComputeLocation(),
+ ]
+
+ @staticmethod
+ def postprocs() -> List[Postproc]:
+ from tvm.meta_schedule import postproc as M
+
+ return [
+ M.DisallowDynamicLoop(),
+ M.RewriteParallelVectorizeUnroll(),
+ M.RewriteReductionBlock(),
+ M.RewriteTensorize(vectorize_init_loop=True),
+ M.RewriteLayout(),
+ ]
+
+ @staticmethod
+ def mutator_probs() -> Dict[Mutator, float]:
+ return _DefaultLLVM.mutator_probs()
+
+
class _DefaultCUDA:
"""Default tuning configuration for CUDA."""
@@ -355,6 +455,8 @@ class _DefaultCUDATensorCore:
from tvm.meta_schedule import schedule_rule as M
from tvm.tir.tensor_intrin.cuda import get_wmma_intrin_group
+ logger.info("Using schedule rule: CUDA tensorcore")
+
return [
M.MultiLevelTilingTensorCore(
intrin_groups=[
diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py
index f7752e41b0..67909b04b8 100644
--- a/python/tvm/relay/op/contrib/dnnl.py
+++ b/python/tvm/relay/op/contrib/dnnl.py
@@ -36,22 +36,18 @@ import logging
from functools import reduce
import tvm.ir
-from tvm.ir import Op
from tvm import relay
+from tvm.ir import Op
+from tvm.relay import expr as _expr
from tvm.relay import transform
-from tvm.relay.expr import GlobalVar
-from tvm.relay.expr_functor import ExprMutator, ExprVisitor
-from tvm.relay.expr import const
-
from tvm.relay.analysis import analysis as _analysis
-from tvm.relay import expr as _expr
+from tvm.relay.expr import Call, GlobalVar, TupleGetItem, const
+from tvm.relay.expr_functor import ExprMutator, ExprVisitor
-from tvm.relay.expr import Call, TupleGetItem
from ... import _ffi_api
-from ...dataflow_pattern import wildcard, is_op, is_constant, is_expr, rewrite, DFPatternCallback
+from ...dataflow_pattern import DFPatternCallback, is_constant, is_expr, is_op, rewrite, wildcard
from .register import register_pattern_table
-
logger = logging.getLogger("DNNL")
supported_post_elts = ["nn.relu", "tanh", "sigmoid", "clip", "gelu", "swish", "mish", None]
@@ -809,7 +805,7 @@ def prune_dnnl_subgraphs(mod):
return new_mod
-class LayerNormRewrite(DFPatternCallback):
+class LayerNormRewritePattern1(DFPatternCallback):
"""
A callback to rewrite the following operators into a single layer normalization operator.
@@ -826,7 +822,42 @@ class LayerNormRewrite(DFPatternCallback):
/* ty=Tensor[(1, 3136, 64), float32] */;
10 %13 = add(%12, meta[relay.Constant][3] /* ty=Tensor[(64), float32] */)
/* ty=Tensor[(1, 3136, 64), float32] */;
+ """
+
+ def __init__(self):
+ super(LayerNormRewritePattern1, self).__init__()
+ self.data = wildcard()
+ self.gamma = wildcard()
+ self.beta = wildcard()
+ mu = is_op("mean")(self.data)
+ diff = is_op("subtract")(self.data, mu)
+ cdiff = is_op("cast")(diff) | diff # cast does not need to be here usually
+ const_two = (
+ is_expr(relay.const(2))
+ | is_expr(relay.const(2.0))
+ | is_expr(relay.const(2.0, dtype="float16"))
+ )
+ p1 = is_op("power")(cdiff, const_two)
+ mp1 = is_op("mean")(p1)
+ eps = is_constant() # TODO: check epsilon is something reasonable
+ added_eps = is_op("add")(mp1, eps)
+ deno = is_op("sqrt")(added_eps)
+ div_out = is_op("divide")(diff, deno)
+ div_out2 = diff * is_op("rsqrt")(added_eps)
+ weighted = is_op("multiply")(div_out | div_out2, self.gamma)
+ added_bias = is_op("add")(weighted, self.beta)
+ self.pattern = added_bias
+ def callback(self, pre, post, node_map):
+ data = node_map[self.data][0]
+ gamma = node_map[self.gamma][0]
+ beta = node_map[self.beta][0]
+ return relay.op.nn.layer_norm(data=data, gamma=gamma, beta=beta)
+
+
+class LayerNormRewritePattern2(DFPatternCallback):
+ """
+ A callback to rewrite the following operators into a single layer normalization operator.
Pattern #2:
1 %0 = mean(%input, axis=[-1], keepdims=True);
2 %1 = variance(%input, %0, axis=[-1], keepdims=True);
@@ -842,19 +873,16 @@ class LayerNormRewrite(DFPatternCallback):
"""
def __init__(self):
- super(LayerNormRewrite, self).__init__()
+ super(LayerNormRewritePattern2, self).__init__()
self.data = wildcard()
self.gamma = wildcard()
self.beta = wildcard()
mu = is_op("mean")(self.data)
- diff = is_op("subtract")(self.data, mu)
- cdiff = diff | is_op("cast")(diff)
- const_two = is_expr(relay.const(2)) | is_expr(relay.const(2.0))
- p1 = is_op("power")(cdiff, const_two)
- mp1 = is_op("mean")(p1) | is_op("variance")(self.data, mu)
+ mp1 = is_op("variance")(self.data, mu)
eps = is_expr(relay.const(1e-5)) | is_expr(relay.const(1e-6))
added_eps = is_op("add")(mp1, eps)
deno = is_op("sqrt")(added_eps)
+ diff = is_op("subtract")(self.data, mu)
div_out = is_op("divide")(diff, deno)
div_out2 = diff * is_op("rsqrt")(added_eps)
weighted = is_op("multiply")(div_out | div_out2, self.gamma)
@@ -872,7 +900,9 @@ def rewrite_layer_norm(mod):
"""Rewrite the input graph to replace multiple operators with a TVM native layer normalization
operator so that we can offload them to dnnl layer normalization byoc part.
"""
- mod["main"] = rewrite(LayerNormRewrite(), mod["main"])
+ mod["main"] = rewrite(LayerNormRewritePattern1(), mod["main"])
+ mod["main"] = rewrite(LayerNormRewritePattern2(), mod["main"])
+
return mod
diff --git a/python/tvm/relay/qnn/op/qnn.py b/python/tvm/relay/qnn/op/qnn.py
index 1f38385107..6d1cabeb8d 100644
--- a/python/tvm/relay/qnn/op/qnn.py
+++ b/python/tvm/relay/qnn/op/qnn.py
@@ -788,6 +788,74 @@ def mul(
)
+def div(
+ lhs,
+ rhs,
+ lhs_scale,
+ lhs_zero_point,
+ rhs_scale,
+ rhs_zero_point,
+ output_scale,
+ output_zero_point,
+ lhs_axis=-1,
+ rhs_axis=-1,
+):
+ """Quantized division with numpy-style broadcasting.
+
+ Parameters
+ ----------
+ lhs : relay.Expr
+ The left hand side quantized input data.
+
+ rhs : relay.Expr
+ The right hand side quantized input data.
+
+ lhs_scale: relay.Expr
+ The scale of the lhs quantized expr.
+
+ lhs_zero_point: relay.Expr
+ The zero point of lhs quantized expr.
+
+ rhs_scale: relay.Expr
+ The scale of the rhs quantized expr.
+
+ rhs_zero_point: relay.Expr
+ The zero point of rhs quantized expr.
+
+ output_scale: relay.Expr
+ The scale of the output quantized expr.
+
+ output_zero_point: relay.Expr
+ The zero point of output quantized expr.
+
+ lhs_axis: int
+ The channel axis for lhs quantization. Default value is -1 which corresponds
+ to the last axis.
+
+ rhs_axis: int
+ The channel axis for rhs quantization. Default value is -1 which corresponds
+ to the last axis.
+
+ Returns
+ -------
+ result : relay.Expr
+ The computed result.
+
+ """
+ return _make.div(
+ lhs,
+ rhs,
+ lhs_scale,
+ lhs_zero_point,
+ rhs_scale,
+ rhs_zero_point,
+ output_scale,
+ output_zero_point,
+ lhs_axis,
+ rhs_axis,
+ )
+
+
def tanh(x, scale, zero_point, output_scale, output_zero_point):
"""Quantized tanh.
diff --git a/python/tvm/relay/qnn/transform.py b/python/tvm/relay/qnn/transform.py
index 0485cecb99..7b42942c8b 100644
--- a/python/tvm/relay/qnn/transform.py
+++ b/python/tvm/relay/qnn/transform.py
@@ -114,3 +114,81 @@ def Legalize():
"""
return relay.transform.Legalize("FTVMQnnLegalize")
+
+
+from tvm.relay.dataflow_pattern import (
+ DFPatternCallback,
+ is_constant,
+ is_expr,
+ is_op,
+ rewrite,
+ wildcard,
+)
+
+
+class RSqrtPattern(DFPatternCallback):
+ """
+ Rewrites QNN RSQRT Pattern
+ """
+
+ def __init__(self):
+ super(RSqrtPattern, self).__init__()
+
+ self.sqrt_data = wildcard()
+ self.sqrt_data_input_scale = wildcard()
+ self.sqrt_data_input_zp = wildcard()
+
+ self.numerator = wildcard()
+ self.numerator_scale = wildcard()
+ self.numerator_zp = wildcard()
+
+ self.output_scale = wildcard()
+ self.output_zp = wildcard()
+
+ self.sqrt = is_op("qnn.sqrt")(
+ self.sqrt_data,
+ self.sqrt_data_input_scale,
+ self.sqrt_data_input_zp,
+ wildcard(),
+ wildcard(),
+ )
+
+ # TODO: match axis properly
+ self.rsqrt = is_op("qnn.div")(
+ self.numerator,
+ self.sqrt,
+ self.numerator_scale,
+ self.numerator_zp,
+ wildcard(),
+ wildcard(),
+ self.output_scale,
+ self.output_zp,
+ )
+
+ self.pattern = self.rsqrt
+
+ def callback(self, pre, post, node_map):
+ sqrt_data = node_map[self.sqrt_data][0]
+ sqrt_data_scale = node_map[self.sqrt_data_input_scale][0]
+ sqrt_data_zp = node_map[self.sqrt_data_input_zp][0]
+
+ numerator = node_map[self.numerator][0]
+ numerator_scale = node_map[self.numerator_scale][0]
+ numerator_zp = node_map[self.numerator_zp][0]
+
+ output_scale = node_map[self.output_scale][0]
+ output_zp = node_map[self.output_zp][0]
+
+ rsqrt = relay.qnn.op.rsqrt(
+ sqrt_data, sqrt_data_scale, sqrt_data_zp, numerator_scale, numerator_zp
+ )
+ return relay.qnn.op.mul(
+ numerator,
+ rsqrt,
+ numerator_scale,
+ numerator_zp,
+ numerator_scale,
+ numerator_zp,
+ output_scale,
+ output_zp,
+ )
diff --git a/python/tvm/relay/transform/fake_quantization_to_integer.py b/python/tvm/relay/transform/fake_quantization_to_integer.py
index 242740399f..3dd2474170 100644
--- a/python/tvm/relay/transform/fake_quantization_to_integer.py
+++ b/python/tvm/relay/transform/fake_quantization_to_integer.py
@@ -19,6 +19,7 @@ import numpy as np
import tvm
from tvm import relay
from tvm.ir import TensorAffineType, TupleAffineType
+from tvm.relay.op.tensor import ones_like
# import to register canonicalization funcs for fq2i
# pylint: disable=unused-import
@@ -198,6 +199,60 @@ def broadcast_to(expr, type_map):
return [out, t]
+@register_fake_quantization_to_integer("take")
+def take(expr, type_map):
+ """Rewrite a take op"""
+ arg1 = expr.args[0]
+ t = type_map[arg1]
+ arg2 = expr.args[1]
+ out = relay.op.take(
+ arg1,
+ arg2,
+ axis=expr.attrs.axis,
+ batch_dims=expr.attrs.batch_dims,
+ mode=expr.attrs.mode,
+ )
+ return [out, t]
+
+
+@register_fake_quantization_to_integer("power")
+def power(expr, type_map):
+ base = expr.args[0]
+ exponent = expr.args[1]
+
+ base_type = type_map[base]
+
+ if not isinstance(exponent, relay.Constant):
+ return [expr, type_map[expr]]
+
+ data = exponent.data.numpy()
+ if not len(data.shape) == 0:
+ return [expr, type_map[expr]]
+
+ data = data.item()
+ if data != 2:
+ return [expr, type_map[expr]]
+
+ out = relay.qnn.op.mul(
+ base,
+ base,
+ base_type.scale,
+ base_type.zero_point,
+ base_type.scale,
+ base_type.zero_point,
+ output_scale=base_type.scale * base_type.scale,
+ output_zero_point=base_type.zero_point,
+ lhs_axis=base_type.axis,
+ rhs_axis=base_type.axis,
+ )
+ return [
+ out,
+ TensorAffineType(
+ base_type.scale * base_type.scale, base_type.zero_point, base_type.dtype, base_type.axis
+ ),
+ ]
+
+
@register_fake_quantization_to_integer("nn.bias_add")
def bias_add(expr, type_map):
"""Rewrite a bias_add op"""
@@ -520,6 +575,37 @@ def register_binary_qnn(op_name, op):
register_binary_qnn("add", lambda *args: relay.qnn.op.add(*args))
register_binary_qnn("multiply", lambda *args: relay.qnn.op.mul(*args))
register_binary_qnn("subtract", lambda *args: relay.qnn.op.subtract(*args))
+# register_binary_qnn("divide", lambda *args: relay.qnn.op.div(*args))
+
+
+'''
+@register_fake_quantization_to_integer("divide")
+def divide(expr, type_map):
+ """Rewrite an adaptive avgpool op"""
+ numerator = expr.args[0]
+ denominator = expr.args[1]
+ numerator_t = type_map[numerator]
+ denominator_t = type_map[denominator]
+ new_scale = numerator_t.scale / (denominator_t.scale * (denominator - denominator_t.zero_point))
+ out = relay.divide(numerator, ones_like(denominator))
+ assert numerator_t.axis == denominator_t.axis, "Only support identical axis for now."
+ # print(out)
+
+ print("new out:")
+ str_new_out = str(relay.transform.InferType()(tvm.IRModule.from_expr(out)))
+ print("\n".join(str_new_out.split("\n")[-10:]))
+ print("old_out:")
+ str_old_out = str(relay.transform.InferType()(tvm.IRModule.from_expr(expr)))
+ print("\n".join(str_old_out.split("\n")[-10:]))
+ print()
+ breakpoint()
+ # print("yay!")
+ # This is to get broadcasting working to get same shape
+ return [
+ out,
+ TensorAffineType(new_scale, numerator_t.zero_point, numerator_t.dtype, numerator_t.axis),
+ ]
+'''
def register_binary_identity(op_name, op):
diff --git a/src/relay/qnn/op/div.cc b/src/relay/qnn/op/div.cc
new file mode 100644
index 0000000000..3c37ed41c4
--- /dev/null
+++ b/src/relay/qnn/op/div.cc
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/qnn/op/mul.cc
+ * \brief QNN mul operator.
+ */
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/relay/qnn/attrs.h>
+
+#include "../../transforms/pattern_utils.h"
+#include "../utils.h"
+#include "op_common.h"
+
+namespace tvm {
+namespace relay {
+namespace qnn {
+
+/*
+ * \brief Canonicalizes the QNN div op.
+ * \param attrs The QNN div attrs.
+ * \param new_args The new mutated args to the call node.
+ * \param arg_types The types of input and output.
+ * \return The sequence of Relay ops for mul op.
+ */
+Expr QnnDivCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
+ const Array<tvm::relay::Type>& arg_types) {
+ Expr output;
+
+ // Get the attrs.
+ QnnBinaryOpArguments args(new_args);
+
+ // Get the input dtype and shape.
+ QnnBinaryOpTensorType input_type(arg_types, 0);
+
+ // data types
+ const auto int32_dtype = DataType::Int(32);
+ const auto float32_dtype = DataType::Float(32);
+
+ const auto* broadcast_attrs = attrs.as<BroadcastAttrs>();
+ ICHECK(broadcast_attrs != nullptr);
+
+ if (IsConstScalar(args.lhs_scale) && IsConstScalar(args.rhs_scale)) {
+ /* If both are constant:
+
+ n1/n2 = [s1(q1-z1)] / [s2(q2-z2)]
+ n1/n2 = [s1/s2][(q1-z1)/(q2-z2)]
+
+ As [(q1-z1)/(q2-z2)] is integer division, we lose perhaps significant precision.
+ To get around this we scale the numerator by C to ensure that
+
+ |C(q1-z1)| >> (q2 - z2) and the precision loss from the division is minimal:
+
+ n1/n2 = [s1/(s2 * C)][C(q1-z1)/(q2-z2)]
+ */
+
+ auto lhs_shifted = Cast(args.lhs, int32_dtype);
+ auto rhs_shifted = Cast(args.rhs, int32_dtype);
+
+ auto zero_scalar = MakeConstantScalar(int32_dtype, 0);
+ if (!IsEqualScalar(args.lhs_zero_point, zero_scalar)) {
+ lhs_shifted = Subtract(lhs_shifted, args.lhs_zero_point);
+ }
+
+ if (!IsEqualScalar(args.rhs_zero_point, zero_scalar)) {
+ rhs_shifted = Subtract(rhs_shifted, args.rhs_zero_point);
+ }
+
+ // multiply numerator to avoid precision loss, as accumulate in INT32 and
+ // may deal with UINT16, multiply by 2^15
+ int divide_scale_factor = 32768;
+ auto divide_scale_factor_constant = MakeConstantScalar(int32_dtype, divide_scale_factor);
+ output = Divide(Multiply(lhs_shifted, divide_scale_factor_constant), rhs_shifted);
+
+ // Get the adjusted new scale and zero points.
+ float lhs_scale_float = GetScalarFromConstant<float>(args.lhs_scale);
+ float rhs_scale_float = GetScalarFromConstant<float>(args.rhs_scale);
+ float new_scale_float = lhs_scale_float / (rhs_scale_float * divide_scale_factor);
+ auto new_input_scale = MakeConstantScalar(float32_dtype, new_scale_float);
+ auto new_input_zero_point = zero_scalar;
+
+ // Requantize to get Q_c
+ output = Requantize(output, input_type.shape, new_input_scale, new_input_zero_point,
+ args.output_scale, args.output_zero_point, input_type.dtype);
+ } else {
+ LOG(FATAL) << "Non-constant scale_factor not supported yet.";
+ }
+
+ return output;
+}
+
+// QNN Multiplication operator.
+QNN_REGISTER_BINARY_OP("div")
+ .describe("Elementwise div with broadcasting for quantized tensors.")
+ .set_support_level(11)
+ .set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnDivCanonicalize);
+
+} // namespace qnn
+} // namespace relay
+} // namespace tvm