You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by jc...@apache.org on 2021/06/28 04:32:38 UTC

[tvm] branch main updated: [AutoScheduler]Simplify the code (#8351)

This is an automated email from the ASF dual-hosted git repository.

jcf94 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new c586834  [AutoScheduler]Simplify the code (#8351)
c586834 is described below

commit c58683445000145a3b97322bfdd5d6d2487ef3c5
Author: Swift.Sun <su...@yeah.net>
AuthorDate: Mon Jun 28 12:32:19 2021 +0800

    [AutoScheduler]Simplify the code (#8351)
---
 src/auto_scheduler/search_policy/utils.cc | 101 +++++++++++++-----------------
 1 file changed, 42 insertions(+), 59 deletions(-)

diff --git a/src/auto_scheduler/search_policy/utils.cc b/src/auto_scheduler/search_policy/utils.cc
index ce8dc39..ac1cf2d 100644
--- a/src/auto_scheduler/search_policy/utils.cc
+++ b/src/auto_scheduler/search_policy/utils.cc
@@ -153,24 +153,21 @@ State DoMultiLevelTiling(const State& state, int stage_id, const std::string& fo
   if (spatial_split_step_ids == nullptr) {
     spatial_split_step_ids = &temp_split_step_ids;
   }
+  spatial_split_step_ids->clear();
+
   std::vector<std::vector<Iterator>> space_levels;
   std::vector<std::vector<Iterator>> reduce_levels;
   std::vector<Iterator> space_outer, space_inner, reduce_outer, reduce_inner;
-  Array<Iterator> split_res;
 
-  for (const auto c : format) {
-    if (tolower(c) == 's') {
-      space_levels.emplace_back();
-    } else if (tolower(c) == 'r') {
-      reduce_levels.emplace_back();
-    } else {
-      LOG(FATAL) << "Invalid multi-level tiling format: " << format;
-    }
+  size_t n_space =
+      std::count(format.begin(), format.end(), 's') + std::count(format.begin(), format.end(), 'S');
+  size_t n_reduce =
+      std::count(format.begin(), format.end(), 'r') + std::count(format.begin(), format.end(), 'R');
+  if (n_space + n_reduce != format.size()) {
+    LOG(FATAL) << "Invalid multi-level tiling format: " << format;
   }
-  size_t n_space = space_levels.size();
-  size_t n_reduce = reduce_levels.size();
-
-  spatial_split_step_ids->clear();
+  space_levels.resize(n_space);
+  reduce_levels.resize(n_reduce);
 
   State tmp_s = state;
   const Stage& stage = state->stages[stage_id];
@@ -179,31 +176,28 @@ State DoMultiLevelTiling(const State& state, int stage_id, const std::string& fo
           ? GetIterNameSetParam(stage->op->attrs, SearchPolicyKey::no_split_at_inner)
           : std::set<std::string>();
 
+  auto sr_levels = [&](int size, const Iterator& iter, std::vector<std::vector<Iterator>>& levels) {
+    ICHECK_GE(size, 1);
+    if (size == 1) {
+      levels[0].push_back(iter);
+    } else {
+      Array<Iterator> split_res =
+          tmp_s.split(stage_id, iter, Array<Optional<Integer>>(size - 1, NullOpt));
+      for (int i = 0; i < size; i++) {
+        levels[i].push_back(split_res[i]);
+      }
+      if (iter->iter_kind == IteratorKind::kSpatial) {
+        spatial_split_step_ids->push_back(tmp_s->transform_steps.size() - 1);
+      }
+    }
+  };
+
   for (const auto& iter : state->stages[stage_id]->iters) {
     if (!no_split_at_inner_name_set.count(iter->name)) {
       if (iter->iter_kind == IteratorKind::kSpatial) {
-        ICHECK_GE(n_space, 1);
-
-        if (n_space == 1) {
-          space_levels[0].push_back(iter);
-        } else {
-          split_res = tmp_s.split(stage_id, iter, Array<Optional<Integer>>(n_space - 1, NullOpt));
-          for (size_t i = 0; i < n_space; i++) {
-            space_levels[i].push_back(split_res[i]);
-          }
-          spatial_split_step_ids->push_back(tmp_s->transform_steps.size() - 1);
-        }
+        sr_levels(n_space, iter, space_levels);
       } else if (iter->iter_kind == IteratorKind::kReduction) {
-        ICHECK_GE(n_reduce, 1);
-
-        if (n_reduce == 1) {
-          reduce_levels[0].push_back(iter);
-        } else {
-          split_res = tmp_s.split(stage_id, iter, Array<Optional<Integer>>(n_reduce - 1, NullOpt));
-          for (size_t i = 0; i < n_reduce; i++) {
-            reduce_levels[i].push_back(split_res[i]);
-          }
-        }
+        sr_levels(n_reduce, iter, reduce_levels);
       } else {
         LOG(FATAL) << "Invalid iter type: " << int(iter->iter_kind);
       }
@@ -218,40 +212,29 @@ State DoMultiLevelTiling(const State& state, int stage_id, const std::string& fo
     }
   }
 
-  if (!space_outer.empty()) {
-    ICHECK(!space_levels.empty());
-    space_levels.front().insert(space_levels.front().begin(),
-                                std::make_move_iterator(space_outer.begin()),
-                                std::make_move_iterator(space_outer.end()));
-  }
-  if (!space_inner.empty()) {
-    ICHECK(!space_levels.empty());
-    space_levels.back().insert(space_levels.back().begin(),
-                               std::make_move_iterator(space_inner.begin()),
-                               std::make_move_iterator(space_inner.end()));
-  }
-
-  if (!reduce_outer.empty()) {
-    ICHECK(!reduce_levels.empty());
-    reduce_levels.front().insert(reduce_levels.front().begin(),
-                                 std::make_move_iterator(reduce_outer.begin()),
-                                 std::make_move_iterator(reduce_outer.end()));
+  auto fill_levels = [&](std::vector<Iterator>& levels_iter, std::vector<Iterator>& fill) {
+    if (!fill.empty()) {
+      levels_iter.insert(levels_iter.begin(), std::make_move_iterator(fill.begin()),
+                         std::make_move_iterator(fill.end()));
+    }
+  };
+  if (!space_levels.empty()) {
+    fill_levels(space_levels.front(), space_outer);
+    fill_levels(space_levels.back(), space_inner);
   }
-  if (!reduce_inner.empty()) {
-    ICHECK(!reduce_levels.empty());
-    reduce_levels.back().insert(reduce_levels.back().begin(),
-                                std::make_move_iterator(reduce_inner.begin()),
-                                std::make_move_iterator(reduce_inner.end()));
+  if (!reduce_levels.empty()) {
+    fill_levels(reduce_levels.front(), reduce_outer);
+    fill_levels(reduce_levels.back(), reduce_inner);
   }
 
   Array<Iterator> order;
   int space_ct = 0, reduce_ct = 0;
   for (const auto c : format) {
-    if (tolower(c) == 's') {
+    if (c == 's' || c == 'S') {
       order.insert(order.end(), std::make_move_iterator(space_levels[space_ct].begin()),
                    std::make_move_iterator(space_levels[space_ct].end()));
       space_ct++;
-    } else if (tolower(c) == 'r') {
+    } else if (c == 'r' || c == 'R') {
       order.insert(order.end(), std::make_move_iterator(reduce_levels[reduce_ct].begin()),
                    std::make_move_iterator(reduce_levels[reduce_ct].end()));
       reduce_ct++;