You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by jr...@apache.org on 2021/05/11 17:44:24 UTC
[tvm] branch main updated: Remove minimum seed constraint on XGB
Tuner (#7992)
This is an automated email from the ASF dual-hosted git repository.
jroesch pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 0f41d47 Remove minimum seed constraint on XGB Tuner (#7992)
0f41d47 is described below
commit 0f41d47bb43ba7509771a6254b986dcecc144479
Author: anwang2009 <an...@gmail.com>
AuthorDate: Tue May 11 10:44:01 2021 -0700
Remove minimum seed constraint on XGB Tuner (#7992)
* remove minimum seed
* reset 3rdparty dep
* add items to 'visited', parametrize min seed records
* add comment
* fix lint
* add tests
---
python/tvm/autotvm/tuner/ga_tuner.py | 2 +-
python/tvm/autotvm/tuner/index_based_tuner.py | 2 +-
python/tvm/autotvm/tuner/model_based_tuner.py | 17 +++++++++++---
python/tvm/autotvm/tuner/tuner.py | 7 +++++-
python/tvm/autotvm/tuner/xgboost_cost_model.py | 4 ++--
.../python/unittest/test_autotvm_xgboost_model.py | 27 ++++++++++++++++++++--
6 files changed, 49 insertions(+), 10 deletions(-)
diff --git a/python/tvm/autotvm/tuner/ga_tuner.py b/python/tvm/autotvm/tuner/ga_tuner.py
index 5825199..2ecd120 100644
--- a/python/tvm/autotvm/tuner/ga_tuner.py
+++ b/python/tvm/autotvm/tuner/ga_tuner.py
@@ -143,5 +143,5 @@ class GATuner(Tuner):
def has_next(self):
return len(self.visited) - (len(self.genes) - self.trial_pt) < len(self.space)
- def load_history(self, data_set):
+ def load_history(self, data_set, min_seed_records=500):
pass
diff --git a/python/tvm/autotvm/tuner/index_based_tuner.py b/python/tvm/autotvm/tuner/index_based_tuner.py
index 945bcfd..972de65 100644
--- a/python/tvm/autotvm/tuner/index_based_tuner.py
+++ b/python/tvm/autotvm/tuner/index_based_tuner.py
@@ -53,7 +53,7 @@ class IndexBaseTuner(Tuner):
def has_next(self):
return self.counter < self.range_length
- def load_history(self, data_set):
+ def load_history(self, data_set, min_seed_records=500):
pass
diff --git a/python/tvm/autotvm/tuner/model_based_tuner.py b/python/tvm/autotvm/tuner/model_based_tuner.py
index 4d16339..f07e7fb 100644
--- a/python/tvm/autotvm/tuner/model_based_tuner.py
+++ b/python/tvm/autotvm/tuner/model_based_tuner.py
@@ -98,7 +98,7 @@ class CostModel(object):
"""
raise NotImplementedError()
- def fit_log(self, records, plan_size):
+ def fit_log(self, records, plan_size, min_seed_records=500):
"""Fit training data from log.
Parameters
@@ -107,6 +107,11 @@ class CostModel(object):
The tuning records
plan_size: int
The plan size of tuner
+ min_seed_records: int
+ Defaults to 500. Indicates the minimum number of records to
+ train the tuner with. If there are less than `min_seed_records`
+ number of records in `data_set`, no training of the tuner
+ will be done.
"""
raise NotImplementedError()
@@ -264,6 +269,12 @@ class ModelBasedTuner(Tuner):
else:
self.xs.append(index)
self.ys.append(0.0)
+ # Usually the update function is called during the tune loop
+ # after the index is already added to the visited set.
+ # However, adding the index to visited again here enables us
+ # to also use this update function to resume tuning progress in
+ # case of interruption.
+ self.visited.add(index)
# if we have enough new training samples
if len(self.xs) >= self.plan_size * (self.train_ct + 1) and self.flops_max > 1e-6:
@@ -285,13 +296,13 @@ class ModelBasedTuner(Tuner):
self.trial_pt = 0
self.train_ct += 1
- def load_history(self, data_set):
+ def load_history(self, data_set, min_seed_records=500):
# set in_tuning as True to make the feature extraction consistent
GLOBAL_SCOPE.in_tuning = True
# fit base model
base_model = self.cost_model.spawn_base_model()
- success = base_model.fit_log(data_set, self.plan_size)
+ success = base_model.fit_log(data_set, self.plan_size, min_seed_records)
if not success:
GLOBAL_SCOPE.in_tuning = False
diff --git a/python/tvm/autotvm/tuner/tuner.py b/python/tvm/autotvm/tuner/tuner.py
index fa60930..8ce6b74 100644
--- a/python/tvm/autotvm/tuner/tuner.py
+++ b/python/tvm/autotvm/tuner/tuner.py
@@ -200,12 +200,17 @@ class Tuner(object):
self.best_flops = 0
self.best_measure_pair = None
- def load_history(self, data_set):
+ def load_history(self, data_set, min_seed_records=500):
"""load history data for transfer learning
Parameters
----------
data_set: Array of (autotvm.measure.MeasureInput, autotvm.measure.MeasureResult) pair
Previous tuning records
+ min_seed_records: int
+ Defaults to 500. Indicates the minimum number of records to
+ train the tuner with. If there are less than `min_seed_records`
+ number of records in `data_set`, no training of the tuner
+ will be done.
"""
raise NotImplementedError()
diff --git a/python/tvm/autotvm/tuner/xgboost_cost_model.py b/python/tvm/autotvm/tuner/xgboost_cost_model.py
index 4b0a45b..8190435 100644
--- a/python/tvm/autotvm/tuner/xgboost_cost_model.py
+++ b/python/tvm/autotvm/tuner/xgboost_cost_model.py
@@ -225,7 +225,7 @@ class XGBoostCostModel(CostModel):
self.feature_cache.size(self.fea_type),
)
- def fit_log(self, records, plan_size):
+ def fit_log(self, records, plan_size, min_seed_records=500):
tic = time.time()
# filter data, only pick the data with a same task
@@ -258,7 +258,7 @@ class XGBoostCostModel(CostModel):
xs.append(x)
ys.append(y)
- if len(xs) < 500: # no enough samples
+ if len(xs) < min_seed_records: # no enough samples
return False
xs, ys = np.array(xs), np.array(ys)
diff --git a/tests/python/unittest/test_autotvm_xgboost_model.py b/tests/python/unittest/test_autotvm_xgboost_model.py
index 58b2a4d..445cff8 100644
--- a/tests/python/unittest/test_autotvm_xgboost_model.py
+++ b/tests/python/unittest/test_autotvm_xgboost_model.py
@@ -59,13 +59,36 @@ def test_fit_spawn():
def test_tuner():
task, target = get_sample_task()
- records = get_sample_records(n=100)
+ records = get_sample_records(n=10)
tuner = autotvm.tuner.XGBTuner(task)
- tuner.load_history(records)
+ tuner.load_history(records, min_seed_records=10)
+ # Confirm that loading history successfully loaded a
+ # base_model.
+ assert tuner.cost_model.base_model is not None
+
+ tuner = autotvm.tuner.XGBTuner(task)
+ tuner.load_history(records, min_seed_records=11)
+ # Confirm that loading history did not load base_model
+ # when not enough records according to `min_seed_records`
+ # are provided
+ assert tuner.cost_model.base_model is None
+
+
+def test_update():
+ task, target = get_sample_task()
+ tuner = autotvm.tuner.XGBTuner(task)
+ n_records = 5
+ records = get_sample_records(n=n_records)
+ tuner.update([inp for inp, _ in records], [res for _, res in records])
+ assert len(tuner.xs) == n_records
+ assert len(tuner.ys) == n_records
+ assert len(tuner.visited) == n_records
+ assert all(x in tuner.visited for x in tuner.xs)
if __name__ == "__main__":
test_fit()
test_fit_spawn()
test_tuner()
+ test_update()