You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/09/30 05:54:08 UTC

[tvm] branch main updated: [MetaSchedule] Fix XGBoost Import Issue (#12936)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 4e4089edda [MetaSchedule] Fix XGBoost Import Issue (#12936)
4e4089edda is described below

commit 4e4089edda7f3cd888178f4ad325d7824717ce8e
Author: Xiyou Zhou <xi...@octoml.ai>
AuthorDate: Thu Sep 29 22:54:01 2022 -0700

    [MetaSchedule] Fix XGBoost Import Issue (#12936)
    
    Previous upgrade introduced a import of xgboost in meta_schedule, removed in current version by using a function to return the call back class.
    
    We've recently introduced a XGBoost Model upgrade to support new xgboost version of callback class in https://github.com/apache/tvm/pull/12141. However, in this PR it uses a function called `optional_xgboost_callback` that works to avoid compatibility issue (xgboost 1.5.2 v.s. 1.6.0). In this specific function, it tries to import the newly introduced xgboost callback class and create a new class using it as base class. This actually imported xgboost when meta_schedule is imported, whi [...]
---
 python/tvm/meta_schedule/cost_model/xgb_model.py   | 348 +++++++++++----------
 .../unittest/test_meta_schedule_cost_model.py      |  45 ++-
 2 files changed, 214 insertions(+), 179 deletions(-)

diff --git a/python/tvm/meta_schedule/cost_model/xgb_model.py b/python/tvm/meta_schedule/cost_model/xgb_model.py
index 1171e081b9..59774b534e 100644
--- a/python/tvm/meta_schedule/cost_model/xgb_model.py
+++ b/python/tvm/meta_schedule/cost_model/xgb_model.py
@@ -22,7 +22,7 @@ import os
 import tempfile
 from collections import OrderedDict
 from itertools import chain as itertools_chain
-from typing import TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Tuple, Callable
 
 import numpy as np  # type: ignore
 
@@ -36,26 +36,10 @@ from ..utils import cpu_count, derived_object, shash2hex
 from .metric import max_curve
 
 
-def optional_xgboost_callback(cls):
-    """Decorator for importing TraningCallback from xgboost"""
-    # pylint:disable = import-outside-toplevel
-    try:
-        from xgboost.callback import TrainingCallback  # type: ignore
-    # pylint:enable = import-outside-toplevel
-    except ImportError:
-
-        class TrainingCallback:  # type: ignore
-            pass
-
-    class OptXGBoostCustomCallback(cls, TrainingCallback):  # type: ignore
-        pass
-
-    return OptXGBoostCustomCallback
-
-
 if TYPE_CHECKING:
 
     import xgboost as xgb  # type: ignore
+    from xgboost.callback import TrainingCallback  # type: ignore
 
     from ..tune_context import TuneContext
 
@@ -346,7 +330,7 @@ class XGBModel(PyCostModel):
         extractor: FeatureExtractor,
         # xgboost model config
         config: XGBConfig = XGBConfig(),
-        # behavior of randomness
+        # random result before enough samples
         num_warmup_samples: int = 100,
         # evaluation
         early_stopping_rounds: int = 50,
@@ -598,7 +582,7 @@ class XGBModel(PyCostModel):
             num_boost_round=10000,
             obj=obj,
             callbacks=[
-                XGBoostCustomCallback(
+                _get_custom_call_back(
                     early_stopping_rounds=self.early_stopping_rounds,
                     verbose_eval=self.verbose_eval,
                     fevals=[rmse, avg_peak_score],
@@ -657,158 +641,194 @@ class XGBModel(PyCostModel):
         return eval_result
 
 
-@optional_xgboost_callback
-class XGBoostCustomCallback:
-    """Custom callback class for xgboost to support multiple custom evaluation functions"""
-
-    def __init__(
-        self,
-        early_stopping_rounds: int,
-        verbose_eval: int,
-        fevals: List[Callable],
-        evals: List[Tuple["xgb.DMatrix", str]],
-        focused_metric: str = "tr-p-rmse",
-        cvfolds: List["xgb.training.CVPack"] = None,
-    ):
-        self.early_stopping_rounds = early_stopping_rounds
-        self.verbose_eval = verbose_eval
-        self.fevals = fevals
-        self.evals = evals
-        self.state: Dict[str, Any] = {}
-        self.focused_metric = focused_metric
-        self.sort_key = make_metric_sorter(focused_metric=focused_metric)
-        self.cvfolds = cvfolds
-        if cvfolds is not None:
-            self.aggregated_cv = None
-
-    def __call__(self, env: "xgb.core.CallbackEnv"):
-        # Compatibility with xgboost < 1.3
-        return self.after_iteration(env.model, env.iteration, env.evaluation_result_list)
-
-    def init(self, model: "xgb.Booster"):
-        """Internal function for intialization"""
-        booster: "xgb.Booster" = model
-        self.state["best_iteration"] = 0
-        self.state["best_score"] = float("inf")
-        if booster is None:
-            assert self.cvfolds is not None
-            return
-        if booster.attr("best_score") is not None:
-            self.state["best_score"] = float(booster.attr("best_score"))
-            self.state["best_iteration"] = int(booster.attr("best_iteration"))
-            self.state["best_msg"] = booster.attr("best_msg")
-        else:
-            booster.set_attr(best_iteration=str(self.state["best_iteration"]))
-            booster.set_attr(best_score=str(self.state["best_score"]))
+def _get_custom_call_back(
+    early_stopping_rounds: int,
+    verbose_eval: int,
+    fevals: List[Callable],
+    evals: List[Tuple["xgb.DMatrix", str]],
+    focused_metric: str = "tr-p-rmse",
+    cvfolds: List["xgb.training.CVPack"] = None,
+) -> "TrainingCallback":
+    """Get a customized callback function for XGBoost. Work around xgboost import."""
 
-    def after_iteration(
-        self, model: "xgb.Booster", epoch: int, evals_log: Dict
-    ):  # pylint: disable = unused-argument
-        """Internal function for after_iteration"""
+    def optional_xgboost_callback(cls):
+        """Decorator for importing TraningCallback from xgboost"""
         # pylint:disable = import-outside-toplevel
         try:
-            from xgboost.callback import _fmt_metric  # type: ignore
+            from xgboost.callback import TrainingCallback  # type: ignore
+        # pylint:enable = import-outside-toplevel
         except ImportError:
-            # Compatibility with xgboost >= 1.6
 
-            def _fmt_metric(value, show_stdv=True):
-                if len(value) == 2:
-                    return f"{value[0]}:{value[1]:.5f}"
-                if len(value) == 3:
-                    if show_stdv:
-                        return f"{value[0]}:{value[1]:.5f}+{value[2]:.5f}"
-                    return f"{value[0]}:{value[1]:.5f}"
-                raise ValueError("wrong metric value", value)
+            class TrainingCallback:  # type: ignore
+                pass
 
-        import xgboost as xgb
-        from xgboost import rabit  # type: ignore
+        class OptXGBoostCustomCallback(cls, TrainingCallback):  # type: ignore
+            pass
 
-        try:
-            from xgboost.training import aggcv  # type: ignore
-        except ImportError:
-            from xgboost.callback import _aggcv as aggcv  # type: ignore
+        return OptXGBoostCustomCallback
 
-        # pylint:enable = import-outside-toplevel
-        if not self.state:
-            self.init(model)
-        booster: xgb.Booster = model
-        iteration: int = epoch
-        cvfolds: List[xgb.training.CVPack] = self.cvfolds
-        ##### Evaluation #####
-        # `eval_result` is a list of (key, score)
-        eval_result: List[Tuple[str, float]] = []
-        if cvfolds is None:
-            eval_result = list(
-                itertools_chain.from_iterable(
-                    [
-                        (key, float(value))
-                        for key, value in map(
-                            lambda x: x.split(":"),
-                            booster.eval_set(
-                                evals=self.evals,
-                                iteration=iteration,
-                                feval=feval,
-                            ).split()[1:],
-                        )
-                    ]
-                    for feval in self.fevals
-                )
-            )
-        else:
-            eval_result = list(
-                itertools_chain.from_iterable(
-                    [
-                        (key, score)
-                        for key, score, _std in aggcv(
-                            fold.eval(
-                                iteration=iteration,
-                                feval=feval,
+    @optional_xgboost_callback
+    class XGBoostCustomCallback:
+        """Custom callback class for xgboost to support multiple custom evaluation functions"""
+
+        def __init__(
+            self,
+            early_stopping_rounds: int,
+            verbose_eval: int,
+            fevals: List[Callable],
+            evals: List[Tuple["xgb.DMatrix", str]],
+            focused_metric: str = "tr-p-rmse",
+            cvfolds: List["xgb.training.CVPack"] = None,
+        ):
+            self.early_stopping_rounds = early_stopping_rounds
+            self.verbose_eval = verbose_eval
+            self.fevals = fevals
+            self.evals = evals
+            self.state: Dict[str, Any] = {}
+            self.focused_metric = focused_metric
+            self.sort_key = make_metric_sorter(focused_metric=focused_metric)
+            self.cvfolds = cvfolds
+            if cvfolds is not None:
+                self.aggregated_cv = None
+
+        def __call__(self, env: "xgb.core.CallbackEnv"):
+            # Compatibility with xgboost < 1.3
+            return self.after_iteration(env.model, env.iteration, env.evaluation_result_list)
+
+        def init(self, model: "xgb.Booster"):
+            """Internal function for intialization"""
+            booster: "xgb.Booster" = model
+            self.state["best_iteration"] = 0
+            self.state["best_score"] = float("inf")
+            if booster is None:
+                assert self.cvfolds is not None
+                return
+            if booster.attr("best_score") is not None:
+                self.state["best_score"] = float(booster.attr("best_score"))
+                self.state["best_iteration"] = int(booster.attr("best_iteration"))
+                self.state["best_msg"] = booster.attr("best_msg")
+            else:
+                booster.set_attr(best_iteration=str(self.state["best_iteration"]))
+                booster.set_attr(best_score=str(self.state["best_score"]))
+
+        def after_iteration(
+            self, model: "xgb.Booster", epoch: int, evals_log: Dict
+        ):  # pylint: disable = unused-argument
+            """Internal function for after_iteration"""
+            # pylint:disable = import-outside-toplevel
+            try:
+                from xgboost.callback import _fmt_metric  # type: ignore
+            except ImportError:
+                # Compatibility with xgboost >= 1.6
+
+                def _fmt_metric(value, show_stdv=True):
+                    if len(value) == 2:
+                        return f"{value[0]}:{value[1]:.5f}"
+                    if len(value) == 3:
+                        if show_stdv:
+                            return f"{value[0]}:{value[1]:.5f}+{value[2]:.5f}"
+                        return f"{value[0]}:{value[1]:.5f}"
+                    raise ValueError("wrong metric value", value)
+
+            import xgboost as xgb
+            from xgboost import rabit  # type: ignore
+
+            try:
+                from xgboost.training import aggcv  # type: ignore
+            except ImportError:
+                from xgboost.callback import _aggcv as aggcv  # type: ignore
+
+            # pylint:enable = import-outside-toplevel
+            if not self.state:
+                self.init(model)
+            booster: xgb.Booster = model
+            iteration: int = epoch
+            cvfolds: List[xgb.training.CVPack] = self.cvfolds
+            ##### Evaluation #####
+            # `eval_result` is a list of (key, score)
+            eval_result: List[Tuple[str, float]] = []
+            if cvfolds is None:
+                eval_result = list(
+                    itertools_chain.from_iterable(
+                        [
+                            (key, float(value))
+                            for key, value in map(
+                                lambda x: x.split(":"),
+                                booster.eval_set(
+                                    evals=self.evals,
+                                    iteration=iteration,
+                                    feval=feval,
+                                ).split()[1:],
                             )
-                            for fold in cvfolds
-                        )
-                    ]
-                    for feval in self.fevals
+                        ]
+                        for feval in self.fevals
+                    )
                 )
-            )
-        eval_result = list(eval_result)
-        eval_result.sort(key=self.sort_key)
-
-        ##### Print eval result #####
-        if self.verbose_eval and iteration % self.verbose_eval == 0:
-            info = []
-            for key, score in eval_result:
-                if "null" not in key:
-                    info.append(f"{key}: {score:.6f}")
-            logger.debug("XGB iter %3d: %s", iteration, "\t".join(info))
-
-        ##### Choose score and do early stopping #####
-        score = None
-        for key, _score in eval_result:
-            if key == self.focused_metric:
-                score = _score
-                break
-        assert score is not None
-
-        best_score = self.state["best_score"]
-        best_iteration = self.state["best_iteration"]
-        if score < best_score:
-            tab = "\t"  # to work with f-string
-            msg = f"[{epoch}] {tab.join([_fmt_metric(x) for x in eval_result])}"
-            self.state["best_msg"] = msg
-            self.state["best_score"] = score
-            self.state["best_iteration"] = epoch
-            # save the property to attributes, so they will occur in checkpoint.
-            if model is not None:
-                model.set_attr(
-                    best_score=str(self.state["best_score"]),
-                    best_iteration=str(self.state["best_iteration"]),
-                    best_msg=self.state["best_msg"],
+            else:
+                eval_result = list(
+                    itertools_chain.from_iterable(
+                        [
+                            (key, score)
+                            for key, score, _std in aggcv(
+                                fold.eval(
+                                    iteration=iteration,
+                                    feval=feval,
+                                )
+                                for fold in cvfolds
+                            )
+                        ]
+                        for feval in self.fevals
+                    )
                 )
-        elif epoch - best_iteration >= self.early_stopping_rounds:
-            best_msg = self.state["best_msg"]
-
-            if self.verbose_eval and rabit.get_rank() == 0:
-                logger.debug("XGB stopped. Best iteration: %s ", best_msg)
-            return True  # instead of raising EarlyStopException, returning True to end the training
-        # False to indicate training should not stop.
-        return False
+            eval_result = list(eval_result)
+            eval_result.sort(key=self.sort_key)
+
+            ##### Print eval result #####
+            if self.verbose_eval and iteration % self.verbose_eval == 0:
+                info = []
+                for key, score in eval_result:
+                    if "null" not in key:
+                        info.append(f"{key}: {score:.6f}")
+                logger.debug("XGB iter %3d: %s", iteration, "\t".join(info))
+
+            ##### Choose score and do early stopping #####
+            score = None
+            for key, _score in eval_result:
+                if key == self.focused_metric:
+                    score = _score
+                    break
+            assert score is not None
+
+            best_score = self.state["best_score"]
+            best_iteration = self.state["best_iteration"]
+            if score < best_score:
+                tab = "\t"  # to work with f-string
+                msg = f"[{epoch}] {tab.join([_fmt_metric(x) for x in eval_result])}"
+                self.state["best_msg"] = msg
+                self.state["best_score"] = score
+                self.state["best_iteration"] = epoch
+                # save the property to attributes, so they will occur in checkpoint.
+                if model is not None:
+                    model.set_attr(
+                        best_score=str(self.state["best_score"]),
+                        best_iteration=str(self.state["best_iteration"]),
+                        best_msg=self.state["best_msg"],
+                    )
+            elif epoch - best_iteration >= self.early_stopping_rounds:
+                best_msg = self.state["best_msg"]
+
+                if self.verbose_eval and rabit.get_rank() == 0:
+                    logger.debug("XGB stopped. Best iteration: %s ", best_msg)
+                # instead of raising EarlyStopException, returning True to end the training
+                return True
+            # False to indicate training should not stop.
+            return False
+
+    return XGBoostCustomCallback(
+        early_stopping_rounds=early_stopping_rounds,
+        verbose_eval=verbose_eval,
+        fevals=fevals,
+        evals=evals,
+        focused_metric=focused_metric,
+        cvfolds=cvfolds,
+    )
diff --git a/tests/python/unittest/test_meta_schedule_cost_model.py b/tests/python/unittest/test_meta_schedule_cost_model.py
index 94b7bce246..c47897eabb 100644
--- a/tests/python/unittest/test_meta_schedule_cost_model.py
+++ b/tests/python/unittest/test_meta_schedule_cost_model.py
@@ -15,27 +15,27 @@
 # specific language governing permissions and limitations
 # under the License.
 # pylint: disable=missing-docstring
+from typing import List
+
 import os
 import re
 import shutil
-import sys
 import tempfile
-from typing import List
-
+from functools import partial
+import unittest
 import numpy as np
-import pytest
+
 import tvm
 import tvm.testing
+from tvm.script import tir as T
+from tvm.tir.schedule.schedule import Schedule
 from tvm.meta_schedule.cost_model import PyCostModel, RandomModel, XGBModel
-from tvm.meta_schedule.cost_model.xgb_model import XGBoostCustomCallback, PackSum
+from tvm.meta_schedule.cost_model.xgb_model import _get_custom_call_back, PackSum
 from tvm.meta_schedule.feature_extractor import RandomFeatureExtractor
 from tvm.meta_schedule.runner import RunnerResult
 from tvm.meta_schedule.search_strategy import MeasureCandidate
 from tvm.meta_schedule.tune_context import TuneContext
 from tvm.meta_schedule.utils import derived_object
-from tvm.script import tir as T
-from tvm.tir.schedule.schedule import Schedule
-
 
 # pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring
 @tvm.script.ir_module
@@ -196,13 +196,15 @@ def test_meta_schedule_xgb_model_reload():
     assert (res1 == res2).all()
     assert old_data_size == new_data_size
     assert len(old_data) == len(new_data)
-    for (k1, g1), (k2, g2) in zip(old_data.items(), new_data.items()):
+    for (k1, g1), (k2, g2) in zip(  # pylint: disable=invalid-name
+        old_data.items(), new_data.items()
+    ):
         assert k1 == k2
         assert k1 == g1.group_hash
         assert k2 == g2.group_hash
         assert (g1.costs == g2.costs).all()
         assert len(g1.features) == len(g2.features)
-        for f1, f2 in zip(g1.features, g2.features):
+        for f1, f2 in zip(g1.features, g2.features):  # pylint: disable=invalid-name
             assert (f1 == f2).all()
 
 
@@ -229,10 +231,23 @@ def test_meta_schedule_xgb_model_reupdate():
     model.predict(TuneContext(), [_dummy_candidate() for i in range(predict_sample_count)])
 
 
-def test_meta_schedule_xgb_model_callback():
+def xgb_version_check():
+
+    # pylint: disable=import-outside-toplevel
+    import xgboost as xgb
+    from packaging import version
+
+    # pylint: enable=import-outside-toplevel
+    return version.parse(xgb.__version__) >= version.parse("1.6.0")
+
+
+@unittest.skipIf(xgb_version_check(), "test not supported for xgboost version after 1.6.0")
+def test_meta_schedule_xgb_model_callback_as_function():
+    # pylint: disable=import-outside-toplevel
     import xgboost as xgb
     from itertools import chain as itertools_chain
-    from functools import partial
+
+    # pylint: enable=import-outside-toplevel
 
     extractor = RandomFeatureExtractor()
     model = XGBModel(extractor=extractor, num_warmup_samples=10)
@@ -252,7 +267,7 @@ def test_meta_schedule_xgb_model_callback():
         model.save(path.name)
 
         old_booster = model.booster
-        xs = [
+        xs = [  # pylint: disable=invalid-name
             x.numpy().astype("float32")
             for x in extractor.extract_from(
                 TuneContext(),
@@ -289,7 +304,7 @@ def test_meta_schedule_xgb_model_callback():
             obj=obj,
             callbacks=[
                 partial(
-                    XGBoostCustomCallback(
+                    _get_custom_call_back(
                         early_stopping_rounds=model.early_stopping_rounds,
                         verbose_eval=model.verbose_eval,
                         fevals=[rmse, avg_peak_score],
@@ -300,7 +315,7 @@ def test_meta_schedule_xgb_model_callback():
             ],
         )
 
-        xs = [
+        xs = [  # pylint: disable=invalid-name
             x.numpy().astype("float32")
             for x in extractor.extract_from(
                 TuneContext(),