You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by me...@apache.org on 2022/06/24 17:17:30 UTC

[tvm] branch main updated: [Hexagon] Softmax slice op initial version (#11559)

This is an automated email from the ASF dual-hosted git repository.

mehrdadh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new f1d30a27b2 [Hexagon] Softmax slice op initial version (#11559)
f1d30a27b2 is described below

commit f1d30a27b2efe5b15e6492f785be1d41c9a75ab9
Author: Gayatri P K <qu...@quicinc.com>
AuthorDate: Fri Jun 24 22:47:24 2022 +0530

    [Hexagon] Softmax slice op initial version (#11559)
    
    Resolve merge conflict in utils.py
---
 python/tvm/topi/hexagon/slice_ops/__init__.py      |   1 +
 python/tvm/topi/hexagon/slice_ops/softmax_slice.py |  76 +++++++++++
 python/tvm/topi/hexagon/utils.py                   |  31 +++++
 .../contrib/test_hexagon/test_softmax_slice.py     | 140 +++++++++++++++++++++
 4 files changed, 248 insertions(+)

diff --git a/python/tvm/topi/hexagon/slice_ops/__init__.py b/python/tvm/topi/hexagon/slice_ops/__init__.py
index 70531c629e..5b3ef530b0 100644
--- a/python/tvm/topi/hexagon/slice_ops/__init__.py
+++ b/python/tvm/topi/hexagon/slice_ops/__init__.py
@@ -19,3 +19,4 @@
 
 from .avg_pool2d import avg_pool2d_compute, avg_pool2d_STIR_schedule
 from .add_subtract_multiply import *
+from .softmax_slice import *
diff --git a/python/tvm/topi/hexagon/slice_ops/softmax_slice.py b/python/tvm/topi/hexagon/slice_ops/softmax_slice.py
new file mode 100644
index 0000000000..f95e58f3ae
--- /dev/null
+++ b/python/tvm/topi/hexagon/slice_ops/softmax_slice.py
@@ -0,0 +1,76 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Hexagon slice softmax compute and schedule"""
+
+import typing
+
+from tvm import te, tir, topi
+from ..utils import get_layout_transform_fn
+
+
+def softmax_compute(in_tensor):
+    """
+    Compute for slice softmax op for hexagon.
+    This op makes the following assumptions:
+    1. This op is written for a sliced softmax operation.
+    2. The input is assumed to be in NC layout.
+    """
+    return topi.nn.softmax(in_tensor, axis=1)
+
+
+def softmax_stir_schedule(
+    out: te.Tensor, inp: te.Tensor, out_layout: typing.Callable, in_layout: typing.Callable
+):
+    """
+    STIR schedule definition for the compute of softmax
+    """
+
+    in_layout = get_layout_transform_fn(in_layout)
+    out_layout = get_layout_transform_fn(out_layout)
+
+    func = te.create_prim_func([inp, out])
+    sch = tir.Schedule(func, debug_mask="all")
+
+    max_tensor = sch.get_block("T_softmax_maxelem")
+    exp_tensor = sch.get_block("T_softmax_exp")
+    sum_tensor = sch.get_block("T_softmax_expsum")
+    out_tensor = sch.get_block("T_softmax_norm")
+
+    sch.transform_layout(max_tensor, inp.name, in_layout)
+    sch.transform_layout(out_tensor, out.name, out_layout)
+
+    _, c_inner = sch.get_loops(max_tensor)
+    _, c_inner_i = sch.split(c_inner, [None, 64])
+    rf_max = sch.rfactor(c_inner_i, 0)
+    _, _, max_inner = sch.get_loops(rf_max)
+    sch.vectorize(max_inner)
+
+    _, loopi = sch.get_loops(exp_tensor)
+    _, loopi_i = sch.split(loopi, [None, 512])
+    sch.vectorize(loopi_i)
+
+    _, c_sum_inner = sch.get_loops(sum_tensor)
+    _, c_sum_inner_i = sch.split(c_sum_inner, [None, 64])
+    rf_sum = sch.rfactor(c_sum_inner_i, 0)
+    _, _, sum_inner = sch.get_loops(rf_sum)
+    sch.vectorize(sum_inner)
+
+    _, c_out_inner = sch.get_loops(out_tensor)
+    _, c_out_inner_i = sch.split(c_out_inner, [None, 512])
+    sch.vectorize(c_out_inner_i)
+
+    return sch
diff --git a/python/tvm/topi/hexagon/utils.py b/python/tvm/topi/hexagon/utils.py
index af6e3de9c3..3efc48c4d0 100644
--- a/python/tvm/topi/hexagon/utils.py
+++ b/python/tvm/topi/hexagon/utils.py
@@ -14,7 +14,10 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+
 # pylint: disable=invalid-name
+
+
 """Common hexagon specific utilities"""
 from tvm import te
 
@@ -39,6 +42,26 @@ def nhwc_8h2w32c2w_1d(n, h, w, c):
     return [n, h // 8, w // 4, c // 32, h % 8, (w % 4) // 2, c % 32, w % 2]
 
 
+def nhwc_4h4w32c_1d(n, h, w, c):
+    """Return index map for nhwc_4h4232c 1d layout"""
+    return [n, h // 4, w // 4, c // 32, h % 4, w % 4, c % 32]
+
+
+def nhwc_4h4w32c_2d(n, h, w, c):
+    """Return index map for nhwc_4h4w32c 2d layout"""
+    return [n, h // 4, w // 4, c // 32, te.AXIS_SEPARATOR, h % 4, w % 4, c % 32]
+
+
+def nc_512c_1d(n, c):
+    """Return index map for nc_512c 1d layout"""
+    return [n, c // 512, c % 512]
+
+
+def nc_512c_2d(n, c):
+    """Return index map for nc_512c 2d layout"""
+    return [n, c // 512, te.AXIS_SEPARATOR, c % 512]
+
+
 def get_layout_transform_fn(layout):
     """Return index map function as per the layout string"""
     if layout == "nhwc-8h2w32c2w-2d":
@@ -49,4 +72,12 @@ def get_layout_transform_fn(layout):
         return n11c_1024c_2d
     if layout == "n11c-1024c-1d":
         return n11c_1024c_1d
+    if layout == "nhwc-4h4w32c-2d":
+        return nhwc_4h4w32c_2d
+    if layout == "nhwc-4h4w32c-1d":
+        return nhwc_4h4w32c_1d
+    if layout == "nc-512c-2d":
+        return nc_512c_2d
+    if layout == "nc-512c-1d":
+        return nc_512c_1d
     raise RuntimeError(f"Unexpected layout '{layout}'")
diff --git a/tests/python/contrib/test_hexagon/test_softmax_slice.py b/tests/python/contrib/test_hexagon/test_softmax_slice.py
new file mode 100644
index 0000000000..a4745d62a7
--- /dev/null
+++ b/tests/python/contrib/test_hexagon/test_softmax_slice.py
@@ -0,0 +1,140 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pytest
+import numpy as np
+from tvm import te, topi
+
+import tvm.testing
+from tvm.topi import testing
+from tvm.contrib.hexagon.build import HexagonLauncher
+
+import tvm.topi.hexagon.slice_ops as sl
+from .infrastructure import allocate_hexagon_array
+
+
+def transform_numpy(arr_np, layout):
+
+    if layout in ["nc-512c-2d"]:
+        N, C = arr_np.shape
+        return arr_np.reshape([N, C // 512, 512])
+    raise RuntimeError(f"Unexpected layout '{layout}'")
+
+
+@tvm.testing.fixture
+def input_np(input_shape, dtype):
+    return (np.random.uniform(size=input_shape)).astype(dtype)
+
+
+@tvm.testing.fixture
+def transformed_expected_output_np(expected_output_np, output_layout):
+    return transform_numpy(expected_output_np, output_layout)
+
+
+@tvm.testing.fixture
+def transformed_input_np(input_np, input_layout):
+    return transform_numpy(input_np, input_layout)
+
+
+class Basesoftmax2d:
+
+    input_shape, input_layout, output_layout, axis_sep = tvm.testing.parameters(
+        ((1, 1024), "nc-512c-2d", "nc-512c-2d", [2])
+    )
+    dtype = tvm.testing.parameter("float32")
+    working_scope = tvm.testing.parameter("global.vtcm")
+
+
+class TestSoftmax2d(Basesoftmax2d):
+    @tvm.testing.fixture
+    def expected_output_np(self, input_np):
+        if len(input_np.shape) == 2:
+            ref_np_2d = tvm.topi.testing.softmax_python(input_np)
+            return ref_np_2d
+        raise RuntimeError(f"Unexpected input shape '{input_np.shape}'")
+
+    @tvm.testing.requires_hexagon
+    def test_softmax_f32(
+        self,
+        dtype,
+        input_layout,
+        output_layout,
+        input_shape,
+        input_np,
+        transformed_input_np,
+        transformed_expected_output_np,
+        expected_output_np,
+        working_scope,
+        axis_sep,
+        hexagon_session,
+    ):
+
+        target_hexagon = tvm.target.hexagon(
+            "v69",
+            llvm_options="--disable-loop-unrolling-pass",
+        )
+        A = te.placeholder(input_shape, name="A", dtype=dtype)
+
+        O = sl.softmax_compute(A)
+
+        if input_layout == "nc-512c-2d":
+            tir_s = sl.softmax_stir_schedule(O, A, output_layout, input_layout)
+            sch = tir_s.mod
+        else:
+            raise RuntimeError(f"Unexpected input layout '{input_layout}'")
+
+        with tvm.transform.PassContext(
+            opt_level=3,
+            config={
+                "tir.LoopPartition": {"partition_const_loop": True},
+            },
+        ):
+
+            func = tvm.build(
+                sch,
+                [A, O],
+                tvm.target.Target(target_hexagon, host=target_hexagon),
+                name="softmax_slice",
+            )
+
+        input_arr = allocate_hexagon_array(
+            hexagon_session.device,
+            data=transformed_input_np,
+            axis_separators=axis_sep,
+            mem_scope=working_scope,
+        )
+
+        output_arr = allocate_hexagon_array(
+            hexagon_session.device,
+            tensor_shape=transformed_expected_output_np.shape,
+            dtype=transformed_expected_output_np.dtype,
+            axis_separators=axis_sep,
+            mem_scope=working_scope,
+        )
+
+        mod = hexagon_session.load_module(func)
+        mod(input_arr, output_arr)
+
+        n, c = input_np.shape
+        output_np = output_arr.numpy().reshape(1, c // 512, 512)
+
+        np.testing.assert_allclose(output_np, transformed_expected_output_np, rtol=1e-4, atol=1e-4)
+
+
+if __name__ == "__main__":
+
+    sys.exit(pytest.main(sys.argv))