You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/06/24 05:10:11 UTC

[tvm] branch main updated: [MetaSchedule] Introduce ArgInfo::FromEntryFunc (#11866)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new d2cbdf381b [MetaSchedule] Introduce ArgInfo::FromEntryFunc (#11866)
d2cbdf381b is described below

commit d2cbdf381b68134951bfd7525c6a3a67838e5bdf
Author: Hongyi Jin <32...@qq.com>
AuthorDate: Fri Jun 24 13:10:05 2022 +0800

    [MetaSchedule] Introduce ArgInfo::FromEntryFunc (#11866)
---
 include/tvm/meta_schedule/arg_info.h               |  8 ++++++++
 python/tvm/meta_schedule/arg_info.py               | 19 +++++++++++++++++
 src/meta_schedule/arg_info.cc                      | 10 +++++++++
 src/meta_schedule/database/database.cc             |  8 +-------
 .../search_strategy/evolutionary_search.cc         | 24 ++++++++--------------
 src/meta_schedule/search_strategy/replay_func.cc   |  6 ++----
 src/meta_schedule/search_strategy/replay_trace.cc  |  9 ++++----
 7 files changed, 53 insertions(+), 31 deletions(-)

diff --git a/include/tvm/meta_schedule/arg_info.h b/include/tvm/meta_schedule/arg_info.h
index c7dd3c7f65..ccf0931262 100644
--- a/include/tvm/meta_schedule/arg_info.h
+++ b/include/tvm/meta_schedule/arg_info.h
@@ -19,6 +19,7 @@
 #ifndef TVM_META_SCHEDULE_ARG_INFO_H_
 #define TVM_META_SCHEDULE_ARG_INFO_H_
 
+#include <tvm/ir/module.h>
 #include <tvm/node/node.h>
 #include <tvm/node/reflection.h>
 #include <tvm/runtime/container/shape_tuple.h>
@@ -60,6 +61,13 @@ class ArgInfo : public runtime::ObjectRef {
    * \return An array of the argument information derived.
    */
   TVM_DLL static Array<ArgInfo, void> FromPrimFunc(const tir::PrimFunc& func);
+  /*!
+   * \brief Extract a list of the argument information from the entry func of an IRModule
+   * \param mod The IRModule to extract argument information from.
+   * \param remove_preproc Whether to remove the preprocessing blocks.
+   * \return An array of the argument information derived.
+   */
+  TVM_DLL static Array<ArgInfo, void> FromEntryFunc(const IRModule& mod, bool remove_preproc);
 
   TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ArgInfo, runtime::ObjectRef, ArgInfoNode);
 
diff --git a/python/tvm/meta_schedule/arg_info.py b/python/tvm/meta_schedule/arg_info.py
index a56ca86e8c..7390c544a5 100644
--- a/python/tvm/meta_schedule/arg_info.py
+++ b/python/tvm/meta_schedule/arg_info.py
@@ -18,6 +18,7 @@
 from typing import Any, List, Union
 
 from tvm._ffi import register_object
+from tvm.ir import IRModule
 from tvm.runtime import DataType, Object, ShapeTuple
 from tvm.tir import PrimFunc
 
@@ -65,6 +66,24 @@ class ArgInfo(Object):
         """
         return _ffi_api.ArgInfoFromPrimFunc(func)  # type: ignore # pylint: disable=no-member
 
+    @staticmethod
+    def from_entry_func(mod: IRModule, remove_preproc: bool = True) -> List["ArgInfo"]:
+        """Extract a list of the argument information from the entry func of an IRModule.
+
+        Parameters
+        ----------
+        mod : IRModule
+            The IRModule to get argument information from.
+        remove_preproc : bool
+            Whether to remove the preprocessing blocks.
+
+        Returns
+        -------
+        extracted : List[ArgInfo]
+            An array of the argument information derived.
+        """
+        return _ffi_api.ArgInfoFromEntryFunc(mod, remove_preproc)  # type: ignore # pylint: disable=no-member
+
 
 @register_object("meta_schedule.TensorInfo")
 class TensorInfo(ArgInfo):
diff --git a/src/meta_schedule/arg_info.cc b/src/meta_schedule/arg_info.cc
index 9b225e8bea..37897a5ac6 100644
--- a/src/meta_schedule/arg_info.cc
+++ b/src/meta_schedule/arg_info.cc
@@ -60,6 +60,15 @@ Array<ArgInfo> ArgInfo::FromPrimFunc(const tir::PrimFunc& func) {
   return result;
 }
 
+Array<ArgInfo> ArgInfo::FromEntryFunc(const IRModule& mod, bool remove_preproc) {
+  // TODO(@jinhongyii): add pass for layout rewrite
+  // if (remove_preproc) {
+  //   IRModule new_mod = tir::transform::RemoveWeightLayoutRewriteBlock()(mod);
+  //   return ArgInfo::FromPrimFunc(FindEntryFunc(new_mod));
+  // }
+  return ArgInfo::FromPrimFunc(FindEntryFunc(mod));
+}
+
 /******** TensorInfo ********/
 
 TensorInfo::TensorInfo(runtime::DataType dtype, runtime::ShapeTuple shape) {
@@ -112,6 +121,7 @@ TVM_REGISTER_NODE_TYPE(TensorInfoNode);
 
 TVM_REGISTER_GLOBAL("meta_schedule.ArgInfoAsJSON").set_body_method<ArgInfo>(&ArgInfoNode::AsJSON);
 TVM_REGISTER_GLOBAL("meta_schedule.ArgInfoFromPrimFunc").set_body_typed(ArgInfo::FromPrimFunc);
+TVM_REGISTER_GLOBAL("meta_schedule.ArgInfoFromEntryFunc").set_body_typed(ArgInfo::FromEntryFunc);
 TVM_REGISTER_GLOBAL("meta_schedule.ArgInfoFromJSON").set_body_typed(ArgInfo::FromJSON);
 TVM_REGISTER_GLOBAL("meta_schedule.TensorInfo")
     .set_body_typed([](runtime::DataType dtype, runtime::ShapeTuple shape) -> TensorInfo {
diff --git a/src/meta_schedule/database/database.cc b/src/meta_schedule/database/database.cc
index 5adff49984..4e180c4fab 100644
--- a/src/meta_schedule/database/database.cc
+++ b/src/meta_schedule/database/database.cc
@@ -89,13 +89,7 @@ MeasureCandidate TuningRecordNode::AsMeasureCandidate() const {
   tir::Schedule sch =
       tir::Schedule::Traced(workload->mod, -1, 0, tir::ScheduleErrorRenderLevel::kDetail);
   trace->ApplyToSchedule(sch, false, nullptr);
-  tir::PrimFunc func;
-  for (const auto& kv : sch->mod()->functions) {
-    func = Downcast<tir::PrimFunc>(kv.second);
-  }
-  Array<ArgInfo> args_info = ArgInfo::FromPrimFunc(func);
-  MeasureCandidate candidate = MeasureCandidate(sch, args_info);
-  return candidate;
+  return MeasureCandidate(sch, ArgInfo::FromEntryFunc(sch->mod(), /*remove_preproc=*/true));
 }
 
 ObjectRef TuningRecordNode::AsJSON() const {
diff --git a/src/meta_schedule/search_strategy/evolutionary_search.cc b/src/meta_schedule/search_strategy/evolutionary_search.cc
index 3b672639aa..c5ff9008ef 100644
--- a/src/meta_schedule/search_strategy/evolutionary_search.cc
+++ b/src/meta_schedule/search_strategy/evolutionary_search.cc
@@ -200,12 +200,12 @@ struct ConcurrentBitmask {
  * \param traces The picked candidate traces.
  * \return The assembled measure candidates.
  */
-Array<MeasureCandidate> AssembleCandidates(const std::vector<Schedule>& picks,
-                                           const Array<ArgInfo>& args_info) {
+Array<MeasureCandidate> AssembleCandidates(const std::vector<Schedule>& picks) {
   Array<MeasureCandidate> measure_inputs;
   measure_inputs.reserve(picks.size());
   for (const Schedule& sch : picks) {
-    measure_inputs.push_back(MeasureCandidate(sch, args_info));
+    measure_inputs.push_back(
+        MeasureCandidate(sch, ArgInfo::FromEntryFunc(sch->mod(), /*remove_preproc=*/true)));
   }
   return measure_inputs;
 }
@@ -218,12 +218,11 @@ Array<MeasureCandidate> AssembleCandidates(const std::vector<Schedule>& picks,
  * \return The normalized score in the prediction
  */
 std::vector<double> PredictNormalizedScore(const std::vector<Schedule>& candidates,
-                                           const TuneContext& context, const CostModel& cost_model,
-                                           const Array<ArgInfo>& args_info) {
+                                           const TuneContext& context,
+                                           const CostModel& cost_model) {
   auto _ = Profiler::TimedScope("EvoSearch/Evolve/PredictNormalizedScore");
   ICHECK(!candidates.empty()) << "Candidates given for score prediction can not be empty list!";
-  std::vector<double> scores =
-      cost_model->Predict(context, AssembleCandidates(candidates, args_info));
+  std::vector<double> scores = cost_model->Predict(context, AssembleCandidates(candidates));
   for (double& score : scores) {
     score = std::max(0.0, score);
   }
@@ -247,8 +246,6 @@ class EvolutionarySearchNode : public SearchStrategyNode {
     int ed;
     /*! \brief The counter of returning empty results. */
     int num_empty_iters;
-    /*! \brief The metadata of the function arguments. */
-    Array<ArgInfo> args_info_{nullptr};
     /*! \brief Pre thread data including module to be tuned and random state. */
     std::vector<PerThreadData> per_thread_data_;
     /*!
@@ -272,7 +269,6 @@ class EvolutionarySearchNode : public SearchStrategyNode {
           num_empty_iters(0) {
       const TuneContextNode* ctx = self->context_;
       IRModule mod = ctx->mod.value();
-      this->args_info_ = ArgInfo::FromPrimFunc(FindEntryFunc(mod));
       this->per_thread_data_.resize(ctx->num_threads);
       for (PerThreadData& data : this->per_thread_data_) {
         data.mod = DeepCopyIRModule(mod);
@@ -509,10 +505,8 @@ std::vector<Schedule> EvolutionarySearchNode::State::EvolveWithCostModel(
   SizedHeap heap(num);
   for (int iter = 0;; ++iter) {
     // Predict normalized score with the cost model,
-    std::vector<double> scores = PredictNormalizedScore(population,                           //
-                                                        GetRef<TuneContext>(self->context_),  //
-                                                        this->cost_model_,                    //
-                                                        this->args_info_);
+    std::vector<double> scores =
+        PredictNormalizedScore(population, GetRef<TuneContext>(self->context_), this->cost_model_);
 
     {
       auto _ = Profiler::TimedScope("EvoSearch/Evolve/Misc");
@@ -695,7 +689,7 @@ Optional<Array<MeasureCandidate>> EvolutionarySearchNode::State::GenerateMeasure
       return NullOpt;
     }
   }
-  return AssembleCandidates(picks, this->args_info_);
+  return AssembleCandidates(picks);
 }
 
 void EvolutionarySearchNode::State::NotifyRunnerResults(
diff --git a/src/meta_schedule/search_strategy/replay_func.cc b/src/meta_schedule/search_strategy/replay_func.cc
index 24bc38ae80..4574c1c817 100644
--- a/src/meta_schedule/search_strategy/replay_func.cc
+++ b/src/meta_schedule/search_strategy/replay_func.cc
@@ -32,13 +32,10 @@ class ReplayFuncNode : public SearchStrategyNode {
     int st;
     /*! \brief `[st, ed)` are the indices of the next batch of candidates. */
     int ed;
-    /*! \brief The metadata of the function arguments. */
-    Array<ArgInfo> args_info_{nullptr};
 
     explicit State(ReplayFuncNode* self) : self(self), st(0), ed(self->num_trials_per_iter) {
       const TuneContextNode* ctx = self->context_;
       ICHECK(ctx);
-      this->args_info_ = ArgInfo::FromPrimFunc(FindEntryFunc(ctx->mod.value()));
     }
 
     inline Optional<Array<MeasureCandidate>> GenerateMeasureCandidates();
@@ -128,7 +125,8 @@ inline Optional<Array<MeasureCandidate>> ReplayFuncNode::State::GenerateMeasureC
         }
       }
       if (!failed) {
-        result.push_back(MeasureCandidate(sch, this->args_info_));
+        Array<ArgInfo> args_info = ArgInfo::FromEntryFunc(sch->mod(), /*remove_preproc=*/true);
+        result.push_back(MeasureCandidate(sch, args_info));
         break;
       }
     }
diff --git a/src/meta_schedule/search_strategy/replay_trace.cc b/src/meta_schedule/search_strategy/replay_trace.cc
index b4b5ef8b31..64fc683943 100644
--- a/src/meta_schedule/search_strategy/replay_trace.cc
+++ b/src/meta_schedule/search_strategy/replay_trace.cc
@@ -37,8 +37,6 @@ class ReplayTraceNode : public SearchStrategyNode {
 
     /*! \brief The module to be tuned. */
     Array<IRModule> per_thread_mod_{nullptr};
-    /*! \brief The metadata of the function arguments. */
-    Array<ArgInfo> args_info_{nullptr};
 
     explicit State(ReplayTraceNode* self, Array<tir::Trace> design_spaces)
         : self(self), design_spaces(design_spaces), st(0), ed(self->num_trials_per_iter) {
@@ -49,7 +47,6 @@ class ReplayTraceNode : public SearchStrategyNode {
       for (int i = 0; i < ctx->num_threads; i++) {
         this->per_thread_mod_.push_back(DeepCopyIRModule(mod));
       }
-      this->args_info_ = ArgInfo::FromPrimFunc(FindEntryFunc(mod));
     }
 
     inline Optional<Array<MeasureCandidate>> GenerateMeasureCandidates();
@@ -143,8 +140,10 @@ inline Optional<Array<MeasureCandidate>> ReplayTraceNode::State::GenerateMeasure
       int design_space_index = tir::SampleInt(&rand_state, 0, design_spaces.size());
       tir::Trace trace = design_spaces[design_space_index];
       tir::Trace new_trace = tir::Trace(trace->insts, {});
-      if (Optional<tir::Schedule> sch = pp.Apply(mod, new_trace, &rand_state)) {
-        per_task_result.Set(task_id, MeasureCandidate(sch.value(), this->args_info_));
+      if (Optional<tir::Schedule> opt_sch = pp.Apply(mod, new_trace, &rand_state)) {
+        tir::Schedule sch = opt_sch.value();
+        Array<ArgInfo> args_info = ArgInfo::FromEntryFunc(sch->mod(), /*remove_preproc=*/true);
+        per_task_result.Set(task_id, MeasureCandidate(sch, args_info));
         break;
       }
     }