You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by me...@apache.org on 2022/04/28 15:07:50 UTC

[tvm] branch main updated: [Hexagon] Add test for depthwise conv2d schedule (#11138)

This is an automated email from the ASF dual-hosted git repository.

mehrdadh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new efbd721618 [Hexagon] Add test for depthwise conv2d schedule (#11138)
efbd721618 is described below

commit efbd721618c70c4ad9304a9c69be5407b02b013e
Author: Farshid Salemi Parizi <fp...@octoml.ai>
AuthorDate: Thu Apr 28 08:07:42 2022 -0700

    [Hexagon] Add test for depthwise conv2d schedule (#11138)
    
    * Add test for registered scheduales - depthwise_conv2d
---
 .../test_hexagon/topi/test_depthwise_conv2d.py     | 298 +++++++++++++++++++++
 1 file changed, 298 insertions(+)

diff --git a/tests/python/contrib/test_hexagon/topi/test_depthwise_conv2d.py b/tests/python/contrib/test_hexagon/topi/test_depthwise_conv2d.py
new file mode 100644
index 0000000000..6343a10f1f
--- /dev/null
+++ b/tests/python/contrib/test_hexagon/topi/test_depthwise_conv2d.py
@@ -0,0 +1,298 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import sys
+
+import numpy as np
+import pytest
+
+import tvm
+import tvm.testing
+import tvm.topi.testing
+
+from tvm import te, topi
+from tvm.topi.utils import get_const_tuple
+from tvm.topi.nn.utils import get_pad_tuple
+from ..conftest import requires_hexagon_toolchain
+
+
+random_seed = tvm.testing.parameter(0)
+
+in_dtype, out_dtype = tvm.testing.parameters(
+    ("float32", "float32"),
+)
+
+
+@tvm.testing.fixture
+def input_shape(layout, batch, in_channel, in_size, filter_shape):
+    if layout == "NCHW":
+        return (batch, in_channel, in_size, in_size)
+    elif layout == "NHWC":
+        return (batch, in_size, in_size, in_channel)
+    elif layout == "NCHWc":
+        oc_block = filter_shape[-1]
+        ic_block = next(bn for bn in range(oc_block, 0, -1) if in_channel % bn == 0)
+        return (batch, in_channel // ic_block, in_size, in_size, ic_block)
+
+
+@tvm.testing.fixture
+def filter_shape(layout, in_channel, channel_multiplier, kernel):
+    filter_channel = in_channel
+    if layout == "NCHW":
+        return (filter_channel, channel_multiplier, kernel, kernel)
+    elif layout == "NHWC":
+        return (kernel, kernel, filter_channel, channel_multiplier)
+    elif layout == "NCHWc":
+        out_channel = in_channel * channel_multiplier
+        # For testing the functionality, we choose an arbitrary block
+        # size that can divide out_channel, regardless of the
+        # performance.
+        oc_block = next(bn for bn in range(16, 0, -1) if out_channel % bn == 0)
+        return (out_channel // oc_block, 1, kernel, kernel, 1, oc_block)
+
+
+@tvm.testing.fixture
+def scale_shape(layout, in_channel, channel_multiplier, filter_shape):
+    out_channel = in_channel * channel_multiplier
+
+    if layout in ("NCHW", "NHWC"):
+        return (out_channel,)
+
+    if layout == "NCHWc":
+        oc_block = filter_shape[-1]
+        return (out_channel // oc_block, oc_block)
+
+    raise ValueError("Unknown layout {}".format(layout))
+
+
+@tvm.testing.fixture
+def shift_shape(scale_shape):
+    return scale_shape
+
+
+@tvm.testing.fixture(cache_return_value=True)
+def ref_data(
+    random_seed,
+    in_dtype,
+    out_dtype,
+    layout,
+    input_shape,
+    filter_shape,
+    dilation,
+    stride,
+    padding,
+    scale_shape,
+    shift_shape,
+    use_scale_shift,
+    apply_relu,
+):
+    np.random.seed(random_seed)
+
+    print(input_shape)
+
+    # scipy.signal.convolve2d does not support float16 data types, and
+    # the python fallback is too slow for general use.  Computing
+    # ref_data in float32 will have fewer rounding errors than the TVM
+    # float16 compute, but those vary based on schedule anyways.
+    conv_dtype = "float32" if in_dtype == "float16" else in_dtype
+
+    input_np = np.random.uniform(size=input_shape).astype(in_dtype)
+    filter_np = np.random.uniform(size=filter_shape).astype(in_dtype)
+    scale_np = np.random.uniform(size=scale_shape).astype(out_dtype)
+    shift_np = np.random.uniform(size=shift_shape).astype(out_dtype)
+    if layout == "NCHW":
+        np_depthwise_conv2d = tvm.topi.testing.depthwise_conv2d_python_nchw
+        dilation = (1, 1, dilation, dilation)
+        reshape = (1, -1, 1, 1)
+    elif layout == "NHWC":
+        np_depthwise_conv2d = tvm.topi.testing.depthwise_conv2d_python_nhwc
+        dilation = (dilation, dilation, 1, 1)
+        reshape = (1, 1, 1, -1)
+    elif layout == "NCHWc":
+        np_depthwise_conv2d = tvm.topi.testing.depthwise_conv2d_python_nchwc
+        dilation = (1, 1, dilation, dilation, 1, 1)
+        reshape = (1, scale_shape[0], 1, 1, scale_shape[1])
+
+    dilated_filter_np = tvm.topi.testing.dilate_python(filter_np, dilation)
+    output_np = np_depthwise_conv2d(
+        input_np.astype(conv_dtype), dilated_filter_np.astype(conv_dtype), stride, padding
+    ).astype(out_dtype)
+
+    if use_scale_shift:
+        output_np = output_np * scale_np.reshape(reshape) + shift_np.reshape(reshape)
+    if apply_relu:
+        output_np = np.maximum(output_np, 0)
+
+    return (
+        input_np,
+        filter_np,
+        scale_np,
+        shift_np,
+        output_np,
+    )
+
+
+class BaseDepthwiseConv2D:
+    """Provides the test_conv2d test function, to be used by other test classes.
+
+    Test parameter sets are split out into different classes for
+    readability (e.g. used for mobilenet), and for restrictions
+    (e.g. implemented only for llvm).
+    """
+
+    @requires_hexagon_toolchain
+    def test_conv2d(
+        self,
+        hexagon_session,
+        in_dtype,
+        out_dtype,
+        layout,
+        input_shape,
+        filter_shape,
+        scale_shape,
+        shift_shape,
+        use_scale_shift,
+        apply_relu,
+        batch,
+        in_channel,
+        channel_multiplier,
+        kernel,
+        stride,
+        padding,
+        dilation,
+        ref_data,
+    ):
+        target_hexagon = tvm.target.hexagon("v68")
+
+        # Transform the padding argument from 'str' to 'tuple' to
+        # match the "workload" tuple in TopHub.  Which padding_args to
+        # use for each layout chosen to reproduce previous behavior.
+        if dilation == 1:
+            padding_args = get_pad_tuple(padding, (kernel, kernel))
+            padding_args_i = [0, 1, 2, 3] if layout == "NCHW" else [0, 1]
+            padding_args = [padding_args[i] for i in padding_args_i]
+        else:
+            padding_args = padding
+
+        # placeholder
+        Input = te.placeholder(input_shape, name="Input", dtype=in_dtype)
+        Filter = te.placeholder(filter_shape, name="Filter", dtype=in_dtype)
+        Scale = te.placeholder(scale_shape, name="Scale", dtype=out_dtype)
+        Shift = te.placeholder(shift_shape, name="Shift", dtype=out_dtype)
+
+        if layout == "NCHW":
+            topi_scale_shift = topi.nn.scale_shift_nchw
+            fcompute_args = (Input, Filter, stride, padding_args, dilation, out_dtype)
+
+        elif layout == "NHWC":
+            topi_scale_shift = topi.nn.scale_shift_nhwc
+            fcompute_args = (Input, Filter, stride, padding_args, dilation, out_dtype)
+
+        elif layout == "NCHWc":
+            topi_scale_shift = topi.nn.scale_shift_nchwc
+            in_layout = "NCHW{}c".format(input_shape[-1])
+            out_layout = "NCHW{}c".format(filter_shape[-1])
+            fcompute_args = (
+                Input,
+                Filter,
+                stride,
+                padding,
+                dilation,
+                in_layout,
+                out_layout,
+                out_dtype,
+            )
+
+        with tvm.target.Target(target_hexagon):
+            # Declare, build schedule
+            if layout == "NCHW":
+                fcompute = topi.nn.depthwise_conv2d_nchw
+                fschedule = topi.hexagon.schedule_depthwise_conv2d_nchw
+            elif layout == "NHWC":
+                fcompute = topi.nn.depthwise_conv2d_nhwc
+                fschedule = topi.hexagon.schedule_depthwise_conv2d_nhwc
+            C = fcompute(*fcompute_args)
+            if use_scale_shift:
+                C = topi_scale_shift(C, Scale, Shift)
+            if apply_relu:
+                C = topi.nn.relu(C)
+
+            s = fschedule([C])
+
+            # Build and run
+            f = tvm.build(
+                s,
+                [Input, Filter, Scale, Shift, C],
+                tvm.target.Target(target_hexagon, host=target_hexagon),
+            )
+            mod = hexagon_session.load_module(f)
+
+            input_np, filter_np, scale_np, shift_np, output_np = ref_data
+
+            dev = hexagon_session.device
+            input_tvm = tvm.nd.array(input_np, dev)
+            filter_tvm = tvm.nd.array(filter_np, dev)
+            scale_tvm = tvm.nd.array(scale_np, dev)
+            shift_tvm = tvm.nd.array(shift_np, dev)
+            output_tvm = tvm.nd.array(
+                np.zeros(shape=get_const_tuple(C.shape), dtype=C.dtype),
+                dev,
+            )
+
+            mod(input_tvm, filter_tvm, scale_tvm, shift_tvm, output_tvm)
+
+            tol = {"rtol": 1e-4, "atol": 1e-5}
+            tvm.testing.assert_allclose(output_np, output_tvm.numpy(), **tol)
+
+
+class TestDepthwiseConv2D_MobilenetWorkloads(BaseDepthwiseConv2D):
+    """Extra tests to verify functionality for workloads used by mobilenet."""
+
+    layout = tvm.testing.parameter("NCHW", "NHWC")
+    use_scale_shift = tvm.testing.parameter(False, ids=["no_scale_shift"])
+    apply_relu = tvm.testing.parameter(False, ids=["no_relu"])
+
+    batch = tvm.testing.parameter(1)
+    channel_multiplier = tvm.testing.parameter(1)
+    kernel = tvm.testing.parameter(3)
+    padding = tvm.testing.parameter("SAME")
+    dilation = tvm.testing.parameter(1)
+
+    in_channel, in_size, stride = tvm.testing.parameters(
+        (32, 112, 1),
+        (64, 112, 2),
+        (128, 56, 1),
+        (128, 56, 2),
+        (256, 28, 1),
+    )
+
+
+class TestDepthwiseConv2D(BaseDepthwiseConv2D):
+
+    layout = tvm.testing.parameter("NCHW", "NHWC")
+    use_scale_shift = tvm.testing.parameter(True, False, ids=["with_scale_shift", "no_scale_shift"])
+    apply_relu = tvm.testing.parameter(True, False, ids=["with_relu", "no_relu"])
+
+    (batch, in_channel, in_size, channel_multiplier, kernel, stride) = tvm.testing.parameters(
+        (1, 64, 32, 1, 3, 1),
+        (1, 128, 64, 2, 5, 2),
+    )
+    padding = tvm.testing.parameter("VALID")
+    dilation = tvm.testing.parameter(1)
+
+
+# TODO(hexagon-team): add TestDepthwiseConv2D_NCHWc test.