You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by co...@apache.org on 2021/01/06 02:04:51 UTC

[tvm] branch main updated: [Fix][Autoscheduler] Costmodel enhancement & bug fix for graph debug runtime (#7197)

This is an automated email from the ASF dual-hosted git repository.

comaniac pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 040afb0  [Fix][Autoscheduler] Costmodel enhancement & bug fix for graph debug runtime (#7197)
040afb0 is described below

commit 040afb0245526e1cc71dc0ada6c3c5787394a5c6
Author: Chenfan <ch...@alibaba-inc.com>
AuthorDate: Wed Jan 6 10:04:34 2021 +0800

    [Fix][Autoscheduler] Costmodel enhancement & bug fix for graph debug runtime (#7197)
    
    * Enhancement for autoscheduler cost model
    
    * Bug fix for graph_runtime_debug
    
    * Update
    
    * Lint fix
    
    * Update
    
    * Update
    
    * Add file exist check for cost model load
    
    * Update
    
    * Update
    
    * Lint fix
    
    * Update
    
    * Bug fix
---
 python/tvm/auto_scheduler/cost_model/xgb_model.py | 25 ++++++++++++++++++++++-
 python/tvm/auto_scheduler/task_scheduler.py       | 13 ++++++++++--
 src/auto_scheduler/feature.cc                     | 18 ++++++++++------
 3 files changed, 47 insertions(+), 9 deletions(-)

diff --git a/python/tvm/auto_scheduler/cost_model/xgb_model.py b/python/tvm/auto_scheduler/cost_model/xgb_model.py
index eb14dff..f426482 100644
--- a/python/tvm/auto_scheduler/cost_model/xgb_model.py
+++ b/python/tvm/auto_scheduler/cost_model/xgb_model.py
@@ -88,7 +88,14 @@ class XGBModel(PythonBasedModel):
     their predictions.
     """
 
-    def __init__(self, verbose_eval=25, num_warmup_sample=100, seed=None):
+    def __init__(
+        self,
+        verbose_eval=25,
+        num_warmup_sample=100,
+        seed=None,
+        model_file=None,
+        adapative_training=False,
+    ):
         global xgb
         try:
             if xgb is None:
@@ -116,12 +123,15 @@ class XGBModel(PythonBasedModel):
         self.plan_size = 32
         self.num_warmup_sample = num_warmup_sample
         self.verbose_eval = verbose_eval
+        self.model_file = model_file
+        self.adapative_training = adapative_training
 
         super().__init__()
 
         # cache measurement input/result pairs and extracted features
         self.inputs = []
         self.results = []
+        self.last_train_length = 0
         self.inputs_feature_cache = []
 
     def update(self, inputs, results):
@@ -141,6 +151,15 @@ class XGBModel(PythonBasedModel):
         self.inputs.extend(inputs)
         self.results.extend(results)
 
+        if (
+            self.adapative_training
+            and len(self.inputs) - self.last_train_length < self.last_train_length / 5
+        ):
+            # Set a training threshold related to `last_train_length` to reduce the training
+            # overhead when there're too many logs
+            return
+        self.last_train_length = len(self.inputs)
+
         # extract feature
         n_cached = len(self.inputs_feature_cache)
         features, normalized_throughputs, task_ids = get_per_store_features_from_measure_pairs(
@@ -176,6 +195,10 @@ class XGBModel(PythonBasedModel):
             ],
         )
 
+        # Update the model file if it has been set
+        if self.model_file:
+            self.save(self.model_file)
+
     def predict(self, task, states):
         """Predict the scores of states
         Parameters
diff --git a/python/tvm/auto_scheduler/task_scheduler.py b/python/tvm/auto_scheduler/task_scheduler.py
index ab83ff4..975306f 100644
--- a/python/tvm/auto_scheduler/task_scheduler.py
+++ b/python/tvm/auto_scheduler/task_scheduler.py
@@ -47,6 +47,7 @@ def make_search_policies(
     verbose,
     load_model_file=None,
     load_log_file=None,
+    adapative_training=False,
 ):
     """Make a list of search policies for a list of search tasks.
     It creates one policy per task.
@@ -70,6 +71,9 @@ def make_search_policies(
     load_log_file: Optional[str]
         Load measurement records from this file. If it is not None, the status of the
         task scheduler, search policies and cost models will be restored according to this file.
+    adapative_training: bool = False
+        Option used for XGBModel, which will reduce the model training frequency when there're too
+        many logs.
 
     Returns
     -------
@@ -82,11 +86,16 @@ def make_search_policies(
     if isinstance(search_policy, str):
         policy_type, model_type = search_policy.split(".")
         if model_type == "xgb":
-            cost_model = XGBModel(num_warmup_sample=len(tasks) * num_measures_per_round)
-            if load_model_file:
+            cost_model = XGBModel(
+                num_warmup_sample=len(tasks) * num_measures_per_round,
+                model_file=load_model_file,
+                adapative_training=adapative_training,
+            )
+            if load_model_file and os.path.isfile(load_model_file):
                 logger.info("TaskScheduler: Load pretrained model...")
                 cost_model.load(load_model_file)
             elif load_log_file:
+                logger.info("TaskScheduler: Reload measured states and train the model...")
                 cost_model.update_from_file(load_log_file)
         elif model_type == "random":
             cost_model = RandomModel()
diff --git a/src/auto_scheduler/feature.cc b/src/auto_scheduler/feature.cc
index 47b9fb6..a5d4958 100755
--- a/src/auto_scheduler/feature.cc
+++ b/src/auto_scheduler/feature.cc
@@ -1462,12 +1462,18 @@ void GetPerStoreFeaturesFromMeasurePairs(const Array<MeasureInput>& inputs,
     if (find_res == task_cache.end()) {
       if (inputs[i]->task->compute_dag.defined()) {  // the measure input is complete
         task = inputs[i]->task;
-      } else {  // the measure input is incomplete
-        // rebuild task for incomplete measure pairs read from file
-        Array<te::Tensor> tensors = (*workload_key_to_tensors)(workload_key);
-        task = SearchTask(ComputeDAG(tensors), workload_key, inputs[i]->task->target,
-                          inputs[i]->task->target_host, inputs[i]->task->hardware_params,
-                          inputs[i]->task->layout_rewrite_option);
+      } else {
+        // The measure input is incomplete, rebuild task for incomplete measure pairs read from file
+        try {
+          Array<te::Tensor> tensors = (*workload_key_to_tensors)(workload_key);
+          task = SearchTask(ComputeDAG(tensors), workload_key, inputs[i]->task->target,
+                            inputs[i]->task->target_host, inputs[i]->task->hardware_params,
+                            inputs[i]->task->layout_rewrite_option);
+        } catch (std::exception& e) {
+          // Cannot build ComputeDAG from workload key, the task may have not been registered in
+          // this search round
+          continue;
+        }
       }
       task_id = task_cache.size();