You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/02/05 16:50:46 UTC

[GitHub] [tvm] merrymercy commented on a change in pull request #7376: [AutoScheduler] Add sampling to dispatcher

merrymercy commented on a change in pull request #7376:
URL: https://github.com/apache/tvm/pull/7376#discussion_r571104145



##########
File path: python/tvm/auto_scheduler/dispatcher.py
##########
@@ -301,6 +306,85 @@ def update(self, target, workload_key, state):
             entry[workload_args] = (state, 1)
 
 
+class ApplyHistoryBestOrSample(ApplyHistoryBest):
+    """
+    Apply the history best config, or sample a valid schedule if no config is found.
+
+    Parameters
+    ----------
+    records : str or iterator of (auto_scheduler.measure.MeasureInput,\
+                                  auto_scheduler.measure.MeasureResult)
+        Collection of tuning records.
+        If is str, then it should be the filename of a records log file.
+        Each row of this file is an encoded record pair. Otherwise, it is an iterator.
+    sample_simple_workloads: bool
+        When False, sampling will not apply to simple workloads (w/o reduction).
+    cost_model_file: str
+        The filename of the pre-trained XGBoost cost model. If not present, then random
+        model will be used.
+    """
+
+    def __init__(self, records, sample_simple_workloads=False, cost_model_file=None):
+        self.sample_simple_workloads = sample_simple_workloads
+        self.log_dir = tempdir()
+        if cost_model_file is None:
+            self.cost_model = RandomModel()
+        else:
+            self.cost_model = XGBModel(num_warmup_sample=1, model_file=cost_model_file)

Review comment:
       The argument `model_file` does not actually load the model. It is used to save the model. I think this argument is a bad design. We can revisit it later.
   To load the model correctly, use
   ```
   self.cost_model = XGBModel()
   self.cost_model.load(cost_model_file)
   ```
   The `num_warmup_sample` will be correctly set in `XGBModel.load`

##########
File path: python/tvm/auto_scheduler/dispatcher.py
##########
@@ -301,6 +306,85 @@ def update(self, target, workload_key, state):
             entry[workload_args] = (state, 1)
 
 
+class ApplyHistoryBestOrSample(ApplyHistoryBest):
+    """
+    Apply the history best config, or sample a valid schedule if no config is found.
+
+    Parameters
+    ----------
+    records : str or iterator of (auto_scheduler.measure.MeasureInput,\
+                                  auto_scheduler.measure.MeasureResult)
+        Collection of tuning records.
+        If is str, then it should be the filename of a records log file.
+        Each row of this file is an encoded record pair. Otherwise, it is an iterator.
+    sample_simple_workloads: bool
+        When False, sampling will not apply to simple workloads (w/o reduction).
+    cost_model_file: str
+        The filename of the pre-trained XGBoost cost model. If not present, then random
+        model will be used.
+    """
+
+    def __init__(self, records, sample_simple_workloads=False, cost_model_file=None):
+        self.sample_simple_workloads = sample_simple_workloads
+        self.log_dir = tempdir()
+        if cost_model_file is None:
+            self.cost_model = RandomModel()
+        else:
+            self.cost_model = XGBModel(num_warmup_sample=1, model_file=cost_model_file)
+
+        super(ApplyHistoryBestOrSample, self).__init__(
+            records, n_lines=None, include_compatible=True
+        )
+
+    def query(self, target, workload_key, has_complex_op, dag):
+        if has_complex_op or self.sample_simple_workloads:
+            ret = self._query_inside(target, workload_key)
+        else:
+            ret = super(ApplyHistoryBestOrSample, self)._query_inside(target, workload_key)
+
+        if ret is None:
+            ret = self._old_ctx.query(target, workload_key, has_complex_op, dag)
+        return ret
+
+    def _query_inside(self, target, workload_key):
+        ret = super(ApplyHistoryBestOrSample, self)._query_inside(target, workload_key)
+        if ret is not None:
+            return ret
+
+        # Sampling valid schedules when no existing records can be used.
+        task = SearchTask(workload_key=workload_key, target=target)
+        measure_ctx = LocalRPCMeasureContext(min_repeat_ms=300)
+
+        log_file = self.log_dir.relpath("%s.log" % decode_workload_key(workload_key)[0])
+
+        while ret is None:
+            tune_option = TuningOptions(
+                num_measure_trials=2,

Review comment:
       This will run measurements. Is this expected? We can set `num_measure_trials` as an argumnet so users can set it to -1 if users do not want to run measurements.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org