You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2019/12/26 17:36:38 UTC
[incubator-tvm] branch master updated: [TOPI][AutoTVM] NHWC conv2d
templates for ARM (#3859)
This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 672b090 [TOPI][AutoTVM] NHWC conv2d templates for ARM (#3859)
672b090 is described below
commit 672b0909dcba37b8adcdf5d963b7629c51033a88
Author: 黎明灰烬 <i...@jackwish.net>
AuthorDate: Fri Dec 27 01:36:31 2019 +0800
[TOPI][AutoTVM] NHWC conv2d templates for ARM (#3859)
* [AutoTVM][TOPI] NHWC conv2d templates (spatial pack) for ARM
As some frontends (tflite for example) are using NHWC as the default
layout, we are enabling NHWC schedule templates in TOPI and AutoTVM.
* some comments fix
---
python/tvm/autotvm/task/topi_integration.py | 7 +-
topi/python/topi/arm_cpu/conv2d.py | 38 +++++-
topi/python/topi/arm_cpu/conv2d_spatial_pack.py | 160 ++++++++++++++++++++++++
3 files changed, 201 insertions(+), 4 deletions(-)
diff --git a/python/tvm/autotvm/task/topi_integration.py b/python/tvm/autotvm/task/topi_integration.py
index ce33d3e..8b3ba35 100644
--- a/python/tvm/autotvm/task/topi_integration.py
+++ b/python/tvm/autotvm/task/topi_integration.py
@@ -182,12 +182,15 @@ class TaskExtractEnv:
args = deserialize_args(args)
A, W = args[:2]
layout = args[-2]
- assert layout == 'NCHW' or layout == 'HWCN', "only support NCHW/HWCN currently"
C = topi.nn.conv2d(*args, **kwargs)
if layout == 'NCHW':
s = topi.generic.schedule_conv2d_nchw([C])
- else:
+ elif layout == 'HWCN':
s = topi.generic.schedule_conv2d_hwcn([C])
+ elif layout == 'NHWC':
+ s = topi.generic.schedule_conv2d_nhwc([C])
+ else:
+ raise ValueError("Unsupported layout {}".format(layout))
return s, [A, W, C]
@register("topi_nn_depthwise_conv2d_nchw")
diff --git a/topi/python/topi/arm_cpu/conv2d.py b/topi/python/topi/arm_cpu/conv2d.py
index 6e95de5..673307a 100644
--- a/topi/python/topi/arm_cpu/conv2d.py
+++ b/topi/python/topi/arm_cpu/conv2d.py
@@ -24,7 +24,8 @@ import tvm
from tvm import autotvm
import tvm.contrib.nnpack
-from ..generic import schedule_conv2d_nchw, schedule_conv2d_winograd_without_weight_transform, \
+from ..generic import schedule_conv2d_nchw, schedule_conv2d_nhwc, \
+ schedule_conv2d_winograd_without_weight_transform, \
schedule_conv2d_winograd_nnpack_without_weight_transform
from ..util import traverse_inline, get_const_tuple
from ..nn import dilate, pad, conv2d, conv2d_alter_layout, \
@@ -34,7 +35,9 @@ from ..nn import dilate, pad, conv2d, conv2d_alter_layout, \
from ..nn.util import get_const_int, get_pad_tuple
from ..nn.winograd_util import winograd_transform_matrices
from .conv2d_spatial_pack import conv2d_spatial_pack_nchw, \
- schedule_conv2d_spatial_pack_nchw
+ conv2d_spatial_pack_nhwc, \
+ schedule_conv2d_spatial_pack_nchw, \
+ schedule_conv2d_spatial_pack_nhwc
logger = logging.getLogger('topi')
@@ -78,6 +81,9 @@ def conv2d_arm_cpu(cfg, data, kernel, strides, padding, dilation, layout, out_dt
if layout == 'NCHW':
return conv2d_spatial_pack_nchw(cfg, data, kernel, strides, padding,
dilation, out_dtype, num_tile=2)
+ elif layout == 'NHWC':
+ return conv2d_spatial_pack_nhwc(cfg, data, kernel, strides, padding,
+ dilation, out_dtype)
else:
raise ValueError("Unsupported layout {}".format(layout))
@@ -136,6 +142,34 @@ def schedule_conv2d_nchw_arm_cpu(cfg, outs):
traverse_inline(s, outs[0].op, _callback)
return s
+@autotvm.register_topi_schedule(schedule_conv2d_nhwc, 'arm_cpu', ['direct'])
+def schedule_conv2d_nhwc_arm_cpu(cfg, outs):
+ """TOPI schedule callback for conv2d
+
+ Parameters
+ ----------
+ cfg: ConfigEntity
+ The config for this template
+
+ outs: Array of Tensor
+ The computation graph description of conv2d
+ in the format of an array of tensors.
+
+ Returns
+ -------
+ s: Schedule
+ The computation schedule for conv2d.
+ """
+ s = tvm.create_schedule([x.op for x in outs])
+
+ def _callback(op):
+ if 'spatial_conv_output_NHWC' in op.tag:
+ schedule_conv2d_spatial_pack_nhwc(cfg, s, op, outs[0])
+
+ traverse_inline(s, outs[0].op, _callback)
+ return s
+
+
@autotvm.register_topi_compute(conv2d, 'arm_cpu', ['winograd'])
def conv2d_arm_cpu_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dtype):
""" TOPI compute callback. Use winograd template """
diff --git a/topi/python/topi/arm_cpu/conv2d_spatial_pack.py b/topi/python/topi/arm_cpu/conv2d_spatial_pack.py
index b566c98..350a022 100644
--- a/topi/python/topi/arm_cpu/conv2d_spatial_pack.py
+++ b/topi/python/topi/arm_cpu/conv2d_spatial_pack.py
@@ -196,3 +196,163 @@ def schedule_conv2d_spatial_pack_nchw(cfg, s, data_vec, kernel_vec,
s[kernel_vec].parallel(co)
return s
+
+def conv2d_spatial_pack_nhwc(cfg, data, kernel, strides, padding, dilation, out_dtype):
+ """Spatial pack compute for Conv2d NHWC"""
+ out_dtype = out_dtype or data.dtype
+
+ N, IH, IW, IC = get_const_tuple(data.shape)
+ assert len(kernel.shape) == 4, "AlterOpLayout not enabled for NHWC yet"
+ KH, KW, _, OC = get_const_tuple(kernel.shape)
+
+ if isinstance(dilation, int):
+ dilation_h = dilation_w = dilation
+ else:
+ dilation_h, dilation_w = dilation
+
+ dilated_kernel_h = (KH - 1) * dilation_h + 1
+ dilated_kernel_w = (KW - 1) * dilation_w + 1
+
+ pad_top, pad_left, pad_down, pad_right = \
+ get_pad_tuple(padding, (dilated_kernel_h, dilated_kernel_w))
+ HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
+
+ OH = (IH + pad_top + pad_down - dilated_kernel_h) // HSTR + 1
+ OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1
+ data_pad = nn.pad(data, [0, pad_top, pad_left, 0], [0, pad_down, pad_right, 0])
+
+ # ==================== define configuration space ====================
+ n, oc, oh, ow = cfg.axis(N), cfg.axis(OC), cfg.axis(OH), cfg.axis(OW)
+ ic, kh, kw = cfg.reduce_axis(IC), cfg.reduce_axis(KH), cfg.reduce_axis(KW)
+
+ oco, oci = cfg.define_split('tile_co', oc, num_outputs=2)
+ oho, ohi = cfg.define_split('tile_oh', oh, num_outputs=2)
+ owo, owi = cfg.define_split('tile_ow', ow, num_outputs=2)
+
+ cfg.define_reorder('reorder_conv',
+ [n, oho, owo, oco, kh, kw, ic, ohi, owi, oci],
+ policy='candidate', candidate=[
+ [n, oho, owo, oco, kh, kw, ic, ohi, owi, oci],
+ [n, oho, owo, oco, ohi, kh, kw, ic, owi, oci],
+ [n, oho, owo, oco, ohi, kh, kw, owi, ic, oci],
+ [n, oho, owo, ohi, oco, kh, kw, owi, ic, oci]])
+
+ cfg.define_annotate("ann_reduce", [kh, kw], policy='try_unroll')
+ cfg.define_annotate("ann_spatial", [ohi, owi, oci], policy='try_unroll_vec')
+ # ====================================================================
+
+ OCI = cfg['tile_co'].size[-1]
+ OHI = cfg['tile_oh'].size[-1]
+ OWI = cfg['tile_ow'].size[-1]
+ OCO = OC // OCI
+ OHO = OH // OHI
+ OWO = OW // OWI
+
+ kvshape = (OCO, KH, KW, IC, OCI)
+ ovshape = (N, OHO, OWO, OCO, OHI, OWI, OCI)
+ oshape = (N, OH, OW, OC)
+
+ if dilation_h != 1 or dilation_w != 1:
+ # undilate input data
+ dvshape = (N, OHO, OWO, KH, KW, IC, OHI, OWI)
+ data_vec = tvm.compute(dvshape, lambda n, oho, owo, kh, kw, ic, ohi, owi:
+ data_pad[n][(oho*OHI+ohi)*HSTR+kh*dilation_h]
+ [(owo*OWI+owi)*WSTR+kw*dilation_w][ic],
+ name='data_vec_undilated')
+ else:
+ dvshape = (N, OHO, OWO, KH + (OHI-1)*HSTR, KW + (OWI-1)*WSTR, IC)
+ data_vec = tvm.compute(dvshape, lambda n, oho, owo, ohi, owi, ic:
+ data_pad[n][oho*OHI*HSTR+ohi][owo*OWI*WSTR+owi][ic],
+ name='data_vec')
+ kernel_vec = tvm.compute(kvshape, lambda oco, kh, kw, ic, oci: \
+ kernel[kh][kw][ic][oco*OCI+oci],
+ name='kernel_vec')
+
+ ic = tvm.reduce_axis((0, IC), name='ic')
+ kh = tvm.reduce_axis((0, KH), name='kh')
+ kw = tvm.reduce_axis((0, KW), name='kw')
+
+ if dilation_h != 1 or dilation_w != 1:
+ conv = tvm.compute(ovshape, lambda n, oho, owo, oco, ohi, owi, oci: \
+ tvm.sum(data_vec[n, oho, owo, kh, kw, ohi, owi, ic].astype(out_dtype) *
+ kernel_vec[oco, kh, kw, ic, oci].astype(out_dtype),
+ axis=[ic, kh, kw]), name='conv')
+ else:
+ conv = tvm.compute(ovshape, lambda n, oho, owo, oco, ohi, owi, oci: \
+ tvm.sum(data_vec[n, oho, owo, ohi*HSTR+kh, owi*WSTR+kw, ic].astype(out_dtype) *
+ kernel_vec[oco, kh, kw, ic, oci].astype(out_dtype),
+ axis=[ic, kh, kw]), name='conv')
+
+ idiv = tvm.indexdiv
+ imod = tvm.indexmod
+ output = tvm.compute(oshape, lambda n, oho, owo, oc:
+ conv[n][idiv(oho, OHI)][idiv(owo, OWI)][idiv(oc, OCI)]\
+ [imod(oho, OHI)][imod(owo, OWI)][imod(oc, OCI)],
+ name='output_unpack', tag='spatial_conv_output_NHWC')
+ return output
+
+def schedule_conv2d_spatial_pack_nhwc(cfg, s, op, output):
+ """Spatial Pack schedule for Conv2d NHWC"""
+ unpack = op.output(0)
+ conv = unpack.op.input_tensors[0]
+ data_vec = conv.op.input_tensors[0]
+ kernel_vec = conv.op.input_tensors[1]
+ data_pad = data_vec.op.input_tensors[0]
+ OHI = cfg['tile_oh'].size[-1]
+ OWI = cfg['tile_ow'].size[-1]
+ OCI = cfg['tile_co'].size[-1]
+
+ # schedule unpack/output
+ if output != unpack:
+ s[unpack].compute_inline()
+ n, oh, ow, oc = s[output].op.axis
+ oco, oci = cfg['tile_co'].apply(s, output, oc)
+ oho, ohi = cfg['tile_oh'].apply(s, output, oh)
+ owo, owi = cfg['tile_ow'].apply(s, output, ow)
+ s[output].reorder(n, oho, owo, oco, ohi, owi, oci)
+ cfg['ann_spatial'].apply(s, output, [ohi, owi, oci], axis_lens=[OHI, OWI, OCI],
+ max_unroll=16, cfg=cfg)
+ cfg.define_knob('compat', [0, 1, 2])
+ if cfg['compat'].val < 2:
+ compat_axis = [owo, oco][cfg['compat'].val] # pylint: disable=R1706
+ s[conv].compute_at(s[output], compat_axis)
+ paxis = s[output].fuse(n, oho)
+ s[output].parallel(paxis)
+
+ # schedule conv
+ n, oho, owo, oco, ohi, owi, oci = s[conv].op.axis
+ ic, kh, kw = s[conv].op.reduce_axis
+ cfg['reorder_conv'].apply(s, conv, [n, oho, owo, oco, kh, kw, ohi, owi, ic, oci])
+ cfg['ann_reduce'].apply(s, conv, [kh, kw],
+ axis_lens=[get_const_int(kh.dom.extent),
+ get_const_int(kw.dom.extent)],
+ max_unroll=16,
+ cfg=cfg)
+ cfg['ann_spatial'].apply(s, conv, [ohi, owi, oci], axis_lens=[OHI, OWI, OCI],
+ max_unroll=16, cfg=cfg)
+ if cfg['compat'].val < 2:
+ compat_axis = [owo, oco][cfg['compat'].val] # pylint: disable=R1706
+ s[kernel_vec].compute_at(s[conv], compat_axis)
+ s[data_vec].compute_at(s[conv], compat_axis)
+
+ # schedule kernel pack
+ oco, kh, kw, ic, oci = kernel_vec.op.axis
+ s[kernel_vec].vectorize(oci)
+ s[kernel_vec].unroll(ic)
+ if cfg['compat'].val == 2:
+ s[kernel_vec].parallel(oco)
+
+ # schedule data pack
+ if data_vec.op.name == 'data_vec_undilated':
+ n, oho, owo, kh, kw, ic, ohi, owi = s[data_vec].op.axis
+ s[data_vec].vectorize(owi)
+ s[data_vec].unroll(ohi)
+ else:
+ n, oho, owo, ohi, owi, ic = s[data_vec].op.axis
+ s[data_vec].vectorize(ic)
+ s[data_vec].unroll(owi)
+ if cfg['compat'].val == 2:
+ paxis = s[data_vec].fuse(n, oho)
+ s[data_vec].parallel(paxis)
+
+ return s