You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by jr...@apache.org on 2021/05/11 17:44:24 UTC

[tvm] branch main updated: Remove minimum seed constraint on XGB Tuner (#7992)

This is an automated email from the ASF dual-hosted git repository.

jroesch pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 0f41d47  Remove minimum seed constraint on XGB Tuner (#7992)
0f41d47 is described below

commit 0f41d47bb43ba7509771a6254b986dcecc144479
Author: anwang2009 <an...@gmail.com>
AuthorDate: Tue May 11 10:44:01 2021 -0700

    Remove minimum seed constraint on XGB Tuner (#7992)
    
    * remove minimum seed
    
    * reset 3rdparty dep
    
    * add items to 'visited', parametrize min seed records
    
    * add comment
    
    * fix lint
    
    * add tests
---
 python/tvm/autotvm/tuner/ga_tuner.py               |  2 +-
 python/tvm/autotvm/tuner/index_based_tuner.py      |  2 +-
 python/tvm/autotvm/tuner/model_based_tuner.py      | 17 +++++++++++---
 python/tvm/autotvm/tuner/tuner.py                  |  7 +++++-
 python/tvm/autotvm/tuner/xgboost_cost_model.py     |  4 ++--
 .../python/unittest/test_autotvm_xgboost_model.py  | 27 ++++++++++++++++++++--
 6 files changed, 49 insertions(+), 10 deletions(-)

diff --git a/python/tvm/autotvm/tuner/ga_tuner.py b/python/tvm/autotvm/tuner/ga_tuner.py
index 5825199..2ecd120 100644
--- a/python/tvm/autotvm/tuner/ga_tuner.py
+++ b/python/tvm/autotvm/tuner/ga_tuner.py
@@ -143,5 +143,5 @@ class GATuner(Tuner):
     def has_next(self):
         return len(self.visited) - (len(self.genes) - self.trial_pt) < len(self.space)
 
-    def load_history(self, data_set):
+    def load_history(self, data_set, min_seed_records=500):
         pass
diff --git a/python/tvm/autotvm/tuner/index_based_tuner.py b/python/tvm/autotvm/tuner/index_based_tuner.py
index 945bcfd..972de65 100644
--- a/python/tvm/autotvm/tuner/index_based_tuner.py
+++ b/python/tvm/autotvm/tuner/index_based_tuner.py
@@ -53,7 +53,7 @@ class IndexBaseTuner(Tuner):
     def has_next(self):
         return self.counter < self.range_length
 
-    def load_history(self, data_set):
+    def load_history(self, data_set, min_seed_records=500):
         pass
 
 
diff --git a/python/tvm/autotvm/tuner/model_based_tuner.py b/python/tvm/autotvm/tuner/model_based_tuner.py
index 4d16339..f07e7fb 100644
--- a/python/tvm/autotvm/tuner/model_based_tuner.py
+++ b/python/tvm/autotvm/tuner/model_based_tuner.py
@@ -98,7 +98,7 @@ class CostModel(object):
         """
         raise NotImplementedError()
 
-    def fit_log(self, records, plan_size):
+    def fit_log(self, records, plan_size, min_seed_records=500):
         """Fit training data from log.
 
         Parameters
@@ -107,6 +107,11 @@ class CostModel(object):
             The tuning records
         plan_size: int
             The plan size of tuner
+        min_seed_records: int
+            Defaults to 500. Indicates the minimum number of records to
+            train the tuner with. If there are less than `min_seed_records`
+            number of records in `data_set`, no training of the tuner
+            will be done.
         """
         raise NotImplementedError()
 
@@ -264,6 +269,12 @@ class ModelBasedTuner(Tuner):
             else:
                 self.xs.append(index)
                 self.ys.append(0.0)
+            # Usually the update function is called during the tune loop
+            # after the index is already added to the visited set.
+            # However, adding the index to visited again here enables us
+            # to also use this update function to resume tuning progress in
+            # case of interruption.
+            self.visited.add(index)
 
         # if we have enough new training samples
         if len(self.xs) >= self.plan_size * (self.train_ct + 1) and self.flops_max > 1e-6:
@@ -285,13 +296,13 @@ class ModelBasedTuner(Tuner):
             self.trial_pt = 0
             self.train_ct += 1
 
-    def load_history(self, data_set):
+    def load_history(self, data_set, min_seed_records=500):
         # set in_tuning as True to make the feature extraction consistent
         GLOBAL_SCOPE.in_tuning = True
 
         # fit base model
         base_model = self.cost_model.spawn_base_model()
-        success = base_model.fit_log(data_set, self.plan_size)
+        success = base_model.fit_log(data_set, self.plan_size, min_seed_records)
 
         if not success:
             GLOBAL_SCOPE.in_tuning = False
diff --git a/python/tvm/autotvm/tuner/tuner.py b/python/tvm/autotvm/tuner/tuner.py
index fa60930..8ce6b74 100644
--- a/python/tvm/autotvm/tuner/tuner.py
+++ b/python/tvm/autotvm/tuner/tuner.py
@@ -200,12 +200,17 @@ class Tuner(object):
         self.best_flops = 0
         self.best_measure_pair = None
 
-    def load_history(self, data_set):
+    def load_history(self, data_set, min_seed_records=500):
         """load history data for transfer learning
 
         Parameters
         ----------
         data_set: Array of (autotvm.measure.MeasureInput, autotvm.measure.MeasureResult) pair
             Previous tuning records
+        min_seed_records: int
+            Defaults to 500. Indicates the minimum number of records to
+            train the tuner with. If there are less than `min_seed_records`
+            number of records in `data_set`, no training of the tuner
+            will be done.
         """
         raise NotImplementedError()
diff --git a/python/tvm/autotvm/tuner/xgboost_cost_model.py b/python/tvm/autotvm/tuner/xgboost_cost_model.py
index 4b0a45b..8190435 100644
--- a/python/tvm/autotvm/tuner/xgboost_cost_model.py
+++ b/python/tvm/autotvm/tuner/xgboost_cost_model.py
@@ -225,7 +225,7 @@ class XGBoostCostModel(CostModel):
             self.feature_cache.size(self.fea_type),
         )
 
-    def fit_log(self, records, plan_size):
+    def fit_log(self, records, plan_size, min_seed_records=500):
         tic = time.time()
 
         # filter data, only pick the data with a same task
@@ -258,7 +258,7 @@ class XGBoostCostModel(CostModel):
                 xs.append(x)
                 ys.append(y)
 
-        if len(xs) < 500:  # no enough samples
+        if len(xs) < min_seed_records:  # no enough samples
             return False
 
         xs, ys = np.array(xs), np.array(ys)
diff --git a/tests/python/unittest/test_autotvm_xgboost_model.py b/tests/python/unittest/test_autotvm_xgboost_model.py
index 58b2a4d..445cff8 100644
--- a/tests/python/unittest/test_autotvm_xgboost_model.py
+++ b/tests/python/unittest/test_autotvm_xgboost_model.py
@@ -59,13 +59,36 @@ def test_fit_spawn():
 
 def test_tuner():
     task, target = get_sample_task()
-    records = get_sample_records(n=100)
+    records = get_sample_records(n=10)
 
     tuner = autotvm.tuner.XGBTuner(task)
-    tuner.load_history(records)
+    tuner.load_history(records, min_seed_records=10)
+    # Confirm that loading history successfully loaded a
+    # base_model.
+    assert tuner.cost_model.base_model is not None
+
+    tuner = autotvm.tuner.XGBTuner(task)
+    tuner.load_history(records, min_seed_records=11)
+    # Confirm that loading history did not load base_model
+    # when not enough records according to `min_seed_records`
+    # are provided
+    assert tuner.cost_model.base_model is None
+
+
+def test_update():
+    task, target = get_sample_task()
+    tuner = autotvm.tuner.XGBTuner(task)
+    n_records = 5
+    records = get_sample_records(n=n_records)
+    tuner.update([inp for inp, _ in records], [res for _, res in records])
+    assert len(tuner.xs) == n_records
+    assert len(tuner.ys) == n_records
+    assert len(tuner.visited) == n_records
+    assert all(x in tuner.visited for x in tuner.xs)
 
 
 if __name__ == "__main__":
     test_fit()
     test_fit_spawn()
     test_tuner()
+    test_update()