You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/10/08 19:48:57 UTC

[GitHub] [tvm] sergey-grovety opened a new pull request #9233: Cortex m7 intrinsic

sergey-grovety opened a new pull request #9233:
URL: https://github.com/apache/tvm/pull/9233


   # TVM operations implementation using cortex -M7 SIMD instructions
   
   ### nn.conv2d
   - We added the implementation of gemm function with 16-bit input
   - shape[-1] multiple of 4 restriction is resolved
   - There is data preparing before 8-bit intrinsic, such preparation will consume too much time in case of small tensor, so we add a check and simple cycle to handle this specific situation
   - In terms of optimization - calculations moved from inside of the intrinsic to outside
   - One of the buffers was radically cut(wasn't in use), lead to reducing memory requirements mostly in a half
   
   ### nn.max_pool2d
   - Implemented with __SSUB8 and __SEL intrinsics for four 8-bit input values, which is lead to notable acceleration
   - Feature: implementation ready for not 1word-aligned input data
   - Feature: ready for data sizes not a multiple of 4
   - memset is used to initialize the minimum values, to provide max speed
   
   ### nn.avg_pool2d
   - Due to lack of sum of four 8-bit values intrinsic - implementation could be possible only for 16-bit data
   -  __SMLAD intrinsic used to process two 16-bit values
   - Feature: implementation ready for not 1word-aligned input data
   - Feature: ready for data sizes not a multiple of 4
   
   ### nn.dense
   Implemented with same gemm method, described above
   
   ### nn.conv1d
   Specific case of gemm usage - with one of data dimensions equal to 1
   
   ### nn.avg_pool1d
   Implemented for NCW layout with same intrinsic as for 2d version of operation
   
   ### nn.max_pool1d
   Implemented with same intrinsic as for 2d version of operation
   
   
   # Benchmarking:
   To enable intrinsic code generation you should specify _-march=armv7e-m _ flag
   HW platform: STM32F746 Nucleo; GCC10, optimization flags: -O3
   If you want to enable intrinsic you should specify -march parameter of the target:
   `target_str = f"c -keys=arm_cpu -mcpu=cortex-m7 -march=armv7e-m -model=stm32f746xx -runtime=c -link-params=1 --executor=aot --unpacked-api=1 --interface-api=c"`
   
   ## Results
   
   ms | No Intrinsic | Intrinsic enabled
   -- | -- | --
   mnist8 | 8.625 | 6.574
   cifar10 | 788.36 | 144.59
   
   
   </div><ul><li>No intrinsic: march parameter not specified, no code generated for Intrinsic</li><li>Intrinsic enabled: <em>march=armv7e-m</em></li></ul>


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] areusch commented on pull request #9233: [Topi] Cortex-M DSP support

Posted by GitBox <gi...@apache.org>.
areusch commented on pull request #9233:
URL: https://github.com/apache/tvm/pull/9233#issuecomment-964651141


   looks like we are busted at head, https://github.com/apache/tvm/pull/9480 is the fix


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] areusch commented on a change in pull request #9233: Cortex m7 intrinsic

Posted by GitBox <gi...@apache.org>.
areusch commented on a change in pull request #9233:
URL: https://github.com/apache/tvm/pull/9233#discussion_r737943687



##########
File path: python/tvm/topi/arm_cpu/cortex_m7/conv1d/direct_simd.py
##########
@@ -0,0 +1,177 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, no-value-for-parameter
+"""Direct implementation of conv1d."""
+from tvm import autotvm
+from tvm.autotvm.task import deserialize_args
+from tvm import te
+from tvm.topi.utils import simplify, traverse_inline
+from tvm.topi.nn.pad import pad
+from tvm.topi.nn.utils import get_pad_tuple1d
+from tvm.tir.expr import Mul
+
+from ..micro_kernel.gemm import (
+    intrin_gemm_MxKxN,
+    gemm_MxKxN_impl,
+)
+
+
+def conv1d_nwc_direct_simd(*args, **kwargs):
+    """Defines the Cortex-M7 SIMD implementation of conv1d on NWC layout."""

Review comment:
       yeah i'm up for moving away from isa_analyzer. that was just an initial stab, but i agree that modeling this at the level of architecture rather than instruction makes more sense. however, i think it would be good to do that in a follow-on PR. this PR could then move forward isolated to what was tested already (on STM32F746 nucleo, I believe), and a follow-on could expand support to the broader architecture. what do you think of this?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] ilyag-grovety commented on a change in pull request #9233: Cortex m7 intrinsic

Posted by GitBox <gi...@apache.org>.
ilyag-grovety commented on a change in pull request #9233:
URL: https://github.com/apache/tvm/pull/9233#discussion_r738094680



##########
File path: python/tvm/topi/arm_cpu/conv1d.py
##########
@@ -0,0 +1,36 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-variable, no-else-return, unused-argument, import-outside-toplevel
+"""Conv1D schedule for ARM CPU"""
+from __future__ import absolute_import as _abs
+
+from tvm import autotvm
+
+from .cortex_m7.conv1d import direct_simd as direct_simd_conv1d
+
+
+@autotvm.register_topi_compute("conv1d_nwc_direct_simd.arm_cpu")
+def conv1d_nwc_direct_simd(cfg, data, kernel, strides, padding, dilation, out_dtype):

Review comment:
       Ok, we will change all "direct_simd" parts in naming to "dsp". Also files "direct_simd.py" to rename to "dsp.py", right?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] areusch commented on a change in pull request #9233: Cortex m7 intrinsic

Posted by GitBox <gi...@apache.org>.
areusch commented on a change in pull request #9233:
URL: https://github.com/apache/tvm/pull/9233#discussion_r742417398



##########
File path: tests/python/conftest.py
##########
@@ -40,3 +41,26 @@
 
 if tvm.support.libinfo().get("USE_MICRO", "OFF") != "ON":
     collect_ignore.append("unittest/test_micro_transport.py")
+
+
+def pytest_addoption(parser):
+    parser.addoption(
+        "--enable-corstone300-tests",
+        action="store_true",
+        default=False,
+        help="Run Corstone-300 FVP tests",
+    )
+
+
+def pytest_collection_modifyitems(config, items):
+    for item in items:
+        if config.getoption("--enable-corstone300-tests"):
+            if not "corstone300" in item.keywords:
+                item.add_marker(
+                    pytest.mark.skip(reason="Test should be marked 'corstone300' to run")

Review comment:
       i think we just need one skip, right? doesn't this skip all other tests aside from corstone300?

##########
File path: python/tvm/target/arm_isa.py
##########
@@ -16,18 +16,24 @@
 # under the License.
 """Defines functions to analyze available opcodes in the ARM ISA."""
 
+import argparse
 
 ARM_ISA_MAP = {
-    "armv7e-m": ["SMLAD"],
+    "armv7e-m": ["SMLAD", "SSUB8", "SEL"],
+    "armv8-m": ["SMLAD", "SSUB8", "SEL"],
 }
 
 
 class IsaAnalyzer(object):
+    """Checks ISA support for given target"""
+
     def __init__(self, target):
         self.target = target
-        # TODO: actually parse -mcpu
-        arch = "armv7e-m"
-        self._isa_map = ARM_ISA_MAP[arch]
+        parser = argparse.ArgumentParser()

Review comment:
       you should use the built-in Target parsing logic here rather than argparse:
   ```suggestion
           target = tvm.target.Target(target)
           march = target.attrs.get("-march", None)
           self._isa_map = ARM_ISA_MAP[march] if march is not None else []
   ```
   (also need to delete the following lines 33-36--suggestion didn't quite get the diff)

##########
File path: python/tvm/topi/arm_cpu/dense.py
##########
@@ -0,0 +1,25 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-variable, no-else-return, unused-argument, import-outside-toplevel
+"""Dense schedule for ARM CPU"""
+
+from .cortex_m7.dense import direct_simd

Review comment:
       i think it makes sense then to not import the cortex_m7 direct_simd into this module. can we reorganize as discussed in the earlier thread?

##########
File path: python/tvm/target/arm_isa.py
##########
@@ -16,18 +16,24 @@
 # under the License.
 """Defines functions to analyze available opcodes in the ARM ISA."""
 
+import argparse
 
 ARM_ISA_MAP = {
-    "armv7e-m": ["SMLAD"],
+    "armv7e-m": ["SMLAD", "SSUB8", "SEL"],

Review comment:
       @u99127 as discussed, let's punt the architecture labelling to the next PR.

##########
File path: python/tvm/relay/op/strategy/arm_cpu.py
##########
@@ -49,6 +49,26 @@ def schedule_concatenate_arm_cpu(_, outs, target):
         return topi.arm_cpu.schedule_concatenate(outs)
 
 
+@schedule_pool.register(["arm_cpu", "micro_dev"])
+def schedule_pool_arm_cpu(attrs, outs, target):
+    """schedule pooling ops arm cpu"""
+    layout = attrs.layout
+    isa = arm_isa.IsaAnalyzer(target)
+    avg_pool = isinstance(attrs, relay.op.op_attrs.AvgPool2DAttrs)
+    with target:
+        if (
+            avg_pool
+            and layout in ("NCW", "NCHW")
+            and "SMLAD" in isa

Review comment:
       i agree with you that we should refactor this. this was left over from the initial implementation which did propose to test for presence of instructions in the ISA; however, you're right that we should just need to determine which architecture is in use. since this PR just adds additional schedules which are purported to be compatible with cortex-m7 devices, perhaps we can address the question of lookup-by-architecture in a follow-on.

##########
File path: tests/python/integration/test_m7_simd.py
##########
@@ -0,0 +1,355 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import sys
+import numpy as np
+import pytest
+import tvm
+from tvm import relay
+from tests.python.relay.aot.aot_test_utils import (
+    AOTTestModel,
+    AOT_CORSTONE300_RUNNER,
+    generate_ref_data,
+    compile_and_run,
+)
+
+
+@tvm.testing.requires_corstone300
+@pytest.mark.parametrize(
+    "data_shape_nhwc, kernel_size, num_filter, strides, padding, dilation",
+    [
+        ((1, 32, 32, 1), (3, 3), 12, 1, 0, 1),
+        ((1, 32, 10, 3), (3, 3), 16, 1, 0, 1),
+        ((1, 49, 10, 1), (10, 4), 64, (2, 1), (4, 1, 5, 1), 1),
+        ((1, 32, 32, 16), (3, 3), 16, 1, (0, 2, 2, 0), 1),
+        ((1, 32, 32, 16), (3, 3), 16, 1, 0, 1),
+        ((1, 32, 32, 16), (3, 3), 16, 1, 0, 1),
+        ((1, 32, 32, 16), (3, 3), 16, 1, (0, 2, 2, 0), 2),
+        ((1, 32, 32, 16), (3, 3), 16, 1, (1, 1, 2, 2), 2),
+        # bug https://github.com/apache/tvm/issues/9226
+        ((1, 49, 10, 1), (10, 4), 64, (2, 2), (4, 1, 5, 1), 1),
+        # from Visual Wake Word model
+        ((1, 96, 96, 3), (3, 3), 8, (2, 2), (0, 0, 1, 1), 1),
+        # from Image Classification model (one of the MLPerfTiny models)
+        ((1, 16, 16, 32), (1, 1), 64, (2, 2), 0, 1),
+        ((4, 16, 16, 8), (5, 5), 8, 2, (0, 4, 4, 0), 1),
+        ((4, 16, 16, 8), (5, 5), 16, 2, (0, 4, 4, 0), 1),
+        ((4, 16, 16, 8), (5, 5), 8, 2, 0, 1),
+        ((4, 16, 16, 8), (5, 5), 16, 2, 0, 1),
+        ((1, 16, 16, 8), (3, 3), 16, 2, (0, 0, 1, 1), 1),
+        ((1, 16, 16, 8), (3, 3), 16, 2, (1, 1, 2, 2), 1),
+        ((1, 16, 16, 8), (5, 5), 16, 2, (3, 3, 2, 2), 1),
+        ((1, 16, 16, 8), (3, 3), 16, 2, (0, 1, 2, 3), 1),
+    ],
+)
+@pytest.mark.parametrize("dtype", ["int8", "int16"])
+def test_conv2d(data_shape_nhwc, kernel_size, num_filter, strides, padding, dilation, dtype):
+    """Test a subgraph with a single conv2d operator."""
+    ishape = data_shape_nhwc
+    wshape = (*kernel_size, data_shape_nhwc[-1], num_filter)
+
+    weight_data = np.random.randint(low=-10, high=10, size=wshape, dtype=dtype)
+
+    input0 = relay.var("input", relay.TensorType(ishape, dtype))
+    weight0 = relay.const(weight_data)
+    out0 = relay.op.nn.conv2d(
+        input0,
+        weight0,
+        kernel_size=kernel_size,
+        strides=strides,
+        padding=padding,
+        dilation=(dilation, dilation),
+        data_layout="NHWC",
+        kernel_layout="HWIO",
+        out_dtype="int32",
+        out_layout="NHWC",
+    )
+    ref_mod = tvm.IRModule.from_expr(relay.Function([input0], out0))
+
+    input1 = relay.var("input", relay.TensorType(ishape, dtype))
+    weight1 = relay.const(np.moveaxis(weight_data, 2, -1))
+    out1 = relay.op.nn.conv2d(
+        input1,
+        weight1,
+        kernel_size=kernel_size,
+        strides=strides,
+        padding=padding,
+        dilation=(dilation, dilation),
+        data_layout="NHWC",
+        kernel_layout="HWOI",
+        out_dtype="int32",
+        out_layout="NHWC",
+    )
+    mod = tvm.IRModule.from_expr(relay.Function([input1], out1))
+
+    inputs = {"input": np.random.randint(low=-128, high=127, size=ishape, dtype=dtype)}
+    output_list = generate_ref_data(ref_mod, inputs)
+
+    compile_and_run(
+        AOTTestModel(module=mod, inputs=inputs, outputs=output_list),
+        runner=AOT_CORSTONE300_RUNNER,
+        interface_api="c",
+        use_unpacked_api=True,
+        target_opts={
+            "-keys": "arm_cpu",
+            "-march": "armv7e-m",

Review comment:
       agreed--i think -mcpu was used to key the IsaAnalyzer, correct?

##########
File path: python/tvm/relay/qnn/op/legalizations.py
##########
@@ -374,6 +374,8 @@ def _qnn_conv2d_legalize_arm_cpu(attrs, inputs, types):
         attrs["kernel_layout"],
         attrs["groups"],
     )
+
+    # Use int8 for Cortex-M7

Review comment:
       @sergey-grovety can you revert this comment or fix the set of CPUs indicated?

##########
File path: python/tvm/target/arm_isa.py
##########
@@ -16,18 +16,24 @@
 # under the License.
 """Defines functions to analyze available opcodes in the ARM ISA."""
 
+import argparse
 
 ARM_ISA_MAP = {
-    "armv7e-m": ["SMLAD"],
+    "armv7e-m": ["SMLAD", "SSUB8", "SEL"],
+    "armv8-m": ["SMLAD", "SSUB8", "SEL"],

Review comment:
       same thing here

##########
File path: python/tvm/topi/arm_cpu/conv1d.py
##########
@@ -0,0 +1,36 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-variable, no-else-return, unused-argument, import-outside-toplevel
+"""Conv1D schedule for ARM CPU"""
+from __future__ import absolute_import as _abs
+
+from tvm import autotvm
+
+from .cortex_m7.conv1d import direct_simd as direct_simd_conv1d
+
+
+@autotvm.register_topi_compute("conv1d_nwc_direct_simd.arm_cpu")
+def conv1d_nwc_direct_simd(cfg, data, kernel, strides, padding, dilation, out_dtype):

Review comment:
       mprofile seems good to me.

##########
File path: tests/python/conftest.py
##########
@@ -40,3 +41,26 @@
 
 if tvm.support.libinfo().get("USE_MICRO", "OFF") != "ON":
     collect_ignore.append("unittest/test_micro_transport.py")
+
+
+def pytest_addoption(parser):
+    parser.addoption(
+        "--enable-corstone300-tests",
+        action="store_true",
+        default=False,
+        help="Run Corstone-300 FVP tests",
+    )
+
+
+def pytest_collection_modifyitems(config, items):
+    for item in items:
+        if config.getoption("--enable-corstone300-tests"):
+            if not "corstone300" in item.keywords:
+                item.add_marker(
+                    pytest.mark.skip(reason="Test should be marked 'corstone300' to run")

Review comment:
       i think we just need one skip, right? doesn't this skip all other tests aside from corstone300?

##########
File path: python/tvm/target/arm_isa.py
##########
@@ -16,18 +16,24 @@
 # under the License.
 """Defines functions to analyze available opcodes in the ARM ISA."""
 
+import argparse
 
 ARM_ISA_MAP = {
-    "armv7e-m": ["SMLAD"],
+    "armv7e-m": ["SMLAD", "SSUB8", "SEL"],
+    "armv8-m": ["SMLAD", "SSUB8", "SEL"],
 }
 
 
 class IsaAnalyzer(object):
+    """Checks ISA support for given target"""
+
     def __init__(self, target):
         self.target = target
-        # TODO: actually parse -mcpu
-        arch = "armv7e-m"
-        self._isa_map = ARM_ISA_MAP[arch]
+        parser = argparse.ArgumentParser()

Review comment:
       you should use the built-in Target parsing logic here rather than argparse:
   ```suggestion
           target = tvm.target.Target(target)
           march = target.attrs.get("-march", None)
           self._isa_map = ARM_ISA_MAP[march] if march is not None else []
   ```
   (also need to delete the following lines 33-36--suggestion didn't quite get the diff)

##########
File path: python/tvm/topi/arm_cpu/dense.py
##########
@@ -0,0 +1,25 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-variable, no-else-return, unused-argument, import-outside-toplevel
+"""Dense schedule for ARM CPU"""
+
+from .cortex_m7.dense import direct_simd

Review comment:
       i think it makes sense then to not import the cortex_m7 direct_simd into this module. can we reorganize as discussed in the earlier thread?

##########
File path: python/tvm/target/arm_isa.py
##########
@@ -16,18 +16,24 @@
 # under the License.
 """Defines functions to analyze available opcodes in the ARM ISA."""
 
+import argparse
 
 ARM_ISA_MAP = {
-    "armv7e-m": ["SMLAD"],
+    "armv7e-m": ["SMLAD", "SSUB8", "SEL"],

Review comment:
       @u99127 as discussed, let's punt the architecture labelling to the next PR.

##########
File path: python/tvm/relay/op/strategy/arm_cpu.py
##########
@@ -49,6 +49,26 @@ def schedule_concatenate_arm_cpu(_, outs, target):
         return topi.arm_cpu.schedule_concatenate(outs)
 
 
+@schedule_pool.register(["arm_cpu", "micro_dev"])
+def schedule_pool_arm_cpu(attrs, outs, target):
+    """schedule pooling ops arm cpu"""
+    layout = attrs.layout
+    isa = arm_isa.IsaAnalyzer(target)
+    avg_pool = isinstance(attrs, relay.op.op_attrs.AvgPool2DAttrs)
+    with target:
+        if (
+            avg_pool
+            and layout in ("NCW", "NCHW")
+            and "SMLAD" in isa

Review comment:
       i agree with you that we should refactor this. this was left over from the initial implementation which did propose to test for presence of instructions in the ISA; however, you're right that we should just need to determine which architecture is in use. since this PR just adds additional schedules which are purported to be compatible with cortex-m7 devices, perhaps we can address the question of lookup-by-architecture in a follow-on.

##########
File path: tests/python/integration/test_m7_simd.py
##########
@@ -0,0 +1,355 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import sys
+import numpy as np
+import pytest
+import tvm
+from tvm import relay
+from tests.python.relay.aot.aot_test_utils import (
+    AOTTestModel,
+    AOT_CORSTONE300_RUNNER,
+    generate_ref_data,
+    compile_and_run,
+)
+
+
+@tvm.testing.requires_corstone300
+@pytest.mark.parametrize(
+    "data_shape_nhwc, kernel_size, num_filter, strides, padding, dilation",
+    [
+        ((1, 32, 32, 1), (3, 3), 12, 1, 0, 1),
+        ((1, 32, 10, 3), (3, 3), 16, 1, 0, 1),
+        ((1, 49, 10, 1), (10, 4), 64, (2, 1), (4, 1, 5, 1), 1),
+        ((1, 32, 32, 16), (3, 3), 16, 1, (0, 2, 2, 0), 1),
+        ((1, 32, 32, 16), (3, 3), 16, 1, 0, 1),
+        ((1, 32, 32, 16), (3, 3), 16, 1, 0, 1),
+        ((1, 32, 32, 16), (3, 3), 16, 1, (0, 2, 2, 0), 2),
+        ((1, 32, 32, 16), (3, 3), 16, 1, (1, 1, 2, 2), 2),
+        # bug https://github.com/apache/tvm/issues/9226
+        ((1, 49, 10, 1), (10, 4), 64, (2, 2), (4, 1, 5, 1), 1),
+        # from Visual Wake Word model
+        ((1, 96, 96, 3), (3, 3), 8, (2, 2), (0, 0, 1, 1), 1),
+        # from Image Classification model (one of the MLPerfTiny models)
+        ((1, 16, 16, 32), (1, 1), 64, (2, 2), 0, 1),
+        ((4, 16, 16, 8), (5, 5), 8, 2, (0, 4, 4, 0), 1),
+        ((4, 16, 16, 8), (5, 5), 16, 2, (0, 4, 4, 0), 1),
+        ((4, 16, 16, 8), (5, 5), 8, 2, 0, 1),
+        ((4, 16, 16, 8), (5, 5), 16, 2, 0, 1),
+        ((1, 16, 16, 8), (3, 3), 16, 2, (0, 0, 1, 1), 1),
+        ((1, 16, 16, 8), (3, 3), 16, 2, (1, 1, 2, 2), 1),
+        ((1, 16, 16, 8), (5, 5), 16, 2, (3, 3, 2, 2), 1),
+        ((1, 16, 16, 8), (3, 3), 16, 2, (0, 1, 2, 3), 1),
+    ],
+)
+@pytest.mark.parametrize("dtype", ["int8", "int16"])
+def test_conv2d(data_shape_nhwc, kernel_size, num_filter, strides, padding, dilation, dtype):
+    """Test a subgraph with a single conv2d operator."""
+    ishape = data_shape_nhwc
+    wshape = (*kernel_size, data_shape_nhwc[-1], num_filter)
+
+    weight_data = np.random.randint(low=-10, high=10, size=wshape, dtype=dtype)
+
+    input0 = relay.var("input", relay.TensorType(ishape, dtype))
+    weight0 = relay.const(weight_data)
+    out0 = relay.op.nn.conv2d(
+        input0,
+        weight0,
+        kernel_size=kernel_size,
+        strides=strides,
+        padding=padding,
+        dilation=(dilation, dilation),
+        data_layout="NHWC",
+        kernel_layout="HWIO",
+        out_dtype="int32",
+        out_layout="NHWC",
+    )
+    ref_mod = tvm.IRModule.from_expr(relay.Function([input0], out0))
+
+    input1 = relay.var("input", relay.TensorType(ishape, dtype))
+    weight1 = relay.const(np.moveaxis(weight_data, 2, -1))
+    out1 = relay.op.nn.conv2d(
+        input1,
+        weight1,
+        kernel_size=kernel_size,
+        strides=strides,
+        padding=padding,
+        dilation=(dilation, dilation),
+        data_layout="NHWC",
+        kernel_layout="HWOI",
+        out_dtype="int32",
+        out_layout="NHWC",
+    )
+    mod = tvm.IRModule.from_expr(relay.Function([input1], out1))
+
+    inputs = {"input": np.random.randint(low=-128, high=127, size=ishape, dtype=dtype)}
+    output_list = generate_ref_data(ref_mod, inputs)
+
+    compile_and_run(
+        AOTTestModel(module=mod, inputs=inputs, outputs=output_list),
+        runner=AOT_CORSTONE300_RUNNER,
+        interface_api="c",
+        use_unpacked_api=True,
+        target_opts={
+            "-keys": "arm_cpu",
+            "-march": "armv7e-m",

Review comment:
       agreed--i think -mcpu was used to key the IsaAnalyzer, correct?

##########
File path: python/tvm/relay/qnn/op/legalizations.py
##########
@@ -374,6 +374,8 @@ def _qnn_conv2d_legalize_arm_cpu(attrs, inputs, types):
         attrs["kernel_layout"],
         attrs["groups"],
     )
+
+    # Use int8 for Cortex-M7

Review comment:
       @sergey-grovety can you revert this comment or fix the set of CPUs indicated?

##########
File path: python/tvm/target/arm_isa.py
##########
@@ -16,18 +16,24 @@
 # under the License.
 """Defines functions to analyze available opcodes in the ARM ISA."""
 
+import argparse
 
 ARM_ISA_MAP = {
-    "armv7e-m": ["SMLAD"],
+    "armv7e-m": ["SMLAD", "SSUB8", "SEL"],
+    "armv8-m": ["SMLAD", "SSUB8", "SEL"],

Review comment:
       same thing here

##########
File path: python/tvm/topi/arm_cpu/conv1d.py
##########
@@ -0,0 +1,36 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-variable, no-else-return, unused-argument, import-outside-toplevel
+"""Conv1D schedule for ARM CPU"""
+from __future__ import absolute_import as _abs
+
+from tvm import autotvm
+
+from .cortex_m7.conv1d import direct_simd as direct_simd_conv1d
+
+
+@autotvm.register_topi_compute("conv1d_nwc_direct_simd.arm_cpu")
+def conv1d_nwc_direct_simd(cfg, data, kernel, strides, padding, dilation, out_dtype):

Review comment:
       mprofile seems good to me.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] areusch commented on a change in pull request #9233: Cortex m7 intrinsic

Posted by GitBox <gi...@apache.org>.
areusch commented on a change in pull request #9233:
URL: https://github.com/apache/tvm/pull/9233#discussion_r737944176



##########
File path: python/tvm/testing/utils.py
##########
@@ -674,6 +674,18 @@ def requires_opencl(*args):
     return _compose(args, _requires_opencl)
 
 
+def requires_corstone300(*args):
+    """Mark a test as requiring the corstone300 FVP
+
+    Parameters
+    ----------
+    f : function
+        Function to mark
+    """
+    _requires_corstone300 = [pytest.mark.corstone300]

Review comment:
       i discussed this with @grant-arm a bit and it seems the consensus was that there isn't a good way to auto-detect the FVP. perhaps we've missed something though?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] areusch commented on a change in pull request #9233: Cortex m7 intrinsic

Posted by GitBox <gi...@apache.org>.
areusch commented on a change in pull request #9233:
URL: https://github.com/apache/tvm/pull/9233#discussion_r737942611



##########
File path: python/tvm/topi/arm_cpu/conv1d.py
##########
@@ -0,0 +1,36 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-variable, no-else-return, unused-argument, import-outside-toplevel
+"""Conv1D schedule for ARM CPU"""
+from __future__ import absolute_import as _abs
+
+from tvm import autotvm
+
+from .cortex_m7.conv1d import direct_simd as direct_simd_conv1d
+
+
+@autotvm.register_topi_compute("conv1d_nwc_direct_simd.arm_cpu")
+def conv1d_nwc_direct_simd(cfg, data, kernel, strides, padding, dilation, out_dtype):

Review comment:
       yeah this makes sense--the simd was a historical thing from our initial implementation. i'm good with renaming this e.g. `conv1d_nwc_dsp`.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] areusch commented on a change in pull request #9233: Cortex m7 intrinsic

Posted by GitBox <gi...@apache.org>.
areusch commented on a change in pull request #9233:
URL: https://github.com/apache/tvm/pull/9233#discussion_r737942698



##########
File path: python/tvm/topi/arm_cpu/conv1d.py
##########
@@ -0,0 +1,36 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-variable, no-else-return, unused-argument, import-outside-toplevel
+"""Conv1D schedule for ARM CPU"""
+from __future__ import absolute_import as _abs
+
+from tvm import autotvm
+
+from .cortex_m7.conv1d import direct_simd as direct_simd_conv1d

Review comment:
       ```suggestion
   from .cortex_m7.conv1d import direct_simd
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] areusch commented on pull request #9233: Cortex m7 intrinsic

Posted by GitBox <gi...@apache.org>.
areusch commented on pull request #9233:
URL: https://github.com/apache/tvm/pull/9233#issuecomment-946030647


   @u99127 would you like to take a look? also, do you have any suggestions as to how to implement the `requires_corstone300` primitive? a simple implementation could be to add a pytest flag `--run-corstone300-tests=/path/to/opt/arm/ethosu` but am curious if there is a better way. the flag method requires us to make a Jenkinsfile change to differentiate between ci-i386 (with no Corstone 300) and ci-cpu (with Corstone 300).


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] sergey-grovety commented on a change in pull request #9233: Cortex m7 intrinsic

Posted by GitBox <gi...@apache.org>.
sergey-grovety commented on a change in pull request #9233:
URL: https://github.com/apache/tvm/pull/9233#discussion_r726534192



##########
File path: python/tvm/topi/arm_cpu/cortex_m7/micro_kernel/relu.py
##########
@@ -0,0 +1,78 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, no-value-for-parameter
+"""Defines relu intrinsics for SIMD relu operation."""
+
+
+def relu_MxN_impl(M, N, uniq_id):
+    """Emit C code for relu impl."""
+    cc_code = f"""
+#ifndef __STATIC_FORCEINLINE
+    #define __STATIC_FORCEINLINE  static inline
+#endif
+
+#ifdef __cplusplus
+extern "C"
+#endif
+__STATIC_FORCEINLINE int32_t relu_rest(
+    int N,
+    int8_t *mat) {{
+  for (int j = 0; j < N; j++)
+    mat[j] = mat[j] > 0 ? mat[j] : 0;
+  return 0;
+}}
+
+#ifdef __cplusplus
+extern "C"
+#endif
+__STATIC_FORCEINLINE int32_t relu_{M}x{N}_loop_{uniq_id}(
+    int8_t *mat) {{
+  for (int i = 0; i < {M}; i++)
+    for (int j = 0; j < {N}; j++)
+			mat[i * {N} + j] > 0 ? mat[i * {N} + j] : 0;
+  return 0;
+}}
+
+#ifdef __cplusplus
+extern "C"
+#endif
+__STATIC_FORCEINLINE int32_t relu_{M}x{N}_{uniq_id}(
+    int8_t *mat) {{
+
+	int32_t *pmat32 = (int32_t *)mat;
+
+#ifdef GROVETY_OP_BENCHMARK

Review comment:
       relu implementation is not ready yet and this file shouldn't appear here. 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] areusch merged pull request #9233: [Topi] Cortex-M DSP support

Posted by GitBox <gi...@apache.org>.
areusch merged pull request #9233:
URL: https://github.com/apache/tvm/pull/9233


   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] ilyag-grovety commented on a change in pull request #9233: Cortex m7 intrinsic

Posted by GitBox <gi...@apache.org>.
ilyag-grovety commented on a change in pull request #9233:
URL: https://github.com/apache/tvm/pull/9233#discussion_r738092988



##########
File path: python/tvm/topi/arm_cpu/conv1d.py
##########
@@ -0,0 +1,36 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-variable, no-else-return, unused-argument, import-outside-toplevel
+"""Conv1D schedule for ARM CPU"""
+from __future__ import absolute_import as _abs
+
+from tvm import autotvm
+
+from .cortex_m7.conv1d import direct_simd as direct_simd_conv1d

Review comment:
       Yes, it's remained from commit where conv1d schedules were in conv2d file.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] areusch commented on a change in pull request #9233: Cortex m7 intrinsic

Posted by GitBox <gi...@apache.org>.
areusch commented on a change in pull request #9233:
URL: https://github.com/apache/tvm/pull/9233#discussion_r737944378



##########
File path: python/tvm/topi/arm_cpu/cortex_m7/conv1d/__init__.py
##########
@@ -0,0 +1,19 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Conv1d implementations for cortex-m7."""

Review comment:
       agree with this, perhaps we do this with the follow-on to expand to architecture?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] ilyag-grovety commented on a change in pull request #9233: Cortex m7 intrinsic

Posted by GitBox <gi...@apache.org>.
ilyag-grovety commented on a change in pull request #9233:
URL: https://github.com/apache/tvm/pull/9233#discussion_r738098603



##########
File path: python/tvm/relay/op/strategy/arm_cpu.py
##########
@@ -49,6 +49,26 @@ def schedule_concatenate_arm_cpu(_, outs, target):
         return topi.arm_cpu.schedule_concatenate(outs)
 
 
+@schedule_pool.register(["arm_cpu", "micro_dev"])

Review comment:
       It was wrongly copied from similar schedule declaration, fixed in fb173291342758541fc31ac1f3eda24c6eb26bb2




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] ilyag-grovety commented on a change in pull request #9233: Cortex m7 intrinsic

Posted by GitBox <gi...@apache.org>.
ilyag-grovety commented on a change in pull request #9233:
URL: https://github.com/apache/tvm/pull/9233#discussion_r737429903



##########
File path: python/tvm/topi/arm_cpu/dense.py
##########
@@ -0,0 +1,25 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-variable, no-else-return, unused-argument, import-outside-toplevel
+"""Dense schedule for ARM CPU"""
+
+from .cortex_m7.dense import direct_simd

Review comment:
       This schedule is only for v7e-m strategy (check on isa). On AArch64 "dense.generic" strategy will be chosen.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] areusch commented on a change in pull request #9233: Cortex m7 intrinsic

Posted by GitBox <gi...@apache.org>.
areusch commented on a change in pull request #9233:
URL: https://github.com/apache/tvm/pull/9233#discussion_r742417398



##########
File path: tests/python/conftest.py
##########
@@ -40,3 +41,26 @@
 
 if tvm.support.libinfo().get("USE_MICRO", "OFF") != "ON":
     collect_ignore.append("unittest/test_micro_transport.py")
+
+
+def pytest_addoption(parser):
+    parser.addoption(
+        "--enable-corstone300-tests",
+        action="store_true",
+        default=False,
+        help="Run Corstone-300 FVP tests",
+    )
+
+
+def pytest_collection_modifyitems(config, items):
+    for item in items:
+        if config.getoption("--enable-corstone300-tests"):
+            if not "corstone300" in item.keywords:
+                item.add_marker(
+                    pytest.mark.skip(reason="Test should be marked 'corstone300' to run")

Review comment:
       i think we just need one skip, right? doesn't this skip all other tests aside from corstone300?

##########
File path: python/tvm/target/arm_isa.py
##########
@@ -16,18 +16,24 @@
 # under the License.
 """Defines functions to analyze available opcodes in the ARM ISA."""
 
+import argparse
 
 ARM_ISA_MAP = {
-    "armv7e-m": ["SMLAD"],
+    "armv7e-m": ["SMLAD", "SSUB8", "SEL"],
+    "armv8-m": ["SMLAD", "SSUB8", "SEL"],
 }
 
 
 class IsaAnalyzer(object):
+    """Checks ISA support for given target"""
+
     def __init__(self, target):
         self.target = target
-        # TODO: actually parse -mcpu
-        arch = "armv7e-m"
-        self._isa_map = ARM_ISA_MAP[arch]
+        parser = argparse.ArgumentParser()

Review comment:
       you should use the built-in Target parsing logic here rather than argparse:
   ```suggestion
           target = tvm.target.Target(target)
           march = target.attrs.get("-march", None)
           self._isa_map = ARM_ISA_MAP[march] if march is not None else []
   ```
   (also need to delete the following lines 33-36--suggestion didn't quite get the diff)

##########
File path: python/tvm/topi/arm_cpu/dense.py
##########
@@ -0,0 +1,25 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-variable, no-else-return, unused-argument, import-outside-toplevel
+"""Dense schedule for ARM CPU"""
+
+from .cortex_m7.dense import direct_simd

Review comment:
       i think it makes sense then to not import the cortex_m7 direct_simd into this module. can we reorganize as discussed in the earlier thread?

##########
File path: python/tvm/target/arm_isa.py
##########
@@ -16,18 +16,24 @@
 # under the License.
 """Defines functions to analyze available opcodes in the ARM ISA."""
 
+import argparse
 
 ARM_ISA_MAP = {
-    "armv7e-m": ["SMLAD"],
+    "armv7e-m": ["SMLAD", "SSUB8", "SEL"],

Review comment:
       @u99127 as discussed, let's punt the architecture labelling to the next PR.

##########
File path: python/tvm/relay/op/strategy/arm_cpu.py
##########
@@ -49,6 +49,26 @@ def schedule_concatenate_arm_cpu(_, outs, target):
         return topi.arm_cpu.schedule_concatenate(outs)
 
 
+@schedule_pool.register(["arm_cpu", "micro_dev"])
+def schedule_pool_arm_cpu(attrs, outs, target):
+    """schedule pooling ops arm cpu"""
+    layout = attrs.layout
+    isa = arm_isa.IsaAnalyzer(target)
+    avg_pool = isinstance(attrs, relay.op.op_attrs.AvgPool2DAttrs)
+    with target:
+        if (
+            avg_pool
+            and layout in ("NCW", "NCHW")
+            and "SMLAD" in isa

Review comment:
       i agree with you that we should refactor this. this was left over from the initial implementation which did propose to test for presence of instructions in the ISA; however, you're right that we should just need to determine which architecture is in use. since this PR just adds additional schedules which are purported to be compatible with cortex-m7 devices, perhaps we can address the question of lookup-by-architecture in a follow-on.

##########
File path: tests/python/integration/test_m7_simd.py
##########
@@ -0,0 +1,355 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import sys
+import numpy as np
+import pytest
+import tvm
+from tvm import relay
+from tests.python.relay.aot.aot_test_utils import (
+    AOTTestModel,
+    AOT_CORSTONE300_RUNNER,
+    generate_ref_data,
+    compile_and_run,
+)
+
+
+@tvm.testing.requires_corstone300
+@pytest.mark.parametrize(
+    "data_shape_nhwc, kernel_size, num_filter, strides, padding, dilation",
+    [
+        ((1, 32, 32, 1), (3, 3), 12, 1, 0, 1),
+        ((1, 32, 10, 3), (3, 3), 16, 1, 0, 1),
+        ((1, 49, 10, 1), (10, 4), 64, (2, 1), (4, 1, 5, 1), 1),
+        ((1, 32, 32, 16), (3, 3), 16, 1, (0, 2, 2, 0), 1),
+        ((1, 32, 32, 16), (3, 3), 16, 1, 0, 1),
+        ((1, 32, 32, 16), (3, 3), 16, 1, 0, 1),
+        ((1, 32, 32, 16), (3, 3), 16, 1, (0, 2, 2, 0), 2),
+        ((1, 32, 32, 16), (3, 3), 16, 1, (1, 1, 2, 2), 2),
+        # bug https://github.com/apache/tvm/issues/9226
+        ((1, 49, 10, 1), (10, 4), 64, (2, 2), (4, 1, 5, 1), 1),
+        # from Visual Wake Word model
+        ((1, 96, 96, 3), (3, 3), 8, (2, 2), (0, 0, 1, 1), 1),
+        # from Image Classification model (one of the MLPerfTiny models)
+        ((1, 16, 16, 32), (1, 1), 64, (2, 2), 0, 1),
+        ((4, 16, 16, 8), (5, 5), 8, 2, (0, 4, 4, 0), 1),
+        ((4, 16, 16, 8), (5, 5), 16, 2, (0, 4, 4, 0), 1),
+        ((4, 16, 16, 8), (5, 5), 8, 2, 0, 1),
+        ((4, 16, 16, 8), (5, 5), 16, 2, 0, 1),
+        ((1, 16, 16, 8), (3, 3), 16, 2, (0, 0, 1, 1), 1),
+        ((1, 16, 16, 8), (3, 3), 16, 2, (1, 1, 2, 2), 1),
+        ((1, 16, 16, 8), (5, 5), 16, 2, (3, 3, 2, 2), 1),
+        ((1, 16, 16, 8), (3, 3), 16, 2, (0, 1, 2, 3), 1),
+    ],
+)
+@pytest.mark.parametrize("dtype", ["int8", "int16"])
+def test_conv2d(data_shape_nhwc, kernel_size, num_filter, strides, padding, dilation, dtype):
+    """Test a subgraph with a single conv2d operator."""
+    ishape = data_shape_nhwc
+    wshape = (*kernel_size, data_shape_nhwc[-1], num_filter)
+
+    weight_data = np.random.randint(low=-10, high=10, size=wshape, dtype=dtype)
+
+    input0 = relay.var("input", relay.TensorType(ishape, dtype))
+    weight0 = relay.const(weight_data)
+    out0 = relay.op.nn.conv2d(
+        input0,
+        weight0,
+        kernel_size=kernel_size,
+        strides=strides,
+        padding=padding,
+        dilation=(dilation, dilation),
+        data_layout="NHWC",
+        kernel_layout="HWIO",
+        out_dtype="int32",
+        out_layout="NHWC",
+    )
+    ref_mod = tvm.IRModule.from_expr(relay.Function([input0], out0))
+
+    input1 = relay.var("input", relay.TensorType(ishape, dtype))
+    weight1 = relay.const(np.moveaxis(weight_data, 2, -1))
+    out1 = relay.op.nn.conv2d(
+        input1,
+        weight1,
+        kernel_size=kernel_size,
+        strides=strides,
+        padding=padding,
+        dilation=(dilation, dilation),
+        data_layout="NHWC",
+        kernel_layout="HWOI",
+        out_dtype="int32",
+        out_layout="NHWC",
+    )
+    mod = tvm.IRModule.from_expr(relay.Function([input1], out1))
+
+    inputs = {"input": np.random.randint(low=-128, high=127, size=ishape, dtype=dtype)}
+    output_list = generate_ref_data(ref_mod, inputs)
+
+    compile_and_run(
+        AOTTestModel(module=mod, inputs=inputs, outputs=output_list),
+        runner=AOT_CORSTONE300_RUNNER,
+        interface_api="c",
+        use_unpacked_api=True,
+        target_opts={
+            "-keys": "arm_cpu",
+            "-march": "armv7e-m",

Review comment:
       agreed--i think -mcpu was used to key the IsaAnalyzer, correct?

##########
File path: python/tvm/relay/qnn/op/legalizations.py
##########
@@ -374,6 +374,8 @@ def _qnn_conv2d_legalize_arm_cpu(attrs, inputs, types):
         attrs["kernel_layout"],
         attrs["groups"],
     )
+
+    # Use int8 for Cortex-M7

Review comment:
       @sergey-grovety can you revert this comment or fix the set of CPUs indicated?

##########
File path: python/tvm/target/arm_isa.py
##########
@@ -16,18 +16,24 @@
 # under the License.
 """Defines functions to analyze available opcodes in the ARM ISA."""
 
+import argparse
 
 ARM_ISA_MAP = {
-    "armv7e-m": ["SMLAD"],
+    "armv7e-m": ["SMLAD", "SSUB8", "SEL"],
+    "armv8-m": ["SMLAD", "SSUB8", "SEL"],

Review comment:
       same thing here

##########
File path: python/tvm/topi/arm_cpu/conv1d.py
##########
@@ -0,0 +1,36 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-variable, no-else-return, unused-argument, import-outside-toplevel
+"""Conv1D schedule for ARM CPU"""
+from __future__ import absolute_import as _abs
+
+from tvm import autotvm
+
+from .cortex_m7.conv1d import direct_simd as direct_simd_conv1d
+
+
+@autotvm.register_topi_compute("conv1d_nwc_direct_simd.arm_cpu")
+def conv1d_nwc_direct_simd(cfg, data, kernel, strides, padding, dilation, out_dtype):

Review comment:
       mprofile seems good to me.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] vinx13 commented on a change in pull request #9233: Cortex m7 intrinsic

Posted by GitBox <gi...@apache.org>.
vinx13 commented on a change in pull request #9233:
URL: https://github.com/apache/tvm/pull/9233#discussion_r730034447



##########
File path: python/tvm/relay/op/strategy/arm_cpu.py
##########
@@ -49,6 +49,26 @@ def schedule_concatenate_arm_cpu(_, outs, target):
         return topi.arm_cpu.schedule_concatenate(outs)
 
 
+@schedule_pool.register(["arm_cpu", "micro_dev"])
+def schedule_pool_arm_cpu(attrs, outs, target):
+    """schedule pooling ops arm cpu"""
+    layout = attrs.layout
+    isa = arm_isa.IsaAnalyzer(target)
+    avg_pool = hasattr(attrs, "count_include_pad")

Review comment:
       what about `isinstance(attrs, relay.op.op_attrs.AvgPool2DAttrs)`




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] areusch commented on pull request #9233: Cortex m7 intrinsic

Posted by GitBox <gi...@apache.org>.
areusch commented on pull request #9233:
URL: https://github.com/apache/tvm/pull/9233#issuecomment-954095841


   discussed a bit with @grant-arm and he and @Mousius @manupa-arm will propose a solution to detecting the Corstone-300 FVP binary and `ETHOSU_PATH`, which are the two things we need to configure based on [corstone300.mk](https://github.com/apache/tvm/blob/main/tests/python/relay/aot/corstone300.mk#L30).


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] mehrdadh commented on a change in pull request #9233: Cortex m7 intrinsic

Posted by GitBox <gi...@apache.org>.
mehrdadh commented on a change in pull request #9233:
URL: https://github.com/apache/tvm/pull/9233#discussion_r726247929



##########
File path: python/tvm/topi/arm_cpu/cortex_m7/micro_kernel/relu.py
##########
@@ -0,0 +1,78 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, no-value-for-parameter
+"""Defines relu intrinsics for SIMD relu operation."""
+
+
+def relu_MxN_impl(M, N, uniq_id):
+    """Emit C code for relu impl."""
+    cc_code = f"""
+#ifndef __STATIC_FORCEINLINE
+    #define __STATIC_FORCEINLINE  static inline
+#endif
+
+#ifdef __cplusplus
+extern "C"
+#endif
+__STATIC_FORCEINLINE int32_t relu_rest(
+    int N,
+    int8_t *mat) {{
+  for (int j = 0; j < N; j++)
+    mat[j] = mat[j] > 0 ? mat[j] : 0;
+  return 0;
+}}
+
+#ifdef __cplusplus
+extern "C"
+#endif
+__STATIC_FORCEINLINE int32_t relu_{M}x{N}_loop_{uniq_id}(
+    int8_t *mat) {{
+  for (int i = 0; i < {M}; i++)
+    for (int j = 0; j < {N}; j++)
+			mat[i * {N} + j] > 0 ? mat[i * {N} + j] : 0;
+  return 0;
+}}
+
+#ifdef __cplusplus
+extern "C"
+#endif
+__STATIC_FORCEINLINE int32_t relu_{M}x{N}_{uniq_id}(
+    int8_t *mat) {{
+
+	int32_t *pmat32 = (int32_t *)mat;
+
+#ifdef GROVETY_OP_BENCHMARK

Review comment:
       What's the use of this macro?

##########
File path: python/tvm/relay/op/strategy/arm_cpu.py
##########
@@ -415,3 +435,67 @@ def schedule_bitserial_dense_arm_cpu(attrs, inputs, out_type, target):
         name="bitserial_dense.arm_cpu",
     )
     return strategy
+
+
+@dense_strategy.register(["arm_cpu", "micro_dev"])
+def schedule_dense_arm_cpu(attrs, inputs, out_type, target):
+    """dense arm cpu strategy"""
+    strategy = _op.OpStrategy()
+    isa = arm_isa.IsaAnalyzer(target)
+    if "SMLAD" in isa:
+        strategy.add_implementation(
+            wrap_compute_dense(topi.nn.dense),
+            wrap_topi_schedule(topi.arm_cpu.schedule_dense_direct_simd),
+            name="dense_direct_simd.micro_dev",
+        )
+    else:
+        strategy.add_implementation(
+            wrap_compute_dense(topi.nn.dense),
+            wrap_topi_schedule(topi.generic.schedule_dense),
+            name="dense.generic",
+        )
+    return strategy
+
+
+@conv1d_strategy.register("arm_cpu")
+def conv1d_strategy_arm_cpu(attrs, inputs, out_type, target):
+    """conv1d strategy"""
+    strategy = _op.OpStrategy()
+    layout = attrs.data_layout
+    kernel_layout = attrs.kernel_layout
+    dilation = get_const_tuple(attrs.dilation)
+    if dilation[0] < 1:
+        raise ValueError("dilation should be a positive value")
+
+    isa = arm_isa.IsaAnalyzer(target)
+
+    if kernel_layout == "WOI":
+        if layout == "NWC" and "SMLAD" in isa:
+            strategy.add_implementation(
+                wrap_compute_conv1d(topi.arm_cpu.conv1d_direct_simd),
+                wrap_topi_schedule(topi.arm_cpu.schedule_conv1d_direct_simd),

Review comment:
       maybe rename `schedule_conv1d_direct_simd` to `schedule_conv1d_nwc_direct_simd` to be more readable and consistent with others.

##########
File path: python/tvm/topi/arm_cpu/conv2d.py
##########
@@ -508,12 +509,25 @@ def _callback(op):
 @autotvm.register_topi_compute("conv2d_direct_simd.arm_cpu")
 def conv2d_direct_simd(cfg, data, kernel, strides, padding, dilation, out_dtype):
     """Compute conv2d with SIMD (v7e-m)."""
-    return direct_simd.conv2d_direct_simd_compute(
+    return direct_simd_conv2d.conv2d_direct_simd_compute(
         cfg, data, kernel, strides, padding, dilation, out_dtype
     )
 
 
 @autotvm.register_topi_schedule("conv2d_direct_simd.arm_cpu")
 def schedule_conv2d_direct_simd(cfg, outs):
     """Create schedule for conv2d_direct_simd"""
-    return direct_simd.conv2d_direct_simd_nhwc_schedule(cfg, outs)
+    return direct_simd_conv2d.conv2d_direct_simd_nhwc_schedule(cfg, outs)
+
+
+@autotvm.register_topi_compute("conv1d_direct_simd.arm_cpu")
+def conv1d_direct_simd(cfg, data, kernel, strides, padding, dilation, out_dtype):

Review comment:
       please move `conv1d` related functions to conv1d.py file in the same directory. 

##########
File path: tests/python/integration/test_m7_simd.py
##########
@@ -0,0 +1,331 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import sys
+import numpy as np
+import pytest
+import tvm
+from tvm import relay
+from tests.python.relay.aot.aot_test_utils import (
+    AOTTestModel,
+    AOT_CORSTONE300_RUNNER,
+    generate_ref_data,
+    compile_and_run,
+)
+
+
+@tvm.testing.requires_corstone300
+@pytest.mark.parametrize(
+    "data_shape_nhwc, kernel_size, num_filter, strides, padding",
+    [
+        ((1, 32, 32, 1), (3, 3), 12, 1, 0),
+        ((1, 32, 10, 3), (3, 3), 16, 1, 0),
+        ((1, 49, 10, 1), (10, 4), 64, (2, 1), (4, 1, 5, 1)),
+        # TOFIX: https://github.com/apache/tvm/issues/9226
+        # ((1, 49, 10, 1), (10, 4), 64, (2, 2), (4, 1, 5, 1)),

Review comment:
       do you plan to have a follow on PR for this issue?

##########
File path: python/tvm/topi/arm_cpu/conv2d.py
##########
@@ -508,12 +509,25 @@ def _callback(op):
 @autotvm.register_topi_compute("conv2d_direct_simd.arm_cpu")
 def conv2d_direct_simd(cfg, data, kernel, strides, padding, dilation, out_dtype):
     """Compute conv2d with SIMD (v7e-m)."""
-    return direct_simd.conv2d_direct_simd_compute(
+    return direct_simd_conv2d.conv2d_direct_simd_compute(
         cfg, data, kernel, strides, padding, dilation, out_dtype
     )
 
 
 @autotvm.register_topi_schedule("conv2d_direct_simd.arm_cpu")
 def schedule_conv2d_direct_simd(cfg, outs):
     """Create schedule for conv2d_direct_simd"""
-    return direct_simd.conv2d_direct_simd_nhwc_schedule(cfg, outs)
+    return direct_simd_conv2d.conv2d_direct_simd_nhwc_schedule(cfg, outs)
+
+
+@autotvm.register_topi_compute("conv1d_direct_simd.arm_cpu")
+def conv1d_direct_simd(cfg, data, kernel, strides, padding, dilation, out_dtype):
+    """Compute conv1d with SIMD (v7e-m)."""
+    return direct_simd_conv1d.conv1d_direct_simd_compute(
+        cfg, data, kernel, strides, padding, dilation, out_dtype
+    )
+
+
+@autotvm.register_topi_schedule("conv1d_direct_simd.arm_cpu")
+def schedule_conv1d_direct_simd(cfg, outs):

Review comment:
       same comment here. since we only support `nwc` format, I think we should be more explicit in the name.

##########
File path: python/tvm/topi/arm_cpu/cortex_m7/conv1d/direct_simd.py
##########
@@ -0,0 +1,175 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, no-value-for-parameter
+"""Direct implementation of conv1d."""
+from tvm import autotvm
+from tvm.autotvm.task import deserialize_args
+from tvm import te
+from tvm.topi.utils import simplify, traverse_inline
+from tvm.topi.nn.pad import pad
+from tvm.topi.nn.utils import get_pad_tuple1d
+
+from ..micro_kernel.gemm import (
+    intrin_gemm_MxKxN,
+    gemm_MxKxN_impl,
+)
+
+
+def conv1d_direct_simd(*args, **kwargs):

Review comment:
       same here.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Mousius commented on a change in pull request #9233: Cortex m7 intrinsic

Posted by GitBox <gi...@apache.org>.
Mousius commented on a change in pull request #9233:
URL: https://github.com/apache/tvm/pull/9233#discussion_r739328601



##########
File path: python/tvm/relay/qnn/op/legalizations.py
##########
@@ -374,6 +374,8 @@ def _qnn_conv2d_legalize_arm_cpu(attrs, inputs, types):
         attrs["kernel_layout"],
         attrs["groups"],
     )
+
+    # Use int8 for Cortex-M7

Review comment:
       This is not limited to this CPU?

##########
File path: python/tvm/topi/arm_cpu/conv1d.py
##########
@@ -0,0 +1,36 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-variable, no-else-return, unused-argument, import-outside-toplevel
+"""Conv1D schedule for ARM CPU"""
+from __future__ import absolute_import as _abs
+
+from tvm import autotvm
+
+from .cortex_m7.conv1d import direct_simd as direct_simd_conv1d
+
+
+@autotvm.register_topi_compute("conv1d_nwc_direct_simd.arm_cpu")
+def conv1d_nwc_direct_simd(cfg, data, kernel, strides, padding, dilation, out_dtype):

Review comment:
       I'd suggested a file structure such as:
   ```
   arm_cpu/mprofile/dsp/conv1d.py
   ```
   This leaves room to add other architecture extensions in future rather than stacking them all in one directory.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Mousius commented on a change in pull request #9233: Cortex m7 intrinsic

Posted by GitBox <gi...@apache.org>.
Mousius commented on a change in pull request #9233:
URL: https://github.com/apache/tvm/pull/9233#discussion_r738614748



##########
File path: python/tvm/testing/utils.py
##########
@@ -674,6 +674,18 @@ def requires_opencl(*args):
     return _compose(args, _requires_opencl)
 
 
+def requires_corstone300(*args):
+    """Mark a test as requiring the corstone300 FVP
+
+    Parameters
+    ----------
+    f : function
+        Function to mark
+    """
+    _requires_corstone300 = [pytest.mark.corstone300]

Review comment:
       Depending on how these tests are ran, we could use the slightly icky AOT skip logic:
   https://github.com/apache/tvm/blob/f4dae23478f41ee899e58533bdb31efc5b1b709e/tests/python/relay/aot/aot_test_utils.py#L196-L201
   
   This would at least automate it if these tests are designed to run in CPU containers. Otherwise, we should just be able to check for the path since we know exactly where we're checking it out in the container:
   https://github.com/apache/tvm/blob/f4dae23478f41ee899e58533bdb31efc5b1b709e/docker/install/ubuntu_install_ethosu_driver_stack.sh#L23




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] areusch commented on a change in pull request #9233: Cortex m7 intrinsic

Posted by GitBox <gi...@apache.org>.
areusch commented on a change in pull request #9233:
URL: https://github.com/apache/tvm/pull/9233#discussion_r742417398



##########
File path: tests/python/conftest.py
##########
@@ -40,3 +41,26 @@
 
 if tvm.support.libinfo().get("USE_MICRO", "OFF") != "ON":
     collect_ignore.append("unittest/test_micro_transport.py")
+
+
+def pytest_addoption(parser):
+    parser.addoption(
+        "--enable-corstone300-tests",
+        action="store_true",
+        default=False,
+        help="Run Corstone-300 FVP tests",
+    )
+
+
+def pytest_collection_modifyitems(config, items):
+    for item in items:
+        if config.getoption("--enable-corstone300-tests"):
+            if not "corstone300" in item.keywords:
+                item.add_marker(
+                    pytest.mark.skip(reason="Test should be marked 'corstone300' to run")

Review comment:
       i think we just need one skip, right? doesn't this skip all other tests aside from corstone300?

##########
File path: python/tvm/target/arm_isa.py
##########
@@ -16,18 +16,24 @@
 # under the License.
 """Defines functions to analyze available opcodes in the ARM ISA."""
 
+import argparse
 
 ARM_ISA_MAP = {
-    "armv7e-m": ["SMLAD"],
+    "armv7e-m": ["SMLAD", "SSUB8", "SEL"],
+    "armv8-m": ["SMLAD", "SSUB8", "SEL"],
 }
 
 
 class IsaAnalyzer(object):
+    """Checks ISA support for given target"""
+
     def __init__(self, target):
         self.target = target
-        # TODO: actually parse -mcpu
-        arch = "armv7e-m"
-        self._isa_map = ARM_ISA_MAP[arch]
+        parser = argparse.ArgumentParser()

Review comment:
       you should use the built-in Target parsing logic here rather than argparse:
   ```suggestion
           target = tvm.target.Target(target)
           march = target.attrs.get("-march", None)
           self._isa_map = ARM_ISA_MAP[march] if march is not None else []
   ```
   (also need to delete the following lines 33-36--suggestion didn't quite get the diff)

##########
File path: python/tvm/topi/arm_cpu/dense.py
##########
@@ -0,0 +1,25 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-variable, no-else-return, unused-argument, import-outside-toplevel
+"""Dense schedule for ARM CPU"""
+
+from .cortex_m7.dense import direct_simd

Review comment:
       i think it makes sense then to not import the cortex_m7 direct_simd into this module. can we reorganize as discussed in the earlier thread?

##########
File path: python/tvm/target/arm_isa.py
##########
@@ -16,18 +16,24 @@
 # under the License.
 """Defines functions to analyze available opcodes in the ARM ISA."""
 
+import argparse
 
 ARM_ISA_MAP = {
-    "armv7e-m": ["SMLAD"],
+    "armv7e-m": ["SMLAD", "SSUB8", "SEL"],

Review comment:
       @u99127 as discussed, let's punt the architecture labelling to the next PR.

##########
File path: python/tvm/relay/op/strategy/arm_cpu.py
##########
@@ -49,6 +49,26 @@ def schedule_concatenate_arm_cpu(_, outs, target):
         return topi.arm_cpu.schedule_concatenate(outs)
 
 
+@schedule_pool.register(["arm_cpu", "micro_dev"])
+def schedule_pool_arm_cpu(attrs, outs, target):
+    """schedule pooling ops arm cpu"""
+    layout = attrs.layout
+    isa = arm_isa.IsaAnalyzer(target)
+    avg_pool = isinstance(attrs, relay.op.op_attrs.AvgPool2DAttrs)
+    with target:
+        if (
+            avg_pool
+            and layout in ("NCW", "NCHW")
+            and "SMLAD" in isa

Review comment:
       i agree with you that we should refactor this. this was left over from the initial implementation which did propose to test for presence of instructions in the ISA; however, you're right that we should just need to determine which architecture is in use. since this PR just adds additional schedules which are purported to be compatible with cortex-m7 devices, perhaps we can address the question of lookup-by-architecture in a follow-on.

##########
File path: tests/python/integration/test_m7_simd.py
##########
@@ -0,0 +1,355 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import sys
+import numpy as np
+import pytest
+import tvm
+from tvm import relay
+from tests.python.relay.aot.aot_test_utils import (
+    AOTTestModel,
+    AOT_CORSTONE300_RUNNER,
+    generate_ref_data,
+    compile_and_run,
+)
+
+
+@tvm.testing.requires_corstone300
+@pytest.mark.parametrize(
+    "data_shape_nhwc, kernel_size, num_filter, strides, padding, dilation",
+    [
+        ((1, 32, 32, 1), (3, 3), 12, 1, 0, 1),
+        ((1, 32, 10, 3), (3, 3), 16, 1, 0, 1),
+        ((1, 49, 10, 1), (10, 4), 64, (2, 1), (4, 1, 5, 1), 1),
+        ((1, 32, 32, 16), (3, 3), 16, 1, (0, 2, 2, 0), 1),
+        ((1, 32, 32, 16), (3, 3), 16, 1, 0, 1),
+        ((1, 32, 32, 16), (3, 3), 16, 1, 0, 1),
+        ((1, 32, 32, 16), (3, 3), 16, 1, (0, 2, 2, 0), 2),
+        ((1, 32, 32, 16), (3, 3), 16, 1, (1, 1, 2, 2), 2),
+        # bug https://github.com/apache/tvm/issues/9226
+        ((1, 49, 10, 1), (10, 4), 64, (2, 2), (4, 1, 5, 1), 1),
+        # from Visual Wake Word model
+        ((1, 96, 96, 3), (3, 3), 8, (2, 2), (0, 0, 1, 1), 1),
+        # from Image Classification model (one of the MLPerfTiny models)
+        ((1, 16, 16, 32), (1, 1), 64, (2, 2), 0, 1),
+        ((4, 16, 16, 8), (5, 5), 8, 2, (0, 4, 4, 0), 1),
+        ((4, 16, 16, 8), (5, 5), 16, 2, (0, 4, 4, 0), 1),
+        ((4, 16, 16, 8), (5, 5), 8, 2, 0, 1),
+        ((4, 16, 16, 8), (5, 5), 16, 2, 0, 1),
+        ((1, 16, 16, 8), (3, 3), 16, 2, (0, 0, 1, 1), 1),
+        ((1, 16, 16, 8), (3, 3), 16, 2, (1, 1, 2, 2), 1),
+        ((1, 16, 16, 8), (5, 5), 16, 2, (3, 3, 2, 2), 1),
+        ((1, 16, 16, 8), (3, 3), 16, 2, (0, 1, 2, 3), 1),
+    ],
+)
+@pytest.mark.parametrize("dtype", ["int8", "int16"])
+def test_conv2d(data_shape_nhwc, kernel_size, num_filter, strides, padding, dilation, dtype):
+    """Test a subgraph with a single conv2d operator."""
+    ishape = data_shape_nhwc
+    wshape = (*kernel_size, data_shape_nhwc[-1], num_filter)
+
+    weight_data = np.random.randint(low=-10, high=10, size=wshape, dtype=dtype)
+
+    input0 = relay.var("input", relay.TensorType(ishape, dtype))
+    weight0 = relay.const(weight_data)
+    out0 = relay.op.nn.conv2d(
+        input0,
+        weight0,
+        kernel_size=kernel_size,
+        strides=strides,
+        padding=padding,
+        dilation=(dilation, dilation),
+        data_layout="NHWC",
+        kernel_layout="HWIO",
+        out_dtype="int32",
+        out_layout="NHWC",
+    )
+    ref_mod = tvm.IRModule.from_expr(relay.Function([input0], out0))
+
+    input1 = relay.var("input", relay.TensorType(ishape, dtype))
+    weight1 = relay.const(np.moveaxis(weight_data, 2, -1))
+    out1 = relay.op.nn.conv2d(
+        input1,
+        weight1,
+        kernel_size=kernel_size,
+        strides=strides,
+        padding=padding,
+        dilation=(dilation, dilation),
+        data_layout="NHWC",
+        kernel_layout="HWOI",
+        out_dtype="int32",
+        out_layout="NHWC",
+    )
+    mod = tvm.IRModule.from_expr(relay.Function([input1], out1))
+
+    inputs = {"input": np.random.randint(low=-128, high=127, size=ishape, dtype=dtype)}
+    output_list = generate_ref_data(ref_mod, inputs)
+
+    compile_and_run(
+        AOTTestModel(module=mod, inputs=inputs, outputs=output_list),
+        runner=AOT_CORSTONE300_RUNNER,
+        interface_api="c",
+        use_unpacked_api=True,
+        target_opts={
+            "-keys": "arm_cpu",
+            "-march": "armv7e-m",

Review comment:
       agreed--i think -mcpu was used to key the IsaAnalyzer, correct?

##########
File path: python/tvm/relay/qnn/op/legalizations.py
##########
@@ -374,6 +374,8 @@ def _qnn_conv2d_legalize_arm_cpu(attrs, inputs, types):
         attrs["kernel_layout"],
         attrs["groups"],
     )
+
+    # Use int8 for Cortex-M7

Review comment:
       @sergey-grovety can you revert this comment or fix the set of CPUs indicated?

##########
File path: python/tvm/target/arm_isa.py
##########
@@ -16,18 +16,24 @@
 # under the License.
 """Defines functions to analyze available opcodes in the ARM ISA."""
 
+import argparse
 
 ARM_ISA_MAP = {
-    "armv7e-m": ["SMLAD"],
+    "armv7e-m": ["SMLAD", "SSUB8", "SEL"],
+    "armv8-m": ["SMLAD", "SSUB8", "SEL"],

Review comment:
       same thing here

##########
File path: python/tvm/topi/arm_cpu/conv1d.py
##########
@@ -0,0 +1,36 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-variable, no-else-return, unused-argument, import-outside-toplevel
+"""Conv1D schedule for ARM CPU"""
+from __future__ import absolute_import as _abs
+
+from tvm import autotvm
+
+from .cortex_m7.conv1d import direct_simd as direct_simd_conv1d
+
+
+@autotvm.register_topi_compute("conv1d_nwc_direct_simd.arm_cpu")
+def conv1d_nwc_direct_simd(cfg, data, kernel, strides, padding, dilation, out_dtype):

Review comment:
       mprofile seems good to me.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] u99127 commented on a change in pull request #9233: Cortex m7 intrinsic

Posted by GitBox <gi...@apache.org>.
u99127 commented on a change in pull request #9233:
URL: https://github.com/apache/tvm/pull/9233#discussion_r734055663



##########
File path: python/tvm/relay/op/strategy/arm_cpu.py
##########
@@ -49,6 +49,26 @@ def schedule_concatenate_arm_cpu(_, outs, target):
         return topi.arm_cpu.schedule_concatenate(outs)
 
 
+@schedule_pool.register(["arm_cpu", "micro_dev"])
+def schedule_pool_arm_cpu(attrs, outs, target):
+    """schedule pooling ops arm cpu"""
+    layout = attrs.layout
+    isa = arm_isa.IsaAnalyzer(target)
+    avg_pool = isinstance(attrs, relay.op.op_attrs.AvgPool2DAttrs)
+    with target:
+        if (
+            avg_pool
+            and layout in ("NCW", "NCHW")
+            and "SMLAD" in isa

Review comment:
       SMLAD, SSUB8 and SEL are part of the DSP instructions and the presence of one implies the presence of the other. I also think that in this case since we are adding all of these together globbing them into a single check for the use of the DSP extensions should be sufficient. Any reason why we are testing individual instructions ? 
   

##########
File path: python/tvm/target/arm_isa.py
##########
@@ -16,18 +16,24 @@
 # under the License.
 """Defines functions to analyze available opcodes in the ARM ISA."""
 
+import argparse
 
 ARM_ISA_MAP = {
-    "armv7e-m": ["SMLAD"],
+    "armv7e-m": ["SMLAD", "SSUB8", "SEL"],

Review comment:
       armv7e-m : DSP ? 
   

##########
File path: python/tvm/topi/arm_cpu/conv1d.py
##########
@@ -0,0 +1,36 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-variable, no-else-return, unused-argument, import-outside-toplevel
+"""Conv1D schedule for ARM CPU"""
+from __future__ import absolute_import as _abs
+
+from tvm import autotvm
+
+from .cortex_m7.conv1d import direct_simd as direct_simd_conv1d

Review comment:
       Hmmm .. This looks like an odd use ?

##########
File path: python/tvm/topi/arm_cpu/conv1d.py
##########
@@ -0,0 +1,36 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-variable, no-else-return, unused-argument, import-outside-toplevel
+"""Conv1D schedule for ARM CPU"""
+from __future__ import absolute_import as _abs
+
+from tvm import autotvm
+
+from .cortex_m7.conv1d import direct_simd as direct_simd_conv1d
+
+
+@autotvm.register_topi_compute("conv1d_nwc_direct_simd.arm_cpu")
+def conv1d_nwc_direct_simd(cfg, data, kernel, strides, padding, dilation, out_dtype):

Review comment:
       I think these are better known as DSP instructions rather than SIMD. While these are SIMD instructions on the integer register set, the presence of the MVE instruction set will cause more confusion in the future and thus sticking to consistent names from the architecture would be more appropriate.
   
   Please update this everywhere.

##########
File path: python/tvm/relay/op/strategy/arm_cpu.py
##########
@@ -49,6 +49,26 @@ def schedule_concatenate_arm_cpu(_, outs, target):
         return topi.arm_cpu.schedule_concatenate(outs)
 
 
+@schedule_pool.register(["arm_cpu", "micro_dev"])

Review comment:
       What do we mean by micro_dev here ? 

##########
File path: python/tvm/topi/arm_cpu/cortex_m7/conv1d/__init__.py
##########
@@ -0,0 +1,19 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Conv1d implementations for cortex-m7."""

Review comment:
       If these folders exist already probably should be renamed as armv7em/dsp.
   
   Maybe that move is a separate pull request rather than being merged in here.

##########
File path: python/tvm/testing/utils.py
##########
@@ -674,6 +674,18 @@ def requires_opencl(*args):
     return _compose(args, _requires_opencl)
 
 
+def requires_corstone300(*args):
+    """Mark a test as requiring the corstone300 FVP
+
+    Parameters
+    ----------
+    f : function
+        Function to mark
+    """
+    _requires_corstone300 = [pytest.mark.corstone300]

Review comment:
       I think we need a better way of controlling this - possibly something @mousius could comment on here ? 

##########
File path: python/tvm/topi/arm_cpu/cortex_m7/conv1d/direct_simd.py
##########
@@ -0,0 +1,177 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, no-value-for-parameter
+"""Direct implementation of conv1d."""
+from tvm import autotvm
+from tvm.autotvm.task import deserialize_args
+from tvm import te
+from tvm.topi.utils import simplify, traverse_inline
+from tvm.topi.nn.pad import pad
+from tvm.topi.nn.utils import get_pad_tuple1d
+from tvm.tir.expr import Mul
+
+from ..micro_kernel.gemm import (
+    intrin_gemm_MxKxN,
+    gemm_MxKxN_impl,
+)
+
+
+def conv1d_nwc_direct_simd(*args, **kwargs):
+    """Defines the Cortex-M7 SIMD implementation of conv1d on NWC layout."""

Review comment:
       I think this could well work in general for Armv7em and Armv8m.main and indeed any Cortex-M CPU that implements the DSP instruction set. The biggest win that one would get is in the use of these instructions rather than anything micro-architectural here.
   
   Thus I would suggest trying to model this properly in terms of the ISA . 
   
   
   @Mousius would you have some time to take a look at this ? 

##########
File path: python/tvm/target/arm_isa.py
##########
@@ -16,18 +16,24 @@
 # under the License.
 """Defines functions to analyze available opcodes in the ARM ISA."""
 
+import argparse
 
 ARM_ISA_MAP = {
-    "armv7e-m": ["SMLAD"],
+    "armv7e-m": ["SMLAD", "SSUB8", "SEL"],
+    "armv8-m": ["SMLAD", "SSUB8", "SEL"],

Review comment:
       I think what you want is armv8-m.main. 

##########
File path: python/tvm/topi/arm_cpu/dense.py
##########
@@ -0,0 +1,25 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-variable, no-else-return, unused-argument, import-outside-toplevel
+"""Dense schedule for ARM CPU"""
+
+from .cortex_m7.dense import direct_simd

Review comment:
       what would happen with this on AArch64 ? Since these schedules are available on both AArch64 and AArch32 ? 

##########
File path: python/tvm/topi/arm_cpu/cortex_m7/conv1d/direct_simd.py
##########
@@ -0,0 +1,177 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, no-value-for-parameter
+"""Direct implementation of conv1d."""
+from tvm import autotvm
+from tvm.autotvm.task import deserialize_args
+from tvm import te
+from tvm.topi.utils import simplify, traverse_inline
+from tvm.topi.nn.pad import pad
+from tvm.topi.nn.utils import get_pad_tuple1d
+from tvm.tir.expr import Mul
+
+from ..micro_kernel.gemm import (
+    intrin_gemm_MxKxN,
+    gemm_MxKxN_impl,
+)
+
+
+def conv1d_nwc_direct_simd(*args, **kwargs):
+    """Defines the Cortex-M7 SIMD implementation of conv1d on NWC layout."""
+    assert not kwargs, "Do not support kwargs in template function call"
+    args = deserialize_args(args)
+    data, kernel = args[:2]
+    layout = args[-2]
+    cfg = autotvm.get_config()
+    args = [cfg] + args
+    assert layout == "NWC"
+    conv = conv1d_nwc_direct_simd_compute(*args)
+    sched = conv1d_nwc_direct_simd_schedule(cfg, [data, kernel, conv])
+    return sched, [data, kernel, conv]
+
+
+conv1d_nwc_direct_simd.template_key = "direct_simd"
+conv1d_nwc_direct_simd.default_data_layout = "NWC"
+conv1d_nwc_direct_simd.default_kernel_layout = "WOI"
+
+
+def conv1d_nwc_direct_simd_compute(cfg, data, kernel, strides, padding, dilation, out_dtype):
+    """Compute function for Cortex-M7 SIMD implementation of conv1d on NWC layout."""
+    if isinstance(strides, (tuple, list)):
+        strides = strides[0]
+    if isinstance(dilation, (tuple, list)):
+        dilation = dilation[0]
+
+    batch_size, data_width, in_channels = data.shape
+    kernel_size, out_channels, _ = kernel.shape
+
+    # Compute the output shape
+    dilated_kernel_size = (kernel_size - 1) * dilation + 1
+    pad_left, pad_right = get_pad_tuple1d(padding, (dilated_kernel_size,))
+    out_channels = simplify(out_channels)
+    out_width = simplify((data_width - dilated_kernel_size + pad_left + pad_right) // strides + 1)
+
+    # Apply padding
+    pad_before = [0, pad_left, 0]
+    pad_after = [0, pad_right, 0]
+    padded_data = pad(data, pad_before, pad_after, name="padded_data")
+
+    # Compute graph
+    rc = te.reduce_axis((0, in_channels), name="rc")
+    rw = te.reduce_axis((0, kernel_size), name="rw")
+
+    conv = te.compute(
+        (batch_size, out_width, out_channels),
+        lambda b, w, c: te.sum(
+            padded_data[b, w * strides + rw * dilation, rc].astype(out_dtype)
+            * kernel[rw, c, rc].astype(out_dtype),
+            axis=[rw, rc],
+        ),
+        name="conv1d",
+        tag="conv1d_nwc",
+    )
+
+    ###########################
+    # Config Space Definition #
+    ###########################
+    n, ow, co = (
+        cfg.axis(batch_size.value),
+        cfg.axis(out_width.value),
+        cfg.axis(out_channels.value),
+    )
+    kw, ci = (
+        cfg.reduce_axis(kernel_size.value),
+        cfg.reduce_axis(in_channels.value),
+    )
+
+    owo, owi = cfg.define_split("tile_ow", ow, policy="factors", num_outputs=2)
+    cio, cii = cfg.define_split(
+        "tile_ci",
+        ci,
+        policy="factors",
+        num_outputs=2,
+        # TODO: check case with in_channels.value % 4 != 0 with AutoTVM
+        filter=None if cfg.is_fallback else lambda x: x.size[-1] % 4 == 0,
+    )
+    coo, coi = cfg.define_split("tile_co", co, policy="factors", num_outputs=2)
+
+    cfg.define_reorder(
+        "reorder_0_simd",
+        [n, owo, owi, coo, coi, kw, cio, cii],
+        policy="candidate",
+        candidate=[
+            [n, kw, owo, coo, cio, owi, coi, cii],
+            [n, kw, coo, owo, cio, owi, coi, cii],
+            [n, kw, owo, coo, cio, owi, coi, cii],
+            [n, kw, coo, owo, cio, owi, coi, cii],
+        ],
+    )
+
+    cfg.define_knob("auto_unroll_max_step", [0, 2, 4, 8, 16, 32])
+    cfg.define_knob("unroll_explicit", [0, 1])
+
+    if cfg.is_fallback:
+        cfg.fallback_split("tile_ow", [-1, out_width.value])
+        cfg.fallback_split("tile_ci", [-1, in_channels.value])
+        cfg.fallback_split("tile_co", [-1, out_channels.value])
+
+    return conv
+
+
+def conv1d_nwc_direct_simd_schedule(cfg, outs):
+    """Schedule function for Cortex-M7 SIMD implementation of conv1d on NWC layout."""

Review comment:
       Fix comments to reflect that these are using the DSP extensions rather than SIMD.
   
   Perhaps say : 
   
   "Schedule function for v7em DSP instructions of conv1d on NWC layout"
   
   Please audit the whole file for such usage and fix it everywhere.

##########
File path: tests/python/integration/test_m7_simd.py
##########
@@ -0,0 +1,355 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import sys
+import numpy as np
+import pytest
+import tvm
+from tvm import relay
+from tests.python.relay.aot.aot_test_utils import (
+    AOTTestModel,
+    AOT_CORSTONE300_RUNNER,
+    generate_ref_data,
+    compile_and_run,
+)
+
+
+@tvm.testing.requires_corstone300
+@pytest.mark.parametrize(
+    "data_shape_nhwc, kernel_size, num_filter, strides, padding, dilation",
+    [
+        ((1, 32, 32, 1), (3, 3), 12, 1, 0, 1),
+        ((1, 32, 10, 3), (3, 3), 16, 1, 0, 1),
+        ((1, 49, 10, 1), (10, 4), 64, (2, 1), (4, 1, 5, 1), 1),
+        ((1, 32, 32, 16), (3, 3), 16, 1, (0, 2, 2, 0), 1),
+        ((1, 32, 32, 16), (3, 3), 16, 1, 0, 1),
+        ((1, 32, 32, 16), (3, 3), 16, 1, 0, 1),
+        ((1, 32, 32, 16), (3, 3), 16, 1, (0, 2, 2, 0), 2),
+        ((1, 32, 32, 16), (3, 3), 16, 1, (1, 1, 2, 2), 2),
+        # bug https://github.com/apache/tvm/issues/9226
+        ((1, 49, 10, 1), (10, 4), 64, (2, 2), (4, 1, 5, 1), 1),
+        # from Visual Wake Word model
+        ((1, 96, 96, 3), (3, 3), 8, (2, 2), (0, 0, 1, 1), 1),
+        # from Image Classification model (one of the MLPerfTiny models)
+        ((1, 16, 16, 32), (1, 1), 64, (2, 2), 0, 1),
+        ((4, 16, 16, 8), (5, 5), 8, 2, (0, 4, 4, 0), 1),
+        ((4, 16, 16, 8), (5, 5), 16, 2, (0, 4, 4, 0), 1),
+        ((4, 16, 16, 8), (5, 5), 8, 2, 0, 1),
+        ((4, 16, 16, 8), (5, 5), 16, 2, 0, 1),
+        ((1, 16, 16, 8), (3, 3), 16, 2, (0, 0, 1, 1), 1),
+        ((1, 16, 16, 8), (3, 3), 16, 2, (1, 1, 2, 2), 1),
+        ((1, 16, 16, 8), (5, 5), 16, 2, (3, 3, 2, 2), 1),
+        ((1, 16, 16, 8), (3, 3), 16, 2, (0, 1, 2, 3), 1),
+    ],
+)
+@pytest.mark.parametrize("dtype", ["int8", "int16"])
+def test_conv2d(data_shape_nhwc, kernel_size, num_filter, strides, padding, dilation, dtype):
+    """Test a subgraph with a single conv2d operator."""
+    ishape = data_shape_nhwc
+    wshape = (*kernel_size, data_shape_nhwc[-1], num_filter)
+
+    weight_data = np.random.randint(low=-10, high=10, size=wshape, dtype=dtype)
+
+    input0 = relay.var("input", relay.TensorType(ishape, dtype))
+    weight0 = relay.const(weight_data)
+    out0 = relay.op.nn.conv2d(
+        input0,
+        weight0,
+        kernel_size=kernel_size,
+        strides=strides,
+        padding=padding,
+        dilation=(dilation, dilation),
+        data_layout="NHWC",
+        kernel_layout="HWIO",
+        out_dtype="int32",
+        out_layout="NHWC",
+    )
+    ref_mod = tvm.IRModule.from_expr(relay.Function([input0], out0))
+
+    input1 = relay.var("input", relay.TensorType(ishape, dtype))
+    weight1 = relay.const(np.moveaxis(weight_data, 2, -1))
+    out1 = relay.op.nn.conv2d(
+        input1,
+        weight1,
+        kernel_size=kernel_size,
+        strides=strides,
+        padding=padding,
+        dilation=(dilation, dilation),
+        data_layout="NHWC",
+        kernel_layout="HWOI",
+        out_dtype="int32",
+        out_layout="NHWC",
+    )
+    mod = tvm.IRModule.from_expr(relay.Function([input1], out1))
+
+    inputs = {"input": np.random.randint(low=-128, high=127, size=ishape, dtype=dtype)}
+    output_list = generate_ref_data(ref_mod, inputs)
+
+    compile_and_run(
+        AOTTestModel(module=mod, inputs=inputs, outputs=output_list),
+        runner=AOT_CORSTONE300_RUNNER,
+        interface_api="c",
+        use_unpacked_api=True,
+        target_opts={
+            "-keys": "arm_cpu",
+            "-march": "armv7e-m",

Review comment:
       I'm a bit confused with the use of -march=armv7e-m and not -mcpu here 

##########
File path: tests/python/conftest.py
##########
@@ -40,3 +41,26 @@
 
 if tvm.support.libinfo().get("USE_MICRO", "OFF") != "ON":
     collect_ignore.append("unittest/test_micro_transport.py")
+
+
+def pytest_addoption(parser):
+    parser.addoption(
+        "--enable-corstone300-tests",
+        action="store_true",
+        default=False,
+        help="Run Corstone-300 FVP tests",
+    )
+
+
+def pytest_collection_modifyitems(config, items):
+    for item in items:
+        if config.getoption("--enable-corstone300-tests"):
+            if not "corstone300" in item.keywords:
+                item.add_marker(
+                    pytest.mark.skip(reason="Test shold be marked 'corstone300' to run")

Review comment:
       s/shold/should.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] sergey-grovety commented on a change in pull request #9233: Cortex m7 intrinsic

Posted by GitBox <gi...@apache.org>.
sergey-grovety commented on a change in pull request #9233:
URL: https://github.com/apache/tvm/pull/9233#discussion_r729871733



##########
File path: python/tvm/target/arm_isa.py
##########
@@ -16,18 +16,24 @@
 # under the License.
 """Defines functions to analyze available opcodes in the ARM ISA."""
 
+import argparse
 
 ARM_ISA_MAP = {
-    "armv7e-m": ["SMLAD"],
+    "armv7e-m": ["SMLAD", "SSUB8", "SEL"],
+    "armv8-m": ["SMLAD", "SSUB8", "SEL"],
 }
 
 
 class IsaAnalyzer(object):
+    """Checks ISA support for given target"""
+
     def __init__(self, target):
         self.target = target
-        # TODO: actually parse -mcpu
-        arch = "armv7e-m"
-        self._isa_map = ARM_ISA_MAP[arch]
+        parser = argparse.ArgumentParser()

Review comment:
       not sure if I understand you correctly, but _target_ here is a string

##########
File path: python/tvm/target/arm_isa.py
##########
@@ -16,18 +16,24 @@
 # under the License.
 """Defines functions to analyze available opcodes in the ARM ISA."""
 
+import argparse
 
 ARM_ISA_MAP = {
-    "armv7e-m": ["SMLAD"],
+    "armv7e-m": ["SMLAD", "SSUB8", "SEL"],
+    "armv8-m": ["SMLAD", "SSUB8", "SEL"],
 }
 
 
 class IsaAnalyzer(object):
+    """Checks ISA support for given target"""
+
     def __init__(self, target):
         self.target = target
-        # TODO: actually parse -mcpu
-        arch = "armv7e-m"
-        self._isa_map = ARM_ISA_MAP[arch]
+        parser = argparse.ArgumentParser()

Review comment:
       not sure if I understand you correctly, but _target_ here is a string

##########
File path: python/tvm/target/arm_isa.py
##########
@@ -16,18 +16,24 @@
 # under the License.
 """Defines functions to analyze available opcodes in the ARM ISA."""
 
+import argparse
 
 ARM_ISA_MAP = {
-    "armv7e-m": ["SMLAD"],
+    "armv7e-m": ["SMLAD", "SSUB8", "SEL"],
+    "armv8-m": ["SMLAD", "SSUB8", "SEL"],
 }
 
 
 class IsaAnalyzer(object):
+    """Checks ISA support for given target"""
+
     def __init__(self, target):
         self.target = target
-        # TODO: actually parse -mcpu
-        arch = "armv7e-m"
-        self._isa_map = ARM_ISA_MAP[arch]
+        parser = argparse.ArgumentParser()

Review comment:
       not sure if I understand you correctly, but _target_ here is a string




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org