You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by co...@apache.org on 2020/12/25 09:14:26 UTC
[tvm] branch main updated: [AutoScheduler] Fix the conflict of
thread pool in measurement (#7166)
This is an automated email from the ASF dual-hosted git repository.
comaniac pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 3d8fd2a [AutoScheduler] Fix the conflict of thread pool in measurement (#7166)
3d8fd2a is described below
commit 3d8fd2a2e7c39216482275d04e71aa131216e4c2
Author: Lianmin Zheng <li...@gmail.com>
AuthorDate: Fri Dec 25 01:14:09 2020 -0800
[AutoScheduler] Fix the conflict of thread pool in measurement (#7166)
---
python/tvm/auto_scheduler/measure.py | 19 ++---
python/tvm/auto_scheduler/utils.py | 51 +++++++++++--
python/tvm/testing.py | 18 -----
.../relay/test_auto_scheduler_layout_rewrite.py | 13 +---
.../unittest/test_auto_scheduler_search_policy.py | 88 ++++------------------
tutorials/auto_scheduler/tune_conv2d_layer_cuda.py | 34 +++++----
tutorials/auto_scheduler/tune_matmul_x86.py | 31 ++------
7 files changed, 96 insertions(+), 158 deletions(-)
diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py
index 38a420d..24a7577 100644
--- a/python/tvm/auto_scheduler/measure.py
+++ b/python/tvm/auto_scheduler/measure.py
@@ -602,9 +602,9 @@ def _timed_func(inp_serialized, build_func, verbose):
if verbose >= 1:
if error_no == MeasureErrorNo.NO_ERROR:
- print(".", end="")
+ print(".", end="", flush=True)
else:
- print(".E", end="") # Build error
+ print(".E", end="", flush=True) # Build error
return filename, args, error_no, error_msg, time.time() - tic
@@ -634,11 +634,11 @@ def local_build_worker(args):
res = call_func_with_timeout(timeout, _timed_func, args=(inp, build_func, verbose))
if isinstance(res, TimeoutError):
if verbose >= 1:
- print(".T", end="") # Build timeout
+ print(".T", end="", flush=True) # Build timeout
res = None, [], MeasureErrorNo.BUILD_TIMEOUT, None, timeout
elif isinstance(res, Exception):
if verbose >= 1:
- print(".E", end="") # Build error
+ print(".E", end="", flush=True) # Build error
res = None, [], MeasureErrorNo.COMPILE_HOST, str(res), timeout
return res
@@ -751,9 +751,9 @@ def _timed_eval_func(
if verbose >= 1:
if error_no == MeasureErrorNo.NO_ERROR:
- print("*", end="")
+ print("*", end="", flush=True)
else:
- print("*E", end="") # Run error
+ print("*E", end="", flush=True) # Run error
return costs, error_no, error_msg, toc - tic + build_res.time_cost, toc
@@ -839,10 +839,11 @@ def local_run(
enable_cpu_cache_flush,
verbose,
),
+ add_thread_wrapper=True,
)
if isinstance(res, TimeoutError):
if verbose >= 1:
- print("*T", end="") # Run timeout
+ print("*T", end="", flush=True) # Run timeout
res = (
(MAX_FLOAT,),
MeasureErrorNo.RUN_TIMEOUT,
@@ -852,7 +853,7 @@ def local_run(
)
elif isinstance(res, Exception):
if verbose >= 1:
- print("*E", end="") # Run error
+ print("*E", end="", flush=True) # Run error
res = (
(MAX_FLOAT,),
MeasureErrorNo.RUNTIME_DEVICE,
@@ -864,7 +865,7 @@ def local_run(
measure_results.append(MeasureResult(*res))
if verbose >= 1:
- print("")
+ print("", flush=True)
return measure_results
diff --git a/python/tvm/auto_scheduler/utils.py b/python/tvm/auto_scheduler/utils.py
index 9a7c199..f3698fa 100644
--- a/python/tvm/auto_scheduler/utils.py
+++ b/python/tvm/auto_scheduler/utils.py
@@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+# pylint: disable=invalid-name
""" Common utilities for auto_scheduler. """
@@ -162,22 +163,56 @@ def make_traceback_info():
return info
-def _func_wrapper(que, func, args, kwargs):
+class PropagatingThread(threading.Thread):
+ """A thread that propagates the exception to the main thread"""
+
+ def run(self):
+ self.exc = None
+ try:
+ self.ret = self._target(*self._args, **self._kwargs)
+ except Exception as e: # pylint: disable=broad-except
+ self.exc = e
+
+ def join(self, timeout=None):
+ super(PropagatingThread, self).join(timeout)
+ if self.exc:
+ raise self.exc
+ return self.ret
+
+
+def call_func_with_thread(func, args, kwargs):
+ """Call a function within a new thread"""
+ res = []
+
+ def wrapper():
+ res.append(func(*args, **kwargs))
+
+ t = PropagatingThread(target=wrapper)
+ t.start()
+ t.join()
+ return res[0]
+
+
+def _func_wrapper(que, func, args, kwargs, add_thread_wrapper):
"""Call function and return the result over the queue."""
try:
- if kwargs:
- que.put(func(*args, **kwargs))
+ if add_thread_wrapper:
+ # Add a new layer of threadinng to avoid the conflict between
+ # python's multiprocessing and tvm's thread pool.
+ res = call_func_with_thread(func, args, kwargs)
else:
- que.put(func(*args))
- # pylint: disable=broad-except
- except Exception:
+ res = func(*args, **kwargs)
+ que.put(res)
+ except Exception: # pylint: disable=broad-except
que.put(Exception(make_traceback_info()))
-def call_func_with_timeout(timeout, func, args=(), kwargs=None):
+def call_func_with_timeout(timeout, func, args=(), kwargs=None, add_thread_wrapper=False):
"""Call a function with timeout"""
que = multiprocessing.Queue(2)
- process = multiprocessing.Process(target=_func_wrapper, args=(que, func, args, kwargs))
+ process = multiprocessing.Process(
+ target=_func_wrapper, args=(que, func, args, kwargs or {}, add_thread_wrapper)
+ )
process.start()
try:
diff --git a/python/tvm/testing.py b/python/tvm/testing.py
index 32307a9..8311a63 100644
--- a/python/tvm/testing.py
+++ b/python/tvm/testing.py
@@ -58,7 +58,6 @@ import logging
import os
import sys
import time
-import threading
import pytest
import numpy as np
import tvm
@@ -743,21 +742,4 @@ def terminate_self():
sys.exit(-1)
-class PropagatingThread(threading.Thread):
- """A thread that propagates the exection to the main thread"""
-
- def run(self):
- self.exc = None
- try:
- self.ret = self._target(*self._args, **self._kwargs)
- except BaseException as e:
- self.exc = e
-
- def join(self, timeout=None):
- super(PropagatingThread, self).join(timeout)
- if self.exc:
- raise self.exc
- return self.ret
-
-
tvm._ffi._init_api("testing", __name__)
diff --git a/tests/python/relay/test_auto_scheduler_layout_rewrite.py b/tests/python/relay/test_auto_scheduler_layout_rewrite.py
index 66d40ba..577f6d6 100644
--- a/tests/python/relay/test_auto_scheduler_layout_rewrite.py
+++ b/tests/python/relay/test_auto_scheduler_layout_rewrite.py
@@ -23,7 +23,6 @@ import tvm
from tvm import relay, auto_scheduler
from tvm.contrib import graph_runtime
import tvm.testing
-from tvm.testing import PropagatingThread
def get_np_array(var, dtype):
@@ -139,23 +138,17 @@ def test_conv2d():
# wrap the search in a new thread to avoid the conflict
# between python's multiprocessing and tvm's thread pool
mod, data, weight = get_relay_conv2d(kh=1, kw=1)
- t = PropagatingThread(target=tune_and_check, args=(mod, data, weight))
- t.start()
- t.join()
+ tune_and_check(mod, data, weight)
def test_dense():
mod, data, weight = get_relay_dense()
- t = PropagatingThread(target=tune_and_check, args=(mod, data, weight))
- t.start()
- t.join()
+ tune_and_check(mod, data, weight)
def test_batch_matmul():
mod, data, weight = get_relay_batchmm()
- t = PropagatingThread(target=tune_and_check, args=(mod, data, weight))
- t.start()
- t.join()
+ tune_and_check(mod, data, weight)
if __name__ == "__main__":
diff --git a/tests/python/unittest/test_auto_scheduler_search_policy.py b/tests/python/unittest/test_auto_scheduler_search_policy.py
index 5bc7c2a..73ce0a1 100644
--- a/tests/python/unittest/test_auto_scheduler_search_policy.py
+++ b/tests/python/unittest/test_auto_scheduler_search_policy.py
@@ -24,7 +24,6 @@ import tempfile
import tvm
import tvm.testing
-from tvm.testing import PropagatingThread
from tvm import auto_scheduler
from test_auto_scheduler_common import matmul_auto_scheduler_test
@@ -78,18 +77,12 @@ def search_common(
num_measures_per_round=2,
early_stopping=1,
runner=runner,
- verbose=2,
measure_callbacks=[auto_scheduler.RecordToFile(log_file), CustomMeasureCallback()],
)
task.tune(tuning_options=tuning_options, search_policy=search_policy)
sch, args = task.apply_best(log_file)
- print("==== Python Code ====")
- print(task.print_best(log_file))
-
try:
- print("==== Lowered Stmt ====")
- print(tvm.lower(sch, args, simple_mode=True))
mod = tvm.build(sch, args, target)
ctx = tvm.context(str(target), 0)
@@ -99,52 +92,29 @@ def search_common(
c = tvm.nd.array(np.zeros((N, N), dtype=dtype), ctx)
mod(a, b, c)
tvm.testing.assert_allclose(c.asnumpy(), np.dot(a.asnumpy(), b.asnumpy()), rtol=1e-5)
- print("==== Verification passed ====")
except Exception:
raise Exception("Error encountered with seed: %d" % (seed))
- print()
@tvm.testing.requires_llvm
def test_workload_registry_search_basic():
- # wrap the search in a new thread to avoid the conflict
- # between python's multiprocessing and tvm's thread pool
- t = PropagatingThread(
- target=search_common, kwargs={"search_policy": "empty", "num_measure_trials": 2}
- )
- t.start()
- t.join()
-
- t = PropagatingThread(
- target=search_common,
- kwargs={
- "workload": "matmul_auto_scheduler_test",
- "num_measure_trials": 2,
- "search_policy": "empty",
- },
+ search_common(search_policy="empty", num_measure_trials=2)
+
+ search_common(
+ workload="matmul_auto_scheduler_test",
+ num_measure_trials=2,
+ search_policy="empty",
)
- t.start()
- t.join()
-
- t = PropagatingThread(
- target=search_common,
- kwargs={
- "workload": "matmul_auto_scheduler_test_rename_1",
- "num_measure_trials": 2,
- "search_policy": "empty",
- },
+ search_common(
+ workload="matmul_auto_scheduler_test_rename_1",
+ num_measure_trials=2,
+ search_policy="empty",
)
- t.start()
- t.join()
@tvm.testing.requires_llvm
def test_sketch_search_policy_basic():
- # wrap the search in a new thread to avoid the conflict
- # between python's multiprocessing and tvm's thread pool
- t = PropagatingThread(target=search_common)
- t.start()
- t.join()
+ search_common()
def sketch_search_policy_basic_spawn():
@@ -162,49 +132,19 @@ def test_sketch_search_policy_basic_spawn():
@tvm.testing.requires_llvm
def test_sketch_search_policy_xgbmodel():
- # wrap the search in a new thread to avoid the conflict
- # between python's multiprocessing and tvm's thread pool
- t = PropagatingThread(
- target=search_common,
- kwargs={
- "cost_model": auto_scheduler.XGBModel(),
- },
- )
- t.start()
- t.join()
+ search_common(cost_model=auto_scheduler.XGBModel())
@tvm.testing.requires_cuda
def test_sketch_search_policy_cuda_rpc_runner():
measure_ctx = auto_scheduler.LocalRPCMeasureContext()
- # wrap the search in a new thread to avoid the conflict
- # between python's multiprocessing and tvm's thread pool
- t = PropagatingThread(
- target=search_common,
- kwargs={
- "target": "cuda",
- "runner": measure_ctx.runner,
- },
- )
- t.start()
- t.join()
+ search_common(target="cuda", runner=measure_ctx.runner)
@tvm.testing.requires_cuda
def test_sketch_search_policy_cuda_xgbmodel_rpc_runner():
measure_ctx = auto_scheduler.LocalRPCMeasureContext()
- # wrap the search in a new thread to avoid the conflict
- # between python's multiprocessing and tvm's thread pool
- t = PropagatingThread(
- target=search_common,
- kwargs={
- "target": "cuda",
- "runner": measure_ctx.runner,
- "cost_model": auto_scheduler.XGBModel(),
- },
- )
- t.start()
- t.join()
+ search_common(target="cuda", runner=measure_ctx.runner, cost_model=auto_scheduler.XGBModel())
if __name__ == "__main__":
diff --git a/tutorials/auto_scheduler/tune_conv2d_layer_cuda.py b/tutorials/auto_scheduler/tune_conv2d_layer_cuda.py
index 103ceb4..396bdb0 100644
--- a/tutorials/auto_scheduler/tune_conv2d_layer_cuda.py
+++ b/tutorials/auto_scheduler/tune_conv2d_layer_cuda.py
@@ -186,18 +186,24 @@ print(task.print_best(log_file, print_mode="cuda"))
# and resume the status of search policy and cost model with the log file.
# In the example below we resume the status and do more 5 trials.
-cost_model = auto_scheduler.XGBModel()
-cost_model.update_from_file(log_file)
-search_policy = auto_scheduler.SketchPolicy(
- task, cost_model, init_search_callbacks=[auto_scheduler.PreloadMeasuredStates(log_file)]
-)
-measure_ctx = auto_scheduler.LocalRPCMeasureContext(min_repeat_ms=300)
-tune_option = auto_scheduler.TuningOptions(
- num_measure_trials=5,
- runner=measure_ctx.runner,
- measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
-)
-task.tune(tune_option, search_policy=search_policy)
-# Kill the measurement process
-del measure_ctx
+def resume_search(task, log_file):
+ print("Resume search:")
+ cost_model = auto_scheduler.XGBModel()
+ cost_model.update_from_file(log_file)
+ search_policy = auto_scheduler.SketchPolicy(
+ task, cost_model, init_search_callbacks=[auto_scheduler.PreloadMeasuredStates(log_file)]
+ )
+ measure_ctx = auto_scheduler.LocalRPCMeasureContext(min_repeat_ms=300)
+ tune_option = auto_scheduler.TuningOptions(
+ num_measure_trials=5,
+ runner=measure_ctx.runner,
+ measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
+ )
+ task.tune(tune_option, search_policy=search_policy)
+
+ # Kill the measurement process
+ del measure_ctx
+
+
+resume_search(task, log_file)
diff --git a/tutorials/auto_scheduler/tune_matmul_x86.py b/tutorials/auto_scheduler/tune_matmul_x86.py
index 9bc15ae..084f5ae 100644
--- a/tutorials/auto_scheduler/tune_matmul_x86.py
+++ b/tutorials/auto_scheduler/tune_matmul_x86.py
@@ -174,36 +174,17 @@ print(task.print_best(log_file))
# In the example below we resume the status and do more 5 trials.
-def resume_search(task, log_file_name):
+def resume_search(task, log_file):
+ print("Resume search:")
cost_model = auto_scheduler.XGBModel()
- cost_model.update_from_file(log_file_name)
+ cost_model.update_from_file(log_file)
search_policy = auto_scheduler.SketchPolicy(
- task,
- cost_model,
- init_search_callbacks=[auto_scheduler.PreloadMeasuredStates(log_file_name)],
+ task, cost_model, init_search_callbacks=[auto_scheduler.PreloadMeasuredStates(log_file)]
)
tune_option = auto_scheduler.TuningOptions(
- num_measure_trials=5, measure_callbacks=[auto_scheduler.RecordToFile(log_file_name)]
+ num_measure_trials=5, measure_callbacks=[auto_scheduler.RecordToFile(log_file)]
)
task.tune(tune_option, search_policy=search_policy)
-# resume_search(task, log_file)
-
-######################################################################
-# .. note::
-# We cannot run the line above because of the conflict between
-# python's multiprocessing and tvm's thread pool.
-# After running a tvm generated binary the python's multiprocessing library
-# will hang forever. You have to make sure that you don't run any tvm
-# generated binaries before calling auot-scheduler's search.
-# To run the function above, you should comment out all code in
-# "Check correctness and evaluate performance" section.
-#
-# You should be careful about this problem in your applications.
-# There are other workarounds for this problem.
-# For example, you can start a new thread/process (with the builtin python library
-# threading or multiprocessing) and run the tvm binaries in the new thread/process.
-# This provides an isolation and avoids the conflict in the main thread/process.
-# You can also use :any:`auto_scheduler.LocalRPCMeasureContext` for auto-scheduler,
-# as shown in the GPU tutorial (:ref:`auto-scheduler-conv-gpu`).
+resume_search(task, log_file)