You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ec...@apache.org on 2022/12/08 10:29:12 UTC

[tvm] branch main updated: [Adreno] Add global pooling schedule (#13573)

This is an automated email from the ASF dual-hosted git repository.

echuraev pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new f4cfcafba5 [Adreno] Add global pooling schedule (#13573)
f4cfcafba5 is described below

commit f4cfcafba5024d9bbef1b8bf422c6a25368837f3
Author: Andrey Malyshev <el...@gmail.com>
AuthorDate: Thu Dec 8 12:29:00 2022 +0200

    [Adreno] Add global pooling schedule (#13573)
    
    * [Adreno] Add global pooling schedule
    
    The parallelizm opportuninties in case of global pooling are
    limited by number of channels, need to change schedule to have
    parallelizm by reduction axis/use rfactor
    
    * address pylint hits
    
    * address PR comments
    
    * switch spatial axis to blk binding
---
 python/tvm/relay/op/strategy/adreno.py             |   7 ++
 python/tvm/topi/adreno/pooling.py                  | 107 ++++++++++++++++
 .../relay/opencl_texture/test_pool_texture.py      | 135 +++++++++++++++++++++
 3 files changed, 249 insertions(+)

diff --git a/python/tvm/relay/op/strategy/adreno.py b/python/tvm/relay/op/strategy/adreno.py
index 21252215fc..b606ab05d7 100644
--- a/python/tvm/relay/op/strategy/adreno.py
+++ b/python/tvm/relay/op/strategy/adreno.py
@@ -215,6 +215,13 @@ def schedule_reduce_adreno(attrs, outs, target):
         return topi.adreno.schedule_reduce(outs)
 
 
+@schedule_adaptive_pool.register(["adreno"])
+def schedule_adaptive_pool_adreno(attrs, outs, target):
+    """schedule adaptive pooling ops for adreno"""
+    with target:
+        return topi.adreno.schedule_adaptive_pool(outs, attrs.layout)
+
+
 @concatenate_strategy.register(["adreno"])
 def concatenate_strategy_adreno(attrs, inputs, out_type, target):
     strategy = _op.OpStrategy()
diff --git a/python/tvm/topi/adreno/pooling.py b/python/tvm/topi/adreno/pooling.py
index 49f103c04a..f02af0c01f 100644
--- a/python/tvm/topi/adreno/pooling.py
+++ b/python/tvm/topi/adreno/pooling.py
@@ -19,6 +19,113 @@
 import tvm
 from tvm import te
 from .. import tag
+from .utils import get_div
+
+
+def schedule_adaptive_pool(outs, layout="NCHW"):
+    """Schedule for adaptive_pool.
+
+    Parameters
+    ----------
+    outs: Array of Tensor
+        The computation graph description of adaptive_pool
+        in the format of an array of tensors.
+
+    Returns
+    -------
+    s: Schedule
+        The computation schedule for adaptive_pool.
+    """
+    outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
+    s = te.create_schedule([x.op for x in outs])
+
+    def _schedule_global(Pool, layout):
+        # examples of latest pool op is global max pool and non latest is global avg pooling
+        # OL - an Expr will be used for rfactor
+        # Out - programming of the parallelizm on the global level
+        # shared is not required, local could be enough but shared scope gives quite significant
+        # perf boost
+        if Pool.op in s.outputs:
+            Out = Pool
+            OL = s.cache_write(Pool, "shared")
+        else:
+            Out = outs[0].op.output(0)
+            s[Pool].set_scope("shared")
+            OL = Pool
+
+        PaddedInput = Pool.op.input_tensors[0]
+
+        # detect axis for later reorder and binding of batch/channel to blocks and
+        # spatial to threads
+        if layout in ("NCHW", "NCHW4c"):
+            channel_index = 1
+            height_index = 2
+            width_index = 3
+        else:
+            channel_index = 3
+            height_index = 1
+            width_index = 2
+
+        if isinstance(PaddedInput.op, tvm.te.ComputeOp):
+            s[PaddedInput].compute_inline()
+
+        fused_reduce = s[OL].fuse(*s[OL].op.reduce_axis)
+
+        spatial = PaddedInput.shape[height_index].value * PaddedInput.shape[width_index].value
+        # below values were selected empirically assuming that we should have some work in each
+        # thread (currently from 25-49) and number of threads not exceeding some threshold that
+        # was selected as 256 from performance point of view after experiments on Adreno 660
+        max_threads = spatial // 25 if spatial > 25 else 1
+        max_threads = 256 if max_threads > 256 else max_threads
+        num_thread = get_div(spatial, max_threads)
+
+        thread_y = te.thread_axis((0, num_thread), "threadIdx.y")
+
+        _, ki = s[OL].split(fused_reduce, factor=num_thread)
+        data_out_rf = s.rfactor(OL, ki)
+        s[data_out_rf].compute_at(s[OL], s[OL].op.reduce_axis[0])
+        s[OL].bind(s[OL].op.reduce_axis[0], thread_y)
+
+        naxis = s[Out].op.axis[0]
+        caxis = s[Out].op.axis[channel_index]
+        haxis = s[Out].op.axis[height_index]
+        waxis = s[Out].op.axis[width_index]
+
+        if layout in ("NHWC4c", "NCHW4c"):
+            texture_axis = s[Out].op.axis[-1]
+            s[Out].reorder(naxis, caxis, haxis, waxis, texture_axis)
+            s[Out].vectorize(texture_axis)
+        else:
+            texture_axis = None
+            s[Out].reorder(naxis, caxis, haxis, waxis)
+
+        bx = s[Out].fuse(naxis, caxis, haxis, waxis)
+        s[Out].bind(bx, te.thread_axis("blockIdx.x"))
+
+        s[OL].compute_at(s[Out], bx)
+
+    scheduled_ops = []
+
+    def traverse(OP):
+        """Internal traverse function"""
+        # inline all one-to-one-mapping operators except the last stage (output)
+        if tag.is_injective(OP.tag):
+            if OP not in s.outputs:
+                s[OP].compute_inline()
+            for tensor in OP.input_tensors:
+                if isinstance(tensor.op, te.tensor.ComputeOp) and tensor.op not in scheduled_ops:
+                    traverse(tensor.op)
+        # schedule global_pool
+        elif OP.tag.startswith("adaptive_pool"):
+            Pool = OP.output(0)
+            _schedule_global(Pool, layout)
+        else:
+            raise RuntimeError("Unsupported operator: %s" % OP.tag)
+
+        scheduled_ops.append(OP)
+
+    traverse(outs[0].op)
+    return s
 
 
 def schedule_pool(outs, layout):
diff --git a/tests/python/relay/opencl_texture/test_pool_texture.py b/tests/python/relay/opencl_texture/test_pool_texture.py
new file mode 100644
index 0000000000..faeb121c80
--- /dev/null
+++ b/tests/python/relay/opencl_texture/test_pool_texture.py
@@ -0,0 +1,135 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+from tvm import relay
+from utils.adreno_utils import build_run_compare
+
+
+dtype = tvm.testing.parameter("float32")
+
+
+@tvm.testing.requires_opencl
+@tvm.testing.parametrize_targets("opencl -device=adreno")
+def test_global_pool2d_nchw_wide(remote, target, dtype):
+    """
+    Use case of NCHW global pooling with big spatial valies
+    """
+    input_shape = (1, 32, 160, 160)
+    A = relay.var("data", shape=input_shape, dtype=dtype)
+    C = relay.nn.global_avg_pool2d(A)
+    mod = relay.Function([A], C)
+
+    build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, target)
+
+
+@tvm.testing.requires_opencl
+@tvm.testing.parametrize_targets("opencl -device=adreno")
+def test_global_pool2d_nchw4c_wide(remote, target, dtype):
+    """
+    Use case of blocked NCHW4c global pooling with big spatial valies
+    """
+    input_shape = (1, 8, 160, 160, 4)
+    A = relay.var("data", shape=input_shape, dtype=dtype)
+    C = relay.nn.global_avg_pool2d(A, layout="NCHW4c")
+    mod = relay.Function([A], C)
+
+    build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, target)
+
+
+@tvm.testing.requires_opencl
+@tvm.testing.parametrize_targets("opencl -device=adreno")
+def test_global_pool2d_nchw_deep(remote, target, dtype):
+    """
+    Use case of NCHW deep global pooling
+    """
+    input_shape = (1, 2048, 20, 20)
+    A = relay.var("data", shape=input_shape, dtype=dtype)
+    C = relay.nn.global_avg_pool2d(A)
+    mod = relay.Function([A], C)
+
+    build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, target)
+
+
+@tvm.testing.requires_opencl
+@tvm.testing.parametrize_targets("opencl -device=adreno")
+def test_global_pool2d_nchw4c_deep(remote, target, dtype):
+    """
+    Use case of blocked NCHW4c deep global pooling
+    """
+    input_shape = (1, 512, 20, 20, 4)
+    A = relay.var("data", shape=input_shape, dtype=dtype)
+    C = relay.nn.global_avg_pool2d(A, layout="NCHW4c")
+    mod = relay.Function([A], C)
+
+    build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, target)
+
+
+@tvm.testing.requires_opencl
+@tvm.testing.parametrize_targets("opencl -device=adreno")
+def test_global_pool2d_nhwc(remote, target, dtype):
+    """
+    Use case of NHWC global pooling with big spatial valies
+    """
+    input_shape = (1, 160, 160, 32)
+    A = relay.var("data", shape=input_shape, dtype=dtype)
+    C = relay.nn.global_avg_pool2d(A, layout="NHWC")
+    mod = relay.Function([A], C)
+
+    build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, target)
+
+
+@tvm.testing.requires_opencl
+@tvm.testing.parametrize_targets("opencl -device=adreno")
+def test_global_pool2d_nhwc4c(remote, target, dtype):
+    """
+    Use case of NHWC deep global pooling
+    """
+    input_shape = (1, 160, 160, 8, 4)
+    A = relay.var("data", shape=input_shape, dtype=dtype)
+    C = relay.nn.global_avg_pool2d(A, layout="NHWC4c")
+    mod = relay.Function([A], C)
+
+    build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, target)
+
+
+@tvm.testing.requires_opencl
+@tvm.testing.parametrize_targets("opencl -device=adreno")
+def test_global_max_pool2d_nchw_wide(remote, target, dtype):
+    """
+    Use case of NCHW global pooling with big spatial valies
+    """
+    input_shape = (1, 32, 160, 160)
+    A = relay.var("data", shape=input_shape, dtype=dtype)
+    C = relay.nn.global_max_pool2d(A)
+    mod = relay.Function([A], C)
+
+    build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, target)
+
+
+@tvm.testing.requires_opencl
+@tvm.testing.parametrize_targets("opencl -device=adreno")
+def test_global_max_pool2d_nchw4c_wide(remote, target, dtype):
+    """
+    Use case of blocked NCHW4c global pooling with big spatial valies
+    """
+    input_shape = (1, 8, 160, 160, 4)
+    A = relay.var("data", shape=input_shape, dtype=dtype)
+    C = relay.nn.global_max_pool2d(A, layout="NCHW4c")
+    mod = relay.Function([A], C)
+
+    build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, target)