You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2020/12/20 21:08:14 UTC

[tvm] branch main updated: [CONTRIB] PopenPoolExecutor (#6959)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 37af2d7  [CONTRIB] PopenPoolExecutor (#6959)
37af2d7 is described below

commit 37af2d741d3efb37ca6aba261db8b78583dbc1cd
Author: Tianqi Chen <tq...@users.noreply.github.com>
AuthorDate: Sun Dec 20 16:07:58 2020 -0500

    [CONTRIB] PopenPoolExecutor (#6959)
    
    PopenPoolExecutor implements a ProcessPoolExecutor backed by popen.
    
    - Only handles invoking functions in tvm namespace.
    - Unlike multiprocessing, does not require __main__ block,
      which means it can directly run on jupyter notebook.
    - Come with timeout and fault tolerant support to timeout
      long running jobs, and restart the process when an error happens.
    
    Recommended usage: it is recommended to create a pool and reuse
    it in a long running job(e.g. autotuning) so that the process
    are reused when possible.
---
 python/tvm/contrib/popen_pool.py        | 329 ++++++++++++++++++++++++++++++++
 python/tvm/exec/popen_worker.py         | 104 ++++++++++
 python/tvm/testing.py                   |  28 +++
 tests/python/contrib/test_popen_pool.py |  71 +++++++
 4 files changed, 532 insertions(+)

diff --git a/python/tvm/contrib/popen_pool.py b/python/tvm/contrib/popen_pool.py
new file mode 100644
index 0000000..bca0862
--- /dev/null
+++ b/python/tvm/contrib/popen_pool.py
@@ -0,0 +1,329 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+"""Multiprocessing via Popen.
+
+This module provides a multi-processing pool backed by Popen.
+with additional timeout support.
+"""
+import os
+import sys
+import struct
+import threading
+import subprocess
+import concurrent.futures
+from enum import IntEnum
+from collections import namedtuple
+import pickle
+
+
+def kill_child_processes(pid):
+    """Kill all child processes recursively for a given pid.
+
+    Parameters
+    ----------
+    pid : int
+        The given parameter id.
+    """
+    # pylint: disable=import-outside-toplevel
+    import psutil
+
+    try:
+        parent = psutil.Process(pid)
+    except psutil.NoSuchProcess:
+        return
+
+    for process in parent.children(recursive=True):
+        try:
+            process.kill()
+        except psutil.NoSuchProcess:
+            pass
+
+
+class StatusKind(IntEnum):
+    """Running and return value status."""
+
+    RUNNING = 0
+    COMPLETE = 1
+    EXCEPTION = 2
+    TIMEOUT = 3
+
+
+class MapResult(namedtuple("MapResult", ["status", "value"])):
+    """Result of map_with_error_catching.
+
+    Parameters
+    ----------
+    status : StatusKind
+        The status of the result.
+
+    value : Any
+        The result value.
+    """
+
+    __slots__ = []
+
+
+class PopenWorker:
+    """A subprocess worker via Popen.
+
+    PopenWorker provides a low-level
+    API to interact with a separate process via Popen.
+    """
+
+    def __init__(self):
+        self._proc = None
+
+    def __del__(self):
+        try:
+            self.kill()
+        except ImportError:
+            pass
+
+    def kill(self):
+        """Kill the current running process and cleanup.
+
+        Note
+        ----
+        The worker can start a new process when send is called again.
+        """
+        if self._proc is not None:
+            # allow gracefully shutdown
+            try:
+                self._writer.close()
+            except IOError:
+                pass
+            try:
+                self._reader.close()
+            except IOError:
+                pass
+            # kill all child processes recurisvely
+            kill_child_processes(self._proc.pid)
+            try:
+                self._proc.kill()
+            except OSError:
+                pass
+            self._proc = None
+
+    def _start(self):
+        """Start a new subprocess if nothing is available"""
+        if self._proc is not None:
+            return
+
+        # connect subprocess with a pair of pipes
+        main_read, worker_write = os.pipe()
+        worker_read, main_write = os.pipe()
+
+        cmd = [sys.executable, "-m", "tvm.exec.popen_worker"]
+        if sys.platform == "win32":
+            # pylint: disable=import-outside-toplevel
+            import msvcrt
+
+            worker_read_handle = msvcrt.get_osfhandle(worker_read)
+            worker_write_handle = msvcrt.get_osfhandle(worker_write)
+            os.set_handle_inheritable(worker_read_handle, True)
+            os.set_handle_inheritable(worker_write_handle, True)
+            cmd += [str(worker_read_handle), str(worker_write_handle)]
+            self._proc = subprocess.Popen(cmd, close_fds=False)
+        else:
+            cmd += [str(worker_read), str(worker_write)]
+            self._proc = subprocess.Popen(cmd, pass_fds=(worker_read, worker_write))
+
+        # close worker side of the pipe
+        os.close(worker_read)
+        os.close(worker_write)
+        self._reader = os.fdopen(main_read, "rb")
+        self._writer = os.fdopen(main_write, "wb")
+
+    def send(self, fn, args=(), kwargs=None, timeout=None):
+        """Send a new function task fn(*args, **kwargs) to the subprocess.
+
+        Parameters
+        ----------
+        fn : function
+            The function to be invoked.
+
+        args : list
+            Positional argument.
+
+        kwargs : dict
+            Keyword arguments
+
+        timeout : float
+            Timeout value when executing the function
+
+        Note
+        ----
+        The caller must call recv before calling the next send in
+        order to make sure the timeout and child process exit
+        won't affect the later requests.
+        """
+        # use cloud pickle
+        # pylint: disable=import-outside-toplevel
+        import cloudpickle
+
+        if self._proc is None:
+            self._start()
+        kwargs = {} if not kwargs else kwargs
+        data = cloudpickle.dumps((fn, args, kwargs, timeout), protocol=pickle.HIGHEST_PROTOCOL)
+        try:
+            self._writer.write(struct.pack("<i", len(data)))
+            self._writer.write(data)
+            self._writer.flush()
+        except IOError:
+            pass
+
+    def _child_process_error(self):
+        """Raise a child process error."""
+        # kill and lazily restart the process in the next send.
+        self.kill()
+        return ChildProcessError("Subprocess terminated")
+
+    def recv(self):
+        """Receive the result of the last send.
+
+        Returns
+        -------
+        result: object
+            The result of the last send.
+
+        Raises
+        ------
+        ChildProcessError: if the child process exited abnormally.
+        TimeoutError: if timeout happens
+        Exception: if other exception happens during the execution.
+        """
+        # pylint: disable=import-outside-toplevel
+        import cloudpickle
+
+        try:
+            len_data = self._reader.read(4)
+        except IOError:
+            raise self._child_process_error()
+
+        if len(len_data) == 0:
+            raise self._child_process_error()
+
+        try:
+            recv_bytes = struct.unpack("<i", len_data)[0]
+            status, value = cloudpickle.loads(self._reader.read(recv_bytes))
+        except IOError:
+            raise self._child_process_error()
+
+        if status == StatusKind.COMPLETE:
+            return value
+        if status == StatusKind.EXCEPTION:
+            raise value
+        assert status == StatusKind.TIMEOUT
+        # kill and lazily restart the process in the next send.
+        self.kill()
+        raise TimeoutError()
+
+
+class PopenPoolExecutor:
+    """An parallel executor backed by Popen processes.
+
+    Parameters
+    ----------
+    max_worker : int
+        Maximum number of workers
+
+    timeout : float
+        Timeout value for each function submit.
+    """
+
+    def __init__(self, max_workers, timeout=None):
+        # Use an internal thread pool to send to popen workers
+        self._threadpool = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)
+        self._timeout = timeout
+        self._worker_map = {}
+        self._lock = threading.Lock()
+
+    def __del__(self):
+        self._lock.acquire()
+        for worker in self._worker_map.values():
+            try:
+                worker.kill()
+            except ImportError:
+                pass
+        self._lock.release()
+        self._threadpool.shutdown()
+
+    def _worker_run(self, fn, args, kwargs):
+        """Internal thread runner."""
+        self._lock.acquire()
+        tid = threading.get_ident()
+        if tid not in self._worker_map:
+            proc = PopenWorker()
+            self._worker_map[tid] = proc
+        else:
+            proc = self._worker_map[tid]
+        self._lock.release()
+
+        proc.send(fn, args, kwargs, self._timeout)
+        return proc.recv()
+
+    def _worker_run_with_error_catching(self, fn, args, kwargs) -> MapResult:
+        # pylint: disable=broad-except
+        try:
+            return MapResult(status=StatusKind.COMPLETE, value=self._worker_run(fn, args, kwargs))
+        except TimeoutError as exception:
+            return MapResult(status=StatusKind.TIMEOUT, value=exception)
+        except Exception as exception:
+            return MapResult(status=StatusKind.EXCEPTION, value=exception)
+
+    def submit(self, fn, *args, **kwargs) -> concurrent.futures.Future:
+        """Submit a new function job to the pool
+
+        Parameters
+        ----------
+        fn : function
+            The function to be invoked.
+
+        args : list
+            Positional argument.
+
+        kwargs : dict
+            Keyword arguments
+
+        Returns
+        -------
+        future : concurrent.futures.Future
+            A future that can be used to access the result.
+        """
+        # pylint: disable=unnecessary-lambda
+        worker = lambda *args: self._worker_run(*args)
+        return self._threadpool.submit(worker, fn, args, kwargs)
+
+    def map_with_error_catching(self, fn, iterator):
+        """Same as map, but catches exceptions and return them instead.
+
+        Parameters
+        ----------
+        fn : function
+            The function to be invoked.
+
+        iterator : Iterator
+            Input iterator.
+
+        Returns
+        -------
+        out_iter : Iterator[MapResult]
+            The result iterator.
+        """
+        worker = lambda x: self._worker_run_with_error_catching(fn, (x,), None)
+        return self._threadpool.map(worker, iterator)
diff --git a/python/tvm/exec/popen_worker.py b/python/tvm/exec/popen_worker.py
new file mode 100644
index 0000000..b62cca5
--- /dev/null
+++ b/python/tvm/exec/popen_worker.py
@@ -0,0 +1,104 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+"""Internal PopenWorker for PopenPool."""
+import sys
+import os
+import struct
+import threading
+import traceback
+import pickle
+import cloudpickle
+
+from tvm.contrib.popen_pool import StatusKind
+
+
+class TimeoutStatus:
+    __slot__ = ["status"]
+
+    def __init__(self):
+        self.status = StatusKind.RUNNING
+
+
+def main():
+    """Main worker function"""
+    if len(sys.argv) != 3:
+        print("Usage: <read_fd> <write_fd>")
+        return
+    if sys.platform == "win32":
+        # pylint: disable=import-outside-toplevel
+        import msvcrt
+
+        reader = os.fdopen(msvcrt.open_osfhandle(int(sys.argv[1]), os.O_BINARY), "rb")
+        writer = os.fdopen(msvcrt.open_osfhandle(int(sys.argv[2]), os.O_BINARY), "wb")
+    else:
+        reader = os.fdopen(int(sys.argv[1]), "rb")
+        writer = os.fdopen(int(sys.argv[2]), "wb")
+
+    lock = threading.Lock()
+
+    def _respond(ret_value):
+        """Send data back to the client."""
+        data = cloudpickle.dumps(ret_value, protocol=pickle.HIGHEST_PROTOCOL)
+        writer.write(struct.pack("<i", len(data)))
+        writer.write(data)
+        writer.flush()
+
+    def _cancel_run(status):
+        lock.acquire()
+        if status.status == StatusKind.RUNNING:
+            _respond((StatusKind.TIMEOUT, TimeoutError()))
+            status.status = StatusKind.TIMEOUT
+        lock.release()
+
+    while True:
+        raw_bytes_size = reader.read(4)
+        if len(raw_bytes_size) != 4:
+            # the parent exited
+            return
+        bytes_size = struct.unpack("<i", raw_bytes_size)[0]
+        fn, args, kwargs, timeout = cloudpickle.loads(reader.read(bytes_size))
+        status = TimeoutStatus()
+
+        if timeout is not None:
+            watcher = threading.Timer(timeout, _cancel_run, [status])
+            watcher.daemon = True
+            watcher.start()
+
+        # pylint: disable=broad-except
+        try:
+            result = fn(*args, **kwargs)
+            ret_value = (StatusKind.COMPLETE, result)
+        except Exception as exception:
+            msg = traceback.format_exc()
+            ret_value = (StatusKind.EXCEPTION, type(exception)(msg))
+
+        if timeout is not None:
+            watcher.cancel()
+
+        lock.acquire()
+        if status.status == StatusKind.RUNNING:
+            _respond(ret_value)
+            status.status = StatusKind.COMPLETE
+        lock.release()
+
+
+if __name__ == "__main__":
+    try:
+        main()
+    except (KeyboardInterrupt, IOError):
+        pass
diff --git a/python/tvm/testing.py b/python/tvm/testing.py
index e5b17f3..8311a63 100644
--- a/python/tvm/testing.py
+++ b/python/tvm/testing.py
@@ -56,6 +56,8 @@ function in this module. Then targets using this node should be added to the
 """
 import logging
 import os
+import sys
+import time
 import pytest
 import numpy as np
 import tvm
@@ -714,4 +716,30 @@ def parametrize_targets(*args):
     return wrap(args)
 
 
+def identity_after(x, sleep):
+    """Testing function to return identity after sleep
+
+    Parameters
+    ----------
+    x : int
+        The input value.
+
+    sleep : float
+        The amount of time to sleep
+
+    Returns
+    -------
+    x : object
+        The original value
+    """
+    if sleep:
+        time.sleep(sleep)
+    return x
+
+
+def terminate_self():
+    """Testing function to terminate the process."""
+    sys.exit(-1)
+
+
 tvm._ffi._init_api("testing", __name__)
diff --git a/tests/python/contrib/test_popen_pool.py b/tests/python/contrib/test_popen_pool.py
new file mode 100644
index 0000000..6b5b367
--- /dev/null
+++ b/tests/python/contrib/test_popen_pool.py
@@ -0,0 +1,71 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test PopenPoolExecutor."""
+import pytest
+import time
+from tvm.contrib.popen_pool import PopenWorker, PopenPoolExecutor
+from tvm.testing import identity_after, terminate_self
+
+
+def test_popen_worker():
+    proc = PopenWorker()
+
+    with pytest.raises(TimeoutError):
+        proc.send(identity_after, [1, 100], timeout=0.01)
+        proc.recv()
+
+    with pytest.raises(ChildProcessError):
+        proc.send(terminate_self)
+        proc.recv()
+
+    proc.send(identity_after, [2, 0])
+    assert proc.recv() == 2
+
+    proc.send(identity_after, [4, 0.0001])
+    assert proc.recv() == 4
+
+
+def test_popen_pool_executor():
+    import tvm
+
+    pool = PopenPoolExecutor(max_workers=2, timeout=0.01)
+    value1 = pool.submit(identity_after, 1, 100)
+    value2 = pool.submit(terminate_self)
+    value3 = pool.submit(identity_after, 3, 0)
+    value4 = pool.submit(tvm.runtime.String, "xyz")
+
+    with pytest.raises(TimeoutError):
+        value1.result()
+
+    with pytest.raises(ChildProcessError):
+        value2.result()
+
+    assert value3.result() == 3
+    value = value4.result()
+    assert isinstance(value, tvm.runtime.String)
+    assert value == "xyz"
+
+    pool = PopenPoolExecutor(max_workers=4, timeout=None)
+    values = pool.map_with_error_catching(lambda x: x, range(100))
+
+    for idx, val in enumerate(values):
+        assert val.value == idx
+
+
+if __name__ == "__main__":
+    test_popen_worker()
+    test_popen_pool_executor()