You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/06/07 03:02:23 UTC

[tvm] branch main updated: [MetaSchedule] Evo Independence from TaskScheduler (#11590)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 68dcecc926 [MetaSchedule] Evo Independence from TaskScheduler (#11590)
68dcecc926 is described below

commit 68dcecc926f890429a8f2cba9ce55eab6a18fa6e
Author: Junru Shao <ju...@gmail.com>
AuthorDate: Mon Jun 6 20:02:18 2022 -0700

    [MetaSchedule] Evo Independence from TaskScheduler (#11590)
    
    Per discussion with @Kathryn-cat, we realized that the current API
    design could be verbose if we only want to tune a single task, in which
    case a dummy task scheduler still needs to be established to supply
    `EvolutionarySearch` with proper `CostModel` and `Database`. This PR
    fixes this UX issue.
---
 include/tvm/meta_schedule/search_strategy.h        |  17 ++-
 include/tvm/meta_schedule/task_scheduler.h         |  20 +--
 include/tvm/meta_schedule/tune_context.h           |   2 -
 .../search_strategy/search_strategy.py             |  24 +++-
 .../meta_schedule/task_scheduler/gradient_based.py |  10 +-
 .../meta_schedule/task_scheduler/round_robin.py    |  10 +-
 .../meta_schedule/task_scheduler/task_scheduler.py |  10 +-
 .../measure_callback/add_to_database.cc            |   5 +-
 .../search_strategy/evolutionary_search.cc         | 148 +++++++++++----------
 src/meta_schedule/search_strategy/replay_func.cc   |  48 ++++---
 src/meta_schedule/search_strategy/replay_trace.cc  |  63 +++++----
 .../search_strategy/search_strategy.cc             |   7 +
 src/meta_schedule/task_scheduler/gradient_based.cc |   7 +-
 src/meta_schedule/task_scheduler/round_robin.cc    |   7 +-
 src/meta_schedule/task_scheduler/task_scheduler.cc |   6 +-
 .../test_meta_schedule_measure_callback.py         |  22 ++-
 .../unittest/test_meta_schedule_search_strategy.py |  93 ++++++-------
 .../unittest/test_meta_schedule_task_scheduler.py  |  60 ++++-----
 18 files changed, 298 insertions(+), 261 deletions(-)

diff --git a/include/tvm/meta_schedule/search_strategy.h b/include/tvm/meta_schedule/search_strategy.h
index 6895673a04..139de7c99d 100644
--- a/include/tvm/meta_schedule/search_strategy.h
+++ b/include/tvm/meta_schedule/search_strategy.h
@@ -113,12 +113,16 @@ class SearchStrategyNode : public runtime::Object {
 
   /*!
    * \brief Pre-tuning for the search strategy.
-   * \param design_spaces The design spaces for pre-tuning.
+   * \param design_spaces The design spaces used during tuning process.
+   * \param database The database used during tuning process.
+   * \param cost_model The cost model used during tuning process.
    * \note Pre-tuning is supposed to be called before the tuning process and after the
    *  initialization. Because the search strategy is stateful, we can always call pretuning
    *  and reset the search strategy.
    */
-  virtual void PreTuning(const Array<tir::Schedule>& design_spaces) = 0;
+  virtual void PreTuning(const Array<tir::Schedule>& design_spaces,
+                         const Optional<Database>& database,
+                         const Optional<CostModel>& cost_model) = 0;
 
   /*!
    * \brief Post-tuning for the search strategy.
@@ -159,7 +163,8 @@ class PySearchStrategyNode : public SearchStrategyNode {
    * \brief The function type of `PreTuning` method.
    * \param design_spaces The design spaces for pre-tuning.
    */
-  using FPreTuning = runtime::TypedPackedFunc<void(const Array<tir::Schedule>&)>;
+  using FPreTuning = runtime::TypedPackedFunc<void(
+      const Array<tir::Schedule>&, const Optional<Database>&, const Optional<CostModel>&)>;
   /*! \brief The function type of `PostTuning` method. */
   using FPostTuning = runtime::TypedPackedFunc<void()>;
   /*!
@@ -199,10 +204,8 @@ class PySearchStrategyNode : public SearchStrategyNode {
     this->f_initialize_with_tune_context(context);
   }
 
-  void PreTuning(const Array<tir::Schedule>& design_spaces) final {
-    ICHECK(f_pre_tuning != nullptr) << "PySearchStrategy's PreTuning method not implemented!";
-    this->f_pre_tuning(design_spaces);
-  }
+  void PreTuning(const Array<tir::Schedule>& design_spaces, const Optional<Database>& database,
+                 const Optional<CostModel>& cost_model) final;
 
   void PostTuning() final {
     ICHECK(f_post_tuning != nullptr) << "PySearchStrategy's PostTuning method not implemented!";
diff --git a/include/tvm/meta_schedule/task_scheduler.h b/include/tvm/meta_schedule/task_scheduler.h
index 7453c2b484..5953a2c3e4 100644
--- a/include/tvm/meta_schedule/task_scheduler.h
+++ b/include/tvm/meta_schedule/task_scheduler.h
@@ -74,13 +74,13 @@ class TaskSchedulerNode : public runtime::Object {
   /*! \brief The runner of the scheduler. */
   Runner runner{nullptr};
   /*! \brief The database of the scheduler. */
-  Database database{nullptr};
-  /*! \brief The maximum number of trials allowed. */
-  int max_trials;
+  Optional<Database> database;
   /*! \brief The cost model of the scheduler. */
   Optional<CostModel> cost_model;
   /*! \brief The list of measure callbacks of the scheduler. */
   Array<MeasureCallback> measure_callbacks;
+  /*! \brief The maximum number of trials allowed. */
+  int max_trials;
   /*! \brief The number of trials already conducted. */
   int num_trials_already;
   /*! \brief The tuning task's logging function. t*/
@@ -94,9 +94,9 @@ class TaskSchedulerNode : public runtime::Object {
     v->Visit("builder", &builder);
     v->Visit("runner", &runner);
     v->Visit("database", &database);
-    v->Visit("max_trials", &max_trials);
     v->Visit("cost_model", &cost_model);
     v->Visit("measure_callbacks", &measure_callbacks);
+    v->Visit("max_trials", &max_trials);
     v->Visit("num_trials_already", &num_trials_already);
     // `logging_func` is not visited
   }
@@ -243,10 +243,10 @@ class TaskScheduler : public runtime::ObjectRef {
   TVM_DLL static TaskScheduler RoundRobin(Array<TuneContext> tasks,                            //
                                           Builder builder,                                     //
                                           Runner runner,                                       //
-                                          Database database,                                   //
-                                          int max_trials,                                      //
+                                          Optional<Database> database,                         //
                                           Optional<CostModel> cost_model,                      //
                                           Optional<Array<MeasureCallback>> measure_callbacks,  //
+                                          int max_trials,                                      //
                                           PackedFunc logging_func);
   /*!
    * \brief Create a task scheduler that fetches tasks in a gradient based fashion.
@@ -268,10 +268,10 @@ class TaskScheduler : public runtime::ObjectRef {
                                              Array<FloatImm> task_weights,                        //
                                              Builder builder,                                     //
                                              Runner runner,                                       //
-                                             Database database,                                   //
-                                             int max_trials,                                      //
+                                             Optional<Database> database,                         //
                                              Optional<CostModel> cost_model,                      //
                                              Optional<Array<MeasureCallback>> measure_callbacks,  //
+                                             int max_trials,                                      //
                                              PackedFunc logging_func,                             //
                                              double alpha,                                        //
                                              int window_size,                                     //
@@ -297,10 +297,10 @@ class TaskScheduler : public runtime::ObjectRef {
       Array<TuneContext> tasks,                                   //
       Builder builder,                                            //
       Runner runner,                                              //
-      Database database,                                          //
-      int max_trials,                                             //
+      Optional<Database> database,                                //
       Optional<CostModel> cost_model,                             //
       Optional<Array<MeasureCallback>> measure_callbacks,         //
+      int max_trials,                                             //
       PackedFunc logging_func,                                    //
       PyTaskSchedulerNode::FTune f_tune,                          //
       PyTaskSchedulerNode::FInitializeTask f_initialize_task,     //
diff --git a/include/tvm/meta_schedule/tune_context.h b/include/tvm/meta_schedule/tune_context.h
index faa24fc99f..d63fb819f3 100644
--- a/include/tvm/meta_schedule/tune_context.h
+++ b/include/tvm/meta_schedule/tune_context.h
@@ -61,8 +61,6 @@ class TuneContextNode : public runtime::Object {
   /*! \brief The number of threads to be used. */
   int num_threads;
 
-  /*! \brief The task scheduler that owns the tune context */
-  const TaskSchedulerNode* task_scheduler;
   /*! \brief Whether the tuning task has been stopped or finished. */
   bool is_terminated;
   /*! \brief The measure candidates. */
diff --git a/python/tvm/meta_schedule/search_strategy/search_strategy.py b/python/tvm/meta_schedule/search_strategy/search_strategy.py
index 07c47f01d1..14b46a0785 100644
--- a/python/tvm/meta_schedule/search_strategy/search_strategy.py
+++ b/python/tvm/meta_schedule/search_strategy/search_strategy.py
@@ -18,7 +18,7 @@
 Meta Schedule search strategy that generates the measure
 candidates for measurement.
 """
-from typing import Callable, List, Optional, TYPE_CHECKING
+from typing import TYPE_CHECKING, Callable, List, Optional
 
 from tvm._ffi import register_object
 from tvm.runtime import Object
@@ -29,6 +29,8 @@ from ..arg_info import ArgInfo
 from ..runner import RunnerResult
 
 if TYPE_CHECKING:
+    from ..cost_model import CostModel
+    from ..database import Database
     from ..tune_context import TuneContext
 
 
@@ -87,15 +89,29 @@ class SearchStrategy(Object):
             self, context
         )
 
-    def pre_tuning(self, design_spaces: List[Schedule]) -> None:
+    def pre_tuning(
+        self,
+        design_spaces: List[Schedule],
+        database: Optional["Database"] = None,
+        cost_model: Optional["CostModel"] = None,
+    ) -> None:
         """Pre-tuning for the search strategy.
 
         Parameters
         ----------
         design_spaces : List[Schedule]
-            The design spaces for pre-tuning.
+            The design spaces used during tuning process.
+        database : Optional[Database] = None
+            The database used during tuning process.
+        cost_model : Optional[CostModel] = None
+            The cost model used during tuning process.
         """
-        _ffi_api.SearchStrategyPreTuning(self, design_spaces)  # type: ignore # pylint: disable=no-member
+        _ffi_api.SearchStrategyPreTuning(  # type: ignore # pylint: disable=no-member
+            self,
+            design_spaces,
+            database,
+            cost_model,
+        )
 
     def post_tuning(self) -> None:
         """Post-tuning for the search strategy."""
diff --git a/python/tvm/meta_schedule/task_scheduler/gradient_based.py b/python/tvm/meta_schedule/task_scheduler/gradient_based.py
index 6234449bf0..20d32dd1c5 100644
--- a/python/tvm/meta_schedule/task_scheduler/gradient_based.py
+++ b/python/tvm/meta_schedule/task_scheduler/gradient_based.py
@@ -45,11 +45,11 @@ class GradientBased(TaskScheduler):
         task_weights: List[float],
         builder: Builder,
         runner: Runner,
-        database: Database,
-        max_trials: int,
         *,
+        database: Database,
         cost_model: Optional[CostModel] = None,
         measure_callbacks: Optional[List[MeasureCallback]] = None,
+        max_trials: int,
         alpha: float = 0.2,
         window_size: int = 3,
         seed: int = -1,
@@ -68,12 +68,12 @@ class GradientBased(TaskScheduler):
             The runner.
         database : Database
             The database.
-        max_trials : int
-            The maximum number of trials to run.
         cost_model : CostModel, default None.
             The cost model of the scheduler.
         measure_callbacks : Optional[List[MeasureCallback]] = None
             The list of measure callbacks of the scheduler.
+        max_trials : int
+            The maximum number of trials to run.
         alpha : float = 0.2
             The parameter alpha in gradient computation.
         window_size : int = 3
@@ -88,9 +88,9 @@ class GradientBased(TaskScheduler):
             builder,
             runner,
             database,
-            max_trials,
             cost_model,
             measure_callbacks,
+            max_trials,
             make_logging_func(logger),
             alpha,
             window_size,
diff --git a/python/tvm/meta_schedule/task_scheduler/round_robin.py b/python/tvm/meta_schedule/task_scheduler/round_robin.py
index a461358283..ed395643bb 100644
--- a/python/tvm/meta_schedule/task_scheduler/round_robin.py
+++ b/python/tvm/meta_schedule/task_scheduler/round_robin.py
@@ -60,11 +60,11 @@ class RoundRobin(TaskScheduler):
         task_weights: List[float],
         builder: Builder,
         runner: Runner,
-        database: Database,
-        max_trials: int,
         *,
+        database: Database,
         cost_model: Optional[CostModel] = None,
         measure_callbacks: Optional[List[MeasureCallback]] = None,
+        max_trials: int,
     ) -> None:
         """Constructor.
 
@@ -80,12 +80,12 @@ class RoundRobin(TaskScheduler):
             The runner.
         database : Database
             The database.
-        max_trials : int
-            The maximum number of trials.
         cost_model : Optional[CostModel]
             The cost model.
         measure_callbacks: Optional[List[MeasureCallback]]
             The list of measure callbacks of the scheduler.
+        max_trials : int
+            The maximum number of trials.
         """
         del task_weights
         self.__init_handle_by_constructor__(
@@ -94,8 +94,8 @@ class RoundRobin(TaskScheduler):
             builder,
             runner,
             database,
-            max_trials,
             cost_model,
             measure_callbacks,
+            max_trials,
             make_logging_func(logger),
         )
diff --git a/python/tvm/meta_schedule/task_scheduler/task_scheduler.py b/python/tvm/meta_schedule/task_scheduler/task_scheduler.py
index 4454078a6f..3d57a6b01b 100644
--- a/python/tvm/meta_schedule/task_scheduler/task_scheduler.py
+++ b/python/tvm/meta_schedule/task_scheduler/task_scheduler.py
@@ -31,7 +31,6 @@ from ..runner import Runner, RunnerResult
 from ..tune_context import TuneContext
 from ..utils import make_logging_func
 
-
 logger = logging.getLogger(__name__)  # pylint: disable=invalid-name
 
 
@@ -177,9 +176,9 @@ class PyTaskScheduler:
             "builder",
             "runner",
             "database",
-            "max_trials",
             "cost_model",
             "measure_callbacks",
+            "max_trials",
         ],
         "methods": [
             "tune",
@@ -195,18 +194,19 @@ class PyTaskScheduler:
         tasks: List[TuneContext],
         builder: Builder,
         runner: Runner,
-        database: Database,
-        max_trials: int,
+        *,
+        database: Optional[Database] = None,
         cost_model: Optional[CostModel] = None,
         measure_callbacks: Optional[List[MeasureCallback]] = None,
+        max_trials: int,
     ):
         self.tasks = tasks
         self.builder = builder
         self.runner = runner
         self.database = database
-        self.max_trials = max_trials
         self.cost_model = cost_model
         self.measure_callbacks = measure_callbacks
+        self.max_trials = max_trials
 
     def tune(self) -> None:
         """Auto-tuning."""
diff --git a/src/meta_schedule/measure_callback/add_to_database.cc b/src/meta_schedule/measure_callback/add_to_database.cc
index 20581f4630..0988da0414 100644
--- a/src/meta_schedule/measure_callback/add_to_database.cc
+++ b/src/meta_schedule/measure_callback/add_to_database.cc
@@ -27,8 +27,11 @@ class AddToDatabaseNode : public MeasureCallbackNode {
              const Array<MeasureCandidate>& measure_candidates,
              const Array<BuilderResult>& builder_results,
              const Array<RunnerResult>& runner_results) final {
+    if (!task_scheduler->database.defined()) {
+      return;
+    }
     TuneContext task = task_scheduler->tasks[task_id];
-    Database database = task_scheduler->database;
+    Database database = task_scheduler->database.value();
     Workload workload = database->CommitWorkload(task->mod.value());
     Target target = task->target.value();
     ICHECK_EQ(runner_results.size(), measure_candidates.size());
diff --git a/src/meta_schedule/search_strategy/evolutionary_search.cc b/src/meta_schedule/search_strategy/evolutionary_search.cc
index bdef26ef87..8b36a95217 100644
--- a/src/meta_schedule/search_strategy/evolutionary_search.cc
+++ b/src/meta_schedule/search_strategy/evolutionary_search.cc
@@ -246,13 +246,41 @@ class EvolutionarySearchNode : public SearchStrategyNode {
     int ed;
     /*! \brief The counter of returning empty results. */
     int num_empty_iters;
-
-    explicit State(EvolutionarySearchNode* self, Array<tir::Trace> design_spaces)
+    /*! \brief The metadata of the function arguments. */
+    Array<ArgInfo> args_info_{nullptr};
+    /*! \brief Pre thread data including module to be tuned and random state. */
+    std::vector<PerThreadData> per_thread_data_;
+    /*!
+     * \brief The workloads that are already measured.
+     * TODO(junrushao1994): add records from the database to avoid re-measuring.
+     * */
+    IRModuleSet measured_workloads_;
+    /*! \brief A Database for selecting useful candidates. */
+    Database database_{nullptr};
+    /*! \brief A cost model helping to explore the search space */
+    CostModel cost_model_{nullptr};
+    /*! \brief The token registered for the given workload in database. */
+    Workload token_{nullptr};
+
+    explicit State(EvolutionarySearchNode* self, Array<tir::Trace> design_spaces, Database database,
+                   CostModel cost_model)
         : self(self),
           design_spaces(design_spaces),
           st(0),
           ed(self->num_trials_per_iter),
-          num_empty_iters(0) {}
+          num_empty_iters(0) {
+      const TuneContextNode* ctx = self->context_;
+      IRModule mod = ctx->mod.value();
+      this->args_info_ = ArgInfo::FromPrimFunc(FindEntryFunc(mod));
+      this->per_thread_data_.resize(ctx->num_threads);
+      for (PerThreadData& data : this->per_thread_data_) {
+        data.mod = DeepCopyIRModule(mod);
+        data.rand_state = ForkSeed(&self->rand_state_);
+      }
+      this->database_ = database;
+      this->cost_model_ = cost_model;
+      this->token_ = database->CommitWorkload(mod);
+    }
 
     /*!
      * \brief Pick up best candidates from database.
@@ -293,33 +321,10 @@ class EvolutionarySearchNode : public SearchStrategyNode {
 
   /*! \brief The tuning context of the evolutionary search strategy. */
   const TuneContextNode* context_{nullptr};
-  /*! \brief The target for the workload. */
-  Target target_{nullptr};
-  /*! \brief The metadata of the function arguments. */
-  Array<ArgInfo> args_info_{nullptr};
-  /*! \brief A Database for selecting useful candidates. */
-  Database database_{nullptr};
-  /*! \brief A cost model helping to explore the search space */
-  CostModel cost_model_{nullptr};
-  /*! \brief The postprocessors. */
-  Array<Postproc> postprocs_{nullptr};
-  /*! \brief Mutators and their probability mass */
-  Map<Mutator, FloatImm> mutator_probs_{nullptr};
-  /*! \brief The number of threads to use. To be initialized with TuneContext. */
-  int num_threads_;
   /*! \brief The random state. To be initialized with TuneContext. */
   TRandState rand_state_;
-  /*! \brief Pre thread data including module to be tuned and random state. */
-  std::vector<PerThreadData> per_thread_data_;
   /*! \brief The state of the search strategy. */
   std::unique_ptr<State> state_ = nullptr;
-  /*! \brief The token registered for the given workload in database. */
-  Workload token_{nullptr};
-  /*!
-   * \brief The workloads that are already measured.
-   * TODO(junrushao1994): add records from the database to avoid re-measuring.
-   * */
-  IRModuleSet measured_workloads_;
 
   /*** Configuration: global ***/
   /*! \brief The number of trials per iteration. */
@@ -351,15 +356,7 @@ class EvolutionarySearchNode : public SearchStrategyNode {
 
   void VisitAttrs(tvm::AttrVisitor* v) {
     // `context_` is not visited
-    // `target_` is not visited
-    // `args_info_` is not visited
-    // `database` is not visited
-    // `cost_model` is not visited
-    // `postprocs` is not visited
-    // `mutator_probs_` is not visited
-    // `num_threads` is not visited
     // `rand_state_` is not visited
-    // `per_thread_data_` is not visited
     // `state_` is not visited
 
     /*** Configuration: global ***/
@@ -386,39 +383,41 @@ class EvolutionarySearchNode : public SearchStrategyNode {
     CHECK(context->num_threads > 0) << "Number of threads has to be larger than 0.";
     CHECK(context->target.defined()) << "Target must be defined!";
     this->context_ = context.get();
-    this->target_ = context->target.value();
-    this->args_info_ = ArgInfo::FromPrimFunc(FindEntryFunc(context->mod.value()));
-    this->mutator_probs_ = context->mutator_probs;
-    this->postprocs_ = context->postprocs;
-    this->num_threads_ = context->num_threads;
     this->rand_state_ = ForkSeed(&context->rand_state);
-    CHECK(context->task_scheduler != nullptr)
-        << "ValueError: TaskScheduler is not defined in TuneContext";
-    this->cost_model_ = context->task_scheduler->cost_model.value();
-    this->database_ = context->task_scheduler->database;
-    this->token_ = this->database_->CommitWorkload(context->mod.value());
-    this->per_thread_data_.resize(this->num_threads_);
-    for (const auto& kv : this->mutator_probs_) {
+    for (const auto& kv : context->mutator_probs) {
       double mass = kv.second->value;
       TVM_META_SCHEDULE_CHECK_PROB_RANGE(mass, "mutator_probs");
     }
-    for (PerThreadData& data : this->per_thread_data_) {
-      data.mod = DeepCopyIRModule(context->mod.value());
-      data.rand_state = ForkSeed(&this->rand_state_);
-    }
     this->state_.reset();
   }
 
-  void PreTuning(const Array<Schedule>& design_spaces) final {
+  void PreTuning(const Array<Schedule>& design_spaces, const Optional<Database>& database,
+                 const Optional<CostModel>& cost_model) final {
     ICHECK(!design_spaces.empty());
+    CHECK(this->context_ != nullptr) << "ValueError: Did you forget to initialize the TuneContext?";
+    CHECK(database.defined())
+        << "ValueError: Database is not supplied in PreTuning. Evolutionary"
+           "search algorithm requires a database to be present, so that it "
+           "could sample from previously-explored population. If you do not "
+           "intent to store data on disk, please use `tvm.meta_schedule.testing.DummyDatabase`";
+    CHECK(cost_model.defined())
+        << "ValueError: CostModel is not supplied in PreTuning. Evolutionary search "
+           "algorithm expects a cost model to filter out potentially less efficient kernels. If "
+           "you do not expect a cost model to help, please use "
+           "`tvm.meta_schedule.cost_model.RandomModel`";
+    if (this->state_ != nullptr) {
+      TVM_PY_LOG(WARNING, this->context_->logging_func)
+          << "EvolutionarySearch is already initialized.";
+      this->state_.reset();
+    }
     ICHECK(this->state_ == nullptr);
-    // Change to traces
     Array<tir::Trace> design_space_traces;
     design_space_traces.reserve(design_spaces.size());
     for (const Schedule& space : design_spaces) {
       design_space_traces.push_back(space->trace().value()->Simplified(true));
     }
-    this->state_ = std::make_unique<State>(this, design_space_traces);
+    this->state_ =
+        std::make_unique<State>(this, design_space_traces, database.value(), cost_model.value());
   }
 
   void PostTuning() final {
@@ -442,16 +441,16 @@ class EvolutionarySearchNode : public SearchStrategyNode {
 std::vector<Schedule> EvolutionarySearchNode::State::PickBestFromDatabase(int num) {
   std::vector<tir::Trace> measured_traces;
   measured_traces.reserve(num);
-  Array<TuningRecord> top_records = self->database_->GetTopK(self->token_, num);
+  Array<TuningRecord> top_records = this->database_->GetTopK(this->token_, num);
   for (TuningRecord record : top_records) {
     measured_traces.push_back(record->trace);
   }
   int actual_num = measured_traces.size();
-  ThreadedTraceApply pp(self->postprocs_);
+  ThreadedTraceApply pp(self->context_->postprocs);
   std::vector<Schedule> results(actual_num, Schedule{nullptr});
   auto f_proc_measured = [this, &measured_traces, &results, &pp](int thread_id,
                                                                  int trace_id) -> void {
-    PerThreadData& data = self->per_thread_data_.at(thread_id);
+    PerThreadData& data = this->per_thread_data_.at(thread_id);
     TRandState* rand_state = &data.rand_state;
     const IRModule& mod = data.mod;
     tir::Trace trace = measured_traces.at(trace_id);
@@ -464,17 +463,17 @@ std::vector<Schedule> EvolutionarySearchNode::State::PickBestFromDatabase(int nu
       throw;
     }
   };
-  support::parallel_for_dynamic(0, actual_num, self->num_threads_, f_proc_measured);
+  support::parallel_for_dynamic(0, actual_num, self->context_->num_threads, f_proc_measured);
   return results;
 }
 
 std::vector<Schedule> EvolutionarySearchNode::State::SampleInitPopulation(int num) {
-  ThreadedTraceApply pp(self->postprocs_);
+  ThreadedTraceApply pp(self->context_->postprocs);
   std::vector<Schedule> out_schs;
   while (static_cast<int>(out_schs.size()) < self->init_min_unmeasured) {
     std::vector<Schedule> results(num, Schedule{nullptr});
     auto f_proc_unmeasured = [this, &results, &pp](int thread_id, int trace_id) -> void {
-      PerThreadData& data = self->per_thread_data_.at(thread_id);
+      PerThreadData& data = this->per_thread_data_.at(thread_id);
       TRandState* rand_state = &data.rand_state;
       const IRModule& mod = data.mod;
       Schedule& result = results.at(trace_id);
@@ -485,7 +484,7 @@ std::vector<Schedule> EvolutionarySearchNode::State::SampleInitPopulation(int nu
         result = sch.value();
       }
     };
-    support::parallel_for_dynamic(0, num, self->num_threads_, f_proc_unmeasured);
+    support::parallel_for_dynamic(0, num, self->context_->num_threads, f_proc_unmeasured);
     for (int i = 0; i < num; i++) {
       if (results[i].defined()) {
         out_schs.push_back(results[i]);
@@ -501,14 +500,14 @@ std::vector<Schedule> EvolutionarySearchNode::State::EvolveWithCostModel(
     std::vector<Schedule> population, int num) {
   ICHECK_GT(num, 0);
   // The heap to record best schedule, we do not consider schedules that are already measured
-  IRModuleSet exists = self->measured_workloads_;
+  IRModuleSet exists = this->measured_workloads_;
   SizedHeap heap(num);
   for (int iter = 0;; ++iter) {
     // Predict normalized score with the cost model,
     std::vector<double> scores = PredictNormalizedScore(population,                           //
                                                         GetRef<TuneContext>(self->context_),  //
-                                                        self->cost_model_,                    //
-                                                        self->args_info_);
+                                                        this->cost_model_,                    //
+                                                        this->args_info_);
     ICHECK_EQ(scores.size(), population.size());
     for (int i = 0, n = population.size(); i < n; ++i) {
       Schedule sch = population.at(i);
@@ -524,18 +523,18 @@ std::vector<Schedule> EvolutionarySearchNode::State::EvolveWithCostModel(
     if (iter == self->genetic_num_iters) {
       break;
     }
-    // Set threaded samplers, with probability from predicated normalized throughputs
-    for (PerThreadData& data : self->per_thread_data_) {
-      data.Set(scores, self->genetic_mutate_prob, self->mutator_probs_);
+    // Set threaded samplers, with probability from predicated normalized throughput
+    for (PerThreadData& data : this->per_thread_data_) {
+      data.Set(scores, self->genetic_mutate_prob, self->context_->mutator_probs);
     }
-    ThreadedTraceApply pp(self->postprocs_);
+    ThreadedTraceApply pp(self->context_->postprocs);
     ConcurrentBitmask cbmask(self->population_size);
     std::vector<Schedule> next_population(self->population_size, Schedule{nullptr});
     // The worker function
     auto f_find_candidate = [&cbmask, &population, &next_population, &pp, this](int thread_id,
                                                                                 int trace_id) {
       // Prepare samplers
-      PerThreadData& data = self->per_thread_data_.at(thread_id);
+      PerThreadData& data = this->per_thread_data_.at(thread_id);
       TRandState* rand_state = &data.rand_state;
       const IRModule& mod = data.mod;
       std::function<int()>& trace_sampler = data.trace_sampler;
@@ -567,7 +566,8 @@ std::vector<Schedule> EvolutionarySearchNode::State::EvolveWithCostModel(
         result = population.at(sampled_trace_id);
       }
     };
-    support::parallel_for_dynamic(0, self->population_size, self->num_threads_, f_find_candidate);
+    support::parallel_for_dynamic(0, self->population_size, self->context_->num_threads,
+                                  f_find_candidate);
     population.swap(next_population);
     TVM_PY_LOG(INFO, self->context_->logging_func) << "Evolve iter #" << iter << " done. Summary:\n"
                                                    << pp.SummarizeFailures();
@@ -607,7 +607,7 @@ std::vector<Schedule> EvolutionarySearchNode::State::PickWithEpsGreedy(
       tir::SampleWithoutReplacement(&self->rand_state_, unmeasured.size(), unmeasured.size());
   std::vector<Schedule> results;
   results.reserve(num);
-  IRModuleSet& measured_workloads = self->measured_workloads_;
+  IRModuleSet& measured_workloads = this->measured_workloads_;
   for (int i = 0, i_bests = 0, i_rands = 0; i < num; ++i) {
     bool has_best = i_bests < static_cast<int>(bests.size());
     bool has_rand = i_rands < static_cast<int>(rands.size());
@@ -677,7 +677,7 @@ Optional<Array<MeasureCandidate>> EvolutionarySearchNode::State::GenerateMeasure
       return NullOpt;
     }
   }
-  return AssembleCandidates(picks, self->args_info_);
+  return AssembleCandidates(picks, this->args_info_);
 }
 
 void EvolutionarySearchNode::State::NotifyRunnerResults(
@@ -713,6 +713,12 @@ SearchStrategy SearchStrategy::EvolutionarySearch(int num_trials_per_iter,     /
   return SearchStrategy(n);
 }
 
+class EvolutionarySearch : public SearchStrategy {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(EvolutionarySearch, SearchStrategy,
+                                                    EvolutionarySearchNode);
+};
+
 TVM_REGISTER_NODE_TYPE(EvolutionarySearchNode);
 TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyEvolutionarySearch")
     .set_body_typed(SearchStrategy::EvolutionarySearch);
diff --git a/src/meta_schedule/search_strategy/replay_func.cc b/src/meta_schedule/search_strategy/replay_func.cc
index 878c872a65..1aaaaa09e8 100644
--- a/src/meta_schedule/search_strategy/replay_func.cc
+++ b/src/meta_schedule/search_strategy/replay_func.cc
@@ -32,8 +32,14 @@ class ReplayFuncNode : public SearchStrategyNode {
     int st;
     /*! \brief `[st, ed)` are the indices of the next batch of candidates. */
     int ed;
+    /*! \brief The metadata of the function arguments. */
+    Array<ArgInfo> args_info_{nullptr};
 
-    explicit State(ReplayFuncNode* self) : self(self), st(0), ed(self->num_trials_per_iter) {}
+    explicit State(ReplayFuncNode* self) : self(self), st(0), ed(self->num_trials_per_iter) {
+      const TuneContextNode* ctx = self->context_;
+      ICHECK(ctx);
+      this->args_info_ = ArgInfo::FromPrimFunc(FindEntryFunc(ctx->mod.value()));
+    }
 
     inline Optional<Array<MeasureCandidate>> GenerateMeasureCandidates();
     inline void NotifyRunnerResults(const Array<RunnerResult>& results);
@@ -44,14 +50,8 @@ class ReplayFuncNode : public SearchStrategyNode {
   /*! \brief The number of total trials. */
   int max_trials_per_task;
 
-  /*! \brief The module to be tuned. */
-  IRModule mod_{nullptr};
-  /*! \brief The metadata of the function arguments. */
-  Array<ArgInfo> args_info_{nullptr};
-  /*! \brief The post processors */
-  Array<Postproc> postprocs_{nullptr};
-  /*! \brief The space generator for measure candidates generation. */
-  SpaceGenerator space_generator_{nullptr};
+  /*! \brief The tuning context of the search strategy. */
+  const TuneContextNode* context_{nullptr};
   /*! \brief The random state. -1 means using random number. */
   TRandState rand_state_ = -1;
   /*! \brief The state of the search strategy. */
@@ -60,10 +60,7 @@ class ReplayFuncNode : public SearchStrategyNode {
   void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("num_trials_per_iter", &num_trials_per_iter);
     v->Visit("max_trials_per_task", &max_trials_per_task);
-    // `space_generator_` is not visited
-    // `mod_` is not visited
-    // `args_info_` is not visited
-    // `num_threads_` is not visited
+    // `context_` is not visited.
     // `rand_state_` is not visited
     // `state_` is not visited
   }
@@ -72,15 +69,21 @@ class ReplayFuncNode : public SearchStrategyNode {
   TVM_DECLARE_FINAL_OBJECT_INFO(ReplayFuncNode, SearchStrategyNode);
 
   void InitializeWithTuneContext(const TuneContext& context) final {
-    this->space_generator_ = context->space_generator.value();
-    this->mod_ = context->mod.value();
-    this->args_info_ = ArgInfo::FromPrimFunc(FindEntryFunc(context->mod.value()));
-    this->postprocs_ = context->postprocs;
+    CHECK(context->space_generator.defined())
+        << "ValueError: TuneContext.space_generator is not defined";
+    CHECK(context->mod.defined()) << "ValueError: TuneContext.mod is not defined";
+    this->context_ = context.get();
     this->rand_state_ = ForkSeed(&context->rand_state);
     this->state_.reset();
   }
 
-  void PreTuning(const Array<tir::Schedule>& design_spaces) final {
+  void PreTuning(const Array<tir::Schedule>& design_spaces, const Optional<Database>& database,
+                 const Optional<CostModel>& cost_model) final {
+    CHECK(this->context_ != nullptr) << "ValueError: Did you forget to initialize the TuneContext?";
+    if (this->state_ != nullptr) {
+      TVM_PY_LOG(WARNING, this->context_->logging_func) << "ReplayFunc is already initialized.";
+      this->state_.reset();
+    }
     ICHECK(this->state_ == nullptr);
     this->state_ = std::make_unique<State>(this);
   }
@@ -109,21 +112,24 @@ inline Optional<Array<MeasureCandidate>> ReplayFuncNode::State::GenerateMeasureC
   }
   ed = std::min(ed, self->max_trials_per_task);
   Array<MeasureCandidate> result;
+  const TuneContextNode* ctx = self->context_;
+  ICHECK(ctx);
+  IRModule mod = ctx->mod.value();
   for (int i = st; i < ed; i++) {
     for (;;) {
-      Array<tir::Schedule> schs = self->space_generator_->GenerateDesignSpace(self->mod_);
+      Array<tir::Schedule> schs = ctx->space_generator.value()->GenerateDesignSpace(mod);
       int design_space_index = tir::SampleInt(&self->rand_state_, 0, schs.size());
       tir::Schedule sch = schs[design_space_index];
       sch->EnterPostproc();
       bool failed = false;
-      for (const Postproc& proc : self->postprocs_) {
+      for (const Postproc& proc : ctx->postprocs) {
         if (!proc->Apply(sch)) {
           failed = true;
           break;
         }
       }
       if (!failed) {
-        result.push_back(MeasureCandidate(sch, self->args_info_));
+        result.push_back(MeasureCandidate(sch, this->args_info_));
         break;
       }
     }
diff --git a/src/meta_schedule/search_strategy/replay_trace.cc b/src/meta_schedule/search_strategy/replay_trace.cc
index f17c5d6c4e..13f32a744e 100644
--- a/src/meta_schedule/search_strategy/replay_trace.cc
+++ b/src/meta_schedule/search_strategy/replay_trace.cc
@@ -35,8 +35,22 @@ class ReplayTraceNode : public SearchStrategyNode {
     /*! \brief `[st, ed)` are the indices of the next batch of candidates. */
     int ed;
 
+    /*! \brief The module to be tuned. */
+    Array<IRModule> per_thread_mod_{nullptr};
+    /*! \brief The metadata of the function arguments. */
+    Array<ArgInfo> args_info_{nullptr};
+
     explicit State(ReplayTraceNode* self, Array<tir::Trace> design_spaces)
-        : self(self), design_spaces(design_spaces), st(0), ed(self->num_trials_per_iter) {}
+        : self(self), design_spaces(design_spaces), st(0), ed(self->num_trials_per_iter) {
+      const TuneContextNode* ctx = self->context_;
+      ICHECK(ctx);
+      IRModule mod = ctx->mod.value();
+      this->per_thread_mod_.reserve(ctx->num_threads);
+      for (int i = 0; i < ctx->num_threads; i++) {
+        this->per_thread_mod_.push_back(DeepCopyIRModule(mod));
+      }
+      this->args_info_ = ArgInfo::FromPrimFunc(FindEntryFunc(mod));
+    }
 
     inline Optional<Array<MeasureCandidate>> GenerateMeasureCandidates();
     inline void NotifyRunnerResults(const Array<RunnerResult>& results);
@@ -47,14 +61,8 @@ class ReplayTraceNode : public SearchStrategyNode {
   /*! \brief The number of total trials. */
   int max_trials_per_task;
 
-  /*! \brief The module to be tuned. */
-  Array<IRModule> per_thread_mod_{nullptr};
-  /*! \brief The metadata of the function arguments. */
-  Array<ArgInfo> args_info_{nullptr};
-  /*! \brief The post processors */
-  Array<Postproc> postprocs_{nullptr};
-  /*! \brief The number of threads to use. -1 means using logical cpu number. */
-  int num_threads_ = -1;
+  /*! \brief The tuning context of the search strategy. */
+  const TuneContextNode* context_{nullptr};
   /*! \brief The random state. -1 means using random number. */
   TRandState rand_state_ = -1;
   /*! \brief The state of the search strategy. */
@@ -63,10 +71,7 @@ class ReplayTraceNode : public SearchStrategyNode {
   void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("num_trials_per_iter", &num_trials_per_iter);
     v->Visit("max_trials_per_task", &max_trials_per_task);
-    // `per_thread_mod_` is not visited
-    // `args_info_` is not visited
-    // `postprocs_` is not visited
-    // `num_threads_` is not visited
+    // `context_` is not visited.
     // `rand_state_` is not visited
     // `state_` is not visited
   }
@@ -75,22 +80,20 @@ class ReplayTraceNode : public SearchStrategyNode {
   TVM_DECLARE_FINAL_OBJECT_INFO(ReplayTraceNode, SearchStrategyNode);
 
   void InitializeWithTuneContext(const TuneContext& context) final {
-    CHECK(context->num_threads > 0) << "Number of threads has to be larger than 0.";
-    this->num_threads_ = context->num_threads;
-
-    this->per_thread_mod_.reserve(this->num_threads_);
-    for (int i = 0; i < this->num_threads_; i++) {
-      this->per_thread_mod_.push_back(DeepCopyIRModule(context->mod.value()));
-    }
-
-    this->args_info_ = ArgInfo::FromPrimFunc(FindEntryFunc(context->mod.value()));
-    this->postprocs_ = context->postprocs;
+    CHECK(context->mod.defined()) << "ValueError: TuneContext.mod is not defined";
+    this->context_ = context.get();
     this->rand_state_ = ForkSeed(&context->rand_state);
     this->state_.reset();
   }
 
-  void PreTuning(const Array<tir::Schedule>& design_spaces) final {
+  void PreTuning(const Array<tir::Schedule>& design_spaces, const Optional<Database>& database,
+                 const Optional<CostModel>& cost_model) final {
     ICHECK(!design_spaces.empty());
+    CHECK(this->context_ != nullptr) << "ValueError: Did you forget to initialize the TuneContext?";
+    if (this->state_ != nullptr) {
+      TVM_PY_LOG(WARNING, this->context_->logging_func) << "RelayTrace is already initialized.";
+      this->state_.reset();
+    }
     ICHECK(this->state_ == nullptr);
     Array<tir::Trace> design_space_traces;
     design_space_traces.reserve(design_spaces.size());
@@ -124,24 +127,26 @@ inline Optional<Array<MeasureCandidate>> ReplayTraceNode::State::GenerateMeasure
   }
   ed = std::min(ed, self->max_trials_per_task);
   ICHECK_LT(st, ed);
-  std::vector<TRandState> per_thread_rand_state = ForkSeed(&self->rand_state_, self->num_threads_);
+  const TuneContextNode* ctx = self->context_;
+  ICHECK(ctx);
+  std::vector<TRandState> per_thread_rand_state = ForkSeed(&self->rand_state_, ctx->num_threads);
   Array<MeasureCandidate> per_task_result(ed - st, MeasureCandidate{nullptr});
-  ThreadedTraceApply pp(self->postprocs_);
+  ThreadedTraceApply pp(ctx->postprocs);
   auto f_worker = [this, &per_thread_rand_state, &per_task_result, &pp](int thread_id,
                                                                         int task_id) -> void {
     TRandState& rand_state = per_thread_rand_state[thread_id];
-    IRModule mod = self->per_thread_mod_[thread_id];
+    IRModule mod = this->per_thread_mod_[thread_id];
     for (;;) {
       int design_space_index = tir::SampleInt(&rand_state, 0, design_spaces.size());
       tir::Trace trace = design_spaces[design_space_index];
       tir::Trace new_trace = tir::Trace(trace->insts, {});
       if (Optional<tir::Schedule> sch = pp.Apply(mod, new_trace, &rand_state)) {
-        per_task_result.Set(task_id, MeasureCandidate(sch.value(), self->args_info_));
+        per_task_result.Set(task_id, MeasureCandidate(sch.value(), this->args_info_));
         break;
       }
     }
   };
-  support::parallel_for_dynamic(0, ed - st, self->num_threads_, f_worker);
+  support::parallel_for_dynamic(0, ed - st, ctx->num_threads, f_worker);
   return per_task_result;
 }
 
diff --git a/src/meta_schedule/search_strategy/search_strategy.cc b/src/meta_schedule/search_strategy/search_strategy.cc
index fefe8dfce7..a6a1100ceb 100644
--- a/src/meta_schedule/search_strategy/search_strategy.cc
+++ b/src/meta_schedule/search_strategy/search_strategy.cc
@@ -28,6 +28,13 @@ MeasureCandidate::MeasureCandidate(tir::Schedule sch, Array<ArgInfo> args_info)
   data_ = std::move(n);
 }
 
+void PySearchStrategyNode::PreTuning(const Array<tir::Schedule>& design_spaces,
+                                     const Optional<Database>& database,
+                                     const Optional<CostModel>& cost_model) {
+  ICHECK(f_pre_tuning != nullptr) << "PySearchStrategy's PreTuning method not implemented!";
+  this->f_pre_tuning(design_spaces, database, cost_model);
+}
+
 SearchStrategy SearchStrategy::PySearchStrategy(
     PySearchStrategyNode::FInitializeWithTuneContext f_initialize_with_tune_context,  //
     PySearchStrategyNode::FPreTuning f_pre_tuning,                                    //
diff --git a/src/meta_schedule/task_scheduler/gradient_based.cc b/src/meta_schedule/task_scheduler/gradient_based.cc
index f8cc9d5514..73d191f593 100644
--- a/src/meta_schedule/task_scheduler/gradient_based.cc
+++ b/src/meta_schedule/task_scheduler/gradient_based.cc
@@ -189,10 +189,10 @@ TaskScheduler TaskScheduler::GradientBased(Array<TuneContext> tasks,
                                            Array<FloatImm> task_weights,                        //
                                            Builder builder,                                     //
                                            Runner runner,                                       //
-                                           Database database,                                   //
-                                           int max_trials,                                      //
+                                           Optional<Database> database,                         //
                                            Optional<CostModel> cost_model,                      //
                                            Optional<Array<MeasureCallback>> measure_callbacks,  //
+                                           int max_trials,                                      //
                                            PackedFunc logging_func,                             //
                                            double alpha,                                        //
                                            int window_size,                                     //
@@ -227,9 +227,6 @@ TaskScheduler TaskScheduler::GradientBased(Array<TuneContext> tasks,
   n->best_time_cost_per_task_ = std::vector<double>(n_tasks, 1e100);
   n->num_rounds_already_ = 0;
   support::LinearCongruentialEngine(&n->rand_state_).Seed(seed);
-  for (const TuneContext& task : tasks) {
-    task->task_scheduler = n.get();
-  }
   return TaskScheduler(n);
 }
 
diff --git a/src/meta_schedule/task_scheduler/round_robin.cc b/src/meta_schedule/task_scheduler/round_robin.cc
index 446b118379..ea22878840 100644
--- a/src/meta_schedule/task_scheduler/round_robin.cc
+++ b/src/meta_schedule/task_scheduler/round_robin.cc
@@ -58,10 +58,10 @@ class RoundRobinNode final : public TaskSchedulerNode {
 TaskScheduler TaskScheduler::RoundRobin(Array<TuneContext> tasks,                            //
                                         Builder builder,                                     //
                                         Runner runner,                                       //
-                                        Database database,                                   //
-                                        int max_trials,                                      //
+                                        Optional<Database> database,                         //
                                         Optional<CostModel> cost_model,                      //
                                         Optional<Array<MeasureCallback>> measure_callbacks,  //
+                                        int max_trials,                                      //
                                         PackedFunc logging_func) {
   ObjectPtr<RoundRobinNode> n = make_object<RoundRobinNode>();
   n->tasks = tasks;
@@ -74,9 +74,6 @@ TaskScheduler TaskScheduler::RoundRobin(Array<TuneContext> tasks,
   n->logging_func = logging_func;
   n->num_trials_already = 0;
   n->task_id = -1;
-  for (const TuneContext& task : tasks) {
-    task->task_scheduler = n.get();
-  }
   return TaskScheduler(n);
 }
 
diff --git a/src/meta_schedule/task_scheduler/task_scheduler.cc b/src/meta_schedule/task_scheduler/task_scheduler.cc
index fd1d95cd1f..25867fb4f3 100644
--- a/src/meta_schedule/task_scheduler/task_scheduler.cc
+++ b/src/meta_schedule/task_scheduler/task_scheduler.cc
@@ -117,7 +117,7 @@ void TaskSchedulerNode::InitializeTask(int task_id) {
                                          << tir::AsTVMScript(sch->mod()) << "\n"
                                          << Concat(trace->AsPython(false), "\n");
   }
-  task->search_strategy.value()->PreTuning(design_spaces);
+  task->search_strategy.value()->PreTuning(design_spaces, database, cost_model);
 }
 
 void TaskSchedulerNode::Tune() {
@@ -203,10 +203,10 @@ TaskScheduler TaskScheduler::PyTaskScheduler(
     Array<TuneContext> tasks,                                   //
     Builder builder,                                            //
     Runner runner,                                              //
-    Database database,                                          //
-    int max_trials,                                             //
+    Optional<Database> database,                                //
     Optional<CostModel> cost_model,                             //
     Optional<Array<MeasureCallback>> measure_callbacks,         //
+    int max_trials,                                             //
     PackedFunc logging_func,                                    //
     PyTaskSchedulerNode::FTune f_tune,                          //
     PyTaskSchedulerNode::FInitializeTask f_initialize_task,     //
diff --git a/tests/python/unittest/test_meta_schedule_measure_callback.py b/tests/python/unittest/test_meta_schedule_measure_callback.py
index a1b188930f..298b51e015 100644
--- a/tests/python/unittest/test_meta_schedule_measure_callback.py
+++ b/tests/python/unittest/test_meta_schedule_measure_callback.py
@@ -16,12 +16,10 @@
 # under the License.
 # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
 import re
-from random import random
 from typing import List
 
 import pytest
 import tvm
-from tvm.ir import IRModule, assert_structural_equal
 from tvm.meta_schedule.builder import BuilderResult
 from tvm.meta_schedule.measure_callback import PyMeasureCallback
 from tvm.meta_schedule.runner import RunnerResult
@@ -66,7 +64,7 @@ def test_meta_schedule_measure_callback():
             results: List[RunnerResult],
         ) -> None:
             assert len(measure_candidates) == 1
-            assert_structural_equal(measure_candidates[0].sch.mod, Matmul)
+            tvm.ir.assert_structural_equal(measure_candidates[0].sch.mod, Matmul)
             assert (
                 len(builds) == 1
                 and builds[0].error_msg is None
@@ -78,7 +76,14 @@ def test_meta_schedule_measure_callback():
 
     measure_callback = FancyMeasureCallback()
     measure_callback.apply(
-        RoundRobin([], [], DummyBuilder(), DummyRunner(), DummyDatabase(), max_trials=1),
+        RoundRobin(
+            tasks=[],
+            task_weights=[],
+            builder=DummyBuilder(),
+            runner=DummyRunner(),
+            database=DummyDatabase(),
+            max_trials=1,
+        ),
         0,
         [MeasureCandidate(Schedule(Matmul), None)],
         [BuilderResult("test_build", None)],
@@ -102,7 +107,14 @@ def test_meta_schedule_measure_callback_fail():
     measure_callback = FailingMeasureCallback()
     with pytest.raises(ValueError, match="test"):
         measure_callback.apply(
-            RoundRobin([], [], DummyBuilder(), DummyRunner(), DummyDatabase(), max_trials=1),
+            RoundRobin(
+                tasks=[],
+                task_weights=[],
+                builder=DummyBuilder(),
+                runner=DummyRunner(),
+                database=DummyDatabase(),
+                max_trials=1,
+            ),
             0,
             [MeasureCandidate(Schedule(Matmul), None)],
             [BuilderResult("test_build", None)],
diff --git a/tests/python/unittest/test_meta_schedule_search_strategy.py b/tests/python/unittest/test_meta_schedule_search_strategy.py
index 94042dd753..4eb8aac5a3 100644
--- a/tests/python/unittest/test_meta_schedule_search_strategy.py
+++ b/tests/python/unittest/test_meta_schedule_search_strategy.py
@@ -123,43 +123,37 @@ def test_meta_schedule_evolutionary_search():  # pylint: disable = invalid-name]
 
     num_trials_per_iter = 10
     max_trials_per_task = 2000
+    (correct_sch,) = ScheduleFn(sch_fn=_schedule_matmul).generate_design_space(Matmul)
 
-    strategy = EvolutionarySearch(
-        num_trials_per_iter=num_trials_per_iter,
-        max_trials_per_task=max_trials_per_task,
-        population_size=5,
-        init_measured_ratio=0.1,
-        init_min_unmeasured=50,
-        genetic_num_iters=3,
-        genetic_mutate_prob=0.5,
-        genetic_max_fail_count=10,
-        eps_greedy=0.9,
-    )
     context = TuneContext(
         mod=Matmul,
-        space_generator=ScheduleFn(sch_fn=_schedule_matmul_small),
+        space_generator=ScheduleFn(
+            sch_fn=_schedule_matmul_small,
+        ),
+        search_strategy=EvolutionarySearch(
+            num_trials_per_iter=num_trials_per_iter,
+            max_trials_per_task=max_trials_per_task,
+            population_size=5,
+            init_measured_ratio=0.1,
+            init_min_unmeasured=50,
+            genetic_num_iters=3,
+            genetic_mutate_prob=0.5,
+            genetic_max_fail_count=10,
+            eps_greedy=0.9,
+        ),
         mutator_probs={
             DummyMutator(): 1.0,
         },
         target=tvm.target.Target("llvm"),
         num_threads=1,  # because we are using a mutator from the python side
     )
-    _scheduler = RoundRobin(
-        tasks=[context],
-        task_weights=[1.0],
-        builder=ms.builder.LocalBuilder(),
-        runner=ms.runner.LocalRunner(),
+    context.initialize()
+    strategy = context.search_strategy
+    strategy.pre_tuning(
+        context.space_generator.generate_design_space(context.mod),
         database=DummyDatabase(),
         cost_model=ms.cost_model.RandomModel(),
-        measure_callbacks=[],
-        max_trials=1,
     )
-    context.space_generator.initialize_with_tune_context(context)
-    spaces = context.space_generator.generate_design_space(context.mod)
-
-    strategy.initialize_with_tune_context(context)
-    strategy.pre_tuning(spaces)
-    (correct_sch,) = ScheduleFn(sch_fn=_schedule_matmul).generate_design_space(Matmul)
     num_trials_each_iter: List[int] = []
     candidates = strategy.generate_measure_candidates()
     while candidates is not None:
@@ -177,52 +171,46 @@ def test_meta_schedule_evolutionary_search():  # pylint: disable = invalid-name]
     strategy.post_tuning()
     assert sum(num_trials_each_iter) == 25
     assert num_trials_each_iter.count(0) < 5
-    del _scheduler
 
 
 def test_meta_schedule_evolutionary_search_early_stop():  # pylint: disable = invalid-name]
     def _schedule_matmul_empty(sch: Schedule):
         return sch
 
+    (correct_sch,) = ScheduleFn(sch_fn=_schedule_matmul).generate_design_space(Matmul)
+
     num_trials_per_iter = 10
     max_trials_per_task = 100
 
-    strategy = EvolutionarySearch(
-        num_trials_per_iter=num_trials_per_iter,
-        max_trials_per_task=max_trials_per_task,
-        population_size=5,
-        init_measured_ratio=0.1,
-        init_min_unmeasured=50,
-        genetic_num_iters=3,
-        genetic_mutate_prob=0.5,
-        genetic_max_fail_count=10,
-        eps_greedy=0.9,
-    )
     context = TuneContext(
         mod=Matmul,
-        space_generator=ScheduleFn(sch_fn=_schedule_matmul_empty),
+        search_strategy=EvolutionarySearch(
+            num_trials_per_iter=num_trials_per_iter,
+            max_trials_per_task=max_trials_per_task,
+            population_size=5,
+            init_measured_ratio=0.1,
+            init_min_unmeasured=50,
+            genetic_num_iters=3,
+            genetic_mutate_prob=0.5,
+            genetic_max_fail_count=10,
+            eps_greedy=0.9,
+        ),
+        space_generator=ScheduleFn(
+            sch_fn=_schedule_matmul_empty,
+        ),
         mutator_probs={
             DummyMutator(): 1.0,
         },
         target=tvm.target.Target("llvm"),
-        num_threads=1,  # because we are using a mutator from the python side
+        num_threads=1,
     )
-    _scheduler = RoundRobin(
-        tasks=[context],
-        task_weights=[1.0],
-        builder=ms.builder.LocalBuilder(),
-        runner=ms.runner.LocalRunner(),
+    context.initialize()
+    strategy = context.search_strategy
+    strategy.pre_tuning(
+        context.space_generator.generate_design_space(context.mod),
         database=DummyDatabase(),
         cost_model=ms.cost_model.RandomModel(),
-        measure_callbacks=[],
-        max_trials=1,
     )
-    context.space_generator.initialize_with_tune_context(context)
-    spaces = context.space_generator.generate_design_space(context.mod)
-
-    strategy.initialize_with_tune_context(context)
-    strategy.pre_tuning(spaces)
-    (correct_sch,) = ScheduleFn(sch_fn=_schedule_matmul).generate_design_space(Matmul)
     num_trials_each_iter: List[int] = []
     candidates = strategy.generate_measure_candidates()
     while candidates is not None:
@@ -239,7 +227,6 @@ def test_meta_schedule_evolutionary_search_early_stop():  # pylint: disable = in
         candidates = strategy.generate_measure_candidates()
     strategy.post_tuning()
     assert num_trials_each_iter == [1, 0, 0, 0, 0]
-    del _scheduler
 
 
 if __name__ == "__main__":
diff --git a/tests/python/unittest/test_meta_schedule_task_scheduler.py b/tests/python/unittest/test_meta_schedule_task_scheduler.py
index 025bbe4225..f24dc5fbbc 100644
--- a/tests/python/unittest/test_meta_schedule_task_scheduler.py
+++ b/tests/python/unittest/test_meta_schedule_task_scheduler.py
@@ -17,7 +17,6 @@
 """ Test Meta Schedule Task Scheduler """
 
 import random
-import sys
 import weakref
 from typing import Set
 
@@ -108,7 +107,6 @@ class BatchMatmulModule:
 def _schedule_matmul(sch: Schedule):
     block = sch.get_block("matmul")
     i, j, k = sch.get_loops(block=block)
-    # TODO(@zxybazh): Change to `sample_perfect_tile` after upstreaming
     i_0, i_1, i_2, i_3 = sch.split(loop=i, factors=[2, 4, 64, 2])
     j_0, j_1, j_2, j_3 = sch.split(loop=j, factors=[4, 64, 2, 2])
     k_0, k_1 = sch.split(loop=k, factors=[32, 32])
@@ -118,7 +116,6 @@ def _schedule_matmul(sch: Schedule):
 def _schedule_batch_matmul(sch: Schedule):
     block = sch.get_block("matmul")
     i, j, k, t = sch.get_loops(block=block)
-    # TODO(@zxybazh): Change to `sample_perfect_tile` after upstreaming
     i_0, i_1, i_2, i_3 = sch.split(loop=i, factors=[2, 2, 2, 2])
     j_0, j_1, j_2, j_3 = sch.split(loop=j, factors=[2, 4, 64, 2])
     k_0, k_1 = sch.split(loop=k, factors=[32, 32])
@@ -156,23 +153,22 @@ class MyTaskScheduler(PyTaskScheduler):
 def test_meta_schedule_task_scheduler_single():
     num_trials_per_iter = 3
     max_trials_per_task = 10
-    sch_fn = ScheduleFn(sch_fn=_schedule_matmul)
-    replay = ReplayTrace(num_trials_per_iter, max_trials_per_task)
-    task = TuneContext(
-        MatmulModule,
-        target=tvm.target.Target("llvm"),
-        space_generator=sch_fn,
-        search_strategy=replay,
-        task_name="Test",
-        rand_state=42,
-    )
     database = DummyDatabase()
     round_robin = RoundRobin(
-        [task],
+        [
+            TuneContext(
+                MatmulModule,
+                target=tvm.target.Target("llvm"),
+                space_generator=ScheduleFn(sch_fn=_schedule_matmul),
+                search_strategy=ReplayTrace(num_trials_per_iter, max_trials_per_task),
+                task_name="Test",
+                rand_state=42,
+            )
+        ],
         [1.0],
-        DummyBuilder(),
-        DummyRunner(),
-        database,
+        builder=DummyBuilder(),
+        runner=DummyRunner(),
+        database=database,
         measure_callbacks=[measure_callback.AddToDatabase()],
         max_trials=max_trials_per_task,
     )
@@ -212,10 +208,10 @@ def test_meta_schedule_task_scheduler_multiple():
     database = DummyDatabase()
     round_robin = RoundRobin(
         tasks,
-        [1.0],
-        DummyBuilder(),
-        DummyRunner(),
-        database,
+        [1.0, 1.0, 1.0],
+        builder=DummyBuilder(),
+        runner=DummyRunner(),
+        database=database,
         measure_callbacks=[measure_callback.AddToDatabase()],
         max_trials=max_trials_per_task * len(tasks),
     )
@@ -239,18 +235,23 @@ def test_meta_schedule_task_scheduler_NIE():  # pylint: disable=invalid-name
         pass
 
     with pytest.raises(TVMError, match="PyTaskScheduler's NextTaskId method not implemented!"):
-        scheduler = NIETaskScheduler([], DummyBuilder(), DummyRunner(), DummyDatabase(), 1)
+        scheduler = NIETaskScheduler(
+            tasks=[],
+            builder=DummyBuilder(),
+            runner=DummyRunner(),
+            database=DummyDatabase(),
+            max_trials=1,
+        )
         scheduler.next_task_id()
 
 
 def test_meta_schedule_task_scheduler_avoid_cyclic():  # pylint: disable=invalid-name
-
     database = DummyDatabase()
     scheduler = MyTaskScheduler(
         [],
-        DummyBuilder(),
-        DummyRunner(),
-        database,
+        builder=DummyBuilder(),
+        runner=DummyRunner(),
+        database=database,
         measure_callbacks=[
             measure_callback.AddToDatabase(),
         ],
@@ -262,7 +263,6 @@ def test_meta_schedule_task_scheduler_avoid_cyclic():  # pylint: disable=invalid
 
 
 def test_meta_schedule_task_scheduler_override_next_task_id_only():  # pylint: disable=invalid-name
-
     num_trials_per_iter = 6
     max_trials_per_task = 101
     tasks = [
@@ -294,9 +294,9 @@ def test_meta_schedule_task_scheduler_override_next_task_id_only():  # pylint: d
     database = DummyDatabase()
     scheduler = MyTaskScheduler(
         tasks,
-        DummyBuilder(),
-        DummyRunner(),
-        database,
+        builder=DummyBuilder(),
+        runner=DummyRunner(),
+        database=database,
         measure_callbacks=[
             measure_callback.AddToDatabase(),
         ],