You are viewing a plain text version of this content. The canonical link for it is here.
Posted to by GitBox <> on 2021/01/06 18:28:13 UTC

[GitHub] [tvm] Wheest commented on a change in pull request #6137: Better grouped convolution for CPU targets

Wheest commented on a change in pull request #6137:

File path: python/tvm/topi/x86/
@@ -0,0 +1,371 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,unused-variable,unused-argument,no-member
+# pylint: disable=no-value-for-parameter,import-outside-toplevel
+"""Grouped Spatial Pack Convolution (Group Conv2D) schedule on x86"""
+import tvm
+from tvm import autotvm
+from tvm import te
+from import SplitEntity, OtherOptionEntity
+from .utils import get_fp32_len
+from ..utils import get_const_tuple
+from ..nn.pad import pad
+from .. import tag
+from ..nn.utils import infer_pad
+from ..nn.conv2d import _get_workload as _get_conv2d_workload
+def group_conv2d_nchw(data, kernel, strides, padding, dilation, groups, out_dtype):
+    """Compute group_conv2d with NCHW layout"""
+    return group_conv2d_nchw_spatial_pack(
+        data, kernel, strides, padding, dilation, groups, out_dtype
+    )
+def schedule_group_conv2d_nchw(outs):
+    """Compute group_conv2d with NCHW layout"""
+    return schedule_group_conv2d_nchwc(outs)
+def _get_default_config(cfg, data, kernel, strides, padding, groups, out_dtype, layout="NCHW"):
+    """
+    Get default schedule config for the workload
+    """
+    static_data_shape = []
+    for dim in get_const_tuple(data.shape):
+        if isinstance(dim, tvm.tir.Var):
+            static_data_shape.append(1)
+        else:
+            static_data_shape.append(dim)
+    data = te.placeholder(static_data_shape, dtype=data.dtype)
+    wkl = _get_conv2d_workload(data, kernel, strides, padding, out_dtype, layout)
+    _fallback_schedule(cfg, wkl)
+def _fallback_schedule(cfg, wkl):
+    simd_width = get_fp32_len()
+    hpad = wkl.hpad
+    stride_w = wkl.wstride
+    out_width = (wkl.width + 2 * hpad - wkl.wkernel) // stride_w + 1
+    groups = wkl.groups
+    kernels_per_group = wkl.out_filter // groups
+    kernel_depth = wkl.in_filter // groups
+    oc_bn = 1
+    for bn in range(simd_width, 0, -1):
+        if kernels_per_group % bn == 0:
+            oc_bn = bn
+            break
+    if oc_bn > kernels_per_group:
+        oc_bn = kernels_per_group
+    ic_bn = 1
+    for bn in range(oc_bn, 0, -1):
+        if kernel_depth % bn == 0:
+            ic_bn = bn
+            break
+    if ic_bn > kernel_depth:
+        ic_bn = kernel_depth
+    reg_n = 1
+    for n in range(31, 0, -1):
+        if out_width % n == 0:
+            reg_n = n
+            break
+    cfg["tile_ic"] = SplitEntity([wkl.in_filter // ic_bn, ic_bn])
+    cfg["tile_oc"] = SplitEntity([wkl.out_filter // oc_bn, oc_bn])
+    cfg["tile_ow"] = SplitEntity([out_width // reg_n, reg_n])
+    cfg["unroll_kw"] = OtherOptionEntity(False)
+def group_conv2d_nchw_spatial_pack(
+    cfg, data, kernel, strides, padding, dilation, groups, out_dtype="float32"
+    """
+    Compute group conv2d with NCHW layout, using GSPC algorithm.
+    """
+    assert isinstance(dilation, int) or len(dilation) == 2
+    if isinstance(dilation, int):
+        dilation_h, dilation_w = dilation, dilation
+    else:
+        dilation_h, dilation_w = dilation
+    assert isinstance(padding, int) or len(padding) == 2 or len(padding) == 4
+    if isinstance(padding, int):
+        pad_top, pad_left, pad_bottom, pad_right = padding, padding, padding, padding
+    elif len(padding) == 2:
+        hpad, wpad = padding
+        pad_top, pad_bottom = hpad, hpad
+        pad_left, pad_right = wpad, wpad
+    else:
+        pad_top, pad_left, pad_bottom, pad_right = padding
+    hpad = pad_top + pad_bottom
+    wpad = pad_left + pad_right
+    assert isinstance(strides, int) or len(strides) == 2
+    if isinstance(strides, int):
+        stride_h, stride_w = strides, strides
+    else:
+        stride_h, stride_w = strides
+    batch_size, in_channel, in_height, in_width = get_const_tuple(data.shape)
+    out_channel, kernel_depth, k_height, k_width = get_const_tuple(kernel.shape)
+    pad_height = in_height + pad_top + pad_bottom
+    pad_width = in_width + pad_left + pad_right
+    dilated_kernel_h = (k_height - 1) * dilation_h + 1
+    dilated_kernel_w = (k_width - 1) * dilation_w + 1
+    out_height = (in_height + pad_top + pad_bottom - dilated_kernel_h) // stride_h + 1
+    out_width = (in_width + pad_left + pad_right - dilated_kernel_w) // stride_w + 1
+    kernels_per_group = out_channel // groups
+    cfg.define_split("tile_ic", in_channel, num_outputs=2)
+    cfg.define_split("tile_oc", out_channel, num_outputs=2)
+    cfg.define_split("tile_ow", out_width, num_outputs=2, filter=lambda y: y.size[-1] <= 64)
+    cfg.define_knob("unroll_kw", [True, False])
+    # If no config was set, we can fallback to default config.
+    if cfg.is_fallback:
+        _get_default_config(
+            cfg,
+            te.placeholder((batch_size, in_channel, in_height, in_width), dtype=data.dtype),
+            te.placeholder(
+                (out_channel, in_channel // groups, k_height, k_width), dtype=kernel.dtype
+            ),
+            strides,
+            padding,
+            groups,
+            out_dtype,
+        )
+    oc_bn = cfg["tile_oc"].size[-1]
+    ic_bn = cfg["tile_ic"].size[-1]
+    # pack data
+    DOPAD = hpad != 0 or wpad != 0
+    if DOPAD:
+        data_pad = pad(
+            data, (0, 0, pad_top, pad_left), (0, 0, pad_bottom, pad_right), name="data_pad"
+        )
+    else:
+        data_pad = data
+    shape = (groups, batch_size, kernel_depth // ic_bn, pad_height, ic_bn, pad_width)
+    data_vec = te.compute(
+        shape,
+        lambda g, n, C, h, c, w: data_pad[n, C * ic_bn + c + kernel_depth * g, h, w],
+        name="data_vec",
+    )
+    # pack kernel
+    shape = (
+        groups,
+        kernels_per_group // oc_bn,
+        kernel_depth // ic_bn,
+        k_height,
+        k_width,
+        ic_bn,
+        oc_bn,
+    )
+    kernel_vec = te.compute(
+        shape,
+        lambda g, out_channel, in_channel, h, w, ci, co: kernel[
+            (out_channel * oc_bn + co + g * kernels_per_group), in_channel * ic_bn + ci, h, w
+        ],
+        name="kernel_vec",
+    )
+    # convolution
+    oshape = (groups, batch_size, kernels_per_group // oc_bn, out_height, out_width, oc_bn)
+    unpack_shape = (batch_size, out_channel, out_height, out_width)
+    ic = te.reduce_axis((0, (kernel_depth)), name="ic")
+    kh = te.reduce_axis((0, k_height), name="kh")
+    kw = te.reduce_axis((0, k_width), name="kw")
+    idxmod = tvm.tir.indexmod
+    idxdiv = tvm.tir.indexdiv
+    conv = te.compute(
+        oshape,
+        lambda g, n, oc_chunk, oh, ow, oc_block: te.sum(
+            data_vec[
+                g,
+                n,
+                idxdiv(ic, ic_bn),
+                oh * stride_h + kh * dilation_h,
+                idxmod(ic, ic_bn),
+                ow * stride_w + kw * dilation_w,
+            ].astype(out_dtype)
+            * kernel_vec[
+                g, oc_chunk, idxdiv(ic, ic_bn), kh, kw, idxmod(ic, ic_bn), oc_block
+            ].astype(out_dtype),
+            axis=[ic, kh, kw],
+        ),
+        name="conv",
+    )
+    unpack = te.compute(
+        unpack_shape,
+        lambda n, c, h, w: conv[
+            idxdiv(c, kernels_per_group),
+            n,
+            idxmod(idxdiv(c, oc_bn), (kernels_per_group // oc_bn)),
+            h,
+            w,
+            idxmod(idxmod(c, oc_bn), kernels_per_group),
+        ].astype(out_dtype),
+        name="output_unpack",
+        tag="group_conv2d_nchw",
+    )
+    return unpack
+def schedule_group_conv2d_nchwc(cfg, outs):
+    """Create schedule for tensors"""
+    s = te.create_schedule([x.op for x in outs])
+    scheduled_ops = []
+    def traverse(op):
+        """Traverse operators from computation graph"""
+        # inline all one-to-one-mapping operators except the last stage (output)
+        if tag.is_broadcast(op.tag):
+            if op not in s.outputs:
+                s[op].compute_inline()
+            for tensor in op.input_tensors:
+                if isinstance(tensor.op, tvm.te.ComputeOp) and tensor.op not in scheduled_ops:
+                    traverse(tensor.op)
+        if "group_conv2d_nchw" in op.tag:
+            output = op.output(0)
+            if "tile_ic" not in cfg:
+                return
+            conv_out = op.input_tensors[0]
+            kernel_vec = conv_out.op.input_tensors[1]
+            kernel = kernel_vec.op.input_tensors[0]
+            if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag:
+                s[kernel].compute_inline()
+            data_vec = conv_out.op.input_tensors[0]
+            data = data_vec.op.input_tensors[0]
+            data_pad = None
+            if isinstance(data.op, tvm.te.ComputeOp) and "pad" in data.op.tag:
+                data_pad = data
+                data = data_pad.op.input_tensors[0]
+            args = [s, cfg, data, data_pad, data_vec, kernel_vec, conv_out, output, outs[0]]
+            _schedule_gspc_nchw(*args)
+        scheduled_ops.append(op)
+    traverse(outs[0].op)
+    return s
+def _schedule_gspc_nchw(s, cfg, data, data_pad, data_vec, kernel_vec, conv_out, output, last):
+    """Schedule GSPC"""
+    ic_bn, oc_bn, reg_n, unroll_kw = (
+        cfg["tile_ic"].size[-1],
+        cfg["tile_oc"].size[-1],
+        cfg["tile_ow"].size[-1],
+        cfg["unroll_kw"].val,
+    )
+    # no stride and padding info here
+    padding = infer_pad(data, data_pad)
+    hpad, wpad = padding

Review comment:
       We only get `hpad` and `wpad` from [nn.utils.infer_pad](, since this information can't be recovered from `data` and `data_pad` alone.  This has no effect on the schedule, since we only use it to get the `DOPAD` variable.

This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at: