You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by cs...@apache.org on 2022/05/03 21:13:01 UTC

[tvm] branch main updated: [Hexagon] Add schedule and test for conv2d_transpose_nchw (#11175)

This is an automated email from the ASF dual-hosted git repository.

csullivan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new eb3ce911d0 [Hexagon] Add schedule and test for conv2d_transpose_nchw (#11175)
eb3ce911d0 is described below

commit eb3ce911d04c2d7915d1b2d5f29f333595785b2b
Author: Farshid Salemi Parizi <fp...@octoml.ai>
AuthorDate: Tue May 3 14:12:56 2022 -0700

    [Hexagon] Add schedule and test for conv2d_transpose_nchw (#11175)
    
    * Add test for registered scheduales - depthwise_conv2d
    
    * added more test to depthwise_conv2
    
    * adding new line at the end of the file
    
    * reformatted the file
    
    * resolve comments
    
    * add schedule and tests for conv2d_transpose_nchw
    
    * registering conv2d_transpose strategy and clean up test
---
 python/tvm/relay/op/strategy/hexagon.py            |  20 +++
 python/tvm/topi/hexagon/conv2d.py                  |  26 ++++
 .../test_hexagon/topi/test_conv2d_transpose.py     | 157 +++++++++++++++++++++
 3 files changed, 203 insertions(+)

diff --git a/python/tvm/relay/op/strategy/hexagon.py b/python/tvm/relay/op/strategy/hexagon.py
index cfd9a8b5dd..da15a54125 100644
--- a/python/tvm/relay/op/strategy/hexagon.py
+++ b/python/tvm/relay/op/strategy/hexagon.py
@@ -112,6 +112,26 @@ def softmax_strategy_hexagon(attrs, inputs, out_type, target):
     return strategy
 
 
+@conv2d_transpose_strategy.register("hexagon")
+def conv2d_transpose_strategy_hexagon(attrs, inputs, out_type, target):
+    """conv2d_transpose hexagon strategy"""
+    layout = attrs.data_layout
+    dilation = get_const_tuple(attrs.dilation)
+    groups = attrs.groups
+    assert layout == "NCHW", "only support nchw for now"
+    assert dilation == (1, 1), "not support dilate now"
+    strategy = _op.OpStrategy()
+    if groups == 1:
+        strategy.add_implementation(
+            wrap_compute_conv2d_transpose(topi.nn.conv2d_transpose_nchw),
+            wrap_topi_schedule(topi.hexagon.schedule_conv2d_transpose_nchw),
+            name="conv2d_transpose_nchw.generic",
+        )
+    else:
+        raise RuntimeError("Unsupported conv2d_transpose layout {}".format(layout))
+    return strategy
+
+
 # --- Op schedule registration
 
 
diff --git a/python/tvm/topi/hexagon/conv2d.py b/python/tvm/topi/hexagon/conv2d.py
index 4f564faa0a..d8f44d6638 100644
--- a/python/tvm/topi/hexagon/conv2d.py
+++ b/python/tvm/topi/hexagon/conv2d.py
@@ -18,6 +18,7 @@
 """Schedule for conv2d"""
 
 import tvm
+from ..utils import traverse_inline
 
 
 def schedule_conv2d_nhwc(outs):
@@ -60,3 +61,28 @@ def schedule_depthwise_conv2d_nchw(outs):
 
 def schedule_depthwise_conv2d_nhwc(out):
     return schedule_conv2d_nhwc(out)
+
+
+def schedule_conv2d_transpose_nchw(outs):
+    """Create schedule for tensors"""
+    outs = [outs] if isinstance(outs, tvm.te.tensor.Tensor) else outs
+    s = schedule_conv2d_nchw(outs)
+
+    def _callback(op):
+        if "unpack_nchwc" in op.tag:
+            conv_out = op.input_tensors[0]
+            # retrieve data
+            data_vec = conv_out.op.input_tensors[0]
+            if isinstance(data_vec, tvm.te.ComputeOp):
+                data_pad = data_vec.op.input_tensors[0]
+                data_dilate = data_pad.op.input_tensors[0]
+                s[data_dilate].compute_inline()
+                s[data_pad].compute_inline()
+            # retrieve kernel
+            kernel_vec = conv_out.op.input_tensors[1]
+            if isinstance(kernel_vec, tvm.te.ComputeOp):
+                kernel_transform = kernel_vec.op.input_tensors[0]
+                s[kernel_transform].compute_inline()
+
+    traverse_inline(s, outs[0].op, _callback)
+    return s
diff --git a/tests/python/contrib/test_hexagon/topi/test_conv2d_transpose.py b/tests/python/contrib/test_hexagon/topi/test_conv2d_transpose.py
new file mode 100644
index 0000000000..1dbac67aeb
--- /dev/null
+++ b/tests/python/contrib/test_hexagon/topi/test_conv2d_transpose.py
@@ -0,0 +1,157 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test code for transposed convolution."""
+import numpy as np
+import tvm
+import tvm.testing
+from tvm import te
+from tvm import topi
+import tvm.topi.testing
+from tvm.contrib.pickle_memoize import memoize
+from tvm.topi.utils import get_const_tuple
+from ..conftest import requires_hexagon_toolchain
+
+
+# TODO Should add kernal to tvm.testing.fixture
+
+random_seed = tvm.testing.parameter(0)
+
+
+@tvm.testing.fixture
+def shift_shape(batch):
+    return batch
+
+
+@tvm.testing.fixture
+def shift_shape(in_channel):
+    return in_channel
+
+
+@tvm.testing.fixture
+def shift_shape(in_size):
+    return in_size
+
+
+@tvm.testing.fixture
+def shift_shape(num_filter):
+    return num_filter
+
+
+@tvm.testing.fixture
+def shift_shape(stride):
+    return stride
+
+
+@tvm.testing.fixture
+def shift_shape(padding):
+    return padding
+
+
+@tvm.testing.fixture
+def shift_shape(output_padding):
+    return output_padding
+
+
+class BaseConv2DTransposeTests:
+    @requires_hexagon_toolchain
+    def test_conv2d(
+        self,
+        hexagon_session,
+        batch,
+        in_channel,
+        in_size,
+        num_filter,
+        stride,
+        padding,
+        output_padding,
+        random_seed,
+    ):
+
+        target_hexagon = tvm.target.hexagon("v68")
+
+        in_height, in_width = in_size
+        kernel_height, kernel_width = (1, 1)
+        stride_height, stride_width = stride
+        pad_top, pad_left, pad_bottom, pad_right = padding
+
+        A = te.placeholder((batch, in_channel, in_height, in_width), name="A")
+        W = te.placeholder((in_channel, num_filter, kernel_height, kernel_width), name="W")
+
+        a_shape = get_const_tuple(A.shape)
+        w_shape = get_const_tuple(W.shape)
+        dtype = A.dtype
+
+        def get_ref_data():
+
+            np.random.seed(random_seed)
+            a_np = np.random.uniform(size=a_shape).astype(dtype)
+            w_np = np.random.uniform(size=w_shape).astype(dtype)
+            b_np = tvm.topi.testing.conv2d_transpose_nchw_python(
+                a_np, w_np, stride, padding, output_padding
+            )
+            c_np = np.maximum(b_np, 0)
+            return a_np, w_np, b_np, c_np
+
+        a_np, w_np, b_np, c_np = get_ref_data()
+
+        fcompute_args = (
+            A,
+            W,
+            [stride_height, stride_width],
+            [pad_top, pad_left, pad_bottom, pad_right],
+            A.dtype,
+            output_padding,
+        )
+
+        with tvm.target.Target(target_hexagon):
+            fcompute = topi.nn.conv2d_transpose_nchw
+            fschedule = topi.hexagon.schedule_conv2d_transpose_nchw
+            B = fcompute(*fcompute_args)
+            C = topi.nn.relu(B)
+            s1 = fschedule([B])
+            s2 = fschedule([C])
+
+            dev = hexagon_session.device
+
+            a = tvm.nd.array(a_np, dev)
+            w = tvm.nd.array(w_np, dev)
+            b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), dev)
+            c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev)
+
+            func1 = tvm.build(s1, [A, W, B], tvm.target.Target(target_hexagon, host=target_hexagon))
+            func2 = tvm.build(s2, [A, W, C], tvm.target.Target(target_hexagon, host=target_hexagon))
+
+            mod1 = hexagon_session.load_module(func1)
+            mod2 = hexagon_session.load_module(func2)
+
+            mod1(a, w, b)
+            mod2(a, w, c)
+            tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-5)
+            tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)
+
+
+class TestConv2DTranspose(BaseConv2DTransposeTests):
+
+    (batch, in_channel, in_size, num_filter, stride) = tvm.testing.parameters(
+        (1, 3, (224, 224), 1, (1, 1)),
+        (1, 8, (224, 224), 1, (1, 1)),
+        (1, 512, (8, 1), 128, (31, 1)),
+        (1, 32, (8192, 1), 1, (1, 1)),
+    )
+
+    padding = tvm.testing.parameter((0, 0, 0, 0))
+    output_padding = tvm.testing.parameter((0, 0))