You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/07/20 19:13:25 UTC

[GitHub] [incubator-tvm] giuseros opened a new pull request #6095: Improve NHWC depthwise convolution for AArch64

giuseros opened a new pull request #6095:
URL: https://github.com/apache/incubator-tvm/pull/6095


   We created a default schedule (no auto-tuning or tensorization) named
   `depthwise_conv2d_nhwc` which does a decent job at optimizing depthwise
   for NHWC layouts (on AArch64 architectures).
   
   The schedule lives in : `topi/python/topi/arm_cpu/depthwise_conv2d.py`
   While we register the strategy in : `python/tvm/relay/op/strategy/arm_cpu.py`


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] giuseros commented on pull request #6095: Improve NHWC depthwise convolution for AArch64

Posted by GitBox <gi...@apache.org>.
giuseros commented on pull request #6095:
URL: https://github.com/apache/incubator-tvm/pull/6095#issuecomment-663233922


   I probably enabled a CUDA test that was making the CI hang. I reverted the test hoping that this was the issue


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] FrozenGene commented on a change in pull request #6095: Improve NHWC depthwise convolution for AArch64

Posted by GitBox <gi...@apache.org>.
FrozenGene commented on a change in pull request #6095:
URL: https://github.com/apache/incubator-tvm/pull/6095#discussion_r459441879



##########
File path: topi/python/topi/arm_cpu/depthwise_conv2d.py
##########
@@ -181,6 +181,170 @@ def depthwise_conv2d_nchw_spatial_pack(cfg, data, kernel, strides, padding, dila
 
     return _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype, num_tile=2)
 
+@autotvm.register_topi_compute("depthwise_conv2d_nhwc.arm_cpu")
+def compute_depthwise_conv2d_nhwc(_, data, kernel, strides, padding, dilation, out_dtype):
+    """TOPI compute callback for depthwise_conv2d nhwc
+
+    Parameters
+    ----------
+    cfg: ConfigEntity
+        The config for this template
+
+    data : tvm.te.Tensor
+        4-D with shape [batch, in_height, in_width, in_channel]
+
+    kernel : tvm.te.Tensor
+        4-D with shape [filter_height, filter_width, in_channel, channel_multiplier]
+
+    strides : list of two ints
+        [stride_height, stride_width]
+
+    padding : list of two ints
+        [pad_height, pad_width]
+
+    dilation : list of two ints
+        [dilation_height, dilation_width]
+
+    out_dtype: str
+        The output type. This is used for mixed precision.
+
+    Returns
+    -------
+    output : tvm.te.Tensor
+        4-D with shape [batch, out_height, out_width, out_channel]
+    """
+
+    out_dtype = out_dtype or data.dtype
+
+    N, IH, IW, IC = get_const_tuple(data.shape)
+
+    if isinstance(dilation, int):
+        dilation_h = dilation_w = dilation
+    else:
+        dilation_h, dilation_w = dilation
+
+    KH, KW, IC, channel_multiplier = get_const_tuple(kernel.shape)
+
+    dilated_kernel_h = (KH - 1) * dilation_h + 1
+    dilated_kernel_w = (KW - 1) * dilation_w + 1
+
+    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
+        padding, (dilated_kernel_h, dilated_kernel_w))
+    HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
+
+    OH = (IH + pad_top + pad_down - dilated_kernel_h) // HSTR + 1
+    OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1
+
+    if pad_top or pad_left or pad_down or pad_right:
+        data_pad = nn.pad(data, [0, pad_top, pad_left, 0], [0, pad_down, pad_right, 0],
+                          name="data_pad")
+    else:
+        data_pad = data
+
+    output_shape = (N, OH, OW, IC*channel_multiplier)
+
+    idxdiv = tvm.tir.indexdiv
+    idxmod = tvm.tir.indexmod
+
+    reduce_h = te.reduce_axis((0, KH), name='reduce_h')
+    reduce_w = te.reduce_axis((0, KW), name='reduce_w')
+
+    out = te.compute(output_shape, lambda n, h, w, c:
+                     te.sum(data_pad[n,
+                                     HSTR*h+dilation_h*reduce_h,
+                                     w*WSTR+reduce_w*dilation_w,
+                                     idxdiv(c, channel_multiplier)].astype(out_dtype) *
+                            kernel[reduce_h,
+                                   reduce_w,
+                                   idxdiv(c, channel_multiplier),
+                                   idxmod(c, channel_multiplier)].astype(out_dtype),
+                            axis=[reduce_h, reduce_w]),
+                     name='depthwise_conv2d_nhwc_output')
+    return out
+
+@autotvm.register_topi_schedule("depthwise_conv2d_nhwc.arm_cpu")
+def schedule_depthwise_conv2d_nhwc(cfg, outs):
+    """Create the schedule for depthwise_conv2d_nchw_spatial_pack"""
+    outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
+    s = te.create_schedule([x.op for x in outs])
+    out = outs[0]
+
+    ##### space definition begin #####
+    n, h, w, c = s[out].op.axis
+    cfg.define_split('tile_c', c, num_outputs=2)
+    _, hi = cfg.define_split('tile_h', h, num_outputs=2)
+    _, wi = cfg.define_split('tile_w', w, num_outputs=2)
+    cfg.define_knob('locate_output', [0, 1])
+
+    # fallback support
+    if cfg.is_fallback:
+        cfg['tile_c'] = SplitEntity([-1, 8])
+        cfg['tile_h'] = SplitEntity([-1, 2])
+        cfg['tile_w'] = SplitEntity([-1, 2])
+        cfg['locate_output'] = OtherOptionEntity(1)
+    ##### space definition end #####
+
+    def schedule_conv(conv):
+        conv_data = conv.op.input_tensors[0]
+
+        n, w, h, c = conv.op.axis
+        r_h, r_w = conv.op.reduce_axis
+        ho, hi = cfg['tile_h'].apply(s, conv, h)
+        wo, wi = cfg['tile_w'].apply(s, conv, w)
+        co, ci = cfg['tile_c'].apply(s, conv, c)
+
+        if conv_data.name == "data_pad":
+            # Define a policy for padding computation

Review comment:
       Let us add one assert `assert isinstance(conv_data.op, tvm.te.ComputeOp)`




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] giuseros commented on a change in pull request #6095: Improve NHWC depthwise convolution for AArch64

Posted by GitBox <gi...@apache.org>.
giuseros commented on a change in pull request #6095:
URL: https://github.com/apache/incubator-tvm/pull/6095#discussion_r458905799



##########
File path: topi/python/topi/arm_cpu/depthwise_conv2d.py
##########
@@ -181,6 +181,154 @@ def depthwise_conv2d_nchw_spatial_pack(cfg, data, kernel, strides, padding, dila
 
     return _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype, num_tile=2)
 
+@autotvm.register_topi_compute("depthwise_conv2d_nhwc.arm_cpu")
+def compute_depthwise_conv2d_nhwc(_, data, kernel, strides, padding, dilation, out_dtype):
+    """TOPI compute callback for depthwise_conv2d nhwc
+
+    Parameters
+    ----------
+    cfg: ConfigEntity
+        The config for this template
+
+    data : tvm.te.Tensor
+        4-D with shape [batch, in_height, in_width, in_channel]
+
+    kernel : tvm.te.Tensor
+        4-D with shape [filter_height, filter_width, in_channel, channel_multiplier]
+
+    strides : list of two ints
+        [stride_height, stride_width]
+
+    padding : list of two ints
+        [pad_height, pad_width]
+
+    dilation : list of two ints
+        [dilation_height, dilation_width]
+
+    out_dtype: str
+        The output type. This is used for mixed precision.
+
+    Returns
+    -------
+    output : tvm.te.Tensor
+        4-D with shape [batch, out_height, out_width, out_channel]
+    """
+
+    out_dtype = out_dtype or data.dtype
+
+    N, IH, IW, IC = get_const_tuple(data.shape)
+
+    if isinstance(dilation, int):
+        dilation_h = dilation_w = dilation
+    else:
+        dilation_h, dilation_w = dilation
+
+    KH, KW, IC, channel_multiplier = get_const_tuple(kernel.shape)
+
+    dilated_kernel_h = (KH - 1) * dilation_h + 1
+    dilated_kernel_w = (KW - 1) * dilation_w + 1
+
+    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
+        padding, (dilated_kernel_h, dilated_kernel_w))
+    HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
+
+    OH = (IH + pad_top + pad_down - dilated_kernel_h) // HSTR + 1
+    OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1
+
+    if pad_top or pad_left:
+        data_pad = nn.pad(data, [0, pad_top, pad_left, 0], [0, pad_down, pad_right, 0],
+                          name="data_pad")
+    else:
+        data_pad = data
+
+    output_shape = (N, OH, OW, IC*channel_multiplier)
+
+    idxdiv = tvm.tir.indexdiv
+    idxmod = tvm.tir.indexmod
+
+    reduce_h = te.reduce_axis((0, KH), name='reduce_h')
+    reduce_w = te.reduce_axis((0, KW), name='reduce_w')
+
+    out = te.compute(output_shape, lambda n, h, w, c:
+                     te.sum(data_pad[n,
+                                     HSTR*h+dilation_h*reduce_h,
+                                     w*WSTR+reduce_w*dilation_w,
+                                     idxdiv(c, channel_multiplier)].astype(out_dtype) *
+                            kernel[reduce_h,
+                                   reduce_w,
+                                   idxdiv(c, channel_multiplier),
+                                   idxmod(c, channel_multiplier)].astype(out_dtype),
+                            axis=[reduce_h, reduce_w]),
+                     name='depthwise_conv2d_nhwc_output')
+    return out
+
+@autotvm.register_topi_schedule("depthwise_conv2d_nhwc.arm_cpu")
+def schedule_depthwise_conv2d_nhwc(cfg, outs):
+    """Create the schedule for depthwise_conv2d_nchw_spatial_pack"""
+    outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
+    s = te.create_schedule([x.op for x in outs])
+    out = outs[0]
+
+    ##### space definition begin #####
+    n, h, w, c = s[out].op.axis
+    cfg.define_split('tile_c', c, num_outputs=2)
+    _, hi = cfg.define_split('tile_h', h, num_outputs=2)
+    _, wi = cfg.define_split('tile_w', w, num_outputs=2)
+    cfg.define_annotate('locate_output', [hi, wi], 'locate_cache', num_anchor=1)
+
+    # fallback support
+    if cfg.is_fallback:
+        cfg['tile_c'] = SplitEntity([-1, 8])
+        cfg['tile_h'] = SplitEntity([-1, 2])
+        cfg['tile_w'] = SplitEntity([-1, 2])
+        cfg['locate_output'] = AnnotateEntity([1])
+    ##### space definition end #####
+
+    def schedule_conv(conv):
+        conv_data = conv.op.input_tensors[0]
+        if conv_data.name == "data_pad":
+            s[conv_data].compute_inline()

Review comment:
       I would resolve this conversation now (as soon as this is in I will update the RFC on discuss)




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] giuseros commented on a change in pull request #6095: Improve NHWC depthwise convolution for AArch64

Posted by GitBox <gi...@apache.org>.
giuseros commented on a change in pull request #6095:
URL: https://github.com/apache/incubator-tvm/pull/6095#discussion_r458814745



##########
File path: topi/python/topi/arm_cpu/depthwise_conv2d.py
##########
@@ -181,6 +181,154 @@ def depthwise_conv2d_nchw_spatial_pack(cfg, data, kernel, strides, padding, dila
 
     return _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype, num_tile=2)
 
+@autotvm.register_topi_compute("depthwise_conv2d_nhwc.arm_cpu")
+def compute_depthwise_conv2d_nhwc(_, data, kernel, strides, padding, dilation, out_dtype):
+    """TOPI compute callback for depthwise_conv2d nhwc
+
+    Parameters
+    ----------
+    cfg: ConfigEntity
+        The config for this template
+
+    data : tvm.te.Tensor
+        4-D with shape [batch, in_height, in_width, in_channel]
+
+    kernel : tvm.te.Tensor
+        4-D with shape [filter_height, filter_width, in_channel, channel_multiplier]
+
+    strides : list of two ints
+        [stride_height, stride_width]
+
+    padding : list of two ints
+        [pad_height, pad_width]
+
+    dilation : list of two ints
+        [dilation_height, dilation_width]
+
+    out_dtype: str
+        The output type. This is used for mixed precision.
+
+    Returns
+    -------
+    output : tvm.te.Tensor
+        4-D with shape [batch, out_height, out_width, out_channel]
+    """
+
+    out_dtype = out_dtype or data.dtype
+
+    N, IH, IW, IC = get_const_tuple(data.shape)
+
+    if isinstance(dilation, int):
+        dilation_h = dilation_w = dilation
+    else:
+        dilation_h, dilation_w = dilation
+
+    KH, KW, IC, channel_multiplier = get_const_tuple(kernel.shape)
+
+    dilated_kernel_h = (KH - 1) * dilation_h + 1
+    dilated_kernel_w = (KW - 1) * dilation_w + 1
+
+    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
+        padding, (dilated_kernel_h, dilated_kernel_w))
+    HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
+
+    OH = (IH + pad_top + pad_down - dilated_kernel_h) // HSTR + 1
+    OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1
+
+    if pad_top or pad_left:
+        data_pad = nn.pad(data, [0, pad_top, pad_left, 0], [0, pad_down, pad_right, 0],
+                          name="data_pad")
+    else:
+        data_pad = data
+
+    output_shape = (N, OH, OW, IC*channel_multiplier)
+
+    idxdiv = tvm.tir.indexdiv
+    idxmod = tvm.tir.indexmod
+
+    reduce_h = te.reduce_axis((0, KH), name='reduce_h')
+    reduce_w = te.reduce_axis((0, KW), name='reduce_w')
+
+    out = te.compute(output_shape, lambda n, h, w, c:
+                     te.sum(data_pad[n,
+                                     HSTR*h+dilation_h*reduce_h,
+                                     w*WSTR+reduce_w*dilation_w,
+                                     idxdiv(c, channel_multiplier)].astype(out_dtype) *
+                            kernel[reduce_h,
+                                   reduce_w,
+                                   idxdiv(c, channel_multiplier),
+                                   idxmod(c, channel_multiplier)].astype(out_dtype),
+                            axis=[reduce_h, reduce_w]),
+                     name='depthwise_conv2d_nhwc_output')
+    return out
+
+@autotvm.register_topi_schedule("depthwise_conv2d_nhwc.arm_cpu")
+def schedule_depthwise_conv2d_nhwc(cfg, outs):
+    """Create the schedule for depthwise_conv2d_nchw_spatial_pack"""
+    outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
+    s = te.create_schedule([x.op for x in outs])
+    out = outs[0]
+
+    ##### space definition begin #####
+    n, h, w, c = s[out].op.axis
+    cfg.define_split('tile_c', c, num_outputs=2)
+    _, hi = cfg.define_split('tile_h', h, num_outputs=2)
+    _, wi = cfg.define_split('tile_w', w, num_outputs=2)
+    cfg.define_annotate('locate_output', [hi, wi], 'locate_cache', num_anchor=1)
+
+    # fallback support
+    if cfg.is_fallback:
+        cfg['tile_c'] = SplitEntity([-1, 8])
+        cfg['tile_h'] = SplitEntity([-1, 2])
+        cfg['tile_w'] = SplitEntity([-1, 2])
+        cfg['locate_output'] = AnnotateEntity([1])
+    ##### space definition end #####
+
+    def schedule_conv(conv):
+        conv_data = conv.op.input_tensors[0]
+        if conv_data.name == "data_pad":
+            s[conv_data].compute_inline()

Review comment:
       Since those will be other 4 knobs to add, I extracted and tuned the depthwise operators in mobilenet_v2 with the 4 different policies (no padding, inline, `compute_at{ho,wo}`) and reported the results in terms of TFlite/TVM ratio (higher is better)
   
   |H/W | C   |S  |  inline                |  compute_at(ho)      |compute_at(wo)       | no-inline            |
   |----|-----|---|------------------------|----------------------|---------------------|----------------------|
   |112 | 96  |2  |  1.452941176470588     |  0.7042857142857142  |0.8355932203389829   | 0.36249999999999993  |
   |56  |144  |1  |  1.7249999999999999    |  0.85                |0.9714285714285715   | 1.38                 |
   |56  |144  |2  |  3.028571428571429     |  1.3187499999999999  |0.45869565217391306  | 1.5214285714285716   |
   |28  |192  |1  |  1.711111111111111     |  0.76                |0.5166666666666667   | 1.409090909090909    |
   |28  |192  |2  |  1.6833333333333333    |  0.5666666666666667  |0.48095238095238096  | 1.442857142857143    |
   |14  |384  |1  |  3.15                  |  1.26                |0.63                 | 0.5727272727272728   |
   |14  |576  |1  |  0.8863636363636364    |  0.97                |0.40625              | 0.527027027027027    |
   |14  |576  |2  |  2.4                   |  0.6857142857142858  |0.7000000000000001   | 0.6714285714285715   |
   |7   |960  |1  |  2.9272727272727272    |  1.211320754716981   |0.9056338028169014   | 1.3416666666666668   |
   
   * Since this is a memory bound operator, not inlining padding is always going to behave poorly (I think), so I would remove it to reduce the tuning time. There would be the argument also to avoid the `compute_at` policies, but since I didn't try other networks (and in small cases it seems to run better) I would leave those 3 knobs (instead of 4)
   * Except one case, we are always faster (sometimes a lot faster) than TFlite.  Once I am done with this and other few improvements I will compare with ACL as well. 
   




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] FrozenGene commented on a change in pull request #6095: Improve NHWC depthwise convolution for AArch64

Posted by GitBox <gi...@apache.org>.
FrozenGene commented on a change in pull request #6095:
URL: https://github.com/apache/incubator-tvm/pull/6095#discussion_r458834973



##########
File path: topi/python/topi/arm_cpu/depthwise_conv2d.py
##########
@@ -181,6 +181,154 @@ def depthwise_conv2d_nchw_spatial_pack(cfg, data, kernel, strides, padding, dila
 
     return _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype, num_tile=2)
 
+@autotvm.register_topi_compute("depthwise_conv2d_nhwc.arm_cpu")
+def compute_depthwise_conv2d_nhwc(_, data, kernel, strides, padding, dilation, out_dtype):
+    """TOPI compute callback for depthwise_conv2d nhwc
+
+    Parameters
+    ----------
+    cfg: ConfigEntity
+        The config for this template
+
+    data : tvm.te.Tensor
+        4-D with shape [batch, in_height, in_width, in_channel]
+
+    kernel : tvm.te.Tensor
+        4-D with shape [filter_height, filter_width, in_channel, channel_multiplier]
+
+    strides : list of two ints
+        [stride_height, stride_width]
+
+    padding : list of two ints
+        [pad_height, pad_width]
+
+    dilation : list of two ints
+        [dilation_height, dilation_width]
+
+    out_dtype: str
+        The output type. This is used for mixed precision.
+
+    Returns
+    -------
+    output : tvm.te.Tensor
+        4-D with shape [batch, out_height, out_width, out_channel]
+    """
+
+    out_dtype = out_dtype or data.dtype
+
+    N, IH, IW, IC = get_const_tuple(data.shape)
+
+    if isinstance(dilation, int):
+        dilation_h = dilation_w = dilation
+    else:
+        dilation_h, dilation_w = dilation
+
+    KH, KW, IC, channel_multiplier = get_const_tuple(kernel.shape)
+
+    dilated_kernel_h = (KH - 1) * dilation_h + 1
+    dilated_kernel_w = (KW - 1) * dilation_w + 1
+
+    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
+        padding, (dilated_kernel_h, dilated_kernel_w))
+    HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
+
+    OH = (IH + pad_top + pad_down - dilated_kernel_h) // HSTR + 1
+    OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1
+
+    if pad_top or pad_left:
+        data_pad = nn.pad(data, [0, pad_top, pad_left, 0], [0, pad_down, pad_right, 0],
+                          name="data_pad")
+    else:
+        data_pad = data
+
+    output_shape = (N, OH, OW, IC*channel_multiplier)
+
+    idxdiv = tvm.tir.indexdiv
+    idxmod = tvm.tir.indexmod
+
+    reduce_h = te.reduce_axis((0, KH), name='reduce_h')
+    reduce_w = te.reduce_axis((0, KW), name='reduce_w')
+
+    out = te.compute(output_shape, lambda n, h, w, c:
+                     te.sum(data_pad[n,
+                                     HSTR*h+dilation_h*reduce_h,
+                                     w*WSTR+reduce_w*dilation_w,
+                                     idxdiv(c, channel_multiplier)].astype(out_dtype) *
+                            kernel[reduce_h,
+                                   reduce_w,
+                                   idxdiv(c, channel_multiplier),
+                                   idxmod(c, channel_multiplier)].astype(out_dtype),
+                            axis=[reduce_h, reduce_w]),
+                     name='depthwise_conv2d_nhwc_output')
+    return out
+
+@autotvm.register_topi_schedule("depthwise_conv2d_nhwc.arm_cpu")
+def schedule_depthwise_conv2d_nhwc(cfg, outs):
+    """Create the schedule for depthwise_conv2d_nchw_spatial_pack"""
+    outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
+    s = te.create_schedule([x.op for x in outs])
+    out = outs[0]
+
+    ##### space definition begin #####
+    n, h, w, c = s[out].op.axis
+    cfg.define_split('tile_c', c, num_outputs=2)
+    _, hi = cfg.define_split('tile_h', h, num_outputs=2)
+    _, wi = cfg.define_split('tile_w', w, num_outputs=2)
+    cfg.define_annotate('locate_output', [hi, wi], 'locate_cache', num_anchor=1)
+
+    # fallback support
+    if cfg.is_fallback:
+        cfg['tile_c'] = SplitEntity([-1, 8])
+        cfg['tile_h'] = SplitEntity([-1, 2])
+        cfg['tile_w'] = SplitEntity([-1, 2])
+        cfg['locate_output'] = AnnotateEntity([1])
+    ##### space definition end #####
+
+    def schedule_conv(conv):
+        conv_data = conv.op.input_tensors[0]
+        if conv_data.name == "data_pad":
+            s[conv_data].compute_inline()

Review comment:
       @giuseros Thanks for detail experiment! I don't fully understand the data excel. Seems compute_at doesn't behave better than compute_inline?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] giuseros commented on a change in pull request #6095: Improve NHWC depthwise convolution for AArch64

Posted by GitBox <gi...@apache.org>.
giuseros commented on a change in pull request #6095:
URL: https://github.com/apache/incubator-tvm/pull/6095#discussion_r457940669



##########
File path: src/relay/op/tensor/reduce.cc
##########
@@ -295,7 +295,6 @@ bool ReduceRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
 }
 
 Expr MakeReduce(Expr data, Array<Integer> axis, bool keepdims, bool exclude, String op_name) {
-  std::cout << "making " << op_name << std::endl;

Review comment:
       Sorry, I probably forgot to rebase




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] FrozenGene commented on pull request #6095: Improve NHWC depthwise convolution for AArch64

Posted by GitBox <gi...@apache.org>.
FrozenGene commented on pull request #6095:
URL: https://github.com/apache/incubator-tvm/pull/6095#issuecomment-661756915


   > ACL implementation
   
   Hi @giuseros Thanks for the work. I fully understand your purpose and smoothy development path. As this schedule will be the default NHWC depthwise convolution, my opinion is we should try to achieve a good performance as far as we could achieve. Notably I don't mean we mush achieve like ACL ultimate performance then we could merge, optimization is not one-shot deal. But here I think we could enable auto tvm to help us to achieve better performance. I think it is worthy introducing into this pr.
   
   - This schedule will be applied for arm32 and arm64 both, we shouldn't only consider arm64. So auto tvm could help us to avoid this issue.
   
   - Tuning knob of `compute_at` (especially `data_pad`) could help us solve `parallel-compute-locality` issue (we can not assume we only run kernel only in one single core). see more detail: http://people.csail.mit.edu/jrk/halide-pldi13.pdf Figure 2
   
   I agree we should reduce tuning knob and improve tuning time experience, but if it could help us improve performance, I think we should introduce it in, otherwise we could avoid it.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] FrozenGene commented on pull request #6095: Improve NHWC depthwise convolution for AArch64

Posted by GitBox <gi...@apache.org>.
FrozenGene commented on pull request #6095:
URL: https://github.com/apache/incubator-tvm/pull/6095#issuecomment-673519634


   Thanks @giuseros Merged now


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] giuseros commented on a change in pull request #6095: Improve NHWC depthwise convolution for AArch64

Posted by GitBox <gi...@apache.org>.
giuseros commented on a change in pull request #6095:
URL: https://github.com/apache/incubator-tvm/pull/6095#discussion_r458842593



##########
File path: topi/python/topi/arm_cpu/depthwise_conv2d.py
##########
@@ -181,6 +181,154 @@ def depthwise_conv2d_nchw_spatial_pack(cfg, data, kernel, strides, padding, dila
 
     return _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype, num_tile=2)
 
+@autotvm.register_topi_compute("depthwise_conv2d_nhwc.arm_cpu")
+def compute_depthwise_conv2d_nhwc(_, data, kernel, strides, padding, dilation, out_dtype):
+    """TOPI compute callback for depthwise_conv2d nhwc
+
+    Parameters
+    ----------
+    cfg: ConfigEntity
+        The config for this template
+
+    data : tvm.te.Tensor
+        4-D with shape [batch, in_height, in_width, in_channel]
+
+    kernel : tvm.te.Tensor
+        4-D with shape [filter_height, filter_width, in_channel, channel_multiplier]
+
+    strides : list of two ints
+        [stride_height, stride_width]
+
+    padding : list of two ints
+        [pad_height, pad_width]
+
+    dilation : list of two ints
+        [dilation_height, dilation_width]
+
+    out_dtype: str
+        The output type. This is used for mixed precision.
+
+    Returns
+    -------
+    output : tvm.te.Tensor
+        4-D with shape [batch, out_height, out_width, out_channel]
+    """
+
+    out_dtype = out_dtype or data.dtype
+
+    N, IH, IW, IC = get_const_tuple(data.shape)
+
+    if isinstance(dilation, int):
+        dilation_h = dilation_w = dilation
+    else:
+        dilation_h, dilation_w = dilation
+
+    KH, KW, IC, channel_multiplier = get_const_tuple(kernel.shape)
+
+    dilated_kernel_h = (KH - 1) * dilation_h + 1
+    dilated_kernel_w = (KW - 1) * dilation_w + 1
+
+    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
+        padding, (dilated_kernel_h, dilated_kernel_w))
+    HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
+
+    OH = (IH + pad_top + pad_down - dilated_kernel_h) // HSTR + 1
+    OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1
+
+    if pad_top or pad_left:
+        data_pad = nn.pad(data, [0, pad_top, pad_left, 0], [0, pad_down, pad_right, 0],
+                          name="data_pad")
+    else:
+        data_pad = data
+
+    output_shape = (N, OH, OW, IC*channel_multiplier)
+
+    idxdiv = tvm.tir.indexdiv
+    idxmod = tvm.tir.indexmod
+
+    reduce_h = te.reduce_axis((0, KH), name='reduce_h')
+    reduce_w = te.reduce_axis((0, KW), name='reduce_w')
+
+    out = te.compute(output_shape, lambda n, h, w, c:
+                     te.sum(data_pad[n,
+                                     HSTR*h+dilation_h*reduce_h,
+                                     w*WSTR+reduce_w*dilation_w,
+                                     idxdiv(c, channel_multiplier)].astype(out_dtype) *
+                            kernel[reduce_h,
+                                   reduce_w,
+                                   idxdiv(c, channel_multiplier),
+                                   idxmod(c, channel_multiplier)].astype(out_dtype),
+                            axis=[reduce_h, reduce_w]),
+                     name='depthwise_conv2d_nhwc_output')
+    return out
+
+@autotvm.register_topi_schedule("depthwise_conv2d_nhwc.arm_cpu")
+def schedule_depthwise_conv2d_nhwc(cfg, outs):
+    """Create the schedule for depthwise_conv2d_nchw_spatial_pack"""
+    outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
+    s = te.create_schedule([x.op for x in outs])
+    out = outs[0]
+
+    ##### space definition begin #####
+    n, h, w, c = s[out].op.axis
+    cfg.define_split('tile_c', c, num_outputs=2)
+    _, hi = cfg.define_split('tile_h', h, num_outputs=2)
+    _, wi = cfg.define_split('tile_w', w, num_outputs=2)
+    cfg.define_annotate('locate_output', [hi, wi], 'locate_cache', num_anchor=1)
+
+    # fallback support
+    if cfg.is_fallback:
+        cfg['tile_c'] = SplitEntity([-1, 8])
+        cfg['tile_h'] = SplitEntity([-1, 2])
+        cfg['tile_w'] = SplitEntity([-1, 2])
+        cfg['locate_output'] = AnnotateEntity([1])
+    ##### space definition end #####
+
+    def schedule_conv(conv):
+        conv_data = conv.op.input_tensors[0]
+        if conv_data.name == "data_pad":
+            s[conv_data].compute_inline()

Review comment:
       No problem! I tend to be very conservative when it comes to add new knobs (because I try to take the tuning time low), so I prefer to experiment a bit before adding one. 
   
   Anyway your suggestions made everything go faster then TFlite, so I am the one thankful :) 
   
   About `compute_at` yes, in general seems to go slower, but for very small spatial sizes (see row 7, height/widht=14 with stride=2 and 576 channels) it brings improvements. This is why at the end I left both `compute_at`. I can always come back later with more evidence and remove one of them. 
   
   This is tested on multicore (4 threads). I did few(er) experiments with single core, but seems to behave the same




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] FrozenGene commented on pull request #6095: Improve NHWC depthwise convolution for AArch64

Posted by GitBox <gi...@apache.org>.
FrozenGene commented on pull request #6095:
URL: https://github.com/apache/incubator-tvm/pull/6095#issuecomment-662241492


   > Hi @FrozenGene ,
   > 
   > Let me thank you for the review and the pointers (the Halide paper is quite interesting)!
   > 
   > I tried most of your suggestions and all I got was a tiny improvement on `mobilenetv2` but a significant increase in tuning time.
   > 
   > Since this looks like it will take more time for more analysis I would prefer if we could take this PR in as is and I can follow up with further improvements in the future.
   
   I agree. We could limit our tuning knob if we find it does no effect.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] giuseros commented on pull request #6095: Improve NHWC depthwise convolution for AArch64

Posted by GitBox <gi...@apache.org>.
giuseros commented on pull request #6095:
URL: https://github.com/apache/incubator-tvm/pull/6095#issuecomment-661282843


   cc @u99127 @anijain2305 @FrozenGene 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] giuseros commented on a change in pull request #6095: Improve NHWC depthwise convolution for AArch64

Posted by GitBox <gi...@apache.org>.
giuseros commented on a change in pull request #6095:
URL: https://github.com/apache/incubator-tvm/pull/6095#discussion_r458145956



##########
File path: topi/python/topi/arm_cpu/depthwise_conv2d.py
##########
@@ -181,6 +180,130 @@ def depthwise_conv2d_nchw_spatial_pack(cfg, data, kernel, strides, padding, dila
 
     return _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype, num_tile=2)
 
+@autotvm.register_topi_compute("depthwise_conv2d_nhwc.arm_cpu")
+def compute_depthwise_conv2d_nhwc(_, data, kernel, strides, padding, dilation, out_dtype):
+    """TOPI compute callback for depthwise_conv2d nhwc
+
+    Parameters
+    ----------
+    cfg: ConfigEntity
+        The config for this template
+
+    data : tvm.te.Tensor
+        4-D with shape [batch, in_height, in_width, in_channel]
+
+    kernel : tvm.te.Tensor
+        4-D with shape [filter_height, filter_width, in_channel, channel_multiplier]
+
+    strides : list of two ints
+        [stride_height, stride_width]
+
+    padding : list of two ints
+        [pad_height, pad_width]
+
+    dilation : list of two ints
+        [dilation_height, dilation_width]
+
+    out_dtype: str
+        The output type. This is used for mixed precision.
+
+    Returns
+    -------
+    output : tvm.te.Tensor
+        4-D with shape [batch, out_height, out_width, out_channel]
+    """
+
+    out_dtype = out_dtype or data.dtype
+
+    N, IH, IW, IC = get_const_tuple(data.shape)
+
+    if isinstance(dilation, int):
+        dilation_h = dilation_w = dilation
+    else:
+        dilation_h, dilation_w = dilation
+
+    KH, KW, IC, channel_multiplier = get_const_tuple(kernel.shape)
+
+    dilated_kernel_h = (KH - 1) * dilation_h + 1
+    dilated_kernel_w = (KW - 1) * dilation_w + 1
+
+    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
+        padding, (dilated_kernel_h, dilated_kernel_w))
+    HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
+
+    OH = (IH + pad_top + pad_down - dilated_kernel_h) // HSTR + 1
+    OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1
+
+    if pad_top or pad_left:
+        data_pad = nn.pad(data, [0, pad_top, pad_left, 0], [0, pad_down, pad_right, 0],
+                          name="data_pad")
+    else:
+        data_pad = data
+
+    output_shape = (N, OH, OW, IC*channel_multiplier)
+
+    idxdiv = tvm.tir.indexdiv
+    idxmod = tvm.tir.indexmod
+
+    reduce_h = te.reduce_axis((0, KH), name='reduce_h')
+    reduce_w = te.reduce_axis((0, KW), name='reduce_w')
+
+    out = te.compute(output_shape, lambda n, h, w, c:
+                     te.sum(data_pad[n,
+                                     HSTR*h+dilation_h*reduce_h,
+                                     w*WSTR+reduce_w*dilation_w,
+                                     idxdiv(c, channel_multiplier)].astype(out_dtype) *
+                            kernel[reduce_h,
+                                   reduce_w,
+                                   idxdiv(c, channel_multiplier),
+                                   idxmod(c, channel_multiplier)].astype(out_dtype),
+                            axis=[reduce_h, reduce_w]),
+                     name='depthwise_conv2d_nhwc_output')
+
+    return out
+
+@autotvm.register_topi_schedule("depthwise_conv2d_nhwc.arm_cpu")
+def schedule_depthwise_conv2d_nhwc(_, outs):
+    """Create the schedule for depthwise_conv2d_nchw_spatial_pack"""
+    outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
+    s = te.create_schedule([x.op for x in outs])
+    out = outs[0]
+
+    def schedule_conv(conv):
+        n, w, h, c = conv.op.axis
+        r_h, r_w = conv.op.reduce_axis
+        co, ci = s[conv].split(c, 8)
+        wo, wi = s[conv].split(w, 2)
+        ho, hi = s[conv].split(h, 2)
+
+        s[conv].reorder(n, wo, ho, co, wi, hi, r_h, r_w, ci)
+        s[conv].parallel(wo)

Review comment:
       Actually the `ho, wo` reordering gave me a 5% improvement. Thanks!




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] FrozenGene merged pull request #6095: Improve NHWC depthwise convolution for AArch64

Posted by GitBox <gi...@apache.org>.
FrozenGene merged pull request #6095:
URL: https://github.com/apache/incubator-tvm/pull/6095


   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] FrozenGene commented on a change in pull request #6095: Improve NHWC depthwise convolution for AArch64

Posted by GitBox <gi...@apache.org>.
FrozenGene commented on a change in pull request #6095:
URL: https://github.com/apache/incubator-tvm/pull/6095#discussion_r457806536



##########
File path: src/relay/op/tensor/reduce.cc
##########
@@ -295,7 +295,6 @@ bool ReduceRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
 }
 
 Expr MakeReduce(Expr data, Array<Integer> axis, bool keepdims, bool exclude, String op_name) {
-  std::cout << "making " << op_name << std::endl;

Review comment:
       This should be done by pr : https://github.com/apache/incubator-tvm/pull/6072, could you update your code to latest master?

##########
File path: topi/python/topi/arm_cpu/depthwise_conv2d.py
##########
@@ -181,6 +180,130 @@ def depthwise_conv2d_nchw_spatial_pack(cfg, data, kernel, strides, padding, dila
 
     return _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype, num_tile=2)
 
+@autotvm.register_topi_compute("depthwise_conv2d_nhwc.arm_cpu")
+def compute_depthwise_conv2d_nhwc(_, data, kernel, strides, padding, dilation, out_dtype):
+    """TOPI compute callback for depthwise_conv2d nhwc
+
+    Parameters
+    ----------
+    cfg: ConfigEntity
+        The config for this template
+
+    data : tvm.te.Tensor
+        4-D with shape [batch, in_height, in_width, in_channel]
+
+    kernel : tvm.te.Tensor
+        4-D with shape [filter_height, filter_width, in_channel, channel_multiplier]
+
+    strides : list of two ints
+        [stride_height, stride_width]
+
+    padding : list of two ints
+        [pad_height, pad_width]
+
+    dilation : list of two ints
+        [dilation_height, dilation_width]
+
+    out_dtype: str
+        The output type. This is used for mixed precision.
+
+    Returns
+    -------
+    output : tvm.te.Tensor
+        4-D with shape [batch, out_height, out_width, out_channel]
+    """
+
+    out_dtype = out_dtype or data.dtype
+
+    N, IH, IW, IC = get_const_tuple(data.shape)
+
+    if isinstance(dilation, int):
+        dilation_h = dilation_w = dilation
+    else:
+        dilation_h, dilation_w = dilation
+
+    KH, KW, IC, channel_multiplier = get_const_tuple(kernel.shape)
+
+    dilated_kernel_h = (KH - 1) * dilation_h + 1
+    dilated_kernel_w = (KW - 1) * dilation_w + 1
+
+    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
+        padding, (dilated_kernel_h, dilated_kernel_w))
+    HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
+
+    OH = (IH + pad_top + pad_down - dilated_kernel_h) // HSTR + 1
+    OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1
+
+    if pad_top or pad_left:
+        data_pad = nn.pad(data, [0, pad_top, pad_left, 0], [0, pad_down, pad_right, 0],
+                          name="data_pad")
+    else:
+        data_pad = data
+
+    output_shape = (N, OH, OW, IC*channel_multiplier)
+
+    idxdiv = tvm.tir.indexdiv
+    idxmod = tvm.tir.indexmod
+
+    reduce_h = te.reduce_axis((0, KH), name='reduce_h')
+    reduce_w = te.reduce_axis((0, KW), name='reduce_w')
+
+    out = te.compute(output_shape, lambda n, h, w, c:
+                     te.sum(data_pad[n,
+                                     HSTR*h+dilation_h*reduce_h,
+                                     w*WSTR+reduce_w*dilation_w,
+                                     idxdiv(c, channel_multiplier)].astype(out_dtype) *
+                            kernel[reduce_h,
+                                   reduce_w,
+                                   idxdiv(c, channel_multiplier),
+                                   idxmod(c, channel_multiplier)].astype(out_dtype),
+                            axis=[reduce_h, reduce_w]),
+                     name='depthwise_conv2d_nhwc_output')
+
+    return out
+
+@autotvm.register_topi_schedule("depthwise_conv2d_nhwc.arm_cpu")
+def schedule_depthwise_conv2d_nhwc(_, outs):
+    """Create the schedule for depthwise_conv2d_nchw_spatial_pack"""
+    outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
+    s = te.create_schedule([x.op for x in outs])
+    out = outs[0]
+
+    def schedule_conv(conv):
+        n, w, h, c = conv.op.axis
+        r_h, r_w = conv.op.reduce_axis
+        co, ci = s[conv].split(c, 8)
+        wo, wi = s[conv].split(w, 2)
+        ho, hi = s[conv].split(h, 2)
+

Review comment:
       Let us leverage auto tvm mechanism, let it search best parameter. 

##########
File path: topi/python/topi/arm_cpu/depthwise_conv2d.py
##########
@@ -181,6 +180,130 @@ def depthwise_conv2d_nchw_spatial_pack(cfg, data, kernel, strides, padding, dila
 
     return _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype, num_tile=2)
 
+@autotvm.register_topi_compute("depthwise_conv2d_nhwc.arm_cpu")
+def compute_depthwise_conv2d_nhwc(_, data, kernel, strides, padding, dilation, out_dtype):
+    """TOPI compute callback for depthwise_conv2d nhwc
+
+    Parameters
+    ----------
+    cfg: ConfigEntity
+        The config for this template
+
+    data : tvm.te.Tensor
+        4-D with shape [batch, in_height, in_width, in_channel]
+
+    kernel : tvm.te.Tensor
+        4-D with shape [filter_height, filter_width, in_channel, channel_multiplier]
+
+    strides : list of two ints
+        [stride_height, stride_width]
+
+    padding : list of two ints
+        [pad_height, pad_width]
+
+    dilation : list of two ints
+        [dilation_height, dilation_width]
+
+    out_dtype: str
+        The output type. This is used for mixed precision.
+
+    Returns
+    -------
+    output : tvm.te.Tensor
+        4-D with shape [batch, out_height, out_width, out_channel]
+    """
+
+    out_dtype = out_dtype or data.dtype
+
+    N, IH, IW, IC = get_const_tuple(data.shape)
+
+    if isinstance(dilation, int):
+        dilation_h = dilation_w = dilation
+    else:
+        dilation_h, dilation_w = dilation
+
+    KH, KW, IC, channel_multiplier = get_const_tuple(kernel.shape)
+
+    dilated_kernel_h = (KH - 1) * dilation_h + 1
+    dilated_kernel_w = (KW - 1) * dilation_w + 1
+
+    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
+        padding, (dilated_kernel_h, dilated_kernel_w))
+    HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
+
+    OH = (IH + pad_top + pad_down - dilated_kernel_h) // HSTR + 1
+    OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1
+
+    if pad_top or pad_left:
+        data_pad = nn.pad(data, [0, pad_top, pad_left, 0], [0, pad_down, pad_right, 0],
+                          name="data_pad")
+    else:
+        data_pad = data
+
+    output_shape = (N, OH, OW, IC*channel_multiplier)
+
+    idxdiv = tvm.tir.indexdiv
+    idxmod = tvm.tir.indexmod
+
+    reduce_h = te.reduce_axis((0, KH), name='reduce_h')
+    reduce_w = te.reduce_axis((0, KW), name='reduce_w')
+
+    out = te.compute(output_shape, lambda n, h, w, c:
+                     te.sum(data_pad[n,
+                                     HSTR*h+dilation_h*reduce_h,
+                                     w*WSTR+reduce_w*dilation_w,
+                                     idxdiv(c, channel_multiplier)].astype(out_dtype) *
+                            kernel[reduce_h,
+                                   reduce_w,
+                                   idxdiv(c, channel_multiplier),
+                                   idxmod(c, channel_multiplier)].astype(out_dtype),
+                            axis=[reduce_h, reduce_w]),
+                     name='depthwise_conv2d_nhwc_output')
+
+    return out
+
+@autotvm.register_topi_schedule("depthwise_conv2d_nhwc.arm_cpu")
+def schedule_depthwise_conv2d_nhwc(_, outs):
+    """Create the schedule for depthwise_conv2d_nchw_spatial_pack"""
+    outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
+    s = te.create_schedule([x.op for x in outs])
+    out = outs[0]
+

Review comment:
       Let us add the schedule of `data_pad`. i.e. add `compute_at` stage, which could help us solve the `parallel-compute-locality` trade off and improve the performance.

##########
File path: topi/python/topi/arm_cpu/depthwise_conv2d.py
##########
@@ -181,6 +180,130 @@ def depthwise_conv2d_nchw_spatial_pack(cfg, data, kernel, strides, padding, dila
 
     return _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype, num_tile=2)
 
+@autotvm.register_topi_compute("depthwise_conv2d_nhwc.arm_cpu")
+def compute_depthwise_conv2d_nhwc(_, data, kernel, strides, padding, dilation, out_dtype):
+    """TOPI compute callback for depthwise_conv2d nhwc
+
+    Parameters
+    ----------
+    cfg: ConfigEntity
+        The config for this template
+
+    data : tvm.te.Tensor
+        4-D with shape [batch, in_height, in_width, in_channel]
+
+    kernel : tvm.te.Tensor
+        4-D with shape [filter_height, filter_width, in_channel, channel_multiplier]
+
+    strides : list of two ints
+        [stride_height, stride_width]
+
+    padding : list of two ints
+        [pad_height, pad_width]
+
+    dilation : list of two ints
+        [dilation_height, dilation_width]
+
+    out_dtype: str
+        The output type. This is used for mixed precision.
+
+    Returns
+    -------
+    output : tvm.te.Tensor
+        4-D with shape [batch, out_height, out_width, out_channel]
+    """
+
+    out_dtype = out_dtype or data.dtype
+
+    N, IH, IW, IC = get_const_tuple(data.shape)
+
+    if isinstance(dilation, int):
+        dilation_h = dilation_w = dilation
+    else:
+        dilation_h, dilation_w = dilation
+
+    KH, KW, IC, channel_multiplier = get_const_tuple(kernel.shape)
+
+    dilated_kernel_h = (KH - 1) * dilation_h + 1
+    dilated_kernel_w = (KW - 1) * dilation_w + 1
+
+    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
+        padding, (dilated_kernel_h, dilated_kernel_w))
+    HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
+
+    OH = (IH + pad_top + pad_down - dilated_kernel_h) // HSTR + 1
+    OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1
+
+    if pad_top or pad_left:
+        data_pad = nn.pad(data, [0, pad_top, pad_left, 0], [0, pad_down, pad_right, 0],
+                          name="data_pad")
+    else:
+        data_pad = data
+
+    output_shape = (N, OH, OW, IC*channel_multiplier)
+
+    idxdiv = tvm.tir.indexdiv
+    idxmod = tvm.tir.indexmod
+
+    reduce_h = te.reduce_axis((0, KH), name='reduce_h')
+    reduce_w = te.reduce_axis((0, KW), name='reduce_w')
+
+    out = te.compute(output_shape, lambda n, h, w, c:
+                     te.sum(data_pad[n,
+                                     HSTR*h+dilation_h*reduce_h,
+                                     w*WSTR+reduce_w*dilation_w,
+                                     idxdiv(c, channel_multiplier)].astype(out_dtype) *
+                            kernel[reduce_h,
+                                   reduce_w,
+                                   idxdiv(c, channel_multiplier),
+                                   idxmod(c, channel_multiplier)].astype(out_dtype),
+                            axis=[reduce_h, reduce_w]),
+                     name='depthwise_conv2d_nhwc_output')
+
+    return out
+
+@autotvm.register_topi_schedule("depthwise_conv2d_nhwc.arm_cpu")
+def schedule_depthwise_conv2d_nhwc(_, outs):
+    """Create the schedule for depthwise_conv2d_nchw_spatial_pack"""
+    outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
+    s = te.create_schedule([x.op for x in outs])
+    out = outs[0]
+
+    def schedule_conv(conv):
+        n, w, h, c = conv.op.axis
+        r_h, r_w = conv.op.reduce_axis
+        co, ci = s[conv].split(c, 8)
+        wo, wi = s[conv].split(w, 2)
+        ho, hi = s[conv].split(h, 2)
+
+        s[conv].reorder(n, wo, ho, co, wi, hi, r_h, r_w, ci)
+        s[conv].parallel(wo)
+        s[conv].vectorize(ci)
+
+    def schedule_conv_out(out):
+        n, h, w, c = out.op.axis
+        co, ci = s[out].split(c, 8)
+        wo, wi = s[out].split(w, 2)
+        ho, hi = s[out].split(h, 2)
+        ci_outer, ci_inner = s[out].split(ci, 4)
+        s[out].reorder(n, wo, ho, co, wi, hi)
+        s[out].vectorize(ci_inner)
+        compute_at_axis = hi

Review comment:
       Let us add `tunable` compute_at_axis. i.e. at least we could have `hi` / `wi`.

##########
File path: topi/python/topi/arm_cpu/depthwise_conv2d.py
##########
@@ -181,6 +180,130 @@ def depthwise_conv2d_nchw_spatial_pack(cfg, data, kernel, strides, padding, dila
 
     return _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype, num_tile=2)
 
+@autotvm.register_topi_compute("depthwise_conv2d_nhwc.arm_cpu")
+def compute_depthwise_conv2d_nhwc(_, data, kernel, strides, padding, dilation, out_dtype):
+    """TOPI compute callback for depthwise_conv2d nhwc
+
+    Parameters
+    ----------
+    cfg: ConfigEntity
+        The config for this template
+
+    data : tvm.te.Tensor
+        4-D with shape [batch, in_height, in_width, in_channel]
+
+    kernel : tvm.te.Tensor
+        4-D with shape [filter_height, filter_width, in_channel, channel_multiplier]
+
+    strides : list of two ints
+        [stride_height, stride_width]
+
+    padding : list of two ints
+        [pad_height, pad_width]
+
+    dilation : list of two ints
+        [dilation_height, dilation_width]
+
+    out_dtype: str
+        The output type. This is used for mixed precision.
+
+    Returns
+    -------
+    output : tvm.te.Tensor
+        4-D with shape [batch, out_height, out_width, out_channel]
+    """
+
+    out_dtype = out_dtype or data.dtype
+
+    N, IH, IW, IC = get_const_tuple(data.shape)
+
+    if isinstance(dilation, int):
+        dilation_h = dilation_w = dilation
+    else:
+        dilation_h, dilation_w = dilation
+
+    KH, KW, IC, channel_multiplier = get_const_tuple(kernel.shape)
+
+    dilated_kernel_h = (KH - 1) * dilation_h + 1
+    dilated_kernel_w = (KW - 1) * dilation_w + 1
+
+    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
+        padding, (dilated_kernel_h, dilated_kernel_w))
+    HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
+
+    OH = (IH + pad_top + pad_down - dilated_kernel_h) // HSTR + 1
+    OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1
+
+    if pad_top or pad_left:
+        data_pad = nn.pad(data, [0, pad_top, pad_left, 0], [0, pad_down, pad_right, 0],
+                          name="data_pad")
+    else:
+        data_pad = data
+
+    output_shape = (N, OH, OW, IC*channel_multiplier)
+
+    idxdiv = tvm.tir.indexdiv
+    idxmod = tvm.tir.indexmod
+
+    reduce_h = te.reduce_axis((0, KH), name='reduce_h')
+    reduce_w = te.reduce_axis((0, KW), name='reduce_w')
+
+    out = te.compute(output_shape, lambda n, h, w, c:
+                     te.sum(data_pad[n,
+                                     HSTR*h+dilation_h*reduce_h,
+                                     w*WSTR+reduce_w*dilation_w,
+                                     idxdiv(c, channel_multiplier)].astype(out_dtype) *
+                            kernel[reduce_h,
+                                   reduce_w,
+                                   idxdiv(c, channel_multiplier),
+                                   idxmod(c, channel_multiplier)].astype(out_dtype),
+                            axis=[reduce_h, reduce_w]),
+                     name='depthwise_conv2d_nhwc_output')
+
+    return out
+
+@autotvm.register_topi_schedule("depthwise_conv2d_nhwc.arm_cpu")
+def schedule_depthwise_conv2d_nhwc(_, outs):
+    """Create the schedule for depthwise_conv2d_nchw_spatial_pack"""
+    outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
+    s = te.create_schedule([x.op for x in outs])
+    out = outs[0]
+
+    def schedule_conv(conv):
+        n, w, h, c = conv.op.axis
+        r_h, r_w = conv.op.reduce_axis
+        co, ci = s[conv].split(c, 8)
+        wo, wi = s[conv].split(w, 2)
+        ho, hi = s[conv].split(h, 2)
+
+        s[conv].reorder(n, wo, ho, co, wi, hi, r_h, r_w, ci)
+        s[conv].parallel(wo)

Review comment:
       1. Why we reorder to `wo, ho`? 
   
   2. Let us `fuse n, wo` instead of parallel `wo` directly even if `n` is 1.

##########
File path: topi/python/topi/arm_cpu/depthwise_conv2d.py
##########
@@ -181,6 +180,130 @@ def depthwise_conv2d_nchw_spatial_pack(cfg, data, kernel, strides, padding, dila
 
     return _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype, num_tile=2)
 
+@autotvm.register_topi_compute("depthwise_conv2d_nhwc.arm_cpu")
+def compute_depthwise_conv2d_nhwc(_, data, kernel, strides, padding, dilation, out_dtype):
+    """TOPI compute callback for depthwise_conv2d nhwc
+
+    Parameters
+    ----------
+    cfg: ConfigEntity
+        The config for this template
+
+    data : tvm.te.Tensor
+        4-D with shape [batch, in_height, in_width, in_channel]
+
+    kernel : tvm.te.Tensor
+        4-D with shape [filter_height, filter_width, in_channel, channel_multiplier]
+
+    strides : list of two ints
+        [stride_height, stride_width]
+
+    padding : list of two ints
+        [pad_height, pad_width]
+
+    dilation : list of two ints
+        [dilation_height, dilation_width]
+
+    out_dtype: str
+        The output type. This is used for mixed precision.
+
+    Returns
+    -------
+    output : tvm.te.Tensor
+        4-D with shape [batch, out_height, out_width, out_channel]
+    """
+
+    out_dtype = out_dtype or data.dtype
+
+    N, IH, IW, IC = get_const_tuple(data.shape)
+
+    if isinstance(dilation, int):
+        dilation_h = dilation_w = dilation
+    else:
+        dilation_h, dilation_w = dilation
+
+    KH, KW, IC, channel_multiplier = get_const_tuple(kernel.shape)
+
+    dilated_kernel_h = (KH - 1) * dilation_h + 1
+    dilated_kernel_w = (KW - 1) * dilation_w + 1
+
+    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
+        padding, (dilated_kernel_h, dilated_kernel_w))
+    HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
+
+    OH = (IH + pad_top + pad_down - dilated_kernel_h) // HSTR + 1
+    OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1
+
+    if pad_top or pad_left:
+        data_pad = nn.pad(data, [0, pad_top, pad_left, 0], [0, pad_down, pad_right, 0],
+                          name="data_pad")
+    else:
+        data_pad = data
+
+    output_shape = (N, OH, OW, IC*channel_multiplier)
+
+    idxdiv = tvm.tir.indexdiv
+    idxmod = tvm.tir.indexmod
+
+    reduce_h = te.reduce_axis((0, KH), name='reduce_h')
+    reduce_w = te.reduce_axis((0, KW), name='reduce_w')
+
+    out = te.compute(output_shape, lambda n, h, w, c:
+                     te.sum(data_pad[n,
+                                     HSTR*h+dilation_h*reduce_h,
+                                     w*WSTR+reduce_w*dilation_w,
+                                     idxdiv(c, channel_multiplier)].astype(out_dtype) *
+                            kernel[reduce_h,
+                                   reduce_w,
+                                   idxdiv(c, channel_multiplier),
+                                   idxmod(c, channel_multiplier)].astype(out_dtype),
+                            axis=[reduce_h, reduce_w]),
+                     name='depthwise_conv2d_nhwc_output')
+
+    return out
+
+@autotvm.register_topi_schedule("depthwise_conv2d_nhwc.arm_cpu")
+def schedule_depthwise_conv2d_nhwc(_, outs):
+    """Create the schedule for depthwise_conv2d_nchw_spatial_pack"""
+    outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
+    s = te.create_schedule([x.op for x in outs])
+    out = outs[0]
+
+    def schedule_conv(conv):
+        n, w, h, c = conv.op.axis
+        r_h, r_w = conv.op.reduce_axis
+        co, ci = s[conv].split(c, 8)
+        wo, wi = s[conv].split(w, 2)
+        ho, hi = s[conv].split(h, 2)
+
+        s[conv].reorder(n, wo, ho, co, wi, hi, r_h, r_w, ci)
+        s[conv].parallel(wo)
+        s[conv].vectorize(ci)
+
+    def schedule_conv_out(out):
+        n, h, w, c = out.op.axis
+        co, ci = s[out].split(c, 8)
+        wo, wi = s[out].split(w, 2)
+        ho, hi = s[out].split(h, 2)
+        ci_outer, ci_inner = s[out].split(ci, 4)

Review comment:
       ditto

##########
File path: topi/python/topi/arm_cpu/depthwise_conv2d.py
##########
@@ -181,6 +180,130 @@ def depthwise_conv2d_nchw_spatial_pack(cfg, data, kernel, strides, padding, dila
 
     return _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype, num_tile=2)
 
+@autotvm.register_topi_compute("depthwise_conv2d_nhwc.arm_cpu")
+def compute_depthwise_conv2d_nhwc(_, data, kernel, strides, padding, dilation, out_dtype):
+    """TOPI compute callback for depthwise_conv2d nhwc
+
+    Parameters
+    ----------
+    cfg: ConfigEntity
+        The config for this template
+
+    data : tvm.te.Tensor
+        4-D with shape [batch, in_height, in_width, in_channel]
+
+    kernel : tvm.te.Tensor
+        4-D with shape [filter_height, filter_width, in_channel, channel_multiplier]
+
+    strides : list of two ints
+        [stride_height, stride_width]
+
+    padding : list of two ints
+        [pad_height, pad_width]
+
+    dilation : list of two ints
+        [dilation_height, dilation_width]
+
+    out_dtype: str
+        The output type. This is used for mixed precision.
+
+    Returns
+    -------
+    output : tvm.te.Tensor
+        4-D with shape [batch, out_height, out_width, out_channel]
+    """
+
+    out_dtype = out_dtype or data.dtype
+
+    N, IH, IW, IC = get_const_tuple(data.shape)
+
+    if isinstance(dilation, int):
+        dilation_h = dilation_w = dilation
+    else:
+        dilation_h, dilation_w = dilation
+
+    KH, KW, IC, channel_multiplier = get_const_tuple(kernel.shape)
+
+    dilated_kernel_h = (KH - 1) * dilation_h + 1
+    dilated_kernel_w = (KW - 1) * dilation_w + 1
+
+    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
+        padding, (dilated_kernel_h, dilated_kernel_w))
+    HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
+
+    OH = (IH + pad_top + pad_down - dilated_kernel_h) // HSTR + 1
+    OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1
+
+    if pad_top or pad_left:
+        data_pad = nn.pad(data, [0, pad_top, pad_left, 0], [0, pad_down, pad_right, 0],
+                          name="data_pad")
+    else:
+        data_pad = data
+
+    output_shape = (N, OH, OW, IC*channel_multiplier)
+
+    idxdiv = tvm.tir.indexdiv
+    idxmod = tvm.tir.indexmod
+
+    reduce_h = te.reduce_axis((0, KH), name='reduce_h')
+    reduce_w = te.reduce_axis((0, KW), name='reduce_w')
+
+    out = te.compute(output_shape, lambda n, h, w, c:
+                     te.sum(data_pad[n,
+                                     HSTR*h+dilation_h*reduce_h,
+                                     w*WSTR+reduce_w*dilation_w,
+                                     idxdiv(c, channel_multiplier)].astype(out_dtype) *
+                            kernel[reduce_h,
+                                   reduce_w,
+                                   idxdiv(c, channel_multiplier),
+                                   idxmod(c, channel_multiplier)].astype(out_dtype),
+                            axis=[reduce_h, reduce_w]),
+                     name='depthwise_conv2d_nhwc_output')
+
+    return out
+
+@autotvm.register_topi_schedule("depthwise_conv2d_nhwc.arm_cpu")
+def schedule_depthwise_conv2d_nhwc(_, outs):
+    """Create the schedule for depthwise_conv2d_nchw_spatial_pack"""
+    outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
+    s = te.create_schedule([x.op for x in outs])
+    out = outs[0]
+
+    def schedule_conv(conv):
+        n, w, h, c = conv.op.axis
+        r_h, r_w = conv.op.reduce_axis
+        co, ci = s[conv].split(c, 8)
+        wo, wi = s[conv].split(w, 2)
+        ho, hi = s[conv].split(h, 2)
+
+        s[conv].reorder(n, wo, ho, co, wi, hi, r_h, r_w, ci)
+        s[conv].parallel(wo)
+        s[conv].vectorize(ci)
+
+    def schedule_conv_out(out):
+        n, h, w, c = out.op.axis
+        co, ci = s[out].split(c, 8)
+        wo, wi = s[out].split(w, 2)
+        ho, hi = s[out].split(h, 2)
+        ci_outer, ci_inner = s[out].split(ci, 4)
+        s[out].reorder(n, wo, ho, co, wi, hi)
+        s[out].vectorize(ci_inner)
+        compute_at_axis = hi
+        s[out].parallel(wo)

Review comment:
       ditto




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] giuseros commented on pull request #6095: Improve NHWC depthwise convolution for AArch64

Posted by GitBox <gi...@apache.org>.
giuseros commented on pull request #6095:
URL: https://github.com/apache/incubator-tvm/pull/6095#issuecomment-661964405


   Hi @tqchen , 
   
   Is this an issue with my changes or with the CI? It seems to point to an `import` in a particular file, but I am not able to see anything wrong with that  `import`. 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] giuseros commented on pull request #6095: Improve NHWC depthwise convolution for AArch64

Posted by GitBox <gi...@apache.org>.
giuseros commented on pull request #6095:
URL: https://github.com/apache/incubator-tvm/pull/6095#issuecomment-663266858


   @FrozenGene , @anijain2305 ,
   This is my last commit before holidays. I enabled only the `arm_cpu` tests (everything passes locally). The case with `dilation>1` will be untested for now (as it is for CUDA). I hope this will make CI happy. If not, I will pick this up when I am back. 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] giuseros commented on pull request #6095: Improve NHWC depthwise convolution for AArch64

Posted by GitBox <gi...@apache.org>.
giuseros commented on pull request #6095:
URL: https://github.com/apache/incubator-tvm/pull/6095#issuecomment-661920128


   Hi @FrozenGene , 
   
   Let me thank you for the review and the pointers (the Halide paper is quite interesting)!
   
   I tried most of your suggestions  and all I got was a tiny improvement on `mobilenetv2` but a significant increase in tuning time. 
   
   Since this looks like it will take more time for more analysis I would prefer if we could take this PR in as is and I can follow up with further improvements in the future.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] anijain2305 commented on a change in pull request #6095: Improve NHWC depthwise convolution for AArch64

Posted by GitBox <gi...@apache.org>.
anijain2305 commented on a change in pull request #6095:
URL: https://github.com/apache/incubator-tvm/pull/6095#discussion_r459545611



##########
File path: python/tvm/relay/op/strategy/arm_cpu.py
##########
@@ -161,11 +161,11 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
                     name="depthwise_conv2d_nchw.x86")
         elif layout == "NHWC":
             assert kernel_layout == "HWOI"
-            logger.warning("depthwise_conv2d with layout NHWC is not optimized for arm cpu.")
+            #logger.warning("depthwise_conv2d with layout NHWC is not optimized for arm cpu.")

Review comment:
       Lets delete this line.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] FrozenGene edited a comment on pull request #6095: Improve NHWC depthwise convolution for AArch64

Posted by GitBox <gi...@apache.org>.
FrozenGene edited a comment on pull request #6095:
URL: https://github.com/apache/incubator-tvm/pull/6095#issuecomment-661756915


   > ACL implementation
   
   Hi @giuseros Thanks for the work. I fully understand your purpose and smoothy development path. As this schedule will be the default NHWC depthwise convolution, my opinion is we should try to achieve a good performance as far as we could achieve. Notably I don't mean we mush achieve like ACL ultimate performance then we could merge, optimization is not one-shot deal. But here I think we could enable auto tvm to help us to achieve better performance. I think it is worthy introducing into this pr.
   
   - This schedule will be applied for arm32 and arm64 both, we shouldn't only consider arm64. So auto tvm could help us to avoid this issue.
   
   - Tuning knob of `compute_at` (especially `data_pad`)  / split could help us solve `parallel-compute-locality` issue (we can not assume we only run kernel only in one single core). see more detail: http://people.csail.mit.edu/jrk/halide-pldi13.pdf Figure 2
   
   I agree we should reduce tuning knob and improve tuning time experience, but if it could help us improve performance, I think we should introduce it in, otherwise we could avoid it.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] giuseros commented on a change in pull request #6095: Improve NHWC depthwise convolution for AArch64

Posted by GitBox <gi...@apache.org>.
giuseros commented on a change in pull request #6095:
URL: https://github.com/apache/incubator-tvm/pull/6095#discussion_r458893850



##########
File path: topi/python/topi/arm_cpu/depthwise_conv2d.py
##########
@@ -181,6 +181,154 @@ def depthwise_conv2d_nchw_spatial_pack(cfg, data, kernel, strides, padding, dila
 
     return _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype, num_tile=2)
 
+@autotvm.register_topi_compute("depthwise_conv2d_nhwc.arm_cpu")
+def compute_depthwise_conv2d_nhwc(_, data, kernel, strides, padding, dilation, out_dtype):
+    """TOPI compute callback for depthwise_conv2d nhwc
+
+    Parameters
+    ----------
+    cfg: ConfigEntity
+        The config for this template
+
+    data : tvm.te.Tensor
+        4-D with shape [batch, in_height, in_width, in_channel]
+
+    kernel : tvm.te.Tensor
+        4-D with shape [filter_height, filter_width, in_channel, channel_multiplier]
+
+    strides : list of two ints
+        [stride_height, stride_width]
+
+    padding : list of two ints
+        [pad_height, pad_width]
+
+    dilation : list of two ints
+        [dilation_height, dilation_width]
+
+    out_dtype: str
+        The output type. This is used for mixed precision.
+
+    Returns
+    -------
+    output : tvm.te.Tensor
+        4-D with shape [batch, out_height, out_width, out_channel]
+    """
+
+    out_dtype = out_dtype or data.dtype
+
+    N, IH, IW, IC = get_const_tuple(data.shape)
+
+    if isinstance(dilation, int):
+        dilation_h = dilation_w = dilation
+    else:
+        dilation_h, dilation_w = dilation
+
+    KH, KW, IC, channel_multiplier = get_const_tuple(kernel.shape)
+
+    dilated_kernel_h = (KH - 1) * dilation_h + 1
+    dilated_kernel_w = (KW - 1) * dilation_w + 1
+
+    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
+        padding, (dilated_kernel_h, dilated_kernel_w))
+    HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
+
+    OH = (IH + pad_top + pad_down - dilated_kernel_h) // HSTR + 1
+    OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1
+
+    if pad_top or pad_left:
+        data_pad = nn.pad(data, [0, pad_top, pad_left, 0], [0, pad_down, pad_right, 0],
+                          name="data_pad")
+    else:
+        data_pad = data
+
+    output_shape = (N, OH, OW, IC*channel_multiplier)
+
+    idxdiv = tvm.tir.indexdiv
+    idxmod = tvm.tir.indexmod
+
+    reduce_h = te.reduce_axis((0, KH), name='reduce_h')
+    reduce_w = te.reduce_axis((0, KW), name='reduce_w')
+
+    out = te.compute(output_shape, lambda n, h, w, c:
+                     te.sum(data_pad[n,
+                                     HSTR*h+dilation_h*reduce_h,
+                                     w*WSTR+reduce_w*dilation_w,
+                                     idxdiv(c, channel_multiplier)].astype(out_dtype) *
+                            kernel[reduce_h,
+                                   reduce_w,
+                                   idxdiv(c, channel_multiplier),
+                                   idxmod(c, channel_multiplier)].astype(out_dtype),
+                            axis=[reduce_h, reduce_w]),
+                     name='depthwise_conv2d_nhwc_output')
+    return out
+
+@autotvm.register_topi_schedule("depthwise_conv2d_nhwc.arm_cpu")
+def schedule_depthwise_conv2d_nhwc(cfg, outs):
+    """Create the schedule for depthwise_conv2d_nchw_spatial_pack"""
+    outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
+    s = te.create_schedule([x.op for x in outs])
+    out = outs[0]
+
+    ##### space definition begin #####
+    n, h, w, c = s[out].op.axis
+    cfg.define_split('tile_c', c, num_outputs=2)
+    _, hi = cfg.define_split('tile_h', h, num_outputs=2)
+    _, wi = cfg.define_split('tile_w', w, num_outputs=2)
+    cfg.define_annotate('locate_output', [hi, wi], 'locate_cache', num_anchor=1)
+
+    # fallback support
+    if cfg.is_fallback:
+        cfg['tile_c'] = SplitEntity([-1, 8])
+        cfg['tile_h'] = SplitEntity([-1, 2])
+        cfg['tile_w'] = SplitEntity([-1, 2])
+        cfg['locate_output'] = AnnotateEntity([1])
+    ##### space definition end #####
+
+    def schedule_conv(conv):
+        conv_data = conv.op.input_tensors[0]
+        if conv_data.name == "data_pad":
+            s[conv_data].compute_inline()

Review comment:
       Yes, that made quite a difference 
   |H/W | C   |S  |  inline                |  compute_at(ho)      |compute_at(wo)       | no-inline            |
   |----|-----|---|------------------------|----------------------|---------------------|----------------------|
   |112 | 96  |2  |  1.452941176470588     |  2.5349999999999997  |0.9236363636363636   | 0.36249999999999993  |
   |56  |144  |1  |  1.7249999999999999    |  1.7                 |3.4                  | 1.38                 |
   |56  |144  |2  |  3.028571428571429     |  3.028571428571429   |2.3777777777777778   | 1.5214285714285716   |
   |28  |192  |1  |  1.711111111111111     |  1.9124999999999999  |1.9374999999999998   | 1.409090909090909    |
   |28  |192  |2  |  1.6833333333333333    |  3.3333333333333335  |1.275                | 1.442857142857143    |
   |14  |384  |1  |  3.15                  |  3.15                |2.066666666666667    | 0.5727272727272728   |
   |14  |576  |1  |  0.8863636363636364    |  2.425               |2.4374999999999996   | 0.527027027027027    |
   |14  |576  |2  |  2.4                   |  2.4                 |2.4                  | 0.6714285714285715   |
   |7   |960  |1  |  2.9272727272727272    |   2.496153846153846  |0.7833333333333332   | 1.3416666666666668   |
   
   So now it even makes more sense to have them as knobs :)




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] giuseros commented on pull request #6095: Improve NHWC depthwise convolution for AArch64

Posted by GitBox <gi...@apache.org>.
giuseros commented on pull request #6095:
URL: https://github.com/apache/incubator-tvm/pull/6095#issuecomment-673507098


   Hi @FrozenGene , @anijain2305 ,
   This PR finally passed the CI. Would it be possible to merge it?
   
   Thanks!


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] giuseros commented on pull request #6095: Improve NHWC depthwise convolution for AArch64

Posted by GitBox <gi...@apache.org>.
giuseros commented on pull request #6095:
URL: https://github.com/apache/incubator-tvm/pull/6095#issuecomment-661738954


   Hi @FrozenGene , 
   Before introducing tuning knobs, I wanted to do first an analysis to find the minimum set of tuning parameters to bring the best performance. 
   
   The aim is to reduce the tuning time. The point is that sometimes we are constrained by the number of registers available in AArch64, so trying out different splits might only increase the tuning time, without giving any benefit. 
   
   So the idea was to have a "default" schedule which mimics [ACL implementation](https://github.com/ARM-software/ComputeLibrary/blob/master/src/core/NEON/kernels/convolution/depthwise/impl_qa8_qa8.hpp#L292-L314) and then introduce (the minimal set of) tuning knobs + tensorization to speed things up. 
   
   What do you think? If you want to add tuning knobs in this PR, I will try to do the tuning analysis today


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] FrozenGene commented on pull request #6095: Improve NHWC depthwise convolution for AArch64

Posted by GitBox <gi...@apache.org>.
FrozenGene commented on pull request #6095:
URL: https://github.com/apache/incubator-tvm/pull/6095#issuecomment-662242235


   > Hi @tqchen ,
   > 
   > Is this an issue with my changes or with the CI? It seems to point to an `import` in a particular file, but I am not able to see anything wrong with that `import`.
   
   Seems not related with your change. Others meet this CI error too.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] FrozenGene commented on a change in pull request #6095: Improve NHWC depthwise convolution for AArch64

Posted by GitBox <gi...@apache.org>.
FrozenGene commented on a change in pull request #6095:
URL: https://github.com/apache/incubator-tvm/pull/6095#discussion_r457807134



##########
File path: topi/python/topi/arm_cpu/depthwise_conv2d.py
##########
@@ -181,6 +180,130 @@ def depthwise_conv2d_nchw_spatial_pack(cfg, data, kernel, strides, padding, dila
 
     return _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype, num_tile=2)
 
+@autotvm.register_topi_compute("depthwise_conv2d_nhwc.arm_cpu")
+def compute_depthwise_conv2d_nhwc(_, data, kernel, strides, padding, dilation, out_dtype):
+    """TOPI compute callback for depthwise_conv2d nhwc
+
+    Parameters
+    ----------
+    cfg: ConfigEntity
+        The config for this template
+
+    data : tvm.te.Tensor
+        4-D with shape [batch, in_height, in_width, in_channel]
+
+    kernel : tvm.te.Tensor
+        4-D with shape [filter_height, filter_width, in_channel, channel_multiplier]
+
+    strides : list of two ints
+        [stride_height, stride_width]
+
+    padding : list of two ints
+        [pad_height, pad_width]
+
+    dilation : list of two ints
+        [dilation_height, dilation_width]
+
+    out_dtype: str
+        The output type. This is used for mixed precision.
+
+    Returns
+    -------
+    output : tvm.te.Tensor
+        4-D with shape [batch, out_height, out_width, out_channel]
+    """
+
+    out_dtype = out_dtype or data.dtype
+
+    N, IH, IW, IC = get_const_tuple(data.shape)
+
+    if isinstance(dilation, int):
+        dilation_h = dilation_w = dilation
+    else:
+        dilation_h, dilation_w = dilation
+
+    KH, KW, IC, channel_multiplier = get_const_tuple(kernel.shape)
+
+    dilated_kernel_h = (KH - 1) * dilation_h + 1
+    dilated_kernel_w = (KW - 1) * dilation_w + 1
+
+    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
+        padding, (dilated_kernel_h, dilated_kernel_w))
+    HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
+
+    OH = (IH + pad_top + pad_down - dilated_kernel_h) // HSTR + 1
+    OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1
+
+    if pad_top or pad_left:
+        data_pad = nn.pad(data, [0, pad_top, pad_left, 0], [0, pad_down, pad_right, 0],
+                          name="data_pad")
+    else:
+        data_pad = data
+
+    output_shape = (N, OH, OW, IC*channel_multiplier)
+
+    idxdiv = tvm.tir.indexdiv
+    idxmod = tvm.tir.indexmod
+
+    reduce_h = te.reduce_axis((0, KH), name='reduce_h')
+    reduce_w = te.reduce_axis((0, KW), name='reduce_w')
+
+    out = te.compute(output_shape, lambda n, h, w, c:
+                     te.sum(data_pad[n,
+                                     HSTR*h+dilation_h*reduce_h,
+                                     w*WSTR+reduce_w*dilation_w,
+                                     idxdiv(c, channel_multiplier)].astype(out_dtype) *
+                            kernel[reduce_h,
+                                   reduce_w,
+                                   idxdiv(c, channel_multiplier),
+                                   idxmod(c, channel_multiplier)].astype(out_dtype),
+                            axis=[reduce_h, reduce_w]),
+                     name='depthwise_conv2d_nhwc_output')
+
+    return out
+
+@autotvm.register_topi_schedule("depthwise_conv2d_nhwc.arm_cpu")
+def schedule_depthwise_conv2d_nhwc(_, outs):
+    """Create the schedule for depthwise_conv2d_nchw_spatial_pack"""
+    outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
+    s = te.create_schedule([x.op for x in outs])
+    out = outs[0]
+
+    def schedule_conv(conv):
+        n, w, h, c = conv.op.axis
+        r_h, r_w = conv.op.reduce_axis
+        co, ci = s[conv].split(c, 8)
+        wo, wi = s[conv].split(w, 2)
+        ho, hi = s[conv].split(h, 2)
+

Review comment:
       Let us leverage auto tvm mechanism, let it search best parameter as we have it, I can not imagine the reason we don't leverage it.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] giuseros commented on a change in pull request #6095: Improve NHWC depthwise convolution for AArch64

Posted by GitBox <gi...@apache.org>.
giuseros commented on a change in pull request #6095:
URL: https://github.com/apache/incubator-tvm/pull/6095#discussion_r459455600



##########
File path: topi/python/topi/arm_cpu/depthwise_conv2d.py
##########
@@ -181,6 +181,170 @@ def depthwise_conv2d_nchw_spatial_pack(cfg, data, kernel, strides, padding, dila
 
     return _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype, num_tile=2)
 
+@autotvm.register_topi_compute("depthwise_conv2d_nhwc.arm_cpu")
+def compute_depthwise_conv2d_nhwc(_, data, kernel, strides, padding, dilation, out_dtype):
+    """TOPI compute callback for depthwise_conv2d nhwc
+
+    Parameters
+    ----------
+    cfg: ConfigEntity
+        The config for this template
+
+    data : tvm.te.Tensor
+        4-D with shape [batch, in_height, in_width, in_channel]
+
+    kernel : tvm.te.Tensor
+        4-D with shape [filter_height, filter_width, in_channel, channel_multiplier]
+
+    strides : list of two ints
+        [stride_height, stride_width]
+
+    padding : list of two ints
+        [pad_height, pad_width]
+
+    dilation : list of two ints
+        [dilation_height, dilation_width]
+
+    out_dtype: str
+        The output type. This is used for mixed precision.
+
+    Returns
+    -------
+    output : tvm.te.Tensor
+        4-D with shape [batch, out_height, out_width, out_channel]
+    """
+
+    out_dtype = out_dtype or data.dtype
+
+    N, IH, IW, IC = get_const_tuple(data.shape)
+
+    if isinstance(dilation, int):
+        dilation_h = dilation_w = dilation
+    else:
+        dilation_h, dilation_w = dilation
+
+    KH, KW, IC, channel_multiplier = get_const_tuple(kernel.shape)
+
+    dilated_kernel_h = (KH - 1) * dilation_h + 1
+    dilated_kernel_w = (KW - 1) * dilation_w + 1
+
+    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
+        padding, (dilated_kernel_h, dilated_kernel_w))
+    HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
+
+    OH = (IH + pad_top + pad_down - dilated_kernel_h) // HSTR + 1
+    OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1
+
+    if pad_top or pad_left or pad_down or pad_right:
+        data_pad = nn.pad(data, [0, pad_top, pad_left, 0], [0, pad_down, pad_right, 0],
+                          name="data_pad")
+    else:
+        data_pad = data
+
+    output_shape = (N, OH, OW, IC*channel_multiplier)
+
+    idxdiv = tvm.tir.indexdiv
+    idxmod = tvm.tir.indexmod
+
+    reduce_h = te.reduce_axis((0, KH), name='reduce_h')
+    reduce_w = te.reduce_axis((0, KW), name='reduce_w')
+
+    out = te.compute(output_shape, lambda n, h, w, c:
+                     te.sum(data_pad[n,
+                                     HSTR*h+dilation_h*reduce_h,
+                                     w*WSTR+reduce_w*dilation_w,
+                                     idxdiv(c, channel_multiplier)].astype(out_dtype) *
+                            kernel[reduce_h,
+                                   reduce_w,
+                                   idxdiv(c, channel_multiplier),
+                                   idxmod(c, channel_multiplier)].astype(out_dtype),
+                            axis=[reduce_h, reduce_w]),
+                     name='depthwise_conv2d_nhwc_output')
+    return out
+
+@autotvm.register_topi_schedule("depthwise_conv2d_nhwc.arm_cpu")
+def schedule_depthwise_conv2d_nhwc(cfg, outs):
+    """Create the schedule for depthwise_conv2d_nchw_spatial_pack"""
+    outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
+    s = te.create_schedule([x.op for x in outs])
+    out = outs[0]
+
+    ##### space definition begin #####
+    n, h, w, c = s[out].op.axis
+    cfg.define_split('tile_c', c, num_outputs=2)
+    _, hi = cfg.define_split('tile_h', h, num_outputs=2)
+    _, wi = cfg.define_split('tile_w', w, num_outputs=2)
+    cfg.define_knob('locate_output', [0, 1])
+
+    # fallback support
+    if cfg.is_fallback:
+        cfg['tile_c'] = SplitEntity([-1, 8])
+        cfg['tile_h'] = SplitEntity([-1, 2])
+        cfg['tile_w'] = SplitEntity([-1, 2])
+        cfg['locate_output'] = OtherOptionEntity(1)
+    ##### space definition end #####
+
+    def schedule_conv(conv):
+        conv_data = conv.op.input_tensors[0]
+
+        n, w, h, c = conv.op.axis
+        r_h, r_w = conv.op.reduce_axis
+        ho, hi = cfg['tile_h'].apply(s, conv, h)
+        wo, wi = cfg['tile_w'].apply(s, conv, w)
+        co, ci = cfg['tile_c'].apply(s, conv, c)
+
+        if conv_data.name == "data_pad":
+            # Define a policy for padding computation

Review comment:
       Done. 




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] giuseros commented on pull request #6095: Improve NHWC depthwise convolution for AArch64

Posted by GitBox <gi...@apache.org>.
giuseros commented on pull request #6095:
URL: https://github.com/apache/incubator-tvm/pull/6095#issuecomment-662986592


   Hi @FrozenGene , 
   Final changes to this PR. 
   
   * Introducing the `compute_at` knobs forces us to use `xgb_knob` tuner which does not support `locate_cache` annotations (so I switched to a custom knob). 
   * I also noticed that we were not legalizing depthwsie, which means we were running pooling, reductions, etc, in order to compute the offset contribution. Legalizing depthwise gives a boost of 2x (for quantized). 
   
   If this won't  be approved by tonight, I will turn into a draft to pick up after I come back from holidays (i.e., 15 days). 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] FrozenGene edited a comment on pull request #6095: Improve NHWC depthwise convolution for AArch64

Posted by GitBox <gi...@apache.org>.
FrozenGene edited a comment on pull request #6095:
URL: https://github.com/apache/incubator-tvm/pull/6095#issuecomment-661756915


   > ACL implementation
   
   Hi @giuseros Thanks for the work. I fully understand your purpose and smoothy development path. As this schedule will be the default NHWC depthwise convolution, my opinion is we should try to achieve a good performance as far as we could achieve. Notably I don't mean we mush achieve like ACL ultimate performance then we could merge, optimization is not one-shot deal. But here I think we could enable auto tvm to help us to achieve better performance. I think it is worthy introducing into this pr.
   
   - This schedule will be applied for arm32 and arm64 both, we shouldn't only consider arm64. So auto tvm (`split`) could help us to avoid this issue.
   
   - Tuning knob of `compute_at` (especially `data_pad`)  could help us solve `parallel-compute-locality` issue (we can not assume we only run kernel only in one single core). see more detail: http://people.csail.mit.edu/jrk/halide-pldi13.pdf Figure 2
   
   I agree we should reduce tuning knob and improve tuning time experience, but if it could help us improve performance, I think we should introduce it in, otherwise we could avoid it.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] giuseros commented on a change in pull request #6095: Improve NHWC depthwise convolution for AArch64

Posted by GitBox <gi...@apache.org>.
giuseros commented on a change in pull request #6095:
URL: https://github.com/apache/incubator-tvm/pull/6095#discussion_r458145575



##########
File path: topi/python/topi/arm_cpu/depthwise_conv2d.py
##########
@@ -181,6 +180,130 @@ def depthwise_conv2d_nchw_spatial_pack(cfg, data, kernel, strides, padding, dila
 
     return _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype, num_tile=2)
 
+@autotvm.register_topi_compute("depthwise_conv2d_nhwc.arm_cpu")
+def compute_depthwise_conv2d_nhwc(_, data, kernel, strides, padding, dilation, out_dtype):
+    """TOPI compute callback for depthwise_conv2d nhwc
+
+    Parameters
+    ----------
+    cfg: ConfigEntity
+        The config for this template
+
+    data : tvm.te.Tensor
+        4-D with shape [batch, in_height, in_width, in_channel]
+
+    kernel : tvm.te.Tensor
+        4-D with shape [filter_height, filter_width, in_channel, channel_multiplier]
+
+    strides : list of two ints
+        [stride_height, stride_width]
+
+    padding : list of two ints
+        [pad_height, pad_width]
+
+    dilation : list of two ints
+        [dilation_height, dilation_width]
+
+    out_dtype: str
+        The output type. This is used for mixed precision.
+
+    Returns
+    -------
+    output : tvm.te.Tensor
+        4-D with shape [batch, out_height, out_width, out_channel]
+    """
+
+    out_dtype = out_dtype or data.dtype
+
+    N, IH, IW, IC = get_const_tuple(data.shape)
+
+    if isinstance(dilation, int):
+        dilation_h = dilation_w = dilation
+    else:
+        dilation_h, dilation_w = dilation
+
+    KH, KW, IC, channel_multiplier = get_const_tuple(kernel.shape)
+
+    dilated_kernel_h = (KH - 1) * dilation_h + 1
+    dilated_kernel_w = (KW - 1) * dilation_w + 1
+
+    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
+        padding, (dilated_kernel_h, dilated_kernel_w))
+    HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
+
+    OH = (IH + pad_top + pad_down - dilated_kernel_h) // HSTR + 1
+    OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1
+
+    if pad_top or pad_left:
+        data_pad = nn.pad(data, [0, pad_top, pad_left, 0], [0, pad_down, pad_right, 0],
+                          name="data_pad")
+    else:
+        data_pad = data
+
+    output_shape = (N, OH, OW, IC*channel_multiplier)
+
+    idxdiv = tvm.tir.indexdiv
+    idxmod = tvm.tir.indexmod
+
+    reduce_h = te.reduce_axis((0, KH), name='reduce_h')
+    reduce_w = te.reduce_axis((0, KW), name='reduce_w')
+
+    out = te.compute(output_shape, lambda n, h, w, c:
+                     te.sum(data_pad[n,
+                                     HSTR*h+dilation_h*reduce_h,
+                                     w*WSTR+reduce_w*dilation_w,
+                                     idxdiv(c, channel_multiplier)].astype(out_dtype) *
+                            kernel[reduce_h,
+                                   reduce_w,
+                                   idxdiv(c, channel_multiplier),
+                                   idxmod(c, channel_multiplier)].astype(out_dtype),
+                            axis=[reduce_h, reduce_w]),
+                     name='depthwise_conv2d_nhwc_output')
+
+    return out
+
+@autotvm.register_topi_schedule("depthwise_conv2d_nhwc.arm_cpu")
+def schedule_depthwise_conv2d_nhwc(_, outs):
+    """Create the schedule for depthwise_conv2d_nchw_spatial_pack"""
+    outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
+    s = te.create_schedule([x.op for x in outs])
+    out = outs[0]
+
+    def schedule_conv(conv):
+        n, w, h, c = conv.op.axis
+        r_h, r_w = conv.op.reduce_axis
+        co, ci = s[conv].split(c, 8)
+        wo, wi = s[conv].split(w, 2)
+        ho, hi = s[conv].split(h, 2)
+
+        s[conv].reorder(n, wo, ho, co, wi, hi, r_h, r_w, ci)
+        s[conv].parallel(wo)
+        s[conv].vectorize(ci)
+
+    def schedule_conv_out(out):
+        n, h, w, c = out.op.axis
+        co, ci = s[out].split(c, 8)
+        wo, wi = s[out].split(w, 2)
+        ho, hi = s[out].split(h, 2)
+        ci_outer, ci_inner = s[out].split(ci, 4)

Review comment:
       In the case of the last split `s[out].split(ci, 4)` I need to do this split for quantized because I want to use arm intrinsics. I added a comment to explain this. 




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] FrozenGene commented on a change in pull request #6095: Improve NHWC depthwise convolution for AArch64

Posted by GitBox <gi...@apache.org>.
FrozenGene commented on a change in pull request #6095:
URL: https://github.com/apache/incubator-tvm/pull/6095#discussion_r458534291



##########
File path: topi/python/topi/arm_cpu/depthwise_conv2d.py
##########
@@ -181,6 +181,154 @@ def depthwise_conv2d_nchw_spatial_pack(cfg, data, kernel, strides, padding, dila
 
     return _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype, num_tile=2)
 
+@autotvm.register_topi_compute("depthwise_conv2d_nhwc.arm_cpu")
+def compute_depthwise_conv2d_nhwc(_, data, kernel, strides, padding, dilation, out_dtype):
+    """TOPI compute callback for depthwise_conv2d nhwc
+
+    Parameters
+    ----------
+    cfg: ConfigEntity
+        The config for this template
+
+    data : tvm.te.Tensor
+        4-D with shape [batch, in_height, in_width, in_channel]
+
+    kernel : tvm.te.Tensor
+        4-D with shape [filter_height, filter_width, in_channel, channel_multiplier]
+
+    strides : list of two ints
+        [stride_height, stride_width]
+
+    padding : list of two ints
+        [pad_height, pad_width]
+
+    dilation : list of two ints
+        [dilation_height, dilation_width]
+
+    out_dtype: str
+        The output type. This is used for mixed precision.
+
+    Returns
+    -------
+    output : tvm.te.Tensor
+        4-D with shape [batch, out_height, out_width, out_channel]
+    """
+
+    out_dtype = out_dtype or data.dtype
+
+    N, IH, IW, IC = get_const_tuple(data.shape)
+
+    if isinstance(dilation, int):
+        dilation_h = dilation_w = dilation
+    else:
+        dilation_h, dilation_w = dilation
+
+    KH, KW, IC, channel_multiplier = get_const_tuple(kernel.shape)
+
+    dilated_kernel_h = (KH - 1) * dilation_h + 1
+    dilated_kernel_w = (KW - 1) * dilation_w + 1
+
+    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
+        padding, (dilated_kernel_h, dilated_kernel_w))
+    HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
+
+    OH = (IH + pad_top + pad_down - dilated_kernel_h) // HSTR + 1
+    OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1
+
+    if pad_top or pad_left:
+        data_pad = nn.pad(data, [0, pad_top, pad_left, 0], [0, pad_down, pad_right, 0],
+                          name="data_pad")
+    else:
+        data_pad = data
+
+    output_shape = (N, OH, OW, IC*channel_multiplier)
+
+    idxdiv = tvm.tir.indexdiv
+    idxmod = tvm.tir.indexmod
+
+    reduce_h = te.reduce_axis((0, KH), name='reduce_h')
+    reduce_w = te.reduce_axis((0, KW), name='reduce_w')
+
+    out = te.compute(output_shape, lambda n, h, w, c:
+                     te.sum(data_pad[n,
+                                     HSTR*h+dilation_h*reduce_h,
+                                     w*WSTR+reduce_w*dilation_w,
+                                     idxdiv(c, channel_multiplier)].astype(out_dtype) *
+                            kernel[reduce_h,
+                                   reduce_w,
+                                   idxdiv(c, channel_multiplier),
+                                   idxmod(c, channel_multiplier)].astype(out_dtype),
+                            axis=[reduce_h, reduce_w]),
+                     name='depthwise_conv2d_nhwc_output')
+    return out
+
+@autotvm.register_topi_schedule("depthwise_conv2d_nhwc.arm_cpu")
+def schedule_depthwise_conv2d_nhwc(cfg, outs):
+    """Create the schedule for depthwise_conv2d_nchw_spatial_pack"""
+    outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
+    s = te.create_schedule([x.op for x in outs])
+    out = outs[0]
+
+    ##### space definition begin #####
+    n, h, w, c = s[out].op.axis
+    cfg.define_split('tile_c', c, num_outputs=2)
+    _, hi = cfg.define_split('tile_h', h, num_outputs=2)
+    _, wi = cfg.define_split('tile_w', w, num_outputs=2)
+    cfg.define_annotate('locate_output', [hi, wi], 'locate_cache', num_anchor=1)
+
+    # fallback support
+    if cfg.is_fallback:
+        cfg['tile_c'] = SplitEntity([-1, 8])
+        cfg['tile_h'] = SplitEntity([-1, 2])
+        cfg['tile_w'] = SplitEntity([-1, 2])
+        cfg['locate_output'] = AnnotateEntity([1])
+    ##### space definition end #####
+
+    def schedule_conv(conv):
+        conv_data = conv.op.input_tensors[0]
+        if conv_data.name == "data_pad":
+            s[conv_data].compute_inline()

Review comment:
       Could you try to do an experiment? After the split of convolution, we could do the logic pseudocode is like this:
   ```python
           cfg.define_knob('data_pad_inline', [0, 1, 2, 3])
           if cfg['data_pad_inline'].val == 1:
               s[conv_data].compute_at(s[conv], ho)
           if cfg['data_pad_inline'].val == 2:
               s[conv_data].compute_at(s[conv], wo)
           if cfg['data_pad_inline'].val == 3:
               s[conv_data].compute_inline()
   ```
   I suppose this could have better performance especially when we run multi cores. This is also one core improvement like my previous pr: https://github.com/apache/incubator-tvm/pull/2345
   




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] giuseros commented on a change in pull request #6095: Improve NHWC depthwise convolution for AArch64

Posted by GitBox <gi...@apache.org>.
giuseros commented on a change in pull request #6095:
URL: https://github.com/apache/incubator-tvm/pull/6095#discussion_r458814745



##########
File path: topi/python/topi/arm_cpu/depthwise_conv2d.py
##########
@@ -181,6 +181,154 @@ def depthwise_conv2d_nchw_spatial_pack(cfg, data, kernel, strides, padding, dila
 
     return _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype, num_tile=2)
 
+@autotvm.register_topi_compute("depthwise_conv2d_nhwc.arm_cpu")
+def compute_depthwise_conv2d_nhwc(_, data, kernel, strides, padding, dilation, out_dtype):
+    """TOPI compute callback for depthwise_conv2d nhwc
+
+    Parameters
+    ----------
+    cfg: ConfigEntity
+        The config for this template
+
+    data : tvm.te.Tensor
+        4-D with shape [batch, in_height, in_width, in_channel]
+
+    kernel : tvm.te.Tensor
+        4-D with shape [filter_height, filter_width, in_channel, channel_multiplier]
+
+    strides : list of two ints
+        [stride_height, stride_width]
+
+    padding : list of two ints
+        [pad_height, pad_width]
+
+    dilation : list of two ints
+        [dilation_height, dilation_width]
+
+    out_dtype: str
+        The output type. This is used for mixed precision.
+
+    Returns
+    -------
+    output : tvm.te.Tensor
+        4-D with shape [batch, out_height, out_width, out_channel]
+    """
+
+    out_dtype = out_dtype or data.dtype
+
+    N, IH, IW, IC = get_const_tuple(data.shape)
+
+    if isinstance(dilation, int):
+        dilation_h = dilation_w = dilation
+    else:
+        dilation_h, dilation_w = dilation
+
+    KH, KW, IC, channel_multiplier = get_const_tuple(kernel.shape)
+
+    dilated_kernel_h = (KH - 1) * dilation_h + 1
+    dilated_kernel_w = (KW - 1) * dilation_w + 1
+
+    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
+        padding, (dilated_kernel_h, dilated_kernel_w))
+    HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
+
+    OH = (IH + pad_top + pad_down - dilated_kernel_h) // HSTR + 1
+    OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1
+
+    if pad_top or pad_left:
+        data_pad = nn.pad(data, [0, pad_top, pad_left, 0], [0, pad_down, pad_right, 0],
+                          name="data_pad")
+    else:
+        data_pad = data
+
+    output_shape = (N, OH, OW, IC*channel_multiplier)
+
+    idxdiv = tvm.tir.indexdiv
+    idxmod = tvm.tir.indexmod
+
+    reduce_h = te.reduce_axis((0, KH), name='reduce_h')
+    reduce_w = te.reduce_axis((0, KW), name='reduce_w')
+
+    out = te.compute(output_shape, lambda n, h, w, c:
+                     te.sum(data_pad[n,
+                                     HSTR*h+dilation_h*reduce_h,
+                                     w*WSTR+reduce_w*dilation_w,
+                                     idxdiv(c, channel_multiplier)].astype(out_dtype) *
+                            kernel[reduce_h,
+                                   reduce_w,
+                                   idxdiv(c, channel_multiplier),
+                                   idxmod(c, channel_multiplier)].astype(out_dtype),
+                            axis=[reduce_h, reduce_w]),
+                     name='depthwise_conv2d_nhwc_output')
+    return out
+
+@autotvm.register_topi_schedule("depthwise_conv2d_nhwc.arm_cpu")
+def schedule_depthwise_conv2d_nhwc(cfg, outs):
+    """Create the schedule for depthwise_conv2d_nchw_spatial_pack"""
+    outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
+    s = te.create_schedule([x.op for x in outs])
+    out = outs[0]
+
+    ##### space definition begin #####
+    n, h, w, c = s[out].op.axis
+    cfg.define_split('tile_c', c, num_outputs=2)
+    _, hi = cfg.define_split('tile_h', h, num_outputs=2)
+    _, wi = cfg.define_split('tile_w', w, num_outputs=2)
+    cfg.define_annotate('locate_output', [hi, wi], 'locate_cache', num_anchor=1)
+
+    # fallback support
+    if cfg.is_fallback:
+        cfg['tile_c'] = SplitEntity([-1, 8])
+        cfg['tile_h'] = SplitEntity([-1, 2])
+        cfg['tile_w'] = SplitEntity([-1, 2])
+        cfg['locate_output'] = AnnotateEntity([1])
+    ##### space definition end #####
+
+    def schedule_conv(conv):
+        conv_data = conv.op.input_tensors[0]
+        if conv_data.name == "data_pad":
+            s[conv_data].compute_inline()

Review comment:
       Since those will be other 4 knobs to add, I extracted and tuned the depthwise operators in mobilenet_v2 with the 4 different policies (no padding, inline, `compute_at{ho,wo}`) and reported the results in terms of TFlite/TVM ratio (higher is better)
   
   |H/W | C   |S  |  inline                |  compute_at(ho)      |compute_at(wo)       | no-inline            |
   |----|-----|---|------------------------|----------------------|---------------------|----------------------|
   |112 | 96  |2  |  1.452941176470588     |  0.7042857142857142  |0.8355932203389829   | 0.36249999999999993  |
   |56  |144  |1  |  1.7249999999999999    |  0.85                |0.9714285714285715   | 1.38                 |
   |56  |144  |2  |  3.028571428571429     |  1.3187499999999999  |0.45869565217391306  | 1.5214285714285716   |
   |28  |192  |1  |  1.711111111111111     |  0.76                |0.5166666666666667   | 1.409090909090909    |
   |28  |192  |2  |  1.6833333333333333    |  0.5666666666666667  |0.48095238095238096  | 1.442857142857143    |
   |14  |384  |1  |  3.15                  |  1.26                |0.63                 | 0.5727272727272728   |
   |14  |576  |1  |  0.8863636363636364    |  0.97                |0.40625              | 0.527027027027027    |
   |14  |576  |2  |  2.4                   |  0.6857142857142858  |0.7000000000000001   | 0.6714285714285715   |
   |7   |960  |1  |  2.9272727272727272    |  1.211320754716981   |0.9056338028169014   | 1.3416666666666668   |
   
   * Since this is a memory bound operator, not inlining padding is always going to behave poorly, so I would remove it to reduce the tuning time. There would be the argument also to avoid the `compute_at` policies, but since I didn't try other networks (and in small cases it seems to run better) I would leave those 3 knobs (instead of 4)
   * Except one case, we are always faster (sometimes a lot faster) than TFlite.  Once I am done with this and other few improvements I will compare with ACL as well. 
   




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] giuseros commented on a change in pull request #6095: Improve NHWC depthwise convolution for AArch64

Posted by GitBox <gi...@apache.org>.
giuseros commented on a change in pull request #6095:
URL: https://github.com/apache/incubator-tvm/pull/6095#discussion_r458144144



##########
File path: topi/python/topi/arm_cpu/depthwise_conv2d.py
##########
@@ -181,6 +180,130 @@ def depthwise_conv2d_nchw_spatial_pack(cfg, data, kernel, strides, padding, dila
 
     return _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype, num_tile=2)
 
+@autotvm.register_topi_compute("depthwise_conv2d_nhwc.arm_cpu")
+def compute_depthwise_conv2d_nhwc(_, data, kernel, strides, padding, dilation, out_dtype):
+    """TOPI compute callback for depthwise_conv2d nhwc
+
+    Parameters
+    ----------
+    cfg: ConfigEntity
+        The config for this template
+
+    data : tvm.te.Tensor
+        4-D with shape [batch, in_height, in_width, in_channel]
+
+    kernel : tvm.te.Tensor
+        4-D with shape [filter_height, filter_width, in_channel, channel_multiplier]
+
+    strides : list of two ints
+        [stride_height, stride_width]
+
+    padding : list of two ints
+        [pad_height, pad_width]
+
+    dilation : list of two ints
+        [dilation_height, dilation_width]
+
+    out_dtype: str
+        The output type. This is used for mixed precision.
+
+    Returns
+    -------
+    output : tvm.te.Tensor
+        4-D with shape [batch, out_height, out_width, out_channel]
+    """
+
+    out_dtype = out_dtype or data.dtype
+
+    N, IH, IW, IC = get_const_tuple(data.shape)
+
+    if isinstance(dilation, int):
+        dilation_h = dilation_w = dilation
+    else:
+        dilation_h, dilation_w = dilation
+
+    KH, KW, IC, channel_multiplier = get_const_tuple(kernel.shape)
+
+    dilated_kernel_h = (KH - 1) * dilation_h + 1
+    dilated_kernel_w = (KW - 1) * dilation_w + 1
+
+    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
+        padding, (dilated_kernel_h, dilated_kernel_w))
+    HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
+
+    OH = (IH + pad_top + pad_down - dilated_kernel_h) // HSTR + 1
+    OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1
+
+    if pad_top or pad_left:
+        data_pad = nn.pad(data, [0, pad_top, pad_left, 0], [0, pad_down, pad_right, 0],
+                          name="data_pad")
+    else:
+        data_pad = data
+
+    output_shape = (N, OH, OW, IC*channel_multiplier)
+
+    idxdiv = tvm.tir.indexdiv
+    idxmod = tvm.tir.indexmod
+
+    reduce_h = te.reduce_axis((0, KH), name='reduce_h')
+    reduce_w = te.reduce_axis((0, KW), name='reduce_w')
+
+    out = te.compute(output_shape, lambda n, h, w, c:
+                     te.sum(data_pad[n,
+                                     HSTR*h+dilation_h*reduce_h,
+                                     w*WSTR+reduce_w*dilation_w,
+                                     idxdiv(c, channel_multiplier)].astype(out_dtype) *
+                            kernel[reduce_h,
+                                   reduce_w,
+                                   idxdiv(c, channel_multiplier),
+                                   idxmod(c, channel_multiplier)].astype(out_dtype),
+                            axis=[reduce_h, reduce_w]),
+                     name='depthwise_conv2d_nhwc_output')
+
+    return out
+
+@autotvm.register_topi_schedule("depthwise_conv2d_nhwc.arm_cpu")
+def schedule_depthwise_conv2d_nhwc(_, outs):
+    """Create the schedule for depthwise_conv2d_nchw_spatial_pack"""
+    outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
+    s = te.create_schedule([x.op for x in outs])
+    out = outs[0]
+

Review comment:
       I added a `compute_inline` to basically merge the padding into the main loop. If I understand correctly, this will be used only in `fp32`, as for quantized the padding will be always 0 (since it is handled by a different relay pass)




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] giuseros commented on a change in pull request #6095: Improve NHWC depthwise convolution for AArch64

Posted by GitBox <gi...@apache.org>.
giuseros commented on a change in pull request #6095:
URL: https://github.com/apache/incubator-tvm/pull/6095#discussion_r458814745



##########
File path: topi/python/topi/arm_cpu/depthwise_conv2d.py
##########
@@ -181,6 +181,154 @@ def depthwise_conv2d_nchw_spatial_pack(cfg, data, kernel, strides, padding, dila
 
     return _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype, num_tile=2)
 
+@autotvm.register_topi_compute("depthwise_conv2d_nhwc.arm_cpu")
+def compute_depthwise_conv2d_nhwc(_, data, kernel, strides, padding, dilation, out_dtype):
+    """TOPI compute callback for depthwise_conv2d nhwc
+
+    Parameters
+    ----------
+    cfg: ConfigEntity
+        The config for this template
+
+    data : tvm.te.Tensor
+        4-D with shape [batch, in_height, in_width, in_channel]
+
+    kernel : tvm.te.Tensor
+        4-D with shape [filter_height, filter_width, in_channel, channel_multiplier]
+
+    strides : list of two ints
+        [stride_height, stride_width]
+
+    padding : list of two ints
+        [pad_height, pad_width]
+
+    dilation : list of two ints
+        [dilation_height, dilation_width]
+
+    out_dtype: str
+        The output type. This is used for mixed precision.
+
+    Returns
+    -------
+    output : tvm.te.Tensor
+        4-D with shape [batch, out_height, out_width, out_channel]
+    """
+
+    out_dtype = out_dtype or data.dtype
+
+    N, IH, IW, IC = get_const_tuple(data.shape)
+
+    if isinstance(dilation, int):
+        dilation_h = dilation_w = dilation
+    else:
+        dilation_h, dilation_w = dilation
+
+    KH, KW, IC, channel_multiplier = get_const_tuple(kernel.shape)
+
+    dilated_kernel_h = (KH - 1) * dilation_h + 1
+    dilated_kernel_w = (KW - 1) * dilation_w + 1
+
+    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
+        padding, (dilated_kernel_h, dilated_kernel_w))
+    HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
+
+    OH = (IH + pad_top + pad_down - dilated_kernel_h) // HSTR + 1
+    OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1
+
+    if pad_top or pad_left:
+        data_pad = nn.pad(data, [0, pad_top, pad_left, 0], [0, pad_down, pad_right, 0],
+                          name="data_pad")
+    else:
+        data_pad = data
+
+    output_shape = (N, OH, OW, IC*channel_multiplier)
+
+    idxdiv = tvm.tir.indexdiv
+    idxmod = tvm.tir.indexmod
+
+    reduce_h = te.reduce_axis((0, KH), name='reduce_h')
+    reduce_w = te.reduce_axis((0, KW), name='reduce_w')
+
+    out = te.compute(output_shape, lambda n, h, w, c:
+                     te.sum(data_pad[n,
+                                     HSTR*h+dilation_h*reduce_h,
+                                     w*WSTR+reduce_w*dilation_w,
+                                     idxdiv(c, channel_multiplier)].astype(out_dtype) *
+                            kernel[reduce_h,
+                                   reduce_w,
+                                   idxdiv(c, channel_multiplier),
+                                   idxmod(c, channel_multiplier)].astype(out_dtype),
+                            axis=[reduce_h, reduce_w]),
+                     name='depthwise_conv2d_nhwc_output')
+    return out
+
+@autotvm.register_topi_schedule("depthwise_conv2d_nhwc.arm_cpu")
+def schedule_depthwise_conv2d_nhwc(cfg, outs):
+    """Create the schedule for depthwise_conv2d_nchw_spatial_pack"""
+    outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
+    s = te.create_schedule([x.op for x in outs])
+    out = outs[0]
+
+    ##### space definition begin #####
+    n, h, w, c = s[out].op.axis
+    cfg.define_split('tile_c', c, num_outputs=2)
+    _, hi = cfg.define_split('tile_h', h, num_outputs=2)
+    _, wi = cfg.define_split('tile_w', w, num_outputs=2)
+    cfg.define_annotate('locate_output', [hi, wi], 'locate_cache', num_anchor=1)
+
+    # fallback support
+    if cfg.is_fallback:
+        cfg['tile_c'] = SplitEntity([-1, 8])
+        cfg['tile_h'] = SplitEntity([-1, 2])
+        cfg['tile_w'] = SplitEntity([-1, 2])
+        cfg['locate_output'] = AnnotateEntity([1])
+    ##### space definition end #####
+
+    def schedule_conv(conv):
+        conv_data = conv.op.input_tensors[0]
+        if conv_data.name == "data_pad":
+            s[conv_data].compute_inline()

Review comment:
       Since those will be other 4 knobs to add, I extracted and tuned the depthwise operators in mobilenet_v2 with the 4 different policies (no padding, inline, `compute_at{ho,wo}`) and reported the results in terms of TFlite/TVM times (higher is better)
   
   |H/W | C   |S  |  inline                |  compute_at(ho)      |compute_at(wo)       | no-inline            |
   |----|-----|---|------------------------|----------------------|---------------------|----------------------|
   |112 | 96  |2  |  1.452941176470588     |  0.7042857142857142  |0.8355932203389829   | 0.36249999999999993  |
   |56  |144  |1  |  1.7249999999999999    |  0.85                |0.9714285714285715   | 1.38                 |
   |56  |144  |2  |  3.028571428571429     |  1.3187499999999999  |0.45869565217391306  | 1.5214285714285716   |
   |28  |192  |1  |  1.711111111111111     |  0.76                |0.5166666666666667   | 1.409090909090909    |
   |28  |192  |2  |  1.6833333333333333    |  0.5666666666666667  |0.48095238095238096  | 1.442857142857143    |
   |14  |384  |1  |  3.15                  |  1.26                |0.63                 | 0.5727272727272728   |
   |14  |576  |1  |  0.8863636363636364    |  0.97                |0.40625              | 0.527027027027027    |
   |14  |576  |2  |  2.4                   |  0.6857142857142858  |0.7000000000000001   | 0.6714285714285715   |
   |7   |960  |1  |  2.9272727272727272    |  1.211320754716981   |0.9056338028169014   | 1.3416666666666668   |
   
   * Since this is a memory bound operator, not inlining padding is always going to behave poorly, so I would remove it to reduce the tuning time. There would be the argument also to avoid the `compute_at` policies, but since I didn't try other networks (and in small cases it seems to run better) I would leave those 3 knobs (instead of 4)
   * Except one case, we are always faster (sometimes a lot faster) than TFlite.  Once I am done with this and other few improvements I will compare with ACL as well. 
   




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] FrozenGene commented on a change in pull request #6095: Improve NHWC depthwise convolution for AArch64

Posted by GitBox <gi...@apache.org>.
FrozenGene commented on a change in pull request #6095:
URL: https://github.com/apache/incubator-tvm/pull/6095#discussion_r458834973



##########
File path: topi/python/topi/arm_cpu/depthwise_conv2d.py
##########
@@ -181,6 +181,154 @@ def depthwise_conv2d_nchw_spatial_pack(cfg, data, kernel, strides, padding, dila
 
     return _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype, num_tile=2)
 
+@autotvm.register_topi_compute("depthwise_conv2d_nhwc.arm_cpu")
+def compute_depthwise_conv2d_nhwc(_, data, kernel, strides, padding, dilation, out_dtype):
+    """TOPI compute callback for depthwise_conv2d nhwc
+
+    Parameters
+    ----------
+    cfg: ConfigEntity
+        The config for this template
+
+    data : tvm.te.Tensor
+        4-D with shape [batch, in_height, in_width, in_channel]
+
+    kernel : tvm.te.Tensor
+        4-D with shape [filter_height, filter_width, in_channel, channel_multiplier]
+
+    strides : list of two ints
+        [stride_height, stride_width]
+
+    padding : list of two ints
+        [pad_height, pad_width]
+
+    dilation : list of two ints
+        [dilation_height, dilation_width]
+
+    out_dtype: str
+        The output type. This is used for mixed precision.
+
+    Returns
+    -------
+    output : tvm.te.Tensor
+        4-D with shape [batch, out_height, out_width, out_channel]
+    """
+
+    out_dtype = out_dtype or data.dtype
+
+    N, IH, IW, IC = get_const_tuple(data.shape)
+
+    if isinstance(dilation, int):
+        dilation_h = dilation_w = dilation
+    else:
+        dilation_h, dilation_w = dilation
+
+    KH, KW, IC, channel_multiplier = get_const_tuple(kernel.shape)
+
+    dilated_kernel_h = (KH - 1) * dilation_h + 1
+    dilated_kernel_w = (KW - 1) * dilation_w + 1
+
+    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
+        padding, (dilated_kernel_h, dilated_kernel_w))
+    HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
+
+    OH = (IH + pad_top + pad_down - dilated_kernel_h) // HSTR + 1
+    OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1
+
+    if pad_top or pad_left:
+        data_pad = nn.pad(data, [0, pad_top, pad_left, 0], [0, pad_down, pad_right, 0],
+                          name="data_pad")
+    else:
+        data_pad = data
+
+    output_shape = (N, OH, OW, IC*channel_multiplier)
+
+    idxdiv = tvm.tir.indexdiv
+    idxmod = tvm.tir.indexmod
+
+    reduce_h = te.reduce_axis((0, KH), name='reduce_h')
+    reduce_w = te.reduce_axis((0, KW), name='reduce_w')
+
+    out = te.compute(output_shape, lambda n, h, w, c:
+                     te.sum(data_pad[n,
+                                     HSTR*h+dilation_h*reduce_h,
+                                     w*WSTR+reduce_w*dilation_w,
+                                     idxdiv(c, channel_multiplier)].astype(out_dtype) *
+                            kernel[reduce_h,
+                                   reduce_w,
+                                   idxdiv(c, channel_multiplier),
+                                   idxmod(c, channel_multiplier)].astype(out_dtype),
+                            axis=[reduce_h, reduce_w]),
+                     name='depthwise_conv2d_nhwc_output')
+    return out
+
+@autotvm.register_topi_schedule("depthwise_conv2d_nhwc.arm_cpu")
+def schedule_depthwise_conv2d_nhwc(cfg, outs):
+    """Create the schedule for depthwise_conv2d_nchw_spatial_pack"""
+    outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
+    s = te.create_schedule([x.op for x in outs])
+    out = outs[0]
+
+    ##### space definition begin #####
+    n, h, w, c = s[out].op.axis
+    cfg.define_split('tile_c', c, num_outputs=2)
+    _, hi = cfg.define_split('tile_h', h, num_outputs=2)
+    _, wi = cfg.define_split('tile_w', w, num_outputs=2)
+    cfg.define_annotate('locate_output', [hi, wi], 'locate_cache', num_anchor=1)
+
+    # fallback support
+    if cfg.is_fallback:
+        cfg['tile_c'] = SplitEntity([-1, 8])
+        cfg['tile_h'] = SplitEntity([-1, 2])
+        cfg['tile_w'] = SplitEntity([-1, 2])
+        cfg['locate_output'] = AnnotateEntity([1])
+    ##### space definition end #####
+
+    def schedule_conv(conv):
+        conv_data = conv.op.input_tensors[0]
+        if conv_data.name == "data_pad":
+            s[conv_data].compute_inline()

Review comment:
       @giuseros Thanks for detail experiment! I don't fully understand the data excel. Seems compute_at doesn't behave better than compute_inline?
   
   The benchmark is single core or multi-cores?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] giuseros commented on a change in pull request #6095: Improve NHWC depthwise convolution for AArch64

Posted by GitBox <gi...@apache.org>.
giuseros commented on a change in pull request #6095:
URL: https://github.com/apache/incubator-tvm/pull/6095#discussion_r458814745



##########
File path: topi/python/topi/arm_cpu/depthwise_conv2d.py
##########
@@ -181,6 +181,154 @@ def depthwise_conv2d_nchw_spatial_pack(cfg, data, kernel, strides, padding, dila
 
     return _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype, num_tile=2)
 
+@autotvm.register_topi_compute("depthwise_conv2d_nhwc.arm_cpu")
+def compute_depthwise_conv2d_nhwc(_, data, kernel, strides, padding, dilation, out_dtype):
+    """TOPI compute callback for depthwise_conv2d nhwc
+
+    Parameters
+    ----------
+    cfg: ConfigEntity
+        The config for this template
+
+    data : tvm.te.Tensor
+        4-D with shape [batch, in_height, in_width, in_channel]
+
+    kernel : tvm.te.Tensor
+        4-D with shape [filter_height, filter_width, in_channel, channel_multiplier]
+
+    strides : list of two ints
+        [stride_height, stride_width]
+
+    padding : list of two ints
+        [pad_height, pad_width]
+
+    dilation : list of two ints
+        [dilation_height, dilation_width]
+
+    out_dtype: str
+        The output type. This is used for mixed precision.
+
+    Returns
+    -------
+    output : tvm.te.Tensor
+        4-D with shape [batch, out_height, out_width, out_channel]
+    """
+
+    out_dtype = out_dtype or data.dtype
+
+    N, IH, IW, IC = get_const_tuple(data.shape)
+
+    if isinstance(dilation, int):
+        dilation_h = dilation_w = dilation
+    else:
+        dilation_h, dilation_w = dilation
+
+    KH, KW, IC, channel_multiplier = get_const_tuple(kernel.shape)
+
+    dilated_kernel_h = (KH - 1) * dilation_h + 1
+    dilated_kernel_w = (KW - 1) * dilation_w + 1
+
+    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
+        padding, (dilated_kernel_h, dilated_kernel_w))
+    HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
+
+    OH = (IH + pad_top + pad_down - dilated_kernel_h) // HSTR + 1
+    OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1
+
+    if pad_top or pad_left:
+        data_pad = nn.pad(data, [0, pad_top, pad_left, 0], [0, pad_down, pad_right, 0],
+                          name="data_pad")
+    else:
+        data_pad = data
+
+    output_shape = (N, OH, OW, IC*channel_multiplier)
+
+    idxdiv = tvm.tir.indexdiv
+    idxmod = tvm.tir.indexmod
+
+    reduce_h = te.reduce_axis((0, KH), name='reduce_h')
+    reduce_w = te.reduce_axis((0, KW), name='reduce_w')
+
+    out = te.compute(output_shape, lambda n, h, w, c:
+                     te.sum(data_pad[n,
+                                     HSTR*h+dilation_h*reduce_h,
+                                     w*WSTR+reduce_w*dilation_w,
+                                     idxdiv(c, channel_multiplier)].astype(out_dtype) *
+                            kernel[reduce_h,
+                                   reduce_w,
+                                   idxdiv(c, channel_multiplier),
+                                   idxmod(c, channel_multiplier)].astype(out_dtype),
+                            axis=[reduce_h, reduce_w]),
+                     name='depthwise_conv2d_nhwc_output')
+    return out
+
+@autotvm.register_topi_schedule("depthwise_conv2d_nhwc.arm_cpu")
+def schedule_depthwise_conv2d_nhwc(cfg, outs):
+    """Create the schedule for depthwise_conv2d_nchw_spatial_pack"""
+    outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
+    s = te.create_schedule([x.op for x in outs])
+    out = outs[0]
+
+    ##### space definition begin #####
+    n, h, w, c = s[out].op.axis
+    cfg.define_split('tile_c', c, num_outputs=2)
+    _, hi = cfg.define_split('tile_h', h, num_outputs=2)
+    _, wi = cfg.define_split('tile_w', w, num_outputs=2)
+    cfg.define_annotate('locate_output', [hi, wi], 'locate_cache', num_anchor=1)
+
+    # fallback support
+    if cfg.is_fallback:
+        cfg['tile_c'] = SplitEntity([-1, 8])
+        cfg['tile_h'] = SplitEntity([-1, 2])
+        cfg['tile_w'] = SplitEntity([-1, 2])
+        cfg['locate_output'] = AnnotateEntity([1])
+    ##### space definition end #####
+
+    def schedule_conv(conv):
+        conv_data = conv.op.input_tensors[0]
+        if conv_data.name == "data_pad":
+            s[conv_data].compute_inline()

Review comment:
       Since those will be other 4 knobs to add, I extracted and tuned the depthwise operators in mobilenet_v2 with the 4 different policies (no pading, inline, `compute_at{ho,wo}`) and reported the results in terms of TFlite/TVM times (higher is better)
   
   |H/W | C   |S  |  inline                |  compute_at(ho)      |compute_at(wo)       | no-inline            |
   |----|-----|---|------------------------|----------------------|---------------------|----------------------|
   |112 | 96  |2  |  1.452941176470588     |  0.7042857142857142  |0.8355932203389829   | 0.36249999999999993  |
   |56  |144  |1  |  1.7249999999999999    |  0.85                |0.9714285714285715   | 1.38                 |
   |56  |144  |2  |  3.028571428571429     |  1.3187499999999999  |0.45869565217391306  | 1.5214285714285716   |
   |28  |192  |1  |  1.711111111111111     |  0.76                |0.5166666666666667   | 1.409090909090909    |
   |28  |192  |2  |  1.6833333333333333    |  0.5666666666666667  |0.48095238095238096  | 1.442857142857143    |
   |14  |384  |1  |  3.15                  |  1.26                |0.63                 | 0.5727272727272728   |
   |14  |576  |1  |  0.8863636363636364    |  0.97                |0.40625              | 0.527027027027027    |
   |14  |576  |2  |  2.4                   |  0.6857142857142858  |0.7000000000000001   | 0.6714285714285715   |
   |7   |960  |1  |  2.9272727272727272    |  1.211320754716981   |0.9056338028169014   | 1.3416666666666668   |
   
   * Since this is a memory bound operator, not inlining padding is always going to behave poorly, so I would remove it to reduce the tuning time. There would be the argument also to avoid the `compute_at` policies, but since I didn't try other networks (and in small cases it seems to run better) I would leave those 3 knobs (instead of 4)
   * Except one case, we are always faster (sometimes a lot faster) than TFlite.  Once I am done with this and other few improvements I will compare with ACL as well. 
   




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] FrozenGene commented on a change in pull request #6095: Improve NHWC depthwise convolution for AArch64

Posted by GitBox <gi...@apache.org>.
FrozenGene commented on a change in pull request #6095:
URL: https://github.com/apache/incubator-tvm/pull/6095#discussion_r458874360



##########
File path: topi/python/topi/arm_cpu/depthwise_conv2d.py
##########
@@ -181,6 +181,154 @@ def depthwise_conv2d_nchw_spatial_pack(cfg, data, kernel, strides, padding, dila
 
     return _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype, num_tile=2)
 
+@autotvm.register_topi_compute("depthwise_conv2d_nhwc.arm_cpu")
+def compute_depthwise_conv2d_nhwc(_, data, kernel, strides, padding, dilation, out_dtype):
+    """TOPI compute callback for depthwise_conv2d nhwc
+
+    Parameters
+    ----------
+    cfg: ConfigEntity
+        The config for this template
+
+    data : tvm.te.Tensor
+        4-D with shape [batch, in_height, in_width, in_channel]
+
+    kernel : tvm.te.Tensor
+        4-D with shape [filter_height, filter_width, in_channel, channel_multiplier]
+
+    strides : list of two ints
+        [stride_height, stride_width]
+
+    padding : list of two ints
+        [pad_height, pad_width]
+
+    dilation : list of two ints
+        [dilation_height, dilation_width]
+
+    out_dtype: str
+        The output type. This is used for mixed precision.
+
+    Returns
+    -------
+    output : tvm.te.Tensor
+        4-D with shape [batch, out_height, out_width, out_channel]
+    """
+
+    out_dtype = out_dtype or data.dtype
+
+    N, IH, IW, IC = get_const_tuple(data.shape)
+
+    if isinstance(dilation, int):
+        dilation_h = dilation_w = dilation
+    else:
+        dilation_h, dilation_w = dilation
+
+    KH, KW, IC, channel_multiplier = get_const_tuple(kernel.shape)
+
+    dilated_kernel_h = (KH - 1) * dilation_h + 1
+    dilated_kernel_w = (KW - 1) * dilation_w + 1
+
+    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
+        padding, (dilated_kernel_h, dilated_kernel_w))
+    HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
+
+    OH = (IH + pad_top + pad_down - dilated_kernel_h) // HSTR + 1
+    OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1
+
+    if pad_top or pad_left:
+        data_pad = nn.pad(data, [0, pad_top, pad_left, 0], [0, pad_down, pad_right, 0],
+                          name="data_pad")
+    else:
+        data_pad = data
+
+    output_shape = (N, OH, OW, IC*channel_multiplier)
+
+    idxdiv = tvm.tir.indexdiv
+    idxmod = tvm.tir.indexmod
+
+    reduce_h = te.reduce_axis((0, KH), name='reduce_h')
+    reduce_w = te.reduce_axis((0, KW), name='reduce_w')
+
+    out = te.compute(output_shape, lambda n, h, w, c:
+                     te.sum(data_pad[n,
+                                     HSTR*h+dilation_h*reduce_h,
+                                     w*WSTR+reduce_w*dilation_w,
+                                     idxdiv(c, channel_multiplier)].astype(out_dtype) *
+                            kernel[reduce_h,
+                                   reduce_w,
+                                   idxdiv(c, channel_multiplier),
+                                   idxmod(c, channel_multiplier)].astype(out_dtype),
+                            axis=[reduce_h, reduce_w]),
+                     name='depthwise_conv2d_nhwc_output')
+    return out
+
+@autotvm.register_topi_schedule("depthwise_conv2d_nhwc.arm_cpu")
+def schedule_depthwise_conv2d_nhwc(cfg, outs):
+    """Create the schedule for depthwise_conv2d_nchw_spatial_pack"""
+    outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
+    s = te.create_schedule([x.op for x in outs])
+    out = outs[0]
+
+    ##### space definition begin #####
+    n, h, w, c = s[out].op.axis
+    cfg.define_split('tile_c', c, num_outputs=2)
+    _, hi = cfg.define_split('tile_h', h, num_outputs=2)
+    _, wi = cfg.define_split('tile_w', w, num_outputs=2)
+    cfg.define_annotate('locate_output', [hi, wi], 'locate_cache', num_anchor=1)
+
+    # fallback support
+    if cfg.is_fallback:
+        cfg['tile_c'] = SplitEntity([-1, 8])
+        cfg['tile_h'] = SplitEntity([-1, 2])
+        cfg['tile_w'] = SplitEntity([-1, 2])
+        cfg['locate_output'] = AnnotateEntity([1])
+    ##### space definition end #####
+
+    def schedule_conv(conv):
+        conv_data = conv.op.input_tensors[0]
+        if conv_data.name == "data_pad":
+            s[conv_data].compute_inline()

Review comment:
       Ah... I suddenly think of we left one vectorize operation when we introduce compute_at, which is in fact is a new stage. i.e. we should do like this s[data_pad].vectorize(list(s[data_pad].op.axis)[-1]). I also doubt whether this could bring how much improvement but we should do. I think you could try it very quickly.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org