You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/09/16 05:01:40 UTC

[tvm] branch main updated: [MetaSchedule] Enable Clone Function for Task-Level Classes (#12796)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 02c2eae510 [MetaSchedule] Enable Clone Function for Task-Level Classes (#12796)
02c2eae510 is described below

commit 02c2eae510d6d6c15189427c97819f7ce05f002d
Author: Xiyou Zhou <xi...@octoml.ai>
AuthorDate: Thu Sep 15 22:01:33 2022 -0700

    [MetaSchedule] Enable Clone Function for Task-Level Classes (#12796)
    
    This PR introduces a clone function for each of the task-level MetaSchedule classes for convenient class deep copying.
    
    - [x] ScheduleRule
    - [x] Postproc
    - [x] Mutator
    - [x] SpaceGenerator
    - [x] SearchStrategy
    - [x] TuneContext
---
 include/tvm/meta_schedule/mutator.h                |  88 ++++++++++------
 include/tvm/meta_schedule/postproc.h               |  86 ++++++++++------
 include/tvm/meta_schedule/schedule_rule.h          |  86 ++++++++++------
 include/tvm/meta_schedule/search_strategy.h        | 114 +++++++++++++--------
 include/tvm/meta_schedule/space_generator.h        |  78 ++++++++------
 include/tvm/meta_schedule/tune_context.h           |   6 ++
 python/tvm/meta_schedule/mutator/mutator.py        |  24 ++++-
 python/tvm/meta_schedule/postproc/postproc.py      |  24 ++++-
 .../meta_schedule/schedule_rule/schedule_rule.py   |  32 ++++--
 .../search_strategy/search_strategy.py             |  23 +++++
 .../space_generator/space_generator.py             |  24 ++++-
 python/tvm/meta_schedule/testing/dummy_object.py   |   3 +
 python/tvm/meta_schedule/tune_context.py           |  10 ++
 .../mutator/mutate_compute_location.cc             |   5 +
 src/meta_schedule/mutator/mutate_parallel.cc       |   5 +
 src/meta_schedule/mutator/mutate_thread_binding.cc |   5 +
 src/meta_schedule/mutator/mutate_tile_size.cc      |   5 +
 src/meta_schedule/mutator/mutate_unroll.cc         |   5 +
 src/meta_schedule/mutator/mutator.cc               |   8 ++
 .../postproc/disallow_dynamic_loop.cc              |   5 +
 src/meta_schedule/postproc/postproc.cc             |   8 ++
 .../postproc/rewrite_cooperative_fetch.cc          |   5 +
 src/meta_schedule/postproc/rewrite_layout.cc       |   5 +
 .../postproc/rewrite_parallel_vectorize_unroll.cc  |   6 ++
 .../postproc/rewrite_reduction_block.cc            |   5 +
 src/meta_schedule/postproc/rewrite_tensorize.cc    |   5 +
 .../postproc/rewrite_unbound_block.cc              |   5 +
 src/meta_schedule/postproc/verify_gpu_code.cc      |   6 ++
 src/meta_schedule/schedule_rule/add_rfactor.cc     |   6 ++
 src/meta_schedule/schedule_rule/auto_bind.cc       |   6 ++
 src/meta_schedule/schedule_rule/auto_inline.cc     |   6 ++
 .../schedule_rule/cross_thread_reduction.cc        |   6 ++
 .../schedule_rule/multi_level_tiling.cc            |   6 ++
 .../schedule_rule/multi_level_tiling.h             |   3 +
 .../multi_level_tiling_tensor_core.cc              |   7 ++
 .../multi_level_tiling_with_intrin.cc              |   7 ++
 .../schedule_rule/parallel_vectorize_unroll.cc     |   7 ++
 .../schedule_rule/random_compute_location.cc       |   6 ++
 src/meta_schedule/schedule_rule/schedule_rule.cc   |   9 ++
 .../search_strategy/evolutionary_search.cc         |  18 ++++
 src/meta_schedule/search_strategy/replay_func.cc   |  10 ++
 src/meta_schedule/search_strategy/replay_trace.cc  |  11 ++
 .../search_strategy/search_strategy.cc             |  11 +-
 .../space_generator/post_order_apply.cc            |   9 ++
 src/meta_schedule/space_generator/schedule_fn.cc   |   5 +
 .../space_generator/space_generator.cc             |  12 ++-
 .../space_generator/space_generator_union.cc       |   9 ++
 src/meta_schedule/tune_context.cc                  |  26 +++++
 48 files changed, 675 insertions(+), 186 deletions(-)

diff --git a/include/tvm/meta_schedule/mutator.h b/include/tvm/meta_schedule/mutator.h
index 566cc82e97..2b580e75e0 100644
--- a/include/tvm/meta_schedule/mutator.h
+++ b/include/tvm/meta_schedule/mutator.h
@@ -32,6 +32,7 @@ namespace tvm {
 namespace meta_schedule {
 
 class TuneContext;
+class Mutator;
 
 /*! \brief Mutator is designed to mutate the trace to explore the design space. */
 class MutatorNode : public runtime::Object {
@@ -57,12 +58,21 @@ class MutatorNode : public runtime::Object {
   virtual Optional<tir::Trace> Apply(const tir::Trace& trace,
                                      support::LinearCongruentialEngine::TRandState* rand_state) = 0;
 
+  /*!
+   * \brief Clone the mutator.
+   * \return The cloned mutator.
+   */
+  virtual Mutator Clone() const = 0;
+
   static constexpr const char* _type_key = "meta_schedule.Mutator";
   TVM_DECLARE_BASE_OBJECT_INFO(MutatorNode, Object);
 };
 
-/*! \brief The mutator with customized methods on the python-side. */
-class PyMutatorNode : public MutatorNode {
+/*!
+ * \brief Managed reference to MutatorNode
+ * \sa MutatorNode
+ */
+class Mutator : public runtime::ObjectRef {
  public:
   /*!
    * \brief The function type of `InitializeWithTuneContext` method.
@@ -76,39 +86,16 @@ class PyMutatorNode : public MutatorNode {
    */
   using FApply = runtime::TypedPackedFunc<Optional<tir::Trace>(
       const tir::Trace&, support::LinearCongruentialEngine::TRandState rand_state)>;
+  /*!
+   * \brief Clone the mutator.
+   * \return The cloned mutator.
+   */
+  using FClone = runtime::TypedPackedFunc<Mutator()>;
   /*!
    * \brief Get the mutator as string with name.
    * \return The string of the mutator.
    */
   using FAsString = runtime::TypedPackedFunc<String()>;
-
-  /*! \brief The packed function to the `InitializeWithTuneContext` function. */
-  FInitializeWithTuneContext f_initialize_with_tune_context;
-  /*! \brief The packed function to the `Apply` function. */
-  FApply f_apply;
-  /*! \brief The packed function to the `AsString` function. */
-  FAsString f_as_string;
-
-  void VisitAttrs(tvm::AttrVisitor* v) {
-    // `f_initialize_with_tune_context` is not visited
-    // `f_apply` is not visited
-    // `f_as_string` is not visited
-  }
-
-  void InitializeWithTuneContext(const TuneContext& context) final;
-  Optional<tir::Trace> Apply(const tir::Trace& trace,
-                             support::LinearCongruentialEngine::TRandState* rand_state) final;
-
-  static constexpr const char* _type_key = "meta_schedule.PyMutator";
-  TVM_DECLARE_FINAL_OBJECT_INFO(PyMutatorNode, MutatorNode);
-};
-
-/*!
- * \brief Managed reference to MutatorNode
- * \sa MutatorNode
- */
-class Mutator : public runtime::ObjectRef {
- public:
   /*! \brief Create a Mutator that mutates the decision of instruction Sample-Perfect-Tile */
   TVM_DLL static Mutator MutateTileSize();
   /*!
@@ -136,16 +123,49 @@ class Mutator : public runtime::ObjectRef {
    * \brief Create a mutator with customized methods on the python-side.
    * \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`.
    * \param f_apply The packed function of `Apply`.
+   * \param f_clone The packed function of `Clone`.
    * \param f_as_string The packed function of `AsString`.
    * \return The mutator created.
    */
-  TVM_DLL static Mutator PyMutator(
-      PyMutatorNode::FInitializeWithTuneContext f_initialize_with_tune_context,  //
-      PyMutatorNode::FApply f_apply,                                             //
-      PyMutatorNode::FAsString f_as_string);
+  TVM_DLL static Mutator PyMutator(FInitializeWithTuneContext f_initialize_with_tune_context,  //
+                                   FApply f_apply,                                             //
+                                   FClone f_clone,                                             //
+                                   FAsString f_as_string);
   TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Mutator, ObjectRef, MutatorNode);
 };
 
+/*! \brief The mutator with customized methods on the python-side. */
+class PyMutatorNode : public MutatorNode {
+ public:
+  using FInitializeWithTuneContext = Mutator::FInitializeWithTuneContext;
+  using FApply = Mutator::FApply;
+  using FClone = Mutator::FClone;
+  using FAsString = Mutator::FAsString;
+  /*! \brief The packed function to the `InitializeWithTuneContext` function. */
+  FInitializeWithTuneContext f_initialize_with_tune_context;
+  /*! \brief The packed function to the `Apply` function. */
+  FApply f_apply;
+  /*! \brief The packed function to the `Clone` function. */
+  FClone f_clone;
+  /*! \brief The packed function to the `AsString` function. */
+  FAsString f_as_string;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    // `f_initialize_with_tune_context` is not visited
+    // `f_apply` is not visited
+    // `f_clone` is not visited
+    // `f_as_string` is not visited
+  }
+
+  void InitializeWithTuneContext(const TuneContext& context) final;
+  Optional<tir::Trace> Apply(const tir::Trace& trace,
+                             support::LinearCongruentialEngine::TRandState* rand_state) final;
+  Mutator Clone() const final;
+
+  static constexpr const char* _type_key = "meta_schedule.PyMutator";
+  TVM_DECLARE_FINAL_OBJECT_INFO(PyMutatorNode, MutatorNode);
+};
+
 }  // namespace meta_schedule
 }  // namespace tvm
 
diff --git a/include/tvm/meta_schedule/postproc.h b/include/tvm/meta_schedule/postproc.h
index 5d99f68454..4fafb95576 100644
--- a/include/tvm/meta_schedule/postproc.h
+++ b/include/tvm/meta_schedule/postproc.h
@@ -29,6 +29,7 @@ namespace tvm {
 namespace meta_schedule {
 
 class TuneContext;
+class Postproc;
 
 /*!
  * \brief Rules to apply a postprocessor to a schedule.
@@ -54,12 +55,21 @@ class PostprocNode : public runtime::Object {
    */
   virtual bool Apply(const tir::Schedule& sch) = 0;
 
+  /*!
+   * \brief Clone the postprocessor.
+   * \return The cloned postprocessor.
+   */
+  virtual Postproc Clone() const = 0;
+
   static constexpr const char* _type_key = "meta_schedule.Postproc";
   TVM_DECLARE_BASE_OBJECT_INFO(PostprocNode, Object);
 };
 
-/*! \brief The postprocessor with customized methods on the python-side. */
-class PyPostprocNode : public PostprocNode {
+/*!
+ * \brief Managed reference to PostprocNode
+ * \sa PostprocNode
+ */
+class Postproc : public runtime::ObjectRef {
  public:
   /*!
    * \brief The function type of `InitializeWithTuneContext` method.
@@ -72,49 +82,28 @@ class PyPostprocNode : public PostprocNode {
    * \return Whether the postprocessor was successfully applied.
    */
   using FApply = runtime::TypedPackedFunc<bool(const tir::Schedule&)>;
+  /*!
+   * \brief Clone the postprocessor.
+   * \return The cloned postprocessor.
+   */
+  using FClone = runtime::TypedPackedFunc<Postproc()>;
   /*!
    * \brief Get the postprocessor function as string with name.
    * \return The string of the postprocessor function.
    */
   using FAsString = runtime::TypedPackedFunc<String()>;
-
-  /*! \brief The packed function to the `InitializeWithTuneContext` function. */
-  FInitializeWithTuneContext f_initialize_with_tune_context;
-  /*! \brief The packed function to the `Apply` function. */
-  FApply f_apply;
-  /*! \brief The packed function to the `AsString` function. */
-  FAsString f_as_string;
-
-  void VisitAttrs(tvm::AttrVisitor* v) {
-    // `f_initialize_with_tune_context` is not visited
-    // `f_apply` is not visited
-    // `f_as_string` is not visited
-  }
-
-  void InitializeWithTuneContext(const TuneContext& context) final;
-  bool Apply(const tir::Schedule& sch) final;
-
-  static constexpr const char* _type_key = "meta_schedule.PyPostproc";
-  TVM_DECLARE_FINAL_OBJECT_INFO(PyPostprocNode, PostprocNode);
-};
-
-/*!
- * \brief Managed reference to PostprocNode
- * \sa PostprocNode
- */
-class Postproc : public runtime::ObjectRef {
- public:
   /*!
    * \brief Create a postprocessor with customized methods on the python-side.
    * \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`.
    * \param f_apply The packed function of `Apply`.
+   * \param f_clone The packed function of `Clone`.
    * \param f_as_string The packed function of `AsString`.
    * \return The postprocessor created.
    */
-  TVM_DLL static Postproc PyPostproc(
-      PyPostprocNode::FInitializeWithTuneContext f_initialize_with_tune_context,  //
-      PyPostprocNode::FApply f_apply,                                             //
-      PyPostprocNode::FAsString f_as_string);
+  TVM_DLL static Postproc PyPostproc(FInitializeWithTuneContext f_initialize_with_tune_context,  //
+                                     FApply f_apply,                                             //
+                                     FClone f_clone,                                             //
+                                     FAsString f_as_string);
   /*!
    * \brief Create a postprocessor that checks if all loops are static
    * \return The postprocessor created
@@ -164,6 +153,37 @@ class Postproc : public runtime::ObjectRef {
   TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Postproc, ObjectRef, PostprocNode);
 };
 
+/*! \brief The postprocessor with customized methods on the python-side. */
+class PyPostprocNode : public PostprocNode {
+ public:
+  using FInitializeWithTuneContext = Postproc::FInitializeWithTuneContext;
+  using FApply = Postproc::FApply;
+  using FClone = Postproc::FClone;
+  using FAsString = Postproc::FAsString;
+  /*! \brief The packed function to the `InitializeWithTuneContext` function. */
+  FInitializeWithTuneContext f_initialize_with_tune_context;
+  /*! \brief The packed function to the `Apply` function. */
+  FApply f_apply;
+  /*! \brief The packed function to the `Clone` function. */
+  FClone f_clone;
+  /*! \brief The packed function to the `AsString` function. */
+  FAsString f_as_string;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    // `f_initialize_with_tune_context` is not visited
+    // `f_apply` is not visited
+    // `f_clone` is not visited
+    // `f_as_string` is not visited
+  }
+
+  void InitializeWithTuneContext(const TuneContext& context) final;
+  bool Apply(const tir::Schedule& sch) final;
+  Postproc Clone() const final;
+
+  static constexpr const char* _type_key = "meta_schedule.PyPostproc";
+  TVM_DECLARE_FINAL_OBJECT_INFO(PyPostprocNode, PostprocNode);
+};
+
 }  // namespace meta_schedule
 }  // namespace tvm
 
diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h
index 2da441c95e..55704cf4a9 100644
--- a/include/tvm/meta_schedule/schedule_rule.h
+++ b/include/tvm/meta_schedule/schedule_rule.h
@@ -34,6 +34,7 @@ namespace tvm {
 namespace meta_schedule {
 
 class TuneContext;
+class ScheduleRule;
 
 /*! \brief Rules to modify a block in a schedule. */
 class ScheduleRuleNode : public runtime::Object {
@@ -59,12 +60,21 @@ class ScheduleRuleNode : public runtime::Object {
   virtual runtime::Array<tir::Schedule> Apply(const tir::Schedule& sch,
                                               const tir::BlockRV& block) = 0;
 
+  /*!
+   * \brief Deep clone the schedule rule.
+   * \return The cloned schedule rule.
+   */
+  virtual ScheduleRule Clone() const = 0;
+
   static constexpr const char* _type_key = "meta_schedule.ScheduleRule";
   TVM_DECLARE_BASE_OBJECT_INFO(ScheduleRuleNode, Object);
 };
 
-/*! \brief The schedule rule with customized methods on the python-side. */
-class PyScheduleRuleNode : public ScheduleRuleNode {
+/*!
+ * \brief Managed reference to ScheduleRuleNode
+ * \sa ScheduleRuleNode
+ */
+class ScheduleRule : public runtime::ObjectRef {
  public:
   /*!
    * \brief The function type of `InitializeWithTuneContext` method.
@@ -84,33 +94,11 @@ class PyScheduleRuleNode : public ScheduleRuleNode {
    * \return The string of the schedule rule.
    */
   using FAsString = runtime::TypedPackedFunc<String()>;
-
-  /*! \brief The packed function to the `InitializeWithTuneContext` function. */
-  FInitializeWithTuneContext f_initialize_with_tune_context;
-  /*! \brief The packed function to the `Apply` function. */
-  FApply f_apply;
-  /*! \brief The packed function to the `AsString` function. */
-  FAsString f_as_string;
-
-  void VisitAttrs(tvm::AttrVisitor* v) {
-    // `f_initialize_with_tune_context` is not visited
-    // `f_apply` is not visited
-    // `f_as_string` is not visited
-  }
-
-  void InitializeWithTuneContext(const TuneContext& context) final;
-  Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block) final;
-
-  static constexpr const char* _type_key = "meta_schedule.PyScheduleRule";
-  TVM_DECLARE_FINAL_OBJECT_INFO(PyScheduleRuleNode, ScheduleRuleNode);
-};
-
-/*!
- * \brief Managed reference to ScheduleRuleNode
- * \sa ScheduleRuleNode
- */
-class ScheduleRule : public runtime::ObjectRef {
- public:
+  /*!
+   * \brief The function type of `Clone` method.
+   * \return The cloned schedule rule.
+   */
+  using FClone = runtime::TypedPackedFunc<ScheduleRule()>;
   /*!
    * \brief Create an auto-inline rule that inlines spatial blocks if it satisfies some conditions
    * \param into_producer If allows to inline a block into its producer
@@ -249,16 +237,50 @@ class ScheduleRule : public runtime::ObjectRef {
    * \brief Create a schedule rule with customized methods on the python-side.
    * \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`.
    * \param f_apply The packed function of `Apply`.
+   * \param f_clone The packed function of `Clone`.
    * \param f_as_string The packed function of `AsString`.
    * \return The schedule rule created.
    */
   TVM_DLL static ScheduleRule PyScheduleRule(
-      PyScheduleRuleNode::FInitializeWithTuneContext f_initialize_with_tune_context,  //
-      PyScheduleRuleNode::FApply f_apply,                                             //
-      PyScheduleRuleNode::FAsString f_as_string);
+      FInitializeWithTuneContext f_initialize_with_tune_context,  //
+      FApply f_apply,                                             //
+      FClone f_clone,                                             //
+      FAsString f_as_string);
   TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ScheduleRule, ObjectRef, ScheduleRuleNode);
 };
 
+/*! \brief The schedule rule with customized methods on the python-side. */
+class PyScheduleRuleNode : public ScheduleRuleNode {
+ public:
+  using FInitializeWithTuneContext = ScheduleRule::FInitializeWithTuneContext;
+  using FApply = ScheduleRule::FApply;
+  using FClone = ScheduleRule::FClone;
+  using FAsString = ScheduleRule::FAsString;
+
+  /*! \brief The packed function to the `InitializeWithTuneContext` function. */
+  FInitializeWithTuneContext f_initialize_with_tune_context;
+  /*! \brief The packed function to the `Apply` function. */
+  FApply f_apply;
+  /*! \brief The packed function to the `AsString` function. */
+  FAsString f_as_string;
+  /*! \brief The packed function to the `Clone` function. */
+  FClone f_clone;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    // `f_initialize_with_tune_context` is not visited
+    // `f_apply` is not visited
+    // `f_as_string` is not visited
+    // `f_clone` is not visited
+  }
+
+  void InitializeWithTuneContext(const TuneContext& context) final;
+  Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block) final;
+  ScheduleRule Clone() const final;
+
+  static constexpr const char* _type_key = "meta_schedule.PyScheduleRule";
+  TVM_DECLARE_FINAL_OBJECT_INFO(PyScheduleRuleNode, ScheduleRuleNode);
+};
+
 }  // namespace meta_schedule
 }  // namespace tvm
 
diff --git a/include/tvm/meta_schedule/search_strategy.h b/include/tvm/meta_schedule/search_strategy.h
index a75a4cd8ae..efd3dc2452 100644
--- a/include/tvm/meta_schedule/search_strategy.h
+++ b/include/tvm/meta_schedule/search_strategy.h
@@ -36,6 +36,7 @@ namespace meta_schedule {
 
 // Forward declaration
 class TuneContext;
+class SearchStrategy;
 
 /*!
  * \brief The search strategy for measure candidates generation.
@@ -119,12 +120,21 @@ class SearchStrategyNode : public runtime::Object {
   virtual void NotifyRunnerResults(const Array<MeasureCandidate>& measure_candidates,
                                    const Array<RunnerResult>& results) = 0;
 
+  /*!
+   * \brief Clone the search strategy.
+   * \return The cloned search strategy.
+   */
+  virtual SearchStrategy Clone() const = 0;
+
   static constexpr const char* _type_key = "meta_schedule.SearchStrategy";
   TVM_DECLARE_BASE_OBJECT_INFO(SearchStrategyNode, Object);
 };
 
-/*! \brief The python side customizable class for measure candidate generation */
-class PySearchStrategyNode : public SearchStrategyNode {
+/*!
+ * \brief Managed reference to SearchStrategyNode.
+ * \sa SearchStrategyNode
+ */
+class SearchStrategy : public runtime::ObjectRef {
  public:
   /*!
    * \brief The function type of `InitializeWithTuneContext` method.
@@ -150,44 +160,11 @@ class PySearchStrategyNode : public SearchStrategyNode {
    */
   using FNotifyRunnerResults =
       runtime::TypedPackedFunc<void(const Array<MeasureCandidate>&, const Array<RunnerResult>&)>;
-
-  /*! \brief The packed function to the `InitializeWithTuneContext` method. */
-  FInitializeWithTuneContext f_initialize_with_tune_context;
-  /*! \brief The packed function to the `PreTuning` method. */
-  FPreTuning f_pre_tuning;
-  /*! \brief The packed function to the `PostTuning` method. */
-  FPostTuning f_post_tuning;
-  /*! \brief The packed function to the `GenerateMeasureCandidates` method. */
-  FGenerateMeasureCandidates f_generate_measure_candidates;
-  /*! \brief The packed function to the `NotifyRunnerResults` method. */
-  FNotifyRunnerResults f_notify_runner_results;
-
-  void VisitAttrs(tvm::AttrVisitor* v) {
-    // `f_initialize_with_tune_context` is not visited
-    // `f_pre_tuning` is not visited
-    // `f_post_tuning` is not visited
-    // `f_generate_measure_candidates` is not visited
-    // `f_notify_runner_results` is not visited
-  }
-
-  void InitializeWithTuneContext(const TuneContext& context) final;
-  void PreTuning(const Array<tir::Schedule>& design_spaces, const Optional<Database>& database,
-                 const Optional<CostModel>& cost_model) final;
-  void PostTuning() final;
-  Optional<Array<MeasureCandidate>> GenerateMeasureCandidates() final;
-  void NotifyRunnerResults(const Array<MeasureCandidate>& measure_candidates,
-                           const Array<RunnerResult>& results);
-
-  static constexpr const char* _type_key = "meta_schedule.PySearchStrategy";
-  TVM_DECLARE_FINAL_OBJECT_INFO(PySearchStrategyNode, SearchStrategyNode);
-};
-
-/*!
- * \brief Managed reference to SearchStrategyNode.
- * \sa SearchStrategyNode
- */
-class SearchStrategy : public runtime::ObjectRef {
- public:
+  /*!
+   * \brief The function type of `Clone` method.
+   * \return The cloned search strategy.
+   */
+  using FClone = runtime::TypedPackedFunc<SearchStrategy()>;
   /*!
    * \brief Create a search strategy with customized methods on the python-side.
    * \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`.
@@ -195,14 +172,16 @@ class SearchStrategy : public runtime::ObjectRef {
    * \param f_post_tuning The packed function of `PostTuning`.
    * \param f_generate_measure_candidates The packed function of `GenerateMeasureCandidates`.
    * \param f_notify_runner_results The packed function of `NotifyRunnerResults`.
+   * \param f_clone The packed function of `Clone`.
    * \return The search strategy created.
    */
   TVM_DLL static SearchStrategy PySearchStrategy(
-      PySearchStrategyNode::FInitializeWithTuneContext f_initialize_with_tune_context,  //
-      PySearchStrategyNode::FPreTuning f_pre_tuning,                                    //
-      PySearchStrategyNode::FPostTuning f_post_tuning,                                  //
-      PySearchStrategyNode::FGenerateMeasureCandidates f_generate_measure_candidates,   //
-      PySearchStrategyNode::FNotifyRunnerResults f_notify_runner_results);
+      FInitializeWithTuneContext f_initialize_with_tune_context,  //
+      FPreTuning f_pre_tuning,                                    //
+      FPostTuning f_post_tuning,                                  //
+      FGenerateMeasureCandidates f_generate_measure_candidates,   //
+      FNotifyRunnerResults f_notify_runner_results,               //
+      FClone f_clone);
 
   /*!
    * \brief Constructor of replay trace search strategy.
@@ -245,6 +224,51 @@ class SearchStrategy : public runtime::ObjectRef {
   TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(SearchStrategy, ObjectRef, SearchStrategyNode);
 };
 
+/*! \brief The python side customizable class for measure candidate generation */
+class PySearchStrategyNode : public SearchStrategyNode {
+ public:
+  using FInitializeWithTuneContext = SearchStrategy::FInitializeWithTuneContext;
+  using FPreTuning = SearchStrategy::FPreTuning;
+  using FPostTuning = SearchStrategy::FPostTuning;
+  using FGenerateMeasureCandidates = SearchStrategy::FGenerateMeasureCandidates;
+  using FNotifyRunnerResults = SearchStrategy::FNotifyRunnerResults;
+  using FClone = SearchStrategy::FClone;
+
+  /*! \brief The packed function to the `InitializeWithTuneContext` method. */
+  FInitializeWithTuneContext f_initialize_with_tune_context;
+  /*! \brief The packed function to the `PreTuning` method. */
+  FPreTuning f_pre_tuning;
+  /*! \brief The packed function to the `PostTuning` method. */
+  FPostTuning f_post_tuning;
+  /*! \brief The packed function to the `GenerateMeasureCandidates` method. */
+  FGenerateMeasureCandidates f_generate_measure_candidates;
+  /*! \brief The packed function to the `NotifyRunnerResults` method. */
+  FNotifyRunnerResults f_notify_runner_results;
+  /*! \brief The packed function to the `Clone` method. */
+  FClone f_clone;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    // `f_initialize_with_tune_context` is not visited
+    // `f_pre_tuning` is not visited
+    // `f_post_tuning` is not visited
+    // `f_generate_measure_candidates` is not visited
+    // `f_notify_runner_results` is not visited
+    // `f_clone` is not visited
+  }
+
+  void InitializeWithTuneContext(const TuneContext& context) final;
+  void PreTuning(const Array<tir::Schedule>& design_spaces, const Optional<Database>& database,
+                 const Optional<CostModel>& cost_model) final;
+  void PostTuning() final;
+  Optional<Array<MeasureCandidate>> GenerateMeasureCandidates() final;
+  void NotifyRunnerResults(const Array<MeasureCandidate>& measure_candidates,
+                           const Array<RunnerResult>& results);
+  SearchStrategy Clone() const final;
+
+  static constexpr const char* _type_key = "meta_schedule.PySearchStrategy";
+  TVM_DECLARE_FINAL_OBJECT_INFO(PySearchStrategyNode, SearchStrategyNode);
+};
+
 }  // namespace meta_schedule
 }  // namespace tvm
 
diff --git a/include/tvm/meta_schedule/space_generator.h b/include/tvm/meta_schedule/space_generator.h
index 2c1b2d4e4d..1e29e757a1 100644
--- a/include/tvm/meta_schedule/space_generator.h
+++ b/include/tvm/meta_schedule/space_generator.h
@@ -31,6 +31,7 @@ namespace meta_schedule {
 
 // Forward declaration
 class TuneContext;
+class SpaceGenerator;
 
 /*!
  * \brief The abstract class for design space generation.
@@ -87,12 +88,21 @@ class SpaceGeneratorNode : public runtime::Object {
    */
   virtual Array<tir::Schedule> GenerateDesignSpace(const IRModule& mod) = 0;
 
+  /*!
+   * \brief Clone the space generator.
+   * \return The cloned space generator.
+   */
+  virtual SpaceGenerator Clone() const = 0;
+
   static constexpr const char* _type_key = "meta_schedule.SpaceGenerator";
   TVM_DECLARE_BASE_OBJECT_INFO(SpaceGeneratorNode, Object);
 };
 
-/*! \brief The design space generator with customized methods on the python-side. */
-class PySpaceGeneratorNode : public SpaceGeneratorNode {
+/*!
+ * \brief Managed reference to SpaceGeneratorNode.
+ * \sa SpaceGeneratorNode
+ */
+class SpaceGenerator : public runtime::ObjectRef {
  public:
   /*!
    * \brief The function type of `InitializeWithTuneContext` method.
@@ -105,29 +115,12 @@ class PySpaceGeneratorNode : public SpaceGeneratorNode {
    * \return The generated design spaces, i.e., schedules.
    */
   using FGenerateDesignSpace = runtime::TypedPackedFunc<Array<tir::Schedule>(const IRModule&)>;
+  /*!
+   * \brief The function type of `Clone` method.
+   * \return The cloned space generator.
+   */
+  using FClone = runtime::TypedPackedFunc<SpaceGenerator()>;
 
-  /*! \brief The packed function to the `InitializeWithTuneContext` function. */
-  FInitializeWithTuneContext f_initialize_with_tune_context;
-  /*! \brief The packed function to the `GenerateDesignSpace` function. */
-  FGenerateDesignSpace f_generate_design_space;
-
-  void VisitAttrs(tvm::AttrVisitor* v) {
-    // `f_initialize_with_tune_context` is not visited
-    // `f_generate_design_space` is not visited
-  }
-
-  void InitializeWithTuneContext(const TuneContext& context) final;
-  Array<tir::Schedule> GenerateDesignSpace(const IRModule& mod) final;
-
-  static constexpr const char* _type_key = "meta_schedule.PySpaceGenerator";
-  TVM_DECLARE_FINAL_OBJECT_INFO(PySpaceGeneratorNode, SpaceGeneratorNode);
-};
-
-/*!
- * \brief Managed reference to SpaceGeneratorNode.
- * \sa SpaceGeneratorNode
- */
-class SpaceGenerator : public runtime::ObjectRef {
  protected:
   SpaceGenerator() = default;
 
@@ -136,11 +129,12 @@ class SpaceGenerator : public runtime::ObjectRef {
    * \brief Create a design space generator with customized methods on the python-side.
    * \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`.
    * \param f_generate_design_space The packed function of `GenerateDesignSpace`.
+   * \param f_clone The packed function of `Clone`.
    * \return The design space generator created.
    */
   TVM_DLL static SpaceGenerator PySpaceGenerator(
-      PySpaceGeneratorNode::FInitializeWithTuneContext f_initialize_with_tune_context,
-      PySpaceGeneratorNode::FGenerateDesignSpace f_generate_design_space);
+      FInitializeWithTuneContext f_initialize_with_tune_context,
+      FGenerateDesignSpace f_generate_design_space, FClone f_clone);
   /*!
    * \brief Create a design space generator with customized schedule function.
    * \param schedule_fn The schedule function, which can have the following signatures:
@@ -156,14 +150,40 @@ class SpaceGenerator : public runtime::ObjectRef {
    */
   TVM_DLL static SpaceGenerator SpaceGeneratorUnion(Array<SpaceGenerator, void> space_generators);
   /*!
-   * \brief Create a design space generator that generates design spaces by applying schedule rules
-   *  to blocks in post-DFS order.
-   * \return The design space generator created.
+   * \brief Create a design space generator that generates design spaces by applying schedule
+   * rules to blocks in post-DFS order. \return The design space generator created.
    */
   TVM_DLL static SpaceGenerator PostOrderApply(runtime::PackedFunc f_block_filter = nullptr);
   TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(SpaceGenerator, ObjectRef, SpaceGeneratorNode);
 };
 
+/*! \brief The design space generator with customized methods on the python-side. */
+class PySpaceGeneratorNode : public SpaceGeneratorNode {
+ public:
+  using FInitializeWithTuneContext = SpaceGenerator::FInitializeWithTuneContext;
+  using FGenerateDesignSpace = SpaceGenerator::FGenerateDesignSpace;
+  using FClone = SpaceGenerator::FClone;
+  /*! \brief The packed function to the `InitializeWithTuneContext` function. */
+  FInitializeWithTuneContext f_initialize_with_tune_context;
+  /*! \brief The packed function to the `GenerateDesignSpace` function. */
+  FGenerateDesignSpace f_generate_design_space;
+  /*! \brief The packed function to the `Clone` function. */
+  FClone f_clone;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    // `f_initialize_with_tune_context` is not visited
+    // `f_generate_design_space` is not visited
+    // `f_clone` is not visited
+  }
+
+  void InitializeWithTuneContext(const TuneContext& context) final;
+  Array<tir::Schedule> GenerateDesignSpace(const IRModule& mod) final;
+  SpaceGenerator Clone() const final;
+
+  static constexpr const char* _type_key = "meta_schedule.PySpaceGenerator";
+  TVM_DECLARE_FINAL_OBJECT_INFO(PySpaceGeneratorNode, SpaceGeneratorNode);
+};
+
 }  // namespace meta_schedule
 }  // namespace tvm
 
diff --git a/include/tvm/meta_schedule/tune_context.h b/include/tvm/meta_schedule/tune_context.h
index 3d732e7fbd..4e2f00fb5a 100644
--- a/include/tvm/meta_schedule/tune_context.h
+++ b/include/tvm/meta_schedule/tune_context.h
@@ -43,6 +43,7 @@ namespace meta_schedule {
 
 class TaskSchedulerNode;
 class MeasureCallback;
+class TuneContext;
 
 /*! \brief The auto tuning context. */
 class TuneContextNode : public runtime::Object {
@@ -99,6 +100,11 @@ class TuneContextNode : public runtime::Object {
 
   /*! \brief Initialize members that needs initialization with tune context. */
   void Initialize();
+  /*!
+   * \brief Clone the tune context.
+   * \return The cloned tune context.
+   */
+  TuneContext Clone() const;
   /*! \brief Set the measure candidates from the SearchStrategy */
   void _SetMeasureCandidates(const Array<MeasureCandidate>& candidates);
   /*!
diff --git a/python/tvm/meta_schedule/mutator/mutator.py b/python/tvm/meta_schedule/mutator/mutator.py
index 0c8de96680..c5286aced7 100644
--- a/python/tvm/meta_schedule/mutator/mutator.py
+++ b/python/tvm/meta_schedule/mutator/mutator.py
@@ -58,6 +58,16 @@ class Mutator(Object):
         """
         return _ffi_api.MutatorApply(self, trace, -1)  # type: ignore # pylint: disable=no-member
 
+    def clone(self) -> "Mutator":
+        """Clone the mutator.
+
+        Returns
+        -------
+        mutator : Mutator
+            The cloned mutator.
+        """
+        return _ffi_api.MutatorClone(self)  # type: ignore # pylint: disable=no-member
+
 
 @register_object("meta_schedule.PyMutator")
 class _PyMutator(Mutator):
@@ -72,6 +82,7 @@ class _PyMutator(Mutator):
         self,
         f_initialize_with_tune_context: Callable = None,
         f_apply: Callable = None,
+        f_clone: Callable = None,
         f_as_string: Callable = None,
     ):
         """Constructor."""
@@ -80,6 +91,7 @@ class _PyMutator(Mutator):
             _ffi_api.MutatorPyMutator,  # type: ignore # pylint: disable=no-member
             f_initialize_with_tune_context,
             f_apply,
+            f_clone,
             f_as_string,
         )
 
@@ -94,7 +106,7 @@ class PyMutator:
 
     _tvm_metadata = {
         "cls": _PyMutator,
-        "methods": ["_initialize_with_tune_context", "apply", "__str__"],
+        "methods": ["_initialize_with_tune_context", "apply", "clone", "__str__"],
     }
 
     def _initialize_with_tune_context(self, context: "TuneContext") -> None:
@@ -122,6 +134,16 @@ class PyMutator:
         """
         raise NotImplementedError
 
+    def clone(self) -> Mutator:
+        """Clone the mutator.
+
+        Returns
+        -------
+        mutator : Mutator
+            The cloned mutator.
+        """
+        raise NotImplementedError
+
     def __str__(self) -> str:
         """Get the mutator as string with name.
 
diff --git a/python/tvm/meta_schedule/postproc/postproc.py b/python/tvm/meta_schedule/postproc/postproc.py
index e37666bd1c..6eec2965ce 100644
--- a/python/tvm/meta_schedule/postproc/postproc.py
+++ b/python/tvm/meta_schedule/postproc/postproc.py
@@ -60,6 +60,16 @@ class Postproc(Object):
         """
         return _ffi_api.PostprocApply(self, sch)  # type: ignore # pylint: disable=no-member
 
+    def clone(self) -> "Postproc":
+        """Clone the postprocessor.
+
+        Returns
+        -------
+        cloned_postproc : Postproc
+            The cloned postprocessor.
+        """
+        return _ffi_api.PostprocClone(self)  # type: ignore # pylint: disable=no-member
+
 
 @register_object("meta_schedule.PyPostproc")
 class _PyPostproc(Postproc):
@@ -74,6 +84,7 @@ class _PyPostproc(Postproc):
         self,
         f_initialize_with_tune_context: Callable = None,
         f_apply: Callable = None,
+        f_clone: Callable = None,
         f_as_string: Callable = None,
     ):
         """Constructor."""
@@ -82,6 +93,7 @@ class _PyPostproc(Postproc):
             _ffi_api.PostprocPyPostproc,  # type: ignore # pylint: disable=no-member
             f_initialize_with_tune_context,
             f_apply,
+            f_clone,
             f_as_string,
         )
 
@@ -96,7 +108,7 @@ class PyPostproc:
 
     _tvm_metadata = {
         "cls": _PyPostproc,
-        "methods": ["_initialize_with_tune_context", "apply", "__str__"],
+        "methods": ["_initialize_with_tune_context", "apply", "clone", "__str__"],
     }
 
     def _initialize_with_tune_context(self, context: "TuneContext") -> None:
@@ -124,6 +136,16 @@ class PyPostproc:
         """
         raise NotImplementedError
 
+    def clone(self) -> Postproc:
+        """Clone the postprocessor.
+
+        Returns
+        -------
+        cloned_postproc : Postproc
+            The cloned postprocessor.
+        """
+        raise NotImplementedError
+
     def __str__(self) -> str:
         """Get the post processor as string with name.
 
diff --git a/python/tvm/meta_schedule/schedule_rule/schedule_rule.py b/python/tvm/meta_schedule/schedule_rule/schedule_rule.py
index 481444341b..2c8e223611 100644
--- a/python/tvm/meta_schedule/schedule_rule/schedule_rule.py
+++ b/python/tvm/meta_schedule/schedule_rule/schedule_rule.py
@@ -66,6 +66,16 @@ class ScheduleRule(Object):
             self, sch, block
         )
 
+    def clone(self) -> "ScheduleRule":
+        """Deep clone the schedule rule.
+
+        Returns
+        -------
+        cloned_rule : ScheduleRule
+            The cloned schedule rule.
+        """
+        return _ffi_api.ScheduleRuleClone(self)  # type: ignore # pylint: disable=no-member
+
 
 @register_object("meta_schedule.PyScheduleRule")
 class _PyScheduleRule(ScheduleRule):
@@ -80,6 +90,7 @@ class _PyScheduleRule(ScheduleRule):
         self,
         f_initialize_with_tune_context: Callable = None,
         f_apply: Callable = None,
+        f_clone: Callable = None,
         f_as_string: Callable = None,
     ):
         """Constructor."""
@@ -88,6 +99,7 @@ class _PyScheduleRule(ScheduleRule):
             _ffi_api.ScheduleRulePyScheduleRule,  # type: ignore # pylint: disable=no-member
             f_initialize_with_tune_context,
             f_apply,
+            f_clone,
             f_as_string,
         )
 
@@ -102,7 +114,7 @@ class PyScheduleRule:
 
     _tvm_metadata = {
         "cls": _PyScheduleRule,
-        "methods": ["_initialize_with_tune_context", "apply", "__str__"],
+        "methods": ["_initialize_with_tune_context", "apply", "clone", "__str__"],
     }
 
     def _initialize_with_tune_context(self, context: "TuneContext") -> None:
@@ -113,9 +125,7 @@ class PyScheduleRule:
         context : TuneContext
             The tuning context for initializing the schedule rule.
         """
-        _ffi_api.ScheduleRuleInitializeWithTuneContext(  # type: ignore # pylint: disable=no-member
-            self, context
-        )
+        raise NotImplementedError
 
     def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]:
         """Apply a schedule rule to the specific block in the given schedule.
@@ -132,9 +142,17 @@ class PyScheduleRule:
         design_spaces : List[Schedule]
             The list of schedules generated by applying the schedule rule.
         """
-        return _ffi_api.ScheduleRuleApply(  #  type: ignore # pylint: disable=no-member
-            self, sch, block
-        )
+        raise NotImplementedError
+
+    def clone(self) -> ScheduleRule:
+        """Deep clone the schedule rule.
+
+        Returns
+        -------
+        cloned_rule : ScheduleRule
+            The cloned schedule rule.
+        """
+        raise NotImplementedError
 
     def __str__(self) -> str:
         """Get the schedule rule as string with name.
diff --git a/python/tvm/meta_schedule/search_strategy/search_strategy.py b/python/tvm/meta_schedule/search_strategy/search_strategy.py
index e88cdf825a..276e657133 100644
--- a/python/tvm/meta_schedule/search_strategy/search_strategy.py
+++ b/python/tvm/meta_schedule/search_strategy/search_strategy.py
@@ -151,6 +151,16 @@ class SearchStrategy(Object):
             results,
         )
 
+    def clone(self) -> "SearchStrategy":
+        """Clone the search strategy.
+
+        Returns
+        -------
+        cloned : SearchStrategy
+            The cloned search strategy.
+        """
+        return _ffi_api.SearchStrategyClone(self)  # type: ignore # pylint: disable=no-member
+
 
 @register_object("meta_schedule.PySearchStrategy")
 class _PySearchStrategy(SearchStrategy):
@@ -168,6 +178,7 @@ class _PySearchStrategy(SearchStrategy):
         f_post_tuning: Callable = None,
         f_generate_measure_candidates: Callable = None,
         f_notify_runner_results: Callable = None,
+        f_clone: Callable = None,
     ):
         """Constructor."""
 
@@ -178,6 +189,7 @@ class _PySearchStrategy(SearchStrategy):
             f_post_tuning,
             f_generate_measure_candidates,
             f_notify_runner_results,
+            f_clone,
         )
 
 
@@ -197,6 +209,7 @@ class PySearchStrategy:
             "post_tuning",
             "generate_measure_candidates",
             "notify_runner_results",
+            "clone",
         ],
     }
 
@@ -250,6 +263,16 @@ class PySearchStrategy:
         """
         raise NotImplementedError
 
+    def clone(self) -> SearchStrategy:
+        """Clone the search strategy.
+
+        Returns
+        -------
+        strategy : SearchStrategy
+            The cloned search strategy.
+        """
+        raise NotImplementedError
+
 
 def create(  # pylint: disable=keyword-arg-before-vararg
     kind: Literal[
diff --git a/python/tvm/meta_schedule/space_generator/space_generator.py b/python/tvm/meta_schedule/space_generator/space_generator.py
index 9d7ebf3bae..23c0361645 100644
--- a/python/tvm/meta_schedule/space_generator/space_generator.py
+++ b/python/tvm/meta_schedule/space_generator/space_generator.py
@@ -72,6 +72,16 @@ class SpaceGenerator(Object):
         """
         return _ffi_api.SpaceGeneratorGenerateDesignSpace(self, mod)  # type: ignore # pylint: disable=no-member
 
+    def clone(self) -> "SpaceGenerator":
+        """Clone the design space generator.
+
+        Returns
+        -------
+        cloned_sg : SpaceGenerator
+            The cloned design space generator.
+        """
+        return _ffi_api.SpaceGeneratorClone(self)  # type: ignore # pylint: disable=no-member
+
 
 ScheduleFnType = SpaceGenerator.ScheduleFnType
 
@@ -89,6 +99,7 @@ class _PySpaceGenerator(SpaceGenerator):
         self,
         f_initialize_with_tune_context: Optional[Callable] = None,
         f_generate_design_space: Optional[Callable] = None,
+        f_clone: Optional[Callable] = None,
     ):
         """Constructor."""
 
@@ -96,6 +107,7 @@ class _PySpaceGenerator(SpaceGenerator):
             _ffi_api.SpaceGeneratorPySpaceGenerator,  # type: ignore # pylint: disable=no-member
             f_initialize_with_tune_context,
             f_generate_design_space,
+            f_clone,
         )
 
 
@@ -109,7 +121,7 @@ class PySpaceGenerator:
 
     _tvm_metadata = {
         "cls": _PySpaceGenerator,
-        "methods": ["_initialize_with_tune_context", "generate_design_space"],
+        "methods": ["_initialize_with_tune_context", "generate_design_space", "clone"],
     }
 
     def _initialize_with_tune_context(self, context: "TuneContext") -> None:
@@ -137,6 +149,16 @@ class PySpaceGenerator:
         """
         raise NotImplementedError
 
+    def clone(self) -> SpaceGenerator:
+        """Clone the design space generator.
+
+        Returns
+        -------
+        cloned_sg : SpaceGenerator
+            The cloned design space generator.
+        """
+        raise NotImplementedError
+
 
 def create(  # pylint: disable=keyword-arg-before-vararg
     kind: Union[
diff --git a/python/tvm/meta_schedule/testing/dummy_object.py b/python/tvm/meta_schedule/testing/dummy_object.py
index 50ae974df5..bb22945449 100644
--- a/python/tvm/meta_schedule/testing/dummy_object.py
+++ b/python/tvm/meta_schedule/testing/dummy_object.py
@@ -58,3 +58,6 @@ class DummyMutator(PyMutator):
 
     def apply(self, trace: Trace, _) -> Optional[Trace]:
         return Trace(trace.insts, {})
+
+    def clone(self):
+        return DummyMutator()
diff --git a/python/tvm/meta_schedule/tune_context.py b/python/tvm/meta_schedule/tune_context.py
index 17acad8d4a..29cd94110c 100644
--- a/python/tvm/meta_schedule/tune_context.py
+++ b/python/tvm/meta_schedule/tune_context.py
@@ -331,3 +331,13 @@ class TuneContext(Object):
                 "Please construct TuneContext with search_strategy"
             )
         return self.search_strategy.notify_runner_results(measure_candidates, results)
+
+    def clone(self) -> "TuneContext":
+        """Clone the TuneContext.
+
+        Returns
+        -------
+        cloned_context : TuneContext
+            The cloned TuneContext.
+        """
+        return _ffi_api.TuneContextClone(self)  # type: ignore # pylint: disable=no-member
diff --git a/src/meta_schedule/mutator/mutate_compute_location.cc b/src/meta_schedule/mutator/mutate_compute_location.cc
index 9d6d69ba35..2a31d2da9b 100644
--- a/src/meta_schedule/mutator/mutate_compute_location.cc
+++ b/src/meta_schedule/mutator/mutate_compute_location.cc
@@ -42,6 +42,11 @@ class MutateComputeLocationNode : public MutatorNode {
   }
   // Inherit from `MutatorNode`
   Optional<Trace> Apply(const Trace& trace, TRandState* rand_state) final;
+  // Inherit from `MutatorNode`
+  Mutator Clone() const final {
+    ObjectPtr<MutateComputeLocationNode> n = make_object<MutateComputeLocationNode>(*this);
+    return Mutator(n);
+  }
 
  private:
   struct Candidate {
diff --git a/src/meta_schedule/mutator/mutate_parallel.cc b/src/meta_schedule/mutator/mutate_parallel.cc
index 82b91da682..9feb4747d8 100644
--- a/src/meta_schedule/mutator/mutate_parallel.cc
+++ b/src/meta_schedule/mutator/mutate_parallel.cc
@@ -188,6 +188,11 @@ class MutateParallelNode : public MutatorNode {
   }
   // Inherit from `MutatorNode`
   Optional<Trace> Apply(const Trace& trace, TRandState* rand_state) final;
+  // Inherit from `MutatorNode`
+  Mutator Clone() const final {
+    ObjectPtr<MutateParallelNode> n = make_object<MutateParallelNode>(*this);
+    return Mutator(n);
+  }
 };
 
 /*! \brief The candidate to be mutated */
diff --git a/src/meta_schedule/mutator/mutate_thread_binding.cc b/src/meta_schedule/mutator/mutate_thread_binding.cc
index de780b53e2..f5d89a8509 100644
--- a/src/meta_schedule/mutator/mutate_thread_binding.cc
+++ b/src/meta_schedule/mutator/mutate_thread_binding.cc
@@ -42,6 +42,11 @@ class MutateThreadBindingNode : public MutatorNode {
   }
   // Inherit from `MutatorNode`
   Optional<Trace> Apply(const Trace& trace, TRandState* rand_state) final;
+  // Inherit from `MutatorNode`
+  Mutator Clone() const final {
+    ObjectPtr<MutateThreadBindingNode> n = make_object<MutateThreadBindingNode>(*this);
+    return Mutator(n);
+  }
 
  private:
   struct Candidate {
diff --git a/src/meta_schedule/mutator/mutate_tile_size.cc b/src/meta_schedule/mutator/mutate_tile_size.cc
index 4a3bfda8a4..8fb83147ea 100644
--- a/src/meta_schedule/mutator/mutate_tile_size.cc
+++ b/src/meta_schedule/mutator/mutate_tile_size.cc
@@ -63,6 +63,11 @@ class MutateTileSizeNode : public MutatorNode {
   void InitializeWithTuneContext(const TuneContext& context) final {}
   // Inherit from `MutatorNode`
   Optional<Trace> Apply(const Trace& trace, TRandState* rand_state) final;
+  // Inherit from `MutatorNode`
+  Mutator Clone() const final {
+    ObjectPtr<MutateTileSizeNode> n = make_object<MutateTileSizeNode>(*this);
+    return Mutator(n);
+  }
 };
 
 /*!
diff --git a/src/meta_schedule/mutator/mutate_unroll.cc b/src/meta_schedule/mutator/mutate_unroll.cc
index c282a171c3..7bbf00343a 100644
--- a/src/meta_schedule/mutator/mutate_unroll.cc
+++ b/src/meta_schedule/mutator/mutate_unroll.cc
@@ -60,6 +60,11 @@ class MutateUnrollNode : public MutatorNode {
   void InitializeWithTuneContext(const TuneContext& context) final {}
   // Inherit from `MutatorNode`
   Optional<Trace> Apply(const Trace& trace, TRandState* rand_state) final;
+  // Inherit from `MutatorNode`
+  Mutator Clone() const final {
+    ObjectPtr<MutateUnrollNode> n = make_object<MutateUnrollNode>(*this);
+    return Mutator(n);
+  }
 };
 
 /*! \brief A candidate to be mutated */
diff --git a/src/meta_schedule/mutator/mutator.cc b/src/meta_schedule/mutator/mutator.cc
index 43b95000c7..25312ab61f 100644
--- a/src/meta_schedule/mutator/mutator.cc
+++ b/src/meta_schedule/mutator/mutator.cc
@@ -33,13 +33,20 @@ Optional<tir::Trace> PyMutatorNode::Apply(
   return f_apply(trace, *rand_state);
 }
 
+Mutator PyMutatorNode::Clone() const {
+  ICHECK(f_clone != nullptr) << "PyMutator's Clone method not implemented!";
+  return f_clone();
+}
+
 Mutator Mutator::PyMutator(
     PyMutatorNode::FInitializeWithTuneContext f_initialize_with_tune_context,  //
     PyMutatorNode::FApply f_apply,                                             //
+    PyMutatorNode::FClone f_clone,                                             //
     PyMutatorNode::FAsString f_as_string) {
   ObjectPtr<PyMutatorNode> n = make_object<PyMutatorNode>();
   n->f_initialize_with_tune_context = std::move(f_initialize_with_tune_context);
   n->f_apply = std::move(f_apply);
+  n->f_clone = std::move(f_clone);
   n->f_as_string = std::move(f_as_string);
   return Mutator(n);
 }
@@ -63,6 +70,7 @@ TVM_REGISTER_GLOBAL("meta_schedule.MutatorApply")
       TRandState seed_ = (seed != -1) ? seed : support::LinearCongruentialEngine::DeviceRandom();
       return self->Apply(trace, &seed_);
     });
+TVM_REGISTER_GLOBAL("meta_schedule.MutatorClone").set_body_method<Mutator>(&MutatorNode::Clone);
 TVM_REGISTER_GLOBAL("meta_schedule.MutatorPyMutator").set_body_typed(Mutator::PyMutator);
 
 }  // namespace meta_schedule
diff --git a/src/meta_schedule/postproc/disallow_dynamic_loop.cc b/src/meta_schedule/postproc/disallow_dynamic_loop.cc
index 85a81f10fd..8362da552e 100644
--- a/src/meta_schedule/postproc/disallow_dynamic_loop.cc
+++ b/src/meta_schedule/postproc/disallow_dynamic_loop.cc
@@ -67,6 +67,11 @@ class DisallowDynamicLoopNode : public PostprocNode {
   void InitializeWithTuneContext(const TuneContext& context) final {}
   // Inherited from PostprocNode
   bool Apply(const tir::Schedule& sch) final { return !tir::DynamicExtentFinder::Find(sch->mod()); }
+  // Inherited from PostprocNode
+  Postproc Clone() const {
+    ObjectPtr<DisallowDynamicLoopNode> n = make_object<DisallowDynamicLoopNode>(*this);
+    return Postproc(n);
+  }
 
   static constexpr const char* _type_key = "meta_schedule.DisallowDynamicLoop";
   TVM_DECLARE_FINAL_OBJECT_INFO(DisallowDynamicLoopNode, PostprocNode);
diff --git a/src/meta_schedule/postproc/postproc.cc b/src/meta_schedule/postproc/postproc.cc
index 0f4f1b1192..957d6e7364 100644
--- a/src/meta_schedule/postproc/postproc.cc
+++ b/src/meta_schedule/postproc/postproc.cc
@@ -32,13 +32,20 @@ bool PyPostprocNode::Apply(const tir::Schedule& sch) {
   return f_apply(sch);
 }
 
+Postproc PyPostprocNode::Clone() const {
+  ICHECK(f_clone != nullptr) << "PyPostproc's Clone method not implemented!";
+  return f_clone();
+}
+
 Postproc Postproc::PyPostproc(
     PyPostprocNode::FInitializeWithTuneContext f_initialize_with_tune_context,  //
     PyPostprocNode::FApply f_apply,                                             //
+    PyPostprocNode::FClone f_clone,                                             //
     PyPostprocNode::FAsString f_as_string) {
   ObjectPtr<PyPostprocNode> n = make_object<PyPostprocNode>();
   n->f_initialize_with_tune_context = std::move(f_initialize_with_tune_context);
   n->f_apply = std::move(f_apply);
+  n->f_clone = std::move(f_clone);
   n->f_as_string = std::move(f_as_string);
   return Postproc(n);
 }
@@ -58,6 +65,7 @@ TVM_REGISTER_NODE_TYPE(PyPostprocNode);
 TVM_REGISTER_GLOBAL("meta_schedule.PostprocInitializeWithTuneContext")
     .set_body_method<Postproc>(&PostprocNode::InitializeWithTuneContext);
 TVM_REGISTER_GLOBAL("meta_schedule.PostprocApply").set_body_method<Postproc>(&PostprocNode::Apply);
+TVM_REGISTER_GLOBAL("meta_schedule.PostprocClone").set_body_method<Postproc>(&PostprocNode::Clone);
 TVM_REGISTER_GLOBAL("meta_schedule.PostprocPyPostproc").set_body_typed(Postproc::PyPostproc);
 
 }  // namespace meta_schedule
diff --git a/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc b/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc
index d111bdb42a..ac9f45ca8e 100644
--- a/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc
+++ b/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc
@@ -104,6 +104,11 @@ class RewriteCooperativeFetchNode : public PostprocNode {
   // Inherited from PostprocNode
   bool Apply(const tir::Schedule& sch) final;
 
+  Postproc Clone() const {
+    ObjectPtr<RewriteCooperativeFetchNode> n = make_object<RewriteCooperativeFetchNode>(*this);
+    return Postproc(n);
+  }
+
   void VisitAttrs(tvm::AttrVisitor* v) {}
 
   static constexpr const char* _type_key = "meta_schedule.RewriteCooperativeFetch";
diff --git a/src/meta_schedule/postproc/rewrite_layout.cc b/src/meta_schedule/postproc/rewrite_layout.cc
index f4cbdfe737..6ff9958c79 100644
--- a/src/meta_schedule/postproc/rewrite_layout.cc
+++ b/src/meta_schedule/postproc/rewrite_layout.cc
@@ -167,6 +167,11 @@ class RewriteLayoutNode : public PostprocNode {
   // Inherited from PostprocNode
   bool Apply(const tir::Schedule& sch) final { return tir::RewriteLayout(sch); }
 
+  Postproc Clone() const {
+    ObjectPtr<RewriteLayoutNode> n = make_object<RewriteLayoutNode>(*this);
+    return Postproc(n);
+  }
+
   static constexpr const char* _type_key = "meta_schedule.RewriteLayout";
   TVM_DECLARE_FINAL_OBJECT_INFO(RewriteLayoutNode, PostprocNode);
 };
diff --git a/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc b/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc
index 08d25d0178..c3cc0ef601 100644
--- a/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc
+++ b/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc
@@ -384,6 +384,12 @@ class RewriteParallelVectorizeUnrollNode : public PostprocNode {
     return true;
   }
 
+  Postproc Clone() const {
+    ObjectPtr<RewriteParallelVectorizeUnrollNode> n =
+        make_object<RewriteParallelVectorizeUnrollNode>(*this);
+    return Postproc(n);
+  }
+
   static constexpr const char* _type_key = "meta_schedule.RewriteParallelVectorizeUnroll";
   TVM_DECLARE_FINAL_OBJECT_INFO(RewriteParallelVectorizeUnrollNode, PostprocNode);
 };
diff --git a/src/meta_schedule/postproc/rewrite_reduction_block.cc b/src/meta_schedule/postproc/rewrite_reduction_block.cc
index ea204e3061..05a7640f04 100644
--- a/src/meta_schedule/postproc/rewrite_reduction_block.cc
+++ b/src/meta_schedule/postproc/rewrite_reduction_block.cc
@@ -114,6 +114,11 @@ class RewriteReductionBlockNode : public PostprocNode {
   // Inherited from PostprocNode
   bool Apply(const tir::Schedule& sch) final;
 
+  Postproc Clone() const {
+    ObjectPtr<RewriteReductionBlockNode> n = make_object<RewriteReductionBlockNode>(*this);
+    return Postproc(n);
+  }
+
   void VisitAttrs(tvm::AttrVisitor* v) {}
 
   static constexpr const char* _type_key = "meta_schedule.RewriteReductionBlock";
diff --git a/src/meta_schedule/postproc/rewrite_tensorize.cc b/src/meta_schedule/postproc/rewrite_tensorize.cc
index 3b6c438d02..4f8e0fb213 100644
--- a/src/meta_schedule/postproc/rewrite_tensorize.cc
+++ b/src/meta_schedule/postproc/rewrite_tensorize.cc
@@ -68,6 +68,11 @@ class RewriteTensorizeNode : public PostprocNode {
 
   void VisitAttrs(tvm::AttrVisitor* v) {}
 
+  Postproc Clone() const {
+    ObjectPtr<RewriteTensorizeNode> n = make_object<RewriteTensorizeNode>(*this);
+    return Postproc(n);
+  }
+
   bool vectorize_init_loop = false;
 
   static constexpr const char* _type_key = "meta_schedule.RewriteTensorize";
diff --git a/src/meta_schedule/postproc/rewrite_unbound_block.cc b/src/meta_schedule/postproc/rewrite_unbound_block.cc
index eb57e90f82..1ba68538ea 100644
--- a/src/meta_schedule/postproc/rewrite_unbound_block.cc
+++ b/src/meta_schedule/postproc/rewrite_unbound_block.cc
@@ -97,6 +97,11 @@ class RewriteUnboundBlockNode : public PostprocNode {
   // Inherited from PostprocNode
   bool Apply(const tir::Schedule& sch) final;
 
+  Postproc Clone() const {
+    ObjectPtr<RewriteUnboundBlockNode> n = make_object<RewriteUnboundBlockNode>(*this);
+    return Postproc(n);
+  }
+
  public:
   /*! \brief The max number of threads per block from Target */
   int max_threads_per_block_ = -1;
diff --git a/src/meta_schedule/postproc/verify_gpu_code.cc b/src/meta_schedule/postproc/verify_gpu_code.cc
index dfe2c5a06a..0828ee5384 100644
--- a/src/meta_schedule/postproc/verify_gpu_code.cc
+++ b/src/meta_schedule/postproc/verify_gpu_code.cc
@@ -196,6 +196,12 @@ class VerifyGPUCodeNode : public PostprocNode {
     return true;
   }
 
+  Postproc Clone() const {
+    ObjectPtr<VerifyGPUCodeNode> n = make_object<VerifyGPUCodeNode>(*this);
+    n->target_constraints_ = this->target_constraints_;
+    return Postproc(n);
+  }
+
   static constexpr const char* _type_key = "meta_schedule.VerifyGPUCode";
   TVM_DECLARE_FINAL_OBJECT_INFO(VerifyGPUCodeNode, PostprocNode);
 };
diff --git a/src/meta_schedule/schedule_rule/add_rfactor.cc b/src/meta_schedule/schedule_rule/add_rfactor.cc
index cf87f24ac2..2fc1352677 100644
--- a/src/meta_schedule/schedule_rule/add_rfactor.cc
+++ b/src/meta_schedule/schedule_rule/add_rfactor.cc
@@ -36,6 +36,12 @@ class AddRFactorNode : public ScheduleRuleNode {
   // Inherited from ScheduleRuleNode
   Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv);
 
+  // Inherited from ScheduleRuleNode
+  ScheduleRule Clone() const final {
+    ObjectPtr<AddRFactorNode> n = make_object<AddRFactorNode>(*this);
+    return ScheduleRule(n);
+  }
+
  public:
   /*!
    * \brief The maximum number of jobs to be launched per core.
diff --git a/src/meta_schedule/schedule_rule/auto_bind.cc b/src/meta_schedule/schedule_rule/auto_bind.cc
index d8f52fa8e1..7af1418d8f 100644
--- a/src/meta_schedule/schedule_rule/auto_bind.cc
+++ b/src/meta_schedule/schedule_rule/auto_bind.cc
@@ -177,6 +177,12 @@ class AutoBindNode : public ScheduleRuleNode {
   // Inherited from ScheduleRuleNode
   Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final;
 
+  // Inherited from ScheduleRuleNode
+  ScheduleRule Clone() const final {
+    ObjectPtr<AutoBindNode> n = make_object<AutoBindNode>(*this);
+    return ScheduleRule(n);
+  }
+
  public:
   /*! \brief The max number of threads per block from Target */
   int64_t max_threads_per_block_ = -1;
diff --git a/src/meta_schedule/schedule_rule/auto_inline.cc b/src/meta_schedule/schedule_rule/auto_inline.cc
index 446c8ead7e..dcdc83f95c 100644
--- a/src/meta_schedule/schedule_rule/auto_inline.cc
+++ b/src/meta_schedule/schedule_rule/auto_inline.cc
@@ -60,6 +60,12 @@ class AutoInlineNode : public ScheduleRuleNode {
     return {sch};
   }
 
+  // Inherited from ScheduleRuleNode
+  ScheduleRule Clone() const final {
+    ObjectPtr<AutoInlineNode> n = make_object<AutoInlineNode>(*this);
+    return ScheduleRule(n);
+  }
+
  public:
   /*! \brief If allows to inline a block into its producer */
   bool into_producer;
diff --git a/src/meta_schedule/schedule_rule/cross_thread_reduction.cc b/src/meta_schedule/schedule_rule/cross_thread_reduction.cc
index 35be33f72e..f2fc67f74c 100644
--- a/src/meta_schedule/schedule_rule/cross_thread_reduction.cc
+++ b/src/meta_schedule/schedule_rule/cross_thread_reduction.cc
@@ -113,6 +113,12 @@ class CrossThreadReductionNode : public ScheduleRuleNode {
     return {tmp_sch, sch};
   }
 
+  // Inherited from ScheduleRuleNode
+  ScheduleRule Clone() const final {
+    ObjectPtr<CrossThreadReductionNode> n = make_object<CrossThreadReductionNode>(*this);
+    return ScheduleRule(n);
+  }
+
  private:
   /*!
    * \brief Check whether the input block is in thread scope, i.e., some of its outer loop is
diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc b/src/meta_schedule/schedule_rule/multi_level_tiling.cc
index c126c85446..1625a27b9a 100644
--- a/src/meta_schedule/schedule_rule/multi_level_tiling.cc
+++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc
@@ -104,6 +104,12 @@ Array<Schedule> MultiLevelTilingNode::Apply(const Schedule& sch, const BlockRV&
   return results;
 }
 
+// Inherited from ScheduleRuleNode
+ScheduleRule MultiLevelTilingNode::Clone() const {
+  ObjectPtr<MultiLevelTilingNode> n = make_object<MultiLevelTilingNode>(*this);
+  return ScheduleRule(n);
+}
+
 std::vector<State> MultiLevelTilingNode::ApplySubRules(std::vector<State> states) {
   states = SubRule(std::move(states), [&](State state) { return TileLoopNest(std::move(state)); });
   states = SubRule(std::move(states), [&](State state) { return AddWriteReuse(std::move(state)); });
diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.h b/src/meta_schedule/schedule_rule/multi_level_tiling.h
index 9161a972c1..47da878c3b 100644
--- a/src/meta_schedule/schedule_rule/multi_level_tiling.h
+++ b/src/meta_schedule/schedule_rule/multi_level_tiling.h
@@ -155,6 +155,9 @@ class MultiLevelTilingNode : public ScheduleRuleNode {
   // Entry of the mega rule; Inherited from ScheduleRuleNode
   Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) override;
 
+  // Inherited from ScheduleRuleNode
+  ScheduleRule Clone() const override;
+
  protected:
   virtual std::vector<State> ApplySubRules(std::vector<State> states);
 
diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc
index 7ddda9b263..13b00fa7de 100644
--- a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc
+++ b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc
@@ -137,6 +137,13 @@ class MultiLevelTilingTensorCoreNode : public MultiLevelTilingNode {
   // Override Apply to apply tensorization-specific analysis before applying sub-rules
   Array<Schedule> Apply(const Schedule& sch, const BlockRV& block_rv) final;
 
+  // Inherited from ScheduleRuleNode
+  ScheduleRule Clone() const final {
+    ObjectPtr<MultiLevelTilingTensorCoreNode> n =
+        make_object<MultiLevelTilingTensorCoreNode>(*this);
+    return ScheduleRule(n);
+  }
+
   /*!
    * \brief Transform and tensorize with the given tensor intrin
    * \param state The state of the meta schedule rule
diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc
index 3a299ed041..b953d1ad4b 100644
--- a/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc
+++ b/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc
@@ -63,6 +63,13 @@ class MultiLevelTilingWithIntrinNode : public MultiLevelTilingNode {
     return res;
   }
 
+  // Inherited from ScheduleRuleNode
+  ScheduleRule Clone() const final {
+    ObjectPtr<MultiLevelTilingWithIntrinNode> n =
+        make_object<MultiLevelTilingWithIntrinNode>(*this);
+    return ScheduleRule(n);
+  }
+
   // Override ApplySubRules to tile the inner loops according to the given tensor intrinsic, then
   // tile the outerloops.
   virtual std::vector<State> ApplySubRules(std::vector<State> states) {
diff --git a/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc b/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc
index 19758996e6..045aa85b73 100644
--- a/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc
+++ b/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc
@@ -79,6 +79,13 @@ class ParallelizeVectorizeUnrollNode : public ScheduleRuleNode {
     return {sch};
   }
 
+  // Inherited from ScheduleRuleNode
+  ScheduleRule Clone() const final {
+    ObjectPtr<ParallelizeVectorizeUnrollNode> n =
+        make_object<ParallelizeVectorizeUnrollNode>(*this);
+    return ScheduleRule(n);
+  }
+
  public:
   /*!
    * \brief The maximum number of jobs to be launched per CPU core. It sets the
diff --git a/src/meta_schedule/schedule_rule/random_compute_location.cc b/src/meta_schedule/schedule_rule/random_compute_location.cc
index 65988dfd56..7796eddd44 100644
--- a/src/meta_schedule/schedule_rule/random_compute_location.cc
+++ b/src/meta_schedule/schedule_rule/random_compute_location.cc
@@ -57,6 +57,12 @@ class RandomComputeLocationNode : public ScheduleRuleNode {
     return {res};
   }
 
+  // Inherited from ScheduleRuleNode
+  ScheduleRule Clone() const final {
+    ObjectPtr<RandomComputeLocationNode> n = make_object<RandomComputeLocationNode>(*this);
+    return ScheduleRule(n);
+  }
+
  private:
   bool CheckConditions(const tir::Schedule sch, const tir::BlockRV& block_rv) const {
     tir::StmtSRef block_sref = sch->GetSRef(block_rv);
diff --git a/src/meta_schedule/schedule_rule/schedule_rule.cc b/src/meta_schedule/schedule_rule/schedule_rule.cc
index 80f8725b0c..416b43f46d 100644
--- a/src/meta_schedule/schedule_rule/schedule_rule.cc
+++ b/src/meta_schedule/schedule_rule/schedule_rule.cc
@@ -33,13 +33,20 @@ Array<tir::Schedule> PyScheduleRuleNode::Apply(const tir::Schedule& sch,
   return f_apply(sch, block);
 }
 
+ScheduleRule PyScheduleRuleNode::Clone() const {
+  ICHECK(f_clone != nullptr) << "PyScheduleRule's Clone method not implemented!";
+  return f_clone();
+}
+
 ScheduleRule ScheduleRule::PyScheduleRule(
     PyScheduleRuleNode::FInitializeWithTuneContext f_initialize_with_tune_context,  //
     PyScheduleRuleNode::FApply f_apply,                                             //
+    PyScheduleRuleNode::FClone f_clone,                                             //
     PyScheduleRuleNode::FAsString f_as_string) {
   ObjectPtr<PyScheduleRuleNode> n = make_object<PyScheduleRuleNode>();
   n->f_initialize_with_tune_context = std::move(f_initialize_with_tune_context);
   n->f_apply = std::move(f_apply);
+  n->f_clone = std::move(f_clone);
   n->f_as_string = std::move(f_as_string);
   return ScheduleRule(n);
 }
@@ -60,6 +67,8 @@ TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleInitializeWithTuneContext")
     .set_body_method<ScheduleRule>(&ScheduleRuleNode::InitializeWithTuneContext);
 TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleApply")
     .set_body_method<ScheduleRule>(&ScheduleRuleNode::Apply);
+TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleClone")
+    .set_body_method<ScheduleRule>(&ScheduleRuleNode::Clone);
 TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRulePyScheduleRule")
     .set_body_typed(ScheduleRule::PyScheduleRule);
 
diff --git a/src/meta_schedule/search_strategy/evolutionary_search.cc b/src/meta_schedule/search_strategy/evolutionary_search.cc
index c5ff9008ef..5930704eb0 100644
--- a/src/meta_schedule/search_strategy/evolutionary_search.cc
+++ b/src/meta_schedule/search_strategy/evolutionary_search.cc
@@ -431,6 +431,24 @@ class EvolutionarySearchNode : public SearchStrategyNode {
     ICHECK(this->state_ != nullptr);
     this->state_->NotifyRunnerResults(measure_candidates, results);
   }
+
+  SearchStrategy Clone() const final {
+    ObjectPtr<EvolutionarySearchNode> n = make_object<EvolutionarySearchNode>();
+    n->max_trials_per_task = this->max_trials_per_task;
+    n->num_trials_per_iter = this->num_trials_per_iter;
+    n->population_size = this->population_size;
+    n->num_empty_iters_before_early_stop = this->num_empty_iters_before_early_stop;
+    n->init_measured_ratio = this->init_measured_ratio;
+    n->init_min_unmeasured = this->init_min_unmeasured;
+    n->genetic_num_iters = this->genetic_num_iters;
+    n->genetic_mutate_prob = this->genetic_mutate_prob;
+    n->genetic_max_fail_count = this->genetic_max_fail_count;
+    n->eps_greedy = this->eps_greedy;
+    n->context_ = this->context_;
+    n->rand_state_ = this->rand_state_;
+    n->state_ = nullptr;  // cleared the state
+    return SearchStrategy(n);
+  }
 };
 
 std::vector<Schedule> EvolutionarySearchNode::State::PickBestFromDatabase(int num) {
diff --git a/src/meta_schedule/search_strategy/replay_func.cc b/src/meta_schedule/search_strategy/replay_func.cc
index 4574c1c817..6914ab2f0f 100644
--- a/src/meta_schedule/search_strategy/replay_func.cc
+++ b/src/meta_schedule/search_strategy/replay_func.cc
@@ -100,6 +100,16 @@ class ReplayFuncNode : public SearchStrategyNode {
     ICHECK(this->state_ != nullptr);
     this->state_->NotifyRunnerResults(results);
   }
+
+  SearchStrategy Clone() const final {
+    ObjectPtr<ReplayFuncNode> n = make_object<ReplayFuncNode>();
+    n->num_trials_per_iter = this->num_trials_per_iter;
+    n->max_trials_per_task = this->max_trials_per_task;
+    n->context_ = this->context_;
+    n->rand_state_ = this->rand_state_;
+    n->state_ = nullptr;  // cleared the state
+    return SearchStrategy(n);
+  }
 };
 
 inline Optional<Array<MeasureCandidate>> ReplayFuncNode::State::GenerateMeasureCandidates() {
diff --git a/src/meta_schedule/search_strategy/replay_trace.cc b/src/meta_schedule/search_strategy/replay_trace.cc
index 64fc683943..bd553bf037 100644
--- a/src/meta_schedule/search_strategy/replay_trace.cc
+++ b/src/meta_schedule/search_strategy/replay_trace.cc
@@ -118,6 +118,17 @@ class ReplayTraceNode : public SearchStrategyNode {
     ICHECK(this->state_ != nullptr);
     this->state_->NotifyRunnerResults(results);
   }
+
+  SearchStrategy Clone() const final {
+    ObjectPtr<ReplayTraceNode> n = make_object<ReplayTraceNode>();
+    n->num_trials_per_iter = this->num_trials_per_iter;
+    n->max_trials_per_task = this->max_trials_per_task;
+    n->max_fail_count = this->max_fail_count;
+    n->context_ = this->context_;
+    n->rand_state_ = this->rand_state_;
+    n->state_ = nullptr;  // cleared the state
+    return SearchStrategy(n);
+  }
 };
 
 inline Optional<Array<MeasureCandidate>> ReplayTraceNode::State::GenerateMeasureCandidates() {
diff --git a/src/meta_schedule/search_strategy/search_strategy.cc b/src/meta_schedule/search_strategy/search_strategy.cc
index 5865fc8422..81c7fda315 100644
--- a/src/meta_schedule/search_strategy/search_strategy.cc
+++ b/src/meta_schedule/search_strategy/search_strategy.cc
@@ -59,18 +59,25 @@ void PySearchStrategyNode::NotifyRunnerResults(const Array<MeasureCandidate>& me
   f_notify_runner_results(measure_candidates, results);
 }
 
+SearchStrategy PySearchStrategyNode::Clone() const {
+  ICHECK(f_clone != nullptr) << "PySearchStrategy's Clone method not implemented!";
+  return f_clone();
+}
+
 SearchStrategy SearchStrategy::PySearchStrategy(
     PySearchStrategyNode::FInitializeWithTuneContext f_initialize_with_tune_context,  //
     PySearchStrategyNode::FPreTuning f_pre_tuning,                                    //
     PySearchStrategyNode::FPostTuning f_post_tuning,                                  //
     PySearchStrategyNode::FGenerateMeasureCandidates f_generate_measure_candidates,   //
-    PySearchStrategyNode::FNotifyRunnerResults f_notify_runner_results) {
+    PySearchStrategyNode::FNotifyRunnerResults f_notify_runner_results,               //
+    PySearchStrategyNode::FClone f_clone) {
   ObjectPtr<PySearchStrategyNode> n = make_object<PySearchStrategyNode>();
   n->f_initialize_with_tune_context = f_initialize_with_tune_context;
   n->f_pre_tuning = f_pre_tuning;
   n->f_post_tuning = f_post_tuning;
   n->f_generate_measure_candidates = f_generate_measure_candidates;
   n->f_notify_runner_results = f_notify_runner_results;
+  n->f_clone = f_clone;
   return SearchStrategy(n);
 }
 
@@ -94,6 +101,8 @@ TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyGenerateMeasureCandidates")
     .set_body_method<SearchStrategy>(&SearchStrategyNode::GenerateMeasureCandidates);
 TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyNotifyRunnerResults")
     .set_body_method<SearchStrategy>(&SearchStrategyNode::NotifyRunnerResults);
+TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyClone")
+    .set_body_method<SearchStrategy>(&SearchStrategyNode::Clone);
 
 }  // namespace meta_schedule
 }  // namespace tvm
diff --git a/src/meta_schedule/space_generator/post_order_apply.cc b/src/meta_schedule/space_generator/post_order_apply.cc
index 9be89e2d9c..991e4fa080 100644
--- a/src/meta_schedule/space_generator/post_order_apply.cc
+++ b/src/meta_schedule/space_generator/post_order_apply.cc
@@ -188,6 +188,15 @@ class PostOrderApplyNode : public SpaceGeneratorNode {
     }
     return result;
   }
+
+  SpaceGenerator Clone() const final {
+    ObjectPtr<PostOrderApplyNode> n = make_object<PostOrderApplyNode>(*this);
+    n->sch_rules_ = Array<ScheduleRule>();
+    for (const ScheduleRule& sch_rule : this->sch_rules_) {
+      n->sch_rules_.push_back(sch_rule->Clone());
+    }
+    return SpaceGenerator(n);
+  }
   static constexpr const char* _type_key = "meta_schedule.PostOrderApply";
   TVM_DECLARE_FINAL_OBJECT_INFO(PostOrderApplyNode, SpaceGeneratorNode);
 };
diff --git a/src/meta_schedule/space_generator/schedule_fn.cc b/src/meta_schedule/space_generator/schedule_fn.cc
index 70559fbcf1..adea139b1c 100644
--- a/src/meta_schedule/space_generator/schedule_fn.cc
+++ b/src/meta_schedule/space_generator/schedule_fn.cc
@@ -72,6 +72,11 @@ class ScheduleFnNode : public SpaceGeneratorNode {
     throw;
   }
 
+  SpaceGenerator Clone() const final {
+    ObjectPtr<ScheduleFnNode> n = make_object<ScheduleFnNode>(*this);
+    return SpaceGenerator(n);
+  }
+
   static constexpr const char* _type_key = "meta_schedule.ScheduleFn";
   TVM_DECLARE_FINAL_OBJECT_INFO(ScheduleFnNode, SpaceGeneratorNode);
 };
diff --git a/src/meta_schedule/space_generator/space_generator.cc b/src/meta_schedule/space_generator/space_generator.cc
index 5c5ab6ebba..6fc31ed896 100644
--- a/src/meta_schedule/space_generator/space_generator.cc
+++ b/src/meta_schedule/space_generator/space_generator.cc
@@ -33,12 +33,18 @@ Array<tir::Schedule> PySpaceGeneratorNode::GenerateDesignSpace(const IRModule& m
   return f_generate_design_space(mod);
 }
 
+SpaceGenerator PySpaceGeneratorNode::Clone() const {
+  ICHECK(f_clone != nullptr) << "PySpaceGenerator's Clone method not implemented!";
+  return f_clone();
+}
+
 SpaceGenerator SpaceGenerator::PySpaceGenerator(
-    PySpaceGeneratorNode::FInitializeWithTuneContext f_initialize_with_tune_context,
-    PySpaceGeneratorNode::FGenerateDesignSpace f_generate_design_space) {
+    FInitializeWithTuneContext f_initialize_with_tune_context,
+    FGenerateDesignSpace f_generate_design_space, FClone f_clone) {
   ObjectPtr<PySpaceGeneratorNode> n = make_object<PySpaceGeneratorNode>();
   n->f_initialize_with_tune_context = std::move(f_initialize_with_tune_context);
   n->f_generate_design_space = std::move(f_generate_design_space);
+  n->f_clone = std::move(f_clone);
   return SpaceGenerator(n);
 }
 
@@ -51,6 +57,8 @@ TVM_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorGenerateDesignSpace")
     .set_body_method<SpaceGenerator>(&SpaceGeneratorNode::GenerateDesignSpace);
 TVM_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorPySpaceGenerator")
     .set_body_typed(SpaceGenerator::PySpaceGenerator);
+TVM_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorClone")
+    .set_body_method<SpaceGenerator>(&SpaceGeneratorNode::Clone);
 
 }  // namespace meta_schedule
 }  // namespace tvm
diff --git a/src/meta_schedule/space_generator/space_generator_union.cc b/src/meta_schedule/space_generator/space_generator_union.cc
index 6ea61824f9..771d0c187f 100644
--- a/src/meta_schedule/space_generator/space_generator_union.cc
+++ b/src/meta_schedule/space_generator/space_generator_union.cc
@@ -47,6 +47,15 @@ class SpaceGeneratorUnionNode : public SpaceGeneratorNode {
     return design_spaces;
   }
 
+  SpaceGenerator Clone() const final {
+    ObjectPtr<SpaceGeneratorUnionNode> n = make_object<SpaceGeneratorUnionNode>(*this);
+    n->space_generators = Array<SpaceGenerator>();
+    for (const SpaceGenerator& space_generator : this->space_generators) {
+      n->space_generators.push_back(space_generator->Clone());
+    }
+    return SpaceGenerator(n);
+  }
+
   static constexpr const char* _type_key = "meta_schedule.SpaceGeneratorUnion";
   TVM_DECLARE_FINAL_OBJECT_INFO(SpaceGeneratorUnionNode, SpaceGeneratorNode);
 };
diff --git a/src/meta_schedule/tune_context.cc b/src/meta_schedule/tune_context.cc
index 57b2344c6f..3650c0374d 100644
--- a/src/meta_schedule/tune_context.cc
+++ b/src/meta_schedule/tune_context.cc
@@ -52,6 +52,32 @@ TuneContext::TuneContext(Optional<IRModule> mod,
   data_ = std::move(n);
 }
 
+TuneContext TuneContextNode::Clone() const {
+  ObjectPtr<TuneContextNode> n = make_object<TuneContextNode>(*this);
+  if (this->sch_rules.defined()) {
+    n->sch_rules = Array<ScheduleRule>();
+    for (const ScheduleRule& sch_rule : this->sch_rules) {
+      n->sch_rules.push_back(sch_rule->Clone());
+    }
+  }
+  if (this->postprocs.defined()) {
+    n->postprocs = Array<Postproc>();
+    for (const Postproc& postproc : this->postprocs) {
+      n->postprocs.push_back(postproc->Clone());
+    }
+  }
+  if (this->mutator_probs.defined()) {
+    n->mutator_probs = Map<Mutator, FloatImm>();
+    for (const auto& kv : this->mutator_probs) {
+      n->mutator_probs.Set(kv.first->Clone(), kv.second);
+    }
+  }
+  if (this->space_generator.defined()) n->space_generator = this->space_generator.value()->Clone();
+  if (this->search_strategy.defined()) n->search_strategy = this->search_strategy.value()->Clone();
+  n->Initialize();
+  return TuneContext(n);
+}
+
 void TuneContextNode::Initialize() {
   if (this->space_generator.defined()) {
     this->space_generator.value()->InitializeWithTuneContext(GetRef<TuneContext>(this));