You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by jc...@apache.org on 2021/06/28 04:32:38 UTC
[tvm] branch main updated: [AutoScheduler]Simplify the code (#8351)
This is an automated email from the ASF dual-hosted git repository.
jcf94 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new c586834 [AutoScheduler]Simplify the code (#8351)
c586834 is described below
commit c58683445000145a3b97322bfdd5d6d2487ef3c5
Author: Swift.Sun <su...@yeah.net>
AuthorDate: Mon Jun 28 12:32:19 2021 +0800
[AutoScheduler]Simplify the code (#8351)
---
src/auto_scheduler/search_policy/utils.cc | 101 +++++++++++++-----------------
1 file changed, 42 insertions(+), 59 deletions(-)
diff --git a/src/auto_scheduler/search_policy/utils.cc b/src/auto_scheduler/search_policy/utils.cc
index ce8dc39..ac1cf2d 100644
--- a/src/auto_scheduler/search_policy/utils.cc
+++ b/src/auto_scheduler/search_policy/utils.cc
@@ -153,24 +153,21 @@ State DoMultiLevelTiling(const State& state, int stage_id, const std::string& fo
if (spatial_split_step_ids == nullptr) {
spatial_split_step_ids = &temp_split_step_ids;
}
+ spatial_split_step_ids->clear();
+
std::vector<std::vector<Iterator>> space_levels;
std::vector<std::vector<Iterator>> reduce_levels;
std::vector<Iterator> space_outer, space_inner, reduce_outer, reduce_inner;
- Array<Iterator> split_res;
- for (const auto c : format) {
- if (tolower(c) == 's') {
- space_levels.emplace_back();
- } else if (tolower(c) == 'r') {
- reduce_levels.emplace_back();
- } else {
- LOG(FATAL) << "Invalid multi-level tiling format: " << format;
- }
+ size_t n_space =
+ std::count(format.begin(), format.end(), 's') + std::count(format.begin(), format.end(), 'S');
+ size_t n_reduce =
+ std::count(format.begin(), format.end(), 'r') + std::count(format.begin(), format.end(), 'R');
+ if (n_space + n_reduce != format.size()) {
+ LOG(FATAL) << "Invalid multi-level tiling format: " << format;
}
- size_t n_space = space_levels.size();
- size_t n_reduce = reduce_levels.size();
-
- spatial_split_step_ids->clear();
+ space_levels.resize(n_space);
+ reduce_levels.resize(n_reduce);
State tmp_s = state;
const Stage& stage = state->stages[stage_id];
@@ -179,31 +176,28 @@ State DoMultiLevelTiling(const State& state, int stage_id, const std::string& fo
? GetIterNameSetParam(stage->op->attrs, SearchPolicyKey::no_split_at_inner)
: std::set<std::string>();
+ auto sr_levels = [&](int size, const Iterator& iter, std::vector<std::vector<Iterator>>& levels) {
+ ICHECK_GE(size, 1);
+ if (size == 1) {
+ levels[0].push_back(iter);
+ } else {
+ Array<Iterator> split_res =
+ tmp_s.split(stage_id, iter, Array<Optional<Integer>>(size - 1, NullOpt));
+ for (int i = 0; i < size; i++) {
+ levels[i].push_back(split_res[i]);
+ }
+ if (iter->iter_kind == IteratorKind::kSpatial) {
+ spatial_split_step_ids->push_back(tmp_s->transform_steps.size() - 1);
+ }
+ }
+ };
+
for (const auto& iter : state->stages[stage_id]->iters) {
if (!no_split_at_inner_name_set.count(iter->name)) {
if (iter->iter_kind == IteratorKind::kSpatial) {
- ICHECK_GE(n_space, 1);
-
- if (n_space == 1) {
- space_levels[0].push_back(iter);
- } else {
- split_res = tmp_s.split(stage_id, iter, Array<Optional<Integer>>(n_space - 1, NullOpt));
- for (size_t i = 0; i < n_space; i++) {
- space_levels[i].push_back(split_res[i]);
- }
- spatial_split_step_ids->push_back(tmp_s->transform_steps.size() - 1);
- }
+ sr_levels(n_space, iter, space_levels);
} else if (iter->iter_kind == IteratorKind::kReduction) {
- ICHECK_GE(n_reduce, 1);
-
- if (n_reduce == 1) {
- reduce_levels[0].push_back(iter);
- } else {
- split_res = tmp_s.split(stage_id, iter, Array<Optional<Integer>>(n_reduce - 1, NullOpt));
- for (size_t i = 0; i < n_reduce; i++) {
- reduce_levels[i].push_back(split_res[i]);
- }
- }
+ sr_levels(n_reduce, iter, reduce_levels);
} else {
LOG(FATAL) << "Invalid iter type: " << int(iter->iter_kind);
}
@@ -218,40 +212,29 @@ State DoMultiLevelTiling(const State& state, int stage_id, const std::string& fo
}
}
- if (!space_outer.empty()) {
- ICHECK(!space_levels.empty());
- space_levels.front().insert(space_levels.front().begin(),
- std::make_move_iterator(space_outer.begin()),
- std::make_move_iterator(space_outer.end()));
- }
- if (!space_inner.empty()) {
- ICHECK(!space_levels.empty());
- space_levels.back().insert(space_levels.back().begin(),
- std::make_move_iterator(space_inner.begin()),
- std::make_move_iterator(space_inner.end()));
- }
-
- if (!reduce_outer.empty()) {
- ICHECK(!reduce_levels.empty());
- reduce_levels.front().insert(reduce_levels.front().begin(),
- std::make_move_iterator(reduce_outer.begin()),
- std::make_move_iterator(reduce_outer.end()));
+ auto fill_levels = [&](std::vector<Iterator>& levels_iter, std::vector<Iterator>& fill) {
+ if (!fill.empty()) {
+ levels_iter.insert(levels_iter.begin(), std::make_move_iterator(fill.begin()),
+ std::make_move_iterator(fill.end()));
+ }
+ };
+ if (!space_levels.empty()) {
+ fill_levels(space_levels.front(), space_outer);
+ fill_levels(space_levels.back(), space_inner);
}
- if (!reduce_inner.empty()) {
- ICHECK(!reduce_levels.empty());
- reduce_levels.back().insert(reduce_levels.back().begin(),
- std::make_move_iterator(reduce_inner.begin()),
- std::make_move_iterator(reduce_inner.end()));
+ if (!reduce_levels.empty()) {
+ fill_levels(reduce_levels.front(), reduce_outer);
+ fill_levels(reduce_levels.back(), reduce_inner);
}
Array<Iterator> order;
int space_ct = 0, reduce_ct = 0;
for (const auto c : format) {
- if (tolower(c) == 's') {
+ if (c == 's' || c == 'S') {
order.insert(order.end(), std::make_move_iterator(space_levels[space_ct].begin()),
std::make_move_iterator(space_levels[space_ct].end()));
space_ct++;
- } else if (tolower(c) == 'r') {
+ } else if (c == 'r' || c == 'R') {
order.insert(order.end(), std::make_move_iterator(reduce_levels[reduce_ct].begin()),
std::make_move_iterator(reduce_levels[reduce_ct].end()));
reduce_ct++;