You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by lm...@apache.org on 2020/11/11 02:20:44 UTC

[incubator-tvm] branch main updated: [AutoScheduler] Improve tuning with random cost model (#6835)

This is an automated email from the ASF dual-hosted git repository.

lmzheng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new d03c0c0  [AutoScheduler] Improve tuning with random cost model (#6835)
d03c0c0 is described below

commit d03c0c0637979026ad547a11e5f14198c14a43c3
Author: Cody Yu <co...@gmail.com>
AuthorDate: Tue Nov 10 18:20:33 2020 -0800

    [AutoScheduler] Improve tuning with random cost model (#6835)
    
    * fix
    
    * more fix
    
    * fix
    
    * revert
    
    * format
    
    * Update sketch_policy.cc
    
    * increase measure trial to avoid flaky
---
 python/tvm/auto_scheduler/relay_integration.py     |  4 +-
 src/auto_scheduler/search_policy/sketch_policy.cc  | 58 +++++++++++++---------
 .../unittest/test_auto_scheduler_search_policy.py  |  2 +-
 3 files changed, 37 insertions(+), 27 deletions(-)

diff --git a/python/tvm/auto_scheduler/relay_integration.py b/python/tvm/auto_scheduler/relay_integration.py
index 24a4c44..c8a4ed5 100644
--- a/python/tvm/auto_scheduler/relay_integration.py
+++ b/python/tvm/auto_scheduler/relay_integration.py
@@ -72,9 +72,9 @@ def extract_tasks(mod, params, target, target_host=None, hardware_params=None):
     from tvm import relay
 
     if isinstance(target, str):
-        target = Target(target)
+        target = tvm.target.Target(target)
     if isinstance(target_host, str):
-        target_host = Target(target_host)
+        target_host = tvm.target.Target(target_host)
 
     # Run the compiler to collect all TOPI calls during compilation.
     env = TracingEnvironment(TracingMode.EXTRACT_TASK)
diff --git a/src/auto_scheduler/search_policy/sketch_policy.cc b/src/auto_scheduler/search_policy/sketch_policy.cc
index e4e186b..6360c72 100644
--- a/src/auto_scheduler/search_policy/sketch_policy.cc
+++ b/src/auto_scheduler/search_policy/sketch_policy.cc
@@ -264,7 +264,6 @@ Array<State> SketchPolicyNode::SearchOneRound(int num_random_states, Array<State
       static_cast<int>(
           GetDoubleParam(params, SketchParamKey::SampleInitPopulation::use_measured_ratio) *
           population));
-  bool is_cost_model_reasonable = !program_cost_model->IsInstance<RandomModelNode>();
 
   // 1. Generate sketches
   if (sketch_cache_.empty()) {
@@ -274,23 +273,17 @@ Array<State> SketchPolicyNode::SearchOneRound(int num_random_states, Array<State
   // 2. Sample the init population
   Array<State> init_population = SampleInitPopulation(sketch_cache_);
 
-  // 3. Perform evolutionary search if a cost model is utilized. Otherwise,
-  // just return some random states.
-  if (is_cost_model_reasonable) {
-    // Also insert already measured good states to the initial population
-    std::vector<int> indices = Argsort(measured_states_throughputs_);
-    for (int i = 0; i < num_use_measured; i++) {
-      init_population.push_back(measured_states_vector_[indices[i]]);
-    }
-    // Sample some random states for eps-greedy
-    if (num_random_states > 0 && random_states != nullptr) {
-      *random_states = RandomSampleStates(init_population, &rand_gen, num_random_states);
-    }
-    return EvolutionarySearch(init_population, num_measure_per_iter_ * 2);
-  } else {
-    PruneInvalidState(search_task, &init_population);
-    return RandomSampleStates(init_population, &rand_gen, num_measure_per_iter_ * 2);
+  // 3. Perform evolutionary search.
+  // Also insert already measured good states to the initial population
+  std::vector<int> indices = Argsort(measured_states_throughputs_);
+  for (int i = 0; i < num_use_measured; i++) {
+    init_population.push_back(measured_states_vector_[indices[i]]);
+  }
+  // Sample some random states for eps-greedy
+  if (num_random_states > 0 && random_states != nullptr) {
+    *random_states = RandomSampleStates(init_population, &rand_gen, num_random_states);
   }
+  return EvolutionarySearch(init_population, num_measure_per_iter_ * 2);
 }
 
 Array<State> SketchPolicyNode::GenerateSketches() {
@@ -378,6 +371,7 @@ Array<State> SketchPolicyNode::SampleInitPopulation(const Array<State>& sketches
   }
   auto tic_begin = std::chrono::high_resolution_clock::now();
 
+  std::unordered_set<std::string> explored_state_strs;
   size_t iter = 1;
   size_t target_size = min_population;
   size_t unchange_cnt = 0;
@@ -421,10 +415,13 @@ Array<State> SketchPolicyNode::SampleInitPopulation(const Array<State>& sketches
       std::vector<float> pop_scores;
       pop_scores.reserve(cand_states.size());
       cand_states = search_task->compute_dag.InferBound(cand_states);
+      PruneInvalidState(search_task, &cand_states);
       program_cost_model->Predict(search_task, cand_states, &pop_scores);
 
       for (size_t i = 0; i < cand_states.size(); i++) {
-        if (pop_scores[i] > -1e10) {
+        const auto state_str = cand_states[i].ToStr();
+        if (pop_scores[i] > -1e10 && explored_state_strs.count(state_str) == 0) {
+          explored_state_strs.insert(state_str);
           out_states.push_back(std::move(cand_states[i]));
           unchange_cnt = 0;  // Reset the counter once we found a valid state
         } else {
@@ -449,7 +446,7 @@ Array<State> SketchPolicyNode::SampleInitPopulation(const Array<State>& sketches
       if (target_size > 1) {
         target_size /= 2;
         StdCout(verbose) << "#Target has been reduced to " << target_size
-                         << " due to too many failures";
+                         << " due to too many failures or duplications" << std::endl;
       }
       unchange_cnt = 0;
     }
@@ -471,8 +468,15 @@ Array<State> SketchPolicyNode::EvolutionarySearch(const Array<State>& init_popul
   auto tic_begin = std::chrono::high_resolution_clock::now();
 
   size_t population = GetIntParam(params, SketchParamKey::EvolutionarySearch::population);
-  int num_iters = GetIntParam(params, SketchParamKey::EvolutionarySearch::num_iters);
   double mutation_prob = GetDoubleParam(params, SketchParamKey::EvolutionarySearch::mutation_prob);
+  int num_iters = GetIntParam(params, SketchParamKey::EvolutionarySearch::num_iters);
+
+  bool is_cost_model_reasonable = !program_cost_model->IsInstance<RandomModelNode>();
+  if (!is_cost_model_reasonable && num_iters > 3) {
+    num_iters = 3;
+    StdCout(verbose) << "GA iteration number has been adjusted to " << num_iters
+                     << " due to random cost model" << std::endl;
+  }
 
   // Two ping pong buffers to avoid copy.
   Array<State> states_buf1{init_population}, states_buf2;
@@ -493,7 +497,7 @@ Array<State> SketchPolicyNode::EvolutionarySearch(const Array<State>& init_popul
   // auxiliary global variables
   std::vector<float> pop_scores;
   std::vector<double> pop_selection_probs;
-  float max_score = 0.0;
+  float max_score = -1e-10;
   pop_scores.reserve(population);
   pop_selection_probs.reserve(population);
   std::uniform_real_distribution<> dis(0.0, 1.0);
@@ -541,9 +545,15 @@ Array<State> SketchPolicyNode::EvolutionarySearch(const Array<State>& init_popul
 
     // Print statistical information
     if (k % 5 == 0 || k == num_iters) {
-      StdCout(verbose) << "GA Iter: " << k << std::fixed << std::setprecision(4)
-                       << "\tMax score: " << max_score << "\tMin score: " << heap.front().second
-                       << "\t#Pop: " << pnow->size() << "\t#M+: " << mutation_success_ct / (k + 1)
+      StdCout(verbose) << "GA Iter: " << k;
+      if (!heap.empty()) {
+        StdCout(verbose) << std::fixed << std::setprecision(4) << "\tMax score: " << max_score
+                         << std::fixed << std::setprecision(4)
+                         << "\tMin score: " << heap.front().second;
+      } else {
+        StdCout(verbose) << "\tMax score: N/A\tMin score: N/A";
+      }
+      StdCout(verbose) << "\t#Pop: " << heap.size() << "\t#M+: " << mutation_success_ct / (k + 1)
                        << "\t#M-: " << mutation_fail_ct / (k + 1) << std::endl;
     }
     if (k == num_iters) {
diff --git a/tests/python/unittest/test_auto_scheduler_search_policy.py b/tests/python/unittest/test_auto_scheduler_search_policy.py
index 5329f3d..6493c24 100644
--- a/tests/python/unittest/test_auto_scheduler_search_policy.py
+++ b/tests/python/unittest/test_auto_scheduler_search_policy.py
@@ -37,7 +37,7 @@ def search_common(
     seed=random.randint(1, 1 << 30),
     runner="local",
     cost_model=auto_scheduler.RandomModel(),
-    num_measure_trials=2,
+    num_measure_trials=10,
     init_search_callbacks=None,
 ):
     print("Test %s schedule search with the default search policy" % (target))