You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/12/25 08:24:46 UTC

[GitHub] [tvm] jcf94 commented on a change in pull request #7166: [AutoScheduler] Fix the conflict of thread pool in measurement

jcf94 commented on a change in pull request #7166:
URL: https://github.com/apache/tvm/pull/7166#discussion_r548831804



##########
File path: python/tvm/auto_scheduler/utils.py
##########
@@ -162,22 +163,56 @@ def make_traceback_info():
     return info
 
 
-def _func_wrapper(que, func, args, kwargs):
+class PropagatingThread(threading.Thread):
+    """A thread that propagates the exception to the main thread"""
+
+    def run(self):
+        self.exc = None
+        try:
+            self.ret = self._target(*self._args, **self._kwargs)
+        except Exception as e:  # pylint: disable=broad-except
+            self.exc = e
+
+    def join(self, timeout=None):
+        super(PropagatingThread, self).join(timeout)
+        if self.exc:
+            raise self.exc
+        return self.ret
+
+
+def call_func_with_thread(func, args, kwargs):
+    """Call a function within a new thread"""
+    res = []
+
+    def wrapper():
+        res.append(func(*args, **kwargs))
+
+    t = PropagatingThread(target=wrapper)
+    t.start()
+    t.join()
+    return res[0]
+
+
+def _func_wrapper(que, func, args, kwargs, add_thread_wrapper):
     """Call function and return the result over the queue."""
     try:
-        if kwargs:
-            que.put(func(*args, **kwargs))
+        if add_thread_wrapper:
+            # Add a new layer of threadinng to avoid the conflict between
+            # python's multiprocessing and tvm's thread pool.
+            res = call_func_with_thread(func, args, kwargs)
         else:
-            que.put(func(*args))
-    # pylint: disable=broad-except
-    except Exception:
+            res = func(*args, **kwargs)
+        que.put(res)
+    except Exception:  # pylint: disable=broad-except
         que.put(Exception(make_traceback_info()))
 
 
-def call_func_with_timeout(timeout, func, args=(), kwargs=None):
+def call_func_with_timeout(timeout, func, args=(), kwargs=None, add_thread_wrapper=False):
     """Call a function with timeout"""
     que = multiprocessing.Queue(2)
-    process = multiprocessing.Process(target=_func_wrapper, args=(que, func, args, kwargs))
+    process = multiprocessing.Process(

Review comment:
       I'm thinking about whether the bug was introduced by the `multiprocessing`? Maybe directly use thread here can solve the problem.... I'm not sure, just a guess.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org