You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/06/28 18:04:20 UTC

[tvm] branch main updated: [MetaSchedule] Enable Adapative Training For XGBoost Cost Model (#11892)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 0e23122846 [MetaSchedule] Enable Adapative Training For XGBoost Cost Model (#11892)
0e23122846 is described below

commit 0e23122846aa3b7a5350102d8c06fa21695d34be
Author: Xiyou Zhou <xi...@octoml.ai>
AuthorDate: Tue Jun 28 11:04:13 2022 -0700

    [MetaSchedule] Enable Adapative Training For XGBoost Cost Model (#11892)
    
    CostModel retraining is a time consuming part for MetaSchedule tuning, similar to AutoScheduler, we can alleviate it with an adapative way of increasing waiting period between each retraining. This PR introduced an argument called `adpative_training` in `TuneConfig` and the constructor of `XGBoostModel` to enable the capability. Testing tuning scripts are also updated.
---
 python/tvm/auto_scheduler/search_task.py         |  4 ++--
 python/tvm/auto_scheduler/testing/tune_onnx.py   | 14 +++++++++++---
 python/tvm/auto_scheduler/testing/tune_relay.py  | 14 +++++++++++---
 python/tvm/auto_scheduler/testing/tune_te.py     | 12 ++++++++++--
 python/tvm/meta_schedule/cost_model/xgb_model.py | 18 ++++++++++++++++++
 python/tvm/meta_schedule/default_config.py       |  6 +++++-
 python/tvm/meta_schedule/testing/tune_onnx.py    | 12 ++++++++++--
 python/tvm/meta_schedule/testing/tune_relay.py   | 12 ++++++++++--
 python/tvm/meta_schedule/testing/tune_te.py      | 10 +++++++++-
 python/tvm/meta_schedule/tune.py                 |  5 ++++-
 10 files changed, 90 insertions(+), 17 deletions(-)

diff --git a/python/tvm/auto_scheduler/search_task.py b/python/tvm/auto_scheduler/search_task.py
index 56dcb56abc..ab03ff9f8e 100644
--- a/python/tvm/auto_scheduler/search_task.py
+++ b/python/tvm/auto_scheduler/search_task.py
@@ -481,7 +481,7 @@ class SearchTask(Object):
             desc,
         )
 
-    def tune(self, tuning_options, search_policy=None):
+    def tune(self, tuning_options, search_policy=None, adaptive_training=False):
         """Run auto scheduling search for a task
 
         Parameters
@@ -492,7 +492,7 @@ class SearchTask(Object):
             The search policy to be used for schedule search.
         """
         if search_policy is None:
-            cost_model = XGBModel()
+            cost_model = XGBModel(adaptive_training=adaptive_training)
             search_policy = SketchPolicy(self, cost_model)
 
         _ffi_api.AutoSchedule(search_policy, tuning_options)
diff --git a/python/tvm/auto_scheduler/testing/tune_onnx.py b/python/tvm/auto_scheduler/testing/tune_onnx.py
index 1998f3d2c3..5444794cf1 100644
--- a/python/tvm/auto_scheduler/testing/tune_onnx.py
+++ b/python/tvm/auto_scheduler/testing/tune_onnx.py
@@ -99,7 +99,14 @@ def _parse_args():
         "--cpu-flush",
         type=lambda x: bool(strtobool(x)),
         required=True,
-        help="example: `True / False",
+        help="example: True / False",
+    )
+    args.add_argument(
+        "--adaptive-training",
+        type=lambda x: bool(strtobool(x)),
+        required=False,
+        help="example: True / False",
+        default=True,
     )
     parsed = args.parse_args()
     parsed.target = tvm.target.Target(parsed.target)
@@ -108,7 +115,7 @@ def _parse_args():
         tracker_host=parsed.rpc_host,
         tracker_port=parsed.rpc_port,
         tracker_key=parsed.rpc_key,
-        session_timeout_sec=3600,
+        session_timeout_sec=600,
     )
     return parsed
 
@@ -179,7 +186,8 @@ def main():
             measure_callbacks=[
                 auto_scheduler.RecordToFile(log_file),
             ],
-        )
+        ),
+        adaptive_training=ARGS.adaptive_training,
     )
 
     with auto_scheduler.ApplyHistoryBest(log_file):
diff --git a/python/tvm/auto_scheduler/testing/tune_relay.py b/python/tvm/auto_scheduler/testing/tune_relay.py
index 1a79b894bc..fedb27281a 100644
--- a/python/tvm/auto_scheduler/testing/tune_relay.py
+++ b/python/tvm/auto_scheduler/testing/tune_relay.py
@@ -97,7 +97,14 @@ def _parse_args():
         "--cpu-flush",
         type=lambda x: bool(strtobool(x)),
         required=True,
-        help="example: `True / False",
+        help="example: True / False",
+    )
+    args.add_argument(
+        "--adaptive-training",
+        type=lambda x: bool(strtobool(x)),
+        required=False,
+        help="example: True / False",
+        default=True,
     )
     parsed = args.parse_args()
     parsed.target = tvm.target.Target(parsed.target)
@@ -106,7 +113,7 @@ def _parse_args():
         tracker_host=parsed.rpc_host,
         tracker_port=parsed.rpc_port,
         tracker_key=parsed.rpc_key,
-        session_timeout_sec=3600,
+        session_timeout_sec=600,
     )
     return parsed
 
@@ -180,7 +187,8 @@ def main():
             measure_callbacks=[
                 auto_scheduler.RecordToFile(log_file),
             ],
-        )
+        ),
+        adaptive_training=ARGS.adaptive_training,
     )
 
     with auto_scheduler.ApplyHistoryBest(log_file):
diff --git a/python/tvm/auto_scheduler/testing/tune_te.py b/python/tvm/auto_scheduler/testing/tune_te.py
index c844bb9bf6..c6a5ab27cf 100644
--- a/python/tvm/auto_scheduler/testing/tune_te.py
+++ b/python/tvm/auto_scheduler/testing/tune_te.py
@@ -82,7 +82,14 @@ def _parse_args():
         "--cpu-flush",
         type=lambda x: bool(strtobool(x)),
         required=True,
-        help="example: `True / False",
+        help="example: True / False",
+    )
+    args.add_argument(
+        "--adaptive-training",
+        type=lambda x: bool(strtobool(x)),
+        required=False,
+        help="example: True / False",
+        default=True,
     )
     parsed = args.parse_args()
     parsed.target = tvm.target.Target(parsed.target)
@@ -135,6 +142,7 @@ def main():
         repeat=ARGS.repeat,
         min_repeat_ms=ARGS.min_repeat_ms,
         enable_cpu_cache_flush=ARGS.cpu_flush,
+        # todo(zxybazh): set session timeout to 60 same as MS
     )
 
     # Inspect the computational graph
@@ -147,7 +155,7 @@ def main():
         runner=runner,
     )
     print("Running AutoTuning:")
-    task.tune(tune_option)
+    task.tune(tune_option, adaptive_training=ARGS.adaptive_training)
     print("History Best:")
     print(task.print_best(log_file))
     sch, args = task.apply_best(log_file)
diff --git a/python/tvm/meta_schedule/cost_model/xgb_model.py b/python/tvm/meta_schedule/cost_model/xgb_model.py
index 910c4ec2d3..8de034758b 100644
--- a/python/tvm/meta_schedule/cost_model/xgb_model.py
+++ b/python/tvm/meta_schedule/cost_model/xgb_model.py
@@ -298,6 +298,8 @@ class XGBModel(PyCostModel):
         The verbose level when doing evaluation.
     average_peak_n : int
         The number to calculate average peak score.
+    adaptive_training : bool
+        Whether use adpative training to reduce tuning time.
     """
 
     # feature extractor
@@ -314,6 +316,9 @@ class XGBModel(PyCostModel):
     data: Dict[str, FeatureGroup]
     data_size: int
     booster: Optional["xgb.Booster"]
+    # adaptive training
+    adaptive_training: bool
+    last_train_size: int
 
     def __init__(
         self,
@@ -328,6 +333,7 @@ class XGBModel(PyCostModel):
         early_stopping_rounds: int = 50,
         verbose_eval: int = 25,
         average_peak_n: int = 32,
+        adaptive_training: bool = True,
     ):
         super().__init__()
         # feature extractor
@@ -347,6 +353,9 @@ class XGBModel(PyCostModel):
         self.data = OrderedDict()
         self.data_size = 0
         self.booster = None
+        # adaptive training
+        self.adaptive_training = adaptive_training
+        self.last_train_size = 0
 
     def load(self, path: str) -> None:
         """Load the cost model from given file location.
@@ -491,6 +500,15 @@ class XGBModel(PyCostModel):
         self.data[new_group_hash] = group
         self.data_size += len(new_features)
 
+        if (
+            self.adaptive_training
+            and self.data_size - self.last_train_size < self.last_train_size / 5
+        ):
+            # Set a training threshold related to `last_train_size` to reduce the training
+            # overhead when there're too many results
+            return
+        self.last_train_size = self.data_size
+
         # Step 5. Re-train the model
         self._train(
             xs=list(itertools_chain.from_iterable([g.features for g in self.data.values()])),
diff --git a/python/tvm/meta_schedule/default_config.py b/python/tvm/meta_schedule/default_config.py
index ff01205381..e99dd1383a 100644
--- a/python/tvm/meta_schedule/default_config.py
+++ b/python/tvm/meta_schedule/default_config.py
@@ -141,10 +141,14 @@ def callbacks(  # pylint: disable=redefined-outer-name
 
 def cost_model(
     cost_model: Optional[CostModel],  # pylint: disable=redefined-outer-name
+    adpative_training: Optional[bool],
 ) -> CostModel:
     """Normalize the input to tvm.meta_schedule.CostModel"""
     if cost_model is None:
-        return XGBModel(extractor=PerStoreFeature())  # type: ignore
+        return XGBModel(  # type: ignore
+            extractor=PerStoreFeature(),
+            adaptive_training=adpative_training is None or adpative_training,
+        )
     if not isinstance(cost_model, CostModel):
         raise TypeError(f"Expected `cost_model` to be CostModel, but gets: {cost_model}")
     return cost_model
diff --git a/python/tvm/meta_schedule/testing/tune_onnx.py b/python/tvm/meta_schedule/testing/tune_onnx.py
index f8a3f1f0ca..8ae9ab1ed0 100644
--- a/python/tvm/meta_schedule/testing/tune_onnx.py
+++ b/python/tvm/meta_schedule/testing/tune_onnx.py
@@ -96,7 +96,14 @@ def _parse_args():
         "--cpu-flush",
         type=lambda x: bool(strtobool(x)),
         required=True,
-        help="example: `True / False",
+        help="example: True / False",
+    )
+    args.add_argument(
+        "--adaptive-training",
+        type=lambda x: bool(strtobool(x)),
+        required=False,
+        help="example: True / False",
+        default=True,
     )
     parsed = args.parse_args()
     parsed.target = tvm.target.Target(parsed.target)
@@ -105,7 +112,7 @@ def _parse_args():
         tracker_host=parsed.rpc_host,
         tracker_port=parsed.rpc_port,
         tracker_key=parsed.rpc_key,
-        session_timeout_sec=3600,
+        session_timeout_sec=600,
     )
     return parsed
 
@@ -147,6 +154,7 @@ def main():
                 num_trials_per_iter=64,
                 max_trials_per_task=ARGS.num_trials,
                 max_trials_global=ARGS.num_trials,
+                adaptive_training=ARGS.adaptive_training,
             ),
             runner=runner,  # type: ignore
             work_dir=ARGS.work_dir,
diff --git a/python/tvm/meta_schedule/testing/tune_relay.py b/python/tvm/meta_schedule/testing/tune_relay.py
index bd235cf03d..daef48daa2 100644
--- a/python/tvm/meta_schedule/testing/tune_relay.py
+++ b/python/tvm/meta_schedule/testing/tune_relay.py
@@ -94,7 +94,14 @@ def _parse_args():
         "--cpu-flush",
         type=lambda x: bool(strtobool(x)),
         required=True,
-        help="example: `True / False",
+        help="example: True / False",
+    )
+    args.add_argument(
+        "--adaptive-training",
+        type=lambda x: bool(strtobool(x)),
+        required=False,
+        help="example: True / False",
+        default=True,
     )
     parsed = args.parse_args()
     parsed.target = tvm.target.Target(parsed.target)
@@ -103,7 +110,7 @@ def _parse_args():
         tracker_host=parsed.rpc_host,
         tracker_port=parsed.rpc_port,
         tracker_key=parsed.rpc_key,
-        session_timeout_sec=3600,
+        session_timeout_sec=600,
     )
     return parsed
 
@@ -148,6 +155,7 @@ def main():
                 num_trials_per_iter=64,
                 max_trials_per_task=ARGS.num_trials,
                 max_trials_global=ARGS.num_trials,
+                adaptive_training=ARGS.adaptive_training,
             ),
             runner=runner,  # type: ignore
             work_dir=ARGS.work_dir,
diff --git a/python/tvm/meta_schedule/testing/tune_te.py b/python/tvm/meta_schedule/testing/tune_te.py
index bd0a1d9b68..e579c561ad 100644
--- a/python/tvm/meta_schedule/testing/tune_te.py
+++ b/python/tvm/meta_schedule/testing/tune_te.py
@@ -83,7 +83,14 @@ def _parse_args():
         "--cpu-flush",
         type=lambda x: bool(strtobool(x)),
         required=True,
-        help="example: `True / False",
+        help="example: True / False",
+    )
+    args.add_argument(
+        "--adaptive-training",
+        type=lambda x: bool(strtobool(x)),
+        required=False,
+        help="example: True / False",
+        default=True,
     )
     parsed = args.parse_args()
     parsed.target = tvm.target.Target(parsed.target)
@@ -125,6 +132,7 @@ def main():
                 num_trials_per_iter=64,
                 max_trials_per_task=ARGS.num_trials,
                 max_trials_global=ARGS.num_trials,
+                adaptive_training=ARGS.adaptive_training,
             ),
             runner=runner,  # type: ignore
             task_name=ARGS.workload,
diff --git a/python/tvm/meta_schedule/tune.py b/python/tvm/meta_schedule/tune.py
index d3c09b4129..fabf14ab23 100644
--- a/python/tvm/meta_schedule/tune.py
+++ b/python/tvm/meta_schedule/tune.py
@@ -78,6 +78,8 @@ class TuneConfig(NamedTuple):
         Configuration for search strategy.
     logger_config: Optional[Dict[str, Any]] = None
         Configuration for logger.
+    adaptive_training: Optional[bool] = None
+        Whether adpative training is enabled for cost model.
     """
 
     max_trials_global: int
@@ -88,6 +90,7 @@ class TuneConfig(NamedTuple):
     task_scheduler_config: Optional[Dict[str, Any]] = None
     search_strategy_config: Optional[Dict[str, Any]] = None
     logger_config: Optional[Dict[str, Any]] = None
+    adaptive_training: Optional[bool] = None
 
     def create_strategy(self):
         """Create search strategy from configuration"""
@@ -310,7 +313,7 @@ def tune_extracted_tasks(
     database = default_config.database(database, work_dir)
     builder = default_config.builder(builder)
     runner = default_config.runner(runner)
-    cost_model = default_config.cost_model(cost_model)
+    cost_model = default_config.cost_model(cost_model, config.adaptive_training)
     measure_callbacks = default_config.callbacks(measure_callbacks)
     # parse the tuning contexts
     tune_contexts = []