You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/06/07 03:02:23 UTC
[tvm] branch main updated: [MetaSchedule] Evo Independence from TaskScheduler (#11590)
This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 68dcecc926 [MetaSchedule] Evo Independence from TaskScheduler (#11590)
68dcecc926 is described below
commit 68dcecc926f890429a8f2cba9ce55eab6a18fa6e
Author: Junru Shao <ju...@gmail.com>
AuthorDate: Mon Jun 6 20:02:18 2022 -0700
[MetaSchedule] Evo Independence from TaskScheduler (#11590)
Per discussion with @Kathryn-cat, we realized that the current API
design could be verbose if we only want to tune a single task, in which
case a dummy task scheduler still needs to be established to supply
`EvolutionarySearch` with proper `CostModel` and `Database`. This PR
fixes this UX issue.
---
include/tvm/meta_schedule/search_strategy.h | 17 ++-
include/tvm/meta_schedule/task_scheduler.h | 20 +--
include/tvm/meta_schedule/tune_context.h | 2 -
.../search_strategy/search_strategy.py | 24 +++-
.../meta_schedule/task_scheduler/gradient_based.py | 10 +-
.../meta_schedule/task_scheduler/round_robin.py | 10 +-
.../meta_schedule/task_scheduler/task_scheduler.py | 10 +-
.../measure_callback/add_to_database.cc | 5 +-
.../search_strategy/evolutionary_search.cc | 148 +++++++++++----------
src/meta_schedule/search_strategy/replay_func.cc | 48 ++++---
src/meta_schedule/search_strategy/replay_trace.cc | 63 +++++----
.../search_strategy/search_strategy.cc | 7 +
src/meta_schedule/task_scheduler/gradient_based.cc | 7 +-
src/meta_schedule/task_scheduler/round_robin.cc | 7 +-
src/meta_schedule/task_scheduler/task_scheduler.cc | 6 +-
.../test_meta_schedule_measure_callback.py | 22 ++-
.../unittest/test_meta_schedule_search_strategy.py | 93 ++++++-------
.../unittest/test_meta_schedule_task_scheduler.py | 60 ++++-----
18 files changed, 298 insertions(+), 261 deletions(-)
diff --git a/include/tvm/meta_schedule/search_strategy.h b/include/tvm/meta_schedule/search_strategy.h
index 6895673a04..139de7c99d 100644
--- a/include/tvm/meta_schedule/search_strategy.h
+++ b/include/tvm/meta_schedule/search_strategy.h
@@ -113,12 +113,16 @@ class SearchStrategyNode : public runtime::Object {
/*!
* \brief Pre-tuning for the search strategy.
- * \param design_spaces The design spaces for pre-tuning.
+ * \param design_spaces The design spaces used during tuning process.
+ * \param database The database used during tuning process.
+ * \param cost_model The cost model used during tuning process.
* \note Pre-tuning is supposed to be called before the tuning process and after the
* initialization. Because the search strategy is stateful, we can always call pretuning
* and reset the search strategy.
*/
- virtual void PreTuning(const Array<tir::Schedule>& design_spaces) = 0;
+ virtual void PreTuning(const Array<tir::Schedule>& design_spaces,
+ const Optional<Database>& database,
+ const Optional<CostModel>& cost_model) = 0;
/*!
* \brief Post-tuning for the search strategy.
@@ -159,7 +163,8 @@ class PySearchStrategyNode : public SearchStrategyNode {
* \brief The function type of `PreTuning` method.
* \param design_spaces The design spaces for pre-tuning.
*/
- using FPreTuning = runtime::TypedPackedFunc<void(const Array<tir::Schedule>&)>;
+ using FPreTuning = runtime::TypedPackedFunc<void(
+ const Array<tir::Schedule>&, const Optional<Database>&, const Optional<CostModel>&)>;
/*! \brief The function type of `PostTuning` method. */
using FPostTuning = runtime::TypedPackedFunc<void()>;
/*!
@@ -199,10 +204,8 @@ class PySearchStrategyNode : public SearchStrategyNode {
this->f_initialize_with_tune_context(context);
}
- void PreTuning(const Array<tir::Schedule>& design_spaces) final {
- ICHECK(f_pre_tuning != nullptr) << "PySearchStrategy's PreTuning method not implemented!";
- this->f_pre_tuning(design_spaces);
- }
+ void PreTuning(const Array<tir::Schedule>& design_spaces, const Optional<Database>& database,
+ const Optional<CostModel>& cost_model) final;
void PostTuning() final {
ICHECK(f_post_tuning != nullptr) << "PySearchStrategy's PostTuning method not implemented!";
diff --git a/include/tvm/meta_schedule/task_scheduler.h b/include/tvm/meta_schedule/task_scheduler.h
index 7453c2b484..5953a2c3e4 100644
--- a/include/tvm/meta_schedule/task_scheduler.h
+++ b/include/tvm/meta_schedule/task_scheduler.h
@@ -74,13 +74,13 @@ class TaskSchedulerNode : public runtime::Object {
/*! \brief The runner of the scheduler. */
Runner runner{nullptr};
/*! \brief The database of the scheduler. */
- Database database{nullptr};
- /*! \brief The maximum number of trials allowed. */
- int max_trials;
+ Optional<Database> database;
/*! \brief The cost model of the scheduler. */
Optional<CostModel> cost_model;
/*! \brief The list of measure callbacks of the scheduler. */
Array<MeasureCallback> measure_callbacks;
+ /*! \brief The maximum number of trials allowed. */
+ int max_trials;
/*! \brief The number of trials already conducted. */
int num_trials_already;
/*! \brief The tuning task's logging function. t*/
@@ -94,9 +94,9 @@ class TaskSchedulerNode : public runtime::Object {
v->Visit("builder", &builder);
v->Visit("runner", &runner);
v->Visit("database", &database);
- v->Visit("max_trials", &max_trials);
v->Visit("cost_model", &cost_model);
v->Visit("measure_callbacks", &measure_callbacks);
+ v->Visit("max_trials", &max_trials);
v->Visit("num_trials_already", &num_trials_already);
// `logging_func` is not visited
}
@@ -243,10 +243,10 @@ class TaskScheduler : public runtime::ObjectRef {
TVM_DLL static TaskScheduler RoundRobin(Array<TuneContext> tasks, //
Builder builder, //
Runner runner, //
- Database database, //
- int max_trials, //
+ Optional<Database> database, //
Optional<CostModel> cost_model, //
Optional<Array<MeasureCallback>> measure_callbacks, //
+ int max_trials, //
PackedFunc logging_func);
/*!
* \brief Create a task scheduler that fetches tasks in a gradient based fashion.
@@ -268,10 +268,10 @@ class TaskScheduler : public runtime::ObjectRef {
Array<FloatImm> task_weights, //
Builder builder, //
Runner runner, //
- Database database, //
- int max_trials, //
+ Optional<Database> database, //
Optional<CostModel> cost_model, //
Optional<Array<MeasureCallback>> measure_callbacks, //
+ int max_trials, //
PackedFunc logging_func, //
double alpha, //
int window_size, //
@@ -297,10 +297,10 @@ class TaskScheduler : public runtime::ObjectRef {
Array<TuneContext> tasks, //
Builder builder, //
Runner runner, //
- Database database, //
- int max_trials, //
+ Optional<Database> database, //
Optional<CostModel> cost_model, //
Optional<Array<MeasureCallback>> measure_callbacks, //
+ int max_trials, //
PackedFunc logging_func, //
PyTaskSchedulerNode::FTune f_tune, //
PyTaskSchedulerNode::FInitializeTask f_initialize_task, //
diff --git a/include/tvm/meta_schedule/tune_context.h b/include/tvm/meta_schedule/tune_context.h
index faa24fc99f..d63fb819f3 100644
--- a/include/tvm/meta_schedule/tune_context.h
+++ b/include/tvm/meta_schedule/tune_context.h
@@ -61,8 +61,6 @@ class TuneContextNode : public runtime::Object {
/*! \brief The number of threads to be used. */
int num_threads;
- /*! \brief The task scheduler that owns the tune context */
- const TaskSchedulerNode* task_scheduler;
/*! \brief Whether the tuning task has been stopped or finished. */
bool is_terminated;
/*! \brief The measure candidates. */
diff --git a/python/tvm/meta_schedule/search_strategy/search_strategy.py b/python/tvm/meta_schedule/search_strategy/search_strategy.py
index 07c47f01d1..14b46a0785 100644
--- a/python/tvm/meta_schedule/search_strategy/search_strategy.py
+++ b/python/tvm/meta_schedule/search_strategy/search_strategy.py
@@ -18,7 +18,7 @@
Meta Schedule search strategy that generates the measure
candidates for measurement.
"""
-from typing import Callable, List, Optional, TYPE_CHECKING
+from typing import TYPE_CHECKING, Callable, List, Optional
from tvm._ffi import register_object
from tvm.runtime import Object
@@ -29,6 +29,8 @@ from ..arg_info import ArgInfo
from ..runner import RunnerResult
if TYPE_CHECKING:
+ from ..cost_model import CostModel
+ from ..database import Database
from ..tune_context import TuneContext
@@ -87,15 +89,29 @@ class SearchStrategy(Object):
self, context
)
- def pre_tuning(self, design_spaces: List[Schedule]) -> None:
+ def pre_tuning(
+ self,
+ design_spaces: List[Schedule],
+ database: Optional["Database"] = None,
+ cost_model: Optional["CostModel"] = None,
+ ) -> None:
"""Pre-tuning for the search strategy.
Parameters
----------
design_spaces : List[Schedule]
- The design spaces for pre-tuning.
+ The design spaces used during tuning process.
+ database : Optional[Database] = None
+ The database used during tuning process.
+ cost_model : Optional[CostModel] = None
+ The cost model used during tuning process.
"""
- _ffi_api.SearchStrategyPreTuning(self, design_spaces) # type: ignore # pylint: disable=no-member
+ _ffi_api.SearchStrategyPreTuning( # type: ignore # pylint: disable=no-member
+ self,
+ design_spaces,
+ database,
+ cost_model,
+ )
def post_tuning(self) -> None:
"""Post-tuning for the search strategy."""
diff --git a/python/tvm/meta_schedule/task_scheduler/gradient_based.py b/python/tvm/meta_schedule/task_scheduler/gradient_based.py
index 6234449bf0..20d32dd1c5 100644
--- a/python/tvm/meta_schedule/task_scheduler/gradient_based.py
+++ b/python/tvm/meta_schedule/task_scheduler/gradient_based.py
@@ -45,11 +45,11 @@ class GradientBased(TaskScheduler):
task_weights: List[float],
builder: Builder,
runner: Runner,
- database: Database,
- max_trials: int,
*,
+ database: Database,
cost_model: Optional[CostModel] = None,
measure_callbacks: Optional[List[MeasureCallback]] = None,
+ max_trials: int,
alpha: float = 0.2,
window_size: int = 3,
seed: int = -1,
@@ -68,12 +68,12 @@ class GradientBased(TaskScheduler):
The runner.
database : Database
The database.
- max_trials : int
- The maximum number of trials to run.
cost_model : CostModel, default None.
The cost model of the scheduler.
measure_callbacks : Optional[List[MeasureCallback]] = None
The list of measure callbacks of the scheduler.
+ max_trials : int
+ The maximum number of trials to run.
alpha : float = 0.2
The parameter alpha in gradient computation.
window_size : int = 3
@@ -88,9 +88,9 @@ class GradientBased(TaskScheduler):
builder,
runner,
database,
- max_trials,
cost_model,
measure_callbacks,
+ max_trials,
make_logging_func(logger),
alpha,
window_size,
diff --git a/python/tvm/meta_schedule/task_scheduler/round_robin.py b/python/tvm/meta_schedule/task_scheduler/round_robin.py
index a461358283..ed395643bb 100644
--- a/python/tvm/meta_schedule/task_scheduler/round_robin.py
+++ b/python/tvm/meta_schedule/task_scheduler/round_robin.py
@@ -60,11 +60,11 @@ class RoundRobin(TaskScheduler):
task_weights: List[float],
builder: Builder,
runner: Runner,
- database: Database,
- max_trials: int,
*,
+ database: Database,
cost_model: Optional[CostModel] = None,
measure_callbacks: Optional[List[MeasureCallback]] = None,
+ max_trials: int,
) -> None:
"""Constructor.
@@ -80,12 +80,12 @@ class RoundRobin(TaskScheduler):
The runner.
database : Database
The database.
- max_trials : int
- The maximum number of trials.
cost_model : Optional[CostModel]
The cost model.
measure_callbacks: Optional[List[MeasureCallback]]
The list of measure callbacks of the scheduler.
+ max_trials : int
+ The maximum number of trials.
"""
del task_weights
self.__init_handle_by_constructor__(
@@ -94,8 +94,8 @@ class RoundRobin(TaskScheduler):
builder,
runner,
database,
- max_trials,
cost_model,
measure_callbacks,
+ max_trials,
make_logging_func(logger),
)
diff --git a/python/tvm/meta_schedule/task_scheduler/task_scheduler.py b/python/tvm/meta_schedule/task_scheduler/task_scheduler.py
index 4454078a6f..3d57a6b01b 100644
--- a/python/tvm/meta_schedule/task_scheduler/task_scheduler.py
+++ b/python/tvm/meta_schedule/task_scheduler/task_scheduler.py
@@ -31,7 +31,6 @@ from ..runner import Runner, RunnerResult
from ..tune_context import TuneContext
from ..utils import make_logging_func
-
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
@@ -177,9 +176,9 @@ class PyTaskScheduler:
"builder",
"runner",
"database",
- "max_trials",
"cost_model",
"measure_callbacks",
+ "max_trials",
],
"methods": [
"tune",
@@ -195,18 +194,19 @@ class PyTaskScheduler:
tasks: List[TuneContext],
builder: Builder,
runner: Runner,
- database: Database,
- max_trials: int,
+ *,
+ database: Optional[Database] = None,
cost_model: Optional[CostModel] = None,
measure_callbacks: Optional[List[MeasureCallback]] = None,
+ max_trials: int,
):
self.tasks = tasks
self.builder = builder
self.runner = runner
self.database = database
- self.max_trials = max_trials
self.cost_model = cost_model
self.measure_callbacks = measure_callbacks
+ self.max_trials = max_trials
def tune(self) -> None:
"""Auto-tuning."""
diff --git a/src/meta_schedule/measure_callback/add_to_database.cc b/src/meta_schedule/measure_callback/add_to_database.cc
index 20581f4630..0988da0414 100644
--- a/src/meta_schedule/measure_callback/add_to_database.cc
+++ b/src/meta_schedule/measure_callback/add_to_database.cc
@@ -27,8 +27,11 @@ class AddToDatabaseNode : public MeasureCallbackNode {
const Array<MeasureCandidate>& measure_candidates,
const Array<BuilderResult>& builder_results,
const Array<RunnerResult>& runner_results) final {
+ if (!task_scheduler->database.defined()) {
+ return;
+ }
TuneContext task = task_scheduler->tasks[task_id];
- Database database = task_scheduler->database;
+ Database database = task_scheduler->database.value();
Workload workload = database->CommitWorkload(task->mod.value());
Target target = task->target.value();
ICHECK_EQ(runner_results.size(), measure_candidates.size());
diff --git a/src/meta_schedule/search_strategy/evolutionary_search.cc b/src/meta_schedule/search_strategy/evolutionary_search.cc
index bdef26ef87..8b36a95217 100644
--- a/src/meta_schedule/search_strategy/evolutionary_search.cc
+++ b/src/meta_schedule/search_strategy/evolutionary_search.cc
@@ -246,13 +246,41 @@ class EvolutionarySearchNode : public SearchStrategyNode {
int ed;
/*! \brief The counter of returning empty results. */
int num_empty_iters;
-
- explicit State(EvolutionarySearchNode* self, Array<tir::Trace> design_spaces)
+ /*! \brief The metadata of the function arguments. */
+ Array<ArgInfo> args_info_{nullptr};
+ /*! \brief Pre thread data including module to be tuned and random state. */
+ std::vector<PerThreadData> per_thread_data_;
+ /*!
+ * \brief The workloads that are already measured.
+ * TODO(junrushao1994): add records from the database to avoid re-measuring.
+ * */
+ IRModuleSet measured_workloads_;
+ /*! \brief A Database for selecting useful candidates. */
+ Database database_{nullptr};
+ /*! \brief A cost model helping to explore the search space */
+ CostModel cost_model_{nullptr};
+ /*! \brief The token registered for the given workload in database. */
+ Workload token_{nullptr};
+
+ explicit State(EvolutionarySearchNode* self, Array<tir::Trace> design_spaces, Database database,
+ CostModel cost_model)
: self(self),
design_spaces(design_spaces),
st(0),
ed(self->num_trials_per_iter),
- num_empty_iters(0) {}
+ num_empty_iters(0) {
+ const TuneContextNode* ctx = self->context_;
+ IRModule mod = ctx->mod.value();
+ this->args_info_ = ArgInfo::FromPrimFunc(FindEntryFunc(mod));
+ this->per_thread_data_.resize(ctx->num_threads);
+ for (PerThreadData& data : this->per_thread_data_) {
+ data.mod = DeepCopyIRModule(mod);
+ data.rand_state = ForkSeed(&self->rand_state_);
+ }
+ this->database_ = database;
+ this->cost_model_ = cost_model;
+ this->token_ = database->CommitWorkload(mod);
+ }
/*!
* \brief Pick up best candidates from database.
@@ -293,33 +321,10 @@ class EvolutionarySearchNode : public SearchStrategyNode {
/*! \brief The tuning context of the evolutionary search strategy. */
const TuneContextNode* context_{nullptr};
- /*! \brief The target for the workload. */
- Target target_{nullptr};
- /*! \brief The metadata of the function arguments. */
- Array<ArgInfo> args_info_{nullptr};
- /*! \brief A Database for selecting useful candidates. */
- Database database_{nullptr};
- /*! \brief A cost model helping to explore the search space */
- CostModel cost_model_{nullptr};
- /*! \brief The postprocessors. */
- Array<Postproc> postprocs_{nullptr};
- /*! \brief Mutators and their probability mass */
- Map<Mutator, FloatImm> mutator_probs_{nullptr};
- /*! \brief The number of threads to use. To be initialized with TuneContext. */
- int num_threads_;
/*! \brief The random state. To be initialized with TuneContext. */
TRandState rand_state_;
- /*! \brief Pre thread data including module to be tuned and random state. */
- std::vector<PerThreadData> per_thread_data_;
/*! \brief The state of the search strategy. */
std::unique_ptr<State> state_ = nullptr;
- /*! \brief The token registered for the given workload in database. */
- Workload token_{nullptr};
- /*!
- * \brief The workloads that are already measured.
- * TODO(junrushao1994): add records from the database to avoid re-measuring.
- * */
- IRModuleSet measured_workloads_;
/*** Configuration: global ***/
/*! \brief The number of trials per iteration. */
@@ -351,15 +356,7 @@ class EvolutionarySearchNode : public SearchStrategyNode {
void VisitAttrs(tvm::AttrVisitor* v) {
// `context_` is not visited
- // `target_` is not visited
- // `args_info_` is not visited
- // `database` is not visited
- // `cost_model` is not visited
- // `postprocs` is not visited
- // `mutator_probs_` is not visited
- // `num_threads` is not visited
// `rand_state_` is not visited
- // `per_thread_data_` is not visited
// `state_` is not visited
/*** Configuration: global ***/
@@ -386,39 +383,41 @@ class EvolutionarySearchNode : public SearchStrategyNode {
CHECK(context->num_threads > 0) << "Number of threads has to be larger than 0.";
CHECK(context->target.defined()) << "Target must be defined!";
this->context_ = context.get();
- this->target_ = context->target.value();
- this->args_info_ = ArgInfo::FromPrimFunc(FindEntryFunc(context->mod.value()));
- this->mutator_probs_ = context->mutator_probs;
- this->postprocs_ = context->postprocs;
- this->num_threads_ = context->num_threads;
this->rand_state_ = ForkSeed(&context->rand_state);
- CHECK(context->task_scheduler != nullptr)
- << "ValueError: TaskScheduler is not defined in TuneContext";
- this->cost_model_ = context->task_scheduler->cost_model.value();
- this->database_ = context->task_scheduler->database;
- this->token_ = this->database_->CommitWorkload(context->mod.value());
- this->per_thread_data_.resize(this->num_threads_);
- for (const auto& kv : this->mutator_probs_) {
+ for (const auto& kv : context->mutator_probs) {
double mass = kv.second->value;
TVM_META_SCHEDULE_CHECK_PROB_RANGE(mass, "mutator_probs");
}
- for (PerThreadData& data : this->per_thread_data_) {
- data.mod = DeepCopyIRModule(context->mod.value());
- data.rand_state = ForkSeed(&this->rand_state_);
- }
this->state_.reset();
}
- void PreTuning(const Array<Schedule>& design_spaces) final {
+ void PreTuning(const Array<Schedule>& design_spaces, const Optional<Database>& database,
+ const Optional<CostModel>& cost_model) final {
ICHECK(!design_spaces.empty());
+ CHECK(this->context_ != nullptr) << "ValueError: Did you forget to initialize the TuneContext?";
+ CHECK(database.defined())
+ << "ValueError: Database is not supplied in PreTuning. Evolutionary"
+ "search algorithm requires a database to be present, so that it "
+ "could sample from previously-explored population. If you do not "
+ "intent to store data on disk, please use `tvm.meta_schedule.testing.DummyDatabase`";
+ CHECK(cost_model.defined())
+ << "ValueError: CostModel is not supplied in PreTuning. Evolutionary search "
+ "algorithm expects a cost model to filter out potentially less efficient kernels. If "
+ "you do not expect a cost model to help, please use "
+ "`tvm.meta_schedule.cost_model.RandomModel`";
+ if (this->state_ != nullptr) {
+ TVM_PY_LOG(WARNING, this->context_->logging_func)
+ << "EvolutionarySearch is already initialized.";
+ this->state_.reset();
+ }
ICHECK(this->state_ == nullptr);
- // Change to traces
Array<tir::Trace> design_space_traces;
design_space_traces.reserve(design_spaces.size());
for (const Schedule& space : design_spaces) {
design_space_traces.push_back(space->trace().value()->Simplified(true));
}
- this->state_ = std::make_unique<State>(this, design_space_traces);
+ this->state_ =
+ std::make_unique<State>(this, design_space_traces, database.value(), cost_model.value());
}
void PostTuning() final {
@@ -442,16 +441,16 @@ class EvolutionarySearchNode : public SearchStrategyNode {
std::vector<Schedule> EvolutionarySearchNode::State::PickBestFromDatabase(int num) {
std::vector<tir::Trace> measured_traces;
measured_traces.reserve(num);
- Array<TuningRecord> top_records = self->database_->GetTopK(self->token_, num);
+ Array<TuningRecord> top_records = this->database_->GetTopK(this->token_, num);
for (TuningRecord record : top_records) {
measured_traces.push_back(record->trace);
}
int actual_num = measured_traces.size();
- ThreadedTraceApply pp(self->postprocs_);
+ ThreadedTraceApply pp(self->context_->postprocs);
std::vector<Schedule> results(actual_num, Schedule{nullptr});
auto f_proc_measured = [this, &measured_traces, &results, &pp](int thread_id,
int trace_id) -> void {
- PerThreadData& data = self->per_thread_data_.at(thread_id);
+ PerThreadData& data = this->per_thread_data_.at(thread_id);
TRandState* rand_state = &data.rand_state;
const IRModule& mod = data.mod;
tir::Trace trace = measured_traces.at(trace_id);
@@ -464,17 +463,17 @@ std::vector<Schedule> EvolutionarySearchNode::State::PickBestFromDatabase(int nu
throw;
}
};
- support::parallel_for_dynamic(0, actual_num, self->num_threads_, f_proc_measured);
+ support::parallel_for_dynamic(0, actual_num, self->context_->num_threads, f_proc_measured);
return results;
}
std::vector<Schedule> EvolutionarySearchNode::State::SampleInitPopulation(int num) {
- ThreadedTraceApply pp(self->postprocs_);
+ ThreadedTraceApply pp(self->context_->postprocs);
std::vector<Schedule> out_schs;
while (static_cast<int>(out_schs.size()) < self->init_min_unmeasured) {
std::vector<Schedule> results(num, Schedule{nullptr});
auto f_proc_unmeasured = [this, &results, &pp](int thread_id, int trace_id) -> void {
- PerThreadData& data = self->per_thread_data_.at(thread_id);
+ PerThreadData& data = this->per_thread_data_.at(thread_id);
TRandState* rand_state = &data.rand_state;
const IRModule& mod = data.mod;
Schedule& result = results.at(trace_id);
@@ -485,7 +484,7 @@ std::vector<Schedule> EvolutionarySearchNode::State::SampleInitPopulation(int nu
result = sch.value();
}
};
- support::parallel_for_dynamic(0, num, self->num_threads_, f_proc_unmeasured);
+ support::parallel_for_dynamic(0, num, self->context_->num_threads, f_proc_unmeasured);
for (int i = 0; i < num; i++) {
if (results[i].defined()) {
out_schs.push_back(results[i]);
@@ -501,14 +500,14 @@ std::vector<Schedule> EvolutionarySearchNode::State::EvolveWithCostModel(
std::vector<Schedule> population, int num) {
ICHECK_GT(num, 0);
// The heap to record best schedule, we do not consider schedules that are already measured
- IRModuleSet exists = self->measured_workloads_;
+ IRModuleSet exists = this->measured_workloads_;
SizedHeap heap(num);
for (int iter = 0;; ++iter) {
// Predict normalized score with the cost model,
std::vector<double> scores = PredictNormalizedScore(population, //
GetRef<TuneContext>(self->context_), //
- self->cost_model_, //
- self->args_info_);
+ this->cost_model_, //
+ this->args_info_);
ICHECK_EQ(scores.size(), population.size());
for (int i = 0, n = population.size(); i < n; ++i) {
Schedule sch = population.at(i);
@@ -524,18 +523,18 @@ std::vector<Schedule> EvolutionarySearchNode::State::EvolveWithCostModel(
if (iter == self->genetic_num_iters) {
break;
}
- // Set threaded samplers, with probability from predicated normalized throughputs
- for (PerThreadData& data : self->per_thread_data_) {
- data.Set(scores, self->genetic_mutate_prob, self->mutator_probs_);
+ // Set threaded samplers, with probability from predicated normalized throughput
+ for (PerThreadData& data : this->per_thread_data_) {
+ data.Set(scores, self->genetic_mutate_prob, self->context_->mutator_probs);
}
- ThreadedTraceApply pp(self->postprocs_);
+ ThreadedTraceApply pp(self->context_->postprocs);
ConcurrentBitmask cbmask(self->population_size);
std::vector<Schedule> next_population(self->population_size, Schedule{nullptr});
// The worker function
auto f_find_candidate = [&cbmask, &population, &next_population, &pp, this](int thread_id,
int trace_id) {
// Prepare samplers
- PerThreadData& data = self->per_thread_data_.at(thread_id);
+ PerThreadData& data = this->per_thread_data_.at(thread_id);
TRandState* rand_state = &data.rand_state;
const IRModule& mod = data.mod;
std::function<int()>& trace_sampler = data.trace_sampler;
@@ -567,7 +566,8 @@ std::vector<Schedule> EvolutionarySearchNode::State::EvolveWithCostModel(
result = population.at(sampled_trace_id);
}
};
- support::parallel_for_dynamic(0, self->population_size, self->num_threads_, f_find_candidate);
+ support::parallel_for_dynamic(0, self->population_size, self->context_->num_threads,
+ f_find_candidate);
population.swap(next_population);
TVM_PY_LOG(INFO, self->context_->logging_func) << "Evolve iter #" << iter << " done. Summary:\n"
<< pp.SummarizeFailures();
@@ -607,7 +607,7 @@ std::vector<Schedule> EvolutionarySearchNode::State::PickWithEpsGreedy(
tir::SampleWithoutReplacement(&self->rand_state_, unmeasured.size(), unmeasured.size());
std::vector<Schedule> results;
results.reserve(num);
- IRModuleSet& measured_workloads = self->measured_workloads_;
+ IRModuleSet& measured_workloads = this->measured_workloads_;
for (int i = 0, i_bests = 0, i_rands = 0; i < num; ++i) {
bool has_best = i_bests < static_cast<int>(bests.size());
bool has_rand = i_rands < static_cast<int>(rands.size());
@@ -677,7 +677,7 @@ Optional<Array<MeasureCandidate>> EvolutionarySearchNode::State::GenerateMeasure
return NullOpt;
}
}
- return AssembleCandidates(picks, self->args_info_);
+ return AssembleCandidates(picks, this->args_info_);
}
void EvolutionarySearchNode::State::NotifyRunnerResults(
@@ -713,6 +713,12 @@ SearchStrategy SearchStrategy::EvolutionarySearch(int num_trials_per_iter, /
return SearchStrategy(n);
}
+class EvolutionarySearch : public SearchStrategy {
+ public:
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(EvolutionarySearch, SearchStrategy,
+ EvolutionarySearchNode);
+};
+
TVM_REGISTER_NODE_TYPE(EvolutionarySearchNode);
TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyEvolutionarySearch")
.set_body_typed(SearchStrategy::EvolutionarySearch);
diff --git a/src/meta_schedule/search_strategy/replay_func.cc b/src/meta_schedule/search_strategy/replay_func.cc
index 878c872a65..1aaaaa09e8 100644
--- a/src/meta_schedule/search_strategy/replay_func.cc
+++ b/src/meta_schedule/search_strategy/replay_func.cc
@@ -32,8 +32,14 @@ class ReplayFuncNode : public SearchStrategyNode {
int st;
/*! \brief `[st, ed)` are the indices of the next batch of candidates. */
int ed;
+ /*! \brief The metadata of the function arguments. */
+ Array<ArgInfo> args_info_{nullptr};
- explicit State(ReplayFuncNode* self) : self(self), st(0), ed(self->num_trials_per_iter) {}
+ explicit State(ReplayFuncNode* self) : self(self), st(0), ed(self->num_trials_per_iter) {
+ const TuneContextNode* ctx = self->context_;
+ ICHECK(ctx);
+ this->args_info_ = ArgInfo::FromPrimFunc(FindEntryFunc(ctx->mod.value()));
+ }
inline Optional<Array<MeasureCandidate>> GenerateMeasureCandidates();
inline void NotifyRunnerResults(const Array<RunnerResult>& results);
@@ -44,14 +50,8 @@ class ReplayFuncNode : public SearchStrategyNode {
/*! \brief The number of total trials. */
int max_trials_per_task;
- /*! \brief The module to be tuned. */
- IRModule mod_{nullptr};
- /*! \brief The metadata of the function arguments. */
- Array<ArgInfo> args_info_{nullptr};
- /*! \brief The post processors */
- Array<Postproc> postprocs_{nullptr};
- /*! \brief The space generator for measure candidates generation. */
- SpaceGenerator space_generator_{nullptr};
+ /*! \brief The tuning context of the search strategy. */
+ const TuneContextNode* context_{nullptr};
/*! \brief The random state. -1 means using random number. */
TRandState rand_state_ = -1;
/*! \brief The state of the search strategy. */
@@ -60,10 +60,7 @@ class ReplayFuncNode : public SearchStrategyNode {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("num_trials_per_iter", &num_trials_per_iter);
v->Visit("max_trials_per_task", &max_trials_per_task);
- // `space_generator_` is not visited
- // `mod_` is not visited
- // `args_info_` is not visited
- // `num_threads_` is not visited
+ // `context_` is not visited.
// `rand_state_` is not visited
// `state_` is not visited
}
@@ -72,15 +69,21 @@ class ReplayFuncNode : public SearchStrategyNode {
TVM_DECLARE_FINAL_OBJECT_INFO(ReplayFuncNode, SearchStrategyNode);
void InitializeWithTuneContext(const TuneContext& context) final {
- this->space_generator_ = context->space_generator.value();
- this->mod_ = context->mod.value();
- this->args_info_ = ArgInfo::FromPrimFunc(FindEntryFunc(context->mod.value()));
- this->postprocs_ = context->postprocs;
+ CHECK(context->space_generator.defined())
+ << "ValueError: TuneContext.space_generator is not defined";
+ CHECK(context->mod.defined()) << "ValueError: TuneContext.mod is not defined";
+ this->context_ = context.get();
this->rand_state_ = ForkSeed(&context->rand_state);
this->state_.reset();
}
- void PreTuning(const Array<tir::Schedule>& design_spaces) final {
+ void PreTuning(const Array<tir::Schedule>& design_spaces, const Optional<Database>& database,
+ const Optional<CostModel>& cost_model) final {
+ CHECK(this->context_ != nullptr) << "ValueError: Did you forget to initialize the TuneContext?";
+ if (this->state_ != nullptr) {
+ TVM_PY_LOG(WARNING, this->context_->logging_func) << "ReplayFunc is already initialized.";
+ this->state_.reset();
+ }
ICHECK(this->state_ == nullptr);
this->state_ = std::make_unique<State>(this);
}
@@ -109,21 +112,24 @@ inline Optional<Array<MeasureCandidate>> ReplayFuncNode::State::GenerateMeasureC
}
ed = std::min(ed, self->max_trials_per_task);
Array<MeasureCandidate> result;
+ const TuneContextNode* ctx = self->context_;
+ ICHECK(ctx);
+ IRModule mod = ctx->mod.value();
for (int i = st; i < ed; i++) {
for (;;) {
- Array<tir::Schedule> schs = self->space_generator_->GenerateDesignSpace(self->mod_);
+ Array<tir::Schedule> schs = ctx->space_generator.value()->GenerateDesignSpace(mod);
int design_space_index = tir::SampleInt(&self->rand_state_, 0, schs.size());
tir::Schedule sch = schs[design_space_index];
sch->EnterPostproc();
bool failed = false;
- for (const Postproc& proc : self->postprocs_) {
+ for (const Postproc& proc : ctx->postprocs) {
if (!proc->Apply(sch)) {
failed = true;
break;
}
}
if (!failed) {
- result.push_back(MeasureCandidate(sch, self->args_info_));
+ result.push_back(MeasureCandidate(sch, this->args_info_));
break;
}
}
diff --git a/src/meta_schedule/search_strategy/replay_trace.cc b/src/meta_schedule/search_strategy/replay_trace.cc
index f17c5d6c4e..13f32a744e 100644
--- a/src/meta_schedule/search_strategy/replay_trace.cc
+++ b/src/meta_schedule/search_strategy/replay_trace.cc
@@ -35,8 +35,22 @@ class ReplayTraceNode : public SearchStrategyNode {
/*! \brief `[st, ed)` are the indices of the next batch of candidates. */
int ed;
+ /*! \brief The module to be tuned. */
+ Array<IRModule> per_thread_mod_{nullptr};
+ /*! \brief The metadata of the function arguments. */
+ Array<ArgInfo> args_info_{nullptr};
+
explicit State(ReplayTraceNode* self, Array<tir::Trace> design_spaces)
- : self(self), design_spaces(design_spaces), st(0), ed(self->num_trials_per_iter) {}
+ : self(self), design_spaces(design_spaces), st(0), ed(self->num_trials_per_iter) {
+ const TuneContextNode* ctx = self->context_;
+ ICHECK(ctx);
+ IRModule mod = ctx->mod.value();
+ this->per_thread_mod_.reserve(ctx->num_threads);
+ for (int i = 0; i < ctx->num_threads; i++) {
+ this->per_thread_mod_.push_back(DeepCopyIRModule(mod));
+ }
+ this->args_info_ = ArgInfo::FromPrimFunc(FindEntryFunc(mod));
+ }
inline Optional<Array<MeasureCandidate>> GenerateMeasureCandidates();
inline void NotifyRunnerResults(const Array<RunnerResult>& results);
@@ -47,14 +61,8 @@ class ReplayTraceNode : public SearchStrategyNode {
/*! \brief The number of total trials. */
int max_trials_per_task;
- /*! \brief The module to be tuned. */
- Array<IRModule> per_thread_mod_{nullptr};
- /*! \brief The metadata of the function arguments. */
- Array<ArgInfo> args_info_{nullptr};
- /*! \brief The post processors */
- Array<Postproc> postprocs_{nullptr};
- /*! \brief The number of threads to use. -1 means using logical cpu number. */
- int num_threads_ = -1;
+ /*! \brief The tuning context of the search strategy. */
+ const TuneContextNode* context_{nullptr};
/*! \brief The random state. -1 means using random number. */
TRandState rand_state_ = -1;
/*! \brief The state of the search strategy. */
@@ -63,10 +71,7 @@ class ReplayTraceNode : public SearchStrategyNode {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("num_trials_per_iter", &num_trials_per_iter);
v->Visit("max_trials_per_task", &max_trials_per_task);
- // `per_thread_mod_` is not visited
- // `args_info_` is not visited
- // `postprocs_` is not visited
- // `num_threads_` is not visited
+ // `context_` is not visited.
// `rand_state_` is not visited
// `state_` is not visited
}
@@ -75,22 +80,20 @@ class ReplayTraceNode : public SearchStrategyNode {
TVM_DECLARE_FINAL_OBJECT_INFO(ReplayTraceNode, SearchStrategyNode);
void InitializeWithTuneContext(const TuneContext& context) final {
- CHECK(context->num_threads > 0) << "Number of threads has to be larger than 0.";
- this->num_threads_ = context->num_threads;
-
- this->per_thread_mod_.reserve(this->num_threads_);
- for (int i = 0; i < this->num_threads_; i++) {
- this->per_thread_mod_.push_back(DeepCopyIRModule(context->mod.value()));
- }
-
- this->args_info_ = ArgInfo::FromPrimFunc(FindEntryFunc(context->mod.value()));
- this->postprocs_ = context->postprocs;
+ CHECK(context->mod.defined()) << "ValueError: TuneContext.mod is not defined";
+ this->context_ = context.get();
this->rand_state_ = ForkSeed(&context->rand_state);
this->state_.reset();
}
- void PreTuning(const Array<tir::Schedule>& design_spaces) final {
+ void PreTuning(const Array<tir::Schedule>& design_spaces, const Optional<Database>& database,
+ const Optional<CostModel>& cost_model) final {
ICHECK(!design_spaces.empty());
+ CHECK(this->context_ != nullptr) << "ValueError: Did you forget to initialize the TuneContext?";
+ if (this->state_ != nullptr) {
+ TVM_PY_LOG(WARNING, this->context_->logging_func) << "RelayTrace is already initialized.";
+ this->state_.reset();
+ }
ICHECK(this->state_ == nullptr);
Array<tir::Trace> design_space_traces;
design_space_traces.reserve(design_spaces.size());
@@ -124,24 +127,26 @@ inline Optional<Array<MeasureCandidate>> ReplayTraceNode::State::GenerateMeasure
}
ed = std::min(ed, self->max_trials_per_task);
ICHECK_LT(st, ed);
- std::vector<TRandState> per_thread_rand_state = ForkSeed(&self->rand_state_, self->num_threads_);
+ const TuneContextNode* ctx = self->context_;
+ ICHECK(ctx);
+ std::vector<TRandState> per_thread_rand_state = ForkSeed(&self->rand_state_, ctx->num_threads);
Array<MeasureCandidate> per_task_result(ed - st, MeasureCandidate{nullptr});
- ThreadedTraceApply pp(self->postprocs_);
+ ThreadedTraceApply pp(ctx->postprocs);
auto f_worker = [this, &per_thread_rand_state, &per_task_result, &pp](int thread_id,
int task_id) -> void {
TRandState& rand_state = per_thread_rand_state[thread_id];
- IRModule mod = self->per_thread_mod_[thread_id];
+ IRModule mod = this->per_thread_mod_[thread_id];
for (;;) {
int design_space_index = tir::SampleInt(&rand_state, 0, design_spaces.size());
tir::Trace trace = design_spaces[design_space_index];
tir::Trace new_trace = tir::Trace(trace->insts, {});
if (Optional<tir::Schedule> sch = pp.Apply(mod, new_trace, &rand_state)) {
- per_task_result.Set(task_id, MeasureCandidate(sch.value(), self->args_info_));
+ per_task_result.Set(task_id, MeasureCandidate(sch.value(), this->args_info_));
break;
}
}
};
- support::parallel_for_dynamic(0, ed - st, self->num_threads_, f_worker);
+ support::parallel_for_dynamic(0, ed - st, ctx->num_threads, f_worker);
return per_task_result;
}
diff --git a/src/meta_schedule/search_strategy/search_strategy.cc b/src/meta_schedule/search_strategy/search_strategy.cc
index fefe8dfce7..a6a1100ceb 100644
--- a/src/meta_schedule/search_strategy/search_strategy.cc
+++ b/src/meta_schedule/search_strategy/search_strategy.cc
@@ -28,6 +28,13 @@ MeasureCandidate::MeasureCandidate(tir::Schedule sch, Array<ArgInfo> args_info)
data_ = std::move(n);
}
+void PySearchStrategyNode::PreTuning(const Array<tir::Schedule>& design_spaces,
+ const Optional<Database>& database,
+ const Optional<CostModel>& cost_model) {
+ ICHECK(f_pre_tuning != nullptr) << "PySearchStrategy's PreTuning method not implemented!";
+ this->f_pre_tuning(design_spaces, database, cost_model);
+}
+
SearchStrategy SearchStrategy::PySearchStrategy(
PySearchStrategyNode::FInitializeWithTuneContext f_initialize_with_tune_context, //
PySearchStrategyNode::FPreTuning f_pre_tuning, //
diff --git a/src/meta_schedule/task_scheduler/gradient_based.cc b/src/meta_schedule/task_scheduler/gradient_based.cc
index f8cc9d5514..73d191f593 100644
--- a/src/meta_schedule/task_scheduler/gradient_based.cc
+++ b/src/meta_schedule/task_scheduler/gradient_based.cc
@@ -189,10 +189,10 @@ TaskScheduler TaskScheduler::GradientBased(Array<TuneContext> tasks,
Array<FloatImm> task_weights, //
Builder builder, //
Runner runner, //
- Database database, //
- int max_trials, //
+ Optional<Database> database, //
Optional<CostModel> cost_model, //
Optional<Array<MeasureCallback>> measure_callbacks, //
+ int max_trials, //
PackedFunc logging_func, //
double alpha, //
int window_size, //
@@ -227,9 +227,6 @@ TaskScheduler TaskScheduler::GradientBased(Array<TuneContext> tasks,
n->best_time_cost_per_task_ = std::vector<double>(n_tasks, 1e100);
n->num_rounds_already_ = 0;
support::LinearCongruentialEngine(&n->rand_state_).Seed(seed);
- for (const TuneContext& task : tasks) {
- task->task_scheduler = n.get();
- }
return TaskScheduler(n);
}
diff --git a/src/meta_schedule/task_scheduler/round_robin.cc b/src/meta_schedule/task_scheduler/round_robin.cc
index 446b118379..ea22878840 100644
--- a/src/meta_schedule/task_scheduler/round_robin.cc
+++ b/src/meta_schedule/task_scheduler/round_robin.cc
@@ -58,10 +58,10 @@ class RoundRobinNode final : public TaskSchedulerNode {
TaskScheduler TaskScheduler::RoundRobin(Array<TuneContext> tasks, //
Builder builder, //
Runner runner, //
- Database database, //
- int max_trials, //
+ Optional<Database> database, //
Optional<CostModel> cost_model, //
Optional<Array<MeasureCallback>> measure_callbacks, //
+ int max_trials, //
PackedFunc logging_func) {
ObjectPtr<RoundRobinNode> n = make_object<RoundRobinNode>();
n->tasks = tasks;
@@ -74,9 +74,6 @@ TaskScheduler TaskScheduler::RoundRobin(Array<TuneContext> tasks,
n->logging_func = logging_func;
n->num_trials_already = 0;
n->task_id = -1;
- for (const TuneContext& task : tasks) {
- task->task_scheduler = n.get();
- }
return TaskScheduler(n);
}
diff --git a/src/meta_schedule/task_scheduler/task_scheduler.cc b/src/meta_schedule/task_scheduler/task_scheduler.cc
index fd1d95cd1f..25867fb4f3 100644
--- a/src/meta_schedule/task_scheduler/task_scheduler.cc
+++ b/src/meta_schedule/task_scheduler/task_scheduler.cc
@@ -117,7 +117,7 @@ void TaskSchedulerNode::InitializeTask(int task_id) {
<< tir::AsTVMScript(sch->mod()) << "\n"
<< Concat(trace->AsPython(false), "\n");
}
- task->search_strategy.value()->PreTuning(design_spaces);
+ task->search_strategy.value()->PreTuning(design_spaces, database, cost_model);
}
void TaskSchedulerNode::Tune() {
@@ -203,10 +203,10 @@ TaskScheduler TaskScheduler::PyTaskScheduler(
Array<TuneContext> tasks, //
Builder builder, //
Runner runner, //
- Database database, //
- int max_trials, //
+ Optional<Database> database, //
Optional<CostModel> cost_model, //
Optional<Array<MeasureCallback>> measure_callbacks, //
+ int max_trials, //
PackedFunc logging_func, //
PyTaskSchedulerNode::FTune f_tune, //
PyTaskSchedulerNode::FInitializeTask f_initialize_task, //
diff --git a/tests/python/unittest/test_meta_schedule_measure_callback.py b/tests/python/unittest/test_meta_schedule_measure_callback.py
index a1b188930f..298b51e015 100644
--- a/tests/python/unittest/test_meta_schedule_measure_callback.py
+++ b/tests/python/unittest/test_meta_schedule_measure_callback.py
@@ -16,12 +16,10 @@
# under the License.
# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
import re
-from random import random
from typing import List
import pytest
import tvm
-from tvm.ir import IRModule, assert_structural_equal
from tvm.meta_schedule.builder import BuilderResult
from tvm.meta_schedule.measure_callback import PyMeasureCallback
from tvm.meta_schedule.runner import RunnerResult
@@ -66,7 +64,7 @@ def test_meta_schedule_measure_callback():
results: List[RunnerResult],
) -> None:
assert len(measure_candidates) == 1
- assert_structural_equal(measure_candidates[0].sch.mod, Matmul)
+ tvm.ir.assert_structural_equal(measure_candidates[0].sch.mod, Matmul)
assert (
len(builds) == 1
and builds[0].error_msg is None
@@ -78,7 +76,14 @@ def test_meta_schedule_measure_callback():
measure_callback = FancyMeasureCallback()
measure_callback.apply(
- RoundRobin([], [], DummyBuilder(), DummyRunner(), DummyDatabase(), max_trials=1),
+ RoundRobin(
+ tasks=[],
+ task_weights=[],
+ builder=DummyBuilder(),
+ runner=DummyRunner(),
+ database=DummyDatabase(),
+ max_trials=1,
+ ),
0,
[MeasureCandidate(Schedule(Matmul), None)],
[BuilderResult("test_build", None)],
@@ -102,7 +107,14 @@ def test_meta_schedule_measure_callback_fail():
measure_callback = FailingMeasureCallback()
with pytest.raises(ValueError, match="test"):
measure_callback.apply(
- RoundRobin([], [], DummyBuilder(), DummyRunner(), DummyDatabase(), max_trials=1),
+ RoundRobin(
+ tasks=[],
+ task_weights=[],
+ builder=DummyBuilder(),
+ runner=DummyRunner(),
+ database=DummyDatabase(),
+ max_trials=1,
+ ),
0,
[MeasureCandidate(Schedule(Matmul), None)],
[BuilderResult("test_build", None)],
diff --git a/tests/python/unittest/test_meta_schedule_search_strategy.py b/tests/python/unittest/test_meta_schedule_search_strategy.py
index 94042dd753..4eb8aac5a3 100644
--- a/tests/python/unittest/test_meta_schedule_search_strategy.py
+++ b/tests/python/unittest/test_meta_schedule_search_strategy.py
@@ -123,43 +123,37 @@ def test_meta_schedule_evolutionary_search(): # pylint: disable = invalid-name]
num_trials_per_iter = 10
max_trials_per_task = 2000
+ (correct_sch,) = ScheduleFn(sch_fn=_schedule_matmul).generate_design_space(Matmul)
- strategy = EvolutionarySearch(
- num_trials_per_iter=num_trials_per_iter,
- max_trials_per_task=max_trials_per_task,
- population_size=5,
- init_measured_ratio=0.1,
- init_min_unmeasured=50,
- genetic_num_iters=3,
- genetic_mutate_prob=0.5,
- genetic_max_fail_count=10,
- eps_greedy=0.9,
- )
context = TuneContext(
mod=Matmul,
- space_generator=ScheduleFn(sch_fn=_schedule_matmul_small),
+ space_generator=ScheduleFn(
+ sch_fn=_schedule_matmul_small,
+ ),
+ search_strategy=EvolutionarySearch(
+ num_trials_per_iter=num_trials_per_iter,
+ max_trials_per_task=max_trials_per_task,
+ population_size=5,
+ init_measured_ratio=0.1,
+ init_min_unmeasured=50,
+ genetic_num_iters=3,
+ genetic_mutate_prob=0.5,
+ genetic_max_fail_count=10,
+ eps_greedy=0.9,
+ ),
mutator_probs={
DummyMutator(): 1.0,
},
target=tvm.target.Target("llvm"),
num_threads=1, # because we are using a mutator from the python side
)
- _scheduler = RoundRobin(
- tasks=[context],
- task_weights=[1.0],
- builder=ms.builder.LocalBuilder(),
- runner=ms.runner.LocalRunner(),
+ context.initialize()
+ strategy = context.search_strategy
+ strategy.pre_tuning(
+ context.space_generator.generate_design_space(context.mod),
database=DummyDatabase(),
cost_model=ms.cost_model.RandomModel(),
- measure_callbacks=[],
- max_trials=1,
)
- context.space_generator.initialize_with_tune_context(context)
- spaces = context.space_generator.generate_design_space(context.mod)
-
- strategy.initialize_with_tune_context(context)
- strategy.pre_tuning(spaces)
- (correct_sch,) = ScheduleFn(sch_fn=_schedule_matmul).generate_design_space(Matmul)
num_trials_each_iter: List[int] = []
candidates = strategy.generate_measure_candidates()
while candidates is not None:
@@ -177,52 +171,46 @@ def test_meta_schedule_evolutionary_search(): # pylint: disable = invalid-name]
strategy.post_tuning()
assert sum(num_trials_each_iter) == 25
assert num_trials_each_iter.count(0) < 5
- del _scheduler
def test_meta_schedule_evolutionary_search_early_stop(): # pylint: disable = invalid-name]
def _schedule_matmul_empty(sch: Schedule):
return sch
+ (correct_sch,) = ScheduleFn(sch_fn=_schedule_matmul).generate_design_space(Matmul)
+
num_trials_per_iter = 10
max_trials_per_task = 100
- strategy = EvolutionarySearch(
- num_trials_per_iter=num_trials_per_iter,
- max_trials_per_task=max_trials_per_task,
- population_size=5,
- init_measured_ratio=0.1,
- init_min_unmeasured=50,
- genetic_num_iters=3,
- genetic_mutate_prob=0.5,
- genetic_max_fail_count=10,
- eps_greedy=0.9,
- )
context = TuneContext(
mod=Matmul,
- space_generator=ScheduleFn(sch_fn=_schedule_matmul_empty),
+ search_strategy=EvolutionarySearch(
+ num_trials_per_iter=num_trials_per_iter,
+ max_trials_per_task=max_trials_per_task,
+ population_size=5,
+ init_measured_ratio=0.1,
+ init_min_unmeasured=50,
+ genetic_num_iters=3,
+ genetic_mutate_prob=0.5,
+ genetic_max_fail_count=10,
+ eps_greedy=0.9,
+ ),
+ space_generator=ScheduleFn(
+ sch_fn=_schedule_matmul_empty,
+ ),
mutator_probs={
DummyMutator(): 1.0,
},
target=tvm.target.Target("llvm"),
- num_threads=1, # because we are using a mutator from the python side
+ num_threads=1,
)
- _scheduler = RoundRobin(
- tasks=[context],
- task_weights=[1.0],
- builder=ms.builder.LocalBuilder(),
- runner=ms.runner.LocalRunner(),
+ context.initialize()
+ strategy = context.search_strategy
+ strategy.pre_tuning(
+ context.space_generator.generate_design_space(context.mod),
database=DummyDatabase(),
cost_model=ms.cost_model.RandomModel(),
- measure_callbacks=[],
- max_trials=1,
)
- context.space_generator.initialize_with_tune_context(context)
- spaces = context.space_generator.generate_design_space(context.mod)
-
- strategy.initialize_with_tune_context(context)
- strategy.pre_tuning(spaces)
- (correct_sch,) = ScheduleFn(sch_fn=_schedule_matmul).generate_design_space(Matmul)
num_trials_each_iter: List[int] = []
candidates = strategy.generate_measure_candidates()
while candidates is not None:
@@ -239,7 +227,6 @@ def test_meta_schedule_evolutionary_search_early_stop(): # pylint: disable = in
candidates = strategy.generate_measure_candidates()
strategy.post_tuning()
assert num_trials_each_iter == [1, 0, 0, 0, 0]
- del _scheduler
if __name__ == "__main__":
diff --git a/tests/python/unittest/test_meta_schedule_task_scheduler.py b/tests/python/unittest/test_meta_schedule_task_scheduler.py
index 025bbe4225..f24dc5fbbc 100644
--- a/tests/python/unittest/test_meta_schedule_task_scheduler.py
+++ b/tests/python/unittest/test_meta_schedule_task_scheduler.py
@@ -17,7 +17,6 @@
""" Test Meta Schedule Task Scheduler """
import random
-import sys
import weakref
from typing import Set
@@ -108,7 +107,6 @@ class BatchMatmulModule:
def _schedule_matmul(sch: Schedule):
block = sch.get_block("matmul")
i, j, k = sch.get_loops(block=block)
- # TODO(@zxybazh): Change to `sample_perfect_tile` after upstreaming
i_0, i_1, i_2, i_3 = sch.split(loop=i, factors=[2, 4, 64, 2])
j_0, j_1, j_2, j_3 = sch.split(loop=j, factors=[4, 64, 2, 2])
k_0, k_1 = sch.split(loop=k, factors=[32, 32])
@@ -118,7 +116,6 @@ def _schedule_matmul(sch: Schedule):
def _schedule_batch_matmul(sch: Schedule):
block = sch.get_block("matmul")
i, j, k, t = sch.get_loops(block=block)
- # TODO(@zxybazh): Change to `sample_perfect_tile` after upstreaming
i_0, i_1, i_2, i_3 = sch.split(loop=i, factors=[2, 2, 2, 2])
j_0, j_1, j_2, j_3 = sch.split(loop=j, factors=[2, 4, 64, 2])
k_0, k_1 = sch.split(loop=k, factors=[32, 32])
@@ -156,23 +153,22 @@ class MyTaskScheduler(PyTaskScheduler):
def test_meta_schedule_task_scheduler_single():
num_trials_per_iter = 3
max_trials_per_task = 10
- sch_fn = ScheduleFn(sch_fn=_schedule_matmul)
- replay = ReplayTrace(num_trials_per_iter, max_trials_per_task)
- task = TuneContext(
- MatmulModule,
- target=tvm.target.Target("llvm"),
- space_generator=sch_fn,
- search_strategy=replay,
- task_name="Test",
- rand_state=42,
- )
database = DummyDatabase()
round_robin = RoundRobin(
- [task],
+ [
+ TuneContext(
+ MatmulModule,
+ target=tvm.target.Target("llvm"),
+ space_generator=ScheduleFn(sch_fn=_schedule_matmul),
+ search_strategy=ReplayTrace(num_trials_per_iter, max_trials_per_task),
+ task_name="Test",
+ rand_state=42,
+ )
+ ],
[1.0],
- DummyBuilder(),
- DummyRunner(),
- database,
+ builder=DummyBuilder(),
+ runner=DummyRunner(),
+ database=database,
measure_callbacks=[measure_callback.AddToDatabase()],
max_trials=max_trials_per_task,
)
@@ -212,10 +208,10 @@ def test_meta_schedule_task_scheduler_multiple():
database = DummyDatabase()
round_robin = RoundRobin(
tasks,
- [1.0],
- DummyBuilder(),
- DummyRunner(),
- database,
+ [1.0, 1.0, 1.0],
+ builder=DummyBuilder(),
+ runner=DummyRunner(),
+ database=database,
measure_callbacks=[measure_callback.AddToDatabase()],
max_trials=max_trials_per_task * len(tasks),
)
@@ -239,18 +235,23 @@ def test_meta_schedule_task_scheduler_NIE(): # pylint: disable=invalid-name
pass
with pytest.raises(TVMError, match="PyTaskScheduler's NextTaskId method not implemented!"):
- scheduler = NIETaskScheduler([], DummyBuilder(), DummyRunner(), DummyDatabase(), 1)
+ scheduler = NIETaskScheduler(
+ tasks=[],
+ builder=DummyBuilder(),
+ runner=DummyRunner(),
+ database=DummyDatabase(),
+ max_trials=1,
+ )
scheduler.next_task_id()
def test_meta_schedule_task_scheduler_avoid_cyclic(): # pylint: disable=invalid-name
-
database = DummyDatabase()
scheduler = MyTaskScheduler(
[],
- DummyBuilder(),
- DummyRunner(),
- database,
+ builder=DummyBuilder(),
+ runner=DummyRunner(),
+ database=database,
measure_callbacks=[
measure_callback.AddToDatabase(),
],
@@ -262,7 +263,6 @@ def test_meta_schedule_task_scheduler_avoid_cyclic(): # pylint: disable=invalid
def test_meta_schedule_task_scheduler_override_next_task_id_only(): # pylint: disable=invalid-name
-
num_trials_per_iter = 6
max_trials_per_task = 101
tasks = [
@@ -294,9 +294,9 @@ def test_meta_schedule_task_scheduler_override_next_task_id_only(): # pylint: d
database = DummyDatabase()
scheduler = MyTaskScheduler(
tasks,
- DummyBuilder(),
- DummyRunner(),
- database,
+ builder=DummyBuilder(),
+ runner=DummyRunner(),
+ database=database,
measure_callbacks=[
measure_callback.AddToDatabase(),
],