You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/07/25 20:52:04 UTC

[GitHub] [tvm] zxybazh commented on a diff in pull request #12141: [Meta Schedule][XGBoost] Update the custom callback function of xgboost in meta schedule

zxybazh commented on code in PR #12141:
URL: https://github.com/apache/tvm/pull/12141#discussion_r928046516


##########
python/tvm/meta_schedule/cost_model/xgb_model.py:
##########
@@ -35,6 +35,14 @@
 from ..utils import cpu_count, derived_object, shash2hex
 from .metric import max_curve
 
+try:
+    from xgboost.callback import TrainingCallback  # type: ignore

Review Comment:
   Please put this under `if TYPE_CHECKING:` because xgboost is not a dependency of tvm.



##########
python/tvm/meta_schedule/cost_model/xgb_model.py:
##########
@@ -763,3 +768,162 @@ def callback(env: "xgb.core.CallbackEnv"):
             raise EarlyStopException(best_iteration)
 
     return callback
+
+
+class XGBoostCallback(TrainingCallback):
+    """Base class for XGBoost callbacks."""
+
+    def __call__(self, env: "xgb.core.CallbackEnv"):
+        # Compatibility with xgboost < 1.3
+        return self.after_iteration(env.model, env.iteration, env.evaluation_result_list)
+
+    def after_iteration(self, model: "xgb.Booster", epoch: int, evals_log: Dict):
+        raise NotImplementedError
+
+
+class XGBoostCustomCallback(XGBoostCallback):
+    """Custom callback class for xgboost to support multiple custom evaluation functions"""
+
+    def __init__(
+        self,
+        early_stopping_rounds: int,
+        verbose_eval: int,
+        fevals: List[Callable],
+        evals: List[Tuple["xgb.DMatrix", str]],
+        focused_metric: str = "tr-p-rmse",
+        cvfolds: List["xgb.training.CVPack"] = None,
+    ):
+        self.early_stopping_rounds = early_stopping_rounds
+        self.verbose_eval = verbose_eval
+        self.fevals = fevals
+        self.evals = evals
+        self.state: Dict[str, Any] = {}
+        self.focused_metric = focused_metric
+        self.sort_key = make_metric_sorter(focused_metric=focused_metric)
+        self.cvfolds = cvfolds
+        if cvfolds is not None:
+            self.aggregated_cv = None
+
+    def init(self, model: "xgb.Booster"):
+        """Internal function for intialization"""
+        booster: "xgb.Booster" = model
+        self.state["best_iteration"] = 0
+        self.state["best_score"] = float("inf")
+        if booster is None:
+            assert self.cvfolds is not None
+            return
+        if booster.attr("best_score") is not None:
+            self.state["best_score"] = float(booster.attr("best_score"))
+            self.state["best_iteration"] = int(booster.attr("best_iteration"))
+            self.state["best_msg"] = booster.attr("best_msg")
+        else:
+            booster.set_attr(best_iteration=str(self.state["best_iteration"]))
+            booster.set_attr(best_score=str(self.state["best_score"]))
+
+    def after_iteration(self, model: "xgb.Booster", epoch: int, evals_log: Dict):
+        """Internal function for after_iteration"""
+        # pylint:disable = import-outside-toplevel

Review Comment:
   enable this later.



##########
python/tvm/meta_schedule/cost_model/xgb_model.py:
##########
@@ -573,22 +581,19 @@ def rmse(ys_pred: np.ndarray, d_train: "xgb.DMatrix"):  # type: ignore # pylint:
         def avg_peak_score(ys_pred: np.ndarray, d_train: "xgb.DMatrix"):  # type: ignore # pylint: disable = unused-argument
             return self.d_train.average_peak_score(ys_pred, self.average_peak_n)
 
+        xgb_custom_callback = XGBoostCustomCallback(
+            early_stopping_rounds=self.early_stopping_rounds,
+            verbose_eval=self.verbose_eval,
+            fevals=[rmse, avg_peak_score],
+            evals=[(self.d_train.dmatrix, "tr")],
+            cvfolds=None,
+        )
         self.booster = xgb.train(
             self.config.to_dict(),
             self.d_train.dmatrix,
             num_boost_round=10000,
             obj=obj,
-            callbacks=[
-                custom_callback(

Review Comment:
   Shall we also remove this function?



##########
python/tvm/meta_schedule/cost_model/xgb_model.py:
##########
@@ -763,3 +768,162 @@ def callback(env: "xgb.core.CallbackEnv"):
             raise EarlyStopException(best_iteration)
 
     return callback
+
+
+class XGBoostCallback(TrainingCallback):
+    """Base class for XGBoost callbacks."""
+
+    def __call__(self, env: "xgb.core.CallbackEnv"):
+        # Compatibility with xgboost < 1.3
+        return self.after_iteration(env.model, env.iteration, env.evaluation_result_list)

Review Comment:
   Thanks for adding the compatibility support here. Would be great if we can:
   1. Have a unit test case specifically using this api and another unit test specifically using new api in CI.
   2. Throw a warning about xgboost version too low and this api will be deprecated in future version.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org