You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/07/28 05:00:55 UTC

[tvm] branch main updated: [Adreno] Fix winograd tests and accuracy (#12202)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 4bcaecf979 [Adreno] Fix winograd tests and accuracy (#12202)
4bcaecf979 is described below

commit 4bcaecf979fb17eaec4df80da534c8bc82933fb3
Author: Egor Churaev <eg...@gmail.com>
AuthorDate: Thu Jul 28 08:00:50 2022 +0300

    [Adreno] Fix winograd tests and accuracy (#12202)
    
    * [Adreno] Fix winograd tests and accuracy
    
    * Fix lint
    
    * Fix test on cpu
---
 python/tvm/topi/adreno/conv2d_winograd_common.py | 16 ++++---
 tests/python/relay/test_conv2d_nchw_texture.py   | 60 +++++++++++++++++++++++-
 tests/python/relay/utils/adreno_utils.py         | 25 ++++++++--
 3 files changed, 88 insertions(+), 13 deletions(-)

diff --git a/python/tvm/topi/adreno/conv2d_winograd_common.py b/python/tvm/topi/adreno/conv2d_winograd_common.py
index 6d11c1fe73..b0cec0f702 100644
--- a/python/tvm/topi/adreno/conv2d_winograd_common.py
+++ b/python/tvm/topi/adreno/conv2d_winograd_common.py
@@ -90,6 +90,7 @@ def conv2d_winograd_comp(
 
     convert_from4d = False
     if len(data.shape) == 4:
+        convert_from4d = True
         if layout == "NCHW":
             N, DCI, H, W = get_const_tuple(data.shape)
         else:
@@ -120,7 +121,6 @@ def conv2d_winograd_comp(
             data = tvm.te.placeholder(dshape, data.dtype, name="data_placeholder")
             kernel = tvm.te.placeholder(kshape, kernel.dtype, name="kernel_placeholder")
         else:
-            convert_from4d = True
             data = pack_input(
                 data, layout, N, in_channel_chunks, in_channel_block, in_channel_tail, H, W
             )
@@ -220,9 +220,9 @@ def conv2d_winograd_comp(
     idxdiv = tvm.tir.indexdiv
     idxmod = tvm.tir.indexmod
     if layout == "NCHW":
-        N, CI, H, W, CB = get_const_tuple(data.shape)
+        N, CI, _, _, CB = get_const_tuple(data.shape)
     else:
-        N, H, W, CI, CB = get_const_tuple(data.shape)
+        N, _, _, CI, CB = get_const_tuple(data.shape)
 
     # pack input tile
     if layout == "NCHW":
@@ -494,16 +494,18 @@ def schedule_conv2d_winograd(cfg, s, output, pre_computed):
         s[OL].set_scope("local")
         output = s.outputs[0]
 
-    m = alpha - 3 + 1
     if len(s[output].op.axis) == 4:
         n, co, h, w = s[output].op.axis
+        cb = None
     else:
-        n, co, h, w, _ = s[output].op.axis
-    ho, wo, hi, wi = s[output].tile(h, w, m, m)
+        n, co, h, w, cb = s[output].op.axis
     inverse_scope, n = s[output].split(n, nparts=1)
 
-    fused = s[output].fuse(n, co, ho, wo)
+    fused = s[output].fuse(n, co, h, w)
     bb, tt = s[output].split(fused, 128)
+    if cb is not None:
+        s[output].reorder(bb, tt, cb)
+        s[output].vectorize(cb)
 
     s[output].bind(bb, te.thread_axis("blockIdx.x"))
     s[output].bind(tt, te.thread_axis("threadIdx.x"))
diff --git a/tests/python/relay/test_conv2d_nchw_texture.py b/tests/python/relay/test_conv2d_nchw_texture.py
index 89f68dacbd..2dd88f6118 100644
--- a/tests/python/relay/test_conv2d_nchw_texture.py
+++ b/tests/python/relay/test_conv2d_nchw_texture.py
@@ -20,6 +20,7 @@ import tvm
 import numpy as np
 from tvm import relay
 from tvm.relay import testing
+from tvm.contrib import utils
 from utils.adreno_utils import gpu_preprocess, build_run_compare
 
 
@@ -432,6 +433,63 @@ def test_conv2d_vgg16_winograd_4d():
         "bias": tvm.nd.array(bias_data),
     }
 
-    graph = build_run_compare(mod, params1, {"data": input_shape}, dtype, target)
+    temp = utils.tempdir()
+    stat_file = temp.relpath("stat.log")
+    with open(stat_file, "w") as f:
+        f.write(
+            '{"input": ["opencl -keys=adreno,opencl,gpu -device=adreno -max_num_threads=256", "conv2d_nchw_winograd_acc32.image2d", [["TENSOR", [1, 512, 28, 28], "float16"], ["TENSOR", [512, 512, 3, 3], "float16"], [1, 1], [1, 1, 1, 1], [1, 1], "float16"], {}], "config": {"index": 1591, "code_hash": null, "entity": [["auto_unroll_max_step", "ot", 4], ["tile_y", "sp", [-1, 1, 32]], ["tile_x", "sp", [-1, 4, 2]], ["tile_rc", "sp", [-1, 8]]]}, "result": [[0.0037244], 0, 7.06374192237854, 165 [...]
+        )
+    graph = build_run_compare(
+        mod, params1, {"data": input_shape}, dtype, target, stat_file=stat_file
+    )
+    matches = re.findall("winograd", graph)
+    assert len(matches) > 0
+
+
+@tvm.testing.requires_opencl
+def test_conv2d_winograd_conv():
+    target = "opencl --device=adreno"
+    dtype = "float16"
+
+    input_shape = (1, 4, 3, 3)
+    A = relay.var("data", shape=input_shape, dtype=dtype)
+    filter_shape3 = (8, 4, 3, 3)
+    bias_shape3 = (8,)
+    B3 = relay.var("weight3", shape=filter_shape3, dtype=dtype)
+    D = relay.nn.conv2d(
+        A, B3, padding=[1, 1, 1, 1], channels=8, kernel_size=[3, 3], out_dtype=dtype
+    )
+
+    filter_shape4 = (8, 8, 3, 3)
+    bias_shape4 = (8,)
+    B4 = relay.var("weight4", shape=filter_shape4, dtype=dtype)
+    D = relay.nn.conv2d(
+        D, B4, padding=[1, 1, 1, 1], channels=8, kernel_size=[3, 3], out_dtype=dtype
+    )
+    mod = relay.Function([A, B3, B4], D)
+    np.random.seed(1)
+    initializer = relay.testing.init.Xavier()
+    filter_data3 = np.zeros(filter_shape3).astype(dtype)
+    bias_data3 = np.zeros(bias_shape3).astype(dtype)
+    filter_data4 = np.zeros(filter_shape4).astype(dtype)
+    bias_data4 = np.zeros(bias_shape4).astype(dtype)
+    initializer("weight", filter_data3)
+    initializer("bias", bias_data3)
+    initializer("weight", filter_data4)
+    initializer("bias", bias_data4)
+    params1 = {
+        "weight3": tvm.nd.array(filter_data3),
+        "weight4": tvm.nd.array(filter_data4),
+    }
+
+    temp = utils.tempdir()
+    stat_file = temp.relpath("stat.log")
+    with open(stat_file, "w") as f:
+        f.write(
+            '{"input": ["opencl -keys=adreno,opencl,gpu -device=adreno -max_num_threads=256", "conv2d_nchw_winograd_acc32.image2d", [["TENSOR", [1, 4, 3, 3], "float16"], ["TENSOR", [8, 4, 3, 3], "float16"], [1, 1], [1, 1, 1, 1], [1, 1], "float16"], {}], "config": {"index": 1591, "code_hash": null, "entity": [["auto_unroll_max_step", "ot", 4], ["tile_y", "sp", [-1, 1, 32]], ["tile_x", "sp", [-1, 4, 2]], ["tile_rc", "sp", [-1, 8]]]}, "result": [[0.0037244], 0, 7.06374192237854, 1653898629. [...]
+        )
+    graph = build_run_compare(
+        mod, params1, {"data": input_shape}, dtype, target, stat_file=stat_file
+    )
     matches = re.findall("winograd", graph)
     assert len(matches) > 0
diff --git a/tests/python/relay/utils/adreno_utils.py b/tests/python/relay/utils/adreno_utils.py
index 3bb4a6ada4..6e353b22cd 100644
--- a/tests/python/relay/utils/adreno_utils.py
+++ b/tests/python/relay/utils/adreno_utils.py
@@ -20,6 +20,7 @@ import os
 import tvm
 import numpy as np
 from tvm import relay
+from tvm import autotvm
 from tvm.relay import testing
 from tvm.relay.transform import recast
 from tvm.contrib import graph_runtime
@@ -45,7 +46,13 @@ def get_cpu_reference(mod, params1, input_shape, inputs):
 
 # build module run with opencl and cpu, compare results
 def build_run_compare(
-    tvm_mod, params1, input_shape, dtype="float32", target="llvm", gpu_preprocess=None
+    tvm_mod,
+    params1,
+    input_shape,
+    dtype="float32",
+    target="llvm",
+    gpu_preprocess=None,
+    stat_file=None,
 ):
 
     if "TVM_TRACKER_HOST" in os.environ and "TVM_TRACKER_PORT" in os.environ:
@@ -63,10 +70,18 @@ def build_run_compare(
     else:
         tvm_mod_nchwc = tvm_mod
 
-    with relay.build_config(opt_level=3):
-        graph, lib, params = relay.build(
-            tvm_mod_nchwc, target_host=target_host, target=target, params=params1
-        )
+    if stat_file is not None:
+        with autotvm.apply_history_best(stat_file):
+            with tvm.transform.PassContext(opt_level=3):
+                graph, lib, params = relay.build(
+                    tvm_mod_nchwc, target_host=target_host, target=target, params=params1
+                )
+    else:
+        with tvm.transform.PassContext(opt_level=3):
+            graph, lib, params = relay.build(
+                tvm_mod_nchwc, target_host=target_host, target=target, params=params1
+            )
+
     if run_on_host:
         ctx = tvm.opencl()
         m = graph_runtime.create(graph, lib, ctx)