You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2019/11/10 03:35:17 UTC
[GitHub] [incubator-tvm] jackwish commented on a change in pull request
#4249: [TOPI][AlterOpLayout][ARM] Enabling NHWC to NCHW layout
transformation.
jackwish commented on a change in pull request #4249: [TOPI][AlterOpLayout][ARM] Enabling NHWC to NCHW layout transformation.
URL: https://github.com/apache/incubator-tvm/pull/4249#discussion_r344470657
##########
File path: topi/python/topi/arm_cpu/conv2d.py
##########
@@ -508,40 +506,62 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
groups = attrs.get_int('groups')
data_layout_key = "data_layout" if "data_layout" in new_attrs else "layout"
layout = attrs[data_layout_key]
+ kernel_layout = attrs['kernel_layout']
out_dtype = attrs["out_dtype"]
if out_dtype in ("same", ""):
out_dtype = tinfos[0].dtype
- if layout != 'NCHW':
- return None
if dilation != (1, 1):
logger.warning("Does not support weight pre-transform for dilated convolution.")
return None
+ # query config of this workload
data, kernel = tinfos[0:2]
- N, CI, H, W = get_const_tuple(data.shape)
- CO, _, KH, KW = get_const_tuple(kernel.shape)
+ if groups == 1:
+ workload = autotvm.task.args_to_workload(
+ [data, kernel, strides, padding, dilation, layout, out_dtype], conv2d)
+ else:
+ workload = autotvm.task.args_to_workload(
+ [data, kernel, strides, padding, dilation, out_dtype], depthwise_conv2d_nchw)
+
+ if layout == 'NCHW' and kernel_layout == 'OIHW':
+ N, CI, H, W = get_const_tuple(data.shape)
+ CO, _, KH, KW = get_const_tuple(kernel.shape)
+ elif layout == 'NHWC' and kernel_layout == 'HWIO' and groups == 1:
+ N, H, W, CI = get_const_tuple(data.shape)
+ KH, KW, _, CO = get_const_tuple(kernel.shape)
+ # Also modify the workload to pick up because later we convert to NCHW
+ # layout.
+ new_data = tvm.placeholder((N, CI, H, W), dtype=data.dtype)
+ new_kernel = tvm.placeholder((CO, CI, KH, KW), dtype=kernel.dtype)
+ new_layout = 'NCHW'
+ workload = autotvm.task.args_to_workload(
+ [new_data, new_kernel, strides, padding, dilation, new_layout, out_dtype], conv2d)
+ else:
+ return None
idxd = tvm.indexdiv
if groups == 1:
- # query config of this workload
- workload = autotvm.task.args_to_workload(
- [data, kernel, strides, padding, dilation, layout, out_dtype], conv2d)
target = tvm.target.current_target()
dispatch_ctx = autotvm.DispatchContext.current
cfg = dispatch_ctx.query(target, workload)
if cfg.is_fallback: # if is fallback, clear query cache and return None
autotvm.task.clear_fallback_cache(target, workload)
+ if layout != 'NCHW':
Review comment:
What about change this to be something like https://github.com/apache/incubator-tvm/pull/4249/files#diff-7ac4fe4f71257850a5471c9c63647bb0R530 (L530)?
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
users@infra.apache.org
With regards,
Apache Git Services