You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by co...@apache.org on 2020/12/25 09:14:26 UTC

[tvm] branch main updated: [AutoScheduler] Fix the conflict of thread pool in measurement (#7166)

This is an automated email from the ASF dual-hosted git repository.

comaniac pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 3d8fd2a  [AutoScheduler] Fix the conflict of thread pool in measurement (#7166)
3d8fd2a is described below

commit 3d8fd2a2e7c39216482275d04e71aa131216e4c2
Author: Lianmin Zheng <li...@gmail.com>
AuthorDate: Fri Dec 25 01:14:09 2020 -0800

    [AutoScheduler] Fix the conflict of thread pool in measurement (#7166)
---
 python/tvm/auto_scheduler/measure.py               | 19 ++---
 python/tvm/auto_scheduler/utils.py                 | 51 +++++++++++--
 python/tvm/testing.py                              | 18 -----
 .../relay/test_auto_scheduler_layout_rewrite.py    | 13 +---
 .../unittest/test_auto_scheduler_search_policy.py  | 88 ++++------------------
 tutorials/auto_scheduler/tune_conv2d_layer_cuda.py | 34 +++++----
 tutorials/auto_scheduler/tune_matmul_x86.py        | 31 ++------
 7 files changed, 96 insertions(+), 158 deletions(-)

diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py
index 38a420d..24a7577 100644
--- a/python/tvm/auto_scheduler/measure.py
+++ b/python/tvm/auto_scheduler/measure.py
@@ -602,9 +602,9 @@ def _timed_func(inp_serialized, build_func, verbose):
 
     if verbose >= 1:
         if error_no == MeasureErrorNo.NO_ERROR:
-            print(".", end="")
+            print(".", end="", flush=True)
         else:
-            print(".E", end="")  # Build error
+            print(".E", end="", flush=True)  # Build error
 
     return filename, args, error_no, error_msg, time.time() - tic
 
@@ -634,11 +634,11 @@ def local_build_worker(args):
     res = call_func_with_timeout(timeout, _timed_func, args=(inp, build_func, verbose))
     if isinstance(res, TimeoutError):
         if verbose >= 1:
-            print(".T", end="")  # Build timeout
+            print(".T", end="", flush=True)  # Build timeout
         res = None, [], MeasureErrorNo.BUILD_TIMEOUT, None, timeout
     elif isinstance(res, Exception):
         if verbose >= 1:
-            print(".E", end="")  # Build error
+            print(".E", end="", flush=True)  # Build error
         res = None, [], MeasureErrorNo.COMPILE_HOST, str(res), timeout
 
     return res
@@ -751,9 +751,9 @@ def _timed_eval_func(
 
     if verbose >= 1:
         if error_no == MeasureErrorNo.NO_ERROR:
-            print("*", end="")
+            print("*", end="", flush=True)
         else:
-            print("*E", end="")  # Run error
+            print("*E", end="", flush=True)  # Run error
     return costs, error_no, error_msg, toc - tic + build_res.time_cost, toc
 
 
@@ -839,10 +839,11 @@ def local_run(
                     enable_cpu_cache_flush,
                     verbose,
                 ),
+                add_thread_wrapper=True,
             )
             if isinstance(res, TimeoutError):
                 if verbose >= 1:
-                    print("*T", end="")  # Run timeout
+                    print("*T", end="", flush=True)  # Run timeout
                 res = (
                     (MAX_FLOAT,),
                     MeasureErrorNo.RUN_TIMEOUT,
@@ -852,7 +853,7 @@ def local_run(
                 )
             elif isinstance(res, Exception):
                 if verbose >= 1:
-                    print("*E", end="")  # Run error
+                    print("*E", end="", flush=True)  # Run error
                 res = (
                     (MAX_FLOAT,),
                     MeasureErrorNo.RUNTIME_DEVICE,
@@ -864,7 +865,7 @@ def local_run(
         measure_results.append(MeasureResult(*res))
 
     if verbose >= 1:
-        print("")
+        print("", flush=True)
 
     return measure_results
 
diff --git a/python/tvm/auto_scheduler/utils.py b/python/tvm/auto_scheduler/utils.py
index 9a7c199..f3698fa 100644
--- a/python/tvm/auto_scheduler/utils.py
+++ b/python/tvm/auto_scheduler/utils.py
@@ -14,6 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+# pylint: disable=invalid-name
 
 """ Common utilities for auto_scheduler. """
 
@@ -162,22 +163,56 @@ def make_traceback_info():
     return info
 
 
-def _func_wrapper(que, func, args, kwargs):
+class PropagatingThread(threading.Thread):
+    """A thread that propagates the exception to the main thread"""
+
+    def run(self):
+        self.exc = None
+        try:
+            self.ret = self._target(*self._args, **self._kwargs)
+        except Exception as e:  # pylint: disable=broad-except
+            self.exc = e
+
+    def join(self, timeout=None):
+        super(PropagatingThread, self).join(timeout)
+        if self.exc:
+            raise self.exc
+        return self.ret
+
+
+def call_func_with_thread(func, args, kwargs):
+    """Call a function within a new thread"""
+    res = []
+
+    def wrapper():
+        res.append(func(*args, **kwargs))
+
+    t = PropagatingThread(target=wrapper)
+    t.start()
+    t.join()
+    return res[0]
+
+
+def _func_wrapper(que, func, args, kwargs, add_thread_wrapper):
     """Call function and return the result over the queue."""
     try:
-        if kwargs:
-            que.put(func(*args, **kwargs))
+        if add_thread_wrapper:
+            # Add a new layer of threadinng to avoid the conflict between
+            # python's multiprocessing and tvm's thread pool.
+            res = call_func_with_thread(func, args, kwargs)
         else:
-            que.put(func(*args))
-    # pylint: disable=broad-except
-    except Exception:
+            res = func(*args, **kwargs)
+        que.put(res)
+    except Exception:  # pylint: disable=broad-except
         que.put(Exception(make_traceback_info()))
 
 
-def call_func_with_timeout(timeout, func, args=(), kwargs=None):
+def call_func_with_timeout(timeout, func, args=(), kwargs=None, add_thread_wrapper=False):
     """Call a function with timeout"""
     que = multiprocessing.Queue(2)
-    process = multiprocessing.Process(target=_func_wrapper, args=(que, func, args, kwargs))
+    process = multiprocessing.Process(
+        target=_func_wrapper, args=(que, func, args, kwargs or {}, add_thread_wrapper)
+    )
     process.start()
 
     try:
diff --git a/python/tvm/testing.py b/python/tvm/testing.py
index 32307a9..8311a63 100644
--- a/python/tvm/testing.py
+++ b/python/tvm/testing.py
@@ -58,7 +58,6 @@ import logging
 import os
 import sys
 import time
-import threading
 import pytest
 import numpy as np
 import tvm
@@ -743,21 +742,4 @@ def terminate_self():
     sys.exit(-1)
 
 
-class PropagatingThread(threading.Thread):
-    """A thread that propagates the exection to the main thread"""
-
-    def run(self):
-        self.exc = None
-        try:
-            self.ret = self._target(*self._args, **self._kwargs)
-        except BaseException as e:
-            self.exc = e
-
-    def join(self, timeout=None):
-        super(PropagatingThread, self).join(timeout)
-        if self.exc:
-            raise self.exc
-        return self.ret
-
-
 tvm._ffi._init_api("testing", __name__)
diff --git a/tests/python/relay/test_auto_scheduler_layout_rewrite.py b/tests/python/relay/test_auto_scheduler_layout_rewrite.py
index 66d40ba..577f6d6 100644
--- a/tests/python/relay/test_auto_scheduler_layout_rewrite.py
+++ b/tests/python/relay/test_auto_scheduler_layout_rewrite.py
@@ -23,7 +23,6 @@ import tvm
 from tvm import relay, auto_scheduler
 from tvm.contrib import graph_runtime
 import tvm.testing
-from tvm.testing import PropagatingThread
 
 
 def get_np_array(var, dtype):
@@ -139,23 +138,17 @@ def test_conv2d():
     # wrap the search in a new thread to avoid the conflict
     # between python's multiprocessing and tvm's thread pool
     mod, data, weight = get_relay_conv2d(kh=1, kw=1)
-    t = PropagatingThread(target=tune_and_check, args=(mod, data, weight))
-    t.start()
-    t.join()
+    tune_and_check(mod, data, weight)
 
 
 def test_dense():
     mod, data, weight = get_relay_dense()
-    t = PropagatingThread(target=tune_and_check, args=(mod, data, weight))
-    t.start()
-    t.join()
+    tune_and_check(mod, data, weight)
 
 
 def test_batch_matmul():
     mod, data, weight = get_relay_batchmm()
-    t = PropagatingThread(target=tune_and_check, args=(mod, data, weight))
-    t.start()
-    t.join()
+    tune_and_check(mod, data, weight)
 
 
 if __name__ == "__main__":
diff --git a/tests/python/unittest/test_auto_scheduler_search_policy.py b/tests/python/unittest/test_auto_scheduler_search_policy.py
index 5bc7c2a..73ce0a1 100644
--- a/tests/python/unittest/test_auto_scheduler_search_policy.py
+++ b/tests/python/unittest/test_auto_scheduler_search_policy.py
@@ -24,7 +24,6 @@ import tempfile
 
 import tvm
 import tvm.testing
-from tvm.testing import PropagatingThread
 from tvm import auto_scheduler
 
 from test_auto_scheduler_common import matmul_auto_scheduler_test
@@ -78,18 +77,12 @@ def search_common(
             num_measures_per_round=2,
             early_stopping=1,
             runner=runner,
-            verbose=2,
             measure_callbacks=[auto_scheduler.RecordToFile(log_file), CustomMeasureCallback()],
         )
         task.tune(tuning_options=tuning_options, search_policy=search_policy)
         sch, args = task.apply_best(log_file)
 
-        print("==== Python Code ====")
-        print(task.print_best(log_file))
-
         try:
-            print("==== Lowered Stmt ====")
-            print(tvm.lower(sch, args, simple_mode=True))
             mod = tvm.build(sch, args, target)
 
             ctx = tvm.context(str(target), 0)
@@ -99,52 +92,29 @@ def search_common(
             c = tvm.nd.array(np.zeros((N, N), dtype=dtype), ctx)
             mod(a, b, c)
             tvm.testing.assert_allclose(c.asnumpy(), np.dot(a.asnumpy(), b.asnumpy()), rtol=1e-5)
-            print("==== Verification passed ====")
         except Exception:
             raise Exception("Error encountered with seed: %d" % (seed))
-    print()
 
 
 @tvm.testing.requires_llvm
 def test_workload_registry_search_basic():
-    # wrap the search in a new thread to avoid the conflict
-    # between python's multiprocessing and tvm's thread pool
-    t = PropagatingThread(
-        target=search_common, kwargs={"search_policy": "empty", "num_measure_trials": 2}
-    )
-    t.start()
-    t.join()
-
-    t = PropagatingThread(
-        target=search_common,
-        kwargs={
-            "workload": "matmul_auto_scheduler_test",
-            "num_measure_trials": 2,
-            "search_policy": "empty",
-        },
+    search_common(search_policy="empty", num_measure_trials=2)
+
+    search_common(
+        workload="matmul_auto_scheduler_test",
+        num_measure_trials=2,
+        search_policy="empty",
     )
-    t.start()
-    t.join()
-
-    t = PropagatingThread(
-        target=search_common,
-        kwargs={
-            "workload": "matmul_auto_scheduler_test_rename_1",
-            "num_measure_trials": 2,
-            "search_policy": "empty",
-        },
+    search_common(
+        workload="matmul_auto_scheduler_test_rename_1",
+        num_measure_trials=2,
+        search_policy="empty",
     )
-    t.start()
-    t.join()
 
 
 @tvm.testing.requires_llvm
 def test_sketch_search_policy_basic():
-    # wrap the search in a new thread to avoid the conflict
-    # between python's multiprocessing and tvm's thread pool
-    t = PropagatingThread(target=search_common)
-    t.start()
-    t.join()
+    search_common()
 
 
 def sketch_search_policy_basic_spawn():
@@ -162,49 +132,19 @@ def test_sketch_search_policy_basic_spawn():
 
 @tvm.testing.requires_llvm
 def test_sketch_search_policy_xgbmodel():
-    # wrap the search in a new thread to avoid the conflict
-    # between python's multiprocessing and tvm's thread pool
-    t = PropagatingThread(
-        target=search_common,
-        kwargs={
-            "cost_model": auto_scheduler.XGBModel(),
-        },
-    )
-    t.start()
-    t.join()
+    search_common(cost_model=auto_scheduler.XGBModel())
 
 
 @tvm.testing.requires_cuda
 def test_sketch_search_policy_cuda_rpc_runner():
     measure_ctx = auto_scheduler.LocalRPCMeasureContext()
-    # wrap the search in a new thread to avoid the conflict
-    # between python's multiprocessing and tvm's thread pool
-    t = PropagatingThread(
-        target=search_common,
-        kwargs={
-            "target": "cuda",
-            "runner": measure_ctx.runner,
-        },
-    )
-    t.start()
-    t.join()
+    search_common(target="cuda", runner=measure_ctx.runner)
 
 
 @tvm.testing.requires_cuda
 def test_sketch_search_policy_cuda_xgbmodel_rpc_runner():
     measure_ctx = auto_scheduler.LocalRPCMeasureContext()
-    # wrap the search in a new thread to avoid the conflict
-    # between python's multiprocessing and tvm's thread pool
-    t = PropagatingThread(
-        target=search_common,
-        kwargs={
-            "target": "cuda",
-            "runner": measure_ctx.runner,
-            "cost_model": auto_scheduler.XGBModel(),
-        },
-    )
-    t.start()
-    t.join()
+    search_common(target="cuda", runner=measure_ctx.runner, cost_model=auto_scheduler.XGBModel())
 
 
 if __name__ == "__main__":
diff --git a/tutorials/auto_scheduler/tune_conv2d_layer_cuda.py b/tutorials/auto_scheduler/tune_conv2d_layer_cuda.py
index 103ceb4..396bdb0 100644
--- a/tutorials/auto_scheduler/tune_conv2d_layer_cuda.py
+++ b/tutorials/auto_scheduler/tune_conv2d_layer_cuda.py
@@ -186,18 +186,24 @@ print(task.print_best(log_file, print_mode="cuda"))
 # and resume the status of search policy and cost model with the log file.
 # In the example below we resume the status and do more 5 trials.
 
-cost_model = auto_scheduler.XGBModel()
-cost_model.update_from_file(log_file)
-search_policy = auto_scheduler.SketchPolicy(
-    task, cost_model, init_search_callbacks=[auto_scheduler.PreloadMeasuredStates(log_file)]
-)
-measure_ctx = auto_scheduler.LocalRPCMeasureContext(min_repeat_ms=300)
-tune_option = auto_scheduler.TuningOptions(
-    num_measure_trials=5,
-    runner=measure_ctx.runner,
-    measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
-)
-task.tune(tune_option, search_policy=search_policy)
 
-# Kill the measurement process
-del measure_ctx
+def resume_search(task, log_file):
+    print("Resume search:")
+    cost_model = auto_scheduler.XGBModel()
+    cost_model.update_from_file(log_file)
+    search_policy = auto_scheduler.SketchPolicy(
+        task, cost_model, init_search_callbacks=[auto_scheduler.PreloadMeasuredStates(log_file)]
+    )
+    measure_ctx = auto_scheduler.LocalRPCMeasureContext(min_repeat_ms=300)
+    tune_option = auto_scheduler.TuningOptions(
+        num_measure_trials=5,
+        runner=measure_ctx.runner,
+        measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
+    )
+    task.tune(tune_option, search_policy=search_policy)
+
+    # Kill the measurement process
+    del measure_ctx
+
+
+resume_search(task, log_file)
diff --git a/tutorials/auto_scheduler/tune_matmul_x86.py b/tutorials/auto_scheduler/tune_matmul_x86.py
index 9bc15ae..084f5ae 100644
--- a/tutorials/auto_scheduler/tune_matmul_x86.py
+++ b/tutorials/auto_scheduler/tune_matmul_x86.py
@@ -174,36 +174,17 @@ print(task.print_best(log_file))
 # In the example below we resume the status and do more 5 trials.
 
 
-def resume_search(task, log_file_name):
+def resume_search(task, log_file):
+    print("Resume search:")
     cost_model = auto_scheduler.XGBModel()
-    cost_model.update_from_file(log_file_name)
+    cost_model.update_from_file(log_file)
     search_policy = auto_scheduler.SketchPolicy(
-        task,
-        cost_model,
-        init_search_callbacks=[auto_scheduler.PreloadMeasuredStates(log_file_name)],
+        task, cost_model, init_search_callbacks=[auto_scheduler.PreloadMeasuredStates(log_file)]
     )
     tune_option = auto_scheduler.TuningOptions(
-        num_measure_trials=5, measure_callbacks=[auto_scheduler.RecordToFile(log_file_name)]
+        num_measure_trials=5, measure_callbacks=[auto_scheduler.RecordToFile(log_file)]
     )
     task.tune(tune_option, search_policy=search_policy)
 
 
-# resume_search(task, log_file)
-
-######################################################################
-# .. note::
-#   We cannot run the line above because of the conflict between
-#   python's multiprocessing and tvm's thread pool.
-#   After running a tvm generated binary the python's multiprocessing library
-#   will hang forever. You have to make sure that you don't run any tvm
-#   generated binaries before calling auot-scheduler's search.
-#   To run the function above, you should comment out all code in
-#   "Check correctness and evaluate performance" section.
-#
-#   You should be careful about this problem in your applications.
-#   There are other workarounds for this problem.
-#   For example, you can start a new thread/process (with the builtin python library
-#   threading or multiprocessing) and run the tvm binaries in the new thread/process.
-#   This provides an isolation and avoids the conflict in the main thread/process.
-#   You can also use :any:`auto_scheduler.LocalRPCMeasureContext` for auto-scheduler,
-#   as shown in the GPU tutorial (:ref:`auto-scheduler-conv-gpu`).
+resume_search(task, log_file)