You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/10/26 12:30:28 UTC

[GitHub] [incubator-tvm] merrymercy commented on a change in pull request #6671: [FIX,AUTOSCHEDULER] Fix auto_scheduler to run with multiprocessing's spawn start method

merrymercy commented on a change in pull request #6671:
URL: https://github.com/apache/incubator-tvm/pull/6671#discussion_r511899520



##########
File path: python/tvm/auto_scheduler/measure.py
##########
@@ -51,19 +51,52 @@
 from .loop_state import StateObject
 from .utils import (
     get_const_tuple,
-    NoDaemonPool,
     call_func_with_timeout,
     request_remote,
     check_remote,
 )
+from .compute_dag import ComputeDAG
+from .search_task import SearchTask
+from .workload_registry import workload_name, get_workload
 
 # The maximum length of error message
 MAX_ERROR_MSG_LEN = 512
 
-# We use fork and a global variable to copy arguments between processes.
-# This can avoid expensive serialization of TVM IR when using multiprocessing.Pool
-GLOBAL_BUILD_ARGUMENTS = None
-GLOBAL_RUN_ARGUMENTS = None
+
+def recover_measure_input(inp, rebuild_state=False):
+    """
+    Recover a deserialized MeasureInput by rebuilding the missing fields.
+    1. Rebuid the compute_dag in inp.task
+    2. (Optional) Rebuild the stages in inp.state
+
+    Parameters
+    ----------
+    inp: MeasureInput
+        The deserialized MeasureInput
+    rebuild_state: bool = False
+        Whether rebuild the stages in MeasureInput.State
+
+    Returns
+    -------
+    new_input: MeasureInput
+        The fully recovered MeasureInput with all fields rebuilt.
+    """
+    task = inp.task
+    print(task.hardware_params)

Review comment:
       delete this

##########
File path: python/tvm/auto_scheduler/measure.py
##########
@@ -950,26 +1017,30 @@ def rpc_runner_run(
     res : List[MeasureResult]
         The measure results of these MeasureInputs.
     """
-    global GLOBAL_RUN_ARGUMENTS
-    GLOBAL_RUN_ARGUMENTS = (
-        inputs,
-        build_results,
-        key,
-        host,
-        port,
-        priority,
-        timeout,
-        number,
-        repeat,
-        min_repeat_ms,
-        cooldown_interval,
-        enable_cpu_cache_flush,
-        verbose,
-    )
-
     assert len(inputs) == len(build_results), "Measure input size should be equal to build results"
-    pool = NoDaemonPool(n_parallel)
-    tuple_res = pool.map(rpc_run_worker, range(len(build_results)))
+    # This pool is not doing computationally intensive work, so we can use threads

Review comment:
       Did you benchmark the speed of ProcessingPool vs. ThreadPool?
   For the comment, is it "not doing computational intensive work" or "not doing computational intensive work in python"?

##########
File path: python/tvm/auto_scheduler/utils.py
##########
@@ -169,17 +143,18 @@ def kill_child_processes(parent_pid, sig=signal.SIGTERM):
             return
 
 
+def _func_wrapper(que, func, args, kwargs):

Review comment:
       document it?

##########
File path: python/tvm/auto_scheduler/measure.py
##########
@@ -590,14 +634,20 @@ def local_builder_build(inputs, timeout, n_parallel, build_func="default", verbo
     res : List[BuildResult]
         The build results of these MeasureInputs.
     """
-    # We use fork and a global variable to copy arguments between processes.
-    # This can avoid expensive serialization of TVM IR when using multiprocessing.Pool
-    global GLOBAL_BUILD_ARGUMENTS
-
-    GLOBAL_BUILD_ARGUMENTS = (inputs, build_func, timeout, verbose)
-
-    pool = NoDaemonPool(n_parallel)
-    tuple_res = pool.map(local_build_worker, range(len(inputs)))
+    # This pool is not doing computationally intensive work, so we can use threads
+    pool = multiprocessing.pool.ThreadPool(n_parallel)

Review comment:
       Do we still need serialization with `ThreadPool`?
   I guess all these arguments will be shared in memory and will be passed as references.

##########
File path: python/tvm/auto_scheduler/measure.py
##########
@@ -87,6 +120,24 @@ def __init__(self, task, state):
         state = state if isinstance(state, StateObject) else state.state_object
         self.__init_handle_by_constructor__(_ffi_api.MeasureInput, task, state)
 
+    def serialize(self):

Review comment:
       Can we use `__getstate__`?

##########
File path: python/tvm/autotvm/task/task.py
##########
@@ -173,6 +172,8 @@ def __getstate__(self):
         # some unpickable local task functions.
         # So we only pickle the name of the function
         # and restore the function by name when unpickling it.
+        import cloudpickle  # pylint: disable=import-outside-toplevel

Review comment:
       Why not import on the top-level?

##########
File path: python/tvm/autotvm/tuner/xgboost_cost_model.py
##########
@@ -321,10 +316,11 @@ def _get_feature(self, indexes):
 
         indexes = np.array(indexes)
         need_extract = [x for x in indexes if x not in fea_cache]
+        args = [(self.space.get(x), self.target, self.task) for x in need_extract]

Review comment:
       Doing this still needs serializing a lot of things. Did you test the performance before vs. after?

##########
File path: src/auto_scheduler/measure_record.cc
##########
@@ -107,18 +107,68 @@ struct Handler<::tvm::auto_scheduler::StateNode> {
   }
 };
 
+template <>
+struct Handler<::tvm::auto_scheduler::HardwareParamsNode> {
+  inline static void Write(dmlc::JSONWriter* writer,
+                           const ::tvm::auto_scheduler::HardwareParamsNode& data) {
+    writer->BeginArray(false);
+    writer->WriteArrayItem(data.num_cores);
+    writer->WriteArrayItem(data.vector_unit_bytes);
+    writer->WriteArrayItem(data.cache_line_bytes);
+    writer->WriteArrayItem(data.max_shared_memory_per_block);
+    writer->WriteArrayItem(data.max_registers_per_block);
+    writer->WriteArrayItem(data.max_threads_per_block);
+    writer->WriteArrayItem(data.max_vthread_extent);
+    writer->WriteArrayItem(data.warp_size);
+    writer->EndArray();
+  }
+  inline static void Read(dmlc::JSONReader* reader,
+                          ::tvm::auto_scheduler::HardwareParamsNode* data) {
+    bool s;
+    reader->BeginArray();
+    s = reader->NextArrayItem();
+    CHECK(s);
+    reader->Read(&data->num_cores);
+    s = reader->NextArrayItem();
+    CHECK(s);
+    reader->Read(&data->vector_unit_bytes);
+    s = reader->NextArrayItem();
+    CHECK(s);
+    reader->Read(&data->cache_line_bytes);
+    s = reader->NextArrayItem();
+    CHECK(s);
+    reader->Read(&data->max_shared_memory_per_block);
+    s = reader->NextArrayItem();
+    CHECK(s);
+    reader->Read(&data->max_registers_per_block);
+    s = reader->NextArrayItem();
+    CHECK(s);
+    reader->Read(&data->max_threads_per_block);
+    s = reader->NextArrayItem();
+    CHECK(s);
+    reader->Read(&data->max_vthread_extent);
+    s = reader->NextArrayItem();
+    CHECK(s);
+    reader->Read(&data->warp_size);
+    s = reader->NextArrayItem();
+    CHECK(!s);
+  }
+};
+
 template <>
 struct Handler<::tvm::auto_scheduler::SearchTaskNode> {
   inline static void Write(dmlc::JSONWriter* writer,
                            const ::tvm::auto_scheduler::SearchTaskNode& data) {
     writer->BeginArray(false);
     writer->WriteArrayItem(std::string(data.workload_key));
     writer->WriteArrayItem(data.target->str());
+    writer->WriteArrayItem(*data.hardware_params.get());

Review comment:
       Since you changed the format, please update the version number to `v0.3`
   https://github.com/apache/incubator-tvm/blob/c6f18250e176a0f107481d489651f6abdfe00976/src/auto_scheduler/measure_record.cc#L219




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org