You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by kp...@apache.org on 2022/07/16 15:10:44 UTC
[tvm] branch main updated: [TOPI] [Hexagon] Uint8 Reshape and batch flatten slice ops (#12037)
This is an automated email from the ASF dual-hosted git repository.
kparzysz pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new c0e996e291 [TOPI] [Hexagon] Uint8 Reshape and batch flatten slice ops (#12037)
c0e996e291 is described below
commit c0e996e2914585fe6b0c11fb2efdaea5c6b9daf9
Author: abhikran-quic <63...@users.noreply.github.com>
AuthorDate: Sat Jul 16 20:40:39 2022 +0530
[TOPI] [Hexagon] Uint8 Reshape and batch flatten slice ops (#12037)
* [TOPI] [Hexagon] Uint8 Reshape and batch flatten slice ops
* Fix documentation
---
python/tvm/topi/hexagon/slice_ops/reshape.py | 13 +++---
python/tvm/topi/hexagon/utils.py | 21 ++++++++++
.../python/contrib/test_hexagon/infrastructure.py | 12 +++++-
.../contrib/test_hexagon/topi/test_reshape.py | 47 +++++++++++++++-------
4 files changed, 72 insertions(+), 21 deletions(-)
diff --git a/python/tvm/topi/hexagon/slice_ops/reshape.py b/python/tvm/topi/hexagon/slice_ops/reshape.py
index 374c20bb72..2220253e21 100644
--- a/python/tvm/topi/hexagon/slice_ops/reshape.py
+++ b/python/tvm/topi/hexagon/slice_ops/reshape.py
@@ -40,13 +40,14 @@ def reshape_compute(inp: te.Tensor, new_shape: tuple) -> te.Tensor:
return topi.transform.reshape(inp, new_shape)
-def stir_schedule_nhwc_1024c(
+def stir_sched_nhwc_2d_op(
out: te.Tensor,
inp: te.Tensor,
out_layout: str,
in_layout: str,
+ c_split: int,
) -> tir.Schedule:
- """Schedule for output layout: nhwc-1024c-2d"""
+ """Schedule for output layout: nc-1024-2d, nc-2048-2d"""
reshape_func = te.create_prim_func([inp, out])
sch = tir.Schedule(reshape_func, debug_mask="all")
compute = sch.get_block("T_reshape")
@@ -57,7 +58,7 @@ def stir_schedule_nhwc_1024c(
jout, channel = sch.split(j, [None, inp.shape[3]])
height, width = sch.split(jout, [inp.shape[1], inp.shape[2]])
channelo, channeli = sch.split(channel, [None, 1024])
- channelio, channelii = sch.split(channeli, [None, 64])
+ channelio, channelii = sch.split(channeli, [None, c_split])
sch.reorder(i, height, width, channelo, channelio, channelii)
sch.vectorize(channelii)
return sch
@@ -101,8 +102,10 @@ def reshape_stir_schedule(
sch : tvm.tir.Schedule
The STIR schedule for slice reshape compute
"""
- if output_layout == "nhwc-8h2w32c2w-2d":
+ if output_layout in ["nhwc-8h2w32c2w-2d", "nhwc-8h8w32c-2d"]:
return stir_schedule_nhwc_8h2w32c2w(out, inp, output_layout, input_layout)
if output_layout == "nc-1024-2d":
- return stir_schedule_nhwc_1024c(out, inp, output_layout, input_layout)
+ return stir_sched_nhwc_2d_op(out, inp, output_layout, input_layout, 64)
+ if output_layout == "nc-2048-2d":
+ return stir_sched_nhwc_2d_op(out, inp, output_layout, input_layout, 128)
raise RuntimeError(f"Unexpected layout '{output_layout}'")
diff --git a/python/tvm/topi/hexagon/utils.py b/python/tvm/topi/hexagon/utils.py
index 4458c55e62..3b8914ffe9 100644
--- a/python/tvm/topi/hexagon/utils.py
+++ b/python/tvm/topi/hexagon/utils.py
@@ -87,6 +87,21 @@ def nc_1024_2d(n, c):
return [n, c // 1024, te.AXIS_SEPARATOR, c % 1024]
+def nhwc_2048c_2d(n, h, w, c):
+ """Return index map for nhwc_2048 2d layout"""
+ return [n, h, w, c // 2048, te.AXIS_SEPARATOR, c % 2048]
+
+
+def nc_2048_2d(n, c):
+ """Return index map for nc_2048 2d layout"""
+ return [n, c // 2048, te.AXIS_SEPARATOR, c % 2048]
+
+
+def nhwc_8h8w32c_2d(n, h, w, c):
+ """Return index map for nhwc_8h8w32c 2d layout"""
+ return [n, h // 8, w // 8, c // 32, te.AXIS_SEPARATOR, h % 8, w % 8, c % 32]
+
+
def iohw_16i32o2i_1d(height, width, in_channel, out_channel):
return [
in_channel // 32,
@@ -129,4 +144,10 @@ def get_layout_transform_fn(layout):
return nc_1024c_2d
if layout == "iohw-16i32o2i-1d":
return iohw_16i32o2i_1d
+ if layout == "nhwc-2048c-2d":
+ return nhwc_2048c_2d
+ if layout == "nc-2048-2d":
+ return nc_2048_2d
+ if layout == "nhwc-8h8w32c-2d":
+ return nhwc_8h8w32c_2d
raise RuntimeError(f"Unexpected layout '{layout}'")
diff --git a/tests/python/contrib/test_hexagon/infrastructure.py b/tests/python/contrib/test_hexagon/infrastructure.py
index a1fbfdefcd..7108ac5598 100644
--- a/tests/python/contrib/test_hexagon/infrastructure.py
+++ b/tests/python/contrib/test_hexagon/infrastructure.py
@@ -256,7 +256,17 @@ def transform_numpy(arr_np, current_layout: str, new_layout: str):
if new_layout == "nhwc-1024c-2d":
N, H, W, C = arr_np.shape
return arr_np.reshape([N, H, W, C // 1024, 1024])
- raise RuntimeError(f"Unexpected new_layout '{new_layout}'")
+ if new_layout == "nc-2048-2d":
+ N, C = arr_np.shape
+ return arr_np.reshape([N, C // 2048, 2048])
+ if new_layout == "nhwc-2048c-2d":
+ N, H, W, C = arr_np.shape
+ return arr_np.reshape([N, H, W, C // 2048, 2048])
+ if new_layout in ["nhwc-8h8w32c-2d"]:
+ n, h, w, c = arr_np.shape
+ return arr_np.reshape([n, h // 8, 8, w // 8, 8, c // 32, 32]).transpose(
+ 0, 1, 3, 5, 2, 4, 6
+ )
if current_layout == "nc":
n, c = arr_np.shape
diff --git a/tests/python/contrib/test_hexagon/topi/test_reshape.py b/tests/python/contrib/test_hexagon/topi/test_reshape.py
index 2def86ad83..7df29a02ab 100644
--- a/tests/python/contrib/test_hexagon/topi/test_reshape.py
+++ b/tests/python/contrib/test_hexagon/topi/test_reshape.py
@@ -56,23 +56,23 @@ def reshape_helper(
input_layout,
)
with tvm.transform.PassContext(opt_level=3):
- print("output of tvm.lower", tvm.lower(tir_s.mod, name=func))
runtime_module = tvm.build(tir_s.mod, target=target, name=func)
mod = hexagon_session.load_module(runtime_module)
- a_numpy = (np.random.uniform(-1, 1, input_shape)).astype(data_type)
+ a_numpy = (np.random.uniform(-10, 10, input_shape)).astype(data_type)
ref = np.reshape(a_numpy, output_shape)
input_np_transformed = transform_numpy(a_numpy, "nhwc", input_layout)
ref_np_transformed = transform_numpy(ref, "nhwc", output_layout)
input_axis_sep = [4]
- if output_layout == "nhwc-8h2w32c2w-2d":
+ if output_layout in ["nhwc-8h2w32c2w-2d", "nhwc-8h8w32c-2d"]:
output_axis_sep = [4]
- elif output_layout == "nc-1024-2d":
+ elif output_layout in ["nc-1024-2d", "nc-2048-2d"]:
output_axis_sep = [2]
else:
raise RuntimeError(f"Unexpected layout '{output_layout}'")
+
a_tvm = allocate_hexagon_array(
hexagon_session.device,
data=input_np_transformed,
@@ -86,11 +86,12 @@ def reshape_helper(
axis_separators=output_axis_sep,
mem_scope="global.vtcm",
)
+
mod(a_tvm, output)
np.testing.assert_allclose(output.numpy(), ref_np_transformed, atol=1e-07, rtol=0)
-batch_flatten_tests = (
+batch_flatten_fp16_tests = (
([1, 1, 1, 2048], [1, 2048], "nhwc-1024c-2d", "nc-1024-2d", "float16"),
([1, 2, 4, 2048], [1, 2 * 4 * 2048], "nhwc-1024c-2d", "nc-1024-2d", "float16"),
([1, 8, 8, 1024], [1, 8 * 8 * 1024], "nhwc-1024c-2d", "nc-1024-2d", "float16"),
@@ -98,14 +99,17 @@ batch_flatten_tests = (
)
+batch_flatten_uint8_tests = (
+ ([1, 1, 1, 2048], [1, 2048], "nhwc-2048c-2d", "nc-2048-2d", "uint8"),
+ ([1, 2, 4, 2048], [1, 2 * 4 * 2048], "nhwc-2048c-2d", "nc-2048-2d", "uint8"),
+)
+
+
class BaseTestBatchFlatten:
- (
- input_shape,
- output_shape,
- input_layout,
- output_layout,
- data_type,
- ) = tvm.testing.parameters(*batch_flatten_tests)
+ (input_shape, output_shape, input_layout, output_layout, data_type,) = tvm.testing.parameters(
+ *batch_flatten_fp16_tests,
+ *batch_flatten_uint8_tests,
+ )
class TestBatchFlatten(BaseTestBatchFlatten):
@@ -132,11 +136,24 @@ class TestBatchFlatten(BaseTestBatchFlatten):
)
+reshape_fp16_tests = (
+ ([1, 8, 4, 64], [1, 8, 8, 32], "nhwc-8h2w32c2w-2d", "nhwc-8h2w32c2w-2d", "float16"),
+ ([1, 16, 8, 128], [1, 16, 16, 64], "nhwc-8h2w32c2w-2d", "nhwc-8h2w32c2w-2d", "float16"),
+)
+
+
+reshape_uint8_tests = (
+ ([1, 8, 8, 128], [1, 8, 16, 64], "nhwc-8h8w32c-2d", "nhwc-8h8w32c-2d", "uint8"),
+ ([1, 16, 64, 128], [1, 16, 128, 64], "nhwc-8h8w32c-2d", "nhwc-8h8w32c-2d", "uint8"),
+)
+
+
class BaseTestReshape(BaseTestBatchFlatten):
(input_shape, output_shape, input_layout, output_layout, data_type,) = tvm.testing.parameters(
- *batch_flatten_tests,
- ([1, 8, 4, 64], [1, 8, 8, 32], "nhwc-8h2w32c2w-2d", "nhwc-8h2w32c2w-2d", "float16"),
- ([1, 16, 8, 128], [1, 16, 16, 64], "nhwc-8h2w32c2w-2d", "nhwc-8h2w32c2w-2d", "float16"),
+ *batch_flatten_fp16_tests,
+ *batch_flatten_uint8_tests,
+ *reshape_fp16_tests,
+ *reshape_uint8_tests,
)