You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by as...@apache.org on 2022/11/17 09:49:11 UTC

[tvm] branch main updated: [ACL] Enable int8 data type in QNN ADD (#13407)

This is an automated email from the ASF dual-hosted git repository.

ashutoshp pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new c98f3cd6f8 [ACL] Enable int8 data type in QNN ADD (#13407)
c98f3cd6f8 is described below

commit c98f3cd6f8b0dcd8b6b07fecd5a60174ec13dc5b
Author: Leandro Nunes <le...@arm.com>
AuthorDate: Thu Nov 17 09:49:04 2022 +0000

    [ACL] Enable int8 data type in QNN ADD (#13407)
    
    This enables int8 data type to be used in Compute Library
    for the Arm(r) Architecture (ACL) BYOC integration.
---
 python/tvm/relay/op/contrib/arm_compute_lib.py        | 2 +-
 src/relay/backend/contrib/arm_compute_lib/codegen.cc  | 2 +-
 src/runtime/contrib/arm_compute_lib/acl_utils.cc      | 2 ++
 tests/python/contrib/test_arm_compute_lib/test_add.py | 4 +++-
 4 files changed, 7 insertions(+), 3 deletions(-)

diff --git a/python/tvm/relay/op/contrib/arm_compute_lib.py b/python/tvm/relay/op/contrib/arm_compute_lib.py
index 9abd320b29..d63cd8c83a 100644
--- a/python/tvm/relay/op/contrib/arm_compute_lib.py
+++ b/python/tvm/relay/op/contrib/arm_compute_lib.py
@@ -511,7 +511,7 @@ def qnn_add(expr):
     """Check if the external ACL codegen for add should be used."""
     args = expr.args
     for typ in [args[0].checked_type, args[1].checked_type]:
-        if typ.dtype != "uint8":
+        if typ.dtype not in ["int8", "uint8"]:
             return False
 
     return True
diff --git a/src/relay/backend/contrib/arm_compute_lib/codegen.cc b/src/relay/backend/contrib/arm_compute_lib/codegen.cc
index 81a5b5bbd9..3f11e63c73 100644
--- a/src/relay/backend/contrib/arm_compute_lib/codegen.cc
+++ b/src/relay/backend/contrib/arm_compute_lib/codegen.cc
@@ -292,7 +292,7 @@ class ACLJSONSerializer : public backend::contrib::JSONSerializer {
   /*!
    * \brief Create a JSON representation of a composite (global) average pooling operator.
    *
-   * A composite function is only created when using the uint8 datatype for these operators.
+   * A composite function is only created when using the int8/uint8 datatype for these operators.
    *
    * \param cn The call to be represented.
    * \return A JSON representation of a specific operator.
diff --git a/src/runtime/contrib/arm_compute_lib/acl_utils.cc b/src/runtime/contrib/arm_compute_lib/acl_utils.cc
index 238b7355de..0f2dde5e36 100644
--- a/src/runtime/contrib/arm_compute_lib/acl_utils.cc
+++ b/src/runtime/contrib/arm_compute_lib/acl_utils.cc
@@ -130,6 +130,8 @@ arm_compute::DataType MakeACLDataType(const DLDataType& data_type) {
     return arm_compute::DataType::F32;
   } else if (data_type.code == DLDataTypeCode::kDLUInt && data_type.bits == 8) {
     return arm_compute::DataType::QASYMM8;
+  } else if (data_type.code == DLDataTypeCode::kDLInt && data_type.bits == 8) {
+    return arm_compute::DataType::QASYMM8_SIGNED;
   } else if (data_type.code == DLDataTypeCode::kDLInt && data_type.bits == 32) {
     return arm_compute::DataType::S32;
   } else {
diff --git a/tests/python/contrib/test_arm_compute_lib/test_add.py b/tests/python/contrib/test_arm_compute_lib/test_add.py
index ba324358f8..ee6fcf603c 100644
--- a/tests/python/contrib/test_arm_compute_lib/test_add.py
+++ b/tests/python/contrib/test_arm_compute_lib/test_add.py
@@ -92,7 +92,8 @@ def test_runtime_add():
 
     for dtype, low, high, atol, rtol, op, op_params in [
         ("float32", -127, 128, 1e-7, 1e-7, relay.add, {}),
-        ("uint8", 0, 255, 0.0, 1.0, relay.qnn.op.add, _qnn_params),
+        ("uint8", 0, 255, 1.0, 0.0, relay.qnn.op.add, _qnn_params),
+        ("int8", -127, 128, 1.0, 0.0, relay.qnn.op.add, _qnn_params),
     ]:
         shape = (2, 2)
         for inputs in [
@@ -125,6 +126,7 @@ def test_codegen_add():
     for dtype, op_name, op, qnn_params in [
         ("float32", "add", relay.add, {}),
         ("uint8", "qnn.add", relay.qnn.op.add, _qnn_params),
+        ("int8", "qnn.add", relay.qnn.op.add, _qnn_params),
     ]:
         for shape in [(1, 1), (2, 2, 2), (3, 3, 3, 3)]:
             func = _get_model(shape, dtype, iter(inputs), op, qnn_params)