You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2021/03/05 18:55:17 UTC

[tvm] branch main updated: Fix autotuning, broken in #7337 (#7566)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 7344b66  Fix autotuning, broken in #7337 (#7566)
7344b66 is described below

commit 7344b6666e76bb69fa1bf727c25074071fa522fb
Author: Andrew Reusch <ar...@octoml.ai>
AuthorDate: Fri Mar 5 10:55:05 2021 -0800

    Fix autotuning, broken in #7337 (#7566)
    
    * Fix autotuning, broken in #7337
    
    * retrigger CI, because I don't understand how it passed
---
 python/tvm/autotvm/measure/measure_methods.py |  10 +-
 tests/python/integration/test_tuning.py       | 244 ++++++++++++++++----------
 2 files changed, 157 insertions(+), 97 deletions(-)

diff --git a/python/tvm/autotvm/measure/measure_methods.py b/python/tvm/autotvm/measure/measure_methods.py
index 62fd811..b68767b 100644
--- a/python/tvm/autotvm/measure/measure_methods.py
+++ b/python/tvm/autotvm/measure/measure_methods.py
@@ -280,7 +280,13 @@ class RPCRunner(Runner):
 
     def run(self, measure_inputs, build_results):
         results = []
-        remote_args = (self.key, self.host, self.port, self.priority, self.timeout)
+        remote_kwargs = dict(
+            device_key=self.key,
+            host=self.host,
+            port=self.port,
+            priority=self.priority,
+            timeout=self.timeout,
+        )
 
         for i in range(0, len(measure_inputs), self.n_parallel):
             futures = []
@@ -300,7 +306,7 @@ class RPCRunner(Runner):
                     self.repeat,
                     self.min_repeat_ms,
                     self.cooldown_interval,
-                    remote_args,
+                    remote_kwargs,
                     self.enable_cpu_cache_flush,
                     module_loader,
                 )
diff --git a/tests/python/integration/test_tuning.py b/tests/python/integration/test_tuning.py
index 64b2c16..813352c 100644
--- a/tests/python/integration/test_tuning.py
+++ b/tests/python/integration/test_tuning.py
@@ -18,9 +18,14 @@
 Test the tuner
 """
 import logging
+import sys
+import textwrap
 import time
 
+import pytest
+
 import tvm
+import tvm.relay
 from tvm import te
 
 from tvm import autotvm
@@ -29,94 +34,100 @@ from tvm.autotvm.tuner import RandomTuner
 import tvm.testing
 
 
-@autotvm.template("testing/conv2d_no_batching")
-def conv2d_no_batching(N, H, W, CI, CO, KH, KW):
-    """An example template for testing"""
-    assert N == 1, "Only consider batch_size = 1 in this template"
-
-    data = te.placeholder((N, CI, H, W), name="data")
-    kernel = te.placeholder((CO, CI, KH, KW), name="kernel")
-
-    rc = te.reduce_axis((0, CI), name="rc")
-    ry = te.reduce_axis((0, KH), name="ry")
-    rx = te.reduce_axis((0, KW), name="rx")
-
-    conv = te.compute(
-        (N, CO, H - KH + 1, W - KW + 1),
-        lambda nn, ff, yy, xx: te.sum(
-            data[nn, rc, yy + ry, xx + rx] * kernel[ff, rc, ry, rx], axis=[rc, ry, rx]
-        ),
-        tag="conv2d_nchw",
-    )
-
-    s = te.create_schedule([conv.op])
-
-    output = conv
-    OL = s.cache_write(conv, "local")
-
-    # create cache stage
-    AA = s.cache_read(data, "shared", [OL])
-    WW = s.cache_read(kernel, "shared", [OL])
-    AL = s.cache_read(AA, "local", [OL])
-    WL = s.cache_read(WW, "local", [OL])
-
-    # tile and bind spatial axes
-    n, f, y, x = s[output].op.axis
-    cfg = autotvm.get_config()
-    cfg.define_split("tile_f", cfg.axis(f), num_outputs=4)
-    cfg.define_split("tile_y", cfg.axis(y), num_outputs=4)
-    cfg.define_split("tile_x", cfg.axis(x), num_outputs=4)
-    bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f)
-    by, vy, ty, yi = cfg["tile_y"].apply(s, output, y)
-    bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x)
-    kernel_scope = n  # this is the scope to attach global config inside this kernel
-
-    s[output].bind(bf, te.thread_axis("blockIdx.z"))
-    s[output].bind(by, te.thread_axis("blockIdx.y"))
-    s[output].bind(bx, te.thread_axis("blockIdx.x"))
-    s[output].bind(vf, te.thread_axis("vthread"))
-    s[output].bind(vy, te.thread_axis("vthread"))
-    s[output].bind(vx, te.thread_axis("vthread"))
-    s[output].bind(tf, te.thread_axis("threadIdx.z"))
-    s[output].bind(ty, te.thread_axis("threadIdx.y"))
-    s[output].bind(tx, te.thread_axis("threadIdx.x"))
-    s[output].reorder(n, bf, by, bx, vf, vy, vx, tf, ty, tx, fi, yi, xi)
-    s[OL].compute_at(s[output], tx)
-
-    # tile and bind reduction axes
-    n, f, y, x = s[OL].op.axis
-    rc, ry, rx = s[OL].op.reduce_axis
-    cfg.define_split("tile_rc", cfg.axis(rc), num_outputs=3)
-    cfg.define_split("tile_ry", cfg.axis(ry), num_outputs=3)
-    cfg.define_split("tile_rx", cfg.axis(rx), num_outputs=3)
-    rco, rcm, rci = cfg["tile_rc"].apply(s, OL, rc)
-    ryo, rym, ryi = cfg["tile_rx"].apply(s, OL, ry)
-    rxo, rxm, rxi = cfg["tile_ry"].apply(s, OL, rx)
-    s[OL].reorder(rco, ryo, rxo, rcm, rym, rxm, rci, ryi, rxi, n, f, y, x)
-
-    s[AA].compute_at(s[OL], rxo)
-    s[WW].compute_at(s[OL], rxo)
-    s[AL].compute_at(s[OL], rxm)
-    s[WL].compute_at(s[OL], rxm)
-
-    # cooperative fetching
-    for load in [AA, WW]:
-        n, f, y, x = s[load].op.axis
-        fused = s[load].fuse(n, f, y, x)
-        tz, fused = s[load].split(fused, nparts=cfg["tile_f"].size[2])
-        ty, fused = s[load].split(fused, nparts=cfg["tile_y"].size[2])
-        tx, fused = s[load].split(fused, nparts=cfg["tile_x"].size[2])
-        s[load].bind(tz, te.thread_axis("threadIdx.z"))
-        s[load].bind(ty, te.thread_axis("threadIdx.y"))
-        s[load].bind(tx, te.thread_axis("threadIdx.x"))
-
-    # tune unroll
-    cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
-    cfg.define_knob("unroll_explicit", [0, 1])
-    s[output].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val)
-    s[output].pragma(kernel_scope, "unroll_explicit", cfg["unroll_explicit"].val)
-
-    return s, [data, kernel, conv]
+def setup_module():
+    @autotvm.template("testing/conv2d_no_batching")
+    def conv2d_no_batching(N, H, W, CI, CO, KH, KW):
+        """An example template for testing"""
+        assert N == 1, "Only consider batch_size = 1 in this template"
+
+        data = te.placeholder((N, CI, H, W), name="data")
+        kernel = te.placeholder((CO, CI, KH, KW), name="kernel")
+
+        rc = te.reduce_axis((0, CI), name="rc")
+        ry = te.reduce_axis((0, KH), name="ry")
+        rx = te.reduce_axis((0, KW), name="rx")
+
+        conv = te.compute(
+            (N, CO, H - KH + 1, W - KW + 1),
+            lambda nn, ff, yy, xx: te.sum(
+                data[nn, rc, yy + ry, xx + rx] * kernel[ff, rc, ry, rx], axis=[rc, ry, rx]
+            ),
+            tag="conv2d_nchw",
+        )
+
+        s = te.create_schedule([conv.op])
+
+        output = conv
+        OL = s.cache_write(conv, "local")
+
+        # create cache stage
+        AA = s.cache_read(data, "shared", [OL])
+        WW = s.cache_read(kernel, "shared", [OL])
+        AL = s.cache_read(AA, "local", [OL])
+        WL = s.cache_read(WW, "local", [OL])
+
+        # tile and bind spatial axes
+        n, f, y, x = s[output].op.axis
+        cfg = autotvm.get_config()
+        cfg.define_split("tile_f", cfg.axis(f), num_outputs=4)
+        cfg.define_split("tile_y", cfg.axis(y), num_outputs=4)
+        cfg.define_split("tile_x", cfg.axis(x), num_outputs=4)
+        bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f)
+        by, vy, ty, yi = cfg["tile_y"].apply(s, output, y)
+        bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x)
+        kernel_scope = n  # this is the scope to attach global config inside this kernel
+
+        s[output].bind(bf, te.thread_axis("blockIdx.z"))
+        s[output].bind(by, te.thread_axis("blockIdx.y"))
+        s[output].bind(bx, te.thread_axis("blockIdx.x"))
+        s[output].bind(vf, te.thread_axis("vthread"))
+        s[output].bind(vy, te.thread_axis("vthread"))
+        s[output].bind(vx, te.thread_axis("vthread"))
+        s[output].bind(tf, te.thread_axis("threadIdx.z"))
+        s[output].bind(ty, te.thread_axis("threadIdx.y"))
+        s[output].bind(tx, te.thread_axis("threadIdx.x"))
+        s[output].reorder(n, bf, by, bx, vf, vy, vx, tf, ty, tx, fi, yi, xi)
+        s[OL].compute_at(s[output], tx)
+
+        # tile and bind reduction axes
+        n, f, y, x = s[OL].op.axis
+        rc, ry, rx = s[OL].op.reduce_axis
+        cfg.define_split("tile_rc", cfg.axis(rc), num_outputs=3)
+        cfg.define_split("tile_ry", cfg.axis(ry), num_outputs=3)
+        cfg.define_split("tile_rx", cfg.axis(rx), num_outputs=3)
+        rco, rcm, rci = cfg["tile_rc"].apply(s, OL, rc)
+        ryo, rym, ryi = cfg["tile_rx"].apply(s, OL, ry)
+        rxo, rxm, rxi = cfg["tile_ry"].apply(s, OL, rx)
+        s[OL].reorder(rco, ryo, rxo, rcm, rym, rxm, rci, ryi, rxi, n, f, y, x)
+
+        s[AA].compute_at(s[OL], rxo)
+        s[WW].compute_at(s[OL], rxo)
+        s[AL].compute_at(s[OL], rxm)
+        s[WL].compute_at(s[OL], rxm)
+
+        # cooperative fetching
+        for load in [AA, WW]:
+            n, f, y, x = s[load].op.axis
+            fused = s[load].fuse(n, f, y, x)
+            tz, fused = s[load].split(fused, nparts=cfg["tile_f"].size[2])
+            ty, fused = s[load].split(fused, nparts=cfg["tile_y"].size[2])
+            tx, fused = s[load].split(fused, nparts=cfg["tile_x"].size[2])
+            s[load].bind(tz, te.thread_axis("threadIdx.z"))
+            s[load].bind(ty, te.thread_axis("threadIdx.y"))
+            s[load].bind(tx, te.thread_axis("threadIdx.x"))
+
+        # tune unroll
+        cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
+        cfg.define_knob("unroll_explicit", [0, 1])
+        s[output].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val)
+        s[output].pragma(kernel_scope, "unroll_explicit", cfg["unroll_explicit"].val)
+
+        return s, [data, kernel, conv]
+
+
+def teardown_module():
+    # TODO(areusch): Tasks should not be registered into a global.
+    del autotvm.task.task.TASK_TABLE["testing/conv2d_no_batching"]
 
 
 def get_sample_task(target=tvm.target.cuda(), target_host=None):
@@ -131,19 +142,62 @@ def get_sample_task(target=tvm.target.cuda(), target_host=None):
 
 
 @tvm.testing.parametrize_targets("cuda", "opencl")
-def test_tuning(target, ctx):
+def test_tuning_gpu(target, ctx):
     # init task
     task, target = get_sample_task(target, None)
-    logging.info("%s", task.config_space)
+    logging.info("task config space: %s", task.config_space)
 
     measure_option = autotvm.measure_option(autotvm.LocalBuilder(), autotvm.LocalRunner())
 
+    results = []
+
     tuner = RandomTuner(task)
-    tuner.tune(n_trial=20, measure_option=measure_option)
+    tuner.tune(
+        n_trial=20,
+        measure_option=measure_option,
+        callbacks=(lambda _tuner, _inputs, rs: results.extend(rs),),
+    )
 
+    assert len(results) == 20
 
-if __name__ == "__main__":
-    # only print log when invoked from main
-    logging.basicConfig(level=logging.DEBUG)
+    successful_results = [r for r in results if r.error_no == autotvm.MeasureErrorNo.NO_ERROR]
+    assert len(successful_results) > 0, f"No successful tuning runs: {results!r}"
+
+
+def test_tuning_cpu():
+    ir_mod = tvm.parser.fromtext(
+        textwrap.dedent(
+            """
+        #[version = "0.0.5"]
+        def @main(%a : Tensor[(1, 3, 32, 32), float32], %b : Tensor[(3, 3, 5, 5), float32]) {
+               nn.conv2d(%a, %b, data_layout="NCHW", kernel_layout="OIHW")
+        }
+        """
+        )
+    )
+    tasks = autotvm.task.relay_integration.extract_from_program(
+        ir_mod, {}, tvm.target.create("llvm")
+    )
+    assert len(tasks) == 1, f"Extracted != 1 task from program: {tasks!r}"
+
+    task = tasks[0]
+
+    measure_option = autotvm.measure_option(autotvm.LocalBuilder(), autotvm.LocalRunner())
+
+    results = []
+
+    tuner = RandomTuner(task)
+    tuner.tune(
+        n_trial=20,
+        measure_option=measure_option,
+        callbacks=(lambda _tuner, _inputs, rs: results.extend(rs),),
+    )
+
+    assert len(results) == 20
 
-    test_tuning()
+    successful_results = [r for r in results if r.error_no == autotvm.MeasureErrorNo.NO_ERROR]
+    assert len(successful_results) > 0, f"No successful tuning runs: {results!r}"
+
+
+if __name__ == "__main__":
+    sys.exit(pytest.main([__file__] + sys.argv[1:]))