You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/01/02 01:23:03 UTC
[tvm] branch main updated: [M3c][MetaScheduler] Update TuneContext, TaskScheduler & Search Strategy Design (#9789)
This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 1c7d36f [M3c][MetaScheduler] Update TuneContext, TaskScheduler & Search Strategy Design (#9789)
1c7d36f is described below
commit 1c7d36ff271f64e460487900e52ff6d6e5f5d451
Author: Xiyou Zhou <xi...@octoml.ai>
AuthorDate: Sat Jan 1 17:22:28 2022 -0800
[M3c][MetaScheduler] Update TuneContext, TaskScheduler & Search Strategy Design (#9789)
* Modify TuneContext, TaskScheduler & SearchStrategy functions.
Co-authored-by: Junru Shao <ju...@gmail.com>
Co-authored-by: Bohan Hou <32...@users.noreply.github.com>
Co-authored-by: Ruihang Lai <la...@qq.com>
Co-authored-by: Hongyi Jin <32...@qq.com>
Co-authored-by: Wuwei Lin <wu...@apache.org>
Co-authored-by: Siyuan Feng <Hz...@sjtu.edu.cn>
* Minor fix.
Fix mypy.
Fix mypy.
* Retrigger CI.
* Minor fixes.
Co-authored-by: Junru Shao <ju...@gmail.com>
Co-authored-by: Bohan Hou <32...@users.noreply.github.com>
Co-authored-by: Ruihang Lai <la...@qq.com>
Co-authored-by: Hongyi Jin <32...@qq.com>
Co-authored-by: Wuwei Lin <wu...@apache.org>
Co-authored-by: Siyuan Feng <Hz...@sjtu.edu.cn>
---
include/tvm/meta_schedule/cost_model.h | 20 +--
include/tvm/meta_schedule/feature_extractor.h | 12 +-
include/tvm/meta_schedule/mutator.h | 146 ++++++++++++++++++
include/tvm/meta_schedule/postproc.h | 167 +++++++++++++++++++++
include/tvm/meta_schedule/search_strategy.h | 46 +++++-
include/tvm/meta_schedule/space_generator.h | 8 +-
include/tvm/meta_schedule/task_scheduler.h | 35 ++++-
include/tvm/meta_schedule/tune_context.h | 29 +++-
include/tvm/support/random_engine.h | 8 +
python/tvm/meta_schedule/cost_model/cost_model.py | 22 ++-
.../tvm/meta_schedule/cost_model/random_model.py | 8 +-
.../feature_extractor/feature_extractor.py | 10 +-
.../feature_extractor/random_feature_extractor.py | 2 +-
python/tvm/meta_schedule/mutator/__init__.py | 22 +++
python/tvm/meta_schedule/mutator/mutator.py | 88 +++++++++++
python/tvm/meta_schedule/postproc/__init__.py | 18 +++
python/tvm/meta_schedule/postproc/postproc.py | 90 +++++++++++
.../meta_schedule/schedule_rule/schedule_rule.py | 10 +-
.../meta_schedule/search_strategy/replay_trace.py | 13 +-
.../search_strategy/search_strategy.py | 41 +++--
.../meta_schedule/space_generator/schedule_fn.py | 4 +-
.../space_generator/space_generator.py | 10 +-
.../meta_schedule/task_scheduler/round_robin.py | 26 +++-
.../meta_schedule/task_scheduler/task_scheduler.py | 20 ++-
python/tvm/meta_schedule/tune_context.py | 27 +++-
src/meta_schedule/cost_model/cost_model.cc | 4 +-
src/meta_schedule/search_strategy/replay_trace.cc | 63 +++++---
src/meta_schedule/task_scheduler/round_robin.cc | 15 +-
src/meta_schedule/task_scheduler/task_scheduler.cc | 126 ++++++++--------
src/meta_schedule/tune_context.cc | 43 ++++--
src/meta_schedule/utils.h | 83 +++++++++-
src/tir/schedule/concrete_schedule.cc | 2 +-
src/tir/schedule/traced_schedule.cc | 2 +-
.../unittest/test_meta_schedule_cost_model.py | 12 +-
.../test_meta_schedule_feature_extractor.py | 4 +-
.../test_meta_schedule_post_order_apply.py | 10 +-
.../unittest/test_meta_schedule_search_strategy.py | 73 +++++----
.../unittest/test_meta_schedule_task_scheduler.py | 77 +++++++---
38 files changed, 1136 insertions(+), 260 deletions(-)
diff --git a/include/tvm/meta_schedule/cost_model.h b/include/tvm/meta_schedule/cost_model.h
index b05dc3c..6fadc2f 100644
--- a/include/tvm/meta_schedule/cost_model.h
+++ b/include/tvm/meta_schedule/cost_model.h
@@ -51,20 +51,20 @@ class CostModelNode : public runtime::Object {
/*!
* \brief Update the cost model given running results.
- * \param tune_context The tuning context.
+ * \param context The tuning context.
* \param candidates The measure candidates.
* \param results The running results of the measure candidates.
*/
- virtual void Update(const TuneContext& tune_context, const Array<MeasureCandidate>& candidates,
+ virtual void Update(const TuneContext& context, const Array<MeasureCandidate>& candidates,
const Array<RunnerResult>& results) = 0;
/*!
* \brief Predict the normalized score (the larger the better) of given measure candidates.
- * \param tune_context The tuning context.
+ * \param context The tuning context.
* \param candidates The measure candidates.
* \return The predicted normalized score.
*/
- virtual std::vector<double> Predict(const TuneContext& tune_context,
+ virtual std::vector<double> Predict(const TuneContext& context,
const Array<MeasureCandidate>& candidates) = 0;
static constexpr const char* _type_key = "meta_schedule.CostModel";
@@ -86,7 +86,7 @@ class PyCostModelNode : public CostModelNode {
using FSave = runtime::TypedPackedFunc<void(String)>;
/*!
* \brief Update the cost model given running results.
- * \param tune_context The tuning context.
+ * \param context The tuning context.
* \param candidates The measure candidates.
* \param results The running results of the measure candidates.
* \return Whether cost model was updated successfully.
@@ -95,7 +95,7 @@ class PyCostModelNode : public CostModelNode {
const Array<RunnerResult>&)>;
/*!
* \brief Predict the running results of given measure candidates.
- * \param tune_context The tuning context.
+ * \param context The tuning context.
* \param candidates The measure candidates.
* \param p_addr The address to save the the estimated running results.
*/
@@ -135,17 +135,17 @@ class PyCostModelNode : public CostModelNode {
ICHECK(f_save != nullptr) << "PyCostModel's Save method not implemented!";
f_save(path);
}
- void Update(const TuneContext& tune_context, const Array<MeasureCandidate>& candidates,
+ void Update(const TuneContext& context, const Array<MeasureCandidate>& candidates,
const Array<RunnerResult>& results) {
ICHECK(f_update != nullptr) << "PyCostModel's Update method not implemented!";
- f_update(tune_context, candidates, results);
+ f_update(context, candidates, results);
}
- std::vector<double> Predict(const TuneContext& tune_context,
+ std::vector<double> Predict(const TuneContext& context,
const Array<MeasureCandidate>& candidates) {
ICHECK(f_predict != nullptr) << "PyCostModel's Predict method not implemented!";
std::vector<double> result(candidates.size(), 0.0);
- f_predict(tune_context, candidates, result.data());
+ f_predict(context, candidates, result.data());
return result;
}
diff --git a/include/tvm/meta_schedule/feature_extractor.h b/include/tvm/meta_schedule/feature_extractor.h
index ee5d94c..c2ca2be 100644
--- a/include/tvm/meta_schedule/feature_extractor.h
+++ b/include/tvm/meta_schedule/feature_extractor.h
@@ -37,11 +37,11 @@ class FeatureExtractorNode : public runtime::Object {
/*!
* \brief Extract features from the given measure candidate.
- * \param tune_context The tuning context for feature extraction.
+ * \param context The tuning context for feature extraction.
* \param candidates The measure candidates to extract features from.
* \return The feature ndarray extracted.
*/
- virtual Array<tvm::runtime::NDArray> ExtractFrom(const TuneContext& tune_context,
+ virtual Array<tvm::runtime::NDArray> ExtractFrom(const TuneContext& context,
const Array<MeasureCandidate>& candidates) = 0;
static constexpr const char* _type_key = "meta_schedule.FeatureExtractor";
@@ -53,12 +53,12 @@ class PyFeatureExtractorNode : public FeatureExtractorNode {
public:
/*!
* \brief Extract features from the given measure candidate.
- * \param tune_context The tuning context for feature extraction.
+ * \param context The tuning context for feature extraction.
* \param candidates The measure candidates to extract features from.
* \return The feature ndarray extracted.
*/
using FExtractFrom = runtime::TypedPackedFunc<Array<tvm::runtime::NDArray>(
- const TuneContext& tune_context, const Array<MeasureCandidate>& candidates)>;
+ const TuneContext& context, const Array<MeasureCandidate>& candidates)>;
/*!
* \brief Get the feature extractor as string with name.
* \return The string of the feature extractor.
@@ -75,10 +75,10 @@ class PyFeatureExtractorNode : public FeatureExtractorNode {
// `f_as_string` is not visited
}
- Array<tvm::runtime::NDArray> ExtractFrom(const TuneContext& tune_context,
+ Array<tvm::runtime::NDArray> ExtractFrom(const TuneContext& context,
const Array<MeasureCandidate>& candidates) {
ICHECK(f_extract_from != nullptr) << "PyFeatureExtractor's ExtractFrom method not implemented!";
- return f_extract_from(tune_context, candidates);
+ return f_extract_from(context, candidates);
}
static constexpr const char* _type_key = "meta_schedule.PyFeatureExtractor";
diff --git a/include/tvm/meta_schedule/mutator.h b/include/tvm/meta_schedule/mutator.h
new file mode 100644
index 0000000..e3fa847
--- /dev/null
+++ b/include/tvm/meta_schedule/mutator.h
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef TVM_META_SCHEDULE_MUTATOR_H_
+#define TVM_META_SCHEDULE_MUTATOR_H_
+
+#include <tvm/tir/schedule/schedule.h>
+
+namespace tvm {
+namespace meta_schedule {
+
+class TuneContext;
+
+/*! \brief Mutator is designed to mutate the trace to explore the design space. */
+class MutatorNode : public runtime::Object {
+ public:
+ /*! \brief Virtual destructor. */
+ virtual ~MutatorNode() = default;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {}
+
+ /*!
+ * \brief Initialize the design space generator with tuning context.
+ * \param context The tuning context for initialization.
+ * \note This method is supposed to be called only once before every other method.
+ */
+ virtual void InitializeWithTuneContext(const TuneContext& context) = 0;
+
+ /*!
+ * \brief Apply the mutator function to the given trace.
+ * \param trace The given trace for mutation.
+ * \param rand_state The random state for mutation.
+ * \return None if mutator failed, otherwise return the mutated trace.
+ */
+ virtual Optional<tir::Trace> Apply(const tir::Trace& trace,
+ support::LinearCongruentialEngine::TRandState* rand_state) = 0;
+
+ static constexpr const char* _type_key = "meta_schedule.Mutator";
+ TVM_DECLARE_BASE_OBJECT_INFO(MutatorNode, Object);
+};
+
+/*! \brief The mutator with customized methods on the python-side. */
+class PyMutatorNode : public MutatorNode {
+ public:
+ /*!
+ * \brief The function type of `InitializeWithTuneContext` method.
+ * \param context The tuning context for initialization.
+ */
+ using FInitializeWithTuneContext = runtime::TypedPackedFunc<void(const TuneContext&)>;
+ /*!
+ * \brief Apply the mutator function to the given trace.
+ * \param trace The given trace for mutation.
+ * \return None if mutator failed, otherwise return the mutated trace.
+ */
+ using FApply = runtime::TypedPackedFunc<Optional<tir::Trace>(
+ const tir::Trace&, support::LinearCongruentialEngine::TRandState rand_state)>;
+ /*!
+ * \brief Get the mutator as string with name.
+ * \return The string of the mutator.
+ */
+ using FAsString = runtime::TypedPackedFunc<String()>;
+
+ /*! \brief The packed function to the `InitializeWithTuneContext` function. */
+ FInitializeWithTuneContext f_initialize_with_tune_context;
+ /*! \brief The packed function to the `Apply` function. */
+ FApply f_apply;
+ /*! \brief The packed function to the `AsString` function. */
+ FAsString f_as_string;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ // `f_initialize_with_tune_context` is not visited
+ // `f_apply` is not visited
+ // `f_as_string` is not visited
+ }
+
+ void InitializeWithTuneContext(const TuneContext& context) final {
+ ICHECK(f_initialize_with_tune_context != nullptr)
+ << "PyMutator's InitializeWithTuneContext method not implemented!";
+ this->f_initialize_with_tune_context(context);
+ }
+
+ Optional<tir::Trace> Apply(const tir::Trace& trace,
+ support::LinearCongruentialEngine::TRandState* rand_state) final {
+ ICHECK(f_apply != nullptr) << "PyMutator's Apply method not implemented!";
+ return this->f_apply(trace, *rand_state);
+ }
+
+ static constexpr const char* _type_key = "meta_schedule.PyMutator";
+ TVM_DECLARE_FINAL_OBJECT_INFO(PyMutatorNode, MutatorNode);
+};
+
+/*!
+ * \brief Managed reference to MutatorNode
+ * \sa MutatorNode
+ */
+class Mutator : public runtime::ObjectRef {
+ public:
+ /*! \brief Create a Mutator that mutates the tile size. */
+ TVM_DLL static Mutator MutateTileSize();
+ /*!
+ * \brief Create a Mutator that mutates the parallel extent
+ * \param max_jobs_per_core The maximum number of parallel jobs per core.
+ * \return The created mutator.
+ */
+ TVM_DLL static Mutator MutateParallel(int64_t max_jobs_per_core);
+ /*! \brief Create a Mutator that mutates auto unroll step */
+ TVM_DLL static Mutator MutateUnroll();
+ /*!
+ * \brief Create a Mutator that mutates the outcome of SampleComputeLocation
+ * \return The mutator created
+ */
+ TVM_DLL static Mutator MutateComputeLocation();
+ /*!
+ * \brief Create a mutator with customized methods on the python-side.
+ * \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`.
+ * \param f_apply The packed function of `Apply`.
+ * \param f_as_string The packed function of `AsString`.
+ * \return The mutator created.
+ */
+ TVM_DLL static Mutator PyMutator(
+ PyMutatorNode::FInitializeWithTuneContext f_initialize_with_tune_context, //
+ PyMutatorNode::FApply f_apply, //
+ PyMutatorNode::FAsString f_as_string);
+ TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Mutator, ObjectRef, MutatorNode);
+};
+
+} // namespace meta_schedule
+} // namespace tvm
+
+#endif // TVM_META_SCHEDULE_MUTATOR_H_
diff --git a/include/tvm/meta_schedule/postproc.h b/include/tvm/meta_schedule/postproc.h
new file mode 100644
index 0000000..93e8be0
--- /dev/null
+++ b/include/tvm/meta_schedule/postproc.h
@@ -0,0 +1,167 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef TVM_META_SCHEDULE_POSTPROC_H_
+#define TVM_META_SCHEDULE_POSTPROC_H_
+
+#include <tvm/tir/schedule/schedule.h>
+
+namespace tvm {
+namespace meta_schedule {
+
+class TuneContext;
+
+/*!
+ * \brief Rules to apply a postprocessor to a schedule.
+ */
+class PostprocNode : public runtime::Object {
+ public:
+ /*! \brief Virtual destructor. */
+ virtual ~PostprocNode() = default;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {}
+
+ /*!
+ * \brief Initialize the design space generator with tuning context.
+ * \param context The tuning context for initialization.
+ * \note This method is supposed to be called only once before every other method.
+ */
+ virtual void InitializeWithTuneContext(const TuneContext& context) = 0;
+
+ /*!
+ * \brief Apply a postprocessor to the given schedule.
+ * \param sch The schedule to be post processed.
+ * \return Whether the postprocessor was successfully applied.
+ */
+ virtual bool Apply(const tir::Schedule& sch) = 0;
+
+ static constexpr const char* _type_key = "meta_schedule.Postproc";
+ TVM_DECLARE_BASE_OBJECT_INFO(PostprocNode, Object);
+};
+
+/*! \brief The postprocessor with customized methods on the python-side. */
+class PyPostprocNode : public PostprocNode {
+ public:
+ /*!
+ * \brief The function type of `InitializeWithTuneContext` method.
+ * \param context The tuning context for initialization.
+ */
+ using FInitializeWithTuneContext = runtime::TypedPackedFunc<void(const TuneContext&)>;
+ /*!
+ * \brief Apply a postprocessor to the given schedule.
+ * \param sch The schedule to be post processed.
+ * \return Whether the postprocessor was successfully applied.
+ */
+ using FApply = runtime::TypedPackedFunc<bool(const tir::Schedule&)>;
+ /*!
+ * \brief Get the postprocessor function as string with name.
+ * \return The string of the postprocessor function.
+ */
+ using FAsString = runtime::TypedPackedFunc<String()>;
+
+ /*! \brief The packed function to the `InitializeWithTuneContext` function. */
+ FInitializeWithTuneContext f_initialize_with_tune_context;
+ /*! \brief The packed function to the `Apply` function. */
+ FApply f_apply;
+ /*! \brief The packed function to the `AsString` function. */
+ FAsString f_as_string;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ // `f_initialize_with_tune_context` is not visited
+ // `f_apply` is not visited
+ // `f_as_string` is not visited
+ }
+
+ void InitializeWithTuneContext(const TuneContext& context) final {
+ ICHECK(f_initialize_with_tune_context != nullptr)
+ << "PyPostproc's InitializeWithTuneContext method not implemented!";
+ this->f_initialize_with_tune_context(context);
+ }
+
+ bool Apply(const tir::Schedule& sch) final {
+ ICHECK(f_apply != nullptr) << "PyPostproc's Apply method not implemented!";
+ return this->f_apply(sch);
+ }
+
+ static constexpr const char* _type_key = "meta_schedule.PyPostproc";
+ TVM_DECLARE_FINAL_OBJECT_INFO(PyPostprocNode, PostprocNode);
+};
+
+/*!
+ * \brief Managed reference to PostprocNode
+ * \sa PostprocNode
+ */
+class Postproc : public runtime::ObjectRef {
+ public:
+ /*!
+ * \brief Create a postprocessor with customized methods on the python-side.
+ * \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`.
+ * \param f_apply The packed function of `Apply`.
+ * \param f_as_string The packed function of `AsString`.
+ * \return The postprocessor created.
+ */
+ TVM_DLL static Postproc PyPostproc(
+ PyPostprocNode::FInitializeWithTuneContext f_initialize_with_tune_context, //
+ PyPostprocNode::FApply f_apply, //
+ PyPostprocNode::FAsString f_as_string);
+ /*!
+ * \brief Create a postprocessor that checks if all loops are static
+ * \return The postprocessor created
+ */
+ TVM_DLL static Postproc DisallowDynamicLoop();
+ /*!
+ * \brief Create a postprocessor that rewrites the cooperative fetch annotation to
+ * actual vectorized cooperative fetching in loop bindings.
+ * \return The postprocessor created.
+ */
+ TVM_DLL static Postproc RewriteCooperativeFetch();
+ /*!
+ * \brief Creates a postprocessor that applies parallelization, vectorization and auto unrolling
+ * according to the annotation of each block
+ * \return The postprocessor created
+ */
+ TVM_DLL static Postproc RewriteParallelVectorizeUnroll();
+ /*!
+ * \brief Create a postprocessor that rewrites reduction block by moving the init block out.
+ * \return The postprocessor created.
+ */
+ TVM_DLL static Postproc RewriteReductionBlock();
+ /*!
+ * \brief Create a postprocessor that adds thread binding to unbound blocks
+ * \return The postprocessor created.
+ */
+ TVM_DLL static Postproc RewriteUnboundBlock();
+ /*!
+ * \brief Create a postprocessor that tensorize Tensor Core related components
+ * \return The postprocessor created.
+ */
+ TVM_DLL static Postproc RewriteTensorCore();
+
+ /*!
+ * \brief Creates a postprocessor that verifies if the GPU code is correct
+ * \return The postprocessor created
+ */
+ TVM_DLL static Postproc VerifyGPUCode();
+ TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Postproc, ObjectRef, PostprocNode);
+};
+
+} // namespace meta_schedule
+} // namespace tvm
+
+#endif // TVM_META_SCHEDULE_POSTPROC_H_
diff --git a/include/tvm/meta_schedule/search_strategy.h b/include/tvm/meta_schedule/search_strategy.h
index e1c68c8..e8d2cbd 100644
--- a/include/tvm/meta_schedule/search_strategy.h
+++ b/include/tvm/meta_schedule/search_strategy.h
@@ -28,6 +28,8 @@ namespace meta_schedule {
// Forward declaration
class TuneContext;
+class CostModel;
+class Database;
/*! \brief The schedule (with input shapes) to be measured. */
class MeasureCandidateNode : public runtime::Object {
@@ -133,9 +135,13 @@ class SearchStrategyNode : public runtime::Object {
/*!
* \brief Update the search strategy with measurement results.
+ * \param context The tuning context.
+ * \param measure_candidates The candidates to be measured.
* \param results The measurement results from the runner.
*/
- virtual void NotifyRunnerResults(const Array<RunnerResult>& results) = 0;
+ virtual void NotifyRunnerResults(const TuneContext& context,
+ const Array<MeasureCandidate>& measure_candidates,
+ const Array<RunnerResult>& results) = 0;
static constexpr const char* _type_key = "meta_schedule.SearchStrategy";
TVM_DECLARE_BASE_OBJECT_INFO(SearchStrategyNode, Object);
@@ -165,7 +171,8 @@ class PySearchStrategyNode : public SearchStrategyNode {
* \brief The function type of `NotifyRunnerResults` method.
* \param results The measurement results from the runner.
*/
- using FNotifyRunnerResults = runtime::TypedPackedFunc<void(const Array<RunnerResult>&)>;
+ using FNotifyRunnerResults = runtime::TypedPackedFunc<void(
+ const TuneContext&, const Array<MeasureCandidate>&, const Array<RunnerResult>&)>;
/*! \brief The packed function to the `InitializeWithTuneContext` method. */
FInitializeWithTuneContext f_initialize_with_tune_context;
@@ -208,10 +215,12 @@ class PySearchStrategyNode : public SearchStrategyNode {
return this->f_generate_measure_candidates();
}
- void NotifyRunnerResults(const Array<RunnerResult>& results) final {
+ void NotifyRunnerResults(const TuneContext& context,
+ const Array<MeasureCandidate>& measure_candidates,
+ const Array<RunnerResult>& results) final {
ICHECK(f_notify_runner_results != nullptr)
<< "PySearchStrategy's NotifyRunnerResults method not implemented!";
- this->f_notify_runner_results(results);
+ this->f_notify_runner_results(context, measure_candidates, results);
}
static constexpr const char* _type_key = "meta_schedule.PySearchStrategy";
@@ -247,6 +256,35 @@ class SearchStrategy : public runtime::ObjectRef {
*/
TVM_DLL static SearchStrategy ReplayTrace(int num_trials_per_iter, int num_trials_total);
+ /*!
+ * \brief Constructor of replay func search strategy.
+ * \param num_trials_per_iter The number of trials per iteration, i.e., the batch size.
+ * \param num_trials_total The total number of trials for func replaying.
+ */
+ TVM_DLL static SearchStrategy ReplayFunc(int num_trials_per_iter, int num_trials_total);
+
+ /*!
+ * \brief Constructor of evolutionary search strategy.
+ * \param num_trials_per_iter The number of trials per iteration, i.e., the batch size.
+ * \param num_trials_total The total number of trials for evolutionary search.
+ * \param population_size The initial sample population.
+ * \param init_measured_ratio The ratio of measures samples in initial population.
+ * \param init_max_fail_count The maximum number to fail trace replaying.
+ * \param genetic_num_iters The iterations to run the genetic algorithm.
+ * \param genetic_mutate_prob The probability of mutation.
+ * \param genetic_max_fail_count The maximum number to try evolving the given trace.
+ * \param eps_greedy The ratio to select samples in a greedy fashion via their predicted score.
+ */
+ TVM_DLL static SearchStrategy EvolutionarySearch(int num_trials_per_iter, //
+ int num_trials_total, //
+ int population_size, //
+ double init_measured_ratio, //
+ int init_max_fail_count, //
+ int genetic_num_iters, //
+ double genetic_mutate_prob, //
+ int genetic_max_fail_count, //
+ double eps_greedy);
+
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(SearchStrategy, ObjectRef, SearchStrategyNode);
};
diff --git a/include/tvm/meta_schedule/space_generator.h b/include/tvm/meta_schedule/space_generator.h
index 7aff683..3611870 100644
--- a/include/tvm/meta_schedule/space_generator.h
+++ b/include/tvm/meta_schedule/space_generator.h
@@ -139,13 +139,13 @@ class SpaceGenerator : public ObjectRef {
public:
/*!
* \brief Create a design space generator with customized methods on the python-side.
- * \param initialize_with_tune_context_func The packed function of `InitializeWithTuneContext`.
- * \param generate_design_space_func The packed function of `GenerateDesignSpace`.
+ * \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`.
+ * \param f_generate_design_space The packed function of `GenerateDesignSpace`.
* \return The design space generator created.
*/
TVM_DLL static SpaceGenerator PySpaceGenerator(
- PySpaceGeneratorNode::FInitializeWithTuneContext initialize_with_tune_context_func,
- PySpaceGeneratorNode::FGenerateDesignSpace generate_design_space_func);
+ PySpaceGeneratorNode::FInitializeWithTuneContext f_initialize_with_tune_context,
+ PySpaceGeneratorNode::FGenerateDesignSpace f_generate_design_space);
/*!
* \brief Create a design space generator that is union of multiple design space generators.
diff --git a/include/tvm/meta_schedule/task_scheduler.h b/include/tvm/meta_schedule/task_scheduler.h
index f28c33d..ddd6f4c 100644
--- a/include/tvm/meta_schedule/task_scheduler.h
+++ b/include/tvm/meta_schedule/task_scheduler.h
@@ -20,7 +20,9 @@
#define TVM_META_SCHEDULE_TASK_SCHEDULER_H_
#include <tvm/meta_schedule/builder.h>
+#include <tvm/meta_schedule/cost_model.h>
#include <tvm/meta_schedule/database.h>
+#include <tvm/meta_schedule/measure_callback.h>
#include <tvm/meta_schedule/runner.h>
#include <tvm/meta_schedule/tune_context.h>
@@ -78,7 +80,7 @@ class TaskSchedulerNode : public runtime::Object {
/*! \brief The list of measure callbacks of the scheduler. */
Array<MeasureCallback> measure_callbacks;
- /*! \brief The default desctructor. */
+ /*! \brief The default destructor. */
virtual ~TaskSchedulerNode() = default;
void VisitAttrs(tvm::AttrVisitor* v) {
@@ -247,16 +249,39 @@ class TaskScheduler : public runtime::ObjectRef {
* \param builder The builder of the scheduler.
* \param runner The runner of the scheduler.
* \param database The database of the scheduler.
+ * \param cost_model The cost model of the scheduler.
+ * \param measure_callbacks The measure callbacks of the scheduler.
+ * \return The task scheduler created.
+ */
+ TVM_DLL static TaskScheduler RoundRobin(Array<TuneContext> tasks, //
+ Builder builder, //
+ Runner runner, //
+ Database database, //
+ Optional<CostModel> cost_model, //
+ Optional<Array<MeasureCallback>> measure_callbacks);
+ /*!
+ * \brief Create a task scheduler with customized methods on the python-side.
+ * \param tasks The tasks to be tuned.
+ * \param builder The builder of the scheduler.
+ * \param runner The runner of the scheduler.
+ * \param database The database of the scheduler.
+ * \param cost_model The cost model of the scheduler.
+ * \param measure_callbacks The measure callbacks of the scheduler.
+ * \param f_tune The packed function of `Tune`.
+ * \param f_initialize_task The packed function of `InitializeTask`.
+ * \param f_set_task_stopped The packed function of `SetTaskStopped`.
+ * \param f_is_task_running The packed function of `IsTaskRunning`.
+ * \param f_join_running_task The packed function of `JoinRunningTask`.
+ * \param f_next_task_id The packed function of `NextTaskId`.
+ * \return The task scheduler created.
*/
- TVM_DLL static TaskScheduler RoundRobin(Array<TuneContext> tasks, //
- Builder builder, //
- Runner runner, //
- Database database); //
TVM_DLL static TaskScheduler PyTaskScheduler(
Array<TuneContext> tasks, //
Builder builder, //
Runner runner, //
Database database, //
+ Optional<CostModel> cost_model, //
+ Optional<Array<MeasureCallback>> measure_callbacks, //
PyTaskSchedulerNode::FTune f_tune, //
PyTaskSchedulerNode::FInitializeTask f_initialize_task, //
PyTaskSchedulerNode::FSetTaskStopped f_set_task_stopped, //
diff --git a/include/tvm/meta_schedule/tune_context.h b/include/tvm/meta_schedule/tune_context.h
index 6eacd4d..428a2e8 100644
--- a/include/tvm/meta_schedule/tune_context.h
+++ b/include/tvm/meta_schedule/tune_context.h
@@ -20,7 +20,12 @@
#define TVM_META_SCHEDULE_TUNE_CONTEXT_H_
#include <tvm/ir/module.h>
+#include <tvm/meta_schedule/builder.h>
+#include <tvm/meta_schedule/mutator.h>
+#include <tvm/meta_schedule/postproc.h>
+#include <tvm/meta_schedule/runner.h>
#include <tvm/meta_schedule/schedule_rule.h>
+#include <tvm/meta_schedule/search_strategy.h>
#include <tvm/meta_schedule/space_generator.h>
#include <tvm/support/random_engine.h>
#include <tvm/target/target.h>
@@ -28,6 +33,8 @@
namespace tvm {
namespace meta_schedule {
+class TaskSchedulerNode;
+
/*! \brief The auto tuning context. */
class TuneContextNode : public runtime::Object {
public:
@@ -41,6 +48,10 @@ class TuneContextNode : public runtime::Object {
Optional<SearchStrategy> search_strategy;
/*! \brief The schedule rules. */
Array<ScheduleRule> sch_rules;
+ /*! \brief The postprocessors. */
+ Array<Postproc> postprocs;
+ /*! \brief The probability of using certain mutator. */
+ Map<Mutator, FloatImm> mutator_probs;
/*! \brief The name of the tuning task. */
Optional<String> task_name;
/*! \brief The random state. */
@@ -48,12 +59,16 @@ class TuneContextNode : public runtime::Object {
/*! \brief The number of threads to be used. */
int num_threads;
+ /*! \brief The task scheduler that owns the tune context */
+ const TaskSchedulerNode* task_scheduler;
/*! \brief Whether the tuning task has been stopped or finished. */
bool is_stopped;
- /*! \brief Packed functions to fetch the runner results asynchronously. */
- Optional<Array<RunnerFuture>> runner_futures;
/*! \brief The measure candidates. */
Optional<Array<MeasureCandidate>> measure_candidates;
+ /*! \brief The building results. */
+ Optional<Array<BuilderResult>> builder_results;
+ /*! \brief Packed functions to fetch the runner results asynchronously. */
+ Optional<Array<RunnerFuture>> runner_futures;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("mod", &mod);
@@ -61,14 +76,18 @@ class TuneContextNode : public runtime::Object {
v->Visit("space_generator", &space_generator);
v->Visit("search_strategy", &search_strategy);
v->Visit("sch_rules", &sch_rules);
+ v->Visit("postprocs", &postprocs);
+ v->Visit("mutator_probs", &mutator_probs);
v->Visit("task_name", &task_name);
v->Visit("rand_state", &rand_state);
v->Visit("num_threads", &num_threads);
v->Visit("is_stopped", &is_stopped);
- v->Visit("runner_futures", &runner_futures);
v->Visit("measure_candidates", &measure_candidates);
}
+ /*! \brief Initialize members that needs initialization with tune context. */
+ void Initialize();
+
static constexpr const char* _type_key = "meta_schedule.TuneContext";
TVM_DECLARE_FINAL_OBJECT_INFO(TuneContextNode, Object);
};
@@ -86,6 +105,8 @@ class TuneContext : public runtime::ObjectRef {
* \param space_generator The design space generator.
* \param search_strategy The search strategy.
* \param sch_rules The schedule rules.
+ * \param postprocs The postprocessors.
+ * \param mutator_probs The probability of using certain mutator.
* \param task_name The name of the tuning task.
* \param rand_state The random state.
* \param num_threads The number of threads to be used.
@@ -95,6 +116,8 @@ class TuneContext : public runtime::ObjectRef {
Optional<SpaceGenerator> space_generator, //
Optional<SearchStrategy> search_strategy, //
Optional<Array<ScheduleRule>> sch_rules, //
+ Optional<Array<Postproc>> postprocs, //
+ Optional<Map<Mutator, FloatImm>> mutator_probs, //
Optional<String> task_name, //
support::LinearCongruentialEngine::TRandState rand_state, //
int num_threads);
diff --git a/include/tvm/support/random_engine.h b/include/tvm/support/random_engine.h
index fcd2326..89b1e91 100644
--- a/include/tvm/support/random_engine.h
+++ b/include/tvm/support/random_engine.h
@@ -29,6 +29,7 @@
#include <tvm/runtime/logging.h>
#include <cstdint> // for uint64_t
+#include <random>
namespace tvm {
namespace support {
@@ -74,6 +75,12 @@ class LinearCongruentialEngine {
static constexpr result_type max() { return modulus - 1; }
/*!
+ * \brief Get a device random state
+ * \return The random state
+ */
+ static TRandState DeviceRandom() { return (std::random_device()()) % modulus; }
+
+ /*!
* \brief Operator to move the random state to the next and return the new random state. According
* to definition of linear congruential engine, the new random state value is computed as
* new_random_state = (current_random_state * multiplier + increment) % modulus.
@@ -93,6 +100,7 @@ class LinearCongruentialEngine {
* \param rand_state The random state given in result_type.
*/
void Seed(TRandState rand_state = 1) {
+ ICHECK(rand_state != -1) << "The seed can't be -1 which should be changed to random seed!";
rand_state %= modulus; // Make sure the seed is within the range of modulus.
if (rand_state == 0)
rand_state = 1; // Avoid getting all 0 given the current parameter set.
diff --git a/python/tvm/meta_schedule/cost_model/cost_model.py b/python/tvm/meta_schedule/cost_model/cost_model.py
index f5bd601..f794b11 100644
--- a/python/tvm/meta_schedule/cost_model/cost_model.py
+++ b/python/tvm/meta_schedule/cost_model/cost_model.py
@@ -55,7 +55,7 @@ class CostModel(Object):
def update(
self,
- tune_context: TuneContext,
+ context: TuneContext,
candidates: List[MeasureCandidate],
results: List[RunnerResult],
) -> None:
@@ -63,21 +63,21 @@ class CostModel(Object):
Parameters
----------
- tune_context : TuneContext,
+ context : TuneContext,
The tuning context.
candidates : List[MeasureCandidate]
The measure candidates.
results : List[RunnerResult]
The running results of the measure candidates.
"""
- _ffi_api.CostModelUpdate(self, tune_context, candidates, results) # type: ignore # pylint: disable=no-member
+ _ffi_api.CostModelUpdate(self, context, candidates, results) # type: ignore # pylint: disable=no-member
- def predict(self, tune_context: TuneContext, candidates: List[MeasureCandidate]) -> np.ndarray:
+ def predict(self, context: TuneContext, candidates: List[MeasureCandidate]) -> np.ndarray:
"""Update the cost model given running results.
Parameters
----------
- tune_context : TuneContext,
+ context : TuneContext,
The tuning context.
candidates : List[MeasureCandidate]
The measure candidates.
@@ -91,7 +91,7 @@ class CostModel(Object):
results = np.zeros(shape=(n,), dtype="float64")
_ffi_api.CostModelPredict( # type: ignore # pylint: disable=no-member
self,
- tune_context,
+ context,
candidates,
results.ctypes.data_as(ctypes.c_void_p),
)
@@ -115,20 +115,18 @@ class PyCostModel(CostModel):
@check_override(self.__class__, CostModel)
def f_update(
- tune_context: TuneContext,
+ context: TuneContext,
candidates: List[MeasureCandidate],
results: List[RunnerResult],
) -> None:
- self.update(tune_context, candidates, results)
+ self.update(context, candidates, results)
@check_override(self.__class__, CostModel)
- def f_predict(
- tune_context: TuneContext, candidates: List[MeasureCandidate], return_ptr
- ) -> None:
+ def f_predict(context: TuneContext, candidates: List[MeasureCandidate], return_ptr) -> None:
n = len(candidates)
return_ptr = ctypes.cast(return_ptr, ctypes.POINTER(ctypes.c_double))
array_wrapper = np.ctypeslib.as_array(return_ptr, shape=(n,))
- array_wrapper[:] = self.predict(tune_context, candidates)
+ array_wrapper[:] = self.predict(context, candidates)
assert (
array_wrapper.dtype == "float64"
), "ValueError: Invalid data type returned from CostModel Predict!"
diff --git a/python/tvm/meta_schedule/cost_model/random_model.py b/python/tvm/meta_schedule/cost_model/random_model.py
index 23238d2..8808476 100644
--- a/python/tvm/meta_schedule/cost_model/random_model.py
+++ b/python/tvm/meta_schedule/cost_model/random_model.py
@@ -84,7 +84,7 @@ class RandomModel(PyCostModel):
def update(
self,
- tune_context: TuneContext,
+ context: TuneContext,
candidates: List[MeasureCandidate],
results: List[RunnerResult],
) -> None:
@@ -92,7 +92,7 @@ class RandomModel(PyCostModel):
Parameters
----------
- tune_context : TuneContext,
+ context : TuneContext,
The tuning context.
candidates : List[MeasureCandidate]
The measure candidates.
@@ -100,12 +100,12 @@ class RandomModel(PyCostModel):
The running results of the measure candidates.
"""
- def predict(self, tune_context: TuneContext, candidates: List[MeasureCandidate]) -> np.ndarray:
+ def predict(self, context: TuneContext, candidates: List[MeasureCandidate]) -> np.ndarray:
"""Update the cost model given running results.
Parameters
----------
- tune_context : TuneContext,
+ context : TuneContext,
The tuning context.
candidates : List[MeasureCandidate]
The measure candidates.
diff --git a/python/tvm/meta_schedule/feature_extractor/feature_extractor.py b/python/tvm/meta_schedule/feature_extractor/feature_extractor.py
index bd7656e..5043d4b 100644
--- a/python/tvm/meta_schedule/feature_extractor/feature_extractor.py
+++ b/python/tvm/meta_schedule/feature_extractor/feature_extractor.py
@@ -32,13 +32,13 @@ class FeatureExtractor(Object):
"""Extractor for features from measure candidates for use in cost model."""
def extract_from(
- self, tune_context: TuneContext, candidates: List[MeasureCandidate]
+ self, context: TuneContext, candidates: List[MeasureCandidate]
) -> List[NDArray]:
"""Extract features from the given measure candidate.
Parameters
----------
- tune_context : TuneContext
+ context : TuneContext
The tuning context for feature extraction.
candidates : List[MeasureCandidate]
The measure candidates to extract features from.
@@ -49,7 +49,7 @@ class FeatureExtractor(Object):
The feature numpy ndarray extracted.
"""
result = _ffi_api.FeatureExtractorExtractFrom( # type: ignore # pylint: disable=no-member
- self, tune_context, candidates
+ self, context, candidates
)
return result
@@ -63,9 +63,9 @@ class PyFeatureExtractor(FeatureExtractor):
@check_override(self.__class__, FeatureExtractor)
def f_extract_from(
- tune_context: TuneContext, candidates: List[MeasureCandidate]
+ context: TuneContext, candidates: List[MeasureCandidate]
) -> List[NDArray]:
- features = self.extract_from(tune_context, candidates)
+ features = self.extract_from(context, candidates)
return features
def f_as_string() -> str:
diff --git a/python/tvm/meta_schedule/feature_extractor/random_feature_extractor.py b/python/tvm/meta_schedule/feature_extractor/random_feature_extractor.py
index 7c72a25..d805648 100644
--- a/python/tvm/meta_schedule/feature_extractor/random_feature_extractor.py
+++ b/python/tvm/meta_schedule/feature_extractor/random_feature_extractor.py
@@ -51,7 +51,7 @@ class RandomFeatureExtractor(PyFeatureExtractor):
self.random_state = np.random.get_state()
def extract_from(
- self, tune_context: TuneContext, candidates: List[MeasureCandidate]
+ self, context: TuneContext, candidates: List[MeasureCandidate]
) -> List[NDArray]:
np.random.set_state(self.random_state)
result = [
diff --git a/python/tvm/meta_schedule/mutator/__init__.py b/python/tvm/meta_schedule/mutator/__init__.py
new file mode 100644
index 0000000..f88043b
--- /dev/null
+++ b/python/tvm/meta_schedule/mutator/__init__.py
@@ -0,0 +1,22 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+The tvm.meta_schedule.mutator package.
+Meta Schedule mutator that mutates the trace to explore the
+design space.
+"""
+from .mutator import Mutator, PyMutator
diff --git a/python/tvm/meta_schedule/mutator/mutator.py b/python/tvm/meta_schedule/mutator/mutator.py
new file mode 100644
index 0000000..80e0f66
--- /dev/null
+++ b/python/tvm/meta_schedule/mutator/mutator.py
@@ -0,0 +1,88 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Meta Schedule Mutator."""
+from typing import Optional, TYPE_CHECKING
+
+from tvm._ffi import register_object
+from tvm.runtime import Object
+from tvm.tir.schedule import Trace
+
+from .. import _ffi_api
+from ..utils import _get_hex_address, check_override
+
+if TYPE_CHECKING:
+ from ..tune_context import TuneContext
+
+
+class Mutator(Object):
+ """Mutator is designed to mutate the trace to explore the design space."""
+
+ def initialize_with_tune_context(self, context: "TuneContext") -> None:
+ """Initialize the mutator with a tune context.
+
+ Parameters
+ ----------
+ context : TuneContext
+ The tuning context for initializing the mutator.
+ """
+ _ffi_api.MutatorInitializeWithTuneContext( # type: ignore # pylint: disable=no-member
+ self, context
+ )
+
+ def apply(self, trace: Trace) -> Optional[Trace]:
+ """Apply the mutator function to the given trace.
+
+ Parameters
+ ----------
+ trace : Trace
+ The given trace for mutation.
+
+ Returns
+ -------
+ trace : Optional[Trace]
+ None if mutator failed, otherwise return the mutated trace.
+ """
+ return _ffi_api.MutatorApply(self, trace, -1) # type: ignore # pylint: disable=no-member
+
+
+@register_object("meta_schedule.PyMutator")
+class PyMutator(Mutator):
+ """An abstract mutator with customized methods on the python-side."""
+
+ def __init__(self):
+ """Constructor."""
+
+ @check_override(self.__class__, Mutator)
+ def f_initialize_with_tune_context(context: "TuneContext") -> None:
+ self.initialize_with_tune_context(context)
+
+ @check_override(self.__class__, Mutator)
+ def f_apply(trace: Trace, _) -> Optional[Trace]:
+ return self.apply(trace)
+
+ def f_as_string() -> str:
+ return str(self)
+
+ self.__init_handle_by_constructor__(
+ _ffi_api.MutatorPyMutator, # type: ignore # pylint: disable=no-member
+ f_initialize_with_tune_context,
+ f_apply,
+ f_as_string,
+ )
+
+ def __str__(self) -> str:
+ return f"{self.__class__.__name__}({_get_hex_address(self.handle)})"
diff --git a/python/tvm/meta_schedule/postproc/__init__.py b/python/tvm/meta_schedule/postproc/__init__.py
new file mode 100644
index 0000000..6ee052e
--- /dev/null
+++ b/python/tvm/meta_schedule/postproc/__init__.py
@@ -0,0 +1,18 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""The tvm.meta_schedule.postproc package."""
+from .postproc import Postproc, PyPostproc
diff --git a/python/tvm/meta_schedule/postproc/postproc.py b/python/tvm/meta_schedule/postproc/postproc.py
new file mode 100644
index 0000000..5f1180c
--- /dev/null
+++ b/python/tvm/meta_schedule/postproc/postproc.py
@@ -0,0 +1,90 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Meta Schedule Postproc."""
+
+from typing import TYPE_CHECKING
+
+from tvm._ffi import register_object
+from tvm.runtime import Object
+from tvm.tir.schedule import Schedule
+
+from .. import _ffi_api
+from ..utils import _get_hex_address, check_override
+
+if TYPE_CHECKING:
+ from ..tune_context import TuneContext
+
+
+@register_object("meta_schedule.Postproc")
+class Postproc(Object):
+ """Rules to apply a postprocessor to a schedule."""
+
+ def initialize_with_tune_context(self, context: "TuneContext") -> None:
+ """Initialize the postprocessor with a tune context.
+
+ Parameters
+ ----------
+ context : TuneContext
+ The tuning context for initializing the postprocessor.
+ """
+ _ffi_api.PostprocInitializeWithTuneContext( # type: ignore # pylint: disable=no-member
+ self, context
+ )
+
+ def apply(self, sch: Schedule) -> bool:
+ """Apply a postprocessor to the given schedule.
+
+ Parameters
+ ----------
+ sch : Schedule
+ The schedule to be post processed.
+
+ Returns
+ -------
+ result : bool
+ Whether the postprocessor was successfully applied.
+ """
+ return _ffi_api.PostprocApply(self, sch) # type: ignore # pylint: disable=no-member
+
+
+@register_object("meta_schedule.PyPostproc")
+class PyPostproc(Postproc):
+ """An abstract Postproc with customized methods on the python-side."""
+
+ def __init__(self):
+ """Constructor."""
+
+ @check_override(self.__class__, Postproc)
+ def f_initialize_with_tune_context(context: "TuneContext") -> None:
+ self.initialize_with_tune_context(context)
+
+ @check_override(self.__class__, Postproc)
+ def f_apply(sch: Schedule) -> bool:
+ return self.apply(sch)
+
+ def f_as_string() -> str:
+ return str(self)
+
+ self.__init_handle_by_constructor__(
+ _ffi_api.PostprocPyPostproc, # type: ignore # pylint: disable=no-member
+ f_initialize_with_tune_context,
+ f_apply,
+ f_as_string,
+ )
+
+ def __str__(self) -> str:
+ return f"{self.__class__.__name__}({_get_hex_address(self.handle)})"
diff --git a/python/tvm/meta_schedule/schedule_rule/schedule_rule.py b/python/tvm/meta_schedule/schedule_rule/schedule_rule.py
index b995e5a..ab142c0 100644
--- a/python/tvm/meta_schedule/schedule_rule/schedule_rule.py
+++ b/python/tvm/meta_schedule/schedule_rule/schedule_rule.py
@@ -35,16 +35,16 @@ if TYPE_CHECKING:
class ScheduleRule(Object):
"""Rules to modify a block in a schedule."""
- def initialize_with_tune_context(self, tune_context: "TuneContext") -> None:
+ def initialize_with_tune_context(self, context: "TuneContext") -> None:
"""Initialize the schedule rule with a tune context.
Parameters
----------
- tune_context : TuneContext
+ context : TuneContext
The tuning context for initializing the schedule rule.
"""
_ffi_api.ScheduleRuleInitializeWithTuneContext( # type: ignore # pylint: disable=no-member
- self, tune_context
+ self, context
)
def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]:
@@ -75,8 +75,8 @@ class PyScheduleRule(ScheduleRule):
"""Constructor."""
@check_override(self.__class__, ScheduleRule)
- def f_initialize_with_tune_context(tune_context: "TuneContext") -> None:
- self.initialize_with_tune_context(tune_context)
+ def f_initialize_with_tune_context(context: "TuneContext") -> None:
+ self.initialize_with_tune_context(context)
@check_override(self.__class__, ScheduleRule)
def f_apply(sch: Schedule, block: BlockRV) -> List[Schedule]:
diff --git a/python/tvm/meta_schedule/search_strategy/replay_trace.py b/python/tvm/meta_schedule/search_strategy/replay_trace.py
index 15f8295..5655038 100644
--- a/python/tvm/meta_schedule/search_strategy/replay_trace.py
+++ b/python/tvm/meta_schedule/search_strategy/replay_trace.py
@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""Replay Trace Search Strategy"""
+from typing import NamedTuple
from tvm._ffi import register_object
from .search_strategy import SearchStrategy
@@ -41,7 +42,17 @@ class ReplayTrace(SearchStrategy):
def __init__(self, num_trials_per_iter: int, num_trials_total: int):
"""Constructor"""
self.__init_handle_by_constructor__(
- _ffi_api.ReplayTrace, # type: ignore # pylint: disable=no-member
+ _ffi_api.SearchStrategyReplayTrace, # type: ignore # pylint: disable=no-member
num_trials_per_iter,
num_trials_total,
)
+
+
+class ReplayTraceConfig(NamedTuple):
+ """Configuration for ReplayTrace"""
+
+ num_trials_per_iter: int
+ num_trials_total: int
+
+ def create_strategy(self) -> ReplayTrace:
+ return ReplayTrace(self.num_trials_per_iter, self.num_trials_total)
diff --git a/python/tvm/meta_schedule/search_strategy/search_strategy.py b/python/tvm/meta_schedule/search_strategy/search_strategy.py
index 6cee09e..411fecb 100644
--- a/python/tvm/meta_schedule/search_strategy/search_strategy.py
+++ b/python/tvm/meta_schedule/search_strategy/search_strategy.py
@@ -48,7 +48,11 @@ class MeasureCandidate(Object):
sch: Schedule
args_info: List[ArgInfo]
- def __init__(self, sch: Schedule, args_info: List[ArgInfo]) -> None:
+ def __init__(
+ self,
+ sch: Schedule,
+ args_info: List[ArgInfo],
+ ) -> None:
"""Constructor.
Parameters
@@ -72,19 +76,16 @@ class SearchStrategy(Object):
before usage and post-tuned after usage.
"""
- def initialize_with_tune_context(
- self,
- tune_context: "TuneContext",
- ) -> None:
+ def initialize_with_tune_context(self, context: "TuneContext") -> None:
"""Initialize the search strategy with tuning context.
Parameters
----------
- tune_context : TuneContext
+ context : TuneContext
The tuning context for initialization.
"""
_ffi_api.SearchStrategyInitializeWithTuneContext( # type: ignore # pylint: disable=no-member
- self, tune_context
+ self, context
)
def pre_tuning(self, design_spaces: List[Schedule]) -> None:
@@ -111,15 +112,29 @@ class SearchStrategy(Object):
"""
return _ffi_api.SearchStrategyGenerateMeasureCandidates(self) # type: ignore # pylint: disable=no-member
- def notify_runner_results(self, results: List[RunnerResult]) -> None:
+ def notify_runner_results(
+ self,
+ context: "TuneContext",
+ measure_candidates: List[MeasureCandidate],
+ results: List[RunnerResult],
+ ) -> None:
"""Update the search strategy with profiling results.
Parameters
----------
+ context : TuneContext
+ The tuning context for update.
+ measure_candidates : List[MeasureCandidate]
+ The measure candidates for update.
results : List[RunnerResult]
The profiling results from the runner.
"""
- _ffi_api.SearchStrategyNotifyRunnerResults(self, results) # type: ignore # pylint: disable=no-member
+ _ffi_api.SearchStrategyNotifyRunnerResults( # type: ignore # pylint: disable=no-member
+ self,
+ context,
+ measure_candidates,
+ results,
+ )
@register_object("meta_schedule.PySearchStrategy")
@@ -146,8 +161,12 @@ class PySearchStrategy(SearchStrategy):
return self.generate_measure_candidates()
@check_override(self.__class__, SearchStrategy)
- def f_notify_runner_results(results: List["RunnerResult"]) -> None:
- self.notify_runner_results(results)
+ def f_notify_runner_results(
+ context: "TuneContext",
+ measure_candidates: List[MeasureCandidate],
+ results: List["RunnerResult"],
+ ) -> None:
+ self.notify_runner_results(context, measure_candidates, results)
self.__init_handle_by_constructor__(
_ffi_api.SearchStrategyPySearchStrategy, # type: ignore # pylint: disable=no-member
diff --git a/python/tvm/meta_schedule/space_generator/schedule_fn.py b/python/tvm/meta_schedule/space_generator/schedule_fn.py
index 64edd9e..8a57a84 100644
--- a/python/tvm/meta_schedule/space_generator/schedule_fn.py
+++ b/python/tvm/meta_schedule/space_generator/schedule_fn.py
@@ -51,12 +51,12 @@ class ScheduleFn(PySpaceGenerator):
super().__init__()
self.sch_fn = sch_fn
- def initialize_with_tune_context(self, tune_context: "TuneContext") -> None:
+ def initialize_with_tune_context(self, context: "TuneContext") -> None:
"""Initialize the design space generator with tuning context.
Parameters
----------
- tune_context : TuneContext
+ context : TuneContext
The tuning context for initializing the design space generator.
"""
diff --git a/python/tvm/meta_schedule/space_generator/space_generator.py b/python/tvm/meta_schedule/space_generator/space_generator.py
index 2172613..e0b0ab2 100644
--- a/python/tvm/meta_schedule/space_generator/space_generator.py
+++ b/python/tvm/meta_schedule/space_generator/space_generator.py
@@ -36,16 +36,16 @@ if TYPE_CHECKING:
class SpaceGenerator(Object):
"""The abstract design space generator interface."""
- def initialize_with_tune_context(self, tune_context: "TuneContext") -> None:
+ def initialize_with_tune_context(self, context: "TuneContext") -> None:
"""Initialize the design space generator with tuning context.
Parameters
----------
- tune_context : TuneContext
+ context : TuneContext
The tuning context for initializing the design space generator.
"""
_ffi_api.SpaceGeneratorInitializeWithTuneContext( # type: ignore # pylint: disable=no-member
- self, tune_context
+ self, context
)
def generate_design_space(self, mod: IRModule) -> List[Schedule]:
@@ -72,8 +72,8 @@ class PySpaceGenerator(SpaceGenerator):
"""Constructor."""
@check_override(self.__class__, SpaceGenerator)
- def f_initialize_with_tune_context(tune_context: "TuneContext") -> None:
- self.initialize_with_tune_context(tune_context)
+ def f_initialize_with_tune_context(context: "TuneContext") -> None:
+ self.initialize_with_tune_context(context)
@check_override(self.__class__, SpaceGenerator)
def f_generate_design_space(mod: IRModule) -> List[Schedule]:
diff --git a/python/tvm/meta_schedule/task_scheduler/round_robin.py b/python/tvm/meta_schedule/task_scheduler/round_robin.py
index 391011b..a63d9a3 100644
--- a/python/tvm/meta_schedule/task_scheduler/round_robin.py
+++ b/python/tvm/meta_schedule/task_scheduler/round_robin.py
@@ -16,13 +16,15 @@
# under the License.
"""Round Robin Task Scheduler"""
-from typing import List, TYPE_CHECKING
+from typing import List, Optional, TYPE_CHECKING
from tvm._ffi import register_object
+from tvm.meta_schedule.measure_callback.measure_callback import MeasureCallback
from ..builder import Builder
from ..runner import Runner
from ..database import Database
+from ..cost_model import CostModel
from .task_scheduler import TaskScheduler
from .. import _ffi_api
@@ -33,7 +35,21 @@ if TYPE_CHECKING:
@register_object("meta_schedule.RoundRobin")
class RoundRobin(TaskScheduler):
- """Round Robin Task Scheduler"""
+ """Round Robin Task Scheduler
+
+ Parameters
+ ----------
+ tasks: List[TuneContext]
+ The list of tune context to process.
+ builder: Builder
+ The builder of the scheduler.
+ runner: Runner
+ The runner of the scheduler.
+ database: Database
+ The database of the scheduler.
+ measure_callbacks: Optional[List[MeasureCallback]] = None
+ The list of measure callbacks of the scheduler.
+ """
def __init__(
self,
@@ -41,6 +57,8 @@ class RoundRobin(TaskScheduler):
builder: Builder,
runner: Runner,
database: Database,
+ cost_model: Optional[CostModel] = None,
+ measure_callbacks: Optional[List[MeasureCallback]] = None,
) -> None:
"""Constructor.
@@ -54,6 +72,8 @@ class RoundRobin(TaskScheduler):
The runner.
database : Database
The database.
+ measure_callbacks: Optional[List[MeasureCallback]]
+ The list of measure callbacks of the scheduler.
"""
self.__init_handle_by_constructor__(
_ffi_api.TaskSchedulerRoundRobin, # type: ignore # pylint: disable=no-member
@@ -61,4 +81,6 @@ class RoundRobin(TaskScheduler):
builder,
runner,
database,
+ cost_model,
+ measure_callbacks,
)
diff --git a/python/tvm/meta_schedule/task_scheduler/task_scheduler.py b/python/tvm/meta_schedule/task_scheduler/task_scheduler.py
index aeea154..dd8e3fe 100644
--- a/python/tvm/meta_schedule/task_scheduler/task_scheduler.py
+++ b/python/tvm/meta_schedule/task_scheduler/task_scheduler.py
@@ -16,14 +16,16 @@
# under the License.
"""Auto-tuning Task Scheduler"""
-from typing import List
+from typing import List, Optional
from tvm._ffi import register_object
+from tvm.meta_schedule.measure_callback.measure_callback import MeasureCallback
from tvm.runtime import Object
from ..runner import Runner
from ..builder import Builder
from ..database import Database
+from ..cost_model import CostModel
from ..tune_context import TuneContext
from .. import _ffi_api
from ..utils import check_override
@@ -43,12 +45,16 @@ class TaskScheduler(Object):
The runner of the scheduler.
database: Database
The database of the scheduler.
+ measure_callbacks: List[MeasureCallback] = None
+ The list of measure callbacks of the scheduler.
"""
tasks: List[TuneContext]
builder: Builder
runner: Runner
database: Database
+ cost_model: Optional[CostModel]
+ measure_callbacks: List[MeasureCallback]
def tune(self) -> None:
"""Auto-tuning."""
@@ -59,7 +65,7 @@ class TaskScheduler(Object):
Returns
-------
- int
+ next_task_id : int
The next task id.
"""
return _ffi_api.TaskSchedulerNextTaskId(self) # type: ignore # pylint: disable=no-member
@@ -94,7 +100,7 @@ class TaskScheduler(Object):
Returns
-------
- bool
+ running : bool
Whether the task is running.
"""
return _ffi_api.TaskSchedulerIsTaskRunning(self, task_id) # type: ignore # pylint: disable=no-member
@@ -120,6 +126,8 @@ class PyTaskScheduler(TaskScheduler):
builder: Builder,
runner: Runner,
database: Database,
+ cost_model: Optional[CostModel] = None,
+ measure_callbacks: Optional[List[MeasureCallback]] = None,
):
"""Constructor.
@@ -133,6 +141,10 @@ class PyTaskScheduler(TaskScheduler):
The runner of the scheduler.
database: Database
The database of the scheduler.
+ cost_model: Optional[CostModel]
+ The cost model of the scheduler.
+ measure_callbacks: List[MeasureCallback]
+ The list of measure callbacks of the scheduler.
"""
@check_override(self.__class__, TaskScheduler, required=False)
@@ -173,6 +185,8 @@ class PyTaskScheduler(TaskScheduler):
builder,
runner,
database,
+ cost_model,
+ measure_callbacks,
f_tune,
f_initialize_task,
f_set_task_stopped,
diff --git a/python/tvm/meta_schedule/tune_context.py b/python/tvm/meta_schedule/tune_context.py
index 99b8c7e..196b1c1 100644
--- a/python/tvm/meta_schedule/tune_context.py
+++ b/python/tvm/meta_schedule/tune_context.py
@@ -16,13 +16,14 @@
# under the License.
"""Meta Schedule tuning context."""
-from typing import Optional, List, TYPE_CHECKING
+from typing import Optional, List, Dict, TYPE_CHECKING
from tvm import IRModule
from tvm._ffi import register_object
from tvm.meta_schedule.utils import cpu_count
from tvm.runtime import Object
from tvm.target import Target
+from tvm.tir import PrimFunc
from . import _ffi_api
@@ -30,6 +31,8 @@ if TYPE_CHECKING:
from .space_generator import SpaceGenerator
from .search_strategy import SearchStrategy
from .schedule_rule import ScheduleRule
+ from .postproc import Postproc
+ from .mutator import Mutator
@register_object("meta_schedule.TuneContext")
@@ -53,6 +56,10 @@ class TuneContext(Object):
The search strategy.
sch_rules: Optional[List[ScheduleRule]] = None,
The schedule rules.
+ postprocs: Optional[List[Postproc"]] = None,
+ The postprocessors.
+ mutator_probs: Optional[Dict[Mutator, float]]
+ Mutators and their probability mass.
task_name : Optional[str] = None
The name of the tuning task.
rand_state : int = -1
@@ -71,23 +78,31 @@ class TuneContext(Object):
mod: Optional[IRModule]
target: Optional[Target]
- space_generator: "SpaceGenerator"
- search_strategy: "SearchStrategy"
- task_name: Optional[str]
+ space_generator: Optional["SpaceGenerator"]
+ search_strategy: Optional["SearchStrategy"]
+ sch_rules: List["ScheduleRule"]
+ postprocs: List["Postproc"]
+ mutator_probs: Optional[Dict["Mutator", float]]
+ task_name: str
rand_state: int
num_threads: int
def __init__(
self,
mod: Optional[IRModule] = None,
+ *,
target: Optional[Target] = None,
space_generator: Optional["SpaceGenerator"] = None,
search_strategy: Optional["SearchStrategy"] = None,
sch_rules: Optional[List["ScheduleRule"]] = None,
- task_name: Optional[str] = None,
+ postprocs: Optional[List["Postproc"]] = None,
+ mutator_probs: Optional[Dict["Mutator", float]] = None,
+ task_name: str = "main",
rand_state: int = -1,
num_threads: Optional[int] = None,
):
+ if isinstance(mod, PrimFunc):
+ mod = IRModule.from_expr(mod)
if num_threads is None:
num_threads = cpu_count()
@@ -98,6 +113,8 @@ class TuneContext(Object):
space_generator,
search_strategy,
sch_rules,
+ postprocs,
+ mutator_probs,
task_name,
rand_state,
num_threads,
diff --git a/src/meta_schedule/cost_model/cost_model.cc b/src/meta_schedule/cost_model/cost_model.cc
index 5cd32b0..c6efb54 100644
--- a/src/meta_schedule/cost_model/cost_model.cc
+++ b/src/meta_schedule/cost_model/cost_model.cc
@@ -53,10 +53,10 @@ TVM_REGISTER_GLOBAL("meta_schedule.CostModelUpdate")
.set_body_method<CostModel>(&CostModelNode::Update);
TVM_REGISTER_GLOBAL("meta_schedule.CostModelPredict")
.set_body_typed([](CostModel model, //
- const TuneContext& tune_context, //
+ const TuneContext& context, //
Array<MeasureCandidate> candidates, //
void* p_addr) -> void {
- std::vector<double> result = model->Predict(tune_context, candidates);
+ std::vector<double> result = model->Predict(context, candidates);
std::copy(result.begin(), result.end(), static_cast<double*>(p_addr));
});
TVM_REGISTER_GLOBAL("meta_schedule.CostModelPyCostModel").set_body_typed(CostModel::PyCostModel);
diff --git a/src/meta_schedule/search_strategy/replay_trace.cc b/src/meta_schedule/search_strategy/replay_trace.cc
index 200eca3..1eac10d 100644
--- a/src/meta_schedule/search_strategy/replay_trace.cc
+++ b/src/meta_schedule/search_strategy/replay_trace.cc
@@ -24,20 +24,18 @@ namespace meta_schedule {
/*! \brief A search strategy that generates measure candidates using trace and random decisions. */
class ReplayTraceNode : public SearchStrategyNode {
public:
- using TRandState = support::LinearCongruentialEngine::TRandState;
-
/*! \brief The state of the search strategy. */
struct State {
/*! \brief The search strategy itself */
ReplayTraceNode* self;
/*! \brief The design spaces. */
- Array<tir::Schedule> design_spaces;
+ Array<tir::Trace> design_spaces;
/*! \brief `[st, ed)` are the indices of the next batch of candidates. */
int st;
/*! \brief `[st, ed)` are the indices of the next batch of candidates. */
int ed;
- explicit State(ReplayTraceNode* self, Array<tir::Schedule> design_spaces)
+ explicit State(ReplayTraceNode* self, Array<tir::Trace> design_spaces)
: self(self), design_spaces(design_spaces), st(0), ed(self->num_trials_per_iter) {}
inline Optional<Array<MeasureCandidate>> GenerateMeasureCandidates();
@@ -50,9 +48,11 @@ class ReplayTraceNode : public SearchStrategyNode {
int num_trials_total;
/*! \brief The module to be tuned. */
- IRModule mod_{nullptr};
+ Array<IRModule> per_thread_mod_{nullptr};
/*! \brief The metadata of the function arguments. */
Array<ArgInfo> args_info_{nullptr};
+ /*! \brief The post processors */
+ Array<Postproc> postprocs_{nullptr};
/*! \brief The number of threads to use. -1 means using logical cpu number. */
int num_threads_ = -1;
/*! \brief The random state. -1 means using random number. */
@@ -63,8 +63,9 @@ class ReplayTraceNode : public SearchStrategyNode {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("num_trials_per_iter", &num_trials_per_iter);
v->Visit("num_trials_total", &num_trials_total);
- // `mod_` is not visited
+ // `per_thread_mod_` is not visited
// `args_info_` is not visited
+ // `postprocs_` is not visited
// `num_threads_` is not visited
// `rand_state_` is not visited
// `state_` is not visited
@@ -74,9 +75,16 @@ class ReplayTraceNode : public SearchStrategyNode {
TVM_DECLARE_FINAL_OBJECT_INFO(ReplayTraceNode, SearchStrategyNode);
void InitializeWithTuneContext(const TuneContext& context) final {
- this->mod_ = context->mod.value();
- this->args_info_ = ArgInfo::FromPrimFunc(FindEntryFunc(this->mod_));
+ CHECK(context->num_threads > 0) << "Number of threads has to be larger than 0.";
this->num_threads_ = context->num_threads;
+
+ this->per_thread_mod_.reserve(this->num_threads_);
+ for (int i = 0; i < this->num_threads_; i++) {
+ this->per_thread_mod_.push_back(DeepCopyIRModule(context->mod.value()));
+ }
+
+ this->args_info_ = ArgInfo::FromPrimFunc(FindEntryFunc(context->mod.value()));
+ this->postprocs_ = context->postprocs;
this->rand_state_ = ForkSeed(&context->rand_state);
this->state_.reset();
}
@@ -84,7 +92,12 @@ class ReplayTraceNode : public SearchStrategyNode {
void PreTuning(const Array<tir::Schedule>& design_spaces) final {
ICHECK(!design_spaces.empty());
ICHECK(this->state_ == nullptr);
- this->state_ = std::make_unique<State>(this, design_spaces);
+ Array<tir::Trace> design_space_traces;
+ design_space_traces.reserve(design_spaces.size());
+ for (const tir::Schedule& space : design_spaces) {
+ design_space_traces.push_back(space->trace().value()->Simplified(true));
+ }
+ this->state_ = std::make_unique<State>(this, design_space_traces);
}
void PostTuning() final {
@@ -97,7 +110,9 @@ class ReplayTraceNode : public SearchStrategyNode {
return this->state_->GenerateMeasureCandidates();
}
- void NotifyRunnerResults(const Array<RunnerResult>& results) final {
+ void NotifyRunnerResults(const TuneContext& context,
+ const Array<MeasureCandidate>& measure_candidates,
+ const Array<RunnerResult>& results) final {
ICHECK(this->state_ != nullptr);
this->state_->NotifyRunnerResults(results);
}
@@ -111,19 +126,20 @@ inline Optional<Array<MeasureCandidate>> ReplayTraceNode::State::GenerateMeasure
ICHECK_LT(st, ed);
std::vector<TRandState> per_thread_rand_state = ForkSeed(&self->rand_state_, self->num_threads_);
Array<MeasureCandidate> per_task_result(ed - st, MeasureCandidate{nullptr});
- auto f_worker = [this, &per_thread_rand_state, &per_task_result](int thread_id,
- int task_id) -> void {
+ ThreadedTraceApply pp(self->postprocs_);
+ auto f_worker = [this, &per_thread_rand_state, &per_task_result, &pp](int thread_id,
+ int task_id) -> void {
TRandState& rand_state = per_thread_rand_state[thread_id];
- int design_space_index = tir::SampleInt(&rand_state, 0, design_spaces.size());
- tir::Trace trace = design_spaces[design_space_index]->trace().value();
- tir::Trace new_trace = tir::Trace(trace->insts, {});
- tir::Schedule sch = tir::Schedule::Traced( //
- self->mod_, //
- /*rand_state=*/ForkSeed(&rand_state), //
- /*debug_mode=*/0, //
- /*error_render_level=*/tir::ScheduleErrorRenderLevel::kNone);
- new_trace->ApplyToSchedule(sch, /*remove_postproc=*/true);
- per_task_result.Set(task_id, MeasureCandidate(sch, self->args_info_));
+ IRModule mod = self->per_thread_mod_[thread_id];
+ for (;;) {
+ int design_space_index = tir::SampleInt(&rand_state, 0, design_spaces.size());
+ tir::Trace trace = design_spaces[design_space_index];
+ tir::Trace new_trace = tir::Trace(trace->insts, {});
+ if (Optional<tir::Schedule> sch = pp.Apply(mod, new_trace, &rand_state)) {
+ per_task_result.Set(task_id, MeasureCandidate(sch.value(), self->args_info_));
+ break;
+ }
+ }
};
support::parallel_for_dynamic(0, ed - st, self->num_threads_, f_worker);
return per_task_result;
@@ -142,7 +158,8 @@ SearchStrategy SearchStrategy::ReplayTrace(int num_trials_per_iter, int num_tria
}
TVM_REGISTER_NODE_TYPE(ReplayTraceNode);
-TVM_REGISTER_GLOBAL("meta_schedule.ReplayTrace").set_body_typed(SearchStrategy::ReplayTrace);
+TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyReplayTrace")
+ .set_body_typed(SearchStrategy::ReplayTrace);
} // namespace meta_schedule
} // namespace tvm
diff --git a/src/meta_schedule/task_scheduler/round_robin.cc b/src/meta_schedule/task_scheduler/round_robin.cc
index 3ef5026..72989a2 100644
--- a/src/meta_schedule/task_scheduler/round_robin.cc
+++ b/src/meta_schedule/task_scheduler/round_robin.cc
@@ -52,16 +52,23 @@ class RoundRobinNode final : public TaskSchedulerNode {
}
};
-TaskScheduler TaskScheduler::RoundRobin(Array<TuneContext> tasks, //
- Builder builder, //
- Runner runner, //
- Database database) {
+TaskScheduler TaskScheduler::RoundRobin(Array<TuneContext> tasks, //
+ Builder builder, //
+ Runner runner, //
+ Database database, //
+ Optional<CostModel> cost_model, //
+ Optional<Array<MeasureCallback>> measure_callbacks) {
ObjectPtr<RoundRobinNode> n = make_object<RoundRobinNode>();
n->tasks = tasks;
n->builder = builder;
n->runner = runner;
n->database = database;
+ n->cost_model = cost_model;
+ n->measure_callbacks = measure_callbacks.value_or({});
n->task_id = -1;
+ for (const TuneContext& task : tasks) {
+ task->task_scheduler = n.get();
+ }
return TaskScheduler(n);
}
diff --git a/src/meta_schedule/task_scheduler/task_scheduler.cc b/src/meta_schedule/task_scheduler/task_scheduler.cc
index 08f2b4f..1f3943d 100644
--- a/src/meta_schedule/task_scheduler/task_scheduler.cc
+++ b/src/meta_schedule/task_scheduler/task_scheduler.cc
@@ -16,7 +16,6 @@
* specific language governing permissions and limitations
* under the License.
*/
-
#include "../utils.h"
namespace tvm {
@@ -29,9 +28,9 @@ namespace meta_schedule {
* \param candidates The measure candidates.
* \return An array of the builder results.
*/
-Array<BuilderResult> SendToBuilder(const Builder& builder, //
- const TuneContext& context,
+Array<BuilderResult> SendToBuilder(const Builder& builder, const TuneContext& context,
const Array<MeasureCandidate>& candidates) {
+ LOG(INFO) << "Sending " << candidates.size() << " sample(s) to builder";
Target target = context->target.value();
Array<BuilderInput> inputs;
inputs.reserve(candidates.size());
@@ -45,14 +44,14 @@ Array<BuilderResult> SendToBuilder(const Builder& builder, //
* \brief Send the built measure candidates to runner.
* \param runner The runner to send the candidates to.
* \param context The tuning context.
- * \param candidates The mesure candidates.
+ * \param candidates The measure candidates.
* \param builder_results The builder results.
* \return An array of the runner results.
*/
-Array<RunnerFuture> SendToRunner(const Runner& runner, //
- const TuneContext& context,
+Array<RunnerFuture> SendToRunner(const Runner& runner, const TuneContext& context,
const Array<MeasureCandidate>& candidates,
const Array<BuilderResult>& builder_results) {
+ LOG(INFO) << "Sending " << candidates.size() << " sample(s) to runner";
Target target = context->target.value();
ICHECK_EQ(candidates.size(), builder_results.size());
int n = candidates.size();
@@ -94,54 +93,60 @@ Array<RunnerFuture> SendToRunner(const Runner& runner, //
void TaskSchedulerNode::InitializeTask(int task_id) {
TuneContext task = this->tasks[task_id];
- // Derive the values.
- IRModule mod = task->mod.value();
- SpaceGenerator space = task->space_generator.value();
- SearchStrategy strategy = task->search_strategy.value();
- // Initialize Modules.
- space->InitializeWithTuneContext(task);
- strategy->InitializeWithTuneContext(task);
+ LOG(INFO) << "Initializing task " << task_id << ": " << task->task_name << ", mod =\n"
+ << tir::AsTVMScript(task->mod);
+ this->tasks[task_id]->Initialize();
}
void TaskSchedulerNode::Tune() {
for (int i = 0; i < static_cast<int>(this->tasks.size()); i++) {
+ TuneContext task = tasks[i];
// Check Optional value validity.
- CHECK(tasks[i]->mod.defined()) << "ValueError: Require `context.mod`, but it is not defined";
- CHECK(tasks[i]->space_generator.defined())
+ CHECK(task->mod.defined()) << "ValueError: Require `context.mod`, but it is not defined";
+ CHECK(task->space_generator.defined())
<< "ValueError: Require `context.space_generator`, but it is not defined";
- CHECK(tasks[i]->search_strategy.defined())
+ CHECK(task->search_strategy.defined())
<< "ValueError: Require `context.search_strategy`, but it is not defined";
-
InitializeTask(i);
-
- tasks[i]->search_strategy.value()->PreTuning(
- tasks[i]->space_generator.value()->GenerateDesignSpace(tasks[i]->mod.value()));
+ Array<tir::Schedule> design_spaces =
+ task->space_generator.value()->GenerateDesignSpace(task->mod.value());
+ LOG(INFO) << "Total " << design_spaces.size() << " design space(s) generated";
+ for (int i = 0, n = design_spaces.size(); i < n; ++i) {
+ tir::Schedule sch = design_spaces[i];
+ tir::Trace trace = sch->trace().value();
+ trace = trace->Simplified(true);
+ LOG(INFO) << "Design space #" << i << ":\n"
+ << tir::AsTVMScript(sch->mod()) << "\n"
+ << Concat(trace->AsPython(false), "\n");
+ }
+ task->search_strategy.value()->PreTuning(design_spaces);
}
int running_tasks = tasks.size();
- while (running_tasks > 0) {
- for (int task_id; (task_id = NextTaskId()) != -1;) {
- TuneContext task = tasks[task_id];
- ICHECK(!task->is_stopped);
- ICHECK(!task->runner_futures.defined());
- SearchStrategy strategy = task->search_strategy.value();
- if ((task->measure_candidates = strategy->GenerateMeasureCandidates()).defined()) {
- Array<BuilderResult> builder_results =
- SendToBuilder(this->builder, task, task->measure_candidates.value());
- task->runner_futures =
- SendToRunner(this->runner, task, task->measure_candidates.value(), builder_results);
- } else {
- SetTaskStopped(task_id);
- --running_tasks;
- }
+ for (int task_id; (task_id = NextTaskId()) != -1;) {
+ LOG(INFO) << "Scheduler picks Task #" << task_id << ": " << tasks[task_id]->task_name;
+ TuneContext task = tasks[task_id];
+ ICHECK(!task->is_stopped);
+ ICHECK(!task->runner_futures.defined());
+ SearchStrategy strategy = task->search_strategy.value();
+ if ((task->measure_candidates = strategy->GenerateMeasureCandidates()).defined()) {
+ Array<BuilderResult> builder_results =
+ SendToBuilder(this->builder, task, task->measure_candidates.value());
+ task->builder_results = builder_results;
+ task->runner_futures =
+ SendToRunner(this->runner, task, task->measure_candidates.value(), builder_results);
+ } else {
+ SetTaskStopped(task_id);
+ --running_tasks;
+ LOG(INFO) << "Task #" << task_id << " has finished. Remaining task(s): " << running_tasks;
}
- int n_tasks = this->tasks.size();
- for (int task_id = 0; task_id < n_tasks; ++task_id)
- if (IsTaskRunning(task_id)) {
- TuneContext task = tasks[task_id];
- this->JoinRunningTask(task_id);
- task->search_strategy.value()->PostTuning();
- }
+ }
+ ICHECK_EQ(running_tasks, 0) << "Not all tasks are finished";
+ int n_tasks = this->tasks.size();
+ for (int task_id = 0; task_id < n_tasks; ++task_id) {
+ ICHECK(!IsTaskRunning(task_id)) << "Task #" << task_id << " is still running";
+ TuneContext task = tasks[task_id];
+ task->search_strategy.value()->PostTuning();
}
}
@@ -175,25 +180,20 @@ void TaskSchedulerNode::JoinRunningTask(int task_id) {
for (const RunnerFuture future : task->runner_futures.value()) {
results.push_back(future->Result());
}
- task->search_strategy.value()->NotifyRunnerResults(results);
- task->runner_futures = NullOpt;
- // Add to database
+ task->search_strategy.value()->NotifyRunnerResults(task, task->measure_candidates.value(),
+ results);
+ // Invoke the callbacks
ICHECK(task->measure_candidates.defined());
- ICHECK(results.size() == task->measure_candidates.value().size());
- int index = 0;
- for (const RunnerResult& result : results) {
- if (!result->error_msg.defined() && result->run_secs.defined()) {
- Optional<tir::Trace> trace = task->measure_candidates.value()[index]->sch->trace();
- ICHECK(trace.defined());
- this->database->CommitTuningRecord(TuningRecord(
- /*trace=*/trace.value(),
- /*run_secs=*/result->run_secs.value(),
- /*workload=*/this->database->CommitWorkload(task->mod.value()),
- /*target=*/task->target.value(),
- /*args_info=*/task->measure_candidates.value()[index]->args_info));
- }
- index++;
+ ICHECK(task->builder_results.defined());
+ ICHECK_EQ(results.size(), task->measure_candidates.value().size());
+ ICHECK_EQ(results.size(), task->builder_results.value().size());
+ for (const MeasureCallback& callback : this->measure_callbacks) {
+ callback->Apply(GetRef<TaskScheduler>(this), task_id, task->measure_candidates.value(),
+ task->builder_results.value(), results);
}
+ task->measure_candidates = NullOpt;
+ task->builder_results = NullOpt;
+ task->runner_futures = NullOpt;
}
TaskScheduler TaskScheduler::PyTaskScheduler(
@@ -201,6 +201,8 @@ TaskScheduler TaskScheduler::PyTaskScheduler(
Builder builder, //
Runner runner, //
Database database, //
+ Optional<CostModel> cost_model, //
+ Optional<Array<MeasureCallback>> measure_callbacks, //
PyTaskSchedulerNode::FTune f_tune, //
PyTaskSchedulerNode::FInitializeTask f_initialize_task, //
PyTaskSchedulerNode::FSetTaskStopped f_set_task_stopped, //
@@ -212,6 +214,12 @@ TaskScheduler TaskScheduler::PyTaskScheduler(
n->builder = builder;
n->runner = runner;
n->database = database;
+ n->cost_model = cost_model;
+ if (measure_callbacks.defined()) {
+ n->measure_callbacks = measure_callbacks.value();
+ } else {
+ n->measure_callbacks = {};
+ }
n->f_tune = f_tune;
n->f_initialize_task = f_initialize_task;
n->f_set_task_stopped = f_set_task_stopped;
diff --git a/src/meta_schedule/tune_context.cc b/src/meta_schedule/tune_context.cc
index ac85d43..f4595d3 100644
--- a/src/meta_schedule/tune_context.cc
+++ b/src/meta_schedule/tune_context.cc
@@ -16,7 +16,6 @@
* specific language governing permissions and limitations
* under the License.
*/
-#include <random>
#include <utility>
#include "./utils.h"
@@ -24,21 +23,13 @@
namespace tvm {
namespace meta_schedule {
-/*!
- * \brief Constructor function of TuneContext class.
- * \param mod The mod to be optimized.
- * \param target The target to be optimized for.
- * \param space_generator The design space generator.
- * \param task_name The name of the tuning task.
- * \param rand_state The random state.
- * \param num_threads The number of threads to be used.
- * \param verbose The verbosity level.
- */
TuneContext::TuneContext(Optional<IRModule> mod, //
Optional<Target> target, //
Optional<SpaceGenerator> space_generator, //
Optional<SearchStrategy> search_strategy, //
Optional<Array<ScheduleRule>> sch_rules, //
+ Optional<Array<Postproc>> postprocs, //
+ Optional<Map<Mutator, FloatImm>> mutator_probs, //
Optional<String> task_name, //
support::LinearCongruentialEngine::TRandState rand_state, //
int num_threads) {
@@ -48,9 +39,11 @@ TuneContext::TuneContext(Optional<IRModule> mod,
n->space_generator = space_generator;
n->search_strategy = search_strategy;
n->sch_rules = sch_rules.value_or({});
+ n->postprocs = postprocs.value_or({});
+ n->mutator_probs = mutator_probs.value_or({});
n->task_name = task_name;
if (rand_state == -1) {
- rand_state = std::random_device()();
+ rand_state = support::LinearCongruentialEngine::DeviceRandom();
}
support::LinearCongruentialEngine(&n->rand_state).Seed(rand_state);
n->num_threads = num_threads;
@@ -60,6 +53,26 @@ TuneContext::TuneContext(Optional<IRModule> mod,
data_ = std::move(n);
}
+void TuneContextNode::Initialize() {
+ if (this->space_generator.defined()) {
+ this->space_generator.value()->InitializeWithTuneContext(GetRef<TuneContext>(this));
+ }
+ if (this->search_strategy.defined()) {
+ this->search_strategy.value()->InitializeWithTuneContext(GetRef<TuneContext>(this));
+ }
+ for (const ScheduleRule& sch_rule : sch_rules) {
+ sch_rule->InitializeWithTuneContext(GetRef<TuneContext>(this));
+ }
+ for (const Postproc& postproc : postprocs) {
+ postproc->InitializeWithTuneContext(GetRef<TuneContext>(this));
+ }
+ if (mutator_probs.defined()) {
+ for (const auto& kv : mutator_probs) {
+ kv.first->InitializeWithTuneContext(GetRef<TuneContext>(this));
+ }
+ }
+}
+
TVM_REGISTER_NODE_TYPE(TuneContextNode);
TVM_REGISTER_GLOBAL("meta_schedule.TuneContext")
@@ -68,11 +81,13 @@ TVM_REGISTER_GLOBAL("meta_schedule.TuneContext")
Optional<SpaceGenerator> space_generator, //
Optional<SearchStrategy> search_strategy, //
Optional<Array<ScheduleRule>> sch_rules, //
+ Optional<Array<Postproc>> postprocs, //
+ Optional<Map<Mutator, FloatImm>> mutator_probs, //
Optional<String> task_name, //
support::LinearCongruentialEngine::TRandState rand_state, //
int num_threads) -> TuneContext {
- return TuneContext(mod, target, space_generator, search_strategy, sch_rules, task_name,
- rand_state, num_threads);
+ return TuneContext(mod, target, space_generator, search_strategy, sch_rules, postprocs,
+ mutator_probs, task_name, rand_state, num_threads);
});
} // namespace meta_schedule
} // namespace tvm
diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h
index 0a9ce4a..3e989e4 100644
--- a/src/meta_schedule/utils.h
+++ b/src/meta_schedule/utils.h
@@ -200,7 +200,7 @@ inline support::LinearCongruentialEngine::TRandState ForkSeed(
/*!
* \brief Fork a random state into another ones, i.e. PRNG splitting.
- * The given random state is also mutated.
+ * The given random state is also mutated.
* \param rand_state The random state to be forked
* \param n The number of forks
* \return The forked random states
@@ -216,6 +216,15 @@ inline std::vector<support::LinearCongruentialEngine::TRandState> ForkSeed(
}
/*!
+ * \brief Get deep copy of an IRModule.
+ * \param mod The IRModule to make a deep copy.
+ * \return The deep copy of the IRModule.
+ */
+inline IRModule DeepCopyIRModule(IRModule mod) {
+ return Downcast<IRModule>(LoadJSON(SaveJSON(mod)));
+}
+
+/*!
* \brief Concatenate strings
* \param strs The strings to concatenate
* \param delim The delimiter
@@ -233,6 +242,78 @@ inline std::string Concat(const Array<String>& strs, const std::string& delim) {
return os.str();
}
+/*!
+ * \brief A helper data structure that replays a trace and collects failure counts
+ * for each postprocessor
+ */
+struct ThreadedTraceApply {
+ /*! \brief Constructor */
+ explicit ThreadedTraceApply(const Array<Postproc>& postprocs)
+ : n_(postprocs.size()), items_(new Item[n_]) {
+ for (int i = 0; i < n_; ++i) {
+ items_[i].postproc = postprocs[i];
+ items_[i].fail_counter = 0;
+ }
+ }
+
+ /*! \brief Destructor */
+ ~ThreadedTraceApply() { delete[] items_; }
+
+ /*!
+ * \brief Apply the trace and postprocessors to an IRModule
+ * \param mod The IRModule to be applied
+ * \param trace The trace to apply to the IRModule
+ * \param rand_state The random seed
+ * \return The schedule created, or NullOpt if any postprocessor fails
+ */
+ Optional<tir::Schedule> Apply(const IRModule& mod, const tir::Trace& trace,
+ TRandState* rand_state) {
+ tir::Schedule sch =
+ tir::Schedule::Traced(mod,
+ /*rand_state=*/ForkSeed(rand_state),
+ /*debug_mode=*/0,
+ /*error_render_level=*/tir::ScheduleErrorRenderLevel::kNone);
+ trace->ApplyToSchedule(sch, /*remove_postproc=*/true);
+ sch->EnterPostproc();
+ for (int i = 0; i < n_; ++i) {
+ Item& item = items_[i];
+ if (!item.postproc->Apply(sch)) {
+ ++item.fail_counter;
+ return NullOpt;
+ }
+ }
+ return sch;
+ }
+
+ /*! \brief Returns a string summarizing the failures on each postprocessor */
+ std::string SummarizeFailures() const {
+ std::ostringstream os;
+ for (int i = 0; i < n_; ++i) {
+ const Item& item = items_[i];
+ os << "Postproc #" << i << " [" << item.postproc //
+ << "]: " << item.fail_counter.load() << " failure(s)";
+ if (i != n_ - 1) {
+ os << "\n";
+ }
+ }
+ return os.str();
+ }
+
+ private:
+ /*! \brief A helper data structure that stores the fail count for each postprocessor. */
+ struct Item {
+ /*! \brief The postprocessor. */
+ Postproc postproc{nullptr};
+ /*! \brief The thread-safe postprocessor failure counter. */
+ std::atomic<int> fail_counter{0};
+ };
+
+ /*! \brief The number of total postprocessors. */
+ int n_;
+ /*! \brief The pointer to the list of postprocessor items. */
+ Item* items_;
+};
+
} // namespace meta_schedule
} // namespace tvm
diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc
index 65886da..37d896a 100644
--- a/src/tir/schedule/concrete_schedule.cc
+++ b/src/tir/schedule/concrete_schedule.cc
@@ -30,7 +30,7 @@ Schedule Schedule::Concrete(IRModule mod, support::LinearCongruentialEngine::TRa
n->error_render_level_ = error_render_level;
n->symbol_table_ = {};
n->analyzer_ = std::make_unique<arith::Analyzer>();
- support::LinearCongruentialEngine(&n->rand_state_).Seed(seed);
+ n->Seed(seed);
return Schedule(std::move(n));
}
diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc
index 6128366..b4d1ba0 100644
--- a/src/tir/schedule/traced_schedule.cc
+++ b/src/tir/schedule/traced_schedule.cc
@@ -29,7 +29,7 @@ Schedule Schedule::Traced(IRModule mod, support::LinearCongruentialEngine::TRand
n->symbol_table_ = {};
n->analyzer_ = std::make_unique<arith::Analyzer>();
n->trace_ = Trace();
- support::LinearCongruentialEngine(&n->rand_state_).Seed(seed);
+ n->Seed(seed);
return Schedule(std::move(n));
}
diff --git a/tests/python/unittest/test_meta_schedule_cost_model.py b/tests/python/unittest/test_meta_schedule_cost_model.py
index 3f98d71..5b409be 100644
--- a/tests/python/unittest/test_meta_schedule_cost_model.py
+++ b/tests/python/unittest/test_meta_schedule_cost_model.py
@@ -62,15 +62,13 @@ def test_meta_schedule_cost_model():
def update(
self,
- tune_context: TuneContext,
+ context: TuneContext,
candidates: List[MeasureCandidate],
results: List[RunnerResult],
) -> None:
pass
- def predict(
- self, tune_context: TuneContext, candidates: List[MeasureCandidate]
- ) -> np.ndarray:
+ def predict(self, context: TuneContext, candidates: List[MeasureCandidate]) -> np.ndarray:
return np.random.rand(10)
model = FancyCostModel()
@@ -91,15 +89,13 @@ def test_meta_schedule_cost_model_as_string():
def update(
self,
- tune_context: TuneContext,
+ context: TuneContext,
candidates: List[MeasureCandidate],
results: List[RunnerResult],
) -> None:
pass
- def predict(
- self, tune_context: TuneContext, candidates: List[MeasureCandidate]
- ) -> np.ndarray:
+ def predict(self, context: TuneContext, candidates: List[MeasureCandidate]) -> np.ndarray:
return np.random.rand(10)
cost_model = NotSoFancyCostModel()
diff --git a/tests/python/unittest/test_meta_schedule_feature_extractor.py b/tests/python/unittest/test_meta_schedule_feature_extractor.py
index 143d446..d95397b 100644
--- a/tests/python/unittest/test_meta_schedule_feature_extractor.py
+++ b/tests/python/unittest/test_meta_schedule_feature_extractor.py
@@ -28,7 +28,7 @@ def test_meta_schedule_feature_extractor():
class FancyFeatureExtractor(PyFeatureExtractor):
def extract_from(
self,
- tune_context: TuneContext, # pylint: disable = unused-argument
+ context: TuneContext, # pylint: disable = unused-argument
candidates: List[MeasureCandidate], # pylint: disable = unused-argument
) -> List[np.ndarray]:
return [np.random.rand(4, 5)]
@@ -43,7 +43,7 @@ def test_meta_schedule_feature_extractor_as_string():
class NotSoFancyFeatureExtractor(PyFeatureExtractor):
def extract_from(
self,
- tune_context: TuneContext, # pylint: disable = unused-argument
+ context: TuneContext, # pylint: disable = unused-argument
candidates: List[MeasureCandidate], # pylint: disable = unused-argument
) -> List[np.ndarray]:
return []
diff --git a/tests/python/unittest/test_meta_schedule_post_order_apply.py b/tests/python/unittest/test_meta_schedule_post_order_apply.py
index b78e678..78477e6 100644
--- a/tests/python/unittest/test_meta_schedule_post_order_apply.py
+++ b/tests/python/unittest/test_meta_schedule_post_order_apply.py
@@ -135,7 +135,7 @@ def _check_correct(schedule: Schedule):
class WowSoFancyScheduleRule(PyScheduleRule):
- def initialize_with_tune_context(self, tune_context: "TuneContext") -> None:
+ def initialize_with_tune_context(self, context: "TuneContext") -> None:
pass
def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]:
@@ -151,7 +151,7 @@ class WowSoFancyScheduleRule(PyScheduleRule):
class DoubleScheduleRule(PyScheduleRule):
- def initialize_with_tune_context(self, tune_context: "TuneContext") -> None:
+ def initialize_with_tune_context(self, context: "TuneContext") -> None:
pass
def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]:
@@ -175,7 +175,7 @@ class DoubleScheduleRule(PyScheduleRule):
class ReorderScheduleRule(PyScheduleRule):
- def initialize_with_tune_context(self, tune_context: "TuneContext") -> None:
+ def initialize_with_tune_context(self, context: "TuneContext") -> None:
pass
def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]:
@@ -262,7 +262,7 @@ def test_meta_schedule_post_order_apply_duplicate_matmul():
def test_meta_schedule_post_order_apply_remove_block():
class TrinityDouble(PyScheduleRule):
- def initialize_with_tune_context(self, tune_context: "TuneContext") -> None:
+ def initialize_with_tune_context(self, context: "TuneContext") -> None:
pass
def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]:
@@ -283,7 +283,7 @@ def test_meta_schedule_post_order_apply_remove_block():
return result
class RemoveBlock(PyScheduleRule):
- def initialize_with_tune_context(self, tune_context: "TuneContext") -> None:
+ def initialize_with_tune_context(self, context: "TuneContext") -> None:
pass
def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]:
diff --git a/tests/python/unittest/test_meta_schedule_search_strategy.py b/tests/python/unittest/test_meta_schedule_search_strategy.py
index 9b3ddfd..668fca9 100644
--- a/tests/python/unittest/test_meta_schedule_search_strategy.py
+++ b/tests/python/unittest/test_meta_schedule_search_strategy.py
@@ -16,25 +16,25 @@
# under the License.
""" Test Meta Schedule SearchStrategy """
# pylint: disable=missing-function-docstring
-from typing import List
-
import sys
-
import pytest
+from typing import List
import tvm
from tvm.meta_schedule import TuneContext
from tvm.meta_schedule.runner import RunnerResult
+from tvm.meta_schedule.search_strategy import (
+ ReplayTrace,
+ SearchStrategy,
+)
from tvm.meta_schedule.space_generator import ScheduleFn
-from tvm.meta_schedule.search_strategy import ReplayTrace
-
from tvm.script import tir as T
from tvm.tir.schedule import Schedule, Trace
MATMUL_M = 32
-# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument, unbalanced-tuple-unpacking
+# pylint: disable=missing-class-docstring,invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument, unbalanced-tuple-unpacking
# fmt: off
@tvm.script.ir_module
@@ -53,48 +53,57 @@ class Matmul:
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
# fmt: on
-# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument
+# pylint: enable=missing-class-docstring,invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument
-def _is_trace_equal(sch_1: Schedule, sch_2: Schedule) -> bool:
- trace_1 = Trace(sch_1.trace.insts, {})
- trace_2 = Trace(sch_2.trace.insts, {})
+def _is_trace_equal(sch_1: Schedule, sch_2: Schedule, remove_decisions=True) -> bool:
+ if remove_decisions:
+ trace_1 = Trace(sch_1.trace.insts, {})
+ trace_2 = Trace(sch_2.trace.insts, {})
+ else:
+ trace_1 = sch_1.trace
+ trace_2 = sch_2.trace
return str(trace_1) == str(trace_2)
def _schedule_matmul(sch: Schedule):
block = sch.get_block("matmul")
i, j, k = sch.get_loops(block=block)
- # TODO(@zxybazh): Change to `sample_perfect_tile` after upstreaming
- i_0, i_1, i_2, i_3 = sch.split(loop=i, factors=[2, 4, 64, 2])
- j_0, j_1, j_2, j_3 = sch.split(loop=j, factors=[4, 64, 2, 2])
- k_0, k_1 = sch.split(loop=k, factors=[32, 32])
+ i_0, i_1, i_2, i_3 = sch.split(i, sch.sample_perfect_tile(i, n=4))
+ j_0, j_1, j_2, j_3 = sch.split(j, sch.sample_perfect_tile(j, n=4))
+ k_0, k_1 = sch.split(k, sch.sample_perfect_tile(k, n=2))
sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3)
-def test_meta_schedule_replay_trace():
+@pytest.mark.parametrize("TestClass", [ReplayTrace])
+def test_meta_schedule_replay_func(TestClass: SearchStrategy): # pylint: disable = invalid-name
num_trials_per_iter = 7
num_trials_total = 20
- (example_sch,) = ScheduleFn(sch_fn=_schedule_matmul).generate_design_space(Matmul)
- replay = ReplayTrace(num_trials_per_iter=num_trials_per_iter, num_trials_total=num_trials_total)
- tune_context = TuneContext(mod=Matmul)
- replay.initialize_with_tune_context(tune_context)
-
- num_trials_each_round: List[int] = []
- replay.pre_tuning([example_sch])
- while True:
- candidates = replay.generate_measure_candidates()
- if candidates is None:
- break
- num_trials_each_round.append(len(candidates))
+ strategy = TestClass(num_trials_per_iter=num_trials_per_iter, num_trials_total=num_trials_total)
+ context = TuneContext(mod=Matmul, space_generator=ScheduleFn(sch_fn=_schedule_matmul))
+ context.space_generator.initialize_with_tune_context(context)
+ spaces = context.space_generator.generate_design_space(context.mod)
+
+ strategy.initialize_with_tune_context(context)
+ strategy.pre_tuning(spaces)
+ (correct_sch,) = ScheduleFn(sch_fn=_schedule_matmul).generate_design_space(Matmul)
+ num_trials_each_iter: List[int] = []
+ candidates = strategy.generate_measure_candidates()
+ while candidates is not None:
+ num_trials_each_iter.append(len(candidates))
runner_results: List[RunnerResult] = []
for candidate in candidates:
- assert _is_trace_equal(candidate.sch, example_sch)
- runner_results.append(RunnerResult(run_secs=[0.5, 0.4, 0.3], error_msg=None))
- replay.notify_runner_results(runner_results)
- replay.post_tuning()
- assert num_trials_each_round == [7, 7, 6]
+ _is_trace_equal(
+ candidate.sch,
+ correct_sch,
+ remove_decisions=(isinstance(strategy, ReplayTrace)),
+ )
+ runner_results.append(RunnerResult(run_secs=[0.11, 0.41, 0.54], error_msg=None))
+ strategy.notify_runner_results(context, candidates, runner_results)
+ candidates = strategy.generate_measure_candidates()
+ strategy.post_tuning()
+ assert num_trials_each_iter == [7, 7, 6]
if __name__ == "__main__":
diff --git a/tests/python/unittest/test_meta_schedule_task_scheduler.py b/tests/python/unittest/test_meta_schedule_task_scheduler.py
index 7eb61ad..d3c4dbc 100644
--- a/tests/python/unittest/test_meta_schedule_task_scheduler.py
+++ b/tests/python/unittest/test_meta_schedule_task_scheduler.py
@@ -16,24 +16,22 @@
# under the License.
""" Test Meta Schedule Task Scheduler """
-from typing import List
-
-import sys
import random
+import sys
+from typing import List
import pytest
-
import tvm
-from tvm.script import tir as T
from tvm.ir import IRModule
-from tvm.tir import Schedule
-from tvm.meta_schedule import TuneContext
-from tvm.meta_schedule.space_generator import ScheduleFn
-from tvm.meta_schedule.search_strategy import ReplayTrace
-from tvm.meta_schedule.builder import PyBuilder, BuilderInput, BuilderResult
-from tvm.meta_schedule.runner import PyRunner, RunnerInput, RunnerFuture, RunnerResult
+from tvm.meta_schedule import TuneContext, measure_callback
+from tvm.meta_schedule.builder import BuilderInput, BuilderResult, PyBuilder
from tvm.meta_schedule.database import PyDatabase, TuningRecord, Workload
-from tvm.meta_schedule.task_scheduler import RoundRobin, PyTaskScheduler
+from tvm.meta_schedule.runner import PyRunner, RunnerFuture, RunnerInput, RunnerResult
+from tvm.meta_schedule.search_strategy import ReplayTrace
+from tvm.meta_schedule.space_generator import ScheduleFn
+from tvm.meta_schedule.task_scheduler import PyTaskScheduler, RoundRobin
+from tvm.script import tir as T
+from tvm.tir import Schedule
# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring
@@ -140,7 +138,10 @@ class DummyDatabase(PyDatabase):
self.records = []
self.workload_reg = []
- def has_workload(self, mod: IRModule) -> bool:
+ def has_workload(self, mod: IRModule) -> Workload:
+ for workload in self.workload_reg:
+ if tvm.ir.structural_equal(workload.mod, mod):
+ return True
return False
def commit_tuning_record(self, record: TuningRecord) -> None:
@@ -183,7 +184,13 @@ def test_meta_schedule_task_scheduler_single():
rand_state=42,
)
database = DummyDatabase()
- round_robin = RoundRobin([task], DummyBuilder(), DummyRunner(), database)
+ round_robin = RoundRobin(
+ [task],
+ DummyBuilder(),
+ DummyRunner(),
+ database,
+ measure_callbacks=[measure_callback.AddToDatabase()],
+ )
round_robin.tune()
assert len(database) == num_trials_total
@@ -218,15 +225,29 @@ def test_meta_schedule_task_scheduler_multiple():
),
]
database = DummyDatabase()
- round_robin = RoundRobin(tasks, DummyBuilder(), DummyRunner(), database)
+ round_robin = RoundRobin(
+ tasks,
+ DummyBuilder(),
+ DummyRunner(),
+ database,
+ measure_callbacks=[measure_callback.AddToDatabase()],
+ )
round_robin.tune()
assert len(database) == num_trials_total * len(tasks)
print(database.workload_reg)
for task in tasks:
- assert len(database.get_top_k(database.commit_workload(task.mod), 1e9)) == num_trials_total
+ assert (
+ len(
+ database.get_top_k(
+ database.commit_workload(task.mod),
+ 100000,
+ )
+ )
+ == num_trials_total
+ )
-def test_meta_schedule_task_scheduler_NIE():
+def test_meta_schedule_task_scheduler_not_implemented_error(): # pylint: disable=invalid-name
class MyTaskScheduler(PyTaskScheduler):
pass
@@ -234,7 +255,7 @@ def test_meta_schedule_task_scheduler_NIE():
MyTaskScheduler([], DummyBuilder(), DummyRunner(), DummyDatabase())
-def test_meta_schedule_task_scheduler_override_next_task_id_only():
+def test_meta_schedule_task_scheduler_override_next_task_id_only(): # pylint: disable=invalid-name
class MyTaskScheduler(PyTaskScheduler):
done = set()
@@ -291,11 +312,27 @@ def test_meta_schedule_task_scheduler_override_next_task_id_only():
),
]
database = DummyDatabase()
- scheduler = MyTaskScheduler(tasks, DummyBuilder(), DummyRunner(), database)
+ scheduler = MyTaskScheduler(
+ tasks,
+ DummyBuilder(),
+ DummyRunner(),
+ database,
+ measure_callbacks=[
+ measure_callback.AddToDatabase(),
+ ],
+ )
scheduler.tune()
assert len(database) == num_trials_total * len(tasks)
for task in tasks:
- assert len(database.get_top_k(database.commit_workload(task.mod), 1e9)) == num_trials_total
+ assert (
+ len(
+ database.get_top_k(
+ database.commit_workload(task.mod),
+ 100000,
+ )
+ )
+ == num_trials_total
+ )
if __name__ == "__main__":