You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by xi...@apache.org on 2022/07/07 19:44:24 UTC
[tvm] branch main updated: [MetaSchedule][Minor] Stability Improvements (#12014)
This is an automated email from the ASF dual-hosted git repository.
xiyou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 013d5e8fcb [MetaSchedule][Minor] Stability Improvements (#12014)
013d5e8fcb is described below
commit 013d5e8fcbd94fb3a0c5c0cdcaea03af43c464aa
Author: Xiyou Zhou <xi...@octoml.ai>
AuthorDate: Thu Jul 7 12:44:18 2022 -0700
[MetaSchedule][Minor] Stability Improvements (#12014)
* Fix tuning util for uint8.
* Change to check runner_result.
* Revert change to let cost model learn.
---
python/tvm/meta_schedule/testing/tune_utils.py | 26 +++++++++++++-------------
1 file changed, 13 insertions(+), 13 deletions(-)
diff --git a/python/tvm/meta_schedule/testing/tune_utils.py b/python/tvm/meta_schedule/testing/tune_utils.py
index aad8496a46..fe0984d51c 100644
--- a/python/tvm/meta_schedule/testing/tune_utils.py
+++ b/python/tvm/meta_schedule/testing/tune_utils.py
@@ -48,21 +48,21 @@ def generate_input_data(
"""
if input_dtype.startswith("float"):
return np.random.uniform(size=input_shape).astype(input_dtype)
- if input_dtype in ["uint8", "int8"]:
- return np.random.randint(
- low=0,
- high=127,
- size=input_shape,
- dtype="int32", # TODO(zxybazh): fix the datatype when int8 / uint8 is supported better
+ if low is None or high is None:
+ warnings.warn(
+ f"Model input value range for shape {input_shape} of {input_dtype} is not set!"
)
- if input_dtype in ["int32", "int64"]:
- if low is None or high is None:
- warnings.warn(
- "Model input value range for shape {input_shape} of {input_dtype} is not set!"
- )
+ range_map = {
+ "uint8": (0, 255),
+ "int8": (-128, 127),
+ "int32": (0, 10000),
+ "int64": (0, 10000),
+ }
+ if input_dtype in range_map:
+ _low, _high = range_map[input_dtype]
return np.random.randint(
- low=0 if low is None else low,
- high=10000 if high is None else high,
+ low=_low if low is None else low,
+ high=_high if high is None else high,
size=input_shape,
dtype=input_dtype,
)