You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2020/11/08 13:45:18 UTC

[incubator-tvm] branch main updated: [AutoScheduler] Fix the occasional crash caused by split memo (#6883)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new d5afc39  [AutoScheduler] Fix the occasional crash caused by split memo (#6883)
d5afc39 is described below

commit d5afc393782c534ab24b3f3691c28a9a6ec75c0a
Author: Lianmin Zheng <li...@gmail.com>
AuthorDate: Sun Nov 8 05:45:05 2020 -0800

    [AutoScheduler] Fix the occasional crash caused by split memo (#6883)
---
 .../search_policy/sketch_policy_rules.cc           |  6 ++--
 src/auto_scheduler/search_policy/utils.cc          | 42 ++--------------------
 src/auto_scheduler/search_policy/utils.h           | 27 --------------
 3 files changed, 7 insertions(+), 68 deletions(-)

diff --git a/src/auto_scheduler/search_policy/sketch_policy_rules.cc b/src/auto_scheduler/search_policy/sketch_policy_rules.cc
index 692ace1..1c69397 100644
--- a/src/auto_scheduler/search_policy/sketch_policy_rules.cc
+++ b/src/auto_scheduler/search_policy/sketch_policy_rules.cc
@@ -450,6 +450,7 @@ std::vector<std::pair<State, int>> RuleSpecialComputeLocationGPU::Apply(
 
 PopulationGenerationRule::ResultKind InitFillTileSize::Apply(SketchPolicyNode* policy, State* state,
                                                              std::mt19937* rand_gen) const {
+  SplitFactorizationMemo split_memo;
   int max_innermost_split_factor =
       GetIntParam(policy->params, SketchParamKey::max_innermost_split_factor);
 
@@ -470,8 +471,9 @@ PopulationGenerationRule::ResultKind InitFillTileSize::Apply(SketchPolicyNode* p
 
       ICHECK(ps->extent);
       int extent = GetIntImm(ps->extent.value());
-      const auto& candidate_lens = policy->split_memo.GetFactorizationSchemes(
-          extent, ps->lengths.size(), max_innermost_split_factor);
+      const auto& candidate_lens = split_memo.GetFactorizationSchemes(extent, ps->lengths.size(),
+                                                                      max_innermost_split_factor);
+      ICHECK(!candidate_lens.empty());
       const auto& candidate_lengths = candidate_lens[(*rand_gen)() % candidate_lens.size()];
 
       pstate->transform_steps.Set(
diff --git a/src/auto_scheduler/search_policy/utils.cc b/src/auto_scheduler/search_policy/utils.cc
index 3e2f7aa..d59df69 100644
--- a/src/auto_scheduler/search_policy/utils.cc
+++ b/src/auto_scheduler/search_policy/utils.cc
@@ -413,55 +413,19 @@ void PruneInvalidState(const SearchTask& task, Array<State>* states) {
 }
 
 /********** SplitFactorizationMemo **********/
-
-void SplitFactorizationMemo::ReadWriteLock::GetRead() {
-  std::unique_lock<std::mutex> lock(cv_mutex_);
-  // Wake up and get the mutex lock if there's no writing thread
-  cv_.wait(lock, [this]() { return !this->is_writing_; });
-  read_count_++;
-}
-
-void SplitFactorizationMemo::ReadWriteLock::GetWrite() {
-  std::unique_lock<std::mutex> lock(cv_mutex_);
-  // Wake up and get the mutex lock if there's no reading or writing threads
-  cv_.wait(lock, [this]() { return this->read_count_ == 0 && !this->is_writing_; });
-  is_writing_ = true;
-}
-
-void SplitFactorizationMemo::ReadWriteLock::UnlockRead() {
-  std::lock_guard<std::mutex> lock(cv_mutex_);
-  read_count_--;
-  // Notify the other blocked threads if this is the last reading thread
-  if (read_count_ == 0) {
-    cv_.notify_one();
-  }
-}
-
-void SplitFactorizationMemo::ReadWriteLock::UnlockWrite() {
-  std::lock_guard<std::mutex> lock(cv_mutex_);
-  is_writing_ = false;
-  // Notify the other blocked threads
-  cv_.notify_one();
-}
-
 const Array<Array<Integer>>& SplitFactorizationMemo::GetFactorizationSchemes(
     int extent, int n_lengths, int max_innermost_factor) {
   QueryKey key = std::make_tuple(extent, n_lengths, max_innermost_factor);
-  const auto& const_memory = memory_;
-  lock_.GetRead();
-  const auto& it = const_memory.find(key);
-  const auto& memory_end = const_memory.end();
-  lock_.UnlockRead();
-  if (it != memory_end) {
+  const auto& it = memory_.find(key);
+  if (it != memory_.end()) {
     return it->second;
   }
 
-  lock_.GetWrite();
   tmp_stack_ = Array<Integer>(n_lengths, Integer());
   results_ = &memory_[key];
   n_lengths_ = n_lengths;
+
   DfsEnumerate(0, extent, max_innermost_factor);
-  lock_.UnlockWrite();
 
   return *results_;
 }
diff --git a/src/auto_scheduler/search_policy/utils.h b/src/auto_scheduler/search_policy/utils.h
index f0c4cbc..ecc46af 100644
--- a/src/auto_scheduler/search_policy/utils.h
+++ b/src/auto_scheduler/search_policy/utils.h
@@ -677,33 +677,6 @@ class SplitFactorizationMemo {
  private:
   void DfsEnumerate(int now, int remaining_length, int max_innermost_factor);
 
-  /*!
-   * \brief A simple implementation of read-write lock.
-   * The guarded block can be read by multiple threads at the same time, while other operations will
-   * be blocked if one thread is writing.
-   * \note Writing threads will wait until all reading threads have finshed. If there're multiple
-   * writing threads, the process order of them is not guaranteed.
-   */
-  class ReadWriteLock {
-   public:
-    /*! \brief The method to get the read lock. One thread can process read if there's on other
-     * writing threads. */
-    void GetRead();
-    /*! \brief The method to get the write lock. One thread can process write if there's on other
-     * reading or writing threads. */
-    void GetWrite();
-    /*! \brief The method to release the read lock. */
-    void UnlockRead();
-    /*! \brief The method to release the write lock. */
-    void UnlockWrite();
-
-   private:
-    uint32_t read_count_ = 0;
-    bool is_writing_ = false;
-    std::mutex cv_mutex_;
-    std::condition_variable cv_;
-  } lock_;
-
   std::unordered_map<QueryKey, Array<Array<Integer>>> memory_;
 
   int n_lengths_;