You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/07/05 04:04:08 UTC

[tvm] branch main updated: [Adreno] Modify default AutoTVM params for conv2d (#12005)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new ef08c36294 [Adreno] Modify default AutoTVM params for conv2d (#12005)
ef08c36294 is described below

commit ef08c36294dc7c90f9d4536948507eca515012bd
Author: Andrey Malyshev <el...@gmail.com>
AuthorDate: Tue Jul 5 07:04:02 2022 +0300

    [Adreno] Modify default AutoTVM params for conv2d (#12005)
---
 python/tvm/topi/adreno/conv2d_nchw.py           |  3 ++
 python/tvm/topi/adreno/conv2d_nhwc.py           |  5 +++
 python/tvm/topi/adreno/depthwise_conv2d_nchw.py | 16 ++++++---
 python/tvm/topi/adreno/depthwise_conv2d_nhwc.py | 15 ++++++---
 python/tvm/topi/adreno/utils.py                 | 44 +++++++++++++++++++++++++
 5 files changed, 73 insertions(+), 10 deletions(-)

diff --git a/python/tvm/topi/adreno/conv2d_nchw.py b/python/tvm/topi/adreno/conv2d_nchw.py
index 96368b3e57..2a8f6028b7 100644
--- a/python/tvm/topi/adreno/conv2d_nchw.py
+++ b/python/tvm/topi/adreno/conv2d_nchw.py
@@ -28,6 +28,7 @@ from .utils import (
     expand_spatial_dimensions,
     add_pad,
     bind_data_copy,
+    get_default_conv2d_config,
 )
 
 
@@ -264,6 +265,8 @@ def schedule_conv2d_NCHWc_KCRSk(cfg, s, output):
     cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
     cfg.define_knob("unroll_explicit", [0, 1])
 
+    if cfg.is_fallback:
+        get_default_conv2d_config(cfg, conv.shape[1], conv.shape[2], conv.shape[3])
     ##### space definition end #####
 
     pad_data, kernel = s[conv].op.input_tensors
diff --git a/python/tvm/topi/adreno/conv2d_nhwc.py b/python/tvm/topi/adreno/conv2d_nhwc.py
index d40f813fdb..388f606ecb 100644
--- a/python/tvm/topi/adreno/conv2d_nhwc.py
+++ b/python/tvm/topi/adreno/conv2d_nhwc.py
@@ -29,6 +29,7 @@ from .utils import (
     add_pad,
     bind_data_copy,
     get_texture_storage,
+    get_default_conv2d_config,
 )
 
 
@@ -261,6 +262,10 @@ def schedule_conv2d_NHWC(cfg, s, output):
     cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
     cfg.define_knob("unroll_explicit", [0, 1])
 
+    if cfg.is_fallback:
+        get_default_conv2d_config(cfg, conv.shape[3], conv.shape[1], conv.shape[2])
+    ##### space definition end #####
+
     pad_data, kernel = s[conv].op.input_tensors
     if (
         isinstance(kernel.op, tvm.te.ComputeOp) and "filter_pack" in kernel.op.tag
diff --git a/python/tvm/topi/adreno/depthwise_conv2d_nchw.py b/python/tvm/topi/adreno/depthwise_conv2d_nchw.py
index 298bd11e00..a11c3f3d36 100644
--- a/python/tvm/topi/adreno/depthwise_conv2d_nchw.py
+++ b/python/tvm/topi/adreno/depthwise_conv2d_nchw.py
@@ -28,6 +28,8 @@ from .utils import (
     expand_spatial_dimensions,
     add_pad,
     bind_data_copy,
+    get_texture_storage,
+    get_default_conv2d_config,
 )
 
 
@@ -240,6 +242,9 @@ def schedule_depthwise_conv2d_NCHWc_KCRSk(cfg, s, output):
     cfg.define_split("tile_rx", rx, num_outputs=2)
     cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
     cfg.define_knob("unroll_explicit", [0, 1])
+
+    if cfg.is_fallback:
+        get_default_conv2d_config(cfg, conv.shape[1], conv.shape[2], conv.shape[3])
     ##### space definition end #####
 
     pad_data, kernel = s[conv].op.input_tensors
@@ -260,11 +265,12 @@ def schedule_depthwise_conv2d_NCHWc_KCRSk(cfg, s, output):
     if latest_blocked == latest and output != latest:
         s[output].compute_inline()
 
-    # create cache stage
-    AT = s.cache_read(pad_data, "global.texture", [conv])
-    WT = s.cache_read(kernel, "global.texture-weight", [conv])
-    bind_data_copy(s[AT])
-    bind_data_copy(s[WT])
+    if autotvm.GLOBAL_SCOPE.in_tuning or len(latest.op.axis) == 4:
+        # create cache stage for tuning only or in case of 4d case
+        AT = s.cache_read(pad_data, get_texture_storage(pad_data.shape), [conv])
+        bind_data_copy(s[AT])
+        WT = s.cache_read(kernel, get_texture_storage(kernel.shape), [conv])
+        bind_data_copy(s[WT])
 
     # tile and bind spatial axes
     n, fc, y, x, fb = s[latest_blocked].op.axis
diff --git a/python/tvm/topi/adreno/depthwise_conv2d_nhwc.py b/python/tvm/topi/adreno/depthwise_conv2d_nhwc.py
index b8a978d3c2..117daf825d 100644
--- a/python/tvm/topi/adreno/depthwise_conv2d_nhwc.py
+++ b/python/tvm/topi/adreno/depthwise_conv2d_nhwc.py
@@ -29,6 +29,7 @@ from .utils import (
     add_pad,
     bind_data_copy,
     get_texture_storage,
+    get_default_conv2d_config,
 )
 
 
@@ -235,6 +236,9 @@ def schedule_depthwise_conv2d_NHWC_HWOI(cfg, s, output):
     cfg.define_split("tile_rx", rx, num_outputs=2)
     cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
     cfg.define_knob("unroll_explicit", [0, 1])
+
+    if cfg.is_fallback:
+        get_default_conv2d_config(cfg, conv.shape[3], conv.shape[1], conv.shape[2])
     ##### space definition end #####
 
     pad_data, kernel = s[conv].op.input_tensors
@@ -255,11 +259,12 @@ def schedule_depthwise_conv2d_NHWC_HWOI(cfg, s, output):
     if latest_blocked == latest and output != latest:
         s[output].compute_inline()
 
-    # create cache stage
-    AT = s.cache_read(pad_data, get_texture_storage(pad_data.shape), [conv])
-    WT = s.cache_read(kernel, get_texture_storage(kernel.shape), [conv])
-    bind_data_copy(s[AT])
-    bind_data_copy(s[WT])
+    if autotvm.GLOBAL_SCOPE.in_tuning or len(latest.op.axis) == 4:
+        # create cache stage for tuning only or in case of 4d case
+        AT = s.cache_read(pad_data, get_texture_storage(pad_data.shape), [conv])
+        bind_data_copy(s[AT])
+        WT = s.cache_read(kernel, get_texture_storage(kernel.shape), [conv])
+        bind_data_copy(s[WT])
 
     # tile and bind spatial axes
     n, y, x, fc, fb = s[latest_blocked].op.axis
diff --git a/python/tvm/topi/adreno/utils.py b/python/tvm/topi/adreno/utils.py
index 78a992e56a..ea19e7d77d 100644
--- a/python/tvm/topi/adreno/utils.py
+++ b/python/tvm/topi/adreno/utils.py
@@ -22,6 +22,7 @@ import numpy
 from tvm import te
 from tvm.topi.utils import simplify
 from tvm.topi import nn
+from tvm.autotvm.task.space import SplitEntity
 from ..utils import get_const_tuple
 
 
@@ -575,3 +576,46 @@ def infer_tile_size(data, layout):
     if H % 8 == 0:
         return 4
     return 2
+
+
+def get_default_conv2d_config(cfg, fc, y, x):
+    """Defines conv2d default parameters for split axis for Adreno conv2d and depthwise conv2d"""
+    # look for vthread params:
+    vy = 1
+    for n in range(5, 0, -1):
+        if y % n == 0:
+            vy = n
+            break
+
+    vx = 1
+    for n in range(5, 0, -1):
+        if x % n == 0 and vy * n < 9:
+            vx = n
+            break
+
+    y = y // vy
+    x = x // vx
+
+    tfc = 1
+    for n in range(64, 0, -1):
+        if fc % n == 0:
+            tfc = n
+            break
+    ty = 1
+    for n in range(16, 0, -1):
+        if y % n == 0 and tfc * n <= 512:
+            ty = n
+            break
+    tx = 1
+    for n in range(16, 0, -1):
+        if x % n == 0 and tfc * ty * n <= 512:
+            tx = n
+            break
+
+    fc = fc // tfc
+    y = y // ty
+    x = x // tx
+
+    cfg["tile_fc"] = SplitEntity([fc, 1, tfc])
+    cfg["tile_y"] = SplitEntity([y, vy, ty])
+    cfg["tile_x"] = SplitEntity([x, vx, tx])