You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2021/03/05 18:55:17 UTC
[tvm] branch main updated: Fix autotuning, broken in #7337 (#7566)
This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 7344b66 Fix autotuning, broken in #7337 (#7566)
7344b66 is described below
commit 7344b6666e76bb69fa1bf727c25074071fa522fb
Author: Andrew Reusch <ar...@octoml.ai>
AuthorDate: Fri Mar 5 10:55:05 2021 -0800
Fix autotuning, broken in #7337 (#7566)
* Fix autotuning, broken in #7337
* retrigger CI, because I don't understand how it passed
---
python/tvm/autotvm/measure/measure_methods.py | 10 +-
tests/python/integration/test_tuning.py | 244 ++++++++++++++++----------
2 files changed, 157 insertions(+), 97 deletions(-)
diff --git a/python/tvm/autotvm/measure/measure_methods.py b/python/tvm/autotvm/measure/measure_methods.py
index 62fd811..b68767b 100644
--- a/python/tvm/autotvm/measure/measure_methods.py
+++ b/python/tvm/autotvm/measure/measure_methods.py
@@ -280,7 +280,13 @@ class RPCRunner(Runner):
def run(self, measure_inputs, build_results):
results = []
- remote_args = (self.key, self.host, self.port, self.priority, self.timeout)
+ remote_kwargs = dict(
+ device_key=self.key,
+ host=self.host,
+ port=self.port,
+ priority=self.priority,
+ timeout=self.timeout,
+ )
for i in range(0, len(measure_inputs), self.n_parallel):
futures = []
@@ -300,7 +306,7 @@ class RPCRunner(Runner):
self.repeat,
self.min_repeat_ms,
self.cooldown_interval,
- remote_args,
+ remote_kwargs,
self.enable_cpu_cache_flush,
module_loader,
)
diff --git a/tests/python/integration/test_tuning.py b/tests/python/integration/test_tuning.py
index 64b2c16..813352c 100644
--- a/tests/python/integration/test_tuning.py
+++ b/tests/python/integration/test_tuning.py
@@ -18,9 +18,14 @@
Test the tuner
"""
import logging
+import sys
+import textwrap
import time
+import pytest
+
import tvm
+import tvm.relay
from tvm import te
from tvm import autotvm
@@ -29,94 +34,100 @@ from tvm.autotvm.tuner import RandomTuner
import tvm.testing
-@autotvm.template("testing/conv2d_no_batching")
-def conv2d_no_batching(N, H, W, CI, CO, KH, KW):
- """An example template for testing"""
- assert N == 1, "Only consider batch_size = 1 in this template"
-
- data = te.placeholder((N, CI, H, W), name="data")
- kernel = te.placeholder((CO, CI, KH, KW), name="kernel")
-
- rc = te.reduce_axis((0, CI), name="rc")
- ry = te.reduce_axis((0, KH), name="ry")
- rx = te.reduce_axis((0, KW), name="rx")
-
- conv = te.compute(
- (N, CO, H - KH + 1, W - KW + 1),
- lambda nn, ff, yy, xx: te.sum(
- data[nn, rc, yy + ry, xx + rx] * kernel[ff, rc, ry, rx], axis=[rc, ry, rx]
- ),
- tag="conv2d_nchw",
- )
-
- s = te.create_schedule([conv.op])
-
- output = conv
- OL = s.cache_write(conv, "local")
-
- # create cache stage
- AA = s.cache_read(data, "shared", [OL])
- WW = s.cache_read(kernel, "shared", [OL])
- AL = s.cache_read(AA, "local", [OL])
- WL = s.cache_read(WW, "local", [OL])
-
- # tile and bind spatial axes
- n, f, y, x = s[output].op.axis
- cfg = autotvm.get_config()
- cfg.define_split("tile_f", cfg.axis(f), num_outputs=4)
- cfg.define_split("tile_y", cfg.axis(y), num_outputs=4)
- cfg.define_split("tile_x", cfg.axis(x), num_outputs=4)
- bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f)
- by, vy, ty, yi = cfg["tile_y"].apply(s, output, y)
- bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x)
- kernel_scope = n # this is the scope to attach global config inside this kernel
-
- s[output].bind(bf, te.thread_axis("blockIdx.z"))
- s[output].bind(by, te.thread_axis("blockIdx.y"))
- s[output].bind(bx, te.thread_axis("blockIdx.x"))
- s[output].bind(vf, te.thread_axis("vthread"))
- s[output].bind(vy, te.thread_axis("vthread"))
- s[output].bind(vx, te.thread_axis("vthread"))
- s[output].bind(tf, te.thread_axis("threadIdx.z"))
- s[output].bind(ty, te.thread_axis("threadIdx.y"))
- s[output].bind(tx, te.thread_axis("threadIdx.x"))
- s[output].reorder(n, bf, by, bx, vf, vy, vx, tf, ty, tx, fi, yi, xi)
- s[OL].compute_at(s[output], tx)
-
- # tile and bind reduction axes
- n, f, y, x = s[OL].op.axis
- rc, ry, rx = s[OL].op.reduce_axis
- cfg.define_split("tile_rc", cfg.axis(rc), num_outputs=3)
- cfg.define_split("tile_ry", cfg.axis(ry), num_outputs=3)
- cfg.define_split("tile_rx", cfg.axis(rx), num_outputs=3)
- rco, rcm, rci = cfg["tile_rc"].apply(s, OL, rc)
- ryo, rym, ryi = cfg["tile_rx"].apply(s, OL, ry)
- rxo, rxm, rxi = cfg["tile_ry"].apply(s, OL, rx)
- s[OL].reorder(rco, ryo, rxo, rcm, rym, rxm, rci, ryi, rxi, n, f, y, x)
-
- s[AA].compute_at(s[OL], rxo)
- s[WW].compute_at(s[OL], rxo)
- s[AL].compute_at(s[OL], rxm)
- s[WL].compute_at(s[OL], rxm)
-
- # cooperative fetching
- for load in [AA, WW]:
- n, f, y, x = s[load].op.axis
- fused = s[load].fuse(n, f, y, x)
- tz, fused = s[load].split(fused, nparts=cfg["tile_f"].size[2])
- ty, fused = s[load].split(fused, nparts=cfg["tile_y"].size[2])
- tx, fused = s[load].split(fused, nparts=cfg["tile_x"].size[2])
- s[load].bind(tz, te.thread_axis("threadIdx.z"))
- s[load].bind(ty, te.thread_axis("threadIdx.y"))
- s[load].bind(tx, te.thread_axis("threadIdx.x"))
-
- # tune unroll
- cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
- cfg.define_knob("unroll_explicit", [0, 1])
- s[output].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val)
- s[output].pragma(kernel_scope, "unroll_explicit", cfg["unroll_explicit"].val)
-
- return s, [data, kernel, conv]
+def setup_module():
+ @autotvm.template("testing/conv2d_no_batching")
+ def conv2d_no_batching(N, H, W, CI, CO, KH, KW):
+ """An example template for testing"""
+ assert N == 1, "Only consider batch_size = 1 in this template"
+
+ data = te.placeholder((N, CI, H, W), name="data")
+ kernel = te.placeholder((CO, CI, KH, KW), name="kernel")
+
+ rc = te.reduce_axis((0, CI), name="rc")
+ ry = te.reduce_axis((0, KH), name="ry")
+ rx = te.reduce_axis((0, KW), name="rx")
+
+ conv = te.compute(
+ (N, CO, H - KH + 1, W - KW + 1),
+ lambda nn, ff, yy, xx: te.sum(
+ data[nn, rc, yy + ry, xx + rx] * kernel[ff, rc, ry, rx], axis=[rc, ry, rx]
+ ),
+ tag="conv2d_nchw",
+ )
+
+ s = te.create_schedule([conv.op])
+
+ output = conv
+ OL = s.cache_write(conv, "local")
+
+ # create cache stage
+ AA = s.cache_read(data, "shared", [OL])
+ WW = s.cache_read(kernel, "shared", [OL])
+ AL = s.cache_read(AA, "local", [OL])
+ WL = s.cache_read(WW, "local", [OL])
+
+ # tile and bind spatial axes
+ n, f, y, x = s[output].op.axis
+ cfg = autotvm.get_config()
+ cfg.define_split("tile_f", cfg.axis(f), num_outputs=4)
+ cfg.define_split("tile_y", cfg.axis(y), num_outputs=4)
+ cfg.define_split("tile_x", cfg.axis(x), num_outputs=4)
+ bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f)
+ by, vy, ty, yi = cfg["tile_y"].apply(s, output, y)
+ bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x)
+ kernel_scope = n # this is the scope to attach global config inside this kernel
+
+ s[output].bind(bf, te.thread_axis("blockIdx.z"))
+ s[output].bind(by, te.thread_axis("blockIdx.y"))
+ s[output].bind(bx, te.thread_axis("blockIdx.x"))
+ s[output].bind(vf, te.thread_axis("vthread"))
+ s[output].bind(vy, te.thread_axis("vthread"))
+ s[output].bind(vx, te.thread_axis("vthread"))
+ s[output].bind(tf, te.thread_axis("threadIdx.z"))
+ s[output].bind(ty, te.thread_axis("threadIdx.y"))
+ s[output].bind(tx, te.thread_axis("threadIdx.x"))
+ s[output].reorder(n, bf, by, bx, vf, vy, vx, tf, ty, tx, fi, yi, xi)
+ s[OL].compute_at(s[output], tx)
+
+ # tile and bind reduction axes
+ n, f, y, x = s[OL].op.axis
+ rc, ry, rx = s[OL].op.reduce_axis
+ cfg.define_split("tile_rc", cfg.axis(rc), num_outputs=3)
+ cfg.define_split("tile_ry", cfg.axis(ry), num_outputs=3)
+ cfg.define_split("tile_rx", cfg.axis(rx), num_outputs=3)
+ rco, rcm, rci = cfg["tile_rc"].apply(s, OL, rc)
+ ryo, rym, ryi = cfg["tile_rx"].apply(s, OL, ry)
+ rxo, rxm, rxi = cfg["tile_ry"].apply(s, OL, rx)
+ s[OL].reorder(rco, ryo, rxo, rcm, rym, rxm, rci, ryi, rxi, n, f, y, x)
+
+ s[AA].compute_at(s[OL], rxo)
+ s[WW].compute_at(s[OL], rxo)
+ s[AL].compute_at(s[OL], rxm)
+ s[WL].compute_at(s[OL], rxm)
+
+ # cooperative fetching
+ for load in [AA, WW]:
+ n, f, y, x = s[load].op.axis
+ fused = s[load].fuse(n, f, y, x)
+ tz, fused = s[load].split(fused, nparts=cfg["tile_f"].size[2])
+ ty, fused = s[load].split(fused, nparts=cfg["tile_y"].size[2])
+ tx, fused = s[load].split(fused, nparts=cfg["tile_x"].size[2])
+ s[load].bind(tz, te.thread_axis("threadIdx.z"))
+ s[load].bind(ty, te.thread_axis("threadIdx.y"))
+ s[load].bind(tx, te.thread_axis("threadIdx.x"))
+
+ # tune unroll
+ cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
+ cfg.define_knob("unroll_explicit", [0, 1])
+ s[output].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val)
+ s[output].pragma(kernel_scope, "unroll_explicit", cfg["unroll_explicit"].val)
+
+ return s, [data, kernel, conv]
+
+
+def teardown_module():
+ # TODO(areusch): Tasks should not be registered into a global.
+ del autotvm.task.task.TASK_TABLE["testing/conv2d_no_batching"]
def get_sample_task(target=tvm.target.cuda(), target_host=None):
@@ -131,19 +142,62 @@ def get_sample_task(target=tvm.target.cuda(), target_host=None):
@tvm.testing.parametrize_targets("cuda", "opencl")
-def test_tuning(target, ctx):
+def test_tuning_gpu(target, ctx):
# init task
task, target = get_sample_task(target, None)
- logging.info("%s", task.config_space)
+ logging.info("task config space: %s", task.config_space)
measure_option = autotvm.measure_option(autotvm.LocalBuilder(), autotvm.LocalRunner())
+ results = []
+
tuner = RandomTuner(task)
- tuner.tune(n_trial=20, measure_option=measure_option)
+ tuner.tune(
+ n_trial=20,
+ measure_option=measure_option,
+ callbacks=(lambda _tuner, _inputs, rs: results.extend(rs),),
+ )
+ assert len(results) == 20
-if __name__ == "__main__":
- # only print log when invoked from main
- logging.basicConfig(level=logging.DEBUG)
+ successful_results = [r for r in results if r.error_no == autotvm.MeasureErrorNo.NO_ERROR]
+ assert len(successful_results) > 0, f"No successful tuning runs: {results!r}"
+
+
+def test_tuning_cpu():
+ ir_mod = tvm.parser.fromtext(
+ textwrap.dedent(
+ """
+ #[version = "0.0.5"]
+ def @main(%a : Tensor[(1, 3, 32, 32), float32], %b : Tensor[(3, 3, 5, 5), float32]) {
+ nn.conv2d(%a, %b, data_layout="NCHW", kernel_layout="OIHW")
+ }
+ """
+ )
+ )
+ tasks = autotvm.task.relay_integration.extract_from_program(
+ ir_mod, {}, tvm.target.create("llvm")
+ )
+ assert len(tasks) == 1, f"Extracted != 1 task from program: {tasks!r}"
+
+ task = tasks[0]
+
+ measure_option = autotvm.measure_option(autotvm.LocalBuilder(), autotvm.LocalRunner())
+
+ results = []
+
+ tuner = RandomTuner(task)
+ tuner.tune(
+ n_trial=20,
+ measure_option=measure_option,
+ callbacks=(lambda _tuner, _inputs, rs: results.extend(rs),),
+ )
+
+ assert len(results) == 20
- test_tuning()
+ successful_results = [r for r in results if r.error_no == autotvm.MeasureErrorNo.NO_ERROR]
+ assert len(successful_results) > 0, f"No successful tuning runs: {results!r}"
+
+
+if __name__ == "__main__":
+ sys.exit(pytest.main([__file__] + sys.argv[1:]))