You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2019/11/10 03:35:17 UTC

[GitHub] [incubator-tvm] jackwish commented on a change in pull request #4249: [TOPI][AlterOpLayout][ARM] Enabling NHWC to NCHW layout transformation.

jackwish commented on a change in pull request #4249: [TOPI][AlterOpLayout][ARM] Enabling NHWC to NCHW layout transformation.
URL: https://github.com/apache/incubator-tvm/pull/4249#discussion_r344470657
 
 

 ##########
 File path: topi/python/topi/arm_cpu/conv2d.py
 ##########
 @@ -508,40 +506,62 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
     groups = attrs.get_int('groups')
     data_layout_key = "data_layout" if "data_layout" in new_attrs else "layout"
     layout = attrs[data_layout_key]
+    kernel_layout = attrs['kernel_layout']
     out_dtype = attrs["out_dtype"]
     if out_dtype in ("same", ""):
         out_dtype = tinfos[0].dtype
 
-    if layout != 'NCHW':
-        return None
     if dilation != (1, 1):
         logger.warning("Does not support weight pre-transform for dilated convolution.")
         return None
 
+    # query config of this workload
     data, kernel = tinfos[0:2]
-    N, CI, H, W = get_const_tuple(data.shape)
-    CO, _, KH, KW = get_const_tuple(kernel.shape)
+    if groups == 1:
+        workload = autotvm.task.args_to_workload(
+            [data, kernel, strides, padding, dilation, layout, out_dtype], conv2d)
+    else:
+        workload = autotvm.task.args_to_workload(
+            [data, kernel, strides, padding, dilation, out_dtype], depthwise_conv2d_nchw)
+
+    if layout == 'NCHW' and kernel_layout == 'OIHW':
+        N, CI, H, W = get_const_tuple(data.shape)
+        CO, _, KH, KW = get_const_tuple(kernel.shape)
+    elif layout == 'NHWC' and kernel_layout == 'HWIO' and groups == 1:
+        N, H, W, CI = get_const_tuple(data.shape)
+        KH, KW, _, CO = get_const_tuple(kernel.shape)
+        # Also modify the workload to pick up because later we convert to NCHW
+        # layout.
+        new_data = tvm.placeholder((N, CI, H, W), dtype=data.dtype)
+        new_kernel = tvm.placeholder((CO, CI, KH, KW), dtype=kernel.dtype)
+        new_layout = 'NCHW'
+        workload = autotvm.task.args_to_workload(
+            [new_data, new_kernel, strides, padding, dilation, new_layout, out_dtype], conv2d)
+    else:
+        return None
 
     idxd = tvm.indexdiv
 
     if groups == 1:
-        # query config of this workload
-        workload = autotvm.task.args_to_workload(
-            [data, kernel, strides, padding, dilation, layout, out_dtype], conv2d)
         target = tvm.target.current_target()
         dispatch_ctx = autotvm.DispatchContext.current
         cfg = dispatch_ctx.query(target, workload)
 
         if cfg.is_fallback:  # if is fallback, clear query cache and return None
             autotvm.task.clear_fallback_cache(target, workload)
+            if layout != 'NCHW':
 
 Review comment:
   What about change this to be something like https://github.com/apache/incubator-tvm/pull/4249/files#diff-7ac4fe4f71257850a5471c9c63647bb0R530 (L530)?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services