You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by an...@apache.org on 2022/09/19 23:38:27 UTC

[tvm] 01/01: changes change

This is an automated email from the ASF dual-hosted git repository.

andrewzhaoluo pushed a commit to branch aluo/autotensorization-redux-09192022
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 72d4ca5218e1a1ee0378bc6681eff756598e3598
Author: Andrew Zhao Luo <an...@gmail.com>
AuthorDate: Mon Sep 19 12:58:43 2022 -0700

    changes change
---
 python/tvm/meta_schedule/default_config.py         | 108 ++++++++++++++++++-
 python/tvm/relay/op/contrib/dnnl.py                |  64 ++++++++---
 python/tvm/relay/qnn/op/qnn.py                     |  68 ++++++++++++
 python/tvm/relay/qnn/transform.py                  |  78 ++++++++++++++
 .../transform/fake_quantization_to_integer.py      |  86 +++++++++++++++
 src/relay/qnn/op/div.cc                            | 117 +++++++++++++++++++++
 6 files changed, 501 insertions(+), 20 deletions(-)

diff --git a/python/tvm/meta_schedule/default_config.py b/python/tvm/meta_schedule/default_config.py
index ac4028ec50..eaa026e3b4 100644
--- a/python/tvm/meta_schedule/default_config.py
+++ b/python/tvm/meta_schedule/default_config.py
@@ -20,6 +20,8 @@ import logging
 from os import path as osp
 from typing import Callable, Dict, List, Optional, Union
 
+from tvm._ffi.registry import register_func
+from tvm.contrib import nvcc
 from tvm.ir import IRModule
 from tvm.target import Target
 from tvm.tir import PrimFunc
@@ -43,6 +45,20 @@ FnPostproc = Callable[[], List[Postproc]]
 FnMutatorProb = Callable[[], Dict[Mutator, float]]
 
 
+def target_has_vnni(target):
+    return target in {
+        "cascadelake",
+        "icelake-client",
+        "icelake-server",
+        "rocketlake",
+        "tigerlake",
+        "cooperlake",
+        "sapphirerapids",
+        "alderlake",
+    }
+
+
+@register_func("tvm.meta_schedule.tune.parse_mod")  # for use in ApplyHistoryBest
 def mod(mod: Union[PrimFunc, IRModule]) -> IRModule:  # pylint: disable=redefined-outer-name
     """Normalize the input to an IRModule"""
     if isinstance(mod, PrimFunc):
@@ -174,9 +190,13 @@ def schedule_rules(  # pylint: disable=redefined-outer-name
         return sch_rules()
     if sch_rules is not None:
         raise TypeError(f"Expected `sch_rules` to be None or callable, but gets: {sch_rules}")
-    if target.kind.name in ["llvm", "hexagon"]:
+    if target.kind.name == "llvm":
+        if target_has_vnni(target.mcpu):
+            return _DefaultLLVMVNNI.schedule_rules()
         return _DefaultLLVM.schedule_rules()
     if target.kind.name in ["cuda", "rocm", "vulkan"]:
+        if target.kind.name == "cuda" and nvcc.have_tensorcore(target=target):
+            return _DefaultCUDATensorCore.schedule_rules()
         return _DefaultCUDA.schedule_rules()
     raise ValueError(f"Unsupported target: {target}")
 
@@ -190,9 +210,13 @@ def postproc(  # pylint: disable=redefined-outer-name
         return postproc()
     if postproc is not None:
         raise TypeError(f"Expected `postproc` to be None or callable, but gets: {postproc}")
-    if target.kind.name in ["llvm", "hexagon"]:
+    if target.kind.name == "llvm":
+        if target_has_vnni(target.mcpu):
+            return _DefaultLLVMVNNI.postprocs()
         return _DefaultLLVM.postprocs()
     if target.kind.name in ["cuda", "rocm", "vulkan"]:
+        if target.kind.name == "cuda" and nvcc.have_tensorcore(target=target):
+            return _DefaultCUDATensorCore.postprocs()
         return _DefaultCUDA.postprocs()
     raise ValueError(f"Unsupported target: {target}")
 
@@ -208,9 +232,13 @@ def mutator_probs(  # pylint: disable=redefined-outer-name
         raise TypeError(
             f"Expected `mutator_probs` to be None or callable, but gets: {mutator_probs}"
         )
-    if target.kind.name in ["llvm", "hexagon"]:
+    if target.kind.name == "llvm":
+        if target_has_vnni(target.mcpu):
+            return _DefaultLLVMVNNI.mutator_probs()
         return _DefaultLLVM.mutator_probs()
     if target.kind.name in ["cuda", "rocm", "vulkan"]:
+        if target.kind.name == "cuda" and nvcc.have_tensorcore(target=target):
+            return _DefaultCUDATensorCore.mutator_probs()
         return _DefaultCUDA.mutator_probs()
     raise ValueError(f"Unsupported target: {target}")
 
@@ -277,6 +305,78 @@ class _DefaultLLVM:
         }
 
 
+class _DefaultLLVMVNNI:
+    """Default tuning configuration for LLVM with VNNI."""
+
+    @staticmethod
+    def schedule_rules() -> List[ScheduleRule]:
+        from tvm.meta_schedule import schedule_rule as M
+        from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN
+
+        logger.info("Using schedule rule: LLVM VNNI")
+
+        return [
+            M.AutoInline(
+                into_producer=False,
+                into_consumer=True,
+                inline_const_tensor=True,
+                disallow_if_then_else=True,
+                require_injective=True,
+                require_ordered=True,
+                disallow_op=["tir.exp"],
+            ),
+            M.AddRFactor(max_jobs_per_core=16, max_innermost_factor=64),
+            M.MultiLevelTilingWithIntrin(
+                VNNI_DOT_16x4_INTRIN,
+                structure="SSRSRS",
+                tile_binds=None,
+                max_innermost_factor=64,
+                vector_load_lens=None,
+                reuse_read=None,
+                reuse_write=M.ReuseType(
+                    req="may",
+                    levels=[1, 2],
+                    scope="global",
+                ),
+            ),
+            M.MultiLevelTiling(
+                structure="SSRSRS",
+                tile_binds=None,
+                max_innermost_factor=64,
+                vector_load_lens=None,
+                reuse_read=None,
+                reuse_write=M.ReuseType(
+                    req="may",
+                    levels=[1, 2],
+                    scope="global",
+                ),
+            ),
+            M.ParallelizeVectorizeUnroll(
+                max_jobs_per_core=16,
+                max_vectorize_extent=64,
+                unroll_max_steps=[0, 16, 64, 512],
+                unroll_explicit=True,
+            ),
+            M.RandomComputeLocation(),
+        ]
+
+    @staticmethod
+    def postprocs() -> List[Postproc]:
+        from tvm.meta_schedule import postproc as M
+
+        return [
+            M.DisallowDynamicLoop(),
+            M.RewriteParallelVectorizeUnroll(),
+            M.RewriteReductionBlock(),
+            M.RewriteTensorize(vectorize_init_loop=True),
+            M.RewriteLayout(),
+        ]
+
+    @staticmethod
+    def mutator_probs() -> Dict[Mutator, float]:
+        return _DefaultLLVM.mutator_probs()
+
+
 class _DefaultCUDA:
     """Default tuning configuration for CUDA."""
 
@@ -355,6 +455,8 @@ class _DefaultCUDATensorCore:
         from tvm.meta_schedule import schedule_rule as M
         from tvm.tir.tensor_intrin.cuda import get_wmma_intrin_group
 
+        logger.info("Using schedule rule: CUDA tensorcore")
+
         return [
             M.MultiLevelTilingTensorCore(
                 intrin_groups=[
diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py
index f7752e41b0..67909b04b8 100644
--- a/python/tvm/relay/op/contrib/dnnl.py
+++ b/python/tvm/relay/op/contrib/dnnl.py
@@ -36,22 +36,18 @@ import logging
 from functools import reduce
 
 import tvm.ir
-from tvm.ir import Op
 from tvm import relay
+from tvm.ir import Op
+from tvm.relay import expr as _expr
 from tvm.relay import transform
-from tvm.relay.expr import GlobalVar
-from tvm.relay.expr_functor import ExprMutator, ExprVisitor
-from tvm.relay.expr import const
-
 from tvm.relay.analysis import analysis as _analysis
-from tvm.relay import expr as _expr
+from tvm.relay.expr import Call, GlobalVar, TupleGetItem, const
+from tvm.relay.expr_functor import ExprMutator, ExprVisitor
 
-from tvm.relay.expr import Call, TupleGetItem
 from ... import _ffi_api
-from ...dataflow_pattern import wildcard, is_op, is_constant, is_expr, rewrite, DFPatternCallback
+from ...dataflow_pattern import DFPatternCallback, is_constant, is_expr, is_op, rewrite, wildcard
 from .register import register_pattern_table
 
-
 logger = logging.getLogger("DNNL")
 supported_post_elts = ["nn.relu", "tanh", "sigmoid", "clip", "gelu", "swish", "mish", None]
 
@@ -809,7 +805,7 @@ def prune_dnnl_subgraphs(mod):
     return new_mod
 
 
-class LayerNormRewrite(DFPatternCallback):
+class LayerNormRewritePattern1(DFPatternCallback):
     """
     A callback to rewrite the following operators into a single layer normalization operator.
 
@@ -826,7 +822,42 @@ class LayerNormRewrite(DFPatternCallback):
             /* ty=Tensor[(1, 3136, 64), float32] */;
     10   %13 = add(%12, meta[relay.Constant][3] /* ty=Tensor[(64), float32] */)
             /* ty=Tensor[(1, 3136, 64), float32] */;
+    """
+
+    def __init__(self):
+        super(LayerNormRewritePattern1, self).__init__()
+        self.data = wildcard()
+        self.gamma = wildcard()
+        self.beta = wildcard()
+        mu = is_op("mean")(self.data)
+        diff = is_op("subtract")(self.data, mu)
+        cdiff = is_op("cast")(diff) | diff  # cast does not need to be here usually
+        const_two = (
+            is_expr(relay.const(2))
+            | is_expr(relay.const(2.0))
+            | is_expr(relay.const(2.0, dtype="float16"))
+        )
+        p1 = is_op("power")(cdiff, const_two)
+        mp1 = is_op("mean")(p1)
+        eps = is_constant()  # TODO: check epsilon is something reasonable
+        added_eps = is_op("add")(mp1, eps)
+        deno = is_op("sqrt")(added_eps)
+        div_out = is_op("divide")(diff, deno)
+        div_out2 = diff * is_op("rsqrt")(added_eps)
+        weighted = is_op("multiply")(div_out | div_out2, self.gamma)
+        added_bias = is_op("add")(weighted, self.beta)
+        self.pattern = added_bias
 
+    def callback(self, pre, post, node_map):
+        data = node_map[self.data][0]
+        gamma = node_map[self.gamma][0]
+        beta = node_map[self.beta][0]
+        return relay.op.nn.layer_norm(data=data, gamma=gamma, beta=beta)
+
+
+class LayerNormRewritePattern2(DFPatternCallback):
+    """
+    A callback to rewrite the following operators into a single layer normalization operator.
     Pattern #2:
     1   %0 = mean(%input, axis=[-1], keepdims=True);
     2   %1 = variance(%input, %0, axis=[-1], keepdims=True);
@@ -842,19 +873,16 @@ class LayerNormRewrite(DFPatternCallback):
     """
 
     def __init__(self):
-        super(LayerNormRewrite, self).__init__()
+        super(LayerNormRewritePattern2, self).__init__()
         self.data = wildcard()
         self.gamma = wildcard()
         self.beta = wildcard()
         mu = is_op("mean")(self.data)
-        diff = is_op("subtract")(self.data, mu)
-        cdiff = diff | is_op("cast")(diff)
-        const_two = is_expr(relay.const(2)) | is_expr(relay.const(2.0))
-        p1 = is_op("power")(cdiff, const_two)
-        mp1 = is_op("mean")(p1) | is_op("variance")(self.data, mu)
+        mp1 = is_op("variance")(self.data, mu)
         eps = is_expr(relay.const(1e-5)) | is_expr(relay.const(1e-6))
         added_eps = is_op("add")(mp1, eps)
         deno = is_op("sqrt")(added_eps)
+        diff = is_op("subtract")(self.data, mu)
         div_out = is_op("divide")(diff, deno)
         div_out2 = diff * is_op("rsqrt")(added_eps)
         weighted = is_op("multiply")(div_out | div_out2, self.gamma)
@@ -872,7 +900,9 @@ def rewrite_layer_norm(mod):
     """Rewrite the input graph to replace multiple operators with a TVM native layer normalization
     operator so that we can offload them to dnnl layer normalization byoc part.
     """
-    mod["main"] = rewrite(LayerNormRewrite(), mod["main"])
+    mod["main"] = rewrite(LayerNormRewritePattern1(), mod["main"])
+    mod["main"] = rewrite(LayerNormRewritePattern2(), mod["main"])
+
     return mod
 
 
diff --git a/python/tvm/relay/qnn/op/qnn.py b/python/tvm/relay/qnn/op/qnn.py
index 1f38385107..6d1cabeb8d 100644
--- a/python/tvm/relay/qnn/op/qnn.py
+++ b/python/tvm/relay/qnn/op/qnn.py
@@ -788,6 +788,74 @@ def mul(
     )
 
 
+def div(
+    lhs,
+    rhs,
+    lhs_scale,
+    lhs_zero_point,
+    rhs_scale,
+    rhs_zero_point,
+    output_scale,
+    output_zero_point,
+    lhs_axis=-1,
+    rhs_axis=-1,
+):
+    """Quantized division with numpy-style broadcasting.
+
+    Parameters
+    ----------
+    lhs : relay.Expr
+        The left hand side quantized input data.
+
+    rhs : relay.Expr
+        The right hand side quantized input data.
+
+    lhs_scale: relay.Expr
+        The scale of the lhs quantized expr.
+
+    lhs_zero_point: relay.Expr
+       The zero point of lhs quantized expr.
+
+    rhs_scale: relay.Expr
+        The scale of the rhs quantized expr.
+
+    rhs_zero_point: relay.Expr
+       The zero point of rhs quantized expr.
+
+    output_scale: relay.Expr
+        The scale of the output quantized expr.
+
+    output_zero_point: relay.Expr
+       The zero point of output quantized expr.
+
+    lhs_axis: int
+        The channel axis for lhs quantization. Default value is -1 which corresponds
+        to the last axis.
+
+    rhs_axis: int
+        The channel axis for rhs quantization. Default value is -1 which corresponds
+        to the last axis.
+
+    Returns
+    -------
+    result : relay.Expr
+        The computed result.
+
+    """
+    return _make.div(
+        lhs,
+        rhs,
+        lhs_scale,
+        lhs_zero_point,
+        rhs_scale,
+        rhs_zero_point,
+        output_scale,
+        output_zero_point,
+        lhs_axis,
+        rhs_axis,
+    )
+
+
 def tanh(x, scale, zero_point, output_scale, output_zero_point):
     """Quantized tanh.
 
diff --git a/python/tvm/relay/qnn/transform.py b/python/tvm/relay/qnn/transform.py
index 0485cecb99..7b42942c8b 100644
--- a/python/tvm/relay/qnn/transform.py
+++ b/python/tvm/relay/qnn/transform.py
@@ -114,3 +114,81 @@ def Legalize():
     """
 
     return relay.transform.Legalize("FTVMQnnLegalize")
+
+
+from tvm.relay.dataflow_pattern import (
+    DFPatternCallback,
+    is_constant,
+    is_expr,
+    is_op,
+    rewrite,
+    wildcard,
+)
+
+
+class RSqrtPattern(DFPatternCallback):
+    """
+    Rewrites QNN RSQRT Pattern
+    """
+
+    def __init__(self):
+        super(RSqrtPattern, self).__init__()
+
+        self.sqrt_data = wildcard()
+        self.sqrt_data_input_scale = wildcard()
+        self.sqrt_data_input_zp = wildcard()
+
+        self.numerator = wildcard()
+        self.numerator_scale = wildcard()
+        self.numerator_zp = wildcard()
+
+        self.output_scale = wildcard()
+        self.output_zp = wildcard()
+
+        self.sqrt = is_op("qnn.sqrt")(
+            self.sqrt_data,
+            self.sqrt_data_input_scale,
+            self.sqrt_data_input_zp,
+            wildcard(),
+            wildcard(),
+        )
+
+        # TODO: match axis properly
+        self.rsqrt = is_op("qnn.div")(
+            self.numerator,
+            self.sqrt,
+            self.numerator_scale,
+            self.numerator_zp,
+            wildcard(),
+            wildcard(),
+            self.output_scale,
+            self.output_zp,
+        )
+
+        self.pattern = self.rsqrt
+
+    def callback(self, pre, post, node_map):
+        sqrt_data = node_map[self.sqrt_data][0]
+        sqrt_data_scale = node_map[self.sqrt_data_input_scale][0]
+        sqrt_data_zp = node_map[self.sqrt_data_input_zp][0]
+
+        numerator = node_map[self.numerator][0]
+        numerator_scale = node_map[self.numerator_scale][0]
+        numerator_zp = node_map[self.numerator_zp][0]
+
+        output_scale = node_map[self.output_scale][0]
+        output_zp = node_map[self.output_zp][0]
+
+        rsqrt = relay.qnn.op.rsqrt(
+            sqrt_data, sqrt_data_scale, sqrt_data_zp, numerator_scale, numerator_zp
+        )
+        return relay.qnn.op.mul(
+            numerator,
+            rsqrt,
+            numerator_scale,
+            numerator_zp,
+            numerator_scale,
+            numerator_zp,
+            output_scale,
+            output_zp,
+        )
diff --git a/python/tvm/relay/transform/fake_quantization_to_integer.py b/python/tvm/relay/transform/fake_quantization_to_integer.py
index 242740399f..3dd2474170 100644
--- a/python/tvm/relay/transform/fake_quantization_to_integer.py
+++ b/python/tvm/relay/transform/fake_quantization_to_integer.py
@@ -19,6 +19,7 @@ import numpy as np
 import tvm
 from tvm import relay
 from tvm.ir import TensorAffineType, TupleAffineType
+from tvm.relay.op.tensor import ones_like
 
 # import to register canonicalization funcs for fq2i
 # pylint: disable=unused-import
@@ -198,6 +199,60 @@ def broadcast_to(expr, type_map):
     return [out, t]
 
 
+@register_fake_quantization_to_integer("take")
+def take(expr, type_map):
+    """Rewrite a take op"""
+    arg1 = expr.args[0]
+    t = type_map[arg1]
+    arg2 = expr.args[1]
+    out = relay.op.take(
+        arg1,
+        arg2,
+        axis=expr.attrs.axis,
+        batch_dims=expr.attrs.batch_dims,
+        mode=expr.attrs.mode,
+    )
+    return [out, t]
+
+
+@register_fake_quantization_to_integer("power")
+def power(expr, type_map):
+    base = expr.args[0]
+    exponent = expr.args[1]
+
+    base_type = type_map[base]
+
+    if not isinstance(exponent, relay.Constant):
+        return [expr, type_map[expr]]
+
+    data = exponent.data.numpy()
+    if not len(data.shape) == 0:
+        return [expr, type_map[expr]]
+
+    data = data.item()
+    if data != 2:
+        return [expr, type_map[expr]]
+
+    out = relay.qnn.op.mul(
+        base,
+        base,
+        base_type.scale,
+        base_type.zero_point,
+        base_type.scale,
+        base_type.zero_point,
+        output_scale=base_type.scale * base_type.scale,
+        output_zero_point=base_type.zero_point,
+        lhs_axis=base_type.axis,
+        rhs_axis=base_type.axis,
+    )
+    return [
+        out,
+        TensorAffineType(
+            base_type.scale * base_type.scale, base_type.zero_point, base_type.dtype, base_type.axis
+        ),
+    ]
+
+
 @register_fake_quantization_to_integer("nn.bias_add")
 def bias_add(expr, type_map):
     """Rewrite a bias_add op"""
@@ -520,6 +575,37 @@ def register_binary_qnn(op_name, op):
 register_binary_qnn("add", lambda *args: relay.qnn.op.add(*args))
 register_binary_qnn("multiply", lambda *args: relay.qnn.op.mul(*args))
 register_binary_qnn("subtract", lambda *args: relay.qnn.op.subtract(*args))
+# register_binary_qnn("divide", lambda *args: relay.qnn.op.div(*args))
+
+
+'''
+@register_fake_quantization_to_integer("divide")
+def divide(expr, type_map):
+    """Rewrite an adaptive avgpool op"""
+    numerator = expr.args[0]
+    denominator = expr.args[1]
+    numerator_t = type_map[numerator]
+    denominator_t = type_map[denominator]
+    new_scale = numerator_t.scale / (denominator_t.scale * (denominator - denominator_t.zero_point))
+    out = relay.divide(numerator, ones_like(denominator))
+    assert numerator_t.axis == denominator_t.axis, "Only support identical axis for now."
+    # print(out)
+
+    print("new out:")
+    str_new_out = str(relay.transform.InferType()(tvm.IRModule.from_expr(out)))
+    print("\n".join(str_new_out.split("\n")[-10:]))
+    print("old_out:")
+    str_old_out = str(relay.transform.InferType()(tvm.IRModule.from_expr(expr)))
+    print("\n".join(str_old_out.split("\n")[-10:]))
+    print()
+    breakpoint()
+    # print("yay!")
+    # This is to get broadcasting working to get same shape
+    return [
+        out,
+        TensorAffineType(new_scale, numerator_t.zero_point, numerator_t.dtype, numerator_t.axis),
+    ]
+'''
 
 
 def register_binary_identity(op_name, op):
diff --git a/src/relay/qnn/op/div.cc b/src/relay/qnn/op/div.cc
new file mode 100644
index 0000000000..3c37ed41c4
--- /dev/null
+++ b/src/relay/qnn/op/div.cc
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/qnn/op/mul.cc
+ * \brief QNN mul operator.
+ */
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/relay/qnn/attrs.h>
+
+#include "../../transforms/pattern_utils.h"
+#include "../utils.h"
+#include "op_common.h"
+
+namespace tvm {
+namespace relay {
+namespace qnn {
+
+/*
+ * \brief Canonicalizes the QNN div op.
+ * \param attrs The QNN div attrs.
+ * \param new_args The new mutated args to the call node.
+ * \param arg_types The types of input and output.
+ * \return The sequence of Relay ops for mul op.
+ */
+Expr QnnDivCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
+                        const Array<tvm::relay::Type>& arg_types) {
+  Expr output;
+
+  // Get the attrs.
+  QnnBinaryOpArguments args(new_args);
+
+  // Get the input dtype and shape.
+  QnnBinaryOpTensorType input_type(arg_types, 0);
+
+  // data types
+  const auto int32_dtype = DataType::Int(32);
+  const auto float32_dtype = DataType::Float(32);
+
+  const auto* broadcast_attrs = attrs.as<BroadcastAttrs>();
+  ICHECK(broadcast_attrs != nullptr);
+
+  if (IsConstScalar(args.lhs_scale) && IsConstScalar(args.rhs_scale)) {
+    /* If both are constant:
+
+    n1/n2 = [s1(q1-z1)] / [s2(q2-z2)]
+    n1/n2 = [s1/s2][(q1-z1)/(q2-z2)]
+
+    As [(q1-z1)/(q2-z2)] is integer division, we lose perhaps significant precision.
+    To get around this we scale the numerator by C to ensure that
+
+    |C(q1-z1)| >> (q2 - z2) and the precision loss from the division is minimal:
+
+    n1/n2 = [s1/(s2 * C)][C(q1-z1)/(q2-z2)]
+    */
+
+    auto lhs_shifted = Cast(args.lhs, int32_dtype);
+    auto rhs_shifted = Cast(args.rhs, int32_dtype);
+
+    auto zero_scalar = MakeConstantScalar(int32_dtype, 0);
+    if (!IsEqualScalar(args.lhs_zero_point, zero_scalar)) {
+      lhs_shifted = Subtract(lhs_shifted, args.lhs_zero_point);
+    }
+
+    if (!IsEqualScalar(args.rhs_zero_point, zero_scalar)) {
+      rhs_shifted = Subtract(rhs_shifted, args.rhs_zero_point);
+    }
+
+    // multiply numerator to avoid precision loss, as accumulate in INT32 and
+    // may deal with UINT16, multiply by 2^15
+    int divide_scale_factor = 32768;
+    auto divide_scale_factor_constant = MakeConstantScalar(int32_dtype, divide_scale_factor);
+    output = Divide(Multiply(lhs_shifted, divide_scale_factor_constant), rhs_shifted);
+
+    // Get the adjusted new scale and zero points.
+    float lhs_scale_float = GetScalarFromConstant<float>(args.lhs_scale);
+    float rhs_scale_float = GetScalarFromConstant<float>(args.rhs_scale);
+    float new_scale_float = lhs_scale_float / (rhs_scale_float * divide_scale_factor);
+    auto new_input_scale = MakeConstantScalar(float32_dtype, new_scale_float);
+    auto new_input_zero_point = zero_scalar;
+
+    // Requantize to get Q_c
+    output = Requantize(output, input_type.shape, new_input_scale, new_input_zero_point,
+                        args.output_scale, args.output_zero_point, input_type.dtype);
+  } else {
+    LOG(FATAL) << "Non-constant scale_factor not supported yet.";
+  }
+
+  return output;
+}
+
+// QNN Multiplication operator.
+QNN_REGISTER_BINARY_OP("div")
+    .describe("Elementwise div with broadcasting for quantized tensors.")
+    .set_support_level(11)
+    .set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnDivCanonicalize);
+
+}  // namespace qnn
+}  // namespace relay
+}  // namespace tvm