You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by me...@apache.org on 2022/07/07 18:19:08 UTC
[tvm] branch main updated: [TOPI] [Hexagon] Reshape slice op (#11983)
This is an automated email from the ASF dual-hosted git repository.
mehrdadh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new c76d8e2bdb [TOPI] [Hexagon] Reshape slice op (#11983)
c76d8e2bdb is described below
commit c76d8e2bdb0a11cbdc30bdfd631963ba9813662a
Author: abhikran-quic <63...@users.noreply.github.com>
AuthorDate: Thu Jul 7 23:49:01 2022 +0530
[TOPI] [Hexagon] Reshape slice op (#11983)
* Reshape slice op. This patch adds the initial python implementation reshape slice op for hexagon.
* Add tests for reshape op
---
python/tvm/topi/hexagon/slice_ops/__init__.py | 1 +
python/tvm/topi/hexagon/slice_ops/reshape.py | 108 +++++++++++++
.../test_hexagon/topi/test_batch_flatten.py | 101 -------------
.../contrib/test_hexagon/topi/test_reshape.py | 168 +++++++++++++++++++++
4 files changed, 277 insertions(+), 101 deletions(-)
diff --git a/python/tvm/topi/hexagon/slice_ops/__init__.py b/python/tvm/topi/hexagon/slice_ops/__init__.py
old mode 100755
new mode 100644
index ce1641bfda..617aaed920
--- a/python/tvm/topi/hexagon/slice_ops/__init__.py
+++ b/python/tvm/topi/hexagon/slice_ops/__init__.py
@@ -24,3 +24,4 @@ from .batch_flatten import batch_flatten_compute, batch_flatten_stir_schedule
from .softmax_slice import *
from .clip import *
from .conv2d import *
+from .reshape import reshape_compute, reshape_stir_schedule
diff --git a/python/tvm/topi/hexagon/slice_ops/reshape.py b/python/tvm/topi/hexagon/slice_ops/reshape.py
new file mode 100644
index 0000000000..374c20bb72
--- /dev/null
+++ b/python/tvm/topi/hexagon/slice_ops/reshape.py
@@ -0,0 +1,108 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Hexagon slice reshape compute and schedule"""
+from tvm import te, tir, topi
+from ..utils import get_layout_transform_fn
+
+
+def reshape_compute(inp: te.Tensor, new_shape: tuple) -> te.Tensor:
+ """Compute for slice reshape op for hexagon.
+ This op makes the following assumptions:
+ 1. This op is written for a sliced reshape operation.
+ 2. The input is assumed to be in NHWC layout.
+
+ Parameters
+ ----------
+ Input : te.Tensor
+ Input tensor
+ New Shape: tuple
+ Output shape
+ Returns
+ -------
+ Output : te.Tensor
+ Output of applying reshape operation on input
+ """
+ return topi.transform.reshape(inp, new_shape)
+
+
+def stir_schedule_nhwc_1024c(
+ out: te.Tensor,
+ inp: te.Tensor,
+ out_layout: str,
+ in_layout: str,
+) -> tir.Schedule:
+ """Schedule for output layout: nhwc-1024c-2d"""
+ reshape_func = te.create_prim_func([inp, out])
+ sch = tir.Schedule(reshape_func, debug_mask="all")
+ compute = sch.get_block("T_reshape")
+
+ sch.transform_layout(compute, inp.name, get_layout_transform_fn(in_layout))
+ sch.transform_layout(compute, out.name, get_layout_transform_fn(out_layout))
+ i, j = sch.get_loops(compute)
+ jout, channel = sch.split(j, [None, inp.shape[3]])
+ height, width = sch.split(jout, [inp.shape[1], inp.shape[2]])
+ channelo, channeli = sch.split(channel, [None, 1024])
+ channelio, channelii = sch.split(channeli, [None, 64])
+ sch.reorder(i, height, width, channelo, channelio, channelii)
+ sch.vectorize(channelii)
+ return sch
+
+
+def stir_schedule_nhwc_8h2w32c2w(
+ out: te.Tensor,
+ inp: te.Tensor,
+ out_layout: str,
+ in_layout: str,
+) -> tir.Schedule:
+ """Schedule for input and output layout nhwc-8h2w32c2w"""
+ reshape_func = te.create_prim_func([inp, out])
+ sch = tir.Schedule(reshape_func, debug_mask="all")
+ compute = sch.get_block("T_reshape")
+
+ sch.transform_layout(compute, inp.name, get_layout_transform_fn(in_layout))
+ sch.transform_layout(compute, out.name, get_layout_transform_fn(out_layout))
+ return sch
+
+
+def reshape_stir_schedule(
+ out: te.Tensor,
+ inp: te.Tensor,
+ output_layout: str,
+ input_layout: str,
+) -> tir.Schedule:
+ """STIR schedule definition for the compute of reshape compute.
+ Parameters
+ ----------
+ outputs : te.Tensor
+ The output tensor as returned by a call to reshape_compute
+ input : te.Tensor
+ Input tensor to reshape
+ out_layout: str
+ The transformation function definition for the expected output layout
+ in_layout: str
+ The transformation function definition for the input layout
+ Returns
+ -------
+ sch : tvm.tir.Schedule
+ The STIR schedule for slice reshape compute
+ """
+ if output_layout == "nhwc-8h2w32c2w-2d":
+ return stir_schedule_nhwc_8h2w32c2w(out, inp, output_layout, input_layout)
+ if output_layout == "nc-1024-2d":
+ return stir_schedule_nhwc_1024c(out, inp, output_layout, input_layout)
+ raise RuntimeError(f"Unexpected layout '{output_layout}'")
diff --git a/tests/python/contrib/test_hexagon/topi/test_batch_flatten.py b/tests/python/contrib/test_hexagon/topi/test_batch_flatten.py
deleted file mode 100644
index 3a056116d4..0000000000
--- a/tests/python/contrib/test_hexagon/topi/test_batch_flatten.py
+++ /dev/null
@@ -1,101 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-import numpy as np
-import pytest
-
-import tvm
-import tvm.testing
-import tvm.topi.hexagon.slice_ops as sl
-from tvm import te, topi
-from tvm.contrib.hexagon.build import HexagonLauncher
-from tvm.topi import testing
-
-from ..infrastructure import allocate_hexagon_array, transform_numpy
-
-
-class BaseTestBatchFlatten:
- input_shape = tvm.testing.parameter(
- (1, 1, 1, 2048),
- (1, 2, 4, 2048),
- (1, 8, 8, 1024),
- (2, 4, 8, 1024),
- (2, 3, 5, 2048),
- )
- input_layout, input_axis_sep = tvm.testing.parameters(("nhwc-1024c-2d", [4]))
- output_layout, output_axis_sep = tvm.testing.parameters(("nc-1024-2d", [2]))
- data_type = tvm.testing.parameter("float16")
-
-
-class TestBatchFlatten(BaseTestBatchFlatten):
- @tvm.testing.fixture
- def output_shape(self, input_shape):
- return input_shape[0], input_shape[1] * input_shape[2] * input_shape[3]
-
- @tvm.testing.requires_hexagon
- def test_batch_flatten(
- self,
- data_type,
- input_shape,
- input_layout,
- input_axis_sep,
- output_shape,
- output_layout,
- output_axis_sep,
- hexagon_session,
- ):
- target_hexagon = tvm.target.hexagon("v69")
- target = tvm.target.Target(target_hexagon, host=target_hexagon)
- A = te.placeholder(input_shape, name="A", dtype=data_type)
- D = sl.batch_flatten_compute(A)
- tir_s = sl.batch_flatten_stir_schedule(
- D,
- A,
- output_layout,
- input_layout,
- )
- func_name = "batch_flatten"
- with tvm.transform.PassContext(opt_level=3):
- runtime_module = tvm.build(tir_s.mod, target=target, name=func_name)
-
- mod = hexagon_session.load_module(runtime_module)
-
- a_numpy = (np.random.uniform(-1, 1, input_shape)).astype(data_type)
- ref = np.reshape(a_numpy, output_shape)
-
- input_np_transformed = transform_numpy(a_numpy, "nhwc", input_layout)
- ref_np_transformed = transform_numpy(ref, "nhwc", output_layout)
-
- a_tvm = allocate_hexagon_array(
- hexagon_session.device,
- data=input_np_transformed,
- axis_separators=input_axis_sep,
- mem_scope="global.vtcm",
- )
- output = allocate_hexagon_array(
- hexagon_session.device,
- ref_np_transformed.shape,
- data_type,
- axis_separators=output_axis_sep,
- mem_scope="global.vtcm",
- )
- mod(a_tvm, output)
- np.testing.assert_allclose(output.numpy(), ref_np_transformed, atol=1e-07, rtol=0)
-
-
-if __name__ == "__main__":
- tvm.testing.main(pytest.main(sys.argv))
diff --git a/tests/python/contrib/test_hexagon/topi/test_reshape.py b/tests/python/contrib/test_hexagon/topi/test_reshape.py
new file mode 100644
index 0000000000..2def86ad83
--- /dev/null
+++ b/tests/python/contrib/test_hexagon/topi/test_reshape.py
@@ -0,0 +1,168 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import numpy as np
+import pytest
+
+import tvm
+import tvm.testing
+import tvm.topi.hexagon.slice_ops as sl
+from tvm import te, topi
+from tvm.contrib.hexagon.build import HexagonLauncher
+from tvm.topi import testing
+
+from ..infrastructure import allocate_hexagon_array, transform_numpy
+
+
+def reshape_helper(
+ func,
+ fcompute,
+ fschedule,
+ data_type,
+ input_shape,
+ input_layout,
+ output_shape,
+ output_layout,
+ hexagon_session,
+):
+
+ target_hexagon = tvm.target.hexagon("v69")
+ target = tvm.target.Target(target_hexagon, host=target_hexagon)
+ A = te.placeholder(input_shape, name="A", dtype=data_type)
+ if func == "reshape":
+ D = fcompute(A, output_shape)
+ elif func == "batch_flatten":
+ D = fcompute(A)
+ else:
+ raise RuntimeError(f"Unexpected func'{func}'")
+ tir_s = fschedule(
+ D,
+ A,
+ output_layout,
+ input_layout,
+ )
+ with tvm.transform.PassContext(opt_level=3):
+ print("output of tvm.lower", tvm.lower(tir_s.mod, name=func))
+ runtime_module = tvm.build(tir_s.mod, target=target, name=func)
+
+ mod = hexagon_session.load_module(runtime_module)
+
+ a_numpy = (np.random.uniform(-1, 1, input_shape)).astype(data_type)
+ ref = np.reshape(a_numpy, output_shape)
+
+ input_np_transformed = transform_numpy(a_numpy, "nhwc", input_layout)
+ ref_np_transformed = transform_numpy(ref, "nhwc", output_layout)
+ input_axis_sep = [4]
+ if output_layout == "nhwc-8h2w32c2w-2d":
+ output_axis_sep = [4]
+ elif output_layout == "nc-1024-2d":
+ output_axis_sep = [2]
+ else:
+ raise RuntimeError(f"Unexpected layout '{output_layout}'")
+ a_tvm = allocate_hexagon_array(
+ hexagon_session.device,
+ data=input_np_transformed,
+ axis_separators=input_axis_sep,
+ mem_scope="global.vtcm",
+ )
+ output = allocate_hexagon_array(
+ hexagon_session.device,
+ ref_np_transformed.shape,
+ data_type,
+ axis_separators=output_axis_sep,
+ mem_scope="global.vtcm",
+ )
+ mod(a_tvm, output)
+ np.testing.assert_allclose(output.numpy(), ref_np_transformed, atol=1e-07, rtol=0)
+
+
+batch_flatten_tests = (
+ ([1, 1, 1, 2048], [1, 2048], "nhwc-1024c-2d", "nc-1024-2d", "float16"),
+ ([1, 2, 4, 2048], [1, 2 * 4 * 2048], "nhwc-1024c-2d", "nc-1024-2d", "float16"),
+ ([1, 8, 8, 1024], [1, 8 * 8 * 1024], "nhwc-1024c-2d", "nc-1024-2d", "float16"),
+ ([2, 4, 8, 1024], [2, 4 * 8 * 1024], "nhwc-1024c-2d", "nc-1024-2d", "float16"),
+)
+
+
+class BaseTestBatchFlatten:
+ (
+ input_shape,
+ output_shape,
+ input_layout,
+ output_layout,
+ data_type,
+ ) = tvm.testing.parameters(*batch_flatten_tests)
+
+
+class TestBatchFlatten(BaseTestBatchFlatten):
+ @tvm.testing.requires_hexagon
+ def test_batch_flatten(
+ self,
+ data_type,
+ input_shape,
+ input_layout,
+ output_shape,
+ output_layout,
+ hexagon_session,
+ ):
+ reshape_helper(
+ "batch_flatten",
+ sl.batch_flatten_compute,
+ sl.batch_flatten_stir_schedule,
+ data_type,
+ input_shape,
+ input_layout,
+ output_shape,
+ output_layout,
+ hexagon_session,
+ )
+
+
+class BaseTestReshape(BaseTestBatchFlatten):
+ (input_shape, output_shape, input_layout, output_layout, data_type,) = tvm.testing.parameters(
+ *batch_flatten_tests,
+ ([1, 8, 4, 64], [1, 8, 8, 32], "nhwc-8h2w32c2w-2d", "nhwc-8h2w32c2w-2d", "float16"),
+ ([1, 16, 8, 128], [1, 16, 16, 64], "nhwc-8h2w32c2w-2d", "nhwc-8h2w32c2w-2d", "float16"),
+ )
+
+
+class TestReshape(BaseTestReshape):
+ @tvm.testing.requires_hexagon
+ def test_reshape(
+ self,
+ data_type,
+ input_shape,
+ input_layout,
+ output_shape,
+ output_layout,
+ hexagon_session,
+ ):
+ reshape_helper(
+ "reshape",
+ sl.reshape_compute,
+ sl.reshape_stir_schedule,
+ data_type,
+ input_shape,
+ input_layout,
+ output_shape,
+ output_layout,
+ hexagon_session,
+ )
+
+
+if __name__ == "__main__":
+ tvm.testing.main()