You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/10/22 14:03:42 UTC

[GitHub] [tvm] u99127 commented on a change in pull request #9233: Cortex m7 intrinsic

u99127 commented on a change in pull request #9233:
URL: https://github.com/apache/tvm/pull/9233#discussion_r734055663



##########
File path: python/tvm/relay/op/strategy/arm_cpu.py
##########
@@ -49,6 +49,26 @@ def schedule_concatenate_arm_cpu(_, outs, target):
         return topi.arm_cpu.schedule_concatenate(outs)
 
 
+@schedule_pool.register(["arm_cpu", "micro_dev"])
+def schedule_pool_arm_cpu(attrs, outs, target):
+    """schedule pooling ops arm cpu"""
+    layout = attrs.layout
+    isa = arm_isa.IsaAnalyzer(target)
+    avg_pool = isinstance(attrs, relay.op.op_attrs.AvgPool2DAttrs)
+    with target:
+        if (
+            avg_pool
+            and layout in ("NCW", "NCHW")
+            and "SMLAD" in isa

Review comment:
       SMLAD, SSUB8 and SEL are part of the DSP instructions and the presence of one implies the presence of the other. I also think that in this case since we are adding all of these together globbing them into a single check for the use of the DSP extensions should be sufficient. Any reason why we are testing individual instructions ? 
   

##########
File path: python/tvm/target/arm_isa.py
##########
@@ -16,18 +16,24 @@
 # under the License.
 """Defines functions to analyze available opcodes in the ARM ISA."""
 
+import argparse
 
 ARM_ISA_MAP = {
-    "armv7e-m": ["SMLAD"],
+    "armv7e-m": ["SMLAD", "SSUB8", "SEL"],

Review comment:
       armv7e-m : DSP ? 
   

##########
File path: python/tvm/topi/arm_cpu/conv1d.py
##########
@@ -0,0 +1,36 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-variable, no-else-return, unused-argument, import-outside-toplevel
+"""Conv1D schedule for ARM CPU"""
+from __future__ import absolute_import as _abs
+
+from tvm import autotvm
+
+from .cortex_m7.conv1d import direct_simd as direct_simd_conv1d

Review comment:
       Hmmm .. This looks like an odd use ?

##########
File path: python/tvm/topi/arm_cpu/conv1d.py
##########
@@ -0,0 +1,36 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-variable, no-else-return, unused-argument, import-outside-toplevel
+"""Conv1D schedule for ARM CPU"""
+from __future__ import absolute_import as _abs
+
+from tvm import autotvm
+
+from .cortex_m7.conv1d import direct_simd as direct_simd_conv1d
+
+
+@autotvm.register_topi_compute("conv1d_nwc_direct_simd.arm_cpu")
+def conv1d_nwc_direct_simd(cfg, data, kernel, strides, padding, dilation, out_dtype):

Review comment:
       I think these are better known as DSP instructions rather than SIMD. While these are SIMD instructions on the integer register set, the presence of the MVE instruction set will cause more confusion in the future and thus sticking to consistent names from the architecture would be more appropriate.
   
   Please update this everywhere.

##########
File path: python/tvm/relay/op/strategy/arm_cpu.py
##########
@@ -49,6 +49,26 @@ def schedule_concatenate_arm_cpu(_, outs, target):
         return topi.arm_cpu.schedule_concatenate(outs)
 
 
+@schedule_pool.register(["arm_cpu", "micro_dev"])

Review comment:
       What do we mean by micro_dev here ? 

##########
File path: python/tvm/topi/arm_cpu/cortex_m7/conv1d/__init__.py
##########
@@ -0,0 +1,19 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Conv1d implementations for cortex-m7."""

Review comment:
       If these folders exist already probably should be renamed as armv7em/dsp.
   
   Maybe that move is a separate pull request rather than being merged in here.

##########
File path: python/tvm/testing/utils.py
##########
@@ -674,6 +674,18 @@ def requires_opencl(*args):
     return _compose(args, _requires_opencl)
 
 
+def requires_corstone300(*args):
+    """Mark a test as requiring the corstone300 FVP
+
+    Parameters
+    ----------
+    f : function
+        Function to mark
+    """
+    _requires_corstone300 = [pytest.mark.corstone300]

Review comment:
       I think we need a better way of controlling this - possibly something @mousius could comment on here ? 

##########
File path: python/tvm/topi/arm_cpu/cortex_m7/conv1d/direct_simd.py
##########
@@ -0,0 +1,177 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, no-value-for-parameter
+"""Direct implementation of conv1d."""
+from tvm import autotvm
+from tvm.autotvm.task import deserialize_args
+from tvm import te
+from tvm.topi.utils import simplify, traverse_inline
+from tvm.topi.nn.pad import pad
+from tvm.topi.nn.utils import get_pad_tuple1d
+from tvm.tir.expr import Mul
+
+from ..micro_kernel.gemm import (
+    intrin_gemm_MxKxN,
+    gemm_MxKxN_impl,
+)
+
+
+def conv1d_nwc_direct_simd(*args, **kwargs):
+    """Defines the Cortex-M7 SIMD implementation of conv1d on NWC layout."""

Review comment:
       I think this could well work in general for Armv7em and Armv8m.main and indeed any Cortex-M CPU that implements the DSP instruction set. The biggest win that one would get is in the use of these instructions rather than anything micro-architectural here.
   
   Thus I would suggest trying to model this properly in terms of the ISA . 
   
   
   @Mousius would you have some time to take a look at this ? 

##########
File path: python/tvm/target/arm_isa.py
##########
@@ -16,18 +16,24 @@
 # under the License.
 """Defines functions to analyze available opcodes in the ARM ISA."""
 
+import argparse
 
 ARM_ISA_MAP = {
-    "armv7e-m": ["SMLAD"],
+    "armv7e-m": ["SMLAD", "SSUB8", "SEL"],
+    "armv8-m": ["SMLAD", "SSUB8", "SEL"],

Review comment:
       I think what you want is armv8-m.main. 

##########
File path: python/tvm/topi/arm_cpu/dense.py
##########
@@ -0,0 +1,25 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-variable, no-else-return, unused-argument, import-outside-toplevel
+"""Dense schedule for ARM CPU"""
+
+from .cortex_m7.dense import direct_simd

Review comment:
       what would happen with this on AArch64 ? Since these schedules are available on both AArch64 and AArch32 ? 

##########
File path: python/tvm/topi/arm_cpu/cortex_m7/conv1d/direct_simd.py
##########
@@ -0,0 +1,177 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, no-value-for-parameter
+"""Direct implementation of conv1d."""
+from tvm import autotvm
+from tvm.autotvm.task import deserialize_args
+from tvm import te
+from tvm.topi.utils import simplify, traverse_inline
+from tvm.topi.nn.pad import pad
+from tvm.topi.nn.utils import get_pad_tuple1d
+from tvm.tir.expr import Mul
+
+from ..micro_kernel.gemm import (
+    intrin_gemm_MxKxN,
+    gemm_MxKxN_impl,
+)
+
+
+def conv1d_nwc_direct_simd(*args, **kwargs):
+    """Defines the Cortex-M7 SIMD implementation of conv1d on NWC layout."""
+    assert not kwargs, "Do not support kwargs in template function call"
+    args = deserialize_args(args)
+    data, kernel = args[:2]
+    layout = args[-2]
+    cfg = autotvm.get_config()
+    args = [cfg] + args
+    assert layout == "NWC"
+    conv = conv1d_nwc_direct_simd_compute(*args)
+    sched = conv1d_nwc_direct_simd_schedule(cfg, [data, kernel, conv])
+    return sched, [data, kernel, conv]
+
+
+conv1d_nwc_direct_simd.template_key = "direct_simd"
+conv1d_nwc_direct_simd.default_data_layout = "NWC"
+conv1d_nwc_direct_simd.default_kernel_layout = "WOI"
+
+
+def conv1d_nwc_direct_simd_compute(cfg, data, kernel, strides, padding, dilation, out_dtype):
+    """Compute function for Cortex-M7 SIMD implementation of conv1d on NWC layout."""
+    if isinstance(strides, (tuple, list)):
+        strides = strides[0]
+    if isinstance(dilation, (tuple, list)):
+        dilation = dilation[0]
+
+    batch_size, data_width, in_channels = data.shape
+    kernel_size, out_channels, _ = kernel.shape
+
+    # Compute the output shape
+    dilated_kernel_size = (kernel_size - 1) * dilation + 1
+    pad_left, pad_right = get_pad_tuple1d(padding, (dilated_kernel_size,))
+    out_channels = simplify(out_channels)
+    out_width = simplify((data_width - dilated_kernel_size + pad_left + pad_right) // strides + 1)
+
+    # Apply padding
+    pad_before = [0, pad_left, 0]
+    pad_after = [0, pad_right, 0]
+    padded_data = pad(data, pad_before, pad_after, name="padded_data")
+
+    # Compute graph
+    rc = te.reduce_axis((0, in_channels), name="rc")
+    rw = te.reduce_axis((0, kernel_size), name="rw")
+
+    conv = te.compute(
+        (batch_size, out_width, out_channels),
+        lambda b, w, c: te.sum(
+            padded_data[b, w * strides + rw * dilation, rc].astype(out_dtype)
+            * kernel[rw, c, rc].astype(out_dtype),
+            axis=[rw, rc],
+        ),
+        name="conv1d",
+        tag="conv1d_nwc",
+    )
+
+    ###########################
+    # Config Space Definition #
+    ###########################
+    n, ow, co = (
+        cfg.axis(batch_size.value),
+        cfg.axis(out_width.value),
+        cfg.axis(out_channels.value),
+    )
+    kw, ci = (
+        cfg.reduce_axis(kernel_size.value),
+        cfg.reduce_axis(in_channels.value),
+    )
+
+    owo, owi = cfg.define_split("tile_ow", ow, policy="factors", num_outputs=2)
+    cio, cii = cfg.define_split(
+        "tile_ci",
+        ci,
+        policy="factors",
+        num_outputs=2,
+        # TODO: check case with in_channels.value % 4 != 0 with AutoTVM
+        filter=None if cfg.is_fallback else lambda x: x.size[-1] % 4 == 0,
+    )
+    coo, coi = cfg.define_split("tile_co", co, policy="factors", num_outputs=2)
+
+    cfg.define_reorder(
+        "reorder_0_simd",
+        [n, owo, owi, coo, coi, kw, cio, cii],
+        policy="candidate",
+        candidate=[
+            [n, kw, owo, coo, cio, owi, coi, cii],
+            [n, kw, coo, owo, cio, owi, coi, cii],
+            [n, kw, owo, coo, cio, owi, coi, cii],
+            [n, kw, coo, owo, cio, owi, coi, cii],
+        ],
+    )
+
+    cfg.define_knob("auto_unroll_max_step", [0, 2, 4, 8, 16, 32])
+    cfg.define_knob("unroll_explicit", [0, 1])
+
+    if cfg.is_fallback:
+        cfg.fallback_split("tile_ow", [-1, out_width.value])
+        cfg.fallback_split("tile_ci", [-1, in_channels.value])
+        cfg.fallback_split("tile_co", [-1, out_channels.value])
+
+    return conv
+
+
+def conv1d_nwc_direct_simd_schedule(cfg, outs):
+    """Schedule function for Cortex-M7 SIMD implementation of conv1d on NWC layout."""

Review comment:
       Fix comments to reflect that these are using the DSP extensions rather than SIMD.
   
   Perhaps say : 
   
   "Schedule function for v7em DSP instructions of conv1d on NWC layout"
   
   Please audit the whole file for such usage and fix it everywhere.

##########
File path: tests/python/integration/test_m7_simd.py
##########
@@ -0,0 +1,355 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import sys
+import numpy as np
+import pytest
+import tvm
+from tvm import relay
+from tests.python.relay.aot.aot_test_utils import (
+    AOTTestModel,
+    AOT_CORSTONE300_RUNNER,
+    generate_ref_data,
+    compile_and_run,
+)
+
+
+@tvm.testing.requires_corstone300
+@pytest.mark.parametrize(
+    "data_shape_nhwc, kernel_size, num_filter, strides, padding, dilation",
+    [
+        ((1, 32, 32, 1), (3, 3), 12, 1, 0, 1),
+        ((1, 32, 10, 3), (3, 3), 16, 1, 0, 1),
+        ((1, 49, 10, 1), (10, 4), 64, (2, 1), (4, 1, 5, 1), 1),
+        ((1, 32, 32, 16), (3, 3), 16, 1, (0, 2, 2, 0), 1),
+        ((1, 32, 32, 16), (3, 3), 16, 1, 0, 1),
+        ((1, 32, 32, 16), (3, 3), 16, 1, 0, 1),
+        ((1, 32, 32, 16), (3, 3), 16, 1, (0, 2, 2, 0), 2),
+        ((1, 32, 32, 16), (3, 3), 16, 1, (1, 1, 2, 2), 2),
+        # bug https://github.com/apache/tvm/issues/9226
+        ((1, 49, 10, 1), (10, 4), 64, (2, 2), (4, 1, 5, 1), 1),
+        # from Visual Wake Word model
+        ((1, 96, 96, 3), (3, 3), 8, (2, 2), (0, 0, 1, 1), 1),
+        # from Image Classification model (one of the MLPerfTiny models)
+        ((1, 16, 16, 32), (1, 1), 64, (2, 2), 0, 1),
+        ((4, 16, 16, 8), (5, 5), 8, 2, (0, 4, 4, 0), 1),
+        ((4, 16, 16, 8), (5, 5), 16, 2, (0, 4, 4, 0), 1),
+        ((4, 16, 16, 8), (5, 5), 8, 2, 0, 1),
+        ((4, 16, 16, 8), (5, 5), 16, 2, 0, 1),
+        ((1, 16, 16, 8), (3, 3), 16, 2, (0, 0, 1, 1), 1),
+        ((1, 16, 16, 8), (3, 3), 16, 2, (1, 1, 2, 2), 1),
+        ((1, 16, 16, 8), (5, 5), 16, 2, (3, 3, 2, 2), 1),
+        ((1, 16, 16, 8), (3, 3), 16, 2, (0, 1, 2, 3), 1),
+    ],
+)
+@pytest.mark.parametrize("dtype", ["int8", "int16"])
+def test_conv2d(data_shape_nhwc, kernel_size, num_filter, strides, padding, dilation, dtype):
+    """Test a subgraph with a single conv2d operator."""
+    ishape = data_shape_nhwc
+    wshape = (*kernel_size, data_shape_nhwc[-1], num_filter)
+
+    weight_data = np.random.randint(low=-10, high=10, size=wshape, dtype=dtype)
+
+    input0 = relay.var("input", relay.TensorType(ishape, dtype))
+    weight0 = relay.const(weight_data)
+    out0 = relay.op.nn.conv2d(
+        input0,
+        weight0,
+        kernel_size=kernel_size,
+        strides=strides,
+        padding=padding,
+        dilation=(dilation, dilation),
+        data_layout="NHWC",
+        kernel_layout="HWIO",
+        out_dtype="int32",
+        out_layout="NHWC",
+    )
+    ref_mod = tvm.IRModule.from_expr(relay.Function([input0], out0))
+
+    input1 = relay.var("input", relay.TensorType(ishape, dtype))
+    weight1 = relay.const(np.moveaxis(weight_data, 2, -1))
+    out1 = relay.op.nn.conv2d(
+        input1,
+        weight1,
+        kernel_size=kernel_size,
+        strides=strides,
+        padding=padding,
+        dilation=(dilation, dilation),
+        data_layout="NHWC",
+        kernel_layout="HWOI",
+        out_dtype="int32",
+        out_layout="NHWC",
+    )
+    mod = tvm.IRModule.from_expr(relay.Function([input1], out1))
+
+    inputs = {"input": np.random.randint(low=-128, high=127, size=ishape, dtype=dtype)}
+    output_list = generate_ref_data(ref_mod, inputs)
+
+    compile_and_run(
+        AOTTestModel(module=mod, inputs=inputs, outputs=output_list),
+        runner=AOT_CORSTONE300_RUNNER,
+        interface_api="c",
+        use_unpacked_api=True,
+        target_opts={
+            "-keys": "arm_cpu",
+            "-march": "armv7e-m",

Review comment:
       I'm a bit confused with the use of -march=armv7e-m and not -mcpu here 

##########
File path: tests/python/conftest.py
##########
@@ -40,3 +41,26 @@
 
 if tvm.support.libinfo().get("USE_MICRO", "OFF") != "ON":
     collect_ignore.append("unittest/test_micro_transport.py")
+
+
+def pytest_addoption(parser):
+    parser.addoption(
+        "--enable-corstone300-tests",
+        action="store_true",
+        default=False,
+        help="Run Corstone-300 FVP tests",
+    )
+
+
+def pytest_collection_modifyitems(config, items):
+    for item in items:
+        if config.getoption("--enable-corstone300-tests"):
+            if not "corstone300" in item.keywords:
+                item.add_marker(
+                    pytest.mark.skip(reason="Test shold be marked 'corstone300' to run")

Review comment:
       s/shold/should.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org