You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by kp...@apache.org on 2022/07/16 15:10:44 UTC

[tvm] branch main updated: [TOPI] [Hexagon] Uint8 Reshape and batch flatten slice ops (#12037)

This is an automated email from the ASF dual-hosted git repository.

kparzysz pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new c0e996e291 [TOPI] [Hexagon] Uint8 Reshape and batch flatten slice ops (#12037)
c0e996e291 is described below

commit c0e996e2914585fe6b0c11fb2efdaea5c6b9daf9
Author: abhikran-quic <63...@users.noreply.github.com>
AuthorDate: Sat Jul 16 20:40:39 2022 +0530

    [TOPI] [Hexagon] Uint8 Reshape and batch flatten slice ops (#12037)
    
    * [TOPI] [Hexagon] Uint8 Reshape and batch flatten slice ops
    
    * Fix documentation
---
 python/tvm/topi/hexagon/slice_ops/reshape.py       | 13 +++---
 python/tvm/topi/hexagon/utils.py                   | 21 ++++++++++
 .../python/contrib/test_hexagon/infrastructure.py  | 12 +++++-
 .../contrib/test_hexagon/topi/test_reshape.py      | 47 +++++++++++++++-------
 4 files changed, 72 insertions(+), 21 deletions(-)

diff --git a/python/tvm/topi/hexagon/slice_ops/reshape.py b/python/tvm/topi/hexagon/slice_ops/reshape.py
index 374c20bb72..2220253e21 100644
--- a/python/tvm/topi/hexagon/slice_ops/reshape.py
+++ b/python/tvm/topi/hexagon/slice_ops/reshape.py
@@ -40,13 +40,14 @@ def reshape_compute(inp: te.Tensor, new_shape: tuple) -> te.Tensor:
     return topi.transform.reshape(inp, new_shape)
 
 
-def stir_schedule_nhwc_1024c(
+def stir_sched_nhwc_2d_op(
     out: te.Tensor,
     inp: te.Tensor,
     out_layout: str,
     in_layout: str,
+    c_split: int,
 ) -> tir.Schedule:
-    """Schedule for output layout: nhwc-1024c-2d"""
+    """Schedule for output layout: nc-1024-2d, nc-2048-2d"""
     reshape_func = te.create_prim_func([inp, out])
     sch = tir.Schedule(reshape_func, debug_mask="all")
     compute = sch.get_block("T_reshape")
@@ -57,7 +58,7 @@ def stir_schedule_nhwc_1024c(
     jout, channel = sch.split(j, [None, inp.shape[3]])
     height, width = sch.split(jout, [inp.shape[1], inp.shape[2]])
     channelo, channeli = sch.split(channel, [None, 1024])
-    channelio, channelii = sch.split(channeli, [None, 64])
+    channelio, channelii = sch.split(channeli, [None, c_split])
     sch.reorder(i, height, width, channelo, channelio, channelii)
     sch.vectorize(channelii)
     return sch
@@ -101,8 +102,10 @@ def reshape_stir_schedule(
     sch : tvm.tir.Schedule
         The STIR schedule for slice reshape compute
     """
-    if output_layout == "nhwc-8h2w32c2w-2d":
+    if output_layout in ["nhwc-8h2w32c2w-2d", "nhwc-8h8w32c-2d"]:
         return stir_schedule_nhwc_8h2w32c2w(out, inp, output_layout, input_layout)
     if output_layout == "nc-1024-2d":
-        return stir_schedule_nhwc_1024c(out, inp, output_layout, input_layout)
+        return stir_sched_nhwc_2d_op(out, inp, output_layout, input_layout, 64)
+    if output_layout == "nc-2048-2d":
+        return stir_sched_nhwc_2d_op(out, inp, output_layout, input_layout, 128)
     raise RuntimeError(f"Unexpected layout '{output_layout}'")
diff --git a/python/tvm/topi/hexagon/utils.py b/python/tvm/topi/hexagon/utils.py
index 4458c55e62..3b8914ffe9 100644
--- a/python/tvm/topi/hexagon/utils.py
+++ b/python/tvm/topi/hexagon/utils.py
@@ -87,6 +87,21 @@ def nc_1024_2d(n, c):
     return [n, c // 1024, te.AXIS_SEPARATOR, c % 1024]
 
 
+def nhwc_2048c_2d(n, h, w, c):
+    """Return index map for nhwc_2048 2d layout"""
+    return [n, h, w, c // 2048, te.AXIS_SEPARATOR, c % 2048]
+
+
+def nc_2048_2d(n, c):
+    """Return index map for nc_2048 2d layout"""
+    return [n, c // 2048, te.AXIS_SEPARATOR, c % 2048]
+
+
+def nhwc_8h8w32c_2d(n, h, w, c):
+    """Return index map for nhwc_8h8w32c 2d layout"""
+    return [n, h // 8, w // 8, c // 32, te.AXIS_SEPARATOR, h % 8, w % 8, c % 32]
+
+
 def iohw_16i32o2i_1d(height, width, in_channel, out_channel):
     return [
         in_channel // 32,
@@ -129,4 +144,10 @@ def get_layout_transform_fn(layout):
         return nc_1024c_2d
     if layout == "iohw-16i32o2i-1d":
         return iohw_16i32o2i_1d
+    if layout == "nhwc-2048c-2d":
+        return nhwc_2048c_2d
+    if layout == "nc-2048-2d":
+        return nc_2048_2d
+    if layout == "nhwc-8h8w32c-2d":
+        return nhwc_8h8w32c_2d
     raise RuntimeError(f"Unexpected layout '{layout}'")
diff --git a/tests/python/contrib/test_hexagon/infrastructure.py b/tests/python/contrib/test_hexagon/infrastructure.py
index a1fbfdefcd..7108ac5598 100644
--- a/tests/python/contrib/test_hexagon/infrastructure.py
+++ b/tests/python/contrib/test_hexagon/infrastructure.py
@@ -256,7 +256,17 @@ def transform_numpy(arr_np, current_layout: str, new_layout: str):
         if new_layout == "nhwc-1024c-2d":
             N, H, W, C = arr_np.shape
             return arr_np.reshape([N, H, W, C // 1024, 1024])
-        raise RuntimeError(f"Unexpected new_layout '{new_layout}'")
+        if new_layout == "nc-2048-2d":
+            N, C = arr_np.shape
+            return arr_np.reshape([N, C // 2048, 2048])
+        if new_layout == "nhwc-2048c-2d":
+            N, H, W, C = arr_np.shape
+            return arr_np.reshape([N, H, W, C // 2048, 2048])
+        if new_layout in ["nhwc-8h8w32c-2d"]:
+            n, h, w, c = arr_np.shape
+            return arr_np.reshape([n, h // 8, 8, w // 8, 8, c // 32, 32]).transpose(
+                0, 1, 3, 5, 2, 4, 6
+            )
 
     if current_layout == "nc":
         n, c = arr_np.shape
diff --git a/tests/python/contrib/test_hexagon/topi/test_reshape.py b/tests/python/contrib/test_hexagon/topi/test_reshape.py
index 2def86ad83..7df29a02ab 100644
--- a/tests/python/contrib/test_hexagon/topi/test_reshape.py
+++ b/tests/python/contrib/test_hexagon/topi/test_reshape.py
@@ -56,23 +56,23 @@ def reshape_helper(
         input_layout,
     )
     with tvm.transform.PassContext(opt_level=3):
-        print("output of tvm.lower", tvm.lower(tir_s.mod, name=func))
         runtime_module = tvm.build(tir_s.mod, target=target, name=func)
 
     mod = hexagon_session.load_module(runtime_module)
 
-    a_numpy = (np.random.uniform(-1, 1, input_shape)).astype(data_type)
+    a_numpy = (np.random.uniform(-10, 10, input_shape)).astype(data_type)
     ref = np.reshape(a_numpy, output_shape)
 
     input_np_transformed = transform_numpy(a_numpy, "nhwc", input_layout)
     ref_np_transformed = transform_numpy(ref, "nhwc", output_layout)
     input_axis_sep = [4]
-    if output_layout == "nhwc-8h2w32c2w-2d":
+    if output_layout in ["nhwc-8h2w32c2w-2d", "nhwc-8h8w32c-2d"]:
         output_axis_sep = [4]
-    elif output_layout == "nc-1024-2d":
+    elif output_layout in ["nc-1024-2d", "nc-2048-2d"]:
         output_axis_sep = [2]
     else:
         raise RuntimeError(f"Unexpected layout '{output_layout}'")
+
     a_tvm = allocate_hexagon_array(
         hexagon_session.device,
         data=input_np_transformed,
@@ -86,11 +86,12 @@ def reshape_helper(
         axis_separators=output_axis_sep,
         mem_scope="global.vtcm",
     )
+
     mod(a_tvm, output)
     np.testing.assert_allclose(output.numpy(), ref_np_transformed, atol=1e-07, rtol=0)
 
 
-batch_flatten_tests = (
+batch_flatten_fp16_tests = (
     ([1, 1, 1, 2048], [1, 2048], "nhwc-1024c-2d", "nc-1024-2d", "float16"),
     ([1, 2, 4, 2048], [1, 2 * 4 * 2048], "nhwc-1024c-2d", "nc-1024-2d", "float16"),
     ([1, 8, 8, 1024], [1, 8 * 8 * 1024], "nhwc-1024c-2d", "nc-1024-2d", "float16"),
@@ -98,14 +99,17 @@ batch_flatten_tests = (
 )
 
 
+batch_flatten_uint8_tests = (
+    ([1, 1, 1, 2048], [1, 2048], "nhwc-2048c-2d", "nc-2048-2d", "uint8"),
+    ([1, 2, 4, 2048], [1, 2 * 4 * 2048], "nhwc-2048c-2d", "nc-2048-2d", "uint8"),
+)
+
+
 class BaseTestBatchFlatten:
-    (
-        input_shape,
-        output_shape,
-        input_layout,
-        output_layout,
-        data_type,
-    ) = tvm.testing.parameters(*batch_flatten_tests)
+    (input_shape, output_shape, input_layout, output_layout, data_type,) = tvm.testing.parameters(
+        *batch_flatten_fp16_tests,
+        *batch_flatten_uint8_tests,
+    )
 
 
 class TestBatchFlatten(BaseTestBatchFlatten):
@@ -132,11 +136,24 @@ class TestBatchFlatten(BaseTestBatchFlatten):
         )
 
 
+reshape_fp16_tests = (
+    ([1, 8, 4, 64], [1, 8, 8, 32], "nhwc-8h2w32c2w-2d", "nhwc-8h2w32c2w-2d", "float16"),
+    ([1, 16, 8, 128], [1, 16, 16, 64], "nhwc-8h2w32c2w-2d", "nhwc-8h2w32c2w-2d", "float16"),
+)
+
+
+reshape_uint8_tests = (
+    ([1, 8, 8, 128], [1, 8, 16, 64], "nhwc-8h8w32c-2d", "nhwc-8h8w32c-2d", "uint8"),
+    ([1, 16, 64, 128], [1, 16, 128, 64], "nhwc-8h8w32c-2d", "nhwc-8h8w32c-2d", "uint8"),
+)
+
+
 class BaseTestReshape(BaseTestBatchFlatten):
     (input_shape, output_shape, input_layout, output_layout, data_type,) = tvm.testing.parameters(
-        *batch_flatten_tests,
-        ([1, 8, 4, 64], [1, 8, 8, 32], "nhwc-8h2w32c2w-2d", "nhwc-8h2w32c2w-2d", "float16"),
-        ([1, 16, 8, 128], [1, 16, 16, 64], "nhwc-8h2w32c2w-2d", "nhwc-8h2w32c2w-2d", "float16"),
+        *batch_flatten_fp16_tests,
+        *batch_flatten_uint8_tests,
+        *reshape_fp16_tests,
+        *reshape_uint8_tests,
     )