You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by lm...@apache.org on 2020/10/19 06:58:25 UTC

[incubator-tvm] branch main updated: [AutoScheduler] Add task scheduler (#6663)

This is an automated email from the ASF dual-hosted git repository.

lmzheng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 94679b5  [AutoScheduler] Add task scheduler (#6663)
94679b5 is described below

commit 94679b5cf46bd89872ae611995ff43f6dae78786
Author: Lianmin Zheng <li...@gmail.com>
AuthorDate: Sun Oct 18 23:55:40 2020 -0700

    [AutoScheduler] Add task scheduler (#6663)
    
    * Add task scheduler
    
    * fix lint
    
    * fix tests
    
    * fix tests
    
    * fix tests
    
    * fix test cases
    
    * fix test cases
    
    * fix tests
    
    * address comments
---
 include/tvm/auto_scheduler/measure.h               |  22 +-
 include/tvm/auto_scheduler/search_policy.h         |  31 +-
 python/tvm/auto_scheduler/__init__.py              |   2 +
 python/tvm/auto_scheduler/auto_schedule.py         |  11 +-
 python/tvm/auto_scheduler/cost_model/xgb_model.py  |   6 +-
 python/tvm/auto_scheduler/measure.py               |  27 ++
 python/tvm/auto_scheduler/search_policy.py         |  51 ++-
 python/tvm/auto_scheduler/task_scheduler.py        | 422 +++++++++++++++++++++
 python/tvm/auto_scheduler/utils.py                 |  47 +++
 src/auto_scheduler/auto_schedule.cc                |   4 +-
 src/auto_scheduler/feature.cc                      |   4 +-
 src/auto_scheduler/measure.cc                      |  22 +-
 src/auto_scheduler/search_policy/empty_policy.cc   |  39 +-
 src/auto_scheduler/search_policy/empty_policy.h    |  14 +-
 src/auto_scheduler/search_policy/search_policy.cc  |   9 +-
 src/auto_scheduler/search_policy/sketch_policy.cc  |  67 +++-
 src/auto_scheduler/search_policy/sketch_policy.h   |  13 +-
 .../search_policy/sketch_policy_rules.cc           |   9 +-
 .../search_policy/sketch_policy_rules.h            |   3 +-
 .../unittest/test_auto_scheduler_task_scheduler.py | 112 ++++++
 tutorials/auto_scheduler/tune_conv2d_layer_cuda.py |   7 +-
 tutorials/auto_scheduler/tune_matmul_x86.py        |   7 +-
 22 files changed, 819 insertions(+), 110 deletions(-)

diff --git a/include/tvm/auto_scheduler/measure.h b/include/tvm/auto_scheduler/measure.h
index 349f4f8..339f428 100755
--- a/include/tvm/auto_scheduler/measure.h
+++ b/include/tvm/auto_scheduler/measure.h
@@ -423,7 +423,7 @@ class RPCRunner : public ProgramRunner {
 
 /*!
  * \brief Measurer that measures the time costs of tvm programs
- * This class combines ProgramBuilder and ProgramRunner and provides a simpler API */
+ * This class combines ProgramBuilder and ProgramRunner, and provides a simpler API */
 class ProgramMeasurerNode : public Object {
  public:
   /*! \brief Measured programs counter. */
@@ -444,7 +444,7 @@ class ProgramMeasurerNode : public Object {
   Optional<Array<MeasureCallback>> callbacks;
   /*! \brief Verbosity level. 0 for silent, 1 to output information during program measuring. */
   int verbose;
-  /*! \brief The number of max continuous error. */
+  /*! \brief The number of allowed maximum continuous error before forcely stopping the tuning */
   int max_continuous_error;
 
   /*! \brief Reset book keeping variables */
@@ -454,13 +454,12 @@ class ProgramMeasurerNode : public Object {
    * \brief Do measurement.
    * \param task The current SearchTask.
    * \param policy The current SearchPolicy.
-   * \param inputs The MeasureInputs.
-   * \param results A pointer to a MeasureResult Array, this is used as output.
+   * \param inputs The inputs of measurement.
    * \param batch_size Number of programs to be measured in one batch.
+   * \return results The results of measurement.
    */
-  void Measure(const SearchTask& task, const SearchPolicy& policy,
-               const Array<MeasureInput>& inputs, Array<MeasureResult>* results,
-               int batch_size = -1);
+  Array<MeasureResult> Measure(const SearchTask& task, const SearchPolicy& policy,
+                               const Array<MeasureInput>& inputs, int batch_size = -1);
   /*!
    * \brief Do measurement silently.
    * This API will not print the measure results to screen.
@@ -486,12 +485,13 @@ class ProgramMeasurer : public ObjectRef {
  public:
   /*!
    * \brief The constructor.
-   * \param builder The ProgramBuilder to build each program.
-   * \param runner The ProgramRunner to measure each program.
-   * \param callbacks MeasureCallback to be called after each measure batch.
+   * \param builder The ProgramBuilder to build programs.
+   * \param runner The ProgramRunner to measure programs.
+   * \param callbacks MeasureCallback to be called after each measurement batch.
    * \param verbose Verbosity level. 0 for silent, 1 to output information during program
    * measuring.
-   * \param max_continuous_error The number of allowed maximum continuous error.
+   * \param max_continuous_error The number of allowed maximum continuous error before
+   * forcely stopping the tuning.
    */
   ProgramMeasurer(ProgramBuilder builder, ProgramRunner runner,
                   Optional<Array<MeasureCallback>> callbacks, int verbose,
diff --git a/include/tvm/auto_scheduler/search_policy.h b/include/tvm/auto_scheduler/search_policy.h
index ddb0dd2..e433799 100755
--- a/include/tvm/auto_scheduler/search_policy.h
+++ b/include/tvm/auto_scheduler/search_policy.h
@@ -22,26 +22,6 @@
  * \brief The base class of search policies, including the abstract definition of search policy and
  * other supporting data structures.
  *
- * The basic schedule search process for the auto-scheduler is design to be:
- * `Program sampling` -> `Performance Tuning`.
- *
- * In `Program sampling`, we use some predefined precise or heuristic rules to generate several
- * initial schedules. Based on these initial starting points, we perform `Performance Tuning` which
- * uses cost model based evolutionary search to select schedules with the best performance.
- *
- * Candidate schedules are measured against the specific hardware target.
- *
- * We intend to introduce different level of automation on the schedule generation process:
- * - Level 0(the default level): For all kinds of ops/subgraphs, the search policy should be able
- *   to generate schedule automatically.
- * - Level 1: For some complicated ops/subgraphs(e.g. conv2d windograd), the default search space
- *   of level 0 may be too large to find a high performance schedule efficiently. We provide some
- *   op attributes to help reduce the total search space, see `SearchPolicyKey` below for more
- *   information.
- * - Level 2: For some further special ops/subgraphs, users may more likely to write their own
- *   template(just like AutoTVM). Search policy should be able to provide a flexible approach as
- *   well.
- *
  * \note How to add a new search policy.
  * In design, there's no need for users to implement their own search policy, our formal search
  * policy(will be brought later) should be enough to cover most use cases. Meanwhile, a custom rule
@@ -62,11 +42,13 @@
 #ifndef TVM_AUTO_SCHEDULER_SEARCH_POLICY_H_
 #define TVM_AUTO_SCHEDULER_SEARCH_POLICY_H_
 
+#include <tvm/auto_scheduler/measure.h>
 #include <tvm/auto_scheduler/search_task.h>
 #include <tvm/node/node.h>
 
 #include <string>
 #include <unordered_set>
+#include <utility>
 #include <vector>
 
 namespace tvm {
@@ -172,6 +154,15 @@ class SearchPolicyNode : public Object {
                        ProgramMeasurer measurer) = 0;
 
   /*!
+   * \brief Continue the search by doing an additional search round.
+   * \param num_measure The number of measurements
+   * \param measurer The measurer to measure programs
+   * \return The measurement records for measurements in this search round
+   */
+  virtual std::pair<Array<MeasureInput>, Array<MeasureResult>> ContinueSearchOneRound(
+      int num_measure, ProgramMeasurer measurer) = 0;
+
+  /*!
    * \brief Preload measured states from a log file to resume the state of the search policy.
    * \param log_file The name of the record log file.
    */
diff --git a/python/tvm/auto_scheduler/__init__.py b/python/tvm/auto_scheduler/__init__.py
index 6a395e7..99d96e8 100644
--- a/python/tvm/auto_scheduler/__init__.py
+++ b/python/tvm/auto_scheduler/__init__.py
@@ -24,6 +24,7 @@ from . import measure
 from . import measure_record
 from . import search_policy
 from . import search_task
+from . import task_scheduler
 from . import utils
 from . import workload_registry
 
@@ -42,4 +43,5 @@ from .measure import (
 from .measure_record import RecordToFile, RecordReader, load_best, load_records, save_records
 from .search_task import SearchTask
 from .search_policy import EmptyPolicy, SketchPolicy, PreloadMeasuredStates
+from .task_scheduler import TaskScheduler
 from .workload_registry import register_workload, make_workload_key
diff --git a/python/tvm/auto_scheduler/auto_schedule.py b/python/tvm/auto_scheduler/auto_schedule.py
index ca069bb..a53c29d 100644
--- a/python/tvm/auto_scheduler/auto_schedule.py
+++ b/python/tvm/auto_scheduler/auto_schedule.py
@@ -16,16 +16,7 @@
 # under the License.
 
 """
-User interface for TVM Auto-scheduler.
-
-The basic schedule search process for TVM Auto-scheduler is designed to be:
-`Program sampling` -> `Performance Tuning`.
-
-In `Program sampling`, we use some predefined precise or heuristic rules to generate several
-initial schedules. Based on these initial starting points, we perform `Performance Tuning` which
-uses cost model based evolutionary search to select schedules with the best performance.
-
-Candidate schedules are measured against the specific hardware target.
+The user interface and tuning options of the TVM auto-scheduler.
 """
 
 import tvm._ffi
diff --git a/python/tvm/auto_scheduler/cost_model/xgb_model.py b/python/tvm/auto_scheduler/cost_model/xgb_model.py
index 3eb64df..9a534aa 100644
--- a/python/tvm/auto_scheduler/cost_model/xgb_model.py
+++ b/python/tvm/auto_scheduler/cost_model/xgb_model.py
@@ -20,6 +20,7 @@
 import multiprocessing
 import logging
 from collections import defaultdict
+import time
 
 import numpy as np
 import xgboost as xgb
@@ -76,7 +77,7 @@ dmatrix_context = XGBDMatrixContext()
 class XGBModel(PythonBasedModel):
     """Train a XGBoost model to predict the normalized throughputs of programs.
     Let the normalized throughput be the score of a program (higher is better). We predict
-    the (approximiate) score of a program = the sum of the scores of all stages in this program.
+    the (approximate) score of a program = the sum of the scores of all stages in this program.
     i.e. score(P) = score_s0 + score_s1 + ... + score_sn,
     where score_si is the score of Stage i in Program P.
     We extract feature for each stage and let the xgboost predict the score for each stage.
@@ -128,6 +129,7 @@ class XGBModel(PythonBasedModel):
         if len(inputs) <= 0:
             return
         assert len(inputs) == len(results)
+        tic = time.time()
 
         self.inputs.extend(inputs)
         self.results.extend(results)
@@ -167,6 +169,8 @@ class XGBModel(PythonBasedModel):
             ],
         )
 
+        logger.info("XGBModel Training time: %.2f s", time.time() - tic)
+
     def predict(self, task, states):
         """Predict the scores of states
         Parameters
diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py
index 81c314f..8a8b922 100644
--- a/python/tvm/auto_scheduler/measure.py
+++ b/python/tvm/auto_scheduler/measure.py
@@ -185,6 +185,33 @@ class ProgramRunner(Object):
         return _ffi_api.ProgramRunnerRun(self, measure_inputs, build_results, verbose)
 
 
+@tvm._ffi.register_object("auto_scheduler.ProgramMeasurer")
+class ProgramMeasurer(Object):
+    """
+    Measurer that measures the time costs of tvm programs
+    This class combines ProgramBuilder and ProgramRunner, and provides a simpler API.
+
+    Parameters
+    ----------
+    builder : ProgramBuilder
+        The ProgramBuilder to build programs
+    runner : ProgramRunner
+        The ProgramRunner to measure programs.
+    callbacks : List[MeasureCallback]
+        Callbacks to be called after each measurement batch
+    verbose : int
+        The Verbosity level: 0 for silent, 1 to output information during program
+    max_continuous_error : Optional[int]
+        The number of allowed maximum continuous error before stop the tuning
+    """
+
+    def __init__(self, builder, runner, callbacks, verbose, max_continuous_error=None):
+        max_continuous_error = max_continuous_error or -1  # -1 means using the default value
+        self.__init_handle_by_constructor__(
+            _ffi_api.ProgramMeasurer, builder, runner, callbacks, verbose, max_continuous_error
+        )
+
+
 @tvm._ffi.register_object("auto_scheduler.LocalBuilder")
 class LocalBuilder(ProgramBuilder):
     """LocalBuilder use local CPU cores to build programs in parallel.
diff --git a/python/tvm/auto_scheduler/search_policy.py b/python/tvm/auto_scheduler/search_policy.py
index 5533aec..f3d459e 100644
--- a/python/tvm/auto_scheduler/search_policy.py
+++ b/python/tvm/auto_scheduler/search_policy.py
@@ -16,15 +16,17 @@
 # under the License.
 
 """
-The search policies for TVM Auto-scheduler.
+The search policies of TVM auto-scheduler.
 
-This contains the strategies to generate a schedule automatically. We provide an EmptyPolicy
-which always returns an unchanged initial state, and a more advanced SketchPolicy which can
-deal with various ops/subgraphs on different target devices.
+The auto-scheduler constructs a search space according to the compute declaration.
+It then randomly samples programs from the search space and uses evolutionary search with a
+learned cost model to fine tune the sampled programs.
+The final optimized programs are sent to actual hardware for measurement.
+The above process is repeated until the auto-scheduler runs out of time budget.
 
 Reference:
 L. Zheng, C. Jia, M. Sun, Z. Wu, C. Yu, et al. "Ansor : Generating High-Performance Tensor
-Programs for Deep Learning." arXiv preprint arXiv:2006.06762 (2020).
+Programs for Deep Learning." (OSDI 2020).
 """
 
 import random
@@ -63,11 +65,42 @@ class PreloadMeasuredStates(SearchCallback):
 class SearchPolicy(Object):
     """ The base class of search policies. """
 
+    def continue_search_one_round(self, num_measure, measurer):
+        """
+        Continue the search by doing an additional search round.
+
+        Parameters
+        ----------
+        num_measure: int
+            The number of programs to measure in this round
+        measurer: ProgramMeasurer
+            The program measurer to measure programs
+
+        Returns
+        -------
+        inputs: List[MeasureInput]
+            The inputs of measurments in this search round
+        results: List[MeasureResult]
+            The results of measurments in this search round
+        """
+        return _ffi_api.SearchPolicyContinueSearchOneRound(self, num_measure, measurer)
+
+    def set_verbose(self, verbose):
+        """
+        Set the verbosity level of the search policy.
+
+        Parameters
+        ----------
+        verbose: int
+            The verbosity level
+        """
+        return _ffi_api.SearchPolicySetVerbose(self, verbose)
+
 
 @tvm._ffi.register_object("auto_scheduler.EmptyPolicy")
 class EmptyPolicy(SearchPolicy):
-    """This is an example empty search policy which will always generate
-    the init state of ComputeDAG.
+    """A simple example of the search policy which always returns
+    the initial naive schedule (state).
 
     Parameters
     ----------
@@ -195,15 +228,17 @@ class SketchPolicy(SearchPolicy):
         return states
 
     def evolutionary_search(self, init_populations, out_size):
-        """Evolutionary search.
+        """Perform evolutionary search.
         This python interface is mainly used for debugging and testing.
         The actual search is all done in c++.
+
         Parameters
         ----------
         init_populations: List[State]
             The initial population states
         out_size : int
             The size of generated states
+
         Returns
         -------
         states: List[State]
diff --git a/python/tvm/auto_scheduler/task_scheduler.py b/python/tvm/auto_scheduler/task_scheduler.py
new file mode 100644
index 0000000..e45573b
--- /dev/null
+++ b/python/tvm/auto_scheduler/task_scheduler.py
@@ -0,0 +1,422 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+
+""" The task scheduler that allocates the time resources when tuning multiple tasks together
+
+The details of the "gradient" strategy below can be found in the section 6 of this paper:
+L. Zheng, C. Jia, M. Sun, Z. Wu, C. Yu, et al. "Ansor : Generating High-Performance Tensor
+Programs for Deep Learning." (OSDI 2020).
+"""
+
+import time
+import math
+import logging
+
+import numpy as np
+
+from .search_policy import SearchPolicy, SketchPolicy
+from .cost_model import RandomModel, XGBModel
+from .utils import array_mean, to_str_round
+from .measure import ProgramMeasurer
+from .measure_record import RecordReader
+
+logger = logging.getLogger("auto_scheduler")
+
+
+def make_search_policies(
+    search_policy, tasks, num_measures_per_round, verbose, load_model_file=None, load_log_file=None
+):
+    """Make a list of search policies for a list of search tasks.
+    It creates one policy per task.
+
+    Parameters
+    ----------
+    search_policy: Union[str, List[SearchPolicy]]
+        The name of search policy.
+    tasks: List[SearchTask]
+        The list of all tasks
+    num_measures_per_round: int
+        The number of schedules to be measured at each search round.
+        This should be the same as `TuningOptions.num_measures_per_round`
+    verbose: int
+        The verbosity level. 0 for silent.
+    load_model_file: Optional[str]
+        Load pre-trained model from this file. If this is None, the cost model will
+        be trained from scratch.
+    load_log_file: Optional[str]
+        Load measurement records from this file. If it is not None, the status of the
+        task scheduler, search policies and cost models will be restored according to this file.
+
+    Returns
+    -------
+    policies: List[SearchPolicy]
+        The list of search policies
+    """
+    if search_policy == "default":
+        search_policy = "sketch.xgb"
+
+    if isinstance(search_policy, str):
+        policy_type, model_type = search_policy.split(".")
+        if model_type == "xgb":
+            cost_model = XGBModel(num_warmup_sample=len(tasks) * num_measures_per_round)
+            if load_model_file:
+                logger.info("Load pretrained model...")
+                cost_model.load(load_model_file)
+            elif load_log_file:
+                cost_model.load_log_file(load_log_file)
+        elif model_type == "random":
+            cost_model = RandomModel()
+        else:
+            raise ValueError("Invalid search policy: " + search_policy)
+
+        if policy_type == "sketch":
+            search_policies = [SketchPolicy(task, cost_model, verbose=verbose) for task in tasks]
+        else:
+            raise ValueError("Invalid search policy: " + search_policy)
+    else:
+        # check type
+        assert isinstance(search_policy, (tuple, list))
+        for item in search_policy:
+            assert isinstance(item, SearchPolicy)
+        search_policies = search_policy
+
+    return search_policies
+
+
+def derive_similarity_tag(dag, log_base=1.618):
+    """Derive the tag for similarity check from one computational DAG.
+    The DAGs with the same tag are considered as similar tasks.
+
+    The tag format is <op1-tag>_<op2-tag> ... <log(flop)>.
+
+    If the tag is "", then the task is not considered to be similar to any other tasks.
+
+    Parameters
+    ----------
+    dag: ComputeDAG
+        The input computational DAG
+    log_base: float = 1.618
+        The base of log to normalize FLOPS
+
+    Returns
+    -------
+    tag: str
+        The tag of this computational DAG.
+    """
+    ret = ""
+    for op in dag.ops:
+        tag = op.attrs.get("auto_scheduler_task_scheduler_tag", None)
+        if tag:
+            ret += op.attrs["auto_scheduler_task_scheduler_tag"] + "_"
+    if ret:
+        ret += "%d" % int(math.log(dag.flop_ct + 1, log_base))
+    return ret
+
+
+class TaskScheduler:
+    """
+    Allocate the time resources when tuning multiple tasks together.
+    This implements two strategies: "round-robin" and "gradient".
+
+    Parameters
+    ----------
+    tasks: List[SearchTask]
+        All tasks to tune
+    objective_func: Optional[Callable[List[float] -> float]]
+        The objective function to be minimized.
+        The objective function accepts the current latencies of all tasks and returns the
+        objective. If not presented, the objective is the sum of the latencies of all task.
+    strategy: str = "gradient"
+        The scheduling strategy.
+        "round-robin": Tune tasks in round robin order.
+        "gradient" : Tune tasks with gradient descent.
+    load_model_file: Optional[str]
+        Load pre-trained model from this file. If this is None, the cost model will
+        be trained from scratch.
+    load_log_file: Optional[str]
+        Load measurement records from this file. If it is not None, the status of the
+        task scheduler, search policies and cost models will be restored according to this file.
+    verbose: int = 1
+        The level of verbosity. 0 means silent.
+    alpha: float = 0.2
+        The parameter used for 'gradient' strategy
+    beta: float = 2
+        The parameter used for 'gradient' strategy
+    backward_window_size: int = 3
+        The parameter used for 'gradient' strategy
+    """
+
+    def __init__(
+        self,
+        tasks,
+        objective_func=None,
+        strategy="gradient",
+        load_model_file: str = None,
+        load_log_file: str = None,
+        verbose: int = 1,
+        alpha: float = 0.2,
+        beta: float = 2,
+        gamma: float = 0.5,
+        backward_window_size: int = 3,
+    ):
+        self.tasks = tasks
+        self.objective_func = objective_func or sum
+        self.strategy = strategy
+        self.verbose = verbose
+        self.load_log_file = load_log_file
+        self.load_model_file = load_model_file
+        self.alpha = alpha
+        self.beta = beta
+        self.gamma = gamma
+        self.backward_window_size = backward_window_size
+
+        assert len(self.tasks) != 0, "No tasks"
+        assert self.strategy in ["round-robin", "gradient"]
+
+        # task_cts[i] saves how many times task i is tuned
+        self.task_cts = [0 for _ in range(len(self.tasks))]
+
+        # task_costs_history[i] saves the latency history of task i
+        self.task_costs_history = [[] for _ in range(len(self.tasks))]
+
+        # best_costs[i] saves the best latency of task i
+        self.best_costs = 1e10 * np.ones(len(self.tasks))
+        self.cur_score = self._compute_score(self.best_costs)
+
+        self.tune_option = self.measurer = self.search_policies = self.ct = self.tic = None
+        self.num_measures_per_round = None
+        self.dead_tasks = set()
+
+        # Build similarity groups
+        self.task_tags = []  # task_id -> tag
+        self.tag_to_group_id = {}  # tag -> group_id
+        self.group_task_ids = []  # group_id -> all task ids in this group
+        self.flop_cts = []  # task_id -> the number of floating ops
+        for i, task in enumerate(self.tasks):
+            tag = derive_similarity_tag(task.compute_dag)
+            self.task_tags.append(tag)
+            self.flop_cts.append(task.compute_dag.flop_ct)
+            if not tag:
+                continue
+
+            if tag not in self.tag_to_group_id:
+                self.tag_to_group_id[tag] = len(self.tag_to_group_id)
+                self.group_task_ids.append([])
+            self.group_task_ids[self.tag_to_group_id[tag]].append(i)
+
+    def tune(self, tune_option, search_policy="default"):
+        """Tune a batch of tasks together.
+
+        Parameters
+        ----------
+        tune_option: TuningOptions
+            The options of tuning
+        search_policy: : Union[str, List[SearchPolicy]]
+            The list of search policies.
+            If it is str.
+            "sketch.xgb" for SketchPolicy + XGBModel
+            "sketch.random" for SketchPolicy + RandomModel
+        """
+        # init members
+        self.tune_option = tune_option
+        self.measurer = ProgramMeasurer(
+            tune_option.builder,
+            tune_option.runner,
+            tune_option.measure_callbacks,
+            tune_option.verbose,
+        )
+        self.ct = 0
+        self.tic = time.time()
+        # reset num_measures_per_round to make sure every task is tuned at least once
+        self.num_measures_per_round = min(
+            tune_option.num_measures_per_round, tune_option.num_measure_trials // len(self.tasks)
+        )
+        if self.num_measures_per_round <= 0:
+            raise ValueError("num_measure_trials is too small. Please set it to a higher value.")
+
+        # restore the status of the task scheduler from a log file
+        if self.load_log_file:
+            self._restore_status(self.load_log_file, self.num_measures_per_round)
+
+        # make one search policy for one task
+        self.search_policies = make_search_policies(
+            search_policy,
+            self.tasks,
+            self.num_measures_per_round,
+            tune_option.verbose,
+            self.load_model_file,
+            self.load_log_file,
+        )
+
+        # do a round robin first to warm up
+        for i in range(len(self.tasks)):
+            self._tune_task(i)
+
+        # use the specific strategy to choose workload to tune
+        task_idx = -1
+        while self.ct < tune_option.num_measure_trials and len(self.dead_tasks) < len(self.tasks):
+            if self.strategy == "round-robin":
+                task_idx = (task_idx + 1) % len(self.tasks)
+                while task_idx in self.dead_tasks:
+                    task_idx = (task_idx + 1) % len(self.tasks)
+            elif self.strategy == "gradient":
+                gradients = []
+                for i in range(len(self.tasks)):
+                    if i in self.dead_tasks:
+                        gradients.append(0)
+                        continue
+
+                    # compute gradient from chain rule : (delta f / delta g_i)
+                    delta = 1e-7
+                    new_costs = list(self.best_costs)
+                    new_costs[i] -= delta
+                    chain_grad = (
+                        self._compute_score(self.best_costs) - self._compute_score(new_costs)
+                    ) / delta
+
+                    # compute (g_i(t_i) - g(t_i - \Delta t)) / (\Delta t)
+                    if (
+                        self.task_cts[i] - 1 < len(self.task_costs_history[i])
+                        and self.task_cts[i] - 1 - self.backward_window_size >= 0
+                    ):
+                        backward_grad = (
+                            self.task_costs_history[i][self.task_cts[i] - 1]
+                            - self.task_costs_history[i][
+                                self.task_cts[i] - 1 - self.backward_window_size
+                            ]
+                        ) / self.backward_window_size
+                    else:
+                        backward_grad = 0
+
+                    # compute (g_i(t_i + \Delta t) - g(t_i)) / (\Delta t)
+                    g_next_1 = self.best_costs[i] - (self.best_costs[i] / self.task_cts[i])
+
+                    g_next_2 = self.beta * 1e30
+                    group_id = self.tag_to_group_id.get(self.task_tags[i], None)
+                    if group_id is not None and len(self.group_task_ids[group_id]) > 1:
+                        best_flops = max(
+                            [
+                                self.flop_cts[j] / self.best_costs[j]
+                                for j in self.group_task_ids[group_id]
+                            ]
+                        )
+                        g_next_2 = self.beta * self.flop_cts[i] / best_flops
+
+                    g_next = min(g_next_1, g_next_2)
+                    forward_grad = g_next - self.best_costs[i]
+
+                    # combine all grads
+                    grad = chain_grad * (
+                        self.alpha * backward_grad + (1 - self.alpha) * forward_grad
+                    )
+                    assert grad <= 0
+                    gradients.append(grad)
+
+                if max(gradients) == min(gradients):
+                    task_idx = np.random.choice(len(gradients))
+                else:
+                    task_idx = np.argmin(gradients)
+            else:
+                raise ValueError("Invalid strategy: " + self.strategy)
+
+            self._tune_task(task_idx)
+            self._adjust_similarity_group(task_idx)
+
+    def _tune_task(self, task_idx):
+        """Tune the select task for one round"""
+        if self.verbose >= 1:
+            logger.info("TaskScheduler: task id:\t%d", task_idx)
+        measure_inputs, measure_results = self.search_policies[task_idx].continue_search_one_round(
+            self.num_measures_per_round, self.measurer
+        )
+
+        for res in measure_results:
+            cost = array_mean(res.costs)
+            if cost < self.best_costs[task_idx]:
+                self.best_costs[task_idx] = cost
+
+        if len(measure_inputs) == 0:
+            self.dead_tasks.add(task_idx)
+
+        self.task_cts[task_idx] += 1
+        self.task_costs_history[task_idx].append(self.best_costs[task_idx])
+
+        self.ct += len(measure_inputs)
+        self.cur_score = self._compute_score(self.best_costs)
+
+        if self.verbose >= 1:
+            logger.info(
+                "TaskScheduler\tct: %d\testimated cost (ms): %.3f\ttime elapsed: %.2f\t"
+                "best_costs (ms): %s\ttask_ct: %s",
+                self.ct,
+                self.cur_score * 1e3,
+                time.time() - self.tic,
+                to_str_round(self.best_costs * 1e3, decimal=3),
+                self.task_cts,
+            )
+
+    def _compute_score(self, costs):
+        """compute the objective function"""
+        return self.objective_func(costs)
+
+    def _adjust_similarity_group(self, task_idx):
+        """adjust the similarity group for the selected task"""
+        group_id = self.tag_to_group_id.get(self.task_tags[task_idx], None)
+        if group_id is None or len(self.group_task_ids[group_id]) <= 1:
+            return
+
+        group_ids = self.group_task_ids[group_id]
+        best_group_flops = max([self.flop_cts[j] / self.best_costs[j] for j in group_ids])
+        cur_flops = self.flop_cts[task_idx] / self.best_costs[task_idx]
+
+        # if we tune a task for many times but it still cannot achieve
+        # a similar speed to the fastest one in its group, this means this task
+        # is actually not similar to other tasks in its group.
+        # So we will remove it from its original group.
+        if cur_flops < best_group_flops / self.beta and self.task_cts[task_idx] > 5 + max(
+            self.task_cts[j] for j in group_ids if j != task_idx
+        ):
+            self.task_tags[task_idx] = None
+            group_ids.remove(task_idx)
+
+    def _restore_status(self, log_file, num_measures_per_round):
+        """restore task_cts and best_costs from a log file"""
+        str_target = str(self.tasks[0].target)
+        workload_key_to_task_id = {t.workload_key: i for i, t in enumerate(self.tasks)}
+        total_ct = -1
+
+        for total_ct, (inp, res) in enumerate(RecordReader(log_file)):
+            if str(inp.task.target) != str_target:
+                continue
+            task_idx = workload_key_to_task_id.get(inp.task.workload_key, None)
+            if task_idx is None:
+                continue
+
+            if res.error_no == 0:
+                self.best_costs[task_idx] = min(self.best_costs[task_idx], array_mean(res.costs))
+
+            self.task_cts[task_idx] += 1
+
+        for i in range(len(self.tasks)):
+            # The computation of taks_cts is just an estimation.
+            # The estimation may not be accurate if the log file is changed externally or
+            # `num_measures_per_round` is different from the last tuning.
+            self.task_cts[i] = int(self.task_cts[i] / num_measures_per_round + 0.5)
+            self.task_costs_history[i].append(self.best_costs[i])
+
+        logger.info("TaskScheduler: Loaded %d measurement records from %s", total_ct + 1, log_file)
diff --git a/python/tvm/auto_scheduler/utils.py b/python/tvm/auto_scheduler/utils.py
index ff357c4..75fec9c 100644
--- a/python/tvm/auto_scheduler/utils.py
+++ b/python/tvm/auto_scheduler/utils.py
@@ -25,6 +25,8 @@ import signal
 import threading
 import os
 
+import numpy as np
+
 try:
     import psutil
 except ImportError:
@@ -264,3 +266,48 @@ def check_remote(device_key, host=None, port=None, priority=100, timeout=10):
     t.start()
     t.join(timeout)
     return not t.is_alive()
+
+
+def array_mean(arr):
+    """Compute mean of the elments in a TVM Array<PrimExpr>
+
+    Parameters
+    ----------
+    arr: Array
+        A TVM Array<PrimExpr>
+
+    Returns
+    -------
+    mean: float
+        The mean of the elements in the array
+    """
+    return sum(x.value for x in arr) / len(arr)
+
+
+def to_str_round(x, decimal=6):
+    """Convert an object to str and round float numbers
+
+    Parameters
+    ----------
+    x: Union[str, list, int, float, np.ndarray]
+        The input object
+    decimal: int
+        The precision of decimal fraction
+
+    Returns
+    -------
+    ret: str
+        The string format of these objects
+    """
+    if isinstance(x, str):
+        return x
+    if isinstance(x, (list, tuple, np.ndarray)):
+        return "[" + ", ".join([to_str_round(y, decimal=decimal) for y in x]) + "]"
+    if isinstance(x, dict):
+        return str({k: to_str_round(v) for k, v in x.items()})
+    if isinstance(x, int):
+        return str(x)
+    if isinstance(x, (np.float32, np.float64, float)):
+        format_str = "%%.%df" % decimal
+        return format_str % x
+    raise ValueError("Invalid value: " + str(x) + "\ttype: " + str(type(x)))
diff --git a/src/auto_scheduler/auto_schedule.cc b/src/auto_scheduler/auto_schedule.cc
index dd6b705..747aa01 100755
--- a/src/auto_scheduler/auto_schedule.cc
+++ b/src/auto_scheduler/auto_schedule.cc
@@ -19,9 +19,7 @@
 
 /*!
  * \file auto_scheduler/auto_schedule.cc
- * \brief The user interface of the TVM Auto-scheduler. This is the entry structure to get
- * schedule search requirements from upper level (Python API), and returns a high performance
- * schedule after search process.
+ * \brief The user interface and tuning options of the TVM auto-scheduler.
  */
 
 #include <tvm/auto_scheduler/auto_schedule.h>
diff --git a/src/auto_scheduler/feature.cc b/src/auto_scheduler/feature.cc
index 2744e0d..15066a9 100755
--- a/src/auto_scheduler/feature.cc
+++ b/src/auto_scheduler/feature.cc
@@ -871,7 +871,9 @@ class PerStoreFeatureExtractor : public StmtExprVisitor {
         stride = (i == static_cast<int>(for_loop_stack_.size()) - 1 ? stride : 0);
 
         float n_continuous = ele_bytes;
-        for (int i = static_cast<int>(tmp_region.size()) - 1; i >= 0; i--) {
+        for (int i = std::min(static_cast<int>(tmp_region.size()) - 1,
+                              static_cast<int>(int_shape.size()) - 1);
+             i >= 0; i--) {
           if (tmp_region[i] == int_shape[i]) {
             n_continuous *= tmp_region[i];
             break;
diff --git a/src/auto_scheduler/measure.cc b/src/auto_scheduler/measure.cc
index 70ea7ab..c3ee6a1 100755
--- a/src/auto_scheduler/measure.cc
+++ b/src/auto_scheduler/measure.cc
@@ -38,6 +38,7 @@ TVM_REGISTER_NODE_TYPE(MeasureResultNode);
 TVM_REGISTER_OBJECT_TYPE(MeasureCallbackNode);
 TVM_REGISTER_OBJECT_TYPE(ProgramRunnerNode);
 TVM_REGISTER_OBJECT_TYPE(ProgramBuilderNode);
+TVM_REGISTER_OBJECT_TYPE(ProgramMeasurerNode);
 TVM_REGISTER_OBJECT_TYPE(LocalBuilderNode);
 TVM_REGISTER_OBJECT_TYPE(LocalRunnerNode);
 TVM_REGISTER_OBJECT_TYPE(RPCRunnerNode);
@@ -204,11 +205,12 @@ void ProgramMeasurerNode::Reset() {
   best_state.clear();
 }
 
-void ProgramMeasurerNode::Measure(const SearchTask& task, const SearchPolicy& policy,
-                                  const Array<MeasureInput>& inputs, Array<MeasureResult>* results,
-                                  int batch_size) {
-  results->clear();
-  results->reserve(inputs.size());
+Array<MeasureResult> ProgramMeasurerNode::Measure(const SearchTask& task,
+                                                  const SearchPolicy& policy,
+                                                  const Array<MeasureInput>& inputs,
+                                                  int batch_size) {
+  Array<MeasureResult> results;
+  results.reserve(inputs.size());
 
   if (batch_size == -1) {
     // set default batch size
@@ -261,13 +263,15 @@ void ProgramMeasurerNode::Measure(const SearchTask& task, const SearchPolicy& po
 
     // Store result batch
     for (auto& res : result_batch) {
-      results->push_back(res);
+      results.push_back(res);
     }
 
     if (error_ct > max_continuous_error) {
       LOG(FATAL) << "Too many errors happened during tuning";
     }
   }
+
+  return results;
 }
 
 void ProgramMeasurerNode::SilentMeasure(const SearchTask& task, const Array<MeasureInput>& inputs,
@@ -343,6 +347,12 @@ TVM_REGISTER_GLOBAL("auto_scheduler.MeasureResult")
       return MeasureResult(costs, error_no, error_msg, all_cost, timestamp);
     });
 
+TVM_REGISTER_GLOBAL("auto_scheduler.ProgramMeasurer")
+    .set_body_typed([](ProgramBuilder builder, ProgramRunner runner,
+                       Array<MeasureCallback> callbacks, int verbose, int max_continuous_error) {
+      return ProgramMeasurer(builder, runner, callbacks, verbose, max_continuous_error);
+    });
+
 TVM_REGISTER_GLOBAL("auto_scheduler.ProgramBuilderBuild")
     .set_body_typed([](const ProgramBuilder& builder, const Array<MeasureInput>& inputs,
                        int verbose) { return builder->Build(inputs, verbose); });
diff --git a/src/auto_scheduler/search_policy/empty_policy.cc b/src/auto_scheduler/search_policy/empty_policy.cc
index 21a68ac..fba1ac2 100644
--- a/src/auto_scheduler/search_policy/empty_policy.cc
+++ b/src/auto_scheduler/search_policy/empty_policy.cc
@@ -19,7 +19,8 @@
 
 /*!
  * \file auto_scheduler/search_policy/empty_policy.cc
- * \brief This is an brief example of search policy.
+ * \brief A simple example of the search policy which always returns the initial naive schedule
+ * (state).
  */
 
 #include "empty_policy.h"
@@ -29,6 +30,8 @@
 
 #include <utility>
 
+#include "utils.h"
+
 namespace tvm {
 namespace auto_scheduler {
 
@@ -64,19 +67,18 @@ State EmptyPolicyNode::Search(int num_measure_trials, int early_stopping,
     measurer->Reset();
     int ct = 0;
     // In each round, we call SearchOneRound to get several candidate states,
-    // then use ProgramMeasurer to test their performance
+    // then use ProgramMeasurer to measure their performance.
     while (ct < num_measure_trials) {
       const auto& res = SearchOneRound();
       ct += res.size();
       // Build MeasureInputs for measuring
       inputs.clear();
       for (const auto& state : res) {
-        // The class members measured_states_set_ provided by SearchPolicy can be used to filter
-        // out the already measured states
         inputs.push_back(MeasureInput(search_task, state));
       }
+      // Perform measurement.
       // ProgramMeasurer will record the state with best performance during measure process
-      measurer->Measure(search_task, GetRef<SearchPolicy>(this), inputs, &results);
+      results = measurer->Measure(search_task, GetRef<SearchPolicy>(this), inputs);
     }
 
     // Return a state with best measured performance
@@ -84,18 +86,33 @@ State EmptyPolicyNode::Search(int num_measure_trials, int early_stopping,
   }
 }
 
+std::pair<Array<MeasureInput>, Array<MeasureResult>> EmptyPolicyNode::ContinueSearchOneRound(
+    int num_measure, ProgramMeasurer measurer) {
+  Array<State> best_states;
+  Array<MeasureInput> inputs;
+  Array<MeasureResult> results;
+
+  // Search one round to get promising states
+  PrintTitle("Search", verbose);
+  best_states = SearchOneRound();
+
+  // Measure these states
+  PrintTitle("Measure", verbose);
+  for (const auto& state : best_states) {
+    inputs.push_back(MeasureInput(search_task, state));
+  }
+  results = measurer->Measure(search_task, GetRef<SearchPolicy>(this), inputs);
+
+  return std::make_pair(std::move(inputs), std::move(results));
+}
+
 // As an example policy, EmptyPolicy always returns a init state
 Array<State> EmptyPolicyNode::SearchOneRound() {
   Array<State> res;
 
-  // 1. We will process `Program sampling` first to generate several initial schedules
+  // Simply return the initial naive schedule (state).
   res.push_back(search_task->compute_dag->init_state);
 
-  // 2. Then `Performance Tuning`: use cost model and evolutionary search to seek for the schedule
-  // with best performance
-  // Note: This example policy does not include this part
-
-  // 3. The returned candidate schedules will be measured in hardware
   return res;
 }
 
diff --git a/src/auto_scheduler/search_policy/empty_policy.h b/src/auto_scheduler/search_policy/empty_policy.h
index 3d13822..2219ebc 100644
--- a/src/auto_scheduler/search_policy/empty_policy.h
+++ b/src/auto_scheduler/search_policy/empty_policy.h
@@ -19,7 +19,7 @@
 
 /*!
  * \file auto_scheduler/search_policy/empty_policy.h
- * \brief A brief example of the search policy which always returns the initial naive schedule
+ * \brief A simple example of the search policy which always returns the initial naive schedule
  * (state).
  */
 
@@ -27,14 +27,17 @@
 #define TVM_AUTO_SCHEDULER_SEARCH_POLICY_EMPTY_POLICY_H_
 
 #include <tvm/auto_scheduler/loop_state.h>
+#include <tvm/auto_scheduler/measure.h>
 #include <tvm/auto_scheduler/search_policy.h>
 
+#include <utility>
+
 namespace tvm {
 namespace auto_scheduler {
 
 /*!
- * \brief A brief example of the search policy which always returns the initial naive schedule
- * (state), the formal search policy will continue to follow its design.
+ * \brief A simple example of the search policy which always returns the initial naive schedule
+ * (state).
  * The key implementation for this structure is `Search()`, check `empty_policy.cc` for more
  * details.
  */
@@ -43,13 +46,16 @@ class EmptyPolicyNode : public SearchPolicyNode {
   State Search(int num_measure_trials, int early_stopping, int num_measures_per_round,
                ProgramMeasurer measurer) final;
 
+  std::pair<Array<MeasureInput>, Array<MeasureResult>> ContinueSearchOneRound(
+      int num_measure, ProgramMeasurer measurer) final;
+
   static constexpr const char* _type_key = "auto_scheduler.EmptyPolicy";
   TVM_DECLARE_FINAL_OBJECT_INFO(EmptyPolicyNode, SearchPolicyNode);
 
  private:
   /*!
    * \brief Use a sub function to generate several candidate states in each search round.
-   * \returns Several generated states
+   * \returns The generated states
    */
   Array<State> SearchOneRound();
 };
diff --git a/src/auto_scheduler/search_policy/search_policy.cc b/src/auto_scheduler/search_policy/search_policy.cc
index d73bd91..8b6d22b 100644
--- a/src/auto_scheduler/search_policy/search_policy.cc
+++ b/src/auto_scheduler/search_policy/search_policy.cc
@@ -104,8 +104,13 @@ TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicyRunCallbacks")
       }
     });
 
-TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicySetTask")
-    .set_body_typed([](SearchPolicy policy, SearchTask task) { policy->search_task = task; });
+TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicyContinueSearchOneRound")
+    .set_body_typed([](SearchPolicy policy, int num_measure, ProgramMeasurer measurer) {
+      Array<MeasureInput> inputs;
+      Array<MeasureResult> results;
+      std::tie(inputs, results) = policy->ContinueSearchOneRound(num_measure, measurer);
+      return Array<ObjectRef>{inputs, results};
+    });
 
 TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicySetVerbose")
     .set_body_typed([](SearchPolicy policy, int verbose) { policy->verbose = verbose; });
diff --git a/src/auto_scheduler/search_policy/sketch_policy.cc b/src/auto_scheduler/search_policy/sketch_policy.cc
index a89fa4b..8de17a6 100644
--- a/src/auto_scheduler/search_policy/sketch_policy.cc
+++ b/src/auto_scheduler/search_policy/sketch_policy.cc
@@ -157,6 +157,7 @@ State SketchPolicyNode::Search(int n_trials, int early_stopping, int num_measure
 
     int ct = 0;
     int empty_retry_count = GetIntParam(params, SketchParamKey::empty_retry_count);
+    Array<State> best_states, random_states;
     Array<MeasureInput> inputs;
     Array<MeasureResult> results;
     while (ct < n_trials) {
@@ -168,8 +169,7 @@ State SketchPolicyNode::Search(int n_trials, int early_stopping, int num_measure
 
       // Search one round to get promising states
       PrintTitle("Search", verbose);
-      Array<State> random_states;
-      Array<State> best_states = SearchOneRound(num_random, &random_states);
+      best_states = SearchOneRound(num_random * 3, &random_states);
 
       // Infer bound. This is necessary for computing the correct ToStr() for redundancy check
       best_states = search_task->compute_dag.InferBound(best_states);
@@ -196,7 +196,7 @@ State SketchPolicyNode::Search(int n_trials, int early_stopping, int num_measure
 
       // Measure candidate states
       PrintTitle("Measure", verbose);
-      measurer->Measure(search_task, GetRef<SearchPolicy>(this), inputs, &results);
+      results = measurer->Measure(search_task, GetRef<SearchPolicy>(this), inputs);
       ct += inputs.size();
 
       // Check if reach the early stopping condition
@@ -218,15 +218,45 @@ State SketchPolicyNode::Search(int n_trials, int early_stopping, int num_measure
   }
 }
 
-Array<State> SketchPolicyNode::SearchOneRound(int num_random_states, Array<State>* random_states) {
-  // Temporal object to be used if the input pointer is nullptr
-  Array<State> temp_random_states;
-  if (random_states == nullptr) {
-    random_states = &temp_random_states;
-  } else {
-    random_states->clear();
+std::pair<Array<MeasureInput>, Array<MeasureResult>> SketchPolicyNode::ContinueSearchOneRound(
+    int num_measure, ProgramMeasurer measurer) {
+  num_measure_per_iter_ = num_measure;
+
+  Array<State> best_states, random_states;
+  Array<MeasureInput> inputs;
+  Array<MeasureResult> results;
+  int num_random = static_cast<int>(GetDoubleParam(params, "eps_greedy") * num_measure);
+
+  // Search one round to get promising states
+  PrintTitle("Search", verbose);
+  best_states = SearchOneRound(num_random * 3, &random_states);
+
+  // Infer bound. This is necessary for computing the correct ToStr() for redundancy check
+  best_states = search_task->compute_dag.InferBound(best_states);
+  random_states = search_task->compute_dag.InferBound(random_states);
+
+  // Pick `num_measure_per_iter` states to measure, check hash to remove already measured state
+  // Also pick some random states to do eps-greedy
+  inputs = PickStatesWithEpsGreedy(best_states, random_states, num_measure);
+
+  // Measure candidate states
+  PrintTitle("Measure", verbose);
+  results = measurer->Measure(search_task, GetRef<SearchPolicy>(this), inputs);
+
+  // Update measured states throughputs. These states will join the EvolutionarySearch in later
+  // search rounds.
+  for (const auto& res : results) {
+    measured_states_throughputs_.push_back(1.0 / FloatArrayMean(res->costs));
   }
 
+  // Update the cost model
+  PrintTitle("Train cost model", verbose);
+  program_cost_model->Update(inputs, results);
+
+  return std::make_pair(std::move(inputs), std::move(results));
+}
+
+Array<State> SketchPolicyNode::SearchOneRound(int num_random_states, Array<State>* random_states) {
   // Get parameters
   int population = GetIntParam(params, SketchParamKey::EvolutionarySearch::population);
   int num_use_measured =
@@ -245,8 +275,8 @@ Array<State> SketchPolicyNode::SearchOneRound(int num_random_states, Array<State
   Array<State> init_population = SampleInitPopulation(
       sketch_cache_, is_cost_model_reasonable ? population - num_use_measured : population);
 
-  // 3. If the cost model is useless (i.e. RandomCostModel), just random pick some generated
-  // states, else perform evolutionary search
+  // 3. Perform evolutionary search if a cost model is utilized. Otherwise,
+  // just return some random states.
   if (is_cost_model_reasonable) {
     // Also insert already measured good states to the initial population
     std::vector<int> indices = Argsort(measured_states_throughputs_);
@@ -254,11 +284,13 @@ Array<State> SketchPolicyNode::SearchOneRound(int num_random_states, Array<State
       init_population.push_back(measured_states_vector_[indices[i]]);
     }
     // Sample some random states for eps-greedy
-    *random_states = RandomSampleStates(init_population, &rand_gen, num_random_states * 3);
+    if (num_random_states > 0 && random_states != nullptr) {
+      *random_states = RandomSampleStates(init_population, &rand_gen, num_random_states);
+    }
     return EvolutionarySearch(init_population, num_measure_per_iter_ * 2);
   } else {
     PruneInvalidState(search_task, &init_population);
-    return RandomSampleStates(init_population, &rand_gen, num_measure_per_iter_ * 3);
+    return RandomSampleStates(init_population, &rand_gen, num_measure_per_iter_ * 2);
   }
 }
 
@@ -347,10 +379,7 @@ Array<State> SketchPolicyNode::SampleInitPopulation(const Array<State>& sketches
 
     support::parallel_for(0, out_size - out_states.size(),
                           [this, &temp_states, &sketches, &rand_gens](int index) {
-                            // Random choose a starting sketch
-                            // TODO(jcf94, merrymercy): Maybe choose sketches in different
-                            // possibility for they may have different potential on generating state
-                            // with better performance
+                            // Randomly choose a sketch
                             State tmp_s = sketches[(rand_gens[index])() % sketches.size()];
                             // Derivation rule based enumeration
                             bool valid = true;
@@ -472,6 +501,8 @@ Array<State> SketchPolicyNode::EvolutionarySearch(const Array<State>& init_popul
     // Compute selection probability
     ComputePrefixSumProb(pop_scores, &pop_selection_probs);
 
+    // TODO(merrymercy, comaniac): add crossover.
+
     // Do mutation
     while (pnext->size() < population) {
       State tmp_s = (*pnow)[RandomChoose(pop_selection_probs, &rand_gen)];
diff --git a/src/auto_scheduler/search_policy/sketch_policy.h b/src/auto_scheduler/search_policy/sketch_policy.h
index 21aaa6e..edaa89e 100644
--- a/src/auto_scheduler/search_policy/sketch_policy.h
+++ b/src/auto_scheduler/search_policy/sketch_policy.h
@@ -19,13 +19,15 @@
 
 /*!
  * \file auto_scheduler/search_policy/sketch_policy.h
- * \brief The search policy that searches in a hierarchical search space defined by sketches.
- * The policy randomly samples programs from the space defined by sketches and use evolutionary
- * search to fine-tune them.
+ * \brief This search policy constructs a search space according to the compute declaration.
+ * It then randomly samples programs from the search space and uses evolutionary search with a
+ * learned cost model to fine tune the sampled programs.
+ * The final optimized programs are sent to actual hardware for measurement.
+ * The above process is repeated until the auto-scheduler runs out of time budget.
  *
  * Reference:
  * L. Zheng, C. Jia, M. Sun, Z. Wu, C. Yu, et al. "Ansor : Generating High-Performance Tensor
- * Programs for Deep Learning." arXiv preprint arXiv:2006.06762 (2020).
+ * Programs for Deep Learning." (OSDI 2020).
  */
 
 #ifndef TVM_AUTO_SCHEDULER_SEARCH_POLICY_SKETCH_POLICY_H_
@@ -106,6 +108,9 @@ class SketchPolicyNode : public SearchPolicyNode {
   State Search(int num_measure_trials, int early_stopping, int num_measures_per_round,
                ProgramMeasurer measurer) final;
 
+  std::pair<Array<MeasureInput>, Array<MeasureResult>> ContinueSearchOneRound(
+      int num_measure, ProgramMeasurer measurer) final;
+
   /*!
    * \brief Generate sketches.
    * \return The generated sketches(states).
diff --git a/src/auto_scheduler/search_policy/sketch_policy_rules.cc b/src/auto_scheduler/search_policy/sketch_policy_rules.cc
index 99188d4..b6ad4d3 100644
--- a/src/auto_scheduler/search_policy/sketch_policy_rules.cc
+++ b/src/auto_scheduler/search_policy/sketch_policy_rules.cc
@@ -19,7 +19,8 @@
 
 /*!
  * \file auto_scheduler/search_policy/sketch_policy_rules.cc
- * \brief Rules defined to generate the sketches and initial sampled states in SketchPolicy.
+ * \brief Rules for generating the sketches, sampling the initial population, and mutating the
+ * population in SketchPolicy.
  */
 
 #include "sketch_policy_rules.h"
@@ -317,7 +318,7 @@ SketchGenerationRule::ConditionKind RuleCrossThreadReduction::MeetCondition(
     const SketchPolicyNode& policy, const State& state, int stage_id) const {
   CHECK(IsGPUTask(policy.search_task));
 
-  // If it is an intermidiate state created by RuleAddCacheWrite,
+  // If it is an intermediate state created by RuleAddCacheWrite,
   // we just skip it.
   if (HasCacheWriteStage(state, stage_id)) {
     return ConditionKind::kSkip;
@@ -1116,6 +1117,10 @@ PopulationGenerationRule::ResultKind MutateParallel::Apply(SketchPolicyNode* pol
     }
   }
 
+  if (max_fusable_iter_id == 0) {
+    return ResultKind::kInvalid;
+  }
+
   // Randomly pick one granularity
   int fuse_to_iter_id = (*rand_gen)() % max_fusable_iter_id + 1;
   Array<Integer> fused_ids;
diff --git a/src/auto_scheduler/search_policy/sketch_policy_rules.h b/src/auto_scheduler/search_policy/sketch_policy_rules.h
index 035dc89..046f036 100644
--- a/src/auto_scheduler/search_policy/sketch_policy_rules.h
+++ b/src/auto_scheduler/search_policy/sketch_policy_rules.h
@@ -19,7 +19,8 @@
 
 /*!
  * \file auto_scheduler/search_policy/sketch_policy_rules.h
- * \brief Rules defined to generate the sketches and initial sampled states in SketchPolicy.
+ * \brief Rules for generating the sketches, sampling the initial population, and mutating the
+ * population in SketchPolicy.
  */
 
 #ifndef TVM_AUTO_SCHEDULER_SEARCH_POLICY_SKETCH_POLICY_RULES_H_
diff --git a/tests/python/unittest/test_auto_scheduler_task_scheduler.py b/tests/python/unittest/test_auto_scheduler_task_scheduler.py
new file mode 100644
index 0000000..72b998a
--- /dev/null
+++ b/tests/python/unittest/test_auto_scheduler_task_scheduler.py
@@ -0,0 +1,112 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+""" Test task scheduler """
+
+import tempfile
+
+import numpy as np
+
+from tvm import auto_scheduler
+
+from test_auto_scheduler_common import matmul_auto_scheduler_test
+
+
+def test_task_scheduler_round_robin():
+    tasks = []
+    for n in [2, 4, 8]:
+        tasks.append(auto_scheduler.create_task(matmul_auto_scheduler_test, (n, n, n), "llvm"))
+
+    def objective_func(costs):
+        return sum(costs)
+
+    with tempfile.NamedTemporaryFile() as fp:
+        log_file = fp.name
+        num_trials_per_task = 2
+
+        # Tune all tasks
+        tune_option = auto_scheduler.TuningOptions(
+            num_measure_trials=num_trials_per_task * len(tasks),
+            num_measures_per_round=1,
+            measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
+        )
+        task_scheduler = auto_scheduler.TaskScheduler(tasks, objective_func, strategy="round-robin")
+        task_scheduler.tune(tune_option, search_policy="sketch.random")
+
+        # Check the result of round robin
+        counters = {}
+        for task in tasks:
+            counters[task.workload_key] = 0
+
+        for inp, res in auto_scheduler.load_records(log_file):
+            counters[inp.task.workload_key] += 1
+
+        for task in tasks:
+            assert counters[task.workload_key] == num_trials_per_task
+
+        # test continuous tuning (restoring the status)
+        task_scheduler = auto_scheduler.TaskScheduler(
+            tasks, objective_func, strategy="round-robin", load_log_file=log_file
+        )
+        tune_option = auto_scheduler.TuningOptions(
+            num_measure_trials=len(tasks),
+            num_measures_per_round=1,
+        )
+        task_scheduler.tune(tune_option, search_policy="sketch.random")
+
+
+def test_task_scheduler_gradient():
+    tasks = []
+    for n in [2, 4]:
+        tasks.append(auto_scheduler.create_task(matmul_auto_scheduler_test, (n, n, n), "llvm"))
+
+    def objective_func(costs):
+        return costs[0]
+
+    with tempfile.NamedTemporaryFile() as fp:
+        log_file = fp.name
+
+        n_trials = 5
+
+        # Tune all tasks
+        tune_option = auto_scheduler.TuningOptions(
+            num_measure_trials=n_trials,
+            num_measures_per_round=1,
+            measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
+        )
+        task_scheduler = auto_scheduler.TaskScheduler(tasks, objective_func)
+
+        # Forcely rewrite the initial values.
+        # This can make this test more stable on the slow CI machines
+        task_scheduler.best_costs = np.array([1e2, 1e-8])
+
+        task_scheduler.tune(tune_option, search_policy="sketch.random")
+
+        # Check the allocation results
+        counters = {}
+        for task in tasks:
+            counters[task.workload_key] = 0
+
+        for inp, res in auto_scheduler.load_records(log_file):
+            counters[inp.task.workload_key] += 1
+
+        assert counters[tasks[0].workload_key] == n_trials - 1
+        assert counters[tasks[1].workload_key] == 1
+
+
+if __name__ == "__main__":
+    test_task_scheduler_round_robin()
+    test_task_scheduler_gradient()
diff --git a/tutorials/auto_scheduler/tune_conv2d_layer_cuda.py b/tutorials/auto_scheduler/tune_conv2d_layer_cuda.py
index 5004a5f..b800eb4 100644
--- a/tutorials/auto_scheduler/tune_conv2d_layer_cuda.py
+++ b/tutorials/auto_scheduler/tune_conv2d_layer_cuda.py
@@ -25,10 +25,9 @@ Auto-scheduling a convolution layer for GPU
 
 Different from the existing :ref:`autotvm <tutorials-autotvm-sec>` which relies on 
 manual templates to define the search space, the auto-scheduler does not require any templates.
-The auto-scheduler is template-free, so users only need to write the computation declaration without
-any schedule commands or templates.
-The auto-scheduler can automatically generate a large
-search space and find a good schedule in the space.
+Users only need to write the computation declaration without any schedule commands or templates.
+The auto-scheduler can automatically generate a large search space and
+find a good schedule in the space.
 
 We use a convolution layer as an example in this tutorial.
 """
diff --git a/tutorials/auto_scheduler/tune_matmul_x86.py b/tutorials/auto_scheduler/tune_matmul_x86.py
index e1e0115..35c4744 100644
--- a/tutorials/auto_scheduler/tune_matmul_x86.py
+++ b/tutorials/auto_scheduler/tune_matmul_x86.py
@@ -22,10 +22,9 @@ Auto-scheduling matrix multiplication for CPU
 
 Different from the existing :ref:`autotvm <tutorials-autotvm-sec>` which relies on 
 manual templates to define the search space, the auto-scheduler does not require any templates.
-The auto-scheduler is template-free, so users only need to write the computation declaration without
-any schedule commands or templates.
-The auto-scheduler can automatically generate a large
-search space and find a good schedule in the space.
+Users only need to write the computation declaration without any schedule commands or templates.
+The auto-scheduler can automatically generate a large search space and
+find a good schedule in the space.
 
 We use matrix multiplication as an example in this tutorial.
 """