You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by cs...@apache.org on 2023/03/27 20:57:14 UTC

[tvm] branch unity updated: [Unity][QNN][Hexagon]Support Relax Constants in the QNN TOPI operations (#14386)

This is an automated email from the ASF dual-hosted git repository.

csullivan pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new d97c43b724 [Unity][QNN][Hexagon]Support Relax Constants in the QNN TOPI operations (#14386)
d97c43b724 is described below

commit d97c43b724062129b0f8e8ab320f6085e4877196
Author: Farshid Salemi Parizi <fp...@octoml.ai>
AuthorDate: Mon Mar 27 13:57:04 2023 -0700

    [Unity][QNN][Hexagon]Support Relax Constants in the QNN TOPI operations (#14386)
    
    * Support Relax Constants in the QNN TOPI operations
---
 python/tvm/topi/hexagon/qnn/nn.py                  | 45 +++++++++++---
 .../test_hexagon/test_wo_qnn_canonicalization.py   | 70 ++++++++++++++++++++++
 2 files changed, 106 insertions(+), 9 deletions(-)

diff --git a/python/tvm/topi/hexagon/qnn/nn.py b/python/tvm/topi/hexagon/qnn/nn.py
index e60314b827..1a707cef7e 100644
--- a/python/tvm/topi/hexagon/qnn/nn.py
+++ b/python/tvm/topi/hexagon/qnn/nn.py
@@ -38,24 +38,49 @@ def clip_cast(val, dtype):
     return te.max(tvm.te.min(val, const_max), const_min).astype(dtype)
 
 
+def is_relax_constant(expr):
+    return hasattr(expr.op, "value") and isinstance(expr.op.value, tvm.relax.expr.Constant)
+
+
 # Return True if given expression is scalar constant value.
 def is_scalar(expr):
+    """
+    Return True if given expression is scalar constant value.
+    """
     if isinstance(expr, te.Tensor):
-        return expr.ndim == 0 and (isinstance(expr.op.body[0], (tvm.tir.FloatImm, tvm.tir.IntImm)))
+        if is_relax_constant(expr):
+            shape = expr.op.value.data.shape
+            dtype = expr.op.value.data.dtype
+            return len(shape) == 0 and ("float" in dtype or "int" in dtype)
+        else:
+            return expr.ndim == 0 and (
+                isinstance(expr.op.body[0], (tvm.tir.FloatImm, tvm.tir.IntImm))
+            )
     return isinstance(expr, (tvm.tir.FloatImm, tvm.tir.IntImm))
 
 
+def get_relax_scalar_const_value(expr):
+    assert len(expr.op.value.data.shape) == 0
+    return expr.op.value.data.numpy()[()]
+
+
 def get_const_int_value(expr):
     if isinstance(expr, te.Tensor):
-        assert isinstance(expr.op.body[0], tvm.tir.IntImm)
-        return expr.op.body[0].value
+        if is_relax_constant(expr):
+            return get_relax_scalar_const_value(expr)
+        else:
+            assert isinstance(expr.op.body[0], tvm.tir.IntImm)
+            return expr.op.body[0].value
     return get_const_int(expr)
 
 
 def get_const_float_value(expr):
     if isinstance(expr, te.Tensor):
-        assert isinstance(expr.op.body[0], tvm.tir.FloatImm)
-        return expr.op.body[0].value
+        if is_relax_constant(expr):
+            return get_relax_scalar_const_value(expr)
+        else:
+            assert isinstance(expr.op.body[0], tvm.tir.FloatImm)
+            return expr.op.body[0].value
     return get_const_float(expr)
 
 
@@ -224,7 +249,7 @@ def qnn_requantize(
             # Add output zero point + clip + cast:
             return saturate(te.add(mul, output_zp), out_dtype).astype(out_dtype)
 
-        return te.compute(data.shape, _compute, name="requantize")
+        return te.compute(data.shape, _compute, name="requantize_scalar")
 
     else:
 
@@ -285,8 +310,8 @@ def compute_qnn_binary_op(
 
     def _compute_tensor(x: te.Tensor, input_scale, input_zp):
         if is_scalar(input_scale) and is_scalar(output_scale):
-            iscale = input_scale.op.body[0].value
-            oscale = output_scale.op.body[0].value
+            iscale = get_const_float_value(input_scale)
+            oscale = get_const_float_value(output_scale)
             scale = iscale / oscale
             scale_fixed_point, rsh = get_fixed_point_value(scale, "int16")
             return te.compute(
@@ -406,7 +431,9 @@ def qnn_mul(
     if is_scalar(lhs_scale) and is_scalar(rhs_scale):
         assert isinstance(lhs_scale, te.Tensor)
         assert isinstance(rhs_scale, te.Tensor)
-        iscale = lhs_scale.op.body[0] * rhs_scale.op.body[0]
+        iscale = get_const_float_value(lhs_scale.op.body[0]) * get_const_float_value(
+            rhs_scale.op.body[0]
+        )
     else:
         iscale = lhs_scale * rhs_scale
 
diff --git a/tests/python/contrib/test_hexagon/test_wo_qnn_canonicalization.py b/tests/python/contrib/test_hexagon/test_wo_qnn_canonicalization.py
index f4342f5814..1c68d084f7 100644
--- a/tests/python/contrib/test_hexagon/test_wo_qnn_canonicalization.py
+++ b/tests/python/contrib/test_hexagon/test_wo_qnn_canonicalization.py
@@ -24,6 +24,8 @@ from tvm.contrib.hexagon.session import Session
 from tvm.contrib.hexagon.pytest_plugin import HEXAGON_AOT_LLVM_TARGET
 from tvm.relay.backend import Executor
 from tvm.relay.testing import run_opt_pass, run_infer_type
+from tvm.relax.testing import relay_translator
+from .infrastructure import get_hexagon_target
 
 
 @tvm.testing.requires_hexagon
@@ -471,5 +473,73 @@ class TestQnnOp:
         np.testing.assert_equal(hexagon_output, llvm_output)
 
 
+def test_qnn_conv2d_is_scalar_relax():
+    """Test to check if the input scale and output scale is constant,
+    qnn.requantize will compute with fixed_point_value."""
+
+    data_shape = (1, 64, 56, 56)
+    kernel_shape = (128, 64, 3, 3)
+
+    data_dtype = "uint8"
+    in_data = relay.var("data", shape=data_shape, dtype=data_dtype)
+
+    kernel_dtype = "int8"
+    kernel = relay.var("kernel", shape=kernel_shape, dtype=kernel_dtype)
+    azp = np.array([0]).astype("int32")
+    wzp = np.array([0]).astype("int32")  # assumed zero
+    bias = (np.zeros((1, 512, 1, 1), dtype="uint32") * -12).astype("int32")
+    rqsci = np.array([1]).astype("float32")
+    rqzpi = np.array([0]).astype("int32")
+    rqsco = np.array([1]).astype("float32")
+    rqzpo = np.array([0]).astype("int32")
+    strides = (1, 1)
+
+    input_zero_point = relay.const(azp[0], dtype="int32")
+    kernel_zero_point = relay.const(wzp[0], dtype="int32")
+
+    input_scale = relay.const(1.0, dtype="float32")
+    kernel_scale = relay.const(1.0, dtype="float32")
+
+    conv_op = relay.qnn.op.conv2d(
+        in_data,
+        kernel,
+        input_zero_point=input_zero_point,
+        kernel_zero_point=kernel_zero_point,
+        input_scale=input_scale,
+        kernel_scale=kernel_scale,
+        kernel_size=(kernel_shape[2], kernel_shape[3]),
+        channels=kernel_shape[0],
+        strides=strides,
+        data_layout="NCHW",
+        kernel_layout="OIHW",
+        out_dtype="int32",
+    )
+
+    bias = relay.var("bias", shape=(kernel_shape[0],), dtype="int32")
+    bias_op = relay.nn.bias_add(conv_op, bias, axis=1)
+
+    requant_op = relay.qnn.op.requantize(
+        bias_op,
+        input_scale=relay.const(rqsci[0]),
+        input_zero_point=relay.const(rqzpi[0]),
+        output_scale=relay.const(rqsco[0]),
+        output_zero_point=relay.const(rqzpo[0]),
+        out_dtype="int32",
+    )
+
+    clip_op = relay.op.clip(requant_op, 0.0, 255.0)
+    cast_op = relay.op.cast(clip_op, "uint8")
+
+    func = relay.Function([in_data, kernel, bias], cast_op)
+
+    mod = tvm.IRModule.from_expr(func)
+    target_hexagon = get_hexagon_target("v69")
+    relax_mod = relay_translator.from_relay(
+        mod["main"], target_hexagon, disabled_pass=["qnn.Legalize"]
+    )
+
+    assert "requantize_scalar" in relax_mod.astext(show_meta_data=False)
+
+
 if __name__ == "__main__":
     tvm.testing.main()