You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by co...@apache.org on 2021/01/19 06:18:37 UTC

[tvm] branch main updated: [AutoScheduler] Bug fix & Custom sketch support (#7260)

This is an automated email from the ASF dual-hosted git repository.

comaniac pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 5d95105  [AutoScheduler] Bug fix & Custom sketch support (#7260)
5d95105 is described below

commit 5d95105a553bccb75c0cd428025d7904d876da0d
Author: Chenfan <ch...@alibaba-inc.com>
AuthorDate: Tue Jan 19 14:18:19 2021 +0800

    [AutoScheduler] Bug fix & Custom sketch support (#7260)
---
 python/tvm/auto_scheduler/__init__.py              |  7 +++-
 python/tvm/auto_scheduler/search_policy.py         | 35 ++++++++++++++++-
 python/tvm/auto_scheduler/search_task.py           |  5 ++-
 python/tvm/auto_scheduler/task_scheduler.py        | 14 ++++++-
 src/auto_scheduler/search_policy/sketch_policy.cc  | 25 ++++++++++++
 src/auto_scheduler/search_policy/sketch_policy.h   | 34 ++++++++++++++++
 .../search_policy/sketch_policy_rules.cc           | 27 +++++++++++++
 .../search_policy/sketch_policy_rules.h            | 23 +++++++++++
 .../unittest/test_auto_scheduler_search_policy.py  | 27 +++++++++++++
 .../test_auto_scheduler_sketch_generation.py       | 45 +++++++++++++++++++++-
 10 files changed, 234 insertions(+), 8 deletions(-)

diff --git a/python/tvm/auto_scheduler/__init__.py b/python/tvm/auto_scheduler/__init__.py
index a03e156..57e5830 100644
--- a/python/tvm/auto_scheduler/__init__.py
+++ b/python/tvm/auto_scheduler/__init__.py
@@ -50,6 +50,11 @@ from .relay_integration import (
     is_auto_scheduler_enabled,
 )
 from .search_task import SearchTask, TuningOptions, HardwareParams, create_task, auto_schedule
-from .search_policy import EmptyPolicy, SketchPolicy, PreloadMeasuredStates
+from .search_policy import (
+    EmptyPolicy,
+    SketchPolicy,
+    PreloadMeasuredStates,
+    PreloadCustomSketchRule,
+)
 from .task_scheduler import TaskScheduler
 from .workload_registry import register_workload, make_workload_key
diff --git a/python/tvm/auto_scheduler/search_policy.py b/python/tvm/auto_scheduler/search_policy.py
index 5b15a48..f0388a8 100644
--- a/python/tvm/auto_scheduler/search_policy.py
+++ b/python/tvm/auto_scheduler/search_policy.py
@@ -61,6 +61,39 @@ class PreloadMeasuredStates(SearchCallback):
         self.__init_handle_by_constructor__(_ffi_api.PreloadMeasuredStates, filename)
 
 
+@tvm._ffi.register_object("auto_scheduler.PreloadCustomSketchRule")
+class PreloadCustomSketchRule(SearchCallback):
+    """
+    A SearchCallback for SketchSearchPolicy that allows users to add
+    custom sketch rule.
+
+    Notes
+    -----
+    This is an advanced feature. Make sure you're clear how it works and this should only be used
+    in SketchSearchPolicy.
+
+    Parameters
+    ----------
+    meet_condition_func: Callable
+        A function with `(policy, state, stage_id) -> int`. Should return one of the result
+        enumeration.
+    apply_func: Callable
+        A function with `(policy, state, stage_id) -> [[State, int], ...]`.
+    rule_name: str = "CustomSketchRule"
+        The name of this custom sketch rule.
+    """
+
+    # Result enumeration of the condition function.
+    PASS = 0  # Skip this rule and continue to try the next rules.
+    APPLY = 1  # Apply this rule and continue to try the next rules.
+    APPLY_AND_SKIP_REST = 2  # Apply this rule and skip the rest rules.
+
+    def __init__(self, meet_condition_func, apply_func, rule_name="CustomSketchRule"):
+        self.__init_handle_by_constructor__(
+            _ffi_api.PreloadCustomSketchRule, meet_condition_func, apply_func, rule_name
+        )
+
+
 @tvm._ffi.register_object("auto_scheduler.SearchPolicy")
 class SearchPolicy(Object):
     """ The base class of search policies. """
@@ -141,8 +174,6 @@ class SketchPolicy(SearchPolicy):
 
           - auto_scheduler.PreloadMeasuredStates
           - auto_scheduler.PreloadCustomSketchRule
-
-        TODO(jcf94): Add these search callback implementations.
     """
 
     DEFAULT_PARAMS = {
diff --git a/python/tvm/auto_scheduler/search_task.py b/python/tvm/auto_scheduler/search_task.py
index bfa596a..d985ed1 100644
--- a/python/tvm/auto_scheduler/search_task.py
+++ b/python/tvm/auto_scheduler/search_task.py
@@ -228,6 +228,9 @@ class SearchTask(Object):
         if isinstance(target_host, str):
             target_host = Target(target_host)
 
+        if layout_rewrite_option is None:
+            layout_rewrite_option = LayoutRewriteOption.get_target_default(target)
+
         self.__init_handle_by_constructor__(
             _ffi_api.SearchTask,
             compute_dag,
@@ -235,7 +238,7 @@ class SearchTask(Object):
             target,
             target_host,
             hardware_params,
-            layout_rewrite_option or LayoutRewriteOption.get_target_default(target),
+            layout_rewrite_option,
         )
 
     def tune(self, tuning_options, search_policy=None):
diff --git a/python/tvm/auto_scheduler/task_scheduler.py b/python/tvm/auto_scheduler/task_scheduler.py
index 975306f..420b5f7 100644
--- a/python/tvm/auto_scheduler/task_scheduler.py
+++ b/python/tvm/auto_scheduler/task_scheduler.py
@@ -72,7 +72,7 @@ def make_search_policies(
         Load measurement records from this file. If it is not None, the status of the
         task scheduler, search policies and cost models will be restored according to this file.
     adapative_training: bool = False
-        Option used for XGBModel, which will reduce the model training frequency when there're too
+        Option used by XGBModel to reduce the model training frequency when there're too
         many logs.
 
     Returns
@@ -275,7 +275,13 @@ class TaskScheduler:
                 self.group_task_ids.append([])
             self.group_task_ids[self.tag_to_group_id[tag]].append(i)
 
-    def tune(self, tune_option, search_policy="default", search_policy_params=None):
+    def tune(
+        self,
+        tune_option,
+        search_policy="default",
+        search_policy_params=None,
+        adapative_training=False,
+    ):
         """Tune a batch of tasks together.
 
         Parameters
@@ -290,6 +296,9 @@ class TaskScheduler:
             "sketch.random" for SketchPolicy + RandomModel.
         search_policy_params : Optional[Dict[str, Any]]
             The parameters of the search policy
+        adapative_training : bool = False
+            Option used by XGBModel to reduce the model training frequency when there're
+            too many logs.
         """
         # init members
         self.tune_option = tune_option
@@ -324,6 +333,7 @@ class TaskScheduler:
             tune_option.verbose,
             self.load_model_file,
             self.load_log_file,
+            adapative_training,
         )
 
         # do a round robin first to warm up
diff --git a/src/auto_scheduler/search_policy/sketch_policy.cc b/src/auto_scheduler/search_policy/sketch_policy.cc
index 1e20b0f..91721af 100644
--- a/src/auto_scheduler/search_policy/sketch_policy.cc
+++ b/src/auto_scheduler/search_policy/sketch_policy.cc
@@ -671,6 +671,26 @@ Array<MeasureInput> SketchPolicyNode::PickStatesWithEpsGreedy(const Array<State>
   return inputs;
 }
 
+/********** PreloadCustomSketchRule **********/
+TVM_REGISTER_OBJECT_TYPE(PreloadCustomSketchRuleNode);
+
+PreloadCustomSketchRule::PreloadCustomSketchRule(PackedFunc meet_condition_func,
+                                                 PackedFunc apply_func, String rule_name) {
+  auto node = make_object<PreloadCustomSketchRuleNode>();
+  node->meet_condition_func = std::move(meet_condition_func);
+  node->apply_func = std::move(apply_func);
+  node->rule_name = std::move(rule_name);
+  data_ = std::move(node);
+}
+
+void PreloadCustomSketchRuleNode::Callback(SearchPolicyNode* policy) {
+  CHECK(policy->IsInstance<SketchPolicyNode>());
+  auto sketch_policy = dynamic_cast<SketchPolicyNode*>(policy);
+  sketch_policy->sketch_rules.push_back(
+      new RuleCustomSketch(meet_condition_func, apply_func, rule_name));
+  StdCout(policy->verbose) << "Custom sketch rule \"" << rule_name << "\" added." << std::endl;
+}
+
 TVM_REGISTER_GLOBAL("auto_scheduler.SketchPolicy")
     .set_body_typed([](SearchTask task, CostModel program_cost_model, Map<String, ObjectRef> params,
                        int seed, int verbose,
@@ -699,5 +719,10 @@ TVM_REGISTER_GLOBAL("auto_scheduler.PrintTitle").set_body_typed([](std::string t
   PrintTitle(title, 1);
 });
 
+TVM_REGISTER_GLOBAL("auto_scheduler.PreloadCustomSketchRule")
+    .set_body_typed([](PackedFunc meet_condition_func, PackedFunc apply_func, String rule_name) {
+      return PreloadCustomSketchRule(meet_condition_func, apply_func, rule_name);
+    });
+
 }  // namespace auto_scheduler
 }  // namespace tvm
diff --git a/src/auto_scheduler/search_policy/sketch_policy.h b/src/auto_scheduler/search_policy/sketch_policy.h
index 4886349..faf058b 100644
--- a/src/auto_scheduler/search_policy/sketch_policy.h
+++ b/src/auto_scheduler/search_policy/sketch_policy.h
@@ -197,6 +197,40 @@ class SketchPolicy : public SearchPolicy {
   TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(SketchPolicy, SearchPolicy, SketchPolicyNode);
 };
 
+/*! \brief Pre-search callback function to load custom rules for sketch generation */
+class PreloadCustomSketchRuleNode : public SearchCallbackNode {
+ public:
+  /*! \brief The condition check function of this rule. */
+  PackedFunc meet_condition_func;
+  /*! \brief The apply function of this rule. */
+  PackedFunc apply_func;
+  /*! \brief The name of this rule. */
+  String rule_name;
+
+  void Callback(SearchPolicyNode* policy) final;
+
+  static constexpr const char* _type_key = "auto_scheduler.PreloadCustomSketchRule";
+  TVM_DECLARE_FINAL_OBJECT_INFO(PreloadCustomSketchRuleNode, SearchCallbackNode);
+};
+
+/*!
+ * \brief Managed reference to PreloadCustomSketchRuleNode.
+ * \sa PreloadCustomSketchRuleNode
+ */
+class PreloadCustomSketchRule : public SearchCallback {
+ public:
+  /*!
+   * \brief The constructor.
+   * \param meet_condition_func The condition check function of this rule.
+   * \param apply_func The apply function of this rule.
+   * \param rule_name The name of this rule.
+   */
+  PreloadCustomSketchRule(PackedFunc meet_condition_func, PackedFunc apply_func, String rule_name);
+
+  TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PreloadCustomSketchRule, SearchCallback,
+                                        PreloadCustomSketchRuleNode);
+};
+
 }  // namespace auto_scheduler
 }  // namespace tvm
 
diff --git a/src/auto_scheduler/search_policy/sketch_policy_rules.cc b/src/auto_scheduler/search_policy/sketch_policy_rules.cc
index f704fe9..110be6b 100644
--- a/src/auto_scheduler/search_policy/sketch_policy_rules.cc
+++ b/src/auto_scheduler/search_policy/sketch_policy_rules.cc
@@ -461,6 +461,33 @@ std::vector<std::pair<State, int>> RuleSpecialComputeLocationGPU::Apply(
   return {std::make_pair(std::move(tmp_s), stage_id - 1)};
 }
 
+/********** RuleCustomSketch **********/
+
+SketchGenerationRule::ConditionKind RuleCustomSketch::MeetCondition(const SketchPolicyNode& policy,
+                                                                    const State& state,
+                                                                    int stage_id) const {
+  auto ret = meet_condition_func_(tvm::runtime::GetRef<SketchPolicy>(&policy), state, stage_id);
+  if (ret.type_code() == 0) {
+    return ConditionKind(static_cast<int>(ret));
+  } else {
+    LOG(WARNING) << "Wrong rule condition value. Apply the rule and skip the rest";
+    return ConditionKind::kApplyAndSkipRest;
+  }
+}
+
+std::vector<std::pair<State, int>> RuleCustomSketch::Apply(const SketchPolicyNode& policy,
+                                                           const State& state, int stage_id) const {
+  Array<Array<ObjectRef>> apply_ret =
+      apply_func_(tvm::runtime::GetRef<SketchPolicy>(&policy), state, stage_id);
+  std::vector<std::pair<State, int>> ret;
+  for (const auto& item : apply_ret) {
+    CHECK_EQ(item.size(), 2);
+    auto next = item[1].as<IntImmNode>();
+    ret.emplace_back(Downcast<State>(item[0]), next->value);
+  }
+  return ret;
+}
+
 /********** Init Population **********/
 
 PopulationGenerationRule::ResultKind InitFillTileSize::Apply(SketchPolicyNode* policy, State* state,
diff --git a/src/auto_scheduler/search_policy/sketch_policy_rules.h b/src/auto_scheduler/search_policy/sketch_policy_rules.h
index 046f036..fc1916b 100644
--- a/src/auto_scheduler/search_policy/sketch_policy_rules.h
+++ b/src/auto_scheduler/search_policy/sketch_policy_rules.h
@@ -131,6 +131,29 @@ DEFINE_SKETCH_GENERATION_RULE(RuleCrossThreadReduction);
  * location of the producers of compute ops that perform "fake reduction" with const tensors. */
 DEFINE_SKETCH_GENERATION_RULE(RuleSpecialComputeLocationGPU);
 
+/*! \brief The rule that allows users to generate custom sketches. */
+class RuleCustomSketch : public SketchGenerationRule {
+ public:
+  RuleCustomSketch(PackedFunc meet_condition_func, PackedFunc apply_func,
+                   String rule_name = "CustomSketchRule")
+      : meet_condition_func_(std::move(meet_condition_func)),
+        apply_func_(std::move(apply_func)),
+        rule_name_(std::move(rule_name)) {}
+
+  ConditionKind MeetCondition(const SketchPolicyNode& policy, const State& state,
+                              int stage_id) const final;
+
+  std::vector<std::pair<State, int>> Apply(const SketchPolicyNode& policy, const State& state,
+                                           int stage_id) const final;
+
+  std::string GetRuleName() const final { return rule_name_; }
+
+ private:
+  PackedFunc meet_condition_func_;
+  PackedFunc apply_func_;
+  String rule_name_;
+};
+
 /********** Init Population **********/
 
 /*! \brief The base class for rules used to annotate the sketches to get the initial population. */
diff --git a/tests/python/unittest/test_auto_scheduler_search_policy.py b/tests/python/unittest/test_auto_scheduler_search_policy.py
index c96dc63..30aafbd 100644
--- a/tests/python/unittest/test_auto_scheduler_search_policy.py
+++ b/tests/python/unittest/test_auto_scheduler_search_policy.py
@@ -183,6 +183,32 @@ def test_sketch_search_policy_zero_rank():
         search_common(task, runner=measure_ctx.runner)
 
 
+@tvm.testing.requires_llvm
+def test_sketch_search_policy_custom_sketch():
+    def meet_condition_func(search_policy, state, stage_id):
+        return auto_scheduler.PreloadCustomSketchRule.APPLY_AND_SKIP_REST
+
+    def apply_func(search_policy, state, stage_id):
+        ret = []
+        state = auto_scheduler.loop_state.State(state, search_policy.search_task.compute_dag)
+        C = state.stage_ops[2]
+
+        ret.append([state.state_object, -1])
+
+        s1 = state.copy()
+        i, _, _ = s1[C].iters
+        s1.split(C, i, [8])
+        ret.append([s1.state_object, -1])
+        return ret
+
+    search_common(
+        cost_model=auto_scheduler.XGBModel(),
+        init_search_callbacks=[
+            auto_scheduler.PreloadCustomSketchRule(meet_condition_func, apply_func)
+        ],
+    )
+
+
 if __name__ == "__main__":
     test_workload_registry_empty_policy()
     test_sketch_search_policy_basic()
@@ -191,3 +217,4 @@ if __name__ == "__main__":
     test_sketch_search_policy_cuda_rpc_runner()
     test_sketch_search_policy_cuda_xgbmodel_rpc_runner()
     test_sketch_search_policy_zero_rank()
+    test_sketch_search_policy_custom_sketch()
diff --git a/tests/python/unittest/test_auto_scheduler_sketch_generation.py b/tests/python/unittest/test_auto_scheduler_sketch_generation.py
index ddff6dd..f3be6c0 100644
--- a/tests/python/unittest/test_auto_scheduler_sketch_generation.py
+++ b/tests/python/unittest/test_auto_scheduler_sketch_generation.py
@@ -36,9 +36,13 @@ from test_auto_scheduler_common import (
 )
 
 
-def generate_sketches(workload_func, args, target, print_for_debug=False):
+def generate_sketches(
+    workload_func, args, target, print_for_debug=False, init_search_callbacks=None
+):
     task = auto_scheduler.SearchTask(func=workload_func, args=args, target=target)
-    policy = auto_scheduler.SketchPolicy(task, verbose=0)
+    policy = auto_scheduler.SketchPolicy(
+        task, verbose=0, init_search_callbacks=init_search_callbacks
+    )
     return policy.generate_sketches(print_for_debug)
 
 
@@ -259,6 +263,42 @@ def test_cpu_zero_rank_sketch():
     assert len(sketches) == 3
 
 
+def test_cpu_custom_sketch():
+    def meet_condition_func(search_policy, state, stage_id):
+        return auto_scheduler.PreloadCustomSketchRule.APPLY_AND_SKIP_REST
+
+    def apply_func(search_policy, state, stage_id):
+        ret = []
+        state = auto_scheduler.loop_state.State(state, search_policy.search_task.compute_dag)
+        C = state.stage_ops[2]
+
+        ret.append([state.state_object, -1])
+
+        s1 = state.copy()
+        i, _, _ = s1[C].iters
+        s1.split(C, i, [8, 2])
+        ret.append([s1.state_object, -1])
+        return ret
+
+    sketches = generate_sketches(
+        matmul_auto_scheduler_test,
+        (512, 512, 512),
+        "llvm",
+        init_search_callbacks=[
+            auto_scheduler.PreloadCustomSketchRule(meet_condition_func, apply_func)
+        ],
+    )
+    assert len(sketches) == 2
+    assert sketches[0].stages[2].iters[0].range.extent == 512
+    assert sketches[0].stages[2].iters[1].range.extent == 512
+    assert sketches[0].stages[2].iters[2].range.extent == 512
+    assert sketches[1].stages[2].iters[0].range.extent == 32
+    assert sketches[1].stages[2].iters[1].range.extent == 8
+    assert sketches[1].stages[2].iters[2].range.extent == 2
+    assert sketches[1].stages[2].iters[3].range.extent == 512
+    assert sketches[1].stages[2].iters[4].range.extent == 512
+
+
 @tvm.testing.requires_cuda
 def test_cuda_matmul_sketch():
     sketches = generate_sketches(matmul_auto_scheduler_test, (512, 512, 512), "cuda")
@@ -407,6 +447,7 @@ if __name__ == "__main__":
     test_cpu_softmax_sketch()
     test_cpu_conv2d_winograd_sketch()
     test_cpu_zero_rank_sketch()
+    test_cpu_custom_sketch()
     test_cuda_matmul_sketch()
     test_cuda_conv2d_bn_relu_sketch()
     test_cuda_max_pool2d_sketch()