You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by li...@apache.org on 2019/11/11 03:09:22 UTC

[incubator-tvm] branch master updated: [TOPI][AlterOpLayout][ARM] Enabling NHWC to NCHW layout transformation. (#4249)

This is an automated email from the ASF dual-hosted git repository.

liuyizhi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 1d24366  [TOPI][AlterOpLayout][ARM] Enabling NHWC to NCHW layout transformation. (#4249)
1d24366 is described below

commit 1d2436647bffdcbb1e133b55dc4c7365f604fc3d
Author: Animesh Jain <an...@umich.edu>
AuthorDate: Sun Nov 10 19:09:16 2019 -0800

    [TOPI][AlterOpLayout][ARM] Enabling NHWC to NCHW layout transformation. (#4249)
---
 tests/python/relay/test_pass_alter_op_layout.py |  62 +++++++++++
 tests/python/relay/test_pass_legalize.py        |  44 --------
 topi/python/topi/arm_cpu/conv2d.py              | 130 ++++++++++++------------
 3 files changed, 129 insertions(+), 107 deletions(-)

diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py
index f1200ec..2738690 100644
--- a/tests/python/relay/test_pass_alter_op_layout.py
+++ b/tests/python/relay/test_pass_alter_op_layout.py
@@ -916,6 +916,67 @@ def test_alter_layout_sum():
     assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
 
 
+def test_alter_layout_nhwc_nchw_arm():
+    """ Check NHWC to NHCW conversion for a small sequence of ops."""
+    # Register alter op layout. "level" is used to override the previously registered functions.
+    @register_alter_op_layout("nn.conv2d", level=115)
+    def alter_conv2d(attrs, inputs, tinfos):
+        from topi.arm_cpu.conv2d import _alter_conv2d_layout_arm
+        return _alter_conv2d_layout_arm(attrs, inputs, tinfos, tvm.relay)
+
+    # Check NHWC conversion.
+    def before_nhwc():
+        x = relay.var("x", shape=(1, 56, 56, 64))
+        weight1 = relay.var('weight1', shape=(3, 3, 64, 64))
+        weight2 = relay.var('weight2', shape=(3, 3, 64, 64))
+        y = relay.nn.conv2d(x, weight1,
+                            channels=64,
+                            kernel_size=(3, 3),
+                            data_layout='NHWC',
+                            kernel_layout='HWIO')
+        y = relay.nn.relu(y)
+        y = relay.nn.avg_pool2d(y,
+                                pool_size=(1,1),
+                                layout='NHWC')
+        y = relay.nn.conv2d(y, weight2,
+                            channels=64,
+                            kernel_size=(3, 3),
+                            data_layout='NHWC',
+                            kernel_layout='HWIO')
+        y = relay.nn.relu(y)
+        y = relay.Function(analysis.free_vars(y), y)
+        return y
+
+    def expected_nhwc():
+        x = relay.var("x", shape=(1, 56, 56, 64))
+        weight1 = relay.var('weight1', shape=(3, 3, 64, 64))
+        weight2 = relay.var('weight2', shape=(3, 3, 64, 64))
+        y = relay.layout_transform(x, "NHWC", "NCHW")
+        weight1 = relay.layout_transform(weight1, "HWIO", "OIHW")
+        weight2 = relay.layout_transform(weight2, "HWIO", "OIHW")
+        y = relay.nn.conv2d(y, weight1,
+                            channels=64,
+                            kernel_size=(3, 3))
+        y = relay.nn.relu(y)
+        y = relay.nn.avg_pool2d(y,
+                                pool_size=(1,1))
+        y = relay.nn.conv2d(y, weight2,
+                            channels=64,
+                            kernel_size=(3, 3))
+        y = relay.nn.relu(y)
+        y = relay.layout_transform(y, "NCHW", "NHWC")
+        y = relay.Function(analysis.free_vars(y), y)
+        return y
+
+    a = before_nhwc()
+    a = run_opt_pass(a, transform.AlterOpLayout())
+
+    b = expected_nhwc()
+    b = run_opt_pass(b, transform.InferType())
+
+    assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+
+
 if __name__ == "__main__":
     test_alter_op()
     test_alter_return_none()
@@ -932,3 +993,4 @@ if __name__ == "__main__":
     test_alter_layout_pad()
     test_alter_layout_pool()
     test_alter_layout_sum()
+    test_alter_layout_nhwc_nchw_arm()
diff --git a/tests/python/relay/test_pass_legalize.py b/tests/python/relay/test_pass_legalize.py
index c5303ef..2f0fbee 100644
--- a/tests/python/relay/test_pass_legalize.py
+++ b/tests/python/relay/test_pass_legalize.py
@@ -171,53 +171,9 @@ def test_legalize_multi_input():
 
     assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
 
-def test_legalize_arm_layout_functional():
-    """Test if the legalized conversion yields same result as original"""
-    def get_output(func, data_val, parameters):
-        with relay.build_config(opt_level=0):
-            graph, lib, params = relay.build(func, target='llvm', params=parameters)
-        m = graph_runtime.create(graph, lib, tvm.cpu())
-        m.set_input("data", data_val)
-        m.set_input(**params)
-        m.run()
-        out = m.get_output(0, tvm.nd.empty((1, 224, 224, 32), 'float32')).asnumpy()
-        return out
-
-    def before():
-        n, ic, ih, iw, oc, kh, kw = 1, 16, 224, 224, 32, 3, 3
-        data = relay.var("data", relay.TensorType((n, ih, iw, ic), 'float32'))
-        kernel = relay.var("kernel", relay.TensorType((kh, kw, ic, oc), 'float32'))
-        y = relay.nn.conv2d(data, kernel,
-                            kernel_size=(kh, kw),
-                            channels=oc,
-                            padding=(1, 1),
-                            dilation=(1, 1),
-                            data_layout='NHWC',
-                            kernel_layout='HWIO',
-                            out_dtype='float32')
-        func = relay.Function([data, kernel], y)
-        return func
-
-    @register_legalize("nn.conv2d", level=105)
-    def legalize_conv2d(attrs, inputs, types):
-        from topi.arm_cpu.conv2d import _conv2d_legalize
-        return _conv2d_legalize(attrs, inputs, types)
-
-    a = before()
-    b = run_opt_pass(a, transform.Legalize())
-    assert b.astext().count('transpose') == 3
-
-    wdata = np.random.rand(3, 3, 16, 32) * 10
-    parameters = {"kernel": tvm.nd.array(wdata.astype('float32'))}
-    data_val = np.random.rand(1, 224, 224, 16).astype('float32')
-    ref_out = get_output(a, data_val, parameters)
-    legalized_out = get_output(b, data_val, parameters)
-    np.testing.assert_allclose(ref_out, legalized_out, rtol=0.01)
-
 
 if __name__ == "__main__":
     test_legalize()
     test_legalize_none()
     test_legalize_multiple_ops()
     test_legalize_multi_input()
-    test_legalize_arm_layout_functional()
diff --git a/topi/python/topi/arm_cpu/conv2d.py b/topi/python/topi/arm_cpu/conv2d.py
index c06c739..cbb6085 100644
--- a/topi/python/topi/arm_cpu/conv2d.py
+++ b/topi/python/topi/arm_cpu/conv2d.py
@@ -22,7 +22,6 @@ import logging
 
 import tvm
 from tvm import autotvm
-from tvm import relay
 import tvm.contrib.nnpack
 
 from ..generic import schedule_conv2d_nchw, schedule_conv2d_winograd_without_weight_transform, \
@@ -32,7 +31,6 @@ from ..nn import dilate, pad, conv2d, conv2d_alter_layout, \
                  conv2d_winograd_without_weight_transform, \
                  conv2d_winograd_nnpack_without_weight_transform, \
                  depthwise_conv2d_nchw
-from ..nn import conv2d_legalize
 from ..nn.util import get_const_int, get_pad_tuple
 from ..nn.winograd_util import winograd_transform_matrices
 from .conv2d_spatial_pack import conv2d_spatial_pack_nchw, \
@@ -508,32 +506,63 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
     groups = attrs.get_int('groups')
     data_layout_key = "data_layout" if "data_layout" in new_attrs else "layout"
     layout = attrs[data_layout_key]
+    kernel_layout = attrs['kernel_layout']
     out_dtype = attrs["out_dtype"]
     if out_dtype in ("same", ""):
         out_dtype = tinfos[0].dtype
 
-    if layout != 'NCHW':
-        return None
     if dilation != (1, 1):
         logger.warning("Does not support weight pre-transform for dilated convolution.")
         return None
 
+    # query config of this workload
     data, kernel = tinfos[0:2]
-    N, CI, H, W = get_const_tuple(data.shape)
-    CO, _, KH, KW = get_const_tuple(kernel.shape)
+    if groups == 1:
+        workload = autotvm.task.args_to_workload(
+            [data, kernel, strides, padding, dilation, layout, out_dtype], conv2d)
+    else:
+        workload = autotvm.task.args_to_workload(
+            [data, kernel, strides, padding, dilation, out_dtype], depthwise_conv2d_nchw)
+
+    if layout == 'NCHW' and kernel_layout == 'OIHW':
+        N, CI, H, W = get_const_tuple(data.shape)
+        CO, _, KH, KW = get_const_tuple(kernel.shape)
+    elif layout == 'NHWC' and kernel_layout == 'HWIO':
+        N, H, W, CI = get_const_tuple(data.shape)
+        KH, KW, _, CO = get_const_tuple(kernel.shape)
+        # Also modify the workload to pick up because later we convert to NCHW
+        # layout.
+        new_data = tvm.placeholder((N, CI, H, W), dtype=data.dtype)
+        new_kernel = tvm.placeholder((CO, CI, KH, KW), dtype=kernel.dtype)
+        new_layout = 'NCHW'
+        workload = autotvm.task.args_to_workload(
+            [new_data, new_kernel, strides, padding, dilation, new_layout, out_dtype], conv2d)
+    elif layout == 'NHWC' and kernel_layout == 'HWOI':
+        # This is the case for depthwise convolution.
+        N, H, W, CI = get_const_tuple(data.shape)
+        KH, KW, CO, M = get_const_tuple(kernel.shape)
+        # Also modify the workload to pick up because later we convert to NCHW
+        # layout.
+        new_data = tvm.placeholder((N, CI, H, W), dtype=data.dtype)
+        new_kernel = tvm.placeholder((CO, M, KH, KW), dtype=kernel.dtype)
+        workload = autotvm.task.args_to_workload(
+            [new_data, new_kernel, strides, padding, dilation, out_dtype], depthwise_conv2d_nchw)
+    else:
+        return None
 
     idxd = tvm.indexdiv
 
     if groups == 1:
-        # query config of this workload
-        workload = autotvm.task.args_to_workload(
-            [data, kernel, strides, padding, dilation, layout, out_dtype], conv2d)
         target = tvm.target.current_target()
         dispatch_ctx = autotvm.DispatchContext.current
         cfg = dispatch_ctx.query(target, workload)
 
         if cfg.is_fallback:  # if is fallback, clear query cache and return None
             autotvm.task.clear_fallback_cache(target, workload)
+            if layout == 'NHWC' and kernel_layout == 'HWIO':
+                new_attrs['data_layout'] = 'NCHW'
+                new_attrs['kernel_layout'] = 'OIHW'
+                return F.nn.conv2d(*copy_inputs, **new_attrs)
             return None
 
         if cfg.template_key == 'direct':  # pack weight tensor
@@ -541,7 +570,8 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
             new_attrs['kernel_layout'] = 'OIHW%do' % VC
 
             # Store the same config for the altered operator (workload)
-            new_data = data
+            new_data = tvm.placeholder((N, CI, H, W), dtype=data.dtype)
+            new_attrs[data_layout_key] = 'NCHW'
             new_kernel = tvm.placeholder((idxd(CO, VC), CI, KH, KW, VC), dtype=kernel.dtype)
             new_workload = autotvm.task.args_to_workload(
                 [new_data, new_kernel, strides, padding, dilation, 'NCHW', out_dtype], conv2d)
@@ -560,7 +590,10 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
                 tile_size = _pick_tile_size(tinfos[0], tinfos[1])
                 VC = cfg['tile_bna'].val
 
-            weight = F.nn.contrib_conv2d_winograd_weight_transform(copy_inputs[1],
+            weight = copy_inputs[1]
+            if kernel_layout != 'OIHW':
+                weight = F.transpose(weight, axes=(2, 3, 0, 1))
+            weight = F.nn.contrib_conv2d_winograd_weight_transform(weight,
                                                                    tile_size=tile_size)
             if VC > 0:
                 weight = F.reshape(weight,
@@ -581,9 +614,10 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
 
             copy_inputs[1] = weight
             new_attrs['tile_size'] = tile_size
+            new_attrs[data_layout_key] = 'NCHW'
 
             # Store the same config for the altered operator (workload)
-            new_data = data
+            new_data = tvm.placeholder((N, CI, H, W), dtype=data.dtype)
             new_workload = autotvm.task.args_to_workload(
                 [new_data, new_weight, strides, padding, dilation,
                  new_attrs[data_layout_key], out_dtype, tile_size],
@@ -596,14 +630,21 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
             # for winograd_nnpack_fp16, the the precomputeprune pass must run on device,
             # where float16 is supported
             weight_dtype = 'float32'
+            weight = copy_inputs[1]
+            if kernel_layout != 'OIHW':
+                weight = F.transpose(weight, axes=(2, 3, 0, 1))
+            weight = F.nn.contrib_conv2d_winograd_weight_transform(weight,
+                                                                   tile_size=tile_size)
             transformed_kernel = F.nn.contrib_conv2d_winograd_nnpack_weight_transform(
-                copy_inputs[1],
+                weight,
                 convolution_algorithm=cfg['winograd_nnpack_algorithm'].val,
                 out_dtype=weight_dtype)
             copy_inputs[1] = transformed_kernel
-            new_data = data
+
+            new_data = tvm.placeholder((N, CI, H, W), dtype=data.dtype)
             new_kernel = tvm.placeholder((CO, CI, 8, 8), "float32")
             bias = tvm.placeholder((CO, ), "float32")
+            new_attrs[data_layout_key] = 'NCHW'
             new_workload = autotvm.task.args_to_workload(
                 [new_data, new_kernel, bias, strides,
                  padding, dilation, new_attrs[data_layout_key], out_dtype]
@@ -617,22 +658,30 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
         else:
             raise RuntimeError("Unsupported template_key '%s'" % cfg.template_key)
     else:
-        workload = autotvm.task.args_to_workload(
-            [data, kernel, strides, padding, dilation, out_dtype], depthwise_conv2d_nchw)
         target = tvm.target.current_target()
         dispatch_ctx = autotvm.DispatchContext.current
         cfg = dispatch_ctx.query(target, workload)
 
         if cfg.is_fallback:  # if is fallback, clear query cache and return None
             autotvm.task.clear_fallback_cache(tvm.target.current_target(), workload)
+            if layout == 'NHWC' and kernel_layout == 'HWOI':
+                new_attrs['data_layout'] = 'NCHW'
+                new_attrs['kernel_layout'] = 'OIHW'
+                return F.nn.conv2d(*copy_inputs, **new_attrs)
             return None
         if cfg.template_key == 'contrib_spatial_pack':
             VC = cfg['tile_co'].size[-1]
             new_attrs['kernel_layout'] = 'OIHW%do' % (cfg['tile_co'].size[-1])
 
             # Store the same config for the altered operator (workload)
-            new_data = data
-            CO, M, KH, KW = get_const_tuple(kernel.shape)
+            new_data = tvm.placeholder((N, CI, H, W), dtype=data.dtype)
+            new_attrs[data_layout_key] = 'NCHW'
+            if attrs['kernel_layout'] == 'OIHW':
+                CO, M, KH, KW = get_const_tuple(kernel.shape)
+            elif attrs['kernel_layout'] == 'HWOI':
+                KH, KW, CO, M = get_const_tuple(kernel.shape)
+            else:
+                raise RuntimeError("Depthwise conv should either have OIHW/HWIO kernel layout")
             new_kernel = tvm.placeholder((idxd(CO, VC), M, KH, KW, VC), dtype=kernel.dtype)
             new_workload = autotvm.task.args_to_workload(
                 [new_data, new_kernel, strides, padding, dilation, out_dtype],
@@ -644,48 +693,3 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
             # currently we only have contrib_spatial_pack and direct template
             # add more schedule templates.
             return None
-
-@conv2d_legalize.register("arm_cpu")
-def _conv2d_legalize(attrs, inputs, arg_types):
-    """Legalizes Conv2D op.
-
-    Parameters
-    ----------
-    attrs : tvm.attrs.Attrs
-        Attributes of current convolution
-    inputs : list of tvm.relay.Expr
-        The args of the Relay expr to be legalized
-    types : list of types
-        List of input and output types
-
-    Returns
-    -------
-    result : tvm.relay.Expr
-        The legalized expr
-    """
-
-    if attrs['data_layout'] == 'NHWC':
-        data, kernel = inputs
-        if attrs['kernel_layout'] == 'HWIO':
-            # Handle HWIO layout. This is common in TF graph.
-            kernel = relay.transpose(kernel, axes=(3, 2, 0, 1))
-        elif attrs['kernel_layout'] == 'HWOI':
-            # Handle HWOI layout. This is common in TF depthwise conv2d graph.
-            kernel = relay.transpose(kernel, axes=(2, 3, 0, 1))
-        elif attrs['kernel_layout'] != 'OIHW':
-            return None
-
-        logger.warning("Legalize arm_cpu - NHWC schedule absent. Inserting layout transforms to "
-                       + "fallback to NCHW. This can result in performance degradation.")
-        # Set new attrs for the tranposed conv.
-        new_attrs = {k: attrs[k] for k in attrs.keys()}
-        new_attrs['data_layout'] = 'NCHW'
-        new_attrs['kernel_layout'] = 'OIHW'
-
-        # Convert from NHWC to NCHW.
-        data = relay.transpose(data, axes=(0, 3, 1, 2))
-        conv = relay.nn.conv2d(data, kernel, **new_attrs)
-        # Convert back to original NHWC layout.
-        out = relay.transpose(conv, axes=(0, 2, 3, 1))
-        return out
-    return None