You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by wu...@apache.org on 2022/09/12 23:06:22 UTC
[tvm] branch main updated: [AutoTVM] Fix `None` feature in AutoTVM tuning (#12760)
This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 4d2766409f [AutoTVM] Fix `None` feature in AutoTVM tuning (#12760)
4d2766409f is described below
commit 4d2766409f1b95504aac171649367c2df2813029
Author: Junru Shao <ju...@gmail.com>
AuthorDate: Mon Sep 12 15:06:16 2022 -0800
[AutoTVM] Fix `None` feature in AutoTVM tuning (#12760)
This PR introduces a couple of fixes to make AutoTVM working more
robustly:
- Fixed a very rarecase that `None` could pop up in AutoTVM features;
- Fixed a misuse of `ARGS` in the testing script;
- Fixed the filename for caching.
---
python/tvm/autotvm/testing/tune_relay.py | 13 +++++++------
python/tvm/autotvm/tuner/xgboost_cost_model.py | 7 +++----
python/tvm/meta_schedule/testing/relay_workload.py | 2 +-
3 files changed, 11 insertions(+), 11 deletions(-)
diff --git a/python/tvm/autotvm/testing/tune_relay.py b/python/tvm/autotvm/testing/tune_relay.py
index e474596374..743127ec1d 100644
--- a/python/tvm/autotvm/testing/tune_relay.py
+++ b/python/tvm/autotvm/testing/tune_relay.py
@@ -139,12 +139,6 @@ def _parse_args():
tracker_key=parsed.rpc_key,
session_timeout_sec=600,
)
- if ARGS.target.kind.name != "llvm" and ARGS.graph_tuner:
- raise ValueError("GraphTuner only supports llvm target")
- if ARGS.target.kind.name != "llvm" and ARGS.cpu_flush:
- raise ValueError("cpu_flush only supports llvm target")
- if ARGS.target.kind.name == "llvm" and not ARGS.cpu_flush:
- warnings.warn("cpu_flush is not enabled for llvm target")
return parsed
@@ -152,6 +146,13 @@ ARGS = _parse_args()
def main():
+ if ARGS.target.kind.name != "llvm" and ARGS.graph_tuner:
+ raise ValueError("GraphTuner only supports llvm target")
+ if ARGS.target.kind.name != "llvm" and ARGS.cpu_flush:
+ raise ValueError("cpu_flush only supports llvm target")
+ if ARGS.target.kind.name == "llvm" and not ARGS.cpu_flush:
+ warnings.warn("cpu_flush is not enabled for llvm target")
+
log_file = os.path.join(ARGS.work_dir, f"{ARGS.workload}.json")
graph_opt_sch_file = os.path.join(ARGS.work_dir, f"{ARGS.workload}_graph_opt.log")
measure_option = autotvm.measure_option(
diff --git a/python/tvm/autotvm/tuner/xgboost_cost_model.py b/python/tvm/autotvm/tuner/xgboost_cost_model.py
index d4942ce6a4..6fa04f336f 100644
--- a/python/tvm/autotvm/tuner/xgboost_cost_model.py
+++ b/python/tvm/autotvm/tuner/xgboost_cost_model.py
@@ -21,12 +21,11 @@ import logging
import time
import numpy as np
-
from tvm.contrib.popen_pool import PopenPoolExecutor, StatusKind
from .. import feature
from ..utils import get_rank
-from .metric import max_curve, recall_curve, cover_curve
+from .metric import cover_curve, max_curve, recall_curve
from .model_based_tuner import CostModel, FeatureCache
xgb = None
@@ -346,7 +345,7 @@ class XGBoostCostModel(CostModel):
ret = np.empty((len(indexes), feature_len), dtype=np.float32)
for i, ii in enumerate(indexes):
t = fea_cache[ii]
- if t.shape[0] < feature_len:
+ if t is not None and t.shape[0] < feature_len:
t = np.pad(t, (0, feature_len - t.shape[0]))
ret[i, :] = t if t is not None else 0
return ret
@@ -449,8 +448,8 @@ def custom_callback(
):
"""callback function for xgboost to support multiple custom evaluation functions"""
# pylint: disable=import-outside-toplevel
- from xgboost.core import EarlyStopException
from xgboost.callback import _fmt_metric
+ from xgboost.core import EarlyStopException
try:
from xgboost.training import aggcv
diff --git a/python/tvm/meta_schedule/testing/relay_workload.py b/python/tvm/meta_schedule/testing/relay_workload.py
index f4f6336df3..98bb995120 100644
--- a/python/tvm/meta_schedule/testing/relay_workload.py
+++ b/python/tvm/meta_schedule/testing/relay_workload.py
@@ -230,7 +230,7 @@ def get_network(
inputs: Tuple[str, List[int], str]
params_bytearray: bytearray
- filename = f'relay-{name}-{",".join(str(i) for i in input_shape)}.json'
+ filename = f'relay-{name}-{layout}-{",".join(str(i) for i in input_shape)}.json'
cached = _load_cache(cache_dir, filename)
if cached is None:
with multiprocessing.Pool(processes=1) as pool: