You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/11/18 04:06:47 UTC

[tvm] branch main updated: [Hexagon][QNN] Add TOPI strategies for qnn ops mul/tanh/subtract (#13416)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 53824d697a [Hexagon][QNN] Add TOPI strategies for qnn ops mul/tanh/subtract (#13416)
53824d697a is described below

commit 53824d697a633260ac62777eafd624c6406d9d42
Author: ibsidorenko <98...@users.noreply.github.com>
AuthorDate: Fri Nov 18 07:06:42 2022 +0300

    [Hexagon][QNN] Add TOPI strategies for qnn ops mul/tanh/subtract (#13416)
    
    This commit adds compute/schedule implementation for Hexagon target for
    QNN ops: qnn.mul, qnn.subtract, qnn.tanh. It works only if QNN
    canonicalization pass was disabled.
---
 python/tvm/relay/qnn/op/_qnn.py                    |  11 +-
 python/tvm/relay/qnn/strategy/generic.py           |  27 ++++
 python/tvm/relay/qnn/strategy/hexagon.py           |  36 +++++
 python/tvm/topi/hexagon/qnn/nn.py                  | 179 ++++++++++++++++-----
 src/relay/qnn/op/add.cc                            |   3 +-
 src/relay/qnn/op/mul.cc                            |   3 +-
 src/relay/qnn/op/requantize.cc                     |   3 +
 src/relay/qnn/op/subtract.cc                       |   3 +-
 .../test_hexagon/test_wo_qnn_canonicalization.py   | 178 +++++++++++++++-----
 9 files changed, 362 insertions(+), 81 deletions(-)

diff --git a/python/tvm/relay/qnn/op/_qnn.py b/python/tvm/relay/qnn/op/_qnn.py
index 4e54583a3b..64ef1ee92a 100644
--- a/python/tvm/relay/qnn/op/_qnn.py
+++ b/python/tvm/relay/qnn/op/_qnn.py
@@ -66,7 +66,16 @@ register_pattern("qnn.requantize", OpPattern.ELEMWISE)
 
 # qnn.add
 register_strategy("qnn.add", strategy.qnn_add_strategy)
-register_pattern("qnn.add", OpPattern.BROADCAST)
+
+# qnn.subtract
+register_strategy("qnn.subtract", strategy.qnn_subtract_strategy)
+
+# qnn.mul
+register_strategy("qnn.mul", strategy.qnn_mul_strategy)
+
+# qnn.tanh
+register_strategy("qnn.tanh", strategy.qnn_tanh_strategy)
+register_pattern("qnn.tanh", OpPattern.ELEMWISE)
 
 # qnn.concatenate
 register_strategy("qnn.concatenate", strategy.qnn_concatenate_strategy)
diff --git a/python/tvm/relay/qnn/strategy/generic.py b/python/tvm/relay/qnn/strategy/generic.py
index 57a364f7e0..8275cf7f75 100644
--- a/python/tvm/relay/qnn/strategy/generic.py
+++ b/python/tvm/relay/qnn/strategy/generic.py
@@ -213,6 +213,33 @@ def qnn_add_strategy(attrs, inputs, out_type, target):
     )
 
 
+@override_native_generic_func("qnn_subtract_strategy")
+def qnn_subtract_strategy(attrs, inputs, out_type, target):
+    """qnn.subtract generic strategy"""
+    raise RuntimeError(
+        "qnn.subtract is currently only supported with Hexagon. "
+        "Please run QNN Canonicalize pass to decompose this op into supported ops."
+    )
+
+
+@override_native_generic_func("qnn_mul_strategy")
+def qnn_mul_strategy(attrs, inputs, out_type, target):
+    """qnn.mul generic strategy"""
+    raise RuntimeError(
+        "qnn.mul is currently only supported with Hexagon. "
+        "Please run QNN Canonicalize pass to decompose this op into supported ops."
+    )
+
+
+@override_native_generic_func("qnn_tanh_strategy")
+def qnn_tanh_strategy(attrs, inputs, out_type, target):
+    """qnn.tanh generic strategy"""
+    raise RuntimeError(
+        "qnn.tanh is currently only supported with Hexagon. "
+        "Please run QNN Canonicalize pass to decompose this op into supported ops."
+    )
+
+
 @override_native_generic_func("qnn_concatenate_strategy")
 def qnn_concatenate_strategy(attrs, inputs, out_type, target):
     """qnn.concatenate generic strategy"""
diff --git a/python/tvm/relay/qnn/strategy/hexagon.py b/python/tvm/relay/qnn/strategy/hexagon.py
index c7f59cc096..d17812e3fb 100644
--- a/python/tvm/relay/qnn/strategy/hexagon.py
+++ b/python/tvm/relay/qnn/strategy/hexagon.py
@@ -71,6 +71,42 @@ def qnn_add_strategy_hexagon(attrs, inputs, out_type, target):
     return strategy
 
 
+@qnn_subtract_strategy.register("hexagon")
+def qnn_subtract_strategy_hexagon(attrs, inputs, out_type, target):
+    """qnn.subtract strategy for Hexagon"""
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(
+        wrap_topi_compute(topi.hexagon.qnn_subtract),
+        wrap_topi_schedule(topi.hexagon.schedule_qnn_subtract),
+        name="qnn_subtract.hexagon",
+    )
+    return strategy
+
+
+@qnn_mul_strategy.register("hexagon")
+def qnn_mul_strategy_hexagon(attrs, inputs, out_type, target):
+    """qnn.mul strategy for Hexagon"""
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(
+        wrap_topi_compute(topi.hexagon.qnn_mul),
+        wrap_topi_schedule(topi.hexagon.schedule_qnn_mul),
+        name="qnn_mul.hexagon",
+    )
+    return strategy
+
+
+@qnn_tanh_strategy.register("hexagon")
+def qnn_tanh_strategy_hexagon(attrs, inputs, out_type, target):
+    """qnn.tanh strategy for Hexagon"""
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(
+        wrap_topi_compute(topi.hexagon.qnn_tanh),
+        wrap_topi_schedule(topi.hexagon.schedule_qnn_tanh),
+        name="qnn_tanh.hexagon",
+    )
+    return strategy
+
+
 @qnn_concatenate_strategy.register("hexagon")
 def qnn_concatenate_strategy_hexagon(attrs, inputs, out_type, target):
     """qnn.concatenate strategy for Hexagon"""
diff --git a/python/tvm/topi/hexagon/qnn/nn.py b/python/tvm/topi/hexagon/qnn/nn.py
index 40cfd0ee96..49220d0fd0 100644
--- a/python/tvm/topi/hexagon/qnn/nn.py
+++ b/python/tvm/topi/hexagon/qnn/nn.py
@@ -19,6 +19,7 @@
 
 import tvm
 from tvm import te, topi
+from ..utils import saturate
 from ...utils import get_const_tuple
 from ...nn.utils import get_pad_tuple
 from ...nn.pad import pad
@@ -33,6 +34,11 @@ def clip_cast(val, dtype):
     return te.max(tvm.te.min(val, const_max), const_min).astype(dtype)
 
 
+# Return True if given Tensor is scalar constant value.
+def is_constant(tensor: te.Tensor):
+    return tensor.ndim == 0
+
+
 def get_qnn_param(param, indices, axis):
     # Account scalar and 1D quantization parameters:
     if len(param.shape) == 0:
@@ -62,7 +68,7 @@ def default_schedule(outs):
     return s
 
 
-def qnn_quantize(data, output_scale, output_zero_point, axis, out_dtype):
+def qnn_quantize(data, output_scale, output_zero_point, axis=-1, out_dtype="int8"):
     """Compute for qnn.quantize
 
     Q_output = clamp((round(input_tensor/output_scale) + output_zero_point),
@@ -101,7 +107,7 @@ def schedule_qnn_quantize(outs):
     return default_schedule(outs)
 
 
-def qnn_dequantize(data, input_scale, input_zero_point, axis):
+def qnn_dequantize(data, input_scale, input_zero_point, axis=-1):
     """Compute for qnn.dequantize
 
     fp_output = input_scale * (Q_input - input_zero_point)
@@ -134,7 +140,7 @@ def schedule_qnn_dequantize(outs):
     return default_schedule(outs)
 
 
-def qnn_requantize(data, input_scale, input_zp, output_scale, output_zp, axis, out_dtype):
+def qnn_requantize(data, input_scale, input_zp, output_scale, output_zp, axis=-1, out_dtype="int8"):
     """Compute for qnn.requantize
 
     Q_output = zp_output + round((scale_input)/(scale_output) * (Q_input - zp_input))
@@ -177,37 +183,58 @@ def schedule_qnn_requantize(outs):
     return default_schedule(outs)
 
 
-def qnn_add(
-    lhs, rhs, lhs_scale, lhs_zero_point, rhs_scale, rhs_zero_point, output_scale, output_zero_point
+def compute_qnn_binary_op(
+    lhs, rhs, lhs_scale, lhs_zp, rhs_scale, rhs_zp, output_scale, output_zp, func
 ):
-    """Compute for qnn.add
+    """Compute for QNN binary operation
 
-    Q_output = zp_output + round((lhs_scale)/(scale_output) * (lhs_input - lhs_zp_input))
-                         + round((rhs_scale)/(scale_output) * (rhs_input - rhs_zp_input))
-
-    TODO: support 'axis' argument.
+    Q_output = output_zp + round((lhs_scale)/(output_scale) * (lhs_input - lhs_zp))
+                      _OP_ round((rhs_scale)/(output_scale) * (rhs_input - rhs_zp))
+    where _OP_ is add/subtract
     """
-
     assert lhs.dtype == rhs.dtype
     dtype = lhs.dtype
 
+    def _compute_const(x: te.Tensor, iscale, input_zp):
+        return te.round(te.multiply(te.div(iscale, output_scale), te.subtract(x, input_zp))).astype(
+            "int32"
+        )
+
+    def _compute_tensor(x: te.Tensor, iscale, input_zp):
+        return te.compute(
+            x.shape,
+            lambda *i: te.round(
+                te.multiply(te.div(iscale, output_scale), te.subtract(x(*i), input_zp))
+            ).astype("int32"),
+        )
+
+    if is_constant(lhs):
+        lhs_tensor = _compute_const(lhs, lhs_scale, lhs_zp)
+    else:
+        lhs_tensor = _compute_tensor(lhs, lhs_scale, lhs_zp)
+
+    if is_constant(rhs):
+        rhs_tensor = _compute_const(rhs, rhs_scale, rhs_zp)
+    else:
+        rhs_tensor = _compute_tensor(rhs, rhs_scale, rhs_zp)
+
+    # Binary op with broadcasting
+    tensor = func(lhs_tensor, rhs_tensor)
+
+    # Add output zero point and clip+cast.
     def _compute(*indices):
-        lvalue = lhs(*indices)
-        rvalue = rhs(*indices)
-        q_lv = te.round(
-            te.multiply(te.div(lhs_scale, output_scale), te.subtract(lvalue, lhs_zero_point))
-        ).astype("int32")
-        q_rv = te.round(
-            te.multiply(te.div(rhs_scale, output_scale), te.subtract(rvalue, rhs_zero_point))
-        ).astype("int32")
-        val = te.add(te.add(q_lv, q_rv), output_zero_point)
+        return saturate(te.add(tensor(*indices), output_zp), dtype).astype(dtype)
+
+    return te.compute(tensor.shape, _compute)
 
-        # clip + cast:
-        const_min = tvm.tir.min_value(dtype)
-        const_max = tvm.tir.max_value(dtype)
-        return te.max(tvm.te.min(val, const_max), const_min).astype(dtype)
 
-    return te.compute(lhs.shape, _compute)
+def qnn_add(lhs, rhs, lhs_scale, lhs_zp, rhs_scale, rhs_zp, output_scale, output_zp):
+    """Compute for qnn.add
+    TODO: support 'axis' argument.
+    """
+    return compute_qnn_binary_op(
+        lhs, rhs, lhs_scale, lhs_zp, rhs_scale, rhs_zp, output_scale, output_zp, topi.add
+    )
 
 
 def schedule_qnn_add(outs):
@@ -227,19 +254,99 @@ def schedule_qnn_add(outs):
     return default_schedule(outs)
 
 
-def requantize_tensor(tensor, i_scale, i_zp, o_scale, o_zp, out_dtype):
-    """Requantize tensor"""
+def qnn_subtract(lhs, rhs, lhs_scale, lhs_zp, rhs_scale, rhs_zp, output_scale, output_zp):
+    """Compute for qnn.subtract"""
 
-    def _compute(*indices):
-        value = tensor(*indices)
-        mul_value = te.round(
-            te.multiply(te.div(i_scale, o_scale), te.subtract(value, i_zp))
-        ).astype("int32")
-        rq_value = te.add(mul_value, o_zp)
+    return compute_qnn_binary_op(
+        lhs, rhs, lhs_scale, lhs_zp, rhs_scale, rhs_zp, output_scale, output_zp, topi.subtract
+    )
 
-        return clip_cast(rq_value, out_dtype)
 
-    return te.compute(tensor.shape, _compute)
+def schedule_qnn_subtract(outs):
+    """Schedule for qnn.subtract
+
+    Parameters
+    ----------
+    outs: Array of Tensor
+          The computation graph description of qnn.add
+          in the format of an array of tensors.
+
+    Returns
+    -------
+    sch: Schedule
+        The computation schedule for the op.
+    """
+    return default_schedule(outs)
+
+
+def qnn_mul(lhs, rhs, lhs_scale, lhs_zp, rhs_scale, rhs_zp, output_scale, output_zp):
+    """Compute for qnn.mul
+
+    mul = (lhs_input - lhs_zp) * (rhs_input - rhs_zp)
+    Q_output = requantize(mul, lhs_scale * rhs_scale, 0, output_scale, output_zp)
+    """
+    assert lhs.dtype == rhs.dtype
+    odtype = lhs.dtype
+
+    if is_constant(lhs):
+        lhs_tensor = lhs - lhs_zp
+    else:
+        lhs_tensor = te.compute(lhs.shape, lambda *i: te.subtract(lhs(*i), lhs_zp))
+
+    if is_constant(rhs):
+        rhs_tensor = rhs - rhs_zp
+    else:
+        rhs_tensor = te.compute(rhs.shape, lambda *i: te.subtract(rhs(*i), rhs_zp))
+
+    # Multiply with broadcasting.
+    mul = topi.multiply(lhs_tensor, rhs_tensor)
+
+    iscale = lhs_scale * rhs_scale
+    return qnn_requantize(mul, iscale, tvm.tir.const(0), output_scale, output_zp, out_dtype=odtype)
+
+
+def schedule_qnn_mul(outs):
+    """Schedule for qnn.mul
+
+    Parameters
+    ----------
+    outs: Array of Tensor
+          The computation graph description of qnn.add
+          in the format of an array of tensors.
+
+    Returns
+    -------
+    sch: Schedule
+        The computation schedule for the op.
+    """
+    return default_schedule(outs)
+
+
+def qnn_tanh(data, input_scale, input_zp, output_scale, output_zp):
+    """Compute for qnn.tanh
+
+    Q_output = quantize(tanh(dequantize(data)))
+    """
+    dq_tensor = qnn_dequantize(data, input_scale, input_zp)
+    tanh = te.compute(dq_tensor.shape, lambda *i: te.tanh(dq_tensor(*i)))
+    return qnn_quantize(tanh, output_scale, output_zp, out_dtype=data.dtype)
+
+
+def schedule_qnn_tanh(outs):
+    """Schedule for qnn.tanh
+
+    Parameters
+    ----------
+    outs: Array of Tensor
+          The computation graph description of qnn.add
+          in the format of an array of tensors.
+
+    Returns
+    -------
+    sch: Schedule
+        The computation schedule for the op.
+    """
+    return default_schedule(outs)
 
 
 def qnn_concatenate(data, axis, out_dtype):
@@ -282,7 +389,7 @@ def qnn_concatenate(data, axis, out_dtype):
         i_zp = data[i + args_num * 2]
 
         # Requantize tensors and add them to the list.
-        args.append(requantize_tensor(tensor, i_scale, i_zp, o_scale, o_zp, out_dtype))
+        args.append(qnn_requantize(tensor, i_scale, i_zp, o_scale, o_zp, out_dtype=out_dtype))
 
     # Call x86 implementation of concatenate.
     return concatenate(args, axis)
diff --git a/src/relay/qnn/op/add.cc b/src/relay/qnn/op/add.cc
index d087d9fa77..0e0d3fdbc0 100644
--- a/src/relay/qnn/op/add.cc
+++ b/src/relay/qnn/op/add.cc
@@ -96,7 +96,8 @@ Expr QnnAddCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
 QNN_REGISTER_BINARY_OP("add")
     .describe("Elementwise add with broadcasting for quantized tensors.")
     .set_support_level(11)
-    .set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnAddCanonicalize);
+    .set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnAddCanonicalize)
+    .set_attr<TOpPattern>("TOpPattern", kBroadcast);
 
 }  // namespace qnn
 }  // namespace relay
diff --git a/src/relay/qnn/op/mul.cc b/src/relay/qnn/op/mul.cc
index 6dde61359d..73c6eed448 100644
--- a/src/relay/qnn/op/mul.cc
+++ b/src/relay/qnn/op/mul.cc
@@ -162,7 +162,8 @@ Expr QnnMulCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
 QNN_REGISTER_BINARY_OP("mul")
     .describe("Elementwise mul with broadcasting for quantized tensors.")
     .set_support_level(11)
-    .set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnMulCanonicalize);
+    .set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnMulCanonicalize)
+    .set_attr<TOpPattern>("TOpPattern", kBroadcast);
 
 }  // namespace qnn
 }  // namespace relay
diff --git a/src/relay/qnn/op/requantize.cc b/src/relay/qnn/op/requantize.cc
index e199ea27f1..91df4a287c 100644
--- a/src/relay/qnn/op/requantize.cc
+++ b/src/relay/qnn/op/requantize.cc
@@ -384,6 +384,9 @@ Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale,
                      const Expr& input_zero_point, const Expr& output_scale,
                      const Expr& output_zero_point, const RequantizeAttrs* param,
                      const Array<IndexExpr>& input_shape, const DataType& out_dtype) {
+  // Check output scale validity.
+  ICHECK_NE(GetScalarFromConstant<float>(output_scale), 0.0)
+      << "QNN requantize output scale can not be equal to 0.0";
   // Check rounding validity.
   ICHECK(param->rounding == "UPWARD" || param->rounding == "TONEAREST")
       << "QNN requantize supports two rounding modes - UPWARD and "
diff --git a/src/relay/qnn/op/subtract.cc b/src/relay/qnn/op/subtract.cc
index 1815019220..962a3434cb 100644
--- a/src/relay/qnn/op/subtract.cc
+++ b/src/relay/qnn/op/subtract.cc
@@ -97,7 +97,8 @@ Expr QnnSubtractCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
 QNN_REGISTER_BINARY_OP("subtract")
     .describe("Elementwise subtract with broadcasting for quantized tensors.")
     .set_support_level(11)
-    .set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnSubtractCanonicalize);
+    .set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnSubtractCanonicalize)
+    .set_attr<TOpPattern>("TOpPattern", kBroadcast);
 
 }  // namespace qnn
 }  // namespace relay
diff --git a/tests/python/contrib/test_hexagon/test_wo_qnn_canonicalization.py b/tests/python/contrib/test_hexagon/test_wo_qnn_canonicalization.py
index e4edf2919a..06e738d9b7 100644
--- a/tests/python/contrib/test_hexagon/test_wo_qnn_canonicalization.py
+++ b/tests/python/contrib/test_hexagon/test_wo_qnn_canonicalization.py
@@ -51,13 +51,33 @@ def test_no_qnn_pass():
     assert "qnn.dequantize" in opt_mod_2.astext(show_meta_data=False)
 
 
-def execute(executor, data_np, weight_np, bias_np=None):
-    executor.set_input("data", data_np)
-    executor.set_input("weight", weight_np)
-    if bias_np is not None:
-        executor.set_input("bias", bias_np)
-    executor.run()
-    return executor.get_output(0)
+def execute(mod_executor, inputs: dict):
+    for input_name, input_data in inputs.items():
+        mod_executor.set_input(input_name, input_data)
+    mod_executor.run()
+    return mod_executor.get_output(0).numpy()
+
+
+def build_hexagon_module(mod):
+    with tvm.transform.PassContext(opt_level=3, disabled_pass=["qnn.Legalize"]):
+        hexagon_lowered = tvm.relay.build(
+            mod,
+            tvm.target.Target(HEXAGON_AOT_LLVM_TARGET, host=HEXAGON_AOT_LLVM_TARGET),
+            executor=Executor("aot"),
+        )
+
+    return hexagon_lowered
+
+
+def build_ref_module(mod):
+    target_llvm = tvm.target.Target("llvm")
+    with tvm.transform.PassContext(opt_level=3):
+        llvm_lowered = tvm.relay.build(
+            mod,
+            tvm.target.Target(target_llvm, host=target_llvm),
+            executor=Executor("aot"),
+        )
+    return llvm_lowered
 
 
 @tvm.testing.requires_hexagon
@@ -90,33 +110,24 @@ def test_qnn_conv2d_rq(hexagon_session: Session):
     )
     relay_mod = tvm.IRModule.from_expr(op5)
 
-    target_llvm = tvm.target.Target("llvm")
-    executor = Executor("aot")
-    with tvm.transform.PassContext(opt_level=3, disabled_pass=["qnn.Legalize"]):
-        hexagon_lowered = tvm.relay.build(
-            relay_mod,
-            tvm.target.Target(HEXAGON_AOT_LLVM_TARGET, host=HEXAGON_AOT_LLVM_TARGET),
-            executor=executor,
-        )
+    # Compile for Hexagon
+    hexagon_lowered = build_hexagon_module(relay_mod)
 
-    with tvm.transform.PassContext(opt_level=3):
-        llvm_lowered = tvm.relay.build(
-            relay_mod,
-            tvm.target.Target(target_llvm, host=target_llvm),
-            executor=executor,
-        )
+    # Reference compilation
+    llvm_lowered = build_ref_module(relay_mod)
 
     data_np = np.random.rand(*data_shape) - 0.5
     weight_np = np.random.rand(*weight_shape) - 0.5
+    inputs = {"data": data_np, "weight": weight_np}
 
     hx_m = hexagon_session.get_executor_from_factory(hexagon_lowered)
-    hexagon_output = execute(hx_m, data_np, weight_np)
+    hexagon_output = execute(hx_m, inputs)
 
     dev = tvm.cpu(0)
     llvm_m = tvm.runtime.executor.AotModule(llvm_lowered["default"](dev))
-    llvm_out = execute(llvm_m, data_np, weight_np)
+    llvm_out = execute(llvm_m, inputs)
 
-    np.testing.assert_equal(hexagon_output.numpy(), llvm_out.numpy())
+    np.testing.assert_equal(hexagon_output, llvm_out)
 
 
 @tvm.testing.requires_hexagon
@@ -152,34 +163,119 @@ def test_qnn_dense_bias_rq(hexagon_session: Session):
     )
     relay_mod = tvm.IRModule.from_expr(op5)
 
-    target_llvm = tvm.target.Target("llvm")
-    executor = Executor("aot")
-    with tvm.transform.PassContext(opt_level=3, disabled_pass=["qnn.Legalize"]):
-        hexagon_lowered = tvm.relay.build(
-            relay_mod,
-            tvm.target.Target(HEXAGON_AOT_LLVM_TARGET, host=HEXAGON_AOT_LLVM_TARGET),
-            executor=executor,
-        )
+    # Compile for Hexagon
+    hexagon_lowered = build_hexagon_module(relay_mod)
 
-    with tvm.transform.PassContext(opt_level=3):
-        llvm_lowered = tvm.relay.build(
-            relay_mod,
-            tvm.target.Target(target_llvm, host=target_llvm),
-            executor=executor,
-        )
+    # Reference compilation
+    llvm_lowered = build_ref_module(relay_mod)
 
     data_np = np.random.rand(*data_shape) - 0.5
     weight_np = np.random.rand(*weight_shape) - 0.5
     bias_np = np.random.rand(*bias_shape)
+    inputs = {"data": data_np, "weight": weight_np, "bias": bias_np}
 
     hx_m = hexagon_session.get_executor_from_factory(hexagon_lowered)
-    hexagon_output = execute(hx_m, data_np, weight_np, bias_np)
+    hexagon_output = execute(hx_m, inputs)
 
     dev = tvm.cpu(0)
     llvm_m = tvm.runtime.executor.AotModule(llvm_lowered["default"](dev))
-    llvm_out = execute(llvm_m, data_np, weight_np, bias_np)
+    llvm_out = execute(llvm_m, inputs)
+
+    np.testing.assert_equal(hexagon_output, llvm_out)
+
+
+class TestQnnBinaryOp:
+    """QNN binary op test class"""
+
+    operation = tvm.testing.parameter(
+        relay.qnn.op.add,
+        relay.qnn.op.subtract,
+        relay.qnn.op.mul,
+    )
+    dtype = tvm.testing.parameter("uint8", "int8")
+    input_shape = tvm.testing.parameter([256], [4, 256])
+
+    @tvm.testing.requires_hexagon
+    def test_qnn_binary_op_broadcasting(
+        self, hexagon_session: Session, operation, dtype, input_shape
+    ):
+        """qnn binary op test without QNN canonicalization."""
+        lhs_shape = [4, 256]
+        rhs_shape = input_shape
+        lhs = relay.var("lhs", shape=lhs_shape, dtype=dtype)
+        rhs = relay.var("rhs", shape=rhs_shape, dtype=dtype)
+        zp_const1 = 1
+        zp_const2 = 3
+
+        op = operation(
+            lhs,
+            rhs,
+            lhs_scale=relay.const(0.041, "float32"),
+            lhs_zero_point=relay.const(zp_const1, "int32"),
+            rhs_scale=relay.const(0.017, "float32"),
+            rhs_zero_point=relay.const(zp_const2, "int32"),
+            output_scale=relay.const(0.039, "float32"),
+            output_zero_point=relay.const(2, "int32"),
+        )
+        mod = tvm.IRModule.from_expr(op)
+
+        # Compile for Hexagon
+        hexagon_lowered = build_hexagon_module(mod)
+
+        # Reference compilation
+        llvm_lowered = build_ref_module(mod)
+
+        lhs_np = np.random.randint(np.iinfo(dtype).min + zp_const1, np.iinfo(dtype).max, lhs_shape)
+        rhs_np = np.random.randint(np.iinfo(dtype).min + zp_const2, np.iinfo(dtype).max, rhs_shape)
+        inputs = {"lhs": lhs_np, "rhs": rhs_np}
+
+        hx_m = hexagon_session.get_executor_from_factory(hexagon_lowered)
+        hexagon_output = execute(hx_m, inputs)
+
+        dev = tvm.cpu(0)
+        llvm_m = tvm.runtime.executor.AotModule(llvm_lowered["default"](dev))
+        llvm_output = execute(llvm_m, inputs)
+
+        # Diff by 1 is Ok.
+        tvm.testing.assert_allclose(hexagon_output, llvm_output, atol=1)
+
+    @tvm.testing.requires_hexagon
+    def test_qnn_binary_op_scalar(self, hexagon_session: Session, operation):
+        """qnn binary op test without QNN canonicalization."""
+        lhs_shape = [4, 256]
+        lhs = relay.var("lhs", shape=lhs_shape, dtype="uint8")
+        rhs = relay.const(11, dtype="uint8")
+
+        op = operation(
+            lhs,
+            rhs,
+            lhs_scale=relay.const(0.049, "float32"),
+            lhs_zero_point=relay.const(1, "int32"),
+            rhs_scale=relay.const(0.067, "float32"),
+            rhs_zero_point=relay.const(3, "int32"),
+            output_scale=relay.const(0.041, "float32"),
+            output_zero_point=relay.const(2, "int32"),
+        )
+        mod = tvm.IRModule.from_expr(op)
+
+        # Compile for Hexagon
+        hexagon_lowered = build_hexagon_module(mod)
+
+        # Reference compilation
+        llvm_lowered = build_ref_module(mod)
+
+        lhs_np = np.random.randint(1, 255, size=lhs_shape)
+        inputs = {"lhs": lhs_np}
+
+        hx_m = hexagon_session.get_executor_from_factory(hexagon_lowered)
+        hexagon_output = execute(hx_m, inputs)
+
+        dev = tvm.cpu(0)
+        llvm_m = tvm.runtime.executor.AotModule(llvm_lowered["default"](dev))
+        llvm_output = execute(llvm_m, inputs)
 
-    np.testing.assert_equal(hexagon_output.numpy(), llvm_out.numpy())
+        # Diff by 1 is Ok.
+        tvm.testing.assert_allclose(hexagon_output, llvm_output, atol=1)
 
 
 if __name__ == "__main__":