You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/06/24 05:10:11 UTC
[tvm] branch main updated: [MetaSchedule] Introduce ArgInfo::FromEntryFunc (#11866)
This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new d2cbdf381b [MetaSchedule] Introduce ArgInfo::FromEntryFunc (#11866)
d2cbdf381b is described below
commit d2cbdf381b68134951bfd7525c6a3a67838e5bdf
Author: Hongyi Jin <32...@qq.com>
AuthorDate: Fri Jun 24 13:10:05 2022 +0800
[MetaSchedule] Introduce ArgInfo::FromEntryFunc (#11866)
---
include/tvm/meta_schedule/arg_info.h | 8 ++++++++
python/tvm/meta_schedule/arg_info.py | 19 +++++++++++++++++
src/meta_schedule/arg_info.cc | 10 +++++++++
src/meta_schedule/database/database.cc | 8 +-------
.../search_strategy/evolutionary_search.cc | 24 ++++++++--------------
src/meta_schedule/search_strategy/replay_func.cc | 6 ++----
src/meta_schedule/search_strategy/replay_trace.cc | 9 ++++----
7 files changed, 53 insertions(+), 31 deletions(-)
diff --git a/include/tvm/meta_schedule/arg_info.h b/include/tvm/meta_schedule/arg_info.h
index c7dd3c7f65..ccf0931262 100644
--- a/include/tvm/meta_schedule/arg_info.h
+++ b/include/tvm/meta_schedule/arg_info.h
@@ -19,6 +19,7 @@
#ifndef TVM_META_SCHEDULE_ARG_INFO_H_
#define TVM_META_SCHEDULE_ARG_INFO_H_
+#include <tvm/ir/module.h>
#include <tvm/node/node.h>
#include <tvm/node/reflection.h>
#include <tvm/runtime/container/shape_tuple.h>
@@ -60,6 +61,13 @@ class ArgInfo : public runtime::ObjectRef {
* \return An array of the argument information derived.
*/
TVM_DLL static Array<ArgInfo, void> FromPrimFunc(const tir::PrimFunc& func);
+ /*!
+ * \brief Extract a list of the argument information from the entry func of an IRModule
+ * \param mod The IRModule to extract argument information from.
+ * \param remove_preproc Whether to remove the preprocessing blocks.
+ * \return An array of the argument information derived.
+ */
+ TVM_DLL static Array<ArgInfo, void> FromEntryFunc(const IRModule& mod, bool remove_preproc);
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ArgInfo, runtime::ObjectRef, ArgInfoNode);
diff --git a/python/tvm/meta_schedule/arg_info.py b/python/tvm/meta_schedule/arg_info.py
index a56ca86e8c..7390c544a5 100644
--- a/python/tvm/meta_schedule/arg_info.py
+++ b/python/tvm/meta_schedule/arg_info.py
@@ -18,6 +18,7 @@
from typing import Any, List, Union
from tvm._ffi import register_object
+from tvm.ir import IRModule
from tvm.runtime import DataType, Object, ShapeTuple
from tvm.tir import PrimFunc
@@ -65,6 +66,24 @@ class ArgInfo(Object):
"""
return _ffi_api.ArgInfoFromPrimFunc(func) # type: ignore # pylint: disable=no-member
+ @staticmethod
+ def from_entry_func(mod: IRModule, remove_preproc: bool = True) -> List["ArgInfo"]:
+ """Extract a list of the argument information from the entry func of an IRModule.
+
+ Parameters
+ ----------
+ mod : IRModule
+ The IRModule to get argument information from.
+ remove_preproc : bool
+ Whether to remove the preprocessing blocks.
+
+ Returns
+ -------
+ extracted : List[ArgInfo]
+ An array of the argument information derived.
+ """
+ return _ffi_api.ArgInfoFromEntryFunc(mod, remove_preproc) # type: ignore # pylint: disable=no-member
+
@register_object("meta_schedule.TensorInfo")
class TensorInfo(ArgInfo):
diff --git a/src/meta_schedule/arg_info.cc b/src/meta_schedule/arg_info.cc
index 9b225e8bea..37897a5ac6 100644
--- a/src/meta_schedule/arg_info.cc
+++ b/src/meta_schedule/arg_info.cc
@@ -60,6 +60,15 @@ Array<ArgInfo> ArgInfo::FromPrimFunc(const tir::PrimFunc& func) {
return result;
}
+Array<ArgInfo> ArgInfo::FromEntryFunc(const IRModule& mod, bool remove_preproc) {
+ // TODO(@jinhongyii): add pass for layout rewrite
+ // if (remove_preproc) {
+ // IRModule new_mod = tir::transform::RemoveWeightLayoutRewriteBlock()(mod);
+ // return ArgInfo::FromPrimFunc(FindEntryFunc(new_mod));
+ // }
+ return ArgInfo::FromPrimFunc(FindEntryFunc(mod));
+}
+
/******** TensorInfo ********/
TensorInfo::TensorInfo(runtime::DataType dtype, runtime::ShapeTuple shape) {
@@ -112,6 +121,7 @@ TVM_REGISTER_NODE_TYPE(TensorInfoNode);
TVM_REGISTER_GLOBAL("meta_schedule.ArgInfoAsJSON").set_body_method<ArgInfo>(&ArgInfoNode::AsJSON);
TVM_REGISTER_GLOBAL("meta_schedule.ArgInfoFromPrimFunc").set_body_typed(ArgInfo::FromPrimFunc);
+TVM_REGISTER_GLOBAL("meta_schedule.ArgInfoFromEntryFunc").set_body_typed(ArgInfo::FromEntryFunc);
TVM_REGISTER_GLOBAL("meta_schedule.ArgInfoFromJSON").set_body_typed(ArgInfo::FromJSON);
TVM_REGISTER_GLOBAL("meta_schedule.TensorInfo")
.set_body_typed([](runtime::DataType dtype, runtime::ShapeTuple shape) -> TensorInfo {
diff --git a/src/meta_schedule/database/database.cc b/src/meta_schedule/database/database.cc
index 5adff49984..4e180c4fab 100644
--- a/src/meta_schedule/database/database.cc
+++ b/src/meta_schedule/database/database.cc
@@ -89,13 +89,7 @@ MeasureCandidate TuningRecordNode::AsMeasureCandidate() const {
tir::Schedule sch =
tir::Schedule::Traced(workload->mod, -1, 0, tir::ScheduleErrorRenderLevel::kDetail);
trace->ApplyToSchedule(sch, false, nullptr);
- tir::PrimFunc func;
- for (const auto& kv : sch->mod()->functions) {
- func = Downcast<tir::PrimFunc>(kv.second);
- }
- Array<ArgInfo> args_info = ArgInfo::FromPrimFunc(func);
- MeasureCandidate candidate = MeasureCandidate(sch, args_info);
- return candidate;
+ return MeasureCandidate(sch, ArgInfo::FromEntryFunc(sch->mod(), /*remove_preproc=*/true));
}
ObjectRef TuningRecordNode::AsJSON() const {
diff --git a/src/meta_schedule/search_strategy/evolutionary_search.cc b/src/meta_schedule/search_strategy/evolutionary_search.cc
index 3b672639aa..c5ff9008ef 100644
--- a/src/meta_schedule/search_strategy/evolutionary_search.cc
+++ b/src/meta_schedule/search_strategy/evolutionary_search.cc
@@ -200,12 +200,12 @@ struct ConcurrentBitmask {
* \param traces The picked candidate traces.
* \return The assembled measure candidates.
*/
-Array<MeasureCandidate> AssembleCandidates(const std::vector<Schedule>& picks,
- const Array<ArgInfo>& args_info) {
+Array<MeasureCandidate> AssembleCandidates(const std::vector<Schedule>& picks) {
Array<MeasureCandidate> measure_inputs;
measure_inputs.reserve(picks.size());
for (const Schedule& sch : picks) {
- measure_inputs.push_back(MeasureCandidate(sch, args_info));
+ measure_inputs.push_back(
+ MeasureCandidate(sch, ArgInfo::FromEntryFunc(sch->mod(), /*remove_preproc=*/true)));
}
return measure_inputs;
}
@@ -218,12 +218,11 @@ Array<MeasureCandidate> AssembleCandidates(const std::vector<Schedule>& picks,
* \return The normalized score in the prediction
*/
std::vector<double> PredictNormalizedScore(const std::vector<Schedule>& candidates,
- const TuneContext& context, const CostModel& cost_model,
- const Array<ArgInfo>& args_info) {
+ const TuneContext& context,
+ const CostModel& cost_model) {
auto _ = Profiler::TimedScope("EvoSearch/Evolve/PredictNormalizedScore");
ICHECK(!candidates.empty()) << "Candidates given for score prediction can not be empty list!";
- std::vector<double> scores =
- cost_model->Predict(context, AssembleCandidates(candidates, args_info));
+ std::vector<double> scores = cost_model->Predict(context, AssembleCandidates(candidates));
for (double& score : scores) {
score = std::max(0.0, score);
}
@@ -247,8 +246,6 @@ class EvolutionarySearchNode : public SearchStrategyNode {
int ed;
/*! \brief The counter of returning empty results. */
int num_empty_iters;
- /*! \brief The metadata of the function arguments. */
- Array<ArgInfo> args_info_{nullptr};
/*! \brief Pre thread data including module to be tuned and random state. */
std::vector<PerThreadData> per_thread_data_;
/*!
@@ -272,7 +269,6 @@ class EvolutionarySearchNode : public SearchStrategyNode {
num_empty_iters(0) {
const TuneContextNode* ctx = self->context_;
IRModule mod = ctx->mod.value();
- this->args_info_ = ArgInfo::FromPrimFunc(FindEntryFunc(mod));
this->per_thread_data_.resize(ctx->num_threads);
for (PerThreadData& data : this->per_thread_data_) {
data.mod = DeepCopyIRModule(mod);
@@ -509,10 +505,8 @@ std::vector<Schedule> EvolutionarySearchNode::State::EvolveWithCostModel(
SizedHeap heap(num);
for (int iter = 0;; ++iter) {
// Predict normalized score with the cost model,
- std::vector<double> scores = PredictNormalizedScore(population, //
- GetRef<TuneContext>(self->context_), //
- this->cost_model_, //
- this->args_info_);
+ std::vector<double> scores =
+ PredictNormalizedScore(population, GetRef<TuneContext>(self->context_), this->cost_model_);
{
auto _ = Profiler::TimedScope("EvoSearch/Evolve/Misc");
@@ -695,7 +689,7 @@ Optional<Array<MeasureCandidate>> EvolutionarySearchNode::State::GenerateMeasure
return NullOpt;
}
}
- return AssembleCandidates(picks, this->args_info_);
+ return AssembleCandidates(picks);
}
void EvolutionarySearchNode::State::NotifyRunnerResults(
diff --git a/src/meta_schedule/search_strategy/replay_func.cc b/src/meta_schedule/search_strategy/replay_func.cc
index 24bc38ae80..4574c1c817 100644
--- a/src/meta_schedule/search_strategy/replay_func.cc
+++ b/src/meta_schedule/search_strategy/replay_func.cc
@@ -32,13 +32,10 @@ class ReplayFuncNode : public SearchStrategyNode {
int st;
/*! \brief `[st, ed)` are the indices of the next batch of candidates. */
int ed;
- /*! \brief The metadata of the function arguments. */
- Array<ArgInfo> args_info_{nullptr};
explicit State(ReplayFuncNode* self) : self(self), st(0), ed(self->num_trials_per_iter) {
const TuneContextNode* ctx = self->context_;
ICHECK(ctx);
- this->args_info_ = ArgInfo::FromPrimFunc(FindEntryFunc(ctx->mod.value()));
}
inline Optional<Array<MeasureCandidate>> GenerateMeasureCandidates();
@@ -128,7 +125,8 @@ inline Optional<Array<MeasureCandidate>> ReplayFuncNode::State::GenerateMeasureC
}
}
if (!failed) {
- result.push_back(MeasureCandidate(sch, this->args_info_));
+ Array<ArgInfo> args_info = ArgInfo::FromEntryFunc(sch->mod(), /*remove_preproc=*/true);
+ result.push_back(MeasureCandidate(sch, args_info));
break;
}
}
diff --git a/src/meta_schedule/search_strategy/replay_trace.cc b/src/meta_schedule/search_strategy/replay_trace.cc
index b4b5ef8b31..64fc683943 100644
--- a/src/meta_schedule/search_strategy/replay_trace.cc
+++ b/src/meta_schedule/search_strategy/replay_trace.cc
@@ -37,8 +37,6 @@ class ReplayTraceNode : public SearchStrategyNode {
/*! \brief The module to be tuned. */
Array<IRModule> per_thread_mod_{nullptr};
- /*! \brief The metadata of the function arguments. */
- Array<ArgInfo> args_info_{nullptr};
explicit State(ReplayTraceNode* self, Array<tir::Trace> design_spaces)
: self(self), design_spaces(design_spaces), st(0), ed(self->num_trials_per_iter) {
@@ -49,7 +47,6 @@ class ReplayTraceNode : public SearchStrategyNode {
for (int i = 0; i < ctx->num_threads; i++) {
this->per_thread_mod_.push_back(DeepCopyIRModule(mod));
}
- this->args_info_ = ArgInfo::FromPrimFunc(FindEntryFunc(mod));
}
inline Optional<Array<MeasureCandidate>> GenerateMeasureCandidates();
@@ -143,8 +140,10 @@ inline Optional<Array<MeasureCandidate>> ReplayTraceNode::State::GenerateMeasure
int design_space_index = tir::SampleInt(&rand_state, 0, design_spaces.size());
tir::Trace trace = design_spaces[design_space_index];
tir::Trace new_trace = tir::Trace(trace->insts, {});
- if (Optional<tir::Schedule> sch = pp.Apply(mod, new_trace, &rand_state)) {
- per_task_result.Set(task_id, MeasureCandidate(sch.value(), this->args_info_));
+ if (Optional<tir::Schedule> opt_sch = pp.Apply(mod, new_trace, &rand_state)) {
+ tir::Schedule sch = opt_sch.value();
+ Array<ArgInfo> args_info = ArgInfo::FromEntryFunc(sch->mod(), /*remove_preproc=*/true);
+ per_task_result.Set(task_id, MeasureCandidate(sch, args_info));
break;
}
}