You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/11/18 04:06:47 UTC
[tvm] branch main updated: [Hexagon][QNN] Add TOPI strategies for qnn ops mul/tanh/subtract (#13416)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 53824d697a [Hexagon][QNN] Add TOPI strategies for qnn ops mul/tanh/subtract (#13416)
53824d697a is described below
commit 53824d697a633260ac62777eafd624c6406d9d42
Author: ibsidorenko <98...@users.noreply.github.com>
AuthorDate: Fri Nov 18 07:06:42 2022 +0300
[Hexagon][QNN] Add TOPI strategies for qnn ops mul/tanh/subtract (#13416)
This commit adds compute/schedule implementation for Hexagon target for
QNN ops: qnn.mul, qnn.subtract, qnn.tanh. It works only if QNN
canonicalization pass was disabled.
---
python/tvm/relay/qnn/op/_qnn.py | 11 +-
python/tvm/relay/qnn/strategy/generic.py | 27 ++++
python/tvm/relay/qnn/strategy/hexagon.py | 36 +++++
python/tvm/topi/hexagon/qnn/nn.py | 179 ++++++++++++++++-----
src/relay/qnn/op/add.cc | 3 +-
src/relay/qnn/op/mul.cc | 3 +-
src/relay/qnn/op/requantize.cc | 3 +
src/relay/qnn/op/subtract.cc | 3 +-
.../test_hexagon/test_wo_qnn_canonicalization.py | 178 +++++++++++++++-----
9 files changed, 362 insertions(+), 81 deletions(-)
diff --git a/python/tvm/relay/qnn/op/_qnn.py b/python/tvm/relay/qnn/op/_qnn.py
index 4e54583a3b..64ef1ee92a 100644
--- a/python/tvm/relay/qnn/op/_qnn.py
+++ b/python/tvm/relay/qnn/op/_qnn.py
@@ -66,7 +66,16 @@ register_pattern("qnn.requantize", OpPattern.ELEMWISE)
# qnn.add
register_strategy("qnn.add", strategy.qnn_add_strategy)
-register_pattern("qnn.add", OpPattern.BROADCAST)
+
+# qnn.subtract
+register_strategy("qnn.subtract", strategy.qnn_subtract_strategy)
+
+# qnn.mul
+register_strategy("qnn.mul", strategy.qnn_mul_strategy)
+
+# qnn.tanh
+register_strategy("qnn.tanh", strategy.qnn_tanh_strategy)
+register_pattern("qnn.tanh", OpPattern.ELEMWISE)
# qnn.concatenate
register_strategy("qnn.concatenate", strategy.qnn_concatenate_strategy)
diff --git a/python/tvm/relay/qnn/strategy/generic.py b/python/tvm/relay/qnn/strategy/generic.py
index 57a364f7e0..8275cf7f75 100644
--- a/python/tvm/relay/qnn/strategy/generic.py
+++ b/python/tvm/relay/qnn/strategy/generic.py
@@ -213,6 +213,33 @@ def qnn_add_strategy(attrs, inputs, out_type, target):
)
+@override_native_generic_func("qnn_subtract_strategy")
+def qnn_subtract_strategy(attrs, inputs, out_type, target):
+ """qnn.subtract generic strategy"""
+ raise RuntimeError(
+ "qnn.subtract is currently only supported with Hexagon. "
+ "Please run QNN Canonicalize pass to decompose this op into supported ops."
+ )
+
+
+@override_native_generic_func("qnn_mul_strategy")
+def qnn_mul_strategy(attrs, inputs, out_type, target):
+ """qnn.mul generic strategy"""
+ raise RuntimeError(
+ "qnn.mul is currently only supported with Hexagon. "
+ "Please run QNN Canonicalize pass to decompose this op into supported ops."
+ )
+
+
+@override_native_generic_func("qnn_tanh_strategy")
+def qnn_tanh_strategy(attrs, inputs, out_type, target):
+ """qnn.tanh generic strategy"""
+ raise RuntimeError(
+ "qnn.tanh is currently only supported with Hexagon. "
+ "Please run QNN Canonicalize pass to decompose this op into supported ops."
+ )
+
+
@override_native_generic_func("qnn_concatenate_strategy")
def qnn_concatenate_strategy(attrs, inputs, out_type, target):
"""qnn.concatenate generic strategy"""
diff --git a/python/tvm/relay/qnn/strategy/hexagon.py b/python/tvm/relay/qnn/strategy/hexagon.py
index c7f59cc096..d17812e3fb 100644
--- a/python/tvm/relay/qnn/strategy/hexagon.py
+++ b/python/tvm/relay/qnn/strategy/hexagon.py
@@ -71,6 +71,42 @@ def qnn_add_strategy_hexagon(attrs, inputs, out_type, target):
return strategy
+@qnn_subtract_strategy.register("hexagon")
+def qnn_subtract_strategy_hexagon(attrs, inputs, out_type, target):
+ """qnn.subtract strategy for Hexagon"""
+ strategy = _op.OpStrategy()
+ strategy.add_implementation(
+ wrap_topi_compute(topi.hexagon.qnn_subtract),
+ wrap_topi_schedule(topi.hexagon.schedule_qnn_subtract),
+ name="qnn_subtract.hexagon",
+ )
+ return strategy
+
+
+@qnn_mul_strategy.register("hexagon")
+def qnn_mul_strategy_hexagon(attrs, inputs, out_type, target):
+ """qnn.mul strategy for Hexagon"""
+ strategy = _op.OpStrategy()
+ strategy.add_implementation(
+ wrap_topi_compute(topi.hexagon.qnn_mul),
+ wrap_topi_schedule(topi.hexagon.schedule_qnn_mul),
+ name="qnn_mul.hexagon",
+ )
+ return strategy
+
+
+@qnn_tanh_strategy.register("hexagon")
+def qnn_tanh_strategy_hexagon(attrs, inputs, out_type, target):
+ """qnn.tanh strategy for Hexagon"""
+ strategy = _op.OpStrategy()
+ strategy.add_implementation(
+ wrap_topi_compute(topi.hexagon.qnn_tanh),
+ wrap_topi_schedule(topi.hexagon.schedule_qnn_tanh),
+ name="qnn_tanh.hexagon",
+ )
+ return strategy
+
+
@qnn_concatenate_strategy.register("hexagon")
def qnn_concatenate_strategy_hexagon(attrs, inputs, out_type, target):
"""qnn.concatenate strategy for Hexagon"""
diff --git a/python/tvm/topi/hexagon/qnn/nn.py b/python/tvm/topi/hexagon/qnn/nn.py
index 40cfd0ee96..49220d0fd0 100644
--- a/python/tvm/topi/hexagon/qnn/nn.py
+++ b/python/tvm/topi/hexagon/qnn/nn.py
@@ -19,6 +19,7 @@
import tvm
from tvm import te, topi
+from ..utils import saturate
from ...utils import get_const_tuple
from ...nn.utils import get_pad_tuple
from ...nn.pad import pad
@@ -33,6 +34,11 @@ def clip_cast(val, dtype):
return te.max(tvm.te.min(val, const_max), const_min).astype(dtype)
+# Return True if given Tensor is scalar constant value.
+def is_constant(tensor: te.Tensor):
+ return tensor.ndim == 0
+
+
def get_qnn_param(param, indices, axis):
# Account scalar and 1D quantization parameters:
if len(param.shape) == 0:
@@ -62,7 +68,7 @@ def default_schedule(outs):
return s
-def qnn_quantize(data, output_scale, output_zero_point, axis, out_dtype):
+def qnn_quantize(data, output_scale, output_zero_point, axis=-1, out_dtype="int8"):
"""Compute for qnn.quantize
Q_output = clamp((round(input_tensor/output_scale) + output_zero_point),
@@ -101,7 +107,7 @@ def schedule_qnn_quantize(outs):
return default_schedule(outs)
-def qnn_dequantize(data, input_scale, input_zero_point, axis):
+def qnn_dequantize(data, input_scale, input_zero_point, axis=-1):
"""Compute for qnn.dequantize
fp_output = input_scale * (Q_input - input_zero_point)
@@ -134,7 +140,7 @@ def schedule_qnn_dequantize(outs):
return default_schedule(outs)
-def qnn_requantize(data, input_scale, input_zp, output_scale, output_zp, axis, out_dtype):
+def qnn_requantize(data, input_scale, input_zp, output_scale, output_zp, axis=-1, out_dtype="int8"):
"""Compute for qnn.requantize
Q_output = zp_output + round((scale_input)/(scale_output) * (Q_input - zp_input))
@@ -177,37 +183,58 @@ def schedule_qnn_requantize(outs):
return default_schedule(outs)
-def qnn_add(
- lhs, rhs, lhs_scale, lhs_zero_point, rhs_scale, rhs_zero_point, output_scale, output_zero_point
+def compute_qnn_binary_op(
+ lhs, rhs, lhs_scale, lhs_zp, rhs_scale, rhs_zp, output_scale, output_zp, func
):
- """Compute for qnn.add
+ """Compute for QNN binary operation
- Q_output = zp_output + round((lhs_scale)/(scale_output) * (lhs_input - lhs_zp_input))
- + round((rhs_scale)/(scale_output) * (rhs_input - rhs_zp_input))
-
- TODO: support 'axis' argument.
+ Q_output = output_zp + round((lhs_scale)/(output_scale) * (lhs_input - lhs_zp))
+ _OP_ round((rhs_scale)/(output_scale) * (rhs_input - rhs_zp))
+ where _OP_ is add/subtract
"""
-
assert lhs.dtype == rhs.dtype
dtype = lhs.dtype
+ def _compute_const(x: te.Tensor, iscale, input_zp):
+ return te.round(te.multiply(te.div(iscale, output_scale), te.subtract(x, input_zp))).astype(
+ "int32"
+ )
+
+ def _compute_tensor(x: te.Tensor, iscale, input_zp):
+ return te.compute(
+ x.shape,
+ lambda *i: te.round(
+ te.multiply(te.div(iscale, output_scale), te.subtract(x(*i), input_zp))
+ ).astype("int32"),
+ )
+
+ if is_constant(lhs):
+ lhs_tensor = _compute_const(lhs, lhs_scale, lhs_zp)
+ else:
+ lhs_tensor = _compute_tensor(lhs, lhs_scale, lhs_zp)
+
+ if is_constant(rhs):
+ rhs_tensor = _compute_const(rhs, rhs_scale, rhs_zp)
+ else:
+ rhs_tensor = _compute_tensor(rhs, rhs_scale, rhs_zp)
+
+ # Binary op with broadcasting
+ tensor = func(lhs_tensor, rhs_tensor)
+
+ # Add output zero point and clip+cast.
def _compute(*indices):
- lvalue = lhs(*indices)
- rvalue = rhs(*indices)
- q_lv = te.round(
- te.multiply(te.div(lhs_scale, output_scale), te.subtract(lvalue, lhs_zero_point))
- ).astype("int32")
- q_rv = te.round(
- te.multiply(te.div(rhs_scale, output_scale), te.subtract(rvalue, rhs_zero_point))
- ).astype("int32")
- val = te.add(te.add(q_lv, q_rv), output_zero_point)
+ return saturate(te.add(tensor(*indices), output_zp), dtype).astype(dtype)
+
+ return te.compute(tensor.shape, _compute)
- # clip + cast:
- const_min = tvm.tir.min_value(dtype)
- const_max = tvm.tir.max_value(dtype)
- return te.max(tvm.te.min(val, const_max), const_min).astype(dtype)
- return te.compute(lhs.shape, _compute)
+def qnn_add(lhs, rhs, lhs_scale, lhs_zp, rhs_scale, rhs_zp, output_scale, output_zp):
+ """Compute for qnn.add
+ TODO: support 'axis' argument.
+ """
+ return compute_qnn_binary_op(
+ lhs, rhs, lhs_scale, lhs_zp, rhs_scale, rhs_zp, output_scale, output_zp, topi.add
+ )
def schedule_qnn_add(outs):
@@ -227,19 +254,99 @@ def schedule_qnn_add(outs):
return default_schedule(outs)
-def requantize_tensor(tensor, i_scale, i_zp, o_scale, o_zp, out_dtype):
- """Requantize tensor"""
+def qnn_subtract(lhs, rhs, lhs_scale, lhs_zp, rhs_scale, rhs_zp, output_scale, output_zp):
+ """Compute for qnn.subtract"""
- def _compute(*indices):
- value = tensor(*indices)
- mul_value = te.round(
- te.multiply(te.div(i_scale, o_scale), te.subtract(value, i_zp))
- ).astype("int32")
- rq_value = te.add(mul_value, o_zp)
+ return compute_qnn_binary_op(
+ lhs, rhs, lhs_scale, lhs_zp, rhs_scale, rhs_zp, output_scale, output_zp, topi.subtract
+ )
- return clip_cast(rq_value, out_dtype)
- return te.compute(tensor.shape, _compute)
+def schedule_qnn_subtract(outs):
+ """Schedule for qnn.subtract
+
+ Parameters
+ ----------
+ outs: Array of Tensor
+ The computation graph description of qnn.add
+ in the format of an array of tensors.
+
+ Returns
+ -------
+ sch: Schedule
+ The computation schedule for the op.
+ """
+ return default_schedule(outs)
+
+
+def qnn_mul(lhs, rhs, lhs_scale, lhs_zp, rhs_scale, rhs_zp, output_scale, output_zp):
+ """Compute for qnn.mul
+
+ mul = (lhs_input - lhs_zp) * (rhs_input - rhs_zp)
+ Q_output = requantize(mul, lhs_scale * rhs_scale, 0, output_scale, output_zp)
+ """
+ assert lhs.dtype == rhs.dtype
+ odtype = lhs.dtype
+
+ if is_constant(lhs):
+ lhs_tensor = lhs - lhs_zp
+ else:
+ lhs_tensor = te.compute(lhs.shape, lambda *i: te.subtract(lhs(*i), lhs_zp))
+
+ if is_constant(rhs):
+ rhs_tensor = rhs - rhs_zp
+ else:
+ rhs_tensor = te.compute(rhs.shape, lambda *i: te.subtract(rhs(*i), rhs_zp))
+
+ # Multiply with broadcasting.
+ mul = topi.multiply(lhs_tensor, rhs_tensor)
+
+ iscale = lhs_scale * rhs_scale
+ return qnn_requantize(mul, iscale, tvm.tir.const(0), output_scale, output_zp, out_dtype=odtype)
+
+
+def schedule_qnn_mul(outs):
+ """Schedule for qnn.mul
+
+ Parameters
+ ----------
+ outs: Array of Tensor
+ The computation graph description of qnn.add
+ in the format of an array of tensors.
+
+ Returns
+ -------
+ sch: Schedule
+ The computation schedule for the op.
+ """
+ return default_schedule(outs)
+
+
+def qnn_tanh(data, input_scale, input_zp, output_scale, output_zp):
+ """Compute for qnn.tanh
+
+ Q_output = quantize(tanh(dequantize(data)))
+ """
+ dq_tensor = qnn_dequantize(data, input_scale, input_zp)
+ tanh = te.compute(dq_tensor.shape, lambda *i: te.tanh(dq_tensor(*i)))
+ return qnn_quantize(tanh, output_scale, output_zp, out_dtype=data.dtype)
+
+
+def schedule_qnn_tanh(outs):
+ """Schedule for qnn.tanh
+
+ Parameters
+ ----------
+ outs: Array of Tensor
+ The computation graph description of qnn.add
+ in the format of an array of tensors.
+
+ Returns
+ -------
+ sch: Schedule
+ The computation schedule for the op.
+ """
+ return default_schedule(outs)
def qnn_concatenate(data, axis, out_dtype):
@@ -282,7 +389,7 @@ def qnn_concatenate(data, axis, out_dtype):
i_zp = data[i + args_num * 2]
# Requantize tensors and add them to the list.
- args.append(requantize_tensor(tensor, i_scale, i_zp, o_scale, o_zp, out_dtype))
+ args.append(qnn_requantize(tensor, i_scale, i_zp, o_scale, o_zp, out_dtype=out_dtype))
# Call x86 implementation of concatenate.
return concatenate(args, axis)
diff --git a/src/relay/qnn/op/add.cc b/src/relay/qnn/op/add.cc
index d087d9fa77..0e0d3fdbc0 100644
--- a/src/relay/qnn/op/add.cc
+++ b/src/relay/qnn/op/add.cc
@@ -96,7 +96,8 @@ Expr QnnAddCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
QNN_REGISTER_BINARY_OP("add")
.describe("Elementwise add with broadcasting for quantized tensors.")
.set_support_level(11)
- .set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnAddCanonicalize);
+ .set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnAddCanonicalize)
+ .set_attr<TOpPattern>("TOpPattern", kBroadcast);
} // namespace qnn
} // namespace relay
diff --git a/src/relay/qnn/op/mul.cc b/src/relay/qnn/op/mul.cc
index 6dde61359d..73c6eed448 100644
--- a/src/relay/qnn/op/mul.cc
+++ b/src/relay/qnn/op/mul.cc
@@ -162,7 +162,8 @@ Expr QnnMulCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
QNN_REGISTER_BINARY_OP("mul")
.describe("Elementwise mul with broadcasting for quantized tensors.")
.set_support_level(11)
- .set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnMulCanonicalize);
+ .set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnMulCanonicalize)
+ .set_attr<TOpPattern>("TOpPattern", kBroadcast);
} // namespace qnn
} // namespace relay
diff --git a/src/relay/qnn/op/requantize.cc b/src/relay/qnn/op/requantize.cc
index e199ea27f1..91df4a287c 100644
--- a/src/relay/qnn/op/requantize.cc
+++ b/src/relay/qnn/op/requantize.cc
@@ -384,6 +384,9 @@ Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale,
const Expr& input_zero_point, const Expr& output_scale,
const Expr& output_zero_point, const RequantizeAttrs* param,
const Array<IndexExpr>& input_shape, const DataType& out_dtype) {
+ // Check output scale validity.
+ ICHECK_NE(GetScalarFromConstant<float>(output_scale), 0.0)
+ << "QNN requantize output scale can not be equal to 0.0";
// Check rounding validity.
ICHECK(param->rounding == "UPWARD" || param->rounding == "TONEAREST")
<< "QNN requantize supports two rounding modes - UPWARD and "
diff --git a/src/relay/qnn/op/subtract.cc b/src/relay/qnn/op/subtract.cc
index 1815019220..962a3434cb 100644
--- a/src/relay/qnn/op/subtract.cc
+++ b/src/relay/qnn/op/subtract.cc
@@ -97,7 +97,8 @@ Expr QnnSubtractCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
QNN_REGISTER_BINARY_OP("subtract")
.describe("Elementwise subtract with broadcasting for quantized tensors.")
.set_support_level(11)
- .set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnSubtractCanonicalize);
+ .set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnSubtractCanonicalize)
+ .set_attr<TOpPattern>("TOpPattern", kBroadcast);
} // namespace qnn
} // namespace relay
diff --git a/tests/python/contrib/test_hexagon/test_wo_qnn_canonicalization.py b/tests/python/contrib/test_hexagon/test_wo_qnn_canonicalization.py
index e4edf2919a..06e738d9b7 100644
--- a/tests/python/contrib/test_hexagon/test_wo_qnn_canonicalization.py
+++ b/tests/python/contrib/test_hexagon/test_wo_qnn_canonicalization.py
@@ -51,13 +51,33 @@ def test_no_qnn_pass():
assert "qnn.dequantize" in opt_mod_2.astext(show_meta_data=False)
-def execute(executor, data_np, weight_np, bias_np=None):
- executor.set_input("data", data_np)
- executor.set_input("weight", weight_np)
- if bias_np is not None:
- executor.set_input("bias", bias_np)
- executor.run()
- return executor.get_output(0)
+def execute(mod_executor, inputs: dict):
+ for input_name, input_data in inputs.items():
+ mod_executor.set_input(input_name, input_data)
+ mod_executor.run()
+ return mod_executor.get_output(0).numpy()
+
+
+def build_hexagon_module(mod):
+ with tvm.transform.PassContext(opt_level=3, disabled_pass=["qnn.Legalize"]):
+ hexagon_lowered = tvm.relay.build(
+ mod,
+ tvm.target.Target(HEXAGON_AOT_LLVM_TARGET, host=HEXAGON_AOT_LLVM_TARGET),
+ executor=Executor("aot"),
+ )
+
+ return hexagon_lowered
+
+
+def build_ref_module(mod):
+ target_llvm = tvm.target.Target("llvm")
+ with tvm.transform.PassContext(opt_level=3):
+ llvm_lowered = tvm.relay.build(
+ mod,
+ tvm.target.Target(target_llvm, host=target_llvm),
+ executor=Executor("aot"),
+ )
+ return llvm_lowered
@tvm.testing.requires_hexagon
@@ -90,33 +110,24 @@ def test_qnn_conv2d_rq(hexagon_session: Session):
)
relay_mod = tvm.IRModule.from_expr(op5)
- target_llvm = tvm.target.Target("llvm")
- executor = Executor("aot")
- with tvm.transform.PassContext(opt_level=3, disabled_pass=["qnn.Legalize"]):
- hexagon_lowered = tvm.relay.build(
- relay_mod,
- tvm.target.Target(HEXAGON_AOT_LLVM_TARGET, host=HEXAGON_AOT_LLVM_TARGET),
- executor=executor,
- )
+ # Compile for Hexagon
+ hexagon_lowered = build_hexagon_module(relay_mod)
- with tvm.transform.PassContext(opt_level=3):
- llvm_lowered = tvm.relay.build(
- relay_mod,
- tvm.target.Target(target_llvm, host=target_llvm),
- executor=executor,
- )
+ # Reference compilation
+ llvm_lowered = build_ref_module(relay_mod)
data_np = np.random.rand(*data_shape) - 0.5
weight_np = np.random.rand(*weight_shape) - 0.5
+ inputs = {"data": data_np, "weight": weight_np}
hx_m = hexagon_session.get_executor_from_factory(hexagon_lowered)
- hexagon_output = execute(hx_m, data_np, weight_np)
+ hexagon_output = execute(hx_m, inputs)
dev = tvm.cpu(0)
llvm_m = tvm.runtime.executor.AotModule(llvm_lowered["default"](dev))
- llvm_out = execute(llvm_m, data_np, weight_np)
+ llvm_out = execute(llvm_m, inputs)
- np.testing.assert_equal(hexagon_output.numpy(), llvm_out.numpy())
+ np.testing.assert_equal(hexagon_output, llvm_out)
@tvm.testing.requires_hexagon
@@ -152,34 +163,119 @@ def test_qnn_dense_bias_rq(hexagon_session: Session):
)
relay_mod = tvm.IRModule.from_expr(op5)
- target_llvm = tvm.target.Target("llvm")
- executor = Executor("aot")
- with tvm.transform.PassContext(opt_level=3, disabled_pass=["qnn.Legalize"]):
- hexagon_lowered = tvm.relay.build(
- relay_mod,
- tvm.target.Target(HEXAGON_AOT_LLVM_TARGET, host=HEXAGON_AOT_LLVM_TARGET),
- executor=executor,
- )
+ # Compile for Hexagon
+ hexagon_lowered = build_hexagon_module(relay_mod)
- with tvm.transform.PassContext(opt_level=3):
- llvm_lowered = tvm.relay.build(
- relay_mod,
- tvm.target.Target(target_llvm, host=target_llvm),
- executor=executor,
- )
+ # Reference compilation
+ llvm_lowered = build_ref_module(relay_mod)
data_np = np.random.rand(*data_shape) - 0.5
weight_np = np.random.rand(*weight_shape) - 0.5
bias_np = np.random.rand(*bias_shape)
+ inputs = {"data": data_np, "weight": weight_np, "bias": bias_np}
hx_m = hexagon_session.get_executor_from_factory(hexagon_lowered)
- hexagon_output = execute(hx_m, data_np, weight_np, bias_np)
+ hexagon_output = execute(hx_m, inputs)
dev = tvm.cpu(0)
llvm_m = tvm.runtime.executor.AotModule(llvm_lowered["default"](dev))
- llvm_out = execute(llvm_m, data_np, weight_np, bias_np)
+ llvm_out = execute(llvm_m, inputs)
+
+ np.testing.assert_equal(hexagon_output, llvm_out)
+
+
+class TestQnnBinaryOp:
+ """QNN binary op test class"""
+
+ operation = tvm.testing.parameter(
+ relay.qnn.op.add,
+ relay.qnn.op.subtract,
+ relay.qnn.op.mul,
+ )
+ dtype = tvm.testing.parameter("uint8", "int8")
+ input_shape = tvm.testing.parameter([256], [4, 256])
+
+ @tvm.testing.requires_hexagon
+ def test_qnn_binary_op_broadcasting(
+ self, hexagon_session: Session, operation, dtype, input_shape
+ ):
+ """qnn binary op test without QNN canonicalization."""
+ lhs_shape = [4, 256]
+ rhs_shape = input_shape
+ lhs = relay.var("lhs", shape=lhs_shape, dtype=dtype)
+ rhs = relay.var("rhs", shape=rhs_shape, dtype=dtype)
+ zp_const1 = 1
+ zp_const2 = 3
+
+ op = operation(
+ lhs,
+ rhs,
+ lhs_scale=relay.const(0.041, "float32"),
+ lhs_zero_point=relay.const(zp_const1, "int32"),
+ rhs_scale=relay.const(0.017, "float32"),
+ rhs_zero_point=relay.const(zp_const2, "int32"),
+ output_scale=relay.const(0.039, "float32"),
+ output_zero_point=relay.const(2, "int32"),
+ )
+ mod = tvm.IRModule.from_expr(op)
+
+ # Compile for Hexagon
+ hexagon_lowered = build_hexagon_module(mod)
+
+ # Reference compilation
+ llvm_lowered = build_ref_module(mod)
+
+ lhs_np = np.random.randint(np.iinfo(dtype).min + zp_const1, np.iinfo(dtype).max, lhs_shape)
+ rhs_np = np.random.randint(np.iinfo(dtype).min + zp_const2, np.iinfo(dtype).max, rhs_shape)
+ inputs = {"lhs": lhs_np, "rhs": rhs_np}
+
+ hx_m = hexagon_session.get_executor_from_factory(hexagon_lowered)
+ hexagon_output = execute(hx_m, inputs)
+
+ dev = tvm.cpu(0)
+ llvm_m = tvm.runtime.executor.AotModule(llvm_lowered["default"](dev))
+ llvm_output = execute(llvm_m, inputs)
+
+ # Diff by 1 is Ok.
+ tvm.testing.assert_allclose(hexagon_output, llvm_output, atol=1)
+
+ @tvm.testing.requires_hexagon
+ def test_qnn_binary_op_scalar(self, hexagon_session: Session, operation):
+ """qnn binary op test without QNN canonicalization."""
+ lhs_shape = [4, 256]
+ lhs = relay.var("lhs", shape=lhs_shape, dtype="uint8")
+ rhs = relay.const(11, dtype="uint8")
+
+ op = operation(
+ lhs,
+ rhs,
+ lhs_scale=relay.const(0.049, "float32"),
+ lhs_zero_point=relay.const(1, "int32"),
+ rhs_scale=relay.const(0.067, "float32"),
+ rhs_zero_point=relay.const(3, "int32"),
+ output_scale=relay.const(0.041, "float32"),
+ output_zero_point=relay.const(2, "int32"),
+ )
+ mod = tvm.IRModule.from_expr(op)
+
+ # Compile for Hexagon
+ hexagon_lowered = build_hexagon_module(mod)
+
+ # Reference compilation
+ llvm_lowered = build_ref_module(mod)
+
+ lhs_np = np.random.randint(1, 255, size=lhs_shape)
+ inputs = {"lhs": lhs_np}
+
+ hx_m = hexagon_session.get_executor_from_factory(hexagon_lowered)
+ hexagon_output = execute(hx_m, inputs)
+
+ dev = tvm.cpu(0)
+ llvm_m = tvm.runtime.executor.AotModule(llvm_lowered["default"](dev))
+ llvm_output = execute(llvm_m, inputs)
- np.testing.assert_equal(hexagon_output.numpy(), llvm_out.numpy())
+ # Diff by 1 is Ok.
+ tvm.testing.assert_allclose(hexagon_output, llvm_output, atol=1)
if __name__ == "__main__":