You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2020/12/20 21:08:14 UTC
[tvm] branch main updated: [CONTRIB] PopenPoolExecutor (#6959)
This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 37af2d7 [CONTRIB] PopenPoolExecutor (#6959)
37af2d7 is described below
commit 37af2d741d3efb37ca6aba261db8b78583dbc1cd
Author: Tianqi Chen <tq...@users.noreply.github.com>
AuthorDate: Sun Dec 20 16:07:58 2020 -0500
[CONTRIB] PopenPoolExecutor (#6959)
PopenPoolExecutor implements a ProcessPoolExecutor backed by popen.
- Only handles invoking functions in tvm namespace.
- Unlike multiprocessing, does not require __main__ block,
which means it can directly run on jupyter notebook.
- Come with timeout and fault tolerant support to timeout
long running jobs, and restart the process when an error happens.
Recommended usage: it is recommended to create a pool and reuse
it in a long running job(e.g. autotuning) so that the process
are reused when possible.
---
python/tvm/contrib/popen_pool.py | 329 ++++++++++++++++++++++++++++++++
python/tvm/exec/popen_worker.py | 104 ++++++++++
python/tvm/testing.py | 28 +++
tests/python/contrib/test_popen_pool.py | 71 +++++++
4 files changed, 532 insertions(+)
diff --git a/python/tvm/contrib/popen_pool.py b/python/tvm/contrib/popen_pool.py
new file mode 100644
index 0000000..bca0862
--- /dev/null
+++ b/python/tvm/contrib/popen_pool.py
@@ -0,0 +1,329 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+"""Multiprocessing via Popen.
+
+This module provides a multi-processing pool backed by Popen.
+with additional timeout support.
+"""
+import os
+import sys
+import struct
+import threading
+import subprocess
+import concurrent.futures
+from enum import IntEnum
+from collections import namedtuple
+import pickle
+
+
+def kill_child_processes(pid):
+ """Kill all child processes recursively for a given pid.
+
+ Parameters
+ ----------
+ pid : int
+ The given parameter id.
+ """
+ # pylint: disable=import-outside-toplevel
+ import psutil
+
+ try:
+ parent = psutil.Process(pid)
+ except psutil.NoSuchProcess:
+ return
+
+ for process in parent.children(recursive=True):
+ try:
+ process.kill()
+ except psutil.NoSuchProcess:
+ pass
+
+
+class StatusKind(IntEnum):
+ """Running and return value status."""
+
+ RUNNING = 0
+ COMPLETE = 1
+ EXCEPTION = 2
+ TIMEOUT = 3
+
+
+class MapResult(namedtuple("MapResult", ["status", "value"])):
+ """Result of map_with_error_catching.
+
+ Parameters
+ ----------
+ status : StatusKind
+ The status of the result.
+
+ value : Any
+ The result value.
+ """
+
+ __slots__ = []
+
+
+class PopenWorker:
+ """A subprocess worker via Popen.
+
+ PopenWorker provides a low-level
+ API to interact with a separate process via Popen.
+ """
+
+ def __init__(self):
+ self._proc = None
+
+ def __del__(self):
+ try:
+ self.kill()
+ except ImportError:
+ pass
+
+ def kill(self):
+ """Kill the current running process and cleanup.
+
+ Note
+ ----
+ The worker can start a new process when send is called again.
+ """
+ if self._proc is not None:
+ # allow gracefully shutdown
+ try:
+ self._writer.close()
+ except IOError:
+ pass
+ try:
+ self._reader.close()
+ except IOError:
+ pass
+ # kill all child processes recurisvely
+ kill_child_processes(self._proc.pid)
+ try:
+ self._proc.kill()
+ except OSError:
+ pass
+ self._proc = None
+
+ def _start(self):
+ """Start a new subprocess if nothing is available"""
+ if self._proc is not None:
+ return
+
+ # connect subprocess with a pair of pipes
+ main_read, worker_write = os.pipe()
+ worker_read, main_write = os.pipe()
+
+ cmd = [sys.executable, "-m", "tvm.exec.popen_worker"]
+ if sys.platform == "win32":
+ # pylint: disable=import-outside-toplevel
+ import msvcrt
+
+ worker_read_handle = msvcrt.get_osfhandle(worker_read)
+ worker_write_handle = msvcrt.get_osfhandle(worker_write)
+ os.set_handle_inheritable(worker_read_handle, True)
+ os.set_handle_inheritable(worker_write_handle, True)
+ cmd += [str(worker_read_handle), str(worker_write_handle)]
+ self._proc = subprocess.Popen(cmd, close_fds=False)
+ else:
+ cmd += [str(worker_read), str(worker_write)]
+ self._proc = subprocess.Popen(cmd, pass_fds=(worker_read, worker_write))
+
+ # close worker side of the pipe
+ os.close(worker_read)
+ os.close(worker_write)
+ self._reader = os.fdopen(main_read, "rb")
+ self._writer = os.fdopen(main_write, "wb")
+
+ def send(self, fn, args=(), kwargs=None, timeout=None):
+ """Send a new function task fn(*args, **kwargs) to the subprocess.
+
+ Parameters
+ ----------
+ fn : function
+ The function to be invoked.
+
+ args : list
+ Positional argument.
+
+ kwargs : dict
+ Keyword arguments
+
+ timeout : float
+ Timeout value when executing the function
+
+ Note
+ ----
+ The caller must call recv before calling the next send in
+ order to make sure the timeout and child process exit
+ won't affect the later requests.
+ """
+ # use cloud pickle
+ # pylint: disable=import-outside-toplevel
+ import cloudpickle
+
+ if self._proc is None:
+ self._start()
+ kwargs = {} if not kwargs else kwargs
+ data = cloudpickle.dumps((fn, args, kwargs, timeout), protocol=pickle.HIGHEST_PROTOCOL)
+ try:
+ self._writer.write(struct.pack("<i", len(data)))
+ self._writer.write(data)
+ self._writer.flush()
+ except IOError:
+ pass
+
+ def _child_process_error(self):
+ """Raise a child process error."""
+ # kill and lazily restart the process in the next send.
+ self.kill()
+ return ChildProcessError("Subprocess terminated")
+
+ def recv(self):
+ """Receive the result of the last send.
+
+ Returns
+ -------
+ result: object
+ The result of the last send.
+
+ Raises
+ ------
+ ChildProcessError: if the child process exited abnormally.
+ TimeoutError: if timeout happens
+ Exception: if other exception happens during the execution.
+ """
+ # pylint: disable=import-outside-toplevel
+ import cloudpickle
+
+ try:
+ len_data = self._reader.read(4)
+ except IOError:
+ raise self._child_process_error()
+
+ if len(len_data) == 0:
+ raise self._child_process_error()
+
+ try:
+ recv_bytes = struct.unpack("<i", len_data)[0]
+ status, value = cloudpickle.loads(self._reader.read(recv_bytes))
+ except IOError:
+ raise self._child_process_error()
+
+ if status == StatusKind.COMPLETE:
+ return value
+ if status == StatusKind.EXCEPTION:
+ raise value
+ assert status == StatusKind.TIMEOUT
+ # kill and lazily restart the process in the next send.
+ self.kill()
+ raise TimeoutError()
+
+
+class PopenPoolExecutor:
+ """An parallel executor backed by Popen processes.
+
+ Parameters
+ ----------
+ max_worker : int
+ Maximum number of workers
+
+ timeout : float
+ Timeout value for each function submit.
+ """
+
+ def __init__(self, max_workers, timeout=None):
+ # Use an internal thread pool to send to popen workers
+ self._threadpool = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)
+ self._timeout = timeout
+ self._worker_map = {}
+ self._lock = threading.Lock()
+
+ def __del__(self):
+ self._lock.acquire()
+ for worker in self._worker_map.values():
+ try:
+ worker.kill()
+ except ImportError:
+ pass
+ self._lock.release()
+ self._threadpool.shutdown()
+
+ def _worker_run(self, fn, args, kwargs):
+ """Internal thread runner."""
+ self._lock.acquire()
+ tid = threading.get_ident()
+ if tid not in self._worker_map:
+ proc = PopenWorker()
+ self._worker_map[tid] = proc
+ else:
+ proc = self._worker_map[tid]
+ self._lock.release()
+
+ proc.send(fn, args, kwargs, self._timeout)
+ return proc.recv()
+
+ def _worker_run_with_error_catching(self, fn, args, kwargs) -> MapResult:
+ # pylint: disable=broad-except
+ try:
+ return MapResult(status=StatusKind.COMPLETE, value=self._worker_run(fn, args, kwargs))
+ except TimeoutError as exception:
+ return MapResult(status=StatusKind.TIMEOUT, value=exception)
+ except Exception as exception:
+ return MapResult(status=StatusKind.EXCEPTION, value=exception)
+
+ def submit(self, fn, *args, **kwargs) -> concurrent.futures.Future:
+ """Submit a new function job to the pool
+
+ Parameters
+ ----------
+ fn : function
+ The function to be invoked.
+
+ args : list
+ Positional argument.
+
+ kwargs : dict
+ Keyword arguments
+
+ Returns
+ -------
+ future : concurrent.futures.Future
+ A future that can be used to access the result.
+ """
+ # pylint: disable=unnecessary-lambda
+ worker = lambda *args: self._worker_run(*args)
+ return self._threadpool.submit(worker, fn, args, kwargs)
+
+ def map_with_error_catching(self, fn, iterator):
+ """Same as map, but catches exceptions and return them instead.
+
+ Parameters
+ ----------
+ fn : function
+ The function to be invoked.
+
+ iterator : Iterator
+ Input iterator.
+
+ Returns
+ -------
+ out_iter : Iterator[MapResult]
+ The result iterator.
+ """
+ worker = lambda x: self._worker_run_with_error_catching(fn, (x,), None)
+ return self._threadpool.map(worker, iterator)
diff --git a/python/tvm/exec/popen_worker.py b/python/tvm/exec/popen_worker.py
new file mode 100644
index 0000000..b62cca5
--- /dev/null
+++ b/python/tvm/exec/popen_worker.py
@@ -0,0 +1,104 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+"""Internal PopenWorker for PopenPool."""
+import sys
+import os
+import struct
+import threading
+import traceback
+import pickle
+import cloudpickle
+
+from tvm.contrib.popen_pool import StatusKind
+
+
+class TimeoutStatus:
+ __slot__ = ["status"]
+
+ def __init__(self):
+ self.status = StatusKind.RUNNING
+
+
+def main():
+ """Main worker function"""
+ if len(sys.argv) != 3:
+ print("Usage: <read_fd> <write_fd>")
+ return
+ if sys.platform == "win32":
+ # pylint: disable=import-outside-toplevel
+ import msvcrt
+
+ reader = os.fdopen(msvcrt.open_osfhandle(int(sys.argv[1]), os.O_BINARY), "rb")
+ writer = os.fdopen(msvcrt.open_osfhandle(int(sys.argv[2]), os.O_BINARY), "wb")
+ else:
+ reader = os.fdopen(int(sys.argv[1]), "rb")
+ writer = os.fdopen(int(sys.argv[2]), "wb")
+
+ lock = threading.Lock()
+
+ def _respond(ret_value):
+ """Send data back to the client."""
+ data = cloudpickle.dumps(ret_value, protocol=pickle.HIGHEST_PROTOCOL)
+ writer.write(struct.pack("<i", len(data)))
+ writer.write(data)
+ writer.flush()
+
+ def _cancel_run(status):
+ lock.acquire()
+ if status.status == StatusKind.RUNNING:
+ _respond((StatusKind.TIMEOUT, TimeoutError()))
+ status.status = StatusKind.TIMEOUT
+ lock.release()
+
+ while True:
+ raw_bytes_size = reader.read(4)
+ if len(raw_bytes_size) != 4:
+ # the parent exited
+ return
+ bytes_size = struct.unpack("<i", raw_bytes_size)[0]
+ fn, args, kwargs, timeout = cloudpickle.loads(reader.read(bytes_size))
+ status = TimeoutStatus()
+
+ if timeout is not None:
+ watcher = threading.Timer(timeout, _cancel_run, [status])
+ watcher.daemon = True
+ watcher.start()
+
+ # pylint: disable=broad-except
+ try:
+ result = fn(*args, **kwargs)
+ ret_value = (StatusKind.COMPLETE, result)
+ except Exception as exception:
+ msg = traceback.format_exc()
+ ret_value = (StatusKind.EXCEPTION, type(exception)(msg))
+
+ if timeout is not None:
+ watcher.cancel()
+
+ lock.acquire()
+ if status.status == StatusKind.RUNNING:
+ _respond(ret_value)
+ status.status = StatusKind.COMPLETE
+ lock.release()
+
+
+if __name__ == "__main__":
+ try:
+ main()
+ except (KeyboardInterrupt, IOError):
+ pass
diff --git a/python/tvm/testing.py b/python/tvm/testing.py
index e5b17f3..8311a63 100644
--- a/python/tvm/testing.py
+++ b/python/tvm/testing.py
@@ -56,6 +56,8 @@ function in this module. Then targets using this node should be added to the
"""
import logging
import os
+import sys
+import time
import pytest
import numpy as np
import tvm
@@ -714,4 +716,30 @@ def parametrize_targets(*args):
return wrap(args)
+def identity_after(x, sleep):
+ """Testing function to return identity after sleep
+
+ Parameters
+ ----------
+ x : int
+ The input value.
+
+ sleep : float
+ The amount of time to sleep
+
+ Returns
+ -------
+ x : object
+ The original value
+ """
+ if sleep:
+ time.sleep(sleep)
+ return x
+
+
+def terminate_self():
+ """Testing function to terminate the process."""
+ sys.exit(-1)
+
+
tvm._ffi._init_api("testing", __name__)
diff --git a/tests/python/contrib/test_popen_pool.py b/tests/python/contrib/test_popen_pool.py
new file mode 100644
index 0000000..6b5b367
--- /dev/null
+++ b/tests/python/contrib/test_popen_pool.py
@@ -0,0 +1,71 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test PopenPoolExecutor."""
+import pytest
+import time
+from tvm.contrib.popen_pool import PopenWorker, PopenPoolExecutor
+from tvm.testing import identity_after, terminate_self
+
+
+def test_popen_worker():
+ proc = PopenWorker()
+
+ with pytest.raises(TimeoutError):
+ proc.send(identity_after, [1, 100], timeout=0.01)
+ proc.recv()
+
+ with pytest.raises(ChildProcessError):
+ proc.send(terminate_self)
+ proc.recv()
+
+ proc.send(identity_after, [2, 0])
+ assert proc.recv() == 2
+
+ proc.send(identity_after, [4, 0.0001])
+ assert proc.recv() == 4
+
+
+def test_popen_pool_executor():
+ import tvm
+
+ pool = PopenPoolExecutor(max_workers=2, timeout=0.01)
+ value1 = pool.submit(identity_after, 1, 100)
+ value2 = pool.submit(terminate_self)
+ value3 = pool.submit(identity_after, 3, 0)
+ value4 = pool.submit(tvm.runtime.String, "xyz")
+
+ with pytest.raises(TimeoutError):
+ value1.result()
+
+ with pytest.raises(ChildProcessError):
+ value2.result()
+
+ assert value3.result() == 3
+ value = value4.result()
+ assert isinstance(value, tvm.runtime.String)
+ assert value == "xyz"
+
+ pool = PopenPoolExecutor(max_workers=4, timeout=None)
+ values = pool.map_with_error_catching(lambda x: x, range(100))
+
+ for idx, val in enumerate(values):
+ assert val.value == idx
+
+
+if __name__ == "__main__":
+ test_popen_worker()
+ test_popen_pool_executor()