You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by me...@apache.org on 2022/11/15 22:46:31 UTC

[tvm] branch main updated: [TOPI][Hexagon] Implement quantized adaptive_avg_pool1d for hexagon (#13282)

This is an automated email from the ASF dual-hosted git repository.

mehrdadh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 4f4b4edafd [TOPI][Hexagon] Implement quantized adaptive_avg_pool1d for hexagon (#13282)
4f4b4edafd is described below

commit 4f4b4edafdb0837972e4df6f570368f9f7cefd20
Author: Tasmia Rahman <89...@users.noreply.github.com>
AuthorDate: Tue Nov 15 16:46:24 2022 -0600

    [TOPI][Hexagon] Implement quantized adaptive_avg_pool1d for hexagon (#13282)
    
    * [TOPI][Hexagon] Implement adaptive_avg_pool1d for hexagon
    
    * Fix lint issues
    
    * Fix some lint issues
    
    * Fix lint issues in test
    
    * Fix import for allocate_hexagon_array
---
 python/tvm/topi/hexagon/qnn/__init__.py            |   1 +
 python/tvm/topi/hexagon/qnn/adaptive_avg_pool1d.py | 120 +++++++++++++
 python/tvm/topi/hexagon/utils.py                   |   7 +
 .../python/contrib/test_hexagon/infrastructure.py  |   9 +
 .../test_hexagon/topi/test_adaptive_avg_pool1d.py  | 185 +++++++++++++++++++++
 5 files changed, 322 insertions(+)

diff --git a/python/tvm/topi/hexagon/qnn/__init__.py b/python/tvm/topi/hexagon/qnn/__init__.py
index f7a018d225..d63b69b2e2 100644
--- a/python/tvm/topi/hexagon/qnn/__init__.py
+++ b/python/tvm/topi/hexagon/qnn/__init__.py
@@ -27,3 +27,4 @@ from .dequantize import (
 from .quantize import quantize_compute, tir_quantize_schedule
 from .nn import *
 from .qdepthwise_conv2d_slice import qdepthwise_conv2d_compute, qdepthwise_conv2d_schedule
+from .adaptive_avg_pool1d import *
diff --git a/python/tvm/topi/hexagon/qnn/adaptive_avg_pool1d.py b/python/tvm/topi/hexagon/qnn/adaptive_avg_pool1d.py
new file mode 100755
index 0000000000..80f1cd1ecf
--- /dev/null
+++ b/python/tvm/topi/hexagon/qnn/adaptive_avg_pool1d.py
@@ -0,0 +1,120 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+""" Compute and schedule for adaptive_avg_pool1d slice op
+
+Following are few notes and assumptions made by the implementation:
+
+Assumptions:
+1) The input is in NCW layout. Distilbert is the only model that calls
+   nn.adaptive_avg_pool1d and the only layout it uses is 'NCW'.
+2) The op takes output_size as an argument and
+   only handles the specialized case where output_size is 1.
+   The argument output_size is used as the value of output_width.
+3) Both input and output dtype is uint8/int8 and
+   quantization parameter is provided to the op.
+4) Input is assumed to always be multiple of fixed chunk 32c64w.
+
+Notes:
+1) If input width is used as output width, there can be two cases:
+    a. If the quantization parameters of input and output are same,
+       it can return the input as output so the op will be a no-op.
+    b. If the quantization parameters of input and output are different,
+       it will essentially be a requantize op.
+2) If output_size is a value besides 1 or input_width,
+   adaptive_avg_pool1d may use dynamic stride and kernel for each output element.
+   When this case occurs, kernel won't be known at compile time. We want to use
+   the generic implementation nn.adaptive_avg_pool1d() for this case.
+"""
+
+from tvm import te
+from tvm import tir
+from ..utils import get_layout_transform_fn, get_fixed_point_value, saturate
+
+
+def adaptive_avg_pool1d(
+    data: te.Tensor,
+    output_size: list,
+    odtype: str,
+    input_zero_point: int,
+    input_scale: float,
+    output_zero_point: int,
+    output_scale: float,
+):
+    """adaptive_avg_pool1d compute"""
+    _, _, inw = data.shape
+
+    out_width = output_size[0]
+
+    n, c = data.shape[:2]
+    oshape = (n, c) + (out_width,)
+
+    # Kernel is same as input_width since output_width is assumed to be 1
+    if out_width == 1:
+        kw_r = inw
+    else:
+        raise RuntimeError(f"Unsupported output_size, {out_width}'")
+
+    if odtype == "uint8":
+        temp_dtype = "uint32"
+    elif odtype == "int8":
+        temp_dtype = "int32"
+    else:
+        raise RuntimeError(f"Unsupported output dtype, {odtype}'")
+
+    scale_with_area = input_scale / (output_scale * int(kw_r))
+    scale_fixed_point, rsh = get_fixed_point_value(scale_with_area, "int16")
+    corr = (output_zero_point << rsh) - input_zero_point * kw_r * scale_fixed_point
+
+    rw_r = te.reduce_axis((0, kw_r), name="rw_r")
+
+    sum_compute = te.compute(
+        oshape,
+        lambda n, c, w: te.sum(data[n, c, w + rw_r].astype(temp_dtype), axis=[rw_r]),
+        name="sum",
+    )
+
+    avg_compute = te.compute(
+        oshape,
+        lambda n, c, w: saturate(
+            ((sum_compute[n, c, w] * scale_fixed_point) + corr) >> rsh, odtype
+        ).astype(odtype),
+        name="adaptive_avg_1d",
+    )
+    return avg_compute
+
+
+def stir_schedule_ncw_32c64w(outs, ins, input_layout: str):
+    """Schedule for input layout ncw-32c64w and output layout ncw"""
+    func = te.create_prim_func([ins, outs])
+    s = tir.Schedule(func)
+
+    sum_block = s.get_block("sum")
+
+    # Input is multiple of fixed chunk but output is NxCx1
+    # Hence transform_layout is only applied on input
+    input_transformed_layout = get_layout_transform_fn(input_layout)
+    s.transform_layout(sum_block, buffer=("read", 0), index_map=input_transformed_layout)
+
+    return s
+
+
+def tir_adaptive_avg_pool1d_schedule(outs, ins, output_layout: str, input_layout: str):
+    """STIR based schedule"""
+    if output_layout == "ncw":
+        return stir_schedule_ncw_32c64w(outs, ins, input_layout)
+    raise RuntimeError(f"Unexpected layout '{output_layout}'")
diff --git a/python/tvm/topi/hexagon/utils.py b/python/tvm/topi/hexagon/utils.py
index 890ebeb9fd..5aeed9aa4f 100644
--- a/python/tvm/topi/hexagon/utils.py
+++ b/python/tvm/topi/hexagon/utils.py
@@ -131,6 +131,11 @@ def ohwi32o_1d(height, width, in_channel, out_channel):
     return [out_channel // 32, height, width, in_channel, out_channel % 32]
 
 
+def ncw_32c64w_2d(n, c, w):
+    """Return index map for ncw_32c64w 2d layout"""
+    return [n, c // 32, w // 64, te.AXIS_SEPARATOR, c % 32, w % 64]
+
+
 def get_layout_transform_fn(layout):
     """Return index map function as per the layout string"""
     if layout == "nhwc-8h2w32c2w-2d":
@@ -173,6 +178,8 @@ def get_layout_transform_fn(layout):
         return n11c_2048c_2d
     if layout == "ohwi32o-1d":
         return ohwi32o_1d
+    if layout == "ncw-32c64w-2d":
+        return ncw_32c64w_2d
     raise RuntimeError(f"Unexpected layout '{layout}'")
 
 
diff --git a/tests/python/contrib/test_hexagon/infrastructure.py b/tests/python/contrib/test_hexagon/infrastructure.py
index c04631156f..c03701f83c 100644
--- a/tests/python/contrib/test_hexagon/infrastructure.py
+++ b/tests/python/contrib/test_hexagon/infrastructure.py
@@ -268,6 +268,15 @@ def transform_numpy(arr_np, current_layout: str, new_layout: str):
 
         raise RuntimeError(f"Unexpected new_layout '{new_layout}'")
 
+    if current_layout == "ncw":
+        if new_layout == "ncw":
+            return arr_np
+        if new_layout in ["ncw-32c64w-2d"]:
+            n, c, w = arr_np.shape
+            return arr_np.reshape([n, c // 32, 32, w // 64, 64]).transpose(0, 1, 3, 2, 4)
+
+        raise RuntimeError(f"Unexpected new_layout '{new_layout}'")
+
     raise RuntimeError(f"Unexpected current_layout '{current_layout}'")
 
 
diff --git a/tests/python/contrib/test_hexagon/topi/test_adaptive_avg_pool1d.py b/tests/python/contrib/test_hexagon/topi/test_adaptive_avg_pool1d.py
new file mode 100755
index 0000000000..4d4aef25e3
--- /dev/null
+++ b/tests/python/contrib/test_hexagon/topi/test_adaptive_avg_pool1d.py
@@ -0,0 +1,185 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Test code for specialized case of adaptive_avg_pool1d."""
+
+import numpy as np
+
+import tvm
+from tvm import te
+from tvm.topi.testing import adaptive_pool
+import tvm.topi.hexagon.qnn as s1
+from tvm.contrib.hexagon import allocate_hexagon_array
+from ..infrastructure import transform_numpy, quantize_np
+
+
+SCALE_M_VAL = None
+ZERO_POINT_M_VAL = None
+SCALE_VAL = None
+ZERO_POINT_VAL = None
+
+
+class TestAdaptivePool1D:
+    """Test specialized case of adaptive_avg_pool1d."""
+
+    (input_shape,) = tvm.testing.parameters(
+        ([1, 128, 128],),
+        ([1, 64, 64],),
+        ([1, 64, 128],),
+        ([1, 32, 64],),
+        ([1, 128, 768],),
+    )
+
+    # Fixed chunk layout is set as ncw-32c64w-2d for now.
+    # The adaptive_avg_pool1d implementation only handles specialized case
+    # where output_size is 1 as it appears on quantized distilbert model.
+    # Since output size won't be a multiple of fixed-chunk,
+    # output_layout is ncw.
+    # For optimization, it might get changed later.
+    input_layout, output_layout, pool_type, layout, output_size, dtype, = tvm.testing.parameters(
+        (
+            "ncw-32c64w-2d",
+            "ncw",
+            "avg",
+            "NCW",
+            [1],
+            "uint8",
+        )
+    )
+
+    @tvm.testing.fixture
+    def expected_output_np(
+        self,
+        input_np,
+        output_size,
+        pool_type,
+        layout,
+    ):
+        """Generate expected output."""
+        out_width = output_size[0]
+
+        ref_np = adaptive_pool(
+            input_np,
+            out_width,
+            pool_type,
+            layout,
+        )
+        return ref_np
+
+    @tvm.testing.fixture
+    def input_np(self, input_shape, dtype):
+        if dtype in ("uint8", "int8"):
+            dtype = "float32"
+        return np.random.random(input_shape).astype(dtype)
+
+    @tvm.testing.fixture
+    def quantize_input_np(self, input_np, dtype):
+        if dtype in ("uint8", "int8"):
+            global ZERO_POINT_VAL, SCALE_VAL
+            input_np_quantized, SCALE_VAL, ZERO_POINT_VAL = quantize_np(input_np, dtype)
+            return input_np_quantized
+
+        raise RuntimeError(f"Unsupported data type '{dtype}'")
+
+    @tvm.testing.fixture
+    def transformed_input_np(self, quantize_input_np, input_layout, layout, dtype):
+        if dtype in ("uint8", "int8"):
+            return transform_numpy(quantize_input_np, layout.lower(), input_layout)
+
+        raise RuntimeError(f"Unsupported data type '{dtype}'")
+
+    @tvm.testing.fixture
+    def quantize_expected_output_np(self, expected_output_np, dtype):
+        """Generate expected output."""
+        if dtype in ("uint8", "int8"):
+            global ZERO_POINT_M_VAL, SCALE_M_VAL
+            out_ref_quantized, SCALE_M_VAL, ZERO_POINT_M_VAL = quantize_np(
+                expected_output_np, dtype
+            )
+
+            # Since output_layout is ncw, no transformation is needed.
+            return out_ref_quantized
+
+        raise RuntimeError(f"Unsupported data type '{dtype}'")
+
+    @tvm.testing.requires_hexagon
+    def test_pool1d(
+        self,
+        dtype,
+        output_size,
+        input_layout,
+        output_layout,
+        input_shape,
+        transformed_input_np,
+        quantize_expected_output_np,
+        hexagon_session,
+    ):
+        """Test adaptive_avg_pool1d."""
+        target_hexagon = tvm.target.hexagon("v69")
+        a_tensor = te.placeholder(input_shape, name="a_tensor", dtype=dtype)
+
+        m_tensor = s1.adaptive_avg_pool1d(
+            a_tensor,
+            output_size,
+            dtype,
+            ZERO_POINT_VAL,
+            SCALE_VAL,
+            ZERO_POINT_M_VAL,
+            SCALE_M_VAL,
+        )
+
+        tir_schedule = s1.tir_adaptive_avg_pool1d_schedule(
+            m_tensor, a_tensor, output_layout, input_layout
+        )
+
+        sch = tir_schedule.mod
+
+        with tvm.transform.PassContext(opt_level=3):
+            func = tvm.build(
+                sch,
+                [a_tensor, m_tensor],
+                tvm.target.Target(target_hexagon, host=target_hexagon),
+                name="adaptive_pool1d",
+            )
+
+        input_axis_separator = [3]
+
+        a_data_nd = allocate_hexagon_array(
+            hexagon_session.device,
+            data=transformed_input_np,
+            dtype=dtype,
+            axis_separators=input_axis_separator,
+            mem_scope="global.vtcm",
+        )
+
+        m_data_nd = allocate_hexagon_array(
+            hexagon_session.device,
+            quantize_expected_output_np.shape,
+            dtype=dtype,
+        )
+
+        mod = hexagon_session.load_module(func)
+        mod(a_data_nd, m_data_nd)
+
+        # Convert nd to np
+        m_data_np = m_data_nd.numpy()
+
+        np.testing.assert_allclose(quantize_expected_output_np, m_data_np, atol=2)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()