You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by lm...@apache.org on 2020/11/11 02:20:44 UTC
[incubator-tvm] branch main updated: [AutoScheduler] Improve tuning
with random cost model (#6835)
This is an automated email from the ASF dual-hosted git repository.
lmzheng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/main by this push:
new d03c0c0 [AutoScheduler] Improve tuning with random cost model (#6835)
d03c0c0 is described below
commit d03c0c0637979026ad547a11e5f14198c14a43c3
Author: Cody Yu <co...@gmail.com>
AuthorDate: Tue Nov 10 18:20:33 2020 -0800
[AutoScheduler] Improve tuning with random cost model (#6835)
* fix
* more fix
* fix
* revert
* format
* Update sketch_policy.cc
* increase measure trial to avoid flaky
---
python/tvm/auto_scheduler/relay_integration.py | 4 +-
src/auto_scheduler/search_policy/sketch_policy.cc | 58 +++++++++++++---------
.../unittest/test_auto_scheduler_search_policy.py | 2 +-
3 files changed, 37 insertions(+), 27 deletions(-)
diff --git a/python/tvm/auto_scheduler/relay_integration.py b/python/tvm/auto_scheduler/relay_integration.py
index 24a4c44..c8a4ed5 100644
--- a/python/tvm/auto_scheduler/relay_integration.py
+++ b/python/tvm/auto_scheduler/relay_integration.py
@@ -72,9 +72,9 @@ def extract_tasks(mod, params, target, target_host=None, hardware_params=None):
from tvm import relay
if isinstance(target, str):
- target = Target(target)
+ target = tvm.target.Target(target)
if isinstance(target_host, str):
- target_host = Target(target_host)
+ target_host = tvm.target.Target(target_host)
# Run the compiler to collect all TOPI calls during compilation.
env = TracingEnvironment(TracingMode.EXTRACT_TASK)
diff --git a/src/auto_scheduler/search_policy/sketch_policy.cc b/src/auto_scheduler/search_policy/sketch_policy.cc
index e4e186b..6360c72 100644
--- a/src/auto_scheduler/search_policy/sketch_policy.cc
+++ b/src/auto_scheduler/search_policy/sketch_policy.cc
@@ -264,7 +264,6 @@ Array<State> SketchPolicyNode::SearchOneRound(int num_random_states, Array<State
static_cast<int>(
GetDoubleParam(params, SketchParamKey::SampleInitPopulation::use_measured_ratio) *
population));
- bool is_cost_model_reasonable = !program_cost_model->IsInstance<RandomModelNode>();
// 1. Generate sketches
if (sketch_cache_.empty()) {
@@ -274,23 +273,17 @@ Array<State> SketchPolicyNode::SearchOneRound(int num_random_states, Array<State
// 2. Sample the init population
Array<State> init_population = SampleInitPopulation(sketch_cache_);
- // 3. Perform evolutionary search if a cost model is utilized. Otherwise,
- // just return some random states.
- if (is_cost_model_reasonable) {
- // Also insert already measured good states to the initial population
- std::vector<int> indices = Argsort(measured_states_throughputs_);
- for (int i = 0; i < num_use_measured; i++) {
- init_population.push_back(measured_states_vector_[indices[i]]);
- }
- // Sample some random states for eps-greedy
- if (num_random_states > 0 && random_states != nullptr) {
- *random_states = RandomSampleStates(init_population, &rand_gen, num_random_states);
- }
- return EvolutionarySearch(init_population, num_measure_per_iter_ * 2);
- } else {
- PruneInvalidState(search_task, &init_population);
- return RandomSampleStates(init_population, &rand_gen, num_measure_per_iter_ * 2);
+ // 3. Perform evolutionary search.
+ // Also insert already measured good states to the initial population
+ std::vector<int> indices = Argsort(measured_states_throughputs_);
+ for (int i = 0; i < num_use_measured; i++) {
+ init_population.push_back(measured_states_vector_[indices[i]]);
+ }
+ // Sample some random states for eps-greedy
+ if (num_random_states > 0 && random_states != nullptr) {
+ *random_states = RandomSampleStates(init_population, &rand_gen, num_random_states);
}
+ return EvolutionarySearch(init_population, num_measure_per_iter_ * 2);
}
Array<State> SketchPolicyNode::GenerateSketches() {
@@ -378,6 +371,7 @@ Array<State> SketchPolicyNode::SampleInitPopulation(const Array<State>& sketches
}
auto tic_begin = std::chrono::high_resolution_clock::now();
+ std::unordered_set<std::string> explored_state_strs;
size_t iter = 1;
size_t target_size = min_population;
size_t unchange_cnt = 0;
@@ -421,10 +415,13 @@ Array<State> SketchPolicyNode::SampleInitPopulation(const Array<State>& sketches
std::vector<float> pop_scores;
pop_scores.reserve(cand_states.size());
cand_states = search_task->compute_dag.InferBound(cand_states);
+ PruneInvalidState(search_task, &cand_states);
program_cost_model->Predict(search_task, cand_states, &pop_scores);
for (size_t i = 0; i < cand_states.size(); i++) {
- if (pop_scores[i] > -1e10) {
+ const auto state_str = cand_states[i].ToStr();
+ if (pop_scores[i] > -1e10 && explored_state_strs.count(state_str) == 0) {
+ explored_state_strs.insert(state_str);
out_states.push_back(std::move(cand_states[i]));
unchange_cnt = 0; // Reset the counter once we found a valid state
} else {
@@ -449,7 +446,7 @@ Array<State> SketchPolicyNode::SampleInitPopulation(const Array<State>& sketches
if (target_size > 1) {
target_size /= 2;
StdCout(verbose) << "#Target has been reduced to " << target_size
- << " due to too many failures";
+ << " due to too many failures or duplications" << std::endl;
}
unchange_cnt = 0;
}
@@ -471,8 +468,15 @@ Array<State> SketchPolicyNode::EvolutionarySearch(const Array<State>& init_popul
auto tic_begin = std::chrono::high_resolution_clock::now();
size_t population = GetIntParam(params, SketchParamKey::EvolutionarySearch::population);
- int num_iters = GetIntParam(params, SketchParamKey::EvolutionarySearch::num_iters);
double mutation_prob = GetDoubleParam(params, SketchParamKey::EvolutionarySearch::mutation_prob);
+ int num_iters = GetIntParam(params, SketchParamKey::EvolutionarySearch::num_iters);
+
+ bool is_cost_model_reasonable = !program_cost_model->IsInstance<RandomModelNode>();
+ if (!is_cost_model_reasonable && num_iters > 3) {
+ num_iters = 3;
+ StdCout(verbose) << "GA iteration number has been adjusted to " << num_iters
+ << " due to random cost model" << std::endl;
+ }
// Two ping pong buffers to avoid copy.
Array<State> states_buf1{init_population}, states_buf2;
@@ -493,7 +497,7 @@ Array<State> SketchPolicyNode::EvolutionarySearch(const Array<State>& init_popul
// auxiliary global variables
std::vector<float> pop_scores;
std::vector<double> pop_selection_probs;
- float max_score = 0.0;
+ float max_score = -1e-10;
pop_scores.reserve(population);
pop_selection_probs.reserve(population);
std::uniform_real_distribution<> dis(0.0, 1.0);
@@ -541,9 +545,15 @@ Array<State> SketchPolicyNode::EvolutionarySearch(const Array<State>& init_popul
// Print statistical information
if (k % 5 == 0 || k == num_iters) {
- StdCout(verbose) << "GA Iter: " << k << std::fixed << std::setprecision(4)
- << "\tMax score: " << max_score << "\tMin score: " << heap.front().second
- << "\t#Pop: " << pnow->size() << "\t#M+: " << mutation_success_ct / (k + 1)
+ StdCout(verbose) << "GA Iter: " << k;
+ if (!heap.empty()) {
+ StdCout(verbose) << std::fixed << std::setprecision(4) << "\tMax score: " << max_score
+ << std::fixed << std::setprecision(4)
+ << "\tMin score: " << heap.front().second;
+ } else {
+ StdCout(verbose) << "\tMax score: N/A\tMin score: N/A";
+ }
+ StdCout(verbose) << "\t#Pop: " << heap.size() << "\t#M+: " << mutation_success_ct / (k + 1)
<< "\t#M-: " << mutation_fail_ct / (k + 1) << std::endl;
}
if (k == num_iters) {
diff --git a/tests/python/unittest/test_auto_scheduler_search_policy.py b/tests/python/unittest/test_auto_scheduler_search_policy.py
index 5329f3d..6493c24 100644
--- a/tests/python/unittest/test_auto_scheduler_search_policy.py
+++ b/tests/python/unittest/test_auto_scheduler_search_policy.py
@@ -37,7 +37,7 @@ def search_common(
seed=random.randint(1, 1 << 30),
runner="local",
cost_model=auto_scheduler.RandomModel(),
- num_measure_trials=2,
+ num_measure_trials=10,
init_search_callbacks=None,
):
print("Test %s schedule search with the default search policy" % (target))