You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/07/21 06:01:11 UTC

[GitHub] [tvm] shingjan commented on a diff in pull request #12144: [Auto Scheduler] Upgrade autoscheduler xgboost callback

shingjan commented on code in PR #12144:
URL: https://github.com/apache/tvm/pull/12144#discussion_r926281877


##########
python/tvm/auto_scheduler/cost_model/xgb_model.py:
##########
@@ -539,125 +539,128 @@ def feval(preds, labels):
     return feval
 
 
-def custom_callback(
-    stopping_rounds,
-    metric,
-    fevals,
-    evals=(),
-    log_file=None,
-    maximize=False,
-    verbose_eval=True,
-    skip_every=2,
-):
-    """Callback function for xgboost to support multiple custom evaluation functions"""
-    # pylint: disable=import-outside-toplevel
-    from xgboost.core import EarlyStopException
-    from xgboost.callback import _fmt_metric
-
-    try:
-        from xgboost.training import aggcv
-    except ImportError:
-        from xgboost.callback import _aggcv as aggcv
-
-    state = {}
-    metric_shortname = metric.split("-")[1]
-
-    def init(env):
-        """internal function"""
-        bst = env.model
-
-        state["maximize_score"] = maximize
-        state["best_iteration"] = 0
-        if maximize:
-            state["best_score"] = float("-inf")
-        else:
-            state["best_score"] = float("inf")
+class CustomCallback(callback.TrainingCallback):

Review Comment:
   `TrainingCallBack` could have different implementation and I think backwards compatibility could be appreciate here in case `xgboost` < 1.6.0 is used. You can refer to the change [here](https://github.com/apache/tvm/pull/12141) in meta schedule.



##########
python/tvm/auto_scheduler/cost_model/xgb_model.py:
##########
@@ -539,125 +539,128 @@ def feval(preds, labels):
     return feval
 
 
-def custom_callback(
-    stopping_rounds,
-    metric,
-    fevals,
-    evals=(),
-    log_file=None,
-    maximize=False,
-    verbose_eval=True,
-    skip_every=2,
-):
-    """Callback function for xgboost to support multiple custom evaluation functions"""
-    # pylint: disable=import-outside-toplevel
-    from xgboost.core import EarlyStopException
-    from xgboost.callback import _fmt_metric
-
-    try:
-        from xgboost.training import aggcv
-    except ImportError:
-        from xgboost.callback import _aggcv as aggcv
-
-    state = {}
-    metric_shortname = metric.split("-")[1]
-
-    def init(env):
-        """internal function"""
-        bst = env.model
-
-        state["maximize_score"] = maximize
-        state["best_iteration"] = 0
-        if maximize:
-            state["best_score"] = float("-inf")
-        else:
-            state["best_score"] = float("inf")
+class CustomCallback(callback.TrainingCallback):
+    """
+    Callback function for xgboost.
+    Support custom evaluation function and early-stopping.
+    """
 
-        if bst is not None:
-            if bst.attr("best_score") is not None:
-                state["best_score"] = float(bst.attr("best_score"))
-                state["best_iteration"] = int(bst.attr("best_iteration"))
-                state["best_msg"] = bst.attr("best_msg")
+    def __init__(
+        self,
+        stopping_rounds,
+        metric,
+        fevals,
+        evals=(),
+        log_file=None,
+        maximize=False,
+        verbose_eval=True,
+        skip_every=2,
+    ):
+        """Init function"""
+        self.stopping_rounds = stopping_rounds
+        self.metric = metric
+        self.metric_shortname = metric.split("-")[1]
+        self.fevals = fevals
+        self.evals = evals
+        self.log_file = log_file
+        self.maximize = maximize
+        self.verbose_eval = verbose_eval
+        self.skip_every = skip_every
+        self.state = {}
+
+    def after_iteration(self, model, epoch, _evals_log):
+        """Run after each iteration.  Return True when training should stop."""
+        ##### init state #####
+        if not self.state:
+            self.state["maximize_score"] = self.maximize
+            self.state["best_iteration"] = 0
+            if self.maximize:
+                self.state["best_score"] = float("-inf")
             else:
-                bst.set_attr(best_iteration=str(state["best_iteration"]))
-                bst.set_attr(best_score=str(state["best_score"]))
-        else:
-            assert env.cvfolds is not None
-
-    def callback(env):
-        """internal function"""
-        if not state:
-            init(env)
-
-        bst = env.model
-        i = env.iteration
-        cvfolds = env.cvfolds
+                self.state["best_score"] = float("inf")
 
+            assert model is not None
+            if model.attr("best_score") is not None:
+                self.state["best_score"] = float(model.attr("best_score"))
+                self.state["best_iteration"] = int(model.attr("best_iteration"))
+                self.state["best_msg"] = model.attr("best_msg")
+            else:
+                model.set_attr(best_iteration=str(self.state["best_iteration"]))
+                model.set_attr(best_score=str(self.state["best_score"]))
         res_dict = {}
 
-        if i % skip_every == 1:
-            return
+        if epoch % self.skip_every == 1:
+            return False
 
         ##### evaluation #####
-        if cvfolds is not None:
-            for feval in fevals:
-                tmp = aggcv([f.eval(i, feval) for f in cvfolds])
-                for k, mean, std in tmp:
-                    res_dict[k] = [mean, std]
-        else:
-            for feval in fevals:
-                bst_eval = bst.eval_set(evals, i, feval)
-                res = [x.split(":") for x in bst_eval.split()]
-                for kv in res[1:]:
-                    res_dict[kv[0]] = [float(kv[1])]
+        for feval in self.fevals:
+            bst_eval = model.eval_set(self.evals, epoch, feval)
+            res = [x.split(":") for x in bst_eval.split()]
+            for kv in res[1:]:
+                res_dict[kv[0]] = [float(kv[1])]
 
         eval_res = []
         keys = list(res_dict.keys())
-        keys.sort(key=lambda x: x if metric_shortname not in x else "a" + x)
+        keys.sort(key=lambda x: x if self.metric_shortname not in x else "a" + x)
         for key in keys:
             v = res_dict[key]
             eval_res.append([key] + v)
 
         ##### print eval result #####
-        if not isinstance(verbose_eval, bool) and verbose_eval and i % verbose_eval == 0:
-            infos = ["XGB iter: %3d" % i]
+        if (
+            not isinstance(self.verbose_eval, bool)
+            and self.verbose_eval
+            and epoch % self.verbose_eval == 0
+        ):
+            infos = ["XGB iter: %3d" % epoch]
             for item in eval_res:
                 if "null" in item[0]:
                     continue
                 infos.append("%s: %.6f" % (item[0], item[1]))
 
             logger.debug("\t".join(infos))
-            if log_file:
-                with open(log_file, "a") as fout:
+            if self.log_file:
+                with open(self.log_file, "a") as fout:
                     fout.write("\t".join(infos) + "\n")
 
         ##### choose score and do early stopping #####
         score = None
         for item in eval_res:
-            if item[0] == metric:
+            if item[0] == self.metric:
                 score = item[1]
                 break
         assert score is not None
 
-        best_score = state["best_score"]
-        best_iteration = state["best_iteration"]
-        maximize_score = state["maximize_score"]
+        best_score = self.state["best_score"]
+        best_iteration = self.state["best_iteration"]
+        maximize_score = self.state["maximize_score"]
+
         if (maximize_score and score > best_score) or (not maximize_score and score < best_score):
-            msg = "[%d] %s" % (env.iteration, "\t".join([_fmt_metric(x) for x in eval_res]))
-            state["best_msg"] = msg
-            state["best_score"] = score
-            state["best_iteration"] = env.iteration
+            msg = "[%d] %s" % (epoch, "\t".join([self._fmt_metric(x) for x in eval_res]))
+            self.state["best_msg"] = msg
+            self.state["best_score"] = score
+            self.state["best_iteration"] = epoch
             # save the property to attributes, so they will occur in checkpoint.
-            if env.model is not None:
-                env.model.set_attr(
-                    best_score=str(state["best_score"]),
-                    best_iteration=str(state["best_iteration"]),
-                    best_msg=state["best_msg"],
+            if model is not None:
+                model.set_attr(
+                    best_score=str(self.state["best_score"]),
+                    best_iteration=str(self.state["best_iteration"]),
+                    best_msg=self.state["best_msg"],
                 )
-        elif env.iteration - best_iteration >= stopping_rounds:
-            best_msg = state["best_msg"]
-            if verbose_eval and env.rank == 0:
+        elif epoch - best_iteration >= self.stopping_rounds:
+            best_msg = self.state["best_msg"]
+            if self.verbose_eval:
                 logger.debug("XGB stopped. Best iteration: %s ", best_msg)
-            raise EarlyStopException(best_iteration)
-
-    return callback
+            return True
+
+        return False
+
+    def _fmt_metric(self, value, show_stdv=True):

Review Comment:
   `_fmt_metric` could be imported from xgboost < 1.6.0. It would be better if we could try import it first and declare it later if there is an `ImportError`



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org