You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/01/02 01:23:03 UTC

[tvm] branch main updated: [M3c][MetaScheduler] Update TuneContext, TaskScheduler & Search Strategy Design (#9789)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 1c7d36f  [M3c][MetaScheduler] Update TuneContext, TaskScheduler & Search Strategy Design (#9789)
1c7d36f is described below

commit 1c7d36ff271f64e460487900e52ff6d6e5f5d451
Author: Xiyou Zhou <xi...@octoml.ai>
AuthorDate: Sat Jan 1 17:22:28 2022 -0800

    [M3c][MetaScheduler] Update TuneContext, TaskScheduler & Search Strategy Design (#9789)
    
    * Modify TuneContext, TaskScheduler & SearchStrategy functions.
    
    Co-authored-by: Junru Shao <ju...@gmail.com>
    Co-authored-by: Bohan Hou <32...@users.noreply.github.com>
    Co-authored-by: Ruihang Lai <la...@qq.com>
    Co-authored-by: Hongyi Jin <32...@qq.com>
    Co-authored-by: Wuwei Lin <wu...@apache.org>
    Co-authored-by: Siyuan Feng <Hz...@sjtu.edu.cn>
    
    * Minor fix.
    
    Fix mypy.
    
    Fix mypy.
    
    * Retrigger CI.
    
    * Minor fixes.
    
    Co-authored-by: Junru Shao <ju...@gmail.com>
    Co-authored-by: Bohan Hou <32...@users.noreply.github.com>
    Co-authored-by: Ruihang Lai <la...@qq.com>
    Co-authored-by: Hongyi Jin <32...@qq.com>
    Co-authored-by: Wuwei Lin <wu...@apache.org>
    Co-authored-by: Siyuan Feng <Hz...@sjtu.edu.cn>
---
 include/tvm/meta_schedule/cost_model.h             |  20 +--
 include/tvm/meta_schedule/feature_extractor.h      |  12 +-
 include/tvm/meta_schedule/mutator.h                | 146 ++++++++++++++++++
 include/tvm/meta_schedule/postproc.h               | 167 +++++++++++++++++++++
 include/tvm/meta_schedule/search_strategy.h        |  46 +++++-
 include/tvm/meta_schedule/space_generator.h        |   8 +-
 include/tvm/meta_schedule/task_scheduler.h         |  35 ++++-
 include/tvm/meta_schedule/tune_context.h           |  29 +++-
 include/tvm/support/random_engine.h                |   8 +
 python/tvm/meta_schedule/cost_model/cost_model.py  |  22 ++-
 .../tvm/meta_schedule/cost_model/random_model.py   |   8 +-
 .../feature_extractor/feature_extractor.py         |  10 +-
 .../feature_extractor/random_feature_extractor.py  |   2 +-
 python/tvm/meta_schedule/mutator/__init__.py       |  22 +++
 python/tvm/meta_schedule/mutator/mutator.py        |  88 +++++++++++
 python/tvm/meta_schedule/postproc/__init__.py      |  18 +++
 python/tvm/meta_schedule/postproc/postproc.py      |  90 +++++++++++
 .../meta_schedule/schedule_rule/schedule_rule.py   |  10 +-
 .../meta_schedule/search_strategy/replay_trace.py  |  13 +-
 .../search_strategy/search_strategy.py             |  41 +++--
 .../meta_schedule/space_generator/schedule_fn.py   |   4 +-
 .../space_generator/space_generator.py             |  10 +-
 .../meta_schedule/task_scheduler/round_robin.py    |  26 +++-
 .../meta_schedule/task_scheduler/task_scheduler.py |  20 ++-
 python/tvm/meta_schedule/tune_context.py           |  27 +++-
 src/meta_schedule/cost_model/cost_model.cc         |   4 +-
 src/meta_schedule/search_strategy/replay_trace.cc  |  63 +++++---
 src/meta_schedule/task_scheduler/round_robin.cc    |  15 +-
 src/meta_schedule/task_scheduler/task_scheduler.cc | 126 ++++++++--------
 src/meta_schedule/tune_context.cc                  |  43 ++++--
 src/meta_schedule/utils.h                          |  83 +++++++++-
 src/tir/schedule/concrete_schedule.cc              |   2 +-
 src/tir/schedule/traced_schedule.cc                |   2 +-
 .../unittest/test_meta_schedule_cost_model.py      |  12 +-
 .../test_meta_schedule_feature_extractor.py        |   4 +-
 .../test_meta_schedule_post_order_apply.py         |  10 +-
 .../unittest/test_meta_schedule_search_strategy.py |  73 +++++----
 .../unittest/test_meta_schedule_task_scheduler.py  |  77 +++++++---
 38 files changed, 1136 insertions(+), 260 deletions(-)

diff --git a/include/tvm/meta_schedule/cost_model.h b/include/tvm/meta_schedule/cost_model.h
index b05dc3c..6fadc2f 100644
--- a/include/tvm/meta_schedule/cost_model.h
+++ b/include/tvm/meta_schedule/cost_model.h
@@ -51,20 +51,20 @@ class CostModelNode : public runtime::Object {
 
   /*!
    * \brief Update the cost model given running results.
-   * \param tune_context The tuning context.
+   * \param context The tuning context.
    * \param candidates The measure candidates.
    * \param results The running results of the measure candidates.
    */
-  virtual void Update(const TuneContext& tune_context, const Array<MeasureCandidate>& candidates,
+  virtual void Update(const TuneContext& context, const Array<MeasureCandidate>& candidates,
                       const Array<RunnerResult>& results) = 0;
 
   /*!
    * \brief Predict the normalized score (the larger the better) of given measure candidates.
-   * \param tune_context The tuning context.
+   * \param context The tuning context.
    * \param candidates The measure candidates.
    * \return The predicted normalized score.
    */
-  virtual std::vector<double> Predict(const TuneContext& tune_context,
+  virtual std::vector<double> Predict(const TuneContext& context,
                                       const Array<MeasureCandidate>& candidates) = 0;
 
   static constexpr const char* _type_key = "meta_schedule.CostModel";
@@ -86,7 +86,7 @@ class PyCostModelNode : public CostModelNode {
   using FSave = runtime::TypedPackedFunc<void(String)>;
   /*!
    * \brief Update the cost model given running results.
-   * \param tune_context The tuning context.
+   * \param context The tuning context.
    * \param candidates The measure candidates.
    * \param results The running results of the measure candidates.
    * \return Whether cost model was updated successfully.
@@ -95,7 +95,7 @@ class PyCostModelNode : public CostModelNode {
                                                 const Array<RunnerResult>&)>;
   /*!
    * \brief Predict the running results of given measure candidates.
-   * \param tune_context The tuning context.
+   * \param context The tuning context.
    * \param candidates The measure candidates.
    * \param p_addr The address to save the the estimated running results.
    */
@@ -135,17 +135,17 @@ class PyCostModelNode : public CostModelNode {
     ICHECK(f_save != nullptr) << "PyCostModel's Save method not implemented!";
     f_save(path);
   }
-  void Update(const TuneContext& tune_context, const Array<MeasureCandidate>& candidates,
+  void Update(const TuneContext& context, const Array<MeasureCandidate>& candidates,
               const Array<RunnerResult>& results) {
     ICHECK(f_update != nullptr) << "PyCostModel's Update method not implemented!";
-    f_update(tune_context, candidates, results);
+    f_update(context, candidates, results);
   }
 
-  std::vector<double> Predict(const TuneContext& tune_context,
+  std::vector<double> Predict(const TuneContext& context,
                               const Array<MeasureCandidate>& candidates) {
     ICHECK(f_predict != nullptr) << "PyCostModel's Predict method not implemented!";
     std::vector<double> result(candidates.size(), 0.0);
-    f_predict(tune_context, candidates, result.data());
+    f_predict(context, candidates, result.data());
     return result;
   }
 
diff --git a/include/tvm/meta_schedule/feature_extractor.h b/include/tvm/meta_schedule/feature_extractor.h
index ee5d94c..c2ca2be 100644
--- a/include/tvm/meta_schedule/feature_extractor.h
+++ b/include/tvm/meta_schedule/feature_extractor.h
@@ -37,11 +37,11 @@ class FeatureExtractorNode : public runtime::Object {
 
   /*!
    * \brief Extract features from the given measure candidate.
-   * \param tune_context The tuning context for feature extraction.
+   * \param context The tuning context for feature extraction.
    * \param candidates The measure candidates to extract features from.
    * \return The feature ndarray extracted.
    */
-  virtual Array<tvm::runtime::NDArray> ExtractFrom(const TuneContext& tune_context,
+  virtual Array<tvm::runtime::NDArray> ExtractFrom(const TuneContext& context,
                                                    const Array<MeasureCandidate>& candidates) = 0;
 
   static constexpr const char* _type_key = "meta_schedule.FeatureExtractor";
@@ -53,12 +53,12 @@ class PyFeatureExtractorNode : public FeatureExtractorNode {
  public:
   /*!
    * \brief Extract features from the given measure candidate.
-   * \param tune_context The tuning context for feature extraction.
+   * \param context The tuning context for feature extraction.
    * \param candidates The measure candidates to extract features from.
    * \return The feature ndarray extracted.
    */
   using FExtractFrom = runtime::TypedPackedFunc<Array<tvm::runtime::NDArray>(
-      const TuneContext& tune_context, const Array<MeasureCandidate>& candidates)>;
+      const TuneContext& context, const Array<MeasureCandidate>& candidates)>;
   /*!
    * \brief Get the feature extractor as string with name.
    * \return The string of the feature extractor.
@@ -75,10 +75,10 @@ class PyFeatureExtractorNode : public FeatureExtractorNode {
     // `f_as_string` is not visited
   }
 
-  Array<tvm::runtime::NDArray> ExtractFrom(const TuneContext& tune_context,
+  Array<tvm::runtime::NDArray> ExtractFrom(const TuneContext& context,
                                            const Array<MeasureCandidate>& candidates) {
     ICHECK(f_extract_from != nullptr) << "PyFeatureExtractor's ExtractFrom method not implemented!";
-    return f_extract_from(tune_context, candidates);
+    return f_extract_from(context, candidates);
   }
 
   static constexpr const char* _type_key = "meta_schedule.PyFeatureExtractor";
diff --git a/include/tvm/meta_schedule/mutator.h b/include/tvm/meta_schedule/mutator.h
new file mode 100644
index 0000000..e3fa847
--- /dev/null
+++ b/include/tvm/meta_schedule/mutator.h
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef TVM_META_SCHEDULE_MUTATOR_H_
+#define TVM_META_SCHEDULE_MUTATOR_H_
+
+#include <tvm/tir/schedule/schedule.h>
+
+namespace tvm {
+namespace meta_schedule {
+
+class TuneContext;
+
+/*! \brief Mutator is designed to mutate the trace to explore the design space. */
+class MutatorNode : public runtime::Object {
+ public:
+  /*! \brief Virtual destructor. */
+  virtual ~MutatorNode() = default;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {}
+
+  /*!
+   * \brief Initialize the design space generator with tuning context.
+   * \param context The tuning context for initialization.
+   * \note This method is supposed to be called only once before every other method.
+   */
+  virtual void InitializeWithTuneContext(const TuneContext& context) = 0;
+
+  /*!
+   * \brief Apply the mutator function to the given trace.
+   * \param trace The given trace for mutation.
+   * \param rand_state The random state for mutation.
+   * \return None if mutator failed, otherwise return the mutated trace.
+   */
+  virtual Optional<tir::Trace> Apply(const tir::Trace& trace,
+                                     support::LinearCongruentialEngine::TRandState* rand_state) = 0;
+
+  static constexpr const char* _type_key = "meta_schedule.Mutator";
+  TVM_DECLARE_BASE_OBJECT_INFO(MutatorNode, Object);
+};
+
+/*! \brief The mutator with customized methods on the python-side. */
+class PyMutatorNode : public MutatorNode {
+ public:
+  /*!
+   * \brief The function type of `InitializeWithTuneContext` method.
+   * \param context The tuning context for initialization.
+   */
+  using FInitializeWithTuneContext = runtime::TypedPackedFunc<void(const TuneContext&)>;
+  /*!
+   * \brief Apply the mutator function to the given trace.
+   * \param trace The given trace for mutation.
+   * \return None if mutator failed, otherwise return the mutated trace.
+   */
+  using FApply = runtime::TypedPackedFunc<Optional<tir::Trace>(
+      const tir::Trace&, support::LinearCongruentialEngine::TRandState rand_state)>;
+  /*!
+   * \brief Get the mutator as string with name.
+   * \return The string of the mutator.
+   */
+  using FAsString = runtime::TypedPackedFunc<String()>;
+
+  /*! \brief The packed function to the `InitializeWithTuneContext` function. */
+  FInitializeWithTuneContext f_initialize_with_tune_context;
+  /*! \brief The packed function to the `Apply` function. */
+  FApply f_apply;
+  /*! \brief The packed function to the `AsString` function. */
+  FAsString f_as_string;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    // `f_initialize_with_tune_context` is not visited
+    // `f_apply` is not visited
+    // `f_as_string` is not visited
+  }
+
+  void InitializeWithTuneContext(const TuneContext& context) final {
+    ICHECK(f_initialize_with_tune_context != nullptr)
+        << "PyMutator's InitializeWithTuneContext method not implemented!";
+    this->f_initialize_with_tune_context(context);
+  }
+
+  Optional<tir::Trace> Apply(const tir::Trace& trace,
+                             support::LinearCongruentialEngine::TRandState* rand_state) final {
+    ICHECK(f_apply != nullptr) << "PyMutator's Apply method not implemented!";
+    return this->f_apply(trace, *rand_state);
+  }
+
+  static constexpr const char* _type_key = "meta_schedule.PyMutator";
+  TVM_DECLARE_FINAL_OBJECT_INFO(PyMutatorNode, MutatorNode);
+};
+
+/*!
+ * \brief Managed reference to MutatorNode
+ * \sa MutatorNode
+ */
+class Mutator : public runtime::ObjectRef {
+ public:
+  /*! \brief Create a Mutator that mutates the tile size. */
+  TVM_DLL static Mutator MutateTileSize();
+  /*!
+   * \brief Create a Mutator that mutates the parallel extent
+   * \param max_jobs_per_core The maximum number of parallel jobs per core.
+   * \return The created mutator.
+   */
+  TVM_DLL static Mutator MutateParallel(int64_t max_jobs_per_core);
+  /*! \brief Create a Mutator that mutates auto unroll step */
+  TVM_DLL static Mutator MutateUnroll();
+  /*!
+   * \brief Create a Mutator that mutates the outcome of SampleComputeLocation
+   * \return The mutator created
+   */
+  TVM_DLL static Mutator MutateComputeLocation();
+  /*!
+   * \brief Create a mutator with customized methods on the python-side.
+   * \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`.
+   * \param f_apply The packed function of `Apply`.
+   * \param f_as_string The packed function of `AsString`.
+   * \return The mutator created.
+   */
+  TVM_DLL static Mutator PyMutator(
+      PyMutatorNode::FInitializeWithTuneContext f_initialize_with_tune_context,  //
+      PyMutatorNode::FApply f_apply,                                             //
+      PyMutatorNode::FAsString f_as_string);
+  TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Mutator, ObjectRef, MutatorNode);
+};
+
+}  // namespace meta_schedule
+}  // namespace tvm
+
+#endif  // TVM_META_SCHEDULE_MUTATOR_H_
diff --git a/include/tvm/meta_schedule/postproc.h b/include/tvm/meta_schedule/postproc.h
new file mode 100644
index 0000000..93e8be0
--- /dev/null
+++ b/include/tvm/meta_schedule/postproc.h
@@ -0,0 +1,167 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef TVM_META_SCHEDULE_POSTPROC_H_
+#define TVM_META_SCHEDULE_POSTPROC_H_
+
+#include <tvm/tir/schedule/schedule.h>
+
+namespace tvm {
+namespace meta_schedule {
+
+class TuneContext;
+
+/*!
+ * \brief Rules to apply a postprocessor to a schedule.
+ */
+class PostprocNode : public runtime::Object {
+ public:
+  /*! \brief Virtual destructor. */
+  virtual ~PostprocNode() = default;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {}
+
+  /*!
+   * \brief Initialize the design space generator with tuning context.
+   * \param context The tuning context for initialization.
+   * \note This method is supposed to be called only once before every other method.
+   */
+  virtual void InitializeWithTuneContext(const TuneContext& context) = 0;
+
+  /*!
+   * \brief Apply a postprocessor to the given schedule.
+   * \param sch The schedule to be post processed.
+   * \return Whether the postprocessor was successfully applied.
+   */
+  virtual bool Apply(const tir::Schedule& sch) = 0;
+
+  static constexpr const char* _type_key = "meta_schedule.Postproc";
+  TVM_DECLARE_BASE_OBJECT_INFO(PostprocNode, Object);
+};
+
+/*! \brief The postprocessor with customized methods on the python-side. */
+class PyPostprocNode : public PostprocNode {
+ public:
+  /*!
+   * \brief The function type of `InitializeWithTuneContext` method.
+   * \param context The tuning context for initialization.
+   */
+  using FInitializeWithTuneContext = runtime::TypedPackedFunc<void(const TuneContext&)>;
+  /*!
+   * \brief Apply a postprocessor to the given schedule.
+   * \param sch The schedule to be post processed.
+   * \return Whether the postprocessor was successfully applied.
+   */
+  using FApply = runtime::TypedPackedFunc<bool(const tir::Schedule&)>;
+  /*!
+   * \brief Get the postprocessor function as string with name.
+   * \return The string of the postprocessor function.
+   */
+  using FAsString = runtime::TypedPackedFunc<String()>;
+
+  /*! \brief The packed function to the `InitializeWithTuneContext` function. */
+  FInitializeWithTuneContext f_initialize_with_tune_context;
+  /*! \brief The packed function to the `Apply` function. */
+  FApply f_apply;
+  /*! \brief The packed function to the `AsString` function. */
+  FAsString f_as_string;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    // `f_initialize_with_tune_context` is not visited
+    // `f_apply` is not visited
+    // `f_as_string` is not visited
+  }
+
+  void InitializeWithTuneContext(const TuneContext& context) final {
+    ICHECK(f_initialize_with_tune_context != nullptr)
+        << "PyPostproc's InitializeWithTuneContext method not implemented!";
+    this->f_initialize_with_tune_context(context);
+  }
+
+  bool Apply(const tir::Schedule& sch) final {
+    ICHECK(f_apply != nullptr) << "PyPostproc's Apply method not implemented!";
+    return this->f_apply(sch);
+  }
+
+  static constexpr const char* _type_key = "meta_schedule.PyPostproc";
+  TVM_DECLARE_FINAL_OBJECT_INFO(PyPostprocNode, PostprocNode);
+};
+
+/*!
+ * \brief Managed reference to PostprocNode
+ * \sa PostprocNode
+ */
+class Postproc : public runtime::ObjectRef {
+ public:
+  /*!
+   * \brief Create a postprocessor with customized methods on the python-side.
+   * \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`.
+   * \param f_apply The packed function of `Apply`.
+   * \param f_as_string The packed function of `AsString`.
+   * \return The postprocessor created.
+   */
+  TVM_DLL static Postproc PyPostproc(
+      PyPostprocNode::FInitializeWithTuneContext f_initialize_with_tune_context,  //
+      PyPostprocNode::FApply f_apply,                                             //
+      PyPostprocNode::FAsString f_as_string);
+  /*!
+   * \brief Create a postprocessor that checks if all loops are static
+   * \return The postprocessor created
+   */
+  TVM_DLL static Postproc DisallowDynamicLoop();
+  /*!
+   * \brief Create a postprocessor that rewrites the cooperative fetch annotation to
+   * actual vectorized cooperative fetching in loop bindings.
+   * \return The postprocessor created.
+   */
+  TVM_DLL static Postproc RewriteCooperativeFetch();
+  /*!
+   * \brief Creates a postprocessor that applies parallelization, vectorization and auto unrolling
+   * according to the annotation of each block
+   * \return The postprocessor created
+   */
+  TVM_DLL static Postproc RewriteParallelVectorizeUnroll();
+  /*!
+   * \brief Create a postprocessor that rewrites reduction block by moving the init block out.
+   * \return The postprocessor created.
+   */
+  TVM_DLL static Postproc RewriteReductionBlock();
+  /*!
+   * \brief Create a postprocessor that adds thread binding to unbound blocks
+   * \return The postprocessor created.
+   */
+  TVM_DLL static Postproc RewriteUnboundBlock();
+  /*!
+   * \brief Create a postprocessor that tensorize Tensor Core related components
+   * \return The postprocessor created.
+   */
+  TVM_DLL static Postproc RewriteTensorCore();
+
+  /*!
+   * \brief Creates a postprocessor that verifies if the GPU code is correct
+   * \return The postprocessor created
+   */
+  TVM_DLL static Postproc VerifyGPUCode();
+  TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Postproc, ObjectRef, PostprocNode);
+};
+
+}  // namespace meta_schedule
+}  // namespace tvm
+
+#endif  // TVM_META_SCHEDULE_POSTPROC_H_
diff --git a/include/tvm/meta_schedule/search_strategy.h b/include/tvm/meta_schedule/search_strategy.h
index e1c68c8..e8d2cbd 100644
--- a/include/tvm/meta_schedule/search_strategy.h
+++ b/include/tvm/meta_schedule/search_strategy.h
@@ -28,6 +28,8 @@ namespace meta_schedule {
 
 // Forward declaration
 class TuneContext;
+class CostModel;
+class Database;
 
 /*! \brief The schedule (with input shapes) to be measured. */
 class MeasureCandidateNode : public runtime::Object {
@@ -133,9 +135,13 @@ class SearchStrategyNode : public runtime::Object {
 
   /*!
    * \brief Update the search strategy with measurement results.
+   * \param context The tuning context.
+   * \param measure_candidates The candidates to be measured.
    * \param results The measurement results from the runner.
    */
-  virtual void NotifyRunnerResults(const Array<RunnerResult>& results) = 0;
+  virtual void NotifyRunnerResults(const TuneContext& context,
+                                   const Array<MeasureCandidate>& measure_candidates,
+                                   const Array<RunnerResult>& results) = 0;
 
   static constexpr const char* _type_key = "meta_schedule.SearchStrategy";
   TVM_DECLARE_BASE_OBJECT_INFO(SearchStrategyNode, Object);
@@ -165,7 +171,8 @@ class PySearchStrategyNode : public SearchStrategyNode {
    * \brief The function type of `NotifyRunnerResults` method.
    * \param results The measurement results from the runner.
    */
-  using FNotifyRunnerResults = runtime::TypedPackedFunc<void(const Array<RunnerResult>&)>;
+  using FNotifyRunnerResults = runtime::TypedPackedFunc<void(
+      const TuneContext&, const Array<MeasureCandidate>&, const Array<RunnerResult>&)>;
 
   /*! \brief The packed function to the `InitializeWithTuneContext` method. */
   FInitializeWithTuneContext f_initialize_with_tune_context;
@@ -208,10 +215,12 @@ class PySearchStrategyNode : public SearchStrategyNode {
     return this->f_generate_measure_candidates();
   }
 
-  void NotifyRunnerResults(const Array<RunnerResult>& results) final {
+  void NotifyRunnerResults(const TuneContext& context,
+                           const Array<MeasureCandidate>& measure_candidates,
+                           const Array<RunnerResult>& results) final {
     ICHECK(f_notify_runner_results != nullptr)
         << "PySearchStrategy's NotifyRunnerResults method not implemented!";
-    this->f_notify_runner_results(results);
+    this->f_notify_runner_results(context, measure_candidates, results);
   }
 
   static constexpr const char* _type_key = "meta_schedule.PySearchStrategy";
@@ -247,6 +256,35 @@ class SearchStrategy : public runtime::ObjectRef {
    */
   TVM_DLL static SearchStrategy ReplayTrace(int num_trials_per_iter, int num_trials_total);
 
+  /*!
+   * \brief Constructor of replay func search strategy.
+   * \param num_trials_per_iter The number of trials per iteration, i.e., the batch size.
+   * \param num_trials_total The total number of trials for func replaying.
+   */
+  TVM_DLL static SearchStrategy ReplayFunc(int num_trials_per_iter, int num_trials_total);
+
+  /*!
+   * \brief Constructor of evolutionary search strategy.
+   * \param num_trials_per_iter The number of trials per iteration, i.e., the batch size.
+   * \param num_trials_total The total number of trials for evolutionary search.
+   * \param population_size The initial sample population.
+   * \param init_measured_ratio The ratio of measures samples in initial population.
+   * \param init_max_fail_count The maximum number to fail trace replaying.
+   * \param genetic_num_iters The iterations to run the genetic algorithm.
+   * \param genetic_mutate_prob The probability of mutation.
+   * \param genetic_max_fail_count The maximum number to try evolving the given trace.
+   * \param eps_greedy The ratio to select samples in a greedy fashion via their predicted score.
+   */
+  TVM_DLL static SearchStrategy EvolutionarySearch(int num_trials_per_iter,     //
+                                                   int num_trials_total,        //
+                                                   int population_size,         //
+                                                   double init_measured_ratio,  //
+                                                   int init_max_fail_count,     //
+                                                   int genetic_num_iters,       //
+                                                   double genetic_mutate_prob,  //
+                                                   int genetic_max_fail_count,  //
+                                                   double eps_greedy);
+
   TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(SearchStrategy, ObjectRef, SearchStrategyNode);
 };
 
diff --git a/include/tvm/meta_schedule/space_generator.h b/include/tvm/meta_schedule/space_generator.h
index 7aff683..3611870 100644
--- a/include/tvm/meta_schedule/space_generator.h
+++ b/include/tvm/meta_schedule/space_generator.h
@@ -139,13 +139,13 @@ class SpaceGenerator : public ObjectRef {
  public:
   /*!
    * \brief Create a design space generator with customized methods on the python-side.
-   * \param initialize_with_tune_context_func The packed function of `InitializeWithTuneContext`.
-   * \param generate_design_space_func The packed function of `GenerateDesignSpace`.
+   * \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`.
+   * \param f_generate_design_space The packed function of `GenerateDesignSpace`.
    * \return The design space generator created.
    */
   TVM_DLL static SpaceGenerator PySpaceGenerator(
-      PySpaceGeneratorNode::FInitializeWithTuneContext initialize_with_tune_context_func,
-      PySpaceGeneratorNode::FGenerateDesignSpace generate_design_space_func);
+      PySpaceGeneratorNode::FInitializeWithTuneContext f_initialize_with_tune_context,
+      PySpaceGeneratorNode::FGenerateDesignSpace f_generate_design_space);
 
   /*!
    * \brief Create a design space generator that is union of multiple design space generators.
diff --git a/include/tvm/meta_schedule/task_scheduler.h b/include/tvm/meta_schedule/task_scheduler.h
index f28c33d..ddd6f4c 100644
--- a/include/tvm/meta_schedule/task_scheduler.h
+++ b/include/tvm/meta_schedule/task_scheduler.h
@@ -20,7 +20,9 @@
 #define TVM_META_SCHEDULE_TASK_SCHEDULER_H_
 
 #include <tvm/meta_schedule/builder.h>
+#include <tvm/meta_schedule/cost_model.h>
 #include <tvm/meta_schedule/database.h>
+#include <tvm/meta_schedule/measure_callback.h>
 #include <tvm/meta_schedule/runner.h>
 #include <tvm/meta_schedule/tune_context.h>
 
@@ -78,7 +80,7 @@ class TaskSchedulerNode : public runtime::Object {
   /*! \brief The list of measure callbacks of the scheduler. */
   Array<MeasureCallback> measure_callbacks;
 
-  /*! \brief The default desctructor. */
+  /*! \brief The default destructor. */
   virtual ~TaskSchedulerNode() = default;
 
   void VisitAttrs(tvm::AttrVisitor* v) {
@@ -247,16 +249,39 @@ class TaskScheduler : public runtime::ObjectRef {
    * \param builder The builder of the scheduler.
    * \param runner The runner of the scheduler.
    * \param database The database of the scheduler.
+   * \param cost_model The cost model of the scheduler.
+   * \param measure_callbacks The measure callbacks of the scheduler.
+   * \return The task scheduler created.
+   */
+  TVM_DLL static TaskScheduler RoundRobin(Array<TuneContext> tasks,        //
+                                          Builder builder,                 //
+                                          Runner runner,                   //
+                                          Database database,               //
+                                          Optional<CostModel> cost_model,  //
+                                          Optional<Array<MeasureCallback>> measure_callbacks);
+  /*!
+   * \brief Create a task scheduler with customized methods on the python-side.
+   * \param tasks The tasks to be tuned.
+   * \param builder The builder of the scheduler.
+   * \param runner The runner of the scheduler.
+   * \param database The database of the scheduler.
+   * \param cost_model The cost model of the scheduler.
+   * \param measure_callbacks The measure callbacks of the scheduler.
+   * \param f_tune The packed function of `Tune`.
+   * \param f_initialize_task The packed function of `InitializeTask`.
+   * \param f_set_task_stopped The packed function of `SetTaskStopped`.
+   * \param f_is_task_running The packed function of `IsTaskRunning`.
+   * \param f_join_running_task The packed function of `JoinRunningTask`.
+   * \param f_next_task_id The packed function of `NextTaskId`.
+   * \return The task scheduler created.
    */
-  TVM_DLL static TaskScheduler RoundRobin(Array<TuneContext> tasks,  //
-                                          Builder builder,           //
-                                          Runner runner,             //
-                                          Database database);        //
   TVM_DLL static TaskScheduler PyTaskScheduler(
       Array<TuneContext> tasks,                                   //
       Builder builder,                                            //
       Runner runner,                                              //
       Database database,                                          //
+      Optional<CostModel> cost_model,                             //
+      Optional<Array<MeasureCallback>> measure_callbacks,         //
       PyTaskSchedulerNode::FTune f_tune,                          //
       PyTaskSchedulerNode::FInitializeTask f_initialize_task,     //
       PyTaskSchedulerNode::FSetTaskStopped f_set_task_stopped,    //
diff --git a/include/tvm/meta_schedule/tune_context.h b/include/tvm/meta_schedule/tune_context.h
index 6eacd4d..428a2e8 100644
--- a/include/tvm/meta_schedule/tune_context.h
+++ b/include/tvm/meta_schedule/tune_context.h
@@ -20,7 +20,12 @@
 #define TVM_META_SCHEDULE_TUNE_CONTEXT_H_
 
 #include <tvm/ir/module.h>
+#include <tvm/meta_schedule/builder.h>
+#include <tvm/meta_schedule/mutator.h>
+#include <tvm/meta_schedule/postproc.h>
+#include <tvm/meta_schedule/runner.h>
 #include <tvm/meta_schedule/schedule_rule.h>
+#include <tvm/meta_schedule/search_strategy.h>
 #include <tvm/meta_schedule/space_generator.h>
 #include <tvm/support/random_engine.h>
 #include <tvm/target/target.h>
@@ -28,6 +33,8 @@
 namespace tvm {
 namespace meta_schedule {
 
+class TaskSchedulerNode;
+
 /*! \brief The auto tuning context. */
 class TuneContextNode : public runtime::Object {
  public:
@@ -41,6 +48,10 @@ class TuneContextNode : public runtime::Object {
   Optional<SearchStrategy> search_strategy;
   /*! \brief The schedule rules. */
   Array<ScheduleRule> sch_rules;
+  /*! \brief The postprocessors. */
+  Array<Postproc> postprocs;
+  /*! \brief The probability of using certain mutator. */
+  Map<Mutator, FloatImm> mutator_probs;
   /*! \brief The name of the tuning task. */
   Optional<String> task_name;
   /*! \brief The random state. */
@@ -48,12 +59,16 @@ class TuneContextNode : public runtime::Object {
   /*! \brief The number of threads to be used. */
   int num_threads;
 
+  /*! \brief The task scheduler that owns the tune context */
+  const TaskSchedulerNode* task_scheduler;
   /*! \brief Whether the tuning task has been stopped or finished. */
   bool is_stopped;
-  /*! \brief Packed functions to fetch the runner results asynchronously. */
-  Optional<Array<RunnerFuture>> runner_futures;
   /*! \brief The measure candidates. */
   Optional<Array<MeasureCandidate>> measure_candidates;
+  /*! \brief The building results. */
+  Optional<Array<BuilderResult>> builder_results;
+  /*! \brief Packed functions to fetch the runner results asynchronously. */
+  Optional<Array<RunnerFuture>> runner_futures;
 
   void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("mod", &mod);
@@ -61,14 +76,18 @@ class TuneContextNode : public runtime::Object {
     v->Visit("space_generator", &space_generator);
     v->Visit("search_strategy", &search_strategy);
     v->Visit("sch_rules", &sch_rules);
+    v->Visit("postprocs", &postprocs);
+    v->Visit("mutator_probs", &mutator_probs);
     v->Visit("task_name", &task_name);
     v->Visit("rand_state", &rand_state);
     v->Visit("num_threads", &num_threads);
     v->Visit("is_stopped", &is_stopped);
-    v->Visit("runner_futures", &runner_futures);
     v->Visit("measure_candidates", &measure_candidates);
   }
 
+  /*! \brief Initialize members that needs initialization with tune context. */
+  void Initialize();
+
   static constexpr const char* _type_key = "meta_schedule.TuneContext";
   TVM_DECLARE_FINAL_OBJECT_INFO(TuneContextNode, Object);
 };
@@ -86,6 +105,8 @@ class TuneContext : public runtime::ObjectRef {
    * \param space_generator The design space generator.
    * \param search_strategy The search strategy.
    * \param sch_rules The schedule rules.
+   * \param postprocs The postprocessors.
+   * \param mutator_probs The probability of using certain mutator.
    * \param task_name The name of the tuning task.
    * \param rand_state The random state.
    * \param num_threads The number of threads to be used.
@@ -95,6 +116,8 @@ class TuneContext : public runtime::ObjectRef {
                                Optional<SpaceGenerator> space_generator,                  //
                                Optional<SearchStrategy> search_strategy,                  //
                                Optional<Array<ScheduleRule>> sch_rules,                   //
+                               Optional<Array<Postproc>> postprocs,                       //
+                               Optional<Map<Mutator, FloatImm>> mutator_probs,            //
                                Optional<String> task_name,                                //
                                support::LinearCongruentialEngine::TRandState rand_state,  //
                                int num_threads);
diff --git a/include/tvm/support/random_engine.h b/include/tvm/support/random_engine.h
index fcd2326..89b1e91 100644
--- a/include/tvm/support/random_engine.h
+++ b/include/tvm/support/random_engine.h
@@ -29,6 +29,7 @@
 #include <tvm/runtime/logging.h>
 
 #include <cstdint>  // for uint64_t
+#include <random>
 
 namespace tvm {
 namespace support {
@@ -74,6 +75,12 @@ class LinearCongruentialEngine {
   static constexpr result_type max() { return modulus - 1; }
 
   /*!
+   * \brief Get a device random state
+   * \return The random state
+   */
+  static TRandState DeviceRandom() { return (std::random_device()()) % modulus; }
+
+  /*!
    * \brief Operator to move the random state to the next and return the new random state. According
    *  to definition of linear congruential engine, the new random state value is computed as
    *  new_random_state = (current_random_state * multiplier + increment) % modulus.
@@ -93,6 +100,7 @@ class LinearCongruentialEngine {
    * \param rand_state The random state given in result_type.
    */
   void Seed(TRandState rand_state = 1) {
+    ICHECK(rand_state != -1) << "The seed can't be -1 which should be changed to random seed!";
     rand_state %= modulus;  // Make sure the seed is within the range of modulus.
     if (rand_state == 0)
       rand_state = 1;  // Avoid getting all 0 given the current parameter set.
diff --git a/python/tvm/meta_schedule/cost_model/cost_model.py b/python/tvm/meta_schedule/cost_model/cost_model.py
index f5bd601..f794b11 100644
--- a/python/tvm/meta_schedule/cost_model/cost_model.py
+++ b/python/tvm/meta_schedule/cost_model/cost_model.py
@@ -55,7 +55,7 @@ class CostModel(Object):
 
     def update(
         self,
-        tune_context: TuneContext,
+        context: TuneContext,
         candidates: List[MeasureCandidate],
         results: List[RunnerResult],
     ) -> None:
@@ -63,21 +63,21 @@ class CostModel(Object):
 
         Parameters
         ----------
-        tune_context : TuneContext,
+        context : TuneContext,
             The tuning context.
         candidates : List[MeasureCandidate]
             The measure candidates.
         results : List[RunnerResult]
             The running results of the measure candidates.
         """
-        _ffi_api.CostModelUpdate(self, tune_context, candidates, results)  # type: ignore # pylint: disable=no-member
+        _ffi_api.CostModelUpdate(self, context, candidates, results)  # type: ignore # pylint: disable=no-member
 
-    def predict(self, tune_context: TuneContext, candidates: List[MeasureCandidate]) -> np.ndarray:
+    def predict(self, context: TuneContext, candidates: List[MeasureCandidate]) -> np.ndarray:
         """Update the cost model given running results.
 
         Parameters
         ----------
-        tune_context : TuneContext,
+        context : TuneContext,
             The tuning context.
         candidates : List[MeasureCandidate]
             The measure candidates.
@@ -91,7 +91,7 @@ class CostModel(Object):
         results = np.zeros(shape=(n,), dtype="float64")
         _ffi_api.CostModelPredict(  # type: ignore # pylint: disable=no-member
             self,
-            tune_context,
+            context,
             candidates,
             results.ctypes.data_as(ctypes.c_void_p),
         )
@@ -115,20 +115,18 @@ class PyCostModel(CostModel):
 
         @check_override(self.__class__, CostModel)
         def f_update(
-            tune_context: TuneContext,
+            context: TuneContext,
             candidates: List[MeasureCandidate],
             results: List[RunnerResult],
         ) -> None:
-            self.update(tune_context, candidates, results)
+            self.update(context, candidates, results)
 
         @check_override(self.__class__, CostModel)
-        def f_predict(
-            tune_context: TuneContext, candidates: List[MeasureCandidate], return_ptr
-        ) -> None:
+        def f_predict(context: TuneContext, candidates: List[MeasureCandidate], return_ptr) -> None:
             n = len(candidates)
             return_ptr = ctypes.cast(return_ptr, ctypes.POINTER(ctypes.c_double))
             array_wrapper = np.ctypeslib.as_array(return_ptr, shape=(n,))
-            array_wrapper[:] = self.predict(tune_context, candidates)
+            array_wrapper[:] = self.predict(context, candidates)
             assert (
                 array_wrapper.dtype == "float64"
             ), "ValueError: Invalid data type returned from CostModel Predict!"
diff --git a/python/tvm/meta_schedule/cost_model/random_model.py b/python/tvm/meta_schedule/cost_model/random_model.py
index 23238d2..8808476 100644
--- a/python/tvm/meta_schedule/cost_model/random_model.py
+++ b/python/tvm/meta_schedule/cost_model/random_model.py
@@ -84,7 +84,7 @@ class RandomModel(PyCostModel):
 
     def update(
         self,
-        tune_context: TuneContext,
+        context: TuneContext,
         candidates: List[MeasureCandidate],
         results: List[RunnerResult],
     ) -> None:
@@ -92,7 +92,7 @@ class RandomModel(PyCostModel):
 
         Parameters
         ----------
-        tune_context : TuneContext,
+        context : TuneContext,
             The tuning context.
         candidates : List[MeasureCandidate]
             The measure candidates.
@@ -100,12 +100,12 @@ class RandomModel(PyCostModel):
             The running results of the measure candidates.
         """
 
-    def predict(self, tune_context: TuneContext, candidates: List[MeasureCandidate]) -> np.ndarray:
+    def predict(self, context: TuneContext, candidates: List[MeasureCandidate]) -> np.ndarray:
         """Update the cost model given running results.
 
         Parameters
         ----------
-        tune_context : TuneContext,
+        context : TuneContext,
             The tuning context.
         candidates : List[MeasureCandidate]
             The measure candidates.
diff --git a/python/tvm/meta_schedule/feature_extractor/feature_extractor.py b/python/tvm/meta_schedule/feature_extractor/feature_extractor.py
index bd7656e..5043d4b 100644
--- a/python/tvm/meta_schedule/feature_extractor/feature_extractor.py
+++ b/python/tvm/meta_schedule/feature_extractor/feature_extractor.py
@@ -32,13 +32,13 @@ class FeatureExtractor(Object):
     """Extractor for features from measure candidates for use in cost model."""
 
     def extract_from(
-        self, tune_context: TuneContext, candidates: List[MeasureCandidate]
+        self, context: TuneContext, candidates: List[MeasureCandidate]
     ) -> List[NDArray]:
         """Extract features from the given measure candidate.
 
         Parameters
         ----------
-        tune_context : TuneContext
+        context : TuneContext
             The tuning context for feature extraction.
         candidates : List[MeasureCandidate]
             The measure candidates to extract features from.
@@ -49,7 +49,7 @@ class FeatureExtractor(Object):
             The feature numpy ndarray extracted.
         """
         result = _ffi_api.FeatureExtractorExtractFrom(  # type: ignore # pylint: disable=no-member
-            self, tune_context, candidates
+            self, context, candidates
         )
         return result
 
@@ -63,9 +63,9 @@ class PyFeatureExtractor(FeatureExtractor):
 
         @check_override(self.__class__, FeatureExtractor)
         def f_extract_from(
-            tune_context: TuneContext, candidates: List[MeasureCandidate]
+            context: TuneContext, candidates: List[MeasureCandidate]
         ) -> List[NDArray]:
-            features = self.extract_from(tune_context, candidates)
+            features = self.extract_from(context, candidates)
             return features
 
         def f_as_string() -> str:
diff --git a/python/tvm/meta_schedule/feature_extractor/random_feature_extractor.py b/python/tvm/meta_schedule/feature_extractor/random_feature_extractor.py
index 7c72a25..d805648 100644
--- a/python/tvm/meta_schedule/feature_extractor/random_feature_extractor.py
+++ b/python/tvm/meta_schedule/feature_extractor/random_feature_extractor.py
@@ -51,7 +51,7 @@ class RandomFeatureExtractor(PyFeatureExtractor):
         self.random_state = np.random.get_state()
 
     def extract_from(
-        self, tune_context: TuneContext, candidates: List[MeasureCandidate]
+        self, context: TuneContext, candidates: List[MeasureCandidate]
     ) -> List[NDArray]:
         np.random.set_state(self.random_state)
         result = [
diff --git a/python/tvm/meta_schedule/mutator/__init__.py b/python/tvm/meta_schedule/mutator/__init__.py
new file mode 100644
index 0000000..f88043b
--- /dev/null
+++ b/python/tvm/meta_schedule/mutator/__init__.py
@@ -0,0 +1,22 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+The tvm.meta_schedule.mutator package.
+Meta Schedule mutator that mutates the trace to explore the
+design space.
+"""
+from .mutator import Mutator, PyMutator
diff --git a/python/tvm/meta_schedule/mutator/mutator.py b/python/tvm/meta_schedule/mutator/mutator.py
new file mode 100644
index 0000000..80e0f66
--- /dev/null
+++ b/python/tvm/meta_schedule/mutator/mutator.py
@@ -0,0 +1,88 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Meta Schedule Mutator."""
+from typing import Optional, TYPE_CHECKING
+
+from tvm._ffi import register_object
+from tvm.runtime import Object
+from tvm.tir.schedule import Trace
+
+from .. import _ffi_api
+from ..utils import _get_hex_address, check_override
+
+if TYPE_CHECKING:
+    from ..tune_context import TuneContext
+
+
+class Mutator(Object):
+    """Mutator is designed to mutate the trace to explore the design space."""
+
+    def initialize_with_tune_context(self, context: "TuneContext") -> None:
+        """Initialize the mutator with a tune context.
+
+        Parameters
+        ----------
+        context : TuneContext
+            The tuning context for initializing the mutator.
+        """
+        _ffi_api.MutatorInitializeWithTuneContext(  # type: ignore # pylint: disable=no-member
+            self, context
+        )
+
+    def apply(self, trace: Trace) -> Optional[Trace]:
+        """Apply the mutator function to the given trace.
+
+        Parameters
+        ----------
+        trace : Trace
+            The given trace for mutation.
+
+        Returns
+        -------
+        trace : Optional[Trace]
+            None if mutator failed, otherwise return the mutated trace.
+        """
+        return _ffi_api.MutatorApply(self, trace, -1)  # type: ignore # pylint: disable=no-member
+
+
+@register_object("meta_schedule.PyMutator")
+class PyMutator(Mutator):
+    """An abstract mutator with customized methods on the python-side."""
+
+    def __init__(self):
+        """Constructor."""
+
+        @check_override(self.__class__, Mutator)
+        def f_initialize_with_tune_context(context: "TuneContext") -> None:
+            self.initialize_with_tune_context(context)
+
+        @check_override(self.__class__, Mutator)
+        def f_apply(trace: Trace, _) -> Optional[Trace]:
+            return self.apply(trace)
+
+        def f_as_string() -> str:
+            return str(self)
+
+        self.__init_handle_by_constructor__(
+            _ffi_api.MutatorPyMutator,  # type: ignore # pylint: disable=no-member
+            f_initialize_with_tune_context,
+            f_apply,
+            f_as_string,
+        )
+
+    def __str__(self) -> str:
+        return f"{self.__class__.__name__}({_get_hex_address(self.handle)})"
diff --git a/python/tvm/meta_schedule/postproc/__init__.py b/python/tvm/meta_schedule/postproc/__init__.py
new file mode 100644
index 0000000..6ee052e
--- /dev/null
+++ b/python/tvm/meta_schedule/postproc/__init__.py
@@ -0,0 +1,18 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""The tvm.meta_schedule.postproc package."""
+from .postproc import Postproc, PyPostproc
diff --git a/python/tvm/meta_schedule/postproc/postproc.py b/python/tvm/meta_schedule/postproc/postproc.py
new file mode 100644
index 0000000..5f1180c
--- /dev/null
+++ b/python/tvm/meta_schedule/postproc/postproc.py
@@ -0,0 +1,90 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Meta Schedule Postproc."""
+
+from typing import TYPE_CHECKING
+
+from tvm._ffi import register_object
+from tvm.runtime import Object
+from tvm.tir.schedule import Schedule
+
+from .. import _ffi_api
+from ..utils import _get_hex_address, check_override
+
+if TYPE_CHECKING:
+    from ..tune_context import TuneContext
+
+
+@register_object("meta_schedule.Postproc")
+class Postproc(Object):
+    """Rules to apply a postprocessor to a schedule."""
+
+    def initialize_with_tune_context(self, context: "TuneContext") -> None:
+        """Initialize the postprocessor with a tune context.
+
+        Parameters
+        ----------
+        context : TuneContext
+            The tuning context for initializing the postprocessor.
+        """
+        _ffi_api.PostprocInitializeWithTuneContext(  # type: ignore # pylint: disable=no-member
+            self, context
+        )
+
+    def apply(self, sch: Schedule) -> bool:
+        """Apply a postprocessor to the given schedule.
+
+        Parameters
+        ----------
+        sch : Schedule
+            The schedule to be post processed.
+
+        Returns
+        -------
+        result : bool
+            Whether the postprocessor was successfully applied.
+        """
+        return _ffi_api.PostprocApply(self, sch)  # type: ignore # pylint: disable=no-member
+
+
+@register_object("meta_schedule.PyPostproc")
+class PyPostproc(Postproc):
+    """An abstract Postproc with customized methods on the python-side."""
+
+    def __init__(self):
+        """Constructor."""
+
+        @check_override(self.__class__, Postproc)
+        def f_initialize_with_tune_context(context: "TuneContext") -> None:
+            self.initialize_with_tune_context(context)
+
+        @check_override(self.__class__, Postproc)
+        def f_apply(sch: Schedule) -> bool:
+            return self.apply(sch)
+
+        def f_as_string() -> str:
+            return str(self)
+
+        self.__init_handle_by_constructor__(
+            _ffi_api.PostprocPyPostproc,  # type: ignore # pylint: disable=no-member
+            f_initialize_with_tune_context,
+            f_apply,
+            f_as_string,
+        )
+
+    def __str__(self) -> str:
+        return f"{self.__class__.__name__}({_get_hex_address(self.handle)})"
diff --git a/python/tvm/meta_schedule/schedule_rule/schedule_rule.py b/python/tvm/meta_schedule/schedule_rule/schedule_rule.py
index b995e5a..ab142c0 100644
--- a/python/tvm/meta_schedule/schedule_rule/schedule_rule.py
+++ b/python/tvm/meta_schedule/schedule_rule/schedule_rule.py
@@ -35,16 +35,16 @@ if TYPE_CHECKING:
 class ScheduleRule(Object):
     """Rules to modify a block in a schedule."""
 
-    def initialize_with_tune_context(self, tune_context: "TuneContext") -> None:
+    def initialize_with_tune_context(self, context: "TuneContext") -> None:
         """Initialize the schedule rule with a tune context.
 
         Parameters
         ----------
-        tune_context : TuneContext
+        context : TuneContext
             The tuning context for initializing the schedule rule.
         """
         _ffi_api.ScheduleRuleInitializeWithTuneContext(  # type: ignore # pylint: disable=no-member
-            self, tune_context
+            self, context
         )
 
     def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]:
@@ -75,8 +75,8 @@ class PyScheduleRule(ScheduleRule):
         """Constructor."""
 
         @check_override(self.__class__, ScheduleRule)
-        def f_initialize_with_tune_context(tune_context: "TuneContext") -> None:
-            self.initialize_with_tune_context(tune_context)
+        def f_initialize_with_tune_context(context: "TuneContext") -> None:
+            self.initialize_with_tune_context(context)
 
         @check_override(self.__class__, ScheduleRule)
         def f_apply(sch: Schedule, block: BlockRV) -> List[Schedule]:
diff --git a/python/tvm/meta_schedule/search_strategy/replay_trace.py b/python/tvm/meta_schedule/search_strategy/replay_trace.py
index 15f8295..5655038 100644
--- a/python/tvm/meta_schedule/search_strategy/replay_trace.py
+++ b/python/tvm/meta_schedule/search_strategy/replay_trace.py
@@ -15,6 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 """Replay Trace Search Strategy"""
+from typing import NamedTuple
 
 from tvm._ffi import register_object
 from .search_strategy import SearchStrategy
@@ -41,7 +42,17 @@ class ReplayTrace(SearchStrategy):
     def __init__(self, num_trials_per_iter: int, num_trials_total: int):
         """Constructor"""
         self.__init_handle_by_constructor__(
-            _ffi_api.ReplayTrace,  # type: ignore # pylint: disable=no-member
+            _ffi_api.SearchStrategyReplayTrace,  # type: ignore # pylint: disable=no-member
             num_trials_per_iter,
             num_trials_total,
         )
+
+
+class ReplayTraceConfig(NamedTuple):
+    """Configuration for ReplayTrace"""
+
+    num_trials_per_iter: int
+    num_trials_total: int
+
+    def create_strategy(self) -> ReplayTrace:
+        return ReplayTrace(self.num_trials_per_iter, self.num_trials_total)
diff --git a/python/tvm/meta_schedule/search_strategy/search_strategy.py b/python/tvm/meta_schedule/search_strategy/search_strategy.py
index 6cee09e..411fecb 100644
--- a/python/tvm/meta_schedule/search_strategy/search_strategy.py
+++ b/python/tvm/meta_schedule/search_strategy/search_strategy.py
@@ -48,7 +48,11 @@ class MeasureCandidate(Object):
     sch: Schedule
     args_info: List[ArgInfo]
 
-    def __init__(self, sch: Schedule, args_info: List[ArgInfo]) -> None:
+    def __init__(
+        self,
+        sch: Schedule,
+        args_info: List[ArgInfo],
+    ) -> None:
         """Constructor.
 
         Parameters
@@ -72,19 +76,16 @@ class SearchStrategy(Object):
     before usage and post-tuned after usage.
     """
 
-    def initialize_with_tune_context(
-        self,
-        tune_context: "TuneContext",
-    ) -> None:
+    def initialize_with_tune_context(self, context: "TuneContext") -> None:
         """Initialize the search strategy with tuning context.
 
         Parameters
         ----------
-        tune_context : TuneContext
+        context : TuneContext
             The tuning context for initialization.
         """
         _ffi_api.SearchStrategyInitializeWithTuneContext(  # type: ignore # pylint: disable=no-member
-            self, tune_context
+            self, context
         )
 
     def pre_tuning(self, design_spaces: List[Schedule]) -> None:
@@ -111,15 +112,29 @@ class SearchStrategy(Object):
         """
         return _ffi_api.SearchStrategyGenerateMeasureCandidates(self)  # type: ignore # pylint: disable=no-member
 
-    def notify_runner_results(self, results: List[RunnerResult]) -> None:
+    def notify_runner_results(
+        self,
+        context: "TuneContext",
+        measure_candidates: List[MeasureCandidate],
+        results: List[RunnerResult],
+    ) -> None:
         """Update the search strategy with profiling results.
 
         Parameters
         ----------
+        context : TuneContext
+            The tuning context for update.
+        measure_candidates : List[MeasureCandidate]
+            The measure candidates for update.
         results : List[RunnerResult]
             The profiling results from the runner.
         """
-        _ffi_api.SearchStrategyNotifyRunnerResults(self, results)  # type: ignore # pylint: disable=no-member
+        _ffi_api.SearchStrategyNotifyRunnerResults(  # type: ignore # pylint: disable=no-member
+            self,
+            context,
+            measure_candidates,
+            results,
+        )
 
 
 @register_object("meta_schedule.PySearchStrategy")
@@ -146,8 +161,12 @@ class PySearchStrategy(SearchStrategy):
             return self.generate_measure_candidates()
 
         @check_override(self.__class__, SearchStrategy)
-        def f_notify_runner_results(results: List["RunnerResult"]) -> None:
-            self.notify_runner_results(results)
+        def f_notify_runner_results(
+            context: "TuneContext",
+            measure_candidates: List[MeasureCandidate],
+            results: List["RunnerResult"],
+        ) -> None:
+            self.notify_runner_results(context, measure_candidates, results)
 
         self.__init_handle_by_constructor__(
             _ffi_api.SearchStrategyPySearchStrategy,  # type: ignore # pylint: disable=no-member
diff --git a/python/tvm/meta_schedule/space_generator/schedule_fn.py b/python/tvm/meta_schedule/space_generator/schedule_fn.py
index 64edd9e..8a57a84 100644
--- a/python/tvm/meta_schedule/space_generator/schedule_fn.py
+++ b/python/tvm/meta_schedule/space_generator/schedule_fn.py
@@ -51,12 +51,12 @@ class ScheduleFn(PySpaceGenerator):
         super().__init__()
         self.sch_fn = sch_fn
 
-    def initialize_with_tune_context(self, tune_context: "TuneContext") -> None:
+    def initialize_with_tune_context(self, context: "TuneContext") -> None:
         """Initialize the design space generator with tuning context.
 
         Parameters
         ----------
-        tune_context : TuneContext
+        context : TuneContext
             The tuning context for initializing the design space generator.
         """
 
diff --git a/python/tvm/meta_schedule/space_generator/space_generator.py b/python/tvm/meta_schedule/space_generator/space_generator.py
index 2172613..e0b0ab2 100644
--- a/python/tvm/meta_schedule/space_generator/space_generator.py
+++ b/python/tvm/meta_schedule/space_generator/space_generator.py
@@ -36,16 +36,16 @@ if TYPE_CHECKING:
 class SpaceGenerator(Object):
     """The abstract design space generator interface."""
 
-    def initialize_with_tune_context(self, tune_context: "TuneContext") -> None:
+    def initialize_with_tune_context(self, context: "TuneContext") -> None:
         """Initialize the design space generator with tuning context.
 
         Parameters
         ----------
-        tune_context : TuneContext
+        context : TuneContext
             The tuning context for initializing the design space generator.
         """
         _ffi_api.SpaceGeneratorInitializeWithTuneContext(  # type: ignore # pylint: disable=no-member
-            self, tune_context
+            self, context
         )
 
     def generate_design_space(self, mod: IRModule) -> List[Schedule]:
@@ -72,8 +72,8 @@ class PySpaceGenerator(SpaceGenerator):
         """Constructor."""
 
         @check_override(self.__class__, SpaceGenerator)
-        def f_initialize_with_tune_context(tune_context: "TuneContext") -> None:
-            self.initialize_with_tune_context(tune_context)
+        def f_initialize_with_tune_context(context: "TuneContext") -> None:
+            self.initialize_with_tune_context(context)
 
         @check_override(self.__class__, SpaceGenerator)
         def f_generate_design_space(mod: IRModule) -> List[Schedule]:
diff --git a/python/tvm/meta_schedule/task_scheduler/round_robin.py b/python/tvm/meta_schedule/task_scheduler/round_robin.py
index 391011b..a63d9a3 100644
--- a/python/tvm/meta_schedule/task_scheduler/round_robin.py
+++ b/python/tvm/meta_schedule/task_scheduler/round_robin.py
@@ -16,13 +16,15 @@
 # under the License.
 """Round Robin Task Scheduler"""
 
-from typing import List, TYPE_CHECKING
+from typing import List, Optional, TYPE_CHECKING
 
 from tvm._ffi import register_object
+from tvm.meta_schedule.measure_callback.measure_callback import MeasureCallback
 
 from ..builder import Builder
 from ..runner import Runner
 from ..database import Database
+from ..cost_model import CostModel
 from .task_scheduler import TaskScheduler
 
 from .. import _ffi_api
@@ -33,7 +35,21 @@ if TYPE_CHECKING:
 
 @register_object("meta_schedule.RoundRobin")
 class RoundRobin(TaskScheduler):
-    """Round Robin Task Scheduler"""
+    """Round Robin Task Scheduler
+
+    Parameters
+    ----------
+    tasks: List[TuneContext]
+        The list of tune context to process.
+    builder: Builder
+        The builder of the scheduler.
+    runner: Runner
+        The runner of the scheduler.
+    database: Database
+        The database of the scheduler.
+    measure_callbacks: Optional[List[MeasureCallback]] = None
+        The list of measure callbacks of the scheduler.
+    """
 
     def __init__(
         self,
@@ -41,6 +57,8 @@ class RoundRobin(TaskScheduler):
         builder: Builder,
         runner: Runner,
         database: Database,
+        cost_model: Optional[CostModel] = None,
+        measure_callbacks: Optional[List[MeasureCallback]] = None,
     ) -> None:
         """Constructor.
 
@@ -54,6 +72,8 @@ class RoundRobin(TaskScheduler):
             The runner.
         database : Database
             The database.
+        measure_callbacks: Optional[List[MeasureCallback]]
+            The list of measure callbacks of the scheduler.
         """
         self.__init_handle_by_constructor__(
             _ffi_api.TaskSchedulerRoundRobin,  # type: ignore # pylint: disable=no-member
@@ -61,4 +81,6 @@ class RoundRobin(TaskScheduler):
             builder,
             runner,
             database,
+            cost_model,
+            measure_callbacks,
         )
diff --git a/python/tvm/meta_schedule/task_scheduler/task_scheduler.py b/python/tvm/meta_schedule/task_scheduler/task_scheduler.py
index aeea154..dd8e3fe 100644
--- a/python/tvm/meta_schedule/task_scheduler/task_scheduler.py
+++ b/python/tvm/meta_schedule/task_scheduler/task_scheduler.py
@@ -16,14 +16,16 @@
 # under the License.
 """Auto-tuning Task Scheduler"""
 
-from typing import List
+from typing import List, Optional
 
 from tvm._ffi import register_object
+from tvm.meta_schedule.measure_callback.measure_callback import MeasureCallback
 from tvm.runtime import Object
 
 from ..runner import Runner
 from ..builder import Builder
 from ..database import Database
+from ..cost_model import CostModel
 from ..tune_context import TuneContext
 from .. import _ffi_api
 from ..utils import check_override
@@ -43,12 +45,16 @@ class TaskScheduler(Object):
         The runner of the scheduler.
     database: Database
         The database of the scheduler.
+    measure_callbacks: List[MeasureCallback] = None
+        The list of measure callbacks of the scheduler.
     """
 
     tasks: List[TuneContext]
     builder: Builder
     runner: Runner
     database: Database
+    cost_model: Optional[CostModel]
+    measure_callbacks: List[MeasureCallback]
 
     def tune(self) -> None:
         """Auto-tuning."""
@@ -59,7 +65,7 @@ class TaskScheduler(Object):
 
         Returns
         -------
-        int
+        next_task_id : int
             The next task id.
         """
         return _ffi_api.TaskSchedulerNextTaskId(self)  # type: ignore # pylint: disable=no-member
@@ -94,7 +100,7 @@ class TaskScheduler(Object):
 
         Returns
         -------
-        bool
+        running : bool
             Whether the task is running.
         """
         return _ffi_api.TaskSchedulerIsTaskRunning(self, task_id)  # type: ignore # pylint: disable=no-member
@@ -120,6 +126,8 @@ class PyTaskScheduler(TaskScheduler):
         builder: Builder,
         runner: Runner,
         database: Database,
+        cost_model: Optional[CostModel] = None,
+        measure_callbacks: Optional[List[MeasureCallback]] = None,
     ):
         """Constructor.
 
@@ -133,6 +141,10 @@ class PyTaskScheduler(TaskScheduler):
             The runner of the scheduler.
         database: Database
             The database of the scheduler.
+        cost_model: Optional[CostModel]
+            The cost model of the scheduler.
+        measure_callbacks: List[MeasureCallback]
+            The list of measure callbacks of the scheduler.
         """
 
         @check_override(self.__class__, TaskScheduler, required=False)
@@ -173,6 +185,8 @@ class PyTaskScheduler(TaskScheduler):
             builder,
             runner,
             database,
+            cost_model,
+            measure_callbacks,
             f_tune,
             f_initialize_task,
             f_set_task_stopped,
diff --git a/python/tvm/meta_schedule/tune_context.py b/python/tvm/meta_schedule/tune_context.py
index 99b8c7e..196b1c1 100644
--- a/python/tvm/meta_schedule/tune_context.py
+++ b/python/tvm/meta_schedule/tune_context.py
@@ -16,13 +16,14 @@
 # under the License.
 """Meta Schedule tuning context."""
 
-from typing import Optional, List, TYPE_CHECKING
+from typing import Optional, List, Dict, TYPE_CHECKING
 
 from tvm import IRModule
 from tvm._ffi import register_object
 from tvm.meta_schedule.utils import cpu_count
 from tvm.runtime import Object
 from tvm.target import Target
+from tvm.tir import PrimFunc
 
 from . import _ffi_api
 
@@ -30,6 +31,8 @@ if TYPE_CHECKING:
     from .space_generator import SpaceGenerator
     from .search_strategy import SearchStrategy
     from .schedule_rule import ScheduleRule
+    from .postproc import Postproc
+    from .mutator import Mutator
 
 
 @register_object("meta_schedule.TuneContext")
@@ -53,6 +56,10 @@ class TuneContext(Object):
         The search strategy.
     sch_rules: Optional[List[ScheduleRule]] = None,
         The schedule rules.
+    postprocs: Optional[List[Postproc"]] = None,
+        The postprocessors.
+    mutator_probs: Optional[Dict[Mutator, float]]
+        Mutators and their probability mass.
     task_name : Optional[str] = None
         The name of the tuning task.
     rand_state : int = -1
@@ -71,23 +78,31 @@ class TuneContext(Object):
 
     mod: Optional[IRModule]
     target: Optional[Target]
-    space_generator: "SpaceGenerator"
-    search_strategy: "SearchStrategy"
-    task_name: Optional[str]
+    space_generator: Optional["SpaceGenerator"]
+    search_strategy: Optional["SearchStrategy"]
+    sch_rules: List["ScheduleRule"]
+    postprocs: List["Postproc"]
+    mutator_probs: Optional[Dict["Mutator", float]]
+    task_name: str
     rand_state: int
     num_threads: int
 
     def __init__(
         self,
         mod: Optional[IRModule] = None,
+        *,
         target: Optional[Target] = None,
         space_generator: Optional["SpaceGenerator"] = None,
         search_strategy: Optional["SearchStrategy"] = None,
         sch_rules: Optional[List["ScheduleRule"]] = None,
-        task_name: Optional[str] = None,
+        postprocs: Optional[List["Postproc"]] = None,
+        mutator_probs: Optional[Dict["Mutator", float]] = None,
+        task_name: str = "main",
         rand_state: int = -1,
         num_threads: Optional[int] = None,
     ):
+        if isinstance(mod, PrimFunc):
+            mod = IRModule.from_expr(mod)
         if num_threads is None:
             num_threads = cpu_count()
 
@@ -98,6 +113,8 @@ class TuneContext(Object):
             space_generator,
             search_strategy,
             sch_rules,
+            postprocs,
+            mutator_probs,
             task_name,
             rand_state,
             num_threads,
diff --git a/src/meta_schedule/cost_model/cost_model.cc b/src/meta_schedule/cost_model/cost_model.cc
index 5cd32b0..c6efb54 100644
--- a/src/meta_schedule/cost_model/cost_model.cc
+++ b/src/meta_schedule/cost_model/cost_model.cc
@@ -53,10 +53,10 @@ TVM_REGISTER_GLOBAL("meta_schedule.CostModelUpdate")
     .set_body_method<CostModel>(&CostModelNode::Update);
 TVM_REGISTER_GLOBAL("meta_schedule.CostModelPredict")
     .set_body_typed([](CostModel model,                     //
-                       const TuneContext& tune_context,     //
+                       const TuneContext& context,          //
                        Array<MeasureCandidate> candidates,  //
                        void* p_addr) -> void {
-      std::vector<double> result = model->Predict(tune_context, candidates);
+      std::vector<double> result = model->Predict(context, candidates);
       std::copy(result.begin(), result.end(), static_cast<double*>(p_addr));
     });
 TVM_REGISTER_GLOBAL("meta_schedule.CostModelPyCostModel").set_body_typed(CostModel::PyCostModel);
diff --git a/src/meta_schedule/search_strategy/replay_trace.cc b/src/meta_schedule/search_strategy/replay_trace.cc
index 200eca3..1eac10d 100644
--- a/src/meta_schedule/search_strategy/replay_trace.cc
+++ b/src/meta_schedule/search_strategy/replay_trace.cc
@@ -24,20 +24,18 @@ namespace meta_schedule {
 /*! \brief A search strategy that generates measure candidates using trace and random decisions. */
 class ReplayTraceNode : public SearchStrategyNode {
  public:
-  using TRandState = support::LinearCongruentialEngine::TRandState;
-
   /*! \brief The state of the search strategy. */
   struct State {
     /*! \brief The search strategy itself */
     ReplayTraceNode* self;
     /*! \brief The design spaces. */
-    Array<tir::Schedule> design_spaces;
+    Array<tir::Trace> design_spaces;
     /*! \brief `[st, ed)` are the indices of the next batch of candidates. */
     int st;
     /*! \brief `[st, ed)` are the indices of the next batch of candidates. */
     int ed;
 
-    explicit State(ReplayTraceNode* self, Array<tir::Schedule> design_spaces)
+    explicit State(ReplayTraceNode* self, Array<tir::Trace> design_spaces)
         : self(self), design_spaces(design_spaces), st(0), ed(self->num_trials_per_iter) {}
 
     inline Optional<Array<MeasureCandidate>> GenerateMeasureCandidates();
@@ -50,9 +48,11 @@ class ReplayTraceNode : public SearchStrategyNode {
   int num_trials_total;
 
   /*! \brief The module to be tuned. */
-  IRModule mod_{nullptr};
+  Array<IRModule> per_thread_mod_{nullptr};
   /*! \brief The metadata of the function arguments. */
   Array<ArgInfo> args_info_{nullptr};
+  /*! \brief The post processors */
+  Array<Postproc> postprocs_{nullptr};
   /*! \brief The number of threads to use. -1 means using logical cpu number. */
   int num_threads_ = -1;
   /*! \brief The random state. -1 means using random number. */
@@ -63,8 +63,9 @@ class ReplayTraceNode : public SearchStrategyNode {
   void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("num_trials_per_iter", &num_trials_per_iter);
     v->Visit("num_trials_total", &num_trials_total);
-    // `mod_` is not visited
+    // `per_thread_mod_` is not visited
     // `args_info_` is not visited
+    // `postprocs_` is not visited
     // `num_threads_` is not visited
     // `rand_state_` is not visited
     // `state_` is not visited
@@ -74,9 +75,16 @@ class ReplayTraceNode : public SearchStrategyNode {
   TVM_DECLARE_FINAL_OBJECT_INFO(ReplayTraceNode, SearchStrategyNode);
 
   void InitializeWithTuneContext(const TuneContext& context) final {
-    this->mod_ = context->mod.value();
-    this->args_info_ = ArgInfo::FromPrimFunc(FindEntryFunc(this->mod_));
+    CHECK(context->num_threads > 0) << "Number of threads has to be larger than 0.";
     this->num_threads_ = context->num_threads;
+
+    this->per_thread_mod_.reserve(this->num_threads_);
+    for (int i = 0; i < this->num_threads_; i++) {
+      this->per_thread_mod_.push_back(DeepCopyIRModule(context->mod.value()));
+    }
+
+    this->args_info_ = ArgInfo::FromPrimFunc(FindEntryFunc(context->mod.value()));
+    this->postprocs_ = context->postprocs;
     this->rand_state_ = ForkSeed(&context->rand_state);
     this->state_.reset();
   }
@@ -84,7 +92,12 @@ class ReplayTraceNode : public SearchStrategyNode {
   void PreTuning(const Array<tir::Schedule>& design_spaces) final {
     ICHECK(!design_spaces.empty());
     ICHECK(this->state_ == nullptr);
-    this->state_ = std::make_unique<State>(this, design_spaces);
+    Array<tir::Trace> design_space_traces;
+    design_space_traces.reserve(design_spaces.size());
+    for (const tir::Schedule& space : design_spaces) {
+      design_space_traces.push_back(space->trace().value()->Simplified(true));
+    }
+    this->state_ = std::make_unique<State>(this, design_space_traces);
   }
 
   void PostTuning() final {
@@ -97,7 +110,9 @@ class ReplayTraceNode : public SearchStrategyNode {
     return this->state_->GenerateMeasureCandidates();
   }
 
-  void NotifyRunnerResults(const Array<RunnerResult>& results) final {
+  void NotifyRunnerResults(const TuneContext& context,
+                           const Array<MeasureCandidate>& measure_candidates,
+                           const Array<RunnerResult>& results) final {
     ICHECK(this->state_ != nullptr);
     this->state_->NotifyRunnerResults(results);
   }
@@ -111,19 +126,20 @@ inline Optional<Array<MeasureCandidate>> ReplayTraceNode::State::GenerateMeasure
   ICHECK_LT(st, ed);
   std::vector<TRandState> per_thread_rand_state = ForkSeed(&self->rand_state_, self->num_threads_);
   Array<MeasureCandidate> per_task_result(ed - st, MeasureCandidate{nullptr});
-  auto f_worker = [this, &per_thread_rand_state, &per_task_result](int thread_id,
-                                                                   int task_id) -> void {
+  ThreadedTraceApply pp(self->postprocs_);
+  auto f_worker = [this, &per_thread_rand_state, &per_task_result, &pp](int thread_id,
+                                                                        int task_id) -> void {
     TRandState& rand_state = per_thread_rand_state[thread_id];
-    int design_space_index = tir::SampleInt(&rand_state, 0, design_spaces.size());
-    tir::Trace trace = design_spaces[design_space_index]->trace().value();
-    tir::Trace new_trace = tir::Trace(trace->insts, {});
-    tir::Schedule sch = tir::Schedule::Traced(  //
-        self->mod_,                             //
-        /*rand_state=*/ForkSeed(&rand_state),   //
-        /*debug_mode=*/0,                       //
-        /*error_render_level=*/tir::ScheduleErrorRenderLevel::kNone);
-    new_trace->ApplyToSchedule(sch, /*remove_postproc=*/true);
-    per_task_result.Set(task_id, MeasureCandidate(sch, self->args_info_));
+    IRModule mod = self->per_thread_mod_[thread_id];
+    for (;;) {
+      int design_space_index = tir::SampleInt(&rand_state, 0, design_spaces.size());
+      tir::Trace trace = design_spaces[design_space_index];
+      tir::Trace new_trace = tir::Trace(trace->insts, {});
+      if (Optional<tir::Schedule> sch = pp.Apply(mod, new_trace, &rand_state)) {
+        per_task_result.Set(task_id, MeasureCandidate(sch.value(), self->args_info_));
+        break;
+      }
+    }
   };
   support::parallel_for_dynamic(0, ed - st, self->num_threads_, f_worker);
   return per_task_result;
@@ -142,7 +158,8 @@ SearchStrategy SearchStrategy::ReplayTrace(int num_trials_per_iter, int num_tria
 }
 
 TVM_REGISTER_NODE_TYPE(ReplayTraceNode);
-TVM_REGISTER_GLOBAL("meta_schedule.ReplayTrace").set_body_typed(SearchStrategy::ReplayTrace);
+TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyReplayTrace")
+    .set_body_typed(SearchStrategy::ReplayTrace);
 
 }  // namespace meta_schedule
 }  // namespace tvm
diff --git a/src/meta_schedule/task_scheduler/round_robin.cc b/src/meta_schedule/task_scheduler/round_robin.cc
index 3ef5026..72989a2 100644
--- a/src/meta_schedule/task_scheduler/round_robin.cc
+++ b/src/meta_schedule/task_scheduler/round_robin.cc
@@ -52,16 +52,23 @@ class RoundRobinNode final : public TaskSchedulerNode {
   }
 };
 
-TaskScheduler TaskScheduler::RoundRobin(Array<TuneContext> tasks,  //
-                                        Builder builder,           //
-                                        Runner runner,             //
-                                        Database database) {
+TaskScheduler TaskScheduler::RoundRobin(Array<TuneContext> tasks,        //
+                                        Builder builder,                 //
+                                        Runner runner,                   //
+                                        Database database,               //
+                                        Optional<CostModel> cost_model,  //
+                                        Optional<Array<MeasureCallback>> measure_callbacks) {
   ObjectPtr<RoundRobinNode> n = make_object<RoundRobinNode>();
   n->tasks = tasks;
   n->builder = builder;
   n->runner = runner;
   n->database = database;
+  n->cost_model = cost_model;
+  n->measure_callbacks = measure_callbacks.value_or({});
   n->task_id = -1;
+  for (const TuneContext& task : tasks) {
+    task->task_scheduler = n.get();
+  }
   return TaskScheduler(n);
 }
 
diff --git a/src/meta_schedule/task_scheduler/task_scheduler.cc b/src/meta_schedule/task_scheduler/task_scheduler.cc
index 08f2b4f..1f3943d 100644
--- a/src/meta_schedule/task_scheduler/task_scheduler.cc
+++ b/src/meta_schedule/task_scheduler/task_scheduler.cc
@@ -16,7 +16,6 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-
 #include "../utils.h"
 
 namespace tvm {
@@ -29,9 +28,9 @@ namespace meta_schedule {
  * \param candidates The measure candidates.
  * \return An array of the builder results.
  */
-Array<BuilderResult> SendToBuilder(const Builder& builder,  //
-                                   const TuneContext& context,
+Array<BuilderResult> SendToBuilder(const Builder& builder, const TuneContext& context,
                                    const Array<MeasureCandidate>& candidates) {
+  LOG(INFO) << "Sending " << candidates.size() << " sample(s) to builder";
   Target target = context->target.value();
   Array<BuilderInput> inputs;
   inputs.reserve(candidates.size());
@@ -45,14 +44,14 @@ Array<BuilderResult> SendToBuilder(const Builder& builder,  //
  * \brief Send the built measure candidates to runner.
  * \param runner The runner to send the candidates to.
  * \param context The tuning context.
- * \param candidates The mesure candidates.
+ * \param candidates The measure candidates.
  * \param builder_results The builder results.
  * \return An array of the runner results.
  */
-Array<RunnerFuture> SendToRunner(const Runner& runner,  //
-                                 const TuneContext& context,
+Array<RunnerFuture> SendToRunner(const Runner& runner, const TuneContext& context,
                                  const Array<MeasureCandidate>& candidates,
                                  const Array<BuilderResult>& builder_results) {
+  LOG(INFO) << "Sending " << candidates.size() << " sample(s) to runner";
   Target target = context->target.value();
   ICHECK_EQ(candidates.size(), builder_results.size());
   int n = candidates.size();
@@ -94,54 +93,60 @@ Array<RunnerFuture> SendToRunner(const Runner& runner,  //
 
 void TaskSchedulerNode::InitializeTask(int task_id) {
   TuneContext task = this->tasks[task_id];
-  // Derive the values.
-  IRModule mod = task->mod.value();
-  SpaceGenerator space = task->space_generator.value();
-  SearchStrategy strategy = task->search_strategy.value();
-  // Initialize Modules.
-  space->InitializeWithTuneContext(task);
-  strategy->InitializeWithTuneContext(task);
+  LOG(INFO) << "Initializing task " << task_id << ": " << task->task_name << ", mod =\n"
+            << tir::AsTVMScript(task->mod);
+  this->tasks[task_id]->Initialize();
 }
 
 void TaskSchedulerNode::Tune() {
   for (int i = 0; i < static_cast<int>(this->tasks.size()); i++) {
+    TuneContext task = tasks[i];
     // Check Optional value validity.
-    CHECK(tasks[i]->mod.defined()) << "ValueError: Require `context.mod`, but it is not defined";
-    CHECK(tasks[i]->space_generator.defined())
+    CHECK(task->mod.defined()) << "ValueError: Require `context.mod`, but it is not defined";
+    CHECK(task->space_generator.defined())
         << "ValueError: Require `context.space_generator`, but it is not defined";
-    CHECK(tasks[i]->search_strategy.defined())
+    CHECK(task->search_strategy.defined())
         << "ValueError: Require `context.search_strategy`, but it is not defined";
-
     InitializeTask(i);
-
-    tasks[i]->search_strategy.value()->PreTuning(
-        tasks[i]->space_generator.value()->GenerateDesignSpace(tasks[i]->mod.value()));
+    Array<tir::Schedule> design_spaces =
+        task->space_generator.value()->GenerateDesignSpace(task->mod.value());
+    LOG(INFO) << "Total " << design_spaces.size() << " design space(s) generated";
+    for (int i = 0, n = design_spaces.size(); i < n; ++i) {
+      tir::Schedule sch = design_spaces[i];
+      tir::Trace trace = sch->trace().value();
+      trace = trace->Simplified(true);
+      LOG(INFO) << "Design space #" << i << ":\n"
+                << tir::AsTVMScript(sch->mod()) << "\n"
+                << Concat(trace->AsPython(false), "\n");
+    }
+    task->search_strategy.value()->PreTuning(design_spaces);
   }
 
   int running_tasks = tasks.size();
-  while (running_tasks > 0) {
-    for (int task_id; (task_id = NextTaskId()) != -1;) {
-      TuneContext task = tasks[task_id];
-      ICHECK(!task->is_stopped);
-      ICHECK(!task->runner_futures.defined());
-      SearchStrategy strategy = task->search_strategy.value();
-      if ((task->measure_candidates = strategy->GenerateMeasureCandidates()).defined()) {
-        Array<BuilderResult> builder_results =
-            SendToBuilder(this->builder, task, task->measure_candidates.value());
-        task->runner_futures =
-            SendToRunner(this->runner, task, task->measure_candidates.value(), builder_results);
-      } else {
-        SetTaskStopped(task_id);
-        --running_tasks;
-      }
+  for (int task_id; (task_id = NextTaskId()) != -1;) {
+    LOG(INFO) << "Scheduler picks Task #" << task_id << ": " << tasks[task_id]->task_name;
+    TuneContext task = tasks[task_id];
+    ICHECK(!task->is_stopped);
+    ICHECK(!task->runner_futures.defined());
+    SearchStrategy strategy = task->search_strategy.value();
+    if ((task->measure_candidates = strategy->GenerateMeasureCandidates()).defined()) {
+      Array<BuilderResult> builder_results =
+          SendToBuilder(this->builder, task, task->measure_candidates.value());
+      task->builder_results = builder_results;
+      task->runner_futures =
+          SendToRunner(this->runner, task, task->measure_candidates.value(), builder_results);
+    } else {
+      SetTaskStopped(task_id);
+      --running_tasks;
+      LOG(INFO) << "Task #" << task_id << " has finished. Remaining task(s): " << running_tasks;
     }
-    int n_tasks = this->tasks.size();
-    for (int task_id = 0; task_id < n_tasks; ++task_id)
-      if (IsTaskRunning(task_id)) {
-        TuneContext task = tasks[task_id];
-        this->JoinRunningTask(task_id);
-        task->search_strategy.value()->PostTuning();
-      }
+  }
+  ICHECK_EQ(running_tasks, 0) << "Not all tasks are finished";
+  int n_tasks = this->tasks.size();
+  for (int task_id = 0; task_id < n_tasks; ++task_id) {
+    ICHECK(!IsTaskRunning(task_id)) << "Task #" << task_id << " is still running";
+    TuneContext task = tasks[task_id];
+    task->search_strategy.value()->PostTuning();
   }
 }
 
@@ -175,25 +180,20 @@ void TaskSchedulerNode::JoinRunningTask(int task_id) {
   for (const RunnerFuture future : task->runner_futures.value()) {
     results.push_back(future->Result());
   }
-  task->search_strategy.value()->NotifyRunnerResults(results);
-  task->runner_futures = NullOpt;
-  // Add to database
+  task->search_strategy.value()->NotifyRunnerResults(task, task->measure_candidates.value(),
+                                                     results);
+  // Invoke the callbacks
   ICHECK(task->measure_candidates.defined());
-  ICHECK(results.size() == task->measure_candidates.value().size());
-  int index = 0;
-  for (const RunnerResult& result : results) {
-    if (!result->error_msg.defined() && result->run_secs.defined()) {
-      Optional<tir::Trace> trace = task->measure_candidates.value()[index]->sch->trace();
-      ICHECK(trace.defined());
-      this->database->CommitTuningRecord(TuningRecord(
-          /*trace=*/trace.value(),
-          /*run_secs=*/result->run_secs.value(),
-          /*workload=*/this->database->CommitWorkload(task->mod.value()),
-          /*target=*/task->target.value(),
-          /*args_info=*/task->measure_candidates.value()[index]->args_info));
-    }
-    index++;
+  ICHECK(task->builder_results.defined());
+  ICHECK_EQ(results.size(), task->measure_candidates.value().size());
+  ICHECK_EQ(results.size(), task->builder_results.value().size());
+  for (const MeasureCallback& callback : this->measure_callbacks) {
+    callback->Apply(GetRef<TaskScheduler>(this), task_id, task->measure_candidates.value(),
+                    task->builder_results.value(), results);
   }
+  task->measure_candidates = NullOpt;
+  task->builder_results = NullOpt;
+  task->runner_futures = NullOpt;
 }
 
 TaskScheduler TaskScheduler::PyTaskScheduler(
@@ -201,6 +201,8 @@ TaskScheduler TaskScheduler::PyTaskScheduler(
     Builder builder,                                            //
     Runner runner,                                              //
     Database database,                                          //
+    Optional<CostModel> cost_model,                             //
+    Optional<Array<MeasureCallback>> measure_callbacks,         //
     PyTaskSchedulerNode::FTune f_tune,                          //
     PyTaskSchedulerNode::FInitializeTask f_initialize_task,     //
     PyTaskSchedulerNode::FSetTaskStopped f_set_task_stopped,    //
@@ -212,6 +214,12 @@ TaskScheduler TaskScheduler::PyTaskScheduler(
   n->builder = builder;
   n->runner = runner;
   n->database = database;
+  n->cost_model = cost_model;
+  if (measure_callbacks.defined()) {
+    n->measure_callbacks = measure_callbacks.value();
+  } else {
+    n->measure_callbacks = {};
+  }
   n->f_tune = f_tune;
   n->f_initialize_task = f_initialize_task;
   n->f_set_task_stopped = f_set_task_stopped;
diff --git a/src/meta_schedule/tune_context.cc b/src/meta_schedule/tune_context.cc
index ac85d43..f4595d3 100644
--- a/src/meta_schedule/tune_context.cc
+++ b/src/meta_schedule/tune_context.cc
@@ -16,7 +16,6 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-#include <random>
 #include <utility>
 
 #include "./utils.h"
@@ -24,21 +23,13 @@
 namespace tvm {
 namespace meta_schedule {
 
-/*!
- * \brief Constructor function of TuneContext class.
- * \param mod The mod to be optimized.
- * \param target The target to be optimized for.
- * \param space_generator The design space generator.
- * \param task_name The name of the tuning task.
- * \param rand_state The random state.
- * \param num_threads The number of threads to be used.
- * \param verbose The verbosity level.
- */
 TuneContext::TuneContext(Optional<IRModule> mod,                                    //
                          Optional<Target> target,                                   //
                          Optional<SpaceGenerator> space_generator,                  //
                          Optional<SearchStrategy> search_strategy,                  //
                          Optional<Array<ScheduleRule>> sch_rules,                   //
+                         Optional<Array<Postproc>> postprocs,                       //
+                         Optional<Map<Mutator, FloatImm>> mutator_probs,            //
                          Optional<String> task_name,                                //
                          support::LinearCongruentialEngine::TRandState rand_state,  //
                          int num_threads) {
@@ -48,9 +39,11 @@ TuneContext::TuneContext(Optional<IRModule> mod,
   n->space_generator = space_generator;
   n->search_strategy = search_strategy;
   n->sch_rules = sch_rules.value_or({});
+  n->postprocs = postprocs.value_or({});
+  n->mutator_probs = mutator_probs.value_or({});
   n->task_name = task_name;
   if (rand_state == -1) {
-    rand_state = std::random_device()();
+    rand_state = support::LinearCongruentialEngine::DeviceRandom();
   }
   support::LinearCongruentialEngine(&n->rand_state).Seed(rand_state);
   n->num_threads = num_threads;
@@ -60,6 +53,26 @@ TuneContext::TuneContext(Optional<IRModule> mod,
   data_ = std::move(n);
 }
 
+void TuneContextNode::Initialize() {
+  if (this->space_generator.defined()) {
+    this->space_generator.value()->InitializeWithTuneContext(GetRef<TuneContext>(this));
+  }
+  if (this->search_strategy.defined()) {
+    this->search_strategy.value()->InitializeWithTuneContext(GetRef<TuneContext>(this));
+  }
+  for (const ScheduleRule& sch_rule : sch_rules) {
+    sch_rule->InitializeWithTuneContext(GetRef<TuneContext>(this));
+  }
+  for (const Postproc& postproc : postprocs) {
+    postproc->InitializeWithTuneContext(GetRef<TuneContext>(this));
+  }
+  if (mutator_probs.defined()) {
+    for (const auto& kv : mutator_probs) {
+      kv.first->InitializeWithTuneContext(GetRef<TuneContext>(this));
+    }
+  }
+}
+
 TVM_REGISTER_NODE_TYPE(TuneContextNode);
 
 TVM_REGISTER_GLOBAL("meta_schedule.TuneContext")
@@ -68,11 +81,13 @@ TVM_REGISTER_GLOBAL("meta_schedule.TuneContext")
                        Optional<SpaceGenerator> space_generator,                  //
                        Optional<SearchStrategy> search_strategy,                  //
                        Optional<Array<ScheduleRule>> sch_rules,                   //
+                       Optional<Array<Postproc>> postprocs,                       //
+                       Optional<Map<Mutator, FloatImm>> mutator_probs,            //
                        Optional<String> task_name,                                //
                        support::LinearCongruentialEngine::TRandState rand_state,  //
                        int num_threads) -> TuneContext {
-      return TuneContext(mod, target, space_generator, search_strategy, sch_rules, task_name,
-                         rand_state, num_threads);
+      return TuneContext(mod, target, space_generator, search_strategy, sch_rules, postprocs,
+                         mutator_probs, task_name, rand_state, num_threads);
     });
 }  // namespace meta_schedule
 }  // namespace tvm
diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h
index 0a9ce4a..3e989e4 100644
--- a/src/meta_schedule/utils.h
+++ b/src/meta_schedule/utils.h
@@ -200,7 +200,7 @@ inline support::LinearCongruentialEngine::TRandState ForkSeed(
 
 /*!
  * \brief Fork a random state into another ones, i.e. PRNG splitting.
- * The given random state is also mutated.
+ *  The given random state is also mutated.
  * \param rand_state The random state to be forked
  * \param n The number of forks
  * \return The forked random states
@@ -216,6 +216,15 @@ inline std::vector<support::LinearCongruentialEngine::TRandState> ForkSeed(
 }
 
 /*!
+ * \brief Get deep copy of an IRModule.
+ * \param mod The IRModule to make a deep copy.
+ * \return The deep copy of the IRModule.
+ */
+inline IRModule DeepCopyIRModule(IRModule mod) {
+  return Downcast<IRModule>(LoadJSON(SaveJSON(mod)));
+}
+
+/*!
  * \brief Concatenate strings
  * \param strs The strings to concatenate
  * \param delim The delimiter
@@ -233,6 +242,78 @@ inline std::string Concat(const Array<String>& strs, const std::string& delim) {
   return os.str();
 }
 
+/*!
+ * \brief A helper data structure that replays a trace and collects failure counts
+ * for each postprocessor
+ */
+struct ThreadedTraceApply {
+  /*! \brief Constructor */
+  explicit ThreadedTraceApply(const Array<Postproc>& postprocs)
+      : n_(postprocs.size()), items_(new Item[n_]) {
+    for (int i = 0; i < n_; ++i) {
+      items_[i].postproc = postprocs[i];
+      items_[i].fail_counter = 0;
+    }
+  }
+
+  /*! \brief Destructor */
+  ~ThreadedTraceApply() { delete[] items_; }
+
+  /*!
+   * \brief Apply the trace and postprocessors to an IRModule
+   * \param mod The IRModule to be applied
+   * \param trace The trace to apply to the IRModule
+   * \param rand_state The random seed
+   * \return The schedule created, or NullOpt if any postprocessor fails
+   */
+  Optional<tir::Schedule> Apply(const IRModule& mod, const tir::Trace& trace,
+                                TRandState* rand_state) {
+    tir::Schedule sch =
+        tir::Schedule::Traced(mod,
+                              /*rand_state=*/ForkSeed(rand_state),
+                              /*debug_mode=*/0,
+                              /*error_render_level=*/tir::ScheduleErrorRenderLevel::kNone);
+    trace->ApplyToSchedule(sch, /*remove_postproc=*/true);
+    sch->EnterPostproc();
+    for (int i = 0; i < n_; ++i) {
+      Item& item = items_[i];
+      if (!item.postproc->Apply(sch)) {
+        ++item.fail_counter;
+        return NullOpt;
+      }
+    }
+    return sch;
+  }
+
+  /*! \brief Returns a string summarizing the failures on each postprocessor */
+  std::string SummarizeFailures() const {
+    std::ostringstream os;
+    for (int i = 0; i < n_; ++i) {
+      const Item& item = items_[i];
+      os << "Postproc #" << i << " [" << item.postproc  //
+         << "]: " << item.fail_counter.load() << " failure(s)";
+      if (i != n_ - 1) {
+        os << "\n";
+      }
+    }
+    return os.str();
+  }
+
+ private:
+  /*! \brief A helper data structure that stores the fail count for each postprocessor. */
+  struct Item {
+    /*! \brief The postprocessor. */
+    Postproc postproc{nullptr};
+    /*! \brief The thread-safe postprocessor failure counter. */
+    std::atomic<int> fail_counter{0};
+  };
+
+  /*! \brief The number of total postprocessors. */
+  int n_;
+  /*! \brief The pointer to the list of postprocessor items. */
+  Item* items_;
+};
+
 }  // namespace meta_schedule
 }  // namespace tvm
 
diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc
index 65886da..37d896a 100644
--- a/src/tir/schedule/concrete_schedule.cc
+++ b/src/tir/schedule/concrete_schedule.cc
@@ -30,7 +30,7 @@ Schedule Schedule::Concrete(IRModule mod, support::LinearCongruentialEngine::TRa
   n->error_render_level_ = error_render_level;
   n->symbol_table_ = {};
   n->analyzer_ = std::make_unique<arith::Analyzer>();
-  support::LinearCongruentialEngine(&n->rand_state_).Seed(seed);
+  n->Seed(seed);
   return Schedule(std::move(n));
 }
 
diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc
index 6128366..b4d1ba0 100644
--- a/src/tir/schedule/traced_schedule.cc
+++ b/src/tir/schedule/traced_schedule.cc
@@ -29,7 +29,7 @@ Schedule Schedule::Traced(IRModule mod, support::LinearCongruentialEngine::TRand
   n->symbol_table_ = {};
   n->analyzer_ = std::make_unique<arith::Analyzer>();
   n->trace_ = Trace();
-  support::LinearCongruentialEngine(&n->rand_state_).Seed(seed);
+  n->Seed(seed);
   return Schedule(std::move(n));
 }
 
diff --git a/tests/python/unittest/test_meta_schedule_cost_model.py b/tests/python/unittest/test_meta_schedule_cost_model.py
index 3f98d71..5b409be 100644
--- a/tests/python/unittest/test_meta_schedule_cost_model.py
+++ b/tests/python/unittest/test_meta_schedule_cost_model.py
@@ -62,15 +62,13 @@ def test_meta_schedule_cost_model():
 
         def update(
             self,
-            tune_context: TuneContext,
+            context: TuneContext,
             candidates: List[MeasureCandidate],
             results: List[RunnerResult],
         ) -> None:
             pass
 
-        def predict(
-            self, tune_context: TuneContext, candidates: List[MeasureCandidate]
-        ) -> np.ndarray:
+        def predict(self, context: TuneContext, candidates: List[MeasureCandidate]) -> np.ndarray:
             return np.random.rand(10)
 
     model = FancyCostModel()
@@ -91,15 +89,13 @@ def test_meta_schedule_cost_model_as_string():
 
         def update(
             self,
-            tune_context: TuneContext,
+            context: TuneContext,
             candidates: List[MeasureCandidate],
             results: List[RunnerResult],
         ) -> None:
             pass
 
-        def predict(
-            self, tune_context: TuneContext, candidates: List[MeasureCandidate]
-        ) -> np.ndarray:
+        def predict(self, context: TuneContext, candidates: List[MeasureCandidate]) -> np.ndarray:
             return np.random.rand(10)
 
     cost_model = NotSoFancyCostModel()
diff --git a/tests/python/unittest/test_meta_schedule_feature_extractor.py b/tests/python/unittest/test_meta_schedule_feature_extractor.py
index 143d446..d95397b 100644
--- a/tests/python/unittest/test_meta_schedule_feature_extractor.py
+++ b/tests/python/unittest/test_meta_schedule_feature_extractor.py
@@ -28,7 +28,7 @@ def test_meta_schedule_feature_extractor():
     class FancyFeatureExtractor(PyFeatureExtractor):
         def extract_from(
             self,
-            tune_context: TuneContext,  # pylint: disable = unused-argument
+            context: TuneContext,  # pylint: disable = unused-argument
             candidates: List[MeasureCandidate],  # pylint: disable = unused-argument
         ) -> List[np.ndarray]:
             return [np.random.rand(4, 5)]
@@ -43,7 +43,7 @@ def test_meta_schedule_feature_extractor_as_string():
     class NotSoFancyFeatureExtractor(PyFeatureExtractor):
         def extract_from(
             self,
-            tune_context: TuneContext,  # pylint: disable = unused-argument
+            context: TuneContext,  # pylint: disable = unused-argument
             candidates: List[MeasureCandidate],  # pylint: disable = unused-argument
         ) -> List[np.ndarray]:
             return []
diff --git a/tests/python/unittest/test_meta_schedule_post_order_apply.py b/tests/python/unittest/test_meta_schedule_post_order_apply.py
index b78e678..78477e6 100644
--- a/tests/python/unittest/test_meta_schedule_post_order_apply.py
+++ b/tests/python/unittest/test_meta_schedule_post_order_apply.py
@@ -135,7 +135,7 @@ def _check_correct(schedule: Schedule):
 
 
 class WowSoFancyScheduleRule(PyScheduleRule):
-    def initialize_with_tune_context(self, tune_context: "TuneContext") -> None:
+    def initialize_with_tune_context(self, context: "TuneContext") -> None:
         pass
 
     def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]:
@@ -151,7 +151,7 @@ class WowSoFancyScheduleRule(PyScheduleRule):
 
 
 class DoubleScheduleRule(PyScheduleRule):
-    def initialize_with_tune_context(self, tune_context: "TuneContext") -> None:
+    def initialize_with_tune_context(self, context: "TuneContext") -> None:
         pass
 
     def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]:
@@ -175,7 +175,7 @@ class DoubleScheduleRule(PyScheduleRule):
 
 
 class ReorderScheduleRule(PyScheduleRule):
-    def initialize_with_tune_context(self, tune_context: "TuneContext") -> None:
+    def initialize_with_tune_context(self, context: "TuneContext") -> None:
         pass
 
     def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]:
@@ -262,7 +262,7 @@ def test_meta_schedule_post_order_apply_duplicate_matmul():
 
 def test_meta_schedule_post_order_apply_remove_block():
     class TrinityDouble(PyScheduleRule):
-        def initialize_with_tune_context(self, tune_context: "TuneContext") -> None:
+        def initialize_with_tune_context(self, context: "TuneContext") -> None:
             pass
 
         def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]:
@@ -283,7 +283,7 @@ def test_meta_schedule_post_order_apply_remove_block():
             return result
 
     class RemoveBlock(PyScheduleRule):
-        def initialize_with_tune_context(self, tune_context: "TuneContext") -> None:
+        def initialize_with_tune_context(self, context: "TuneContext") -> None:
             pass
 
         def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]:
diff --git a/tests/python/unittest/test_meta_schedule_search_strategy.py b/tests/python/unittest/test_meta_schedule_search_strategy.py
index 9b3ddfd..668fca9 100644
--- a/tests/python/unittest/test_meta_schedule_search_strategy.py
+++ b/tests/python/unittest/test_meta_schedule_search_strategy.py
@@ -16,25 +16,25 @@
 # under the License.
 """ Test Meta Schedule SearchStrategy """
 # pylint: disable=missing-function-docstring
-from typing import List
-
 import sys
-
 import pytest
+from typing import List
 
 import tvm
 from tvm.meta_schedule import TuneContext
 from tvm.meta_schedule.runner import RunnerResult
+from tvm.meta_schedule.search_strategy import (
+    ReplayTrace,
+    SearchStrategy,
+)
 from tvm.meta_schedule.space_generator import ScheduleFn
-from tvm.meta_schedule.search_strategy import ReplayTrace
-
 from tvm.script import tir as T
 from tvm.tir.schedule import Schedule, Trace
 
 
 MATMUL_M = 32
 
-# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument, unbalanced-tuple-unpacking
+# pylint: disable=missing-class-docstring,invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument, unbalanced-tuple-unpacking
 # fmt: off
 
 @tvm.script.ir_module
@@ -53,48 +53,57 @@ class Matmul:
                 C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
 
 # fmt: on
-# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument
+# pylint: enable=missing-class-docstring,invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument
 
 
-def _is_trace_equal(sch_1: Schedule, sch_2: Schedule) -> bool:
-    trace_1 = Trace(sch_1.trace.insts, {})
-    trace_2 = Trace(sch_2.trace.insts, {})
+def _is_trace_equal(sch_1: Schedule, sch_2: Schedule, remove_decisions=True) -> bool:
+    if remove_decisions:
+        trace_1 = Trace(sch_1.trace.insts, {})
+        trace_2 = Trace(sch_2.trace.insts, {})
+    else:
+        trace_1 = sch_1.trace
+        trace_2 = sch_2.trace
     return str(trace_1) == str(trace_2)
 
 
 def _schedule_matmul(sch: Schedule):
     block = sch.get_block("matmul")
     i, j, k = sch.get_loops(block=block)
-    # TODO(@zxybazh): Change to `sample_perfect_tile` after upstreaming
-    i_0, i_1, i_2, i_3 = sch.split(loop=i, factors=[2, 4, 64, 2])
-    j_0, j_1, j_2, j_3 = sch.split(loop=j, factors=[4, 64, 2, 2])
-    k_0, k_1 = sch.split(loop=k, factors=[32, 32])
+    i_0, i_1, i_2, i_3 = sch.split(i, sch.sample_perfect_tile(i, n=4))
+    j_0, j_1, j_2, j_3 = sch.split(j, sch.sample_perfect_tile(j, n=4))
+    k_0, k_1 = sch.split(k, sch.sample_perfect_tile(k, n=2))
     sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3)
 
 
-def test_meta_schedule_replay_trace():
+@pytest.mark.parametrize("TestClass", [ReplayTrace])
+def test_meta_schedule_replay_func(TestClass: SearchStrategy):  # pylint: disable = invalid-name
     num_trials_per_iter = 7
     num_trials_total = 20
 
-    (example_sch,) = ScheduleFn(sch_fn=_schedule_matmul).generate_design_space(Matmul)
-    replay = ReplayTrace(num_trials_per_iter=num_trials_per_iter, num_trials_total=num_trials_total)
-    tune_context = TuneContext(mod=Matmul)
-    replay.initialize_with_tune_context(tune_context)
-
-    num_trials_each_round: List[int] = []
-    replay.pre_tuning([example_sch])
-    while True:
-        candidates = replay.generate_measure_candidates()
-        if candidates is None:
-            break
-        num_trials_each_round.append(len(candidates))
+    strategy = TestClass(num_trials_per_iter=num_trials_per_iter, num_trials_total=num_trials_total)
+    context = TuneContext(mod=Matmul, space_generator=ScheduleFn(sch_fn=_schedule_matmul))
+    context.space_generator.initialize_with_tune_context(context)
+    spaces = context.space_generator.generate_design_space(context.mod)
+
+    strategy.initialize_with_tune_context(context)
+    strategy.pre_tuning(spaces)
+    (correct_sch,) = ScheduleFn(sch_fn=_schedule_matmul).generate_design_space(Matmul)
+    num_trials_each_iter: List[int] = []
+    candidates = strategy.generate_measure_candidates()
+    while candidates is not None:
+        num_trials_each_iter.append(len(candidates))
         runner_results: List[RunnerResult] = []
         for candidate in candidates:
-            assert _is_trace_equal(candidate.sch, example_sch)
-            runner_results.append(RunnerResult(run_secs=[0.5, 0.4, 0.3], error_msg=None))
-        replay.notify_runner_results(runner_results)
-    replay.post_tuning()
-    assert num_trials_each_round == [7, 7, 6]
+            _is_trace_equal(
+                candidate.sch,
+                correct_sch,
+                remove_decisions=(isinstance(strategy, ReplayTrace)),
+            )
+            runner_results.append(RunnerResult(run_secs=[0.11, 0.41, 0.54], error_msg=None))
+        strategy.notify_runner_results(context, candidates, runner_results)
+        candidates = strategy.generate_measure_candidates()
+    strategy.post_tuning()
+    assert num_trials_each_iter == [7, 7, 6]
 
 
 if __name__ == "__main__":
diff --git a/tests/python/unittest/test_meta_schedule_task_scheduler.py b/tests/python/unittest/test_meta_schedule_task_scheduler.py
index 7eb61ad..d3c4dbc 100644
--- a/tests/python/unittest/test_meta_schedule_task_scheduler.py
+++ b/tests/python/unittest/test_meta_schedule_task_scheduler.py
@@ -16,24 +16,22 @@
 # under the License.
 """ Test Meta Schedule Task Scheduler """
 
-from typing import List
-
-import sys
 import random
+import sys
+from typing import List
 
 import pytest
-
 import tvm
-from tvm.script import tir as T
 from tvm.ir import IRModule
-from tvm.tir import Schedule
-from tvm.meta_schedule import TuneContext
-from tvm.meta_schedule.space_generator import ScheduleFn
-from tvm.meta_schedule.search_strategy import ReplayTrace
-from tvm.meta_schedule.builder import PyBuilder, BuilderInput, BuilderResult
-from tvm.meta_schedule.runner import PyRunner, RunnerInput, RunnerFuture, RunnerResult
+from tvm.meta_schedule import TuneContext, measure_callback
+from tvm.meta_schedule.builder import BuilderInput, BuilderResult, PyBuilder
 from tvm.meta_schedule.database import PyDatabase, TuningRecord, Workload
-from tvm.meta_schedule.task_scheduler import RoundRobin, PyTaskScheduler
+from tvm.meta_schedule.runner import PyRunner, RunnerFuture, RunnerInput, RunnerResult
+from tvm.meta_schedule.search_strategy import ReplayTrace
+from tvm.meta_schedule.space_generator import ScheduleFn
+from tvm.meta_schedule.task_scheduler import PyTaskScheduler, RoundRobin
+from tvm.script import tir as T
+from tvm.tir import Schedule
 
 
 # pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring
@@ -140,7 +138,10 @@ class DummyDatabase(PyDatabase):
         self.records = []
         self.workload_reg = []
 
-    def has_workload(self, mod: IRModule) -> bool:
+    def has_workload(self, mod: IRModule) -> Workload:
+        for workload in self.workload_reg:
+            if tvm.ir.structural_equal(workload.mod, mod):
+                return True
         return False
 
     def commit_tuning_record(self, record: TuningRecord) -> None:
@@ -183,7 +184,13 @@ def test_meta_schedule_task_scheduler_single():
         rand_state=42,
     )
     database = DummyDatabase()
-    round_robin = RoundRobin([task], DummyBuilder(), DummyRunner(), database)
+    round_robin = RoundRobin(
+        [task],
+        DummyBuilder(),
+        DummyRunner(),
+        database,
+        measure_callbacks=[measure_callback.AddToDatabase()],
+    )
     round_robin.tune()
     assert len(database) == num_trials_total
 
@@ -218,15 +225,29 @@ def test_meta_schedule_task_scheduler_multiple():
         ),
     ]
     database = DummyDatabase()
-    round_robin = RoundRobin(tasks, DummyBuilder(), DummyRunner(), database)
+    round_robin = RoundRobin(
+        tasks,
+        DummyBuilder(),
+        DummyRunner(),
+        database,
+        measure_callbacks=[measure_callback.AddToDatabase()],
+    )
     round_robin.tune()
     assert len(database) == num_trials_total * len(tasks)
     print(database.workload_reg)
     for task in tasks:
-        assert len(database.get_top_k(database.commit_workload(task.mod), 1e9)) == num_trials_total
+        assert (
+            len(
+                database.get_top_k(
+                    database.commit_workload(task.mod),
+                    100000,
+                )
+            )
+            == num_trials_total
+        )
 
 
-def test_meta_schedule_task_scheduler_NIE():
+def test_meta_schedule_task_scheduler_not_implemented_error():  # pylint: disable=invalid-name
     class MyTaskScheduler(PyTaskScheduler):
         pass
 
@@ -234,7 +255,7 @@ def test_meta_schedule_task_scheduler_NIE():
         MyTaskScheduler([], DummyBuilder(), DummyRunner(), DummyDatabase())
 
 
-def test_meta_schedule_task_scheduler_override_next_task_id_only():
+def test_meta_schedule_task_scheduler_override_next_task_id_only():  # pylint: disable=invalid-name
     class MyTaskScheduler(PyTaskScheduler):
         done = set()
 
@@ -291,11 +312,27 @@ def test_meta_schedule_task_scheduler_override_next_task_id_only():
         ),
     ]
     database = DummyDatabase()
-    scheduler = MyTaskScheduler(tasks, DummyBuilder(), DummyRunner(), database)
+    scheduler = MyTaskScheduler(
+        tasks,
+        DummyBuilder(),
+        DummyRunner(),
+        database,
+        measure_callbacks=[
+            measure_callback.AddToDatabase(),
+        ],
+    )
     scheduler.tune()
     assert len(database) == num_trials_total * len(tasks)
     for task in tasks:
-        assert len(database.get_top_k(database.commit_workload(task.mod), 1e9)) == num_trials_total
+        assert (
+            len(
+                database.get_top_k(
+                    database.commit_workload(task.mod),
+                    100000,
+                )
+            )
+            == num_trials_total
+        )
 
 
 if __name__ == "__main__":