You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/09/16 05:01:40 UTC
[tvm] branch main updated: [MetaSchedule] Enable Clone Function for Task-Level Classes (#12796)
This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 02c2eae510 [MetaSchedule] Enable Clone Function for Task-Level Classes (#12796)
02c2eae510 is described below
commit 02c2eae510d6d6c15189427c97819f7ce05f002d
Author: Xiyou Zhou <xi...@octoml.ai>
AuthorDate: Thu Sep 15 22:01:33 2022 -0700
[MetaSchedule] Enable Clone Function for Task-Level Classes (#12796)
This PR introduces a clone function for each of the task-level MetaSchedule classes for convenient class deep copying.
- [x] ScheduleRule
- [x] Postproc
- [x] Mutator
- [x] SpaceGenerator
- [x] SearchStrategy
- [x] TuneContext
---
include/tvm/meta_schedule/mutator.h | 88 ++++++++++------
include/tvm/meta_schedule/postproc.h | 86 ++++++++++------
include/tvm/meta_schedule/schedule_rule.h | 86 ++++++++++------
include/tvm/meta_schedule/search_strategy.h | 114 +++++++++++++--------
include/tvm/meta_schedule/space_generator.h | 78 ++++++++------
include/tvm/meta_schedule/tune_context.h | 6 ++
python/tvm/meta_schedule/mutator/mutator.py | 24 ++++-
python/tvm/meta_schedule/postproc/postproc.py | 24 ++++-
.../meta_schedule/schedule_rule/schedule_rule.py | 32 ++++--
.../search_strategy/search_strategy.py | 23 +++++
.../space_generator/space_generator.py | 24 ++++-
python/tvm/meta_schedule/testing/dummy_object.py | 3 +
python/tvm/meta_schedule/tune_context.py | 10 ++
.../mutator/mutate_compute_location.cc | 5 +
src/meta_schedule/mutator/mutate_parallel.cc | 5 +
src/meta_schedule/mutator/mutate_thread_binding.cc | 5 +
src/meta_schedule/mutator/mutate_tile_size.cc | 5 +
src/meta_schedule/mutator/mutate_unroll.cc | 5 +
src/meta_schedule/mutator/mutator.cc | 8 ++
.../postproc/disallow_dynamic_loop.cc | 5 +
src/meta_schedule/postproc/postproc.cc | 8 ++
.../postproc/rewrite_cooperative_fetch.cc | 5 +
src/meta_schedule/postproc/rewrite_layout.cc | 5 +
.../postproc/rewrite_parallel_vectorize_unroll.cc | 6 ++
.../postproc/rewrite_reduction_block.cc | 5 +
src/meta_schedule/postproc/rewrite_tensorize.cc | 5 +
.../postproc/rewrite_unbound_block.cc | 5 +
src/meta_schedule/postproc/verify_gpu_code.cc | 6 ++
src/meta_schedule/schedule_rule/add_rfactor.cc | 6 ++
src/meta_schedule/schedule_rule/auto_bind.cc | 6 ++
src/meta_schedule/schedule_rule/auto_inline.cc | 6 ++
.../schedule_rule/cross_thread_reduction.cc | 6 ++
.../schedule_rule/multi_level_tiling.cc | 6 ++
.../schedule_rule/multi_level_tiling.h | 3 +
.../multi_level_tiling_tensor_core.cc | 7 ++
.../multi_level_tiling_with_intrin.cc | 7 ++
.../schedule_rule/parallel_vectorize_unroll.cc | 7 ++
.../schedule_rule/random_compute_location.cc | 6 ++
src/meta_schedule/schedule_rule/schedule_rule.cc | 9 ++
.../search_strategy/evolutionary_search.cc | 18 ++++
src/meta_schedule/search_strategy/replay_func.cc | 10 ++
src/meta_schedule/search_strategy/replay_trace.cc | 11 ++
.../search_strategy/search_strategy.cc | 11 +-
.../space_generator/post_order_apply.cc | 9 ++
src/meta_schedule/space_generator/schedule_fn.cc | 5 +
.../space_generator/space_generator.cc | 12 ++-
.../space_generator/space_generator_union.cc | 9 ++
src/meta_schedule/tune_context.cc | 26 +++++
48 files changed, 675 insertions(+), 186 deletions(-)
diff --git a/include/tvm/meta_schedule/mutator.h b/include/tvm/meta_schedule/mutator.h
index 566cc82e97..2b580e75e0 100644
--- a/include/tvm/meta_schedule/mutator.h
+++ b/include/tvm/meta_schedule/mutator.h
@@ -32,6 +32,7 @@ namespace tvm {
namespace meta_schedule {
class TuneContext;
+class Mutator;
/*! \brief Mutator is designed to mutate the trace to explore the design space. */
class MutatorNode : public runtime::Object {
@@ -57,12 +58,21 @@ class MutatorNode : public runtime::Object {
virtual Optional<tir::Trace> Apply(const tir::Trace& trace,
support::LinearCongruentialEngine::TRandState* rand_state) = 0;
+ /*!
+ * \brief Clone the mutator.
+ * \return The cloned mutator.
+ */
+ virtual Mutator Clone() const = 0;
+
static constexpr const char* _type_key = "meta_schedule.Mutator";
TVM_DECLARE_BASE_OBJECT_INFO(MutatorNode, Object);
};
-/*! \brief The mutator with customized methods on the python-side. */
-class PyMutatorNode : public MutatorNode {
+/*!
+ * \brief Managed reference to MutatorNode
+ * \sa MutatorNode
+ */
+class Mutator : public runtime::ObjectRef {
public:
/*!
* \brief The function type of `InitializeWithTuneContext` method.
@@ -76,39 +86,16 @@ class PyMutatorNode : public MutatorNode {
*/
using FApply = runtime::TypedPackedFunc<Optional<tir::Trace>(
const tir::Trace&, support::LinearCongruentialEngine::TRandState rand_state)>;
+ /*!
+ * \brief Clone the mutator.
+ * \return The cloned mutator.
+ */
+ using FClone = runtime::TypedPackedFunc<Mutator()>;
/*!
* \brief Get the mutator as string with name.
* \return The string of the mutator.
*/
using FAsString = runtime::TypedPackedFunc<String()>;
-
- /*! \brief The packed function to the `InitializeWithTuneContext` function. */
- FInitializeWithTuneContext f_initialize_with_tune_context;
- /*! \brief The packed function to the `Apply` function. */
- FApply f_apply;
- /*! \brief The packed function to the `AsString` function. */
- FAsString f_as_string;
-
- void VisitAttrs(tvm::AttrVisitor* v) {
- // `f_initialize_with_tune_context` is not visited
- // `f_apply` is not visited
- // `f_as_string` is not visited
- }
-
- void InitializeWithTuneContext(const TuneContext& context) final;
- Optional<tir::Trace> Apply(const tir::Trace& trace,
- support::LinearCongruentialEngine::TRandState* rand_state) final;
-
- static constexpr const char* _type_key = "meta_schedule.PyMutator";
- TVM_DECLARE_FINAL_OBJECT_INFO(PyMutatorNode, MutatorNode);
-};
-
-/*!
- * \brief Managed reference to MutatorNode
- * \sa MutatorNode
- */
-class Mutator : public runtime::ObjectRef {
- public:
/*! \brief Create a Mutator that mutates the decision of instruction Sample-Perfect-Tile */
TVM_DLL static Mutator MutateTileSize();
/*!
@@ -136,16 +123,49 @@ class Mutator : public runtime::ObjectRef {
* \brief Create a mutator with customized methods on the python-side.
* \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`.
* \param f_apply The packed function of `Apply`.
+ * \param f_clone The packed function of `Clone`.
* \param f_as_string The packed function of `AsString`.
* \return The mutator created.
*/
- TVM_DLL static Mutator PyMutator(
- PyMutatorNode::FInitializeWithTuneContext f_initialize_with_tune_context, //
- PyMutatorNode::FApply f_apply, //
- PyMutatorNode::FAsString f_as_string);
+ TVM_DLL static Mutator PyMutator(FInitializeWithTuneContext f_initialize_with_tune_context, //
+ FApply f_apply, //
+ FClone f_clone, //
+ FAsString f_as_string);
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Mutator, ObjectRef, MutatorNode);
};
+/*! \brief The mutator with customized methods on the python-side. */
+class PyMutatorNode : public MutatorNode {
+ public:
+ using FInitializeWithTuneContext = Mutator::FInitializeWithTuneContext;
+ using FApply = Mutator::FApply;
+ using FClone = Mutator::FClone;
+ using FAsString = Mutator::FAsString;
+ /*! \brief The packed function to the `InitializeWithTuneContext` function. */
+ FInitializeWithTuneContext f_initialize_with_tune_context;
+ /*! \brief The packed function to the `Apply` function. */
+ FApply f_apply;
+ /*! \brief The packed function to the `Clone` function. */
+ FClone f_clone;
+ /*! \brief The packed function to the `AsString` function. */
+ FAsString f_as_string;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ // `f_initialize_with_tune_context` is not visited
+ // `f_apply` is not visited
+ // `f_clone` is not visited
+ // `f_as_string` is not visited
+ }
+
+ void InitializeWithTuneContext(const TuneContext& context) final;
+ Optional<tir::Trace> Apply(const tir::Trace& trace,
+ support::LinearCongruentialEngine::TRandState* rand_state) final;
+ Mutator Clone() const final;
+
+ static constexpr const char* _type_key = "meta_schedule.PyMutator";
+ TVM_DECLARE_FINAL_OBJECT_INFO(PyMutatorNode, MutatorNode);
+};
+
} // namespace meta_schedule
} // namespace tvm
diff --git a/include/tvm/meta_schedule/postproc.h b/include/tvm/meta_schedule/postproc.h
index 5d99f68454..4fafb95576 100644
--- a/include/tvm/meta_schedule/postproc.h
+++ b/include/tvm/meta_schedule/postproc.h
@@ -29,6 +29,7 @@ namespace tvm {
namespace meta_schedule {
class TuneContext;
+class Postproc;
/*!
* \brief Rules to apply a postprocessor to a schedule.
@@ -54,12 +55,21 @@ class PostprocNode : public runtime::Object {
*/
virtual bool Apply(const tir::Schedule& sch) = 0;
+ /*!
+ * \brief Clone the postprocessor.
+ * \return The cloned postprocessor.
+ */
+ virtual Postproc Clone() const = 0;
+
static constexpr const char* _type_key = "meta_schedule.Postproc";
TVM_DECLARE_BASE_OBJECT_INFO(PostprocNode, Object);
};
-/*! \brief The postprocessor with customized methods on the python-side. */
-class PyPostprocNode : public PostprocNode {
+/*!
+ * \brief Managed reference to PostprocNode
+ * \sa PostprocNode
+ */
+class Postproc : public runtime::ObjectRef {
public:
/*!
* \brief The function type of `InitializeWithTuneContext` method.
@@ -72,49 +82,28 @@ class PyPostprocNode : public PostprocNode {
* \return Whether the postprocessor was successfully applied.
*/
using FApply = runtime::TypedPackedFunc<bool(const tir::Schedule&)>;
+ /*!
+ * \brief Clone the postprocessor.
+ * \return The cloned postprocessor.
+ */
+ using FClone = runtime::TypedPackedFunc<Postproc()>;
/*!
* \brief Get the postprocessor function as string with name.
* \return The string of the postprocessor function.
*/
using FAsString = runtime::TypedPackedFunc<String()>;
-
- /*! \brief The packed function to the `InitializeWithTuneContext` function. */
- FInitializeWithTuneContext f_initialize_with_tune_context;
- /*! \brief The packed function to the `Apply` function. */
- FApply f_apply;
- /*! \brief The packed function to the `AsString` function. */
- FAsString f_as_string;
-
- void VisitAttrs(tvm::AttrVisitor* v) {
- // `f_initialize_with_tune_context` is not visited
- // `f_apply` is not visited
- // `f_as_string` is not visited
- }
-
- void InitializeWithTuneContext(const TuneContext& context) final;
- bool Apply(const tir::Schedule& sch) final;
-
- static constexpr const char* _type_key = "meta_schedule.PyPostproc";
- TVM_DECLARE_FINAL_OBJECT_INFO(PyPostprocNode, PostprocNode);
-};
-
-/*!
- * \brief Managed reference to PostprocNode
- * \sa PostprocNode
- */
-class Postproc : public runtime::ObjectRef {
- public:
/*!
* \brief Create a postprocessor with customized methods on the python-side.
* \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`.
* \param f_apply The packed function of `Apply`.
+ * \param f_clone The packed function of `Clone`.
* \param f_as_string The packed function of `AsString`.
* \return The postprocessor created.
*/
- TVM_DLL static Postproc PyPostproc(
- PyPostprocNode::FInitializeWithTuneContext f_initialize_with_tune_context, //
- PyPostprocNode::FApply f_apply, //
- PyPostprocNode::FAsString f_as_string);
+ TVM_DLL static Postproc PyPostproc(FInitializeWithTuneContext f_initialize_with_tune_context, //
+ FApply f_apply, //
+ FClone f_clone, //
+ FAsString f_as_string);
/*!
* \brief Create a postprocessor that checks if all loops are static
* \return The postprocessor created
@@ -164,6 +153,37 @@ class Postproc : public runtime::ObjectRef {
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Postproc, ObjectRef, PostprocNode);
};
+/*! \brief The postprocessor with customized methods on the python-side. */
+class PyPostprocNode : public PostprocNode {
+ public:
+ using FInitializeWithTuneContext = Postproc::FInitializeWithTuneContext;
+ using FApply = Postproc::FApply;
+ using FClone = Postproc::FClone;
+ using FAsString = Postproc::FAsString;
+ /*! \brief The packed function to the `InitializeWithTuneContext` function. */
+ FInitializeWithTuneContext f_initialize_with_tune_context;
+ /*! \brief The packed function to the `Apply` function. */
+ FApply f_apply;
+ /*! \brief The packed function to the `Clone` function. */
+ FClone f_clone;
+ /*! \brief The packed function to the `AsString` function. */
+ FAsString f_as_string;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ // `f_initialize_with_tune_context` is not visited
+ // `f_apply` is not visited
+ // `f_clone` is not visited
+ // `f_as_string` is not visited
+ }
+
+ void InitializeWithTuneContext(const TuneContext& context) final;
+ bool Apply(const tir::Schedule& sch) final;
+ Postproc Clone() const final;
+
+ static constexpr const char* _type_key = "meta_schedule.PyPostproc";
+ TVM_DECLARE_FINAL_OBJECT_INFO(PyPostprocNode, PostprocNode);
+};
+
} // namespace meta_schedule
} // namespace tvm
diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h
index 2da441c95e..55704cf4a9 100644
--- a/include/tvm/meta_schedule/schedule_rule.h
+++ b/include/tvm/meta_schedule/schedule_rule.h
@@ -34,6 +34,7 @@ namespace tvm {
namespace meta_schedule {
class TuneContext;
+class ScheduleRule;
/*! \brief Rules to modify a block in a schedule. */
class ScheduleRuleNode : public runtime::Object {
@@ -59,12 +60,21 @@ class ScheduleRuleNode : public runtime::Object {
virtual runtime::Array<tir::Schedule> Apply(const tir::Schedule& sch,
const tir::BlockRV& block) = 0;
+ /*!
+ * \brief Deep clone the schedule rule.
+ * \return The cloned schedule rule.
+ */
+ virtual ScheduleRule Clone() const = 0;
+
static constexpr const char* _type_key = "meta_schedule.ScheduleRule";
TVM_DECLARE_BASE_OBJECT_INFO(ScheduleRuleNode, Object);
};
-/*! \brief The schedule rule with customized methods on the python-side. */
-class PyScheduleRuleNode : public ScheduleRuleNode {
+/*!
+ * \brief Managed reference to ScheduleRuleNode
+ * \sa ScheduleRuleNode
+ */
+class ScheduleRule : public runtime::ObjectRef {
public:
/*!
* \brief The function type of `InitializeWithTuneContext` method.
@@ -84,33 +94,11 @@ class PyScheduleRuleNode : public ScheduleRuleNode {
* \return The string of the schedule rule.
*/
using FAsString = runtime::TypedPackedFunc<String()>;
-
- /*! \brief The packed function to the `InitializeWithTuneContext` function. */
- FInitializeWithTuneContext f_initialize_with_tune_context;
- /*! \brief The packed function to the `Apply` function. */
- FApply f_apply;
- /*! \brief The packed function to the `AsString` function. */
- FAsString f_as_string;
-
- void VisitAttrs(tvm::AttrVisitor* v) {
- // `f_initialize_with_tune_context` is not visited
- // `f_apply` is not visited
- // `f_as_string` is not visited
- }
-
- void InitializeWithTuneContext(const TuneContext& context) final;
- Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block) final;
-
- static constexpr const char* _type_key = "meta_schedule.PyScheduleRule";
- TVM_DECLARE_FINAL_OBJECT_INFO(PyScheduleRuleNode, ScheduleRuleNode);
-};
-
-/*!
- * \brief Managed reference to ScheduleRuleNode
- * \sa ScheduleRuleNode
- */
-class ScheduleRule : public runtime::ObjectRef {
- public:
+ /*!
+ * \brief The function type of `Clone` method.
+ * \return The cloned schedule rule.
+ */
+ using FClone = runtime::TypedPackedFunc<ScheduleRule()>;
/*!
* \brief Create an auto-inline rule that inlines spatial blocks if it satisfies some conditions
* \param into_producer If allows to inline a block into its producer
@@ -249,16 +237,50 @@ class ScheduleRule : public runtime::ObjectRef {
* \brief Create a schedule rule with customized methods on the python-side.
* \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`.
* \param f_apply The packed function of `Apply`.
+ * \param f_clone The packed function of `Clone`.
* \param f_as_string The packed function of `AsString`.
* \return The schedule rule created.
*/
TVM_DLL static ScheduleRule PyScheduleRule(
- PyScheduleRuleNode::FInitializeWithTuneContext f_initialize_with_tune_context, //
- PyScheduleRuleNode::FApply f_apply, //
- PyScheduleRuleNode::FAsString f_as_string);
+ FInitializeWithTuneContext f_initialize_with_tune_context, //
+ FApply f_apply, //
+ FClone f_clone, //
+ FAsString f_as_string);
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ScheduleRule, ObjectRef, ScheduleRuleNode);
};
+/*! \brief The schedule rule with customized methods on the python-side. */
+class PyScheduleRuleNode : public ScheduleRuleNode {
+ public:
+ using FInitializeWithTuneContext = ScheduleRule::FInitializeWithTuneContext;
+ using FApply = ScheduleRule::FApply;
+ using FClone = ScheduleRule::FClone;
+ using FAsString = ScheduleRule::FAsString;
+
+ /*! \brief The packed function to the `InitializeWithTuneContext` function. */
+ FInitializeWithTuneContext f_initialize_with_tune_context;
+ /*! \brief The packed function to the `Apply` function. */
+ FApply f_apply;
+ /*! \brief The packed function to the `AsString` function. */
+ FAsString f_as_string;
+ /*! \brief The packed function to the `Clone` function. */
+ FClone f_clone;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ // `f_initialize_with_tune_context` is not visited
+ // `f_apply` is not visited
+ // `f_as_string` is not visited
+ // `f_clone` is not visited
+ }
+
+ void InitializeWithTuneContext(const TuneContext& context) final;
+ Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block) final;
+ ScheduleRule Clone() const final;
+
+ static constexpr const char* _type_key = "meta_schedule.PyScheduleRule";
+ TVM_DECLARE_FINAL_OBJECT_INFO(PyScheduleRuleNode, ScheduleRuleNode);
+};
+
} // namespace meta_schedule
} // namespace tvm
diff --git a/include/tvm/meta_schedule/search_strategy.h b/include/tvm/meta_schedule/search_strategy.h
index a75a4cd8ae..efd3dc2452 100644
--- a/include/tvm/meta_schedule/search_strategy.h
+++ b/include/tvm/meta_schedule/search_strategy.h
@@ -36,6 +36,7 @@ namespace meta_schedule {
// Forward declaration
class TuneContext;
+class SearchStrategy;
/*!
* \brief The search strategy for measure candidates generation.
@@ -119,12 +120,21 @@ class SearchStrategyNode : public runtime::Object {
virtual void NotifyRunnerResults(const Array<MeasureCandidate>& measure_candidates,
const Array<RunnerResult>& results) = 0;
+ /*!
+ * \brief Clone the search strategy.
+ * \return The cloned search strategy.
+ */
+ virtual SearchStrategy Clone() const = 0;
+
static constexpr const char* _type_key = "meta_schedule.SearchStrategy";
TVM_DECLARE_BASE_OBJECT_INFO(SearchStrategyNode, Object);
};
-/*! \brief The python side customizable class for measure candidate generation */
-class PySearchStrategyNode : public SearchStrategyNode {
+/*!
+ * \brief Managed reference to SearchStrategyNode.
+ * \sa SearchStrategyNode
+ */
+class SearchStrategy : public runtime::ObjectRef {
public:
/*!
* \brief The function type of `InitializeWithTuneContext` method.
@@ -150,44 +160,11 @@ class PySearchStrategyNode : public SearchStrategyNode {
*/
using FNotifyRunnerResults =
runtime::TypedPackedFunc<void(const Array<MeasureCandidate>&, const Array<RunnerResult>&)>;
-
- /*! \brief The packed function to the `InitializeWithTuneContext` method. */
- FInitializeWithTuneContext f_initialize_with_tune_context;
- /*! \brief The packed function to the `PreTuning` method. */
- FPreTuning f_pre_tuning;
- /*! \brief The packed function to the `PostTuning` method. */
- FPostTuning f_post_tuning;
- /*! \brief The packed function to the `GenerateMeasureCandidates` method. */
- FGenerateMeasureCandidates f_generate_measure_candidates;
- /*! \brief The packed function to the `NotifyRunnerResults` method. */
- FNotifyRunnerResults f_notify_runner_results;
-
- void VisitAttrs(tvm::AttrVisitor* v) {
- // `f_initialize_with_tune_context` is not visited
- // `f_pre_tuning` is not visited
- // `f_post_tuning` is not visited
- // `f_generate_measure_candidates` is not visited
- // `f_notify_runner_results` is not visited
- }
-
- void InitializeWithTuneContext(const TuneContext& context) final;
- void PreTuning(const Array<tir::Schedule>& design_spaces, const Optional<Database>& database,
- const Optional<CostModel>& cost_model) final;
- void PostTuning() final;
- Optional<Array<MeasureCandidate>> GenerateMeasureCandidates() final;
- void NotifyRunnerResults(const Array<MeasureCandidate>& measure_candidates,
- const Array<RunnerResult>& results);
-
- static constexpr const char* _type_key = "meta_schedule.PySearchStrategy";
- TVM_DECLARE_FINAL_OBJECT_INFO(PySearchStrategyNode, SearchStrategyNode);
-};
-
-/*!
- * \brief Managed reference to SearchStrategyNode.
- * \sa SearchStrategyNode
- */
-class SearchStrategy : public runtime::ObjectRef {
- public:
+ /*!
+ * \brief The function type of `Clone` method.
+ * \return The cloned search strategy.
+ */
+ using FClone = runtime::TypedPackedFunc<SearchStrategy()>;
/*!
* \brief Create a search strategy with customized methods on the python-side.
* \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`.
@@ -195,14 +172,16 @@ class SearchStrategy : public runtime::ObjectRef {
* \param f_post_tuning The packed function of `PostTuning`.
* \param f_generate_measure_candidates The packed function of `GenerateMeasureCandidates`.
* \param f_notify_runner_results The packed function of `NotifyRunnerResults`.
+ * \param f_clone The packed function of `Clone`.
* \return The search strategy created.
*/
TVM_DLL static SearchStrategy PySearchStrategy(
- PySearchStrategyNode::FInitializeWithTuneContext f_initialize_with_tune_context, //
- PySearchStrategyNode::FPreTuning f_pre_tuning, //
- PySearchStrategyNode::FPostTuning f_post_tuning, //
- PySearchStrategyNode::FGenerateMeasureCandidates f_generate_measure_candidates, //
- PySearchStrategyNode::FNotifyRunnerResults f_notify_runner_results);
+ FInitializeWithTuneContext f_initialize_with_tune_context, //
+ FPreTuning f_pre_tuning, //
+ FPostTuning f_post_tuning, //
+ FGenerateMeasureCandidates f_generate_measure_candidates, //
+ FNotifyRunnerResults f_notify_runner_results, //
+ FClone f_clone);
/*!
* \brief Constructor of replay trace search strategy.
@@ -245,6 +224,51 @@ class SearchStrategy : public runtime::ObjectRef {
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(SearchStrategy, ObjectRef, SearchStrategyNode);
};
+/*! \brief The python side customizable class for measure candidate generation */
+class PySearchStrategyNode : public SearchStrategyNode {
+ public:
+ using FInitializeWithTuneContext = SearchStrategy::FInitializeWithTuneContext;
+ using FPreTuning = SearchStrategy::FPreTuning;
+ using FPostTuning = SearchStrategy::FPostTuning;
+ using FGenerateMeasureCandidates = SearchStrategy::FGenerateMeasureCandidates;
+ using FNotifyRunnerResults = SearchStrategy::FNotifyRunnerResults;
+ using FClone = SearchStrategy::FClone;
+
+ /*! \brief The packed function to the `InitializeWithTuneContext` method. */
+ FInitializeWithTuneContext f_initialize_with_tune_context;
+ /*! \brief The packed function to the `PreTuning` method. */
+ FPreTuning f_pre_tuning;
+ /*! \brief The packed function to the `PostTuning` method. */
+ FPostTuning f_post_tuning;
+ /*! \brief The packed function to the `GenerateMeasureCandidates` method. */
+ FGenerateMeasureCandidates f_generate_measure_candidates;
+ /*! \brief The packed function to the `NotifyRunnerResults` method. */
+ FNotifyRunnerResults f_notify_runner_results;
+ /*! \brief The packed function to the `Clone` method. */
+ FClone f_clone;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ // `f_initialize_with_tune_context` is not visited
+ // `f_pre_tuning` is not visited
+ // `f_post_tuning` is not visited
+ // `f_generate_measure_candidates` is not visited
+ // `f_notify_runner_results` is not visited
+ // `f_clone` is not visited
+ }
+
+ void InitializeWithTuneContext(const TuneContext& context) final;
+ void PreTuning(const Array<tir::Schedule>& design_spaces, const Optional<Database>& database,
+ const Optional<CostModel>& cost_model) final;
+ void PostTuning() final;
+ Optional<Array<MeasureCandidate>> GenerateMeasureCandidates() final;
+ void NotifyRunnerResults(const Array<MeasureCandidate>& measure_candidates,
+ const Array<RunnerResult>& results);
+ SearchStrategy Clone() const final;
+
+ static constexpr const char* _type_key = "meta_schedule.PySearchStrategy";
+ TVM_DECLARE_FINAL_OBJECT_INFO(PySearchStrategyNode, SearchStrategyNode);
+};
+
} // namespace meta_schedule
} // namespace tvm
diff --git a/include/tvm/meta_schedule/space_generator.h b/include/tvm/meta_schedule/space_generator.h
index 2c1b2d4e4d..1e29e757a1 100644
--- a/include/tvm/meta_schedule/space_generator.h
+++ b/include/tvm/meta_schedule/space_generator.h
@@ -31,6 +31,7 @@ namespace meta_schedule {
// Forward declaration
class TuneContext;
+class SpaceGenerator;
/*!
* \brief The abstract class for design space generation.
@@ -87,12 +88,21 @@ class SpaceGeneratorNode : public runtime::Object {
*/
virtual Array<tir::Schedule> GenerateDesignSpace(const IRModule& mod) = 0;
+ /*!
+ * \brief Clone the space generator.
+ * \return The cloned space generator.
+ */
+ virtual SpaceGenerator Clone() const = 0;
+
static constexpr const char* _type_key = "meta_schedule.SpaceGenerator";
TVM_DECLARE_BASE_OBJECT_INFO(SpaceGeneratorNode, Object);
};
-/*! \brief The design space generator with customized methods on the python-side. */
-class PySpaceGeneratorNode : public SpaceGeneratorNode {
+/*!
+ * \brief Managed reference to SpaceGeneratorNode.
+ * \sa SpaceGeneratorNode
+ */
+class SpaceGenerator : public runtime::ObjectRef {
public:
/*!
* \brief The function type of `InitializeWithTuneContext` method.
@@ -105,29 +115,12 @@ class PySpaceGeneratorNode : public SpaceGeneratorNode {
* \return The generated design spaces, i.e., schedules.
*/
using FGenerateDesignSpace = runtime::TypedPackedFunc<Array<tir::Schedule>(const IRModule&)>;
+ /*!
+ * \brief The function type of `Clone` method.
+ * \return The cloned space generator.
+ */
+ using FClone = runtime::TypedPackedFunc<SpaceGenerator()>;
- /*! \brief The packed function to the `InitializeWithTuneContext` function. */
- FInitializeWithTuneContext f_initialize_with_tune_context;
- /*! \brief The packed function to the `GenerateDesignSpace` function. */
- FGenerateDesignSpace f_generate_design_space;
-
- void VisitAttrs(tvm::AttrVisitor* v) {
- // `f_initialize_with_tune_context` is not visited
- // `f_generate_design_space` is not visited
- }
-
- void InitializeWithTuneContext(const TuneContext& context) final;
- Array<tir::Schedule> GenerateDesignSpace(const IRModule& mod) final;
-
- static constexpr const char* _type_key = "meta_schedule.PySpaceGenerator";
- TVM_DECLARE_FINAL_OBJECT_INFO(PySpaceGeneratorNode, SpaceGeneratorNode);
-};
-
-/*!
- * \brief Managed reference to SpaceGeneratorNode.
- * \sa SpaceGeneratorNode
- */
-class SpaceGenerator : public runtime::ObjectRef {
protected:
SpaceGenerator() = default;
@@ -136,11 +129,12 @@ class SpaceGenerator : public runtime::ObjectRef {
* \brief Create a design space generator with customized methods on the python-side.
* \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`.
* \param f_generate_design_space The packed function of `GenerateDesignSpace`.
+ * \param f_clone The packed function of `Clone`.
* \return The design space generator created.
*/
TVM_DLL static SpaceGenerator PySpaceGenerator(
- PySpaceGeneratorNode::FInitializeWithTuneContext f_initialize_with_tune_context,
- PySpaceGeneratorNode::FGenerateDesignSpace f_generate_design_space);
+ FInitializeWithTuneContext f_initialize_with_tune_context,
+ FGenerateDesignSpace f_generate_design_space, FClone f_clone);
/*!
* \brief Create a design space generator with customized schedule function.
* \param schedule_fn The schedule function, which can have the following signatures:
@@ -156,14 +150,40 @@ class SpaceGenerator : public runtime::ObjectRef {
*/
TVM_DLL static SpaceGenerator SpaceGeneratorUnion(Array<SpaceGenerator, void> space_generators);
/*!
- * \brief Create a design space generator that generates design spaces by applying schedule rules
- * to blocks in post-DFS order.
- * \return The design space generator created.
+ * \brief Create a design space generator that generates design spaces by applying schedule
+ * rules to blocks in post-DFS order. \return The design space generator created.
*/
TVM_DLL static SpaceGenerator PostOrderApply(runtime::PackedFunc f_block_filter = nullptr);
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(SpaceGenerator, ObjectRef, SpaceGeneratorNode);
};
+/*! \brief The design space generator with customized methods on the python-side. */
+class PySpaceGeneratorNode : public SpaceGeneratorNode {
+ public:
+ using FInitializeWithTuneContext = SpaceGenerator::FInitializeWithTuneContext;
+ using FGenerateDesignSpace = SpaceGenerator::FGenerateDesignSpace;
+ using FClone = SpaceGenerator::FClone;
+ /*! \brief The packed function to the `InitializeWithTuneContext` function. */
+ FInitializeWithTuneContext f_initialize_with_tune_context;
+ /*! \brief The packed function to the `GenerateDesignSpace` function. */
+ FGenerateDesignSpace f_generate_design_space;
+ /*! \brief The packed function to the `Clone` function. */
+ FClone f_clone;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ // `f_initialize_with_tune_context` is not visited
+ // `f_generate_design_space` is not visited
+ // `f_clone` is not visited
+ }
+
+ void InitializeWithTuneContext(const TuneContext& context) final;
+ Array<tir::Schedule> GenerateDesignSpace(const IRModule& mod) final;
+ SpaceGenerator Clone() const final;
+
+ static constexpr const char* _type_key = "meta_schedule.PySpaceGenerator";
+ TVM_DECLARE_FINAL_OBJECT_INFO(PySpaceGeneratorNode, SpaceGeneratorNode);
+};
+
} // namespace meta_schedule
} // namespace tvm
diff --git a/include/tvm/meta_schedule/tune_context.h b/include/tvm/meta_schedule/tune_context.h
index 3d732e7fbd..4e2f00fb5a 100644
--- a/include/tvm/meta_schedule/tune_context.h
+++ b/include/tvm/meta_schedule/tune_context.h
@@ -43,6 +43,7 @@ namespace meta_schedule {
class TaskSchedulerNode;
class MeasureCallback;
+class TuneContext;
/*! \brief The auto tuning context. */
class TuneContextNode : public runtime::Object {
@@ -99,6 +100,11 @@ class TuneContextNode : public runtime::Object {
/*! \brief Initialize members that needs initialization with tune context. */
void Initialize();
+ /*!
+ * \brief Clone the tune context.
+ * \return The cloned tune context.
+ */
+ TuneContext Clone() const;
/*! \brief Set the measure candidates from the SearchStrategy */
void _SetMeasureCandidates(const Array<MeasureCandidate>& candidates);
/*!
diff --git a/python/tvm/meta_schedule/mutator/mutator.py b/python/tvm/meta_schedule/mutator/mutator.py
index 0c8de96680..c5286aced7 100644
--- a/python/tvm/meta_schedule/mutator/mutator.py
+++ b/python/tvm/meta_schedule/mutator/mutator.py
@@ -58,6 +58,16 @@ class Mutator(Object):
"""
return _ffi_api.MutatorApply(self, trace, -1) # type: ignore # pylint: disable=no-member
+ def clone(self) -> "Mutator":
+ """Clone the mutator.
+
+ Returns
+ -------
+ mutator : Mutator
+ The cloned mutator.
+ """
+ return _ffi_api.MutatorClone(self) # type: ignore # pylint: disable=no-member
+
@register_object("meta_schedule.PyMutator")
class _PyMutator(Mutator):
@@ -72,6 +82,7 @@ class _PyMutator(Mutator):
self,
f_initialize_with_tune_context: Callable = None,
f_apply: Callable = None,
+ f_clone: Callable = None,
f_as_string: Callable = None,
):
"""Constructor."""
@@ -80,6 +91,7 @@ class _PyMutator(Mutator):
_ffi_api.MutatorPyMutator, # type: ignore # pylint: disable=no-member
f_initialize_with_tune_context,
f_apply,
+ f_clone,
f_as_string,
)
@@ -94,7 +106,7 @@ class PyMutator:
_tvm_metadata = {
"cls": _PyMutator,
- "methods": ["_initialize_with_tune_context", "apply", "__str__"],
+ "methods": ["_initialize_with_tune_context", "apply", "clone", "__str__"],
}
def _initialize_with_tune_context(self, context: "TuneContext") -> None:
@@ -122,6 +134,16 @@ class PyMutator:
"""
raise NotImplementedError
+ def clone(self) -> Mutator:
+ """Clone the mutator.
+
+ Returns
+ -------
+ mutator : Mutator
+ The cloned mutator.
+ """
+ raise NotImplementedError
+
def __str__(self) -> str:
"""Get the mutator as string with name.
diff --git a/python/tvm/meta_schedule/postproc/postproc.py b/python/tvm/meta_schedule/postproc/postproc.py
index e37666bd1c..6eec2965ce 100644
--- a/python/tvm/meta_schedule/postproc/postproc.py
+++ b/python/tvm/meta_schedule/postproc/postproc.py
@@ -60,6 +60,16 @@ class Postproc(Object):
"""
return _ffi_api.PostprocApply(self, sch) # type: ignore # pylint: disable=no-member
+ def clone(self) -> "Postproc":
+ """Clone the postprocessor.
+
+ Returns
+ -------
+ cloned_postproc : Postproc
+ The cloned postprocessor.
+ """
+ return _ffi_api.PostprocClone(self) # type: ignore # pylint: disable=no-member
+
@register_object("meta_schedule.PyPostproc")
class _PyPostproc(Postproc):
@@ -74,6 +84,7 @@ class _PyPostproc(Postproc):
self,
f_initialize_with_tune_context: Callable = None,
f_apply: Callable = None,
+ f_clone: Callable = None,
f_as_string: Callable = None,
):
"""Constructor."""
@@ -82,6 +93,7 @@ class _PyPostproc(Postproc):
_ffi_api.PostprocPyPostproc, # type: ignore # pylint: disable=no-member
f_initialize_with_tune_context,
f_apply,
+ f_clone,
f_as_string,
)
@@ -96,7 +108,7 @@ class PyPostproc:
_tvm_metadata = {
"cls": _PyPostproc,
- "methods": ["_initialize_with_tune_context", "apply", "__str__"],
+ "methods": ["_initialize_with_tune_context", "apply", "clone", "__str__"],
}
def _initialize_with_tune_context(self, context: "TuneContext") -> None:
@@ -124,6 +136,16 @@ class PyPostproc:
"""
raise NotImplementedError
+ def clone(self) -> Postproc:
+ """Clone the postprocessor.
+
+ Returns
+ -------
+ cloned_postproc : Postproc
+ The cloned postprocessor.
+ """
+ raise NotImplementedError
+
def __str__(self) -> str:
"""Get the post processor as string with name.
diff --git a/python/tvm/meta_schedule/schedule_rule/schedule_rule.py b/python/tvm/meta_schedule/schedule_rule/schedule_rule.py
index 481444341b..2c8e223611 100644
--- a/python/tvm/meta_schedule/schedule_rule/schedule_rule.py
+++ b/python/tvm/meta_schedule/schedule_rule/schedule_rule.py
@@ -66,6 +66,16 @@ class ScheduleRule(Object):
self, sch, block
)
+ def clone(self) -> "ScheduleRule":
+ """Deep clone the schedule rule.
+
+ Returns
+ -------
+ cloned_rule : ScheduleRule
+ The cloned schedule rule.
+ """
+ return _ffi_api.ScheduleRuleClone(self) # type: ignore # pylint: disable=no-member
+
@register_object("meta_schedule.PyScheduleRule")
class _PyScheduleRule(ScheduleRule):
@@ -80,6 +90,7 @@ class _PyScheduleRule(ScheduleRule):
self,
f_initialize_with_tune_context: Callable = None,
f_apply: Callable = None,
+ f_clone: Callable = None,
f_as_string: Callable = None,
):
"""Constructor."""
@@ -88,6 +99,7 @@ class _PyScheduleRule(ScheduleRule):
_ffi_api.ScheduleRulePyScheduleRule, # type: ignore # pylint: disable=no-member
f_initialize_with_tune_context,
f_apply,
+ f_clone,
f_as_string,
)
@@ -102,7 +114,7 @@ class PyScheduleRule:
_tvm_metadata = {
"cls": _PyScheduleRule,
- "methods": ["_initialize_with_tune_context", "apply", "__str__"],
+ "methods": ["_initialize_with_tune_context", "apply", "clone", "__str__"],
}
def _initialize_with_tune_context(self, context: "TuneContext") -> None:
@@ -113,9 +125,7 @@ class PyScheduleRule:
context : TuneContext
The tuning context for initializing the schedule rule.
"""
- _ffi_api.ScheduleRuleInitializeWithTuneContext( # type: ignore # pylint: disable=no-member
- self, context
- )
+ raise NotImplementedError
def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]:
"""Apply a schedule rule to the specific block in the given schedule.
@@ -132,9 +142,17 @@ class PyScheduleRule:
design_spaces : List[Schedule]
The list of schedules generated by applying the schedule rule.
"""
- return _ffi_api.ScheduleRuleApply( # type: ignore # pylint: disable=no-member
- self, sch, block
- )
+ raise NotImplementedError
+
+ def clone(self) -> ScheduleRule:
+ """Deep clone the schedule rule.
+
+ Returns
+ -------
+ cloned_rule : ScheduleRule
+ The cloned schedule rule.
+ """
+ raise NotImplementedError
def __str__(self) -> str:
"""Get the schedule rule as string with name.
diff --git a/python/tvm/meta_schedule/search_strategy/search_strategy.py b/python/tvm/meta_schedule/search_strategy/search_strategy.py
index e88cdf825a..276e657133 100644
--- a/python/tvm/meta_schedule/search_strategy/search_strategy.py
+++ b/python/tvm/meta_schedule/search_strategy/search_strategy.py
@@ -151,6 +151,16 @@ class SearchStrategy(Object):
results,
)
+ def clone(self) -> "SearchStrategy":
+ """Clone the search strategy.
+
+ Returns
+ -------
+ cloned : SearchStrategy
+ The cloned search strategy.
+ """
+ return _ffi_api.SearchStrategyClone(self) # type: ignore # pylint: disable=no-member
+
@register_object("meta_schedule.PySearchStrategy")
class _PySearchStrategy(SearchStrategy):
@@ -168,6 +178,7 @@ class _PySearchStrategy(SearchStrategy):
f_post_tuning: Callable = None,
f_generate_measure_candidates: Callable = None,
f_notify_runner_results: Callable = None,
+ f_clone: Callable = None,
):
"""Constructor."""
@@ -178,6 +189,7 @@ class _PySearchStrategy(SearchStrategy):
f_post_tuning,
f_generate_measure_candidates,
f_notify_runner_results,
+ f_clone,
)
@@ -197,6 +209,7 @@ class PySearchStrategy:
"post_tuning",
"generate_measure_candidates",
"notify_runner_results",
+ "clone",
],
}
@@ -250,6 +263,16 @@ class PySearchStrategy:
"""
raise NotImplementedError
+ def clone(self) -> SearchStrategy:
+ """Clone the search strategy.
+
+ Returns
+ -------
+ strategy : SearchStrategy
+ The cloned search strategy.
+ """
+ raise NotImplementedError
+
def create( # pylint: disable=keyword-arg-before-vararg
kind: Literal[
diff --git a/python/tvm/meta_schedule/space_generator/space_generator.py b/python/tvm/meta_schedule/space_generator/space_generator.py
index 9d7ebf3bae..23c0361645 100644
--- a/python/tvm/meta_schedule/space_generator/space_generator.py
+++ b/python/tvm/meta_schedule/space_generator/space_generator.py
@@ -72,6 +72,16 @@ class SpaceGenerator(Object):
"""
return _ffi_api.SpaceGeneratorGenerateDesignSpace(self, mod) # type: ignore # pylint: disable=no-member
+ def clone(self) -> "SpaceGenerator":
+ """Clone the design space generator.
+
+ Returns
+ -------
+ cloned_sg : SpaceGenerator
+ The cloned design space generator.
+ """
+ return _ffi_api.SpaceGeneratorClone(self) # type: ignore # pylint: disable=no-member
+
ScheduleFnType = SpaceGenerator.ScheduleFnType
@@ -89,6 +99,7 @@ class _PySpaceGenerator(SpaceGenerator):
self,
f_initialize_with_tune_context: Optional[Callable] = None,
f_generate_design_space: Optional[Callable] = None,
+ f_clone: Optional[Callable] = None,
):
"""Constructor."""
@@ -96,6 +107,7 @@ class _PySpaceGenerator(SpaceGenerator):
_ffi_api.SpaceGeneratorPySpaceGenerator, # type: ignore # pylint: disable=no-member
f_initialize_with_tune_context,
f_generate_design_space,
+ f_clone,
)
@@ -109,7 +121,7 @@ class PySpaceGenerator:
_tvm_metadata = {
"cls": _PySpaceGenerator,
- "methods": ["_initialize_with_tune_context", "generate_design_space"],
+ "methods": ["_initialize_with_tune_context", "generate_design_space", "clone"],
}
def _initialize_with_tune_context(self, context: "TuneContext") -> None:
@@ -137,6 +149,16 @@ class PySpaceGenerator:
"""
raise NotImplementedError
+ def clone(self) -> SpaceGenerator:
+ """Clone the design space generator.
+
+ Returns
+ -------
+ cloned_sg : SpaceGenerator
+ The cloned design space generator.
+ """
+ raise NotImplementedError
+
def create( # pylint: disable=keyword-arg-before-vararg
kind: Union[
diff --git a/python/tvm/meta_schedule/testing/dummy_object.py b/python/tvm/meta_schedule/testing/dummy_object.py
index 50ae974df5..bb22945449 100644
--- a/python/tvm/meta_schedule/testing/dummy_object.py
+++ b/python/tvm/meta_schedule/testing/dummy_object.py
@@ -58,3 +58,6 @@ class DummyMutator(PyMutator):
def apply(self, trace: Trace, _) -> Optional[Trace]:
return Trace(trace.insts, {})
+
+ def clone(self):
+ return DummyMutator()
diff --git a/python/tvm/meta_schedule/tune_context.py b/python/tvm/meta_schedule/tune_context.py
index 17acad8d4a..29cd94110c 100644
--- a/python/tvm/meta_schedule/tune_context.py
+++ b/python/tvm/meta_schedule/tune_context.py
@@ -331,3 +331,13 @@ class TuneContext(Object):
"Please construct TuneContext with search_strategy"
)
return self.search_strategy.notify_runner_results(measure_candidates, results)
+
+ def clone(self) -> "TuneContext":
+ """Clone the TuneContext.
+
+ Returns
+ -------
+ cloned_context : TuneContext
+ The cloned TuneContext.
+ """
+ return _ffi_api.TuneContextClone(self) # type: ignore # pylint: disable=no-member
diff --git a/src/meta_schedule/mutator/mutate_compute_location.cc b/src/meta_schedule/mutator/mutate_compute_location.cc
index 9d6d69ba35..2a31d2da9b 100644
--- a/src/meta_schedule/mutator/mutate_compute_location.cc
+++ b/src/meta_schedule/mutator/mutate_compute_location.cc
@@ -42,6 +42,11 @@ class MutateComputeLocationNode : public MutatorNode {
}
// Inherit from `MutatorNode`
Optional<Trace> Apply(const Trace& trace, TRandState* rand_state) final;
+ // Inherit from `MutatorNode`
+ Mutator Clone() const final {
+ ObjectPtr<MutateComputeLocationNode> n = make_object<MutateComputeLocationNode>(*this);
+ return Mutator(n);
+ }
private:
struct Candidate {
diff --git a/src/meta_schedule/mutator/mutate_parallel.cc b/src/meta_schedule/mutator/mutate_parallel.cc
index 82b91da682..9feb4747d8 100644
--- a/src/meta_schedule/mutator/mutate_parallel.cc
+++ b/src/meta_schedule/mutator/mutate_parallel.cc
@@ -188,6 +188,11 @@ class MutateParallelNode : public MutatorNode {
}
// Inherit from `MutatorNode`
Optional<Trace> Apply(const Trace& trace, TRandState* rand_state) final;
+ // Inherit from `MutatorNode`
+ Mutator Clone() const final {
+ ObjectPtr<MutateParallelNode> n = make_object<MutateParallelNode>(*this);
+ return Mutator(n);
+ }
};
/*! \brief The candidate to be mutated */
diff --git a/src/meta_schedule/mutator/mutate_thread_binding.cc b/src/meta_schedule/mutator/mutate_thread_binding.cc
index de780b53e2..f5d89a8509 100644
--- a/src/meta_schedule/mutator/mutate_thread_binding.cc
+++ b/src/meta_schedule/mutator/mutate_thread_binding.cc
@@ -42,6 +42,11 @@ class MutateThreadBindingNode : public MutatorNode {
}
// Inherit from `MutatorNode`
Optional<Trace> Apply(const Trace& trace, TRandState* rand_state) final;
+ // Inherit from `MutatorNode`
+ Mutator Clone() const final {
+ ObjectPtr<MutateThreadBindingNode> n = make_object<MutateThreadBindingNode>(*this);
+ return Mutator(n);
+ }
private:
struct Candidate {
diff --git a/src/meta_schedule/mutator/mutate_tile_size.cc b/src/meta_schedule/mutator/mutate_tile_size.cc
index 4a3bfda8a4..8fb83147ea 100644
--- a/src/meta_schedule/mutator/mutate_tile_size.cc
+++ b/src/meta_schedule/mutator/mutate_tile_size.cc
@@ -63,6 +63,11 @@ class MutateTileSizeNode : public MutatorNode {
void InitializeWithTuneContext(const TuneContext& context) final {}
// Inherit from `MutatorNode`
Optional<Trace> Apply(const Trace& trace, TRandState* rand_state) final;
+ // Inherit from `MutatorNode`
+ Mutator Clone() const final {
+ ObjectPtr<MutateTileSizeNode> n = make_object<MutateTileSizeNode>(*this);
+ return Mutator(n);
+ }
};
/*!
diff --git a/src/meta_schedule/mutator/mutate_unroll.cc b/src/meta_schedule/mutator/mutate_unroll.cc
index c282a171c3..7bbf00343a 100644
--- a/src/meta_schedule/mutator/mutate_unroll.cc
+++ b/src/meta_schedule/mutator/mutate_unroll.cc
@@ -60,6 +60,11 @@ class MutateUnrollNode : public MutatorNode {
void InitializeWithTuneContext(const TuneContext& context) final {}
// Inherit from `MutatorNode`
Optional<Trace> Apply(const Trace& trace, TRandState* rand_state) final;
+ // Inherit from `MutatorNode`
+ Mutator Clone() const final {
+ ObjectPtr<MutateUnrollNode> n = make_object<MutateUnrollNode>(*this);
+ return Mutator(n);
+ }
};
/*! \brief A candidate to be mutated */
diff --git a/src/meta_schedule/mutator/mutator.cc b/src/meta_schedule/mutator/mutator.cc
index 43b95000c7..25312ab61f 100644
--- a/src/meta_schedule/mutator/mutator.cc
+++ b/src/meta_schedule/mutator/mutator.cc
@@ -33,13 +33,20 @@ Optional<tir::Trace> PyMutatorNode::Apply(
return f_apply(trace, *rand_state);
}
+Mutator PyMutatorNode::Clone() const {
+ ICHECK(f_clone != nullptr) << "PyMutator's Clone method not implemented!";
+ return f_clone();
+}
+
Mutator Mutator::PyMutator(
PyMutatorNode::FInitializeWithTuneContext f_initialize_with_tune_context, //
PyMutatorNode::FApply f_apply, //
+ PyMutatorNode::FClone f_clone, //
PyMutatorNode::FAsString f_as_string) {
ObjectPtr<PyMutatorNode> n = make_object<PyMutatorNode>();
n->f_initialize_with_tune_context = std::move(f_initialize_with_tune_context);
n->f_apply = std::move(f_apply);
+ n->f_clone = std::move(f_clone);
n->f_as_string = std::move(f_as_string);
return Mutator(n);
}
@@ -63,6 +70,7 @@ TVM_REGISTER_GLOBAL("meta_schedule.MutatorApply")
TRandState seed_ = (seed != -1) ? seed : support::LinearCongruentialEngine::DeviceRandom();
return self->Apply(trace, &seed_);
});
+TVM_REGISTER_GLOBAL("meta_schedule.MutatorClone").set_body_method<Mutator>(&MutatorNode::Clone);
TVM_REGISTER_GLOBAL("meta_schedule.MutatorPyMutator").set_body_typed(Mutator::PyMutator);
} // namespace meta_schedule
diff --git a/src/meta_schedule/postproc/disallow_dynamic_loop.cc b/src/meta_schedule/postproc/disallow_dynamic_loop.cc
index 85a81f10fd..8362da552e 100644
--- a/src/meta_schedule/postproc/disallow_dynamic_loop.cc
+++ b/src/meta_schedule/postproc/disallow_dynamic_loop.cc
@@ -67,6 +67,11 @@ class DisallowDynamicLoopNode : public PostprocNode {
void InitializeWithTuneContext(const TuneContext& context) final {}
// Inherited from PostprocNode
bool Apply(const tir::Schedule& sch) final { return !tir::DynamicExtentFinder::Find(sch->mod()); }
+ // Inherited from PostprocNode
+ Postproc Clone() const {
+ ObjectPtr<DisallowDynamicLoopNode> n = make_object<DisallowDynamicLoopNode>(*this);
+ return Postproc(n);
+ }
static constexpr const char* _type_key = "meta_schedule.DisallowDynamicLoop";
TVM_DECLARE_FINAL_OBJECT_INFO(DisallowDynamicLoopNode, PostprocNode);
diff --git a/src/meta_schedule/postproc/postproc.cc b/src/meta_schedule/postproc/postproc.cc
index 0f4f1b1192..957d6e7364 100644
--- a/src/meta_schedule/postproc/postproc.cc
+++ b/src/meta_schedule/postproc/postproc.cc
@@ -32,13 +32,20 @@ bool PyPostprocNode::Apply(const tir::Schedule& sch) {
return f_apply(sch);
}
+Postproc PyPostprocNode::Clone() const {
+ ICHECK(f_clone != nullptr) << "PyPostproc's Clone method not implemented!";
+ return f_clone();
+}
+
Postproc Postproc::PyPostproc(
PyPostprocNode::FInitializeWithTuneContext f_initialize_with_tune_context, //
PyPostprocNode::FApply f_apply, //
+ PyPostprocNode::FClone f_clone, //
PyPostprocNode::FAsString f_as_string) {
ObjectPtr<PyPostprocNode> n = make_object<PyPostprocNode>();
n->f_initialize_with_tune_context = std::move(f_initialize_with_tune_context);
n->f_apply = std::move(f_apply);
+ n->f_clone = std::move(f_clone);
n->f_as_string = std::move(f_as_string);
return Postproc(n);
}
@@ -58,6 +65,7 @@ TVM_REGISTER_NODE_TYPE(PyPostprocNode);
TVM_REGISTER_GLOBAL("meta_schedule.PostprocInitializeWithTuneContext")
.set_body_method<Postproc>(&PostprocNode::InitializeWithTuneContext);
TVM_REGISTER_GLOBAL("meta_schedule.PostprocApply").set_body_method<Postproc>(&PostprocNode::Apply);
+TVM_REGISTER_GLOBAL("meta_schedule.PostprocClone").set_body_method<Postproc>(&PostprocNode::Clone);
TVM_REGISTER_GLOBAL("meta_schedule.PostprocPyPostproc").set_body_typed(Postproc::PyPostproc);
} // namespace meta_schedule
diff --git a/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc b/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc
index d111bdb42a..ac9f45ca8e 100644
--- a/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc
+++ b/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc
@@ -104,6 +104,11 @@ class RewriteCooperativeFetchNode : public PostprocNode {
// Inherited from PostprocNode
bool Apply(const tir::Schedule& sch) final;
+ Postproc Clone() const {
+ ObjectPtr<RewriteCooperativeFetchNode> n = make_object<RewriteCooperativeFetchNode>(*this);
+ return Postproc(n);
+ }
+
void VisitAttrs(tvm::AttrVisitor* v) {}
static constexpr const char* _type_key = "meta_schedule.RewriteCooperativeFetch";
diff --git a/src/meta_schedule/postproc/rewrite_layout.cc b/src/meta_schedule/postproc/rewrite_layout.cc
index f4cbdfe737..6ff9958c79 100644
--- a/src/meta_schedule/postproc/rewrite_layout.cc
+++ b/src/meta_schedule/postproc/rewrite_layout.cc
@@ -167,6 +167,11 @@ class RewriteLayoutNode : public PostprocNode {
// Inherited from PostprocNode
bool Apply(const tir::Schedule& sch) final { return tir::RewriteLayout(sch); }
+ Postproc Clone() const {
+ ObjectPtr<RewriteLayoutNode> n = make_object<RewriteLayoutNode>(*this);
+ return Postproc(n);
+ }
+
static constexpr const char* _type_key = "meta_schedule.RewriteLayout";
TVM_DECLARE_FINAL_OBJECT_INFO(RewriteLayoutNode, PostprocNode);
};
diff --git a/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc b/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc
index 08d25d0178..c3cc0ef601 100644
--- a/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc
+++ b/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc
@@ -384,6 +384,12 @@ class RewriteParallelVectorizeUnrollNode : public PostprocNode {
return true;
}
+ Postproc Clone() const {
+ ObjectPtr<RewriteParallelVectorizeUnrollNode> n =
+ make_object<RewriteParallelVectorizeUnrollNode>(*this);
+ return Postproc(n);
+ }
+
static constexpr const char* _type_key = "meta_schedule.RewriteParallelVectorizeUnroll";
TVM_DECLARE_FINAL_OBJECT_INFO(RewriteParallelVectorizeUnrollNode, PostprocNode);
};
diff --git a/src/meta_schedule/postproc/rewrite_reduction_block.cc b/src/meta_schedule/postproc/rewrite_reduction_block.cc
index ea204e3061..05a7640f04 100644
--- a/src/meta_schedule/postproc/rewrite_reduction_block.cc
+++ b/src/meta_schedule/postproc/rewrite_reduction_block.cc
@@ -114,6 +114,11 @@ class RewriteReductionBlockNode : public PostprocNode {
// Inherited from PostprocNode
bool Apply(const tir::Schedule& sch) final;
+ Postproc Clone() const {
+ ObjectPtr<RewriteReductionBlockNode> n = make_object<RewriteReductionBlockNode>(*this);
+ return Postproc(n);
+ }
+
void VisitAttrs(tvm::AttrVisitor* v) {}
static constexpr const char* _type_key = "meta_schedule.RewriteReductionBlock";
diff --git a/src/meta_schedule/postproc/rewrite_tensorize.cc b/src/meta_schedule/postproc/rewrite_tensorize.cc
index 3b6c438d02..4f8e0fb213 100644
--- a/src/meta_schedule/postproc/rewrite_tensorize.cc
+++ b/src/meta_schedule/postproc/rewrite_tensorize.cc
@@ -68,6 +68,11 @@ class RewriteTensorizeNode : public PostprocNode {
void VisitAttrs(tvm::AttrVisitor* v) {}
+ Postproc Clone() const {
+ ObjectPtr<RewriteTensorizeNode> n = make_object<RewriteTensorizeNode>(*this);
+ return Postproc(n);
+ }
+
bool vectorize_init_loop = false;
static constexpr const char* _type_key = "meta_schedule.RewriteTensorize";
diff --git a/src/meta_schedule/postproc/rewrite_unbound_block.cc b/src/meta_schedule/postproc/rewrite_unbound_block.cc
index eb57e90f82..1ba68538ea 100644
--- a/src/meta_schedule/postproc/rewrite_unbound_block.cc
+++ b/src/meta_schedule/postproc/rewrite_unbound_block.cc
@@ -97,6 +97,11 @@ class RewriteUnboundBlockNode : public PostprocNode {
// Inherited from PostprocNode
bool Apply(const tir::Schedule& sch) final;
+ Postproc Clone() const {
+ ObjectPtr<RewriteUnboundBlockNode> n = make_object<RewriteUnboundBlockNode>(*this);
+ return Postproc(n);
+ }
+
public:
/*! \brief The max number of threads per block from Target */
int max_threads_per_block_ = -1;
diff --git a/src/meta_schedule/postproc/verify_gpu_code.cc b/src/meta_schedule/postproc/verify_gpu_code.cc
index dfe2c5a06a..0828ee5384 100644
--- a/src/meta_schedule/postproc/verify_gpu_code.cc
+++ b/src/meta_schedule/postproc/verify_gpu_code.cc
@@ -196,6 +196,12 @@ class VerifyGPUCodeNode : public PostprocNode {
return true;
}
+ Postproc Clone() const {
+ ObjectPtr<VerifyGPUCodeNode> n = make_object<VerifyGPUCodeNode>(*this);
+ n->target_constraints_ = this->target_constraints_;
+ return Postproc(n);
+ }
+
static constexpr const char* _type_key = "meta_schedule.VerifyGPUCode";
TVM_DECLARE_FINAL_OBJECT_INFO(VerifyGPUCodeNode, PostprocNode);
};
diff --git a/src/meta_schedule/schedule_rule/add_rfactor.cc b/src/meta_schedule/schedule_rule/add_rfactor.cc
index cf87f24ac2..2fc1352677 100644
--- a/src/meta_schedule/schedule_rule/add_rfactor.cc
+++ b/src/meta_schedule/schedule_rule/add_rfactor.cc
@@ -36,6 +36,12 @@ class AddRFactorNode : public ScheduleRuleNode {
// Inherited from ScheduleRuleNode
Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv);
+ // Inherited from ScheduleRuleNode
+ ScheduleRule Clone() const final {
+ ObjectPtr<AddRFactorNode> n = make_object<AddRFactorNode>(*this);
+ return ScheduleRule(n);
+ }
+
public:
/*!
* \brief The maximum number of jobs to be launched per core.
diff --git a/src/meta_schedule/schedule_rule/auto_bind.cc b/src/meta_schedule/schedule_rule/auto_bind.cc
index d8f52fa8e1..7af1418d8f 100644
--- a/src/meta_schedule/schedule_rule/auto_bind.cc
+++ b/src/meta_schedule/schedule_rule/auto_bind.cc
@@ -177,6 +177,12 @@ class AutoBindNode : public ScheduleRuleNode {
// Inherited from ScheduleRuleNode
Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final;
+ // Inherited from ScheduleRuleNode
+ ScheduleRule Clone() const final {
+ ObjectPtr<AutoBindNode> n = make_object<AutoBindNode>(*this);
+ return ScheduleRule(n);
+ }
+
public:
/*! \brief The max number of threads per block from Target */
int64_t max_threads_per_block_ = -1;
diff --git a/src/meta_schedule/schedule_rule/auto_inline.cc b/src/meta_schedule/schedule_rule/auto_inline.cc
index 446c8ead7e..dcdc83f95c 100644
--- a/src/meta_schedule/schedule_rule/auto_inline.cc
+++ b/src/meta_schedule/schedule_rule/auto_inline.cc
@@ -60,6 +60,12 @@ class AutoInlineNode : public ScheduleRuleNode {
return {sch};
}
+ // Inherited from ScheduleRuleNode
+ ScheduleRule Clone() const final {
+ ObjectPtr<AutoInlineNode> n = make_object<AutoInlineNode>(*this);
+ return ScheduleRule(n);
+ }
+
public:
/*! \brief If allows to inline a block into its producer */
bool into_producer;
diff --git a/src/meta_schedule/schedule_rule/cross_thread_reduction.cc b/src/meta_schedule/schedule_rule/cross_thread_reduction.cc
index 35be33f72e..f2fc67f74c 100644
--- a/src/meta_schedule/schedule_rule/cross_thread_reduction.cc
+++ b/src/meta_schedule/schedule_rule/cross_thread_reduction.cc
@@ -113,6 +113,12 @@ class CrossThreadReductionNode : public ScheduleRuleNode {
return {tmp_sch, sch};
}
+ // Inherited from ScheduleRuleNode
+ ScheduleRule Clone() const final {
+ ObjectPtr<CrossThreadReductionNode> n = make_object<CrossThreadReductionNode>(*this);
+ return ScheduleRule(n);
+ }
+
private:
/*!
* \brief Check whether the input block is in thread scope, i.e., some of its outer loop is
diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc b/src/meta_schedule/schedule_rule/multi_level_tiling.cc
index c126c85446..1625a27b9a 100644
--- a/src/meta_schedule/schedule_rule/multi_level_tiling.cc
+++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc
@@ -104,6 +104,12 @@ Array<Schedule> MultiLevelTilingNode::Apply(const Schedule& sch, const BlockRV&
return results;
}
+// Inherited from ScheduleRuleNode
+ScheduleRule MultiLevelTilingNode::Clone() const {
+ ObjectPtr<MultiLevelTilingNode> n = make_object<MultiLevelTilingNode>(*this);
+ return ScheduleRule(n);
+}
+
std::vector<State> MultiLevelTilingNode::ApplySubRules(std::vector<State> states) {
states = SubRule(std::move(states), [&](State state) { return TileLoopNest(std::move(state)); });
states = SubRule(std::move(states), [&](State state) { return AddWriteReuse(std::move(state)); });
diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.h b/src/meta_schedule/schedule_rule/multi_level_tiling.h
index 9161a972c1..47da878c3b 100644
--- a/src/meta_schedule/schedule_rule/multi_level_tiling.h
+++ b/src/meta_schedule/schedule_rule/multi_level_tiling.h
@@ -155,6 +155,9 @@ class MultiLevelTilingNode : public ScheduleRuleNode {
// Entry of the mega rule; Inherited from ScheduleRuleNode
Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) override;
+ // Inherited from ScheduleRuleNode
+ ScheduleRule Clone() const override;
+
protected:
virtual std::vector<State> ApplySubRules(std::vector<State> states);
diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc
index 7ddda9b263..13b00fa7de 100644
--- a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc
+++ b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc
@@ -137,6 +137,13 @@ class MultiLevelTilingTensorCoreNode : public MultiLevelTilingNode {
// Override Apply to apply tensorization-specific analysis before applying sub-rules
Array<Schedule> Apply(const Schedule& sch, const BlockRV& block_rv) final;
+ // Inherited from ScheduleRuleNode
+ ScheduleRule Clone() const final {
+ ObjectPtr<MultiLevelTilingTensorCoreNode> n =
+ make_object<MultiLevelTilingTensorCoreNode>(*this);
+ return ScheduleRule(n);
+ }
+
/*!
* \brief Transform and tensorize with the given tensor intrin
* \param state The state of the meta schedule rule
diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc
index 3a299ed041..b953d1ad4b 100644
--- a/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc
+++ b/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc
@@ -63,6 +63,13 @@ class MultiLevelTilingWithIntrinNode : public MultiLevelTilingNode {
return res;
}
+ // Inherited from ScheduleRuleNode
+ ScheduleRule Clone() const final {
+ ObjectPtr<MultiLevelTilingWithIntrinNode> n =
+ make_object<MultiLevelTilingWithIntrinNode>(*this);
+ return ScheduleRule(n);
+ }
+
// Override ApplySubRules to tile the inner loops according to the given tensor intrinsic, then
// tile the outerloops.
virtual std::vector<State> ApplySubRules(std::vector<State> states) {
diff --git a/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc b/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc
index 19758996e6..045aa85b73 100644
--- a/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc
+++ b/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc
@@ -79,6 +79,13 @@ class ParallelizeVectorizeUnrollNode : public ScheduleRuleNode {
return {sch};
}
+ // Inherited from ScheduleRuleNode
+ ScheduleRule Clone() const final {
+ ObjectPtr<ParallelizeVectorizeUnrollNode> n =
+ make_object<ParallelizeVectorizeUnrollNode>(*this);
+ return ScheduleRule(n);
+ }
+
public:
/*!
* \brief The maximum number of jobs to be launched per CPU core. It sets the
diff --git a/src/meta_schedule/schedule_rule/random_compute_location.cc b/src/meta_schedule/schedule_rule/random_compute_location.cc
index 65988dfd56..7796eddd44 100644
--- a/src/meta_schedule/schedule_rule/random_compute_location.cc
+++ b/src/meta_schedule/schedule_rule/random_compute_location.cc
@@ -57,6 +57,12 @@ class RandomComputeLocationNode : public ScheduleRuleNode {
return {res};
}
+ // Inherited from ScheduleRuleNode
+ ScheduleRule Clone() const final {
+ ObjectPtr<RandomComputeLocationNode> n = make_object<RandomComputeLocationNode>(*this);
+ return ScheduleRule(n);
+ }
+
private:
bool CheckConditions(const tir::Schedule sch, const tir::BlockRV& block_rv) const {
tir::StmtSRef block_sref = sch->GetSRef(block_rv);
diff --git a/src/meta_schedule/schedule_rule/schedule_rule.cc b/src/meta_schedule/schedule_rule/schedule_rule.cc
index 80f8725b0c..416b43f46d 100644
--- a/src/meta_schedule/schedule_rule/schedule_rule.cc
+++ b/src/meta_schedule/schedule_rule/schedule_rule.cc
@@ -33,13 +33,20 @@ Array<tir::Schedule> PyScheduleRuleNode::Apply(const tir::Schedule& sch,
return f_apply(sch, block);
}
+ScheduleRule PyScheduleRuleNode::Clone() const {
+ ICHECK(f_clone != nullptr) << "PyScheduleRule's Clone method not implemented!";
+ return f_clone();
+}
+
ScheduleRule ScheduleRule::PyScheduleRule(
PyScheduleRuleNode::FInitializeWithTuneContext f_initialize_with_tune_context, //
PyScheduleRuleNode::FApply f_apply, //
+ PyScheduleRuleNode::FClone f_clone, //
PyScheduleRuleNode::FAsString f_as_string) {
ObjectPtr<PyScheduleRuleNode> n = make_object<PyScheduleRuleNode>();
n->f_initialize_with_tune_context = std::move(f_initialize_with_tune_context);
n->f_apply = std::move(f_apply);
+ n->f_clone = std::move(f_clone);
n->f_as_string = std::move(f_as_string);
return ScheduleRule(n);
}
@@ -60,6 +67,8 @@ TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleInitializeWithTuneContext")
.set_body_method<ScheduleRule>(&ScheduleRuleNode::InitializeWithTuneContext);
TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleApply")
.set_body_method<ScheduleRule>(&ScheduleRuleNode::Apply);
+TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleClone")
+ .set_body_method<ScheduleRule>(&ScheduleRuleNode::Clone);
TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRulePyScheduleRule")
.set_body_typed(ScheduleRule::PyScheduleRule);
diff --git a/src/meta_schedule/search_strategy/evolutionary_search.cc b/src/meta_schedule/search_strategy/evolutionary_search.cc
index c5ff9008ef..5930704eb0 100644
--- a/src/meta_schedule/search_strategy/evolutionary_search.cc
+++ b/src/meta_schedule/search_strategy/evolutionary_search.cc
@@ -431,6 +431,24 @@ class EvolutionarySearchNode : public SearchStrategyNode {
ICHECK(this->state_ != nullptr);
this->state_->NotifyRunnerResults(measure_candidates, results);
}
+
+ SearchStrategy Clone() const final {
+ ObjectPtr<EvolutionarySearchNode> n = make_object<EvolutionarySearchNode>();
+ n->max_trials_per_task = this->max_trials_per_task;
+ n->num_trials_per_iter = this->num_trials_per_iter;
+ n->population_size = this->population_size;
+ n->num_empty_iters_before_early_stop = this->num_empty_iters_before_early_stop;
+ n->init_measured_ratio = this->init_measured_ratio;
+ n->init_min_unmeasured = this->init_min_unmeasured;
+ n->genetic_num_iters = this->genetic_num_iters;
+ n->genetic_mutate_prob = this->genetic_mutate_prob;
+ n->genetic_max_fail_count = this->genetic_max_fail_count;
+ n->eps_greedy = this->eps_greedy;
+ n->context_ = this->context_;
+ n->rand_state_ = this->rand_state_;
+ n->state_ = nullptr; // cleared the state
+ return SearchStrategy(n);
+ }
};
std::vector<Schedule> EvolutionarySearchNode::State::PickBestFromDatabase(int num) {
diff --git a/src/meta_schedule/search_strategy/replay_func.cc b/src/meta_schedule/search_strategy/replay_func.cc
index 4574c1c817..6914ab2f0f 100644
--- a/src/meta_schedule/search_strategy/replay_func.cc
+++ b/src/meta_schedule/search_strategy/replay_func.cc
@@ -100,6 +100,16 @@ class ReplayFuncNode : public SearchStrategyNode {
ICHECK(this->state_ != nullptr);
this->state_->NotifyRunnerResults(results);
}
+
+ SearchStrategy Clone() const final {
+ ObjectPtr<ReplayFuncNode> n = make_object<ReplayFuncNode>();
+ n->num_trials_per_iter = this->num_trials_per_iter;
+ n->max_trials_per_task = this->max_trials_per_task;
+ n->context_ = this->context_;
+ n->rand_state_ = this->rand_state_;
+ n->state_ = nullptr; // cleared the state
+ return SearchStrategy(n);
+ }
};
inline Optional<Array<MeasureCandidate>> ReplayFuncNode::State::GenerateMeasureCandidates() {
diff --git a/src/meta_schedule/search_strategy/replay_trace.cc b/src/meta_schedule/search_strategy/replay_trace.cc
index 64fc683943..bd553bf037 100644
--- a/src/meta_schedule/search_strategy/replay_trace.cc
+++ b/src/meta_schedule/search_strategy/replay_trace.cc
@@ -118,6 +118,17 @@ class ReplayTraceNode : public SearchStrategyNode {
ICHECK(this->state_ != nullptr);
this->state_->NotifyRunnerResults(results);
}
+
+ SearchStrategy Clone() const final {
+ ObjectPtr<ReplayTraceNode> n = make_object<ReplayTraceNode>();
+ n->num_trials_per_iter = this->num_trials_per_iter;
+ n->max_trials_per_task = this->max_trials_per_task;
+ n->max_fail_count = this->max_fail_count;
+ n->context_ = this->context_;
+ n->rand_state_ = this->rand_state_;
+ n->state_ = nullptr; // cleared the state
+ return SearchStrategy(n);
+ }
};
inline Optional<Array<MeasureCandidate>> ReplayTraceNode::State::GenerateMeasureCandidates() {
diff --git a/src/meta_schedule/search_strategy/search_strategy.cc b/src/meta_schedule/search_strategy/search_strategy.cc
index 5865fc8422..81c7fda315 100644
--- a/src/meta_schedule/search_strategy/search_strategy.cc
+++ b/src/meta_schedule/search_strategy/search_strategy.cc
@@ -59,18 +59,25 @@ void PySearchStrategyNode::NotifyRunnerResults(const Array<MeasureCandidate>& me
f_notify_runner_results(measure_candidates, results);
}
+SearchStrategy PySearchStrategyNode::Clone() const {
+ ICHECK(f_clone != nullptr) << "PySearchStrategy's Clone method not implemented!";
+ return f_clone();
+}
+
SearchStrategy SearchStrategy::PySearchStrategy(
PySearchStrategyNode::FInitializeWithTuneContext f_initialize_with_tune_context, //
PySearchStrategyNode::FPreTuning f_pre_tuning, //
PySearchStrategyNode::FPostTuning f_post_tuning, //
PySearchStrategyNode::FGenerateMeasureCandidates f_generate_measure_candidates, //
- PySearchStrategyNode::FNotifyRunnerResults f_notify_runner_results) {
+ PySearchStrategyNode::FNotifyRunnerResults f_notify_runner_results, //
+ PySearchStrategyNode::FClone f_clone) {
ObjectPtr<PySearchStrategyNode> n = make_object<PySearchStrategyNode>();
n->f_initialize_with_tune_context = f_initialize_with_tune_context;
n->f_pre_tuning = f_pre_tuning;
n->f_post_tuning = f_post_tuning;
n->f_generate_measure_candidates = f_generate_measure_candidates;
n->f_notify_runner_results = f_notify_runner_results;
+ n->f_clone = f_clone;
return SearchStrategy(n);
}
@@ -94,6 +101,8 @@ TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyGenerateMeasureCandidates")
.set_body_method<SearchStrategy>(&SearchStrategyNode::GenerateMeasureCandidates);
TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyNotifyRunnerResults")
.set_body_method<SearchStrategy>(&SearchStrategyNode::NotifyRunnerResults);
+TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyClone")
+ .set_body_method<SearchStrategy>(&SearchStrategyNode::Clone);
} // namespace meta_schedule
} // namespace tvm
diff --git a/src/meta_schedule/space_generator/post_order_apply.cc b/src/meta_schedule/space_generator/post_order_apply.cc
index 9be89e2d9c..991e4fa080 100644
--- a/src/meta_schedule/space_generator/post_order_apply.cc
+++ b/src/meta_schedule/space_generator/post_order_apply.cc
@@ -188,6 +188,15 @@ class PostOrderApplyNode : public SpaceGeneratorNode {
}
return result;
}
+
+ SpaceGenerator Clone() const final {
+ ObjectPtr<PostOrderApplyNode> n = make_object<PostOrderApplyNode>(*this);
+ n->sch_rules_ = Array<ScheduleRule>();
+ for (const ScheduleRule& sch_rule : this->sch_rules_) {
+ n->sch_rules_.push_back(sch_rule->Clone());
+ }
+ return SpaceGenerator(n);
+ }
static constexpr const char* _type_key = "meta_schedule.PostOrderApply";
TVM_DECLARE_FINAL_OBJECT_INFO(PostOrderApplyNode, SpaceGeneratorNode);
};
diff --git a/src/meta_schedule/space_generator/schedule_fn.cc b/src/meta_schedule/space_generator/schedule_fn.cc
index 70559fbcf1..adea139b1c 100644
--- a/src/meta_schedule/space_generator/schedule_fn.cc
+++ b/src/meta_schedule/space_generator/schedule_fn.cc
@@ -72,6 +72,11 @@ class ScheduleFnNode : public SpaceGeneratorNode {
throw;
}
+ SpaceGenerator Clone() const final {
+ ObjectPtr<ScheduleFnNode> n = make_object<ScheduleFnNode>(*this);
+ return SpaceGenerator(n);
+ }
+
static constexpr const char* _type_key = "meta_schedule.ScheduleFn";
TVM_DECLARE_FINAL_OBJECT_INFO(ScheduleFnNode, SpaceGeneratorNode);
};
diff --git a/src/meta_schedule/space_generator/space_generator.cc b/src/meta_schedule/space_generator/space_generator.cc
index 5c5ab6ebba..6fc31ed896 100644
--- a/src/meta_schedule/space_generator/space_generator.cc
+++ b/src/meta_schedule/space_generator/space_generator.cc
@@ -33,12 +33,18 @@ Array<tir::Schedule> PySpaceGeneratorNode::GenerateDesignSpace(const IRModule& m
return f_generate_design_space(mod);
}
+SpaceGenerator PySpaceGeneratorNode::Clone() const {
+ ICHECK(f_clone != nullptr) << "PySpaceGenerator's Clone method not implemented!";
+ return f_clone();
+}
+
SpaceGenerator SpaceGenerator::PySpaceGenerator(
- PySpaceGeneratorNode::FInitializeWithTuneContext f_initialize_with_tune_context,
- PySpaceGeneratorNode::FGenerateDesignSpace f_generate_design_space) {
+ FInitializeWithTuneContext f_initialize_with_tune_context,
+ FGenerateDesignSpace f_generate_design_space, FClone f_clone) {
ObjectPtr<PySpaceGeneratorNode> n = make_object<PySpaceGeneratorNode>();
n->f_initialize_with_tune_context = std::move(f_initialize_with_tune_context);
n->f_generate_design_space = std::move(f_generate_design_space);
+ n->f_clone = std::move(f_clone);
return SpaceGenerator(n);
}
@@ -51,6 +57,8 @@ TVM_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorGenerateDesignSpace")
.set_body_method<SpaceGenerator>(&SpaceGeneratorNode::GenerateDesignSpace);
TVM_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorPySpaceGenerator")
.set_body_typed(SpaceGenerator::PySpaceGenerator);
+TVM_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorClone")
+ .set_body_method<SpaceGenerator>(&SpaceGeneratorNode::Clone);
} // namespace meta_schedule
} // namespace tvm
diff --git a/src/meta_schedule/space_generator/space_generator_union.cc b/src/meta_schedule/space_generator/space_generator_union.cc
index 6ea61824f9..771d0c187f 100644
--- a/src/meta_schedule/space_generator/space_generator_union.cc
+++ b/src/meta_schedule/space_generator/space_generator_union.cc
@@ -47,6 +47,15 @@ class SpaceGeneratorUnionNode : public SpaceGeneratorNode {
return design_spaces;
}
+ SpaceGenerator Clone() const final {
+ ObjectPtr<SpaceGeneratorUnionNode> n = make_object<SpaceGeneratorUnionNode>(*this);
+ n->space_generators = Array<SpaceGenerator>();
+ for (const SpaceGenerator& space_generator : this->space_generators) {
+ n->space_generators.push_back(space_generator->Clone());
+ }
+ return SpaceGenerator(n);
+ }
+
static constexpr const char* _type_key = "meta_schedule.SpaceGeneratorUnion";
TVM_DECLARE_FINAL_OBJECT_INFO(SpaceGeneratorUnionNode, SpaceGeneratorNode);
};
diff --git a/src/meta_schedule/tune_context.cc b/src/meta_schedule/tune_context.cc
index 57b2344c6f..3650c0374d 100644
--- a/src/meta_schedule/tune_context.cc
+++ b/src/meta_schedule/tune_context.cc
@@ -52,6 +52,32 @@ TuneContext::TuneContext(Optional<IRModule> mod,
data_ = std::move(n);
}
+TuneContext TuneContextNode::Clone() const {
+ ObjectPtr<TuneContextNode> n = make_object<TuneContextNode>(*this);
+ if (this->sch_rules.defined()) {
+ n->sch_rules = Array<ScheduleRule>();
+ for (const ScheduleRule& sch_rule : this->sch_rules) {
+ n->sch_rules.push_back(sch_rule->Clone());
+ }
+ }
+ if (this->postprocs.defined()) {
+ n->postprocs = Array<Postproc>();
+ for (const Postproc& postproc : this->postprocs) {
+ n->postprocs.push_back(postproc->Clone());
+ }
+ }
+ if (this->mutator_probs.defined()) {
+ n->mutator_probs = Map<Mutator, FloatImm>();
+ for (const auto& kv : this->mutator_probs) {
+ n->mutator_probs.Set(kv.first->Clone(), kv.second);
+ }
+ }
+ if (this->space_generator.defined()) n->space_generator = this->space_generator.value()->Clone();
+ if (this->search_strategy.defined()) n->search_strategy = this->search_strategy.value()->Clone();
+ n->Initialize();
+ return TuneContext(n);
+}
+
void TuneContextNode::Initialize() {
if (this->space_generator.defined()) {
this->space_generator.value()->InitializeWithTuneContext(GetRef<TuneContext>(this));