You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by me...@apache.org on 2022/05/11 22:26:53 UTC

[tvm] branch main updated: [ARM][Strategy] Fix is_int8_hw_support check function (#11193)

This is an automated email from the ASF dual-hosted git repository.

mehrdadh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 4eb6497adb [ARM][Strategy] Fix is_int8_hw_support check function (#11193)
4eb6497adb is described below

commit 4eb6497adba48f72a59837fe42ae3f56cba3fe35
Author: Mehrdad Hessar <mh...@octoml.ai>
AuthorDate: Wed May 11 15:26:48 2022 -0700

    [ARM][Strategy] Fix is_int8_hw_support check function (#11193)
    
    * Fix hw schedule condition
    
    * add warning messages to unoptimized schedules
---
 python/tvm/relay/op/strategy/arm_cpu.py |  5 +++
 python/tvm/topi/arm_cpu/conv2d_int8.py  |  2 +-
 tests/python/target/test_arm_target.py  | 56 +++++++++++++++++++++++++++++++++
 3 files changed, 62 insertions(+), 1 deletion(-)

diff --git a/python/tvm/relay/op/strategy/arm_cpu.py b/python/tvm/relay/op/strategy/arm_cpu.py
index d1f2b90706..6ccb449d0e 100644
--- a/python/tvm/relay/op/strategy/arm_cpu.py
+++ b/python/tvm/relay/op/strategy/arm_cpu.py
@@ -66,6 +66,7 @@ def schedule_pool_arm_cpu(attrs, outs, target):
             and layout in ("NWC", "NHWC")
         ):
             return topi.arm_cpu.schedule_pool(outs, layout)
+        logger.warning("pool is not optimized for arm cpu.")
         return topi.generic.schedule_pool(outs, layout)
 
 
@@ -236,6 +237,7 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
                     name="depthwise_conv2d_nhwc.arm_cpu",
                 )
             else:
+                logger.warning("depthwise_conv2d with layout NHWC is not optimized for arm cpu.")
                 strategy.add_implementation(
                     wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc),
                     wrap_topi_schedule(conv2d_generic.schedule_depthwise_conv2d_nhwc),
@@ -472,6 +474,7 @@ def schedule_dense_arm_cpu(attrs, inputs, out_type, target):
             name="dense_dsp",
         )
     else:
+        logger.warning("dense is not optimized for arm cpu.")
         strategy.add_implementation(
             wrap_compute_dense(
                 topi.nn.dense, need_auto_scheduler_layout=is_auto_scheduler_enabled()
@@ -508,12 +511,14 @@ def conv1d_strategy_arm_cpu(attrs, inputs, out_type, target):
                 )
             )
     elif layout == "NCW":
+        logger.warning("conv1d with layout %s is not optimized for arm cpu.", layout)
         strategy.add_implementation(
             wrap_compute_conv1d(topi.nn.conv1d_ncw),
             wrap_topi_schedule(topi.generic.schedule_conv1d_ncw),
             name="conv1d_ncw.generic",
         )
     elif layout == "NWC":
+        logger.warning("conv1d with layout %s is not optimized for arm cpu.", layout)
         strategy.add_implementation(
             wrap_compute_conv1d(topi.nn.conv1d_nwc),
             wrap_topi_schedule(topi.generic.schedule_conv1d_nwc),
diff --git a/python/tvm/topi/arm_cpu/conv2d_int8.py b/python/tvm/topi/arm_cpu/conv2d_int8.py
index b6ab89de8b..224d21b34d 100644
--- a/python/tvm/topi/arm_cpu/conv2d_int8.py
+++ b/python/tvm/topi/arm_cpu/conv2d_int8.py
@@ -126,7 +126,7 @@ def is_int8_hw_support(data_dtype, kernel_dtype):
     # 3) Check target
     is_target_support = is_neon_available() or is_dotprod_available()
 
-    return is_dtype_support and is_llvm_support
+    return is_dtype_support and is_llvm_support and is_target_support
 
 
 @autotvm.register_topi_schedule("conv2d_NCHWc_int8.arm_cpu")
diff --git a/tests/python/target/test_arm_target.py b/tests/python/target/test_arm_target.py
new file mode 100644
index 0000000000..9106c169c8
--- /dev/null
+++ b/tests/python/target/test_arm_target.py
@@ -0,0 +1,56 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+
+import tvm
+from tvm.topi.arm_cpu.conv2d_int8 import is_int8_hw_support
+from tvm.target import codegen
+
+arm_target, input_dtype, kernel_dtype, is_supported = tvm.testing.parameters(
+    # Testing mcpu type
+    ("c -mcpu=cortex-m4 -keys=arm_cpu", "int8", "int8", False),
+    ("c -mcpu=cortex-m7 -keys=arm_cpu", "int8", "int8", False),
+    ("c -mcpu=cortex-m33 -keys=arm_cpu", "int8", "int8", False),
+    ("c -mcpu=cortex-m55 -keys=arm_cpu", "int8", "int8", False),
+    ("c -mcpu=cortex-m3 -keys=arm_cpu", "int8", "int8", False),
+    ("llvm -keys=arm_cpu -mattr=+neon", "int8", "int8", True),
+    # This fails because of a bug in topi.arm_cpu.arm_utils.get_arch_version
+    # ("llvm -keys=arm_cpu -mattr=v8.4a,+dotprod", "int8", "int8", True),
+    # Testing dtype
+    ("llvm -keys=arm_cpu -mattr=+neon", "int16", "int8", False),
+    ("llvm -keys=arm_cpu -mattr=+neon", "int8", "int16", False),
+    ("llvm -keys=arm_cpu -mattr=+neon", "int16", "int16", False),
+)
+
+
+def test_arm_conv2d_int8_support(arm_target, input_dtype, kernel_dtype, is_supported):
+    """Test ARM conv2d int8 support for different targets.
+
+    Parameters
+    ----------
+    arm_target : str
+        ARM CPU target.
+    input_dtype : str
+        Conv2d input data type.
+    kernel_dtype : Session
+        Conv2d kernel data type.
+    is_supported : bool
+        Expected result.
+    """
+    with tvm.target.Target(arm_target):
+        expected_result = is_supported and (codegen.llvm_version_major() >= 8)
+        assert is_int8_hw_support(input_dtype, kernel_dtype) == expected_result