You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2019/12/26 17:36:38 UTC

[incubator-tvm] branch master updated: [TOPI][AutoTVM] NHWC conv2d templates for ARM (#3859)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 672b090  [TOPI][AutoTVM] NHWC conv2d templates for ARM (#3859)
672b090 is described below

commit 672b0909dcba37b8adcdf5d963b7629c51033a88
Author: 黎明灰烬 <i...@jackwish.net>
AuthorDate: Fri Dec 27 01:36:31 2019 +0800

    [TOPI][AutoTVM] NHWC conv2d templates for ARM (#3859)
    
    * [AutoTVM][TOPI] NHWC conv2d templates (spatial pack) for ARM
    
    As some frontends (tflite for example) are using NHWC as the default
    layout, we are enabling NHWC schedule templates in TOPI and AutoTVM.
    
    * some comments fix
---
 python/tvm/autotvm/task/topi_integration.py     |   7 +-
 topi/python/topi/arm_cpu/conv2d.py              |  38 +++++-
 topi/python/topi/arm_cpu/conv2d_spatial_pack.py | 160 ++++++++++++++++++++++++
 3 files changed, 201 insertions(+), 4 deletions(-)

diff --git a/python/tvm/autotvm/task/topi_integration.py b/python/tvm/autotvm/task/topi_integration.py
index ce33d3e..8b3ba35 100644
--- a/python/tvm/autotvm/task/topi_integration.py
+++ b/python/tvm/autotvm/task/topi_integration.py
@@ -182,12 +182,15 @@ class TaskExtractEnv:
             args = deserialize_args(args)
             A, W = args[:2]
             layout = args[-2]
-            assert layout == 'NCHW' or layout == 'HWCN', "only support NCHW/HWCN currently"
             C = topi.nn.conv2d(*args, **kwargs)
             if layout == 'NCHW':
                 s = topi.generic.schedule_conv2d_nchw([C])
-            else:
+            elif layout == 'HWCN':
                 s = topi.generic.schedule_conv2d_hwcn([C])
+            elif layout == 'NHWC':
+                s = topi.generic.schedule_conv2d_nhwc([C])
+            else:
+                raise ValueError("Unsupported layout {}".format(layout))
             return s, [A, W, C]
 
         @register("topi_nn_depthwise_conv2d_nchw")
diff --git a/topi/python/topi/arm_cpu/conv2d.py b/topi/python/topi/arm_cpu/conv2d.py
index 6e95de5..673307a 100644
--- a/topi/python/topi/arm_cpu/conv2d.py
+++ b/topi/python/topi/arm_cpu/conv2d.py
@@ -24,7 +24,8 @@ import tvm
 from tvm import autotvm
 import tvm.contrib.nnpack
 
-from ..generic import schedule_conv2d_nchw, schedule_conv2d_winograd_without_weight_transform, \
+from ..generic import schedule_conv2d_nchw, schedule_conv2d_nhwc, \
+                      schedule_conv2d_winograd_without_weight_transform, \
                       schedule_conv2d_winograd_nnpack_without_weight_transform
 from ..util import traverse_inline, get_const_tuple
 from ..nn import dilate, pad, conv2d, conv2d_alter_layout, \
@@ -34,7 +35,9 @@ from ..nn import dilate, pad, conv2d, conv2d_alter_layout, \
 from ..nn.util import get_const_int, get_pad_tuple
 from ..nn.winograd_util import winograd_transform_matrices
 from .conv2d_spatial_pack import conv2d_spatial_pack_nchw, \
-                                 schedule_conv2d_spatial_pack_nchw
+                                 conv2d_spatial_pack_nhwc, \
+                                 schedule_conv2d_spatial_pack_nchw, \
+                                 schedule_conv2d_spatial_pack_nhwc
 
 logger = logging.getLogger('topi')
 
@@ -78,6 +81,9 @@ def conv2d_arm_cpu(cfg, data, kernel, strides, padding, dilation, layout, out_dt
     if layout == 'NCHW':
         return conv2d_spatial_pack_nchw(cfg, data, kernel, strides, padding,
                                         dilation, out_dtype, num_tile=2)
+    elif layout == 'NHWC':
+        return conv2d_spatial_pack_nhwc(cfg, data, kernel, strides, padding,
+                                        dilation, out_dtype)
     else:
         raise ValueError("Unsupported layout {}".format(layout))
 
@@ -136,6 +142,34 @@ def schedule_conv2d_nchw_arm_cpu(cfg, outs):
     traverse_inline(s, outs[0].op, _callback)
     return s
 
+@autotvm.register_topi_schedule(schedule_conv2d_nhwc, 'arm_cpu', ['direct'])
+def schedule_conv2d_nhwc_arm_cpu(cfg, outs):
+    """TOPI schedule callback for conv2d
+
+    Parameters
+    ----------
+    cfg: ConfigEntity
+        The config for this template
+
+    outs: Array of Tensor
+        The computation graph description of conv2d
+        in the format of an array of tensors.
+
+    Returns
+    -------
+    s: Schedule
+        The computation schedule for conv2d.
+    """
+    s = tvm.create_schedule([x.op for x in outs])
+
+    def _callback(op):
+        if 'spatial_conv_output_NHWC' in op.tag:
+            schedule_conv2d_spatial_pack_nhwc(cfg, s, op, outs[0])
+
+    traverse_inline(s, outs[0].op, _callback)
+    return s
+
+
 @autotvm.register_topi_compute(conv2d, 'arm_cpu', ['winograd'])
 def conv2d_arm_cpu_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dtype):
     """ TOPI compute callback. Use winograd template """
diff --git a/topi/python/topi/arm_cpu/conv2d_spatial_pack.py b/topi/python/topi/arm_cpu/conv2d_spatial_pack.py
index b566c98..350a022 100644
--- a/topi/python/topi/arm_cpu/conv2d_spatial_pack.py
+++ b/topi/python/topi/arm_cpu/conv2d_spatial_pack.py
@@ -196,3 +196,163 @@ def schedule_conv2d_spatial_pack_nchw(cfg, s, data_vec, kernel_vec,
         s[kernel_vec].parallel(co)
 
     return s
+
+def conv2d_spatial_pack_nhwc(cfg, data, kernel, strides, padding, dilation, out_dtype):
+    """Spatial pack compute for Conv2d NHWC"""
+    out_dtype = out_dtype or data.dtype
+
+    N, IH, IW, IC = get_const_tuple(data.shape)
+    assert len(kernel.shape) == 4, "AlterOpLayout not enabled for NHWC yet"
+    KH, KW, _, OC = get_const_tuple(kernel.shape)
+
+    if isinstance(dilation, int):
+        dilation_h = dilation_w = dilation
+    else:
+        dilation_h, dilation_w = dilation
+
+    dilated_kernel_h = (KH - 1) * dilation_h + 1
+    dilated_kernel_w = (KW - 1) * dilation_w + 1
+
+    pad_top, pad_left, pad_down, pad_right = \
+            get_pad_tuple(padding, (dilated_kernel_h, dilated_kernel_w))
+    HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
+
+    OH = (IH + pad_top + pad_down - dilated_kernel_h) // HSTR + 1
+    OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1
+    data_pad = nn.pad(data, [0, pad_top, pad_left, 0], [0, pad_down, pad_right, 0])
+
+    # ==================== define configuration space ====================
+    n, oc, oh, ow = cfg.axis(N), cfg.axis(OC), cfg.axis(OH), cfg.axis(OW)
+    ic, kh, kw = cfg.reduce_axis(IC), cfg.reduce_axis(KH), cfg.reduce_axis(KW)
+
+    oco, oci = cfg.define_split('tile_co', oc, num_outputs=2)
+    oho, ohi = cfg.define_split('tile_oh', oh, num_outputs=2)
+    owo, owi = cfg.define_split('tile_ow', ow, num_outputs=2)
+
+    cfg.define_reorder('reorder_conv',
+                       [n, oho, owo, oco, kh, kw, ic, ohi, owi, oci],
+                       policy='candidate', candidate=[
+                           [n, oho, owo, oco, kh, kw, ic, ohi, owi, oci],
+                           [n, oho, owo, oco, ohi, kh, kw, ic, owi, oci],
+                           [n, oho, owo, oco, ohi, kh, kw, owi, ic, oci],
+                           [n, oho, owo, ohi, oco, kh, kw, owi, ic, oci]])
+
+    cfg.define_annotate("ann_reduce", [kh, kw], policy='try_unroll')
+    cfg.define_annotate("ann_spatial", [ohi, owi, oci], policy='try_unroll_vec')
+    # ====================================================================
+
+    OCI = cfg['tile_co'].size[-1]
+    OHI = cfg['tile_oh'].size[-1]
+    OWI = cfg['tile_ow'].size[-1]
+    OCO = OC // OCI
+    OHO = OH // OHI
+    OWO = OW // OWI
+
+    kvshape = (OCO, KH, KW, IC, OCI)
+    ovshape = (N, OHO, OWO, OCO, OHI, OWI, OCI)
+    oshape = (N, OH, OW, OC)
+
+    if dilation_h != 1 or dilation_w != 1:
+        # undilate input data
+        dvshape = (N, OHO, OWO, KH, KW, IC, OHI, OWI)
+        data_vec = tvm.compute(dvshape, lambda n, oho, owo, kh, kw, ic, ohi, owi:
+                               data_pad[n][(oho*OHI+ohi)*HSTR+kh*dilation_h]
+                               [(owo*OWI+owi)*WSTR+kw*dilation_w][ic],
+                               name='data_vec_undilated')
+    else:
+        dvshape = (N, OHO, OWO, KH + (OHI-1)*HSTR, KW + (OWI-1)*WSTR, IC)
+        data_vec = tvm.compute(dvshape, lambda n, oho, owo, ohi, owi, ic:
+                               data_pad[n][oho*OHI*HSTR+ohi][owo*OWI*WSTR+owi][ic],
+                               name='data_vec')
+    kernel_vec = tvm.compute(kvshape, lambda oco, kh, kw, ic, oci: \
+                             kernel[kh][kw][ic][oco*OCI+oci],
+                             name='kernel_vec')
+
+    ic = tvm.reduce_axis((0, IC), name='ic')
+    kh = tvm.reduce_axis((0, KH), name='kh')
+    kw = tvm.reduce_axis((0, KW), name='kw')
+
+    if dilation_h != 1 or dilation_w != 1:
+        conv = tvm.compute(ovshape, lambda n, oho, owo, oco, ohi, owi, oci: \
+            tvm.sum(data_vec[n, oho, owo, kh, kw, ohi, owi, ic].astype(out_dtype) *
+                    kernel_vec[oco, kh, kw, ic, oci].astype(out_dtype),
+                    axis=[ic, kh, kw]), name='conv')
+    else:
+        conv = tvm.compute(ovshape, lambda n, oho, owo, oco, ohi, owi, oci: \
+            tvm.sum(data_vec[n, oho, owo, ohi*HSTR+kh, owi*WSTR+kw, ic].astype(out_dtype) *
+                    kernel_vec[oco, kh, kw, ic, oci].astype(out_dtype),
+                    axis=[ic, kh, kw]), name='conv')
+
+    idiv = tvm.indexdiv
+    imod = tvm.indexmod
+    output = tvm.compute(oshape, lambda n, oho, owo, oc:
+                         conv[n][idiv(oho, OHI)][idiv(owo, OWI)][idiv(oc, OCI)]\
+                             [imod(oho, OHI)][imod(owo, OWI)][imod(oc, OCI)],
+                         name='output_unpack', tag='spatial_conv_output_NHWC')
+    return output
+
+def schedule_conv2d_spatial_pack_nhwc(cfg, s, op, output):
+    """Spatial Pack schedule for Conv2d NHWC"""
+    unpack = op.output(0)
+    conv = unpack.op.input_tensors[0]
+    data_vec = conv.op.input_tensors[0]
+    kernel_vec = conv.op.input_tensors[1]
+    data_pad = data_vec.op.input_tensors[0]
+    OHI = cfg['tile_oh'].size[-1]
+    OWI = cfg['tile_ow'].size[-1]
+    OCI = cfg['tile_co'].size[-1]
+
+    # schedule unpack/output
+    if output != unpack:
+        s[unpack].compute_inline()
+    n, oh, ow, oc = s[output].op.axis
+    oco, oci = cfg['tile_co'].apply(s, output, oc)
+    oho, ohi = cfg['tile_oh'].apply(s, output, oh)
+    owo, owi = cfg['tile_ow'].apply(s, output, ow)
+    s[output].reorder(n, oho, owo, oco, ohi, owi, oci)
+    cfg['ann_spatial'].apply(s, output, [ohi, owi, oci], axis_lens=[OHI, OWI, OCI],
+                             max_unroll=16, cfg=cfg)
+    cfg.define_knob('compat', [0, 1, 2])
+    if cfg['compat'].val < 2:
+        compat_axis = [owo, oco][cfg['compat'].val] # pylint: disable=R1706
+        s[conv].compute_at(s[output], compat_axis)
+    paxis = s[output].fuse(n, oho)
+    s[output].parallel(paxis)
+
+    # schedule conv
+    n, oho, owo, oco, ohi, owi, oci = s[conv].op.axis
+    ic, kh, kw = s[conv].op.reduce_axis
+    cfg['reorder_conv'].apply(s, conv, [n, oho, owo, oco, kh, kw, ohi, owi, ic, oci])
+    cfg['ann_reduce'].apply(s, conv, [kh, kw],
+                            axis_lens=[get_const_int(kh.dom.extent),
+                                       get_const_int(kw.dom.extent)],
+                            max_unroll=16,
+                            cfg=cfg)
+    cfg['ann_spatial'].apply(s, conv, [ohi, owi, oci], axis_lens=[OHI, OWI, OCI],
+                             max_unroll=16, cfg=cfg)
+    if cfg['compat'].val < 2:
+        compat_axis = [owo, oco][cfg['compat'].val] # pylint: disable=R1706
+        s[kernel_vec].compute_at(s[conv], compat_axis)
+        s[data_vec].compute_at(s[conv], compat_axis)
+
+    # schedule kernel pack
+    oco, kh, kw, ic, oci = kernel_vec.op.axis
+    s[kernel_vec].vectorize(oci)
+    s[kernel_vec].unroll(ic)
+    if cfg['compat'].val == 2:
+        s[kernel_vec].parallel(oco)
+
+    # schedule data pack
+    if data_vec.op.name == 'data_vec_undilated':
+        n, oho, owo, kh, kw, ic, ohi, owi = s[data_vec].op.axis
+        s[data_vec].vectorize(owi)
+        s[data_vec].unroll(ohi)
+    else:
+        n, oho, owo, ohi, owi, ic = s[data_vec].op.axis
+        s[data_vec].vectorize(ic)
+        s[data_vec].unroll(owi)
+    if cfg['compat'].val == 2:
+        paxis = s[data_vec].fuse(n, oho)
+        s[data_vec].parallel(paxis)
+
+    return s