You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by me...@apache.org on 2022/07/07 18:19:08 UTC

[tvm] branch main updated: [TOPI] [Hexagon] Reshape slice op (#11983)

This is an automated email from the ASF dual-hosted git repository.

mehrdadh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new c76d8e2bdb [TOPI] [Hexagon] Reshape slice op (#11983)
c76d8e2bdb is described below

commit c76d8e2bdb0a11cbdc30bdfd631963ba9813662a
Author: abhikran-quic <63...@users.noreply.github.com>
AuthorDate: Thu Jul 7 23:49:01 2022 +0530

    [TOPI] [Hexagon] Reshape slice op (#11983)
    
    * Reshape slice op. This patch adds the initial python implementation reshape slice op for hexagon.
    
    * Add tests for reshape op
---
 python/tvm/topi/hexagon/slice_ops/__init__.py      |   1 +
 python/tvm/topi/hexagon/slice_ops/reshape.py       | 108 +++++++++++++
 .../test_hexagon/topi/test_batch_flatten.py        | 101 -------------
 .../contrib/test_hexagon/topi/test_reshape.py      | 168 +++++++++++++++++++++
 4 files changed, 277 insertions(+), 101 deletions(-)

diff --git a/python/tvm/topi/hexagon/slice_ops/__init__.py b/python/tvm/topi/hexagon/slice_ops/__init__.py
old mode 100755
new mode 100644
index ce1641bfda..617aaed920
--- a/python/tvm/topi/hexagon/slice_ops/__init__.py
+++ b/python/tvm/topi/hexagon/slice_ops/__init__.py
@@ -24,3 +24,4 @@ from .batch_flatten import batch_flatten_compute, batch_flatten_stir_schedule
 from .softmax_slice import *
 from .clip import *
 from .conv2d import *
+from .reshape import reshape_compute, reshape_stir_schedule
diff --git a/python/tvm/topi/hexagon/slice_ops/reshape.py b/python/tvm/topi/hexagon/slice_ops/reshape.py
new file mode 100644
index 0000000000..374c20bb72
--- /dev/null
+++ b/python/tvm/topi/hexagon/slice_ops/reshape.py
@@ -0,0 +1,108 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Hexagon slice reshape compute and schedule"""
+from tvm import te, tir, topi
+from ..utils import get_layout_transform_fn
+
+
+def reshape_compute(inp: te.Tensor, new_shape: tuple) -> te.Tensor:
+    """Compute for slice reshape op for hexagon.
+    This op makes the following assumptions:
+    1. This op is written for a sliced reshape operation.
+    2. The input is assumed to be in NHWC layout.
+
+    Parameters
+    ----------
+    Input : te.Tensor
+        Input tensor
+    New Shape: tuple
+        Output shape
+    Returns
+    -------
+    Output : te.Tensor
+        Output of applying reshape operation on input
+    """
+    return topi.transform.reshape(inp, new_shape)
+
+
+def stir_schedule_nhwc_1024c(
+    out: te.Tensor,
+    inp: te.Tensor,
+    out_layout: str,
+    in_layout: str,
+) -> tir.Schedule:
+    """Schedule for output layout: nhwc-1024c-2d"""
+    reshape_func = te.create_prim_func([inp, out])
+    sch = tir.Schedule(reshape_func, debug_mask="all")
+    compute = sch.get_block("T_reshape")
+
+    sch.transform_layout(compute, inp.name, get_layout_transform_fn(in_layout))
+    sch.transform_layout(compute, out.name, get_layout_transform_fn(out_layout))
+    i, j = sch.get_loops(compute)
+    jout, channel = sch.split(j, [None, inp.shape[3]])
+    height, width = sch.split(jout, [inp.shape[1], inp.shape[2]])
+    channelo, channeli = sch.split(channel, [None, 1024])
+    channelio, channelii = sch.split(channeli, [None, 64])
+    sch.reorder(i, height, width, channelo, channelio, channelii)
+    sch.vectorize(channelii)
+    return sch
+
+
+def stir_schedule_nhwc_8h2w32c2w(
+    out: te.Tensor,
+    inp: te.Tensor,
+    out_layout: str,
+    in_layout: str,
+) -> tir.Schedule:
+    """Schedule for input and output layout nhwc-8h2w32c2w"""
+    reshape_func = te.create_prim_func([inp, out])
+    sch = tir.Schedule(reshape_func, debug_mask="all")
+    compute = sch.get_block("T_reshape")
+
+    sch.transform_layout(compute, inp.name, get_layout_transform_fn(in_layout))
+    sch.transform_layout(compute, out.name, get_layout_transform_fn(out_layout))
+    return sch
+
+
+def reshape_stir_schedule(
+    out: te.Tensor,
+    inp: te.Tensor,
+    output_layout: str,
+    input_layout: str,
+) -> tir.Schedule:
+    """STIR schedule definition for the compute of reshape compute.
+    Parameters
+    ----------
+    outputs : te.Tensor
+        The output tensor as returned by a call to reshape_compute
+    input : te.Tensor
+        Input tensor to reshape
+    out_layout: str
+        The transformation function definition for the expected output layout
+    in_layout: str
+        The transformation function definition for the input layout
+    Returns
+    -------
+    sch : tvm.tir.Schedule
+        The STIR schedule for slice reshape compute
+    """
+    if output_layout == "nhwc-8h2w32c2w-2d":
+        return stir_schedule_nhwc_8h2w32c2w(out, inp, output_layout, input_layout)
+    if output_layout == "nc-1024-2d":
+        return stir_schedule_nhwc_1024c(out, inp, output_layout, input_layout)
+    raise RuntimeError(f"Unexpected layout '{output_layout}'")
diff --git a/tests/python/contrib/test_hexagon/topi/test_batch_flatten.py b/tests/python/contrib/test_hexagon/topi/test_batch_flatten.py
deleted file mode 100644
index 3a056116d4..0000000000
--- a/tests/python/contrib/test_hexagon/topi/test_batch_flatten.py
+++ /dev/null
@@ -1,101 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-import numpy as np
-import pytest
-
-import tvm
-import tvm.testing
-import tvm.topi.hexagon.slice_ops as sl
-from tvm import te, topi
-from tvm.contrib.hexagon.build import HexagonLauncher
-from tvm.topi import testing
-
-from ..infrastructure import allocate_hexagon_array, transform_numpy
-
-
-class BaseTestBatchFlatten:
-    input_shape = tvm.testing.parameter(
-        (1, 1, 1, 2048),
-        (1, 2, 4, 2048),
-        (1, 8, 8, 1024),
-        (2, 4, 8, 1024),
-        (2, 3, 5, 2048),
-    )
-    input_layout, input_axis_sep = tvm.testing.parameters(("nhwc-1024c-2d", [4]))
-    output_layout, output_axis_sep = tvm.testing.parameters(("nc-1024-2d", [2]))
-    data_type = tvm.testing.parameter("float16")
-
-
-class TestBatchFlatten(BaseTestBatchFlatten):
-    @tvm.testing.fixture
-    def output_shape(self, input_shape):
-        return input_shape[0], input_shape[1] * input_shape[2] * input_shape[3]
-
-    @tvm.testing.requires_hexagon
-    def test_batch_flatten(
-        self,
-        data_type,
-        input_shape,
-        input_layout,
-        input_axis_sep,
-        output_shape,
-        output_layout,
-        output_axis_sep,
-        hexagon_session,
-    ):
-        target_hexagon = tvm.target.hexagon("v69")
-        target = tvm.target.Target(target_hexagon, host=target_hexagon)
-        A = te.placeholder(input_shape, name="A", dtype=data_type)
-        D = sl.batch_flatten_compute(A)
-        tir_s = sl.batch_flatten_stir_schedule(
-            D,
-            A,
-            output_layout,
-            input_layout,
-        )
-        func_name = "batch_flatten"
-        with tvm.transform.PassContext(opt_level=3):
-            runtime_module = tvm.build(tir_s.mod, target=target, name=func_name)
-
-        mod = hexagon_session.load_module(runtime_module)
-
-        a_numpy = (np.random.uniform(-1, 1, input_shape)).astype(data_type)
-        ref = np.reshape(a_numpy, output_shape)
-
-        input_np_transformed = transform_numpy(a_numpy, "nhwc", input_layout)
-        ref_np_transformed = transform_numpy(ref, "nhwc", output_layout)
-
-        a_tvm = allocate_hexagon_array(
-            hexagon_session.device,
-            data=input_np_transformed,
-            axis_separators=input_axis_sep,
-            mem_scope="global.vtcm",
-        )
-        output = allocate_hexagon_array(
-            hexagon_session.device,
-            ref_np_transformed.shape,
-            data_type,
-            axis_separators=output_axis_sep,
-            mem_scope="global.vtcm",
-        )
-        mod(a_tvm, output)
-        np.testing.assert_allclose(output.numpy(), ref_np_transformed, atol=1e-07, rtol=0)
-
-
-if __name__ == "__main__":
-    tvm.testing.main(pytest.main(sys.argv))
diff --git a/tests/python/contrib/test_hexagon/topi/test_reshape.py b/tests/python/contrib/test_hexagon/topi/test_reshape.py
new file mode 100644
index 0000000000..2def86ad83
--- /dev/null
+++ b/tests/python/contrib/test_hexagon/topi/test_reshape.py
@@ -0,0 +1,168 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import numpy as np
+import pytest
+
+import tvm
+import tvm.testing
+import tvm.topi.hexagon.slice_ops as sl
+from tvm import te, topi
+from tvm.contrib.hexagon.build import HexagonLauncher
+from tvm.topi import testing
+
+from ..infrastructure import allocate_hexagon_array, transform_numpy
+
+
+def reshape_helper(
+    func,
+    fcompute,
+    fschedule,
+    data_type,
+    input_shape,
+    input_layout,
+    output_shape,
+    output_layout,
+    hexagon_session,
+):
+
+    target_hexagon = tvm.target.hexagon("v69")
+    target = tvm.target.Target(target_hexagon, host=target_hexagon)
+    A = te.placeholder(input_shape, name="A", dtype=data_type)
+    if func == "reshape":
+        D = fcompute(A, output_shape)
+    elif func == "batch_flatten":
+        D = fcompute(A)
+    else:
+        raise RuntimeError(f"Unexpected func'{func}'")
+    tir_s = fschedule(
+        D,
+        A,
+        output_layout,
+        input_layout,
+    )
+    with tvm.transform.PassContext(opt_level=3):
+        print("output of tvm.lower", tvm.lower(tir_s.mod, name=func))
+        runtime_module = tvm.build(tir_s.mod, target=target, name=func)
+
+    mod = hexagon_session.load_module(runtime_module)
+
+    a_numpy = (np.random.uniform(-1, 1, input_shape)).astype(data_type)
+    ref = np.reshape(a_numpy, output_shape)
+
+    input_np_transformed = transform_numpy(a_numpy, "nhwc", input_layout)
+    ref_np_transformed = transform_numpy(ref, "nhwc", output_layout)
+    input_axis_sep = [4]
+    if output_layout == "nhwc-8h2w32c2w-2d":
+        output_axis_sep = [4]
+    elif output_layout == "nc-1024-2d":
+        output_axis_sep = [2]
+    else:
+        raise RuntimeError(f"Unexpected layout '{output_layout}'")
+    a_tvm = allocate_hexagon_array(
+        hexagon_session.device,
+        data=input_np_transformed,
+        axis_separators=input_axis_sep,
+        mem_scope="global.vtcm",
+    )
+    output = allocate_hexagon_array(
+        hexagon_session.device,
+        ref_np_transformed.shape,
+        data_type,
+        axis_separators=output_axis_sep,
+        mem_scope="global.vtcm",
+    )
+    mod(a_tvm, output)
+    np.testing.assert_allclose(output.numpy(), ref_np_transformed, atol=1e-07, rtol=0)
+
+
+batch_flatten_tests = (
+    ([1, 1, 1, 2048], [1, 2048], "nhwc-1024c-2d", "nc-1024-2d", "float16"),
+    ([1, 2, 4, 2048], [1, 2 * 4 * 2048], "nhwc-1024c-2d", "nc-1024-2d", "float16"),
+    ([1, 8, 8, 1024], [1, 8 * 8 * 1024], "nhwc-1024c-2d", "nc-1024-2d", "float16"),
+    ([2, 4, 8, 1024], [2, 4 * 8 * 1024], "nhwc-1024c-2d", "nc-1024-2d", "float16"),
+)
+
+
+class BaseTestBatchFlatten:
+    (
+        input_shape,
+        output_shape,
+        input_layout,
+        output_layout,
+        data_type,
+    ) = tvm.testing.parameters(*batch_flatten_tests)
+
+
+class TestBatchFlatten(BaseTestBatchFlatten):
+    @tvm.testing.requires_hexagon
+    def test_batch_flatten(
+        self,
+        data_type,
+        input_shape,
+        input_layout,
+        output_shape,
+        output_layout,
+        hexagon_session,
+    ):
+        reshape_helper(
+            "batch_flatten",
+            sl.batch_flatten_compute,
+            sl.batch_flatten_stir_schedule,
+            data_type,
+            input_shape,
+            input_layout,
+            output_shape,
+            output_layout,
+            hexagon_session,
+        )
+
+
+class BaseTestReshape(BaseTestBatchFlatten):
+    (input_shape, output_shape, input_layout, output_layout, data_type,) = tvm.testing.parameters(
+        *batch_flatten_tests,
+        ([1, 8, 4, 64], [1, 8, 8, 32], "nhwc-8h2w32c2w-2d", "nhwc-8h2w32c2w-2d", "float16"),
+        ([1, 16, 8, 128], [1, 16, 16, 64], "nhwc-8h2w32c2w-2d", "nhwc-8h2w32c2w-2d", "float16"),
+    )
+
+
+class TestReshape(BaseTestReshape):
+    @tvm.testing.requires_hexagon
+    def test_reshape(
+        self,
+        data_type,
+        input_shape,
+        input_layout,
+        output_shape,
+        output_layout,
+        hexagon_session,
+    ):
+        reshape_helper(
+            "reshape",
+            sl.reshape_compute,
+            sl.reshape_stir_schedule,
+            data_type,
+            input_shape,
+            input_layout,
+            output_shape,
+            output_layout,
+            hexagon_session,
+        )
+
+
+if __name__ == "__main__":
+    tvm.testing.main()