You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by co...@apache.org on 2021/01/19 06:18:37 UTC
[tvm] branch main updated: [AutoScheduler] Bug fix & Custom sketch
support (#7260)
This is an automated email from the ASF dual-hosted git repository.
comaniac pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 5d95105 [AutoScheduler] Bug fix & Custom sketch support (#7260)
5d95105 is described below
commit 5d95105a553bccb75c0cd428025d7904d876da0d
Author: Chenfan <ch...@alibaba-inc.com>
AuthorDate: Tue Jan 19 14:18:19 2021 +0800
[AutoScheduler] Bug fix & Custom sketch support (#7260)
---
python/tvm/auto_scheduler/__init__.py | 7 +++-
python/tvm/auto_scheduler/search_policy.py | 35 ++++++++++++++++-
python/tvm/auto_scheduler/search_task.py | 5 ++-
python/tvm/auto_scheduler/task_scheduler.py | 14 ++++++-
src/auto_scheduler/search_policy/sketch_policy.cc | 25 ++++++++++++
src/auto_scheduler/search_policy/sketch_policy.h | 34 ++++++++++++++++
.../search_policy/sketch_policy_rules.cc | 27 +++++++++++++
.../search_policy/sketch_policy_rules.h | 23 +++++++++++
.../unittest/test_auto_scheduler_search_policy.py | 27 +++++++++++++
.../test_auto_scheduler_sketch_generation.py | 45 +++++++++++++++++++++-
10 files changed, 234 insertions(+), 8 deletions(-)
diff --git a/python/tvm/auto_scheduler/__init__.py b/python/tvm/auto_scheduler/__init__.py
index a03e156..57e5830 100644
--- a/python/tvm/auto_scheduler/__init__.py
+++ b/python/tvm/auto_scheduler/__init__.py
@@ -50,6 +50,11 @@ from .relay_integration import (
is_auto_scheduler_enabled,
)
from .search_task import SearchTask, TuningOptions, HardwareParams, create_task, auto_schedule
-from .search_policy import EmptyPolicy, SketchPolicy, PreloadMeasuredStates
+from .search_policy import (
+ EmptyPolicy,
+ SketchPolicy,
+ PreloadMeasuredStates,
+ PreloadCustomSketchRule,
+)
from .task_scheduler import TaskScheduler
from .workload_registry import register_workload, make_workload_key
diff --git a/python/tvm/auto_scheduler/search_policy.py b/python/tvm/auto_scheduler/search_policy.py
index 5b15a48..f0388a8 100644
--- a/python/tvm/auto_scheduler/search_policy.py
+++ b/python/tvm/auto_scheduler/search_policy.py
@@ -61,6 +61,39 @@ class PreloadMeasuredStates(SearchCallback):
self.__init_handle_by_constructor__(_ffi_api.PreloadMeasuredStates, filename)
+@tvm._ffi.register_object("auto_scheduler.PreloadCustomSketchRule")
+class PreloadCustomSketchRule(SearchCallback):
+ """
+ A SearchCallback for SketchSearchPolicy that allows users to add
+ custom sketch rule.
+
+ Notes
+ -----
+ This is an advanced feature. Make sure you're clear how it works and this should only be used
+ in SketchSearchPolicy.
+
+ Parameters
+ ----------
+ meet_condition_func: Callable
+ A function with `(policy, state, stage_id) -> int`. Should return one of the result
+ enumeration.
+ apply_func: Callable
+ A function with `(policy, state, stage_id) -> [[State, int], ...]`.
+ rule_name: str = "CustomSketchRule"
+ The name of this custom sketch rule.
+ """
+
+ # Result enumeration of the condition function.
+ PASS = 0 # Skip this rule and continue to try the next rules.
+ APPLY = 1 # Apply this rule and continue to try the next rules.
+ APPLY_AND_SKIP_REST = 2 # Apply this rule and skip the rest rules.
+
+ def __init__(self, meet_condition_func, apply_func, rule_name="CustomSketchRule"):
+ self.__init_handle_by_constructor__(
+ _ffi_api.PreloadCustomSketchRule, meet_condition_func, apply_func, rule_name
+ )
+
+
@tvm._ffi.register_object("auto_scheduler.SearchPolicy")
class SearchPolicy(Object):
""" The base class of search policies. """
@@ -141,8 +174,6 @@ class SketchPolicy(SearchPolicy):
- auto_scheduler.PreloadMeasuredStates
- auto_scheduler.PreloadCustomSketchRule
-
- TODO(jcf94): Add these search callback implementations.
"""
DEFAULT_PARAMS = {
diff --git a/python/tvm/auto_scheduler/search_task.py b/python/tvm/auto_scheduler/search_task.py
index bfa596a..d985ed1 100644
--- a/python/tvm/auto_scheduler/search_task.py
+++ b/python/tvm/auto_scheduler/search_task.py
@@ -228,6 +228,9 @@ class SearchTask(Object):
if isinstance(target_host, str):
target_host = Target(target_host)
+ if layout_rewrite_option is None:
+ layout_rewrite_option = LayoutRewriteOption.get_target_default(target)
+
self.__init_handle_by_constructor__(
_ffi_api.SearchTask,
compute_dag,
@@ -235,7 +238,7 @@ class SearchTask(Object):
target,
target_host,
hardware_params,
- layout_rewrite_option or LayoutRewriteOption.get_target_default(target),
+ layout_rewrite_option,
)
def tune(self, tuning_options, search_policy=None):
diff --git a/python/tvm/auto_scheduler/task_scheduler.py b/python/tvm/auto_scheduler/task_scheduler.py
index 975306f..420b5f7 100644
--- a/python/tvm/auto_scheduler/task_scheduler.py
+++ b/python/tvm/auto_scheduler/task_scheduler.py
@@ -72,7 +72,7 @@ def make_search_policies(
Load measurement records from this file. If it is not None, the status of the
task scheduler, search policies and cost models will be restored according to this file.
adapative_training: bool = False
- Option used for XGBModel, which will reduce the model training frequency when there're too
+ Option used by XGBModel to reduce the model training frequency when there're too
many logs.
Returns
@@ -275,7 +275,13 @@ class TaskScheduler:
self.group_task_ids.append([])
self.group_task_ids[self.tag_to_group_id[tag]].append(i)
- def tune(self, tune_option, search_policy="default", search_policy_params=None):
+ def tune(
+ self,
+ tune_option,
+ search_policy="default",
+ search_policy_params=None,
+ adapative_training=False,
+ ):
"""Tune a batch of tasks together.
Parameters
@@ -290,6 +296,9 @@ class TaskScheduler:
"sketch.random" for SketchPolicy + RandomModel.
search_policy_params : Optional[Dict[str, Any]]
The parameters of the search policy
+ adapative_training : bool = False
+ Option used by XGBModel to reduce the model training frequency when there're
+ too many logs.
"""
# init members
self.tune_option = tune_option
@@ -324,6 +333,7 @@ class TaskScheduler:
tune_option.verbose,
self.load_model_file,
self.load_log_file,
+ adapative_training,
)
# do a round robin first to warm up
diff --git a/src/auto_scheduler/search_policy/sketch_policy.cc b/src/auto_scheduler/search_policy/sketch_policy.cc
index 1e20b0f..91721af 100644
--- a/src/auto_scheduler/search_policy/sketch_policy.cc
+++ b/src/auto_scheduler/search_policy/sketch_policy.cc
@@ -671,6 +671,26 @@ Array<MeasureInput> SketchPolicyNode::PickStatesWithEpsGreedy(const Array<State>
return inputs;
}
+/********** PreloadCustomSketchRule **********/
+TVM_REGISTER_OBJECT_TYPE(PreloadCustomSketchRuleNode);
+
+PreloadCustomSketchRule::PreloadCustomSketchRule(PackedFunc meet_condition_func,
+ PackedFunc apply_func, String rule_name) {
+ auto node = make_object<PreloadCustomSketchRuleNode>();
+ node->meet_condition_func = std::move(meet_condition_func);
+ node->apply_func = std::move(apply_func);
+ node->rule_name = std::move(rule_name);
+ data_ = std::move(node);
+}
+
+void PreloadCustomSketchRuleNode::Callback(SearchPolicyNode* policy) {
+ CHECK(policy->IsInstance<SketchPolicyNode>());
+ auto sketch_policy = dynamic_cast<SketchPolicyNode*>(policy);
+ sketch_policy->sketch_rules.push_back(
+ new RuleCustomSketch(meet_condition_func, apply_func, rule_name));
+ StdCout(policy->verbose) << "Custom sketch rule \"" << rule_name << "\" added." << std::endl;
+}
+
TVM_REGISTER_GLOBAL("auto_scheduler.SketchPolicy")
.set_body_typed([](SearchTask task, CostModel program_cost_model, Map<String, ObjectRef> params,
int seed, int verbose,
@@ -699,5 +719,10 @@ TVM_REGISTER_GLOBAL("auto_scheduler.PrintTitle").set_body_typed([](std::string t
PrintTitle(title, 1);
});
+TVM_REGISTER_GLOBAL("auto_scheduler.PreloadCustomSketchRule")
+ .set_body_typed([](PackedFunc meet_condition_func, PackedFunc apply_func, String rule_name) {
+ return PreloadCustomSketchRule(meet_condition_func, apply_func, rule_name);
+ });
+
} // namespace auto_scheduler
} // namespace tvm
diff --git a/src/auto_scheduler/search_policy/sketch_policy.h b/src/auto_scheduler/search_policy/sketch_policy.h
index 4886349..faf058b 100644
--- a/src/auto_scheduler/search_policy/sketch_policy.h
+++ b/src/auto_scheduler/search_policy/sketch_policy.h
@@ -197,6 +197,40 @@ class SketchPolicy : public SearchPolicy {
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(SketchPolicy, SearchPolicy, SketchPolicyNode);
};
+/*! \brief Pre-search callback function to load custom rules for sketch generation */
+class PreloadCustomSketchRuleNode : public SearchCallbackNode {
+ public:
+ /*! \brief The condition check function of this rule. */
+ PackedFunc meet_condition_func;
+ /*! \brief The apply function of this rule. */
+ PackedFunc apply_func;
+ /*! \brief The name of this rule. */
+ String rule_name;
+
+ void Callback(SearchPolicyNode* policy) final;
+
+ static constexpr const char* _type_key = "auto_scheduler.PreloadCustomSketchRule";
+ TVM_DECLARE_FINAL_OBJECT_INFO(PreloadCustomSketchRuleNode, SearchCallbackNode);
+};
+
+/*!
+ * \brief Managed reference to PreloadCustomSketchRuleNode.
+ * \sa PreloadCustomSketchRuleNode
+ */
+class PreloadCustomSketchRule : public SearchCallback {
+ public:
+ /*!
+ * \brief The constructor.
+ * \param meet_condition_func The condition check function of this rule.
+ * \param apply_func The apply function of this rule.
+ * \param rule_name The name of this rule.
+ */
+ PreloadCustomSketchRule(PackedFunc meet_condition_func, PackedFunc apply_func, String rule_name);
+
+ TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PreloadCustomSketchRule, SearchCallback,
+ PreloadCustomSketchRuleNode);
+};
+
} // namespace auto_scheduler
} // namespace tvm
diff --git a/src/auto_scheduler/search_policy/sketch_policy_rules.cc b/src/auto_scheduler/search_policy/sketch_policy_rules.cc
index f704fe9..110be6b 100644
--- a/src/auto_scheduler/search_policy/sketch_policy_rules.cc
+++ b/src/auto_scheduler/search_policy/sketch_policy_rules.cc
@@ -461,6 +461,33 @@ std::vector<std::pair<State, int>> RuleSpecialComputeLocationGPU::Apply(
return {std::make_pair(std::move(tmp_s), stage_id - 1)};
}
+/********** RuleCustomSketch **********/
+
+SketchGenerationRule::ConditionKind RuleCustomSketch::MeetCondition(const SketchPolicyNode& policy,
+ const State& state,
+ int stage_id) const {
+ auto ret = meet_condition_func_(tvm::runtime::GetRef<SketchPolicy>(&policy), state, stage_id);
+ if (ret.type_code() == 0) {
+ return ConditionKind(static_cast<int>(ret));
+ } else {
+ LOG(WARNING) << "Wrong rule condition value. Apply the rule and skip the rest";
+ return ConditionKind::kApplyAndSkipRest;
+ }
+}
+
+std::vector<std::pair<State, int>> RuleCustomSketch::Apply(const SketchPolicyNode& policy,
+ const State& state, int stage_id) const {
+ Array<Array<ObjectRef>> apply_ret =
+ apply_func_(tvm::runtime::GetRef<SketchPolicy>(&policy), state, stage_id);
+ std::vector<std::pair<State, int>> ret;
+ for (const auto& item : apply_ret) {
+ CHECK_EQ(item.size(), 2);
+ auto next = item[1].as<IntImmNode>();
+ ret.emplace_back(Downcast<State>(item[0]), next->value);
+ }
+ return ret;
+}
+
/********** Init Population **********/
PopulationGenerationRule::ResultKind InitFillTileSize::Apply(SketchPolicyNode* policy, State* state,
diff --git a/src/auto_scheduler/search_policy/sketch_policy_rules.h b/src/auto_scheduler/search_policy/sketch_policy_rules.h
index 046f036..fc1916b 100644
--- a/src/auto_scheduler/search_policy/sketch_policy_rules.h
+++ b/src/auto_scheduler/search_policy/sketch_policy_rules.h
@@ -131,6 +131,29 @@ DEFINE_SKETCH_GENERATION_RULE(RuleCrossThreadReduction);
* location of the producers of compute ops that perform "fake reduction" with const tensors. */
DEFINE_SKETCH_GENERATION_RULE(RuleSpecialComputeLocationGPU);
+/*! \brief The rule that allows users to generate custom sketches. */
+class RuleCustomSketch : public SketchGenerationRule {
+ public:
+ RuleCustomSketch(PackedFunc meet_condition_func, PackedFunc apply_func,
+ String rule_name = "CustomSketchRule")
+ : meet_condition_func_(std::move(meet_condition_func)),
+ apply_func_(std::move(apply_func)),
+ rule_name_(std::move(rule_name)) {}
+
+ ConditionKind MeetCondition(const SketchPolicyNode& policy, const State& state,
+ int stage_id) const final;
+
+ std::vector<std::pair<State, int>> Apply(const SketchPolicyNode& policy, const State& state,
+ int stage_id) const final;
+
+ std::string GetRuleName() const final { return rule_name_; }
+
+ private:
+ PackedFunc meet_condition_func_;
+ PackedFunc apply_func_;
+ String rule_name_;
+};
+
/********** Init Population **********/
/*! \brief The base class for rules used to annotate the sketches to get the initial population. */
diff --git a/tests/python/unittest/test_auto_scheduler_search_policy.py b/tests/python/unittest/test_auto_scheduler_search_policy.py
index c96dc63..30aafbd 100644
--- a/tests/python/unittest/test_auto_scheduler_search_policy.py
+++ b/tests/python/unittest/test_auto_scheduler_search_policy.py
@@ -183,6 +183,32 @@ def test_sketch_search_policy_zero_rank():
search_common(task, runner=measure_ctx.runner)
+@tvm.testing.requires_llvm
+def test_sketch_search_policy_custom_sketch():
+ def meet_condition_func(search_policy, state, stage_id):
+ return auto_scheduler.PreloadCustomSketchRule.APPLY_AND_SKIP_REST
+
+ def apply_func(search_policy, state, stage_id):
+ ret = []
+ state = auto_scheduler.loop_state.State(state, search_policy.search_task.compute_dag)
+ C = state.stage_ops[2]
+
+ ret.append([state.state_object, -1])
+
+ s1 = state.copy()
+ i, _, _ = s1[C].iters
+ s1.split(C, i, [8])
+ ret.append([s1.state_object, -1])
+ return ret
+
+ search_common(
+ cost_model=auto_scheduler.XGBModel(),
+ init_search_callbacks=[
+ auto_scheduler.PreloadCustomSketchRule(meet_condition_func, apply_func)
+ ],
+ )
+
+
if __name__ == "__main__":
test_workload_registry_empty_policy()
test_sketch_search_policy_basic()
@@ -191,3 +217,4 @@ if __name__ == "__main__":
test_sketch_search_policy_cuda_rpc_runner()
test_sketch_search_policy_cuda_xgbmodel_rpc_runner()
test_sketch_search_policy_zero_rank()
+ test_sketch_search_policy_custom_sketch()
diff --git a/tests/python/unittest/test_auto_scheduler_sketch_generation.py b/tests/python/unittest/test_auto_scheduler_sketch_generation.py
index ddff6dd..f3be6c0 100644
--- a/tests/python/unittest/test_auto_scheduler_sketch_generation.py
+++ b/tests/python/unittest/test_auto_scheduler_sketch_generation.py
@@ -36,9 +36,13 @@ from test_auto_scheduler_common import (
)
-def generate_sketches(workload_func, args, target, print_for_debug=False):
+def generate_sketches(
+ workload_func, args, target, print_for_debug=False, init_search_callbacks=None
+):
task = auto_scheduler.SearchTask(func=workload_func, args=args, target=target)
- policy = auto_scheduler.SketchPolicy(task, verbose=0)
+ policy = auto_scheduler.SketchPolicy(
+ task, verbose=0, init_search_callbacks=init_search_callbacks
+ )
return policy.generate_sketches(print_for_debug)
@@ -259,6 +263,42 @@ def test_cpu_zero_rank_sketch():
assert len(sketches) == 3
+def test_cpu_custom_sketch():
+ def meet_condition_func(search_policy, state, stage_id):
+ return auto_scheduler.PreloadCustomSketchRule.APPLY_AND_SKIP_REST
+
+ def apply_func(search_policy, state, stage_id):
+ ret = []
+ state = auto_scheduler.loop_state.State(state, search_policy.search_task.compute_dag)
+ C = state.stage_ops[2]
+
+ ret.append([state.state_object, -1])
+
+ s1 = state.copy()
+ i, _, _ = s1[C].iters
+ s1.split(C, i, [8, 2])
+ ret.append([s1.state_object, -1])
+ return ret
+
+ sketches = generate_sketches(
+ matmul_auto_scheduler_test,
+ (512, 512, 512),
+ "llvm",
+ init_search_callbacks=[
+ auto_scheduler.PreloadCustomSketchRule(meet_condition_func, apply_func)
+ ],
+ )
+ assert len(sketches) == 2
+ assert sketches[0].stages[2].iters[0].range.extent == 512
+ assert sketches[0].stages[2].iters[1].range.extent == 512
+ assert sketches[0].stages[2].iters[2].range.extent == 512
+ assert sketches[1].stages[2].iters[0].range.extent == 32
+ assert sketches[1].stages[2].iters[1].range.extent == 8
+ assert sketches[1].stages[2].iters[2].range.extent == 2
+ assert sketches[1].stages[2].iters[3].range.extent == 512
+ assert sketches[1].stages[2].iters[4].range.extent == 512
+
+
@tvm.testing.requires_cuda
def test_cuda_matmul_sketch():
sketches = generate_sketches(matmul_auto_scheduler_test, (512, 512, 512), "cuda")
@@ -407,6 +447,7 @@ if __name__ == "__main__":
test_cpu_softmax_sketch()
test_cpu_conv2d_winograd_sketch()
test_cpu_zero_rank_sketch()
+ test_cpu_custom_sketch()
test_cuda_matmul_sketch()
test_cuda_conv2d_bn_relu_sketch()
test_cuda_max_pool2d_sketch()