You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by lm...@apache.org on 2020/10/19 06:58:25 UTC
[incubator-tvm] branch main updated: [AutoScheduler] Add task
scheduler (#6663)
This is an automated email from the ASF dual-hosted git repository.
lmzheng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 94679b5 [AutoScheduler] Add task scheduler (#6663)
94679b5 is described below
commit 94679b5cf46bd89872ae611995ff43f6dae78786
Author: Lianmin Zheng <li...@gmail.com>
AuthorDate: Sun Oct 18 23:55:40 2020 -0700
[AutoScheduler] Add task scheduler (#6663)
* Add task scheduler
* fix lint
* fix tests
* fix tests
* fix tests
* fix test cases
* fix test cases
* fix tests
* address comments
---
include/tvm/auto_scheduler/measure.h | 22 +-
include/tvm/auto_scheduler/search_policy.h | 31 +-
python/tvm/auto_scheduler/__init__.py | 2 +
python/tvm/auto_scheduler/auto_schedule.py | 11 +-
python/tvm/auto_scheduler/cost_model/xgb_model.py | 6 +-
python/tvm/auto_scheduler/measure.py | 27 ++
python/tvm/auto_scheduler/search_policy.py | 51 ++-
python/tvm/auto_scheduler/task_scheduler.py | 422 +++++++++++++++++++++
python/tvm/auto_scheduler/utils.py | 47 +++
src/auto_scheduler/auto_schedule.cc | 4 +-
src/auto_scheduler/feature.cc | 4 +-
src/auto_scheduler/measure.cc | 22 +-
src/auto_scheduler/search_policy/empty_policy.cc | 39 +-
src/auto_scheduler/search_policy/empty_policy.h | 14 +-
src/auto_scheduler/search_policy/search_policy.cc | 9 +-
src/auto_scheduler/search_policy/sketch_policy.cc | 67 +++-
src/auto_scheduler/search_policy/sketch_policy.h | 13 +-
.../search_policy/sketch_policy_rules.cc | 9 +-
.../search_policy/sketch_policy_rules.h | 3 +-
.../unittest/test_auto_scheduler_task_scheduler.py | 112 ++++++
tutorials/auto_scheduler/tune_conv2d_layer_cuda.py | 7 +-
tutorials/auto_scheduler/tune_matmul_x86.py | 7 +-
22 files changed, 819 insertions(+), 110 deletions(-)
diff --git a/include/tvm/auto_scheduler/measure.h b/include/tvm/auto_scheduler/measure.h
index 349f4f8..339f428 100755
--- a/include/tvm/auto_scheduler/measure.h
+++ b/include/tvm/auto_scheduler/measure.h
@@ -423,7 +423,7 @@ class RPCRunner : public ProgramRunner {
/*!
* \brief Measurer that measures the time costs of tvm programs
- * This class combines ProgramBuilder and ProgramRunner and provides a simpler API */
+ * This class combines ProgramBuilder and ProgramRunner, and provides a simpler API */
class ProgramMeasurerNode : public Object {
public:
/*! \brief Measured programs counter. */
@@ -444,7 +444,7 @@ class ProgramMeasurerNode : public Object {
Optional<Array<MeasureCallback>> callbacks;
/*! \brief Verbosity level. 0 for silent, 1 to output information during program measuring. */
int verbose;
- /*! \brief The number of max continuous error. */
+ /*! \brief The number of allowed maximum continuous error before forcely stopping the tuning */
int max_continuous_error;
/*! \brief Reset book keeping variables */
@@ -454,13 +454,12 @@ class ProgramMeasurerNode : public Object {
* \brief Do measurement.
* \param task The current SearchTask.
* \param policy The current SearchPolicy.
- * \param inputs The MeasureInputs.
- * \param results A pointer to a MeasureResult Array, this is used as output.
+ * \param inputs The inputs of measurement.
* \param batch_size Number of programs to be measured in one batch.
+ * \return results The results of measurement.
*/
- void Measure(const SearchTask& task, const SearchPolicy& policy,
- const Array<MeasureInput>& inputs, Array<MeasureResult>* results,
- int batch_size = -1);
+ Array<MeasureResult> Measure(const SearchTask& task, const SearchPolicy& policy,
+ const Array<MeasureInput>& inputs, int batch_size = -1);
/*!
* \brief Do measurement silently.
* This API will not print the measure results to screen.
@@ -486,12 +485,13 @@ class ProgramMeasurer : public ObjectRef {
public:
/*!
* \brief The constructor.
- * \param builder The ProgramBuilder to build each program.
- * \param runner The ProgramRunner to measure each program.
- * \param callbacks MeasureCallback to be called after each measure batch.
+ * \param builder The ProgramBuilder to build programs.
+ * \param runner The ProgramRunner to measure programs.
+ * \param callbacks MeasureCallback to be called after each measurement batch.
* \param verbose Verbosity level. 0 for silent, 1 to output information during program
* measuring.
- * \param max_continuous_error The number of allowed maximum continuous error.
+ * \param max_continuous_error The number of allowed maximum continuous error before
+ * forcely stopping the tuning.
*/
ProgramMeasurer(ProgramBuilder builder, ProgramRunner runner,
Optional<Array<MeasureCallback>> callbacks, int verbose,
diff --git a/include/tvm/auto_scheduler/search_policy.h b/include/tvm/auto_scheduler/search_policy.h
index ddb0dd2..e433799 100755
--- a/include/tvm/auto_scheduler/search_policy.h
+++ b/include/tvm/auto_scheduler/search_policy.h
@@ -22,26 +22,6 @@
* \brief The base class of search policies, including the abstract definition of search policy and
* other supporting data structures.
*
- * The basic schedule search process for the auto-scheduler is design to be:
- * `Program sampling` -> `Performance Tuning`.
- *
- * In `Program sampling`, we use some predefined precise or heuristic rules to generate several
- * initial schedules. Based on these initial starting points, we perform `Performance Tuning` which
- * uses cost model based evolutionary search to select schedules with the best performance.
- *
- * Candidate schedules are measured against the specific hardware target.
- *
- * We intend to introduce different level of automation on the schedule generation process:
- * - Level 0(the default level): For all kinds of ops/subgraphs, the search policy should be able
- * to generate schedule automatically.
- * - Level 1: For some complicated ops/subgraphs(e.g. conv2d windograd), the default search space
- * of level 0 may be too large to find a high performance schedule efficiently. We provide some
- * op attributes to help reduce the total search space, see `SearchPolicyKey` below for more
- * information.
- * - Level 2: For some further special ops/subgraphs, users may more likely to write their own
- * template(just like AutoTVM). Search policy should be able to provide a flexible approach as
- * well.
- *
* \note How to add a new search policy.
* In design, there's no need for users to implement their own search policy, our formal search
* policy(will be brought later) should be enough to cover most use cases. Meanwhile, a custom rule
@@ -62,11 +42,13 @@
#ifndef TVM_AUTO_SCHEDULER_SEARCH_POLICY_H_
#define TVM_AUTO_SCHEDULER_SEARCH_POLICY_H_
+#include <tvm/auto_scheduler/measure.h>
#include <tvm/auto_scheduler/search_task.h>
#include <tvm/node/node.h>
#include <string>
#include <unordered_set>
+#include <utility>
#include <vector>
namespace tvm {
@@ -172,6 +154,15 @@ class SearchPolicyNode : public Object {
ProgramMeasurer measurer) = 0;
/*!
+ * \brief Continue the search by doing an additional search round.
+ * \param num_measure The number of measurements
+ * \param measurer The measurer to measure programs
+ * \return The measurement records for measurements in this search round
+ */
+ virtual std::pair<Array<MeasureInput>, Array<MeasureResult>> ContinueSearchOneRound(
+ int num_measure, ProgramMeasurer measurer) = 0;
+
+ /*!
* \brief Preload measured states from a log file to resume the state of the search policy.
* \param log_file The name of the record log file.
*/
diff --git a/python/tvm/auto_scheduler/__init__.py b/python/tvm/auto_scheduler/__init__.py
index 6a395e7..99d96e8 100644
--- a/python/tvm/auto_scheduler/__init__.py
+++ b/python/tvm/auto_scheduler/__init__.py
@@ -24,6 +24,7 @@ from . import measure
from . import measure_record
from . import search_policy
from . import search_task
+from . import task_scheduler
from . import utils
from . import workload_registry
@@ -42,4 +43,5 @@ from .measure import (
from .measure_record import RecordToFile, RecordReader, load_best, load_records, save_records
from .search_task import SearchTask
from .search_policy import EmptyPolicy, SketchPolicy, PreloadMeasuredStates
+from .task_scheduler import TaskScheduler
from .workload_registry import register_workload, make_workload_key
diff --git a/python/tvm/auto_scheduler/auto_schedule.py b/python/tvm/auto_scheduler/auto_schedule.py
index ca069bb..a53c29d 100644
--- a/python/tvm/auto_scheduler/auto_schedule.py
+++ b/python/tvm/auto_scheduler/auto_schedule.py
@@ -16,16 +16,7 @@
# under the License.
"""
-User interface for TVM Auto-scheduler.
-
-The basic schedule search process for TVM Auto-scheduler is designed to be:
-`Program sampling` -> `Performance Tuning`.
-
-In `Program sampling`, we use some predefined precise or heuristic rules to generate several
-initial schedules. Based on these initial starting points, we perform `Performance Tuning` which
-uses cost model based evolutionary search to select schedules with the best performance.
-
-Candidate schedules are measured against the specific hardware target.
+The user interface and tuning options of the TVM auto-scheduler.
"""
import tvm._ffi
diff --git a/python/tvm/auto_scheduler/cost_model/xgb_model.py b/python/tvm/auto_scheduler/cost_model/xgb_model.py
index 3eb64df..9a534aa 100644
--- a/python/tvm/auto_scheduler/cost_model/xgb_model.py
+++ b/python/tvm/auto_scheduler/cost_model/xgb_model.py
@@ -20,6 +20,7 @@
import multiprocessing
import logging
from collections import defaultdict
+import time
import numpy as np
import xgboost as xgb
@@ -76,7 +77,7 @@ dmatrix_context = XGBDMatrixContext()
class XGBModel(PythonBasedModel):
"""Train a XGBoost model to predict the normalized throughputs of programs.
Let the normalized throughput be the score of a program (higher is better). We predict
- the (approximiate) score of a program = the sum of the scores of all stages in this program.
+ the (approximate) score of a program = the sum of the scores of all stages in this program.
i.e. score(P) = score_s0 + score_s1 + ... + score_sn,
where score_si is the score of Stage i in Program P.
We extract feature for each stage and let the xgboost predict the score for each stage.
@@ -128,6 +129,7 @@ class XGBModel(PythonBasedModel):
if len(inputs) <= 0:
return
assert len(inputs) == len(results)
+ tic = time.time()
self.inputs.extend(inputs)
self.results.extend(results)
@@ -167,6 +169,8 @@ class XGBModel(PythonBasedModel):
],
)
+ logger.info("XGBModel Training time: %.2f s", time.time() - tic)
+
def predict(self, task, states):
"""Predict the scores of states
Parameters
diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py
index 81c314f..8a8b922 100644
--- a/python/tvm/auto_scheduler/measure.py
+++ b/python/tvm/auto_scheduler/measure.py
@@ -185,6 +185,33 @@ class ProgramRunner(Object):
return _ffi_api.ProgramRunnerRun(self, measure_inputs, build_results, verbose)
+@tvm._ffi.register_object("auto_scheduler.ProgramMeasurer")
+class ProgramMeasurer(Object):
+ """
+ Measurer that measures the time costs of tvm programs
+ This class combines ProgramBuilder and ProgramRunner, and provides a simpler API.
+
+ Parameters
+ ----------
+ builder : ProgramBuilder
+ The ProgramBuilder to build programs
+ runner : ProgramRunner
+ The ProgramRunner to measure programs.
+ callbacks : List[MeasureCallback]
+ Callbacks to be called after each measurement batch
+ verbose : int
+ The Verbosity level: 0 for silent, 1 to output information during program
+ max_continuous_error : Optional[int]
+ The number of allowed maximum continuous error before stop the tuning
+ """
+
+ def __init__(self, builder, runner, callbacks, verbose, max_continuous_error=None):
+ max_continuous_error = max_continuous_error or -1 # -1 means using the default value
+ self.__init_handle_by_constructor__(
+ _ffi_api.ProgramMeasurer, builder, runner, callbacks, verbose, max_continuous_error
+ )
+
+
@tvm._ffi.register_object("auto_scheduler.LocalBuilder")
class LocalBuilder(ProgramBuilder):
"""LocalBuilder use local CPU cores to build programs in parallel.
diff --git a/python/tvm/auto_scheduler/search_policy.py b/python/tvm/auto_scheduler/search_policy.py
index 5533aec..f3d459e 100644
--- a/python/tvm/auto_scheduler/search_policy.py
+++ b/python/tvm/auto_scheduler/search_policy.py
@@ -16,15 +16,17 @@
# under the License.
"""
-The search policies for TVM Auto-scheduler.
+The search policies of TVM auto-scheduler.
-This contains the strategies to generate a schedule automatically. We provide an EmptyPolicy
-which always returns an unchanged initial state, and a more advanced SketchPolicy which can
-deal with various ops/subgraphs on different target devices.
+The auto-scheduler constructs a search space according to the compute declaration.
+It then randomly samples programs from the search space and uses evolutionary search with a
+learned cost model to fine tune the sampled programs.
+The final optimized programs are sent to actual hardware for measurement.
+The above process is repeated until the auto-scheduler runs out of time budget.
Reference:
L. Zheng, C. Jia, M. Sun, Z. Wu, C. Yu, et al. "Ansor : Generating High-Performance Tensor
-Programs for Deep Learning." arXiv preprint arXiv:2006.06762 (2020).
+Programs for Deep Learning." (OSDI 2020).
"""
import random
@@ -63,11 +65,42 @@ class PreloadMeasuredStates(SearchCallback):
class SearchPolicy(Object):
""" The base class of search policies. """
+ def continue_search_one_round(self, num_measure, measurer):
+ """
+ Continue the search by doing an additional search round.
+
+ Parameters
+ ----------
+ num_measure: int
+ The number of programs to measure in this round
+ measurer: ProgramMeasurer
+ The program measurer to measure programs
+
+ Returns
+ -------
+ inputs: List[MeasureInput]
+ The inputs of measurments in this search round
+ results: List[MeasureResult]
+ The results of measurments in this search round
+ """
+ return _ffi_api.SearchPolicyContinueSearchOneRound(self, num_measure, measurer)
+
+ def set_verbose(self, verbose):
+ """
+ Set the verbosity level of the search policy.
+
+ Parameters
+ ----------
+ verbose: int
+ The verbosity level
+ """
+ return _ffi_api.SearchPolicySetVerbose(self, verbose)
+
@tvm._ffi.register_object("auto_scheduler.EmptyPolicy")
class EmptyPolicy(SearchPolicy):
- """This is an example empty search policy which will always generate
- the init state of ComputeDAG.
+ """A simple example of the search policy which always returns
+ the initial naive schedule (state).
Parameters
----------
@@ -195,15 +228,17 @@ class SketchPolicy(SearchPolicy):
return states
def evolutionary_search(self, init_populations, out_size):
- """Evolutionary search.
+ """Perform evolutionary search.
This python interface is mainly used for debugging and testing.
The actual search is all done in c++.
+
Parameters
----------
init_populations: List[State]
The initial population states
out_size : int
The size of generated states
+
Returns
-------
states: List[State]
diff --git a/python/tvm/auto_scheduler/task_scheduler.py b/python/tvm/auto_scheduler/task_scheduler.py
new file mode 100644
index 0000000..e45573b
--- /dev/null
+++ b/python/tvm/auto_scheduler/task_scheduler.py
@@ -0,0 +1,422 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+
+""" The task scheduler that allocates the time resources when tuning multiple tasks together
+
+The details of the "gradient" strategy below can be found in the section 6 of this paper:
+L. Zheng, C. Jia, M. Sun, Z. Wu, C. Yu, et al. "Ansor : Generating High-Performance Tensor
+Programs for Deep Learning." (OSDI 2020).
+"""
+
+import time
+import math
+import logging
+
+import numpy as np
+
+from .search_policy import SearchPolicy, SketchPolicy
+from .cost_model import RandomModel, XGBModel
+from .utils import array_mean, to_str_round
+from .measure import ProgramMeasurer
+from .measure_record import RecordReader
+
+logger = logging.getLogger("auto_scheduler")
+
+
+def make_search_policies(
+ search_policy, tasks, num_measures_per_round, verbose, load_model_file=None, load_log_file=None
+):
+ """Make a list of search policies for a list of search tasks.
+ It creates one policy per task.
+
+ Parameters
+ ----------
+ search_policy: Union[str, List[SearchPolicy]]
+ The name of search policy.
+ tasks: List[SearchTask]
+ The list of all tasks
+ num_measures_per_round: int
+ The number of schedules to be measured at each search round.
+ This should be the same as `TuningOptions.num_measures_per_round`
+ verbose: int
+ The verbosity level. 0 for silent.
+ load_model_file: Optional[str]
+ Load pre-trained model from this file. If this is None, the cost model will
+ be trained from scratch.
+ load_log_file: Optional[str]
+ Load measurement records from this file. If it is not None, the status of the
+ task scheduler, search policies and cost models will be restored according to this file.
+
+ Returns
+ -------
+ policies: List[SearchPolicy]
+ The list of search policies
+ """
+ if search_policy == "default":
+ search_policy = "sketch.xgb"
+
+ if isinstance(search_policy, str):
+ policy_type, model_type = search_policy.split(".")
+ if model_type == "xgb":
+ cost_model = XGBModel(num_warmup_sample=len(tasks) * num_measures_per_round)
+ if load_model_file:
+ logger.info("Load pretrained model...")
+ cost_model.load(load_model_file)
+ elif load_log_file:
+ cost_model.load_log_file(load_log_file)
+ elif model_type == "random":
+ cost_model = RandomModel()
+ else:
+ raise ValueError("Invalid search policy: " + search_policy)
+
+ if policy_type == "sketch":
+ search_policies = [SketchPolicy(task, cost_model, verbose=verbose) for task in tasks]
+ else:
+ raise ValueError("Invalid search policy: " + search_policy)
+ else:
+ # check type
+ assert isinstance(search_policy, (tuple, list))
+ for item in search_policy:
+ assert isinstance(item, SearchPolicy)
+ search_policies = search_policy
+
+ return search_policies
+
+
+def derive_similarity_tag(dag, log_base=1.618):
+ """Derive the tag for similarity check from one computational DAG.
+ The DAGs with the same tag are considered as similar tasks.
+
+ The tag format is <op1-tag>_<op2-tag> ... <log(flop)>.
+
+ If the tag is "", then the task is not considered to be similar to any other tasks.
+
+ Parameters
+ ----------
+ dag: ComputeDAG
+ The input computational DAG
+ log_base: float = 1.618
+ The base of log to normalize FLOPS
+
+ Returns
+ -------
+ tag: str
+ The tag of this computational DAG.
+ """
+ ret = ""
+ for op in dag.ops:
+ tag = op.attrs.get("auto_scheduler_task_scheduler_tag", None)
+ if tag:
+ ret += op.attrs["auto_scheduler_task_scheduler_tag"] + "_"
+ if ret:
+ ret += "%d" % int(math.log(dag.flop_ct + 1, log_base))
+ return ret
+
+
+class TaskScheduler:
+ """
+ Allocate the time resources when tuning multiple tasks together.
+ This implements two strategies: "round-robin" and "gradient".
+
+ Parameters
+ ----------
+ tasks: List[SearchTask]
+ All tasks to tune
+ objective_func: Optional[Callable[List[float] -> float]]
+ The objective function to be minimized.
+ The objective function accepts the current latencies of all tasks and returns the
+ objective. If not presented, the objective is the sum of the latencies of all task.
+ strategy: str = "gradient"
+ The scheduling strategy.
+ "round-robin": Tune tasks in round robin order.
+ "gradient" : Tune tasks with gradient descent.
+ load_model_file: Optional[str]
+ Load pre-trained model from this file. If this is None, the cost model will
+ be trained from scratch.
+ load_log_file: Optional[str]
+ Load measurement records from this file. If it is not None, the status of the
+ task scheduler, search policies and cost models will be restored according to this file.
+ verbose: int = 1
+ The level of verbosity. 0 means silent.
+ alpha: float = 0.2
+ The parameter used for 'gradient' strategy
+ beta: float = 2
+ The parameter used for 'gradient' strategy
+ backward_window_size: int = 3
+ The parameter used for 'gradient' strategy
+ """
+
+ def __init__(
+ self,
+ tasks,
+ objective_func=None,
+ strategy="gradient",
+ load_model_file: str = None,
+ load_log_file: str = None,
+ verbose: int = 1,
+ alpha: float = 0.2,
+ beta: float = 2,
+ gamma: float = 0.5,
+ backward_window_size: int = 3,
+ ):
+ self.tasks = tasks
+ self.objective_func = objective_func or sum
+ self.strategy = strategy
+ self.verbose = verbose
+ self.load_log_file = load_log_file
+ self.load_model_file = load_model_file
+ self.alpha = alpha
+ self.beta = beta
+ self.gamma = gamma
+ self.backward_window_size = backward_window_size
+
+ assert len(self.tasks) != 0, "No tasks"
+ assert self.strategy in ["round-robin", "gradient"]
+
+ # task_cts[i] saves how many times task i is tuned
+ self.task_cts = [0 for _ in range(len(self.tasks))]
+
+ # task_costs_history[i] saves the latency history of task i
+ self.task_costs_history = [[] for _ in range(len(self.tasks))]
+
+ # best_costs[i] saves the best latency of task i
+ self.best_costs = 1e10 * np.ones(len(self.tasks))
+ self.cur_score = self._compute_score(self.best_costs)
+
+ self.tune_option = self.measurer = self.search_policies = self.ct = self.tic = None
+ self.num_measures_per_round = None
+ self.dead_tasks = set()
+
+ # Build similarity groups
+ self.task_tags = [] # task_id -> tag
+ self.tag_to_group_id = {} # tag -> group_id
+ self.group_task_ids = [] # group_id -> all task ids in this group
+ self.flop_cts = [] # task_id -> the number of floating ops
+ for i, task in enumerate(self.tasks):
+ tag = derive_similarity_tag(task.compute_dag)
+ self.task_tags.append(tag)
+ self.flop_cts.append(task.compute_dag.flop_ct)
+ if not tag:
+ continue
+
+ if tag not in self.tag_to_group_id:
+ self.tag_to_group_id[tag] = len(self.tag_to_group_id)
+ self.group_task_ids.append([])
+ self.group_task_ids[self.tag_to_group_id[tag]].append(i)
+
+ def tune(self, tune_option, search_policy="default"):
+ """Tune a batch of tasks together.
+
+ Parameters
+ ----------
+ tune_option: TuningOptions
+ The options of tuning
+ search_policy: : Union[str, List[SearchPolicy]]
+ The list of search policies.
+ If it is str.
+ "sketch.xgb" for SketchPolicy + XGBModel
+ "sketch.random" for SketchPolicy + RandomModel
+ """
+ # init members
+ self.tune_option = tune_option
+ self.measurer = ProgramMeasurer(
+ tune_option.builder,
+ tune_option.runner,
+ tune_option.measure_callbacks,
+ tune_option.verbose,
+ )
+ self.ct = 0
+ self.tic = time.time()
+ # reset num_measures_per_round to make sure every task is tuned at least once
+ self.num_measures_per_round = min(
+ tune_option.num_measures_per_round, tune_option.num_measure_trials // len(self.tasks)
+ )
+ if self.num_measures_per_round <= 0:
+ raise ValueError("num_measure_trials is too small. Please set it to a higher value.")
+
+ # restore the status of the task scheduler from a log file
+ if self.load_log_file:
+ self._restore_status(self.load_log_file, self.num_measures_per_round)
+
+ # make one search policy for one task
+ self.search_policies = make_search_policies(
+ search_policy,
+ self.tasks,
+ self.num_measures_per_round,
+ tune_option.verbose,
+ self.load_model_file,
+ self.load_log_file,
+ )
+
+ # do a round robin first to warm up
+ for i in range(len(self.tasks)):
+ self._tune_task(i)
+
+ # use the specific strategy to choose workload to tune
+ task_idx = -1
+ while self.ct < tune_option.num_measure_trials and len(self.dead_tasks) < len(self.tasks):
+ if self.strategy == "round-robin":
+ task_idx = (task_idx + 1) % len(self.tasks)
+ while task_idx in self.dead_tasks:
+ task_idx = (task_idx + 1) % len(self.tasks)
+ elif self.strategy == "gradient":
+ gradients = []
+ for i in range(len(self.tasks)):
+ if i in self.dead_tasks:
+ gradients.append(0)
+ continue
+
+ # compute gradient from chain rule : (delta f / delta g_i)
+ delta = 1e-7
+ new_costs = list(self.best_costs)
+ new_costs[i] -= delta
+ chain_grad = (
+ self._compute_score(self.best_costs) - self._compute_score(new_costs)
+ ) / delta
+
+ # compute (g_i(t_i) - g(t_i - \Delta t)) / (\Delta t)
+ if (
+ self.task_cts[i] - 1 < len(self.task_costs_history[i])
+ and self.task_cts[i] - 1 - self.backward_window_size >= 0
+ ):
+ backward_grad = (
+ self.task_costs_history[i][self.task_cts[i] - 1]
+ - self.task_costs_history[i][
+ self.task_cts[i] - 1 - self.backward_window_size
+ ]
+ ) / self.backward_window_size
+ else:
+ backward_grad = 0
+
+ # compute (g_i(t_i + \Delta t) - g(t_i)) / (\Delta t)
+ g_next_1 = self.best_costs[i] - (self.best_costs[i] / self.task_cts[i])
+
+ g_next_2 = self.beta * 1e30
+ group_id = self.tag_to_group_id.get(self.task_tags[i], None)
+ if group_id is not None and len(self.group_task_ids[group_id]) > 1:
+ best_flops = max(
+ [
+ self.flop_cts[j] / self.best_costs[j]
+ for j in self.group_task_ids[group_id]
+ ]
+ )
+ g_next_2 = self.beta * self.flop_cts[i] / best_flops
+
+ g_next = min(g_next_1, g_next_2)
+ forward_grad = g_next - self.best_costs[i]
+
+ # combine all grads
+ grad = chain_grad * (
+ self.alpha * backward_grad + (1 - self.alpha) * forward_grad
+ )
+ assert grad <= 0
+ gradients.append(grad)
+
+ if max(gradients) == min(gradients):
+ task_idx = np.random.choice(len(gradients))
+ else:
+ task_idx = np.argmin(gradients)
+ else:
+ raise ValueError("Invalid strategy: " + self.strategy)
+
+ self._tune_task(task_idx)
+ self._adjust_similarity_group(task_idx)
+
+ def _tune_task(self, task_idx):
+ """Tune the select task for one round"""
+ if self.verbose >= 1:
+ logger.info("TaskScheduler: task id:\t%d", task_idx)
+ measure_inputs, measure_results = self.search_policies[task_idx].continue_search_one_round(
+ self.num_measures_per_round, self.measurer
+ )
+
+ for res in measure_results:
+ cost = array_mean(res.costs)
+ if cost < self.best_costs[task_idx]:
+ self.best_costs[task_idx] = cost
+
+ if len(measure_inputs) == 0:
+ self.dead_tasks.add(task_idx)
+
+ self.task_cts[task_idx] += 1
+ self.task_costs_history[task_idx].append(self.best_costs[task_idx])
+
+ self.ct += len(measure_inputs)
+ self.cur_score = self._compute_score(self.best_costs)
+
+ if self.verbose >= 1:
+ logger.info(
+ "TaskScheduler\tct: %d\testimated cost (ms): %.3f\ttime elapsed: %.2f\t"
+ "best_costs (ms): %s\ttask_ct: %s",
+ self.ct,
+ self.cur_score * 1e3,
+ time.time() - self.tic,
+ to_str_round(self.best_costs * 1e3, decimal=3),
+ self.task_cts,
+ )
+
+ def _compute_score(self, costs):
+ """compute the objective function"""
+ return self.objective_func(costs)
+
+ def _adjust_similarity_group(self, task_idx):
+ """adjust the similarity group for the selected task"""
+ group_id = self.tag_to_group_id.get(self.task_tags[task_idx], None)
+ if group_id is None or len(self.group_task_ids[group_id]) <= 1:
+ return
+
+ group_ids = self.group_task_ids[group_id]
+ best_group_flops = max([self.flop_cts[j] / self.best_costs[j] for j in group_ids])
+ cur_flops = self.flop_cts[task_idx] / self.best_costs[task_idx]
+
+ # if we tune a task for many times but it still cannot achieve
+ # a similar speed to the fastest one in its group, this means this task
+ # is actually not similar to other tasks in its group.
+ # So we will remove it from its original group.
+ if cur_flops < best_group_flops / self.beta and self.task_cts[task_idx] > 5 + max(
+ self.task_cts[j] for j in group_ids if j != task_idx
+ ):
+ self.task_tags[task_idx] = None
+ group_ids.remove(task_idx)
+
+ def _restore_status(self, log_file, num_measures_per_round):
+ """restore task_cts and best_costs from a log file"""
+ str_target = str(self.tasks[0].target)
+ workload_key_to_task_id = {t.workload_key: i for i, t in enumerate(self.tasks)}
+ total_ct = -1
+
+ for total_ct, (inp, res) in enumerate(RecordReader(log_file)):
+ if str(inp.task.target) != str_target:
+ continue
+ task_idx = workload_key_to_task_id.get(inp.task.workload_key, None)
+ if task_idx is None:
+ continue
+
+ if res.error_no == 0:
+ self.best_costs[task_idx] = min(self.best_costs[task_idx], array_mean(res.costs))
+
+ self.task_cts[task_idx] += 1
+
+ for i in range(len(self.tasks)):
+ # The computation of taks_cts is just an estimation.
+ # The estimation may not be accurate if the log file is changed externally or
+ # `num_measures_per_round` is different from the last tuning.
+ self.task_cts[i] = int(self.task_cts[i] / num_measures_per_round + 0.5)
+ self.task_costs_history[i].append(self.best_costs[i])
+
+ logger.info("TaskScheduler: Loaded %d measurement records from %s", total_ct + 1, log_file)
diff --git a/python/tvm/auto_scheduler/utils.py b/python/tvm/auto_scheduler/utils.py
index ff357c4..75fec9c 100644
--- a/python/tvm/auto_scheduler/utils.py
+++ b/python/tvm/auto_scheduler/utils.py
@@ -25,6 +25,8 @@ import signal
import threading
import os
+import numpy as np
+
try:
import psutil
except ImportError:
@@ -264,3 +266,48 @@ def check_remote(device_key, host=None, port=None, priority=100, timeout=10):
t.start()
t.join(timeout)
return not t.is_alive()
+
+
+def array_mean(arr):
+ """Compute mean of the elments in a TVM Array<PrimExpr>
+
+ Parameters
+ ----------
+ arr: Array
+ A TVM Array<PrimExpr>
+
+ Returns
+ -------
+ mean: float
+ The mean of the elements in the array
+ """
+ return sum(x.value for x in arr) / len(arr)
+
+
+def to_str_round(x, decimal=6):
+ """Convert an object to str and round float numbers
+
+ Parameters
+ ----------
+ x: Union[str, list, int, float, np.ndarray]
+ The input object
+ decimal: int
+ The precision of decimal fraction
+
+ Returns
+ -------
+ ret: str
+ The string format of these objects
+ """
+ if isinstance(x, str):
+ return x
+ if isinstance(x, (list, tuple, np.ndarray)):
+ return "[" + ", ".join([to_str_round(y, decimal=decimal) for y in x]) + "]"
+ if isinstance(x, dict):
+ return str({k: to_str_round(v) for k, v in x.items()})
+ if isinstance(x, int):
+ return str(x)
+ if isinstance(x, (np.float32, np.float64, float)):
+ format_str = "%%.%df" % decimal
+ return format_str % x
+ raise ValueError("Invalid value: " + str(x) + "\ttype: " + str(type(x)))
diff --git a/src/auto_scheduler/auto_schedule.cc b/src/auto_scheduler/auto_schedule.cc
index dd6b705..747aa01 100755
--- a/src/auto_scheduler/auto_schedule.cc
+++ b/src/auto_scheduler/auto_schedule.cc
@@ -19,9 +19,7 @@
/*!
* \file auto_scheduler/auto_schedule.cc
- * \brief The user interface of the TVM Auto-scheduler. This is the entry structure to get
- * schedule search requirements from upper level (Python API), and returns a high performance
- * schedule after search process.
+ * \brief The user interface and tuning options of the TVM auto-scheduler.
*/
#include <tvm/auto_scheduler/auto_schedule.h>
diff --git a/src/auto_scheduler/feature.cc b/src/auto_scheduler/feature.cc
index 2744e0d..15066a9 100755
--- a/src/auto_scheduler/feature.cc
+++ b/src/auto_scheduler/feature.cc
@@ -871,7 +871,9 @@ class PerStoreFeatureExtractor : public StmtExprVisitor {
stride = (i == static_cast<int>(for_loop_stack_.size()) - 1 ? stride : 0);
float n_continuous = ele_bytes;
- for (int i = static_cast<int>(tmp_region.size()) - 1; i >= 0; i--) {
+ for (int i = std::min(static_cast<int>(tmp_region.size()) - 1,
+ static_cast<int>(int_shape.size()) - 1);
+ i >= 0; i--) {
if (tmp_region[i] == int_shape[i]) {
n_continuous *= tmp_region[i];
break;
diff --git a/src/auto_scheduler/measure.cc b/src/auto_scheduler/measure.cc
index 70ea7ab..c3ee6a1 100755
--- a/src/auto_scheduler/measure.cc
+++ b/src/auto_scheduler/measure.cc
@@ -38,6 +38,7 @@ TVM_REGISTER_NODE_TYPE(MeasureResultNode);
TVM_REGISTER_OBJECT_TYPE(MeasureCallbackNode);
TVM_REGISTER_OBJECT_TYPE(ProgramRunnerNode);
TVM_REGISTER_OBJECT_TYPE(ProgramBuilderNode);
+TVM_REGISTER_OBJECT_TYPE(ProgramMeasurerNode);
TVM_REGISTER_OBJECT_TYPE(LocalBuilderNode);
TVM_REGISTER_OBJECT_TYPE(LocalRunnerNode);
TVM_REGISTER_OBJECT_TYPE(RPCRunnerNode);
@@ -204,11 +205,12 @@ void ProgramMeasurerNode::Reset() {
best_state.clear();
}
-void ProgramMeasurerNode::Measure(const SearchTask& task, const SearchPolicy& policy,
- const Array<MeasureInput>& inputs, Array<MeasureResult>* results,
- int batch_size) {
- results->clear();
- results->reserve(inputs.size());
+Array<MeasureResult> ProgramMeasurerNode::Measure(const SearchTask& task,
+ const SearchPolicy& policy,
+ const Array<MeasureInput>& inputs,
+ int batch_size) {
+ Array<MeasureResult> results;
+ results.reserve(inputs.size());
if (batch_size == -1) {
// set default batch size
@@ -261,13 +263,15 @@ void ProgramMeasurerNode::Measure(const SearchTask& task, const SearchPolicy& po
// Store result batch
for (auto& res : result_batch) {
- results->push_back(res);
+ results.push_back(res);
}
if (error_ct > max_continuous_error) {
LOG(FATAL) << "Too many errors happened during tuning";
}
}
+
+ return results;
}
void ProgramMeasurerNode::SilentMeasure(const SearchTask& task, const Array<MeasureInput>& inputs,
@@ -343,6 +347,12 @@ TVM_REGISTER_GLOBAL("auto_scheduler.MeasureResult")
return MeasureResult(costs, error_no, error_msg, all_cost, timestamp);
});
+TVM_REGISTER_GLOBAL("auto_scheduler.ProgramMeasurer")
+ .set_body_typed([](ProgramBuilder builder, ProgramRunner runner,
+ Array<MeasureCallback> callbacks, int verbose, int max_continuous_error) {
+ return ProgramMeasurer(builder, runner, callbacks, verbose, max_continuous_error);
+ });
+
TVM_REGISTER_GLOBAL("auto_scheduler.ProgramBuilderBuild")
.set_body_typed([](const ProgramBuilder& builder, const Array<MeasureInput>& inputs,
int verbose) { return builder->Build(inputs, verbose); });
diff --git a/src/auto_scheduler/search_policy/empty_policy.cc b/src/auto_scheduler/search_policy/empty_policy.cc
index 21a68ac..fba1ac2 100644
--- a/src/auto_scheduler/search_policy/empty_policy.cc
+++ b/src/auto_scheduler/search_policy/empty_policy.cc
@@ -19,7 +19,8 @@
/*!
* \file auto_scheduler/search_policy/empty_policy.cc
- * \brief This is an brief example of search policy.
+ * \brief A simple example of the search policy which always returns the initial naive schedule
+ * (state).
*/
#include "empty_policy.h"
@@ -29,6 +30,8 @@
#include <utility>
+#include "utils.h"
+
namespace tvm {
namespace auto_scheduler {
@@ -64,19 +67,18 @@ State EmptyPolicyNode::Search(int num_measure_trials, int early_stopping,
measurer->Reset();
int ct = 0;
// In each round, we call SearchOneRound to get several candidate states,
- // then use ProgramMeasurer to test their performance
+ // then use ProgramMeasurer to measure their performance.
while (ct < num_measure_trials) {
const auto& res = SearchOneRound();
ct += res.size();
// Build MeasureInputs for measuring
inputs.clear();
for (const auto& state : res) {
- // The class members measured_states_set_ provided by SearchPolicy can be used to filter
- // out the already measured states
inputs.push_back(MeasureInput(search_task, state));
}
+ // Perform measurement.
// ProgramMeasurer will record the state with best performance during measure process
- measurer->Measure(search_task, GetRef<SearchPolicy>(this), inputs, &results);
+ results = measurer->Measure(search_task, GetRef<SearchPolicy>(this), inputs);
}
// Return a state with best measured performance
@@ -84,18 +86,33 @@ State EmptyPolicyNode::Search(int num_measure_trials, int early_stopping,
}
}
+std::pair<Array<MeasureInput>, Array<MeasureResult>> EmptyPolicyNode::ContinueSearchOneRound(
+ int num_measure, ProgramMeasurer measurer) {
+ Array<State> best_states;
+ Array<MeasureInput> inputs;
+ Array<MeasureResult> results;
+
+ // Search one round to get promising states
+ PrintTitle("Search", verbose);
+ best_states = SearchOneRound();
+
+ // Measure these states
+ PrintTitle("Measure", verbose);
+ for (const auto& state : best_states) {
+ inputs.push_back(MeasureInput(search_task, state));
+ }
+ results = measurer->Measure(search_task, GetRef<SearchPolicy>(this), inputs);
+
+ return std::make_pair(std::move(inputs), std::move(results));
+}
+
// As an example policy, EmptyPolicy always returns a init state
Array<State> EmptyPolicyNode::SearchOneRound() {
Array<State> res;
- // 1. We will process `Program sampling` first to generate several initial schedules
+ // Simply return the initial naive schedule (state).
res.push_back(search_task->compute_dag->init_state);
- // 2. Then `Performance Tuning`: use cost model and evolutionary search to seek for the schedule
- // with best performance
- // Note: This example policy does not include this part
-
- // 3. The returned candidate schedules will be measured in hardware
return res;
}
diff --git a/src/auto_scheduler/search_policy/empty_policy.h b/src/auto_scheduler/search_policy/empty_policy.h
index 3d13822..2219ebc 100644
--- a/src/auto_scheduler/search_policy/empty_policy.h
+++ b/src/auto_scheduler/search_policy/empty_policy.h
@@ -19,7 +19,7 @@
/*!
* \file auto_scheduler/search_policy/empty_policy.h
- * \brief A brief example of the search policy which always returns the initial naive schedule
+ * \brief A simple example of the search policy which always returns the initial naive schedule
* (state).
*/
@@ -27,14 +27,17 @@
#define TVM_AUTO_SCHEDULER_SEARCH_POLICY_EMPTY_POLICY_H_
#include <tvm/auto_scheduler/loop_state.h>
+#include <tvm/auto_scheduler/measure.h>
#include <tvm/auto_scheduler/search_policy.h>
+#include <utility>
+
namespace tvm {
namespace auto_scheduler {
/*!
- * \brief A brief example of the search policy which always returns the initial naive schedule
- * (state), the formal search policy will continue to follow its design.
+ * \brief A simple example of the search policy which always returns the initial naive schedule
+ * (state).
* The key implementation for this structure is `Search()`, check `empty_policy.cc` for more
* details.
*/
@@ -43,13 +46,16 @@ class EmptyPolicyNode : public SearchPolicyNode {
State Search(int num_measure_trials, int early_stopping, int num_measures_per_round,
ProgramMeasurer measurer) final;
+ std::pair<Array<MeasureInput>, Array<MeasureResult>> ContinueSearchOneRound(
+ int num_measure, ProgramMeasurer measurer) final;
+
static constexpr const char* _type_key = "auto_scheduler.EmptyPolicy";
TVM_DECLARE_FINAL_OBJECT_INFO(EmptyPolicyNode, SearchPolicyNode);
private:
/*!
* \brief Use a sub function to generate several candidate states in each search round.
- * \returns Several generated states
+ * \returns The generated states
*/
Array<State> SearchOneRound();
};
diff --git a/src/auto_scheduler/search_policy/search_policy.cc b/src/auto_scheduler/search_policy/search_policy.cc
index d73bd91..8b6d22b 100644
--- a/src/auto_scheduler/search_policy/search_policy.cc
+++ b/src/auto_scheduler/search_policy/search_policy.cc
@@ -104,8 +104,13 @@ TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicyRunCallbacks")
}
});
-TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicySetTask")
- .set_body_typed([](SearchPolicy policy, SearchTask task) { policy->search_task = task; });
+TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicyContinueSearchOneRound")
+ .set_body_typed([](SearchPolicy policy, int num_measure, ProgramMeasurer measurer) {
+ Array<MeasureInput> inputs;
+ Array<MeasureResult> results;
+ std::tie(inputs, results) = policy->ContinueSearchOneRound(num_measure, measurer);
+ return Array<ObjectRef>{inputs, results};
+ });
TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicySetVerbose")
.set_body_typed([](SearchPolicy policy, int verbose) { policy->verbose = verbose; });
diff --git a/src/auto_scheduler/search_policy/sketch_policy.cc b/src/auto_scheduler/search_policy/sketch_policy.cc
index a89fa4b..8de17a6 100644
--- a/src/auto_scheduler/search_policy/sketch_policy.cc
+++ b/src/auto_scheduler/search_policy/sketch_policy.cc
@@ -157,6 +157,7 @@ State SketchPolicyNode::Search(int n_trials, int early_stopping, int num_measure
int ct = 0;
int empty_retry_count = GetIntParam(params, SketchParamKey::empty_retry_count);
+ Array<State> best_states, random_states;
Array<MeasureInput> inputs;
Array<MeasureResult> results;
while (ct < n_trials) {
@@ -168,8 +169,7 @@ State SketchPolicyNode::Search(int n_trials, int early_stopping, int num_measure
// Search one round to get promising states
PrintTitle("Search", verbose);
- Array<State> random_states;
- Array<State> best_states = SearchOneRound(num_random, &random_states);
+ best_states = SearchOneRound(num_random * 3, &random_states);
// Infer bound. This is necessary for computing the correct ToStr() for redundancy check
best_states = search_task->compute_dag.InferBound(best_states);
@@ -196,7 +196,7 @@ State SketchPolicyNode::Search(int n_trials, int early_stopping, int num_measure
// Measure candidate states
PrintTitle("Measure", verbose);
- measurer->Measure(search_task, GetRef<SearchPolicy>(this), inputs, &results);
+ results = measurer->Measure(search_task, GetRef<SearchPolicy>(this), inputs);
ct += inputs.size();
// Check if reach the early stopping condition
@@ -218,15 +218,45 @@ State SketchPolicyNode::Search(int n_trials, int early_stopping, int num_measure
}
}
-Array<State> SketchPolicyNode::SearchOneRound(int num_random_states, Array<State>* random_states) {
- // Temporal object to be used if the input pointer is nullptr
- Array<State> temp_random_states;
- if (random_states == nullptr) {
- random_states = &temp_random_states;
- } else {
- random_states->clear();
+std::pair<Array<MeasureInput>, Array<MeasureResult>> SketchPolicyNode::ContinueSearchOneRound(
+ int num_measure, ProgramMeasurer measurer) {
+ num_measure_per_iter_ = num_measure;
+
+ Array<State> best_states, random_states;
+ Array<MeasureInput> inputs;
+ Array<MeasureResult> results;
+ int num_random = static_cast<int>(GetDoubleParam(params, "eps_greedy") * num_measure);
+
+ // Search one round to get promising states
+ PrintTitle("Search", verbose);
+ best_states = SearchOneRound(num_random * 3, &random_states);
+
+ // Infer bound. This is necessary for computing the correct ToStr() for redundancy check
+ best_states = search_task->compute_dag.InferBound(best_states);
+ random_states = search_task->compute_dag.InferBound(random_states);
+
+ // Pick `num_measure_per_iter` states to measure, check hash to remove already measured state
+ // Also pick some random states to do eps-greedy
+ inputs = PickStatesWithEpsGreedy(best_states, random_states, num_measure);
+
+ // Measure candidate states
+ PrintTitle("Measure", verbose);
+ results = measurer->Measure(search_task, GetRef<SearchPolicy>(this), inputs);
+
+ // Update measured states throughputs. These states will join the EvolutionarySearch in later
+ // search rounds.
+ for (const auto& res : results) {
+ measured_states_throughputs_.push_back(1.0 / FloatArrayMean(res->costs));
}
+ // Update the cost model
+ PrintTitle("Train cost model", verbose);
+ program_cost_model->Update(inputs, results);
+
+ return std::make_pair(std::move(inputs), std::move(results));
+}
+
+Array<State> SketchPolicyNode::SearchOneRound(int num_random_states, Array<State>* random_states) {
// Get parameters
int population = GetIntParam(params, SketchParamKey::EvolutionarySearch::population);
int num_use_measured =
@@ -245,8 +275,8 @@ Array<State> SketchPolicyNode::SearchOneRound(int num_random_states, Array<State
Array<State> init_population = SampleInitPopulation(
sketch_cache_, is_cost_model_reasonable ? population - num_use_measured : population);
- // 3. If the cost model is useless (i.e. RandomCostModel), just random pick some generated
- // states, else perform evolutionary search
+ // 3. Perform evolutionary search if a cost model is utilized. Otherwise,
+ // just return some random states.
if (is_cost_model_reasonable) {
// Also insert already measured good states to the initial population
std::vector<int> indices = Argsort(measured_states_throughputs_);
@@ -254,11 +284,13 @@ Array<State> SketchPolicyNode::SearchOneRound(int num_random_states, Array<State
init_population.push_back(measured_states_vector_[indices[i]]);
}
// Sample some random states for eps-greedy
- *random_states = RandomSampleStates(init_population, &rand_gen, num_random_states * 3);
+ if (num_random_states > 0 && random_states != nullptr) {
+ *random_states = RandomSampleStates(init_population, &rand_gen, num_random_states);
+ }
return EvolutionarySearch(init_population, num_measure_per_iter_ * 2);
} else {
PruneInvalidState(search_task, &init_population);
- return RandomSampleStates(init_population, &rand_gen, num_measure_per_iter_ * 3);
+ return RandomSampleStates(init_population, &rand_gen, num_measure_per_iter_ * 2);
}
}
@@ -347,10 +379,7 @@ Array<State> SketchPolicyNode::SampleInitPopulation(const Array<State>& sketches
support::parallel_for(0, out_size - out_states.size(),
[this, &temp_states, &sketches, &rand_gens](int index) {
- // Random choose a starting sketch
- // TODO(jcf94, merrymercy): Maybe choose sketches in different
- // possibility for they may have different potential on generating state
- // with better performance
+ // Randomly choose a sketch
State tmp_s = sketches[(rand_gens[index])() % sketches.size()];
// Derivation rule based enumeration
bool valid = true;
@@ -472,6 +501,8 @@ Array<State> SketchPolicyNode::EvolutionarySearch(const Array<State>& init_popul
// Compute selection probability
ComputePrefixSumProb(pop_scores, &pop_selection_probs);
+ // TODO(merrymercy, comaniac): add crossover.
+
// Do mutation
while (pnext->size() < population) {
State tmp_s = (*pnow)[RandomChoose(pop_selection_probs, &rand_gen)];
diff --git a/src/auto_scheduler/search_policy/sketch_policy.h b/src/auto_scheduler/search_policy/sketch_policy.h
index 21aaa6e..edaa89e 100644
--- a/src/auto_scheduler/search_policy/sketch_policy.h
+++ b/src/auto_scheduler/search_policy/sketch_policy.h
@@ -19,13 +19,15 @@
/*!
* \file auto_scheduler/search_policy/sketch_policy.h
- * \brief The search policy that searches in a hierarchical search space defined by sketches.
- * The policy randomly samples programs from the space defined by sketches and use evolutionary
- * search to fine-tune them.
+ * \brief This search policy constructs a search space according to the compute declaration.
+ * It then randomly samples programs from the search space and uses evolutionary search with a
+ * learned cost model to fine tune the sampled programs.
+ * The final optimized programs are sent to actual hardware for measurement.
+ * The above process is repeated until the auto-scheduler runs out of time budget.
*
* Reference:
* L. Zheng, C. Jia, M. Sun, Z. Wu, C. Yu, et al. "Ansor : Generating High-Performance Tensor
- * Programs for Deep Learning." arXiv preprint arXiv:2006.06762 (2020).
+ * Programs for Deep Learning." (OSDI 2020).
*/
#ifndef TVM_AUTO_SCHEDULER_SEARCH_POLICY_SKETCH_POLICY_H_
@@ -106,6 +108,9 @@ class SketchPolicyNode : public SearchPolicyNode {
State Search(int num_measure_trials, int early_stopping, int num_measures_per_round,
ProgramMeasurer measurer) final;
+ std::pair<Array<MeasureInput>, Array<MeasureResult>> ContinueSearchOneRound(
+ int num_measure, ProgramMeasurer measurer) final;
+
/*!
* \brief Generate sketches.
* \return The generated sketches(states).
diff --git a/src/auto_scheduler/search_policy/sketch_policy_rules.cc b/src/auto_scheduler/search_policy/sketch_policy_rules.cc
index 99188d4..b6ad4d3 100644
--- a/src/auto_scheduler/search_policy/sketch_policy_rules.cc
+++ b/src/auto_scheduler/search_policy/sketch_policy_rules.cc
@@ -19,7 +19,8 @@
/*!
* \file auto_scheduler/search_policy/sketch_policy_rules.cc
- * \brief Rules defined to generate the sketches and initial sampled states in SketchPolicy.
+ * \brief Rules for generating the sketches, sampling the initial population, and mutating the
+ * population in SketchPolicy.
*/
#include "sketch_policy_rules.h"
@@ -317,7 +318,7 @@ SketchGenerationRule::ConditionKind RuleCrossThreadReduction::MeetCondition(
const SketchPolicyNode& policy, const State& state, int stage_id) const {
CHECK(IsGPUTask(policy.search_task));
- // If it is an intermidiate state created by RuleAddCacheWrite,
+ // If it is an intermediate state created by RuleAddCacheWrite,
// we just skip it.
if (HasCacheWriteStage(state, stage_id)) {
return ConditionKind::kSkip;
@@ -1116,6 +1117,10 @@ PopulationGenerationRule::ResultKind MutateParallel::Apply(SketchPolicyNode* pol
}
}
+ if (max_fusable_iter_id == 0) {
+ return ResultKind::kInvalid;
+ }
+
// Randomly pick one granularity
int fuse_to_iter_id = (*rand_gen)() % max_fusable_iter_id + 1;
Array<Integer> fused_ids;
diff --git a/src/auto_scheduler/search_policy/sketch_policy_rules.h b/src/auto_scheduler/search_policy/sketch_policy_rules.h
index 035dc89..046f036 100644
--- a/src/auto_scheduler/search_policy/sketch_policy_rules.h
+++ b/src/auto_scheduler/search_policy/sketch_policy_rules.h
@@ -19,7 +19,8 @@
/*!
* \file auto_scheduler/search_policy/sketch_policy_rules.h
- * \brief Rules defined to generate the sketches and initial sampled states in SketchPolicy.
+ * \brief Rules for generating the sketches, sampling the initial population, and mutating the
+ * population in SketchPolicy.
*/
#ifndef TVM_AUTO_SCHEDULER_SEARCH_POLICY_SKETCH_POLICY_RULES_H_
diff --git a/tests/python/unittest/test_auto_scheduler_task_scheduler.py b/tests/python/unittest/test_auto_scheduler_task_scheduler.py
new file mode 100644
index 0000000..72b998a
--- /dev/null
+++ b/tests/python/unittest/test_auto_scheduler_task_scheduler.py
@@ -0,0 +1,112 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+""" Test task scheduler """
+
+import tempfile
+
+import numpy as np
+
+from tvm import auto_scheduler
+
+from test_auto_scheduler_common import matmul_auto_scheduler_test
+
+
+def test_task_scheduler_round_robin():
+ tasks = []
+ for n in [2, 4, 8]:
+ tasks.append(auto_scheduler.create_task(matmul_auto_scheduler_test, (n, n, n), "llvm"))
+
+ def objective_func(costs):
+ return sum(costs)
+
+ with tempfile.NamedTemporaryFile() as fp:
+ log_file = fp.name
+ num_trials_per_task = 2
+
+ # Tune all tasks
+ tune_option = auto_scheduler.TuningOptions(
+ num_measure_trials=num_trials_per_task * len(tasks),
+ num_measures_per_round=1,
+ measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
+ )
+ task_scheduler = auto_scheduler.TaskScheduler(tasks, objective_func, strategy="round-robin")
+ task_scheduler.tune(tune_option, search_policy="sketch.random")
+
+ # Check the result of round robin
+ counters = {}
+ for task in tasks:
+ counters[task.workload_key] = 0
+
+ for inp, res in auto_scheduler.load_records(log_file):
+ counters[inp.task.workload_key] += 1
+
+ for task in tasks:
+ assert counters[task.workload_key] == num_trials_per_task
+
+ # test continuous tuning (restoring the status)
+ task_scheduler = auto_scheduler.TaskScheduler(
+ tasks, objective_func, strategy="round-robin", load_log_file=log_file
+ )
+ tune_option = auto_scheduler.TuningOptions(
+ num_measure_trials=len(tasks),
+ num_measures_per_round=1,
+ )
+ task_scheduler.tune(tune_option, search_policy="sketch.random")
+
+
+def test_task_scheduler_gradient():
+ tasks = []
+ for n in [2, 4]:
+ tasks.append(auto_scheduler.create_task(matmul_auto_scheduler_test, (n, n, n), "llvm"))
+
+ def objective_func(costs):
+ return costs[0]
+
+ with tempfile.NamedTemporaryFile() as fp:
+ log_file = fp.name
+
+ n_trials = 5
+
+ # Tune all tasks
+ tune_option = auto_scheduler.TuningOptions(
+ num_measure_trials=n_trials,
+ num_measures_per_round=1,
+ measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
+ )
+ task_scheduler = auto_scheduler.TaskScheduler(tasks, objective_func)
+
+ # Forcely rewrite the initial values.
+ # This can make this test more stable on the slow CI machines
+ task_scheduler.best_costs = np.array([1e2, 1e-8])
+
+ task_scheduler.tune(tune_option, search_policy="sketch.random")
+
+ # Check the allocation results
+ counters = {}
+ for task in tasks:
+ counters[task.workload_key] = 0
+
+ for inp, res in auto_scheduler.load_records(log_file):
+ counters[inp.task.workload_key] += 1
+
+ assert counters[tasks[0].workload_key] == n_trials - 1
+ assert counters[tasks[1].workload_key] == 1
+
+
+if __name__ == "__main__":
+ test_task_scheduler_round_robin()
+ test_task_scheduler_gradient()
diff --git a/tutorials/auto_scheduler/tune_conv2d_layer_cuda.py b/tutorials/auto_scheduler/tune_conv2d_layer_cuda.py
index 5004a5f..b800eb4 100644
--- a/tutorials/auto_scheduler/tune_conv2d_layer_cuda.py
+++ b/tutorials/auto_scheduler/tune_conv2d_layer_cuda.py
@@ -25,10 +25,9 @@ Auto-scheduling a convolution layer for GPU
Different from the existing :ref:`autotvm <tutorials-autotvm-sec>` which relies on
manual templates to define the search space, the auto-scheduler does not require any templates.
-The auto-scheduler is template-free, so users only need to write the computation declaration without
-any schedule commands or templates.
-The auto-scheduler can automatically generate a large
-search space and find a good schedule in the space.
+Users only need to write the computation declaration without any schedule commands or templates.
+The auto-scheduler can automatically generate a large search space and
+find a good schedule in the space.
We use a convolution layer as an example in this tutorial.
"""
diff --git a/tutorials/auto_scheduler/tune_matmul_x86.py b/tutorials/auto_scheduler/tune_matmul_x86.py
index e1e0115..35c4744 100644
--- a/tutorials/auto_scheduler/tune_matmul_x86.py
+++ b/tutorials/auto_scheduler/tune_matmul_x86.py
@@ -22,10 +22,9 @@ Auto-scheduling matrix multiplication for CPU
Different from the existing :ref:`autotvm <tutorials-autotvm-sec>` which relies on
manual templates to define the search space, the auto-scheduler does not require any templates.
-The auto-scheduler is template-free, so users only need to write the computation declaration without
-any schedule commands or templates.
-The auto-scheduler can automatically generate a large
-search space and find a good schedule in the space.
+Users only need to write the computation declaration without any schedule commands or templates.
+The auto-scheduler can automatically generate a large search space and
+find a good schedule in the space.
We use matrix multiplication as an example in this tutorial.
"""